Skip to content

Conversation

@marcorudolphflex
Copy link
Contributor

@marcorudolphflex marcorudolphflex commented Jan 22, 2026

Implemented adjoint gradients for CustomCurrentSource.current_dataset and CustomFieldSource.field_dataset.

here some raw results from the numerical tests


[custom_field_vec_e_noise_(1.0, -0.5, 0.25)] grad_adjoint = [ 0.17309422 -0.08330589  0.        ]
[custom_field_vec_e_noise_(1.0, -0.5, 0.25)] grad_fd      = [ 0.15962869 -0.07763505  0.        ]
[custom_field_vec_e_noise_(1.0, -0.5, 0.25)] angle_deg    = 0.23552397635418437

[custom_field_vec_e_noise_(0.2, 0.7, -0.9)] grad_adjoint = [0.04982851 0.13126495 0.        ]
[custom_field_vec_e_noise_(0.2, 0.7, -0.9)] grad_fd      = [0.04645437 0.12250617 0.        ]
[custom_field_vec_e_noise_(0.2, 0.7, -0.9)] angle_deg    = 0.02015207770121852

[custom_field_vec_h_noise_(1.0, -0.5, 0.25)] grad_adjoint = [ 0.17501316 -0.08719919  0.        ]
[custom_field_vec_h_noise_(1.0, -0.5, 0.25)] grad_fd      = [ 0.16290694 -0.07942319  0.        ]
[custom_field_vec_h_noise_(1.0, -0.5, 0.25)] angle_deg    = 0.49353584832671316

[custom_current_vec_e_clean_(0.2, 0.7, -0.9)] grad_adjoint = [ 99.45885067 263.29361833 -67.95866614]
[custom_current_vec_e_clean_(0.2, 0.7, -0.9)] grad_fd      = [ 89.9887085  250.85449219 -50.69732666]
[custom_current_vec_e_clean_(0.2, 0.7, -0.9)] angle_deg    = 2.9567202337319016

[custom_current_vec_h_clean_(1.0, -0.5, 0.25)] grad_adjoint = [ 2.31835088e-08 -3.51533076e-09  3.38876971e-09]
[custom_current_vec_h_clean_(1.0, -0.5, 0.25)] grad_fd      = [ 2.15072404e-08 -3.56603636e-09  3.82804899e-09]
[custom_current_vec_h_clean_(1.0, -0.5, 0.25)] angle_deg    = 1.9038311935062815

[custom_current_vec_h_noise_(0.2, 0.7, -0.9)] grad_adjoint = [ 6.10464396e-09  1.36642137e-08 -4.35298592e-09]
[custom_current_vec_h_noise_(0.2, 0.7, -0.9)] grad_fd      = [ 6.24833518e-09  1.31894495e-08 -3.64153152e-09]
[custom_current_vec_h_noise_(0.2, 0.7, -0.9)] angle_deg    = 2.5278326000493307

[custom_field_vec_h_clean_(1.0, -0.5, 0.25)] grad_adjoint = [ 0.20265503 -0.057224    0.        ]
[custom_field_vec_h_clean_(1.0, -0.5, 0.25)] grad_fd      = [ 0.18961728 -0.05118549  0.        ]
[custom_field_vec_h_clean_(1.0, -0.5, 0.25)] angle_deg    = 0.6617374116909891

[custom_field_vec_h_noise_(0.2, 0.7, -0.9)] grad_adjoint = [0.0495895  0.12431687 0.        ]
[custom_field_vec_h_noise_(0.2, 0.7, -0.9)] grad_fd      = [0.04731119 0.11537224 0.        ]
[custom_field_vec_h_noise_(0.2, 0.7, -0.9)] angle_deg    = 0.5504151622847467

[custom_current_vec_h_clean_(0.2, 0.7, -0.9)] grad_adjoint = [ 5.26759494e-09  1.29291713e-08 -3.78763451e-09]
[custom_current_vec_h_clean_(0.2, 0.7, -0.9)] grad_fd      = [ 5.35793632e-09  1.24344979e-08 -3.17523785e-09]
[custom_current_vec_h_clean_(0.2, 0.7, -0.9)] angle_deg    = 2.2701833801929077

