Skip to content

Commit

Permalink
adding ifft2 method to ops
Browse files Browse the repository at this point in the history
  • Loading branch information
rohithpudari committed Nov 4, 2024
1 parent c052cea commit 5d51865
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 0 deletions.
9 changes: 9 additions & 0 deletions keras/src/backend/jax/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ def fft2(x):
return jnp.real(complex_output), jnp.imag(complex_output)


def ifft2(x):
real, imag = x
H = cast(jnp.shape(real)[-2], jnp.float32)
W = cast(jnp.shape(real)[-1], jnp.float32)
real_conj, imag_conj = real, -imag
fft_real, fft_imag = fft2((real_conj, imag_conj))
return fft_real / (H * W), fft_imag / (H * W)


def rfft(x, fft_length=None):
complex_output = jnp.fft.rfft(x, n=fft_length, axis=-1, norm="backward")
return jnp.real(complex_output), jnp.imag(complex_output)
Expand Down
9 changes: 9 additions & 0 deletions keras/src/backend/numpy/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,15 @@ def fft2(x):
return np.array(real), np.array(imag)


def ifft2(x):
real, imag = x
H = np.float32(real.shape[-2])
W = np.float32(real.shape[-1])
real_conj, imag_conj = real, -imag
fft_real, fft_imag = fft2((real_conj, imag_conj))
return fft_real / (H * W), -fft_imag / (H * W)


def rfft(x, fft_length=None):
complex_output = np.fft.rfft(x, n=fft_length, axis=-1, norm="backward")
# numpy always outputs complex128, so we need to recast the dtype
Expand Down
9 changes: 9 additions & 0 deletions keras/src/backend/tensorflow/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def fft2(x):
return tf.math.real(complex_output), tf.math.imag(complex_output)


def ifft2(x):
real, imag = x
H = cast(tf.shape(real)[-2], "float32")
W = cast(tf.shape(real)[-1], "float32")
real_conj, imag_conj = real, -imag
fft_real, fft_imag = fft2((real_conj, imag_conj))
return fft_real / (H * W), -fft_imag / (H * W)


def rfft(x, fft_length=None):
if fft_length is not None:
fft_length = [fft_length]
Expand Down
9 changes: 9 additions & 0 deletions keras/src/backend/torch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ def fft2(x):
return complex_output.real, complex_output.imag


def ifft2(x):
real, imag = x
H = cast(torch.tensor(real.shape[-2]), "float32")
W = cast(torch.tensor(real.shape[-1]), "float32")
real_conj, imag_conj = real, -imag
fft_real, fft_imag = fft2((real_conj, imag_conj))
return fft_real / (H * W), -fft_imag / (H * W)


def rfft(x, fft_length=None):
x = convert_to_tensor(x)
complex_output = torch.fft.rfft(x, n=fft_length, dim=-1, norm="backward")
Expand Down
75 changes: 75 additions & 0 deletions keras/src/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,81 @@ def fft2(x):
return backend.math.fft2(x)


class IFFT2(Operation):
def __init__(self):
super().__init__()
self.axes = (-2, -1)

def compute_output_spec(self, x):
if not isinstance(x, (tuple, list)) or len(x) != 2:
raise ValueError(
"Input `x` should be a tuple of two tensors - real and "
f"imaginary. Received: x={x}"
)

real, imag = x
# Both real and imaginary parts should have the same shape.
if real.shape != imag.shape:
raise ValueError(
"Input `x` should be a tuple of two tensors - real and "
"imaginary. Both the real and imaginary parts should have the "
f"same shape. Received: x[0].shape = {real.shape}, "
f"x[1].shape = {imag.shape}"
)
# We are calculating 2D IFFT. Hence, rank >= 2.
if len(real.shape) < 2:
raise ValueError(
f"Input should have rank >= 2. "
f"Received: input.shape = {real.shape}"
)

# The axes along which we are calculating IFFT should be fully-defined.
m = real.shape[self.axes[0]]
n = real.shape[self.axes[1]]
if m is None or n is None:
raise ValueError(
f"Input should have its {self.axes} axes fully-defined. "
f"Received: input.shape = {real.shape}"
)

return (
KerasTensor(shape=real.shape, dtype=real.dtype),
KerasTensor(shape=imag.shape, dtype=imag.dtype),
)

def call(self, x):
return backend.math.ifft2(x)


@keras_export("keras.ops.ifft2")
def ifft2(x):
"""Computes the 2D Inverse Fast Fourier Transform along the last two axes of
input.
Args:
x: Tuple of the real and imaginary parts of the input tensor. Both
tensors in the tuple should be of floating type.
Returns:
A tuple containing two tensors - the real and imaginary parts of the
output.
Example:
>>> x = (
... keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]),
... keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]),
... )
>>> ifft2(x)
(array([[ 6., 0.],
[ 0., -2.]], dtype=float32), array([[ 2., 0.],
[ 0., -2.]], dtype=float32))
"""
if any_symbolic_tensors(x):
return IFFT2().symbolic_call(x)
return backend.math.ifft2(x)


class RFFT(Operation):
def __init__(self, fft_length=None):
super().__init__()
Expand Down
29 changes: 29 additions & 0 deletions keras/src/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,15 @@ def test_fft2(self):
self.assertEqual(real_output.shape, ref_shape)
self.assertEqual(imag_output.shape, ref_shape)

def test_ifft2(self):
real = KerasTensor((None, 4, 3), dtype="float32")
imag = KerasTensor((None, 4, 3), dtype="float32")
real_output, imag_output = kmath.ifft2((real, imag))
ref = np.fft.ifft2(np.ones((2, 4, 3)))
ref_shape = (None,) + ref.shape[1:]
self.assertEqual(real_output.shape, ref_shape)
self.assertEqual(imag_output.shape, ref_shape)

@parameterized.parameters([(None,), (1,), (5,)])
def test_rfft(self, fft_length):
x = KerasTensor((None, 4, 3), dtype="float32")
Expand Down Expand Up @@ -355,6 +364,14 @@ def test_fft2(self):
self.assertEqual(real_output.shape, ref.shape)
self.assertEqual(imag_output.shape, ref.shape)

def test_ifft2(self):
real = KerasTensor((2, 4, 3), dtype="float32")
imag = KerasTensor((2, 4, 3), dtype="float32")
real_output, imag_output = kmath.ifft2((real, imag))
ref = np.fft.ifft2(np.ones((2, 4, 3)))
self.assertEqual(real_output.shape, ref.shape)
self.assertEqual(imag_output.shape, ref.shape)

def test_rfft(self):
x = KerasTensor((2, 4, 3), dtype="float32")
real_output, imag_output = kmath.rfft(x)
Expand Down Expand Up @@ -717,6 +734,18 @@ def test_fft2(self):
self.assertAllClose(real_ref, real_output)
self.assertAllClose(imag_ref, imag_output)

def test_ifft2(self):
real = np.random.random((2, 4, 3)).astype(np.float32)
imag = np.random.random((2, 4, 3)).astype(np.float32)
complex_arr = real + 1j * imag

real_output, imag_output = kmath.ifft2((real, imag))
ref = np.fft.ifft2(complex_arr)
real_ref = np.real(ref)
imag_ref = np.imag(ref)
self.assertAllClose(real_ref, real_output)
self.assertAllClose(imag_ref, imag_output)

@parameterized.parameters([(None,), (3,), (15,)])
def test_rfft(self, n):
# Test 1D.
Expand Down

0 comments on commit 5d51865

Please sign in to comment.