Source code for xorbits._mars.tensor.linalg.tensordot

# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
from collections.abc import Iterable

import numpy as np

from ... import opcodes as OperandDef
from ...serialization.serializables import FieldTypes, KeyField, TupleField
from ...utils import has_unknown_shape
from ..arithmetic.utils import chunk_tree_add
from ..array_utils import as_same_device, device, is_sparse_module
from ..core import TensorOrder
from ..datasource import tensor as astensor
from ..operands import TensorOperand, TensorOperandMixin
from ..utils import unify_chunks


class TensorTensorDot(TensorOperand, TensorOperandMixin):
    _op_type_ = OperandDef.TENSORDOT

    _a = KeyField("a")
    _b = KeyField("b")
    _a_axes = TupleField("a_axes", FieldTypes.int32)
    _b_axes = TupleField("b_axes", FieldTypes.int32)

    def __init__(self, a_axes=None, b_axes=None, **kw):
        super().__init__(_a_axes=a_axes, _b_axes=b_axes, **kw)

    @property
    def a(self):
        return self._a

    @property
    def b(self):
        return self._b

    @property
    def a_axes(self):
        return self._a_axes

    @property
    def b_axes(self):
        return self._b_axes

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)
        self._a = self._inputs[0]
        self._b = self._inputs[1]

    def __call__(self, a, b):
        shape = tuple(
            s for i, s in enumerate(a.shape) if i not in set(self._a_axes)
        ) + tuple(s for i, s in enumerate(b.shape) if i not in set(self._b_axes))
        return self.new_tensor([a, b], shape, order=TensorOrder.C_ORDER)

    @classmethod
    def estimate_size(cls, ctx, op):
        chunk = op.outputs[0]
        if chunk.is_sparse():
            return super().estimate_size(ctx, op)

        # empirical value in real environments
        calc_usage = chunk.nbytes

        # add input sizes when sparse-to-dense is needed
        for inp in chunk.inputs:
            if inp.is_sparse():
                calc_usage += inp.nbytes

        ctx[chunk.key] = (chunk.nbytes, calc_usage)

    @classmethod
    def tile(cls, op):
        a, b, a_axes, b_axes = op.a, op.b, op.a_axes, op.b_axes

        c = itertools.count(max(a.ndim, b.ndim))
        a_ax = tuple(a_axes.index(i) if i in a_axes else next(c) for i in range(a.ndim))
        b_ax = tuple(b_axes.index(i) if i in b_axes else next(c) for i in range(b.ndim))
        if has_unknown_shape(*op.inputs):
            yield
        a, b = yield from unify_chunks((a, a_ax), (b, b_ax))
        out = op.outputs[0]

        a_output_indexes = [
            range(len(a.nsplits[i])) for i in range(a.ndim) if i not in a_axes
        ]
        b_output_indexes = [
            range(len(b.nsplits[i])) for i in range(b.ndim) if i not in b_axes
        ]
        output_axes = [(0, i) for i in range(a.ndim) if i not in a_axes] + [
            (1, i) for i in range(b.ndim) if i not in b_axes
        ]

        out_chunks = []
        for out_idx in itertools.product(
            *itertools.chain(a_output_indexes, b_output_indexes)
        ):
            a_indexes = [None] * a.ndim
            b_indexes = [None] * b.ndim
            tensor_shape = []
            for i, idx in enumerate(out_idx):
                t_idx, axis = output_axes[i]
                t = (a, b)[t_idx]
                (a_indexes if t_idx == 0 else b_indexes)[axis] = idx
                tensor_shape.append(t.nsplits[axis][idx])
            tensor_shape = tuple(tensor_shape)

            tensordot_chunks = []
            for contract_indexes in itertools.product(
                *[range(len(a.nsplits[ax])) for ax in a_axes]
            ):
                a_indices, b_indices = list(a_indexes), list(b_indexes)
                for a_axis, contract_index in zip(a_axes, contract_indexes):
                    a_indices[a_axis] = contract_index
                for b_axis, contract_index in zip(b_axes, contract_indexes):
                    b_indices[b_axis] = contract_index

                tensordot_chunk_op = op.copy().reset_key()
                tensordot_chunk = tensordot_chunk_op.new_chunk(
                    [a.cix[tuple(a_indices)], b.cix[tuple(b_indices)]],
                    shape=tensor_shape,
                    order=out.order,
                )
                tensordot_chunks.append(tensordot_chunk)

            if len(tensordot_chunks) == 1:
                c = tensordot_chunks[0]
                chunk_op = c.op.copy()
                chunk = chunk_op.new_chunk(
                    c.inputs, shape=c.shape, index=out_idx, order=out.order
                )
            else:
                chunk = chunk_tree_add(
                    op.dtype, tensordot_chunks, out_idx, tensor_shape, sparse=op.sparse
                )
            out_chunks.append(chunk)

        get_nsplits = lambda t_idx, i: (a, b)[t_idx].nsplits[i]
        nsplits = [get_nsplits(*it) for it in output_axes]
        new_op = op.copy()
        return new_op.new_tensors([a, b], out.shape, chunks=out_chunks, nsplits=nsplits)

    @classmethod
    def execute(cls, ctx, op):
        (a, b), device_id, xp = as_same_device(
            [ctx[c.key] for c in op.inputs], device=op.device, ret_extra=True
        )

        axes = op.a_axes, op.b_axes
        with device(device_id):
            if not op.sparse and is_sparse_module(xp):
                # tell sparse to do calculation on numpy or cupy dot
                ctx[op.outputs[0].key] = xp.tensordot(a, b, axes, sparse=False)
            else:
                ret = xp.tensordot(a, b, axes)
                out = op.outputs[0]
                ctx[out.key] = ret.astype(ret.dtype, order=out.order.value, copy=False)


