Skip to content

Conversation

@marcorudolphflex
Copy link
Contributor

@marcorudolphflex marcorudolphflex commented Jan 27, 2026

Note

Medium Risk
Touches the autograd forward/backward execution path and adds new batching/scheduling logic for adjoint simulations; while gated behind local_gradient + config flags with fallbacks, issues could affect gradient correctness or task/file handling.

Overview
Adds an optional parallel-adjoint execution path for autograd when local_gradient=True, launching canonical “unit” adjoint simulations alongside the forward run and reusing/scaling their results during the backward pass (with automatic fallback to the existing sequential adjoint pipeline when unsupported or incomplete).

Introduces new adjoint configuration knobs (config.adjoint.parallel_all_port and config.adjoint.parallel_adjoint_mode_direction_policy), adds deterministic parallel-adjoint basis generation for mode/diffraction/point-field monitors (and corresponding source factories), refactors adjoint-simulation construction into make_adjoint_simulation, and centralizes VJP filtering/field-map accumulation utilities. Includes documentation updates and a new comprehensive test suite covering equivalence vs sequential gradients, fallback/limits behavior, and task launching/file relocation.

Written by Cursor Bugbot for commit 6b77f75. This will update automatically on new commits. Configure here.

Greptile Overview

Greptile Summary

This PR implements parallel adjoint scheduling for autograd simulations, allowing eligible adjoint simulations to run concurrently with forward simulations when local_gradient=True. The feature launches canonical "unit" adjoint solves up front and scales them during the backward pass, reducing gradient computation wall-clock time.

Key changes:

  • Added config.adjoint.parallel_all_port configuration flag to enable the feature
  • Added config.adjoint.parallel_adjoint_mode_direction_policy to control mode direction handling
  • Created ParallelAdjointDescriptor classes for mode, diffraction, and point-field monitors
  • Implemented source factory functions for generating adjoint sources deterministically
  • Extended monitor data classes with supports_parallel_adjoint() and parallel_adjoint_descriptors() methods
  • Refactored adjoint simulation creation into reusable make_adjoint_simulation() function
  • Added comprehensive test suite verifying parallel vs sequential gradient equivalence
  • Updated documentation with detailed feature description

The implementation includes proper fallback mechanisms when monitors are unsupported or limits are exceeded, ensuring backward compatibility.

Confidence Score: 3/5

  • This PR introduces significant new functionality with good test coverage but has floating-point comparison issues that need addressing.
  • Score reflects well-architected feature with comprehensive tests and documentation, but critical floating-point equality comparisons (5 instances) need tolerance-based checks per project standards. The refactoring is clean and maintains backward compatibility with proper fallback mechanisms.
  • Pay close attention to tidy3d/web/api/autograd/parallel_adjoint.py (lines 322, 327, 329) and tidy3d/components/autograd/source_factory.py (lines 94, 207) for floating-point comparison fixes.

Important Files Changed

Filename Overview
tidy3d/web/api/autograd/parallel_adjoint.py New file implementing parallel adjoint scheduling. Contains floating-point comparison issues (lines 322, 327, 329) that need tolerance-based checks.
tidy3d/components/autograd/parallel_adjoint_descriptors.py New file with descriptor classes for parallel adjoint; well-structured with proper error handling and type checking.
tidy3d/components/autograd/source_factory.py New source factory utilities with floating-point equality issues (lines 94, 207) that should use tolerance-based comparisons.
tidy3d/web/api/autograd/autograd.py Extended autograd pipeline with parallel adjoint integration; adds helper functions for VJP filtering, field map accumulation, and batch processing.
tidy3d/components/data/monitor_data.py Refactored to support parallel adjoint via new supports_parallel_adjoint() and parallel_adjoint_descriptors() methods; extracted mode source creation to factory.
tidy3d/components/data/sim_data.py Extracted adjoint simulation creation into standalone make_adjoint_simulation() function for reuse; clean refactoring with no logic changes.

Sequence Diagram

sequenceDiagram
    participant User
    participant AutogradAPI as Autograd API
    participant ParallelAdjoint as Parallel Adjoint
    participant Batch as Batch Executor
    participant Solver as FDTD Solver

    User->>AutogradAPI: run(sim, local_gradient=True)
    AutogradAPI->>ParallelAdjoint: prepare_parallel_adjoint(sim)
    ParallelAdjoint->>ParallelAdjoint: collect descriptors from monitors
    ParallelAdjoint->>ParallelAdjoint: filter by direction policy
    ParallelAdjoint->>ParallelAdjoint: create canonical adjoint sims
    ParallelAdjoint-->>AutogradAPI: ParallelAdjointPayload
    
    alt Parallel Adjoint Enabled
        AutogradAPI->>Batch: run_async({fwd, adj_1, adj_2, ...})
        Batch->>Solver: run forward sim
        Batch->>Solver: run adjoint sim 1
        Batch->>Solver: run adjoint sim 2
        Batch-->>AutogradAPI: BatchData
        AutogradAPI->>AutogradAPI: populate_parallel_adjoint_bases()
        AutogradAPI-->>User: SimulationData + aux_data
    else Parallel Adjoint Disabled
        AutogradAPI->>Solver: run forward sim only
        AutogradAPI-->>User: SimulationData
    end

    User->>AutogradAPI: backward pass (VJP)
    AutogradAPI->>ParallelAdjoint: apply_parallel_adjoint(vjp, bases)
    ParallelAdjoint->>ParallelAdjoint: compute coefficients from VJP
    ParallelAdjoint->>ParallelAdjoint: scale and accumulate basis maps
    ParallelAdjoint-->>AutogradAPI: vjp_parallel + vjp_fallback
    
    alt Has fallback VJPs
        AutogradAPI->>Solver: run sequential adjoint for remaining
        Solver-->>AutogradAPI: adjoint field data
        AutogradAPI->>AutogradAPI: combine vjp_parallel + sequential
    end
    
    AutogradAPI-->>User: gradient
