Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_components/test_IO.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import tidy3d as td
from tidy3d import __version__
from tidy3d.components.data.data_array import DATA_ARRAY_MAP
from tidy3d.components.data.data_array import is_data_array_name
from tidy3d.components.data.sim_data import DATA_TYPE_MAP

from ..test_data.test_monitor_data import make_flux_data
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_to_json_data():
# type saved in the combined json file?
data = make_flux_data()
json_dict = json.loads(data._json_string)
assert json_dict["flux"] in DATA_ARRAY_MAP
assert is_data_array_name(json_dict["flux"])


def test_to_hdf5_group_path_sim_data(tmp_path):
Expand Down
7 changes: 5 additions & 2 deletions tests/test_components/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import ValidationError

import tidy3d as td
from tidy3d.components.data.data_array import _isinstance
from tidy3d.components.data.dataset import PermittivityDataset
from tidy3d.components.data.utils import UnstructuredGridDataset, _get_numpy_array
from tidy3d.components.medium import (
Expand Down Expand Up @@ -554,15 +555,17 @@ def verify_custom_medium_methods(mat, reduced_fields):

# data fields in medium classes could be SpatialArrays or 2d tuples of spatial arrays
# lets convert everything into 2d tuples of spatial arrays for uniform handling
if isinstance(original, (td.SpatialDataArray, UnstructuredGridDataset)):
if _isinstance(original, td.SpatialDataArray) or isinstance(
original, UnstructuredGridDataset
):
original = [[original]]
reduced = [[reduced]]

for or_set, re_set in zip(original, reduced):
assert len(or_set) == len(re_set)

for ind in range(len(or_set)):
if isinstance(or_set[ind], td.SpatialDataArray):
if _isinstance(or_set[ind], td.SpatialDataArray):
diff = (or_set[ind] - re_set[ind]).abs
assert diff.does_cover(subsection.bounds)
assert np.allclose(diff, 0)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_components/test_microwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from shapely import LineString

import tidy3d as td
from tidy3d.components.data.data_array import FreqModeDataArray
from tidy3d.components.data.data_array import FreqModeDataArray, _isinstance
from tidy3d.components.data.monitor_data import FreqDataArray
from tidy3d.components.microwave.formulas.circuit_parameters import (
capacitance_colinear_cylindrical_wire_segments,
Expand Down Expand Up @@ -440,8 +440,8 @@ def test_antenna_parameters():
)

# Test that all essential parameters exist and are correct type
assert isinstance(antenna_params.radiation_efficiency, FreqDataArray)
assert isinstance(antenna_params.reflection_efficiency, FreqDataArray)
assert _isinstance(antenna_params.radiation_efficiency, FreqDataArray)
assert _isinstance(antenna_params.reflection_efficiency, FreqDataArray)
assert np.allclose(antenna_params.reflection_efficiency, 0.75)
assert isinstance(antenna_params.gain, xr.DataArray)
assert isinstance(antenna_params.realized_gain, xr.DataArray)
Expand Down
93 changes: 90 additions & 3 deletions tests/test_data/test_data_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,17 @@
import autograd.numpy as np
import numpy
import pytest
import xarray as xr
import xarray.testing as xrt
from autograd.test_util import check_grads
from pydantic import BaseModel, ValidationError

import tidy3d as td
from tidy3d._common.components.data.data_array import DataArray
from tidy3d.components.data.data_array import (
data_array_annotated_type,
)
from tidy3d.components.data.dataset import TimeDataset
from tidy3d.exceptions import DataError

np.random.seed(4)
Expand Down Expand Up @@ -343,13 +350,93 @@ def test_abs():


def test_angle():
# Make sure works on real data and the type is correct
# Make sure works on real data and preserves DataArray structure
data = make_scalar_field_time_data_array("Ex")
angle_data = data.angle
assert type(data) is type(angle_data)
assert isinstance(angle_data, xr.DataArray)
assert angle_data.dims == data.dims
assert angle_data.coords.equals(data.coords)
data = make_mode_amps_data_array()
angle_data = data.angle
assert type(data) is type(angle_data)
assert isinstance(angle_data, xr.DataArray)
assert angle_data.dims == data.dims
assert angle_data.coords.equals(data.coords)


def test_annotated_data_array_spec():
ScalarFieldSpec = data_array_annotated_type(td.ScalarFieldDataArray)

class Model(BaseModel):
field: ScalarFieldSpec

data = make_scalar_field_data_array("Ex")
data_plain = xr.DataArray(data.data, coords=data.coords, dims=data.dims)
model = Model(field=data_plain)
assert model.field.dims == data.dims
assert "tidy3d.data.scalar_field" in model.model_dump_json()

with pytest.raises(ValidationError):
Model(field=xr.DataArray(np.zeros((2, 2)), dims=("x", "y")))


def test_annotated_accepts_legacy_class():
ScalarFieldSpec = data_array_annotated_type(td.ScalarFieldDataArray)

class Model(BaseModel):
field: ScalarFieldSpec

data = make_scalar_field_data_array("Ex")
model = Model(field=data)
assert model.field.dims == data.dims


def test_legacy_data_array_shims():
arr = xr.DataArray(
np.random.random((3, 4, 5)),
coords={
"x": np.linspace(0, 1, 3),
"y": np.linspace(1, 2, 4),
"z": np.linspace(2, 3, 5),
},
)
bounds = ((0.2, 1.1, 2.1), (0.9, 1.9, 2.9))
selected = arr.sel_inside(bounds)
assert selected.dims == arr.dims
reflected = arr.reflect(axis=0, center=-0.5, reflection_only=True)
assert reflected.dims == arr.dims
updated = arr._with_updated_data(data=np.zeros((1, 1, 1)), coords={"x": 0, "y": 1, "z": 2})
assert updated.dims == arr.dims


def test_annotated_dataset_hdf5_roundtrip(tmp_path):
times = np.linspace(0, 1e-12, 4)
values = np.random.random(len(times))
data = xr.DataArray(values, coords={"t": times}, dims=("t",))
dataset = TimeDataset(values=data)

path = tmp_path / "time_dataset.hdf5"
dataset.to_hdf5(path)
loaded = TimeDataset.from_hdf5(path)

assert type(loaded.values) is DataArray
assert loaded.values.dims == data.dims
assert loaded.values.coords["t"].equals(data.coords["t"])


def test_legacy_class_spec_validation():
class Model(BaseModel):
field: data_array_annotated_type(td.ScalarFieldDataArray)

data = xr.DataArray(
np.random.random((len(FS), 2, 3, 4)),
coords={"f": FS, "x": [0, 1], "y": [0, 1, 2], "z": [0, 1, 2, 3]},
dims=("f", "x", "y", "z"),
)
model = Model(field=data)
assert model.field.dims == ("x", "y", "z", "f")

with pytest.raises(ValidationError):
Model(field=xr.DataArray(np.zeros((2, 2)), dims=("x", "y")))


def test_heat_data_array():
Expand Down
14 changes: 7 additions & 7 deletions tests/test_plugins/smatrix/test_terminal_component_modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tidy3d.plugins.smatrix.utils
from tidy3d import SimulationDataMap
from tidy3d.components.boundary import BroadbandModeABCSpec
from tidy3d.components.data.data_array import FreqDataArray
from tidy3d.components.data.data_array import FreqDataArray, _isinstance
from tidy3d.exceptions import SetupError, Tidy3dError, Tidy3dKeyError
from tidy3d.plugins.smatrix import (
CoaxialLumpedPort,
Expand Down Expand Up @@ -1236,8 +1236,8 @@ def test_antenna_helpers(monkeypatch, tmp_path):

# Test power wave amplitude computation
a, b = modeler_data.compute_power_wave_amplitudes_at_each_port(sim_data=sim_data)
assert isinstance(a, PortDataArray)
assert isinstance(b, PortDataArray)
assert _isinstance(a, PortDataArray)
assert _isinstance(b, PortDataArray)


@pytest.mark.parametrize("port_type", ["lumped", "wave"])
Expand Down Expand Up @@ -1288,8 +1288,8 @@ def test_antenna_parameters(monkeypatch, port_type):
antenna_params = modeler_data.get_antenna_metrics_data()

# Test that all essential parameters exist and are correct type
assert isinstance(antenna_params.radiation_efficiency, FreqDataArray)
assert isinstance(antenna_params.reflection_efficiency, FreqDataArray)
assert _isinstance(antenna_params.radiation_efficiency, FreqDataArray)
assert _isinstance(antenna_params.reflection_efficiency, FreqDataArray)
assert isinstance(antenna_params.gain, xr.DataArray)
assert isinstance(antenna_params.realized_gain, xr.DataArray)

Expand Down Expand Up @@ -1345,8 +1345,8 @@ def test_get_combined_antenna_parameters_data(monkeypatch, tmp_path):
)

# Check that essential properties exist and are correct type
assert isinstance(antenna_params.radiation_efficiency, FreqDataArray)
assert isinstance(antenna_params.reflection_efficiency, FreqDataArray)
assert _isinstance(antenna_params.radiation_efficiency, FreqDataArray)
assert _isinstance(antenna_params.reflection_efficiency, FreqDataArray)
assert isinstance(antenna_params.partial_gain(), xr.Dataset)
assert isinstance(antenna_params.gain, xr.DataArray)
assert isinstance(antenna_params.partial_realized_gain(), xr.Dataset)
Expand Down
13 changes: 9 additions & 4 deletions tidy3d/_common/components/autograd/derivative_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import xarray as xr
from numpy.typing import NDArray

from tidy3d._common.components.data.data_array import FreqDataArray, ScalarFieldDataArray
from tidy3d._common.components.data.data_array import (
FreqDataArray,
ScalarFieldDataArray,
_isinstance,
data_array_annotated_type,
)
from tidy3d._common.components.types.base import ArrayLike, Bound, Complex
from tidy3d._common.config import config
from tidy3d._common.constants import C_0, EPSILON_0, LARGE_NUMBER, MU_0
Expand All @@ -27,7 +32,7 @@

FieldData = dict[str, ScalarFieldDataArray]
PermittivityData = dict[str, ScalarFieldDataArray]
EpsType = Union[Complex, FreqDataArray]
EpsType = Union[Complex, data_array_annotated_type(FreqDataArray)]
ArrayFloat = NDArray[np.floating]
ArrayComplex = NDArray[np.complexfloating]

Expand Down Expand Up @@ -706,7 +711,7 @@ def _prepare_epsilon(eps: EpsType) -> np.ndarray:
For FreqDataArray, extracts values and broadcasts to shape (1, n_freqs).
For scalar values, broadcasts to shape (1, 1) for consistency with multi-frequency.
"""
if isinstance(eps, FreqDataArray):
if _isinstance(eps, FreqDataArray):
# data is already sliced, just extract values
eps_values = eps.values
# shape: (n_freqs,) - need to broadcast to (1, n_freqs)
Expand Down Expand Up @@ -812,7 +817,7 @@ def adaptive_vjp_spacing(
min_allowed_spacing_fraction = config.adjoint.minimum_spacing_fraction

# handle FreqDataArray or scalar eps_in
if isinstance(self.eps_in, FreqDataArray):
if _isinstance(self.eps_in, FreqDataArray):
eps_real = np.asarray(self.eps_in.values, dtype=np.complex128).real
else:
eps_real = np.asarray(self.eps_in, dtype=np.complex128).real
Expand Down
27 changes: 18 additions & 9 deletions tidy3d/_common/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator

from tidy3d._common.components.autograd.utils import get_static
from tidy3d._common.components.data.data_array import DATA_ARRAY_MAP
from tidy3d._common.components.data.data_array import (
data_array_spec_from_name,
is_data_array_name,
iter_data_array_names,
write_data_array_to_hdf5,
)
from tidy3d._common.components.file_util import compress_file_to_gzip, extract_gzip_file
from tidy3d._common.components.types.base import TYPE_TAG_STR, Undefined
from tidy3d._common.exceptions import FileError
Expand Down Expand Up @@ -1053,7 +1058,7 @@ def to_yaml(self, fname: PathLike) -> None:
@staticmethod
def _warn_if_contains_data(json_str: str) -> None:
"""Log a warning if the json string contains data, used in '.json' and '.yaml' file."""
if any((key in json_str for key, _ in DATA_ARRAY_MAP.items())):
if any(name in json_str for name in iter_data_array_names()):
log.warning(
"Data contents found in the model to be written to file. "
"Note that this data will not be included in '.json' or '.yaml' formats. "
Expand Down Expand Up @@ -1155,7 +1160,7 @@ def dict_from_hdf5(

def is_data_array(value: Any) -> bool:
"""Whether a value is supposed to be a data array based on the contents."""
return isinstance(value, str) and value in DATA_ARRAY_MAP
return is_data_array_name(value)

fname_path = Path(fname)

Expand All @@ -1178,10 +1183,10 @@ def load_data_from_file(model_dict: dict, group_path: str = "") -> None:

# write the path to the element of the json dict where the data_array should be
if is_data_array(value):
data_array_type = DATA_ARRAY_MAP[value]
model_dict[key] = data_array_type.from_hdf5(
fname=fname_path, group_path=subpath
)
spec = data_array_spec_from_name(value)
if spec is None:
raise FileError(f"Unrecognized DataArray schema '{value}'.")
model_dict[key] = spec.from_hdf5(fname=fname_path, group_path=subpath)
continue

# if a list, assign each element a unique key, recurse
Expand Down Expand Up @@ -1291,7 +1296,7 @@ def add_data_to_file(data_dict: dict, group_path: str = "") -> None:

# write the path to the element of the json dict where the data_array should be
if isinstance(value, xr.DataArray):
value.to_hdf5(fname=f_handle, group_path=subpath)
write_data_array_to_hdf5(value, f_handle=f_handle, group_path=subpath)

# if a tuple, assign each element a unique key
if isinstance(value, (list, tuple)):
Expand Down Expand Up @@ -1444,7 +1449,11 @@ def _fields_equal(a: Any, b: Any) -> bool:
if a is b:
return True
if type(a) is not type(b):
if not (isinstance(a, (list, tuple)) and isinstance(b, (list, tuple))):
if isinstance(a, (xr.DataArray, xr.Dataset)) and isinstance(
b, (xr.DataArray, xr.Dataset)
):
pass
elif not (isinstance(a, (list, tuple)) and isinstance(b, (list, tuple))):
return False
if isinstance(a, np.ndarray):
return np.array_equal(a, b)
Expand Down
Loading