Source code for xorbits._mars.tensor.linalg.cholesky

# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
from numpy.linalg import LinAlgError

from ... import opcodes as OperandDef
from ...core import recursive_tile
from ...serialization.serializables import BoolField, KeyField
from ...utils import has_unknown_shape, is_same_module
from ..array_utils import as_same_device, device
from ..core import TensorOrder
from ..datasource import tensor as astensor
from ..operands import TensorHasInput, TensorOperand, TensorOperandMixin


class TensorCholesky(TensorHasInput, TensorOperandMixin):
    _op_type_ = OperandDef.CHOLESKY

    _input = KeyField("input")
    _lower = BoolField("lower")

    def __init__(self, lower=None, **kw):
        super().__init__(_lower=lower, **kw)

    @property
    def lower(self):
        return self._lower

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)
        self._input = self._inputs[0]

    def __call__(self, a):
        return self.new_tensor([a], a.shape, order=TensorOrder.C_ORDER)

    @classmethod
    def tile(cls, op):
        from ..base import TensorTranspose
        from ..datasource.zeros import TensorZeros
        from ..utils import reverse_order
        from .dot import TensorDot
        from .solve_triangular import TensorSolveTriangular

        tensor = op.outputs[0]
        in_tensor = op.input
        if has_unknown_shape(in_tensor):
            yield
        if in_tensor.nsplits[0] != in_tensor.nsplits[1]:
            # all chunks on diagonal should be square
            nsplits = in_tensor.nsplits[0]
            in_tensor = yield from recursive_tile(in_tensor.rechunk([nsplits, nsplits]))

        lower_chunks, upper_chunks = {}, {}
        for i in range(in_tensor.chunk_shape[0]):
            for j in range(in_tensor.chunk_shape[1]):
                if i < j:
                    lower_shape = (in_tensor.nsplits[0][i], in_tensor.nsplits[1][j])
                    lower_chunk = TensorZeros(
                        dtype=tensor.dtype, shape=lower_shape, order=tensor.order.value
                    ).new_chunk(
                        None,
                        shape=lower_shape,
                        index=(i, j),
                        order=tensor.order,
                    )
                    upper_shape = (in_tensor.nsplits[1][j], in_tensor.nsplits[0][i])
                    upper_chunk = TensorZeros(
                        dtype=tensor.dtype, shape=upper_shape, order=tensor.order.value
                    ).new_chunk(
                        None,
                        shape=upper_shape,
                        index=(j, i),
                        order=tensor.order,
                    )
                    lower_chunks[lower_chunk.index] = lower_chunk
                    upper_chunks[upper_chunk.index] = upper_chunk
                elif i == j:
                    target = in_tensor.cix[i, j]
                    if i > 0:
                        prev_chunks = []
                        for p in range(i):
                            a, b = lower_chunks[i, p], upper_chunks[p, j]
                            prev_chunk = TensorDot(dtype=tensor.dtype).new_chunk(
                                [a, b],
                                shape=(a.shape[0], b.shape[1]),
                                order=tensor.order,
                            )
                            prev_chunks.append(prev_chunk)

                        cholesky_fuse_op = TensorCholeskyFuse()
                        lower_chunk = cholesky_fuse_op.new_chunk(
                            [target] + prev_chunks,
                            shape=target.shape,
                            index=(i, j),
                            order=tensor.order,
                        )
                    else:
                        lower_chunk = TensorCholesky(
                            lower=True, dtype=tensor.dtype
                        ).new_chunk(
                            [target],
                            shape=target.shape,
                            index=(i, j),
                            order=tensor.order,
                        )

                    upper_chunk = TensorTranspose(dtype=lower_chunk.dtype).new_chunk(
                        [lower_chunk],
                        shape=lower_chunk.shape[::-1],
                        index=lower_chunk.index[::-1],
                        order=reverse_order(lower_chunk.order),
                    )
                    lower_chunks[lower_chunk.index] = lower_chunk
                    upper_chunks[upper_chunk.index] = upper_chunk
                else:
                    target = in_tensor.cix[j, i]
                    if j > 0:
                        prev_chunks = []
                        for p in range(j):
                            a, b = lower_chunks[j, p], upper_chunks[p, i]
                            prev_chunk = TensorDot(dtype=tensor.dtype).new_chunk(
                                [a, b],
                                shape=(a.shape[0], b.shape[1]),
                                order=tensor.order,
                            )
                            prev_chunks.append(prev_chunk)
                        cholesky_fuse_op = TensorCholeskyFuse(by_solve_triangular=True)
                        upper_chunk = cholesky_fuse_op.new_chunk(
                            [target] + [lower_chunks[j, j]] + prev_chunks,
                            shape=target.shape,
                            index=(j, i),
                            order=tensor.order,
                        )
                    else:
                        upper_chunk = TensorSolveTriangular(
                            lower=True, dtype=tensor.dtype
                        ).new_chunk(
                            [lower_chunks[j, j], target],
                            shape=target.shape,
                            index=(j, i),
                            order=tensor.order,
                        )
                    lower_chunk = TensorTranspose(dtype=upper_chunk.dtype).new_chunk(
                        [upper_chunk],
                        shape=upper_chunk.shape[::-1],
                        index=upper_chunk.index[::-1],
                        order=reverse_order(upper_chunk.order),
                    )
                    lower_chunks[lower_chunk.index] = lower_chunk
                    upper_chunks[upper_chunk.index] = upper_chunk

        new_op = op.copy()
        if op.lower:
            return new_op.new_tensors(
                op.inputs,
                tensor.shape,
                order=tensor.order,
                chunks=list(lower_chunks.values()),
                nsplits=in_tensor.nsplits,
            )
        else:
            return new_op.new_tensors(
                op.inputs,
                tensor.shape,
                order=tensor.order,
                chunks=list(upper_chunks.values()),
                nsplits=in_tensor.nsplits,
            )

    @classmethod
    def execute(cls, ctx, op):
        chunk = op.outputs[0]
        (a,), device_id, xp = as_same_device(
            [ctx[c.key] for c in op.inputs], device=op.device, ret_extra=True
        )

        with device(device_id):
            if is_same_module(xp, np):
                import scipy.linalg

                ctx[chunk.key] = scipy.linalg.cholesky(a, lower=op.lower)
                return

            r = xp.linalg.cholesky(a)
            if not chunk.op.lower:
                r = r.T.conj()

            ctx[chunk.key] = r