[docs]def tensordot(a, b, axes=2, sparse=None): """ Compute tensor dot product along specified axes for tensors >= 1-D. Given two tensors (arrays of dimension greater than or equal to one), `a` and `b`, and an array_like object containing two array_like objects, ``(a_axes, b_axes)``, sum the products of `a`'s and `b`'s elements (components) over the axes specified by ``a_axes`` and ``b_axes``. The third argument can be a single non-negative integer_like scalar, ``N``; if it is such, then the last ``N`` dimensions of `a` and the first ``N`` dimensions of `b` are summed over. Parameters ---------- a, b : array_like, len(shape) >= 1 Tensors to "dot". axes : int or (2,) array_like * integer_like If an int N, sum over the last N axes of `a` and the first N axes of `b` in order. The sizes of the corresponding axes must match. * (2,) array_like Or, a list of axes to be summed over, first sequence applying to `a`, second to `b`. Both elements array_like must be of the same length. See Also -------- dot, einsum Notes ----- Three common use cases are: * ``axes = 0`` : tensor product :math:`a\\otimes b` * ``axes = 1`` : tensor dot product :math:`a\\cdot b` * ``axes = 2`` : (default) tensor double contraction :math:`a:b` When `axes` is integer_like, the sequence for evaluation will be: first the -Nth axis in `a` and 0th axis in `b`, and the -1th axis in `a` and Nth axis in `b` last. When there is more than one axis to sum over - and they are not the last (first) axes of `a` (`b`) - the argument `axes` should consist of two sequences of the same length, with the first axis to sum over given first in both sequences, the second axis second, and so forth. Examples -------- >>> import mars.tensor as mt A "traditional" example: >>> a = mt.arange(60.).reshape(3,4,5) >>> b = mt.arange(24.).reshape(4,3,2) >>> c = mt.tensordot(a,b, axes=([1,0],[0,1])) >>> c.shape (5, 2) >>> r = c.execute() >>> r array([[ 4400., 4730.], [ 4532., 4874.], [ 4664., 5018.], [ 4796., 5162.], [ 4928., 5306.]]) >>> # A slower but equivalent way of computing the same... >>> ra = np.arange(60.).reshape(3,4,5) >>> rb = np.arange(24.).reshape(4,3,2) >>> d = np.zeros((5,2)) >>> for i in range(5): ... for j in range(2): ... for k in range(3): ... for n in range(4): ... d[i,j] += ra[k,n,i] * rb[n,k,j] >>> r == d array([[ True, True], [ True, True], [ True, True], [ True, True], [ True, True]], dtype=bool) An extended example taking advantage of the overloading of + and \\*: >>> a = mt.array(range(1, 9)) >>> a.shape = (2, 2, 2) >>> A = mt.array(('a', 'b', 'c', 'd'), dtype=object) >>> A.shape = (2, 2) >>> a.execute(); A.execute() array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) array([[a, b], [c, d]], dtype=object) >>> mt.tensordot(a, A).execute() # third argument default is 2 for double-contraction array([abbcccdddd, aaaaabbbbbbcccccccdddddddd], dtype=object) >>> mt.tensordot(a, A, 1).execute() array([[[acc, bdd], [aaacccc, bbbdddd]], [[aaaaacccccc, bbbbbdddddd], [aaaaaaacccccccc, bbbbbbbdddddddd]]], dtype=object) >>> mt.tensordot(a, A, 0).execute() # tensor product (result too long to incl.) array([[[[[a, b], [c, d]], ... >>> mt.tensordot(a, A, (0, 1)).execute() array([[[abbbbb, cddddd], [aabbbbbb, ccdddddd]], [[aaabbbbbbb, cccddddddd], [aaaabbbbbbbb, ccccdddddddd]]], dtype=object) >>> mt.tensordot(a, A, (2, 1)).execute() array([[[abb, cdd], [aaabbbb, cccdddd]], [[aaaaabbbbbb, cccccdddddd], [aaaaaaabbbbbbbb, cccccccdddddddd]]], dtype=object) >>> mt.tensordot(a, A, ((0, 1), (0, 1))).execute() array([abbbcccccddddddd, aabbbbccccccdddddddd], dtype=object) >>> mt.tensordot(a, A, ((2, 1), (1, 0))).execute() array([acccbbdddd, aaaaacccccccbbbbbbdddddddd], dtype=object) """ a = astensor(a) b = astensor(b) if isinstance(axes, Iterable): a_axes, b_axes = axes else: a_axes = tuple(range(a.ndim - 1, a.ndim - axes - 1, -1)) b_axes = tuple(range(0, axes)) if isinstance(a_axes, Iterable): a_axes = tuple(a_axes) else: a_axes = (a_axes,) a_axes = tuple(axis if axis >= 0 else a.ndim + axis for axis in a_axes) if isinstance(b_axes, Iterable): b_axes = tuple(b_axes) else: b_axes = (b_axes,) b_axes = tuple(axis if axis >= 0 else b.ndim + axis for axis in b_axes) if ( a.shape and b.shape and not np.array_equal( np.array(a.shape)[list(a_axes)], np.array(b.shape)[list(b_axes)] ) ): raise ValueError("shape-mismatch for sum") sparse = sparse if sparse is not None else a.issparse() and b.issparse() op = TensorTensorDot( a_axes=a_axes, b_axes=b_axes, dtype=np.promote_types(a.dtype, b.dtype), sparse=sparse, ) return op(a, b)