|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import sys |
| 4 | +from collections.abc import Mapping, Sequence |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import Callable |
| 7 | + |
| 8 | +import autograd as ag |
| 9 | +import autograd.numpy as anp |
| 10 | +import numpy as np |
| 11 | +import pytest |
| 12 | + |
| 13 | +import tidy3d as td |
| 14 | +import tidy3d.web as web |
| 15 | + |
| 16 | + |
| 17 | +@pytest.fixture(autouse=True) |
| 18 | +def _enable_local_cache(monkeypatch): |
| 19 | + # monkeypatch.setattr(td.config.local_cache, "enabled", True) |
| 20 | + pass |
| 21 | + |
| 22 | + |
| 23 | +td.config.local_cache.enabled = True |
| 24 | +WVL0 = 1 |
| 25 | +FREQ0 = td.C_0 / WVL0 |
| 26 | +FWIDTH = FREQ0 / 10 |
| 27 | +PULSE = td.GaussianPulse(freq0=FREQ0, fwidth=FWIDTH) |
| 28 | +SIM_SIZE = (3 * WVL0, 3 * WVL0, 3 * WVL0) |
| 29 | +MONITOR_CENTER = (-0.3, 0.1, 0.2) |
| 30 | +MONITOR_SIZE = (0.5, 0.5, 0) |
| 31 | +SOURCE_SIZE = (1, 1, 0.1) |
| 32 | +SOURCE_CENTER = (0.1, 0.4, -0.2) |
| 33 | +DATASET_SPACING = 0.1 |
| 34 | + |
| 35 | + |
| 36 | +def _axis_coords(size: float, spacing: float) -> np.ndarray: |
| 37 | + """Return 1D coords for an axis; if size==0, return a single point at 0.""" |
| 38 | + if size <= 0: |
| 39 | + return np.array([0.0]) |
| 40 | + # Prefer rounding to avoid silent off-by-one due to float -> int truncation |
| 41 | + n = max(2, int(np.round(size / spacing))) |
| 42 | + return np.linspace(-size / 2, size / 2, n) |
| 43 | + |
| 44 | + |
| 45 | +def _make_coords( |
| 46 | + size_xyz: tuple[float, float, float], spacing: float, freq0: float |
| 47 | +) -> dict[str, object]: |
| 48 | + x = _axis_coords(size_xyz[0], spacing) |
| 49 | + y = _axis_coords(size_xyz[1], spacing) |
| 50 | + z = _axis_coords(size_xyz[2], spacing) |
| 51 | + return {"x": x, "y": y, "z": z, "f": [freq0]} |
| 52 | + |
| 53 | + |
| 54 | +def _make_field_data( |
| 55 | + amp: float, |
| 56 | + shape: tuple[int, int, int, int], |
| 57 | + *, |
| 58 | + add_noise: bool, |
| 59 | + noise_scale: float = 1.0, |
| 60 | + seed: int = 12345, |
| 61 | +) -> np.ndarray: |
| 62 | + """Uniform field with optional deterministic Gaussian noise.""" |
| 63 | + base = amp * np.ones(shape, dtype=float) |
| 64 | + if not add_noise: |
| 65 | + return base |
| 66 | + rng = np.random.RandomState(seed) |
| 67 | + noise = rng.normal(size=shape) |
| 68 | + return base + noise_scale * noise |
| 69 | + |
| 70 | + |
| 71 | +def _make_field_dataset_from_amplitudes( |
| 72 | + amplitudes: Sequence[float], |
| 73 | + field_prefix: str, |
| 74 | + coords: Mapping[str, object], |
| 75 | + *, |
| 76 | + add_noise: bool, |
| 77 | + noise_scale: float = 1.0, |
| 78 | + noise_seed: int = 12345, |
| 79 | +) -> td.FieldDataset: |
| 80 | + x = np.asarray(coords["x"]) |
| 81 | + y = np.asarray(coords["y"]) |
| 82 | + z = np.asarray(coords["z"]) |
| 83 | + shape = (len(x), len(y), len(z), 1) |
| 84 | + |
| 85 | + field_components = {} |
| 86 | + for amp, axis in zip(amplitudes, "xyz"): |
| 87 | + data = _make_field_data( |
| 88 | + amp, |
| 89 | + shape, |
| 90 | + add_noise=add_noise, |
| 91 | + noise_scale=noise_scale, |
| 92 | + seed=noise_seed, |
| 93 | + ) |
| 94 | + field_components[f"{field_prefix}{axis}"] = td.ScalarFieldDataArray(data, coords=coords) |
| 95 | + |
| 96 | + return td.FieldDataset(**field_components) |
| 97 | + |
| 98 | + |
| 99 | +def _make_custom_field_source_components( |
| 100 | + amplitudes: Sequence[float], |
| 101 | + field_prefix: str, |
| 102 | + *, |
| 103 | + add_noise: bool, |
| 104 | +) -> td.CustomFieldSource: |
| 105 | + coords = _make_coords((SOURCE_SIZE[0], SOURCE_SIZE[1], 0.0), DATASET_SPACING, FREQ0) |
| 106 | + field_dataset = _make_field_dataset_from_amplitudes( |
| 107 | + amplitudes, |
| 108 | + field_prefix, |
| 109 | + coords, |
| 110 | + add_noise=add_noise, |
| 111 | + noise_scale=1.0, |
| 112 | + noise_seed=12345, |
| 113 | + ) |
| 114 | + return td.CustomFieldSource( |
| 115 | + center=SOURCE_CENTER, |
| 116 | + size=(SOURCE_SIZE[0], SOURCE_SIZE[1], 0.0), |
| 117 | + source_time=PULSE, |
| 118 | + field_dataset=field_dataset, |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +def _make_custom_current_source_components( |
| 123 | + amplitudes: Sequence[float], |
| 124 | + field_prefix: str, |
| 125 | + *, |
| 126 | + add_noise: bool, |
| 127 | +) -> td.CustomCurrentSource: |
| 128 | + coords = _make_coords(SOURCE_SIZE, DATASET_SPACING, FREQ0) |
| 129 | + field_dataset = _make_field_dataset_from_amplitudes( |
| 130 | + amplitudes, |
| 131 | + field_prefix, |
| 132 | + coords, |
| 133 | + add_noise=add_noise, |
| 134 | + noise_scale=1.0, |
| 135 | + noise_seed=12345, |
| 136 | + ) |
| 137 | + return td.CustomCurrentSource( |
| 138 | + center=SOURCE_CENTER, |
| 139 | + size=SOURCE_SIZE, |
| 140 | + source_time=PULSE, |
| 141 | + current_dataset=field_dataset, |
| 142 | + ) |
| 143 | + |
| 144 | + |
| 145 | +def angled_overlap_deg(v1, v2): |
| 146 | + norm_v1 = np.linalg.norm(v1) |
| 147 | + norm_v2 = np.linalg.norm(v2) |
| 148 | + |
| 149 | + if np.isclose(norm_v1, 0.0) or np.isclose(norm_v2, 0.0): |
| 150 | + if not (np.isclose(norm_v1, 0.0) and np.isclose(norm_v2, 0.0)): |
| 151 | + return np.inf |
| 152 | + |
| 153 | + return 0.0 |
| 154 | + |
| 155 | + dot = np.minimum(1.0, np.sum((v1 / np.linalg.norm(v1)) * (v2 / np.linalg.norm(v2)))) |
| 156 | + angle_deg = np.arccos(dot) * 180.0 / np.pi |
| 157 | + |
| 158 | + return angle_deg |
| 159 | + |
| 160 | + |
| 161 | +def _make_sim(source: td.Source) -> td.Simulation: |
| 162 | + monitor = td.FieldMonitor( |
| 163 | + name="field_monitor", |
| 164 | + center=MONITOR_CENTER, |
| 165 | + size=MONITOR_SIZE, |
| 166 | + freqs=[FREQ0], |
| 167 | + ) |
| 168 | + return td.Simulation( |
| 169 | + size=SIM_SIZE, |
| 170 | + run_time=1e-12, |
| 171 | + grid_spec=td.GridSpec.auto(min_steps_per_wvl=40, wavelength=WVL0), |
| 172 | + sources=[source], |
| 173 | + monitors=[monitor], |
| 174 | + boundary_spec=td.BoundarySpec.all_sides(boundary=td.PML()), |
| 175 | + ) |
| 176 | + |
| 177 | + |
| 178 | +def _eval_objective(sim_data: td.SimulationData, field_component: str) -> float: |
| 179 | + field_data = sim_data.load_field_monitor("field_monitor") |
| 180 | + component = getattr(field_data, field_component, None) |
| 181 | + if component is None: |
| 182 | + return 0.0 |
| 183 | + indexers = {} |
| 184 | + for dim in ("x", "y", "z", "f"): |
| 185 | + if dim in component.dims: |
| 186 | + indexers[dim] = component.sizes[dim] // 2 |
| 187 | + field_value = component.isel(**indexers).values |
| 188 | + return anp.abs(field_value) ** 2 |
| 189 | + |
| 190 | + |
| 191 | +def _eval_objective_components( |
| 192 | + sim_data: td.SimulationData, field_components: Sequence[str] |
| 193 | +) -> float: |
| 194 | + total = 0.0 |
| 195 | + for component in field_components: |
| 196 | + total += _eval_objective(sim_data, component) |
| 197 | + return total |
| 198 | + |
| 199 | + |
| 200 | +@dataclass(frozen=True) |
| 201 | +class SourceCase: |
| 202 | + name: str |
| 203 | + monitor_components: tuple[str, str, str] |
| 204 | + delta: float |
| 205 | + make_source: Callable |
| 206 | + |
| 207 | + |
| 208 | +VECTOR_SOURCE_CASES = [ |
| 209 | + # ------------------ |
| 210 | + # custom field source |
| 211 | + # ------------------ |
| 212 | + SourceCase( |
| 213 | + name="custom_field_vec_e", |
| 214 | + monitor_components=("Ex", "Ey", "Ez"), |
| 215 | + delta=1e-4, |
| 216 | + make_source=lambda amps, *, add_noise: _make_custom_field_source_components( |
| 217 | + amps, "E", add_noise=add_noise |
| 218 | + ), |
| 219 | + ), |
| 220 | + SourceCase( |
| 221 | + name="custom_field_vec_h", |
| 222 | + monitor_components=("Hx", "Hy", "Hz"), |
| 223 | + delta=1e-4, |
| 224 | + make_source=lambda amps, *, add_noise: _make_custom_field_source_components( |
| 225 | + amps, "H", add_noise=add_noise |
| 226 | + ), |
| 227 | + ), |
| 228 | + # ------------------ |
| 229 | + # custom current source |
| 230 | + # ------------------ |
| 231 | + SourceCase( |
| 232 | + name="custom_current_vec_e", |
| 233 | + monitor_components=("Ex", "Ey", "Ez"), |
| 234 | + delta=1e-4, |
| 235 | + make_source=lambda amps, *, add_noise: _make_custom_current_source_components( |
| 236 | + amps, "E", add_noise=add_noise |
| 237 | + ), |
| 238 | + ), |
| 239 | + SourceCase( |
| 240 | + name="custom_current_vec_h", |
| 241 | + monitor_components=("Hx", "Hy", "Hz"), |
| 242 | + delta=1e-4, |
| 243 | + make_source=lambda amps, *, add_noise: _make_custom_current_source_components( |
| 244 | + amps, "H", add_noise=add_noise |
| 245 | + ), |
| 246 | + ), |
| 247 | +] |
| 248 | +PARAM_VECTORS = [ |
| 249 | + pytest.param((1.0, -0.5, 0.25), id="p1"), |
| 250 | + pytest.param((0.2, 0.7, -0.9), id="p2"), |
| 251 | +] |
| 252 | + |
| 253 | +NOISE_CASES = [ |
| 254 | + pytest.param(False, id="no_noise"), |
| 255 | + pytest.param(True, id="noise"), |
| 256 | +] |
| 257 | + |
| 258 | + |
| 259 | +@pytest.mark.numerical |
| 260 | +@pytest.mark.parametrize("params", PARAM_VECTORS) |
| 261 | +@pytest.mark.parametrize("add_noise", NOISE_CASES) |
| 262 | +@pytest.mark.parametrize("case", VECTOR_SOURCE_CASES, ids=lambda case: case.name) |
| 263 | +def test_custom_source_gradients( |
| 264 | + _enable_local_cache, |
| 265 | + tmp_path, |
| 266 | + case, |
| 267 | + params, |
| 268 | + add_noise, |
| 269 | +): |
| 270 | + delta = case.delta |
| 271 | + monitor_components = case.monitor_components |
| 272 | + label = f"{case.name}_{'noise' if add_noise else 'clean'}_{params!r}" |
| 273 | + |
| 274 | + make_source = lambda amps: case.make_source(amps, add_noise=add_noise) |
| 275 | + |
| 276 | + def objective_adj(ax, ay, az): |
| 277 | + sim = _make_sim(make_source((ax, ay, az))) |
| 278 | + sim_data = web.run( |
| 279 | + sim, |
| 280 | + task_name=f"{label}_adj", |
| 281 | + path=tmp_path / f"{label}_adj.hdf5", |
| 282 | + local_gradient=True, |
| 283 | + verbose=False, |
| 284 | + ) |
| 285 | + return _eval_objective_components(sim_data, monitor_components) |
| 286 | + |
| 287 | + grad_adjoint = np.array( |
| 288 | + [ |
| 289 | + ag.grad(objective_adj, 0)(*params), |
| 290 | + ag.grad(objective_adj, 1)(*params), |
| 291 | + ag.grad(objective_adj, 2)(*params), |
| 292 | + ], |
| 293 | + dtype=float, |
| 294 | + ) |
| 295 | + |
| 296 | + sims = {} |
| 297 | + for idx, axis in enumerate("xyz"): |
| 298 | + params_plus = list(params) |
| 299 | + params_plus[idx] += delta |
| 300 | + params_minus = list(params) |
| 301 | + params_minus[idx] -= delta |
| 302 | + sims[f"{label}_fd_{axis}_plus"] = _make_sim(make_source(tuple(params_plus))) |
| 303 | + sims[f"{label}_fd_{axis}_minus"] = _make_sim(make_source(tuple(params_minus))) |
| 304 | + |
| 305 | + sim_data_map = web.run_async( |
| 306 | + sims, |
| 307 | + path_dir=tmp_path, |
| 308 | + local_gradient=False, |
| 309 | + verbose=False, |
| 310 | + ) |
| 311 | + |
| 312 | + grad_fd = np.zeros(3, dtype=float) |
| 313 | + for idx, axis in enumerate("xyz"): |
| 314 | + obj_plus = _eval_objective_components( |
| 315 | + sim_data_map[f"{label}_fd_{axis}_plus"], monitor_components |
| 316 | + ) |
| 317 | + obj_minus = _eval_objective_components( |
| 318 | + sim_data_map[f"{label}_fd_{axis}_minus"], monitor_components |
| 319 | + ) |
| 320 | + grad_fd[idx] = (obj_plus - obj_minus) / (2 * delta) |
| 321 | + |
| 322 | + angle = angled_overlap_deg(grad_adjoint, grad_fd) |
| 323 | + print(f"[{label}] grad_adjoint = {grad_adjoint}", file=sys.stderr) |
| 324 | + print(f"[{label}] grad_fd = {grad_fd}", file=sys.stderr) |
| 325 | + print(f"[{label}] angle_deg = {angle}", file=sys.stderr) |
| 326 | + |
| 327 | + assert angle < 5.0 |
0 commit comments