Source code for xorbits._mars.tensor.base.partition

# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools

import numpy as np

from ... import opcodes as OperandDef
from ...core import ENTITY_TYPE, ExecutableTuple, recursive_tile
from ...serialization.serializables import (
    AnyField,
    BoolField,
    FieldTypes,
    Int32Field,
    KeyField,
    ListField,
    StringField,
)
from ...utils import flatten, has_unknown_shape, stack_back
from ..array_utils import as_same_device, device
from ..core import TENSOR_CHUNK_TYPE, TENSOR_TYPE, TensorOrder
from ..datasource import tensor as astensor
from ..operands import TensorOperand, TensorShuffleProxy
from ..utils import validate_axis, validate_order
from .psrs import TensorPSRSOperandMixin


class ParallelPartitionMixin(TensorPSRSOperandMixin):
    @classmethod
    def calc_paritions_info(cls, op, kth, size, sort_info_chunks):
        # stage5, collect sort infos and calculate partition info for each partitions
        if isinstance(kth, TENSOR_TYPE):
            kth = kth.chunks[0]
            is_kth_input = True
        else:
            is_kth_input = False
        calc_op = CalcPartitionsInfo(
            kth=kth, size=size, dtype=np.dtype(np.int32), gpu=op.gpu
        )
        kws = []
        for i, sort_info_chunk in enumerate(sort_info_chunks):
            kws.append(
                {
                    "shape": sort_info_chunk.shape + (len(kth),),
                    "order": sort_info_chunk.order,
                    "index": sort_info_chunk.index,
                    "pos": i,
                }
            )
        inputs = list(sort_info_chunks)
        if is_kth_input:
            inputs.insert(0, kth)
        return calc_op.new_chunks(inputs, kws=kws, output_limit=len(kws))

    @classmethod
    def partition_on_merged(
        cls,
        op,
        need_align,
        partition_merged_chunks,
        partition_indices_chunks,
        partition_info_chunks,
    ):
        # Stage 6: partition on each partitions
        return_value, return_indices = op.return_value, op.return_indices
        partitioned_chunks, partitioned_indices_chunks = [], []
        for i, partition_merged_chunk, partition_info_chunk in zip(
            itertools.count(), partition_merged_chunks, partition_info_chunks
        ):
            partition_op = PartitionMerged(
                return_value=return_value,
                return_indices=return_indices,
                order=op.order,
                kind=op.kind,
                need_align=need_align,
                dtype=partition_merged_chunk.dtype,
                gpu=op.gpu,
            )
            chunk_inputs = []
            kws = []
            if return_value:
                chunk_inputs.append(partition_merged_chunk)
                kws.append(
                    {
                        "shape": partition_merged_chunk.shape,
                        "order": partition_merged_chunk.order,
                        "index": partition_merged_chunk.index,
                        "dtype": partition_merged_chunk.dtype,
                        "type": "partitioned",
                    }
                )
            if return_indices:
                if not return_value:
                    # value is required even it's not returned
                    chunk_inputs.append(partition_merged_chunk)
                chunk_inputs.append(partition_indices_chunks[i])
                kws.append(
                    {
                        "shape": partition_merged_chunk.shape,
                        "order": TensorOrder.C_ORDER,
                        "index": partition_merged_chunk.index,
                        "dtype": np.dtype(np.int64),
                        "type": "argpartition",
                    }
                )
            chunk_inputs.append(partition_info_chunk)
            partition_chunks = partition_op.new_chunks(chunk_inputs, kws=kws)
            if return_value:
                partitioned_chunks.append(partition_chunks[0])
            if return_indices:
                partitioned_indices_chunks.append(partition_chunks[-1])

        return partitioned_chunks, partitioned_indices_chunks