Loading

@marcorudolphflex
Copy link
Contributor Author

@greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch 4 times, most recently from 8386a06 to 0c9fe1b Compare January 27, 2026 12:37
@marcorudolphflex marcorudolphflex marked this pull request as ready for review January 27, 2026 12:46
@marcorudolphflex
Copy link
Contributor Author

technical still semi-drafty, marked as ready for cursor bugbot. Will re-request review when really ready.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

@github-actions
Copy link
Contributor

github-actions bot commented Jan 27, 2026

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/autograd/parallel_adjoint_bases.py (75.9%): Missing lines 22,28,61,72,76,90-93,98,103-110,115-125,160,164,204,233,237
  • tidy3d/components/autograd/source_factory.py (63.0%): Missing lines 24,79,81,91,120,133-137,141-145,151-160,173-177,179-182,184,191,198-202,204-206,234,238,287-290
  • tidy3d/components/autograd/utils.py (90.0%): Missing lines 66
  • tidy3d/components/data/monitor_data.py (85.3%): Missing lines 178,182,1467,1879,4110
  • tidy3d/components/data/sim_data.py (100%)
  • tidy3d/components/monitor.py (42.2%): Missing lines 68,74-78,84-85,88-90,1116,1857,1863,1869,1871-1874,1876-1879,1881-1887,1889,1892,1895-1898,1915
  • tidy3d/config/sections.py (100%)
  • tidy3d/web/api/autograd/autograd.py (100%)
  • tidy3d/web/api/autograd/backward.py (100%)
  • tidy3d/web/api/autograd/constants.py (100%)
  • tidy3d/web/api/autograd/parallel_adjoint.py (83.7%): Missing lines 44,63-65,109,113,138,143,148,162,223,276-278,288,292,305,315,318,374-375,378,383,415-416,418,421,441,448-455,465,475,479,481-486
  • tidy3d/web/api/autograd/utils.py (100%)

Summary

  • Total: 801 lines
  • Missing: 172 lines
  • Coverage: 78%

tidy3d/components/autograd/parallel_adjoint_bases.py

Lines 18-26

  18 
  19 def _coord_index(coord_values: np.ndarray, target: object) -> int:
  20     values = np.asarray(coord_values)
  21     if values.size == 0:
! 22         raise ValueError("No coordinate values available to index.")
  23     if values.dtype.kind in ("f", "c"):
  24         matches = np.where(np.isclose(values, float(target), rtol=1e-10, atol=0.0))[0]
  25     else:
  26         matches = np.where(values == target)[0]

Lines 24-32

  24         matches = np.where(np.isclose(values, float(target), rtol=1e-10, atol=0.0))[0]
  25     else:
  26         matches = np.where(values == target)[0]
  27     if matches.size == 0:
! 28         raise ValueError(f"Could not find coordinate value {target!r} in {values}.")
  29     return int(matches[0])
  30 
  31 
  32 def _index_for_dims(data_array: DataArray, coord_map: dict[str, object]) -> tuple[int, ...]:

Lines 57-65

  57         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  58     ) -> complex:
  59         vjp = data_fields_vjp.get(self.data_path)
  60         if vjp is None:
! 61             return 0.0 + 0.0j
  62         data_index = self._data_index_from_sim_data(sim_data_orig)
  63         vjp_array = np.asarray(vjp)
  64         value = complex(vjp_array[data_index])
  65         return value

Lines 68-80

  68         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  69     ) -> None:
  70         vjp = data_fields_vjp.get(self.data_path)
  71         if vjp is None:
! 72             return
  73         vjp_array = np.asarray(vjp)
  74         vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
  75         if vjp_array is not vjp:
! 76             data_fields_vjp[self.data_path] = vjp_array
  77 
  78 
  79 @dataclass(frozen=True)
  80 class DiffractionAdjointBasis:

Lines 86-129

   86     polarization: Literal["s", "p"]
   87     data_path: tuple
   88 
   89     def _data_index_from_sim_data(self, sim_data_orig: SimulationData) -> tuple[int, ...]:
