Source code for xorbits._mars.tensor.fft.irfft

# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from ... import opcodes as OperandDef
from ..datasource import tensor as astensor
from .core import TensorFFTMixin, TensorRealFFT, validate_fft


class TensorIRFFT(TensorRealFFT, TensorFFTMixin):
    _op_type_ = OperandDef.IRFFT

    def __init__(self, n=None, axis=-1, norm=None, **kw):
        super().__init__(_n=n, _axis=axis, _norm=norm, **kw)

    @classmethod
    def _get_shape(cls, op, shape):
        new_shape = list(shape)
        if op.n is not None:
            new_shape[op.axis] = op.n
        else:
            new_shape[op.axis] = 2 * (new_shape[op.axis] - 1)
        return tuple(new_shape)


[docs]def irfft(a, n=None, axis=-1, norm=None): """ Compute the inverse of the n-point DFT for real input. This function computes the inverse of the one-dimensional *n*-point discrete Fourier Transform of real input computed by `rfft`. In other words, ``irfft(rfft(a), len(a)) == a`` to within numerical accuracy. (See Notes below for why ``len(a)`` is necessary here.) The input is expected to be in the form returned by `rfft`, i.e. the real zero-frequency term followed by the complex positive frequency terms in order of increasing frequency. Since the discrete Fourier Transform of real input is Hermitian-symmetric, the negative frequency terms are taken to be the complex conjugates of the corresponding positive frequency terms. Parameters ---------- a : array_like The input tensor. n : int, optional Length of the transformed axis of the output. For `n` output points, ``n//2+1`` input points are necessary. If the input is longer than this, it is cropped. If it is shorter than this, it is padded with zeros. If `n` is not given, it is determined from the length of the input along the axis specified by `axis`. axis : int, optional Axis over which to compute the inverse FFT. If not given, the last axis is used. norm : {None, "ortho"}, optional Normalization mode (see `mt.fft`). Default is None. Returns ------- out : Tensor The truncated or zero-padded input, transformed along the axis indicated by `axis`, or the last one if `axis` is not specified. The length of the transformed axis is `n`, or, if `n` is not given, ``2*(m-1)`` where ``m`` is the length of the transformed axis of the input. To get an odd number of output points, `n` must be specified. Raises ------ IndexError If `axis` is larger than the last axis of `a`. See Also -------- mt.fft : For definition of the DFT and conventions used. rfft : The one-dimensional FFT of real input, of which `irfft` is inverse. fft : The one-dimensional FFT. irfft2 : The inverse of the two-dimensional FFT of real input. irfftn : The inverse of the *n*-dimensional FFT of real input. Notes ----- Returns the real valued `n`-point inverse discrete Fourier transform of `a`, where `a` contains the non-negative frequency terms of a Hermitian-symmetric sequence. `n` is the length of the result, not the input. If you specify an `n` such that `a` must be zero-padded or truncated, the extra/removed values will be added/removed at high frequencies. One can thus resample a series to `m` points via Fourier interpolation by: ``a_resamp = irfft(rfft(a), m)``. Examples -------- >>> import mars.tenosr as mt >>> mt.fft.ifft([1, -1j, -1, 1j]).execute() array([ 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j]) >>> mt.fft.irfft([1, -1j, -1]).execute() array([ 0., 1., 0., 0.]) Notice how the last term in the input to the ordinary `ifft` is the complex conjugate of the second term, and the output has zero imaginary part everywhere. When calling `irfft`, the negative frequencies are not specified, and the output array is purely real. """ a = astensor(a) validate_fft(a, axis=axis, norm=norm) op = TensorIRFFT(n=n, axis=axis, norm=norm, dtype=np.dtype(np.float_)) return op(a)