[custom_field_vec_h_clean_(0.2, 0.7, -0.9)] grad_adjoint = [0.07723136 0.15429205 0.        ]
[custom_field_vec_h_clean_(0.2, 0.7, -0.9)] grad_fd      = [0.07404014 0.14368445 0.        ]
[custom_field_vec_h_clean_(0.2, 0.7, -0.9)] angle_deg    = 0.6715144060362552

[custom_current_vec_e_clean_(1.0, -0.5, 0.25)] grad_adjoint = [474.49067083 -74.59588806  76.7800313 ]
[custom_current_vec_e_clean_(1.0, -0.5, 0.25)] grad_fd      = [444.25964355 -76.37023926  91.62902832]
[custom_current_vec_e_clean_(1.0, -0.5, 0.25)] angle_deg    = 2.539341918234507

[custom_current_vec_e_noise_(1.0, -0.5, 0.25)] grad_adjoint = [477.40864118 -40.94562373  65.26307893]
[custom_current_vec_e_noise_(1.0, -0.5, 0.25)] grad_fd      = [447.15881348 -44.02160645  84.38110352]
[custom_current_vec_e_noise_(1.0, -0.5, 0.25)] angle_deg    = 2.9664772611479786

[custom_field_vec_e_clean_(1.0, -0.5, 0.25)] grad_adjoint = [ 0.20166847 -0.05835947  0.        ]
[custom_field_vec_e_clean_(1.0, -0.5, 0.25)] grad_fd      = [ 0.18898398 -0.05148351  0.        ]
[custom_field_vec_e_clean_(1.0, -0.5, 0.25)] angle_deg    = 0.900684114155899

[custom_current_vec_h_noise_(1.0, -0.5, 0.25)] grad_adjoint = [ 2.40205526e-08 -2.78028553e-09  2.82341683e-09]
[custom_current_vec_h_noise_(1.0, -0.5, 0.25)] grad_fd      = [ 2.24620322e-08 -2.76223489e-09  3.33510997e-09]
[custom_current_vec_h_noise_(1.0, -0.5, 0.25)] angle_deg    = 1.7702686329253479

[custom_current_vec_e_noise_(0.2, 0.7, -0.9)] grad_adjoint = [102.37676179 296.94389099 -79.4756258 ]
[custom_current_vec_e_noise_(0.2, 0.7, -0.9)] grad_fd      = [ 93.15490723 283.50830078 -59.66186523]
[custom_current_vec_e_noise_(0.2, 0.7, -0.9)] angle_deg    = 3.0055478992711557

[custom_field_vec_e_clean_(0.2, 0.7, -0.9)] grad_adjoint = [0.07840275 0.15621138 0.        ]
[custom_field_vec_e_clean_(0.2, 0.7, -0.9)] grad_fd      = [0.07607043 0.14856458 0.        ]
[custom_field_vec_e_clean_(0.2, 0.7, -0.9)] angle_deg    = 0.46193485235168164


Note

Enables differentiation w.r.t. source field data and wires sources into the autograd forward/adjoint flow.

  • Implement _compute_derivatives for CustomCurrentSource.current_dataset and CustomFieldSource.field_dataset
  • Add adjoint utilities: transpose_interp_field_to_dataset, compute_spatial_weights, compute_source_weights, get_frequency_omega
  • Extend Simulation._make_adjoint_monitors to create field monitors for sources; no eps monitors for sources
  • Generalize _strip_traced_fields to accept multiple starting_paths and update autograd API (setup_run, forward/backward paths) to include sources
  • Refactor postprocess_adj to process structures and sources separately; add source-time scaling for VJP
  • Add extensive analytical and numerical tests validating source gradients; minor test fix for SIM_FIELDS_KEYS
  • Update changelog to mention new adjoint support

Written by Cursor Bugbot for commit 360ac7a. This will update automatically on new commits. Configure here.

Greptile Overview

Greptile Summary

This PR implements adjoint gradient computation for CustomCurrentSource.current_dataset and CustomFieldSource.field_dataset, enabling automatic differentiation with respect to source field data. The implementation extends the existing autograd infrastructure to support sources in addition to structures.

