-
Notifications
You must be signed in to change notification settings - Fork 70
feat(tidy3d): FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint #3208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
|
@greptile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
8386a06 to
0c9fe1b
Compare
|
technical still semi-drafty, marked as ready for cursor bugbot. Will re-request review when really ready. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 5 comments
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/autograd/parallel_adjoint_bases.pyLines 18-26 18
19 def _coord_index(coord_values: np.ndarray, target: object) -> int:
20 values = np.asarray(coord_values)
21 if values.size == 0:
! 22 raise ValueError("No coordinate values available to index.")
23 if values.dtype.kind in ("f", "c"):
24 matches = np.where(np.isclose(values, float(target), rtol=1e-10, atol=0.0))[0]
25 else:
26 matches = np.where(values == target)[0]Lines 24-32 24 matches = np.where(np.isclose(values, float(target), rtol=1e-10, atol=0.0))[0]
25 else:
26 matches = np.where(values == target)[0]
27 if matches.size == 0:
! 28 raise ValueError(f"Could not find coordinate value {target!r} in {values}.")
29 return int(matches[0])
30
31
32 def _index_for_dims(data_array: DataArray, coord_map: dict[str, object]) -> tuple[int, ...]:Lines 57-65 57 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
58 ) -> complex:
59 vjp = data_fields_vjp.get(self.data_path)
60 if vjp is None:
! 61 return 0.0 + 0.0j
62 data_index = self._data_index_from_sim_data(sim_data_orig)
63 vjp_array = np.asarray(vjp)
64 value = complex(vjp_array[data_index])
65 return valueLines 68-80 68 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
69 ) -> None:
70 vjp = data_fields_vjp.get(self.data_path)
71 if vjp is None:
! 72 return
73 vjp_array = np.asarray(vjp)
74 vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
75 if vjp_array is not vjp:
! 76 data_fields_vjp[self.data_path] = vjp_array
77
78
79 @dataclass(frozen=True)
80 class DiffractionAdjointBasis:Lines 86-129 86 polarization: Literal["s", "p"]
87 data_path: tuple
88
89 def _data_index_from_sim_data(self, sim_data_orig: SimulationData) -> tuple[int, ...]:
! 90 diff_data = sim_data_orig.data[self.monitor_index]
! 91 dataset_name = self.data_path[-1]
! 92 field_data = getattr(diff_data, dataset_name)
! 93 coord_map = {
94 "orders_x": int(self.order_x),
95 "orders_y": int(self.order_y),
96 "f": float(self.freq),
97 }
! 98 return _index_for_dims(field_data, coord_map)
99
100 def vjp_value(
101 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData, norm: np.ndarray
102 ) -> complex:
! 103 vjp = data_fields_vjp.get(self.data_path)
! 104 if vjp is None:
! 105 return 0.0 + 0.0j
! 106 try:
! 107 data_index = self._data_index_from_sim_data(sim_data_orig)
! 108 except ValueError:
! 109 return 0.0 + 0.0j
! 110 return complex(np.asarray(vjp)[data_index] * norm[data_index])
111
112 def zero_vjp_entry(
113 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
114 ) -> None:
! 115 vjp = data_fields_vjp.get(self.data_path)
! 116 if vjp is None:
! 117 return
! 118 try:
! 119 data_index = self._data_index_from_sim_data(sim_data_orig)
! 120 except ValueError:
! 121 return
! 122 vjp_array = np.asarray(vjp)
! 123 vjp_array[data_index] = 0.0
! 124 if vjp_array is not vjp:
! 125 data_fields_vjp[self.data_path] = vjp_array
126
127
128 @dataclass(frozen=True)
129 class PointFieldAdjointBasis:Lines 156-168 156 self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
157 ) -> None:
158 vjp = data_fields_vjp.get(self.data_path)
159 if vjp is None:
! 160 return
161 vjp_array = np.asarray(vjp)
162 vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
163 if vjp_array is not vjp:
! 164 data_fields_vjp[self.data_path] = vjp_array
165
166
167 ParallelAdjointBasis = ModeAdjointBasis | DiffractionAdjointBasis | PointFieldAdjointBasisLines 200-208 200 ) -> list[PointFieldAdjointBasis]:
201 bases: list[PointFieldAdjointBasis] = []
202 for component, freqs in component_freqs:
203 if component not in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
! 204 continue
205 for freq in freqs:
206 bases.append(
207 PointFieldAdjointBasis(
208 monitor_index=monitor_index,Lines 229-241 229 for order_x in orders_x:
230 for order_y in orders_y:
231 angle_theta = float(theta_for(int(order_x), int(order_y)))
232 if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 233 continue
234 for pol in pols:
235 pol_str = str(pol)
236 if pol_str not in ("s", "p"):
! 237 continue
238 dataset_name = "Ephi" if pol_str == "s" else "Etheta"
239 bases.append(
240 DiffractionAdjointBasis(
241 monitor_index=monitor_index,tidy3d/components/autograd/source_factory.pyLines 20-28 20 def flip_direction(direction: object) -> str:
21 if hasattr(direction, "values"):
22 direction = str(direction.values)
23 if direction not in ("+", "-"):
! 24 raise ValueError(f"Direction must be in {('+', '-')}, got '{direction}'.")
25 return "-" if direction == "+" else "+"
26
27
28 def adjoint_fwidth_from_simulation(simulation: Simulation) -> float:Lines 75-85 75 coefficient: complex,
76 fwidth: float,
77 ) -> CustomCurrentSource | None:
78 if any(simulation.symmetry):
! 79 raise ValueError("Point-field adjoint sources require symmetry to be disabled.")
80 if not monitor.colocate:
! 81 raise ValueError("Point-field adjoint sources require colocated field monitors.")
82
83 grid = simulation.discretize_monitor(monitor)
84 coords = {}
85 spatial_coords = grid.boundariesLines 87-95 87 for axis, dim in enumerate("xyz"):
88 if monitor.size[axis] == 0:
89 coords[dim] = np.array([monitor.center[axis]])
90 else:
! 91 coords[dim] = np.array(spatial_coords_dict[dim][:-1])
92 values = (
93 2
94 * -1j
95 * coefficientLines 116-124 116 values *= scaling_factor
117 values = np.nan_to_num(values, nan=0.0)
118
119 if np.all(values == 0):
! 120 return None
121
122 dataset = FieldDataset(**{component: ScalarFieldDataArray(values, coords=coords)})
123 return CustomCurrentSource(
124 center=monitor.geometry.center,Lines 129-149 129 )
130
131
132 def diffraction_monitor_medium(simulation: Simulation, monitor: DiffractionMonitor) -> object:
! 133 structures = [simulation.scene.background_structure, *list(simulation.structures or ())]
! 134 mediums = simulation.scene.intersecting_media(monitor, structures)
! 135 if len(mediums) != 1:
! 136 raise ValueError("Diffraction monitor plane must be homogeneous to build adjoint sources.")
! 137 return list(mediums)[0]
138
139
140 def bloch_vec_for_axis(simulation: Simulation, axis_name: str) -> float:
! 141 boundary = simulation.boundary_spec[axis_name]
! 142 plus = boundary.plus
! 143 if hasattr(plus, "bloch_vec"):
! 144 return float(plus.bloch_vec)
! 145 return 0.0
146
147
148 def diffraction_order_range(
149 size: float, bloch_vec: float, freq: float, medium: objectLines 147-164 147
148 def diffraction_order_range(
149 size: float, bloch_vec: float, freq: float, medium: object
150 ) -> np.ndarray:
! 151 if size == 0:
! 152 return np.array([0], dtype=int)
! 153 eps = medium.eps_model(freq)
! 154 index = np.real(np.sqrt(eps))
! 155 limit = abs(index) * freq * size / C_0
! 156 order_min = int(np.ceil(-limit - bloch_vec))
! 157 order_max = int(np.floor(limit - bloch_vec))
! 158 if order_max < order_min:
! 159 return np.array([], dtype=int)
! 160 return np.arange(order_min, order_max + 1, dtype=int)
161
162
163 def diffraction_source_from_simulation(
164 simulation: Simulation,Lines 169-188 169 polarization: Literal["s", "p"],
170 coefficient: complex,
171 fwidth: float,
172 ) -> PlaneWave:
! 173 medium = diffraction_monitor_medium(simulation, monitor)
! 174 axis_names = ("x", "y", "z")
! 175 normal_axis = monitor.normal_axis
! 176 transverse_axes = [axis_names[i] for i in range(3) if i != normal_axis]
! 177 axis_x, axis_y = transverse_axes
178
! 179 size_x = simulation.size[axis_names.index(axis_x)]
! 180 size_y = simulation.size[axis_names.index(axis_y)]
! 181 bloch_vec_x = bloch_vec_for_axis(simulation, axis_x)
! 182 bloch_vec_y = bloch_vec_for_axis(simulation, axis_y)
183
! 184 ux = DiffractionData.reciprocal_coords(
185 orders=np.array([order_x]),
186 size=size_x,
187 bloch_vec=bloch_vec_x,
188 f=freq,Lines 187-195 187 bloch_vec=bloch_vec_x,
188 f=freq,
189 medium=medium,
190 )
! 191 uy = DiffractionData.reciprocal_coords(
192 orders=np.array([order_y]),
193 size=size_y,
194 bloch_vec=bloch_vec_y,
195 f=freq,Lines 194-210 194 bloch_vec=bloch_vec_y,
195 f=freq,
196 medium=medium,
197 )
! 198 theta_vals, phi_vals = DiffractionData.compute_angles((ux, uy))
! 199 angle_theta = float(theta_vals[0, 0, 0])
! 200 angle_phi = float(phi_vals[0, 0, 0])
! 201 if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 202 raise ValueError("Adjoint source not available for evanescent diffraction order.")
203
! 204 pol_angle = 0.0 if polarization == "p" else np.pi / 2
! 205 bck_eps = medium.eps_model(freq)
! 206 return _diffraction_plane_wave(
207 monitor=monitor,
208 freq=freq,
209 angle_theta=angle_theta,
210 angle_phi=angle_phi,Lines 230-242 230 angle_theta = float(theta_data.sel(**angle_sel_kwargs))
231 angle_phi = float(phi_data.sel(**angle_sel_kwargs))
232
233 if np.isnan(angle_theta):
! 234 return None
235
236 pol_str = str(polarization)
237 if pol_str not in ("p", "s"):
! 238 raise ValueError(f"Something went wrong, given pol='{pol_str}' in adjoint source.")
239
240 pol_angle = 0.0 if pol_str == "p" else np.pi / 2
241 bck_eps = diff_data.medium.eps_model(freq)
242 return _diffraction_plane_wave(Lines 283-291 283 )
284
285
286 def diffraction_norm(diffraction_data: DiffractionData) -> np.ndarray:
! 287 theta_data, _ = diffraction_data.angles
! 288 cos_theta = np.cos(np.nan_to_num(theta_data))
! 289 cos_theta[cos_theta <= COS_THETA_THRESH] = np.inf
! 290 return 1.0 / np.sqrt(2.0 * np.asarray(diffraction_data.eta)) / np.sqrt(cos_theta)tidy3d/components/autograd/utils.pyLines 62-70 62 if k in target:
63 val = target[k]
64 if isinstance(val, (list, tuple)) and isinstance(v, (list, tuple)):
65 if len(val) != len(v):
! 66 raise ValueError(
67 f"Cannot accumulate field map for key '{k}': "
68 f"length mismatch ({len(val)} vs {len(v)})."
69 )
70 target[k] = type(val)(x + y for x, y in zip(val, v))tidy3d/components/data/monitor_data.pyLines 174-186 174 return []
175
176 def supports_parallel_adjoint(self) -> bool:
177 """Return ``True`` if this monitor data supports parallel adjoint sources."""
! 178 return False
179
180 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
181 """Return parallel adjoint bases for this monitor data."""
! 182 return []
183
184 @staticmethod
185 def get_amplitude(x) -> complex:
186 """Get the complex amplitude out of some data."""Lines 1463-1471 1463
1464 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
1465 """Return parallel adjoint bases for single-point field monitors."""
1466 if not self.supports_parallel_adjoint():
! 1467 return []
1468 component_freqs = [
1469 (str(component), data_array.coords["f"].values)
1470 for component, data_array in self.field_components.items()
1471 ]Lines 1875-1883 1875 return val
1876
1877 def supports_parallel_adjoint(self) -> bool:
1878 """Return ``True`` for mode monitor amplitude adjoints."""
! 1879 return True
1880
1881 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
1882 """Return parallel adjoint bases for mode monitor amplitudes."""
1883 amps = self.ampsLines 4106-4114 4106 return DataArray(np.stack([amp_phi, amp_theta], axis=3), coords=coords)
4107
4108 def supports_parallel_adjoint(self) -> bool:
4109 """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 4110 return True
4111
4112 def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
4113 """Return parallel adjoint bases for diffraction monitor amplitudes."""
4114 amps = self.ampstidy3d/components/monitor.pyLines 64-72 64 WINDOW_FACTOR = 15
65
66
67 def _shifted_orders(orders: np.ndarray, bloch_vec: float) -> np.ndarray:
! 68 return bloch_vec + np.atleast_2d(orders).T
69
70
71 def _reciprocal_coords(
72 orders: np.ndarray, size: float, bloch_vec: float, freq: float, medium: MediumTypeLines 70-82 70
71 def _reciprocal_coords(
72 orders: np.ndarray, size: float, bloch_vec: float, freq: float, medium: MediumType
73 ) -> np.ndarray:
! 74 if size == 0:
! 75 return np.atleast_2d(0)
! 76 epsilon = medium.eps_model(freq)
! 77 bloch_array = _shifted_orders(orders, bloch_vec)
! 78 return bloch_array / size * C_0 / freq / np.real(np.sqrt(epsilon))
79
80
81 def _compute_angles(
82 reciprocal_vectors: tuple[np.ndarray, np.ndarray],Lines 80-94 80
81 def _compute_angles(
82 reciprocal_vectors: tuple[np.ndarray, np.ndarray],
83 ) -> tuple[np.ndarray, np.ndarray]:
! 84 with warnings.catch_warnings():
! 85 warnings.filterwarnings(
86 "ignore", message="invalid value encountered in arcsin", category=RuntimeWarning
87 )
! 88 ux, uy = reciprocal_vectors
! 89 thetas, phis = Geometry.kspace_2_sph(ux[:, None, :], uy[None, :, :], axis=2)
! 90 return (thetas, phis)
91
92
93 class Monitor(AbstractMonitor):
94 """Abstract base class for monitors."""Lines 1112-1120 1112 return amps_size + fields_size
1113
1114 def supports_parallel_adjoint(self) -> bool:
1115 """Return ``True`` for mode monitor amplitude adjoints."""
! 1116 return True
1117
1118 def parallel_adjoint_bases(
1119 self, simulation: Simulation, monitor_index: int
1120 ) -> list[ParallelAdjointBasis]:Lines 1853-1861 1853 return BYTES_COMPLEX * len(self.ux) * len(self.uy) * len(self.freqs) * 6
1854
1855 def supports_parallel_adjoint(self) -> bool:
1856 """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 1857 return True
1858
1859 def parallel_adjoint_bases(
1860 self, simulation: Simulation, monitor_index: int
1861 ) -> list[ParallelAdjointBasis]:Lines 1859-1867 1859 def parallel_adjoint_bases(
1860 self, simulation: Simulation, monitor_index: int
1861 ) -> list[ParallelAdjointBasis]:
1862 """Return parallel adjoint bases for diffraction monitor amplitudes."""
! 1863 from tidy3d.components.autograd.source_factory import (
1864 bloch_vec_for_axis,
1865 diffraction_monitor_medium,
1866 diffraction_order_range,
1867 )Lines 1865-1902 1865 diffraction_monitor_medium,
1866 diffraction_order_range,
1867 )
1868
! 1869 medium = diffraction_monitor_medium(simulation, self)
1870
! 1871 axis_names = ("x", "y", "z")
! 1872 normal_axis = self.normal_axis
! 1873 transverse_axes = [axis_names[i] for i in range(3) if i != normal_axis]
! 1874 axis_x, axis_y = transverse_axes
1875
! 1876 size_x = simulation.size[axis_names.index(axis_x)]
! 1877 size_y = simulation.size[axis_names.index(axis_y)]
! 1878 bloch_vec_x = bloch_vec_for_axis(simulation, axis_x)
! 1879 bloch_vec_y = bloch_vec_for_axis(simulation, axis_y)
1880
! 1881 bases: list[DiffractionAdjointBasis] = []
! 1882 freqs = [float(freq) for freq in self.freqs]
! 1883 for freq in freqs:
! 1884 orders_x = diffraction_order_range(size_x, bloch_vec_x, freq, medium)
! 1885 orders_y = diffraction_order_range(size_y, bloch_vec_y, freq, medium)
! 1886 if orders_x.size == 0 or orders_y.size == 0:
! 1887 continue
1888
! 1889 ux = _reciprocal_coords(
1890 orders=orders_x, size=size_x, bloch_vec=bloch_vec_x, freq=freq, medium=medium
1891 )
! 1892 uy = _reciprocal_coords(
1893 orders=orders_y, size=size_y, bloch_vec=bloch_vec_y, freq=freq, medium=medium
1894 )
! 1895 theta_vals, _ = _compute_angles((ux, uy))
! 1896 order_x_index = {int(val): idx for idx, val in enumerate(orders_x)}
! 1897 order_y_index = {int(val): idx for idx, val in enumerate(orders_y)}
! 1898 bases.extend(
1899 _build_diffraction_bases_for_freq(
1900 monitor_name=self.name,
1901 monitor_index=monitor_index,
1902 freq=freq,Lines 1911-1919 1911 order_x_index[ox], order_y_index[oy], 0
1912 ],
1913 )
1914 )
! 1915 return bases
1916
1917
1918 class DiffractionMonitor(PlanarMonitor, FreqMonitor):
1919 """:class:`Monitor` that uses a 2D Fourier transform to compute thetidy3d/web/api/autograd/parallel_adjoint.pyLines 40-48 40 def _scale_field_map(field_map: AutogradFieldMap, scale: float) -> AutogradFieldMap:
41 scaled = {}
42 for k, v in field_map.items():
43 if isinstance(v, (list, tuple)):
! 44 scaled[k] = type(v)(scale * x for x in v)
45 else:
46 scaled[k] = scale * v
47 return scaledLines 59-69 59 unsupported: list[str] = []
60 for monitor_index, monitor in enumerate(simulation.monitors):
61 try:
62 bases_for_monitor = monitor.parallel_adjoint_bases(simulation, monitor_index)
! 63 except ValueError:
! 64 unsupported.append(monitor.name)
! 65 continue
66 if bases_for_monitor:
67 bases.extend(bases_for_monitor)
68 elif not monitor.supports_parallel_adjoint():
69 unsupported.append(monitor.name)Lines 105-117 105 basis_spec: object,
106 ) -> object:
107 post_norm = sim_data_adj.simulation.post_norm
108 if not hasattr(basis_spec, "freq"):
! 109 return post_norm
110 freqs = np.asarray(post_norm.coords["f"].values)
111 idx = int(np.argmin(np.abs(freqs - basis_spec.freq)))
112 if not np.isclose(freqs[idx], basis_spec.freq):
! 113 raise td.exceptions.AdjointError(
114 "Parallel adjoint basis frequency not found in adjoint post-normalization."
115 )
116 return post_norm.isel(f=[idx])Lines 134-152 134 for key, data_array in monitor_data.field_components.items():
135 if "f" in data_array.dims:
136 freqs = np.asarray(data_array.coords["f"].values)
137 if freqs.size == 0:
! 138 raise td.exceptions.AdjointError(
139 "Parallel adjoint expected frequency data but no frequencies were found."
140 )
141 idx = int(np.argmin(np.abs(freqs - freq)))
142 if not np.isclose(freqs[idx], freq, rtol=1e-10, atol=0.0):
! 143 raise td.exceptions.AdjointError(
144 "Parallel adjoint basis frequency not found in monitor data."
145 )
146 updates[key] = data_array.isel(f=[idx])
147 return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False, **updates)
! 148 return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False)
149
150
151 def _select_sim_data_freq(
152 sim_data_adj: td.SimulationData,Lines 158-166 158 for monitor in sim.monitors:
159 if hasattr(monitor, "freqs"):
160 monitor_updated = monitor.updated_copy(freqs=[freq])
161 else:
! 162 monitor_updated = monitor
163 monitors.append(monitor_updated)
164 monitor_map[monitor.name] = monitor_updated
165 sim_updated = sim.updated_copy(monitors=monitors)Lines 219-227 219 simulation: td.Simulation,
220 basis_sources: list[tuple[ParallelAdjointBasis, Any]],
221 ) -> list[tuple[list[ParallelAdjointBasis], AdjointSourceInfo]]:
222 if not basis_sources:
! 223 return []
224
225 sim_data_stub = td.SimulationData(simulation=simulation, data=())
226 sources = [source for _, source in basis_sources]
227 sources_processed = td.SimulationData._adjoint_src_width_single(sources)Lines 272-282 272 monitor = simulation.monitors[basis.monitor_index]
273 fwidth = adjoint_fwidth_from_simulation(simulation)
274
275 if isinstance(basis, DiffractionAdjointBasis):
! 276 if not isinstance(monitor, DiffractionMonitor):
! 277 raise ValueError("Diffraction basis monitor mismatch.")
! 278 source = diffraction_source_from_simulation(
279 simulation=simulation,
280 monitor=monitor,
281 freq=basis.freq,
282 order_x=basis.order_x,Lines 284-296 284 polarization=basis.polarization,
285 coefficient=coefficient,
286 fwidth=fwidth,
287 )
! 288 return adjoint_source_info_single(source)
289
290 if isinstance(basis, ModeAdjointBasis):
291 if not isinstance(monitor, ModeMonitor):
! 292 raise ValueError("Mode basis monitor mismatch.")
293 source = mode_source_from_monitor(
294 monitor=monitor,
295 freq=basis.freq,
296 direction=basis.direction,Lines 301-309 301 return adjoint_source_info_single(source)
302
303 if isinstance(basis, PointFieldAdjointBasis):
304 if not isinstance(monitor, FieldMonitor):
! 305 raise ValueError("Point-field basis monitor mismatch.")
306 source = point_current_source_from_simulation(
307 simulation=simulation,
308 monitor=monitor,
309 component=basis.component,Lines 311-322 311 coefficient=coefficient,
312 fwidth=fwidth,
313 )
314 if source is None:
! 315 raise ValueError("Adjoint point source has zero amplitude.")
316 return adjoint_source_info_single(source)
317
! 318 raise ValueError("Unsupported parallel adjoint basis.")
319
320
321 @dataclass(frozen=True)
322 class ParallelAdjointPayload:Lines 370-387 370 simulation=simulation,
371 basis=basis,
372 coefficient=1.0 + 0.0j,
373 )
! 374 except ValueError as exc:
! 375 td.log.info(
376 f"Skipping parallel adjoint basis for monitor '{basis.monitor_name}': {exc}"
377 )
! 378 continue
379 basis_sources.append((basis, source_info.sources[0]))
380
381 if not basis_sources:
382 if basis_specs:
! 383 td.log.info("Parallel adjoint produced no simulations for this task.")
384 else:
385 td.log.warning(
386 "Parallel adjoint disabled because no eligible monitor outputs were found."
387 )Lines 411-425 411 task_map[adj_task_name] = bases
412 used_bases.extend(bases)
413
414 if not sims_adj_dict:
! 415 if basis_specs:
! 416 td.log.info("Parallel adjoint produced no simulations for this task.")
417 else:
! 418 td.log.warning(
419 "Parallel adjoint disabled because no eligible monitor outputs were found."
420 )
! 421 return None
422
423 td.log.info(
424 "Parallel adjoint enabled: launched "
425 f"{len(sims_adj_dict)} canonical adjoint simulations for task '{task_name}'."Lines 437-445 437 task_paths: dict[str, str],
438 base_dir: PathLike,
439 ) -> None:
440 if not task_names:
! 441 return
442 target_dir = Path(base_dir) / config.adjoint.local_adjoint_dir
443 target_dir.mkdir(parents=True, exist_ok=True)
444 for task_name in task_names:
445 src_path = task_paths.get(task_name)Lines 444-459 444 for task_name in task_names:
445 src_path = task_paths.get(task_name)
446 if not src_path:
447 continue
! 448 src = Path(src_path)
! 449 if not src.exists():
! 450 continue
! 451 dst = target_dir / src.name
! 452 if src.resolve() == dst.resolve():
! 453 continue
! 454 dst.parent.mkdir(parents=True, exist_ok=True)
! 455 src.replace(dst)
456
457
458 def apply_parallel_adjoint(
459 data_fields_vjp: AutogradFieldMap,Lines 461-469 461 sim_data_orig: td.SimulationData,
462 ) -> tuple[AutogradFieldMap, AutogradFieldMap]:
463 basis_maps = parallel_info.get("basis_maps")
464 if basis_maps is None:
! 465 return {}, data_fields_vjp
466
467 data_fields_vjp_fallback = {k: np.array(v, copy=True) for k, v in data_fields_vjp.items()}
468 vjp_parallel: AutogradFieldMap = {}
469 norm_cache: dict[int, np.ndarray] = {}Lines 471-490 471 basis_specs = list(parallel_info.get("basis_specs", []))
472 for basis in basis_specs:
473 basis_map = basis_maps.get(basis)
474 if basis_map is None:
! 475 continue
476 basis_real = basis_map.get("real")
477 basis_imag = basis_map.get("imag")
478 if basis_real is None or basis_imag is None:
! 479 continue
480 if isinstance(basis, DiffractionAdjointBasis):
! 481 norm = norm_cache.get(basis.monitor_index)
! 482 if norm is None:
! 483 diff_data = sim_data_orig.data[basis.monitor_index]
! 484 norm = diffraction_norm(diff_data)
! 485 norm_cache[basis.monitor_index] = norm
! 486 coefficient = basis.vjp_value(data_fields_vjp, sim_data_orig, norm)
487 else:
488 coefficient = basis.vjp_value(data_fields_vjp, sim_data_orig)
489
490 if coefficient == 0: |
0c9fe1b to
4d59014
Compare
4d59014 to
7cabb37
Compare
yaugenst-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @marcorudolphflex this is pretty great, had a cursory glance at the PR to try to understand a bit what's going on and left some questions/comments, but I'll look deeper into the implementation when I find some time. I guess one thing to note is that this introduces a lot of new code, even modules. Not a problem in itself but I'd maybe have a closer look whether any of this can be simplified.
Also, could you show some plots/verification against the non-parallel adjoint?
| - Mode direction policy (for mode monitors): `config.adjoint.parallel_adjoint_mode_direction_policy` | ||
| - `"assume_outgoing"` (default): pick the mode direction based on monitor position relative to the simulation center and flip it for the adjoint. | ||
| - `"run_both_directions"`: launch parallel adjoint sources for both `+` and `-` directions. | ||
| - `"no_parallel"`: disable parallel adjoint entirely. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do mode monitors separately have a flag to turn parallel adjoint off, in addition to the global config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbd if users need that in case they want to override the global toggle for this less-determined mode monitor... As we do have a config field anyways, I think it doesn't hurt. Or could that be confusing for users regarding its effect along with the global toggle?
| - Only effective when: `config.adjoint.local_gradient = True` | ||
| - If `local_gradient=False`, the flag is ignored and behavior remains unchanged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only local gradients? Couldn't this be supported in remote too? Maybe it's fine as an initial version but I don't see how this couldnt be done for remote?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably yes, this was the "easy" start.
|
|
||
| #### Limits and guardrails you should expect | ||
|
|
||
| - Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are all parallel simulations counted as adjoint toward this cap?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
tidy3d/plugins/autograd/README.md
Outdated
| #### Limits and guardrails you should expect | ||
|
|
||
| - Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`. | ||
| - If enabling parallel adjoint would exceed the cap, the run logs a warning and proceeds with the sequential path for that forward run (or a safe subset, depending on policy). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might not want to proceed at all in that case, not sure. Since this a flag that we wouldn't turn on by default, it means that generally the user will have requested it, so they might want to choose to increase the cap instead of running sequentially.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true, changed it to raising an AdjointError as we do it currently for sequential adjoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be important to explain/understand here in which scenarios how many adjoint simulations would get launched in the parallel case and what the edge cases are so there are no surprises.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added some section in the readme
7cabb37 to
f351142
Compare
f351142 to
9abfd55
Compare
9abfd55 to
ac73e25
Compare
ac73e25 to
dac5685
Compare
dac5685 to
29306de
Compare
29306de to
e05cfd7
Compare
e05cfd7 to
e13d19f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
e13d19f to
6b77f75
Compare
Note
Medium Risk
Touches the autograd forward/backward execution path and adds new batching/scheduling logic for adjoint simulations; while gated behind
local_gradient+ config flags with fallbacks, issues could affect gradient correctness or task/file handling.Overview
Adds an optional parallel-adjoint execution path for autograd when
local_gradient=True, launching canonical “unit” adjoint simulations alongside the forward run and reusing/scaling their results during the backward pass (with automatic fallback to the existing sequential adjoint pipeline when unsupported or incomplete).Introduces new adjoint configuration knobs (
config.adjoint.parallel_all_portandconfig.adjoint.parallel_adjoint_mode_direction_policy), adds deterministic parallel-adjoint basis generation for mode/diffraction/point-field monitors (and corresponding source factories), refactors adjoint-simulation construction intomake_adjoint_simulation, and centralizes VJP filtering/field-map accumulation utilities. Includes documentation updates and a new comprehensive test suite covering equivalence vs sequential gradients, fallback/limits behavior, and task launching/file relocation.Written by Cursor Bugbot for commit 6b77f75. This will update automatically on new commits. Configure here.
Greptile Overview
Greptile Summary
This PR implements parallel adjoint scheduling for autograd simulations, allowing eligible adjoint simulations to run concurrently with forward simulations when
local_gradient=True. The feature launches canonical "unit" adjoint solves up front and scales them during the backward pass, reducing gradient computation wall-clock time.Key changes:
config.adjoint.parallel_all_portconfiguration flag to enable the featureconfig.adjoint.parallel_adjoint_mode_direction_policyto control mode direction handlingParallelAdjointDescriptorclasses for mode, diffraction, and point-field monitorssupports_parallel_adjoint()andparallel_adjoint_descriptors()methodsmake_adjoint_simulation()functionThe implementation includes proper fallback mechanisms when monitors are unsupported or limits are exceeded, ensuring backward compatibility.
Confidence Score: 3/5
tidy3d/web/api/autograd/parallel_adjoint.py(lines 322, 327, 329) andtidy3d/components/autograd/source_factory.py(lines 94, 207) for floating-point comparison fixes.Important Files Changed
supports_parallel_adjoint()andparallel_adjoint_descriptors()methods; extracted mode source creation to factory.make_adjoint_simulation()function for reuse; clean refactoring with no logic changes.Sequence Diagram
sequenceDiagram participant User participant AutogradAPI as Autograd API participant ParallelAdjoint as Parallel Adjoint participant Batch as Batch Executor participant Solver as FDTD Solver User->>AutogradAPI: run(sim, local_gradient=True) AutogradAPI->>ParallelAdjoint: prepare_parallel_adjoint(sim) ParallelAdjoint->>ParallelAdjoint: collect descriptors from monitors ParallelAdjoint->>ParallelAdjoint: filter by direction policy ParallelAdjoint->>ParallelAdjoint: create canonical adjoint sims ParallelAdjoint-->>AutogradAPI: ParallelAdjointPayload alt Parallel Adjoint Enabled AutogradAPI->>Batch: run_async({fwd, adj_1, adj_2, ...}) Batch->>Solver: run forward sim Batch->>Solver: run adjoint sim 1 Batch->>Solver: run adjoint sim 2 Batch-->>AutogradAPI: BatchData AutogradAPI->>AutogradAPI: populate_parallel_adjoint_bases() AutogradAPI-->>User: SimulationData + aux_data else Parallel Adjoint Disabled AutogradAPI->>Solver: run forward sim only AutogradAPI-->>User: SimulationData end User->>AutogradAPI: backward pass (VJP) AutogradAPI->>ParallelAdjoint: apply_parallel_adjoint(vjp, bases) ParallelAdjoint->>ParallelAdjoint: compute coefficients from VJP ParallelAdjoint->>ParallelAdjoint: scale and accumulate basis maps ParallelAdjoint-->>AutogradAPI: vjp_parallel + vjp_fallback alt Has fallback VJPs AutogradAPI->>Solver: run sequential adjoint for remaining Solver-->>AutogradAPI: adjoint field data AutogradAPI->>AutogradAPI: combine vjp_parallel + sequential end AutogradAPI-->>User: gradient