Source code for xorbits._mars.tensor.fft.rfft

# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from ... import opcodes as OperandDef
from ..datasource import tensor as astensor
from .core import TensorFFTMixin, TensorRealFFT, validate_fft


class TensorRFFT(TensorRealFFT, TensorFFTMixin):
    _op_type_ = OperandDef.RFFT

    def __init__(self, n=None, axis=-1, norm=None, **kw):
        super().__init__(_n=n, _axis=axis, _norm=norm, **kw)

    @classmethod
    def _get_shape(cls, op, shape):
        new_shape = list(shape)
        if op.n is not None:
            new_shape[op.axis] = op.n
        new_shape[op.axis] = new_shape[op.axis] // 2 + 1
        return tuple(new_shape)


[docs]def rfft(a, n=None, axis=-1, norm=None): """ Compute the one-dimensional discrete Fourier Transform for real input. This function computes the one-dimensional *n*-point discrete Fourier Transform (DFT) of a real-valued array by means of an efficient algorithm called the Fast Fourier Transform (FFT). Parameters ---------- a : array_like Input tensor n : int, optional Number of points along transformation axis in the input to use. If `n` is smaller than the length of the input, the input is cropped. If it is larger, the input is padded with zeros. If `n` is not given, the length of the input along the axis specified by `axis` is used. axis : int, optional Axis over which to compute the FFT. If not given, the last axis is used. norm : {None, "ortho"}, optional Normalization mode (see `mt.fft`). Default is None. Returns ------- out : complex Tensor The truncated or zero-padded input, transformed along the axis indicated by `axis`, or the last one if `axis` is not specified. If `n` is even, the length of the transformed axis is ``(n/2)+1``. If `n` is odd, the length is ``(n+1)/2``. Raises ------ IndexError If `axis` is larger than the last axis of `a`. See Also -------- mt.fft : For definition of the DFT and conventions used. irfft : The inverse of `rfft`. fft : The one-dimensional FFT of general (complex) input. fftn : The *n*-dimensional FFT. rfftn : The *n*-dimensional FFT of real input. Notes ----- When the DFT is computed for purely real input, the output is Hermitian-symmetric, i.e. the negative frequency terms are just the complex conjugates of the corresponding positive-frequency terms, and the negative-frequency terms are therefore redundant. This function does not compute the negative frequency terms, and the length of the transformed axis of the output is therefore ``n//2 + 1``. When ``A = rfft(a)`` and fs is the sampling frequency, ``A[0]`` contains the zero-frequency term 0*fs, which is real due to Hermitian symmetry. If `n` is even, ``A[-1]`` contains the term representing both positive and negative Nyquist frequency (+fs/2 and -fs/2), and must also be purely real. If `n` is odd, there is no term at fs/2; ``A[-1]`` contains the largest positive frequency (fs/2*(n-1)/n), and is complex in the general case. If the input `a` contains an imaginary part, it is silently discarded. Examples -------- >>> import mars.tensor as mt >>> mt.fft.fft([0, 1, 0, 0]).execute() array([ 1.+0.j, 0.-1.j, -1.+0.j, 0.+1.j]) >>> mt.fft.rfft([0, 1, 0, 0]).execute() array([ 1.+0.j, 0.-1.j, -1.+0.j]) Notice how the final element of the `fft` output is the complex conjugate of the second element, for real input. For `rfft`, this symmetry is exploited to compute only the non-negative frequency terms. """ a = astensor(a) validate_fft(a, axis=axis, norm=norm) op = TensorRFFT(n=n, axis=axis, norm=norm, dtype=np.dtype(np.complex_)) return op(a)