Key Changes:

  • Added _compute_derivatives() methods to CustomCurrentSource and CustomFieldSource that compute vector-Jacobian products (VJPs) by interpolating adjoint fields onto source datasets
  • Extended _make_adjoint_monitors() in Simulation to create field monitors for sources alongside existing structure monitors
  • Refactored postprocess_adj() in backward.py to handle both structures and sources through separate processing functions
  • Added utility functions transpose_interp_field_to_dataset(), compute_source_weights(), and get_frequency_omega() for source gradient computations
  • Modified _strip_traced_fields() in base.py to support multiple starting paths instead of a single path
  • Included comprehensive analytical and numerical tests validating the gradient implementation

Implementation Details:

For CustomCurrentSource, the gradient is computed as 0.5 * Re(source_time_scaling * adjoint_field * sign) where the sign depends on whether the component is E (+1) or H (-1).

For CustomFieldSource, the implementation uses the equivalence principle with cross products to determine the relationship between field components and injected currents, scaled by omega * epsilon_0 / cell_size.

The numerical test results in the PR description show angle differences between adjoint and finite-difference gradients ranging from 0.02° to 3.0°, indicating good agreement.

Confidence Score: 4/5

  • This PR is generally safe to merge with minor style improvements recommended
  • The implementation is well-structured with comprehensive test coverage (analytical and numerical tests). The gradient computation follows established patterns from structure gradients. Two minor style issues with error message formatting were identified. The numerical results show good agreement between adjoint and finite-difference methods (angles < 3°). The refactoring properly separates concerns between structures and sources.
  • Pay attention to error message formatting in tidy3d/components/source/current.py and tidy3d/web/api/autograd/backward.py to align with coding standards

Important Files Changed

Filename Overview
tidy3d/components/source/current.py Added _compute_derivatives method to CustomCurrentSource for adjoint gradient computation with proper field interpolation and scaling
tidy3d/components/source/field.py Added _compute_derivatives method to CustomFieldSource for adjoint gradient computation with cross-product based current scaling
tidy3d/web/api/autograd/backward.py Refactored adjoint processing to support both structures and sources, added _process_source_gradients function with source time scaling
tidy3d/components/simulation.py Extended _make_adjoint_monitors to create field monitors for sources in addition to structures
tidy3d/components/base.py Changed _strip_traced_fields to support multiple starting paths instead of single path
tidy3d/components/autograd/derivative_utils.py Added helper functions compute_source_weights, transpose_interp_field_to_dataset, and get_frequency_omega for source gradient computation

Sequence Diagram

sequenceDiagram
    participant User
    participant AutogradAPI as Autograd API
    participant Simulation
    participant Source as CustomSource
    participant BackwardPass as Backward Pass
    participant DerivativeInfo
    
    User->>AutogradAPI: run with traced source parameters
    AutogradAPI->>Simulation: execute forward simulation
    Simulation->>Simulation: _make_adjoint_monitors()
    Simulation->>Simulation: create source field monitors
    
    Note over Simulation: Forward simulation runs
    
    User->>AutogradAPI: compute gradients (backward pass)
    AutogradAPI->>BackwardPass: setup_adj(data_fields_vjp)
    BackwardPass->>BackwardPass: filter traced fields
    BackwardPass->>Simulation: _make_adjoint_sims()
    
    Note over Simulation: Adjoint simulation runs
    
    BackwardPass->>BackwardPass: postprocess_adj()
    BackwardPass->>BackwardPass: _process_source_gradients()
    BackwardPass->>DerivativeInfo: create DerivativeInfo with E_adj, H_adj
    BackwardPass->>Source: _compute_derivatives(derivative_info)
    
    alt CustomCurrentSource
        Source->>Source: transpose_interp_field_to_dataset()
        Source->>Source: compute VJP with source_time_scaling
        Source-->>BackwardPass: derivative_map
    else CustomFieldSource
        Source->>Source: compute cross products (n x E, n x H)
        Source->>Source: transpose_interp_field_to_dataset()
        Source->>Source: apply current_scale (omega * epsilon_0)
        Source-->>BackwardPass: derivative_map
    end
    
    BackwardPass-->>AutogradAPI: sim_fields_vjp
    AutogradAPI-->>User: gradients w.r.t. source parameters
Loading

@marcorudolphflex marcorudolphflex force-pushed the FXC-4927-enable-source-differentiation branch from 9542484 to 5dec731 Compare January 23, 2026 10:18
@marcorudolphflex marcorudolphflex marked this pull request as ready for review January 23, 2026 10:23
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@marcorudolphflex marcorudolphflex force-pushed the FXC-4927-enable-source-differentiation branch from 5dec731 to 6aa0c0b Compare January 23, 2026 10:29
cursor[bot]

This comment was marked as outdated.

@marcorudolphflex marcorudolphflex force-pushed the FXC-4927-enable-source-differentiation branch from 6aa0c0b to 7d69d2a Compare January 23, 2026 11:33
cursor[bot]

This comment was marked as outdated.

@marcorudolphflex marcorudolphflex force-pushed the FXC-4927-enable-source-differentiation branch from 7d69d2a to e93bb93 Compare January 23, 2026 11:45
cursor[bot]

This comment was marked as outdated.

@marcorudolphflex marcorudolphflex force-pushed the FXC-4927-enable-source-differentiation branch from e93bb93 to 360ac7a Compare January 23, 2026 12:30
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

scale = scale * step
if dim not in dims_to_integrate and field_data.sizes.get(dim, 0) > 1:
scale = scale / field_data.sizes[dim]
return weights * scale
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused exported function never used in production

Low Severity

The function compute_source_weights is defined, added to __all__, and has a unit test, but it is never actually imported or used anywhere in the production code. CustomCurrentSource._compute_derivatives only imports transpose_interp_field_to_dataset, and CustomFieldSource._compute_derivatives imports compute_spatial_weights (not compute_source_weights), get_frequency_omega, and transpose_interp_field_to_dataset. This is dead code that adds maintenance burden without being utilized.

Fix in Cursor Fix in Web

@github-actions
Copy link
Contributor

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/autograd/derivative_utils.py (89.1%): Missing lines 1067,1077,1085,1109,1113,1121,1140-1141,1153,1203,1210,1221
  • tidy3d/components/base.py (90.9%): Missing lines 1226
  • tidy3d/components/simulation.py (94.4%): Missing lines 4898
  • tidy3d/components/source/base.py (66.7%): Missing lines 71
  • tidy3d/components/source/current.py (71.4%): Missing lines 234,244,247-248,256-258,262,265-266
  • tidy3d/components/source/field.py (71.2%): Missing lines 258,265,270-272,277-278,285-287,291-292,308,317-318,332-333
  • tidy3d/web/api/autograd/backward.py (91.0%): Missing lines 136,174,270,318-320,322,360,435-437,439

Summary

  • Total: 370 lines
  • Missing: 54 lines
  • Coverage: 85%

tidy3d/components/autograd/derivative_utils.py

