# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ... import opcodes as OperandDef
from ..datasource import tensor as astensor
from .core import TensorFFTMixin, TensorRealFFT, validate_fft
class TensorRFFT(TensorRealFFT, TensorFFTMixin):
_op_type_ = OperandDef.RFFT
def __init__(self, n=None, axis=-1, norm=None, **kw):
super().__init__(_n=n, _axis=axis, _norm=norm, **kw)
@classmethod
def _get_shape(cls, op, shape):
new_shape = list(shape)
if op.n is not None:
new_shape[op.axis] = op.n
new_shape[op.axis] = new_shape[op.axis] // 2 + 1
return tuple(new_shape)
[docs]def rfft(a, n=None, axis=-1, norm=None):
"""
Compute the one-dimensional discrete Fourier Transform for real input.
This function computes the one-dimensional *n*-point discrete Fourier
Transform (DFT) of a real-valued array by means of an efficient algorithm
called the Fast Fourier Transform (FFT).
Parameters
----------
a : array_like
Input tensor
n : int, optional
Number of points along transformation axis in the input to use.
If `n` is smaller than the length of the input, the input is cropped.
If it is larger, the input is padded with zeros. If `n` is not given,
the length of the input along the axis specified by `axis` is used.
axis : int, optional
Axis over which to compute the FFT. If not given, the last axis is
used.
norm : {None, "ortho"}, optional
Normalization mode (see `mt.fft`). Default is None.
Returns
-------
out : complex Tensor
The truncated or zero-padded input, transformed along the axis
indicated by `axis`, or the last one if `axis` is not specified.
If `n` is even, the length of the transformed axis is ``(n/2)+1``.
If `n` is odd, the length is ``(n+1)/2``.
Raises
------
IndexError
If `axis` is larger than the last axis of `a`.
See Also
--------
mt.fft : For definition of the DFT and conventions used.
irfft : The inverse of `rfft`.
fft : The one-dimensional FFT of general (complex) input.
fftn : The *n*-dimensional FFT.
rfftn : The *n*-dimensional FFT of real input.
Notes
-----
When the DFT is computed for purely real input, the output is
Hermitian-symmetric, i.e. the negative frequency terms are just the complex
conjugates of the corresponding positive-frequency terms, and the
negative-frequency terms are therefore redundant. This function does not
compute the negative frequency terms, and the length of the transformed
axis of the output is therefore ``n//2 + 1``.
When ``A = rfft(a)`` and fs is the sampling frequency, ``A[0]`` contains
the zero-frequency term 0*fs, which is real due to Hermitian symmetry.
If `n` is even, ``A[-1]`` contains the term representing both positive
and negative Nyquist frequency (+fs/2 and -fs/2), and must also be purely
real. If `n` is odd, there is no term at fs/2; ``A[-1]`` contains
the largest positive frequency (fs/2*(n-1)/n), and is complex in the
general case.
If the input `a` contains an imaginary part, it is silently discarded.
Examples
--------
>>> import mars.tensor as mt
>>> mt.fft.fft([0, 1, 0, 0]).execute()
array([ 1.+0.j, 0.-1.j, -1.+0.j, 0.+1.j])
>>> mt.fft.rfft([0, 1, 0, 0]).execute()
array([ 1.+0.j, 0.-1.j, -1.+0.j])
Notice how the final element of the `fft` output is the complex conjugate
of the second element, for real input. For `rfft`, this symmetry is
exploited to compute only the non-negative frequency terms.
"""
a = astensor(a)
validate_fft(a, axis=axis, norm=norm)
op = TensorRFFT(n=n, axis=axis, norm=norm, dtype=np.dtype(np.complex_))
return op(a)