diff --git a/CHANGELOG.md b/CHANGELOG.md index c8462720a8..4d0094e2ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -59,6 +59,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - For `HeatChargeSimulation` objects, the `plot` function now adds the simulation boundary conditions. +- Improved memory performance in `postprocess_adj` by using `isel` instead of `sel` for xarray slicing. ### Fixed - Fixed `AutoImpedanceSpec` validation to check path intersections against all conductors, not just filtered ones, as well as the mode plane bounds. diff --git a/pyproject.toml b/pyproject.toml index 66341822d9..d6e42e52b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -309,9 +309,11 @@ banned-module-level-imports = ["scipy", "matplotlib"] [tool.pytest.ini_options] # TODO: remove --assert=plain when https://github.com/scipy/scipy/issues/22236 is resolved -addopts = "--cov=tidy3d --doctest-modules -n auto --dist worksteal --assert=plain -m 'not numerical and not perf'" +addopts = "--cov=tidy3d --doctest-modules -n auto --dist worksteal --assert=plain -m 'not numerical and not perf and not generate_profile and not performance_profile'" markers = [ "numerical: marks numerical tests for adjoint gradients that require running simulations (deselect with '-m \"not numerical\"')", + "generate_profile: regenerate fixture data required by the performance tests (select with '-m \"generate_profile\"')", + "performance_profile: enable CPU and memory profiling based on data generated with generate_profile (select with '-m \"performance_profile\"')", "perf: marks tests which test the runtime of operations (deselect with '-m \"not perf\"')", "slow: marks tests as slow (deselect with -m 'not slow')", ] diff --git a/tests/test_components/autograd/numerical/test_autograd_frequency_selection.py b/tests/test_components/autograd/numerical/test_autograd_frequency_selection.py new file mode 100644 index 0000000000..a5e6295276 --- /dev/null +++ b/tests/test_components/autograd/numerical/test_autograd_frequency_selection.py @@ -0,0 +1,438 @@ +# test autograd postprocess for frequency slicing to ensure gradients are consistent +from __future__ import annotations + +import operator + +import autograd as ag +import matplotlib.pylab as plt +import numpy as np +import pytest + +import tidy3d as td +import tidy3d.web as web +from tidy3d.config import config + +PLOT_ADJ_COMPARISON = True # False +NUM_FINITE_DIFFERENCE = 10 +SAVE_FD_ADJ_DATA = False +SAVE_FD_LOC = 0 +SAVE_ADJ_LOC = 1 +LOCAL_GRADIENT = True +VERBOSE = False +NUMERICAL_RESULTS_SUBDIR = "numerical_test_frequency_selection" + +if PLOT_ADJ_COMPARISON: + pytestmark = pytest.mark.usefixtures("mpl_config_interactive") +else: + pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") + +MESH_FACTOR_DESIGN = 60.0 + + +def get_sim_geometry(mesh_wvl_um, offset_y_size_wvl=0): + return td.Box(size=(7 * mesh_wvl_um, 7 * mesh_wvl_um, 3 * mesh_wvl_um), center=(0, 0, 0)) + + +def make_base_sim( + mesh_wvl_um, + adj_wvl_um, + geometry_size_wvl, + box_for_override, + required_freqs, + num_extra_freqs, + rng, + run_time=5e-11, +): + """Creates a base simulation with input/output waveguides, mode sources, and mode monitors.""" + sim_geometry = get_sim_geometry(mesh_wvl_um) + sim_size_um = sim_geometry.size + sim_center_um = sim_geometry.center + + boundary_spec = td.BoundarySpec( + x=td.Boundary.pml(), + y=td.Boundary.pml(), + z=td.Boundary.pml(), + ) + + dl_design = mesh_wvl_um / MESH_FACTOR_DESIGN + + mesh_overrides = [] + mesh_overrides.extend( + [ + td.MeshOverrideStructure( + geometry=box_for_override, + dl=[dl_design, dl_design, dl_design], + ), + ] + ) + + src_size = sim_size_um[0:2] + (0,) + + wl_min_src_um = 0.9 * adj_wvl_um + wl_max_src_um = 1.1 * adj_wvl_um + + fwidth_src = td.C_0 * ((1.0 / wl_min_src_um) - (1.0 / wl_max_src_um)) + freq0 = td.C_0 / adj_wvl_um + + wg_input_left = -0.75 * sim_size_um[0] + wg_input_right = sim_center_um[0] - 0.5 * geometry_size_wvl[0] * mesh_wvl_um + + wg_output_left = sim_center_um[0] + 0.5 * geometry_size_wvl[0] * mesh_wvl_um + wg_output_right = 0.75 * sim_size_um[0] + + wg_input_center = 0.5 * (wg_input_left + wg_input_right) + wg_output_center = 0.5 * (wg_output_left + wg_output_right) + + wg_input_length = wg_input_right - wg_input_left + wg_output_length = wg_output_right - wg_output_left + + src_input_center = 0.5 * (-0.5 * sim_size_um[0] + wg_input_right) + monitor_output_center = 0.5 * (0.5 * sim_size_um[0] + wg_output_left) + + output_wg_y_offset_um = 0.5 * mesh_wvl_um + + mode_layer_height_um = MODE_LAYER_HEIGHT_WVL * adj_wvl_um + input_waveguide_geometry = td.Box( + center=(wg_input_center, 0, 0.5 * mode_layer_height_um), + size=(wg_input_length, WG_WIDTH_WVL * adj_wvl_um, mode_layer_height_um), + ) + output_waveguide_geometry = td.Box( + center=(wg_output_center, output_wg_y_offset_um, 0.5 * mode_layer_height_um), + size=(wg_output_length, WG_WIDTH_WVL * adj_wvl_um, mode_layer_height_um), + ) + output_waveguide_geometry2 = td.Box( + center=(wg_output_center, -output_wg_y_offset_um, 0.5 * mode_layer_height_um), + size=(wg_output_length, WG_WIDTH_WVL * adj_wvl_um, mode_layer_height_um), + ) + + input_waveguide = td.Structure( + geometry=input_waveguide_geometry, medium=td.Medium(permittivity=WG_INDEX**2) + ) + output_waveguide = td.Structure( + geometry=output_waveguide_geometry, medium=td.Medium(permittivity=WG_INDEX**2) + ) + output_waveguide2 = td.Structure( + geometry=output_waveguide_geometry2, medium=td.Medium(permittivity=WG_INDEX**2) + ) + + substrate_max = 0 + substrate_min = -0.75 * sim_size_um[2] + substrate = td.Structure( + geometry=td.Box( + center=(sim_center_um[0], sim_center_um[1], 0.5 * (substrate_max + substrate_min)), + size=(1.5 * sim_size_um[0], 1.5 * sim_size_um[1], (substrate_max - substrate_min)), + ), + medium=td.Medium(permittivity=SUBSTRATE_INDEX**2), + ) + + min_required_freq = np.min(required_freqs) + max_required_freq = np.max(required_freqs) + + required_freq_span = max_required_freq - min_required_freq + + random_other_freqs = rng.uniform( + low=min_required_freq - 0.1 * required_freq_span, + high=max_required_freq + 0.1 * required_freq_span, + size=num_extra_freqs, + ) + + mode_monitor_freqs = sorted(list(required_freqs) + list(random_other_freqs)) + mode_monitor_freqs = np.flip(mode_monitor_freqs) + mode_monitor_top = td.ModeMonitor( + center=( + monitor_output_center + 0.15 * mesh_wvl_um, + output_wg_y_offset_um, + 0.5 * mode_layer_height_um, + ), + size=(0, 5 * WG_WIDTH_WVL * mesh_wvl_um, 5 * mode_layer_height_um), + name="monitor_mode_top", + freqs=mode_monitor_freqs, + ) + + mode_monitor_bottom = td.ModeMonitor( + center=(monitor_output_center, -output_wg_y_offset_um, 0.5 * mode_layer_height_um), + size=(0, 5 * WG_WIDTH_WVL * mesh_wvl_um, 5 * mode_layer_height_um), + name="monitor_mode_bottom", + freqs=mode_monitor_freqs, + ) + + pulse = td.GaussianPulse(freq0=freq0, fwidth=fwidth_src) + mode_src = td.ModeSource( + center=(src_input_center, 0, 0.5 * mode_layer_height_um), + size=(0, 5 * WG_WIDTH_WVL * mesh_wvl_um, 5 * mode_layer_height_um), + name="src_mode", + source_time=pulse, + direction="+", + ) + + sim_base = td.Simulation( + center=sim_center_um, + size=sim_size_um, + grid_spec=td.GridSpec.auto( + min_steps_per_wvl=20, + wavelength=mesh_wvl_um, + override_structures=mesh_overrides, + ), + structures=[input_waveguide, output_waveguide, output_waveguide2, substrate], + sources=[mode_src], + monitors=[mode_monitor_top, mode_monitor_bottom], + run_time=run_time, + boundary_spec=boundary_spec, + subpixel=True, + ) + + return sim_base + + +def create_objective_function( + create_sim_base, + eval_fn, + sim_path_dir, + mode_layer_height_um, + polyslab_height_um, + polyslab_permittivity, +): + """Create an objective function for the test based on the base simulation, type of evaluation function on + the electromagnetic data, and geometric and material parameters.""" + + def objective(vertices): + sim_base = create_sim_base() + + simulation_dict = {} + for idx in range(len(vertices)): + vertices_x = vertices[idx][0:NUM_VERTICES] + vertices_y = vertices[idx][NUM_VERTICES:] + + make_polyslab = td.PolySlab( + slab_bounds=( + 0.5 * mode_layer_height_um - 0.5 * polyslab_height_um, + 0.5 * mode_layer_height_um + 0.5 * polyslab_height_um, + ), + axis=2, + vertices=tuple(zip(vertices_x, vertices_y)), + ) + + polyslab_structure = td.Structure( + geometry=make_polyslab, medium=td.Medium(permittivity=polyslab_permittivity) + ) + + sim_with_polyslab = sim_base.updated_copy( + structures=(*sim_base.structures, polyslab_structure) + ) + + simulation_dict[f"numerical_mode_polyslab_testing_{idx}"] = sim_with_polyslab.copy() + + sim_data = web.run_async( + simulation_dict, path_dir=sim_path_dir, local_gradient=LOCAL_GRADIENT, verbose=VERBOSE + ) + + objective_vals = [] + for idx in range(len(vertices)): + objective_vals.append(eval_fn(sim_data[f"numerical_mode_polyslab_testing_{idx}"])) + + if len(vertices) == 1: + return objective_vals[0] + + return objective_vals + + return objective + + +# Parameters for controlling the test geometry and material parameters as well as the +# array of tests to run. +MODE_LAYER_HEIGHT_WVL = 0.25 +POLYSLAB_HEIGHT_WVL = MODE_LAYER_HEIGHT_WVL / 8.0 +WG_WIDTH_WVL = 0.275 +SUBSTRATE_INDEX = 1.5 + +WG_INDEX = 3.5 + +# Number of vertices to put in the test polyslab. +NUM_VERTICES = 15 +NUM_EXTRA_FREQS = 10 +NUM_OBJECTIVES = 5 + +mesh_wvls_um = [1.55] +adj_wvls_um = [1.5] +geometry_sizes_wvl = [(3.0, 3.0, MODE_LAYER_HEIGHT_WVL)] +polyslab_indices = np.linspace(SUBSTRATE_INDEX, WG_INDEX, 3) + +mode_data_test_parameters = [] + +test_number = 0 +for idx in range(len(mesh_wvls_um)): + mesh_wvl_um = mesh_wvls_um[idx] + adj_wvl_um = adj_wvls_um[idx] + + for geometry_size_wvl in geometry_sizes_wvl: + for polyslab_index in polyslab_indices: + polyslab_permittivity = polyslab_index**2 + + mode_data_test_parameters.append( + { + "mesh_wvl_um": mesh_wvl_um, + "adj_wvl_um": adj_wvl_um, + "geometry_size_wvl": geometry_size_wvl, + "polyslab_permittivity": polyslab_permittivity, + "test_number": test_number, + } + ) + + test_number += 1 + + +@pytest.mark.numerical +@pytest.mark.parametrize("mode_data_test_parameters", mode_data_test_parameters) +def test_finite_difference_mode_data_polyslab( + mode_data_test_parameters, rng, monkeypatch, numerical_case_dir, redirect_stdout_to_stderr +): + """Test a variety of frequency combinations in the forward monitor to ensure we get + the same adjoint gradient when using a fixed set of frequencies in the objective function.""" + + monkeypatch.setattr(config.adjoint, "solver_freq_chunk_size", 2) + + test_results = np.zeros((2, NUM_FINITE_DIFFERENCE)) + + test_number = mode_data_test_parameters["test_number"] + + ( + mesh_wvl_um, + adj_wvl_um, + geometry_size_wvl, + polyslab_permittivity, + test_number, + ) = operator.itemgetter( + "mesh_wvl_um", + "adj_wvl_um", + "geometry_size_wvl", + "polyslab_permittivity", + "test_number", + )(mode_data_test_parameters) + + adj_freq = td.C_0 / adj_wvl_um + + dim_x_um = geometry_size_wvl[0] * mesh_wvl_um * 2 + dim_y_um = geometry_size_wvl[1] * mesh_wvl_um * 2 + thickness_um = geometry_size_wvl[2] * mesh_wvl_um + + dim_x = 1 + int(dim_x_um / (mesh_wvl_um / MESH_FACTOR_DESIGN)) + dim_y = 1 + int(dim_y_um / (mesh_wvl_um / MESH_FACTOR_DESIGN)) + Nz = 1 + int(thickness_um / (mesh_wvl_um / MESH_FACTOR_DESIGN)) + + sim_geometry = get_sim_geometry(mesh_wvl_um) + + box_for_override = td.Box( + center=(0, 0, 0), + size=(np.inf, np.inf, MODE_LAYER_HEIGHT_WVL * mesh_wvl_um + mesh_wvl_um), + ) + + sim_path_dir = numerical_case_dir / "simulations" / f"test{test_number}" + sim_path_dir.mkdir(parents=True, exist_ok=True) + + # Weights for creating a random objective function over multiple frequencies by + # summing their contributions by random weights. This helps verify gradient errors + # due to a multifrequency objective function. + monitor_top_weights = rng.random(2) + monitor_bottom_weights = rng.random(3) + + def make_random_objective(): + required_freqs = [0.95 * adj_freq, 1.01 * adj_freq, 1.03 * adj_freq] + + def eval_fn(sim_data): + return np.sum( + monitor_top_weights + * np.abs( + sim_data["monitor_mode_top"] + .amps.sel(direction="+") + .sel(f=required_freqs[0:2]) + .data + ) + ** 2 + ) + np.sum( + monitor_bottom_weights + * np.abs( + sim_data["monitor_mode_bottom"] + .amps.sel(direction="+") + .sel(f=required_freqs) + .data + ) + ** 2 + ) + + polyslab_height_um = POLYSLAB_HEIGHT_WVL * adj_wvl_um + + objective = create_objective_function( + lambda mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + geometry_size_wvl=geometry_size_wvl, + polyslab_permittivity=polyslab_permittivity, + box_for_override=box_for_override, + required_freqs=required_freqs, + rng=rng: make_base_sim( + mesh_wvl_um=mesh_wvl_um, + adj_wvl_um=adj_wvl_um, + geometry_size_wvl=geometry_size_wvl, + box_for_override=box_for_override, + required_freqs=required_freqs, + num_extra_freqs=NUM_EXTRA_FREQS, + rng=rng, + run_time=2e-11, + ), + eval_fn, + sim_path_dir=str(sim_path_dir), + mode_layer_height_um=MODE_LAYER_HEIGHT_WVL * mesh_wvl_um, + polyslab_height_um=polyslab_height_um, + polyslab_permittivity=polyslab_permittivity, + ) + + return objective + + objectives = [make_random_objective() for idx in range(NUM_OBJECTIVES)] + objective_grads = [ag.grad(obj) for obj in objectives] + + angles = np.linspace(0, 2 * np.pi, NUM_VERTICES + 1)[0:-1] + vertex_centers_x = 1.1 * mesh_wvl_um * np.cos(angles) + vertex_centers_y = 0.8 * mesh_wvl_um * np.sin(angles) + + input_data = [list(vertex_centers_x) + list(vertex_centers_y)] + gradients = [np.squeeze(np.array(grad_fn(input_data)).flatten()) for grad_fn in objective_grads] + + if PLOT_ADJ_COMPARISON: + for grad in gradients: + plt.plot(grad) + plt.xlabel("vertex") + plt.ylabel("gradient") + plt.show() + + cumulative_rms_error = 0.0 + num = 0 + for grad_1_idx in range(len(gradients)): + for grad_2_idx in range(grad_1_idx + 1, len(gradients)): + grad_1 = gradients[grad_1_idx] + grad_2 = gradients[grad_2_idx] + + cumulative_rms_error += np.sqrt(np.mean((grad_1 - grad_2) ** 2)) + num += 1 + + avg_rms_error = cumulative_rms_error / num + + print("\n" * 3) + print("-" * 20) + print(f"Numerical test #{test_number}") + print(f"Mesh and adjoint wavelengths: {mesh_wvl_um}, {adj_wvl_um}") + print(f"Geometry size: {geometry_size_wvl}") + print(f"Average RMS Error: {avg_rms_error}") + print("-" * 20) + print("\n" * 3) + + save_path = None + if SAVE_FD_ADJ_DATA: + results_dir = numerical_case_dir / NUMERICAL_RESULTS_SUBDIR + results_dir.mkdir(parents=True, exist_ok=True) + save_path = results_dir / f"results_{test_number}.npy" + + try: + assert np.isclose(avg_rms_error, 0.0), "RMS error magnitude too large" + finally: + if save_path is not None: + np.save(save_path, test_results) diff --git a/tests/test_components/autograd/performance/__init__.py b/tests/test_components/autograd/performance/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_components/autograd/performance/postprocess_adj_utils.py b/tests/test_components/autograd/performance/postprocess_adj_utils.py new file mode 100644 index 0000000000..362a39647c --- /dev/null +++ b/tests/test_components/autograd/performance/postprocess_adj_utils.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import json +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path + +import autograd.numpy as np + +import tidy3d as td +from tidy3d.web import run + + +@dataclass +class PostprocessAdjInputs: + """Container for the inputs required by ``postprocess_adj``.""" + + sim_data_adj: td.SimulationData + sim_data_orig: td.SimulationData + sim_data_fwd: td.SimulationData + sim_fields_keys: tuple[tuple, ...] + + +DATASET_ROOT = Path(__file__).resolve().parents[3] / "_test_data" / "autograd" / "postprocess_adj" +SIM_DATA_ADJ_NAME = "sim_data_adj.hdf5" +SIM_DATA_ORIG_NAME = "sim_data_orig.hdf5" +SIM_DATA_FWD_NAME = "sim_data_fwd.hdf5" +SIM_FIELDS_KEYS_NAME = "sim_fields_keys.json" + + +def generate_postprocess_adj_inputs(output_dir: Path) -> PostprocessAdjInputs: + """Generate inputs for ``postprocess_adj``. + + Parameters + ---------- + output_dir: + Directory that can be used to persist intermediate artefacts while generating the dataset. + + Returns + ------- + PostprocessAdjInputs + Fully populated structure ready to be passed to ``postprocess_adj``. + + Notes + ----- + Return ``None`` or raise ``NotImplementedError`` if the dataset cannot be generated in the + current environment; the caller will convert that into a skipped test instead of a failure. + """ + + N_structures_per_dim = 35 + spacing_per_structure = 0.5 + structure_buffer_per_side = 1 + + dim = (N_structures_per_dim - 1) * spacing_per_structure + 2 * structure_buffer_per_side + dim_z = 4.0 + + N_freqs = 2 + wl = 0.65 + freq0 = td.C_0 / wl + fwidth = 0.2 * freq0 + + freqs = np.linspace(freq0 - 0.25 * fwidth, freq0 + 0.25 * fwidth, N_freqs) + + refr_index = 2.5 + permittivity = refr_index**2 + + fwd_source = td.PlaneWave( + center=(0.0, 0.0, -0.25 * dim_z), + size=(td.inf, td.inf, 0.0), + source_time=td.GaussianPulse(freq0=freq0, fwidth=fwidth), + direction="+", + ) + + # dummy adjoint source + adj_source = td.PlaneWave( + center=(0.0, 0.0, 0.25 * dim_z), + size=(td.inf, td.inf, 0.0), + source_time=td.GaussianPulse(freq0=freq0, fwidth=fwidth), + direction="-", + ) + + geometries = [] + for x_idx in range(N_structures_per_dim): + x_pos = -0.5 * dim + structure_buffer_per_side + x_idx * spacing_per_structure + for y_idx in range(N_structures_per_dim): + y_pos = -0.5 * dim + structure_buffer_per_side + y_idx * spacing_per_structure + + geometries.append( + td.Cylinder( + center=(x_pos, y_pos, 0.0), radius=0.25 * spacing_per_structure, length=0.5 * wl + ) + ) + + geom_group = td.GeometryGroup(geometries=geometries) + + structure = td.Structure(geometry=geom_group, medium=td.Medium(permittivity=permittivity)) + + adj_fld_monitor = td.FieldMonitor( + center=(0.0, 0.0, 0.0), + size=structure.geometry.bounding_box.size, + freqs=freqs, + fields=["Ex", "Ey", "Ez"], + name="adjoint_fld_0", + colocate=False, + ) + + adj_perm_monitor = td.PermittivityMonitor( + center=(0.0, 0.0, 0.0), + size=structure.geometry.bounding_box.size, + freqs=freqs, + name="adjoint_eps_0", + ) + + fwd_sim = td.Simulation( + center=(0.0, 0.0, 0.0), + size=(dim, dim, dim_z), + structures=[structure], + monitors=[adj_fld_monitor, adj_perm_monitor], + sources=[fwd_source], + run_time=1e-11, + boundary_spec=td.BoundarySpec.all_sides(boundary=td.PML()), + grid_spec=td.GridSpec.auto(wavelength=wl, min_steps_per_wvl=10), + ) + + adj_sim = td.Simulation( + center=(0.0, 0.0, 0.0), + size=(dim, dim, dim_z), + structures=[structure], + monitors=[adj_fld_monitor, adj_perm_monitor], + sources=[adj_source], + run_time=1e-11, + boundary_spec=td.BoundarySpec.all_sides(boundary=td.PML()), + grid_spec=td.GridSpec.auto(wavelength=wl, min_steps_per_wvl=10), + ) + + sim_data_fwd = run(fwd_sim, task_name="perf_sim_fwd") + sim_data_adj = run(adj_sim, task_name="perf_sim_adj") + + sim_fields_keys = [] + for idx, _ in enumerate(geom_group.geometries): + sim_fields_keys.append(("structure", 0, "geometry", "geometries", idx, "radius")) + + return PostprocessAdjInputs( + sim_data_adj=sim_data_adj, + sim_data_orig=sim_data_fwd, + sim_data_fwd=sim_data_fwd, + sim_fields_keys=tuple(sim_fields_keys), + ) + + +def dataset_exists(dataset_dir: Path = DATASET_ROOT) -> bool: + """Return ``True`` when all persisted artefacts for a dataset are present.""" + required = [ + dataset_dir / SIM_DATA_ADJ_NAME, + dataset_dir / SIM_DATA_ORIG_NAME, + dataset_dir / SIM_DATA_FWD_NAME, + dataset_dir / SIM_FIELDS_KEYS_NAME, + ] + return all(path.exists() for path in required) + + +def persist_postprocess_adj_inputs( + inputs: PostprocessAdjInputs, dataset_dir: Path = DATASET_ROOT +) -> None: + """Persist ``postprocess_adj`` inputs to disk for reuse in performance tests.""" + dataset_dir.mkdir(parents=True, exist_ok=True) + + inputs.sim_data_adj.to_file(dataset_dir / SIM_DATA_ADJ_NAME) + inputs.sim_data_orig.to_file(dataset_dir / SIM_DATA_ORIG_NAME) + inputs.sim_data_fwd.to_file(dataset_dir / SIM_DATA_FWD_NAME) + + serializable_paths = [_serialize_path(path) for path in inputs.sim_fields_keys] + (dataset_dir / SIM_FIELDS_KEYS_NAME).write_text(json.dumps(serializable_paths, indent=2)) + + +def load_postprocess_adj_inputs(dataset_dir: Path = DATASET_ROOT) -> PostprocessAdjInputs: + """Load a persisted dataset for ``postprocess_adj`` performance tests.""" + if not dataset_exists(dataset_dir): + raise FileNotFoundError( + "Persisted postprocess_adj dataset is missing. " + "Run the generation flow first to create it." + ) + + sim_data_adj = td.SimulationData.from_file(dataset_dir / SIM_DATA_ADJ_NAME) + sim_data_orig = td.SimulationData.from_file(dataset_dir / SIM_DATA_ORIG_NAME) + sim_data_fwd = td.SimulationData.from_file(dataset_dir / SIM_DATA_FWD_NAME) + + serializable = json.loads((dataset_dir / SIM_FIELDS_KEYS_NAME).read_text()) + sim_fields_keys = [_deserialize_path(path) for path in serializable] + + return PostprocessAdjInputs( + sim_data_adj=sim_data_adj, + sim_data_orig=sim_data_orig, + sim_data_fwd=sim_data_fwd, + sim_fields_keys=tuple(sim_fields_keys), + ) + + +def _serialize_path(path: Iterable) -> list: + """Convert a tuple path into JSON serializable data.""" + serialized = [] + for value in path: + if isinstance(value, (str, int)): + serialized.append(value) + elif hasattr(value, "tolist"): + serialized.append(value.tolist()) + else: + serialized.append(str(value)) + return serialized + + +def _deserialize_path(path: Iterable) -> tuple: + """Convert JSON serialized path data back into tuple form.""" + deserialized = [] + for value in path: + deserialized.append(value) + return tuple(deserialized) diff --git a/tests/test_components/autograd/performance/test_shape_performance.py b/tests/test_components/autograd/performance/test_shape_performance.py new file mode 100644 index 0000000000..913df3878d --- /dev/null +++ b/tests/test_components/autograd/performance/test_shape_performance.py @@ -0,0 +1,109 @@ +# test autograd and compares to numerically computed finite difference gradients +from __future__ import annotations + +import cProfile +import pstats +from dataclasses import dataclass +from pathlib import Path +from typing import Callable + +import pytest + +from tidy3d.web.api.autograd.backward import postprocess_adj + +from .postprocess_adj_utils import ( + DATASET_ROOT, + PostprocessAdjInputs, + dataset_exists, + generate_postprocess_adj_inputs, + load_postprocess_adj_inputs, + persist_postprocess_adj_inputs, +) + +MARK_GENERATE_NAME = "generate_profile" + + +@dataclass +class ProfileArtifacts: + cpu_profile_path: Path + stats_text_path: Path + peak_memory_bytes: int + + +def _run_postprocess_adj(inputs: PostprocessAdjInputs) -> dict: + return postprocess_adj( + sim_data_adj=inputs.sim_data_adj, + sim_data_orig=inputs.sim_data_orig, + sim_data_fwd=inputs.sim_data_fwd, + sim_fields_keys=inputs.sim_fields_keys, + ) + + +def _load_inputs_or_skip() -> PostprocessAdjInputs: + if not dataset_exists(DATASET_ROOT): + pytest.skip( + f"Persisted postprocess_adj dataset missing at {DATASET_ROOT}. " + f"Run with @{MARK_GENERATE_NAME} first." + ) + return load_postprocess_adj_inputs(DATASET_ROOT) + + +def _profile_callable(func: Callable[[], dict], output_dir: Path) -> ProfileArtifacts: + output_dir.mkdir(parents=True, exist_ok=True) + + try: + import tracemalloc + except ImportError: + tracemalloc = None + + if tracemalloc: + tracemalloc.start() + + profile = cProfile.Profile() + profile.enable() + result = func() + profile.disable() + + cpu_profile_path = output_dir / "postprocess_adj.prof" + profile.dump_stats(cpu_profile_path) + + peak_memory_bytes = -1 + if tracemalloc: + _, peak_memory_bytes = tracemalloc.get_traced_memory() + tracemalloc.stop() + + stats_text_path = output_dir / "postprocess_adj_profile.txt" + with stats_text_path.open("w", encoding="utf8") as stats_file: + stats = pstats.Stats(profile, stream=stats_file) + stats.sort_stats("cumulative") + stats.print_stats(40) + stats_file.write(f"\nResult keys: {sorted(result.keys())}\n") + stats_file.write(f"\nPeak memory GB: {peak_memory_bytes / (1024.0 * 1024.0 * 1024.0)}\n") + + return ProfileArtifacts( + cpu_profile_path=cpu_profile_path, + stats_text_path=stats_text_path, + peak_memory_bytes=peak_memory_bytes, + ) + + +@pytest.mark.generate_profile +def test_generate_postprocess_adj_dataset(tmp_path: Path, redirect_stdout_to_stderr): + """Generate test data for running test_postprocess_adj_profile.""" + print(f"Profile data in {tmp_path}") + + inputs = generate_postprocess_adj_inputs(tmp_path) + + persist_postprocess_adj_inputs(inputs, DATASET_ROOT) + assert dataset_exists(DATASET_ROOT) + + +@pytest.mark.performance_profile +def test_postprocess_adj_profile(tmp_path: Path, redirect_stdout_to_stderr): + """Run profiling for postprocess_adj after having generated test data with test_generate_postprocess_adj_dataset""" + print(f"Profile data in {tmp_path}") + + inputs = _load_inputs_or_skip() + artifacts = _profile_callable(lambda: _run_postprocess_adj(inputs), tmp_path / "profile") + assert artifacts.cpu_profile_path.exists() + assert artifacts.stats_text_path.exists() diff --git a/tests/test_components/autograd/test_autograd.py b/tests/test_components/autograd/test_autograd.py index c3491de611..78e6e1c982 100644 --- a/tests/test_components/autograd/test_autograd.py +++ b/tests/test_components/autograd/test_autograd.py @@ -3229,22 +3229,15 @@ def test_frequency_coordinate_alignment(): field_data = {"Ex": data, "Ey": data, "Ez": data} # Test 1: Exact match should work - freqs_exact = np.array([freq]) - result = _slice_field_data(field_data, freqs_exact) + result = _slice_field_data(field_data, slice(0, 1)) assert len(result) == 3 assert all(k in result for k in ["Ex", "Ey", "Ez"]) - # Test 2: Tiny FP drift (within typical precision) should fail with KeyError - # This demonstrates the original bug - even 0.1 Hz difference causes failure - freqs_drifted = np.array([freq + 0.1]) # 0.1 Hz drift at 2e14 Hz scale - with pytest.raises(KeyError): - _slice_field_data(field_data, freqs_drifted) - - # Test 3: Component indicator filtering works - result_e_only = _slice_field_data(field_data, freqs_exact, component_indicator="E") + # Test 2: Component indicator filtering works + result_e_only = _slice_field_data(field_data, slice(0, 1), component_indicator="E") assert len(result_e_only) == 3 - # Test 4: Multiple frequencies + # Test 3: Multiple frequencies freqs_multi = [1e14, 2e14, 3e14] data_multi = xr.DataArray( np.array([1.0, 2.0, 3.0]), @@ -3254,12 +3247,17 @@ def test_frequency_coordinate_alignment(): field_data_multi = {"Ex": data_multi} # Selecting subset should work - result_subset = _slice_field_data(field_data_multi, np.array([2e14])) + result_subset = _slice_field_data( + field_data_multi, slice(freqs_multi.index(2e14), 1 + freqs_multi.index(2e14)) + ) assert result_subset["Ex"].sizes["f"] == 1 # Selecting non-existent frequency should fail - with pytest.raises(KeyError): - _slice_field_data(field_data_multi, np.array([1.5e14])) + with pytest.raises(IndexError): + _slice_field_data(field_data_multi, slice(len(freqs_multi), len(freqs_multi) + 1)) + + with pytest.raises(IndexError): + _slice_field_data(field_data_multi, slice(-1, len(freqs_multi))) def test_geometry_group_passes_intersected_bounds_to_children(): diff --git a/tidy3d/web/api/autograd/backward.py b/tidy3d/web/api/autograd/backward.py index d1481f4013..5226673b25 100644 --- a/tidy3d/web/api/autograd/backward.py +++ b/tidy3d/web/api/autograd/backward.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections import defaultdict +from typing import Union import numpy as np import xarray as xr @@ -94,13 +95,30 @@ def _compute_eps_array(medium: Medium, frequencies: list[float]) -> DataArray: def _slice_field_data( - field_data: dict, freqs: np.ndarray, component_indicator: str | None = None + field_data: dict, freq_indices: slice, component_indicator: str | None = None ) -> dict: - """Slice field data dictionary along frequency dimension.""" - if component_indicator: - return {k: v.sel(f=freqs) for k, v in field_data.items() if component_indicator in k} - else: - return {k: v.sel(f=freqs) for k, v in field_data.items()} + """ + Slice field data dictionary along frequency dimension using `isel` + and freq_indices. + """ + sliced_data = {} + + # filter keys first to avoid unnecessary looping + keys_to_process = ( + k for k in field_data.keys() if component_indicator is None or component_indicator in k + ) + + num_freqs = next(iter(field_data.values())).sizes["f"] + + start = freq_indices.start + stop = freq_indices.stop + if (start < 0) or (start >= num_freqs): + raise IndexError(f"Frequency slice ({start}, {stop}) is out of bounds for size {num_freqs}") + + for k in keys_to_process: + sliced_data[k] = field_data[k].isel(f=freq_indices) + + return sliced_data @disable_local_subpixel @@ -127,6 +145,23 @@ def postprocess_adj( fld_adj = sim_data_adj._get_adjoint_data(structure_index, data_type="fld") eps_adj = sim_data_adj._get_adjoint_data(structure_index, data_type="eps") + def sort_by_freq_ascending( + dataset: Union[td.PermittivityData, td.FieldData], + ) -> Union[td.PermittivityData, td.FieldData]: + dataset_sort = {} + for key, val in dataset.field_components.items(): + dataset_sort[key] = val.sortby("f", ascending=True) + + return dataset.updated_copy(**dataset_sort) + + # sort data by ascending frequency value to ensure data ordering is consistent + fld_fwd = sort_by_freq_ascending(fld_fwd) + eps_fwd = sort_by_freq_ascending(eps_fwd) + fld_adj = sort_by_freq_ascending(fld_adj) + eps_adj = sort_by_freq_ascending(eps_adj) + + freqs_adj = np.array(fld_adj.monitor.freqs) + # post normalize the adjoint fields if a single, broadband source fwd_flds_adj_normed = {} for key, val in fld_adj.field_components.items(): @@ -147,6 +182,18 @@ def postprocess_adj( H_info_exists = H_der_map is not None + def filter_adj_freq( + dataset: Union[td.PermittityData, td.FieldData], filter_freqs: np.ndarray + ) -> Union[td.PermittityData, td.FieldData]: + dataset_filter_freq = {} + for key, val in dataset.field_components.items(): + dataset_filter_freq[key] = val.sel(f=filter_freqs) + + return dataset.updated_copy(**dataset_filter_freq) + + fld_fwd = filter_adj_freq(fld_fwd, freqs_adj) + eps_fwd = filter_adj_freq(eps_fwd, freqs_adj) + D_fwd = E_to_D(fld_fwd, eps_fwd) D_adj = E_to_D(fld_adj, eps_fwd) @@ -266,43 +313,37 @@ def postprocess_adj( select_adjoint_freqs = adjoint_frequencies[freq_slice] # slice field data for current chunk - E_der_map_chunk = _slice_field_data(E_der_map.field_components, select_adjoint_freqs) - D_der_map_chunk = _slice_field_data(D_der_map.field_components, select_adjoint_freqs) + E_der_map_chunk = _slice_field_data(E_der_map.field_components, freq_slice) + D_der_map_chunk = _slice_field_data(D_der_map.field_components, freq_slice) E_fwd_chunk = _slice_field_data( - fld_fwd.field_components, select_adjoint_freqs, component_indicator="E" + fld_fwd.field_components, freq_slice, component_indicator="E" ) E_adj_chunk = _slice_field_data( - fld_adj.field_components, select_adjoint_freqs, component_indicator="E" + fld_adj.field_components, freq_slice, component_indicator="E" ) - D_fwd_chunk = _slice_field_data(D_fwd.field_components, select_adjoint_freqs) - D_adj_chunk = _slice_field_data(D_adj.field_components, select_adjoint_freqs) - eps_data_chunk = _slice_field_data(eps_fwd.field_components, select_adjoint_freqs) + D_fwd_chunk = _slice_field_data(D_fwd.field_components, freq_slice) + D_adj_chunk = _slice_field_data(D_adj.field_components, freq_slice) + eps_data_chunk = _slice_field_data(eps_fwd.field_components, freq_slice) H_der_map_chunk = None H_fwd_chunk = None H_adj_chunk = None if H_info_exists: - H_der_map_chunk = _slice_field_data( - H_der_map.field_components, select_adjoint_freqs - ) + H_der_map_chunk = _slice_field_data(H_der_map.field_components, freq_slice) H_fwd_chunk = _slice_field_data( - fld_fwd.field_components, select_adjoint_freqs, component_indicator="H" + fld_fwd.field_components, freq_slice, component_indicator="H" ) H_adj_chunk = _slice_field_data( - fld_adj.field_components, select_adjoint_freqs, component_indicator="H" + fld_adj.field_components, freq_slice, component_indicator="H" ) # slice epsilon arrays eps_no_structure_chunk = ( - eps_no_structure.sel(f=select_adjoint_freqs) - if eps_no_structure is not None - else None + eps_no_structure.isel(f=freq_slice) if eps_no_structure is not None else None ) eps_inf_structure_chunk = ( - eps_inf_structure.sel(f=select_adjoint_freqs) - if eps_inf_structure is not None - else None + eps_inf_structure.isel(f=freq_slice) if eps_inf_structure is not None else None ) # create derivative info with sliced data