Source code for xorbits._mars.tensor.base.repeat

# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from numbers import Integral

import numpy as np

from ... import opcodes as OperandDef
from ...core import recursive_tile
from ...serialization.serializables import AnyField, Int32Field, KeyField
from ...utils import has_unknown_shape
from ..array_utils import as_same_device, device
from ..core import TENSOR_CHUNK_TYPE, TENSOR_TYPE, Tensor, TensorOrder
from ..datasource import tensor as astensor
from ..operands import TensorHasInput, TensorOperandMixin
from ..utils import broadcast_shape, unify_chunks
from .ravel import ravel


class TensorRepeat(TensorHasInput, TensorOperandMixin):
    _op_type_ = OperandDef.REPEAT

    _input = KeyField("input")
    _repeats = AnyField("repeats")
    _axis = Int32Field("axis")

    def __init__(self, axis=None, dtype=None, sparse=False, **kw):
        super().__init__(_axis=axis, dtype=dtype, sparse=sparse, **kw)

    @property
    def repeats(self):
        return self._repeats

    @property
    def axis(self):
        return self._axis

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)
        self._input = self._inputs[0]
        if len(inputs) > 1:
            self._repeats = self._inputs[1]

    def __call__(self, a, repeats):
        axis = self._axis
        a = astensor(a)
        if axis is None:
            a = ravel(a)

        ax = axis or 0

        if not isinstance(repeats, Integral):
            if not isinstance(repeats, Tensor):
                repeats = np.asarray(repeats)
                if repeats.size == 1:
                    repeats = int(repeats[0])
                    size = repeats * a.shape[axis or 0]
                elif a.shape[ax] == 1:
                    size = repeats = int(repeats.sum())
                else:
                    size = int(repeats.sum())
            else:
                size = np.nan
            if not isinstance(repeats, Integral):
                if repeats.ndim != 1:
                    raise ValueError("repeats should be 1-d tensor")
                broadcast_shape(repeats.shape, a.shape[ax : ax + 1])
        else:
            size = a.shape[axis or 0] * repeats

        shape = a.shape[:ax] + (size,) + a.shape[ax + 1 :]
        self.dtype = a.dtype
        self.sparse = a.issparse()

        inputs = [a]
        if isinstance(repeats, Tensor):
            inputs.append(repeats)
        else:
            self._repeats = repeats

        return self.new_tensor(inputs, shape, order=TensorOrder.C_ORDER)

    @classmethod
    def tile(cls, op):
        a = op.input
        repeats = op.repeats
        axis = op.axis
        ax = axis or 0
        out = op.outputs[0]

        if has_unknown_shape(*op.inputs):
            yield

        if isinstance(repeats, TENSOR_TYPE):
            a, repeats = yield from unify_chunks(a, (repeats, (ax,)))

        nsplit = a.nsplits[axis or 0]

        if isinstance(repeats, Integral):
            new_nsplit = []
            for split in nsplit:
                s = max(split // repeats, 1)
                c = split // s
                new_nsplit.extend([s] * c)
                if split % s != 0:
                    new_nsplit.append(split % s)

            a = yield from recursive_tile(a.rechunk({ax: new_nsplit}))

        out_chunks = []
        ax_cum_count = np.cumsum((0,) + a.nsplits[ax])
        is_repeats_ndarray = isinstance(repeats, np.ndarray)
        for out_idx in itertools.product(*[range(len(s)) for s in a.nsplits]):
            in_chunk = a.cix[out_idx]
            ax_idx = out_idx[ax]
            if is_repeats_ndarray:
                start = ax_cum_count[ax_idx]
                stop = ax_cum_count[ax_idx + 1]
                rp = repeats[start:stop]
                size = int(rp.sum())
            elif not isinstance(repeats, Integral):
                rp = repeats.cix[ax_idx,]
                size = np.nan
            else:
                rp = repeats
                size = in_chunk.shape[ax] * rp

            chunk_inputs = [in_chunk]
            if isinstance(rp, TENSOR_CHUNK_TYPE):
                chunk_inputs.append(rp)

            chunk_shape = in_chunk.shape[:ax] + (size,) + in_chunk.shape[ax + 1 :]
            chunk_op = op.copy().reset_key()
            if len(chunk_inputs) < 2:
                # repeats is not chunk
                chunk_op._repeats = rp
            out_chunk = chunk_op.new_chunk(
                chunk_inputs, shape=chunk_shape, index=out_idx, order=out.order
            )
            out_chunks.append(out_chunk)

        nsplits = [
            tuple(
                c.shape[i]
                for c in out_chunks
                if all(idx == 0 for j, idx in enumerate(c.index) if j != i)
            )
            for i in range(len(out_chunks[0].shape))
        ]
        new_op = op.copy()
        return new_op.new_tensors(
            op.inputs, out.shape, order=out.order, chunks=out_chunks, nsplits=nsplits
        )

    @classmethod
    def execute(cls, ctx, op):
        inputs, device_id, xp = as_same_device(
            [ctx[c.key] for c in op.inputs], device=op.device, ret_extra=True
        )

        a = inputs[0]
        if len(inputs) > 1:
            repeats = inputs[1]
        else:
            repeats = op.repeats

        with device(device_id):
            ctx[op.outputs[0].key] = xp.repeat(a, repeats=repeats, axis=op.axis)


[docs]def repeat(a, repeats, axis=None): """ Repeat elements of a tensor. Parameters ---------- a : array_like Input tensor. repeats : int or tensor of ints The number of repetitions for each element. `repeats` is broadcasted to fit the shape of the given axis. axis : int, optional The axis along which to repeat values. By default, use the flattened input tensor, and return a flat output tensor. Returns ------- repeated_tensor : Tensor Output array which has the same shape as `a`, except along the given axis. See Also -------- tile : Tile a tensor. Examples -------- >>> import mars.tensor as mt >>> mt.repeat(3, 4).execute() array([3, 3, 3, 3]) >>> x = mt.array([[1,2],[3,4]]) >>> mt.repeat(x, 2).execute() array([1, 1, 2, 2, 3, 3, 4, 4]) >>> mt.repeat(x, 3, axis=1).execute() array([[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]]) >>> mt.repeat(x, [1, 2], axis=0).execute() array([[1, 2], [3, 4], [3, 4]]) """ op = TensorRepeat(axis=axis) return op(a, repeats)