# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ... import opcodes as OperandDef
from ..datasource import tensor as astensor
from .core import TensorFFTMixin, TensorRealFFT, validate_fft
class TensorIRFFT(TensorRealFFT, TensorFFTMixin):
_op_type_ = OperandDef.IRFFT
def __init__(self, n=None, axis=-1, norm=None, **kw):
super().__init__(_n=n, _axis=axis, _norm=norm, **kw)
@classmethod
def _get_shape(cls, op, shape):
new_shape = list(shape)
if op.n is not None:
new_shape[op.axis] = op.n
else:
new_shape[op.axis] = 2 * (new_shape[op.axis] - 1)
return tuple(new_shape)
[docs]def irfft(a, n=None, axis=-1, norm=None):
"""
Compute the inverse of the n-point DFT for real input.
This function computes the inverse of the one-dimensional *n*-point
discrete Fourier Transform of real input computed by `rfft`.
In other words, ``irfft(rfft(a), len(a)) == a`` to within numerical
accuracy. (See Notes below for why ``len(a)`` is necessary here.)
The input is expected to be in the form returned by `rfft`, i.e. the
real zero-frequency term followed by the complex positive frequency terms
in order of increasing frequency. Since the discrete Fourier Transform of
real input is Hermitian-symmetric, the negative frequency terms are taken
to be the complex conjugates of the corresponding positive frequency terms.
Parameters
----------
a : array_like
The input tensor.
n : int, optional
Length of the transformed axis of the output.
For `n` output points, ``n//2+1`` input points are necessary. If the
input is longer than this, it is cropped. If it is shorter than this,
it is padded with zeros. If `n` is not given, it is determined from
the length of the input along the axis specified by `axis`.
axis : int, optional
Axis over which to compute the inverse FFT. If not given, the last
axis is used.
norm : {None, "ortho"}, optional
Normalization mode (see `mt.fft`). Default is None.
Returns
-------
out : Tensor
The truncated or zero-padded input, transformed along the axis
indicated by `axis`, or the last one if `axis` is not specified.
The length of the transformed axis is `n`, or, if `n` is not given,
``2*(m-1)`` where ``m`` is the length of the transformed axis of the
input. To get an odd number of output points, `n` must be specified.
Raises
------
IndexError
If `axis` is larger than the last axis of `a`.
See Also
--------
mt.fft : For definition of the DFT and conventions used.
rfft : The one-dimensional FFT of real input, of which `irfft` is inverse.
fft : The one-dimensional FFT.
irfft2 : The inverse of the two-dimensional FFT of real input.
irfftn : The inverse of the *n*-dimensional FFT of real input.
Notes
-----
Returns the real valued `n`-point inverse discrete Fourier transform
of `a`, where `a` contains the non-negative frequency terms of a
Hermitian-symmetric sequence. `n` is the length of the result, not the
input.
If you specify an `n` such that `a` must be zero-padded or truncated, the
extra/removed values will be added/removed at high frequencies. One can
thus resample a series to `m` points via Fourier interpolation by:
``a_resamp = irfft(rfft(a), m)``.
Examples
--------
>>> import mars.tenosr as mt
>>> mt.fft.ifft([1, -1j, -1, 1j]).execute()
array([ 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j])
>>> mt.fft.irfft([1, -1j, -1]).execute()
array([ 0., 1., 0., 0.])
Notice how the last term in the input to the ordinary `ifft` is the
complex conjugate of the second term, and the output has zero imaginary
part everywhere. When calling `irfft`, the negative frequencies are not
specified, and the output array is purely real.
"""
a = astensor(a)
validate_fft(a, axis=axis, norm=norm)
op = TensorIRFFT(n=n, axis=axis, norm=norm, dtype=np.dtype(np.float_))
return op(a)