Lines 1063-1071

  1063     """
  1064 
  1065     def _cell_size_weights(coord: np.ndarray) -> np.ndarray:
  1066         if coord.size <= 1:
! 1067             return np.array([1.0], dtype=float)
  1068         deltas = np.diff(coord)
  1069         diff_left = np.pad(deltas, (1, 0), mode="edge")
  1070         diff_right = np.pad(deltas, (0, 1), mode="edge")
  1071         return 0.5 * (diff_left + diff_right)

Lines 1073-1081

  1073     weight_dims = []
  1074     weight_arrays = []
  1075     for dim in dims:
  1076         if dim not in arr.coords:
! 1077             continue
  1078         coord = np.asarray(arr.coords[dim].data)
  1079         if coord.size <= 1:
  1080             continue
  1081         weight_dims.append(dim)

Lines 1081-1089

  1081         weight_dims.append(dim)
  1082         weight_arrays.append(_cell_size_weights(coord))
  1083 
  1084     if not weight_dims:
! 1085         return SpatialDataArray(1.0)
  1086 
  1087     weights = np.ix_(*weight_arrays)
  1088     weights_data = weights[0]
  1089     for weight_array in weights[1:]:

Lines 1105-1117

  1105     weights = compute_spatial_weights(field_data, dims=dims_to_integrate)
  1106     scale = 1.0
  1107     for axis, dim in enumerate("xyz"):
  1108         if dim not in field_data.coords:
! 1109             continue
  1110         if dim in dims_to_integrate and field_data.sizes.get(dim, 0) == 1:
  1111             axis_size = float(source_size[axis])
  1112             if axis_size > 0.0:
! 1113                 scale = scale * axis_size
  1114             elif axis_size == 0.0 and dim in adjoint_field.coords:
  1115                 coord_vals = np.asarray(adjoint_field.coords[dim].data)
  1116                 if coord_vals.size > 1:
  1117                     step = np.min(np.abs(np.diff(coord_vals)))

Lines 1117-1125

  1117                     step = np.min(np.abs(np.diff(coord_vals)))
  1118                     if np.isfinite(step) and step > 0.0:
  1119                         scale = scale * step
  1120         if dim not in dims_to_integrate and field_data.sizes.get(dim, 0) > 1:
! 1121             scale = scale / field_data.sizes[dim]
  1122     return weights * scale
  1123 
  1124 
  1125 def transpose_interp_field_to_dataset(

Lines 1136-1145

  1136         if target_freqs.size == source_freqs.size and np.allclose(
  1137             target_freqs, source_freqs, rtol=1e-12, atol=0.0
  1138         ):
  1139             return field
! 1140         method = "nearest" if target_freqs.size <= 1 or source_freqs.size <= 1 else "linear"
! 1141         return field.interp(
  1142             {"f": target_freqs},
  1143             method=method,
  1144             kwargs={"bounds_error": False, "fill_value": 0.0},
  1145         ).fillna(0.0)

Lines 1149-1157

  1149     ) -> np.ndarray:
  1150         if param_coords_1d.size == 1:
  1151             return field_values.sum(axis=0, keepdims=True)
  1152         if np.any(param_coords_1d[1:] < param_coords_1d[:-1]):
! 1153             raise ValueError("Spatial coordinates must be sorted before computing derivatives.")
  1154 
  1155         n_param = param_coords_1d.size
  1156         n_field = field_values.shape[0]
  1157         field_values_2d = field_values.reshape(n_field, -1)

Lines 1199-1207

  1199     values = np.asarray(weighted.data)
  1200     dims = list(weighted.dims)
  1201     for dim in "xyz":
  1202         if dim not in field_coords or dim not in param_coords:
! 1203             continue
  1204         axis_index = dims.index(dim)
  1205         values = _interp_axis(values, axis_index, field_coords[dim], param_coords[dim])
  1206 
  1207     out_coords = {dim: np.asarray(dataset_field.coords[dim].data) for dim in dataset_field.dims}

Lines 1206-1214

  1206 
  1207     out_coords = {dim: np.asarray(dataset_field.coords[dim].data) for dim in dataset_field.dims}
  1208     result = SpatialDataArray(values, coords=out_coords, dims=tuple(dims))
  1209     if tuple(dims) != tuple(dataset_field.dims):
! 1210         result = result.transpose(*dataset_field.dims)
  1211     return result
  1212 
  1213 
  1214 def get_frequency_omega(

Lines 1217-1225

  1217     """Return angular frequency aligned with field_data frequencies."""
  1218     if "f" in field_data.dims:
  1219         omega = 2 * np.pi * np.asarray(field_data.coords["f"].data)
  1220         return FreqDataArray(omega, coords={"f": np.asarray(field_data.coords["f"].data)})
! 1221     return 2 * np.pi * float(np.asarray(frequencies).squeeze())
  1222 
  1223 
  1224 __all__ = [
  1225     "DerivativeInfo",

tidy3d/components/base.py

Lines 1222-1230

  1222         # Handle multiple starting paths
  1223         if paths:
  1224             # If paths is a single tuple, convert to tuple of tuples
  1225             if isinstance(paths[0], str):
! 1226                 paths = (paths,)
  1227 
  1228             # Process each starting path
  1229             for starting_path in paths:
  1230                 # Navigate to the starting path in the dictionary

tidy3d/components/simulation.py

Lines 4894-4902

  4894                 structure_index_to_keys[index].append(fields)
  4895             elif component_type == "sources":
  4896                 source_index_to_keys[index].append(fields)
  4897             else:
! 4898                 raise ValueError(
  4899                     f"Unknown component type '{component_type}' encountered while "
  4900                     "constructing adjoint monitors. "
  4901                     "Expected one of: 'structures', 'sources'."
  4902                 )

tidy3d/components/source/base.py

Lines 67-75

  67     _warn_traced_size = _warn_unsupported_traced_argument("size")
  68 
  69     def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
  70         """Compute adjoint derivatives for source parameters."""
! 71         raise NotImplementedError(f"Can't compute derivative for 'Source': '{type(self)}'.")
  72 
  73     @pydantic.validator("source_time", always=True)
  74     def _freqs_lower_bound(cls, val):
  75         """Raise validation error if central frequency is too low."""

tidy3d/components/source/current.py

Lines 230-238

  230             transpose_interp_field_to_dataset,
  231         )
  232 
  233         if self.current_dataset is None:
! 234             return {tuple(path): 0.0 for path in derivative_info.paths}
  235 
  236         derivative_map = {}
  237         center = tuple(self.center)
  238         h_adj = derivative_info.H_adj or {}

Lines 240-252

  240 
  241         for field_path in derivative_info.paths:
  242             field_path = tuple(field_path)
  243             if len(field_path) < 2 or field_path[0] != "current_dataset":
! 244                 log.warning(
  245                     f"Unsupported traced source path '{field_path}' for CustomCurrentSource."
  246                 )
! 247                 derivative_map[field_path] = 0.0
! 248                 continue
  249 
  250             field_name = field_path[1]
  251             if (
  252                 len(field_name) != 2

Lines 252-270

  252                 len(field_name) != 2
  253                 or field_name[0] not in ("E", "H")
  254                 or field_name[1] not in ("x", "y", "z")
  255             ):
! 256                 log.warning(f"Unsupported field component '{field_name}' in CustomCurrentSource.")
! 257                 derivative_map[field_path] = 0.0
! 258                 continue
  259 
  260             field_data = getattr(self.current_dataset, field_name, None)
  261             if field_data is None:
! 262                 raise ValueError(f"Cannot find field '{field_name}' in current dataset.")
  263 
  264             if field_name.startswith("H"):
! 265                 adjoint_field = h_adj.get(field_name)
! 266                 component_sign = -1.0
  267             else:  # "E" case
  268                 adjoint_field = e_adj.get(field_name)
  269                 component_sign = 1.0

tidy3d/components/source/field.py

Lines 254-262

  254             transpose_interp_field_to_dataset,
  255         )
  256 
  257         if self.field_dataset is None:
! 258             return {tuple(path): 0.0 for path in derivative_info.paths}
  259 
  260         derivative_map = {}
  261         center = tuple(self.center)
  262         e_adj = derivative_info.E_adj or {}

Lines 261-282

  261         center = tuple(self.center)
  262         e_adj = derivative_info.E_adj or {}
  263         h_adj = derivative_info.H_adj or {}
  264         if self.injection_axis is None:
! 265             return {tuple(path): 0.0 for path in derivative_info.paths}
  266 
  267         for field_path in derivative_info.paths:
  268             field_path = tuple(field_path)
  269             if len(field_path) < 2 or field_path[0] != "field_dataset":
! 270                 log.warning(f"Unsupported traced source path '{field_path}' for CustomFieldSource.")
! 271                 derivative_map[field_path] = 0.0
! 272                 continue
  273 
  274             field_name = field_path[1]
  275             field_data = getattr(self.field_dataset, field_name, None)
  276             if field_data is None:
! 277                 derivative_map[field_path] = 0.0
! 278                 continue
  279 
  280             if (
  281                 len(field_name) != 2
  282                 or field_name[0] not in ("E", "H")

Lines 281-296

  281                 len(field_name) != 2
  282                 or field_name[0] not in ("E", "H")
  283                 or field_name[1] not in ("x", "y", "z")
  284             ):
! 285                 log.warning(f"Unsupported field component '{field_name}' in CustomFieldSource.")
! 286                 derivative_map[field_path] = 0.0
! 287                 continue
  288 
  289             component_axis = "xyz".index(field_name[1])
  290             if component_axis == self.injection_axis:
! 291                 derivative_map[field_path] = np.zeros_like(field_data.data)
! 292                 continue
  293 
  294             def _get_adjoint_and_sign(
  295                 *,
  296                 field_name: str,

Lines 304-312

  304                 e_vec = np.eye(3)[component_axis]
  305                 cross = np.cross(n_vec, e_vec)
  306 
  307                 if not np.any(cross):
! 308                     return None, 0.0  # indicates "no gradient"
  309 
  310                 target_axis = int(np.flatnonzero(cross)[0])
  311                 component_sign = float(cross[target_axis])

Lines 313-322

  313                 if field_name.startswith("E"):
  314                     target_component = f"H{'xyz'[target_axis]}"
  315                     adjoint_field = h_adj.get(target_component)
  316                 else:
! 317                     target_component = f"E{'xyz'[target_axis]}"
! 318                     adjoint_field = e_adj.get(target_component)
  319 
  320                 return adjoint_field, component_sign
  321 
  322             adjoint_field, component_sign = _get_adjoint_and_sign(

Lines 328-337

  328             )
  329 
  330             if component_sign == 0.0:
  331                 # no gradient for injection_axis == component_axis
! 332                 derivative_map[field_path] = np.zeros_like(field_data.data)
! 333                 continue
  334 
  335             adjoint_on_dataset = transpose_interp_field_to_dataset(
  336                 adjoint_field, field_data, center=center
  337             )

tidy3d/web/api/autograd/backward.py

Lines 132-140

  132                     sim_data_adj, sim_data_orig, sim_data_fwd, component_index, component_paths
  133                 )
  134             )
  135         else:
! 136             raise ValueError(
  137                 f"Unexpected component_type='{component_type}' for component_index={component_index}. "
  138                 "Expected 'structures' or 'sources'."
  139             )

Lines 170-178

  170     monitor_freqs = np.array(fld_adj.monitor.freqs)
  171     if len(adjoint_frequencies) != len(monitor_freqs) or not np.allclose(
  172         np.sort(adjoint_frequencies), np.sort(monitor_freqs), rtol=1e-10, atol=0
  173     ):
! 174         raise ValueError(
  175             f"Frequency mismatch in adjoint postprocessing for source {source_index}. "
  176             f"Expected frequencies from monitor: {monitor_freqs}, "
  177             f"but derivative map has: {adjoint_frequencies}. "
  178         )

Lines 266-274

  266     monitor_freqs = np.array(fld_adj.monitor.freqs)
  267     if len(adjoint_frequencies) != len(monitor_freqs) or not np.allclose(
  268         np.sort(adjoint_frequencies), np.sort(monitor_freqs), rtol=1e-10, atol=0
  269     ):
! 270         raise ValueError(
  271             f"Frequency mismatch in adjoint postprocessing for structure {structure_index}. "
  272             f"Expected frequencies from monitor: {monitor_freqs}, "
  273             f"but derivative map has: {adjoint_frequencies}. "
  274         )

Lines 314-326

  314         geometry_box = structure.geometry.bounding_box
  315         background_structures_2d = []
  316         sim_inf_background_medium = sim_orig.medium
  317         if np.any(np.array(geometry_box.size) == 0.0):
! 318             zero_coordinate = tuple(geometry_box.size).index(0.0)
! 319             new_size = [td.inf, td.inf, td.inf]
! 320             new_size[zero_coordinate] = 0.0
  321 
! 322             background_structures_2d = [
  323                 structure.updated_copy(geometry=geometry_box.updated_copy(size=new_size))
  324             ]
  325         else:
  326             sim_inf_background_medium = structure.medium

Lines 356-364

  356     n_freqs = len(adjoint_frequencies)
  357     if not freq_chunk_size or freq_chunk_size <= 0:
  358         freq_chunk_size = n_freqs
  359     else:
! 360         freq_chunk_size = min(freq_chunk_size, n_freqs)
  361 
  362     # process in chunks
  363     vjp_value_map = {}

Lines 431-443

  431 
  432         # accumulate results
  433         for path, value in vjp_chunk.items():
  434             if path in vjp_value_map:
! 435                 val = vjp_value_map[path]
! 436                 if isinstance(val, (list, tuple)) and isinstance(value, (list, tuple)):
! 437                     vjp_value_map[path] = type(val)(x + y for x, y in zip(val, value))
  438                 else:
! 439                     vjp_value_map[path] += value
  440             else:
  441                 vjp_value_map[path] = value
  442     sim_fields_vjp = {}
  443     # store vjps in output map

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants