Skip to content
2 changes: 2 additions & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from keras.src.ops.math import segment_sum as segment_sum
from keras.src.ops.math import stft as stft
from keras.src.ops.math import top_k as top_k
from keras.src.ops.math import view_as_complex as view_as_complex
from keras.src.ops.math import view_as_real as view_as_real
from keras.src.ops.nn import average_pool as average_pool
from keras.src.ops.nn import batch_normalization as batch_normalization
from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
Expand Down
2 changes: 2 additions & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from keras.src.ops.math import segment_sum as segment_sum
from keras.src.ops.math import stft as stft
from keras.src.ops.math import top_k as top_k
from keras.src.ops.math import view_as_complex as view_as_complex
from keras.src.ops.math import view_as_real as view_as_real
from keras.src.ops.nn import average_pool as average_pool
from keras.src.ops.nn import batch_normalization as batch_normalization
from keras.src.ops.nn import binary_crossentropy as binary_crossentropy
Expand Down
98 changes: 98 additions & 0 deletions keras/src/ops/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,3 +1044,101 @@ def logdet(x):
if any_symbolic_tensors((x,)):
return Logdet().symbolic_call(x)
return backend.math.logdet(x)


class ViewAsComplex(Operation):
def call(self, x):
x = backend.convert_to_tensor(x)
if len(x.shape) < 1 or x.shape[-1] != 2:
raise ValueError(
"Input tensor's last dimension must be 2 (real and imaginary)."
)
return x[..., 0] + 1j * x[..., 1]

def compute_output_spec(self, x):
return KerasTensor(shape=x.shape[:-1], dtype="complex64")


class ViewAsReal(Operation):
def call(self, x):
x = backend.convert_to_tensor(x)
real_part = backend.numpy.real(x)
imag_part = backend.numpy.imag(x)
return backend.numpy.stack((real_part, imag_part), axis=-1)

def compute_output_spec(self, x):
return KerasTensor(shape=x.shape + (2,), dtype="float32")


@keras_export("keras.ops.view_as_complex")
def view_as_complex(x):
"""Converts a real tensor with shape `(..., 2)` to a complex tensor,
where the last dimension represents the real and imaginary components
of a complex tensor.

Args:
x: A real tensor with last dimension of size 2.

Returns:
A complex tensor with shape `x.shape[:-1]`.

Example:

```
>>> import numpy as np
>>> from keras import ops

>>> real_imag = np.array([[1.0, 2.0], [3.0, 4.0]])
>>> complex_tensor = ops.view_as_complex(real_imag)
>>> complex_tensor
array([1.+2.j, 3.+4.j])
```
"""
if any_symbolic_tensors((x,)):
return ViewAsComplex().symbolic_call(x)

x = backend.convert_to_tensor(x)
if len(x.shape) < 1 or x.shape[-1] != 2:
raise ValueError(
"Last dimension of input must be size 2 (real and imaginary). "
f"Received shape: {x.shape}"
)
real_part = x[..., 0]
imag_part = x[..., 1]

return backend.cast(real_part, dtype="complex64") + 1j * backend.cast(
imag_part, dtype="complex64"
)


@keras_export("keras.ops.view_as_real")
def view_as_real(x):
"""Converts a complex tensor to a real tensor with shape `(..., 2)`,
where the last dimension represents the real and imaginary components.

Args:
x: A complex tensor.

Returns:
A real tensor where the last dimension contains the
real and imaginary parts.

Example:
```
>>> import numpy as np
>>> from keras import ops

>>> complex_tensor = np.array([1 + 2j, 3 + 4j])
>>> real = ops.view_as_real(complex_tensor)
>>> real
array([[1., 2.],
[3., 4.]])
```
"""
if any_symbolic_tensors((x,)):
return ViewAsReal().symbolic_call(x)

x = backend.convert_to_tensor(x)
real_part = backend.numpy.real(x)
imag_part = backend.numpy.imag(x)
return backend.numpy.stack((real_part, imag_part), axis=-1)
68 changes: 68 additions & 0 deletions keras/src/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from keras.src import backend
from keras.src import testing
from keras.src.backend.common import dtypes
from keras.src.backend.common import standardize_dtype
from keras.src.backend.common.keras_tensor import KerasTensor
from keras.src.ops import math as kmath

Expand Down Expand Up @@ -1494,3 +1495,70 @@ def test_istft_invalid_window_shape_2D_inputs(self):
fft_length,
window=incorrect_window,
)


@pytest.mark.skipif(
backend.backend() == "openvino",
reason="Complex dtype is not supported on OpenVINO backend.",
)
class ViewAsComplexRealTest(testing.TestCase):
def test_view_as_complex_basic(self):
real_imag = np.array([[1.0, 2.0], [3.0, 4.0]])
expected = np.array([1.0 + 2.0j, 3.0 + 4.0j], dtype=np.complex64)

result = kmath.view_as_complex(real_imag)

self.assertEqual(result.shape, expected.shape)
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
self.assertAllClose(result, expected)

def test_view_as_real_basic(self):
complex_tensor = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)
expected = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)

result = kmath.view_as_real(complex_tensor)

self.assertEqual(result.shape, expected.shape)
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
self.assertAllClose(result, expected)

def test_view_as_complex_invalid_shape(self):
bad_input = np.array([1.0, 2.0, 3.0]) # Last dimension not size 2
with self.assertRaisesRegex(
ValueError, "Last dimension of input must be size 2"
):
kmath.view_as_complex(bad_input)

def test_view_as_complex_symbolic_input(self):
x = KerasTensor(shape=(None, 2), dtype="float32")
result = kmath.view_as_complex(x)

self.assertEqual(result.shape, (None,))
self.assertEqual(standardize_dtype(result.dtype), "complex64")

def test_view_as_real_symbolic_input(self):
x = KerasTensor(shape=(None,), dtype="complex64")
result = kmath.view_as_real(x)

self.assertEqual(result.shape, (None, 2))
self.assertEqual(standardize_dtype(result.dtype), "float32")

def test_view_as_complex_multi_dimensional(self):
x = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32)
expected = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64)

result = kmath.view_as_complex(x)

self.assertEqual(result.shape, expected.shape)
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
self.assertAllClose(result, expected)

def test_view_as_real_multi_dimensional(self):
x = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64)
expected = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32)

result = kmath.view_as_real(x)

self.assertEqual(result.shape, expected.shape)
self.assertEqual(standardize_dtype(result.dtype), expected.dtype)
self.assertAllClose(result, expected)