# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ..core import Tensor
from ..datasource import tensor as astensor
from ..utils import validate_axis
[docs]def compress(condition, a, axis=None, out=None):
"""
Return selected slices of a tensor along given axis.
When working along a given axis, a slice along that axis is returned in
`output` for each index where `condition` evaluates to True. When
working on a 1-D array, `compress` is equivalent to `extract`.
Parameters
----------
condition : 1-D tensor of bools
Tensor that selects which entries to return. If len(condition)
is less than the size of `a` along the given axis, then output is
truncated to the length of the condition tensor.
a : array_like
Tensor from which to extract a part.
axis : int, optional
Axis along which to take slices. If None (default), work on the
flattened tensor.
out : Tensor, optional
Output tensor. Its type is preserved and it must be of the right
shape to hold the output.
Returns
-------
compressed_array : Tensor
A copy of `a` without the slices along axis for which `condition`
is false.
See Also
--------
take, choose, diag, diagonal, select
Tensor.compress : Equivalent method in ndarray
mt.extract: Equivalent method when working on 1-D arrays
Examples
--------
>>> import mars.tensor as mt
>>> a = mt.array([[1, 2], [3, 4], [5, 6]])
>>> a.execute()
array([[1, 2],
[3, 4],
[5, 6]])
>>> mt.compress([0, 1], a, axis=0).execute()
array([[3, 4]])
>>> mt.compress([False, True, True], a, axis=0).execute()
array([[3, 4],
[5, 6]])
>>> mt.compress([False, True], a, axis=1).execute()
array([[2],
[4],
[6]])
Working on the flattened tensor does not return slices along an axis but
selects elements.
>>> mt.compress([False, True], a).execute()
array([2])
"""
a = astensor(a)
condition = astensor(condition, dtype=bool)
if condition.ndim != 1:
raise ValueError("condition must be an 1-d tensor")
if axis is None:
a = a.ravel()
if len(condition) < a.size:
a = a[: len(condition)]
return a[condition]
try:
axis = validate_axis(a.ndim, axis)
except ValueError:
raise np.AxisError(
f"axis {axis} is out of bounds for tensor of dimension {a.ndim}"
)
try:
if len(condition) < a.shape[axis]:
a = a[(slice(None),) * axis + (slice(len(condition)),)]
t = a[(slice(None),) * axis + (condition,)]
if out is None:
return t
if out is not None and not isinstance(out, Tensor):
raise TypeError(f"out should be Tensor object, got {type(out)} instead")
if not np.can_cast(out.dtype, t.dtype, "safe"):
raise TypeError(
f"Cannot cast array data from dtype('{out.dtype}') to dtype('{t.dtype}') "
"according to the rule 'safe'"
)
# skip shape check because out shape is unknown
out.data = t.astype(out.dtype, order=out.order.value).data
return out
except IndexError:
raise np.AxisError(
f"axis {len(condition)} is out of bounds for tensor of dimension 1"
)