class TensorPartition(TensorOperand, ParallelPartitionMixin):
    _op_type_ = OperandDef.PARTITION

    _input = KeyField("input")
    _kth = AnyField("kth")
    _axis = Int32Field("axis")
    _kind = StringField("kind")
    _order = ListField("order", FieldTypes.string)
    _need_align = BoolField("need_align")
    _return_value = BoolField("return_value")
    _return_indices = BoolField("return_indices")

    def __init__(
        self,
        kth=None,
        axis=None,
        kind=None,
        order=None,
        need_align=None,
        return_value=None,
        return_indices=None,
        dtype=None,
        gpu=None,
        **kw,
    ):
        super().__init__(
            _kth=kth,
            _axis=axis,
            _kind=kind,
            _order=order,
            _need_align=need_align,
            _return_value=return_value,
            _return_indices=return_indices,
            dtype=dtype,
            gpu=gpu,
            **kw,
        )

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)
        self._input = self._inputs[0]
        if len(self._inputs) > 1:
            self._kth = self._inputs[1]

    @property
    def psrs_kinds(self):
        # to keep compatibility with PSRS
        # remember when merging data in PSRSShuffle(reduce),
        # we don't need sort, thus set psrs_kinds[2] to None
        return ["quicksort", "mergesort", None]

    @property
    def need_align(self):
        return self._need_align

    @property
    def input(self):
        return self._input

    @property
    def kth(self):
        return self._kth

    @property
    def axis(self):
        return self._axis

    @property
    def kind(self):
        return self._kind

    @property
    def order(self):
        return self._order

    @property
    def return_value(self):
        return self._return_value

    @property
    def return_indices(self):
        return self._return_indices

    @property
    def output_limit(self):
        return int(bool(self._return_value)) + int(bool(self._return_indices))

    def __call__(self, a, kth):
        inputs = [a]
        if isinstance(kth, TENSOR_TYPE):
            inputs.append(kth)
        kws = []
        if self._return_value:
            kws.append(
                {
                    "shape": a.shape,
                    "order": a.order,
                    "type": "sorted",
                    "dtype": a.dtype,
                }
            )
        if self._return_indices:
            kws.append(
                {
                    "shape": a.shape,
                    "order": TensorOrder.C_ORDER,
                    "type": "argsort",
                    "dtype": np.dtype(np.int64),
                }
            )
        ret = self.new_tensors(inputs, kws=kws)
        if len(kws) == 1:
            return ret[0]
        return ExecutableTuple(ret)

    @classmethod
    def _tile_psrs(cls, op, kth):
        """
        Approach here would be almost like PSRSSorter, but there are definitely some differences
        Main processes are listed below:
        Stage 1, local sort and regular samples collected
        State 2, gather and merge samples, choose and broadcast p-1 pivots
        Stage 3, Local data is partitioned
        Stage 4: all *ith* classes are gathered and merged, sizes should be calculated as well
        Stage 5: collect sizes from partitions, calculate how to partition given kth
        Stage 6: partition on each partitions
        Stage 7: align if axis is given, and more than 1 dimension
        """
        out_tensor = op.outputs[0]
        return_value, return_indices = op.return_value, op.return_indices
        # preprocess, to make sure chunk shape on axis are approximately same
        in_tensor, axis_chunk_shape, out_idxes, need_align = yield from cls.preprocess(
            op
        )
        axis_offsets = [0] + np.cumsum(in_tensor.nsplits[op.axis]).tolist()[:-1]

        out_chunks, out_indices_chunks = [], []
        for out_idx in out_idxes:
            # stage 1: local sort and regular samples collected
            (
                sorted_chunks,
                indices_chunks,
                sampled_chunks,
            ) = cls.local_sort_and_regular_sample(
                op, in_tensor, axis_chunk_shape, axis_offsets, out_idx
            )

            # stage 2: gather and merge samples, choose and broadcast p-1 pivots
            concat_pivot_chunk = cls.concat_and_pivot(
                op, axis_chunk_shape, out_idx, sorted_chunks, sampled_chunks
            )

            # stage 3: Local data is partitioned
            partition_chunks = cls.partition_local_data(
                op, axis_chunk_shape, sorted_chunks, indices_chunks, concat_pivot_chunk
            )

            proxy_chunk = TensorShuffleProxy(dtype=partition_chunks[0].dtype).new_chunk(
                partition_chunks, shape=()
            )

            # stage 4: all *ith* classes are gathered and merged,
            # note that we don't need sort here, op.psrs_kinds[2] is None
            # force need_align=True to get sort info
            (
                partition_merged_chunks,
                partition_indices_chunks,
                sort_info_chunks,
            ) = cls.partition_merge_data(op, True, True, partition_chunks, proxy_chunk)

            # stage5, collect sort infos and calculate partition info for each partitions
            partition_info_chunks = cls.calc_paritions_info(
                op, kth, in_tensor.shape[op.axis], sort_info_chunks
            )

            # Stage 6: partition on each partitions
            partitioned_chunks, partitioned_indices_chunks = cls.partition_on_merged(
                op,
                need_align,
                partition_merged_chunks,
                partition_indices_chunks,
                partition_info_chunks,
            )

            if not need_align:
                if return_value:
                    out_chunks.extend(partitioned_chunks)
                if return_indices:
                    out_indices_chunks.extend(partitioned_indices_chunks)
            else:
                (
                    align_reduce_chunks,
                    align_reduce_indices_chunks,
                ) = cls.align_partitions_data(
                    op,
                    out_idx,
                    in_tensor,
                    partitioned_chunks,
                    partitioned_indices_chunks,
                    sort_info_chunks,
                )
                if return_value:
                    out_chunks.extend(align_reduce_chunks)
                if return_indices:
                    out_indices_chunks.extend(align_reduce_indices_chunks)

        new_op = op.copy()
        nsplits = list(in_tensor.nsplits)
        if not need_align:
            nsplits[op.axis] = (np.nan,) * axis_chunk_shape
        kws = []
        if return_value:
            kws.append(
                {
                    "shape": out_tensor.shape,
                    "order": out_tensor.order,
                    "chunks": out_chunks,
                    "nsplits": tuple(nsplits),
                    "dtype": out_tensor.dtype,
                    "type": "partitioned",
                }
            )
        if return_indices:
            kws.append(
                {
                    "shape": out_tensor.shape,
                    "order": TensorOrder.C_ORDER,
                    "chunks": out_indices_chunks,
                    "nsplits": tuple(nsplits),
                    "dtype": np.dtype(np.int64),
                    "type": "argpartition",
                }
            )
        return new_op.new_tensors(op.inputs, kws=kws)

    @classmethod
    def tile(cls, op):
        in_tensor = op.input
        if np.isnan(in_tensor.shape[op.axis]):
            yield

        kth = op.kth
        if isinstance(kth, TENSOR_TYPE):
            # if `kth` is a tensor, make sure no unknown shape
            if has_unknown_shape(kth):
                yield
            kth = yield from recursive_tile(kth.rechunk(kth.shape))

        return_value, return_indices = op.return_value, op.return_indices
        if in_tensor.chunk_shape[op.axis] == 1:
            out_chunks, out_indices_chunks = [], []
            for chunk in in_tensor.chunks:
                chunk_op = op.copy().reset_key()
                kws = []
                if return_value:
                    kws.append(
                        {
                            "shape": chunk.shape,
                            "index": chunk.index,
                            "order": chunk.order,
                            "dtype": chunk.dtype,
                            "type": "partitioned",
                        }
                    )
                if return_indices:
                    kws.append(
                        {
                            "shape": chunk.shape,
                            "index": chunk.index,
                            "order": TensorOrder.C_ORDER,
                            "dtype": np.dtype(np.int64),
                            "type": "argpartition",
                        }
                    )
                chunk_inputs = [chunk]
                if isinstance(kth, TENSOR_TYPE):
                    chunk_inputs.append(kth.chunks[0])
                chunks = chunk_op.new_chunks(chunk_inputs, kws=kws)
                if return_value:
                    out_chunks.append(chunks[0])
                if return_indices:
                    out_indices_chunks.append(chunks[-1])

            new_op = op.copy()
            kws = [out.params for out in op.outputs]
            if return_value:
                kws[0]["nsplits"] = in_tensor.nsplits
                kws[0]["chunks"] = out_chunks
            if return_indices:
                kws[-1]["nsplits"] = in_tensor.nsplits
                kws[-1]["chunks"] = out_indices_chunks
            return new_op.new_tensors([in_tensor], kws=kws)
        else:
            return (yield from cls._tile_psrs(op, kth))

    @classmethod
    def execute(cls, ctx, op):
        inputs, device_id, xp = as_same_device(
            [ctx[inp.key] for inp in op.inputs], device=op.device, ret_extra=True
        )
        a = inputs[0]
        if len(inputs) == 2:
            kth = inputs[1]
        else:
            kth = op.kth
        return_value, return_indices = op.return_value, op.return_indices

        with device(device_id):
            kw = {}
            if op.kind is not None:
                kw["kind"] = op.kind
            if op.order is not None:
                kw["order"] = op.order

            if return_indices:
                if not return_value:
                    ctx[op.outputs[0].key] = xp.argpartition(a, kth, axis=op.axis, **kw)
                else:
                    argparts = ctx[op.outputs[1].key] = xp.argpartition(
                        a, kth, axis=op.axis, **kw
                    )
                    ctx[op.outputs[0].key] = xp.take_along_axis(a, argparts, op.axis)
            else:
                ctx[op.outputs[0].key] = xp.partition(a, kth, axis=op.axis, **kw)