!  90         diff_data = sim_data_orig.data[self.monitor_index]
!  91         dataset_name = self.data_path[-1]
!  92         field_data = getattr(diff_data, dataset_name)
!  93         coord_map = {
   94             "orders_x": int(self.order_x),
   95             "orders_y": int(self.order_y),
   96             "f": float(self.freq),
   97         }
!  98         return _index_for_dims(field_data, coord_map)
   99 
  100     def vjp_value(
  101         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData, norm: np.ndarray
  102     ) -> complex:
! 103         vjp = data_fields_vjp.get(self.data_path)
! 104         if vjp is None:
! 105             return 0.0 + 0.0j
! 106         try:
! 107             data_index = self._data_index_from_sim_data(sim_data_orig)
! 108         except ValueError:
! 109             return 0.0 + 0.0j
! 110         return complex(np.asarray(vjp)[data_index] * norm[data_index])
  111 
  112     def zero_vjp_entry(
  113         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  114     ) -> None:
! 115         vjp = data_fields_vjp.get(self.data_path)
! 116         if vjp is None:
! 117             return
! 118         try:
! 119             data_index = self._data_index_from_sim_data(sim_data_orig)
! 120         except ValueError:
! 121             return
! 122         vjp_array = np.asarray(vjp)
! 123         vjp_array[data_index] = 0.0
! 124         if vjp_array is not vjp:
! 125             data_fields_vjp[self.data_path] = vjp_array
  126 
  127 
  128 @dataclass(frozen=True)
  129 class PointFieldAdjointBasis:

Lines 156-168

  156         self, data_fields_vjp: AutogradFieldMap, sim_data_orig: SimulationData
  157     ) -> None:
  158         vjp = data_fields_vjp.get(self.data_path)
  159         if vjp is None:
! 160             return
  161         vjp_array = np.asarray(vjp)
  162         vjp_array[self._data_index_from_sim_data(sim_data_orig)] = 0.0
  163         if vjp_array is not vjp:
! 164             data_fields_vjp[self.data_path] = vjp_array
  165 
  166 
  167 ParallelAdjointBasis = ModeAdjointBasis | DiffractionAdjointBasis | PointFieldAdjointBasis

Lines 200-208

  200 ) -> list[PointFieldAdjointBasis]:
  201     bases: list[PointFieldAdjointBasis] = []
  202     for component, freqs in component_freqs:
  203         if component not in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"):