class TensorCholeskyFuse(TensorOperand, TensorOperandMixin):
    _op_type_ = OperandDef.CHOLESKY_FUSE

    _by_solve_triangular = BoolField("by_solve_triangular")

    def __init__(self, by_solve_triangular=None, **kw):
        super().__init__(_by_solve_triangular=by_solve_triangular, **kw)

    @property
    def by_solve_triangular(self):
        return self._by_solve_triangular

    @classmethod
    def _execute_by_cholesky(cls, inputs):
        import scipy.linalg

        target = inputs[0]
        return scipy.linalg.cholesky((target - sum(inputs[1:])), lower=True)

    @classmethod
    def _execute_by_solve_striangular(cls, inputs):
        import scipy.linalg

        target = inputs[0]
        lower = inputs[1]
        return scipy.linalg.solve_triangular(
            lower, (target - sum(inputs[2:])), lower=True
        )

    @classmethod
    def execute(cls, ctx, op):
        inputs = [ctx[c.key] for c in op.inputs]
        if op.by_solve_triangular:
            ret = cls._execute_by_solve_striangular(inputs)
        else:
            ret = cls._execute_by_cholesky(inputs)
        ctx[op.outputs[0].key] = ret


[docs]def cholesky(a, lower=False): """ Cholesky decomposition. Return the Cholesky decomposition, `L * L.H`, of the square matrix `a`, where `L` is lower-triangular and .H is the conjugate transpose operator (which is the ordinary transpose if `a` is real-valued). `a` must be Hermitian (symmetric if real-valued) and positive-definite. Only `L` is actually returned. Parameters ---------- a : (..., M, M) array_like Hermitian (symmetric if all elements are real), positive-definite input matrix. lower : bool Whether to compute the upper or lower triangular Cholesky factorization. Default is upper-triangular. Returns ------- L : (..., M, M) array_like Upper or lower-triangular Cholesky factor of `a`. Raises ------ LinAlgError If the decomposition fails, for example, if `a` is not positive-definite. Notes ----- Broadcasting rules apply, see the `mt.linalg` documentation for details. The Cholesky decomposition is often used as a fast way of solving .. math:: A \\mathbf{x} = \\mathbf{b} (when `A` is both Hermitian/symmetric and positive-definite). First, we solve for :math:`\\mathbf{y}` in .. math:: L \\mathbf{y} = \\mathbf{b}, and then for :math:`\\mathbf{x}` in .. math:: L.H \\mathbf{x} = \\mathbf{y}. Examples -------- >>> import mars.tensor as mt >>> A = mt.array([[1,-2j],[2j,5]]) >>> A.execute() array([[ 1.+0.j, 0.-2.j], [ 0.+2.j, 5.+0.j]]) >>> L = mt.linalg.cholesky(A, lower=True) >>> L.execute() array([[ 1.+0.j, 0.+0.j], [ 0.+2.j, 1.+0.j]]) >>> mt.dot(L, L.T.conj()).execute() # verify that L * L.H = A array([[ 1.+0.j, 0.-2.j], [ 0.+2.j, 5.+0.j]]) >>> A = [[1,-2j],[2j,5]] # what happens if A is only array_like? >>> mt.linalg.cholesky(A, lower=True).execute() array([[ 1.+0.j, 0.+0.j], [ 0.+2.j, 1.+0.j]]) """ a = astensor(a) if a.ndim != 2: # pragma: no cover raise LinAlgError( f"{a.ndim}-dimensional array given. Tensor must be two-dimensional" ) if a.shape[0] != a.shape[1]: # pragma: no cover raise LinAlgError("Input must be square") cho = np.linalg.cholesky(np.array([[1, 2], [2, 5]], dtype=a.dtype)) op = TensorCholesky(lower=lower, dtype=cho.dtype) return op(a)