class CalcPartitionsInfo(TensorOperand, TensorPSRSOperandMixin):
    _op_type_ = OperandDef.CALC_PARTITIONS_INFO

    _kth = AnyField("kth")
    _size = Int32Field("size")

    def __init__(self, kth=None, size=None, dtype=None, gpu=None, **kw):
        super().__init__(_kth=kth, _size=size, dtype=dtype, gpu=gpu, **kw)

    @property
    def kth(self):
        return self._kth

    @property
    def size(self):
        return self._size

    @property
    def output_limit(self):
        return np.inf

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)
        if isinstance(self._kth, ENTITY_TYPE):
            self._kth = self._inputs[0]

    @classmethod
    def execute(cls, ctx, op):
        inputs, device_id, xp = as_same_device(
            [ctx[inp.key] for inp in op.inputs], device=op.device, ret_extra=True
        )

        with device(device_id):
            if isinstance(op.kth, TENSOR_CHUNK_TYPE):
                kth = inputs[0]
                sort_infos = inputs[1:]
                # make kth all positive
                kth = _validate_kth_value(kth, op.size)
            else:
                kth = op.kth
                sort_infos = inputs

            sort_info_shape = sort_infos[0].shape
            # create arrays filled with -1, -1 means do nothing about partition
            partition_infos = [
                xp.full(sort_info_shape + (len(kth),), -1) for _ in sort_infos
            ]
            concat_sort_info = xp.stack([sort_info.ravel() for sort_info in sort_infos])
            cumsum_sort_info = xp.cumsum(concat_sort_info, axis=0)

            for j in range(cumsum_sort_info.shape[1]):
                idx = xp.unravel_index(j, sort_infos[0].shape)
                sizes = cumsum_sort_info[:, j]
                to_partition_chunk_idxes = xp.searchsorted(sizes, kth, side="right")
                for i, to_partition_chunk_idx in enumerate(to_partition_chunk_idxes):
                    partition_idx = tuple(idx) + (i,)
                    k = kth[i]
                    # if to partition on chunk 0, just set to kth
                    # else kth - {size of previous chunks}
                    chunk_k = (
                        k
                        if to_partition_chunk_idx == 0
                        else k - sizes[to_partition_chunk_idx - 1]
                    )
                    partition_infos[to_partition_chunk_idx][partition_idx] = chunk_k

            for out, partition_info in zip(op.outputs, partition_infos):
                ctx[out.key] = partition_info


