Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions tests/test_package/test_lazy_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""Tests that heavy optional packages are not imported at top-level tidy3d import.
These tests ensure that packages like matplotlib, scipy, trimesh, and vtk
are only imported when actually needed (lazy imports), not at module load time.
This keeps the initial import time low for users who don't need these features.
"""

from __future__ import annotations

import subprocess
import sys

import pytest

# List of packages that should NOT be imported on top-level tidy3d import
LAZY_PACKAGES = [
"matplotlib",
"scipy",
"trimesh",
"vtk",
"networkx", # transitive dependency of trimesh
]


@pytest.mark.parametrize("package", LAZY_PACKAGES)
def test_package_not_imported_on_tidy3d_import(package: str) -> None:
"""Test that a package is not imported when importing tidy3d.
We run this in a subprocess to ensure a clean Python environment
without any prior imports that might pollute sys.modules.
"""
# Create a script that imports tidy3d and checks if the package is in sys.modules
script = f"""
import sys
# Import tidy3d
import tidy3d
# Check if {package} was imported
if "{package}" in sys.modules:
print(f"FAIL: {package} was imported")
sys.exit(1)
else:
print(f"OK: {package} was not imported")
sys.exit(0)
"""
result = subprocess.run(
[sys.executable, "-c", script],
capture_output=True,
text=True,
)

# Print output for debugging
if result.stdout:
print(result.stdout)
if result.stderr:
print(result.stderr, file=sys.stderr)

assert result.returncode == 0, (
f"Package '{package}' was imported on top-level tidy3d import. "
f"This increases import time. Consider making the import lazy."
)


def test_all_lazy_packages_not_imported() -> None:
"""Test that none of the lazy packages are imported when importing tidy3d.
This is a combined test that checks all packages at once, which is faster
than running separate subprocesses for each package.
"""
packages_str = ", ".join(f'"{p}"' for p in LAZY_PACKAGES)
script = f"""
import sys
# Import tidy3d
import tidy3d
# Check which packages were imported
lazy_packages = [{packages_str}]
imported = [p for p in lazy_packages if p in sys.modules]
if imported:
print(f"FAIL: These packages were imported: {{imported}}")
sys.exit(1)
else:
print(f"OK: None of the lazy packages were imported")
sys.exit(0)
"""
result = subprocess.run(
[sys.executable, "-c", script],
capture_output=True,
text=True,
)

# Print output for debugging
if result.stdout:
print(result.stdout)
if result.stderr:
print(result.stderr, file=sys.stderr)

assert result.returncode == 0, (
f"Some lazy packages were imported on top-level tidy3d import: {result.stdout}"
)


def test_matplotlib_imported_on_plot() -> None:
"""Test that matplotlib IS imported when a plot function is called."""
script = """
import sys
import tidy3d as td
# matplotlib should not be imported yet
assert "matplotlib" not in sys.modules, "matplotlib imported too early"
# Create a simple simulation and try to plot it
sim = td.Simulation(
size=(1, 1, 1),
grid_spec=td.GridSpec.auto(wavelength=1.0),
run_time=1e-12,
)
# This should trigger matplotlib import
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend for testing
ax = sim.plot(z=0)
# Now matplotlib should be imported
assert "matplotlib" in sys.modules, "matplotlib should be imported after plotting"
print("OK: matplotlib imported only when needed")
"""
result = subprocess.run(
[sys.executable, "-c", script],
capture_output=True,
text=True,
)

# Print output for debugging
if result.stdout:
print(result.stdout)
if result.stderr:
print(result.stderr, file=sys.stderr)

assert result.returncode == 0, f"Test failed: {result.stderr}"
14 changes: 7 additions & 7 deletions tidy3d/components/data/unstructured/triangular.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@

from __future__ import annotations

from typing import Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import numpy as np
import pydantic.v1 as pd
from xarray import DataArray as XrDataArray

try:
from matplotlib import pyplot as plt
if TYPE_CHECKING:
from matplotlib.tri import Triangulation
except ImportError:
pass

from xarray import DataArray as XrDataArray

from tidy3d.components.base import cached_property
from tidy3d.components.data.data_array import (
Expand Down Expand Up @@ -573,6 +569,8 @@ def does_cover(self, bounds: Bound) -> bool:
@property
def _triangulation_obj(self) -> Triangulation:
"""Matplotlib triangular representation of the grid to use in plotting."""
from matplotlib.tri import Triangulation

return Triangulation(self.points[:, 0], self.points[:, 1], self.cells)

@equal_aspect
Expand Down Expand Up @@ -649,6 +647,8 @@ def plot(
)

if cbar:
from matplotlib import pyplot as plt

label_kwargs = {}
if "label" not in cbar_kwargs:
label_kwargs["label"] = self.values.name
Expand Down
10 changes: 6 additions & 4 deletions tidy3d/components/eme/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

from typing import Any, Literal, Optional, Union

try:
import matplotlib as mpl
except ImportError:
pass
import numpy as np
import pydantic.v1 as pd

Expand Down Expand Up @@ -325,6 +321,8 @@ def plot_eme_ports(
**kwargs: Any,
) -> Ax:
"""Plot the EME ports."""
import matplotlib as mpl

kwargs.setdefault("linewidth", 0.4)
kwargs.setdefault("colors", "black")
rmin = self.geometry.bounds[0][self.axis]
Expand Down Expand Up @@ -372,6 +370,8 @@ def plot_eme_subgrid_boundaries(
Does nothing if ``eme_grid_spec`` is not :class:`.EMECompositeGrid`.
Operates recursively on subgrids.
"""
import matplotlib as mpl

if not isinstance(eme_grid_spec, EMECompositeGrid):
return ax
kwargs.setdefault("linewidth", 0.4)
Expand Down Expand Up @@ -421,6 +421,8 @@ def plot_eme_grid(
**kwargs: Any,
) -> Ax:
"""Plot the EME grid."""
import matplotlib as mpl

kwargs.setdefault("linewidth", 0.2)
kwargs.setdefault("colors", "black")
cell_boundaries = self.eme_grid.boundaries
Expand Down
10 changes: 4 additions & 6 deletions tidy3d/components/geometry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
from numpy._typing import ArrayLike, NDArray
from typing_extensions import Self

try:
from matplotlib import patches
except ImportError:
pass

from tidy3d.compat import _shapely_is_older_than
from tidy3d.components.autograd import (
AutogradFieldMap,
Expand Down Expand Up @@ -2399,6 +2394,7 @@ def _plot_arrow(
matplotlib.axes._subplots.Axes
The matplotlib axes with the arrow added.
"""
from matplotlib import patches

plot_axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z)
_, (dx, dy) = self.pop_axis(direction, axis=plot_axis)
Expand Down Expand Up @@ -2433,7 +2429,7 @@ def _plot_arrow(
arrow = patches.FancyArrowPatch(
(x0, y0),
(x0 + v_x, y0 + v_y),
arrowstyle=arrow_style,
arrowstyle=arrow_style(),
color=color,
alpha=alpha,
zorder=np.inf,
Expand Down Expand Up @@ -2461,6 +2457,8 @@ def _arrow_shape_cb(
sign: float,
bend_radius: float | None,
) -> Callable[[Event], None]:
from matplotlib import patches

def _cb(event: Event) -> None:
# We only want to set the shape once, so we disconnect ourselves
event.canvas.mpl_disconnect(arrow.set_shape_cb[0])
Expand Down
Loading
Loading