Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions tests/test_package/test_lazy_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Tests that heavy optional packages are not imported at top-level tidy3d import.

These tests ensure that packages like matplotlib, scipy, trimesh, and vtk
are only imported when actually needed (lazy imports), not at module load time.
This keeps the initial import time low for users who don't need these features.
"""

from __future__ import annotations

import subprocess
import sys

import pytest

# List of packages that should NOT be imported on top-level tidy3d import
LAZY_PACKAGES = [
"matplotlib",
"scipy",
"trimesh",
"vtk",
"networkx", # transitive dependency of trimesh
]


@pytest.mark.parametrize("package", LAZY_PACKAGES)
def test_package_not_imported_on_tidy3d_import(package: str) -> None:
"""Test that a package is not imported when importing tidy3d.

We run this in a subprocess to ensure a clean Python environment
without any prior imports that might pollute sys.modules.
"""
# Create a script that imports tidy3d and checks if the package is in sys.modules
script = f"""
import sys
# Import tidy3d
import tidy3d
# Check if {package} was imported
if "{package}" in sys.modules:
print(f"FAIL: {package} was imported")
sys.exit(1)
else:
print(f"OK: {package} was not imported")
sys.exit(0)
"""
result = subprocess.run(
[sys.executable, "-c", script],
capture_output=True,
text=True,
)

# Print output for debugging
if result.stdout:
print(result.stdout)
if result.stderr:
print(result.stderr, file=sys.stderr)

assert result.returncode == 0, (
f"Package '{package}' was imported on top-level tidy3d import. "
f"This increases import time. Consider making the import lazy."
)


def test_all_lazy_packages_not_imported() -> None:
"""Test that none of the lazy packages are imported when importing tidy3d.

This is a combined test that checks all packages at once, which is faster
than running separate subprocesses for each package.
"""
packages_str = ", ".join(f'"{p}"' for p in LAZY_PACKAGES)
script = f"""
import sys
# Import tidy3d
import tidy3d
# Check which packages were imported
lazy_packages = [{packages_str}]
imported = [p for p in lazy_packages if p in sys.modules]
if imported:
print(f"FAIL: These packages were imported: {{imported}}")
sys.exit(1)
else:
print(f"OK: None of the lazy packages were imported")
sys.exit(0)
"""
result = subprocess.run(
[sys.executable, "-c", script],
capture_output=True,
text=True,
)

# Print output for debugging
if result.stdout:
print(result.stdout)
if result.stderr:
print(result.stderr, file=sys.stderr)

assert result.returncode == 0, (
f"Some lazy packages were imported on top-level tidy3d import: {result.stdout}"
)


def test_matplotlib_imported_on_plot() -> None:
"""Test that matplotlib IS imported when a plot function is called."""
script = """
import sys
import tidy3d as td

# matplotlib should not be imported yet
assert "matplotlib" not in sys.modules, "matplotlib imported too early"

# Create a simple simulation and try to plot it
sim = td.Simulation(
size=(1, 1, 1),
grid_spec=td.GridSpec.auto(wavelength=1.0),
run_time=1e-12,
)

# This should trigger matplotlib import
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend for testing
ax = sim.plot(z=0)

# Now matplotlib should be imported
assert "matplotlib" in sys.modules, "matplotlib should be imported after plotting"
print("OK: matplotlib imported only when needed")
"""
result = subprocess.run(
[sys.executable, "-c", script],
capture_output=True,
text=True,
)

# Print output for debugging
if result.stdout:
print(result.stdout)
if result.stderr:
print(result.stderr, file=sys.stderr)

assert result.returncode == 0, f"Test failed: {result.stderr}"
14 changes: 7 additions & 7 deletions tidy3d/components/data/unstructured/triangular.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@

from __future__ import annotations

from typing import Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import numpy as np
import pydantic.v1 as pd
from xarray import DataArray as XrDataArray

try:
from matplotlib import pyplot as plt
if TYPE_CHECKING:
from matplotlib.tri import Triangulation
except ImportError:
pass

from xarray import DataArray as XrDataArray

from tidy3d.components.base import cached_property
from tidy3d.components.data.data_array import (
Expand Down Expand Up @@ -573,6 +569,8 @@ def does_cover(self, bounds: Bound) -> bool:
@property
def _triangulation_obj(self) -> Triangulation:
"""Matplotlib triangular representation of the grid to use in plotting."""
from matplotlib.tri import Triangulation

return Triangulation(self.points[:, 0], self.points[:, 1], self.cells)

@equal_aspect
Expand Down Expand Up @@ -649,6 +647,8 @@ def plot(
)

if cbar:
from matplotlib import pyplot as plt

label_kwargs = {}
if "label" not in cbar_kwargs:
label_kwargs["label"] = self.values.name
Expand Down
10 changes: 6 additions & 4 deletions tidy3d/components/eme/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

from typing import Any, Literal, Optional, Union

try:
import matplotlib as mpl
except ImportError:
pass
import numpy as np
import pydantic.v1 as pd

Expand Down Expand Up @@ -325,6 +321,8 @@ def plot_eme_ports(
**kwargs: Any,
) -> Ax:
"""Plot the EME ports."""
import matplotlib as mpl

kwargs.setdefault("linewidth", 0.4)
kwargs.setdefault("colors", "black")
rmin = self.geometry.bounds[0][self.axis]
Expand Down Expand Up @@ -372,6 +370,8 @@ def plot_eme_subgrid_boundaries(
Does nothing if ``eme_grid_spec`` is not :class:`.EMECompositeGrid`.
Operates recursively on subgrids.
"""
import matplotlib as mpl

if not isinstance(eme_grid_spec, EMECompositeGrid):
return ax
kwargs.setdefault("linewidth", 0.4)
Expand Down Expand Up @@ -421,6 +421,8 @@ def plot_eme_grid(
**kwargs: Any,
) -> Ax:
"""Plot the EME grid."""
import matplotlib as mpl

kwargs.setdefault("linewidth", 0.2)
kwargs.setdefault("colors", "black")
cell_boundaries = self.eme_grid.boundaries
Expand Down
10 changes: 4 additions & 6 deletions tidy3d/components/geometry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
from numpy._typing import ArrayLike, NDArray
from typing_extensions import Self

try:
from matplotlib import patches
except ImportError:
pass

