-
Notifications
You must be signed in to change notification settings - Fork 70
FXC-4927 enable source differentiation #3197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
tests/test_components/autograd/numerical/test_autograd_source_numerical.py
Outdated
Show resolved
Hide resolved
9542484 to
5dec731
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
5dec731 to
6aa0c0b
Compare
6aa0c0b to
7d69d2a
Compare
7d69d2a to
e93bb93
Compare
e93bb93 to
360ac7a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
| scale = scale * step | ||
| if dim not in dims_to_integrate and field_data.sizes.get(dim, 0) > 1: | ||
| scale = scale / field_data.sizes[dim] | ||
| return weights * scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unused exported function never used in production
Low Severity
The function compute_source_weights is defined, added to __all__, and has a unit test, but it is never actually imported or used anywhere in the production code. CustomCurrentSource._compute_derivatives only imports transpose_interp_field_to_dataset, and CustomFieldSource._compute_derivatives imports compute_spatial_weights (not compute_source_weights), get_frequency_omega, and transpose_interp_field_to_dataset. This is dead code that adds maintenance burden without being utilized.
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/autograd/derivative_utils.pyLines 1063-1071 1063 """
1064
1065 def _cell_size_weights(coord: np.ndarray) -> np.ndarray:
1066 if coord.size <= 1:
! 1067 return np.array([1.0], dtype=float)
1068 deltas = np.diff(coord)
1069 diff_left = np.pad(deltas, (1, 0), mode="edge")
1070 diff_right = np.pad(deltas, (0, 1), mode="edge")
1071 return 0.5 * (diff_left + diff_right)Lines 1073-1081 1073 weight_dims = []
1074 weight_arrays = []
1075 for dim in dims:
1076 if dim not in arr.coords:
! 1077 continue
1078 coord = np.asarray(arr.coords[dim].data)
1079 if coord.size <= 1:
1080 continue
1081 weight_dims.append(dim)Lines 1081-1089 1081 weight_dims.append(dim)
1082 weight_arrays.append(_cell_size_weights(coord))
1083
1084 if not weight_dims:
! 1085 return SpatialDataArray(1.0)
1086
1087 weights = np.ix_(*weight_arrays)
1088 weights_data = weights[0]
1089 for weight_array in weights[1:]:Lines 1105-1117 1105 weights = compute_spatial_weights(field_data, dims=dims_to_integrate)
1106 scale = 1.0
1107 for axis, dim in enumerate("xyz"):
1108 if dim not in field_data.coords:
! 1109 continue
1110 if dim in dims_to_integrate and field_data.sizes.get(dim, 0) == 1:
1111 axis_size = float(source_size[axis])
1112 if axis_size > 0.0:
! 1113 scale = scale * axis_size
1114 elif axis_size == 0.0 and dim in adjoint_field.coords:
1115 coord_vals = np.asarray(adjoint_field.coords[dim].data)
1116 if coord_vals.size > 1:
1117 step = np.min(np.abs(np.diff(coord_vals)))Lines 1117-1125 1117 step = np.min(np.abs(np.diff(coord_vals)))
1118 if np.isfinite(step) and step > 0.0:
1119 scale = scale * step
1120 if dim not in dims_to_integrate and field_data.sizes.get(dim, 0) > 1:
! 1121 scale = scale / field_data.sizes[dim]
1122 return weights * scale
1123
1124
1125 def transpose_interp_field_to_dataset(Lines 1136-1145 1136 if target_freqs.size == source_freqs.size and np.allclose(
1137 target_freqs, source_freqs, rtol=1e-12, atol=0.0
1138 ):
1139 return field
! 1140 method = "nearest" if target_freqs.size <= 1 or source_freqs.size <= 1 else "linear"
! 1141 return field.interp(
1142 {"f": target_freqs},
1143 method=method,
1144 kwargs={"bounds_error": False, "fill_value": 0.0},
1145 ).fillna(0.0)Lines 1149-1157 1149 ) -> np.ndarray:
1150 if param_coords_1d.size == 1:
1151 return field_values.sum(axis=0, keepdims=True)
1152 if np.any(param_coords_1d[1:] < param_coords_1d[:-1]):
! 1153 raise ValueError("Spatial coordinates must be sorted before computing derivatives.")
1154
1155 n_param = param_coords_1d.size
1156 n_field = field_values.shape[0]
1157 field_values_2d = field_values.reshape(n_field, -1)Lines 1199-1207 1199 values = np.asarray(weighted.data)
1200 dims = list(weighted.dims)
1201 for dim in "xyz":
1202 if dim not in field_coords or dim not in param_coords:
! 1203 continue
1204 axis_index = dims.index(dim)
1205 values = _interp_axis(values, axis_index, field_coords[dim], param_coords[dim])
1206
1207 out_coords = {dim: np.asarray(dataset_field.coords[dim].data) for dim in dataset_field.dims}Lines 1206-1214 1206
1207 out_coords = {dim: np.asarray(dataset_field.coords[dim].data) for dim in dataset_field.dims}
1208 result = SpatialDataArray(values, coords=out_coords, dims=tuple(dims))
1209 if tuple(dims) != tuple(dataset_field.dims):
! 1210 result = result.transpose(*dataset_field.dims)
1211 return result
1212
1213
1214 def get_frequency_omega(Lines 1217-1225 1217 """Return angular frequency aligned with field_data frequencies."""
1218 if "f" in field_data.dims:
1219 omega = 2 * np.pi * np.asarray(field_data.coords["f"].data)
1220 return FreqDataArray(omega, coords={"f": np.asarray(field_data.coords["f"].data)})
! 1221 return 2 * np.pi * float(np.asarray(frequencies).squeeze())
1222
1223
1224 __all__ = [
1225 "DerivativeInfo",tidy3d/components/base.pyLines 1222-1230 1222 # Handle multiple starting paths
1223 if paths:
1224 # If paths is a single tuple, convert to tuple of tuples
1225 if isinstance(paths[0], str):
! 1226 paths = (paths,)
1227
1228 # Process each starting path
1229 for starting_path in paths:
1230 # Navigate to the starting path in the dictionarytidy3d/components/simulation.pyLines 4894-4902 4894 structure_index_to_keys[index].append(fields)
4895 elif component_type == "sources":
4896 source_index_to_keys[index].append(fields)
4897 else:
! 4898 raise ValueError(
4899 f"Unknown component type '{component_type}' encountered while "
4900 "constructing adjoint monitors. "
4901 "Expected one of: 'structures', 'sources'."
4902 )tidy3d/components/source/base.pyLines 67-75 67 _warn_traced_size = _warn_unsupported_traced_argument("size")
68
69 def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap:
70 """Compute adjoint derivatives for source parameters."""
! 71 raise NotImplementedError(f"Can't compute derivative for 'Source': '{type(self)}'.")
72
73 @pydantic.validator("source_time", always=True)
74 def _freqs_lower_bound(cls, val):
75 """Raise validation error if central frequency is too low."""tidy3d/components/source/current.pyLines 230-238 230 transpose_interp_field_to_dataset,
231 )
232
233 if self.current_dataset is None:
! 234 return {tuple(path): 0.0 for path in derivative_info.paths}
235
236 derivative_map = {}
237 center = tuple(self.center)
238 h_adj = derivative_info.H_adj or {}Lines 240-252 240
241 for field_path in derivative_info.paths:
242 field_path = tuple(field_path)
243 if len(field_path) < 2 or field_path[0] != "current_dataset":
! 244 log.warning(
245 f"Unsupported traced source path '{field_path}' for CustomCurrentSource."
246 )
! 247 derivative_map[field_path] = 0.0
! 248 continue
249
250 field_name = field_path[1]
251 if (
252 len(field_name) != 2Lines 252-270 252 len(field_name) != 2
253 or field_name[0] not in ("E", "H")
254 or field_name[1] not in ("x", "y", "z")
255 ):
! 256 log.warning(f"Unsupported field component '{field_name}' in CustomCurrentSource.")
! 257 derivative_map[field_path] = 0.0
! 258 continue
259
260 field_data = getattr(self.current_dataset, field_name, None)
261 if field_data is None:
! 262 raise ValueError(f"Cannot find field '{field_name}' in current dataset.")
263
264 if field_name.startswith("H"):
! 265 adjoint_field = h_adj.get(field_name)
! 266 component_sign = -1.0
267 else: # "E" case
268 adjoint_field = e_adj.get(field_name)
269 component_sign = 1.0tidy3d/components/source/field.pyLines 254-262 254 transpose_interp_field_to_dataset,
255 )
256
257 if self.field_dataset is None:
! 258 return {tuple(path): 0.0 for path in derivative_info.paths}
259
260 derivative_map = {}
261 center = tuple(self.center)
262 e_adj = derivative_info.E_adj or {}Lines 261-282 261 center = tuple(self.center)
262 e_adj = derivative_info.E_adj or {}
263 h_adj = derivative_info.H_adj or {}
264 if self.injection_axis is None:
! 265 return {tuple(path): 0.0 for path in derivative_info.paths}
266
267 for field_path in derivative_info.paths:
268 field_path = tuple(field_path)
269 if len(field_path) < 2 or field_path[0] != "field_dataset":
! 270 log.warning(f"Unsupported traced source path '{field_path}' for CustomFieldSource.")
! 271 derivative_map[field_path] = 0.0
! 272 continue
273
274 field_name = field_path[1]
275 field_data = getattr(self.field_dataset, field_name, None)
276 if field_data is None:
! 277 derivative_map[field_path] = 0.0
! 278 continue
279
280 if (
281 len(field_name) != 2
282 or field_name[0] not in ("E", "H")Lines 281-296 281 len(field_name) != 2
282 or field_name[0] not in ("E", "H")
283 or field_name[1] not in ("x", "y", "z")
284 ):
! 285 log.warning(f"Unsupported field component '{field_name}' in CustomFieldSource.")
! 286 derivative_map[field_path] = 0.0
! 287 continue
288
289 component_axis = "xyz".index(field_name[1])
290 if component_axis == self.injection_axis:
! 291 derivative_map[field_path] = np.zeros_like(field_data.data)
! 292 continue
293
294 def _get_adjoint_and_sign(
295 *,
296 field_name: str,Lines 304-312 304 e_vec = np.eye(3)[component_axis]
305 cross = np.cross(n_vec, e_vec)
306
307 if not np.any(cross):
! 308 return None, 0.0 # indicates "no gradient"
309
310 target_axis = int(np.flatnonzero(cross)[0])
311 component_sign = float(cross[target_axis])Lines 313-322 313 if field_name.startswith("E"):
314 target_component = f"H{'xyz'[target_axis]}"
315 adjoint_field = h_adj.get(target_component)
316 else:
! 317 target_component = f"E{'xyz'[target_axis]}"
! 318 adjoint_field = e_adj.get(target_component)
319
320 return adjoint_field, component_sign
321
322 adjoint_field, component_sign = _get_adjoint_and_sign(Lines 328-337 328 )
329
330 if component_sign == 0.0:
331 # no gradient for injection_axis == component_axis
! 332 derivative_map[field_path] = np.zeros_like(field_data.data)
! 333 continue
334
335 adjoint_on_dataset = transpose_interp_field_to_dataset(
336 adjoint_field, field_data, center=center
337 )tidy3d/web/api/autograd/backward.pyLines 132-140 132 sim_data_adj, sim_data_orig, sim_data_fwd, component_index, component_paths
133 )
134 )
135 else:
! 136 raise ValueError(
137 f"Unexpected component_type='{component_type}' for component_index={component_index}. "
138 "Expected 'structures' or 'sources'."
139 )Lines 170-178 170 monitor_freqs = np.array(fld_adj.monitor.freqs)
171 if len(adjoint_frequencies) != len(monitor_freqs) or not np.allclose(
172 np.sort(adjoint_frequencies), np.sort(monitor_freqs), rtol=1e-10, atol=0
173 ):
! 174 raise ValueError(
175 f"Frequency mismatch in adjoint postprocessing for source {source_index}. "
176 f"Expected frequencies from monitor: {monitor_freqs}, "
177 f"but derivative map has: {adjoint_frequencies}. "
178 )Lines 266-274 266 monitor_freqs = np.array(fld_adj.monitor.freqs)
267 if len(adjoint_frequencies) != len(monitor_freqs) or not np.allclose(
268 np.sort(adjoint_frequencies), np.sort(monitor_freqs), rtol=1e-10, atol=0
269 ):
! 270 raise ValueError(
271 f"Frequency mismatch in adjoint postprocessing for structure {structure_index}. "
272 f"Expected frequencies from monitor: {monitor_freqs}, "
273 f"but derivative map has: {adjoint_frequencies}. "
274 )Lines 314-326 314 geometry_box = structure.geometry.bounding_box
315 background_structures_2d = []
316 sim_inf_background_medium = sim_orig.medium
317 if np.any(np.array(geometry_box.size) == 0.0):
! 318 zero_coordinate = tuple(geometry_box.size).index(0.0)
! 319 new_size = [td.inf, td.inf, td.inf]
! 320 new_size[zero_coordinate] = 0.0
321
! 322 background_structures_2d = [
323 structure.updated_copy(geometry=geometry_box.updated_copy(size=new_size))
324 ]
325 else:
326 sim_inf_background_medium = structure.mediumLines 356-364 356 n_freqs = len(adjoint_frequencies)
357 if not freq_chunk_size or freq_chunk_size <= 0:
358 freq_chunk_size = n_freqs
359 else:
! 360 freq_chunk_size = min(freq_chunk_size, n_freqs)
361
362 # process in chunks
363 vjp_value_map = {}Lines 431-443 431
432 # accumulate results
433 for path, value in vjp_chunk.items():
434 if path in vjp_value_map:
! 435 val = vjp_value_map[path]
! 436 if isinstance(val, (list, tuple)) and isinstance(value, (list, tuple)):
! 437 vjp_value_map[path] = type(val)(x + y for x, y in zip(val, value))
438 else:
! 439 vjp_value_map[path] += value
440 else:
441 vjp_value_map[path] = value
442 sim_fields_vjp = {}
443 # store vjps in output map |
Implemented adjoint gradients for
CustomCurrentSource.current_datasetandCustomFieldSource.field_dataset.here some raw results from the numerical tests
Note
Enables differentiation w.r.t. source field data and wires sources into the autograd forward/adjoint flow.
_compute_derivativesforCustomCurrentSource.current_datasetandCustomFieldSource.field_datasettranspose_interp_field_to_dataset,compute_spatial_weights,compute_source_weights,get_frequency_omegaSimulation._make_adjoint_monitorsto create field monitors for sources; no eps monitors for sources_strip_traced_fieldsto accept multiplestarting_pathsand update autograd API (setup_run, forward/backward paths) to includesourcespostprocess_adjto process structures and sources separately; add source-time scaling for VJPSIM_FIELDS_KEYSWritten by Cursor Bugbot for commit 360ac7a. This will update automatically on new commits. Configure here.
Greptile Overview
Greptile Summary
This PR implements adjoint gradient computation for
CustomCurrentSource.current_datasetandCustomFieldSource.field_dataset, enabling automatic differentiation with respect to source field data. The implementation extends the existing autograd infrastructure to support sources in addition to structures.Key Changes:
_compute_derivatives()methods toCustomCurrentSourceandCustomFieldSourcethat compute vector-Jacobian products (VJPs) by interpolating adjoint fields onto source datasets_make_adjoint_monitors()inSimulationto create field monitors for sources alongside existing structure monitorspostprocess_adj()in backward.py to handle both structures and sources through separate processing functionstranspose_interp_field_to_dataset(),compute_source_weights(), andget_frequency_omega()for source gradient computations_strip_traced_fields()in base.py to support multiple starting paths instead of a single pathImplementation Details:
For
CustomCurrentSource, the gradient is computed as0.5 * Re(source_time_scaling * adjoint_field * sign)where the sign depends on whether the component is E (+1) or H (-1).For
CustomFieldSource, the implementation uses the equivalence principle with cross products to determine the relationship between field components and injected currents, scaled byomega * epsilon_0 / cell_size.The numerical test results in the PR description show angle differences between adjoint and finite-difference gradients ranging from 0.02° to 3.0°, indicating good agreement.
Confidence Score: 4/5
tidy3d/components/source/current.pyandtidy3d/web/api/autograd/backward.pyto align with coding standardsImportant Files Changed
_compute_derivativesmethod toCustomCurrentSourcefor adjoint gradient computation with proper field interpolation and scaling_compute_derivativesmethod toCustomFieldSourcefor adjoint gradient computation with cross-product based current scaling_process_source_gradientsfunction with source time scaling_make_adjoint_monitorsto create field monitors for sources in addition to structures_strip_traced_fieldsto support multiple starting paths instead of single pathcompute_source_weights,transpose_interp_field_to_dataset, andget_frequency_omegafor source gradient computationSequence Diagram
sequenceDiagram participant User participant AutogradAPI as Autograd API participant Simulation participant Source as CustomSource participant BackwardPass as Backward Pass participant DerivativeInfo User->>AutogradAPI: run with traced source parameters AutogradAPI->>Simulation: execute forward simulation Simulation->>Simulation: _make_adjoint_monitors() Simulation->>Simulation: create source field monitors Note over Simulation: Forward simulation runs User->>AutogradAPI: compute gradients (backward pass) AutogradAPI->>BackwardPass: setup_adj(data_fields_vjp) BackwardPass->>BackwardPass: filter traced fields BackwardPass->>Simulation: _make_adjoint_sims() Note over Simulation: Adjoint simulation runs BackwardPass->>BackwardPass: postprocess_adj() BackwardPass->>BackwardPass: _process_source_gradients() BackwardPass->>DerivativeInfo: create DerivativeInfo with E_adj, H_adj BackwardPass->>Source: _compute_derivatives(derivative_info) alt CustomCurrentSource Source->>Source: transpose_interp_field_to_dataset() Source->>Source: compute VJP with source_time_scaling Source-->>BackwardPass: derivative_map else CustomFieldSource Source->>Source: compute cross products (n x E, n x H) Source->>Source: transpose_interp_field_to_dataset() Source->>Source: apply current_scale (omega * epsilon_0) Source-->>BackwardPass: derivative_map end BackwardPass-->>AutogradAPI: sim_fields_vjp AutogradAPI-->>User: gradients w.r.t. source parameters