! 204             continue
  205         for freq in freqs:
  206             bases.append(
  207                 PointFieldAdjointBasis(
  208                     monitor_index=monitor_index,

Lines 229-241

  229     for order_x in orders_x:
  230         for order_y in orders_y:
  231             angle_theta = float(theta_for(int(order_x), int(order_y)))
  232             if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 233                 continue
  234             for pol in pols:
  235                 pol_str = str(pol)
  236                 if pol_str not in ("s", "p"):
! 237                     continue
  238                 dataset_name = "Ephi" if pol_str == "s" else "Etheta"
  239                 bases.append(
  240                     DiffractionAdjointBasis(
  241                         monitor_index=monitor_index,

tidy3d/components/autograd/source_factory.py

Lines 20-28

  20 def flip_direction(direction: object) -> str:
  21     if hasattr(direction, "values"):
  22         direction = str(direction.values)
  23     if direction not in ("+", "-"):
! 24         raise ValueError(f"Direction must be in {('+', '-')}, got '{direction}'.")
  25     return "-" if direction == "+" else "+"
  26 
  27 
  28 def adjoint_fwidth_from_simulation(simulation: Simulation) -> float:

Lines 75-85

  75     coefficient: complex,
  76     fwidth: float,
  77 ) -> CustomCurrentSource | None:
  78     if any(simulation.symmetry):
! 79         raise ValueError("Point-field adjoint sources require symmetry to be disabled.")
  80     if not monitor.colocate:
! 81         raise ValueError("Point-field adjoint sources require colocated field monitors.")
  82 
  83     grid = simulation.discretize_monitor(monitor)
  84     coords = {}
  85     spatial_coords = grid.boundaries

Lines 87-95

  87     for axis, dim in enumerate("xyz"):
  88         if monitor.size[axis] == 0:
  89             coords[dim] = np.array([monitor.center[axis]])
  90         else:
! 91             coords[dim] = np.array(spatial_coords_dict[dim][:-1])
  92     values = (
  93         2
  94         * -1j
  95         * coefficient

Lines 116-124

  116     values *= scaling_factor
  117     values = np.nan_to_num(values, nan=0.0)
  118 
  119     if np.all(values == 0):
! 120         return None
  121 
  122     dataset = FieldDataset(**{component: ScalarFieldDataArray(values, coords=coords)})
  123     return CustomCurrentSource(
  124         center=monitor.geometry.center,

Lines 129-149

  129     )
  130 
  131 
  132 def diffraction_monitor_medium(simulation: Simulation, monitor: DiffractionMonitor) -> object:
! 133     structures = [simulation.scene.background_structure, *list(simulation.structures or ())]
! 134     mediums = simulation.scene.intersecting_media(monitor, structures)
! 135     if len(mediums) != 1:
! 136         raise ValueError("Diffraction monitor plane must be homogeneous to build adjoint sources.")
! 137     return list(mediums)[0]
  138 
  139 
  140 def bloch_vec_for_axis(simulation: Simulation, axis_name: str) -> float:
! 141     boundary = simulation.boundary_spec[axis_name]
! 142     plus = boundary.plus
! 143     if hasattr(plus, "bloch_vec"):
! 144         return float(plus.bloch_vec)
! 145     return 0.0
  146 
  147 
  148 def diffraction_order_range(
  149     size: float, bloch_vec: float, freq: float, medium: object

Lines 147-164

  147 
  148 def diffraction_order_range(
  149     size: float, bloch_vec: float, freq: float, medium: object
  150 ) -> np.ndarray:
! 151     if size == 0:
! 152         return np.array([0], dtype=int)
! 153     eps = medium.eps_model(freq)
! 154     index = np.real(np.sqrt(eps))
! 155     limit = abs(index) * freq * size / C_0
! 156     order_min = int(np.ceil(-limit - bloch_vec))
! 157     order_max = int(np.floor(limit - bloch_vec))
! 158     if order_max < order_min:
! 159         return np.array([], dtype=int)
! 160     return np.arange(order_min, order_max + 1, dtype=int)
  161 
  162 
  163 def diffraction_source_from_simulation(
  164     simulation: Simulation,

Lines 169-188

  169     polarization: Literal["s", "p"],
  170     coefficient: complex,
  171     fwidth: float,
  172 ) -> PlaneWave:
! 173     medium = diffraction_monitor_medium(simulation, monitor)
! 174     axis_names = ("x", "y", "z")
! 175     normal_axis = monitor.normal_axis
! 176     transverse_axes = [axis_names[i] for i in range(3) if i != normal_axis]
! 177     axis_x, axis_y = transverse_axes
  178 
! 179     size_x = simulation.size[axis_names.index(axis_x)]
! 180     size_y = simulation.size[axis_names.index(axis_y)]
! 181     bloch_vec_x = bloch_vec_for_axis(simulation, axis_x)
! 182     bloch_vec_y = bloch_vec_for_axis(simulation, axis_y)
  183 
! 184     ux = DiffractionData.reciprocal_coords(
  185         orders=np.array([order_x]),
  186         size=size_x,
  187         bloch_vec=bloch_vec_x,
  188         f=freq,

Lines 187-195

  187         bloch_vec=bloch_vec_x,
  188         f=freq,
  189         medium=medium,
  190     )
! 191     uy = DiffractionData.reciprocal_coords(
  192         orders=np.array([order_y]),
  193         size=size_y,
  194         bloch_vec=bloch_vec_y,
  195         f=freq,

Lines 194-210

  194         bloch_vec=bloch_vec_y,
  195         f=freq,
  196         medium=medium,
  197     )
! 198     theta_vals, phi_vals = DiffractionData.compute_angles((ux, uy))
! 199     angle_theta = float(theta_vals[0, 0, 0])
! 200     angle_phi = float(phi_vals[0, 0, 0])
! 201     if np.isnan(angle_theta) or np.cos(angle_theta) <= COS_THETA_THRESH:
! 202         raise ValueError("Adjoint source not available for evanescent diffraction order.")
  203 
! 204     pol_angle = 0.0 if polarization == "p" else np.pi / 2
! 205     bck_eps = medium.eps_model(freq)
! 206     return _diffraction_plane_wave(
  207         monitor=monitor,
  208         freq=freq,
  209         angle_theta=angle_theta,
  210         angle_phi=angle_phi,

Lines 230-242

  230     angle_theta = float(theta_data.sel(**angle_sel_kwargs))
  231     angle_phi = float(phi_data.sel(**angle_sel_kwargs))
  232 
  233     if np.isnan(angle_theta):
! 234         return None
  235 
  236     pol_str = str(polarization)
  237     if pol_str not in ("p", "s"):
! 238         raise ValueError(f"Something went wrong, given pol='{pol_str}' in adjoint source.")
  239 
  240     pol_angle = 0.0 if pol_str == "p" else np.pi / 2
  241     bck_eps = diff_data.medium.eps_model(freq)
  242     return _diffraction_plane_wave(

Lines 283-291

  283     )
  284 
  285 
  286 def diffraction_norm(diffraction_data: DiffractionData) -> np.ndarray:
! 287     theta_data, _ = diffraction_data.angles
! 288     cos_theta = np.cos(np.nan_to_num(theta_data))
! 289     cos_theta[cos_theta <= COS_THETA_THRESH] = np.inf
! 290     return 1.0 / np.sqrt(2.0 * np.asarray(diffraction_data.eta)) / np.sqrt(cos_theta)

tidy3d/components/autograd/utils.py

Lines 62-70

  62         if k in target:
  63             val = target[k]
  64             if isinstance(val, (list, tuple)) and isinstance(v, (list, tuple)):
  65                 if len(val) != len(v):
! 66                     raise ValueError(
  67                         f"Cannot accumulate field map for key '{k}': "
  68                         f"length mismatch ({len(val)} vs {len(v)})."
  69                     )
  70                 target[k] = type(val)(x + y for x, y in zip(val, v))

tidy3d/components/data/monitor_data.py

Lines 174-186

  174         return []
  175 
  176     def supports_parallel_adjoint(self) -> bool:
  177         """Return ``True`` if this monitor data supports parallel adjoint sources."""
! 178         return False
  179 
  180     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  181         """Return parallel adjoint bases for this monitor data."""
! 182         return []
  183 
  184     @staticmethod
  185     def get_amplitude(x) -> complex:
  186         """Get the complex amplitude out of some data."""

Lines 1463-1471

  1463 
  1464     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  1465         """Return parallel adjoint bases for single-point field monitors."""
  1466         if not self.supports_parallel_adjoint():
! 1467             return []
  1468         component_freqs = [
  1469             (str(component), data_array.coords["f"].values)
  1470             for component, data_array in self.field_components.items()
  1471         ]

Lines 1875-1883

  1875         return val
  1876 
  1877     def supports_parallel_adjoint(self) -> bool:
  1878         """Return ``True`` for mode monitor amplitude adjoints."""
! 1879         return True
  1880 
  1881     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  1882         """Return parallel adjoint bases for mode monitor amplitudes."""
  1883         amps = self.amps

Lines 4106-4114

  4106         return DataArray(np.stack([amp_phi, amp_theta], axis=3), coords=coords)
  4107 
  4108     def supports_parallel_adjoint(self) -> bool:
  4109         """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 4110         return True
  4111 
  4112     def parallel_adjoint_bases(self, monitor_index: int) -> list[ParallelAdjointBasis]:
  4113         """Return parallel adjoint bases for diffraction monitor amplitudes."""
  4114         amps = self.amps

tidy3d/components/monitor.py

Lines 64-72

  64 WINDOW_FACTOR = 15
  65 
  66 
  67 def _shifted_orders(orders: np.ndarray, bloch_vec: float) -> np.ndarray:
! 68     return bloch_vec + np.atleast_2d(orders).T
  69 
  70 
  71 def _reciprocal_coords(
  72     orders: np.ndarray, size: float, bloch_vec: float, freq: float, medium: MediumType

Lines 70-82

  70 
  71 def _reciprocal_coords(
  72     orders: np.ndarray, size: float, bloch_vec: float, freq: float, medium: MediumType
  73 ) -> np.ndarray:
! 74     if size == 0:
! 75         return np.atleast_2d(0)
! 76     epsilon = medium.eps_model(freq)
! 77     bloch_array = _shifted_orders(orders, bloch_vec)
! 78     return bloch_array / size * C_0 / freq / np.real(np.sqrt(epsilon))
  79 
  80 
  81 def _compute_angles(
  82     reciprocal_vectors: tuple[np.ndarray, np.ndarray],

Lines 80-94

  80 
  81 def _compute_angles(
  82     reciprocal_vectors: tuple[np.ndarray, np.ndarray],
  83 ) -> tuple[np.ndarray, np.ndarray]:
! 84     with warnings.catch_warnings():
! 85         warnings.filterwarnings(
  86             "ignore", message="invalid value encountered in arcsin", category=RuntimeWarning
  87         )
! 88         ux, uy = reciprocal_vectors
! 89         thetas, phis = Geometry.kspace_2_sph(ux[:, None, :], uy[None, :, :], axis=2)
! 90     return (thetas, phis)
  91 
  92 
  93 class Monitor(AbstractMonitor):
  94     """Abstract base class for monitors."""

Lines 1112-1120

  1112         return amps_size + fields_size
  1113 
  1114     def supports_parallel_adjoint(self) -> bool:
  1115         """Return ``True`` for mode monitor amplitude adjoints."""
! 1116         return True
  1117 
  1118     def parallel_adjoint_bases(
  1119         self, simulation: Simulation, monitor_index: int
  1120     ) -> list[ParallelAdjointBasis]:

Lines 1853-1861

  1853         return BYTES_COMPLEX * len(self.ux) * len(self.uy) * len(self.freqs) * 6
  1854 
  1855     def supports_parallel_adjoint(self) -> bool:
  1856         """Return ``True`` for diffraction monitor adjoints based on amplitude data."""
! 1857         return True
  1858 
  1859     def parallel_adjoint_bases(
  1860         self, simulation: Simulation, monitor_index: int
  1861     ) -> list[ParallelAdjointBasis]:

Lines 1859-1867

  1859     def parallel_adjoint_bases(
  1860         self, simulation: Simulation, monitor_index: int
  1861     ) -> list[ParallelAdjointBasis]:
  1862         """Return parallel adjoint bases for diffraction monitor amplitudes."""
! 1863         from tidy3d.components.autograd.source_factory import (
  1864             bloch_vec_for_axis,
  1865             diffraction_monitor_medium,
  1866             diffraction_order_range,
  1867         )

Lines 1865-1902

  1865             diffraction_monitor_medium,
  1866             diffraction_order_range,
  1867         )
  1868 
! 1869         medium = diffraction_monitor_medium(simulation, self)
  1870 
! 1871         axis_names = ("x", "y", "z")
! 1872         normal_axis = self.normal_axis
! 1873         transverse_axes = [axis_names[i] for i in range(3) if i != normal_axis]
! 1874         axis_x, axis_y = transverse_axes
  1875 
! 1876         size_x = simulation.size[axis_names.index(axis_x)]
! 1877         size_y = simulation.size[axis_names.index(axis_y)]
! 1878         bloch_vec_x = bloch_vec_for_axis(simulation, axis_x)
! 1879         bloch_vec_y = bloch_vec_for_axis(simulation, axis_y)
  1880 
! 1881         bases: list[DiffractionAdjointBasis] = []
! 1882         freqs = [float(freq) for freq in self.freqs]
! 1883         for freq in freqs:
! 1884             orders_x = diffraction_order_range(size_x, bloch_vec_x, freq, medium)
! 1885             orders_y = diffraction_order_range(size_y, bloch_vec_y, freq, medium)
! 1886             if orders_x.size == 0 or orders_y.size == 0:
! 1887                 continue
  1888 
! 1889             ux = _reciprocal_coords(
  1890                 orders=orders_x, size=size_x, bloch_vec=bloch_vec_x, freq=freq, medium=medium
  1891             )
! 1892             uy = _reciprocal_coords(
  1893                 orders=orders_y, size=size_y, bloch_vec=bloch_vec_y, freq=freq, medium=medium
  1894             )
! 1895             theta_vals, _ = _compute_angles((ux, uy))
! 1896             order_x_index = {int(val): idx for idx, val in enumerate(orders_x)}
! 1897             order_y_index = {int(val): idx for idx, val in enumerate(orders_y)}
! 1898             bases.extend(
  1899                 _build_diffraction_bases_for_freq(
  1900                     monitor_name=self.name,
  1901                     monitor_index=monitor_index,
  1902                     freq=freq,

Lines 1911-1919

  1911                         order_x_index[ox], order_y_index[oy], 0
  1912                     ],
  1913                 )
  1914             )
! 1915         return bases
  1916 
  1917 
  1918 class DiffractionMonitor(PlanarMonitor, FreqMonitor):
  1919     """:class:`Monitor` that uses a 2D Fourier transform to compute the

tidy3d/web/api/autograd/parallel_adjoint.py

Lines 40-48

  40 def _scale_field_map(field_map: AutogradFieldMap, scale: float) -> AutogradFieldMap:
  41     scaled = {}
  42     for k, v in field_map.items():
  43         if isinstance(v, (list, tuple)):
! 44             scaled[k] = type(v)(scale * x for x in v)
  45         else:
  46             scaled[k] = scale * v
  47     return scaled

Lines 59-69

  59     unsupported: list[str] = []
  60     for monitor_index, monitor in enumerate(simulation.monitors):
  61         try:
  62             bases_for_monitor = monitor.parallel_adjoint_bases(simulation, monitor_index)
! 63         except ValueError:
! 64             unsupported.append(monitor.name)
! 65             continue
  66         if bases_for_monitor:
  67             bases.extend(bases_for_monitor)
  68         elif not monitor.supports_parallel_adjoint():
  69             unsupported.append(monitor.name)

Lines 105-117

  105     basis_spec: object,
  106 ) -> object:
  107     post_norm = sim_data_adj.simulation.post_norm
  108     if not hasattr(basis_spec, "freq"):
! 109         return post_norm
  110     freqs = np.asarray(post_norm.coords["f"].values)
  111     idx = int(np.argmin(np.abs(freqs - basis_spec.freq)))
  112     if not np.isclose(freqs[idx], basis_spec.freq):
! 113         raise td.exceptions.AdjointError(
  114             "Parallel adjoint basis frequency not found in adjoint post-normalization."
  115         )
  116     return post_norm.isel(f=[idx])

Lines 134-152

  134         for key, data_array in monitor_data.field_components.items():
  135             if "f" in data_array.dims:
  136                 freqs = np.asarray(data_array.coords["f"].values)
  137                 if freqs.size == 0:
! 138                     raise td.exceptions.AdjointError(
  139                         "Parallel adjoint expected frequency data but no frequencies were found."
  140                     )
  141                 idx = int(np.argmin(np.abs(freqs - freq)))
  142                 if not np.isclose(freqs[idx], freq, rtol=1e-10, atol=0.0):
! 143                     raise td.exceptions.AdjointError(
  144                         "Parallel adjoint basis frequency not found in monitor data."
  145                     )
  146                 updates[key] = data_array.isel(f=[idx])
  147         return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False, **updates)
! 148     return monitor_data.updated_copy(monitor=monitor, deep=False, validate=False)
  149 
  150 
  151 def _select_sim_data_freq(
  152     sim_data_adj: td.SimulationData,

Lines 158-166

  158     for monitor in sim.monitors:
  159         if hasattr(monitor, "freqs"):
  160             monitor_updated = monitor.updated_copy(freqs=[freq])
  161         else:
! 162             monitor_updated = monitor
  163         monitors.append(monitor_updated)
  164         monitor_map[monitor.name] = monitor_updated
  165     sim_updated = sim.updated_copy(monitors=monitors)

Lines 219-227

  219     simulation: td.Simulation,
  220     basis_sources: list[tuple[ParallelAdjointBasis, Any]],
  221 ) -> list[tuple[list[ParallelAdjointBasis], AdjointSourceInfo]]:
  222     if not basis_sources:
! 223         return []
  224 
  225     sim_data_stub = td.SimulationData(simulation=simulation, data=())
  226     sources = [source for _, source in basis_sources]
  227     sources_processed = td.SimulationData._adjoint_src_width_single(sources)

Lines 272-282

  272     monitor = simulation.monitors[basis.monitor_index]
  273     fwidth = adjoint_fwidth_from_simulation(simulation)
  274 
  275     if isinstance(basis, DiffractionAdjointBasis):
! 276         if not isinstance(monitor, DiffractionMonitor):
! 277             raise ValueError("Diffraction basis monitor mismatch.")
! 278         source = diffraction_source_from_simulation(
  279             simulation=simulation,
  280             monitor=monitor,
  281             freq=basis.freq,
  282             order_x=basis.order_x,

Lines 284-296

  284             polarization=basis.polarization,
  285             coefficient=coefficient,
  286             fwidth=fwidth,
  287         )
! 288         return adjoint_source_info_single(source)
  289 
  290     if isinstance(basis, ModeAdjointBasis):
  291         if not isinstance(monitor, ModeMonitor):
! 292             raise ValueError("Mode basis monitor mismatch.")
  293         source = mode_source_from_monitor(
  294             monitor=monitor,
  295             freq=basis.freq,
  296             direction=basis.direction,

Lines 301-309

  301         return adjoint_source_info_single(source)
  302 
  303     if isinstance(basis, PointFieldAdjointBasis):
  304         if not isinstance(monitor, FieldMonitor):
! 305             raise ValueError("Point-field basis monitor mismatch.")
  306         source = point_current_source_from_simulation(
  307             simulation=simulation,
  308             monitor=monitor,
  309             component=basis.component,

Lines 311-322

  311             coefficient=coefficient,
  312             fwidth=fwidth,
  313         )
  314         if source is None:
! 315             raise ValueError("Adjoint point source has zero amplitude.")
  316         return adjoint_source_info_single(source)
  317 
! 318     raise ValueError("Unsupported parallel adjoint basis.")
  319 
  320 
  321 @dataclass(frozen=True)
  322 class ParallelAdjointPayload:

Lines 370-387

  370                 simulation=simulation,
  371                 basis=basis,
  372                 coefficient=1.0 + 0.0j,
  373             )
! 374         except ValueError as exc:
! 375             td.log.info(
  376                 f"Skipping parallel adjoint basis for monitor '{basis.monitor_name}': {exc}"
  377             )
! 378             continue
  379         basis_sources.append((basis, source_info.sources[0]))
  380 
  381     if not basis_sources:
  382         if basis_specs:
! 383             td.log.info("Parallel adjoint produced no simulations for this task.")
  384         else:
  385             td.log.warning(
  386                 "Parallel adjoint disabled because no eligible monitor outputs were found."
  387             )

Lines 411-425

  411         task_map[adj_task_name] = bases
  412         used_bases.extend(bases)
  413 
  414     if not sims_adj_dict:
! 415         if basis_specs:
! 416             td.log.info("Parallel adjoint produced no simulations for this task.")
  417         else:
! 418             td.log.warning(
  419                 "Parallel adjoint disabled because no eligible monitor outputs were found."
  420             )
! 421         return None
  422 
  423     td.log.info(
  424         "Parallel adjoint enabled: launched "
  425         f"{len(sims_adj_dict)} canonical adjoint simulations for task '{task_name}'."

Lines 437-445

  437     task_paths: dict[str, str],
  438     base_dir: PathLike,
  439 ) -> None:
  440     if not task_names:
! 441         return
  442     target_dir = Path(base_dir) / config.adjoint.local_adjoint_dir
  443     target_dir.mkdir(parents=True, exist_ok=True)
  444     for task_name in task_names:
  445         src_path = task_paths.get(task_name)

Lines 444-459

  444     for task_name in task_names:
  445         src_path = task_paths.get(task_name)
  446         if not src_path:
  447             continue
! 448         src = Path(src_path)
! 449         if not src.exists():
! 450             continue
! 451         dst = target_dir / src.name
! 452         if src.resolve() == dst.resolve():
! 453             continue
! 454         dst.parent.mkdir(parents=True, exist_ok=True)
! 455         src.replace(dst)
  456 
  457 
  458 def apply_parallel_adjoint(
  459     data_fields_vjp: AutogradFieldMap,

Lines 461-469

  461     sim_data_orig: td.SimulationData,
  462 ) -> tuple[AutogradFieldMap, AutogradFieldMap]:
  463     basis_maps = parallel_info.get("basis_maps")
  464     if basis_maps is None:
! 465         return {}, data_fields_vjp
  466 
  467     data_fields_vjp_fallback = {k: np.array(v, copy=True) for k, v in data_fields_vjp.items()}
  468     vjp_parallel: AutogradFieldMap = {}
  469     norm_cache: dict[int, np.ndarray] = {}

Lines 471-490

  471     basis_specs = list(parallel_info.get("basis_specs", []))
  472     for basis in basis_specs:
  473         basis_map = basis_maps.get(basis)
  474         if basis_map is None:
! 475             continue
  476         basis_real = basis_map.get("real")
  477         basis_imag = basis_map.get("imag")
  478         if basis_real is None or basis_imag is None:
! 479             continue
  480         if isinstance(basis, DiffractionAdjointBasis):
! 481             norm = norm_cache.get(basis.monitor_index)
! 482             if norm is None:
! 483                 diff_data = sim_data_orig.data[basis.monitor_index]
! 484                 norm = diffraction_norm(diff_data)
! 485                 norm_cache[basis.monitor_index] = norm
! 486             coefficient = basis.vjp_value(data_fields_vjp, sim_data_orig, norm)
  487         else:
  488             coefficient = basis.vjp_value(data_fields_vjp, sim_data_orig)
  489 
  490         if coefficient == 0:

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 0c9fe1b to 4d59014 Compare January 27, 2026 15:56
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 4d59014 to 7cabb37 Compare January 27, 2026 16:10
Copy link
Collaborator

@yaugenst-flex yaugenst-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @marcorudolphflex this is pretty great, had a cursory glance at the PR to try to understand a bit what's going on and left some questions/comments, but I'll look deeper into the implementation when I find some time. I guess one thing to note is that this introduces a lot of new code, even modules. Not a problem in itself but I'd maybe have a closer look whether any of this can be simplified.
Also, could you show some plots/verification against the non-parallel adjoint?

- Mode direction policy (for mode monitors): `config.adjoint.parallel_adjoint_mode_direction_policy`
- `"assume_outgoing"` (default): pick the mode direction based on monitor position relative to the simulation center and flip it for the adjoint.
- `"run_both_directions"`: launch parallel adjoint sources for both `+` and `-` directions.
- `"no_parallel"`: disable parallel adjoint entirely.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do mode monitors separately have a flag to turn parallel adjoint off, in addition to the global config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbd if users need that in case they want to override the global toggle for this less-determined mode monitor... As we do have a config field anyways, I think it doesn't hurt. Or could that be confusing for users regarding its effect along with the global toggle?

Comment on lines +69 to +70
- Only effective when: `config.adjoint.local_gradient = True`
- If `local_gradient=False`, the flag is ignored and behavior remains unchanged.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only local gradients? Couldn't this be supported in remote too? Maybe it's fine as an initial version but I don't see how this couldnt be done for remote?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably yes, this was the "easy" start.


#### Limits and guardrails you should expect

- Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all parallel simulations counted as adjoint toward this cap?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

#### Limits and guardrails you should expect

- Hard cap: the feature will not exceed `config.adjoint.max_adjoint_per_fwd`.
- If enabling parallel adjoint would exceed the cap, the run logs a warning and proceeds with the sequential path for that forward run (or a safe subset, depending on policy).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might not want to proceed at all in that case, not sure. Since this a flag that we wouldn't turn on by default, it means that generally the user will have requested it, so they might want to choose to increase the cap instead of running sequentially.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, changed it to raising an AdjointError as we do it currently for sequential adjoint

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be important to explain/understand here in which scenarios how many adjoint simulations would get launched in the parallel case and what the edge cases are so there are no surprises.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some section in the readme

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 7cabb37 to f351142 Compare January 28, 2026 08:19
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from f351142 to 9abfd55 Compare January 28, 2026 09:46
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 9abfd55 to ac73e25 Compare January 28, 2026 11:34
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from ac73e25 to dac5685 Compare January 28, 2026 15:12
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from dac5685 to 29306de Compare January 28, 2026 16:16
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from 29306de to e05cfd7 Compare January 29, 2026 13:54
@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from e05cfd7 to e13d19f Compare January 29, 2026 15:05
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

@marcorudolphflex marcorudolphflex force-pushed the FXC-3343-all-port-excitation-with-post-normalization-parallel-adjoint branch from e13d19f to 6b77f75 Compare January 29, 2026 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants