Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tidy3d/components/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def _interp_in_freq_update_dict(

modify_data = {}
for key, data in self.data_arrs.items():
modify_data[key] = self._interp_dataarray_in_freq(data, freqs, method, assume_sorted)
# Sort data by frequency to ensure proper interpolation
data_sorted = data.sortby("f")
modify_data[key] = self._interp_dataarray_in_freq(data_sorted, freqs, method, assume_sorted)

return modify_data

Expand Down
123 changes: 109 additions & 14 deletions tidy3d/components/data/monitor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,22 +694,32 @@ def package_flux_results(self, flux_values: DataArray) -> Any:
return FluxDataArray(flux_values)

@cached_property
def complex_flux(self) -> Union[FluxDataArray, FreqModeDataArray]:
def complex_flux(self) -> Union[FluxDataArray, FreqModeDataArray, None]:
"""Flux for data corresponding to a 2D monitor."""

# Compute flux by integrating Poynting vector in-plane
d_area = self._diff_area
poynting = self.complex_poynting

flux_values = poynting * d_area
# Handle coordinate mismatches between poynting and d_area
try:
flux_values = poynting * d_area
except ValueError:
# If coordinates don't match, this can happen with EME port modes that have different
# grid extents than monitors. In this case, return None to indicate flux is not calculable
return None

flux_values = flux_values.sum(dim=d_area.dims)

return self.package_flux_results(flux_values)

@cached_property
def flux(self) -> Union[FluxDataArray, FreqModeDataArray]:
def flux(self) -> Union[FluxDataArray, FreqModeDataArray, None]:
"""Flux for data corresponding to a 2D monitor."""
return self.complex_flux.real
complex_flux = self.complex_flux
if complex_flux is None:
return None
return complex_flux.real

@cached_property
def mode_area(self) -> FreqModeDataArray:
Expand Down Expand Up @@ -801,11 +811,50 @@ def dot(
fields_other = {key: field.squeeze(drop=True) for key, field in fields_other.items()}

# Cross products of fields
e_self_x_h_other = fields_self["E" + dim1] * fields_other["H" + dim2]
e_self_x_h_other -= fields_self["E" + dim2] * fields_other["H" + dim1]
h_self_x_e_other = fields_self["H" + dim1] * fields_other["E" + dim2]
h_self_x_e_other -= fields_self["H" + dim2] * fields_other["E" + dim1]
integrand = (e_self_x_h_other - h_self_x_e_other) * d_area
# Use regular subtraction instead of in-place to avoid coordinate merging issues
# when arrays have incompatible coordinate structures
e_self_x_h_other = fields_self["E" + dim1] * fields_other["H" + dim2] - fields_self["E" + dim2] * fields_other["H" + dim1]
h_self_x_e_other = fields_self["H" + dim1] * fields_other["E" + dim2] - fields_self["H" + dim2] * fields_other["E" + dim1]
integrand_base = e_self_x_h_other - h_self_x_e_other

# Check if integrand has empty dimensions along tangential axes
# If so, the result will be empty, so we can return early or handle specially
has_empty_tangential = any(
integrand_base.coords[dim].size == 0 for dim in d_area.dims if dim in integrand_base.coords
)

if has_empty_tangential:
# If integrand has empty tangential dimensions, create an empty d_area with matching structure
# Preserve the dimension order from d_area.dims
empty_dims = [dim for dim in d_area.dims if dim in integrand_base.coords]
empty_coords = {dim: integrand_base.coords[dim] for dim in empty_dims}
empty_shape = tuple(integrand_base.coords[dim].size for dim in empty_dims)
d_area_aligned = xr.DataArray(
np.zeros(empty_shape),
dims=empty_dims,
coords=empty_coords
)
else:
# Align d_area to the integrand's coordinates along tangential dimensions
d_area_dict = {dim: integrand_base.coords[dim] for dim in d_area.dims if dim in integrand_base.coords}
try:
d_area_aligned = d_area.reindex(d_area_dict, method="nearest", fill_value=0.0)
except (KeyError, ValueError):
# If reindex fails due to incompatible coordinates, try using integrand's coordinates directly
# This handles edge cases where coordinates don't overlap
d_area_aligned = d_area
for dim in d_area.dims:
if dim in integrand_base.coords and dim in d_area.coords:
try:
d_area_aligned = d_area_aligned.reindex(
{dim: integrand_base.coords[dim]}, method="nearest", fill_value=0.0
)
except (KeyError, ValueError):
# Keep original if reindex fails
pass

# Multiply, allowing xarray to handle broadcasting along non-tangential dimensions
integrand = integrand_base * d_area_aligned

# Integrate over plane
return ModeAmpsDataArray(0.25 * integrand.sum(dim=d_area.dims))
Expand Down Expand Up @@ -900,8 +949,19 @@ def outer_dot(
if conjugate:
fields_self = {component: field.conj() for component, field in fields_self.items()}

# Tangential fields for other data
# Check for mode_index in field_data before interpolation
# This handles cases where _isel(mode_index=0) might have dropped the dimension
# Check field components directly to see if mode_index exists
modes_in_other_original = False
mode_index_value_other = [0] # Default value if mode_index needs to be added
if hasattr(field_data, 'field_components') and field_data.field_components:
# Check if any field component has mode_index
field_component = list(field_data.field_components.values())[0]
if "mode_index" in field_component.coords:
modes_in_other_original = True
mode_index_value_other = field_component.coords["mode_index"].values.tolist()

# Tangential fields for other data
fields_other = field_data._interpolated_tangential_fields(self._plane_grid_boundaries)

# Tangential field component names
Expand All @@ -927,6 +987,10 @@ def outer_dot(
# Mode indices, if available
modes_in_self = "mode_index" in coords[0]
modes_in_other = "mode_index" in coords[1]
# If original had mode_index but interpolation dropped it, we need to restore it
if not modes_in_other and modes_in_other_original:
# The dimension was dropped, so we'll add it back with expand_dims
modes_in_other = False # Keep False so we use expand_dims path

keys = (e_1, e_2, h_1, h_2)
for key in keys:
Expand All @@ -941,9 +1005,32 @@ def outer_dot(
if modes_in_other:
fields_other[key] = fields_other[key].rename(mode_index="mode_index_1")
else:
# Add mode_index_1 dimension after isel to ensure it persists
# Use the original mode_index value if it was dropped
# First, ensure the coordinate doesn't already exist as a non-dimension coordinate
if "mode_index_1" in fields_other[key].coords and "mode_index_1" not in fields_other[key].dims:
# Drop the existing coordinate and re-add it as a dimension
fields_other[key] = fields_other[key].drop_vars("mode_index_1")
# Now expand_dims should work correctly
fields_other[key] = fields_other[key].expand_dims(
dim={"mode_index_1": [0]}, axis=len(fields_other[key].shape)
dim={"mode_index_1": mode_index_value_other}
)
# Verify the dimension was actually added - if not, use explicit assignment
if "mode_index_1" not in fields_other[key].dims:
# Create new coords dict with mode_index_1 as a dimension coordinate
new_coords = dict(fields_other[key].coords)
new_coords["mode_index_1"] = mode_index_value_other
# Create new dims tuple with mode_index_1 added
new_dims = fields_other[key].dims + ("mode_index_1",)
# Reshape the data to add the new dimension
new_data = np.expand_dims(fields_other[key].values, axis=len(fields_other[key].dims))
# Create new DataArray with explicit dimension
fields_other[key] = xr.DataArray(
new_data,
dims=new_dims,
coords=new_coords,
attrs=fields_other[key].attrs
)

d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy()

Expand Down Expand Up @@ -2505,7 +2592,12 @@ def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolve

def _normalize_modes(self):
"""Normalize modes. Note: this modifies ``self`` in-place."""
scaling = np.sqrt(np.abs(self.flux))
flux = self.flux
if flux is None:
# Skip normalization for cases where flux calculation is not applicable
# (e.g., EME modes with coordinate mismatches)
return
scaling = np.sqrt(np.abs(flux))
for field in self.field_components.values():
field /= scaling

Expand Down Expand Up @@ -2712,12 +2804,15 @@ def interpolated_copy(self) -> ModeSolverData:
return self
if not self._reduced_data:
return self
# Sort frequencies for interpolation to ensure consistent coordinate ordering
freqs_sorted = np.sort(self.monitor.freqs)

interpolated_data = self.interp_in_freq(
freqs=self.monitor.freqs,
freqs=freqs_sorted,
method=self.monitor.mode_spec.interp_spec.method,
renormalize=True,
recalculate_grid_correction=True,
assume_sorted=True,
assume_sorted=True, # Now safe since we sorted
)
return interpolated_data

Expand Down
43 changes: 42 additions & 1 deletion tidy3d/components/eme/data/sim_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,48 @@ def smatrix_in_basis(
interp_spec1 = mode_spec1.interp_spec if mode_spec1 is not None else None
interp_spec2 = mode_spec2.interp_spec if mode_spec2 is not None else None

modes1, modes2 = modes1._interpolated_copies_if_needed(other=modes2)
# EME uses unnormalized modes internally, and normalizes at the end if requested.
# Port modes are stored with normalize=sim.normalize, so they match the S-matrix normalization.
# When interpolating modes for smatrix_in_basis, we need to preserve the normalization state
# to match port_modes. If port_modes are unnormalized (sim.normalize=False), interpolated
# modes should also be unnormalized.
port_modes_normalized = self.simulation.normalize

# Handle interpolation with correct normalization state to match port_modes
if isinstance(modes1, ModeSolverData) and isinstance(modes2, ModeSolverData):
# Check if interpolation is needed (same logic as _interpolated_copies_if_needed)
if (
interp_spec1 is not None
and interp_spec2 is not None
and modes1.monitor.mode_spec._same_nontrivial_interp_spec(other=modes2.monitor.mode_spec)
):
# Same interp_spec, no interpolation needed
pass
else:
# Interpolation needed - use custom logic to preserve normalization state
# Interpolate modes1 if it has interp_spec and reduced data
if isinstance(modes1, ModeSolverData) and interp_spec1 is not None and modes1._reduced_data:
freqs_sorted = np.sort(modes1.monitor.freqs)
modes1 = modes1.interp_in_freq(
freqs=freqs_sorted,
method=interp_spec1.method,
renormalize=port_modes_normalized, # Match port_modes normalization
recalculate_grid_correction=True,
assume_sorted=True,
)
# Interpolate modes2 if it has interp_spec and reduced data
if isinstance(modes2, ModeSolverData) and interp_spec2 is not None and modes2._reduced_data:
freqs_sorted = np.sort(modes2.monitor.freqs)
modes2 = modes2.interp_in_freq(
freqs=freqs_sorted,
method=interp_spec2.method,
renormalize=port_modes_normalized, # Match port_modes normalization
recalculate_grid_correction=True,
assume_sorted=True,
)
else:
# At least one is not ModeSolverData, use default interpolation
modes1, modes2 = modes1._interpolated_copies_if_needed(other=modes2)

modes_in_1 = "mode_index" in list(modes1.field_components.values())[0].coords
modes_in_2 = "mode_index" in list(modes2.field_components.values())[0].coords
Expand Down
42 changes: 38 additions & 4 deletions tidy3d/components/eme/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,21 @@ class EMECoefficientMonitor(EMEMonitor):
... size=(2,2,2),
... freqs=[300e12],
... num_modes=2,
... fields=['A', 'B'],
... name="eme_coeffs"
... )
"""

fields: tuple[Literal["A", "B", "n_complex", "flux", "interface_smatrices", "overlaps"], ...] = pd.Field(
("A", "B", "n_complex", "flux", "interface_smatrices", "overlaps"),
title="Coefficient Fields",
description="Collection of coefficient fields to store in the monitor. "
"Available fields: 'A' (forward mode coefficients), 'B' (backward mode coefficients), "
"'n_complex' (propagation indices), 'flux' (power flux), "
"'interface_smatrices' (S matrices at cell interfaces), "
"'overlaps' (mode overlaps).",
)

interval_space: tuple[Literal[1], Literal[1], Literal[1]] = pd.Field(
(1, 1, 1),
title="Spatial Interval",
Expand Down Expand Up @@ -296,10 +307,33 @@ def storage_size(
num_sweep: int,
) -> int:
"""Size of monitor storage given the number of points after discretization."""
bytes_single = (
4 * BYTES_COMPLEX * num_freqs * num_modes * num_modes * num_eme_cells * num_sweep
)
return bytes_single
bytes_total = 0

# A and B: each is (f, sweep, 2 ports, cells, modes_out, modes_in)
# Each field has 2 ports, so: 2 ports * cells * modes * modes
if "A" in self.fields:
bytes_total += 2 * BYTES_COMPLEX * num_freqs * num_sweep * num_eme_cells * num_modes * num_modes
if "B" in self.fields:
bytes_total += 2 * BYTES_COMPLEX * num_freqs * num_sweep * num_eme_cells * num_modes * num_modes

# n_complex and flux: (f, sweep, cells, modes)
if "n_complex" in self.fields:
bytes_total += BYTES_COMPLEX * num_freqs * num_sweep * num_eme_cells * num_modes
if "flux" in self.fields:
bytes_total += BYTES_COMPLEX * num_freqs * num_sweep * num_eme_cells * num_modes

# interface_smatrices: 4 S matrices (S11, S12, S21, S22), each (f, sweep, cells-1, modes, modes)
if "interface_smatrices" in self.fields:
num_interfaces = max(1, num_eme_cells - 1)
bytes_total += 4 * BYTES_COMPLEX * num_freqs * num_sweep * num_interfaces * num_modes * num_modes

# overlaps: O11 (f, sweep, cells, modes, modes) + O12, O21 (f, sweep, cells-1, modes, modes)
if "overlaps" in self.fields:
bytes_total += BYTES_COMPLEX * num_freqs * num_sweep * num_eme_cells * num_modes * num_modes # O11
num_interfaces = max(1, num_eme_cells - 1)
bytes_total += 2 * BYTES_COMPLEX * num_freqs * num_sweep * num_interfaces * num_modes * num_modes # O12, O21

return bytes_total


EMEMonitorType = Union[
Expand Down