from tidy3d.compat import _shapely_is_older_than
from tidy3d.components.autograd import (
AutogradFieldMap,
Expand Down Expand Up @@ -2399,6 +2394,7 @@ def _plot_arrow(
matplotlib.axes._subplots.Axes
The matplotlib axes with the arrow added.
"""
from matplotlib import patches

plot_axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z)
_, (dx, dy) = self.pop_axis(direction, axis=plot_axis)
Expand Down Expand Up @@ -2433,7 +2429,7 @@ def _plot_arrow(
arrow = patches.FancyArrowPatch(
(x0, y0),
(x0 + v_x, y0 + v_y),
arrowstyle=arrow_style,
arrowstyle=arrow_style(),
color=color,
alpha=alpha,
zorder=np.inf,
Expand Down Expand Up @@ -2461,6 +2457,8 @@ def _arrow_shape_cb(
sign: float,
bend_radius: float | None,
) -> Callable[[Event], None]:
from matplotlib import patches

def _cb(event: Event) -> None:
# We only want to set the shape once, so we disconnect ourselves
event.canvas.mpl_disconnect(arrow.set_shape_cb[0])
Expand Down
45 changes: 19 additions & 26 deletions tidy3d/components/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@

from __future__ import annotations

from typing import Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import autograd.numpy as np
import pydantic.v1 as pd

try:
if TYPE_CHECKING:
import matplotlib as mpl
import matplotlib.pylab as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
except ImportError:
mpl = None
import pydantic.v1 as pd

from tidy3d.components.material.tcad.charge import (
ChargeConductorMedium,
Expand Down Expand Up @@ -643,9 +639,13 @@ def _add_cbar(
label: str,
cmap: str,
ax: Ax = None,
norm: mpl.colors.Normalize = None,
norm: mpl.colors.Normalize | None = None,
) -> None:
"""Add a colorbar to plot."""
import matplotlib as mpl
import matplotlib.pylab as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.15)
if norm is None:
Expand Down Expand Up @@ -1018,6 +1018,7 @@ def plot_structures_property(
matplotlib.axes._subplots.Axes
The supplied or created matplotlib axes.
"""
import matplotlib as mpl

structures = self.sorted_structures

Expand Down Expand Up @@ -1197,7 +1198,7 @@ def _add_cbar_eps(
eps_max: float,
ax: Ax = None,
reverse: bool = False,
norm: Optional[mpl.colors.Normalize] = None,
norm: mpl.colors.Normalize | None = None,
) -> None:
"""Add a permittivity colorbar to plot."""
Scene._add_cbar(
Expand Down Expand Up @@ -1268,7 +1269,7 @@ def _pcolormesh_shape_custom_medium_structure_eps(
ax: Ax,
grid: Grid,
eps_component: Optional[PermittivityComponent] = None,
norm: mpl.colors.Normalize = None,
norm: mpl.colors.Normalize | None = None,
) -> None:
"""
Plot shape made of custom medium with ``pcolormesh``.
Expand Down Expand Up @@ -1420,9 +1421,10 @@ def _get_structure_eps_plot_params(
reverse: bool = False,
alpha: Optional[float] = None,
eps_component: Optional[PermittivityComponent] = None,
norm: Optional[mpl.colors.Normalize] = None,
norm: mpl.colors.Normalize | None = None,
) -> PlotParams:
"""Constructs the plot parameters for a given medium in scene.plot_eps()."""
import matplotlib as mpl

plot_params = plot_params_structure.copy(update={"linewidth": 0})
if isinstance(medium, AbstractMedium):
Expand All @@ -1448,22 +1450,13 @@ def _get_structure_eps_plot_params(
eps_medium = medium._eps_plot(frequency=freq, eps_component=eps_component)
if norm is not None:
color_value = float(norm(eps_medium))
elif mpl is not None:
else:
active_norm = mpl.colors.Normalize(vmin=eps_min, vmax=eps_max)
color_value = float(active_norm(eps_medium))
else:
if eps_max == eps_min:
color_value = 0.5
else:
color_value = (eps_medium - eps_min) / (eps_max - eps_min)
color_value = min(1.0, max(0.0, color_value))
if mpl is not None:
cmap_name = _get_colormap(reverse=reverse)
cmap = mpl.cm.get_cmap(cmap_name)
rgba = tuple(float(component) for component in cmap(color_value))
else:
gray_value = color_value if reverse else 1.0 - color_value
rgba = (gray_value, gray_value, gray_value, 1.0)
cmap_name = _get_colormap(reverse=reverse)
cmap = mpl.cm.get_cmap(cmap_name)
rgba = tuple(float(component) for component in cmap(color_value))
plot_params = plot_params.copy(update={"facecolor": rgba})

return plot_params
Expand All @@ -1479,7 +1472,7 @@ def _plot_shape_structure_eps(
reverse: bool = False,
alpha: Optional[float] = None,
eps_component: Optional[PermittivityComponent] = None,
norm: Optional[mpl.colors.Normalize] = None,
norm: mpl.colors.Normalize | None = None,
) -> Ax:
"""Plot a structure's cross section shape for a given medium, grayscale for permittivity."""
plot_params = self._get_structure_eps_plot_params(
Expand Down Expand Up @@ -2114,7 +2107,7 @@ def _pcolormesh_shape_doping_box(
shape: Shapely,
ax: Ax,
plt_type: str = "doping",
norm: mpl.colors.Normalize = None,
norm: mpl.colors.Normalize | None = None,
) -> None:
"""
Plot shape made of structure defined with doping.
Expand Down
Loading
Loading