diff --git a/keras/src/backend/jax/math.py b/keras/src/backend/jax/math.py index 18ba91862a9..ea60e9434db 100644 --- a/keras/src/backend/jax/math.py +++ b/keras/src/backend/jax/math.py @@ -123,6 +123,15 @@ def fft2(x): return jnp.real(complex_output), jnp.imag(complex_output) +def ifft2(x): + real, imag = x + H = cast(jnp.shape(real)[-2], jnp.float32) + W = cast(jnp.shape(real)[-1], jnp.float32) + real_conj, imag_conj = real, -imag + fft_real, fft_imag = fft2((real_conj, imag_conj)) + return fft_real / (H * W), fft_imag / (H * W) + + def rfft(x, fft_length=None): complex_output = jnp.fft.rfft(x, n=fft_length, axis=-1, norm="backward") return jnp.real(complex_output), jnp.imag(complex_output) diff --git a/keras/src/backend/numpy/math.py b/keras/src/backend/numpy/math.py index f9448c92b93..7ea91250ecd 100644 --- a/keras/src/backend/numpy/math.py +++ b/keras/src/backend/numpy/math.py @@ -144,6 +144,15 @@ def fft2(x): return np.array(real), np.array(imag) +def ifft2(x): + real, imag = x + H = np.float32(real.shape[-2]) + W = np.float32(real.shape[-1]) + real_conj, imag_conj = real, -imag + fft_real, fft_imag = fft2((real_conj, imag_conj)) + return fft_real / (H * W), -fft_imag / (H * W) + + def rfft(x, fft_length=None): complex_output = np.fft.rfft(x, n=fft_length, axis=-1, norm="backward") # numpy always outputs complex128, so we need to recast the dtype diff --git a/keras/src/backend/tensorflow/math.py b/keras/src/backend/tensorflow/math.py index 4f920ae1eb6..029a991caa9 100644 --- a/keras/src/backend/tensorflow/math.py +++ b/keras/src/backend/tensorflow/math.py @@ -113,6 +113,15 @@ def fft2(x): return tf.math.real(complex_output), tf.math.imag(complex_output) +def ifft2(x): + real, imag = x + H = cast(tf.shape(real)[-2], "float32") + W = cast(tf.shape(real)[-1], "float32") + real_conj, imag_conj = real, -imag + fft_real, fft_imag = fft2((real_conj, imag_conj)) + return fft_real / (H * W), -fft_imag / (H * W) + + def rfft(x, fft_length=None): if fft_length is not None: fft_length = [fft_length] diff --git a/keras/src/backend/torch/math.py b/keras/src/backend/torch/math.py index 4531ff673cb..b66032066d4 100644 --- a/keras/src/backend/torch/math.py +++ b/keras/src/backend/torch/math.py @@ -203,6 +203,15 @@ def fft2(x): return complex_output.real, complex_output.imag +def ifft2(x): + real, imag = x + H = cast(torch.tensor(real.shape[-2]), "float32") + W = cast(torch.tensor(real.shape[-1]), "float32") + real_conj, imag_conj = real, -imag + fft_real, fft_imag = fft2((real_conj, imag_conj)) + return fft_real / (H * W), -fft_imag / (H * W) + + def rfft(x, fft_length=None): x = convert_to_tensor(x) complex_output = torch.fft.rfft(x, n=fft_length, dim=-1, norm="backward") diff --git a/keras/src/ops/math.py b/keras/src/ops/math.py index fd0a41d5177..6fa6f31d0c6 100644 --- a/keras/src/ops/math.py +++ b/keras/src/ops/math.py @@ -475,6 +475,81 @@ def fft2(x): return backend.math.fft2(x) +class IFFT2(Operation): + def __init__(self): + super().__init__() + self.axes = (-2, -1) + + def compute_output_spec(self, x): + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + f"imaginary. Received: x={x}" + ) + + real, imag = x + # Both real and imaginary parts should have the same shape. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and " + "imaginary. Both the real and imaginary parts should have the " + f"same shape. Received: x[0].shape = {real.shape}, " + f"x[1].shape = {imag.shape}" + ) + # We are calculating 2D IFFT. Hence, rank >= 2. + if len(real.shape) < 2: + raise ValueError( + f"Input should have rank >= 2. " + f"Received: input.shape = {real.shape}" + ) + + # The axes along which we are calculating IFFT should be fully-defined. + m = real.shape[self.axes[0]] + n = real.shape[self.axes[1]] + if m is None or n is None: + raise ValueError( + f"Input should have its {self.axes} axes fully-defined. " + f"Received: input.shape = {real.shape}" + ) + + return ( + KerasTensor(shape=real.shape, dtype=real.dtype), + KerasTensor(shape=imag.shape, dtype=imag.dtype), + ) + + def call(self, x): + return backend.math.ifft2(x) + + +@keras_export("keras.ops.ifft2") +def ifft2(x): + """Computes the 2D Inverse Fast Fourier Transform along the last two axes of + input. + + Args: + x: Tuple of the real and imaginary parts of the input tensor. Both + tensors in the tuple should be of floating type. + + Returns: + A tuple containing two tensors - the real and imaginary parts of the + output. + + Example: + + >>> x = ( + ... keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]), + ... keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]), + ... ) + >>> ifft2(x) + (array([[ 6., 0.], + [ 0., -2.]], dtype=float32), array([[ 2., 0.], + [ 0., -2.]], dtype=float32)) + """ + if any_symbolic_tensors(x): + return IFFT2().symbolic_call(x) + return backend.math.ifft2(x) + + class RFFT(Operation): def __init__(self, fft_length=None): super().__init__() diff --git a/keras/src/ops/math_test.py b/keras/src/ops/math_test.py index 09c87514c78..3f54e5159e8 100644 --- a/keras/src/ops/math_test.py +++ b/keras/src/ops/math_test.py @@ -219,6 +219,15 @@ def test_fft2(self): self.assertEqual(real_output.shape, ref_shape) self.assertEqual(imag_output.shape, ref_shape) + def test_ifft2(self): + real = KerasTensor((None, 4, 3), dtype="float32") + imag = KerasTensor((None, 4, 3), dtype="float32") + real_output, imag_output = kmath.ifft2((real, imag)) + ref = np.fft.ifft2(np.ones((2, 4, 3))) + ref_shape = (None,) + ref.shape[1:] + self.assertEqual(real_output.shape, ref_shape) + self.assertEqual(imag_output.shape, ref_shape) + @parameterized.parameters([(None,), (1,), (5,)]) def test_rfft(self, fft_length): x = KerasTensor((None, 4, 3), dtype="float32") @@ -355,6 +364,14 @@ def test_fft2(self): self.assertEqual(real_output.shape, ref.shape) self.assertEqual(imag_output.shape, ref.shape) + def test_ifft2(self): + real = KerasTensor((2, 4, 3), dtype="float32") + imag = KerasTensor((2, 4, 3), dtype="float32") + real_output, imag_output = kmath.ifft2((real, imag)) + ref = np.fft.ifft2(np.ones((2, 4, 3))) + self.assertEqual(real_output.shape, ref.shape) + self.assertEqual(imag_output.shape, ref.shape) + def test_rfft(self): x = KerasTensor((2, 4, 3), dtype="float32") real_output, imag_output = kmath.rfft(x) @@ -717,6 +734,18 @@ def test_fft2(self): self.assertAllClose(real_ref, real_output) self.assertAllClose(imag_ref, imag_output) + def test_ifft2(self): + real = np.random.random((2, 4, 3)).astype(np.float32) + imag = np.random.random((2, 4, 3)).astype(np.float32) + complex_arr = real + 1j * imag + + real_output, imag_output = kmath.ifft2((real, imag)) + ref = np.fft.ifft2(complex_arr) + real_ref = np.real(ref) + imag_ref = np.imag(ref) + self.assertAllClose(real_ref, real_output) + self.assertAllClose(imag_ref, imag_output) + @parameterized.parameters([(None,), (3,), (15,)]) def test_rfft(self, n): # Test 1D.