class PartitionMerged(TensorOperand, TensorPSRSOperandMixin):
    _op_type_ = OperandDef.PARTITION_MERGED

    _return_value = BoolField("return_value")
    _return_indices = BoolField("return_indices")
    _order = ListField("order", FieldTypes.string)
    _kind = StringField("kind")
    _need_align = BoolField("need_align")

    def __init__(
        self,
        return_value=None,
        return_indices=None,
        order=None,
        kind=None,
        need_align=None,
        **kw,
    ):
        super().__init__(
            _return_value=return_value,
            _return_indices=return_indices,
            _order=order,
            _kind=kind,
            _need_align=need_align,
            **kw,
        )

    @property
    def return_value(self):
        return self._return_value

    @property
    def return_indices(self):
        return self._return_indices

    @property
    def order(self):
        return self._order

    @property
    def kind(self):
        return self._kind

    @property
    def need_align(self):
        return self._need_align

    @property
    def output_limit(self):
        return int(bool(self._return_value)) + int(bool(self._return_indices))

    @classmethod
    def execute(cls, ctx, op):
        return_value, return_indices = op.return_value, op.return_indices

        raw_inputs = [ctx[inp.key] for inp in op.inputs]
        flatten_inputs = flatten(raw_inputs)
        inputs, device_id, xp = as_same_device(
            flatten_inputs, device=op.device, ret_extra=True
        )
        inputs = stack_back(inputs, raw_inputs)
        partition_info = inputs[-1]
        merged_data, merged_indices = None, None
        if return_value:
            merged_data = inputs[0]
        if return_indices:
            # if return indices, value should be returned
            assert len(inputs) == 3
            if not return_value:
                merged_data = inputs[0]
            merged_indices = inputs[1]

        outs, out_indices = [], []
        with device(device_id):
            kw = {}
            if op.kind is not None:
                kw["kind"] = op.kind
            if op.order is not None:
                kw["order"] = op.order

            ravel_partition_info = partition_info.reshape(-1, partition_info.shape[-1])
            for i, merged_vec, kth in zip(
                itertools.count(), merged_data, ravel_partition_info
            ):
                kth = kth[kth > -1]
                if kth.size == 0:
                    if return_value:
                        outs.append(merged_vec)
                    if return_indices:
                        out_indices.append(merged_indices[i])
                else:
                    if return_indices:
                        argparts = xp.argpartition(merged_vec, kth, **kw)
                        if return_value:
                            outs.append(xp.take(merged_vec, argparts))
                        out_indices.append(xp.take(merged_indices[i], argparts))
                    else:
                        outs.append(xp.partition(merged_vec, kth, **kw))

        if not op.need_align:
            assert len(outs or out_indices) == 1
            i = 0
            if return_value:
                ctx[op.outputs[0].key] = outs[0]
                i += 1
            if return_indices:
                ctx[op.outputs[i].key] = out_indices[0]
        else:
            i = 0
            if return_value:
                ctx[op.outputs[0].key] = tuple(outs)
                i += 1
            if return_indices:
                ctx[op.outputs[i].key] = tuple(out_indices)


