# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ... import opcodes as OperandDef
from ...core import ENTITY_TYPE, recursive_tile
from ...serialization.serializables import AnyField, Int32Field, KeyField, TupleField
from ...utils import has_unknown_shape
from ..datasource import tensor as astensor
from ..operands import TensorHasInput, TensorOperandMixin
from ..utils import calc_object_length, filter_inputs, validate_axis
class TensorInsert(TensorHasInput, TensorOperandMixin):
_op_type_ = OperandDef.INSERT
_index_obj = AnyField("index_obj")
_values = AnyField("values")
_axis = Int32Field("axis")
_input = KeyField("input")
# for chunk
_range_on_axis = TupleField("range_on_axis")
def __init__(
self, index_obj=None, values=None, axis=None, range_on_axis=None, **kw
):
super().__init__(
_index_obj=index_obj,
_values=values,
_axis=axis,
_range_on_axis=range_on_axis,
**kw
)
@property
def index_obj(self):
return self._index_obj
@property
def values(self):
return self._values
@property
def axis(self):
return self._axis
@property
def range_on_axis(self):
return self._range_on_axis
def _set_inputs(self, inputs):
super()._set_inputs(inputs)
inputs_iter = iter(self._inputs[1:])
if isinstance(self._index_obj, ENTITY_TYPE):
self._index_obj = next(inputs_iter)
if isinstance(self._values, ENTITY_TYPE):
self._values = next(inputs_iter)
@classmethod
def tile(cls, op: "TensorInsert"):
inp = op.inputs[0]
axis = op.axis
if axis is None:
inp = yield from recursive_tile(inp.flatten())
axis = 0
else:
new_splits = [s if i == axis else sum(s) for i, s in enumerate(inp.nsplits)]
inp = yield from recursive_tile(inp.rechunk(new_splits))
if has_unknown_shape(inp):
yield
index_obj = op.index_obj
values = op.values
if isinstance(values, ENTITY_TYPE):
# if values is Mars type, we rechunk it into one chunk and
# all insert chunks depend on it
values = yield from recursive_tile(values.rechunk(values.shape))
nsplits_on_axis = []
if isinstance(index_obj, int):
splits = inp.nsplits[axis]
cum_splits = np.cumsum([0] + list(splits))
# add 1 for last split
cum_splits[-1] = cum_splits[-1] + 1
in_idx = cum_splits.searchsorted(index_obj, side="right") - 1
out_chunks = []
for chunk in inp.chunks:
if chunk.index[axis] == in_idx:
chunk_op = op.copy().reset_key()
chunk_op._index_obj = index_obj - cum_splits[in_idx]
if isinstance(values, ENTITY_TYPE):
chunk_values = values.chunks[0]
else:
chunk_values = values
inputs = filter_inputs([chunk, chunk_values])
shape = tuple(
s + calc_object_length(index_obj) if i == axis else s
for i, s in enumerate(chunk.shape)
)
out_chunks.append(
chunk_op.new_chunk(inputs, shape=shape, index=chunk.index)
)
nsplits_on_axis.append(shape[axis])
else:
out_chunks.append(chunk)
nsplits_on_axis.append(chunk.shape[axis])
elif isinstance(index_obj, ENTITY_TYPE):
index_obj = yield from recursive_tile(index_obj.rechunk(index_obj.shape))
offset = 0
out_chunks = []
for chunk in inp.chunks:
chunk_op = op.copy().reset_key()
chunk_op._index_obj = index_obj.chunks[0]
if isinstance(values, ENTITY_TYPE):
chunk_values = values.chunks[0]
else:
chunk_values = values
chunk_op._values = chunk_values
if chunk.index[axis] + 1 == len(inp.nsplits[axis]):
# the last chunk on axis
chunk_op._range_on_axis = (offset, offset + chunk.shape[axis] + 1)
else:
chunk_op._range_on_axis = (offset, offset + chunk.shape[axis])
shape = tuple(
np.nan if j == axis else s for j, s in enumerate(chunk.shape)
)
inputs = filter_inputs([chunk, index_obj.chunks[0], chunk_values])
out_chunks.append(
chunk_op.new_chunk(inputs, shape=shape, index=chunk.index)
)
offset += chunk.shape[axis]
nsplits_on_axis.append(np.nan)
else:
# index object is slice or sequence of ints
if isinstance(index_obj, slice):
index_obj = range(
index_obj.start or 0, index_obj.stop, index_obj.step or 1
)
splits = inp.nsplits[axis]
cum_splits = np.cumsum([0] + list(splits))
# add 1 for last split
cum_splits[-1] = cum_splits[-1] + 1
chunk_idx_params = [[[], []] for _ in splits]
for i, int_idx in enumerate(index_obj):
in_idx = cum_splits.searchsorted(int_idx, side="right") - 1
chunk_idx_params[in_idx][0].append(int_idx - cum_splits[in_idx])
chunk_idx_params[in_idx][1].append(i)
out_chunks = []
offset = 0
for chunk in inp.chunks:
idx_on_axis = chunk.index[axis]
if len(chunk_idx_params[idx_on_axis][0]) > 0:
chunk_op = op.copy().reset_key()
chunk_index_obj = chunk_idx_params[idx_on_axis][0]
shape = tuple(
s + len(chunk_index_obj) if j == axis else s
for j, s in enumerate(chunk.shape)
)
if isinstance(values, int):
chunk_op._index_obj = chunk_index_obj
out_chunks.append(
chunk_op.new_chunk([chunk], shape=shape, index=chunk.index)
)
elif isinstance(values, ENTITY_TYPE):
chunk_op._values = values.chunks[0]
if chunk.index[axis] + 1 == len(inp.nsplits[axis]):
chunk_op._range_on_axis = (
offset,
offset + chunk.shape[axis] + 1,
)
else:
chunk_op._range_on_axis = (
offset,
offset + chunk.shape[axis],
)
out_chunks.append(
chunk_op.new_chunk(
[chunk, values.chunks[0]],
shape=shape,
index=chunk.index,
)
)
offset += chunk.shape[axis]
else:
chunk_op._index_obj = chunk_index_obj
values = np.asarray(values)
to_shape = [
calc_object_length(index_obj, chunk.shape[axis])
] + [s for j, s in enumerate(inp.shape) if j != axis]
if all(j == k for j, k in zip(to_shape, values.shape)):
chunk_values = np.asarray(values)[
chunk_idx_params[idx_on_axis][1]
]
chunk_op._values = chunk_values
out_chunks.append(
chunk_op.new_chunk(
[chunk], shape=shape, index=chunk.index
)
)
else:
out_chunks.append(
chunk_op.new_chunk(
[chunk], shape=shape, index=chunk.index
)
)
nsplits_on_axis.append(shape[axis])
else:
out_chunks.append(chunk)
nsplits_on_axis.append(chunk.shape[axis])
nsplits = tuple(
s if i != axis else tuple(nsplits_on_axis)
for i, s in enumerate(inp.nsplits)
)
out = op.outputs[0]
new_op = op.copy()
return new_op.new_tensors(
op.inputs,
shape=out.shape,
order=out.order,
chunks=out_chunks,
nsplits=nsplits,
)
@classmethod
def execute(cls, ctx, op: "TensorInsert"):
inp = ctx[op.input.key]
index_obj = (
ctx[op.index_obj.key] if hasattr(op.index_obj, "key") else op.index_obj
)
values = ctx[op.values.key] if hasattr(op.values, "key") else op.values
if op.range_on_axis is None:
ctx[op.outputs[0].key] = np.insert(inp, index_obj, values, axis=op.axis)
else:
if isinstance(index_obj, slice):
index_obj = np.arange(
index_obj.step or 0, index_obj.stop, index_obj.step or 1
)
else:
index_obj = np.array(index_obj)
values = np.asarray(values)
part_index = [
i
for i, idx in enumerate(index_obj)
if ((idx >= op.range_on_axis[0]) and idx < op.range_on_axis[1])
]
if (
(values.ndim > 0)
and len(index_obj) == len(values)
and (values[0].ndim > 0 or inp.ndim == 1)
):
ctx[op.outputs[0].key] = np.insert(
inp,
index_obj[part_index] - op.range_on_axis[0],
values[part_index],
axis=op.axis,
)
else:
ctx[op.outputs[0].key] = np.insert(
inp,
index_obj[part_index] - op.range_on_axis[0],
values,
axis=op.axis,
)
def __call__(self, arr, obj, values, shape):
return self.new_tensor(
filter_inputs([arr, obj, values]), shape=shape, order=arr.order
)
[docs]def insert(arr, obj, values, axis=None):
"""
Insert values along the given axis before the given indices.
Parameters
----------
arr : array like
Input array.
obj : int, slice or sequence of ints
Object that defines the index or indices before which `values` is
inserted.
values : array_like
Values to insert into `arr`. If the type of `values` is different
from that of `arr`, `values` is converted to the type of `arr`.
`values` should be shaped so that ``arr[...,obj,...] = values``
is legal.
axis : int, optional
Axis along which to insert `values`. If `axis` is None then `arr`
is flattened first.
Returns
-------
out : ndarray
A copy of `arr` with `values` inserted. Note that `insert`
does not occur in-place: a new array is returned. If
`axis` is None, `out` is a flattened array.
See Also
--------
append : Append elements at the end of an array.
concatenate : Join a sequence of arrays along an existing axis.
delete : Delete elements from an array.
Notes
-----
Note that for higher dimensional inserts `obj=0` behaves very different
from `obj=[0]` just like `arr[:,0,:] = values` is different from
`arr[:,[0],:] = values`.
Examples
--------
>>> import mars.tensor as mt
>>> a = mt.array([[1, 1], [2, 2], [3, 3]])
>>> a.execute()
array([[1, 1],
[2, 2],
[3, 3]])
>>> mt.insert(a, 1, 5).execute()
array([1, 5, 1, ..., 2, 3, 3])
>>> mt.insert(a, 1, 5, axis=1).execute()
array([[1, 5, 1],
[2, 5, 2],
[3, 5, 3]])
Difference between sequence and scalars:
>>> mt.insert(a, [1], [[1],[2],[3]], axis=1).execute()
array([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
>>> b = a.flatten()
>>> b.execute()
array([1, 1, 2, 2, 3, 3])
>>> mt.insert(b, [2, 2], [5, 6]).execute()
array([1, 1, 5, ..., 2, 3, 3])
>>> mt.insert(b, slice(2, 4), [5, 6]).execute()
array([1, 1, 5, ..., 2, 3, 3])
>>> mt.insert(b, [2, 2], [7.13, False]).execute() # type casting
array([1, 1, 7, ..., 2, 3, 3])
>>> x = mt.arange(8).reshape(2, 4)
>>> idx = (1, 3)
>>> mt.insert(x, idx, 999, axis=1).execute()
array([[ 0, 999, 1, 2, 999, 3],
[ 4, 999, 5, 6, 999, 7]])
"""
arr = astensor(arr)
if getattr(obj, "ndim", 0) > 1: # pragma: no cover
raise ValueError(
"index array argument obj to insert must be one dimensional or scalar"
)
if axis is None:
# if axis is None, array will be flatten
arr_size = arr.size
idx_length = calc_object_length(obj, size=arr_size)
shape = (arr_size + idx_length,)
else:
validate_axis(arr.ndim, axis)
idx_length = calc_object_length(obj, size=arr.shape[axis])
shape = tuple(
s + idx_length if i == axis else s for i, s in enumerate(arr.shape)
)
op = TensorInsert(index_obj=obj, values=values, axis=axis, dtype=arr.dtype)
return op(arr, obj, values, shape)