# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ... import opcodes as OperandDef
from ...core import ExecutableTuple, recursive_tile
from ...lib.sparse.core import get_array_module
from ...serialization.serializables import AnyField, Int32Field, KeyField
from ..core import Tensor
from ..datasource import tensor as astensor
from ..operands import TensorHasInput, TensorOperandMixin
from ..utils import calc_sliced_size
class TensorSplit(TensorHasInput, TensorOperandMixin):
_op_type_ = OperandDef.ARRAY_SPLIT
_input = KeyField("input")
_indices_or_sections = AnyField("indices_or_sections")
_axis = Int32Field("axis")
def __init__(self, axis=None, **kw):
super().__init__(_axis=axis, **kw)
@property
def indices_or_sections(self):
return self._indices_or_sections
@property
def axis(self):
return getattr(self, "_axis", 0)
@property
def output_limit(self):
return float("inf")
def _set_inputs(self, inputs):
super()._set_inputs(inputs)
self._input = self._inputs[0]
if len(self._inputs) > 1:
self._indices_or_sections = self._inputs[1]
def __call__(self, a, indices_or_sections, is_split=False):
axis = self._axis
size = a.shape[axis]
if np.isnan(size):
raise ValueError(
"cannot split array with unknown shape, "
"call `.execute()` on input tensor first"
)
if (
isinstance(indices_or_sections, Tensor)
and hasattr(indices_or_sections.op, "data")
and indices_or_sections.op.data is not None
):
indices_or_sections = indices_or_sections.op.data
try:
indices_or_sections = int(indices_or_sections)
if is_split:
if size % indices_or_sections:
raise ValueError(
"tensor split does not result in an equal division"
)
nparts = indices_or_sections
nsplit = (size // indices_or_sections,) * nparts
else:
nparts = indices_or_sections
if size % indices_or_sections == 0:
nsplit = (size // indices_or_sections,) * nparts
else:
nsplit = (size // indices_or_sections + 1,) * (
size % indices_or_sections
) + (size // indices_or_sections,) * (
size - size % indices_or_sections
)
except TypeError:
if isinstance(indices_or_sections, Tensor):
nparts = indices_or_sections.shape[0] + 1
nsplit = (np.nan,) * nparts
else:
ind = indices_or_sections = get_array_module(
indices_or_sections
).asarray(indices_or_sections)
if indices_or_sections.ndim != 1 or not np.issubdtype(
indices_or_sections.dtype, np.integer
):
raise TypeError("slice indices must be integers or None")
nparts = indices_or_sections.shape[0] + 1
get = lambda i: None if i < 0 or i >= len(ind) else ind[i]
nsplit = [
calc_sliced_size(size, slice(get(j - 1), get(j)))
for j in range(nparts)
]
inputs = [a]
if isinstance(indices_or_sections, Tensor):
inputs.append(indices_or_sections)
else:
self._indices_or_sections = indices_or_sections
kws = [
{
"i": i,
"shape": a.shape[:axis] + (nsplit[i],) + a.shape[axis + 1 :],
"order": a.order,
}
for i in range(nparts)
]
return ExecutableTuple(self.new_tensors(inputs, kws=kws, output_limit=nparts))
@classmethod
def tile(cls, op):
in_tensor = op.input
splits = op.outputs
axis = op.axis
acc_shapes = np.cumsum([s.shape[axis] for s in splits])
out_kws = [dict() for _ in splits]
for i, split in enumerate(splits):
slc = slice(0 if i == 0 else acc_shapes[i - 1], acc_shapes[i])
new_s = yield from recursive_tile(in_tensor[(slice(None),) * axis + (slc,)])
out_kws[i]["chunks"] = new_s.chunks
out_kws[i]["nsplits"] = new_s.nsplits
out_kws[i]["shape"] = split.shape
out_kws[i]["order"] = op.outputs[i].order
new_op = op.copy()
return new_op.new_tensors(op.inputs, kws=out_kws, output_limit=len(out_kws))
def _split(a, indices_or_sections, axis=0, is_split=False):
op = TensorSplit(axis=axis, dtype=a.dtype)
return op(a, indices_or_sections, is_split=is_split)
[docs]def split(ary, indices_or_sections, axis=0):
"""
Split a tensor into multiple sub-tensors.
Parameters
----------
ary : Tensor
Tensor to be divided into sub-tensors.
indices_or_sections : int or 1-D tensor
If `indices_or_sections` is an integer, N, the array will be divided
into N equal tensors along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D tensor of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the tensor along `axis`,
an empty sub-tensor is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-tensors : list of Tensors
A list of sub-tensors.
Raises
------
ValueError
If `indices_or_sections` is given as an integer, but
a split does not result in equal division.
See Also
--------
array_split : Split a tensor into multiple sub-tensors of equal or
near-equal size. Does not raise an exception if
an equal division cannot be made.
hsplit : Split into multiple sub-arrays horizontally (column-wise).
vsplit : Split tensor into multiple sub-tensors vertically (row wise).
dsplit : Split tensor into multiple sub-tensors along the 3rd axis (depth).
concatenate : Join a sequence of tensors along an existing axis.
stack : Join a sequence of tensors along a new axis.
hstack : Stack tensors in sequence horizontally (column wise).
vstack : Stack tensors in sequence vertically (row wise).
dstack : Stack tensors in sequence depth wise (along third dimension).
Examples
--------
>>> import mars.tensor as mt
>>> x = mt.arange(9.0)
>>> mt.split(x, 3).execute()
[array([ 0., 1., 2.]), array([ 3., 4., 5.]), array([ 6., 7., 8.])]
>>> x = mt.arange(8.0)
>>> mt.split(x, [3, 5, 6, 10]).execute()
[array([ 0., 1., 2.]),
array([ 3., 4.]),
array([ 5.]),
array([ 6., 7.]),
array([], dtype=float64)]
"""
return _split(astensor(ary), indices_or_sections, axis=axis, is_split=True)