def _check_kth_dtype(dtype):
    if not np.issubdtype(dtype, np.integer):
        raise TypeError("Partition index must be integer")


def _validate_kth_value(kth, size):
    kth = np.where(kth < 0, kth + size, kth)
    if np.any((kth < 0) | (kth >= size)):
        invalid_kth = next(k for k in kth if k < 0 or k >= size)
        raise ValueError(f"kth(={invalid_kth}) out of bounds ({size})")
    return kth


def _validate_partition_arguments(a, kth, axis, kind, order, kw):
    a = astensor(a)
    if axis is None:
        a = a.flatten()
        axis = 0
    else:
        axis = validate_axis(a.ndim, axis)
    if isinstance(kth, ENTITY_TYPE):
        kth = astensor(kth)
        _check_kth_dtype(kth.dtype)
    else:
        kth = np.atleast_1d(kth)
        kth = _validate_kth_value(kth, a.shape[axis])
    if kth.ndim > 1:
        raise ValueError("object too deep for desired array")
    if kind != "introselect":
        raise ValueError(f"{kind} is an unrecognized kind of select")
    # if a is structure type and order is not None
    order = validate_order(a.dtype, order)
    need_align = kw.pop("need_align", None)
    if len(kw) > 0:
        raise TypeError(
            f"partition() got an unexpected keyword argument '{next(iter(kw))}'"
        )

    return a, kth, axis, kind, order, need_align


[docs]def partition(a, kth, axis=-1, kind="introselect", order=None, **kw): r""" Return a partitioned copy of a tensor. Creates a copy of the tensor with its elements rearranged in such a way that the value of the element in k-th position is in the position it would be in a sorted tensor. All elements smaller than the k-th element are moved before this element and all equal or greater are moved behind it. The ordering of the elements in the two partitions is undefined. Parameters ---------- a : array_like Tensor to be sorted. kth : int or sequence of ints Element index to partition by. The k-th value of the element will be in its final sorted position and all smaller elements will be moved before it and all equal or greater elements behind it. The order of all elements in the partitions is undefined. If provided with a sequence of k-th it will partition all elements indexed by k-th of them into their sorted position at once. axis : int or None, optional Axis along which to sort. If None, the tensor is flattened before sorting. The default is -1, which sorts along the last axis. kind : {'introselect'}, optional Selection algorithm. Default is 'introselect'. order : str or list of str, optional When `a` is a tensor with fields defined, this argument specifies which fields to compare first, second, etc. A single field can be specified as a string. Not all fields need be specified, but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties. Returns ------- partitioned_tensor : Tensor Tensor of the same type and shape as `a`. See Also -------- Tensor.partition : Method to sort a tensor in-place. argpartition : Indirect partition. sort : Full sorting Notes ----- The various selection algorithms are characterized by their average speed, worst case performance, work space size, and whether they are stable. A stable sort keeps items with the same key in the same relative order. The available algorithms have the following properties: ================= ======= ============= ============ ======= kind speed worst case work space stable ================= ======= ============= ============ ======= 'introselect' 1 O(n) 0 no ================= ======= ============= ============ ======= All the partition algorithms make temporary copies of the data when partitioning along any but the last axis. Consequently, partitioning along the last axis is faster and uses less space than partitioning along any other axis. The sort order for complex numbers is lexicographic. If both the real and imaginary parts are non-nan then the order is determined by the real parts except when they are equal, in which case the order is determined by the imaginary parts. Examples -------- >>> import mars.tensor as mt >>> a = mt.array([3, 4, 2, 1]) >>> mt.partition(a, 3).execute() array([2, 1, 3, 4]) >>> mt.partition(a, (1, 3)).execute() array([1, 2, 3, 4]) """ return_indices = kw.pop("return_index", False) a, kth, axis, kind, order, need_align = _validate_partition_arguments( a, kth, axis, kind, order, kw ) op = TensorPartition( kth=kth, axis=axis, kind=kind, order=order, need_align=need_align, return_value=True, return_indices=return_indices, dtype=a.dtype, gpu=a.op.gpu, ) return op(a, kth)