diff --git a/.github/workflows/tidy3d-python-client-tests.yml b/.github/workflows/tidy3d-python-client-tests.yml index 7c3ab7c71f..824f6f24c7 100644 --- a/.github/workflows/tidy3d-python-client-tests.yml +++ b/.github/workflows/tidy3d-python-client-tests.yml @@ -101,6 +101,9 @@ on: permissions: contents: read +env: + VERIFICATIONS_PY_VERSION: '3.11' + jobs: determine-test-scope: runs-on: ubuntu-latest @@ -223,6 +226,39 @@ jobs: echo "version_match_tests=$version_match_tests" echo "extras_integration_tests=$extras_integration_tests" echo "test_type=$test_type" + + + move-type-imports: + name: move-type-imports + needs: determine-test-scope + if: needs.determine-test-scope.outputs.code_quality_tests == 'true' + runs-on: ubuntu-latest + steps: + - name: checkout-branch + uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref }} + repository: ${{ github.event.pull_request.head.repo.full_name }} + fetch-depth: 0 + persist-credentials: false + + - name: setup-python + uses: actions/setup-python@v5 + with: + python-version: ${{ env.VERIFICATIONS_PY_VERSION }} + + - name: install-dependencies + run: | + set -euo pipefail + python -m venv .venv + source .venv/bin/activate + pip install libcst + + - name: verify-type-import-guards + run: | + set -euo pipefail + source .venv/bin/activate + python scripts/move_type_imports.py --mode check_on_change lint: needs: determine-test-scope @@ -256,10 +292,10 @@ jobs: submodules: false persist-credentials: false - - name: set-python-3.10 + - name: set-python uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: ${{ env.VERIFICATIONS_PY_VERSION }} - name: Install mypy run: | @@ -270,6 +306,27 @@ jobs: run: | mypy --config-file=pyproject.toml + ensure-common-imports: + name: ensure-common-imports + needs: determine-test-scope + if: needs.determine-test-scope.outputs.code_quality_tests == 'true' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 1 + submodules: false + persist-credentials: false + + - name: set-python-3.10 + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Run tidy3d._common import check + run: | + python scripts/ensure_imports_from_common.py + zizmor: name: Run zizmor 🌈 runs-on: ubuntu-latest @@ -287,7 +344,7 @@ jobs: uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 # v6.7.0 - name: Run zizmor 🌈 - run: uvx zizmor .github/workflows/*.y* --format=sarif . > results.sarif + run: uvx zizmor==1.19.0 .github/workflows/*.y* --format=sarif . > results.sarif env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -298,7 +355,7 @@ jobs: category: zizmor - name: run zizmor directly # this gets a success or fail result - run: uvx zizmor .github/workflows/*.y* + run: uvx zizmor==1.19.0 .github/workflows/*.y* env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -370,6 +427,8 @@ jobs: runs-on: ubuntu-latest if: needs.determine-test-scope.outputs.code_quality_tests == 'true' name: lint-commit-messages + # Soft-fail on PRs (early feedback), hard-fail in merge queue (enforced) + continue-on-error: ${{ github.event_name == 'pull_request' }} steps: - name: Check out source code uses: actions/checkout@v4 @@ -391,7 +450,18 @@ jobs: node --version npm --version npx commitlint --version - + + - name: Check commit messages (pull_request) + if: github.event_name == 'pull_request' + env: + BASE_SHA: ${{ github.event.pull_request.base.sha }} + HEAD_SHA: ${{ github.event.pull_request.head.sha }} + run: | + npx commitlint --from $BASE_SHA --to $HEAD_SHA --verbose || { + echo "::warning::Commit messages don't follow conventional commits format. Fix before merge queue." + exit 1 + } + - name: Check commit messages (merge_group) if: github.event_name == 'merge_group' env: @@ -429,7 +499,7 @@ jobs: - name: install-depedencies run: | - uv venv $GITHUB_WORKSPACE/.venv -p 3.11 + uv venv $GITHUB_WORKSPACE/.venv -p "$VERIFICATIONS_PY_VERSION" source $GITHUB_WORKSPACE/.venv/bin/activate uv pip install -e "$GITHUB_WORKSPACE" @@ -981,8 +1051,10 @@ jobs: - determine-test-scope - local-tests - remote-tests + - move-type-imports - lint - mypy + - ensure-common-imports - verify-schema-change - lint-commit-messages - lint-branch-name @@ -993,6 +1065,12 @@ jobs: - extras-integration-tests runs-on: ubuntu-latest steps: + - name: move-type-imports + if: ${{ needs.determine-test-scope.outputs.code_quality_tests == 'true' && needs.move-type-imports.result != 'success' && needs.move-type-imports.result != 'skipped' }} + run: | + echo "❌ Found imports used only for typing that are not guarded by if TYPE_CHECKING." + exit 1 + - name: check-linting-result if: ${{ needs.determine-test-scope.outputs.code_quality_tests == 'true' && needs.lint.result != 'success' && needs.lint.result != 'skipped' }} run: | @@ -1004,6 +1082,12 @@ jobs: run: | echo "❌ Mypy type checking failed." exit 1 + + - name: check-common-imports-result + if: ${{ needs.determine-test-scope.outputs.code_quality_tests == 'true' && needs.ensure-common-imports.result != 'success' && needs.ensure-common-imports.result != 'skipped' }} + run: | + echo "❌ tidy3d._common import check failed." + exit 1 - name: check-schema-change-verification if: ${{ needs.determine-test-scope.outputs.code_quality_tests == 'true' && needs.verify-schema-change.result != 'success' && needs.verify-schema-change.result != 'skipped' }} diff --git a/.gitignore b/.gitignore index 1e972cd884..ad1d34319c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Batch JSON files batch*.json *.vtu +simulation.json # Byte-compiled / optimized / DLL files *$py.class @@ -135,5 +136,6 @@ htmlcov/ .idea .vscode -# cProfile output +# profile outputs *.prof +pytest_profile_stats.txt diff --git a/.gitmodules b/.gitmodules index 2ce36b88a2..e69de29bb2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +0,0 @@ -[submodule "docs/notebooks"] - path = docs/notebooks - url = git@github.com:flexcompute/tidy3d-notebooks.git -[submodule "docs/faq"] - path = docs/faq - url = https://github.com/flexcompute/tidy3d-faq diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fce5cbc0cc..8a66d9438e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,6 +3,14 @@ default_install_hook_types: - pre-commit - commit-msg repos: + - repo: local + hooks: + - id: move-type-imports + name: move type-only imports under TYPE_CHECKING + entry: poetry run python scripts/move_type_imports.py --mode fix --only-changed + language: system + pass_filenames: false + stages: [pre-commit] - repo: https://github.com/astral-sh/ruff-pre-commit rev: "v0.11.11" hooks: @@ -20,7 +28,7 @@ repos: entry: bash -c 'commitlint --edit || exit 0' - repo: https://github.com/zizmorcore/zizmor-pre-commit # Zizmor version. - rev: v1.15.2 + rev: v1.19.0 hooks: - id: zizmor stages: [pre-commit] diff --git a/AGENTS.md b/AGENTS.md index 0eaf29d80c..43498ad651 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -45,5 +45,6 @@ - Follow Conventional Commits per `.commitlintrc.json`. - Branch names must use an allowed prefix (`chore`, `hotfix`, `daily-chore`) or include a Jira key to satisfy CI. - PRs should link issues, summarize behavior changes, list the `poetry run …` checks you executed, and call out docs/schema updates. +- Add a changelog entry under `## [Unreleased]` in `CHANGELOG.md` for user-facing changes (new features, bug fixes, breaking changes). _Reminder: update this AGENTS.md whenever workflow, tooling, or review expectations change so agents stay in sync with the repo._ diff --git a/CHANGELOG.md b/CHANGELOG.md index c320443a39..55a3d98c96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,14 +5,47 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [2.10.2] - 2026-01-21 ### Added +- Added `warn_once` option to logging configuration (`td.config.logging.warn_once`) that causes each unique warning message to be shown only once per process, reducing noise from repeated validation warnings. ### Changed +- Unified inside/outside permittivity handling for all geometries when computing shape gradients. +- Enabled PEC gradients for dielectric structures embedded in PEC through use of `background_medium` field in `Structure` which includes shapes with combinations of dielectric-dielectric and dielectric-PEC boundaries. +- Improved `ComponentModeler` task monitoring. + +### Fixed +- Fixed interpolation handling for permittivity and conductivity gradients in `CustomPoleResidue`. +- Fixed docstrings with missing Notes sections causing spurious parameters in Sphinx documentation. + +## [2.10.1] - 2026-01-14 + +### Added +- Added priority attribute to `TopologyDesignRegion` to enable manual control of overlapping structures. +- `to_mat_file()` method is now available on `ModeSimulationData` and `HeatChargeSimulationData` for exporting results to MATLAB `.mat` files. +- Added autograd support for diagonal `AnisotropicMedium` and `CustomAnisotropicMedium` with diagonal permittivity tensor. +- Added support of numpy 2.4 +- Added validation to `DCVoltageSource` that warns when duplicate voltage values are detected in the `voltage` array, including treating `0` and `-0` as the same value. + +### Changed +- For `HeatChargeSimulation` objects, the `plot` function now adds the simulation boundary conditions. ### Fixed - Fixed `AutoImpedanceSpec` validation to check path intersections against all conductors, not just filtered ones, as well as the mode plane bounds. +- Fixed `WavePort` validation so invalid `mode_spec` errors are no longer masked by a `KeyError`. +- Fixed adjoint gradients being treated as zero due to scale-dependent `np.allclose(..., atol=1e-8)` checks, which could skip adjoint simulations and return zero gradients. +- Fixed adjoint setup crashing when traced monitor outputs produce no adjoint sources, returning no adjoint simulations instead. +- Fixed autograd `interpn` compatibility with newer SciPy versions by avoiding dtype coercion during interpolation setup. +- Fixed interpolation handling for permittivity and conductivity gradients in CustomMedium. +- Restored original batch-load logging by suppressing per-task “Loading simulation…” messages. +- Fixed output range of `tidy3d.plugins.invdes.FilterAndProject` to be between 0 and 1. +- Cropped adjoint monitor sizes in 2D simulations to planar geometry intersection. +- Fixed `Batch.download()` silently succeeding when background downloads fail (e.g., gzip extraction errors). +- Handling of zero values when using `sim_data.plot_field` with `scale=dB`. +- Fixed `intersections_plane` method in `PolySlab`, which sometimes missed vertices for planes coincident with `PolySlab` side faces. +- Fixed `http_interceptor` crash on non-dict JSON responses +- Fixed gradient regression for `Box` geometries where outside permittivity was not being correctly sampled. ## [2.10.0] - 2025-12-18 @@ -1907,7 +1940,9 @@ which fields are to be projected is now determined automatically based on the me - Job and Batch classes for better simulation handling (eventually to fully replace webapi functions). - A large number of small improvements and bug fixes. -[Unreleased]: https://github.com/flexcompute/tidy3d/compare/v2.10.0...develop +[Unreleased]: https://github.com/flexcompute/tidy3d/compare/v2.10.2...develop +[2.10.2]: https://github.com/flexcompute/tidy3d/compare/v2.10.1...v2.10.2 +[2.10.1]: https://github.com/flexcompute/tidy3d/compare/v2.10.0...v2.10.1 [2.10.0]: https://github.com/flexcompute/tidy3d/compare/v2.9.3...v2.10.0 [2.9.3]: https://github.com/flexcompute/tidy3d/compare/v2.9.2...v2.9.3 [2.9.2]: https://github.com/flexcompute/tidy3d/compare/v2.9.1...v2.9.2 diff --git a/docs/api/index.rst b/docs/api/index.rst index 3ea92ea88e..b3d8cc0e00 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -12,7 +12,7 @@ API |:computer:| mediums material_library boundary_conditions - discretization + discretization/index sources monitors output_data diff --git a/docs/api/plugins/index.rst b/docs/api/plugins/index.rst index d99b2994a6..5134e1aa0e 100644 --- a/docs/api/plugins/index.rst +++ b/docs/api/plugins/index.rst @@ -1,6 +1,11 @@ Plugins ======= +.. warning:: + + |:warning:| The 'adjoint' plugin (legacy JAX-based adjoint plugin) + was deprecated in Tidy3D 'v2.7.0' and is disabled as of 'v2.9.0'. + .. toctree:: :maxdepth: 1 diff --git a/docs/development/usage.rst b/docs/development/usage.rst index 37a706fc98..7b380ba5c4 100644 --- a/docs/development/usage.rst +++ b/docs/development/usage.rst @@ -67,6 +67,12 @@ There are a range of handy development functions that you might want to use to s * - Running ``pytest`` commands inside the ``poetry`` environment. - Make sure you have already installed ``tidy3d`` in ``poetry`` and you are in the root directory. - ``poetry run pytest`` + * - Analyze slow ``pytest`` runs with durations / cProfile / debug subset helpers. + - Use ``--debug`` to run only the first N collected tests or ``--profile`` to capture call stacks. + - ``python scripts/profile_pytest.py [options]`` + * - Track ``pytest`` RAM usage over time and per test. + - Defaults to pytest's configured parallelism; use ``--single-process`` to disable xdist. + - ``poetry run python scripts/pytest_ram_profile.py [options]`` * - Run ``coverage`` testing from the ``poetry`` environment. - - ``poetry run coverage run -m pytest`` @@ -83,5 +89,3 @@ There are a range of handy development functions that you might want to use to s - - ``poetry run tidy3d develop replace-in-files`` - - diff --git a/poetry.lock b/poetry.lock index 5474070f8b..9598484868 100644 --- a/poetry.lock +++ b/poetry.lock @@ -71,15 +71,15 @@ files = [ [[package]] name = "anyio" -version = "4.12.0" +version = "4.12.1" description = "High-level concurrency and networking framework on top of asyncio or Trio" optional = true python-versions = ">=3.9" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "anyio-4.12.0-py3-none-any.whl", hash = "sha256:dad2376a628f98eeca4881fc56cd06affd18f659b17a747d3ff0307ced94b1bb"}, - {file = "anyio-4.12.0.tar.gz", hash = "sha256:73c693b567b0c55130c104d0b43a9baf3aa6a31fc6110116509f27bf75e21ec0"}, + {file = "anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c"}, + {file = "anyio-4.12.1.tar.gz", hash = "sha256:41cfcc3a4c85d3f05c932da7c26d0201ac36f72abd4435ba90d0464a3ffed703"}, ] [package.dependencies] @@ -182,15 +182,15 @@ test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock [[package]] name = "astroid" -version = "4.0.2" +version = "4.0.3" description = "An abstract syntax tree for Python with inference support." optional = true python-versions = ">=3.10.0" groups = ["main"] markers = "extra == \"dev\" or extra == \"tests\" or extra == \"docs\"" files = [ - {file = "astroid-4.0.2-py3-none-any.whl", hash = "sha256:d7546c00a12efc32650b19a2bb66a153883185d3179ab0d4868086f807338b9b"}, - {file = "astroid-4.0.2.tar.gz", hash = "sha256:ac8fb7ca1c08eb9afec91ccc23edbd8ac73bb22cbdd7da1d488d9fb8d6579070"}, + {file = "astroid-4.0.3-py3-none-any.whl", hash = "sha256:864a0a34af1bd70e1049ba1e61cee843a7252c826d97825fcee9b2fcbd9e1b14"}, + {file = "astroid-4.0.3.tar.gz", hash = "sha256:08d1de40d251cc3dc4a7a12726721d475ac189e4e583d596ece7422bc176bda3"}, ] [package.dependencies] @@ -215,15 +215,15 @@ test = ["astroid (>=2,<5)", "pytest (<9.0)", "pytest-cov", "pytest-xdist"] [[package]] name = "async-lru" -version = "2.0.5" +version = "2.1.0" description = "Simple LRU cache for asyncio" optional = true -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "async_lru-2.0.5-py3-none-any.whl", hash = "sha256:ab95404d8d2605310d345932697371a5f40def0487c03d6d0ad9138de52c9943"}, - {file = "async_lru-2.0.5.tar.gz", hash = "sha256:481d52ccdd27275f42c43a928b4a50c3bfb2d67af4e78b170e3e0bb39c66e5bb"}, + {file = "async_lru-2.1.0-py3-none-any.whl", hash = "sha256:fa12dcf99a42ac1280bc16c634bbaf06883809790f6304d85cdab3f666f33a7e"}, + {file = "async_lru-2.1.0.tar.gz", hash = "sha256:9eeb2fecd3fe42cc8a787fc32ead53a3a7158cc43d039c3c55ab3e4e5b2a80ed"}, ] [package.dependencies] @@ -322,47 +322,47 @@ lxml = ["lxml"] [[package]] name = "black" -version = "25.12.0" +version = "26.1.0" description = "The uncompromising code formatter." optional = true python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "black-25.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f85ba1ad15d446756b4ab5f3044731bf68b777f8f9ac9cdabd2425b97cd9c4e8"}, - {file = "black-25.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:546eecfe9a3a6b46f9d69d8a642585a6eaf348bcbbc4d87a19635570e02d9f4a"}, - {file = "black-25.12.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:17dcc893da8d73d8f74a596f64b7c98ef5239c2cd2b053c0f25912c4494bf9ea"}, - {file = "black-25.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:09524b0e6af8ba7a3ffabdfc7a9922fb9adef60fed008c7cd2fc01f3048e6e6f"}, - {file = "black-25.12.0-cp310-cp310-win_arm64.whl", hash = "sha256:b162653ed89eb942758efeb29d5e333ca5bb90e5130216f8369857db5955a7da"}, - {file = "black-25.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d0cfa263e85caea2cff57d8f917f9f51adae8e20b610e2b23de35b5b11ce691a"}, - {file = "black-25.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1a2f578ae20c19c50a382286ba78bfbeafdf788579b053d8e4980afb079ab9be"}, - {file = "black-25.12.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e1b65634b0e471d07ff86ec338819e2ef860689859ef4501ab7ac290431f9b"}, - {file = "black-25.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:a3fa71e3b8dd9f7c6ac4d818345237dfb4175ed3bf37cd5a581dbc4c034f1ec5"}, - {file = "black-25.12.0-cp311-cp311-win_arm64.whl", hash = "sha256:51e267458f7e650afed8445dc7edb3187143003d52a1b710c7321aef22aa9655"}, - {file = "black-25.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:31f96b7c98c1ddaeb07dc0f56c652e25bdedaac76d5b68a059d998b57c55594a"}, - {file = "black-25.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:05dd459a19e218078a1f98178c13f861fe6a9a5f88fc969ca4d9b49eb1809783"}, - {file = "black-25.12.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c1f68c5eff61f226934be6b5b80296cf6939e5d2f0c2f7d543ea08b204bfaf59"}, - {file = "black-25.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:274f940c147ddab4442d316b27f9e332ca586d39c85ecf59ebdea82cc9ee8892"}, - {file = "black-25.12.0-cp312-cp312-win_arm64.whl", hash = "sha256:169506ba91ef21e2e0591563deda7f00030cb466e747c4b09cb0a9dae5db2f43"}, - {file = "black-25.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a05ddeb656534c3e27a05a29196c962877c83fa5503db89e68857d1161ad08a5"}, - {file = "black-25.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9ec77439ef3e34896995503865a85732c94396edcc739f302c5673a2315e1e7f"}, - {file = "black-25.12.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0e509c858adf63aa61d908061b52e580c40eae0dfa72415fa47ac01b12e29baf"}, - {file = "black-25.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:252678f07f5bac4ff0d0e9b261fbb029fa530cfa206d0a636a34ab445ef8ca9d"}, - {file = "black-25.12.0-cp313-cp313-win_arm64.whl", hash = "sha256:bc5b1c09fe3c931ddd20ee548511c64ebf964ada7e6f0763d443947fd1c603ce"}, - {file = "black-25.12.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:0a0953b134f9335c2434864a643c842c44fba562155c738a2a37a4d61f00cad5"}, - {file = "black-25.12.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:2355bbb6c3b76062870942d8cc450d4f8ac71f9c93c40122762c8784df49543f"}, - {file = "black-25.12.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9678bd991cc793e81d19aeeae57966ee02909877cb65838ccffef24c3ebac08f"}, - {file = "black-25.12.0-cp314-cp314-win_amd64.whl", hash = "sha256:97596189949a8aad13ad12fcbb4ae89330039b96ad6742e6f6b45e75ad5cfd83"}, - {file = "black-25.12.0-cp314-cp314-win_arm64.whl", hash = "sha256:778285d9ea197f34704e3791ea9404cd6d07595745907dd2ce3da7a13627b29b"}, - {file = "black-25.12.0-py3-none-any.whl", hash = "sha256:48ceb36c16dbc84062740049eef990bb2ce07598272e673c17d1a7720c71c828"}, - {file = "black-25.12.0.tar.gz", hash = "sha256:8d3dd9cea14bff7ddc0eb243c811cdb1a011ebb4800a5f0335a01a68654796a7"}, + {file = "black-26.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ca699710dece84e3ebf6e92ee15f5b8f72870ef984bf944a57a777a48357c168"}, + {file = "black-26.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5e8e75dabb6eb83d064b0db46392b25cabb6e784ea624219736e8985a6b3675d"}, + {file = "black-26.1.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eb07665d9a907a1a645ee41a0df8a25ffac8ad9c26cdb557b7b88eeeeec934e0"}, + {file = "black-26.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:7ed300200918147c963c87700ccf9966dceaefbbb7277450a8d646fc5646bf24"}, + {file = "black-26.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:c5b7713daea9bf943f79f8c3b46f361cc5229e0e604dcef6a8bb6d1c37d9df89"}, + {file = "black-26.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3cee1487a9e4c640dc7467aaa543d6c0097c391dc8ac74eb313f2fbf9d7a7cb5"}, + {file = "black-26.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d62d14ca31c92adf561ebb2e5f2741bf8dea28aef6deb400d49cca011d186c68"}, + {file = "black-26.1.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fb1dafbbaa3b1ee8b4550a84425aac8874e5f390200f5502cf3aee4a2acb2f14"}, + {file = "black-26.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:101540cb2a77c680f4f80e628ae98bd2bd8812fb9d72ade4f8995c5ff019e82c"}, + {file = "black-26.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:6f3977a16e347f1b115662be07daa93137259c711e526402aa444d7a88fdc9d4"}, + {file = "black-26.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6eeca41e70b5f5c84f2f913af857cf2ce17410847e1d54642e658e078da6544f"}, + {file = "black-26.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:dd39eef053e58e60204f2cdf059e2442e2eb08f15989eefe259870f89614c8b6"}, + {file = "black-26.1.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9459ad0d6cd483eacad4c6566b0f8e42af5e8b583cee917d90ffaa3778420a0a"}, + {file = "black-26.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:a19915ec61f3a8746e8b10adbac4a577c6ba9851fa4a9e9fbfbcf319887a5791"}, + {file = "black-26.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:643d27fb5facc167c0b1b59d0315f2674a6e950341aed0fc05cf307d22bf4954"}, + {file = "black-26.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ba1d768fbfb6930fc93b0ecc32a43d8861ded16f47a40f14afa9bb04ab93d304"}, + {file = "black-26.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2b807c240b64609cb0e80d2200a35b23c7df82259f80bef1b2c96eb422b4aac9"}, + {file = "black-26.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1de0f7d01cc894066a1153b738145b194414cc6eeaad8ef4397ac9abacf40f6b"}, + {file = "black-26.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:91a68ae46bf07868963671e4d05611b179c2313301bd756a89ad4e3b3db2325b"}, + {file = "black-26.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:be5e2fe860b9bd9edbf676d5b60a9282994c03fbbd40fe8f5e75d194f96064ca"}, + {file = "black-26.1.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9dc8c71656a79ca49b8d3e2ce8103210c9481c57798b48deeb3a8bb02db5f115"}, + {file = "black-26.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:b22b3810451abe359a964cc88121d57f7bce482b53a066de0f1584988ca36e79"}, + {file = "black-26.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:53c62883b3f999f14e5d30b5a79bd437236658ad45b2f853906c7cbe79de00af"}, + {file = "black-26.1.0-cp314-cp314-win_amd64.whl", hash = "sha256:f016baaadc423dc960cdddf9acae679e71ee02c4c341f78f3179d7e4819c095f"}, + {file = "black-26.1.0-cp314-cp314-win_arm64.whl", hash = "sha256:66912475200b67ef5a0ab665011964bf924745103f51977a78b4fb92a9fc1bf0"}, + {file = "black-26.1.0-py3-none-any.whl", hash = "sha256:1054e8e47ebd686e078c0bb0eaf31e6ce69c966058d122f2c0c950311f9f3ede"}, + {file = "black-26.1.0.tar.gz", hash = "sha256:d294ac3340eef9c9eb5d29288e96dc719ff269a88e27b396340459dd85da4c58"}, ] [package.dependencies] click = ">=8.0.0" mypy-extensions = ">=0.4.3" packaging = ">=22.0" -pathspec = ">=0.9.0" +pathspec = ">=1.0.0" platformdirs = ">=2" pytokens = ">=0.3.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} @@ -396,18 +396,18 @@ css = ["tinycss2 (>=1.1.0,<1.5)"] [[package]] name = "boto3" -version = "1.42.14" +version = "1.42.30" description = "The AWS SDK for Python" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "boto3-1.42.14-py3-none-any.whl", hash = "sha256:bfcc665227bb4432a235cb4adb47719438d6472e5ccbf7f09512046c3f749670"}, - {file = "boto3-1.42.14.tar.gz", hash = "sha256:a5d005667b480c844ed3f814a59f199ce249d0f5669532a17d06200c0a93119c"}, + {file = "boto3-1.42.30-py3-none-any.whl", hash = "sha256:d7e548bea65e0ae2c465c77de937bc686b591aee6a352d5a19a16bc751e591c1"}, + {file = "boto3-1.42.30.tar.gz", hash = "sha256:ba9cd2f7819637d15bfbeb63af4c567fcc8a7dcd7b93dd12734ec58601169538"}, ] [package.dependencies] -botocore = ">=1.42.14,<1.43.0" +botocore = ">=1.42.30,<1.43.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.16.0,<0.17.0" @@ -416,14 +416,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.42.14" +version = "1.42.30" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "botocore-1.42.14-py3-none-any.whl", hash = "sha256:efe89adfafa00101390ec2c371d453b3359d5f9690261bc3bd70131e0d453e8e"}, - {file = "botocore-1.42.14.tar.gz", hash = "sha256:cf5bebb580803c6cfd9886902ca24834b42ecaa808da14fb8cd35ad523c9f621"}, + {file = "botocore-1.42.30-py3-none-any.whl", hash = "sha256:97070a438cac92430bb7b65f8ebd7075224f4a289719da4ee293d22d1e98db02"}, + {file = "botocore-1.42.30.tar.gz", hash = "sha256:9bf1662b8273d5cc3828a49f71ca85abf4e021011c1f0a71f41a2ea5769a5116"}, ] [package.dependencies] @@ -449,14 +449,14 @@ files = [ [[package]] name = "certifi" -version = "2025.11.12" +version = "2026.1.4" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.7" groups = ["main"] files = [ - {file = "certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b"}, - {file = "certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316"}, + {file = "certifi-2026.1.4-py3-none-any.whl", hash = "sha256:9943707519e4add1115f44c2bc244f782c0249876bf51b6599fee1ffbedd685c"}, + {file = "certifi-2026.1.4.tar.gz", hash = "sha256:ac726dd470482006e014ad384921ed6438c457018f4b3d204aea4281258b2120"}, ] [[package]] @@ -1001,105 +1001,105 @@ test-no-images = ["pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist" [[package]] name = "coverage" -version = "7.13.0" +version = "7.13.1" description = "Code coverage measurement for Python" optional = false python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\" or extra == \"tests\"" files = [ - {file = "coverage-7.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:02d9fb9eccd48f6843c98a37bd6817462f130b86da8660461e8f5e54d4c06070"}, - {file = "coverage-7.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:367449cf07d33dc216c083f2036bb7d976c6e4903ab31be400ad74ad9f85ce98"}, - {file = "coverage-7.13.0-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cdb3c9f8fef0a954c632f64328a3935988d33a6604ce4bf67ec3e39670f12ae5"}, - {file = "coverage-7.13.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d10fd186aac2316f9bbb46ef91977f9d394ded67050ad6d84d94ed6ea2e8e54e"}, - {file = "coverage-7.13.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7f88ae3e69df2ab62fb0bc5219a597cb890ba5c438190ffa87490b315190bb33"}, - {file = "coverage-7.13.0-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c4be718e51e86f553bcf515305a158a1cd180d23b72f07ae76d6017c3cc5d791"}, - {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a00d3a393207ae12f7c49bb1c113190883b500f48979abb118d8b72b8c95c032"}, - {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a7b1cd820e1b6116f92c6128f1188e7afe421c7e1b35fa9836b11444e53ebd9"}, - {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:37eee4e552a65866f15dedd917d5e5f3d59805994260720821e2c1b51ac3248f"}, - {file = "coverage-7.13.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:62d7c4f13102148c78d7353c6052af6d899a7f6df66a32bddcc0c0eb7c5326f8"}, - {file = "coverage-7.13.0-cp310-cp310-win32.whl", hash = "sha256:24e4e56304fdb56f96f80eabf840eab043b3afea9348b88be680ec5986780a0f"}, - {file = "coverage-7.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:74c136e4093627cf04b26a35dab8cbfc9b37c647f0502fc313376e11726ba303"}, - {file = "coverage-7.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0dfa3855031070058add1a59fdfda0192fd3e8f97e7c81de0596c145dea51820"}, - {file = "coverage-7.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fdb6f54f38e334db97f72fa0c701e66d8479af0bc3f9bfb5b90f1c30f54500f"}, - {file = "coverage-7.13.0-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7e442c013447d1d8d195be62852270b78b6e255b79b8675bad8479641e21fd96"}, - {file = "coverage-7.13.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1ed5630d946859de835a85e9a43b721123a8a44ec26e2830b296d478c7fd4259"}, - {file = "coverage-7.13.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7f15a931a668e58087bc39d05d2b4bf4b14ff2875b49c994bbdb1c2217a8daeb"}, - {file = "coverage-7.13.0-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:30a3a201a127ea57f7e14ba43c93c9c4be8b7d17a26e03bb49e6966d019eede9"}, - {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7a485ff48fbd231efa32d58f479befce52dcb6bfb2a88bb7bf9a0b89b1bc8030"}, - {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:22486cdafba4f9e471c816a2a5745337742a617fef68e890d8baf9f3036d7833"}, - {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:263c3dbccc78e2e331e59e90115941b5f53e85cfcc6b3b2fbff1fd4e3d2c6ea8"}, - {file = "coverage-7.13.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e5330fa0cc1f5c3c4c3bb8e101b742025933e7848989370a1d4c8c5e401ea753"}, - {file = "coverage-7.13.0-cp311-cp311-win32.whl", hash = "sha256:0f4872f5d6c54419c94c25dd6ae1d015deeb337d06e448cd890a1e89a8ee7f3b"}, - {file = "coverage-7.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51a202e0f80f241ccb68e3e26e19ab5b3bf0f813314f2c967642f13ebcf1ddfe"}, - {file = "coverage-7.13.0-cp311-cp311-win_arm64.whl", hash = "sha256:d2a9d7f1c11487b1c69367ab3ac2d81b9b3721f097aa409a3191c3e90f8f3dd7"}, - {file = "coverage-7.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:0b3d67d31383c4c68e19a88e28fc4c2e29517580f1b0ebec4a069d502ce1e0bf"}, - {file = "coverage-7.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:581f086833d24a22c89ae0fe2142cfaa1c92c930adf637ddf122d55083fb5a0f"}, - {file = "coverage-7.13.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0a3a30f0e257df382f5f9534d4ce3d4cf06eafaf5192beb1a7bd066cb10e78fb"}, - {file = "coverage-7.13.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:583221913fbc8f53b88c42e8dbb8fca1d0f2e597cb190ce45916662b8b9d9621"}, - {file = "coverage-7.13.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f5d9bd30756fff3e7216491a0d6d520c448d5124d3d8e8f56446d6412499e74"}, - {file = "coverage-7.13.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a23e5a1f8b982d56fa64f8e442e037f6ce29322f1f9e6c2344cd9e9f4407ee57"}, - {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9b01c22bc74a7fb44066aaf765224c0d933ddf1f5047d6cdfe4795504a4493f8"}, - {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:898cce66d0836973f48dda4e3514d863d70142bdf6dfab932b9b6a90ea5b222d"}, - {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:3ab483ea0e251b5790c2aac03acde31bff0c736bf8a86829b89382b407cd1c3b"}, - {file = "coverage-7.13.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1d84e91521c5e4cb6602fe11ece3e1de03b2760e14ae4fcf1a4b56fa3c801fcd"}, - {file = "coverage-7.13.0-cp312-cp312-win32.whl", hash = "sha256:193c3887285eec1dbdb3f2bd7fbc351d570ca9c02ca756c3afbc71b3c98af6ef"}, - {file = "coverage-7.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:4f3e223b2b2db5e0db0c2b97286aba0036ca000f06aca9b12112eaa9af3d92ae"}, - {file = "coverage-7.13.0-cp312-cp312-win_arm64.whl", hash = "sha256:086cede306d96202e15a4b77ace8472e39d9f4e5f9fd92dd4fecdfb2313b2080"}, - {file = "coverage-7.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:28ee1c96109974af104028a8ef57cec21447d42d0e937c0275329272e370ebcf"}, - {file = "coverage-7.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d1e97353dcc5587b85986cda4ff3ec98081d7e84dd95e8b2a6d59820f0545f8a"}, - {file = "coverage-7.13.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:99acd4dfdfeb58e1937629eb1ab6ab0899b131f183ee5f23e0b5da5cba2fec74"}, - {file = "coverage-7.13.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ff45e0cd8451e293b63ced93161e189780baf444119391b3e7d25315060368a6"}, - {file = "coverage-7.13.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f4f72a85316d8e13234cafe0a9f81b40418ad7a082792fa4165bd7d45d96066b"}, - {file = "coverage-7.13.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:11c21557d0e0a5a38632cbbaca5f008723b26a89d70db6315523df6df77d6232"}, - {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:76541dc8d53715fb4f7a3a06b34b0dc6846e3c69bc6204c55653a85dd6220971"}, - {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:6e9e451dee940a86789134b6b0ffbe31c454ade3b849bb8a9d2cca2541a8e91d"}, - {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:5c67dace46f361125e6b9cace8fe0b729ed8479f47e70c89b838d319375c8137"}, - {file = "coverage-7.13.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f59883c643cb19630500f57016f76cfdcd6845ca8c5b5ea1f6e17f74c8e5f511"}, - {file = "coverage-7.13.0-cp313-cp313-win32.whl", hash = "sha256:58632b187be6f0be500f553be41e277712baa278147ecb7559983c6d9faf7ae1"}, - {file = "coverage-7.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:73419b89f812f498aca53f757dd834919b48ce4799f9d5cad33ca0ae442bdb1a"}, - {file = "coverage-7.13.0-cp313-cp313-win_arm64.whl", hash = "sha256:eb76670874fdd6091eedcc856128ee48c41a9bbbb9c3f1c7c3cf169290e3ffd6"}, - {file = "coverage-7.13.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:6e63ccc6e0ad8986386461c3c4b737540f20426e7ec932f42e030320896c311a"}, - {file = "coverage-7.13.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:494f5459ffa1bd45e18558cd98710c36c0b8fbfa82a5eabcbe671d80ecffbfe8"}, - {file = "coverage-7.13.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:06cac81bf10f74034e055e903f5f946e3e26fc51c09fc9f584e4a1605d977053"}, - {file = "coverage-7.13.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f2ffc92b46ed6e6760f1d47a71e56b5664781bc68986dbd1836b2b70c0ce2071"}, - {file = "coverage-7.13.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0602f701057c6823e5db1b74530ce85f17c3c5be5c85fc042ac939cbd909426e"}, - {file = "coverage-7.13.0-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:25dc33618d45456ccb1d37bce44bc78cf269909aa14c4db2e03d63146a8a1493"}, - {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:71936a8b3b977ddd0b694c28c6a34f4fff2e9dd201969a4ff5d5fc7742d614b0"}, - {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:936bc20503ce24770c71938d1369461f0c5320830800933bc3956e2a4ded930e"}, - {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:af0a583efaacc52ae2521f8d7910aff65cdb093091d76291ac5820d5e947fc1c"}, - {file = "coverage-7.13.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f1c23e24a7000da892a312fb17e33c5f94f8b001de44b7cf8ba2e36fbd15859e"}, - {file = "coverage-7.13.0-cp313-cp313t-win32.whl", hash = "sha256:5f8a0297355e652001015e93be345ee54393e45dc3050af4a0475c5a2b767d46"}, - {file = "coverage-7.13.0-cp313-cp313t-win_amd64.whl", hash = "sha256:6abb3a4c52f05e08460bd9acf04fec027f8718ecaa0d09c40ffbc3fbd70ecc39"}, - {file = "coverage-7.13.0-cp313-cp313t-win_arm64.whl", hash = "sha256:3ad968d1e3aa6ce5be295ab5fe3ae1bf5bb4769d0f98a80a0252d543a2ef2e9e"}, - {file = "coverage-7.13.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:453b7ec753cf5e4356e14fe858064e5520c460d3bbbcb9c35e55c0d21155c256"}, - {file = "coverage-7.13.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:af827b7cbb303e1befa6c4f94fd2bf72f108089cfa0f8abab8f4ca553cf5ca5a"}, - {file = "coverage-7.13.0-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:9987a9e4f8197a1000280f7cc089e3ea2c8b3c0a64d750537809879a7b4ceaf9"}, - {file = "coverage-7.13.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3188936845cd0cb114fa6a51842a304cdbac2958145d03be2377ec41eb285d19"}, - {file = "coverage-7.13.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2bdb3babb74079f021696cb46b8bb5f5661165c385d3a238712b031a12355be"}, - {file = "coverage-7.13.0-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7464663eaca6adba4175f6c19354feea61ebbdd735563a03d1e472c7072d27bb"}, - {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8069e831f205d2ff1f3d355e82f511eb7c5522d7d413f5db5756b772ec8697f8"}, - {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:6fb2d5d272341565f08e962cce14cdf843a08ac43bd621783527adb06b089c4b"}, - {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:5e70f92ef89bac1ac8a99b3324923b4749f008fdbd7aa9cb35e01d7a284a04f9"}, - {file = "coverage-7.13.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:4b5de7d4583e60d5fd246dd57fcd3a8aa23c6e118a8c72b38adf666ba8e7e927"}, - {file = "coverage-7.13.0-cp314-cp314-win32.whl", hash = "sha256:a6c6e16b663be828a8f0b6c5027d36471d4a9f90d28444aa4ced4d48d7d6ae8f"}, - {file = "coverage-7.13.0-cp314-cp314-win_amd64.whl", hash = "sha256:0900872f2fdb3ee5646b557918d02279dc3af3dfb39029ac4e945458b13f73bc"}, - {file = "coverage-7.13.0-cp314-cp314-win_arm64.whl", hash = "sha256:3a10260e6a152e5f03f26db4a407c4c62d3830b9af9b7c0450b183615f05d43b"}, - {file = "coverage-7.13.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:9097818b6cc1cfb5f174e3263eba4a62a17683bcfe5c4b5d07f4c97fa51fbf28"}, - {file = "coverage-7.13.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:0018f73dfb4301a89292c73be6ba5f58722ff79f51593352759c1790ded1cabe"}, - {file = "coverage-7.13.0-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:166ad2a22ee770f5656e1257703139d3533b4a0b6909af67c6b4a3adc1c98657"}, - {file = "coverage-7.13.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f6aaef16d65d1787280943f1c8718dc32e9cf141014e4634d64446702d26e0ff"}, - {file = "coverage-7.13.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e999e2dcc094002d6e2c7bbc1fb85b58ba4f465a760a8014d97619330cdbbbf3"}, - {file = "coverage-7.13.0-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:00c3d22cf6fb1cf3bf662aaaa4e563be8243a5ed2630339069799835a9cc7f9b"}, - {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22ccfe8d9bb0d6134892cbe1262493a8c70d736b9df930f3f3afae0fe3ac924d"}, - {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:9372dff5ea15930fea0445eaf37bbbafbc771a49e70c0aeed8b4e2c2614cc00e"}, - {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:69ac2c492918c2461bc6ace42d0479638e60719f2a4ef3f0815fa2df88e9f940"}, - {file = "coverage-7.13.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:739c6c051a7540608d097b8e13c76cfa85263ced467168dc6b477bae3df7d0e2"}, - {file = "coverage-7.13.0-cp314-cp314t-win32.whl", hash = "sha256:fe81055d8c6c9de76d60c94ddea73c290b416e061d40d542b24a5871bad498b7"}, - {file = "coverage-7.13.0-cp314-cp314t-win_amd64.whl", hash = "sha256:445badb539005283825959ac9fa4a28f712c214b65af3a2c464f1adc90f5fcbc"}, - {file = "coverage-7.13.0-cp314-cp314t-win_arm64.whl", hash = "sha256:de7f6748b890708578fc4b7bb967d810aeb6fcc9bff4bb77dbca77dab2f9df6a"}, - {file = "coverage-7.13.0-py3-none-any.whl", hash = "sha256:850d2998f380b1e266459ca5b47bc9e7daf9af1d070f66317972f382d46f1904"}, - {file = "coverage-7.13.0.tar.gz", hash = "sha256:a394aa27f2d7ff9bc04cf703817773a59ad6dfbd577032e690f961d2460ee936"}, + {file = "coverage-7.13.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e1fa280b3ad78eea5be86f94f461c04943d942697e0dac889fa18fff8f5f9147"}, + {file = "coverage-7.13.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c3d8c679607220979434f494b139dfb00131ebf70bb406553d69c1ff01a5c33d"}, + {file = "coverage-7.13.1-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:339dc63b3eba969067b00f41f15ad161bf2946613156fb131266d8debc8e44d0"}, + {file = "coverage-7.13.1-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:db622b999ffe49cb891f2fff3b340cdc2f9797d01a0a202a0973ba2562501d90"}, + {file = "coverage-7.13.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1443ba9acbb593fa7c1c29e011d7c9761545fe35e7652e85ce7f51a16f7e08d"}, + {file = "coverage-7.13.1-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c832ec92c4499ac463186af72f9ed4d8daec15499b16f0a879b0d1c8e5cf4a3b"}, + {file = "coverage-7.13.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:562ec27dfa3f311e0db1ba243ec6e5f6ab96b1edfcfc6cf86f28038bc4961ce6"}, + {file = "coverage-7.13.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4de84e71173d4dada2897e5a0e1b7877e5eefbfe0d6a44edee6ce31d9b8ec09e"}, + {file = "coverage-7.13.1-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:a5a68357f686f8c4d527a2dc04f52e669c2fc1cbde38f6f7eb6a0e58cbd17cae"}, + {file = "coverage-7.13.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:77cc258aeb29a3417062758975521eae60af6f79e930d6993555eeac6a8eac29"}, + {file = "coverage-7.13.1-cp310-cp310-win32.whl", hash = "sha256:bb4f8c3c9a9f34423dba193f241f617b08ffc63e27f67159f60ae6baf2dcfe0f"}, + {file = "coverage-7.13.1-cp310-cp310-win_amd64.whl", hash = "sha256:c8e2706ceb622bc63bac98ebb10ef5da80ed70fbd8a7999a5076de3afaef0fb1"}, + {file = "coverage-7.13.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:1a55d509a1dc5a5b708b5dad3b5334e07a16ad4c2185e27b40e4dba796ab7f88"}, + {file = "coverage-7.13.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4d010d080c4888371033baab27e47c9df7d6fb28d0b7b7adf85a4a49be9298b3"}, + {file = "coverage-7.13.1-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:d938b4a840fb1523b9dfbbb454f652967f18e197569c32266d4d13f37244c3d9"}, + {file = "coverage-7.13.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bf100a3288f9bb7f919b87eb84f87101e197535b9bd0e2c2b5b3179633324fee"}, + {file = "coverage-7.13.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef6688db9bf91ba111ae734ba6ef1a063304a881749726e0d3575f5c10a9facf"}, + {file = "coverage-7.13.1-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0b609fc9cdbd1f02e51f67f51e5aee60a841ef58a68d00d5ee2c0faf357481a3"}, + {file = "coverage-7.13.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c43257717611ff5e9a1d79dce8e47566235ebda63328718d9b65dd640bc832ef"}, + {file = "coverage-7.13.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e09fbecc007f7b6afdfb3b07ce5bd9f8494b6856dd4f577d26c66c391b829851"}, + {file = "coverage-7.13.1-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:a03a4f3a19a189919c7055098790285cc5c5b0b3976f8d227aea39dbf9f8bfdb"}, + {file = "coverage-7.13.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3820778ea1387c2b6a818caec01c63adc5b3750211af6447e8dcfb9b6f08dbba"}, + {file = "coverage-7.13.1-cp311-cp311-win32.whl", hash = "sha256:ff10896fa55167371960c5908150b434b71c876dfab97b69478f22c8b445ea19"}, + {file = "coverage-7.13.1-cp311-cp311-win_amd64.whl", hash = "sha256:a998cc0aeeea4c6d5622a3754da5a493055d2d95186bad877b0a34ea6e6dbe0a"}, + {file = "coverage-7.13.1-cp311-cp311-win_arm64.whl", hash = "sha256:fea07c1a39a22614acb762e3fbbb4011f65eedafcb2948feeef641ac78b4ee5c"}, + {file = "coverage-7.13.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6f34591000f06e62085b1865c9bc5f7858df748834662a51edadfd2c3bfe0dd3"}, + {file = "coverage-7.13.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b67e47c5595b9224599016e333f5ec25392597a89d5744658f837d204e16c63e"}, + {file = "coverage-7.13.1-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3e7b8bd70c48ffb28461ebe092c2345536fb18bbbf19d287c8913699735f505c"}, + {file = "coverage-7.13.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c223d078112e90dc0e5c4e35b98b9584164bea9fbbd221c0b21c5241f6d51b62"}, + {file = "coverage-7.13.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:794f7c05af0763b1bbd1b9e6eff0e52ad068be3b12cd96c87de037b01390c968"}, + {file = "coverage-7.13.1-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0642eae483cc8c2902e4af7298bf886d605e80f26382124cddc3967c2a3df09e"}, + {file = "coverage-7.13.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9f5e772ed5fef25b3de9f2008fe67b92d46831bd2bc5bdc5dd6bfd06b83b316f"}, + {file = "coverage-7.13.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:45980ea19277dc0a579e432aef6a504fe098ef3a9032ead15e446eb0f1191aee"}, + {file = "coverage-7.13.1-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:e4f18eca6028ffa62adbd185a8f1e1dd242f2e68164dba5c2b74a5204850b4cf"}, + {file = "coverage-7.13.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f8dca5590fec7a89ed6826fce625595279e586ead52e9e958d3237821fbc750c"}, + {file = "coverage-7.13.1-cp312-cp312-win32.whl", hash = "sha256:ff86d4e85188bba72cfb876df3e11fa243439882c55957184af44a35bd5880b7"}, + {file = "coverage-7.13.1-cp312-cp312-win_amd64.whl", hash = "sha256:16cc1da46c04fb0fb128b4dc430b78fa2aba8a6c0c9f8eb391fd5103409a6ac6"}, + {file = "coverage-7.13.1-cp312-cp312-win_arm64.whl", hash = "sha256:8d9bc218650022a768f3775dd7fdac1886437325d8d295d923ebcfef4892ad5c"}, + {file = "coverage-7.13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:cb237bfd0ef4d5eb6a19e29f9e528ac67ac3be932ea6b44fb6cc09b9f3ecff78"}, + {file = "coverage-7.13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1dcb645d7e34dcbcc96cd7c132b1fc55c39263ca62eb961c064eb3928997363b"}, + {file = "coverage-7.13.1-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3d42df8201e00384736f0df9be2ced39324c3907607d17d50d50116c989d84cd"}, + {file = "coverage-7.13.1-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:fa3edde1aa8807de1d05934982416cb3ec46d1d4d91e280bcce7cca01c507992"}, + {file = "coverage-7.13.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9edd0e01a343766add6817bc448408858ba6b489039eaaa2018474e4001651a4"}, + {file = "coverage-7.13.1-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:985b7836931d033570b94c94713c6dba5f9d3ff26045f72c3e5dbc5fe3361e5a"}, + {file = "coverage-7.13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ffed1e4980889765c84a5d1a566159e363b71d6b6fbaf0bebc9d3c30bc016766"}, + {file = "coverage-7.13.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:8842af7f175078456b8b17f1b73a0d16a65dcbdc653ecefeb00a56b3c8c298c4"}, + {file = "coverage-7.13.1-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:ccd7a6fca48ca9c131d9b0a2972a581e28b13416fc313fb98b6d24a03ce9a398"}, + {file = "coverage-7.13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:0403f647055de2609be776965108447deb8e384fe4a553c119e3ff6bfbab4784"}, + {file = "coverage-7.13.1-cp313-cp313-win32.whl", hash = "sha256:549d195116a1ba1e1ae2f5ca143f9777800f6636eab917d4f02b5310d6d73461"}, + {file = "coverage-7.13.1-cp313-cp313-win_amd64.whl", hash = "sha256:5899d28b5276f536fcf840b18b61a9fce23cc3aec1d114c44c07fe94ebeaa500"}, + {file = "coverage-7.13.1-cp313-cp313-win_arm64.whl", hash = "sha256:868a2fae76dfb06e87291bcbd4dcbcc778a8500510b618d50496e520bd94d9b9"}, + {file = "coverage-7.13.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:67170979de0dacac3f3097d02b0ad188d8edcea44ccc44aaa0550af49150c7dc"}, + {file = "coverage-7.13.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f80e2bb21bfab56ed7405c2d79d34b5dc0bc96c2c1d2a067b643a09fb756c43a"}, + {file = "coverage-7.13.1-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:f83351e0f7dcdb14d7326c3d8d8c4e915fa685cbfdc6281f9470d97a04e9dfe4"}, + {file = "coverage-7.13.1-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bb3f6562e89bad0110afbe64e485aac2462efdce6232cdec7862a095dc3412f6"}, + {file = "coverage-7.13.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:77545b5dcda13b70f872c3b5974ac64c21d05e65b1590b441c8560115dc3a0d1"}, + {file = "coverage-7.13.1-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a4d240d260a1aed814790bbe1f10a5ff31ce6c21bc78f0da4a1e8268d6c80dbd"}, + {file = "coverage-7.13.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:d2287ac9360dec3837bfdad969963a5d073a09a85d898bd86bea82aa8876ef3c"}, + {file = "coverage-7.13.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:0d2c11f3ea4db66b5cbded23b20185c35066892c67d80ec4be4bab257b9ad1e0"}, + {file = "coverage-7.13.1-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:3fc6a169517ca0d7ca6846c3c5392ef2b9e38896f61d615cb75b9e7134d4ee1e"}, + {file = "coverage-7.13.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d10a2ed46386e850bb3de503a54f9fe8192e5917fcbb143bfef653a9355e9a53"}, + {file = "coverage-7.13.1-cp313-cp313t-win32.whl", hash = "sha256:75a6f4aa904301dab8022397a22c0039edc1f51e90b83dbd4464b8a38dc87842"}, + {file = "coverage-7.13.1-cp313-cp313t-win_amd64.whl", hash = "sha256:309ef5706e95e62578cda256b97f5e097916a2c26247c287bbe74794e7150df2"}, + {file = "coverage-7.13.1-cp313-cp313t-win_arm64.whl", hash = "sha256:92f980729e79b5d16d221038dbf2e8f9a9136afa072f9d5d6ed4cb984b126a09"}, + {file = "coverage-7.13.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:97ab3647280d458a1f9adb85244e81587505a43c0c7cff851f5116cd2814b894"}, + {file = "coverage-7.13.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8f572d989142e0908e6acf57ad1b9b86989ff057c006d13b76c146ec6a20216a"}, + {file = "coverage-7.13.1-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:d72140ccf8a147e94274024ff6fd8fb7811354cf7ef88b1f0a988ebaa5bc774f"}, + {file = "coverage-7.13.1-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d3c9f051b028810f5a87c88e5d6e9af3c0ff32ef62763bf15d29f740453ca909"}, + {file = "coverage-7.13.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f398ba4df52d30b1763f62eed9de5620dcde96e6f491f4c62686736b155aa6e4"}, + {file = "coverage-7.13.1-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:132718176cc723026d201e347f800cd1a9e4b62ccd3f82476950834dad501c75"}, + {file = "coverage-7.13.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:9e549d642426e3579b3f4b92d0431543b012dcb6e825c91619d4e93b7363c3f9"}, + {file = "coverage-7.13.1-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:90480b2134999301eea795b3a9dbf606c6fbab1b489150c501da84a959442465"}, + {file = "coverage-7.13.1-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:e825dbb7f84dfa24663dd75835e7257f8882629fc11f03ecf77d84a75134b864"}, + {file = "coverage-7.13.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:623dcc6d7a7ba450bbdbeedbaa0c42b329bdae16491af2282f12a7e809be7eb9"}, + {file = "coverage-7.13.1-cp314-cp314-win32.whl", hash = "sha256:6e73ebb44dca5f708dc871fe0b90cf4cff1a13f9956f747cc87b535a840386f5"}, + {file = "coverage-7.13.1-cp314-cp314-win_amd64.whl", hash = "sha256:be753b225d159feb397bd0bf91ae86f689bad0da09d3b301478cd39b878ab31a"}, + {file = "coverage-7.13.1-cp314-cp314-win_arm64.whl", hash = "sha256:228b90f613b25ba0019361e4ab81520b343b622fc657daf7e501c4ed6a2366c0"}, + {file = "coverage-7.13.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:60cfb538fe9ef86e5b2ab0ca8fc8d62524777f6c611dcaf76dc16fbe9b8e698a"}, + {file = "coverage-7.13.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:57dfc8048c72ba48a8c45e188d811e5efd7e49b387effc8fb17e97936dde5bf6"}, + {file = "coverage-7.13.1-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3f2f725aa3e909b3c5fdb8192490bdd8e1495e85906af74fe6e34a2a77ba0673"}, + {file = "coverage-7.13.1-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9ee68b21909686eeb21dfcba2c3b81fee70dcf38b140dcd5aa70680995fa3aa5"}, + {file = "coverage-7.13.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:724b1b270cb13ea2e6503476e34541a0b1f62280bc997eab443f87790202033d"}, + {file = "coverage-7.13.1-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:916abf1ac5cf7eb16bc540a5bf75c71c43a676f5c52fcb9fe75a2bd75fb944e8"}, + {file = "coverage-7.13.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:776483fd35b58d8afe3acbd9988d5de592ab6da2d2a865edfdbc9fdb43e7c486"}, + {file = "coverage-7.13.1-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:b6f3b96617e9852703f5b633ea01315ca45c77e879584f283c44127f0f1ec564"}, + {file = "coverage-7.13.1-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:bd63e7b74661fed317212fab774e2a648bc4bb09b35f25474f8e3325d2945cd7"}, + {file = "coverage-7.13.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:933082f161bbb3e9f90d00990dc956120f608cdbcaeea15c4d897f56ef4fe416"}, + {file = "coverage-7.13.1-cp314-cp314t-win32.whl", hash = "sha256:18be793c4c87de2965e1c0f060f03d9e5aff66cfeae8e1dbe6e5b88056ec153f"}, + {file = "coverage-7.13.1-cp314-cp314t-win_amd64.whl", hash = "sha256:0e42e0ec0cd3e0d851cb3c91f770c9301f48647cb2877cb78f74bdaa07639a79"}, + {file = "coverage-7.13.1-cp314-cp314t-win_arm64.whl", hash = "sha256:eaecf47ef10c72ece9a2a92118257da87e460e113b83cc0d2905cbbe931792b4"}, + {file = "coverage-7.13.1-py3-none-any.whl", hash = "sha256:2016745cb3ba554469d02819d78958b571792bb68e31302610e898f80dd3a573"}, + {file = "coverage-7.13.1.tar.gz", hash = "sha256:b7593fe7eb5feaa3fbb461ac79aac9f9fc0387a5ca8080b0c6fe2ca27b091afd"}, ] [package.dependencies] @@ -1223,15 +1223,15 @@ files = [ [[package]] name = "diff-cover" -version = "10.0.0" +version = "10.2.0" description = "Run coverage and linting reports on diffs" optional = true -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\"" files = [ - {file = "diff_cover-10.0.0-py3-none-any.whl", hash = "sha256:b3a095d733ba715df6098f51d9155607e4752f82be8a6cbba9bfcf77df736852"}, - {file = "diff_cover-10.0.0.tar.gz", hash = "sha256:92ead026726055bf4c1a90cd7ff83544049d467840e07c66289a4351126dbe25"}, + {file = "diff_cover-10.2.0-py3-none-any.whl", hash = "sha256:59c328595e0b8948617cc5269af9e484c86462e2844bfcafa3fb37f8fca0af87"}, + {file = "diff_cover-10.2.0.tar.gz", hash = "sha256:61bf83025f10510c76ef6a5820680cf61b9b974e8f81de70c57ac926fa63872a"}, ] [package.dependencies] @@ -1245,15 +1245,15 @@ toml = ["tomli (>=1.2.1)"] [[package]] name = "dill" -version = "0.4.0" +version = "0.4.1" description = "serialize all of Python" optional = true -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main"] markers = "extra == \"dev\" or extra == \"tests\" or extra == \"docs\"" files = [ - {file = "dill-0.4.0-py3-none-any.whl", hash = "sha256:44f54bf6412c2c8464c14e8243eb163690a9800dbe2c367330883b19c7561049"}, - {file = "dill-0.4.0.tar.gz", hash = "sha256:0633f1d2df477324f53a895b02c901fb961bdbf65a17122586ea7019292cbcf0"}, + {file = "dill-0.4.1-py3-none-any.whl", hash = "sha256:1e1ce33e978ae97fcfcff5638477032b801c46c7c65cf717f95fbc2248f79a9d"}, + {file = "dill-0.4.1.tar.gz", hash = "sha256:423092df4182177d4d8ba8290c8a5b640c66ab35ec7da59ccfa00f6fa3eea5fa"}, ] [package.extras] @@ -1390,15 +1390,15 @@ tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipyth [[package]] name = "fastcore" -version = "1.9.5" +version = "1.12.2" description = "Python supercharged for fastai development" optional = true python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "fastcore-1.9.5-py3-none-any.whl", hash = "sha256:8c306dec73b3a0069e986d5a3d982947cf3e651a206b20892f657c574c04bb1b"}, - {file = "fastcore-1.9.5.tar.gz", hash = "sha256:e9f06b8c80bc29a2ec21113566ec975a600a1f20951815589bc3ab31b2f79b9b"}, + {file = "fastcore-1.12.2-py3-none-any.whl", hash = "sha256:11fccdc9dc0a13a0d15f6b63bed3f45e6bc8ab73458f8795300d2bd8e35d8ba6"}, + {file = "fastcore-1.12.2.tar.gz", hash = "sha256:01c0c9e19cd0c3cb7a208311d2826da3f0269babc32bd9aa21c7b7976ad5626e"}, ] [package.dependencies] @@ -1425,15 +1425,15 @@ devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benc [[package]] name = "filelock" -version = "3.20.1" +version = "3.20.3" description = "A platform independent file lock." optional = true python-versions = ">=3.10" groups = ["main"] -markers = "sys_platform == \"darwin\" and (extra == \"dev\" or extra == \"pytorch\" or extra == \"tests\" or extra == \"scikit-rf\" or extra == \"docs\") or extra == \"dev\" or extra == \"tests\" or extra == \"scikit-rf\" or extra == \"docs\" or extra == \"pytorch\"" +markers = "extra == \"dev\" or extra == \"tests\" or extra == \"scikit-rf\" or extra == \"docs\" or extra == \"pytorch\"" files = [ - {file = "filelock-3.20.1-py3-none-any.whl", hash = "sha256:15d9e9a67306188a44baa72f569d2bfd803076269365fdea0934385da4dc361a"}, - {file = "filelock-3.20.1.tar.gz", hash = "sha256:b8360948b351b80f420878d8516519a2204b07aefcdcfd24912a5d33127f188c"}, + {file = "filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1"}, + {file = "filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1"}, ] [[package]] @@ -1484,7 +1484,7 @@ jax = ">=0.8.1" msgpack = "*" numpy = [ {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, ] optax = "*" orbax-checkpoint = "*" @@ -1587,14 +1587,14 @@ files = [ [[package]] name = "fsspec" -version = "2025.12.0" +version = "2026.1.0" description = "File-system specification" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "fsspec-2025.12.0-py3-none-any.whl", hash = "sha256:8bf1fe301b7d8acfa6e8571e3b1c3d158f909666642431cc78a1b7b4dbc5ec5b"}, - {file = "fsspec-2025.12.0.tar.gz", hash = "sha256:c505de011584597b1060ff778bb664c1bc022e87921b0e4f10cc9c44f9635973"}, + {file = "fsspec-2026.1.0-py3-none-any.whl", hash = "sha256:cb76aa913c2285a3b49bdd5fc55b1d7c708d7208126b60f2eb8194fe1b4cbdcc"}, + {file = "fsspec-2026.1.0.tar.gz", hash = "sha256:e987cb0496a0d81bba3a9d1cee62922fb395e7d4c3b575e57f547953334fe07b"}, ] [package.extras] @@ -1605,7 +1605,7 @@ dask = ["dask", "distributed"] dev = ["pre-commit", "ruff (>=0.5)"] doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] dropbox = ["dropbox", "dropboxdrivefs", "requests"] -full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs (>2024.2.0)", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs (>2024.2.0)", "smbprotocol", "tqdm"] fuse = ["fusepy"] gcs = ["gcsfs"] git = ["pygit2"] @@ -1622,7 +1622,7 @@ smb = ["smbprotocol"] ssh = ["paramiko"] test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"] test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"] -test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard ; python_version < \"3.14\""] +test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "backports-zstd ; python_version < \"3.14\"", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr"] tqdm = ["tqdm"] [[package]] @@ -1640,75 +1640,55 @@ files = [ [[package]] name = "gdstk" -version = "0.9.61" +version = "0.9.62" description = "Python module for creation and manipulation of GDSII files." optional = true python-versions = ">=3.9" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\" or extra == \"gdstk\"" files = [ - {file = "gdstk-0.9.61-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8db8120b5b8864de074ed773d4c0788100b76eecd2bf327a6de338f011745e3f"}, - {file = "gdstk-0.9.61-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ad942da613f6274e391371771b8cfef2854eb69f628914f716f518929567dcd4"}, - {file = "gdstk-0.9.61-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3b49ff5e43764783d2053b129fe1eac152910e2d062dfc2fd2408c9b91a043d5"}, - {file = "gdstk-0.9.61-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5e53c0f765796b4fc449b72c800924df2e936820087816686e987962b3f0452a"}, - {file = "gdstk-0.9.61-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc601258b850595b34e22b5c0fd1d98724a053faa4b1a23517c693b6eb01e275"}, - {file = "gdstk-0.9.61-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8ab2644f04a3d91e158bfce7c5dbdc60f09745cf7dc7fc19e9255cb6e6d9547b"}, - {file = "gdstk-0.9.61-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:4aa897a629d20bca211cacf36e35a7316a5d6cfe03effb6af19c0eb7fd225421"}, - {file = "gdstk-0.9.61-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9d20a09f06596ff2926e6b4ad12f3b0ae0ce545bf60211b96c2f9791f1df37fe"}, - {file = "gdstk-0.9.61-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:688cc52aa1a5b9016eb0787a9cea4943a1aa2cc3d8d3cbeeaa44b3203f71e38f"}, - {file = "gdstk-0.9.61-cp310-cp310-win_amd64.whl", hash = "sha256:5214c4f89fb9ff60ced79f6d2d28de4c5d5b588c9ef930fe72333edaa5e0bcf2"}, - {file = "gdstk-0.9.61-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5fab80fa1e5ac4d956a04fdc78fb6971cb32a43418553939ee4ccf4eba6d4496"}, - {file = "gdstk-0.9.61-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:82706a72f37c70340978fb70777cf94119408593f5a8c73c0700c0b84486a3fe"}, - {file = "gdstk-0.9.61-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6857359fc517fa91d6c0cd179cf09290aaebf538869d825585d9a0ed3cec754d"}, - {file = "gdstk-0.9.61-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:459b1f28a6283bb61ed28c745aba3d49c5cbd9424fb81f76023d3f44b92c6257"}, - {file = "gdstk-0.9.61-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3812aadf36764bb6ca86f1b9f4bdf8f8c41749bcdf1e3b45d6263e48b4f97eab"}, - {file = "gdstk-0.9.61-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f3c6f0df208263039851ac5d3d94fcddbc80029a69918d53c0b7dc392725d8fb"}, - {file = "gdstk-0.9.61-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:7e166ef1c26fc0f48fa8194e54683e61ca43b72d3342708d4229855dcad137ed"}, - {file = "gdstk-0.9.61-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:79dc9f0f0c5f6860199c9af09564bbfed4c34885d3f5b46ab9514ab0716cff39"}, - {file = "gdstk-0.9.61-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4b3e2b367e5962db05845eaaf3f9d8bcfa3914738c6e174455a152a63d78904c"}, - {file = "gdstk-0.9.61-cp311-cp311-win_amd64.whl", hash = "sha256:0c3866dc287d657f78ae587e2e10de2747ebbf5d2824dc6ba4f9ece89c36a35a"}, - {file = "gdstk-0.9.61-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:61f0ee05cdce9b4163ea812cbf2e2f5d8d01a293fa118ff98348280306bd91d6"}, - {file = "gdstk-0.9.61-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fff1b104b6775e4c27ab2751b3f4ac6c1ce86a4e9afd5e5535ac4acefa6a7a07"}, - {file = "gdstk-0.9.61-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5218f8c5ab13b6e979665c0a7dc1272768003a1cb7add0682483837f7485faed"}, - {file = "gdstk-0.9.61-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4e79f3881d3b3666a600efd5b2c131454507f69d3c9b9eaf383d106cfbd6e7bc"}, - {file = "gdstk-0.9.61-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e90a6e24c2145320e53e953a59c6297fd25c17c6ef098fa8602e64e64a5390ea"}, - {file = "gdstk-0.9.61-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3a49401cbd26c5a17a4152d1befa73efb21af694524557bf09d15f4c8a874e6"}, - {file = "gdstk-0.9.61-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:8738ac63bbe29dcb5abae6a19d207c4e0857f9dc1bd405c85af8a87f0dcfb348"}, - {file = "gdstk-0.9.61-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:23bb023a49f3321673d0e32cdce2e2705a51d9e12328c928723ded49af970520"}, - {file = "gdstk-0.9.61-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:81c2f19cab89623d1f56848e7a16e2fab82a93c61c8f7aa73f5ff59840b60c0f"}, - {file = "gdstk-0.9.61-cp312-cp312-win_amd64.whl", hash = "sha256:4474f015ecc228b210165287cb7eea65639ea6308f60105cb49e970079bddc2b"}, - {file = "gdstk-0.9.61-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:3beeae846fc523c7e3a01c47edcd3b7dd83c29650e56b82a371e528f9cb0ec3e"}, - {file = "gdstk-0.9.61-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:575a21639b31e2fab4d9e918468b8b40a58183028db563e5963be594bff1403d"}, - {file = "gdstk-0.9.61-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:90d54b48223dcbb8257769faaa87542d12a749d8486e8d1187a45d06e9422859"}, - {file = "gdstk-0.9.61-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:35405bed95542a0b10f343b165ce0ad80740bf8127a4507565ec74222e6ec8d3"}, - {file = "gdstk-0.9.61-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b311ddf8982995b52ac3bf3b32a6cf6d918afc4e66dea527d531e8af73896231"}, - {file = "gdstk-0.9.61-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6dcbfc60fba92d10f1c7d612b5409c343fcaf2a380640e9fb01c504ca948b412"}, - {file = "gdstk-0.9.61-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:fab67ccdd8029ef7eb873f8c98f875dc2665a5e45af7cf3d2a7a0f401826a1d3"}, - {file = "gdstk-0.9.61-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:5852749e203d6978e06d02f8ef9e29ce4512cb1aedeb62c37b8e8b2c10c4f529"}, - {file = "gdstk-0.9.61-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7ee38a54c799e77dbe219266f765bbd3b2906b62bc7b6fb64b1387e6db3dd187"}, - {file = "gdstk-0.9.61-cp313-cp313-win_amd64.whl", hash = "sha256:6abb396873b2660dd7863d664b3822f00547bf7f216af27be9f1f812bc5e8027"}, - {file = "gdstk-0.9.61-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:a674af8be5cf1f8ea9f6c5b5f165f797d7e2ed74cbca68b4a22adb92b515fb35"}, - {file = "gdstk-0.9.61-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:38ec0b7285d6c9bf8cbc279731dc0d314633cda2ce9e6f9053554b3e5f004fcd"}, - {file = "gdstk-0.9.61-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3b63a77b57fb441c8017217aaf1e8b13d93cbee822031a8e2826adb716e01dd4"}, - {file = "gdstk-0.9.61-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7fae6eee627e837d1405b47d381ccd33dbba85473b1bb3822bdc8ae41dbc0dc"}, - {file = "gdstk-0.9.61-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:9e396694cac24bd87d0e38c37e6740d9ba0c13f6c9f2211a871d62288430f069"}, - {file = "gdstk-0.9.61-cp314-cp314-win_amd64.whl", hash = "sha256:7ea0c1200dc53b794e9c0cc6fe3ea51e49113dfdd9c3109e1961cda3cc2197c7"}, - {file = "gdstk-0.9.61-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:616dd1c3e7aea4a98aeb03db7cf76a853d134c54690790eaa25c63eede7b869a"}, - {file = "gdstk-0.9.61-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b0e898202fbb7fd4c39f8404831415a0aa0445656342102c4e77d4a7c2c15a1d"}, - {file = "gdstk-0.9.61-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:29bb862a1a814f5bbd6f8bbc2f99e1163df9e6307071cb6e11251dbe7542feb5"}, - {file = "gdstk-0.9.61-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6c2a08d82a683aff50dc63f2943ed805d32d46bd984cbd4ac9cf876146d0ef9"}, - {file = "gdstk-0.9.61-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3ba52f95763052a6968583942e6531ceca20c14c762d44fe2bd887445e2f73b6"}, - {file = "gdstk-0.9.61-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1033d4ddd2af34461c1133ef62213a4861f23d07d64d66e92fe8d2554a85ba6d"}, - {file = "gdstk-0.9.61-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bad94f74dff3efaa5ade7bab5040464e575839fa65b935c8f872a47e1658f535"}, - {file = "gdstk-0.9.61-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c9c8738b57cb6100cb0d4346272af489d05f9b9908e0018a5ecbcb5ee485fa97"}, - {file = "gdstk-0.9.61-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7f9dd95da53d3cdbc3dcaed446b7404d8d4dfbdbd68628eeddde6285bc5a5"}, - {file = "gdstk-0.9.61-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f73637dc2abe3754906f2911557b563281f868f5d153332edea681d963b2a22"}, - {file = "gdstk-0.9.61-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:56d493bb7fc3fb33de63d8c1107ff3d645b62596d0c2073f1a390d90bef73233"}, - {file = "gdstk-0.9.61-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:7905572cc10b2b85a960317eadb5cf95197b5a52b1ef9358336d5cd224e08314"}, - {file = "gdstk-0.9.61-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:a4bc70f308653d63c26d15637b27e2435f7bdaa50d072db410c1f573db6d985b"}, - {file = "gdstk-0.9.61-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:3794115a278d5a38db5c5a8f0cfeff3e1701263bcfb58b7e1934e199578e14f1"}, - {file = "gdstk-0.9.61-cp39-cp39-win_amd64.whl", hash = "sha256:24c83250e8d6c6ced0d8e3946c096b2944564dc3cca53a9e75a7350eda2538b7"}, - {file = "gdstk-0.9.61.tar.gz", hash = "sha256:2967935fdf455c56ca77ad5c703c87cb88644ab75e752dcac866a36499879c6f"}, + {file = "gdstk-0.9.62-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:49dae046c5fd71b9eddf2b3f17e3efd4fadb926d258a8a570f2b04c127da96d8"}, + {file = "gdstk-0.9.62-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1a9694ddc6419f622a2bf710ab16912bbfd9d31e3931354d6f586b544012380f"}, + {file = "gdstk-0.9.62-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4b1f7876471e4761b167f6c10305341a8ed269ae98a2378372bbe74150cb79e5"}, + {file = "gdstk-0.9.62-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ec371cf06176deeebb96e3855fffbce1d1fab0cf9e937a904c2cc51cad17665"}, + {file = "gdstk-0.9.62-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c04f4b1cbeaa5508bbfa784c5db65d9c05cb5d48bea9846a0ed3bbf1233f02da"}, + {file = "gdstk-0.9.62-cp310-cp310-win_amd64.whl", hash = "sha256:aa23f91036eb86bae45dbaca92ea0ea0d78a951e76b47a9a728122ad3440991a"}, + {file = "gdstk-0.9.62-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ebf47814ae21481a5edba1d4a66e33d5185e5be41a4aedba6abd2aee93ce9c5e"}, + {file = "gdstk-0.9.62-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9738331712914118d2821f3ec6b724cd17bf713cde82d96be482046c7886df21"}, + {file = "gdstk-0.9.62-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fad7bf2d811ac417a78781f5f22ec1ce1aadb5bd3212ee8076bd695547d0598d"}, + {file = "gdstk-0.9.62-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:39d4fb2e4e855ca0dfc06ce3328288313c91ef0a68d3cc5d5a61f9ad495d25e8"}, + {file = "gdstk-0.9.62-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:95a13152c16044493a413584e3a3e90c1d8defb13d49a2bf2f835778a7d57c54"}, + {file = "gdstk-0.9.62-cp311-cp311-win_amd64.whl", hash = "sha256:18e7df66921cc97a70de0f12fa11b14448dcb651e7f76663a204bf0eeb770759"}, + {file = "gdstk-0.9.62-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:61c4fb771e12ec3bee3be6da989cc3f8e46fd0922d4af6d2b4ec7a76c8ed94f1"}, + {file = "gdstk-0.9.62-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:635edcb30a2651c042c0b70c4b46f3bead010fb638a0ef7161a1021c43308602"}, + {file = "gdstk-0.9.62-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b3422d3e8a201f8a2a160cd8fbcc6b38a7e8cae856cf0ffe51b8554af53b60b1"}, + {file = "gdstk-0.9.62-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0dd7dbfa62ddbb6550a714523019ec2aaf973f376707405918be08bc243d916b"}, + {file = "gdstk-0.9.62-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:07368f1075cbeaaee2eb5c60d66ae06fab507010feb8b4a294e5631abe9422ad"}, + {file = "gdstk-0.9.62-cp312-cp312-win_amd64.whl", hash = "sha256:ef5ab2d9a3638abcac684223577a8705ea1a537789f22e2a67fdbc036cdec992"}, + {file = "gdstk-0.9.62-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:39395bb333dc13068ef9fae231008aea9faf381bc2d7616d4461eec9b71f23b4"}, + {file = "gdstk-0.9.62-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6010171b2023be75d566f2b5adbd2dd921ec11d97c99472e8bda27b092365da9"}, + {file = "gdstk-0.9.62-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5fc7516a1a07f440e56d37ffcb845ba8784ee9d975ebb07bf185a1e9ea6b5a0e"}, + {file = "gdstk-0.9.62-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c3c9197c31ad541f725b6f907a6ad98d56fbd718ff0e54ee57336adae85d11c3"}, + {file = "gdstk-0.9.62-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa3773c5f0ce1c55b2a079f4c0a0620069f4ae20cec59951e4cdaf94321a0993"}, + {file = "gdstk-0.9.62-cp313-cp313-win_amd64.whl", hash = "sha256:388fde959049c6ad83bf9e2e62d0895f294a61231b2f0531eb826dd753626e27"}, + {file = "gdstk-0.9.62-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:6ee06dfc46ad63ca095046896ef99d62d75e43fc7d0b162b36d8b6db3432f064"}, + {file = "gdstk-0.9.62-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:27ed2dc2054c0496096bfc9043ba9497160f324c15d8c376186890d6adcd659b"}, + {file = "gdstk-0.9.62-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:73d233f88aac027f389f0e7b29f7b73a6f76ae77bae7538f25c649984f876492"}, + {file = "gdstk-0.9.62-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2601a7c4e3bb3a15464b584bca6122f25f70b44f6500345353498df850eb8e74"}, + {file = "gdstk-0.9.62-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2d6741f8930c722345c319d747df4af8d92e4558fdeec32214d45dde822736a6"}, + {file = "gdstk-0.9.62-cp314-cp314-win_amd64.whl", hash = "sha256:dc198b829a1eef65590a5b88dfc3507e0f058343e75fd849070e3354e84a1f66"}, + {file = "gdstk-0.9.62-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:18b00f30d5cf63d9a4e608cd5d69e1096cf8d199c67af9c51f30530b0f95da97"}, + {file = "gdstk-0.9.62-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:8f8c703bdaa087b702dd4cf40e57ad337238c6dde5bc9528c83c5457e34612d6"}, + {file = "gdstk-0.9.62-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:22f60ebc23b5eb7be611b13692aca8adc1a2a42f7970016dcc3fba84214baedd"}, + {file = "gdstk-0.9.62-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22cf47ee90bb8c063574ce6d75eafd3cef919172bc4736f982abfd8cec15cb10"}, + {file = "gdstk-0.9.62-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:f5a292e72016a263283d37a6dc54fd38a3b6d2b57c21478c781d621acc9a2125"}, + {file = "gdstk-0.9.62-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:10d40cd31f8db6adacd24b7b2e3bbd7a99427312feacd1850d45476c8a57031f"}, + {file = "gdstk-0.9.62-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a4ac0816aeb57cd881a1aa3e6a2a7c88421d11b82b23df45cc901d253270012b"}, + {file = "gdstk-0.9.62-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cacd06c432dcbaf9bb3c0a54047588e4c16aa9e8d69d7f72c791a86433ffb6e4"}, + {file = "gdstk-0.9.62-cp39-cp39-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:396be125aee6596642292afb3c8893cd67de2bb72457d5a4784319401b6631e1"}, + {file = "gdstk-0.9.62-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:07a114ec2d2c18d682506074a193d3e0a98921d1d3f29217e9cbe4c164b563a5"}, + {file = "gdstk-0.9.62-cp39-cp39-win_amd64.whl", hash = "sha256:d7b3edc6b820fdaef1c03dc416117e5e8a92b1fff024afeaf0c72a93119f8b1e"}, + {file = "gdstk-0.9.62.tar.gz", hash = "sha256:d69d048a82c98b309a09058ba7fab3fd52c7fc35ef2df46a8e3466cde5ca98e6"}, ] [package.dependencies] @@ -1736,15 +1716,15 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.45" +version = "3.1.46" description = "GitPython is a Python library used to interact with Git repositories" optional = true python-versions = ">=3.7" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77"}, - {file = "gitpython-3.1.45.tar.gz", hash = "sha256:85b0ee964ceddf211c41b9f27a49086010a190fd8132a24e21f362a4b36a791c"}, + {file = "gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058"}, + {file = "gitpython-3.1.46.tar.gz", hash = "sha256:400124c7d0ef4ea03f7310ac2fbf7151e09ff97f2a3288d64a440c584a29c37f"}, ] [package.dependencies] @@ -1752,7 +1732,7 @@ gitdb = ">=4.0.1,<5" [package.extras] doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"] -test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions ; python_version < \"3.11\""] +test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock ; python_version < \"3.8\"", "mypy (==1.18.2) ; python_version >= \"3.9\"", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions ; python_version < \"3.11\""] [[package]] name = "grcwa" @@ -1905,15 +1885,15 @@ tests = ["freezegun", "pytest", "pytest-cov"] [[package]] name = "identify" -version = "2.6.15" +version = "2.6.16" description = "File identification library for Python" optional = true -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\" or extra == \"tests\" or extra == \"scikit-rf\"" files = [ - {file = "identify-2.6.15-py2.py3-none-any.whl", hash = "sha256:1181ef7608e00704db228516541eb83a88a9f94433a8c80bb9b5bd54b1d81757"}, - {file = "identify-2.6.15.tar.gz", hash = "sha256:e4f4864b96c6557ef2a1e1c951771838f4edc9df3a72ec7118b338801b11c7bf"}, + {file = "identify-2.6.16-py2.py3-none-any.whl", hash = "sha256:391ee4d77741d994189522896270b787aed8670389bfd60f326d677d64a6dfb0"}, + {file = "identify-2.6.16.tar.gz", hash = "sha256:846857203b5511bbe94d5a352a48ef2359532bc8f6727b5544077a0dcfb24980"}, ] [package.extras] @@ -2041,15 +2021,15 @@ test = ["flaky", "ipyparallel", "pre-commit", "pytest (>=7.0,<9)", "pytest-async [[package]] name = "ipython" -version = "8.37.0" +version = "8.38.0" description = "IPython: Productive Interactive Computing" optional = true python-versions = ">=3.10" groups = ["main"] -markers = "python_version < \"3.11\" and (extra == \"dev\" or extra == \"docs\")" +markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "ipython-8.37.0-py3-none-any.whl", hash = "sha256:ed87326596b878932dbcb171e3e698845434d8c61b8d8cd474bf663041a9dcf2"}, - {file = "ipython-8.37.0.tar.gz", hash = "sha256:ca815841e1a41a1e6b73a0b08f3038af9b2252564d01fc405356d34033012216"}, + {file = "ipython-8.38.0-py3-none-any.whl", hash = "sha256:750162629d800ac65bb3b543a14e7a74b0e88063eac9b92124d4b2aa3f6d8e86"}, + {file = "ipython-8.38.0.tar.gz", hash = "sha256:9cfea8c903ce0867cc2f23199ed8545eb741f3a69420bfcf3743ad1cec856d39"}, ] [package.dependencies] @@ -2079,56 +2059,6 @@ qtconsole = ["qtconsole"] test = ["packaging", "pickleshare", "pytest", "pytest-asyncio (<0.22)", "testpath"] test-extra = ["curio", "ipython[test]", "jupyter_ai", "matplotlib (!=3.2.0)", "nbformat", "numpy (>=1.23)", "pandas", "trio"] -[[package]] -name = "ipython" -version = "9.8.0" -description = "IPython: Productive Interactive Computing" -optional = true -python-versions = ">=3.11" -groups = ["main"] -markers = "python_version >= \"3.11\" and (extra == \"dev\" or extra == \"docs\")" -files = [ - {file = "ipython-9.8.0-py3-none-any.whl", hash = "sha256:ebe6d1d58d7d988fbf23ff8ff6d8e1622cfdb194daf4b7b73b792c4ec3b85385"}, - {file = "ipython-9.8.0.tar.gz", hash = "sha256:8e4ce129a627eb9dd221c41b1d2cdaed4ef7c9da8c17c63f6f578fe231141f83"}, -] - -[package.dependencies] -colorama = {version = ">=0.4.4", markers = "sys_platform == \"win32\""} -decorator = ">=4.3.2" -ipython-pygments-lexers = ">=1.0.0" -jedi = ">=0.18.1" -matplotlib-inline = ">=0.1.5" -pexpect = {version = ">4.3", markers = "sys_platform != \"win32\" and sys_platform != \"emscripten\""} -prompt_toolkit = ">=3.0.41,<3.1.0" -pygments = ">=2.11.0" -stack_data = ">=0.6.0" -traitlets = ">=5.13.0" -typing_extensions = {version = ">=4.6", markers = "python_version < \"3.12\""} - -[package.extras] -all = ["ipython[doc,matplotlib,test,test-extra]"] -black = ["black"] -doc = ["docrepr", "exceptiongroup", "intersphinx_registry", "ipykernel", "ipython[matplotlib,test]", "setuptools (>=70.0)", "sphinx (>=8.0)", "sphinx-rtd-theme (>=0.1.8)", "sphinx_toml (==0.0.4)", "typing_extensions"] -matplotlib = ["matplotlib (>3.9)"] -test = ["packaging (>=20.1.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=1.0.0)", "setuptools (>=61.2)", "testpath (>=0.2)"] -test-extra = ["curio", "ipykernel (>6.30)", "ipython[matplotlib]", "ipython[test]", "jupyter_ai", "nbclient", "nbformat", "numpy (>=1.27)", "pandas (>2.1)", "trio (>=0.1.0)"] - -[[package]] -name = "ipython-pygments-lexers" -version = "1.1.1" -description = "Defines a variety of Pygments lexers for highlighting IPython code." -optional = true -python-versions = ">=3.8" -groups = ["main"] -markers = "python_version >= \"3.11\" and (extra == \"dev\" or extra == \"docs\")" -files = [ - {file = "ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c"}, - {file = "ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81"}, -] - -[package.dependencies] -pygments = "*" - [[package]] name = "ipywidgets" version = "8.1.8" @@ -2323,15 +2253,15 @@ scipy = ">=1.13" [[package]] name = "jaxtyping" -version = "0.3.4" +version = "0.3.5" description = "Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays." optional = true python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "jaxtyping-0.3.4-py3-none-any.whl", hash = "sha256:70e438db2f361575d04cccea50f77f9c8fe92f8b2086dc0ce89e5f1658bebaab"}, - {file = "jaxtyping-0.3.4.tar.gz", hash = "sha256:b4aac576a1b6c62a363f76f543f21c7cd4c7bb8714816c2c875f28b7abcdb770"}, + {file = "jaxtyping-0.3.5-py3-none-any.whl", hash = "sha256:862c39fa2e526274e82dc96ee8dbe9369dadb651ab1e05d95bd685acb4e2ef02"}, + {file = "jaxtyping-0.3.5.tar.gz", hash = "sha256:8150ad5b72b62fa63f573d492a79e9e455f070abe3b260f7dc15270b3eb9bba6"}, ] [package.dependencies] @@ -2408,20 +2338,17 @@ files = [ [[package]] name = "json5" -version = "0.12.1" +version = "0.13.0" description = "A Python implementation of the JSON5 data format." optional = true python-versions = ">=3.8.0" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "json5-0.12.1-py3-none-any.whl", hash = "sha256:d9c9b3bc34a5f54d43c35e11ef7cb87d8bdd098c6ace87117a7b7e83e705c1d5"}, - {file = "json5-0.12.1.tar.gz", hash = "sha256:b2743e77b3242f8d03c143dd975a6ec7c52e2f2afe76ed934e53503dd4ad4990"}, + {file = "json5-0.13.0-py3-none-any.whl", hash = "sha256:9a08e1dd65f6a4d4c6fa82d216cf2477349ec2346a38fd70cc11d2557499fbcc"}, + {file = "json5-0.13.0.tar.gz", hash = "sha256:b1edf8d487721c0bf64d83c28e91280781f6e21f4a797d3261c7c828d4c165bf"}, ] -[package.extras] -dev = ["build (==1.2.2.post1)", "coverage (==7.5.4) ; python_version < \"3.9\"", "coverage (==7.8.0) ; python_version >= \"3.9\"", "mypy (==1.14.1) ; python_version < \"3.9\"", "mypy (==1.15.0) ; python_version >= \"3.9\"", "pip (==25.0.1)", "pylint (==3.2.7) ; python_version < \"3.9\"", "pylint (==3.3.6) ; python_version >= \"3.9\"", "ruff (==0.11.2)", "twine (==6.1.0)", "uv (==0.6.11)"] - [[package]] name = "jsonpointer" version = "3.0.0" @@ -2437,15 +2364,15 @@ files = [ [[package]] name = "jsonschema" -version = "4.25.1" +version = "4.26.0" description = "An implementation of JSON Schema validation for Python" optional = true -python-versions = ">=3.9" +python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "jsonschema-4.25.1-py3-none-any.whl", hash = "sha256:3fba0169e345c7175110351d456342c364814cfcf3b964ba4587f22915230a63"}, - {file = "jsonschema-4.25.1.tar.gz", hash = "sha256:e4a9655ce0da0c0b67a085847e00a3a51449e1157f4f75e9fb5aa545e122eb85"}, + {file = "jsonschema-4.26.0-py3-none-any.whl", hash = "sha256:d489f15263b8d200f8387e64b4c3a75f06629559fb73deb8fdfb525f2dab50ce"}, + {file = "jsonschema-4.26.0.tar.gz", hash = "sha256:0c26707e2efad8aa1bfc5b7ce170f3fccc2e4918ff85989ba9ffa9facb2be326"}, ] [package.dependencies] @@ -2459,7 +2386,7 @@ referencing = ">=0.28.4" rfc3339-validator = {version = "*", optional = true, markers = "extra == \"format-nongpl\""} rfc3986-validator = {version = ">0.1.0", optional = true, markers = "extra == \"format-nongpl\""} rfc3987-syntax = {version = ">=1.1.0", optional = true, markers = "extra == \"format-nongpl\""} -rpds-py = ">=0.7.1" +rpds-py = ">=0.25.0" uri-template = {version = "*", optional = true, markers = "extra == \"format-nongpl\""} webcolors = {version = ">=24.6.0", optional = true, markers = "extra == \"format-nongpl\""} @@ -2506,15 +2433,15 @@ notebook = "*" [[package]] name = "jupyter-client" -version = "8.7.0" +version = "8.8.0" description = "Jupyter protocol implementation and client libraries" optional = true python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "jupyter_client-8.7.0-py3-none-any.whl", hash = "sha256:3671a94fd25e62f5f2f554f5e95389c2294d89822378a5f2dd24353e1494a9e0"}, - {file = "jupyter_client-8.7.0.tar.gz", hash = "sha256:3357212d9cbe01209e59190f67a3a7e1f387a4f4e88d1e0433ad84d7b262531d"}, + {file = "jupyter_client-8.8.0-py3-none-any.whl", hash = "sha256:f93a5b99c5e23a507b773d3a1136bd6e16c67883ccdbd9a829b0bbdb98cd7d7a"}, + {file = "jupyter_client-8.8.0.tar.gz", hash = "sha256:d556811419a4f2d96c869af34e854e3f059b7cc2d6d01a9cd9c85c267691be3e"}, ] [package.dependencies] @@ -2526,7 +2453,8 @@ traitlets = ">=5.3" [package.extras] docs = ["ipykernel", "myst-parser", "pydata-sphinx-theme", "sphinx (>=4)", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling"] -test = ["anyio", "coverage", "ipykernel (>=6.14)", "mypy", "paramiko ; sys_platform == \"win32\"", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.6.2)", "pytest-timeout"] +orjson = ["orjson"] +test = ["anyio", "coverage", "ipykernel (>=6.14)", "msgpack", "mypy ; platform_python_implementation != \"PyPy\"", "paramiko ; sys_platform == \"win32\"", "pre-commit", "pytest", "pytest-cov", "pytest-jupyter[client] (>=0.6.2)", "pytest-timeout"] [[package]] name = "jupyter-console" @@ -2678,15 +2606,15 @@ test = ["jupyter-server[test]", "pytest"] [[package]] name = "jupyter-server-terminals" -version = "0.5.3" +version = "0.5.4" description = "A Jupyter Server Extension Providing Terminals." optional = true python-versions = ">=3.8" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "jupyter_server_terminals-0.5.3-py3-none-any.whl", hash = "sha256:41ee0d7dc0ebf2809c668e0fc726dfaf258fcd3e769568996ca731b6194ae9aa"}, - {file = "jupyter_server_terminals-0.5.3.tar.gz", hash = "sha256:5ae0295167220e9ace0edcfdb212afd2b01ee8d179fe6f23c899590e9b8a5269"}, + {file = "jupyter_server_terminals-0.5.4-py3-none-any.whl", hash = "sha256:55be353fc74a80bc7f3b20e6be50a55a61cd525626f578dcb66a5708e2007d14"}, + {file = "jupyter_server_terminals-0.5.4.tar.gz", hash = "sha256:bbda128ed41d0be9020349f9f1f2a4ab9952a73ed5f5ac9f1419794761fb87f5"}, ] [package.dependencies] @@ -2699,15 +2627,15 @@ test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (> [[package]] name = "jupyterlab" -version = "4.5.1" +version = "4.5.2" description = "JupyterLab computational environment" optional = true python-versions = ">=3.9" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "jupyterlab-4.5.1-py3-none-any.whl", hash = "sha256:31b059de96de0754ff1f2ce6279774b6aab8c34d7082e9752db58207c99bd514"}, - {file = "jupyterlab-4.5.1.tar.gz", hash = "sha256:09da1ddfbd9eec18b5101dbb8515612aa1e47443321fb99503725a88e93d20d9"}, + {file = "jupyterlab-4.5.2-py3-none-any.whl", hash = "sha256:76466ebcfdb7a9bb7e2fbd6459c0e2c032ccf75be673634a84bee4b3e6b13ab6"}, + {file = "jupyterlab-4.5.2.tar.gz", hash = "sha256:c80a6b9f6dace96a566d590c65ee2785f61e7cd4aac5b4d453dcc7d0d5e069b7"}, ] [package.dependencies] @@ -2904,7 +2832,7 @@ description = "a KLU solver for JAX" optional = true python-versions = ">=3.10" groups = ["main"] -markers = "extra == \"dev\" or extra == \"docs\"" +markers = "python_version < \"3.11\" and (extra == \"dev\" or extra == \"docs\")" files = [ {file = "klujax-0.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:26127b2f76f76d76e26bd20ea516e3b42cab020b6eeafc7bdb17c1e9258b11e1"}, {file = "klujax-0.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e5cd54acdd452407343d52076aa54cf33538ef7a3c57a66643d5a9d90b1c75"}, @@ -2937,6 +2865,46 @@ numpy = ">=2" [package.extras] dev = ["build (>=1.2.2)", "ipykernel (>=6.29.5)", "pre-commit (>=4.1.0)", "pytest (>=8.3.4)", "ruff (>=0.9.7)", "setuptools (>=75.8.0)", "tbump (>=6.11.0)"] +[[package]] +name = "klujax" +version = "0.4.6" +description = "a KLU solver for JAX" +optional = true +python-versions = ">=3.11" +groups = ["main"] +markers = "python_version >= \"3.11\" and (extra == \"dev\" or extra == \"docs\")" +files = [ + {file = "klujax-0.4.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d39900db9e2f76876cfac4bf1961be84ac625c667985bacd882185150deffdf9"}, + {file = "klujax-0.4.6-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:09f16b4fe4ecfbc6af28346470e14bc2da56933868b0c3277f0720784c0dbd04"}, + {file = "klujax-0.4.6-cp311-cp311-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de99500b9a9b34259b9a91fbca305b0ed7716382134d5558633b704685676efb"}, + {file = "klujax-0.4.6-cp311-cp311-win_amd64.whl", hash = "sha256:28b89841698fe55258605e70df51ddb06922c6cdea544bb62e70cac2b9710433"}, + {file = "klujax-0.4.6-cp311-cp311-win_arm64.whl", hash = "sha256:64a72ea00d9b77dacb5570f62e4953a74c663908da0ca1393cec1c3d9b4d1840"}, + {file = "klujax-0.4.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:18312d5f2cb85bace3e58dd203d0d97892bf56175d0909263dce238d8cd34dd0"}, + {file = "klujax-0.4.6-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f2dd8c7cb581f55c1984ea43f350c82a33f81624336efdbe3e4d7ed6bf6cc1dc"}, + {file = "klujax-0.4.6-cp312-cp312-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ac4ca525d1fd53d252e059d9225dd3a61622e1782419cd47ef42b56efde0f5d"}, + {file = "klujax-0.4.6-cp312-cp312-win_amd64.whl", hash = "sha256:67bc8d9303d2c6b641a3418f7cc997a96464b4433ca398dd28f8b6c65c632ca1"}, + {file = "klujax-0.4.6-cp312-cp312-win_arm64.whl", hash = "sha256:5f0b4e5f08668305e25cc420f89e03772201ce8cdc4e7fbc38830d0f93b8f200"}, + {file = "klujax-0.4.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:32b526267e55357eec91ac77b808eeb733e60b267bc591f510a9812238d29087"}, + {file = "klujax-0.4.6-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dbbe5cfb4bf42f95be6162847b668457c217501f20b569483d3f32d8b2ee2afa"}, + {file = "klujax-0.4.6-cp313-cp313-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1aa011e7fa6829b86b7a7d39f6d493570e734dde8134af2aba7304a748724c9e"}, + {file = "klujax-0.4.6-cp313-cp313-win_amd64.whl", hash = "sha256:8e40b2c52b84c45c47abef294181fc06c22f06dab0874d0cbbd5bc7cb9b7d8ea"}, + {file = "klujax-0.4.6-cp313-cp313-win_arm64.whl", hash = "sha256:43898b6844b49cf2ad734826f14f026a8714b34d0a3bcaa3936d9a4eca98e7bc"}, + {file = "klujax-0.4.6-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3c448ae1e16d36fdb4913eecc0932f7e091c1829069fc58945e8c054a4f978a0"}, + {file = "klujax-0.4.6-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1fee72f5d678a715a3495edb1dba5671356020f0a65004685b052a117d22f679"}, + {file = "klujax-0.4.6-cp314-cp314-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dca25a378629c1439224eff353bf6d1ebba4577c98fe3a37dc561d0f1f4d02fb"}, + {file = "klujax-0.4.6-cp314-cp314-win_amd64.whl", hash = "sha256:fcecb318448fae0c413c80db1d40d62f5037b1ceb1f5bd14757f6f2d56629d60"}, + {file = "klujax-0.4.6-cp314-cp314-win_arm64.whl", hash = "sha256:e04d85cb00bd168a9e75a2595827faa234fd590ba0ba157185d8f9be1077ca6e"}, +] + +[package.dependencies] +jax = ">=0.5.0" +jaxlib = ">=0.5.0" +jaxtyping = ">=0.2.38" +numpy = ">=2" + +[package.extras] +dev = ["build (>=1.2.2)", "ipykernel (>=6.29.5)", "pre-commit (>=4.1.0)", "pytest (>=8.3.4)", "ruff (>=0.9.7)", "setuptools (>=75.8.0)", "tbump (>=6.11.0)"] + [[package]] name = "lark" version = "1.3.1" @@ -3255,15 +3223,15 @@ psutil = "*" [[package]] name = "mistune" -version = "3.1.4" +version = "3.2.0" description = "A sane and fast Markdown parser with useful plugins and renderers" optional = true python-versions = ">=3.8" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "mistune-3.1.4-py3-none-any.whl", hash = "sha256:93691da911e5d9d2e23bc54472892aff676df27a75274962ff9edc210364266d"}, - {file = "mistune-3.1.4.tar.gz", hash = "sha256:b5a7f801d389f724ec702840c11d8fc48f2b33519102fc7ee739e8177b672164"}, + {file = "mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1"}, + {file = "mistune-3.2.0.tar.gz", hash = "sha256:708487c8a8cdd99c9d90eb3ed4c3ed961246ff78ac82f03418f5183ab70e398a"}, ] [package.dependencies] @@ -3320,10 +3288,10 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.21.2", markers = "python_version >= \"3.10\""}, {version = ">=2.1.0", markers = "python_version >= \"3.13\""}, {version = ">=1.26.0", markers = "python_version == \"3.12\""}, {version = ">=1.23.3", markers = "python_version >= \"3.11\""}, + {version = ">=1.21.2", markers = "python_version >= \"3.10\""}, ] [package.extras] @@ -3535,15 +3503,15 @@ icu = ["PyICU (>=1.0.0)"] [[package]] name = "nbclient" -version = "0.10.3" +version = "0.10.4" description = "A client library for executing notebooks. Formerly nbconvert's ExecutePreprocessor." optional = true python-versions = ">=3.10.0" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "nbclient-0.10.3-py3-none-any.whl", hash = "sha256:39e9bd403504dd2484dd0fd25235bb6a683ce8cd9873356e40d880696adc9e35"}, - {file = "nbclient-0.10.3.tar.gz", hash = "sha256:0baf171ee246e3bb2391da0635e719f27dc77d99aef59e0b04dcb935ee04c575"}, + {file = "nbclient-0.10.4-py3-none-any.whl", hash = "sha256:9162df5a7373d70d606527300a95a975a47c137776cd942e52d9c7e29ff83440"}, + {file = "nbclient-0.10.4.tar.gz", hash = "sha256:1e54091b16e6da39e297b0ece3e10f6f29f4ac4e8ee515d29f8a7099bd6553c9"}, ] [package.dependencies] @@ -3597,15 +3565,15 @@ webpdf = ["playwright"] [[package]] name = "nbdime" -version = "4.0.2" +version = "4.0.3" description = "Diff and merge of Jupyter Notebooks" optional = true python-versions = ">=3.6" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "nbdime-4.0.2-py3-none-any.whl", hash = "sha256:e5a43aca669c576c66e757071c0e882de05ac305311d79aded99bfb5a3e9419e"}, - {file = "nbdime-4.0.2.tar.gz", hash = "sha256:d8279f8f4b236c0b253b20d60c4831bb67843ed8dbd6e09f234eb011d36f1bf2"}, + {file = "nbdime-4.0.3-py3-none-any.whl", hash = "sha256:28bec44cb3d67356e1156b9a948c4045aa123a5d5d83c26ca1801b380c4258e8"}, + {file = "nbdime-4.0.3.tar.gz", hash = "sha256:62ab50a758282523c4501144b9f314221dbbaed0403c12fd70f6a4fcc532ec24"}, ] [package.dependencies] @@ -3714,20 +3682,20 @@ files = [ [[package]] name = "notebook" -version = "7.5.1" +version = "7.5.2" description = "Jupyter Notebook - A web-based notebook environment for interactive computing" optional = true python-versions = ">=3.9" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "notebook-7.5.1-py3-none-any.whl", hash = "sha256:f4e2451c19910c33b88709b84537e11f6368c1cdff1aa0c43db701aea535dd44"}, - {file = "notebook-7.5.1.tar.gz", hash = "sha256:b2fb4cef4d47d08c33aecce1c6c6e84be05436fbd791f88fce8df9fbca088b75"}, + {file = "notebook-7.5.2-py3-none-any.whl", hash = "sha256:17d078a98603d70d62b6b4b3fcb67e87d7a68c398a7ae9b447eb2d7d9aec9979"}, + {file = "notebook-7.5.2.tar.gz", hash = "sha256:83e82f93c199ca730313bea1bb24bc279ea96f74816d038a92d26b6b9d5f3e4a"}, ] [package.dependencies] jupyter-server = ">=2.4.0,<3" -jupyterlab = ">=4.5.1,<4.6" +jupyterlab = ">=4.5.2,<4.6" jupyterlab-server = ">=2.28.0,<3" notebook-shim = ">=0.2,<0.3" tornado = ">=6.2.0" @@ -3824,87 +3792,85 @@ files = [ [[package]] name = "numpy" -version = "2.3.5" +version = "2.4.1" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.11" groups = ["main"] markers = "python_version >= \"3.11\"" files = [ - {file = "numpy-2.3.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:de5672f4a7b200c15a4127042170a694d4df43c992948f5e1af57f0174beed10"}, - {file = "numpy-2.3.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:acfd89508504a19ed06ef963ad544ec6664518c863436306153e13e94605c218"}, - {file = "numpy-2.3.5-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:ffe22d2b05504f786c867c8395de703937f934272eb67586817b46188b4ded6d"}, - {file = "numpy-2.3.5-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:872a5cf366aec6bb1147336480fef14c9164b154aeb6542327de4970282cd2f5"}, - {file = "numpy-2.3.5-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3095bdb8dd297e5920b010e96134ed91d852d81d490e787beca7e35ae1d89cf7"}, - {file = "numpy-2.3.5-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8cba086a43d54ca804ce711b2a940b16e452807acebe7852ff327f1ecd49b0d4"}, - {file = "numpy-2.3.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6cf9b429b21df6b99f4dee7a1218b8b7ffbbe7df8764dc0bd60ce8a0708fed1e"}, - {file = "numpy-2.3.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:396084a36abdb603546b119d96528c2f6263921c50df3c8fd7cb28873a237748"}, - {file = "numpy-2.3.5-cp311-cp311-win32.whl", hash = "sha256:b0c7088a73aef3d687c4deef8452a3ac7c1be4e29ed8bf3b366c8111128ac60c"}, - {file = "numpy-2.3.5-cp311-cp311-win_amd64.whl", hash = "sha256:a414504bef8945eae5f2d7cb7be2d4af77c5d1cb5e20b296c2c25b61dff2900c"}, - {file = "numpy-2.3.5-cp311-cp311-win_arm64.whl", hash = "sha256:0cd00b7b36e35398fa2d16af7b907b65304ef8bb4817a550e06e5012929830fa"}, - {file = "numpy-2.3.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:74ae7b798248fe62021dbf3c914245ad45d1a6b0cb4a29ecb4b31d0bfbc4cc3e"}, - {file = "numpy-2.3.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ee3888d9ff7c14604052b2ca5535a30216aa0a58e948cdd3eeb8d3415f638769"}, - {file = "numpy-2.3.5-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:612a95a17655e213502f60cfb9bf9408efdc9eb1d5f50535cc6eb365d11b42b5"}, - {file = "numpy-2.3.5-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:3101e5177d114a593d79dd79658650fe28b5a0d8abeb8ce6f437c0e6df5be1a4"}, - {file = "numpy-2.3.5-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b973c57ff8e184109db042c842423ff4f60446239bd585a5131cc47f06f789d"}, - {file = "numpy-2.3.5-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d8163f43acde9a73c2a33605353a4f1bc4798745a8b1d73183b28e5b435ae28"}, - {file = "numpy-2.3.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:51c1e14eb1e154ebd80e860722f9e6ed6ec89714ad2db2d3aa33c31d7c12179b"}, - {file = "numpy-2.3.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b46b4ec24f7293f23adcd2d146960559aaf8020213de8ad1909dba6c013bf89c"}, - {file = "numpy-2.3.5-cp312-cp312-win32.whl", hash = "sha256:3997b5b3c9a771e157f9aae01dd579ee35ad7109be18db0e85dbdbe1de06e952"}, - {file = "numpy-2.3.5-cp312-cp312-win_amd64.whl", hash = "sha256:86945f2ee6d10cdfd67bcb4069c1662dd711f7e2a4343db5cecec06b87cf31aa"}, - {file = "numpy-2.3.5-cp312-cp312-win_arm64.whl", hash = "sha256:f28620fe26bee16243be2b7b874da327312240a7cdc38b769a697578d2100013"}, - {file = "numpy-2.3.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d0f23b44f57077c1ede8c5f26b30f706498b4862d3ff0a7298b8411dd2f043ff"}, - {file = "numpy-2.3.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:aa5bc7c5d59d831d9773d1170acac7893ce3a5e130540605770ade83280e7188"}, - {file = "numpy-2.3.5-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:ccc933afd4d20aad3c00bcef049cb40049f7f196e0397f1109dba6fed63267b0"}, - {file = "numpy-2.3.5-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:afaffc4393205524af9dfa400fa250143a6c3bc646c08c9f5e25a9f4b4d6a903"}, - {file = "numpy-2.3.5-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c75442b2209b8470d6d5d8b1c25714270686f14c749028d2199c54e29f20b4d"}, - {file = "numpy-2.3.5-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11e06aa0af8c0f05104d56450d6093ee639e15f24ecf62d417329d06e522e017"}, - {file = "numpy-2.3.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ed89927b86296067b4f81f108a2271d8926467a8868e554eaf370fc27fa3ccaf"}, - {file = "numpy-2.3.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:51c55fe3451421f3a6ef9a9c1439e82101c57a2c9eab9feb196a62b1a10b58ce"}, - {file = "numpy-2.3.5-cp313-cp313-win32.whl", hash = "sha256:1978155dd49972084bd6ef388d66ab70f0c323ddee6f693d539376498720fb7e"}, - {file = "numpy-2.3.5-cp313-cp313-win_amd64.whl", hash = "sha256:00dc4e846108a382c5869e77c6ed514394bdeb3403461d25a829711041217d5b"}, - {file = "numpy-2.3.5-cp313-cp313-win_arm64.whl", hash = "sha256:0472f11f6ec23a74a906a00b48a4dcf3849209696dff7c189714511268d103ae"}, - {file = "numpy-2.3.5-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:414802f3b97f3c1eef41e530aaba3b3c1620649871d8cb38c6eaff034c2e16bd"}, - {file = "numpy-2.3.5-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5ee6609ac3604fa7780e30a03e5e241a7956f8e2fcfe547d51e3afa5247ac47f"}, - {file = "numpy-2.3.5-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:86d835afea1eaa143012a2d7a3f45a3adce2d7adc8b4961f0b362214d800846a"}, - {file = "numpy-2.3.5-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:30bc11310e8153ca664b14c5f1b73e94bd0503681fcf136a163de856f3a50139"}, - {file = "numpy-2.3.5-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1062fde1dcf469571705945b0f221b73928f34a20c904ffb45db101907c3454e"}, - {file = "numpy-2.3.5-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ce581db493ea1a96c0556360ede6607496e8bf9b3a8efa66e06477267bc831e9"}, - {file = "numpy-2.3.5-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:cc8920d2ec5fa99875b670bb86ddeb21e295cb07aa331810d9e486e0b969d946"}, - {file = "numpy-2.3.5-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:9ee2197ef8c4f0dfe405d835f3b6a14f5fee7782b5de51ba06fb65fc9b36e9f1"}, - {file = "numpy-2.3.5-cp313-cp313t-win32.whl", hash = "sha256:70b37199913c1bd300ff6e2693316c6f869c7ee16378faf10e4f5e3275b299c3"}, - {file = "numpy-2.3.5-cp313-cp313t-win_amd64.whl", hash = "sha256:b501b5fa195cc9e24fe102f21ec0a44dffc231d2af79950b451e0d99cea02234"}, - {file = "numpy-2.3.5-cp313-cp313t-win_arm64.whl", hash = "sha256:a80afd79f45f3c4a7d341f13acbe058d1ca8ac017c165d3fa0d3de6bc1a079d7"}, - {file = "numpy-2.3.5-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:bf06bc2af43fa8d32d30fae16ad965663e966b1a3202ed407b84c989c3221e82"}, - {file = "numpy-2.3.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:052e8c42e0c49d2575621c158934920524f6c5da05a1d3b9bab5d8e259e045f0"}, - {file = "numpy-2.3.5-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:1ed1ec893cff7040a02c8aa1c8611b94d395590d553f6b53629a4461dc7f7b63"}, - {file = "numpy-2.3.5-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:2dcd0808a421a482a080f89859a18beb0b3d1e905b81e617a188bd80422d62e9"}, - {file = "numpy-2.3.5-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:727fd05b57df37dc0bcf1a27767a3d9a78cbbc92822445f32cc3436ba797337b"}, - {file = "numpy-2.3.5-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fffe29a1ef00883599d1dc2c51aa2e5d80afe49523c261a74933df395c15c520"}, - {file = "numpy-2.3.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8f7f0e05112916223d3f438f293abf0727e1181b5983f413dfa2fefc4098245c"}, - {file = "numpy-2.3.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2e2eb32ddb9ccb817d620ac1d8dae7c3f641c1e5f55f531a33e8ab97960a75b8"}, - {file = "numpy-2.3.5-cp314-cp314-win32.whl", hash = "sha256:66f85ce62c70b843bab1fb14a05d5737741e74e28c7b8b5a064de10142fad248"}, - {file = "numpy-2.3.5-cp314-cp314-win_amd64.whl", hash = "sha256:e6a0bc88393d65807d751a614207b7129a310ca4fe76a74e5c7da5fa5671417e"}, - {file = "numpy-2.3.5-cp314-cp314-win_arm64.whl", hash = "sha256:aeffcab3d4b43712bb7a60b65f6044d444e75e563ff6180af8f98dd4b905dfd2"}, - {file = "numpy-2.3.5-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:17531366a2e3a9e30762c000f2c43a9aaa05728712e25c11ce1dbe700c53ad41"}, - {file = "numpy-2.3.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:d21644de1b609825ede2f48be98dfde4656aefc713654eeee280e37cadc4e0ad"}, - {file = "numpy-2.3.5-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:c804e3a5aba5460c73955c955bdbd5c08c354954e9270a2c1565f62e866bdc39"}, - {file = "numpy-2.3.5-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:cc0a57f895b96ec78969c34f682c602bf8da1a0270b09bc65673df2e7638ec20"}, - {file = "numpy-2.3.5-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:900218e456384ea676e24ea6a0417f030a3b07306d29d7ad843957b40a9d8d52"}, - {file = "numpy-2.3.5-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:09a1bea522b25109bf8e6f3027bd810f7c1085c64a0c7ce050c1676ad0ba010b"}, - {file = "numpy-2.3.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:04822c00b5fd0323c8166d66c701dc31b7fbd252c100acd708c48f763968d6a3"}, - {file = "numpy-2.3.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d6889ec4ec662a1a37eb4b4fb26b6100841804dac55bd9df579e326cdc146227"}, - {file = "numpy-2.3.5-cp314-cp314t-win32.whl", hash = "sha256:93eebbcf1aafdf7e2ddd44c2923e2672e1010bddc014138b229e49725b4d6be5"}, - {file = "numpy-2.3.5-cp314-cp314t-win_amd64.whl", hash = "sha256:c8a9958e88b65c3b27e22ca2a076311636850b612d6bbfb76e8d156aacde2aaf"}, - {file = "numpy-2.3.5-cp314-cp314t-win_arm64.whl", hash = "sha256:6203fdf9f3dc5bdaed7319ad8698e685c7a3be10819f41d32a0723e611733b42"}, - {file = "numpy-2.3.5-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:f0963b55cdd70fad460fa4c1341f12f976bb26cb66021a5580329bd498988310"}, - {file = "numpy-2.3.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:f4255143f5160d0de972d28c8f9665d882b5f61309d8362fdd3e103cf7bf010c"}, - {file = "numpy-2.3.5-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:a4b9159734b326535f4dd01d947f919c6eefd2d9827466a696c44ced82dfbc18"}, - {file = "numpy-2.3.5-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:2feae0d2c91d46e59fcd62784a3a83b3fb677fead592ce51b5a6fbb4f95965ff"}, - {file = "numpy-2.3.5-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ffac52f28a7849ad7576293c0cb7b9f08304e8f7d738a8cb8a90ec4c55a998eb"}, - {file = "numpy-2.3.5-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63c0e9e7eea69588479ebf4a8a270d5ac22763cc5854e9a7eae952a3908103f7"}, - {file = "numpy-2.3.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f16417ec91f12f814b10bafe79ef77e70113a2f5f7018640e7425ff979253425"}, - {file = "numpy-2.3.5.tar.gz", hash = "sha256:784db1dcdab56bf0517743e746dfb0f885fc68d948aba86eeec2cba234bdf1c0"}, + {file = "numpy-2.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0cce2a669e3c8ba02ee563c7835f92c153cf02edff1ae05e1823f1dde21b16a5"}, + {file = "numpy-2.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:899d2c18024984814ac7e83f8f49d8e8180e2fbe1b2e252f2e7f1d06bea92425"}, + {file = "numpy-2.4.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:09aa8a87e45b55a1c2c205d42e2808849ece5c484b2aab11fecabec3841cafba"}, + {file = "numpy-2.4.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:edee228f76ee2dab4579fad6f51f6a305de09d444280109e0f75df247ff21501"}, + {file = "numpy-2.4.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a92f227dbcdc9e4c3e193add1a189a9909947d4f8504c576f4a732fd0b54240a"}, + {file = "numpy-2.4.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:538bf4ec353709c765ff75ae616c34d3c3dca1a68312727e8f2676ea644f8509"}, + {file = "numpy-2.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ac08c63cb7779b85e9d5318e6c3518b424bc1f364ac4cb2c6136f12e5ff2dccc"}, + {file = "numpy-2.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4f9c360ecef085e5841c539a9a12b883dff005fbd7ce46722f5e9cef52634d82"}, + {file = "numpy-2.4.1-cp311-cp311-win32.whl", hash = "sha256:0f118ce6b972080ba0758c6087c3617b5ba243d806268623dc34216d69099ba0"}, + {file = "numpy-2.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:18e14c4d09d55eef39a6ab5b08406e84bc6869c1e34eef45564804f90b7e0574"}, + {file = "numpy-2.4.1-cp311-cp311-win_arm64.whl", hash = "sha256:6461de5113088b399d655d45c3897fa188766415d0f568f175ab071c8873bd73"}, + {file = "numpy-2.4.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d3703409aac693fa82c0aee023a1ae06a6e9d065dba10f5e8e80f642f1e9d0a2"}, + {file = "numpy-2.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7211b95ca365519d3596a1d8688a95874cc94219d417504d9ecb2df99fa7bfa8"}, + {file = "numpy-2.4.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:5adf01965456a664fc727ed69cc71848f28d063217c63e1a0e200a118d5eec9a"}, + {file = "numpy-2.4.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:26f0bcd9c79a00e339565b303badc74d3ea2bd6d52191eeca5f95936cad107d0"}, + {file = "numpy-2.4.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0093e85df2960d7e4049664b26afc58b03236e967fb942354deef3208857a04c"}, + {file = "numpy-2.4.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ad270f438cbdd402c364980317fb6b117d9ec5e226fff5b4148dd9aa9fc6e02"}, + {file = "numpy-2.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:297c72b1b98100c2e8f873d5d35fb551fce7040ade83d67dd51d38c8d42a2162"}, + {file = "numpy-2.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cf6470d91d34bf669f61d515499859fa7a4c2f7c36434afb70e82df7217933f9"}, + {file = "numpy-2.4.1-cp312-cp312-win32.whl", hash = "sha256:b6bcf39112e956594b3331316d90c90c90fb961e39696bda97b89462f5f3943f"}, + {file = "numpy-2.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:e1a27bb1b2dee45a2a53f5ca6ff2d1a7f135287883a1689e930d44d1ff296c87"}, + {file = "numpy-2.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:0e6e8f9d9ecf95399982019c01223dc130542960a12edfa8edd1122dfa66a8a8"}, + {file = "numpy-2.4.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d797454e37570cfd61143b73b8debd623c3c0952959adb817dd310a483d58a1b"}, + {file = "numpy-2.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:82c55962006156aeef1629b953fd359064aa47e4d82cfc8e67f0918f7da3344f"}, + {file = "numpy-2.4.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:71abbea030f2cfc3092a0ff9f8c8fdefdc5e0bf7d9d9c99663538bb0ecdac0b9"}, + {file = "numpy-2.4.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:5b55aa56165b17aaf15520beb9cbd33c9039810e0d9643dd4379e44294c7303e"}, + {file = "numpy-2.4.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0faba4a331195bfa96f93dd9dfaa10b2c7aa8cda3a02b7fd635e588fe821bf5"}, + {file = "numpy-2.4.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e3087f53e2b4428766b54932644d148613c5a595150533ae7f00dab2f319a8"}, + {file = "numpy-2.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:49e792ec351315e16da54b543db06ca8a86985ab682602d90c60ef4ff4db2a9c"}, + {file = "numpy-2.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:79e9e06c4c2379db47f3f6fc7a8652e7498251789bf8ff5bd43bf478ef314ca2"}, + {file = "numpy-2.4.1-cp313-cp313-win32.whl", hash = "sha256:3d1a100e48cb266090a031397863ff8a30050ceefd798f686ff92c67a486753d"}, + {file = "numpy-2.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:92a0e65272fd60bfa0d9278e0484c2f52fe03b97aedc02b357f33fe752c52ffb"}, + {file = "numpy-2.4.1-cp313-cp313-win_arm64.whl", hash = "sha256:20d4649c773f66cc2fc36f663e091f57c3b7655f936a4c681b4250855d1da8f5"}, + {file = "numpy-2.4.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f93bc6892fe7b0663e5ffa83b61aab510aacffd58c16e012bb9352d489d90cb7"}, + {file = "numpy-2.4.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:178de8f87948163d98a4c9ab5bee4ce6519ca918926ec8df195af582de28544d"}, + {file = "numpy-2.4.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:98b35775e03ab7f868908b524fc0a84d38932d8daf7b7e1c3c3a1b6c7a2c9f15"}, + {file = "numpy-2.4.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:941c2a93313d030f219f3a71fd3d91a728b82979a5e8034eb2e60d394a2b83f9"}, + {file = "numpy-2.4.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:529050522e983e00a6c1c6b67411083630de8b57f65e853d7b03d9281b8694d2"}, + {file = "numpy-2.4.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2302dc0224c1cbc49bb94f7064f3f923a971bfae45c33870dcbff63a2a550505"}, + {file = "numpy-2.4.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:9171a42fcad32dcf3fa86f0a4faa5e9f8facefdb276f54b8b390d90447cff4e2"}, + {file = "numpy-2.4.1-cp313-cp313t-win32.whl", hash = "sha256:382ad67d99ef49024f11d1ce5dcb5ad8432446e4246a4b014418ba3a1175a1f4"}, + {file = "numpy-2.4.1-cp313-cp313t-win_amd64.whl", hash = "sha256:62fea415f83ad8fdb6c20840578e5fbaf5ddd65e0ec6c3c47eda0f69da172510"}, + {file = "numpy-2.4.1-cp313-cp313t-win_arm64.whl", hash = "sha256:a7870e8c5fc11aef57d6fea4b4085e537a3a60ad2cdd14322ed531fdca68d261"}, + {file = "numpy-2.4.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:3869ea1ee1a1edc16c29bbe3a2f2a4e515cc3a44d43903ad41e0cacdbaf733dc"}, + {file = "numpy-2.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:e867df947d427cdd7a60e3e271729090b0f0df80f5f10ab7dd436f40811699c3"}, + {file = "numpy-2.4.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:e3bd2cb07841166420d2fa7146c96ce00cb3410664cbc1a6be028e456c4ee220"}, + {file = "numpy-2.4.1-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:f0a90aba7d521e6954670550e561a4cb925713bd944445dbe9e729b71f6cabee"}, + {file = "numpy-2.4.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d558123217a83b2d1ba316b986e9248a1ed1971ad495963d555ccd75dcb1556"}, + {file = "numpy-2.4.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2f44de05659b67d20499cbc96d49f2650769afcb398b79b324bb6e297bfe3844"}, + {file = "numpy-2.4.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:69e7419c9012c4aaf695109564e3387f1259f001b4326dfa55907b098af082d3"}, + {file = "numpy-2.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:2ffd257026eb1b34352e749d7cc1678b5eeec3e329ad8c9965a797e08ccba205"}, + {file = "numpy-2.4.1-cp314-cp314-win32.whl", hash = "sha256:727c6c3275ddefa0dc078524a85e064c057b4f4e71ca5ca29a19163c607be745"}, + {file = "numpy-2.4.1-cp314-cp314-win_amd64.whl", hash = "sha256:7d5d7999df434a038d75a748275cd6c0094b0ecdb0837342b332a82defc4dc4d"}, + {file = "numpy-2.4.1-cp314-cp314-win_arm64.whl", hash = "sha256:ce9ce141a505053b3c7bce3216071f3bf5c182b8b28930f14cd24d43932cd2df"}, + {file = "numpy-2.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:4e53170557d37ae404bf8d542ca5b7c629d6efa1117dac6a83e394142ea0a43f"}, + {file = "numpy-2.4.1-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:a73044b752f5d34d4232f25f18160a1cc418ea4507f5f11e299d8ac36875f8a0"}, + {file = "numpy-2.4.1-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:fb1461c99de4d040666ca0444057b06541e5642f800b71c56e6ea92d6a853a0c"}, + {file = "numpy-2.4.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:423797bdab2eeefbe608d7c1ec7b2b4fd3c58d51460f1ee26c7500a1d9c9ee93"}, + {file = "numpy-2.4.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:52b5f61bdb323b566b528899cc7db2ba5d1015bda7ea811a8bcf3c89c331fa42"}, + {file = "numpy-2.4.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:42d7dd5fa36d16d52a84f821eb96031836fd405ee6955dd732f2023724d0aa01"}, + {file = "numpy-2.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e7b6b5e28bbd47b7532698e5db2fe1db693d84b58c254e4389d99a27bb9b8f6b"}, + {file = "numpy-2.4.1-cp314-cp314t-win32.whl", hash = "sha256:5de60946f14ebe15e713a6f22850c2372fa72f4ff9a432ab44aa90edcadaa65a"}, + {file = "numpy-2.4.1-cp314-cp314t-win_amd64.whl", hash = "sha256:8f085da926c0d491ffff3096f91078cc97ea67e7e6b65e490bc8dcda65663be2"}, + {file = "numpy-2.4.1-cp314-cp314t-win_arm64.whl", hash = "sha256:6436cffb4f2bf26c974344439439c95e152c9a527013f26b3577be6c2ca64295"}, + {file = "numpy-2.4.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8ad35f20be147a204e28b6a0575fbf3540c5e5f802634d4258d55b1ff5facce1"}, + {file = "numpy-2.4.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:8097529164c0f3e32bb89412a0905d9100bf434d9692d9fc275e18dcf53c9344"}, + {file = "numpy-2.4.1-pp311-pypy311_pp73-macosx_14_0_arm64.whl", hash = "sha256:ea66d2b41ca4a1630aae5507ee0a71647d3124d1741980138aa8f28f44dac36e"}, + {file = "numpy-2.4.1-pp311-pypy311_pp73-macosx_14_0_x86_64.whl", hash = "sha256:d3f8f0df9f4b8be57b3bf74a1d087fec68f927a2fab68231fdb442bf2c12e426"}, + {file = "numpy-2.4.1-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2023ef86243690c2791fd6353e5b4848eedaa88ca8a2d129f462049f6d484696"}, + {file = "numpy-2.4.1-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8361ea4220d763e54cff2fbe7d8c93526b744f7cd9ddab47afeff7e14e8503be"}, + {file = "numpy-2.4.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:4f1b68ff47680c2925f8063402a693ede215f0257f02596b1318ecdfb1d79e33"}, + {file = "numpy-2.4.1.tar.gz", hash = "sha256:a1ceafc5042451a858231588a104093474c6a5c57dcc724841f5c888d237d690"}, ] [[package]] @@ -4401,9 +4367,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -4485,17 +4451,23 @@ complete = ["blosc", "numpy (>=1.20.0)", "pandas (>=1.3)", "pyzmq"] [[package]] name = "pathspec" -version = "0.12.1" +version = "1.0.3" description = "Utility library for gitignore style pattern matching of file paths." optional = true -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, - {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, + {file = "pathspec-1.0.3-py3-none-any.whl", hash = "sha256:e80767021c1cc524aa3fb14bedda9c34406591343cc42797b386ce7b9354fb6c"}, + {file = "pathspec-1.0.3.tar.gz", hash = "sha256:bac5cf97ae2c2876e2d25ebb15078eb04d76e4b98921ee31c6f85ade8b59444d"}, ] +[package.extras] +hyperscan = ["hyperscan (>=0.7)"] +optional = ["typing-extensions (>=4)"] +re2 = ["google-re2 (>=1.1)"] +tests = ["pytest (>=9)", "typing-extensions (>=4.15)"] + [[package]] name = "pexpect" version = "4.9.0" @@ -4514,103 +4486,103 @@ ptyprocess = ">=0.5" [[package]] name = "pillow" -version = "12.0.0" +version = "12.1.0" description = "Python Imaging Library (fork)" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "pillow-12.0.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:3adfb466bbc544b926d50fe8f4a4e6abd8c6bffd28a26177594e6e9b2b76572b"}, - {file = "pillow-12.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1ac11e8ea4f611c3c0147424eae514028b5e9077dd99ab91e1bd7bc33ff145e1"}, - {file = "pillow-12.0.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d49e2314c373f4c2b39446fb1a45ed333c850e09d0c59ac79b72eb3b95397363"}, - {file = "pillow-12.0.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c7b2a63fd6d5246349f3d3f37b14430d73ee7e8173154461785e43036ffa96ca"}, - {file = "pillow-12.0.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d64317d2587c70324b79861babb9c09f71fbb780bad212018874b2c013d8600e"}, - {file = "pillow-12.0.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d77153e14b709fd8b8af6f66a3afbb9ed6e9fc5ccf0b6b7e1ced7b036a228782"}, - {file = "pillow-12.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:32ed80ea8a90ee3e6fa08c21e2e091bba6eda8eccc83dbc34c95169507a91f10"}, - {file = "pillow-12.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c828a1ae702fc712978bda0320ba1b9893d99be0badf2647f693cc01cf0f04fa"}, - {file = "pillow-12.0.0-cp310-cp310-win32.whl", hash = "sha256:bd87e140e45399c818fac4247880b9ce719e4783d767e030a883a970be632275"}, - {file = "pillow-12.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:455247ac8a4cfb7b9bc45b7e432d10421aea9fc2e74d285ba4072688a74c2e9d"}, - {file = "pillow-12.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:6ace95230bfb7cd79ef66caa064bbe2f2a1e63d93471c3a2e1f1348d9f22d6b7"}, - {file = "pillow-12.0.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:0fd00cac9c03256c8b2ff58f162ebcd2587ad3e1f2e397eab718c47e24d231cc"}, - {file = "pillow-12.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3475b96f5908b3b16c47533daaa87380c491357d197564e0ba34ae75c0f3257"}, - {file = "pillow-12.0.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:110486b79f2d112cf6add83b28b627e369219388f64ef2f960fef9ebaf54c642"}, - {file = "pillow-12.0.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5269cc1caeedb67e6f7269a42014f381f45e2e7cd42d834ede3c703a1d915fe3"}, - {file = "pillow-12.0.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:aa5129de4e174daccbc59d0a3b6d20eaf24417d59851c07ebb37aeb02947987c"}, - {file = "pillow-12.0.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bee2a6db3a7242ea309aa7ee8e2780726fed67ff4e5b40169f2c940e7eb09227"}, - {file = "pillow-12.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:90387104ee8400a7b4598253b4c406f8958f59fcf983a6cea2b50d59f7d63d0b"}, - {file = "pillow-12.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc91a56697869546d1b8f0a3ff35224557ae7f881050e99f615e0119bf934b4e"}, - {file = "pillow-12.0.0-cp311-cp311-win32.whl", hash = "sha256:27f95b12453d165099c84f8a8bfdfd46b9e4bda9e0e4b65f0635430027f55739"}, - {file = "pillow-12.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:b583dc9070312190192631373c6c8ed277254aa6e6084b74bdd0a6d3b221608e"}, - {file = "pillow-12.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:759de84a33be3b178a64c8ba28ad5c135900359e85fb662bc6e403ad4407791d"}, - {file = "pillow-12.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:53561a4ddc36facb432fae7a9d8afbfaf94795414f5cdc5fc52f28c1dca90371"}, - {file = "pillow-12.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:71db6b4c1653045dacc1585c1b0d184004f0d7e694c7b34ac165ca70c0838082"}, - {file = "pillow-12.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2fa5f0b6716fc88f11380b88b31fe591a06c6315e955c096c35715788b339e3f"}, - {file = "pillow-12.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:82240051c6ca513c616f7f9da06e871f61bfd7805f566275841af15015b8f98d"}, - {file = "pillow-12.0.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:55f818bd74fe2f11d4d7cbc65880a843c4075e0ac7226bc1a23261dbea531953"}, - {file = "pillow-12.0.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b87843e225e74576437fd5b6a4c2205d422754f84a06942cfaf1dc32243e45a8"}, - {file = "pillow-12.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c607c90ba67533e1b2355b821fef6764d1dd2cbe26b8c1005ae84f7aea25ff79"}, - {file = "pillow-12.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:21f241bdd5080a15bc86d3466a9f6074a9c2c2b314100dd896ac81ee6db2f1ba"}, - {file = "pillow-12.0.0-cp312-cp312-win32.whl", hash = "sha256:dd333073e0cacdc3089525c7df7d39b211bcdf31fc2824e49d01c6b6187b07d0"}, - {file = "pillow-12.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9fe611163f6303d1619bbcb653540a4d60f9e55e622d60a3108be0d5b441017a"}, - {file = "pillow-12.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:7dfb439562f234f7d57b1ac6bc8fe7f838a4bd49c79230e0f6a1da93e82f1fad"}, - {file = "pillow-12.0.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:0869154a2d0546545cde61d1789a6524319fc1897d9ee31218eae7a60ccc5643"}, - {file = "pillow-12.0.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:a7921c5a6d31b3d756ec980f2f47c0cfdbce0fc48c22a39347a895f41f4a6ea4"}, - {file = "pillow-12.0.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:1ee80a59f6ce048ae13cda1abf7fbd2a34ab9ee7d401c46be3ca685d1999a399"}, - {file = "pillow-12.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c50f36a62a22d350c96e49ad02d0da41dbd17ddc2e29750dbdba4323f85eb4a5"}, - {file = "pillow-12.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5193fde9a5f23c331ea26d0cf171fbf67e3f247585f50c08b3e205c7aeb4589b"}, - {file = "pillow-12.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bde737cff1a975b70652b62d626f7785e0480918dece11e8fef3c0cf057351c3"}, - {file = "pillow-12.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a6597ff2b61d121172f5844b53f21467f7082f5fb385a9a29c01414463f93b07"}, - {file = "pillow-12.0.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b817e7035ea7f6b942c13aa03bb554fc44fea70838ea21f8eb31c638326584e"}, - {file = "pillow-12.0.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f4f1231b7dec408e8670264ce63e9c71409d9583dd21d32c163e25213ee2a344"}, - {file = "pillow-12.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6e51b71417049ad6ab14c49608b4a24d8fb3fe605e5dfabfe523b58064dc3d27"}, - {file = "pillow-12.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d120c38a42c234dc9a8c5de7ceaaf899cf33561956acb4941653f8bdc657aa79"}, - {file = "pillow-12.0.0-cp313-cp313-win32.whl", hash = "sha256:4cc6b3b2efff105c6a1656cfe59da4fdde2cda9af1c5e0b58529b24525d0a098"}, - {file = "pillow-12.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:4cf7fed4b4580601c4345ceb5d4cbf5a980d030fd5ad07c4d2ec589f95f09905"}, - {file = "pillow-12.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:9f0b04c6b8584c2c193babcccc908b38ed29524b29dd464bc8801bf10d746a3a"}, - {file = "pillow-12.0.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:7fa22993bac7b77b78cae22bad1e2a987ddf0d9015c63358032f84a53f23cdc3"}, - {file = "pillow-12.0.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:f135c702ac42262573fe9714dfe99c944b4ba307af5eb507abef1667e2cbbced"}, - {file = "pillow-12.0.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c85de1136429c524e55cfa4e033b4a7940ac5c8ee4d9401cc2d1bf48154bbc7b"}, - {file = "pillow-12.0.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:38df9b4bfd3db902c9c2bd369bcacaf9d935b2fff73709429d95cc41554f7b3d"}, - {file = "pillow-12.0.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7d87ef5795da03d742bf49439f9ca4d027cde49c82c5371ba52464aee266699a"}, - {file = "pillow-12.0.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aff9e4d82d082ff9513bdd6acd4f5bd359f5b2c870907d2b0a9c5e10d40c88fe"}, - {file = "pillow-12.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:8d8ca2b210ada074d57fcee40c30446c9562e542fc46aedc19baf758a93532ee"}, - {file = "pillow-12.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:99a7f72fb6249302aa62245680754862a44179b545ded638cf1fef59befb57ef"}, - {file = "pillow-12.0.0-cp313-cp313t-win32.whl", hash = "sha256:4078242472387600b2ce8d93ade8899c12bf33fa89e55ec89fe126e9d6d5d9e9"}, - {file = "pillow-12.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:2c54c1a783d6d60595d3514f0efe9b37c8808746a66920315bfd34a938d7994b"}, - {file = "pillow-12.0.0-cp313-cp313t-win_arm64.whl", hash = "sha256:26d9f7d2b604cd23aba3e9faf795787456ac25634d82cd060556998e39c6fa47"}, - {file = "pillow-12.0.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:beeae3f27f62308f1ddbcfb0690bf44b10732f2ef43758f169d5e9303165d3f9"}, - {file = "pillow-12.0.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:d4827615da15cd59784ce39d3388275ec093ae3ee8d7f0c089b76fa87af756c2"}, - {file = "pillow-12.0.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:3e42edad50b6909089750e65c91aa09aaf1e0a71310d383f11321b27c224ed8a"}, - {file = "pillow-12.0.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:e5d8efac84c9afcb40914ab49ba063d94f5dbdf5066db4482c66a992f47a3a3b"}, - {file = "pillow-12.0.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:266cd5f2b63ff316d5a1bba46268e603c9caf5606d44f38c2873c380950576ad"}, - {file = "pillow-12.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:58eea5ebe51504057dd95c5b77d21700b77615ab0243d8152793dc00eb4faf01"}, - {file = "pillow-12.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f13711b1a5ba512d647a0e4ba79280d3a9a045aaf7e0cc6fbe96b91d4cdf6b0c"}, - {file = "pillow-12.0.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6846bd2d116ff42cba6b646edf5bf61d37e5cbd256425fa089fee4ff5c07a99e"}, - {file = "pillow-12.0.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c98fa880d695de164b4135a52fd2e9cd7b7c90a9d8ac5e9e443a24a95ef9248e"}, - {file = "pillow-12.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:fa3ed2a29a9e9d2d488b4da81dcb54720ac3104a20bf0bd273f1e4648aff5af9"}, - {file = "pillow-12.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d034140032870024e6b9892c692fe2968493790dd57208b2c37e3fb35f6df3ab"}, - {file = "pillow-12.0.0-cp314-cp314-win32.whl", hash = "sha256:1b1b133e6e16105f524a8dec491e0586d072948ce15c9b914e41cdadd209052b"}, - {file = "pillow-12.0.0-cp314-cp314-win_amd64.whl", hash = "sha256:8dc232e39d409036af549c86f24aed8273a40ffa459981146829a324e0848b4b"}, - {file = "pillow-12.0.0-cp314-cp314-win_arm64.whl", hash = "sha256:d52610d51e265a51518692045e372a4c363056130d922a7351429ac9f27e70b0"}, - {file = "pillow-12.0.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:1979f4566bb96c1e50a62d9831e2ea2d1211761e5662afc545fa766f996632f6"}, - {file = "pillow-12.0.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:b2e4b27a6e15b04832fe9bf292b94b5ca156016bbc1ea9c2c20098a0320d6cf6"}, - {file = "pillow-12.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fb3096c30df99fd01c7bf8e544f392103d0795b9f98ba71a8054bcbf56b255f1"}, - {file = "pillow-12.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7438839e9e053ef79f7112c881cef684013855016f928b168b81ed5835f3e75e"}, - {file = "pillow-12.0.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5d5c411a8eaa2299322b647cd932586b1427367fd3184ffbb8f7a219ea2041ca"}, - {file = "pillow-12.0.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d7e091d464ac59d2c7ad8e7e08105eaf9dafbc3883fd7265ffccc2baad6ac925"}, - {file = "pillow-12.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:792a2c0be4dcc18af9d4a2dfd8a11a17d5e25274a1062b0ec1c2d79c76f3e7f8"}, - {file = "pillow-12.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:afbefa430092f71a9593a99ab6a4e7538bc9eabbf7bf94f91510d3503943edc4"}, - {file = "pillow-12.0.0-cp314-cp314t-win32.whl", hash = "sha256:3830c769decf88f1289680a59d4f4c46c72573446352e2befec9a8512104fa52"}, - {file = "pillow-12.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:905b0365b210c73afb0ebe9101a32572152dfd1c144c7e28968a331b9217b94a"}, - {file = "pillow-12.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:99353a06902c2e43b43e8ff74ee65a7d90307d82370604746738a1e0661ccca7"}, - {file = "pillow-12.0.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b22bd8c974942477156be55a768f7aa37c46904c175be4e158b6a86e3a6b7ca8"}, - {file = "pillow-12.0.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:805ebf596939e48dbb2e4922a1d3852cfc25c38160751ce02da93058b48d252a"}, - {file = "pillow-12.0.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cae81479f77420d217def5f54b5b9d279804d17e982e0f2fa19b1d1e14ab5197"}, - {file = "pillow-12.0.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:aeaefa96c768fc66818730b952a862235d68825c178f1b3ffd4efd7ad2edcb7c"}, - {file = "pillow-12.0.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:09f2d0abef9e4e2f349305a4f8cc784a8a6c2f58a8c4892eea13b10a943bd26e"}, - {file = "pillow-12.0.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bdee52571a343d721fb2eb3b090a82d959ff37fc631e3f70422e0c2e029f3e76"}, - {file = "pillow-12.0.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:b290fd8aa38422444d4b50d579de197557f182ef1068b75f5aa8558638b8d0a5"}, - {file = "pillow-12.0.0.tar.gz", hash = "sha256:87d4f8125c9988bfbed67af47dd7a953e2fc7b0cc1e7800ec6d2080d490bb353"}, + {file = "pillow-12.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:fb125d860738a09d363a88daa0f59c4533529a90e564785e20fe875b200b6dbd"}, + {file = "pillow-12.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cad302dc10fac357d3467a74a9561c90609768a6f73a1923b0fd851b6486f8b0"}, + {file = "pillow-12.1.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a40905599d8079e09f25027423aed94f2823adaf2868940de991e53a449e14a8"}, + {file = "pillow-12.1.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:92a7fe4225365c5e3a8e598982269c6d6698d3e783b3b1ae979e7819f9cd55c1"}, + {file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f10c98f49227ed8383d28174ee95155a675c4ed7f85e2e573b04414f7e371bda"}, + {file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8637e29d13f478bc4f153d8daa9ffb16455f0a6cb287da1b432fdad2bfbd66c7"}, + {file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:21e686a21078b0f9cb8c8a961d99e6a4ddb88e0fc5ea6e130172ddddc2e5221a"}, + {file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2415373395a831f53933c23ce051021e79c8cd7979822d8cc478547a3f4da8ef"}, + {file = "pillow-12.1.0-cp310-cp310-win32.whl", hash = "sha256:e75d3dba8fc1ddfec0cd752108f93b83b4f8d6ab40e524a95d35f016b9683b09"}, + {file = "pillow-12.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:64efdf00c09e31efd754448a383ea241f55a994fd079866b92d2bbff598aad91"}, + {file = "pillow-12.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:f188028b5af6b8fb2e9a76ac0f841a575bd1bd396e46ef0840d9b88a48fdbcea"}, + {file = "pillow-12.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:a83e0850cb8f5ac975291ebfc4170ba481f41a28065277f7f735c202cd8e0af3"}, + {file = "pillow-12.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b6e53e82ec2db0717eabb276aa56cf4e500c9a7cec2c2e189b55c24f65a3e8c0"}, + {file = "pillow-12.1.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:40a8e3b9e8773876d6e30daed22f016509e3987bab61b3b7fe309d7019a87451"}, + {file = "pillow-12.1.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:800429ac32c9b72909c671aaf17ecd13110f823ddb7db4dfef412a5587c2c24e"}, + {file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b022eaaf709541b391ee069f0022ee5b36c709df71986e3f7be312e46f42c84"}, + {file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f345e7bc9d7f368887c712aa5054558bad44d2a301ddf9248599f4161abc7c0"}, + {file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d70347c8a5b7ccd803ec0c85c8709f036e6348f1e6a5bf048ecd9c64d3550b8b"}, + {file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1fcc52d86ce7a34fd17cb04e87cfdb164648a3662a6f20565910a99653d66c18"}, + {file = "pillow-12.1.0-cp311-cp311-win32.whl", hash = "sha256:3ffaa2f0659e2f740473bcf03c702c39a8d4b2b7ffc629052028764324842c64"}, + {file = "pillow-12.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:806f3987ffe10e867bab0ddad45df1148a2b98221798457fa097ad85d6e8bc75"}, + {file = "pillow-12.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:9f5fefaca968e700ad1a4a9de98bf0869a94e397fe3524c4c9450c1445252304"}, + {file = "pillow-12.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a332ac4ccb84b6dde65dbace8431f3af08874bf9770719d32a635c4ef411b18b"}, + {file = "pillow-12.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:907bfa8a9cb790748a9aa4513e37c88c59660da3bcfffbd24a7d9e6abf224551"}, + {file = "pillow-12.1.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:efdc140e7b63b8f739d09a99033aa430accce485ff78e6d311973a67b6bf3208"}, + {file = "pillow-12.1.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bef9768cab184e7ae6e559c032e95ba8d07b3023c289f79a2bd36e8bf85605a5"}, + {file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:742aea052cf5ab5034a53c3846165bc3ce88d7c38e954120db0ab867ca242661"}, + {file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6dfc2af5b082b635af6e08e0d1f9f1c4e04d17d4e2ca0ef96131e85eda6eb17"}, + {file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:609e89d9f90b581c8d16358c9087df76024cf058fa693dd3e1e1620823f39670"}, + {file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:43b4899cfd091a9693a1278c4982f3e50f7fb7cff5153b05174b4afc9593b616"}, + {file = "pillow-12.1.0-cp312-cp312-win32.whl", hash = "sha256:aa0c9cc0b82b14766a99fbe6084409972266e82f459821cd26997a488a7261a7"}, + {file = "pillow-12.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:d70534cea9e7966169ad29a903b99fc507e932069a881d0965a1a84bb57f6c6d"}, + {file = "pillow-12.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:65b80c1ee7e14a87d6a068dd3b0aea268ffcabfe0498d38661b00c5b4b22e74c"}, + {file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:7b5dd7cbae20285cdb597b10eb5a2c13aa9de6cde9bb64a3c1317427b1db1ae1"}, + {file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:29a4cef9cb672363926f0470afc516dbf7305a14d8c54f7abbb5c199cd8f8179"}, + {file = "pillow-12.1.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:681088909d7e8fa9e31b9799aaa59ba5234c58e5e4f1951b4c4d1082a2e980e0"}, + {file = "pillow-12.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:983976c2ab753166dc66d36af6e8ec15bb511e4a25856e2227e5f7e00a160587"}, + {file = "pillow-12.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:db44d5c160a90df2d24a24760bbd37607d53da0b34fb546c4c232af7192298ac"}, + {file = "pillow-12.1.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6b7a9d1db5dad90e2991645874f708e87d9a3c370c243c2d7684d28f7e133e6b"}, + {file = "pillow-12.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6258f3260986990ba2fa8a874f8b6e808cf5abb51a94015ca3dc3c68aa4f30ea"}, + {file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e115c15e3bc727b1ca3e641a909f77f8ca72a64fff150f666fcc85e57701c26c"}, + {file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6741e6f3074a35e47c77b23a4e4f2d90db3ed905cb1c5e6e0d49bff2045632bc"}, + {file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:935b9d1aed48fcfb3f838caac506f38e29621b44ccc4f8a64d575cb1b2a88644"}, + {file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5fee4c04aad8932da9f8f710af2c1a15a83582cfb884152a9caa79d4efcdbf9c"}, + {file = "pillow-12.1.0-cp313-cp313-win32.whl", hash = "sha256:a786bf667724d84aa29b5db1c61b7bfdde380202aaca12c3461afd6b71743171"}, + {file = "pillow-12.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:461f9dfdafa394c59cd6d818bdfdbab4028b83b02caadaff0ffd433faf4c9a7a"}, + {file = "pillow-12.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:9212d6b86917a2300669511ed094a9406888362e085f2431a7da985a6b124f45"}, + {file = "pillow-12.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:00162e9ca6d22b7c3ee8e61faa3c3253cd19b6a37f126cad04f2f88b306f557d"}, + {file = "pillow-12.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:7d6daa89a00b58c37cb1747ec9fb7ac3bc5ffd5949f5888657dfddde6d1312e0"}, + {file = "pillow-12.1.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e2479c7f02f9d505682dc47df8c0ea1fc5e264c4d1629a5d63fe3e2334b89554"}, + {file = "pillow-12.1.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f188d580bd870cda1e15183790d1cc2fa78f666e76077d103edf048eed9c356e"}, + {file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0fde7ec5538ab5095cc02df38ee99b0443ff0e1c847a045554cf5f9af1f4aa82"}, + {file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ed07dca4a8464bada6139ab38f5382f83e5f111698caf3191cb8dbf27d908b4"}, + {file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f45bd71d1fa5e5749587613037b172e0b3b23159d1c00ef2fc920da6f470e6f0"}, + {file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:277518bf4fe74aa91489e1b20577473b19ee70fb97c374aa50830b279f25841b"}, + {file = "pillow-12.1.0-cp313-cp313t-win32.whl", hash = "sha256:7315f9137087c4e0ee73a761b163fc9aa3b19f5f606a7fc08d83fd3e4379af65"}, + {file = "pillow-12.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:0ddedfaa8b5f0b4ffbc2fa87b556dc59f6bb4ecb14a53b33f9189713ae8053c0"}, + {file = "pillow-12.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:80941e6d573197a0c28f394753de529bb436b1ca990ed6e765cf42426abc39f8"}, + {file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:5cb7bc1966d031aec37ddb9dcf15c2da5b2e9f7cc3ca7c54473a20a927e1eb91"}, + {file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:97e9993d5ed946aba26baf9c1e8cf18adbab584b99f452ee72f7ee8acb882796"}, + {file = "pillow-12.1.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:414b9a78e14ffeb98128863314e62c3f24b8a86081066625700b7985b3f529bd"}, + {file = "pillow-12.1.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:e6bdb408f7c9dd2a5ff2b14a3b0bb6d4deb29fb9961e6eb3ae2031ae9a5cec13"}, + {file = "pillow-12.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3413c2ae377550f5487991d444428f1a8ae92784aac79caa8b1e3b89b175f77e"}, + {file = "pillow-12.1.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e5dcbe95016e88437ecf33544ba5db21ef1b8dd6e1b434a2cb2a3d605299e643"}, + {file = "pillow-12.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d0a7735df32ccbcc98b98a1ac785cc4b19b580be1bdf0aeb5c03223220ea09d5"}, + {file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0c27407a2d1b96774cbc4a7594129cc027339fd800cd081e44497722ea1179de"}, + {file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15c794d74303828eaa957ff8070846d0efe8c630901a1c753fdc63850e19ecd9"}, + {file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c990547452ee2800d8506c4150280757f88532f3de2a58e3022e9b179107862a"}, + {file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b63e13dd27da389ed9475b3d28510f0f954bca0041e8e551b2a4eb1eab56a39a"}, + {file = "pillow-12.1.0-cp314-cp314-win32.whl", hash = "sha256:1a949604f73eb07a8adab38c4fe50791f9919344398bdc8ac6b307f755fc7030"}, + {file = "pillow-12.1.0-cp314-cp314-win_amd64.whl", hash = "sha256:4f9f6a650743f0ddee5593ac9e954ba1bdbc5e150bc066586d4f26127853ab94"}, + {file = "pillow-12.1.0-cp314-cp314-win_arm64.whl", hash = "sha256:808b99604f7873c800c4840f55ff389936ef1948e4e87645eaf3fccbc8477ac4"}, + {file = "pillow-12.1.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:bc11908616c8a283cf7d664f77411a5ed2a02009b0097ff8abbba5e79128ccf2"}, + {file = "pillow-12.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:896866d2d436563fa2a43a9d72f417874f16b5545955c54a64941e87c1376c61"}, + {file = "pillow-12.1.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8e178e3e99d3c0ea8fc64b88447f7cac8ccf058af422a6cedc690d0eadd98c51"}, + {file = "pillow-12.1.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:079af2fb0c599c2ec144ba2c02766d1b55498e373b3ac64687e43849fbbef5bc"}, + {file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdec5e43377761c5dbca620efb69a77f6855c5a379e32ac5b158f54c84212b14"}, + {file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:565c986f4b45c020f5421a4cea13ef294dde9509a8577f29b2fc5edc7587fff8"}, + {file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:43aca0a55ce1eefc0aefa6253661cb54571857b1a7b2964bd8a1e3ef4b729924"}, + {file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0deedf2ea233722476b3a81e8cdfbad786f7adbed5d848469fa59fe52396e4ef"}, + {file = "pillow-12.1.0-cp314-cp314t-win32.whl", hash = "sha256:b17fbdbe01c196e7e159aacb889e091f28e61020a8abeac07b68079b6e626988"}, + {file = "pillow-12.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27b9baecb428899db6c0de572d6d305cfaf38ca1596b5c0542a5182e3e74e8c6"}, + {file = "pillow-12.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f61333d817698bdcdd0f9d7793e365ac3d2a21c1f1eb02b32ad6aefb8d8ea831"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ca94b6aac0d7af2a10ba08c0f888b3d5114439b6b3ef39968378723622fed377"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:351889afef0f485b84078ea40fe33727a0492b9af3904661b0abbafee0355b72"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb0984b30e973f7e2884362b7d23d0a348c7143ee559f38ef3eaab640144204c"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:84cabc7095dd535ca934d57e9ce2a72ffd216e435a84acb06b2277b1de2689bd"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53d8b764726d3af1a138dd353116f774e3862ec7e3794e0c8781e30db0f35dfc"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5da841d81b1a05ef940a8567da92decaa15bc4d7dedb540a8c219ad83d91808a"}, + {file = "pillow-12.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:75af0b4c229ac519b155028fa1be632d812a519abba9b46b20e50c6caa184f19"}, + {file = "pillow-12.1.0.tar.gz", hash = "sha256:5c5ae0a06e9ea030ab786b0251b32c7e4ce10e58d983c0d5c56029455180b5b9"}, ] [package.extras] @@ -4678,18 +4650,20 @@ virtualenv = ">=20.10.0" [[package]] name = "prometheus-client" -version = "0.23.1" +version = "0.24.1" description = "Python client for the Prometheus monitoring system." optional = true python-versions = ">=3.9" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "prometheus_client-0.23.1-py3-none-any.whl", hash = "sha256:dd1913e6e76b59cfe44e7a4b83e01afc9873c1bdfd2ed8739f1e76aeca115f99"}, - {file = "prometheus_client-0.23.1.tar.gz", hash = "sha256:6ae8f9081eaaaf153a2e959d2e6c4f4fb57b12ef76c8c7980202f1e57b48b2ce"}, + {file = "prometheus_client-0.24.1-py3-none-any.whl", hash = "sha256:150db128af71a5c2482b36e588fc8a6b95e498750da4b17065947c16070f4055"}, + {file = "prometheus_client-0.24.1.tar.gz", hash = "sha256:7e0ced7fbbd40f7b84962d5d2ab6f17ef88a72504dcf7c0b40737b43b2a461f9"}, ] [package.extras] +aiohttp = ["aiohttp"] +django = ["django"] twisted = ["twisted"] [[package]] @@ -4710,56 +4684,58 @@ wcwidth = "*" [[package]] name = "protobuf" -version = "6.33.2" +version = "6.33.4" description = "" optional = true python-versions = ">=3.9" groups = ["main"] files = [ - {file = "protobuf-6.33.2-cp310-abi3-win32.whl", hash = "sha256:87eb388bd2d0f78febd8f4c8779c79247b26a5befad525008e49a6955787ff3d"}, - {file = "protobuf-6.33.2-cp310-abi3-win_amd64.whl", hash = "sha256:fc2a0e8b05b180e5fc0dd1559fe8ebdae21a27e81ac77728fb6c42b12c7419b4"}, - {file = "protobuf-6.33.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d9b19771ca75935b3a4422957bc518b0cecb978b31d1dd12037b088f6bcc0e43"}, - {file = "protobuf-6.33.2-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:b5d3b5625192214066d99b2b605f5783483575656784de223f00a8d00754fc0e"}, - {file = "protobuf-6.33.2-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:8cd7640aee0b7828b6d03ae518b5b4806fdfc1afe8de82f79c3454f8aef29872"}, - {file = "protobuf-6.33.2-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:1f8017c48c07ec5859106533b682260ba3d7c5567b1ca1f24297ce03384d1b4f"}, - {file = "protobuf-6.33.2-cp39-cp39-win32.whl", hash = "sha256:7109dcc38a680d033ffb8bf896727423528db9163be1b6a02d6a49606dcadbfe"}, - {file = "protobuf-6.33.2-cp39-cp39-win_amd64.whl", hash = "sha256:2981c58f582f44b6b13173e12bb8656711189c2a70250845f264b877f00b1913"}, - {file = "protobuf-6.33.2-py3-none-any.whl", hash = "sha256:7636aad9bb01768870266de5dc009de2d1b936771b38a793f73cbbf279c91c5c"}, - {file = "protobuf-6.33.2.tar.gz", hash = "sha256:56dc370c91fbb8ac85bc13582c9e373569668a290aa2e66a590c2a0d35ddb9e4"}, + {file = "protobuf-6.33.4-cp310-abi3-win32.whl", hash = "sha256:918966612c8232fc6c24c78e1cd89784307f5814ad7506c308ee3cf86662850d"}, + {file = "protobuf-6.33.4-cp310-abi3-win_amd64.whl", hash = "sha256:8f11ffae31ec67fc2554c2ef891dcb561dae9a2a3ed941f9e134c2db06657dbc"}, + {file = "protobuf-6.33.4-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2fe67f6c014c84f655ee06f6f66213f9254b3a8b6bda6cda0ccd4232c73c06f0"}, + {file = "protobuf-6.33.4-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:757c978f82e74d75cba88eddec479df9b99a42b31193313b75e492c06a51764e"}, + {file = "protobuf-6.33.4-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c7c64f259c618f0bef7bee042075e390debbf9682334be2b67408ec7c1c09ee6"}, + {file = "protobuf-6.33.4-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:3df850c2f8db9934de4cf8f9152f8dc2558f49f298f37f90c517e8e5c84c30e9"}, + {file = "protobuf-6.33.4-cp39-cp39-win32.whl", hash = "sha256:955478a89559fa4568f5a81dce77260eabc5c686f9e8366219ebd30debf06aa6"}, + {file = "protobuf-6.33.4-cp39-cp39-win_amd64.whl", hash = "sha256:0f12ddbf96912690c3582f9dffb55530ef32015ad8e678cd494312bd78314c4f"}, + {file = "protobuf-6.33.4-py3-none-any.whl", hash = "sha256:1fe3730068fcf2e595816a6c34fe66eeedd37d51d0400b72fabc848811fdc1bc"}, + {file = "protobuf-6.33.4.tar.gz", hash = "sha256:dc2e61bca3b10470c1912d166fe0af67bfc20eb55971dcef8dfa48ce14f0ed91"}, ] [[package]] name = "psutil" -version = "7.1.3" +version = "7.2.1" description = "Cross-platform lib for process and system monitoring." optional = true python-versions = ">=3.6" groups = ["main"] files = [ - {file = "psutil-7.1.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0005da714eee687b4b8decd3d6cc7c6db36215c9e74e5ad2264b90c3df7d92dc"}, - {file = "psutil-7.1.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:19644c85dcb987e35eeeaefdc3915d059dac7bd1167cdcdbf27e0ce2df0c08c0"}, - {file = "psutil-7.1.3-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:95ef04cf2e5ba0ab9eaafc4a11eaae91b44f4ef5541acd2ee91d9108d00d59a7"}, - {file = "psutil-7.1.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1068c303be3a72f8e18e412c5b2a8f6d31750fb152f9cb106b54090296c9d251"}, - {file = "psutil-7.1.3-cp313-cp313t-win_amd64.whl", hash = "sha256:18349c5c24b06ac5612c0428ec2a0331c26443d259e2a0144a9b24b4395b58fa"}, - {file = "psutil-7.1.3-cp313-cp313t-win_arm64.whl", hash = "sha256:c525ffa774fe4496282fb0b1187725793de3e7c6b29e41562733cae9ada151ee"}, - {file = "psutil-7.1.3-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b403da1df4d6d43973dc004d19cee3b848e998ae3154cc8097d139b77156c353"}, - {file = "psutil-7.1.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ad81425efc5e75da3f39b3e636293360ad8d0b49bed7df824c79764fb4ba9b8b"}, - {file = "psutil-7.1.3-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8f33a3702e167783a9213db10ad29650ebf383946e91bc77f28a5eb083496bc9"}, - {file = "psutil-7.1.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fac9cd332c67f4422504297889da5ab7e05fd11e3c4392140f7370f4208ded1f"}, - {file = "psutil-7.1.3-cp314-cp314t-win_amd64.whl", hash = "sha256:3792983e23b69843aea49c8f5b8f115572c5ab64c153bada5270086a2123c7e7"}, - {file = "psutil-7.1.3-cp314-cp314t-win_arm64.whl", hash = "sha256:31d77fcedb7529f27bb3a0472bea9334349f9a04160e8e6e5020f22c59893264"}, - {file = "psutil-7.1.3-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab"}, - {file = "psutil-7.1.3-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880"}, - {file = "psutil-7.1.3-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3"}, - {file = "psutil-7.1.3-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b"}, - {file = "psutil-7.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd"}, - {file = "psutil-7.1.3-cp37-abi3-win_arm64.whl", hash = "sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1"}, - {file = "psutil-7.1.3.tar.gz", hash = "sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74"}, + {file = "psutil-7.2.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:ba9f33bb525b14c3ea563b2fd521a84d2fa214ec59e3e6a2858f78d0844dd60d"}, + {file = "psutil-7.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:81442dac7abfc2f4f4385ea9e12ddf5a796721c0f6133260687fec5c3780fa49"}, + {file = "psutil-7.2.1-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ea46c0d060491051d39f0d2cff4f98d5c72b288289f57a21556cc7d504db37fc"}, + {file = "psutil-7.2.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:35630d5af80d5d0d49cfc4d64c1c13838baf6717a13effb35869a5919b854cdf"}, + {file = "psutil-7.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:923f8653416604e356073e6e0bccbe7c09990acef442def2f5640dd0faa9689f"}, + {file = "psutil-7.2.1-cp313-cp313t-win_arm64.whl", hash = "sha256:cfbe6b40ca48019a51827f20d830887b3107a74a79b01ceb8cc8de4ccb17b672"}, + {file = "psutil-7.2.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:494c513ccc53225ae23eec7fe6e1482f1b8a44674241b54561f755a898650679"}, + {file = "psutil-7.2.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:3fce5f92c22b00cdefd1645aa58ab4877a01679e901555067b1bd77039aa589f"}, + {file = "psutil-7.2.1-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:93f3f7b0bb07711b49626e7940d6fe52aa9940ad86e8f7e74842e73189712129"}, + {file = "psutil-7.2.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d34d2ca888208eea2b5c68186841336a7f5e0b990edec929be909353a202768a"}, + {file = "psutil-7.2.1-cp314-cp314t-win_amd64.whl", hash = "sha256:2ceae842a78d1603753561132d5ad1b2f8a7979cb0c283f5b52fb4e6e14b1a79"}, + {file = "psutil-7.2.1-cp314-cp314t-win_arm64.whl", hash = "sha256:08a2f175e48a898c8eb8eace45ce01777f4785bc744c90aa2cc7f2fa5462a266"}, + {file = "psutil-7.2.1-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:b2e953fcfaedcfbc952b44744f22d16575d3aa78eb4f51ae74165b4e96e55f42"}, + {file = "psutil-7.2.1-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:05cc68dbb8c174828624062e73078e7e35406f4ca2d0866c272c2410d8ef06d1"}, + {file = "psutil-7.2.1-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5e38404ca2bb30ed7267a46c02f06ff842e92da3bb8c5bfdadbd35a5722314d8"}, + {file = "psutil-7.2.1-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab2b98c9fc19f13f59628d94df5cc4cc4844bc572467d113a8b517d634e362c6"}, + {file = "psutil-7.2.1-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:f78baafb38436d5a128f837fab2d92c276dfb48af01a240b861ae02b2413ada8"}, + {file = "psutil-7.2.1-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:99a4cd17a5fdd1f3d014396502daa70b5ec21bf4ffe38393e152f8e449757d67"}, + {file = "psutil-7.2.1-cp37-abi3-win_amd64.whl", hash = "sha256:b1b0671619343aa71c20ff9767eced0483e4fc9e1f489d50923738caf6a03c17"}, + {file = "psutil-7.2.1-cp37-abi3-win_arm64.whl", hash = "sha256:0d67c1822c355aa6f7314d92018fb4268a76668a536f133599b91edd48759442"}, + {file = "psutil-7.2.1.tar.gz", hash = "sha256:f7583aec590485b43ca601dd9cea0dcd65bd7bb21d30ef4ddbf4ea6b5ed1bdd3"}, ] [package.extras] -dev = ["abi3audit", "black", "check-manifest", "colorama ; os_name == \"nt\"", "coverage", "packaging", "pylint", "pyperf", "pypinfo", "pyreadline ; os_name == \"nt\"", "pytest", "pytest-cov", "pytest-instafail", "pytest-subtests", "pytest-xdist", "pywin32 ; os_name == \"nt\" and platform_python_implementation != \"PyPy\"", "requests", "rstcheck", "ruff", "setuptools", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "validate-pyproject[all]", "virtualenv", "vulture", "wheel", "wheel ; os_name == \"nt\" and platform_python_implementation != \"PyPy\"", "wmi ; os_name == \"nt\" and platform_python_implementation != \"PyPy\""] -test = ["pytest", "pytest-instafail", "pytest-subtests", "pytest-xdist", "pywin32 ; os_name == \"nt\" and platform_python_implementation != \"PyPy\"", "setuptools", "wheel ; os_name == \"nt\" and platform_python_implementation != \"PyPy\"", "wmi ; os_name == \"nt\" and platform_python_implementation != \"PyPy\""] +dev = ["abi3audit", "black", "check-manifest", "coverage", "packaging", "psleak", "pylint", "pyperf", "pypinfo", "pytest", "pytest-cov", "pytest-instafail", "pytest-xdist", "requests", "rstcheck", "ruff", "setuptools", "sphinx", "sphinx_rtd_theme", "toml-sort", "twine", "validate-pyproject[all]", "virtualenv", "vulture", "wheel"] +test = ["psleak", "pytest", "pytest-instafail", "pytest-xdist", "setuptools"] [[package]] name = "ptyprocess" @@ -4959,6 +4935,30 @@ files = [ [package.dependencies] typing-extensions = ">=4.14.1" +[[package]] +name = "pydantic-settings" +version = "2.12.0" +description = "Settings management using Pydantic" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809"}, + {file = "pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0"}, +] + +[package.dependencies] +pydantic = ">=2.7.0" +python-dotenv = ">=0.21.0" +typing-inspection = ">=0.4.0" + +[package.extras] +aws-secrets-manager = ["boto3 (>=1.35.0)", "boto3-stubs[secretsmanager]"] +azure-key-vault = ["azure-identity (>=1.16.0)", "azure-keyvault-secrets (>=4.8.0)"] +gcp-secret-manager = ["google-cloud-secret-manager (>=2.23.1)"] +toml = ["tomli (>=2.0.1)"] +yaml = ["pyyaml (>=6.0.1)"] + [[package]] name = "pydata-sphinx-theme" version = "0.15.4" @@ -5057,9 +5057,9 @@ files = [ astroid = ">=4.0.2,<=4.1.dev0" colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ - {version = ">=0.2", markers = "python_version < \"3.11\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, - {version = ">=0.3.6", markers = "python_version == \"3.11\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\""}, + {version = ">=0.2", markers = "python_version < \"3.11\""}, ] isort = ">=5,<5.13 || >5.13,<8" mccabe = ">=0.6,<0.8" @@ -5073,14 +5073,14 @@ testutils = ["gitpython (>3)"] [[package]] name = "pyparsing" -version = "3.2.5" +version = "3.3.1" description = "pyparsing - Classes and methods to define and execute parsing grammars" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "pyparsing-3.2.5-py3-none-any.whl", hash = "sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e"}, - {file = "pyparsing-3.2.5.tar.gz", hash = "sha256:2df8d5b7b2802ef88e8d016a2eb9c7aeaa923529cd251ed0fe4608275d4105b6"}, + {file = "pyparsing-3.3.1-py3-none-any.whl", hash = "sha256:023b5e7e5520ad96642e2c6db4cb683d3970bd640cdf7115049a6e9c3682df82"}, + {file = "pyparsing-3.3.1.tar.gz", hash = "sha256:47fad0f17ac1e2cad3de3b458570fbc9b03560aa029ed5e16ee5554da9a2251c"}, ] [package.extras] @@ -5263,6 +5263,21 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.2.1" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61"}, + {file = "python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "python-json-logger" version = "4.0.0" @@ -5281,15 +5296,55 @@ dev = ["backports.zoneinfo ; python_version < \"3.9\"", "black", "build", "freez [[package]] name = "pytokens" -version = "0.3.0" +version = "0.4.0" description = "A Fast, spec compliant Python 3.14+ tokenizer that runs on older Pythons." optional = true python-versions = ">=3.8" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "pytokens-0.3.0-py3-none-any.whl", hash = "sha256:95b2b5eaf832e469d141a378872480ede3f251a5a5041b8ec6e581d3ac71bbf3"}, - {file = "pytokens-0.3.0.tar.gz", hash = "sha256:2f932b14ed08de5fcf0b391ace2642f858f1394c0857202959000b68ed7a458a"}, + {file = "pytokens-0.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:af0c3166aea367a9e755a283171befb92dd3043858b94ae9b3b7efbe9def26a3"}, + {file = "pytokens-0.4.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:daae524ed14ca459932cbf51d74325bea643701ba8a8b0cc2d10f7cd4b3e2b63"}, + {file = "pytokens-0.4.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e95cb158c44d642ed62f555bf8136bbe780dbd64d2fb0b9169e11ffb944664c3"}, + {file = "pytokens-0.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:df58d44630eaf25f587540e94bdf1fc50b4e6d5f212c786de0fb024bfcb8753a"}, + {file = "pytokens-0.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:55efcc36f9a2e0e930cfba0ce7f83445306b02f8326745585ed5551864eba73a"}, + {file = "pytokens-0.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:92eb3ef88f27c22dc9dbab966ace4d61f6826e02ba04dac8e2d65ea31df56c8e"}, + {file = "pytokens-0.4.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f4b77858a680635ee9904306f54b0ee4781effb89e211ba0a773d76539537165"}, + {file = "pytokens-0.4.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:25cacc20c2ad90acb56f3739d87905473c54ca1fa5967ffcd675463fe965865e"}, + {file = "pytokens-0.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:628fab535ebc9079e4db35cd63cb401901c7ce8720a9834f9ad44b9eb4e0f1d4"}, + {file = "pytokens-0.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:4d0f568d7e82b7e96be56d03b5081de40e43c904eb6492bf09aaca47cd55f35b"}, + {file = "pytokens-0.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cd8da894e5a29ba6b6da8be06a4f7589d7220c099b5e363cb0643234b9b38c2a"}, + {file = "pytokens-0.4.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:237ba7cfb677dbd3b01b09860810aceb448871150566b93cd24501d5734a04b1"}, + {file = "pytokens-0.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01d1a61e36812e4e971cfe2c0e4c1f2d66d8311031dac8bf168af8a249fa04dd"}, + {file = "pytokens-0.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e47e2ef3ec6ee86909e520d79f965f9b23389fda47460303cf715d510a6fe544"}, + {file = "pytokens-0.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3d36954aba4557fd5a418a03cf595ecbb1cdcce119f91a49b19ef09d691a22ae"}, + {file = "pytokens-0.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:73eff3bdd8ad08da679867992782568db0529b887bed4c85694f84cdf35eafc6"}, + {file = "pytokens-0.4.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d97cc1f91b1a8e8ebccf31c367f28225699bea26592df27141deade771ed0afb"}, + {file = "pytokens-0.4.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a2c8952c537cb73a1a74369501a83b7f9d208c3cf92c41dd88a17814e68d48ce"}, + {file = "pytokens-0.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5dbf56f3c748aed9310b310d5b8b14e2c96d3ad682ad5a943f381bdbbdddf753"}, + {file = "pytokens-0.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:e131804513597f2dff2b18f9911d9b6276e21ef3699abeffc1c087c65a3d975e"}, + {file = "pytokens-0.4.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:0d7374c917197106d3c4761374718bc55ea2e9ac0fb94171588ef5840ee1f016"}, + {file = "pytokens-0.4.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0cd3fa1caf9e47a72ee134a29ca6b5bea84712724bba165d6628baa190c6ea5b"}, + {file = "pytokens-0.4.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9c6986576b7b07fe9791854caa5347923005a80b079d45b63b0be70d50cce5f1"}, + {file = "pytokens-0.4.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:9940f7c2e2f54fb1cb5fe17d0803c54da7a2bf62222704eb4217433664a186a7"}, + {file = "pytokens-0.4.0-cp314-cp314-win_amd64.whl", hash = "sha256:54691cf8f299e7efabcc25adb4ce715d3cef1491e1c930eaf555182f898ef66a"}, + {file = "pytokens-0.4.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:94ff5db97a0d3cd7248a5b07ba2167bd3edc1db92f76c6db00137bbaf068ddf8"}, + {file = "pytokens-0.4.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d0dd6261cd9cc95fae1227b1b6ebee023a5fd4a4b6330b071c73a516f5f59b63"}, + {file = "pytokens-0.4.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0cdca8159df407dbd669145af4171a0d967006e0be25f3b520896bc7068f02c4"}, + {file = "pytokens-0.4.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:4b5770abeb2a24347380a1164a558f0ebe06e98aedbd54c45f7929527a5fb26e"}, + {file = "pytokens-0.4.0-cp314-cp314t-win_amd64.whl", hash = "sha256:74500d72c561dad14c037a9e86a657afd63e277dd5a3bb7570932ab7a3b12551"}, + {file = "pytokens-0.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e368e0749e4e9d86a6e08763310dc92bc69ad73d9b6db5243b30174c71a8a534"}, + {file = "pytokens-0.4.0-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:865cc65c75c8f2e9e0d8330338f649b12bfd9442561900ebaf58c596a72107d2"}, + {file = "pytokens-0.4.0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dbb9338663b3538f31c4ca7afe4f38d9b9b3a16a8be18a273a5704a1bc7a2367"}, + {file = "pytokens-0.4.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:658f870523ac1a5f4733d7db61ce9af61a0c23b2aeea3d03d1800c93f760e15f"}, + {file = "pytokens-0.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:d69a2491190a74e4b6f87f3b9dfce7a6873de3f3bf330d20083d374380becac0"}, + {file = "pytokens-0.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8cd795191c4127fcb3d7b76d84006a07748c390226f47657869235092eedbc05"}, + {file = "pytokens-0.4.0-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef2bcbddb73ac18599a86c8c549d5145130f2cd9d83dc2b5482fd8322b7806cd"}, + {file = "pytokens-0.4.0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:06ac081c1187389762b58823d90d6339e6880ce0df912f71fb9022d81d7fd429"}, + {file = "pytokens-0.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:278129d54573efdc79e75c6082e73ebd19858e22a2e848359f93629323186ca6"}, + {file = "pytokens-0.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:9380fb6d96fa5ab83ed606ebad27b6171930cc14a8a8d215f6adb187ba428690"}, + {file = "pytokens-0.4.0-py3-none-any.whl", hash = "sha256:0508d11b4de157ee12063901603be87fb0253e8f4cb9305eb168b1202ab92068"}, + {file = "pytokens-0.4.0.tar.gz", hash = "sha256:6b0b03e6ea7c9f9d47c5c61164b69ad30f4f0d70a5d9fe7eac4d19f24f77af2d"}, ] [package.extras] @@ -6116,101 +6171,100 @@ test = ["Cython", "array-api-strict (>=2.0,<2.1.1)", "asv", "gmpy2", "hypothesis [[package]] name = "scipy" -version = "1.16.3" +version = "1.17.0" description = "Fundamental algorithms for scientific computing in Python" optional = false python-versions = ">=3.11" groups = ["main"] markers = "python_version >= \"3.11\"" files = [ - {file = "scipy-1.16.3-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:40be6cf99e68b6c4321e9f8782e7d5ff8265af28ef2cd56e9c9b2638fa08ad97"}, - {file = "scipy-1.16.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:8be1ca9170fcb6223cc7c27f4305d680ded114a1567c0bd2bfcbf947d1b17511"}, - {file = "scipy-1.16.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:bea0a62734d20d67608660f69dcda23e7f90fb4ca20974ab80b6ed40df87a005"}, - {file = "scipy-1.16.3-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:2a207a6ce9c24f1951241f4693ede2d393f59c07abc159b2cb2be980820e01fb"}, - {file = "scipy-1.16.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:532fb5ad6a87e9e9cd9c959b106b73145a03f04c7d57ea3e6f6bb60b86ab0876"}, - {file = "scipy-1.16.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0151a0749efeaaab78711c78422d413c583b8cdd2011a3c1d6c794938ee9fdb2"}, - {file = "scipy-1.16.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b7180967113560cca57418a7bc719e30366b47959dd845a93206fbed693c867e"}, - {file = "scipy-1.16.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:deb3841c925eeddb6afc1e4e4a45e418d19ec7b87c5df177695224078e8ec733"}, - {file = "scipy-1.16.3-cp311-cp311-win_amd64.whl", hash = "sha256:53c3844d527213631e886621df5695d35e4f6a75f620dca412bcd292f6b87d78"}, - {file = "scipy-1.16.3-cp311-cp311-win_arm64.whl", hash = "sha256:9452781bd879b14b6f055b26643703551320aa8d79ae064a71df55c00286a184"}, - {file = "scipy-1.16.3-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:81fc5827606858cf71446a5e98715ba0e11f0dbc83d71c7409d05486592a45d6"}, - {file = "scipy-1.16.3-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:c97176013d404c7346bf57874eaac5187d969293bf40497140b0a2b2b7482e07"}, - {file = "scipy-1.16.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2b71d93c8a9936046866acebc915e2af2e292b883ed6e2cbe5c34beb094b82d9"}, - {file = "scipy-1.16.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:3d4a07a8e785d80289dfe66b7c27d8634a773020742ec7187b85ccc4b0e7b686"}, - {file = "scipy-1.16.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0553371015692a898e1aa858fed67a3576c34edefa6b7ebdb4e9dde49ce5c203"}, - {file = "scipy-1.16.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:72d1717fd3b5e6ec747327ce9bda32d5463f472c9dce9f54499e81fbd50245a1"}, - {file = "scipy-1.16.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1fb2472e72e24d1530debe6ae078db70fb1605350c88a3d14bc401d6306dbffe"}, - {file = "scipy-1.16.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c5192722cffe15f9329a3948c4b1db789fbb1f05c97899187dcf009b283aea70"}, - {file = "scipy-1.16.3-cp312-cp312-win_amd64.whl", hash = "sha256:56edc65510d1331dae01ef9b658d428e33ed48b4f77b1d51caf479a0253f96dc"}, - {file = "scipy-1.16.3-cp312-cp312-win_arm64.whl", hash = "sha256:a8a26c78ef223d3e30920ef759e25625a0ecdd0d60e5a8818b7513c3e5384cf2"}, - {file = "scipy-1.16.3-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c"}, - {file = "scipy-1.16.3-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d"}, - {file = "scipy-1.16.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9"}, - {file = "scipy-1.16.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4"}, - {file = "scipy-1.16.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959"}, - {file = "scipy-1.16.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88"}, - {file = "scipy-1.16.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234"}, - {file = "scipy-1.16.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d"}, - {file = "scipy-1.16.3-cp313-cp313-win_amd64.whl", hash = "sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304"}, - {file = "scipy-1.16.3-cp313-cp313-win_arm64.whl", hash = "sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2"}, - {file = "scipy-1.16.3-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b"}, - {file = "scipy-1.16.3-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079"}, - {file = "scipy-1.16.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a"}, - {file = "scipy-1.16.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119"}, - {file = "scipy-1.16.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c"}, - {file = "scipy-1.16.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e"}, - {file = "scipy-1.16.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135"}, - {file = "scipy-1.16.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6"}, - {file = "scipy-1.16.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc"}, - {file = "scipy-1.16.3-cp313-cp313t-win_arm64.whl", hash = "sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a"}, - {file = "scipy-1.16.3-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:875555ce62743e1d54f06cdf22c1e0bc47b91130ac40fe5d783b6dfa114beeb6"}, - {file = "scipy-1.16.3-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:bb61878c18a470021fb515a843dc7a76961a8daceaaaa8bad1332f1bf4b54657"}, - {file = "scipy-1.16.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:f2622206f5559784fa5c4b53a950c3c7c1cf3e84ca1b9c4b6c03f062f289ca26"}, - {file = "scipy-1.16.3-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:7f68154688c515cdb541a31ef8eb66d8cd1050605be9dcd74199cbd22ac739bc"}, - {file = "scipy-1.16.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3c820ddb80029fe9f43d61b81d8b488d3ef8ca010d15122b152db77dc94c22"}, - {file = "scipy-1.16.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d3837938ae715fc0fe3c39c0202de3a8853aff22ca66781ddc2ade7554b7e2cc"}, - {file = "scipy-1.16.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:aadd23f98f9cb069b3bd64ddc900c4d277778242e961751f77a8cb5c4b946fb0"}, - {file = "scipy-1.16.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b7c5f1bda1354d6a19bc6af73a649f8285ca63ac6b52e64e658a5a11d4d69800"}, - {file = "scipy-1.16.3-cp314-cp314-win_amd64.whl", hash = "sha256:e5d42a9472e7579e473879a1990327830493a7047506d58d73fc429b84c1d49d"}, - {file = "scipy-1.16.3-cp314-cp314-win_arm64.whl", hash = "sha256:6020470b9d00245926f2d5bb93b119ca0340f0d564eb6fbaad843eaebf9d690f"}, - {file = "scipy-1.16.3-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:e1d27cbcb4602680a49d787d90664fa4974063ac9d4134813332a8c53dbe667c"}, - {file = "scipy-1.16.3-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:9b9c9c07b6d56a35777a1b4cc8966118fb16cfd8daf6743867d17d36cfad2d40"}, - {file = "scipy-1.16.3-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:3a4c460301fb2cffb7f88528f30b3127742cff583603aa7dc964a52c463b385d"}, - {file = "scipy-1.16.3-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:f667a4542cc8917af1db06366d3f78a5c8e83badd56409f94d1eac8d8d9133fa"}, - {file = "scipy-1.16.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f379b54b77a597aa7ee5e697df0d66903e41b9c85a6dd7946159e356319158e8"}, - {file = "scipy-1.16.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4aff59800a3b7f786b70bfd6ab551001cb553244988d7d6b8299cb1ea653b353"}, - {file = "scipy-1.16.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:da7763f55885045036fabcebd80144b757d3db06ab0861415d1c3b7c69042146"}, - {file = "scipy-1.16.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ffa6eea95283b2b8079b821dc11f50a17d0571c92b43e2b5b12764dc5f9b285d"}, - {file = "scipy-1.16.3-cp314-cp314t-win_amd64.whl", hash = "sha256:d9f48cafc7ce94cf9b15c6bffdc443a81a27bf7075cf2dcd5c8b40f85d10c4e7"}, - {file = "scipy-1.16.3-cp314-cp314t-win_arm64.whl", hash = "sha256:21d9d6b197227a12dcbf9633320a4e34c6b0e51c57268df255a0942983bac562"}, - {file = "scipy-1.16.3.tar.gz", hash = "sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb"}, -] - -[package.dependencies] -numpy = ">=1.25.2,<2.6" - -[package.extras] -dev = ["cython-lint (>=0.12.2)", "doit (>=0.36.0)", "mypy (==1.10.0)", "pycodestyle", "pydevtool", "rich-click", "ruff (>=0.0.292)", "types-psutil", "typing_extensions"] -doc = ["intersphinx_registry", "jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.19.1)", "jupytext", "linkify-it-py", "matplotlib (>=3.5)", "myst-nb (>=1.2.0)", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<8.2.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0)"] + {file = "scipy-1.17.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:2abd71643797bd8a106dff97894ff7869eeeb0af0f7a5ce02e4227c6a2e9d6fd"}, + {file = "scipy-1.17.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:ef28d815f4d2686503e5f4f00edc387ae58dfd7a2f42e348bb53359538f01558"}, + {file = "scipy-1.17.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:272a9f16d6bb4667e8b50d25d71eddcc2158a214df1b566319298de0939d2ab7"}, + {file = "scipy-1.17.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:7204fddcbec2fe6598f1c5fdf027e9f259106d05202a959a9f1aecf036adc9f6"}, + {file = "scipy-1.17.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fc02c37a5639ee67d8fb646ffded6d793c06c5622d36b35cfa8fe5ececb8f042"}, + {file = "scipy-1.17.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dac97a27520d66c12a34fd90a4fe65f43766c18c0d6e1c0a80f114d2260080e4"}, + {file = "scipy-1.17.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb7446a39b3ae0fe8f416a9a3fdc6fba3f11c634f680f16a239c5187bc487c0"}, + {file = "scipy-1.17.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:474da16199f6af66601a01546144922ce402cb17362e07d82f5a6cf8f963e449"}, + {file = "scipy-1.17.0-cp311-cp311-win_amd64.whl", hash = "sha256:255c0da161bd7b32a6c898e7891509e8a9289f0b1c6c7d96142ee0d2b114c2ea"}, + {file = "scipy-1.17.0-cp311-cp311-win_arm64.whl", hash = "sha256:85b0ac3ad17fa3be50abd7e69d583d98792d7edc08367e01445a1e2076005379"}, + {file = "scipy-1.17.0-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:0d5018a57c24cb1dd828bcf51d7b10e65986d549f52ef5adb6b4d1ded3e32a57"}, + {file = "scipy-1.17.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:88c22af9e5d5a4f9e027e26772cc7b5922fab8bcc839edb3ae33de404feebd9e"}, + {file = "scipy-1.17.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f3cd947f20fe17013d401b64e857c6b2da83cae567adbb75b9dcba865abc66d8"}, + {file = "scipy-1.17.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:e8c0b331c2c1f531eb51f1b4fc9ba709521a712cce58f1aa627bc007421a5306"}, + {file = "scipy-1.17.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5194c445d0a1c7a6c1a4a4681b6b7c71baad98ff66d96b949097e7513c9d6742"}, + {file = "scipy-1.17.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9eeb9b5f5997f75507814ed9d298ab23f62cf79f5a3ef90031b1ee2506abdb5b"}, + {file = "scipy-1.17.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:40052543f7bbe921df4408f46003d6f01c6af109b9e2c8a66dd1cf6cf57f7d5d"}, + {file = "scipy-1.17.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0cf46c8013fec9d3694dc572f0b54100c28405d55d3e2cb15e2895b25057996e"}, + {file = "scipy-1.17.0-cp312-cp312-win_amd64.whl", hash = "sha256:0937a0b0d8d593a198cededd4c439a0ea216a3f36653901ea1f3e4be949056f8"}, + {file = "scipy-1.17.0-cp312-cp312-win_arm64.whl", hash = "sha256:f603d8a5518c7426414d1d8f82e253e454471de682ce5e39c29adb0df1efb86b"}, + {file = "scipy-1.17.0-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:65ec32f3d32dfc48c72df4291345dae4f048749bc8d5203ee0a3f347f96c5ce6"}, + {file = "scipy-1.17.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:1f9586a58039d7229ce77b52f8472c972448cded5736eaf102d5658bbac4c269"}, + {file = "scipy-1.17.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:9fad7d3578c877d606b1150135c2639e9de9cecd3705caa37b66862977cc3e72"}, + {file = "scipy-1.17.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:423ca1f6584fc03936972b5f7c06961670dbba9f234e71676a7c7ccf938a0d61"}, + {file = "scipy-1.17.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fe508b5690e9eaaa9467fc047f833af58f1152ae51a0d0aed67aa5801f4dd7d6"}, + {file = "scipy-1.17.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6680f2dfd4f6182e7d6db161344537da644d1cf85cf293f015c60a17ecf08752"}, + {file = "scipy-1.17.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:eec3842ec9ac9de5917899b277428886042a93db0b227ebbe3a333b64ec7643d"}, + {file = "scipy-1.17.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d7425fcafbc09a03731e1bc05581f5fad988e48c6a861f441b7ab729a49a55ea"}, + {file = "scipy-1.17.0-cp313-cp313-win_amd64.whl", hash = "sha256:87b411e42b425b84777718cc41516b8a7e0795abfa8e8e1d573bf0ef014f0812"}, + {file = "scipy-1.17.0-cp313-cp313-win_arm64.whl", hash = "sha256:357ca001c6e37601066092e7c89cca2f1ce74e2a520ca78d063a6d2201101df2"}, + {file = "scipy-1.17.0-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:ec0827aa4d36cb79ff1b81de898e948a51ac0b9b1c43e4a372c0508c38c0f9a3"}, + {file = "scipy-1.17.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:819fc26862b4b3c73a60d486dbb919202f3d6d98c87cf20c223511429f2d1a97"}, + {file = "scipy-1.17.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:363ad4ae2853d88ebcde3ae6ec46ccca903ea9835ee8ba543f12f575e7b07e4e"}, + {file = "scipy-1.17.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:979c3a0ff8e5ba254d45d59ebd38cde48fce4f10b5125c680c7a4bfe177aab07"}, + {file = "scipy-1.17.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:130d12926ae34399d157de777472bf82e9061c60cc081372b3118edacafe1d00"}, + {file = "scipy-1.17.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6e886000eb4919eae3a44f035e63f0fd8b651234117e8f6f29bad1cd26e7bc45"}, + {file = "scipy-1.17.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:13c4096ac6bc31d706018f06a49abe0485f96499deb82066b94d19b02f664209"}, + {file = "scipy-1.17.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:cacbaddd91fcffde703934897c5cd2c7cb0371fac195d383f4e1f1c5d3f3bd04"}, + {file = "scipy-1.17.0-cp313-cp313t-win_amd64.whl", hash = "sha256:edce1a1cf66298cccdc48a1bdf8fb10a3bf58e8b58d6c3883dd1530e103f87c0"}, + {file = "scipy-1.17.0-cp313-cp313t-win_arm64.whl", hash = "sha256:30509da9dbec1c2ed8f168b8d8aa853bc6723fede1dbc23c7d43a56f5ab72a67"}, + {file = "scipy-1.17.0-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:c17514d11b78be8f7e6331b983a65a7f5ca1fd037b95e27b280921fe5606286a"}, + {file = "scipy-1.17.0-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:4e00562e519c09da34c31685f6acc3aa384d4d50604db0f245c14e1b4488bfa2"}, + {file = "scipy-1.17.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:f7df7941d71314e60a481e02d5ebcb3f0185b8d799c70d03d8258f6c80f3d467"}, + {file = "scipy-1.17.0-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:aabf057c632798832f071a8dde013c2e26284043934f53b00489f1773b33527e"}, + {file = "scipy-1.17.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a38c3337e00be6fd8a95b4ed66b5d988bac4ec888fd922c2ea9fe5fb1603dd67"}, + {file = "scipy-1.17.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00fb5f8ec8398ad90215008d8b6009c9db9fa924fd4c7d6be307c6f945f9cd73"}, + {file = "scipy-1.17.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f2a4942b0f5f7c23c7cd641a0ca1955e2ae83dedcff537e3a0259096635e186b"}, + {file = "scipy-1.17.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:dbf133ced83889583156566d2bdf7a07ff89228fe0c0cb727f777de92092ec6b"}, + {file = "scipy-1.17.0-cp314-cp314-win_amd64.whl", hash = "sha256:3625c631a7acd7cfd929e4e31d2582cf00f42fcf06011f59281271746d77e061"}, + {file = "scipy-1.17.0-cp314-cp314-win_arm64.whl", hash = "sha256:9244608d27eafe02b20558523ba57f15c689357c85bdcfe920b1828750aa26eb"}, + {file = "scipy-1.17.0-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:2b531f57e09c946f56ad0b4a3b2abee778789097871fc541e267d2eca081cff1"}, + {file = "scipy-1.17.0-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:13e861634a2c480bd237deb69333ac79ea1941b94568d4b0efa5db5e263d4fd1"}, + {file = "scipy-1.17.0-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:eb2651271135154aa24f6481cbae5cc8af1f0dd46e6533fb7b56aa9727b6a232"}, + {file = "scipy-1.17.0-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:c5e8647f60679790c2f5c76be17e2e9247dc6b98ad0d3b065861e082c56e078d"}, + {file = "scipy-1.17.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5fb10d17e649e1446410895639f3385fd2bf4c3c7dfc9bea937bddcbc3d7b9ba"}, + {file = "scipy-1.17.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8547e7c57f932e7354a2319fab613981cde910631979f74c9b542bb167a8b9db"}, + {file = "scipy-1.17.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:33af70d040e8af9d5e7a38b5ed3b772adddd281e3062ff23fec49e49681c38cf"}, + {file = "scipy-1.17.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:f9eb55bb97d00f8b7ab95cb64f873eb0bf54d9446264d9f3609130381233483f"}, + {file = "scipy-1.17.0-cp314-cp314t-win_amd64.whl", hash = "sha256:1ff269abf702f6c7e67a4b7aad981d42871a11b9dd83c58d2d2ea624efbd1088"}, + {file = "scipy-1.17.0-cp314-cp314t-win_arm64.whl", hash = "sha256:031121914e295d9791319a1875444d55079885bbae5bdc9c5e0f2ee5f09d34ff"}, + {file = "scipy-1.17.0.tar.gz", hash = "sha256:2591060c8e648d8b96439e111ac41fd8342fdeff1876be2e19dea3fe8930454e"}, +] + +[package.dependencies] +numpy = ">=1.26.4,<2.7" + +[package.extras] +dev = ["click (<8.3.0)", "cython-lint (>=0.12.2)", "mypy (==1.10.0)", "pycodestyle", "ruff (>=0.12.0)", "spin", "types-psutil", "typing_extensions"] +doc = ["intersphinx_registry", "jupyterlite-pyodide-kernel", "jupyterlite-sphinx (>=0.19.1)", "jupytext", "linkify-it-py", "matplotlib (>=3.5)", "myst-nb (>=1.2.0)", "numpydoc", "pooch", "pydata-sphinx-theme (>=0.15.2)", "sphinx (>=5.0.0,<8.2.0)", "sphinx-copybutton", "sphinx-design (>=0.4.0)", "tabulate"] test = ["Cython", "array-api-strict (>=2.3.1)", "asv", "gmpy2", "hypothesis (>=6.30)", "meson", "mpmath", "ninja ; sys_platform != \"emscripten\"", "pooch", "pytest (>=8.0.0)", "pytest-cov", "pytest-timeout", "pytest-xdist", "scikit-umfpack", "threadpoolctl"] [[package]] name = "send2trash" -version = "1.8.3" +version = "2.1.0" description = "Send file to trash natively under Mac OS X, Windows and Linux" optional = true -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=3.8" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "Send2Trash-1.8.3-py3-none-any.whl", hash = "sha256:0c31227e0bd08961c7665474a3d1ef7193929fedda4233843689baa056be46c9"}, - {file = "Send2Trash-1.8.3.tar.gz", hash = "sha256:b18e7a3966d99871aefeb00cfbcfdced55ce4871194810fc71f4aa484b953abf"}, + {file = "send2trash-2.1.0-py3-none-any.whl", hash = "sha256:0da2f112e6d6bb22de6aa6daa7e144831a4febf2a87261451c4ad849fe9a873c"}, + {file = "send2trash-2.1.0.tar.gz", hash = "sha256:1c72b39f09457db3c05ce1d19158c2cbef4c32b8bedd02c155e49282b7ea7459"}, ] [package.extras] -nativelib = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\"", "pywin32 ; sys_platform == \"win32\""] -objc = ["pyobjc-framework-Cocoa ; sys_platform == \"darwin\""] -win32 = ["pywin32 ; sys_platform == \"win32\""] +nativelib = ["pyobjc (>=9.0) ; sys_platform == \"darwin\"", "pywin32 (>=305) ; sys_platform == \"win32\""] +test = ["pytest (>=8)"] [[package]] name = "setuptools" @@ -6219,7 +6273,7 @@ description = "Easily download, build, install, upgrade, and uninstall Python pa optional = true python-versions = ">=3.9" groups = ["main"] -markers = "(extra == \"dev\" or extra == \"docs\") and (extra == \"dev\" or extra == \"pytorch\" or extra == \"docs\") or python_version >= \"3.12\" and (extra == \"dev\" or extra == \"docs\" or extra == \"pytorch\")" +markers = "python_version >= \"3.12\" and (extra == \"dev\" or extra == \"docs\" or extra == \"pytorch\") or (extra == \"dev\" or extra == \"docs\") and (extra == \"dev\" or extra == \"docs\" or extra == \"pytorch\")" files = [ {file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"}, {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, @@ -6490,15 +6544,15 @@ files = [ [[package]] name = "soupsieve" -version = "2.8.1" +version = "2.8.3" description = "A modern CSS selector implementation for Beautiful Soup." optional = true python-versions = ">=3.9" groups = ["main"] markers = "extra == \"dev\" or extra == \"docs\"" files = [ - {file = "soupsieve-2.8.1-py3-none-any.whl", hash = "sha256:a11fe2a6f3d76ab3cf2de04eb339c1be5b506a8a47f2ceb6d139803177f85434"}, - {file = "soupsieve-2.8.1.tar.gz", hash = "sha256:4cf733bc50fa805f5df4b8ef4740fc0e0fa6218cf3006269afd3f9d6d80fd350"}, + {file = "soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95"}, + {file = "soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349"}, ] [[package]] @@ -7036,76 +7090,81 @@ files = [ [[package]] name = "tidy3d-extras" -version = "2.10.0" +version = "2.10.2" description = "tidy3d-extras is an optional plugin for Tidy3D providing addtional, more advanced local functionality." optional = true python-versions = ">=3.9" groups = ["main"] markers = "extra == \"extras\"" files = [ - {file = "tidy3d_extras-2.10.0-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:56f73ac274df2a3d580df1086c088242715d6fd06873d31794baf7e2550f63b9"}, - {file = "tidy3d_extras-2.10.0-cp310-cp310-macosx_15_0_x86_64.whl", hash = "sha256:b3ab140a8524435671cd9323a639dd26e4c67424ea027fc5047570769596d158"}, - {file = "tidy3d_extras-2.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:db229d6c840a995ee3a91a323d4cd3fe94b49e7a291ba30039033226ba43a538"}, - {file = "tidy3d_extras-2.10.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d77c60fc7cbd9d7b802b9e0c186fababa23fd8ec8a51742afcf0c5d95ae1f7e1"}, - {file = "tidy3d_extras-2.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8606f6aab68d1db1fafb11fe0b075449b24b08c1700c5d972f08334816565cab"}, - {file = "tidy3d_extras-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:d9df52c8ab5bc60963e19207a5271ee735c3308d9db577edacdbbb7e889e3836"}, - {file = "tidy3d_extras-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:2bf8f4e217b9060a7131bcba2914680ddd1e5e34d081141131e5f635a03d33fe"}, - {file = "tidy3d_extras-2.10.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:168d5f5d37b343a8c1af120b715882c0d481ca0c5c8292952d0d44ebd2ce579b"}, - {file = "tidy3d_extras-2.10.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:d0b5f7a82c0a70fa49e99e9168fe001bf67c6e18c4a78bece82e7e71a8034386"}, - {file = "tidy3d_extras-2.10.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7279e82ec835b41060215fa3034002dcc2caba3e9073b81dce321bb5b9e0a313"}, - {file = "tidy3d_extras-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:078166cce2db6e1685d1ae06712808e89bc8abd180caf8273a62fc07d2a82b74"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:642aeb88f05ee5628255ae601ede70cd27f0e05e8a2bc9ff847bb79bf8c37d2e"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-macosx_15_0_x86_64.whl", hash = "sha256:ffbca1f645471a9b71e668cf3fd70b2cb9ce203fa5a621b16dc2be631c757492"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6648314187f8c5b4db50fc50062b50c5be816d4f46eb2300472998b3f6289221"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5532f756081e479761106635f86f89eeedb28fd347a2802ef1dd984338d756dc"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1aaceaaecacdd3b38f4cc641860e1af30da58cf26aef8319e872a26a2fecd89a"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:52908c8ed3795f02f387506d3eeb401977b3d0de1be1fa41f60d73ae065e8f65"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:5a3c3153da5ff46b7438c0510c77524003aba18ee28a4690b4c9c71e7c49784b"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:81eadbf113f179201f8c93d2109ed4864669aff49b891e227f07fc066ea77993"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:98d5daac4498423f60a804c303d31f9901c0608263552df4478c2836985a2c1d"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eb942efbeba3fd67996debae35257eb12c7050ac4d86bef4c814cbae40a359cd"}, - {file = "tidy3d_extras-2.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:edd60a2c0dcb1575db9f1562c010694e9ae6b7bb67fbac4af522d8f0b43a1f61"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:fb1b348e46e27d87b31b273efb678d06db8023cbff2a038973755376e9767c79"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-macosx_15_0_x86_64.whl", hash = "sha256:ce8f2af8a994f65a68dd34322cde7886df03da2596af1ac00dac6a22c3491a77"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e94339c205d3944c71b396c08b4f47db6bc6ce8295da34e65dad574782dcdeb"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7879167ffc8355d3e2bf24376d3c635aa0b36023835e85ea5c30c0eed9e8dcb"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9076d74a47fca0ac10cbb93b70740006cc0b94983ae0e889f1a9d24d1512216a"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:08882ab911797c3d44406a1432801f18d9c84ec092a2d0b4fd2191c8844665af"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:a7fbb841f2a68cd7c03afe9a4524fd8b17c47372a78e81107e09c5db41690153"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3b2afd25ef30169a6810ea91f311db25144ec13516c2ea9039935e6c1345e151"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a7b1c6600892349213a917bbe59297ec74d72994da48ef0c198b2d86b9fcc303"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6f8fc41c832fc11eb914f3010b47969cb2a484c349b87be5ffb7b5e085ae6182"}, - {file = "tidy3d_extras-2.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:bba8f26536245535f7ad69e0a7918a68ad5673dda48e612ace42c08d2c1a9138"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:55a7aa26cc1f73596021958b8e497371a4bdad00711b183e0388c13b6cd2a3b5"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-macosx_15_0_x86_64.whl", hash = "sha256:3fab7a87776e5b38ead144d07185fe5234939659ad4c0359803605ae99b5570e"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28fa5697d36dc6cec401fa879a0a27e61baea29a217188bf15a0610fff539032"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7850c38b2bfad5069dca366564e317961b5a7e63224b396dca19c7fe258d9244"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c82320ada13af84a3bcf2ceedfa916cb696bc048621b458f21e9aa7277b1b8d5"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:6bf82620142ccaafef48e02abbecacc37951ded585e70baa8fbaf46790642bda"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:f066da860b275c3735531ece4f77818fcade0a0fd16191659bd6d8c5d03e099b"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4933d79f61cbf58048071db19ac6befa2e8a7e0a2dd4577a2e83cbb099c29ee6"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7d41a0582b7b90aaad26a9f9a7c9540436b3849bf294d5b677d07b04990187f2"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1750f2888e2ab843e2911797e72a09581464335d38e16736c7e938a249cd5334"}, - {file = "tidy3d_extras-2.10.0-cp313-cp313-win_amd64.whl", hash = "sha256:c9214a2b03b3c6c4849c1a0cb0d12378dafcb98eb5cd0d38dd684698f0e23aa9"}, - {file = "tidy3d_extras-2.10.0-cp39-cp39-macosx_15_0_x86_64.whl", hash = "sha256:6d981baece7829494bb98bcf855e74e385aa65d92ec7883e03f3fafaf3a8cf45"}, - {file = "tidy3d_extras-2.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de5f3210db6e5234ef329a4de84c92ee23bb251da0ac09f9a5e78e6520a1a99e"}, - {file = "tidy3d_extras-2.10.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b23177f78b4e15d4ba01e139a239332029aa9966a61391e2667d9f539e0a0af"}, - {file = "tidy3d_extras-2.10.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:744c7cdc43e6a839384d8a52a2119eb2c634d26a897eb03c9a3752e7528c9cf9"}, - {file = "tidy3d_extras-2.10.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:264119370a2f70d684a3af86ede0c3887723bb8b9476df460890d01023deb0fc"}, - {file = "tidy3d_extras-2.10.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:349badc37b98f4a0adc704ba57a8cc80c97615f0f74dc9d7411e4e6fce6d8f6d"}, - {file = "tidy3d_extras-2.10.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:43f047299d25c6834cc2169efd53b180407712f562ec4f3063e1d04ce4b1d125"}, - {file = "tidy3d_extras-2.10.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:0c6a084354070c5eb8f04d4d265b012515915021f2873e31f99273254e5d34be"}, - {file = "tidy3d_extras-2.10.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:8cf87a3ab1aef5d02a7b3b287228b01c573625b298253b5961b2ba49559f426b"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-macosx_15_0_arm64.whl", hash = "sha256:f2b7c760e00cf649598c531ac97be3579348a6b89521196c18ea508efad60679"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-macosx_15_0_x86_64.whl", hash = "sha256:d7d196b3dee8cd67c813517090e9acfea36605da8a61fefecc1ac37b183eb4b2"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f6e9b1588addf7fa05120362b03bf01da257180b69a8d67eba57649d26af16f"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:20f14fb02a85292a594b173f9562721f9e402c016d9e2be576f8617294050a4e"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df5668989bd8e5da9f700293f1c47985d75180c71ea61f0a9a956c826a04e607"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:ccc69ea3e6dd56b73a84210fd0559a0332cf08e2574eb74cd6718e1e8967bd04"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:ae2460451c0489b7a8b2c4e79586d63c03fb3bf5701cb5d8b15f775d5c906e09"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a02f9c3daf9ae2ab85a13a22d4281e1c4c0e8d80682c345afd4ada8101778dd5"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:8604a714206a52384aaea68ae2e022c511ca4e7502b94a1af010aba8761ccda0"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:36ffadc70de550eadb7da963ff8f3d4f973c4c179a7c41ea9944735e97e731bb"}, + {file = "tidy3d_extras-2.10.2-cp310-cp310-win_amd64.whl", hash = "sha256:dcab4a30cbee161981242b1e4d54bbbf6ce7b9a406b2bfd95e588985ee690342"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:6301f723cf6c0510f582c5280f1c39f59c9bb9c2ee65cb97d6c2faf87a7367f4"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-macosx_15_0_x86_64.whl", hash = "sha256:43266bc780317e73d6e1deec2a37196d31dab574c26f4226b947098576ab3a71"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:399be5304fba509859f121c2e7129386d3613a1a67b35c2d1ab2b0112fa4459e"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f02ad9bda1d87680690cf7bf195382c3561aff0215c732f7e9b6a6e78af9a9ec"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3278ce81b744ba099d795a30881b815e4372a585741b51d36da2b5a41511b28"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ae1db86adeea22fce320110208e9fffdaa9ffe9976a0e10f6835440a3d2f255b"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:7e79af69edc0a91d88e147047cb00d48bbeaec6bb8d4ec42a45894639f1aecfa"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7c7873bd1b2d14f97b016a6ee65ed85756f86e09df55195de59b5d8470849f43"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:91ccfda7d5d51331ce914e8a88160027fdefdc09bcb645da6ca17dcbefa49ae4"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6532d49df9c2eed164719ca518966affc198fd1e52eb56430d282b69c786fc4a"}, + {file = "tidy3d_extras-2.10.2-cp311-cp311-win_amd64.whl", hash = "sha256:cee1bbc888f1b917c5127d9c92e1c2d4bbb3c19f8da7947965c6dbb7650798b5"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:5af1752afe98e61e61370970ce40895baa1588eb6928cf773bdd495552dfb939"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-macosx_15_0_x86_64.whl", hash = "sha256:c7c50ac667780c5f66bf7045adc378ca2ca6137408b56284e63627b5a1a8dc8d"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e80a495dfce5362f6b58f7bacecffc9143cc9c7a9b08b47eefe0b87e508360c0"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad1a5dd5ec7451def46c952e2282fb4359b4a380cc10396e23d86788b98b577f"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b46d287a2ed18cbf9b5a0b7de117cd08163de98b5114f1dcdf7de74c1c9e5a4"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:175ab618efecbe749498da50cf205a992739f747985617c70645b509f2efc172"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:583445bd9d9e8a46a629bcff612cc32ac58aecebc1ffe4c9fd512a8afd9a1eb6"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4d2432f95372e1b0d33c957b672512a180338b843c2d9cc83690c7b7be1bb36d"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0d2cd8f354ebae333b69e3289ea92bee63d107fc1c08cf38eeb1bdb8f54ed578"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6fa57b30045df89e99066f21b950202edbdc6633579979763acfb8bf89691831"}, + {file = "tidy3d_extras-2.10.2-cp312-cp312-win_amd64.whl", hash = "sha256:8b75a74de0340c46300c7b659c8bee579b68e67ffae46a72bb64c111cb76d1b5"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:296d483aec0e0f93933bb6a932fc19c3991d3b6a5f279fd25c6bc5f966bbbe5b"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-macosx_15_0_x86_64.whl", hash = "sha256:1e7f747ec8e493f3f025a4810ab5e25883a90de5b65050d41d55edd1ff6bc4d0"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7c2e04fc175c7290d567c4e704853d02d918b5f1a68b2f2a95398f87ebd76bf"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a09ade7414a3dafd09dd1f2ef07a7aad74824576c60c42b7af9b97b2a6ee8b76"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f71ffea884fcbbb0830ccd495fd3eedc00ad4ecc7a8dcf09f304cf580de378b3"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a45e02e0746055f580637d4723209abc4887338ad757807927c94af34965f4d3"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:8bbd2a963a65ebc5dc8bd06bf05e362de7a9f10bcdee69261f81af784e02f2f7"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3341102734b6f05c09ee972e86934d800204568fffc25ea63d6a946baf02e29a"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:aad0dcf73edac613b9baba2f36fc2859994ffe070081fc46605623579cab5632"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ac263a66c04b408eb4f77cf856e3d437ea15772ef238f22bcbe802ff44366507"}, + {file = "tidy3d_extras-2.10.2-cp313-cp313-win_amd64.whl", hash = "sha256:668961025869049abbf6b13202c6071bf504c4979538d2b2799625b6a3997a39"}, + {file = "tidy3d_extras-2.10.2-cp39-cp39-macosx_15_0_x86_64.whl", hash = "sha256:577021c52731d8db5aef451b49c0b244f01440a4aaa5b35b5df95ce0a8125fb0"}, + {file = "tidy3d_extras-2.10.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76825cc6f0d6c8e081bc02797f66a7c6895f9ecfcb4948403b46ca7b5edb0ff7"}, + {file = "tidy3d_extras-2.10.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d87fba521651098707aa3862c009005c018b15ec1a6378dadf4cbea38f4d375"}, + {file = "tidy3d_extras-2.10.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:683644e95e1b17a08d6acb7ebf753e08dccf698c844dc53ee4aaecca6cf23c6a"}, + {file = "tidy3d_extras-2.10.2-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:ee0f0273edab7dbbef1c9112e232cb681ec3a930cd7f54e864c2590336399629"}, + {file = "tidy3d_extras-2.10.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:cb584505ac73a6b4781ebcd026f14b6731322ee93e4650f3e059d27e1c3793bb"}, + {file = "tidy3d_extras-2.10.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8a229c8309bbbdbae1cbcd4e764dd4d003b621048379256a30466ef652c1f092"}, + {file = "tidy3d_extras-2.10.2-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:9699f8b08cf7a7f8a0f62f07a69670d896f6173d6fca87728629b0f6dbfb13ec"}, + {file = "tidy3d_extras-2.10.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:bc93009ff6197869b2949f20587f0fbc7fa1404d229002e961414da1eba1d274"}, ] [package.dependencies] numpy = ">=2.0" -tidy3d = "2.10.0" +tidy3d = "2.10.2" xarray = ">=2024.6" [package.extras] test = ["pytest (>=7.2)"] +[package.source] +type = "legacy" +url = "https://flexcompute-625554095313.d.codeartifact.us-east-1.amazonaws.com/pypi/pypi-releases/simple" +reference = "codeartifact" + [[package]] name = "tinycss2" version = "1.4.0" @@ -7153,55 +7212,60 @@ files = [ [[package]] name = "tomli" -version = "2.3.0" +version = "2.4.0" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" groups = ["main"] markers = "python_version < \"3.11\" and (extra == \"dev\" or extra == \"tests\" or extra == \"docs\")" files = [ - {file = "tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45"}, - {file = "tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba"}, - {file = "tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf"}, - {file = "tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441"}, - {file = "tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845"}, - {file = "tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c"}, - {file = "tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456"}, - {file = "tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be"}, - {file = "tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac"}, - {file = "tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22"}, - {file = "tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f"}, - {file = "tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52"}, - {file = "tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8"}, - {file = "tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6"}, - {file = "tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876"}, - {file = "tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878"}, - {file = "tomli-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5192f562738228945d7b13d4930baffda67b69425a7f0da96d360b0a3888136b"}, - {file = "tomli-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:be71c93a63d738597996be9528f4abe628d1adf5e6eb11607bc8fe1a510b5dae"}, - {file = "tomli-2.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4665508bcbac83a31ff8ab08f424b665200c0e1e645d2bd9ab3d3e557b6185b"}, - {file = "tomli-2.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4021923f97266babc6ccab9f5068642a0095faa0a51a246a6a02fccbb3514eaf"}, - {file = "tomli-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4ea38c40145a357d513bffad0ed869f13c1773716cf71ccaa83b0fa0cc4e42f"}, - {file = "tomli-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ad805ea85eda330dbad64c7ea7a4556259665bdf9d2672f5dccc740eb9d3ca05"}, - {file = "tomli-2.3.0-cp313-cp313-win32.whl", hash = "sha256:97d5eec30149fd3294270e889b4234023f2c69747e555a27bd708828353ab606"}, - {file = "tomli-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0c95ca56fbe89e065c6ead5b593ee64b84a26fca063b5d71a1122bf26e533999"}, - {file = "tomli-2.3.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:cebc6fe843e0733ee827a282aca4999b596241195f43b4cc371d64fc6639da9e"}, - {file = "tomli-2.3.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4c2ef0244c75aba9355561272009d934953817c49f47d768070c3c94355c2aa3"}, - {file = "tomli-2.3.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c22a8bf253bacc0cf11f35ad9808b6cb75ada2631c2d97c971122583b129afbc"}, - {file = "tomli-2.3.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0eea8cc5c5e9f89c9b90c4896a8deefc74f518db5927d0e0e8d4a80953d774d0"}, - {file = "tomli-2.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b74a0e59ec5d15127acdabd75ea17726ac4c5178ae51b85bfe39c4f8a278e879"}, - {file = "tomli-2.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b5870b50c9db823c595983571d1296a6ff3e1b88f734a4c8f6fc6188397de005"}, - {file = "tomli-2.3.0-cp314-cp314-win32.whl", hash = "sha256:feb0dacc61170ed7ab602d3d972a58f14ee3ee60494292d384649a3dc38ef463"}, - {file = "tomli-2.3.0-cp314-cp314-win_amd64.whl", hash = "sha256:b273fcbd7fc64dc3600c098e39136522650c49bca95df2d11cf3b626422392c8"}, - {file = "tomli-2.3.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:940d56ee0410fa17ee1f12b817b37a4d4e4dc4d27340863cc67236c74f582e77"}, - {file = "tomli-2.3.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:f85209946d1fe94416debbb88d00eb92ce9cd5266775424ff81bc959e001acaf"}, - {file = "tomli-2.3.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a56212bdcce682e56b0aaf79e869ba5d15a6163f88d5451cbde388d48b13f530"}, - {file = "tomli-2.3.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c5f3ffd1e098dfc032d4d3af5c0ac64f6d286d98bc148698356847b80fa4de1b"}, - {file = "tomli-2.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5e01decd096b1530d97d5d85cb4dff4af2d8347bd35686654a004f8dea20fc67"}, - {file = "tomli-2.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8a35dd0e643bb2610f156cca8db95d213a90015c11fee76c946aa62b7ae7e02f"}, - {file = "tomli-2.3.0-cp314-cp314t-win32.whl", hash = "sha256:a1f7f282fe248311650081faafa5f4732bdbfef5d45fe3f2e702fbc6f2d496e0"}, - {file = "tomli-2.3.0-cp314-cp314t-win_amd64.whl", hash = "sha256:70a251f8d4ba2d9ac2542eecf008b3c8a9fc5c3f9f02c56a9d7952612be2fdba"}, - {file = "tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b"}, - {file = "tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549"}, + {file = "tomli-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b5ef256a3fd497d4973c11bf142e9ed78b150d36f5773f1ca6088c230ffc5867"}, + {file = "tomli-2.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5572e41282d5268eb09a697c89a7bee84fae66511f87533a6f88bd2f7b652da9"}, + {file = "tomli-2.4.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:551e321c6ba03b55676970b47cb1b73f14a0a4dce6a3e1a9458fd6d921d72e95"}, + {file = "tomli-2.4.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5e3f639a7a8f10069d0e15408c0b96a2a828cfdec6fca05296ebcdcc28ca7c76"}, + {file = "tomli-2.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1b168f2731796b045128c45982d3a4874057626da0e2ef1fdd722848b741361d"}, + {file = "tomli-2.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:133e93646ec4300d651839d382d63edff11d8978be23da4cc106f5a18b7d0576"}, + {file = "tomli-2.4.0-cp311-cp311-win32.whl", hash = "sha256:b6c78bdf37764092d369722d9946cb65b8767bfa4110f902a1b2542d8d173c8a"}, + {file = "tomli-2.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:d3d1654e11d724760cdb37a3d7691f0be9db5fbdaef59c9f532aabf87006dbaa"}, + {file = "tomli-2.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:cae9c19ed12d4e8f3ebf46d1a75090e4c0dc16271c5bce1c833ac168f08fb614"}, + {file = "tomli-2.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:920b1de295e72887bafa3ad9f7a792f811847d57ea6b1215154030cf131f16b1"}, + {file = "tomli-2.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7d6d9a4aee98fac3eab4952ad1d73aee87359452d1c086b5ceb43ed02ddb16b8"}, + {file = "tomli-2.4.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36b9d05b51e65b254ea6c2585b59d2c4cb91c8a3d91d0ed0f17591a29aaea54a"}, + {file = "tomli-2.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1c8a885b370751837c029ef9bc014f27d80840e48bac415f3412e6593bbc18c1"}, + {file = "tomli-2.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8768715ffc41f0008abe25d808c20c3d990f42b6e2e58305d5da280ae7d1fa3b"}, + {file = "tomli-2.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b438885858efd5be02a9a133caf5812b8776ee0c969fea02c45e8e3f296ba51"}, + {file = "tomli-2.4.0-cp312-cp312-win32.whl", hash = "sha256:0408e3de5ec77cc7f81960c362543cbbd91ef883e3138e81b729fc3eea5b9729"}, + {file = "tomli-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:685306e2cc7da35be4ee914fd34ab801a6acacb061b6a7abca922aaf9ad368da"}, + {file = "tomli-2.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:5aa48d7c2356055feef06a43611fc401a07337d5b006be13a30f6c58f869e3c3"}, + {file = "tomli-2.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:84d081fbc252d1b6a982e1870660e7330fb8f90f676f6e78b052ad4e64714bf0"}, + {file = "tomli-2.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9a08144fa4cba33db5255f9b74f0b89888622109bd2776148f2597447f92a94e"}, + {file = "tomli-2.4.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c73add4bb52a206fd0c0723432db123c0c75c280cbd67174dd9d2db228ebb1b4"}, + {file = "tomli-2.4.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fb2945cbe303b1419e2706e711b7113da57b7db31ee378d08712d678a34e51e"}, + {file = "tomli-2.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bbb1b10aa643d973366dc2cb1ad94f99c1726a02343d43cbc011edbfac579e7c"}, + {file = "tomli-2.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4cbcb367d44a1f0c2be408758b43e1ffb5308abe0ea222897d6bfc8e8281ef2f"}, + {file = "tomli-2.4.0-cp313-cp313-win32.whl", hash = "sha256:7d49c66a7d5e56ac959cb6fc583aff0651094ec071ba9ad43df785abc2320d86"}, + {file = "tomli-2.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:3cf226acb51d8f1c394c1b310e0e0e61fecdd7adcb78d01e294ac297dd2e7f87"}, + {file = "tomli-2.4.0-cp313-cp313-win_arm64.whl", hash = "sha256:d20b797a5c1ad80c516e41bc1fb0443ddb5006e9aaa7bda2d71978346aeb9132"}, + {file = "tomli-2.4.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:26ab906a1eb794cd4e103691daa23d95c6919cc2fa9160000ac02370cc9dd3f6"}, + {file = "tomli-2.4.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:20cedb4ee43278bc4f2fee6cb50daec836959aadaf948db5172e776dd3d993fc"}, + {file = "tomli-2.4.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:39b0b5d1b6dd03684b3fb276407ebed7090bbec989fa55838c98560c01113b66"}, + {file = "tomli-2.4.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a26d7ff68dfdb9f87a016ecfd1e1c2bacbe3108f4e0f8bcd2228ef9a766c787d"}, + {file = "tomli-2.4.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:20ffd184fb1df76a66e34bd1b36b4a4641bd2b82954befa32fe8163e79f1a702"}, + {file = "tomli-2.4.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:75c2f8bbddf170e8effc98f5e9084a8751f8174ea6ccf4fca5398436e0320bc8"}, + {file = "tomli-2.4.0-cp314-cp314-win32.whl", hash = "sha256:31d556d079d72db7c584c0627ff3a24c5d3fb4f730221d3444f3efb1b2514776"}, + {file = "tomli-2.4.0-cp314-cp314-win_amd64.whl", hash = "sha256:43e685b9b2341681907759cf3a04e14d7104b3580f808cfde1dfdb60ada85475"}, + {file = "tomli-2.4.0-cp314-cp314-win_arm64.whl", hash = "sha256:3d895d56bd3f82ddd6faaff993c275efc2ff38e52322ea264122d72729dca2b2"}, + {file = "tomli-2.4.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:5b5807f3999fb66776dbce568cc9a828544244a8eb84b84b9bafc080c99597b9"}, + {file = "tomli-2.4.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c084ad935abe686bd9c898e62a02a19abfc9760b5a79bc29644463eaf2840cb0"}, + {file = "tomli-2.4.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f2e3955efea4d1cfbcb87bc321e00dc08d2bcb737fd1d5e398af111d86db5df"}, + {file = "tomli-2.4.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0e0fe8a0b8312acf3a88077a0802565cb09ee34107813bba1c7cd591fa6cfc8d"}, + {file = "tomli-2.4.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:413540dce94673591859c4c6f794dfeaa845e98bf35d72ed59636f869ef9f86f"}, + {file = "tomli-2.4.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0dc56fef0e2c1c470aeac5b6ca8cc7b640bb93e92d9803ddaf9ea03e198f5b0b"}, + {file = "tomli-2.4.0-cp314-cp314t-win32.whl", hash = "sha256:d878f2a6707cc9d53a1be1414bbb419e629c3d6e67f69230217bb663e76b5087"}, + {file = "tomli-2.4.0-cp314-cp314t-win_amd64.whl", hash = "sha256:2add28aacc7425117ff6364fe9e06a183bb0251b03f986df0e78e974047571fd"}, + {file = "tomli-2.4.0-cp314-cp314t-win_arm64.whl", hash = "sha256:2b1e3b80e1d5e52e40e9b924ec43d81570f0e7d09d11081b797bc4692765a3d4"}, + {file = "tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a"}, + {file = "tomli-2.4.0.tar.gz", hash = "sha256:aa89c3f6c277dd275d8e243ad24f3b5e701491a860d5121f2cdd399fbb31fc9c"}, ] [[package]] @@ -7376,29 +7440,29 @@ files = [ [[package]] name = "tox" -version = "4.32.0" +version = "4.34.1" description = "tox is a generic virtualenv management and test command line tool" optional = true python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\"" files = [ - {file = "tox-4.32.0-py3-none-any.whl", hash = "sha256:451e81dc02ba8d1ed20efd52ee409641ae4b5d5830e008af10fe8823ef1bd551"}, - {file = "tox-4.32.0.tar.gz", hash = "sha256:1ad476b5f4d3679455b89a992849ffc3367560bbc7e9495ee8a3963542e7c8ff"}, + {file = "tox-4.34.1-py3-none-any.whl", hash = "sha256:5610d69708bab578d618959b023f8d7d5d3386ed14a2392aeebf9c583615af60"}, + {file = "tox-4.34.1.tar.gz", hash = "sha256:ef1e82974c2f5ea02954d590ee0b967fad500c3879b264ea19efb9a554f3cc60"}, ] [package.dependencies] -cachetools = ">=6.2" +cachetools = ">=6.2.4" chardet = ">=5.2" colorama = ">=0.4.6" -filelock = ">=3.20" +filelock = ">=3.20.2" packaging = ">=25" -platformdirs = ">=4.5" +platformdirs = ">=4.5.1" pluggy = ">=1.6" -pyproject-api = ">=1.9.1" +pyproject-api = ">=1.10" tomli = {version = ">=2.3", markers = "python_version < \"3.11\""} typing-extensions = {version = ">=4.15", markers = "python_version < \"3.11\""} -virtualenv = ">=20.34" +virtualenv = ">=20.35.4" [[package]] name = "tqdm" @@ -7463,15 +7527,15 @@ test = ["absl-py (>=1.4.0)", "jax (>=0.4.23)", "omegaconf (>=2.0.0)", "pydantic [[package]] name = "trimesh" -version = "4.10.1" +version = "4.11.1" description = "Import, export, process, analyze and view triangular meshes." optional = true python-versions = ">=3.8" groups = ["main"] markers = "extra == \"dev\" or extra == \"tests\" or extra == \"trimesh\" or extra == \"heatcharge\"" files = [ - {file = "trimesh-4.10.1-py3-none-any.whl", hash = "sha256:4e81fae696683dfe912ef54ce124869487d35d267b87e10fe07fc05ab62aaadb"}, - {file = "trimesh-4.10.1.tar.gz", hash = "sha256:2067ebb8dcde0d7f00c2a85bfcae4aa891c40898e5f14232592429025ee2c593"}, + {file = "trimesh-4.11.1-py3-none-any.whl", hash = "sha256:bcc082ced94610ecd2c09b031431d0f3ad74352525e23a41b5688a2897b3e3e0"}, + {file = "trimesh-4.11.1.tar.gz", hash = "sha256:9a10040ca5d1c4438e0b7af94433edf6b043f5204393fc97bb85c9159a8bf21e"}, ] [package.dependencies] @@ -7572,14 +7636,14 @@ dev = ["flake8", "flake8-annotations", "flake8-bandit", "flake8-bugbear", "flake [[package]] name = "urllib3" -version = "2.6.2" +version = "2.6.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "urllib3-2.6.2-py3-none-any.whl", hash = "sha256:ec21cddfe7724fc7cb4ba4bea7aa8e2ef36f607a4bab81aa6ce42a13dc3f03dd"}, - {file = "urllib3-2.6.2.tar.gz", hash = "sha256:016f9c98bb7e98085cb2b4b17b87d2c702975664e4f060c6532e64d1c1a5e797"}, + {file = "urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4"}, + {file = "urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed"}, ] [package.extras] @@ -7590,20 +7654,20 @@ zstd = ["backports-zstd (>=1.0.0) ; python_version < \"3.14\""] [[package]] name = "virtualenv" -version = "20.35.4" +version = "20.36.1" description = "Virtual Python Environment builder" optional = true python-versions = ">=3.8" groups = ["main"] markers = "extra == \"dev\" or extra == \"tests\" or extra == \"scikit-rf\"" files = [ - {file = "virtualenv-20.35.4-py3-none-any.whl", hash = "sha256:c21c9cede36c9753eeade68ba7d523529f228a403463376cf821eaae2b650f1b"}, - {file = "virtualenv-20.35.4.tar.gz", hash = "sha256:643d3914d73d3eeb0c552cbb12d7e82adf0e504dbf86a3182f8771a153a1971c"}, + {file = "virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f"}, + {file = "virtualenv-20.36.1.tar.gz", hash = "sha256:8befb5c81842c641f8ee658481e42641c68b5eab3521d8e092d18320902466ba"}, ] [package.dependencies] distlib = ">=0.3.7,<1" -filelock = ">=3.12.2,<4" +filelock = {version = ">=3.20.1,<4", markers = "python_version >= \"3.10\""} platformdirs = ">=3.9.1,<5" typing-extensions = {version = ">=4.13.2", markers = "python_version < \"3.11\""} @@ -7821,24 +7885,24 @@ type = ["pytest-mypy"] [[package]] name = "zizmor" -version = "1.19.0" +version = "1.22.0" description = "Static analysis for GitHub Actions" optional = true python-versions = ">=3.10" groups = ["main"] markers = "extra == \"dev\"" files = [ - {file = "zizmor-1.19.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:7c61760405296e160a35175b5d6e00edfe15360e084153cb40d4ad87f623d2af"}, - {file = "zizmor-1.19.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b7a3c55c0b48a19a23079d24bc73e3f4766747d0a2250187b4b16d6b88d43b5f"}, - {file = "zizmor-1.19.0-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:3a2322a58d3aac92d8bcb8f016715f2ba09628bd09aaab2976c51b9385c2c56c"}, - {file = "zizmor-1.19.0-py3-none-manylinux_2_28_armv7l.whl", hash = "sha256:70bdf165c9c9d1f00d820d752b06d4ab07aeed97d8de751bc3f7c1607c35ac09"}, - {file = "zizmor-1.19.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:1ceb91a99947a8a86eba89828b8879996e88046f8948b94afeb623584d7f6155"}, - {file = "zizmor-1.19.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6830c2b2765fdd98f07ed0cbbaf8fbcac718ba19aa44354cc18d9d87c12d210b"}, - {file = "zizmor-1.19.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e8e7d914944f739ba06db24d856b9dde2815469651a4d4895cc3f6296a9db59b"}, - {file = "zizmor-1.19.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5724a355128ab3904e1992e59c98eb88cc99df1e7593364cb587058f99fc7377"}, - {file = "zizmor-1.19.0-py3-none-win32.whl", hash = "sha256:9299537d37aff38545bfcd93a31c251d9b0af29ae18fbe02999128d0b2909656"}, - {file = "zizmor-1.19.0-py3-none-win_amd64.whl", hash = "sha256:91bd22257ed37b573beb350402c253d255f97f836d37a2f4616df73921a3777a"}, - {file = "zizmor-1.19.0.tar.gz", hash = "sha256:625fd810a976dca7e5b2c43469103dcc0f123504907276f6d41b1e596a374563"}, + {file = "zizmor-1.22.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:80c62a9503a4235091dd076c11925f98f8d0e3fcaa8a6e2e1153e784ff3e0062"}, + {file = "zizmor-1.22.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:05c62d8cc0c16e0c0551e5085ef1f367d3b70e18df6ff1dcb212f6a6355bee4e"}, + {file = "zizmor-1.22.0-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:92d90fbebdcbd865ff8d48a8579a23102c788d09e378b60d3452603eb15e69bc"}, + {file = "zizmor-1.22.0-py3-none-manylinux_2_28_armv7l.whl", hash = "sha256:4f81f527ea5f628537f1a80c63a832f6ee1d529201ecd473f1c06fe28498ecc1"}, + {file = "zizmor-1.22.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:ccc9fe012ca9068add4104fcf75513fb740d41056c238f98dcf97d693feaa743"}, + {file = "zizmor-1.22.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:aad09a0788df8d8e9941d47743029f182434599548a463733240b09983b9870e"}, + {file = "zizmor-1.22.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a2f3f9d2dc341921543374f75cbd99cad130676809854c3c0fb38fa44fb9ebee"}, + {file = "zizmor-1.22.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:8fbc43d9feec0e6f85dfb324956f41a58e649763fe8b269ca1c2c4d61f27764d"}, + {file = "zizmor-1.22.0-py3-none-win32.whl", hash = "sha256:05b906d08b7fa70474ad99b95ae6e332e6d527089821df71b739107488aad303"}, + {file = "zizmor-1.22.0-py3-none-win_amd64.whl", hash = "sha256:2d52bc8367986c046c116134800117bd5ad4816922e890e81b9bd37a7b28c28f"}, + {file = "zizmor-1.22.0.tar.gz", hash = "sha256:b36a24cf52af902fbf5526d9ed6b6e665cc418b7c76c93dd018225a385691200"}, ] [extras] @@ -7858,4 +7922,4 @@ vtk = ["vtk"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "be03db7155e8e244cf8f9b362e2cc25fb71a557bc7c4d1fe19a1ee21358868e7" +content-hash = "b849e1cb129ea8566392dad720649283c89b385a589e3df3b21b161c693436bb" diff --git a/pyproject.toml b/pyproject.toml index 6c34a48e93..e47dfe1408 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tidy3d" -version = "2.10.0" +version = "2.10.2" description = "A fast FDTD solver" authors = ["Tyler Hughes "] license = "LGPLv2+" @@ -23,60 +23,61 @@ include = [{ path = "tidy3d/style.mplstyle", format = ["sdist", "wheel"] }] [tool.poetry.dependencies] python = ">=3.10,<3.14" -pyroots = ">=0.5.0" -xarray = ">=2023.08" -importlib-metadata = ">=6.0.0" +pyroots = ">=0.5.0,<0.6.0" +xarray = ">=2023.08,<2026.0.0" +importlib-metadata = ">=6.0.0,<9.0.0" h5netcdf = "1.0.2" h5py = ">=3.0.0,<3.15" rich = "^13.0" -numpy = "<2.4.0" -matplotlib = "*" +numpy = ">=2.2.6,<2.5.0" +matplotlib = "^3.10.0" shapely = "^2.0" -pandas = "*" -pydantic = "^2.0" -PyYAML = "*" -dask = "*" -toml = "*" +pandas = "^2.2" +pydantic = ">=2.9,<3" +pydantic-settings = "^2.10.0" +PyYAML = "^6.0.3" +dask = "^2025.12.0" +toml = "^0.10.2" tomlkit = "^0.13.2" -autograd = ">=1.7.0" -scipy = "*" +autograd = ">=1.7.0,<2.0.0" +scipy = "^1.14" ### NOT CORE boto3 = "^1.28.0" requests = "^2.31.0" -pyjwt = "*" +pyjwt = "^2.10.1" click = "^8.1.0" -responses = "*" -joblib = "*" -typing-extensions = "*" +responses = "^0.25.8" +joblib = "^1.5.3" +typing-extensions = "^4.15.0" ### END NOT CORE ### Optional dependencies ### # development core ruff = { version = "0.11.11", optional = true } -coverage = { version = "*", optional = true } -dill = { version = "*", optional = true } -ipython = { version = "*", optional = true } -memory_profiler = { version = "*", optional = true } -psutil = { version = "*", optional = true } -pre-commit = { version = ">=4.*", optional = true } -pylint = { version = "*", optional = true } -pytest = { version = ">=8.1", optional = true } -pytest-timeout = { version = "*", optional = true } +coverage = { version = "^7.13.1", optional = true } +dill = { version = "^0.4.0", optional = true } +ipython = { version = "^8.38.0", optional = true } +memory_profiler = { version = "^0.61.0", optional = true } +psutil = { version = "^7.2.1", optional = true } +pre-commit = { version = ">=4,<5", optional = true } +pylint = { version = "^4.0.4", optional = true } +pytest = { version = ">=8.1,<10.0.0", optional = true } +pytest-timeout = { version = "^2.4.0", optional = true } pytest-xdist = "^3.6.1" pytest-cov = "^6.0.0" pytest-env = "^1.1.5" -tox = { version = "*", optional = true } -diff-cover = { version = "*", optional = true } -zizmor = { version = "*", optional = true } +tox = { version = "^4.33.0", optional = true } +diff-cover = { version = "^10.1.0", optional = true } +zizmor = { version = "^1.20.0", optional = true } mypy = { version = "1.13.0", optional = true } # gdstk -gdstk = { version = ">=0.9.49", optional = true } +gdstk = { version = ">=0.9.49,<0.10.0", optional = true } # design bayesian-optimization = { version = "<2", optional = true } pygad = { version = "3.3.1", optional = true } -pyswarms = { version = "*", optional = true } +pyswarms = { version = "^1.3.0", optional = true } # pytorch torch = [ @@ -85,41 +86,41 @@ torch = [ ] # scikit-rf -scikit-rf = { version = "*", optional = true } +scikit-rf = { version = "^1.9.0", optional = true } # trimesh networkx = { version = "^2.6.3", optional = true } rtree = { version = "1.2.0", optional = true } -trimesh = { version = ">=4.6", optional = true } +trimesh = { version = ">=4.6,<5.0.0", optional = true } # docs -jupyter = { version = "*", optional = true } -jinja2 = { version = ">=3.1.2", optional = true } -nbconvert = { version = ">=7.11.0", optional = true } -sphinx = { version = ">=6", optional = true } -nbsphinx = { version = ">=0.8.7", optional = true } -sphinx-copybutton = { version = ">=0.5.2", optional = true } -sphinx-book-theme = { version = ">=1.0.1", optional = true } -pydata-sphinx-theme = { version = ">=0.13.3", optional = true } +jupyter = { version = "^1.1.1", optional = true } +jinja2 = { version = ">=3.1.2,<4.0.0", optional = true } +nbconvert = { version = ">=7.11.0,<8.0.0", optional = true } +sphinx = { version = ">=6,<9.0.0", optional = true } +nbsphinx = { version = ">=0.8.7,<0.10.0", optional = true } +sphinx-copybutton = { version = ">=0.5.2,<0.6.0", optional = true } +sphinx-book-theme = { version = ">=1.0.1,<2.0.0", optional = true } +pydata-sphinx-theme = { version = ">=0.13.3,<0.16.0", optional = true } # divparams = {optional = true, git = "https://github.com/daquinteroflex/sphinxcontrib-divparams.git"} # TODO FIX -tmm = { version = "*", optional = true } -grcwa = { version = "*", optional = true } -sphinx-design = { version = "*", optional = true } -sphinx-favicon = { version = "*", optional = true } -sphinx-sitemap = { version = ">=2.5.1", optional = true } -sphinx-notfound-page = { version = "*", optional = true } -sphinx-tabs = { version = "*", optional = true } -nbdime = { version = "*", optional = true } -myst-parser = { version = "*", optional = true } -optax = { version = ">=0.2.2", optional = true } -signac = { version = "*", optional = true } -flax = { version = ">=0.8.2", optional = true } +tmm = { version = "^0.2.0", optional = true } +grcwa = { version = "^0.1.2", optional = true } +sphinx-design = { version = "^0.6.1", optional = true } +sphinx-favicon = { version = "^1.0.1", optional = true } +sphinx-sitemap = { version = ">=2.5.1,<3.0.0", optional = true } +sphinx-notfound-page = { version = "^1.1.0", optional = true } +sphinx-tabs = { version = "^3.4.7", optional = true } +nbdime = { version = "^4.0.2", optional = true } +myst-parser = { version = "^4.0.1", optional = true } +optax = { version = ">=0.2.2,<0.3.0", optional = true } +signac = { version = "^2.3.0", optional = true } +flax = { version = ">=0.8.2,<0.13", optional = true } sax = { version = "^0.11", optional = true } -vtk = { version = ">=9.2.6", optional = true } -sphinxemoji = { version = "*", optional = true } -cma = { version = "*", optional = true } -openpyxl = { version = "*", optional = true } -tidy3d-extras = { version = "2.10.0", optional = true } +vtk = { version = ">=9.2.6,<10.0.0", optional = true } +sphinxemoji = { version = "^0.3.2", optional = true } +cma = { version = "^4.4.1", optional = true } +openpyxl = { version = "^3.1.5", optional = true } +tidy3d-extras = { version = "2.10.2", optional = true } [tool.poetry.extras] dev = [ @@ -175,6 +176,7 @@ dev = [ 'diff-cover', 'openpyxl', 'zizmor', + 'libcst', ] tests = [ 'psutil', @@ -309,9 +311,17 @@ banned-module-level-imports = ["scipy", "matplotlib"] [tool.pytest.ini_options] # TODO: remove --assert=plain when https://github.com/scipy/scipy/issues/22236 is resolved -addopts = "--cov=tidy3d --doctest-modules -n auto --dist worksteal --assert=plain -m 'not numerical'" +# TODO(yaugenst-flex): Revisit adjoint plugin for pydantic v2 +addopts = """ +--doctest-modules -n auto --dist worksteal --assert=plain -m 'not numerical and not perf' \ +--ignore=tests/test_plugins/test_adjoint.py \ +--ignore=tidy3d/plugins/adjoint/ +--ignore=tests/test_plugins/test_adjoint.py +""" markers = [ "numerical: marks numerical tests for adjoint gradients that require running simulations (deselect with '-m \"not numerical\"')", + "perf: marks tests which test the runtime of operations (deselect with '-m \"not perf\"')", + "slow: marks tests as slow (deselect with -m 'not slow')", ] env = ["MPLBACKEND=Agg", "OMP_NUM_THREADS=1", "TIDY3D_MICROWAVE__SUPPRESS_RF_LICENSE_WARNING=true"] doctest_optionflags = "NORMALIZE_WHITESPACE ELLIPSIS" @@ -332,82 +342,24 @@ python_files = "*.py" [tool.mypy] python_version = "3.10" files = [ - "tidy3d/web", - "tidy3d/config", - "tidy3d/material_library", - "tidy3d/components/geometry" + "tidy3d", ] ignore_missing_imports = true follow_imports = "skip" disallow_untyped_defs = true disable_error_code = [ - "abstract", - "annotation-unchecked", "arg-type", - "assert-type", "assignment", "attr-defined", - "await-not-async", "call-arg", "call-overload", - "comparison-overlap", - "dict-item", - "empty-body", - "exit-return", - "explicit-override", - "func-returns-value", - "has-type", - "ignore-without-code", - "import", - "import-not-found", "import-untyped", "index", - "list-item", - "literal-required", - "method-assign", "misc", - "mutable-override", - "name-defined", - "name-match", - "narrowed-type-not-subtype", - "no-any-return", - "no-any-unimported", - "no-overload-impl", - "no-redef", - "no-untyped-call", "operator", - "overload-cannot-match", - "overload-overlap", "override", - "possibly-undefined", - "prop-decorator", - "redundant-cast", - "redundant-expr", - "redundant-self", - "return", "return-value", - "safe-super", - "str-bytes-safe", - "str-format", - "syntax", - "top-level-await", - "truthy-bool", - "truthy-function", - "truthy-iterable", - "type-abstract", - "type-arg", - "type-var", - "typeddict-item", - "typeddict-readonly-mutated", - "typeddict-unknown-key", - "unimported-reveal", "union-attr", - "unreachable", - "unused-awaitable", - "unused-coroutine", - "unused-ignore", - "used-before-def", - "valid-newtype", "valid-type", "var-annotated", ] diff --git a/schemas/EMESimulation.json b/schemas/EMESimulation.json index 9a8c933b07..373bf4723a 100644 --- a/schemas/EMESimulation.json +++ b/schemas/EMESimulation.json @@ -13001,7 +13001,7 @@ "type": "string" }, "version": { - "default": "2.10.0", + "default": "2.10.2", "type": "string" } }, diff --git a/schemas/HeatChargeSimulation.json b/schemas/HeatChargeSimulation.json index e1514a8e0d..400ba8370e 100644 --- a/schemas/HeatChargeSimulation.json +++ b/schemas/HeatChargeSimulation.json @@ -10146,7 +10146,7 @@ "type": "string" }, "version": { - "default": "2.10.0", + "default": "2.10.2", "type": "string" } }, diff --git a/schemas/HeatSimulation.json b/schemas/HeatSimulation.json index f1167d86f1..56ccff0d17 100644 --- a/schemas/HeatSimulation.json +++ b/schemas/HeatSimulation.json @@ -10146,7 +10146,7 @@ "type": "string" }, "version": { - "default": "2.10.0", + "default": "2.10.2", "type": "string" } }, diff --git a/schemas/ModeSimulation.json b/schemas/ModeSimulation.json index dda93ce073..0c76b54488 100644 --- a/schemas/ModeSimulation.json +++ b/schemas/ModeSimulation.json @@ -12815,7 +12815,7 @@ "type": "string" }, "version": { - "default": "2.10.0", + "default": "2.10.2", "type": "string" } }, diff --git a/schemas/Simulation.json b/schemas/Simulation.json index 9b86e9efdb..fd2dcd7745 100644 --- a/schemas/Simulation.json +++ b/schemas/Simulation.json @@ -17566,7 +17566,7 @@ "type": "string" }, "version": { - "default": "2.10.0", + "default": "2.10.2", "type": "string" } }, diff --git a/schemas/TerminalComponentModeler.json b/schemas/TerminalComponentModeler.json index 64e8bfabf9..64d738d724 100644 --- a/schemas/TerminalComponentModeler.json +++ b/schemas/TerminalComponentModeler.json @@ -16177,7 +16177,7 @@ "type": "string" }, "version": { - "default": "2.10.0", + "default": "2.10.2", "type": "string" } }, diff --git a/scripts/ensure_imports_from_common.py b/scripts/ensure_imports_from_common.py new file mode 100644 index 0000000000..f44e93e5ff --- /dev/null +++ b/scripts/ensure_imports_from_common.py @@ -0,0 +1,114 @@ +"""Ensure tidy3d._common modules avoid importing from tidy3d outside tidy3d._common.""" + +from __future__ import annotations + +import argparse +import ast +import sys +from collections.abc import Iterable +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(frozen=True) +class ImportViolation: + file: str + line: int + statement: str + + +def parse_args(argv: Iterable[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Ensure tidy3d._common does not import from tidy3d modules outside tidy3d._common." + ) + ) + parser.add_argument( + "--root", + default="tidy3d/_common", + help="Root directory to scan (relative to repo root).", + ) + return parser.parse_args(argv) + + +def main(argv: Iterable[str]) -> None: + args = parse_args(argv) + repo_root = Path.cwd().resolve() + root = (repo_root / args.root).resolve() + if not root.exists(): + print(f"No directory found at {root}. Skipping check.") + return + + violations: list[ImportViolation] = [] + for path in sorted(root.rglob("*.py")): + violations.extend(_violations_in_file(path, repo_root)) + + if violations: + print("Invalid tidy3d imports found in tidy3d._common:") + for violation in violations: + print(f"{violation.file}:{violation.line}: {violation.statement}") + raise SystemExit(1) + + print("No invalid tidy3d imports found in tidy3d._common.") + + +def _violations_in_file(path: Path, repo_root: Path) -> list[ImportViolation]: + source = path.read_text(encoding="utf-8") + try: + tree = ast.parse(source) + except SyntaxError as exc: + raise SystemExit(f"Syntax error parsing {path}: {exc}") from exc + + rel_path = str(path.relative_to(repo_root)) + violations: list[ImportViolation] = [] + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + name = alias.name + if name == "tidy3d" or ( + name.startswith("tidy3d.") and not name.startswith("tidy3d._common") + ): + violations.append( + ImportViolation( + file=rel_path, + line=node.lineno, + statement=_statement(source, node), + ) + ) + elif isinstance(node, ast.ImportFrom): + if node.level: + continue + module = node.module + if not module: + continue + if module == "tidy3d": + for alias in node.names: + if alias.name != "_common": + violations.append( + ImportViolation( + file=rel_path, + line=node.lineno, + statement=_statement(source, node), + ) + ) + continue + if module.startswith("tidy3d.") and not module.startswith("tidy3d._common"): + violations.append( + ImportViolation( + file=rel_path, + line=node.lineno, + statement=_statement(source, node), + ) + ) + return violations + + +def _statement(source: str, node: ast.AST) -> str: + segment = ast.get_source_segment(source, node) + if segment: + return " ".join(segment.strip().splitlines()) + return node.__class__.__name__ + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/tests/config/test_legacy_env.py b/tests/config/test_legacy_env.py index cb9cba06a9..0fe27c121e 100644 --- a/tests/config/test_legacy_env.py +++ b/tests/config/test_legacy_env.py @@ -74,11 +74,9 @@ def test_env_vars_follow_profile_switch(mock_config_dir, monkeypatch, config_man def test_web_core_environment_reexports(): """Legacy `tidy3d.web.core.environment` exports remain available via config shim.""" - - import tidy3d.web as web + from tidy3d._common.web.core import environment from tidy3d.config import Env as ConfigEnv - environment = web.core.environment assert environment.Env is ConfigEnv with warnings.catch_warnings(record=True) as caught: diff --git a/tests/config/test_loader.py b/tests/config/test_loader.py index 2131357fce..7e161e8608 100644 --- a/tests/config/test_loader.py +++ b/tests/config/test_loader.py @@ -5,8 +5,8 @@ from click.testing import CliRunner from pydantic import Field +from tidy3d._common.config import loader as config_loader # import from common as it is patched from tidy3d.config import get_manager, reload_config -from tidy3d.config import loader as config_loader from tidy3d.config import registry as config_registry from tidy3d.config.legacy import finalize_legacy_migration from tidy3d.config.loader import migrate_legacy_config diff --git a/tests/conftest.py b/tests/conftest.py index c4acd71cae..1154d0f3f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import sys from pathlib import Path import autograd @@ -113,3 +114,46 @@ def dir_name(request): def create_directory(dir_name): if dir_name is not None: directory = Path(dir_name).mkdir(parents=True, exist_ok=True) + + +class OutputTee: + """Helper class to write to two streams at once.""" + + def __init__(self, original_stdout, stderr): + self.original_stdout = original_stdout + self.stderr = stderr + + def write(self, message): + # Write to the original stdout (so pytest capture works) + self.original_stdout.write(message) + # Write to stderr (so you see it immediately) + # We generally want to flush immediately for debug prints + self.stderr.write(message) + self.stderr.flush() + + def flush(self): + self.original_stdout.flush() + self.stderr.flush() + + def __getattr__(self, attr): + # Pass any other method calls (like isatty) to the original stream + return getattr(self.original_stdout, attr) + + +@pytest.fixture() +def redirect_stdout_to_stderr(request): + """ + Automatically wraps sys.stdout to write to both stdout and stderr. + This ensures output is visible during parallel execution without + breaking pytest capturing. + """ + # 1. Capture the current stdout (which might be pytest's capture buffer) + original_stdout = sys.stdout + + # 2. Replace stdout with our Tee + sys.stdout = OutputTee(original_stdout, sys.stderr) + + yield + + # 3. Restore original stdout after test finishes + sys.stdout = original_stdout diff --git a/tests/test_components/autograd/numerical/test_autograd_box_polyslab_numerical.py b/tests/test_components/autograd/numerical/test_autograd_box_polyslab_numerical.py index 9ab680efc8..2816ed71fe 100644 --- a/tests/test_components/autograd/numerical/test_autograd_box_polyslab_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_box_polyslab_numerical.py @@ -2,7 +2,6 @@ # and PolySlab geometries. from __future__ import annotations -import sys from pathlib import Path import autograd.numpy as anp @@ -24,7 +23,6 @@ FINITE_DIFFERENCE_STEP = MESH_SPACING_UM LOCAL_GRADIENT = True VERBOSE = False -SHOW_PRINT_STATEMENTS = True PLOT_FD_ADJ_COMPARISON = False SAVE_OUTPUT_DATA = True COMPARE_TO_FINITE_DIFFERENCE = True @@ -37,9 +35,6 @@ else: pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") -if SHOW_PRINT_STATEMENTS: - sys.stdout = sys.stderr - def angled_overlap_deg(v1, v2): norm_v1 = np.linalg.norm(v1) @@ -336,7 +331,7 @@ def squeeze_dimension(array: np.ndarray, is_3d: bool, infinite_dim: int | None) ) @pytest.mark.parametrize("shift_box_center", (True, False)) def test_box_and_polyslab_gradients_match( - is_3d, infinite_dim_2d, shift_box_center, numerical_case_dir + is_3d, infinite_dim_2d, shift_box_center, numerical_case_dir, redirect_stdout_to_stderr ): """Test that the box and polyslab gradients match for rectangular slab geometries. Allow comparison as well to finite difference values.""" diff --git a/tests/test_components/autograd/numerical/test_autograd_conductivity_numerical.py b/tests/test_components/autograd/numerical/test_autograd_conductivity_numerical.py index c5dc51256d..455b4bc81b 100644 --- a/tests/test_components/autograd/numerical/test_autograd_conductivity_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_conductivity_numerical.py @@ -19,7 +19,6 @@ from __future__ import annotations import operator -import sys import autograd as ag import matplotlib.pylab as plt @@ -38,7 +37,6 @@ LOCAL_GRADIENT = True VERBOSE = False NUMERICAL_RESULTS_SUBDIR = "numerical_conductivity_test" -SHOW_PRINT_STATEMENTS = False RMS_THRESHOLD = 0.6 @@ -47,10 +45,6 @@ else: pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") -if SHOW_PRINT_STATEMENTS: - sys.stdout = sys.stderr - - # Constants for conductivity testing CONDUCTIVITY_SEED = 0.01 MESH_FACTOR_DESIGN = 30.0 @@ -327,7 +321,7 @@ def flux(sim_data): @pytest.mark.numerical @pytest.mark.parametrize("conductivity_data_test_parameters", conductivity_data_test_parameters) def test_finite_difference_conductivity_data( - conductivity_data_test_parameters, rng, numerical_case_dir + conductivity_data_test_parameters, rng, numerical_case_dir, redirect_stdout_to_stderr ): """Test autograd conductivity gradients by comparing to numerical finite difference. diff --git a/tests/test_components/autograd/numerical/test_autograd_mode_polyslab_numerical.py b/tests/test_components/autograd/numerical/test_autograd_mode_polyslab_numerical.py index f8433e617f..896b2cc508 100644 --- a/tests/test_components/autograd/numerical/test_autograd_mode_polyslab_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_mode_polyslab_numerical.py @@ -2,7 +2,6 @@ from __future__ import annotations import operator -import sys import autograd as ag import matplotlib.pylab as plt @@ -21,7 +20,6 @@ LOCAL_GRADIENT = True VERBOSE = False NUMERICAL_RESULTS_SUBDIR = "numerical_mode_polyslab_test" -SHOW_PRINT_STATEMENTS = False NUM_MODE_MONITOR_FREQUENCIES = 4 @@ -32,9 +30,6 @@ else: pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") -if SHOW_PRINT_STATEMENTS: - sys.stdout = sys.stderr - MESH_FACTOR_DESIGN = 60.0 @@ -271,8 +266,6 @@ def objective(vertices): mesh_wvls_um = [1.55, 1.55, 10 * 1.55, 10 * 1.55] adj_wvls_um = [1.5, 2.0, 10 * 1.55, 10 * 2.0] -# mesh_wvls_um = [1.55, 1.55, 10 * 1.55, 10 * 1.55] -# adj_wvls_um = [2.0, 2.0, 10 * 1.55, 10 * 2.0] geometry_sizes_wvl = [(3.0, 3.0, MODE_LAYER_HEIGHT_WVL)] polyslab_indices = np.linspace(SUBSTRATE_INDEX, WG_INDEX, 5) @@ -302,7 +295,9 @@ def objective(vertices): @pytest.mark.numerical @pytest.mark.parametrize("mode_data_test_parameters", mode_data_test_parameters) -def test_finite_difference_mode_data_polyslab(mode_data_test_parameters, rng, numerical_case_dir): +def test_finite_difference_mode_data_polyslab( + mode_data_test_parameters, rng, numerical_case_dir, redirect_stdout_to_stderr +): """Test a variety of autograd permittivity gradients for ModeData in combination with polyslab by""" """comparing them to numerical finite difference.""" diff --git a/tests/test_components/autograd/numerical/test_autograd_numerical.py b/tests/test_components/autograd/numerical/test_autograd_numerical.py index a785845d3b..d1e4af6358 100644 --- a/tests/test_components/autograd/numerical/test_autograd_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_numerical.py @@ -2,7 +2,6 @@ from __future__ import annotations import operator -import sys import autograd as ag import matplotlib.pylab as plt @@ -18,10 +17,9 @@ SAVE_FD_ADJ_DATA = False SAVE_FD_LOC = 0 SAVE_ADJ_LOC = 1 -LOCAL_GRADIENT = False +LOCAL_GRADIENT = True VERBOSE = False NUMERICAL_RESULTS_SUBDIR = "numerical_field_test" -SHOW_PRINT_STATEMENTS = False RMS_THRESHOLD = 0.25 @@ -30,9 +28,6 @@ else: pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") -if SHOW_PRINT_STATEMENTS: - sys.stdout = sys.stderr - FINITE_DIFF_PERM_SEED = 1.5**2 MESH_FACTOR_DESIGN = 30.0 @@ -219,7 +214,9 @@ def flux(sim_data): @pytest.mark.numerical @pytest.mark.parametrize("field_data_test_parameters", field_data_test_parameters) -def test_finite_difference_field_data(field_data_test_parameters, rng, numerical_case_dir): +def test_finite_difference_field_data( + field_data_test_parameters, rng, numerical_case_dir, redirect_stdout_to_stderr +): """Test a variety of autograd permittivity gradients for FieldData by""" """comparing them to numerical finite difference.""" diff --git a/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py b/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py index 1ef3607ad2..6fb612dafd 100644 --- a/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_periodic_numerical.py @@ -2,7 +2,6 @@ from __future__ import annotations import operator -import sys import autograd as ag import matplotlib.pylab as plt @@ -18,10 +17,9 @@ SAVE_FD_ADJ_DATA = True SAVE_FD_LOC = 0 SAVE_ADJ_LOC = 1 -LOCAL_GRADIENT = False +LOCAL_GRADIENT = True VERBOSE = False NUMERICAL_RESULTS_SUBDIR = "numerical_periodic_test" -SHOW_PRINT_STATEMENTS = False RMS_THRESHOLD = 0.25 @@ -30,9 +28,6 @@ else: pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") -if SHOW_PRINT_STATEMENTS: - sys.stdout = sys.stderr - FINITE_DIFF_PERM_SEED = 1.5**2 MESH_FACTOR_DESIGN = 30.0 @@ -266,7 +261,9 @@ def transmission_order_pol_amp_sq(sim_data): @pytest.mark.numerical @pytest.mark.parametrize("periodic_test_parameters", periodic_test_parameters) -def test_finite_difference_diffraction_data(periodic_test_parameters, rng, numerical_case_dir): +def test_finite_difference_diffraction_data( + periodic_test_parameters, rng, numerical_case_dir, redirect_stdout_to_stderr +): """Test a variety of autograd permittivity gradients for DiffractionData by""" """comparing them to numerical finite difference.""" @@ -410,19 +407,17 @@ def test_finite_difference_diffraction_data(periodic_test_parameters, rng, numer results_dir.mkdir(parents=True, exist_ok=True) save_path = results_dir / f"results_{save_idx}.npy" - try: - assert rms_error < RMS_THRESHOLD * fd_mag, "RMS error magnitude too large" - finally: - if save_path is not None: - np.save(save_path, test_results) - - test_number += 1 - if PLOT_FD_ADJ_COMPARISON: plt.plot(pattern_dot_adj_gradient, color="g", linewidth=2.0) plt.plot(fd_grad, color="b", linewidth=1.5, linestyle="--") plt.title(f"Gradient for objective: {eval_fn_name}") - plt.legend(["Finite difference", "Adjoint"]) + plt.legend(["Adjoint", "Finite difference"]) plt.xlabel("Sample number") plt.ylabel("Gradient value") plt.show() + + try: + assert rms_error < RMS_THRESHOLD * fd_mag, "RMS error magnitude too large" + finally: + if save_path is not None: + np.save(save_path, test_results) diff --git a/tests/test_components/autograd/numerical/test_autograd_polyslab_sidewall_numerical.py b/tests/test_components/autograd/numerical/test_autograd_polyslab_sidewall_numerical.py index 23280c4a78..7a1964bb71 100644 --- a/tests/test_components/autograd/numerical/test_autograd_polyslab_sidewall_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_polyslab_sidewall_numerical.py @@ -4,6 +4,7 @@ import numpy as np import pytest from autograd import value_and_grad +from autograd.tracer import getval import tidy3d as td import tidy3d.web as web @@ -17,6 +18,7 @@ MSPW = 20 # min steps per wavelength for auto grid UNIFORM_DX = None # set to float to force uniform grid DOMAIN = (4.0, 4.0, 4.0) +LOCAL_GRADIENT = True def _build_sim(theta_deg: float, axis: int, reference_plane: str) -> td.Simulation: @@ -73,23 +75,25 @@ def _build_sim(theta_deg: float, axis: int, reference_plane: str) -> td.Simulati def _objective(theta_deg: float, axis: int, reference_plane: str, case_dir, verbose: bool) -> float: sim = _build_sim(theta_deg, axis=axis, reference_plane=reference_plane) - task_name = f"obj_axis{axis}_ref{reference_plane}_t{float(theta_deg):+0.3f}" + task_name = f"obj_axis{axis}_ref{reference_plane}_t{float(getval(theta_deg)):+0.3f}" out_path = case_dir / f"{task_name}.hdf5" data = web.run( sim, task_name=task_name, - local_gradient=True, + local_gradient=LOCAL_GRADIENT, verbose=verbose, path=str(out_path), ) - return float(data["field"].flux.item()) + return data["field"].flux.item() @pytest.mark.numerical @pytest.mark.parametrize("axis", AXES) @pytest.mark.parametrize("reference_plane", REF_PLANES) @pytest.mark.parametrize("theta0_deg", THETAS_DEG) -def test_autograd_polyslab_sidewall_vs_fd(axis, reference_plane, theta0_deg, numerical_case_dir): +def test_autograd_polyslab_sidewall_vs_fd( + axis, reference_plane, theta0_deg, numerical_case_dir, redirect_stdout_to_stderr +): """Adjoint dJ/dtheta matches centered FD for PolySlab.sidewall_angle across axes/ref planes.""" verbose = False @@ -113,7 +117,7 @@ def test_autograd_polyslab_sidewall_vs_fd(axis, reference_plane, theta0_deg, num datas = web.run_async( sims, path_dir=str(fd_dir), - local_gradient=True, + local_gradient=LOCAL_GRADIENT, verbose=verbose, ) @@ -121,5 +125,15 @@ def test_autograd_polyslab_sidewall_vs_fd(axis, reference_plane, theta0_deg, num obj_minus = float(datas[f"minus_{uid}"]["field"].flux.item()) grad_fd = (obj_plus - obj_minus) / (2 * H) + print("\n" * 3) + print("-" * 20) + print(f"axis: {axis}") + print(f"reference_plane: {reference_plane}") + print(f"theta0_deg: {theta0_deg}") + print(f"Grad (adjoint): {grad_adj}") + print(f"Grad (finite difference): {grad_fd}") + print("-" * 20) + print("\n" * 3) + assert np.isfinite(grad_adj) assert np.isclose(grad_adj, grad_fd, rtol=RTOL, atol=ATOL) diff --git a/tests/test_components/autograd/numerical/test_autograd_polyslab_trianglemesh_numerical.py b/tests/test_components/autograd/numerical/test_autograd_polyslab_trianglemesh_numerical.py index 3f5bb0614f..0def8371c9 100644 --- a/tests/test_components/autograd/numerical/test_autograd_polyslab_trianglemesh_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_polyslab_trianglemesh_numerical.py @@ -2,8 +2,6 @@ # PolySlab and TriangleMesh geometries representing the same rectangular slab. from __future__ import annotations -import sys - import autograd.numpy as anp import numpy as np import pytest @@ -32,7 +30,6 @@ FINITE_DIFFERENCE_STEP = MESH_SPACING_UM LOCAL_GRADIENT = True VERBOSE = False -SHOW_PRINT_STATEMENTS = True PLOT_FD_ADJ_COMPARISON = False SAVE_OUTPUT_DATA = True COMPARE_TO_FINITE_DIFFERENCE = True @@ -77,9 +74,6 @@ else: pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") -if SHOW_PRINT_STATEMENTS: - sys.stdout = sys.stderr - def _triangles_from_params(params, box_center): params_arr = anp.array(params) @@ -151,7 +145,7 @@ def objective(parameters): ) @pytest.mark.parametrize("shift_box_center", (True, False)) def test_polyslab_and_trianglemesh_gradients_match( - is_3d, infinite_dim_2d, shift_box_center, tmp_path + is_3d, infinite_dim_2d, shift_box_center, tmp_path, redirect_stdout_to_stderr ): """Test that the triangle mesh and polyslab gradients match for rectangular slab geometries. Allow comparison as well to finite difference values.""" diff --git a/tests/test_components/autograd/numerical/test_autograd_sphere_trianglemesh_numerical.py b/tests/test_components/autograd/numerical/test_autograd_sphere_trianglemesh_numerical.py index b455576d39..63f9664ee5 100644 --- a/tests/test_components/autograd/numerical/test_autograd_sphere_trianglemesh_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_sphere_trianglemesh_numerical.py @@ -2,7 +2,6 @@ # geometries, comparing to finite differences for validation. from __future__ import annotations -import sys from collections.abc import Sequence from pathlib import Path from typing import Callable @@ -38,7 +37,6 @@ ICOSAHEDRON_SUBDIVISIONS = 3 LOCAL_GRADIENT = True VERBOSE = False -SHOW_PRINT_STATEMENTS = True SAVE_OUTPUT_DATA = True ANGLE_OVERLAP_FD_ADJ_THRESH_DEG = 10.0 VERTEX_FD_STEP = 1e-3 @@ -50,9 +48,6 @@ freqs = td.C_0 / np.linspace(0.6, 0.7, 101) -if SHOW_PRINT_STATEMENTS: - sys.stdout = sys.stderr - def make_base_simulation( radii: list[float], *, extra_structures: Sequence[td.Structure] | None = None @@ -228,7 +223,7 @@ def finite_difference_params(objective, params: anp.ndarray, finite_diff_step) - valid_indices.append(idx) objectives = objective(anp.stack(perturbations)) - objectives = np.asarray(objectives, dtype=float) + objectives = np.squeeze(np.asarray(objectives, dtype=float)) fd = np.zeros_like(np.asarray(params, dtype=float)) for pair_idx, param_idx in enumerate(valid_indices): obj_up = objectives[2 * pair_idx] @@ -326,7 +321,7 @@ def objective(parameters): @pytest.mark.parametrize("scale_axis", (0,)) @pytest.mark.parametrize("overlap_cube", (False,)) def test_sphere_triangles_match_fd( - scale_factor, scale_axis, overlap_cube, tmp_path, numerical_case_dir + scale_factor, scale_axis, overlap_cube, tmp_path, numerical_case_dir, redirect_stdout_to_stderr ): """ Compares FD gradients with gradients from _compute_derivatives in TriangleMesh. @@ -397,7 +392,9 @@ def test_sphere_triangles_match_fd( @pytest.mark.skip -def test_grad_insensitive_to_face_splitting(tmp_path, numerical_case_dir): +def test_grad_insensitive_to_face_splitting( + tmp_path, numerical_case_dir, redirect_stdout_to_stderr +): scale_factor = 1 scale_axis = 0 @@ -493,7 +490,7 @@ def test_grad_insensitive_to_face_splitting(tmp_path, numerical_case_dir): @pytest.mark.parametrize("scale_axis", SCALE_AXES) @pytest.mark.parametrize("overlap_cube", (False, True)) def test_triangle_sphere_fd_step_sweep_ref( - tmp_path, scale_factor, scale_axis, overlap_cube, numerical_case_dir + tmp_path, scale_factor, scale_axis, overlap_cube, numerical_case_dir, redirect_stdout_to_stderr ): global SPHERE_RADIUS_UM SPHERE_RADIUS_UM = SPHERE_RADIUS_UM * 4 diff --git a/tests/test_components/autograd/numerical/test_autograd_symmetry_numerical.py b/tests/test_components/autograd/numerical/test_autograd_symmetry_numerical.py index f51700f10c..ff08bfd77a 100644 --- a/tests/test_components/autograd/numerical/test_autograd_symmetry_numerical.py +++ b/tests/test_components/autograd/numerical/test_autograd_symmetry_numerical.py @@ -2,7 +2,6 @@ from __future__ import annotations import operator -import sys from pathlib import Path import autograd as ag @@ -20,7 +19,6 @@ SAVE_ADJ_LOC = 1 LOCAL_GRADIENT = False VERBOSE = False -SHOW_PRINT_STATEMENTS = True RMS_THRESHOLD = 0.25 @@ -29,9 +27,6 @@ else: pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") -if SHOW_PRINT_STATEMENTS: - sys.stdout = sys.stderr - FINITE_DIFF_PERM_SEED = 1.5**2 MESH_FACTOR_DESIGN = 30.0 @@ -229,7 +224,9 @@ def flux(sim_data): @pytest.mark.numerical @pytest.mark.parametrize("field_symmetry_test_parameters", field_symmetry_test_parameters) -def test_adjoint_difference_symmetry(field_symmetry_test_parameters, rng, numerical_case_dir): +def test_adjoint_difference_symmetry( + field_symmetry_test_parameters, rng, numerical_case_dir, redirect_stdout_to_stderr +): """Test the gradient is not affected by symmetry when using field sources.""" num_tests = 0 @@ -334,10 +331,9 @@ def test_adjoint_difference_symmetry(field_symmetry_test_parameters, rng, numeri mag_compare = np.sqrt(np.mean(grad_data**2)) rms_error = np.sqrt(np.mean((grad_data_base - grad_data) ** 2)) - if SHOW_PRINT_STATEMENTS: - print(f"Testing {eval_fn_name} objective") - print(f"Symmetry comparison: {symmetries[0]}, {symmetries[idx]}") - print(f"RMS error (normalized): {rms_error / np.sqrt(mag_base * mag_compare)}") + print(f"Testing {eval_fn_name} objective") + print(f"Symmetry comparison: {symmetries[0]}, {symmetries[idx]}") + print(f"RMS error (normalized): {rms_error / np.sqrt(mag_base * mag_compare)}") assert np.isclose(rms_error / np.sqrt(mag_base * mag_compare), 0.0, atol=0.075), ( "Expected adjoint gradients to be the same with and without symmetry" diff --git a/tests/test_components/autograd/test_autograd.py b/tests/test_components/autograd/test_autograd.py index b56a6345cb..59daa9fb93 100644 --- a/tests/test_components/autograd/test_autograd.py +++ b/tests/test_components/autograd/test_autograd.py @@ -34,7 +34,7 @@ from tidy3d.web import run, run_async from tidy3d.web.api.autograd import autograd as autograd_module -from ...utils import SIM_FULL, AssertLogLevel, run_emulated, tracer_arr +from ...utils import SIM_FULL, AssertLogLevel, custom_poleresidue_u, run_emulated, tracer_arr """ Test configuration """ @@ -83,6 +83,9 @@ def _patch_cmp_to_const(monkeypatch, cls, dJ_const): def _make_di(paths, freq): """Construct a minimal DerivativeInfo shared by custom dispersive tests.""" + + eps_keys = ["eps_xx", "eps_yy", "eps_zz"] + return DerivativeInfo( paths=paths, E_der_map={}, @@ -91,16 +94,19 @@ def _make_di(paths, freq): D_fwd={}, E_adj={}, D_adj={}, - eps_data={}, - eps_in=2.0, - eps_out=1.0, + eps_data={ + key: td.ScalarFieldDataArray( + [[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [200e12]} + ) + for key in eps_keys + }, frequencies=[freq], bounds=((-1, -1, -1), (1, 1, 1)), - eps_no_structure=td.ScalarFieldDataArray( - [[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [1.0]} + eps_out=td.ScalarFieldDataArray( + [[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [freq]} ), - eps_inf_structure=td.ScalarFieldDataArray( - [[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [1.0]} + eps_in=td.ScalarFieldDataArray( + [[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [freq]} ), bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), @@ -163,8 +169,8 @@ def _make_di(paths, freq): SIM_BASE = td.Simulation( size=(LX, 3.15, LZ), run_time=200 / FWIDTH, - sources=[PLANE_WAVE], - structures=[ + sources=(PLANE_WAVE,), + structures=( td.Structure( geometry=td.Box( size=(0.5, 0.5, LZ / 2), @@ -173,16 +179,16 @@ def _make_di(paths, freq): .rotated(ROT_ANGLE_WG, axis=0) .translated(x=0, y=-np.tan(ROT_ANGLE_WG) * MODE_FIELD_SPC, z=LZ / 2), medium=td.Medium(permittivity=2.0), - ) - ], - monitors=[ + ), + ), + monitors=( td.FieldMonitor( center=(0, 0, 0), size=(0, 0, 0), freqs=[FREQ0], name="extraneous", - ) - ], + ), + ), boundary_spec=td.BoundarySpec.pml(x=PML_X, y=True, z=True), grid_spec=td.GridSpec.uniform(dl=0.01 * td.C_0 / FREQ0), ) @@ -1054,8 +1060,8 @@ def test_autograd_speed_num_structures(use_emulated_run): def make_sim(*args): structure = make_structures(*args)[structure_key] - structures = num_structures_test * [structure] - return SIM_BASE.updated_copy(structures=structures, monitors=[monitor]) + structures = tuple(num_structures_test * [structure]) + return SIM_BASE.updated_copy(structures=structures, monitors=(monitor,)) def objective(*args): """Objective function.""" @@ -1111,7 +1117,7 @@ def make_sim(params, geo_maker): geo = geo_maker(*params) structure = td.Structure(geometry=geo, medium=td.Medium(permittivity=2)) - return SIM_BASE.updated_copy(structures=[structure], monitors=[monitor]) + return SIM_BASE.updated_copy(structures=(structure,), monitors=(monitor,)) p0 = [1.0, 0.0, 0.0, t0] @@ -1188,7 +1194,7 @@ def test_sim_full_ops(structure_key): def objective(*params): s = make_structures(*params)[structure_key] s = s.updated_copy(geometry=s.geometry.updated_copy(center=(2, 2, 2), size=(0, 0, 0))) - sim_full_traced = SIM_FULL.updated_copy(structures=[*list(SIM_FULL.structures), s]) + sim_full_traced = SIM_FULL.updated_copy(structures=(*SIM_FULL.structures, s)) sim_full_static = sim_full_traced.to_static() @@ -1332,7 +1338,7 @@ def f(x): geometry=td.Box(center=(0, 0, 0), size=(1, 1, x)), dl=[1, 1, 1], ) - sim = SIM_FULL.updated_copy(override_structures=[override_structure], path="grid_spec") + sim = SIM_FULL.updated_copy(override_structures=(override_structure,), path="grid_spec") return sim.grid_spec.override_structures[0].geometry.size[2] with AssertLogLevel("WARNING", contains_str="override structures"): @@ -1345,7 +1351,7 @@ def test_sim_fields_io(structure_key, tmp_path): from file, and then converting back, returns the same object.""" s = make_structures(params0)[structure_key] s = s.updated_copy(geometry=s.geometry.updated_copy(center=(2, 2, 2), size=(0, 0, 0))) - sim_full_traced = SIM_FULL.updated_copy(structures=[*list(SIM_FULL.structures), s]) + sim_full_traced = SIM_FULL.updated_copy(structures=(*SIM_FULL.structures, s)) sim_fields = sim_full_traced._strip_traced_fields() field_map = FieldMap.from_autograd_field_map(sim_fields) @@ -1401,8 +1407,8 @@ def test_too_many_traced_structures(monkeypatch, use_emulated_run): def make_sim(*args): structure = make_structures(*args)[structure_key] return SIM_BASE.updated_copy( - structures=(config.adjoint.max_traced_structures + 1) * [structure], - monitors=[monitor], + structures=(config.adjoint.max_traced_structures + 1) * (structure,), + monitors=(monitor,), ) def objective(*args): @@ -1428,7 +1434,7 @@ def objective(args): sim = SIM_BASE.updated_copy( structures=structures, - monitors=[td.FieldTimeMonitor(size=(0, 0, 0), name="time_monitor_only")], + monitors=(td.FieldTimeMonitor(size=(0, 0, 0), name="time_monitor_only"),), ) # doesn't need to be a valid objective since this should error when calling web.run return web.run(sim, task_name="autograd_test", verbose=False) @@ -1647,7 +1653,7 @@ def objective(args): for structure_key in structure_keys_: structures.append(structures_traced_dict[structure_key]) - sim = SIM_BASE.updated_copy(monitors=[monitor], structures=structures) + sim = SIM_BASE.updated_copy(monitors=(monitor,), structures=tuple(structures)) data = run(sim, task_name="autograd_test", verbose=False) if objtype == "flux": @@ -1716,7 +1722,7 @@ def setup(far_field_approx, projection_type, sim_2d): name="far_field", ) - sim = SIM_BASE.updated_copy(monitors=[monitor]) + sim = SIM_BASE.updated_copy(monitors=(monitor,)) if sim_2d and IS_3D: sim = sim.updated_copy(size=(0, *sim.size[1:])) @@ -1750,7 +1756,7 @@ def objective(args): for structure_key in structure_keys_: structures.append(structures_traced_dict[structure_key]) - sim = sim_base.updated_copy(structures=structures) + sim = sim_base.updated_copy(structures=tuple(structures)) sim_data = run(sim, task_name="field_projection_test") return self.objective(sim_data, monitor_far) @@ -1778,14 +1784,14 @@ def test_error_if_server_side_projection( """Using a far field monitor directly should error""" # build a projection‐only monitor sim sim_base, monitor_far = self.setup(far_field_approx, projection_type, sim_2d) - sim_base = sim_base.updated_copy(monitors=[monitor_far]) + sim_base = sim_base.updated_copy(monitors=(monitor_far,)) def objective(args): structures_traced_dict = make_structures(args) structures = list(SIM_BASE.structures) for structure_key in structure_keys_: structures.append(structures_traced_dict[structure_key]) - sim = sim_base.updated_copy(structures=structures) + sim = sim_base.updated_copy(structures=tuple(structures)) sim_data = run(sim, task_name="field_projection_test") return sim_data["far_field"].power.sum().item() @@ -1859,6 +1865,8 @@ def J(eps): for j in range(2): field_paths.append(("poles", i, j)) + eps_keys = ["eps_xx", "eps_yy", "eps_zz"] + info = DerivativeInfo( paths=field_paths, E_der_map={}, @@ -1867,15 +1875,18 @@ def J(eps): D_fwd={}, E_adj={}, D_adj={}, - eps_data={}, - eps_in=2.0, - eps_out=1.0, + eps_data={ + key: td.ScalarFieldDataArray( + [[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [200e12]} + ) + for key in eps_keys + }, frequencies=[freq], bounds=((-1, -1, -1), (1, 1, 1)), - eps_no_structure=td.ScalarFieldDataArray( + eps_out=td.ScalarFieldDataArray( [[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [1.94e14]} ), - eps_inf_structure=td.ScalarFieldDataArray( + eps_in=td.ScalarFieldDataArray( [[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [1.94e14]} ), bounds_intersect=((-1, -1, -1), (1, 1, 1)), @@ -1905,6 +1916,8 @@ def f(eps_inf, poles): def test_adaptive_spacing(eps_real): freq = 5e9 + eps_keys = ["eps_xx", "eps_yy", "eps_zz"] + info = DerivativeInfo( paths={}, E_der_map={}, @@ -1913,13 +1926,20 @@ def test_adaptive_spacing(eps_real): D_fwd={}, E_adj={}, D_adj={}, - eps_data={}, - eps_in=eps_real, - eps_out=1.0, + eps_data={ + key: td.ScalarFieldDataArray( + [[[[eps_real]]]], coords={"x": [0], "y": [0], "z": [0], "f": [200e12]} + ) + for key in eps_keys + }, + eps_in=td.ScalarFieldDataArray( + [[[[eps_real]]]], coords={"x": [0], "y": [0], "z": [0], "f": [freq]} + ), + eps_out=td.ScalarFieldDataArray( + [[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [freq]} + ), frequencies=[freq], bounds=((-1, -1, -1), (1, 1, 1)), - eps_no_structure={}, - eps_inf_structure={}, bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), ) @@ -1935,6 +1955,8 @@ def test_adaptive_spacing(eps_real): def test_cylinder_discretization(eps_real): freq = 5e9 + eps_keys = ["eps_xx", "eps_yy", "eps_zz"] + info = DerivativeInfo( paths={}, E_der_map={}, @@ -1943,13 +1965,20 @@ def test_cylinder_discretization(eps_real): D_fwd={}, E_adj={}, D_adj={}, - eps_data={}, - eps_in=eps_real, - eps_out=1.0, + eps_data={ + key: td.ScalarFieldDataArray( + [[[[eps_real]]]], coords={"x": [0], "y": [0], "z": [0], "f": [200e12]} + ) + for key in eps_keys + }, + eps_in=td.ScalarFieldDataArray( + [[[[eps_real]]]], coords={"x": [0], "y": [0], "z": [0], "f": [freq]} + ), + eps_out=td.ScalarFieldDataArray( + [[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [freq]} + ), frequencies=[freq], bounds=((-1, -1, -1), (1, 1, 1)), - eps_no_structure={}, - eps_inf_structure={}, bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), ) @@ -1998,8 +2027,15 @@ def J(eps): monkeypatch.setattr( td.CustomPoleResidue, - "_derivative_field_cmp", - lambda self, E_der_map, spatial_data, dim, freqs, component="real": dJ_deps / 3.0, + "_derivative_field_cmp_custom", + lambda self, + E_der_map, + spatial_data, + dim, + freqs, + bounds=None, + component="real", + interp_method=None: dJ_deps / 3.0, ) import importlib @@ -2012,6 +2048,8 @@ def J(eps): for j in range(2): field_paths.append(("poles", i, j)) + eps_keys = ["eps_xx", "eps_yy", "eps_zz"] + info = DerivativeInfo( paths=field_paths, E_der_map={}, @@ -2020,16 +2058,19 @@ def J(eps): D_fwd={}, E_adj={}, D_adj={}, - eps_data={}, - eps_in=2.0, - eps_out=1.0, + eps_data={ + key: td.ScalarFieldDataArray( + [[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [200e12]} + ) + for key in eps_keys + }, frequencies=[freq], bounds=((-1, -1, -1), (1, 1, 1)), - eps_no_structure=td.ScalarFieldDataArray( - [[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [1.94e14]} + eps_in=td.ScalarFieldDataArray( + [[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [freq]} ), - eps_inf_structure=td.ScalarFieldDataArray( - [[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [1.94e14]} + eps_out=td.ScalarFieldDataArray( + [[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [freq]} ), bounds_intersect=((-1, -1, -1), (1, 1, 1)), simulation_bounds=((-2, -2, -2), (2, 2, 2)), @@ -2060,6 +2101,43 @@ def f(eps_inf, poles): assert np.allclose(grads_computed[field_path], np.conj(grad_poles[i][j])) +def test_custom_pole_residue_unstructured_derivatives(): + """Ensure unstructured pole residue adjoints are explicitly unsupported.""" + pr = custom_poleresidue_u + field_paths = [("eps_inf",), ("poles", 0, 0), ("poles", 0, 1)] + + eps_keys = ["eps_xx", "eps_yy", "eps_zz"] + + info = DerivativeInfo( + paths=field_paths, + E_der_map={}, + D_der_map={}, + E_fwd={}, + D_fwd={}, + E_adj={}, + D_adj={}, + eps_data={ + key: td.ScalarFieldDataArray( + [[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [200e12]} + ) + for key in eps_keys + }, + frequencies=[3e8], + bounds=((-1, -1, -1), (1, 1, 1)), + eps_out=td.ScalarFieldDataArray( + [[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [3e8]} + ), + eps_in=td.ScalarFieldDataArray( + [[[[2.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [3e8]} + ), + bounds_intersect=((-1, -1, -1), (1, 1, 1)), + simulation_bounds=((-2, -2, -2), (2, 2, 2)), + ) + + with pytest.raises(NotImplementedError, match="unstructured"): + pr._compute_derivatives(derivative_info=info) + + def test_custom_sellmeier(monkeypatch): """Test that computed CustomSellmeier derivatives match analytic mapping.""" @@ -2422,8 +2500,8 @@ def make_objective(postprocess_fn: typing.Callable, structure_key: str) -> typin def objective(params): structure_traced = make_structures(params)[structure_key] sim = SIM_BASE.updated_copy( - structures=[structure_traced], - monitors=[*list(SIM_BASE.monitors), mnt_single, mnt_multi], + structures=(structure_traced,), + monitors=(*SIM_BASE.monitors, mnt_single, mnt_multi), ) data = run(sim, task_name="multifreq_test") return postprocess_fn(data) @@ -2550,8 +2628,8 @@ def test_multi_freq_edge_cases(use_emulated_run, structure_key, label, check_fn, def objective(params): structure_traced = make_structures(params)[structure_key] sim = SIM_BASE.updated_copy( - structures=[structure_traced], - monitors=[*list(SIM_BASE.monitors), mnt_single, mnt_multi], + structures=(structure_traced,), + monitors=(*SIM_BASE.monitors, mnt_single, mnt_multi), ) data = run(sim, task_name="multifreq_test") return postprocess_fn(data) @@ -2574,8 +2652,8 @@ def objective_indi(params, structure_key) -> float: for f in mnt_multi.freqs: structure_traced = make_structures(params)[structure_key] sim = SIM_BASE.updated_copy( - structures=[structure_traced], - monitors=[*list(SIM_BASE.monitors), mnt_multi], + structures=(structure_traced,), + monitors=(*SIM_BASE.monitors, mnt_multi), ) sim_data = web.run(sim, task_name="multifreq_test") @@ -2588,8 +2666,8 @@ def objective_indi(params, structure_key) -> float: def objective_multi(params, structure_key) -> float: structure_traced = make_structures(params)[structure_key] sim = SIM_BASE.updated_copy( - structures=[structure_traced], - monitors=[*list(SIM_BASE.monitors), mnt_multi], + structures=(structure_traced,), + monitors=(*SIM_BASE.monitors, mnt_multi), ) sim_data = web.run(sim, task_name="multifreq_test") amps = get_amps(sim_data, "multi").sel(mode_index=0, direction="+") @@ -2615,11 +2693,11 @@ def test_error_flux(use_emulated_run): def objective(params): structure_traced = make_structures(params)["medium"] sim = SIM_BASE.updated_copy( - structures=[structure_traced], - monitors=[ + structures=(structure_traced,), + monitors=( td.FluxMonitor(size=(1, 1, 0), center=(0, 0, 0), freqs=[FREQ0], name="flux"), td.FieldMonitor(size=(1, 1, 0), center=(0, 0, 0), freqs=[FREQ0], name="field"), - ], + ), ) data = run(sim, task_name="flux_error") return anp.sum(data["flux"].flux.values) @@ -2636,8 +2714,8 @@ def test_extraneous_field(use_emulated_run): def objective(params): structure_traced = make_structures(params)["medium"] sim = SIM_BASE.updated_copy( - structures=[structure_traced], - monitors=[ + structures=(structure_traced,), + monitors=( SIM_BASE.monitors[0], td.ModeMonitor( size=(1, 1, 0), @@ -2646,7 +2724,7 @@ def objective(params): freqs=[FREQ0 * 0.9, FREQ0 * 1.1], name="mode", ), - ], + ), ) data = run(sim, task_name="extra_field") amp = data["mode"].amps.sel(direction="+", f=FREQ0 * 0.9, mode_index=0).values @@ -2882,21 +2960,26 @@ def test_flux_monitor_freq_exclusion(use_emulated_run): """Checks if we are excluding flux monitor frequencies from the adjoint frequencies since we cannot differentiate through flux data.""" - monitors_just_field = [ - td.FieldMonitor(size=(1, 1, 0), center=(0, 0, 0), freqs=[FREQ0], name="field") - ] + monitors_just_field = ( + td.FieldMonitor( + size=(1, 1, 0), + center=(0, 0, 0), + freqs=[FREQ0], + name="field", + ), + ) - monitors_with_flux = [ + monitors_with_flux = ( td.FieldMonitor(size=(1, 1, 0), center=(0, 0, 0), freqs=[FREQ0], name="field"), td.FluxMonitor( size=(1, 1, 0), center=(0, 0, 0), freqs=[FREQ0 - FWIDTH, FREQ0 + FWIDTH], name="flux" ), - ] + ) def objective_with_monitors(monitors): def objective(params): structure_traced = make_structures(params)["medium"] - sim = SIM_BASE.updated_copy(structures=[structure_traced], monitors=monitors) + sim = SIM_BASE.updated_copy(structures=(structure_traced,), monitors=monitors) data = run(sim, task_name="adjoint_freq_test") assert data.simulation._freqs_adjoint == [FREQ0] return anp.sum(data["field"].flux.values) @@ -2918,7 +3001,7 @@ def test_dispersive_no_inf(use_emulated_run): def objective(args): structure_traced = make_structures(args)["polyslab_dispersive"] - sim = make_sim(args).updated_copy(structures=[structure_traced]) + sim = make_sim(args).updated_copy(structures=(structure_traced,)) sim_data = run(sim, task_name="adjoint_test", verbose=False) return postprocess(sim_data) @@ -2961,10 +3044,10 @@ def objective(x): union = td.ClipOperation(operation="union", geometry_a=box1, geometry_b=box2) structure = td.Structure(geometry=union, medium=td.Medium(permittivity=2)) sim = SIM_BASE.updated_copy( - structures=[structure], - monitors=[ + structures=(structure,), + monitors=( td.FieldMonitor(size=(0, 0, 0), center=(0, 0, 0), freqs=[FREQ0], name="field"), - ], + ), ) data = run(sim, task_name="clip_error") return anp.sum(data["field"].intensity.item()) @@ -3022,6 +3105,68 @@ def objective(params): assert anp.all(grad != 0.0), "some gradients are 0 for conductivity-only test" +@pytest.mark.parametrize("use_run_async", (False, True)) +def test_error_custom_medium_and_geometry_traced(rng, use_run_async, use_emulated_run, tmp_path): + """Test that we properly error when there is a combination of custom medium and + geometry gradients.""" + monitor, postprocess = make_monitors()["field_point"] + + def objective(all_params): + """Objective function testing only conductivity gradient (constant permittivity).""" + params = all_params[0:-3] + size_params = all_params[-3:] + + len_arr = np.prod(DA_SHAPE) + matrix = rng.random((len_arr, N_PARAMS)) + + # variable permittivity + eps_arr = 1.5 + 1.5 * (anp.tanh(3 * matrix @ params).reshape(DA_SHAPE) + 1) + + nx, ny, nz = DA_SHAPE + coords = { + "x": np.linspace(-0.5, 0.5, nx), + "y": np.linspace(-0.5, 0.5, ny), + "z": np.linspace(-0.5, 0.5, nz), + } + + custom_med_struct = td.Structure( + geometry=td.Box(center=(0, 0, 0), size=tuple(size_params)), + medium=td.CustomMedium( + permittivity=td.SpatialDataArray(eps_arr, coords=coords), + ), + ) + + sim = SIM_BASE.updated_copy( + structures=[custom_med_struct], + monitors=[monitor], + ) + + if use_run_async: + data = run_async( + [sim], + path_dir=str(tmp_path), + verbose=False, + )[0] + else: + data = run( + sim, + path=str(tmp_path / "sim_test.hdf5"), + task_name="error_custom_medium_and_geometry_traced_test", + verbose=False, + ) + return postprocess(data, data[monitor.name]) + + box_sizes = [1.0, 1.0, 1.0] + all_params = np.array(list(params0) + box_sizes) + + with pytest.raises( + AdjointError, + match="Detected structure at index 0 containing a CustomMedium " + "type and traced geometry attributes.", + ): + val, grad = ag.value_and_grad(objective)(all_params) + + @pytest.mark.parametrize("structure_key, monitor_key", args) def test_vjp_nan(use_emulated_run, structure_key, monitor_key): """Test vjp data that has nan in it is flagged as an error.""" diff --git a/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py b/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py index 3f23a7e98f..3e75b5fb98 100644 --- a/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py +++ b/tests/test_components/autograd/test_autograd_custom_dispersive_vjps.py @@ -40,12 +40,10 @@ def _deriv_info(freq): "E_adj": {}, "D_adj": {}, "eps_data": {}, - "eps_in": 2.0, - "eps_out": 1.0, "frequencies": [freq], "bounds": ((-1, -1, -1), (1, 1, 1)), - "eps_no_structure": eps_no, - "eps_inf_structure": eps_inf, + "eps_out": eps_no, + "eps_in": eps_inf, "bounds_intersect": ((-1, -1, -1), (1, 1, 1)), "simulation_bounds": ((-2, -2, -2), (2, 2, 2)), } diff --git a/tests/test_components/autograd/test_autograd_polyslab.py b/tests/test_components/autograd/test_autograd_polyslab.py index 3aa372aabc..5f60becb20 100644 --- a/tests/test_components/autograd/test_autograd_polyslab.py +++ b/tests/test_components/autograd/test_autograd_polyslab.py @@ -105,7 +105,19 @@ def __init__( self.c = coeffs["c"] self.d = coeffs["d"] self.frequencies = [200e12] - self.eps_in = 12.0 + self.eps_in = td.ScalarFieldDataArray( + [[[[12.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [200e12]} + ) + self.eps_out = td.ScalarFieldDataArray( + [[[[1.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [200e12]} + ) + eps_keys = ["eps_xx", "eps_yy", "eps_zz"] + self.eps_data = { + key: td.ScalarFieldDataArray( + [[[[12.0]]]], coords={"x": [0], "y": [0], "z": [0], "f": [200e12]} + ) + for key in eps_keys + } self.interpolators = None self.bounds_intersect = ( diff --git a/tests/test_components/autograd/test_autograd_rf_box.py b/tests/test_components/autograd/test_autograd_rf_box.py index 9f88a136ff..66be46d058 100644 --- a/tests/test_components/autograd/test_autograd_rf_box.py +++ b/tests/test_components/autograd/test_autograd_rf_box.py @@ -2,7 +2,6 @@ from __future__ import annotations import operator -import sys import autograd as ag import matplotlib.pylab as plt @@ -25,7 +24,6 @@ SAVE_ADJ_LOC = 1 LOCAL_GRADIENT = True VERBOSE = False -SHOW_PRINT_STATEMENTS = True USE_POLYSLAB_FOR_BOX = False NUMERICAL_RESULTS_DATA_DIR = ( "./numerical_rf_box_polyslab_test/" if USE_POLYSLAB_FOR_BOX else "./numerical_rf_box_box_test/" @@ -38,9 +36,6 @@ else: pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") -if SHOW_PRINT_STATEMENTS: - sys.stdout = sys.stderr - def get_sim_geometry(mesh_wvl_um): return td.Box(size=(4 * mesh_wvl_um, 4 * mesh_wvl_um, 7 * mesh_wvl_um), center=(0, 0, 0)) @@ -446,7 +441,9 @@ def run_and_process_fd(all_box_parameters, fd_step, objective): ), indirect=["dir_name"], ) -def test_finite_difference_2d_box_pec(rf_2d_test_parameters, rng, tmp_path, create_directory): +def test_finite_difference_2d_box_pec( + rf_2d_test_parameters, rng, tmp_path, create_directory, redirect_stdout_to_stderr +): """Test a variety of autograd permittivity gradients for 2D PEC boxes by""" """comparing them to numerical finite difference.""" @@ -571,17 +568,16 @@ def test_finite_difference_2d_box_pec(rf_2d_test_parameters, rng, tmp_path, crea width_data[SAVE_FD_LOC, :] = all_width_fd width_data[SAVE_ADJ_LOC, :] = all_width_adj - if SHOW_PRINT_STATEMENTS: - print(f"\n2D PEC Box Test {test_number} Summary:") - print(f"Mesh wavelength (um): {mesh_wvl_um}") - print(f"Adjoint wavelength (um): {adj_wvl_um}") - print(f"Monitor size (wavelengths): {monitor_size_wvl}") - print(f"Mesh refinement factor: {mesh_refinement_factor}") - print(f"Eval function: {eval_fn_name}") - print(f"Width mean (std): {width_error_mean} ({width_error_std})") - print(f"Width norm mean (std): {width_error_norm_mean} ({width_error_norm_std})") - print(f"Width overlap deg: {width_overlap_deg}") - print("\n") + print(f"\n2D PEC Box Test {test_number} Summary:") + print(f"Mesh wavelength (um): {mesh_wvl_um}") + print(f"Adjoint wavelength (um): {adj_wvl_um}") + print(f"Monitor size (wavelengths): {monitor_size_wvl}") + print(f"Mesh refinement factor: {mesh_refinement_factor}") + print(f"Eval function: {eval_fn_name}") + print(f"Width mean (std): {width_error_mean} ({width_error_std})") + print(f"Width norm mean (std): {width_error_norm_mean} ({width_error_norm_std})") + print(f"Width overlap deg: {width_overlap_deg}") + print("\n") if SAVE_FD_ADJ_DATA: np.save( @@ -626,7 +622,9 @@ def test_finite_difference_2d_box_pec(rf_2d_test_parameters, rng, tmp_path, crea ), indirect=["dir_name"], ) -def test_finite_difference_3d_box_pec(rf_3d_test_parameters, rng, tmp_path, create_directory): +def test_finite_difference_3d_box_pec( + rf_3d_test_parameters, rng, tmp_path, create_directory, redirect_stdout_to_stderr +): """Test a variety of autograd permittivity gradients for 3D PEC boxes by""" """comparing them to numerical finite difference.""" @@ -773,20 +771,19 @@ def test_finite_difference_3d_box_pec(rf_3d_test_parameters, rng, tmp_path, crea z_coord_data[SAVE_FD_LOC, :] = all_z_coord_fd z_coord_data[SAVE_ADJ_LOC, :] = all_z_coord_adj - if SHOW_PRINT_STATEMENTS: - print(f"\n3D PEC Box Test {test_number} Summary:") - print(f"Mesh wavelength (um): {mesh_wvl_um}") - print(f"Adjoint wavelength (um): {adj_wvl_um}") - print(f"Monitor size (wavelengths): {monitor_size_wvl}") - print(f"Box z thickness (wavelengths): {box_z_thickness_wvl}") - print(f"Eval function: {eval_fn_name}") - print(f"Width mean (std): {width_error_mean} ({width_error_std})") - print(f"Width norm mean (std): {width_error_norm_mean} ({width_error_norm_std})") - print(f"Width overlap deg: {width_overlap_deg}") - print(f"Z mean (std): {z_coord_error_mean} ({z_coord_error_std})") - print(f"Z norm mean (std): {z_coord_error_norm_mean} ({z_coord_error_norm_std})") - print(f"Z overlap deg: {z_coord_overlap_deg}") - print("\n") + print(f"\n3D PEC Box Test {test_number} Summary:") + print(f"Mesh wavelength (um): {mesh_wvl_um}") + print(f"Adjoint wavelength (um): {adj_wvl_um}") + print(f"Monitor size (wavelengths): {monitor_size_wvl}") + print(f"Box z thickness (wavelengths): {box_z_thickness_wvl}") + print(f"Eval function: {eval_fn_name}") + print(f"Width mean (std): {width_error_mean} ({width_error_std})") + print(f"Width norm mean (std): {width_error_norm_mean} ({width_error_norm_std})") + print(f"Width overlap deg: {width_overlap_deg}") + print(f"Z mean (std): {z_coord_error_mean} ({z_coord_error_std})") + print(f"Z norm mean (std): {z_coord_error_norm_mean} ({z_coord_error_norm_std})") + print(f"Z overlap deg: {z_coord_overlap_deg}") + print("\n") if SAVE_FD_ADJ_DATA: np.save( diff --git a/tests/test_components/autograd/test_autograd_rf_polyslab.py b/tests/test_components/autograd/test_autograd_rf_polyslab.py index 30ede80272..7be728e88a 100644 --- a/tests/test_components/autograd/test_autograd_rf_polyslab.py +++ b/tests/test_components/autograd/test_autograd_rf_polyslab.py @@ -2,7 +2,6 @@ from __future__ import annotations import operator -import sys import autograd as ag import matplotlib.pylab as plt @@ -26,16 +25,12 @@ LOCAL_GRADIENT = True VERBOSE = False NUMERICAL_RESULTS_DATA_DIR = "./numerical_rf_polyslab_test/" -SHOW_PRINT_STATEMENTS = False if PLOT_FD_ADJ_COMPARISON: pytestmark = pytest.mark.usefixtures("mpl_config_interactive") else: pytestmark = pytest.mark.usefixtures("mpl_config_noninteractive") -if SHOW_PRINT_STATEMENTS: - sys.stdout = sys.stderr - def get_sim_geometry(mesh_wvl_um): return td.Box(size=(5 * mesh_wvl_um, 5 * mesh_wvl_um, 7 * mesh_wvl_um), center=(0, 0, 0)) @@ -127,6 +122,8 @@ def create_objective_function_2D(create_sim_base, eval_fn, polyslab_z_value, sim def objective(polyslab_param_arrays): sim_base = create_sim_base() + layer_refinement_specs = [] + simulation_dict = {} for idx in range(len(polyslab_param_arrays)): get_polyslab_params = polyslab_param_arrays[idx] @@ -142,8 +139,18 @@ def objective(polyslab_param_arrays): ) ] + layer_refinement_specs.append( + td.LayerRefinementSpec.from_layer_bounds( + axis=2, + bounds=(polyslab_z_value, polyslab_z_value), + ) + ) + sim_with_block = sim_base.updated_copy( - structures=tuple(list(sim_base.structures) + polyslab_structures) + structures=tuple(list(sim_base.structures) + polyslab_structures), + grid_spec=sim_base.grid_spec.updated_copy( + layer_refinement_specs=layer_refinement_specs + ), ) simulation_dict[f"numerical_rf_polyslab_2d_testing_{idx}"] = sim_with_block.copy() @@ -177,6 +184,8 @@ def create_objective_function_3D( def objective(polyslab_param_arrays): sim_base = create_sim_base() + layer_refinement_specs = [] + simulation_dict = {} for idx in range(len(polyslab_param_arrays)): get_polyslab_params = polyslab_param_arrays[idx] @@ -195,8 +204,21 @@ def objective(polyslab_param_arrays): ) ] + layer_refinement_specs.append( + td.LayerRefinementSpec.from_layer_bounds( + axis=2, + bounds=( + polyslab_z_value - 0.5 * polyslab_z_thickness, + polyslab_z_value + 0.5 * polyslab_z_thickness, + ), + ) + ) + sim_with_block = sim_base.updated_copy( - structures=tuple(list(sim_base.structures) + polyslab_structures) + structures=tuple(list(sim_base.structures) + polyslab_structures), + grid_spec=sim_base.grid_spec.updated_copy( + layer_refinement_specs=layer_refinement_specs + ), ) simulation_dict[f"numerical_rf_polyslab_3d_testing_{idx}"] = sim_with_block.copy() @@ -396,7 +418,9 @@ def run_and_process_fd(polyslab_parameters, fd_step, objective): ), indirect=["dir_name"], ) -def test_finite_difference_2d_polyslab_pec(rf_2d_test_parameters, rng, tmp_path, create_directory): +def test_finite_difference_2d_polyslab_pec( + rf_2d_test_parameters, rng, tmp_path, create_directory, redirect_stdout_to_stderr +): """Test a variety of autograd permittivity gradients for 2D `PolySlab` PEC by""" """comparing them to numerical finite difference.""" @@ -496,17 +520,16 @@ def test_finite_difference_2d_polyslab_pec(rf_2d_test_parameters, rng, tmp_path, vertex_data[SAVE_FD_LOC, :] = vertex_fd vertex_data[SAVE_ADJ_LOC, :] = vertex_adj - if SHOW_PRINT_STATEMENTS: - print(f"\n2D PEC PolySlab Test {test_number} Summary:") - print(f"Mesh wavelength (um): {mesh_wvl_um}") - print(f"Adjoint wavelength (um): {adj_wvl_um}") - print(f"Monitor size (wavelengths): {monitor_size_wvl}") - print(f"Mesh refinement factor: {mesh_refinement_factor}") - print(f"Eval function: {eval_fn_name}") - print(f"Vertex mean (std): {vertex_error_mean} ({vertex_error_std})") - print(f"Vertex norm mean (std): {vertex_error_norm_mean} ({vertex_error_norm_std})") - print(f"Vertex overlap deg: {vertex_overlap_deg}") - print("\n") + print(f"\n2D PEC PolySlab Test {test_number} Summary:") + print(f"Mesh wavelength (um): {mesh_wvl_um}") + print(f"Adjoint wavelength (um): {adj_wvl_um}") + print(f"Monitor size (wavelengths): {monitor_size_wvl}") + print(f"Mesh refinement factor: {mesh_refinement_factor}") + print(f"Eval function: {eval_fn_name}") + print(f"Vertex mean (std): {vertex_error_mean} ({vertex_error_std})") + print(f"Vertex norm mean (std): {vertex_error_norm_mean} ({vertex_error_norm_std})") + print(f"Vertex overlap deg: {vertex_overlap_deg}") + print("\n") if SAVE_FD_ADJ_DATA: np.save( @@ -553,7 +576,9 @@ def test_finite_difference_2d_polyslab_pec(rf_2d_test_parameters, rng, tmp_path, ), indirect=["dir_name"], ) -def test_finite_difference_3d_polyslab_pec(rf_3d_test_parameters, rng, tmp_path, create_directory): +def test_finite_difference_3d_polyslab_pec( + rf_3d_test_parameters, rng, tmp_path, create_directory, redirect_stdout_to_stderr +): """Test a variety of autograd permittivity gradients for 3D PEC `PolySlab` by""" """comparing them to numerical finite difference.""" @@ -652,17 +677,16 @@ def test_finite_difference_3d_polyslab_pec(rf_3d_test_parameters, rng, tmp_path, vertex_data[SAVE_FD_LOC, :] = vertex_fd vertex_data[SAVE_ADJ_LOC, :] = vertex_adj - if SHOW_PRINT_STATEMENTS: - print(f"\n3D PEC PolySlab Test {test_number} Summary:") - print(f"Mesh wavelength (um): {mesh_wvl_um}") - print(f"Adjoint wavelength (um): {adj_wvl_um}") - print(f"Monitor size (wavelengths): {monitor_size_wvl}") - print(f"Polyslab z thickness (wavelengths): {polyslab_z_thickness_wvl}") - print(f"Eval function: {eval_fn_name}") - print(f"Vertex mean (std): {vertex_error_mean} ({vertex_error_std})") - print(f"Vertex norm mean (std): {vertex_error_norm_mean} ({vertex_error_norm_std})") - print(f"Vertex overlap deg: {vertex_overlap_deg}") - print("\n") + print(f"\n3D PEC PolySlab Test {test_number} Summary:") + print(f"Mesh wavelength (um): {mesh_wvl_um}") + print(f"Adjoint wavelength (um): {adj_wvl_um}") + print(f"Monitor size (wavelengths): {monitor_size_wvl}") + print(f"Polyslab z thickness (wavelengths): {polyslab_z_thickness_wvl}") + print(f"Eval function: {eval_fn_name}") + print(f"Vertex mean (std): {vertex_error_mean} ({vertex_error_std})") + print(f"Vertex norm mean (std): {vertex_error_norm_mean} ({vertex_error_norm_std})") + print(f"Vertex overlap deg: {vertex_overlap_deg}") + print("\n") if SAVE_FD_ADJ_DATA: np.save( diff --git a/tests/test_components/test_IO.py b/tests/test_components/test_IO.py index 83114f5678..3de9b8dbdb 100644 --- a/tests/test_components/test_IO.py +++ b/tests/test_components/test_IO.py @@ -13,7 +13,7 @@ import tidy3d as td from tidy3d import __version__ -from tidy3d.components.base import DATA_ARRAY_MAP +from tidy3d.components.data.data_array import DATA_ARRAY_MAP from tidy3d.components.data.sim_data import DATA_TYPE_MAP from ..test_data.test_monitor_data import make_flux_data @@ -24,6 +24,7 @@ # Store an example of every minor release simulation to test updater in the future SIM_DIR = "tests/sims" +SIM_STATIC = SIM.to_static() @pytest.fixture @@ -35,7 +36,7 @@ def split_string(monkeypatch): def set_datasets_to_none(sim): - sim_dict = sim.dict() + sim_dict = sim.model_dump() for src in sim_dict["sources"]: if src["type"] == "CustomFieldSource": src["field_dataset"] = None @@ -62,21 +63,21 @@ def set_datasets_to_none(sim): structure["medium"]["poles"] = [] else: structure["medium"]["coeffs"] = [] - return td.Simulation.parse_obj(sim_dict) + return td.Simulation.model_validate(sim_dict) def test_simulation_load_export(split_string, tmp_path): major, minor, patch = __version__.split(".") path = os.path.join(tmp_path, f"simulation_{major}_{minor}_{patch}.json") path_hdf5 = os.path.join(tmp_path, f"simulation_{major}_{minor}_{patch}.h5") - SIM.to_file(path) - SIM.to_hdf5(path_hdf5) + SIM_STATIC.to_file(path) + SIM_STATIC.to_hdf5(path_hdf5) SIM2 = td.Simulation.from_file(path) SIM_HDF5 = td.Simulation.from_hdf5(path_hdf5) - assert set_datasets_to_none(SIM)._json_string == SIM2._json_string, ( + assert set_datasets_to_none(SIM_STATIC)._json_string == SIM2._json_string, ( "original and loaded simulations are not the same" ) - assert SIM == SIM_HDF5, "original and loaded from hdf5 simulations are not the same" + assert SIM_STATIC == SIM_HDF5, "original and loaded from hdf5 simulations are not the same" def test_simulation_load_export_yaml(tmp_path): @@ -104,30 +105,30 @@ def test_component_load_export_yaml(tmp_path): def test_simulation_load_export_hdf5(split_string, tmp_path): path = str(tmp_path / "simulation.hdf5") - SIM.to_file(path) + SIM_STATIC.to_file(path) SIM2 = td.Simulation.from_file(path) - assert SIM == SIM2, "original and loaded simulations are not the same" + assert SIM_STATIC == SIM2, "original and loaded simulations are not the same" def test_simulation_load_export_hdf5_gz(split_string, tmp_path): path = str(tmp_path / "simulation.hdf5.gz") - SIM.to_file(path) + SIM_STATIC.to_file(path) SIM2 = td.Simulation.from_file(path) - assert SIM == SIM2, "original and loaded simulations are not the same" + assert SIM_STATIC == SIM2, "original and loaded simulations are not the same" def test_simulation_load_export_hdf5_explicit(split_string, tmp_path): path = str(tmp_path / "simulation.hdf5") - SIM.to_hdf5(path) + SIM_STATIC.to_hdf5(path) SIM2 = td.Simulation.from_hdf5(path) - assert SIM == SIM2, "original and loaded simulations are not the same" + assert SIM_STATIC == SIM2, "original and loaded simulations are not the same" def test_simulation_load_export_hdf5_gz_explicit(split_string, tmp_path): path = str(tmp_path / "simulation.hdf5.gz") - SIM.to_hdf5_gz(path) + SIM_STATIC.to_hdf5_gz(path) SIM2 = td.Simulation.from_hdf5_gz(path) - assert SIM == SIM2, "original and loaded simulations are not the same" + assert SIM_STATIC == SIM2, "original and loaded simulations are not the same" def test_simulation_load_export_pckl(tmp_path): @@ -176,23 +177,25 @@ def test_1a_simulation_load_export2(tmp_path): assert SIM2 == SIM3, "original and loaded simulations are not the same" +@pytest.mark.perf def test_validation_speed(tmp_path): sizes_bytes = [] times_sec = [] path = str(tmp_path / "simulation.json") _ = SIM - N_tests = 10 + N_tests = 5 # may be increased temporarily, makes it slow for routine tests + max_structures = np.log10(100) # may be increased temporarily, makes it slow for routine tests # adjust as needed, keeping small to speed tests up - num_structures = np.logspace(0, 2, N_tests).astype(int) + num_structures = np.logspace(0, max_structures, N_tests).astype(int) for n in num_structures: new_structures = [] for i in range(n): new_structure = SIM.structures[0].copy(update={"name": str(i)}) new_structures.append(new_structure) - S = SIM.copy(update={"structures": new_structures}) + S = SIM.copy(update={"structures": tuple(new_structures)}) S.to_file(path) time_start = time() @@ -226,7 +229,9 @@ def test_simulation_updater(sim_file): def test_yaml(tmp_path): path = str(tmp_path / "simulation.json") SIM.to_file(path) + SIM.to_file("simulation.json") sim = td.Simulation.from_file(path) + path1 = str(tmp_path / "simulation.yaml") sim.to_yaml(path1) sim1 = td.Simulation.from_yaml(path1) diff --git a/tests/test_components/test_absorbers.py b/tests/test_components/test_absorbers.py index 9a385a52f0..5d6cbbd38c 100644 --- a/tests/test_components/test_absorbers.py +++ b/tests/test_components/test_absorbers.py @@ -3,8 +3,8 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d.components.boundary import DEFAULT_BROADBAND_MODE_ABC_NUM_FREQS @@ -18,12 +18,12 @@ def test_port_absorbers_alone(): direction="+", size=(1, 1, 0), boundary_spec=td.ABCBoundary(permittivity=1) ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.InternalAbsorber( direction="+", size=(1, 1, 1), boundary_spec=td.ABCBoundary(permittivity=1) ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.InternalAbsorber(direction="+", size=(1, 1, 0), boundary_spec=td.ABCBoundary()) absorber = td.InternalAbsorber( @@ -84,7 +84,7 @@ def test_port_absorbers_simulations(): sim.plot(x=0) # validate no fully anisotropic mediums - with pytest.raises(SetupError): + with pytest.raises(ValidationError): _ = td.Simulation( center=[0, 0, 0], size=[1, 1, 1], @@ -103,7 +103,7 @@ def test_port_absorbers_simulations(): ) # disallow ABC ports in zero dimensions - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=[0, 0, 0], size=[1, 1, 0], @@ -193,13 +193,13 @@ def test_abc_boundaries_alone(): _ = td.ABCBoundary(permittivity=2) _ = td.ABCBoundary(permittivity=2, conductivity=0.1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ABCBoundary(permittivity=0) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ABCBoundary(permittivity=2, conductivity=-0.1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ABCBoundary(permittivity=None, conductivity=-0.1) # test mode abc @@ -212,7 +212,7 @@ def test_abc_boundaries_alone(): freq_spec=freq0, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ModeABCBoundary( plane=td.Box(size=(1, 1, 0)), mode_spec=td.ModeSpec(num_modes=2), @@ -220,7 +220,7 @@ def test_abc_boundaries_alone(): freq_spec=-1, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ModeABCBoundary( plane=td.Box(size=(1, 1, 0)), mode_spec=td.ModeSpec(num_modes=2), @@ -228,7 +228,7 @@ def test_abc_boundaries_alone(): freq_spec=freq0, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ModeABCBoundary( plane=td.Box(size=(1, 1, 1)), mode_spec=td.ModeSpec(num_modes=2), @@ -274,10 +274,10 @@ def test_abc_boundaries_alone(): assert abc_boundary == abc_boundary_from_source assert abc_boundary == abc_boundary_from_monitor - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Boundary(minus=td.Periodic(), plus=td.ABCBoundary()) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Boundary(minus=td.Periodic(), plus=td.ModeABCBoundary(plane=td.Box(size=(1, 1, 0)))) @@ -307,7 +307,7 @@ def test_abc_boundaries_simulations(): sim.plot(x=0) # validate ABC medium is not anisotorpic - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=[0, 0, 0], size=[1, 1, 1], @@ -357,7 +357,7 @@ def test_abc_boundaries_simulations(): boundary_spec=td.BoundarySpec.all_sides(td.ABCBoundary(permittivity=2)), ) # not ok if ABC boundary is crossed - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=[0, 0, 0], size=[1, 1, 1], @@ -470,7 +470,7 @@ def test_abc_boundaries_simulations(): ), ) # error if no frequency - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=[0, 0, 0], size=[1, 1, 1], @@ -485,7 +485,7 @@ def test_abc_boundaries_simulations(): ), ) # error if no frequency for automatic abc from interesected mediums - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=[0, 0, 0], size=[1, 1, 1], @@ -521,7 +521,7 @@ def test_abc_boundaries_simulations(): boundary_spec=td.BoundarySpec.all_sides(td.ABCBoundary(permittivity=2, conductivity=0)), ) # not ok if non-zerp conductivity - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=[0, 0, 0], size=[1, 1, 1], @@ -544,21 +544,21 @@ def test_abc_boundaries_broadband(): ) # test max num poles > 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.BroadbandModeABCFitterParam(max_num_poles=0) # test max num poles <= 10 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.BroadbandModeABCFitterParam(max_num_poles=11) # test tolerance rms >= 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.BroadbandModeABCFitterParam(tolerance_rms=-1) # test frequency sampling points > 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.BroadbandModeABCFitterParam(frequency_sampling_points=0) # test frequency sampling points <= 21 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.BroadbandModeABCFitterParam(frequency_sampling_points=102) # test basic instance @@ -571,11 +571,11 @@ def test_abc_boundaries_broadband(): ) # test max frequency > min frequency - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.BroadbandModeABCSpec(frequency_range=(fmax, fmin)) # test min frequency > 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.BroadbandModeABCSpec(frequency_range=(0, fmax)) # test from_wavelength_range diff --git a/tests/test_components/test_apodization.py b/tests/test_components/test_apodization.py index 4fadd253d3..20b4044d3b 100644 --- a/tests/test_components/test_apodization.py +++ b/tests/test_components/test_apodization.py @@ -3,8 +3,8 @@ from __future__ import annotations import matplotlib.pyplot as plt -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td @@ -17,27 +17,27 @@ def test_apodization(): def test_end_lt_start(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=2, end=1, width=0.2) def test_no_width(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=1, end=2) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(end=2) def test_negative_times(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=-2, end=-1, width=0.2) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=1, end=2, width=-0.2) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ApodizationSpec(start=1, end=2, width=0) diff --git a/tests/test_components/test_base.py b/tests/test_components/test_base.py index b996025aa6..32903a43cf 100644 --- a/tests/test_components/test_base.py +++ b/tests/test_components/test_base.py @@ -2,18 +2,49 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal, Optional import numpy as np import pytest -from pydantic.v1 import ValidationError +from pydantic import Field, ValidationError +from pydantic_core import PydanticSerializationError import tidy3d as td from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.components.types import Undefined M = td.Medium() +class LeafModel(Tidy3dBaseModel): + leaf_attr: Optional[str] = None + common_attr: int = 0 + value_attr: Optional[float] = None + + +class NodeModel(Tidy3dBaseModel): + node_attr: Optional[str] = None + leaf_child: Optional[LeafModel] = None + leaf_list: list[LeafModel] = Field(default_factory=list) + leaf_tuple: tuple[LeafModel, ...] = Field(default_factory=tuple) + common_attr: float = 0.0 + value_attr: Optional[int] = None + + +class RootModel(Tidy3dBaseModel): + root_attr: Optional[str] = None + node_child: Optional[NodeModel] = None + node_list: list[NodeModel] = Field(default_factory=list) + node_tuple: tuple[NodeModel, ...] = Field(default_factory=tuple) + mixed_list: list[Any] = Field(default_factory=list) + common_attr: bool = False + value_attr: Optional[str] = None + + +class SpecialNodeModel(NodeModel): + special_attr: bool = True + + def test_shallow_copy(): _ = M.copy(deep=False) @@ -39,20 +70,6 @@ def test_comparisons(): M == M2 -def _test_version(tmp_path): - """ensure there's a version in simulation""" - - sim = td.Simulation( - size=(1, 1, 1), - run_time=1e-12, - ) - path = str(tmp_path / "simulation.json") - sim.to_file(path) - with open(path) as f: - s = f.read() - assert '"version": ' in s - - def test_deep_copy(): """Make sure deep copying works as expected with defaults.""" b = td.Box(size=(1, 1, 1)) @@ -63,11 +80,6 @@ def test_deep_copy(): medium=m, ) - # s_shallow = s.copy(deep=False) - # with shallow copy, these should be the same objects - # assert id(s.geometry) == id(s_shallow.geometry) - # assert id(s.medium) == id(s_shallow.medium) - s_deep = s.copy(deep=True) # with deep copy, these should be different objects @@ -79,28 +91,15 @@ def test_deep_copy(): assert id(s.geometry) != id(s_default.geometry) assert id(s.medium) != id(s_default.medium) - # make sure other kwargs work, here we update the geometry to a sphere and shallow copy medium - # s_kwargs = s.copy(deep=False, update=dict(geometry=Sphere(radius=1.0))) - # assert id(s.medium) == id(s_kwargs.medium) - # assert id(s.geometry) != id(s_kwargs.geometry) - # behavior of modifying attributes s_default = s.copy(update={"geometry": td.Sphere(radius=1.0)}) assert id(s.geometry) != id(s_default.geometry) - # s_shallow = s.copy(deep=False, update=dict(geometry=Sphere(radius=1.0))) - # assert id(s.geometry) != id(s_shallow.geometry) - # behavior of modifying attributes of attributes new_geometry = s.geometry.copy(update={"size": (2, 2, 2)}) s_default = s.copy(update={"geometry": new_geometry}) assert id(s.geometry) != id(s_default.geometry) - # s_shallow = s.copy(deep=False) - # new_geometry = s.geometry.copy(update=dict(size=(2,2,2))) - # s_shallow = s_shallow.copy(update=dict(geometry=new_geometry)) - # assert id(s.geometry) == id(s_shallow.geometry) - def test_updated_copy(): """Make sure updated copying shortcut works as expected with defaults.""" @@ -172,7 +171,7 @@ def test_updated_copy_path(): ) # forgot path - with pytest.raises(ValueError): + with pytest.raises(KeyError): assert sim == sim.updated_copy(permittivity=2.0) assert sim.updated_copy(size=(6, 6, 6)) == sim.updated_copy(size=(6, 6, 6), path=None) @@ -204,7 +203,7 @@ def test_attrs(tmp_path): assert obj.attrs == {"foo": "attr"} # this is still not allowed though - with pytest.raises(TypeError): + with pytest.raises(ValidationError): obj.attrs = {} # attrs can be modified @@ -220,11 +219,11 @@ def test_attrs(tmp_path): assert obj3.attrs == obj2.attrs # attrs are in the json strings - obj_json = obj3.json() - assert '{"foo": "bar"}' in obj_json + obj_json = obj3.model_dump_json() + assert '{"foo":"bar"}' in obj_json - # attrs are in the dict() - obj_dict = obj3.dict() + # attrs are in the dict + obj_dict = obj3.model_dump() assert obj_dict["attrs"] == {"foo": "bar"} # objects saved and loaded from file still have attrs @@ -236,8 +235,8 @@ def test_attrs(tmp_path): # test attrs that can't be serialized obj.attrs["not_serializable"] = type - with pytest.raises(TypeError): - obj.json() + with pytest.raises(PydanticSerializationError): + obj.model_dump_json() @pytest.mark.parametrize( @@ -289,8 +288,308 @@ class DispatchChild(DispatchBase): type: Literal["DispatchChild"] = "DispatchChild" data = {"type": "DispatchChild", "value": 1} - parsed = Tidy3dBaseModel._parse_model_dict(data) + parsed = Tidy3dBaseModel._model_validate(data) assert isinstance(parsed, DispatchChild) with pytest.raises(ValidationError): - DispatchChild.parse_obj({"type": "DispatchBase", "value": 2}) + DispatchChild.model_validate({"type": "DispatchBase", "value": 2}) + + +def test_find_paths_empty_model(): + empty_leaf = LeafModel() + # field 'leaf_attr' exists on LeafModel, even if its value is None + assert empty_leaf.find_paths("leaf_attr") == [""] + assert empty_leaf.find_paths("non_existent_attr") == [] + + +def test_find_paths_top_level(): + leaf1 = LeafModel(leaf_attr="test_val", common_attr=5) + assert leaf1.find_paths("leaf_attr") == [""] + assert leaf1.find_paths("leaf_attr", "test_val") == [""] + assert leaf1.find_paths("leaf_attr", "wrong_value") == [] + assert leaf1.find_paths("common_attr", 5) == [""] + assert leaf1.find_paths("common_attr", 0) == [] # default is 0, but instance has 5 + + +def test_find_paths_nested(): + leaf_inner = LeafModel(leaf_attr="inner_leaf_val", value_attr=3.14) + node = NodeModel(leaf_child=leaf_inner, node_attr="node_val") + + assert node.find_paths("leaf_attr") == ["leaf_child"] + assert node.find_paths("leaf_attr", "inner_leaf_val") == ["leaf_child"] + assert node.find_paths("value_attr", 3.14) == ["leaf_child"] # leaf_inner.value_attr + assert node.find_paths("leaf_attr", "wrong_value") == [] + assert node.find_paths("node_attr") == [""] + assert node.find_paths("node_attr", "node_val") == [""] + + +def test_find_paths_list_and_tuple(): + leaf1 = LeafModel(leaf_attr="l1_val", common_attr=1) + leaf2 = LeafModel(leaf_attr="l2_val", common_attr=2) + leaf3 = LeafModel(leaf_attr="l1_val", common_attr=3) # Same leaf_attr as leaf1 + + node = NodeModel(leaf_list=[leaf1, leaf2], leaf_tuple=(leaf3,), common_attr=0.5) + + # Search for 'leaf_attr' without value filter + expected_paths_leaf_attr = sorted(["leaf_list/0", "leaf_list/1", "leaf_tuple/0"]) + assert node.find_paths("leaf_attr") == expected_paths_leaf_attr + + # Search for 'leaf_attr' with value "l1_val" + expected_paths_leaf_attr_l1 = sorted(["leaf_list/0", "leaf_tuple/0"]) + assert node.find_paths("leaf_attr", "l1_val") == expected_paths_leaf_attr_l1 + + # Search for 'common_attr' (exists on NodeModel and LeafModel) + # NodeModel.common_attr=0.5 (float) + # LeafModel.common_attr (int) + expected_paths_common_attr = sorted(["", "leaf_list/0", "leaf_list/1", "leaf_tuple/0"]) + assert node.find_paths("common_attr") == expected_paths_common_attr + + # Search for 'common_attr' with specific values + assert node.find_paths("common_attr", 1) == ["leaf_list/0"] # leaf1.common_attr + assert node.find_paths("common_attr", 0.5) == [""] # node.common_attr + assert node.find_paths("common_attr", 3) == ["leaf_tuple/0"] # leaf3.common_attr + + +def test_find_paths_no_match(): + leaf = LeafModel() + node = NodeModel(leaf_child=leaf, leaf_list=[leaf]) + root = RootModel(node_child=node, node_list=[node]) + + assert root.find_paths("non_existent_field") == [] + assert root.find_paths("leaf_attr", "value_that_does_not_exist") == [] + # 'leaf_attr' exists, but not with this value (all are None or unset) + assert root.find_paths("leaf_attr", "specific_value") == [] + + +def test_find_paths_value_is_none(): + l_none = LeafModel(leaf_attr=None) + l_set = LeafModel(leaf_attr="set") + node = NodeModel(leaf_list=[l_none, l_set]) + + assert node.find_paths("leaf_attr", None) == ["leaf_list/0"] + assert node.find_paths("leaf_attr", Undefined) == sorted(["leaf_list/0", "leaf_list/1"]) + + +def test_find_paths_complex_structure(): + l1 = LeafModel(leaf_attr="target_leaf", common_attr=10) + l2 = LeafModel(leaf_attr="other_leaf", common_attr=20) + l3 = LeafModel(common_attr=10, value_attr=10.0) # common_attr matches l1 + + n1 = NodeModel(node_attr="n1_val", leaf_child=l1, leaf_list=[l2, l3]) + n2 = NodeModel(node_attr="target_node_val", common_attr=5.5) + n3 = NodeModel(leaf_child=LeafModel(leaf_attr="target_leaf")) # New LeafModel instance + + root = RootModel( + root_attr="root_val", + node_child=n1, + node_list=[n2, n3], + mixed_list=[l1, "a_string_item", n2, LeafModel(leaf_attr="target_leaf")], + ) + + # Find 'leaf_attr' == "target_leaf" + expected = sorted( + [ + "node_child/leaf_child", # n1.leaf_child (l1) + "node_list/1/leaf_child", # n3.leaf_child + "mixed_list/0", # l1 in mixed_list + "mixed_list/3", # new LeafModel in mixed_list + ] + ) + assert root.find_paths("leaf_attr", "target_leaf") == expected + + # Find 'common_attr' == 10 (int) + expected_common_int = sorted( + [ + "node_child/leaf_child", # l1 in n1.leaf_child (l1.common_attr is int) + "node_child/leaf_list/1", # l3 in n1.leaf_list (l3.common_attr is int) + "mixed_list/0", # l1 in mixed_list + ] + ) + assert root.find_paths("common_attr", 10) == expected_common_int + + # Find 'node_attr' (any value) + expected_node_attr = sorted( + [ + "node_child", # n1 + "node_list/0", # n2 + "node_list/1", # n3 + "mixed_list/2", # n2 in mixed_list + ] + ) + assert root.find_paths("node_attr") == expected_node_attr + + # Find 'node_attr' == "target_node_val" + expected_target_node = sorted( + [ + "node_list/0", # n2 + "mixed_list/2", # n2 in mixed_list + ] + ) + assert root.find_paths("node_attr", "target_node_val") == expected_target_node + + # Find 'root_attr' (on self) + assert root.find_paths("root_attr") == [""] + assert root.find_paths("root_attr", "root_val") == [""] + + +def test_find_submodels_find_self_and_empty(): + leaf = LeafModel() + assert leaf.find_submodels(LeafModel) == [leaf] # Finds self + assert leaf.find_submodels(NodeModel) == [] # Does not find other types + + node = NodeModel() + # NodeModel itself has leaf_child: Optional[LeafModel] = None etc. + # these are not instantiated if not provided. + assert node.find_submodels(NodeModel) == [node] + assert node.find_submodels(LeafModel) == [] + + +def test_find_submodels_nested(): + leaf_inner = LeafModel() + node = NodeModel(leaf_child=leaf_inner) + root = RootModel(node_child=node) + + found_leafs = root.find_submodels(LeafModel) + assert found_leafs == [leaf_inner] + + found_nodes = root.find_submodels(NodeModel) + assert found_nodes == [node] + + assert root.find_submodels(RootModel) == [root] + + +def test_find_submodels_list_and_tuple_uniqueness_and_order(): + # Instances + l1 = LeafModel(common_attr=1) + l2 = LeafModel(common_attr=2) + # l1 and l2 are distinct objects with different content. + + n1 = NodeModel(leaf_child=l1, common_attr=0.1) + n2 = NodeModel(leaf_list=[l1, l2], common_attr=0.2) # l1 is reused here + # n1 and n2 are distinct objects with different content. + + root = RootModel(node_list=[n1, n2], node_tuple=(n1,)) # n1 is reused here + + # Expected order of first encounter during depth-first traversal: + # root (RootModel instance) + # n1 (from root.node_list[0]) + # l1 (from n1.leaf_child) + # n2 (from root.node_list[1]) + # (l1 from n2.leaf_list[0] is already seen) + # l2 (from n2.leaf_list[1]) + # (n1 from root.node_tuple[0] is already seen) + + # Find LeafModel instances + found_leafs = root.find_submodels(LeafModel) + assert found_leafs == [l1, l2] # Order: l1 then l2 + + # Find NodeModel instances + found_nodes = root.find_submodels(NodeModel) + assert found_nodes == [n1, n2] # Order: n1 then n2 + + # Find RootModel + assert root.find_submodels(RootModel) == [root] + + +def test_find_submodels_uniqueness_identical_content_distinct_instances(): + leaf_val_equiv1 = LeafModel(leaf_attr="same_val") + leaf_val_equiv2 = LeafModel(leaf_attr="same_val") # Different instance, same content + assert leaf_val_equiv1 is not leaf_val_equiv2 + assert leaf_val_equiv1 == leaf_val_equiv2 # Relies on Tidy3dBaseModel.__eq__ + # Hashes should be the same due to frozen=True and content-based Pydantic hash. + + node = NodeModel(leaf_list=[leaf_val_equiv1, leaf_val_equiv2, LeafModel(leaf_attr="diff_val")]) + + # Expected order: leaf_val_equiv1 (first one encountered of the "same_val" pair), then the "diff_val" one. + found_leafs = node.find_submodels(LeafModel) + + assert len(found_leafs) == 2 + assert found_leafs[0] is leaf_val_equiv1 # First instance of "same_val" + assert found_leafs[1].leaf_attr == "diff_val" + + +def test_find_submodels_no_match_type(): + root = RootModel(node_child=NodeModel(leaf_child=LeafModel())) + # Use a type not present in the structure (e.g. a built-in or other library type) + assert root.find_submodels(td.Simulation) == [] + + +def test_find_submodels_subclassing(): + leaf = LeafModel() + node_base_instance = NodeModel(leaf_child=leaf, node_attr="base_node") + node_special_instance = SpecialNodeModel( + leaf_child=leaf, node_attr="special_node", special_attr=True + ) + + root = RootModel( + node_list=[node_base_instance, node_special_instance], + mixed_list=[node_special_instance, leaf], + ) + + # Expected order for NodeModel: node_base_instance, node_special_instance + # Expected order for SpecialNodeModel: node_special_instance + # Expected order for LeafModel: leaf + + # Find all NodeModels (should include SpecialNodeModel instances) + found_nodes = root.find_submodels(NodeModel) + assert found_nodes == [node_base_instance, node_special_instance] + + # Find only SpecialNodeModels + found_special_nodes = root.find_submodels(SpecialNodeModel) + assert found_special_nodes == [node_special_instance] + + # Find LeafModels (only one unique instance 'leaf' is involved) + found_leafs = root.find_submodels(LeafModel) + assert found_leafs == [leaf] + + +def test_find_submodels_complex_structure_and_order(): + l1 = LeafModel(common_attr=1) + l2 = LeafModel(common_attr=2) + l_shared = LeafModel(common_attr=99) + + n1 = NodeModel(leaf_child=l1, common_attr=1.0, node_attr="N1") + n2_special = SpecialNodeModel(leaf_list=[l2, l_shared], common_attr=2.0, node_attr="N2S") + n3 = NodeModel(leaf_child=l_shared, common_attr=3.0, node_attr="N3") # l_shared is re-used + + # The root model itself is an instance of RootModel. + # It will be found if searching for RootModel or Tidy3dBaseModel. + root = RootModel( + node_child=n1, + node_list=[n2_special, n3], + mixed_list=[l1, "string_element", n1, l_shared], # n1, l1, l_shared re-used + ) + + # Expected order of first encounter of unique models: + # root (RootModel) + # n1 (NodeModel, from root.node_child) + # l1 (LeafModel, from n1.leaf_child) + # n2_special (SpecialNodeModel, from root.node_list[0]) + # l2 (LeafModel, from n2_special.leaf_list[0]) + # l_shared (LeafModel, from n2_special.leaf_list[1]) + # n3 (NodeModel, from root.node_list[1]) + # (l_shared from n3.leaf_child is already seen) + # (l1 from root.mixed_list[0] is already seen) + # (n1 from root.mixed_list[2] is already seen) + # (l_shared from root.mixed_list[3] is already seen) + + # Test finding RootModel + assert root.find_submodels(RootModel) == [root] + + # Test finding NodeModel (includes SpecialNodeModel) + # This search should yield n1, n2_special, n3 in that order. + found_all_nodes = root.find_submodels(NodeModel) + assert found_all_nodes == [n1, n2_special, n3] + + # Test finding SpecialNodeModel + assert root.find_submodels(SpecialNodeModel) == [n2_special] + + # Test finding LeafModel + # This search should yield l1, l2, l_shared in that order. + found_leafs = root.find_submodels(LeafModel) + assert found_leafs == [l1, l2, l_shared] + + # Test finding Tidy3dBaseModel (should return all unique model instances in order of first encounter) + all_models = root.find_submodels(Tidy3dBaseModel) + expected_all_models = [root, n1, l1, n2_special, l2, l_shared, n3] + assert all_models == expected_all_models diff --git a/tests/test_components/test_beam.py b/tests/test_components/test_beam.py index d55c776314..4e7bcb91d1 100644 --- a/tests/test_components/test_beam.py +++ b/tests/test_components/test_beam.py @@ -3,8 +3,8 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pd import pytest +from pydantic import ValidationError from tidy3d.components.beam import ( AstigmaticGaussianBeamProfile, @@ -156,7 +156,7 @@ def test_invalid_beam_size(): center = (0, 0, 0) size = (10, 10, 10) resolution = 100 - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): GaussianBeamProfile(center=center, size=size, resolution=resolution, freqs=FREQS) diff --git a/tests/test_components/test_boundaries.py b/tests/test_components/test_boundaries.py index 3c190919ca..f428ec8027 100644 --- a/tests/test_components/test_boundaries.py +++ b/tests/test_components/test_boundaries.py @@ -2,8 +2,8 @@ from __future__ import annotations -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d.components.boundary import ( @@ -84,11 +84,11 @@ def test_boundary_validators(): periodic = Periodic() # test `bloch_on_both_sides` - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = Boundary(plus=bloch, minus=pec) # test `periodic_with_pml` - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = Boundary(plus=periodic, minus=pml) @@ -196,7 +196,7 @@ def test_boundaryspec_classmethods(): @pytest.mark.parametrize("absorber_type", [PML, StablePML, Absorber]) def test_num_layers_validator(absorber_type): """Test the Field validators that enforce ``num_layers>0``.""" - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = absorber_type(num_layers=0) diff --git a/tests/test_components/test_custom.py b/tests/test_components/test_custom.py index e52adf321e..0039f43362 100644 --- a/tests/test_components/test_custom.py +++ b/tests/test_components/test_custom.py @@ -4,9 +4,9 @@ import dill as pickle import numpy as np -import pydantic.v1 as pydantic import pytest import xarray as xr +from pydantic import ValidationError import tidy3d as td from tidy3d.components.data.dataset import PermittivityDataset @@ -119,7 +119,7 @@ def test_validator_tangential_field(): """Test that it errors if no tangential field defined.""" field_dataset = FIELD_SRC.field_dataset field_dataset = field_dataset.copy(update={"Ex": None, "Ez": None, "Hx": None, "Hz": None}) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.CustomFieldSource(size=SIZE, source_time=ST, field_dataset=field_dataset) @@ -127,7 +127,7 @@ def test_validator_non_planar(): """Test that it errors if the source geometry has a volume.""" field_dataset = FIELD_SRC.field_dataset field_dataset = field_dataset.copy(update={"Ex": None, "Ez": None, "Hx": None, "Hz": None}) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.CustomFieldSource(size=(1, 1, 1), source_time=ST, field_dataset=field_dataset) @@ -137,7 +137,7 @@ def test_validator_freq_out_of_range_src(source): key, dataset = get_dataset(source) Ex_new = td.ScalarFieldDataArray(dataset.Ex.data, coords={"x": X, "y": Y, "z": Z, "f": [0]}) dataset_fail = dataset.copy(update={"Ex": Ex_new}) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = source.updated_copy(size=SIZE, source_time=ST, **{key: dataset_fail}) @@ -148,7 +148,7 @@ def test_validator_freq_multiple(source): new_data = np.concatenate((dataset.Ex.data, dataset.Ex.data), axis=-1) Ex_new = td.ScalarFieldDataArray(new_data, coords={"x": X, "y": Y, "z": Z, "f": [1, 2]}) dataset_fail = dataset.copy(update={"Ex": Ex_new}) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = source.copy(update={key: dataset_fail}) @@ -422,7 +422,7 @@ def test_medium_smaller_than_one_positive_sigma(unstructured): if unstructured: n_dataarray = cartesian_to_unstructured(n_dataarray.isel(f=0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = CustomMedium.from_nk(n_dataarray) # negative sigma @@ -436,7 +436,7 @@ def test_medium_smaller_than_one_positive_sigma(unstructured): n_dataarray = cartesian_to_unstructured(n_dataarray.isel(f=0), seed=1) k_dataarray = cartesian_to_unstructured(k_dataarray.isel(f=0), seed=1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = CustomMedium.from_nk(n_dataarray, k_dataarray, freq=freqs[0]) @@ -473,9 +473,9 @@ def test_medium_nk(unstructured): assert np.isclose(med.eps_model(1e14), meds.eps_model(1e14), rtol=RTOL) # gain - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): med = CustomMedium.from_nk(n=n, k=-k) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): meds = CustomMedium.from_nk(n=ns, k=-ks, freq=freqs[0]) med = CustomMedium.from_nk(n=n, k=-k, allow_gain=True) meds = CustomMedium.from_nk(n=ns, k=-ks, freq=freqs[0], allow_gain=True) @@ -500,7 +500,7 @@ def test_medium_eps_model(): med.eps_model(frequency=freqs[0]) # error with multifrequency data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): med = make_custom_medium(make_scalar_data_multifreqs()) @@ -555,16 +555,8 @@ def verify_custom_medium_methods(mat, reduced_fields): # data fields in medium classes could be SpatialArrays or 2d tuples of spatial arrays # lets convert everything into 2d tuples of spatial arrays for uniform handling if isinstance(original, (td.SpatialDataArray, UnstructuredGridDataset)): - original = [ - [ - original, - ], - ] - reduced = [ - [ - reduced, - ], - ] + original = [[original]] + reduced = [[reduced]] for or_set, re_set in zip(original, reduced): assert len(or_set) == len(re_set) @@ -647,30 +639,30 @@ def test_custom_isotropic_medium(unstructured): conductivity = make_spatial_data(value=1, unstructured=unstructured, seed=seed) # some terms in permittivity are complex - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=1 + 0.1j, unstructured=unstructured, seed=seed) mat = CustomMedium(permittivity=epstmp, conductivity=conductivity) # some terms in permittivity are < 1 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=0, unstructured=unstructured, seed=seed) mat = CustomMedium(permittivity=epstmp, conductivity=conductivity) # some terms in conductivity are complex - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sigmatmp = make_spatial_data(value=0.1j, unstructured=unstructured, seed=seed) mat = CustomMedium(permittivity=permittivity, conductivity=sigmatmp) # some terms in conductivity are negative sigmatmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = CustomMedium(permittivity=permittivity, conductivity=sigmatmp) mat = CustomMedium(permittivity=permittivity, conductivity=sigmatmp, allow_gain=True) verify_custom_medium_methods(mat, ["permittivity", "conductivity"]) assert not mat.is_spatially_uniform # inconsistent coords - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sigmatmp = make_spatial_data(value=0, dx=1, unstructured=unstructured, seed=seed) mat = CustomMedium(permittivity=permittivity, conductivity=sigmatmp) @@ -718,6 +710,7 @@ def verify_custom_dispersive_medium_methods(mat, reduced_fields): @pytest.mark.parametrize("unstructured", [False, True]) +@pytest.mark.slow def test_custom_pole_residue(unstructured): """Custom pole residue medium.""" seed = 98345 @@ -726,27 +719,27 @@ def test_custom_pole_residue(unstructured): c = 1j * make_spatial_data(value=1, unstructured=unstructured, seed=seed) # some terms in eps_inf are negative - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomPoleResidue(eps_inf=epstmp, poles=((a, c),)) # some terms in eps_inf are complex - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=0.1j, unstructured=unstructured, seed=seed) mat = CustomPoleResidue(eps_inf=epstmp, poles=((a, c),)) # inconsistent coords of eps_inf with a,c - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=1, dx=1, unstructured=unstructured, seed=seed) mat = CustomPoleResidue(eps_inf=epstmp, poles=((a, c),)) # mixing Cartesian and unstructured data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=1, dx=1, unstructured=(not unstructured), seed=seed) mat = CustomPoleResidue(eps_inf=epstmp, poles=((a, c),)) # break causality - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): atmp = make_spatial_data(value=0, unstructured=unstructured, seed=seed) mat = CustomPoleResidue(eps_inf=eps_inf, poles=((atmp, c),)) @@ -762,7 +755,7 @@ def test_custom_pole_residue(unstructured): # non-dispersive but gain a = 0 * c mat = CustomPoleResidue(eps_inf=eps_inf, poles=((a, c - 0.1),)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat_medium = mat.to_medium() mat = CustomPoleResidue(eps_inf=eps_inf, poles=((a, c - 0.1),), allow_gain=True) mat_medium = mat.to_medium() @@ -776,6 +769,7 @@ def test_custom_pole_residue(unstructured): @pytest.mark.parametrize("unstructured", [False, True]) +@pytest.mark.slow def test_custom_sellmeier(unstructured): """Custom Sellmeier medium.""" seed = 897245 @@ -786,39 +780,39 @@ def test_custom_sellmeier(unstructured): c2 = make_spatial_data(value=0.1, unstructured=unstructured, seed=seed) # complex b - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): btmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomSellmeier(coeffs=((b1, c1), (btmp, c2))) # complex c - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ctmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomSellmeier(coeffs=((b1, c1), (b2, ctmp))) # negative c - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ctmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomSellmeier(coeffs=((b1, c1), (b2, ctmp))) # negative b btmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = CustomSellmeier(coeffs=((b1, c1), (btmp, c2))) mat = CustomSellmeier(coeffs=((b1, c1), (btmp, c2)), allow_gain=True) assert mat.pole_residue.allow_gain # inconsistent coord - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): btmp = make_spatial_data(value=0, dx=1, unstructured=unstructured, seed=seed) mat = CustomSellmeier(coeffs=((b1, c2), (btmp, c2))) # mixing Cartesian and unstructured data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): btmp = make_spatial_data(value=0, dx=1, unstructured=(not unstructured), seed=seed) mat = CustomSellmeier(coeffs=((b1, c2), (btmp, c2))) # some of C is close to 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ctmp = make_spatial_data( value=0, unstructured=unstructured, seed=seed, random_magnitude=1e-7 ) @@ -838,6 +832,7 @@ def test_custom_sellmeier(unstructured): @pytest.mark.parametrize("unstructured", [False, True]) +@pytest.mark.slow def test_custom_lorentz(unstructured): """Custom Lorentz medium.""" seed = 31342 @@ -852,32 +847,32 @@ def test_custom_lorentz(unstructured): delta2 = make_spatial_data(value=0, unstructured=unstructured, seed=seed) # complex de - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): detmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (detmp, f2, delta2))) # mixed delta > f and delta < f over spatial points - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): deltatmp = make_spatial_data(value=1, unstructured=unstructured, seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (de2, f2, deltatmp))) # inconsistent coords - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ftmp = make_spatial_data(value=1, dx=1, unstructured=unstructured, seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (de2, ftmp, delta2))) # mixing Cartesian and unstructured data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ftmp = make_spatial_data(value=1, dx=1, unstructured=(not unstructured), seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (de2, ftmp, delta2))) # break causality with negative delta - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): deltatmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (de2, f2, deltatmp))) # gain medium with negative delta epsilon - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): detmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomLorentz(eps_inf=eps_inf, coeffs=((de1, f1, delta1), (detmp, f2, delta2))) mat = CustomLorentz( @@ -908,22 +903,22 @@ def test_custom_drude(unstructured): delta2 = make_spatial_data(value=0, unstructured=unstructured, seed=seed) # complex delta - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): deltatmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomDrude(eps_inf=eps_inf, coeffs=((f1, delta1), (f2, deltatmp))) # negative delta - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): deltatmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomDrude(eps_inf=eps_inf, coeffs=((f1, delta1), (f2, deltatmp))) # inconsistent coords - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ftmp = make_spatial_data(value=1, dx=1, unstructured=unstructured, seed=seed) mat = CustomDrude(eps_inf=eps_inf, coeffs=((f1, delta1), (ftmp, delta2))) # mixing Cartesian and unstructured data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ftmp = make_spatial_data(value=1, dx=1, unstructured=(not unstructured), seed=seed) mat = CustomDrude(eps_inf=eps_inf, coeffs=((f1, delta1), (ftmp, delta2))) @@ -946,38 +941,38 @@ def test_custom_debye(unstructured): tau2 = make_spatial_data(value=0.1, unstructured=unstructured, seed=seed) # complex eps - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (epstmp, tau2))) # complex tau - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): tautmp = make_spatial_data(value=-0.5j, unstructured=unstructured, seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (eps2, tautmp))) # negative tau - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): tautmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (eps2, tautmp))) # some of tau is close to 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): tautmp = make_spatial_data( value=0, unstructured=unstructured, seed=seed, random_magnitude=1e-38 ) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (eps2, tautmp))) # inconsistent coords - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=0, dx=1, unstructured=unstructured, seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (epstmp, tau2))) # mixing Cartesian and unstructured data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=0, dx=1, unstructured=(not unstructured), seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (epstmp, tau2))) # negative delta epsilon - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): epstmp = make_spatial_data(value=-0.5, unstructured=unstructured, seed=seed) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (epstmp, tau2))) mat = CustomDebye(eps_inf=eps_inf, coeffs=((eps1, tau1), (epstmp, tau2)), allow_gain=True) @@ -991,6 +986,7 @@ def test_custom_debye(unstructured): @pytest.mark.parametrize("unstructured", [True]) +@pytest.mark.slow def test_custom_anisotropic_medium(unstructured): """Custom anisotropic medium.""" seed = 43243 @@ -1029,7 +1025,7 @@ def test_custom_anisotropic_medium(unstructured): # so that xx-component is using "nearest" freq = 2e14 dist_coeff = 0.7 - coord_test = td.Coords(x=[X[0] * dist_coeff + X[1] * (1 - dist_coeff)], y=Y[0], z=Z[0]) + coord_test = td.Coords(x=[X[0] * dist_coeff + X[1] * (1 - dist_coeff)], y=[Y[0]], z=[Z[0]]) eps_nearest = mat.eps_sigma_to_eps_complex( permittivity.interp(x=X[0], y=Y[0], z=Z[0], method="nearest"), conductivity.interp(x=X[0], y=Y[0], z=Z[0], method="nearest"), @@ -1080,11 +1076,11 @@ def test_custom_anisotropic_medium(unstructured): field_components = {f"eps_{d}{d}": make_scalar_data() for d in "xyz"} eps_dataset = PermittivityDataset(**field_components) mat_tmp = CustomMedium(eps_dataset=eps_dataset) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = CustomAnisotropicMedium(xx=mat_tmp, yy=mat_yy, zz=mat_zz) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = CustomAnisotropicMedium(xx=mat_xx, yy=mat_tmp, zz=mat_zz) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = CustomAnisotropicMedium(xx=mat_xx, yy=mat_yy, zz=mat_tmp) @@ -1175,7 +1171,7 @@ def test_warn_planewave_intersection(): medium=mat, ) with AssertLogLevel("WARNING"): - sim.updated_copy(structures=[box]) + sim.updated_copy(structures=(box,)) def test_warn_diffraction_monitor_intersection(): @@ -1204,10 +1200,10 @@ def test_warn_diffraction_monitor_intersection(): with AssertLogLevel(None): sim = td.Simulation( size=(1, 1, 2), - structures=[box], + structures=(box,), grid_spec=td.GridSpec.auto(wavelength=1), - monitors=[monitor], - sources=[src], + monitors=(monitor,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1220,7 +1216,7 @@ def test_warn_diffraction_monitor_intersection(): medium=mat, ) with AssertLogLevel("WARNING"): - sim.updated_copy(structures=[box]) + sim.updated_copy(structures=(box,)) @pytest.mark.parametrize( @@ -1247,7 +1243,7 @@ def test_custom_medium_duplicate_coords(custom_class, data_key): spatial_data = td.SpatialDataArray(data, coords=coords) if custom_class == CustomMedium: - with pytest.raises(pydantic.ValidationError, match="duplicate coordinates"): + with pytest.raises(ValidationError, match="duplicate coordinates"): _ = custom_class(permittivity=spatial_data) else: field_components = { @@ -1255,5 +1251,5 @@ def test_custom_medium_duplicate_coords(custom_class, data_key): } field_dataset = td.FieldDataset(**field_components) - with pytest.raises(pydantic.ValidationError, match="duplicate coordinates"): + with pytest.raises(ValidationError, match="duplicate coordinates"): _ = custom_class(size=SIZE, source_time=ST, **{data_key: field_dataset}) diff --git a/tests/test_components/test_eme.py b/tests/test_components/test_eme.py index 5400d7bd36..18a3fc2e19 100644 --- a/tests/test_components/test_eme.py +++ b/tests/test_components/test_eme.py @@ -1,7 +1,7 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pd +import pydantic as pd import pytest from matplotlib import pyplot as plt @@ -91,11 +91,11 @@ def make_eme_sim(): def test_sim_version_update(): sim = make_eme_sim() - sim_dict = sim.dict() + sim_dict = sim.model_dump() sim_dict["version"] = "ancient_version" with AssertLogLevel("WARNING"): - sim_new = td.EMESimulation.parse_obj(sim_dict) + sim_new = td.EMESimulation.model_validate(sim_dict) assert sim_new.version == td.__version__ @@ -309,7 +309,7 @@ def test_eme_simulation(): _ = sim.updated_copy(freqs=None) # no symmetry in propagation direction - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(symmetry=(0, 0, 1)) # test warning for not providing wavelength in autogrid @@ -327,8 +327,8 @@ def test_eme_simulation(): ) # test port offsets - with pytest.raises(ValidationError): - _ = sim.updated_copy(port_offsets=[sim.size[sim.axis] * 2 / 3, sim.size[sim.axis] * 2 / 3]) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(port_offsets=(sim.size[sim.axis] * 2 / 3, sim.size[sim.axis] * 2 / 3)) # test duplicate freqs with pytest.raises(pd.ValidationError): @@ -347,7 +347,7 @@ def test_eme_simulation(): med = td.FullyAnisotropicMedium(permittivity=perm, conductivity=cond) struct = sim.structures[0].updated_copy(medium=med) with pytest.raises(pd.ValidationError): - _ = sim.updated_copy(structures=[struct]) + _ = sim.updated_copy(structures=(struct,)) # warn for time modulated FREQ_MODULATE = 1e12 AMP_TIME = 1.1 @@ -364,7 +364,7 @@ def test_eme_simulation(): _ = td.EMESimulation( size=sim.size, monitors=sim.monitors, - structures=[struct], + structures=(struct,), grid_spec=grid_spec, axis=sim.axis, eme_grid_spec=sim.eme_grid_spec, @@ -372,7 +372,8 @@ def test_eme_simulation(): ) # warn for nonlinear nonlinear = td.Medium( - permittivity=2, nonlinear_spec=td.NonlinearSpec(models=[td.NonlinearSusceptibility(chi3=1)]) + permittivity=2, + nonlinear_spec=td.NonlinearSpec(models=(td.NonlinearSusceptibility(chi3=1),)), ) struct = sim.structures[0].updated_copy(medium=nonlinear) with AssertLogLevel("WARNING"): @@ -397,28 +398,28 @@ def test_eme_simulation(): # test monitor setup monitor = sim.monitors[0].updated_copy(freqs=[sim.freqs[0], sim.freqs[0]]) - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[monitor]) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(monitor,)) monitor = sim.monitors[0].updated_copy(freqs=[5e10]) - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[monitor]) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(monitor,)) monitor = sim.monitors[0].updated_copy(num_modes=1000) - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[monitor]) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(monitor,)) monitor = sim.monitors[2].updated_copy(num_modes=6) - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[monitor]) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(monitor,)) # test monitor at simulation bounds monitor = sim.monitors[-1].updated_copy(center=[0, 0, -sim.size[2] / 2]) with pytest.raises(pd.ValidationError): - _ = sim.updated_copy(monitors=[monitor]) + _ = sim.updated_copy(monitors=(monitor,)) # test max sim size and freqs sim_bad = sim.updated_copy(size=(150, 150, 3)) with pytest.raises(SetupError): sim_bad.validate_pre_upload() - sim_bad = sim.updated_copy(size=(50, 50, 3), monitors=[]) + sim_bad = sim.updated_copy(size=(50, 50, 3), monitors=()) with AssertLogLevel("WARNING", "slow-down"): sim_bad.validate_pre_upload() @@ -451,13 +452,13 @@ def test_eme_simulation(): large_monitor = sim.monitors[2].updated_copy(size=(td.inf, td.inf, td.inf)) _ = sim.updated_copy( size=(10, 10, 10), - monitors=[large_monitor], + monitors=(large_monitor,), freqs=list(1e14 * np.linspace(1, 2, 1)), grid_spec=sim.grid_spec.updated_copy(wavelength=1), ) sim_bad = sim.updated_copy( size=(10, 10, 10), - monitors=[large_monitor], + monitors=(large_monitor,), freqs=list(1e14 * np.linspace(1, 2, 5)), grid_spec=sim.grid_spec.updated_copy(wavelength=1), ) @@ -476,7 +477,7 @@ def test_eme_simulation(): sim_bad.updated_copy(store_coeffs=True).validate_pre_upload() sim_bad = sim.updated_copy( size=(10, 10, 10), - monitors=[large_monitor], + monitors=(large_monitor,), freqs=list(1e14 * np.linspace(1, 2, 20)), grid_spec=sim.grid_spec.updated_copy(wavelength=1), ) @@ -484,7 +485,7 @@ def test_eme_simulation(): sim_bad.validate_pre_upload() sim_bad = sim.updated_copy( size=(10, 10, 10), - monitors=[large_monitor, large_monitor.updated_copy(name="lmon2")], + monitors=(large_monitor, large_monitor.updated_copy(name="lmon2")), freqs=list(1e14 * np.linspace(1, 2, 5)), grid_spec=sim.grid_spec.updated_copy(wavelength=1), ) @@ -497,25 +498,25 @@ def test_eme_simulation(): center=(0, 0, -1.5), name="modes", ) - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[mode_monitor], port_offsets=(0.5, 0.5)) + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(mode_monitor,), port_offsets=(0.5, 0.5)) # test eme cell interval space mode_monitor = mode_monitor.updated_copy( size=(td.inf, td.inf, td.inf), eme_cell_interval_space=8 ) - sim2 = sim.updated_copy(monitors=[mode_monitor]) + sim2 = sim.updated_copy(monitors=(mode_monitor,)) assert sim2._monitor_num_eme_cells(monitor=mode_monitor) == 2 # test monitor num modes - sim_tmp = sim.updated_copy(monitors=[sim.monitors[0].updated_copy(num_modes=1)]) + sim_tmp = sim.updated_copy(monitors=(sim.monitors[0].updated_copy(num_modes=1),)) assert sim_tmp._monitor_num_modes_cell(monitor=sim_tmp.monitors[0], cell_index=0) == 1 # test monitor num freqs - sim_tmp = sim.updated_copy(monitors=[sim.monitors[0].updated_copy(freqs=[sim.freqs[0]])]) + sim_tmp = sim.updated_copy(monitors=(sim.monitors[0].updated_copy(freqs=[sim.freqs[0]]),)) assert sim_tmp._monitor_num_freqs(monitor=sim_tmp.monitors[0]) == 1 # test sweep - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy( sweep_spec=td.EMELengthSweep(scale_factors=list(np.linspace(1, 2, 10))) ) @@ -535,9 +536,9 @@ def test_eme_simulation(): scale_factors=np.stack((np.linspace(1, 2, 7), np.linspace(1, 2, 7))) ), ) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim_no_field.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=[])) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim_no_field.updated_copy( sweep_spec=td.EMELengthSweep( scale_factors=np.stack( @@ -549,13 +550,13 @@ def test_eme_simulation(): ) ) # second shape of length sweep must equal number of cells - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim_no_field.updated_copy( sweep_spec=td.EMELengthSweep(scale_factors=np.array([[1, 2], [3, 4]])) ) _ = sim.updated_copy(sweep_spec=td.EMEModeSweep(num_modes=list(np.arange(1, 5)))) # test sweep size limit - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim_no_field.updated_copy(sweep_spec=td.EMELengthSweep(scale_factors=[])) sim_bad = sim_no_field.updated_copy( sweep_spec=td.EMELengthSweep(scale_factors=list(np.linspace(1, 2, 200))) @@ -563,7 +564,7 @@ def test_eme_simulation(): with pytest.raises(SetupError): sim_bad.validate_pre_upload() # can't exceed max num modes - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(sweep_spec=td.EMEModeSweep(num_modes=list(np.arange(150, 200)))) # don't warn in these two cases @@ -604,38 +605,38 @@ def test_eme_simulation(): assert sim._sweep_modes assert sim._num_sweep == 2 assert sim._monitor_num_sweep(sim.monitors[0]) == 1 - sim = sim.updated_copy(monitors=[sim.monitors[0].updated_copy(num_sweep=None)]) + sim = sim.updated_copy(monitors=(sim.monitors[0].updated_copy(num_sweep=None),)) assert sim._monitor_num_sweep(sim.monitors[0]) == 2 - with pytest.raises(SetupError): - _ = sim.updated_copy(monitors=[sim.monitors[0].updated_copy(num_sweep=4)]) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): + _ = sim.updated_copy(monitors=(sim.monitors[0].updated_copy(num_sweep=4),)) + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(sweep_spec=td.EMEFreqSweep(freq_scale_factors=[1e-10, 2])) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy( eme_grid_spec=td.EMEExplicitGrid( - boundaries=[-sim.size[2] / 2 + 0.001], - mode_specs=[td.EMEModeSpec(), td.EMEModeSpec()], + boundaries=(-sim.size[2] / 2 + 0.001,), + mode_specs=(td.EMEModeSpec(), td.EMEModeSpec()), ) ) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy( eme_grid_spec=td.EMEExplicitGrid( - boundaries=[sim.size[2] / 2 - 0.001], - mode_specs=[td.EMEModeSpec(), td.EMEModeSpec()], + boundaries=(sim.size[2] / 2 - 0.001,), + mode_specs=(td.EMEModeSpec(), td.EMEModeSpec()), ) ) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy( - monitors=[ + monitors=( td.ModeSolverMonitor( - center=[0, 0, sim.size[2] / 2 - 0.001], - size=[td.inf, td.inf, 0], + center=(0, 0, sim.size[2] / 2 - 0.001), + size=(td.inf, td.inf, 0), name="modes", freqs=sim.freqs, mode_spec=td.ModeSpec(), - ) - ] + ), + ) ) @@ -1032,6 +1033,7 @@ def _get_eme_mode_solver_data(num_sweep=0): ) +@pytest.mark.slow def _get_eme_field_data(num_sweep=0): dataset = _get_eme_field_dataset(num_sweep=num_sweep) kwargs = dataset.field_components @@ -1099,6 +1101,7 @@ def _get_eme_port_modes(num_sweep=0): return mode_data.updated_copy(n_complex=n_complex, **kwargs) +@pytest.mark.slow def test_eme_sim_data(): sim = make_eme_sim() mode_monitor_data = _get_eme_mode_solver_data() @@ -1376,7 +1379,7 @@ def test_eme_sim_data(): # test field in basis with freq sweep field_monitor_data = _get_eme_field_data(num_sweep=10) data[2] = field_monitor_data - sim_data = sim_data.updated_copy(data=data) + sim_data = sim_data.updated_copy(data=tuple(data)) field_in_basis = sim_data.field_in_basis(field=sim_data["field"], port_index=0) assert len(field_in_basis.Ex.sweep_index) == 10 assert "mode_index" in field_in_basis.Ex.coords @@ -1420,7 +1423,7 @@ def test_eme_periodicity(): # directly give it num_reps # can't have field monitor - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(num_reps=2, path="eme_grid_spec/subgrids/1") # EMEPeriodicitySweep validation @@ -1428,25 +1431,25 @@ def test_eme_periodicity(): _ = td.EMEPeriodicitySweep(num_reps=[{"a": n} for n in range(150000, 150003)]) sweep_spec = td.EMEPeriodicitySweep(num_reps=[{"a": n} for n in range(1, 4)]) # still can't have field monitor - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(sweep_spec=sweep_spec) # remove the field monitor, now it passes desired_cell_index_pairs = set([(i, i + 1) for i in range(6)] + [(5, 1)]) sim = sim.updated_copy( - monitors=[m for m in sim.monitors if not isinstance(m, td.EMEFieldMonitor)] + monitors=tuple(m for m in sim.monitors if not isinstance(m, td.EMEFieldMonitor)) ) sim2 = sim.updated_copy(num_reps=2, path="eme_grid_spec/subgrids/1") assert set(sim2._cell_index_pairs) == desired_cell_index_pairs # sweep can't have coeff monitor - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(sweep_spec=sweep_spec) - with pytest.raises(SetupError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(sweep_spec=sweep_spec, store_coeffs=True, monitors=[]) # remove coeff monitor too, now it passes with AssertLogLevel(None): sim = sim.updated_copy( - monitors=[m for m in sim.monitors if not isinstance(m, td.EMECoefficientMonitor)] + monitors=tuple(m for m in sim.monitors if not isinstance(m, td.EMECoefficientMonitor)) ) sim2 = sim.updated_copy(sweep_spec=sweep_spec) assert set(sim2._cell_index_pairs) == desired_cell_index_pairs @@ -1465,10 +1468,10 @@ def test_eme_grid_from_structures(): names=[None, "wg", None], num_reps=[1, 2, 1], ) - sim = sim.updated_copy(eme_grid_spec=eme_grid_spec, monitors=[]) + sim = sim.updated_copy(eme_grid_spec=eme_grid_spec, monitors=()) with pytest.raises(ValidationError): _ = td.EMECompositeGrid.from_structure_groups( - structure_groups=[], + structure_groups=(), axis=2, mode_specs=[], names=[None, "wg", None], @@ -1476,7 +1479,7 @@ def test_eme_grid_from_structures(): ) with pytest.raises(ValidationError): _ = td.EMECompositeGrid.from_structure_groups( - structure_groups=[[], [td.Box(center=(0, 0, 0), size=(1, 1, 1))], []], + structure_groups=[([], [td.Box(center=(0, 0, 0), size=(1, 1, 1))], [])], axis=2, mode_specs=[td.EMEModeSpec(num_modes=1)] * 2, names=[None, "wg", None], @@ -1541,6 +1544,6 @@ def test_eme_sim_2d(): axis=2, freqs=[freq0], eme_grid_spec=eme_grid_spec, - monitors=[monitor], + monitors=(monitor,), port_offsets=(0.5, 0), ) diff --git a/tests/test_components/test_field_projection.py b/tests/test_components/test_field_projection.py index 7d74bf02b1..683c1387e5 100644 --- a/tests/test_components/test_field_projection.py +++ b/tests/test_components/test_field_projection.py @@ -4,9 +4,9 @@ import autograd.numpy as anp import numpy as np -import pydantic.v1 as pydantic import pytest from autograd import make_vjp +from pydantic import ValidationError import tidy3d as td from tidy3d.components.field_projection import FieldProjector, _far_field_integral @@ -287,7 +287,7 @@ def test_proj_clientside(): sim = td.Simulation( size=sim_size, grid_spec=td.GridSpec.auto(wavelength=td.C_0 / f0), - monitors=[monitor], + monitors=(monitor,), run_time=1e-12, ) @@ -622,7 +622,7 @@ def test_2d_sim_with_proj_monitors_near(): # Modify only proj_distance and far_field_approx proj_monitors_near = [ - monitor.__class__( + type(monitor)( proj_distance=R_FAR / 50, # Adjust projection distance far_field_approx=False, # Disable far-field approximation **{ @@ -635,7 +635,7 @@ def test_2d_sim_with_proj_monitors_near(): ] with pytest.raises( - pydantic.ValidationError, + ValidationError, match="Exact far-field projection for 2D simulations is not yet available", ): _ = td.Simulation( diff --git a/tests/test_components/test_geometry.py b/tests/test_components/test_geometry.py index 08fead6b09..92a3fd66c3 100644 --- a/tests/test_components/test_geometry.py +++ b/tests/test_components/test_geometry.py @@ -8,7 +8,7 @@ import gdstk import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic +import pydantic as pd import pytest import shapely import trimesh @@ -23,7 +23,7 @@ ) import tidy3d as td -from tidy3d.compat import _shapely_is_older_than +from tidy3d.compat import _package_is_older_than from tidy3d.components.geometry.base import cleanup_shapely_object from tidy3d.components.geometry.mesh import AREA_SIZE_THRESHOLD from tidy3d.components.geometry.utils import ( @@ -50,10 +50,10 @@ CYLINDER = td.Cylinder(axis=2, length=1, radius=1) GROUP = td.GeometryGroup( - geometries=[ + geometries=( td.Box(center=(-0.25, 0, 0), size=(0.5, 1, 1)), td.Box(center=(0.25, 0, 0), size=(0.5, 1, 1)), - ] + ) ) UNION = td.ClipOperation( operation="union", @@ -266,16 +266,16 @@ def test_intersections_plane_quad_segs(component, quad_segs): def test_center_not_inf_validate(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Box(center=(td.inf, 0, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Box(center=(-td.inf, 0, 0)) def test_radius_not_inf_validate(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Sphere(radius=td.inf) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder(radius=td.inf, center=(0, 0, 0), axis=1, length=1) @@ -292,7 +292,7 @@ def test_slanted_cylinder_infinite_length_validate(): sidewall_angle=0.1, reference_plane="middle", ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder( radius=1, center=(0, 0, 0), @@ -301,7 +301,7 @@ def test_slanted_cylinder_infinite_length_validate(): sidewall_angle=0.1, reference_plane="top", ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder( radius=1, center=(0, 0, 0), @@ -362,7 +362,7 @@ def test_polyslab_inf_bounds(lower_bound, upper_bound): def test_polyslab_bounds(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.PolySlab(vertices=((0, 0), (1, 0), (1, 1)), slab_bounds=(0.5, -0.5), axis=2) @@ -398,15 +398,15 @@ def test_polyslab_inf_to_finite_bounds(axis): def test_validate_polyslab_vertices_valid(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): POLYSLAB.copy(update={"vertices": (1, 2, 3)}) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): crossing_verts = ((0, 0), (1, 1), (0, 1), (1, 0)) POLYSLAB.copy(update={"vertices": crossing_verts}) def test_sidewall_failed_validation(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): POLYSLAB.copy(update={"sidewall_angle": 1000}) @@ -440,7 +440,7 @@ def test_gdstk_cell(): def make_geo_group(): """Make a generic Geometry Group.""" - boxes = [td.Box(size=(1, 1, 1), center=(i, 0, 0)) for i in range(-5, 5)] + boxes = tuple(td.Box(size=(1, 1, 1), center=(i, 0, 0)) for i in range(-5, 5)) return td.GeometryGroup(geometries=boxes) @@ -468,8 +468,8 @@ def test_geo_group_methods(): def test_geo_group_empty(): """dont allow empty geometry list.""" - with pytest.raises(pydantic.ValidationError): - _ = td.GeometryGroup(geometries=[]) + with pytest.raises(pd.ValidationError): + _ = td.GeometryGroup(geometries=()) def test_geo_group_volume(): @@ -644,22 +644,22 @@ def test_flattening(): flat = list( flatten_groups( td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(1, 1, 1)), td.Box(size=(0, 1, 0)), td.ClipOperation( operation="union", geometry_a=td.Box(size=(0, 0, 1)), geometry_b=td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(2, 2, 2)), td.GeometryGroup( - geometries=[td.Box(size=(3, 3, 3)), td.Box(size=(3, 0, 3))] + geometries=(td.Box(size=(3, 3, 3)), td.Box(size=(3, 0, 3))) ), - ] + ) ), ), - ] + ) ) ) ) @@ -669,22 +669,22 @@ def test_flattening(): flat = list( flatten_groups( td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(1, 1, 1)), td.Box(size=(0, 1, 0)), td.ClipOperation( operation="intersection", geometry_a=td.Box(size=(0, 0, 1)), geometry_b=td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(2, 2, 2)), td.GeometryGroup( - geometries=[td.Box(size=(3, 3, 3)), td.Box(size=(3, 0, 3))] + geometries=(td.Box(size=(3, 3, 3)), td.Box(size=(3, 0, 3))) ), - ] + ) ), ), - ] + ) ) ) ) @@ -727,15 +727,15 @@ def test_geometry_traversal(): assert len(geometries) == 1 geo_tree = td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(1, 0, 0)), td.ClipOperation( operation="intersection", geometry_a=td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(5, 0, 0)), td.Box(size=(6, 0, 0)), - ] + ) ), geometry_b=td.ClipOperation( operation="difference", @@ -744,13 +744,13 @@ def test_geometry_traversal(): ), ), td.GeometryGroup( - geometries=[ + geometries=( td.Box(size=(3, 0, 0)), td.Box(size=(4, 0, 0)), - ] + ) ), td.Box(size=(2, 0, 0)), - ] + ) ) geometries = list(traverse_geometries(geo_tree)) assert len(geometries) == 13 @@ -768,34 +768,34 @@ def test_geometry(): # _ = PolySlab(vertices=vertices_np, slab_bounds=(-1, 1), axis=1) # make sure wrong axis arguments error - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder(radius=1, center=(0, 0, 0), axis=-1, length=1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.PolySlab(radius=1, center=(0, 0, 0), axis=-1, slab_bounds=(-0.5, 0.5)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder(radius=1, center=(0, 0, 0), axis=3, length=1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.PolySlab(radius=1, center=(0, 0, 0), axis=3, slab_bounds=(-0.5, 0.5)) # make sure negative values error - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Sphere(radius=-1, center=(0, 0, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder(radius=-1, center=(0, 0, 0), axis=3, length=1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Cylinder(radius=1, center=(0, 0, 0), axis=3, length=-1) def test_geometry_sizes(): # negative in size kwargs errors for size in (-1, 1, 1), (1, -1, 1), (1, 1, -1): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Box(size=size, center=(0, 0, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Simulation(size=size, run_time=1e-12, grid_spec=td.GridSpec(wavelength=1.0)) # negative grid sizes error? - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Simulation(size=(1, 1, 1), grid_spec=td.GridSpec.uniform(dl=-1.0), run_time=1e-12) @@ -936,11 +936,80 @@ def test_polyslab_intersection_inf_bounds(): assert poly.intersections_plane(x=0)[0] == shapely.box(-1, 0.0, 1, LARGE_NUMBER) # 2) [-inf, 0] - poly = poly.updated_copy(slab_bounds=[-td.inf, 0]) + poly = poly.updated_copy(slab_bounds=(-td.inf, 0)) assert len(poly.intersections_plane(x=0)) == 1 assert poly.intersections_plane(x=0)[0] == shapely.box(-1, -LARGE_NUMBER, 1, 0) +def test_polyslab_intersection_with_coincident_plane(): + """Test if intersection returns the correct shape when the plane is coincident with the side face.""" + poly = td.PolySlab( + vertices=[[500.0, -7500.0], [500.0, 7500.0], [-500.0, 7500.0], [-500.0, -7500.0]], + slab_bounds=[0, 50], + axis=2, + ) + # Each case should give one side face of the polyslab + expected_x_face = shapely.box(-7500, 0, 7500, 50) # y-extent × z-extent + expected_y_face = shapely.box(-500, 0, 500, 50) # x-extent × z-extent + + assert poly.intersections_plane(x=-500) == [expected_x_face] + assert poly.intersections_plane(x=500) == [expected_x_face] + assert poly.intersections_plane(y=-7500) == [expected_y_face] + assert poly.intersections_plane(y=7500) == [expected_y_face] + + +def test_polyslab_intersection_rotated_square(): + """Test PolySlab plane intersection with a rotated square (diamond shape).""" + # Create a diamond by rotating a square 45 degrees + size = 2.0 + angle = np.pi / 4 + base_vertices = np.array( + [[-size / 2, -size / 2], [size / 2, -size / 2], [size / 2, size / 2], [-size / 2, size / 2]] + ) + cos_a, sin_a = np.cos(angle), np.sin(angle) + rotation = np.array([[cos_a, -sin_a], [sin_a, cos_a]]) + rotated = base_vertices @ rotation.T + rotated = rotated - rotated.min(axis=0) + 0.5 # shift to positive quadrant + vertices = [tuple(v) for v in rotated] + + polyslab = td.PolySlab(vertices=vertices, slab_bounds=(0, 3), axis=2) + + all_verts = np.array(vertices) + left_tip_x = all_verts[:, 0].min() + bottom_tip_y = all_verts[:, 1].min() + x_center = (all_verts[:, 0].min() + all_verts[:, 0].max()) / 2 + + # Test 1: Cut at z=1.5 (middle of slab) - should give full diamond + cross_section = polyslab.intersections_plane(z=1.5) + assert len(cross_section) == 1 + assert np.isclose(cross_section[0].area, 4.0) + + # Test 2: Cut through center at x=x_center - should give rectangle + cross_section = polyslab.intersections_plane(x=x_center) + assert len(cross_section) == 1 + assert cross_section[0].area > 0 + + # Test 3: Cut at left corner tip (tangent touch) - should give degenerate shape + cross_section = polyslab.intersections_plane(x=left_tip_x) + assert len(cross_section) == 1 + assert np.isclose(cross_section[0].area, 0.0) + + # Test 4: Cut near left corner (slightly inside) - should give small shape + cross_section = polyslab.intersections_plane(x=left_tip_x + 0.3) + assert len(cross_section) == 1 + assert cross_section[0].area > 0 + + # Test 5: Cut at bottom corner (tangent touch) - should give degenerate shape + cross_section = polyslab.intersections_plane(y=bottom_tip_y) + assert len(cross_section) == 1 + assert np.isclose(cross_section[0].area, 0.0) + + # Test 6: Cut at z=0 (bottom boundary) - should give full diamond + cross_section = polyslab.intersections_plane(z=0) + assert len(cross_section) == 1 + assert np.isclose(cross_section[0].area, 4.0) + + def test_from_shapely(): ring = shapely.LinearRing([(-16, 9), (-8, 9), (-12, 2)]) poly = shapely.Polygon([(-2, 0), (-10, 0), (-6, 7)]) @@ -1123,7 +1192,7 @@ def test_triangle_mesh_to_stl_roundtrip(tmp_path, binary): def test_geo_group_sim(): geo_grp = td.TriangleMesh.from_stl("tests/data/two_boxes_separate.stl") geos_orig = list(geo_grp.geometries) - geo_grp_full = geo_grp.updated_copy(geometries=[*geos_orig, td.Box(size=(1, 1, 1))]) + geo_grp_full = geo_grp.updated_copy(geometries=(*geos_orig, td.Box(size=(1, 1, 1)))) sim = td.Simulation( size=(10, 10, 10), @@ -1140,7 +1209,7 @@ def test_geo_group_sim(): def test_finite_geometry_transformation(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Box(size=(td.inf, 0, 1)).scaled(1, 1, 1) @@ -1199,7 +1268,7 @@ def test_subdivide(): overlapping_boxes = td.GeometryGroup(geometries=(box, overlap_box)) background_structure = td.Structure(medium=td.Medium(), geometry=td.Box(size=(10, 10, 10))) - subdivisions = subdivide(geom=overlapping_boxes, structures=[background_structure]) + subdivisions = subdivide(geom=overlapping_boxes, structures=(background_structure,)) assert len(subdivisions) == 1 # Test that when a small sliver is created during subdivide @@ -1207,7 +1276,7 @@ def test_subdivide(): box_sliver = td.Structure( medium=td.Medium(), geometry=td.Box(size=(1, 1, 1), center=(1 - fp_eps, 0, 0)) ) - subdivisions = subdivide(geom=overlapping_boxes, structures=[background_structure, box_sliver]) + subdivisions = subdivide(geom=overlapping_boxes, structures=(background_structure, box_sliver)) def test_subdivide_geometry_group_with_polygon_holes(): @@ -1470,7 +1539,7 @@ def wrong_shape_height_func(x, y): def test_cleanup_shapely_object(): - if _shapely_is_older_than("2.1"): + if _package_is_older_than("shapely", "2.1"): # (Old versions of shapely don't support `shapely.make_valid()` with the correct arguments. # However older alternatives like `.buffer(0)` are not as robust. `.buffer(0)` is likely # to generate polygons which look correct, but have extra vertices, causing test to fail. diff --git a/tests/test_components/test_grid.py b/tests/test_components/test_grid.py index f3f3b6310b..3eda1b7d3a 100644 --- a/tests/test_components/test_grid.py +++ b/tests/test_components/test_grid.py @@ -25,6 +25,99 @@ def test_coords(): _ = Coords(x=x, y=y, z=z) +def test_coords_arrays_are_immutable(): + """Test that arrays in Coords objects are immutable. + + This ensures that numpy arrays in Pydantic models cannot be modified, + enforcing true immutability for these data structures. + """ + + # Create original arrays + x_orig = np.array([1.0, 2.0, 3.0]) + y_orig = np.array([4.0, 5.0, 6.0]) + z_orig = np.array([7.0, 8.0, 9.0]) + + # Create Coords object + coords = Coords(x=x_orig, y=y_orig, z=z_orig) + + # Get dictionary + coord_dict = coords.to_dict + + # Verify we got the right values + assert np.array_equal(coord_dict["x"], x_orig) + assert np.array_equal(coord_dict["y"], y_orig) + assert np.array_equal(coord_dict["z"], z_orig) + + # Verify arrays are not writeable + assert not coord_dict["x"].flags.writeable + assert not coord_dict["y"].flags.writeable + assert not coord_dict["z"].flags.writeable + + # Attempting to modify the arrays should raise an error + with pytest.raises(ValueError, match="output array is read-only"): + coord_dict["x"] -= 10 + + with pytest.raises(ValueError, match="output array is read-only"): + coord_dict["y"] *= 2 + + with pytest.raises(ValueError, match="output array is read-only"): + coord_dict["z"] += 100 + + # Arrays should still have original values + assert np.array_equal(coord_dict["x"], x_orig) + assert np.array_equal(coord_dict["y"], y_orig) + assert np.array_equal(coord_dict["z"], z_orig) + + +def test_grid_boundaries_modification_pattern(): + """Test the pattern of modifying grid boundaries after retrieval. + + This demonstrates that arrays are immutable and shows the correct + pattern for creating modified versions. + """ + + # Create a grid for testing boundary modification + boundaries_x = np.array([-1.0, 0.0, 1.0]) + boundaries_y = np.array([-1.0, 0.0, 1.0]) + boundaries_z = np.array([-1.0, 0.0, 1.0]) + coords = Coords(x=boundaries_x, y=boundaries_y, z=boundaries_z) + grid = Grid(boundaries=coords) + + # Store original boundary values + original_x = grid.boundaries.x.copy() + original_y = grid.boundaries.y.copy() + original_z = grid.boundaries.z.copy() + + # Get boundaries dictionary + boundaries = grid.boundaries.to_dict + center = [0.5, 0.5, 0.5] # Simulate an offset value + + # Verify that direct modification fails due to immutability + with pytest.raises(ValueError, match="output array is read-only"): + boundaries["x"] -= center[0] + + # Show the correct pattern: make copies when modification is needed + boundaries_copy = {k: v.copy() for k, v in boundaries.items()} + + # Now we can modify the copies + for dim, dim_name in enumerate(boundaries_copy.keys()): + boundaries_copy[dim_name] -= center[dim] + + # Create a new grid with modified boundaries + offset_coords = Coords(**boundaries_copy) + offset_grid = Grid(boundaries=offset_coords) + + # Verify original grid is unchanged + assert np.array_equal(grid.boundaries.x, original_x) + assert np.array_equal(grid.boundaries.y, original_y) + assert np.array_equal(grid.boundaries.z, original_z) + + # Verify offset grid has the expected modified values + assert np.array_equal(offset_grid.boundaries.x, original_x - 0.5) + assert np.array_equal(offset_grid.boundaries.y, original_y - 0.5) + assert np.array_equal(offset_grid.boundaries.z, original_z - 0.5) + + def test_field_grid(): x = np.linspace(-1, 1, 100) y = np.linspace(-1, 1, 100) @@ -45,7 +138,7 @@ def test_grid(): assert np.all(g.centers.z == np.array([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5])) for dim in "xyz": - s = g.sizes.dict()[dim] + s = g.sizes.model_dump()[dim] assert np.all(np.array(s) == 1.0) assert np.all(g.yee.E.x.x == np.array([-0.5, 0.5])) @@ -212,11 +305,11 @@ def test_sim_grid(): ) for dim in "xyz": - c = sim.grid.centers.dict()[dim] + c = sim.grid.centers.model_dump()[dim] assert np.all(c == np.array([-1.5, -0.5, 0.5, 1.5])) for dim in "xyz": - b = sim.grid.boundaries.dict()[dim] + b = sim.grid.boundaries.model_dump()[dim] assert np.all(b == np.array([-2, -1, 0, 1, 2])) @@ -265,11 +358,11 @@ def test_sim_pml_grid(): ) for dim in "xyz": - c = sim.grid.centers.dict()[dim] + c = sim.grid.centers.model_dump()[dim] assert np.all(c == np.arange(-7.5, 8, 1)) for dim in "xyz": - b = sim.grid.boundaries.dict()[dim] + b = sim.grid.boundaries.model_dump()[dim] assert np.all(b == np.arange(-8, 8.5, 1)) @@ -286,11 +379,11 @@ def test_sim_discretize_vol(): subgrid = sim.discretize(vol) for dim in "xyz": - b = subgrid.boundaries.dict()[dim] + b = subgrid.boundaries.model_dump()[dim] assert np.all(b == np.array([-1, 0, 1])) for dim in "xyz": - c = subgrid.centers.dict()[dim] + c = subgrid.centers.model_dump()[dim] assert np.all(c == np.array([-0.5, 0.5])) _ = td.Box(size=(6, 6, 0)) diff --git a/tests/test_components/test_grid_spec.py b/tests/test_components/test_grid_spec.py index 7ac2b7eb71..444ff2c21d 100644 --- a/tests/test_components/test_grid_spec.py +++ b/tests/test_components/test_grid_spec.py @@ -3,8 +3,8 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d.exceptions import SetupError @@ -317,7 +317,7 @@ def test_zerosize_dimensions(): assert np.allclose(sim.grid.boundaries.y, [-dl / 2, dl / 2]) - with pytest.raises(SetupError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(5, 0, 10), boundary_spec=td.BoundarySpec.pec( @@ -333,7 +333,7 @@ def test_zerosize_dimensions(): run_time=1e-12, ) - with pytest.raises(SetupError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(5, 3, 10), boundary_spec=td.BoundarySpec.pec( @@ -532,7 +532,7 @@ def test_domain_mismatch(): def test_uniform_grid_dl_validation(dl, expect_exception): """Test the validator that checks 'dl' is between 1e-7 and 3e8 µm.""" if expect_exception: - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(1, 1, 1), grid_spec=td.GridSpec.uniform(dl=dl), @@ -549,8 +549,8 @@ def test_uniform_grid_dl_validation(dl, expect_exception): def test_custom_grid_boundary_validation(): """Tests that the 'coords' is at least length 2 and sorted in ascending order.""" - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.CustomGridBoundaries(coords=[10]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.CustomGridBoundaries(coords=[9, 10, 9, 10, 11, 9, 8]) diff --git a/tests/test_components/test_heat.py b/tests/test_components/test_heat.py index 4b4ab27338..b35f9def11 100644 --- a/tests/test_components/test_heat.py +++ b/tests/test_components/test_heat.py @@ -1,9 +1,9 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pd import pytest from matplotlib import pyplot as plt +from pydantic import ValidationError import tidy3d as td from tidy3d import ( @@ -54,10 +54,10 @@ def make_heat_mediums(): def test_heat_medium(): _, solid_medium = make_heat_mediums() - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = solid_medium.heat_spec.updated_copy(capacity=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = solid_medium.heat_spec.updated_copy(conductivity=-1) # check we can create solid medium from SI units @@ -116,13 +116,13 @@ def make_heat_bcs(): def test_heat_bcs(): bc_temp, bc_flux, bc_conv = make_heat_bcs() - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = TemperatureBC(temperature=-10) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = ConvectionBC(ambient_temperature=-400, transfer_coeff=0.2) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = ConvectionBC(ambient_temperature=400, transfer_coeff=-0.2) # Test vertical natural convection model in ConvectionBC @@ -137,11 +137,11 @@ def test_heat_bcs(): name="air", ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.VerticalNaturalConvectionCoeffModel(medium=air.heat, plate_length=-10) _, solid_medium = make_heat_mediums() - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.VerticalNaturalConvectionCoeffModel(medium=solid_medium.heat_spec, plate_length=1e5) @@ -163,10 +163,10 @@ def make_heat_mnts(): def test_heat_mnt(): temp_mnt, _, _, _, _, _ = make_heat_mnts() - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = temp_mnt.updated_copy(name=None) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = temp_mnt.updated_copy(size=(-1, 2, 3)) @@ -276,20 +276,20 @@ def make_distance_grid_spec(): def test_grid_spec(): grid_spec = make_uniform_grid_spec() - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(dl=0) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(min_edges_per_circumference=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(min_edges_per_side=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(relative_min_dl=-1e-4) grid_spec = make_distance_grid_spec() _ = grid_spec.updated_copy(relative_min_dl=0) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(dl_interface=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = grid_spec.updated_copy(distance_interface=2, distance_bulk=1) @@ -309,8 +309,8 @@ def make_custom_heat_source(): def test_heat_source(): source = make_heat_source() source = make_custom_heat_source() - with pytest.raises(pd.ValidationError): - _ = source.updated_copy(structures=[]) + with pytest.raises(ValidationError): + _ = source.updated_copy(structures=()) def make_heat_sim(include_custom_source: bool = True): @@ -374,23 +374,25 @@ def test_heat_sim(): condition=bc_temp, placement=StructureSimulationBoundary(structure="no_mesh") ), ]: - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(boundary_spec=[pl]) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(boundary_spec=(pl,)) - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(sources=[HeatSource(structures=["noname"])], rate=-10) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(sources=(HeatSource(structures=["noname"]),), rate=-10) # run 2D case - _ = heat_sim.updated_copy(center=(0.7, 0, 0), size=(0, 2, 2), monitors=heat_sim.monitors[:5]) + _ = heat_sim.updated_copy( + center=(0.7, 0, 0), size=(0, 2, 2), monitors=tuple(heat_sim.monitors[:5]) + ) # test unsupported 1D heat domains - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = heat_sim.updated_copy(center=(1, 1, 1), size=(1, 0, 0)) temp_mnt = heat_sim.monitors[0] - with pytest.raises(pd.ValidationError): - heat_sim.updated_copy(monitors=[temp_mnt, temp_mnt]) + with pytest.raises(ValidationError): + heat_sim.updated_copy(monitors=(temp_mnt, temp_mnt)) _ = heat_sim.plot(x=0) plt.close() @@ -403,7 +405,7 @@ def test_heat_sim(): plt.close() # no negative symmetry - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = heat_sim.updated_copy(symmetry=(-1, 0, 1)) # no SolidSpec in the entire simulation @@ -412,16 +414,16 @@ def test_heat_sim(): ) solid_med = heat_sim.structures[1].medium - _ = heat_sim.updated_copy(structures=[], medium=solid_med, sources=[], boundary_spec=[bc_spec]) - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(structures=[], sources=[], boundary_spec=[bc_spec], monitors=[]) + _ = heat_sim.updated_copy(structures=(), medium=solid_med, sources=(), boundary_spec=(bc_spec,)) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(structures=(), sources=(), boundary_spec=(bc_spec,), monitors=()) _ = heat_sim.updated_copy( - structures=[heat_sim.structures[0]], medium=solid_med, boundary_spec=[bc_spec], sources=[] + structures=(heat_sim.structures[0],), medium=solid_med, boundary_spec=(bc_spec,), sources=() ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = heat_sim.updated_copy( - structures=[heat_sim.structures[0]], boundary_spec=[bc_spec], sources=[], monitors=[] + structures=(heat_sim.structures[0],), boundary_spec=(bc_spec,), sources=(), monitors=() ) # 1D and 2D structures @@ -433,18 +435,18 @@ def test_heat_sim(): geometry=td.Box(size=(1, 0, 1)), medium=heat_sim.medium, ) - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(structures=[*list(heat_sim.structures), struct_1d]) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(structures=(*heat_sim.structures, struct_1d)) - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(structures=[*list(heat_sim.structures), struct_2d]) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(structures=(*heat_sim.structures, struct_2d)) # no data expected inside a monitor for mnt_size in [(0.2, 0.2, 0.2), (0, 1, 1), (0, 2, 0), (0, 0, 0)]: temp_mnt = td.TemperatureMonitor(center=(0, 0, 0), size=mnt_size, name="test") - with pytest.raises(pd.ValidationError): - _ = heat_sim.updated_copy(monitors=[temp_mnt]) + with pytest.raises(ValidationError): + _ = heat_sim.updated_copy(monitors=(temp_mnt,)) @pytest.mark.parametrize("shift_amount, log_level", ((1, None), (2, "WARNING"))) @@ -561,15 +563,15 @@ def test_sim_data(): with pytest.raises(KeyError): _ = heat_sim_data.plot_field("test3", x=0) - with pytest.raises(pd.ValidationError): - _ = heat_sim_data.updated_copy(data=[heat_sim_data.data[0]] * 2) + with pytest.raises(ValidationError): + _ = heat_sim_data.updated_copy(data=(heat_sim_data.data[0],) * 2) temp_mnt = TemperatureMonitor(size=(1, 2, 3), name="test") temp_mnt = temp_mnt.updated_copy(name="test2") - sim = heat_sim_data.simulation.updated_copy(monitors=[temp_mnt]) + sim = heat_sim_data.simulation.updated_copy(monitors=(temp_mnt,)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = heat_sim_data.updated_copy(simulation=sim) @@ -649,11 +651,11 @@ def test_relative_min_dl_warning(): def test_sim_version_update(): heat_sim = make_heat_sim() - heat_sim_dict = heat_sim.dict() + heat_sim_dict = heat_sim.model_dump() heat_sim_dict["version"] = "ancient_version" with AssertLogLevel("WARNING"): - heat_sim_new = td.HeatSimulation.parse_obj(heat_sim_dict) + heat_sim_new = td.HeatSimulation.model_validate(heat_sim_dict) assert heat_sim_new.version == td.__version__ @@ -746,13 +748,13 @@ def test_unsteady_setup(): ) heat_sim = heat_sim.updated_copy( - structures=[solid_structure], + structures=(solid_structure,), analysis_spec=unsteady_spec, - monitors=[temp_mnt], - boundary_spec=[bc], + monitors=(temp_mnt,), + boundary_spec=(bc,), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): solid_medium = td.MultiPhysicsMedium( heat=td.SolidMedium( conductivity=3, @@ -760,9 +762,9 @@ def test_unsteady_setup(): name="solid_medium", ) new_struct = solid_structure.updated_copy(medium=solid_medium) - _ = heat_sim.updated_copy(structures=[new_struct]) + _ = heat_sim.updated_copy(structures=(new_struct,)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): solid_medium = td.MultiPhysicsMedium( heat=td.SolidMedium( conductivity=3, @@ -771,9 +773,9 @@ def test_unsteady_setup(): name="solid_medium", ) new_struct = solid_structure.updated_copy(medium=solid_medium) - _ = heat_sim.updated_copy(structures=[new_struct]) + _ = heat_sim.updated_copy(structures=(new_struct,)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): solid_medium = td.MultiPhysicsMedium( heat=td.SolidMedium( conductivity=3, @@ -782,4 +784,4 @@ def test_unsteady_setup(): name="solid_medium", ) new_struct = solid_structure.updated_copy(medium=solid_medium) - _ = heat_sim.updated_copy(structures=[new_struct]) + _ = heat_sim.updated_copy(structures=(new_struct,)) diff --git a/tests/test_components/test_heat_charge.py b/tests/test_components/test_heat_charge.py index 6d0ed00655..bad5a1fed6 100644 --- a/tests/test_components/test_heat_charge.py +++ b/tests/test_components/test_heat_charge.py @@ -3,9 +3,9 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pd import pytest from matplotlib import pyplot as plt +from pydantic import ValidationError import tidy3d as td from tidy3d.components.tcad.simulation.heat_charge import TCADAnalysisTypes @@ -837,22 +837,22 @@ def test_heat_charge_medium_validation(mediums): solid_medium = mediums["solid_medium"] # Test invalid capacity - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): solid_medium.heat_spec.updated_copy(capacity=-1) # Test invalid conductivity - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): solid_medium.heat_spec.updated_copy(conductivity=-1) # Test invalid charge conductivity - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): solid_medium.charge.updated_copy(conductivity=-1) def test_constant_mobility(): constant_mobility = td.ConstantMobilityModel(mu=1500) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = constant_mobility.updated_copy(mu=-1) @@ -876,33 +876,52 @@ def test_heat_charge_bcs_validation(boundary_conditions): bc_temp, bc_flux, bc_conv, bc_volt, bc_current = boundary_conditions # Invalid TemperatureBC - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.TemperatureBC(temperature=-10) # Invalid ConvectionBC: negative ambient temperature - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.ConvectionBC(ambient_temperature=-400, transfer_coeff=0.2) # Invalid ConvectionBC: negative transfer coefficient - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.ConvectionBC(ambient_temperature=400, transfer_coeff=-0.2) # Invalid VoltageBC: infinite voltage - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.VoltageBC(source=td.DCVoltageSource(voltage=[td.inf])) # Invalid CurrentBC: infinite current density - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.CurrentBC(source=td.DCCurrentSource(current=td.inf)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.VoltageBC(source=td.DCVoltageSource(voltage=np.array([td.inf, 0, 1]))) # Invalid SSACVoltageSource: infinite voltage - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.VoltageBC(source=td.SSACVoltageSource(voltage=np.array([td.inf, 0, 1]), amplitude=1e-2)) +def test_repeated_voltage_warning(): + """Test that a warning is raised when repeated voltage values are present.""" + # No warning for unique values + with AssertLogLevel(None): + td.DCVoltageSource(voltage=[0, 1, 2, 3]) + + # Warning for repeated values + with AssertLogLevel("WARNING"): + td.DCVoltageSource(voltage=[1, 2, 2, 3]) + + # Warning for 0 and -0 (treated as duplicates) + with AssertLogLevel("WARNING"): + td.DCVoltageSource(voltage=[0.0, -0.0, 1, 2]) + + # Warning for multiple repeated values + with AssertLogLevel("WARNING"): + td.DCVoltageSource(voltage=[1, 1, 2, 2, 3]) + + def test_freqs_validation(): """Test validation that freqs requires SSACVoltageSource.""" solid_box_1 = td.Box(center=(0, 0, 0), size=(2, 2, 2)) @@ -968,7 +987,7 @@ def test_freqs_validation(): # Test that freqs without SSACVoltageSource raises error with pytest.raises( - pd.ValidationError, + ValidationError, match="If 'freqs' is provided and not empty, at least one 'SSACVoltageSource' must be present in the boundary conditions.", ): sim.updated_copy( @@ -985,10 +1004,10 @@ def test_freqs_validation(): assert np.isclose(freqs, freqs_input).all() assert np.isclose(1e-3, amplitude) - with pytest.raises(pd.ValidationError, match="'freqs' cannot contain infinite frequencies."): + with pytest.raises(ValidationError, match="'freqs' cannot contain infinite frequencies."): sim.updated_copy(analysis_spec=sim.analysis_spec.updated_copy(freqs=[1e2, np.inf])) - with pytest.raises(pd.ValidationError, match="'freqs' cannot contain negative frequencies."): + with pytest.raises(ValidationError, match="'freqs' cannot contain negative frequencies."): sim.updated_copy(analysis_spec=sim.analysis_spec.updated_copy(freqs=[1e2, -1e2])) @@ -1061,7 +1080,7 @@ def test_vertical_natural_convection(): # Verify that placing the model on an interface between two solid media # correctly raises a validation error. - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): sim.updated_copy( structures=[solid_structure_l, solid_structure_r], boundary_spec=[ @@ -1077,7 +1096,7 @@ def test_vertical_natural_convection(): incomplete_air = td.MultiPhysicsMedium( heat=td.FluidMedium(expansivity=1 / 300.0), name="incomplete_air" ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): new_fluid_structure_r = fluid_structure_r.updated_copy(medium=incomplete_air) sim.updated_copy( structures=[solid_structure_l, new_fluid_structure_r], @@ -1105,7 +1124,7 @@ def test_vertical_natural_convection(): # Verify that a validation error is raised if the medium supplied directly to the # coefficient model has incomplete properties for the natural convection calculation. incomplete_coeff_model = coeff_model.updated_copy(medium=incomplete_air.heat) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): sim.updated_copy( boundary_spec=[ td.HeatBoundarySpec( @@ -1124,15 +1143,15 @@ def test_heat_charge_monitors_validation(monitors): mesh_mnt = monitors[11] # Invalid monitor name - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): temp_mnt.updated_copy(name=None) # Invalid monitor size (negative dimension) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): temp_mnt.updated_copy(size=(-1, 2, 3)) # Mesh monitor 1D - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): mesh_mnt.updated_copy(size=(0, 1, 0)) @@ -1148,9 +1167,9 @@ def test_monitor_crosses_medium(mediums, structures, heat_simulation, conduction center=(0, 0, 0), size=(td.inf, td.inf, td.inf), name="voltage" ) # A voltage monitor in a heat simulation should throw error if no ChargeConductorMedium is present - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): heat_simulation.updated_copy( - medium=solid_no_elect, structures=[solid_struct_no_elect], monitors=[volt_monitor] + medium=solid_no_elect, structures=(solid_struct_no_elect,), monitors=(volt_monitor,) ) # Temperature monitor @@ -1158,15 +1177,15 @@ def test_monitor_crosses_medium(mediums, structures, heat_simulation, conduction center=(0, 0, 0), size=(td.inf, td.inf, td.inf), name="temperature" ) # A temperature monitor should throw error in a conduction simulation if no SolidSpec is present - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): conduction_simulation.updated_copy( - medium=solid_no_heat, structures=[solid_struct_no_heat], monitors=[temp_monitor] + medium=solid_no_heat, structures=(solid_struct_no_heat,), monitors=(temp_monitor,) ) # check error is raised in voltage monitor doesn't cross a conducting medium - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): volt_mnt = td.SteadyPotentialMonitor(center=(0, 0, 0), size=(0, td.inf, td.inf)) - _ = conduction_simulation.updated_copy(monitors=[volt_mnt]) + _ = conduction_simulation.updated_copy(monitors=(volt_mnt,)) def test_heat_charge_mnt_data( @@ -1227,7 +1246,7 @@ def test_heat_charge_mnt_data( values=tri_grid_values, ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): if var == "E": _ = mnt_data.updated_copy(E=tri_grid) elif var == "J": @@ -1238,18 +1257,18 @@ def test_grid_spec_validation(grid_specs): """Tests whether unstructured grids can be created and different validators for them.""" # Test UniformUnstructuredGrid uniform_grid = grid_specs["uniform"] - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): uniform_grid.updated_copy(dl=0) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): uniform_grid.updated_copy(min_edges_per_circumference=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): uniform_grid.updated_copy(min_edges_per_side=-1) # Test DistanceUnstructuredGrid distance_grid = grid_specs["distance"] - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): distance_grid.updated_copy(dl_interface=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): distance_grid.updated_copy(distance_interface=2, distance_bulk=1) @@ -1342,16 +1361,16 @@ def test_sim_data_plotting(simulation_data): heat_sim_data.plot_field("test3", x=0) # Test updating simulation data with duplicate data - with pytest.raises(pd.ValidationError): - heat_sim_data.updated_copy(data=[heat_sim_data.data[0]] * 2) + with pytest.raises(ValidationError): + heat_sim_data.updated_copy(data=(heat_sim_data.data[0],) * 2) # Test updating simulation data with invalid simulation temp_mnt = td.TemperatureMonitor(size=(1, 2, 3), name="test") temp_mnt = temp_mnt.updated_copy(name="test2") - sim = heat_sim_data.simulation.updated_copy(monitors=[temp_mnt]) + sim = heat_sim_data.simulation.updated_copy(monitors=(temp_mnt,)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): heat_sim_data.updated_copy(simulation=sim) @@ -1396,21 +1415,21 @@ def test_mesh_plotting(simulation_data): def test_conduction_simulation_has_conductors(conduction_simulation, structures): """Test whether error is raised if conduction simulation has no conductors.""" - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = conduction_simulation.updated_copy( - monitors=[], - structures=[structures["insulator_structure"]], + monitors=(), + structures=(structures["insulator_structure"],), ) def test_coupling_source(conduction_simulation, heat_simulation): """Test whether the coupling source can be applied.""" - with pytest.raises(pd.ValidationError): - _ = conduction_simulation.updated_copy(sources=[td.HeatFromElectricSource()]) + with pytest.raises(ValidationError): + _ = conduction_simulation.updated_copy(sources=(td.HeatFromElectricSource(),)) - with pytest.raises(pd.ValidationError): - _ = heat_simulation.updated_copy(sources=[td.HeatFromElectricSource()]) + with pytest.raises(ValidationError): + _ = heat_simulation.updated_copy(sources=(td.HeatFromElectricSource(),)) # -------------------------- @@ -1569,18 +1588,18 @@ def test_charge_simulation( ) # At least one ChargeSimulationMonitor should be added - with pytest.raises(pd.ValidationError): - sim.updated_copy(monitors=[]) + with pytest.raises(ValidationError): + sim.updated_copy(monitors=()) # At least 2 VoltageBCs should be defined - with pytest.raises(pd.ValidationError): - sim.updated_copy(boundary_spec=[bc_n]) + with pytest.raises(ValidationError): + sim.updated_copy(boundary_spec=(bc_n,)) condition_ssac_n = td.VoltageBC(source=td.SSACVoltageSource(voltage=[0, 1], amplitude=1e-3)) condition_ssac_p = td.VoltageBC(source=td.SSACVoltageSource(voltage=[0, 1], amplitude=1e-3)) # Two AC sources cannot be defined with pytest.raises( - pd.ValidationError, match="Only a single 'SSACVoltageSource' source can be supplied." + ValidationError, match="Only a single 'SSACVoltageSource' source can be supplied." ): analysis = td.IsothermalSSACAnalysis(freqs=[1e2, 1e3], temperature=300) sim.updated_copy( @@ -1593,7 +1612,7 @@ def test_charge_simulation( # Test SSACAnalysis as well with pytest.raises( - pd.ValidationError, match="Only a single 'SSACVoltageSource' source can be supplied." + ValidationError, match="Only a single 'SSACVoltageSource' source can be supplied." ): analysis_ssac = td.SSACAnalysis(freqs=[1e2, 1e3], tolerance_settings=charge_tolerance) sim.updated_copy( @@ -1611,28 +1630,28 @@ def test_charge_simulation( ) new_structures = [struct.updated_copy(medium=medium) for struct in sim.structures] - with pytest.raises(pd.ValidationError): - sim.updated_copy(structures=new_structures) + with pytest.raises(ValidationError): + sim.updated_copy(structures=tuple(new_structures)) # test a voltage array is provided when a capacitance monitor is present - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): new_bc_n = bc_n.updated_copy( condition=td.VoltageBC(source=td.DCVoltageSource(voltage=1)) ) - _ = sim.updated_copy(boundary_spec=[bc_p, new_bc_n]) + _ = sim.updated_copy(boundary_spec=(bc_p, new_bc_n)) # test error is raised when more than one voltage array is provided - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): new_bc_p = bc_p.updated_copy( condition=td.VoltageBC(source=td.DCVoltageSource(voltage=[1, 2])) ) - _ = sim.updated_copy(boundary_spec=[new_bc_p, bc_n]) + _ = sim.updated_copy(boundary_spec=(new_bc_p, bc_n)) # test non isothermal spec non_isothermal_spec = td.SteadyChargeDCAnalysis(tolerance_settings=charge_tolerance) sim = sim.updated_copy(analysis_spec=non_isothermal_spec) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): # remove heat from mediums new_structs = [] for struct in sim.structures: @@ -1641,7 +1660,7 @@ def test_charge_simulation( ) _ = sim.updated_copy(structures=new_structs) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): # remove charge from mediums new_structs = [] for struct in sim.structures: @@ -1650,7 +1669,7 @@ def test_charge_simulation( ) _ = sim.updated_copy(structures=new_structs) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): # make sure there is at least one semiconductor new_structs = [] for struct in sim.structures: @@ -1816,7 +1835,7 @@ def test_sim_structure_extent(box_size, log_level): with AssertLogLevel(log_level): _ = td.HeatChargeSimulation( size=(1, 1, 1), - structures=[box], + structures=(box,), medium=td.MultiPhysicsMedium(charge=td.ChargeConductorMedium(conductivity=1)), boundary_spec=[ td.HeatChargeBoundarySpec( @@ -1938,7 +1957,7 @@ def test_simulation_initialization_invalid_parameters( ): """Test simulation initialization with invalid parameters.""" # Invalid simulation size - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.HeatChargeSimulation( medium=mediums["fluid_medium"], structures=[structures["fluid_structure"]], @@ -1951,7 +1970,7 @@ def test_simulation_initialization_invalid_parameters( ) # Invalid monitor type - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): td.HeatChargeSimulation( medium=mediums["fluid_medium"], structures=[structures["fluid_structure"]], @@ -2018,9 +2037,7 @@ def test_dynamic_simulation_updates(heat_simulation): # Add a new monitor new_monitor = td.TemperatureMonitor(size=(1, 1, 1), name="new_temp_mnt") - updated_sim = heat_simulation.updated_copy( - monitors=(*list(heat_simulation.monitors), new_monitor) - ) + updated_sim = heat_simulation.updated_copy(monitors=(*heat_simulation.monitors, new_monitor)) assert len(updated_sim.monitors) == len(heat_simulation.monitors) + 1 assert updated_sim.monitors[-1].name == "new_temp_mnt" @@ -2304,21 +2321,21 @@ def test_unsteady_parameters(): ) # test non-positive initial temperature raises error - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.UnsteadyHeatAnalysis( initial_temperature=0, unsteady_spec=td.UnsteadySpec(time_step=0.1, total_time_steps=1), ) # test negative time step raises error - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.UnsteadyHeatAnalysis( initial_temperature=10, unsteady_spec=td.UnsteadySpec(time_step=-0.1, total_time_steps=1), ) # test negative total time steps raises error - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.UnsteadyHeatAnalysis( initial_temperature=10, unsteady_spec=td.UnsteadySpec(time_step=0.1, total_time_steps=-1), @@ -2343,19 +2360,19 @@ def test_unsteady_heat_analysis(heat_simulation): # this should work since the monitor is unstructured unsteady_sim = heat_simulation.updated_copy( - analysis_spec=unsteady_analysis_spec, monitors=[temp_mnt] + analysis_spec=unsteady_analysis_spec, monitors=(temp_mnt,) ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): temp_mnt = temp_mnt.updated_copy(unstructured=False) - _ = unsteady_sim.updated_copy(monitors=[temp_mnt]) + _ = unsteady_sim.updated_copy(monitors=(temp_mnt,)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): temp_mnt = temp_mnt.updated_copy(unstructured=True, interval=0) - _ = unsteady_sim.updated_copy(monitors=[temp_mnt]) + _ = unsteady_sim.updated_copy(monitors=(temp_mnt,)) # try simulation with excessive time steps - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): mew_spex = td.UnsteadyHeatAnalysis( initial_temperature=300, unsteady_spec=td.UnsteadySpec(time_step=0.1, total_time_steps=100000), @@ -2412,11 +2429,11 @@ def test_heat_conduction_simulations(): monitors=[temp_monitor, voltage_monitor], ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): # no thermal monitors _ = sim.updated_copy(monitors=[voltage_monitor]) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): # voltage array in electric BC _ = sim.updated_copy( boundary_spec=[ @@ -2430,17 +2447,17 @@ def test_heat_conduction_simulations(): # this doesn't raise error _ = sim.updated_copy(sources=[td.HeatFromElectricSource()]) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): # This should error since the conduction simulation doesn't have a monitor _ = sim.updated_copy(monitors=[temp_monitor]) # test error if structures defined with Medium instead of MultiPhysicsMedium - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): struct_error = struct1.updated_copy(medium=td.Medium(conductivity=1)) _ = sim.updated_copy(structures=[struct_error]) # test error if structures aren't conducting - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): struct_error = struct1.updated_copy( medium=struct1.medium.updated_copy(charge=td.ChargeInsulatorMedium) ) @@ -2597,3 +2614,133 @@ def test_heat_only_simulation_with_semiconductor(): assert TCADAnalysisTypes.CONDUCTION not in simulation_types, ( "Conduction simulation should NOT be triggered when no electric BCs are present." ) + + +def test_heat_charge_simulation_plot(): + """Test the HeatChargeSimulation.plot() method adds BCs based on simulation type.""" + + # Create mediums + solid_medium = td.MultiPhysicsMedium( + heat=td.SolidMedium(conductivity=1, capacity=1), + name="solid", + ) + fluid_medium = td.MultiPhysicsMedium( + heat=td.FluidMedium(), + name="fluid", + ) + + # Create structures + solid_structure = td.Structure( + geometry=td.Box(size=(1, 1, 1), center=(0, 0, 0)), + medium=solid_medium, + name="solid_structure", + ) + + # Create boundary conditions for heat simulation + bc_temp = td.HeatChargeBoundarySpec( + condition=td.TemperatureBC(temperature=300), + placement=td.StructureBoundary(structure="solid_structure"), + ) + + # Create heat source + heat_source = td.UniformHeatSource(rate=1e3, structures=["solid_structure"]) + + # Create monitor + temp_monitor = td.TemperatureMonitor( + center=(0, 0, 0), + size=(1, 1, 0), + name="temp_mnt", + ) + + # Create a HEAT simulation + heat_sim = td.HeatChargeSimulation( + medium=fluid_medium, + structures=[solid_structure], + center=(0, 0, 0), + size=(2, 2, 2), + boundary_spec=[bc_temp], + grid_spec=td.UniformUnstructuredGrid(dl=0.1), + sources=[heat_source], + monitors=[temp_monitor], + ) + + # Test plot for HEAT simulation - should add heat BCs + _, ax_scene_only = plt.subplots() + heat_sim.scene.plot(z=0, ax=ax_scene_only) + num_children_scene_only = len(ax_scene_only.get_children()) + plt.close() + + _, ax_with_bc = plt.subplots() + heat_sim.plot(z=0, ax=ax_with_bc) + num_children_with_bc = len(ax_with_bc.get_children()) + plt.close() + + # heat_sim.plot() should have more visual elements than scene.plot() + # because it adds monitors and heat boundaries for HEAT simulations + assert num_children_with_bc - num_children_scene_only >= 2, ( + "heat_sim.plot() should add at least monitors and heat boundaries " + "for HEAT simulations, resulting in at least 2 more visual elements " + "than heat_sim.scene.plot()" + ) + + # Now test with a CHARGE simulation + semicon = td.material_library["cSi"].variants["Si_MultiPhysics"].medium.charge + Si_n = semicon.updated_copy(N_d=[td.ConstantDoping(concentration=1e16)], name="Si_n") + Si_p = semicon.updated_copy(N_a=[td.ConstantDoping(concentration=1e16)], name="Si_p") + + n_side = td.Structure( + geometry=td.Box(center=(-0.25, 0, 0), size=(0.5, 1, 1)), + medium=Si_n, + name="n_side", + ) + p_side = td.Structure( + geometry=td.Box(center=(0.25, 0, 0), size=(0.5, 1, 1)), + medium=Si_p, + name="p_side", + ) + + bc_v1 = td.HeatChargeBoundarySpec( + condition=td.VoltageBC(source=td.DCVoltageSource(voltage=0)), + placement=td.MediumMediumInterface(mediums=[fluid_medium.name, Si_n.name]), + ) + bc_v2 = td.HeatChargeBoundarySpec( + condition=td.VoltageBC(source=td.DCVoltageSource(voltage=0.5)), + placement=td.MediumMediumInterface(mediums=[fluid_medium.name, Si_p.name]), + ) + + volt_monitor = td.SteadyPotentialMonitor( + center=(0, 0, 0), + size=(1, 1, 0), + name="volt_mnt", + unstructured=True, + ) + + charge_sim = td.HeatChargeSimulation( + structures=[n_side, p_side], + medium=fluid_medium, + monitors=[volt_monitor], + center=(0, 0, 0), + size=(2, 2, 2), + grid_spec=td.UniformUnstructuredGrid(dl=0.05), + boundary_spec=[bc_v1, bc_v2], + analysis_spec=td.IsothermalSteadyChargeDCAnalysis(temperature=300), + ) + + # Test plot for CHARGE simulation - should add electric BCs + _, ax_scene_only = plt.subplots() + charge_sim.scene.plot(z=0, ax=ax_scene_only) + num_children_scene_only = len(ax_scene_only.get_children()) + plt.close() + + _, ax_with_bc = plt.subplots() + charge_sim.plot(z=0, ax=ax_with_bc) + num_children_with_bc = len(ax_with_bc.get_children()) + plt.close() + + # charge_sim.plot() should have more visual elements than scene.plot() + # because it adds monitors and electric boundaries for CHARGE simulations + assert num_children_with_bc - num_children_scene_only >= 2, ( + "charge_sim.plot() should add at least monitors and electric boundaries " + "for CHARGE simulations, resulting in at least 2 more visual elements " + "than charge_sim.scene.plot()" + ) diff --git a/tests/test_components/test_layerrefinement.py b/tests/test_components/test_layerrefinement.py index 8716f17f33..59daaf7907 100644 --- a/tests/test_components/test_layerrefinement.py +++ b/tests/test_components/test_layerrefinement.py @@ -3,8 +3,8 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d.components.grid.corner_finder import CornerFinderSpec @@ -94,7 +94,7 @@ def test_layerrefinement(): """Test LayerRefinementSpec is working as expected.""" # size along axis must be inf - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = LayerRefinementSpec(axis=0, size=(td.inf, 0, 0)) # classmethod @@ -119,19 +119,19 @@ def test_layerrefinement(): assert layer._is_inplane_bounded(layer) assert layer.axis == 1 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): structures = [ td.Structure(geometry=td.Box(size=(td.inf, td.inf, td.inf)), medium=td.Medium()) ] layer = LayerRefinementSpec.from_structures(structures) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = LayerRefinementSpec.from_layer_bounds(axis=axis, bounds=(0, td.inf)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = LayerRefinementSpec.from_layer_bounds(axis=axis, bounds=(td.inf, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = LayerRefinementSpec.from_layer_bounds(axis=axis, bounds=(-td.inf, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = LayerRefinementSpec.from_layer_bounds(axis=axis, bounds=(1, -1)) @@ -573,7 +573,9 @@ def test_gap_meshing(): reentry_gap = td.Structure( geometry=td.PolySlab( - slab_bounds=[-0.2, 0.2], axis=1, vertices=[(-0.3, 0.52), (-0.05, 0.3), (0.2, 0.52)] + slab_bounds=(-0.2, 0.2), + axis=1, + vertices=[(-0.3, 0.52), (-0.05, 0.3), (0.2, 0.52)], ), medium=td.Medium(), ) diff --git a/tests/test_components/test_low_freq_smoothing.py b/tests/test_components/test_low_freq_smoothing.py index 341e8be992..2c481e12f3 100644 --- a/tests/test_components/test_low_freq_smoothing.py +++ b/tests/test_components/test_low_freq_smoothing.py @@ -2,8 +2,8 @@ from __future__ import annotations -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td @@ -19,7 +19,7 @@ def test_low_freq_smoothing_spec_initialization_default_values(): def test_empty_monitors(): """Test that LowFrequencySmoothingSpec raises an error if monitors are not provided.""" - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.LowFrequencySmoothingSpec(monitors=[]) @@ -40,7 +40,7 @@ def test_monitors_exist(): low_freq_smoothing=td.LowFrequencySmoothingSpec(monitors=["monitor1"]), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(1, 1, 1), monitors=[], diff --git a/tests/test_components/test_lumped_element.py b/tests/test_components/test_lumped_element.py index 957a4e8aeb..cabfdc79eb 100644 --- a/tests/test_components/test_lumped_element.py +++ b/tests/test_components/test_lumped_element.py @@ -3,8 +3,8 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d.components.lumped_element import network_complex_permittivity @@ -36,7 +36,7 @@ def test_lumped_resistor(): assert monitor.name == resistor.monitor_name # error if voltage axis is not in plane with the resistor - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.LumpedResistor( resistance=50.0, center=[0, 0, 0], @@ -46,7 +46,7 @@ def test_lumped_resistor(): ) # error if not planar - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.LumpedResistor( resistance=50.0, center=[0, 0, 0], @@ -54,7 +54,7 @@ def test_lumped_resistor(): voltage_axis=2, name="R", ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.LumpedResistor( resistance=50.0, center=[0, 0, 0], @@ -89,7 +89,7 @@ def test_lumped_resistor_snapping(): # snapped version resistor_snapped = resistor.updated_copy(enable_snapping_points=True) - sim_snapped = sim.updated_copy(lumped_elements=[resistor_snapped]) + sim_snapped = sim.updated_copy(lumped_elements=(resistor_snapped,)) # whether lumped element is snapped along normal axis assert not any(np.isclose(sim.grid.boundaries.z, 0.1)) assert any(np.isclose(sim_snapped.grid.boundaries.z, 0.1)) @@ -127,7 +127,7 @@ def test_coaxial_lumped_resistor_snapping(): # snapped version resistor_snapped = resistor.updated_copy(enable_snapping_points=True) - sim_snapped = sim.updated_copy(lumped_elements=[resistor_snapped]) + sim_snapped = sim.updated_copy(lumped_elements=(resistor_snapped,)) # whether lumped element is snapped along normal axis assert not any(np.isclose(sim.grid.boundaries.z, 0.1)) assert any(np.isclose(sim_snapped.grid.boundaries.z, 0.1)) @@ -156,7 +156,7 @@ def test_coaxial_lumped_resistor(): _ = resistor.to_snapping_points() # error if inner diameter is larger - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.CoaxialLumpedResistor( resistance=50.0, center=[0, 0, 0], @@ -166,7 +166,7 @@ def test_coaxial_lumped_resistor(): name="R", ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.CoaxialLumpedResistor( resistance=50.0, center=[0, 0, np.inf], @@ -180,11 +180,11 @@ def test_coaxial_lumped_resistor(): def test_validators_RLC_network(): """Test that ``RLCNetwork`` is validated correctly.""" # Must have a defined value for R,L,or C - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.RLCNetwork() # Must have a valid topology - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.RLCNetwork( capacitance=0.2e-12, network_topology="left", @@ -193,13 +193,13 @@ def test_validators_RLC_network(): def test_validators_admittance_network(): """Test that ``AdmittanceNetwork`` is validated correctly.""" - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.AdmittanceNetwork() a = (0, -1, 2) b = (1, 1, 2) # non negative a and b - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.AdmittanceNetwork( a=a, b=b, @@ -208,7 +208,7 @@ def test_validators_admittance_network(): a = (0, complex(1, 2), 2) b = (1, 1, 2) # real a and b - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.AdmittanceNetwork( a=a, b=b, @@ -268,17 +268,12 @@ def test_RLC_and_lumped_network_agreement(Rval, Lval, Cval, topology): if configuration_includes_parallel_inductor: return - network = td.AdmittanceNetwork( - a=a, - b=b, - ) + network = td.AdmittanceNetwork(a=a, b=b) (a, b) = network._as_admittance_function med_network = network._to_medium(sf) # Check conversion to geometry and to structure - linear_element = linear_element.updated_copy( - network=network, - ) + linear_element = linear_element.updated_copy(network=network) _ = linear_element.to_geometry() assert np.allclose(med_RLC.eps_model(freqs), med_network.eps_model(freqs), rtol=rtol) diff --git a/tests/test_components/test_map.py b/tests/test_components/test_map.py index 059fbcdf4c..345ebfc9bf 100644 --- a/tests/test_components/test_map.py +++ b/tests/test_components/test_map.py @@ -2,8 +2,8 @@ import collections.abc -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError from tidy3d import SimulationMap @@ -33,7 +33,7 @@ def test_simulation_map_creation(): def test_simulation_map_invalid_type_raises_error(): """Tests that a ValidationError is raised for incorrect value types.""" invalid_data = {"sim_A": "not a simulation"} - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): SimulationMap(keys=tuple(invalid_data.keys()), values=tuple(invalid_data.values())) diff --git a/tests/test_components/test_medium.py b/tests/test_components/test_medium.py index 320d1bda2b..df6b019349 100644 --- a/tests/test_components/test_medium.py +++ b/tests/test_components/test_medium.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic +import pydantic as pd import pytest import tidy3d as td @@ -56,18 +56,18 @@ def test_from_n_less_than_1(): def test_medium(): # mediums error with unacceptable values - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Medium(permittivity=0.0) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Medium(conductivity=-1.0) def test_validate_largest_pole_parameters(): # error for large pole parameters - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.PoleResidue(poles=[((-1e50 + 2j), (1 + 3j))]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.PoleResidue(poles=[((-1 + 2j), (1e50 + 3j))]) @@ -158,30 +158,30 @@ def test_PMC(): def test_lossy_metal(): # frequency_range shouldn't be None - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=1) - # frequency_range shouldn't contain non-postive values - with pytest.raises(pydantic.ValidationError): + # frequency_range shouldn't contain non-positive values + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=1, frequency_range=(0, 10)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=1, frequency_range=(-10, 10)) # frequency_range should be finite - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=1, frequency_range=(10, np.inf)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=1, frequency_range=(-np.inf, 10)) # allow_gain cannot be true - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(allow_gain=True, conductivity=1, frequency_range=(10, 20)) # conductivity cannot be negative - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=-1, frequency_range=(10, 20)) # conductivity cannot be 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.LossyMetalMedium(conductivity=0, frequency_range=(10, 20)) # default fitting @@ -232,13 +232,13 @@ def test_medium_dispersion(): m_DR = td.Drude(eps_inf=1.0, coeffs=[(1, 3), (2, 4)]) m_DB = td.Debye(eps_inf=1.0, coeffs=[(1, 3), (2, 4)]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Sellmeier(coeffs=[(2, 0), (2, 4)]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Drude(eps_inf=1.0, coeffs=[(1, 0), (2, 4)]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Debye(eps_inf=1.0, coeffs=[(1, 0), (2, 4)]) freqs = np.linspace(0.01, 1, 1001) @@ -442,27 +442,27 @@ def test_n_cfl(): def test_gain_medium(): """Test passive and gain medium validations.""" # non-dispersive - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Medium(conductivity=-0.1) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Medium(conductivity=-1.0, allow_gain=False) _ = td.Medium(conductivity=-1.0, allow_gain=True) # pole residue, causality - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.PoleResidue(eps_inf=0.16, poles=[(1 + 1j, 2 + 2j)]) # Sellmeier - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Sellmeier(coeffs=((-1, 1),)) mS = td.Sellmeier(coeffs=((-1, 1),), allow_gain=True) # Lorentz # causality, negative gamma - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Lorentz(eps_inf=0.04, coeffs=[(1, 2, -3)]) # gain, negative Delta epsilon - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Lorentz(eps_inf=0.04, coeffs=[(-1, 2, 3)]) mL = td.Lorentz(eps_inf=0.04, coeffs=[(-1, 2, 3)], allow_gain=True) assert mL.pole_residue.allow_gain @@ -471,7 +471,7 @@ def test_gain_medium(): _ = td.Lorentz(eps_inf=0.04, coeffs=[(1, -2, 3)]) # Drude, only causality constraint - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Drude(eps_inf=0.04, coeffs=[(1, -2)]) # anisotropic medium, warn allow_gain is ignored @@ -516,7 +516,7 @@ def test_medium2d(): _ = medium.plot(freqs=[2e14, 3e14], ax=AX) plt.close() - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Medium2D(ss=td.PECMedium(), tt=td.Medium()) @@ -559,24 +559,24 @@ def test_fully_anisotropic_media(): _ = td.FullyAnisotropicMedium(permittivity=perm, conductivity=cond) # check that tensors are provided - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(permittivity=2) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(permittivity=[3, 4, 2]) # check that permittivity >= 1 and conductivity >= 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(permittivity=[[3, 0, 0], [0, 0.5, 0], [0, 0, 1]]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(conductivity=[[-3, 0, 0], [0, 0.5, 0], [0, 0, 1]]) td.FullyAnisotropicMedium(conductivity=[[-3, 0, 0], [0, 0.5, 0], [0, 0, 1]], allow_gain=True) # check that permittivity needs to be symmetric - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(permittivity=[[3, 0.1, 0], [0.2, 2, 0], [0, 0, 1]]) # check that differently oriented permittivity and conductivity are not accepted - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FullyAnisotropicMedium(permittivity=perm, conductivity=cond2) # check creation from diagonal medium @@ -660,7 +660,7 @@ def test_nonlinear_medium(): med = td.Medium(nonlinear_spec=td.NonlinearSusceptibility(chi3=1.5)) # don't use deprecated numiters - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium( nonlinear_spec=td.NonlinearSpec(models=[td.NonlinearSusceptibility(chi3=1, numiters=2)]) ) @@ -669,15 +669,15 @@ def test_nonlinear_medium(): med = td.PoleResidue(poles=[(-1, 1)], nonlinear_spec=td.NonlinearSusceptibility(chi3=1.5)) # unsupported material types - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): med = td.AnisotropicMedium( xx=med, yy=med, zz=med, nonlinear_spec=td.NonlinearSusceptibility(chi3=1.5) ) # numiters too large - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium(nonlinear_spec=td.NonlinearSusceptibility(chi3=1.5, numiters=200)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium( nonlinear_spec=td.NonlinearSpec( num_iters=200, models=[td.NonlinearSusceptibility(chi3=1.5)] @@ -685,7 +685,7 @@ def test_nonlinear_medium(): ) # duplicate models - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium( nonlinear_spec=td.NonlinearSpec( models=[ @@ -696,7 +696,7 @@ def test_nonlinear_medium(): ) # active materials - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): med = td.Medium( nonlinear_spec=td.NonlinearSpec(models=[td.TwoPhotonAbsorption(beta=-1, n0=1, freq0=1)]) ) @@ -724,28 +724,28 @@ def test_nonlinear_medium(): # subsection with nonlinear materials preserves sources sim2 = sim.updated_copy(center=(-4, -4, -4), path="sources/0") sim2 = sim2.updated_copy( - models=[td.TwoPhotonAbsorption(beta=1)], path="structures/0/medium/nonlinear_spec" + models=(td.TwoPhotonAbsorption(beta=1),), path="structures/0/medium/nonlinear_spec" ) sim2 = sim2.subsection(region=td.Box(center=(0, 0, 0), size=(1, 1, 0))) nonlinear_spec = td.NonlinearSpec(models=[td.KerrNonlinearity(n2=1, n0=1)]) structure = structure.updated_copy(medium=medium.updated_copy(nonlinear_spec=nonlinear_spec)) - sim = sim.updated_copy(structures=[structure]) + sim = sim.updated_copy(structures=(structure,)) nonlinear_spec = td.NonlinearSpec(models=[td.TwoPhotonAbsorption(beta=1, n0=1)]) structure = structure.updated_copy(medium=medium.updated_copy(nonlinear_spec=nonlinear_spec)) sim = sim.updated_copy(structures=[structure]) nonlinear_spec = td.NonlinearSpec(models=[td.TwoPhotonAbsorption(beta=1, n0=1, freq0=1)]) structure = structure.updated_copy(medium=medium.updated_copy(nonlinear_spec=nonlinear_spec)) - sim = sim.updated_copy(structures=[structure]) + sim = sim.updated_copy(structures=(structure,)) # active materials with automatic detection of n0 nonlinear_spec_active = td.NonlinearSpec(models=[td.TwoPhotonAbsorption(beta=-1)]) - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): medium_active = medium.updated_copy(nonlinear_spec=nonlinear_spec_active) # inconsistent n0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.NonlinearSpec( models=[td.KerrNonlinearity(n0=1, n2=1), td.TwoPhotonAbsorption(beta=1, n0=2)] ) @@ -762,9 +762,9 @@ def test_nonlinear_medium(): MODULATION_SPEC = td.ModulationSpec() modulation_spec = MODULATION_SPEC.updated_copy(permittivity=ST) modulated = td.Medium(permittivity=2, modulation_spec=modulation_spec) - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): td.Medium2D(ss=medium, tt=medium) - with pytest.raises(ValidationError): + with pytest.raises(pd.ValidationError): td.Medium2D(ss=modulated, tt=modulated) grid_spec = td.GridSpec.auto(min_steps_per_wvl=10, wavelength=1) @@ -778,7 +778,7 @@ def test_nonlinear_medium(): interval=1, size=(0, 0, 0), name="aux_field_time", fields=aux_fields ) sim = sim.updated_copy(medium=med, path="structures/0") - sim = sim.updated_copy(monitors=[monitor]) + sim = sim.updated_copy(monitors=(monitor,)) with AssertLogLevel("WARNING", contains_str="stores field"): med = td.Medium( @@ -821,7 +821,7 @@ def create_mediums(n_dataset): with AssertLogLevel(None): create_mediums(n_dataset=n_dataset) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): # repeat some entries so data cannot be interpolated X2 = [X[0], *list(X)] n_data2 = np.vstack((n_data[0, :, :, :].reshape(1, Ny, Nz, Nf), n_data)) diff --git a/tests/test_components/test_meshgenerate.py b/tests/test_components/test_meshgenerate.py index 2a5ec3d59f..32ca32fab6 100644 --- a/tests/test_components/test_meshgenerate.py +++ b/tests/test_components/test_meshgenerate.py @@ -505,7 +505,7 @@ def test_mesh_direct_override(): assert np.isclose(sizes[len(sizes) // 2], 0.05) # default override has no effect when coarser than enclosing structure - override_coarse = override_fine.copy(update={"dl": [0.2] * 3}) + override_coarse = override_fine.copy(update={"dl": (0.2,) * 3}) sim = td.Simulation( size=(3, 3, 3), grid_spec=td.GridSpec.auto( @@ -678,13 +678,13 @@ def test_small_structure_size(): # Warning not raised if structure is higher index box2 = box.updated_copy(medium=td.Medium(permittivity=300)) with AssertLogLevel(None): - sim.updated_copy(structures=[box2]) + sim.updated_copy(structures=(box2,)) # Warning not raised if structure is covered by an override structure override = td.MeshOverrideStructure(geometry=box.geometry, dl=(box_size, td.inf, td.inf)) with AssertLogLevel(None): sim3 = sim.updated_copy( - grid_spec=sim.grid_spec.updated_copy(override_structures=[override]) + grid_spec=sim.grid_spec.updated_copy(override_structures=(override,)) ) # Also check that the structure boundaries are in the grid ind_mid_cell = int(sim3.grid.num_cells[0] // 2) @@ -696,7 +696,7 @@ def test_small_structure_size(): geometry=td.Box(center=(box_size, 0, 0), size=(box_size, td.inf, td.inf)), medium=medium ) with AssertLogLevel("WARNING"): - sim.updated_copy(structures=[box3, box]) + sim.updated_copy(structures=(box3, box)) def test_shapely_strtree_warnings(): diff --git a/tests/test_components/test_microwave.py b/tests/test_components/test_microwave.py index 3d14664a56..5ea56e0349 100644 --- a/tests/test_components/test_microwave.py +++ b/tests/test_components/test_microwave.py @@ -7,7 +7,7 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pd +import pydantic as pd import pytest import xarray as xr from shapely import LineString diff --git a/tests/test_components/test_mode.py b/tests/test_components/test_mode.py index 77041abd38..48a405c860 100644 --- a/tests/test_components/test_mode.py +++ b/tests/test_components/test_mode.py @@ -5,7 +5,7 @@ import importlib import numpy as np -import pydantic.v1 as pydantic +import pydantic as pd import pytest from matplotlib import pyplot as plt @@ -37,29 +37,29 @@ def test_modes(): for opt in ["lowest", "highest", "central"]: _ = td.ModeSpec(num_modes=3, sort_spec=td.ModeSortSpec(track_freq=opt)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(num_modes=3, track_freq="middle") - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(num_modes=3, track_freq=4) def test_bend_axis_not_given(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(bend_radius=1.0, bend_axis=None) def test_zero_radius(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(bend_radius=0.0, bend_axis=1) def test_glancing_incidence(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(angle_theta=np.pi / 2) def test_group_index_step_validation(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.ModeSpec(group_index_step=1.0) ms = td.ModeSpec(group_index_step=True) @@ -76,7 +76,7 @@ def test_angle_rotation_with_phi(): td.ModeSpec(angle_phi=np.pi, angle_rotation=True) # Case where angle_phi is not a multiple of np.pi and angle_rotation is True - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.ModeSpec(angle_phi=np.pi / 3, angle_rotation=True) @@ -128,22 +128,22 @@ def test_validation_from_simulation(): _ = sim.updated_copy(structures=[reg_geometry], monitors=[rot_monitor]) # Test that transforming a geometry with an infinite extent raises an error - with pytest.raises(SetupError): + with pytest.raises((SetupError, pd.ValidationError)): sim.updated_copy(structures=[inf_geometry], monitors=[rot_monitor]) # Test that transforming an anisotropic medium raises an error - with pytest.raises(SetupError): + with pytest.raises((SetupError, pd.ValidationError)): sim.updated_copy(structures=[anisotropic_geometry], monitors=[rot_monitor]) # Same thing with a ModeSource - with pytest.raises(SetupError): + with pytest.raises((SetupError, pd.ValidationError)): sim.updated_copy(structures=[inf_geometry], sources=[rot_source]) - with pytest.raises(SetupError): + with pytest.raises((SetupError, pd.ValidationError)): sim.updated_copy(structures=[anisotropic_geometry], sources=[rot_source]) # Same thing with ModeSimulation - with pytest.raises(SetupError): + with pytest.raises((SetupError, pd.ValidationError)): td.ModeSimulation( structures=[inf_geometry], size=(0, 5, 5), @@ -151,7 +151,7 @@ def test_validation_from_simulation(): freqs=[td.C_0], ) - with pytest.raises(SetupError): + with pytest.raises((SetupError, pd.ValidationError)): td.ModeSimulation( structures=[anisotropic_geometry], size=(0, 5, 5), @@ -172,7 +172,7 @@ def get_mode_sim(): freqs=FS, mode_spec=mode_spec, grid_spec=td.GridSpec.auto(wavelength=td.C_0 / FS[0]), - monitors=[permittivity_monitor], + monitors=(permittivity_monitor,), ) return sim @@ -201,14 +201,14 @@ def test_mode_sim(): assert sim.plane == sim.geometry # must be planar or have plane - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(size=(3, 3, 3), plane=None) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(size=(3, 3, 3), plane=td.Box(size=(3, 3, 3))) _ = sim.updated_copy(size=(3, 3, 3), plane=td.Box(size=(3, 3, 0))) # plane must intersect sim geometry - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = sim.updated_copy(size=(3, 3, 3), plane=td.Box(center=(5, 5, 5), size=(1, 1, 0))) # test warning for not providing wavelength in autogrid @@ -257,7 +257,7 @@ def test_mode_sim(): ) assert td.ModeSimulation.from_simulation(sim) == sim - assert td.ModeSimulation.from_mode_solver(sim._mode_solver) == sim.updated_copy(monitors=[]) + assert td.ModeSimulation.from_mode_solver(sim._mode_solver) == sim.updated_copy(monitors=()) _ = td.ModeSimulation.from_simulation( simulation=fdtd_sim, plane=td.Box(size=(4, 4, 0)), diff --git a/tests/test_components/test_mode_interp.py b/tests/test_components/test_mode_interp.py index 67a087372c..6653f22c00 100644 --- a/tests/test_components/test_mode_interp.py +++ b/tests/test_components/test_mode_interp.py @@ -3,15 +3,15 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d.plugins.mode import ModeSolver # from tidy3d.plugins.smatrix.ports.wave import DEFAULT_WAVE_PORT_MODE_SPEC from ..test_data.test_data_arrays import MODE_SPEC, SIZE_2D -from ..utils import AssertLogLevel +from ..utils import AssertLogLevel, AssertLogStr # Shared test constants FREQS_DENSE = np.linspace(1e14, 2e14, 20) @@ -47,7 +47,7 @@ def test_interp_spec_default_method(): def test_interp_spec_cubic_needs_4_points(): """Test that cubic interpolation requires at least 4 points.""" - with pytest.raises(pydantic.ValidationError, match="Cubic interpolation requires at least 4"): + with pytest.raises(ValidationError, match="Cubic interpolation requires at least 4"): td.ModeInterpSpec.uniform(num_points=3, method="cubic") @@ -60,9 +60,7 @@ def test_interp_spec_valid_poly(): def test_interp_spec_poly_needs_3_points(): """Test that polynomial interpolation requires at least 3 points.""" - with pytest.raises( - pydantic.ValidationError, match="Polynomial interpolation requires at least 3" - ): + with pytest.raises(ValidationError, match="Polynomial interpolation requires at least 3"): td.ModeInterpSpec.uniform(num_points=2, method="poly") @@ -154,22 +152,22 @@ def test_interp_spec_sampling_points_custom(): def test_interp_spec_min_2_points(): """Test that at least 2 points are required.""" - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.ModeInterpSpec.uniform(num_points=1, method="linear") def test_interp_spec_positive_points(): """Test that num_points must be positive.""" - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.ModeInterpSpec.uniform(num_points=0, method="linear") - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.ModeInterpSpec.uniform(num_points=-5, method="linear") def test_interp_spec_invalid_method(): """Test that invalid interpolation method is rejected.""" - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.ModeInterpSpec.uniform(num_points=5, method="quadratic") @@ -192,7 +190,7 @@ def test_interp_spec_reduce_data_false(): def test_interp_spec_requires_tracking(): """Test that ModeMonitor with interp_spec requires track_freq.""" - with pytest.raises(pydantic.ValidationError, match="tracking"): + with pytest.raises(ValidationError, match="tracking"): mode_spec_no_track = td.ModeSpec( num_modes=2, track_freq=None, @@ -363,7 +361,7 @@ def test_mode_solver_warns_num_points(): ) plane = td.Box(center=(0, 0, 0), size=SIZE_2D) - with AssertLogLevel(None): + with AssertLogStr(None, excludes_str=["has bounds that extend"]): ms = ModeSolver( simulation=sim, plane=plane, @@ -620,12 +618,10 @@ def test_mode_solver_data_interp_single_frequency(): field_data = getattr(data_interp, field_name) assert field_data is not None assert field_data.coords["f"].size == 1 - assert float(field_data.coords["f"]) == 1.5e14 + assert float(field_data.coords["f"].item()) == 1.5e14 # Check n_group_raw and dispersion_raw if present if data_interp.n_group_raw is not None: - print(data_interp.n_group_raw.shape) - print((1, original_num_modes)) assert data_interp.n_group_raw.shape == (1, original_num_modes) if data_interp.dispersion_raw is not None: assert data_interp.dispersion_raw.shape == (1, original_num_modes) diff --git a/tests/test_components/test_monitor.py b/tests/test_components/test_monitor.py index b8b9faa60b..42948ab3d9 100644 --- a/tests/test_components/test_monitor.py +++ b/tests/test_components/test_monitor.py @@ -3,7 +3,7 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic +import pydantic as pd import pytest import tidy3d as td @@ -13,7 +13,7 @@ def test_stop_start(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FluxTimeMonitor(size=(1, 1, 0), name="f", start=2, stop=1) @@ -60,13 +60,13 @@ def test_downsampled(): def test_excluded_surfaces_flat(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FluxMonitor(size=(1, 1, 0), name="f", freqs=[1e12], exclude_surfaces=("x-",)) def test_fld_mnt_freqs_none(): """Test that validation errors if freqs=[None].""" - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.FieldMonitor(center=(0, 0, 0), size=(0, 0, 0), freqs=[None], name="test") @@ -150,7 +150,7 @@ def test_fieldproj_surfaces(): def test_fieldproj_surfaces_in_simulaiton(): # test error if all projection surfaces are outside the simulation domain M = td.FieldProjectionAngleMonitor(size=(3, 3, 3), theta=[1], phi=[0], name="f", freqs=[2e12]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Simulation( size=(2, 2, 2), run_time=1e-12, @@ -162,13 +162,13 @@ def test_fieldproj_surfaces_in_simulaiton(): _ = td.Simulation( size=(2, 2, 2), run_time=1e-12, - monitors=[M], + monitors=(M,), grid_spec=td.GridSpec.uniform(0.1), ) # error when the surfaces that are in are excluded - M = M.updated_copy(exclude_surfaces=["x-", "x+"]) - with pytest.raises(pydantic.ValidationError): + M = M.updated_copy(exclude_surfaces=("x-", "x+")) + with pytest.raises(pd.ValidationError): _ = td.Simulation( size=(2, 2, 2), run_time=1e-12, @@ -179,11 +179,11 @@ def test_fieldproj_surfaces_in_simulaiton(): def test_fieldproj_kspace_range(): # make sure ux, uy are in [-1, 1] for k-space projection monitors - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FieldProjectionKSpaceMonitor( size=(2, 0, 2), ux=[0.1, 2], uy=[0], name="f", freqs=[2e12], proj_axis=1 ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FieldProjectionKSpaceMonitor( size=(2, 0, 2), ux=[0.1, 0.2], uy=[1.1], name="f", freqs=[2e12], proj_axis=1 ) @@ -212,12 +212,12 @@ def test_fieldproj_window(): points = np.linspace(0, 10, 100) _ = M.window_function(points, window_size, window_minus, window_plus, 2) # do not allow a window size larger than 1 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FieldProjectionAngleMonitor( size=(2, 0, 2), theta=[1, 2], phi=[0], name="f", freqs=[2e12], window_size=(0.2, 1.1) ) # do not allow non-zero windows for volume monitors - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FieldProjectionAngleMonitor( size=(2, 1, 2), theta=[1, 2], phi=[0], name="f", freqs=[2e12], window_size=(0.2, 0) ) @@ -242,7 +242,7 @@ def test_storage_sizes(proj_mnt): def test_monitor_freqs_empty(): # errors when no frequencies supplied - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.FieldMonitor( size=(td.inf, td.inf, td.inf), freqs=[], @@ -326,7 +326,7 @@ def test_diffraction_validators(): y=td.Boundary.periodic(), z=td.Boundary.pml(), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.Simulation( size=(2, 2, 2), run_time=1e-12, @@ -337,7 +337,7 @@ def test_diffraction_validators(): ) # ensure error if monitor isn't infinite in two directions - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.DiffractionMonitor(size=[td.inf, 4, 0], freqs=[1e12], name="de") @@ -395,11 +395,11 @@ def test_monitor(): def test_monitor_plane(): # make sure flux, mode and diffraction monitors fail with non planar geometries for size in ((0, 0, 0), (1, 0, 0), (1, 1, 1)): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.ModeMonitor(size=size, freqs=FREQS, modes=[]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.ModeSolverMonitor(size=size, freqs=FREQS, modes=[]) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): td.DiffractionMonitor(size=size, freqs=FREQS, name="de") @@ -452,18 +452,18 @@ def test_directivity_monitor(): size = (1, 2, 3) center = (1, 2, 3) - pd = np.atleast_1d(40000) + pd_arr = np.atleast_1d(40000) thetas = np.linspace(0, 2 * np.pi, 100) phis = np.linspace(0, np.pi, 100) # far_field_approx cannot be set to False - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): _ = td.DirectivityMonitor( size=size, center=center, theta=thetas, phi=phis, - proj_distance=pd, + proj_distance=pd_arr, freqs=FREQS, name="directivity", far_field_approx=False, diff --git a/tests/test_components/test_parameter_perturbation.py b/tests/test_components/test_parameter_perturbation.py index 4fbd5d2091..69a6a0f8e8 100644 --- a/tests/test_components/test_parameter_perturbation.py +++ b/tests/test_components/test_parameter_perturbation.py @@ -4,8 +4,8 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td @@ -45,14 +45,14 @@ def test_heat_perturbation(): # test complex type detection assert not perturb.is_complex - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.LinearHeatPerturbation( coeff=0.01, temperature_ref=-300, temperature_range=(200, 400), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.LinearHeatPerturbation( coeff=0.01, temperature_ref=300, @@ -141,7 +141,7 @@ def test_heat_perturbation(): assert test_value_out == perturb_data.data[2] # test not allowed interpolation method - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.CustomHeatPerturbation( perturbation_values=perturb_data, interp_method="quadratic", @@ -166,7 +166,7 @@ def test_charge_perturbation(): # test complex type detection assert not perturb.is_complex - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.LinearChargePerturbation( electron_coeff=1e-21, electron_ref=0, @@ -176,7 +176,7 @@ def test_charge_perturbation(): hole_range=(0, 0.5e20), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.LinearChargePerturbation( electron_coeff=1e-21, electron_ref=0, @@ -348,7 +348,7 @@ def test_sample(perturb): assert test_value_out == perturb_data[-1, -1].item() # test not allowed interpolation method - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): perturb = td.CustomChargePerturbation( perturbation_values=perturb_data, interp_method="quadratic", @@ -671,7 +671,7 @@ def test_delta_model(): delta_model = td.NedeljkovicSorefMashanovich(ref_freq=freq) # make sure it serializes - delta_model.json() + delta_model.model_dump_json() # make sure it's interpolating correctly coeffs_3_5 = np.array([3.10e-21, 1.210, 6.05e-20, 1.145, 6.95e-21, 0.986, 9.28e-18, 0.834]) diff --git a/tests/test_components/test_perturbation_medium.py b/tests/test_components/test_perturbation_medium.py index a9bcbc5c74..767fd0c090 100644 --- a/tests/test_components/test_perturbation_medium.py +++ b/tests/test_components/test_perturbation_medium.py @@ -3,8 +3,8 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td @@ -103,7 +103,7 @@ def test_perturbation_medium(unstructured): assert cmed.allow_gain == pmed.allow_gain # permittivity < 1 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = pmed.perturbed_copy(1.1 * temperature) # conductivity validators @@ -130,7 +130,7 @@ def test_perturbation_medium(unstructured): for pmed in [pmed_direct, pmed_perm, pmed_index]: cmed = pmed.perturbed_copy(0.9 * temperature) # positive conductivity assert not cmed.subpixel - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = pmed.perturbed_copy(1.1 * temperature) # negative conductivity # negative conductivity but allow gain @@ -138,11 +138,11 @@ def test_perturbation_medium(unstructured): _ = pmed.perturbed_copy(1.1 * temperature) # complex perturbation - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): pmed = td.PerturbationMedium(permittivity=3, permittivity_perturbation=pp_complex) # overdefinition - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.PerturbationMedium( permittivity=1.21, permittivity_perturbation=pp_real, @@ -252,18 +252,18 @@ def test_perturbation_medium(unstructured): assert cmed.allow_gain == pmed.allow_gain # eps_inf < 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = pmed.perturbed_copy(1.1 * temperature) # mismatch between base parameter and perturbations - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): pmed = td.PerturbationPoleResidue( poles=[(1j, 3), (2j, 4)], poles_perturbation=[(None, pp_real)], ) # overdefinition - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.PerturbationPoleResidue( eps_inf=1.21, poles=[(1j, 3), (2j, 4)], diff --git a/tests/test_components/test_scene.py b/tests/test_components/test_scene.py index b13dac7473..9f964f8b49 100644 --- a/tests/test_components/test_scene.py +++ b/tests/test_components/test_scene.py @@ -5,12 +5,12 @@ import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pd import pytest +from pydantic import ValidationError import tidy3d as td import tidy3d.components.scene as scene_mod -from tidy3d.components.scene import MAX_NUM_MEDIUMS +from tidy3d.components import scene from tidy3d.components.viz import STRUCTURE_EPS_CMAP, STRUCTURE_EPS_CMAP_R from tidy3d.exceptions import SetupError @@ -19,6 +19,7 @@ SCENE = td.Scene() SCENE_FULL = SIM_FULL.scene +TEST_MAX_NUM_MEDIUMS = 3 def test_scene_init(): @@ -51,7 +52,7 @@ def test_scene_init(): def test_validate_components_none(): - assert SCENE._validate_mediums(val=None) is None + assert type(SCENE)._validate_mediums(val=None) is None def test_plot_eps(): @@ -120,7 +121,7 @@ def test_structure_alpha(): new_structs = [ td.Structure(geometry=s.geometry, medium=SCENE_FULL.medium) for s in SCENE_FULL.structures ] - S2 = SCENE_FULL.copy(update={"structures": new_structs}) + S2 = SCENE_FULL.copy(update={"structures": tuple(new_structs)}) _ = S2.plot_structures_eps(x=0, alpha=0.5) plt.close() @@ -240,11 +241,11 @@ def test_structure_eps_color_mapping_no_matplotlib( assert np.allclose(params.facecolor, expected) -def test_num_mediums(): +def test_num_mediums(monkeypatch): """Make sure we error if too many mediums supplied.""" - + monkeypatch.setattr(scene, "MAX_NUM_MEDIUMS", TEST_MAX_NUM_MEDIUMS) structures = [] - for i in range(MAX_NUM_MEDIUMS): + for i in range(TEST_MAX_NUM_MEDIUMS): structures.append( td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=td.Medium(permittivity=i + 1)) ) @@ -252,7 +253,7 @@ def test_num_mediums(): structures=structures, ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): structures.append( td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=td.Medium(permittivity=i + 2)) ) @@ -288,7 +289,7 @@ def _test_names_default(): def test_names_unique(): - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.Scene( structures=[ td.Structure( @@ -402,7 +403,7 @@ def test_perturbed_mediums_copy(unstructured, z): # medium=td.Medium(permittivity=2.0), # ), # ] -# with pytest.raises(pd.ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "): +# with pytest.raises(ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "): # _ = td.Scene(structures=not_fine) diff --git a/tests/test_components/test_sidewall.py b/tests/test_components/test_sidewall.py index 31252a603e..f2f4048a46 100644 --- a/tests/test_components/test_sidewall.py +++ b/tests/test_components/test_sidewall.py @@ -3,8 +3,8 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError from shapely import Point, Polygon import tidy3d as td @@ -132,17 +132,17 @@ def test_valid_polygon(): # area = 0 vertices = ((0, 0), (1, 0), (2, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) # only two points vertices = ((0, 0), (1, 0), (1, 0)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) # intersecting edges vertices = ((0, 0), (1, 0), (1, 1), (0, 1), (0.5, -1)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) @@ -161,13 +161,13 @@ def test_crossing_square_poly(): dilation = -1.1 angle = 0 for ref_plane in ["bottom", "middle", "top"]: - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds, reference_plane=ref_plane) # angle too large, self-intersecting dilation = 0 angle = np.pi / 3 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) _ = setup_polyslab(vertices, dilation, angle, bounds, reference_plane="top") # middle plane @@ -176,13 +176,13 @@ def test_crossing_square_poly(): # angle too large for middle reference plane angle = np.arctan(2.001) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds, reference_plane="middle") # combines both dilation = -0.1 angle = np.pi / 4 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) @@ -196,26 +196,26 @@ def test_crossing_concave_poly(): vertices = ((-0.5, 1), (-0.5, -1), (1, -1), (0, -0.1), (0, 0.1), (1, 1)) dilation = 0.5 angle = 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) # polygon splitting dilation = -0.3 angle = 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) # polygon fully eroded dilation = -0.5 angle = 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) # # or, effectively dilation = 0 angle = -np.pi / 4 for bounds in [(0, 0.3), (0, 0.5)]: - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = setup_polyslab(vertices, dilation, angle, bounds) _ = setup_polyslab(vertices, dilation, -angle, bounds, reference_plane="top") @@ -224,7 +224,7 @@ def test_crossing_concave_poly(): bounds = (0, 0.44) _ = setup_polyslab(vertices, dilation, angle, bounds, reference_plane="middle") _ = setup_polyslab(vertices, dilation, -angle, bounds, reference_plane="middle") - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): # vertices degenerate bounds = (0, 0.45) _ = setup_polyslab(vertices, dilation, angle, bounds, reference_plane="middle") diff --git a/tests/test_components/test_simulation.py b/tests/test_components/test_simulation.py index 3fdc73f02f..2ee42d51f6 100644 --- a/tests/test_components/test_simulation.py +++ b/tests/test_components/test_simulation.py @@ -7,13 +7,12 @@ import gdstk import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest from matplotlib.testing.compare import compare_images +from pydantic import ValidationError import tidy3d as td -from tidy3d.components import simulation -from tidy3d.components.scene import MAX_NUM_MEDIUMS +from tidy3d.components import scene, simulation from tidy3d.components.simulation import MAX_NUM_SOURCES from tidy3d.exceptions import SetupError, Tidy3dError, Tidy3dKeyError from tidy3d.plugins.mode import ModeSolver @@ -29,6 +28,7 @@ SIM = td.Simulation(size=(1, 1, 1), run_time=1e-12, grid_spec=td.GridSpec(wavelength=1.0)) RTOL = 0.01 +TEST_MAX_NUM_MEDIUMS = 3 def test_sim_init(): @@ -37,7 +37,7 @@ def test_sim_init(): sim = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, - structures=[ + structures=( td.Structure( geometry=td.Box(size=(1, 1, 1), center=(-1, 0, 0)), medium=td.Medium(permittivity=2.0), @@ -53,8 +53,8 @@ def test_sim_init(): geometry=td.Cylinder(radius=1.4, length=2.0, center=(1.0, 0.0, -1.0), axis=1), medium=td.Medium(), ), - ], - sources=[ + ), + sources=( td.UniformCurrentSource( size=(0, 0, 0), center=(0, -0.5, 0), @@ -73,11 +73,11 @@ def test_sim_init(): fwidth=1e12, ), ), - ], - monitors=[ + ), + monitors=( td.FieldMonitor(size=(0, 0, 0), center=(0, 0, 0), freqs=[1e14, 2e14], name="point"), td.FluxTimeMonitor(size=(1, 1, 0), center=(0, 0, 0), interval=10, name="plane"), - ], + ), symmetry=(0, 1, -1), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=20), @@ -123,13 +123,13 @@ def test_num_cells(): size=(1, 1, 1), run_time=1e-12, grid_spec=td.GridSpec.uniform(dl=0.1), - sources=[ + sources=( td.PointDipole( center=(0, 0, 0), polarization="Ex", source_time=td.GaussianPulse(freq0=2e14, fwidth=1e14), - ) - ], + ), + ), ) assert sim.num_computational_grid_points > sim.num_cells # due to extra pixels at boundaries @@ -143,7 +143,7 @@ def test_monitors_data_size(): sim = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, - structures=[ + structures=( td.Structure( geometry=td.Box(size=(1, 1, 1), center=(-1, 0, 0)), medium=td.Medium(permittivity=2.0), @@ -159,8 +159,8 @@ def test_monitors_data_size(): geometry=td.Cylinder(radius=1.4, length=2.0, center=(1.0, 0.0, -1.0), axis=1), medium=td.Medium(), ), - ], - sources=[ + ), + sources=( td.UniformCurrentSource( size=(0, 0, 0), center=(0, -0.5, 0), @@ -179,11 +179,11 @@ def test_monitors_data_size(): fwidth=1e12, ), ), - ], - monitors=[ + ), + monitors=( td.FieldMonitor(size=(0, 0, 0), center=(0, 0, 0), freqs=[1e12, 2e12], name="point"), td.FluxTimeMonitor(size=(1, 1, 0), center=(0, 0, 0), interval=10, name="plane"), - ], + ), symmetry=(0, 1, -1), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=20), @@ -206,13 +206,13 @@ def test_deprecation_defaults(): size=(1, 1, 1), run_time=1e-12, grid_spec=td.GridSpec.uniform(dl=0.1), - sources=[ + sources=( td.PointDipole( center=(0, 0, 0), polarization="Ex", source_time=td.GaussianPulse(freq0=2e14, fwidth=1e14), - ) - ], + ), + ), ) @@ -231,19 +231,19 @@ def place_box(center_offset): center=CENTER_SHIFT, grid_spec=td.GridSpec(wavelength=1.0), run_time=1e-12, - structures=[ + structures=( td.Structure( geometry=td.Box(size=(1, 1, 1), center=shifted_center), medium=td.Medium() - ) - ], + ), + ), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[ + sources=( td.PointDipole( center=CENTER_SHIFT, polarization="Ex", source_time=td.GaussianPulse(freq0=td.C_0, fwidth=td.C_0), - ) - ], + ), + ), ) # create all permutations of squares being shifted 1, -1, or zero in all three directions @@ -290,7 +290,7 @@ def test_sim_size(): s._validate_size() # check too many time steps - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): s = td.Simulation( size=(1, 1, 1), run_time=1e-7, @@ -303,11 +303,11 @@ def _test_monitor_size(): s = td.Simulation( size=(1, 1, 1), grid_spec=td.GridSpec.uniform(1e-3), - monitors=[ + monitors=( td.FieldMonitor( size=(td.inf, td.inf, td.inf), freqs=np.linspace(0, 200e12, 10001), name="test" - ) - ], + ), + ), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -329,9 +329,9 @@ def test_monitor_medium_frequency_range(freq, log_level): with AssertLogLevel(log_level): _ = td.Simulation( size=(1, 1, 1), - structures=[box], - monitors=[mnt], - sources=[src], + structures=(box,), + monitors=(mnt,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -353,8 +353,8 @@ def test_monitor_simulation_frequency_range(monitor_freq, log_level): with AssertLogLevel(log_level): _ = td.Simulation( size=(1, 1, 1), - monitors=[mnt], - sources=[src], + monitors=(mnt,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -372,8 +372,8 @@ def test_validate_monitor_simulation_frequency_range(): mnt = td.FieldMonitor(size=(0, 0, 0), name="freq", freqs=[2e12]) s = td.Simulation( size=(1, 1, 1), - monitors=[mnt], - sources=[src], + monitors=(mnt,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -383,8 +383,8 @@ def test_validate_monitor_simulation_frequency_range(): mnt = td.FieldMonitor(size=(0, 0, 0), name="freq", freqs=[5e10]) s = td.Simulation( size=(1, 1, 1), - monitors=[mnt], - sources=[src], + monitors=(mnt,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -394,8 +394,8 @@ def test_validate_monitor_simulation_frequency_range(): mnt = td.FieldMonitor(size=(0, 0, 0), name="freq", freqs=[5e13]) s = td.Simulation( size=(1, 1, 1), - monitors=[mnt], - sources=[src], + monitors=(mnt,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -403,7 +403,7 @@ def test_validate_monitor_simulation_frequency_range(): def test_validate_bloch_with_symmetry(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), run_time=1e-12, @@ -430,7 +430,7 @@ def test_validate_normalize_index(): ) # negative normalize index - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), run_time=1e-12, @@ -439,12 +439,12 @@ def test_validate_normalize_index(): ) # normalize index out of bounds - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), run_time=1e-12, grid_spec=td.GridSpec.uniform(dl=0.1), - sources=[src], + sources=(src,), normalize_index=1, ) # skipped if no sources @@ -453,12 +453,12 @@ def test_validate_normalize_index(): ) # normalize by zero-amplitude source - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), run_time=1e-12, grid_spec=td.GridSpec.uniform(dl=0.1), - sources=[src0], + sources=(src0,), ) @@ -516,16 +516,16 @@ def test_validate_plane_wave_boundaries(): td.Simulation( size=(1, 1, 1), run_time=1e-12, - sources=[src1], + sources=(src1,), boundary_spec=bspec1, ) # angled incidence plane wave with PMLs / absorbers should error - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 1, 1), run_time=1e-12, - sources=[src2], + sources=(src2,), boundary_spec=bspec1, ) @@ -534,7 +534,7 @@ def test_validate_plane_wave_boundaries(): td.Simulation( size=(1, 1, 1), run_time=1e-12, - sources=[src2], + sources=(src2,), boundary_spec=td.BoundarySpec.all_sides(td.Periodic()), ) @@ -543,9 +543,9 @@ def test_validate_plane_wave_boundaries(): td.Simulation( size=(1, 1, 1), run_time=1e-12, - sources=[src2], + sources=(src2,), boundary_spec=bspec3, - monitors=[mnt], + monitors=(mnt,), ) # angled incidence plane wave with wrong Bloch vector should warn @@ -553,7 +553,7 @@ def test_validate_plane_wave_boundaries(): td.Simulation( size=(1, 1, 1), run_time=1e-12, - sources=[src2], + sources=(src2,), boundary_spec=bspec4, ) @@ -587,7 +587,7 @@ def test_validate_zero_dim_boundaries(): td.Simulation( size=(1, 1, 0), run_time=1e-12, - sources=[src], + sources=(src,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(), y=td.Boundary.stable_pml(), @@ -613,7 +613,7 @@ def test_validate_symmetry_boundaries(): z=td.Boundary.pml(), ), ) - with pytest.raises(pydantic.ValidationError, match="Symmetry"): + with pytest.raises(ValidationError, match="Symmetry"): td.Simulation( size=(1, 1, 1), symmetry=(1, 1, 1), @@ -628,26 +628,12 @@ def test_validate_symmetry_boundaries(): def test_validate_components_none(): - assert SIM._structures_not_at_edges(val=None, values=SIM.dict()) is None - assert SIM._validate_num_sources(val=None) is None - assert SIM._warn_monitor_mediums_frequency_range(val=None, values=SIM.dict()) is None - assert SIM._warn_monitor_simulation_frequency_range(val=None, values=SIM.dict()) is None - assert SIM._warn_grid_size_too_small(val=None, values=SIM.dict()) is None - assert SIM._source_homogeneous_isotropic(val=None, values=SIM.dict()) is None - - -def test_sources_edge_case_validation(): - values = SIM.dict() - values.pop("sources") - with AssertLogLevel("WARNING"): - SIM._warn_monitor_simulation_frequency_range(val="test", values=values) - - -def test_validate_size_run_time(monkeypatch): - monkeypatch.setattr(simulation, "MAX_TIME_STEPS", 1) - with pytest.raises(SetupError): - s = SIM.copy(update={"run_time": 1e-12}) - s._validate_size() + assert type(SIM)._validate_num_sources(val=None) is None + assert SIM._structures_not_at_edges() is SIM + assert SIM._warn_monitor_mediums_frequency_range() is SIM + assert SIM._warn_monitor_simulation_frequency_range() is SIM + assert SIM._warn_grid_size_too_small() is SIM + assert SIM._source_homogeneous_isotropic() is SIM def test_validate_size_spatial_and_time(monkeypatch): @@ -703,7 +689,7 @@ def test_validate_mnt_size(monkeypatch): # medium=td.Medium(permittivity=2.0), # ), # ] -# with pytest.raises(pydantic.ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "): +# with pytest.raises(ValidationError, match=f" {MAX_GEOMETRY_COUNT + 2} "): # _ = td.Simulation(size=(1, 1, 1), run_time=1, grid_spec=gs, structures=not_fine) @@ -803,7 +789,7 @@ def make_sim(self, medium): size=(L, L, L), grid_spec=td.GridSpec.uniform(dl=0.01), structures=structures, - sources=[source], + sources=(source,), run_time=1e-12, ) @@ -975,7 +961,7 @@ def test_structure_alpha(): new_structs = [ td.Structure(geometry=s.geometry, medium=SIM_FULL.medium) for s in SIM_FULL.structures ] - S2 = SIM_FULL.copy(update={"structures": new_structs}) + S2 = SIM_FULL.copy(update={"structures": tuple(new_structs)}) _ = S2.plot_structures_eps(x=0, alpha=0.5) plt.close() @@ -990,8 +976,8 @@ def test_plot_eps_with_default_frequency(): box = td.Structure(medium=chromium, geometry=td.Box(size=(0.2, 0.2, 0.2), center=(0, 0, 0))) sim = td.Simulation( size=(1, 1, 1), - structures=[box], - sources=[src], + structures=(box,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.PECBoundary()), grid_spec=td.GridSpec.uniform(dl=0.01), @@ -1025,7 +1011,7 @@ def test_plot_symmetries(): def test_plot_grid(): override = td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=td.Medium()) S2 = SIM_FULL.copy( - update={"grid_spec": td.GridSpec(wavelength=1.0, override_structures=[override])} + update={"grid_spec": td.GridSpec(wavelength=1.0, override_structures=(override,))} ) S2.plot_grid(x=0) plt.close() @@ -1049,7 +1035,7 @@ def test_plot_with_lumped_elements(): load = td.LumpedResistor( center=(0, 0, 0), size=(1, 2, 0), name="resistor", voltage_axis=0, resistance=50 ) - sim_test = SIM_FULL.updated_copy(lumped_elements=[load]) + sim_test = SIM_FULL.updated_copy(lumped_elements=(load,)) sim_test.plot(z=0) plt.close() @@ -1086,7 +1072,7 @@ def test_nyquist(): # nyquist step decreses to 1 when the frequency-domain monitor is at high frequency S_MONITOR = S.copy( - update={"monitors": [td.FluxMonitor(size=(1, 1, 0), freqs=[1e14, 1e20], name="flux")]} + update={"monitors": (td.FluxMonitor(size=(1, 1, 0), freqs=[1e14, 1e20], name="flux"),)} ) assert S_MONITOR.nyquist_step == 1 @@ -1136,8 +1122,8 @@ def test_large_grid_size(grid_size, log_level): _ = td.Simulation( size=(1, 1, 1), grid_spec=td.GridSpec.uniform(dl=grid_size), - structures=[box], - sources=[src], + structures=(box,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1157,8 +1143,8 @@ def test_sim_structure_gap(box_size, log_level): with AssertLogLevel(log_level): _ = td.Simulation( size=(10, 10, 10), - structures=[box], - sources=[src], + structures=(box,), + sources=(src,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=6), y=td.Boundary.pml(num_layers=6), @@ -1196,36 +1182,36 @@ def test_sim_plane_wave_error(): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg, - structures=[box_transparent], - sources=[src], + structures=(box_transparent,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) # with non-transparent box, raise - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg, - structures=[box_transparent, box], - sources=[src], + structures=(box_transparent, box), + sources=(src), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) # raise with anisotropic medium - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg_diag, - sources=[src], + sources=(src,), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg_full, - sources=[src], + sources=(src,), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1279,21 +1265,21 @@ def test_sim_monitor_homogeneous(): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg, - structures=[box_transparent], - sources=[src], + structures=(box_transparent,), + sources=(src,), run_time=1e-12, - monitors=[monitor], + monitors=(monitor,), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) # with non-transparent box, raise - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg, - structures=[box], - sources=[src], - monitors=[monitor], + structures=(box,), + sources=(src,), + monitors=(monitor,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1324,9 +1310,9 @@ def test_sim_monitor_homogeneous(): _ = td.Simulation( size=(1, 1, 1), medium=medium_bg, - structures=[box_transparent, box], - sources=[src], - monitors=[monitor_n2f_vol_exclude], + structures=(box_transparent, box), + sources=(src,), + monitors=(monitor_n2f_vol_exclude,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1382,10 +1368,10 @@ def test_proj_monitor_distance(): with AssertLogLevel("WARNING"): _ = td.Simulation( size=(1, 1, 0.3), - structures=[], - sources=[src], + structures=(), + sources=(src,), run_time=1e-12, - monitors=[monitor_n2f_far], + monitors=(monitor_n2f_far,), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1393,10 +1379,10 @@ def test_proj_monitor_distance(): with AssertLogLevel(None): _ = td.Simulation( size=(1, 1, 0.3), - structures=[], - sources=[src], + structures=(), + sources=(src,), run_time=1e-12, - monitors=[monitor_n2f], + monitors=(monitor_n2f,), grid_spec=td.GridSpec.auto(wavelength=src.source_time.freq0), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1405,10 +1391,10 @@ def test_proj_monitor_distance(): with AssertLogLevel(None): _ = td.Simulation( size=(1, 1, 0.3), - structures=[], - sources=[src], + structures=(), + sources=(src,), run_time=1e-12, - monitors=[monitor_n2f_approx], + monitors=(monitor_n2f_approx,), grid_spec=td.GridSpec.auto(wavelength=src.source_time.freq0), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1497,10 +1483,10 @@ def test_proj_monitor_warnings(monitor_type, monitor_kwargs, custom_origin, norm with AssertLogLevel("WARNING"): _ = td.Simulation( size=(1, 1, 1), - structures=[], - sources=[src], + structures=(), + sources=(src,), run_time=1e-12, - monitors=[monitor], + monitors=(monitor,), ) @@ -1529,22 +1515,22 @@ def test_diffraction_medium(): pol_angle=-1.0, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2, 2, 2), - structures=[box_cond], - sources=[src], + structures=(box_cond,), + sources=(src,), run_time=1e-12, - monitors=[monitor], + monitors=(monitor,), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2, 2, 2), - structures=[box_disp], - sources=[src], - monitors=[monitor], + structures=(box_disp,), + sources=(src,), + monitors=(monitor,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1572,8 +1558,8 @@ def test_sim_structure_extent(box_size, log_level): with AssertLogLevel(log_level): _ = td.Simulation( size=(1, 1, 1), - structures=[box], - sources=[src], + structures=(box,), + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1602,9 +1588,9 @@ def test_warn_lumped_elements_outside_sim_bounds(): sim_good = td.Simulation( size=sim_size, center=sim_center, - sources=[src], + sources=(src,), run_time=1e-12, - lumped_elements=[resistor_in], + lumped_elements=(resistor_in,), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) assert len(sim_good.volumetric_structures) == 1 @@ -1621,9 +1607,9 @@ def test_warn_lumped_elements_outside_sim_bounds(): sim_good = td.Simulation( size=sim_size, center=sim_center, - sources=[src], + sources=(src,), run_time=1e-12, - lumped_elements=[resistor_in], + lumped_elements=(resistor_in,), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) assert len(sim_good.volumetric_structures) == 1 @@ -1637,7 +1623,7 @@ def test_warn_lumped_elements_outside_sim_bounds(): name="resistor_outside", ) with AssertLogStr("WARNING", contains_str="not completely inside"): - sim_bad = sim_good.updated_copy(lumped_elements=[resistor_out]) + sim_bad = sim_good.updated_copy(lumped_elements=(resistor_out,)) assert len(sim_bad.volumetric_structures) == 0 # Lumped element is flush against boundary along its zero size dimension @@ -1649,7 +1635,7 @@ def test_warn_lumped_elements_outside_sim_bounds(): name="resistor_edge", ) with AssertLogStr("WARNING", contains_str="not completely inside"): - sim_bad = sim_good.updated_copy(lumped_elements=[resistor_edge]) + sim_bad = sim_good.updated_copy(lumped_elements=(resistor_edge,)) assert len(sim_bad.volumetric_structures) == 0 @@ -1680,9 +1666,9 @@ def test_sim_validate_structure_bounds_pml(box_length, absorb_type, log_level): with AssertLogLevel(log_level): _ = td.Simulation( size=(1, 1, 1), - structures=[box], + structures=(box,), grid_spec=td.GridSpec.auto(wavelength=0.001), - sources=[src], + sources=(src,), run_time=1e-12, boundary_spec=td.BoundarySpec( x=td.Boundary(plus=boundary, minus=boundary), @@ -1694,12 +1680,10 @@ def test_sim_validate_structure_bounds_pml(box_length, absorb_type, log_level): def test_num_mediums(monkeypatch): """Make sure we error if too many mediums supplied.""" - - max_num_mediums = 10 - monkeypatch.setattr(simulation, "MAX_NUM_MEDIUMS", max_num_mediums) + monkeypatch.setattr(simulation, "MAX_NUM_MEDIUMS", TEST_MAX_NUM_MEDIUMS) structures = [] grid_spec = td.GridSpec.auto(wavelength=1.0) - for i in range(max_num_mediums): + for i in range(TEST_MAX_NUM_MEDIUMS): structures.append( td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=td.Medium(permittivity=i + 1)) ) @@ -1711,7 +1695,7 @@ def test_num_mediums(monkeypatch): boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): structures.append( td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=td.Medium(permittivity=i + 2)) ) @@ -1751,10 +1735,10 @@ def test_num_sources(): direction="+", ) - _ = td.Simulation(size=(5, 5, 5), run_time=1e-12, sources=[src] * MAX_NUM_SOURCES) + _ = td.Simulation(size=(5, 5, 5), run_time=1e-12, sources=(src,) * MAX_NUM_SOURCES) - with pytest.raises(pydantic.ValidationError): - _ = td.Simulation(size=(5, 5, 5), run_time=1e-12, sources=[src] * (MAX_NUM_SOURCES + 1)) + with pytest.raises(ValidationError): + _ = td.Simulation(size=(5, 5, 5), run_time=1e-12, sources=(src,) * (MAX_NUM_SOURCES + 1)) def _test_names_default(): @@ -1763,7 +1747,7 @@ def _test_names_default(): sim = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, - structures=[ + structures=( td.Structure( geometry=td.Box(size=(1, 1, 1), center=(-1, 0, 0)), medium=td.Medium(permittivity=2.0), @@ -1779,8 +1763,8 @@ def _test_names_default(): geometry=td.Cylinder(radius=1.4, length=2.0, center=(1.0, 0.0, -1.0), axis=1), medium=td.Medium(), ), - ], - sources=[ + ), + sources=( td.UniformCurrentSource( size=(0, 0, 0), center=(0, -0.5, 0), @@ -1799,12 +1783,12 @@ def _test_names_default(): polarization="Ey", source_time=td.GaussianPulse(freq0=1e14, fwidth=1e12), ), - ], - monitors=[ + ), + monitors=( td.FluxMonitor(size=(1, 1, 0), center=(0, -0.5, 0), freqs=[1e12], name="mon1"), td.FluxMonitor(size=(0, 1, 1), center=(0, -0.5, 0), freqs=[1e12], name="mon2"), td.FluxMonitor(size=(1, 0, 1), center=(0, -0.5, 0), freqs=[1e12], name="mon3"), - ], + ), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1816,11 +1800,11 @@ def _test_names_default(): def test_names_unique(): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, - structures=[ + structures=( td.Structure( geometry=td.Box(size=(1, 1, 1), center=(-1, 0, 0)), medium=td.Medium(permittivity=2.0), @@ -1831,15 +1815,15 @@ def test_names_unique(): medium=td.Medium(permittivity=2.0), name="struct1", ), - ], + ), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, - sources=[ + sources=( td.UniformCurrentSource( size=(0, 0, 0), center=(0, -0.5, 0), @@ -1854,18 +1838,18 @@ def test_names_unique(): source_time=td.GaussianPulse(freq0=1e14, fwidth=1e12), name="source1", ), - ], + ), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, - monitors=[ + monitors=( td.FluxMonitor(size=(1, 1, 0), center=(0, -0.5, 0), freqs=[1e12], name="mon1"), td.FluxMonitor(size=(0, 1, 1), center=(0, -0.5, 0), freqs=[1e12], name="mon1"), - ], + ), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1875,28 +1859,28 @@ def test_mode_object_syms(): g = td.GaussianPulse(freq0=1e12, fwidth=0.1e12) # wrong mode source - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=(1.0, -1.0, 0.5), size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.auto(wavelength=td.C_0 / 1.0), run_time=1e-12, symmetry=(1, -1, 0), - sources=[td.ModeSource(size=(2, 2, 0), direction="+", source_time=g)], + sources=(td.ModeSource(size=(2, 2, 0), direction="+", source_time=g),), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) # wrong mode monitor - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=(1.0, -1.0, 0.5), size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.auto(wavelength=td.C_0 / 1.0), run_time=1e-12, symmetry=(1, -1, 0), - monitors=[ - td.ModeMonitor(size=(2, 2, 0), name="mnt", freqs=[2e12], mode_spec=td.ModeSpec()) - ], + monitors=( + td.ModeMonitor(size=(2, 2, 0), name="mnt", freqs=[2e12], mode_spec=td.ModeSpec()), + ), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1907,7 +1891,7 @@ def test_mode_object_syms(): grid_spec=td.GridSpec.auto(wavelength=td.C_0 / 1.0), run_time=1e-12, symmetry=(1, -1, 0), - sources=[td.ModeSource(center=(1, -1, 1), size=(2, 2, 0), direction="+", source_time=g)], + sources=(td.ModeSource(center=(1, -1, 1), size=(2, 2, 0), direction="+", source_time=g),), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1918,11 +1902,11 @@ def test_mode_object_syms(): grid_spec=td.GridSpec.auto(wavelength=td.C_0 / 1.0), run_time=1e-12, symmetry=(1, -1, 0), - monitors=[ + monitors=( td.ModeMonitor( center=(2, 0, 1), size=(2, 2, 0), name="mnt", freqs=[2e12], mode_spec=td.ModeSpec() - ) - ], + ), + ), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -1941,13 +1925,13 @@ def test_tfsf_symmetry(): injection_axis=2, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.auto(wavelength=td.C_0 / 1.0), run_time=1e-12, symmetry=(0, -1, 0), - sources=[source], + sources=(source,), ) @@ -1966,12 +1950,12 @@ def test_tfsf_aux_source_outside_domain(): injection_axis=2, ) - with pytest.raises(SetupError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 1.01), grid_spec=td.GridSpec.auto(wavelength=td.C_0 / 1.0), run_time=1e-12, - sources=[source], + sources=(source,), ) @@ -1999,7 +1983,7 @@ def test_tfsf_boundaries(): z=td.Boundary.periodic(), ), run_time=1e-12, - sources=[source], + sources=(source,), ) # can cross Bloch boundaries in the transverse directions @@ -2007,7 +1991,7 @@ def test_tfsf_boundaries(): size=(0.5, 0.5, 2.0), grid_spec=td.GridSpec.auto(wavelength=1.0), run_time=1e-12, - sources=[source], + sources=(source,), boundary_spec=td.BoundarySpec( x=td.Boundary.bloch_from_source(source=source, domain_size=0.5, axis=0, medium=None), y=td.Boundary.bloch_from_source(source=source, domain_size=0.5, axis=1, medium=None), @@ -2022,7 +2006,7 @@ def test_tfsf_boundaries(): size=(0.5, 0.5, 2.0), grid_spec=td.GridSpec.auto(wavelength=1.0), run_time=1e-12, - sources=[source], + sources=(source,), boundary_spec=td.BoundarySpec( x=td.Boundary.bloch_from_source( source=source, @@ -2041,22 +2025,22 @@ def test_tfsf_boundaries(): ) # cannot cross any boundary in the direction of injection - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(2.0, 2.0, 0.5), grid_spec=td.GridSpec.auto(wavelength=1.0), run_time=1e-12, - sources=[source], + sources=(source,), ) # cannot cross any non-periodic boundary in the transverse direction - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( center=(0.5, 0, 0), # also check the case when the boundary is crossed only on one side size=(0.5, 0.5, 2.0), grid_spec=td.GridSpec.auto(wavelength=1.0), run_time=1e-12, - sources=[source], + sources=(source,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(), y=td.Boundary.absorber(), @@ -2084,13 +2068,13 @@ def test_tfsf_structures_grid(): size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.auto(wavelength=1.0), run_time=1e-12, - sources=[source], - structures=[ + sources=(source,), + structures=( td.Structure( geometry=td.Box(center=(0, 0, -1), size=(0.5, 0.5, 0.5)), medium=td.Medium(permittivity=2), - ) - ], + ), + ), ) sim.validate_pre_upload() @@ -2100,13 +2084,13 @@ def test_tfsf_structures_grid(): size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.auto(wavelength=1.0), run_time=1e-12, - sources=[source], - structures=[ + sources=(source,), + structures=( td.Structure( geometry=td.Box(center=(0.5, 0, 0), size=(0.25, 0.25, 0.25)), medium=td.Medium(permittivity=2), - ) - ], + ), + ), ) with pytest.raises(SetupError): sim.validate_pre_upload() @@ -2116,12 +2100,12 @@ def test_tfsf_structures_grid(): size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.auto(wavelength=1.0), run_time=1e-12, - sources=[source], - structures=[ + sources=(source,), + structures=( td.Structure( geometry=td.Box(center=(0.5, 0, 0), size=(0.25, 0.25, 0.25)), medium=td.Medium() - ) - ], + ), + ), ) # TFSF box must not intersect a custom medium @@ -2140,13 +2124,13 @@ def test_tfsf_structures_grid(): size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.auto(wavelength=1.0), run_time=1e-12, - sources=[source], - structures=[ + sources=(source,), + structures=( td.Structure( geometry=td.Box(center=(0.5, 0, 0), size=(td.inf, td.inf, 0.25)), medium=custom_medium, - ) - ], + ), + ), ) with pytest.raises(SetupError): sim.validate_pre_upload() @@ -2159,26 +2143,27 @@ def test_tfsf_structures_grid(): size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.auto(wavelength=1.0), run_time=1e-12, - sources=[source], - structures=[ + sources=(source,), + structures=( td.Structure( geometry=td.Box(center=(0.5, 0, 0), size=(td.inf, td.inf, 0.25)), medium=anisotropic_medium, - ) - ], + ), + ), ) with pytest.raises(SetupError): sim.validate_pre_upload() @pytest.mark.parametrize( - "size, num_struct, log_level", [(1, 1, None), (50, 1, "WARNING"), (1, 11, "WARNING")] + "size, num_struct, log_level", [(1, 1, None), (2, 1, "WARNING"), (1, 11, "WARNING")] ) @td.packaging.disable_local_subpixel def test_warn_large_epsilon(monkeypatch, size, num_struct, log_level): """Make sure we get a warning if the epsilon grid is too large.""" monkeypatch.setattr(simulation, "NUM_STRUCTURES_WARN_EPSILON", 10) + monkeypatch.setattr(simulation, "NUM_CELLS_WARN_EPSILON", 2_000) structures = [ td.Structure( geometry=td.Box(center=(0, 0, 0), size=(0.1, 0.1, 0.1)), @@ -2191,15 +2176,15 @@ def test_warn_large_epsilon(monkeypatch, size, num_struct, log_level): grid_spec=td.GridSpec.uniform(dl=0.1), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[ + sources=( td.ModeSource( center=(0, 0, 0), size=(td.inf, td.inf, 0), direction="+", source_time=td.GaussianPulse(freq0=1e12, fwidth=0.1e12), - ) - ], - structures=structures, + ), + ), + structures=tuple(structures), ) with AssertLogLevel(log_level): @@ -2214,18 +2199,18 @@ def test_warn_large_mode_monitor(dl, log_level): size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.uniform(dl=dl), run_time=1e-12, - sources=[ + sources=( td.ModeSource( size=(0.4, 0.4, 0), direction="+", source_time=td.GaussianPulse(freq0=1e12, fwidth=0.1e12), - ) - ], - monitors=[ + ), + ), + monitors=( td.ModeMonitor( size=(td.inf, 0, td.inf), freqs=[1e12], name="test", mode_spec=td.ModeSpec() - ) - ], + ), + ), ) with AssertLogLevel(log_level): @@ -2240,13 +2225,13 @@ def test_warn_large_mode_source(dl, log_level): size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.uniform(dl=dl), run_time=1e-12, - sources=[ + sources=( td.ModeSource( size=(td.inf, td.inf, 0), direction="+", source_time=td.GaussianPulse(freq0=1e12, fwidth=0.1e12), - ) - ], + ), + ), ) with AssertLogLevel(log_level): @@ -2274,14 +2259,14 @@ def test_error_large_monitors(monitor): grid_spec=td.GridSpec.uniform(dl=0.001), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[ + sources=( td.ModeSource( size=(0.1, 0.1, 0), direction="+", source_time=td.GaussianPulse(freq0=1e12, fwidth=0.1e12), - ) - ], - monitors=[monitor], + ), + ), + monitors=(monitor,), ) # small sim should not error @@ -2300,29 +2285,29 @@ def test_error_max_time_monitor_steps(): size=(5, 5, 5), run_time=1e-12, grid_spec=td.GridSpec.uniform(dl=0.01), - sources=[ + sources=( td.ModeSource( size=(0.1, 0.1, 0), direction="+", source_time=td.GaussianPulse(freq0=2e14, fwidth=0.1e14), - ) - ], + ), + ), ) # simulation with a 0D time monitor should not error monitor = td.FieldTimeMonitor(center=(0, 0, 0), size=(0, 0, 0), name="time") - sim = sim.updated_copy(monitors=[monitor]) + sim = sim.updated_copy(monitors=(monitor,)) sim.validate_pre_upload() # 1D monitor should error with pytest.raises(SetupError): monitor = monitor.updated_copy(size=(1, 0, 0)) - sim = sim.updated_copy(monitors=[monitor]) + sim = sim.updated_copy(monitors=(monitor,)) sim.validate_pre_upload() # setting a large enough interval should again not error monitor = monitor.updated_copy(interval=20) - sim = sim.updated_copy(monitors=[monitor]) + sim = sim.updated_copy(monitors=(monitor,)) sim.validate_pre_upload() @@ -2355,14 +2340,14 @@ def test_warn_time_monitor_outside_run_time(start, log_level): size=(2.0, 2.0, 2.0), grid_spec=td.GridSpec.uniform(dl=0.1), run_time=1e-12, - sources=[ + sources=( td.ModeSource( size=(0.4, 0.4, 0), direction="+", source_time=td.GaussianPulse(freq0=1e12, fwidth=0.1e12), - ) - ], - monitors=[td.FieldTimeMonitor(size=(td.inf, 0, td.inf), start=start, name="test")], + ), + ), + monitors=(td.FieldTimeMonitor(size=(td.inf, 0, td.inf), start=start, name="test"),), ) with AssertLogLevel(log_level, contains_str="start time"): sim.validate_pre_upload() @@ -2382,7 +2367,7 @@ def test_dt(): geometry=td.Box(size=(1, 1, 1), center=(-1, 0, 0)), medium=td.PoleResidue(eps_inf=0.16, poles=[(-1 + 1j, 2 + 2j)]), ) - sim_new = sim.copy(update={"structures": [structure]}) + sim_new = sim.copy(update={"structures": (structure,)}) assert sim_new.dt == 0.4 * dt @@ -2395,7 +2380,7 @@ def test_conformal_dt(): sim = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, - structures=[box], + structures=(box,), grid_spec=td.GridSpec.uniform(dl=0.1), subpixel=td.SubpixelSpec(pec=td.Staircasing()), ) @@ -2457,8 +2442,8 @@ def test_sim_volumetric_structures(tmp_path): for struct in [box, cyl, pslab]: sim = td.Simulation( size=(10, 10, 10), - structures=[struct], - sources=[src], + structures=(struct,), + sources=(src,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=6), y=td.Boundary.pml(num_layers=6), @@ -2498,9 +2483,9 @@ def test_sim_volumetric_structures(tmp_path): ) sim = td.Simulation( size=(10, 10, 10), - structures=[below, box], - sources=[src], - monitors=[monitor], + structures=(below, box), + sources=(src,), + monitors=(monitor,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=6), y=td.Boundary.pml(num_layers=6), @@ -2523,9 +2508,9 @@ def test_sim_volumetric_structures(tmp_path): ) sim = td.Simulation( size=(10, 10, 10), - structures=[below, box], - sources=[src], - monitors=[monitor], + structures=(below, box), + sources=(src,), + monitors=(monitor,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=6), y=td.Boundary.pml(num_layers=6), @@ -2555,8 +2540,8 @@ def test_sim_volumetric_structures(tmp_path): sim = td.Simulation( size=(10, 10, 10), - structures=[below_half, box], - sources=[src], + structures=(below_half, box), + sources=(src,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=6), y=td.Boundary.pml(num_layers=6), @@ -2571,8 +2556,8 @@ def test_sim_volumetric_structures(tmp_path): # structure overlaying the 2D material should overwrite it like normal sim = td.Simulation( size=(10, 10, 10), - structures=[box, below], - sources=[src], + structures=(box, below), + sources=(src,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=6), y=td.Boundary.pml(num_layers=6), @@ -2585,11 +2570,11 @@ def test_sim_volumetric_structures(tmp_path): assert np.isclose(sim.volumetric_structures[1].medium.xx.permittivity, 2, rtol=RTOL) # test simulation.medium can't be Medium2D - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(10, 10, 10), - structures=[], - sources=[src], + structures=(), + sources=(src,), medium=box.medium, boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=6), @@ -2601,16 +2586,16 @@ def test_sim_volumetric_structures(tmp_path): ) # test 2d medium is added to 2d geometry - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=td.Box(center=(0, 0, 0), size=(1, 1, 1)), medium=box.medium) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=td.Cylinder(radius=1, length=1), medium=box.medium) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure( geometry=td.PolySlab(vertices=[(0, 0), (1, 0), (1, 1)], slab_bounds=(-1, 1)), medium=box.medium, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=td.Sphere(radius=1), medium=box.medium) # test warning for 2d geometry in simulation without Medium2D @@ -2618,8 +2603,8 @@ def test_sim_volumetric_structures(tmp_path): struct = td.Structure(medium=td.Medium(), geometry=td.Box(size=(1, 0, 1))) sim = td.Simulation( size=(10, 10, 10), - structures=[struct], - sources=[src], + structures=(struct,), + sources=(src,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=6), y=td.Boundary.pml(num_layers=6), @@ -2642,7 +2627,7 @@ def test_pml_boxes_2D(normal_axis): size=sim_size, run_time=1e-12, grid_spec=td.GridSpec(wavelength=1.0), - sources=[ + sources=( td.PointDipole( center=(0, 0, 0), polarization="Ex", @@ -2650,8 +2635,8 @@ def test_pml_boxes_2D(normal_axis): freq0=1e14, fwidth=1e12, ), - ) - ], + ), + ), boundary_spec=td.BoundarySpec.pml(**pml_on_kwargs), ) @@ -2685,10 +2670,10 @@ def test_allow_gain(): run_time=1e-12, medium=medium, grid_spec=td.GridSpec.uniform(dl=0.1), - structures=[struct], + structures=(struct,), ) assert not sim.allow_gain - sim = sim.updated_copy(structures=[struct_gain]) + sim = sim.updated_copy(structures=(struct_gain,)) assert sim.allow_gain @@ -2745,7 +2730,7 @@ def test_perturbed_mediums_copy(unstructured, z): run_time=1e-12, medium=pmed1, grid_spec=td.GridSpec.uniform(dl=0.1), - structures=[struct], + structures=(struct,), ) # no perturbations provided -> regular mediums @@ -2770,7 +2755,7 @@ def test_scene_from_scene(): sim = td.Simulation.from_scene( scene=scene, - **SIM_FULL.dict(exclude={"structures", "medium"}), + **SIM_FULL.model_dump(exclude={"structures", "medium"}), ) assert sim == SIM_FULL @@ -2780,7 +2765,7 @@ def test_to_gds(tmp_path): sim = td.Simulation( size=(2.0, 2.0, 2.0), run_time=1e-12, - structures=[ + structures=( td.Structure( geometry=td.Box(size=(1, 1, 1), center=(-1, 0, 0)), medium=td.Medium(permittivity=2.0), @@ -2793,17 +2778,17 @@ def test_to_gds(tmp_path): geometry=td.Cylinder(radius=1.4, length=2.0, center=(1.0, 0.0, -1.0), axis=1), medium=td.Medium(), ), - ], - sources=[ + ), + sources=( td.PointDipole( center=(0, 0, 0), polarization="Ex", source_time=td.GaussianPulse(freq0=1e14, fwidth=1e12), - ) - ], - monitors=[ + ), + ), + monitors=( td.FieldMonitor(size=(0, 0, 0), center=(0, 0, 0), freqs=[1e12, 2e12], name="point"), - ], + ), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=20), y=td.Boundary.stable_pml(num_layers=30), @@ -2865,7 +2850,7 @@ def test_sim_subsection(unstructured, nz): sim_red = sim_full_sym.subsection( region=region, symmetry=(1, 0, -1), - monitors=[mnt for mnt in SIM_FULL.monitors if not isinstance(mnt, td.ModeMonitor)], + monitors=tuple(mnt for mnt in SIM_FULL.monitors if not isinstance(mnt, td.ModeMonitor)), ) assert sim_red.symmetry == (1, 0, -1) sim_red = SIM_FULL.subsection( @@ -2873,11 +2858,11 @@ def test_sim_subsection(unstructured, nz): ) sim_red = SIM_FULL.subsection( region=region, - sources=[], + sources=(), grid_spec=td.GridSpec.uniform(dl=20), ) assert len(sim_red.sources) == 0 - sim_red = SIM_FULL.subsection(region=region, monitors=[]) + sim_red = SIM_FULL.subsection(region=region, monitors=()) assert len(sim_red.monitors) == 0 sim_red = SIM_FULL.subsection(region=region, remove_outside_structures=False) assert len(sim_red.structures) == len(SIM_FULL.structures) @@ -2901,12 +2886,12 @@ def test_sim_subsection(unstructured, nz): fine_custom_medium = td.CustomMedium(permittivity=perm) sim = SIM_FULL.updated_copy( - structures=[ + structures=( td.Structure( geometry=td.Box(size=(1, 2, 3)), medium=fine_custom_medium, - ) - ], + ), + ), medium=fine_custom_medium, ) sim_red = sim.subsection(region=region, remove_outside_custom_mediums=True) @@ -2914,7 +2899,7 @@ def test_sim_subsection(unstructured, nz): # check automatic symmetry expansion sim_sym = sim_full_sym.updated_copy( symmetry=(-1, 0, 1), - sources=[src for src in SIM_FULL.sources if not isinstance(src, td.TFSF)], + sources=tuple(src for src in SIM_FULL.sources if not isinstance(src, td.TFSF)), ) sim_red = sim_sym.subsection(region=region) assert np.allclose(sim_red.center, (0, 0.05, 0.0)) @@ -2935,21 +2920,26 @@ def test_sim_subsection(unstructured, nz): # compare assert np.allclose(red_grid, full_grid[ind : ind + len(red_grid)]) - subsection_monitors = [mnt for mnt in SIM_FULL.monitors if region_xy.intersects(mnt)] + subsection_monitors = ( + mnt + for mnt in SIM_FULL.monitors + if region_xy.intersects(mnt) + and getattr(mnt, "far_field_approx", True) # unsupported in 2d + and not isinstance( + mnt, (td.FieldProjectionCartesianMonitor, td.FieldProjectionKSpaceMonitor) + ) + ) sim_red = SIM_FULL.subsection( region=region_xy, grid_spec="identical", boundary_spec=td.BoundarySpec.all_sides(td.Periodic()), # Set theta to 'pi/2' for 2D simulation in the x-y plane - monitors=[ + monitors=tuple( mnt.updated_copy(theta=np.pi / 2) if isinstance(mnt, td.FieldProjectionAngleMonitor) else mnt for mnt in subsection_monitors - if not isinstance( - mnt, (td.FieldProjectionCartesianMonitor, td.FieldProjectionKSpaceMonitor) - ) - ], + ), ) assert sim_red.size[2] == 0 assert isinstance(sim_red.boundary_spec.z.minus, td.Periodic) @@ -3064,8 +3054,8 @@ def test_2d_material_subdivision(): size=size_sim, grid_spec=td.GridSpec(grid_x=uni_grid, grid_y=uni_grid, grid_z=uni_grid), structures=structures, - sources=[], - monitors=[], + sources=(), + monitors=(), run_time=1e-12, ) @@ -3200,22 +3190,22 @@ def test_advanced_material_intersection(): size=(4.0, 4.0, 4.0), grid_spec=td.GridSpec.auto(wavelength=1.0), run_time=1e-12, - sources=[source], - structures=[], + sources=(source,), + structures=(), ) for pair in compatible_pairs: struct1 = td.Structure(geometry=td.Box(size=(1, 1, 1), center=(0, 0, 0.5)), medium=pair[0]) struct2 = td.Structure(geometry=td.Box(size=(1, 1, 1), center=(0, 0, -0.5)), medium=pair[1]) # this pair can intersect - sim = sim.updated_copy(structures=[struct1, struct2]) + sim = sim.updated_copy(structures=(struct1, struct2)) for pair in incompatible_pairs: struct1 = td.Structure(geometry=td.Box(size=(1, 1, 1), center=(0, 0, 0.5)), medium=pair[0]) struct2 = td.Structure(geometry=td.Box(size=(1, 1, 1), center=(0, 0, -0.5)), medium=pair[1]) # this pair cannot intersect - with pytest.raises(pydantic.ValidationError): - sim = sim.updated_copy(structures=[struct1, struct2]) + with pytest.raises(ValidationError): + sim = sim.updated_copy(structures=(struct1, struct2)) for pair in incompatible_pairs: struct1 = td.Structure(geometry=td.Box(size=(1, 1, 1), center=(0, 0, 0.75)), medium=pair[0]) @@ -3223,12 +3213,12 @@ def test_advanced_material_intersection(): geometry=td.Box(size=(1, 1, 1), center=(0, 0, -0.75)), medium=pair[1] ) # it's ok if these are both present as long as they don't intersect - sim = sim.updated_copy(structures=[struct1, struct2]) + sim = sim.updated_copy(structures=(struct1, struct2)) -def test_num_lumped_elements(): +def test_num_lumped_elements(monkeypatch): """Make sure we error if too many lumped elements supplied.""" - + monkeypatch.setattr(simulation, "MAX_NUM_MEDIUMS", TEST_MAX_NUM_MEDIUMS) resistor = td.LumpedResistor( size=(0, 1, 2), center=(0, 0, 0), name="R1", voltage_axis=2, resistance=75 ) @@ -3237,16 +3227,16 @@ def test_num_lumped_elements(): _ = td.Simulation( size=(5, 5, 5), grid_spec=grid_spec, - structures=[], - lumped_elements=[resistor] * MAX_NUM_MEDIUMS, + structures=(), + lumped_elements=(resistor,) * TEST_MAX_NUM_MEDIUMS, run_time=1e-12, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.Simulation( size=(5, 5, 5), grid_spec=grid_spec, - structures=[], - lumped_elements=[resistor] * (MAX_NUM_MEDIUMS + 1), + structures=(), + lumped_elements=(resistor,) * (TEST_MAX_NUM_MEDIUMS + 1), run_time=1e-12, ) @@ -3260,23 +3250,23 @@ def test_validate_lumped_elements(): size=(1, 2, 3), run_time=1e-12, grid_spec=td.GridSpec.uniform(dl=0.1), - lumped_elements=[resistor], + lumped_elements=(resistor,), ) # error for 1D/2D simulation with lumped elements - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 0, 3), run_time=1e-12, grid_spec=td.GridSpec.uniform(dl=0.1), - lumped_elements=[resistor], + lumped_elements=(resistor,), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Simulation( size=(1, 0, 0), run_time=1e-12, grid_spec=td.GridSpec.uniform(dl=0.1), - lumped_elements=[resistor], + lumped_elements=(resistor,), ) @@ -3288,7 +3278,7 @@ def test_suggested_mesh_overrides(): size=(1, 2, 3), run_time=1e-12, grid_spec=td.GridSpec.auto(wavelength=1), - lumped_elements=[resistor], + lumped_elements=(resistor,), ) assert len(sim.internal_override_structures) == 1 assert len(sim.internal_snapping_points) == 3 @@ -3303,7 +3293,7 @@ def test_suggested_mesh_overrides(): ) sim = sim.updated_copy( - lumped_elements=[coax_resistor], + lumped_elements=(coax_resistor,), ) assert len(sim.internal_override_structures) == 1 assert len(sim.internal_snapping_points) == 1 @@ -3331,8 +3321,8 @@ def test_run_time_spec_lossy_metal(): sim = td.Simulation( run_time=run_time_spec, size=(1e4, 1e4, 2e3), - sources=[source], - structures=[box], + sources=(source,), + structures=(box,), ) assert max(sim.get_refractive_indices(freq0)) < 2 # if lossymetal is not handled properly, _run_time can approach 1e-6 @@ -3352,7 +3342,7 @@ def test_validate_low_num_cells_in_mode_objects(): direction="+", ) - sim = SIM.updated_copy(sources=[mode_source]) + sim = SIM.updated_copy(sources=(mode_source,)) # check with mode source that is too small with pytest.raises(SetupError): @@ -3365,7 +3355,7 @@ def test_validate_low_num_cells_in_mode_objects(): size=sim_2d_size, run_time=1e-12, grid_spec=td.GridSpec(wavelength=1.0), - sources=[mode_source], + sources=(mode_source,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=6), y=td.Boundary.pec(), @@ -3382,7 +3372,7 @@ def test_validate_low_num_cells_in_mode_objects(): mode_spec=mode_spec, freqs=[1e12], ) - sim = SIM.updated_copy(monitors=[mode_monitor]) + sim = SIM.updated_copy(monitors=(mode_monitor,)) with pytest.raises(SetupError): sim._validate_num_cells_in_mode_objects() @@ -3404,20 +3394,20 @@ def test_validate_sources_monitors_in_bounds(): ) # check that a source at y- simulation domain edge errors - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(2, 2, 2), run_time=1e-12, grid_spec=td.GridSpec(wavelength=1.0), - sources=[mode_source], + sources=(mode_source,), ) # check that a monitor at y+ simulation domain edge errors - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sim = td.Simulation( size=(2, 2, 2), run_time=1e-12, grid_spec=td.GridSpec(wavelength=1.0), - monitors=[mode_monitor], + monitors=(mode_monitor,), ) @@ -3436,48 +3426,48 @@ def test_mode_pml_warning(): sim = td.Simulation( size=sim_size, medium=sio2, - structures=[wg], + structures=(wg,), grid_spec=grid_spec, run_time=1e-30, - monitors=[ + monitors=( td.ModeSolverMonitor( size=(2, 2, 0), name="mode", freqs=[freq0], mode_spec=mode_spec.updated_copy(num_pml=(10, 10)), - ) - ], + ), + ), symmetry=symmetry, ) with AssertLogLevel("WARNING", contains_str="covers more than"): sim = td.Simulation( size=sim_size, medium=sio2, - structures=[wg], + structures=(wg,), grid_spec=grid_spec, run_time=1e-30, - monitors=[ + monitors=( td.ModeSolverMonitor( size=(2, 2, 0), name="mode", freqs=[freq0], mode_spec=mode_spec - ) - ], + ), + ), symmetry=symmetry, ) with AssertLogLevel("WARNING", contains_str="covers more than"): sim = td.Simulation( size=sim_size, medium=sio2, - structures=[wg], + structures=(wg,), grid_spec=grid_spec, run_time=1e-30, - sources=[ + sources=( td.ModeSource( size=(2, 2, 0), direction="+", source_time=td.GaussianPulse(freq0=freq0, fwidth=0.1 * freq0), mode_spec=mode_spec, - ) - ], + ), + ), symmetry=symmetry, ) with AssertLogLevel("WARNING", contains_str="covers more than"): @@ -3494,7 +3484,7 @@ def test_mode_pml_warning(): mode_sim = td.ModeSimulation( size=sim_size, medium=sio2, - structures=[wg], + structures=(wg,), grid_spec=grid_spec, plane=mode_plane, mode_spec=mode_spec, @@ -3528,9 +3518,9 @@ def test_fixed_angle_sim(): ) sim_size = (2.2, 2.2, 2.2) sim = td.Simulation( - structures=[sphere], - sources=[source], - monitors=[flux_r_mnt], + structures=(sphere,), + sources=(source,), + monitors=(flux_r_mnt,), size=sim_size, grid_spec=td.GridSpec.auto(min_steps_per_wvl=15), boundary_spec=td.BoundarySpec( @@ -3541,7 +3531,7 @@ def test_fixed_angle_sim(): assert sim._is_fixed_angle - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim.updated_copy( boundary_spec=td.BoundarySpec( x=td.Boundary.pml(), @@ -3550,26 +3540,26 @@ def test_fixed_angle_sim(): ) ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(KeyError): _ = sim.updated_copy(med=td.Medium(conductivity=0.001)) anisotropic_med = td.FullyAnisotropicMedium(permittivity=[[2, 0, 0], [0, 1, 0], [0, 0, 3]]) - with pytest.raises(pydantic.ValidationError): - _ = sim.updated_copy(structures=[sphere.updated_copy(medium=anisotropic_med)]) + with pytest.raises(ValidationError): + _ = sim.updated_copy(structures=(sphere.updated_copy(medium=anisotropic_med),)) - with pytest.raises(pydantic.ValidationError): - _ = sim.updated_copy(sources=[source, source]) + with pytest.raises(ValidationError): + _ = sim.updated_copy(sources=(source, source)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim.updated_copy( - structures=[sphere.updated_copy(medium=td.Medium(conductivity=-0.1, allow_gain=True))] + structures=(sphere.updated_copy(medium=td.Medium(conductivity=-0.1, allow_gain=True)),) ) - with pytest.raises(pydantic.ValidationError): - _ = sim.updated_copy(monitors=[td.FieldTimeMonitor(size=[td.inf, td.inf, 0], name="time")]) + with pytest.raises(ValidationError): + _ = sim.updated_copy(monitors=(td.FieldTimeMonitor(size=[td.inf, td.inf, 0], name="time"),)) - with pytest.raises(pydantic.ValidationError): - _ = sim.updated_copy(monitors=[td.FluxTimeMonitor(size=[td.inf, td.inf, 0], name="time")]) + with pytest.raises(ValidationError): + _ = sim.updated_copy(monitors=(td.FluxTimeMonitor(size=[td.inf, td.inf, 0], name="time"),)) nonlinear_med = td.Medium( permittivity=3, @@ -3580,8 +3570,8 @@ def test_fixed_angle_sim(): num_iters=20, ), ) - with pytest.raises(pydantic.ValidationError): - _ = sim.updated_copy(structures=[sphere.updated_copy(medium=nonlinear_med)]) + with pytest.raises(ValidationError): + _ = sim.updated_copy(structures=(sphere.updated_copy(medium=nonlinear_med),)) time_modulated_med = td.Medium( permittivity=2, @@ -3591,8 +3581,8 @@ def test_fixed_angle_sim(): ) ), ) - with pytest.raises(pydantic.ValidationError): - _ = sim.updated_copy(structures=[sphere.updated_copy(medium=time_modulated_med)]) + with pytest.raises(ValidationError): + _ = sim.updated_copy(structures=(sphere.updated_copy(medium=time_modulated_med),)) def test_sim_volumetric_structures_with_lumped_elements(tmp_path): @@ -3625,14 +3615,14 @@ def test_sim_volumetric_structures_with_lumped_elements(tmp_path): for element in [resistor, coax_resistor, linear_element]: sim = td.Simulation( size=(10, 10, 10), - structures=[substrate], - sources=[src], + structures=(substrate,), + sources=(src,), boundary_spec=td.BoundarySpec( x=td.Boundary.pml(num_layers=6), y=td.Boundary.pml(num_layers=6), z=td.Boundary.pml(num_layers=6), ), - lumped_elements=[element], + lumped_elements=(element,), grid_spec=td.GridSpec.uniform(dl=grid_dl), run_time=1e-12, ) @@ -3647,7 +3637,7 @@ def test_create_sim_multiphysics(): size=(10, 10, 10), grid_spec=td.GridSpec(wavelength=1.0), medium=td.Medium(permittivity=1.0), - structures=[ + structures=( td.Structure( geometry=td.Box(size=(1, 1, 1), center=(-1, 0.5, 0.5)), medium=td.MultiPhysicsMedium( @@ -3656,7 +3646,7 @@ def test_create_sim_multiphysics(): name="SiO2", ), ), - ], + ), ) @@ -3681,13 +3671,13 @@ def test_create_sim_multiphysics_with_incompatibilities(): num_iters=20, ) ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): s = td.Simulation( run_time=1e-12, size=(10, 10, 10), grid_spec=td.GridSpec(wavelength=1.0), medium=td.Medium(permittivity=1.0), - structures=[ + structures=( td.Structure( geometry=td.Box(size=(1, 1, 1), center=(-1, 0.5, 0.5)), medium=nonlinear, @@ -3700,7 +3690,7 @@ def test_create_sim_multiphysics_with_incompatibilities(): name="SiO2", ), ), - ], + ), ) @@ -3735,20 +3725,19 @@ def test_messages_contain_object_names(): polarization="Ex", source_time=td.GaussianPulse(freq0=100e14, fwidth=10e14), ) - with pytest.raises(pydantic.ValidationError, match=name) as e: + with pytest.raises(ValidationError, match=name) as e: _ = sim.updated_copy(sources=[source]) # Test 3) Create a monitor lying outside the simulation boundary. # Check that an error message is generated containing the monitor's `name`. name = "monitor_123" monitor = td.FieldMonitor(name=name, center=(-1.0, 0, 0), size=(0.5, 0, 1), freqs=[100e14]) - with pytest.raises(pydantic.ValidationError, match=name) as e: + with pytest.raises(ValidationError, match=name) as e: _ = sim.updated_copy(monitors=[monitor]) def test_structures_per_medium(monkeypatch): """Test if structures that share the same medium warn or error appropriately.""" - import tidy3d.components.scene as scene # Set low thresholds to keep the test fast; ensure len(structures) > MAX to avoid early return monkeypatch.setattr(scene, "WARN_STRUCTURES_PER_MEDIUM", 2) @@ -3778,7 +3767,7 @@ def test_structures_per_medium(monkeypatch): monkeypatch.setattr(scene, "MAX_STRUCTURES_PER_MEDIUM", 3, raising=False) structs = [td.Structure(geometry=td.Box(size=(1, 1, 1)), medium=shared_med) for _ in range(4)] - with pytest.raises(pydantic.ValidationError, match="use the same medium"): + with pytest.raises(ValidationError, match="use the same medium"): _ = td.Simulation( size=(10, 10, 10), run_time=1e-12, @@ -3857,7 +3846,7 @@ def test_validate_microwave_mode_spec(): path="mode_spec/", impedance_specs=(custom_spec, td.AutoImpedanceSpec()) ) # check that validation error is in the MicrowaveModeSpec - with pytest.raises(SetupError): + with pytest.raises(ValidationError): sim = sim.updated_copy( monitors=[mode_mon], ) diff --git a/tests/test_components/test_source.py b/tests/test_components/test_source.py index 3160207f79..a2ffa9c5c6 100644 --- a/tests/test_components/test_source.py +++ b/tests/test_components/test_source.py @@ -4,8 +4,8 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d.components.source.field import CHEB_GRID_WIDTH, DirectionalSource @@ -214,14 +214,14 @@ def test_dipole(): # p.plot(y=2) # plt.close() - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.PointDipole(size=(1, 1, 1), source_time=g, center=(1, 2, 3), polarization="Ex") def test_dipole_sources_from_angles(): g = td.GaussianPulse(freq0=1e12, fwidth=0.1e12) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.PointDipole.sources_from_angles( size=(1, 1, 1), source_time=g, @@ -311,11 +311,11 @@ def test_FieldSource(): # plt.close() # test that non-planar geometry crashes plane wave and gaussian beams - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.PlaneWave(size=(1, 1, 1), source_time=g, pol_angle=np.pi / 2, direction="+") - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.GaussianBeam(size=(1, 1, 1), source_time=g, pol_angle=np.pi / 2, direction="+") - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.AstigmaticGaussianBeam( size=(1, 1, 1), source_time=g, @@ -324,14 +324,14 @@ def test_FieldSource(): waist_sizes=(0.2, 0.4), waist_distances=(0.1, 0.3), ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.ModeSource(size=(1, 1, 1), source_time=g, mode_spec=mode_spec) tfsf = td.TFSF(size=(1, 1, 1), direction="+", source_time=g, injection_axis=2) _ = tfsf.injection_plane_center # assert that TFSF must be volumetric - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.TFSF(size=(1, 1, 0), direction="+", source_time=g, injection_axis=2) # s.plot(z=0) @@ -481,11 +481,11 @@ def check_freq_grid(freq_grid, num_freqs): check_freq_grid(freq_grid, num_freqs) # check validators for num_freqs - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): s = td.GaussianBeam( size=(0, 1, 1), source_time=g, pol_angle=np.pi / 2, direction="+", num_freqs=200 ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): s = td.AstigmaticGaussianBeam( size=(0, 1, 1), source_time=g, @@ -495,7 +495,7 @@ def check_freq_grid(freq_grid, num_freqs): waist_distances=(0.1, 0.3), num_freqs=100, ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): s = td.ModeSource( size=(0, 1, 1), direction="+", @@ -541,7 +541,7 @@ def test_custom_source_time(): ) cst = td.CustomSourceTime.from_values(freq0=freq0, fwidth=0.1e12, values=[0, 1], dt=sim.dt) source = td.PointDipole(center=(0, 0, 0), source_time=cst, polarization="Ex") - sim = sim.updated_copy(sources=[source]) + sim = sim.updated_copy(sources=(source,)) assert np.allclose(cst.amp_time(sim.tmesh[0]), [0], rtol=0, atol=ATOL) assert np.allclose( cst.amp_time(sim.tmesh[1:]), @@ -560,7 +560,7 @@ def test_custom_source_time(): cst = td.CustomSourceTime(source_time_dataset=dataset, freq0=freq0, fwidth=0.1e12) source = td.PointDipole(center=(0, 0, 0), source_time=cst, polarization="Ex") with AssertLogLevel("WARNING", contains_str="defined over a time range"): - sim = sim.updated_copy(sources=[source]) + sim = sim.updated_copy(sources=(source,)) # test normalization warning with AssertLogLevel("WARNING"): @@ -568,10 +568,10 @@ def test_custom_source_time(): with AssertLogLevel("WARNING"): source = source.updated_copy(source_time=td.ContinuousWave(freq0=freq0, fwidth=0.1e12)) - sim = sim.updated_copy(sources=[source]) + sim = sim.updated_copy(sources=(source,)) # test single value validation error - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): vals = td.components.data.data_array.TimeDataArray([1], coords={"t": [0]}) dataset = td.components.data.dataset.TimeDataset(values=vals) cst = td.CustomSourceTime(source_time_dataset=dataset, freq0=freq0, fwidth=0.1e12) @@ -597,7 +597,7 @@ def make_custom_field_source(field_ds): with AssertLogLevel(None): make_custom_field_source(field_dataset) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): # repeat some entries so data cannot be interpolated X2 = [X[0], *list(X)] n_data2 = np.vstack((n_data[0, :, :, :].reshape(1, Ny, Nz, Nf), n_data)) @@ -724,7 +724,7 @@ def test_broadband_angled_gaussian_warning(): def test_source_frame(): _ = td.PECFrame() _ = td.PECFrame(length=4) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.PECFrame(length=0) _ = td.ModeSource( diff --git a/tests/test_components/test_source_frames.py b/tests/test_components/test_source_frames.py index f644d07a4b..d0bfc49d9c 100644 --- a/tests/test_components/test_source_frames.py +++ b/tests/test_components/test_source_frames.py @@ -2,15 +2,15 @@ from __future__ import annotations -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td def test_source_absorber_frames(): _ = td.PECFrame() - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.PECFrame(length=0) wvl_um = 1 diff --git a/tests/test_components/test_structure.py b/tests/test_components/test_structure.py index d45931aadf..4beae9f174 100644 --- a/tests/test_components/test_structure.py +++ b/tests/test_components/test_structure.py @@ -4,8 +4,8 @@ import autograd.numpy as anp import gdstk import numpy as np -import pydantic.v1 as pd import pytest +from pydantic import ValidationError import tidy3d as td @@ -119,23 +119,23 @@ def test_invalid_polyslab(axis): _ = td.Structure(geometry=geo1, medium=medium) geo2 = ps.rotated(np.pi / 4, i) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=geo2, medium=medium) geo3 = ps.rotated(np.pi / 5, j) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=geo3, medium=medium) geo4 = ps.rotated(np.pi / 6, (1, 1, 1)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=geo4, medium=medium) geo5 = td.GeometryGroup(geometries=[ps]).rotated(np.pi / 2, j) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=geo5, medium=medium) geo6 = td.GeometryGroup(geometries=[ps - box]).rotated(np.pi / 2, i) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=geo6, medium=medium) geo7 = td.GeometryGroup(geometries=[(ps - box).rotated(np.pi / 4, j)]).rotated(-np.pi / 4, j) @@ -156,11 +156,11 @@ def test_invalid_polyslab(axis): _ = td.Structure(geometry=geo10, medium=medium) geo11 = ps.reflected(n2) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=geo11, medium=medium) geo12 = td.GeometryGroup(geometries=[ps]).reflected(n2) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=geo12, medium=medium) geo13 = td.GeometryGroup(geometries=[(ps - box).reflected(n2)]).reflected(n2) @@ -239,7 +239,7 @@ def test_validation_of_structures_with_2d_materials(): ] for geom in not_allowed_geometries: - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.Structure(geometry=geom, medium=med2d) diff --git a/tests/test_components/test_time_modulation.py b/tests/test_components/test_time_modulation.py index 5d6b741f46..387004d31c 100644 --- a/tests/test_components/test_time_modulation.py +++ b/tests/test_components/test_time_modulation.py @@ -5,8 +5,8 @@ from math import isclose import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td @@ -115,7 +115,7 @@ def test_space_modulation(): check_sp_reduction(SP_UNIFORM) # uniform in phase, but custom in amplitude - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sp = SP_UNIFORM.updated_copy(amplitude=ARRAY_CMP) sp = SP_UNIFORM.updated_copy(amplitude=ARRAY) @@ -123,14 +123,14 @@ def test_space_modulation(): check_sp_reduction(sp) # uniform in amplitude, but custom in phase - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sp = SP_UNIFORM.updated_copy(phase=ARRAY_CMP) sp = SP_UNIFORM.updated_copy(phase=ARRAY) assert isclose(sp.max_modulation, 1) check_sp_reduction(sp) # custom in both - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): sp = SP_UNIFORM.updated_copy(phase=ARRAY_CMP, amplitude=ARRAY_CMP) sp = SP_UNIFORM.updated_copy(phase=ARRAY, amplitude=ARRAY) check_sp_reduction(sp) @@ -174,7 +174,7 @@ def test_modulated_medium(): # permittivity modulated modulation_spec = MODULATION_SPEC.updated_copy(permittivity=ST) # modulated permitivity <= 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): medium = td.Medium(modulation_spec=modulation_spec) medium = td.Medium(permittivity=2, modulation_spec=modulation_spec) assert isclose(medium.n_cfl, np.sqrt(2 - AMP_TIME)) @@ -183,7 +183,7 @@ def test_modulated_medium(): # conductivity modulated modulation_spec = MODULATION_SPEC.updated_copy(conductivity=ST) # modulated conductivity <= 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): medium = td.Medium(modulation_spec=modulation_spec) medium_sometimes_active = td.Medium(modulation_spec=modulation_spec, allow_gain=True) medium = td.Medium(conductivity=2, modulation_spec=modulation_spec) @@ -194,7 +194,7 @@ def test_modulated_medium(): st_freq2 = ST.updated_copy( time_modulation=td.ContinuousWaveTimeModulation(freq0=2e12, amplitude=2) ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): modulation_spec = MODULATION_SPEC.updated_copy(permittivity=ST, conductivity=st_freq2) # both modulated, but different space modulation: fine st_space2 = ST.updated_copy(space_modulation=td.SpaceModulation(amplitude=0.1)) @@ -212,30 +212,30 @@ def test_unsupported_modulated_medium_types(): modulation_spec = MODULATION_SPEC.updated_copy(permittivity=ST) # PEC cannot be modulated - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.PECMedium(modulation_spec=modulation_spec) # PMC cannot be modulated - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.PMCMedium(modulation_spec=modulation_spec) # For Anisotropic medium, one should modulate the components, not the whole medium - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.AnisotropicMedium( xx=td.Medium(), yy=td.Medium(), zz=td.Medium(), modulation_spec=modulation_spec ) # Modulation to fully Anisotropic medium unsupported - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.FullyAnisotropicMedium(modulation_spec=modulation_spec) # 2D material - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): drude_medium = td.Drude(eps_inf=2.0, coeffs=[(1, 2), (3, 4)]) td.Medium2D(ss=drude_medium, tt=drude_medium, modulation_spec=modulation_spec) # together with nonlinear_spec - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.Medium( permittivity=2, nonlinear_spec=td.NonlinearSusceptibility(chi3=1), @@ -257,10 +257,10 @@ def test_supported_modulated_medium_types(unstructured, z): assert mat_p.is_time_modulated assert isclose(mat_p.n_cfl, np.sqrt(2 - AMP_TIME)) # too much modulation resulting in eps_inf < 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = mat_p.updated_copy(eps_inf=1.0) # conductivity modulation - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = mat_p.updated_copy(modulation_spec=modulation_both_spec) mat = mat_p.updated_copy(modulation_spec=modulation_both_spec, allow_gain=True) check_med_reduction(mat) @@ -276,10 +276,10 @@ def test_supported_modulated_medium_types(unstructured, z): assert mat_c.is_time_modulated assert isclose(mat_c.n_cfl, np.sqrt(2 - AMP_TIME)) # too much modulation resulting in eps_inf < 0 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = mat_c.updated_copy(permittivity=permittivity * 0.5) # conductivity modulation - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): mat = mat_c.updated_copy(modulation_spec=modulation_both_spec) mat = mat_c.updated_copy(modulation_spec=modulation_both_spec, allow_gain=True) check_med_reduction(mat_c) diff --git a/tests/test_components/test_types.py b/tests/test_components/test_types.py index dbf93c88f9..caab5253fa 100644 --- a/tests/test_components/test_types.py +++ b/tests/test_components/test_types.py @@ -3,50 +3,32 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.types import ArrayLike, Complex, constrained_array - - -def _test_validate_array_like(): - class S(Tidy3dBaseModel): - f: ArrayLike[float, 2] - - _ = S(f=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) - with pytest.raises(pydantic.ValidationError): - _ = S(f=np.array([1.0, 2.0, 3.0])) - - class MyClass(Tidy3dBaseModel): - f: constrained_array(ndim=3, shape=(1, 2, 3)) - - with pytest.raises(pydantic.ValidationError): - _ = MyClass(f=np.ones((2, 2, 3))) - - with pytest.raises(pydantic.ValidationError): - _ = MyClass(f=np.ones((1, 2, 3, 4))) +from tidy3d.components.types import ArrayLike, Complex +from tidy3d.components.types.base import array_alias def test_schemas(): class S(Tidy3dBaseModel): f: ArrayLike - ca: constrained_array(ndim=1, dtype=complex) + ca: array_alias(ndim=1, dtype=complex) c: Complex - # TODO: unexpected behavior, if list with more than one element, it fails. - _ = S(f=[13], c=1 + 1j, ca=1 + 1j) + _ = S(f=[13], c=1 + 1j, ca=[1 + 1j]) S.schema() def test_array_like(): class MyClass(Tidy3dBaseModel): a: ArrayLike = None # can be any array-like thing - b: constrained_array(ndim=2) = None # must be 2D - c: constrained_array(dtype=float) = None # must be float-like - d: constrained_array(ndim=1, dtype=complex) = None # 1D complex + b: array_alias(ndim=2) = None # must be 2D + c: array_alias(dtype=float) = None # must be float-like + d: array_alias(ndim=1, dtype=complex) = None # 1D complex e: ArrayLike - f: constrained_array(ndim=3, shape=(1, 2, 3)) = None # must have certain shape + f: array_alias(ndim=3, shape=(1, 2, 3)) = None # must have certain shape my_obj = MyClass( a=1.0 + 2j, @@ -62,37 +44,185 @@ class MyClass(Tidy3dBaseModel): assert np.all(my_obj.c == [1.0, 3.0]) # converted to float assert np.all(my_obj.d == [1.0 + 0.0j]) # converted to complex - my_obj.json() + my_obj.model_dump_json() -def test_array_like_field_name(): +def test_hash(): class MyClass(Tidy3dBaseModel): - a: ArrayLike # can be any array-like thing - b: constrained_array(ndim=2) # must be 2D - c: constrained_array(dtype=float) # must be float-like - d: constrained_array(ndim=1, dtype=complex) # 1D complex - e: constrained_array(ndim=3, shape=(1, 2, 3)) # must have certain shape - f: ArrayLike = None + a: ArrayLike + b: array_alias(ndim=1) + c: tuple[ArrayLike, ...] - fields = MyClass.__fields__ + c = MyClass(a=[1.0], b=[2.0, 1.0], c=([2.0, 1.0])) + hash(c.model_dump_json()) - def correct_field_display(field_name, display_name): - """Make sure the field has the expected name.""" - assert fields[field_name]._type_display() == display_name - correct_field_display("a", "ArrayLike") - correct_field_display("b", "ArrayLike[ndim=2]") - correct_field_display("c", "ArrayLike[dtype=float]") - correct_field_display("d", "ArrayLike[dtype=complex, ndim=1]") - correct_field_display("e", "ArrayLike[ndim=3, shape=(1, 2, 3)]") - correct_field_display("f", "Optional[ArrayLike]") +def test_array_like_validation_errors(): + """Tests that appropriate ValidationErrors are raised for array constraints.""" + # input that cannot be converted to a NumPy array at all (with specific dtype) + class ModelDtypeConversionFail(Tidy3dBaseModel): + a: array_alias(dtype=int) -def test_hash(): - class MyClass(Tidy3dBaseModel): + with pytest.raises(ValidationError, match="cannot convert"): + ModelDtypeConversionFail(a="not an int") + + # ndim mismatch + class ModelNdimMismatch(Tidy3dBaseModel): + a: array_alias(ndim=1) + + with pytest.raises(ValidationError, match="expected"): + ModelNdimMismatch(a=[[1, 2], [3, 4]]) + + # ndim mismatch (scalar for ndim=1, scalar_to_1d=False by default) + class ModelNdimScalarDefault(Tidy3dBaseModel): + a: array_alias(ndim=1) + + with pytest.raises(ValidationError, match="expected"): + ModelNdimScalarDefault(a=5) + + # shape mismatch + class ModelShapeMismatch(Tidy3dBaseModel): + a: array_alias(shape=(2, 2)) + + with pytest.raises(ValidationError, match=r"expected shape"): + ModelShapeMismatch(a=[[1, 2, 3], [4, 5, 6]]) + + # forbid_nan=True (default) and array contains NaN + class ModelForbidNan(Tidy3dBaseModel): + a: array_alias(dtype=float) + + with pytest.raises(ValidationError, match="array contains NaN"): + ModelForbidNan(a=[1.0, np.nan, 3.0]) + + # strict=True and a scalar is provided + class ModelStrictScalar(Tidy3dBaseModel): + a: array_alias(strict=True) + + with pytest.raises(ValidationError, match="strict mode"): + ModelStrictScalar(a=10) + + # input results in an array with dtype=object + class ModelObjectDtype(Tidy3dBaseModel): a: ArrayLike - b: constrained_array(ndim=1) - c: tuple[ArrayLike, ...] - c = MyClass(a=[1.0], b=[2.0, 1.0], c=([2.0, 1.0])) - hash(c.json()) + with pytest.raises(ValidationError, match=r"unsupported element type"): + ModelObjectDtype(a=[1, "string", object()]) + + # general conversion failure for an unhandled type + class ModelGeneralConversionFail(Tidy3dBaseModel): + a: ArrayLike + + class UnconvertibleObject: + pass + + with pytest.raises(ValidationError, match="unsupported element type"): + ModelGeneralConversionFail(a=UnconvertibleObject()) + + # _from_complex_dict receives a dict it doesn't understand, passes it to _coerce, + # which then fails because dict becomes an object array or direct conversion fails + class ModelComplexInvalidDict(Tidy3dBaseModel): + a: array_alias(dtype=complex) + + with pytest.raises(ValidationError, match=r"cannot convert"): + ModelComplexInvalidDict(a={"real_part": 1, "imag_part": 2}) + + # scalar_to_1d=True with ndim=1 successfully converts scalar + class ModelScalarTo1DSuccess(Tidy3dBaseModel): + a: array_alias(ndim=1, scalar_to_1d=True) + + obj_s21d = ModelScalarTo1DSuccess(a=5.0) + assert np.array_equal(obj_s21d.a, np.array([5.0])) + assert obj_s21d.a.ndim == 1 + + # scalar_to_1d=True but ndim is incompatible with 1D array (e.g. ndim=2) + class ModelScalarTo1DWrongNdim(Tidy3dBaseModel): + a: array_alias(ndim=2, scalar_to_1d=True, dtype=float) + + with pytest.raises(ValidationError, match="expected"): + ModelScalarTo1DWrongNdim(a=5.0) + + # strict=True takes precedence over scalar_to_1d=True if input is scalar + class ModelStrictAndScalarTo1D(Tidy3dBaseModel): + a: array_alias(strict=True, scalar_to_1d=True, dtype=float) + + with pytest.raises(ValidationError, match="strict mode"): + ModelStrictAndScalarTo1D(a=5.0) + + # allow NaN when forbid_nan=False + class ModelAllowNan(Tidy3dBaseModel): + a: array_alias(dtype=float, forbid_nan=False) + + obj_allow_nan = ModelAllowNan(a=[1.0, np.nan]) + assert np.array_equal(obj_allow_nan.a, np.array([1.0, np.nan]), equal_nan=True) + + # strict=False (default) allows non-array if it can be coerced + class ModelStrictFalseCoercion(Tidy3dBaseModel): + a: array_alias(dtype=int, ndim=1) + + # should pass because [1.0, 2.0] can be coerced to np.array([1,2]) of dtype int, ndim 1 + obj_sf_coerce = ModelStrictFalseCoercion(a=[1.0, 2.0]) + assert np.array_equal(obj_sf_coerce.a, np.array([1, 2])) + assert obj_sf_coerce.a.dtype == np.dtype(int) + assert obj_sf_coerce.a.ndim == 1 + + # scalar_to_1d=False (default), ndim=None, scalar input -> 0D array + class ModelScalarTo0D(Tidy3dBaseModel): + a: array_alias(scalar_to_1d=False) + + obj_s0d = ModelScalarTo0D(a=10) + assert np.array_equal(obj_s0d.a, np.array(10)) + assert obj_s0d.a.ndim == 0 + + # scalar_to_1d=True, ndim=None, scalar input -> 1D array + class ModelScalarTo1DNoNdim(Tidy3dBaseModel): + a: array_alias(scalar_to_1d=True) + + obj_s1d_no_ndim = ModelScalarTo1DNoNdim(a=10) + assert np.array_equal(obj_s1d_no_ndim.a, np.array([10])), obj_s1d_no_ndim.a + assert obj_s1d_no_ndim.a.ndim == 1 + + +def test_complex_type(): + """Tests the Complex type for parsing and serialization.""" + + class ComplexModel(Tidy3dBaseModel): + val: Complex + + inputs = [ + (1 + 2j, 1 + 2j), + ({"real": 3, "imag": -4}, 3 - 4j), + ({"real": 3.5, "imag": 0}, 3.5 + 0j), + (5, 5 + 0j), # int + (6.7, 6.7 + 0j), # float + (True, 1 + 0j), # bool (subclass of int, numbers.Number) + (np.float32(2.5), 2.5 + 0j), # numpy float + (np.int64(-3), -3 + 0j), # numpy int + ([10, -2], 10 - 2j), # list of two numbers + ((0.5, 1.5), 0.5 + 1.5j), # tuple of two numbers + ] + + class ObjWithComplexMethod: + def __complex__(self): + return -1 - 1j + + class ObjWithComplexMethodNumeric: + def __init__(self, val): + self._val = val + + def __complex__(self): + return self._val + + inputs.append((ObjWithComplexMethod(), -1 - 1j)) + inputs.append((ObjWithComplexMethodNumeric(3 + 7j), 3 + 7j)) + inputs.append((ObjWithComplexMethodNumeric(5), 5 + 0j)) + + for input_val, expected_complex in inputs: + model = ComplexModel(val=input_val) + assert model.val == expected_complex, f"Input: {input_val}" + assert isinstance(model.val, complex), f"Input: {input_val}" + + expected_json_val = {"real": expected_complex.real, "imag": expected_complex.imag} + assert model.model_dump(mode="json")["val"] == expected_json_val, ( + f"Input for serialization: {input_val}" + ) diff --git a/tests/test_components/test_viz.py b/tests/test_components/test_viz.py index 4a93435b53..3be8f59cfb 100644 --- a/tests/test_components/test_viz.py +++ b/tests/test_components/test_viz.py @@ -4,8 +4,8 @@ import matplotlib as mpl import matplotlib.pyplot as plt -import pydantic.v1 as pd import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d import Box, Medium, Simulation, Structure @@ -29,7 +29,7 @@ def test_0d_plot(center_z, len_collections): sim = td.Simulation( size=(1, 1, 1), - sources=[ + sources=( td.PointDipole( center=(0, 0, center_z), source_time=td.GaussianPulse( @@ -37,8 +37,8 @@ def test_0d_plot(center_z, len_collections): fwidth=td.C_0 / 5.0, ), polarization="Ez", - ) - ], + ), + ), run_time=1e-13, ) @@ -122,9 +122,9 @@ def test_unallowed_colors(): """ Tests validator for visualization spec for colors not recognized by matplotlib. """ - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.VisualizationSpec(facecolor="rr", edgecolor="green", alpha=0.5) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.VisualizationSpec(facecolor="red", edgecolor="gg", alpha=0.5) @@ -132,9 +132,9 @@ def test_unallowed_alpha(): """ Tests validator for disallowed alpha values. """ - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.VisualizationSpec(facecolor="red", edgecolor="green", alpha=-0.5) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = td.VisualizationSpec(facecolor="red", edgecolor="green", alpha=2.5) @@ -224,7 +224,9 @@ def plot_with_multi_viz_spec(alphas, facecolors, edgecolors, rng, use_viz_spec=T def test_no_matlab_install(monkeypatch): """Test that the `VisualizationSpec` only throws a warning on validation if matplotlib is not installed.""" - monkeypatch.setattr("tidy3d.components.viz.visualization_spec.MATPLOTLIB_IMPORTED", False) + monkeypatch.setattr( + "tidy3d._common.components.viz.visualization_spec.MATPLOTLIB_IMPORTED", False + ) EXPECTED_WARNING_MSG_PIECE = ( "matplotlib was not successfully imported, but is required to validate colors" @@ -247,13 +249,13 @@ def test_plot_from_structure_local(): plot_with_viz_spec(alpha=0.75, facecolor="red", edgecolor="blue") plot_with_viz_spec(alpha=0.75, facecolor="red", edgecolor="blue", use_viz_spec=False) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): plot_with_viz_spec(alpha=0.5, facecolor="dark green", edgecolor="blue") - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): plot_with_viz_spec(alpha=0.5, facecolor="red", edgecolor="ble") - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): plot_with_viz_spec(alpha=1.5, facecolor="red", edgecolor="blue") - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): plot_with_viz_spec(alpha=-0.5, facecolor="red", edgecolor="blue") diff --git a/tests/test_data/test_datasets.py b/tests/test_data/test_datasets.py index b3602ff507..877ecc840e 100644 --- a/tests/test_data/test_datasets.py +++ b/tests/test_data/test_datasets.py @@ -3,9 +3,9 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pd import pytest from matplotlib import pyplot as plt +from pydantic import ValidationError from ..utils import AssertLogLevel, cartesian_to_unstructured @@ -62,7 +62,7 @@ def test_triangular_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): assert tri_grid.name == ds_name # wrong points dimensionality - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): tri_grid_points_bad = td.PointDataArray( np.random.random((4, 3)), coords={"index": np.arange(4), "axis": np.arange(3)}, @@ -117,7 +117,7 @@ def test_triangular_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): [[0, 1, 2, 3]], coords={"cell_index": np.arange(1), "vertex_index": np.arange(4)}, ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( normal_axis=2, normal_pos=-3, @@ -130,7 +130,7 @@ def test_triangular_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): [[0, 1, 5], [1, 2, 3]], coords={"cell_index": np.arange(2), "vertex_index": np.arange(3)}, ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( normal_axis=2, normal_pos=-3, @@ -144,7 +144,7 @@ def test_triangular_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): np.random.rand(3, *[len(coord) for coord in extra_dims.values()]), coords=dict(index=np.arange(3), **extra_dims), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( normal_axis=0, normal_pos=0, @@ -289,7 +289,6 @@ def test_triangular_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): # writing/reading tri_grid.to_file(tmp_path / "tri_grid_test.hdf5") - tri_grid_loaded = dataset_type.from_file(tmp_path / "tri_grid_test.hdf5") assert tri_grid == tri_grid_loaded @@ -387,7 +386,7 @@ def test_tetrahedral_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): np.random.random((8, 2)), coords={"index": np.arange(8), "axis": np.arange(2)}, ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( points=tet_grid_points_bad, cells=tet_grid_cells, @@ -433,7 +432,7 @@ def test_tetrahedral_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): [[0, 1, 3], [0, 2, 3], [0, 2, 6], [0, 4, 6], [0, 4, 5], [0, 1, 5]], coords={"cell_index": np.arange(6), "vertex_index": np.arange(3)}, ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( points=tet_grid_points, cells=tet_grid_cells_bad, @@ -444,7 +443,7 @@ def test_tetrahedral_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): [[0, 1, 3, 17], [0, 2, 3, 7], [0, 2, 6, 7], [0, 4, 6, 7], [0, 4, 5, 7], [0, 1, 5, 7]], coords={"cell_index": np.arange(6), "vertex_index": np.arange(4)}, ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( points=tet_grid_points, cells=tet_grid_cells_bad, @@ -456,7 +455,7 @@ def test_tetrahedral_dataset(tmp_path, ds_name, dataset_type_ind, no_vtk=False): np.random.rand(5, *[len(coord) for coord in extra_dims.values()]), coords=dict(index=np.arange(5), **extra_dims), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = dataset_type( points=tet_grid_points, cells=tet_grid_cells_bad, diff --git a/tests/test_data/test_map_data.py b/tests/test_data/test_map_data.py index 12aa2c2156..37f8390eb6 100644 --- a/tests/test_data/test_map_data.py +++ b/tests/test_data/test_map_data.py @@ -2,8 +2,8 @@ import collections.abc -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError from tidy3d import SimulationDataMap @@ -42,7 +42,7 @@ def test_simulation_data_map_creation(): def test_simulation_data_map_invalid_type_raises_error(): """Tests that a ValidationError is raised for incorrect value types.""" invalid_data = {"data_A": "not simulation data"} - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): SimulationDataMap(keys=tuple(invalid_data.keys()), values=tuple(invalid_data.values())) diff --git a/tests/test_data/test_monitor_data.py b/tests/test_data/test_monitor_data.py index dedaa1d494..33113ab07d 100644 --- a/tests/test_data/test_monitor_data.py +++ b/tests/test_data/test_monitor_data.py @@ -4,15 +4,12 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest import xarray as xr +from pydantic import ValidationError import tidy3d as td -from tidy3d.components.data.data_array import ( - FreqDataArray, - FreqModeDataArray, -) +from tidy3d.components.data.data_array import FreqDataArray, FreqModeDataArray from tidy3d.components.data.monitor_data import ( AuxFieldTimeData, DiffractionData, @@ -411,10 +408,10 @@ def test_mode_data_with_fields(): _ = data.updated_copy(eps_spec=["tensorial_real"] * num_freqs) _ = data.updated_copy(eps_spec=["tensorial_complex"] * num_freqs) # wrong keyword - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = data.updated_copy(eps_spec=["tensorial"] * num_freqs) # wrong number - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = data.updated_copy(eps_spec=["diagonal"] * (num_freqs + 1)) # check monitor direction changes upon time reversal data_reversed = data.time_reversed_copy @@ -717,7 +714,7 @@ def test_field_data_symmetry_present(): _ = td.FieldTimeData(monitor=monitor, **fields) # fails if symmetry specified but missing symmetry center - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.FieldTimeData( monitor=monitor, symmetry=(1, -1, 0), @@ -726,7 +723,7 @@ def test_field_data_symmetry_present(): ) # fails if symmetry specified but missing etended grid - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = td.FieldTimeData( monitor=monitor, symmetry=(1, -1, 1), symmetry_center=(0, 0, 0), **fields ) @@ -977,7 +974,7 @@ def test_no_nans(): eps_dataset_nan = td.PermittivityDataset( **dict.fromkeys(["eps_xx", "eps_yy", "eps_zz"], eps_nan) ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): td.CustomMedium(eps_dataset=eps_dataset_nan) diff --git a/tests/test_data/test_sim_data.py b/tests/test_data/test_sim_data.py index 7dc22d5866..9118a80583 100644 --- a/tests/test_data/test_sim_data.py +++ b/tests/test_data/test_sim_data.py @@ -2,11 +2,13 @@ from __future__ import annotations +import warnings + import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest from matplotlib import colors as mcolors +from pydantic import ValidationError import tidy3d as td from tidy3d.components.data.data_array import ScalarFieldTimeDataArray @@ -16,6 +18,7 @@ from tidy3d.components.monitor import FieldMonitor, FieldTimeMonitor, ModeMonitor from tidy3d.exceptions import DataError, SetupError, Tidy3dKeyError +from ..test_components.test_mode import get_mode_sim_data from ..utils import get_nested_shape from .test_data_arrays import FIELD_MONITOR, SIM, SIM_SYM from .test_monitor_data import ( @@ -96,6 +99,59 @@ def make_sim_data(symmetry: bool = True): ) +def make_heat_charge_sim_data(): + """Create a simple HeatChargeSimulationData for testing.""" + temp_mnt = td.TemperatureMonitor(size=(1, 2, 3), name="temperature") + + tet_grid_points = td.PointDataArray( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dims=("index", "axis"), + ) + tet_grid_cells = td.CellDataArray( + [[0, 1, 2, 4], [1, 2, 3, 4]], + dims=("cell_index", "vertex_index"), + ) + tet_grid_values = td.IndexedDataArray( + np.linspace(300, 350, tet_grid_points.shape[0]), + dims=("index",), + name="T", + ) + + tet_grid = td.TetrahedralGridDataset( + points=tet_grid_points, + cells=tet_grid_cells, + values=tet_grid_values, + ) + + temp_data = td.TemperatureData(monitor=temp_mnt, temperature=tet_grid) + + heat_sim = td.HeatChargeSimulation( + size=(3.0, 3.0, 3.0), + structures=[ + td.Structure( + geometry=td.Box(size=(1, 1, 1), center=(0, 0, 0)), + medium=td.Medium( + permittivity=2.0, + heat_spec=td.SolidSpec(conductivity=1, capacity=1), + ), + name="box", + ), + ], + medium=td.Medium(permittivity=3.0, heat_spec=td.FluidSpec()), + grid_spec=td.UniformUnstructuredGrid(dl=0.1), + sources=[td.HeatSource(rate=1, structures=["box"])], + boundary_spec=[ + td.HeatChargeBoundarySpec( + placement=td.StructureBoundary(structure="box"), + condition=td.TemperatureBC(temperature=500), + ) + ], + monitors=[temp_mnt], + ) + + return td.HeatChargeSimulationData(simulation=heat_sim, data=(temp_data,)) + + def test_sim_data(): _ = make_sim_data() @@ -285,7 +341,7 @@ def test_field_decay_log_none(): def test_to_dict(): sim_data = make_sim_data() - j = sim_data.dict() + j = sim_data.model_dump() sim_data2 = SimulationData(**j) assert sim_data == sim_data2 @@ -296,7 +352,7 @@ def test_to_json(tmp_path): sim_data.to_file(fname=FNAME) # saving to json does not store data, so trying to load from file will trigger custom error. - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = SimulationData.from_file(fname=FNAME) @@ -485,9 +541,9 @@ def test_plot_field_title(): def test_missing_monitor(): sim_data = make_sim_data() - new_monitors = list(sim_data.simulation.monitors)[:-1] + new_monitors = tuple(sim_data.simulation.monitors)[:-1] new_sim = sim_data.simulation.copy(update={"monitors": new_monitors}) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = sim_data.copy(update={"simulation": new_sim}) @@ -504,7 +560,7 @@ def test_replace_values_dict(): """ # Create data for test test_data = make_sim_data() - test_data_dict = test_data.dict() + test_data_dict = test_data.model_dump() test_data_dict["none_test"] = None # Check that replace works at top level of nested dict # Get the original shape of nested dict @@ -542,11 +598,16 @@ def test_replace_values_list(): assert original_shape == new_shape and new_list[3] == [] and new_list[2][5][0] == [] -def test_to_mat_file(tmp_path): +@pytest.mark.parametrize( + "make_sim_data_fn", + [make_sim_data, get_mode_sim_data, make_heat_charge_sim_data], + ids=["SimulationData", "ModeSimulationData", "HeatChargeSimulationData"], +) +def test_to_mat_file(tmp_path, make_sim_data_fn): """ - Test output of ``.mat`` file completes without error. + Test output of ``.mat`` file completes without error for all simulation data types. """ - sim_data = make_sim_data() + sim_data = make_sim_data_fn() path = str(tmp_path / "test.mat") sim_data.to_mat_file(path) @@ -565,3 +626,49 @@ def test_plot_field_monitor_data_unsupported_scale(): val="real", scale="invalid", ) + + +def test_plot_field_with_zeros_db_scale(): + """Test that plotting field data with zeros using dB scale doesn't produce warnings.""" + sim_data = make_sim_data() + + # Get the existing field data and modify it to include zeros + field_monitor_data = sim_data["field"] + ex_data = field_monitor_data.Ex + + # Create modified data with some zeros + values = ex_data.values.copy() + values[np.abs(values) < np.abs(values).max() * 0.3] = 0 # Set small values to zero + ex_with_zeros = ex_data.copy(data=values) + + # Create new field data with zeros + field_data_with_zeros = field_monitor_data.updated_copy(Ex=ex_with_zeros) + f_sel = ex_data.f.values[0] + x_sel = ex_data.x.values[0] + + # Plot with dB scale - this should not produce divide by zero warnings + with warnings.catch_warnings(): + warnings.filterwarnings("error", message="divide by zero") + sim_data.plot_field_monitor_data( + field_monitor_data=field_data_with_zeros, + field_name="Ex", + val="abs", + scale="dB", + f=f_sel, + x=x_sel, + ) + plt.close() + + # Also test with vmin specified (zeros replaced with floor value) + with warnings.catch_warnings(): + warnings.filterwarnings("error", message="divide by zero") + sim_data.plot_field_monitor_data( + field_monitor_data=field_data_with_zeros, + field_name="Ex", + val="abs", + scale="dB", + f=f_sel, + x=x_sel, + vmin=-50, + ) + plt.close() diff --git a/tests/test_material_library/test_material_library.py b/tests/test_material_library/test_material_library.py index 681bf86c29..942f290368 100644 --- a/tests/test_material_library/test_material_library.py +++ b/tests/test_material_library/test_material_library.py @@ -71,12 +71,8 @@ def test_medium_repr(): repr_noname_medium = test_media[0].__repr__() str_noname_medium_dict = str(noname_medium_in_dict) - assert "type='Medium' permittivity=2.25 conductivity=0.0" in str_noname_medium, ( - "Expected medium information in string" - ) - assert "Medium(attrs={}, name=None, frequency_range=None" in repr_noname_medium, ( - "Expcted medium information in repr" - ) + assert "name=None," in str_noname_medium, "Expected medium information in string" + assert "name=None" in repr_noname_medium, "Expected medium information in repr" assert repr_noname_medium in str_noname_medium_dict, "Expected repr in dictionary string" for medium in test_media: diff --git a/tests/test_package/test_log.py b/tests/test_package/test_log.py index 0ea792838e..cdc068cb1f 100644 --- a/tests/test_package/test_log.py +++ b/tests/test_package/test_log.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from pydantic import ValidationError +from pydantic import ValidationError, model_validator import tidy3d as td from tidy3d.exceptions import Tidy3dError @@ -66,6 +66,7 @@ def test_logging_warning_capture(): # create sim with warnings domain_size = 12 + td.log.set_capture(True) wavelength = 1 f0 = td.C_0 / wavelength fwidth = f0 / 10.0 @@ -204,16 +205,15 @@ def test_logging_warning_capture(): ) # parse the entire simulation at once to capture warnings hierarchically - sim_dict = sim.dict() + sim_dict = sim.model_dump() # re-add projection monitors because it has been overwritten in validators (far_field_approx=False -> True) monitors = list(sim_dict["monitors"]) - monitors[2] = proj_mnt.dict() + monitors[2] = proj_mnt.model_dump() sim_dict["monitors"] = monitors - td.log.set_capture(True) - sim = td.Simulation.parse_obj(sim_dict) + sim = td.Simulation.model_validate(sim_dict) print(sim.monitors_data_size) sim.validate_pre_upload() warning_list = td.log.captured_warnings() @@ -224,18 +224,18 @@ def test_logging_warning_capture(): # check that capture doesn't change validation errors - # validation error during parse_obj() - sim_dict_no_source = sim.dict() + # validation error during model_validate() + sim_dict_no_source = sim.model_dump() sim_dict_no_source.update({"sources": []}) # validation error during validate_pre_upload() - sim_dict_large_mnt = sim.dict() + sim_dict_large_mnt = sim.model_dump() sim_dict_large_mnt.update({"monitors": [monitor_time.updated_copy(size=(10, 10, 10))]}) # for sim_dict in [sim_dict_no_source, sim_dict_large_mnt]: for sim_dict in [sim_dict_no_source]: try: - sim = td.Simulation.parse_obj(sim_dict) + sim = td.Simulation.model_validate(sim_dict) sim.validate_pre_upload() except ValidationError as e: error_without = e.errors() @@ -244,7 +244,7 @@ def test_logging_warning_capture(): td.log.set_capture(True) try: - sim = td.Simulation.parse_obj(sim_dict) + sim = td.Simulation.model_validate(sim_dict) sim.validate_pre_upload() except ValidationError as e: error_with = e.errors() @@ -252,10 +252,36 @@ def test_logging_warning_capture(): error_with = str(e) td.log.set_capture(False) - print(error_without) - print(error_with) + assert str(error_without) == str(error_with) + + +def test_warning_capture_during_model_validation(): + from tidy3d.components.base import Tidy3dBaseModel + from tidy3d.log import log - assert error_without == error_with + class _CaptureChild(Tidy3dBaseModel): + x: int + + @model_validator(mode="after") + def _warn_child(self): + log.warning("child warning") + return self + + class _CaptureParent(Tidy3dBaseModel): + child: _CaptureChild + + @model_validator(mode="after") + def _warn_parent(self): + log.warning("parent warning") + return self + + td.log.set_capture(True) + _CaptureParent(child={"x": 1}) + warning_list = td.log.captured_warnings() + td.log.set_capture(False) + + assert {"loc": [], "msg": "parent warning"} in warning_list + assert {"loc": ["child"], "msg": "child warning"} in warning_list def test_log_suppression(): @@ -275,6 +301,59 @@ def test_log_suppression(): td.config.log_suppression = True +def test_warn_once(): + """Test that warn_once setting causes each unique warning to only be shown once.""" + + # Clear the static cache to ensure clean test state + td.log._static_cache.clear() + + # By default, warn_once should be False + assert td.log.warn_once is False + assert td.config.logging.warn_once is False + + # Enable warn_once via config + td.config.logging.warn_once = True + assert td.log.warn_once is True + + # First warning should go through + initial_cache_size = len(td.log._static_cache) + td.log.warning("unique_test_warning_message_1234") + assert len(td.log._static_cache) == initial_cache_size + 1 + assert "unique_test_warning_message_1234" in td.log._static_cache + + # Same warning should be skipped (cache doesn't grow) + td.log.warning("unique_test_warning_message_1234") + assert len(td.log._static_cache) == initial_cache_size + 1 + + # Different warning should go through + td.log.warning("different_warning_message_5678") + assert len(td.log._static_cache) == initial_cache_size + 2 + assert "different_warning_message_5678" in td.log._static_cache + + # Info messages should NOT be affected by warn_once + td.log.info("info_message_should_not_cache") + td.log.info("info_message_should_not_cache") + # Info messages don't use the cache when warn_once is enabled (only warnings do) + assert "info_message_should_not_cache" not in td.log._static_cache + + # Error messages should NOT be affected by warn_once + td.log.error("error_message_should_not_cache") + td.log.error("error_message_should_not_cache") + assert "error_message_should_not_cache" not in td.log._static_cache + + # Critical messages should NOT be affected by warn_once + td.log.critical("critical_message_should_not_cache") + td.log.critical("critical_message_should_not_cache") + assert "critical_message_should_not_cache" not in td.log._static_cache + + # Disable warn_once + td.config.logging.warn_once = False + assert td.log.warn_once is False + + # Clear cache for cleanup + td.log._static_cache.clear() + + def test_assert_log_level(): """Test features of the assert_log_level""" diff --git a/tests/test_package/test_material_library.py b/tests/test_package/test_material_library.py index 1b988c43f4..60bacacd4b 100644 --- a/tests/test_package/test_material_library.py +++ b/tests/test_package/test_material_library.py @@ -1,8 +1,8 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d.components.material.multi_physics import MultiPhysicsMedium @@ -55,7 +55,7 @@ def test_MaterialItem(): ) assert material["v1"] == material.medium - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): material = MaterialItem( name="material", variants={"v1": variant1, "v2": variant2}, default="v3" ) diff --git a/tests/test_package/test_parametric_variants.py b/tests/test_package/test_parametric_variants.py index ddd0d94bbc..f514ab7c3a 100644 --- a/tests/test_package/test_parametric_variants.py +++ b/tests/test_package/test_parametric_variants.py @@ -32,7 +32,8 @@ def test_graphene_defaults(): _ = graphene.numerical_conductivity(freqs) -@pytest.mark.parametrize("rng_seed", np.arange(0, 15)) +@pytest.mark.parametrize("rng_seed", np.arange(0, 8)) +@pytest.mark.slow def test_graphene(rng_seed): """test graphene for range of physical parameters""" rng = default_rng(rng_seed) diff --git a/tests/test_plugins/autograd/invdes/test_parametrizations.py b/tests/test_plugins/autograd/invdes/test_parametrizations.py index 1a07a73c9f..106adb86f9 100644 --- a/tests/test_plugins/autograd/invdes/test_parametrizations.py +++ b/tests/test_plugins/autograd/invdes/test_parametrizations.py @@ -10,20 +10,29 @@ @pytest.mark.parametrize("radius", [1, 2, (1, 2)]) @pytest.mark.parametrize("dl", [0.1, 0.2, (0.1, 0.2)]) @pytest.mark.parametrize("size_px", [None, 5, (5, 7)]) +@pytest.mark.parametrize("beta", [1.0, 10.0]) @pytest.mark.parametrize("filter_type", ["circular", "conic", "gaussian"]) @pytest.mark.parametrize("padding", PaddingType.__args__) -def test_make_filter_and_project(rng, radius, dl, size_px, filter_type, padding): +@pytest.mark.parametrize("init_type", ("random", "binary_low_border", "binary_high_border")) +def test_make_filter_and_project(rng, radius, dl, size_px, beta, filter_type, padding, init_type): """Test make_filter_and_project function for various parameters.""" filter_and_project_func = make_filter_and_project( radius=radius, dl=dl, size_px=size_px, - beta=10, + beta=beta, eta=0.5, filter_type=filter_type, padding=padding, ) - array = rng.random((51, 51)) + if init_type == "random": + array = rng.random((51, 51)) + elif init_type == "binary_low_border": + array = np.zeros((100, 100)) + array[40:60, 40:60] = 1.0 + else: + array = np.ones((100, 100)) + array[40:60, 40:60] = 0.0 result = filter_and_project_func(array) assert result.shape == array.shape assert np.all(result >= 0) and np.all(result <= 1) diff --git a/tests/test_plugins/autograd/test_functions.py b/tests/test_plugins/autograd/test_functions.py index 432a36db0f..534b7528b2 100644 --- a/tests/test_plugins/autograd/test_functions.py +++ b/tests/test_plugins/autograd/test_functions.py @@ -9,6 +9,7 @@ from autograd.test_util import check_grads from scipy.signal import convolve as convolve_sp +from tidy3d.compat import np_trapezoid from tidy3d.plugins.autograd import ( add_at, convolve, @@ -552,7 +553,7 @@ def test_trapz_val(self, rng, shape, axis, use_x): """Test trapz values against NumPy for different array dimensions and integration axes.""" y, x, dx = self.generate_y_x_dx(rng, shape, use_x) result_custom = trapz(y, x=x, dx=dx, axis=axis) - result_numpy = np.trapz(y, x=x, dx=dx, axis=axis) + result_numpy = np_trapezoid(y, x=x, dx=dx, axis=axis) npt.assert_allclose(result_custom, result_numpy) def test_trapz_grad(self, rng, shape, axis, use_x): diff --git a/tests/test_plugins/expressions/test_dispatch.py b/tests/test_plugins/expressions/test_dispatch.py index ee7fbec31c..1ffe233698 100644 --- a/tests/test_plugins/expressions/test_dispatch.py +++ b/tests/test_plugins/expressions/test_dispatch.py @@ -8,7 +8,7 @@ def test_expression_parse_obj_round_trip(): expr = Constant(3.14) - parsed = Expression.parse_obj(expr.dict()) + parsed = Expression.model_validate(expr.dict()) assert isinstance(parsed, Constant) assert parsed.value == pytest.approx(3.14) @@ -16,4 +16,4 @@ def test_expression_parse_obj_round_trip(): def test_expression_parse_obj_rejects_unrelated_types(): # Simulation registers a distinct type in the global map; parsing via Expression should fail. with pytest.raises(ValueError, match="Cannot parse type"): - Expression.parse_obj({"type": "Simulation"}) + Expression.model_validate({"type": "Simulation"}) diff --git a/tests/test_plugins/expressions/test_variables.py b/tests/test_plugins/expressions/test_variables.py index da74b03dc0..d50ce5a2a9 100644 --- a/tests/test_plugins/expressions/test_variables.py +++ b/tests/test_plugins/expressions/test_variables.py @@ -20,7 +20,7 @@ def test_constant_evaluate(value): def test_constant_type(value): constant = Constant(value) result = constant.evaluate() - assert isinstance(result, type(value)) + assert isinstance(result, type(value)), f"{type(value)}, {type(result)}" def test_variable_evaluate_positional(value): diff --git a/tests/test_plugins/klayout/drc/test_drc.py b/tests/test_plugins/klayout/drc/test_drc.py index f33a813eb8..2a2d50e329 100644 --- a/tests/test_plugins/klayout/drc/test_drc.py +++ b/tests/test_plugins/klayout/drc/test_drc.py @@ -6,8 +6,8 @@ import xml.etree.ElementTree as ET from pathlib import Path -import pydantic.v1 as pd import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d.exceptions import FileError @@ -174,7 +174,7 @@ def test_drc_config_args_require_mapping(tmp_path): """drc_args must be a mapping and refuses other iterables.""" kwargs = _basic_drc_config_kwargs(tmp_path) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): DRCConfig(**kwargs, drc_args=["not", "a", "mapping"]) @@ -182,7 +182,7 @@ def test_drc_config_args_reject_reserved_keys(tmp_path): """Reserved keys such as gdsfile cannot be overridden via drc_args.""" kwargs = _basic_drc_config_kwargs(tmp_path) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): DRCConfig(**kwargs, drc_args={"gdsfile": "custom.gds"}) @@ -205,7 +205,7 @@ def __str__(self): kwargs = _basic_drc_config_kwargs(tmp_path) with pytest.raises( - pd.ValidationError, match="Could not coerce keys and values of drc_args to strings." + ValidationError, match="Could not coerce keys and values of drc_args to strings." ): DRCConfig(**kwargs, drc_args={"bad": Unstringifiable()}) @@ -425,7 +425,7 @@ def test_check_drcfile_format_invalid( if drc_file_suffix == ".lydrc": drc_content = TestDRCRunner.wrap_drc_to_lydrc(drc_content) self.write_drcrunset(tmp_path, f"bad_drcrunset{drc_file_suffix}", drc_content) - with pytest.raises(pd.ValidationError) as e: + with pytest.raises(ValidationError) as e: self.run( monkeypatch=monkeypatch, drc_runsetfile=tmp_path / f"bad_drcrunset{drc_file_suffix}", @@ -438,7 +438,7 @@ def test_check_drcfile_format_invalid( def test_check_gdsfile_exists(self, monkeypatch, tmp_path, good_drcrunset_content): """Test gdsfile existence checking works""" self.write_drcrunset(tmp_path, "good_drcfile.drc", good_drcrunset_content) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): self.run( monkeypatch=monkeypatch, drc_runsetfile=tmp_path / "good_drcfile.drc", @@ -454,7 +454,7 @@ def test_check_gdsfile_filetype( """Test gdsfile filetype checking works""" self.write_drcrunset(tmp_path, "good_drcfile.drc", good_drcrunset_content) geom.to_gds_file(tmp_path / "test.g2ds", **geom_to_gds_kwargs) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): self.run( monkeypatch=monkeypatch, drc_runsetfile=tmp_path / "good_drcfile.drc", @@ -467,7 +467,7 @@ def test_check_gdsfile_filetype( def test_check_designrulefile_exists(self, monkeypatch, tmp_path, geom, geom_to_gds_kwargs): """Test design rule file existence checking works""" geom.to_gds_file(tmp_path / "test.gds", **geom_to_gds_kwargs) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): self.run( monkeypatch=monkeypatch, drc_runsetfile=tmp_path / "not_a_drc_file.drc", @@ -483,7 +483,7 @@ def test_check_designrulefile_filetype( """Test design rule file filetype checking works""" geom.to_gds_file(tmp_path / "test.gds", **geom_to_gds_kwargs) self.write_drcrunset(tmp_path, "good_drcfile.drc2", good_drcrunset_content) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): self.run( monkeypatch=monkeypatch, drc_runsetfile=tmp_path / "good_drcfile.drc2", diff --git a/tests/test_plugins/smatrix/test_component_modeler.py b/tests/test_plugins/smatrix/test_component_modeler.py index 830955a31c..74c021782e 100644 --- a/tests/test_plugins/smatrix/test_component_modeler.py +++ b/tests/test_plugins/smatrix/test_component_modeler.py @@ -3,8 +3,8 @@ import gdstk import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td from tidy3d import SimulationDataMap @@ -127,7 +127,7 @@ def offset(u): # in-plane field monitor (optional, increases required data storage) domain_monitor = td.FieldMonitor( - center=[0, 0, wg_height / 2], size=[td.inf, td.inf, 0], freqs=freqs, name="field" + center=(0, 0, wg_height / 2), size=(td.inf, td.inf, 0), freqs=freqs, name="field" ) # initialize the simulation @@ -213,13 +213,13 @@ def test_validate_no_sources(): source_time=td.GaussianPulse(freq0=2e14, fwidth=1e14), polarization="Ex" ) sim_w_source = modeler.simulation.copy(update={"sources": (source,)}) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = modeler.copy(update={"simulation": sim_w_source}) def test_element_mappings_none(): modeler = make_component_modeler() - modeler = modeler.updated_copy(ports=[], element_mappings=()) + modeler = modeler.updated_copy(ports=(), element_mappings=()) _ = modeler.matrix_indices_run_sim @@ -301,7 +301,7 @@ def test_component_modeler_run_only(monkeypatch): _ = make_coupler() _ = make_ports() ONLY_SOURCE = (port_run_only, mode_index_run_only) = ("right_bot", 0) - run_only = [ONLY_SOURCE] + run_only = (ONLY_SOURCE,) modeler = make_component_modeler(run_only=run_only) modeler_data = run_component_modeler(monkeypatch, modeler=modeler) s_matrix = modeler_data.smatrix() @@ -380,14 +380,14 @@ def test_mapping_exclusion(monkeypatch): mapping = ((("right_bot", 1), ("right_bot", 1)), (EXCLUDE_INDEX, EXCLUDE_INDEX), +1) element_mappings.append(mapping) - modeler = make_component_modeler(element_mappings=element_mappings) + modeler = make_component_modeler(element_mappings=tuple(element_mappings)) modeler_data = run_component_modeler(monkeypatch, modeler=modeler) s_matrix = modeler_data.smatrix() run_sim_indices = modeler.matrix_indices_run_sim assert EXCLUDE_INDEX not in run_sim_indices, "mapping didnt exclude row properly" - _test_mappings(element_mappings, s_matrix) + _test_mappings(tuple(element_mappings), s_matrix) def test_mapping_with_run_only(): @@ -416,7 +416,7 @@ def test_mapping_with_run_only(): _ = make_component_modeler(element_mappings=element_mappings, run_only=run_only) run_only.remove(EXCLUDE_INDEX) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): _ = make_component_modeler(element_mappings=element_mappings, run_only=run_only) @@ -482,7 +482,7 @@ def test_validate_run_only_uniqueness_modal(): port1_idx = (modeler.ports[1].name, 0) # Test with duplicate entries - should raise ValidationError - with pytest.raises(pydantic.ValidationError, match="duplicate entries"): + with pytest.raises(ValidationError, match="duplicate entries"): modeler.updated_copy(run_only=(port0_idx, port0_idx, port1_idx)) @@ -491,11 +491,11 @@ def test_validate_run_only_membership_modal(): modeler = make_component_modeler() # Test with invalid port name - with pytest.raises(pydantic.ValidationError, match="not present in"): + with pytest.raises(ValidationError, match="not present in"): modeler.updated_copy(run_only=(("invalid_port", 0),)) # Test with invalid mode index port0_name = modeler.ports[0].name invalid_mode = modeler.ports[0].mode_spec.num_modes + 1 - with pytest.raises(pydantic.ValidationError, match="not present in"): + with pytest.raises(ValidationError, match="not present in"): modeler.updated_copy(run_only=((port0_name, invalid_mode),)) diff --git a/tests/test_plugins/smatrix/test_run_functions.py b/tests/test_plugins/smatrix/test_run_functions.py index a4207165e5..494764e318 100644 --- a/tests/test_plugins/smatrix/test_run_functions.py +++ b/tests/test_plugins/smatrix/test_run_functions.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock -import pydantic.v1 as pd import pytest +from pydantic import ValidationError import tidy3d from tests.test_plugins.smatrix.terminal_component_modeler_def import ( @@ -109,7 +109,7 @@ def test_compose_modeler_data_keys_mismatch(): ), ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): TerminalComponentModelerData( modeler=make_terminal_component_modeler(planar_pec=True), data=dummy_sim_data_map ) diff --git a/tests/test_plugins/smatrix/test_terminal_component_modeler.py b/tests/test_plugins/smatrix/test_terminal_component_modeler.py index 4bb4649cee..95bf5d92f1 100644 --- a/tests/test_plugins/smatrix/test_terminal_component_modeler.py +++ b/tests/test_plugins/smatrix/test_terminal_component_modeler.py @@ -2,10 +2,10 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pd import pytest import skrf import xarray as xr +from pydantic import ValidationError import tidy3d as td import tidy3d.plugins.smatrix.analysis.terminal @@ -14,7 +14,7 @@ from tidy3d import SimulationDataMap from tidy3d.components.boundary import BroadbandModeABCSpec from tidy3d.components.data.data_array import FreqDataArray -from tidy3d.exceptions import SetupError, Tidy3dError, Tidy3dKeyError, ValidationError +from tidy3d.exceptions import SetupError, Tidy3dError, Tidy3dKeyError from tidy3d.plugins.smatrix import ( CoaxialLumpedPort, LumpedPort, @@ -254,7 +254,7 @@ def test_validate_no_sources(tmp_path): source_time=td.GaussianPulse(freq0=2e14, fwidth=1e14), polarization="Ex" ) sim_w_source = modeler.simulation.copy(update={"sources": (source,)}) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = modeler.copy(update={"simulation": sim_w_source}) @@ -267,14 +267,21 @@ def test_validate_freqs(): _ = modeler._source_time # Negative frequencies are not allowed freqs = np.array([-1.0, 5]) * 1e9 - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): + _ = modeler.updated_copy(freqs=freqs) + freqs = np.array([-1.0, 5]) + with pytest.raises(ValidationError): _ = modeler.updated_copy(freqs=freqs) + freqs = np.array([1, 2, 1.9]) + with pytest.raises(ValidationError): + _ = modeler.updated_copy(freqs=freqs) + # Test case with non-unique value f_min, f_max = (0.5e9, 1.5e9) f0 = (f_min + f_max) / 2 f_target = 1.35e9 freqs = np.sort(np.append(np.linspace(f_min, f_max, 21), f_target)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = modeler.updated_copy(freqs=freqs) @@ -292,7 +299,7 @@ def test_validate_3D_sim(tmp_path): ), run_time=1e-10, ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = modeler.updated_copy(simulation=sim) @@ -558,7 +565,7 @@ def test_coarse_grid_at_port(monkeypatch): def test_validate_port_voltage_axis(): - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): LumpedPort(center=(0, 0, 0), size=(0, 1, 2), voltage_axis=0, impedance=50) @@ -625,11 +632,11 @@ def test_lumped_port_from_structures(): # test port width with lateral coords lp_options["lateral_coord"] = None - with pytest.raises(ValidationError): + with pytest.raises(ValueError): LP5 = LumpedPort.from_structures(x=-WL / 2 - LL1, name="LP5", **lp_options) lp_options["lateral_coord"] = 10 * WL - with pytest.raises(ValidationError): + with pytest.raises(ValueError): LP6 = LumpedPort.from_structures(x=-WL / 2 - LL1, name="LP6", **lp_options) # ensure that validation error is raised when specified port width exceeds terminal overlap in lateral direction. @@ -642,7 +649,7 @@ def test_lumped_port_from_structures(): str_gnd_new = str_gnd.updated_copy(medium=td.Medium(conductivity=1e2)) lp_options["lateral_coord"] = -2 * LL2 - WL / 2 lp_options["ground_terminal"] = str_gnd_new - with pytest.raises(ValidationError): + with pytest.raises(ValueError): LP7 = LumpedPort.from_structures(x=-WL / 2 - LL1, name="LP7", **lp_options) @@ -704,7 +711,7 @@ def test_coarse_grid_at_coaxial_port(monkeypatch, tmp_path, grid_spec): def test_validate_coaxial_center_not_inf(): - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): CoaxialLumpedPort( center=(td.inf, 0, 0), outer_diameter=8, @@ -718,7 +725,7 @@ def test_validate_coaxial_center_not_inf(): def test_validate_coaxial_port_diameters(): - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): CoaxialLumpedPort( center=(0, 0, 0), outer_diameter=1, @@ -839,7 +846,7 @@ def test_run_coaxial_component_modeler_with_wave_ports( xy_grid = td.UniformGrid(dl=0.1 * 1e3) grid_spec = td.GridSpec(grid_x=xy_grid, grid_y=xy_grid, grid_z=z_grid) if not (voltage_enabled or current_enabled): - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): modeler = make_coaxial_component_modeler( port_types=(WavePort, WavePort), grid_spec=grid_spec, @@ -966,7 +973,7 @@ def test_wave_port_path_integral_validation(): direction="+", ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): mw_mode_spec = td.MicrowaveModeSpec( num_modes=1, target_neff=1.8, @@ -981,7 +988,7 @@ def test_wave_port_path_integral_validation(): ) voltage_path = voltage_path.updated_copy(size=(4, 0, 0)) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): mode_spec = td.MicrowaveModeSpec( num_modes=1, target_neff=1.8, @@ -999,7 +1006,7 @@ def test_wave_port_path_integral_validation(): center=center_port, radius=3, num_points=21, normal_axis=2, clockwise=False ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): mode_spec = td.MicrowaveModeSpec( num_modes=1, target_neff=1.8, @@ -1080,7 +1087,7 @@ def test_wave_port_grid_validation(tmp_path): num_grid_cells=None, ) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = WavePort( center=center_port, size=size_port, @@ -1134,7 +1141,7 @@ def test_port_source_snapped_to_PML(tmp_path): mode_spec=mw_mode_spec, direction="-", ) - modeler = modeler.updated_copy(ports=[port]) + modeler = modeler.updated_copy(ports=(port,)) # Error because port is snapped to PML layers; but the error message might not # be very informative, e.g. "simulation.sources[0]' is outside of the simulation domain". @@ -1158,7 +1165,7 @@ def test_port_source_snapped_to_PML(tmp_path): mode_spec=mw_mode_spec, direction="+", ) - modeler = modeler.updated_copy(ports=[port]) + modeler = modeler.updated_copy(ports=(port,)) with pytest.raises(SetupError): modeler.sim_dict @@ -1195,7 +1202,9 @@ def test_antenna_helpers(monkeypatch, tmp_path): theta=theta, phi=phi, ) - modeler: TerminalComponentModeler = modeler.updated_copy(radiation_monitors=[radiation_monitor]) + modeler: TerminalComponentModeler = modeler.updated_copy( + radiation_monitors=(radiation_monitor,) + ) # Run simulation to get data modeler_data = run_component_modeler(monkeypatch, modeler) @@ -1258,11 +1267,11 @@ def test_antenna_parameters(monkeypatch, port_type): theta=theta, phi=phi, ) - with pytest.raises(pd.ValidationError): - modeler = modeler.updated_copy(radiation_monitors=[radiation_monitor]) + with pytest.raises(ValidationError): + modeler = modeler.updated_copy(radiation_monitors=(radiation_monitor,)) radiation_monitor = radiation_monitor.updated_copy(freqs=modeler.freqs) - modeler = modeler.updated_copy(radiation_monitors=[radiation_monitor]) + modeler = modeler.updated_copy(radiation_monitors=(radiation_monitor,)) # Run simulation and get antenna parameters modeler_data = run_component_modeler(monkeypatch, modeler) @@ -1324,7 +1333,7 @@ def test_get_combined_antenna_parameters_data(monkeypatch, tmp_path): theta=theta, phi=phi, ) - modeler = modeler.updated_copy(radiation_monitors=[radiation_monitor]) + modeler = modeler.updated_copy(radiation_monitors=(radiation_monitor,)) modeler_data = run_component_modeler(monkeypatch=monkeypatch, modeler=modeler) # Define port amplitudes @@ -1636,10 +1645,10 @@ def test_low_freq_smoothing_spec_validation_order_bounds(): ModelerLowFrequencySmoothingSpec(order=3) # Test invalid orders - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): ModelerLowFrequencySmoothingSpec(order=-1) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): ModelerLowFrequencySmoothingSpec(order=4) @@ -1652,7 +1661,7 @@ def test_low_freq_smoothing_spec_validation_max_deviation_bounds(): ModelerLowFrequencySmoothingSpec(max_deviation=1.0) # Test invalid max_deviation - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): ModelerLowFrequencySmoothingSpec(max_deviation=-0.1) @@ -1775,7 +1784,7 @@ def test_wave_port_extrusion_coaxial(): port_1 = port_1.updated_copy(center=(0, 0, -50000), extrude_structures=True) # test that structure extrusion requires an internal absorber (should raise ValidationError) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = port_2.updated_copy(center=(0, 0, 50000), extrude_structures=True, absorber=False) # define a valid waveport @@ -1853,7 +1862,7 @@ def test_wave_port_extrusion_differential_stripline(): port_1 = port_1.updated_copy(extrude_structures=True) # test that structure extrusion requires an internal absorber (should raise ValidationError) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): _ = port_2.updated_copy(extrude_structures=True, absorber=False) # define a valid waveport @@ -1965,7 +1974,7 @@ def test_validate_run_only_uniqueness(): port1_idx = modeler.network_index(modeler.ports[1]) # Test with duplicate entries - should raise ValidationError - with pytest.raises(pd.ValidationError, match="duplicate entries"): + with pytest.raises(ValidationError, match="duplicate entries"): modeler.updated_copy(run_only=(port0_idx, port0_idx, port1_idx)) @@ -1974,12 +1983,12 @@ def test_validate_run_only_membership(): modeler = make_component_modeler(planar_pec=True) # Test with invalid index - should raise ValidationError - with pytest.raises(pd.ValidationError, match="not present in"): + with pytest.raises(ValidationError, match="not present in"): modeler.updated_copy(run_only=("invalid_port_name",)) # Test with partially invalid indices port0_idx = modeler.network_index(modeler.ports[0]) - with pytest.raises(pd.ValidationError, match="not present in"): + with pytest.raises(ValidationError, match="not present in"): modeler.updated_copy(run_only=(port0_idx, "invalid_port")) @@ -2002,7 +2011,7 @@ def test_validate_run_only_with_wave_ports(): assert modeler_updated.run_only == (port0_idx,) # Invalid case - with pytest.raises(pd.ValidationError, match="not present in"): + with pytest.raises(ValidationError, match="not present in"): modeler.updated_copy(run_only=("nonexistent_wave_port",)) @@ -2144,7 +2153,7 @@ def test_wave_port_mode_index_validation(): assert port._mode_indices == (0,) # Invalid: index greater than number of modes - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): WavePort( center=(0, 0, -10), size=(0, 2, 2), @@ -2177,7 +2186,7 @@ def test_wave_port_mode_index_validation(): assert port._mode_indices == (0, 1, 2) # Invalid: negative index - with pytest.raises(pd.ValidationError, match="non-negative"): + with pytest.raises(ValidationError, match="non-negative"): WavePort( center=(0, 0, -10), size=(0, 2, 2), @@ -2188,7 +2197,7 @@ def test_wave_port_mode_index_validation(): ) # Invalid: index >= num_modes - with pytest.raises(pd.ValidationError, match="mode_spec.num_modes"): + with pytest.raises(ValidationError, match="mode_spec.num_modes"): WavePort( center=(0, 0, -10), size=(0, 2, 2), @@ -2199,7 +2208,7 @@ def test_wave_port_mode_index_validation(): ) # Invalid: duplicate indices - with pytest.raises(pd.ValidationError, match="duplicate"): + with pytest.raises(ValidationError, match="duplicate"): WavePort( center=(0, 0, -10), size=(0, 2, 2), @@ -2272,7 +2281,7 @@ def test_get_task_name(): """Test get_task_name with RF ports.""" # First make sure ports cannot have @ in their name - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): lumped_port = LumpedPort( center=(0, 0, 0), size=(1, 0, 0.5), diff --git a/tests/test_plugins/test_array_factor.py b/tests/test_plugins/test_array_factor.py index 959e1dc7b0..93de0a4782 100644 --- a/tests/test_plugins/test_array_factor.py +++ b/tests/test_plugins/test_array_factor.py @@ -3,8 +3,8 @@ from __future__ import annotations import numpy as np -import pydantic.v1 as pydantic import pytest +from pydantic import ValidationError import tidy3d as td import tidy3d.plugins.microwave as mw @@ -96,19 +96,19 @@ def test_rectangular_array_calculator_basic(): # test ._extend_dims assert array_calculator._extend_dims == [1, 2] - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): # Test invalid array size mw.RectangularAntennaArrayCalculator( array_size=(0, 4, 5), spacings=(0.5, 0.5, 0.5), phase_shifts=(0, 0, 0) ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): # Test invalid spacings mw.RectangularAntennaArrayCalculator( array_size=(3, 4, 5), spacings=(-0.5, 0.5, 0.5), phase_shifts=(0, 0, 0) ) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): # Test wrong length of amp_multipliers mw.RectangularAntennaArrayCalculator( array_size=(3, 4, 5), @@ -118,7 +118,7 @@ def test_rectangular_array_calculator_basic(): ) for i in range(3): - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): # Test wrong length of amp_multipliers components amps = [np.ones(3), np.ones(4), np.ones(5)] amps[i] = np.ones(10) @@ -737,7 +737,7 @@ def test_rectangular_array_calculator_array_factor_taper(): phi_y = np.pi / 4 phi_z = np.pi / 3 - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): # Test for type mismatch taper = mw.RadialTaper(window=mw.ChebWindow(attenuation=45)) @@ -814,7 +814,7 @@ def test_rectangular_array_calculator_array_factor_taper(): assert af_amps_1d.shape == (100, 3) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): # assert that Rectangular Taper has at least one set window taper = mw.RectangularTaper() diff --git a/tests/test_plugins/test_design.py b/tests/test_plugins/test_design.py index 376c68849b..5d05f2610c 100644 --- a/tests/test_plugins/test_design.py +++ b/tests/test_plugins/test_design.py @@ -18,7 +18,7 @@ SWEEP_METHODS = { "grid": tdd.MethodGrid(), "monte_carlo": tdd.MethodMonteCarlo(num_points=5, seed=1), - "bay_opt": tdd.MethodBayOpt(initial_iter=5, n_iter=2, seed=1), + "bay_opt": tdd.MethodBayOpt(initial_iter=3, n_iter=2, seed=2), "gen_alg": tdd.MethodGenAlg( solutions_per_pop=6, n_generations=2, @@ -323,15 +323,15 @@ def init_design_space(sweep_method): radius_variable = tdd.ParameterFloat( name="radius", span=(0, 1.5), - num_points=5, # note: only used for MethodGrid + num_points=3, # note: only used for MethodGrid ) num_spheres_variable = tdd.ParameterInt( name="num_spheres", - span=(0, 3), + span=(0, 2), ) - tag_variable = tdd.ParameterAny(name="tag", allowed_values=("tag1", "tag2", "tag3")) + tag_variable = tdd.ParameterAny(name="tag", allowed_values=("tag1", "tag2")) design_space = tdd.DesignSpace( parameters=[radius_variable, num_spheres_variable, tag_variable], @@ -344,6 +344,7 @@ def init_design_space(sweep_method): @pytest.mark.parametrize("sweep_method", SWEEP_METHODS.values()) +@pytest.mark.slow def test_sweep(sweep_method, monkeypatch): # Problem, simulate scattering cross section of sphere ensemble # simulation consists of `num_spheres` spheres of radius `radius`. diff --git a/tests/test_plugins/test_dispersion_fitter.py b/tests/test_plugins/test_dispersion_fitter.py index fbdf541c4e..c77d894e5e 100644 --- a/tests/test_plugins/test_dispersion_fitter.py +++ b/tests/test_plugins/test_dispersion_fitter.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic +import pydantic as pd import pytest import responses import rich @@ -45,7 +45,7 @@ def mock_url(*args, **kwargs): responses.add( responses.POST, f"{mock_url()}/dispersion/fit", - json={"message": td.PoleResidue().json(), "rms": 1e-16}, + json={"message": td.PoleResidue().model_dump_json(), "rms": 1e-16}, status=200, ) @@ -80,13 +80,13 @@ def test_lossless_dispersion(random_data, mock_remote_api): """perform fitting on random data""" # wrong input data - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): fitter = DispersionFitter(wvl_um=[], n_data=()) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): fitter = DispersionFitter(wvl_um=[1.0], n_data=(1.0, 1.1)) - with pytest.raises(pydantic.ValidationError): + with pytest.raises(pd.ValidationError): fitter = DispersionFitter(wvl_um=[1.0], n_data=(1.0), k_data=(0, 1)) with pytest.raises(SetupError): diff --git a/tests/test_plugins/test_invdes.py b/tests/test_plugins/test_invdes.py index f3ba2f3b96..57dcedf9ea 100644 --- a/tests/test_plugins/test_invdes.py +++ b/tests/test_plugins/test_invdes.py @@ -5,9 +5,11 @@ import numpy as np import numpy.testing as npt import pytest +from pydantic import ValidationError import tidy3d as td import tidy3d.plugins.invdes as tdi +from tidy3d.exceptions import SetupError from tidy3d.plugins.expressions import ModeAmp, ModePower from tidy3d.plugins.invdes.initialization import ( CustomInitializationSpec, @@ -156,6 +158,19 @@ def test_region_inf_size(): _ = region.to_structure(params_0_inf) +def test_region_priority(): + region = make_design_region() + + # Test default priority (None) + structure_default = region.to_structure(region.params_zeros) + assert structure_default.priority is None + + # Test explicit priority value + region = region.updated_copy(priority=1) + structure = region.to_structure(region.params_zeros) + assert structure.priority == 1 + + def post_process_fn(sim_data: td.SimulationData, **kwargs) -> float: """Define a post-processing function with extra kwargs (recommended).""" intensity = sim_data.get_intensity(MNT_NAME1) @@ -256,7 +271,6 @@ def make_invdes_multi(): region = make_design_region() simulations = n * [simulation] - # post_process_fns = n * [post_process_fn] invdes = tdi.InverseDesignMulti( design_region=region, @@ -273,12 +287,12 @@ def test_invdes_multi_same_length(): invdes = make_invdes_multi() n = len(invdes.simulations) - output_monitor_names = (n + 1) * [["test"]] + output_monitor_names = (n + 1) * [("test",)] with pytest.raises(ValueError): _ = invdes.updated_copy(output_monitor_names=output_monitor_names) - output_monitor_names = [([MNT_NAME1, MNT_NAME2], None)[i % 2] for i in range(n)] + output_monitor_names = [((MNT_NAME1, MNT_NAME2), None)[i % 2] for i in range(n)] invdes = invdes.updated_copy(output_monitor_names=output_monitor_names) _ = invdes.designs @@ -316,10 +330,34 @@ def test_warn_zero_grad(use_emulated_run): # noqa: F811 """Test default paramns running the optimization defined in the ``InverseDesign`` object.""" optimizer = make_optimizer() - with AssertLogLevel("WARNING", contains_str="All elements of the gradient are almost zero"): + design_region = optimizer.design.design_region.updated_copy(penalties=()) + design = optimizer.design.updated_copy(design_region=design_region) + optimizer = optimizer.updated_copy(design=design) + + with pytest.raises(SetupError, match="All elements of the gradient are exactly zero"): optimizer.run(post_process_fn=post_process_fn_untraced) +def test_scaled_objective_grad_not_filtered(use_emulated_run): # noqa: F811 + """Test that scaled objectives do not result in an all-zero gradient.""" + + optimizer = make_optimizer() + design_region = optimizer.design.design_region.updated_copy(penalties=()) + design = optimizer.design.updated_copy(design_region=design_region) + optimizer = optimizer.updated_copy(design=design) + + scale = 1e-10 + + def post_process_fn_scaled(sim_data: td.SimulationData, **kwargs) -> float: + intensity = sim_data.get_intensity(MNT_NAME1) + return scale * anp.sum(intensity.values) + + result = optimizer.run(post_process_fn=post_process_fn_scaled) + + grad = result.grad[-1] + assert np.count_nonzero(grad) > 0 + + def make_result_multi(use_emulated_run): # noqa: F811 """Test running the optimization defined in the ``InverseDesignMulti`` object.""" @@ -368,6 +406,7 @@ def test_continue_run_fns(use_emulated_run): # noqa: F811 ) +@pytest.mark.slow def test_continue_run_from_file(use_emulated_run): # noqa: F811 """Test continuing an already run inverse design from file.""" result_orig = make_result(use_emulated_run) @@ -445,13 +484,13 @@ def test_invdes_io(tmp_path, use_emulated_run): # noqa: F811 design = optimizer.design for obj in (design, optimizer, result): - obj.json() + obj.model_dump_json() path = str(tmp_path / "obj.hdf5") obj.to_file(path) obj2 = obj.from_file(path) - assert obj2.json() == obj.json() + assert obj2.model_dump_json() == obj.model_dump_json() def test_objective_utilities(use_emulated_run): # noqa: F811 @@ -601,29 +640,29 @@ def test_validate_invdes_metric(): """Test the _validate_metric_monitor_name validator.""" invdes = make_invdes() metric = ModePower(monitor_name="invalid_monitor", f=[FREQ0]) - with pytest.raises(ValueError, match="monitors"): + with pytest.raises(ValidationError, match="monitors"): invdes.updated_copy(metric=metric) metric = ModePower(monitor_name=MNT_NAME2, mode_index=10, f=[FREQ0]) - with pytest.raises(ValueError, match="mode index"): + with pytest.raises(ValidationError, match="mode index"): invdes.updated_copy(metric=metric) metric = ModePower(monitor_name=MNT_NAME2, mode_index=0, f=[FREQ0 / 2]) - with pytest.raises(ValueError, match="frequencies"): + with pytest.raises(ValidationError, match="frequencies"): invdes.updated_copy(metric=metric) metric = ModePower(monitor_name=MNT_NAME2, mode_index=0) monitor = mnt2.updated_copy(freqs=[FREQ0, FREQ0 / 2]) - invdes = invdes.updated_copy(simulation=simulation.updated_copy(monitors=[monitor])) - with pytest.raises(ValueError, match="single frequency"): + invdes = invdes.updated_copy(simulation=simulation.updated_copy(monitors=(monitor,))) + with pytest.raises(ValidationError, match="single frequency"): invdes.updated_copy(metric=metric) metric = ModeAmp(monitor_name=MNT_NAME2, mode_index=0) + ModePower( monitor_name=MNT_NAME2, mode_index=0 ) monitor = mnt2.updated_copy(freqs=[FREQ0]) - invdes = invdes.updated_copy(simulation=simulation.updated_copy(monitors=[monitor])) - with pytest.raises(ValueError, match="must return a real"): + invdes = invdes.updated_copy(simulation=simulation.updated_copy(monitors=(monitor,))) + with pytest.raises(ValidationError, match="must return a real"): invdes.updated_copy(metric=metric) @@ -677,3 +716,26 @@ def test_result_params_out_of_bounds(): # get_sim should work without issues sim = result.get_sim(index=0) + + +@pytest.mark.parametrize("check_low", [False, True]) +def test_transformation_clipping(check_low): + """Test the output of `FilterProject` is between 0 and 1.""" + + filter_project = tdi.FilterProject(radius=0.5, beta=1.0) + design_region_dl = 0.02 + + if check_low: + test_region = np.zeros((100, 100)) + test_region[40:60, 40:60] = 1.0 + + output_region = filter_project.evaluate(test_region, design_region_dl) + + assert np.min(output_region) >= 0.0, "Output region minimum below 0.0" + else: + test_region = np.ones((100, 100)) + test_region[40:60, 40:60] = 0.0 + + output_region = filter_project.evaluate(test_region, design_region_dl) + + assert np.max(output_region) <= 1.0, "Output region maximum above 1.0" diff --git a/tests/test_plugins/test_microwave.py b/tests/test_plugins/test_microwave.py index 3132bfdd5d..3fc67998a9 100644 --- a/tests/test_plugins/test_microwave.py +++ b/tests/test_plugins/test_microwave.py @@ -6,8 +6,8 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pd import pytest +from pydantic import ValidationError from skrf import Frequency from skrf.media import MLine @@ -84,7 +84,7 @@ def test_lobe_measurer_validation(): Urad = np.cos(theta) # Raise error when radiation pattern is negative - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): mw.LobeMeasurer( angle=theta, radiation_pattern=Urad, @@ -92,7 +92,7 @@ def test_lobe_measurer_validation(): Urad = np.cos(theta) + 1j * np.sin(theta) # Raise error when radiation pattern is complex - with pytest.raises(pd.ValidationError), pytest.warns(np.exceptions.ComplexWarning): + with pytest.raises(ValidationError), pytest.warns(np.exceptions.ComplexWarning): mw.LobeMeasurer( angle=theta, radiation_pattern=Urad, @@ -104,7 +104,7 @@ def test_lobe_measurer_validation(): mw.LobeMeasurer(angle=theta, radiation_pattern=Urad, apply_cyclic_extension=False) # Raise error when cyclic extension is enabled and angle array is not in [0, 2π) - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): mw.LobeMeasurer( angle=theta, radiation_pattern=Urad, @@ -114,7 +114,7 @@ def test_lobe_measurer_validation(): theta[10] = theta[75] Urad = np.cos(theta) ** 2 # Make sure array is sorted - with pytest.raises(pd.ValidationError): + with pytest.raises(ValidationError): mw.LobeMeasurer(angle=theta, radiation_pattern=Urad, apply_cyclic_extension=False) diff --git a/tests/test_plugins/test_mode_solver.py b/tests/test_plugins/test_mode_solver.py index cdc412513d..9c8b2c2b1c 100644 --- a/tests/test_plugins/test_mode_solver.py +++ b/tests/test_plugins/test_mode_solver.py @@ -4,9 +4,9 @@ import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pydantic import pytest import responses +from pydantic import ValidationError import tidy3d as td import tidy3d.plugins.mode.web as msweb @@ -47,11 +47,11 @@ def mock_download(resource_id, remote_filename, to_file, *args, **kwargs): simulation = td.Simulation( size=SIM_SIZE, grid_spec=td.GridSpec(wavelength=1.0), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, symmetry=(1, 0, -1), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) mode_spec = td.ModeSpec( num_modes=3, @@ -273,7 +273,7 @@ def test_mode_solver_validation(): ) # frequency is too low - with pytest.raises(pydantic.ValidationError): + with pytest.raises(ValidationError): ms = ModeSolver( simulation=simulation, plane=PLANE, @@ -293,7 +293,7 @@ def test_mode_solver_validation(): # num of modes * plane grid points too large # 1) number of modes too big - with pytest.raises(SetupError): + with pytest.raises(ValidationError): ms = ModeSolver( simulation=simulation, plane=PLANE, @@ -302,7 +302,7 @@ def test_mode_solver_validation(): direction="+", ) # 2) number of grid points too big - with pytest.raises(SetupError): + with pytest.raises(ValidationError): ms = ModeSolver( simulation=simulation.updated_copy(grid_spec=td.GridSpec.uniform(dl=0.0001)), plane=PLANE, @@ -356,9 +356,7 @@ def test_mode_solver_fields(): grid_spec=td.GridSpec(wavelength=1.0), run_time=1e-12, ) - mode_spec = td.ModeSpec( - num_modes=1, - ) + mode_spec = td.ModeSpec(num_modes=1) ms = ModeSolver( simulation=simulation, plane=PLANE, @@ -384,17 +382,17 @@ def test_mode_solver_fields(): @pytest.mark.parametrize("local", [True, False]) @responses.activate -def test_mode_solver_simple(mock_remote_api, local): +def test_mode_solver_simple(mock_remote_api, local, tmp_path): """Simple mode solver run (with symmetry)""" simulation = td.Simulation( size=SIM_SIZE, grid_spec=td.GridSpec(wavelength=1.0), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, symmetry=(0, 0, 1), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) mode_spec = td.ModeSpec( num_modes=3, @@ -423,7 +421,7 @@ def test_mode_solver_simple(mock_remote_api, local): check_ms_reduction(ms) else: - _ = msweb.run(ms) + _ = msweb.run(ms, results_file=tmp_path / "tmp.hdf5") # Testing issue 807 functions freq0 = td.C_0 / 1.55 @@ -444,11 +442,11 @@ def test_mode_solver_remote_after_local(mock_remote_api, tmp_path): simulation = td.Simulation( size=SIM_SIZE, grid_spec=td.GridSpec(wavelength=1.0), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, symmetry=(0, 0, 1), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) mode_spec = td.ModeSpec( num_modes=3, @@ -495,7 +493,7 @@ def test_mode_solver_custom_medium(mock_remote_api, local, tmp_path): simulation = td.Simulation( size=(2, 2, 2), grid_spec=td.GridSpec(wavelength=1.0), - structures=[waveguide], + structures=(waveguide,), run_time=1e-12, ) mode_spec = td.ModeSpec( @@ -569,7 +567,7 @@ def test_mode_solver_unstructured_custom_medium(nx, cond_factor, interp, tol, tm simulation = td.Simulation( size=(2, 2, 2), grid_spec=td.GridSpec(wavelength=1.0), - structures=[waveguide], + structures=(waveguide,), run_time=1e-12, ) mode_spec = td.ModeSpec(num_modes=1) @@ -605,11 +603,11 @@ def test_mode_solver_straight_vs_angled(): simulation = td.Simulation( size=SIM_SIZE, grid_spec=td.GridSpec.auto(wavelength=1.0, min_steps_per_wvl=16), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, symmetry=(0, 0, 1), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) mode_spec = td.ModeSpec(num_modes=5, group_index_step=True) freqs = [td.C_0 / 0.9, td.C_0 / 1.0, td.C_0 / 1.1] @@ -679,11 +677,11 @@ def test_mode_solver_angle_bend(): simulation = td.Simulation( size=SIM_SIZE, grid_spec=td.GridSpec(wavelength=1.0), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, symmetry=(-1, 0, 1), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) mode_spec = td.ModeSpec( num_modes=3, @@ -726,7 +724,7 @@ def test_mode_bend_radius(): size=(10, 10, 10), grid_spec=td.GridSpec(wavelength=1.0), # grid_spec=td.GridSpec.uniform(dl=0.04), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, ) mode_spec1 = td.ModeSpec( @@ -780,10 +778,10 @@ def test_mode_solver_2D(): simulation = td.Simulation( size=(0, SIM_SIZE[1], SIM_SIZE[2]), grid_spec=td.GridSpec(wavelength=1.0), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) ms = ModeSolver( simulation=simulation, plane=PLANE, mode_spec=mode_spec, freqs=[td.C_0 / 1.0], direction="-" @@ -803,10 +801,10 @@ def test_mode_solver_2D(): simulation = td.Simulation( size=(SIM_SIZE[0], SIM_SIZE[1], 0), grid_spec=td.GridSpec(wavelength=1.0), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, boundary_spec=td.BoundarySpec.pml(z=False), - sources=[SRC], + sources=(SRC,), ) ms = ModeSolver( simulation=simulation, plane=PLANE, mode_spec=mode_spec, freqs=[td.C_0 / 1.0], direction="+" @@ -822,7 +820,7 @@ def test_mode_solver_2D(): grid_spec=td.GridSpec(wavelength=1.0), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) ms = ModeSolver(simulation=simulation, plane=PLANE, mode_spec=mode_spec, freqs=[td.C_0 / 1.0]) compare_colocation(ms) @@ -839,15 +837,15 @@ def test_group_index(mock_remote_api, local, tmp_path): simulation = td.Simulation( size=(5, 5, 1), grid_spec=td.GridSpec(wavelength=1.55), - structures=[ + structures=( td.Structure( geometry=td.Box(size=(0.5, 0.22, td.inf)), medium=td.Medium(permittivity=3.48**2) - ) - ], + ), + ), medium=td.Medium(permittivity=1.44**2), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) mode_spec = td.ModeSpec( num_modes=2, @@ -888,7 +886,7 @@ def test_group_index(mock_remote_api, local, tmp_path): mode_spec=mode_spec.copy(update={"group_index_step": True}), freqs=freqs, ) - modes = ms.solve() if local else msweb.run(ms) + modes = ms.solve() if local else msweb.run(ms, results_file=tmp_path / "tmp.hdf5") if local: assert (modes.n_group.sel(mode_index=0).values > 3.9).all() assert (modes.n_group.sel(mode_index=0).values < 4.2).all() @@ -936,11 +934,11 @@ def test_mode_solver_nan_pol_fraction(): medium=td.Medium(permittivity=2), size=SIM_SIZE, grid_spec=td.GridSpec.auto(wavelength=1.55, min_steps_per_wvl=15), - structures=[wg], + structures=(wg,), run_time=1e-12, symmetry=(0, 0, 1), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) mode_spec = td.ModeSpec( @@ -986,7 +984,7 @@ def test_mode_solver_method_defaults(): run_time=1e-12, symmetry=(0, 0, 1), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) mode_spec = td.ModeSpec( @@ -1030,7 +1028,7 @@ def test_mode_solver_method_defaults(): @responses.activate -def test_mode_solver_web_run_batch(mock_remote_api): +def test_mode_solver_web_run_batch(mock_remote_api, tmp_path): """Testing run_batch function for the web mode solver.""" wav = 1.5 @@ -1043,7 +1041,7 @@ def test_mode_solver_web_run_batch(mock_remote_api): simulation = td.Simulation( size=SIM_SIZE, grid_spec=td.GridSpec(wavelength=wav), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, boundary_spec=td.BoundarySpec.all_sides(boundary=td.PML()), ) @@ -1065,7 +1063,13 @@ def test_mode_solver_web_run_batch(mock_remote_api): ) # Run mode solver one at a time - results = msweb.run_batch(mode_solver_list, verbose=False, folder_name="Mode Solver") + results_files = [tmp_path / f"ms_batch_{i}.hdf5" for i in range(num_of_sims)] + results = msweb.run_batch( + mode_solver_list, + verbose=False, + folder_name="Mode Solver", + results_files=results_files, + ) print(*results, sep="\n") assert all(isinstance(x, ModeSolverData) for x in results) assert (results[i].n_eff.shape == (num_freqs, i + 1) for i in range(num_of_sims)) @@ -1077,11 +1081,11 @@ def test_mode_solver_relative(): simulation = td.Simulation( size=SIM_SIZE, grid_spec=td.GridSpec(wavelength=1.0), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, symmetry=(0, 0, 1), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) mode_spec = td.ModeSpec( num_modes=3, @@ -1111,11 +1115,11 @@ def test_mode_solver_plot(): simulation = td.Simulation( size=SIM_SIZE, grid_spec=td.GridSpec(wavelength=1.0), - structures=[WAVEGUIDE], + structures=(WAVEGUIDE,), run_time=1e-12, symmetry=(0, 0, 1), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), - sources=[SRC], + sources=(SRC,), ) mode_spec = td.ModeSpec( num_modes=3, @@ -1144,7 +1148,7 @@ def test_mode_solver_plot(): @pytest.mark.parametrize("local", [True, False]) @responses.activate -def test_modes_eme_sim(mock_remote_api, local): +def test_modes_eme_sim(mock_remote_api, local, tmp_path): lambda0 = 1 freq0 = td.C_0 / lambda0 sim_size = (1, 1, 1) @@ -1161,8 +1165,10 @@ def test_modes_eme_sim(mock_remote_api, local): _ = solver.data else: with pytest.raises(SetupError): - _ = msweb.run(solver) - _ = msweb.run(solver.to_fdtd_mode_solver()) + _ = msweb.run(solver, results_file=tmp_path / "eme_solver_remote.hdf5") + _ = msweb.run( + solver.to_fdtd_mode_solver(), results_file=tmp_path / "eme_solver_fdtd_remote.hdf5" + ) _ = solver.reduced_simulation_copy @@ -1221,7 +1227,7 @@ def make_high_order_mode_solver(sign, dim=3): grid_spec=td.GridSpec.auto( min_steps_per_wvl=20, wavelength=1.55, override_structures=[refine_box] ), - structures=[waveguide], + structures=(waveguide,), medium=td.Medium(permittivity=1.44**2), boundary_spec=td.BoundarySpec(x=pml, y=pml, z=pml if dim == 3 else periodic), run_time=1e-12, @@ -1282,7 +1288,7 @@ def test_translated_dot(): grid_spec = td.GridSpec.auto(wavelength=lambda0, min_steps_per_wvl=20) sim = td.Simulation( - size=sim_size, medium=sio2, structures=[wg], grid_spec=grid_spec, run_time=1e-30 + size=sim_size, medium=sio2, structures=(wg,), grid_spec=grid_spec, run_time=1e-30 ) mode_plane = td.Box(size=(3, 3, 0)) mode_solver = ModeSolver(simulation=sim, plane=mode_plane, mode_spec=mode_spec, freqs=[freq0]) @@ -1315,7 +1321,7 @@ def test_translated_dot(): def test_mode_spec_filter_pol_sort_spec_exclusive(): """Ensure ModeSpec errors when both filter_pol and sort_spec are set.""" - with pytest.raises(pydantic.ValidationError, match="simultaneously"): + with pytest.raises(ValidationError, match="simultaneously"): _ = td.ModeSpec(num_modes=1, filter_pol="te", sort_spec=td.ModeSortSpec(sort_key="n_eff")) diff --git a/tests/test_plugins/test_waveguide.py b/tests/test_plugins/test_waveguide.py index 2708559251..3949d3e86c 100644 --- a/tests/test_plugins/test_waveguide.py +++ b/tests/test_plugins/test_waveguide.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError import tidy3d as td from tidy3d.plugins import waveguide diff --git a/tests/test_web/test_local_cache.py b/tests/test_web/test_local_cache.py index 3c8a9c7ee6..923d04beb2 100644 --- a/tests/test_web/test_local_cache.py +++ b/tests/test_web/test_local_cache.py @@ -11,7 +11,6 @@ import autograd as ag import pytest -import xarray as xr from autograd.core import defvjp from click.testing import CliRunner from rich.console import Console @@ -19,6 +18,7 @@ import tidy3d as td from tests.test_components.autograd.test_autograd import ALL_KEY, get_functions, params0 from tests.test_web.test_webapi_mode import make_mode_sim +from tests.utils import run_emulated from tidy3d import config from tidy3d.components.autograd.field_map import FieldMap from tidy3d.config import get_manager @@ -61,24 +61,20 @@ def _reset_fake_maps(): class _FakeStubData: def __init__(self, simulation: td.Simulation): self.simulation = simulation + self.sim_data = run_emulated(self.simulation) def __getitem__(self, key): - if key == "mode": - params = self.simulation.attrs["params_autograd"] - return SimpleNamespace( - amps=xr.DataArray(params, dims=["x"], coords={"x": list(range(len(params)))}) - ) + return self.sim_data[key] def _strip_traced_fields(self, *args, **kwargs): """Fake _strip_traced_fields: return minimal valid autograd-style mapping.""" - return {"params": self.simulation.attrs["params"]} + return self.sim_data._strip_traced_fields(*args, **kwargs) def _insert_traced_fields(self, field_mapping, *args, **kwargs): - self.simulation.attrs["params_autograd"] = field_mapping["params"] - return self + return self.sim_data._insert_traced_fields(field_mapping, *args, **kwargs) def _make_adjoint_sims(self, **kwargs): - return [self.simulation.updated_copy(run_time=self.simulation.run_time * 2)] + return self.sim_data._make_adjoint_sims(**kwargs) @pytest.fixture @@ -285,7 +281,7 @@ def _test_load_simulation_if_cached(monkeypatch, tmp_path, basic_simulation): assert counters == {"upload": 1, "start": 1, "monitor": 1, "download": 1} assert len(cache) == 1 - sim_data_from_cache = load_simulation_if_cached(basic_simulation) + sim_data_from_cache = load_simulation_if_cached(basic_simulation, path=tmp_path / "tmp.hdf5") assert sim_data_from_cache is not None assert sim_data_from_cache.simulation == basic_simulation @@ -296,20 +292,20 @@ def _test_load_simulation_if_cached(monkeypatch, tmp_path, basic_simulation): def _test_mode_solver_caching(monkeypatch, tmp_path): counters = _patch_run_pipeline(monkeypatch) - + tmp_file = tmp_path / "tmp.hdf5" # store in cache mode_sim = make_mode_sim() - mode_sim_data = web.run(mode_sim) + mode_sim_data = web.run(mode_sim, path=tmp_file) # test basic loading from cache - from_cache_data = load_simulation_if_cached(mode_sim) + from_cache_data = load_simulation_if_cached(mode_sim, path=tmp_file) assert from_cache_data is not None assert isinstance(from_cache_data, _FakeStubData) assert mode_sim_data.simulation == from_cache_data.simulation # test loading from run _reset_counters(counters) - mode_sim_data_run = web.run(mode_sim) + mode_sim_data_run = web.run(mode_sim, path=tmp_file) assert counters["download"] == 0 assert isinstance(mode_sim_data_run, _FakeStubData) assert mode_sim_data.simulation == mode_sim_data_run.simulation @@ -317,7 +313,7 @@ def _test_mode_solver_caching(monkeypatch, tmp_path): # test loading from job _reset_counters(counters) job = Job(simulation=mode_sim, task_name="test") - job_data = job.run() + job_data = job.run(path=tmp_file) assert counters["download"] == 0 assert isinstance(job_data, _FakeStubData) assert mode_sim_data.simulation == job_data.simulation @@ -334,14 +330,14 @@ def _test_mode_solver_caching(monkeypatch, tmp_path): cache = resolve_local_cache(True) # test storing via job cache.clear() - Job(simulation=mode_sim, task_name="test").run() - assert load_simulation_if_cached(mode_sim) is not None + Job(simulation=mode_sim, task_name="test").run(path=tmp_file) + assert load_simulation_if_cached(mode_sim, path=tmp_file) is not None # test storing via batch cache.clear() batch_mode_data = Batch(simulations={"sim1": mode_sim}).run(path_dir=tmp_path) _ = batch_mode_data["sim1"] # access to store - assert load_simulation_if_cached(mode_sim) is not None + assert load_simulation_if_cached(mode_sim, path=tmp_file) is not None def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path): @@ -382,7 +378,7 @@ def _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path): assert len(cache) == 3 -def _test_verbosity(monkeypatch, basic_simulation): +def _test_verbosity(monkeypatch, basic_simulation, tmp_path): _CSI_RE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]") # ANSI CSI _OSC8_RE = re.compile(r"\x1b\]8;.*?(?:\x1b\\|\x07)", re.DOTALL) # OSC-8 hyperlinks @@ -403,8 +399,8 @@ def _normalize_console_text(s: str) -> str: _reset_counters(counters) sim2 = basic_simulation.updated_copy(shutoff=1e-4) sim3 = basic_simulation.updated_copy(shutoff=1e-3) - - run(basic_simulation, verbose=True) # seed cache + tmp_file = tmp_path / "tmp.hdf5" + run(basic_simulation, verbose=True, path=tmp_file) # seed cache log_mod = importlib.import_module("tidy3d.log") @@ -424,7 +420,7 @@ def _normalize_console_text(s: str) -> str: buf.seek(0) # test for load_simulation_if_cached - sim_data = load_simulation_if_cached(basic_simulation, verbose=True) + sim_data = load_simulation_if_cached(basic_simulation, verbose=True, path=tmp_file) assert sim_data is not None assert "Loading simulation from" in buf.getvalue(), ( f"Expected 'Loading simulation from' in log, got '{buf.getvalue()}'" @@ -432,14 +428,14 @@ def _normalize_console_text(s: str) -> str: buf.truncate(0) buf.seek(0) - load_simulation_if_cached(basic_simulation, verbose=False) + load_simulation_if_cached(basic_simulation, verbose=False, path=tmp_file) assert sim_data is not None assert buf.getvalue().strip() == "", f"Expected empty log, got '{buf.getvalue()}'" # test for batched runs buf.truncate(0) buf.seek(0) - run([basic_simulation, sim3], verbose=True) + run([basic_simulation, sim3], verbose=True, path=tmp_path) txt = _normalize_console_text(buf.getvalue()) assert "Got 1 simulation from cache" in txt, ( f"Expected 'Got 1 simulation from cache' in log, got '{buf.getvalue()}'" @@ -448,13 +444,13 @@ def _normalize_console_text(s: str) -> str: # if some found buf.truncate(0) buf.seek(0) - run([basic_simulation, sim2], verbose=False) + run([basic_simulation, sim2], verbose=False, path=tmp_path) assert buf.getvalue().strip() == "", f"Expected empty log, got '{buf.getvalue()}'" # if all found buf.truncate(0) buf.seek(0) - run([basic_simulation, sim2], verbose=False) + run([basic_simulation, sim2], verbose=False, path=tmp_path) assert buf.getvalue().strip() == "", f"Expected empty log, got '{buf.getvalue()}'" finally: @@ -467,7 +463,7 @@ def _test_job_run_cache(monkeypatch, basic_simulation, tmp_path): cache = resolve_local_cache(use_cache=True) cache.clear() job = Job(simulation=basic_simulation, task_name="test") - job.run() + job.run(path=tmp_path / "tmp.hdf5") assert len(cache) == 1 @@ -485,7 +481,7 @@ def _test_job_run_cache(monkeypatch, basic_simulation, tmp_path): assert os.path.exists(out2_path) -def _test_autograd_cache(monkeypatch, request): +def _test_autograd_cache(monkeypatch, request, tmp_path): counters = _patch_run_pipeline(monkeypatch) # "Original" rule: the one autograd uses by default @@ -525,8 +521,7 @@ def _restore_make_dict_vjp(): def objective(params): sim = make_sim(params) - sim.attrs["params"] = params - sim_data = run_autograd(sim) + sim_data = run_autograd(sim, path=tmp_path / "tmp.hdf5") value = postprocess(sim_data) return value @@ -580,13 +575,13 @@ def _test_cache_eviction_by_entries(monkeypatch, tmp_path_factory, basic_simulat file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME file1.write_text("a") - cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD") + cache.store_result(MOCK_TASK_ID, str(file1), "FDTD", simulation=basic_simulation) assert len(cache) == 1 sim2 = basic_simulation.updated_copy(shutoff=1e-4) file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME file2.write_text("b") - cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD") + cache.store_result(MOCK_TASK_ID, str(file2), "FDTD", simulation=sim2) entries = cache.list() assert len(entries) == 1 @@ -600,13 +595,13 @@ def _test_cache_eviction_by_size(monkeypatch, tmp_path_factory, basic_simulation file1 = tmp_path_factory.mktemp("art1") / CACHE_ARTIFACT_NAME file1.write_text("a" * 8_000) - cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(file1), "FDTD") + cache.store_result(MOCK_TASK_ID, str(file1), "FDTD", simulation=basic_simulation) assert len(cache) == 1 sim2 = basic_simulation.updated_copy(shutoff=1e-4) file2 = tmp_path_factory.mktemp("art2") / CACHE_ARTIFACT_NAME file2.write_text("b" * 8_000) - cache.store_result(_FakeStubData(sim2), MOCK_TASK_ID, str(file2), "FDTD") + cache.store_result(MOCK_TASK_ID, str(file2), "FDTD", simulation=sim2) entries = cache.list() assert len(cache) == 1 @@ -622,7 +617,7 @@ def _test_cache_stats_tracking(monkeypatch, tmp_path_factory, basic_simulation): payload = "stats-payload" artifact.write_text(payload) - cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(artifact), "FDTD") + cache.store_result(MOCK_TASK_ID, str(artifact), "FDTD", simulation=basic_simulation) stats_path = cache.root / CACHE_STATS_NAME assert stats_path.exists() @@ -661,12 +656,12 @@ def _test_cache_stats_sync(monkeypatch, tmp_path_factory, basic_simulation): artifact1 = tmp_path_factory.mktemp("artifact_sync1") / CACHE_ARTIFACT_NAME payload1 = "sync-one" artifact1.write_text(payload1) - cache.store_result(_FakeStubData(sim1), f"{MOCK_TASK_ID}-1", str(artifact1), "FDTD") + cache.store_result(f"{MOCK_TASK_ID}-1", str(artifact1), "FDTD", simulation=sim1) artifact2 = tmp_path_factory.mktemp("artifact_sync2") / CACHE_ARTIFACT_NAME payload2 = "sync-two" artifact2.write_text(payload2) - cache.store_result(_FakeStubData(sim2), f"{MOCK_TASK_ID}-2", str(artifact2), "FDTD") + cache.store_result(f"{MOCK_TASK_ID}-2", str(artifact2), "FDTD", simulation=sim2) stats_path = cache.root / CACHE_STATS_NAME assert stats_path.exists() @@ -703,7 +698,7 @@ def _counting_iter(): artifact = tmp_path / "iter_guard.hdf5" artifact.write_text("payload") - cache.store_result(_FakeStubData(basic_simulation), MOCK_TASK_ID, str(artifact), "FDTD") + cache.store_result(MOCK_TASK_ID, str(artifact), "FDTD", simulation=basic_simulation) assert iter_calls["count"] == 0 entry_dirs = [] @@ -778,9 +773,7 @@ def _test_cache_cli_commands(monkeypatch, tmp_path_factory, basic_simulation): artifact = artifact_dir / CACHE_ARTIFACT_NAME artifact.write_text("payload_cli") - cache.store_result( - _FakeStubData(basic_simulation), f"{MOCK_TASK_ID}-cli", str(artifact), "FDTD" - ) + cache.store_result(f"{MOCK_TASK_ID}-cli", str(artifact), "FDTD", simulation=basic_simulation) info_result = runner.invoke(tidy3d_cli, ["cache", "info"]) assert info_result.exit_code == 0 @@ -823,9 +816,9 @@ def test_cache_sequential( _test_cache_stats_sync(monkeypatch, tmp_path_factory, basic_simulation) _test_run_cache_hit_async(monkeypatch, basic_simulation, tmp_path) _test_job_run_cache(monkeypatch, basic_simulation, tmp_path) - _test_autograd_cache(monkeypatch, request) + _test_autograd_cache(monkeypatch, request, tmp_path) _test_configure_cache_roundtrip(monkeypatch, tmp_path) _test_store_and_fetch_do_not_iterate(monkeypatch, tmp_path, basic_simulation) _test_mode_solver_caching(monkeypatch, tmp_path) - _test_verbosity(monkeypatch, basic_simulation) + _test_verbosity(monkeypatch, basic_simulation, tmp_path) _test_cache_cli_commands(monkeypatch, tmp_path_factory, basic_simulation) diff --git a/tests/test_web/test_s3utils.py b/tests/test_web/test_s3utils.py index a6552737a4..d1849bb32a 100644 --- a/tests/test_web/test_s3utils.py +++ b/tests/test_web/test_s3utils.py @@ -4,9 +4,11 @@ import pytest -import tidy3d +from tidy3d._common.web.core import s3utils as s3utils_common from tidy3d.web.core import s3utils +s3_utils_path = "tidy3d._common.web.core.s3utils" + @pytest.fixture def mock_S3STSToken(monkeypatch): @@ -16,9 +18,9 @@ def mock_S3STSToken(monkeypatch): mock_token.get_bucket = lambda: "" mock_token.get_s3_key = lambda: "" mock_token.is_expired = lambda: False - mock_token.get_client = lambda: tidy3d.web.core.s3utils.boto3.client() + mock_token.get_client = lambda: s3utils_common.boto3.client() monkeypatch.setattr( - target=tidy3d.web.core.s3utils, name="_S3STSToken", value=MagicMock(return_value=mock_token) + target=s3utils_common, name="_S3STSToken", value=MagicMock(return_value=mock_token) ) return mock_token @@ -26,10 +28,10 @@ def mock_S3STSToken(monkeypatch): @pytest.fixture def mock_get_s3_sts_token(monkeypatch): def _mock_get_s3_sts_token(resource_id, remote_filename): - return s3utils._S3STSToken(resource_id, remote_filename) + return s3utils_common._S3STSToken(resource_id, remote_filename) monkeypatch.setattr( - target=tidy3d.web.core.s3utils, name="get_s3_sts_token", value=_mock_get_s3_sts_token + target=s3utils_common, name="get_s3_sts_token", value=_mock_get_s3_sts_token ) return _mock_get_s3_sts_token @@ -44,7 +46,7 @@ def mock_s3_client(monkeypatch): # Patch the `client` as it is imported within `tidy3d.web.core.s3utils.boto3` so that # whenever it's invoked (for example with "s3"), it returns our `mock_client`. monkeypatch.setattr( - target=tidy3d.web.core.s3utils.boto3, + target=s3utils_common.boto3, name="client", value=MagicMock(return_value=mock_client), ) @@ -148,11 +150,11 @@ def test_s3_token_get_client_with_custom_endpoint(tmp_path, monkeypatch): # Mock boto3.client mock_boto_client = MagicMock() - monkeypatch.setattr("tidy3d.web.core.s3utils.boto3.client", mock_boto_client) + monkeypatch.setattr(f"{s3_utils_path}.boto3.client", mock_boto_client) # Test 1: Without custom endpoint - use fresh config test_config = ConfigManager(config_dir=tmp_path) - monkeypatch.setattr("tidy3d.web.core.s3utils.config", test_config) + monkeypatch.setattr(f"{s3_utils_path}.config", test_config) token.get_client() # Verify boto3.client was called without endpoint_url @@ -195,12 +197,12 @@ def test_s3_token_get_client_respects_ssl_verify(tmp_path, monkeypatch): token = _S3STSToken(**token_data) mock_boto_client = MagicMock() - monkeypatch.setattr("tidy3d.web.core.s3utils.boto3.client", mock_boto_client) + monkeypatch.setattr(f"{s3_utils_path}.boto3.client", mock_boto_client) # Use fresh config with ssl_verify=False test_config = ConfigManager(config_dir=tmp_path) test_config.update_section("web", ssl_verify=False) - monkeypatch.setattr("tidy3d.web.core.s3utils.config", test_config) + monkeypatch.setattr(f"{s3_utils_path}.config", test_config) token.get_client() diff --git a/tests/test_web/test_tidy3d_material_library.py b/tests/test_web/test_tidy3d_material_library.py index 25790b96dc..badeb2c35c 100644 --- a/tests/test_web/test_tidy3d_material_library.py +++ b/tests/test_web/test_tidy3d_material_library.py @@ -4,7 +4,7 @@ import responses import tidy3d as td -from tidy3d.web.api.material_libray import MaterialLibray +from tidy3d.web.api.material_library import MaterialLibrary from tidy3d.web.core.environment import Env Env.dev.active() @@ -27,6 +27,6 @@ def test_lib(set_api_key): json={"data": [{"id": "3eb06d16-208b-487b-864b-e9b1d3e010a7", "name": "medium1"}]}, status=200, ) - libs = MaterialLibray.list() + libs = MaterialLibrary.list() lib = libs[0] assert lib.name == "medium1" diff --git a/tests/test_web/test_tidy3d_stub.py b/tests/test_web/test_tidy3d_stub.py index 9d0f283ab7..217a18d646 100644 --- a/tests/test_web/test_tidy3d_stub.py +++ b/tests/test_web/test_tidy3d_stub.py @@ -118,20 +118,22 @@ def test_stub_data_to_file(tmp_path): def test_stub_data_postprocess_logs(tmp_path): """Tests the postprocess method of Tidy3dStubData when simulation diverged.""" td.log.set_capture(True) - - # test diverged - sim_data = make_sim_data() - sim_data = sim_data.updated_copy(diverged=True, log="The simulation has diverged!") - file_path = os.path.join(tmp_path, "test_diverged.hdf5") - sim_data.to_file(file_path) - Tidy3dStubData.postprocess(file_path) - - # test warnings - sim_data = make_sim_data() - sim_data = sim_data.updated_copy(log="WARNING: messages were found in the solver log.") - file_path = os.path.join(tmp_path, "test_warnings.hdf5") - sim_data.to_file(file_path) - Tidy3dStubData.postprocess(file_path) + try: + # test diverged + sim_data = make_sim_data() + sim_data = sim_data.updated_copy(diverged=True, log="The simulation has diverged!") + file_path = os.path.join(tmp_path, "test_diverged.hdf5") + sim_data.to_file(file_path) + Tidy3dStubData.postprocess(file_path) + + # test warnings + sim_data = make_sim_data() + sim_data = sim_data.updated_copy(log="WARNING: messages were found in the solver log.") + file_path = os.path.join(tmp_path, "test_warnings.hdf5") + sim_data.to_file(file_path) + Tidy3dStubData.postprocess(file_path) + finally: + td.log.set_capture(False) @responses.activate @@ -139,32 +141,34 @@ def test_stub_data_lazy_loading(tmp_path): """Tests the postprocess method with lazy loading of Tidy3dStubData when simulation diverged.""" td.log.set_capture(True) sim_diverged_log = "The simulation has diverged!" - - # make sim data where test diverged - sim_data = make_sim_data() - sim_data = sim_data.updated_copy(diverged=True, log=sim_diverged_log) - file_path = os.path.join(tmp_path, "test_diverged.hdf5") - sim_data.to_file(file_path) - - # default case with lazy=False should output a warning - with AssertLogLevel("WARNING", contains_str=sim_diverged_log): - Tidy3dStubData.postprocess(file_path, lazy=False) - - # we expect no warning in lazy mode as object should not be loaded - with AssertLogLevel(None): - sim_data = Tidy3dStubData.postprocess(file_path, lazy=True) - - sim_data_copy = sim_data.copy() - - # variable dict should only contain metadata to load the data, not the data itself - assert is_lazy_object(sim_data) - - # the type should be still SimulationData despite being lazy - assert isinstance(sim_data, SimulationData) - - # we expect a warning from the lazy object if some field is accessed - with AssertLogLevel("WARNING", contains_str=sim_diverged_log): - _ = sim_data_copy.monitor_data + try: + # make sim data where test diverged + sim_data = make_sim_data() + sim_data = sim_data.updated_copy(diverged=True, log=sim_diverged_log) + file_path = os.path.join(tmp_path, "test_diverged.hdf5") + sim_data.to_file(file_path) + + # default case with lazy=False should output a warning + with AssertLogLevel("WARNING", contains_str=sim_diverged_log): + Tidy3dStubData.postprocess(file_path, lazy=False) + + # we expect no warning in lazy mode as object should not be loaded + with AssertLogLevel(None): + sim_data = Tidy3dStubData.postprocess(file_path, lazy=True) + + sim_data_copy = sim_data.copy() + + # variable dict should only contain metadata to load the data, not the data itself + assert is_lazy_object(sim_data) + + # the type should be still SimulationData despite being lazy + assert isinstance(sim_data, SimulationData) + + # we expect a warning from the lazy object if some field is accessed + with AssertLogLevel("WARNING", contains_str=sim_diverged_log): + _ = sim_data_copy.monitor_data + finally: + td.log.set_capture(False) @pytest.mark.parametrize( diff --git a/tests/test_web/test_tidy3d_task.py b/tests/test_web/test_tidy3d_task.py index 5270964df8..27a4b65e66 100644 --- a/tests/test_web/test_tidy3d_task.py +++ b/tests/test_web/test_tidy3d_task.py @@ -7,6 +7,7 @@ from responses import matchers import tidy3d as td +from tests.test_web.test_webapi import task_core_path from tidy3d.web.core import http_util from tidy3d.web.core.environment import Env, EnvironmentConfig from tidy3d.web.core.task_core import Folder, SimulationTask @@ -37,7 +38,7 @@ def make_sim(): @pytest.fixture def set_api_key(monkeypatch): """Set the api key.""" - import tidy3d.web.core.http_util as httputil + import tidy3d._common.web.core.http_util as httputil monkeypatch.setattr(httputil, "api_key", lambda: "apikey") monkeypatch.setattr(httputil, "get_version", lambda: td.version.__version__) @@ -91,7 +92,7 @@ def mock_download(*args, **kwargs): to_file = kwargs["to_file"] sim.to_file(to_file) - monkeypatch.setattr("tidy3d.web.core.task_core.download_gz_file", mock_download) + monkeypatch.setattr(f"{task_core_path}.download_gz_file", mock_download) responses.add( responses.GET, @@ -127,7 +128,7 @@ def test_upload(monkeypatch, set_api_key): def mock_download(*args, **kwargs): pass - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_download) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_download) task = SimulationTask.get("3eb06d16-208b-487b-864b-e9b1d3e010a7") with tempfile.NamedTemporaryFile() as temp: task.upload_file(temp.name, "temp.json") @@ -359,7 +360,7 @@ def mock(*args, **kwargs): with open(file_path, "w") as f: f.write("0.3,5.7") - monkeypatch.setattr("tidy3d.web.core.task_core.download_file", mock) + monkeypatch.setattr(f"{task_core_path}.download_file", mock) responses.add( responses.GET, f"{Env.current.web_api_endpoint}/tidy3d/tasks/3eb06d16-208b-487b-864b-e9b1d3e010a7/detail", diff --git a/tests/test_web/test_webapi.py b/tests/test_web/test_webapi.py index 4fdd280bf0..ab9da2eee3 100644 --- a/tests/test_web/test_webapi.py +++ b/tests/test_web/test_webapi.py @@ -64,7 +64,7 @@ common.CONNECTION_RETRY_TIME = 0.1 INVALID_TASK_ID = "INVALID_TASK_ID" -task_core_path = "tidy3d.web.core.task_core" +task_core_path = "tidy3d._common.web.core.task_core" api_path = "tidy3d.web.api.webapi" Env.dev.active() @@ -204,7 +204,7 @@ def mock_upload(monkeypatch, set_api_key): def mock_upload_file(*args, **kwargs): pass - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file) @pytest.fixture @@ -358,7 +358,7 @@ def mock_webapi( @responses.activate def test_source_validation(monkeypatch, mock_upload, mock_get_info, mock_metadata): - sim = make_sim().copy(update={"sources": []}) + sim = make_sim().copy(update={"sources": ()}) assert upload(sim, TASK_NAME, PROJECT_NAME, source_required=False) with pytest.raises(SetupError): @@ -498,7 +498,7 @@ def mock_download(*args, **kwargs): pass def get_str(*args, **kwargs): - return sim.json().encode("utf-8") + return sim.model_dump_json().encode("utf-8") monkeypatch.setattr(f"{task_core_path}.download_gz_file", mock_download) monkeypatch.setattr(f"{task_core_path}.read_simulation_from_hdf5", get_str) @@ -816,15 +816,32 @@ def test_batch_monitor_skips_existing_download(monkeypatch, tmp_path): assert downloads == [("task_b_id", "download", os.path.join(str(tmp_path), "task_b_id.hdf5"))] +def test_batch_download_surfaces_download_errors(monkeypatch, tmp_path): + monkeypatch.setattr("tidy3d.web.api.container.Job.status", property(lambda self: "success")) + monkeypatch.setattr("tidy3d.web.api.container.Job.load_if_cached", property(lambda self: False)) + monkeypatch.setattr("tidy3d.web.api.container.Job.task_id", property(lambda self: "task_a_id")) + + def _raise_download(self, path): + raise RuntimeError("gzip extraction failed") + + monkeypatch.setattr("tidy3d.web.api.container.Job.download", _raise_download) + + sims = {"task_a": make_sim()} + batch = Batch(simulations=sims, folder_name=PROJECT_NAME, verbose=False) + + with pytest.raises(RuntimeError, match="gzip extraction failed"): + batch.download(path_dir=str(tmp_path)) + + """ Async """ @responses.activate @pytest.mark.parametrize("task_name", [TASK_NAME, None]) -def test_async(mock_webapi, mock_job_status, task_name): +def test_async(mock_webapi, mock_job_status, tmp_path, task_name): # monkeypatch.setattr("tidy3d.web.api.container.Job.status", property(lambda self: "success")) sims = {TASK_NAME: make_sim()} if task_name else [make_sim()] - _ = run_async(sims, folder_name=PROJECT_NAME) + _ = run_async(sims, folder_name=PROJECT_NAME, path_dir=str(tmp_path)) """ Main """ @@ -852,6 +869,8 @@ def save_sim_to_path(path: str) -> None: PROJECT_NAME, "--inspect_credits", "--inspect_sim", + "-o", + str(tmp_path / "tmp.hdf5"), ] ) @@ -865,6 +884,8 @@ def save_sim_to_path(path: str) -> None: "--folder_name", PROJECT_NAME, "--inspect_credits", + "-o", + str(tmp_path / "tmp.hdf5"), ] ) @@ -877,6 +898,8 @@ def save_sim_to_path(path: str) -> None: "--folder_name", PROJECT_NAME, "--inspect_sim", + "-o", + str(tmp_path / "tmp.hdf5"), ] ) @@ -1061,18 +1084,22 @@ def test_job_run_accepts_pathlikes(monkeypatch, tmp_path, path_builder): [_pathlib_builder, _posix_builder, _str_builder, _fspath_builder], ids=["pathlib.Path", "posixpath_str", "str", "PathLike"], ) +@pytest.mark.slow def test_batch_run_accepts_pathlike_dir(monkeypatch, tmp_path, dir_builder): """Batch.run(path_dir=...) accepts any PathLike directory location.""" - sims = {"A": make_sim(), "B": make_sim()} + sims = {"A": make_sim()} out_dir = dir_builder(tmp_path, "batch_out") # Map task_ids to sims: upload() is patched to return task_name, which for dict input # corresponds to the dict keys ("A", "B"), so we map those. - apply_common_patches(monkeypatch, tmp_path, taskid_to_sim={"A": sims["A"], "B": sims["B"]}) + apply_common_patches(monkeypatch, tmp_path, taskid_to_sim={"A": sims["A"]}) b = Batch(simulations=sims, folder_name=PROJECT_NAME) b.run(path_dir=out_dir) - # Directory created and two .hdf5 outputs produced + # Directory created and .hdf5 output produced out_dir_str = os.fspath(out_dir) assert os.path.isdir(out_dir_str) + + batch_file = Path(out_dir) / "batch.hdf5" + assert batch_file.is_file() diff --git a/tests/test_web/test_webapi_account.py b/tests/test_web/test_webapi_account.py index bb31a54010..3ee1b002f6 100644 --- a/tests/test_web/test_webapi_account.py +++ b/tests/test_web/test_webapi_account.py @@ -10,7 +10,6 @@ ) from tidy3d.web.core.environment import Env -task_core_path = "tidy3d.web.core.task_core" api_path = "tidy3d.web.api.webapi" Env.dev.active() diff --git a/tests/test_web/test_webapi_eme.py b/tests/test_web/test_webapi_eme.py index c44cd19b41..b847788409 100644 --- a/tests/test_web/test_webapi_eme.py +++ b/tests/test_web/test_webapi_eme.py @@ -7,6 +7,7 @@ from responses import matchers import tidy3d as td +from tests.test_web.test_webapi import task_core_path from tidy3d import EMESimulation from tidy3d.exceptions import SetupError from tidy3d.web.api.asynchronous import run_async @@ -37,7 +38,6 @@ EST_FLEX_UNIT = 11.11 FILE_SIZE_GB = 4.0 -task_core_path = "tidy3d.web.core.task_core" api_path = "tidy3d.web.api.webapi" @@ -89,7 +89,7 @@ def mock_upload(monkeypatch, set_api_key): def mock_upload_file(*args, **kwargs): pass - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file) @pytest.fixture @@ -277,7 +277,7 @@ def mock_download(*args, **kwargs): pass def get_str(*args, **kwargs): - return sim.json().encode("utf-8") + return sim.model_dump_json().encode("utf-8") monkeypatch.setattr(f"{task_core_path}.download_gz_file", mock_download) monkeypatch.setattr(f"{task_core_path}.read_simulation_from_hdf5", get_str) @@ -378,7 +378,7 @@ def test_batch(mock_webapi, mock_job_status, tmp_path): @responses.activate -def test_async(mock_webapi, mock_job_status): +def test_async(mock_webapi, mock_job_status, tmp_path): # monkeypatch.setattr("tidy3d.web.api.container.Job.status", property(lambda self: "success")) sims = {TASK_NAME: make_eme_sim()} - _ = run_async(sims, folder_name=PROJECT_NAME) + _ = run_async(sims, folder_name=PROJECT_NAME, path_dir=str(tmp_path)) diff --git a/tests/test_web/test_webapi_extra.py b/tests/test_web/test_webapi_extra.py index 5e16ef42e3..315a40e9c7 100644 --- a/tests/test_web/test_webapi_extra.py +++ b/tests/test_web/test_webapi_extra.py @@ -5,15 +5,14 @@ import pytest import responses +from tests.test_web.test_webapi import task_core_path from tidy3d.web.api.webapi import delete, get_info, get_tasks, real_cost, start @responses.activate def test_get_info_not_found(monkeypatch): """Tests that get_info raises a ValueError when the task is not found.""" - monkeypatch.setattr( - "tidy3d.web.core.task_core.SimulationTask.get", lambda *args, **kwargs: None - ) + monkeypatch.setattr(f"{task_core_path}.SimulationTask.get", lambda *args, **kwargs: None) with pytest.raises(ValueError, match="Task not found."): get_info("non_existent_task_id") @@ -21,9 +20,7 @@ def test_get_info_not_found(monkeypatch): @responses.activate def test_start_not_found(monkeypatch): """Tests that start raises a ValueError when the task is not found.""" - monkeypatch.setattr( - "tidy3d.web.core.task_core.SimulationTask.get", lambda *args, **kwargs: None - ) + monkeypatch.setattr(f"{task_core_path}.SimulationTask.get", lambda *args, **kwargs: None) with pytest.raises(ValueError, match="Task not found."): start("non_existent_task_id") @@ -42,9 +39,7 @@ class MockFolder: def list_tasks(self): return [] - monkeypatch.setattr( - "tidy3d.web.core.task_core.Folder.get", lambda *args, **kwargs: MockFolder() - ) + monkeypatch.setattr(f"{task_core_path}.Folder.get", lambda *args, **kwargs: MockFolder()) assert get_tasks() == [] @@ -60,6 +55,9 @@ def __init__(self, created_at, task_id): def dict(self): return {"task_id": self.task_id, "created_at": self.created_at} + def model_dump(self): + return self.dict() + class MockFolder: def list_tasks(self): return [ @@ -68,9 +66,7 @@ def list_tasks(self): MockTask(datetime(2023, 1, 3), "3"), ] - monkeypatch.setattr( - "tidy3d.web.core.task_core.Folder.get", lambda *args, **kwargs: MockFolder() - ) + monkeypatch.setattr(f"{task_core_path}.Folder.get", lambda *args, **kwargs: MockFolder()) tasks = get_tasks(order="old") assert [t["task_id"] for t in tasks] == ["1", "2", "3"] diff --git a/tests/test_web/test_webapi_heat.py b/tests/test_web/test_webapi_heat.py index 571aad371f..46135cdac6 100644 --- a/tests/test_web/test_webapi_heat.py +++ b/tests/test_web/test_webapi_heat.py @@ -7,6 +7,7 @@ from responses import matchers import tidy3d as td +from tests.test_web.test_webapi import task_core_path from tidy3d import HeatSimulation from tidy3d.web.api.asynchronous import run_async from tidy3d.web.api.container import Batch, Job @@ -34,7 +35,6 @@ EST_FLEX_UNIT = 11.11 FILE_SIZE_GB = 4.0 -task_core_path = "tidy3d.web.core.task_core" api_path = "tidy3d.web.api.webapi" @@ -86,7 +86,7 @@ def mock_upload(monkeypatch, set_api_key): def mock_upload_file(*args, **kwargs): pass - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file) @pytest.fixture @@ -267,7 +267,7 @@ def mock_download(*args, **kwargs): pass def get_str(*args, **kwargs): - return sim.json().encode("utf-8") + return sim.model_dump_json().encode("utf-8") monkeypatch.setattr(f"{task_core_path}.download_gz_file", mock_download) monkeypatch.setattr(f"{task_core_path}.read_simulation_from_hdf5", get_str) @@ -357,7 +357,7 @@ def test_batch(mock_webapi, mock_job_status, tmp_path): @responses.activate -def test_async(mock_webapi, mock_job_status): +def test_async(mock_webapi, mock_job_status, tmp_path): # monkeypatch.setattr("tidy3d.web.api.container.Job.status", property(lambda self: "success")) sims = {TASK_NAME: make_heat_sim()} - _ = run_async(sims, folder_name=PROJECT_NAME) + _ = run_async(sims, folder_name=PROJECT_NAME, path_dir=str(tmp_path)) diff --git a/tests/test_web/test_webapi_mode.py b/tests/test_web/test_webapi_mode.py index 93e4a2a8bf..ebace6a4ac 100644 --- a/tests/test_web/test_webapi_mode.py +++ b/tests/test_web/test_webapi_mode.py @@ -8,7 +8,8 @@ from responses import matchers import tidy3d as td -from tidy3d.components.data.dataset import ModeIndexDataArray +from tests.test_web.test_webapi import task_core_path +from tidy3d.components.data.data_array import ModeIndexDataArray from tidy3d.plugins.mode import ModeSolver from tidy3d.web.api.asynchronous import run_async from tidy3d.web.api.container import Batch, Job @@ -35,12 +36,18 @@ EST_FLEX_UNIT = 11.11 FILE_SIZE_GB = 4.0 -task_core_path = "tidy3d.web.core.task_core" api_path = "tidy3d.web.api.webapi" f, AX = plt.subplots() +@pytest.fixture +def unique_project_name(tmp_path): + """Generate unique project name using tmp_path's unique directory.""" + # tmp_path.name gives us something like 'test_upload0' + return f"default_{tmp_path.name}" + + def make_mode_sim(): """Simple mode solver""" @@ -57,7 +64,7 @@ def make_mode_sim(): simulation=simulation, plane=td.Box(center=(0, 0, 0), size=(1, 1, 0)), mode_spec=mode_spec, - freqs=[2e14], + freqs=(2e14,), direction="-", ) return ms @@ -73,13 +80,13 @@ def set_api_key(monkeypatch): @pytest.fixture -def mock_upload(monkeypatch, set_api_key): +def mock_upload(monkeypatch, set_api_key, unique_project_name): """Mocks webapi.upload.""" responses.add( responses.GET, f"{Env.current.web_api_endpoint}/tidy3d/project", - match=[matchers.query_param_matcher({"projectName": PROJECT_NAME})], - json={"data": {"projectId": FOLDER_ID, "projectName": PROJECT_NAME}}, + match=[matchers.query_param_matcher({"projectName": unique_project_name})], + json={"data": {"projectId": FOLDER_ID, "projectName": unique_project_name}}, status=200, ) @@ -118,9 +125,9 @@ def mock_upload_file(*args, **kwargs): pass monkeypatch.setattr( - "tidy3d.web.core.task_core.SimulationTask.upload_simulation", mock_upload_simulation + f"{task_core_path}.SimulationTask.upload_simulation", mock_upload_simulation ) - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file) return uploaded_stub @@ -274,21 +281,21 @@ def mock_webapi( @responses.activate -def test_upload(monkeypatch, mock_upload, mock_get_info, mock_metadata): +def test_upload(monkeypatch, mock_upload, mock_get_info, mock_metadata, unique_project_name): sim = make_mode_sim() assert sim != get_reduced_simulation(sim, reduce_simulation=True) - assert upload(sim, TASK_NAME, PROJECT_NAME, reduce_simulation=True) + assert upload(sim, TASK_NAME, unique_project_name, reduce_simulation=True) @pytest.mark.parametrize("reduce_simulation", [True, False]) @responses.activate def test_upload_with_reduction_parameter( - monkeypatch, mock_upload, mock_get_info, mock_metadata, reduce_simulation + monkeypatch, mock_upload, mock_get_info, mock_metadata, reduce_simulation, unique_project_name ): """Test that simulation reduction is properly applied before upload based on reduce_simulation parameter.""" sim = make_mode_sim() - upload(sim, TASK_NAME, PROJECT_NAME, reduce_simulation=reduce_simulation) + upload(sim, TASK_NAME, unique_project_name, reduce_simulation=reduce_simulation) if reduce_simulation: expected_sim = get_reduced_simulation(sim, reduce_simulation=True) @@ -324,13 +331,14 @@ def mock_download(*args, **kwargs): pass def get_str(*args, **kwargs): - return sim.json().encode("utf-8") + return sim.model_dump_json().encode("utf-8") monkeypatch.setattr(f"{task_core_path}.download_gz_file", mock_download) monkeypatch.setattr(f"{task_core_path}.read_simulation_from_hdf5", get_str) fname_tmp = str(tmp_path / "web_test_tmp.json") download_json(TASK_ID, fname_tmp) + assert ModeSolver.from_file(fname_tmp) == sim @@ -345,13 +353,13 @@ def mock_download(*args, **kwargs): @responses.activate -def test_run(mock_webapi, monkeypatch, tmp_path): +def test_run(mock_webapi, monkeypatch, tmp_path, unique_project_name): sim = make_mode_sim() monkeypatch.setattr(f"{api_path}.load", lambda *args, **kwargs: True) assert run( sim, task_name=TASK_NAME, - folder_name=PROJECT_NAME, + folder_name=unique_project_name, path=str(tmp_path / "web_test_tmp.json"), ) @@ -379,10 +387,10 @@ def test_abort_task(set_api_key, mock_get_info): @responses.activate -def test_job(mock_webapi, monkeypatch, tmp_path): +def test_job(mock_webapi, monkeypatch, tmp_path, unique_project_name): monkeypatch.setattr("tidy3d.web.api.container.Job.load", lambda *args, **kwargs: True) sim = make_mode_sim() - j = Job(simulation=sim, task_name=TASK_NAME, folder_name=PROJECT_NAME) + j = Job(simulation=sim, task_name=TASK_NAME, folder_name=unique_project_name) _ = j.run(path=str(tmp_path / "web_test_tmp.json")) _ = j.status @@ -399,13 +407,13 @@ def mock_job_status(monkeypatch): @responses.activate -def test_batch(mock_webapi, mock_job_status, tmp_path, monkeypatch): +def test_batch(mock_webapi, mock_job_status, tmp_path, monkeypatch, unique_project_name): # monkeypatch.setattr("tidy3d.web.api.container.Batch.monitor", lambda self: time.sleep(0.1)) # monkeypatch.setattr("tidy3d.web.api.container.Job.status", property(lambda self: "success")) monkeypatch.setattr(f"{api_path}.load", lambda *args, **kwargs: True) sims = {TASK_NAME: make_mode_sim()} - b = Batch(simulations=sims, folder_name=PROJECT_NAME) + b = Batch(simulations=sims, folder_name=unique_project_name) b.estimate_cost() _ = b.run(path_dir=str(tmp_path)) assert b.real_cost() == FLEX_UNIT * len(sims) @@ -415,16 +423,16 @@ def test_batch(mock_webapi, mock_job_status, tmp_path, monkeypatch): @responses.activate -def test_async(mock_webapi, mock_job_status, monkeypatch): +def test_async(mock_webapi, mock_job_status, monkeypatch, tmp_path): # monkeypatch.setattr("tidy3d.web.api.container.Job.status", property(lambda self: "success")) monkeypatch.setattr(f"{api_path}.load", lambda *args, **kwargs: True) sims = {TASK_NAME: make_mode_sim()} - _ = run_async(sims, folder_name=PROJECT_NAME) + _ = run_async(sims, folder_name=PROJECT_NAME, path_dir=str(tmp_path)) @responses.activate -def test_patch_data(mock_webapi, monkeypatch, tmp_path): +def test_patch_data(mock_webapi, monkeypatch, tmp_path, unique_project_name): """Test that mode solver is patched with remote data after run""" def get_sim_and_data(): @@ -453,7 +461,7 @@ def check_patched(result, sim, data_local, data_remote): result = run( sim, task_name=TASK_NAME, - folder_name=PROJECT_NAME, + folder_name=unique_project_name, path=str(tmp_path / "web_test_tmp.json"), ) @@ -463,7 +471,7 @@ def check_patched(result, sim, data_local, data_remote): sim, data_local, data_remote = get_sim_and_data() - j = Job(simulation=sim, task_name=TASK_NAME, folder_name=PROJECT_NAME) + j = Job(simulation=sim, task_name=TASK_NAME, folder_name=unique_project_name) result = j.run(path=str(tmp_path / "web_test_tmp.json")) @@ -474,7 +482,7 @@ def check_patched(result, sim, data_local, data_remote): sim, data_local, data_remote = get_sim_and_data() sims = {TASK_NAME: sim} - b = Batch(simulations=sims, folder_name=PROJECT_NAME) + b = Batch(simulations=sims, folder_name=unique_project_name) result = b.run(path_dir=str(tmp_path)) check_patched(result[TASK_NAME], sim, data_local, data_remote) diff --git a/tests/test_web/test_webapi_mode_sim.py b/tests/test_web/test_webapi_mode_sim.py index 052591beab..332e49d324 100644 --- a/tests/test_web/test_webapi_mode_sim.py +++ b/tests/test_web/test_webapi_mode_sim.py @@ -7,6 +7,7 @@ from responses import matchers import tidy3d as td +from tests.test_web.test_webapi import task_core_path from tidy3d.plugins.mode import ModeSolver from tidy3d.web.api.asynchronous import run_async from tidy3d.web.api.container import Batch, Job @@ -33,7 +34,6 @@ EST_FLEX_UNIT = 11.11 FILE_SIZE_GB = 4.0 -task_core_path = "tidy3d.web.core.task_core" api_path = "tidy3d.web.api.webapi" @@ -59,6 +59,13 @@ def make_mode_sim(): return td.ModeSimulation.from_mode_solver(ms) +@pytest.fixture +def unique_project_name(tmp_path): + """Generate unique project name using tmp_path's unique directory.""" + # tmp_path.name gives us something like 'test_upload0' + return f"default_{tmp_path.name}" + + @pytest.fixture def set_api_key(monkeypatch): """Set the api key.""" @@ -69,13 +76,13 @@ def set_api_key(monkeypatch): @pytest.fixture -def mock_upload(monkeypatch, set_api_key): +def mock_upload(monkeypatch, set_api_key, unique_project_name): """Mocks webapi.upload.""" responses.add( responses.GET, f"{Env.current.web_api_endpoint}/tidy3d/project", - match=[matchers.query_param_matcher({"projectName": PROJECT_NAME})], - json={"data": {"projectId": FOLDER_ID, "projectName": PROJECT_NAME}}, + match=[matchers.query_param_matcher({"projectName": unique_project_name})], + json={"data": {"projectId": FOLDER_ID, "projectName": unique_project_name}}, status=200, ) @@ -114,9 +121,9 @@ def mock_upload_file(*args, **kwargs): pass monkeypatch.setattr( - "tidy3d.web.core.task_core.SimulationTask.upload_simulation", mock_upload_simulation + f"{task_core_path}.SimulationTask.upload_simulation", mock_upload_simulation ) - monkeypatch.setattr("tidy3d.web.core.task_core.upload_file", mock_upload_file) + monkeypatch.setattr(f"{task_core_path}.upload_file", mock_upload_file) return uploaded_stub @@ -270,21 +277,21 @@ def mock_webapi( @responses.activate -def test_upload(monkeypatch, mock_upload, mock_get_info, mock_metadata): +def test_upload(monkeypatch, mock_upload, mock_get_info, mock_metadata, unique_project_name): sim = make_mode_sim() assert sim != get_reduced_simulation(sim, reduce_simulation=True) - assert upload(sim, TASK_NAME, PROJECT_NAME, reduce_simulation=True) + assert upload(sim, TASK_NAME, unique_project_name, reduce_simulation=True) @pytest.mark.parametrize("reduce_simulation", [True, False]) @responses.activate def test_upload_with_reduction_parameter( - monkeypatch, mock_upload, mock_get_info, mock_metadata, reduce_simulation + monkeypatch, mock_upload, mock_get_info, mock_metadata, reduce_simulation, unique_project_name ): """Test that simulation reduction is properly applied before upload based on reduce_simulation parameter.""" sim = make_mode_sim() - upload(sim, TASK_NAME, PROJECT_NAME, reduce_simulation=reduce_simulation) + upload(sim, TASK_NAME, unique_project_name, reduce_simulation=reduce_simulation) if reduce_simulation: expected_sim = get_reduced_simulation(sim, reduce_simulation=True) @@ -320,7 +327,7 @@ def mock_download(*args, **kwargs): pass def get_str(*args, **kwargs): - return sim.json().encode("utf-8") + return sim.model_dump_json().encode("utf-8") monkeypatch.setattr(f"{task_core_path}.download_gz_file", mock_download) monkeypatch.setattr(f"{task_core_path}.read_simulation_from_hdf5", get_str) @@ -341,13 +348,13 @@ def mock_download(*args, **kwargs): @responses.activate -def test_run(mock_webapi, monkeypatch, tmp_path): +def test_run(mock_webapi, monkeypatch, tmp_path, unique_project_name): sim = make_mode_sim() monkeypatch.setattr(f"{api_path}.load", lambda *args, **kwargs: True) assert run( sim, task_name=TASK_NAME, - folder_name=PROJECT_NAME, + folder_name=unique_project_name, path=str(tmp_path / "web_test_tmp.json"), ) @@ -375,10 +382,10 @@ def test_abort_task(set_api_key, mock_get_info): @responses.activate -def test_job(mock_webapi, monkeypatch, tmp_path): +def test_job(mock_webapi, monkeypatch, tmp_path, unique_project_name): monkeypatch.setattr("tidy3d.web.api.container.Job.load", lambda *args, **kwargs: True) sim = make_mode_sim() - j = Job(simulation=sim, task_name=TASK_NAME, folder_name=PROJECT_NAME) + j = Job(simulation=sim, task_name=TASK_NAME, folder_name=unique_project_name) _ = j.run(path=str(tmp_path / "web_test_tmp.json")) _ = j.status @@ -395,12 +402,12 @@ def mock_job_status(monkeypatch): @responses.activate -def test_batch(mock_webapi, mock_job_status, tmp_path): +def test_batch(mock_webapi, mock_job_status, tmp_path, unique_project_name): # monkeypatch.setattr("tidy3d.web.api.container.Batch.monitor", lambda self: time.sleep(0.1)) # monkeypatch.setattr("tidy3d.web.api.container.Job.status", property(lambda self: "success")) sims = {TASK_NAME: make_mode_sim()} - b = Batch(simulations=sims, folder_name=PROJECT_NAME) + b = Batch(simulations=sims, folder_name=unique_project_name) b.estimate_cost() _ = b.run(path_dir=str(tmp_path)) assert b.real_cost() == FLEX_UNIT * len(sims) @@ -410,7 +417,7 @@ def test_batch(mock_webapi, mock_job_status, tmp_path): @responses.activate -def test_async(mock_webapi, mock_job_status): +def test_async(mock_webapi, mock_job_status, tmp_path): # monkeypatch.setattr("tidy3d.web.api.container.Job.status", property(lambda self: "success")) sims = {TASK_NAME: make_mode_sim()} - _ = run_async(sims, folder_name=PROJECT_NAME) + _ = run_async(sims, folder_name=PROJECT_NAME, path_dir=str(tmp_path)) diff --git a/tests/utils.py b/tests/utils.py index adf28676ad..bf641d29a1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,11 +5,11 @@ from typing import Any, Optional, Union import numpy as np -import pydantic.v1 as pd import trimesh import xarray as xr from autograd.core import VJPNode from autograd.tracer import new_box +from pydantic import Field import tidy3d as td from tidy3d import ModeIndexDataArray @@ -34,7 +34,7 @@ size=(10.0, 10.0, 10.0), grid_spec=td.GridSpec(wavelength=1.0), run_time=1e-13, - monitors=[ + monitors=( td.FieldMonitor(size=(1, 1, 1), center=(0, 1, 0), freqs=FREQS, name="field_freq"), td.FieldTimeMonitor(size=(1, 1, 0), center=(1, 0, 0), interval=10, name="field_time"), td.FluxMonitor(size=(1, 1, 0), center=(0, 0, 0), freqs=FREQS, name="flux_freq"), @@ -46,7 +46,7 @@ mode_spec=td.ModeSpec(num_modes=3), name="mode", ), - ], + ), boundary_spec=td.BoundarySpec.all_sides(boundary=td.Periodic()), ) @@ -411,7 +411,7 @@ def make_custom_data(lims, unstructured): SIM_FULL = td.Simulation( size=(8.0, 8.0, 8.0), run_time=1e-12, - structures=[ + structures=( td.Structure( geometry=td.Cylinder(length=1, center=(-1 * tracer, 0, 0), radius=tracer, axis=2), medium=td.Medium(permittivity=1 + tracer, name="dieletric"), @@ -491,7 +491,7 @@ def make_custom_data(lims, unstructured): name="fully_anisotropic_box", ), td.Structure( - geometry=td.GeometryGroup(geometries=[td.Box(size=(1, 1, 1), center=(-1, 0, 0))]), + geometry=td.GeometryGroup(geometries=(td.Box(size=(1, 1, 1), center=(-1, 0, 0)),)), medium=td.PEC, name="pec_group", ), @@ -510,7 +510,7 @@ def make_custom_data(lims, unstructured): ), td.Structure( geometry=td.PolySlab( - vertices=[(-1.5, -1.5), (-0.5, -1.5), (-0.5, -0.5)], slab_bounds=[-1, 1] + vertices=[(-1.5, -1.5), (-0.5, -1.5), (-0.5, -0.5)], slab_bounds=(-1, 1) ), medium=td.PoleResidue( eps_inf=1.0, poles=((6206417594288582j, (-3.311074436985222e16j)),) @@ -630,7 +630,7 @@ def make_custom_data(lims, unstructured): ), td.Structure( geometry=td.PolySlab( - vertices=[(-1.5, -1.5), (-0.5, -1.5), (-0.5, -0.5)], slab_bounds=[-1, 1] + vertices=[(-1.5, -1.5), (-0.5, -1.5), (-0.5, -0.5)], slab_bounds=(-1, 1) ), medium=td.PoleResidue( eps_inf=1.0, poles=((6206417594288582j, (-3.311074436985222e16j)),) @@ -689,8 +689,8 @@ def make_custom_data(lims, unstructured): name="SiO2", ), ), - ], - sources=[ + ), + sources=( td.UniformCurrentSource( size=(0, 0, 0), center=(0, 0.5, 0), @@ -806,10 +806,10 @@ def make_custom_data(lims, unstructured): freq0=2e14, fwidth=4e13, values=np.linspace(0, 10, 1000), dt=1e-12 / 100 ), ), - ], + ), monitors=( td.FieldMonitor( - size=(0, 0, 0), center=(0, 0, 0), fields=["Ex"], freqs=[1.5e14, 2e14], name="field" + size=(0, 0, 0), center=(0, 0, 0), fields=("Ex",), freqs=[1.5e14, 2e14], name="field" ), td.FieldTimeMonitor(size=(0, 0, 0), center=(0, 0, 0), name="field_time", interval=100), td.AuxFieldTimeMonitor( @@ -882,7 +882,7 @@ def make_custom_data(lims, unstructured): freqs=[1e14, 2e14], ), ), - lumped_elements=[ + lumped_elements=( td.LumpedResistor( center=(2, 2, 0), size=(0.2, 0.2, 0), name="Resistor", resistance=42, voltage_axis=0 ), @@ -901,7 +901,7 @@ def make_custom_data(lims, unstructured): network=td.RLCNetwork(inductance=1e-9, capacitance=10e-12, network_topology="parallel"), voltage_axis=0, ), - ], + ), symmetry=(0, 0, 0), boundary_spec=td.BoundarySpec( x=td.Boundary(plus=td.PML(num_layers=20), minus=td.Absorber(num_layers=100)), @@ -915,12 +915,12 @@ def make_custom_data(lims, unstructured): grid_x=td.AutoGrid(), grid_y=td.CustomGrid(dl=100 * [0.04]), grid_z=td.UniformGrid(dl=0.05), - override_structures=[ + override_structures=( td.Structure( geometry=td.Box(size=(1, 1, 1), center=(-1, 0, 0)), medium=td.Medium(permittivity=2.0), - ) - ], + ), + ), ), ) @@ -1131,7 +1131,7 @@ def get_spatial_coords_dict(simulation: td.Simulation, monitor: td.Monitor, fiel """Returns MonitorData coordinates associated with a Monitor object""" grid = simulation.discretize_monitor(monitor) spatial_coords = grid.boundaries if monitor.colocate else grid[field_name] - spatial_coords_dict = spatial_coords.dict() + spatial_coords_dict = spatial_coords.model_dump() coords = {} for axis, dim in enumerate("xyz"): @@ -1156,7 +1156,7 @@ def make_data( ) -> td.components.data.data_array.DataArray: """make a random DataArray out of supplied coordinates and data_type.""" data_shape = [len(coords[k]) for k in data_array_type._dims] - np.random.seed(1) + np.random.seed(0) data = DATA_GEN_FN(data_shape) data = (1 + 0.5j) * data if is_complex else data @@ -1292,7 +1292,7 @@ def make_microwave_mode_solver_data( def make_eps_data(monitor: td.PermittivityMonitor) -> td.PermittivityData: """make a random PermittivityData from a PermittivityMonitor.""" - field_mnt = td.FieldMonitor(**monitor.dict(exclude={"type", "fields"})) + field_mnt = td.FieldMonitor(**monitor.model_dump(exclude={"type", "fields"})) field_data = make_field_data(monitor=field_mnt) return td.PermittivityData( monitor=monitor, @@ -1599,7 +1599,7 @@ def make_flux_time_data(monitor: td.FluxTimeMonitor) -> td.FluxTimeData: td.FluxTimeMonitor: make_flux_time_data, } - data = [MONITOR_MAKER_MAP[type(mnt)](mnt) for mnt in simulation.monitors] + data = tuple(MONITOR_MAKER_MAP[type(mnt)](mnt) for mnt in simulation.monitors) sim_data = td.SimulationData(simulation=simulation, data=data) if path is not None: @@ -1611,14 +1611,14 @@ def make_flux_time_data(monitor: td.FluxTimeMonitor) -> td.FluxTimeData: class BatchDataTest(Tidy3dBaseModel): """Holds a collection of :class:`.SimulationData` returned by :class:`.Batch`.""" - task_paths: dict[str, str] = pd.Field( - ..., + task_paths: dict[str, str] = Field( title="Data Paths", description="Mapping of task_name to path to corresponding data for each task in batch.", ) - task_ids: dict[str, str] = pd.Field( - ..., title="Task IDs", description="Mapping of task_name to task_id for each task in batch." + task_ids: dict[str, str] = Field( + title="Task IDs", + description="Mapping of task_name to task_id for each task in batch.", ) sim_data: dict[str, td.SimulationData] diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index 13934377f7..7923290bad 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -2,6 +2,9 @@ from __future__ import annotations +# ruff: noqa: I001 - ensure config is imported first +from .config import config + from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.boundary import BroadbandModeABCFitterParam, BroadbandModeABCSpec from tidy3d.components.data.index import SimulationDataMap @@ -455,8 +458,6 @@ from .components.transformation import RotationAroundAxis from .components.viz import VisualizationSpec, restore_matplotlib_rcparams -# config -from .config import config # constants imported as `C_0 = td.C_0` or `td.constants.C_0` from .constants import C_0, EPSILON_0, ETA_0, HBAR, K_B, MU_0, Q_e, inf @@ -484,9 +485,9 @@ def set_logging_level(level: str) -> None: log.info(f"Using client version: {__version__}") -Transformed.update_forward_refs() -ClipOperation.update_forward_refs() -GeometryGroup.update_forward_refs() +Transformed.model_rebuild() +ClipOperation.model_rebuild() +GeometryGroup.model_rebuild() # Backwards compatibility: Remove 2.11 renamed integral classes VoltageIntegralAxisAligned = AxisAlignedVoltageIntegral diff --git a/tidy3d/__main__.py b/tidy3d/__main__.py index 743a552d02..7aa5fa4310 100644 --- a/tidy3d/__main__.py +++ b/tidy3d/__main__.py @@ -9,7 +9,7 @@ from tidy3d.web import Job -def main(args) -> None: +def main(args: list[str]) -> None: """Parse args and run the corresponding tidy3d simulaton.""" parser = argparse.ArgumentParser(description="Tidy3D") diff --git a/tidy3d/_common/__init__.py b/tidy3d/_common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/_runtime.py b/tidy3d/_common/_runtime.py new file mode 100644 index 0000000000..6dbf61accd --- /dev/null +++ b/tidy3d/_common/_runtime.py @@ -0,0 +1,12 @@ +"""Runtime environment detection for tidy3d. + +This module must have ZERO dependencies on other tidy3d modules to avoid +circular imports. It is imported very early in the initialization chain. +""" + +from __future__ import annotations + +import sys + +# Detect WASM/Pyodide environment where web and filesystem features are unavailable +WASM_BUILD = "pyodide" in sys.modules or sys.platform == "emscripten" diff --git a/tidy3d/_common/compat.py b/tidy3d/_common/compat.py new file mode 100644 index 0000000000..a616a41895 --- /dev/null +++ b/tidy3d/_common/compat.py @@ -0,0 +1,31 @@ +"""Compatibility layer for handling differences between package versions.""" + +from __future__ import annotations + +import importlib +from functools import cache + +from packaging.version import parse + +try: + from xarray.structure import alignment +except ImportError: + from xarray.core import alignment + +try: + from numpy import trapezoid as np_trapezoid +except ImportError: # NumPy < 2.0 + from numpy import trapz as np_trapezoid + +try: + from typing import Self, TypeAlias # Python >= 3.11 +except ImportError: # Python <3.11 + from typing_extensions import Self, TypeAlias + + +@cache +def _package_is_older_than(package: str, version: str) -> bool: + return parse(importlib.metadata.version(package)) < parse(version) + + +__all__ = ["Self", "TypeAlias", "_package_is_older_than", "alignment", "np_trapezoid"] diff --git a/tidy3d/_common/components/__init__.py b/tidy3d/_common/components/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/autograd/__init__.py b/tidy3d/_common/components/autograd/__init__.py new file mode 100644 index 0000000000..3b5b36e033 --- /dev/null +++ b/tidy3d/_common/components/autograd/__init__.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from tidy3d._common.components.autograd.boxes import TidyArrayBox +from tidy3d._common.components.autograd.functions import interpn +from tidy3d._common.components.autograd.types import ( + AutogradFieldMap, + InterpolationType, + PathType, + TracedArrayFloat2D, + TracedArrayLike, + TracedComplex, + TracedCoordinate, + TracedFloat, + TracedPoleAndResidue, + TracedPolesAndResidues, + TracedPositiveFloat, + TracedSize, + TracedSize1D, +) +from tidy3d._common.components.autograd.utils import get_static, hasbox, is_tidy_box, split_list + +__all__ = [ + "AutogradFieldMap", + "InterpolationType", + "PathType", + "TidyArrayBox", + "TracedArrayFloat2D", + "TracedArrayLike", + "TracedComplex", + "TracedCoordinate", + "TracedFloat", + "TracedPoleAndResidue", + "TracedPolesAndResidues", + "TracedPositiveFloat", + "TracedSize", + "TracedSize1D", + "get_static", + "hasbox", + "interpn", + "is_tidy_box", + "split_list", +] diff --git a/tidy3d/_common/components/autograd/boxes.py b/tidy3d/_common/components/autograd/boxes.py new file mode 100644 index 0000000000..d51e948a85 --- /dev/null +++ b/tidy3d/_common/components/autograd/boxes.py @@ -0,0 +1,162 @@ +# Adds some functionality to the autograd arraybox and related autograd patches +# NOTE: we do not subclass ArrayBox since that would break autograd's internal checks +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING, Any + +import autograd.numpy as anp +from autograd.extend import VJPNode, defjvp, register_notrace +from autograd.numpy.numpy_boxes import ArrayBox +from autograd.numpy.numpy_wrapper import _astype + +if TYPE_CHECKING: + from typing import Callable + +TidyArrayBox = ArrayBox # NOT a subclass + +_autograd_module_cache = {} # cache for imported autograd modules + +register_notrace(VJPNode, anp.full_like) + +defjvp( + _astype, + lambda g, ans, A, dtype, order="K", casting="unsafe", subok=True, copy=True: _astype(g, dtype), +) + +anp.astype = _astype +anp.permute_dims = anp.transpose + + +@classmethod +def from_arraybox(cls: Any, box: ArrayBox) -> TidyArrayBox: + """Construct a TidyArrayBox from an ArrayBox.""" + return cls(box._value, box._trace, box._node) + + +def __array_function__( + self: Any, + func: Callable, + types: list[Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> Any: + """ + Handle the dispatch of NumPy functions to autograd's numpy implementation. + + Parameters + ---------- + self : Any + The instance of the class. + func : Callable + The NumPy function being called. + types : list[Any] + The types of the arguments that implement __array_function__. + args : tuple[Any, ...] + The positional arguments to the function. + kwargs : dict[str, Any] + The keyword arguments to the function. + + Returns + ------- + Any + The result of the function call, or NotImplemented. + + Raises + ------ + NotImplementedError + If the function is not implemented in autograd.numpy. + + See Also + -------- + https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_function__ + """ + if not all(t in TidyArrayBox.type_mappings for t in types): + return NotImplemented + + module_name = func.__module__ + + if module_name.startswith("numpy"): + anp_module_name = "autograd." + module_name + else: + return NotImplemented + + # Use the cached module if available + anp_module = _autograd_module_cache.get(anp_module_name) + if anp_module is None: + try: + anp_module = importlib.import_module(anp_module_name) + _autograd_module_cache[anp_module_name] = anp_module + except ImportError: + return NotImplemented + + f = getattr(anp_module, func.__name__, None) + if f is None: + return NotImplemented + + if f.__name__ == "nanmean": # somehow xarray always dispatches to nanmean + f = anp.mean + kwargs.pop("dtype", None) # autograd mean vjp doesn't support dtype + + return f(*args, **kwargs) + + +def __array_ufunc__( + self: Any, + ufunc: Callable, + method: str, + *inputs: Any, + **kwargs: dict[str, Any], +) -> Any: + """ + Handle the dispatch of NumPy ufuncs to autograd's numpy implementation. + + Parameters + ---------- + self : Any + The instance of the class. + ufunc : Callable + The universal function being called. + method : str + The method of the ufunc being called. + inputs : Any + The input arguments to the ufunc. + kwargs : dict[str, Any] + The keyword arguments to the ufunc. + + Returns + ------- + Any + The result of the ufunc call, or NotImplemented. + + See Also + -------- + https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__ + """ + if method != "__call__": + return NotImplemented + + ufunc_name = ufunc.__name__ + + anp_ufunc = getattr(anp, ufunc_name, None) + if anp_ufunc is not None: + return anp_ufunc(*inputs, **kwargs) + + return NotImplemented + + +def item(self: Any) -> Any: + if self.size != 1: + raise ValueError("Can only convert an array of size 1 to a scalar") + return anp.ravel(self)[0] + + +TidyArrayBox._tidy = True +TidyArrayBox.from_arraybox = from_arraybox +TidyArrayBox.__array_namespace__ = lambda self, *, api_version=None: anp +TidyArrayBox.__array_ufunc__ = __array_ufunc__ +TidyArrayBox.__array_function__ = __array_function__ +TidyArrayBox.real = property(anp.real) +TidyArrayBox.imag = property(anp.imag) +TidyArrayBox.conj = anp.conj +TidyArrayBox.item = item diff --git a/tidy3d/_common/components/autograd/derivative_utils.py b/tidy3d/_common/components/autograd/derivative_utils.py new file mode 100644 index 0000000000..7b2ff1d1da --- /dev/null +++ b/tidy3d/_common/components/autograd/derivative_utils.py @@ -0,0 +1,1057 @@ +"""Utilities for autograd derivative computation and field gradient evaluation.""" + +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from functools import reduce +from typing import TYPE_CHECKING, Any, Optional + +import numpy as np +import xarray as xr +from numpy.typing import NDArray + +from tidy3d._common.components.data.data_array import FreqDataArray, ScalarFieldDataArray +from tidy3d._common.components.types.base import ArrayLike, Bound +from tidy3d._common.config import config +from tidy3d._common.constants import C_0, EPSILON_0, LARGE_NUMBER, MU_0 +from tidy3d._common.log import log + +from .types import PathType +from .utils import get_static + +if TYPE_CHECKING: + from typing import Callable, Union + + from tidy3d._common.compat import Self + from tidy3d._common.components.types.base import xyz + + +FieldData = dict[str, ScalarFieldDataArray] +PermittivityData = dict[str, ScalarFieldDataArray] +EpsType = FreqDataArray +ArrayFloat = NDArray[np.floating] +ArrayComplex = NDArray[np.complexfloating] + + +class LazyInterpolator: + """Lazy wrapper for interpolators that creates them on first access.""" + + def __init__(self, creator_func: Callable[[], Callable[[ArrayFloat], ArrayComplex]]) -> None: + """Initialize with a function that creates the interpolator when called.""" + self.creator_func = creator_func + self._interpolator: Optional[Callable[[ArrayFloat], ArrayComplex]] = None + + def __call__(self, *args: Any, **kwargs: Any) -> ArrayComplex: + """Create interpolator on first call and delegate to it.""" + if self._interpolator is None: + self._interpolator = self.creator_func() + return self._interpolator(*args, **kwargs) + + +@dataclass +class DerivativeInfo: + """Stores derivative information passed to the ``._compute_derivatives`` methods. + + This dataclass contains all the field data and parameters needed for computing + gradients with respect to geometry perturbations. + """ + + # Required fields + paths: list[PathType] + """List of paths to the traced fields that need derivatives calculated.""" + + E_der_map: FieldData + """Electric field gradient map. + Dataset where the field components ("Ex", "Ey", "Ez") store the multiplication + of the forward and adjoint electric fields. The tangential components of this + dataset are used when computing adjoint gradients for shifting boundaries. + All components are used when computing volume-based gradients.""" + + D_der_map: FieldData + """Displacement field gradient map. + Dataset where the field components ("Ex", "Ey", "Ez") store the multiplication + of the forward and adjoint displacement fields. The normal component of this + dataset is used when computing adjoint gradients for shifting boundaries.""" + + E_fwd: FieldData + """Forward electric fields. + Dataset where the field components ("Ex", "Ey", "Ez") represent the forward + electric fields used for computing gradients for a given structure.""" + + E_adj: FieldData + """Adjoint electric fields. + Dataset where the field components ("Ex", "Ey", "Ez") represent the adjoint + electric fields used for computing gradients for a given structure.""" + + D_fwd: FieldData + """Forward displacement fields. + Dataset where the field components ("Ex", "Ey", "Ez") represent the forward + displacement fields used for computing gradients for a given structure.""" + + D_adj: FieldData + """Adjoint displacement fields. + Dataset where the field components ("Ex", "Ey", "Ez") represent the adjoint + displacement fields used for computing gradients for a given structure.""" + + eps_data: PermittivityData + """Permittivity dataset. + Dataset of relative permittivity values along all three dimensions. + Used for automatically computing permittivity inside or outside of a simple geometry.""" + + eps_in: EpsType + """Permittivity inside the Structure. + Typically computed from Structure.medium.eps_model. + Used when it cannot be computed from eps_data or when eps_approx=True.""" + + eps_out: EpsType + """Permittivity outside the Structure. + Typically computed from Simulation.medium.eps_model. + Used when it cannot be computed from eps_data or when eps_approx=True.""" + + bounds: Bound + """Geometry bounds. + Bounds corresponding to the structure, used in Medium calculations.""" + + bounds_intersect: Bound + """Geometry and simulation intersection bounds. + Bounds corresponding to the minimum intersection between the structure + and the simulation it is contained in.""" + + simulation_bounds: Bound + """Simulation bounds. + Bounds corresponding to the simulation domain containing this structure. + Unlike bounds_intersect, this is independent of the structure's bounds and + is purely based on the simulation geometry.""" + + frequencies: ArrayLike + """Frequencies at which the adjoint gradient should be computed.""" + + # Optional fields with defaults + + H_der_map: Optional[FieldData] = None + """Magnetic field gradient map. + Dataset where the field components ("Hx", "Hy", "Hz") store the multiplication + of the forward and adjoint magnetic fields. The tangential component of this + dataset is used when computing adjoint gradients for shifting boundaries of + structures composed of PEC mediums.""" + + H_fwd: Optional[FieldData] = None + """Forward magnetic fields. + Dataset where the field components ("Hx", "Hy", "Hz") represent the forward + magnetic fields used for computing gradients for a given structure.""" + + H_adj: Optional[FieldData] = None + """Adjoint magnetic fields. + Dataset where the field components ("Hx", "Hy", "Hz") represent the adjoint + magnetic fields used for computing gradients for a given structure.""" + + is_medium_pec: bool = False + """Indicates if structure material is PEC. + If True, the structure contains a PEC material which changes the gradient + formulation at the boundary compared to the dielectric case.""" + + background_medium_is_pec: bool = False + """Indicates if structure material is PEC. + If True, the structure is partially surrounded by a PEC material.""" + + interpolators: Optional[dict] = None + """Pre-computed interpolators. + Optional pre-computed interpolators for field components and permittivity data. + When provided, avoids redundant interpolator creation for multiple geometries + sharing the same field data. This significantly improves performance for + GeometryGroup processing.""" + + # private cache for interpolators + _interpolators_cache: dict = field(default_factory=dict, init=False, repr=False) + + def updated_copy(self, **kwargs: Any) -> Self: + """Create a copy with updated fields.""" + kwargs.pop("deep", None) + kwargs.pop("validate", None) + return replace(self, **kwargs) + + @staticmethod + def _nan_to_num_if_needed( + coords: Union[ArrayFloat, ArrayComplex], + ) -> Union[ArrayFloat, ArrayComplex]: + """Convert NaN and infinite values to finite numbers, optimized for finite inputs.""" + # skip check for small arrays + if coords.size < 1000: + return np.nan_to_num(coords, posinf=LARGE_NUMBER, neginf=-LARGE_NUMBER) + + if np.isfinite(coords).all(): + return coords + return np.nan_to_num(coords, posinf=LARGE_NUMBER, neginf=-LARGE_NUMBER) + + @staticmethod + def _evaluate_with_interpolators( + interpolators: dict[str, Callable[[ArrayFloat], ArrayComplex]], + coords: ArrayFloat, + ) -> dict[str, ArrayComplex]: + """Evaluate field components at coordinates using cached interpolators. + + Parameters + ---------- + interpolators : dict + Dictionary mapping field component names to ``RegularGridInterpolator`` objects. + coords : np.ndarray + Spatial coordinates (N, 3) where fields are evaluated. + + Returns + ------- + dict[str, np.ndarray] + Dictionary mapping component names to field values at coordinates. + """ + auto_cfg = config.adjoint + float_dtype = auto_cfg.gradient_dtype_float + complex_dtype = auto_cfg.gradient_dtype_complex + + coords = DerivativeInfo._nan_to_num_if_needed(coords) + if coords.dtype != float_dtype and coords.dtype != complex_dtype: + coords = coords.astype(float_dtype, copy=False) + return {name: interp(coords) for name, interp in interpolators.items()} + + def create_interpolators(self, dtype: Optional[np.dtype[Any]] = None) -> dict[str, Any]: + """Create interpolators for field components and permittivity data. + + Creates and caches ``RegularGridInterpolator`` objects for all field components + (E_fwd, E_adj, D_fwd, D_adj) and permittivity data (eps_inf, eps_no). + This caching strategy significantly improves performance by avoiding + repeated interpolator construction in gradient evaluation loops. + + Parameters + ---------- + dtype : np.dtype[Any], optional = None + Data type for interpolation coordinates and values. Defaults to the + current ``config.adjoint.gradient_dtype_float``. + + Returns + ------- + dict + Nested dictionary structure: + - Field data: {"E_fwd": {"Ex": interpolator, ...}, ...} + - Permittivity: {"eps_inf": interpolator, "eps_no": interpolator} + """ + from scipy.interpolate import RegularGridInterpolator + + auto_cfg = config.adjoint + if dtype is None: + dtype = auto_cfg.gradient_dtype_float + complex_dtype = auto_cfg.gradient_dtype_complex + + cache_key = str(dtype) + if cache_key in self._interpolators_cache: + return self._interpolators_cache[cache_key] + + interpolators = {} + coord_cache = {} + + def _make_lazy_interpolator_group( + field_data_dict: Optional[FieldData], + group_key: Optional[str], + is_field_group: bool = True, + override_method: Optional[str] = None, + ) -> None: + """Helper to create a group of lazy interpolators.""" + if not field_data_dict: + return + if is_field_group: + interpolators[group_key] = {} + + for component_name, arr in field_data_dict.items(): + # use object ID for caching to handle shared grids + arr_id = id(arr.data) + if arr_id not in coord_cache: + points = tuple(c.data.astype(dtype, copy=False) for c in (arr.x, arr.y, arr.z)) + coord_cache[arr_id] = points + points = coord_cache[arr_id] + + def creator_func( + arr: ScalarFieldDataArray = arr, + points: tuple[np.ndarray, ...] = points, + ) -> Callable[[ArrayFloat], ArrayComplex]: + data = arr.data.astype( + complex_dtype if np.iscomplexobj(arr.data) else dtype, copy=False + ) + # create interpolator with frequency dimension + if "f" in arr.dims: + freq_coords = arr.coords["f"].data.astype(dtype, copy=False) + # ensure frequency dimension is last + if arr.dims != ("x", "y", "z", "f"): + freq_dim_idx = arr.dims.index("f") + axes = list(range(data.ndim)) + axes.append(axes.pop(freq_dim_idx)) + data = np.transpose(data, axes) + else: + # single frequency case - add singleton dimension + freq_coords = np.array([0.0], dtype=dtype) + data = data[..., np.newaxis] + + points_with_freq = (*points, freq_coords) + # If PEC, use nearest interpolation instead of linear to avoid interpolating + # with field values inside the PEC (which are 0). Instead, we make sure to + # choose interpolation points such that their nearest location is outside of + # the PEC surface. The same applies if the background_medium is marked as PEC + # since we will need to use the same interpolation strategy inside the structure + # border. + method = ( + "nearest" + if (self.is_medium_pec or self.background_medium_is_pec) + else "linear" + ) + if override_method is not None: + method = override_method + interpolator_obj = RegularGridInterpolator( + points_with_freq, data, method=method, bounds_error=False, fill_value=None + ) + + def interpolator(coords: ArrayFloat) -> ArrayComplex: + # coords: (N, 3) spatial points + n_points = coords.shape[0] + n_freqs = len(freq_coords) + + # build coordinates with frequency dimension + coords_with_freq = np.empty((n_points * n_freqs, 4), dtype=coords.dtype) + coords_with_freq[:, :3] = np.repeat(coords, n_freqs, axis=0) + coords_with_freq[:, 3] = np.tile(freq_coords, n_points) + + result = interpolator_obj(coords_with_freq) + return result.reshape(n_points, n_freqs) + + return interpolator + + if is_field_group: + interpolators[group_key][component_name] = LazyInterpolator(creator_func) + else: + interpolators[component_name] = LazyInterpolator(creator_func) + + # process field interpolators (nested dictionaries) + interpolator_groups = [ + ("E_fwd", self.E_fwd), + ("E_adj", self.E_adj), + ("D_fwd", self.D_fwd), + ("D_adj", self.D_adj), + ] + if self.is_medium_pec or self.background_medium_is_pec: + interpolator_groups += [("H_fwd", self.H_fwd), ("H_adj", self.H_adj)] # type: ignore[list-item] + for group_key, data_dict in interpolator_groups: + _make_lazy_interpolator_group( + data_dict, f"{group_key}_linear", is_field_group=True, override_method="linear" + ) + _make_lazy_interpolator_group( + data_dict, f"{group_key}_nearest", is_field_group=True, override_method="nearest" + ) + + if self.eps_data is not None: + _make_lazy_interpolator_group( + self.eps_data, "eps_data", is_field_group=True, override_method="nearest" + ) + + if self.eps_in is not None: + _make_lazy_interpolator_group( + {"eps_in": self.eps_in}, None, is_field_group=False, override_method="nearest" + ) + if self.eps_out is not None: + _make_lazy_interpolator_group( + {"eps_out": self.eps_out}, None, is_field_group=False, override_method="nearest" + ) + + self._interpolators_cache[cache_key] = interpolators + return interpolators + + def evaluate_gradient_at_points( + self, + spatial_coords: np.ndarray, + normals: np.ndarray, + perps1: np.ndarray, + perps2: np.ndarray, + interpolators: Optional[dict] = None, + ) -> np.ndarray: + """Compute adjoint gradients at surface points for shape optimization. + + Implements the surface integral formulation for computing gradients with respect + to geometry perturbations. + + Parameters + ---------- + spatial_coords : np.ndarray + (N, 3) array of surface evaluation points. + normals : np.ndarray + (N, 3) array of outward-pointing normal vectors at each surface point. + perps1 : np.ndarray + (N, 3) array of first tangent vectors perpendicular to normals. + perps2 : np.ndarray + (N, 3) array of second tangent vectors perpendicular to both normals and perps1. + interpolators : dict = None + Pre-computed field interpolators for efficiency. + + Returns + ------- + np.ndarray + (N,) array of gradient values at each surface point. Must be integrated + with appropriate quadrature weights to get total gradient. + """ + if interpolators is None: + raise NotImplementedError( + "Direct field evaluation without interpolators is not implemented. " + "Please create interpolators using 'create_interpolators()' first." + ) + + # In all paths below, we need to have computed the gradient integration for a + # dielectric-dielectric interface. + vjps_dielectric = self._evaluate_dielectric_gradient_at_points( + spatial_coords, + normals, + perps1, + perps2, + interpolators, + self.eps_in, + self.eps_out, + ) + + if self.is_medium_pec: + # The structure medium is PEC, but there may be a part of the interface that has + # dielectric placed on top of or around it where we want to use the dielectric + # gradient integration. We use the mask to choose between the PEC-dielectric and + # dielectric-dielectric parts of the border. + + # Detect PEC by looking just inside the boundary + mask_pec = self._detect_pec_gradient_points( + spatial_coords, + normals, + self.eps_in, + interpolators["eps_data"], + is_outside=False, + ) + + # Compute PEC gradients, pulling fields outside of the boundary + vjps_pec = self._evaluate_pec_gradient_at_points( + spatial_coords, + normals, + perps1, + perps2, + interpolators, + ("eps_out", self.eps_out), + is_outside=True, + ) + + vjps = mask_pec * vjps_pec + (1.0 - mask_pec) * vjps_dielectric + elif self.background_medium_is_pec: + # The structure medium is dielectric, but there may be a part of the interface that has + # PEC placed on top of or around it where we want to use the PEC gradient integration. + # We use the mask to choose between the dielectric-dielectric and PEC-dielectric parts + # of the border. + + # Detect PEC by looking just outside the boundary + mask_pec = self._detect_pec_gradient_points( + spatial_coords, + normals, + self.eps_out, + interpolators["eps_data"], + is_outside=True, + ) + + # Compute PEC gradients, pulling fields inside of the boundary and applying a negative + # sign compared to above because inside and outside definitions are switched + vjps_pec = -self._evaluate_pec_gradient_at_points( + spatial_coords, + normals, + perps1, + perps2, + interpolators, + ("eps_in", self.eps_in), + is_outside=False, + ) + + vjps = mask_pec * vjps_pec + (1.0 - mask_pec) * vjps_dielectric + else: + # The structure and its background are both assumed to be dielectric, so we use the + # dielectric-dielectric gradient integration. + vjps = vjps_dielectric + + # sum over frequency dimension + vjps = np.sum(vjps, axis=-1) + + return vjps + + def _evaluate_dielectric_gradient_at_points( + self, + spatial_coords: ArrayFloat, + normals: ArrayFloat, + perps1: ArrayFloat, + perps2: ArrayFloat, + interpolators: dict[str, dict[str, Callable[[ArrayFloat], ArrayComplex]]], + eps_in_data: ScalarFieldDataArray, + eps_out_data: ScalarFieldDataArray, + ) -> ArrayComplex: + eps_out_coords = self._snap_spatial_coords_boundary( + spatial_coords, + normals, + is_outside=True, + data_array=eps_out_data, + ) + eps_in_coords = self._snap_spatial_coords_boundary( + spatial_coords, + normals, + is_outside=False, + data_array=eps_in_data, + ) + + eps_out = interpolators["eps_out"](eps_out_coords) + eps_in = interpolators["eps_in"](eps_in_coords) + + # evaluate all field components at surface points + E_fwd_at_coords = { + name: interp(spatial_coords) for name, interp in interpolators["E_fwd_linear"].items() + } + E_adj_at_coords = { + name: interp(spatial_coords) for name, interp in interpolators["E_adj_linear"].items() + } + D_fwd_at_coords = { + name: interp(spatial_coords) for name, interp in interpolators["D_fwd_linear"].items() + } + D_adj_at_coords = { + name: interp(spatial_coords) for name, interp in interpolators["D_adj_linear"].items() + } + + delta_eps_inv = 1.0 / eps_in - 1.0 / eps_out + delta_eps = eps_in - eps_out + + # project fields onto local surface basis (normal + two tangents) + D_fwd_norm = self._project_in_basis(D_fwd_at_coords, basis_vector=normals) + D_adj_norm = self._project_in_basis(D_adj_at_coords, basis_vector=normals) + + E_fwd_perp1 = self._project_in_basis(E_fwd_at_coords, basis_vector=perps1) + E_adj_perp1 = self._project_in_basis(E_adj_at_coords, basis_vector=perps1) + + E_fwd_perp2 = self._project_in_basis(E_fwd_at_coords, basis_vector=perps2) + E_adj_perp2 = self._project_in_basis(E_adj_at_coords, basis_vector=perps2) + + D_der_norm = D_fwd_norm * D_adj_norm + E_der_perp1 = E_fwd_perp1 * E_adj_perp1 + E_der_perp2 = E_fwd_perp2 * E_adj_perp2 + + vjps = -delta_eps_inv * D_der_norm + E_der_perp1 * delta_eps + E_der_perp2 * delta_eps + + return vjps + + def _snap_spatial_coords_boundary( + self, + spatial_coords: ArrayFloat, + normals: ArrayFloat, + is_outside: bool, + data_array: ScalarFieldDataArray, + ) -> np.ndarray: + """Assuming a nearest interpolation, adjust the interpolation points given the grid + defined by `grid_centers` and using `spatial_coords` as a starting point such that we + select a point inside/outside the boundary depending on is_outside. + + *** (nearest point outside boundary) + ^ + | n (normal direction) + | + _.-~'`-._.-~'`-._ (boundary) + * (nearest point) + + Parameters + ---------- + spatial_coords : np.ndarray + (N, 3) array of surface evaluation points. + normals : np.ndarray + (N, 3) array of outward-pointing normal vectors at each surface point. + is_outside: bool + Indicator specifying if coordinates should be snapped inside or outside the boundary. + data_array: ScalarFieldDataArray + Data array to pull grid centers from when snapping coordinates. + + Returns + ------- + np.ndarray + (N, 3) array of coordinate centers at which to interpolate such that they line up + with a grid center and are inside/outside the boundary + """ + coords = data_array.coords + grid_centers = {key: np.array(coords[key].values) for key in coords} + + grid_ddim = np.zeros_like(normals) + for idx, dim in enumerate("xyz"): + expanded_coords = np.expand_dims(spatial_coords[:, idx], axis=1) + grid_centers_select = grid_centers[dim] + + diff = np.abs(expanded_coords - grid_centers_select) + + nearest_grid = np.argmin(diff, axis=-1) + nearest_grid = np.minimum(np.maximum(nearest_grid, 1), len(grid_centers_select) - 1) + + # compute the local grid spacing near the boundary + grid_ddim[:, idx] = ( + grid_centers_select[nearest_grid] - grid_centers_select[nearest_grid - 1] + ) + + # assuming we move in the normal direction, finds which dimension we need to move the least + # in order to ensure we snap to a point outside the boundary in the worst case (i.e. - the + # nearest point is just inside the surface) + coords_dn = np.min( + np.abs(grid_ddim) / (np.abs(normals) + np.finfo(normals.dtype).eps), + axis=1, + keepdims=True, + ) + + # adjust coordinates by half a grid point outside boundary such that nearest interpolation + # point snaps to outside the boundary + normal_direction = 1.0 if is_outside else -1.0 + adjust_spatial_coords = ( + spatial_coords + + normal_direction * normals * config.adjoint.boundary_snapping_fraction * coords_dn + ) + + return adjust_spatial_coords + + def _compute_edge_distance( + self, + spatial_coords: np.ndarray, + grid_centers: dict[str, np.ndarray], + adjust_spatial_coords: np.ndarray, + ) -> np.ndarray: + """Assuming nearest neighbor interpolation, computes the edge distance after interpolation when using the + adjust_spatial_coords computed from _snap_spatial_coords_boundary. + + Parameters + ---------- + spatial_coords : np.ndarray + (N, 3) array of surface evaluation points. + normals : np.ndarray + (N, 3) array of outward-pointing normal vectors at each surface point. + grid_centers: dict[str, np.ndarray] + The grid points for a given field component indexed by dimension. These grid points + are used to find the nearest snapping point and adjust the interpolation coordinates + to ensure we fall inside/outside of a boundary. + + Returns + ------- + np.ndarray + (N,) array of distances from the nearest interpolation points to the desired surface + edge points specified by `spatial_coords` + """ + + edge_distance_squared_sum = np.zeros_like(adjust_spatial_coords[:, 0]) + for idx, dim in enumerate("xyz"): + expanded_adjusted_coords = np.expand_dims(adjust_spatial_coords[:, idx], axis=1) + grid_centers_select = grid_centers[dim] + + # find nearest grid point from the adjusted coordinates + diff = np.abs(expanded_adjusted_coords - grid_centers_select) + nearest_grid = np.argmin(diff, axis=-1) + + # compute edge distance from the nearest interpolated point to the boundary edge + edge_distance_squared_sum += ( + np.abs(spatial_coords[:, idx] - grid_centers_select[nearest_grid]) ** 2 + ) + + # this edge distance is useful when correcting for edge singularities like those from a PEC + # material and is used when the PEC PolySlab structure has zero thickness, for example + edge_distance = np.sqrt(edge_distance_squared_sum) + + return edge_distance + + def _detect_pec_gradient_points( + self, + spatial_coords: np.ndarray, + normals: np.ndarray, + eps_data: ScalarFieldDataArray, + interpolator: LazyInterpolator, + is_outside: bool, + ) -> np.ndarray: + def _detect_pec(eps_mask: np.ndarray) -> np.ndarray: + return 1.0 * (eps_mask < config.adjoint.pec_detection_threshold) + + adjusted_coords = self._snap_spatial_coords_boundary( + spatial_coords=spatial_coords, + normals=normals, + is_outside=is_outside, + data_array=eps_data, + ) + + eps_adjusted_all = [ + component_interpolator(adjusted_coords) + for _, component_interpolator in interpolator.items() + ] + eps_detect_pec = reduce(np.minimum, eps_adjusted_all) + + return _detect_pec(eps_detect_pec) + + def _evaluate_pec_gradient_at_points( + self, + spatial_coords: np.ndarray, + normals: np.ndarray, + perps1: np.ndarray, + perps2: np.ndarray, + interpolators: dict, + eps_dielectric: tuple[str, ScalarFieldDataArray], + is_outside: bool, + ) -> np.ndarray: + eps_dielectric_key, eps_dielectric_data = eps_dielectric + + def _snap_coordinate_outside( + field_components: FieldData, + ) -> dict[str, dict[str, ArrayFloat]]: + """Helper function to perform coordinate adjustment and compute edge distance for each + component in `field_components`. + + Parameters + ---------- + field_components: FieldData + The field components (i.e - Ex, Ey, Ez, Hx, Hy, Hz) that we would like to sample just + outside the PEC surface using nearest interpolation. + + Returns + ------- + dict[str, dict[str, np.ndarray]] + Dictionary mapping each field component name to a dictionary of adjusted coordinates + and edge distances for that component. + """ + adjustment = {} + for name in field_components: + field_component = field_components[name] + field_component_coords = field_component.coords + + grid_centers = { + key: np.array(field_component_coords[key].values) + for key in field_component_coords + } + + adjusted_coords = self._snap_spatial_coords_boundary( + spatial_coords, + normals, + is_outside=is_outside, + data_array=field_component, + ) + + edge_distance = self._compute_edge_distance( + spatial_coords=spatial_coords, + grid_centers=grid_centers, + adjust_spatial_coords=adjusted_coords, + ) + adjustment[name] = {"coords": adjusted_coords, "edge_distance": edge_distance} + + return adjustment + + def _interpolate_field_components( + interp_coords: dict[str, dict[str, ArrayFloat]], field_name: str + ) -> dict[str, ArrayComplex]: + return { + name: interp(interp_coords[name]["coords"]) + for name, interp in interpolators[field_name].items() + } + + # adjust coordinates for PEC to be outside structure bounds and get edge distance for singularity correction. + E_fwd_coords_adjusted = _snap_coordinate_outside(self.E_fwd) + E_adj_coords_adjusted = _snap_coordinate_outside(self.E_adj) + + H_fwd_coords_adjusted = _snap_coordinate_outside(self.H_fwd) + H_adj_coords_adjusted = _snap_coordinate_outside(self.H_adj) + + # using the adjusted coordinates, evaluate all field components at surface points + E_fwd_at_coords = _interpolate_field_components( + E_fwd_coords_adjusted, field_name="E_fwd_nearest" + ) + E_adj_at_coords = _interpolate_field_components( + E_adj_coords_adjusted, field_name="E_adj_nearest" + ) + H_fwd_at_coords = _interpolate_field_components( + H_fwd_coords_adjusted, field_name="H_fwd_nearest" + ) + H_adj_at_coords = _interpolate_field_components( + H_adj_coords_adjusted, field_name="H_adj_nearest" + ) + + eps_coords_adjusted = self._snap_spatial_coords_boundary( + spatial_coords, + normals, + is_outside=is_outside, + data_array=eps_dielectric_data, + ) + eps_dielectric = interpolators[eps_dielectric_key](eps_coords_adjusted) + + structure_sizes = np.array( + [self.bounds[1][idx] - self.bounds[0][idx] for idx in range(len(self.bounds[0]))] + ) + + is_flat_perp_dim1 = np.isclose(np.abs(np.sum(perps1[0] * structure_sizes)), 0.0) + is_flat_perp_dim2 = np.isclose(np.abs(np.sum(perps2[0] * structure_sizes)), 0.0) + flat_perp_dims = [is_flat_perp_dim1, is_flat_perp_dim2] + + # check if this integration is happening along an edge in which case we will eliminate + # on of the H field integration components and apply singularity correction + pec_line_integration = is_flat_perp_dim1 or is_flat_perp_dim2 + + def _compute_singularity_correction( + adjustment_: dict[str, dict[str, ArrayFloat]], + ) -> ArrayFloat: + """ + Given the `adjustment_` which contains the distance from the PEC edge each field + component is nearest interpolated at, computes the singularity correction when + working with 2D PEC using the average edge_distance for each component. In the case + of 3D PEC gradients, no singularity correction is applied so an array of ones is returned. + + Parameters + ---------- + adjustment_: dict[str, dict[str, np.ndarray]] + Dictionary that maps field component name to a dictionary containing the coordinate + adjustment and the distance to the PEC edge for those coordinates. The edge distance + is used for 2D PEC singularity correction. + + Returns + ------- + np.ndarray + Returns the singularity correction which has shape (N,) where there are N points in + `spatial_coords` + """ + return ( + ( + 0.5 + * np.pi + * np.mean([adjustment_[name]["edge_distance"] for name in adjustment_], axis=0) + ) + if pec_line_integration + else np.ones_like(spatial_coords, shape=spatial_coords.shape[0]) + ) + + E_norm_singularity_correction = np.expand_dims( + _compute_singularity_correction(E_fwd_coords_adjusted), axis=1 + ) + H_perp_singularity_correction = np.expand_dims( + _compute_singularity_correction(H_fwd_coords_adjusted), axis=1 + ) + + E_fwd_norm = self._project_in_basis(E_fwd_at_coords, basis_vector=normals) + E_adj_norm = self._project_in_basis(E_adj_at_coords, basis_vector=normals) + + # compute the normal E contribution to the gradient (the tangential E contribution + # is 0 in the case of PEC since this field component is continuous and thus 0 at + # the boundary) + contrib_E = E_norm_singularity_correction * eps_dielectric * E_fwd_norm * E_adj_norm + vjps = contrib_E + + # compute the tangential H contribution to the gradient (the normal H contribution + # is 0 for PEC) + H_fwd_perp1 = self._project_in_basis(H_fwd_at_coords, basis_vector=perps1) + H_adj_perp1 = self._project_in_basis(H_adj_at_coords, basis_vector=perps1) + + H_fwd_perp2 = self._project_in_basis(H_fwd_at_coords, basis_vector=perps2) + H_adj_perp2 = self._project_in_basis(H_adj_at_coords, basis_vector=perps2) + + H_der_perp1 = H_perp_singularity_correction * H_fwd_perp1 * H_adj_perp1 + H_der_perp2 = H_perp_singularity_correction * H_fwd_perp2 * H_adj_perp2 + + H_integration_components = (H_der_perp1, H_der_perp2) + if pec_line_integration: + # if we are integrating along the line, we choose the H component normal to + # the edge which corresponds to a surface current along the edge whereas the other + # tangential component corresponds to a surface current along the flat dimension. + H_integration_components = tuple( + H_comp for idx, H_comp in enumerate(H_integration_components) if flat_perp_dims[idx] + ) + + # for each of the tangential components we are integrating the H fields over, + # adjust weighting to account for pre-weighting of the source by `EPSILON_0` + # and multiply by appropriate `MU_0` factor + for H_perp in H_integration_components: + contrib_H = MU_0 * H_perp / EPSILON_0 + vjps += contrib_H + + return vjps + + @staticmethod + def _project_in_basis( + field_components: dict[str, np.ndarray], + basis_vector: np.ndarray, + ) -> np.ndarray: + """Project 3D field components onto a basis vector. + + Parameters + ---------- + field_components : dict[str, np.ndarray] + Dictionary with keys like "Ex", "Ey", "Ez" or "Dx", "Dy", "Dz" containing field values. + Values have shape (N, F) where F is the number of frequencies. + basis_vector : np.ndarray + (N, 3) array of basis vectors, one per evaluation point. + + Returns + ------- + np.ndarray + Projected field values with shape (N, F). + """ + prefix = next(iter(field_components.keys()))[0] + field_matrix = np.stack([field_components[f"{prefix}{dim}"] for dim in "xyz"], axis=0) + + # always expect (3, N, F) shape, transpose to (N, 3, F) + field_matrix = np.transpose(field_matrix, (1, 0, 2)) + return np.einsum("ij...,ij->i...", field_matrix, basis_vector) + + def project_der_map_to_axis( + self, axis: xyz, field_type: str = "E" + ) -> dict[str, ScalarFieldDataArray] | None: + """Return a copy of the selected derivative map with only one axis kept. + + Parameters + ---------- + axis: + Axis to keep (``"x"``, ``"y"``, ``"z"``, case-insensitive). + field_type: + Map selector: ``"E"`` (``self.E_der_map``) or ``"D"`` (``self.D_der_map``). + + Returns + ------- + dict[str, ScalarFieldDataArray] | None + Copied map where non-selected components are replaced by zeros, or ``None`` + if the requested map is unavailable. + """ + field_map = {"E": self.E_der_map, "D": self.D_der_map}.get(field_type) + if field_map is None: + raise ValueError("field type must be 'D' or 'E'.") + + axis = axis.lower() + projected = dict(field_map) + if not field_map: + return projected + for dim in "xyz": + key = f"E{dim}" + if key not in field_map: + continue + if dim != axis: + projected[key] = xr.zeros_like(field_map[key]) + else: + projected[key] = field_map[key] + return projected + + def adaptive_vjp_spacing( + self, + wl_fraction: Optional[float] = None, + min_allowed_spacing_fraction: Optional[float] = None, + ) -> float: + """Compute adaptive spacing for finite-difference gradient evaluation. + + Determines an appropriate spatial resolution based on the material + properties and electromagnetic wavelength/skin depth. + + Parameters + ---------- + wl_fraction : float, optional + Fraction of wavelength/skin depth to use as spacing. Defaults to the configured + ``autograd.default_wavelength_fraction`` when ``None``. + min_allowed_spacing_fraction : float, optional + Minimum allowed spacing fraction of free space wavelength used to + prevent numerical issues. Defaults to ``config.adjoint.minimum_spacing_fraction`` + when not specified. + + Returns + ------- + float + Adaptive spacing value for gradient evaluation. + """ + if wl_fraction is None or min_allowed_spacing_fraction is None: + from tidy3d._common.config import config + + if wl_fraction is None: + wl_fraction = config.adjoint.default_wavelength_fraction + if min_allowed_spacing_fraction is None: + min_allowed_spacing_fraction = config.adjoint.minimum_spacing_fraction + + def spacing_by_permittivity(eps_array: ScalarFieldDataArray) -> float: + eps_real = np.asarray(eps_array.values, dtype=np.complex128).real + + dx_candidates = [] + max_frequency = np.max(self.frequencies) + + # wavelength-based sampling for dielectrics + if np.any(eps_real > 0): + eps_max = eps_real[eps_real > 0].max() + lambda_min = self.wavelength_min / np.sqrt(eps_max) + dx_candidates.append(wl_fraction * lambda_min) + + # skin depth sampling for metals + if np.any(eps_real <= 0): + omega = 2 * np.pi * max_frequency + eps_neg = eps_real[eps_real <= 0] + delta_min = C_0 / (omega * np.sqrt(np.abs(eps_neg).max())) + dx_candidates.append(wl_fraction * delta_min) + + computed_spacing = min(dx_candidates) + + return computed_spacing + + eps_spacings = [ + spacing_by_permittivity(eps_array) for _, eps_array in self.eps_data.items() + ] + computed_spacing = np.min(eps_spacings) + + min_allowed_spacing = self.wavelength_min * min_allowed_spacing_fraction + + if computed_spacing < min_allowed_spacing: + log.warning( + f"Based on the material, the adaptive spacing for integrating the polyslab surface " + f"would be {computed_spacing:.3e} μm. The spacing has been clipped to {min_allowed_spacing:.3e} μm " + f"to prevent a performance degradation.", + log_once=True, + ) + + return max(computed_spacing, min_allowed_spacing) + + @property + def wavelength_min(self) -> float: + return C_0 / np.max(self.frequencies) + + @property + def wavelength_max(self) -> float: + return C_0 / np.min(self.frequencies) + + +def integrate_within_bounds(arr: xr.DataArray, dims: list[str], bounds: Bound) -> xr.DataArray: + """Integrate a data array within specified spatial bounds. + + Clips the integration domain to the specified bounds and performs + numerical integration using the trapezoidal rule. + + Parameters + ---------- + arr : xr.DataArray + Data array to integrate. + dims : list[str] + Dimensions to integrate over (e.g., ['x', 'y', 'z']). + bounds : Bound + Integration bounds as [[xmin, ymin, zmin], [xmax, ymax, zmax]]. + + Returns + ------- + xr.DataArray + Result of integration with specified dimensions removed. + + Notes + ----- + - Coordinates outside bounds are clipped, effectively setting dL=0 + - Only integrates dimensions with more than one coordinate point + - Uses xarray's integrate method (trapezoidal rule) + """ + bounds = np.asarray(bounds).T + all_coords = {} + + for dim, (bmin, bmax) in zip(dims, bounds): + bmin = get_static(bmin) + bmax = get_static(bmax) + + # clip coordinates to bounds (sets dL=0 outside bounds) + coord_values = arr.coords[dim].data + all_coords[dim] = np.clip(coord_values, bmin, bmax) + + _arr = arr.assign_coords(**all_coords) + + # only integrate dimensions with multiple points + dims_integrate = [dim for dim in dims if len(_arr.coords[dim]) > 1] + return _arr.integrate(coord=dims_integrate) + + +__all__ = [ + "DerivativeInfo", + "integrate_within_bounds", +] diff --git a/tidy3d/_common/components/autograd/field_map.py b/tidy3d/_common/components/autograd/field_map.py new file mode 100644 index 0000000000..159a7d1527 --- /dev/null +++ b/tidy3d/_common/components/autograd/field_map.py @@ -0,0 +1,77 @@ +"""Typed containers for autograd traced field metadata.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, Union + +from pydantic import Field + +from tidy3d._common.components.autograd.types import TracedArrayLike, TracedComplex, TracedFloat +from tidy3d._common.components.base import Tidy3dBaseModel + +if TYPE_CHECKING: + from typing import Callable + + from tidy3d._common.components.autograd.types import AutogradFieldMap + + +class Tracer(Tidy3dBaseModel): + """Representation of a single traced element within a model.""" + + path: tuple[Any, ...] = Field( + title="Path to the traced object in the model dictionary.", + ) + data: Union[TracedFloat, TracedComplex, TracedArrayLike] = Field(title="Tracing data") + + +class FieldMap(Tidy3dBaseModel): + """Collection of traced elements.""" + + tracers: tuple[Tracer, ...] = Field( + title="Collection of Tracers.", + ) + + @property + def to_autograd_field_map(self) -> AutogradFieldMap: + """Convert to ``AutogradFieldMap`` autograd dictionary.""" + return {tracer.path: tracer.data for tracer in self.tracers} + + @classmethod + def from_autograd_field_map(cls, autograd_field_map: AutogradFieldMap) -> FieldMap: + """Initialize from an ``AutogradFieldMap`` autograd dictionary.""" + tracers = [] + for path, data in autograd_field_map.items(): + tracers.append(Tracer(path=path, data=data)) + return cls(tracers=tuple(tracers)) + + +def _encoded_path(path: tuple[Any, ...]) -> str: + """Return a stable JSON representation for a traced path.""" + return json.dumps(list(path), separators=(",", ":"), ensure_ascii=True) + + +class TracerKeys(Tidy3dBaseModel): + """Collection of traced field paths.""" + + keys: tuple[tuple[Any, ...], ...] = Field( + title="Collection of tracer keys.", + ) + + def encoded_keys(self) -> list[str]: + """Return the JSON-encoded representation of keys.""" + return [_encoded_path(path) for path in self.keys] + + @classmethod + def from_field_mapping( + cls, + field_mapping: AutogradFieldMap, + *, + sort_key: Callable[[tuple[Any, ...]], str] | None = None, + ) -> TracerKeys: + """Construct keys from an autograd field mapping.""" + if sort_key is None: + sort_key = _encoded_path + + sorted_paths = tuple(sorted(field_mapping.keys(), key=sort_key)) + return cls(keys=sorted_paths) diff --git a/tidy3d/_common/components/autograd/functions.py b/tidy3d/_common/components/autograd/functions.py new file mode 100644 index 0000000000..86beaec421 --- /dev/null +++ b/tidy3d/_common/components/autograd/functions.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING, Any + +import autograd.numpy as anp +import numpy as np +from autograd.extend import defjvp, defvjp, primitive +from autograd.numpy.numpy_jvps import broadcast +from autograd.numpy.numpy_vjps import unbroadcast_f + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from tidy3d._common.components.autograd.types import InterpolationType + + +def _evaluate_nearest( + indices: NDArray[np.int64], norm_distances: NDArray[np.float64], values: NDArray[np.float64] +) -> NDArray[np.float64]: + """Perform nearest neighbor interpolation in an n-dimensional space. + + This function determines the nearest neighbor in a grid for a given point + and returns the corresponding value from the input array. + + Parameters + ---------- + indices : np.ndarray[np.int64] + Indices of the lower bounds of the grid cell containing the interpolation point. + norm_distances : np.ndarray[np.float64] + Normalized distances from the lower bounds of the grid cell to the + interpolation point, for each dimension. + values : np.ndarray[np.float64] + The n-dimensional array of values to interpolate from. + + Returns + ------- + np.ndarray[np.float64] + The value of the nearest neighbor to the interpolation point. + """ + idx_res = tuple(anp.where(yi <= 0.5, i, i + 1) for i, yi in zip(indices, norm_distances)) + return values[idx_res] + + +def _evaluate_linear( + indices: NDArray[np.int64], norm_distances: NDArray[np.float64], values: NDArray[np.float64] +) -> NDArray[np.float64]: + """Perform linear interpolation in an n-dimensional space. + + This function calculates the linearly interpolated value at a point in an + n-dimensional grid, given the indices of the surrounding grid points and + the normalized distances to these points. + The multi-linear interpolation is implemented by computing a weighted + average of the values at the vertices of the hypercube surrounding the + interpolation point. + + Parameters + ---------- + indices : np.ndarray[np.int64] + Indices of the lower bounds of the grid cell containing the interpolation point. + norm_distances : np.ndarray[np.float64] + Normalized distances from the lower bounds of the grid cell to the + interpolation point, for each dimension. + values : np.ndarray[np.float64] + The n-dimensional array of values to interpolate from. + + Returns + ------- + np.ndarray[np.float64] + The interpolated value at the desired point. + """ + # Create a slice object for broadcasting over trailing dimensions + _slice = (slice(None),) + (None,) * (values.ndim - len(indices)) + + # Prepare iterables for lower and upper bounds of the hypercube + ix = zip(indices, (1 - yi for yi in norm_distances)) + iy = zip((i + 1 for i in indices), norm_distances) + + # Initialize the result + value = anp.zeros(1) + + # Iterate over all vertices of the hypercube + for h in itertools.product(*zip(ix, iy)): + edge_indices, weights = zip(*h) + + # Compute the weight for this vertex + weight = anp.ones(1) + for w in weights: + weight = weight * w + + # Compute the contribution of this vertex and add it to the result + term = values[edge_indices] * weight[_slice] + value = value + term + + return value + + +def interpn( + points: tuple[NDArray[np.float64], ...], + values: NDArray[np.float64], + xi: tuple[NDArray[np.float64], ...], + *, + method: InterpolationType = "linear", + **kwargs: Any, +) -> NDArray[np.float64]: + """Interpolate over a rectilinear grid in arbitrary dimensions. + + This function mirrors the interface of `scipy.interpolate.interpn` but is differentiable with autograd. + + Parameters + ---------- + points : tuple[np.ndarray[np.float64], ...] + The points defining the rectilinear grid in n dimensions. + values : np.ndarray[np.float64] + The data values on the rectilinear grid. + xi : tuple[np.ndarray[np.float64], ...] + The coordinates to sample the gridded data at. + method : InterpolationType = "linear" + The method of interpolation to perform. Supported are "linear" and "nearest". + + Returns + ------- + np.ndarray[np.float64] + The interpolated values. + + Raises + ------ + ValueError + If the interpolation method is not supported. + + See Also + -------- + `scipy.interpolate.interpn `_ + """ + from scipy.interpolate import RegularGridInterpolator + + if method == "nearest": + interp_fn = _evaluate_nearest + elif method == "linear": + interp_fn = _evaluate_linear + else: + raise ValueError(f"Unsupported interpolation method: {method}") + + # Avoid SciPy coercing autograd ArrayBox values during _check_values. + dummy_values = np.zeros(np.shape(values), dtype=float) + if kwargs.get("fill_value") == "extrapolate": + itrp = RegularGridInterpolator( + points, dummy_values, method=method, fill_value=None, bounds_error=False + ) + else: + itrp = RegularGridInterpolator(points, dummy_values, method=method) + + # Prepare the grid for interpolation + # This step reshapes the grid, checks for NaNs and out-of-bounds values + # It returns: + # - reshaped grid + # - original shape + # - number of dimensions + # - boolean array indicating NaN positions + # - (discarded) boolean array for out-of-bounds values + xi, shape, ndim, nans, _ = itrp._prepare_xi(xi) + + # Find the indices of the grid cells containing the interpolation points + # and calculate the normalized distances (ranging from 0 at lower grid point to 1 + # at upper grid point) within these cells + indices, norm_distances = itrp._find_indices(xi.T) + + result = interp_fn(indices, norm_distances, values) + nans = anp.reshape(nans, (-1,) + (1,) * (result.ndim - 1)) + result = anp.where(nans, np.nan, result) + return anp.reshape(result, shape[:-1] + values.shape[ndim:]) + + +def trapz(y: NDArray, x: NDArray = None, dx: float = 1.0, axis: int = -1) -> float: + """ + Integrate along the given axis using the composite trapezoidal rule. + + Parameters + ---------- + y : np.ndarray + Input array to integrate. + x : np.ndarray = None + The sample points corresponding to the y values. If None, the sample points are assumed to be evenly spaced + with spacing `dx`. + dx : float = 1.0 + The spacing between sample points when `x` is None. Default is 1.0. + axis : int = -1 + The axis along which to integrate. Default is the last axis. + + Returns + ------- + float + Definite integral as approximated by the trapezoidal rule. + """ + if x is None: + d = dx + elif x.ndim == 1: + d = np.diff(x) + shape = [1] * y.ndim + shape[axis] = d.shape[0] + d = np.reshape(d, shape) + else: + d = np.diff(x, axis=axis) + + slice1 = [slice(None)] * y.ndim + slice2 = [slice(None)] * y.ndim + slice1[axis] = slice(1, None) + slice2[axis] = slice(None, -1) + + return anp.sum((y[tuple(slice1)] + y[tuple(slice2)]) * d / 2, axis=axis) + + +@primitive +def _add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray: + """ + Add values to specified indices of an array. + + Autograd requires that arguments to primitives are passed in positionally. + ``add_at`` is the public-facing wrapper for this function, + which allows keyword arguments in case users pass in kwargs. + """ + out = np.copy(x) # Copy to preserve 'x' for gradient computation + out[tuple(indices_x)] += y + return out + + +defvjp( + _add_at, + lambda ans, x, indices_x, y: unbroadcast_f(x, lambda g: g), + lambda ans, x, indices_x, y: lambda g: g[tuple(indices_x)], + argnums=(0, 2), +) + +defjvp( + _add_at, + lambda g, ans, x, indices_x, y: broadcast(g, ans), + lambda g, ans, x, indices_x, y: _add_at(anp.zeros_like(ans), indices_x, g), + argnums=(0, 2), +) + + +def add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray: + """ + Add values to specified indices of an array. + + This function creates a copy of the input array `x`, adds the values from `y` to the specified + indices `indices_x`, and returns the modified array. + + Parameters + ---------- + x : np.ndarray + Input array to which values will be added. + indices_x : tuple + Indices of `x` where values from `y` will be added. + y : np.ndarray + Values to add to the specified indices of `x`. + + Returns + ------- + np.ndarray + The modified array with values added at the specified indices. + """ + return _add_at(x, indices_x, y) + + +@primitive +def _straight_through_clip(x: NDArray, a_min: Any, a_max: Any) -> NDArray: + """Passthrough clip can be used to preserve gradients at the endpoints of the clip range where + there is a discontinuity in the derivative. This is useful when values are at the endpoints but may + have a gradient away from the boundary or in cases where numerical precision causes a function that is + typically bounded by the clip bounds to produce a value just outside the bounds. In the forward pass, + this runs the standard clip.""" + return anp.clip(x, a_min=a_min, a_max=a_max) + + +def _straight_through_clip_vjp(ans: Any, x: NDArray, a_min: Any, a_max: Any) -> NDArray: + """Preserve original gradient information in the backward pass up until a tolerance beyond the clip bounds.""" + tolerance = 1e-5 + mask = (x >= a_min - tolerance) & (x <= a_max + tolerance) + return lambda g: g * mask + + +defvjp(_straight_through_clip, _straight_through_clip_vjp) + +__all__ = [ + "add_at", + "interpn", + "trapz", +] diff --git a/tidy3d/_common/components/autograd/types.py b/tidy3d/_common/components/autograd/types.py new file mode 100644 index 0000000000..baea29e1fc --- /dev/null +++ b/tidy3d/_common/components/autograd/types.py @@ -0,0 +1,136 @@ +# type information for autograd + +from __future__ import annotations + +import copy +from typing import TYPE_CHECKING, Annotated, Any, Literal, Union, get_origin + +import autograd.numpy as anp +from autograd.builtins import dict as TracedDict +from autograd.extend import Box, defvjp, primitive +from autograd.numpy.numpy_boxes import ArrayBox +from pydantic import BeforeValidator, PlainSerializer, PositiveFloat, TypeAdapter + +from tidy3d._common.components.autograd.utils import get_static, hasbox +from tidy3d._common.components.types.base import ( + ArrayFloat2D, + ArrayLike, + Complex, + Size1D, + _auto_serializer, +) +from tidy3d._common.components.types.utils import _add_schema + +if TYPE_CHECKING: + from typing import Optional + + from pydantic import SerializationInfo + + from tidy3d._common.compat import TypeAlias + +# add schema to the Box +_add_schema(Box, title="AutogradBox", field_type_str="autograd.tracer.Box") +_add_schema(ArrayBox, title="AutogradArrayBox", field_type_str="autograd.numpy.ArrayBox") + +# make sure Boxes in tidy3d properly define VJPs for copy operations, for computational graph +_copy = primitive(copy.copy) +_deepcopy = primitive(copy.deepcopy) + +defvjp(_copy, lambda ans, x: lambda g: _copy(g)) +defvjp(_deepcopy, lambda ans, x, memo: lambda g: _deepcopy(g, memo)) + +Box.__copy__ = lambda v: _copy(v) +Box.__deepcopy__ = lambda v, memo: _deepcopy(v, memo) +Box.__str__ = lambda self: f"{self._value} <{type(self).__name__}>" +Box.__repr__ = Box.__str__ + + +def traced_alias(base_alias: Any, *, name: Optional[str] = None) -> TypeAlias: + base_adapter = TypeAdapter(base_alias, config={"arbitrary_types_allowed": True}) + + def _validate_box_or_container(v: Any) -> Any: + # case 1: v itself is a tracer + # in this case we just validate but leave the tracer untouched + if isinstance(v, Box): + base_adapter.validate_python(get_static(v)) + return v + + # case 2: v is a plain container that contains at least one tracer + # in this case we try to coerce into ArrayBox for one-shot validation, + # but always return the original v, and fall back to a structural walk if needed + if hasbox(v): + # decide whether we must return an array + origin = get_origin(base_alias) + is_array_field = base_alias in (ArrayLike, ArrayFloat2D) or origin is None + + if is_array_field: + dense = anp.array(v) + base_adapter.validate_python(get_static(dense)) + return dense + + # otherwise it's a Python container type + # try the fast-path array validation, but return the array so ops work + try: + dense = anp.array(v) + base_adapter.validate_python(get_static(dense)) + return dense + + except Exception: + # ragged/un-coercible -> rebuild container of Boxes + if isinstance(v, tuple): + return tuple(_validate_box_or_container(x) for x in v) + if isinstance(v, list): + return [_validate_box_or_container(x) for x in v] + if isinstance(v, dict): + return {k: _validate_box_or_container(x) for k, x in v.items()} + # fallback: can't handle this structure + raise + + return base_adapter.validate_python(v) + + def _serialize_traced(a: Any, info: SerializationInfo) -> Any: + return _auto_serializer(get_static(a), info) + + return Annotated[ + object, + BeforeValidator(_validate_box_or_container), + PlainSerializer(_serialize_traced, when_used="json"), + ] + + +# "primitive" types that can use traced_alias +TracedArrayLike = traced_alias(ArrayLike) +TracedArrayFloat2D = traced_alias(ArrayFloat2D) +TracedFloat = traced_alias(float) +TracedPositiveFloat = traced_alias(PositiveFloat) +TracedComplex = traced_alias(Complex) +TracedSize1D = traced_alias(Size1D) + +# derived traced types (these mirror the types in `components.types`) +TracedSize = tuple[TracedSize1D, TracedSize1D, TracedSize1D] +TracedCoordinate = tuple[TracedFloat, TracedFloat, TracedFloat] +TracedPoleAndResidue = tuple[TracedComplex, TracedComplex] +TracedPolesAndResidues = tuple[TracedPoleAndResidue, ...] + +# The data type that we pass in and out of the web.run() @autograd.primitive +PathType = tuple[Union[int, str], ...] +AutogradFieldMap = TracedDict[PathType, Box] + +InterpolationType = Literal["nearest", "linear"] + +__all__ = [ + "AutogradFieldMap", + "InterpolationType", + "PathType", + "TracedArrayFloat2D", + "TracedArrayLike", + "TracedComplex", + "TracedCoordinate", + "TracedDict", + "TracedFloat", + "TracedPoleAndResidue", + "TracedPolesAndResidues", + "TracedPositiveFloat", + "TracedSize", + "TracedSize1D", +] diff --git a/tidy3d/_common/components/autograd/utils.py b/tidy3d/_common/components/autograd/utils.py new file mode 100644 index 0000000000..76c13b583f --- /dev/null +++ b/tidy3d/_common/components/autograd/utils.py @@ -0,0 +1,84 @@ +# utilities for working with autograd +from __future__ import annotations + +from collections.abc import Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any + +import autograd.numpy as anp +from autograd.tracer import getval, isbox + +if TYPE_CHECKING: + from typing import Union + + from autograd.numpy.numpy_boxes import ArrayBox + from numpy.typing import ArrayLike, NDArray + +__all__ = [ + "asarray1d", + "contains", + "get_static", + "hasbox", + "is_tidy_box", + "pack_complex_vec", + "split_list", +] + + +def get_static(item: Any) -> Any: + """ + Get the 'static' (untraced) version of some value by recursively calling getval + on Box instances within a nested structure. + """ + if isbox(item): + return getval(item) + elif isinstance(item, list): + return [get_static(x) for x in item] + elif isinstance(item, tuple): + return tuple(get_static(x) for x in item) + elif isinstance(item, dict): + return {k: get_static(v) for k, v in item.items()} + return item + + +def split_list(x: list[Any], index: int) -> tuple[list, list]: + """Split a list at a given index.""" + x = list(x) + return x[:index], x[index:] + + +def is_tidy_box(x: Any) -> bool: + """Check if a value is a tidy box.""" + return getattr(x, "_tidy", False) + + +def contains(target: Any, seq: Iterable[Any]) -> bool: + """Return ``True`` if target occurs anywhere within arbitrarily nested iterables.""" + for x in seq: + if x == target: + return True + if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): + if contains(target, x): + return True + return False + + +def hasbox(obj: Any) -> bool: + """True if any element inside obj is an autograd Box.""" + if isbox(obj): + return True + if isinstance(obj, Mapping): + return any(hasbox(v) for v in obj.values()) + if isinstance(obj, Sequence) and not isinstance(obj, (str, bytes)): + return any(hasbox(i) for i in obj) + return False + + +def pack_complex_vec(z: Union[NDArray, ArrayBox]) -> Union[NDArray, ArrayBox]: + """Ravel [Re(z); Im(z)] into one real vector (autograd-safe).""" + return anp.concatenate([anp.ravel(anp.real(z)), anp.ravel(anp.imag(z))]) + + +def asarray1d(x: Union[ArrayLike, ArrayBox]) -> Union[NDArray, ArrayBox]: + """Autograd-friendly 1D flatten: returns ndarray of shape (-1,).""" + x = anp.array(x) + return x if x.ndim == 1 else anp.ravel(x) diff --git a/tidy3d/_common/components/base.py b/tidy3d/_common/components/base.py new file mode 100644 index 0000000000..c0ad4f769d --- /dev/null +++ b/tidy3d/_common/components/base.py @@ -0,0 +1,1893 @@ +"""global configuration / base class for pydantic models used to make simulation.""" + +from __future__ import annotations + +import hashlib +import io +import json +import math +import os +import tempfile +import typing as _t +from collections import defaultdict +from collections.abc import Mapping, Sequence +from functools import total_ordering, wraps +from math import ceil +from os import PathLike +from pathlib import Path +from types import UnionType +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, get_args, get_origin + +import h5py +import numpy as np +import rich +import xarray as xr +import yaml +from autograd.builtins import dict as TracedDict +from autograd.numpy.numpy_boxes import ArrayBox +from autograd.tracer import isbox +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator, model_validator + +from tidy3d._common.components.autograd.utils import get_static +from tidy3d._common.components.data.data_array import DATA_ARRAY_MAP +from tidy3d._common.components.file_util import compress_file_to_gzip, extract_gzip_file +from tidy3d._common.components.types.base import TYPE_TAG_STR, Undefined +from tidy3d._common.exceptions import FileError +from tidy3d._common.log import log + +if TYPE_CHECKING: + from collections.abc import Iterator + from typing import Callable + + from pydantic.fields import FieldInfo + from pydantic.functional_validators import ModelWrapValidatorHandler + + from tidy3d._common.compat import Self + from tidy3d._common.components.autograd.types import AutogradFieldMap + + +INDENT_JSON_FILE = 4 # default indentation of json string in json files +INDENT = None # default indentation of json string used internally +JSON_TAG = "JSON_STRING" +# If json string is larger than ``MAX_STRING_LENGTH``, split the string when storing in hdf5 +MAX_STRING_LENGTH = 1_000_000_000 +FORBID_SPECIAL_CHARACTERS = ["/"] +TRACED_FIELD_KEYS_ATTR = "__tidy3d_traced_field_keys__" +TYPE_TO_CLASS_MAP: dict[str, type[Tidy3dBaseModel]] = {} + +_CacheReturn = TypeVar("_CacheReturn") + + +def cache(prop: Callable[[Any], _CacheReturn]) -> Callable[[Any], _CacheReturn]: + """Decorates a property to cache the first computed value and return it on subsequent calls.""" + + # note, we could also just use `prop` as dict key, but hashing property might be slow + prop_name = prop.__name__ + + @wraps(prop) + def cached_property_getter(self: Any) -> _CacheReturn: + """The new property method to be returned by decorator.""" + + stored_value = self._cached_properties.get(prop_name) + + if stored_value is not None: + return stored_value + + computed_value = prop(self) + self._cached_properties[prop_name] = computed_value + return computed_value + + return cached_property_getter + + +def cached_property(cached_property_getter: Callable[[Any], _CacheReturn]) -> property: + """Shortcut for property(cache()) of a getter.""" + + return property(cache(cached_property_getter)) + + +_GuardedReturn = TypeVar("_GuardedReturn") + + +def cached_property_guarded( + key_func: Callable[[Any], Any], +) -> Callable[[Callable[[Any], _GuardedReturn]], property]: + """Like cached_property, but invalidates when the key_func(self) changes.""" + + def _decorator(getter: Callable[[Any], _GuardedReturn]) -> property: + prop_name = getter.__name__ + + @wraps(getter) + def _guarded(self: Any) -> _GuardedReturn: + cache_store = self._cached_properties.get(prop_name) + current_key = key_func(self) + if cache_store is not None: + cached_key, cached_value = cache_store + if cached_key == current_key: + return cached_value + value = getter(self) + self._cached_properties[prop_name] = (current_key, value) + return value + + return property(_guarded) + + return _decorator + + +def make_json_compatible(json_string: str) -> str: + """Makes the string compatible with json standards, notably for infinity.""" + + tmp_string = "<>" + json_string = json_string.replace("-Infinity", tmp_string) + json_string = json_string.replace('""-Infinity""', tmp_string) + json_string = json_string.replace("Infinity", '"Infinity"') + json_string = json_string.replace('""Infinity""', '"Infinity"') + return json_string.replace(tmp_string, '"-Infinity"') + + +def _get_valid_extension(fname: PathLike) -> str: + """Return the file extension from fname, validated to accepted ones.""" + valid_extensions = [".json", ".yaml", ".hdf5", ".h5", ".hdf5.gz"] + path = Path(fname) + extensions = [s.lower() for s in path.suffixes[-2:]] + if len(extensions) == 0: + raise FileError(f"File '{path}' missing extension.") + single_extension = extensions[-1] + if single_extension in valid_extensions: + return single_extension + double_extension = "".join(extensions) + if double_extension in valid_extensions: + return double_extension + raise FileError( + f"File extension must be one of {', '.join(valid_extensions)}; file '{path}' does not " + "match any of those." + ) + + +def _fmt_ann_literal(ann: Any) -> str: + """Spell the annotation exactly as written.""" + if ann is None: + return "Any" + if isinstance(ann, _t._GenericAlias): + return str(ann).replace("typing.", "") + return ann.__name__ if hasattr(ann, "__name__") else str(ann) + + +T = TypeVar("T", bound="Tidy3dBaseModel") + + +def field_allows_scalar(field: FieldInfo) -> bool: + annotation = field.annotation + + def allows_scalar(a: Any) -> bool: + origin = get_origin(a) + if origin in (Union, UnionType): + args = (arg for arg in get_args(a) if arg is not type(None)) + return any(allows_scalar(arg) for arg in args) + if origin is not None: + return False + return isinstance(a, type) and issubclass(a, (float, int, np.generic)) + + return allows_scalar(annotation) + + +@total_ordering +class Tidy3dBaseModel(BaseModel): + """Base pydantic model that all Tidy3d components inherit from. + Defines configuration for handling data structures + as well as methods for importing, exporting, and hashing tidy3d objects. + For more details on pydantic base models, see: + `Pydantic Models `_ + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_default=True, + validate_assignment=True, + populate_by_name=True, + ser_json_inf_nan="strings", + extra="forbid", + frozen=True, + ) + + attrs: dict = Field( + default_factory=dict, + title="Attributes", + description="Dictionary storing arbitrary metadata for a Tidy3D object. " + "This dictionary can be freely used by the user for storing data without affecting the " + "operation of Tidy3D as it is not used internally. " + "Note that, unlike regular Tidy3D fields, ``attrs`` are mutable. " + "For example, the following is allowed for setting an ``attr`` ``obj.attrs['foo'] = bar``. " + "Also note that Tidy3D will raise a ``TypeError`` if ``attrs`` contain objects " + "that can not be serialized. One can check if ``attrs`` are serializable " + "by calling ``obj.model_dump_json()``.", + ) + + _cached_properties: dict = PrivateAttr(default_factory=dict) + _has_tracers: Optional[bool] = PrivateAttr(default=None) + + @field_validator("name", check_fields=False) + @classmethod + def _validate_name_no_special_characters(cls: type[T], name: Optional[str]) -> Optional[str]: + if name is None: + return name + for character in FORBID_SPECIAL_CHARACTERS: + if character in name: + raise ValueError( + f"Special character '{character}' not allowed in component name {name}." + ) + return name + + def __init_subclass__(cls: type[T], **kwargs: Any) -> None: + """Injects a constant discriminator field before Pydantic builds the model. + + Adds + type: Literal[""] = "" + to every concrete subclass so it can participate in a + `Field(discriminator="type")` union without manual boilerplate. + + Must run *before* `super().__init_subclass__()`; that call lets Pydantic + see the injected field during its normal schema/validator generation. + See also: https://peps.python.org/pep-0487/ + """ + tag = cls.__name__ + cls.__annotations__[TYPE_TAG_STR] = Literal[tag] + setattr(cls, TYPE_TAG_STR, tag) + TYPE_TO_CLASS_MAP[tag] = cls + + if "__tidy3d_end_capture__" not in cls.__dict__: + + @model_validator(mode="after") + def __tidy3d_end_capture__(self: T) -> T: + if log._capture: + log.end_capture(self) + return self + + cls.__tidy3d_end_capture__ = __tidy3d_end_capture__ + + super().__init_subclass__(**kwargs) + + @classmethod + def __pydantic_init_subclass__(cls: type[T], **kwargs: Any) -> None: + super().__pydantic_init_subclass__(**kwargs) + + # add docstring once pydantic is done constructing the class + cls.__doc__ = cls.generate_docstring() + + @model_validator(mode="wrap") + @classmethod + def _capture_validation_warnings( + cls: type[T], + data: Any, + handler: ModelWrapValidatorHandler[T], + ) -> T: + if not log._capture: + return handler(data) + + log.begin_capture() + try: + return handler(data) + except Exception: + log.abort_capture() + raise + + def __hash__(self) -> int: + """Hash method.""" + return self._recursive_hash(self) + + @staticmethod + def _recursive_hash(value: Any) -> int: + # Handle Autograd ArrayBoxes + if isinstance(value, ArrayBox): + # Unwrap the underlying numpy array and recurse + return Tidy3dBaseModel._recursive_hash(value._value) + if isinstance(value, np.ndarray): + # numpy arrays are not hashable by default, use byte representation + v_hash = hashlib.md5(value.tobytes()).hexdigest() + return hash(v_hash) + if isinstance(value, (xr.DataArray, xr.Dataset)): + # we choose to not hash data arrays as this would require a lot of careful handling of units, metadata. + # technically this is incorrect, but should never lead to bugs in current implementation + return hash(str(value.__class__.__name__)) + if isinstance(value, str): + # this if-case is necessary because length-1 string would lead to infinite recursion in sequence case below + return hash(value) + if isinstance(value, Sequence): + # this assumes all objects in lists are hashable by default and do not require special handling + v_hash = tuple([Tidy3dBaseModel._recursive_hash(vi) for vi in value]) + return hash(v_hash) + if isinstance(value, dict): + to_hash_list = [] + for k, v in value.items(): + v_hash = Tidy3dBaseModel._recursive_hash(v) + to_hash_list.append((k, v_hash)) + return hash(tuple(to_hash_list)) + if isinstance(value, Tidy3dBaseModel): + # This function needs to take special care because of mutable attributes inside of frozen pydantic models + to_hash_list = [] + for k in type(value).model_fields: + if k == "attrs": + continue + v_hash = Tidy3dBaseModel._recursive_hash(getattr(value, k)) + to_hash_list.append((k, v_hash)) + extra = getattr(value, "__pydantic_extra__", None) + if extra: + for k, v in extra.items(): + v_hash = Tidy3dBaseModel._recursive_hash(v) + to_hash_list.append((k, v_hash)) + # attrs is mutable, use serialized output as safe hashing option + if value.attrs: + attrs_str = value._attrs_digest() + attrs_hash = hash(attrs_str) + to_hash_list.append(("attrs", attrs_hash)) + return hash(tuple(to_hash_list)) + return hash(value) + + def _hash_self(self) -> str: + """Hash this component with ``hashlib`` in a way that is the same every session.""" + bf = io.BytesIO() + self.to_hdf5(bf) + return hashlib.md5(bf.getvalue()).hexdigest() + + @model_validator(mode="before") + @classmethod + def coerce_numpy_scalars_for_model(cls, data: Any) -> Any: + """ + coerce numpy scalars / size-1 arrays to native Python + scalars, but only for fields whose annotations allow scalars. + """ + if not isinstance(data, dict): + return data + + for name, field in cls.model_fields.items(): + if name not in data or not field_allows_scalar(field): + continue + + v = data[name] + if isinstance(v, np.generic) or (isinstance(v, np.ndarray) and v.size == 1): + data[name] = v.item() + + return data + + @classmethod + def _get_type_value(cls, obj: dict[str, Any]) -> str: + """Return the type tag from a raw dictionary.""" + if not isinstance(obj, dict): + raise TypeError("Input must be a dict") + try: + type_value = obj[TYPE_TAG_STR] + except KeyError as exc: + raise ValueError(f'Missing "{TYPE_TAG_STR}" in data') from exc + if not isinstance(type_value, str) or not type_value: + raise ValueError(f'Invalid "{TYPE_TAG_STR}" value: {type_value!r}') + return type_value + + @classmethod + def _get_registered_class(cls, type_value: str) -> type[Tidy3dBaseModel]: + try: + return TYPE_TO_CLASS_MAP[type_value] + except KeyError as exc: + raise ValueError(f"Unknown type: {type_value}") from exc + + @classmethod + def _should_dispatch_to(cls, target_cls: type[Tidy3dBaseModel]) -> bool: + """Return True if ``cls`` allows auto-dispatch to ``target_cls``.""" + return issubclass(target_cls, cls) + + @classmethod + def _resolve_dispatch_target(cls, obj: dict[str, Any]) -> type[Tidy3dBaseModel]: + """Determine which subclass should receive ``obj``.""" + type_value = cls._get_type_value(obj) + target_cls = cls._get_registered_class(type_value) + if cls._should_dispatch_to(target_cls): + return target_cls + if target_cls is cls: + return cls + raise ValueError( + f'Cannot parse type "{type_value}" using {cls.__name__}; expected subclass of {cls.__name__}.' + ) + + @classmethod + def _target_cls_from_file( + cls, fname: PathLike, group_path: Optional[str] = None + ) -> type[Tidy3dBaseModel]: + """Peek the file metadata to determine the subclass to instantiate.""" + model_dict = cls.dict_from_file( + fname=fname, + group_path=group_path, + load_data_arrays=False, + ) + return cls._resolve_dispatch_target(model_dict) + + @classmethod + def _model_validate(cls, obj: dict[str, Any], **parse_obj_kwargs: Any) -> Tidy3dBaseModel: + """Dispatch ``obj`` to the correct subclass registered in the type map.""" + target_cls = cls._resolve_dispatch_target(obj) + if target_cls is cls: + return super().model_validate(obj, **parse_obj_kwargs) + return target_cls.model_validate(obj, **parse_obj_kwargs) + + @classmethod + def _validate_model_dict( + cls, model_dict: dict[str, Any], **parse_obj_kwargs: Any + ) -> Tidy3dBaseModel: + """Parse ``model_dict`` while optionally auto-dispatching when called on the base class.""" + if cls is Tidy3dBaseModel: + return cls._model_validate(model_dict, **parse_obj_kwargs) + return cls.model_validate(model_dict, **parse_obj_kwargs) + + def _preprocess_update_values(self, update: Mapping[str, Any]) -> dict[str, Any]: + """Preprocess update values to convert lists to tuples where appropriate. + + This helps avoid Pydantic v2 serialization warnings when using `model_copy()` + with list values for tuple fields. + """ + if not update: + return {} + + def get_tuple_element_type(annotation: Any) -> Optional[type]: + """Get the element type of a tuple annotation if it has one consistent type.""" + origin = get_origin(annotation) + if origin is tuple: + args = get_args(annotation) + if args: + # Check if it's a homogeneous tuple like tuple[bool, ...] or tuple[str, ...] + if len(args) == 2 and args[1] is ...: + return args[0] + # Check if all elements have the same type + if all(arg == args[0] for arg in args): + return args[0] + return None + + def should_convert_to_tuple(annotation: Any) -> tuple[bool, Optional[type[Any]]]: + """Check if the given annotation represents a tuple type and return element type if any.""" + origin = get_origin(annotation) + + if origin is tuple: + return True, get_tuple_element_type(annotation) + + # Union types containing tuple + if origin is Union: + args = get_args(annotation) + for arg in args: + if get_origin(arg) is tuple: + return True, get_tuple_element_type(arg) + + return False, None + + def convert_value(value: Any, field_info: FieldInfo) -> Any: + """Convert value based on field type information.""" + annotation = field_info.annotation + + # Handle list/tuple to tuple conversion with proper element types + is_tuple, element_type = should_convert_to_tuple(annotation) + + # Check if value is a numpy array and needs to be converted to tuple + try: + import numpy as np + + if isinstance(value, np.ndarray) and is_tuple: + # Convert numpy array to list first + value = value.tolist() + except ImportError: + pass + + # Handle autograd SequenceBox - convert to tuple + if ( + is_tuple + and hasattr(value, "__class__") + and value.__class__.__name__ == "SequenceBox" + ): + # SequenceBox is iterable, so convert it to tuple + return tuple(value) + + if isinstance(value, (list, tuple)) and is_tuple: + # Convert elements based on element type + if element_type is bool: + # Convert integers to booleans + value = [bool(item) if isinstance(item, int) else item for item in value] + elif element_type is str: + # Ensure all elements are strings + value = [str(item) if not isinstance(item, str) else item for item in value] + else: + # Check if it's a numpy array or contains numpy types + try: + import numpy as np + + if any(isinstance(item, np.generic) for item in value): + # Convert numpy types to Python types + value = [ + item.item() if isinstance(item, np.generic) else item + for item in value + ] + except ImportError: + pass + return tuple(value) + + # Handle int to bool conversion + if annotation is bool and isinstance(value, int): + return bool(value) + + # Handle dict to Tidy3dBaseModel conversion + if isinstance(value, dict): + # Check if the annotation is a Tidy3dBaseModel subclass + origin = get_origin(annotation) + if origin is None: + # Not a generic type, check if it's a direct subclass + try: + if isinstance(annotation, type) and issubclass(annotation, Tidy3dBaseModel): + return annotation(**value) + except (TypeError, AttributeError): + pass + elif origin is Union: + # For Union types, try to convert to the first matching Tidy3dBaseModel type + args = get_args(annotation) + for arg in args: + try: + if isinstance(arg, type) and issubclass(arg, Tidy3dBaseModel): + return arg(**value) + except (TypeError, AttributeError, ValueError): + continue + + return value + + processed = {} + for field_name, value in update.items(): + if field_name in type(self).model_fields: + field_info = type(self).model_fields[field_name] + processed[field_name] = convert_value(value, field_info) + else: + processed[field_name] = value + + return processed + + def copy( + self, + deep: bool = True, + *, + validate: bool = True, + update: Optional[Mapping[str, Any]] = None, + ) -> Self: + """Return a copy of the model. + + Parameters + ---------- + deep : bool = True + Whether to make a deep copy first (same as v1). + validate : bool = True + If ``True``, run full Pydantic validation on the copied data. + update : Optional[Mapping[str, Any]] = None + Optional mapping of fields to overwrite (passed straight + through to ``model_copy(update=...)``). + """ + if update and self.model_config.get("extra") == "forbid": + invalid = set(update) - set(type(self).model_fields) + if invalid: + raise KeyError(f"'{self.type}' received invalid fields on copy: {invalid}") + + # preprocess update values to convert lists to tuples where appropriate + if update: + update = self._preprocess_update_values(update) + + new_model = self.model_copy(deep=deep, update=update) + + if validate: + return self.__class__.model_validate(new_model.model_dump()) + else: + # make sure cache is always cleared + new_model._cached_properties = {} + + new_model._has_tracers = None + return new_model + + def updated_copy( + self, + path: Optional[str] = None, + *, + deep: bool = True, + validate: bool = True, + **kwargs: Any, + ) -> Self: + """Make copy of a component instance with ``**kwargs`` indicating updated field values. + + Note + ---- + If ``path`` is supplied, applies the updated copy with the update performed on the sub- + component corresponding to the path. For indexing into a tuple or list, use the integer + value. + + Example + ------- + >>> sim = simulation.updated_copy(size=new_size, path=f"structures/{i}/geometry") # doctest: +SKIP + """ + if not path: + return self.copy(deep=deep, validate=validate, update=kwargs) + + path_parts = path.split("/") + field_name, *rest = path_parts + + try: + sub_component = getattr(self, field_name) + except AttributeError as exc: + raise AttributeError( + f"Could not find field '{field_name}' in path '{path}'. " + f"Available top-level fields: {tuple(type(self).model_fields)}." + ) from exc + + if isinstance(sub_component, (list, tuple)): + try: + index = int(rest[0]) + except (IndexError, ValueError): + raise ValueError( + f"Expected integer index into '{field_name}' in path '{path}'." + ) from None + sub_component_list = list(sub_component) + sub_component_list[index] = sub_component_list[index].updated_copy( + path="/".join(rest[1:]), + deep=deep, + validate=validate, + **kwargs, + ) + new_value = type(sub_component)(sub_component_list) + else: + new_value = sub_component.updated_copy( + path="/".join(rest), + deep=deep, + validate=validate, + **kwargs, + ) + + return self.copy(deep=deep, validate=validate, update={field_name: new_value}) + + @staticmethod + def _core_model_traversal( + current_obj: Any, current_path_segments: tuple[str, ...] + ) -> Iterator[tuple[Self, tuple[str, ...]]]: + """ + Recursively traverses a model structure yielding Tidy3dBaseModel instances and their paths. + + This is an internal helper method used by :meth:`find_paths` and :meth:`find_submodels` + to navigate nested :class:`Tidy3dBaseModel` structures. + + Parameters + ---------- + current_obj : Any + The current object in the traversal, which can be a :class:`Tidy3dBaseModel`, + list, tuple, or other type. + current_path_segments : tuple[str, ...] + A tuple of strings representing the path segments from the initial model + to the ``current_obj``. + + Returns + ------- + Iterator[tuple[Self, tuple[str, ...]]] + An iterator yielding tuples, where the first element is a found :class:`Tidy3dBaseModel` instance + and the second is a tuple of strings representing the path to that instance + from the initial object. The path for the top-level model itself will be an empty tuple. + """ + if isinstance(current_obj, Tidy3dBaseModel): + yield current_obj, current_path_segments + + for field_name in type(current_obj).model_fields: + if ( + field_name == "type" + and getattr(current_obj, field_name, None) == current_obj.__class__.__name__ + ): + continue + + field_value = getattr(current_obj, field_name) + yield from Tidy3dBaseModel._core_model_traversal( + field_value, (*current_path_segments, field_name) + ) + elif isinstance(current_obj, (list, tuple)): + for index, item in enumerate(current_obj): + yield from Tidy3dBaseModel._core_model_traversal( + item, (*current_path_segments, str(index)) + ) + + def find_paths(self, target_field_name: str, target_field_value: Any = Undefined) -> list[str]: + """ + Finds paths to nested model instances that have a specific field, optionally matching a value. + + The paths are string representations like ``"structures/0/geometry"``, designed for direct + use with the :meth:`updated_copy` method to modify specific parts of this model. + An empty string ``""`` in the returned list indicates that this model instance + itself (the one ``find_paths`` is called on) matches the criteria. + + Parameters + ---------- + target_field_name : str + The name of the attribute (field) to search for within nested + :class:`Tidy3dBaseModel` instances. For example, ``"name"`` or ``"permittivity"``. + target_field_value : Any, optional + If provided, only paths to model instances where ``target_field_name`` also has this + specific value will be returned. If omitted, paths are returned if the + ``target_field_name`` exists, regardless of its value. + + Returns + ------- + list[str] + A sorted list of unique string paths. Each path points to a + :class:`Tidy3dBaseModel` instance that possesses the ``target_field_name`` + (and optionally matches ``target_field_value``). + + Example + ------- + >>> # Assume 'sim' is a Tidy3D simulation object + >>> # Find all geometries named "waveguide" + >>> paths = sim.find_paths(target_field_name="name", target_field_value="waveguide") # doctest: +SKIP + >>> # paths might be ['structures/0', 'structures/3'] + >>> # Update the size of the first found "waveguide" + >>> new_sim = sim.updated_copy(path=paths[0], size=(1.0, 0.5, 0.22)) # doctest: +SKIP + """ + found_paths_set = set() + + for sub_model_instance, path_segments_to_sub_model in Tidy3dBaseModel._core_model_traversal( + self, () + ): + if target_field_name in type(sub_model_instance).model_fields: + passes_value_filter = True + if target_field_value is not Undefined: + actual_value = getattr(sub_model_instance, target_field_name) + if actual_value != target_field_value: + passes_value_filter = False + + if passes_value_filter: + path_str = "/".join(path_segments_to_sub_model) + found_paths_set.add(path_str) + + return sorted(found_paths_set) + + def find_submodels(self, target_type: Self) -> list[Self]: + """ + Finds all unique nested instances of a specific Tidy3D model type within this model. + + This method traverses the model structure and collects all instances that are of + the ``target_type`` (e.g., :class:`~tidy3d.Structure`, :class:`~tidy3d.Medium`, + :class:`~tidy3d.Box`). + Uniqueness is determined by the model's content. The order of models + in the returned list corresponds to their first encounter during a depth-first traversal. + + Parameters + ---------- + target_type : Tidy3dBaseModel + The specific Tidy3D class (e.g., ``Structure``, ``Medium``, ``Box``) to search for. + This class must be a subclass of :class:`Tidy3dBaseModel`. + + Returns + ------- + list[Tidy3dBaseModel] + A list of unique instances found within this model that are of the + provided ``target_type``. + + Example + ------- + >>> # Assume 'sim' is a Tidy3D Simulation object + >>> # Find all Structure instances within the simulation + >>> all_structures = sim.find_submodels(td.Structure) # doctest: +SKIP + >>> for struct in all_structures: + ... print(f"Structure: {struct.name}, medium: {struct.medium}") # doctest: +SKIP + + >>> # Find all Box geometries within the simulation + >>> all_boxes = sim.find_submodels(td.Box) # doctest: +SKIP + >>> for box in all_boxes: + ... print(f"Found Box with size: {box.size}") # doctest: +SKIP + + >>> # Find all Medium instances (useful for checking materials) + >>> all_media = sim.find_submodels(td.Medium) # doctest: +SKIP + >>> # Note: This would find td.Medium instances, but not td.PECMedium or td.PoleResidue + >>> # unless they inherit directly from td.Medium and not just Tidy3dBaseModel or td.AbstractMedium. + >>> # To find all medium types, one might search for td.AbstractMedium if that's a common base. + """ + found_models_dict = {} + + for sub_model_candidate, _ in Tidy3dBaseModel._core_model_traversal(self, ()): + if isinstance(sub_model_candidate, target_type): + if sub_model_candidate not in found_models_dict: + found_models_dict[sub_model_candidate] = True + + return list(found_models_dict.keys()) + + def help(self, methods: bool = False) -> None: + """Prints message describing the fields and methods of a :class:`Tidy3dBaseModel`. + + Parameters + ---------- + methods : bool = False + Whether to also print out information about object's methods. + + Example + ------- + >>> simulation.help(methods=True) # doctest: +SKIP + """ + rich.inspect(type(self), methods=methods) + + @classmethod + def from_file( + cls, + fname: PathLike, + group_path: Optional[str] = None, + lazy: bool = False, + on_load: Optional[Callable[[Any], None]] = None, + **parse_obj_kwargs: Any, + ) -> Self: + """Loads a :class:`Tidy3dBaseModel` from .yaml, .json, .hdf5, or .hdf5.gz file. + + Parameters + ---------- + fname : PathLike + Full path to the file to load the :class:`Tidy3dBaseModel` from. + group_path : Optional[str] = None + Path to a group inside the file to use as the base level. Only for hdf5 files. + Starting `/` is optional. + lazy : bool = False + Whether to load the actual data (``lazy=False``) or return a proxy that loads + the data when accessed (``lazy=True``). + on_load : Optional[Callable[[Any], None]] = None + Callback function executed once the model is fully materialized. + Only used if ``lazy=True``. The callback is invoked with the loaded + instance as its sole argument, enabling post-processing such as + validation, logging, or warnings checks. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method when loading model. + + Returns + ------- + Self + An instance of the component class calling ``load``. + + Example + ------- + >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP + """ + if lazy: + target_cls = cls._target_cls_from_file(fname=fname, group_path=group_path) + Proxy = _make_lazy_proxy(target_cls, on_load=on_load) + return Proxy(fname, group_path, parse_obj_kwargs) + model_dict = cls.dict_from_file(fname=fname, group_path=group_path) + obj = cls._validate_model_dict(model_dict, **parse_obj_kwargs) + if not lazy and on_load is not None: + on_load(obj) + return obj + + @classmethod + def dict_from_file( + cls: type[T], + fname: PathLike, + group_path: Optional[str] = None, + *, + load_data_arrays: bool = True, + ) -> dict: + """Loads a dictionary containing the model from a .yaml, .json, .hdf5, or .hdf5.gz file. + + Parameters + ---------- + fname : PathLike + Full path to the file to load the :class:`Tidy3dBaseModel` from. + group_path : str, optional + Path to a group inside the file to use as the base level. + + Returns + ------- + dict + A dictionary containing the model. + + Example + ------- + >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP + """ + fname_path = Path(fname) + extension = _get_valid_extension(fname_path) + kwargs = {"fname": fname_path} + + if group_path is not None: + if extension in {".hdf5", ".hdf5.gz", ".h5"}: + kwargs["group_path"] = group_path + else: + log.warning("'group_path' provided, but this feature only works with hdf5 files.") + + if extension in {".hdf5", ".hdf5.gz", ".h5"}: + kwargs["load_data_arrays"] = load_data_arrays + + converter = { + ".json": cls.dict_from_json, + ".yaml": cls.dict_from_yaml, + ".hdf5": cls.dict_from_hdf5, + ".hdf5.gz": cls.dict_from_hdf5_gz, + ".h5": cls.dict_from_hdf5, + }[extension] + return converter(**kwargs) + + def to_file(self, fname: PathLike) -> None: + """Exports :class:`Tidy3dBaseModel` instance to .yaml, .json, or .hdf5 file + + Parameters + ---------- + fname : PathLike + Full path to the .yaml or .json file to save the :class:`Tidy3dBaseModel` to. + + Example + ------- + >>> simulation.to_file(fname='folder/sim.json') # doctest: +SKIP + """ + extension = _get_valid_extension(fname) + converter = { + ".json": self.to_json, + ".yaml": self.to_yaml, + ".hdf5": self.to_hdf5, + ".hdf5.gz": self.to_hdf5_gz, + }[extension] + return converter(fname=fname) + + @classmethod + def from_json(cls: type[T], fname: PathLike, **model_validate_kwargs: Any) -> Self: + """Load a :class:`Tidy3dBaseModel` from .json file. + + Parameters + ---------- + fname : PathLike + Full path to the .json file to load the :class:`Tidy3dBaseModel` from. + + Returns + ------- + Self + An instance of the component class calling `load`. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method. + + Example + ------- + >>> simulation = Simulation.from_json(fname='folder/sim.json') # doctest: +SKIP + """ + model_dict = cls.dict_from_json(fname=fname) + return cls._validate_model_dict(model_dict, **model_validate_kwargs) + + @classmethod + def dict_from_json(cls: type[T], fname: PathLike) -> dict: + """Load dictionary of the model from a .json file. + + Parameters + ---------- + fname : PathLike + Full path to the .json file to load the :class:`Tidy3dBaseModel` from. + + Returns + ------- + dict + A dictionary containing the model. + + Example + ------- + >>> sim_dict = Simulation.dict_from_json(fname='folder/sim.json') # doctest: +SKIP + """ + with open(fname, encoding="utf-8") as json_fhandle: + model_dict = json.load(json_fhandle) + return model_dict + + def to_json(self, fname: PathLike) -> None: + """Exports :class:`Tidy3dBaseModel` instance to .json file + + Parameters + ---------- + fname : PathLike + Full path to the .json file to save the :class:`Tidy3dBaseModel` to. + + Example + ------- + >>> simulation.to_json(fname='folder/sim.json') # doctest: +SKIP + """ + export_model = self.to_static() + json_string = export_model.model_dump_json(indent=INDENT_JSON_FILE) + self._warn_if_contains_data(json_string) + path = Path(fname) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as file_handle: + file_handle.write(json_string) + + @classmethod + def from_yaml(cls: type[T], fname: PathLike, **model_validate_kwargs: Any) -> Self: + """Loads :class:`Tidy3dBaseModel` from .yaml file. + + Parameters + ---------- + fname : PathLike + Full path to the .yaml file to load the :class:`Tidy3dBaseModel` from. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method. + + Returns + ------- + Self + An instance of the component class calling `from_yaml`. + + Example + ------- + >>> simulation = Simulation.from_yaml(fname='folder/sim.yaml') # doctest: +SKIP + """ + model_dict = cls.dict_from_yaml(fname=fname) + return cls._validate_model_dict(model_dict, **model_validate_kwargs) + + @classmethod + def dict_from_yaml(cls: type[T], fname: PathLike) -> dict: + """Load dictionary of the model from a .yaml file. + + Parameters + ---------- + fname : PathLike + Full path to the .yaml file to load the :class:`Tidy3dBaseModel` from. + + Returns + ------- + dict + A dictionary containing the model. + + Example + ------- + >>> sim_dict = Simulation.dict_from_yaml(fname='folder/sim.yaml') # doctest: +SKIP + """ + with open(fname, encoding="utf-8") as yaml_in: + model_dict = yaml.safe_load(yaml_in) + return model_dict + + def to_yaml(self, fname: PathLike) -> None: + """Exports :class:`Tidy3dBaseModel` instance to .yaml file. + + Parameters + ---------- + fname : PathLike + Full path to the .yaml file to save the :class:`Tidy3dBaseModel` to. + + Example + ------- + >>> simulation.to_yaml(fname='folder/sim.yaml') # doctest: +SKIP + """ + export_model = self.to_static() + json_string = export_model.model_dump_json() + self._warn_if_contains_data(json_string) + model_dict = json.loads(json_string) + path = Path(fname) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w+", encoding="utf-8") as file_handle: + yaml.dump(model_dict, file_handle, indent=INDENT_JSON_FILE) + + @staticmethod + def _warn_if_contains_data(json_str: str) -> None: + """Log a warning if the json string contains data, used in '.json' and '.yaml' file.""" + if any((key in json_str for key, _ in DATA_ARRAY_MAP.items())): + log.warning( + "Data contents found in the model to be written to file. " + "Note that this data will not be included in '.json' or '.yaml' formats. " + "As a result, it will not be possible to load the file back to the original model. " + "Instead, use '.hdf5' extension in filename passed to 'to_file()'." + ) + + @staticmethod + def _construct_group_path(group_path: str) -> str: + """Construct a group path with the leading forward slash if not supplied.""" + + # empty string or None + if not group_path: + return "/" + + # missing leading forward slash + if group_path[0] != "/": + return f"/{group_path}" + + return group_path + + @staticmethod + def get_tuple_group_name(index: int) -> str: + """Get the group name of a tuple element.""" + return str(int(index)) + + @staticmethod + def get_tuple_index(key_name: str) -> int: + """Get the index into the tuple based on its group name.""" + return int(str(key_name)) + + @classmethod + def tuple_to_dict(cls: type[T], tuple_values: tuple) -> dict: + """How we generate a dictionary mapping new keys to tuple values for hdf5.""" + return {cls.get_tuple_group_name(index=i): val for i, val in enumerate(tuple_values)} + + @classmethod + def get_sub_model( + cls: type[T], group_path: str, model_dict: Union[dict[str, Any], list[Any]] + ) -> dict: + """Get the sub model for a given group path.""" + + for key in group_path.split("/"): + if key: + if isinstance(model_dict, list): + tuple_index = cls.get_tuple_index(key_name=key) + model_dict = model_dict[tuple_index] + else: + model_dict = model_dict[key] + return model_dict + + @staticmethod + def _json_string_key(index: int) -> str: + """Get json string key for string chunk number ``index``.""" + if index: + return f"{JSON_TAG}_{index}" + return JSON_TAG + + @classmethod + def _json_string_from_hdf5(cls: type[T], fname: PathLike) -> str: + """Load the model json string from an hdf5 file.""" + with h5py.File(fname, "r") as f_handle: + num_string_parts = len([key for key in f_handle.keys() if JSON_TAG in key]) + json_string = b"" + for ind in range(num_string_parts): + json_string += f_handle[cls._json_string_key(ind)][()] + return json_string + + @classmethod + def dict_from_hdf5( + cls: type[T], + fname: PathLike, + group_path: str = "", + custom_decoders: Optional[list[Callable]] = None, + load_data_arrays: bool = True, + ) -> dict: + """Loads a dictionary containing the model contents from a .hdf5 file. + + Parameters + ---------- + fname : PathLike + Full path to the .hdf5 file to load the :class:`Tidy3dBaseModel` from. + group_path : str, optional + Path to a group inside the file to selectively load a sub-element of the model only. + custom_decoders : List[Callable] + List of functions accepting + (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the + value in the model dict after a custom decoding. + + Returns + ------- + dict + Dictionary containing the model. + + Example + ------- + >>> sim_dict = Simulation.dict_from_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP + """ + + def is_data_array(value: Any) -> bool: + """Whether a value is supposed to be a data array based on the contents.""" + return isinstance(value, str) and value in DATA_ARRAY_MAP + + fname_path = Path(fname) + + def load_data_from_file(model_dict: dict, group_path: str = "") -> None: + """For every DataArray item in dictionary, load path of hdf5 group as value.""" + + for key, value in model_dict.items(): + subpath = f"{group_path}/{key}" + + # apply custom validation to the key value pair and modify model_dict + if custom_decoders: + for custom_decoder in custom_decoders: + custom_decoder( + fname=str(fname_path), + group_path=subpath, + model_dict=model_dict, + key=key, + value=value, + ) + + # write the path to the element of the json dict where the data_array should be + if is_data_array(value): + data_array_type = DATA_ARRAY_MAP[value] + model_dict[key] = data_array_type.from_hdf5( + fname=fname_path, group_path=subpath + ) + continue + + # if a list, assign each element a unique key, recurse + if isinstance(value, (list, tuple)): + value_dict = cls.tuple_to_dict(tuple_values=value) + load_data_from_file(model_dict=value_dict, group_path=subpath) + + # handle case of nested list of DataArray elements + val_tuple = list(value_dict.values()) + for ind, (model_item, value_item) in enumerate(zip(model_dict[key], val_tuple)): + if is_data_array(model_item): + model_dict[key][ind] = value_item + + # if a dict, recurse + elif isinstance(value, dict): + load_data_from_file(model_dict=value, group_path=subpath) + + model_dict = json.loads(cls._json_string_from_hdf5(fname=fname_path)) + group_path = cls._construct_group_path(group_path) + model_dict = cls.get_sub_model(group_path=group_path, model_dict=model_dict) + if load_data_arrays: + load_data_from_file(model_dict=model_dict, group_path=group_path) + return model_dict + + @classmethod + def from_hdf5( + cls: type[T], + fname: PathLike, + group_path: str = "", + custom_decoders: Optional[list[Callable]] = None, + **model_validate_kwargs: Any, + ) -> Self: + """Loads :class:`Tidy3dBaseModel` instance to .hdf5 file. + + Parameters + ---------- + fname : PathLike + Full path to the .hdf5 file to load the :class:`Tidy3dBaseModel` from. + group_path : str, optional + Path to a group inside the file to selectively load a sub-element of the model only. + Starting `/` is optional. + custom_decoders : List[Callable] + List of functions accepting + (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the + value in the model dict after a custom decoding. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method. + + Example + ------- + >>> simulation = Simulation.from_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP + """ + + group_path = cls._construct_group_path(group_path) + model_dict = cls.dict_from_hdf5( + fname=fname, + group_path=group_path, + custom_decoders=custom_decoders, + ) + return cls._validate_model_dict(model_dict, **model_validate_kwargs) + + def to_hdf5( + self, + fname: Union[PathLike, io.BytesIO], + custom_encoders: Optional[list[Callable]] = None, + ) -> None: + """Exports :class:`Tidy3dBaseModel` instance to .hdf5 file. + + Parameters + ---------- + fname : Union[PathLike, BytesIO] + Full path to the .hdf5 file or buffer to save the :class:`Tidy3dBaseModel` to. + custom_encoders : List[Callable] + List of functions accepting (fname: str, group_path: str, value: Any) that take + the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. + + Example + ------- + >>> simulation.to_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP + """ + + export_model = self.to_static() + traced_keys_payload = export_model.attrs.get(TRACED_FIELD_KEYS_ATTR) + + if traced_keys_payload is None: + traced_keys_payload = self.attrs.get(TRACED_FIELD_KEYS_ATTR) + if traced_keys_payload is None: + traced_keys_payload = self._serialized_traced_field_keys() + path = Path(fname) if isinstance(fname, PathLike) else fname + with h5py.File(path, "w") as f_handle: + json_str = export_model.model_dump_json() + for ind in range(ceil(len(json_str) / MAX_STRING_LENGTH)): + ind_start = int(ind * MAX_STRING_LENGTH) + ind_stop = min(int(ind + 1) * MAX_STRING_LENGTH, len(json_str)) + f_handle[self._json_string_key(ind)] = json_str[ind_start:ind_stop] + + def add_data_to_file(data_dict: dict, group_path: str = "") -> None: + """For every DataArray item in dictionary, write path of hdf5 group as value.""" + + for key, value in data_dict.items(): + # append the key to the path + subpath = f"{group_path}/{key}" + + if custom_encoders: + for custom_encoder in custom_encoders: + custom_encoder(fname=f_handle, group_path=subpath, value=value) + + # write the path to the element of the json dict where the data_array should be + if isinstance(value, xr.DataArray): + value.to_hdf5(fname=f_handle, group_path=subpath) + + # if a tuple, assign each element a unique key + if isinstance(value, (list, tuple)): + value_dict = export_model.tuple_to_dict(tuple_values=value) + add_data_to_file(data_dict=value_dict, group_path=subpath) + + # if a dict, recurse + elif isinstance(value, dict): + add_data_to_file(data_dict=value, group_path=subpath) + + add_data_to_file(data_dict=export_model.model_dump()) + if traced_keys_payload: + f_handle.attrs[TRACED_FIELD_KEYS_ATTR] = traced_keys_payload + + @classmethod + def dict_from_hdf5_gz( + cls: type[T], + fname: PathLike, + group_path: str = "", + custom_decoders: Optional[list[Callable]] = None, + load_data_arrays: bool = True, + ) -> dict: + """Loads a dictionary containing the model contents from a .hdf5.gz file. + + Parameters + ---------- + fname : PathLike + Full path to the .hdf5.gz file to load the :class:`Tidy3dBaseModel` from. + group_path : str, optional + Path to a group inside the file to selectively load a sub-element of the model only. + custom_decoders : List[Callable] + List of functions accepting + (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the + value in the model dict after a custom decoding. + + Returns + ------- + dict + Dictionary containing the model. + + Example + ------- + >>> sim_dict = Simulation.dict_from_hdf5(fname='folder/sim.hdf5.gz') # doctest: +SKIP + """ + file_descriptor, extracted = tempfile.mkstemp(".hdf5") + os.close(file_descriptor) + extracted_path = Path(extracted) + try: + extract_gzip_file(fname, extracted_path) + result = cls.dict_from_hdf5( + extracted_path, + group_path=group_path, + custom_decoders=custom_decoders, + load_data_arrays=load_data_arrays, + ) + finally: + extracted_path.unlink(missing_ok=True) + + return result + + @classmethod + def from_hdf5_gz( + cls: type[T], + fname: PathLike, + group_path: str = "", + custom_decoders: Optional[list[Callable]] = None, + **model_validate_kwargs: Any, + ) -> Self: + """Loads :class:`Tidy3dBaseModel` instance to .hdf5.gz file. + + Parameters + ---------- + fname : PathLike + Full path to the .hdf5.gz file to load the :class:`Tidy3dBaseModel` from. + group_path : str, optional + Path to a group inside the file to selectively load a sub-element of the model only. + Starting `/` is optional. + custom_decoders : List[Callable] + List of functions accepting + (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the + value in the model dict after a custom decoding. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method. + + Example + ------- + >>> simulation = Simulation.from_hdf5_gz(fname='folder/sim.hdf5.gz') # doctest: +SKIP + """ + + group_path = cls._construct_group_path(group_path) + model_dict = cls.dict_from_hdf5_gz( + fname=fname, + group_path=group_path, + custom_decoders=custom_decoders, + ) + return cls._validate_model_dict(model_dict, **model_validate_kwargs) + + def to_hdf5_gz( + self, + fname: Union[PathLike, io.BytesIO], + custom_encoders: Optional[list[Callable]] = None, + ) -> None: + """Exports :class:`Tidy3dBaseModel` instance to .hdf5.gz file. + + Parameters + ---------- + fname : Union[PathLike, BytesIO] + Full path to the .hdf5.gz file or buffer to save the :class:`Tidy3dBaseModel` to. + custom_encoders : List[Callable] + List of functions accepting (fname: str, group_path: str, value: Any) that take + the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. + + Example + ------- + >>> simulation.to_hdf5_gz(fname='folder/sim.hdf5.gz') # doctest: +SKIP + """ + file, decompressed = tempfile.mkstemp(".hdf5") + os.close(file) + try: + self.to_hdf5(decompressed, custom_encoders=custom_encoders) + compress_file_to_gzip(decompressed, fname) + finally: + os.unlink(decompressed) + + def __lt__(self, other: object) -> bool: + """define < for getting unique indices based on hash.""" + return hash(self) < hash(other) + + def __eq__(self, other: object) -> bool: + """Two models are equal when origins match and every public or extra field matches.""" + if not isinstance(other, BaseModel): + return NotImplemented + + self_origin = ( + getattr(self, "__pydantic_generic_metadata__", {}).get("origin") or self.__class__ + ) + other_origin = ( + getattr(other, "__pydantic_generic_metadata__", {}).get("origin") or other.__class__ + ) + if self_origin is not other_origin: + return False + + if getattr(self, "__pydantic_extra__", None) != getattr(other, "__pydantic_extra__", None): + return False + + def _fields_equal(a: Any, b: Any) -> bool: + a = get_static(a) + b = get_static(b) + + if a is b: + return True + if type(a) is not type(b): + if not (isinstance(a, (list, tuple)) and isinstance(b, (list, tuple))): + return False + if isinstance(a, np.ndarray): + return np.array_equal(a, b) + if isinstance(a, (xr.DataArray, xr.Dataset)): + return a.equals(b) + if isinstance(a, Mapping): + if a.keys() != b.keys(): + return False + return all(_fields_equal(a[k], b[k]) for k in a) + if isinstance(a, Sequence) and not isinstance(a, (str, bytes)): + if len(a) != len(b): + return False + return all(_fields_equal(x, y) for i, (x, y) in enumerate(zip(a, b))) + if isinstance(a, float) and isinstance(b, float) and np.isnan(a) and np.isnan(b): + return True + return a == b + + for name in type(self).model_fields: + if not _fields_equal(getattr(self, name), getattr(other, name)): + return False + + return True + + def _attrs_digest(self) -> str: + """Stable digest of `attrs` using the same JSON encoding rules as pydantic .json().""" + # encoders = getattr(self.__config__, "json_encoders", {}) or {} + + # def _default(o): + # return custom_pydantic_encoder(encoders, o) + + json_str = json.dumps( + self.attrs, + # default=_default, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=False, + ) + json_str = make_json_compatible(json_str) + + return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + + @cached_property_guarded(lambda self: self._attrs_digest()) + def _json_string(self) -> str: + """Returns string representation of a :class:`Tidy3dBaseModel`. + + Returns + ------- + str + Json-formatted string holding :class:`Tidy3dBaseModel` data. + """ + return self.model_dump_json(indent=INDENT, exclude_unset=False) + + def _strip_traced_fields( + self, starting_path: tuple[str, ...] = (), include_untraced_data_arrays: bool = False + ) -> AutogradFieldMap: + """Extract a dictionary mapping paths in the model to the data traced by ``autograd``. + + Parameters + ---------- + starting_path : tuple[str, ...] = () + If provided, starts recursing in self.model_dump() from this path of field names + include_untraced_data_arrays : bool = False + Whether to include ``DataArray`` objects without tracers. + We need to include these when returning data, but are unnecessary for structures. + + Returns + ------- + dict + mapping of traced fields used by ``autograd`` + + """ + + path = tuple(starting_path) + if self._has_tracers is False and not include_untraced_data_arrays: + return TracedDict() + + field_mapping = {} + + def handle_value(x: Any, path: tuple[str, ...]) -> None: + """recursively update ``field_mapping`` with path to the autograd data.""" + + # this is a leaf node that we want to trace, add this path and data to the mapping + if isbox(x): + field_mapping[path] = x + + # for data arrays, need to be more careful as their tracers are stored in .data + elif isinstance(x, xr.DataArray): + data = x.data + if isbox(data) or any(isbox(el) for el in np.asarray(data).ravel()): + field_mapping[path] = x.data + elif include_untraced_data_arrays: + field_mapping[path] = x.data + + # for sequences, add (i,) to the path and handle each value individually + elif isinstance(x, (list, tuple)): + for i, val in enumerate(x): + handle_value(val, path=(*path, i)) + + # for dictionaries, add the (key,) to the path and handle each value individually + elif isinstance(x, dict): + for key, val in x.items(): + handle_value(val, path=(*path, key)) + + # recursively parse the dictionary of this object + self_dict = self.model_dump(round_trip=True) + + # if an include_only string was provided, only look at that subset of the dict + if path: + for key in path: + self_dict = self_dict[key] + + handle_value(self_dict, path=path) + + if field_mapping: + if not include_untraced_data_arrays: + self._has_tracers = True + return TracedDict(field_mapping) + + if not include_untraced_data_arrays and not path: + self._has_tracers = False + return TracedDict() + + def _insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Self: + """Recursively insert a map of paths to autograd-traced fields into a copy of this obj.""" + self_dict = self.model_dump(round_trip=True) + + def insert_value(x: Any, path: tuple[str, ...], sub_dict: dict[str, Any]) -> None: + """Insert a value into the path into a dictionary.""" + current_dict = sub_dict + for key in path[:-1]: + if isinstance(current_dict[key], tuple): + current_dict[key] = list(current_dict[key]) + current_dict = current_dict[key] + + final_key = path[-1] + if isinstance(current_dict[final_key], tuple): + current_dict[final_key] = list(current_dict[final_key]) + + sub_element = current_dict[final_key] + if isinstance(sub_element, xr.DataArray): + current_dict[final_key] = sub_element.copy(deep=False, data=x) + + else: + current_dict[final_key] = x + + for path, value in field_mapping.items(): + insert_value(value, path=path, sub_dict=self_dict) + + return self.__class__.model_validate(self_dict) + + def _serialized_traced_field_keys( + self, field_mapping: Optional[AutogradFieldMap] = None + ) -> Optional[str]: + """Return a serialized, order-independent representation of traced field paths.""" + + if field_mapping is None: + field_mapping = self._strip_traced_fields() + if not field_mapping: + return None + + # TODO: remove this deferred import once TracerKeys is decoupled from Tidy3dBaseModel. + from tidy3d._common.components.autograd.field_map import TracerKeys + + tracer_keys = TracerKeys.from_field_mapping(field_mapping) + return tracer_keys.model_dump_json() + + def to_static(self) -> Self: + """Version of object with all autograd-traced fields removed.""" + + if self._has_tracers is False: + return self + + # get dictionary of all traced fields + field_mapping = self._strip_traced_fields() + + # shortcut to just return self if no tracers found, for performance + if not field_mapping: + self._has_tracers = False + return self + + # convert all fields to static values + field_mapping_static = {key: get_static(val) for key, val in field_mapping.items()} + + # insert the static values into a copy of self + static_self = self._insert_traced_fields(field_mapping_static) + static_self._has_tracers = False + return static_self + + @classmethod + def generate_docstring(cls) -> str: + """Generates a docstring for a Tidy3D model.""" + + doc = "" + + # keep any pre-existing class description + original_docstrings = [] + if cls.__doc__: + original_docstrings = cls.__doc__.split("\n\n") + doc += original_docstrings.pop(0) + original_docstrings = "\n\n".join(original_docstrings) + + # parameters + doc += "\n\n Parameters\n ----------\n" + for field_name, field in cls.model_fields.items(): # v2 + if field_name == TYPE_TAG_STR: + continue + + # type + ann = getattr(field, "annotation", None) + data_type = _fmt_ann_literal(ann) + + # default / default_factory + default_val = ( + f"{field.default_factory.__name__}()" + if field.default_factory is not None + else field.get_default(call_default_factory=False) + ) + + if isinstance(default_val, BaseModel) or ( + "=" in str(default_val) if default_val is not None else False + ): + default_val = ", ".join( + str(f"{default_val.__class__.__name__}({default_val})").split(" ") + ) + + default_str = "" if field.is_required() else f" = {default_val}" + doc += f" {field_name} : {data_type}{default_str}\n" + + parts = [] + + # units + units = None + extra = getattr(field, "json_schema_extra", None) + if isinstance(extra, dict): + units = extra.get("units") + if units is None and hasattr(field, "metadata"): + for meta in field.metadata: + if isinstance(meta, dict) and "units" in meta: + units = meta["units"] + break + if units is not None: + unitstr = ( + f"({', '.join(str(u) for u in units)})" + if isinstance(units, (list, tuple)) + else str(units) + ) + parts.append(f"[units = {unitstr}].") + + # description + desc = getattr(field, "description", None) + if desc: + parts.append(desc) + + if parts: + doc += " " + " ".join(parts) + "\n" + + if original_docstrings: + doc += "\n" + original_docstrings + doc += "\n" + + return doc + + def get_submodels_by_hash(self) -> dict[int, list[Union[str, tuple[str, int]]]]: + """ + Return a mapping ``{hash(submodel): [field_path, ...]}`` for every + nested ``Tidy3dBaseModel`` inside this model. + """ + out = defaultdict(list) + + for name in type(self).model_fields: + value = getattr(self, name) + + if isinstance(value, Tidy3dBaseModel): + out[hash(value)].append(name) + continue + + if isinstance(value, (list, tuple)): + for idx, item in enumerate(value): + if isinstance(item, Tidy3dBaseModel): + out[hash(item)].append((name, idx)) + + elif isinstance(value, np.ndarray): + for idx, item in enumerate(value.flat): + if isinstance(item, Tidy3dBaseModel): + out[hash(item)].append((name, idx)) + + elif isinstance(value, dict): + for k, item in value.items(): + if isinstance(item, Tidy3dBaseModel): + out[hash(item)].append((name, k)) + + return dict(out) + + @staticmethod + def _scientific_notation( + min_val: float, max_val: float, min_digits: int = 4 + ) -> tuple[str, str]: + """ + Convert numbers to scientific notation, displaying only digits up to the point of difference, + with a minimum number of significant digits specified by `min_digits`. + """ + + def to_sci(value: float, exponent: int, precision: int) -> str: + normalized_value = value / (10**exponent) + return f"{normalized_value:.{precision}f}e{exponent}" + + if min_val == 0 or max_val == 0: + return f"{min_val:.0e}", f"{max_val:.0e}" + + exponent_min = math.floor(math.log10(abs(min_val))) + exponent_max = math.floor(math.log10(abs(max_val))) + + common_exponent = min(exponent_min, exponent_max) + normalized_min = min_val / (10**common_exponent) + normalized_max = max_val / (10**common_exponent) + + if normalized_min == normalized_max: + precision = min_digits + else: + precision = 0 + while round(normalized_min, precision) == round(normalized_max, precision): + precision += 1 + + precision = max(precision, min_digits) + + sci_min = to_sci(min_val, common_exponent, precision) + sci_max = to_sci(max_val, common_exponent, precision) + + return sci_min, sci_max + + def __rich_repr__(self) -> rich.repr.Result: + """How to pretty-print instances of ``Tidy3dBaseModel``.""" + for name in type(self).model_fields: + value = getattr(self, name) + + # don't print the type field we add to the models + if name == "type": + continue + + # skip `attrs` if it's an empty dictionary + if name == "attrs" and isinstance(value, dict) and not value: + continue + + yield name, value + + def __str__(self) -> str: + """Return a pretty-printed string representation of the model.""" + from io import StringIO + + from rich.console import Console + + sio = StringIO() + console = Console(file=sio) + console.print(self) + output = sio.getvalue() + return output.rstrip("\n") + + +def _make_lazy_proxy( + target_cls: type[Tidy3dBaseModel], + on_load: Optional[Callable[[Any], None]] = None, +) -> type[Tidy3dBaseModel]: + """ + Return a lazy-loading proxy subclass of ``target_cls``. + + Parameters + ---------- + target_cls : type + Must implement ``dict_from_file`` and ``model_validate``. + on_load : Optional[Callable[[Any], None]] = None + A function to call with the fully loaded instance once loaded. + + Returns + ------- + type + A class named ``Proxy`` with init args: + ``(fname, group_path, parse_obj_kwargs)``. + """ + + proxy_name = f"{target_cls.__name__}Proxy" + + class _LazyProxy(target_cls): # type: ignore[misc] + def __init__( + self, + fname: PathLike, + group_path: Optional[str], + parse_obj_kwargs: Any, + ) -> None: + # store lazy context only in __dict__ + object.__setattr__(self, "_lazy_fname", Path(fname)) + object.__setattr__(self, "_lazy_group_path", group_path) + object.__setattr__(self, "_lazy_parse_obj_kwargs", dict(parse_obj_kwargs or {})) + + def copy(self, **kwargs: Any) -> Self: + """Return another lazy proxy instead of materializing.""" + return _LazyProxy( + object.__getattribute__(self, "_lazy_fname"), + object.__getattribute__(self, "_lazy_group_path"), + { + **object.__getattribute__(self, "_lazy_parse_obj_kwargs"), + **kwargs, + }, + ) + + def __getattribute__(self, name: str) -> Any: + # Attributes that must *not* trigger materialization + if name.startswith("_lazy_") or name in { + "__class__", + "__dict__", + "__weakref__", + "__post_root_validators__", + "__pydantic_decorators__", + "copy", # don't materialize just for .copy() + }: + return object.__getattribute__(self, name) + + d = object.__getattribute__(self, "__dict__") + + if "_lazy_fname" in d: + fname = d["_lazy_fname"] + group_path = d["_lazy_group_path"] + kwargs = d["_lazy_parse_obj_kwargs"] + + # Build the real instance + model_dict = target_cls.dict_from_file(fname=fname, group_path=group_path) + target = target_cls._validate_model_dict(model_dict, **kwargs) + + d.clear() + d.update(target.__dict__) + + object.__setattr__(self, "__class__", target.__class__) + fields_set = getattr(target, "__pydantic_fields_set__", None) + if fields_set is not None: + object.__setattr__(self, "__pydantic_fields_set__", set(fields_set)) + + pvt = getattr(target, "__pydantic_private__", None) + if pvt is not None: + object.__setattr__(self, "__pydantic_private__", pvt) + + if on_load is not None: + on_load(self) + + return object.__getattribute__(self, name) + + _LazyProxy.__name__ = proxy_name + return _LazyProxy diff --git a/tidy3d/_common/components/base_sim/__init__.py b/tidy3d/_common/components/base_sim/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/base_sim/source.py b/tidy3d/_common/components/base_sim/source.py new file mode 100644 index 0000000000..eb6f51deca --- /dev/null +++ b/tidy3d/_common/components/base_sim/source.py @@ -0,0 +1,30 @@ +"""Abstract base for classes that define simulation sources.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional + +from pydantic import Field + +from tidy3d._common.components.base import Tidy3dBaseModel +from tidy3d._common.components.validators import validate_name_str + +if TYPE_CHECKING: + from tidy3d._common.components.viz import PlotParams + + +class AbstractSource(Tidy3dBaseModel, ABC): + """Abstract base class for all sources.""" + + name: Optional[str] = Field( + None, + title="Name", + description="Optional name for the source.", + ) + + @abstractmethod + def plot_params(self) -> PlotParams: + """Default parameters for plotting a Source object.""" + + _name_validator = validate_name_str() diff --git a/tidy3d/_common/components/data/__init__.py b/tidy3d/_common/components/data/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/data/data_array.py b/tidy3d/_common/components/data/data_array.py new file mode 100644 index 0000000000..40b628efcb --- /dev/null +++ b/tidy3d/_common/components/data/data_array.py @@ -0,0 +1,840 @@ +"""Storing tidy3d data at it's most fundamental level as xr.DataArray objects""" + +from __future__ import annotations + +import pathlib +from abc import ABC +from typing import TYPE_CHECKING, Any + +import autograd.numpy as anp +import h5py +import numpy as np +import xarray as xr +from autograd.tracer import isbox +from pydantic_core import core_schema +from xarray.core import missing +from xarray.core.indexes import PandasIndex +from xarray.core.indexing import _outer_to_numpy_indexer +from xarray.core.utils import OrderedSet, either_dict_or_kwargs +from xarray.core.variable import as_variable + +from tidy3d._common.compat import alignment +from tidy3d._common.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box +from tidy3d._common.components.geometry.bound_ops import bounds_contains +from tidy3d._common.constants import ( + HERTZ, + MICROMETER, + RADIAN, + SECOND, +) +from tidy3d._common.exceptions import DataError, FileError + +if TYPE_CHECKING: + from collections.abc import Mapping + from os import PathLike + from typing import Optional, Union + + from numpy.typing import NDArray + from pydantic.annotated_handlers import GetCoreSchemaHandler + from pydantic.json_schema import GetJsonSchemaHandler, JsonSchemaValue + from xarray.core.types import InterpOptions, Self + + from tidy3d._common.components.autograd import InterpolationType + from tidy3d._common.components.types.base import Axis, Bound + +# maps the dimension names to their attributes +DIM_ATTRS = { + "x": {"units": MICROMETER, "long_name": "x position"}, + "y": {"units": MICROMETER, "long_name": "y position"}, + "z": {"units": MICROMETER, "long_name": "z position"}, + "f": {"units": HERTZ, "long_name": "frequency"}, + "t": {"units": SECOND, "long_name": "time"}, + "direction": {"long_name": "propagation direction"}, + "mode_index": {"long_name": "mode index"}, + "eme_port_index": {"long_name": "EME port index"}, + "eme_cell_index": {"long_name": "EME cell index"}, + "mode_index_in": {"long_name": "mode index in"}, + "mode_index_out": {"long_name": "mode index out"}, + "sweep_index": {"long_name": "sweep index"}, + "theta": {"units": RADIAN, "long_name": "elevation angle"}, + "phi": {"units": RADIAN, "long_name": "azimuth angle"}, + "ux": {"long_name": "normalized kx"}, + "uy": {"long_name": "normalized ky"}, + "orders_x": {"long_name": "diffraction order"}, + "orders_y": {"long_name": "diffraction order"}, + "face_index": {"long_name": "face index"}, + "vertex_index": {"long_name": "vertex index"}, + "axis": {"long_name": "axis"}, +} + + +# name of the DataArray.values in the hdf5 file (xarray's default name too) +DATA_ARRAY_VALUE_NAME = "__xarray_dataarray_variable__" + +DATA_ARRAY_MAP: dict[str, type[DataArray]] = {} +DATA_ARRAY_TYPES: list[type[DataArray]] = [] + + +class DataArray(xr.DataArray): + """Subclass of ``xr.DataArray`` that requires _dims to match the keys of the coords.""" + + # Always set __slots__ = () to avoid xarray warnings + __slots__ = () + # stores an ordered tuple of strings corresponding to the data dimensions + _dims = () + # stores a dictionary of attributes corresponding to the data values + _data_attrs: dict[str, str] = {} + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if cls is DataArray: + return + DATA_ARRAY_MAP[cls.__name__] = cls + DATA_ARRAY_TYPES.append(cls) + + def __init__(self, data: Any, *args: Any, **kwargs: Any) -> None: + # if data is a vanilla autograd box, convert to our box + if isbox(data) and not is_tidy_box(data): + data = TidyArrayBox.from_arraybox(data) + # do the same for xr.Variable or xr.DataArray type + elif isinstance(data, (xr.Variable, xr.DataArray)): + if isbox(data.data) and not is_tidy_box(data.data): + data.data = TidyArrayBox.from_arraybox(data.data) + super().__init__(data, *args, **kwargs) + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + """Core schema definition for validation & serialization.""" + + def _initial_parser(value: Any) -> Self: + if isinstance(value, cls): + return value + + if isinstance(value, str) and value == cls.__name__: + raise DataError( + f"Trying to load '{cls.__name__}' from string placeholder '{value}' " + "but the actual data is missing. DataArrays are not typically stored " + "in JSON. Load from HDF5 or ensure the DataArray object is provided." + ) + + try: + instance = cls(value) + if not isinstance(instance, cls): + raise TypeError( + f"Constructor for {cls.__name__} returned unexpected type {type(instance)}" + ) + return instance + except Exception as e: + raise ValueError( + f"Could not construct '{cls.__name__}' from input of type '{type(value)}'. " + f"Ensure input is compatible with xarray.DataArray constructor. Original error: {e}" + ) from e + + validation_schema = core_schema.no_info_plain_validator_function(_initial_parser) + validation_schema = core_schema.no_info_after_validator_function( + cls._validate_dims, validation_schema + ) + validation_schema = core_schema.no_info_after_validator_function( + cls._assign_data_attrs, validation_schema + ) + validation_schema = core_schema.no_info_after_validator_function( + cls._assign_coord_attrs, validation_schema + ) + + def _serialize_to_name(instance: Self) -> str: + return type(instance).__name__ + + # serialization behavior: + # - for JSON ('json' mode), use the _serialize_to_name function. + # - for Python ('python' mode), use Pydantic's default for the object type + serialization_schema = core_schema.plain_serializer_function_ser_schema( + _serialize_to_name, + return_schema=core_schema.str_schema(), + when_used="json", + ) + + return core_schema.json_or_python_schema( + python_schema=validation_schema, + json_schema=validation_schema, # Use same validation rules for JSON input + serialization=serialization_schema, + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema_obj: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + """JSON schema definition (defines how it LOOKS in a schema, not the data).""" + return { + "type": "string", + "title": cls.__name__, + "description": ( + f"Placeholder for a '{cls.__name__}' object. Actual data is typically " + "serialized separately (e.g., via HDF5) and not embedded in JSON." + ), + } + + @classmethod + def _validate_dims(cls, val: Self) -> Self: + """Make sure the dims are the same as ``_dims``, then put them in the correct order.""" + if set(val.dims) != set(cls._dims): + raise ValueError( + f"Wrong dims for {cls.__name__}, expected '{cls._dims}', got '{val.dims}'" + ) + if val.dims != cls._dims: + val = val.transpose(*cls._dims) + return val + + @classmethod + def _assign_data_attrs(cls, val: Self) -> Self: + """Assign the correct data attributes to the :class:`.DataArray`.""" + for attr_name, attr_val in cls._data_attrs.items(): + val.attrs[attr_name] = attr_val + return val + + @classmethod + def _assign_coord_attrs(cls, val: Self) -> Self: + """Assign the correct coordinate attributes to the :class:`.DataArray`.""" + target_dims = set(val.dims) & set(cls._dims) & set(val.coords) + for dim in target_dims: + template = DIM_ATTRS.get(dim) + if not template: + continue + + coord_attrs = val.coords[dim].attrs + missing = {k: v for k, v in template.items() if coord_attrs.get(k) != v} + coord_attrs.update(missing) + return val + + def _interp_validator(self, field_name: Optional[str] = None) -> None: + """Ensure the data can be interpolated or selected by checking for duplicate coordinates. + + NOTE + ---- + This does not check every 'DataArray' by default. Instead, when required, this check can be + called from a validator, as is the case with 'CustomMedium' and 'CustomFieldSource'. + """ + if field_name is None: + field_name = self.__class__.__name__ + + for dim, coord in self.coords.items(): + if coord.to_index().duplicated().any(): + raise DataError( + f"Field '{field_name}' contains duplicate coordinates in dimension '{dim}'. " + "Duplicates can be removed by running " + f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'." + ) + + def __eq__(self, other: Any) -> bool: + """Whether two data array objects are equal.""" + + if not isinstance(other, xr.DataArray): + return False + + if not self.data.shape == other.data.shape or not np.all(self.data == other.data): + return False + for key, val in self.coords.items(): + if not np.all(np.array(val) == np.array(other.coords[key])): + return False + return True + + @property + def values(self) -> NDArray: + """ + The array's data converted to a numpy.ndarray. + + Returns + ------- + np.ndarray + The values of the DataArray. + """ + return self.data if isbox(self.data) else super().values + + @values.setter + def values(self, value: Any) -> None: + self.variable.values = value + + def to_numpy(self) -> np.ndarray: + """Return `.data` when traced to avoid `dtype=object` NumPy conversion.""" + return self.data if isbox(self.data) else super().to_numpy() + + @property + def abs(self) -> Self: + """Absolute value of data array.""" + return abs(self) + + @property + def angle(self) -> Self: + """Angle or phase value of data array.""" + values = np.angle(self.values) + return type(self)(values, coords=self.coords) + + @property + def is_uniform(self) -> bool: + """Whether each element is of equal value in the data array""" + raw_data = self.data.ravel() + return np.allclose(raw_data, raw_data[0]) + + def to_hdf5(self, fname: Union[PathLike, h5py.File], group_path: str) -> None: + """Save an ``xr.DataArray`` to the hdf5 file or file handle with a given path to the group.""" + if isinstance(fname, (str, pathlib.Path)): + path = pathlib.Path(fname) + path.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(path, "w") as f_handle: + self.to_hdf5_handle(f_handle=f_handle, group_path=group_path) + else: + self.to_hdf5_handle(f_handle=fname, group_path=group_path) + + def to_hdf5_handle(self, f_handle: h5py.File, group_path: str) -> None: + """Save an ``xr.DataArray`` to the hdf5 file handle with a given path to the group.""" + sub_group = f_handle.create_group(group_path) + sub_group[DATA_ARRAY_VALUE_NAME] = get_static(self.data) + for key, val in self.coords.items(): + if val.dtype == " Self: + """Load a DataArray from an hdf5 file with a given path to the group.""" + path = pathlib.Path(fname) + with h5py.File(path, "r") as f: + sub_group = f[group_path] + values = np.array(sub_group[DATA_ARRAY_VALUE_NAME]) + coords = {dim: np.array(sub_group[dim]) for dim in cls._dims if dim in sub_group} + for key, val in coords.items(): + if val.dtype == "O": + coords[key] = [byte_string.decode() for byte_string in val.tolist()] + return cls(values, coords=coords, dims=cls._dims) + + @classmethod + def from_file(cls, fname: PathLike, group_path: str) -> Self: + """Load a DataArray from an hdf5 file with a given path to the group.""" + path = pathlib.Path(fname) + if not any(suffix.lower() == ".hdf5" for suffix in path.suffixes): + raise FileError( + f"'DataArray' objects must be written to '.hdf5' format. Given filename of {path}." + ) + return cls.from_hdf5(fname=path, group_path=group_path) + + def __hash__(self) -> int: + """Generate hash value for a :class:`.DataArray` instance, needed for custom components.""" + import dask + + token_str = dask.base.tokenize(self) + return hash(token_str) + + def multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: + """Multiply self by value at indices.""" + if isbox(self.data) or isbox(value): + return self._ag_multiply_at(value, coord_name, indices) + + self_mult = self.copy() + self_mult[{coord_name: indices}] *= value + return self_mult + + def _ag_multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: + """Autograd multiply_at override when tracing.""" + key = {coord_name: indices} + _, index_tuple, _ = self.variable._broadcast_indexes(key) + idx = _outer_to_numpy_indexer(index_tuple, self.data.shape) + mask = np.zeros(self.data.shape, dtype="?") + mask[idx] = True + return self.copy(deep=False, data=anp.where(mask, self.data * value, self.data)) + + def interp( + self, + coords: Mapping[Any, Any] | None = None, + method: InterpOptions = "linear", + assume_sorted: bool = False, + kwargs: Mapping[str, Any] | None = None, + **coords_kwargs: Any, + ) -> Self: + """Interpolate this DataArray to new coordinate values. + + Parameters + ---------- + coords : Union[Mapping[Any, Any], None] = None + A mapping from dimension names to new coordinate labels. + method : InterpOptions = "linear" + The interpolation method to use. + assume_sorted : bool = False + If True, skip sorting of coordinates. + kwargs : Union[Mapping[str, Any], None] = None + Additional keyword arguments to pass to the interpolation function. + **coords_kwargs : Any + The keyword arguments form of coords. + + Returns + ------- + DataArray + A new DataArray with interpolated values. + + Raises + ------ + KeyError + If any of the specified coordinates are not in the DataArray. + """ + if isbox(self.data): + return self._ag_interp(coords, method, assume_sorted, kwargs, **coords_kwargs) + + return super().interp(coords, method, assume_sorted, kwargs, **coords_kwargs) + + def _ag_interp( + self, + coords: Union[Mapping[Any, Any], None] = None, + method: InterpOptions = "linear", + assume_sorted: bool = False, + kwargs: Union[Mapping[str, Any], None] = None, + **coords_kwargs: Any, + ) -> Self: + """Autograd interp override when tracing over self.data. + + This implementation closely follows the interp implementation of xarray + to match its behavior as closely as possible while supporting autograd. + + See: + - https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html + - https://docs.xarray.dev/en/latest/generated/xarray.Dataset.interp.html + """ + if kwargs is None: + kwargs = {} + + ds = self._to_temp_dataset() + + coords = either_dict_or_kwargs(coords, coords_kwargs, "interp") + indexers = dict(ds._validate_interp_indexers(coords)) + + if coords: + # Find shared dimensions between the dataset and the indexers + sdims = ( + set(ds.dims) + .intersection(*[set(nx.dims) for nx in indexers.values()]) + .difference(coords.keys()) + ) + indexers.update({d: ds.variables[d] for d in sdims}) + + obj = ds if assume_sorted else ds.sortby(list(coords)) + + # workaround to get a variable for a dimension without a coordinate + validated_indexers = { + k: (obj._variables.get(k, as_variable((k, range(obj.sizes[k])))), v) + for k, v in indexers.items() + } + + for k, v in validated_indexers.items(): + obj, newidx = missing._localize(obj, {k: v}) + validated_indexers[k] = newidx[k] + + variables = {} + reindex = False + for name, var in obj._variables.items(): + if name in indexers: + continue + dtype_kind = var.dtype.kind + if dtype_kind in "uifc": + # Interpolation for numeric types + var_indexers = {k: v for k, v in validated_indexers.items() if k in var.dims} + variables[name] = self._ag_interp_func(var, var_indexers, method, **kwargs) + elif dtype_kind in "ObU" and (validated_indexers.keys() & var.dims): + # Stepwise interpolation for non-numeric types + reindex = True + elif all(d not in indexers for d in var.dims): + # Keep variables not dependent on interpolated coords + variables[name] = var + + if reindex: + # Reindex for non-numeric types + reindex_indexers = {k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,)} + reindexed = alignment.reindex( + obj, + indexers=reindex_indexers, + method="nearest", + exclude_vars=variables.keys(), + ) + indexes = dict(reindexed._indexes) + variables.update(reindexed.variables) + else: + # Get the indexes that are not being interpolated along + indexes = {k: v for k, v in obj._indexes.items() if k not in indexers} + + # Get the coords that also exist in the variables + coord_names = obj._coord_names & variables.keys() + selected = ds._replace_with_new_dims(variables.copy(), coord_names, indexes=indexes) + + # Attach indexer as coordinate + for k, v in indexers.items(): + if v.dims == (k,): + index = PandasIndex(v, k, coord_dtype=v.dtype) + index_vars = index.create_variables({k: v}) + indexes[k] = index + variables.update(index_vars) + else: + variables[k] = v + + # Extract coordinates from indexers + coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) + variables.update(coord_vars) + indexes.update(new_indexes) + + coord_names = obj._coord_names & variables.keys() | coord_vars.keys() + ds = ds._replace_with_new_dims(variables, coord_names, indexes=indexes) + return self._from_temp_dataset(ds) + + @staticmethod + def _ag_interp_func( + var: xr.Variable, + indexes_coords: dict[str, tuple[xr.Variable, xr.Variable]], + method: InterpolationType, + **kwargs: Any, + ) -> xr.Variable: + """ + Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`. + + The implementation follows xarray's interp implementation in xarray.core.missing, + but replaces some of the pre-processing as well as the actual interpolation + function with an autograd-compatible approach. + + + Parameters + ---------- + var : xr.Variable + The variable to be interpolated. + indexes_coords : dict + A dictionary mapping dimension names to coordinate values for interpolation. + method : Literal["nearest", "linear"] + The interpolation method to use. + **kwargs : dict + Additional keyword arguments to pass to the interpolation function. + + Returns + ------- + xr.Variable + The interpolated variable. + """ + if not indexes_coords: + return var.copy() + result = var + for indep_indexes_coords in missing.decompose_interp(indexes_coords): + var = result + + # target dimensions + dims = list(indep_indexes_coords) + x, new_x = zip(*[indep_indexes_coords[d] for d in dims]) + destination = missing.broadcast_variables(*new_x) + + broadcast_dims = [d for d in var.dims if d not in dims] + original_dims = broadcast_dims + dims + new_dims = broadcast_dims + list(destination[0].dims) + + x, new_x = missing._floatize_x(x, new_x) + + permutation = [var.dims.index(dim) for dim in original_dims] + combined_permutation = permutation[-len(x) :] + permutation[: -len(x)] + data = anp.transpose(var.data, combined_permutation) + xi = anp.stack([anp.ravel(new_xi.data) for new_xi in new_x], axis=-1) + + result = interpn([xn.data for xn in x], data, xi, method=method, **kwargs) + + result = anp.moveaxis(result, 0, -1) + result = anp.reshape(result, result.shape[:-1] + new_x[0].shape) + + result = xr.Variable(new_dims, result, attrs=var.attrs, fastpath=True) + + out_dims: OrderedSet = OrderedSet() + for d in var.dims: + if d in dims: + out_dims.update(indep_indexes_coords[d][1].dims) + else: + out_dims.add(d) + if len(out_dims) > 1: + result = result.transpose(*out_dims) + return result + + def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> DataArray: + """Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd compatible + + Constraints / Edge cases: + - `coords` must map to a specific value eg {x: '1'}, does not broadcast to arrays + - `data` will be reshaped to try to match `self.shape` except where `coords` present + """ + + # make mask + mask = xr.zeros_like(self, dtype=bool) + mask.loc[coords] = True + + # reshape `data` to line up with `self.dims`, with shape of 1 along the selected axis + old_data = self.data + new_shape = list(old_data.shape) + for i, dim in enumerate(self.dims): + if dim in coords: + new_shape[i] = 1 + try: + new_data = data.reshape(new_shape) + except ValueError as e: + raise ValueError( + "Couldn't reshape the supplied 'data' to update 'DataArray'. The provided data was " + f"of shape {data.shape} and tried to reshape to {new_shape}. If you encounter this " + "error please raise an issue on the tidy3d github repository with the context." + ) from e + + # broadcast data to repeat data along the selected dimensions to match mask + new_data = new_data + np.zeros_like(old_data) + + new_data = np.where(mask, new_data, old_data) + + return self.copy(deep=True, data=new_data) + + +class FreqDataArray(DataArray): + """Frequency-domain array. + + Example + ------- + >>> f = [2e14, 3e14] + >>> fd = FreqDataArray((1+1j) * np.random.random((2,)), coords=dict(f=f)) + """ + + __slots__ = () + _dims = ("f",) + + +class AbstractSpatialDataArray(DataArray, ABC): + """Spatial distribution.""" + + __slots__ = () + _dims = ("x", "y", "z") + _data_attrs = {"long_name": "field value"} + + @property + def _spatially_sorted(self) -> Self: + """Check whether sorted and sort if not.""" + needs_sorting = [] + for axis in "xyz": + axis_coords = self.coords[axis].values + if len(axis_coords) > 1 and np.any(axis_coords[1:] < axis_coords[:-1]): + needs_sorting.append(axis) + + if len(needs_sorting) > 0: + return self.sortby(needs_sorting) + + return self + + def sel_inside(self, bounds: Bound) -> Self: + """Return a new SpatialDataArray that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. Note that the returned data is sorted with respect + to spatial coordinates. + + + Parameters + ---------- + bounds : Tuple[float, float, float], Tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + SpatialDataArray + Extracted spatial data array. + """ + if any(bmin > bmax for bmin, bmax in zip(*bounds)): + raise DataError( + "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." + ) + + # make sure data is sorted with respect to coordinates + sorted_self = self._spatially_sorted + + inds_list = [] + + coords = (sorted_self.x, sorted_self.y, sorted_self.z) + + for coord, smin, smax in zip(coords, bounds[0], bounds[1]): + length = len(coord) + + # one point along direction, assume invariance + if length == 1: + comp_inds = [0] + else: + # if data does not cover structure at all take the closest index + if smax < coord[0]: # structure is completely on the left side + # take 2 if possible, so that linear iterpolation is possible + comp_inds = np.arange(0, max(2, length)) + + elif smin > coord[-1]: # structure is completely on the right side + # take 2 if possible, so that linear iterpolation is possible + comp_inds = np.arange(min(0, length - 2), length) + + else: + if smin < coord[0]: + ind_min = 0 + else: + ind_min = max(0, (coord >= smin).argmax().data - 1) + + if smax > coord[-1]: + ind_max = length - 1 + else: + ind_max = (coord >= smax).argmax().data + + comp_inds = np.arange(ind_min, ind_max + 1) + + inds_list.append(comp_inds) + + return sorted_self.isel(x=inds_list[0], y=inds_list[1], z=inds_list[2]) + + def does_cover(self, bounds: Bound, rtol: float = 0.0, atol: float = 0.0) -> bool: + """Check whether data fully covers specified by ``bounds`` spatial region. If data contains + only one point along a given direction, then it is assumed the data is constant along that + direction and coverage is not checked. + + + Parameters + ---------- + bounds : Tuple[float, float, float], Tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + rtol : float = 0.0 + Relative tolerance for comparing bounds + atol : float = 0.0 + Absolute tolerance for comparing bounds + + Returns + ------- + bool + Full cover check outcome. + """ + if any(bmin > bmax for bmin, bmax in zip(*bounds)): + raise DataError( + "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." + ) + xyz = [self.x, self.y, self.z] + self_min = [0] * 3 + self_max = [0] * 3 + for dim in range(3): + coords = xyz[dim] + if len(coords) == 1: + self_min[dim] = bounds[0][dim] + self_max[dim] = bounds[1][dim] + else: + self_min[dim] = np.min(coords) + self_max[dim] = np.max(coords) + self_bounds = (tuple(self_min), tuple(self_max)) + return bounds_contains(self_bounds, bounds, rtol=rtol, atol=atol) + + +class ScalarFieldDataArray(AbstractSpatialDataArray): + """Spatial distribution in the frequency-domain. + + Example + ------- + >>> x = [1,2] + >>> y = [2,3,4] + >>> z = [3,4,5,6] + >>> f = [2e14, 3e14] + >>> coords = dict(x=x, y=y, z=z, f=f) + >>> fd = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) + """ + + __slots__ = () + _dims = ("x", "y", "z", "f") + + +class TriangleMeshDataArray(DataArray): + """Data of the triangles of a surface mesh as in the STL file format.""" + + __slots__ = () + _dims = ("face_index", "vertex_index", "axis") + _data_attrs = {"long_name": "surface mesh triangles"} + + +class TimeDataArray(DataArray): + """Time-domain array. + + Example + ------- + >>> t = [0, 1e-12, 2e-12] + >>> td = TimeDataArray((1+1j) * np.random.random((3,)), coords=dict(t=t)) + """ + + __slots__ = () + _dims = ("t",) + + +class SpatialDataArray(AbstractSpatialDataArray): + """Spatial distribution. + + Example + ------- + >>> x = [1,2] + >>> y = [2,3,4] + >>> z = [3,4,5,6] + >>> coords = dict(x=x, y=y, z=z) + >>> fd = SpatialDataArray((1+1j) * np.random.random((2,3,4)), coords=coords) + """ + + __slots__ = () + + def reflect(self, axis: Axis, center: float, reflection_only: bool = False) -> Self: + """Reflect data across the plane define by parameters ``axis`` and ``center`` from right to + left. Note that the returned data is sorted with respect to spatial coordinates. + + Parameters + ---------- + axis : Literal[0, 1, 2] + Normal direction of the reflection plane. + center : float + Location of the reflection plane along its normal direction. + reflection_only : bool = False + Return only reflected data. + + Returns + ------- + SpatialDataArray + Data after reflection is performed. + """ + + sorted_self = self._spatially_sorted + + coords = [sorted_self.x.values, sorted_self.y.values, sorted_self.z.values] + data = np.array(sorted_self.data) + + data_left_bound = coords[axis][0] + + if np.isclose(center, data_left_bound): + num_duplicates = 1 + elif center > data_left_bound: + raise DataError("Reflection center must be outside and to the left of the data region.") + else: + num_duplicates = 0 + + if reflection_only: + coords[axis] = 2 * center - coords[axis] + coords_dict = dict(zip("xyz", coords)) + + tmp_arr = SpatialDataArray(sorted_self.data, coords=coords_dict) + + return tmp_arr.sortby("xyz"[axis]) + + shape = np.array(np.shape(data)) + old_len = shape[axis] + shape[axis] = 2 * old_len - num_duplicates + + ind_left = [slice(shape[0]), slice(shape[1]), slice(shape[2])] + ind_right = [slice(shape[0]), slice(shape[1]), slice(shape[2])] + + ind_left[axis] = slice(old_len - 1, None, -1) + ind_right[axis] = slice(old_len - num_duplicates, None) + + new_data = np.zeros(shape) + + new_data[ind_left[0], ind_left[1], ind_left[2]] = data + new_data[ind_right[0], ind_right[1], ind_right[2]] = data + + new_coords = np.zeros(shape[axis]) + new_coords[old_len - num_duplicates :] = coords[axis] + new_coords[old_len - 1 :: -1] = 2 * center - coords[axis] + + coords[axis] = new_coords + coords_dict = dict(zip("xyz", coords)) + + return SpatialDataArray(new_data, coords=coords_dict) diff --git a/tidy3d/_common/components/data/dataset.py b/tidy3d/_common/components/data/dataset.py new file mode 100644 index 0000000000..21ee22b6b8 --- /dev/null +++ b/tidy3d/_common/components/data/dataset.py @@ -0,0 +1,207 @@ +"""Collections of DataArrays.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import numpy as np +import xarray as xr +from pydantic import Field + +from tidy3d._common.components.base import Tidy3dBaseModel +from tidy3d._common.components.data.data_array import ( + DataArray, + ScalarFieldDataArray, + TimeDataArray, + TriangleMeshDataArray, +) +from tidy3d._common.exceptions import DataError +from tidy3d._common.log import log + +if TYPE_CHECKING: + from typing import Callable + + from tidy3d._common.components.types.base import ArrayLike, Axis + +DEFAULT_MAX_SAMPLES_PER_STEP = 10_000 +DEFAULT_MAX_CELLS_PER_STEP = 10_000 +DEFAULT_TOLERANCE_CELL_FINDING = 1e-6 + + +class Dataset(Tidy3dBaseModel, ABC): + """Abstract base class for objects that store collections of `:class:`.DataArray`s.""" + + @property + def data_arrs(self) -> dict: + """Returns a dictionary of all `:class:`.DataArray`s in the dataset.""" + data_arrs = {} + for key in self.__class__.model_fields.keys(): + data = getattr(self, key) + if isinstance(data, DataArray): + data_arrs[key] = data + return data_arrs + + +class TriangleMeshDataset(Dataset): + """Dataset for storing triangular surface data.""" + + surface_mesh: TriangleMeshDataArray = Field( + title="Surface mesh data", + description="Dataset containing the surface triangles and corresponding face indices " + "for a surface mesh.", + ) + + +class AbstractFieldDataset(Dataset, ABC): + """Collection of scalar fields with some symmetry properties.""" + + @property + @abstractmethod + def field_components(self) -> dict[str, DataArray]: + """Maps the field components to their associated data.""" + + def apply_phase(self, phase: float) -> AbstractFieldDataset: + """Create a copy where all elements are phase-shifted by a value (in radians).""" + if phase == 0.0: + return self + phasor = np.exp(1j * phase) + field_components_shifted = {} + for fld_name, fld_cmp in self.field_components.items(): + fld_cmp_shifted = phasor * fld_cmp + field_components_shifted[fld_name] = fld_cmp_shifted + return self.updated_copy(**field_components_shifted) + + @property + @abstractmethod + def grid_locations(self) -> dict[str, str]: + """Maps field components to the string key of their grid locations on the yee lattice.""" + + @property + @abstractmethod + def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: + """Maps field components to their (positive) symmetry eigenvalues.""" + + def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArray]) -> Any: + """How to package the dictionary of fields computed via self.colocate().""" + return xr.Dataset(centered_fields) + + def colocate(self, x: ArrayLike = None, y: ArrayLike = None, z: ArrayLike = None) -> xr.Dataset: + """Colocate all of the data at a set of x, y, z coordinates. + + Parameters + ---------- + x : Optional[array-like] = None + x coordinates of locations. + If not supplied, does not try to colocate on this dimension. + y : Optional[array-like] = None + y coordinates of locations. + If not supplied, does not try to colocate on this dimension. + z : Optional[array-like] = None + z coordinates of locations. + If not supplied, does not try to colocate on this dimension. + + Returns + ------- + xr.Dataset + Dataset containing all fields at the same spatial locations. + For more details refer to `xarray's Documentation `_. + + Note + ---- + For many operations (such as flux calculations and plotting), + it is important that the fields are colocated at the same spatial locations. + Be sure to apply this method to your field data in those cases. + """ + + if hasattr(self, "monitor") and self.monitor.colocate: + with log as consolidated_logger: + consolidated_logger.warning( + "Colocating data that has already been colocated during the solver " + "run. For most accurate results when colocating to custom coordinates set " + "'Monitor.colocate' to 'False' to use the raw data on the Yee grid " + "and avoid double interpolation. Note: the default value was changed to 'True' " + "in Tidy3D version 2.4.0." + ) + + # convert supplied coordinates to array and assign string mapping to them + supplied_coord_map = {k: np.array(v) for k, v in zip("xyz", (x, y, z)) if v is not None} + + # dict of data arrays to combine in dataset and return + centered_fields = {} + + # loop through field components + for field_name, field_data in self.field_components.items(): + # loop through x, y, z dimensions and raise an error if only one element along dim + for coord_name, coords_supplied in supplied_coord_map.items(): + coord_data = np.array(field_data.coords[coord_name]) + if coord_data.size == 1: + raise DataError( + f"colocate given {coord_name}={coords_supplied}, but " + f"data only has one coordinate at {coord_name}={coord_data[0]}. " + "Therefore, can't colocate along this dimension. " + f"supply {coord_name}=None to skip it." + ) + + centered_fields[field_name] = field_data.interp( + **supplied_coord_map, kwargs={"bounds_error": True} + ) + + # combine all centered fields in a dataset + return self.package_colocate_results(centered_fields) + + +class TimeDataset(Dataset): + """Dataset for storing a function of time.""" + + values: TimeDataArray = Field( + title="Values", + description="Values as a function of time.", + ) + + +class AbstractMediumPropertyDataset(AbstractFieldDataset, ABC): + """Dataset storing medium property.""" + + eps_xx: ScalarFieldDataArray = Field( + title="Epsilon xx", + description="Spatial distribution of the xx-component of the relative permittivity.", + ) + eps_yy: ScalarFieldDataArray = Field( + title="Epsilon yy", + description="Spatial distribution of the yy-component of the relative permittivity.", + ) + eps_zz: ScalarFieldDataArray = Field( + title="Epsilon zz", + description="Spatial distribution of the zz-component of the relative permittivity.", + ) + + +class PermittivityDataset(AbstractMediumPropertyDataset): + """Dataset storing the diagonal components of the permittivity tensor. + + Example + ------- + >>> x = [-1,1] + >>> y = [-2,0,2] + >>> z = [-3,-1,1,3] + >>> f = [2e14, 3e14] + >>> coords = dict(x=x, y=y, z=z, f=f) + >>> sclr_fld = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) + >>> data = PermittivityDataset(eps_xx=sclr_fld, eps_yy=sclr_fld, eps_zz=sclr_fld) + """ + + @property + def field_components(self) -> dict[str, ScalarFieldDataArray]: + """Maps the field components to their associated data.""" + return {"eps_xx": self.eps_xx, "eps_yy": self.eps_yy, "eps_zz": self.eps_zz} + + @property + def grid_locations(self) -> dict[str, str]: + """Maps field components to the string key of their grid locations on the yee lattice.""" + return {"eps_xx": "Ex", "eps_yy": "Ey", "eps_zz": "Ez"} + + @property + def symmetry_eigenvalues(self) -> dict[str, None]: + """Maps field components to their (positive) symmetry eigenvalues.""" + return {"eps_xx": None, "eps_yy": None, "eps_zz": None} diff --git a/tidy3d/_common/components/data/validators.py b/tidy3d/_common/components/data/validators.py new file mode 100644 index 0000000000..fd7ae3d2bf --- /dev/null +++ b/tidy3d/_common/components/data/validators.py @@ -0,0 +1,85 @@ +# special validators for Datasets +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +from pydantic import field_validator + +from tidy3d._common.components.data.data_array import DataArray, ScalarFieldDataArray +from tidy3d._common.components.data.dataset import AbstractFieldDataset +from tidy3d._common.exceptions import ValidationError + +if TYPE_CHECKING: + from typing import Callable, Optional + + from pydantic_core.core_schema import ValidationInfo + + +# this can't go in validators.py because that file imports dataset.py +def validate_no_nans(*field_names: str) -> Callable[[Any, ValidationInfo], Any]: + """Raise validation error if nans found in Dataset, or other data-containing item.""" + + @field_validator(*field_names) + def no_nans(val: Any, info: ValidationInfo) -> Any: + """Raise validation error if nans found in Dataset, or other data-containing item.""" + + if val is None: + return val + + def error_if_has_nans(value: Any, identifier: Optional[str] = None) -> None: + """Recursively check if value (or iterable) has nans and error if so.""" + + def has_nans(values: Any) -> bool: + """Base case: do these values contain NaN?""" + try: + return np.any(np.isnan(values)) + # if this fails for some reason (fails in adjoint, for example), don't check it. + except Exception: + return False + + if isinstance(value, (tuple, list)): + for i, _value in enumerate(value): + error_if_has_nans(_value, identifier=f"[{i}]") + + elif isinstance(value, AbstractFieldDataset): + for key, val in value.field_components.items(): + error_if_has_nans(val, identifier=f".{key}") + + elif isinstance(value, DataArray): + error_if_has_nans(value.values) + + else: + if has_nans(value): + # the identifier is used to make the message more clear by appending some more info + field_name_display = info.field_name + if identifier: + field_name_display += identifier + + raise ValidationError( + f"Found 'NaN' values in '{field_name_display}'. " + "If they were not intended, please double check your construction. " + "If intended, to replace these data points with a value 'x', " + "call 'values = np.nan_to_num(values, nan=x)'." + ) + + error_if_has_nans(val) + return val + + return no_nans + + +def validate_can_interpolate( + *field_names: str, +) -> Callable[[AbstractFieldDataset], AbstractFieldDataset]: + """Make sure the data in ``field_name`` can be interpolated.""" + + @field_validator(*field_names) + def check_fields_interpolate(val: AbstractFieldDataset) -> AbstractFieldDataset: + if isinstance(val, AbstractFieldDataset): + for name, data in val.field_components.items(): + if isinstance(data, ScalarFieldDataArray): + data._interp_validator(name) + return val + + return check_fields_interpolate diff --git a/tidy3d/_common/components/data/zbf.py b/tidy3d/_common/components/data/zbf.py new file mode 100644 index 0000000000..5fa0e5a1a1 --- /dev/null +++ b/tidy3d/_common/components/data/zbf.py @@ -0,0 +1,156 @@ +"""ZBF utilities""" + +from __future__ import annotations + +from struct import unpack + +import numpy as np +from pydantic import Field + +from tidy3d._common.components.base import Tidy3dBaseModel + + +class ZBFData(Tidy3dBaseModel): + """ + Contains data read in from a ``.zbf`` file + """ + + version: int = Field(title="Version", description="File format version number.") + nx: int = Field(title="Samples in X", description="Number of samples in the x direction.") + ny: int = Field(title="Samples in Y", description="Number of samples in the y direction.") + ispol: bool = Field( + title="Is Polarized", + description="``True`` if the beam is polarized, ``False`` otherwise.", + ) + unit: str = Field( + title="Spatial Units", description="Spatial units, either 'mm', 'cm', 'in', or 'm'." + ) + dx: float = Field(title="Grid Spacing, X", description="Grid spacing in x.") + dy: float = Field(title="Grid Spacing, Y", description="Grid spacing in y.") + zposition_x: float = Field( + title="Z Position, X Direction", + description="The pilot beam z position with respect to the pilot beam waist, x direction.", + ) + zposition_y: float = Field( + title="Z Position, Y Direction", + description="The pilot beam z position with respect to the pilot beam waist, y direction.", + ) + rayleigh_x: float = Field( + title="Rayleigh Distance, X Direction", + description="The pilot beam Rayleigh distance in the x direction.", + ) + rayleigh_y: float = Field( + title="Rayleigh Distance, Y Direction", + description="The pilot beam Rayleigh distance in the y direction.", + ) + waist_x: float = Field( + title="Beam Waist, X", description="The pilot beam waist in the x direction." + ) + waist_y: float = Field( + title="Beam Waist, Y", description="The pilot beam waist in the y direction." + ) + wavelength: float = Field(title="Wavelength", description="The wavelength of the beam.") + background_refractive_index: float = Field( + title="Background Refractive Index", + description="The index of refraction in the current medium.", + ) + receiver_eff: float = Field( + title="Receiver Efficiency", + description="The receiver efficiency. Zero if fiber coupling is not computed.", + ) + system_eff: float = Field( + title="System Efficiency", + description="The system efficiency. Zero if fiber coupling is not computed.", + ) + Ex: np.ndarray = Field( + title="Electric Field, X Component", + description="Complex-valued electric field, x component.", + ) + Ey: np.ndarray = Field( + title="Electric Field, Y Component", + description="Complex-valued electric field, y component.", + ) + + def read_zbf(filename: str) -> ZBFData: + """Reads a Zemax Beam File (``.zbf``) + + Parameters + ---------- + filename : str + The file name of the ``.zbf`` file to read. + + Returns + ------- + :class:`.ZBFData` + """ + + # Read the zbf file + with open(filename, "rb") as f: + # Load the header + version, nx, ny, ispol, units = unpack("<5I", f.read(20)) + f.read(16) # unused values + ( + dx, + dy, + zposition_x, + rayleigh_x, + waist_x, + zposition_y, + rayleigh_y, + waist_y, + wavelength, + background_refractive_index, + receiver_eff, + system_eff, + ) = unpack("<12d", f.read(96)) + f.read(64) # unused values + + # read E field + nsamps = 2 * nx * ny + rawx = list(unpack(f"<{nsamps}d", f.read(8 * nsamps))) + if ispol: + rawy = list(unpack(f"<{nsamps}d", f.read(8 * nsamps))) + + # convert unit key to unit string + map_units = {0: "mm", 1: "cm", 2: "in", 3: "m"} + try: + unit = map_units[units] + except KeyError: + raise KeyError( + f"Invalid units specified in the zbf file (expected '0', '1', '2', or '3', got '{units}')." + ) from None + + # load E field + Ex_real = np.asarray(rawx[0::2]).reshape(nx, ny, order="F") + Ex_imag = np.asarray(rawx[1::2]).reshape(nx, ny, order="F") + if ispol: + Ey_real = np.asarray(rawy[0::2]).reshape(nx, ny, order="F") + Ey_imag = np.asarray(rawy[1::2]).reshape(nx, ny, order="F") + else: + Ey_real = np.zeros((nx, ny)) + Ey_imag = np.zeros((nx, ny)) + + Ex = Ex_real + 1j * Ex_imag + Ey = Ey_real + 1j * Ey_imag + + return ZBFData( + version=version, + nx=nx, + ny=ny, + ispol=ispol, + unit=unit, + dx=dx, + dy=dy, + zposition_x=zposition_x, + zposition_y=zposition_y, + rayleigh_x=rayleigh_x, + rayleigh_y=rayleigh_y, + waist_x=waist_x, + waist_y=waist_y, + wavelength=wavelength, + background_refractive_index=background_refractive_index, + receiver_eff=receiver_eff, + system_eff=system_eff, + Ex=Ex, + Ey=Ey, + ) diff --git a/tidy3d/_common/components/file_util.py b/tidy3d/_common/components/file_util.py new file mode 100644 index 0000000000..51e13f586d --- /dev/null +++ b/tidy3d/_common/components/file_util.py @@ -0,0 +1,83 @@ +"""File compression utilities""" + +from __future__ import annotations + +import gzip +import pathlib +import shutil +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + from io import BytesIO + from os import PathLike + + +def compress_file_to_gzip(input_file: PathLike, output_gz_file: PathLike | BytesIO) -> None: + """ + Compress a file using gzip. + + Parameters + ---------- + input_file : PathLike + The path to the input file. + output_gz_file : PathLike | BytesIO + The path to the output gzip file or an in-memory buffer. + """ + input_file = pathlib.Path(input_file) + with input_file.open("rb") as file_in: + with gzip.open(output_gz_file, "wb") as file_out: + shutil.copyfileobj(file_in, file_out) + + +def extract_gzip_file(input_gz_file: PathLike, output_file: PathLike) -> None: + """ + Extract a gzip-compressed file. + + Parameters + ---------- + input_gz_file : PathLike + The path to the gzip-compressed input file. + output_file : PathLike + The path to the extracted output file. + """ + input_path = pathlib.Path(input_gz_file) + output_path = pathlib.Path(output_file) + with gzip.open(input_path, "rb") as file_in: + with output_path.open("wb") as file_out: + shutil.copyfileobj(file_in, file_out) + + +def replace_values(values: Any, search_value: Any, replace_value: Any) -> Any: + """ + Create a copy of ``values`` where any elements equal to ``search_value`` are replaced by ``replace_value``. + + Parameters + ---------- + values : Any + The input object to iterate through. + search_value : Any + An object to match for in ``values``. + replace_value : Any + A replacement object for the matched value in ``values``. + + Returns + ------- + Any + values type object with ``search_value`` terms replaced by ``replace_value``. + """ + # np.all allows for arrays to be evaluated + if np.all(values == search_value): + return replace_value + if isinstance(values, dict): + return { + key: replace_values(val, search_value, replace_value) for key, val in values.items() + } + if isinstance( + values, (tuple, list) + ): # Parts of the nested dict structure include tuples with more dicts + return type(values)(replace_values(val, search_value, replace_value) for val in values) + + # Used to maintain values that are not search_value or containers + return values diff --git a/tidy3d/_common/components/geometry/__init__.py b/tidy3d/_common/components/geometry/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/geometry/base.py b/tidy3d/_common/components/geometry/base.py new file mode 100644 index 0000000000..6a0bc17e85 --- /dev/null +++ b/tidy3d/_common/components/geometry/base.py @@ -0,0 +1,3716 @@ +"""Abstract base classes for geometry.""" + +from __future__ import annotations + +import functools +import pathlib +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +import autograd.numpy as np +import shapely +from pydantic import Field, NonNegativeFloat, field_validator, model_validator + +from tidy3d._common.compat import _package_is_older_than +from tidy3d._common.components.autograd import TracedCoordinate, TracedFloat, TracedSize, get_static +from tidy3d._common.components.base import Tidy3dBaseModel, cached_property +from tidy3d._common.components.geometry.bound_ops import bounds_intersection, bounds_union +from tidy3d._common.components.geometry.float_utils import increment_float +from tidy3d._common.components.transformation import ReflectionFromPlane, RotationAroundAxis +from tidy3d._common.components.types.base import ( + Axis, + ClipOperationType, + MatrixReal4x4, + PlanePosition, + discriminated_union, +) +from tidy3d._common.components.viz import ( + ARROW_LENGTH, + PLOT_BUFFER, + add_ax_if_none, + arrow_style, + equal_aspect, + plot_params_geometry, + polygon_patch, + set_default_labels_and_title, +) +from tidy3d._common.constants import LARGE_NUMBER, MICROMETER, RADIAN, fp_eps, inf +from tidy3d._common.exceptions import ( + SetupError, + Tidy3dError, + Tidy3dImportError, + Tidy3dKeyError, + ValidationError, +) +from tidy3d._common.log import log +from tidy3d._common.packaging import verify_packages_import + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from os import PathLike + from typing import Callable, Union + + import pydantic + from gdstk import Cell + from matplotlib.backend_bases import Event + from matplotlib.patches import FancyArrowPatch + from numpy.typing import ArrayLike, NDArray + from pydantic import NonNegativeInt, PositiveFloat + from typing_extensions import Self + + from tidy3d._common.components.autograd import AutogradFieldMap + from tidy3d._common.components.autograd.derivative_utils import DerivativeInfo + from tidy3d._common.components.types.base import ( + ArrayFloat2D, + ArrayFloat3D, + Ax, + Bound, + Coordinate, + Coordinate2D, + LengthUnit, + Shapely, + Size, + ) + from tidy3d._common.components.viz import PlotParams, VisualizationSpec + +try: + from matplotlib import patches +except ImportError: + pass + +POLY_GRID_SIZE = 1e-12 +POLY_TOLERANCE_RATIO = 1e-12 +POLY_DISTANCE_TOLERANCE = 8e-12 + + +_shapely_operations = { + "union": shapely.union, + "intersection": shapely.intersection, + "difference": shapely.difference, + "symmetric_difference": shapely.symmetric_difference, +} + +_bit_operations = { + "union": lambda a, b: a | b, + "intersection": lambda a, b: a & b, + "difference": lambda a, b: a & ~b, + "symmetric_difference": lambda a, b: a != b, +} + + +class Geometry(Tidy3dBaseModel, ABC): + """Abstract base class, defines where something exists in space.""" + + @cached_property + def plot_params(self) -> PlotParams: + """Default parameters for plotting a Geometry object.""" + return plot_params_geometry + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + + def point_inside(x: float, y: float, z: float) -> bool: + """Returns ``True`` if a single point ``(x, y, z)`` is inside.""" + shapes_intersect = self.intersections_plane(z=z) + loc = self.make_shapely_point(x, y) + return any(shape.contains(loc) for shape in shapes_intersect) + + arrays = tuple(map(np.array, (x, y, z))) + self._ensure_equal_shape(*arrays) + inside = np.zeros((arrays[0].size,), dtype=bool) + arrays_flat = map(np.ravel, arrays) + for ipt, args in enumerate(zip(*arrays_flat)): + inside[ipt] = point_inside(*args) + return inside.reshape(arrays[0].shape) + + @staticmethod + def _ensure_equal_shape(*arrays: Any) -> None: + """Ensure all input arrays have the same shape.""" + shapes = {np.array(arr).shape for arr in arrays} + if len(shapes) > 1: + raise ValueError("All coordinate inputs (x, y, z) must have the same shape.") + + @staticmethod + def make_shapely_box(minx: float, miny: float, maxx: float, maxy: float) -> shapely.box: + """Make a shapely box ensuring everything untraced.""" + + minx = get_static(minx) + miny = get_static(miny) + maxx = get_static(maxx) + maxy = get_static(maxy) + + return shapely.box(minx, miny, maxx, maxy) + + @staticmethod + def make_shapely_point(minx: float, miny: float) -> shapely.Point: + """Make a shapely Point ensuring everything untraced.""" + + minx = get_static(minx) + miny = get_static(miny) + + return shapely.Point(minx, miny) + + def _inds_inside_bounds( + self, x: NDArray[float], y: NDArray[float], z: NDArray[float] + ) -> tuple[slice, slice, slice]: + """Return slices into the sorted input arrays that are inside the geometry bounds. + + Parameters + ---------- + x : np.ndarray[float] + 1D array of point positions in x direction. + y : np.ndarray[float] + 1D array of point positions in y direction. + z : np.ndarray[float] + 1D array of point positions in z direction. + + Returns + ------- + tuple[slice, slice, slice] + Slices into each of the three arrays that are inside the geometry bounds. + """ + bounds = self.bounds + inds_in = [] + for dim, coords in enumerate([x, y, z]): + inds = np.nonzero((bounds[0][dim] <= coords) * (coords <= bounds[1][dim]))[0] + inds_in.append(slice(0, 0) if inds.size == 0 else slice(inds[0], inds[-1] + 1)) + + return tuple(inds_in) + + def inside_meshgrid( + self, x: NDArray[float], y: NDArray[float], z: NDArray[float] + ) -> NDArray[bool]: + """Perform ``self.inside`` on a set of sorted 1D coordinates. Applies meshgrid to the + supplied coordinates before checking inside. + + Parameters + ---------- + + x : np.ndarray[float] + 1D array of point positions in x direction. + y : np.ndarray[float] + 1D array of point positions in y direction. + z : np.ndarray[float] + 1D array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every + point that is inside the geometry. + """ + + arrays = tuple(map(np.array, (x, y, z))) + if any(arr.ndim != 1 for arr in arrays): + raise ValueError("Each of the supplied coordinates (x, y, z) must be 1D.") + shape = tuple(arr.size for arr in arrays) + is_inside = np.zeros(shape, dtype=bool) + inds_inside = self._inds_inside_bounds(*arrays) + coords_inside = tuple(arr[ind] for ind, arr in zip(inds_inside, arrays)) + coords_3d = np.meshgrid(*coords_inside, indexing="ij") + is_inside[inds_inside] = self.inside(*coords_3d) + return is_inside + + @abstractmethod + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + origin = self.unpop_axis(position, (0, 0), axis=axis) + normal = self.unpop_axis(1, (0, 0), axis=axis) + to_2D = np.eye(4) + if axis != 2: + last, indices = self.pop_axis((0, 1, 2), axis) + to_2D = to_2D[[*list(indices), last, 3]] + return self.intersections_tilted_plane( + normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs + ) + + def intersections_2dbox(self, plane: Box) -> list[Shapely]: + """Returns list of shapely geometries representing the intersections of the geometry with + a 2D box. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. For more details refer to + `Shapely's Documentation `_. + """ + log.warning( + "'intersections_2dbox()' is deprecated and will be removed in the future. " + "Use 'plane.intersections_with(...)' for the same functionality." + ) + return plane.intersections_with(self) + + def intersects( + self, other: Geometry, strict_inequality: tuple[bool, bool, bool] = [False, False, False] + ) -> bool: + """Returns ``True`` if two :class:`Geometry` have intersecting `.bounds`. + + Parameters + ---------- + other : :class:`Geometry` + Geometry to check intersection with. + strict_inequality : tuple[bool, bool, bool] = [False, False, False] + For each dimension, defines whether to include equality in the boundaries comparison. + If ``False``, equality is included, and two geometries that only intersect at their + boundaries will evaluate as ``True``. If ``True``, such geometries will evaluate as + ``False``. + + Returns + ------- + bool + Whether the rectangular bounding boxes of the two geometries intersect. + """ + + self_bmin, self_bmax = self.bounds + other_bmin, other_bmax = other.bounds + + for smin, omin, smax, omax, strict in zip( + self_bmin, other_bmin, self_bmax, other_bmax, strict_inequality + ): + # are all of other's minimum coordinates less than self's maximum coordinate? + in_minus = omin < smax if strict else omin <= smax + # are all of other's maximum coordinates greater than self's minimum coordinate? + in_plus = omax > smin if strict else omax >= smin + + # if either failed, return False + if not all((in_minus, in_plus)): + return False + + return True + + def contains( + self, other: Geometry, strict_inequality: tuple[bool, bool, bool] = [False, False, False] + ) -> bool: + """Returns ``True`` if the `.bounds` of ``other`` are contained within the + `.bounds` of ``self``. + + Parameters + ---------- + other : :class:`Geometry` + Geometry to check containment with. + strict_inequality : tuple[bool, bool, bool] = [False, False, False] + For each dimension, defines whether to include equality in the boundaries comparison. + If ``False``, equality will be considered as contained. If ``True``, ``other``'s + bounds must be strictly within the bounds of ``self``. + + Returns + ------- + bool + Whether the rectangular bounding box of ``other`` is contained within the bounding + box of ``self``. + """ + + self_bmin, self_bmax = self.bounds + other_bmin, other_bmax = other.bounds + + for smin, omin, smax, omax, strict in zip( + self_bmin, other_bmin, self_bmax, other_bmax, strict_inequality + ): + # are all of other's minimum coordinates greater than self's minimim coordinate? + in_minus = omin > smin if strict else omin >= smin + # are all of other's maximum coordinates less than self's maximum coordinate? + in_plus = omax < smax if strict else omax <= smax + + # if either failed, return False + if not all((in_minus, in_plus)): + return False + + return True + + def intersects_plane( + self, x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None + ) -> bool: + """Whether self intersects plane specified by one non-None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + + Returns + ------- + bool + Whether this geometry intersects the plane. + """ + + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + return self.intersects_axis_position(axis, position) + + def intersects_axis_position(self, axis: int, position: float) -> bool: + """Whether self intersects plane specified by a given position along a normal axis. + + Parameters + ---------- + axis : int = None + Axis normal to the plane. + position : float = None + Position of plane along the normal axis. + + Returns + ------- + bool + Whether this geometry intersects the plane. + """ + return self.bounds[0][axis] <= position <= self.bounds[1][axis] + + @cached_property + @abstractmethod + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + + @staticmethod + def bounds_intersection(bounds1: Bound, bounds2: Bound) -> Bound: + """Return the bounds that are the intersection of two bounds.""" + return bounds_intersection(bounds1, bounds2) + + @staticmethod + def bounds_union(bounds1: Bound, bounds2: Bound) -> Bound: + """Return the bounds that are the union of two bounds.""" + return bounds_union(bounds1, bounds2) + + @cached_property + def bounding_box(self) -> Box: + """Returns :class:`Box` representation of the bounding box of a :class:`Geometry`. + + Returns + ------- + :class:`Box` + Geometric object representing bounding box. + """ + return Box.from_bounds(*self.bounds) + + @cached_property + def zero_dims(self) -> list[Axis]: + """A list of axes along which the :class:`Geometry` is zero-sized based on its bounds.""" + zero_dims = [] + for dim in range(3): + if self.bounds[1][dim] == self.bounds[0][dim]: + zero_dims.append(dim) + return zero_dims + + def _pop_bounds(self, axis: Axis) -> tuple[Coordinate2D, tuple[Coordinate2D, Coordinate2D]]: + """Returns min and max bounds in plane normal to and tangential to ``axis``. + + Parameters + ---------- + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + tuple[float, float], tuple[tuple[float, float], tuple[float, float]] + Bounds along axis and a tuple of bounds in the ordered planar coordinates. + Packed as ``(zmin, zmax), ((xmin, ymin), (xmax, ymax))``. + """ + b_min, b_max = self.bounds + zmin, (xmin, ymin) = self.pop_axis(b_min, axis=axis) + zmax, (xmax, ymax) = self.pop_axis(b_max, axis=axis) + return (zmin, zmax), ((xmin, ymin), (xmax, ymax)) + + @staticmethod + def _get_center(pt_min: float, pt_max: float) -> float: + """Returns center point based on bounds along dimension.""" + if np.isneginf(pt_min) and np.isposinf(pt_max): + return 0.0 + if np.isneginf(pt_min) or np.isposinf(pt_max): + raise SetupError( + f"Bounds of ({pt_min}, {pt_max}) supplied along one dimension. " + "We currently don't support a single ``inf`` value in bounds for ``Box``. " + "To construct a semi-infinite ``Box``, " + "please supply a large enough number instead of ``inf``. " + "For example, a location extending outside of the " + "Simulation domain (including PML)." + ) + return (pt_min + pt_max) / 2.0 + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + raise ValidationError("'Medium2D' is not compatible with this geometry class.") + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Geometry: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + raise NotImplementedError( + "'_update_from_bounds' is not compatible with this geometry class." + ) + + @equal_aspect + @add_ax_if_none + def plot( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + ax: Ax = None, + plot_length_units: LengthUnit = None, + viz_spec: VisualizationSpec = None, + **patch_kwargs: Any, + ) -> Ax: + """Plot geometry cross section at single (x,y,z) coordinate. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + plot_length_units : LengthUnit = None + Specify units to use for axis labels, tick labels, and the title. + viz_spec : VisualizationSpec = None + Plotting parameters associated with a medium to use instead of defaults. + **patch_kwargs + Optional keyword arguments passed to the matplotlib patch plotting of structure. + For details on accepted values, refer to + `Matplotlib's documentation `_. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + # find shapes that intersect self at plane + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + shapes_intersect = self.intersections_plane(x=x, y=y, z=z) + + plot_params = self.plot_params + if viz_spec is not None: + plot_params = plot_params.override_with_viz_spec(viz_spec) + plot_params = plot_params.include_kwargs(**patch_kwargs) + + # for each intersection, plot the shape + for shape in shapes_intersect: + ax = self.plot_shape(shape, plot_params=plot_params, ax=ax) + + # clean up the axis display + ax = self.add_ax_lims(axis=axis, ax=ax) + ax.set_aspect("equal") + # Add the default axis labels, tick labels, and title + ax = Box.add_ax_labels_and_title(ax=ax, x=x, y=y, z=z, plot_length_units=plot_length_units) + return ax + + def plot_shape(self, shape: Shapely, plot_params: PlotParams, ax: Ax) -> Ax: + """Defines how a shape is plotted on a matplotlib axes.""" + if shape.geom_type in ( + "MultiPoint", + "MultiLineString", + "MultiPolygon", + "GeometryCollection", + ): + for sub_shape in shape.geoms: + ax = self.plot_shape(shape=sub_shape, plot_params=plot_params, ax=ax) + + return ax + + _shape = Geometry.evaluate_inf_shape(shape) + + if _shape.geom_type == "LineString": + xs, ys = zip(*_shape.coords) + ax.plot(xs, ys, color=plot_params.facecolor, linewidth=plot_params.linewidth) + elif _shape.geom_type == "Point": + ax.scatter(shape.x, shape.y, color=plot_params.facecolor) + else: + patch = polygon_patch(_shape, **plot_params.to_kwargs()) + ax.add_artist(patch) + return ax + + @staticmethod + def _do_not_intersect( + bounds_a: float, bounds_b: float, shape_a: Shapely, shape_b: Shapely + ) -> bool: + """Check whether two shapes intersect.""" + + # do a bounding box check to see if any intersection to do anything about + if ( + bounds_a[0] > bounds_b[2] + or bounds_b[0] > bounds_a[2] + or bounds_a[1] > bounds_b[3] + or bounds_b[1] > bounds_a[3] + ): + return True + + # look more closely to see if intersected. + if shape_b.is_empty or not shape_a.intersects(shape_b): + return True + + return False + + @staticmethod + def _get_plot_labels(axis: Axis) -> tuple[str, str]: + """Returns planar coordinate x and y axis labels for cross section plots. + + Parameters + ---------- + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + str, str + Labels of plot, packaged as ``(xlabel, ylabel)``. + """ + _, (xlabel, ylabel) = Geometry.pop_axis("xyz", axis=axis) + return xlabel, ylabel + + def _get_plot_limits( + self, axis: Axis, buffer: float = PLOT_BUFFER + ) -> tuple[Coordinate2D, Coordinate2D]: + """Gets planar coordinate limits for cross section plots. + + Parameters + ---------- + axis : int + Integer index into 'xyz' (0,1,2). + buffer : float = 0.3 + Amount of space to add around the limits on the + and - sides. + + Returns + ------- + tuple[float, float], tuple[float, float] + The x and y plot limits, packed as ``(xmin, xmax), (ymin, ymax)``. + """ + _, ((xmin, ymin), (xmax, ymax)) = self._pop_bounds(axis=axis) + return (xmin - buffer, xmax + buffer), (ymin - buffer, ymax + buffer) + + def add_ax_lims(self, axis: Axis, ax: Ax, buffer: float = PLOT_BUFFER) -> Ax: + """Sets the x,y limits based on ``self.bounds``. + + Parameters + ---------- + axis : int + Integer index into 'xyz' (0,1,2). + ax : matplotlib.axes._subplots.Axes + Matplotlib axes to add labels and limits on. + buffer : float = 0.3 + Amount of space to place around the limits on the + and - sides. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + (xmin, xmax), (ymin, ymax) = self._get_plot_limits(axis=axis, buffer=buffer) + + # note: axes limits dont like inf values, so we need to evaluate them first if present + xmin, xmax, ymin, ymax = self._evaluate_inf((xmin, xmax, ymin, ymax)) + + ax.set_xlim(xmin, xmax) + ax.set_ylim(ymin, ymax) + return ax + + @staticmethod + def add_ax_labels_and_title( + ax: Ax, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + plot_length_units: LengthUnit = None, + ) -> Ax: + """Sets the axis labels, tick labels, and title based on ``axis`` + and an optional ``plot_length_units`` argument. + + Parameters + ---------- + ax : matplotlib.axes._subplots.Axes + Matplotlib axes to add labels and limits on. + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + plot_length_units : LengthUnit = None + When set to a supported ``LengthUnit``, plots will be produced with annotated axes + and title with the proper units. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied matplotlib axes. + """ + axis, position = Box.parse_xyz_kwargs(x=x, y=y, z=z) + axis_labels = Box._get_plot_labels(axis) + ax = set_default_labels_and_title( + axis_labels=axis_labels, + axis=axis, + position=position, + ax=ax, + plot_length_units=plot_length_units, + ) + return ax + + @staticmethod + def _evaluate_inf(array: ArrayLike) -> NDArray[np.floating]: + """Processes values and evaluates any infs into large (signed) numbers.""" + array = get_static(np.array(array)) + return np.where(np.isinf(array), np.sign(array) * LARGE_NUMBER, array) + + @staticmethod + def evaluate_inf_shape(shape: Shapely) -> Shapely: + """Returns a copy of shape with inf vertices replaced by large numbers if polygon.""" + if not any(np.isinf(b) for b in shape.bounds): + return shape + + def _processed_coords(coords: Sequence[tuple[Any, ...]]) -> list[tuple[float, ...]]: + evaluated = Geometry._evaluate_inf(np.array(coords)) + return [tuple(point) for point in evaluated.tolist()] + + if shape.geom_type == "Polygon": + shell = _processed_coords(shape.exterior.coords) + holes = [_processed_coords(g.coords) for g in shape.interiors] + return shapely.Polygon(shell, holes) + if shape.geom_type in {"Point", "LineString", "LinearRing"}: + return shape.__class__(Geometry._evaluate_inf(np.array(shape.coords))) + if shape.geom_type in { + "MultiPoint", + "MultiLineString", + "MultiPolygon", + "GeometryCollection", + }: + return shape.__class__([Geometry.evaluate_inf_shape(g) for g in shape.geoms]) + return shape + + @staticmethod + def pop_axis(coord: tuple[Any, Any, Any], axis: int) -> tuple[Any, tuple[Any, Any]]: + """Separates coordinate at ``axis`` index from coordinates on the plane tangent to ``axis``. + + Parameters + ---------- + coord : tuple[Any, Any, Any] + Tuple of three values in original coordinate system. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + Any, tuple[Any, Any] + The input coordinates are separated into the one along the axis provided + and the two on the planar coordinates, + like ``axis_coord, (planar_coord1, planar_coord2)``. + """ + plane_vals = list(coord) + axis_val = plane_vals.pop(axis) + return axis_val, tuple(plane_vals) + + @staticmethod + def unpop_axis(ax_coord: Any, plane_coords: tuple[Any, Any], axis: int) -> tuple[Any, Any, Any]: + """Combine coordinate along axis with coordinates on the plane tangent to the axis. + + Parameters + ---------- + ax_coord : Any + Value along axis direction. + plane_coords : tuple[Any, Any] + Values along ordered planar directions. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + tuple[Any, Any, Any] + The three values in the xyz coordinate system. + """ + coords = list(plane_coords) + coords.insert(axis, ax_coord) + return tuple(coords) + + @staticmethod + def parse_xyz_kwargs(**xyz: Any) -> tuple[Axis, float]: + """Turns x,y,z kwargs into index of the normal axis and position along that axis. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + + Returns + ------- + int, float + Index into xyz axis (0,1,2) and position along that axis. + """ + xyz_filtered = {k: v for k, v in xyz.items() if v is not None} + if len(xyz_filtered) != 1: + raise ValueError("exactly one kwarg in [x,y,z] must be specified.") + axis_label, position = list(xyz_filtered.items())[0] + axis = "xyz".index(axis_label) + return axis, position + + @staticmethod + def parse_two_xyz_kwargs(**xyz: Any) -> list[tuple[Axis, float]]: + """Turns x,y,z kwargs into indices of axes and the position along each axis. + + Parameters + ---------- + x : float = None + Position in x direction, only two of x,y,z can be specified to define line. + y : float = None + Position in y direction, only two of x,y,z can be specified to define line. + z : float = None + Position in z direction, only two of x,y,z can be specified to define line. + + Returns + ------- + [(int, float), (int, float)] + Index into xyz axis (0,1,2) and position along that axis. + """ + xyz_filtered = {k: v for k, v in xyz.items() if v is not None} + assert len(xyz_filtered) == 2, "exactly two kwarg in [x,y,z] must be specified." + xyz_list = list(xyz_filtered.items()) + return [("xyz".index(axis_label), position) for axis_label, position in xyz_list] + + @staticmethod + def rotate_points(points: ArrayFloat3D, axis: Coordinate, angle: float) -> ArrayFloat3D: + """Rotate a set of points in 3D. + + Parameters + ---------- + points : ArrayLike[float] + Array of shape ``(3, ...)``. + axis : Coordinate + Axis of rotation + angle : float + Angle of rotation counter-clockwise around the axis (rad). + """ + rotation = RotationAroundAxis(axis=axis, angle=angle) + return rotation.rotate_vector(points) + + def reflect_points( + self, + points: ArrayFloat3D, + polar_axis: Axis, + angle_theta: float, + angle_phi: float, + ) -> ArrayFloat3D: + """Reflect a set of points in 3D at a plane passing through the coordinate origin defined + and normal to a given axis defined in polar coordinates (theta, phi) w.r.t. the + ``polar_axis`` which can be 0, 1, or 2. + + Parameters + ---------- + points : ArrayLike[float] + Array of shape ``(3, ...)``. + polar_axis : Axis + Cartesian axis w.r.t. which the normal axis angles are defined. + angle_theta : float + Polar angle w.r.t. the polar axis. + angle_phi : float + Azimuth angle around the polar axis. + """ + + # Rotate such that the plane normal is along the polar_axis + axis_theta, axis_phi = [0, 0, 0], [0, 0, 0] + axis_phi[polar_axis] = 1 + plane_axes = [0, 1, 2] + plane_axes.pop(polar_axis) + axis_theta[plane_axes[1]] = 1 + points_new = self.rotate_points(points, axis_phi, -angle_phi) + points_new = self.rotate_points(points_new, axis_theta, -angle_theta) + + # Flip the ``polar_axis`` coordinate of the points, which is now normal to the plane + points_new[polar_axis, :] *= -1 + + # Rotate back + points_new = self.rotate_points(points_new, axis_theta, angle_theta) + points_new = self.rotate_points(points_new, axis_phi, angle_phi) + + return points_new + + def volume(self, bounds: Bound = None) -> float: + """Returns object's volume with optional bounds. + + Parameters + ---------- + bounds : tuple[tuple[float, float, float], tuple[float, float, float]] = None + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + float + Volume in um^3. + """ + + if not bounds: + bounds = self.bounds + + return self._volume(bounds) + + @abstractmethod + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + + def surface_area(self, bounds: Bound = None) -> float: + """Returns object's surface area with optional bounds. + + Parameters + ---------- + bounds : tuple[tuple[float, float, float], tuple[float, float, float]] = None + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + float + Surface area in um^2. + """ + + if not bounds: + bounds = self.bounds + + return self._surface_area(bounds) + + @abstractmethod + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + + def translated(self, x: float, y: float, z: float) -> Geometry: + """Return a translated copy of this geometry. + + Parameters + ---------- + x : float + Translation along x. + y : float + Translation along y. + z : float + Translation along z. + + Returns + ------- + :class:`Geometry` + Translated copy of this geometry. + """ + return Transformed(geometry=self, transform=Transformed.translation(x, y, z)) + + def scaled(self, x: float = 1.0, y: float = 1.0, z: float = 1.0) -> Geometry: + """Return a scaled copy of this geometry. + + Parameters + ---------- + x : float = 1.0 + Scaling factor along x. + y : float = 1.0 + Scaling factor along y. + z : float = 1.0 + Scaling factor along z. + + Returns + ------- + :class:`Geometry` + Scaled copy of this geometry. + """ + return Transformed(geometry=self, transform=Transformed.scaling(x, y, z)) + + def rotated(self, angle: float, axis: Union[Axis, Coordinate]) -> Geometry: + """Return a rotated copy of this geometry. + + Parameters + ---------- + angle : float + Rotation angle (in radians). + axis : Union[int, tuple[float, float, float]] + Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. + + Returns + ------- + :class:`Geometry` + Rotated copy of this geometry. + """ + return Transformed(geometry=self, transform=Transformed.rotation(angle, axis)) + + def reflected(self, normal: Coordinate) -> Geometry: + """Return a reflected copy of this geometry. + + Parameters + ---------- + normal : tuple[float, float, float] + The 3D normal vector of the plane of reflection. The plane is assumed + to pass through the origin (0,0,0). + + Returns + ------- + :class:`Geometry` + Reflected copy of this geometry. + """ + return Transformed(geometry=self, transform=Transformed.reflection(normal)) + + """ Field and coordinate transformations """ + + @staticmethod + def car_2_sph(x: float, y: float, z: float) -> tuple[float, float, float]: + """Convert Cartesian to spherical coordinates. + + Parameters + ---------- + x : float + x coordinate relative to ``local_origin``. + y : float + y coordinate relative to ``local_origin``. + z : float + z coordinate relative to ``local_origin``. + + Returns + ------- + tuple[float, float, float] + r, theta, and phi coordinates relative to ``local_origin``. + """ + r = np.sqrt(x**2 + y**2 + z**2) + theta = np.arccos(z / r) + phi = np.arctan2(y, x) + return r, theta, phi + + @staticmethod + def sph_2_car(r: float, theta: float, phi: float) -> tuple[float, float, float]: + """Convert spherical to Cartesian coordinates. + + Parameters + ---------- + r : float + radius. + theta : float + polar angle (rad) downward from x=y=0 line. + phi : float + azimuthal (rad) angle from y=z=0 line. + + Returns + ------- + tuple[float, float, float] + x, y, and z coordinates relative to ``local_origin``. + """ + r_sin_theta = r * np.sin(theta) + x = r_sin_theta * np.cos(phi) + y = r_sin_theta * np.sin(phi) + z = r * np.cos(theta) + return x, y, z + + @staticmethod + def sph_2_car_field( + f_r: float, f_theta: float, f_phi: float, theta: float, phi: float + ) -> tuple[complex, complex, complex]: + """Convert vector field components in spherical coordinates to cartesian. + + Parameters + ---------- + f_r : float + radial component of the vector field. + f_theta : float + polar angle component of the vector fielf. + f_phi : float + azimuthal angle component of the vector field. + theta : float + polar angle (rad) of location of the vector field. + phi : float + azimuthal angle (rad) of location of the vector field. + + Returns + ------- + tuple[float, float, float] + x, y, and z components of the vector field in cartesian coordinates. + """ + sin_theta = np.sin(theta) + cos_theta = np.cos(theta) + sin_phi = np.sin(phi) + cos_phi = np.cos(phi) + f_x = f_r * sin_theta * cos_phi + f_theta * cos_theta * cos_phi - f_phi * sin_phi + f_y = f_r * sin_theta * sin_phi + f_theta * cos_theta * sin_phi + f_phi * cos_phi + f_z = f_r * cos_theta - f_theta * sin_theta + return f_x, f_y, f_z + + @staticmethod + def car_2_sph_field( + f_x: float, f_y: float, f_z: float, theta: float, phi: float + ) -> tuple[complex, complex, complex]: + """Convert vector field components in cartesian coordinates to spherical. + + Parameters + ---------- + f_x : float + x component of the vector field. + f_y : float + y component of the vector fielf. + f_z : float + z component of the vector field. + theta : float + polar angle (rad) of location of the vector field. + phi : float + azimuthal angle (rad) of location of the vector field. + + Returns + ------- + tuple[float, float, float] + radial (s), elevation (theta), and azimuthal (phi) components + of the vector field in spherical coordinates. + """ + sin_theta = np.sin(theta) + cos_theta = np.cos(theta) + sin_phi = np.sin(phi) + cos_phi = np.cos(phi) + f_r = f_x * sin_theta * cos_phi + f_y * sin_theta * sin_phi + f_z * cos_theta + f_theta = f_x * cos_theta * cos_phi + f_y * cos_theta * sin_phi - f_z * sin_theta + f_phi = -f_x * sin_phi + f_y * cos_phi + return f_r, f_theta, f_phi + + @staticmethod + def kspace_2_sph(ux: float, uy: float, axis: Axis) -> tuple[float, float]: + """Convert normalized k-space coordinates to angles. + + Parameters + ---------- + ux : float + normalized kx coordinate. + uy : float + normalized ky coordinate. + axis : int + axis along which the observation plane is oriented. + + Returns + ------- + tuple[float, float] + theta and phi coordinates relative to ``local_origin``. + """ + phi_local = np.arctan2(uy, ux) + with np.errstate(invalid="ignore"): + theta_local = np.arcsin(np.sqrt(ux**2 + uy**2)) + # Spherical coordinates rotation matrix reference: + # https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula#Matrix_notation + if axis == 2: + return theta_local, phi_local + + x = np.cos(theta_local) + y = np.sin(theta_local) * np.cos(phi_local) + z = np.sin(theta_local) * np.sin(phi_local) + + if axis == 1: + x, y, z = y, x, z + + theta = np.arccos(z) + phi = np.arctan2(y, x) + return theta, phi + + @staticmethod + @verify_packages_import(["gdstk"]) + def load_gds_vertices_gdstk( + gds_cell: Cell, + gds_layer: int, + gds_dtype: Optional[int] = None, + gds_scale: PositiveFloat = 1.0, + ) -> list[ArrayFloat2D]: + """Load polygon vertices from a ``gdstk.Cell``. + + Parameters + ---------- + gds_cell : gdstk.Cell + ``gdstk.Cell`` containing 2D geometric data. + gds_layer : int + Layer index in the ``gds_cell``. + gds_dtype : int = None + Data-type index in the ``gds_cell``. If ``None``, imports all data for this layer into + the returned list. + gds_scale : float = 1.0 + Length scale used in GDS file in units of micrometer. For example, if gds file uses + nanometers, set ``gds_scale=1e-3``. Must be positive. + + Returns + ------- + list[ArrayFloat2D] + List of polygon vertices + """ + + # apply desired scaling and load the polygon vertices + if gds_dtype is not None: + # if both layer and datatype are specified, let gdstk do the filtering for better + # performance on large layouts + all_vertices = [ + polygon.scale(gds_scale).points + for polygon in gds_cell.get_polygons(layer=gds_layer, datatype=gds_dtype) + ] + else: + all_vertices = [ + polygon.scale(gds_scale).points + for polygon in gds_cell.get_polygons() + if polygon.layer == gds_layer + ] + # make sure something got loaded, otherwise error + if not all_vertices: + raise Tidy3dKeyError( + f"Couldn't load gds_cell, no vertices found at gds_layer={gds_layer} " + f"with specified gds_dtype={gds_dtype}." + ) + + return all_vertices + + @staticmethod + @verify_packages_import(["gdstk"]) + def from_gds( + gds_cell: Cell, + axis: Axis, + slab_bounds: tuple[float, float], + gds_layer: int, + gds_dtype: Optional[int] = None, + gds_scale: PositiveFloat = 1.0, + dilation: float = 0.0, + sidewall_angle: float = 0, + reference_plane: PlanePosition = "middle", + ) -> Geometry: + """Import a ``gdstk.Cell`` and extrude it into a GeometryGroup. + + Parameters + ---------- + gds_cell : gdstk.Cell + ``gdstk.Cell`` containing 2D geometric data. + axis : int + Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). + slab_bounds: tuple[float, float] + Minimal and maximal positions of the extruded slab along ``axis``. + gds_layer : int + Layer index in the ``gds_cell``. + gds_dtype : int = None + Data-type index in the ``gds_cell``. If ``None``, imports all data for this layer into + the returned list. + gds_scale : float = 1.0 + Length scale used in GDS file in units of micrometer. For example, if gds file uses + nanometers, set ``gds_scale=1e-3``. Must be positive. + dilation : float = 0.0 + Dilation (positive) or erosion (negative) amount to be applied to the original polygons. + sidewall_angle : float = 0 + Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive + (negative) values result in slabs larger (smaller) at the base than at the top. + reference_plane : PlanePosition = "middle" + Reference position of the (dilated/eroded) polygons along the slab axis. One of + ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` + (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value + has no effect if ``sidewall_angle == 0``. + + Returns + ------- + :class:`Geometry` + Geometries created from the 2D data. + """ + import gdstk + + if not isinstance(gds_cell, gdstk.Cell): + # Check if it might be a gdstk cell but gdstk is not found (should be caught by decorator) + # or if it's an entirely different type. + if "gdstk" in gds_cell.__class__.__name__.lower(): + raise Tidy3dImportError( + "Module 'gdstk' not found. It is required to import gdstk cells." + ) + raise Tidy3dImportError("Argument 'gds_cell' must be an instance of 'gdstk.Cell'.") + + gds_loader_fn = Geometry.load_gds_vertices_gdstk + geometries = [] + with log as consolidated_logger: + for vertices in gds_loader_fn(gds_cell, gds_layer, gds_dtype, gds_scale): + # buffer(0) is necessary to merge self-intersections + shape = shapely.set_precision(shapely.Polygon(vertices).buffer(0), POLY_GRID_SIZE) + try: + geometries.append( + from_shapely( + shape, axis, slab_bounds, dilation, sidewall_angle, reference_plane + ) + ) + except ValidationError as error: + consolidated_logger.warning(str(error)) + except Tidy3dError as error: + consolidated_logger.warning(str(error)) + return geometries[0] if len(geometries) == 1 else GeometryGroup(geometries=geometries) + + @staticmethod + def from_shapely( + shape: Shapely, + axis: Axis, + slab_bounds: tuple[float, float], + dilation: float = 0.0, + sidewall_angle: float = 0, + reference_plane: PlanePosition = "middle", + ) -> Geometry: + """Convert a shapely primitive into a geometry instance by extrusion. + + Parameters + ---------- + shape : shapely.geometry.base.BaseGeometry + Shapely primitive to be converted. It must be a linear ring, a polygon or a collection + of any of those. + axis : int + Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). + slab_bounds: tuple[float, float] + Minimal and maximal positions of the extruded slab along ``axis``. + dilation : float + Dilation of the polygon in the base by shifting each edge along its normal outwards + direction by a distance; a negative value corresponds to erosion. + sidewall_angle : float = 0 + Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive + (negative) values result in slabs larger (smaller) at the base than at the top. + reference_plane : PlanePosition = "middle" + Reference position of the (dilated/eroded) polygons along the slab axis. One of + ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` + (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value + has no effect if ``sidewall_angle == 0``. + + Returns + ------- + :class:`Geometry` + Geometry extruded from the 2D data. + """ + return from_shapely(shape, axis, slab_bounds, dilation, sidewall_angle, reference_plane) + + @verify_packages_import(["gdstk"]) + def to_gdstk( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, + ) -> list: + """Convert a Geometry object's planar slice to a .gds type polygon. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + gds_layer : int = 0 + Layer index to use for the shapes stored in the .gds file. + gds_dtype : int = 0 + Data-type index to use for the shapes stored in the .gds file. + + Return + ------ + List + List of `gdstk.Polygon`. + """ + import gdstk + + shapes = self.intersections_plane(x=x, y=y, z=z) + polygons = [] + for shape in shapes: + for vertices in vertices_from_shapely(shape): + if len(vertices) == 1: + polygons.append(gdstk.Polygon(vertices[0], gds_layer, gds_dtype)) + else: + polygons.extend( + gdstk.boolean( + vertices[:1], + vertices[1:], + "not", + layer=gds_layer, + datatype=gds_dtype, + ) + ) + return polygons + + @verify_packages_import(["gdstk"]) + def to_gds( + self, + cell: Cell, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, + ) -> None: + """Append a Geometry object's planar slice to a .gds cell. + + Parameters + ---------- + cell : ``gdstk.Cell`` + Cell object to which the generated polygons are added. + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + gds_layer : int = 0 + Layer index to use for the shapes stored in the .gds file. + gds_dtype : int = 0 + Data-type index to use for the shapes stored in the .gds file. + """ + import gdstk + + if not isinstance(cell, gdstk.Cell): + if "gdstk" in cell.__class__.__name__.lower(): + raise Tidy3dImportError( + "Module 'gdstk' not found. It is required to export shapes to gdstk cells." + ) + raise Tidy3dImportError("Argument 'cell' must be an instance of 'gdstk.Cell'.") + + polygons = self.to_gdstk(x=x, y=y, z=z, gds_layer=gds_layer, gds_dtype=gds_dtype) + if polygons: + cell.add(*polygons) + + @verify_packages_import(["gdstk"]) + def to_gds_file( + self, + fname: PathLike, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, + gds_cell_name: str = "MAIN", + ) -> None: + """Export a Geometry object's planar slice to a .gds file. + + Parameters + ---------- + fname : PathLike + Full path to the .gds file to save the :class:`Geometry` slice to. + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + gds_layer : int = 0 + Layer index to use for the shapes stored in the .gds file. + gds_dtype : int = 0 + Data-type index to use for the shapes stored in the .gds file. + gds_cell_name : str = 'MAIN' + Name of the cell created in the .gds file to store the geometry. + """ + try: + import gdstk + except ImportError as e: + raise Tidy3dImportError( + "Python module 'gdstk' not found. To export geometries to .gds " + "files, please install it." + ) from e + + library = gdstk.Library() + cell = library.new_cell(gds_cell_name) + self.to_gds(cell, x=x, y=y, z=z, gds_layer=gds_layer, gds_dtype=gds_dtype) + fname = pathlib.Path(fname) + fname.parent.mkdir(parents=True, exist_ok=True) + library.write_gds(fname) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + raise NotImplementedError(f"Can't compute derivative for 'Geometry': '{type(self)}'.") + + def _as_union(self) -> list[Geometry]: + """Return a list of geometries that, united, make up the given geometry.""" + if isinstance(self, GeometryGroup): + return self.geometries + + if isinstance(self, ClipOperation) and self.operation == "union": + return (self.geometry_a, self.geometry_b) + return (self,) + + def __add__(self, other: Union[int, Geometry]) -> Union[Self, GeometryGroup]: + """Union of geometries""" + # This allows the user to write sum(geometries...) with the default start=0 + if isinstance(other, int): + return self + if not isinstance(other, Geometry): + return NotImplemented # type: ignore[return-value] + return GeometryGroup(geometries=self._as_union() + other._as_union()) + + def __radd__(self, other: Union[int, Geometry]) -> Union[Self, GeometryGroup]: + """Union of geometries""" + # This allows the user to write sum(geometries...) with the default start=0 + if isinstance(other, int): + return self + if not isinstance(other, Geometry): + return NotImplemented # type: ignore[return-value] + return GeometryGroup(geometries=other._as_union() + self._as_union()) + + def __or__(self, other: Geometry) -> GeometryGroup: + """Union of geometries""" + if not isinstance(other, Geometry): + return NotImplemented + return GeometryGroup(geometries=self._as_union() + other._as_union()) + + def __mul__(self, other: Geometry) -> ClipOperation: + """Intersection of geometries""" + if not isinstance(other, Geometry): + return NotImplemented + return ClipOperation(operation="intersection", geometry_a=self, geometry_b=other) + + def __and__(self, other: Geometry) -> ClipOperation: + """Intersection of geometries""" + if not isinstance(other, Geometry): + return NotImplemented + return ClipOperation(operation="intersection", geometry_a=self, geometry_b=other) + + def __sub__(self, other: Geometry) -> ClipOperation: + """Difference of geometries""" + if not isinstance(other, Geometry): + return NotImplemented # type: ignore[return-value] + return ClipOperation(operation="difference", geometry_a=self, geometry_b=other) + + def __xor__(self, other: Geometry) -> ClipOperation: + """Symmetric difference of geometries""" + if not isinstance(other, Geometry): + return NotImplemented + return ClipOperation(operation="symmetric_difference", geometry_a=self, geometry_b=other) + + def __pos__(self) -> Self: + """No op""" + return self + + def __neg__(self) -> ClipOperation: + """Opposite of a geometry""" + return ClipOperation( + operation="difference", geometry_a=Box(size=(inf, inf, inf)), geometry_b=self + ) + + def __invert__(self) -> ClipOperation: + """Opposite of a geometry""" + return ClipOperation( + operation="difference", geometry_a=Box(size=(inf, inf, inf)), geometry_b=self + ) + + +""" Abstract subclasses """ + + +class Centered(Geometry, ABC): + """Geometry with a well defined center.""" + + center: Optional[TracedCoordinate] = Field( + None, + title="Center", + description="Center of object in x, y, and z.", + units=MICROMETER, + ) + + @field_validator("center", mode="before") + @classmethod + def _center_default(cls, val: Any) -> Any: + """Make sure center is not infinitiy.""" + if val is None: + val = (0.0, 0.0, 0.0) + return val + + @field_validator("center") + @classmethod + def _center_not_inf(cls, val: tuple[float, float, float]) -> tuple[float, float, float]: + """Make sure center is not infinitiy.""" + if any(np.isinf(v) for v in val): + raise ValidationError("center can not contain td.inf terms.") + return val + + +class SimplePlaneIntersection(Geometry, ABC): + """A geometry where intersections with an axis aligned plane may be computed efficiently.""" + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + Checks special cases before relying on the complete computation. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + # Check if normal is a special case, where the normal is aligned with an axis. + if np.sum(np.isclose(normal, 0.0)) == 2: + axis = np.argmax(np.abs(normal)).item() + coord = "xyz"[axis] + kwargs = {coord: origin[axis]} + section = self.intersections_plane(cleanup=cleanup, quad_segs=quad_segs, **kwargs) + # Apply transformation in the plane by removing row and column + to_2D_in_plane = np.delete(np.delete(to_2D, 2, 0), axis, 1) + + def transform(p_array: NDArray) -> NDArray: + return np.dot( + np.hstack((p_array, np.ones((p_array.shape[0], 1)))), to_2D_in_plane.T + )[:, :2] + + transformed_section = shapely.transform(section, transformation=transform) + return transformed_section + # Otherwise compute the arbitrary intersection + return self._do_intersections_tilted_plane( + normal=normal, origin=origin, to_2D=to_2D, quad_segs=quad_segs + ) + + @abstractmethod + def _do_intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + +class Planar(SimplePlaneIntersection, Geometry, ABC): + """Geometry with one ``axis`` that is slab-like with thickness ``height``.""" + + axis: Axis = Field( + 2, + title="Axis", + description="Specifies dimension of the planar axis (0,1,2) -> (x,y,z).", + ) + + sidewall_angle: TracedFloat = Field( + 0.0, + title="Sidewall angle", + description="Angle of the sidewall. " + "``sidewall_angle=0`` (default) specifies a vertical wall; " + "``0 float: + lower_bound = -np.pi / 2 + upper_bound = np.pi / 2 + if (val <= lower_bound) or (val >= upper_bound): + # u03C0 is unicode for pi + raise ValidationError(f"Sidewall angle ({val}) must be between -π/2 and π/2 rad.") + return val + + @property + @abstractmethod + def center_axis(self) -> float: + """Gets the position of the center of the geometry in the out of plane dimension.""" + + @property + @abstractmethod + def length_axis(self) -> float: + """Gets the length of the geometry along the out of plane dimension.""" + + @property + def finite_length_axis(self) -> float: + """Gets the length of the geometry along the out of plane dimension. + If the length is td.inf, return ``LARGE_NUMBER`` + """ + return min(self.length_axis, LARGE_NUMBER) + + @property + def reference_axis_pos(self) -> float: + """Coordinate along the slab axis at the reference plane. + + Returns the axis coordinate corresponding to the selected + reference_plane: + - "bottom": lower bound of slab_bounds + - "middle": center_axis + - "top": upper bound of slab_bounds + """ + if self.reference_plane == "bottom": + return self.slab_bounds[0] + if self.reference_plane == "top": + return self.slab_bounds[1] + # default to middle + return self.center_axis + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns shapely geometry at plane specified by one non None value of x,y,z. + + Parameters + ---------- + x : float + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation ``. + """ + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + if not self.intersects_axis_position(axis, position): + return [] + if axis == self.axis: + return self._intersections_normal(position, quad_segs=quad_segs) + return self._intersections_side(position, axis) + + @abstractmethod + def _intersections_normal(self, z: float, quad_segs: Optional[int] = None) -> list: + """Find shapely geometries intersecting planar geometry with axis normal to slab. + + Parameters + ---------- + z : float + Position along the axis normal to slab + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + @abstractmethod + def _intersections_side(self, position: float, axis: Axis) -> list[Shapely]: + """Find shapely geometries intersecting planar geometry with axis orthogonal to plane. + + Parameters + ---------- + position : float + Position along axis. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + def _order_axis(self, axis: int) -> int: + """Order the axis as if self.axis is along z-direction. + + Parameters + ---------- + axis : int + Integer index into the structure's planar axis. + + Returns + ------- + int + New index of axis. + """ + axis_index = [0, 1] + axis_index.insert(self.axis, 2) + return axis_index[axis] + + def _order_by_axis(self, plane_val: Any, axis_val: Any, axis: int) -> tuple[Any, Any]: + """Orders a value in the plane and value along axis in correct (x,y) order for plotting. + Note: sometimes if axis=1 and we compute cross section values orthogonal to axis, + they can either be x or y in the plots. + This function allows one to figure out the ordering. + + Parameters + ---------- + plane_val : Any + The value in the planar coordinate. + axis_val : Any + The value in the ``axis`` coordinate. + axis : int + Integer index into the structure's planar axis. + + Returns + ------- + ``(Any, Any)`` + The two planar coordinates in this new coordinate system. + """ + vals = 3 * [plane_val] + vals[self.axis] = axis_val + _, (val_x, val_y) = self.pop_axis(vals, axis=axis) + return val_x, val_y + + @cached_property + def _tanq(self) -> float: + """Value of ``tan(sidewall_angle)``. + + The (possibliy infinite) geometry offset is given by ``_tanq * length_axis``. + """ + return np.tan(self.sidewall_angle) + + +class Circular(Geometry): + """Geometry with circular characteristics (specified by a radius).""" + + radius: NonNegativeFloat = Field( + title="Radius", + description="Radius of geometry.", + units=MICROMETER, + ) + + @field_validator("radius") + @classmethod + def _radius_not_inf(cls, val: float) -> float: + """Make sure center is not infinitiy.""" + if np.isinf(val): + raise ValidationError("radius can not be 'td.inf'.") + return val + + def _intersect_dist(self, position: float, z0: float) -> float: + """Distance between points on circle at z=position where center of circle at z=z0. + + Parameters + ---------- + position : float + position along z. + z0 : float + center of circle in z. + + Returns + ------- + float + Distance between points on the circle intersecting z=z, if no points, ``None``. + """ + dz = np.abs(z0 - position) + if dz > self.radius: + return None + return 2 * np.sqrt(self.radius**2 - dz**2) + + +"""Primitive classes""" + + +class Box(SimplePlaneIntersection, Centered): + """Rectangular prism. + Also base class for :class:`.Simulation`, :class:`Monitor`, and :class:`Source`. + + Example + ------- + >>> b = Box(center=(1,2,3), size=(2,2,2)) + """ + + size: TracedSize = Field( + title="Size", + description="Size in x, y, and z directions.", + units=MICROMETER, + ) + + @classmethod + def from_bounds(cls, rmin: Coordinate, rmax: Coordinate, **kwargs: Any) -> Self: + """Constructs a :class:`Box` from minimum and maximum coordinate bounds + + Parameters + ---------- + rmin : tuple[float, float, float] + (x, y, z) coordinate of the minimum values. + rmax : tuple[float, float, float] + (x, y, z) coordinate of the maximum values. + + Example + ------- + >>> b = Box.from_bounds(rmin=(-1, -2, -3), rmax=(3, 2, 1)) + """ + + center = tuple(cls._get_center(pt_min, pt_max) for pt_min, pt_max in zip(rmin, rmax)) + size = tuple((pt_max - pt_min) for pt_min, pt_max in zip(rmin, rmax)) + return cls(center=center, size=size, **kwargs) + + @cached_property + def _normal_axis(self) -> Axis: + """Axis normal to the Box. Errors if box is not planar.""" + if self.size.count(0.0) != 1: + raise ValidationError( + f"Tried to get 'normal_axis' of 'Box' that is not planar. Given 'size={self.size}.'" + ) + return self.size.index(0.0) + + @classmethod + def surfaces(cls, size: Size, center: Coordinate, **kwargs: Any) -> list[Self]: + """Returns a list of 6 :class:`Box` instances corresponding to each surface of a 3D volume. + The output surfaces are stored in the order [x-, x+, y-, y+, z-, z+], where x, y, and z + denote which axis is perpendicular to that surface, while "-" and "+" denote the direction + of the normal vector of that surface. If a name is provided, each output surface's name + will be that of the provided name appended with the above symbols. E.g., if the provided + name is "box", the x+ surfaces's name will be "box_x+". + + Parameters + ---------- + size : tuple[float, float, float] + Size of object in x, y, and z directions. + center : tuple[float, float, float] + Center of object in x, y, and z. + + Example + ------- + >>> b = Box.surfaces(size=(1, 2, 3), center=(3, 2, 1)) + """ + + if any(s == 0.0 for s in size): + raise SetupError( + "Can't generate surfaces for the given object because it has zero volume." + ) + + bounds = Box(center=center, size=size).bounds + + # Set up geometry data and names for each surface: + centers = [list(center) for _ in range(6)] + sizes = [list(size) for _ in range(6)] + + surface_index = 0 + for dim_index in range(3): + for min_max_index in range(2): + new_center = centers[surface_index] + new_size = sizes[surface_index] + + new_center[dim_index] = bounds[min_max_index][dim_index] + new_size[dim_index] = 0.0 + + centers[surface_index] = new_center + sizes[surface_index] = new_size + + surface_index += 1 + + name_base = kwargs.pop("name", "") + kwargs.pop("normal_dir", None) + + names = [] + normal_dirs = [] + + for coord in "xyz": + for direction in "-+": + surface_name = name_base + "_" + coord + direction + names.append(surface_name) + normal_dirs.append(direction) + + # ignore surfaces that are infinitely far away + del_idx = [] + for idx, _size in enumerate(size): + if _size == inf: + del_idx.append(idx) + del_idx = [[2 * i, 2 * i + 1] for i in del_idx] + del_idx = [item for sublist in del_idx for item in sublist] + + def del_items(items: Iterable, indices: int) -> list: + """Delete list items at indices.""" + return [i for j, i in enumerate(items) if j not in indices] + + centers = del_items(centers, del_idx) + sizes = del_items(sizes, del_idx) + names = del_items(names, del_idx) + normal_dirs = del_items(normal_dirs, del_idx) + + surfaces = [] + for _cent, _size, _name, _normal_dir in zip(centers, sizes, names, normal_dirs): + if "normal_dir" in cls.model_fields: + kwargs["normal_dir"] = _normal_dir + + if "name" in cls.model_fields: + kwargs["name"] = _name + + surface = cls(center=_cent, size=_size, **kwargs) + surfaces.append(surface) + + return surfaces + + @classmethod + def surfaces_with_exclusion(cls, size: Size, center: Coordinate, **kwargs: Any) -> list[Self]: + """Returns a list of 6 :class:`Box` instances corresponding to each surface of a 3D volume. + The output surfaces are stored in the order [x-, x+, y-, y+, z-, z+], where x, y, and z + denote which axis is perpendicular to that surface, while "-" and "+" denote the direction + of the normal vector of that surface. If a name is provided, each output surface's name + will be that of the provided name appended with the above symbols. E.g., if the provided + name is "box", the x+ surfaces's name will be "box_x+". If ``kwargs`` contains an + ``exclude_surfaces`` parameter, the returned list of surfaces will not include the excluded + surfaces. Otherwise, the behavior is identical to that of ``surfaces()``. + + Parameters + ---------- + size : tuple[float, float, float] + Size of object in x, y, and z directions. + center : tuple[float, float, float] + Center of object in x, y, and z. + + Example + ------- + >>> b = Box.surfaces_with_exclusion( + ... size=(1, 2, 3), center=(3, 2, 1), exclude_surfaces=["x-"] + ... ) + """ + exclude_surfaces = kwargs.pop("exclude_surfaces", None) + surfaces = cls.surfaces(size=size, center=center, **kwargs) + if "name" in cls.model_fields and exclude_surfaces: + surfaces = [surf for surf in surfaces if surf.name[-2:] not in exclude_surfaces] + return surfaces + + @verify_packages_import(["trimesh"]) + def _do_intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for Box geometry. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + import trimesh + + (x0, y0, z0), (x1, y1, z1) = self.bounds + vertices = [ + (x0, y0, z0), # 0 + (x0, y0, z1), # 1 + (x0, y1, z0), # 2 + (x0, y1, z1), # 3 + (x1, y0, z0), # 4 + (x1, y0, z1), # 5 + (x1, y1, z0), # 6 + (x1, y1, z1), # 7 + ] + faces = [ + (0, 1, 3, 2), # -x + (4, 6, 7, 5), # +x + (0, 4, 5, 1), # -y + (2, 3, 7, 6), # +y + (0, 2, 6, 4), # -z + (1, 5, 7, 3), # +z + ] + mesh = trimesh.Trimesh(vertices, faces) + + section = mesh.section(plane_origin=origin, plane_normal=normal) + if section is None: + return [] + path, _ = section.to_2D(to_2D=to_2D) + return path.polygons_full + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns shapely geometry at plane specified by one non None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for Box geometry. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + if not self.intersects_axis_position(axis, position): + return [] + z0, (x0, y0) = self.pop_axis(self.center, axis=axis) + Lz, (Lx, Ly) = self.pop_axis(self.size, axis=axis) + dz = np.abs(z0 - position) + if dz > Lz / 2 + fp_eps: + return [] + + minx = x0 - Lx / 2 + miny = y0 - Ly / 2 + maxx = x0 + Lx / 2 + maxy = y0 + Ly / 2 + + # handle case where the box vertices are identical + if np.isclose(minx, maxx) and np.isclose(miny, maxy): + return [self.make_shapely_point(minx, miny)] + + return [self.make_shapely_box(minx, miny, maxx, maxy)] + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + self._ensure_equal_shape(x, y, z) + x0, y0, z0 = self.center + Lx, Ly, Lz = self.size + dist_x = np.abs(x - x0) + dist_y = np.abs(y - y0) + dist_z = np.abs(z - z0) + return (dist_x <= Lx / 2) * (dist_y <= Ly / 2) * (dist_z <= Lz / 2) + + def intersections_with( + self, other: Shapely, cleanup: bool = True, quad_segs: Optional[int] = None + ) -> list[Shapely]: + """Returns list of shapely geometries representing the intersections of the geometry with + this 2D box. + + Parameters + ---------- + other : Shapely + Geometry to intersect with. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect this 2D box. + For more details refer to + `Shapely's Documentation `_. + """ + + # Verify 2D + if self.size.count(0.0) != 1: + raise ValidationError( + "Intersections with other geometry are only calculated from a 2D box." + ) + + # dont bother if the geometry doesn't intersect the self at all + if not other.intersects(self): + return [] + + # get list of Shapely shapes that intersect at the self + normal_ind = self.size.index(0.0) + dim = "xyz"[normal_ind] + pos = self.center[normal_ind] + xyz_kwargs = {dim: pos} + shapes_plane = other.intersections_plane(cleanup=cleanup, quad_segs=quad_segs, **xyz_kwargs) + + # intersect all shapes with the input self + bs_min, bs_max = (self.pop_axis(bounds, axis=normal_ind)[1] for bounds in self.bounds) + + shapely_box = self.make_shapely_box(bs_min[0], bs_min[1], bs_max[0], bs_max[1]) + shapely_box = Geometry.evaluate_inf_shape(shapely_box) + return [Geometry.evaluate_inf_shape(shape) & shapely_box for shape in shapes_plane] + + def slightly_enlarged_copy(self) -> Box: + """Box size slightly enlarged around machine precision.""" + size = [increment_float(orig_length, 1) for orig_length in self.size] + return self.updated_copy(size=size) + + def padded_copy( + self, + x: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, + y: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, + z: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, + ) -> Box: + """Created a padded copy of a :class:`Box` instance. + + Parameters + ---------- + x : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None + Padding sizes at the left and right boundaries of the box along x-axis. + y : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None + Padding sizes at the left and right boundaries of the box along y-axis. + z : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None + Padding sizes at the left and right boundaries of the box along z-axis. + + Returns + ------- + Box + Padded instance of :class:`Box`. + """ + + # Validate that padding values are non-negative + for axis_name, axis_padding in zip(("x", "y", "z"), (x, y, z)): + if axis_padding is not None: + if not isinstance(axis_padding, (tuple, list)) or len(axis_padding) != 2: + raise ValueError(f"Padding for {axis_name}-axis must be a tuple of two values.") + if any(p < 0 for p in axis_padding): + raise ValueError( + f"Padding values for {axis_name}-axis must be non-negative. Got {axis_padding}." + ) + + rmin, rmax = self.bounds + + def bound_array(arrs: ArrayLike, idx: int) -> NDArray: + return np.array([(a[idx] if a is not None else 0) for a in arrs]) + + # parse padding sizes for simulation + drmin = bound_array((x, y, z), 0) + drmax = bound_array((x, y, z), 1) + + rmin = np.array(rmin) - drmin + rmax = np.array(rmax) + drmax + + return Box.from_bounds(rmin=rmin, rmax=rmax) + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + size = self.size + center = self.center + coord_min = tuple(c - s / 2 for (s, c) in zip(size, center)) + coord_max = tuple(c + s / 2 for (s, c) in zip(size, center)) + return (coord_min, coord_max) + + @cached_property + def geometry(self) -> Box: + """:class:`Box` representation of self (used for subclasses of Box). + + Returns + ------- + :class:`Box` + Instance of :class:`Box` representing self's geometry. + """ + return Box(center=self.center, size=self.size) + + @cached_property + def zero_dims(self) -> list[Axis]: + """A list of axes along which the :class:`Box` is zero-sized.""" + return [dim for dim, size in enumerate(self.size) if size == 0] + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + if np.count_nonzero(self.size) != 2: + raise ValidationError( + "'Medium2D' requires exactly one of the 'Box' dimensions to have size zero." + ) + return self.size.index(0) + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Box: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + new_center = list(self.center) + new_center[axis] = (bounds[0] + bounds[1]) / 2 + new_size = list(self.size) + new_size[axis] = bounds[1] - bounds[0] + return self.updated_copy(center=tuple(new_center), size=tuple(new_size)) + + def _plot_arrow( + self, + direction: tuple[float, float, float], + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + color: Optional[str] = None, + alpha: Optional[float] = None, + bend_radius: Optional[float] = None, + bend_axis: Axis = None, + both_dirs: bool = False, + ax: Ax = None, + arrow_base: Coordinate = None, + ) -> Ax: + """Adds an arrow to the axis if with options if certain conditions met. + + Parameters + ---------- + direction: tuple[float, float, float] + Normalized vector describing the arrow direction. + x : float = None + Position of plotting plane in x direction. + y : float = None + Position of plotting plane in y direction. + z : float = None + Position of plotting plane in z direction. + color : str = None + Color of the arrow. + alpha : float = None + Opacity of the arrow (0, 1) + bend_radius : float = None + Radius of curvature for this arrow. + bend_axis : Axis = None + Axis of curvature of ``bend_radius``. + both_dirs : bool = False + If True, plots an arrow pointing in direction and one in -direction. + arrow_base : :class:`.Coordinate` = None + Custom base of the arrow. Uses the geometry's center if not provided. + + Returns + ------- + matplotlib.axes._subplots.Axes + The matplotlib axes with the arrow added. + """ + + plot_axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) + _, (dx, dy) = self.pop_axis(direction, axis=plot_axis) + + # conditions to check to determine whether to plot arrow, taking into account the + # possibility of a custom arrow base + arrow_intersecting_plane = len(self.intersections_plane(x=x, y=y, z=z)) > 0 + center = self.center + if arrow_base: + arrow_intersecting_plane = arrow_intersecting_plane and any( + a == b for a, b in zip(arrow_base, [x, y, z]) + ) + center = arrow_base + + _, (dx, dy) = self.pop_axis(direction, axis=plot_axis) + components_in_plane = any(not np.isclose(component, 0) for component in (dx, dy)) + + # plot if arrow in plotting plane and some non-zero component can be displayed. + if arrow_intersecting_plane and components_in_plane: + _, (x0, y0) = self.pop_axis(center, axis=plot_axis) + + # Reasonable value for temporary arrow size. The correct size and direction + # have to be calculated after all transforms have been set. That is why we + # use a callback to do these calculations only at the drawing phase. + xmin, xmax = ax.get_xlim() + ymin, ymax = ax.get_ylim() + v_x = (xmax - xmin) / 10 + v_y = (ymax - ymin) / 10 + + directions = (1.0, -1.0) if both_dirs else (1.0,) + for sign in directions: + arrow = patches.FancyArrowPatch( + (x0, y0), + (x0 + v_x, y0 + v_y), + arrowstyle=arrow_style, + color=color, + alpha=alpha, + zorder=np.inf, + ) + # Don't draw this arrow until it's been reshaped + arrow.set_visible(False) + + callback = self._arrow_shape_cb( + arrow, (x0, y0), (dx, dy), sign, bend_radius if bend_axis == plot_axis else None + ) + callback_id = ax.figure.canvas.mpl_connect("draw_event", callback) + + # Store a reference to the callback because mpl_connect does not. + arrow.set_shape_cb = (callback_id, callback) + + ax.add_patch(arrow) + + return ax + + @staticmethod + def _arrow_shape_cb( + arrow: FancyArrowPatch, + pos: tuple[float, float], + direction: ArrayLike, + sign: float, + bend_radius: float | None, + ) -> Callable[[Event], None]: + def _cb(event: Event) -> None: + # We only want to set the shape once, so we disconnect ourselves + event.canvas.mpl_disconnect(arrow.set_shape_cb[0]) + + transform = arrow.axes.transData.transform + scale_x = transform((1, 0))[0] - transform((0, 0))[0] + scale_y = transform((0, 1))[1] - transform((0, 0))[1] + scale = max(scale_x, scale_y) # <-- Hack: This is a somewhat arbitrary choice. + arrow_length = ARROW_LENGTH * event.canvas.figure.get_dpi() / scale + + if bend_radius: + v_norm = (direction[0] ** 2 + direction[1] ** 2) ** 0.5 + vx_norm = direction[0] / v_norm + vy_norm = direction[1] / v_norm + bend_angle = -sign * arrow_length / bend_radius + t_x = 1 - np.cos(bend_angle) + t_y = np.sin(bend_angle) + v_x = -bend_radius * (vx_norm * t_y - vy_norm * t_x) + v_y = -bend_radius * (vx_norm * t_x + vy_norm * t_y) + tangent_angle = np.arctan2(direction[1], direction[0]) + arrow.set_connectionstyle( + patches.ConnectionStyle.Angle3( + angleA=180 / np.pi * tangent_angle, + angleB=180 / np.pi * (tangent_angle + bend_angle), + ) + ) + + else: + v_x = sign * arrow_length * direction[0] + v_y = sign * arrow_length * direction[1] + + arrow.set_positions(pos, (pos[0] + v_x, pos[1] + v_y)) + arrow.set_visible(True) + arrow.draw(event.renderer) + + return _cb + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + + volume = 1 + + for axis in range(3): + min_bound = max(self.bounds[0][axis], bounds[0][axis]) + max_bound = min(self.bounds[1][axis], bounds[1][axis]) + + volume *= max_bound - min_bound + + return volume + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + + min_bounds = list(self.bounds[0]) + max_bounds = list(self.bounds[1]) + + in_bounds_factor = [2, 2, 2] + length = [0, 0, 0] + + for axis in (0, 1, 2): + if min_bounds[axis] < bounds[0][axis]: + min_bounds[axis] = bounds[0][axis] + in_bounds_factor[axis] -= 1 + + if max_bounds[axis] > bounds[1][axis]: + max_bounds[axis] = bounds[1][axis] + in_bounds_factor[axis] -= 1 + + length[axis] = max_bounds[axis] - min_bounds[axis] + + return ( + length[0] * length[1] * in_bounds_factor[2] + + length[1] * length[2] * in_bounds_factor[0] + + length[2] * length[0] * in_bounds_factor[1] + ) + + """ Autograd code """ + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + + # get gradients w.r.t. each of the 6 faces (in normal direction) + vjps_faces = self._derivative_faces(derivative_info=derivative_info) + + # post-process these values to give the gradients w.r.t. center and size + vjps_center_size = self._derivatives_center_size(vjps_faces=vjps_faces) + + # store only the gradients asked for in 'field_paths' + derivative_map = {} + for field_path in derivative_info.paths: + field_name, *index = field_path + + if field_name in vjps_center_size: + # if the vjp calls for a specific index into the tuple + if index and len(index) == 1: + index = int(index[0]) + if field_path not in derivative_map: + derivative_map[field_path] = vjps_center_size[field_name][index] + + # otherwise, just grab the whole array + else: + derivative_map[field_path] = vjps_center_size[field_name] + + return derivative_map + + @staticmethod + def _derivatives_center_size(vjps_faces: Bound) -> dict[str, Coordinate]: + """Derivatives with respect to the ``center`` and ``size`` fields in the ``Box``.""" + + vjps_faces_min, vjps_faces_max = np.array(vjps_faces) + + # post-process min and max face gradients into center and size + vjp_center = vjps_faces_max - vjps_faces_min + vjp_size = (vjps_faces_min + vjps_faces_max) / 2.0 + + return { + "center": tuple(vjp_center.tolist()), + "size": tuple(vjp_size.tolist()), + } + + def _derivative_faces(self, derivative_info: DerivativeInfo) -> Bound: + """Derivative with respect to normal position of 6 faces of ``Box``.""" + + axes_to_compute = (0, 1, 2) + if len(derivative_info.paths[0]) > 1: + axes_to_compute = tuple(info[1] for info in derivative_info.paths) + + # change in permittivity between inside and outside + vjp_faces = np.zeros((2, 3)) + + for min_max_index, _ in enumerate((0, -1)): + for axis in axes_to_compute: + vjp_face = self._derivative_face( + min_max_index=min_max_index, + axis_normal=axis, + derivative_info=derivative_info, + ) + + # record vjp for this face + vjp_faces[min_max_index, axis] = vjp_face + + return vjp_faces + + def _derivative_face( + self, + min_max_index: int, + axis_normal: Axis, + derivative_info: DerivativeInfo, + ) -> float: + """Compute the derivative w.r.t. shifting a face in the normal direction.""" + + interpolators = derivative_info.interpolators or derivative_info.create_interpolators() + _, axis_perp = self.pop_axis((0, 1, 2), axis=axis_normal) + + # First, check if the face is outside the simulation domain in which case set the + # face gradient to 0. + bounds_normal, bounds_perp = self.pop_axis( + np.array(derivative_info.bounds).T, axis=axis_normal + ) + coord_normal_face = bounds_normal[min_max_index] + + if min_max_index == 0: + if coord_normal_face < derivative_info.simulation_bounds[0][axis_normal]: + return 0.0 + else: + if coord_normal_face > derivative_info.simulation_bounds[1][axis_normal]: + return 0.0 + + intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect) + extents = intersect_max - intersect_min + _, intersect_min_perp = self.pop_axis(np.array(intersect_min), axis=axis_normal) + _, intersect_max_perp = self.pop_axis(np.array(intersect_max), axis=axis_normal) + + is_2d_map = [] + for axis_idx in range(3): + if axis_idx == axis_normal: + continue + is_2d_map.append(np.isclose(extents[axis_idx], 0.0)) + + if np.all(is_2d_map): + return 0.0 + + is_2d = np.any(is_2d_map) + + sim_bounds_normal, sim_bounds_perp = self.pop_axis( + np.array(derivative_info.simulation_bounds).T, axis=axis_normal + ) + + # Build point grid + adaptive_spacing = derivative_info.adaptive_vjp_spacing() + + def spacing_to_grid_points( + spacing: float, min_coord: float, max_coord: float + ) -> NDArray[float]: + N = np.maximum(3, 1 + int((max_coord - min_coord) / spacing)) + + points = np.linspace(min_coord, max_coord, N) + centers = 0.5 * (points[0:-1] + points[1:]) + + return centers + + def verify_integration_interval(bound: tuple[float, float]) -> bool: + # assume the bounds should not be equal or else this integration interval + # would be the flat dimension of a 2D geometry. + return bound[1] > bound[0] + + def compute_integration_weight(grid_points: NDArray[float]) -> float: + grid_spacing = grid_points[1] - grid_points[0] + if grid_spacing == 0.0: + integration_weight = 1.0 / len(grid_points) + else: + integration_weight = grid_points[1] - grid_points[0] + + return integration_weight + + if is_2d: + # build 1D grid for sampling points along the face, which is an edge in the 2D case + zero_dim = np.where(is_2d_map)[0][0] + # zero dim is one of the perpendicular directions, so the other perpendicular direction + # is the nonzero dimension + nonzero_dim = 1 - zero_dim + + # clip at simulation bounds for integration dimension + integration_bounds_perp = ( + intersect_min_perp[nonzero_dim], + intersect_max_perp[nonzero_dim], + ) + + if not verify_integration_interval(integration_bounds_perp): + return 0.0 + + grid_points_linear = spacing_to_grid_points( + adaptive_spacing, integration_bounds_perp[0], integration_bounds_perp[1] + ) + integration_weight = compute_integration_weight(grid_points_linear) + + grid_points = np.repeat(np.expand_dims(grid_points_linear.copy(), 1), 3, axis=1) + + # set up grid points to pass into evaluate_gradient_at_points + grid_points[:, axis_perp[nonzero_dim]] = grid_points_linear + grid_points[:, axis_perp[zero_dim]] = intersect_min_perp[zero_dim] + grid_points[:, axis_normal] = coord_normal_face + else: + # build 3D grid for sampling points along the face + + # clip at simulation bounds for each integration dimension + integration_bounds_perp = ( + (intersect_min_perp[0], intersect_max_perp[0]), + (intersect_min_perp[1], intersect_max_perp[1]), + ) + + if not np.all([verify_integration_interval(b) for b in integration_bounds_perp]): + return 0.0 + + grid_points_perp_1 = spacing_to_grid_points( + adaptive_spacing, integration_bounds_perp[0][0], integration_bounds_perp[0][1] + ) + grid_points_perp_2 = spacing_to_grid_points( + adaptive_spacing, integration_bounds_perp[1][0], integration_bounds_perp[1][1] + ) + integration_weight = compute_integration_weight( + grid_points_perp_1 + ) * compute_integration_weight(grid_points_perp_2) + + mesh_perp1, mesh_perp2 = np.meshgrid(grid_points_perp_1, grid_points_perp_2) + + zip_perp_coords = np.array(list(zip(mesh_perp1.flatten(), mesh_perp2.flatten()))) + + grid_points = np.pad(zip_perp_coords.copy(), ((0, 0), (1, 0)), mode="constant") + + # set up grid points to pass into evaluate_gradient_at_points + grid_points[:, axis_perp[0]] = zip_perp_coords[:, 0] + grid_points[:, axis_perp[1]] = zip_perp_coords[:, 1] + grid_points[:, axis_normal] = coord_normal_face + + normals = np.zeros_like(grid_points) + perps1 = np.zeros_like(grid_points) + perps2 = np.zeros_like(grid_points) + + normals[:, axis_normal] = -1 if (min_max_index == 0) else 1 + perps1[:, axis_perp[0]] = 1 + perps2[:, axis_perp[1]] = 1 + + gradient_at_points = derivative_info.evaluate_gradient_at_points( + spatial_coords=grid_points, + normals=normals, + perps1=perps1, + perps2=perps2, + interpolators=interpolators, + ) + + vjp_value = np.sum(integration_weight * np.real(gradient_at_points)) + return vjp_value + + +"""Compound subclasses""" + + +class Transformed(Geometry): + """Class representing a transformed geometry.""" + + geometry: discriminated_union(GeometryType) = Field( + title="Geometry", + description="Base geometry to be transformed.", + ) + + transform: MatrixReal4x4 = Field( + default_factory=lambda: np.eye(4).tolist(), + title="Transform", + description="Transform matrix applied to the base geometry.", + ) + + @field_validator("transform") + @classmethod + def _transform_is_invertible(cls, val: MatrixReal4x4) -> MatrixReal4x4: + # If the transform is not invertible, this will raise an error + _ = np.linalg.inv(val) + return val + + @field_validator("geometry") + @classmethod + def _geometry_is_finite(cls, val: GeometryType) -> GeometryType: + if not np.isfinite(val.bounds).all(): + raise ValidationError( + "Transformations are only supported on geometries with finite dimensions. " + "Try using a large value instead of 'inf' when creating geometries that undergo " + "transformations." + ) + return val + + @model_validator(mode="after") + def _apply_transforms(self: dict[str, Any]) -> dict[str, Any]: + while isinstance(self.geometry, Transformed): + inner = self.geometry + object.__setattr__(self, "geometry", inner.geometry) + object.__setattr__(self, "transform", np.dot(self.transform, inner.transform)) + return self + + @cached_property + def inverse(self) -> MatrixReal4x4: + """Inverse of this transform.""" + return np.linalg.inv(self.transform) + + @staticmethod + def _vertices_from_bounds(bounds: Bound) -> ArrayFloat2D: + """Return the 8 vertices derived from bounds. + + The vertices are returned as homogeneous coordinates (with 4 components). + + Parameters + ---------- + bounds : Bound + Bounds from which to derive the vertices. + + Returns + ------- + ArrayFloat2D + Array with shape (4, 8) with all vertices from ``bounds``. + """ + (x0, y0, z0), (x1, y1, z1) = bounds + return np.array( + ( + (x0, x0, x0, x0, x1, x1, x1, x1), + (y0, y0, y1, y1, y0, y0, y1, y1), + (z0, z1, z0, z1, z0, z1, z0, z1), + (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0), + ) + ) + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float, float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + # NOTE (Lucas): The bounds are overestimated because we don't want to calculate + # precise TriangleMesh representations for GeometryGroup or ClipOperation. + vertices = np.dot(self.transform, self._vertices_from_bounds(self.geometry.bounds))[:3] + return (tuple(vertices.min(axis=1)), tuple(vertices.max(axis=1))) + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + return self.geometry.intersections_tilted_plane( + tuple(np.dot((normal[0], normal[1], normal[2], 0.0), self.transform)[:3]), + tuple(np.dot(self.inverse, (origin[0], origin[1], origin[2], 1.0))[:3]), + np.dot(to_2D, self.transform), + cleanup=cleanup, + quad_segs=quad_segs, + ) + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + x = np.array(x) + y = np.array(y) + z = np.array(z) + xyz = np.dot(self.inverse, np.vstack((x.flat, y.flat, z.flat, np.ones(x.size)))) + if xyz.shape[1] == 1: + # TODO: This "fix" is required because of a bug in PolySlab.inside (with non-zero sidewall angle) + return self.geometry.inside(xyz[0][0], xyz[1][0], xyz[2][0]).reshape(x.shape) + return self.geometry.inside(xyz[0], xyz[1], xyz[2]).reshape(x.shape) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + # NOTE (Lucas): Bounds are overestimated. + vertices = np.dot(self.inverse, self._vertices_from_bounds(bounds))[:3] + inverse_bounds = (tuple(vertices.min(axis=1)), tuple(vertices.max(axis=1))) + return abs(np.linalg.det(self.transform)) * self.geometry.volume(inverse_bounds) + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + log.warning("Surface area of transformed elements cannot be calculated.") + return None + + @staticmethod + def translation(x: float, y: float, z: float) -> MatrixReal4x4: + """Return a translation matrix. + + Parameters + ---------- + x : float + Translation along x. + y : float + Translation along y. + z : float + Translation along z. + + Returns + ------- + numpy.ndarray + Transform matrix with shape (4, 4). + """ + return np.array( + [ + (1.0, 0.0, 0.0, x), + (0.0, 1.0, 0.0, y), + (0.0, 0.0, 1.0, z), + (0.0, 0.0, 0.0, 1.0), + ], + dtype=float, + ) + + @staticmethod + def scaling(x: float = 1.0, y: float = 1.0, z: float = 1.0) -> MatrixReal4x4: + """Return a scaling matrix. + + Parameters + ---------- + x : float = 1.0 + Scaling factor along x. + y : float = 1.0 + Scaling factor along y. + z : float = 1.0 + Scaling factor along z. + + Returns + ------- + numpy.ndarray + Transform matrix with shape (4, 4). + """ + if np.isclose((x, y, z), 0.0).any(): + raise Tidy3dError("Scaling factors cannot be zero in any dimensions.") + return np.array( + [ + (x, 0.0, 0.0, 0.0), + (0.0, y, 0.0, 0.0), + (0.0, 0.0, z, 0.0), + (0.0, 0.0, 0.0, 1.0), + ], + dtype=float, + ) + + @staticmethod + def rotation(angle: float, axis: Union[Axis, Coordinate]) -> MatrixReal4x4: + """Return a rotation matrix. + + Parameters + ---------- + angle : float + Rotation angle (in radians). + axis : Union[int, tuple[float, float, float]] + Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. + + Returns + ------- + numpy.ndarray + Transform matrix with shape (4, 4). + """ + transform = np.eye(4) + transform[:3, :3] = RotationAroundAxis(angle=angle, axis=axis).matrix + return transform + + @staticmethod + def reflection(normal: Coordinate) -> MatrixReal4x4: + """Return a reflection matrix. + + Parameters + ---------- + normal : tuple[float, float, float] + Normal of the plane of reflection. + + Returns + ------- + numpy.ndarray + Transform matrix with shape (4, 4). + """ + + transform = np.eye(4) + transform[:3, :3] = ReflectionFromPlane(normal=normal).matrix + return transform + + @staticmethod + def preserves_axis(transform: MatrixReal4x4, axis: Axis) -> bool: + """Indicate if the transform preserves the orientation of a given axis. + + Parameters: + transform: MatrixReal4x4 + Transform matrix to check. + axis : int + Axis to check. Values 0, 1, or 2, to check x, y, or z, respectively. + + Returns + ------- + bool + ``True`` if the transformation preserves the axis orientation, ``False`` otherwise. + """ + i = (axis + 1) % 3 + j = (axis + 2) % 3 + return np.isclose(transform[i, axis], 0) and np.isclose(transform[j, axis], 0) + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + normal = self.geometry._normal_2dmaterial + preserves_axis = Transformed.preserves_axis(self.transform, normal) + + if not preserves_axis: + raise ValidationError( + "'Medium2D' requires geometries of type 'Transformed' to " + "perserve the axis normal to the 'Medium2D'." + ) + + return normal + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Transformed: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + min_bound = np.array([0, 0, 0, 1.0]) + min_bound[axis] = bounds[0] + max_bound = np.array([0, 0, 0, 1.0]) + max_bound[axis] = bounds[1] + new_bounds = [] + new_bounds.append(np.dot(self.inverse, min_bound)[axis]) + new_bounds.append(np.dot(self.inverse, max_bound)[axis]) + new_geometry = self.geometry._update_from_bounds(bounds=new_bounds, axis=axis) + return self.updated_copy(geometry=new_geometry) + + +class ClipOperation(Geometry): + """Class representing the result of a set operation between geometries.""" + + operation: ClipOperationType = Field( + title="Operation Type", + description="Operation to be performed between geometries.", + ) + + geometry_a: discriminated_union(GeometryType) = Field( + title="Geometry A", + description="First operand for the set operation. It can be any geometry type, including " + ":class:`GeometryGroup`.", + ) + + geometry_b: discriminated_union(GeometryType) = Field( + title="Geometry B", + description="Second operand for the set operation. It can also be any geometry type.", + ) + + @field_validator("geometry_a", "geometry_b") + @classmethod + def _geometries_untraced(cls, val: GeometryType) -> GeometryType: + """Make sure that ``ClipOperation`` geometries do not contain tracers.""" + traced = val._strip_traced_fields() + if traced: + raise ValidationError( + f"{val.type} contains traced fields {list(traced.keys())}. Note that " + "'ClipOperation' does not currently support automatic differentiation." + ) + return val + + @staticmethod + def to_polygon_list(base_geometry: Shapely, cleanup: bool = False) -> list[Shapely]: + """Return a list of valid polygons from a shapely geometry, discarding points, lines, and + empty polygons, and empty triangles within polygons. + + Parameters + ---------- + base_geometry : shapely.geometry.base.BaseGeometry + Base geometry for inspection. + cleanup: bool = False + If True, removes extremely small features from each polygon's boundary. + This is useful for removing artifacts from 2D plots displayed to the user. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + Valid polygons retrieved from ``base geometry``. + """ + unfiltered_geoms = [] + if base_geometry.geom_type == "GeometryCollection": + unfiltered_geoms = [ + p + for geom in base_geometry.geoms + for p in ClipOperation.to_polygon_list(geom, cleanup) + ] + if base_geometry.geom_type == "MultiPolygon": + unfiltered_geoms = [p for p in base_geometry.geoms if not p.is_empty] + if base_geometry.geom_type == "Polygon" and not base_geometry.is_empty: + unfiltered_geoms = [base_geometry] + geoms = [] + if cleanup: + # Optional: "clean" each of the polygons (by removing extremely small or thin features). + for geom in unfiltered_geoms: + geom_clean = cleanup_shapely_object(geom) + if geom_clean.geom_type == "Polygon": + geoms.append(geom_clean) + if geom_clean.geom_type == "MultiPolygon": + geoms += [p for p in geom_clean.geoms if not p.is_empty] + # Ignore other types of shapely objects (points and lines) + else: + geoms = unfiltered_geoms + return geoms + + @property + def _shapely_operation(self) -> Callable[[Shapely, Shapely], Shapely]: + """Return a Shapely function equivalent to this operation.""" + result = _shapely_operations.get(self.operation, None) + if not result: + raise ValueError( + "'operation' must be one of 'union', 'intersection', 'difference', or " + "'symmetric_difference'." + ) + return result + + @property + def _bit_operation(self) -> Callable[[Any, Any], Any]: + """Return a function equivalent to this operation using bit operators.""" + result = _bit_operations.get(self.operation, None) + if not result: + raise ValueError( + "'operation' must be one of 'union', 'intersection', 'difference', or " + "'symmetric_difference'." + ) + return result + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + a = self.geometry_a.intersections_tilted_plane( + normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs + ) + b = self.geometry_b.intersections_tilted_plane( + normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs + ) + geom_a = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in a]) + geom_b = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in b]) + return ClipOperation.to_polygon_list( + self._shapely_operation(geom_a, geom_b), + cleanup=cleanup, + ) + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentaton `_. + """ + a = self.geometry_a.intersections_plane(x, y, z, cleanup=cleanup, quad_segs=quad_segs) + b = self.geometry_b.intersections_plane(x, y, z, cleanup=cleanup, quad_segs=quad_segs) + geom_a = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in a]) + geom_b = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in b]) + return ClipOperation.to_polygon_list( + self._shapely_operation(geom_a, geom_b), + cleanup=cleanup, + ) + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + # Overestimates + if self.operation == "difference": + result = self.geometry_a.bounds + elif self.operation == "intersection": + bounds = (self.geometry_a.bounds, self.geometry_b.bounds) + result = ( + tuple(max(b[i] for b, _ in bounds) for i in range(3)), + tuple(min(b[i] for _, b in bounds) for i in range(3)), + ) + if any(result[0][i] > result[1][i] for i in range(3)): + result = ((0, 0, 0), (0, 0, 0)) + else: + bounds = (self.geometry_a.bounds, self.geometry_b.bounds) + result = ( + tuple(min(b[i] for b, _ in bounds) for i in range(3)), + tuple(max(b[i] for _, b in bounds) for i in range(3)), + ) + return result + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + inside_a = self.geometry_a.inside(x, y, z) + inside_b = self.geometry_b.inside(x, y, z) + return self._bit_operation(inside_a, inside_b) + + def inside_meshgrid( + self, x: NDArray[float], y: NDArray[float], z: NDArray[float] + ) -> NDArray[bool]: + """Faster way to check ``self.inside`` on a meshgrid. The input arrays are assumed sorted. + + Parameters + ---------- + x : np.ndarray[float] + 1D array of point positions in x direction. + y : np.ndarray[float] + 1D array of point positions in y direction. + z : np.ndarray[float] + 1D array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every + point that is inside the geometry. + """ + inside_a = self.geometry_a.inside_meshgrid(x, y, z) + inside_b = self.geometry_b.inside_meshgrid(x, y, z) + return self._bit_operation(inside_a, inside_b) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + # Overestimates + if self.operation == "intersection": + return min(self.geometry_a.volume(bounds), self.geometry_b.volume(bounds)) + if self.operation == "difference": + return self.geometry_a.volume(bounds) + return self.geometry_a.volume(bounds) + self.geometry_b.volume(bounds) + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + # Overestimates + return self.geometry_a.surface_area(bounds) + self.geometry_b.surface_area(bounds) + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + normal_a = self.geometry_a._normal_2dmaterial + normal_b = self.geometry_b._normal_2dmaterial + + if normal_a != normal_b: + raise ValidationError( + "'Medium2D' requires both geometries in the 'ClipOperation' to " + "have exactly one dimension with zero size in common." + ) + + plane_position_a = self.geometry_a.bounds[0][normal_a] + plane_position_b = self.geometry_b.bounds[0][normal_b] + + if plane_position_a != plane_position_b: + raise ValidationError( + "'Medium2D' requires both geometries in the 'ClipOperation' to be co-planar." + ) + return normal_a + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> ClipOperation: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + new_geom_a = self.geometry_a._update_from_bounds(bounds=bounds, axis=axis) + new_geom_b = self.geometry_b._update_from_bounds(bounds=bounds, axis=axis) + return self.updated_copy(geometry_a=new_geom_a, geometry_b=new_geom_b) + + +class GeometryGroup(Geometry): + """A collection of Geometry objects that can be called as a single geometry object.""" + + geometries: tuple[discriminated_union(GeometryType), ...] = Field( + title="Geometries", + description="Tuple of geometries in a single grouping. " + "Can provide significant performance enhancement in ``Structure`` when all geometries are " + "assigned the same medium.", + ) + + @field_validator("geometries") + @classmethod + def _geometries_not_empty(cls, val: tuple[GeometryType, ...]) -> tuple[GeometryType, ...]: + """make sure geometries are not empty.""" + if not len(val) > 0: + raise ValidationError("GeometryGroup.geometries must not be empty.") + return val + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float, float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + + bounds = tuple(geometry.bounds for geometry in self.geometries) + return ( + tuple(min(b[i] for b, _ in bounds) for i in range(3)), + tuple(max(b[i] for _, b in bounds) for i in range(3)), + ) + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + return [ + intersection + for geometry in self.geometries + for intersection in geometry.intersections_tilted_plane( + normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs + ) + ] + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + if not self.intersects_plane(x, y, z): + return [] + return [ + intersection + for geometry in self.geometries + for intersection in geometry.intersections_plane( + x=x, y=y, z=z, cleanup=cleanup, quad_segs=quad_segs + ) + ] + + def intersects_axis_position(self, axis: float, position: float) -> bool: + """Whether self intersects plane specified by a given position along a normal axis. + + Parameters + ---------- + axis : int = None + Axis normal to the plane. + position : float = None + Position of plane along the normal axis. + + Returns + ------- + bool + Whether this geometry intersects the plane. + """ + return any(geom.intersects_axis_position(axis, position) for geom in self.geometries) + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + individual_insides = (geometry.inside(x, y, z) for geometry in self.geometries) + return functools.reduce(lambda a, b: a | b, individual_insides) + + def inside_meshgrid( + self, x: NDArray[float], y: NDArray[float], z: NDArray[float] + ) -> NDArray[bool]: + """Faster way to check ``self.inside`` on a meshgrid. The input arrays are assumed sorted. + + Parameters + ---------- + x : np.ndarray[float] + 1D array of point positions in x direction. + y : np.ndarray[float] + 1D array of point positions in y direction. + z : np.ndarray[float] + 1D array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every + point that is inside the geometry. + """ + individual_insides = (geom.inside_meshgrid(x, y, z) for geom in self.geometries) + return functools.reduce(lambda a, b: a | b, individual_insides) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + return sum(geometry.volume(bounds) for geometry in self.geometries) + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + return sum(geometry.surface_area(bounds) for geometry in self.geometries) + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + + normals = {geom._normal_2dmaterial for geom in self.geometries} + + if len(normals) != 1: + raise ValidationError( + "'Medium2D' requires all geometries in the 'GeometryGroup' to " + "share exactly one dimension with zero size." + ) + normal = list(normals)[0] + positions = {geom.bounds[0][normal] for geom in self.geometries} + if len(positions) != 1: + raise ValidationError( + "'Medium2D' requires all geometries in the 'GeometryGroup' to be co-planar." + ) + return normal + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> GeometryGroup: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + new_geometries = tuple( + geometry._update_from_bounds(bounds=bounds, axis=axis) for geometry in self.geometries + ) + return self.updated_copy(geometries=new_geometries) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + + grad_vjps = {} + + # create interpolators once for all geometries to avoid redundant field data conversions + interpolators = derivative_info.interpolators or derivative_info.create_interpolators() + + for field_path in derivative_info.paths: + _, index, *geo_path = field_path + geo = self.geometries[index] + # pass pre-computed interpolators if available + geo_info = derivative_info.updated_copy( + paths=[tuple(geo_path)], + bounds=geo.bounds, + bounds_intersect=self.bounds_intersection( + geo.bounds, derivative_info.simulation_bounds + ), + deep=False, + interpolators=interpolators, + ) + + vjp_dict_geo = geo._compute_derivatives(geo_info) + + if len(vjp_dict_geo) != 1: + raise AssertionError("Got multiple gradients for single geometry field.") + + grad_vjps[field_path] = vjp_dict_geo.popitem()[1] + + return grad_vjps + + +def cleanup_shapely_object(obj: Shapely, tolerance_ratio: float = POLY_TOLERANCE_RATIO) -> Shapely: + """Remove small geometric features from the boundaries of a shapely object including + inward and outward spikes, thin holes, and thin connections between larger regions. + + Parameters + ---------- + obj : shapely + a shapely object (typically a ``Polygon`` or a ``MultiPolygon``) + tolerance_ratio : float = ``POLY_TOLERANCE_RATIO`` + Features on the boundaries of polygons will be discarded if they are smaller + or narrower than ``tolerance_ratio`` multiplied by the size of the object. + + Returns + ------- + Shapely + A new shapely object whose small features (eg. thin spikes or holes) are removed. + + Notes + ----- + This function does not attempt to delete overlapping, nearby, or collinear vertices. + To solve that problem, use ``shapely.simplify()`` afterwards. + """ + if _package_is_older_than("shapely", "2.1"): + log.warning("Versions of shapely prior to v2.1 may cause plot errors.", log_once=True) + return obj + if obj.is_empty: + return obj + centroid = obj.centroid + object_size = min(obj.bounds[2] - obj.bounds[0], obj.bounds[3] - obj.bounds[1]) + if object_size == 0.0: + return shapely.Polygon([]) + + # To prevent numerical over- or underflow errors, subtract the centroid and rescale + normalized_obj = shapely.affinity.affine_transform( + obj, + matrix=[ + 1 / object_size, + 0.0, + 0.0, + 1 / object_size, + -centroid.x / object_size, + -centroid.y / object_size, + ], + ) + # Important: Remove any self intersections beforehand using `shapely.make_valid()`. + valid_obj = shapely.make_valid(normalized_obj, method="structure", keep_collapsed=False) + + # To get rid of small thin features, erode(shrink), dilate(expand), and erode again. + eroded_obj = shapely.buffer( + valid_obj, + distance=-tolerance_ratio, + cap_style="square", + quad_segs=3, + ) + dilated_obj = shapely.buffer( + eroded_obj, + distance=2 * tolerance_ratio, + cap_style="square", + quad_segs=3, + ) + cleaned_obj = dilated_obj + + # Optional: Now shrink the polygon back to the original size. + cleaned_obj = shapely.buffer( + cleaned_obj, + distance=-tolerance_ratio, + cap_style="square", + quad_segs=3, + ) + # Clean vertices of very close distances created during the erosion/dilation process. + # The distance value is heuristic. + cleaned_obj = cleaned_obj.simplify(POLY_DISTANCE_TOLERANCE, preserve_topology=True) + # Revert to the original scale and position. + rescaled_clean_obj = shapely.affinity.affine_transform( + cleaned_obj, + matrix=[ + object_size, + 0.0, + 0.0, + object_size, + centroid.x, + centroid.y, + ], + ) + return rescaled_clean_obj + + +from tidy3d._common.components.geometry.utils import ( # noqa: E402 + GeometryType, + from_shapely, + vertices_from_shapely, +) diff --git a/tidy3d/_common/components/geometry/bound_ops.py b/tidy3d/_common/components/geometry/bound_ops.py new file mode 100644 index 0000000000..b21308b3c7 --- /dev/null +++ b/tidy3d/_common/components/geometry/bound_ops.py @@ -0,0 +1,71 @@ +"""Geometry operations for bounding box type with minimal imports.""" + +from __future__ import annotations + +from math import isclose +from typing import TYPE_CHECKING + +from tidy3d._common.constants import fp_eps + +if TYPE_CHECKING: + from tidy3d._common.components.types.base import Bound + + +def bounds_intersection(bounds1: Bound, bounds2: Bound) -> Bound: + """Return the bounds that are the intersection of two bounds.""" + rmin1, rmax1 = bounds1 + rmin2, rmax2 = bounds2 + rmin = tuple(max(v1, v2) for v1, v2 in zip(rmin1, rmin2)) + rmax = tuple(min(v1, v2) for v1, v2 in zip(rmax1, rmax2)) + return (rmin, rmax) + + +def bounds_union(bounds1: Bound, bounds2: Bound) -> Bound: + """Return the bounds that are the union of two bounds.""" + rmin1, rmax1 = bounds1 + rmin2, rmax2 = bounds2 + rmin = tuple(min(v1, v2) for v1, v2 in zip(rmin1, rmin2)) + rmax = tuple(max(v1, v2) for v1, v2 in zip(rmax1, rmax2)) + return (rmin, rmax) + + +def bounds_contains( + outer_bounds: Bound, inner_bounds: Bound, rtol: float = fp_eps, atol: float = 0.0 +) -> bool: + """Checks whether ``inner_bounds`` is contained within ``outer_bounds`` within specified tolerances. + + Parameters + ---------- + outer_bounds : Bound + The outer bounds to check containment against + inner_bounds : Bound + The inner bounds to check if contained + rtol : float = fp_eps + Relative tolerance for comparing bounds + atol : float = 0.0 + Absolute tolerance for comparing bounds + + Returns + ------- + bool + True if ``inner_bounds`` is contained within ``outer_bounds`` within tolerances + """ + outer_min, outer_max = outer_bounds + inner_min, inner_max = inner_bounds + for dim in range(3): + outer_min_dim = outer_min[dim] + outer_max_dim = outer_max[dim] + inner_min_dim = inner_min[dim] + inner_max_dim = inner_max[dim] + within_min = ( + isclose(outer_min_dim, inner_min_dim, rel_tol=rtol, abs_tol=atol) + or outer_min_dim <= inner_min_dim + ) + within_max = ( + isclose(outer_max_dim, inner_max_dim, rel_tol=rtol, abs_tol=atol) + or outer_max_dim >= inner_max_dim + ) + + if not within_min or not within_max: + return False + return True diff --git a/tidy3d/_common/components/geometry/float_utils.py b/tidy3d/_common/components/geometry/float_utils.py new file mode 100644 index 0000000000..2b5848666d --- /dev/null +++ b/tidy3d/_common/components/geometry/float_utils.py @@ -0,0 +1,31 @@ +"""Utilities for float manipulation.""" + +from __future__ import annotations + +import numpy as np + +from tidy3d._common.constants import inf + + +def increment_float(val: float, sign: int) -> float: + """Applies a small positive or negative shift as though `val` is a 32bit float + using numpy.nextafter, but additionally handles some corner cases. + """ + # Infinity is left unchanged + if val == inf or val == -inf: + return val + + if sign >= 0: + sign = 1 + else: + sign = -1 + + # Avoid small increments within subnormal values + if np.abs(val) <= np.finfo(np.float32).tiny: + return val + sign * np.finfo(np.float32).tiny + + # Numpy seems to skip over the increment from -0.0 and +0.0 + # which is different from c++ + val_inc = np.nextafter(val, sign * inf, dtype=np.float32) + + return np.float32(val_inc) diff --git a/tidy3d/_common/components/geometry/mesh.py b/tidy3d/_common/components/geometry/mesh.py new file mode 100644 index 0000000000..416b9eaaf1 --- /dev/null +++ b/tidy3d/_common/components/geometry/mesh.py @@ -0,0 +1,1285 @@ +"""Mesh-defined geometry.""" + +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Any, Optional + +import numpy as np +from autograd import numpy as anp +from numpy.typing import NDArray +from pydantic import Field, PrivateAttr, field_validator, model_validator + +from tidy3d._common.components.autograd import get_static +from tidy3d._common.components.base import cached_property +from tidy3d._common.components.data.data_array import DATA_ARRAY_MAP, TriangleMeshDataArray +from tidy3d._common.components.data.dataset import TriangleMeshDataset +from tidy3d._common.components.data.validators import validate_no_nans +from tidy3d._common.components.geometry import base +from tidy3d._common.components.viz import add_ax_if_none, equal_aspect +from tidy3d._common.config import config +from tidy3d._common.constants import fp_eps, inf +from tidy3d._common.exceptions import DataError, ValidationError +from tidy3d._common.log import log +from tidy3d._common.packaging import verify_packages_import + +if TYPE_CHECKING: + from os import PathLike + from typing import Callable, Literal, Union + + from trimesh import Trimesh + + from tidy3d._common.components.autograd import AutogradFieldMap + from tidy3d._common.components.autograd.derivative_utils import DerivativeInfo + from tidy3d._common.components.types.base import Ax, Bound, Coordinate, MatrixReal4x4, Shapely + +AREA_SIZE_THRESHOLD = 1e-36 + + +class TriangleMesh(base.Geometry, ABC): + """Custom surface geometry given by a triangle mesh, as in the STL file format. + + Example + ------- + >>> vertices = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) + >>> faces = np.array([[1, 2, 3], [0, 3, 2], [0, 1, 3], [0, 2, 1]]) + >>> stl_geom = TriangleMesh.from_vertices_faces(vertices, faces) + """ + + mesh_dataset: Optional[TriangleMeshDataset] = Field( + None, + title="Surface mesh data", + description="Surface mesh data.", + ) + + _no_nans_mesh = validate_no_nans("mesh_dataset") + _barycentric_samples: dict[int, NDArray] = PrivateAttr(default_factory=dict) + + @verify_packages_import(["trimesh"]) + @model_validator(mode="before") + @classmethod + def _validate_trimesh_library(cls, data: dict[str, Any]) -> dict[str, Any]: + """Check if the trimesh package is imported as a validator.""" + return data + + @field_validator("mesh_dataset", mode="before") + @classmethod + def _warn_if_none(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: + """Warn if the Dataset fails to load.""" + if isinstance(val, dict): + if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): + log.warning("Loading 'mesh_dataset' without data.") + return None + return val + + @field_validator("mesh_dataset") + @classmethod + def _check_mesh(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: + """Check that the mesh is valid.""" + if val is None: + return None + + import trimesh + + surface_mesh = val.surface_mesh + triangles = get_static(surface_mesh.data) + mesh = cls._triangles_to_trimesh(triangles) + if not all(np.array(mesh.area_faces) > AREA_SIZE_THRESHOLD): + old_tol = trimesh.tol.merge + trimesh.tol.merge = np.sqrt(2 * AREA_SIZE_THRESHOLD) + new_mesh = mesh.process(validate=True) + trimesh.tol.merge = old_tol + val = TriangleMesh.from_trimesh(new_mesh).mesh_dataset + log.warning( + f"The provided mesh has triangles with near zero area < {AREA_SIZE_THRESHOLD}. " + "Triangles which have one edge of their 2D oriented bounding box shorter than " + f"'sqrt(2*{AREA_SIZE_THRESHOLD}) are being automatically removed.'" + ) + if not all(np.array(new_mesh.area_faces) > AREA_SIZE_THRESHOLD): + raise ValidationError( + f"The provided mesh has triangles with near zero area < {AREA_SIZE_THRESHOLD}. " + "The automatic removal of these triangles has failed. You can try " + "using numpy-stl's 'from_file' import with 'remove_empty_areas' set " + "to True and a suitable 'AREA_SIZE_THRESHOLD' to remove them." + ) + if not mesh.is_watertight: + log.warning( + "The provided mesh is not watertight. " + "This can lead to incorrect permittivity distributions, " + "and can also cause problems with plotting and mesh validation. " + "You can try 'TriangleMesh.fill_holes', which attempts to repair the mesh. " + "Otherwise, the mesh may require manual repair. You can use a " + "'PermittivityMonitor' to check if the permittivity distribution is correct. " + "You can see which faces are broken using 'trimesh.repair.broken_faces'." + ) + if not mesh.is_winding_consistent: + log.warning( + "The provided mesh does not have consistent winding (face orientations). " + "This can lead to incorrect permittivity distributions, " + "and can also cause problems with plotting and mesh validation. " + "You can try 'TriangleMesh.fix_winding', which attempts to repair the mesh. " + "Otherwise, the mesh may require manual repair. You can use a " + "'PermittivityMonitor' to check if the permittivity distribution is correct. " + ) + if not mesh.is_volume: + log.warning( + "The provided mesh does not represent a valid volume, possibly due to " + "incorrect normal vector orientation. " + "This can lead to incorrect permittivity distributions, " + "and can also cause problems with plotting and mesh validation. " + "You can try 'TriangleMesh.fix_normals', " + "which attempts to fix the normals to be consistent and outward-facing. " + "Otherwise, the mesh may require manual repair. You can use a " + "'PermittivityMonitor' to check if the permittivity distribution is correct." + ) + + return val + + @verify_packages_import(["trimesh"]) + def fix_winding(self) -> TriangleMesh: + """Try to fix winding in the mesh.""" + import trimesh + + mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) + trimesh.repair.fix_winding(mesh) + return TriangleMesh.from_trimesh(mesh) + + @verify_packages_import(["trimesh"]) + def fill_holes(self) -> TriangleMesh: + """Try to fill holes in the mesh. Can be used to repair non-watertight meshes.""" + import trimesh + + mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) + trimesh.repair.fill_holes(mesh) + return TriangleMesh.from_trimesh(mesh) + + @verify_packages_import(["trimesh"]) + def fix_normals(self) -> TriangleMesh: + """Try to fix normals to be consistent and outward-facing.""" + import trimesh + + mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) + trimesh.repair.fix_normals(mesh) + return TriangleMesh.from_trimesh(mesh) + + @classmethod + @verify_packages_import(["trimesh"]) + def from_stl( + cls, + filename: str, + scale: float = 1.0, + origin: tuple[float, float, float] = (0, 0, 0), + solid_index: Optional[int] = None, + **kwargs: Any, + ) -> Union[TriangleMesh, base.GeometryGroup]: + """Load a :class:`.TriangleMesh` directly from an STL file. + The ``solid_index`` parameter can be used to select a single solid from the file. + Otherwise, if the file contains a single solid, it will be loaded as a + :class:`.TriangleMesh`; if the file contains multiple solids, + they will all be loaded as a :class:`.GeometryGroup`. + + Parameters + ---------- + filename : str + The name of the STL file containing the surface geometry mesh data. + scale : float = 1.0 + The length scale for the loaded geometry (um). + For example, a scale of 10.0 means that a vertex (1, 0, 0) will be placed at + x = 10 um. + origin : tuple[float, float, float] = (0, 0, 0) + The origin of the loaded geometry, in units of ``scale``. + Translates from (0, 0, 0) to this point after applying the scaling. + solid_index : int = None + If set, read a single solid with this index from the file. + + Returns + ------- + Union[:class:`.TriangleMesh`, :class:`.GeometryGroup`] + The geometry or geometry group from the file. + """ + import trimesh + + from tidy3d._common.components.types.third_party import TrimeshType + + def process_single(mesh: TrimeshType) -> TriangleMesh: + """Process a single 'trimesh.Trimesh' using scale and origin.""" + mesh.apply_scale(scale) + mesh.apply_translation(origin) + return cls.from_trimesh(mesh) + + scene = trimesh.load(filename, **kwargs) + meshes = [] + if isinstance(scene, trimesh.Trimesh): + meshes = [scene] + elif isinstance(scene, trimesh.Scene): + meshes = scene.dump() + else: + raise ValidationError( + "Invalid trimesh type in file. Supported types are 'trimesh.Trimesh' " + "and 'trimesh.Scene'." + ) + + if solid_index is None: + if isinstance(scene, trimesh.Trimesh): + return process_single(scene) + if isinstance(scene, trimesh.Scene): + geoms = [process_single(mesh) for mesh in meshes] + return base.GeometryGroup(geometries=geoms) + + if solid_index < len(meshes): + return process_single(meshes[solid_index]) + raise ValidationError("No solid found at 'solid_index' in the stl file.") + + @verify_packages_import(["trimesh"]) + def to_stl( + self, + filename: PathLike, + *, + binary: bool = True, + ) -> None: + """Export this TriangleMesh to an STL file. + + Parameters + ---------- + filename : str + Output STL filename. + binary : bool = True + Whether to write binary STL. Set False for ASCII STL. + """ + triangles = get_static(self.mesh_dataset.surface_mesh.data) + mesh = self._triangles_to_trimesh(triangles) + + file_type = "stl" if binary else "stl_ascii" + mesh.export(file_obj=filename, file_type=file_type) + + @classmethod + @verify_packages_import(["trimesh"]) + def from_trimesh(cls, mesh: Trimesh) -> TriangleMesh: + """Create a :class:`.TriangleMesh` from a ``trimesh.Trimesh`` object. + + Parameters + ---------- + trimesh : ``trimesh.Trimesh`` + The Trimesh object containing the surface geometry mesh data. + + Returns + ------- + :class:`.TriangleMesh` + The custom surface mesh geometry given by the ``trimesh.Trimesh`` provided. + """ + return cls.from_vertices_faces(mesh.vertices, mesh.faces) + + @classmethod + def from_triangles(cls, triangles: NDArray) -> TriangleMesh: + """Create a :class:`.TriangleMesh` from a numpy array + containing the triangles of a surface mesh. + + Parameters + ---------- + triangles : ``np.ndarray`` + A numpy array of shape (N, 3, 3) storing the triangles of the surface mesh. + The first index labels the triangle, the second index labels the vertex + within a given triangle, and the third index is the coordinate (x, y, or z). + + Returns + ------- + :class:`.TriangleMesh` + The custom surface mesh geometry given by the triangles provided. + + """ + triangles = anp.array(triangles) + if len(triangles.shape) != 3 or triangles.shape[1] != 3 or triangles.shape[2] != 3: + raise ValidationError( + f"Provided 'triangles' must be an N x 3 x 3 array, given {triangles.shape}." + ) + num_faces = len(triangles) + coords = { + "face_index": np.arange(num_faces), + "vertex_index": np.arange(3), + "axis": np.arange(3), + } + vertices = TriangleMeshDataArray(triangles, coords=coords) + mesh_dataset = TriangleMeshDataset(surface_mesh=vertices) + return TriangleMesh(mesh_dataset=mesh_dataset) + + @classmethod + @verify_packages_import(["trimesh"]) + def from_vertices_faces(cls, vertices: NDArray, faces: NDArray) -> TriangleMesh: + """Create a :class:`.TriangleMesh` from numpy arrays containing the data + of a surface mesh. The first array contains the vertices, and the second array contains + faces formed from triples of the vertices. + + Parameters + ---------- + vertices: ``np.ndarray`` + A numpy array of shape (N, 3) storing the vertices of the surface mesh. + The first index labels the vertex, and the second index is the coordinate + (x, y, or z). + faces : ``np.ndarray`` + A numpy array of shape (M, 3) storing the indices of the vertices of each face + in the surface mesh. The first index labels the face, and the second index + labels the vertex index within the ``vertices`` array. + + Returns + ------- + :class:`.TriangleMesh` + The custom surface mesh geometry given by the vertices and faces provided. + + """ + import trimesh + + vertices = np.array(vertices) + faces = np.array(faces) + if len(vertices.shape) != 2 or vertices.shape[1] != 3: + raise ValidationError( + f"Provided 'vertices' must be an N x 3 array, given {vertices.shape}." + ) + if len(faces.shape) != 2 or faces.shape[1] != 3: + raise ValidationError(f"Provided 'faces' must be an M x 3 array, given {faces.shape}.") + return cls.from_triangles(trimesh.Trimesh(vertices, faces).triangles) + + @classmethod + @verify_packages_import(["trimesh"]) + def _triangles_to_trimesh( + cls, triangles: NDArray + ) -> Trimesh: # -> We need to get this out of the classes and into functional methods operating on a class (maybe still referenced to the class) + """Convert an (N, 3, 3) numpy array of triangles to a ``trimesh.Trimesh``.""" + import trimesh + + # ``triangles`` may contain autograd ``ArrayBox`` entries when differentiating + # geometry parameters. ``trimesh`` expects plain ``float`` values, so strip any + # tracing information before constructing the mesh. + triangles = get_static(anp.array(triangles)) + return trimesh.Trimesh(**trimesh.triangles.to_kwargs(triangles)) + + @classmethod + def from_height_grid( + cls, + axis: Ax, + direction: Literal["-", "+"], + base: float, + grid: tuple[np.ndarray, np.ndarray], + height: NDArray, + ) -> TriangleMesh: + """Construct a TriangleMesh object from grid based height information. + + Parameters + ---------- + axis : Ax + Axis of extrusion. + direction : Literal["-", "+"] + Direction of extrusion. + base : float + Coordinate of the base surface along the geometry's axis. + grid : Tuple[np.ndarray, np.ndarray] + Tuple of two one-dimensional arrays representing the sampling grid (XY, YZ, or ZX + corresponding to values of axis) + height : np.ndarray + Height values sampled on the given grid. Can be 1D (raveled) or 2D (matching grid mesh). + + Returns + ------- + TriangleMesh + The resulting TriangleMesh geometry object. + """ + + x_coords = grid[0] + y_coords = grid[1] + + nx = len(x_coords) + ny = len(y_coords) + nt = nx * ny + + x_mesh, y_mesh = np.meshgrid(x_coords, y_coords, indexing="ij") + + sign = 1 + if direction == "-": + sign = -1 + + flat_height = np.ravel(height) + if flat_height.shape[0] != nt: + raise ValueError( + f"Shape of flattened height array {flat_height.shape} does not match " + f"the number of grid points {nt}." + ) + + if np.any(flat_height < 0): + raise ValueError("All height values must be non-negative.") + + max_h = np.max(flat_height) + min_h_clip = fp_eps * max_h + flat_height = np.clip(flat_height, min_h_clip, inf) + + vertices_raw_list = [ + [np.ravel(x_mesh), np.ravel(y_mesh), base + sign * flat_height], # Alpha surface + [np.ravel(x_mesh), np.ravel(y_mesh), base * np.ones(nt)], + ] + + if direction == "-": + vertices_raw_list = vertices_raw_list[::-1] + + vertices = np.hstack(vertices_raw_list).T + vertices = np.roll(vertices, shift=axis - 2, axis=1) + + q0 = (np.arange(nx - 1)[:, None] * ny + np.arange(ny - 1)[None, :]).ravel() + q1 = (np.arange(1, nx)[:, None] * ny + np.arange(ny - 1)[None, :]).ravel() + q2 = (np.arange(1, nx)[:, None] * ny + np.arange(1, ny)[None, :]).ravel() + q3 = (np.arange(nx - 1)[:, None] * ny + np.arange(1, ny)[None, :]).ravel() + + q0_b = nt + q0 + q1_b = nt + q1 + q2_b = nt + q2 + q3_b = nt + q3 + + top_quads = np.stack((q0, q1, q2, q3), axis=-1) + bottom_quads = np.stack((q0_b, q3_b, q2_b, q1_b), axis=-1) + + s1_q0 = (0 * ny + np.arange(ny - 1)).ravel() + s1_q1 = (0 * ny + np.arange(1, ny)).ravel() + s1_q2 = (nt + 0 * ny + np.arange(1, ny)).ravel() + s1_q3 = (nt + 0 * ny + np.arange(ny - 1)).ravel() + side1_quads = np.stack((s1_q0, s1_q1, s1_q2, s1_q3), axis=-1) + + s2_q0 = ((nx - 1) * ny + np.arange(ny - 1)).ravel() + s2_q1 = (nt + (nx - 1) * ny + np.arange(ny - 1)).ravel() + s2_q2 = (nt + (nx - 1) * ny + np.arange(1, ny)).ravel() + s2_q3 = ((nx - 1) * ny + np.arange(1, ny)).ravel() + side2_quads = np.stack((s2_q0, s2_q1, s2_q2, s2_q3), axis=-1) + + s3_q0 = (np.arange(nx - 1) * ny + 0).ravel() + s3_q1 = (nt + np.arange(nx - 1) * ny + 0).ravel() + s3_q2 = (nt + np.arange(1, nx) * ny + 0).ravel() + s3_q3 = (np.arange(1, nx) * ny + 0).ravel() + side3_quads = np.stack((s3_q0, s3_q1, s3_q2, s3_q3), axis=-1) + + s4_q0 = (np.arange(nx - 1) * ny + ny - 1).ravel() + s4_q1 = (np.arange(1, nx) * ny + ny - 1).ravel() + s4_q2 = (nt + np.arange(1, nx) * ny + ny - 1).ravel() + s4_q3 = (nt + np.arange(nx - 1) * ny + ny - 1).ravel() + side4_quads = np.stack((s4_q0, s4_q1, s4_q2, s4_q3), axis=-1) + + all_quads = np.vstack( + (top_quads, bottom_quads, side1_quads, side2_quads, side3_quads, side4_quads) + ) + + triangles_list = [ + np.stack((all_quads[:, 0], all_quads[:, 1], all_quads[:, 3]), axis=-1), + np.stack((all_quads[:, 3], all_quads[:, 1], all_quads[:, 2]), axis=-1), + ] + tri_faces = np.vstack(triangles_list) + + return cls.from_vertices_faces(vertices=vertices, faces=tri_faces) + + @classmethod + def from_height_function( + cls, + axis: Ax, + direction: Literal["-", "+"], + base: float, + center: tuple[float, float], + size: tuple[float, float], + grid_size: tuple[int, int], + height_func: Callable[[np.ndarray, np.ndarray], np.ndarray], + ) -> TriangleMesh: + """Construct a TriangleMesh object from analytical expression of height function. + The height function should be vectorized to accept 2D meshgrid arrays. + + Parameters + ---------- + axis : Ax + Axis of extrusion. + direction : Literal["-", "+"] + Direction of extrusion. + base : float + Coordinate of the base rectangle along the geometry's axis. + center : Tuple[float, float] + Center of the base rectangle in the plane perpendicular to the extrusion axis + (XY, YZ, or ZX corresponding to values of axis). + size : Tuple[float, float] + Size of the base rectangle in the plane perpendicular to the extrusion axis + (XY, YZ, or ZX corresponding to values of axis). + grid_size : Tuple[int, int] + Number of grid points for discretization of the base rectangle + (XY, YZ, or ZX corresponding to values of axis). + height_func : Callable[[np.ndarray, np.ndarray], np.ndarray] + Vectorized function to compute height values from 2D meshgrid coordinate arrays. + It should take two ndarrays (x_mesh, y_mesh) and return an ndarray of heights. + + Returns + ------- + TriangleMesh + The resulting TriangleMesh geometry object. + """ + x_lin = np.linspace(center[0] - 0.5 * size[0], center[0] + 0.5 * size[0], grid_size[0]) + y_lin = np.linspace(center[1] - 0.5 * size[1], center[1] + 0.5 * size[1], grid_size[1]) + + x_mesh, y_mesh = np.meshgrid(x_lin, y_lin, indexing="ij") + + height_values = height_func(x_mesh, y_mesh) + + if not (isinstance(height_values, np.ndarray) and height_values.shape == x_mesh.shape): + raise ValueError( + f"The 'height_func' must return a NumPy array with shape {x_mesh.shape}, " + f"but got shape {getattr(height_values, 'shape', type(height_values))}." + ) + + return cls.from_height_grid( + axis=axis, + direction=direction, + base=base, + grid=(x_lin, y_lin), + height=height_values, + ) + + @cached_property + @verify_packages_import(["trimesh"]) + def trimesh( + self, + ) -> Trimesh: # -> We need to get this out of the classes and into functional methods operating on a class (maybe still referenced to the class) + """A ``trimesh.Trimesh`` object representing the custom surface mesh geometry.""" + return self._triangles_to_trimesh(self.triangles) + + @cached_property + def triangles(self) -> np.ndarray: + """The triangles of the surface mesh as an ``np.ndarray``.""" + if self.mesh_dataset is None: + raise DataError("Can't get triangles as 'mesh_dataset' is None.") + return np.asarray(get_static(self.mesh_dataset.surface_mesh.data)) + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + # currently ignores bounds + return self.trimesh.area + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + # currently ignores bounds + return self.trimesh.volume + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + if self.mesh_dataset is None: + return ((-inf, -inf, -inf), (inf, inf, inf)) + return self.trimesh.bounds + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for TriangleMesh. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + section = self.trimesh.section(plane_origin=origin, plane_normal=normal) + if section is None: + return [] + path, _ = section.to_2D(to_2D=to_2D) + return path.polygons_full + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for TriangleMesh. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentaton `_. + """ + + if self.mesh_dataset is None: + return [] + + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + + origin = self.unpop_axis(position, (0, 0), axis=axis) + normal = self.unpop_axis(1, (0, 0), axis=axis) + + mesh = self.trimesh + + try: + section = mesh.section(plane_origin=origin, plane_normal=normal) + + if section is None: + return [] + + # homogeneous transformation matrix to map to xy plane + mapping = np.eye(4) + + # translate to origin + mapping[3, :3] = -np.array(origin) + + # permute so normal is aligned with z axis + # and (y, z), (x, z), resp. (x, y) are aligned with (x, y) + identity = np.eye(3) + permutation = self.unpop_axis(identity[2], identity[0:2], axis=axis) + mapping[:3, :3] = np.array(permutation).T + + section2d, _ = section.to_2D(to_2D=mapping) + return list(section2d.polygons_full) + + except ValueError as e: + if not mesh.is_watertight: + log.warning( + "Unable to compute 'TriangleMesh.intersections_plane' " + "because the mesh was not watertight. Using bounding box instead. " + "This may be overly strict; consider using 'TriangleMesh.fill_holes' " + "to repair the non-watertight mesh." + ) + else: + log.warning( + "Unable to compute 'TriangleMesh.intersections_plane'. " + "Using bounding box instead." + ) + log.warning(f"Error encountered: {e}") + return self.bounding_box.intersections_plane(x=x, y=y, z=z, cleanup=cleanup) + + def inside(self, x: NDArray, y: NDArray, z: NDArray) -> np.ndarray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + + arrays = tuple(map(np.array, (x, y, z))) + self._ensure_equal_shape(*arrays) + arrays_flat = map(np.ravel, arrays) + arrays_stacked = np.stack(tuple(arrays_flat), axis=-1) + inside = self.trimesh.contains(arrays_stacked) + return inside.reshape(arrays[0].shape) + + @equal_aspect + @add_ax_if_none + def plot( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + ax: Ax = None, + **patch_kwargs: Any, + ) -> Ax: + """Plot geometry cross section at single (x,y,z) coordinate. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in y direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in z direction, only one of x,y,z can be specified to define plane. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + **patch_kwargs + Optional keyword arguments passed to the matplotlib patch plotting of structure. + For details on accepted values, refer to + `Matplotlib's documentation `_. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + log.warning( + "Plotting a 'TriangleMesh' may give inconsistent results " + "if the mesh is not unionized. We recommend unionizing all meshes before import. " + "A 'PermittivityMonitor' can be used to check that the mesh is loaded correctly." + ) + + return base.Geometry.plot(self, x=x, y=y, z=z, ax=ax, **patch_kwargs) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute adjoint derivatives for a ``TriangleMesh`` geometry.""" + vjps: AutogradFieldMap = {} + + if not self.mesh_dataset: + raise DataError("Can't compute derivatives without mesh data.") + + valid_paths = {("mesh_dataset", "surface_mesh")} + for path in derivative_info.paths: + if path not in valid_paths: + raise ValueError(f"No derivative defined w.r.t. 'TriangleMesh' field '{path}'.") + + if ("mesh_dataset", "surface_mesh") not in derivative_info.paths: + return vjps + + triangles = np.asarray(self.triangles, dtype=config.adjoint.gradient_dtype_float) + + # early exit if geometry is completely outside simulation bounds + sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds) + mesh_min, mesh_max = map(np.asarray, self.bounds) + if np.any(mesh_max < sim_min) or np.any(mesh_min > sim_max): + log.warning( + "'TriangleMesh' lies completely outside the simulation domain.", + log_once=True, + ) + zeros = np.zeros_like(triangles) + vjps[("mesh_dataset", "surface_mesh")] = zeros + return vjps + + # gather surface samples within the simulation bounds + dx = derivative_info.adaptive_vjp_spacing() + samples = self._collect_surface_samples( + triangles=triangles, + spacing=dx, + sim_min=sim_min, + sim_max=sim_max, + ) + + if samples["points"].shape[0] == 0: + zeros = np.zeros_like(triangles) + vjps[("mesh_dataset", "surface_mesh")] = zeros + return vjps + + interpolators = derivative_info.interpolators + if interpolators is None: + interpolators = derivative_info.create_interpolators( + dtype=config.adjoint.gradient_dtype_float + ) + + g = derivative_info.evaluate_gradient_at_points( + samples["points"], + samples["normals"], + samples["perps1"], + samples["perps2"], + interpolators, + ) + + # accumulate per-vertex contributions using barycentric weights + weights = (samples["weights"] * g).real + normals = samples["normals"] + faces = samples["faces"] + bary = samples["barycentric"] + + contrib_vec = weights[:, None] * normals + + triangle_grads = np.zeros_like(triangles, dtype=config.adjoint.gradient_dtype_float) + for vertex_idx in range(3): + scaled = contrib_vec * bary[:, vertex_idx][:, None] + np.add.at(triangle_grads[:, vertex_idx, :], faces, scaled) + + vjps[("mesh_dataset", "surface_mesh")] = triangle_grads + return vjps + + def _collect_surface_samples( + self, + triangles: NDArray, + spacing: float, + sim_min: NDArray, + sim_max: NDArray, + ) -> dict[str, np.ndarray]: + """Deterministic per-triangle sampling used historically.""" + + dtype = config.adjoint.gradient_dtype_float + tol = config.adjoint.edge_clip_tolerance + + sim_min = np.asarray(sim_min, dtype=dtype) + sim_max = np.asarray(sim_max, dtype=dtype) + + points_list: list[np.ndarray] = [] + normals_list: list[np.ndarray] = [] + perps1_list: list[np.ndarray] = [] + perps2_list: list[np.ndarray] = [] + weights_list: list[np.ndarray] = [] + faces_list: list[np.ndarray] = [] + bary_list: list[np.ndarray] = [] + + spacing = max(float(spacing), np.finfo(float).eps) + triangles_arr = np.asarray(triangles, dtype=dtype) + + sim_extents = sim_max - sim_min + valid_axes = np.abs(sim_extents) > tol + collapsed_indices = np.flatnonzero(np.isclose(sim_extents, 0.0, atol=tol)) + collapsed_axis: Optional[int] = None + plane_value: Optional[float] = None + if collapsed_indices.size == 1: + collapsed_axis = int(collapsed_indices[0]) + plane_value = float(sim_min[collapsed_axis]) + + warned = False + warning_msg = "Some triangles from the mesh lie outside the simulation bounds - this may lead to inaccurate gradients." + for face_index, tri in enumerate(triangles_arr): + area, normal = self._triangle_area_and_normal(tri) + if area <= AREA_SIZE_THRESHOLD: + continue + + perps = self._triangle_tangent_basis(tri, normal) + if perps is None: + continue + perp1, perp2 = perps + + if collapsed_axis is not None and plane_value is not None: + samples, outside_bounds = self._collect_surface_samples_2d( + triangle=tri, + face_index=face_index, + normal=normal, + perp1=perp1, + perp2=perp2, + spacing=spacing, + collapsed_axis=collapsed_axis, + plane_value=plane_value, + sim_min=sim_min, + sim_max=sim_max, + valid_axes=valid_axes, + tol=tol, + dtype=dtype, + ) + else: + samples, outside_bounds = self._collect_surface_samples_3d( + triangle=tri, + face_index=face_index, + normal=normal, + perp1=perp1, + perp2=perp2, + area=area, + spacing=spacing, + sim_min=sim_min, + sim_max=sim_max, + valid_axes=valid_axes, + tol=tol, + dtype=dtype, + ) + + if outside_bounds and not warned: + log.warning(warning_msg) + warned = True + + if samples is None: + continue + + points_list.append(samples["points"]) + normals_list.append(samples["normals"]) + perps1_list.append(samples["perps1"]) + perps2_list.append(samples["perps2"]) + weights_list.append(samples["weights"]) + faces_list.append(samples["faces"]) + bary_list.append(samples["barycentric"]) + + if not points_list: + return { + "points": np.zeros((0, 3), dtype=dtype), + "normals": np.zeros((0, 3), dtype=dtype), + "perps1": np.zeros((0, 3), dtype=dtype), + "perps2": np.zeros((0, 3), dtype=dtype), + "weights": np.zeros((0,), dtype=dtype), + "faces": np.zeros((0,), dtype=int), + "barycentric": np.zeros((0, 3), dtype=dtype), + } + + return { + "points": np.concatenate(points_list, axis=0), + "normals": np.concatenate(normals_list, axis=0), + "perps1": np.concatenate(perps1_list, axis=0), + "perps2": np.concatenate(perps2_list, axis=0), + "weights": np.concatenate(weights_list, axis=0), + "faces": np.concatenate(faces_list, axis=0), + "barycentric": np.concatenate(bary_list, axis=0), + } + + def _collect_surface_samples_2d( + self, + triangle: NDArray, + face_index: int, + normal: np.ndarray, + perp1: np.ndarray, + perp2: np.ndarray, + spacing: float, + collapsed_axis: int, + plane_value: float, + sim_min: np.ndarray, + sim_max: np.ndarray, + valid_axes: np.ndarray, + tol: float, + dtype: np.dtype, + ) -> tuple[Optional[dict[str, np.ndarray]], bool]: + """Collect samples when the simulation bounds collapse onto a 2D plane.""" + + segments = self._triangle_plane_segments( + triangle=triangle, axis=collapsed_axis, plane_value=plane_value, tol=tol + ) + + points: list[np.ndarray] = [] + normals: list[np.ndarray] = [] + perps1_list: list[np.ndarray] = [] + perps2_list: list[np.ndarray] = [] + weights: list[np.ndarray] = [] + faces: list[np.ndarray] = [] + barycentric: list[np.ndarray] = [] + outside_bounds = False + + for start, end in segments: + vec = end - start + length = float(np.linalg.norm(vec)) + if length <= tol: + continue + + subdivisions = max(1, int(np.ceil(length / spacing))) + t_vals = (np.arange(subdivisions, dtype=dtype) + 0.5) / subdivisions + sample_points = start[None, :] + t_vals[:, None] * vec[None, :] + bary = self._barycentric_coordinates(triangle, sample_points, tol) + + inside_mask = np.ones(sample_points.shape[0], dtype=bool) + if np.any(valid_axes): + min_bound = (sim_min - tol)[valid_axes] + max_bound = (sim_max + tol)[valid_axes] + coords = sample_points[:, valid_axes] + inside_mask = np.all(coords >= min_bound, axis=1) & np.all( + coords <= max_bound, axis=1 + ) + + outside_bounds = outside_bounds or (not np.all(inside_mask)) + if not np.any(inside_mask): + continue + + sample_points = sample_points[inside_mask] + bary_inside = bary[inside_mask] + n_inside = sample_points.shape[0] + + normal_tile = np.repeat(normal[None, :], n_inside, axis=0) + perp1_tile = np.repeat(perp1[None, :], n_inside, axis=0) + perp2_tile = np.repeat(perp2[None, :], n_inside, axis=0) + weights_tile = np.full(n_inside, length / subdivisions, dtype=dtype) + faces_tile = np.full(n_inside, face_index, dtype=int) + + points.append(sample_points) + normals.append(normal_tile) + perps1_list.append(perp1_tile) + perps2_list.append(perp2_tile) + weights.append(weights_tile) + faces.append(faces_tile) + barycentric.append(bary_inside) + + if not points: + return None, outside_bounds + + samples = { + "points": np.concatenate(points, axis=0), + "normals": np.concatenate(normals, axis=0), + "perps1": np.concatenate(perps1_list, axis=0), + "perps2": np.concatenate(perps2_list, axis=0), + "weights": np.concatenate(weights, axis=0), + "faces": np.concatenate(faces, axis=0), + "barycentric": np.concatenate(barycentric, axis=0), + } + return samples, outside_bounds + + def _collect_surface_samples_3d( + self, + triangle: NDArray, + face_index: int, + normal: np.ndarray, + perp1: np.ndarray, + perp2: np.ndarray, + area: float, + spacing: float, + sim_min: np.ndarray, + sim_max: np.ndarray, + valid_axes: np.ndarray, + tol: float, + dtype: np.dtype, + ) -> tuple[Optional[dict[str, np.ndarray]], bool]: + """Collect samples when the simulation bounds represent a full 3D region.""" + + edge_lengths = ( + np.linalg.norm(triangle[1] - triangle[0]), + np.linalg.norm(triangle[2] - triangle[1]), + np.linalg.norm(triangle[0] - triangle[2]), + ) + subdivisions = self._subdivision_count(area, spacing, edge_lengths) + barycentric = self._get_barycentric_samples(subdivisions, dtype) + num_samples = barycentric.shape[0] + base_weight = area / num_samples + + sample_points = barycentric @ triangle + + inside_mask = np.all( + sample_points[:, valid_axes] >= (sim_min - tol)[valid_axes], axis=1 + ) & np.all(sample_points[:, valid_axes] <= (sim_max + tol)[valid_axes], axis=1) + outside_bounds = not np.all(inside_mask) + if not np.any(inside_mask): + return None, outside_bounds + + sample_points = sample_points[inside_mask] + bary_inside = barycentric[inside_mask] + n_samples_inside = sample_points.shape[0] + + normal_tile = np.repeat(normal[None, :], n_samples_inside, axis=0) + perp1_tile = np.repeat(perp1[None, :], n_samples_inside, axis=0) + perp2_tile = np.repeat(perp2[None, :], n_samples_inside, axis=0) + weights_tile = np.full(n_samples_inside, base_weight, dtype=dtype) + faces_tile = np.full(n_samples_inside, face_index, dtype=int) + + samples = { + "points": sample_points, + "normals": normal_tile, + "perps1": perp1_tile, + "perps2": perp2_tile, + "weights": weights_tile, + "faces": faces_tile, + "barycentric": bary_inside, + } + return samples, outside_bounds + + @staticmethod + def _triangle_area_and_normal(triangle: NDArray) -> tuple[float, np.ndarray]: + """Return area and outward normal of the provided triangle.""" + + edge01 = triangle[1] - triangle[0] + edge02 = triangle[2] - triangle[0] + cross = np.cross(edge01, edge02) + norm = np.linalg.norm(cross) + if norm <= 0.0: + return 0.0, np.zeros(3, dtype=triangle.dtype) + normal = (cross / norm).astype(triangle.dtype, copy=False) + area = 0.5 * norm + return area, normal + + @staticmethod + def _triangle_plane_segments( + triangle: NDArray, axis: int, plane_value: float, tol: float + ) -> list[tuple[np.ndarray, np.ndarray]]: + """Return intersection segments between a triangle and an axis-aligned plane.""" + + vertices = np.asarray(triangle) + distances = vertices[:, axis] - plane_value + edges = ((0, 1), (1, 2), (2, 0)) + + segments: list[tuple[np.ndarray, np.ndarray]] = [] + points: list[np.ndarray] = [] + + def add_point(pt: np.ndarray) -> None: + for existing in points: + if np.linalg.norm(existing - pt) <= tol: + return + points.append(pt.copy()) + + for i, j in edges: + di = distances[i] + dj = distances[j] + vi = vertices[i] + vj = vertices[j] + + if abs(di) <= tol and abs(dj) <= tol: + segments.append((vi.copy(), vj.copy())) + continue + + if di * dj > 0.0: + continue + + if abs(di) <= tol: + add_point(vi) + continue + + if abs(dj) <= tol: + add_point(vj) + continue + + denom = di - dj + if abs(denom) <= tol: + continue + t = di / denom + if t < 0.0 or t > 1.0: + continue + point = vi + t * (vj - vi) + add_point(point) + + if segments: + return segments + + if len(points) >= 2: + return [(points[0], points[1])] + + return [] + + @staticmethod + def _barycentric_coordinates(triangle: NDArray, points: np.ndarray, tol: float) -> np.ndarray: + """Compute barycentric coordinates of ``points`` with respect to ``triangle``.""" + + pts = np.asarray(points, dtype=triangle.dtype) + v0 = triangle[0] + v1 = triangle[1] + v2 = triangle[2] + v0v1 = v1 - v0 + v0v2 = v2 - v0 + + d00 = float(np.dot(v0v1, v0v1)) + d01 = float(np.dot(v0v1, v0v2)) + d11 = float(np.dot(v0v2, v0v2)) + denom = d00 * d11 - d01 * d01 + if abs(denom) <= tol: + return np.tile( + np.array([1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], dtype=triangle.dtype), (pts.shape[0], 1) + ) + + v0p = pts - v0 + d20 = v0p @ v0v1 + d21 = v0p @ v0v2 + v = (d11 * d20 - d01 * d21) / denom + w = (d00 * d21 - d01 * d20) / denom + u = 1.0 - v - w + bary = np.stack((u, v, w), axis=1) + return bary.astype(triangle.dtype, copy=False) + + @classmethod + def _subdivision_count( + cls, + area: float, + spacing: float, + edge_lengths: Optional[tuple[float, float, float]] = None, + ) -> int: + """Determine the number of subdivisions needed for the given area and spacing.""" + + spacing = max(float(spacing), np.finfo(float).eps) + + target = np.sqrt(max(area, 0.0)) + area_based = np.ceil(np.sqrt(2.0) * target / spacing) + + edge_based = 0.0 + if edge_lengths: + max_edge = max(edge_lengths) + if max_edge > 0.0: + edge_based = np.ceil(max_edge / spacing) + + subdivisions = max(1, int(max(area_based, edge_based))) + return subdivisions + + def _get_barycentric_samples(self, subdivisions: int, dtype: np.dtype) -> np.ndarray: + """Return barycentric sample coordinates for a subdivision level.""" + + cache = self._barycentric_samples + if subdivisions not in cache: + cache[subdivisions] = self._build_barycentric_samples(subdivisions) + return cache[subdivisions].astype(dtype, copy=False) + + @staticmethod + def _build_barycentric_samples(subdivisions: int) -> np.ndarray: + """Construct barycentric sampling points for a given subdivision level.""" + + if subdivisions <= 1: + return np.array([[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]]) + + bary = [] + for i in range(subdivisions): + for j in range(subdivisions - i): + l1 = (i + 1.0 / 3.0) / subdivisions + l2 = (j + 1.0 / 3.0) / subdivisions + l0 = 1.0 - l1 - l2 + bary.append((l0, l1, l2)) + return np.asarray(bary, dtype=float) + + @staticmethod + def subdivide_faces(vertices: NDArray, faces: NDArray) -> tuple[np.ndarray, np.ndarray]: + """Uniformly subdivide each triangular face by inserting edge midpoints.""" + + midpoint_cache: dict[tuple[int, int], int] = {} + verts_list = [np.asarray(v, dtype=float) for v in vertices] + + def midpoint(i: int, j: int) -> int: + key = (i, j) if i < j else (j, i) + if key in midpoint_cache: + return midpoint_cache[key] + vm = 0.5 * (verts_list[i] + verts_list[j]) + verts_list.append(vm) + idx = len(verts_list) - 1 + midpoint_cache[key] = idx + return idx + + new_faces: list[tuple[int, int, int]] = [] + for tri in faces: + a = midpoint(tri[0], tri[1]) + b = midpoint(tri[1], tri[2]) + c = midpoint(tri[2], tri[0]) + new_faces.extend(((tri[0], a, c), (tri[1], b, a), (tri[2], c, b), (a, b, c))) + + verts_arr = np.asarray(verts_list, dtype=float) + return verts_arr, np.asarray(new_faces, dtype=int) + + @staticmethod + def _triangle_tangent_basis( + triangle: NDArray, normal: NDArray + ) -> Optional[tuple[np.ndarray, np.ndarray]]: + """Compute orthonormal tangential vectors for a triangle.""" + + tol = np.finfo(triangle.dtype).eps + edges = [triangle[1] - triangle[0], triangle[2] - triangle[0], triangle[2] - triangle[1]] + + edge = None + for candidate in edges: + length = np.linalg.norm(candidate) + if length > tol: + edge = (candidate / length).astype(triangle.dtype, copy=False) + break + + if edge is None: + return None + + perp1 = edge + perp2 = np.cross(normal, perp1) + perp2_norm = np.linalg.norm(perp2) + if perp2_norm <= tol: + return None + perp2 = (perp2 / perp2_norm).astype(triangle.dtype, copy=False) + return perp1, perp2 diff --git a/tidy3d/_common/components/geometry/polyslab.py b/tidy3d/_common/components/geometry/polyslab.py new file mode 100644 index 0000000000..220200f495 --- /dev/null +++ b/tidy3d/_common/components/geometry/polyslab.py @@ -0,0 +1,2774 @@ +"""Geometry extruded from polygonal shapes.""" + +from __future__ import annotations + +import math +from copy import copy +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +import autograd.numpy as np +import shapely +from autograd.tracer import getval +from numpy.polynomial.legendre import leggauss as _leggauss +from pydantic import Field, field_validator, model_validator + +from tidy3d._common.components.autograd import TracedArrayFloat2D, get_static +from tidy3d._common.components.autograd.types import TracedFloat +from tidy3d._common.components.autograd.utils import hasbox +from tidy3d._common.components.base import cached_property +from tidy3d._common.components.geometry import base, triangulation +from tidy3d._common.components.transformation import ReflectionFromPlane, RotationAroundAxis +from tidy3d._common.config import config +from tidy3d._common.constants import LARGE_NUMBER, MICROMETER, fp_eps +from tidy3d._common.exceptions import SetupError, Tidy3dImportError, ValidationError +from tidy3d._common.log import log +from tidy3d._common.packaging import verify_packages_import + +if TYPE_CHECKING: + from typing import Optional, Union + + from gdstk import Cell + from numpy.typing import NDArray + from pydantic import PositiveFloat + + from tidy3d._common.compat import Self + from tidy3d._common.components.autograd import AutogradFieldMap + from tidy3d._common.components.autograd.derivative_utils import DerivativeInfo + from tidy3d._common.components.types.base import ( + ArrayFloat1D, + ArrayFloat2D, + ArrayLike, + Axis, + Bound, + Coordinate, + MatrixReal4x4, + PlanePosition, + Shapely, + ) + +# sampling polygon along dilation for validating polygon to be +# non self-intersecting during the entire dilation process +_N_SAMPLE_POLYGON_INTERSECT = 5 + +_IS_CLOSE_RTOL = np.finfo(float).eps + +# Warn for too many divided polyslabs +_COMPLEX_POLYSLAB_DIVISIONS_WARN = 100 + +# Warn before triangulating large polyslabs due to inefficiency +_MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION = 500 + +_MIN_POLYGON_AREA = fp_eps + + +@lru_cache(maxsize=128) +def leggauss(n: int) -> tuple[NDArray, NDArray]: + """Cached version of leggauss with dtype conversions.""" + g, w = _leggauss(n) + return g.astype(config.adjoint.gradient_dtype_float, copy=False), w.astype( + config.adjoint.gradient_dtype_float, copy=False + ) + + +class PolySlab(base.Planar): + """Polygon extruded with optional sidewall angle along axis direction. + + Example + ------- + >>> vertices = np.array([(0,0), (1,0), (1,1)]) + >>> p = PolySlab(vertices=vertices, axis=2, slab_bounds=(-1, 1)) + """ + + slab_bounds: tuple[TracedFloat, TracedFloat] = Field( + title="Slab Bounds", + description="Minimum and maximum positions of the slab along axis dimension.", + units=MICROMETER, + ) + + dilation: float = Field( + 0.0, + title="Dilation", + description="Dilation of the supplied polygon by shifting each edge along its " + "normal outwards direction by a distance; a negative value corresponds to erosion.", + units=MICROMETER, + ) + + vertices: TracedArrayFloat2D = Field( + title="Vertices", + description="List of (d1, d2) defining the 2 dimensional positions of the polygon " + "face vertices at the ``reference_plane``. " + "The index of dimension should be in the ascending order: e.g. if " + "the slab normal axis is ``axis=y``, the coordinate of the vertices will be in (x, z)", + units=MICROMETER, + ) + + @staticmethod + def make_shapely_polygon(vertices: ArrayLike) -> shapely.Polygon: + """Make a shapely polygon from some vertices, first ensures they are untraced.""" + vertices = get_static(vertices) + return shapely.Polygon(vertices) + + @field_validator("slab_bounds") + @classmethod + def slab_bounds_order(cls, val: tuple[float, float]) -> tuple[float, float]: + """Maximum position of the slab should be no smaller than its minimal position.""" + if val[1] < val[0]: + raise SetupError( + "Polyslab.slab_bounds must be specified in the order of " + "minimum and maximum positions of the slab along the axis. " + f"But now the maximum {val[1]} is smaller than the minimum {val[0]}." + ) + return val + + @field_validator("vertices") + @classmethod + def correct_shape(cls, val: ArrayFloat2D) -> ArrayFloat2D: + """Makes sure vertices size is correct. Make sure no intersecting edges.""" + # overall shape of vertices + if val.shape[1] != 2: + raise SetupError( + "PolySlab.vertices must be a 2 dimensional array shaped (N, 2). " + f"Given array with shape of {val.shape}." + ) + # make sure no polygon splitting, islands, 0 area + poly_heal = shapely.make_valid(cls.make_shapely_polygon(val)) + if poly_heal.area < _MIN_POLYGON_AREA: + raise SetupError("The polygon almost collapses to a 1D curve.") + + if not poly_heal.geom_type == "Polygon" or len(poly_heal.interiors) > 0: + raise SetupError( + "Polygon is self-intersecting, resulting in " + "polygon splitting or generation of holes/islands. " + "A general treatment to self-intersecting polygon will be available " + "in future releases." + ) + return val + + @model_validator(mode="after") + def no_complex_self_intersecting_polygon_at_reference_plane(self: Self) -> Self: + """At the reference plane, check if the polygon is self-intersecting. + + There are two types of self-intersection that can occur during dilation: + 1) the one that creates holes/islands, or splits polygons, or removes everything; + 2) the one that does not. + + For 1), we issue an error since it is yet to be supported; + For 2), we heal the polygon, and warn that the polygon has been cleaned up. + """ + val = self.vertices + # no need to validate anything here + if math.isclose(self.dilation, 0): + return self + + val_np = PolySlab._proper_vertices(val) + dist = self.dilation + + # 0) fully eroded + if dist < 0 and dist < -PolySlab._maximal_erosion(val_np): + raise SetupError("Erosion value is too large. The polygon is fully eroded.") + + # no edge events + if not PolySlab._edge_events_detection(val_np, dist, ignore_at_dist=False): + return self + + poly_offset = PolySlab._shift_vertices(val_np, dist)[0] + if PolySlab._area(poly_offset) < fp_eps**2: + raise SetupError("Erosion value is too large. The polygon is fully eroded.") + + # edge events + poly_offset = shapely.make_valid(self.make_shapely_polygon(poly_offset)) + # 1) polygon split or create holes/islands + if not poly_offset.geom_type == "Polygon" or len(poly_offset.interiors) > 0: + raise SetupError( + "Dilation/Erosion value is too large, resulting in " + "polygon splitting or generation of holes/islands. " + "A general treatment to self-intersecting polygon will be available " + "in future releases." + ) + + # case 2 + log.warning( + "The dilation/erosion value is too large. resulting in a " + "self-intersecting polygon. " + "The vertices have been modified to make a valid polygon." + ) + return self + + @model_validator(mode="after") + def no_self_intersecting_polygon_during_extrusion(self: Self) -> Self: + """In this simple polyslab, we don't support self-intersecting polygons yet, meaning that + any normal cross section of the PolySlab cannot be self-intersecting. This part checks + if any self-interction will occur during extrusion with non-zero sidewall angle. + + There are two types of self-intersection, known as edge events, + that can occur during dilation: + 1) neighboring vertex-vertex crossing. This type of edge event can be treated with + ``ComplexPolySlab`` which divides the polyslab into a list of simple polyslabs. + + 2) other types of edge events that can create holes/islands or split polygons. + To detect this, we sample _N_SAMPLE_POLYGON_INTERSECT cross sections to see if any creation + of polygons/holes, and changes in vertices number. + """ + val = self.vertices + + # no need to validate anything here + # sidewall_angle may be autograd-traced; use static value for this check only + if math.isclose(getval(self.sidewall_angle), 0): + return self + + # apply dilation + poly_ref = PolySlab._proper_vertices(val) + if not math.isclose(self.dilation, 0): + poly_ref = PolySlab._shift_vertices(poly_ref, self.dilation)[0] + poly_ref = PolySlab._heal_polygon(poly_ref) + + slab_bounds = get_static(self.slab_bounds) + slab_min, slab_max = slab_bounds + + # first, check vertex-vertex crossing at any point during extrusion + length = slab_bounds[1] - slab_bounds[0] + dist = [-length * np.tan(self.sidewall_angle)] + # reverse the dilation value if it's defined on the top + if self.reference_plane == "top": + dist = [-dist[0]] + # for middle, both direction needs to be examined + elif self.reference_plane == "middle": + dist = [dist[0] / 2, -dist[0] / 2] + + # capture vertex crossing events + max_thick = [] + for dist_val in dist: + max_dist = PolySlab._neighbor_vertices_crossing_detection(poly_ref, dist_val) + + if max_dist is not None: + max_thick.append(max_dist / abs(dist_val) * length) + + if len(max_thick) > 0: + max_thick = min(max_thick) + raise SetupError( + "Sidewall angle or structure thickness is so large that the polygon " + "is self-intersecting during extrusion. " + f"Please either reduce structure thickness to be < {max_thick:.3e}, " + "or use our plugin 'ComplexPolySlab' to divide the complex polyslab " + "into a list of simple polyslabs." + ) + + # vertex-edge crossing event. + for dist_val in dist: + if PolySlab._edge_events_detection(poly_ref, dist_val): + raise SetupError( + "Sidewall angle or structure thickness is too large, " + "resulting in polygon splitting or generation of holes/islands. " + "A general treatment to self-intersecting polygon will be available " + "in future releases." + ) + return self + + @classmethod + def from_gds( + cls, + gds_cell: Cell, + axis: Axis, + slab_bounds: tuple[float, float], + gds_layer: int, + gds_dtype: Optional[int] = None, + gds_scale: PositiveFloat = 1.0, + dilation: float = 0.0, + sidewall_angle: float = 0, + reference_plane: PlanePosition = "middle", + ) -> list[PolySlab]: + """Import :class:`PolySlab` from a ``gdstk.Cell``. + + Parameters + ---------- + gds_cell : gdstk.Cell + ``gdstk.Cell`` containing 2D geometric data. + axis : int + Integer index into the polygon's slab axis. (0,1,2) -> (x,y,z). + slab_bounds: tuple[float, float] + Minimum and maximum positions of the slab along ``axis``. + gds_layer : int + Layer index in the ``gds_cell``. + gds_dtype : int = None + Data-type index in the ``gds_cell``. + If ``None``, imports all data for this layer into the returned list. + gds_scale : float = 1.0 + Length scale used in GDS file in units of MICROMETER. + For example, if gds file uses nanometers, set ``gds_scale=1e-3``. + Must be positive. + dilation : float = 0.0 + Dilation of the polygon in the base by shifting each edge along its + normal outwards direction by a distance; + a negative value corresponds to erosion. + sidewall_angle : float = 0 + Angle of the sidewall. + ``sidewall_angle=0`` (default) specifies vertical wall, + while ``0 list[ArrayFloat2D]: + """Import :class:`PolySlab` from a ``gdstk.Cell``. + + Parameters + ---------- + gds_cell : gdstk.Cell + ``gdstk.Cell`` containing 2D geometric data. + gds_layer : int + Layer index in the ``gds_cell``. + gds_dtype : int = None + Data-type index in the ``gds_cell``. + If ``None``, imports all data for this layer into the returned list. + gds_scale : float = 1.0 + Length scale used in GDS file in units of MICROMETER. + For example, if gds file uses nanometers, set ``gds_scale=1e-3``. + Must be positive. + + Returns + ------- + list[ArrayFloat2D] + List of :class:`.ArrayFloat2D` + """ + import gdstk + + gds_cell_class_name = str(gds_cell.__class__) + if not isinstance(gds_cell, gdstk.Cell): + if ( + "gdstk" in gds_cell_class_name + ): # Check if it might be a gdstk cell but gdstk is not found + raise Tidy3dImportError( + "Module 'gdstk' not found. It is required to import gdstk cells." + ) + raise ValueError( + f"validate 'gds_cell' of type '{gds_cell_class_name}' " + "does not seem to be associated with 'gdstk' package " + "and therefore can't be loaded by Tidy3D." + ) + + all_vertices = base.Geometry.load_gds_vertices_gdstk( + gds_cell=gds_cell, + gds_layer=gds_layer, + gds_dtype=gds_dtype, + gds_scale=gds_scale, + ) + + # convert vertices into polyslabs + polygons = [PolySlab.make_shapely_polygon(vertices).buffer(0) for vertices in all_vertices] + polys_union = shapely.unary_union(polygons, grid_size=base.POLY_GRID_SIZE) + + if polys_union.geom_type == "Polygon": + all_vertices = [np.array(polys_union.exterior.coords)] + elif polys_union.geom_type == "MultiPolygon": + all_vertices = [np.array(polygon.exterior.coords) for polygon in polys_union.geoms] + return all_vertices + + @property + def center_axis(self) -> float: + """Gets the position of the center of the geometry in the out of plane dimension.""" + zmin, zmax = self.slab_bounds + if np.isneginf(zmin) and np.isposinf(zmax): + return 0.0 + zmin = max(zmin, -LARGE_NUMBER) + zmax = min(zmax, LARGE_NUMBER) + return (zmax + zmin) / 2.0 + + @property + def length_axis(self) -> float: + """Gets the length of the geometry along the out of plane dimension.""" + zmin, zmax = self.slab_bounds + return zmax - zmin + + @property + def finite_length_axis(self) -> float: + """Gets the length of the PolySlab along the out of plane dimension. + First clips the slab bounds to LARGE_NUMBER and then returns difference. + """ + zmin, zmax = self.slab_bounds + zmin = max(zmin, -LARGE_NUMBER) + zmax = min(zmax, LARGE_NUMBER) + return zmax - zmin + + @cached_property + def reference_polygon(self) -> NDArray: + """The polygon at the reference plane. + + Returns + ------- + ArrayLike[float, float] + The vertices of the polygon at the reference plane. + """ + vertices = self._proper_vertices(self.vertices) + if math.isclose(self.dilation, 0): + return vertices + offset_vertices = self._shift_vertices(vertices, self.dilation)[0] + return self._heal_polygon(offset_vertices) + + @cached_property + def middle_polygon(self) -> NDArray: + """The polygon at the middle. + + Returns + ------- + ArrayLike[float, float] + The vertices of the polygon at the middle. + """ + + dist = self._extrusion_length_to_offset_distance(self.finite_length_axis / 2) + if self.reference_plane == "bottom": + return self._shift_vertices(self.reference_polygon, dist)[0] + if self.reference_plane == "top": + return self._shift_vertices(self.reference_polygon, -dist)[0] + # middle case + return self.reference_polygon + + @cached_property + def base_polygon(self) -> NDArray: + """The polygon at the base, derived from the ``middle_polygon``. + + Returns + ------- + ArrayLike[float, float] + The vertices of the polygon at the base. + """ + if self.reference_plane == "bottom": + return self.reference_polygon + dist = self._extrusion_length_to_offset_distance(-self.finite_length_axis / 2) + return self._shift_vertices(self.middle_polygon, dist)[0] + + @cached_property + def top_polygon(self) -> NDArray: + """The polygon at the top, derived from the ``middle_polygon``. + + Returns + ------- + ArrayLike[float, float] + The vertices of the polygon at the top. + """ + if self.reference_plane == "top": + return self.reference_polygon + dist = self._extrusion_length_to_offset_distance(self.finite_length_axis / 2) + return self._shift_vertices(self.middle_polygon, dist)[0] + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + if self.slab_bounds[0] != self.slab_bounds[1]: + raise ValidationError("'Medium2D' requires the 'PolySlab' bounds to be equal.") + return self.axis + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> PolySlab: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + if axis != self.axis: + raise ValueError( + f"'_update_from_bounds' may only be applied along axis '{self.axis}', " + f"but was given axis '{axis}'." + ) + return self.updated_copy(slab_bounds=tuple(bounds)) + + @cached_property + def is_ccw(self) -> bool: + """Is this ``PolySlab`` CCW-oriented?""" + return PolySlab._area(self.vertices) > 0 + + def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Note + ---- + For slanted sidewalls, this function only works if x, y, and z are arrays produced by a + ``meshgrid call``, i.e. 3D arrays and each is constant along one axis. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + self._ensure_equal_shape(x, y, z) + + z, (x, y) = self.pop_axis((x, y, z), axis=self.axis) + + z0 = self.center_axis + dist_z = np.abs(z - z0) + inside_height = dist_z <= (self.finite_length_axis / 2) + + # avoid going into face checking if no points are inside slab bounds + if not np.any(inside_height): + return inside_height + + # check what points are inside polygon cross section (face) + z_local = z - z0 # distance to the middle + dist = -z_local * self._tanq + + if isinstance(x, np.ndarray): + inside_polygon = np.zeros_like(inside_height) + xs_slab = x[inside_height] + ys_slab = y[inside_height] + + # vertical sidewall + if math.isclose(self.sidewall_angle, 0): + face_polygon = shapely.Polygon(self.reference_polygon).buffer(fp_eps) + shapely.prepare(face_polygon) + inside_polygon_slab = shapely.contains_xy(face_polygon, x=xs_slab, y=ys_slab) + inside_polygon[inside_height] = inside_polygon_slab + # slanted sidewall, offsetting vertices at each z + else: + # a helper function for moving axis + def _move_axis(arr: NDArray) -> NDArray: + return np.moveaxis(arr, source=self.axis, destination=-1) + + def _move_axis_reverse(arr: NDArray) -> NDArray: + return np.moveaxis(arr, source=-1, destination=self.axis) + + inside_polygon_axis = _move_axis(inside_polygon) + x_axis = _move_axis(x) + y_axis = _move_axis(y) + + for z_i in range(z.shape[self.axis]): + if not _move_axis(inside_height)[0, 0, z_i]: + continue + vertices_z = self._shift_vertices( + self.middle_polygon, _move_axis(dist)[0, 0, z_i] + )[0] + face_polygon = shapely.Polygon(vertices_z).buffer(fp_eps) + shapely.prepare(face_polygon) + xs = x_axis[:, :, 0].flatten() + ys = y_axis[:, :, 0].flatten() + inside_polygon_slab = shapely.contains_xy(face_polygon, x=xs, y=ys) + inside_polygon_axis[:, :, z_i] = inside_polygon_slab.reshape(x_axis.shape[:2]) + inside_polygon = _move_axis_reverse(inside_polygon_axis) + else: + vertices_z = self._shift_vertices(self.middle_polygon, dist)[0] + face_polygon = self.make_shapely_polygon(vertices_z).buffer(fp_eps) + point = shapely.Point(x, y) + inside_polygon = face_polygon.covers(point) + return inside_height * inside_polygon + + @verify_packages_import(["trimesh"]) + def _do_intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for PolySlab geometry. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + import trimesh + + if len(self.base_polygon) > _MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION: + log.warning( + f"Processing PolySlabs with over {_MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION} vertices can be slow.", + log_once=True, + ) + base_triangles = triangulation.triangulate(self.base_polygon) + top_triangles = ( + base_triangles + if math.isclose(self.sidewall_angle, 0) + else triangulation.triangulate(self.top_polygon) + ) + + n = len(self.base_polygon) + faces = ( + [[a, b, c] for c, b, a in base_triangles] + + [[n + a, n + b, n + c] for a, b, c in top_triangles] + + [(i, (i + 1) % n, n + i) for i in range(n)] + + [((i + 1) % n, n + ((i + 1) % n), n + i) for i in range(n)] + ) + + x = np.hstack((self.base_polygon[:, 0], self.top_polygon[:, 0])) + y = np.hstack((self.base_polygon[:, 1], self.top_polygon[:, 1])) + z = np.hstack((np.full(n, self.slab_bounds[0]), np.full(n, self.slab_bounds[1]))) + vertices = np.vstack(self.unpop_axis(z, (x, y), self.axis)).T + mesh = trimesh.Trimesh(vertices, faces) + + section = mesh.section(plane_origin=origin, plane_normal=normal) + if section is None: + return [] + path, _ = section.to_2D(to_2D=to_2D) + return path.polygons_full + + def _intersections_normal(self, z: float, quad_segs: Optional[int] = None) -> list[Shapely]: + """Find shapely geometries intersecting planar geometry with axis normal to slab. + + Parameters + ---------- + z : float + Position along the axis normal to slab. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + if math.isclose(self.sidewall_angle, 0): + return [self.make_shapely_polygon(self.reference_polygon)] + + z0 = self.center_axis + z_local = z - z0 # distance to the middle + dist = -z_local * self._tanq + vertices_z = self._shift_vertices(self.middle_polygon, dist)[0] + return [self.make_shapely_polygon(vertices_z)] + + def _intersections_side(self, position: float, axis: int) -> list[Shapely]: + """Find shapely geometries intersecting planar geometry with axis orthogonal to slab. + + For slanted polyslab, the procedure is as follows, + 1) Find out all z-coordinates where the plane will intersect directly with a vertex. + Denote the coordinates as (z_0, z_1, z_2, ... ) + 2) Find out all polygons that can be formed between z_i and z_{i+1}. There are two + types of polygons: + a) formed by the plane intersecting the edges + b) formed by the plane intersecting the vertices. + For either type, one needs to compute: + i) intersecting position + ii) angle between the plane and the intersecting edge + For a), both are straightforward to compute; while for b), one needs to compute + which edge the plane will slide into. + 3) Looping through z_i, and merge all polygons. The partition by z_i is because once + the plane intersects the vertex, it can intersect with other edges during + the extrusion. + + Parameters + ---------- + position : float + Position along ``axis``. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + + # find out all z_i where the plane will intersect the vertex + z0 = self.center_axis + z_base = z0 - self.finite_length_axis / 2 + + axis_ordered = self._order_axis(axis) + height_list = self._find_intersecting_height(position, axis_ordered) + polys = [] + + # looping through z_i to assemble the polygons + height_list = np.append(height_list, self.finite_length_axis) + h_base = 0.0 + for h_top in height_list: + # length within between top and bottom + h_length = h_top - h_base + + # coordinate of each subsection + z_min = z_base + h_base + z_max = np.inf if np.isposinf(h_top) else z_base + h_top + + # for vertical sidewall, no need for complications + if math.isclose(self.sidewall_angle, 0): + ints_y, ints_angle = self._find_intersecting_ys_angle_vertical( + self.reference_polygon, position, axis_ordered + ) + else: + # for slanted sidewall, move up by `fp_eps` in case vertices are degenerate at the base. + dist = -(h_base - self.finite_length_axis / 2 + fp_eps) * self._tanq + vertices = self._shift_vertices(self.middle_polygon, dist)[0] + ints_y, ints_angle = self._find_intersecting_ys_angle_slant( + vertices, position, axis_ordered + ) + + # make polygon with intersections and z axis information + for y_index in range(len(ints_y) // 2): + y_min = ints_y[2 * y_index] + y_max = ints_y[2 * y_index + 1] + minx, miny = self._order_by_axis(plane_val=y_min, axis_val=z_min, axis=axis) + maxx, maxy = self._order_by_axis(plane_val=y_max, axis_val=z_max, axis=axis) + + if math.isclose(self.sidewall_angle, 0): + polys.append(self.make_shapely_box(minx, miny, maxx, maxy)) + else: + angle_min = ints_angle[2 * y_index] + angle_max = ints_angle[2 * y_index + 1] + + angle_min = np.arctan(np.tan(self.sidewall_angle) / np.sin(angle_min)) + angle_max = np.arctan(np.tan(self.sidewall_angle) / np.sin(angle_max)) + + dy_min = h_length * np.tan(angle_min) + dy_max = h_length * np.tan(angle_max) + + x1, y1 = self._order_by_axis(plane_val=y_min, axis_val=z_min, axis=axis) + x2, y2 = self._order_by_axis(plane_val=y_max, axis_val=z_min, axis=axis) + x3, y3 = self._order_by_axis( + plane_val=y_max - dy_max, axis_val=z_max, axis=axis + ) + x4, y4 = self._order_by_axis( + plane_val=y_min + dy_min, axis_val=z_max, axis=axis + ) + vertices = ((x1, y1), (x2, y2), (x3, y3), (x4, y4)) + polys.append(self.make_shapely_polygon(vertices).buffer(0)) + # update the base coordinate for the next subsection + h_base = h_top + + # merge touching polygons + polys_union = shapely.unary_union(polys, grid_size=base.POLY_GRID_SIZE) + if polys_union.geom_type == "Polygon": + return [polys_union] + if polys_union.geom_type == "MultiPolygon": + return polys_union.geoms + # in other cases, just return the original unmerged polygons + return polys + + def _find_intersecting_height(self, position: float, axis: int) -> NDArray: + """Found a list of height where the plane will intersect with the vertices; + For vertical sidewall, just return np.array([]). + Assumes axis is handles so this function works on xy plane. + + Parameters + ---------- + position : float + position along axis. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + np.ndarray + Height (relative to the base) where the plane will intersect with vertices. + """ + if math.isclose(self.sidewall_angle, 0): + return np.array([]) + + # shift rate + dist = 1.0 + shift_x, shift_y = PolySlab._shift_vertices(self.middle_polygon, dist)[2] + shift_val = shift_x if axis == 0 else shift_y + shift_val[np.isclose(shift_val, 0, rtol=_IS_CLOSE_RTOL)] = np.inf # for static vertices + + # distance to the plane in the direction of vertex shifting + distance = self.middle_polygon[:, axis] - position + height = distance / self._tanq / shift_val + self.finite_length_axis / 2 + height = np.unique(height) + # further filter very close ones + is_not_too_close = np.insert((np.diff(height) > fp_eps), 0, True) + height = height[is_not_too_close] + + height = height[height > fp_eps] + height = height[height < self.finite_length_axis - fp_eps] + return height + + def _find_intersecting_ys_angle_vertical( + self, + vertices: NDArray, + position: float, + axis: int, + exclude_on_vertices: bool = False, + ) -> tuple[NDArray, NDArray, NDArray]: + """Finds pairs of forward and backwards vertices where polygon intersects position at axis, + Find intersection point (in y) assuming straight line,and intersecting angle between plane + and edges. (For unslanted polyslab). + Assumes axis is handles so this function works on xy plane. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + position : float + position along axis. + axis : int + Integer index into 'xyz' (0,1,2). + exclude_on_vertices : bool = False + Whether to exclude those intersecting directly with the vertices. + + Returns + ------- + Union[np.ndarray, np.ndarray] + List of intersection points along y direction. + List of angles between plane and edges. + """ + + vertices_axis = vertices + + # flip vertices x,y for axis = y + if axis == 1: + vertices_axis = np.roll(vertices_axis, shift=1, axis=1) + + # get the forward vertices + vertices_f = np.roll(vertices_axis, shift=-1, axis=0) + + # x coordinate of the two sets of vertices + x_vertices_f, _ = vertices_f.T + x_vertices_axis, _ = vertices_axis.T + + # Find which segments intersect: + # 1. Strictly crossing: one endpoint strictly left, one strictly right + # 2. Touching: exactly one endpoint on the plane (xor), which excludes + # edges lying entirely on the plane (both endpoints at position). + orig_on_plane = np.isclose(x_vertices_axis, position, rtol=_IS_CLOSE_RTOL) + f_on_plane = np.roll(orig_on_plane, shift=-1) + crosses_b = (x_vertices_axis > position) & (x_vertices_f < position) + crosses_f = (x_vertices_axis < position) & (x_vertices_f > position) + + if exclude_on_vertices: + # exclude vertices at the position + not_touching = np.logical_not(orig_on_plane | f_on_plane) + intersects_segment = (crosses_b | crosses_f) & not_touching + else: + single_touch = np.logical_xor(orig_on_plane, f_on_plane) + intersects_segment = crosses_b | crosses_f | single_touch + + iverts_b = vertices_axis[intersects_segment] + iverts_f = vertices_f[intersects_segment] + + # intersecting positions and angles + ints_y = [] + ints_angle = [] + for vertices_f_local, vertices_b_local in zip(iverts_b, iverts_f): + x1, y1 = vertices_f_local + x2, y2 = vertices_b_local + slope = (y2 - y1) / (x2 - x1) + y = y1 + slope * (position - x1) + ints_y.append(y) + ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope))) + + ints_y = np.array(ints_y) + ints_angle = np.array(ints_angle) + + # Get rid of duplicate intersection points (vertices counted twice if directly on position) + ints_y_sort, sort_index = np.unique(ints_y, return_index=True) + ints_angle_sort = ints_angle[sort_index] + + # For tangent touches (vertex on plane, both neighbors on same side), + # add y-value back to form a degenerate pair + if not exclude_on_vertices: + n = len(vertices_axis) + for idx in np.where(orig_on_plane)[0]: + prev_on = orig_on_plane[(idx - 1) % n] + next_on = orig_on_plane[(idx + 1) % n] + if not prev_on and not next_on: + prev_side = x_vertices_axis[(idx - 1) % n] > position + next_side = x_vertices_axis[(idx + 1) % n] > position + if prev_side == next_side: + ints_y_sort = np.append(ints_y_sort, vertices_axis[idx, 1]) + ints_angle_sort = np.append(ints_angle_sort, 0) + + sort_index = np.argsort(ints_y_sort) + ints_y_sort = ints_y_sort[sort_index] + ints_angle_sort = ints_angle_sort[sort_index] + return ints_y_sort, ints_angle_sort + + def _find_intersecting_ys_angle_slant( + self, vertices: NDArray, position: float, axis: int + ) -> tuple[NDArray, NDArray, NDArray]: + """Finds pairs of forward and backwards vertices where polygon intersects position at axis, + Find intersection point (in y) assuming straight line,and intersecting angle between plane + and edges. (For slanted polyslab) + Assumes axis is handles so this function works on xy plane. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + position : float + position along axis. + axis : int + Integer index into 'xyz' (0,1,2). + + Returns + ------- + Union[np.ndarray, np.ndarray] + List of intersection points along y direction. + List of angles between plane and edges. + """ + + vertices_axis = vertices.copy() + # flip vertices x,y for axis = y + if axis == 1: + vertices_axis = np.roll(vertices_axis, shift=1, axis=1) + + # get the forward vertices + vertices_f = np.roll(vertices_axis, shift=-1, axis=0) + # get the backward vertices + vertices_b = np.roll(vertices_axis, shift=1, axis=0) + + ## First part, plane intersects with edges, same as vertical + ints_y, ints_angle = self._find_intersecting_ys_angle_vertical( + vertices, position, axis, exclude_on_vertices=True + ) + ints_y = ints_y.tolist() + ints_angle = ints_angle.tolist() + + ## Second part, plane intersects directly with vertices + # vertices on the intersection + intersects_on = np.isclose(vertices_axis[:, 0], position, rtol=_IS_CLOSE_RTOL) + iverts_on = vertices_axis[intersects_on] + # position of the neighbouring vertices + iverts_b = vertices_b[intersects_on] + iverts_f = vertices_f[intersects_on] + # shift rate + dist = -np.sign(self.sidewall_angle) + shift_x, shift_y = self._shift_vertices(self.middle_polygon, dist)[2] + shift_val = shift_x if axis == 0 else shift_y + shift_val = shift_val[intersects_on] + + for vertices_f_local, vertices_b_local, vertices_on_local, shift_local in zip( + iverts_f, iverts_b, iverts_on, shift_val + ): + x_on, y_on = vertices_on_local + x_f, y_f = vertices_f_local + x_b, y_b = vertices_b_local + + num_added = 0 # keep track the number of added vertices + slope = [] # list of slopes for added vertices + # case 1, shifting velocity is 0 + if np.isclose(shift_local, 0, rtol=_IS_CLOSE_RTOL): + ints_y.append(y_on) + # Slope w.r.t. forward and backward should equal, + # just pick one of them. + slope.append((y_on - y_b) / (x_on - x_b)) + ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope[0]))) + continue + + # case 2, shifting towards backward direction + if (x_b - position) * shift_local < 0: + ints_y.append(y_on) + slope.append((y_on - y_b) / (x_on - x_b)) + num_added += 1 + + # case 3, shifting towards forward direction + if (x_f - position) * shift_local < 0: + ints_y.append(y_on) + slope.append((y_on - y_f) / (x_on - x_f)) + num_added += 1 + + # in case 2, and case 3, if just num_added = 1 + if num_added == 1: + ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope[0]))) + # if num_added = 2, the order of the two new vertices needs to handled correctly; + # it should be sorted according to the -slope * moving direction + elif num_added == 2: + dressed_slope = [-s_i * shift_local for s_i in slope] + sort_index = np.argsort(np.array(dressed_slope)) + sorted_slope = np.array(slope)[sort_index] + + ints_angle.append(np.pi / 2 - np.arctan(np.abs(sorted_slope[0]))) + ints_angle.append(np.pi / 2 - np.arctan(np.abs(sorted_slope[1]))) + + ints_y = np.array(ints_y) + ints_angle = np.array(ints_angle) + + sort_index = np.argsort(ints_y) + ints_y_sort = ints_y[sort_index] + ints_angle_sort = ints_angle[sort_index] + + return ints_y_sort, ints_angle_sort + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. The dilation and slant angle are not + taken into account exactly for speed. Instead, the polygon may be slightly smaller than + the returned bounds, but it should always be fully contained. + + Returns + ------- + tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + + # check for the maximum possible contribution from dilation/slant on each side + max_offset = self.dilation + # sidewall_angle may be autograd-traced; unbox for this check + if not math.isclose(getval(self.sidewall_angle), 0): + if self.reference_plane == "bottom": + max_offset += max(0, -self._tanq * self.finite_length_axis) + elif self.reference_plane == "top": + max_offset += max(0, self._tanq * self.finite_length_axis) + elif self.reference_plane == "middle": + max_offset += max(0, abs(self._tanq) * self.finite_length_axis / 2) + + # special care when dilated + if max_offset > 0: + dilated_vertices = self._shift_vertices( + self._proper_vertices(self.vertices), max_offset + )[0] + xmin, ymin = np.amin(dilated_vertices, axis=0) + xmax, ymax = np.amax(dilated_vertices, axis=0) + else: + # otherwise, bounds are directly based on the supplied vertices + xmin, ymin = np.amin(self.vertices, axis=0) + xmax, ymax = np.amax(self.vertices, axis=0) + + # get bounds in (local) z + zmin, zmax = self.slab_bounds + + # rearrange axes + coords_min = self.unpop_axis(zmin, (xmin, ymin), axis=self.axis) + coords_max = self.unpop_axis(zmax, (xmax, ymax), axis=self.axis) + return (tuple(coords_min), tuple(coords_max)) + + def _extrusion_length_to_offset_distance(self, extrusion: float) -> float: + """Convert extrusion length to offset distance.""" + if math.isclose(self.sidewall_angle, 0): + return 0 + return -extrusion * self._tanq + + @staticmethod + def _area(vertices: NDArray) -> float: + """Compute the signed polygon area (positive for CCW orientation). + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + + Returns + ------- + float + Signed polygon area (positive for CCW orientation). + """ + vert_shift = np.roll(vertices, axis=0, shift=-1) + + xs, ys = vertices.T + xs_shift, ys_shift = vert_shift.T + + term1 = xs * ys_shift + term2 = ys * xs_shift + return np.sum(term1 - term2) * 0.5 + + @staticmethod + def _perimeter(vertices: NDArray) -> float: + """Compute the polygon perimeter. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + + Returns + ------- + float + Polygon perimeter. + """ + + vert_shift = np.roll(vertices, axis=0, shift=-1) + squared_diffs = (vertices - vert_shift) ** 2 + + # distance along each edge + dists = np.sqrt(squared_diffs.sum(axis=-1)) + + # total distance along all edges + return np.sum(dists) + + @staticmethod + def _orient(vertices: NDArray) -> NDArray: + """Return a CCW-oriented polygon. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + + Returns + ------- + np.ndarray + Vertices of a CCW-oriented polygon. + """ + return vertices if PolySlab._area(vertices) > 0 else vertices[::-1, :] + + @staticmethod + def _remove_duplicate_vertices(vertices: NDArray) -> NDArray: + """Remove redundant/identical nearest neighbour vertices. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + + Returns + ------- + np.ndarray + Vertices of polygon. + """ + + vertices_f = np.roll(vertices, shift=-1, axis=0) + vertices_diff = np.linalg.norm(vertices - vertices_f, axis=1) + return vertices[~np.isclose(vertices_diff, 0, rtol=_IS_CLOSE_RTOL)] + + @staticmethod + def _proper_vertices(vertices: ArrayFloat2D) -> NDArray: + """convert vertices to np.array format, + removing duplicate neighbouring vertices, + and oriented in CCW direction. + + Returns + ------- + ArrayLike[float, float] + The vertices of the polygon for internal use. + """ + vertices_np = np.array(vertices) + return PolySlab._orient(PolySlab._remove_duplicate_vertices(vertices_np)) + + @staticmethod + def _edge_events_detection( + proper_vertices: NDArray, dilation: float, ignore_at_dist: bool = True + ) -> bool: + """Detect any edge events within the offset distance ``dilation``. + If ``ignore_at_dist=True``, the edge event at ``dist`` is ignored. + """ + + # ignore the event that occurs right at the offset distance + if ignore_at_dist: + dilation -= fp_eps * dilation / abs(dilation) + # number of vertices before offsetting + num_vertices = proper_vertices.shape[0] + + # 0) fully eroded? + if dilation < 0 and dilation < -PolySlab._maximal_erosion(proper_vertices): + return True + + # sample at a few dilation values + dist_list = ( + dilation + * np.linspace( + 0, 1, 1 + _N_SAMPLE_POLYGON_INTERSECT, dtype=config.adjoint.gradient_dtype_float + )[1:] + ) + for dist in dist_list: + # offset: we offset the vertices first, and then use shapely to make it proper + # in principle, one can offset with shapely.buffer directly, but shapely somehow + # automatically removes some vertices even though no change of topology. + poly_offset = PolySlab._shift_vertices(proper_vertices, dist)[0] + # flipped winding number + if PolySlab._area(poly_offset) < fp_eps**2: + return True + + poly_offset = shapely.make_valid(PolySlab.make_shapely_polygon(poly_offset)) + # 1) polygon split or create holes/islands + if not poly_offset.geom_type == "Polygon" or len(poly_offset.interiors) > 0: + return True + + # 2) reduction in vertex number + offset_vertices = PolySlab._proper_vertices(poly_offset.exterior.coords) + if offset_vertices.shape[0] != num_vertices: + return True + + # 3) some split polygon might fully disappear after the offset, but they + # can be detected if we offset back. + poly_offset_back = shapely.make_valid( + PolySlab.make_shapely_polygon(PolySlab._shift_vertices(offset_vertices, -dist)[0]) + ) + if poly_offset_back.geom_type == "MultiPolygon" or len(poly_offset_back.interiors) > 0: + return True + offset_back_vertices = poly_offset_back.exterior.coords + if PolySlab._proper_vertices(offset_back_vertices).shape[0] != num_vertices: + return True + + return False + + @staticmethod + def _neighbor_vertices_crossing_detection( + vertices: NDArray, dist: float, ignore_at_dist: bool = True + ) -> float: + """Detect if neighboring vertices will cross after a dilation distance dist. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + dist : float + Distance to offset. + ignore_at_dist : bool, optional + whether to ignore the event right at ``dist`. + + Returns + ------- + float + the absolute value of the maximal allowed dilation + if there are any crossing, otherwise return ``None``. + """ + # ignore the event that occurs right at the offset distance + if ignore_at_dist: + dist -= fp_eps * dist / abs(dist) + + edge_length, edge_reduction = PolySlab._edge_length_and_reduction_rate(vertices) + length_remaining = edge_length - edge_reduction * dist + + if np.any(length_remaining < 0): + index_oversized = length_remaining < 0 + max_dist = np.min( + np.abs(edge_length[index_oversized] / edge_reduction[index_oversized]) + ) + return max_dist + return None + + @staticmethod + def array_to_vertices(arr_vertices: NDArray) -> ArrayFloat2D: + """Converts a numpy array of vertices to a list of tuples.""" + return list(arr_vertices) + + @staticmethod + def vertices_to_array(vertices_tuple: ArrayFloat2D) -> NDArray: + """Converts a list of tuples (vertices) to a numpy array.""" + return np.array(vertices_tuple) + + @cached_property + def interior_angle(self) -> ArrayFloat1D: + """Angle formed inside polygon by two adjacent edges.""" + + def normalize(v: NDArray) -> NDArray: + return v / np.linalg.norm(v, axis=0) + + vs_orig = self.reference_polygon.T + vs_next = np.roll(vs_orig, axis=-1, shift=-1) + vs_previous = np.roll(vs_orig, axis=-1, shift=+1) + + asp = normalize(vs_next - vs_orig) + asm = normalize(vs_previous - vs_orig) + + cos_angle = asp[0] * asm[0] + asp[1] * asm[1] + sin_angle = asp[0] * asm[1] - asp[1] * asm[0] + + angle = np.arccos(cos_angle) + # concave angles + angle[sin_angle < 0] = 2 * np.pi - angle[sin_angle < 0] + return angle + + @staticmethod + def _shift_vertices( + vertices: NDArray, dist: float + ) -> tuple[NDArray, NDArray, tuple[NDArray, NDArray]]: + """Shifts the vertices of a polygon outward uniformly by distances + `dists`. + + Parameters + ---------- + np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + dist : float + Distance to offset. + + Returns + ------- + tuple[np.ndarray, np.narray,tuple[np.ndarray,np.ndarray]] + New polygon vertices; + and the shift of vertices in direction parallel to the edges. + Shift along x and y direction. + """ + + # 'dist' may be autograd-traced; unbox for the zero-check only + if math.isclose(getval(dist), 0): + return vertices, np.zeros(vertices.shape[0], dtype=float), None + + def rot90(v: tuple[NDArray, NDArray]) -> NDArray: + """90 degree rotation of 2d vector + vx -> vy + vy -> -vx + """ + vxs, vys = v + return np.stack((-vys, vxs), axis=0) + + def cross(u: NDArray, v: NDArray) -> Any: + return u[0] * v[1] - u[1] * v[0] + + def normalize(v: NDArray) -> NDArray: + return v / np.linalg.norm(v, axis=0) + + vs_orig = copy(vertices.T) + vs_next = np.roll(copy(vs_orig), axis=-1, shift=-1) + vs_previous = np.roll(copy(vs_orig), axis=-1, shift=+1) + + asp = normalize(vs_next - vs_orig) + asm = normalize(vs_orig - vs_previous) + + # the vertex shift is decomposed into parallel and perpendicular directions + perpendicular_shift = -dist + det = cross(asm, asp) + + tan_half_angle = np.where( + np.isclose(det, 0, rtol=_IS_CLOSE_RTOL), + 0.0, + cross(asm, rot90(asm - asp)) / (det + np.isclose(det, 0, rtol=_IS_CLOSE_RTOL)), + ) + parallel_shift = dist * tan_half_angle + + shift_total = perpendicular_shift * rot90(asm) + parallel_shift * asm + shift_x = shift_total[0, :] + shift_y = shift_total[1, :] + + return ( + np.swapaxes(vs_orig + shift_total, -2, -1), + parallel_shift, + (shift_x, shift_y), + ) + + @staticmethod + def _edge_length_and_reduction_rate( + vertices: NDArray, + ) -> tuple[NDArray, NDArray]: + """Edge length of reduction rate of each edge with unit offset length. + + Parameters + ---------- + vertices : np.ndarray + Shape (N, 2) defining the polygon vertices in the xy-plane. + + Returns + ------- + tuple[np.ndarray, np.narray] + edge length, and reduction rate + """ + + # edge length + vs_orig = copy(vertices.T) + vs_next = np.roll(copy(vs_orig), axis=-1, shift=-1) + edge_length = np.linalg.norm(vs_next - vs_orig, axis=0) + + # edge length remaining + dist = 1 + parallel_shift = PolySlab._shift_vertices(vertices, dist)[1] + parallel_shift_p = np.roll(copy(parallel_shift), shift=-1) + edge_reduction = -(parallel_shift + parallel_shift_p) + return edge_length, edge_reduction + + @staticmethod + def _maximal_erosion(vertices: NDArray) -> float: + """The erosion value that reduces the length of + all edges to be non-positive. + """ + edge_length, edge_reduction = PolySlab._edge_length_and_reduction_rate(vertices) + ind_nonzero = abs(edge_reduction) > fp_eps + return -np.min(edge_length[ind_nonzero] / edge_reduction[ind_nonzero]) + + @staticmethod + def _heal_polygon(vertices: NDArray) -> NDArray: + """heal a self-intersecting polygon.""" + shapely_poly = PolySlab.make_shapely_polygon(vertices) + if shapely_poly.is_valid: + return vertices + elif hasbox(vertices): + raise NotImplementedError( + "The dilation caused damage to the polygon. " + "Automatically healing this is currently not supported when " + "differentiating w.r.t. the vertices. Try increasing the spacing " + "between vertices or reduce the amount of dilation." + ) + # perform healing + poly_heal = shapely.make_valid(shapely_poly) + return PolySlab._proper_vertices(list(poly_heal.exterior.coords)) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + + z_min, z_max = self.slab_bounds + + z_min = max(z_min, bounds[0][self.axis]) + z_max = min(z_max, bounds[1][self.axis]) + + length = z_max - z_min + + top_area = abs(self._area(self.top_polygon)) + base_area = abs(self._area(self.base_polygon)) + + # https://mathworld.wolfram.com/PyramidalFrustum.html + return 1.0 / 3.0 * length * (top_area + base_area + np.sqrt(top_area * base_area)) + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + + area = 0 + + top_polygon = self.top_polygon + base_polygon = self.base_polygon + + top_area = abs(self._area(top_polygon)) + base_area = abs(self._area(base_polygon)) + + top_perim = self._perimeter(top_polygon) + base_perim = self._perimeter(base_polygon) + + z_min, z_max = self.slab_bounds + + if z_min < bounds[0][self.axis]: + z_min = bounds[0][self.axis] + else: + area += base_area + + if z_max > bounds[1][self.axis]: + z_max = bounds[1][self.axis] + else: + area += top_area + + length = z_max - z_min + + area += 0.5 * (top_perim + base_perim) * length + + return area + + """ Autograd code """ + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """ + Return VJPs while handling several edge-cases: + + - If the slab volume does not overlap the simulation, all grads are zero + (one warning is issued). + - Faces that lie completely outside the simulation give zero ``slab_bounds`` + gradients; this includes the +/- inf cases. + - A 2d simulation collapses the surface integral to a line integral + """ + vjps: AutogradFieldMap = {} + + intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect) + sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds) + + extents = intersect_max - intersect_min + is_2d = np.isclose(extents[self.axis], 0.0) + + # early return if polyslab is not in simulation domain + slab_min, slab_max = self.slab_bounds + if (slab_max < sim_min[self.axis]) or (slab_min > sim_max[self.axis]): + log.warning( + "'PolySlab' lies completely outside the simulation domain.", + log_once=True, + ) + for p in derivative_info.paths: + vjps[p] = np.zeros_like(self.vertices) if p == ("vertices",) else 0.0 + return vjps + + # create interpolators once for ALL derivative computations + # use provided interpolators if available to avoid redundant field data conversions + interpolators = derivative_info.interpolators or derivative_info.create_interpolators( + dtype=config.adjoint.gradient_dtype_float + ) + + for path in derivative_info.paths: + if path == ("vertices",): + vjps[path] = self._compute_derivative_vertices( + derivative_info, sim_min, sim_max, is_2d, interpolators + ) + + elif path == ("sidewall_angle",): + vjps[path] = self._compute_derivative_sidewall_angle( + derivative_info, sim_min, sim_max, is_2d, interpolators + ) + elif path[0] == "slab_bounds": + idx = path[1] + face_coord = self.slab_bounds[idx] + + # face entirely outside -> gradient 0 + if ( + np.isinf(face_coord) + or face_coord < sim_min[self.axis] + or face_coord > sim_max[self.axis] + or is_2d + ): + vjps[path] = 0.0 + continue + + v = self._compute_derivative_slab_bounds(derivative_info, idx, interpolators) + # outward-normal convention + if idx == 0: + v *= -1 + vjps[path] = v + else: + raise ValueError(f"No derivative defined w.r.t. 'PolySlab' field '{path}'.") + + return vjps + + # ---- Shared helpers for VJP surface integrations ---- + def _z_slices( + self, sim_min: NDArray, sim_max: NDArray, is_2d: bool, dx: float + ) -> tuple[NDArray, float, float, float]: + """Compute z-slice centers and spacing within bounds. + + Returns (z_centers, dz, z0, z1). For 2D, returns single center and dz=1. + """ + if is_2d: + midpoint_z = np.maximum( + np.minimum(self.center_axis, sim_max[self.axis]), + sim_min[self.axis], + ) + zc = np.array([midpoint_z], dtype=config.adjoint.gradient_dtype_float) + return zc, 1.0, self.center_axis, self.center_axis + + z0 = max(self.slab_bounds[0], sim_min[self.axis]) + z1 = min(self.slab_bounds[1], sim_max[self.axis]) + if z1 <= z0: + return np.array([], dtype=config.adjoint.gradient_dtype_float), 0.0, z0, z1 + + n_z = max(1, int(np.ceil((z1 - z0) / dx))) + dz = (z1 - z0) / n_z + z_centers = np.linspace( + z0 + dz / 2, z1 - dz / 2, n_z, dtype=config.adjoint.gradient_dtype_float + ) + return z_centers, dz, z0, z1 + + @staticmethod + def _clip_edges_to_bounds_batch( + segment_starts: NDArray, + segment_ends: NDArray, + sim_min: NDArray, + sim_max: NDArray, + *, + _edge_clip_tol: Optional[float] = None, + _dtype: Optional[type] = None, + ) -> tuple[NDArray, NDArray, NDArray]: + """ + Compute parametric bounds for multiple segments clipped to simulation bounds. + + Parameters + ---------- + segment_starts : NDArray + (N, 3) array of segment start coordinates. + segment_ends : NDArray + (N, 3) array of segment end coordinates. + sim_min : NDArray + (3,) array of simulation minimum bounds. + sim_max : NDArray + (3,) array of simulation maximum bounds. + + Returns + ------- + is_within_bounds : NDArray + (N,) boolean array indicating if the segment intersects the bounds. + t_starts : NDArray + (N,) array of parametric start values (0.0 to 1.0). + t_ends : NDArray + (N,) array of parametric end values (0.0 to 1.0). + """ + n = segment_starts.shape[0] + if _edge_clip_tol is None: + _edge_clip_tol = config.adjoint.edge_clip_tolerance + if _dtype is None: + _dtype = config.adjoint.gradient_dtype_float + + t_starts = np.zeros(n, dtype=_dtype) + t_ends = np.ones(n, dtype=_dtype) + is_within_bounds = np.ones(n, dtype=bool) + + for dim in range(3): + start_coords = segment_starts[:, dim] + end_coords = segment_ends[:, dim] + bound_min = sim_min[dim] + bound_max = sim_max[dim] + + # check for parallel edges (faster than isclose) + parallel = np.abs(start_coords - end_coords) < 1e-12 + + # parallel edges: check if outside bounds + outside = parallel & ( + (start_coords < (bound_min - _edge_clip_tol)) + | (start_coords > (bound_max + _edge_clip_tol)) + ) + is_within_bounds &= ~outside + + # non-parallel edges: compute t_min, t_max + not_parallel = ~parallel & is_within_bounds + if np.any(not_parallel): + denom = np.where(not_parallel, end_coords - start_coords, 1.0) # avoid div by zero + t_min = (bound_min - start_coords) / denom + t_max = (bound_max - start_coords) / denom + + # swap if needed + swap = t_min > t_max + t_min_new = np.where(swap, t_max, t_min) + t_max_new = np.where(swap, t_min, t_max) + + # update t_starts and t_ends for valid non-parallel edges + t_starts = np.where(not_parallel, np.maximum(t_starts, t_min_new), t_starts) + t_ends = np.where(not_parallel, np.minimum(t_ends, t_max_new), t_ends) + + # still valid? + is_within_bounds &= ~not_parallel | (t_starts < t_ends) + + is_within_bounds &= t_ends > t_starts + _edge_clip_tol + + return is_within_bounds, t_starts, t_ends + + @staticmethod + def _adaptive_edge_samples( + L: float, + dx: float, + t_start: float = 0.0, + t_end: float = 1.0, + *, + _sample_fraction: Optional[float] = None, + _gauss_order: Optional[int] = None, + _dtype: Optional[type] = None, + ) -> tuple[NDArray, NDArray]: + """ + Compute Gauss samples and weights along [t_start, t_end] with adaptive count. + + Parameters + ---------- + L : float + Physical length of the full edge. + dx : float + Target discretization step size. + t_start : float, optional + Start parameter, by default 0.0. + t_end : float, optional + End parameter, by default 1.0. + + Returns + ------- + tuple[NDArray, NDArray] + Tuple of (samples, weights) for the integration. + """ + if _sample_fraction is None: + _sample_fraction = config.adjoint.quadrature_sample_fraction + if _gauss_order is None: + _gauss_order = config.adjoint.gauss_quadrature_order + if _dtype is None: + _dtype = config.adjoint.gradient_dtype_float + + L_eff = L * max(0.0, t_end - t_start) + n_uniform = max(1, int(np.ceil(L_eff / dx))) + n_gauss = n_uniform if n_uniform <= 3 else max(2, int(n_uniform * _sample_fraction)) + if n_gauss <= _gauss_order: + g, w = leggauss(n_gauss) + half_range = 0.5 * (t_end - t_start) + s = (half_range * g + 0.5 * (t_end + t_start)).astype(_dtype, copy=False) + wt = (w * half_range).astype(_dtype, copy=False) + return s, wt + + # composite Gauss with fixed local order + g_loc, w_loc = leggauss(_gauss_order) + segs = n_uniform + edges_t = np.linspace(t_start, t_end, segs + 1, dtype=_dtype) + + # compute all segments at once + a = edges_t[:-1] # (segs,) + b = edges_t[1:] # (segs,) + half_width = 0.5 * (b - a) # (segs,) + mid = 0.5 * (b + a) # (segs,) + + # (segs, 1) * (order,) + (segs, 1) -> (segs, order) + S = (half_width[:, None] * g_loc + mid[:, None]).astype(_dtype, copy=False) + W = (half_width[:, None] * w_loc).astype(_dtype, copy=False) + return S.ravel(), W.ravel() + + def _collect_sidewall_patches( + self, + vertices: NDArray, + next_v: NDArray, + edges: NDArray, + basis: dict, + sim_min: NDArray, + sim_max: NDArray, + is_2d: bool, + dx: float, + ) -> dict: + """ + Collect sidewall patch geometry for batched VJP evaluation. + + Parameters + ---------- + vertices : NDArray + Array of polygon vertices. + next_v : NDArray + Array of next vertices (forming edges). + edges : NDArray + Edge vectors. + basis : dict + Basis vectors dictionary. + sim_min : NDArray + Simulation minimum bounds. + sim_max : NDArray + Simulation maximum bounds. + is_2d : bool + Whether the simulation is 2D. + dx : float + Discretization step. + + Returns + ------- + dict + Dictionary containing: + - centers: (N, 3) array of patch centers. + - normals: (N, 3) array of patch normals. + - perps1: (N, 3) array of first tangent vectors. + - perps2: (N, 3) array of second tangent vectors. + - Ls: (N,) array of edge lengths. + - s_vals: (N,) array of parametric coordinates along the edge. + - s_weights: (N,) array of quadrature weights. + - zc_vals: (N,) array of z-coordinates. + - dz: float, slice thickness. + - edge_indices: (N,) array of original edge indices. + """ + # cache config values to avoid repeated lookups (overhead not insignificant here) + _dtype = config.adjoint.gradient_dtype_float + _edge_clip_tol = config.adjoint.edge_clip_tolerance + _sample_fraction = config.adjoint.quadrature_sample_fraction + _gauss_order = config.adjoint.gauss_quadrature_order + + theta = get_static(self.sidewall_angle) + z_ref = self.reference_axis_pos + + cos_th = np.cos(theta) + cos_th = np.clip(cos_th, 1e-12, 1.0) + tan_th = np.tan(theta) + dprime = -tan_th # dd/dz + + # axis unit vector in 3D + axis_vec = np.zeros(3, dtype=_dtype) + axis_vec[self.axis] = 1.0 + + # densify along axis as |theta| grows, dz scales with cos(theta) + z_centers, dz, z0, z1 = self._z_slices(sim_min, sim_max, is_2d=is_2d, dx=dx * cos_th) + + # early exit: no slices + if (not is_2d) and len(z_centers) == 0: + return { + "centers": np.empty((0, 3), dtype=_dtype), + "normals": np.empty((0, 3), dtype=_dtype), + "perps1": np.empty((0, 3), dtype=_dtype), + "perps2": np.empty((0, 3), dtype=_dtype), + "Ls": np.empty((0,), dtype=_dtype), + "s_vals": np.empty((0,), dtype=_dtype), + "s_weights": np.empty((0,), dtype=_dtype), + "zc_vals": np.empty((0,), dtype=_dtype), + "dz": dz, + "edge_indices": np.empty((0,), dtype=int), + } + + # estimate patches for pre-allocation + n_edges = len(vertices) + estimated_patches = 0 + denom_edge = max(dx * cos_th, 1e-12) + for ei in range(n_edges): + v0, v1 = vertices[ei], next_v[ei] + L = np.linalg.norm(v1 - v0) + if not np.isclose(L, 0.0): + # prealloc guided by actual step; ds_phys scales with cos(theta) + n_samples = max(1, int(np.ceil(L / denom_edge) * 0.6)) + estimated_patches += n_samples * max(1, len(z_centers)) + estimated_patches = int(max(1, estimated_patches) * 1.2) + + # pre-allocate arrays + centers = np.empty((estimated_patches, 3), dtype=_dtype) + normals = np.empty((estimated_patches, 3), dtype=_dtype) + perps1 = np.empty((estimated_patches, 3), dtype=_dtype) + perps2 = np.empty((estimated_patches, 3), dtype=_dtype) + Ls = np.empty((estimated_patches,), dtype=_dtype) + s_vals = np.empty((estimated_patches,), dtype=_dtype) + s_weights = np.empty((estimated_patches,), dtype=_dtype) + zc_vals = np.empty((estimated_patches,), dtype=_dtype) + edge_indices = np.empty((estimated_patches,), dtype=int) + + patch_idx = 0 + + # if the simulation is effectively 2D (one tangential dimension collapsed), + # slightly expand degenerate bounds to enable finite-length clipping of edges. + sim_min_eff = np.array(sim_min, dtype=_dtype) + sim_max_eff = np.array(sim_max, dtype=_dtype) + for dim in range(3): + if dim == self.axis: + continue + if np.isclose(sim_max_eff[dim] - sim_min_eff[dim], 0.0): + sim_min_eff[dim] -= 0.5 * dx + sim_max_eff[dim] += 0.5 * dx + + # pre-compute values that are constant across z slices + n_z = len(z_centers) + z_centers_arr = np.asarray(z_centers, dtype=_dtype) + + # slanted local basis (constant across z for non-slanted case) + # for slanted: rz = axis_vec + dprime * n2d, but dprime is constant + for ei, (v0, v1) in enumerate(zip(vertices, next_v)): + edge_vec = v1 - v0 + L = np.sqrt(np.dot(edge_vec, edge_vec)) + if L < 1e-12: + continue + + # constant along edge: unit tangent in 3D (no axis component) + t_edge = basis["perp1"][ei] + + # outward in-plane normal from canonical basis normal + n2d = basis["norm"][ei].copy() + n2d[self.axis] = 0.0 + nrm = np.linalg.norm(n2d) + if not np.isclose(nrm, 0.0): + n2d = n2d / nrm + else: + # fallback to right-handed construction if degenerate + tmp = np.cross(axis_vec, t_edge) + n2d = tmp / (np.linalg.norm(tmp) + 1e-20) + + # compute basis vectors once per edge + rz = axis_vec + dprime * n2d + T1_vec = t_edge + N_vec = np.cross(T1_vec, rz) + N_norm = np.linalg.norm(N_vec) + if not np.isclose(N_norm, 0.0): + N_vec = N_vec / N_norm + + # align N with outward edge normal + if float(np.dot(N_vec, basis["norm"][ei])) < 0.0: + N_vec = -N_vec + + T2_vec = np.cross(N_vec, T1_vec) + T2_norm = np.linalg.norm(T2_vec) + if not np.isclose(T2_norm, 0.0): + T2_vec = T2_vec / T2_norm + + # batch compute offsets for all z slices at once + d_all = -(z_centers_arr - z_ref) * tan_th # (n_z,) + offsets_3d = d_all[:, None] * n2d # (n_z, 3) - faster than np.outer + + # batch compute segment starts and ends for all z slices + segment_starts = np.empty((n_z, 3), dtype=_dtype) + segment_ends = np.empty((n_z, 3), dtype=_dtype) + plane_axes = [i for i in range(3) if i != self.axis] + segment_starts[:, self.axis] = z_centers_arr + segment_starts[:, plane_axes] = v0 + segment_starts += offsets_3d + segment_ends[:, self.axis] = z_centers_arr + segment_ends[:, plane_axes] = v1 + segment_ends += offsets_3d + + # batch clip all z slices at once + is_within_bounds, t_starts, t_ends = self._clip_edges_to_bounds_batch( + segment_starts, + segment_ends, + sim_min_eff, + sim_max_eff, + _edge_clip_tol=_edge_clip_tol, + _dtype=_dtype, + ) + + # process only valid z slices (sampling has variable output sizes) + valid_indices = np.nonzero(is_within_bounds)[0] + if len(valid_indices) == 0: + continue + + # group z slices by unique (t0, t1) pairs to avoid redundant quadrature calculations. + # since most z-slices will have identical clipping bounds (0.0, 1.0), + # we can compute the Gauss samples once and reuse them for almost all slices. + # rounding ensures we get cache hits despite tiny floating point differences. + t0_valid = np.round(t_starts[valid_indices], 10) + t1_valid = np.round(t_ends[valid_indices], 10) + + # simple cache for sampling results: (t0, t1) -> (s_list, w_list) + sample_cache = {} + + # process each z slice + for zi, t0, t1 in zip(valid_indices, t0_valid, t1_valid): + if (t0, t1) not in sample_cache: + sample_cache[(t0, t1)] = self._adaptive_edge_samples( + L, + denom_edge, + t0, + t1, + _sample_fraction=_sample_fraction, + _gauss_order=_gauss_order, + _dtype=_dtype, + ) + + s_list, w_list = sample_cache[(t0, t1)] + if len(s_list) == 0: + continue + + zc = z_centers_arr[zi] + offset3d = offsets_3d[zi] + + pts2d = v0 + s_list[:, None] * edge_vec # faster than np.outer + + # inline unpop_axis_vect for xyz computation + n_pts = len(s_list) + xyz = np.empty((n_pts, 3), dtype=_dtype) + xyz[:, self.axis] = zc + xyz[:, plane_axes] = pts2d + xyz += offset3d + + n_patches = n_pts + new_size_needed = patch_idx + n_patches + if new_size_needed > centers.shape[0]: + # grow arrays by 1.5x to avoid frequent reallocations + new_size = int(new_size_needed * 1.5) + centers.resize((new_size, 3), refcheck=False) + normals.resize((new_size, 3), refcheck=False) + perps1.resize((new_size, 3), refcheck=False) + perps2.resize((new_size, 3), refcheck=False) + Ls.resize((new_size,), refcheck=False) + s_vals.resize((new_size,), refcheck=False) + s_weights.resize((new_size,), refcheck=False) + zc_vals.resize((new_size,), refcheck=False) + edge_indices.resize((new_size,), refcheck=False) + + sl = slice(patch_idx, patch_idx + n_patches) + centers[sl] = xyz + normals[sl] = N_vec + perps1[sl] = T1_vec + perps2[sl] = T2_vec + Ls[sl] = L + s_vals[sl] = s_list + s_weights[sl] = w_list + zc_vals[sl] = zc + edge_indices[sl] = ei + + patch_idx += n_patches + + # trim arrays to final size + centers = centers[:patch_idx] + normals = normals[:patch_idx] + perps1 = perps1[:patch_idx] + perps2 = perps2[:patch_idx] + Ls = Ls[:patch_idx] + s_vals = s_vals[:patch_idx] + s_weights = s_weights[:patch_idx] + zc_vals = zc_vals[:patch_idx] + edge_indices = edge_indices[:patch_idx] + + return { + "centers": centers, + "normals": normals, + "perps1": perps1, + "perps2": perps2, + "Ls": Ls, + "s_vals": s_vals, + "s_weights": s_weights, + "zc_vals": zc_vals, + "dz": dz, + "edge_indices": edge_indices, + } + + def _compute_derivative_sidewall_angle( + self, + derivative_info: DerivativeInfo, + sim_min: NDArray, + sim_max: NDArray, + is_2d: bool = False, + interpolators: Optional[dict] = None, + ) -> float: + """VJP for dJ/dtheta where theta = sidewall_angle. + + Use dJ/dtheta = integral_S g(x) * V_n(x; theta) * dA, with g(x) from + `evaluate_gradient_at_points`. For a ruled sidewall built by + offsetting the mid-plane polygon by d(z) = -(z - z_ref) * tan(theta), + the normal velocity is V_n = (dd/dtheta) * cos(theta) = -(z - z_ref)/cos(theta) + and the area element is dA = (dz/cos(theta)) * d_ell. + Therefore each patch weight is w = L * dz * (-(z - z_ref)) / cos(theta)^2. + """ + if interpolators is None: + interpolators = derivative_info.create_interpolators( + dtype=config.adjoint.gradient_dtype_float + ) + + # 2D sim => no dependence on theta (z_local=0) + if is_2d: + return 0.0 + + vertices, next_v, edges, basis = self._edge_geometry_arrays() + + dx = derivative_info.adaptive_vjp_spacing() + + # collect patches once + patch = self._collect_sidewall_patches( + vertices=vertices, + next_v=next_v, + edges=edges, + basis=basis, + sim_min=sim_min, + sim_max=sim_max, + is_2d=False, + dx=dx, + ) + if patch["centers"].shape[0] == 0: + return 0.0 + + # Shape-derivative factors: + # - Offset: d(z) = -(z - z_ref) * tan(theta) + # - Tangential rate: dd/dtheta = -(z - z_ref) * sec(theta)^2 + # - Normal velocity (project to surface normal): V_n = (dd/dtheta) * cos(theta) = -(z - z_ref)/cos(theta) + # - Area element of slanted strip: dA = (dz/cos(theta)) * d_ell + # => Patch weight scales as: V_n * dA = -(z - z_ref) * dz * d_ell / cos(theta)^2 + cos_theta = np.cos(get_static(self.sidewall_angle)) + inv_cos2 = 1.0 / (cos_theta * cos_theta) + z_ref = self.reference_axis_pos + + g = derivative_info.evaluate_gradient_at_points( + patch["centers"], patch["normals"], patch["perps1"], patch["perps2"], interpolators + ) + z_local = patch["zc_vals"] - z_ref + weights = patch["Ls"] * patch["s_weights"] * patch["dz"] * (-z_local) * inv_cos2 + return float(np.real(np.sum(g * weights))) + + def _compute_derivative_slab_bounds( + self, derivative_info: DerivativeInfo, min_max_index: int, interpolators: dict + ) -> TracedArrayFloat2D: + """VJP for one of the two horizontal faces of a ``PolySlab``. + + The face is discretized into a Cartesian grid of small planar patches. + The adjoint surface integral is evaluated on every retained patch; the + resulting derivative is split equally between the two vertices that bound + the edge segment. + """ + # rmin/rmax over the geometry and simulation box + if np.isclose(self.slab_bounds[1] - self.slab_bounds[0], 0.0): + log.warning( + "Computing slab face derivatives for flat structures is not fully supported and " + "may give zero for the derivative. Try using a structure with a small, but nonzero " + "thickness for slab bound derivatives." + ) + rmin, rmax = derivative_info.bounds_intersect + _, (r1_min, r2_min) = self.pop_axis(rmin, axis=self.axis) + _, (r1_max, r2_max) = self.pop_axis(rmax, axis=self.axis) + ax_val = self.slab_bounds[min_max_index] + + # planar grid resolution, clipped to polygon bounding box + face_verts = self.base_polygon if min_max_index == 0 else self.top_polygon + face_poly = shapely.Polygon(face_verts).buffer(fp_eps) + + # limit the patch grid to the face that lives inside the simulation box + poly_min_r1, poly_min_r2, poly_max_r1, poly_max_r2 = face_poly.bounds + r1_min = max(r1_min, poly_min_r1) + r1_max = min(r1_max, poly_max_r1) + r2_min = max(r2_min, poly_min_r2) + r2_max = min(r2_max, poly_max_r2) + + # intersect the polygon with the simulation bounds + face_poly = face_poly.intersection(shapely.box(r1_min, r2_min, r1_max, r2_max)) + + if (r1_max <= r1_min) and (r2_max <= r2_min): + # the polygon does not intersect the current simulation slice + return 0.0 + + # re-compute the extents after clipping to the polygon bounds + extents = np.array([r1_max - r1_min, r2_max - r2_min]) + + # choose surface or line integral + integral_fun = ( + self.compute_derivative_slab_bounds_line + if np.isclose(extents, 0).any() + else self.compute_derivative_slab_bounds_surface + ) + return integral_fun( + derivative_info, + extents, + r1_min, + r1_max, + r2_min, + r2_max, + ax_val, + face_poly, + min_max_index, + interpolators, + ) + + def compute_derivative_slab_bounds_line( + self, + derivative_info: DerivativeInfo, + extents: NDArray, + r1_min: float, + r1_max: float, + r2_min: float, + r2_max: float, + ax_val: float, + face_poly: shapely.Polygon, + min_max_index: int, + interpolators: dict, + ) -> float: + """Handle degenerate line cross-section case""" + line_dim = 1 if np.isclose(extents[0], 0) else 0 + + poly_min_r1, poly_min_r2, poly_max_r1, poly_max_r2 = face_poly.bounds + if line_dim == 0: # x varies, y is fixed + l_min = max(r1_min, poly_min_r1) + l_max = min(r1_max, poly_max_r1) + else: # y varies, x is fixed + l_min = max(r2_min, poly_min_r2) + l_max = min(r2_max, poly_max_r2) + + length = l_max - l_min + if np.isclose(length, 0): + return 0.0 + + dx = derivative_info.adaptive_vjp_spacing() + n_seg = max(1, int(np.ceil(length / dx))) + coords = np.linspace( + l_min, l_max, 2 * n_seg + 1, dtype=config.adjoint.gradient_dtype_float + )[1::2] + + # build XY coordinates and in-plane direction vectors + if line_dim == 0: + xy = np.column_stack((coords, np.full_like(coords, r2_min))) + dir_vec_plane = np.column_stack((np.ones_like(coords), np.zeros_like(coords))) + else: + xy = np.column_stack((np.full_like(coords, r1_min), coords)) + dir_vec_plane = np.column_stack((np.zeros_like(coords), np.ones_like(coords))) + + inside = shapely.contains_xy(face_poly, xy[:, 0], xy[:, 1]) + if not inside.any(): + return 0.0 + + xy = xy[inside] + dir_vec_plane = dir_vec_plane[inside] + n_pts = len(xy) + + centers_xyz = self.unpop_axis_vect(np.full(n_pts, ax_val), xy) + areas = np.full(n_pts, length / n_seg) # patch length + + normals_xyz = self.unpop_axis_vect( + np.full( + n_pts, -1 if min_max_index == 0 else 1, dtype=config.adjoint.gradient_dtype_float + ), + np.zeros_like(xy, dtype=config.adjoint.gradient_dtype_float), + ) + perps1_xyz = self.unpop_axis_vect(np.zeros(n_pts), dir_vec_plane) + perps2_xyz = self.unpop_axis_vect(np.zeros(n_pts), np.zeros_like(dir_vec_plane)) + + vjps = derivative_info.evaluate_gradient_at_points( + centers_xyz, normals_xyz, perps1_xyz, perps2_xyz, interpolators + ) + return np.real(np.sum(vjps * areas)).item() + + def compute_derivative_slab_bounds_surface( + self, + derivative_info: DerivativeInfo, + extents: NDArray, + r1_min: float, + r1_max: float, + r2_min: float, + r2_max: float, + ax_val: float, + face_poly: shapely.Polygon, + min_max_index: int, + interpolators: dict, + ) -> float: + """2d surface integral on a Gauss quadrature grid""" + dx = derivative_info.adaptive_vjp_spacing() + + # uniform grid would use n1 x n2 points + n1_uniform, n2_uniform = np.maximum(1, np.ceil(extents / dx).astype(int)) + + # use ~1/2 Gauss points in each direction for similar accuracy + n1 = max(2, n1_uniform // 2) + n2 = max(2, n2_uniform // 2) + + g1, w1 = leggauss(n1) + g2, w2 = leggauss(n2) + + coords1 = (0.5 * (r1_max - r1_min) * g1 + 0.5 * (r1_max + r1_min)).astype( + config.adjoint.gradient_dtype_float, copy=False + ) + coords2 = (0.5 * (r2_max - r2_min) * g2 + 0.5 * (r2_max + r2_min)).astype( + config.adjoint.gradient_dtype_float, copy=False + ) + + r1_grid, r2_grid = np.meshgrid(coords1, coords2, indexing="ij") + r1_flat = r1_grid.flatten() + r2_flat = r2_grid.flatten() + pts = np.column_stack((r1_flat, r2_flat)) + + in_face = shapely.contains_xy(face_poly, pts[:, 0], pts[:, 1]) + if not in_face.any(): + return 0.0 + + xyz = self.unpop_axis_vect( + np.full(in_face.sum(), ax_val, dtype=config.adjoint.gradient_dtype_float), pts[in_face] + ) + n_patches = xyz.shape[0] + + normals_xyz = self.unpop_axis_vect( + np.full( + n_patches, + -1 if min_max_index == 0 else 1, + dtype=config.adjoint.gradient_dtype_float, + ), + np.zeros((n_patches, 2), dtype=config.adjoint.gradient_dtype_float), + ) + perps1_xyz = self.unpop_axis_vect( + np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), + np.column_stack( + ( + np.ones(n_patches, dtype=config.adjoint.gradient_dtype_float), + np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), + ) + ), + ) + perps2_xyz = self.unpop_axis_vect( + np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), + np.column_stack( + ( + np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), + np.ones(n_patches, dtype=config.adjoint.gradient_dtype_float), + ) + ), + ) + + w1_grid, w2_grid = np.meshgrid(w1, w2, indexing="ij") + weights_flat = (w1_grid * w2_grid).flatten()[in_face] + jacobian = 0.25 * (r1_max - r1_min) * (r2_max - r2_min) + + # area-based correction for non-rectangular domains (e.g. concave polygon) + # for constant integrand, integral should equal polygon area + sum_weights = np.sum(weights_flat) + if sum_weights > 0: + area_correction = face_poly.area / (sum_weights * jacobian) + weights_flat = weights_flat * area_correction + + vjps = derivative_info.evaluate_gradient_at_points( + xyz, normals_xyz, perps1_xyz, perps2_xyz, interpolators + ) + return np.real(np.sum(vjps * weights_flat * jacobian)).item() + + def _compute_derivative_vertices( + self, + derivative_info: DerivativeInfo, + sim_min: NDArray, + sim_max: NDArray, + is_2d: bool = False, + interpolators: Optional[dict] = None, + ) -> NDArray: + """VJP for the vertices of a ``PolySlab``. + + Uses shared sidewall patch collection and batched field evaluation. + """ + vertices, next_v, edges, basis = self._edge_geometry_arrays() + dx = derivative_info.adaptive_vjp_spacing() + + # collect patches once + patch = self._collect_sidewall_patches( + vertices=vertices, + next_v=next_v, + edges=edges, + basis=basis, + sim_min=sim_min, + sim_max=sim_max, + is_2d=is_2d, + dx=dx, + ) + + # early return if no patches + if patch["centers"].shape[0] == 0: + return np.zeros_like(vertices) + + dz = patch["dz"] + dz_surf = 1.0 if is_2d else dz / np.cos(self.sidewall_angle) + + # use provided interpolators or create them if not provided + if interpolators is None: + interpolators = derivative_info.create_interpolators( + dtype=config.adjoint.gradient_dtype_float + ) + + # evaluate integrand + g = derivative_info.evaluate_gradient_at_points( + patch["centers"], patch["normals"], patch["perps1"], patch["perps2"], interpolators + ) + + # compute area-based weights and weighted vjps + areas = patch["Ls"] * patch["s_weights"] * dz_surf + patch_vjps = (g * areas).real + + # distribute to vertices using vectorized accumulation + normals_2d = np.delete(basis["norm"], self.axis, axis=1) + edge_idx = patch["edge_indices"] + s = patch["s_vals"] + w0 = (1.0 - s) * patch_vjps + w1 = s * patch_vjps + edge_norms = normals_2d[edge_idx] + + # Accumulate per-vertex contributions using bincount (O(N_patches)) + num_vertices = vertices.shape[0] + contrib0 = w0[:, None] * edge_norms # (n_patches, 2) + contrib1 = w1[:, None] * edge_norms # (n_patches, 2) + + idx0 = edge_idx + idx1 = (edge_idx + 1) % num_vertices + + v0x = np.bincount(idx0, weights=contrib0[:, 0], minlength=num_vertices) + v0y = np.bincount(idx0, weights=contrib0[:, 1], minlength=num_vertices) + v1x = np.bincount(idx1, weights=contrib1[:, 0], minlength=num_vertices) + v1y = np.bincount(idx1, weights=contrib1[:, 1], minlength=num_vertices) + + vjp_per_vertex = np.stack((v0x + v1x, v0y + v1y), axis=1) + return vjp_per_vertex + + def _edge_geometry_arrays( + self, dtype: np.dtype = config.adjoint.gradient_dtype_float + ) -> tuple[NDArray, NDArray, NDArray, dict[str, NDArray]]: + """Return (vertices, next_v, edges, basis) arrays for sidewall edge geometry.""" + vertices = np.asarray(self.vertices, dtype=dtype) + next_v = np.roll(vertices, -1, axis=0) + edges = next_v - vertices + basis = self.edge_basis_vectors(edges) + return vertices, next_v, edges, basis + + def edge_basis_vectors( + self, + edges: NDArray, # (N, 2) + ) -> dict[str, NDArray]: # (N, 3) + """Normalized basis vectors for ``normal`` direction, ``slab`` tangent direction and ``edge``.""" + + # ensure edges have consistent dtype + edges = edges.astype(config.adjoint.gradient_dtype_float, copy=False) + + num_vertices, _ = edges.shape + zeros = np.zeros(num_vertices, dtype=config.adjoint.gradient_dtype_float) + ones = np.ones(num_vertices, dtype=config.adjoint.gradient_dtype_float) + + # normalized vectors along edges + edges_norm_in_plane = self.normalize_vect(edges) + edges_norm_xyz = self.unpop_axis_vect(zeros, edges_norm_in_plane) + + # normalized vectors from base of edges to tops of edges + cos_angle = np.cos(self.sidewall_angle) + sin_angle = np.sin(self.sidewall_angle) + slabs_axis_components = cos_angle * ones + + # create axis_norm as array directly to avoid tuple->array conversion in np.cross + axis_norm = np.zeros(3, dtype=config.adjoint.gradient_dtype_float) + axis_norm[self.axis] = 1.0 + slab_normal_xyz = -sin_angle * np.cross(edges_norm_xyz, axis_norm) + _, slab_normal_in_plane = self.pop_axis_vect(slab_normal_xyz) + slabs_norm_xyz = self.unpop_axis_vect(slabs_axis_components, slab_normal_in_plane) + + # normalized vectors pointing in normal direction of edge + # cross yields inward normal when the extrusion axis is y, so negate once for axis==1 + sign = (-1 if self.axis == 1 else 1) * (-1 if not self.is_ccw else 1) + normals_norm_xyz = sign * np.cross(edges_norm_xyz, slabs_norm_xyz) + + return { + "norm": normals_norm_xyz, + "perp1": edges_norm_xyz, + "perp2": slabs_norm_xyz, + } + + def unpop_axis_vect(self, ax_coords: NDArray, plane_coords: NDArray) -> NDArray: + """Combine coordinate along axis with coordinates on the plane tangent to the axis. + + ax_coords.shape == [N] + plane_coords.shape == [N, 2] + return shape == [N, 3] + """ + n_pts = ax_coords.shape[0] + arr_xyz = np.zeros((n_pts, 3), dtype=ax_coords.dtype) + + plane_axes = [i for i in range(3) if i != self.axis] + + arr_xyz[:, self.axis] = ax_coords + arr_xyz[:, plane_axes] = plane_coords + + return arr_xyz + + def pop_axis_vect(self, coord: NDArray) -> tuple[NDArray, tuple[NDArray, NDArray]]: + """Combine coordinate along axis with coordinates on the plane tangent to the axis. + + coord.shape == [N, 3] + return shape == ([N], [N, 2] + """ + + arr_axis, arrs_plane = self.pop_axis(coord.T, axis=self.axis) + arrs_plane = np.array(arrs_plane).T + + return arr_axis, arrs_plane + + @staticmethod + def normalize_vect(arr: NDArray) -> NDArray: + """normalize an array shaped (N, d) along the `d` axis and return (N, 1).""" + norm = np.linalg.norm(arr, axis=-1, keepdims=True) + norm = np.where(norm == 0, 1, norm) + return arr / norm + + def translated(self, x: float, y: float, z: float) -> PolySlab: + """Return a translated copy of this geometry. + + Parameters + ---------- + x : float + Translation along x. + y : float + Translation along y. + z : float + Translation along z. + + Returns + ------- + :class:`PolySlab` + Translated copy of this ``PolySlab``. + """ + + t_normal, t_plane = self.pop_axis((x, y, z), axis=self.axis) + translated_vertices = np.array(self.vertices) + np.array(t_plane)[None, :] + translated_slab_bounds = (self.slab_bounds[0] + t_normal, self.slab_bounds[1] + t_normal) + return self.updated_copy(vertices=translated_vertices, slab_bounds=translated_slab_bounds) + + def scaled(self, x: float = 1.0, y: float = 1.0, z: float = 1.0) -> PolySlab: + """Return a scaled copy of this geometry. + + Parameters + ---------- + x : float = 1.0 + Scaling factor along x. + y : float = 1.0 + Scaling factor along y. + z : float = 1.0 + Scaling factor along z. + + Returns + ------- + :class:`Geometry` + Scaled copy of this geometry. + """ + scale_normal, scale_in_plane = self.pop_axis((x, y, z), axis=self.axis) + scaled_vertices = self.vertices * np.array(scale_in_plane) + scaled_slab_bounds = tuple(scale_normal * bound for bound in self.slab_bounds) + return self.updated_copy(vertices=scaled_vertices, slab_bounds=scaled_slab_bounds) + + def rotated(self, angle: float, axis: Union[Axis, Coordinate]) -> PolySlab: + """Return a rotated copy of this geometry. + + Parameters + ---------- + angle : float + Rotation angle (in radians). + axis : Union[int, tuple[float, float, float]] + Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. + + Returns + ------- + :class:`PolySlab` + Rotated copy of this ``PolySlab``. + """ + _, plane_axs = self.pop_axis([0, 1, 2], self.axis) + if (isinstance(axis, int) and axis == self.axis) or ( + isinstance(axis, tuple) and all(axis[ax] == 0 for ax in plane_axs) + ): + verts_3d = np.zeros((3, self.vertices.shape[0])) + verts_3d[plane_axs[0], :] = self.vertices[:, 0] + verts_3d[plane_axs[1], :] = self.vertices[:, 1] + rotation = RotationAroundAxis(angle=angle, axis=axis) + rotated_vertices = rotation.rotate_vector(verts_3d) + rotated_vertices = rotated_vertices[plane_axs, :].T + return self.updated_copy(vertices=rotated_vertices) + + return super().rotated(angle=angle, axis=axis) + + def reflected(self, normal: Coordinate) -> PolySlab: + """Return a reflected copy of this geometry. + + Parameters + ---------- + normal : tuple[float, float, float] + The 3D normal vector of the plane of reflection. The plane is assumed + to pass through the origin (0,0,0). + + Returns + ------- + ------- + :class:`PolySlab` + Reflected copy of this ``PolySlab``. + """ + if math.isclose(normal[self.axis], 0): + _, plane_axs = self.pop_axis((0, 1, 2), self.axis) + verts_3d = np.zeros((3, self.vertices.shape[0])) + verts_3d[plane_axs[0], :] = self.vertices[:, 0] + verts_3d[plane_axs[1], :] = self.vertices[:, 1] + reflection = ReflectionFromPlane(normal=normal) + reflected_vertices = reflection.reflect_vector(verts_3d) + reflected_vertices = reflected_vertices[plane_axs, :].T + return self.updated_copy(vertices=reflected_vertices) + + return super().reflected(normal=normal) + + +class ComplexPolySlabBase(PolySlab): + """Interface for dividing a complex polyslab where self-intersecting polygon can + occur during extrusion. This class should not be used directly. Use instead + :class:`plugins.polyslab.ComplexPolySlab`.""" + + @model_validator(mode="after") + def no_self_intersecting_polygon_during_extrusion(self: Self) -> Self: + """Turn off the validation for this class.""" + return self + + @classmethod + def from_gds( + cls, + gds_cell: Cell, + axis: Axis, + slab_bounds: tuple[float, float], + gds_layer: int, + gds_dtype: Optional[int] = None, + gds_scale: PositiveFloat = 1.0, + dilation: float = 0.0, + sidewall_angle: float = 0, + reference_plane: PlanePosition = "middle", + ) -> list[PolySlab]: + """Import :class:`.PolySlab` from a ``gdstk.Cell``. + + Parameters + ---------- + gds_cell : gdstk.Cell + ``gdstk.Cell`` containing 2D geometric data. + axis : int + Integer index into the polygon's slab axis. (0,1,2) -> (x,y,z). + slab_bounds: tuple[float, float] + Minimum and maximum positions of the slab along ``axis``. + gds_layer : int + Layer index in the ``gds_cell``. + gds_dtype : int = None + Data-type index in the ``gds_cell``. + If ``None``, imports all data for this layer into the returned list. + gds_scale : float = 1.0 + Length scale used in GDS file in units of MICROMETER. + For example, if gds file uses nanometers, set ``gds_scale=1e-3``. + Must be positive. + dilation : float = 0.0 + Dilation of the polygon in the base by shifting each edge along its + normal outwards direction by a distance; + a negative value corresponds to erosion. + sidewall_angle : float = 0 + Angle of the sidewall. + ``sidewall_angle=0`` (default) specifies vertical wall, + while ``0 base.GeometryGroup: + """Divide a complex polyslab into a list of simple polyslabs, which + are assembled into a :class:`.GeometryGroup`. + + Returns + ------- + :class:`.GeometryGroup` + GeometryGroup for a list of simple polyslabs divided from the complex + polyslab. + """ + return base.GeometryGroup(geometries=self.sub_polyslabs) + + @property + def sub_polyslabs(self) -> list[PolySlab]: + """Divide a complex polyslab into a list of simple polyslabs. + Only neighboring vertex-vertex crossing events are treated in this + version. + + Returns + ------- + list[PolySlab] + A list of simple polyslabs. + """ + sub_polyslab_list = [] + num_division_count = 0 + # initialize sub-polyslab parameters + sub_polyslab_dict = self.model_dump(exclude={"type"}).copy() + if math.isclose(self.sidewall_angle, 0): + return [PolySlab.model_validate(sub_polyslab_dict)] + + sub_polyslab_dict.update({"dilation": 0}) # dilation accounted in setup + # initialize offset distance + offset_distance = 0 + + for dist_val in self._dilation_length: + dist_now = 0.0 + vertices_now = self.reference_polygon + + # constructing sub-polyslabs until reaching the base/top + while not math.isclose(dist_now, dist_val): + # bounds for sub-polyslabs assuming no self-intersection + slab_bounds = [ + self._dilation_value_at_reference_to_coord(dist_now), + self._dilation_value_at_reference_to_coord(dist_val), + ] + # 1) find out any vertices touching events between the current + # position to the base/top + max_dist = PolySlab._neighbor_vertices_crossing_detection( + vertices_now, dist_val - dist_now + ) + + # vertices touching events captured, update bounds for sub-polyslab + if max_dist is not None: + # max_dist doesn't have sign, so construct signed offset distance + offset_distance = max_dist * dist_val / abs(dist_val) + slab_bounds[1] = self._dilation_value_at_reference_to_coord( + dist_now + offset_distance + ) + + # 2) construct sub-polyslab + slab_bounds.sort() # for reference_plane=top/bottom, bounds need to be ordered + # direction of marching + reference_plane = "bottom" if dist_val / self._tanq < 0 else "top" + sub_polyslab_dict.update( + { + "slab_bounds": tuple(slab_bounds), + "vertices": vertices_now, + "reference_plane": reference_plane, + } + ) + sub_polyslab_list.append(PolySlab.model_validate(sub_polyslab_dict)) + + # Now Step 3 + if max_dist is None: + break + dist_now += offset_distance + # new polygon vertices where collapsing vertices are removed but keep one + vertices_now = PolySlab._shift_vertices(vertices_now, offset_distance)[0] + vertices_now = PolySlab._remove_duplicate_vertices(vertices_now) + # all vertices collapse + if len(vertices_now) < 3: + break + # polygon collapse into 1D + if self.make_shapely_polygon(vertices_now).buffer(0).area < fp_eps: + break + vertices_now = PolySlab._orient(vertices_now) + num_division_count += 1 + + if num_division_count > _COMPLEX_POLYSLAB_DIVISIONS_WARN: + log.warning( + f"Too many self-intersecting events: the polyslab has been divided into " + f"{num_division_count} polyslabs; more than {_COMPLEX_POLYSLAB_DIVISIONS_WARN} may " + f"slow down the simulation." + ) + + return sub_polyslab_list + + @property + def _dilation_length(self) -> list[float]: + """dilation length from reference plane to the top/bottom of the polyslab.""" + + # for "bottom", only needs to compute the offset length to the top + dist = [self._extrusion_length_to_offset_distance(self.finite_length_axis)] + # reverse the dilation value if the reference plane is on the top + if self.reference_plane == "top": + dist = [-dist[0]] + # for middle, both directions + elif self.reference_plane == "middle": + dist = [dist[0] / 2, -dist[0] / 2] + return dist + + def _dilation_value_at_reference_to_coord(self, dilation: float) -> float: + """Compute the coordinate based on the dilation value to the reference plane.""" + + z_coord = -dilation / self._tanq + self.slab_bounds[0] + if self.reference_plane == "middle": + return z_coord + self.finite_length_axis / 2 + if self.reference_plane == "top": + return z_coord + self.finite_length_axis + # bottom case + return z_coord + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. Not used for PolySlab. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + return [ + shapely.unary_union( + [ + base.Geometry.evaluate_inf_shape(shape) + for polyslab in self.sub_polyslabs + for shape in polyslab.intersections_tilted_plane( + normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs + ) + ] + ) + ] diff --git a/tidy3d/_common/components/geometry/primitives.py b/tidy3d/_common/components/geometry/primitives.py new file mode 100644 index 0000000000..741d606661 --- /dev/null +++ b/tidy3d/_common/components/geometry/primitives.py @@ -0,0 +1,1016 @@ +"""Concrete primitive geometrical objects.""" + +from __future__ import annotations + +from math import isclose +from typing import TYPE_CHECKING, Any + +import autograd.numpy as anp +import numpy as np +import shapely +from pydantic import Field, PrivateAttr, model_validator + +from tidy3d._common.components.autograd import TracedSize1D +from tidy3d._common.components.base import cached_property +from tidy3d._common.components.geometry import base +from tidy3d._common.components.geometry.mesh import TriangleMesh +from tidy3d._common.components.geometry.polyslab import PolySlab +from tidy3d._common.config import config +from tidy3d._common.constants import LARGE_NUMBER, MICROMETER +from tidy3d._common.exceptions import SetupError, ValidationError +from tidy3d._common.log import log +from tidy3d._common.packaging import verify_packages_import + +if TYPE_CHECKING: + from typing import Optional + + from shapely.geometry.base import BaseGeometry + + from tidy3d._common.compat import Self + from tidy3d._common.components.autograd import AutogradFieldMap + from tidy3d._common.components.autograd.derivative_utils import DerivativeInfo + from tidy3d._common.components.types.base import Axis, Bound, Coordinate, MatrixReal4x4, Shapely + +# for sampling conical frustum in visualization +_N_SAMPLE_CURVE_SHAPELY = 40 + +# for shapely circular shapes discretization in visualization +_N_SHAPELY_QUAD_SEGS_VISUALIZATION = 200 + +# Default number of points to discretize polyslab in `Cylinder.to_polyslab()` +_N_PTS_CYLINDER_POLYSLAB = 51 +_MAX_ICOSPHERE_SUBDIVISIONS = 7 # this would have 164K vertices and 328K faces +_DEFAULT_EDGE_FRACTION = 0.25 + + +def _base_icosahedron() -> tuple[np.ndarray, np.ndarray]: + """Return vertices and faces of a unit icosahedron.""" + + phi = (1.0 + np.sqrt(5.0)) / 2.0 + vertices = np.array( + [ + (-1, phi, 0), + (1, phi, 0), + (-1, -phi, 0), + (1, -phi, 0), + (0, -1, phi), + (0, 1, phi), + (0, -1, -phi), + (0, 1, -phi), + (phi, 0, -1), + (phi, 0, 1), + (-phi, 0, -1), + (-phi, 0, 1), + ], + dtype=float, + ) + vertices /= np.linalg.norm(vertices, axis=1)[:, None] + faces = np.array( + [ + (0, 11, 5), + (0, 5, 1), + (0, 1, 7), + (0, 7, 10), + (0, 10, 11), + (1, 5, 9), + (5, 11, 4), + (11, 10, 2), + (10, 7, 6), + (7, 1, 8), + (3, 9, 4), + (3, 4, 2), + (3, 2, 6), + (3, 6, 8), + (3, 8, 9), + (4, 9, 5), + (2, 4, 11), + (6, 2, 10), + (8, 6, 7), + (9, 8, 1), + ], + dtype=int, + ) + return vertices, faces + + +_ICOSAHEDRON_VERTS, _ICOSAHEDRON_FACES = _base_icosahedron() + + +class Sphere(base.Centered, base.Circular): + """Spherical geometry. + + Example + ------- + >>> b = Sphere(center=(1,2,3), radius=2) + """ + + _icosphere_cache: dict[int, tuple[np.ndarray, float]] = PrivateAttr(default_factory=dict) + + def inside( + self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] + ) -> np.ndarray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + self._ensure_equal_shape(x, y, z) + x0, y0, z0 = self.center + dist_x = np.abs(x - x0) + dist_y = np.abs(y - y0) + dist_z = np.abs(z - z0) + return (dist_x**2 + dist_y**2 + dist_z**2) <= (self.radius**2) + + def intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + if quad_segs is None: + quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION + + normal = np.array(normal) + unit_normal = normal / (np.sum(normal**2) ** 0.5) + projection = np.dot(np.array(origin) - np.array(self.center), unit_normal) + if abs(projection) >= self.radius: + return [] + + radius = (self.radius**2 - projection**2) ** 0.5 + center = np.array(self.center) + projection * unit_normal + + v = np.zeros(3) + v[np.argmin(np.abs(unit_normal))] = 1 + u = np.cross(unit_normal, v) + u /= np.sum(u**2) ** 0.5 + v = np.cross(unit_normal, u) + + angles = np.linspace(0, 2 * np.pi, quad_segs * 4 + 1)[:-1] + circ = center + np.outer(np.cos(angles), radius * u) + np.outer(np.sin(angles), radius * v) + vertices = np.dot(np.hstack((circ, np.ones((angles.size, 1)))), to_2D.T) + return [shapely.Polygon(vertices[:, :2])] + + def intersections_plane( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + cleanup: bool = True, + quad_segs: Optional[int] = None, + ) -> list[BaseGeometry]: + """Returns shapely geometry at plane specified by one non None value of x,y,z. + + Parameters + ---------- + x : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + y : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + z : float = None + Position of plane in x direction, only one of x,y,z can be specified to define plane. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation ``. + """ + if quad_segs is None: + quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION + + axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) + if not self.intersects_axis_position(axis, position): + return [] + z0, (x0, y0) = self.pop_axis(self.center, axis=axis) + intersect_dist = self._intersect_dist(position, z0) + if not intersect_dist: + return [] + return [shapely.Point(x0, y0).buffer(0.5 * intersect_dist, quad_segs=quad_segs)] + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + Tuple[float, float, float], Tuple[float, float, float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + coord_min = tuple(c - self.radius for c in self.center) + coord_max = tuple(c + self.radius for c in self.center) + return (coord_min, coord_max) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + + volume = 4.0 / 3.0 * np.pi * self.radius**3 + + # a very loose upper bound on how much of sphere is in bounds + for axis in range(3): + if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: + volume *= 0.5 + + return volume + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + + area = 4.0 * np.pi * self.radius**2 + + # a very loose upper bound on how much of sphere is in bounds + for axis in range(3): + if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: + area *= 0.5 + + return area + + @classmethod + def unit_sphere_triangles( + cls, + *, + target_edge_length: Optional[float] = None, + subdivisions: Optional[int] = None, + ) -> np.ndarray: + """Return unit sphere triangles discretized via an icosphere.""" + + unit_tris = UNIT_SPHERE._unit_sphere_triangles( + target_edge_length=target_edge_length, + subdivisions=subdivisions, + copy_result=True, + ) + return unit_tris + + def _unit_sphere_triangles( + self, + *, + target_edge_length: Optional[float] = None, + subdivisions: Optional[int] = None, + copy_result: bool = True, + ) -> np.ndarray: + """Return cached unit-sphere triangles with optional copying.""" + if target_edge_length is not None and subdivisions is not None: + raise ValueError("Specify either target_edge_length OR subdivisions, not both.") + + if subdivisions is None: + subdivisions = self._subdivisions_for_edge(target_edge_length) + + triangles, _ = self._icosphere_data(subdivisions) + return np.array(triangles, copy=copy_result) + + def _subdivisions_for_edge(self, target_edge_length: Optional[float]) -> int: + if target_edge_length is None or target_edge_length <= 0.0: + return 0 + + for subdiv in range(_MAX_ICOSPHERE_SUBDIVISIONS + 1): + _, max_edge = self._icosphere_data(subdiv) + if max_edge <= target_edge_length: + return subdiv + + log.warning( + f"Requested sphere mesh edge length {target_edge_length:.3e} μm requires more than " + f"{_MAX_ICOSPHERE_SUBDIVISIONS} subdivisions. " + "Clipping to the finest available mesh.", + log_once=True, + ) + return _MAX_ICOSPHERE_SUBDIVISIONS + + def _icosphere_data(self, subdivisions: int) -> tuple[np.ndarray, float]: + cache = self._icosphere_cache + if subdivisions in cache: + return cache[subdivisions] + + vertices = np.asarray(_ICOSAHEDRON_VERTS, dtype=float) + faces = np.asarray(_ICOSAHEDRON_FACES, dtype=int) + if subdivisions > 0: + vertices = vertices.copy() + faces = faces.copy() + for _ in range(subdivisions): + vertices, faces = TriangleMesh.subdivide_faces(vertices, faces) + + norms = np.linalg.norm(vertices, axis=1, keepdims=True) + norms = np.where(norms == 0.0, 1.0, norms) + vertices = vertices / norms + + triangles = vertices[faces] + max_edge = self._max_edge_length(triangles) + cache[subdivisions] = (triangles, max_edge) + return triangles, max_edge + + @staticmethod + def _max_edge_length(triangles: np.ndarray) -> float: + v = triangles + edges = np.stack( + [ + v[:, 1] - v[:, 0], + v[:, 2] - v[:, 1], + v[:, 0] - v[:, 2], + ], + axis=1, + ) + return float(np.linalg.norm(edges, axis=2).max()) + + +UNIT_SPHERE = Sphere(center=(0.0, 0.0, 0.0), radius=1.0) + + +class Cylinder(base.Centered, base.Circular, base.Planar): + """Cylindrical geometry with optional sidewall angle along axis + direction. When ``sidewall_angle`` is nonzero, the shape is a + conical frustum or a cone. + + Example + ------- + >>> c = Cylinder(center=(1,2,3), radius=2, length=5, axis=2) + + See Also + -------- + + **Notebooks** + + * `THz integrated demultiplexer/filter based on a ring resonator <../../../notebooks/THzDemultiplexerFilter.html>`_ + * `Photonic crystal waveguide polarization filter <../../../notebooks/PhotonicCrystalWaveguidePolarizationFilter.html>`_ + """ + + # Provide more explanations on where radius is defined + radius: TracedSize1D = Field( + title="Radius", + description="Radius of geometry at the ``reference_plane``.", + units=MICROMETER, + ) + + length: TracedSize1D = Field( + title="Length", + description="Defines thickness of cylinder along axis dimension.", + units=MICROMETER, + ) + + @model_validator(mode="after") + def _only_middle_for_infinite_length_slanted_cylinder(self: Self) -> Self: + """For a slanted cylinder of infinite length, ``reference_plane`` can only + be ``middle``; otherwise, the radius at ``center`` is either td.inf or 0. + """ + if isclose(self.sidewall_angle, 0) or not np.isinf(self.length): + return self + if self.reference_plane != "middle": + raise SetupError( + "For a slanted cylinder here is of infinite length, " + "defining the reference_plane other than 'middle' " + "leads to undefined cylinder behaviors near 'center'." + ) + return self + + def to_polyslab( + self, num_pts_circumference: int = _N_PTS_CYLINDER_POLYSLAB, **kwargs: Any + ) -> PolySlab: + """Convert instance of ``Cylinder`` into a discretized version using ``PolySlab``. + + Parameters + ---------- + num_pts_circumference : int = 51 + Number of points in the circumference of the discretized polyslab. + **kwargs: + Extra keyword arguments passed to ``PolySlab()``, such as ``dilation``. + + Returns + ------- + PolySlab + Extruded polygon representing a discretized version of the cylinder. + """ + + center_axis = self.center_axis + length_axis = self.length_axis + slab_bounds = (center_axis - length_axis / 2.0, center_axis + length_axis / 2.0) + + if num_pts_circumference < 3: + raise ValueError("'PolySlab' from 'Cylinder' must have 3 or more radius points.") + + _, (x0, y0) = self.pop_axis(self.center, axis=self.axis) + + xs_, ys_ = self._points_unit_circle(num_pts_circumference=num_pts_circumference) + + xs = x0 + self.radius * xs_ + ys = y0 + self.radius * ys_ + + vertices = anp.stack((xs, ys), axis=-1) + + return PolySlab( + vertices=vertices, + axis=self.axis, + slab_bounds=slab_bounds, + sidewall_angle=self.sidewall_angle, + reference_plane=self.reference_plane, + **kwargs, + ) + + def _points_unit_circle( + self, num_pts_circumference: int = _N_PTS_CYLINDER_POLYSLAB + ) -> np.ndarray: + """Set of x and y points for the unit circle when discretizing cylinder as a polyslab.""" + angles = np.linspace(0, 2 * np.pi, num_pts_circumference, endpoint=False) + xs = np.cos(angles) + ys = np.sin(angles) + return np.stack((xs, ys), axis=0) + + def _discretization_wavelength(self, derivative_info: DerivativeInfo) -> float: + """Choose a reference wavelength for discretizing the cylinder into a `PolySlab`.""" + wvl0_min = derivative_info.wavelength_min + wvl_mat = wvl0_min / np.max([1.0, np.max(np.sqrt(abs(derivative_info.eps_in)))]) + + grid_cfg = config.adjoint + + min_wvl_mat = grid_cfg.min_wvl_fraction * wvl0_min + if wvl_mat < min_wvl_mat: + log.warning( + f"The minimum wavelength inside the cylinder material is {wvl_mat:.3e} μm, which would " + f"create a large number of discretization points for computing the gradient. " + f"To prevent performance degradation, the discretization wavelength has " + f"been clipped to {min_wvl_mat:.3e} μm.", + log_once=True, + ) + wvl_mat = max(wvl_mat, min_wvl_mat) + + return wvl_mat + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + + # compute circumference discretization + wvl_mat = self._discretization_wavelength(derivative_info=derivative_info) + + circumference = 2 * np.pi * self.radius + wvls_in_circumference = circumference / wvl_mat + + grid_cfg = config.adjoint + num_pts_circumference = int(np.ceil(grid_cfg.points_per_wavelength * wvls_in_circumference)) + num_pts_circumference = max(3, num_pts_circumference) + + # construct equivalent polyslab and compute the derivatives + polyslab = self.to_polyslab(num_pts_circumference=num_pts_circumference) + + # build PolySlab derivative paths based on requested Cylinder paths + ps_paths = set() + for path in derivative_info.paths: + if path == ("length",): + ps_paths.update({("slab_bounds", 0), ("slab_bounds", 1)}) + elif path == ("radius",): + ps_paths.add(("vertices",)) + elif "center" in path: + _, center_index = path + _, (index_x, index_y) = self.pop_axis((0, 1, 2), axis=self.axis) + if center_index in (index_x, index_y): + ps_paths.add(("vertices",)) + else: + ps_paths.update({("slab_bounds", 0), ("slab_bounds", 1)}) + elif path == ("sidewall_angle",): + ps_paths.add(("sidewall_angle",)) + + # pass interpolators to PolySlab if available to avoid redundant conversions + update_kwargs = { + "paths": list(ps_paths), + "deep": False, + } + if derivative_info.interpolators is not None: + update_kwargs["interpolators"] = derivative_info.interpolators + + derivative_info_polyslab = derivative_info.updated_copy(**update_kwargs) + vjps_polyslab = polyslab._compute_derivatives(derivative_info_polyslab) + + vjps = {} + for path in derivative_info.paths: + if path == ("length",): + vjp_top = vjps_polyslab.get(("slab_bounds", 0), 0.0) + vjp_bot = vjps_polyslab.get(("slab_bounds", 1), 0.0) + vjps[path] = vjp_top - vjp_bot + + elif path == ("radius",): + # transform polyslab vertices derivatives into radius derivative + xs_, ys_ = self._points_unit_circle(num_pts_circumference=num_pts_circumference) + if ("vertices",) not in vjps_polyslab: + vjps[path] = 0.0 + else: + vjps_vertices_xs, vjps_vertices_ys = vjps_polyslab[("vertices",)].T + vjp_xs = np.sum(xs_ * vjps_vertices_xs) + vjp_ys = np.sum(ys_ * vjps_vertices_ys) + vjps[path] = vjp_xs + vjp_ys + + elif "center" in path: + _, center_index = path + _, (index_x, index_y) = self.pop_axis((0, 1, 2), axis=self.axis) + if center_index == index_x: + if ("vertices",) not in vjps_polyslab: + vjps[path] = 0.0 + else: + vjps_vertices_xs = vjps_polyslab[("vertices",)][:, 0] + vjps[path] = np.sum(vjps_vertices_xs) + elif center_index == index_y: + if ("vertices",) not in vjps_polyslab: + vjps[path] = 0.0 + else: + vjps_vertices_ys = vjps_polyslab[("vertices",)][:, 1] + vjps[path] = np.sum(vjps_vertices_ys) + else: + vjp_top = vjps_polyslab.get(("slab_bounds", 0), 0.0) + vjp_bot = vjps_polyslab.get(("slab_bounds", 1), 0.0) + vjps[path] = vjp_top + vjp_bot + + elif path == ("sidewall_angle",): + # direct mapping: cylinder angle equals polyslab angle + vjps[path] = vjps_polyslab.get(("sidewall_angle",), 0.0) + + else: + raise NotImplementedError( + f"Differentiation with respect to 'Cylinder' '{path}' field not supported. " + "If you would like this feature added, please feel free to raise " + "an issue on the tidy3d front end repository." + ) + + return vjps + + @property + def center_axis(self) -> Any: + """Gets the position of the center of the geometry in the out of plane dimension.""" + z0, _ = self.pop_axis(self.center, axis=self.axis) + return z0 + + @property + def length_axis(self) -> float: + """Gets the length of the geometry along the out of plane dimension.""" + return self.length + + @cached_property + def _normal_2dmaterial(self) -> Axis: + """Get the normal to the given geometry, checking that it is a 2D geometry.""" + if self.length != 0: + raise ValidationError("'Medium2D' requires the 'Cylinder' length to be zero.") + return self.axis + + def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Cylinder: + """Returns an updated geometry which has been transformed to fit within ``bounds`` + along the ``axis`` direction.""" + if axis != self.axis: + raise ValueError( + f"'_update_from_bounds' may only be applied along axis '{self.axis}', " + f"but was given axis '{axis}'." + ) + new_center = list(self.center) + new_center[axis] = (bounds[0] + bounds[1]) / 2 + new_length = bounds[1] - bounds[0] + return self.updated_copy(center=tuple(new_center), length=new_length) + + @verify_packages_import(["trimesh"]) + def _do_intersections_tilted_plane( + self, + normal: Coordinate, + origin: Coordinate, + to_2D: MatrixReal4x4, + quad_segs: Optional[int] = None, + ) -> list[Shapely]: + """Return a list of shapely geometries at the plane specified by normal and origin. + + Parameters + ---------- + normal : Coordinate + Vector defining the normal direction to the plane. + origin : Coordinate + Vector defining the plane origin. + to_2D : MatrixReal4x4 + Transformation matrix to apply to resulting shapes. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + import trimesh + + if quad_segs is None: + quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION + + z0, (x0, y0) = self.pop_axis(self.center, self.axis) + half_length = self.finite_length_axis / 2 + + z_top = z0 + half_length + z_bot = z0 - half_length + + if np.isclose(self.sidewall_angle, 0): + r_top = self.radius + r_bot = self.radius + else: + r_top = self.radius_top + r_bot = self.radius_bottom + if r_top < 0 or np.isclose(r_top, 0): + r_top = 0 + z_top = z0 + self._radius_z(z0) / self._tanq + elif r_bot < 0 or np.isclose(r_bot, 0): + r_bot = 0 + z_bot = z0 + self._radius_z(z0) / self._tanq + + angles = np.linspace(0, 2 * np.pi, quad_segs * 4 + 1) + + if r_bot > 0: + x_bot = x0 + r_bot * np.cos(angles) + y_bot = y0 + r_bot * np.sin(angles) + x_bot[-1] = x0 + y_bot[-1] = y0 + else: + x_bot = np.array([x0]) + y_bot = np.array([y0]) + + if r_top > 0: + x_top = x0 + r_top * np.cos(angles) + y_top = y0 + r_top * np.sin(angles) + x_top[-1] = x0 + y_top[-1] = y0 + else: + x_top = np.array([x0]) + y_top = np.array([y0]) + + x = np.hstack((x_bot, x_top)) + y = np.hstack((y_bot, y_top)) + z = np.hstack((np.full_like(x_bot, z_bot), np.full_like(x_top, z_top))) + vertices = np.vstack(self.unpop_axis(z, (x, y), self.axis)).T + + if x_bot.shape[0] == 1: + m = 1 + n = x_top.shape[0] - 1 + faces_top = [(m + n, m + i, m + (i + 1) % n) for i in range(n)] + faces_side = [(m + (i + 1) % n, m + i, 0) for i in range(n)] + faces = faces_top + faces_side + elif x_top.shape[0] == 1: + m = x_bot.shape[0] + n = m - 1 + faces_bot = [(n, (i + 1) % n, i) for i in range(n)] + faces_side = [(i, (i + 1) % n, m) for i in range(n)] + faces = faces_bot + faces_side + else: + m = x_bot.shape[0] + n = m - 1 + faces_bot = [(n, (i + 1) % n, i) for i in range(n)] + faces_top = [(m + n, m + i, m + (i + 1) % n) for i in range(n)] + faces_side_bot = [(i, (i + 1) % n, m + (i + 1) % n) for i in range(n)] + faces_side_top = [(m + (i + 1) % n, m + i, i) for i in range(n)] + faces = faces_bot + faces_top + faces_side_bot + faces_side_top + + mesh = trimesh.Trimesh(vertices, faces) + + section = mesh.section(plane_origin=origin, plane_normal=normal) + if section is None: + return [] + path, _ = section.to_2D(to_2D=to_2D) + return path.polygons_full + + def _intersections_normal( + self, z: float, quad_segs: Optional[int] = None + ) -> list[BaseGeometry]: + """Find shapely geometries intersecting cylindrical geometry with axis normal to slab. + + Parameters + ---------- + z : float + Position along the axis normal to slab + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + if quad_segs is None: + quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION + + static_self = self.to_static() + + # radius at z + radius_offset = static_self._radius_z(z) + + if radius_offset <= 0: + return [] + + _, (x0, y0) = self.pop_axis(static_self.center, axis=self.axis) + return [shapely.Point(x0, y0).buffer(radius_offset, quad_segs=quad_segs)] + + def _intersections_side(self, position: float, axis: int) -> list[BaseGeometry]: + """Find shapely geometries intersecting cylindrical geometry with axis orthogonal to length. + When ``sidewall_angle`` is nonzero, so that it's in fact a conical frustum or cone, the + cross section can contain hyperbolic curves. This is currently approximated by a polygon + of many vertices. + + Parameters + ---------- + position : float + Position along axis direction. + axis : int + Integer index into 'xyz' (0, 1, 2). + + Returns + ------- + list[shapely.geometry.base.BaseGeometry] + List of 2D shapes that intersect plane. + For more details refer to + `Shapely's Documentation `_. + """ + # position in the local coordinate of the cylinder + position_local = position - self.center[axis] + + # no intersection + if abs(position_local) >= self.radius_max: + return [] + + # half of intersection length at the top and bottom + intersect_half_length_max = np.sqrt(self.radius_max**2 - position_local**2) + intersect_half_length_min = -LARGE_NUMBER + if abs(position_local) < self.radius_min: + intersect_half_length_min = np.sqrt(self.radius_min**2 - position_local**2) + + # the vertices on the max side of top/bottom + # The two vertices are present in all scenarios. + vertices_max = [ + self._local_to_global_side_cross_section([-intersect_half_length_max, 0], axis), + self._local_to_global_side_cross_section([intersect_half_length_max, 0], axis), + ] + + # Extending to a cone, the maximal height of the cone + h_cone = ( + LARGE_NUMBER if isclose(self.sidewall_angle, 0) else self.radius_max / abs(self._tanq) + ) + # The maximal height of the cross section + height_max = min( + (1 - abs(position_local) / self.radius_max) * h_cone, self.finite_length_axis + ) + + # more vertices to add for conical frustum shape + vertices_frustum_right = [] + vertices_frustum_left = [] + if not (isclose(position, self.center[axis]) or isclose(self.sidewall_angle, 0)): + # The y-coordinate for the additional vertices + y_list = height_max * np.linspace(0, 1, _N_SAMPLE_CURVE_SHAPELY) + # `abs()` to make sure np.sqrt(0-fp_eps) goes through + x_list = np.sqrt( + np.abs(self.radius_max**2 * (1 - y_list / h_cone) ** 2 - position_local**2) + ) + for i in range(_N_SAMPLE_CURVE_SHAPELY): + vertices_frustum_right.append( + self._local_to_global_side_cross_section([x_list[i], y_list[i]], axis) + ) + vertices_frustum_left.append( + self._local_to_global_side_cross_section( + [ + -x_list[_N_SAMPLE_CURVE_SHAPELY - i - 1], + y_list[_N_SAMPLE_CURVE_SHAPELY - i - 1], + ], + axis, + ) + ) + + # the vertices on the min side of top/bottom + vertices_min = [] + + ## termination at the top/bottom + if intersect_half_length_min > 0: + vertices_min.append( + self._local_to_global_side_cross_section( + [intersect_half_length_min, self.finite_length_axis], axis + ) + ) + vertices_min.append( + self._local_to_global_side_cross_section( + [-intersect_half_length_min, self.finite_length_axis], axis + ) + ) + ## early termination + else: + vertices_min.append(self._local_to_global_side_cross_section([0, height_max], axis)) + + return [ + shapely.Polygon( + vertices_max + vertices_frustum_right + vertices_min + vertices_frustum_left + ) + ] + + def inside( + self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] + ) -> np.ndarray[bool]: + """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array + with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the + volume of the :class:`Geometry`, and ``False`` otherwise. + + Parameters + ---------- + x : np.ndarray[float] + Array of point positions in x direction. + y : np.ndarray[float] + Array of point positions in y direction. + z : np.ndarray[float] + Array of point positions in z direction. + + Returns + ------- + np.ndarray[bool] + ``True`` for every point that is inside the geometry. + """ + # radius at z + self._ensure_equal_shape(x, y, z) + z0, (x0, y0) = self.pop_axis(self.center, axis=self.axis) + z, (x, y) = self.pop_axis((x, y, z), axis=self.axis) + radius_offset = self._radius_z(z) + positive_radius = radius_offset > 0 + + dist_x = np.abs(x - x0) + dist_y = np.abs(y - y0) + dist_z = np.abs(z - z0) + inside_radius = (dist_x**2 + dist_y**2) <= (radius_offset**2) + inside_height = dist_z <= (self.finite_length_axis / 2) + return positive_radius * inside_radius * inside_height + + @cached_property + def bounds(self) -> Bound: + """Returns bounding box min and max coordinates. + + Returns + ------- + Tuple[float, float, float], Tuple[float, float, float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + """ + coord_min = [c - self.radius_max for c in self.center] + coord_max = [c + self.radius_max for c in self.center] + coord_min[self.axis] = self.center[self.axis] - self.length_axis / 2.0 + coord_max[self.axis] = self.center[self.axis] + self.length_axis / 2.0 + return (tuple(coord_min), tuple(coord_max)) + + def _volume(self, bounds: Bound) -> float: + """Returns object's volume within given bounds.""" + + coord_min = max(self.bounds[0][self.axis], bounds[0][self.axis]) + coord_max = min(self.bounds[1][self.axis], bounds[1][self.axis]) + + length = coord_max - coord_min + + volume = np.pi * self.radius_max**2 * length + + # a very loose upper bound on how much of the cylinder is in bounds + for axis in range(3): + if axis != self.axis: + if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: + volume *= 0.5 + + return volume + + def _surface_area(self, bounds: Bound) -> float: + """Returns object's surface area within given bounds.""" + + area = 0 + + coord_min = self.bounds[0][self.axis] + coord_max = self.bounds[1][self.axis] + + if coord_min < bounds[0][self.axis]: + coord_min = bounds[0][self.axis] + else: + area += np.pi * self.radius_max**2 + + if coord_max > bounds[1][self.axis]: + coord_max = bounds[1][self.axis] + else: + area += np.pi * self.radius_max**2 + + length = coord_max - coord_min + + area += 2.0 * np.pi * self.radius_max * length + + # a very loose upper bound on how much of the cylinder is in bounds + for axis in range(3): + if axis != self.axis: + if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: + area *= 0.5 + + return area + + @cached_property + def radius_bottom(self) -> float: + """radius of bottom""" + return self._radius_z(self.center_axis - self.finite_length_axis / 2) + + @cached_property + def radius_top(self) -> float: + """radius of bottom""" + return self._radius_z(self.center_axis + self.finite_length_axis / 2) + + @cached_property + def radius_max(self) -> float: + """max(radius of top, radius of bottom)""" + return max(self.radius_bottom, self.radius_top) + + @cached_property + def radius_min(self) -> float: + """min(radius of top, radius of bottom). It can be negative for a large + sidewall angle. + """ + return min(self.radius_bottom, self.radius_top) + + def _radius_z(self, z: float) -> float: + """Compute the radius of the cross section at the position z. + + Parameters + ---------- + z : float + Position along the axis normal to slab + """ + if isclose(self.sidewall_angle, 0): + return self.radius + + radius_middle = self.radius + if self.reference_plane == "top": + radius_middle += self.finite_length_axis / 2 * self._tanq + elif self.reference_plane == "bottom": + radius_middle -= self.finite_length_axis / 2 * self._tanq + + return radius_middle - (z - self.center_axis) * self._tanq + + def _local_to_global_side_cross_section(self, coords: list[float], axis: int) -> list[float]: + """Map a point (x,y) from local to global coordinate system in the + side cross section. + + The definition of the local: y=0 lies at the base if ``sidewall_angle>=0``, + and at the top if ``sidewall_angle<0``; x=0 aligns with the corresponding + ``self.center``. In both cases, y-axis is pointing towards the narrowing + direction of cylinder. + + Parameters + ---------- + axis : int + Integer index into 'xyz' (0, 1, 2). + coords : list[float, float] + The value in the planar coordinate. + + Returns + ------- + Tuple[float, float] + The point in the global coordinate for plotting `_intersection_side`. + + """ + + # For negative sidewall angle, quantities along axis direction usually needs a flipped sign + axis_sign = 1 + if self.sidewall_angle < 0: + axis_sign = -1 + + lx_offset, ly_offset = self._order_by_axis( + plane_val=coords[0], + axis_val=axis_sign * (-self.finite_length_axis / 2 + coords[1]), + axis=axis, + ) + _, (x_center, y_center) = self.pop_axis(self.center, axis=axis) + return [x_center + lx_offset, y_center + ly_offset] diff --git a/tidy3d/_common/components/geometry/triangulation.py b/tidy3d/_common/components/geometry/triangulation.py new file mode 100644 index 0000000000..79d457ff8f --- /dev/null +++ b/tidy3d/_common/components/geometry/triangulation.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np +import shapely + +from tidy3d._common.components.types.base import ArrayFloat1D +from tidy3d._common.exceptions import Tidy3dError + +if TYPE_CHECKING: + from tidy3d._common.components.types.base import ArrayFloat2D + + +@dataclass +class Vertex: + """Simple data class to hold triangulation data structures. + + Parameters + ---------- + coordinate: ArrayFloat1D + Vertex coordinate. + index : int + Vertex index in the original polygon. + convexity : float = 0.0 + Value representing the convexity (> 0) or concavity (< 0) of the vertex in the polygon. + is_ear : bool = False + Flag indicating whether this is an ear of the polygon. + """ + + coordinate: ArrayFloat1D + + index: int + + convexity: float + + is_ear: bool + + +def update_convexity(vertices: list[Vertex], i: int) -> int: + """Update the convexity of a vertex in a polygon. + + Parameters + ---------- + vertices : list[Vertex] + Vertices of the polygon. + i : int + Index of the vertex to be updated. + + Returns + ------- + int + Value indicating vertex convexity change w.r.t. 0. See note below. + + Note + ---- + Besides updating the vertex, this function returns a value indicating whether the updated vertex + convexity changed to or from 0 (0 convexity means the vertex is collinear with its neighbors). + If the convexity changes from zero to non-zero, return -1. If it changes from non-zero to zero, + return +1. Return 0 in any other case. This allows the main triangulation loop to keep track of + the total number of collinear vertices in the polygon. + + """ + result = -1 if vertices[i].convexity == 0.0 else 0 + j = (i + 1) % len(vertices) + vertices[i].convexity = np.linalg.det( + [ + vertices[i].coordinate - vertices[i - 1].coordinate, + vertices[j].coordinate - vertices[i].coordinate, + ] + ) + if vertices[i].convexity == 0.0: + result += 1 + return result + + +def is_inside( + vertex: ArrayFloat1D, triangle: tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] +) -> bool: + """Check if a vertex is inside a triangle. + + Parameters + ---------- + vertex : ArrayFloat1D + Vertex coordinates. + triangle : tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] + Vertices of the triangle in CCW order. + + Returns + ------- + bool: + Flag indicating if the vertex is inside the triangle. + """ + return all( + np.linalg.det([triangle[i] - triangle[i - 1], vertex - triangle[i - 1]]) > 0 + for i in range(3) + ) + + +def update_ear_flag(vertices: list[Vertex], i: int) -> None: + """Update the ear flag of a vertex in a polygon. + + Parameters + ---------- + vertices : list[Vertex] + Vertices of the polygon. + i : int + Index of the vertex to be updated. + """ + h = (i - 1) % len(vertices) + j = (i + 1) % len(vertices) + triangle = (vertices[h].coordinate, vertices[i].coordinate, vertices[j].coordinate) + vertices[i].is_ear = vertices[i].convexity > 0 and not any( + is_inside(v.coordinate, triangle) + for k, v in enumerate(vertices) + if not (v.convexity > 0 or k == h or k == i or k == j) + ) + + +# TODO: This is an inefficient algorithm that runs in O(n^2). We should use something +# better, and probably as a compiled extension. +def triangulate(vertices: ArrayFloat2D) -> list[tuple[int, int, int]]: + """Triangulate a simple polygon. + + Parameters + ---------- + vertices : ArrayFloat2D + Vertices of the polygon. + + Returns + ------- + list[tuple[int, int, int]] + List of indices of the vertices of the triangles. + """ + is_ccw = shapely.LinearRing(vertices).is_ccw + + # Initialize vertices as non-collinear because we will update the actual value below and count + # the number of collinear vertices. + vertices = [Vertex(v, i, -1.0, False) for i, v in enumerate(vertices)] + if not is_ccw: + vertices.reverse() + + collinears = 0 + for i in range(len(vertices)): + collinears += update_convexity(vertices, i) + + for i in range(len(vertices)): + update_ear_flag(vertices, i) + + triangles = [] + + ear_found = True + while len(vertices) > 3: + if not ear_found: + raise Tidy3dError( + "Impossible to triangulate polygon. Verify that the polygon is valid." + ) + ear_found = False + i = 0 + while i < len(vertices): + if vertices[i].is_ear: + removed = vertices.pop(i) + h = (i - 1) % len(vertices) + j = i % len(vertices) + collinears += update_convexity(vertices, h) + collinears += update_convexity(vertices, j) + if collinears == len(vertices): + # Undo removal because only collinear vertices remain + vertices.insert(i, removed) + collinears += update_convexity(vertices, (i - 1) % len(vertices)) + collinears += update_convexity(vertices, (i + 1) % len(vertices)) + i += 1 + else: + ear_found = True + triangles.append((vertices[h].index, removed.index, vertices[j].index)) + update_ear_flag(vertices, h) + update_ear_flag(vertices, j) + if len(vertices) == 3: + break + else: + i += 1 + + triangles.append(tuple(v.index for v in vertices)) + return triangles diff --git a/tidy3d/_common/components/geometry/utils.py b/tidy3d/_common/components/geometry/utils.py new file mode 100644 index 0000000000..989ed39398 --- /dev/null +++ b/tidy3d/_common/components/geometry/utils.py @@ -0,0 +1,481 @@ +"""Utilities for geometry manipulation.""" + +from __future__ import annotations + +from collections import defaultdict +from enum import Enum +from math import isclose +from typing import TYPE_CHECKING, Any, Optional, Union + +import numpy as np +import shapely +from pydantic import Field, NonNegativeInt +from shapely.geometry import ( + Polygon, +) +from shapely.geometry.base import ( + BaseMultipartGeometry, +) + +from tidy3d._common.components.autograd.utils import get_static +from tidy3d._common.components.base import Tidy3dBaseModel +from tidy3d._common.components.geometry import base, mesh, polyslab, primitives +from tidy3d._common.components.geometry.base import Box +from tidy3d._common.components.types.base import Shapely +from tidy3d._common.exceptions import SetupError, Tidy3dError + +if TYPE_CHECKING: + from collections.abc import Iterable + + from numpy.typing import ArrayLike + + from tidy3d._common.components.types.base import ( + ArrayFloat2D, + Axis, + MatrixReal4x4, + PlanePosition, + ) + +GeometryType = Union[ + base.Box, + base.Transformed, + base.ClipOperation, + base.GeometryGroup, + primitives.Sphere, + primitives.Cylinder, + polyslab.PolySlab, + polyslab.ComplexPolySlabBase, + mesh.TriangleMesh, +] + + +def flatten_shapely_geometries( + geoms: Union[Shapely, Iterable[Shapely]], keep_types: tuple[type, ...] = (Polygon,) +) -> list[Shapely]: + """ + Flatten nested geometries into a flat list, while only keeping the specified types. + + Recursively extracts and returns non-empty geometries of the given types from input geometries, + expanding any GeometryCollections or Multi* types. + + Parameters + ---------- + geoms : Union[Shapely, Iterable[Shapely]] + Input geometries to flatten. + + keep_types : tuple[type, ...] + Geometry types to keep (e.g., (Polygon, LineString)). Default is + (Polygon). + + Returns + ------- + list[Shapely] + Flat list of non-empty geometries matching the specified types. + """ + # Handle single Shapely object by wrapping it in a list + if isinstance(geoms, Shapely): + geoms = [geoms] + + flat = [] + for geom in geoms: + if geom.is_empty: + continue + if isinstance(geom, keep_types): + flat.append(geom) + elif isinstance(geom, BaseMultipartGeometry): + flat.extend(flatten_shapely_geometries(geom.geoms, keep_types)) + return flat + + +def merging_geometries_on_plane( + geometries: list[GeometryType], + plane: Box, + property_list: list[Any], + interior_disjoint_geometries: bool = False, + cleanup: bool = True, + quad_segs: Optional[int] = None, +) -> list[tuple[Any, Shapely]]: + """Compute list of shapes on plane. Overlaps are removed or merged depending on + provided property_list. + + Parameters + ---------- + geometries : list[GeometryType] + List of structures to filter on the plane. + plane : Box + Plane specification. + property_list : List = None + Property value for each structure. + interior_disjoint_geometries: bool = False + If ``True``, geometries of different properties on the plane must not be overlapping. + cleanup : bool = True + If True, removes extremely small features from each polygon's boundary. + quad_segs : Optional[int] = None + Number of segments used to discretize circular shapes. If ``None``, uses + high-quality visualization settings. + + Returns + ------- + list[tuple[Any, Shapely]] + List of shapes and their property value on the plane after merging. + """ + + if len(geometries) != len(property_list): + raise SetupError( + "Number of provided property values is not equal to the number of geometries." + ) + + shapes = [] + for geo, prop in zip(geometries, property_list): + # get list of Shapely shapes that intersect at the plane + shapes_plane = plane.intersections_with(geo, cleanup=cleanup, quad_segs=quad_segs) + + # Append each of them and their property information to the list of shapes + for shape in shapes_plane: + shapes.append((prop, shape, shape.bounds)) + + if interior_disjoint_geometries: + # No need to consider overlapping. We simply group shapes by property, and union_all + # shapes of the same property. + shapes_by_prop = defaultdict(list) + for prop, shape, _ in shapes: + shapes_by_prop[prop].append(shape) + # union shapes of same property + results = [] + for prop, shapes in shapes_by_prop.items(): + unionized = shapely.union_all(shapes).buffer(0).normalize() + if not unionized.is_empty: + results.append((prop, unionized)) + return results + + background_shapes = [] + for prop, shape, bounds in shapes: + minx, miny, maxx, maxy = bounds + + # loop through background_shapes (note: all background are non-intersecting or merged) + for index, (_prop, _shape, _bounds) in enumerate(background_shapes): + _minx, _miny, _maxx, _maxy = _bounds + + # do a bounding box check to see if any intersection to do anything about + if minx > _maxx or _minx > maxx or miny > _maxy or _miny > maxy: + continue + + # look more closely to see if intersected. + if shape.disjoint(_shape): + continue + + # different prop, remove intersection from background shape + if prop != _prop: + diff_shape = (_shape - shape).buffer(0).normalize() + # mark background shape for removal if nothing left + if diff_shape.is_empty or len(diff_shape.bounds) == 0: + background_shapes[index] = None + background_shapes[index] = (_prop, diff_shape, diff_shape.bounds) + # same prop, unionize shapes and mark background shape for removal + else: + shape = (shape | _shape).buffer(0).normalize() + background_shapes[index] = None + + # after doing this with all background shapes, add this shape to the background + background_shapes.append((prop, shape, shape.bounds)) + + # remove any existing background shapes that have been marked as 'None' + background_shapes = [b for b in background_shapes if b is not None] + + # filter out any remaining None or empty shapes (shapes with area completely removed) + return [(prop, shape) for (prop, shape, _) in background_shapes if shape] + + +def flatten_groups( + *geometries: GeometryType, + flatten_nonunion_type: bool = False, + flatten_transformed: bool = False, + transform: Optional[MatrixReal4x4] = None, +) -> GeometryType: + """Iterates over all geometries, flattening groups and unions. + + Parameters + ---------- + *geometries : GeometryType + Geometries to flatten. + flatten_nonunion_type : bool = False + If ``False``, only flatten geometry unions (and ``GeometryGroup``). If ``True``, flatten + all clip operations. + flatten_transformed : bool = False + If ``True``, ``Transformed`` groups are flattened into individual transformed geometries. + transform : Optional[MatrixReal4x4] + Accumulated transform from parents. Only used when ``flatten_transformed`` is ``True``. + + Yields + ------ + GeometryType + Geometries after flattening groups and unions. + """ + for geometry in geometries: + if isinstance(geometry, base.GeometryGroup): + yield from flatten_groups( + *geometry.geometries, + flatten_nonunion_type=flatten_nonunion_type, + flatten_transformed=flatten_transformed, + transform=transform, + ) + elif isinstance(geometry, base.ClipOperation) and ( + flatten_nonunion_type or geometry.operation == "union" + ): + yield from flatten_groups( + geometry.geometry_a, + geometry.geometry_b, + flatten_nonunion_type=flatten_nonunion_type, + flatten_transformed=flatten_transformed, + transform=transform, + ) + elif flatten_transformed and isinstance(geometry, base.Transformed): + new_transform = geometry.transform + if transform is not None: + new_transform = np.matmul(transform, new_transform) + yield from flatten_groups( + geometry.geometry, + flatten_nonunion_type=flatten_nonunion_type, + flatten_transformed=flatten_transformed, + transform=new_transform, + ) + elif flatten_transformed and transform is not None: + yield base.Transformed(geometry=geometry, transform=transform) + else: + yield geometry + + +def traverse_geometries(geometry: GeometryType) -> GeometryType: + """Iterator over all geometries within the given geometry. + + Iterates over groups and clip operations within the given geometry, yielding each one. + + Parameters + ---------- + geometry: GeometryType + Base geometry to start iteration. + + Returns + ------- + :class:`Geometry` + Geometries within the base geometry. + """ + if isinstance(geometry, base.GeometryGroup): + for g in geometry.geometries: + yield from traverse_geometries(g) + elif isinstance(geometry, base.ClipOperation): + yield from traverse_geometries(geometry.geometry_a) + yield from traverse_geometries(geometry.geometry_b) + yield geometry + + +def from_shapely( + shape: Shapely, + axis: Axis, + slab_bounds: tuple[float, float], + dilation: float = 0.0, + sidewall_angle: float = 0, + reference_plane: PlanePosition = "middle", +) -> base.Geometry: + """Convert a shapely primitive into a geometry instance by extrusion. + + Parameters + ---------- + shape : shapely.geometry.base.BaseGeometry + Shapely primitive to be converted. It must be a linear ring, a polygon or a collection + of any of those. + axis : int + Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). + slab_bounds: tuple[float, float] + Minimal and maximal positions of the extruded slab along ``axis``. + dilation : float + Dilation of the polygon in the base by shifting each edge along its normal outwards + direction by a distance; a negative value corresponds to erosion. + sidewall_angle : float = 0 + Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive + (negative) values result in slabs larger (smaller) at the base than at the top. + reference_plane : PlanePosition = "middle" + Reference position of the (dilated/eroded) polygons along the slab axis. One of + ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` + (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value + has no effect if ``sidewall_angle == 0``. + + Returns + ------- + :class:`Geometry` + Geometry extruded from the 2D data. + """ + if shape.geom_type == "LinearRing": + if sidewall_angle == 0: + return polyslab.PolySlab( + vertices=shape.coords[:-1], + axis=axis, + slab_bounds=slab_bounds, + dilation=dilation, + reference_plane=reference_plane, + ) + group = polyslab.ComplexPolySlabBase( + vertices=shape.coords[:-1], + axis=axis, + slab_bounds=slab_bounds, + dilation=dilation, + sidewall_angle=sidewall_angle, + reference_plane=reference_plane, + ).geometry_group + return group.geometries[0] if len(group.geometries) == 1 else group + + if shape.geom_type == "Polygon": + exterior = from_shapely( + shape.exterior, axis, slab_bounds, dilation, sidewall_angle, reference_plane + ) + interior = [ + from_shapely(hole, axis, slab_bounds, -dilation, -sidewall_angle, reference_plane) + for hole in shape.interiors + ] + if len(interior) == 0: + return exterior + interior = interior[0] if len(interior) == 1 else base.GeometryGroup(geometries=interior) + return base.ClipOperation(operation="difference", geometry_a=exterior, geometry_b=interior) + + if shape.geom_type in {"MultiPolygon", "GeometryCollection"}: + return base.GeometryGroup( + geometries=[ + from_shapely(geo, axis, slab_bounds, dilation, sidewall_angle, reference_plane) + for geo in shape.geoms + ] + ) + + raise Tidy3dError(f"Shape {shape} cannot be converted to Geometry.") + + +def vertices_from_shapely(shape: Shapely) -> ArrayFloat2D: + """Iterate over the polygons of a shapely geometry returning the vertices. + + Parameters + ---------- + shape : shapely.geometry.base.BaseGeometry + Shapely primitive to have its vertices extracted. It must be a linear ring, a polygon or a + collection of any of those. + + Returns + ------- + list[tuple[ArrayFloat2D]] + List of tuples ``(exterior, *interiors)``. + """ + if shape.geom_type == "LinearRing": + return [(shape.coords[:-1],)] + if shape.geom_type == "Polygon": + return [(shape.exterior.coords[:-1], *tuple(hole.coords[:-1] for hole in shape.interiors))] + if shape.geom_type in {"MultiPolygon", "GeometryCollection"}: + return sum(vertices_from_shapely(geo) for geo in shape.geoms) + + raise Tidy3dError(f"Shape {shape} cannot be converted to Geometry.") + + +def validate_no_transformed_polyslabs( + geometry: GeometryType, transform: MatrixReal4x4 = None +) -> None: + """Prevents the creation of slanted polyslabs rotated out of plane.""" + if transform is None: + transform = np.eye(4) + if isinstance(geometry, polyslab.PolySlab): + # sidewall_angle may be autograd-traced; unbox for the check only + if not ( + isclose(get_static(geometry.sidewall_angle), 0) + or base.Transformed.preserves_axis(transform, geometry.axis) + ): + raise Tidy3dError( + "Slanted PolySlabs are not allowed to be rotated out of the slab plane." + ) + elif isinstance(geometry, base.Transformed): + transform = np.dot(transform, geometry.transform) + validate_no_transformed_polyslabs(geometry.geometry, transform) + elif isinstance(geometry, base.GeometryGroup): + for geo in geometry.geometries: + validate_no_transformed_polyslabs(geo, transform) + elif isinstance(geometry, base.ClipOperation): + validate_no_transformed_polyslabs(geometry.geometry_a, transform) + validate_no_transformed_polyslabs(geometry.geometry_b, transform) + + +class SnapLocation(Enum): + """Describes different methods for defining the snapping locations.""" + + Boundary = 1 + """ + Choose the boundaries of Yee cells. + """ + Center = 2 + """ + Choose the center of Yee cells. + """ + + +class SnapBehavior(Enum): + """Describes different methods for snapping intervals, which are defined by two endpoints.""" + + Closest = 1 + """ + Snaps the interval's endpoints to the closest grid point. + """ + Expand = 2 + """ + Snaps the interval's endpoints to the closest grid points, + while guaranteeing that the snapping location will never move endpoints inwards. + """ + Contract = 3 + """ + Snaps the interval's endpoints to the closest grid points, + while guaranteeing that the snapping location will never move endpoints outwards. + """ + StrictExpand = 4 + """ + Same as Expand, but will always move endpoints outwards, even if already coincident with grid. + """ + StrictContract = 5 + """ + Same as Contract, but will always move endpoints inwards, even if already coincident with grid. + """ + Off = 6 + """ + Do not use snapping. + """ + + +class SnappingSpec(Tidy3dBaseModel): + """Specifies how to apply grid snapping along each dimension.""" + + location: tuple[SnapLocation, SnapLocation, SnapLocation] = Field( + title="Location", + description="Describes which positions in the grid will be considered for snapping.", + ) + + behavior: tuple[SnapBehavior, SnapBehavior, SnapBehavior] = Field( + title="Behavior", + description="Describes how snapping positions will be chosen.", + ) + + margin: Optional[tuple[NonNegativeInt, NonNegativeInt, NonNegativeInt]] = Field( + (0, 0, 0), + title="Margin", + description="Number of additional grid points to consider when expanding or contracting " + "during snapping. Only applies when ``SnapBehavior`` is ``Expand`` or ``Contract``.", + ) + + +def get_closest_value(test: float, coords: ArrayLike, upper_bound_idx: int) -> float: + """Helper to choose the closest value in an array to a given test value, + using the index of the upper bound. The ``upper_bound_idx`` corresponds to the first value in + the ``coords`` array which is greater than or equal to the test value. + """ + # Handle corner cases first + if upper_bound_idx == 0: + return coords[upper_bound_idx] + if upper_bound_idx == len(coords): + return coords[upper_bound_idx - 1] + # General case + lower_bound = coords[upper_bound_idx - 1] + upper_bound = coords[upper_bound_idx] + dlower = abs(test - lower_bound) + dupper = abs(test - upper_bound) + return lower_bound if dlower < dupper else upper_bound diff --git a/tidy3d/_common/components/medium.py b/tidy3d/_common/components/medium.py new file mode 100644 index 0000000000..2d5574b081 --- /dev/null +++ b/tidy3d/_common/components/medium.py @@ -0,0 +1,6460 @@ +"""Defines properties of the medium / materials""" + +from __future__ import annotations + +import functools +from abc import ABC, abstractmethod +from collections.abc import Sequence +from math import isclose +from typing import TYPE_CHECKING, Any, Callable, Never, Optional, TypeVar, Union, get_args + +import autograd.numpy as np +import numpy as npo +from autograd.differential_operators import tensor_jacobian_product +from numpy.typing import NDArray +from pydantic import Field, NonNegativeFloat, PositiveFloat, field_validator, model_validator + +from tidy3d._common.components.autograd.derivative_utils import integrate_within_bounds +from tidy3d._common.components.autograd.types import ( + TracedFloat, + TracedPolesAndResidues, + TracedPositiveFloat, +) +from tidy3d._common.components.autograd.utils import pack_complex_vec +from tidy3d._common.components.base import Tidy3dBaseModel, cached_property +from tidy3d._common.components.data.data_array import ( + DATA_ARRAY_MAP, + ScalarFieldDataArray, + SpatialDataArray, +) +from tidy3d._common.components.data.dataset import PermittivityDataset +from tidy3d._common.components.data.validators import validate_no_nans +from tidy3d._common.components.types.base import TYPE_TAG_STR, FreqBound, InterpMethod, TensorReal +from tidy3d._common.components.validators import validate_name_str +from tidy3d._common.components.viz import VisualizationSpec, add_ax_if_none +from tidy3d._common.constants import ( + C_0, + CONDUCTIVITY, + EPSILON_0, + HBAR, + HERTZ, + LARGEST_FP_NUMBER, + MICROMETER, + PERMITTIVITY, + RADPERSEC, + SECOND, + fp_eps, + pec_val, +) +from tidy3d._common.exceptions import SetupError, ValidationError +from tidy3d._common.log import log +from tidy3d.components.data.unstructured.base import UnstructuredGridDataset +from tidy3d.components.data.utils import ( + CustomSpatialDataType, + CustomSpatialDataTypeAnnotated, + _check_same_coordinates, + _get_numpy_array, + _ones_like, + _zeros_like, +) +from tidy3d.components.dispersion_fitter import ( + LOSS_CHECK_MAX, + LOSS_CHECK_MIN, + LOSS_CHECK_NUM, + imag_resp_extrema_locs, +) +from tidy3d.components.grid.grid import Coords, Grid +from tidy3d.components.material.tcad.heat import ThermalSpecType +from tidy3d.components.nonlinear import ( # noqa: F401 + KerrNonlinearity, + NonlinearModel, + NonlinearSpec, + NonlinearSusceptibility, + TwoPhotonAbsorption, +) +from tidy3d.components.time_modulation import ModulationSpec + +if TYPE_CHECKING: + import xarray as xr + from autograd.numpy.numpy_boxes import ArrayBox + from pydantic import FieldValidationInfo, PositiveInt + + from tidy3d._common.compat import Self + from tidy3d._common.components.autograd.derivative_utils import DerivativeInfo + from tidy3d._common.components.autograd.types import AutogradFieldMap + from tidy3d._common.components.data.dataset import ElectromagneticFieldDataset + from tidy3d._common.components.transformation import RotationType + from tidy3d._common.components.types import ( + ArrayComplex3D, + ArrayFloat1D, + Ax, + Axis, + Bound, + Complex, + PermittivityComponent, + ) + from tidy3d._common.components.types.base import PolesAndResidues + +T = TypeVar("T") + +ArrayFloat = NDArray[npo.floating] +ArrayComplex = NDArray[np.complexfloating] +ArrayGeneric = NDArray[Any] +FrequencyArray = Union[Sequence[float], ArrayFloat] +WeightFunction = Callable[[float], ArrayComplex] +ComplexArrayOrScalar = Union[complex, ArrayGeneric] + +# evaluate frequency as this number (Hz) if inf +FREQ_EVAL_INF = 1e50 + +# extrapolation option in custom medium +FILL_VALUE = "extrapolate" + +# Lossy metal +LOSSY_METAL_DEFAULT_SAMPLING_FREQUENCY = 20 +LOSSY_METAL_SCALED_REAL_PART = 10.0 +LOSSY_METAL_DEFAULT_MAX_POLES = 5 +LOSSY_METAL_DEFAULT_TOLERANCE_RMS = 1e-3 + +ALLOWED_INTERP_METHODS = get_args(InterpMethod) + + +def ensure_freq_in_range( + eps_model: Callable[[AbstractMedium, float], complex], +) -> Callable[[AbstractMedium, float], complex]: + """Decorate ``eps_model`` to log warning if frequency supplied is out of bounds.""" + + @functools.wraps(eps_model) + def _eps_model(self: AbstractMedium, frequency: float) -> complex: + """New eps_model function.""" + # evaluate infs and None as FREQ_EVAL_INF + is_inf_scalar = isinstance(frequency, float) and np.isinf(frequency) + if frequency is None or is_inf_scalar: + frequency = FREQ_EVAL_INF + + if isinstance(frequency, np.ndarray): + frequency = frequency.astype(float) + frequency[np.where(np.isinf(frequency))] = FREQ_EVAL_INF + + # if frequency range not present just return original function + if self.frequency_range is None: + return eps_model(self, frequency) + + fmin, fmax = self.frequency_range + # don't warn for evaluating infinite frequency + if is_inf_scalar: + return eps_model(self, frequency) + + outside_lower = np.zeros_like(frequency, dtype=bool) + outside_upper = np.zeros_like(frequency, dtype=bool) + + if fmin > 0: + outside_lower = frequency / fmin < 1 - fp_eps + elif fmin == 0: + outside_lower = frequency < 0 + + if fmax > 0: + outside_upper = frequency / fmax > 1 + fp_eps + + if np.any(outside_lower | outside_upper): + log.warning( + "frequency passed to 'Medium.eps_model()'" + f"is outside of 'Medium.frequency_range' = {self.frequency_range}", + capture=False, + ) + return eps_model(self, frequency) + + return _eps_model + + +""" Medium Definitions """ + + +class AbstractMedium(ABC, Tidy3dBaseModel): + """A medium within which electromagnetic waves propagate.""" + + name: Optional[str] = Field(None, title="Name", description="Optional unique name for medium.") + + frequency_range: Optional[FreqBound] = Field( + None, + title="Frequency Range", + description="Optional range of validity for the medium.", + units=(HERTZ, HERTZ), + ) + + allow_gain: bool = Field( + False, + title="Allow gain medium", + description="Allow the medium to be active. Caution: " + "simulations with a gain medium are unstable, and are likely to diverge." + "Simulations where ``allow_gain`` is set to ``True`` will still be charged even if " + "diverged. Monitor data up to the divergence point will still be returned and can be " + "useful in some cases.", + ) + + nonlinear_spec: Optional[Union[NonlinearSpec, NonlinearSusceptibility]] = Field( + None, + title="Nonlinear Spec", + description="Nonlinear spec applied on top of the base medium properties.", + ) + + modulation_spec: Optional[ModulationSpec] = Field( + None, + title="Modulation Spec", + description="Modulation spec applied on top of the base medium properties.", + ) + + viz_spec: Optional[VisualizationSpec] = Field( + None, + title="Visualization Specification", + description="Plotting specification for visualizing medium.", + ) + + heat_spec: Optional[ThermalSpecType] = Field( + None, + title="Heat Specification", + description="DEPRECATED: Use :class:`MultiPhysicsMedium`. Specification of the medium heat properties. They are " + "used for solving the heat equation via the :class:`HeatSimulation` interface. Such simulations can be" + "used for investigating the influence of heat propagation on the properties of optical systems. " + "Once the temperature distribution in the system is found using :class:`HeatSimulation` object, " + "``Simulation.perturbed_mediums_copy()`` can be used to convert mediums with perturbation " + "models defined into spatially dependent custom mediums. " + "Otherwise, the ``heat_spec`` does not directly affect the running of an optical " + "``Simulation``.", + discriminator=TYPE_TAG_STR, + ) + + @model_validator(mode="after") + def _validate_nonlinear_spec(self) -> Self: + """Check compatibility with nonlinear_spec.""" + if self.__class__.__name__ == "AnisotropicMedium" and any( + comp.nonlinear_spec is not None for comp in [self.xx, self.yy, self.zz] + ): + raise ValidationError( + "Nonlinearities are not currently supported for the components " + "of an anisotropic medium." + ) + if self.__class__.__name__ == "Medium2D" and any( + comp.nonlinear_spec is not None for comp in [self.ss, self.tt] + ): + raise ValidationError( + "Nonlinearities are not currently supported for the components of a 2D medium." + ) + + if self.nonlinear_spec is None: + return self + if isinstance(self.nonlinear_spec, NonlinearModel): + log.warning( + "The API for 'nonlinear_spec' has changed. " + "The old usage 'nonlinear_spec=model' is deprecated and will be removed " + "in a future release. The new usage is " + r"'nonlinear_spec=NonlinearSpec(models=\[model])'." + ) + for model in self._nonlinear_models: + model._validate_medium_type(self) + model._validate_medium(self) + if ( + isinstance(self.nonlinear_spec, NonlinearSpec) + and isinstance(model, NonlinearSusceptibility) + and model.numiters is not None + ): + raise ValidationError( + "'NonlinearSusceptibility.numiters' is deprecated. " + "Please use 'NonlinearSpec.num_iters' instead." + ) + return self + + @model_validator(mode="after") + def _check_either_modulation_or_nonlinear_spec(self) -> Self: + """Check compatibility with modulation_spec.""" + val = self.modulation_spec + nonlinear_spec = self.nonlinear_spec + if val is not None and nonlinear_spec is not None: + raise ValidationError( + f"For medium class {self.type}, 'modulation_spec' of class {type(val).__name__} and " + f"'nonlinear_spec' of class {type(nonlinear_spec).__name__} are " + "not simultaneously supported." + ) + return self + + _name_validator = validate_name_str() + + @model_validator(mode="after") + def _validate_modulation_spec_after(self) -> Self: + """Check compatibility with nonlinear_spec.""" + if self.__class__.__name__ == "Medium2D" and any( + comp.modulation_spec is not None for comp in [self.ss, self.tt] + ): + raise ValidationError( + "Time modulation is not currently supported for the components of a 2D medium." + ) + return self + + @property + def charge(self) -> None: + return None + + @property + def electrical(self) -> None: + return None + + @property + def heat(self) -> Optional[ThermalSpecType]: + return self.heat_spec + + @property + def optical(self) -> None: + return None + + @cached_property + def _nonlinear_models(self) -> list: + """The nonlinear models in the nonlinear_spec.""" + if self.nonlinear_spec is None: + return [] + if isinstance(self.nonlinear_spec, NonlinearModel): + return [self.nonlinear_spec] + if self.nonlinear_spec.models is None: + return [] + return list(self.nonlinear_spec.models) + + @cached_property + def _nonlinear_num_iters(self) -> PositiveInt: + """The num_iters of the nonlinear_spec.""" + if self.nonlinear_spec is None: + return 0 + if isinstance(self.nonlinear_spec, NonlinearModel): + if self.nonlinear_spec.numiters is None: + return 1 # old default value for backwards compatibility + return self.nonlinear_spec.numiters + return self.nonlinear_spec.num_iters + + @cached_property + def is_spatially_uniform(self) -> bool: + """Whether the medium is spatially uniform.""" + return True + + @cached_property + def is_time_modulated(self) -> bool: + """Whether any component of the medium is time modulated.""" + return self.modulation_spec is not None and self.modulation_spec.applied_modulation + + @cached_property + def is_nonlinear(self) -> bool: + """Whether the medium is nonlinear.""" + return self.nonlinear_spec is not None + + @cached_property + def is_custom(self) -> bool: + """Whether the medium is custom.""" + return isinstance(self, AbstractCustomMedium) + + @cached_property + def is_fully_anisotropic(self) -> bool: + """Whether the medium is fully anisotropic.""" + return isinstance(self, FullyAnisotropicMedium) + + @cached_property + def _incompatible_material_types(self) -> list[str]: + """A list of material properties present which may lead to incompatibilities.""" + properties = [ + self.is_time_modulated, + self.is_nonlinear, + self.is_custom, + self.is_fully_anisotropic, + ] + names = ["time modulated", "nonlinear", "custom", "fully anisotropic"] + types = [name for name, prop in zip(names, properties) if prop] + return types + + @cached_property + def _has_incompatibilities(self) -> bool: + """Whether the medium has incompatibilities. Certain medium types are incompatible + with certain others, and such pairs are not allowed to intersect in a simulation.""" + return len(self._incompatible_material_types) > 0 + + def _compatible_with(self, other: AbstractMedium) -> bool: + """Whether these two media are compatible if in structures that intersect.""" + if not (self._has_incompatibilities and other._has_incompatibilities): + return True + for med1, med2 in [(self, other), (other, self)]: + if med1.is_custom: + # custom and fully_anisotropic is OK + if med2.is_nonlinear or med2.is_time_modulated: + return False + if med1.is_fully_anisotropic: + if med2.is_nonlinear or med2.is_time_modulated: + return False + if med1.is_nonlinear: + if med2.is_time_modulated: + return False + return True + + @abstractmethod + def eps_model(self, frequency: float) -> complex: + # TODO this should be moved out of here into FDTD Simulation Mediums? + """Complex-valued permittivity as a function of frequency. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + complex + Complex-valued relative permittivity evaluated at ``frequency``. + """ + + def nk_model(self, frequency: float) -> tuple[float, float]: + """Real and imaginary parts of the refactive index as a function of frequency. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[float, float] + Real part (n) and imaginary part (k) of refractive index of medium. + """ + eps_complex = self.eps_model(frequency=frequency) + return self.eps_complex_to_nk(eps_complex) + + def loss_tangent_model(self, frequency: float) -> tuple[float, float]: + """Permittivity and loss tangent as a function of frequency. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[float, float] + Real part of permittivity and loss tangent. + """ + eps_complex = self.eps_model(frequency=frequency) + return self.eps_complex_to_eps_loss_tangent(eps_complex) + + @ensure_freq_in_range + def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: + """Main diagonal of the complex-valued permittivity tensor as a function of frequency. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[complex, complex, complex] + The diagonal elements of the relative permittivity tensor evaluated at ``frequency``. + """ + + # This only needs to be overwritten for anisotropic materials + eps = self.eps_model(frequency) + return (eps, eps, eps) + + def eps_diagonal_numerical(self, frequency: float) -> tuple[complex, complex, complex]: + """Main diagonal of the complex-valued permittivity tensor for numerical considerations + such as meshing and runtime estimation. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[complex, complex, complex] + The diagonal elements of relative permittivity tensor relevant for numerical + considerations evaluated at ``frequency``. + """ + + if self.is_pec: + # also 1 for lossy metal and Medium2D, but let's handle them in the subclass. + return (1.0 + 0j,) * 3 + + return self.eps_diagonal(frequency) + + def eps_comp(self, row: Axis, col: Axis, frequency: float) -> complex: + """Single component of the complex-valued permittivity tensor as a function of frequency. + + Parameters + ---------- + row : int + Component's row in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). + col : int + Component's column in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + complex + Element of the relative permittivity tensor evaluated at ``frequency``. + """ + + # This only needs to be overwritten for anisotropic materials + if row == col: + return self.eps_model(frequency) + return 0j + + def _eps_plot( + self, frequency: float, eps_component: Optional[PermittivityComponent] = None + ) -> float: + """Returns real part of epsilon for plotting. A specific component of the epsilon tensor can + be selected for anisotropic medium. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at. + eps_component : PermittivityComponent + Component of the permittivity tensor to plot + e.g. ``"xx"``, ``"yy"``, ``"zz"``, ``"xy"``, ``"yz"``, ... + Defaults to ``None``, which returns the average of the diagonal values. + + Returns + ------- + float + Element ``eps_component`` of the relative permittivity tensor evaluated at ``frequency``. + """ + # Assumes the material is isotropic + # Will need to be overridden for anisotropic materials + return self.eps_model(frequency).real + + @cached_property + @abstractmethod + def n_cfl(self) -> float: + # TODO this should be moved out of here into FDTD Simulation Mediums? + """To ensure a stable FDTD simulation, it is essential to select an appropriate + time step size in accordance with the CFL condition. The maximal time step + size is inversely proportional to the speed of light in the medium, and thus + proportional to the index of refraction. However, for dispersive medium, + anisotropic medium, and other more complicated media, there are complications in + deciding on the choice of the index of refraction. + + This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl``. + """ + + @add_ax_if_none + def plot(self, freqs: float, ax: Ax = None) -> Ax: + """Plot n, k of a :class:`.Medium` as a function of frequency. + + Parameters + ---------- + freqs: float + Frequencies (Hz) to evaluate the medium properties at. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + freqs = np.array(freqs) + eps_complex = np.array([self.eps_model(freq) for freq in freqs]) + n, k = AbstractMedium.eps_complex_to_nk(eps_complex) + + freqs_thz = freqs / 1e12 + ax.plot(freqs_thz, n, label="n") + ax.plot(freqs_thz, k, label="k") + ax.set_xlabel("frequency (THz)") + ax.set_title("medium dispersion") + ax.legend() + ax.set_aspect("auto") + return ax + + """ Conversion helper functions """ + + @staticmethod + def nk_to_eps_complex(n: float, k: float = 0.0) -> complex: + """Convert n, k to complex permittivity. + + Parameters + ---------- + n : float + Real part of refractive index. + k : float = 0.0 + Imaginary part of refrative index. + + Returns + ------- + complex + Complex-valued relative permittivity. + """ + eps_real = n**2 - k**2 + eps_imag = 2 * n * k + return eps_real + 1j * eps_imag + + @staticmethod + def eps_complex_to_nk(eps_c: complex) -> tuple[float, float]: + """Convert complex permittivity to n, k values. + + Parameters + ---------- + eps_c : complex + Complex-valued relative permittivity. + + Returns + ------- + tuple[float, float] + Real and imaginary parts of refractive index (n & k). + """ + eps_c = np.array(eps_c) + ref_index = np.sqrt(eps_c) + return np.real(ref_index), np.imag(ref_index) + + @staticmethod + def nk_to_eps_sigma(n: float, k: float, freq: float) -> tuple[float, float]: + """Convert ``n``, ``k`` at frequency ``freq`` to permittivity and conductivity values. + + Parameters + ---------- + n : float + Real part of refractive index. + k : float = 0.0 + Imaginary part of refrative index. + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[float, float] + Real part of relative permittivity & electric conductivity. + """ + eps_complex = AbstractMedium.nk_to_eps_complex(n, k) + eps_real, eps_imag = eps_complex.real, eps_complex.imag + omega = 2 * np.pi * freq + sigma = omega * eps_imag * EPSILON_0 + return eps_real, sigma + + @staticmethod + def eps_sigma_to_eps_complex(eps_real: float, sigma: float, freq: float) -> complex: + """convert permittivity and conductivity to complex permittivity at freq + + Parameters + ---------- + eps_real : float + Real-valued relative permittivity. + sigma : float + Conductivity. + freq : float + Frequency to evaluate permittivity at (Hz). + If not supplied, returns real part of permittivity (limit as frequency -> infinity.) + + Returns + ------- + complex + Complex-valued relative permittivity. + """ + if freq is None: + return eps_real + omega = 2 * np.pi * freq + + return eps_real + 1j * sigma / omega / EPSILON_0 + + @staticmethod + def eps_complex_to_eps_sigma(eps_complex: complex, freq: float) -> tuple[float, float]: + """Convert complex permittivity at frequency ``freq`` + to permittivity and conductivity values. + + Parameters + ---------- + eps_complex : complex + Complex-valued relative permittivity. + freq : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[float, float] + Real part of relative permittivity & electric conductivity. + """ + eps_real, eps_imag = eps_complex.real, eps_complex.imag + omega = 2 * np.pi * freq + sigma = omega * eps_imag * EPSILON_0 + return eps_real, sigma + + @staticmethod + def eps_complex_to_eps_loss_tangent(eps_complex: complex) -> tuple[float, float]: + """Convert complex permittivity to permittivity and loss tangent. + + Parameters + ---------- + eps_complex : complex + Complex-valued relative permittivity. + + Returns + ------- + tuple[float, float] + Real part of relative permittivity & loss tangent + """ + eps_real, eps_imag = eps_complex.real, eps_complex.imag + return eps_real, eps_imag / eps_real + + @staticmethod + def eps_loss_tangent_to_eps_complex(eps_real: float, loss_tangent: float) -> complex: + """Convert permittivity and loss tangent to complex permittivity. + + Parameters + ---------- + eps_real : float + Real part of relative permittivity + loss_tangent : float + Loss tangent + + Returns + ------- + eps_complex : complex + Complex-valued relative permittivity. + """ + return eps_real * (1 + 1j * loss_tangent) + + @staticmethod + def eV_to_angular_freq(f_eV: float) -> float: + """Convert frequency in unit of eV to rad/s. + + Parameters + ---------- + f_eV : float + Frequency in unit of eV + """ + return f_eV / HBAR + + @staticmethod + def angular_freq_to_eV(f_rad: float) -> float: + """Convert frequency in unit of rad/s to eV. + + Parameters + ---------- + f_rad : float + Frequency in unit of rad/s + """ + return f_rad * HBAR + + @staticmethod + def angular_freq_to_Hz(f_rad: float) -> float: + """Convert frequency in unit of rad/s to Hz. + + Parameters + ---------- + f_rad : float + Frequency in unit of rad/s + """ + return f_rad / 2 / np.pi + + @staticmethod + def Hz_to_angular_freq(f_hz: float) -> float: + """Convert frequency in unit of Hz to rad/s. + + Parameters + ---------- + f_hz : float + Frequency in unit of Hz + """ + return f_hz * 2 * np.pi + + @ensure_freq_in_range + def sigma_model(self, freq: float) -> complex: + """Complex-valued conductivity as a function of frequency. + + Parameters + ---------- + freq: float + Frequency to evaluate conductivity at (Hz). + + Returns + ------- + complex + Complex conductivity at this frequency. + """ + omega = freq * 2 * np.pi + eps_complex = self.eps_model(freq) + eps_inf = self.eps_model(np.inf) + sigma = (eps_inf - eps_complex) * 1j * omega * EPSILON_0 + return sigma + + @cached_property + def is_pec(self) -> bool: + """Whether the medium is a PEC.""" + return False + + @cached_property + def is_pmc(self) -> bool: + """Whether the medium is a PMC.""" + return False + + def sel_inside(self, bounds: Bound) -> AbstractMedium: + """Return a new medium that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. + + + Parameters + ---------- + bounds : tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + AbstractMedium + Medium with reduced data. + """ + + if self.modulation_spec is not None: + modulation_reduced = self.modulation_spec.sel_inside(bounds) + return self.updated_copy(modulation_spec=modulation_reduced) + + return self + + """ Autograd code """ + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + raise NotImplementedError(f"Can't compute derivative for 'Medium': '{type(self)}'.") + + def _derivative_eps_sigma_volume( + self, E_der_map: ElectromagneticFieldDataset, bounds: Bound + ) -> dict[str, xr.DataArray]: + """Get the derivative w.r.t permittivity and conductivity in the volume.""" + + vjp_eps_complex = self._derivative_eps_complex_volume(E_der_map=E_der_map, bounds=bounds) + + values = vjp_eps_complex.values + + # compute directly with frequency dimension + freqs = vjp_eps_complex.coords["f"].values + omegas = 2 * np.pi * freqs + eps_vjp = np.real(values) + sigma_vjp = -np.imag(values) / omegas / EPSILON_0 + + eps_vjp = np.sum(eps_vjp) + sigma_vjp = np.sum(sigma_vjp) + + return {"permittivity": eps_vjp, "conductivity": sigma_vjp} + + def _derivative_eps_complex_volume( + self, E_der_map: ElectromagneticFieldDataset, bounds: Bound + ) -> xr.DataArray: + """Get the derivative w.r.t complex-valued permittivity in the volume.""" + vjp_value = None + for field_name in ("Ex", "Ey", "Ez"): + fld = E_der_map[field_name] + vjp_value_fld = integrate_within_bounds( + arr=fld, + dims=("x", "y", "z"), + bounds=bounds, + ) + if vjp_value is None: + vjp_value = vjp_value_fld + else: + vjp_value += vjp_value_fld + + return vjp_value + + def __repr__(self) -> str: + """If the medium has a name, use it as the representation. Otherwise, use the default representation.""" + if self.name: + return self.name + return super().__repr__() + + +_PERTURBATION_MEDIUM_EXTRA_TYPES: list[type[Any]] = [] + + +def _build_perturbation_medium_type() -> object: + pertubation_medium_types = tuple(_PERTURBATION_MEDIUM_EXTRA_TYPES) + if len(pertubation_medium_types) == 0: + return Never + return Union[pertubation_medium_types] + + +PerturbationMediumType = _build_perturbation_medium_type() + + +def extend_perturbation_medium_type(*extra_types: type[Any]) -> None: + """Extend ``PerturbationMediumType`` and rebuild dependent models.""" + for extra_type in extra_types: + if extra_type not in _PERTURBATION_MEDIUM_EXTRA_TYPES: + _PERTURBATION_MEDIUM_EXTRA_TYPES.append(extra_type) + + global PerturbationMediumType + PerturbationMediumType = _build_perturbation_medium_type() + + # Rebuild dependent models if already defined + if "AbstractCustomMedium" in globals(): + AbstractCustomMedium.model_rebuild( + force=True, _types_namespace={"PerturbationMediumType": PerturbationMediumType} + ) + if "CustomPoleResidue" in globals(): + CustomPoleResidue.model_rebuild( + force=True, _types_namespace={"PerturbationMediumType": PerturbationMediumType} + ) + + +class AbstractCustomMedium(AbstractMedium, ABC): + """A spatially varying medium.""" + + interp_method: InterpMethod = Field( + "nearest", + title="Interpolation method", + description="Interpolation method to obtain permittivity values " + "that are not supplied at the Yee grids; For grids outside the range " + "of the supplied data, extrapolation will be applied. When the extrapolated " + "value is smaller (greater) than the minimal (maximal) of the supplied data, " + "the extrapolated value will take the minimal (maximal) of the supplied data.", + ) + + subpixel: bool = Field( + False, + title="Subpixel averaging", + description="If ``True``, apply the subpixel averaging method specified by " + "``Simulation``'s field ``subpixel`` for this type of material on the " + "interface of the structure, including exterior boundary and " + "intersection interfaces with other structures.", + ) + + derived_from: Optional[PerturbationMediumType] = Field( + None, + title="Parent Medium", + description="If not ``None``, it records the parent medium from which this medium was derived.", + ) + + @cached_property + @abstractmethod + def is_isotropic(self) -> bool: + """The medium is isotropic or anisotropic.""" + + def _interp_method(self, comp: Axis) -> InterpMethod: + """Interpolation method applied to comp.""" + return self.interp_method + + @abstractmethod + def eps_dataarray_freq( + self, frequency: float + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + """Permittivity array at ``frequency``. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[ + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset` + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset` + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset` + ], + ] + The permittivity evaluated at ``frequency``. + """ + + def eps_diagonal_on_grid( + self, + frequency: float, + coords: Coords, + ) -> tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D]: + """Spatial profile of main diagonal of the complex-valued permittivity + at ``frequency`` interpolated at the supplied coordinates. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + coords : :class:`.Coords` + The grid point coordinates over which interpolation is performed. + + Returns + ------- + tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D] + The complex-valued permittivity tensor at ``frequency`` interpolated + at the supplied coordinate. + """ + eps_spatial = self.eps_dataarray_freq(frequency) + if self.is_isotropic: + eps_interp = _get_numpy_array( + coords.spatial_interp(eps_spatial[0], self._interp_method(0)) + ) + return (eps_interp, eps_interp, eps_interp) + return tuple( + _get_numpy_array(coords.spatial_interp(eps_comp, self._interp_method(comp))) + for comp, eps_comp in enumerate(eps_spatial) + ) + + def eps_comp_on_grid( + self, + row: Axis, + col: Axis, + frequency: float, + coords: Coords, + ) -> ArrayComplex3D: + """Spatial profile of a single component of the complex-valued permittivity tensor at + ``frequency`` interpolated at the supplied coordinates. + + Parameters + ---------- + row : int + Component's row in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). + col : int + Component's column in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). + frequency : float + Frequency to evaluate permittivity at (Hz). + coords : :class:`.Coords` + The grid point coordinates over which interpolation is performed. + + Returns + ------- + ArrayComplex3D + Single component of the complex-valued permittivity tensor at ``frequency`` interpolated + at the supplied coordinates. + """ + + if row == col: + return self.eps_diagonal_on_grid(frequency, coords)[row] + return 0j + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + """Complex-valued spatially averaged permittivity as a function of frequency.""" + if self.is_isotropic: + return np.mean(_get_numpy_array(self.eps_dataarray_freq(frequency)[0])) + return np.mean( + [np.mean(_get_numpy_array(eps_comp)) for eps_comp in self.eps_dataarray_freq(frequency)] + ) + + @ensure_freq_in_range + def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: + """Main diagonal of the complex-valued permittivity tensor + at ``frequency``. Spatially, we take max{||eps||}, so that autoMesh generation + works appropriately. + """ + eps_spatial = self.eps_dataarray_freq(frequency) + if self.is_isotropic: + eps_comp = _get_numpy_array(eps_spatial[0]).ravel() + eps = eps_comp[np.argmax(np.abs(eps_comp))] + return (eps, eps, eps) + eps_spatial_array = (_get_numpy_array(eps_comp).ravel() for eps_comp in eps_spatial) + return tuple(eps_comp[np.argmax(np.abs(eps_comp))] for eps_comp in eps_spatial_array) + + def _get_real_vals(self, x: ArrayGeneric) -> ArrayFloat: + """Grab the real part of the values in array. + Used for _eps_bounds() + """ + return _get_numpy_array(np.real(x)).ravel() + + def _eps_bounds( + self, + frequency: Optional[float] = None, + eps_component: Optional[PermittivityComponent] = None, + ) -> tuple[float, float]: + """Returns permittivity bounds for setting the color bounds when plotting. + + Parameters + ---------- + frequency : float = None + Frequency to evaluate the relative permittivity of all mediums. + If not specified, evaluates at infinite frequency. + eps_component : Optional[PermittivityComponent] = None + Component of the permittivity tensor to plot for anisotropic materials, + e.g. ``"xx"``, ``"yy"``, ``"zz"``, ``"xy"``, ``"yz"``, ... + Defaults to ``None``, which returns the average of the diagonal values. + + Returns + ------- + tuple[float, float] + The min and max values of the permittivity for the selected component and evaluated at ``frequency``. + """ + eps_dataarray = self.eps_dataarray_freq(frequency) + all_eps = np.concatenate(self._get_real_vals(eps_comp) for eps_comp in eps_dataarray) + return (np.min(all_eps), np.max(all_eps)) + + @staticmethod + def _validate_isreal_dataarray(dataarray: CustomSpatialDataType) -> bool: + """Validate that the dataarray is real""" + return np.all(np.isreal(_get_numpy_array(dataarray))) + + @staticmethod + def _validate_isreal_dataarray_tuple( + dataarray_tuple: tuple[CustomSpatialDataType, ...], + ) -> bool: + """Validate that the dataarray is real""" + return np.all([AbstractCustomMedium._validate_isreal_dataarray(f) for f in dataarray_tuple]) + + @abstractmethod + def _sel_custom_data_inside(self, bounds: Bound) -> Self: + """Return a new medium that contains the minimal amount custom data necessary to cover + a spatial region defined by ``bounds``.""" + + def sel_inside(self, bounds: Bound) -> AbstractCustomMedium: + """Return a new medium that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. + + + Parameters + ---------- + bounds : tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + AbstractMedium + Medium with reduced data. + """ + + self_mod_data_reduced = super().sel_inside(bounds) + + return self_mod_data_reduced._sel_custom_data_inside(bounds) + + @staticmethod + def _not_loaded(field: Any) -> bool: + """Check whether data was not loaded.""" + if isinstance(field, str) and field in DATA_ARRAY_MAP: + return True + # attempting to construct an UnstructuredGridDataset from a dict + if isinstance(field, dict) and field.get("type") in ( + "TriangularGridDataset", + "TetrahedralGridDataset", + ): + return any( + isinstance(subfield, str) and subfield in DATA_ARRAY_MAP + for subfield in [field["points"], field["cells"], field["values"]] + ) + # attempting to pass an UnstructuredGridDataset with zero points + if isinstance(field, UnstructuredGridDataset): + return any(len(subfield) == 0 for subfield in [field.points, field.cells, field.values]) + return False + + def _derivative_field_cmp( + self, + E_der_map: ElectromagneticFieldDataset, + spatial_data: PermittivityDataset, + dim: str, + ) -> ArrayGeneric: + if isinstance(spatial_data, UnstructuredGridDataset): + raise NotImplementedError( + "Adjoint derivatives for unstructured custom media are not supported." + ) + coords_interp = {key: val for key, val in spatial_data.coords.items() if len(val) > 1} + dims_sum = {dim for dim in spatial_data.coords.keys() if dim not in coords_interp} + + eps_coordinate_shape = [ + len(spatial_data.coords[dim]) for dim in spatial_data.dims if dim in "xyz" + ] + + # compute sizes along each of the interpolation dimensions + sizes_list = [] + for _, coords in coords_interp.items(): + num_coords = len(coords) + coords = np.array(coords) + + # compute distances between midpoints for all internal coords + mid_points = (coords[1:] + coords[:-1]) / 2.0 + dists = np.diff(mid_points) + sizes = np.zeros(num_coords) + sizes[1:-1] = dists + + # estimate the sizes on the edges using 2 x the midpoint distance + sizes[0] = 2 * abs(mid_points[0] - coords[0]) + sizes[-1] = 2 * abs(coords[-1] - mid_points[-1]) + + sizes_list.append(sizes) + + # turn this into a volume element, should be re-sizeable to the gradient shape + if sizes_list: + d_vol = functools.reduce(np.outer, sizes_list) + else: + # if sizes_list is empty, then reduce() fails + d_vol = np.array(1.0) + + # TODO: probably this could be more robust. eg if the DataArray has weird edge cases + E_der_dim = E_der_map[f"E{dim}"] + E_der_dim_interp = ( + E_der_dim.interp(**coords_interp, assume_sorted=True).fillna(0.0).sum(dims_sum).sum("f") + ) + vjp_array = np.array(E_der_dim_interp.values).astype(complex) + vjp_array = vjp_array.reshape(eps_coordinate_shape) + + # multiply by volume elements (if possible, being defensive here..) + try: + vjp_array *= d_vol.reshape(vjp_array.shape) + except ValueError: + log.warning( + "Skipping volume element normalization of 'CustomMedium' gradients. " + f"Could not reshape the volume elements of shape {d_vol.shape} " + f"to the shape of the gradient {vjp_array.shape}. " + "If you encounter this warning, gradient direction will be accurate but the norm " + "will be inaccurate. Please raise an issue on the tidy3d front end with this " + "message and some information about your simulation setup and we will investigate. " + ) + return vjp_array + + def _derivative_field_cmp_custom( + self, + E_der_map: ElectromagneticFieldDataset, + spatial_data: SpatialDataArray, + dim: str, + freqs: NDArray, + bounds: Optional[Bound] = None, + component: str = "real", + interp_method: Optional[InterpMethod] = None, + ) -> NDArray: + """Compute the derivative with respect to a material property component.""" + param_coords = {axis: np.asarray(spatial_data.coords[axis]) for axis in "xyz"} + eps_shape = [len(param_coords[axis]) for axis in "xyz"] + dtype_out = complex if component == "complex" else float + + E_der_dim = E_der_map.get(f"E{dim}") + if E_der_dim is None or np.all(E_der_dim.values == 0): + return np.zeros(eps_shape, dtype=dtype_out) + + field_coords = {axis: np.asarray(E_der_dim.coords[axis]) for axis in "xyz"} + values = E_der_dim.values + + def _bounds_slice(axis: NDArray, vmin: float, vmax: float, *, name: str) -> slice: + n = axis.size + i0 = int(np.searchsorted(axis, vmin, side="left")) + i1 = int(np.searchsorted(axis, vmax, side="right")) + if i1 <= i0 and n: + old = (i0, i1) + if i1 < n: + i1 = i0 + 1 # expand right + elif i0 > 0: + i0 = i1 - 1 # expand left + log.warning( + f"Empty bounds crop on '{name}' while computing CustomMedium parameter gradients " + f"(adjoint field grid -> medium grid): bounds=[{vmin!r}, {vmax!r}], " + f"grid=[{axis[0]!r}, {axis[-1]!r}] -> indices {old}; using ({i0}, {i1}).", + log_once=True, + ) + return slice(i0, i1) + + # usage + if bounds is not None: + (xmin, ymin, zmin), (xmax, ymax, zmax) = bounds + + sx = _bounds_slice(field_coords["x"], xmin, xmax, name="x") + sy = _bounds_slice(field_coords["y"], ymin, ymax, name="y") + sz = _bounds_slice(field_coords["z"], zmin, zmax, name="z") + + field_coords = {k: field_coords[k][s] for k, s in (("x", sx), ("y", sy), ("z", sz))} + values = values[sx, sy, sz, :] + + def _axis_sizes(coords: NDArray) -> NDArray: + if coords.size <= 1: + return np.array([1.0]) + mid_points = (coords[1:] + coords[:-1]) / 2.0 + dists = np.diff(mid_points) + sizes = np.zeros(coords.size) + sizes[1:-1] = dists + sizes[0] = 2 * abs(mid_points[0] - coords[0]) + sizes[-1] = 2 * abs(coords[-1] - mid_points[-1]) + return sizes + + size_x = _axis_sizes(field_coords["x"]) + size_y = _axis_sizes(field_coords["y"]) + size_z = _axis_sizes(field_coords["z"]) + scale = ( + size_x[:, None, None, None] * size_y[None, :, None, None] * size_z[None, None, :, None] + ) + np.multiply(values, scale, out=values) + + method = interp_method if interp_method is not None else self.interp_method + + def _transpose_interp_axis( + field_values: NDArray, field_coords_1d: NDArray, param_coords_1d: NDArray + ) -> NDArray: + """ + Transpose (adjoint) of 1D interpolation along one axis. + + Parameters + ---------- + field_values : np.ndarray + Array of values sampled on the field grid along this axis. + Shape: (n_field, ...rest...). + Notes: + - The first axis corresponds to `field_coords_1d`. + - The remaining axes (...rest...) are treated as batch dimensions and are + carried through unchanged. + + field_coords_1d : np.ndarray + 1D coordinates of the field grid along this axis. + Shape: (n_field,). + + param_coords_1d : np.ndarray + 1D coordinates of the parameter grid along this axis. + Shape: (n_param,). Must be sorted ascending for the searchsorted-based logic. + + Returns + ------- + param_values : np.ndarray + Field contributions accumulated onto the parameter grid along this axis. + Shape: (n_param, ...rest...). + + Implementation note + ------------------- + For efficient accumulation, we flatten the trailing dimensions (...rest...) into a single + dimension so we can run a vectorized `np.add.at` on a 2D buffer of shape (n_param, n_rest), + then reshape back to (n_param, ...rest...). + """ + # Single-point parameter grid: every field sample maps to the only parameter entry, + if param_coords_1d.size == 1: + return field_values.sum(axis=0, keepdims=True) + + # Ensure parameter coordinates are sorted for searchsorted-based binning. + if np.any(param_coords_1d[1:] < param_coords_1d[:-1]): + raise ValueError("Spatial coordinates must be sorted before computing derivatives.") + param_coords_sorted = param_coords_1d + + n_param = param_coords_sorted.size + if method not in ALLOWED_INTERP_METHODS: + raise ValueError( + f"Unsupported interpolation method: {method!r}. " + f"Choose one of: {', '.join(ALLOWED_INTERP_METHODS)}." + ) + + # Flatten trailing dimensions into a single "rest" dimension for vectorized accumulation. + n_field = field_values.shape[0] + field_values_2d = field_values.reshape(n_field, -1) + + if method == "nearest": + # Midpoints define bin edges between adjacent parameter coordinates. + param_midpoints = (param_coords_sorted[1:] + param_coords_sorted[:-1]) / 2.0 + # Map each field coordinate to a nearest parameter-bin index. + param_index_nearest = np.searchsorted(param_midpoints, field_coords_1d) + + # Accumulate all field samples into their assigned parameter bins. + param_values_2d = npo.zeros( + (n_param, field_values_2d.shape[1]), dtype=field_values.dtype + ) + npo.add.at(param_values_2d, param_index_nearest, field_values_2d) + + param_values = param_values_2d.reshape((n_param,) + field_values.shape[1:]) + return param_values + + # linear + # Find bracketing parameter indices for each field coordinate. + param_index_upper = np.searchsorted(param_coords_sorted, field_coords_1d, side="right") + param_index_upper = np.clip(param_index_upper, 1, n_param - 1) + param_index_lower = param_index_upper - 1 + + # Compute interpolation fraction within the bracketing segment. + segment_width = ( + param_coords_sorted[param_index_upper] - param_coords_sorted[param_index_lower] + ) + segment_width = np.where(segment_width == 0, 1.0, segment_width) + frac_upper = (field_coords_1d - param_coords_sorted[param_index_lower]) / segment_width + frac_upper = np.clip(frac_upper, 0.0, 1.0) + + # Weights per field sample (broadcast across the flattened trailing dimensions). + w_lower = (1.0 - frac_upper)[:, None] + w_upper = frac_upper[:, None] + + # Accumulate contributions into both bracketing parameter indices. + param_values_2d = npo.zeros( + (n_param, field_values_2d.shape[1]), dtype=field_values.dtype + ) + npo.add.at(param_values_2d, param_index_lower, field_values_2d * w_lower) + npo.add.at(param_values_2d, param_index_upper, field_values_2d * w_upper) + + param_values = param_values_2d.reshape((n_param,) + field_values.shape[1:]) + return param_values + + def _interp_axis( + arr: NDArray, axis: int, field_axis: NDArray, param_axis: NDArray + ) -> NDArray: + """Accumulate values from the field grid onto the parameter grid along one axis. + + Moves ``axis`` to the front, applies ``_transpose_interp_axis`` (adjoint of 1D interpolation) + to map from ``field_axis`` (n_field) to ``param_axis`` (n_param), then moves the axis back. + """ + moved = np.moveaxis(arr, axis, 0) + moved = _transpose_interp_axis(moved, field_axis, param_axis) + return np.moveaxis(moved, 0, axis) + + values = _interp_axis(values, 0, field_coords["x"], param_coords["x"]) + values = _interp_axis(values, 1, field_coords["y"], param_coords["y"]) + values = _interp_axis(values, 2, field_coords["z"], param_coords["z"]) + + freqs_da = np.asarray(E_der_dim.coords["f"]) + if component == "sigma": + values = values.imag * (-1.0 / (2.0 * np.pi * freqs_da * EPSILON_0)) + elif component == "imag": + values = values.imag + elif component == "real": + values = values.real + + vjp_array = values.sum(axis=-1).reshape(eps_shape) + + # match derivative dtype to the underlying dataset + target_array = getattr(spatial_data, "values", None) + if target_array is None and hasattr(spatial_data, "data"): + target_array = spatial_data.data + if target_array is not None: + target_dtype = np.asarray(target_array).dtype + if not np.issubdtype(target_dtype, np.complexfloating): + vjp_array = np.real(vjp_array).astype(target_dtype, copy=False) + + return vjp_array + + +""" Dispersionless Medium """ + + +# PEC keyword +class PECMedium(AbstractMedium): + """Perfect electrical conductor class. + + Note + ---- + + To avoid confusion from duplicate PECs, must import ``tidy3d.PEC`` instance directly. + + + + """ + + @field_validator("modulation_spec") + @classmethod + def _validate_modulation_spec(cls, val: Optional[ModulationSpec]) -> Optional[ModulationSpec]: + """Check compatibility with modulation_spec.""" + if val is not None: + raise ValidationError( + f"A 'modulation_spec' of class {type(val).__name__} is not " + f"currently supported for medium class {cls.__name__}." + ) + return val + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + # return something like frequency with value of pec_val + 0j + return 0j * frequency + pec_val + + @cached_property + def n_cfl(self) -> float: + """This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl``. + """ + return 1.0 + + @cached_property + def is_pec(self) -> bool: + """Whether the medium is a PEC.""" + return True + + +# PEC builtin instance +PEC = PECMedium(name="PEC") + + +# PMC keyword +class PMCMedium(AbstractMedium): + """Perfect magnetic conductor class. + + Note + ---- + + To avoid confusion from duplicate PMCs, must import ``tidy3d.PMC`` instance directly. + + + + """ + + @field_validator("modulation_spec") + @classmethod + def _validate_modulation_spec(cls, val: Optional[ModulationSpec]) -> Optional[ModulationSpec]: + """Check compatibility with modulation_spec.""" + if val is not None: + raise ValidationError( + f"A 'modulation_spec' of class {type(val).__name__} is not " + f"currently supported for medium class {cls.__name__}." + ) + return val + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + # permittivity of a PMC. + return 1.0 + 0j + + @cached_property + def n_cfl(self) -> float: + """This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl``. + """ + return 1.0 + + @cached_property + def is_pmc(self) -> bool: + """Whether the medium is a PMC.""" + return True + + +# PEC builtin instance +PMC = PMCMedium(name="PMC") + + +class Medium(AbstractMedium): + """Dispersionless medium. Mediums define the optical properties of the materials within the simulation. + + Notes + ----- + + In a dispersion-less medium, the displacement field :math:`D(t)` reacts instantaneously to the applied + electric field :math:`E(t)`. + + .. math:: + + D(t) = \\epsilon E(t) + + Example + ------- + >>> dielectric = Medium(permittivity=4.0, name='my_medium') + >>> eps = dielectric.eps_model(200e12) + + See Also + -------- + + **Notebooks** + * `Introduction on Tidy3D working principles <../../notebooks/Primer.html#Mediums>`_ + * `Index <../../notebooks/docs/features/medium.html>`_ + + **Lectures** + * `Modeling dispersive material in FDTD `_ + + **GUI** + * `Mediums `_ + + """ + + permittivity: TracedFloat = Field( + 1.0, ge=1.0, title="Permittivity", description="Relative permittivity.", units=PERMITTIVITY + ) + + conductivity: TracedFloat = Field( + 0.0, + title="Conductivity", + description="Electric conductivity. Defined such that the imaginary part of the complex " + "permittivity at angular frequency omega is given by conductivity/omega.", + units=CONDUCTIVITY, + ) + + @model_validator(mode="after") + def _passivity_validation(self) -> Self: + """Assert passive medium if ``allow_gain`` is False.""" + val = self.conductivity + if not self.allow_gain and val < 0: + raise ValidationError( + "For passive medium, 'conductivity' must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, and are likely to diverge." + ) + return self + + @model_validator(mode="after") + def _permittivity_modulation_validation(self) -> Self: + """Assert modulated permittivity cannot be <= 0.""" + val = self.permittivity + modulation = self.modulation_spec + if modulation is None or modulation.permittivity is None: + return self + + min_eps_inf = np.min(_get_numpy_array(val)) + if min_eps_inf - modulation.permittivity.max_modulation <= 0: + raise ValidationError( + "The minimum permittivity value with modulation applied was found to be negative." + ) + return self + + @model_validator(mode="after") + def _passivity_modulation_validation(self) -> Self: + """Assert passive medium if ``allow_gain`` is False.""" + val = self.conductivity + modulation = self.modulation_spec + if modulation is None or modulation.conductivity is None: + return self + + min_sigma = np.min(_get_numpy_array(val)) + if not self.allow_gain and min_sigma - modulation.conductivity.max_modulation < 0: + raise ValidationError( + "For passive medium, 'conductivity' must be non-negative at any time." + "With conductivity modulation, this medium can sometimes be active. " + "Please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + return self + + @cached_property + def n_cfl(self) -> float: + """This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl``. + + For dispersiveless medium, it equals ``sqrt(permittivity)``. + """ + permittivity = self.permittivity + if self.modulation_spec is not None and self.modulation_spec.permittivity is not None: + permittivity -= self.modulation_spec.permittivity.max_modulation + n, _ = self.eps_complex_to_nk(permittivity) + return n + + @staticmethod + def _eps_model(permittivity: float, conductivity: float, frequency: float) -> complex: + """Complex-valued permittivity as a function of frequency.""" + + return AbstractMedium.eps_sigma_to_eps_complex(permittivity, conductivity, frequency) + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + """Complex-valued permittivity as a function of frequency.""" + + return self._eps_model(self.permittivity, self.conductivity, frequency) + + @classmethod + def from_nk(cls, n: float, k: float, freq: float, **kwargs: Any) -> Self: + """Convert ``n`` and ``k`` values at frequency ``freq`` to :class:`.Medium`. + + Parameters + ---------- + n : float + Real part of refractive index. + k : float = 0 + Imaginary part of refrative index. + freq : float + Frequency to evaluate permittivity at (Hz). + kwargs: dict + Keyword arguments passed to the medium construction. + + Returns + ------- + :class:`.Medium` + medium containing the corresponding ``permittivity`` and ``conductivity``. + """ + eps, sigma = AbstractMedium.nk_to_eps_sigma(n, k, freq) + if eps < 1: + raise ValidationError( + "Dispersiveless medium must have 'permittivity>=1`. " + "Please use 'Lorentz.from_nk()' to covert to a Lorentz medium, or the utility " + "function 'td.medium_from_nk()' to automatically return the proper medium type." + ) + return cls(permittivity=eps, conductivity=sigma, **kwargs) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + + # get vjps w.r.t. permittivity and conductivity of the bulk + vjps_volume = self._derivative_eps_sigma_volume( + E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds + ) + + # store the fields asked for by ``field_paths`` + derivative_map = {} + for field_path in derivative_info.paths: + field_name, *_ = field_path + if field_name in vjps_volume: + derivative_map[field_path] = vjps_volume[field_name] + + return derivative_map + + def _derivative_eps_sigma_volume( + self, E_der_map: ElectromagneticFieldDataset, bounds: Bound + ) -> dict[str, xr.DataArray]: + """Get the derivative w.r.t permittivity and conductivity in the volume.""" + + vjp_eps_complex = self._derivative_eps_complex_volume(E_der_map=E_der_map, bounds=bounds) + + values = vjp_eps_complex.values + + # vjp of eps_complex_to_eps_sigma + omegas = 2 * np.pi * vjp_eps_complex.coords["f"].values + eps_vjp = np.real(values) + sigma_vjp = -np.imag(values) / omegas / EPSILON_0 + + eps_vjp = np.sum(eps_vjp) + sigma_vjp = np.sum(sigma_vjp) + + return {"permittivity": eps_vjp, "conductivity": sigma_vjp} + + def _derivative_eps_complex_volume( + self, E_der_map: ElectromagneticFieldDataset, bounds: Bound + ) -> xr.DataArray: + """Get the derivative w.r.t complex-valued permittivity in the volume.""" + + vjp_value = None + for field_name in ("Ex", "Ey", "Ez"): + fld = E_der_map[field_name] + vjp_value_fld = integrate_within_bounds( + arr=fld, + dims=("x", "y", "z"), + bounds=bounds, + ) + if vjp_value is None: + vjp_value = vjp_value_fld + else: + vjp_value += vjp_value_fld + + return vjp_value + + +class CustomIsotropicMedium(AbstractCustomMedium, Medium): + """:class:`.Medium` with user-supplied permittivity distribution. + (This class is for internal use in v2.0; it will be renamed as `CustomMedium` in v3.0.) + + Example + ------- + >>> Nx, Ny, Nz = 10, 9, 8 + >>> X = np.linspace(-1, 1, Nx) + >>> Y = np.linspace(-1, 1, Ny) + >>> Z = np.linspace(-1, 1, Nz) + >>> coords = dict(x=X, y=Y, z=Z) + >>> permittivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) + >>> conductivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) + >>> dielectric = CustomIsotropicMedium(permittivity=permittivity, conductivity=conductivity) + >>> eps = dielectric.eps_model(200e12) + """ + + permittivity: CustomSpatialDataTypeAnnotated = Field( + title="Permittivity", + description="Relative permittivity.", + units=PERMITTIVITY, + ) + + conductivity: Optional[CustomSpatialDataTypeAnnotated] = Field( + None, + title="Conductivity", + description="Electric conductivity. Defined such that the imaginary part of the complex " + "permittivity at angular frequency omega is given by conductivity/omega.", + units=CONDUCTIVITY, + ) + + _no_nans = validate_no_nans("permittivity", "conductivity") + + @field_validator("permittivity") + @classmethod + def _eps_inf_greater_no_less_than_one( + cls, val: Optional[CustomSpatialDataTypeAnnotated] + ) -> Optional[CustomSpatialDataTypeAnnotated]: + """Assert any eps_inf must be >=1""" + + if not CustomIsotropicMedium._validate_isreal_dataarray(val): + raise SetupError("'permittivity' must be real.") + + if np.any(_get_numpy_array(val) < 1): + raise SetupError("'permittivity' must be no less than one.") + + return val + + @model_validator(mode="after") + def _conductivity_real_and_correct_shape(self) -> Self: + """Assert conductivity is real and of right shape.""" + val = self.conductivity + + if val is None: + return self + + if not CustomIsotropicMedium._validate_isreal_dataarray(val): + raise SetupError("'conductivity' must be real.") + + if not _check_same_coordinates(self.permittivity, val): + raise SetupError("'permittivity' and 'conductivity' must have the same coordinates.") + return self + + @model_validator(mode="after") + def _passivity_validation(self) -> Self: + """Assert passive medium if ``allow_gain`` is False.""" + val = self.conductivity + if val is None: + return self + if not self.allow_gain and np.any(_get_numpy_array(val) < 0): + raise ValidationError( + "For passive medium, 'conductivity' must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, and are likely to diverge." + ) + return self + + @cached_property + def is_spatially_uniform(self) -> bool: + """Whether the medium is spatially uniform.""" + if self.conductivity is None: + return self.permittivity.is_uniform + return self.permittivity.is_uniform and self.conductivity.is_uniform + + @cached_property + def n_cfl(self) -> float: + """This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl``. + + For dispersiveless medium, it equals ``sqrt(permittivity)``. + """ + permittivity = np.min(_get_numpy_array(self.permittivity)) + if self.modulation_spec is not None and self.modulation_spec.permittivity is not None: + permittivity -= self.modulation_spec.permittivity.max_modulation + n, _ = self.eps_complex_to_nk(permittivity) + return n + + @cached_property + def is_isotropic(self) -> bool: + """Whether the medium is isotropic.""" + return True + + def eps_dataarray_freq( + self, frequency: float + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + """Permittivity array at ``frequency``. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[ + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset` + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset` + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset` + ], + ] + The permittivity evaluated at ``frequency``. + """ + conductivity = self.conductivity + if conductivity is None: + conductivity = _zeros_like(self.permittivity) + eps = self.eps_sigma_to_eps_complex(self.permittivity, conductivity, frequency) + return (eps, eps, eps) + + def _sel_custom_data_inside(self, bounds: Bound) -> Self: + """Return a new custom medium that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. + + + Parameters + ---------- + bounds : tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + CustomMedium + CustomMedium with reduced data. + """ + if not self.permittivity.does_cover(bounds=bounds): + log.warning( + "Permittivity spatial data array does not fully cover the requested region." + ) + perm_reduced = self.permittivity.sel_inside(bounds=bounds) + cond_reduced = None + if self.conductivity is not None: + if not self.conductivity.does_cover(bounds=bounds): + log.warning( + "Conductivity spatial data array does not fully cover the requested region." + ) + cond_reduced = self.conductivity.sel_inside(bounds=bounds) + + return self.updated_copy( + permittivity=perm_reduced, + conductivity=cond_reduced, + ) + + +class CustomMedium(AbstractCustomMedium): + """:class:`.Medium` with user-supplied permittivity distribution. + + Example + ------- + >>> Nx, Ny, Nz = 10, 9, 8 + >>> X = np.linspace(-1, 1, Nx) + >>> Y = np.linspace(-1, 1, Ny) + >>> Z = np.linspace(-1, 1, Nz) + >>> coords = dict(x=X, y=Y, z=Z) + >>> permittivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) + >>> conductivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) + >>> dielectric = CustomMedium(permittivity=permittivity, conductivity=conductivity) + >>> eps = dielectric.eps_model(200e12) + """ + + eps_dataset: Optional[PermittivityDataset] = Field( + None, + title="Permittivity Dataset", + description="[To be deprecated] User-supplied dataset containing complex-valued " + "permittivity as a function of space. Permittivity distribution over the Yee-grid " + "will be interpolated based on ``interp_method``.", + ) + + permittivity: Optional[CustomSpatialDataTypeAnnotated] = Field( + None, + title="Permittivity", + description="Spatial profile of relative permittivity.", + units=PERMITTIVITY, + ) + + conductivity: Optional[CustomSpatialDataTypeAnnotated] = Field( + None, + title="Conductivity", + description="Spatial profile Electric conductivity. Defined such " + "that the imaginary part of the complex permittivity at angular " + "frequency omega is given by conductivity/omega.", + units=CONDUCTIVITY, + ) + + _no_nans = validate_no_nans("eps_dataset", "permittivity", "conductivity") + + @model_validator(mode="before") + @classmethod + def _warn_if_none(cls, data: dict) -> dict: + """Warn if the data array fails to load, and return a vacuum medium.""" + fail_load = False + if cls._not_loaded(data.get("permittivity")): + log.warning( + "Loading 'permittivity' without data; constructing a vacuum medium instead." + ) + fail_load = True + if cls._not_loaded(data.get("conductivity")): + log.warning( + "Loading 'conductivity' without data; constructing a vacuum medium instead." + ) + fail_load = True + eps_ds = data.get("eps_dataset") + if isinstance(eps_ds, dict): + if any(isinstance(v, str) and v in DATA_ARRAY_MAP for v in eps_ds.values()): + log.warning( + "Loading 'eps_dataset' without data; constructing a vacuum medium instead." + ) + fail_load = True + if fail_load: + data["permittivity"] = SpatialDataArray( + np.ones((1, 1, 1)), coords={"x": [0], "y": [0], "z": [0]} + ) + return data + + @model_validator(mode="after") + def _deprecation_dataset(self) -> Self: + """Raise deprecation warning if dataset supplied and convert to dataset.""" + + eps_dataset = self.eps_dataset + permittivity = self.permittivity + conductivity = self.conductivity + + # Incomplete custom medium definition. + if eps_dataset is None and permittivity is None and conductivity is None: + raise SetupError("Missing spatial profiles of 'permittivity' or 'eps_dataset'.") + if eps_dataset is None and permittivity is None: + raise SetupError("Missing spatial profiles of 'permittivity'.") + + # Definition racing + if eps_dataset is not None and (permittivity is not None or conductivity is not None): + raise SetupError( + "Please either define 'permittivity' and 'conductivity', or 'eps_dataset', " + "but not both simultaneously." + ) + + if eps_dataset is None: + return self + + # TODO: sometime before 3.0, uncomment these lines to warn users to start using new API + # if isinstance(eps_dataset, dict): + # eps_components = [eps_dataset[f"eps_{dim}{dim}"] for dim in "xyz"] + # else: + # eps_components = [eps_dataset.eps_xx, eps_dataset.eps_yy, eps_dataset.eps_zz] + + # is_isotropic = eps_components[0] == eps_components[1] == eps_components[2] + + # if is_isotropic: + # # deprecation warning for isotropic custom medium + # log.warning( + # "For spatially varying isotropic medium, the 'eps_dataset' field " + # "is being replaced by 'permittivity' and 'conductivity' in v3.0. " + # "We recommend you change your scripts to be compatible with the new API." + # ) + # else: + # # deprecation warning for anisotropic custom medium + # log.warning( + # "For spatially varying anisotropic medium, this class is being replaced " + # "by 'CustomAnisotropicMedium' in v3.0. " + # "We recommend you change your scripts to be compatible with the new API." + # ) + + return self + + @field_validator("eps_dataset") + @classmethod + def _eps_dataset_single_frequency( + cls, val: Optional[PermittivityDataset] + ) -> Optional[PermittivityDataset]: + """Assert only one frequency supplied.""" + if val is None: + return val + + for name, eps_dataset_component in val.field_components.items(): + freqs = eps_dataset_component.f + if len(freqs) != 1: + raise SetupError( + f"'eps_dataset.{name}' must have a single frequency, " + f"but it contains {len(freqs)} frequencies." + ) + return val + + @model_validator(mode="after") + def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(self) -> Self: + """Assert any eps_inf must be >=1""" + val = self.eps_dataset + if val is None: + return self + modulation = self.modulation_spec + + for comp in ["eps_xx", "eps_yy", "eps_zz"]: + eps_real, sigma = CustomMedium.eps_complex_to_eps_sigma( + val.field_components[comp], val.field_components[comp].f + ) + if np.any(_get_numpy_array(eps_real) < 1): + raise SetupError( + "Permittivity at infinite frequency at any spatial point " + "must be no less than one." + ) + + if modulation is not None and modulation.permittivity is not None: + if np.any(_get_numpy_array(eps_real) - modulation.permittivity.max_modulation <= 0): + raise ValidationError( + "The minimum permittivity value with modulation applied " + "was found to be negative." + ) + + if not self.allow_gain and np.any(_get_numpy_array(sigma) < 0): + raise ValidationError( + "For passive medium, imaginary part of permittivity must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + + if ( + not self.allow_gain + and modulation is not None + and modulation.conductivity is not None + and np.any(_get_numpy_array(sigma) - modulation.conductivity.max_modulation <= 0) + ): + raise ValidationError( + "For passive medium, imaginary part of permittivity must be non-negative " + "at any time. " + "With conductivity modulation, this medium can sometimes be active. " + "Please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + return self + + @model_validator(mode="after") + def _eps_inf_greater_no_less_than_one(self) -> Self: + """Assert any eps_inf must be >=1""" + val = self.permittivity + if val is None: + return self + + if not CustomMedium._validate_isreal_dataarray(val): + raise SetupError("'permittivity' must be real.") + + if np.any(_get_numpy_array(val) < 1): + raise SetupError("'permittivity' must be no less than one.") + + modulation = self.modulation_spec + if modulation is None or modulation.permittivity is None: + return self + + if np.any(_get_numpy_array(val) - modulation.permittivity.max_modulation <= 0): + raise ValidationError( + "The minimum permittivity value with modulation applied was found to be negative." + ) + + return self + + @model_validator(mode="after") + def _conductivity_non_negative_correct_shape(self) -> Self: + """Assert conductivity>=0""" + val = self.conductivity + + if val is None: + return self + + if not CustomMedium._validate_isreal_dataarray(val): + raise SetupError("'conductivity' must be real.") + + if not self.allow_gain and np.any(_get_numpy_array(val) < 0): + raise ValidationError( + "For passive medium, 'conductivity' must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + + if not _check_same_coordinates(self.permittivity, val): + raise SetupError("'permittivity' and 'conductivity' must have the same coordinates.") + + return self + + @model_validator(mode="after") + def _passivity_modulation_validation(self) -> Self: + """Assert passive medium at any time during modulation if ``allow_gain`` is False.""" + val = self.conductivity + + # validated already when the data is supplied through `eps_dataset` + if self.eps_dataset: + return self + + # permittivity defined with ``permittivity`` and ``conductivity`` + modulation = self.modulation_spec + if self.allow_gain or modulation is None or modulation.conductivity is None: + return self + if val is None or np.any( + _get_numpy_array(val) - modulation.conductivity.max_modulation < 0 + ): + raise ValidationError( + "For passive medium, 'conductivity' must be non-negative at any time. " + "With conductivity modulation, this medium can sometimes be active. " + "Please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + return self + + @field_validator("permittivity", "conductivity") + @classmethod + def _check_permittivity_conductivity_interpolate( + cls, val: Optional[CustomSpatialDataType], info: FieldValidationInfo + ) -> Optional[CustomSpatialDataType]: + """Check that the custom medium 'SpatialDataArrays' can be interpolated.""" + + if isinstance(val, SpatialDataArray): + val._interp_validator(info.field_name) + + return val + + @cached_property + def is_isotropic(self) -> bool: + """Check if the medium is isotropic or anisotropic.""" + if self.eps_dataset is None: + return True + if self.eps_dataset.eps_xx == self.eps_dataset.eps_yy == self.eps_dataset.eps_zz: + return True + return False + + @cached_property + def is_spatially_uniform(self) -> bool: + """Whether the medium is spatially uniform.""" + return self._medium.is_spatially_uniform + + @cached_property + def _permittivity_sorted(self) -> SpatialDataArray | None: + """Cached copy of permittivity sorted along spatial axes.""" + if self.permittivity is None: + return None + return self.permittivity._spatially_sorted + + @cached_property + def _conductivity_sorted(self) -> SpatialDataArray | None: + """Cached copy of conductivity sorted along spatial axes.""" + if self.conductivity is None: + return None + return self.conductivity._spatially_sorted + + @cached_property + def _eps_components_sorted(self) -> dict[str, ScalarFieldDataArray]: + """Cached copies of dataset components sorted along spatial axes.""" + if self.eps_dataset is None: + return {} + return { + key: comp._spatially_sorted for key, comp in self.eps_dataset.field_components.items() + } + + @cached_property + def freqs(self) -> ArrayFloat: + """float array of frequencies. + This field is to be deprecated in v3.0. + """ + # return dummy values in this case + if self.eps_dataset is None: + return np.array([0, 0, 0]) + return np.array( + [ + self.eps_dataset.eps_xx.coords["f"], + self.eps_dataset.eps_yy.coords["f"], + self.eps_dataset.eps_zz.coords["f"], + ] + ) + + @cached_property + def _medium(self) -> CustomAnisotropicMedium: + """Internal representation in the form of + either `CustomIsotropicMedium` or `CustomAnisotropicMedium`. + """ + self_dict = self.model_dump(exclude={"type", "eps_dataset"}) + # isotropic + if self.eps_dataset is None: + self_dict.update({"permittivity": self.permittivity, "conductivity": self.conductivity}) + return CustomIsotropicMedium.model_validate(self_dict) + + def get_eps_sigma(eps_complex: SpatialDataArray, freq: float) -> tuple: + """Convert a complex permittivity to real permittivity and conductivity.""" + eps_values = np.array(eps_complex.values) + + eps_real, sigma = CustomMedium.eps_complex_to_eps_sigma(eps_values, freq) + coords = eps_complex.coords + + eps_real = ScalarFieldDataArray(eps_real, coords=coords) + sigma = ScalarFieldDataArray(sigma, coords=coords) + + eps_real = SpatialDataArray(eps_real.squeeze(dim="f", drop=True)) + sigma = SpatialDataArray(sigma.squeeze(dim="f", drop=True)) + + return eps_real, sigma + + # isotropic, but with `eps_dataset` + if self.is_isotropic: + eps_complex = self.eps_dataset.eps_xx + eps_real, sigma = get_eps_sigma(eps_complex, freq=self.freqs[0]) + + self_dict.update({"permittivity": eps_real, "conductivity": sigma}) + return CustomIsotropicMedium.model_validate(self_dict) + + # anisotropic + mat_comp = {"interp_method": self.interp_method} + for freq, comp in zip(self.freqs, ["xx", "yy", "zz"]): + eps_complex = self.eps_dataset.field_components["eps_" + comp] + eps_real, sigma = get_eps_sigma(eps_complex, freq=freq) + + comp_dict = self_dict.copy() + comp_dict.update({"permittivity": eps_real, "conductivity": sigma}) + mat_comp.update({comp: CustomIsotropicMedium.model_validate(comp_dict)}) + return CustomAnisotropicMediumInternal(**mat_comp) + + def _interp_method(self, comp: Axis) -> InterpMethod: + """Interpolation method applied to comp.""" + return self._medium._interp_method(comp) + + @cached_property + def n_cfl(self) -> float: + """This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl```. + + For dispersiveless custom medium, it equals ``min[sqrt(eps_inf)]``, where ``min`` + is performed over all components and spatial points. + """ + return self._medium.n_cfl + + def eps_dataarray_freq( + self, frequency: float + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + """Permittivity array at ``frequency``. () + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[ + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + ] + The permittivity evaluated at ``frequency``. + """ + return self._medium.eps_dataarray_freq(frequency) + + def eps_diagonal_on_grid( + self, + frequency: float, + coords: Coords, + ) -> tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D]: + """Spatial profile of main diagonal of the complex-valued permittivity + at ``frequency`` interpolated at the supplied coordinates. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + coords : :class:`.Coords` + The grid point coordinates over which interpolation is performed. + + Returns + ------- + tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D] + The complex-valued permittivity tensor at ``frequency`` interpolated + at the supplied coordinate. + """ + return self._medium.eps_diagonal_on_grid(frequency, coords) + + @ensure_freq_in_range + def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: + """Main diagonal of the complex-valued permittivity tensor + at ``frequency``. Spatially, we take max{|eps|}, so that autoMesh generation + works appropriately. + """ + return self._medium.eps_diagonal(frequency) + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + """Spatial and polarizaiton average of complex-valued permittivity + as a function of frequency. + """ + return self._medium.eps_model(frequency) + + @classmethod + def from_eps_raw( + cls, + eps: Union[ScalarFieldDataArray, CustomSpatialDataType], + freq: Optional[float] = None, + interp_method: InterpMethod = "nearest", + **kwargs: Any, + ) -> Self: + """Construct a :class:`.CustomMedium` from datasets containing raw permittivity values. + + Parameters + ---------- + eps : Union[ + :class:`.SpatialDataArray`, + :class:`.ScalarFieldDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ] + Dataset containing complex-valued permittivity as a function of space. + freq : float, optional + Frequency at which ``eps`` are defined. + interp_method : :class:`.InterpMethod`, optional + Interpolation method to obtain permittivity values that are not supplied + at the Yee grids. + + Notes + ----- + + For lossy medium that has a complex-valued ``eps``, if ``eps`` is supplied through + :class:`.SpatialDataArray`, which doesn't contain frequency information, + the ``freq`` kwarg will be used to evaluate the permittivity and conductivity. + Alternatively, ``eps`` can be supplied through :class:`.ScalarFieldDataArray`, + which contains a frequency coordinate. + In this case, leave ``freq`` kwarg as the default of ``None``. + + Returns + ------- + :class:`.CustomMedium` + Medium containing the spatially varying permittivity data. + """ + if isinstance(eps, CustomSpatialDataType.__args__): + # purely real, not need to know `freq` + if CustomMedium._validate_isreal_dataarray(eps): + return cls(permittivity=eps, interp_method=interp_method, **kwargs) + # complex permittivity, needs to know `freq` + if freq is None: + raise SetupError( + "For a complex 'eps', 'freq' at which 'eps' is defined must be supplied", + ) + eps_real, sigma = CustomMedium.eps_complex_to_eps_sigma(eps, freq) + return cls( + permittivity=eps_real, conductivity=sigma, interp_method=interp_method, **kwargs + ) + + # eps is ScalarFieldDataArray + # contradictory definition of frequency + freq_data = eps.coords["f"].data[0] + if freq is not None and not isclose(freq, freq_data): + raise SetupError( + "'freq' value is inconsistent with the coordinate 'f'" + "in 'eps' DataArray. It's unclear at which frequency 'eps' " + "is defined. Please leave 'freq=None' to use the frequency " + "value in the DataArray." + ) + eps_real, sigma = CustomMedium.eps_complex_to_eps_sigma(eps, freq_data) + eps_real = SpatialDataArray(eps_real.squeeze(dim="f", drop=True)) + sigma = SpatialDataArray(sigma.squeeze(dim="f", drop=True)) + return cls(permittivity=eps_real, conductivity=sigma, interp_method=interp_method, **kwargs) + + @classmethod + def from_nk( + cls, + n: Union[ScalarFieldDataArray, CustomSpatialDataType], + k: Optional[Union[ScalarFieldDataArray, CustomSpatialDataType]] = None, + freq: Optional[float] = None, + interp_method: InterpMethod = "nearest", + **kwargs: Any, + ) -> Self: + """Construct a :class:`.CustomMedium` from datasets containing n and k values. + + Parameters + ---------- + n : Union[ + :class:`.SpatialDataArray`, + :class:`.ScalarFieldDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ] + Real part of refractive index. + k : Union[ + :class:`.SpatialDataArray`, + :class:`.ScalarFieldDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], optional + Imaginary part of refrative index for lossy medium. + freq : float, optional + Frequency at which ``n`` and ``k`` are defined. + interp_method : :class:`.InterpMethod`, optional + Interpolation method to obtain permittivity values that are not supplied + at the Yee grids. + kwargs: dict + Keyword arguments passed to the medium construction. + + Note + ---- + For lossy medium, if both ``n`` and ``k`` are supplied through + :class:`.SpatialDataArray`, which doesn't contain frequency information, + the ``freq`` kwarg will be used to evaluate the permittivity and conductivity. + Alternatively, ``n`` and ``k`` can be supplied through :class:`.ScalarFieldDataArray`, + which contains a frequency coordinate. + In this case, leave ``freq`` kwarg as the default of ``None``. + + Returns + ------- + :class:`.CustomMedium` + Medium containing the spatially varying permittivity data. + """ + # lossless + if k is None: + if isinstance(n, ScalarFieldDataArray): + n = SpatialDataArray(n.squeeze(dim="f", drop=True)) + freq = 0 # dummy value + eps_real, _ = CustomMedium.nk_to_eps_sigma(n, 0 * n, freq) + return cls(permittivity=eps_real, interp_method=interp_method, **kwargs) + + # lossy case + if not _check_same_coordinates(n, k): + raise SetupError("'n' and 'k' must be of the same type and must have same coordinates.") + + # k is a SpatialDataArray + if isinstance(k, CustomSpatialDataType.__args__): + if freq is None: + raise SetupError( + "For a lossy medium, must supply 'freq' at which to convert 'n' " + "and 'k' to a complex valued permittivity." + ) + eps_real, sigma = CustomMedium.nk_to_eps_sigma(n, k, freq) + return cls( + permittivity=eps_real, conductivity=sigma, interp_method=interp_method, **kwargs + ) + + # k is a ScalarFieldDataArray + freq_data = k.coords["f"].data[0] + if freq is not None and not isclose(freq, freq_data): + raise SetupError( + "'freq' value is inconsistent with the coordinate 'f'" + "in 'k' DataArray. It's unclear at which frequency 'k' " + "is defined. Please leave 'freq=None' to use the frequency " + "value in the DataArray." + ) + + eps_real, sigma = CustomMedium.nk_to_eps_sigma(n, k, freq_data) + eps_real = SpatialDataArray(eps_real.squeeze(dim="f", drop=True)) + sigma = SpatialDataArray(sigma.squeeze(dim="f", drop=True)) + return cls(permittivity=eps_real, conductivity=sigma, interp_method=interp_method, **kwargs) + + def grids(self, bounds: Bound) -> dict[str, Grid]: + """Make a :class:`.Grid` corresponding to the data in each ``eps_ii`` component. + The min and max coordinates along each dimension are bounded by ``bounds``.""" + + rmin, rmax = bounds + pt_mins = dict(zip("xyz", rmin)) + pt_maxs = dict(zip("xyz", rmax)) + + def make_grid(scalar_field: Union[ScalarFieldDataArray, SpatialDataArray]) -> Grid: + """Make a grid for a single dataset.""" + + def make_bound_coords(coords: ArrayFloat, pt_min: float, pt_max: float) -> list[float]: + """Convert user supplied coords into boundary coords to use in :class:`.Grid`.""" + + # get coordinates of the bondaries halfway between user-supplied data + coord_bounds = (coords[1:] + coords[:-1]) / 2.0 + + # res-set coord boundaries that lie outside geometry bounds to the boundary (0 vol.) + coord_bounds[coord_bounds <= pt_min] = pt_min + coord_bounds[coord_bounds >= pt_max] = pt_max + + # add the geometry bounds in explicitly + return [pt_min, *coord_bounds.tolist(), pt_max] + + # grab user supplied data long this dimension + coords = {key: np.array(val) for key, val in scalar_field.coords.items()} + spatial_coords = {key: coords[key] for key in "xyz"} + + # convert each spatial coord to boundary coords + bound_coords = {} + for key, coords in spatial_coords.items(): + pt_min = pt_mins[key] + pt_max = pt_maxs[key] + bound_coords[key] = make_bound_coords(coords=coords, pt_min=pt_min, pt_max=pt_max) + + # construct grid + boundaries = Coords(**bound_coords) + return Grid(boundaries=boundaries) + + grids = {} + for field_name in ("eps_xx", "eps_yy", "eps_zz"): + # grab user supplied data long this dimension + scalar_field = self.eps_dataset.field_components[field_name] + + # feed it to make_grid + grids[field_name] = make_grid(scalar_field) + + return grids + + def _sel_custom_data_inside(self, bounds: Bound) -> Self: + """Return a new custom medium that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. + + + Parameters + ---------- + bounds : tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + CustomMedium + CustomMedium with reduced data. + """ + + perm_reduced = None + if self.permittivity is not None: + if not self.permittivity.does_cover(bounds=bounds): + log.warning( + "Permittivity spatial data array does not fully cover the requested region." + ) + perm_reduced = self.permittivity.sel_inside(bounds=bounds) + + cond_reduced = None + if self.conductivity is not None: + if not self.conductivity.does_cover(bounds=bounds): + log.warning( + "Conductivity spatial data array does not fully cover the requested region." + ) + cond_reduced = self.conductivity.sel_inside(bounds=bounds) + + eps_reduced = None + if self.eps_dataset is not None: + eps_reduced_dict = {} + for key, comp in self.eps_dataset.field_components.items(): + if not comp.does_cover(bounds=bounds): + log.warning( + f"{key} spatial data array does not fully cover the requested region." + ) + eps_reduced_dict[key] = comp.sel_inside(bounds=bounds) + eps_reduced = PermittivityDataset(**eps_reduced_dict) + + return self.updated_copy( + permittivity=perm_reduced, + conductivity=cond_reduced, + eps_dataset=eps_reduced, + ) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + + vjps = {} + + for field_path in derivative_info.paths: + if field_path[0] == "permittivity": + spatial_data = self._permittivity_sorted + if spatial_data is None: + continue + vjp_array = 0.0 + for dim in "xyz": + vjp_array += self._derivative_field_cmp_custom( + E_der_map=derivative_info.E_der_map, + spatial_data=spatial_data, + dim=dim, + freqs=derivative_info.frequencies, + bounds=derivative_info.bounds_intersect, + component="real", + ) + vjps[field_path] = vjp_array + + elif field_path[0] == "conductivity": + spatial_data = self._conductivity_sorted + if spatial_data is None: + continue + vjp_array = 0.0 + for dim in "xyz": + vjp_array += self._derivative_field_cmp_custom( + E_der_map=derivative_info.E_der_map, + spatial_data=spatial_data, + dim=dim, + freqs=derivative_info.frequencies, + bounds=derivative_info.bounds_intersect, + component="sigma", + ) + vjps[field_path] = vjp_array + + elif field_path[0] == "eps_dataset": + key = field_path[1] + spatial_data = self._eps_components_sorted.get(key) + if spatial_data is None: + continue + dim = key[-1] + vjps[field_path] = self._derivative_field_cmp_custom( + E_der_map=derivative_info.E_der_map, + spatial_data=spatial_data, + dim=dim, + freqs=derivative_info.frequencies, + bounds=derivative_info.bounds_intersect, + component="complex", + ) + else: + raise NotImplementedError( + f"No derivative defined for 'CustomMedium' field: {field_path}." + ) + + return vjps + + +""" Dispersive Media """ + + +class DispersiveMedium(AbstractMedium, ABC): + """ + A Medium with dispersion: field propagation characteristics depend on frequency. + + Notes + ----- + + In dispersive mediums, the displacement field :math:`D(t)` depends on the previous electric field :math:`E( + t')` and time-dependent permittivity :math:`\\epsilon` changes. + + .. math:: + + D(t) = \\int \\epsilon(t - t') E(t') \\delta t' + + Dispersive mediums can be defined in three ways: + + - Imported from our `material library <../material_library.html>`_. + - Defined directly by specifying the parameters in the `various supplied dispersive models <../mediums.html>`_. + - Fitted to optical n-k data using the `dispersion fitting tool plugin <../plugins/dispersion.html>`_. + + It is important to keep in mind that dispersive materials are inevitably slower to simulate than their + dispersion-less counterparts, with complexity increasing with the number of poles included in the dispersion + model. For simulations with a narrow range of frequencies of interest, it may sometimes be faster to define + the material through its real and imaginary refractive index at the center frequency. + + + See Also + -------- + + :class:`CustomPoleResidue`: + A spatially varying dispersive medium described by the pole-residue pair model. + + **Notebooks** + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + * `Modeling dispersive material in FDTD `_ + """ + + @staticmethod + def _permittivity_modulation_validation() -> Callable[[T], T]: + """Assert modulated permittivity cannot be <= 0 at any time.""" + + @model_validator(mode="after") + def _validate_permittivity_modulation(self: T) -> T: + """Assert modulated permittivity cannot be <= 0.""" + val = self.eps_inf + modulation = self.modulation_spec + if modulation is None or modulation.permittivity is None: + return self + + min_eps_inf = np.min(_get_numpy_array(val)) + if min_eps_inf - modulation.permittivity.max_modulation <= 0: + raise ValidationError( + "The minimum permittivity value with modulation applied was found to be negative." + ) + return self + + return _validate_permittivity_modulation + + @staticmethod + def _conductivity_modulation_validation() -> Callable[[T], T]: + """Assert passive medium at any time if not ``allow_gain``.""" + + @model_validator(mode="after") + def _validate_conductivity_modulation(self: T) -> T: + """With conductivity modulation, the medium can exhibit gain during the cycle. + So `allow_gain` must be True when the conductivity is modulated. + """ + val = self.modulation_spec + if val is None or val.conductivity is None: + return self + + if not self.allow_gain: + raise ValidationError( + "For passive medium, 'conductivity' must be non-negative at any time. " + "With conductivity modulation, this medium can sometimes be active. " + "Please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, and are likely to diverge." + ) + return self + + return _validate_conductivity_modulation + + @abstractmethod + def _pole_residue_dict(self) -> dict: + """Dict representation of Medium as a pole-residue model.""" + + @cached_property + def pole_residue(self) -> PoleResidue: + """Representation of Medium as a pole-residue model.""" + return PoleResidue(**self._pole_residue_dict(), allow_gain=self.allow_gain) + + @cached_property + def n_cfl(self) -> float: + """This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl``. + + For PoleResidue model, it equals ``sqrt(eps_inf)`` + [https://ieeexplore.ieee.org/document/9082879]. + """ + permittivity = self.pole_residue.eps_inf + if self.modulation_spec is not None and self.modulation_spec.permittivity is not None: + permittivity -= self.modulation_spec.permittivity.max_modulation + n, _ = self.eps_complex_to_nk(permittivity) + return n + + @staticmethod + def tuple_to_complex(value: tuple[float, float]) -> complex: + """Convert a tuple of real and imaginary parts to complex number.""" + + val_r, val_i = value + return val_r + 1j * val_i + + @staticmethod + def complex_to_tuple(value: complex) -> tuple[float, float]: + """Convert a complex number to a tuple of real and imaginary parts.""" + + return (value.real, value.imag) + + # --- shared autograd helpers for dispersive models --- + def _tjp_inputs( + self, derivative_info: DerivativeInfo + ) -> tuple[NDArray, Union[ArrayFloat, ArrayBox]]: + """Prepare shared inputs for TJP: frequencies and packed adjoint vector.""" + dJ = self._derivative_eps_complex_volume( + E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds + ) + freqs = np.asarray(derivative_info.frequencies, float) + dJv = np.asarray(getattr(dJ, "values", dJ)) + return freqs, pack_complex_vec(dJv) + + @staticmethod + def _tjp_grad( + theta0: ArrayFloat, + eps_vec_fn: Callable[[ArrayFloat], Union[ArrayComplex, ArrayBox]], + vec: Union[ArrayComplex, ArrayBox], + ) -> ArrayFloat: + """Run a tensor-Jacobian-product to get J^T @ vec.""" + return tensor_jacobian_product(eps_vec_fn)(theta0, vec) + + @staticmethod + def _map_grad_real( + g: TracedFloat, + paths: set[tuple], + mapping: Sequence[tuple[tuple, int]], + ) -> AutogradFieldMap: + """Map flat gradient to model paths, taking the real part.""" + out = {} + for k, idx in mapping: + if k in paths: + out[k] = np.real(g[idx]) + return out + + +class CustomDispersiveMedium(AbstractCustomMedium, DispersiveMedium, ABC): + """A spatially varying dispersive medium.""" + + @cached_property + def n_cfl(self) -> float: + """This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl``. + + For PoleResidue model, it equals ``sqrt(eps_inf)`` + [https://ieeexplore.ieee.org/document/9082879]. + """ + permittivity = np.min(_get_numpy_array(self.pole_residue.eps_inf)) + if self.modulation_spec is not None and self.modulation_spec.permittivity is not None: + permittivity -= self.modulation_spec.permittivity.max_modulation + n, _ = self.eps_complex_to_nk(permittivity) + return n + + @cached_property + def is_isotropic(self) -> bool: + """Whether the medium is isotropic.""" + return True + + @cached_property + def pole_residue(self) -> CustomPoleResidue: + """Representation of Medium as a pole-residue model.""" + return CustomPoleResidue( + **self._pole_residue_dict(), + interp_method=self.interp_method, + allow_gain=self.allow_gain, + subpixel=self.subpixel, + ) + + @staticmethod + def _warn_if_data_none( + nested_tuple_field: str, + ) -> Callable[[type[AbstractMedium], dict[str, Any]], dict[str, Any]]: + """Warn if any of `eps_inf` and nested_tuple_field are not loaded, + and return a vacuum with eps_inf = 1. + """ + + @model_validator(mode="before") + @classmethod + def _warn_if_none(cls: type[AbstractMedium], data: dict[str, Any]) -> dict[str, Any]: + is_not_loaded = AbstractCustomMedium._not_loaded + + eps_inf = data.get("eps_inf") + coeffs = data.get(nested_tuple_field, ()) + + eps_bad = is_not_loaded(eps_inf) + coeff_bad = any(is_not_loaded(c) for coeff in coeffs for c in coeff) + + if not (eps_bad or coeff_bad): + return data + + if eps_bad: + log.warning("Loading 'eps_inf' without data; constructing a vacuum medium instead.") + if coeff_bad: + log.warning( + f"Loading '{nested_tuple_field}' without data; constructing a vacuum medium instead." + ) + + data[nested_tuple_field] = () + if eps_inf is not None: + data["eps_inf"] = SpatialDataArray( + np.ones((1, 1, 1)), coords={"x": [0], "y": [0], "z": [0]} + ) + + return data + + return _warn_if_none + + # --- helpers for custom dispersive adjoints --- + def _sum_complex_eps_sensitivity( + self, + derivative_info: DerivativeInfo, + spatial_ref: PermittivityDataset, + ) -> ArrayComplex: + """Sum complex permittivity sensitivities over xyz on the given spatial grid. + + Parameters + ---------- + derivative_info : DerivativeInfo + Info bundle carrying field maps and frequencies. + spatial_ref : PermittivityDataset + Spatial dataset to define the grid/coords for interpolation and summation. + + Returns + ------- + np.ndarray + Complex-valued aggregated dJ array with the same spatial shape as ``spatial_ref``. + """ + dJ = 0.0 + 0.0j + for dim in "xyz": + dJ += self._derivative_field_cmp( + E_der_map=derivative_info.E_der_map, + spatial_data=spatial_ref, + dim=dim, + ) + return dJ + + @staticmethod + def _accum_real_inner(dJ: ArrayComplex, weight: ArrayComplex) -> ArrayFloat: + """Compute Re(dJ * conj(weight)) with proper broadcasting.""" + return np.real(dJ * np.conj(weight)) + + def _sum_over_freqs( + self, freqs: FrequencyArray, dJ: ArrayComplex, weight_fn: WeightFunction + ) -> ArrayFloat: + """Accumulate gradient contributions over frequencies using provided weight function. + + Parameters + ---------- + freqs : array-like + Frequencies to accumulate over. + dJ : np.ndarray + Complex dataset sensitivity with spatial shape. + weight_fn : Callable[[float], np.ndarray] + Function mapping frequency to weight array broadcastable to dJ. + + Returns + ------- + np.ndarray + Real-valued gradient array matching dJ's broadcasted shape. + """ + g = 0.0 + for f in freqs: + g = g + self._accum_real_inner(dJ, weight_fn(f)) + return g + + +class PoleResidue(DispersiveMedium): + """A dispersive medium described by the pole-residue pair model. + + Notes + ----- + + The frequency-dependence of the complex-valued permittivity is described by: + + .. math:: + + \\epsilon(\\omega) = \\epsilon_\\infty - \\sum_i + \\left[\\frac{c_i}{j \\omega + a_i} + + \\frac{c_i^*}{j \\omega + a_i^*}\\right] + + Example + ------- + >>> pole_res = PoleResidue(eps_inf=2.0, poles=[((-1+2j), (3+4j)), ((-5+6j), (7+8j))]) + >>> eps = pole_res.eps_model(200e12) + + See Also + -------- + + :class:`CustomPoleResidue`: + A spatially varying dispersive medium described by the pole-residue pair model. + + **Notebooks** + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + * `Modeling dispersive material in FDTD `_ + """ + + eps_inf: TracedPositiveFloat = Field( + 1.0, + title="Epsilon at Infinity", + description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", + units=PERMITTIVITY, + ) + + poles: TracedPolesAndResidues = Field( + (), + title="Poles", + description="Tuple of complex-valued (:math:`a_i, c_i`) poles for the model.", + units=(RADPERSEC, RADPERSEC), + ) + + @field_validator("poles") + @classmethod + def _causality_validation(cls, val: TracedPolesAndResidues) -> TracedPolesAndResidues: + """Assert causal medium.""" + for a, _ in val: + if np.any(np.real(_get_numpy_array(a)) > 0): + raise SetupError("For stable medium, 'Re(a_i)' must be non-positive.") + return val + + @field_validator("poles") + @classmethod + def _poles_largest_value(cls, val: TracedPolesAndResidues) -> TracedPolesAndResidues: + """Assert pole parameters are not too large.""" + for a, c in val: + if np.any(abs(_get_numpy_array(a)) > LARGEST_FP_NUMBER): + raise ValidationError( + "The value of some 'a_i' is too large. They are unlikely to contribute to material dispersion." + ) + if np.any(abs(_get_numpy_array(c)) > LARGEST_FP_NUMBER): + raise ValidationError("The value of some 'c_i' is too large.") + return val + + _validate_permittivity_modulation = DispersiveMedium._permittivity_modulation_validation() + _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() + + @staticmethod + def _eps_model(eps_inf: PositiveFloat, poles: PolesAndResidues, frequency: float) -> complex: + """Complex-valued permittivity as a function of frequency.""" + + omega = 2 * np.pi * frequency + eps = eps_inf + 0 * frequency + 0.0j + for a, c in poles: + a_cc = np.conj(a) + c_cc = np.conj(c) + eps = eps - c / (1j * omega + a) + eps = eps - c_cc / (1j * omega + a_cc) + return eps + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + """Complex-valued permittivity as a function of frequency.""" + return self._eps_model(eps_inf=self.eps_inf, poles=self.poles, frequency=frequency) + + def _pole_residue_dict(self) -> dict: + """Dict representation of Medium as a pole-residue model.""" + + return { + "eps_inf": self.eps_inf, + "poles": self.poles, + "frequency_range": self.frequency_range, + "name": self.name, + } + + def __str__(self) -> str: + """string representation""" + return ( + f"td.PoleResidue(" + f"\n\teps_inf={self.eps_inf}, " + f"\n\tpoles={self.poles}, " + f"\n\tfrequency_range={self.frequency_range})" + ) + + @classmethod + def from_medium(cls, medium: Medium) -> Self: + """Convert a :class:`.Medium` to a pole residue model. + + Parameters + ---------- + medium: :class:`.Medium` + The medium with permittivity and conductivity to convert. + + Returns + ------- + :class:`.PoleResidue` + The pole residue equivalent. + """ + poles = [(0, medium.conductivity / (2 * EPSILON_0))] + return PoleResidue( + eps_inf=medium.permittivity, poles=poles, frequency_range=medium.frequency_range + ) + + def to_medium(self) -> Medium: + """Convert to a :class:`.Medium`. + Requires the pole residue model to only have a pole at 0 frequency, + corresponding to a constant conductivity term. + + Returns + ------- + :class:`.Medium` + The non-dispersive equivalent with constant permittivity and conductivity. + """ + res = 0 + for a, c in self.poles: + if abs(a) > fp_eps: + raise ValidationError("Cannot convert dispersive 'PoleResidue' to 'Medium'.") + res = res + (c + np.conj(c)) / 2 + sigma = res * 2 * EPSILON_0 + return Medium( + permittivity=self.eps_inf, + conductivity=np.real(sigma), + frequency_range=self.frequency_range, + ) + + @staticmethod + def lo_to_eps_model( + poles: tuple[tuple[float, float, float, float], ...], + eps_inf: PositiveFloat, + frequency: float, + ) -> complex: + """Complex permittivity as a function of frequency for a given set of LO-TO coefficients. + See ``from_lo_to`` in :class:`.PoleResidue` for the detailed form of the model + and a reference paper. + + Parameters + ---------- + poles : tuple[tuple[float, float, float, float], ...] + The LO-TO poles, given as list of tuples of the form + (omega_LO, gamma_LO, omega_TO, gamma_TO). + eps_inf: PositiveFloat + The relative permittivity at infinite frequency. + frequency: float + Frequency at which to evaluate the permittivity. + + Returns + ------- + complex + The complex permittivity of the given LO-TO model at the given frequency. + """ + omega = 2 * np.pi * frequency + eps = eps_inf + for omega_lo, gamma_lo, omega_to, gamma_to in poles: + eps *= omega_lo**2 - omega**2 - 1j * omega * gamma_lo + eps /= omega_to**2 - omega**2 - 1j * omega * gamma_to + return eps + + @classmethod + def from_lo_to( + cls, poles: tuple[tuple[float, float, float, float], ...], eps_inf: PositiveFloat = 1 + ) -> Self: + """Construct a pole residue model from the LO-TO form + (longitudinal and transverse optical modes). + The LO-TO form is :math:`\\epsilon_\\infty \\prod_{i=1}^l \\frac{\\omega_{LO, i}^2 - \\omega^2 - i \\omega \\gamma_{LO, i}}{\\omega_{TO, i}^2 - \\omega^2 - i \\omega \\gamma_{TO, i}}` as given in the paper: + + M. Schubert, T. E. Tiwald, and C. M. Herzinger, + "Infrared dielectric anisotropy and phonon modes of sapphire," + Phys. Rev. B 61, 8187 (2000). + + Parameters + ---------- + poles : tuple[tuple[float, float, float, float], ...] + The LO-TO poles, given as list of tuples of the form + (omega_LO, gamma_LO, omega_TO, gamma_TO). + eps_inf: PositiveFloat + The relative permittivity at infinite frequency. + + Returns + ------- + :class:`.PoleResidue` + The pole residue equivalent of the LO-TO form provided. + """ + + omegas_lo, gammas_lo, omegas_to, gammas_to = map(np.array, zip(*poles)) + + # discriminants of quadratic factors of denominator + discs = 2 * npo.emath.sqrt((gammas_to / 2) ** 2 - omegas_to**2) + + # require nondegenerate TO poles + if len({(omega_to, gamma_to) for (_, _, omega_to, gamma_to) in poles}) != len(poles) or any( + disc == 0 for disc in discs + ): + raise ValidationError( + "Unable to construct a pole residue model " + "from an LO-TO form with degenerate TO poles. Consider adding a " + "perturbation to split the poles, or using " + "'PoleResidue.lo_to_eps_model' and fitting with the 'FastDispersionFitter'." + ) + + # roots of denominator, in pairs + roots = [] + for gamma_to, disc in zip(gammas_to, discs): + roots.append(-gamma_to / 2 + disc / 2) + roots.append(-gamma_to / 2 - disc / 2) + + # interpolants + interpolants = eps_inf * np.ones(len(roots), dtype=complex) + for i, a in enumerate(roots): + for omega_lo, gamma_lo in zip(omegas_lo, gammas_lo): + interpolants[i] *= omega_lo**2 + a**2 + a * gamma_lo + for j, a2 in enumerate(roots): + if j != i: + interpolants[i] /= a - a2 + + a_coeffs = [] + c_coeffs = [] + + for i in range(0, len(roots), 2): + if not np.isreal(roots[i]): + a_coeffs.append(roots[i]) + c_coeffs.append(interpolants[i]) + else: + a_coeffs.append(roots[i]) + a_coeffs.append(roots[i + 1]) + # factor of two from adding conjugate pole of real pole + c_coeffs.append(interpolants[i] / 2) + c_coeffs.append(interpolants[i + 1] / 2) + + return PoleResidue(eps_inf=eps_inf, poles=list(zip(a_coeffs, c_coeffs))) + + @staticmethod + def imag_ep_extrema(poles: PolesAndResidues) -> ArrayFloat1D: + """Extrema of Im[eps] in the same unit as poles. + + Parameters + ---------- + poles: PolesAndResidues + Tuple of complex-valued (``a_i, c_i``) poles for the model. + """ + + poles_a = [a for (a, _) in poles] + poles_c = [c for (_, c) in poles] + return imag_resp_extrema_locs(poles=poles_a, residues=poles_c) + + def _imag_ep_extrema_with_samples(self) -> ArrayFloat1D: + """Provide a list of frequencies (in unit of rad/s) to probe the possible lower and + upper bound of Im[eps] within the ``frequency_range``. If ``frequency_range`` is None, + it checks the entire frequency range. The returned frequencies include not only extrema, + but also a list of sampled frequencies. + """ + + # extrema frequencies: in the intermediate stage, convert to the unit eV for + # better numerical handling, since those quantities will be ~ 1 in photonics + extrema_freq = self.imag_ep_extrema(self.angular_freq_to_eV(np.array(self.poles))) + extrema_freq = self.eV_to_angular_freq(extrema_freq) + + # let's check a big range in addition to the imag_extrema + if self.frequency_range is None: + range_ev = np.logspace(LOSS_CHECK_MIN, LOSS_CHECK_MAX, LOSS_CHECK_NUM) + range_omega = self.eV_to_angular_freq(range_ev) + else: + fmin, fmax = self.frequency_range + fmin = max(fmin, fp_eps) + range_freq = np.logspace(np.log10(fmin), np.log10(fmax), LOSS_CHECK_NUM) + range_omega = self.Hz_to_angular_freq(range_freq) + + extrema_freq = extrema_freq[ + np.logical_and(extrema_freq > range_omega[0], extrema_freq < range_omega[-1]) + ] + return np.concatenate((range_omega, extrema_freq)) + + @cached_property + def loss_upper_bound(self) -> float: + """Upper bound of Im[eps] in `frequency_range`""" + freq_list = self.angular_freq_to_Hz(self._imag_ep_extrema_with_samples()) + ep = self.eps_model(freq_list) + # filter `NAN` in case some of freq_list are exactly at the pole frequency + # of Sellmeier-type poles. + ep = ep[~np.isnan(ep)] + return max(ep.imag) + + @staticmethod + def _get_vjps_from_params( + dJ_deps_complex: ComplexArrayOrScalar, + poles_vals: list[tuple[ComplexArrayOrScalar, ComplexArrayOrScalar]], + omega: float, + requested_paths: list[tuple], + project_real: bool = False, + ) -> AutogradFieldMap: + """ + Static helper to compute VJPs from parameters using the analytical chain rule. + + Parameters + - dJ_deps_complex: Complex adjoint sensitivity w.r.t. epsilon at a single frequency. + - poles_vals: Sequence of (a_i, c_i) pole parameters to differentiate with respect to. + - omega: Angular frequency for this VJP evaluation. + - requested_paths: Paths requested by the caller; used to filter outputs. + - project_real: If True, project pole-parameter VJPs to their real part. + Use True for uniform PoleResidue to match real-valued objectives; use False for + CustomPoleResidue where parameters are complex and complex VJPs are required. + """ + jw = 1j * omega + vjps = {} + + if ("eps_inf",) in requested_paths: + vjps[("eps_inf",)] = np.real(dJ_deps_complex) + + for i, (a_val, c_val) in enumerate(poles_vals): + if any(path[1] == i for path in requested_paths if path[0] == "poles"): + if ("poles", i, 0) in requested_paths: + deps_da = c_val / (jw + a_val) ** 2 + dJ_da = dJ_deps_complex * deps_da + vjps[("poles", i, 0)] = np.real(dJ_da) if project_real else dJ_da + if ("poles", i, 1) in requested_paths: + deps_dc = -1 / (jw + a_val) + dJ_dc = dJ_deps_complex * deps_dc + vjps[("poles", i, 1)] = np.real(dJ_dc) if project_real else dJ_dc + + return vjps + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute adjoint derivatives by preparing scalar data and calling the static helper.""" + + dJ_deps_complex = self._derivative_eps_complex_volume( + E_der_map=derivative_info.E_der_map, + bounds=derivative_info.bounds, + ) + + poles_vals = [(complex(a), complex(c)) for a, c in self.poles] + + freqs = dJ_deps_complex.coords["f"].values + vjps_total = {} + + for freq in freqs: + dJ_deps_complex_f = dJ_deps_complex.sel(f=freq) + vjps_f = self._get_vjps_from_params( + dJ_deps_complex=complex(dJ_deps_complex_f.item()), + poles_vals=poles_vals, + omega=2 * np.pi * freq, + requested_paths=derivative_info.paths, + project_real=True, + ) + for path, vjp in vjps_f.items(): + if path not in vjps_total: + vjps_total[path] = vjp + else: + vjps_total[path] += vjp + + return vjps_total + + @classmethod + def _real_partial_fraction_decomposition( + cls, a: ArrayFloat, b: ArrayFloat, tol: PositiveFloat = 1e-2 + ) -> tuple[list[tuple[Complex, Complex]], ArrayFloat]: + """Computes the complex conjugate pole residue pairs given a rational expression with + real coefficients. + + Parameters + ---------- + + a : np.ndarray + Coefficients of the numerator polynomial in increasing monomial order. + b : np.ndarray + Coefficients of the denominator polynomial in increasing monomial order. + tol : PositiveFloat + Tolerance for pole finding. Two poles are considered equal, if their spacing is less + than ``tol``. + + Returns + ------- + tuple[list[tuple[Complex, Complex]], np.ndarray] + The list of complex conjugate poles and their associated residues. The second element of the + ``tuple`` is an array of coefficients representing any direct polynomial term. + + """ + from scipy import signal + + if a.ndim != 1 or np.any(np.iscomplex(a)): + raise ValidationError( + "Numerator coefficients must be a one-dimensional array of real numbers." + ) + if b.ndim != 1 or np.any(np.iscomplex(b)): + raise ValidationError( + "Denominator coefficients must be a one-dimensional array of real numbers." + ) + + # Compute residues and poles using scipy + (r, p, k) = signal.residue(np.flip(a), np.flip(b), tol=tol, rtype="avg") + + # Assuming real coefficients for the polynomials, the poles should be real or come as + # complex conjugate pairs + r_filtered = [] + p_filtered = [] + for res, (idx, pole) in zip(list(r), enumerate(list(p))): + # Residue equal to zero interpreted as rational expression was not + # in simplest form. So skip this pole. + if res == 0: + continue + # Causal and stability check + if np.real(pole) > 0: + raise ValidationError("Transfer function is invalid. It is non-causal.") + # Check for higher order pole, which come in consecutive order + if idx > 0 and p[idx - 1] == pole: + raise ValidationError( + "Transfer function is invalid. A higher order pole was detected. Try reducing ``tol``, " + "or ensure that the rational expression does not have repeated poles. " + ) + if np.imag(pole) == 0: + r_filtered.append(res / 2) + p_filtered.append(pole) + else: + pair_found = len(np.argwhere(np.array(p) == np.conj(pole))) == 1 + if not pair_found: + raise ValueError( + "Failed to find complex-conjugate of pole in poles computed by SciPy." + ) + previously_added = len(np.argwhere(np.array(p_filtered) == np.conj(pole))) == 1 + if not previously_added: + r_filtered.append(res) + p_filtered.append(pole) + + poles_residues = tuple(zip(p_filtered, r_filtered)) + k_increasing_order = np.flip(k) + return (poles_residues, k_increasing_order) + + @classmethod + def from_admittance_coeffs( + cls, + a: ArrayFloat, + b: ArrayFloat, + eps_inf: PositiveFloat = 1, + pole_tol: PositiveFloat = 1e-2, + ) -> Self: + """Construct a :class:`.PoleResidue` model from an admittance function defining the + relationship between the electric field and the polarization current density in the + Laplace domain. + + Parameters + ---------- + a : np.ndarray + Coefficients of the numerator polynomial in increasing monomial order. + b : np.ndarray + Coefficients of the denominator polynomial in increasing monomial order. + eps_inf: PositiveFloat + The relative permittivity at infinite frequency. + pole_tol: PositiveFloat + Tolerance for the pole finding algorithm in Hertz. Two poles are considered equal, if their + spacing is closer than ``pole_tol`. + Returns + ------- + :class:`.PoleResidue` + The pole residue equivalent. + + Notes + ----- + + The supplied admittance function relates the electric field to the polarization current density + in the Laplace domain and is equivalent to a frequency-dependent complex conductivity + :math:`\\sigma(\\omega)`. + + .. math:: + J_p(s) = Y(s)E(s) + + .. math:: + Y(s) = \\frac{a_0 + a_1 s + \\dots + a_M s^M}{b_0 + b_1 s + \\dots + b_N s^N} + + An equivalent :class:`.PoleResidue` medium is constructed using an equivalent frequency-dependent + complex permittivity defined as + + .. math:: + \\epsilon(s) = \\epsilon_\\infty - \\frac{1}{\\epsilon_0 s} + \\frac{a_0 + a_1 s + \\dots + a_M s^M}{b_0 + b_1 s + \\dots + b_N s^N}. + """ + + if a.ndim != 1 or np.any(np.logical_or(np.iscomplex(a), a < 0)): + raise ValidationError( + "Numerator coefficients must be a one-dimensional array of non-negative real numbers." + ) + if b.ndim != 1 or np.any(np.logical_or(np.iscomplex(b), b < 0)): + raise ValidationError( + "Denominator coefficients must be a one-dimensional array of non-negative real numbers." + ) + + # Trim any trailing zeros, so that length corresponds with polynomial order + a = np.trim_zeros(a, "b") + b = np.trim_zeros(b, "b") + + # Validate that transfer function will result in a proper transfer function, once converted to + # the complex permittivity version + # Let q equal the order of the numerator polynomial, and p equal the order + # of the denominator polynomal. Then, q < p is strictly proper rational transfer function (RTF) + # q <= p is a proper RTF, and q > p is an improper RTF. + q = len(a) - 1 + p = len(b) - 1 + + if q > p + 1: + raise ValidationError( + "Transfer function is improper, the order of the numerator polynomial must be at most " + "one greater than the order of the denominator polynomial." + ) + + # Modify the transfer function defining a complex conductivity to match the complex + # frequency-dependent portion of the pole residue model + # Meaning divide by -j*omega*epsilon (s*epsilon) + b = np.concatenate(([0], b * EPSILON_0)) + + poles_and_residues, k = cls._real_partial_fraction_decomposition( + a=a, b=b, tol=pole_tol * 2 * np.pi + ) + + # A direct polynomial term of zeroth order is interpreted as an additional contribution to eps_inf. + # So we only handle that special case. + if len(k) == 1: + if np.iscomplex(k[0]) or k[0] < 0: + raise ValidationError( + "Transfer function is invalid. Direct polynomial term must be real and positive for " + "conversion to an equivalent 'PoleResidue' medium." + ) + # A pure capacitance will translate to an increased permittivity at infinite frequency. + eps_inf = eps_inf + k[0] + + pole_residue_from_transfer = PoleResidue(eps_inf=eps_inf, poles=poles_and_residues) + + # Check passivity + ang_freqs = PoleResidue._imag_ep_extrema_with_samples(pole_residue_from_transfer) + freq_list = PoleResidue.angular_freq_to_Hz(ang_freqs) + ep = pole_residue_from_transfer.eps_model(freq_list) + # filter `NAN` in case some of freq_list are exactly at the pole frequency + ep = ep[~np.isnan(ep)] + + if np.any(np.imag(ep) < -fp_eps): + log.warning( + "Generated 'PoleResidue' medium is not passive. Please raise an issue on the " + "Tidy3d frontend with this message and some information about your " + "simulation setup and we will investigate." + ) + + return pole_residue_from_transfer + + +class CustomPoleResidue(CustomDispersiveMedium, PoleResidue): + """A spatially varying dispersive medium described by the pole-residue pair model. + + Notes + ----- + + In this method, the frequency-dependent permittivity :math:`\\epsilon(\\omega)` is expressed as a sum of + resonant material poles _`[1]`. + + .. math:: + + \\epsilon(\\omega) = \\epsilon_\\infty - \\sum_i + \\left[\\frac{c_i}{j \\omega + a_i} + + \\frac{c_i^*}{j \\omega + a_i^*}\\right] + + For each of these resonant poles identified by the index :math:`i`, an auxiliary differential equation is + used to relate the auxiliary current :math:`J_i(t)` to the applied electric field :math:`E(t)`. + The sum of all these auxiliary current contributions describes the total dielectric response of the material. + + .. math:: + + \\frac{d}{dt} J_i (t) - a_i J_i (t) = \\epsilon_0 c_i \\frac{d}{dt} E (t) + + Hence, the computational cost increases with the number of poles. + + **References** + + .. [1] M. Han, R.W. Dutton and S. Fan, IEEE Microwave and Wireless Component Letters, 16, 119 (2006). + + .. TODO add links to notebooks using this. + + Example + ------- + >>> x = np.linspace(-1, 1, 5) + >>> y = np.linspace(-1, 1, 6) + >>> z = np.linspace(-1, 1, 7) + >>> coords = dict(x=x, y=y, z=z) + >>> eps_inf = SpatialDataArray(np.ones((5, 6, 7)), coords=coords) + >>> a1 = SpatialDataArray(-np.random.random((5, 6, 7)), coords=coords) + >>> c1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) + >>> a2 = SpatialDataArray(-np.random.random((5, 6, 7)), coords=coords) + >>> c2 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) + >>> pole_res = CustomPoleResidue(eps_inf=eps_inf, poles=[(a1, c1), (a2, c2)]) + >>> eps = pole_res.eps_model(200e12) + + See Also + -------- + + **Notebooks** + + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + + * `Modeling dispersive material in FDTD `_ + """ + + eps_inf: CustomSpatialDataTypeAnnotated = Field( + title="Epsilon at Infinity", + description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", + units=PERMITTIVITY, + ) + + poles: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( + Field( + (), + title="Poles", + description="Tuple of complex-valued (:math:`a_i, c_i`) poles for the model.", + units=(RADPERSEC, RADPERSEC), + ) + ) + + _no_nans = validate_no_nans("eps_inf", "poles") + _warn_if_none = CustomDispersiveMedium._warn_if_data_none("poles") + + @field_validator("eps_inf") + @classmethod + def _eps_inf_positive(cls, val: CustomSpatialDataType) -> CustomSpatialDataType: + """eps_inf must be positive""" + if not CustomDispersiveMedium._validate_isreal_dataarray(val): + raise SetupError("'eps_inf' must be real.") + if np.any(_get_numpy_array(val) < 0): + raise SetupError("'eps_inf' must be positive.") + return val + + @model_validator(mode="after") + def _poles_correct_shape(self) -> Self: + """poles must have the same shape.""" + val = self.poles + + for coeffs in val: + for coeff in coeffs: + if not _check_same_coordinates(coeff, self.eps_inf): + raise SetupError( + "All pole coefficients 'a' and 'c' must have the same coordinates; " + "The coordinates must also be consistent with 'eps_inf'." + ) + return self + + @cached_property + def is_spatially_uniform(self) -> bool: + """Whether the medium is spatially uniform.""" + if not self.eps_inf.is_uniform: + return False + + for coeffs in self.poles: + for coeff in coeffs: + if not coeff.is_uniform: + return False + return True + + @staticmethod + def _sorted_spatial_data( + data: CustomSpatialDataTypeAnnotated, + ) -> CustomSpatialDataTypeAnnotated: + """Return spatial data sorted along its coordinates if applicable.""" + if isinstance(data, SpatialDataArray): + return data._spatially_sorted + return data + + @cached_property + def _eps_inf_sorted(self) -> CustomSpatialDataTypeAnnotated: + """Cached sorted copy of eps_inf when structured data is provided.""" + return self._sorted_spatial_data(self.eps_inf) + + @cached_property + def _poles_sorted( + self, + ) -> tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...]: + """Cached sorted copies of pole coefficients when structured data is provided.""" + return tuple( + (self._sorted_spatial_data(a), self._sorted_spatial_data(c)) for a, c in self.poles + ) + + def eps_dataarray_freq( + self, frequency: float + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + """Permittivity array at ``frequency``. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[ + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + ] + The permittivity evaluated at ``frequency``. + """ + eps = PoleResidue.eps_model(self, frequency) + return (eps, eps, eps) + + def poles_on_grid(self, coords: Coords) -> tuple[tuple[ArrayComplex3D, ArrayComplex3D], ...]: + """Spatial profile of poles interpolated at the supplied coordinates. + + Parameters + ---------- + coords : :class:`.Coords` + The grid point coordinates over which interpolation is performed. + + Returns + ------- + tuple[tuple[ArrayComplex3D, ArrayComplex3D], ...] + The poles interpolated at the supplied coordinate. + """ + + def fun_interp(input_data: SpatialDataArray) -> ArrayComplex3D: + return _get_numpy_array(coords.spatial_interp(input_data, self.interp_method)) + + return tuple((fun_interp(a), fun_interp(c)) for (a, c) in self.poles) + + @classmethod + def from_medium(cls, medium: CustomMedium) -> Self: + """Convert a :class:`.CustomMedium` to a pole residue model. + + Parameters + ---------- + medium: :class:`.CustomMedium` + The medium with permittivity and conductivity to convert. + + Returns + ------- + :class:`.CustomPoleResidue` + The pole residue equivalent. + """ + poles = [(_zeros_like(medium.conductivity), medium.conductivity / (2 * EPSILON_0))] + medium_dict = medium.model_dump( + exclude={"type", "eps_dataset", "permittivity", "conductivity"} + ) + medium_dict.update({"eps_inf": medium.permittivity, "poles": poles}) + return CustomPoleResidue.model_validate(medium_dict) + + def to_medium(self) -> CustomMedium: + """Convert to a :class:`.CustomMedium`. + Requires the pole residue model to only have a pole at 0 frequency, + corresponding to a constant conductivity term. + + Returns + ------- + :class:`.CustomMedium` + The non-dispersive equivalent with constant permittivity and conductivity. + """ + res = 0 + for a, c in self.poles: + if np.any(abs(_get_numpy_array(a)) > fp_eps): + raise ValidationError( + "Cannot convert dispersive 'CustomPoleResidue' to 'CustomMedium'." + ) + res = res + (c + np.conj(c)) / 2 + sigma = res * 2 * EPSILON_0 + + self_dict = self.model_dump(exclude={"type", "eps_inf", "poles"}) + self_dict.update({"permittivity": self.eps_inf, "conductivity": np.real(sigma)}) + return CustomMedium.model_validate(self_dict) + + @cached_property + def loss_upper_bound(self) -> float: + """Not implemented yet.""" + raise SetupError("To be implemented.") + + def _sel_custom_data_inside(self, bounds: Bound) -> Self: + """Return a new custom medium that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. + + + Parameters + ---------- + bounds : tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + CustomPoleResidue + CustomPoleResidue with reduced data. + """ + if not self.eps_inf.does_cover(bounds=bounds): + log.warning("eps_inf spatial data array does not fully cover the requested region.") + eps_inf_reduced = self.eps_inf.sel_inside(bounds=bounds) + poles_reduced = [] + for pole, residue in self.poles: + if not pole.does_cover(bounds=bounds): + log.warning("Pole spatial data array does not fully cover the requested region.") + + if not residue.does_cover(bounds=bounds): + log.warning("Residue spatial data array does not fully cover the requested region.") + + poles_reduced.append((pole.sel_inside(bounds), residue.sel_inside(bounds))) + + return self.updated_copy(eps_inf=eps_inf_reduced, poles=tuple(poles_reduced)) + + def _derivative_field_cmp( + self, + E_der_map: ElectromagneticFieldDataset, + spatial_data: CustomSpatialDataTypeAnnotated, + dim: str, + freqs: Optional[ArrayFloat] = None, + component: str = "complex", + ) -> ArrayGeneric: + """Compatibility wrapper for derivative computation. + + Accepts the extended signature used by other custom media ( + e.g., `CustomMedium._derivative_field_cmp`) while delegating the actual + computation to the base implementation that only depends on + `E_der_map`, `spatial_data`, and `dim`. + + Parameters `freqs` and `component` are ignored for this model since the + derivative is taken with respect to the complex permittivity directly. + """ + return super()._derivative_field_cmp( + E_der_map=E_der_map, spatial_data=spatial_data, dim=dim + ) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Compute adjoint derivatives by preparing array data and calling the static helper.""" + + eps_inf_sorted = self._eps_inf_sorted + use_custom_derivative = isinstance(eps_inf_sorted, SpatialDataArray) + + dJ_deps_complex = 0.0 + 0.0j + for dim in "xyz": + if use_custom_derivative: + dJ_deps_complex += self._derivative_field_cmp_custom( + E_der_map=derivative_info.E_der_map, + spatial_data=eps_inf_sorted, + dim=dim, + freqs=derivative_info.frequencies, + bounds=derivative_info.bounds_intersect, + component="complex", + ) + else: + dJ_deps_complex += self._derivative_field_cmp( + E_der_map=derivative_info.E_der_map, + spatial_data=eps_inf_sorted, + dim=dim, + ) + + poles_vals = [ + (np.array(a_sorted.values, dtype=complex), np.array(c_sorted.values, dtype=complex)) + for a_sorted, c_sorted in self._poles_sorted + ] + + vjps_total = {} + for freq in derivative_info.frequencies: + vjps_f = PoleResidue._get_vjps_from_params( + dJ_deps_complex=dJ_deps_complex, + poles_vals=poles_vals, + omega=2 * np.pi * freq, + requested_paths=derivative_info.paths, + project_real=False, + ) + for path, vjp in vjps_f.items(): + if path not in vjps_total: + vjps_total[path] = vjp + else: + vjps_total[path] += vjp + return vjps_total + + +class Sellmeier(DispersiveMedium): + """A dispersive medium described by the Sellmeier model. + + Notes + ----- + + The frequency-dependence of the refractive index is described by: + + .. math:: + + n(\\lambda)^2 = 1 + \\sum_i \\frac{B_i \\lambda^2}{\\lambda^2 - C_i} + + For lossless, weakly dispersive materials, the best way to incorporate the dispersion without doing + complicated fits and without slowing the simulation down significantly is to provide the value of the + refractive index dispersion :math:`\\frac{dn}{d\\lambda}` in :meth:`tidy3d.Sellmeier.from_dispersion`. The + value is assumed to be at the central frequency or wavelength (whichever is provided), and a one-pole model + for the material is generated. + + Example + ------- + >>> sellmeier_medium = Sellmeier(coeffs=[(1,2), (3,4)]) + >>> eps = sellmeier_medium.eps_model(200e12) + + See Also + -------- + + :class:`CustomSellmeier` + A spatially varying dispersive medium described by the Sellmeier model. + + **Notebooks** + + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + + * `Modeling dispersive material in FDTD `_ + """ + + coeffs: tuple[tuple[float, PositiveFloat], ...] = Field( + title="Coefficients", + description="List of Sellmeier (:math:`B_i, C_i`) coefficients.", + units=(None, MICROMETER + "^2"), + ) + + @model_validator(mode="after") + def _passivity_validation(self) -> Self: + """Assert passive medium if `allow_gain` is False.""" + val = self.coeffs + if self.allow_gain: + return self + for B, _ in val: + if B < 0: + raise ValidationError( + "For passive medium, 'B_i' must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + return self + + @field_validator("modulation_spec") + @classmethod + def _validate_permittivity_modulation( + cls, val: Optional[ModulationSpec] + ) -> Optional[ModulationSpec]: + """Assert modulated permittivity cannot be <= 0.""" + + if val is None or val.permittivity is None: + return val + + min_eps_inf = 1.0 + if min_eps_inf - val.permittivity.max_modulation <= 0: + raise ValidationError( + "The minimum permittivity value with modulation applied was found to be negative." + ) + return val + + _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() + + def _n_model(self, frequency: float) -> complex: + """Complex-valued refractive index as a function of frequency.""" + + wvl = C_0 / np.array(frequency) + wvl2 = wvl**2 + n_squared = 1.0 + for B, C in self.coeffs: + n_squared = n_squared + B * wvl2 / (wvl2 - C) + return np.sqrt(n_squared + 0j) + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + """Complex-valued permittivity as a function of frequency.""" + + n = self._n_model(frequency) + return AbstractMedium.nk_to_eps_complex(n) + + def _pole_residue_dict(self) -> dict: + """Dict representation of Medium as a pole-residue model""" + poles = [] + eps_inf = _ones_like(self.coeffs[0][0]) + for B, C in self.coeffs: + # for small C, it's equivalent to modifying eps_inf + if np.any(np.isclose(_get_numpy_array(C), 0)): + eps_inf += B + else: + beta = 2 * np.pi * C_0 / np.sqrt(C) + alpha = -0.5 * beta * B + a = 1j * beta + c = 1j * alpha + poles.append((a, c)) + return { + "eps_inf": eps_inf, + "poles": poles, + "frequency_range": self.frequency_range, + "name": self.name, + } + + @staticmethod + def _from_dispersion_to_coeffs( + n: float, freq: ArrayFloat, dn_dwvl: float + ) -> list[tuple[float, float]]: + """Compute Sellmeier coefficients from dispersion.""" + wvl = C_0 / np.array(freq) + nsqm1 = n**2 - 1 + c_coeff = -(wvl**3) * n * dn_dwvl / (nsqm1 - wvl * n * dn_dwvl) + b_coeff = (wvl**2 - c_coeff) / wvl**2 * nsqm1 + return [(b_coeff, c_coeff)] + + @classmethod + def from_dispersion(cls, n: float, freq: float, dn_dwvl: float = 0, **kwargs: Any) -> Self: + """Convert ``n`` and wavelength dispersion ``dn_dwvl`` values at frequency ``freq`` to + a single-pole :class:`Sellmeier` medium. + + Parameters + ---------- + n : float + Real part of refractive index. Must be larger than or equal to one. + dn_dwvl : float = 0 + Derivative of the refractive index with wavelength (1/um). Must be negative. + freq : float + Frequency at which ``n`` and ``dn_dwvl`` are sampled. + + Returns + ------- + :class:`Sellmeier` + Single-pole Sellmeier medium with the prvoided refractive index and index dispersion + valuesat at the prvoided frequency. + """ + + if dn_dwvl >= 0: + raise ValidationError("Dispersion ``dn_dwvl`` must be smaller than zero.") + if n < 1: + raise ValidationError("Refractive index ``n`` cannot be smaller than one.") + return cls(coeffs=cls._from_dispersion_to_coeffs(n, freq, dn_dwvl), **kwargs) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Adjoint derivatives for Sellmeier params via TJP through eps_model().""" + + freqs, vec = self._tjp_inputs(derivative_info) + N = len(self.coeffs) + if N == 0: + return {} + + # pack parameters into flat vector [B..., C...] + B0 = np.array([float(b) for (b, _c) in self.coeffs]) + C0 = np.array([float(c) for (_b, c) in self.coeffs]) + theta0 = np.concatenate([B0, C0]) + + def _eps_vec(theta: Sequence[PositiveFloat]) -> Union[NDArray, ArrayBox]: + B = theta[:N] + C = theta[N : 2 * N] + coeffs = tuple((B[i], C[i]) for i in range(N)) + eps = self.updated_copy(coeffs=coeffs, validate=False).eps_model(freqs) + return pack_complex_vec(eps) + + g = self._tjp_grad(theta0, _eps_vec, vec) + + mapping = [] + mapping += [(("coeffs", i, 0), i) for i in range(N)] + mapping += [(("coeffs", i, 1), N + i) for i in range(N)] + return self._map_grad_real(g, derivative_info.paths, mapping) + + @staticmethod + def _lam2( + freq: Union[float, ArrayFloat], + ) -> Union[float, ArrayFloat]: + return (C_0 / freq) ** 2 + + @staticmethod + def _sellmeier_den( + lam2: Union[float, ArrayFloat], + C: Union[float, ArrayFloat], + ) -> Union[float, ArrayFloat]: + return lam2 - C + + # frequency weights for custom Sellmeier + @staticmethod + def _w_B( + freq: Union[float, ArrayFloat], + C: Union[float, ArrayFloat], + ) -> Union[float, ArrayFloat]: + lam2 = Sellmeier._lam2(freq) + return lam2 / Sellmeier._sellmeier_den(lam2, C) + + @staticmethod + def _w_C( + freq: Union[float, ArrayFloat], + B: Union[float, ArrayFloat], + C: Union[float, ArrayFloat], + ) -> Union[float, ArrayFloat]: + lam2 = Sellmeier._lam2(freq) + den = Sellmeier._sellmeier_den(lam2, C) + return B * lam2 / (den**2) + + +class CustomSellmeier(CustomDispersiveMedium, Sellmeier): + """A spatially varying dispersive medium described by the Sellmeier model. + + Notes + ----- + + The frequency-dependence of the refractive index is described by: + + .. math:: + + n(\\lambda)^2 = 1 + \\sum_i \\frac{B_i \\lambda^2}{\\lambda^2 - C_i} + + Example + ------- + >>> x = np.linspace(-1, 1, 5) + >>> y = np.linspace(-1, 1, 6) + >>> z = np.linspace(-1, 1, 7) + >>> coords = dict(x=x, y=y, z=z) + >>> b1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) + >>> c1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) + >>> sellmeier_medium = CustomSellmeier(coeffs=[(b1,c1),]) + >>> eps = sellmeier_medium.eps_model(200e12) + + See Also + -------- + + :class:`Sellmeier` + A dispersive medium described by the Sellmeier model. + + **Notebooks** + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + * `Modeling dispersive material in FDTD `_ + """ + + coeffs: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( + Field( + title="Coefficients", + description="List of Sellmeier (:math:`B_i, C_i`) coefficients.", + units=(None, MICROMETER + "^2"), + ) + ) + + _no_nans = validate_no_nans("coeffs") + _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") + + @field_validator("coeffs") + @classmethod + def _correct_shape_and_sign( + cls, val: tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...] + ) -> tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...]: + """every term in coeffs must have the same shape, and B>=0 and C>0.""" + if len(val) == 0: + return val + for B, C in val: + if not _check_same_coordinates(B, val[0][0]) or not _check_same_coordinates( + C, val[0][0] + ): + raise SetupError("Every term in 'coeffs' must have the same coordinates.") + if not CustomDispersiveMedium._validate_isreal_dataarray_tuple((B, C)): + raise SetupError("'B' and 'C' must be real.") + if np.any(_get_numpy_array(C) <= 0): + raise SetupError("'C' must be positive.") + return val + + @model_validator(mode="after") + def _passivity_validation(self) -> Self: + """Assert passive medium if `allow_gain` is False.""" + val = self.coeffs + if self.allow_gain: + return self + for B, _ in val: + if np.any(_get_numpy_array(B) < 0): + raise ValidationError( + "For passive medium, 'B_i' must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + return self + + @field_validator("coeffs") + @classmethod + def _coeffs_C_all_near_zero_or_much_greater( + cls, val: tuple[tuple[float, PositiveFloat], ...] + ) -> tuple[tuple[float, PositiveFloat], ...]: + """We restrict either all C~=0, or very different from 0.""" + for _, C in val: + c_array_near_zero = np.isclose(_get_numpy_array(C), 0) + if np.any(c_array_near_zero) and not np.all(c_array_near_zero): + raise SetupError( + "Coefficients 'C_i' are restricted to be " + "either all near zero or much greater than 0." + ) + return val + + @cached_property + def is_spatially_uniform(self) -> bool: + """Whether the medium is spatially uniform.""" + for coeffs in self.coeffs: + for coeff in coeffs: + if not coeff.is_uniform: + return False + return True + + def _pole_residue_dict(self) -> dict: + """Dict representation of Medium as a pole-residue model.""" + poles_dict = Sellmeier._pole_residue_dict(self) + if len(self.coeffs) > 0: + poles_dict.update({"eps_inf": _ones_like(self.coeffs[0][0])}) + return poles_dict + + def eps_dataarray_freq( + self, frequency: float + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + """Permittivity array at ``frequency``. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[ + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + ] + The permittivity evaluated at ``frequency``. + """ + eps = Sellmeier.eps_model(self, frequency) + # if `eps` is simply a float, convert it to a SpatialDataArray ; this is possible when + # `coeffs` is empty. + if isinstance(eps, (int, float, complex)): + eps = SpatialDataArray(eps * np.ones((1, 1, 1)), coords={"x": [0], "y": [0], "z": [0]}) + return (eps, eps, eps) + + @classmethod + def from_dispersion( + cls, + n: CustomSpatialDataType, + freq: float, + dn_dwvl: CustomSpatialDataType, + interp_method: InterpMethod = "nearest", + **kwargs: Any, + ) -> Self: + """Convert ``n`` and wavelength dispersion ``dn_dwvl`` values at frequency ``freq`` to + a single-pole :class:`CustomSellmeier` medium. + + Parameters + ---------- + n : Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ] + Real part of refractive index. Must be larger than or equal to one. + dn_dwvl : Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ] + Derivative of the refractive index with wavelength (1/um). Must be negative. + freq : float + Frequency at which ``n`` and ``dn_dwvl`` are sampled. + interp_method : :class:`.InterpMethod`, optional + Interpolation method to obtain permittivity values that are not supplied + at the Yee grids. + + Returns + ------- + :class:`.CustomSellmeier` + Single-pole Sellmeier medium with the prvoided refractive index and index dispersion + valuesat at the prvoided frequency. + """ + + if not _check_same_coordinates(n, dn_dwvl): + raise ValidationError("'n' and'dn_dwvl' must have the same dimension.") + if np.any(_get_numpy_array(dn_dwvl) >= 0): + raise ValidationError("Dispersion ``dn_dwvl`` must be smaller than zero.") + if np.any(_get_numpy_array(n) < 1): + raise ValidationError("Refractive index ``n`` cannot be smaller than one.") + return cls( + coeffs=cls._from_dispersion_to_coeffs(n, freq, dn_dwvl), + interp_method=interp_method, + **kwargs, + ) + + def _sel_custom_data_inside(self, bounds: Bound) -> Self: + """Return a new custom medium that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. + + + Parameters + ---------- + bounds : tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + CustomSellmeier + CustomSellmeier with reduced data. + """ + coeffs_reduced = [] + for b_coeff, c_coeff in self.coeffs: + if not b_coeff.does_cover(bounds=bounds): + log.warning( + "Sellmeier B coeff spatial data array does not fully cover the requested region." + ) + + if not c_coeff.does_cover(bounds=bounds): + log.warning( + "Sellmeier C coeff spatial data array does not fully cover the requested region." + ) + + coeffs_reduced.append((b_coeff.sel_inside(bounds), c_coeff.sel_inside(bounds))) + + return self.updated_copy(coeffs=tuple(coeffs_reduced)) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Adjoint derivatives for CustomSellmeier via analytic chain rule. + + Uses the complex permittivity derivative aggregated over spatial dims and + applies frequency-dependent weights per Sellmeier term. + """ + + if len(self.coeffs) == 0: + return {} + + # accumulate complex-valued sensitivity across xyz using B's grid as reference + ref = self.coeffs[0][0] + dJ = self._sum_complex_eps_sensitivity(derivative_info, spatial_ref=ref) + + # prepare gradients map + grads: AutogradFieldMap = {} + + # iterate coefficients and requested paths + for i, (B_da, C_da) in enumerate(self.coeffs): + need_B = ("coeffs", i, 0) in derivative_info.paths + need_C = ("coeffs", i, 1) in derivative_info.paths + if not (need_B or need_C): + continue + + Bv = np.array(B_da.values, dtype=float) + Cv = np.array(C_da.values, dtype=float) + + gB = 0.0 if not need_B else np.zeros_like(Bv, dtype=float) + gC = 0.0 if not need_C else np.zeros_like(Cv, dtype=float) + + if need_B: + gB = gB + self._sum_over_freqs( + derivative_info.frequencies, + dJ, + weight_fn=lambda f, Cv=Cv: Sellmeier._w_B(f, Cv), + ) + if need_C: + gC = gC + self._sum_over_freqs( + derivative_info.frequencies, + dJ, + weight_fn=lambda f, Bv=Bv, Cv=Cv: Sellmeier._w_C(f, Bv, Cv), + ) + + if need_B: + grads[("coeffs", i, 0)] = gB + if need_C: + grads[("coeffs", i, 1)] = gC + + return grads + + +class Lorentz(DispersiveMedium): + """A dispersive medium described by the Lorentz model. + + Notes + ----- + + The frequency-dependence of the complex-valued permittivity is described by: + + .. math:: + + \\epsilon(f) = \\epsilon_\\infty + \\sum_i + \\frac{\\Delta\\epsilon_i f_i^2}{f_i^2 - 2jf\\delta_i - f^2} + + Example + ------- + >>> lorentz_medium = Lorentz(eps_inf=2.0, coeffs=[(1,2,3), (4,5,6)]) + >>> eps = lorentz_medium.eps_model(200e12) + + See Also + -------- + + **Notebooks** + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + * `Modeling dispersive material in FDTD `_ + """ + + eps_inf: PositiveFloat = Field( + 1.0, + title="Epsilon at Infinity", + description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", + units=PERMITTIVITY, + ) + + coeffs: tuple[tuple[float, float, NonNegativeFloat], ...] = Field( + title="Coefficients", + description="List of (:math:`\\Delta\\epsilon_i, f_i, \\delta_i`) values for model.", + units=(PERMITTIVITY, HERTZ, HERTZ), + ) + + @field_validator("coeffs") + @classmethod + def _coeffs_unequal_f_delta( + cls, val: tuple[tuple[float, float, NonNegativeFloat], ...] + ) -> tuple[tuple[float, float, NonNegativeFloat], ...]: + """f**2 and delta**2 cannot be exactly the same.""" + for _, f, delta in val: + if f**2 == delta**2: + raise SetupError("'f' and 'delta' cannot take equal values.") + return val + + @model_validator(mode="after") + def _passivity_validation(self) -> Self: + """Assert passive medium if ``allow_gain`` is False.""" + val = self.coeffs + if self.allow_gain: + return self + for del_ep, _, _ in val: + if del_ep < 0: + raise ValidationError( + "For passive medium, 'Delta epsilon_i' must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + return self + + _validate_permittivity_modulation = DispersiveMedium._permittivity_modulation_validation() + _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + """Complex-valued permittivity as a function of frequency.""" + + eps = self.eps_inf + 0.0j + for de, f, delta in self.coeffs: + eps = eps + (de * f**2) / (f**2 - 2j * frequency * delta - frequency**2) + return eps + + def _pole_residue_dict(self) -> dict: + """Dict representation of Medium as a pole-residue model.""" + + poles = [] + for de, f, delta in self.coeffs: + w = 2 * np.pi * f + d = 2 * np.pi * delta + + if self._all_larger(d**2, w**2): + r = np.sqrt(d * d - w * w) + 0j + a0 = -d + r + c0 = de * w**2 / 4 / r + a1 = -d - r + c1 = -c0 + poles.extend(((a0, c0), (a1, c1))) + else: + r = np.sqrt(w * w - d * d) + a = -d - 1j * r + c = 1j * de * w**2 / 2 / r + poles.append((a, c)) + + return { + "eps_inf": self.eps_inf, + "poles": poles, + "frequency_range": self.frequency_range, + "name": self.name, + } + + @staticmethod + def _all_larger( + coeff_a: tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...], + coeff_b: tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...], + ) -> bool: + """``coeff_a`` and ``coeff_b`` can be either float or SpatialDataArray.""" + if isinstance(coeff_a, CustomSpatialDataType.__args__): + return np.all(_get_numpy_array(coeff_a) > _get_numpy_array(coeff_b)) + return coeff_a > coeff_b + + @classmethod + def from_nk(cls, n: float, k: float, freq: float, **kwargs: Any) -> Self: + """Convert ``n`` and ``k`` values at frequency ``freq`` to a single-pole Lorentz + medium. + + Parameters + ---------- + n : float + Real part of refractive index. + k : float = 0 + Imaginary part of refrative index. + freq : float + Frequency to evaluate permittivity at (Hz). + kwargs: dict + Keyword arguments passed to the medium construction. + + Returns + ------- + :class:`Lorentz` + Lorentz medium having refractive index n+ik at frequency ``freq``. + """ + eps_complex = AbstractMedium.nk_to_eps_complex(n, k) + eps_r, eps_i = eps_complex.real, eps_complex.imag + if eps_r >= 1: + log.warning( + "For 'permittivity>=1', it is more computationally efficient to " + "use a dispersiveless medium constructed from 'Medium.from_nk()'." + ) + # first, lossless medium + if isclose(eps_i, 0): + if eps_r < 1: + fp = np.sqrt((eps_r - 1) / (eps_r - 2)) * freq + return cls( + eps_inf=1, + coeffs=[ + (1, fp, 0), + ], + ) + return cls( + eps_inf=1, + coeffs=[ + ((eps_r - 1) / 2, np.sqrt(2) * freq, 0), + ], + ) + # lossy medium + alpha = (eps_r - 1) / eps_i + delta_p = freq / 2 / (alpha**2 - alpha + 1) + fp = np.sqrt((alpha**2 + 1) / (alpha**2 - alpha + 1)) * freq + return cls( + eps_inf=1, + coeffs=[ + (eps_i, fp, delta_p), + ], + **kwargs, + ) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Adjoint derivatives for Lorentz params via TJP through eps_model().""" + + f, vec = self._tjp_inputs(derivative_info) + + N = len(self.coeffs) + if N == 0 and ("eps_inf",) not in derivative_info.paths: + return {} + + # pack into flat [eps_inf, de..., f0..., delta...] + eps_inf0 = float(self.eps_inf) + de0 = np.array([float(de) for (de, _f, _d) in self.coeffs]) if N else np.array([]) + f0 = np.array([float(fi) for (_de, fi, _d) in self.coeffs]) if N else np.array([]) + d0 = np.array([float(dd) for (_de, _f, dd) in self.coeffs]) if N else np.array([]) + theta0 = np.concatenate([np.array([eps_inf0]), de0, f0, d0]) + + def _eps_vec(theta: Sequence[PositiveFloat]) -> Union[NDArray, ArrayBox]: + eps_inf = theta[0] + de = theta[1 : 1 + N] + fi = theta[1 + N : 1 + 2 * N] + dd = theta[1 + 2 * N : 1 + 3 * N] + coeffs = tuple((de[i], fi[i], dd[i]) for i in range(N)) + eps = self.updated_copy(eps_inf=eps_inf, coeffs=coeffs, validate=False).eps_model(f) + return pack_complex_vec(eps) + + g = self._tjp_grad(theta0, _eps_vec, vec) + + mapping = [(("eps_inf",), 0)] + base = 1 + mapping += [(("coeffs", i, 0), base + i) for i in range(N)] + mapping += [(("coeffs", i, 1), base + N + i) for i in range(N)] + mapping += [(("coeffs", i, 2), base + 2 * N + i) for i in range(N)] + return self._map_grad_real(g, derivative_info.paths, mapping) + + @staticmethod + def _den( + freq: Union[float, ArrayFloat], + f0: Union[float, ArrayFloat], + delta: Union[float, ArrayFloat], + ) -> Union[complex, ArrayComplex]: + return (f0**2) - 2j * (freq * delta) - (freq**2) + + # frequency weights for custom Lorentz + @staticmethod + def _w_de( + freq: Union[float, ArrayFloat], + f0: Union[float, ArrayFloat], + delta: Union[float, ArrayFloat], + ) -> Union[complex, ArrayComplex]: + return (f0**2) / Lorentz._den(freq, f0, delta) + + @staticmethod + def _w_f0( + freq: Union[float, ArrayFloat], + de: Union[float, ArrayFloat], + f0: Union[float, ArrayFloat], + delta: Union[float, ArrayFloat], + ) -> Union[complex, ArrayComplex]: + den = Lorentz._den(freq, f0, delta) + return (2.0 * de * f0 * (den - f0**2)) / (den**2) + + @staticmethod + def _w_delta( + freq: Union[float, ArrayFloat], + de: Union[float, ArrayFloat], + f0: Union[float, ArrayFloat], + delta: Union[float, ArrayFloat], + ) -> Union[complex, ArrayComplex]: + den = Lorentz._den(freq, f0, delta) + return (2j * freq * de * (f0**2)) / (den**2) + + +class CustomLorentz(CustomDispersiveMedium, Lorentz): + """A spatially varying dispersive medium described by the Lorentz model. + + Notes + ----- + + The frequency-dependence of the complex-valued permittivity is described by: + + .. math:: + + \\epsilon(f) = \\epsilon_\\infty + \\sum_i + \\frac{\\Delta\\epsilon_i f_i^2}{f_i^2 - 2jf\\delta_i - f^2} + + Example + ------- + >>> x = np.linspace(-1, 1, 5) + >>> y = np.linspace(-1, 1, 6) + >>> z = np.linspace(-1, 1, 7) + >>> coords = dict(x=x, y=y, z=z) + >>> eps_inf = SpatialDataArray(np.ones((5, 6, 7)), coords=coords) + >>> d_epsilon = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) + >>> f = SpatialDataArray(1+np.random.random((5, 6, 7)), coords=coords) + >>> delta = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) + >>> lorentz_medium = CustomLorentz(eps_inf=eps_inf, coeffs=[(d_epsilon,f,delta),]) + >>> eps = lorentz_medium.eps_model(200e12) + + See Also + -------- + + :class:`CustomPoleResidue`: + A spatially varying dispersive medium described by the pole-residue pair model. + + **Notebooks** + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + * `Modeling dispersive material in FDTD `_ + """ + + eps_inf: CustomSpatialDataTypeAnnotated = Field( + title="Epsilon at Infinity", + description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", + units=PERMITTIVITY, + ) + + coeffs: tuple[ + tuple[ + CustomSpatialDataTypeAnnotated, + CustomSpatialDataTypeAnnotated, + CustomSpatialDataTypeAnnotated, + ], + ..., + ] = Field( + title="Coefficients", + description="List of (:math:`\\Delta\\epsilon_i, f_i, \\delta_i`) values for model.", + units=(PERMITTIVITY, HERTZ, HERTZ), + ) + + _no_nans = validate_no_nans("eps_inf", "coeffs") + _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") + + @field_validator("eps_inf") + @classmethod + def _eps_inf_positive(cls, val: CustomSpatialDataType) -> CustomSpatialDataType: + """eps_inf must be positive""" + if not CustomDispersiveMedium._validate_isreal_dataarray(val): + raise SetupError("'eps_inf' must be real.") + if np.any(_get_numpy_array(val) < 0): + raise SetupError("'eps_inf' must be positive.") + return val + + @field_validator("coeffs") + @classmethod + def _coeffs_unequal_f_delta( + cls, val: tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...] + ) -> tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...]: + """f and delta cannot be exactly the same. + Not needed for now because we have a more strict + validator `_coeffs_delta_all_smaller_or_larger_than_fi`. + """ + return val + + @model_validator(mode="after") + def _coeffs_correct_shape(self) -> Self: + """coeffs must have consistent shape.""" + val = self.coeffs + for de, f, delta in val: + if ( + not _check_same_coordinates(de, self.eps_inf) + or not _check_same_coordinates(f, self.eps_inf) + or not _check_same_coordinates(delta, self.eps_inf) + ): + raise SetupError( + "All terms in 'coeffs' must have the same coordinates; " + "The coordinates must also be consistent with 'eps_inf'." + ) + if not CustomDispersiveMedium._validate_isreal_dataarray_tuple((de, f, delta)): + raise SetupError("All terms in 'coeffs' must be real.") + return self + + @field_validator("coeffs") + @classmethod + def _coeffs_delta_all_smaller_or_larger_than_fi( + cls, val: tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...] + ) -> tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...]: + """We restrict either all f**2>delta**2 or all f**2'f**2'." + ) + return val + + @model_validator(mode="after") + def _passivity_validation(self) -> Self: + """Assert passive medium if ``allow_gain`` is False.""" + val = self.coeffs + allow_gain = self.allow_gain + for del_ep, _, delta in val: + if np.any(_get_numpy_array(delta) < 0): + raise ValidationError("For stable medium, 'delta_i' must be non-negative.") + if not allow_gain and np.any(_get_numpy_array(del_ep) < 0): + raise ValidationError( + "For passive medium, 'Delta epsilon_i' must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + return self + + @cached_property + def is_spatially_uniform(self) -> bool: + """Whether the medium is spatially uniform.""" + if not self.eps_inf.is_uniform: + return False + for coeffs in self.coeffs: + for coeff in coeffs: + if not coeff.is_uniform: + return False + return True + + def eps_dataarray_freq( + self, frequency: float + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + """Permittivity array at ``frequency``. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[ + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + ] + The permittivity evaluated at ``frequency``. + """ + eps = Lorentz.eps_model(self, frequency) + return (eps, eps, eps) + + def _sel_custom_data_inside(self, bounds: Bound) -> Self: + """Return a new custom medium that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. + + + Parameters + ---------- + bounds : tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + CustomLorentz + CustomLorentz with reduced data. + """ + if not self.eps_inf.does_cover(bounds=bounds): + log.warning("Eps inf spatial data array does not fully cover the requested region.") + eps_inf_reduced = self.eps_inf.sel_inside(bounds=bounds) + coeffs_reduced = [] + for de, f, delta in self.coeffs: + if not de.does_cover(bounds=bounds): + log.warning( + "Lorentz 'de' spatial data array does not fully cover the requested region." + ) + + if not f.does_cover(bounds=bounds): + log.warning( + "Lorentz 'f' spatial data array does not fully cover the requested region." + ) + + if not delta.does_cover(bounds=bounds): + log.warning( + "Lorentz 'delta' spatial data array does not fully cover the requested region." + ) + + coeffs_reduced.append( + (de.sel_inside(bounds), f.sel_inside(bounds), delta.sel_inside(bounds)) + ) + + return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=tuple(coeffs_reduced)) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Adjoint derivatives for CustomLorentz via analytic chain rule.""" + + # complex epsilon sensitivity over xyz aligned to eps_inf grid + dJ = self._sum_complex_eps_sensitivity(derivative_info, spatial_ref=self.eps_inf) + + grads: AutogradFieldMap = {} + + # eps_inf path + if ("eps_inf",) in derivative_info.paths: + grads[("eps_inf",)] = np.real(dJ) + + # per-coefficient contributions + for i, (de_da, f0_da, dl_da) in enumerate(self.coeffs): + need_de = ("coeffs", i, 0) in derivative_info.paths + need_f0 = ("coeffs", i, 1) in derivative_info.paths + need_dl = ("coeffs", i, 2) in derivative_info.paths + if not (need_de or need_f0 or need_dl): + continue + + de = np.array(de_da.values, dtype=float) + f0 = np.array(f0_da.values, dtype=float) + dl = np.array(dl_da.values, dtype=float) + + g_de = 0.0 if not need_de else np.zeros_like(de, dtype=float) + g_f0 = 0.0 if not need_f0 else np.zeros_like(f0, dtype=float) + g_dl = 0.0 if not need_dl else np.zeros_like(dl, dtype=float) + + if need_de: + g_de = g_de + self._sum_over_freqs( + derivative_info.frequencies, + dJ, + weight_fn=lambda f, f0=f0, dl=dl: Lorentz._w_de(f, f0, dl), + ) + if need_f0: + # d/d f0 of (de f0^2 / den) = (2 de f0 (den - f0^2)) / den^2 + g_f0 = g_f0 + self._sum_over_freqs( + derivative_info.frequencies, + dJ, + weight_fn=lambda f, de=de, f0=f0, dl=dl: Lorentz._w_f0(f, de, f0, dl), + ) + if need_dl: + # d/d delta of (de f0^2 / den) = (2 j f de f0^2) / den^2 + g_dl = g_dl + self._sum_over_freqs( + derivative_info.frequencies, + dJ, + weight_fn=lambda f, de=de, f0=f0, dl=dl: Lorentz._w_delta(f, de, f0, dl), + ) + + if need_de: + grads[("coeffs", i, 0)] = g_de + if need_f0: + grads[("coeffs", i, 1)] = g_f0 + if need_dl: + grads[("coeffs", i, 2)] = g_dl + + return grads + + +class Drude(DispersiveMedium): + """A dispersive medium described by the Drude model. + + Notes + ----- + + The frequency-dependence of the complex-valued permittivity is described by: + + .. math:: + + \\epsilon(f) = \\epsilon_\\infty - \\sum_i + \\frac{ f_i^2}{f^2 + jf\\delta_i} + + Example + ------- + >>> drude_medium = Drude(eps_inf=2.0, coeffs=[(1,2), (3,4)]) + >>> eps = drude_medium.eps_model(200e12) + + See Also + -------- + + :class:`CustomDrude`: + A spatially varying dispersive medium described by the Drude model. + + **Notebooks** + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + * `Modeling dispersive material in FDTD `_ + """ + + eps_inf: PositiveFloat = Field( + 1.0, + title="Epsilon at Infinity", + description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", + units=PERMITTIVITY, + ) + + coeffs: tuple[tuple[float, PositiveFloat], ...] = Field( + title="Coefficients", + description="List of (:math:`f_i, \\delta_i`) values for model.", + units=(HERTZ, HERTZ), + ) + + _validate_permittivity_modulation = DispersiveMedium._permittivity_modulation_validation() + _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + """Complex-valued permittivity as a function of frequency.""" + + eps = self.eps_inf + 0.0j + for f, delta in self.coeffs: + eps = eps - (f**2) / (frequency**2 + 1j * frequency * delta) + return eps + + # --- unified helpers for autograd + tests --- + + def _pole_residue_dict(self) -> dict: + """Dict representation of Medium as a pole-residue model.""" + + poles = [] + + for f, delta in self.coeffs: + w = 2 * np.pi * f + d = 2 * np.pi * delta + + c0 = (w**2) / 2 / d + 0j + c1 = -c0 + a1 = -d + 0j + + if isinstance(c0, complex): + a0 = 0j + else: + a0 = 0 * c0 + + poles.extend(((a0, c0), (a1, c1))) + + return { + "eps_inf": self.eps_inf, + "poles": poles, + "frequency_range": self.frequency_range, + "name": self.name, + } + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Adjoint derivatives for Drude params via TJP through eps_model().""" + + f, vec = self._tjp_inputs(derivative_info) + + N = len(self.coeffs) + if N == 0 and ("eps_inf",) not in derivative_info.paths: + return {} + + # pack into flat [eps_inf, fp..., delta...] + eps_inf0 = float(self.eps_inf) + fp0 = np.array([float(fp) for (fp, _d) in self.coeffs]) if N else np.array([]) + d0 = np.array([float(dd) for (_fp, dd) in self.coeffs]) if N else np.array([]) + theta0 = np.concatenate([np.array([eps_inf0]), fp0, d0]) + + def _eps_vec(theta: Sequence[PositiveFloat]) -> Union[NDArray, ArrayBox]: + eps_inf = theta[0] + fp = theta[1 : 1 + N] + dd = theta[1 + N : 1 + 2 * N] + coeffs = tuple((fp[i], dd[i]) for i in range(N)) + eps = self.updated_copy(eps_inf=eps_inf, coeffs=coeffs, validate=False).eps_model(f) + return pack_complex_vec(eps) + + g = self._tjp_grad(theta0, _eps_vec, vec) + + mapping = [(("eps_inf",), 0)] + base = 1 + mapping += [(("coeffs", i, 0), base + i) for i in range(N)] + mapping += [(("coeffs", i, 1), base + N + i) for i in range(N)] + return self._map_grad_real(g, derivative_info.paths, mapping) + + @staticmethod + def _den( + freq: Union[float, ArrayFloat], + delta: Union[float, ArrayFloat], + ) -> Union[complex, ArrayComplex]: + return (freq**2) + 1j * (freq * delta) + + # frequency weights for custom Drude + @staticmethod + def _w_fp( + freq: Union[float, ArrayFloat], + fp: Union[float, ArrayFloat], + delta: Union[float, ArrayFloat], + ) -> Union[complex, ArrayComplex]: + return -(2.0 * fp) / Drude._den(freq, delta) + + @staticmethod + def _w_delta( + freq: Union[float, ArrayFloat], + fp: Union[float, ArrayFloat], + delta: Union[float, ArrayFloat], + ) -> Union[complex, ArrayComplex]: + den = Drude._den(freq, delta) + return (1j * freq * (fp**2)) / (den**2) + + +class CustomDrude(CustomDispersiveMedium, Drude): + """A spatially varying dispersive medium described by the Drude model. + + + Notes + ----- + + The frequency-dependence of the complex-valued permittivity is described by: + + .. math:: + + \\epsilon(f) = \\epsilon_\\infty - \\sum_i + \\frac{ f_i^2}{f^2 + jf\\delta_i} + + Example + ------- + >>> x = np.linspace(-1, 1, 5) + >>> y = np.linspace(-1, 1, 6) + >>> z = np.linspace(-1, 1, 7) + >>> coords = dict(x=x, y=y, z=z) + >>> eps_inf = SpatialDataArray(np.ones((5, 6, 7)), coords=coords) + >>> f1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) + >>> delta1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) + >>> drude_medium = CustomDrude(eps_inf=eps_inf, coeffs=[(f1,delta1),]) + >>> eps = drude_medium.eps_model(200e12) + + See Also + -------- + + :class:`Drude`: + A dispersive medium described by the Drude model. + + **Notebooks** + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + * `Modeling dispersive material in FDTD `_ + """ + + eps_inf: CustomSpatialDataTypeAnnotated = Field( + title="Epsilon at Infinity", + description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", + units=PERMITTIVITY, + ) + + coeffs: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( + Field( + title="Coefficients", + description="List of (:math:`f_i, \\delta_i`) values for model.", + units=(HERTZ, HERTZ), + ) + ) + + _no_nans = validate_no_nans("eps_inf", "coeffs") + _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") + + @field_validator("eps_inf") + @classmethod + def _eps_inf_positive(cls, val: TracedPositiveFloat) -> TracedPositiveFloat: + """eps_inf must be positive""" + if not CustomDispersiveMedium._validate_isreal_dataarray(val): + raise SetupError("'eps_inf' must be real.") + if np.any(_get_numpy_array(val) < 0): + raise SetupError("'eps_inf' must be positive.") + return val + + @model_validator(mode="after") + def _coeffs_correct_shape_and_sign(self) -> Self: + """coeffs must have consistent shape and sign.""" + val = self.coeffs + for f, delta in val: + if not _check_same_coordinates(f, self.eps_inf) or not _check_same_coordinates( + delta, self.eps_inf + ): + raise SetupError( + "All terms in 'coeffs' must have the same coordinates; " + "The coordinates must also be consistent with 'eps_inf'." + ) + if not CustomDispersiveMedium._validate_isreal_dataarray_tuple((f, delta)): + raise SetupError("All terms in 'coeffs' must be real.") + if np.any(_get_numpy_array(delta) <= 0): + raise SetupError("For stable medium, 'delta' must be positive.") + return self + + @cached_property + def is_spatially_uniform(self) -> bool: + """Whether the medium is spatially uniform.""" + if not self.eps_inf.is_uniform: + return False + for coeffs in self.coeffs: + for coeff in coeffs: + if not coeff.is_uniform: + return False + return True + + def eps_dataarray_freq( + self, frequency: float + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + """Permittivity array at ``frequency``. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[ + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + ] + The permittivity evaluated at ``frequency``. + """ + eps = Drude.eps_model(self, frequency) + return (eps, eps, eps) + + def _sel_custom_data_inside(self, bounds: Bound) -> Self: + """Return a new custom medium that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. + + + Parameters + ---------- + bounds : tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + CustomDrude + CustomDrude with reduced data. + """ + if not self.eps_inf.does_cover(bounds=bounds): + log.warning("Eps inf spatial data array does not fully cover the requested region.") + eps_inf_reduced = self.eps_inf.sel_inside(bounds=bounds) + coeffs_reduced = [] + for f, delta in self.coeffs: + if not f.does_cover(bounds=bounds): + log.warning( + "Drude 'f' spatial data array does not fully cover the requested region." + ) + + if not delta.does_cover(bounds=bounds): + log.warning( + "Drude 'delta' spatial data array does not fully cover the requested region." + ) + + coeffs_reduced.append((f.sel_inside(bounds), delta.sel_inside(bounds))) + + return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=tuple(coeffs_reduced)) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Adjoint derivatives for CustomDrude via analytic chain rule.""" + + dJ = self._sum_complex_eps_sensitivity(derivative_info, spatial_ref=self.eps_inf) + + grads: AutogradFieldMap = {} + if ("eps_inf",) in derivative_info.paths: + grads[("eps_inf",)] = np.real(dJ) + + for i, (fp_da, dl_da) in enumerate(self.coeffs): + need_fp = ("coeffs", i, 0) in derivative_info.paths + need_dl = ("coeffs", i, 1) in derivative_info.paths + if not (need_fp or need_dl): + continue + + fp = np.array(fp_da.values, dtype=float) + dl = np.array(dl_da.values, dtype=float) + + g_fp = 0.0 if not need_fp else np.zeros_like(fp, dtype=float) + g_dl = 0.0 if not need_dl else np.zeros_like(dl, dtype=float) + + if need_fp: + g_fp = g_fp + self._sum_over_freqs( + derivative_info.frequencies, + dJ, + weight_fn=lambda f, fp=fp, dl=dl: Drude._w_fp(f, fp, dl), + ) + if need_dl: + g_dl = g_dl + self._sum_over_freqs( + derivative_info.frequencies, + dJ, + weight_fn=lambda f, fp=fp, dl=dl: Drude._w_delta(f, fp, dl), + ) + + if need_fp: + grads[("coeffs", i, 0)] = g_fp + if need_dl: + grads[("coeffs", i, 1)] = g_dl + + return grads + + +class Debye(DispersiveMedium): + """A dispersive medium described by the Debye model. + + Notes + ----- + + The frequency-dependence of the complex-valued permittivity is described by: + + .. math:: + + \\epsilon(f) = \\epsilon_\\infty + \\sum_i + \\frac{\\Delta\\epsilon_i}{1 - jf\\tau_i} + + Example + ------- + >>> debye_medium = Debye(eps_inf=2.0, coeffs=[(1,2),(3,4)]) + >>> eps = debye_medium.eps_model(200e12) + + See Also + -------- + + :class:`CustomDebye` + A spatially varying dispersive medium described by the Debye model. + + **Notebooks** + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + * `Modeling dispersive material in FDTD `_ + """ + + eps_inf: PositiveFloat = Field( + 1.0, + title="Epsilon at Infinity", + description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", + units=PERMITTIVITY, + ) + + coeffs: tuple[tuple[float, PositiveFloat], ...] = Field( + title="Coefficients", + description="List of (:math:`\\Delta\\epsilon_i, \\tau_i`) values for model.", + units=(PERMITTIVITY, SECOND), + ) + + @model_validator(mode="after") + def _passivity_validation(self) -> Self: + """Assert passive medium if `allow_gain` is False.""" + val = self.coeffs + if self.allow_gain: + return self + for del_ep, _ in val: + if del_ep < 0: + raise ValidationError( + "For passive medium, 'Delta epsilon_i' must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + return self + + _validate_permittivity_modulation = DispersiveMedium._permittivity_modulation_validation() + _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + """Complex-valued permittivity as a function of frequency.""" + + eps = self.eps_inf + 0.0j + for de, tau in self.coeffs: + eps = eps + de / (1 - 1j * frequency * tau) + return eps + + # --- unified helpers for autograd + tests --- + + def _pole_residue_dict( + self, + ) -> dict[ + str, Union[PositiveFloat, list[tuple[complex, complex]], Optional[FreqBound], Optional[str]] + ]: + """Dict representation of Medium as a pole-residue model.""" + + poles = [] + eps_inf = self.eps_inf + for de, tau in self.coeffs: + # for |tau| close to 0, it's equivalent to modifying eps_inf + if np.any(abs(_get_numpy_array(tau)) < 1 / 2 / np.pi / LARGEST_FP_NUMBER): + eps_inf = eps_inf + de + else: + a = -2 * np.pi / tau + 0j + c = -0.5 * de * a + + poles.append((a, c)) + + return { + "eps_inf": eps_inf, + "poles": poles, + "frequency_range": self.frequency_range, + "name": self.name, + } + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Adjoint derivatives for Debye params via TJP through eps_model().""" + + f, vec = self._tjp_inputs(derivative_info) + + N = len(self.coeffs) + if N == 0 and ("eps_inf",) not in derivative_info.paths: + return {} + + # pack into flat [eps_inf, de..., tau...] + eps_inf0 = float(self.eps_inf) + de0 = np.array([float(de) for (de, _t) in self.coeffs]) if N else np.array([]) + tau0 = np.array([float(t) for (_de, t) in self.coeffs]) if N else np.array([]) + theta0 = np.concatenate([np.array([eps_inf0]), de0, tau0]) + + def _eps_vec(theta: Sequence[PositiveFloat]) -> Union[NDArray, ArrayBox]: + eps_inf = theta[0] + de = theta[1 : 1 + N] + tau = theta[1 + N : 1 + 2 * N] + coeffs = tuple((de[i], tau[i]) for i in range(N)) + eps = self.updated_copy(eps_inf=eps_inf, coeffs=coeffs, validate=False).eps_model(f) + return pack_complex_vec(eps) + + g = self._tjp_grad(theta0, _eps_vec, vec) + + mapping = [(("eps_inf",), 0)] + base = 1 + mapping += [(("coeffs", i, 0), base + i) for i in range(N)] + mapping += [(("coeffs", i, 1), base + N + i) for i in range(N)] + return self._map_grad_real(g, derivative_info.paths, mapping) + + @staticmethod + def _den( + freq: Union[float, ArrayFloat], + tau: Union[float, ArrayFloat], + ) -> Union[complex, ArrayComplex]: + return 1 - 1j * (freq * tau) + + # frequency weights for custom Debye + @staticmethod + def _w_de( + freq: Union[float, ArrayFloat], + tau: Union[float, ArrayFloat], + ) -> Union[complex, ArrayComplex]: + return 1.0 / Debye._den(freq, tau) + + @staticmethod + def _w_tau( + freq: Union[float, ArrayFloat], + de: Union[float, ArrayFloat], + tau: Union[float, ArrayFloat], + ) -> Union[complex, ArrayComplex]: + den = Debye._den(freq, tau) + return (1j * freq * de) / (den**2) + + +class CustomDebye(CustomDispersiveMedium, Debye): + """A spatially varying dispersive medium described by the Debye model. + + Notes + ----- + + The frequency-dependence of the complex-valued permittivity is described by: + + .. math:: + + \\epsilon(f) = \\epsilon_\\infty + \\sum_i + \\frac{\\Delta\\epsilon_i}{1 - jf\\tau_i} + + Example + ------- + >>> x = np.linspace(-1, 1, 5) + >>> y = np.linspace(-1, 1, 6) + >>> z = np.linspace(-1, 1, 7) + >>> coords = dict(x=x, y=y, z=z) + >>> eps_inf = SpatialDataArray(1+np.random.random((5, 6, 7)), coords=coords) + >>> eps1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) + >>> tau1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) + >>> debye_medium = CustomDebye(eps_inf=eps_inf, coeffs=[(eps1,tau1),]) + >>> eps = debye_medium.eps_model(200e12) + + See Also + -------- + + :class:`Debye` + A dispersive medium described by the Debye model. + + **Notebooks** + * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ + + **Lectures** + * `Modeling dispersive material in FDTD `_ + """ + + eps_inf: CustomSpatialDataTypeAnnotated = Field( + title="Epsilon at Infinity", + description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", + units=PERMITTIVITY, + ) + + coeffs: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( + Field( + title="Coefficients", + description="List of (:math:`\\Delta\\epsilon_i, \\tau_i`) values for model.", + units=(PERMITTIVITY, SECOND), + ) + ) + + _no_nans = validate_no_nans("eps_inf", "coeffs") + _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") + + @field_validator("eps_inf") + @classmethod + def _eps_inf_positive(cls, val: TracedPositiveFloat) -> TracedPositiveFloat: + """eps_inf must be positive""" + if not CustomDispersiveMedium._validate_isreal_dataarray(val): + raise SetupError("'eps_inf' must be real.") + if np.any(_get_numpy_array(val) < 0): + raise SetupError("'eps_inf' must be positive.") + return val + + @model_validator(mode="after") + def _coeffs_correct_shape(self) -> Self: + """coeffs must have consistent shape.""" + val = self.coeffs + for de, tau in val: + if not _check_same_coordinates(de, self.eps_inf) or not _check_same_coordinates( + tau, self.eps_inf + ): + raise SetupError( + "All terms in 'coeffs' must have the same coordinates; " + "The coordinates must also be consistent with 'eps_inf'." + ) + if not CustomDispersiveMedium._validate_isreal_dataarray_tuple((de, tau)): + raise SetupError("All terms in 'coeffs' must be real.") + return self + + @field_validator("coeffs") + @classmethod + def _coeffs_tau_all_sufficient_positive( + cls, val: tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...] + ) -> tuple[tuple[CustomSpatialDataType, CustomSpatialDataType], ...]: + """We restrict either all tau is sufficently greater than 0.""" + for _, tau in val: + if np.any(_get_numpy_array(tau) < 1 / 2 / np.pi / LARGEST_FP_NUMBER): + raise SetupError( + "Coefficients 'tau_i' are restricted to be sufficiently greater than 0." + ) + return val + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Adjoint derivatives for CustomDebye via analytic chain rule.""" + + dJ = self._sum_complex_eps_sensitivity(derivative_info, spatial_ref=self.eps_inf) + + grads: AutogradFieldMap = {} + if ("eps_inf",) in derivative_info.paths: + grads[("eps_inf",)] = np.real(dJ) + + for i, (de_da, tau_da) in enumerate(self.coeffs): + need_de = ("coeffs", i, 0) in derivative_info.paths + need_tau = ("coeffs", i, 1) in derivative_info.paths + if not (need_de or need_tau): + continue + + de = np.array(de_da.values, dtype=float) + tau = np.array(tau_da.values, dtype=float) + + g_de = 0.0 if not need_de else np.zeros_like(de, dtype=float) + g_tau = 0.0 if not need_tau else np.zeros_like(tau, dtype=float) + + if need_de: + g_de = g_de + self._sum_over_freqs( + derivative_info.frequencies, + dJ, + weight_fn=lambda f, tau=tau: Debye._w_de(f, tau), + ) + if need_tau: + g_tau = g_tau + self._sum_over_freqs( + derivative_info.frequencies, + dJ, + weight_fn=lambda f, de=de, tau=tau: Debye._w_tau(f, de, tau), + ) + + if need_de: + grads[("coeffs", i, 0)] = g_de + if need_tau: + grads[("coeffs", i, 1)] = g_tau + + return grads + + @model_validator(mode="after") + def _passivity_validation(self) -> Self: + """Assert passive medium if ``allow_gain`` is False.""" + val = self.coeffs + allow_gain = self.allow_gain + for del_ep, tau in val: + if np.any(_get_numpy_array(tau) <= 0): + raise SetupError("For stable medium, 'tau_i' must be positive.") + if not allow_gain and np.any(_get_numpy_array(del_ep) < 0): + raise ValidationError( + "For passive medium, 'Delta epsilon_i' must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, " + "and are likely to diverge." + ) + return self + + @cached_property + def is_spatially_uniform(self) -> bool: + """Whether the medium is spatially uniform.""" + if not self.eps_inf.is_uniform: + return False + for coeffs in self.coeffs: + for coeff in coeffs: + if not coeff.is_uniform: + return False + return True + + def eps_dataarray_freq( + self, frequency: float + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + """Permittivity array at ``frequency``. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[ + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + ] + The permittivity evaluated at ``frequency``. + """ + eps = Debye.eps_model(self, frequency) + return (eps, eps, eps) + + def _sel_custom_data_inside(self, bounds: Bound) -> Self: + """Return a new custom medium that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. + + + Parameters + ---------- + bounds : tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + CustomDebye + CustomDebye with reduced data. + """ + if not self.eps_inf.does_cover(bounds=bounds): + log.warning("Eps inf spatial data array does not fully cover the requested region.") + eps_inf_reduced = self.eps_inf.sel_inside(bounds=bounds) + coeffs_reduced = [] + for de, tau in self.coeffs: + if not de.does_cover(bounds=bounds): + log.warning( + "Debye 'f' spatial data array does not fully cover the requested region." + ) + + if not tau.does_cover(bounds=bounds): + log.warning( + "Debye 'tau' spatial data array does not fully cover the requested region." + ) + + coeffs_reduced.append((de.sel_inside(bounds), tau.sel_inside(bounds))) + + return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=tuple(coeffs_reduced)) + + +_ISOTROPIC_UNIFORM_MEDIUM_CORE_TYPES = ( + Medium, + PoleResidue, + Sellmeier, + Lorentz, + Debye, + Drude, + PECMedium, + PMCMedium, +) + +_ISOTROPIC_UNIFORM_MEDIUM_EXTRA_TYPES: list[type[Any]] = [] + + +def _build_isotropic_uniform_medium_type() -> object: + return Union[ + _ISOTROPIC_UNIFORM_MEDIUM_CORE_TYPES + tuple(_ISOTROPIC_UNIFORM_MEDIUM_EXTRA_TYPES) + ] + + +IsotropicUniformMediumType = _build_isotropic_uniform_medium_type() + + +def extend_isotropic_uniform_medium_type(*extra_types: type[Any]) -> None: + """Extend ``IsotropicUniformMediumType`` and rebuild dependent models.""" + for extra_type in extra_types: + if extra_type not in _ISOTROPIC_UNIFORM_MEDIUM_EXTRA_TYPES: + _ISOTROPIC_UNIFORM_MEDIUM_EXTRA_TYPES.append(extra_type) + global IsotropicUniformMediumType + IsotropicUniformMediumType = _build_isotropic_uniform_medium_type() + if "AnisotropicMedium" in globals(): + AnisotropicMedium.model_rebuild( + force=True, _types_namespace={"IsotropicUniformMediumType": IsotropicUniformMediumType} + ) + + +IsotropicCustomMediumType = Union[ + CustomPoleResidue, + CustomSellmeier, + CustomLorentz, + CustomDebye, + CustomDrude, +] + +IsotropicCustomMediumInternalType = Union[IsotropicCustomMediumType, CustomIsotropicMedium] + + +class AnisotropicMedium(AbstractMedium): + """Diagonally anisotropic medium. + + Notes + ----- + + Only diagonal anisotropy is currently supported. + + Example + ------- + >>> medium_xx = Medium(permittivity=4.0) + >>> medium_yy = Medium(permittivity=4.1) + >>> medium_zz = Medium(permittivity=3.9) + >>> anisotropic_dielectric = AnisotropicMedium(xx=medium_xx, yy=medium_yy, zz=medium_zz) + + See Also + -------- + + :class:`CustomAnisotropicMedium` + Diagonally anisotropic medium with spatially varying permittivity in each component. + + :class:`FullyAnisotropicMedium` + Fully anisotropic medium including all 9 components of the permittivity and conductivity tensors. + + **Notebooks** + * `Broadband polarizer assisted by anisotropic metamaterial <../../notebooks/SWGBroadbandPolarizer.html>`_ + * `Thin film lithium niobate adiabatic waveguide coupler <../../notebooks/AdiabaticCouplerLN.html>`_ + """ + + xx: IsotropicUniformMediumType = Field( + title="XX Component", + description="Medium describing the xx-component of the diagonal permittivity tensor.", + discriminator=TYPE_TAG_STR, + ) + + yy: IsotropicUniformMediumType = Field( + title="YY Component", + description="Medium describing the yy-component of the diagonal permittivity tensor.", + discriminator=TYPE_TAG_STR, + ) + + zz: IsotropicUniformMediumType = Field( + title="ZZ Component", + description="Medium describing the zz-component of the diagonal permittivity tensor.", + discriminator=TYPE_TAG_STR, + ) + + allow_gain: Optional[bool] = Field( + None, + title="Allow gain medium", + description="This field is ignored. Please set ``allow_gain`` in each component", + ) + + @field_validator("modulation_spec") + @classmethod + def _validate_modulation_spec(cls, val: Optional[ModulationSpec]) -> Optional[ModulationSpec]: + """Check compatibility with modulation_spec.""" + if val is not None: + raise ValidationError( + f"A 'modulation_spec' of class {type(val)} is not " + f"currently supported for medium class {cls.__name__}. " + "Please add modulation to each component." + ) + return val + + @model_validator(mode="after") + def _ignored_fields(self) -> Self: + """The field is ignored.""" + if self.xx is not None and self.allow_gain is not None: + log.warning( + "The field 'allow_gain' is ignored. Please set 'allow_gain' in each component." + ) + return self + + @cached_property + def components(self) -> dict[str, Medium]: + """Dictionary of diagonal medium components.""" + return {"xx": self.xx, "yy": self.yy, "zz": self.zz} + + @cached_property + def is_time_modulated(self) -> bool: + """Whether any component of the medium is time modulated.""" + return any(mat.is_time_modulated for mat in self.components.values()) + + @cached_property + def n_cfl(self) -> float: + """This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl``. + + For this medium, it takes the minimal of ``n_clf`` in all components. + """ + return min(mat_component.n_cfl for mat_component in self.components.values()) + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + """Complex-valued permittivity as a function of frequency.""" + + return np.mean(self.eps_diagonal(frequency), axis=0) + + @ensure_freq_in_range + def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: + """Main diagonal of the complex-valued permittivity tensor as a function of frequency.""" + + eps_xx = self.xx.eps_model(frequency) + eps_yy = self.yy.eps_model(frequency) + eps_zz = self.zz.eps_model(frequency) + return (eps_xx, eps_yy, eps_zz) + + def eps_comp(self, row: Axis, col: Axis, frequency: float) -> complex: + """Single component the complex-valued permittivity tensor as a function of frequency. + + Parameters + ---------- + row : int + Component's row in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). + col : int + Component's column in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + complex + Element of the relative permittivity tensor evaluated at ``frequency``. + """ + + if row != col: + return 0j + cmp = "xyz"[row] + field_name = cmp + cmp + return self.components[field_name].eps_model(frequency) + + def _eps_plot( + self, frequency: float, eps_component: Optional[PermittivityComponent] = None + ) -> float: + """Returns real part of epsilon for plotting. A specific component of the epsilon tensor can + be selected for anisotropic medium. + + Parameters + ---------- + frequency : float + eps_component : PermittivityComponent + + Returns + ------- + float + Element ``eps_component`` of the relative permittivity tensor evaluated at ``frequency``. + """ + if eps_component is None: + # return the average of the diag + return self.eps_model(frequency).real + if eps_component in ["xx", "yy", "zz"]: + # return the requested diagonal component + comp2indx = {"x": 0, "y": 1, "z": 2} + return self.eps_comp( + row=comp2indx[eps_component[0]], + col=comp2indx[eps_component[1]], + frequency=frequency, + ).real + raise ValueError( + f"Plotting component '{eps_component}' of a diagonally-anisotropic permittivity tensor is not supported." + ) + + @add_ax_if_none + def plot(self, freqs: float, ax: Ax = None) -> Ax: + """Plot n, k of a :class:`.Medium` as a function of frequency.""" + + freqs = np.array(freqs) + freqs_thz = freqs / 1e12 + + for label, medium_component in self.elements.items(): + eps_complex = medium_component.eps_model(freqs) + n, k = AbstractMedium.eps_complex_to_nk(eps_complex) + ax.plot(freqs_thz, n, label=f"n, eps_{label}") + ax.plot(freqs_thz, k, label=f"k, eps_{label}") + + ax.set_xlabel("frequency (THz)") + ax.set_title("medium dispersion") + ax.legend() + ax.set_aspect("auto") + return ax + + @property + def elements(self) -> dict[str, IsotropicUniformMediumType]: + """The diagonal elements of the medium as a dictionary.""" + return {"xx": self.xx, "yy": self.yy, "zz": self.zz} + + @cached_property + def is_pec(self) -> bool: + """Whether the medium is a PEC.""" + return any(self.is_comp_pec(i) for i in range(3)) + + @cached_property + def is_pmc(self) -> bool: + """Whether the medium is a PMC.""" + return any(self.is_comp_pmc(i) for i in range(3)) + + def is_comp_pec(self, comp: Axis) -> bool: + """Whether the medium is a PEC.""" + return isinstance(self.components[["xx", "yy", "zz"][comp]], PECMedium) + + def is_comp_pmc(self, comp: Axis) -> bool: + """Whether the medium is a PMC.""" + return isinstance(self.components[["xx", "yy", "zz"][comp]], PMCMedium) + + def sel_inside(self, bounds: Bound) -> Self: + """Return a new medium that contains the minimal amount data necessary to cover + a spatial region defined by ``bounds``. + + + Parameters + ---------- + bounds : tuple[float, float, float], tuple[float, float float] + Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. + + Returns + ------- + AnisotropicMedium + AnisotropicMedium with reduced data. + """ + + new_comps = [comp.sel_inside(bounds) for comp in [self.xx, self.yy, self.zz]] + + return self.updated_copy(**dict(zip(["xx", "yy", "zz"], new_comps))) + + # --- shared autograd helpers --- + @staticmethod + def _component_derivative_info( + derivative_info: DerivativeInfo, component: str + ) -> DerivativeInfo | None: + """Build ``DerivativeInfo`` filtered to a single anisotropic component.""" + + component_paths = [ + tuple(path[1:]) for path in derivative_info.paths if path and path[0] == component + ] + if not component_paths: + return None + + axis = component[0] # f.e. xx -> x + projected_E = derivative_info.project_der_map_to_axis(axis, "E") + projected_D = derivative_info.project_der_map_to_axis(axis, "D") + return derivative_info.updated_copy( + paths=component_paths, E_der_map=projected_E, D_der_map=projected_D + ) + + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Delegate derivatives for each diagonal component of an anisotropic medium.""" + + components = self.components + for field_path in derivative_info.paths: + if len(field_path) < 2 or field_path[0] not in components: + raise NotImplementedError( + f"No derivative defined for '{type(self).__name__}' field: {field_path}." + ) + + vjps: AutogradFieldMap = {} + for comp_name, component in components.items(): + comp_info = self._component_derivative_info( + derivative_info=derivative_info, component=comp_name + ) + if comp_info is None: + continue + comp_vjps = component._compute_derivatives(comp_info) + for sub_path, value in comp_vjps.items(): + vjps[(comp_name, *sub_path)] = value + + return vjps + + +class CustomAnisotropicMedium(AbstractCustomMedium, AnisotropicMedium): + """Diagonally anisotropic medium with spatially varying permittivity in each component. + + Note + ---- + Only diagonal anisotropy is currently supported. + + Example + ------- + >>> Nx, Ny, Nz = 10, 9, 8 + >>> x = np.linspace(-1, 1, Nx) + >>> y = np.linspace(-1, 1, Ny) + >>> z = np.linspace(-1, 1, Nz) + >>> coords = dict(x=x, y=y, z=z) + >>> permittivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) + >>> conductivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) + >>> medium_xx = CustomMedium(permittivity=permittivity, conductivity=conductivity) + >>> medium_yy = CustomMedium(permittivity=permittivity, conductivity=conductivity) + >>> d_epsilon = SpatialDataArray(np.random.random((Nx, Ny, Nz)), coords=coords) + >>> f = SpatialDataArray(1+np.random.random((Nx, Ny, Nz)), coords=coords) + >>> delta = SpatialDataArray(np.random.random((Nx, Ny, Nz)), coords=coords) + >>> medium_zz = CustomLorentz(eps_inf=permittivity, coeffs=[(d_epsilon,f,delta),]) + >>> anisotropic_dielectric = CustomAnisotropicMedium(xx=medium_xx, yy=medium_yy, zz=medium_zz) + + See Also + -------- + + :class:`AnisotropicMedium` + Diagonally anisotropic medium. + + **Notebooks** + * `Broadband polarizer assisted by anisotropic metamaterial <../../notebooks/SWGBroadbandPolarizer.html>`_ + * `Thin film lithium niobate adiabatic waveguide coupler <../../notebooks/AdiabaticCouplerLN.html>`_ + * `Defining fully anisotropic materials <../../notebooks/FullyAnisotropic.html>`_ + """ + + xx: Union[IsotropicCustomMediumType, CustomMedium] = Field( + title="XX Component", + description="Medium describing the xx-component of the diagonal permittivity tensor.", + discriminator=TYPE_TAG_STR, + ) + + yy: Union[IsotropicCustomMediumType, CustomMedium] = Field( + title="YY Component", + description="Medium describing the yy-component of the diagonal permittivity tensor.", + discriminator=TYPE_TAG_STR, + ) + + zz: Union[IsotropicCustomMediumType, CustomMedium] = Field( + title="ZZ Component", + description="Medium describing the zz-component of the diagonal permittivity tensor.", + discriminator=TYPE_TAG_STR, + ) + + interp_method: Optional[InterpMethod] = Field( + None, + title="Interpolation method", + description="When the value is ``None`` each component will follow its own " + "interpolation method. When the value is other than ``None`` the interpolation " + "method specified by this field will override the one in each component.", + ) + + allow_gain: Optional[bool] = Field( + None, + title="Allow gain medium", + description="This field is ignored. Please set ``allow_gain`` in each component", + ) + + subpixel: Optional[bool] = Field( + None, + title="Subpixel averaging", + description="This field is ignored. Please set ``subpixel`` in each component", + ) + + @field_validator("xx", "yy", "zz") + @classmethod + def _isotropic_xx( + cls, val: Union[IsotropicCustomMediumType, CustomMedium], info: FieldValidationInfo + ) -> Union[IsotropicCustomMediumType, CustomMedium]: + """If it's `CustomMedium`, make sure it's isotropic.""" + if isinstance(val, CustomMedium) and not val.is_isotropic: + raise SetupError(f"The {info.field_name}-component medium type is not isotropic.") + return val + + @model_validator(mode="after") + def _ignored_fields(self) -> Self: + """The field is ignored.""" + if self.xx is not None: + if self.allow_gain is not None: + log.warning( + "The field 'allow_gain' is ignored. Please set 'allow_gain' in each component." + ) + if self.subpixel is not None: + log.warning( + "The field 'subpixel' is ignored. Please set 'subpixel' in each component." + ) + return self + + @cached_property + def is_spatially_uniform(self) -> bool: + """Whether the medium is spatially uniform.""" + return any(comp.is_spatially_uniform for comp in self.components.values()) + + @cached_property + def n_cfl(self) -> float: + """This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl``. + + For this medium, it takes the minimal of ``n_clf`` in all components. + """ + return min(mat_component.n_cfl for mat_component in self.components.values()) + + @cached_property + def is_isotropic(self) -> bool: + """Whether the medium is isotropic.""" + return False + + def _interp_method(self, comp: Axis) -> InterpMethod: + """Interpolation method applied to comp.""" + # override `interp_method` in components if self.interp_method is not None + if self.interp_method is not None: + return self.interp_method + # use component's interp_method + comp_map = ["xx", "yy", "zz"] + return self.components[comp_map[comp]].interp_method + + def eps_dataarray_freq( + self, frequency: float + ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: + """Permittivity array at ``frequency``. + + Parameters + ---------- + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + tuple[ + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + Union[ + :class:`.SpatialDataArray`, + :class:`.TriangularGridDataset`, + :class:`.TetrahedralGridDataset`, + ], + ] + The permittivity evaluated at ``frequency``. + """ + return tuple( + mat_component.eps_dataarray_freq(frequency)[ind] + for ind, mat_component in enumerate(self.components.values()) + ) + + def _eps_bounds( + self, + frequency: Optional[float] = None, + eps_component: Optional[PermittivityComponent] = None, + ) -> tuple[float, float]: + """Returns permittivity bounds for setting the color bounds when plotting. + + Parameters + ---------- + frequency : float = None + Frequency to evaluate the relative permittivity of all mediums. + If not specified, evaluates at infinite frequency. + eps_component : Optional[PermittivityComponent] = None + Component of the permittivity tensor to plot for anisotropic materials, + e.g. ``"xx"``, ``"yy"``, ``"zz"``, ``"xy"``, ``"yz"``, ... + Defaults to ``None``, which returns the average of the diagonal values. + + Returns + ------- + tuple[float, float] + The min and max values of the permittivity for the selected component and evaluated at ``frequency``. + """ + comps = ["xx", "yy", "zz"] + if eps_component in comps: + # Return the bounds of a specific component + eps_dataarray = self.eps_dataarray_freq(frequency) + eps = self._get_real_vals(eps_dataarray[comps.index(eps_component)]) + return (np.min(eps), np.max(eps)) + if eps_component is None: + # Returns the bounds across all components + return super()._eps_bounds(frequency=frequency) + raise ValueError( + f"Plotting component '{eps_component}' of a diagonally-anisotropic permittivity tensor is not supported." + ) + + def _sel_custom_data_inside(self, bounds: Bound) -> Self: + return self + + +class CustomAnisotropicMediumInternal(CustomAnisotropicMedium): + """Diagonally anisotropic medium with spatially varying permittivity in each component. + + Notes + ----- + + Only diagonal anisotropy is currently supported. + + Example + ------- + >>> Nx, Ny, Nz = 10, 9, 8 + >>> X = np.linspace(-1, 1, Nx) + >>> Y = np.linspace(-1, 1, Ny) + >>> Z = np.linspace(-1, 1, Nz) + >>> coords = dict(x=X, y=Y, z=Z) + >>> permittivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) + >>> conductivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) + >>> medium_xx = CustomMedium(permittivity=permittivity, conductivity=conductivity) + >>> medium_yy = CustomMedium(permittivity=permittivity, conductivity=conductivity) + >>> d_epsilon = SpatialDataArray(np.random.random((Nx, Ny, Nz)), coords=coords) + >>> f = SpatialDataArray(1+np.random.random((Nx, Ny, Nz)), coords=coords) + >>> delta = SpatialDataArray(np.random.random((Nx, Ny, Nz)), coords=coords) + >>> medium_zz = CustomLorentz(eps_inf=permittivity, coeffs=[(d_epsilon,f,delta),]) + >>> anisotropic_dielectric = CustomAnisotropicMedium(xx=medium_xx, yy=medium_yy, zz=medium_zz) + """ + + xx: Union[IsotropicCustomMediumInternalType, CustomMedium] = Field( + title="XX Component", + description="Medium describing the xx-component of the diagonal permittivity tensor.", + discriminator=TYPE_TAG_STR, + ) + + yy: Union[IsotropicCustomMediumInternalType, CustomMedium] = Field( + title="YY Component", + description="Medium describing the yy-component of the diagonal permittivity tensor.", + discriminator=TYPE_TAG_STR, + ) + + zz: Union[IsotropicCustomMediumInternalType, CustomMedium] = Field( + title="ZZ Component", + description="Medium describing the zz-component of the diagonal permittivity tensor.", + discriminator=TYPE_TAG_STR, + ) + + +class FullyAnisotropicMedium(AbstractMedium): + """Fully anisotropic medium including all 9 components of the permittivity and conductivity + tensors. + + Notes + ----- + + Provided permittivity tensor and the symmetric part of the conductivity tensor must + have coinciding main directions. A non-symmetric conductivity tensor can be used to model + magneto-optic effects. Note that dispersive properties and subpixel averaging are currently not + supported for fully anisotropic materials. + + Note + ---- + + Simulations involving fully anisotropic materials are computationally more intensive, thus, + they take longer time to complete. This increase strongly depends on the filling fraction of + the simulation domain by fully anisotropic materials, varying approximately in the range from + 1.5 to 5. The cost of running a simulation is adjusted correspondingly. + + Example + ------- + >>> perm = [[2, 0, 0], [0, 1, 0], [0, 0, 3]] + >>> cond = [[0.1, 0, 0], [0, 0, 0], [0, 0, 0]] + >>> anisotropic_dielectric = FullyAnisotropicMedium(permittivity=perm, conductivity=cond) + + See Also + -------- + + :class:`CustomAnisotropicMedium` + Diagonally anisotropic medium with spatially varying permittivity in each component. + + :class:`AnisotropicMedium` + Diagonally anisotropic medium. + + **Notebooks** + * `Broadband polarizer assisted by anisotropic metamaterial <../../notebooks/SWGBroadbandPolarizer.html>`_ + * `Thin film lithium niobate adiabatic waveguide coupler <../../notebooks/AdiabaticCouplerLN.html>`_ + * `Defining fully anisotropic materials <../../notebooks/FullyAnisotropic.html>`_ + """ + + permittivity: TensorReal = Field( + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + title="Permittivity", + description="Relative permittivity tensor.", + units=PERMITTIVITY, + ) + + conductivity: TensorReal = Field( + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + title="Conductivity", + description="Electric conductivity tensor. Defined such that the imaginary part " + "of the complex permittivity at angular frequency omega is given by conductivity/omega.", + units=CONDUCTIVITY, + ) + + @field_validator("modulation_spec") + @classmethod + def _validate_modulation_spec(cls, val: Optional[ModulationSpec]) -> Optional[ModulationSpec]: + """Check compatibility with modulation_spec.""" + if val is not None: + raise ValidationError( + f"A 'modulation_spec' of class {type(val)} is not " + f"currently supported for medium class {cls.__name__}." + ) + return val + + @field_validator("permittivity") + @classmethod + def permittivity_spd_and_ge_one(cls, val: TracedFloat) -> TracedFloat: + """Check that provided permittivity tensor is symmetric positive definite + with eigenvalues >= 1. + """ + + if not np.allclose(val, np.transpose(val), atol=fp_eps): + raise ValidationError("Provided permittivity tensor is not symmetric.") + + if np.any(np.linalg.eigvals(val) < 1 - fp_eps): + raise ValidationError("Main diagonal of provided permittivity tensor is not >= 1.") + + return val + + @model_validator(mode="after") + def conductivity_commutes(self) -> Self: + """Check that the symmetric part of conductivity tensor commutes with permittivity tensor + (that is, simultaneously diagonalizable). + """ + + val = self.conductivity + perm = self.permittivity + cond_sym = 0.5 * (val + val.T) + comm_diff = np.abs(np.matmul(perm, cond_sym) - np.matmul(cond_sym, perm)) + + if not np.allclose(comm_diff, 0, atol=fp_eps): + raise ValidationError( + "Main directions of conductivity and permittivity tensor do not coincide." + ) + + return self + + @model_validator(mode="after") + def _passivity_validation(self) -> Self: + """Assert passive medium if ``allow_gain`` is False.""" + val = self.conductivity + if self.allow_gain: + return self + + cond_sym = 0.5 * (val + val.T) + if np.any(np.linalg.eigvals(cond_sym) < -fp_eps): + raise ValidationError( + "For passive medium, main diagonal of provided conductivity tensor " + "must be non-negative. " + "To simulate a gain medium, please set 'allow_gain=True'. " + "Caution: simulations with a gain medium are unstable, and are likely to diverge." + ) + return self + + @classmethod + def from_diagonal(cls, xx: Medium, yy: Medium, zz: Medium, rotation: RotationType) -> Self: + """Construct a fully anisotropic medium by rotating a diagonally anisotropic medium. + + Parameters + ---------- + xx : :class:`.Medium` + Medium describing the xx-component of the diagonal permittivity tensor. + yy : :class:`.Medium` + Medium describing the yy-component of the diagonal permittivity tensor. + zz : :class:`.Medium` + Medium describing the zz-component of the diagonal permittivity tensor. + rotation : Union[:class:`.RotationAroundAxis`] + Rotation applied to diagonal permittivity tensor. + + Returns + ------- + :class:`FullyAnisotropicMedium` + Resulting fully anisotropic medium. + """ + + if any(comp.nonlinear_spec is not None for comp in [xx, yy, zz]): + raise ValidationError( + "Nonlinearities are not currently supported for the components " + "of a fully anisotropic medium." + ) + + if any(comp.modulation_spec is not None for comp in [xx, yy, zz]): + raise ValidationError( + "Modulation is not currently supported for the components " + "of a fully anisotropic medium." + ) + + permittivity_diag = np.diag([comp.permittivity for comp in [xx, yy, zz]]).tolist() + conductivity_diag = np.diag([comp.conductivity for comp in [xx, yy, zz]]).tolist() + + permittivity = rotation.rotate_tensor(permittivity_diag) + conductivity = rotation.rotate_tensor(conductivity_diag) + + return cls(permittivity=permittivity, conductivity=conductivity) + + @cached_property + def _to_diagonal(self) -> AnisotropicMedium: + """Construct a diagonally anisotropic medium from main components. + + Returns + ------- + :class:`AnisotropicMedium` + Resulting diagonally anisotropic medium. + """ + + perm, cond, _ = self.eps_sigma_diag + + return AnisotropicMedium( + xx=Medium(permittivity=perm[0], conductivity=cond[0]), + yy=Medium(permittivity=perm[1], conductivity=cond[1]), + zz=Medium(permittivity=perm[2], conductivity=cond[2]), + ) + + @cached_property + def eps_sigma_diag( + self, + ) -> tuple[tuple[float, float, float], tuple[float, float, float], TensorReal]: + """Main components of permittivity and conductivity tensors and their directions.""" + + perm_diag, vecs = np.linalg.eig(self.permittivity) + cond_diag = np.diag(np.matmul(np.transpose(vecs), np.matmul(self.conductivity, vecs))) + + return (perm_diag, cond_diag, vecs) + + @ensure_freq_in_range + def eps_model(self, frequency: float) -> complex: + """Complex-valued permittivity as a function of frequency.""" + perm_diag, cond_diag, _ = self.eps_sigma_diag + + if not np.isscalar(frequency): + perm_diag = perm_diag[:, None] + cond_diag = cond_diag[:, None] + eps_diag = AbstractMedium.eps_sigma_to_eps_complex(perm_diag, cond_diag, frequency) + return np.mean(eps_diag) + + @ensure_freq_in_range + def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: + """Main diagonal of the complex-valued permittivity tensor as a function of frequency.""" + + perm_diag, cond_diag, _ = self.eps_sigma_diag + + if not np.isscalar(frequency): + perm_diag = perm_diag[:, None] + cond_diag = cond_diag[:, None] + return AbstractMedium.eps_sigma_to_eps_complex(perm_diag, cond_diag, frequency) + + def eps_comp(self, row: Axis, col: Axis, frequency: float) -> complex: + """Single component the complex-valued permittivity tensor as a function of frequency. + + Parameters + ---------- + row : int + Component's row in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). + col : int + Component's column in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). + frequency : float + Frequency to evaluate permittivity at (Hz). + + Returns + ------- + complex + Element of the relative permittivity tensor evaluated at ``frequency``. + """ + + eps = self.permittivity[row][col] + sig = self.conductivity[row][col] + return AbstractMedium.eps_sigma_to_eps_complex(eps, sig, frequency) + + def _eps_plot( + self, frequency: float, eps_component: Optional[PermittivityComponent] = None + ) -> float: + """Returns real part of epsilon for plotting. A specific component of the epsilon tensor can + be selected for anisotropic medium. + + Parameters + ---------- + frequency : float + eps_component : PermittivityComponent + + Returns + ------- + float + Element ``eps_component`` of the relative permittivity tensor evaluated at ``frequency``. + """ + if eps_component is None: + # return the average of the diag + return self.eps_model(frequency).real + + # return the requested component + comp2indx = {"x": 0, "y": 1, "z": 2} + return self.eps_comp( + row=comp2indx[eps_component[0]], col=comp2indx[eps_component[1]], frequency=frequency + ).real + + @cached_property + def n_cfl(self) -> float: + """This property computes the index of refraction related to CFL condition, so that + the FDTD with this medium is stable when the time step size that doesn't take + material factor into account is multiplied by ``n_cfl``. + + For this medium, it take the minimal of ``sqrt(permittivity)`` for main directions. + """ + + perm_diag, _, _ = self.eps_sigma_diag + return min(np.sqrt(perm_diag)) + + @add_ax_if_none + def plot(self, freqs: float, ax: Ax = None) -> Ax: + """Plot n, k of a :class:`FullyAnisotropicMedium` as a function of frequency.""" + + diagonal_medium = self._to_diagonal + ax = diagonal_medium.plot(freqs=freqs, ax=ax) + _, _, directions = self.eps_sigma_diag + + # rename components from xx, yy, zz to 1, 2, 3 to avoid misleading + # and add their directions + for label, n_line, k_line, direction in zip( + ("1", "2", "3"), ax.lines[-6::2], ax.lines[-5::2], directions.T + ): + direction_str = f"({direction[0]:.2f}, {direction[1]:.2f}, {direction[2]:.2f})" + k_line.set_label(f"k, eps_{label} {direction_str}") + n_line.set_label(f"n, eps_{label} {direction_str}") + + ax.legend() + return ax diff --git a/tidy3d/_common/components/source/__init__.py b/tidy3d/_common/components/source/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/source/base.py b/tidy3d/_common/components/source/base.py new file mode 100644 index 0000000000..fb5fb1bd4d --- /dev/null +++ b/tidy3d/_common/components/source/base.py @@ -0,0 +1,134 @@ +"""Defines an abstract base for electromagnetic sources.""" + +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Any + +from pydantic import Field, field_validator + +from tidy3d._common.components.base import cached_property +from tidy3d._common.components.base_sim.source import AbstractSource +from tidy3d._common.components.geometry.base import Box +from tidy3d._common.components.source.time import SourceTimeType +from tidy3d._common.components.types.base import TYPE_TAG_STR +from tidy3d._common.components.validators import _assert_min_freq, _warn_unsupported_traced_argument +from tidy3d._common.components.viz import ( + ARROW_ALPHA, + ARROW_COLOR_POLARIZATION, + ARROW_COLOR_SOURCE, + plot_params_source, +) + +if TYPE_CHECKING: + from typing import Optional + + from tidy3d._common.components.types.base import Ax + from tidy3d._common.components.viz import PlotParams + + +class Source(Box, AbstractSource, ABC): + """Abstract base class for all sources.""" + + source_time: SourceTimeType = Field( + title="Source Time", + description="Specification of the source time-dependence.", + discriminator=TYPE_TAG_STR, + ) + + @cached_property + def plot_params(self) -> PlotParams: + """Default parameters for plotting a Source object.""" + return plot_params_source + + @cached_property + def geometry(self) -> Box: + """:class:`Box` representation of source.""" + + return Box(center=self.center, size=self.size) + + @cached_property + def _injection_axis(self) -> None: + """Injection axis of the source.""" + return + + @cached_property + def _dir_vector(self) -> None: + """Returns a vector indicating the source direction for arrow plotting, if not None.""" + return None + + @cached_property + def _pol_vector(self) -> None: + """Returns a vector indicating the source polarization for arrow plotting, if not None.""" + return None + + _warn_traced_center = _warn_unsupported_traced_argument("center") + _warn_traced_size = _warn_unsupported_traced_argument("size") + + @field_validator("source_time") + @classmethod + def _freqs_lower_bound(cls, val: SourceTimeType) -> SourceTimeType: + """Raise validation error if central frequency is too low.""" + _assert_min_freq(val._freq0_sigma_centroid, msg_start="'source_time.freq0'") + return val + + def plot( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + ax: Ax = None, + **patch_kwargs: Any, + ) -> Ax: + """Plot this source.""" + + kwargs_arrow_base = patch_kwargs.pop("arrow_base", None) + + # call the `Source.plot()` function first. + ax = Box.plot(self, x=x, y=y, z=z, ax=ax, **patch_kwargs) + + kwargs_alpha = patch_kwargs.get("alpha") + arrow_alpha = ARROW_ALPHA if kwargs_alpha is None else kwargs_alpha + + # then add the arrow based on the propagation direction + if self._dir_vector is not None: + bend_radius = None + bend_axis = None + if hasattr(self, "mode_spec") and self.mode_spec.bend_radius is not None: + bend_radius = self.mode_spec.bend_radius + bend_axis = self._bend_axis + sign = 1 if self.direction == "+" else -1 + # Curvature has to be reversed because of ploting coordinates + if (self.size.index(0), bend_axis) in [(1, 2), (2, 0), (2, 1)]: + bend_radius *= -sign + else: + bend_radius *= sign + + ax = self._plot_arrow( + x=x, + y=y, + z=z, + ax=ax, + direction=self._dir_vector, + bend_radius=bend_radius, + bend_axis=bend_axis, + color=ARROW_COLOR_SOURCE, + alpha=arrow_alpha, + both_dirs=False, + arrow_base=kwargs_arrow_base, + ) + + if self._pol_vector is not None: + ax = self._plot_arrow( + x=x, + y=y, + z=z, + ax=ax, + direction=self._pol_vector, + color=ARROW_COLOR_POLARIZATION, + alpha=arrow_alpha, + both_dirs=False, + arrow_base=kwargs_arrow_base, + ) + + return ax diff --git a/tidy3d/_common/components/source/time.py b/tidy3d/_common/components/source/time.py new file mode 100644 index 0000000000..ac4355621a --- /dev/null +++ b/tidy3d/_common/components/source/time.py @@ -0,0 +1,691 @@ +"""Defines time dependencies of injected electromagnetic sources.""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional, Union + +import numpy as np +from pydantic import Field, PositiveFloat, field_validator, model_validator +from pyroots import Brentq + +from tidy3d._common.components.base import cached_property +from tidy3d._common.components.data.data_array import TimeDataArray +from tidy3d._common.components.data.dataset import TimeDataset +from tidy3d._common.components.data.validators import validate_no_nans +from tidy3d._common.components.time import AbstractTimeDependence +from tidy3d._common.components.types.base import FreqBound +from tidy3d._common.components.validators import warn_if_dataset_none +from tidy3d._common.components.viz import add_ax_if_none +from tidy3d._common.constants import HERTZ +from tidy3d._common.exceptions import ValidationError +from tidy3d._common.log import log +from tidy3d._common.packaging import check_tidy3d_extras_licensed_feature, tidy3d_extras + +if TYPE_CHECKING: + from tidy3d._common.components.types.base import ArrayComplex1D, ArrayFloat1D, Ax, PlotVal + +# how many units of ``twidth`` from the ``offset`` until a gaussian pulse is considered "off" +END_TIME_FACTOR_GAUSSIAN = 10 + +# warn if source amplitude is too small at the endpoints of frequency range +WARN_SOURCE_AMPLITUDE = 0.1 +# used in Brentq +_ROOTS_TOL = 1e-10 +# Default sigma value in frequency_range +DEFAULT_SIGMA = 4.0 +# Offset in fwidth in finding frequency_range_sigma[1] to ensure the interval brackets the root +OFFSET_FWIDTH_FMAX = 100 + + +class SourceTime(AbstractTimeDependence): + """Base class describing the time dependence of a source.""" + + @add_ax_if_none + def plot_spectrum( + self, + times: ArrayFloat1D, + num_freqs: int = 101, + val: PlotVal = "real", + ax: Ax = None, + ) -> Ax: + """Plot the complex-valued amplitude of the source time-dependence. + Note: Only the real part of the time signal is used. + + Parameters + ---------- + times : np.ndarray + Array of evenly-spaced times (seconds) to evaluate source time-dependence at. + The spectrum is computed from this value and the source time frequency content. + To see source spectrum for a specific :class:`.Simulation`, + pass ``simulation.tmesh``. + num_freqs : int = 101 + Number of frequencies to plot within the SourceTime.frequency_range. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + fmin, fmax = self.frequency_range_sigma() + return self.plot_spectrum_in_frequency_range( + times, fmin, fmax, num_freqs=num_freqs, val=val, ax=ax + ) + + @abstractmethod + def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: + """Frequency range within plus/minus ``num_fwidth * fwidth`` of the central frequency.""" + + def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: + """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" + return self.frequency_range(num_fwidth=sigma) + + @cached_property + def _frequency_range_sigma_cached(self) -> FreqBound: + """Cached `frequency_range_sigma` for the default sigma value.""" + return self.frequency_range_sigma(sigma=DEFAULT_SIGMA) + + @abstractmethod + def end_time(self) -> Optional[float]: + """Time after which the source is effectively turned off / close to zero amplitude.""" + + @cached_property + def _freq0(self) -> float: + """Central frequency. If not present in input parameters, returns `_freq0_sigma_centroid`.""" + return self._freq0_sigma_centroid + + @cached_property + def _freq0_sigma_centroid(self) -> float: + """Central of frequency range at 1-sigma drop from the peak amplitude.""" + return np.mean(self.frequency_range_sigma(sigma=1)) + + +class Pulse(SourceTime, ABC): + """A source time that ramps up with some ``fwidth`` and oscillates at ``freq0``.""" + + freq0: PositiveFloat = Field( + title="Central Frequency", + description="Central frequency of the pulse.", + units=HERTZ, + ) + fwidth: PositiveFloat = Field( + title="", + description="Standard deviation of the frequency content of the pulse.", + units=HERTZ, + ) + + offset: float = Field( + 5.0, + title="Offset", + description="Time delay of the maximum value of the " + "pulse in units of 1 / (``2pi * fwidth``).", + ge=2.5, + ) + + @cached_property + def _freq0(self) -> float: + """Central frequency.""" + return self.freq0 + + @property + def offset_time(self) -> float: + """Offset time in seconds.""" + return self.offset * self.twidth + + @property + def twidth(self) -> float: + """Width of pulse in seconds.""" + return 1.0 / (2 * np.pi * self.fwidth) + + def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: + """Frequency range within 5 standard deviations of the central frequency. + + Parameters + ---------- + num_fwidth : float = 4. + Frequency range defined as plus/minus ``num_fwidth * self.fwdith``. + + Returns + ------- + Tuple[float, float] + Minimum and maximum frequencies of the :class:`GaussianPulse` or :class:`ContinuousWave` + power. + """ + + freq_width_range = num_fwidth * self.fwidth + freq_min = max(0, self.freq0 - freq_width_range) + freq_max = self.freq0 + freq_width_range + return (freq_min, freq_max) + + +class GaussianPulse(Pulse): + """Source time dependence that describes a Gaussian pulse. + + Example + ------- + >>> pulse = GaussianPulse(freq0=200e12, fwidth=20e12) + """ + + remove_dc_component: bool = Field( + True, + title="Remove DC Component", + description="Whether to remove the DC component in the Gaussian pulse spectrum. " + "If ``True``, the Gaussian pulse is modified at low frequencies to zero out the " + "DC component, which is usually desirable so that the fields will decay. However, " + "for broadband simulations, it may be better to have non-vanishing source power " + "near zero frequency. Setting this to ``False`` results in an unmodified Gaussian " + "pulse spectrum which can have a nonzero DC component.", + ) + + @property + def peak_time(self) -> float: + """Peak time in seconds, defined by ``offset``.""" + return self.offset * self.twidth + + @property + def _peak_time_shift(self) -> float: + """In the case of DC removal, correction to offset_time so that ``offset`` indeed defines time delay + of pulse peak. + """ + if self.remove_dc_component and self.fwidth > self.freq0: + return self.twidth * np.sqrt(1 - self.freq0**2 / self.fwidth**2) + return 0 + + @property + def offset_time(self) -> float: + """Offset time in seconds. Note that in the case of DC removal, the maximal value of pulse can be shifted.""" + return self.peak_time + self._peak_time_shift + + def amp_time(self, time: float) -> complex: + """Complex-valued source amplitude as a function of time.""" + + omega0 = 2 * np.pi * self.freq0 + time_shifted = time - self.offset_time + + offset = np.exp(1j * self.phase) + oscillation = np.exp(-1j * omega0 * time) + amp = np.exp(-(time_shifted**2) / 2 / self.twidth**2) * self.amplitude + + pulse_amp = offset * oscillation * amp + + # subtract out DC component + if self.remove_dc_component: + pulse_amp = pulse_amp * (1j * omega0 + time_shifted / self.twidth**2) + # normalize by peak frequency instead of omega0, as for small omega0, omega0 approaches 0 faster + pulse_amp /= 2 * np.pi * self.peak_frequency + else: + # 1j to make it agree in large omega0 limit + pulse_amp = pulse_amp * 1j + + return pulse_amp + + def end_time(self) -> Optional[float]: + """Time after which the source is effectively turned off / close to zero amplitude.""" + + # TODO: decide if we should continue to return an end_time if the DC component remains + # if not self.remove_dc_component: + # return None + + end_time = self.offset_time + END_TIME_FACTOR_GAUSSIAN * self.twidth + + # for derivative Gaussian that contains two peaks, add time interval between them + if self.remove_dc_component and self.fwidth > self.freq0: + end_time += 2 * self._peak_time_shift + return end_time + + def amp_freq(self, freq: float) -> complex: + """Complex-valued source spectrum in frequency domain.""" + phase = np.exp(1j * self.phase + 1j * 2 * np.pi * (freq - self.freq0) * self.offset_time) + envelope = np.exp(-((freq - self.freq0) ** 2) / 2 / self.fwidth**2) + amp = 1j * self.amplitude / self.fwidth * phase * envelope + if not self.remove_dc_component: + return amp + + # derivative of Gaussian when DC is removed + return freq * amp / (2 * np.pi * self.peak_frequency) + + def _rel_amp_freq(self, freq: float) -> complex: + """Complex-valued source spectrum in frequency domain normalized by peak amplitude.""" + return self.amp_freq(freq) / self._peak_freq_amp + + @property + def peak_frequency(self) -> float: + """Frequency at which the source time dependence has its peak amplitude in the frequency domain.""" + if not self.remove_dc_component: + return self.freq0 + return 0.5 * (self.freq0 + np.sqrt(self.freq0**2 + 4 * self.fwidth**2)) + + @property + def _peak_freq_amp(self) -> complex: + """Peak amplitude in frequency domain""" + return self.amp_freq(self.peak_frequency) + + @property + def _peak_time_amp(self) -> complex: + """Peak amplitude in time domain""" + return self.amp_time(self.peak_time) + + def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: + """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" + if not self.remove_dc_component: + return self.frequency_range(num_fwidth=sigma) + + # With dc removed, we'll need to solve for the transcendental equation to find the frequency range + def equation_for_sigma_frequency(freq: float) -> float: + """computes A / A_p - exp(-sigma)""" + return np.abs(self._rel_amp_freq(freq)) - np.exp(-(sigma**2) / 2) + + logger = logging.getLogger("pyroots") + logger.setLevel(logging.CRITICAL) + root_scalar = Brentq(raise_on_fail=False, epsilon=_ROOTS_TOL) + fmin_data = root_scalar(equation_for_sigma_frequency, xa=0, xb=self.peak_frequency) + fmax_data = root_scalar( + equation_for_sigma_frequency, + xa=self.peak_frequency, + xb=self.peak_frequency + + self.fwidth + * ( + OFFSET_FWIDTH_FMAX + 2 * sigma**2 + ), # offset slightly to make sure that it flips sign + ) + fmin, fmax = fmin_data.x0, fmax_data.x0 + + # if unconverged, fall back to `frequency_range` + if not (fmin_data.converged and fmax_data.converged and fmax > fmin): + return self.frequency_range(num_fwidth=sigma) + + # converged + return fmin.item(), fmax.item() + + @property + def amp_complex(self) -> complex: + """Grab the complex amplitude from a ``GaussianPulse``.""" + phase = np.exp(1j * self.phase) + return self.amplitude * phase + + @classmethod + def from_amp_complex(cls, amp: complex, **kwargs: Any) -> GaussianPulse: + """Set the complex amplitude of a ``GaussianPulse``. + + Parameters + ---------- + amp : complex + Complex-valued amplitude to set in the returned ``GaussianPulse``. + kwargs : dict + Keyword arguments passed to ``GaussianPulse()``, excluding ``amplitude`` & ``phase``. + """ + amplitude = abs(amp) + phase = np.angle(amp) + return cls(amplitude=amplitude, phase=phase, **kwargs) + + @staticmethod + def _minimum_source_bandwidth( + fmin: float, fmax: float, minimum_source_bandwidth: float + ) -> tuple[float, float]: + """Define a source bandwidth based on fmin and fmax, but enforce a minimum bandwidth.""" + if minimum_source_bandwidth <= 0: + raise ValidationError("'minimum_source_bandwidth' must be positive") + if minimum_source_bandwidth >= 1: + raise ValidationError("'minimum_source_bandwidth' must less than or equal to 1") + + f_difference = fmax - fmin + f_middle = 0.5 * (fmin + fmax) + + full_width = minimum_source_bandwidth * f_middle + if f_difference < full_width: + half_width = 0.5 * full_width + fmin = f_middle - half_width + fmax = f_middle + half_width + + return fmin, fmax + + @classmethod + def from_frequency_range( + cls, + fmin: PositiveFloat, + fmax: PositiveFloat, + minimum_source_bandwidth: Optional[PositiveFloat] = None, + **kwargs: Any, + ) -> GaussianPulse: + """Create a ``GaussianPulse`` that maximizes its amplitude in the frequency range [fmin, fmax]. + + Parameters + ---------- + fmin : float + Lower bound of frequency of interest. + fmax : float + Upper bound of frequency of interest. + kwargs : dict + Keyword arguments passed to ``GaussianPulse()``, excluding ``freq0`` & ``fwidth``. + + Returns + ------- + GaussianPulse + A ``GaussianPulse`` that maximizes its amplitude in the frequency range [fmin, fmax]. + """ + # validate that fmin and fmax must positive, and fmax > fmin + if fmin <= 0: + raise ValidationError("'fmin' must be positive.") + if fmax <= fmin: + raise ValidationError("'fmax' must be greater than 'fmin'.") + + if minimum_source_bandwidth is not None: + fmin, fmax = cls._minimum_source_bandwidth(fmin, fmax, minimum_source_bandwidth) + + # frequency range and center + freq_range = fmax - fmin + freq_center = (fmax + fmin) / 2.0 + + # If remove_dc_component=False, simply return the standard GaussianPulse parameters + if kwargs.get("remove_dc_component", True) is False: + return cls(freq0=freq_center, fwidth=freq_range / 2.0, **kwargs) + + # If remove_dc_component=True, the Gaussian pulse is distorted + kwargs.update({"remove_dc_component": True}) + log_ratio = np.log(fmax / fmin) + coeff = ((1 + log_ratio**2) ** 0.5 - 1) / 2.0 + freq0 = freq_center - coeff / log_ratio * freq_range + fwidth = freq_range / log_ratio * coeff**0.5 + pulse = cls(freq0=freq0, fwidth=fwidth, **kwargs) + if np.abs(pulse._rel_amp_freq(fmin)) < WARN_SOURCE_AMPLITUDE: + log.warning( + "Source amplitude is not sufficiently large throughout the specified frequency range, " + "which can result in inaccurate simulation results. Please decrease the frequency range.", + ) + return pulse + + +class ContinuousWave(Pulse): + """Source time dependence that ramps up to continuous oscillation + and holds until end of simulation. + + Note + ---- + Field decay will not occur, so the simulation will run for the full ``run_time``. + Also, source normalization of frequency-domain monitors is not meaningful. + + Example + ------- + >>> cw = ContinuousWave(freq0=200e12, fwidth=20e12) + """ + + def amp_time(self, time: float) -> complex: + """Complex-valued source amplitude as a function of time.""" + + twidth = 1.0 / (2 * np.pi * self.fwidth) + omega0 = 2 * np.pi * self.freq0 + time_shifted = time - self.offset_time + + const = 1.0 + offset = np.exp(1j * self.phase) + oscillation = np.exp(-1j * omega0 * time) + amp = 1 / (1 + np.exp(-time_shifted / twidth)) * self.amplitude + + return const * offset * oscillation * amp + + def end_time(self) -> Optional[float]: + """Time after which the source is effectively turned off / close to zero amplitude.""" + return None + + +class CustomSourceTime(Pulse): + """Custom source time dependence consisting of a real or complex envelope + modulated at a central frequency, as shown below. + + Note + ---- + .. math:: + + amp\\_time(t) = amplitude \\cdot \\ + e^{i \\cdot phase - 2 \\pi i \\cdot freq0 \\cdot t} \\cdot \\ + envelope(t - offset / (2 \\pi \\cdot fwidth)) + + Note + ---- + Depending on the envelope, field decay may not occur. + If field decay does not occur, then the simulation will run for the full ``run_time``. + Also, if field decay does not occur, then source normalization of frequency-domain + monitors is not meaningful. + + Note + ---- + The source time dependence is linearly interpolated to the simulation time steps. + The sampling rate should be sufficiently fast that this interpolation does not + introduce artifacts. The source time dependence should also start at zero and ramp up smoothly. + The first and last values of the envelope will be used for times that are out of range + of the provided data. + + Example + ------- + >>> cst = CustomSourceTime.from_values(freq0=1, fwidth=0.1, + ... values=np.linspace(0, 9, 10), dt=0.1) + + """ + + offset: float = Field( + 0.0, + title="Offset", + description="Time delay of the envelope in units of 1 / (``2pi * fwidth``).", + ) + + source_time_dataset: Optional[TimeDataset] = Field( + None, + title="Source time dataset", + description="Dataset for storing the envelope of the custom source time. " + "This envelope will be modulated by a complex exponential at frequency ``freq0``.", + ) + + _no_nans_dataset = validate_no_nans("source_time_dataset") + _source_time_dataset_none_warning = warn_if_dataset_none("source_time_dataset") + + @field_validator("source_time_dataset") + @classmethod + def _more_than_one_time(cls, val: Optional[TimeDataset]) -> Optional[TimeDataset]: + """Must have more than one time to interpolate.""" + if val is None: + return val + if val.values.size <= 1: + raise ValidationError("'CustomSourceTime' must have more than one time coordinate.") + return val + + @classmethod + def from_values( + cls, freq0: float, fwidth: float, values: ArrayComplex1D, dt: float + ) -> CustomSourceTime: + """Create a :class:`.CustomSourceTime` from a numpy array. + + Parameters + ---------- + freq0 : float + Central frequency of the source. The envelope provided will be modulated + by a complex exponential at this frequency. + fwidth : float + Estimated frequency width of the source. + values: ArrayComplex1D + Complex values of the source envelope. + dt: float + Time step for the ``values`` array. This value should be sufficiently small + that the interpolation to simulation time steps does not introduce artifacts. + + Returns + ------- + CustomSourceTime + :class:`.CustomSourceTime` with envelope given by ``values``, modulated by a complex + exponential at frequency ``freq0``. The time coordinates are evenly spaced + between ``0`` and ``dt * (N-1)`` with a step size of ``dt``, where ``N`` is the length of + the values array. + """ + + times = np.arange(len(values)) * dt + source_time_dataarray = TimeDataArray(values, coords={"t": times}) + source_time_dataset = TimeDataset(values=source_time_dataarray) + return CustomSourceTime( + freq0=freq0, + fwidth=fwidth, + source_time_dataset=source_time_dataset, + ) + + @property + def data_times(self) -> ArrayFloat1D: + """Times of envelope definition.""" + if self.source_time_dataset is None: + return [] + data_times = self.source_time_dataset.values.coords["t"].values.squeeze() + return data_times + + def _all_outside_range(self, run_time: float) -> bool: + """Whether all times are outside range of definition.""" + + # can't validate if data isn't loaded + if self.source_time_dataset is None: + return False + + # make time a numpy array for uniform handling + data_times = self.data_times + + # shift time + max_time_shifted = run_time - self.offset_time + min_time_shifted = -self.offset_time + + return (max_time_shifted < min(data_times)) | (min_time_shifted > max(data_times)) + + def amp_time(self, time: float) -> complex: + """Complex-valued source amplitude as a function of time. + + Parameters + ---------- + time : float + Time in seconds. + + Returns + ------- + complex + Complex-valued source amplitude at that time. + """ + + if self.source_time_dataset is None: + return None + + # make time a numpy array for uniform handling + times = np.array([time] if isinstance(time, (int, float)) else time) + data_times = self.data_times + + # shift time + twidth = 1.0 / (2 * np.pi * self.fwidth) + time_shifted = times - self.offset * twidth + + # mask times that are out of range + mask = (time_shifted < min(data_times)) | (time_shifted > max(data_times)) + + # get envelope + envelope = np.zeros(len(time_shifted), dtype=complex) + values = self.source_time_dataset.values + envelope[mask] = values.sel(t=time_shifted[mask], method="nearest").to_numpy() + if not all(mask): + envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy() + + # modulation, phase, amplitude + omega0 = 2 * np.pi * self.freq0 + offset = np.exp(1j * self.phase) + oscillation = np.exp(-1j * omega0 * times) + amp = self.amplitude + + return offset * oscillation * amp * envelope + + def end_time(self) -> Optional[float]: + """Time after which the source is effectively turned off / close to zero amplitude.""" + + if self.source_time_dataset is None: + return None + + data_array = self.source_time_dataset.values + + t_coords = data_array.coords["t"] + source_is_non_zero = ~np.isclose(abs(data_array), 0) + t_non_zero = t_coords[source_is_non_zero] + + return np.max(t_non_zero) + + +class BroadbandPulse(SourceTime): + """A source time injecting significant energy in the entire custom frequency range.""" + + freq_range: FreqBound = Field( + title="Frequency Range", + description="Frequency range where the pulse should have significant energy.", + units=HERTZ, + ) + minimum_amplitude: float = Field( + 0.3, + title="Minimum Amplitude", + description="Minimum amplitude of the pulse relative to the peak amplitude in the frequency range.", + gt=0.05, + lt=0.5, + ) + offset: float = Field( + 0.0, + title="Offset", + description="An automatic time delay of the peak value of the pulse has been applied under the hood " + "to ensure smooth ramping up of the pulse at time = 0. This offfset is added on top of the automatic time delay " + "in units of 1 / [``2pi * (freq_range[1] - freq_range[0])``].", + ) + + @field_validator("freq_range") + @classmethod + def _validate_freq_range(cls, val: FreqBound) -> FreqBound: + """Validate that freq_range is positive and properly ordered.""" + if val[0] <= 0 or val[1] <= 0: + raise ValidationError("Both elements of 'freq_range' must be positive.") + if val[1] <= val[0]: + raise ValidationError( + f"'freq_range[1]' ({val[1]}) must be greater than 'freq_range[0]' ({val[0]})." + ) + return val + + @model_validator(mode="before") + @classmethod + def _check_broadband_pulse_available(cls, values: dict[str, Any]) -> dict[str, Any]: + """Check if BroadbandPulse is available.""" + check_tidy3d_extras_licensed_feature("BroadbandPulse") + return values + + @cached_property + def _source(self) -> Any: + """Implementation of broadband pulse.""" + return tidy3d_extras["mod"].extension.BroadbandPulse( + fmin=self.freq_range[0], + fmax=self.freq_range[1], + minRelAmp=self.minimum_amplitude, + amp=self.amplitude, + phase=self.phase, + offset=self.offset, + ) + + def end_time(self) -> float: + """Time after which the source is effectively turned off / close to zero amplitude.""" + return self._source.end_time(END_TIME_FACTOR_GAUSSIAN) + + def amp_time(self, time: float) -> complex: + """Complex-valued source amplitude as a function of time.""" + return self._source.amp_time(time) + + def amp_freq(self, freq: float) -> complex: + """Complex-valued source amplitude as a function of frequency.""" + return self._source.amp_freq(freq) + + def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: + """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" + return self._source.frequency_range(sigma) + + def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: + """Delegated to `frequency_range_sigma(sigma=num_fwidth)` for computing the frequency range where the source amplitude + is within ``exp(-num_fwidth**2/2)`` of the peak amplitude. + """ + return self.frequency_range_sigma(num_fwidth) + + +SourceTimeType = Union[GaussianPulse, ContinuousWave, CustomSourceTime, BroadbandPulse] diff --git a/tidy3d/_common/components/time.py b/tidy3d/_common/components/time.py new file mode 100644 index 0000000000..a61959cd3b --- /dev/null +++ b/tidy3d/_common/components/time.py @@ -0,0 +1,204 @@ +"""Defines time dependence""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import numpy as np +from pydantic import Field, NonNegativeFloat + +from tidy3d._common.components.base import Tidy3dBaseModel +from tidy3d._common.components.viz import add_ax_if_none +from tidy3d._common.constants import RADIAN +from tidy3d._common.exceptions import SetupError + +if TYPE_CHECKING: + from tidy3d._common.components.types import ArrayFloat1D, Ax, PlotVal + +# in spectrum computation, discard amplitudes with relative magnitude smaller than cutoff +DFT_CUTOFF = 1e-8 + + +class AbstractTimeDependence(ABC, Tidy3dBaseModel): + """Base class describing time dependence.""" + + amplitude: NonNegativeFloat = Field( + 1.0, title="Amplitude", description="Real-valued maximum amplitude of the time dependence." + ) + + phase: float = Field( + 0.0, title="Phase", description="Phase shift of the time dependence.", units=RADIAN + ) + + @abstractmethod + def amp_time(self, time: float) -> complex: + """Complex-valued amplitude as a function of time. + + Parameters + ---------- + time : float + Time in seconds. + + Returns + ------- + complex + Complex-valued amplitude at that time. + """ + + def spectrum( + self, + times: ArrayFloat1D, + freqs: ArrayFloat1D, + dt: float, + ) -> complex: + """Complex-valued spectrum as a function of frequency. + Note: Only the real part of the time signal is used. + + Parameters + ---------- + times : np.ndarray + Times to use to evaluate spectrum Fourier transform. + (Typically the simulation time mesh). + freqs : np.ndarray + Frequencies in Hz to evaluate spectrum at. + dt : float or np.ndarray + Time step to weight FT integral with. + If array, use to weigh each of the time intervals in ``times``. + + Returns + ------- + np.ndarray + Complex-valued array (of len(freqs)) containing spectrum at those frequencies. + """ + + times = np.array(times) + freqs = np.array(freqs) + time_amps = np.real(self.amp_time(times)) + + # if all time amplitudes are zero, just return (complex-valued) zeros for spectrum + if np.all(np.equal(time_amps, 0.0)): + return (0.0 + 0.0j) * np.zeros_like(freqs) + + # Cut to only relevant times + relevant_time_inds = np.where(np.abs(time_amps) / np.amax(np.abs(time_amps)) > DFT_CUTOFF) + # find first and last index where the filter is True + start_ind = relevant_time_inds[0][0] + stop_ind = relevant_time_inds[0][-1] + 1 + time_amps = time_amps[start_ind:stop_ind] + times_cut = times[start_ind:stop_ind] + if times_cut.size == 0: + return (0.0 + 0.0j) * np.zeros_like(freqs) + + # only need to compute DTFT kernel for distinct dts + # usually, there is only one dt, if times is simulation time mesh + dts = np.diff(times_cut) + dts_unique, kernel_indices = np.unique(dts, return_inverse=True) + + dft_kernels = [np.exp(2j * np.pi * freqs * curr_dt) for curr_dt in dts_unique] + running_kernel = np.exp(2j * np.pi * freqs * times_cut[0]) + dft = np.zeros(len(freqs), dtype=complex) + for amp, kernel_index in zip(time_amps, kernel_indices): + dft += running_kernel * amp + running_kernel *= dft_kernels[kernel_index] + + # kernel_indices was one index shorter than time_amps + dft += running_kernel * time_amps[-1] + + return dt * dft / np.sqrt(2 * np.pi) + + @add_ax_if_none + def plot_spectrum_in_frequency_range( + self, + times: ArrayFloat1D, + fmin: float, + fmax: float, + num_freqs: int = 101, + val: PlotVal = "real", + ax: Ax = None, + ) -> Ax: + """Plot the complex-valued amplitude of the time-dependence. + Note: Only the real part of the time signal is used. + + Parameters + ---------- + times : np.ndarray + Array of evenly-spaced times (seconds) to evaluate time-dependence at. + The spectrum is computed from this value and the time frequency content. + To see spectrum for a specific :class:`.Simulation`, + pass ``simulation.tmesh``. + fmin : float + Lower bound of frequency for the spectrum plot. + fmax : float + Upper bound of frequency for the spectrum plot. + num_freqs : int = 101 + Number of frequencies to plot within the [fmin, fmax]. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + times = np.array(times) + + dts = np.diff(times) + if not np.allclose(dts, dts[0] * np.ones_like(dts), atol=1e-17): + raise SetupError("Supplied times not evenly spaced.") + + dt = np.mean(dts) + freqs = np.linspace(fmin, fmax, num_freqs) + + spectrum = self.spectrum(times=times, dt=dt, freqs=freqs) + + if val == "real": + ax.plot(freqs, spectrum.real, color="blueviolet", label="real") + elif val == "imag": + ax.plot(freqs, spectrum.imag, color="crimson", label="imag") + elif val == "abs": + ax.plot(freqs, np.abs(spectrum), color="k", label="abs") + else: + raise ValueError(f"Plot 'val' option of '{val}' not recognized.") + ax.set_xlabel("frequency (Hz)") + ax.set_title("source spectrum") + ax.legend() + ax.set_aspect("auto") + return ax + + @add_ax_if_none + def plot(self, times: ArrayFloat1D, val: PlotVal = "real", ax: Ax = None) -> Ax: + """Plot the complex-valued amplitude of the time-dependence. + + Parameters + ---------- + times : np.ndarray + Array of times (seconds) to plot source at. + To see source time amplitude for a specific :class:`.Simulation`, + pass ``simulation.tmesh``. + val : Literal['real', 'imag', 'abs'] = 'real' + Which part of the spectrum to plot. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + times = np.array(times) + amp_complex = self.amp_time(times) + + if val == "real": + ax.plot(times, amp_complex.real, color="blueviolet", label="real") + elif val == "imag": + ax.plot(times, amp_complex.imag, color="crimson", label="imag") + elif val == "abs": + ax.plot(times, np.abs(amp_complex), color="k", label="abs") + else: + raise ValueError(f"Plot 'val' option of '{val}' not recognized.") + ax.set_xlabel("time (s)") + ax.set_title("source amplitude") + ax.legend() + ax.set_aspect("auto") + return ax diff --git a/tidy3d/_common/components/transformation.py b/tidy3d/_common/components/transformation.py new file mode 100644 index 0000000000..04774c885e --- /dev/null +++ b/tidy3d/_common/components/transformation.py @@ -0,0 +1,210 @@ +"""Defines geometric transformation classes""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Union + +import numpy as np +from pydantic import Field, field_validator + +from tidy3d._common.components.autograd import TracedFloat +from tidy3d._common.components.base import Tidy3dBaseModel, cached_property +from tidy3d._common.components.types.base import Axis, Coordinate +from tidy3d._common.constants import RADIAN +from tidy3d._common.exceptions import ValidationError + +if TYPE_CHECKING: + from tidy3d._common.components.types.base import ArrayFloat2D, TensorReal + + +class AbstractRotation(ABC, Tidy3dBaseModel): + """Abstract rotation of vectors and tensors.""" + + @cached_property + @abstractmethod + def matrix(self) -> TensorReal: + """Rotation matrix.""" + + @cached_property + @abstractmethod + def isidentity(self) -> bool: + """Check whether rotation is identity.""" + + def rotate_vector(self, vector: ArrayFloat2D) -> ArrayFloat2D: + """Rotate a vector/point or a list of vectors/points. + + Parameters + ---------- + points : ArrayLike[float] + Array of shape ``(3, ...)``. + + Returns + ------- + Coordinate + Rotated vector. + """ + + if self.isidentity: + return vector + + if len(vector.shape) == 1: + return self.matrix @ vector + + return np.tensordot(self.matrix, vector, axes=1) + + def rotate_tensor(self, tensor: TensorReal) -> TensorReal: + """Rotate a tensor. + + Parameters + ---------- + tensor : ArrayLike[float] + Array of shape ``(3, 3)``. + + Returns + ------- + TensorReal + Rotated tensor. + """ + + if self.isidentity: + return tensor + + return np.matmul(self.matrix, np.matmul(tensor, self.matrix.T)) + + +class RotationAroundAxis(AbstractRotation): + """Rotation of vectors and tensors around a given vector.""" + + axis: Union[Axis, Coordinate] = Field( + 0, + title="Axis of Rotation", + description="A vector that specifies the axis of rotation, or a single int: 0, 1, or 2, " + "indicating x, y, or z.", + ) + + angle: TracedFloat = Field( + 0.0, + title="Angle of Rotation", + description="Angle of rotation in radians.", + units=RADIAN, + ) + + @field_validator("axis") + @classmethod + def _validate_axis_vector(cls, val: Union[Axis, Coordinate]) -> Coordinate: + if not isinstance(val, tuple): + axis = [0.0, 0.0, 0.0] + axis[val] = 1.0 + val = tuple(axis) + return val + + @field_validator("axis") + @classmethod + def _validate_axis_nonzero_norm(cls, val: Coordinate) -> Coordinate: + norm = np.linalg.norm(val) + if np.isclose(norm, 0): + raise ValidationError( + "The norm of vector 'axis' cannot be zero. Please provide a proper rotation axis." + ) + return val + + @cached_property + def isidentity(self) -> bool: + """Check whether rotation is identity.""" + + return np.isclose(self.angle % (2 * np.pi), 0) + + @cached_property + def matrix(self) -> TensorReal: + """Rotation matrix.""" + + if self.isidentity: + return np.eye(3) + + norm = np.linalg.norm(self.axis) + n = self.axis / norm + c = np.cos(self.angle) + s = np.sin(self.angle) + K = np.array([[0, -n[2], n[1]], [n[2], 0, -n[0]], [-n[1], n[0], 0]]) + R = np.eye(3) + s * K + (1 - c) * K @ K + + return R + + +class AbstractReflection(ABC, Tidy3dBaseModel): + """Abstract reflection of vectors and tensors.""" + + @cached_property + @abstractmethod + def matrix(self) -> TensorReal: + """Reflection matrix.""" + + def reflect_vector(self, vector: ArrayFloat2D) -> ArrayFloat2D: + """Reflect a vector/point or a list of vectors/points. + + Parameters + ---------- + vector : ArrayLike[float] + Array of shape ``(3, ...)``. + + Returns + ------- + Coordinate + Reflected vector. + """ + + if len(vector.shape) == 1: + return self.matrix @ vector + + return np.tensordot(self.matrix, vector, axes=1) + + def reflect_tensor(self, tensor: TensorReal) -> TensorReal: + """Reflect a tensor. + + Parameters + ---------- + tensor : ArrayLike[float] + Array of shape ``(3, 3)``. + + Returns + ------- + TensorReal + Reflected tensor. + """ + + return np.matmul(self.matrix, np.matmul(tensor, self.matrix.T)) + + +class ReflectionFromPlane(AbstractReflection): + """Reflection of vectors and tensors around a given vector.""" + + normal: Coordinate = Field( + (1, 0, 0), + title="Normal of the reflecting plane", + description="A vector that specifies the normal of the plane of reflection", + ) + + @field_validator("normal") + @classmethod + def _validate_normal_nonzero_norm(cls, val: Coordinate) -> Coordinate: + norm = np.linalg.norm(val) + if np.isclose(norm, 0): + raise ValidationError( + "The norm of vector 'normal' cannot be zero. Please provide a proper normal vector." + ) + return val + + @cached_property + def matrix(self) -> TensorReal: + """Reflection matrix.""" + + norm = np.linalg.norm(self.normal) + n = self.normal / norm + R = np.eye(3) - 2 * np.outer(n, n) + + return R + + +RotationType = Union[RotationAroundAxis] +ReflectionType = Union[ReflectionFromPlane] diff --git a/tidy3d/_common/components/types/__init__.py b/tidy3d/_common/components/types/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/components/types/base.py b/tidy3d/_common/components/types/base.py new file mode 100644 index 0000000000..ea408643fd --- /dev/null +++ b/tidy3d/_common/components/types/base.py @@ -0,0 +1,320 @@ +"""Defines 'types' that various fields can be""" + +from __future__ import annotations + +import numbers +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union + +import numpy as np +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + Field, + NonNegativeFloat, + PlainValidator, + PositiveFloat, +) +from pydantic.functional_serializers import PlainSerializer +from pydantic.json_schema import WithJsonSchema + +if TYPE_CHECKING: + from numpy.typing import NDArray + +try: + from matplotlib.axes import Axes +except ImportError: + Axes = None + +from shapely.geometry.base import BaseGeometry + +# type tag default name +TYPE_TAG_STR = "type" + + +def discriminated_union(union: type, discriminator: str = TYPE_TAG_STR) -> type: + return Annotated[union, Field(discriminator=discriminator)] + + +""" Numpy Arrays """ + + +def _dtype2python(value: Any) -> Any: + """Converts numpy scalar types to their python equivalents.""" + if isinstance(value, np.integer): + return int(value) + if isinstance(value, np.floating): + return float(value) + if isinstance(value, np.complexfloating): + return complex(value) + if isinstance(value, np.bool_): + return bool(value) + return value + + +def _from_complex_dict(v: Any) -> Any: + if isinstance(v, dict) and "real" in v and "imag" in v: + return np.asarray(v["real"]) + 1j * np.asarray(v["imag"]) + return v + + +def _auto_serializer(a: Any, _: Any) -> Any: + """Serializes numpy arrays and scalars for JSON.""" + if isinstance(a, complex) or ( + hasattr(np, "complexfloating") and isinstance(a, np.complexfloating) + ): + return {"real": float(a.real), "imag": float(a.imag)} + if isinstance(a, np.ndarray): + if np.iscomplexobj(a): + return {"real": a.real.tolist(), "imag": a.imag.tolist()} + else: + return a.tolist() + if isinstance(a, float) or (hasattr(np, "floating") and isinstance(a, np.floating)): + return float(a) # Ensure basic Python float + if isinstance(a, int) or (hasattr(np, "integer") and isinstance(a, np.integer)): + return int(a) # Ensure basic Python int + if hasattr(np, "number") and isinstance(a, np.number): + return a.item() + return a + + +DTypeLike = Annotated[np.dtype, PlainValidator(np.dtype), WithJsonSchema({"type": "np.dtype"})] + + +class ArrayConstraints(BaseModel): + """Container for array constraints.""" + + model_config = ConfigDict(frozen=True) + + dtype: Optional[DTypeLike] = None + ndim: Optional[int] = None + shape: Optional[tuple[int, ...]] = None + forbid_nan: bool = True + scalar_to_1d: bool = False + strict: bool = False + + +def _coerce(v: Any, *, constraints: ArrayConstraints) -> NDArray: + """Convert input to a NumPy array with constraints. + + Raises + ------ + ValueError + - If conversion to an array fails. + - If the array ends up with dtype=object (unsupported element type). + - If the number of dimensions or shape does not match the expectations. + - If ``forbid_nan`` is ``True`` and the array contains NaN values. + """ + if constraints.strict and np.isscalar(v): + raise ValueError( + f"strict mode: scalar value {type(v).__name__!r} cannot be coerced to a NumPy array. " + ) + + try: + # constraints.dtype is already an np.dtype object or None + arr = np.asarray(v) if constraints.dtype is None else np.asarray(v, dtype=constraints.dtype) + except Exception as e: + raise ValueError(f"cannot convert {type(v).__name__!r} to a NumPy array") from e + + if arr.dtype == np.dtype("object"): + raise ValueError(f"unsupported element type {type(v).__name__!r} for array coercion") + + if ( + arr.ndim == 0 + and (constraints.ndim == 1 or constraints.ndim is None) + and constraints.scalar_to_1d + ): + arr = arr.reshape(1) + if constraints.ndim is not None and arr.ndim != constraints.ndim: + raise ValueError(f"expected {constraints.ndim}-D, got {arr.ndim}-D") + if constraints.shape is not None and tuple(arr.shape) != constraints.shape: + raise ValueError(f"expected shape {constraints.shape}, got {tuple(arr.shape)}") + if constraints.forbid_nan and np.any(np.isnan(arr)): + raise ValueError("array contains NaN") + + # enforce immutability of our Pydantic models + arr.flags.writeable = False + + return arr + + +def array_alias( + *, + dtype: Optional[Any] = None, + ndim: Optional[int] = None, + shape: Optional[tuple[int, ...]] = None, + forbid_nan: bool = True, + scalar_to_1d: bool = False, + strict: bool = False, +) -> Any: + constraints = ArrayConstraints( + dtype=dtype, + ndim=ndim, + shape=shape, + forbid_nan=forbid_nan, + scalar_to_1d=scalar_to_1d, + strict=strict, + ) + serializer = PlainSerializer(_auto_serializer, when_used="json") + + base_schema = { + "type": "ArrayLike", + "x-array-dtype": getattr(constraints.dtype, "str", None), + "x-array-ndim": constraints.ndim, + "x-array-shape": constraints.shape, + "x-array-forbid_nan": constraints.forbid_nan, + "x-array-scalar_to_1d": constraints.scalar_to_1d, + "x-array-strict": constraints.strict, + } + + return Annotated[ + np.ndarray, + BeforeValidator(_from_complex_dict), + BeforeValidator(lambda v: _coerce(v, constraints=constraints)), + serializer, + WithJsonSchema(base_schema), + ] + + +ArrayLike = array_alias() +ArrayLikeStrict = array_alias(strict=True) + +ArrayInt1D = array_alias(dtype=int, ndim=1, scalar_to_1d=True) + +ArrayFloat = array_alias(dtype=float) +ArrayFloat1D = array_alias(dtype=float, ndim=1, scalar_to_1d=True) +ArrayFloat2D = array_alias(dtype=float, ndim=2) +ArrayFloat3D = array_alias(dtype=float, ndim=3) +ArrayFloat4D = array_alias(dtype=float, ndim=4) + +ArrayComplex = array_alias(dtype=complex) +ArrayComplex1D = array_alias(dtype=complex, ndim=1, scalar_to_1d=True) +ArrayComplex2D = array_alias(dtype=complex, ndim=2) +ArrayComplex3D = array_alias(dtype=complex, ndim=3) +ArrayComplex4D = array_alias(dtype=complex, ndim=4) + +TensorReal = array_alias(dtype=float, ndim=2, shape=(3, 3)) +MatrixReal4x4 = array_alias(dtype=float, ndim=2, shape=(4, 4)) + +""" Complex Values """ + + +def _parse_complex(v: Any) -> complex: + if isinstance(v, complex): + return v + + if isinstance(v, dict) and "real" in v and "imag" in v: + return complex(v["real"], v["imag"]) + + if isinstance(v, numbers.Number): + return complex(v) + + if hasattr(v, "__complex__"): + try: + return complex(v.__complex__()) + except Exception: + pass + + if isinstance(v, (list, tuple)) and len(v) == 2: + return complex(v[0], v[1]) + + return v + + +Complex = Annotated[ + complex, + BeforeValidator(_parse_complex), + PlainSerializer( + lambda z, _: {"real": z.real, "imag": z.imag}, + when_used="json", + return_type=dict, + ), +] + +""" symmetry """ + +Symmetry = Annotated[Literal[0, -1, 1], BeforeValidator(_dtype2python)] +ScalarSymmetry = Annotated[Literal[0, 1], BeforeValidator(_dtype2python)] + +""" geometric """ + +Size1D = NonNegativeFloat +Size = tuple[Size1D, Size1D, Size1D] +Coordinate = tuple[float, float, float] +CoordinateOptional = tuple[Optional[float], Optional[float], Optional[float]] +Coordinate2D = tuple[float, float] +Bound = tuple[Coordinate, Coordinate] +GridSize = Union[PositiveFloat, tuple[PositiveFloat, ...]] +Axis = Annotated[Literal[0, 1, 2], BeforeValidator(_dtype2python)] +Axis2D = Annotated[Literal[0, 1], BeforeValidator(_dtype2python)] +Shapely = BaseGeometry +PlanePosition = Literal["bottom", "middle", "top"] +ClipOperationType = Literal["union", "intersection", "difference", "symmetric_difference"] +BoxSurface = Literal["x-", "x+", "y-", "y+", "z-", "z+"] +LengthUnit = Literal["nm", "μm", "um", "mm", "cm", "m", "mil", "in"] +PriorityMode = Literal["equal", "conductor"] + +""" medium """ + +# custom medium +InterpMethod = Literal["nearest", "linear"] + +PoleAndResidue = tuple[Complex, Complex] +PolesAndResidues = tuple[PoleAndResidue, ...] +FreqBoundMax = float +FreqBoundMin = float +FreqBound = tuple[FreqBoundMin, FreqBoundMax] + +PermittivityComponent = Literal["xx", "xy", "xz", "yx", "yy", "yz", "zx", "zy", "zz"] + +""" sources """ + +Polarization = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] +Direction = Literal["+", "-"] + +""" monitors """ + + +def _list_to_tuple(v: Any) -> Any: + if isinstance(v, list): + return tuple(v) + return v + + +EMField = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] +FieldType = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] +FreqArray = ArrayFloat1D +ObsGridArray = FreqArray +PolarizationBasis = Literal["linear", "circular"] +AuxField = Literal["Nfx", "Nfy", "Nfz"] + +""" plotting """ + +Ax = Axes +PlotVal = Literal["real", "imag", "abs"] +FieldVal = Literal["real", "imag", "abs", "abs^2", "phase"] +RealFieldVal = Literal["real", "abs", "abs^2"] +PlotScale = Literal["lin", "dB", "log", "symlog"] +ColormapType = Literal["divergent", "sequential", "cyclic"] + +""" mode solver """ + +ModeSolverType = Literal["tensorial", "diagonal"] +EpsSpecType = Literal["diagonal", "tensorial_real", "tensorial_complex"] +ModeClassification = Literal["TEM", "quasi-TEM", "TE", "TM", "Hybrid"] + +""" mode tracking """ + +TrackFreq = Literal["central", "lowest", "highest"] + +""" lumped elements""" + +LumpDistType = Literal["off", "laterally_only", "on"] + +""" dataset """ + +xyz = Literal["x", "y", "z"] +UnitsZBF = Literal["mm", "cm", "in", "m"] + +""" sentinel """ +Undefined = object() diff --git a/tidy3d/_common/components/types/third_party.py b/tidy3d/_common/components/types/third_party.py new file mode 100644 index 0000000000..1530d2f088 --- /dev/null +++ b/tidy3d/_common/components/types/third_party.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from typing import Any + +from tidy3d._common.packaging import check_import + +# TODO Complicated as trimesh should be a core package unless decoupled implementation types in functional location. +# We need to restructure. +if check_import("trimesh"): + import trimesh # Won't add much overhead if already imported + + TrimeshType = trimesh.Trimesh +else: + TrimeshType = Any diff --git a/tidy3d/_common/components/types/utils.py b/tidy3d/_common/components/types/utils.py new file mode 100644 index 0000000000..333cdb807e --- /dev/null +++ b/tidy3d/_common/components/types/utils.py @@ -0,0 +1,33 @@ +"""Utilities for type & schema creation.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from pydantic_core import core_schema + +if TYPE_CHECKING: + from pydantic import GetCoreSchemaHandler + + +def _add_schema(arbitrary_type: type, title: str, field_type_str: str) -> None: + """Adds a schema to the ``arbitrary_type`` class without subclassing.""" + + @classmethod + def __get_pydantic_core_schema__( + cls: type, _source_type: type, _handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + def _serialize(value: Any, info: core_schema.SerializationInfo) -> Any: + from tidy3d._common.components.autograd.utils import get_static + from tidy3d._common.components.types.base import _auto_serializer + + return _auto_serializer(get_static(value), info) + + return core_schema.any_schema( + metadata={"title": title, "type": field_type_str}, + serialization=core_schema.plain_serializer_function_ser_schema( + _serialize, info_arg=True + ), + ) + + arbitrary_type.__get_pydantic_core_schema__ = __get_pydantic_core_schema__ diff --git a/tidy3d/_common/components/validators.py b/tidy3d/_common/components/validators.py new file mode 100644 index 0000000000..ded4659f02 --- /dev/null +++ b/tidy3d/_common/components/validators.py @@ -0,0 +1,122 @@ +"""Defines various validation functions that get used to ensure inputs are legit""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, TypeVar, Union + +import numpy as np +from numpy.typing import NDArray +from pydantic import field_validator + +from tidy3d._common.components.autograd.utils import get_static, hasbox +from tidy3d._common.components.data.data_array import DATA_ARRAY_MAP +from tidy3d._common.exceptions import ValidationError +from tidy3d._common.log import log + +if TYPE_CHECKING: + from typing import Callable, Optional + + from pydantic import FieldValidationInfo + +T = TypeVar("T") + +""" Explanation of pydantic validators: + + Validators are class methods that are added to the models to validate their fields (kwargs). + The functions on this page return validators based on config arguments + and are generally in multiple components of tidy3d. + The inner functions (validators) are decorated with @pydantic.validator, which is configured. + First argument is the string of the field being validated in the model. + ``allow_reuse`` lets us use the validator in more than one model. + ``always`` makes sure if the model is changed, the validator gets called again. + + The function being decorated by @pydantic.validator generally takes + ``cls`` the class that the validator is added to. + ``val`` the value of the field being validated. + ``values`` a dictionary containing all of the other fields of the model. + It is important to note that the validator only has access to fields that are defined + before the field being validated. + Fields defined under the validated field will not be in ``values``. + + All validators generally should throw an exception if the validation fails + and return val if it passes. + Sometimes, we can use validators to change ``val`` or ``values``, + but this should be done with caution as it can be hard to reason about. + + To add a validator from this file to the pydantic model, + put it in the model's main body and assign it to a variable (class method). + For example ``_plane_validator = assert_plane()``. + Note, if the assigned name ``_plane_validator`` is used later on for another validator, say, + the original validator will be overwritten so be aware of this. + + For more details: `Pydantic Validators `_ +""" + +# Lowest frequency supported (Hz) +MIN_FREQUENCY = 1e5 + +FloatArray = Union[Sequence[float], NDArray] + + +def _assert_min_freq(freqs: FloatArray, msg_start: str) -> None: + """Check if all ``freqs`` are above the minimum frequency.""" + if np.min(freqs) < MIN_FREQUENCY: + raise ValidationError( + f"{msg_start} must be no lower than {MIN_FREQUENCY:.0e} Hz. " + "Note that the unit of frequency is 'Hz'." + ) + + +def _warn_unsupported_traced_argument( + *names: str, +) -> Callable[[type, Any, FieldValidationInfo], Any]: + @field_validator(*names) + @classmethod + def _warn_traced_arg(cls: type, val: Any, info: FieldValidationInfo) -> Any: + if hasbox(val): + log.warning( + f"Field '{info.field_name}' of '{cls.__name__}' received an autograd tracer " + f"(i.e., a value being tracked for automatic differentiation). " + f"Automatic differentiation through this field is unsupported, " + f"so the tracer has been converted to its static value. " + f"If you want to avoid this warning, you manually unbox the value " + f"using the 'autograd.tracer.getval' function before passing it to Tidy3D." + ) + return get_static(val) + return val + + return _warn_traced_arg + + +def warn_if_dataset_none( + field_name: str, +) -> Callable[[type, Optional[dict[str, Any]]], Optional[dict[str, Any]]]: + """Warn if a Dataset field has None in its dictionary.""" + + @field_validator(field_name, mode="before") + @classmethod + def _warn_if_none(cls: type, val: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: + """Warn if the DataArrays fail to load.""" + if isinstance(val, dict): + if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): + log.warning(f"Loading {field_name} without data.", custom_loc=[field_name]) + return None + return val + + return _warn_if_none + + +# FIXME: this validator doesn't do anything +def validate_name_str() -> Callable[[type, Optional[str]], Optional[str]]: + """make sure the name does not include [, ] (used for default names)""" + + @field_validator("name") + @classmethod + def field_has_unique_names(cls: type, val: Optional[str]) -> Optional[str]: + """raise exception if '[' or ']' in name""" + # if val and ('[' in val or ']' in val): + # raise SetupError(f"'[' or ']' not allowed in name: {val} (used for defaults)") + return val + + return field_has_unique_names diff --git a/tidy3d/_common/components/viz/__init__.py b/tidy3d/_common/components/viz/__init__.py new file mode 100644 index 0000000000..649c58d3a4 --- /dev/null +++ b/tidy3d/_common/components/viz/__init__.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from tidy3d._common.components.viz.axes_utils import ( + add_ax_if_none, + equal_aspect, + make_ax, + set_default_labels_and_title, +) +from tidy3d._common.components.viz.descartes import Polygon, polygon_patch, polygon_path +from tidy3d._common.components.viz.flex_style import ( + apply_tidy3d_params, + restore_matplotlib_rcparams, +) +from tidy3d._common.components.viz.plot_params import ( + AbstractPlotParams, + PathPlotParams, + PlotParams, + plot_params_abc, + plot_params_absorber, + plot_params_bloch, + plot_params_fluid, + plot_params_geometry, + plot_params_grid, + plot_params_lumped_element, + plot_params_monitor, + plot_params_override_structures, + plot_params_pec, + plot_params_pmc, + plot_params_pml, + plot_params_source, + plot_params_structure, + plot_params_symmetry, +) +from tidy3d._common.components.viz.plot_sim_3d import plot_scene_3d, plot_sim_3d +from tidy3d._common.components.viz.styles import ( + ARROW_ALPHA, + ARROW_COLOR_ABSORBER, + ARROW_COLOR_MONITOR, + ARROW_COLOR_POLARIZATION, + ARROW_COLOR_SOURCE, + ARROW_LENGTH, + FLEXCOMPUTE_COLORS, + MEDIUM_CMAP, + PLOT_BUFFER, + STRUCTURE_EPS_CMAP, + STRUCTURE_EPS_CMAP_R, + STRUCTURE_HEAT_COND_CMAP, + arrow_style, +) +from tidy3d._common.components.viz.visualization_spec import MATPLOTLIB_IMPORTED, VisualizationSpec + +apply_tidy3d_params() + +__all__ = [ + "ARROW_ALPHA", + "ARROW_COLOR_ABSORBER", + "ARROW_COLOR_MONITOR", + "ARROW_COLOR_POLARIZATION", + "ARROW_COLOR_SOURCE", + "ARROW_LENGTH", + "FLEXCOMPUTE_COLORS", + "MATPLOTLIB_IMPORTED", + "MEDIUM_CMAP", + "PLOT_BUFFER", + "STRUCTURE_EPS_CMAP", + "STRUCTURE_EPS_CMAP_R", + "STRUCTURE_HEAT_COND_CMAP", + "AbstractPlotParams", + "PathPlotParams", + "PlotParams", + "Polygon", + "VisualizationSpec", + "add_ax_if_none", + "arrow_style", + "equal_aspect", + "make_ax", + "plot_params_abc", + "plot_params_absorber", + "plot_params_bloch", + "plot_params_fluid", + "plot_params_geometry", + "plot_params_grid", + "plot_params_lumped_element", + "plot_params_monitor", + "plot_params_override_structures", + "plot_params_pec", + "plot_params_pmc", + "plot_params_pml", + "plot_params_source", + "plot_params_structure", + "plot_params_symmetry", + "plot_scene_3d", + "plot_sim_3d", + "polygon_patch", + "polygon_path", + "restore_matplotlib_rcparams", + "set_default_labels_and_title", +] diff --git a/tidy3d/_common/components/viz/axes_utils.py b/tidy3d/_common/components/viz/axes_utils.py new file mode 100644 index 0000000000..4a3e342a7b --- /dev/null +++ b/tidy3d/_common/components/viz/axes_utils.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +from functools import wraps +from typing import TYPE_CHECKING + +from tidy3d._common.components.types.base import LengthUnit +from tidy3d._common.constants import UnitScaling +from tidy3d._common.exceptions import Tidy3dKeyError + +if TYPE_CHECKING: + from typing import Callable, ParamSpec, TypeVar + + import matplotlib.ticker as ticker + from matplotlib.axes import Axes + + P = ParamSpec("P") + T = TypeVar("T", bound=Callable[..., Axes]) + from typing import Optional + + from tidy3d._common.components.types.base import Ax, Axis + + +def _create_unit_aware_locator() -> ticker.Locator: + """Create UnitAwareLocator lazily due to matplotlib import restrictions.""" + import matplotlib.ticker as ticker + + class UnitAwareLocator(ticker.Locator): + """Custom tick locator that places ticks at nice positions in the target unit.""" + + def __init__(self, scale_factor: float) -> None: + """ + Parameters + ---------- + scale_factor : float + Factor to convert from micrometers to the target unit. + """ + super().__init__() + self.scale_factor = scale_factor + + def __call__(self) -> list[float]: + vmin, vmax = self.axis.get_view_interval() + return self.tick_values(vmin, vmax) + + def view_limits(self, vmin: float, vmax: float) -> tuple[float, float]: + """Override to prevent matplotlib from adjusting our limits.""" + return vmin, vmax + + def tick_values(self, vmin: float, vmax: float) -> list[float]: + # convert the view range to the target unit + vmin_unit = vmin * self.scale_factor + vmax_unit = vmax * self.scale_factor + + # tolerance for floating point comparisons in target unit + unit_range = vmax_unit - vmin_unit + unit_tol = unit_range * 1e-8 + + locator = ticker.MaxNLocator(nbins=11, prune=None, min_n_ticks=2) + + ticks_unit = locator.tick_values(vmin_unit, vmax_unit) + + # ensure we have ticks that cover the full range + if len(ticks_unit) > 0: + if ticks_unit[0] > vmin_unit + unit_tol or ticks_unit[-1] < vmax_unit - unit_tol: + # try with fewer bins to get better coverage + for n in [10, 9, 8, 7, 6, 5]: + locator = ticker.MaxNLocator(nbins=n, prune=None, min_n_ticks=2) + ticks_unit = locator.tick_values(vmin_unit, vmax_unit) + if ( + len(ticks_unit) >= 3 + and ticks_unit[0] <= vmin_unit + unit_tol + and ticks_unit[-1] >= vmax_unit - unit_tol + ): + break + + # if still no good coverage, manually ensure edge coverage + if len(ticks_unit) > 0: + if ( + ticks_unit[0] > vmin_unit + unit_tol + or ticks_unit[-1] < vmax_unit - unit_tol + ): + # find a reasonable step size from existing ticks + if len(ticks_unit) > 1: + step = ticks_unit[1] - ticks_unit[0] + else: + step = unit_range / 5 + + # extend the range to ensure coverage + extended_min = vmin_unit - step + extended_max = vmax_unit + step + + # try one more time with extended range + locator = ticker.MaxNLocator(nbins=8, prune=None, min_n_ticks=2) + ticks_unit = locator.tick_values(extended_min, extended_max) + + # filter to reasonable bounds around the original range + ticks_unit = [ + t + for t in ticks_unit + if t >= vmin_unit - step / 2 and t <= vmax_unit + step / 2 + ] + + # convert the nice ticks back to the original data unit (micrometers) + ticks_um = ticks_unit / self.scale_factor + + # filter to ensure ticks are within bounds (with small tolerance) + eps = (vmax - vmin) * 1e-8 + return [tick for tick in ticks_um if vmin - eps <= tick <= vmax + eps] + + return UnitAwareLocator + + +def make_ax() -> Ax: + """makes an empty ``ax``.""" + import matplotlib.pyplot as plt + + _, ax = plt.subplots(1, 1, tight_layout=True) + return ax + + +def add_ax_if_none(plot: T) -> T: + """Decorates ``plot(*args, **kwargs, ax=None)`` function. + if ax=None in the function call, creates an ax and feeds it to rest of function. + """ + + @wraps(plot) + def _plot(*args: P.args, **kwargs: P.kwargs) -> Axes: + """New plot function using a generated ax if None.""" + if kwargs.get("ax") is None: + ax = make_ax() + kwargs["ax"] = ax + return plot(*args, **kwargs) + + return _plot + + +def equal_aspect(plot: T) -> T: + """Decorates a plotting function returning a matplotlib axes. + Ensures the aspect ratio of the returned axes is set to equal. + Useful for 2D plots, like sim.plot() or sim_data.plot_fields() + """ + + @wraps(plot) + def _plot(*args: P.args, **kwargs: P.kwargs) -> Axes: + """New plot function with equal aspect ratio axes returned.""" + ax = plot(*args, **kwargs) + ax.set_aspect("equal") + return ax + + return _plot + + +def set_default_labels_and_title( + axis_labels: tuple[str, str], + axis: Axis, + position: float, + ax: Ax, + plot_length_units: Optional[LengthUnit] = None, +) -> Ax: + """Adds axis labels and title to plots involving spatial dimensions. + When the ``plot_length_units`` are specified, the plot axes are scaled, and + the title and axis labels include the desired units. + """ + + import matplotlib.ticker as ticker + + xlabel = axis_labels[0] + ylabel = axis_labels[1] + if plot_length_units is not None: + if plot_length_units not in UnitScaling: + raise Tidy3dKeyError( + f"Provided units '{plot_length_units}' are not supported. " + f"Please choose one of '{LengthUnit}'." + ) + ax.set_xlabel(f"{xlabel} ({plot_length_units})") + ax.set_ylabel(f"{ylabel} ({plot_length_units})") + + scale_factor = UnitScaling[plot_length_units] + + # for imperial units, use custom tick locator for nice tick positions + if plot_length_units in ["mil", "in"]: + UnitAwareLocator = _create_unit_aware_locator() + x_locator = UnitAwareLocator(scale_factor) + y_locator = UnitAwareLocator(scale_factor) + ax.xaxis.set_major_locator(x_locator) + ax.yaxis.set_major_locator(y_locator) + + formatter = ticker.FuncFormatter(lambda y, _: f"{y * scale_factor:.2f}") + + ax.xaxis.set_major_formatter(formatter) + ax.yaxis.set_major_formatter(formatter) + + position_scaled = position * scale_factor + ax.set_title(f"cross section at {'xyz'[axis]}={position_scaled:.2f} ({plot_length_units})") + else: + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(f"cross section at {'xyz'[axis]}={position:.2f}") + return ax diff --git a/tidy3d/_common/components/viz/descartes.py b/tidy3d/_common/components/viz/descartes.py new file mode 100644 index 0000000000..572dfc44ba --- /dev/null +++ b/tidy3d/_common/components/viz/descartes.py @@ -0,0 +1,113 @@ +"""================================================================================================= +Descartes modified from https://pypi.org/project/descartes/ for Shapely >= 1.8.0 + +Copyright Flexcompute 2022 + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER +IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from numpy.typing import NDArray + from shapely.geometry.base import BaseGeometry + +try: + from matplotlib.patches import PathPatch + from matplotlib.path import Path +except ImportError: + pass +from numpy import array, concatenate, ones + + +class Polygon: + """Adapt Shapely polygons to a common interface""" + + def __init__(self, context: dict[str, Any]) -> None: + if isinstance(context, dict): + self.context = context["coordinates"] + else: + self.context = context + + @property + def exterior(self) -> Any: + """Get polygon exterior.""" + value = getattr(self.context, "exterior", None) + if value is None: + value = self.context[0] + return value + + @property + def interiors(self) -> Any: + """Get polygon interiors.""" + value = getattr(self.context, "interiors", None) + if value is None: + value = self.context[1:] + return value + + +def polygon_path(polygon: BaseGeometry) -> Path: + """Constructs a compound matplotlib path from a Shapely or GeoJSON-like + geometric object""" + + def coding(obj: Any) -> NDArray: + # The codes will be all "LINETO" commands, except for "MOVETO"s at the + # beginning of each subpath + crds = getattr(obj, "coords", None) + if crds is None: + crds = obj + n = len(crds) + vals = ones(n, dtype=Path.code_type) * Path.LINETO + if len(vals) > 0: + vals[0] = Path.MOVETO + return vals + + ptype = polygon.geom_type + if ptype == "Polygon": + polygon = [Polygon(polygon)] + elif ptype == "MultiPolygon": + polygon = [Polygon(p) for p in polygon.geoms] + + vertices = concatenate( + [ + concatenate( + [array(t.exterior.coords)[:, :2]] + [array(r.coords)[:, :2] for r in t.interiors] + ) + for t in polygon + ] + ) + codes = concatenate( + [concatenate([coding(t.exterior)] + [coding(r) for r in t.interiors]) for t in polygon] + ) + + return Path(vertices, codes) + + +def polygon_patch(polygon: BaseGeometry, **kwargs: Any) -> PathPatch: + """Constructs a matplotlib patch from a geometric object + + The ``polygon`` may be a Shapely or GeoJSON-like object with or without holes. + The ``kwargs`` are those supported by the matplotlib.patches.Polygon class + constructor. Returns an instance of matplotlib.patches.PathPatch. + + Example + ------- + >>> b = Point(0, 0).buffer(1.0) # doctest: +SKIP + >>> patch = PolygonPatch(b, fc='blue', ec='blue', alpha=0.5) # doctest: +SKIP + >>> axis.add_patch(patch) # doctest: +SKIP + + """ + return PathPatch(polygon_path(polygon), **kwargs) + + +"""End descartes modification +=================================================================================================""" diff --git a/tidy3d/_common/components/viz/flex_color_palettes.py b/tidy3d/_common/components/viz/flex_color_palettes.py new file mode 100644 index 0000000000..7fc1454a0b --- /dev/null +++ b/tidy3d/_common/components/viz/flex_color_palettes.py @@ -0,0 +1,3306 @@ +from __future__ import annotations + +SEQUENTIAL_PALETTES_HEX = { + "flex_turquoise_seq": [ + "#ffffff", + "#fefefe", + "#fdfdfd", + "#fcfcfc", + "#fbfbfb", + "#fafbfa", + "#f9fafa", + "#f8f9f9", + "#f7f8f8", + "#f6f7f7", + "#f5f6f6", + "#f3f5f5", + "#f2f4f4", + "#f1f3f3", + "#f0f3f2", + "#eff2f1", + "#eef1f1", + "#edf0f0", + "#ecefef", + "#ebeeee", + "#eaeded", + "#e9edec", + "#e8eceb", + "#e7ebeb", + "#e6eaea", + "#e5e9e9", + "#e4e8e8", + "#e3e7e7", + "#e2e7e6", + "#e1e6e5", + "#e0e5e5", + "#dfe4e4", + "#dee3e3", + "#dde2e2", + "#dce2e1", + "#dbe1e0", + "#dae0df", + "#d9dfdf", + "#d8dede", + "#d7dedd", + "#d6dddc", + "#d5dcdb", + "#d4dbdb", + "#d3dada", + "#d2dad9", + "#d1d9d8", + "#d1d8d7", + "#d0d7d6", + "#cfd6d6", + "#ced6d5", + "#cdd5d4", + "#ccd4d3", + "#cbd3d2", + "#cad2d2", + "#c9d2d1", + "#c8d1d0", + "#c7d0cf", + "#c6cfce", + "#c5cece", + "#c4cecd", + "#c3cdcc", + "#c2cccb", + "#c1cbca", + "#c0cbca", + "#bfcac9", + "#bec9c8", + "#bec8c7", + "#bdc8c7", + "#bcc7c6", + "#bbc6c5", + "#bac5c4", + "#b9c5c3", + "#b8c4c3", + "#b7c3c2", + "#b6c2c1", + "#b5c2c0", + "#b4c1c0", + "#b3c0bf", + "#b2bfbe", + "#b2bfbd", + "#b1bebd", + "#b0bdbc", + "#afbcbb", + "#aebcba", + "#adbbba", + "#acbab9", + "#abbab8", + "#aab9b7", + "#a9b8b7", + "#a9b7b6", + "#a8b7b5", + "#a7b6b4", + "#a6b5b4", + "#a5b4b3", + "#a4b4b2", + "#a3b3b2", + "#a2b2b1", + "#a1b2b0", + "#a1b1af", + "#a0b0af", + "#9fb0ae", + "#9eafad", + "#9daeac", + "#9cadac", + "#9badab", + "#9aacaa", + "#99abaa", + "#99aba9", + "#98aaa8", + "#97a9a7", + "#96a9a7", + "#95a8a6", + "#94a7a5", + "#93a6a5", + "#92a6a4", + "#92a5a3", + "#91a4a2", + "#90a4a2", + "#8fa3a1", + "#8ea2a0", + "#8da2a0", + "#8ca19f", + "#8ca09e", + "#8ba09e", + "#8a9f9d", + "#899e9c", + "#889e9c", + "#879d9b", + "#869c9a", + "#869c9a", + "#859b99", + "#849a98", + "#839a97", + "#829997", + "#819896", + "#809895", + "#809795", + "#7f9694", + "#7e9693", + "#7d9593", + "#7c9492", + "#7b9491", + "#7a9391", + "#7a9290", + "#79928f", + "#78918f", + "#77908e", + "#76908d", + "#758f8d", + "#758f8c", + "#748e8b", + "#738d8b", + "#728d8a", + "#718c89", + "#708b89", + "#708b88", + "#6f8a87", + "#6e8987", + "#6d8986", + "#6c8885", + "#6b8885", + "#6a8784", + "#6a8684", + "#698683", + "#688582", + "#678482", + "#668481", + "#658380", + "#658280", + "#64827f", + "#63817e", + "#62817e", + "#61807d", + "#607f7c", + "#607f7c", + "#5f7e7b", + "#5e7d7b", + "#5d7d7a", + "#5c7c79", + "#5b7c79", + "#5b7b78", + "#5a7a77", + "#597a77", + "#587976", + "#577975", + "#567875", + "#567774", + "#557774", + "#547673", + "#537572", + "#527572", + "#517471", + "#507470", + "#507370", + "#4f726f", + "#4e726f", + "#4d716e", + "#4c716d", + "#4b706d", + "#4b6f6c", + "#4a6f6b", + "#496e6b", + "#486e6a", + "#476d6a", + "#466c69", + "#456c68", + "#446b68", + "#446b67", + "#436a67", + "#426966", + "#416965", + "#406865", + "#3f6864", + "#3e6763", + "#3e6663", + "#3d6662", + "#3c6562", + "#3b6561", + "#3a6460", + "#396360", + "#38635f", + "#37625f", + "#36625e", + "#35615d", + "#35605d", + "#34605c", + "#335f5c", + "#325f5b", + "#315e5a", + "#305d5a", + "#2f5d59", + "#2e5c58", + "#2d5c58", + "#2c5b57", + "#2b5a57", + "#2a5a56", + "#295955", + "#285955", + "#275854", + "#265754", + "#255753", + "#245652", + "#235652", + "#225551", + "#215551", + "#205450", + "#1e534f", + "#1d534f", + "#1c524e", + "#1b524e", + "#1a514d", + "#18504c", + "#17504c", + "#164f4b", + "#144f4b", + "#134e4a", + ], + "flex_green_seq": [ + "#ffffff", + "#fefefe", + "#fdfdfd", + "#fcfcfc", + "#fbfbfb", + "#f9fafa", + "#f8f9f9", + "#f7f8f8", + "#f6f7f7", + "#f5f6f6", + "#f4f5f5", + "#f3f5f3", + "#f2f4f2", + "#f1f3f1", + "#f0f2f0", + "#eff1ef", + "#eef0ee", + "#ecefed", + "#ebeeec", + "#eaedeb", + "#e9ecea", + "#e8ebe9", + "#e7eae8", + "#e6eae7", + "#e5e9e6", + "#e4e8e5", + "#e3e7e4", + "#e2e6e3", + "#e1e5e2", + "#e0e4e1", + "#dfe3e0", + "#dee3df", + "#dde2de", + "#dce1dd", + "#dbe0dc", + "#dadfdb", + "#d8deda", + "#d7ddd9", + "#d6dcd8", + "#d5dcd7", + "#d4dbd6", + "#d3dad5", + "#d2d9d4", + "#d1d8d3", + "#d0d7d2", + "#cfd6d1", + "#ced6d0", + "#cdd5cf", + "#ccd4ce", + "#cbd3ce", + "#cad2cd", + "#c9d1cc", + "#c8d1cb", + "#c7d0ca", + "#c6cfc9", + "#c5cec8", + "#c4cdc7", + "#c3cdc6", + "#c2ccc5", + "#c1cbc4", + "#c0cac3", + "#bfc9c2", + "#bec9c1", + "#bdc8c0", + "#bcc7bf", + "#bbc6be", + "#bac5bd", + "#b9c5bd", + "#b9c4bc", + "#b8c3bb", + "#b7c2ba", + "#b6c1b9", + "#b5c1b8", + "#b4c0b7", + "#b3bfb6", + "#b2beb5", + "#b1bdb4", + "#b0bdb3", + "#afbcb2", + "#aebbb1", + "#adbab1", + "#acbab0", + "#abb9af", + "#aab8ae", + "#a9b7ad", + "#a8b7ac", + "#a7b6ab", + "#a6b5aa", + "#a5b4a9", + "#a5b4a8", + "#a4b3a8", + "#a3b2a7", + "#a2b1a6", + "#a1b1a5", + "#a0b0a4", + "#9fafa3", + "#9eaea2", + "#9daea1", + "#9cada0", + "#9baca0", + "#9aab9f", + "#99ab9e", + "#99aa9d", + "#98a99c", + "#97a89b", + "#96a89a", + "#95a799", + "#94a699", + "#93a598", + "#92a597", + "#91a496", + "#90a395", + "#90a394", + "#8fa293", + "#8ea193", + "#8da092", + "#8ca091", + "#8b9f90", + "#8a9e8f", + "#899e8e", + "#889d8d", + "#879c8d", + "#879b8c", + "#869b8b", + "#859a8a", + "#849989", + "#839988", + "#829888", + "#819787", + "#809786", + "#809685", + "#7f9584", + "#7e9483", + "#7d9483", + "#7c9382", + "#7b9281", + "#7a9280", + "#79917f", + "#79907e", + "#78907e", + "#778f7d", + "#768e7c", + "#758e7b", + "#748d7a", + "#738c79", + "#728c79", + "#728b78", + "#718a77", + "#708a76", + "#6f8975", + "#6e8875", + "#6d8774", + "#6c8773", + "#6c8672", + "#6b8571", + "#6a8571", + "#698470", + "#68836f", + "#67836e", + "#66826d", + "#66816d", + "#65816c", + "#64806b", + "#637f6a", + "#627f69", + "#617e69", + "#607d68", + "#607d67", + "#5f7c66", + "#5e7c65", + "#5d7b65", + "#5c7a64", + "#5b7a63", + "#5a7962", + "#5a7861", + "#597861", + "#587760", + "#57765f", + "#56765e", + "#55755d", + "#55745d", + "#54745c", + "#53735b", + "#52725a", + "#51725a", + "#507159", + "#4f7058", + "#4f7057", + "#4e6f56", + "#4d6e56", + "#4c6e55", + "#4b6d54", + "#4a6d53", + "#4a6c53", + "#496b52", + "#486b51", + "#476a50", + "#466950", + "#45694f", + "#44684e", + "#44674d", + "#43674c", + "#42664c", + "#41654b", + "#40654a", + "#3f6449", + "#3e6449", + "#3e6348", + "#3d6247", + "#3c6246", + "#3b6146", + "#3a6045", + "#396044", + "#385f43", + "#385e43", + "#375e42", + "#365d41", + "#355c40", + "#345c40", + "#335b3f", + "#325b3e", + "#315a3d", + "#30593d", + "#30593c", + "#2f583b", + "#2e573a", + "#2d573a", + "#2c5639", + "#2b5538", + "#2a5537", + "#295437", + "#285436", + "#275335", + "#265234", + "#265234", + "#255133", + "#245032", + "#235031", + "#224f31", + "#214e30", + "#204e2f", + "#1f4d2e", + "#1e4c2e", + "#1d4c2d", + "#1c4b2c", + "#1b4b2b", + "#1a4a2b", + "#18492a", + "#174929", + "#164828", + "#154728", + "#144727", + "#134626", + "#124525", + "#104525", + "#0f4424", + ], + "flex_blue_seq": [ + "#ffffff", + "#fefefe", + "#fdfdfd", + "#fbfcfc", + "#fafbfb", + "#f9fafa", + "#f8f9f9", + "#f7f7f8", + "#f6f6f8", + "#f4f5f7", + "#f3f4f6", + "#f2f3f5", + "#f1f2f4", + "#f0f1f3", + "#eff0f2", + "#eeeff1", + "#eceef0", + "#ebedf0", + "#eaecef", + "#e9ebee", + "#e8eaed", + "#e7e9ec", + "#e6e8eb", + "#e4e7ea", + "#e3e6ea", + "#e2e5e9", + "#e1e4e8", + "#e0e3e7", + "#dfe2e6", + "#dee1e5", + "#dde0e5", + "#dcdfe4", + "#dadee3", + "#d9dde2", + "#d8dce1", + "#d7dbe0", + "#d6dae0", + "#d5d9df", + "#d4d8de", + "#d3d7dd", + "#d2d6dc", + "#d1d5dc", + "#d0d4db", + "#cfd3da", + "#ced2d9", + "#ccd1d9", + "#cbd0d8", + "#cacfd7", + "#c9ced6", + "#c8cdd5", + "#c7ccd5", + "#c6ccd4", + "#c5cbd3", + "#c4cad2", + "#c3c9d2", + "#c2c8d1", + "#c1c7d0", + "#c0c6cf", + "#bfc5cf", + "#bec4ce", + "#bdc3cd", + "#bcc2cd", + "#bbc1cc", + "#bac0cb", + "#b9c0ca", + "#b8bfca", + "#b7bec9", + "#b6bdc8", + "#b5bcc7", + "#b4bbc7", + "#b3bac6", + "#b2b9c5", + "#b1b8c5", + "#b0b7c4", + "#afb7c3", + "#aeb6c3", + "#adb5c2", + "#acb4c1", + "#abb3c1", + "#aab2c0", + "#a9b1bf", + "#a8b0be", + "#a7b0be", + "#a6afbd", + "#a5aebc", + "#a4adbc", + "#a3acbb", + "#a2abba", + "#a1aaba", + "#a0aab9", + "#9fa9b8", + "#9ea8b8", + "#9da7b7", + "#9ca6b7", + "#9ba5b6", + "#9aa4b5", + "#99a4b5", + "#98a3b4", + "#97a2b3", + "#96a1b3", + "#95a0b2", + "#949fb1", + "#939fb1", + "#929eb0", + "#919db0", + "#909caf", + "#8f9bae", + "#8f9aae", + "#8e9aad", + "#8d99ac", + "#8c98ac", + "#8b97ab", + "#8a96ab", + "#8996aa", + "#8895a9", + "#8794a9", + "#8693a8", + "#8592a8", + "#8492a7", + "#8391a6", + "#8290a6", + "#818fa5", + "#818ea5", + "#808ea4", + "#7f8da3", + "#7e8ca3", + "#7d8ba2", + "#7c8aa2", + "#7b8aa1", + "#7a89a1", + "#7988a0", + "#78879f", + "#77869f", + "#76869e", + "#76859e", + "#75849d", + "#74839d", + "#73829c", + "#72829b", + "#71819b", + "#70809a", + "#6f7f9a", + "#6e7f99", + "#6d7e99", + "#6c7d98", + "#6c7c98", + "#6b7c97", + "#6a7b97", + "#697a96", + "#687995", + "#677895", + "#667894", + "#657794", + "#647693", + "#637593", + "#637592", + "#627492", + "#617391", + "#607291", + "#5f7290", + "#5e7190", + "#5d708f", + "#5c6f8f", + "#5b6f8e", + "#5b6e8e", + "#5a6d8d", + "#596c8c", + "#586b8c", + "#576b8b", + "#566a8b", + "#55698a", + "#54688a", + "#536889", + "#536789", + "#526688", + "#516588", + "#506587", + "#4f6487", + "#4e6386", + "#4d6286", + "#4c6285", + "#4b6185", + "#4a6084", + "#4a5f84", + "#495f83", + "#485e83", + "#475d83", + "#465d82", + "#455c82", + "#445b81", + "#435a81", + "#425a80", + "#425980", + "#41587f", + "#40577f", + "#3f577e", + "#3e567e", + "#3d557d", + "#3c547d", + "#3b547c", + "#3a537c", + "#39527b", + "#39517b", + "#38517b", + "#37507a", + "#364f7a", + "#354e79", + "#344e79", + "#334d78", + "#324c78", + "#314b77", + "#304b77", + "#2f4a76", + "#2e4976", + "#2d4876", + "#2c4875", + "#2c4775", + "#2b4674", + "#2a4574", + "#294473", + "#284473", + "#274373", + "#264272", + "#254172", + "#244171", + "#234071", + "#223f70", + "#213e70", + "#203e70", + "#1f3d6f", + "#1e3c6f", + "#1d3b6e", + "#1c3a6e", + "#1b3a6e", + "#1a396d", + "#19386d", + "#17376c", + "#16366c", + "#15366c", + "#14356b", + "#13346b", + "#12336b", + "#10326a", + "#0f326a", + "#0e316a", + "#0d3069", + "#0b2f69", + "#0a2e68", + "#082d68", + "#072c68", + "#062c68", + "#042b67", + "#032a67", + "#022967", + "#012866", + "#002766", + ], + "flex_orange_seq": [ + "#ffffff", + "#fefefe", + "#fefdfd", + "#fdfdfc", + "#fdfcfb", + "#fcfbfa", + "#fbfafa", + "#fbf9f9", + "#faf9f8", + "#faf8f7", + "#f9f7f6", + "#f8f6f5", + "#f8f6f4", + "#f7f5f3", + "#f7f4f2", + "#f6f3f1", + "#f5f2f1", + "#f5f2f0", + "#f4f1ef", + "#f3f0ee", + "#f3efed", + "#f2efec", + "#f2eeeb", + "#f1edea", + "#f1ece9", + "#f0ece8", + "#f0ebe7", + "#efeae6", + "#efe9e5", + "#eee9e4", + "#eee8e3", + "#ede7e2", + "#ede6e1", + "#ece5e0", + "#ece5df", + "#ebe4de", + "#ebe3dd", + "#eae2dc", + "#eae2db", + "#e9e1da", + "#e9e0d9", + "#e9dfd8", + "#e8dfd7", + "#e8ded6", + "#e7ddd5", + "#e7dcd4", + "#e6dbd3", + "#e6dbd2", + "#e6dad1", + "#e5d9d0", + "#e5d8cf", + "#e4d8ce", + "#e4d7cd", + "#e3d6cc", + "#e3d5cb", + "#e3d5ca", + "#e2d4c9", + "#e2d3c8", + "#e1d2c7", + "#e1d2c6", + "#e0d1c5", + "#e0d0c4", + "#e0cfc3", + "#dfcfc2", + "#dfcec1", + "#decdc0", + "#deccbf", + "#deccbe", + "#ddcbbd", + "#ddcabc", + "#dcc9bb", + "#dcc9ba", + "#dcc8b9", + "#dbc7b8", + "#dbc6b8", + "#dbc6b7", + "#dac5b6", + "#dac4b5", + "#d9c4b4", + "#d9c3b3", + "#d9c2b2", + "#d8c1b1", + "#d8c1b0", + "#d7c0af", + "#d7bfae", + "#d7bead", + "#d6beac", + "#d6bdab", + "#d6bcaa", + "#d5bba9", + "#d5bba8", + "#d4baa7", + "#d4b9a6", + "#d4b9a5", + "#d3b8a4", + "#d3b7a3", + "#d3b6a2", + "#d2b6a1", + "#d2b5a0", + "#d2b49f", + "#d1b49e", + "#d1b39d", + "#d0b29c", + "#d0b19b", + "#d0b19a", + "#cfb099", + "#cfaf99", + "#cfaf98", + "#ceae97", + "#cead96", + "#ceac95", + "#cdac94", + "#cdab93", + "#cdaa92", + "#ccaa91", + "#cca990", + "#cca88f", + "#cba78e", + "#cba78d", + "#caa68c", + "#caa58b", + "#caa58a", + "#c9a489", + "#c9a388", + "#c9a387", + "#c8a286", + "#c8a185", + "#c8a085", + "#c7a084", + "#c79f83", + "#c79e82", + "#c69e81", + "#c69d80", + "#c69c7f", + "#c59c7e", + "#c59b7d", + "#c59a7c", + "#c4997b", + "#c4997a", + "#c49879", + "#c39778", + "#c39777", + "#c29676", + "#c29575", + "#c29575", + "#c19474", + "#c19373", + "#c19372", + "#c09271", + "#c09170", + "#c0906f", + "#bf906e", + "#bf8f6d", + "#bf8e6c", + "#be8e6b", + "#be8d6a", + "#be8c69", + "#bd8c68", + "#bd8b67", + "#bd8a67", + "#bc8a66", + "#bc8965", + "#bc8864", + "#bb8863", + "#bb8762", + "#bb8661", + "#ba8660", + "#ba855f", + "#ba845e", + "#b9835d", + "#b9835c", + "#b8825b", + "#b8815b", + "#b8815a", + "#b78059", + "#b77f58", + "#b77f57", + "#b67e56", + "#b67d55", + "#b67d54", + "#b57c53", + "#b57b52", + "#b57b51", + "#b47a50", + "#b4794f", + "#b4794f", + "#b3784e", + "#b3774d", + "#b3774c", + "#b2764b", + "#b2754a", + "#b17549", + "#b17448", + "#b17347", + "#b07346", + "#b07245", + "#b07144", + "#af7144", + "#af7043", + "#af6f42", + "#ae6f41", + "#ae6e40", + "#ae6d3f", + "#ad6d3e", + "#ad6c3d", + "#ac6b3c", + "#ac6b3b", + "#ac6a3a", + "#ab6939", + "#ab6939", + "#ab6838", + "#aa6737", + "#aa6736", + "#aa6635", + "#a96534", + "#a96533", + "#a86432", + "#a86331", + "#a86330", + "#a7622f", + "#a7612e", + "#a7612d", + "#a6602c", + "#a65f2b", + "#a55f2a", + "#a55e2a", + "#a55d29", + "#a45d28", + "#a45c27", + "#a35b26", + "#a35b25", + "#a35a24", + "#a25923", + "#a25922", + "#a25821", + "#a15720", + "#a1571f", + "#a0561e", + "#a0551d", + "#a0551c", + "#9f541b", + "#9f531a", + "#9e5318", + "#9e5217", + "#9e5116", + "#9d5115", + "#9d5014", + "#9c4f13", + "#9c4f12", + "#9b4e10", + "#9b4d0f", + "#9b4d0e", + "#9a4c0c", + "#9a4b0b", + "#994b09", + "#994a08", + ], + "flex_red_seq": [ + "#ffffff", + "#fefefe", + "#fefdfd", + "#fdfcfc", + "#fcfbfb", + "#fcfafa", + "#fbf9f9", + "#faf8f8", + "#faf7f7", + "#f9f6f6", + "#f8f5f5", + "#f8f4f5", + "#f7f3f4", + "#f6f2f3", + "#f5f2f2", + "#f5f1f1", + "#f4f0f0", + "#f3efef", + "#f3eeee", + "#f2eded", + "#f1ecec", + "#f1ebec", + "#f0eaeb", + "#efe9ea", + "#efe8e9", + "#eee7e8", + "#eee6e7", + "#ede5e6", + "#ece4e6", + "#ece3e5", + "#ebe2e4", + "#ebe1e3", + "#eae0e2", + "#eae0e1", + "#e9dfe0", + "#e9dedf", + "#e8dddf", + "#e8dcde", + "#e7dbdd", + "#e7dadc", + "#e6d9db", + "#e6d8da", + "#e5d7d9", + "#e5d6d8", + "#e4d5d7", + "#e4d4d7", + "#e3d3d6", + "#e3d2d5", + "#e2d1d4", + "#e2d0d3", + "#e1d0d2", + "#e1cfd1", + "#e0ced1", + "#e0cdd0", + "#dfcccf", + "#dfcbce", + "#decacd", + "#dec9cc", + "#ddc8cb", + "#ddc7cb", + "#dcc6ca", + "#dcc5c9", + "#dbc4c8", + "#dbc4c7", + "#dbc3c6", + "#dac2c5", + "#dac1c5", + "#d9c0c4", + "#d9bfc3", + "#d8bec2", + "#d8bdc1", + "#d7bcc0", + "#d7bbc0", + "#d7babf", + "#d6babe", + "#d6b9bd", + "#d5b8bc", + "#d5b7bb", + "#d4b6bb", + "#d4b5ba", + "#d4b4b9", + "#d3b3b8", + "#d3b2b7", + "#d2b1b6", + "#d2b0b6", + "#d1b0b5", + "#d1afb4", + "#d1aeb3", + "#d0adb2", + "#d0acb1", + "#cfabb1", + "#cfaab0", + "#cfa9af", + "#cea8ae", + "#cea8ad", + "#cda7ad", + "#cda6ac", + "#cca5ab", + "#cca4aa", + "#cca3a9", + "#cba2a9", + "#cba1a8", + "#caa0a7", + "#caa0a6", + "#ca9fa5", + "#c99ea5", + "#c99da4", + "#c89ca3", + "#c89ba2", + "#c89aa1", + "#c799a1", + "#c799a0", + "#c6989f", + "#c6979e", + "#c6969d", + "#c5959d", + "#c5949c", + "#c4939b", + "#c4929a", + "#c49299", + "#c39199", + "#c39098", + "#c28f97", + "#c28e96", + "#c28d96", + "#c18c95", + "#c18c94", + "#c18b93", + "#c08a92", + "#c08992", + "#bf8891", + "#bf8790", + "#bf868f", + "#be858f", + "#be858e", + "#bd848d", + "#bd838c", + "#bd828c", + "#bc818b", + "#bc808a", + "#bb7f89", + "#bb7f88", + "#bb7e88", + "#ba7d87", + "#ba7c86", + "#b97b85", + "#b97a85", + "#b97984", + "#b87983", + "#b87882", + "#b87782", + "#b77681", + "#b77580", + "#b6747f", + "#b6737f", + "#b6737e", + "#b5727d", + "#b5717c", + "#b4707c", + "#b46f7b", + "#b46e7a", + "#b36d79", + "#b36d79", + "#b26c78", + "#b26b77", + "#b26a76", + "#b16976", + "#b16875", + "#b06774", + "#b06773", + "#b06673", + "#af6572", + "#af6471", + "#ae6371", + "#ae6270", + "#ae616f", + "#ad616e", + "#ad606e", + "#ac5f6d", + "#ac5e6c", + "#ab5d6b", + "#ab5c6b", + "#ab5b6a", + "#aa5b69", + "#aa5a69", + "#a95968", + "#a95867", + "#a95766", + "#a85666", + "#a85565", + "#a75464", + "#a75463", + "#a65363", + "#a65262", + "#a65161", + "#a55061", + "#a54f60", + "#a44e5f", + "#a44d5e", + "#a34d5e", + "#a34c5d", + "#a34b5c", + "#a24a5c", + "#a2495b", + "#a1485a", + "#a1475a", + "#a04659", + "#a04558", + "#9f4557", + "#9f4457", + "#9f4356", + "#9e4255", + "#9e4155", + "#9d4054", + "#9d3f53", + "#9c3e52", + "#9c3d52", + "#9b3c51", + "#9b3b50", + "#9a3b50", + "#9a3a4f", + "#99394e", + "#99384e", + "#98374d", + "#98364c", + "#98354b", + "#97344b", + "#97334a", + "#963249", + "#963149", + "#953048", + "#952f47", + "#942e47", + "#942d46", + "#932c45", + "#932b45", + "#922a44", + "#922943", + "#912843", + "#912742", + "#902641", + "#902540", + "#8f2440", + "#8e223f", + "#8e213e", + "#8d203e", + "#8d1f3d", + "#8c1e3c", + "#8c1d3c", + "#8b1b3b", + "#8b1a3a", + "#8a193a", + "#8a1739", + "#891638", + "#891438", + "#881337", + ], + "flex_purple_seq": [ + "#ffffff", + "#fefefe", + "#fdfdfd", + "#fcfcfd", + "#fbfbfc", + "#fafafb", + "#f9f9fa", + "#f8f8f9", + "#f7f7f9", + "#f6f6f8", + "#f5f5f7", + "#f4f4f6", + "#f3f3f6", + "#f2f2f5", + "#f1f1f4", + "#f0f0f3", + "#efeff3", + "#eeeef2", + "#ededf1", + "#ececf0", + "#ebebf0", + "#eaeaef", + "#e9e9ee", + "#e8e8ed", + "#e8e8ed", + "#e7e7ec", + "#e6e6eb", + "#e5e5eb", + "#e4e4ea", + "#e3e3e9", + "#e2e2e8", + "#e1e1e8", + "#e0e0e7", + "#dfdfe6", + "#dedee6", + "#dddde5", + "#dcdce4", + "#dbdbe4", + "#dadae3", + "#d9dae2", + "#d9d9e2", + "#d8d8e1", + "#d7d7e0", + "#d6d6e0", + "#d5d5df", + "#d4d4df", + "#d3d3de", + "#d2d2dd", + "#d1d1dd", + "#d0d1dc", + "#d0d0db", + "#cfcfdb", + "#ceceda", + "#cdcdd9", + "#ccccd9", + "#cbcbd8", + "#cacad8", + "#c9c9d7", + "#c8c9d6", + "#c8c8d6", + "#c7c7d5", + "#c6c6d5", + "#c5c5d4", + "#c4c4d3", + "#c3c3d3", + "#c2c2d2", + "#c2c2d2", + "#c1c1d1", + "#c0c0d1", + "#bfbfd0", + "#bebecf", + "#bdbdcf", + "#bcbcce", + "#bcbcce", + "#bbbbcd", + "#babacd", + "#b9b9cc", + "#b8b8cc", + "#b7b7cb", + "#b7b7ca", + "#b6b6ca", + "#b5b5c9", + "#b4b4c9", + "#b3b3c8", + "#b3b2c8", + "#b2b2c7", + "#b1b1c7", + "#b0b0c6", + "#afafc6", + "#aeaec5", + "#aeadc5", + "#adadc4", + "#acacc4", + "#ababc3", + "#aaaac3", + "#aaa9c2", + "#a9a8c2", + "#a8a8c1", + "#a7a7c1", + "#a6a6c0", + "#a6a5c0", + "#a5a4bf", + "#a4a4bf", + "#a3a3be", + "#a3a2be", + "#a2a1bd", + "#a1a0bd", + "#a0a0bc", + "#9f9fbc", + "#9f9ebb", + "#9e9dbb", + "#9d9cba", + "#9c9cba", + "#9c9bb9", + "#9b9ab9", + "#9a99b8", + "#9998b8", + "#9998b8", + "#9897b7", + "#9796b7", + "#9695b6", + "#9694b6", + "#9594b5", + "#9493b5", + "#9392b4", + "#9391b4", + "#9291b4", + "#9190b3", + "#908fb3", + "#908eb2", + "#8f8db2", + "#8e8db1", + "#8d8cb1", + "#8d8bb1", + "#8c8ab0", + "#8b8ab0", + "#8a89af", + "#8a88af", + "#8987ae", + "#8886ae", + "#8886ae", + "#8785ad", + "#8684ad", + "#8583ac", + "#8583ac", + "#8482ac", + "#8381ab", + "#8280ab", + "#8280ab", + "#817faa", + "#807eaa", + "#807da9", + "#7f7ca9", + "#7e7ca9", + "#7e7ba8", + "#7d7aa8", + "#7c79a8", + "#7b79a7", + "#7b78a7", + "#7a77a6", + "#7976a6", + "#7976a6", + "#7875a5", + "#7774a5", + "#7773a5", + "#7673a4", + "#7572a4", + "#7571a4", + "#7470a3", + "#736fa3", + "#736fa3", + "#726ea2", + "#716da2", + "#716ca2", + "#706ca1", + "#6f6ba1", + "#6f6aa1", + "#6e69a0", + "#6d69a0", + "#6d68a0", + "#6c679f", + "#6b669f", + "#6b669f", + "#6a659e", + "#69649e", + "#69639e", + "#68629d", + "#67629d", + "#67619d", + "#66609d", + "#655f9c", + "#655f9c", + "#645e9c", + "#635d9b", + "#635c9b", + "#625c9b", + "#625b9b", + "#615a9a", + "#60599a", + "#60589a", + "#5f589a", + "#5e5799", + "#5e5699", + "#5d5599", + "#5d5498", + "#5c5498", + "#5b5398", + "#5b5298", + "#5a5198", + "#595097", + "#595097", + "#584f97", + "#584e97", + "#574d96", + "#564c96", + "#564c96", + "#554b96", + "#554a95", + "#544995", + "#544895", + "#534795", + "#524795", + "#524694", + "#514594", + "#514494", + "#504394", + "#504294", + "#4f4294", + "#4e4193", + "#4e4093", + "#4d3f93", + "#4d3e93", + "#4c3d93", + "#4c3c93", + "#4b3b93", + "#4b3a92", + "#4a3992", + "#4a3892", + "#493892", + "#493792", + "#483692", + "#483592", + "#473492", + "#473391", + "#463291", + "#463191", + "#452f91", + "#452e91", + "#442d91", + "#442c91", + "#432b91", + "#432a91", + "#422991", + "#422891", + "#412691", + "#412591", + ], + "flex_grey_seq": [ + "#ffffff", + "#fefefe", + "#fdfdfd", + "#fcfcfc", + "#fbfbfc", + "#fafafb", + "#f9f9fa", + "#f8f9f9", + "#f8f8f8", + "#f7f7f7", + "#f6f6f6", + "#f5f5f6", + "#f4f4f5", + "#f3f3f4", + "#f2f2f3", + "#f1f1f2", + "#f0f0f1", + "#eff0f1", + "#eeeff0", + "#eeeeef", + "#ededee", + "#ececed", + "#ebebec", + "#eaeaec", + "#e9e9eb", + "#e8e9ea", + "#e7e8e9", + "#e6e7e8", + "#e6e6e8", + "#e5e5e7", + "#e4e4e6", + "#e3e3e5", + "#e2e3e4", + "#e1e2e4", + "#e0e1e3", + "#dfe0e2", + "#dfdfe1", + "#dedee0", + "#dddde0", + "#dcdddf", + "#dbdcde", + "#dadbdd", + "#d9dadd", + "#d9d9dc", + "#d8d8db", + "#d7d8da", + "#d6d7da", + "#d5d6d9", + "#d4d5d8", + "#d4d4d7", + "#d3d4d6", + "#d2d3d6", + "#d1d2d5", + "#d0d1d4", + "#cfd0d3", + "#cfcfd3", + "#cecfd2", + "#cdced1", + "#cccdd0", + "#cbccd0", + "#cacbcf", + "#cacbce", + "#c9cace", + "#c8c9cd", + "#c7c8cc", + "#c6c8cb", + "#c6c7cb", + "#c5c6ca", + "#c4c5c9", + "#c3c4c8", + "#c2c4c8", + "#c2c3c7", + "#c1c2c6", + "#c0c1c6", + "#bfc0c5", + "#bec0c4", + "#bebfc3", + "#bdbec3", + "#bcbdc2", + "#bbbdc1", + "#babcc1", + "#babbc0", + "#b9babf", + "#b8babe", + "#b7b9be", + "#b7b8bd", + "#b6b7bc", + "#b5b7bc", + "#b4b6bb", + "#b3b5ba", + "#b3b4ba", + "#b2b4b9", + "#b1b3b8", + "#b0b2b8", + "#b0b1b7", + "#afb1b6", + "#aeb0b6", + "#adafb5", + "#adaeb4", + "#acaeb3", + "#abadb3", + "#aaacb2", + "#aaabb1", + "#a9abb1", + "#a8aab0", + "#a7a9af", + "#a7a8af", + "#a6a8ae", + "#a5a7ad", + "#a4a6ad", + "#a4a6ac", + "#a3a5ab", + "#a2a4ab", + "#a1a3aa", + "#a1a3aa", + "#a0a2a9", + "#9fa1a8", + "#9ea1a8", + "#9ea0a7", + "#9d9fa6", + "#9c9ea6", + "#9c9ea5", + "#9b9da4", + "#9a9ca4", + "#999ca3", + "#999ba2", + "#989aa2", + "#979aa1", + "#9799a0", + "#9698a0", + "#95979f", + "#94979f", + "#94969e", + "#93959d", + "#92959d", + "#92949c", + "#91939b", + "#90939b", + "#8f929a", + "#8f919a", + "#8e9199", + "#8d9098", + "#8d8f98", + "#8c8e97", + "#8b8e96", + "#8b8d96", + "#8a8c95", + "#898c95", + "#888b94", + "#888a93", + "#878a93", + "#868992", + "#868891", + "#858891", + "#848790", + "#848690", + "#83868f", + "#82858e", + "#82848e", + "#81848d", + "#80838d", + "#80828c", + "#7f828b", + "#7e818b", + "#7e808a", + "#7d808a", + "#7c7f89", + "#7b7e88", + "#7b7e88", + "#7a7d87", + "#797c87", + "#797c86", + "#787b85", + "#777a85", + "#777a84", + "#767984", + "#757983", + "#757882", + "#747782", + "#737781", + "#737681", + "#727580", + "#71757f", + "#71747f", + "#70737e", + "#70737e", + "#6f727d", + "#6e717d", + "#6e717c", + "#6d707b", + "#6c707b", + "#6c6f7a", + "#6b6e7a", + "#6a6e79", + "#6a6d79", + "#696c78", + "#686c77", + "#686b77", + "#676a76", + "#666a76", + "#666975", + "#656974", + "#646874", + "#646773", + "#636773", + "#636672", + "#626572", + "#616571", + "#616470", + "#606370", + "#5f636f", + "#5f626f", + "#5e626e", + "#5d616e", + "#5d606d", + "#5c606c", + "#5b5f6c", + "#5b5e6b", + "#5a5e6b", + "#5a5d6a", + "#595d6a", + "#585c69", + "#585b69", + "#575b68", + "#565a67", + "#565967", + "#555966", + "#555866", + "#545865", + "#535765", + "#535664", + "#525663", + "#515563", + "#515562", + "#505462", + "#4f5361", + "#4f5361", + "#4e5260", + "#4e5160", + "#4d515f", + "#4c505e", + "#4c505e", + "#4b4f5d", + "#4a4e5d", + "#4a4e5c", + "#494d5c", + "#494d5b", + "#484c5a", + "#474b5a", + "#474b59", + "#464a59", + "#454958", + "#454958", + "#444857", + "#444857", + "#434756", + ], +} +CATEGORICAL_PALETTES_HEX = { + "flex_distinct": [ + "#176737", + "#FF7B0D", + "#979BAA", + "#F44E6A", + "#0062FF", + "#26AB5B", + "#6D3EF2", + "#F59E0B", + ] +} +DIVERGING_PALETTES_HEX = { + "flex_BuRd": [ + "#002766", + "#022967", + "#052b67", + "#072d68", + "#0a2e69", + "#0d3069", + "#10326a", + "#12346b", + "#15356c", + "#17376c", + "#1a396d", + "#1c3a6e", + "#1e3c6f", + "#203e70", + "#223f71", + "#244171", + "#264372", + "#284473", + "#2a4674", + "#2c4775", + "#2e4976", + "#304a77", + "#324c78", + "#344e79", + "#364f7a", + "#38517b", + "#3a527c", + "#3c547d", + "#3e557e", + "#3f577f", + "#415980", + "#435a80", + "#455c81", + "#475d83", + "#495f84", + "#4b6085", + "#4c6286", + "#4e6387", + "#506588", + "#526789", + "#54688a", + "#566a8b", + "#586b8c", + "#5a6d8d", + "#5b6e8e", + "#5d708f", + "#5f7290", + "#617391", + "#637592", + "#657694", + "#677895", + "#687a96", + "#6a7b97", + "#6c7d98", + "#6e7e99", + "#70809a", + "#72829b", + "#74839d", + "#76859e", + "#78879f", + "#7a88a0", + "#7b8aa1", + "#7d8ca3", + "#7f8da4", + "#818fa5", + "#8391a6", + "#8592a8", + "#8794a9", + "#8996aa", + "#8b97ab", + "#8d99ad", + "#8f9bae", + "#919daf", + "#939eb1", + "#95a0b2", + "#97a2b3", + "#99a4b4", + "#9ba5b6", + "#9da7b7", + "#9fa9b9", + "#a1abba", + "#a3acbb", + "#a5aebd", + "#a7b0be", + "#a9b2c0", + "#abb4c1", + "#adb6c2", + "#afb7c4", + "#b1b9c5", + "#b4bbc7", + "#b6bdc8", + "#b8bfca", + "#bac1cb", + "#bcc3cd", + "#bec5ce", + "#c0c6d0", + "#c3c8d1", + "#c5cad3", + "#c7ccd5", + "#c9ced6", + "#cbd0d8", + "#ced2d9", + "#d0d4db", + "#d2d6dd", + "#d4d8de", + "#d7dae0", + "#d9dce2", + "#dbdee3", + "#dee0e5", + "#e0e3e7", + "#e2e5e9", + "#e4e7ea", + "#e7e9ec", + "#e9ebee", + "#ecedf0", + "#eeeff2", + "#f0f2f3", + "#f3f4f5", + "#f5f6f7", + "#f8f8f9", + "#fafafb", + "#fdfdfd", + "#FFFFFF", + "#fefdfd", + "#fcfbfb", + "#fbf9f9", + "#f9f7f7", + "#f8f5f5", + "#f6f3f3", + "#f5f1f1", + "#f4efef", + "#f2edee", + "#f1ebec", + "#efe9ea", + "#eee7e8", + "#ede5e6", + "#ece3e4", + "#ebe1e3", + "#e9dfe1", + "#e8dddf", + "#e7dbdd", + "#e6d9db", + "#e5d7d9", + "#e4d5d8", + "#e3d3d6", + "#e2d1d4", + "#e1cfd2", + "#e0cdd0", + "#dfcccf", + "#decacd", + "#ddc8cb", + "#dcc6c9", + "#dbc4c7", + "#dac2c6", + "#d9c0c4", + "#d8bec2", + "#d7bcc0", + "#d6babf", + "#d6b8bd", + "#d5b6bb", + "#d4b5b9", + "#d3b3b8", + "#d2b1b6", + "#d1afb4", + "#d0adb2", + "#cfabb1", + "#cfa9af", + "#cea8ad", + "#cda6ac", + "#cca4aa", + "#cba2a8", + "#caa0a7", + "#c99ea5", + "#c99ca3", + "#c89ba2", + "#c799a0", + "#c6979e", + "#c5959d", + "#c4939b", + "#c49199", + "#c39098", + "#c28e96", + "#c18c94", + "#c08a93", + "#c08891", + "#bf8790", + "#be858e", + "#bd838c", + "#bc818b", + "#bb7f89", + "#bb7e88", + "#ba7c86", + "#b97a84", + "#b87883", + "#b77681", + "#b67580", + "#b6737e", + "#b5717d", + "#b46f7b", + "#b36d79", + "#b26c78", + "#b26a76", + "#b16875", + "#b06673", + "#af6572", + "#ae6370", + "#ad616f", + "#ac5f6d", + "#ac5d6c", + "#ab5c6a", + "#aa5a69", + "#a95867", + "#a85666", + "#a75464", + "#a65263", + "#a55161", + "#a44f60", + "#a44d5e", + "#a34b5d", + "#a2495b", + "#a1475a", + "#a04658", + "#9f4457", + "#9e4255", + "#9d4054", + "#9c3e52", + "#9b3c51", + "#9a3a4f", + "#99384e", + "#98364c", + "#97344b", + "#96324a", + "#953048", + "#942e47", + "#932c45", + "#922a44", + "#912842", + "#902541", + "#8f233f", + "#8e213e", + "#8d1e3d", + "#8b1c3b", + "#8a193a", + "#891638", + "#881337", + ], + "flex_RdBu": [ + "#881337", + "#891638", + "#8a193a", + "#8b1c3b", + "#8d1e3d", + "#8e213e", + "#8f233f", + "#902541", + "#912842", + "#922a44", + "#932c45", + "#942e47", + "#953048", + "#96324a", + "#97344b", + "#98364c", + "#99384e", + "#9a3a4f", + "#9b3c51", + "#9c3e52", + "#9d4054", + "#9e4255", + "#9f4457", + "#a04658", + "#a1475a", + "#a2495b", + "#a34b5d", + "#a44d5e", + "#a44f60", + "#a55161", + "#a65263", + "#a75464", + "#a85666", + "#a95867", + "#aa5a69", + "#ab5c6a", + "#ac5d6c", + "#ac5f6d", + "#ad616f", + "#ae6370", + "#af6572", + "#b06673", + "#b16875", + "#b26a76", + "#b26c78", + "#b36d79", + "#b46f7b", + "#b5717d", + "#b6737e", + "#b67580", + "#b77681", + "#b87883", + "#b97a84", + "#ba7c86", + "#bb7e88", + "#bb7f89", + "#bc818b", + "#bd838c", + "#be858e", + "#bf8790", + "#c08891", + "#c08a93", + "#c18c94", + "#c28e96", + "#c39098", + "#c49199", + "#c4939b", + "#c5959d", + "#c6979e", + "#c799a0", + "#c89ba2", + "#c99ca3", + "#c99ea5", + "#caa0a7", + "#cba2a8", + "#cca4aa", + "#cda6ac", + "#cea8ad", + "#cfa9af", + "#cfabb1", + "#d0adb2", + "#d1afb4", + "#d2b1b6", + "#d3b3b8", + "#d4b5b9", + "#d5b6bb", + "#d6b8bd", + "#d6babf", + "#d7bcc0", + "#d8bec2", + "#d9c0c4", + "#dac2c6", + "#dbc4c7", + "#dcc6c9", + "#ddc8cb", + "#decacd", + "#dfcccf", + "#e0cdd0", + "#e1cfd2", + "#e2d1d4", + "#e3d3d6", + "#e4d5d8", + "#e5d7d9", + "#e6d9db", + "#e7dbdd", + "#e8dddf", + "#e9dfe1", + "#ebe1e3", + "#ece3e4", + "#ede5e6", + "#eee7e8", + "#efe9ea", + "#f1ebec", + "#f2edee", + "#f4efef", + "#f5f1f1", + "#f6f3f3", + "#f8f5f5", + "#f9f7f7", + "#fbf9f9", + "#fcfbfb", + "#fefdfd", + "#FFFFFF", + "#fdfdfd", + "#fafafb", + "#f8f8f9", + "#f5f6f7", + "#f3f4f5", + "#f0f2f3", + "#eeeff2", + "#ecedf0", + "#e9ebee", + "#e7e9ec", + "#e4e7ea", + "#e2e5e9", + "#e0e3e7", + "#dee0e5", + "#dbdee3", + "#d9dce2", + "#d7dae0", + "#d4d8de", + "#d2d6dd", + "#d0d4db", + "#ced2d9", + "#cbd0d8", + "#c9ced6", + "#c7ccd5", + "#c5cad3", + "#c3c8d1", + "#c0c6d0", + "#bec5ce", + "#bcc3cd", + "#bac1cb", + "#b8bfca", + "#b6bdc8", + "#b4bbc7", + "#b1b9c5", + "#afb7c4", + "#adb6c2", + "#abb4c1", + "#a9b2c0", + "#a7b0be", + "#a5aebd", + "#a3acbb", + "#a1abba", + "#9fa9b9", + "#9da7b7", + "#9ba5b6", + "#99a4b4", + "#97a2b3", + "#95a0b2", + "#939eb1", + "#919daf", + "#8f9bae", + "#8d99ad", + "#8b97ab", + "#8996aa", + "#8794a9", + "#8592a8", + "#8391a6", + "#818fa5", + "#7f8da4", + "#7d8ca3", + "#7b8aa1", + "#7a88a0", + "#78879f", + "#76859e", + "#74839d", + "#72829b", + "#70809a", + "#6e7e99", + "#6c7d98", + "#6a7b97", + "#687a96", + "#677895", + "#657694", + "#637592", + "#617391", + "#5f7290", + "#5d708f", + "#5b6e8e", + "#5a6d8d", + "#586b8c", + "#566a8b", + "#54688a", + "#526789", + "#506588", + "#4e6387", + "#4c6286", + "#4b6085", + "#495f84", + "#475d83", + "#455c81", + "#435a80", + "#415980", + "#3f577f", + "#3e557e", + "#3c547d", + "#3a527c", + "#38517b", + "#364f7a", + "#344e79", + "#324c78", + "#304a77", + "#2e4976", + "#2c4775", + "#2a4674", + "#284473", + "#264372", + "#244171", + "#223f71", + "#203e70", + "#1e3c6f", + "#1c3a6e", + "#1a396d", + "#17376c", + "#15356c", + "#12346b", + "#10326a", + "#0d3069", + "#0a2e69", + "#072d68", + "#052b67", + "#022967", + "#002766", + ], + "flex_GrPu": [ + "#0f4424", + "#124526", + "#144727", + "#174829", + "#19492a", + "#1b4b2c", + "#1d4c2d", + "#1f4d2f", + "#214f30", + "#235032", + "#255234", + "#275335", + "#295437", + "#2b5638", + "#2d573a", + "#2f583b", + "#315a3d", + "#335b3e", + "#355c40", + "#365e42", + "#385f43", + "#3a6045", + "#3c6246", + "#3e6348", + "#3f644a", + "#41664b", + "#43674d", + "#45684e", + "#476a50", + "#486b52", + "#4a6c53", + "#4c6e55", + "#4e6f56", + "#4f7058", + "#51725a", + "#53735b", + "#55745d", + "#57765f", + "#587760", + "#5a7962", + "#5c7a63", + "#5e7b65", + "#5f7d67", + "#617e68", + "#637f6a", + "#65816c", + "#67826d", + "#68846f", + "#6a8571", + "#6c8672", + "#6e8874", + "#6f8976", + "#718b78", + "#738c79", + "#758e7b", + "#778f7d", + "#79907e", + "#7a9280", + "#7c9382", + "#7e9584", + "#809685", + "#829887", + "#849989", + "#859b8b", + "#879c8c", + "#899d8e", + "#8b9f90", + "#8da092", + "#8fa294", + "#91a395", + "#92a597", + "#94a699", + "#96a89b", + "#98aa9d", + "#9aab9e", + "#9cada0", + "#9eaea2", + "#a0b0a4", + "#a2b1a6", + "#a4b3a8", + "#a6b4aa", + "#a8b6ab", + "#aab8ad", + "#acb9af", + "#aebbb1", + "#b0bcb3", + "#b2beb5", + "#b4c0b7", + "#b6c1b9", + "#b8c3bb", + "#bac5bd", + "#bcc6bf", + "#bec8c1", + "#c0cac2", + "#c2cbc4", + "#c4cdc6", + "#c6cfc8", + "#c8d0ca", + "#cad2cc", + "#ccd4ce", + "#ced6d0", + "#d0d7d2", + "#d3d9d5", + "#d5dbd7", + "#d7ddd9", + "#d9dfdb", + "#dbe0dd", + "#dde2df", + "#dfe4e1", + "#e2e6e3", + "#e4e8e5", + "#e6eae7", + "#e8ebe9", + "#ebedeb", + "#edefee", + "#eff1f0", + "#f1f3f2", + "#f4f5f4", + "#f6f7f6", + "#f8f9f8", + "#fafbfb", + "#fdfdfd", + "#FFFFFF", + "#fdfdfd", + "#fbfbfc", + "#f9f9fa", + "#f7f7f8", + "#f5f5f7", + "#f3f3f5", + "#f1f0f4", + "#efeef2", + "#ececf0", + "#eaeaef", + "#e8e8ed", + "#e6e6ec", + "#e4e5ea", + "#e3e3e9", + "#e1e1e8", + "#dfdfe6", + "#dddde5", + "#dbdbe3", + "#d9d9e2", + "#d7d7e1", + "#d5d5df", + "#d3d3de", + "#d1d1dd", + "#cfd0db", + "#ceceda", + "#ccccd9", + "#cacad7", + "#c8c8d6", + "#c6c6d5", + "#c4c4d4", + "#c3c3d2", + "#c1c1d1", + "#bfbfd0", + "#bdbdcf", + "#bcbcce", + "#babacd", + "#b8b8cb", + "#b6b6ca", + "#b5b4c9", + "#b3b3c8", + "#b1b1c7", + "#afafc6", + "#aeaec5", + "#acacc4", + "#aaaac3", + "#a9a8c1", + "#a7a7c0", + "#a5a5bf", + "#a4a3be", + "#a2a2bd", + "#a1a0bc", + "#9f9ebb", + "#9d9dba", + "#9c9bb9", + "#9a99b8", + "#9898b8", + "#9796b7", + "#9594b6", + "#9493b5", + "#9291b4", + "#918fb3", + "#8f8eb2", + "#8e8cb1", + "#8c8ab0", + "#8b89af", + "#8987af", + "#8786ae", + "#8684ad", + "#8482ac", + "#8381ab", + "#817faa", + "#807eaa", + "#7f7ca9", + "#7d7aa8", + "#7c79a7", + "#7a77a6", + "#7976a6", + "#7774a5", + "#7672a4", + "#7471a4", + "#736fa3", + "#726ea2", + "#706ca1", + "#6f6aa1", + "#6d69a0", + "#6c679f", + "#6b669f", + "#69649e", + "#68629d", + "#67619d", + "#655f9c", + "#645e9c", + "#635c9b", + "#615a9a", + "#60599a", + "#5f5799", + "#5d5599", + "#5c5498", + "#5b5298", + "#595097", + "#584f97", + "#574d96", + "#564b96", + "#554a95", + "#534895", + "#524695", + "#514494", + "#504394", + "#4f4193", + "#4d3f93", + "#4c3d93", + "#4b3b92", + "#4a3992", + "#493792", + "#483592", + "#473392", + "#463191", + "#452f91", + "#442d91", + "#432a91", + "#422891", + "#412591", + ], + "flex_PuGr": [ + "#412591", + "#422891", + "#432a91", + "#442d91", + "#452f91", + "#463191", + "#473392", + "#483592", + "#493792", + "#4a3992", + "#4b3b92", + "#4c3d93", + "#4d3f93", + "#4f4193", + "#504394", + "#514494", + "#524695", + "#534895", + "#554a95", + "#564b96", + "#574d96", + "#584f97", + "#595097", + "#5b5298", + "#5c5498", + "#5d5599", + "#5f5799", + "#60599a", + "#615a9a", + "#635c9b", + "#645e9c", + "#655f9c", + "#67619d", + "#68629d", + "#69649e", + "#6b669f", + "#6c679f", + "#6d69a0", + "#6f6aa1", + "#706ca1", + "#726ea2", + "#736fa3", + "#7471a4", + "#7672a4", + "#7774a5", + "#7976a6", + "#7a77a6", + "#7c79a7", + "#7d7aa8", + "#7f7ca9", + "#807eaa", + "#817faa", + "#8381ab", + "#8482ac", + "#8684ad", + "#8786ae", + "#8987af", + "#8b89af", + "#8c8ab0", + "#8e8cb1", + "#8f8eb2", + "#918fb3", + "#9291b4", + "#9493b5", + "#9594b6", + "#9796b7", + "#9898b8", + "#9a99b8", + "#9c9bb9", + "#9d9dba", + "#9f9ebb", + "#a1a0bc", + "#a2a2bd", + "#a4a3be", + "#a5a5bf", + "#a7a7c0", + "#a9a8c1", + "#aaaac3", + "#acacc4", + "#aeaec5", + "#afafc6", + "#b1b1c7", + "#b3b3c8", + "#b5b4c9", + "#b6b6ca", + "#b8b8cb", + "#babacd", + "#bcbcce", + "#bdbdcf", + "#bfbfd0", + "#c1c1d1", + "#c3c3d2", + "#c4c4d4", + "#c6c6d5", + "#c8c8d6", + "#cacad7", + "#ccccd9", + "#ceceda", + "#cfd0db", + "#d1d1dd", + "#d3d3de", + "#d5d5df", + "#d7d7e1", + "#d9d9e2", + "#dbdbe3", + "#dddde5", + "#dfdfe6", + "#e1e1e8", + "#e3e3e9", + "#e4e5ea", + "#e6e6ec", + "#e8e8ed", + "#eaeaef", + "#ececf0", + "#efeef2", + "#f1f0f4", + "#f3f3f5", + "#f5f5f7", + "#f7f7f8", + "#f9f9fa", + "#fbfbfc", + "#fdfdfd", + "#FFFFFF", + "#fdfdfd", + "#fafbfb", + "#f8f9f8", + "#f6f7f6", + "#f4f5f4", + "#f1f3f2", + "#eff1f0", + "#edefee", + "#ebedeb", + "#e8ebe9", + "#e6eae7", + "#e4e8e5", + "#e2e6e3", + "#dfe4e1", + "#dde2df", + "#dbe0dd", + "#d9dfdb", + "#d7ddd9", + "#d5dbd7", + "#d3d9d5", + "#d0d7d2", + "#ced6d0", + "#ccd4ce", + "#cad2cc", + "#c8d0ca", + "#c6cfc8", + "#c4cdc6", + "#c2cbc4", + "#c0cac2", + "#bec8c1", + "#bcc6bf", + "#bac5bd", + "#b8c3bb", + "#b6c1b9", + "#b4c0b7", + "#b2beb5", + "#b0bcb3", + "#aebbb1", + "#acb9af", + "#aab8ad", + "#a8b6ab", + "#a6b4aa", + "#a4b3a8", + "#a2b1a6", + "#a0b0a4", + "#9eaea2", + "#9cada0", + "#9aab9e", + "#98aa9d", + "#96a89b", + "#94a699", + "#92a597", + "#91a395", + "#8fa294", + "#8da092", + "#8b9f90", + "#899d8e", + "#879c8c", + "#859b8b", + "#849989", + "#829887", + "#809685", + "#7e9584", + "#7c9382", + "#7a9280", + "#79907e", + "#778f7d", + "#758e7b", + "#738c79", + "#718b78", + "#6f8976", + "#6e8874", + "#6c8672", + "#6a8571", + "#68846f", + "#67826d", + "#65816c", + "#637f6a", + "#617e68", + "#5f7d67", + "#5e7b65", + "#5c7a63", + "#5a7962", + "#587760", + "#57765f", + "#55745d", + "#53735b", + "#51725a", + "#4f7058", + "#4e6f56", + "#4c6e55", + "#4a6c53", + "#486b52", + "#476a50", + "#45684e", + "#43674d", + "#41664b", + "#3f644a", + "#3e6348", + "#3c6246", + "#3a6045", + "#385f43", + "#365e42", + "#355c40", + "#335b3e", + "#315a3d", + "#2f583b", + "#2d573a", + "#2b5638", + "#295437", + "#275335", + "#255234", + "#235032", + "#214f30", + "#1f4d2f", + "#1d4c2d", + "#1b4b2c", + "#19492a", + "#174829", + "#144727", + "#124526", + "#0f4424", + ], + "flex_TuOr": [ + "#134e4a", + "#164f4b", + "#19504d", + "#1b524e", + "#1e534f", + "#205450", + "#225552", + "#255753", + "#275854", + "#295955", + "#2b5a57", + "#2d5c58", + "#2f5d59", + "#315e5a", + "#335f5c", + "#35615d", + "#37625e", + "#39635f", + "#3b6461", + "#3c6662", + "#3e6763", + "#406865", + "#426966", + "#446b67", + "#456c68", + "#476d6a", + "#496e6b", + "#4b706c", + "#4d716e", + "#4e726f", + "#507370", + "#527572", + "#547673", + "#557774", + "#577975", + "#597a77", + "#5b7b78", + "#5d7c79", + "#5e7e7b", + "#607f7c", + "#62807d", + "#64827f", + "#658380", + "#678482", + "#698683", + "#6b8784", + "#6c8886", + "#6e8a87", + "#708b88", + "#728c8a", + "#738e8b", + "#758f8c", + "#77908e", + "#79928f", + "#7a9391", + "#7c9492", + "#7e9693", + "#809795", + "#819896", + "#839a98", + "#859b99", + "#879d9b", + "#899e9c", + "#8a9f9d", + "#8ca19f", + "#8ea2a0", + "#90a4a2", + "#92a5a3", + "#93a7a5", + "#95a8a6", + "#97a9a8", + "#99aba9", + "#9bacab", + "#9daeac", + "#9eafae", + "#a0b1af", + "#a2b2b1", + "#a4b4b2", + "#a6b5b4", + "#a8b7b5", + "#aab8b7", + "#acbab8", + "#adbbba", + "#afbdbb", + "#b1bebd", + "#b3c0bf", + "#b5c1c0", + "#b7c3c2", + "#b9c5c3", + "#bbc6c5", + "#bdc8c7", + "#bfc9c8", + "#c1cbca", + "#c3cccc", + "#c5cecd", + "#c7d0cf", + "#c9d1d1", + "#cbd3d2", + "#cdd5d4", + "#cfd6d6", + "#d1d8d7", + "#d3dad9", + "#d5dbdb", + "#d7dddc", + "#d9dfde", + "#dbe0e0", + "#dde2e2", + "#dfe4e3", + "#e1e6e5", + "#e3e7e7", + "#e5e9e9", + "#e7ebeb", + "#e9edec", + "#ebeeee", + "#eef0f0", + "#f0f2f2", + "#f2f4f4", + "#f4f6f6", + "#f6f8f7", + "#f8f9f9", + "#fbfbfb", + "#fdfdfd", + "#FFFFFF", + "#fefdfd", + "#fcfcfb", + "#fbfaf9", + "#faf8f7", + "#f9f7f6", + "#f7f5f4", + "#f6f4f2", + "#f5f2f0", + "#f4f0ee", + "#f2efec", + "#f1edea", + "#f0ece8", + "#efeae6", + "#eee8e4", + "#ede7e2", + "#ece5df", + "#ebe4dd", + "#eae2db", + "#e9e0d9", + "#e8dfd7", + "#e7ddd5", + "#e6dcd3", + "#e5dad1", + "#e5d8cf", + "#e4d7cd", + "#e3d5cb", + "#e2d4c9", + "#e1d2c7", + "#e0d0c5", + "#dfcfc3", + "#dfcdc1", + "#deccbe", + "#ddcabc", + "#dcc9ba", + "#dbc7b8", + "#dac6b6", + "#dac4b4", + "#d9c2b2", + "#d8c1b0", + "#d7bfae", + "#d6beac", + "#d6bcaa", + "#d5bba8", + "#d4b9a6", + "#d3b8a4", + "#d3b6a2", + "#d2b5a0", + "#d1b39e", + "#d0b29c", + "#d0b09a", + "#cfaf98", + "#cead96", + "#cdac94", + "#cdaa92", + "#cca990", + "#cba78e", + "#caa68c", + "#caa48a", + "#c9a388", + "#c8a286", + "#c7a084", + "#c79f82", + "#c69d80", + "#c59c7e", + "#c59a7c", + "#c4997a", + "#c39778", + "#c29676", + "#c29474", + "#c19372", + "#c09270", + "#c0906e", + "#bf8f6c", + "#be8d6b", + "#bd8c69", + "#bd8a67", + "#bc8965", + "#bb8863", + "#bb8661", + "#ba855f", + "#b9835d", + "#b8825b", + "#b88059", + "#b77f57", + "#b67e55", + "#b57c53", + "#b57b51", + "#b47950", + "#b3784e", + "#b3774c", + "#b2754a", + "#b17448", + "#b07246", + "#b07144", + "#af7042", + "#ae6e40", + "#ad6d3e", + "#ad6b3c", + "#ac6a3a", + "#ab6938", + "#aa6737", + "#a96635", + "#a96433", + "#a86331", + "#a7622f", + "#a6602d", + "#a65f2b", + "#a55d29", + "#a45c27", + "#a35b25", + "#a25923", + "#a15821", + "#a1571f", + "#a0551c", + "#9f541a", + "#9e5218", + "#9d5116", + "#9c5013", + "#9c4e11", + "#9b4d0e", + "#9a4b0b", + "#994a08", + ], + "flex_OrTu": [ + "#994a08", + "#9a4b0b", + "#9b4d0e", + "#9c4e11", + "#9c5013", + "#9d5116", + "#9e5218", + "#9f541a", + "#a0551c", + "#a1571f", + "#a15821", + "#a25923", + "#a35b25", + "#a45c27", + "#a55d29", + "#a65f2b", + "#a6602d", + "#a7622f", + "#a86331", + "#a96433", + "#a96635", + "#aa6737", + "#ab6938", + "#ac6a3a", + "#ad6b3c", + "#ad6d3e", + "#ae6e40", + "#af7042", + "#b07144", + "#b07246", + "#b17448", + "#b2754a", + "#b3774c", + "#b3784e", + "#b47950", + "#b57b51", + "#b57c53", + "#b67e55", + "#b77f57", + "#b88059", + "#b8825b", + "#b9835d", + "#ba855f", + "#bb8661", + "#bb8863", + "#bc8965", + "#bd8a67", + "#bd8c69", + "#be8d6b", + "#bf8f6c", + "#c0906e", + "#c09270", + "#c19372", + "#c29474", + "#c29676", + "#c39778", + "#c4997a", + "#c59a7c", + "#c59c7e", + "#c69d80", + "#c79f82", + "#c7a084", + "#c8a286", + "#c9a388", + "#caa48a", + "#caa68c", + "#cba78e", + "#cca990", + "#cdaa92", + "#cdac94", + "#cead96", + "#cfaf98", + "#d0b09a", + "#d0b29c", + "#d1b39e", + "#d2b5a0", + "#d3b6a2", + "#d3b8a4", + "#d4b9a6", + "#d5bba8", + "#d6bcaa", + "#d6beac", + "#d7bfae", + "#d8c1b0", + "#d9c2b2", + "#dac4b4", + "#dac6b6", + "#dbc7b8", + "#dcc9ba", + "#ddcabc", + "#deccbe", + "#dfcdc1", + "#dfcfc3", + "#e0d0c5", + "#e1d2c7", + "#e2d4c9", + "#e3d5cb", + "#e4d7cd", + "#e5d8cf", + "#e5dad1", + "#e6dcd3", + "#e7ddd5", + "#e8dfd7", + "#e9e0d9", + "#eae2db", + "#ebe4dd", + "#ece5df", + "#ede7e2", + "#eee8e4", + "#efeae6", + "#f0ece8", + "#f1edea", + "#f2efec", + "#f4f0ee", + "#f5f2f0", + "#f6f4f2", + "#f7f5f4", + "#f9f7f6", + "#faf8f7", + "#fbfaf9", + "#fcfcfb", + "#fefdfd", + "#FFFFFF", + "#fdfdfd", + "#fbfbfb", + "#f8f9f9", + "#f6f8f7", + "#f4f6f6", + "#f2f4f4", + "#f0f2f2", + "#eef0f0", + "#ebeeee", + "#e9edec", + "#e7ebeb", + "#e5e9e9", + "#e3e7e7", + "#e1e6e5", + "#dfe4e3", + "#dde2e2", + "#dbe0e0", + "#d9dfde", + "#d7dddc", + "#d5dbdb", + "#d3dad9", + "#d1d8d7", + "#cfd6d6", + "#cdd5d4", + "#cbd3d2", + "#c9d1d1", + "#c7d0cf", + "#c5cecd", + "#c3cccc", + "#c1cbca", + "#bfc9c8", + "#bdc8c7", + "#bbc6c5", + "#b9c5c3", + "#b7c3c2", + "#b5c1c0", + "#b3c0bf", + "#b1bebd", + "#afbdbb", + "#adbbba", + "#acbab8", + "#aab8b7", + "#a8b7b5", + "#a6b5b4", + "#a4b4b2", + "#a2b2b1", + "#a0b1af", + "#9eafae", + "#9daeac", + "#9bacab", + "#99aba9", + "#97a9a8", + "#95a8a6", + "#93a7a5", + "#92a5a3", + "#90a4a2", + "#8ea2a0", + "#8ca19f", + "#8a9f9d", + "#899e9c", + "#879d9b", + "#859b99", + "#839a98", + "#819896", + "#809795", + "#7e9693", + "#7c9492", + "#7a9391", + "#79928f", + "#77908e", + "#758f8c", + "#738e8b", + "#728c8a", + "#708b88", + "#6e8a87", + "#6c8886", + "#6b8784", + "#698683", + "#678482", + "#658380", + "#64827f", + "#62807d", + "#607f7c", + "#5e7e7b", + "#5d7c79", + "#5b7b78", + "#597a77", + "#577975", + "#557774", + "#547673", + "#527572", + "#507370", + "#4e726f", + "#4d716e", + "#4b706c", + "#496e6b", + "#476d6a", + "#456c68", + "#446b67", + "#426966", + "#406865", + "#3e6763", + "#3c6662", + "#3b6461", + "#39635f", + "#37625e", + "#35615d", + "#335f5c", + "#315e5a", + "#2f5d59", + "#2d5c58", + "#2b5a57", + "#295955", + "#275854", + "#255753", + "#225552", + "#205450", + "#1e534f", + "#1b524e", + "#19504d", + "#164f4b", + "#134e4a", + ], +} diff --git a/tidy3d/_common/components/viz/flex_style.py b/tidy3d/_common/components/viz/flex_style.py new file mode 100644 index 0000000000..babcd3e9f1 --- /dev/null +++ b/tidy3d/_common/components/viz/flex_style.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from tidy3d._common.log import log + +_ORIGINAL_PARAMS = None + + +def apply_tidy3d_params() -> None: + """ + Applies a set of defaults to the matplotlib params that are following the tidy3d color palettes and design. + """ + global _ORIGINAL_PARAMS + try: + import matplotlib as mpl + import matplotlib.pyplot as plt + + _ORIGINAL_PARAMS = mpl.rcParams.copy() + + try: + plt.style.use("tidy3d.style") + except Exception as e: + log.error(f"Failed to apply Tidy3D plotting style on import. Error: {e}") + _ORIGINAL_PARAMS = {} + except ImportError: + pass + + +def restore_matplotlib_rcparams() -> None: + """ + Resets matplotlib rcParams to the values they had before the Tidy3D + style was automatically applied on import. + """ + global _ORIGINAL_PARAMS + try: + import matplotlib.pyplot as plt + from matplotlib import style + + if not _ORIGINAL_PARAMS: + style.use("default") + return + + plt.rcParams.update(_ORIGINAL_PARAMS) + except ImportError: + log.error("Matplotlib is not installed on your system. Failed to reset to default styles.") + except Exception as e: + log.error(f"Failed to reset previous Matplotlib style. Error: {e}") diff --git a/tidy3d/_common/components/viz/plot_params.py b/tidy3d/_common/components/viz/plot_params.py new file mode 100644 index 0000000000..c8955b602a --- /dev/null +++ b/tidy3d/_common/components/viz/plot_params.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional + +from numpy import inf +from pydantic import Field, NonNegativeFloat + +from tidy3d._common.components.base import Tidy3dBaseModel + +if TYPE_CHECKING: + from tidy3d._common.components.viz.visualization_spec import VisualizationSpec + + +class AbstractPlotParams(Tidy3dBaseModel): + """Abstract class for storing plotting parameters. + Corresponds with select properties of ``matplotlib.artist.Artist``. + """ + + alpha: Any = Field(1.0, title="Opacity") + zorder: Optional[float] = Field(None, title="Display Order") + + def include_kwargs(self, **kwargs: Any) -> AbstractPlotParams: + """Update the plot params with supplied kwargs.""" + update_dict = { + key: value + for key, value in kwargs.items() + if key not in ("type",) and value is not None and key in type(self).model_fields + } + return self.copy(update=update_dict) + + def override_with_viz_spec(self, viz_spec: VisualizationSpec) -> AbstractPlotParams: + """Override plot params with supplied VisualizationSpec.""" + return self.include_kwargs(**dict(viz_spec)) + + def to_kwargs(self) -> dict[str, Any]: + """Export the plot parameters as kwargs dict that can be supplied to plot function.""" + kwarg_dict = self.model_dump() + for ignore_key in ("type", "attrs"): + kwarg_dict.pop(ignore_key) + return kwarg_dict + + +class PathPlotParams(AbstractPlotParams): + """Stores plotting parameters / specifications for a path. + Corresponds with select properties of ``matplotlib.lines.Line2D``. + """ + + color: Optional[Any] = Field(None, title="Color", alias="c") + linewidth: NonNegativeFloat = Field(2, title="Line Width", alias="lw") + linestyle: str = Field("--", title="Line Style", alias="ls") + marker: Any = Field("o", title="Marker Style") + markeredgecolor: Optional[Any] = Field(None, title="Marker Edge Color", alias="mec") + markerfacecolor: Optional[Any] = Field(None, title="Marker Face Color", alias="mfc") + markersize: NonNegativeFloat = Field(10, title="Marker Size", alias="ms") + + +class PlotParams(AbstractPlotParams): + """Stores plotting parameters / specifications for a given model. + Corresponds with select properties of ``matplotlib.patches.Patch``. + """ + + edgecolor: Optional[Any] = Field(None, title="Edge Color", alias="ec") + facecolor: Optional[Any] = Field(None, title="Face Color", alias="fc") + fill: bool = Field(True, title="Is Filled") + hatch: Optional[str] = Field(None, title="Hatch Style") + linewidth: NonNegativeFloat = Field(1, title="Line Width", alias="lw") + + +# defaults for different tidy3d objects +plot_params_geometry = PlotParams() +plot_params_structure = PlotParams() +plot_params_source = PlotParams(alpha=0.4, facecolor="limegreen", edgecolor="limegreen", lw=3) +plot_params_absorber = PlotParams( + alpha=0.4, facecolor="lightskyblue", edgecolor="lightskyblue", lw=3 +) +plot_params_monitor = PlotParams(alpha=0.4, facecolor="orange", edgecolor="orange", lw=3) +plot_params_pml = PlotParams(alpha=0.7, facecolor="gray", edgecolor="gray", hatch="x", zorder=inf) +plot_params_pec = PlotParams(alpha=1.0, facecolor="gold", edgecolor="black", zorder=inf) +plot_params_pmc = PlotParams(alpha=1.0, facecolor="lightsteelblue", edgecolor="black", zorder=inf) +plot_params_bloch = PlotParams(alpha=1.0, facecolor="orchid", edgecolor="black", zorder=inf) +plot_params_abc = PlotParams(alpha=1.0, facecolor="lightskyblue", edgecolor="black", zorder=inf) +plot_params_symmetry = PlotParams(edgecolor="gray", facecolor="gray", alpha=0.6, zorder=inf) +plot_params_override_structures = PlotParams( + linewidth=0.4, edgecolor="black", fill=False, zorder=inf +) +plot_params_fluid = PlotParams(facecolor="white", edgecolor="lightsteelblue", lw=0.4, hatch="xx") +plot_params_grid = PlotParams(edgecolor="black", lw=0.2) +plot_params_lumped_element = PlotParams( + alpha=0.4, facecolor="mediumblue", edgecolor="mediumblue", lw=3 +) diff --git a/tidy3d/_common/components/viz/plot_sim_3d.py b/tidy3d/_common/components/viz/plot_sim_3d.py new file mode 100644 index 0000000000..9acc979247 --- /dev/null +++ b/tidy3d/_common/components/viz/plot_sim_3d.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import io +from html import escape +from typing import ( + TYPE_CHECKING, + Protocol, + runtime_checkable, +) + +from tidy3d._common.exceptions import SetupError + +if TYPE_CHECKING: + from collections.abc import Sequence + from os import PathLike + from typing import ( + Callable, + Optional, + Union, + runtime_checkable, + ) + + from IPython.core.display_functions import DisplayHandle + + +@runtime_checkable +class PlotSim3DProtocol(Protocol): + def to_hdf5_gz( + self, + fname: Union[PathLike[str], io.BytesIO], + custom_encoders: Optional[Sequence[Callable[..., object]]] = None, + ) -> None: ... + + +@runtime_checkable +class PlotScene3DProtocol(Protocol): + # Used by plot_scene_3d to patch JSON_STRING + size: Sequence[float] + center: Sequence[float] + + def to_hdf5( + self, + fname: Union[PathLike[str], io.BytesIO], + custom_encoders: Optional[Sequence[Callable[..., object]]] = None, + ) -> None: ... + + +def plot_scene_3d(scene: PlotScene3DProtocol, width: int = 800, height: int = 800) -> None: + import gzip + import json + from base64 import b64encode + from io import BytesIO + + import h5py + + # Serialize scene to HDF5 in-memory + buffer = BytesIO() + scene.to_hdf5(buffer) + buffer.seek(0) + + # Open source HDF5 for reading and prepare modified copy + with h5py.File(buffer, "r") as src: + buffer2 = BytesIO() + with h5py.File(buffer2, "w") as dst: + + def copy_item(name: str, obj: h5py.Group | h5py.Dataset) -> None: + if isinstance(obj, h5py.Group): + dst.create_group(name) + for k, v in obj.attrs.items(): + dst[name].attrs[k] = v + elif isinstance(obj, h5py.Dataset): + data = obj[()] + if name == "JSON_STRING": + # Parse and update JSON string + json_str = ( + data.decode("utf-8") if isinstance(data, (bytes, bytearray)) else data + ) + json_data = json.loads(json_str) + json_data["size"] = list(scene.size) + json_data["center"] = list(scene.center) + json_data["grid_spec"] = {} + new_str = json.dumps(json_data) + dst.create_dataset(name, data=new_str.encode("utf-8")) + else: + dst.create_dataset(name, data=data) + for k, v in obj.attrs.items(): + dst[name].attrs[k] = v + + src.visititems(copy_item) + buffer2.seek(0) + + # Gzip the modified HDF5 + gz_buffer = BytesIO() + with gzip.GzipFile(fileobj=gz_buffer, mode="wb") as gz: + gz.write(buffer2.read()) + gz_buffer.seek(0) + + # Base64 encode and display with gzipped flag + sim_base64 = b64encode(gz_buffer.read()).decode("utf-8") + plot_sim_3d(sim_base64, width=width, height=height, is_gz_base64=True) + + +def plot_sim_3d( + sim: Union[PlotSim3DProtocol, str], + width: int = 800, + height: int = 800, + is_gz_base64: bool = False, +) -> DisplayHandle: + """Make 3D display of simulation in ipython notebook.""" + + try: + from IPython.display import HTML, display + except ImportError as e: + raise SetupError( + "3D plotting requires ipython to be installed " + "and the code to be running on a jupyter notebook." + ) from e + + from base64 import b64encode + from io import BytesIO + + if not is_gz_base64: + buffer = BytesIO() + sim.to_hdf5_gz(buffer) + buffer.seek(0) + base64 = b64encode(buffer.read()).decode("utf-8") + else: + base64 = sim + + js_code = """ + /** + * Simulation Viewer Injector + * + * Monitors the document for elements being added in the form: + * + *
+ * + * This script will then inject an iframe to the viewer application, and pass it the simulation data + * via the postMessage API on request. The script may be safely included multiple times, with only the + * configuration of the first started script (e.g. viewer URL) applying. + * + */ + (function() { + const TARGET_CLASS = "simulation-viewer"; + const ACTIVE_CLASS = "simulation-viewer-active"; + const VIEWER_URL = "https://tidy3d.simulation.cloud/simulation-viewer"; + + class SimulationViewerInjector { + constructor() { + for (var node of document.getElementsByClassName(TARGET_CLASS)) { + this.injectViewer(node); + } + + // Monitor for newly added nodes to the DOM + this.observer = new MutationObserver(this.onMutations.bind(this)); + this.observer.observe(document.body, {childList: true, subtree: true}); + } + + onMutations(mutations) { + for (var mutation of mutations) { + if (mutation.type === 'childList') { + /** + * Have found that adding the element does not reliably trigger the mutation observer. + * It may be the case that setting content with innerHTML does not trigger. + * + * It seems to be sufficient to re-scan the document for un-activated viewers + * whenever an event occurs, as Jupyter triggers multiple events on cell evaluation. + */ + var viewers = document.getElementsByClassName(TARGET_CLASS); + for (var node of viewers) { + this.injectViewer(node); + } + } + } + } + + injectViewer(node) { + // (re-)check that this is a valid simulation container and has not already been injected + if (node.classList.contains(TARGET_CLASS) && !node.classList.contains(ACTIVE_CLASS)) { + // Mark node as injected, to prevent re-runs + node.classList.add(ACTIVE_CLASS); + + var uuid; + if (window.crypto && window.crypto.randomUUID) { + uuid = window.crypto.randomUUID(); + } else { + uuid = "" + Math.random(); + } + + var frame = document.createElement("iframe"); + frame.width = node.dataset.width || 800; + frame.height = node.dataset.height || 800; + frame.style.cssText = `width:${frame.width}px;height:${frame.height}px;max-width:none;border:0;display:block` + frame.src = VIEWER_URL + "?uuid=" + uuid; + + var postMessageToViewer; + postMessageToViewer = event => { + if(event.data.type === 'viewer' && event.data.uuid===uuid){ + frame.contentWindow.postMessage({ type: 'jupyter', uuid, value: node.dataset.simulation, fileType: 'hdf5'}, '*'); + + // Run once only + window.removeEventListener('message', postMessageToViewer); + } + }; + window.addEventListener( + 'message', + postMessageToViewer, + false + ); + + node.appendChild(frame); + } + } + } + + if (!window.simulationViewerInjector) { + window.simulationViewerInjector = new SimulationViewerInjector(); + } + })(); + """ + html_code = f""" +
+ + """ + + return display(HTML(html_code)) diff --git a/tidy3d/_common/components/viz/styles.py b/tidy3d/_common/components/viz/styles.py new file mode 100644 index 0000000000..067afa9327 --- /dev/null +++ b/tidy3d/_common/components/viz/styles.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +try: + from matplotlib.patches import ArrowStyle + + arrow_style = ArrowStyle.Simple(head_length=11, head_width=9, tail_width=4) +except ImportError: + arrow_style = None + +FLEXCOMPUTE_COLORS = { + "brand_green": "#00643C", + "brand_tan": "#B8A18B", + "brand_blue": "#6DB5DD", + "brand_purple": "#8851AD", + "brand_black": "#000000", + "brand_orange": "#FC7A4C", +} +ARROW_COLOR_SOURCE = FLEXCOMPUTE_COLORS["brand_green"] +ARROW_COLOR_POLARIZATION = FLEXCOMPUTE_COLORS["brand_tan"] +ARROW_COLOR_MONITOR = FLEXCOMPUTE_COLORS["brand_orange"] +ARROW_COLOR_ABSORBER = FLEXCOMPUTE_COLORS["brand_blue"] +PLOT_BUFFER = 0.3 +ARROW_ALPHA = 0.8 +ARROW_LENGTH = 0.3 + +# stores color of simulation.structures for given index in simulation.medium_map +MEDIUM_CMAP = [ + "#689DBC", + "#D0698E", + "#5E6EAD", + "#C6224E", + "#BDB3E2", + "#9EC3E0", + "#616161", + "#877EBC", +] + +# colormap for structure's permittivity in plot_eps +STRUCTURE_EPS_CMAP = "gist_yarg" +STRUCTURE_EPS_CMAP_R = "gist_yarg_r" +STRUCTURE_HEAT_COND_CMAP = "gist_yarg" diff --git a/tidy3d/_common/components/viz/visualization_spec.py b/tidy3d/_common/components/viz/visualization_spec.py new file mode 100644 index 0000000000..9928c678ec --- /dev/null +++ b/tidy3d/_common/components/viz/visualization_spec.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pydantic import Field, field_validator + +from tidy3d._common.components.base import Tidy3dBaseModel +from tidy3d._common.log import log + +if TYPE_CHECKING: + from pydantic import ValidationInfo + +MATPLOTLIB_IMPORTED = True +try: + from matplotlib.colors import is_color_like +except ImportError: + is_color_like = None + MATPLOTLIB_IMPORTED = False + + +def is_valid_color(value: str) -> str: + if not MATPLOTLIB_IMPORTED: + log.warning( + "matplotlib was not successfully imported, but is required " + "to validate colors in the VisualizationSpec. The specified colors " + "have not been validated." + ) + else: + if is_color_like is not None and not is_color_like(value): + raise ValueError(f"{value} is not a valid plotting color") + + return value + + +class VisualizationSpec(Tidy3dBaseModel): + """Defines specification for visualization when used with plotting functions.""" + + facecolor: str = Field( + "", + title="Face color", + description="Color applied to the faces in visualization.", + ) + + edgecolor: str = Field( + "", + title="Edge color", + description="Color applied to the edges in visualization.", + ) + + alpha: float = Field( + 1.0, + title="Opacity", + description="Opacity/alpha value in plotting between 0 and 1.", + ge=0, + le=1, + ) + + @field_validator("facecolor") + @classmethod + def _validate_facecolor(cls, value: str) -> str: + return is_valid_color(value) + + @field_validator("edgecolor") + @classmethod + def _ensure_edgecolor(cls, value: str, info: ValidationInfo) -> str: + # if no explicit edgecolor given, fall back to facecolor + if (value == "") and "facecolor" in info.data: + return is_valid_color(info.data["facecolor"]) + return is_valid_color(value) diff --git a/tidy3d/config/README.md b/tidy3d/_common/config/README.md similarity index 100% rename from tidy3d/config/README.md rename to tidy3d/_common/config/README.md diff --git a/tidy3d/_common/config/__init__.py b/tidy3d/_common/config/__init__.py new file mode 100644 index 0000000000..4be3a4ab1a --- /dev/null +++ b/tidy3d/_common/config/__init__.py @@ -0,0 +1,85 @@ +"""Tidy3D configuration system public API.""" + +from __future__ import annotations + +from typing import Any + +from .legacy import ( + LegacyConfigWrapper, + LegacyEnvironment, + LegacyEnvironmentConfig, +) +from .manager import ConfigManager +from .registry import ( + get_handlers, + get_sections, + register_handler, + register_plugin, + register_section, +) + +__all__ = [ + "ConfigManager", + "Env", + "Environment", + "EnvironmentConfig", + "config", + "get_handlers", + "get_sections", + "register_handler", + "register_plugin", + "register_section", +] + + +def _create_manager() -> ConfigManager: + return ConfigManager() + + +_base_manager = _create_manager() +# TODO(FXC-3827): Drop LegacyConfigWrapper once legacy accessors are removed in Tidy3D 2.12. +_config_wrapper = LegacyConfigWrapper(_base_manager) +config = _config_wrapper + +# TODO(FXC-3827): Remove legacy Env exports after deprecation window (planned 2.12). +Environment = LegacyEnvironment +EnvironmentConfig = LegacyEnvironmentConfig +Env: LegacyEnvironment | None = None + + +def initialize_env() -> None: + """Initialize legacy Env after sections register.""" + + global Env + if Env is None: + Env = LegacyEnvironment(_base_manager) + + +def reload_config(*, profile: str | None = None) -> LegacyConfigWrapper: + """Recreate the global configuration manager (primarily for tests).""" + + global _base_manager, Env + if _base_manager is not None: + try: + _base_manager.apply_web_env({}) + except AttributeError: + pass + _base_manager = ConfigManager(profile=profile) + _config_wrapper.reset_manager(_base_manager) + if Env is None: + initialize_env() + Env.reset_manager(_base_manager) + return _config_wrapper + + +def get_manager() -> ConfigManager: + """Return the underlying configuration manager instance.""" + + return _base_manager + + +def __getattr__(name: str) -> Any: + if name == "Env": + initialize_env() + return Env + return getattr(config, name) diff --git a/tidy3d/_common/config/legacy.py b/tidy3d/_common/config/legacy.py new file mode 100644 index 0000000000..015600349e --- /dev/null +++ b/tidy3d/_common/config/legacy.py @@ -0,0 +1,541 @@ +"""Legacy compatibility layer for tidy3d.config. + +This module holds (most) of the compatibility layer to the pre-2.10 tidy3d config +and is intended to be removed in a future release. +""" + +from __future__ import annotations + +import os +import warnings +from typing import TYPE_CHECKING, Any + +import toml + +from tidy3d._common._runtime import WASM_BUILD +from tidy3d._common.log import log + +# TODO(FXC-3827): Remove LegacyConfigWrapper/Environment shims and related helpers in Tidy3D 2.12. +from .manager import ConfigManager, normalize_profile_name +from .profiles import BUILTIN_PROFILES + +if TYPE_CHECKING: + from pathlib import Path + from typing import Optional + + from tidy3d._common.log import LogLevel + + +def _warn_env_deprecated() -> None: + message = "'tidy3d.config.Env' is deprecated; use 'config.switch_profile(...)' instead." + warnings.warn(message, DeprecationWarning, stacklevel=3) + log.warning(message, log_once=True) + + +# TODO(FXC-3827): Delete LegacyConfigWrapper once legacy attribute access is dropped. +class LegacyConfigWrapper: + """Provide attribute-level compatibility with the legacy config module.""" + + def __init__(self, manager: ConfigManager): + self._manager = manager + self._frozen = False # retained for backwards compatibility tests + + @property + def logging_level(self) -> LogLevel: + return self._manager.get_section("logging").level + + @logging_level.setter + def logging_level(self, value: LogLevel) -> None: + from warnings import warn + + warn( + "'config.logging_level' is deprecated; use 'config.logging.level' instead.", + DeprecationWarning, + stacklevel=2, + ) + self._manager.update_section("logging", level=value) + + @property + def log_suppression(self) -> bool: + return self._manager.get_section("logging").suppression + + @log_suppression.setter + def log_suppression(self, value: bool) -> None: + from warnings import warn + + warn( + "'config.log_suppression' is deprecated; use 'config.logging.suppression'.", + DeprecationWarning, + stacklevel=2, + ) + self._manager.update_section("logging", suppression=value) + + @property + def use_local_subpixel(self) -> Optional[bool]: + return self._manager.get_section("simulation").use_local_subpixel + + @use_local_subpixel.setter + def use_local_subpixel(self, value: Optional[bool]) -> None: + from warnings import warn + + warn( + "'config.use_local_subpixel' is deprecated; use 'config.simulation.use_local_subpixel'.", + DeprecationWarning, + stacklevel=2, + ) + self._manager.update_section("simulation", use_local_subpixel=value) + + @property + def suppress_rf_license_warning(self) -> bool: + return self._manager.get_section("microwave").suppress_rf_license_warning + + @suppress_rf_license_warning.setter + def suppress_rf_license_warning(self, value: bool) -> None: + from warnings import warn + + warn( + "'config.suppress_rf_license_warning' is deprecated; " + "use 'config.microwave.suppress_rf_license_warning'.", + DeprecationWarning, + stacklevel=2, + ) + self._manager.update_section("microwave", suppress_rf_license_warning=value) + + @property + def frozen(self) -> bool: + return self._frozen + + @frozen.setter + def frozen(self, value: bool) -> None: + self._frozen = bool(value) + + def save(self, include_defaults: bool = False) -> None: + self._manager.save(include_defaults=include_defaults) + + def reset_manager(self, manager: ConfigManager) -> None: + """Swap the underlying manager instance.""" + + self._manager = manager + + def switch_profile(self, profile: str) -> None: + """Switch active profile and synchronize the legacy environment proxy.""" + + normalized = normalize_profile_name(profile) + self._manager.switch_profile(normalized) + try: + from tidy3d._common.config import Env as _legacy_env + except Exception: + _legacy_env = None + if _legacy_env is not None: + _legacy_env._sync_to_manager(apply_env=True) + + def __getattr__(self, name: str) -> Any: + return getattr(self._manager, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith("_"): + object.__setattr__(self, name, value) + elif name in { + "logging_level", + "log_suppression", + "use_local_subpixel", + "suppress_rf_license_warning", + "frozen", + }: + prop = getattr(type(self), name) + prop.fset(self, value) + else: + setattr(self._manager, name, value) + + def __str__(self) -> str: + return self._manager.format() + + +# TODO(FXC-3827): Delete LegacyEnvironmentConfig once profile-based Env shim is removed. +class LegacyEnvironmentConfig: + """Backward compatible environment config wrapper that proxies ConfigManager.""" + + def __init__( + self, + manager: Optional[ConfigManager] = None, + name: Optional[str] = None, + *, + web_api_endpoint: Optional[str] = None, + website_endpoint: Optional[str] = None, + s3_region: Optional[str] = None, + ssl_verify: Optional[bool] = None, + enable_caching: Optional[bool] = None, + ssl_version: Optional[str] = None, + env_vars: Optional[dict[str, str]] = None, + environment: Optional[LegacyEnvironment] = None, + ) -> None: + if name is None: + raise ValueError("Environment name is required") + self._manager = manager + self._name = normalize_profile_name(name) + self._environment = environment + self._pending: dict[str, Any] = {} + if web_api_endpoint is not None: + self._pending["api_endpoint"] = web_api_endpoint + if website_endpoint is not None: + self._pending["website_endpoint"] = website_endpoint + if s3_region is not None: + self._pending["s3_region"] = s3_region + if ssl_verify is not None: + self._pending["ssl_verify"] = ssl_verify + if enable_caching is not None: + self._pending["enable_caching"] = enable_caching + if ssl_version is not None: + self._pending["ssl_version"] = ssl_version + if env_vars is not None: + self._pending["env_vars"] = dict(env_vars) + + def reset_manager(self, manager: ConfigManager) -> None: + self._manager = manager + + @property + def manager(self) -> Optional[ConfigManager]: + if self._manager is not None: + return self._manager + if self._environment is not None: + return self._environment._manager + return None + + def active(self) -> None: + _warn_env_deprecated() + environment = self._environment + if environment is None: + from tidy3d._common.config import Env # local import to avoid circular + + environment = Env + + environment.set_current(self) + + @property + def web_api_endpoint(self) -> Optional[str]: + value = self._value("api_endpoint") + return _maybe_str(value) + + @property + def website_endpoint(self) -> Optional[str]: + value = self._value("website_endpoint") + return _maybe_str(value) + + @property + def s3_region(self) -> Optional[str]: + return self._value("s3_region") + + @property + def ssl_verify(self) -> bool: + value = self._value("ssl_verify") + if value is None: + return True + return bool(value) + + @property + def enable_caching(self) -> bool: + value = self._value("enable_caching") + if value is None: + return True + return bool(value) + + @enable_caching.setter + def enable_caching(self, value: Optional[bool]) -> None: + self._set_pending("enable_caching", value) + + @property + def ssl_version(self) -> Optional[str]: + return self._value("ssl_version") + + @ssl_version.setter + def ssl_version(self, value: Optional[str]) -> None: + self._set_pending("ssl_version", value) + + @property + def env_vars(self) -> dict[str, str]: + value = self._value("env_vars") + if value is None: + return {} + return dict(value) + + @env_vars.setter + def env_vars(self, value: dict[str, str]) -> None: + self._set_pending("env_vars", dict(value)) + + @property + def name(self) -> str: + return self._name + + @name.setter + def name(self, value: str) -> None: + self._name = normalize_profile_name(value) + + def copy_state_from(self, other: LegacyEnvironmentConfig) -> None: + if not isinstance(other, LegacyEnvironmentConfig): + raise TypeError("Expected LegacyEnvironmentConfig instance.") + for key, value in other._pending.items(): + if key == "env_vars" and value is not None: + self._pending[key] = dict(value) + else: + self._pending[key] = value + + def get_real_url(self, path: str) -> str: + manager = self.manager + if manager is not None and manager.profile == self._name: + web_section = manager.get_section("web") + if hasattr(web_section, "build_api_url"): + return web_section.build_api_url(path) + + endpoint = self.web_api_endpoint or "" + if not path: + return endpoint + return "/".join([endpoint.rstrip("/"), str(path).lstrip("/")]) + + def apply_pending_overrides(self) -> None: + manager = self.manager + if manager is None or manager.profile != self._name: + return + if not self._pending: + return + updates = dict(self._pending) + manager.update_section("web", **updates) + self._pending.clear() + + def _set_pending(self, key: str, value: Any) -> None: + if key == "env_vars" and value is not None: + self._pending[key] = dict(value) + else: + self._pending[key] = value + self.apply_pending_overrides() + + def _web_section(self) -> dict[str, Any]: + manager = self.manager + if manager is None or WASM_BUILD: + return {} + profile = normalize_profile_name(self._name) + if manager.profile == profile: + section = manager.get_section("web") + return section.model_dump(mode="python", exclude_unset=False) + preview = manager.preview_profile(profile) + source = preview.get("web", {}) + return dict(source) if isinstance(source, dict) else {} + + def _value(self, key: str) -> Any: + if key in self._pending: + return self._pending[key] + return self._web_section().get(key) + + +# TODO(FXC-3827): Delete LegacyEnvironment after deprecating `tidy3d.config.Env`. +class LegacyEnvironment: + """Legacy Env wrapper that maps to profiles.""" + + def __init__(self, manager: ConfigManager): + self._previous_env_vars: dict[str, Optional[str]] = {} + self.env_map: dict[str, LegacyEnvironmentConfig] = {} + self._current: Optional[LegacyEnvironmentConfig] = None + self._manager: Optional[ConfigManager] = None + self._applied_profile: Optional[str] = None + self.reset_manager(manager) + + def reset_manager(self, manager: ConfigManager) -> None: + self._manager = manager + self.env_map = {} + for name in BUILTIN_PROFILES: + key = normalize_profile_name(name) + self.env_map[key] = LegacyEnvironmentConfig(manager, key, environment=self) + self._applied_profile = None + self._current = None + self._sync_to_manager(apply_env=True) + + @property + def current(self) -> LegacyEnvironmentConfig: + self._sync_to_manager() + assert self._current is not None + return self._current + + def set_current(self, env_config: LegacyEnvironmentConfig) -> None: + _warn_env_deprecated() + key = normalize_profile_name(env_config.name) + stored = self._get_config(key) + stored.copy_state_from(env_config) + if self._manager and self._manager.profile != key: + self._manager.switch_profile(key) + self._sync_to_manager(apply_env=True) + + def enable_caching(self, enable_caching: Optional[bool] = True) -> None: + config = self.current + config.enable_caching = enable_caching + self._sync_to_manager() + + def set_ssl_version(self, ssl_version: Optional[str]) -> None: + config = self.current + config.ssl_version = ssl_version + self._sync_to_manager() + + def __getattr__(self, name: str) -> LegacyEnvironmentConfig: + return self._get_config(name) + + def _get_config(self, name: str) -> LegacyEnvironmentConfig: + key = normalize_profile_name(name) + config = self.env_map.get(key) + if config is None: + config = LegacyEnvironmentConfig(self._manager, key, environment=self) + self.env_map[key] = config + else: + manager = self._manager + if manager is not None: + config.reset_manager(manager) + config._environment = self + return config + + def _sync_to_manager(self, *, apply_env: bool = False) -> None: + if self._manager is None: + return + active = normalize_profile_name(self._manager.profile) + config = self._get_config(active) + config.apply_pending_overrides() + self._current = config + if apply_env or self._applied_profile != active: + self._apply_env_vars(config) + self._applied_profile = active + + def _apply_env_vars(self, config: LegacyEnvironmentConfig) -> None: + self._restore_env_vars() + env_vars = config.env_vars or {} + self._previous_env_vars = {} + for key, value in env_vars.items(): + self._previous_env_vars[key] = os.environ.get(key) + os.environ[key] = value + + def _restore_env_vars(self) -> None: + for key, previous in self._previous_env_vars.items(): + if previous is None: + os.environ.pop(key, None) + else: + os.environ[key] = previous + self._previous_env_vars = {} + + +def _maybe_str(value: Any) -> Optional[str]: + if value is None: + return None + return str(value) + + +def load_legacy_flat_config(config_dir: Path) -> dict[str, Any]: + """Load legacy flat configuration file (pre-migration format). + + This function now supports both the original flat config format and + Nexus custom deployment settings introduced in later versions. + + Legacy key mappings: + - apikey -> web.apikey + - web_api_endpoint -> web.api_endpoint + - website_endpoint -> web.website_endpoint + - s3_region -> web.s3_region + - s3_endpoint -> web.env_vars.AWS_ENDPOINT_URL_S3 + - ssl_verify -> web.ssl_verify + - enable_caching -> web.enable_caching + """ + + legacy_path = config_dir / "config" + if not legacy_path.exists(): + return {} + + try: + text = legacy_path.read_text(encoding="utf-8") + except Exception as exc: + log.warning(f"Failed to read legacy configuration file '{legacy_path}': {exc}") + return {} + + try: + parsed = toml.loads(text) + except Exception as exc: + log.warning(f"Failed to decode legacy configuration file '{legacy_path}': {exc}") + return {} + + legacy_data: dict[str, Any] = {} + + # Migrate API key (original functionality) + apikey = parsed.get("apikey") + if apikey is not None: + legacy_data.setdefault("web", {})["apikey"] = apikey + + # Migrate Nexus API endpoint + web_api = parsed.get("web_api_endpoint") + if web_api is not None: + legacy_data.setdefault("web", {})["api_endpoint"] = web_api + + # Migrate Nexus website endpoint + website = parsed.get("website_endpoint") + if website is not None: + legacy_data.setdefault("web", {})["website_endpoint"] = website + + # Migrate S3 region + s3_region = parsed.get("s3_region") + if s3_region is not None: + legacy_data.setdefault("web", {})["s3_region"] = s3_region + + # Migrate SSL verification setting + ssl_verify = parsed.get("ssl_verify") + if ssl_verify is not None: + legacy_data.setdefault("web", {})["ssl_verify"] = ssl_verify + + # Migrate caching setting + enable_caching = parsed.get("enable_caching") + if enable_caching is not None: + legacy_data.setdefault("web", {})["enable_caching"] = enable_caching + + # Migrate S3 endpoint to env_vars + s3_endpoint = parsed.get("s3_endpoint") + if s3_endpoint is not None: + env_vars = legacy_data.setdefault("web", {}).setdefault("env_vars", {}) + env_vars["AWS_ENDPOINT_URL_S3"] = s3_endpoint + + return legacy_data + + +__all__ = [ + "LegacyConfigWrapper", + "LegacyEnvironment", + "LegacyEnvironmentConfig", + "finalize_legacy_migration", + "load_legacy_flat_config", +] + + +def finalize_legacy_migration(config_dir: Path) -> None: + """Promote a copied legacy configuration tree into the structured format. + + Parameters + ---------- + config_dir : Path + Destination directory (typically the canonical config location). + """ + + legacy_data = load_legacy_flat_config(config_dir) + + from .manager import ConfigManager # local import to avoid circular dependency + + manager = ConfigManager(profile="default", config_dir=config_dir) + config_path = config_dir / "config.toml" + for section, values in legacy_data.items(): + if isinstance(values, dict): + manager.update_section(section, **values) + try: + manager.save(include_defaults=True) + except Exception: + if config_path.exists(): + try: + config_path.unlink() + except Exception: + pass + raise + + legacy_flat_path = config_dir / "config" + if legacy_flat_path.exists(): + try: + legacy_flat_path.unlink() + except Exception as exc: + log.warning(f"Failed to remove legacy configuration file '{legacy_flat_path}': {exc}") diff --git a/tidy3d/_common/config/loader.py b/tidy3d/_common/config/loader.py new file mode 100644 index 0000000000..19492b84d4 --- /dev/null +++ b/tidy3d/_common/config/loader.py @@ -0,0 +1,441 @@ +"""Filesystem helpers and persistence utilities for the configuration system.""" + +from __future__ import annotations + +import os +import shutil +import tempfile +from copy import deepcopy +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import toml +import tomlkit + +from tidy3d._common.log import log + +from .profiles import BUILTIN_PROFILES +from .serializer import build_document, collect_descriptions + +if TYPE_CHECKING: + from typing import Optional + + +class ConfigLoader: + """Handle reading and writing configuration files.""" + + def __init__(self, config_dir: Optional[Path] = None): + self.config_dir = config_dir or resolve_config_directory() + self.config_dir.mkdir(mode=0o700, parents=True, exist_ok=True) + self._docs: dict[Path, tomlkit.TOMLDocument] = {} + + def load_base(self) -> dict[str, Any]: + """Load base configuration from config.toml. + + If config.toml doesn't exist but the legacy flat config does, + automatically migrate to the new format. + """ + + config_path = self.config_dir / "config.toml" + data = self._read_toml(config_path) + if data: + return data + + # Check for legacy flat config + from .legacy import load_legacy_flat_config + + legacy_path = self.config_dir / "config" + legacy = load_legacy_flat_config(self.config_dir) + + # Auto-migrate if legacy config exists + if legacy and legacy_path.exists(): + log.info( + f"Detected legacy configuration at '{legacy_path}'. " + "Automatically migrating to new format..." + ) + + try: + # Save in new format + self.save_base(legacy) + + # Rename old config to preserve it + backup_path = legacy_path.with_suffix(".migrated") + legacy_path.rename(backup_path) + + log.info( + f"Migration complete. Configuration saved to '{config_path}'. " + f"Legacy config backed up as '{backup_path.name}'." + ) + + # Re-read the newly created config + return self._read_toml(config_path) + except Exception as exc: + log.warning( + f"Failed to auto-migrate legacy configuration: {exc}. " + "Using legacy data without migration." + ) + return legacy + + if legacy: + return legacy + return {} + + def load_user_profile(self, profile: str) -> dict[str, Any]: + """Load user profile overrides (if any).""" + + if profile in ("default", "prod"): + # default and prod share the same baseline; user overrides live in config.toml + return {} + + profile_path = self.profile_path(profile) + return self._read_toml(profile_path) + + def get_builtin_profile(self, profile: str) -> dict[str, Any]: + """Return builtin profile data if available.""" + + return BUILTIN_PROFILES.get(profile, {}) + + def save_base(self, data: dict[str, Any]) -> None: + """Persist base configuration.""" + + config_path = self.config_dir / "config.toml" + self._atomic_write(config_path, data) + + def save_profile(self, profile: str, data: dict[str, Any]) -> None: + """Persist profile overrides (remove file if empty).""" + + profile_path = self.profile_path(profile) + if not data: + if profile_path.exists(): + profile_path.unlink() + self._docs.pop(profile_path, None) + return + profile_path.parent.mkdir(mode=0o700, parents=True, exist_ok=True) + self._atomic_write(profile_path, data) + + def profile_path(self, profile: str) -> Path: + """Return on-disk path for a profile.""" + + return self.config_dir / "profiles" / f"{profile}.toml" + + def get_default_profile(self) -> Optional[str]: + """Read the default_profile from config.toml. + + Returns + ------- + Optional[str] + The default profile name if set, None otherwise. + """ + + config_path = self.config_dir / "config.toml" + if not config_path.exists(): + return None + + try: + text = config_path.read_text(encoding="utf-8") + data = toml.loads(text) + return data.get("default_profile") + except Exception as exc: + log.warning(f"Failed to read default_profile from '{config_path}': {exc}") + return None + + def set_default_profile(self, profile: Optional[str]) -> None: + """Set the default_profile in config.toml. + + Parameters + ---------- + profile : Optional[str] + The profile name to set as default, or None to remove the setting. + """ + + config_path = self.config_dir / "config.toml" + data = self._read_toml(config_path) + + if profile is None: + # Remove default_profile if it exists + if "default_profile" in data: + del data["default_profile"] + else: + # Set default_profile as a top-level key + data["default_profile"] = profile + + self._atomic_write(config_path, data) + + def _read_toml(self, path: Path) -> dict[str, Any]: + if not path.exists(): + self._docs.pop(path, None) + return {} + + try: + text = path.read_text(encoding="utf-8") + except Exception as exc: + log.warning(f"Failed to read configuration file '{path}': {exc}") + self._docs.pop(path, None) + return {} + + try: + document = tomlkit.parse(text) + except Exception as exc: + log.warning(f"Failed to parse configuration file '{path}': {exc}") + document = tomlkit.document() + self._docs[path] = document + + try: + return toml.loads(text) + except Exception as exc: + log.warning(f"Failed to decode configuration file '{path}': {exc}") + return {} + + def _atomic_write(self, path: Path, data: dict[str, Any]) -> None: + path.parent.mkdir(mode=0o700, parents=True, exist_ok=True) + tmp_dir = path.parent + + cleaned = _clean_data(deepcopy(data)) + descriptions = collect_descriptions() + + base_document = self._docs.get(path) + document = build_document(cleaned, base_document, descriptions) + toml_text = tomlkit.dumps(document) + + with tempfile.NamedTemporaryFile( + "w", dir=tmp_dir, delete=False, encoding="utf-8" + ) as handle: + tmp_path = Path(handle.name) + handle.write(toml_text) + handle.flush() + os.fsync(handle.fileno()) + + backup_path = path.with_suffix(path.suffix + ".bak") + try: + if path.exists(): + shutil.copy2(path, backup_path) + tmp_path.replace(path) + os.chmod(path, 0o600) + if backup_path.exists(): + backup_path.unlink() + except Exception: + if tmp_path.exists(): + tmp_path.unlink() + if backup_path.exists(): + try: + backup_path.replace(path) + except Exception: + log.warning("Failed to restore configuration backup") + raise + + self._docs[path] = tomlkit.parse(toml_text) + + +def load_environment_overrides() -> dict[str, Any]: + """Parse environment variables into a nested configuration dict.""" + + overrides: dict[str, Any] = {} + for key, value in os.environ.items(): + if key == "SIMCLOUD_APIKEY": + _assign_path(overrides, ("web", "apikey"), value) + continue + if not key.startswith("TIDY3D_"): + continue + rest = key[len("TIDY3D_") :] + if "__" not in rest: + continue + segments = tuple(segment.lower() for segment in rest.split("__") if segment) + if not segments: + continue + if segments[0] == "auth": + segments = ("web",) + segments[1:] + _assign_path(overrides, segments, value) + return overrides + + +def deep_merge(*sources: dict[str, Any]) -> dict[str, Any]: + """Deep merge multiple dictionaries into a new dict.""" + + result: dict[str, Any] = {} + for source in sources: + _merge_into(result, source) + return result + + +def _merge_into(target: dict[str, Any], source: dict[str, Any]) -> None: + for key, value in source.items(): + if isinstance(value, dict): + node = target.setdefault(key, {}) + if isinstance(node, dict): + _merge_into(node, value) + else: + target[key] = deepcopy(value) + else: + target[key] = value + + +def deep_diff(base: dict[str, Any], target: dict[str, Any]) -> dict[str, Any]: + """Return keys from target that differ from base.""" + + diff: dict[str, Any] = {} + keys = set(base.keys()) | set(target.keys()) + for key in keys: + base_value = base.get(key) + target_value = target.get(key) + if isinstance(base_value, dict) and isinstance(target_value, dict): + nested = deep_diff(base_value, target_value) + if nested: + diff[key] = nested + elif target_value != base_value: + if isinstance(target_value, dict): + diff[key] = deepcopy(target_value) + else: + diff[key] = target_value + return diff + + +def _assign_path(target: dict[str, Any], path: tuple[str, ...], value: Any) -> None: + node = target + for segment in path[:-1]: + node = node.setdefault(segment, {}) + node[path[-1]] = value + + +def _clean_data(data: Any) -> Any: + if isinstance(data, dict): + cleaned: dict[str, Any] = {} + for key, value in data.items(): + cleaned_value = _clean_data(value) + if cleaned_value is None: + continue + cleaned[key] = cleaned_value + return cleaned + if isinstance(data, list): + cleaned_list = [_clean_data(item) for item in data] + return [item for item in cleaned_list if item is not None] + if data is None: + return None + return data + + +def legacy_config_directory() -> Path: + """Return the legacy configuration directory (~/.tidy3d).""" + + return Path.home() / ".tidy3d" + + +def canonical_config_directory() -> Path: + """Return the platform-dependent canonical configuration directory.""" + + return _xdg_config_home() / "tidy3d" + + +def resolve_config_directory() -> Path: + """Determine the directory used to store tidy3d configuration files.""" + + base_override = os.getenv("TIDY3D_BASE_DIR") + if base_override: + base_path = Path(base_override).expanduser().resolve() + path = base_path / "config" + if _is_writable(path.parent): + return path + log.warning( + "'TIDY3D_BASE_DIR' is not writable; using temporary configuration directory instead." + ) + return _temporary_config_dir() + + canonical_dir = canonical_config_directory() + if _is_writable(canonical_dir.parent): + legacy_dir = legacy_config_directory() + if legacy_dir.exists(): + log.warning( + f"Using canonical configuration directory at '{canonical_dir}'. " + "Found legacy directory at '~/.tidy3d', which will be ignored. " + "Remove it manually or run 'tidy3d config migrate --delete-legacy' to clean up.", + log_once=True, + ) + return canonical_dir + + legacy_dir = legacy_config_directory() + if legacy_dir.exists(): + log.warning( + "Configuration found in legacy location '~/.tidy3d'. Consider running 'tidy3d config migrate'.", + log_once=True, + ) + return legacy_dir + + log.warning(f"Unable to write to '{canonical_dir}'; falling back to temporary directory.") + return _temporary_config_dir() + + +def _xdg_config_home() -> Path: + xdg_home = os.getenv("XDG_CONFIG_HOME") + if xdg_home: + return Path(xdg_home).expanduser() + return Path.home() / ".config" + + +def _temporary_config_dir() -> Path: + base = Path(tempfile.gettempdir()) / "tidy3d" + base.mkdir(mode=0o700, exist_ok=True) + return base / "config" + + +def _is_writable(path: Path) -> bool: + try: + path.mkdir(parents=True, exist_ok=True) + test_file = path / ".tidy3d_write_test" + with open(test_file, "w", encoding="utf-8"): + pass + test_file.unlink() + return True + except Exception: + return False + + +def migrate_legacy_config(*, overwrite: bool = False, remove_legacy: bool = False) -> Path: + """Copy configuration files from the legacy ``~/.tidy3d`` directory to the canonical location. + + Parameters + ---------- + overwrite : bool + If ``True``, existing files in the canonical directory will be replaced. + remove_legacy : bool + If ``True``, the legacy directory is removed after a successful migration. + + Returns + ------- + Path + The path of the canonical configuration directory. + + Raises + ------ + FileNotFoundError + If the legacy directory does not exist. + FileExistsError + If the destination already exists and ``overwrite`` is ``False``. + RuntimeError + If the legacy and canonical directories resolve to the same location. + """ + + legacy_dir = legacy_config_directory() + if not legacy_dir.exists(): + raise FileNotFoundError("Legacy configuration directory '~/.tidy3d' was not found.") + + canonical_dir = canonical_config_directory() + if canonical_dir.resolve() == legacy_dir.resolve(): + raise RuntimeError( + "Legacy and canonical configuration directories are the same path; nothing to migrate." + ) + + if canonical_dir.exists() and not overwrite: + raise FileExistsError( + f"Destination '{canonical_dir}' already exists. Pass overwrite=True to replace existing files." + ) + + canonical_dir.parent.mkdir(parents=True, exist_ok=True) + shutil.copytree(legacy_dir, canonical_dir, dirs_exist_ok=overwrite) + + from .legacy import finalize_legacy_migration # local import to avoid circular dependency + + finalize_legacy_migration(canonical_dir) + + if remove_legacy: + shutil.rmtree(legacy_dir) + + return canonical_dir diff --git a/tidy3d/_common/config/manager.py b/tidy3d/_common/config/manager.py new file mode 100644 index 0000000000..ffd1277913 --- /dev/null +++ b/tidy3d/_common/config/manager.py @@ -0,0 +1,634 @@ +"""Central configuration manager implementation.""" + +from __future__ import annotations + +import os +import shutil +from collections import defaultdict +from copy import deepcopy +from enum import Enum +from io import StringIO +from pathlib import Path +from typing import TYPE_CHECKING, Any, get_args, get_origin + +from pydantic import BaseModel +from rich.console import Console +from rich.panel import Panel +from rich.pretty import Pretty +from rich.text import Text +from rich.tree import Tree + +from tidy3d._common.log import log + +from .loader import ConfigLoader, deep_diff, deep_merge, load_environment_overrides +from .profiles import BUILTIN_PROFILES +from .registry import attach_manager, get_handlers, get_sections + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from typing import Optional + + +def normalize_profile_name(name: str) -> str: + """Return a canonical profile name for builtin profiles.""" + + normalized = name.strip() + lowered = normalized.lower() + if lowered in BUILTIN_PROFILES: + return lowered + return normalized + + +class SectionAccessor: + """Attribute proxy that routes assignments back through the manager.""" + + def __init__(self, manager: ConfigManager, path: str): + self._manager = manager + self._path = path + + def __getattr__(self, name: str) -> Any: + model = self._manager._get_model(self._path) + if model is None: + raise AttributeError(f"Section '{self._path}' is not available") + return getattr(model, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith("_"): + object.__setattr__(self, name, value) + return + self._manager.update_section(self._path, **{name: value}) + + def __repr__(self) -> str: + model = self._manager._get_model(self._path) + return f"SectionAccessor({self._path}={model!r})" + + def __rich__(self) -> Panel: + model = self._manager._get_model(self._path) + if model is None: + return Panel(Text(f"Section '{self._path}' is unavailable", style="red")) + data = _prepare_for_display(model.model_dump(exclude_unset=False)) + return _build_section_panel(self._path, data) + + def dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + model = self._manager._get_model(self._path) + if model is None: + return {} + return model.model_dump(*args, **kwargs) + + def __str__(self) -> str: + return self._manager.format_section(self._path) + + +class PluginsAccessor: + """Provides access to registered plugin configurations.""" + + def __init__(self, manager: ConfigManager): + self._manager = manager + + def __getattr__(self, plugin: str) -> SectionAccessor: + if plugin not in self._manager._plugin_models: + raise AttributeError(f"Plugin '{plugin}' is not registered") + return SectionAccessor(self._manager, f"plugins.{plugin}") + + def list(self) -> Iterable[str]: + return sorted(self._manager._plugin_models.keys()) + + +class ProfilesAccessor: + """Read-only profile helper.""" + + def __init__(self, manager: ConfigManager): + self._manager = manager + + def list(self) -> dict[str, list[str]]: + return self._manager.list_profiles() + + def __getattr__(self, profile: str) -> dict[str, Any]: + return self._manager.preview_profile(profile) + + +class ConfigManager: + """High-level orchestrator for tidy3d configuration.""" + + def __init__( + self, + profile: Optional[str] = None, + config_dir: Optional[os.PathLike[str]] = None, + ): + loader_path = None if config_dir is None else Path(config_dir) + self._loader = ConfigLoader(loader_path) + self._runtime_overrides: dict[str, dict[str, Any]] = defaultdict(dict) + self._plugin_models: dict[str, BaseModel] = {} + self._section_models: dict[str, BaseModel] = {} + self._profile = self._resolve_initial_profile(profile) + self._builtin_data: dict[str, Any] = {} + self._base_data: dict[str, Any] = {} + self._profile_data: dict[str, Any] = {} + self._raw_tree: dict[str, Any] = {} + self._effective_tree: dict[str, Any] = {} + self._env_overrides: dict[str, Any] = load_environment_overrides() + self._web_env_previous: dict[str, Optional[str]] = {} + + attach_manager(self) + self._reload() + + # Notify users when using a non-default profile + if self._profile != "default": + log.info(f"Using configuration profile: '{self._profile}'", log_once=True) + + self._apply_handlers() + + @property + def profile(self) -> str: + return self._profile + + @property + def config_dir(self) -> Path: + return self._loader.config_dir + + @property + def plugins(self) -> PluginsAccessor: + return PluginsAccessor(self) + + @property + def profiles(self) -> ProfilesAccessor: + return ProfilesAccessor(self) + + def update_section(self, name: str, **updates: Any) -> None: + if not updates: + return + segments = name.split(".") + overrides = self._runtime_overrides[self._profile] + previous = deepcopy(overrides) + node = overrides + for segment in segments[:-1]: + node = node.setdefault(segment, {}) + section_key = segments[-1] + section_payload = node.setdefault(section_key, {}) + for key, value in updates.items(): + section_payload[key] = _serialize_value(value) + try: + self._reload() + except Exception: + self._runtime_overrides[self._profile] = previous + raise + self._apply_handlers(section=name) + + def switch_profile(self, profile: str) -> None: + if not profile: + raise ValueError("Profile name cannot be empty") + normalized = normalize_profile_name(profile) + if not normalized: + raise ValueError("Profile name cannot be empty") + self._profile = normalized + self._reload() + + # Notify users when switching to a non-default profile + if self._profile != "default": + log.info(f"Switched to configuration profile: '{self._profile}'") + + self._apply_handlers() + + def set_default_profile(self, profile: Optional[str]) -> None: + """Set the default profile to be used on startup. + + Parameters + ---------- + profile : Optional[str] + The profile name to use as default, or None to clear the default. + When set, this profile will be automatically loaded unless overridden + by environment variables (TIDY3D_CONFIG_PROFILE, TIDY3D_PROFILE, or TIDY3D_ENV). + + Notes + ----- + This setting is persisted to config.toml and survives across sessions. + Environment variables always take precedence over the default profile. + """ + + if profile is not None: + normalized = normalize_profile_name(profile) + if not normalized: + raise ValueError("Profile name cannot be empty") + self._loader.set_default_profile(normalized) + else: + self._loader.set_default_profile(None) + + def get_default_profile(self) -> Optional[str]: + """Get the currently configured default profile. + + Returns + ------- + Optional[str] + The default profile name if set, None otherwise. + """ + + return self._loader.get_default_profile() + + def save(self, include_defaults: bool = False) -> None: + if self._profile == "default": + # For base config: only save fields marked with persist=True + base_without_env = self._filter_persisted(self._compose_without_env()) + if include_defaults: + defaults = self._filter_persisted(self._default_tree()) + base_without_env = deep_merge(defaults, base_without_env) + self._loader.save_base(base_without_env) + else: + # For profile overrides: save any field that differs from baseline + # (don't filter by persist flag - profiles should save all customizations) + base_without_env = self._compose_without_env() + baseline = deep_merge(self._builtin_data, self._base_data) + diff = deep_diff(baseline, base_without_env) + self._loader.save_profile(self._profile, diff) + # refresh cached base/profile data after saving + self._base_data = self._loader.load_base() + self._profile_data = self._loader.load_user_profile(self._profile) + self._reload() + + def reset_to_defaults(self, *, include_profiles: bool = True) -> None: + """Reset configuration files to their default annotated state.""" + + self._runtime_overrides = defaultdict(dict) + defaults = self._filter_persisted(self._default_tree()) + self._loader.save_base(defaults) + + if include_profiles: + profiles_dir = self._loader.profile_path("_dummy").parent + if profiles_dir.exists(): + shutil.rmtree(profiles_dir) + loader_docs = getattr(self._loader, "_docs", {}) + for path in list(loader_docs.keys()): + try: + path.relative_to(profiles_dir) + except ValueError: + continue + loader_docs.pop(path, None) + self._profile = "default" + + self._reload() + self._apply_handlers() + + def apply_web_env(self, env_vars: Mapping[str, str]) -> None: + """Apply environment variable overrides for the web configuration section.""" + + self._restore_web_env() + for key, value in env_vars.items(): + self._web_env_previous[key] = os.environ.get(key) + os.environ[key] = value + + def _restore_web_env(self) -> None: + """Restore previously overridden environment variables.""" + + for key, previous in self._web_env_previous.items(): + if previous is None: + os.environ.pop(key, None) + else: + os.environ[key] = previous + self._web_env_previous.clear() + + def list_profiles(self) -> dict[str, list[str]]: + profiles_dir = self._loader.config_dir / "profiles" + user_profiles = [] + if profiles_dir.exists(): + for path in profiles_dir.glob("*.toml"): + user_profiles.append(path.stem) + built_in = sorted(name for name in BUILTIN_PROFILES.keys()) + return {"built_in": built_in, "user": sorted(user_profiles)} + + def preview_profile(self, profile: str) -> dict[str, Any]: + builtin = self._loader.get_builtin_profile(profile) + base = self._loader.load_base() + overrides = self._loader.load_user_profile(profile) + view = deep_merge(builtin, base, overrides) + return deepcopy(view) + + def get_section(self, name: str) -> BaseModel: + model = self._get_model(name) + if model is None: + raise AttributeError(f"Section '{name}' is not available") + return model + + def as_dict(self, include_env: bool = True) -> dict[str, Any]: + """Return the current configuration tree, including defaults for all sections.""" + + tree = self._compose_without_env() + if include_env: + tree = deep_merge(tree, self._env_overrides) + return deep_merge(self._default_tree(), tree) + + def __rich__(self) -> Panel: + """Return a rich renderable representation of the full configuration.""" + + return _build_config_panel( + title=f"Config (profile='{self._profile}')", + data=_prepare_for_display(self.as_dict(include_env=True)), + ) + + def format(self, *, include_env: bool = True) -> str: + """Return a human-friendly representation of the full configuration.""" + + panel = _build_config_panel( + title=f"Config (profile='{self._profile}')", + data=_prepare_for_display(self.as_dict(include_env=include_env)), + ) + return _render_panel(panel) + + def format_section(self, name: str) -> str: + """Return a string representation for an individual section.""" + + model = self._get_model(name) + if model is None: + raise AttributeError(f"Section '{name}' is not available") + data = _prepare_for_display(model.model_dump(exclude_unset=False)) + panel = _build_section_panel(name, data) + return _render_panel(panel) + + def on_section_registered(self, section: str) -> None: + self._reload() + self._apply_handlers(section=section) + + def on_handler_registered(self, section: str) -> None: + self._apply_handlers(section=section) + + def _resolve_initial_profile(self, profile: Optional[str]) -> str: + if profile: + return normalize_profile_name(str(profile)) + + # Check environment variables first (highest priority) + env_profile = ( + os.getenv("TIDY3D_CONFIG_PROFILE") + or os.getenv("TIDY3D_PROFILE") + or os.getenv("TIDY3D_ENV") + ) + if env_profile: + return normalize_profile_name(env_profile) + + # Check for default_profile in config file + config_default = self._loader.get_default_profile() + if config_default: + return normalize_profile_name(config_default) + + # Fall back to "default" profile + return "default" + + def _reload(self) -> None: + self._env_overrides = load_environment_overrides() + self._builtin_data = deepcopy(self._loader.get_builtin_profile(self._profile)) + self._base_data = deepcopy(self._loader.load_base()) + self._profile_data = deepcopy(self._loader.load_user_profile(self._profile)) + self._raw_tree = deep_merge(self._builtin_data, self._base_data, self._profile_data) + + runtime = deepcopy(self._runtime_overrides.get(self._profile, {})) + effective = deep_merge(self._raw_tree, self._env_overrides, runtime) + self._effective_tree = effective + self._build_models() + + def _build_models(self) -> None: + sections = get_sections() + new_sections: dict[str, BaseModel] = {} + new_plugins: dict[str, BaseModel] = {} + + errors: list[tuple[str, Exception]] = [] + for name, schema in sections.items(): + if name.startswith("plugins."): + plugin_name = name.split(".", 1)[1] + plugin_data = _deep_get(self._effective_tree, ("plugins", plugin_name)) or {} + try: + new_plugins[plugin_name] = schema(**plugin_data) + except Exception as exc: + log.error(f"Failed to load configuration for plugin '{plugin_name}': {exc}") + errors.append((name, exc)) + continue + if name == "plugins": + continue + section_data = self._effective_tree.get(name, {}) + try: + new_sections[name] = schema(**section_data) + except Exception as exc: + log.error(f"Failed to load configuration for section '{name}': {exc}") + errors.append((name, exc)) + + if errors: + # propagate the first error; others already logged + raise errors[0][1] + + self._section_models = new_sections + self._plugin_models = new_plugins + + def _get_model(self, name: str) -> Optional[BaseModel]: + if name.startswith("plugins."): + plugin = name.split(".", 1)[1] + return self._plugin_models.get(plugin) + return self._section_models.get(name) + + def _apply_handlers(self, section: Optional[str] = None) -> None: + handlers = get_handlers() + targets = [section] if section else handlers.keys() + for target in targets: + handler = handlers.get(target) + if handler is None: + continue + model = self._get_model(target) + if model is None: + continue + try: + handler(model) + except Exception as exc: + log.error(f"Failed to apply configuration handler for '{target}': {exc}") + + def _compose_without_env(self) -> dict[str, Any]: + runtime = self._runtime_overrides.get(self._profile, {}) + return deep_merge(self._raw_tree, runtime) + + def _default_tree(self) -> dict[str, Any]: + defaults: dict[str, Any] = {} + for name, schema in get_sections().items(): + if name.startswith("plugins."): + plugin = name.split(".", 1)[1] + defaults.setdefault("plugins", {})[plugin] = _model_dict(schema()) + elif name == "plugins": + defaults.setdefault("plugins", {}) + else: + defaults[name] = _model_dict(schema()) + return defaults + + def _filter_persisted(self, tree: dict[str, Any]) -> dict[str, Any]: + sections = get_sections() + filtered: dict[str, Any] = {} + plugins_source = tree.get("plugins", {}) + plugin_filtered: dict[str, Any] = {} + + for name, schema in sections.items(): + if name == "plugins": + continue + if name.startswith("plugins."): + plugin_name = name.split(".", 1)[1] + plugin_data = plugins_source.get(plugin_name, {}) + if not isinstance(plugin_data, dict): + continue + persisted_plugin = _extract_persisted(schema, plugin_data) + if persisted_plugin: + plugin_filtered[plugin_name] = persisted_plugin + continue + + section_data = tree.get(name, {}) + if not isinstance(section_data, dict): + continue + persisted_section = _extract_persisted(schema, section_data) + if persisted_section: + filtered[name] = persisted_section + + if plugin_filtered: + filtered["plugins"] = plugin_filtered + return filtered + + def __getattr__(self, name: str) -> Any: + if name in self._section_models: + return SectionAccessor(self, name) + if name == "plugins": + return self.plugins + raise AttributeError(f"Config has no section '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + if name.startswith("_"): + object.__setattr__(self, name, value) + return + if name in self._section_models: + if isinstance(value, BaseModel): + payload = value.model_dump(exclude_unset=False) + else: + payload = value + self.update_section(name, **payload) + return + object.__setattr__(self, name, value) + + def __str__(self) -> str: + return self.format() + + +def _deep_get(tree: dict[str, Any], path: Iterable[str]) -> Optional[dict[str, Any]]: + node: Any = tree + for segment in path: + if not isinstance(node, dict): + return None + node = node.get(segment) + if node is None: + return None + return node if isinstance(node, dict) else None + + +def _resolve_model_type(annotation: Any) -> Optional[type[BaseModel]]: + """Return the first BaseModel subclass found in an annotation (if any).""" + + if isinstance(annotation, type) and issubclass(annotation, BaseModel): + return annotation + + origin = get_origin(annotation) + if origin is None: + return None + + for arg in get_args(annotation): + nested = _resolve_model_type(arg) + if nested is not None: + return nested + return None + + +def _serialize_value(value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(exclude_unset=False) + if hasattr(value, "get_secret_value"): + return value.get_secret_value() + return value + + +def _prepare_for_display(value: Any) -> Any: + if isinstance(value, BaseModel): + return { + k: _prepare_for_display(v) for k, v in value.model_dump(exclude_unset=False).items() + } + if isinstance(value, dict): + return {str(k): _prepare_for_display(v) for k, v in value.items()} + if isinstance(value, (list, tuple, set)): + return [_prepare_for_display(v) for v in value] + if isinstance(value, Path): + return str(value) + if isinstance(value, Enum): + return value.value + if hasattr(value, "get_secret_value"): + displayed = getattr(value, "display", None) + if callable(displayed): + return displayed() + return str(value) + return value + + +def _build_config_panel(title: str, data: dict[str, Any]) -> Panel: + tree = Tree(Text(title, style="bold cyan")) + if data: + for key in sorted(data.keys()): + branch = tree.add(Text(key, style="bold magenta")) + branch.add(Pretty(data[key], expand_all=True)) + else: + tree.add(Text("", style="dim")) + return Panel(tree, border_style="cyan", padding=(0, 1)) + + +def _build_section_panel(name: str, data: Any) -> Panel: + tree = Tree(Text(name, style="bold cyan")) + tree.add(Pretty(data, expand_all=True)) + return Panel(tree, border_style="cyan", padding=(0, 1)) + + +def _render_panel(renderable: Panel, *, width: int = 100) -> str: + buffer = StringIO() + console = Console(file=buffer, record=True, force_terminal=True, width=width, color_system=None) + console.print(renderable) + return buffer.getvalue().rstrip() + + +def _model_dict(model: BaseModel) -> dict[str, Any]: + data = model.model_dump(exclude_unset=False) + for key, value in list(data.items()): + if hasattr(value, "get_secret_value"): + data[key] = value.get_secret_value() + return data + + +def _extract_persisted(schema: type[BaseModel], data: dict[str, Any]) -> dict[str, Any]: + persisted: dict[str, Any] = {} + for field_name, field in schema.model_fields.items(): + schema_extra = field.json_schema_extra or {} + annotation = field.annotation + persist = bool(schema_extra.get("persist")) if isinstance(schema_extra, dict) else False + if not persist: + continue + if field_name not in data: + continue + value = data[field_name] + if value is None: + persisted[field_name] = None + continue + + nested_type = _resolve_model_type(annotation) + if nested_type is not None: + nested_source = value if isinstance(value, dict) else {} + nested_persisted = _extract_persisted(nested_type, nested_source) + if nested_persisted: + persisted[field_name] = nested_persisted + continue + + if hasattr(value, "get_secret_value"): + persisted[field_name] = value.get_secret_value() + else: + persisted[field_name] = deepcopy(value) + + return persisted + + +__all__ = [ + "ConfigManager", + "PluginsAccessor", + "ProfilesAccessor", + "SectionAccessor", + "normalize_profile_name", +] diff --git a/tidy3d/_common/config/profiles.py b/tidy3d/_common/config/profiles.py new file mode 100644 index 0000000000..f73f1be562 --- /dev/null +++ b/tidy3d/_common/config/profiles.py @@ -0,0 +1,57 @@ +"""Built-in configuration profiles for tidy3d.""" + +from __future__ import annotations + +from typing import Any + +BUILTIN_PROFILES: dict[str, dict[str, Any]] = { + "default": { + "web": { + "api_endpoint": "https://tidy3d-api.simulation.cloud", + "website_endpoint": "https://tidy3d.simulation.cloud", + "s3_region": "us-gov-west-1", + } + }, + "prod": { + "web": { + "api_endpoint": "https://tidy3d-api.simulation.cloud", + "website_endpoint": "https://tidy3d.simulation.cloud", + "s3_region": "us-gov-west-1", + } + }, + "dev": { + "web": { + "api_endpoint": "https://tidy3d-api.dev-simulation.cloud", + "website_endpoint": "https://tidy3d.dev-simulation.cloud", + "s3_region": "us-east-1", + } + }, + "uat": { + "web": { + "api_endpoint": "https://tidy3d-api.uat-simulation.cloud", + "website_endpoint": "https://tidy3d.uat-simulation.cloud", + "s3_region": "us-west-2", + } + }, + "pre": { + "web": { + "api_endpoint": "https://preprod-tidy3d-api.simulation.cloud", + "website_endpoint": "https://preprod-tidy3d.simulation.cloud", + "s3_region": "us-gov-west-1", + } + }, + "nexus": { + "web": { + "api_endpoint": "http://127.0.0.1:5000", + "website_endpoint": "http://127.0.0.1/tidy3d", + "ssl_verify": False, + "enable_caching": False, + "s3_region": "us-east-1", + "env_vars": { + "AWS_ENDPOINT_URL_S3": "http://127.0.0.1:9000", + }, + } + }, +} + +__all__ = ["BUILTIN_PROFILES"] diff --git a/tidy3d/_common/config/registry.py b/tidy3d/_common/config/registry.py new file mode 100644 index 0000000000..7c1b16b7a1 --- /dev/null +++ b/tidy3d/_common/config/registry.py @@ -0,0 +1,83 @@ +"""Registry utilities for tidy3d configuration sections and handlers.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar + +from pydantic import BaseModel + +if TYPE_CHECKING: + from typing import Callable, Optional + +T = TypeVar("T", bound=BaseModel) + +_SECTIONS: dict[str, type[BaseModel]] = {} +_HANDLERS: dict[str, Callable[[BaseModel], None]] = {} +_MANAGER: Optional[ConfigManagerProtocol] = None + + +class ConfigManagerProtocol: + """Protocol-like interface for manager notifications.""" + + def on_section_registered(self, section: str) -> None: + """Called when a new section schema is registered.""" + + def on_handler_registered(self, section: str) -> None: + """Called when a handler is registered.""" + + +def attach_manager(manager: ConfigManagerProtocol) -> None: + """Attach the active configuration manager for registry callbacks.""" + + global _MANAGER + _MANAGER = manager + + +def get_manager() -> Optional[ConfigManagerProtocol]: + """Return the currently attached configuration manager, if any.""" + + return _MANAGER + + +def register_section(name: str) -> Callable[[type[T]], type[T]]: + """Decorator to register a configuration section schema.""" + + def decorator(cls: type[T]) -> type[T]: + _SECTIONS[name] = cls + if _MANAGER is not None: + _MANAGER.on_section_registered(name) + return cls + + return decorator + + +def register_plugin(name: str) -> Callable[[type[T]], type[T]]: + """Decorator to register a plugin configuration schema.""" + + return register_section(f"plugins.{name}") + + +def register_handler( + name: str, +) -> Callable[[Callable[[BaseModel], None]], Callable[[BaseModel], None]]: + """Decorator to register a handler for a configuration section.""" + + def decorator(func: Callable[[BaseModel], None]) -> Callable[[BaseModel], None]: + _HANDLERS[name] = func + if _MANAGER is not None: + _MANAGER.on_handler_registered(name) + return func + + return decorator + + +def get_sections() -> dict[str, type[BaseModel]]: + """Return registered section schemas.""" + + return dict(_SECTIONS) + + +def get_handlers() -> dict[str, Callable[[BaseModel], None]]: + """Return registered configuration handlers.""" + + return dict(_HANDLERS) diff --git a/tidy3d/_common/config/serializer.py b/tidy3d/_common/config/serializer.py new file mode 100644 index 0000000000..5db5dc5d97 --- /dev/null +++ b/tidy3d/_common/config/serializer.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, get_args, get_origin + +import tomlkit +from pydantic import BaseModel +from tomlkit.items import Item, Table + +from .registry import get_sections + +if TYPE_CHECKING: + from collections.abc import Iterable + + from pydantic.fields import FieldInfo + +Path = tuple[str, ...] + + +def collect_descriptions() -> dict[Path, str]: + """Collect description strings for registered configuration fields.""" + + descriptions: dict[Path, str] = {} + for section_name, model in get_sections().items(): + base_path = tuple(segment for segment in section_name.split(".") if segment) + section_doc = (model.__doc__ or "").strip() + if section_doc and base_path: + descriptions[base_path] = descriptions.get( + base_path, section_doc.splitlines()[0].strip() + ) + for field_name, field in model.model_fields.items(): + descriptions.update(_describe_field(field, prefix=(*base_path, field_name))) + return descriptions + + +def _describe_field(field: FieldInfo, prefix: Path) -> dict[Path, str]: + descriptions: dict[Path, str] = {} + description = (field.description or "").strip() + if description: + descriptions[prefix] = description + + nested_models: Iterable[type[BaseModel]] = _iter_model_types(field.annotation) + for model in nested_models: + nested_doc = (model.__doc__ or "").strip() + if nested_doc: + descriptions[prefix] = descriptions.get(prefix, nested_doc.splitlines()[0].strip()) + for sub_name, sub_field in model.model_fields.items(): + descriptions.update(_describe_field(sub_field, prefix=(*prefix, sub_name))) + return descriptions + + +def _iter_model_types(annotation: Any) -> Iterable[type[BaseModel]]: + """Yield BaseModel subclasses referenced by a field annotation (if any).""" + + if annotation is None: + return + + stack = [annotation] + seen: set[type[BaseModel]] = set() + + while stack: + current = stack.pop() + if isinstance(current, type) and issubclass(current, BaseModel): + if current not in seen: + seen.add(current) + yield current + continue + + origin = get_origin(current) + if origin is None: + continue + + stack.extend(get_args(current)) + + +def build_document( + data: dict[str, Any], + existing: tomlkit.TOMLDocument | None, + descriptions: dict[Path, str] | None = None, +) -> tomlkit.TOMLDocument: + """Return a TOML document populated with data and annotated comments.""" + + descriptions = descriptions or collect_descriptions() + document = existing if existing is not None else tomlkit.document() + _prune_missing_keys(document, data.keys()) + for key, value in data.items(): + _apply_value( + container=document, + key=key, + value=value, + path=(key,), + descriptions=descriptions, + is_new=key not in document, + ) + return document + + +def _prune_missing_keys(container: Table | tomlkit.TOMLDocument, keys: Iterable[str]) -> None: + desired = set(keys) + for existing_key in list(container.keys()): + if existing_key not in desired: + del container[existing_key] + + +def _apply_value( + container: Table | tomlkit.TOMLDocument, + key: str, + value: Any, + path: Path, + descriptions: dict[Path, str], + is_new: bool, +) -> None: + description = descriptions.get(path) + if isinstance(value, dict): + existing = container.get(key) + table = existing if isinstance(existing, Table) else tomlkit.table() + _prune_missing_keys(table, value.keys()) + for sub_key, sub_value in value.items(): + _apply_value( + container=table, + key=sub_key, + value=sub_value, + path=(*path, sub_key), + descriptions=descriptions, + is_new=not isinstance(existing, Table) or sub_key not in table, + ) + if key in container: + container[key] = table + else: + if isinstance(container, tomlkit.TOMLDocument) and len(container) > 0: + container.add(tomlkit.nl()) + container.add(key, table) + return + + if value is None: + return + + existing_item = container.get(key) + new_item = tomlkit.item(value) + if isinstance(existing_item, Item): + new_item.trivia.comment = existing_item.trivia.comment + new_item.trivia.comment_ws = existing_item.trivia.comment_ws + elif description: + new_item.comment(description) + + if key in container: + container[key] = new_item + else: + container.add(key, new_item) diff --git a/tidy3d/_common/constants.py b/tidy3d/_common/constants.py new file mode 100644 index 0000000000..81b168cad5 --- /dev/null +++ b/tidy3d/_common/constants.py @@ -0,0 +1,313 @@ +"""Defines importable constants. + +Attributes: + inf (float): Tidy3d representation of infinity. + C_0 (float): Speed of light in vacuum [um/s] + EPSILON_0 (float): Vacuum permittivity [F/um] + MU_0 (float): Vacuum permeability [H/um] + ETA_0 (float): Vacuum impedance + HBAR (float): reduced Planck constant [eV*s] + Q_e (float): funamental charge [C] +""" + +from __future__ import annotations + +from types import MappingProxyType + +import numpy as np + +# fundamental constants (https://physics.nist.gov) +C_0 = 2.99792458e14 +""" +Speed of light in vacuum [um/s] +""" + +MU_0 = 1.25663706212e-12 +""" +Vacuum permeability [H/um] +""" + +EPSILON_0 = 1 / (MU_0 * C_0**2) +""" +Vacuum permittivity [F/um] +""" + +#: Free space impedance +ETA_0 = np.sqrt(MU_0 / EPSILON_0) +""" +Vacuum impedance in Ohms +""" + +Q_e = 1.602176634e-19 +""" +Fundamental charge [C] +""" + +HBAR = 6.582119569e-16 +""" +Reduced Planck constant [eV*s] +""" + +K_B = 8.617333262e-5 +""" +Boltzmann constant [eV/K] +""" + +GRAV_ACC = 9.80665 * 1e6 +""" +Gravitational acceleration (g) [um/s^2].", +""" + +M_E_C_SQUARE = 0.51099895069e6 +""" +Electron rest mass energy (m_e * c^2) [eV] +""" + +M_E_EV = M_E_C_SQUARE / C_0**2 +""" +Electron mass [eV*s^2/um^2] +""" + +# floating point precisions +dp_eps = np.finfo(np.float64).eps +""" +Double floating point precision. +""" + +fp_eps = np.float64(np.finfo(np.float32).eps) +""" +Floating point precision. +""" + +# values of PEC for mode solver +pec_val = -1e8 +""" +PEC values for mode solver +""" + +# unit labels +HERTZ = "Hz" +""" +One cycle per second. +""" + +TERAHERTZ = "THz" +""" +One trillion (10^12) cycles per second. +""" + +SECOND = "sec" +""" +SI unit of time. +""" + +PICOSECOND = "ps" +""" +One trillionth (10^-12) of a second. +""" + +METER = "m" +""" +SI unit of length. +""" + +PERMETER = "1/m" +""" +SI unit of inverse length. +""" + +MICROMETER = "um" +""" +One millionth (10^-6) of a meter. +""" + +NANOMETER = "nm" +""" +One billionth (10^-9) of a meter. +""" + +RADIAN = "rad" +""" +SI unit of angle. +""" + +CONDUCTIVITY = "S/um" +""" +Siemens per micrometer. +""" + +PERMITTIVITY = "None (relative permittivity)" +""" +Relative permittivity. +""" + +PML_SIGMA = "2*EPSILON_0/dt" +""" +2 times vacuum permittivity over time differential step. +""" + +RADPERSEC = "rad/sec" +""" +One radian per second. +""" + +RADPERMETER = "rad/m" +""" +One radian per meter. +""" + +NEPERPERMETER = "Np/m" +""" +SI unit for attenuation constant. +""" + + +ELECTRON_VOLT = "eV" +""" +Unit of energy. +""" + +KELVIN = "K" +""" +SI unit of temperature. +""" + +CMCUBE = "cm^3" +""" +Cubic centimeter unit of volume. +""" + +PERCMCUBE = "1/cm^3" +""" +Unit per centimeter cube. +""" + +WATT = "W" +""" +SI unit of power. +""" + +VOLT = "V" +""" +SI unit of electric potential. +""" + +PICOSECOND_PER_NANOMETER_PER_KILOMETER = "ps/(nm km)" +""" +Picosecond per (nanometer kilometer). +""" + +OHM = "ohm" +""" +SI unit of resistance. +""" + +FARAD = "farad" +""" +SI unit of capacitance. +""" + +HENRY = "henry" +""" +SI unit of inductance. +""" + +AMP = "A" +""" +SI unit of electric current. +""" + +THERMAL_CONDUCTIVITY = "W/(um*K)" +""" +Watts per (micrometer Kelvin). +""" + +SPECIFIC_HEAT_CAPACITY = "J/(kg*K)" +""" +Joules per (kilogram Kelvin). +""" + +DENSITY = "kg/um^3" +""" +Kilograms per cubic micrometer. +""" + +HEAT_FLUX = "W/um^2" +""" +Watts per square micrometer. +""" + +VOLUMETRIC_HEAT_RATE = "W/um^3" +""" +Watts per cube micrometer. +""" + +HEAT_TRANSFER_COEFF = "W/(um^2*K)" +""" +Watts per (square micrometer Kelvin). +""" + +CURRENT_DENSITY = "A/um^2" +""" +Amperes per square micrometer +""" + +DYNAMIC_VISCOSITY = "kg/(um*s)" +""" +Kilograms per (micrometer second) +""" + +SPECIFIC_HEAT = "um^2/(s^2*K)" +""" +Square micrometers per (square second Kelvin). +""" + +THERMAL_EXPANSIVITY = "1/K" +""" +Inverse Kelvin. +""" + +VELOCITY_SI = "m/s" +""" +SI unit of velocity +""" + +ACCELERATION = "um/s^2" +""" +Acceleration unit. +""" + +LARGE_NUMBER = 1e10 +""" +Large number used for comparing infinity. +""" + +LARGEST_FP_NUMBER = 1e38 +""" +Largest number used for single precision floating point number. +""" + +inf = np.inf +""" +Representation of infinity used within tidy3d. +""" + +# if |np.pi/2 - angle_theta| < GLANCING_CUTOFF in an angled source or in mode spec, raise warning +GLANCING_CUTOFF = 0.1 +""" +if |np.pi/2 - angle_theta| < GLANCING_CUTOFF in an angled source or in mode spec, raise warning. +""" + +UnitScaling = MappingProxyType( + { + "nm": 1e3, + "μm": 1e0, + "um": 1e0, + "mm": 1e-3, + "cm": 1e-4, + "m": 1e-6, + "mil": 1.0 / 25.4, + "in": 1.0 / 25400, + } +) +"""Immutable dictionary for converting microns to another spatial unit, eg. nm = um * UnitScaling["nm"].""" diff --git a/tidy3d/_common/exceptions.py b/tidy3d/_common/exceptions.py new file mode 100644 index 0000000000..f92fdf5aaf --- /dev/null +++ b/tidy3d/_common/exceptions.py @@ -0,0 +1,64 @@ +"""Custom Tidy3D exceptions""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from tidy3d._common.log import log + +if TYPE_CHECKING: + from typing import Optional + + +class Tidy3dError(ValueError): + """Any error in tidy3d""" + + def __init__(self, message: Optional[str] = None, log_error: bool = True) -> None: + """Log just the error message and then raise the Exception.""" + super().__init__(message) + if log_error: + log.error(message) + + +class ConfigError(Tidy3dError): + """Error when configuring Tidy3d.""" + + +class Tidy3dKeyError(Tidy3dError): + """Could not find a key in a Tidy3d dictionary.""" + + +class ValidationError(Tidy3dError): + """Error when constructing Tidy3d components.""" + + +class SetupError(Tidy3dError): + """Error regarding the setup of the components (outside of domains, etc).""" + + +class FileError(Tidy3dError): + """Error reading or writing to file.""" + + +class WebError(Tidy3dError): + """Error with the webAPI.""" + + +class AuthenticationError(Tidy3dError): + """Error authenticating a user through webapi webAPI.""" + + +class DataError(Tidy3dError): + """Error accessing data.""" + + +class Tidy3dImportError(Tidy3dError): + """Error importing a package needed for tidy3d.""" + + +class Tidy3dNotImplementedError(Tidy3dError): + """Error when a functionality is not (yet) supported.""" + + +class AdjointError(Tidy3dError): + """An error in setting up the adjoint solver.""" diff --git a/tidy3d/_common/log.py b/tidy3d/_common/log.py new file mode 100644 index 0000000000..290ae8b0e5 --- /dev/null +++ b/tidy3d/_common/log.py @@ -0,0 +1,520 @@ +"""Logging Configuration for Tidy3d.""" + +from __future__ import annotations + +import inspect +from contextlib import contextmanager +from datetime import datetime +from typing import TYPE_CHECKING, Any, Literal, Union + +from rich.console import Console +from rich.text import Text + +if TYPE_CHECKING: + from collections.abc import Iterator + from os import PathLike + from types import TracebackType + from typing import Callable, Optional + + from pydantic import BaseModel + from rich.progress import Progress as RichProgress + + from tidy3d._common.compat import Self +# Note: "SUPPORT" and "USER" levels are meant for backend runs only. +# Logging in frontend code should just use the standard debug/info/warning/error/critical. +LogLevel = Literal["DEBUG", "SUPPORT", "USER", "INFO", "WARNING", "ERROR", "CRITICAL"] +LogValue = Union[int, LogLevel] + +# Logging levels compatible with logging module +_level_value = { + "DEBUG": 10, + "SUPPORT": 12, + "USER": 15, + "INFO": 20, + "WARNING": 30, + "ERROR": 40, + "CRITICAL": 50, +} + +_level_name = {v: k for k, v in _level_value.items()} + +DEFAULT_LEVEL = "WARNING" + +DEFAULT_LOG_STYLES = { + "DEBUG": None, + "SUPPORT": None, + "USER": None, + "INFO": None, + "WARNING": "red", + "ERROR": "red bold", + "CRITICAL": "red bold", +} + +# Width of the console used for rich logging (in characters). +CONSOLE_WIDTH = 80 + + +def _default_log_level_format(level: str, message: str) -> tuple[str, str]: + """By default just return unformatted prefix and message.""" + return level, message + + +def _get_level_int(level: LogValue) -> int: + """Get the integer corresponding to the level string.""" + if isinstance(level, int): + return level + + if level not in _level_value: + # We don't want to import ConfigError to avoid a circular dependency + raise ValueError( + f"logging level {level} not supported, must be " + "'DEBUG', 'SUPPORT', 'USER', 'INFO', 'WARNING', 'ERROR', or 'CRITICAL'" + ) + return _level_value[level] + + +class LogHandler: + """Handle log messages depending on log level""" + + def __init__( + self, + console: Console, + level: LogValue, + log_level_format: Callable = _default_log_level_format, + prefix_every_line: bool = False, + ) -> None: + self.level = _get_level_int(level) + self.console = console + self.log_level_format = log_level_format + self.prefix_every_line = prefix_every_line + + def handle(self, level: int, level_name: str, message: str) -> None: + """Output log messages depending on log level""" + if level >= self.level: + stack = inspect.stack() + console = self.console + offset = 4 + if stack[offset - 1].filename.endswith("exceptions.py"): + # We want the calling site for exceptions.py + offset += 1 + prefix, msg = self.log_level_format(level_name, message) + if self.prefix_every_line: + wrapped_text = Text(msg, style="default") + msgs = wrapped_text.wrap(console=console, width=console.width - len(prefix) - 2) + else: + msgs = [msg] + for msg in msgs: + console.log( + prefix, + msg, + sep=": ", + style=DEFAULT_LOG_STYLES[level_name], + _stack_offset=offset, + ) + + +class Logger: + """Custom logger to avoid the complexities of the logging module. + + Notes + ----- + The logger can be used in a context manager to avoid the emission of multiple messages. In this + case, the first message in the context is emitted normally, but any others are discarded. When + the context is exited, the number of discarded messages of each level is displayed with the + highest level of the captures messages. + + Messages can also be captured for post-processing. That can be enabled through 'set_capture' to + record warnings emitted during model validation (and other explicit begin/end capture regions, + e.g. validation routines like ``validate_pre_upload``). A structured copy of captured warnings + can then be recovered through 'captured_warnings'. + """ + + _static_cache = set() + + def __init__(self) -> None: + self.handlers = {} + self.suppression = True + self.warn_once = False + self._counts = None + self._stack = None + self._capture = False + self._captured_warnings = [] + + def set_capture(self, capture: bool) -> None: + """Turn on/off tree-like capturing of log messages.""" + self._capture = capture + + def captured_warnings(self) -> list[dict[str, Any]]: + """Get the formatted list of captured log messages.""" + captured_warnings = self._captured_warnings + self._captured_warnings = [] + return captured_warnings + + def __enter__(self) -> Self: + """If suppression is enabled, enter a consolidation context (only a single message is + emitted).""" + if self.suppression and self._counts is None: + self._counts = {} + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Literal[False]: + """Exist a consolidation context (report the number of messages discarded).""" + if self._counts is not None: + total = sum(v for v in self._counts.values()) + if total > 0: + max_level = max(k for k, v in self._counts.items() if v > 0) + counts = [f"{v} {_level_name[k]}" for k, v in self._counts.items() if v > 0] + self._counts = None + if total > 0: + noun = " messages." if total > 1 else " message." + # Temporarily prevent capturing messages to emit consolidated summary + stack = self._stack + self._stack = None + self.log(max_level, "Suppressed " + ", ".join(counts) + noun) + self._stack = stack + return False + + def begin_capture(self) -> None: + """Start capturing log stack for consolidated validation log. + + This method should be called before a validation routine starts. It must be followed by a + corresponding 'end_capture'. + """ + if not self._capture: + return + + stack_item = {"messages": [], "children": {}} + if self._stack: + self._stack.append(stack_item) + else: + self._stack = [stack_item] + + def abort_capture(self) -> None: + """Undo the last ``begin_capture()`` call. + + This is used when validation fails before reaching the corresponding ``end_capture()``. + """ + if not self._stack: + return + + self._stack.pop() + if len(self._stack) == 0: + self._stack = None + + def end_capture(self, model: BaseModel) -> None: + """End capturing log stack for consolidated validation log. + + This method should be called after a validation routine ends. It must follow a + corresponding 'begin_capture'. + """ + if not self._stack: + return + + stack_item = self._stack.pop() + if len(self._stack) == 0: + self._stack = None + + # Check if this stack item contains any messages or children + if len(stack_item["messages"]) > 0 or len(stack_item["children"]) > 0: + stack_item["type"] = model.__class__.__name__ + + # Set the path for each children + model_fields = model.get_submodels_by_hash() + for child_hash, child_dict in stack_item["children"].items(): + child_dict["parent_fields"] = model_fields.get(child_hash, []) + + # Are we at the bottom of the stack? + if self._stack is None: + # Yes, we're root + self._parse_warning_capture(current_loc=[], stack_item=stack_item) + else: + # No, we're someone else's child + hash_ = hash(model) + self._stack[-1]["children"][hash_] = stack_item + + def _parse_warning_capture(self, current_loc: list[Any], stack_item: dict[str, Any]) -> None: + """Process capture tree to compile formatted captured warnings.""" + + if "parent_fields" in stack_item: + for field in stack_item["parent_fields"]: + if isinstance(field, tuple): + # array field + new_loc = current_loc + list(field) + else: + # single field + new_loc = [*current_loc, field] + + # process current level warnings + for level, msg, custom_loc in stack_item["messages"]: + if level == "WARNING": + self._captured_warnings.append({"loc": new_loc + custom_loc, "msg": msg}) + + # initialize processing at children level + for child_stack in stack_item["children"].values(): + self._parse_warning_capture(current_loc=new_loc, stack_item=child_stack) + + else: # for root object + # process current level warnings + for level, msg, custom_loc in stack_item["messages"]: + if level == "WARNING": + self._captured_warnings.append({"loc": current_loc + custom_loc, "msg": msg}) + + # initialize processing at children level + for child_stack in stack_item["children"].values(): + self._parse_warning_capture(current_loc=current_loc, stack_item=child_stack) + + def _log( + self, + level: int, + level_name: str, + message: str, + *args: Any, + log_once: bool = False, + custom_loc: Optional[list] = None, + capture: bool = True, + ) -> None: + """Distribute log messages to all handlers""" + + # Check global cache if requested or if warn_once is enabled for warnings + # (before composing/capturing to avoid duplicates) + should_check_cache = log_once or (self.warn_once and level_name == "WARNING") + if should_check_cache: + # Use the message body before composition as key + if message in self._static_cache: + return + self._static_cache.add(message) + + # Compose message + if len(args) > 0: + try: + composed_message = str(message) % args + + except Exception as e: + composed_message = f"{message} % {args}\n{e}" + else: + composed_message = str(message) + + # Capture all messages (even if suppressed later) + if self._stack and capture: + if custom_loc is None: + custom_loc = [] + self._stack[-1]["messages"].append((level_name, composed_message, custom_loc)) + + # Context-local logger emits a single message and consolidates the rest + if self._counts is not None: + if len(self._counts) > 0: + self._counts[level] = 1 + self._counts.get(level, 0) + return + self._counts[level] = 0 + + # Forward message to handlers + for handler in self.handlers.values(): + handler.handle(level, level_name, composed_message) + + def log(self, level: LogValue, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) with given level""" + if isinstance(level, str): + level_name = level + level = _get_level_int(level) + else: + level_name = _level_name.get(level, "unknown") + self._log(level, level_name, message, *args, log_once=log_once) + + def debug(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at debug level""" + self._log(_level_value["DEBUG"], "DEBUG", message, *args, log_once=log_once) + + def support(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at support level""" + self._log(_level_value["SUPPORT"], "SUPPORT", message, *args, log_once=log_once) + + def user(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at user level""" + self._log(_level_value["USER"], "USER", message, *args, log_once=log_once) + + def info(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at info level""" + self._log(_level_value["INFO"], "INFO", message, *args, log_once=log_once) + + def warning( + self, + message: str, + *args: Any, + log_once: bool = False, + custom_loc: Optional[list] = None, + capture: bool = True, + ) -> None: + """Log (message) % (args) at warning level""" + self._log( + _level_value["WARNING"], + "WARNING", + message, + *args, + log_once=log_once, + custom_loc=custom_loc, + capture=capture, + ) + + def error(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at error level""" + self._log(_level_value["ERROR"], "ERROR", message, *args, log_once=log_once) + + def critical(self, message: str, *args: Any, log_once: bool = False) -> None: + """Log (message) % (args) at critical level""" + self._log(_level_value["CRITICAL"], "CRITICAL", message, *args, log_once=log_once) + + +def set_logging_level(level: LogValue = DEFAULT_LEVEL) -> None: + """Set tidy3d console logging level priority. + + Parameters + ---------- + level : str + The lowest priority level of logging messages to display. One of ``{'DEBUG', 'SUPPORT', + 'USER', INFO', 'WARNING', 'ERROR', 'CRITICAL'}`` (listed in increasing priority). + """ + if "console" in log.handlers: + log.handlers["console"].level = _get_level_int(level) + + +def set_log_suppression(value: bool) -> None: + """Control log suppression for repeated messages.""" + log.suppression = value + + +def set_warn_once(value: bool) -> None: + """Control whether warnings are only shown once per unique message. + + Parameters + ---------- + value : bool + When True, each unique warning message is only shown once per process. + """ + log.warn_once = value + + +def get_aware_datetime() -> datetime: + """Get an aware current local datetime(with local timezone info)""" + return datetime.now().astimezone() + + +def set_logging_console(stderr: bool = False) -> None: + """Set stdout or stderr as console output + + Parameters + ---------- + stderr : bool + If False, logs are directed to stdout, otherwise to stderr. + """ + if "console" in log.handlers: + previous_level = log.handlers["console"].level + else: + previous_level = DEFAULT_LEVEL + log.handlers["console"] = LogHandler( + Console( + stderr=stderr, + width=CONSOLE_WIDTH, + log_path=False, + get_datetime=get_aware_datetime, + log_time_format="%X %Z", + ), + previous_level, + ) + + +def set_logging_file( + fname: PathLike, + filemode: str = "w", + level: LogValue = DEFAULT_LEVEL, + log_path: bool = False, +) -> None: + """Set a file to write log to, independently from the stdout and stderr + output chosen using :meth:`set_logging_level`. + + Parameters + ---------- + fname : PathLike + Path to file to direct the output to. If empty string, a previously set logging file will + be closed, if any, but nothing else happens. + filemode : str + 'w' or 'a', defining if the file should be overwritten or appended. + level : str + One of ``{'DEBUG', 'SUPPORT', 'USER', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'}``. This is set + for the file independently of the console output level set by :meth:`set_logging_level`. + log_path : bool = False + Whether to log the path to the file that issued the message. + """ + if filemode not in "wa": + raise ValueError("filemode must be either 'w' or 'a'") + + # Close previous handler, if any + if "file" in log.handlers: + try: + log.handlers["file"].console.file.close() + except Exception: # TODO: catch specific exception + log.warning("Log file could not be closed") + finally: + del log.handlers["file"] + + if str(fname) == "": + # Empty string can be passed to just stop previously opened file handler + return + + try: + file = open(fname, filemode) + except Exception: # TODO: catch specific exception + log.error(f"File {fname} could not be opened") + return + + log.handlers["file"] = LogHandler( + Console(file=file, force_jupyter=False, log_path=log_path), level + ) + + +# Initialize Tidy3d's logger +log = Logger() + +# Set default logging output +set_logging_console() + + +def get_logging_console() -> Console: + """Get console from logging handlers.""" + if "console" not in log.handlers: + set_logging_console() + return log.handlers["console"].console + + +class NoOpProgress: + """Dummy progress manager that doesn't show any output.""" + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + pass + + def add_task(self, *args: Any, **kwargs: Any) -> None: + pass + + def update(self, *args: Any, **kwargs: Any) -> None: + pass + + +@contextmanager +def Progress(console: Console, show_progress: bool) -> Iterator[Union[RichProgress, NoOpProgress]]: + """Progress manager that wraps ``rich.Progress`` if ``show_progress`` is ``True``, + and ``NoOpProgress`` otherwise.""" + if show_progress: + from rich.progress import Progress + + with Progress(console=console) as progress: + yield progress + else: + with NoOpProgress() as progress: + yield progress diff --git a/tidy3d/_common/packaging.py b/tidy3d/_common/packaging.py new file mode 100644 index 0000000000..497e2ae0b0 --- /dev/null +++ b/tidy3d/_common/packaging.py @@ -0,0 +1,284 @@ +""" +This file contains a set of functions relating to packaging tidy3d for distribution. Sections of the codebase should depend on this file, but this file should not depend on any other part of the codebase. + +This section should only depend on the standard core installation in the pyproject.toml, and should not depend on any other part of the codebase optional imports. +""" + +from __future__ import annotations + +import functools +from importlib import import_module +from importlib.util import find_spec +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +import numpy as np + +from tidy3d._common.exceptions import Tidy3dImportError +from tidy3d.version import __version__ # TODO solve + +if TYPE_CHECKING: + from typing import Literal + +F = TypeVar("F", bound=Callable[..., Any]) + +vtk = { + "mod": None, + "id_type": np.int64, + "vtk_to_numpy": None, + "numpy_to_vtkIdTypeArray": None, + "numpy_to_vtk": None, +} + + +def check_import(module_name: str) -> bool: + """ + Check if a module or submodule section has been imported. This is a functional way of loading packages that will still load the corresponding module into the total space. + + Parameters + ---------- + module_name + + Returns + ------- + bool + True if the module has been imported, False otherwise. + + """ + try: + import_module(module_name) + return True + except ImportError: + return False + + +def verify_packages_import( + modules: list[str], required: Literal["any", "all"] = "all" +) -> Callable[[F], F]: + def decorator(func: F) -> F: + """ + When decorating a method, requires that the specified modules are available. It will raise an error if the + module is not available depending on the value of the 'required' parameter which represents the type of + import required. + + There are a few options to choose for the 'required' parameter: + - 'all': All the modules must be available for the operation to continue without raising an error + - 'any': At least one of the modules must be available for the operation to continue without raising an error + + Parameters + ---------- + func + The function to decorate. + + Returns + ------- + checks_modules_import + The decorated function. + + """ + + @functools.wraps(func) + def checks_modules_import(*args: Any, **kwargs: Any) -> Any: + """ + Checks if the modules are available. If they are not available, it will raise an error depending on the value. + """ + available_modules_status = [] + maximum_amount_modules = len(modules) + + module_id_i = 0 + for module in modules: + # Starts counting from one so that it can be compared to len(modules) + module_id_i += 1 + import_available = check_import(module) + available_modules_status.append( + import_available + ) # Stores the status of the module import + + if not import_available: + if required == "all": + raise Tidy3dImportError( + f"The package '{module}' is required for this operation, but it was not found. " + f"Please install the '{module}' dependencies using, for example, " + f"'pip install tidy3d[]" + ) + if required == "any": + # Means we need to verify that at least one of the modules is available + if ( + not any(available_modules_status) + ) and module_id_i == maximum_amount_modules: + # Means that we have reached the last module and none of them were available + raise Tidy3dImportError( + f"The package '{module}' is required for this operation, but it was not found. " + f"Please install the '{module}' dependencies using, for example, " + f"'pip install tidy3d[]" + ) + else: + raise ValueError( + f"The value '{required}' is not a valid value for the 'required' parameter. " + f"Please use any 'all' or 'any'." + ) + else: + # Means that the module is available, so we can just continue with the operation + pass + return func(*args, **kwargs) + + return checks_modules_import + + return decorator + + +def requires_vtk(fn: F) -> F: + """When decorating a method, requires that vtk is available.""" + + @functools.wraps(fn) + def _fn(*args: Any, **kwargs: Any) -> Any: + if vtk["mod"] is None: + try: + import vtk as vtk_mod + from vtk.util.numpy_support import ( + numpy_to_vtk, + numpy_to_vtkIdTypeArray, + vtk_to_numpy, + ) + from vtkmodules.vtkCommonCore import vtkLogger + + vtk["mod"] = vtk_mod + vtk["vtk_to_numpy"] = vtk_to_numpy + vtk["numpy_to_vtkIdTypeArray"] = numpy_to_vtkIdTypeArray + vtk["numpy_to_vtk"] = numpy_to_vtk + + vtkLogger.SetStderrVerbosity(vtkLogger.VERBOSITY_WARNING) + + if vtk["mod"].vtkIdTypeArray().GetDataTypeSize() == 4: + vtk["id_type"] = np.int32 + + except ImportError as exc: + raise Tidy3dImportError( + "The package 'vtk' is required for this operation, but it was not found. " + "Please install the 'vtk' dependencies using, for example, " + "'pip install .[vtk]'." + ) from exc + + return fn(*args, **kwargs) + + return _fn + + +def get_numpy_major_version(module: Any = np) -> int: + """ + Extracts the major version of the installed numpy accordingly. + + Parameters + ---------- + module : module + The module to extract the version from. Default is numpy. + + Returns + ------- + int + The major version of the module. + """ + # Get the version of the module + module_version = module.__version__ + + # Extract the major version number + major_version = int(module_version.split(".")[0]) + + return major_version + + +tidy3d_extras = {"mod": None, "use_local_subpixel": None} + + +def _check_tidy3d_extras_available(quiet: bool = False) -> None: + """Helper function to check if 'tidy3d-extras' is available and version matched. + + Parameters + ---------- + quiet : bool + If True, suppress error logging when raising exceptions. + + Raises + ------ + Tidy3dImportError + If tidy3d-extras is not available or not properly initialized. + """ + if tidy3d_extras["mod"] is not None: + return + + module_exists = find_spec("tidy3d_extras") is not None + if not module_exists: + raise Tidy3dImportError( + "The package 'tidy3d-extras' is absent. " + "Please install the 'tidy3d-extras' package using, for " + r"example, 'pip install tidy3d\[extras]'.", + log_error=not quiet, + ) + + try: + import tidy3d_extras as tidy3d_extras_mod + + except ImportError as exc: + raise Tidy3dImportError( + "The package 'tidy3d-extras' did not initialize correctly.", + log_error=not quiet, + ) from exc + + if not hasattr(tidy3d_extras_mod, "__version__"): + raise Tidy3dImportError( + "The package 'tidy3d-extras' did not initialize correctly. " + "Please install the 'tidy3d-extras' package using, for " + r"example, 'pip install tidy3d\[extras]'.", + log_error=not quiet, + ) + + version = tidy3d_extras_mod.__version__ + + if version is None: + raise Tidy3dImportError( + "The package 'tidy3d-extras' did not initialize correctly, " + "likely due to an invalid API key.", + log_error=not quiet, + ) + + if version != __version__: + raise Tidy3dImportError( + f"The version of 'tidy3d-extras' is {version}, but the version of 'tidy3d' is {__version__}. " + "They must match. You can install the correct " + r"version using 'pip install tidy3d\[extras]'.", + log_error=not quiet, + ) + + tidy3d_extras["mod"] = tidy3d_extras_mod + + +def check_tidy3d_extras_licensed_feature(feature_name: str, quiet: bool = False) -> None: + """Helper function to check if a specific feature is licensed in 'tidy3d-extras'. + + Parameters + ---------- + feature_name : str + The name of the feature to check for. + quiet : bool + If True, suppress error logging when raising exceptions. + + Raises + ------ + Tidy3dImportError + If the feature is not available with your license. + """ + + try: + _check_tidy3d_extras_available(quiet=quiet) + except Tidy3dImportError as exc: + raise Tidy3dImportError( + f"The package 'tidy3d-extras' is required for this feature '{feature_name}'.", + log_error=not quiet, + ) from exc + + features = tidy3d_extras["mod"].extension._features() + if feature_name not in features: + raise Tidy3dImportError( + f"The feature '{feature_name}' is not available with your license. " + "Please contact Tidy3D support, or upgrade your license.", + log_error=not quiet, + ) diff --git a/tidy3d/_common/web/__init__.py b/tidy3d/_common/web/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tidy3d/_common/web/cache.py b/tidy3d/_common/web/cache.py new file mode 100644 index 0000000000..de7f07a46f --- /dev/null +++ b/tidy3d/_common/web/cache.py @@ -0,0 +1,884 @@ +"""Local simulation cache manager.""" + +from __future__ import annotations + +import hashlib +import json +import os +import shutil +import tempfile +import threading +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol + +from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt + +from tidy3d._common import config +from tidy3d._common.log import log +from tidy3d._common.web.core.http_util import get_version as _get_protocol_version + +if TYPE_CHECKING: + from collections.abc import Iterator + + from tidy3d._common.web.core.constants import TaskId + + +class CacheableSimulation(Protocol): + """Protocol for simulation objects that can be cached.""" + + def _hash_self(self) -> str: + """Return a stable hash for cache key construction.""" + + +_GetWorkflowType = Callable[[CacheableSimulation], str] +_get_workflow_type_callback: Optional[_GetWorkflowType] = None + + +def register_get_workflow_type(callback: _GetWorkflowType) -> None: + """Register workflow type resolver for cache logging.""" + global _get_workflow_type_callback + _get_workflow_type_callback = callback + + +def _get_workflow_type(simulation: CacheableSimulation) -> str: + if _get_workflow_type_callback is None: + return type(simulation).__name__ + try: + return _get_workflow_type_callback(simulation) + except Exception: + return type(simulation).__name__ + + +CACHE_ARTIFACT_NAME = "simulation_data.hdf5" +CACHE_METADATA_NAME = "metadata.json" +CACHE_STATS_NAME = "stats.json" + +TMP_PREFIX = "tidy3d-cache-" +TMP_BATCH_PREFIX = "tmp_batch" + +_CACHE: Optional[LocalCache] = None + + +def get_cache_entry_dir(root: os.PathLike, key: str) -> Path: + """ + Returns the cache directory for a given key. + A three-character prefix subdirectory is used to avoid hitting filesystem limits on the number of entries per folder. + """ + return Path(root) / key[:3] / key + + +class CacheStats(BaseModel): + """Lightweight summary of cache usage persisted in ``stats.json``.""" + + last_used: dict[str, str] = Field( + default_factory=dict, + description="Mapping from cache entry key to the most recent ISO-8601 access timestamp.", + ) + total_size: NonNegativeInt = Field( + default=0, + description="Aggregate size in bytes across cached artifacts captured in the stats file.", + ) + updated_at: Optional[datetime] = Field( + default=None, + description="UTC timestamp indicating when the statistics were last refreshed.", + ) + + model_config = ConfigDict(extra="allow", validate_assignment=True) + + @property + def total_entries(self) -> int: + return len(self.last_used) + + +class CacheEntryMetadata(BaseModel): + """Schema for cache entry metadata persisted on disk.""" + + cache_key: str + checksum: str + created_at: datetime + last_used: datetime + file_size: int = Field(ge=0) + simulation_hash: str + workflow_type: str + versions: Any + task_id: str + path: str + + model_config = ConfigDict(extra="allow", validate_assignment=True) + + def bump_last_used(self) -> None: + self.last_used = datetime.now(timezone.utc) + + def as_dict(self) -> dict[str, Any]: + return self.model_dump(mode="json") + + def get(self, key: str, default: Any = None) -> Any: + return self.as_dict().get(key, default) + + def __getitem__(self, key: str) -> Any: + data = self.as_dict() + if key not in data: + raise KeyError(key) + return data[key] + + +@dataclass +class CacheEntry: + """Internal representation of a cache entry.""" + + key: str + root: Path + metadata: CacheEntryMetadata + + @property + def path(self) -> Path: + return get_cache_entry_dir(self.root, self.key) + + @property + def artifact_path(self) -> Path: + return self.path / CACHE_ARTIFACT_NAME + + @property + def metadata_path(self) -> Path: + return self.path / CACHE_METADATA_NAME + + def exists(self) -> bool: + return self.path.exists() and self.artifact_path.exists() and self.metadata_path.exists() + + def verify(self) -> bool: + if not self.exists(): + return False + checksum = self.metadata.checksum + if not checksum: + return False + try: + actual_checksum, file_size = _copy_and_hash(self.artifact_path, None) + except FileNotFoundError: + return False + if checksum != actual_checksum: + log.warning( + "Simulation cache checksum mismatch for key '%s'. Removing stale entry.", self.key + ) + return False + if self.metadata.file_size != file_size: + self.metadata.file_size = file_size + _write_metadata(self.metadata_path, self.metadata) + return True + + def materialize(self, target: Path) -> Path: + """Copy cached artifact to ``target`` and return the resulting path.""" + target = Path(target) + target.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(self.artifact_path, target) + return target + + +class LocalCache: + """Manages storing and retrieving cached simulation artifacts.""" + + def __init__(self, directory: os.PathLike, max_size_gb: float, max_entries: int) -> None: + self.max_size_gb = max_size_gb + self.max_entries = max_entries + self._root = Path(directory) + self._lock = threading.RLock() + self._syncing_stats = False + self._sync_pending = False + + @property + def _stats_path(self) -> Path: + return self._root / CACHE_STATS_NAME + + def _schedule_sync(self) -> None: + self._sync_pending = True + + def _run_pending_sync(self) -> None: + if self._sync_pending and not self._syncing_stats: + self._sync_pending = False + self.sync_stats() + + @contextmanager + def _with_lock(self) -> Iterator[None]: + self._run_pending_sync() + with self._lock: + yield + self._run_pending_sync() + + def _write_stats(self, stats: CacheStats) -> CacheStats: + updated = stats.model_copy(update={"updated_at": datetime.now(timezone.utc)}) + payload = updated.model_dump(mode="json") + payload["total_entries"] = updated.total_entries + self._stats_path.parent.mkdir(parents=True, exist_ok=True) + _write_metadata(self._stats_path, payload) + self._sync_pending = False + return updated + + def _load_stats(self, *, rebuild: bool = False) -> CacheStats: + path = self._stats_path + if not path.exists(): + if not self._syncing_stats: + self._schedule_sync() + return CacheStats() + try: + data = json.loads(path.read_text(encoding="utf-8")) + if "last_used" not in data and "entries" in data: + data["last_used"] = data.pop("entries") + stats = CacheStats.model_validate(data) + except Exception: + if rebuild and not self._syncing_stats: + self._schedule_sync() + return CacheStats() + if stats.total_size < 0: + self._schedule_sync() + return CacheStats() + return stats + + def _record_store_stats( + self, + key: str, + *, + last_used: str, + file_size: int, + previous_size: int, + ) -> None: + stats = self._load_stats() + entries = dict(stats.last_used) + entries[key] = last_used + total_size = stats.total_size - previous_size + file_size + if total_size < 0: + total_size = 0 + self._schedule_sync() + updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) + self._write_stats(updated) + + def _record_touch_stats( + self, key: str, last_used: str, *, file_size: Optional[int] = None + ) -> None: + stats = self._load_stats() + entries = dict(stats.last_used) + existed = key in entries + total_size = stats.total_size + if not existed and file_size is not None: + total_size += file_size + if total_size < 0: + total_size = 0 + self._schedule_sync() + entries[key] = last_used + updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) + self._write_stats(updated) + + def _record_remove_stats(self, key: str, file_size: int) -> None: + stats = self._load_stats() + entries = dict(stats.last_used) + entries.pop(key, None) + total_size = stats.total_size - file_size + if total_size < 0: + total_size = 0 + self._schedule_sync() + updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) + self._write_stats(updated) + + def _enforce_limits_post_sync(self, entries: list[CacheEntry]) -> None: + if not entries: + return + + entries_map = {entry.key: entry.metadata.last_used.isoformat() for entry in entries} + + if self.max_entries > 0 and len(entries) > self.max_entries: + excess = len(entries) - self.max_entries + self._evict(entries_map, remove_count=excess, exclude_keys=set()) + + max_size_bytes = int(self.max_size_gb * (1024**3)) + if max_size_bytes > 0: + total_size = sum(entry.metadata.file_size for entry in entries) + if total_size > max_size_bytes: + bytes_to_free = total_size - max_size_bytes + self._evict_by_size(entries_map, bytes_to_free, exclude_keys=set()) + + def sync_stats(self) -> CacheStats: + with self._lock: + self._syncing_stats = True + log.debug("Syncing stats.json of local cache") + try: + entries: list[CacheEntry] = [] + last_used_map: dict[str, str] = {} + total_size = 0 + for entry in self._iter_entries(): + entries.append(entry) + total_size += entry.metadata.file_size + last_used_map[entry.key] = entry.metadata.last_used.isoformat() + stats = CacheStats(last_used=last_used_map, total_size=total_size) + written = self._write_stats(stats) + self._enforce_limits_post_sync(entries) + return written + finally: + self._syncing_stats = False + + @property + def root(self) -> Path: + return self._root + + def list(self) -> list[dict[str, Any]]: + """Return metadata for all cache entries.""" + with self._with_lock(): + entries = [entry.metadata.model_dump(mode="json") for entry in self._iter_entries()] + return entries + + def clear(self, hard: bool = False) -> None: + """Remove all cache contents. If set to hard, root directory is removed.""" + with self._with_lock(): + if self._root.exists(): + try: + shutil.rmtree(self._root) + if not hard: + self._root.mkdir(parents=True, exist_ok=True) + except (FileNotFoundError, OSError): + pass + if not hard: + self._write_stats(CacheStats()) + + def _fetch(self, key: str) -> Optional[CacheEntry]: + """Retrieve an entry by key, verifying checksum.""" + with self._with_lock(): + entry = self._load_entry(key) + if not entry or not entry.exists(): + return None + if not entry.verify(): + self._remove_entry(entry) + return None + self._touch(entry) + return entry + + def __len__(self) -> int: + """Return number of valid cache entries.""" + with self._with_lock(): + count = self._load_stats().total_entries + return count + + def _store( + self, key: str, source_path: Path, metadata: CacheEntryMetadata + ) -> Optional[CacheEntry]: + """Store a new cache entry from ``source_path``. + + Parameters + ---------- + key : str + Cache key computed from simulation hash and runtime context. + source_path : Path + Location of the artifact to cache. + metadata : CacheEntryMetadata + Metadata describing the cache entry to be persisted. + + Returns + ------- + CacheEntry + Representation of the stored cache entry. + """ + source_path = Path(source_path) + if not source_path.exists(): + raise FileNotFoundError(f"Cannot cache missing artifact: {source_path}") + os.makedirs(self._root, exist_ok=True) + tmp_dir = Path(tempfile.mkdtemp(prefix=TMP_PREFIX, dir=self._root)) + tmp_artifact = tmp_dir / CACHE_ARTIFACT_NAME + tmp_meta = tmp_dir / CACHE_METADATA_NAME + os.makedirs(tmp_dir, exist_ok=True) + + checksum, file_size = _copy_and_hash(source_path, tmp_artifact) + metadata.cache_key = key + metadata.created_at = datetime.now(timezone.utc) + metadata.last_used = metadata.created_at + metadata.checksum = checksum + metadata.file_size = file_size + + _write_metadata(tmp_meta, metadata) + entry: Optional[CacheEntry] = None + try: + with self._with_lock(): + self._root.mkdir(parents=True, exist_ok=True) + existing_entry = self._load_entry(key) + previous_size = ( + existing_entry.metadata.file_size if existing_entry is not None else 0 + ) + self._ensure_limits( + file_size, + incoming_key=key, + replacing_size=previous_size, + ) + final_dir = get_cache_entry_dir(self._root, key) + final_dir.parent.mkdir(parents=True, exist_ok=True) + if final_dir.exists(): + shutil.rmtree(final_dir) + os.replace(tmp_dir, final_dir) + entry = CacheEntry(key=key, root=self._root, metadata=metadata) + + self._record_store_stats( + key, + last_used=metadata.last_used.isoformat(), + file_size=file_size, + previous_size=previous_size, + ) + log.debug("Stored simulation cache entry '%s' (%d bytes).", key, file_size) + finally: + try: + if tmp_dir.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + except FileNotFoundError: + pass + return entry + + def invalidate(self, key: str) -> None: + with self._with_lock(): + entry = self._load_entry(key) + if entry: + self._remove_entry(entry) + + def _ensure_limits( + self, + incoming_size: int, + *, + incoming_key: Optional[str] = None, + replacing_size: int = 0, + ) -> None: + max_entries = self.max_entries + max_size_bytes = int(self.max_size_gb * (1024**3)) + + try: + incoming_size_int = int(incoming_size) + except (TypeError, ValueError): + incoming_size_int = 0 + if incoming_size_int < 0: + incoming_size_int = 0 + + stats = self._load_stats() + entries_info = dict(stats.last_used) + existing_keys = set(entries_info) + projected_entries = stats.total_entries + if not incoming_key or incoming_key not in existing_keys: + projected_entries += 1 + + if projected_entries > max_entries > 0: + excess = projected_entries - max_entries + exclude = {incoming_key} if incoming_key else set() + self._evict(entries_info, remove_count=excess, exclude_keys=exclude) + stats = self._load_stats() + entries_info = dict(stats.last_used) + existing_keys = set(entries_info) + + if max_size_bytes == 0: # no limit + return + + existing_size = stats.total_size + try: + replacing_size_int = int(replacing_size) + except (TypeError, ValueError): + replacing_size_int = 0 + if incoming_key and incoming_key in existing_keys: + projected_size = existing_size - replacing_size_int + incoming_size_int + else: + projected_size = existing_size + incoming_size_int + + if max_size_bytes > 0 and projected_size > max_size_bytes: + bytes_to_free = projected_size - max_size_bytes + exclude = {incoming_key} if incoming_key else set() + self._evict_by_size(entries_info, bytes_to_free, exclude_keys=exclude) + + def _evict(self, entries: dict[str, str], *, remove_count: int, exclude_keys: set[str]) -> None: + if remove_count <= 0: + return + candidates = [(key, entries.get(key, "")) for key in entries if key not in exclude_keys] + if not candidates: + return + candidates.sort(key=lambda item: item[1] or "") + for key, _ in candidates[:remove_count]: + self._remove_entry_by_key(key) + + def _evict_by_size( + self, entries: dict[str, str], bytes_to_free: int, *, exclude_keys: set[str] + ) -> None: + if bytes_to_free <= 0: + return + candidates = [(key, entries.get(key, "")) for key in entries if key not in exclude_keys] + if not candidates: + return + candidates.sort(key=lambda item: item[1] or "") + reclaimed = 0 + for key, _ in candidates: + if reclaimed >= bytes_to_free: + break + entry = self._load_entry(key) + if entry is None: + log.debug("Could not find entry for eviction.") + self._schedule_sync() + break + size = entry.metadata.file_size + self._remove_entry(entry) + reclaimed += size + log.info(f"Simulation cache evicted entry '{key}' to reclaim {size} bytes.") + + def _iter_entries(self) -> Iterator[CacheEntry]: + """Iterate lazily over all cache entries, including those in prefix subdirectories.""" + if not self._root.exists(): + return + + for prefix_dir in self._root.iterdir(): + if not prefix_dir.is_dir() or prefix_dir.name.startswith( + (TMP_PREFIX, TMP_BATCH_PREFIX) + ): + continue + + # if cache is directly flat (no prefix directories), include that level too + subdirs = [prefix_dir] + if any((prefix_dir / name).is_dir() for name in prefix_dir.iterdir()): + subdirs = prefix_dir.iterdir() + + for child in subdirs: + if not child.is_dir(): + continue + if child.name.startswith((TMP_PREFIX, TMP_BATCH_PREFIX)): + continue + + meta_path = child / CACHE_METADATA_NAME + if not meta_path.exists(): + continue + + try: + metadata = _read_metadata(meta_path, child / CACHE_ARTIFACT_NAME) + except Exception: + log.debug( + "Failed to parse metadata for '%s'; scheduling stats sync.", child.name + ) + self._schedule_sync() + continue + + yield CacheEntry(key=child.name, root=self._root, metadata=metadata) + + def _load_entry(self, key: str) -> Optional[CacheEntry]: + entry = CacheEntry(key=key, root=self._root, metadata={}) + if not entry.metadata_path.exists() or not entry.artifact_path.exists(): + return None + try: + metadata = _read_metadata(entry.metadata_path, entry.artifact_path) + except Exception: + return None + return CacheEntry(key=key, root=self._root, metadata=metadata) + + def _touch(self, entry: CacheEntry) -> None: + entry.metadata.bump_last_used() + _write_metadata(entry.metadata_path, entry.metadata) + self._record_touch_stats( + entry.key, + entry.metadata.last_used.isoformat(), + file_size=entry.metadata.file_size, + ) + + def _remove_entry_by_key(self, key: str) -> None: + entry = self._load_entry(key) + if entry is None: + path = get_cache_entry_dir(self._root, key) + if path.exists(): + shutil.rmtree(path, ignore_errors=True) + else: + log.debug("Could not find entry for key '%s' to delete.", key) + self._record_remove_stats(key, 0) + return + self._remove_entry(entry) + + def _remove_entry(self, entry: CacheEntry) -> None: + file_size = entry.metadata.file_size + if entry.path.exists(): + shutil.rmtree(entry.path, ignore_errors=True) + self._record_remove_stats(entry.key, file_size) + + def try_fetch( + self, + simulation: CacheableSimulation, + verbose: bool = False, + ) -> Optional[CacheEntry]: + """ + Attempt to resolve and fetch a cached result entry for the given simulation context. + On miss or any cache error, returns None (the caller should proceed with upload/run). + """ + try: + simulation_hash = simulation._hash_self() + workflow_type = _get_workflow_type(simulation) + + versions = _get_protocol_version() + + cache_key = build_cache_key( + simulation_hash=simulation_hash, + version=versions, + ) + + entry = self._fetch(cache_key) + if not entry: + return None + + if verbose: + log.info( + f"Simulation cache hit for workflow '{workflow_type}'; using local results." + ) + + return entry + except Exception as e: + log.error("Failed to fetch cache results: " + str(e)) + return None + + def store_result( + self, + task_id: TaskId, + path: str, + workflow_type: str, + simulation: CacheableSimulation, + ) -> bool: + """ + Stores completed workflow results in the local cache using a canonical cache key. + + Parameters + ---------- + task_id : str + Unique identifier of the finished workflow task. + path : str + Path to the results file on disk. + workflow_type : str + Type of workflow associated with the results (e.g., ``"SIMULATION"`` or ``"MODE_SOLVER"``). + simulation : :class:`.CacheableSimulation` + Simulation object to use when computing the cache key. If not provided, + it will be inferred from ``stub_data.simulation`` when possible. + + Returns + ------- + bool + ``True`` if the result was successfully stored in the local cache, ``False`` otherwise. + + Notes + ----- + The cache entry is keyed by the simulation hash, workflow type, environment, and protocol version. + This enables automatic reuse of identical simulation results across future runs. + Legacy task ID mappings are recorded to support backward lookup compatibility. + """ + try: + simulation_hash = simulation._hash_self() + if not simulation_hash: + log.debug("Failed storing local cache entry: Could not hash simulation.") + return False + + version = _get_protocol_version() + + cache_key = build_cache_key( + simulation_hash=simulation_hash, + version=version, + ) + + metadata = build_entry_metadata( + simulation_hash=simulation_hash, + workflow_type=workflow_type, + task_id=task_id, + version=version, + path=Path(path), + ) + + self._store( + key=cache_key, + source_path=Path(path), + metadata=metadata, + ) + log.debug("Stored local cache entry for workflow type '%s'.", workflow_type) + except Exception as e: + log.error(f"Could not store cache entry: {e}") + return False + return True + + +def _copy_and_hash( + source: Path, dest: Optional[Path], existing_hash: Optional[str] = None +) -> tuple[str, int]: + """Copy ``source`` to ``dest`` while computing SHA256 checksum. + + Parameters + ---------- + source : Path + Source file path. + dest : Path or None + Destination file path. If ``None``, no copy is performed. + existing_hash : str, optional + If provided alongside ``dest`` and ``dest`` already exists, skip copying when hashes match. + + Returns + ------- + tuple[str, int] + The hexadecimal digest and file size in bytes. + """ + source = Path(source) + if dest is not None: + dest = Path(dest) + sha256 = _Hasher() + size = 0 + with source.open("rb") as src: + if dest is None: + while chunk := src.read(1024 * 1024): + sha256.update(chunk) + size += len(chunk) + else: + dest.parent.mkdir(parents=True, exist_ok=True) + with dest.open("wb") as dst: + while chunk := src.read(1024 * 1024): + dst.write(chunk) + sha256.update(chunk) + size += len(chunk) + return sha256.hexdigest(), size + + +def _write_metadata(path: Path, metadata: CacheEntryMetadata | dict[str, Any]) -> None: + tmp_path = path.with_suffix(".tmp") + payload: dict[str, Any] + if isinstance(metadata, CacheEntryMetadata): + payload = metadata.model_dump(mode="json") + else: + payload = metadata + with tmp_path.open("w", encoding="utf-8") as fh: + json.dump(payload, fh, indent=2, sort_keys=True) + os.replace(tmp_path, path) + + +def _now() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _timestamp_suffix() -> str: + return datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S%f") + + +def _read_metadata(meta_path: Path, artifact_path: Path) -> CacheEntryMetadata: + raw = json.loads(meta_path.read_text(encoding="utf-8")) + if "file_size" not in raw: + try: + raw["file_size"] = artifact_path.stat().st_size + except FileNotFoundError: + raw["file_size"] = 0 + raw.setdefault("created_at", _now()) + raw.setdefault("last_used", raw["created_at"]) + raw.setdefault("cache_key", meta_path.parent.name) + return CacheEntryMetadata.model_validate(raw) + + +class _Hasher: + def __init__(self) -> None: + self._hasher = hashlib.sha256() + + def update(self, data: bytes) -> None: + self._hasher.update(data) + + def hexdigest(self) -> str: + return self._hasher.hexdigest() + + +def clear() -> None: + """Remove all cache entries.""" + cache = resolve_local_cache(use_cache=True) + if cache is not None: + cache.clear() + + +def _canonicalize(value: Any) -> Any: + """Convert value into a JSON-serializable object for hashing/metadata.""" + + if isinstance(value, dict): + return { + str(k): _canonicalize(v) + for k, v in sorted(value.items(), key=lambda item: str(item[0])) + } + if isinstance(value, (list, tuple)): + return [_canonicalize(v) for v in value] + if isinstance(value, set): + return sorted(_canonicalize(v) for v in value) + if isinstance(value, Enum): + return value.value + if isinstance(value, Path): + return str(value) + if isinstance(value, datetime): + return value.isoformat() + if isinstance(value, bytes): + return value.decode("utf-8", errors="ignore") + return value + + +def build_cache_key( + *, + simulation_hash: str, + version: str, +) -> str: + """Construct a deterministic cache key.""" + + payload = { + "simulation_hash": simulation_hash, + "versions": _canonicalize(version), + } + encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + +def build_entry_metadata( + *, + simulation_hash: str, + workflow_type: str, + task_id: str, + version: str, + path: Path, +) -> CacheEntryMetadata: + """Create metadata object for a cache entry.""" + + now = datetime.now(timezone.utc) + return CacheEntryMetadata( + cache_key="", + checksum="", + created_at=now, + last_used=now, + file_size=0, + simulation_hash=simulation_hash, + workflow_type=workflow_type, + versions=_canonicalize(version), + task_id=task_id, + path=str(path), + ) + + +def resolve_local_cache(use_cache: Optional[bool] = None) -> Optional[LocalCache]: + """ + Returns LocalCache instance if enabled. + Returns None if use_cached=False or config-fetched 'enabled' is False. + Deletes old cache directory if existing. + """ + global _CACHE + + if use_cache is False or (use_cache is not True and not config.local_cache.enabled): + return None + + if _CACHE is not None and _CACHE._root != Path(config.local_cache.directory): + old_root = _CACHE._root + new_root = Path(config.local_cache.directory) + log.debug(f"Moving cache directory from {old_root} → {new_root}") + try: + new_root.parent.mkdir(parents=True, exist_ok=True) + if old_root.exists(): + shutil.move(old_root, new_root) + except Exception as e: + log.warning(f"Failed to move cache directory: {e}. Delete old cache.") + shutil.rmtree(old_root) + + _CACHE = LocalCache( + directory=config.local_cache.directory, + max_entries=config.local_cache.max_entries, + max_size_gb=config.local_cache.max_size_gb, + ) + + try: + return _CACHE + except Exception as err: + log.debug(f"Simulation cache unavailable: {err}") + return None + + +resolve_local_cache() diff --git a/tidy3d/_common/web/core/__init__.py b/tidy3d/_common/web/core/__init__.py new file mode 100644 index 0000000000..f1a0e1eaa8 --- /dev/null +++ b/tidy3d/_common/web/core/__init__.py @@ -0,0 +1,8 @@ +"""Tidy3d core package imports""" + +from __future__ import annotations + +# TODO(FXC-3827): Drop this import once the legacy shim is removed in Tidy3D 2.12. +from tidy3d._common.web.core import environment + +__all__ = ["environment"] diff --git a/tidy3d/_common/web/core/account.py b/tidy3d/_common/web/core/account.py new file mode 100644 index 0000000000..aefc41bd61 --- /dev/null +++ b/tidy3d/_common/web/core/account.py @@ -0,0 +1,66 @@ +"""Tidy3d user account.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from tidy3d._common.web.core.http_util import http +from tidy3d._common.web.core.types import Tidy3DResource + + +class Account(Tidy3DResource, extra="allow"): + """Tidy3D User Account.""" + + allowance_cycle_type: Optional[str] = Field( + None, + title="AllowanceCycleType", + description="Daily or Monthly", + alias="allowanceCycleType", + ) + credit: Optional[float] = Field( + 0, title="credit", description="Current FlexCredit balance", alias="credit" + ) + credit_expiration: Optional[datetime] = Field( + None, + title="creditExpiration", + description="Expiration date", + alias="creditExpiration", + ) + allowance_current_cycle_amount: Optional[float] = Field( + 0, + title="allowanceCurrentCycleAmount", + description="Daily/Monthly free simulation balance", + alias="allowanceCurrentCycleAmount", + ) + allowance_current_cycle_end_date: Optional[datetime] = Field( + None, + title="allowanceCurrentCycleEndDate", + description="Daily/Monthly free simulation balance expiration date", + alias="allowanceCurrentCycleEndDate", + ) + daily_free_simulation_counts: Optional[int] = Field( + 0, + title="dailyFreeSimulationCounts", + description="Daily free simulation counts", + alias="dailyFreeSimulationCounts", + ) + + @classmethod + def get(cls) -> Optional[Account]: + """Get user account information. + + Parameters + ---------- + + Returns + ------- + account : Account + """ + resp = http.get("tidy3d/py/account") + if resp: + account = Account(**resp) + return account + return None diff --git a/tidy3d/_common/web/core/cache.py b/tidy3d/_common/web/core/cache.py new file mode 100644 index 0000000000..d83421ca21 --- /dev/null +++ b/tidy3d/_common/web/core/cache.py @@ -0,0 +1,6 @@ +"""Local caches.""" + +from __future__ import annotations + +FOLDER_CACHE = {} +S3_STS_TOKENS = {} diff --git a/tidy3d/_common/web/core/constants.py b/tidy3d/_common/web/core/constants.py new file mode 100644 index 0000000000..623af2bba8 --- /dev/null +++ b/tidy3d/_common/web/core/constants.py @@ -0,0 +1,38 @@ +"""Defines constants for core.""" + +# HTTP Header key and value +from __future__ import annotations + +HEADER_APIKEY = "simcloud-api-key" +HEADER_VERSION = "tidy3d-python-version" +HEADER_SOURCE = "source" +HEADER_SOURCE_VALUE = "Python" +HEADER_USER_AGENT = "User-Agent" +HEADER_APPLICATION = "Application" +HEADER_APPLICATION_VALUE = "TIDY3D" + + +SIMCLOUD_APIKEY = "SIMCLOUD_APIKEY" +KEY_APIKEY = "apikey" +JSON_TAG = "JSON_STRING" +# type of the task_id +TaskId = str +# type of task_name +TaskName = str + + +SIMULATION_JSON = "simulation.json" +SIMULATION_DATA_HDF5 = "output/monitor_data.hdf5" +SIMULATION_DATA_HDF5_GZ = "output/simulation_data.hdf5.gz" +RUNNING_INFO = "output/solver_progress.csv" +SIM_LOG_FILE = "output/tidy3d.log" +SIM_FILE_HDF5 = "simulation.hdf5" +SIM_FILE_HDF5_GZ = "simulation.hdf5.gz" +MODE_FILE_HDF5_GZ = "mode_solver.hdf5.gz" +MODE_DATA_HDF5_GZ = "output/mode_solver_data.hdf5.gz" +SIM_ERROR_FILE = "output/tidy3d_error.json" +SIM_VALIDATION_FILE = "output/tidy3d_validation.json" + +# Component modeler specific artifacts +MODELER_FILE_HDF5_GZ = "modeler.hdf5.gz" +CM_DATA_HDF5_GZ = "output/cm_data.hdf5.gz" diff --git a/tidy3d/_common/web/core/core_config.py b/tidy3d/_common/web/core/core_config.py new file mode 100644 index 0000000000..3a9a517c71 --- /dev/null +++ b/tidy3d/_common/web/core/core_config.py @@ -0,0 +1,50 @@ +"""Tidy3d core log, need init config from Tidy3d api""" + +from __future__ import annotations + +import logging as log +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from rich.console import Console + + from tidy3d._common.log import Logger + +# default setting +config_setting = { + "logger": log, + "logger_console": None, + "version": "", +} + + +def set_config(logger: Logger, logger_console: Console, version: str) -> None: + """Init tidy3d core logger and logger console. + + Parameters + ---------- + logger : :class:`.Logger` + Tidy3d log Logger. + logger_console : :class:`.Console` + Get console from logging handlers. + version : str + tidy3d version + """ + config_setting["logger"] = logger + config_setting["logger_console"] = logger_console + config_setting["version"] = version + + +def get_logger() -> Logger: + """Get logging handlers.""" + return config_setting["logger"] + + +def get_logger_console() -> Console: + """Get console from logging handlers.""" + return config_setting["logger_console"] + + +def get_version() -> str: + """Get version from cache.""" + return config_setting["version"] diff --git a/tidy3d/_common/web/core/environment.py b/tidy3d/_common/web/core/environment.py new file mode 100644 index 0000000000..58bd8ceaef --- /dev/null +++ b/tidy3d/_common/web/core/environment.py @@ -0,0 +1,42 @@ +"""Legacy re-export of configuration environment helpers.""" + +from __future__ import annotations + +# TODO(FXC-3827): Remove this module-level legacy shim in Tidy3D 2.12. +import warnings +from typing import Any + +from tidy3d._common.config import Env, Environment, EnvironmentConfig + +__all__ = [ # noqa: F822 + "Env", + "Environment", + "EnvironmentConfig", + "dev", + "nexus", + "pre", + "prod", + "uat", +] + +_LEGACY_ENV_NAMES = {"dev", "uat", "pre", "prod", "nexus"} +_DEPRECATION_MESSAGE = ( + "'tidy3d.web.core.environment.{name}' is deprecated and will be removed in " + "Tidy3D 2.12. Transition to 'tidy3d.config.Env.{name}' or " + "'tidy3d.config.config.switch_profile(...)'." +) + + +def _get_legacy_env(name: str) -> Any: + warnings.warn(_DEPRECATION_MESSAGE.format(name=name), DeprecationWarning, stacklevel=2) + return getattr(Env, name) + + +def __getattr__(name: str) -> Any: + if name in _LEGACY_ENV_NAMES: + return _get_legacy_env(name) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +def __dir__() -> list[str]: + return sorted(set(__all__)) diff --git a/tidy3d/_common/web/core/exceptions.py b/tidy3d/_common/web/core/exceptions.py new file mode 100644 index 0000000000..4ca929f239 --- /dev/null +++ b/tidy3d/_common/web/core/exceptions.py @@ -0,0 +1,24 @@ +"""Custom Tidy3D exceptions""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from tidy3d._common.web.core.core_config import get_logger + +if TYPE_CHECKING: + from typing import Optional + + +class WebError(Exception): + """Any error in tidy3d""" + + def __init__(self, message: Optional[str] = None) -> None: + """Log just the error message and then raise the Exception.""" + log = get_logger() + super().__init__(message) + log.error(message) + + +class WebNotFoundError(WebError): + """A generic error indicating an HTTP 404 (resource not found).""" diff --git a/tidy3d/_common/web/core/file_util.py b/tidy3d/_common/web/core/file_util.py new file mode 100644 index 0000000000..41a9dac0a2 --- /dev/null +++ b/tidy3d/_common/web/core/file_util.py @@ -0,0 +1,87 @@ +"""File compression utilities""" + +from __future__ import annotations + +import gzip +import os +import shutil +import tempfile + +import h5py + +from tidy3d._common.web.core.constants import JSON_TAG + + +def compress_file_to_gzip(input_file: os.PathLike, output_gz_file: os.PathLike) -> None: + """Compresses a file using gzip. + + Parameters + ---------- + input_file : PathLike + The path of the input file. + output_gz_file : PathLike + The path of the output gzip file. + """ + with open(input_file, "rb") as file_in: + with gzip.open(output_gz_file, "wb") as file_out: + shutil.copyfileobj(file_in, file_out) + + +def extract_gzip_file(input_gz_file: os.PathLike, output_file: os.PathLike) -> None: + """Extract a gzip file. + + Parameters + ---------- + input_gz_file : PathLike + The path of the gzip input file. + output_file : PathLike + The path of the output file. + """ + with gzip.open(input_gz_file, "rb") as file_in: + with open(output_file, "wb") as file_out: + shutil.copyfileobj(file_in, file_out) + + +def read_simulation_from_hdf5_gz(file_name: os.PathLike) -> str: + """read simulation str from hdf5.gz""" + + hdf5_file, hdf5_file_path = tempfile.mkstemp(".hdf5") + os.close(hdf5_file) + try: + extract_gzip_file(file_name, hdf5_file_path) + # Pass the uncompressed temporary file path to the reader + json_str = read_simulation_from_hdf5(hdf5_file_path) + finally: + os.unlink(hdf5_file_path) + return json_str + + +"""TODO: _json_string_key and read_simulation_from_hdf5 are duplicated functions that also exist +as methods in Tidy3dBaseModel. For consistency it would be best if this duplication is avoided.""" + + +def _json_string_key(index: int) -> str: + """Get json string key for string chunk number ``index``.""" + if index: + return f"{JSON_TAG}_{index}" + return JSON_TAG + + +def read_simulation_from_hdf5(file_name: os.PathLike) -> bytes: + """read simulation str from hdf5""" + with h5py.File(file_name, "r") as f_handle: + num_string_parts = len([key for key in f_handle.keys() if JSON_TAG in key]) + json_string = b"" + for ind in range(num_string_parts): + json_string += f_handle[_json_string_key(ind)][()] + return json_string + + +"""End TODO""" + + +def read_simulation_from_json(file_name: os.PathLike) -> str: + """read simulation str from json""" + with open(file_name, encoding="utf-8") as json_file: + json_data = json_file.read() + return json_data diff --git a/tidy3d/_common/web/core/http_util.py b/tidy3d/_common/web/core/http_util.py new file mode 100644 index 0000000000..70668e3c86 --- /dev/null +++ b/tidy3d/_common/web/core/http_util.py @@ -0,0 +1,285 @@ +"""Http connection pool and authentication management.""" + +from __future__ import annotations + +import json +import os +import ssl +from enum import Enum +from functools import wraps +from typing import TYPE_CHECKING, Any + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.ssl_ import create_urllib3_context + +from tidy3d._common import log +from tidy3d._common.config import config +from tidy3d._common.web.core import core_config +from tidy3d._common.web.core.constants import ( + HEADER_APIKEY, + HEADER_APPLICATION, + HEADER_APPLICATION_VALUE, + HEADER_SOURCE, + HEADER_SOURCE_VALUE, + HEADER_USER_AGENT, + HEADER_VERSION, + SIMCLOUD_APIKEY, +) +from tidy3d._common.web.core.core_config import get_logger +from tidy3d._common.web.core.exceptions import WebError, WebNotFoundError + +if TYPE_CHECKING: + from typing import Callable, Optional, TypeAlias + +JSONType: TypeAlias = dict[str, Any] | list[Any] | str | int + + +class ResponseCodes(Enum): + """HTTP response codes to handle individually.""" + + UNAUTHORIZED = 401 + OK = 200 + NOT_FOUND = 404 + + +def get_version() -> str: + """Get the version for the current environment.""" + return core_config.get_version() + # return "2.10.0rc2.1" + + +def get_user_agent() -> str: + """Get the user agent the current environment.""" + return os.environ.get("TIDY3D_AGENT", f"Python-Client/{get_version()}") + + +def api_key() -> Optional[str]: + """Get the api key for the current environment.""" + + if os.environ.get(SIMCLOUD_APIKEY): + return os.environ.get(SIMCLOUD_APIKEY) + + try: + apikey = config.web.apikey + except AttributeError: + return None + + if apikey is None: + return None + if hasattr(apikey, "get_secret_value"): + return apikey.get_secret_value() + return str(apikey) + + +def api_key_auth(request: requests.request) -> requests.request: + """Save the authentication info in a request. + + Parameters + ---------- + request : requests.request + The original request to set authentication for. + + Returns + ------- + requests.request + The request with authentication set. + """ + key = api_key() + version = get_version() + if not key: + raise ValueError( + "API key not found. To get your API key, sign into 'https://tidy3d.simulation.cloud' " + "and copy it from your 'Account' page. Then you can configure tidy3d through command " + "line 'tidy3d configure' and enter your API key when prompted. " + "Alternatively, especially if using windows, you can manually create the configuration " + "file by creating a file at their home directory '~/.tidy3d/config' (unix) or " + "'.tidy3d/config' (windows) containing the following line: " + "apikey = 'XXX'. Here XXX is your API key copied from your account page within quotes." + ) + if not version: + raise ValueError("version not found.") + + request.headers[HEADER_APIKEY] = key + request.headers[HEADER_VERSION] = version + request.headers[HEADER_SOURCE] = HEADER_SOURCE_VALUE + request.headers[HEADER_USER_AGENT] = get_user_agent() + return request + + +def get_headers() -> dict[str, Optional[str]]: + """get headers for http request. + + Returns + ------- + dict[str, str] + dictionary with "Authorization" and "Application" keys. + """ + return { + HEADER_APIKEY: api_key(), + HEADER_APPLICATION: HEADER_APPLICATION_VALUE, + HEADER_USER_AGENT: get_user_agent(), + } + + +def http_interceptor(func: Callable[..., Any]) -> Callable[..., JSONType]: + """Intercept the response and raise an exception if the status code is not 200.""" + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> JSONType: + """The wrapper function.""" + suppress_404 = kwargs.pop("suppress_404", False) + + # Extend some capabilities of func + resp = func(*args, **kwargs) + + if resp.status_code != ResponseCodes.OK.value: + if resp.status_code == ResponseCodes.NOT_FOUND.value: + if suppress_404: + return None + raise WebNotFoundError("Resource not found (HTTP 404).") + try: + json_resp = resp.json() + except Exception: + json_resp = None + + # Build a helpful error message using available fields + err_msg = None + if isinstance(json_resp, dict): + parts = [] + for key in ("error", "message", "msg", "detail", "code", "httpStatus", "warning"): + val = json_resp.get(key) + if not val: + continue + if key == "error": + # Always include the raw 'error' payload for debugging. Also try to extract a nested message. + if isinstance(val, str): + try: + nested = json.loads(val) + if isinstance(nested, dict): + nested_msg = ( + nested.get("message") + or nested.get("error") + or nested.get("msg") + ) + if nested_msg: + parts.append(str(nested_msg)) + except Exception: + pass + parts.append(f"error={val}") + else: + parts.append(f"error={val!s}") + continue + parts.append(str(val)) + if parts: + err_msg = "; ".join(parts) + if not err_msg: + # Fallback to response text or status code + err_msg = resp.text or f"HTTP {resp.status_code}" + + # Append request context to aid debugging + try: + method = getattr(resp.request, "method", "") + url = getattr(resp.request, "url", "") + err_msg = f"{err_msg} [HTTP {resp.status_code} {method} {url}]" + except Exception: + pass + + raise WebError(err_msg) + + if not resp.text: + return None + result = resp.json() + + if isinstance(result, dict): + warning = result.get("warning") + if warning: + log = get_logger() + log.warning(warning) + + if "data" in result: + return result["data"] + + return result + + return wrapper + + +class TLSAdapter(HTTPAdapter): + def init_poolmanager(self, *args: Any, **kwargs: Any) -> None: + try: + ssl_version = ( + ssl.TLSVersion[config.web.ssl_version] + if config.web.ssl_version is not None + else None + ) + except KeyError: + log.warning(f"Invalid SSL/TLS version '{config.web.ssl_version}', using default") + ssl_version = None + context = create_urllib3_context(ssl_version=ssl_version) + kwargs["ssl_context"] = context + return super().init_poolmanager(*args, **kwargs) + + +class HttpSessionManager: + """Http util class.""" + + def __init__(self, session: requests.Session) -> None: + """Initialize the session.""" + self.session = session + self._mounted_ssl_version = None + self._ensure_tls_adapter(config.web.ssl_version) + self.session.verify = config.web.ssl_verify + + def reinit(self) -> None: + """Reinitialize the session.""" + ssl_version = config.web.ssl_version + self._ensure_tls_adapter(ssl_version) + self.session.verify = config.web.ssl_verify + + def _ensure_tls_adapter(self, ssl_version: str) -> None: + if not ssl_version: + self._mounted_ssl_version = None + return + if self._mounted_ssl_version != ssl_version: + self.session.mount("https://", TLSAdapter()) + self._mounted_ssl_version = ssl_version + + @http_interceptor + def get( + self, path: str, json: JSONType = None, params: Optional[dict[str, Any]] = None + ) -> requests.Response: + """Get the resource.""" + self.reinit() + return self.session.get( + url=config.web.build_api_url(path), auth=api_key_auth, json=json, params=params + ) + + @http_interceptor + def post(self, path: str, json: JSONType = None) -> requests.Response: + """Create the resource.""" + self.reinit() + return self.session.post(config.web.build_api_url(path), json=json, auth=api_key_auth) + + @http_interceptor + def put( + self, path: str, json: JSONType = None, files: Optional[dict[str, Any]] = None + ) -> requests.Response: + """Update the resource.""" + self.reinit() + return self.session.put( + config.web.build_api_url(path), json=json, auth=api_key_auth, files=files + ) + + @http_interceptor + def delete( + self, path: str, json: JSONType = None, params: Optional[dict[str, Any]] = None + ) -> requests.Response: + """Delete the resource.""" + self.reinit() + return self.session.delete( + config.web.build_api_url(path), auth=api_key_auth, json=json, params=params + ) + + +http = HttpSessionManager(requests.Session()) diff --git a/tidy3d/_common/web/core/s3utils.py b/tidy3d/_common/web/core/s3utils.py new file mode 100644 index 0000000000..eb275fea82 --- /dev/null +++ b/tidy3d/_common/web/core/s3utils.py @@ -0,0 +1,471 @@ +"""handles filesystem, storage""" + +from __future__ import annotations + +import os +import tempfile +import urllib +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import boto3 +from boto3.s3.transfer import TransferConfig +from pydantic import BaseModel, Field +from rich.progress import ( + BarColumn, + DownloadColumn, + Progress, + TextColumn, + TimeRemainingColumn, + TransferSpeedColumn, +) + +from tidy3d._common.config import config +from tidy3d._common.web.core.core_config import get_logger_console +from tidy3d._common.web.core.exceptions import WebError +from tidy3d._common.web.core.file_util import extract_gzip_file +from tidy3d._common.web.core.http_util import http + +if TYPE_CHECKING: + from collections.abc import Mapping + from os import PathLike + from typing import Callable, Optional + + import rich + +IN_TRANSIT_SUFFIX = ".tmp" + + +class _UserCredential(BaseModel): + """Stores information about user credentials.""" + + access_key_id: str = Field(alias="accessKeyId") + expiration: datetime + secret_access_key: str = Field(alias="secretAccessKey") + session_token: str = Field(alias="sessionToken") + + +class _S3STSToken(BaseModel): + """Stores information about S3 token.""" + + cloud_path: str = Field(alias="cloudpath") + user_credential: _UserCredential = Field(alias="userCredentials") + + def get_bucket(self) -> str: + """Get the bucket name for this token.""" + + r = urllib.parse.urlparse(self.cloud_path) + return r.netloc + + def get_s3_key(self) -> str: + """Get the s3 key for this token.""" + + r = urllib.parse.urlparse(self.cloud_path) + return r.path[1:] + + def get_client(self) -> boto3.client: + """Get the boto client for this token. + + Automatically configures custom S3 endpoint if specified in web.env_vars. + """ + + client_kwargs = { + "service_name": "s3", + "region_name": config.web.s3_region, + "aws_access_key_id": self.user_credential.access_key_id, + "aws_secret_access_key": self.user_credential.secret_access_key, + "aws_session_token": self.user_credential.session_token, + "verify": config.web.ssl_verify, + } + + # Add custom S3 endpoint if configured (e.g., for Nexus deployments) + if config.web.env_vars and "AWS_ENDPOINT_URL_S3" in config.web.env_vars: + s3_endpoint = config.web.env_vars["AWS_ENDPOINT_URL_S3"] + client_kwargs["endpoint_url"] = s3_endpoint + + return boto3.client(**client_kwargs) + + def is_expired(self) -> bool: + """True if token is expired.""" + + return ( + self.user_credential.expiration + - datetime.now(tz=self.user_credential.expiration.tzinfo) + ).total_seconds() < 300 + + +class UploadProgress: + """Updates progressbar with the upload status. + + Attributes + ---------- + progress : rich.progress.Progress() + Progressbar instance from rich + ul_task : rich.progress.Task + Progressbar task instance. + """ + + def __init__(self, size_bytes: int, progress: rich.progress.Progress) -> None: + """initialize with the size of file and rich.progress.Progress() instance. + + Parameters + ---------- + size_bytes: int + Number of total bytes to upload. + progress : rich.progress.Progress() + Progressbar instance from rich + """ + self.progress = progress + self.ul_task = self.progress.add_task("[red]Uploading...", total=size_bytes) + + def report(self, bytes_in_chunk: Any) -> None: + """Update the progressbar with the most recent chunk. + + Parameters + ---------- + bytes_in_chunk : int + Description + """ + self.progress.update(self.ul_task, advance=bytes_in_chunk) + + +class DownloadProgress: + """Updates progressbar using the download status. + + Attributes + ---------- + progress : rich.progress.Progress() + Progressbar instance from rich + ul_task : rich.progress.Task + Progressbar task instance. + """ + + def __init__(self, size_bytes: int, progress: rich.progress.Progress) -> None: + """initialize with the size of file and rich.progress.Progress() instance + + Parameters + ---------- + size_bytes: float + Number of total bytes to download. + progress : rich.progress.Progress() + Progressbar instance from rich + """ + self.progress = progress + self.dl_task = self.progress.add_task("[red]Downloading...", total=size_bytes) + + def report(self, bytes_in_chunk: int) -> None: + """Update the progressbar with the most recent chunk. + + Parameters + ---------- + bytes_in_chunk : float + Description + """ + self.progress.update(self.dl_task, advance=bytes_in_chunk) + + +class _S3Action(Enum): + UPLOADING = "↑" + DOWNLOADING = "↓" + + +def _get_progress(action: _S3Action) -> Progress: + """Get the progress of an action.""" + + col = ( + TextColumn(f"[bold green]{_S3Action.DOWNLOADING.value}") + if action == _S3Action.DOWNLOADING + else TextColumn(f"[bold red]{_S3Action.UPLOADING.value}") + ) + return Progress( + col, + TextColumn("[bold blue]{task.fields[filename]}"), + BarColumn(), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + DownloadColumn(), + "•", + TransferSpeedColumn(), + "•", + TimeRemainingColumn(), + console=get_logger_console(), + ) + + +_s3_config = TransferConfig() + +_s3_sts_tokens: dict[str, _S3STSToken] = {} + + +def get_s3_sts_token( + resource_id: str, file_name: PathLike, extra_arguments: Optional[Mapping[str, str]] = None +) -> _S3STSToken: + """Get s3 sts token for the given resource id and file name. + + Parameters + ---------- + resource_id : str + The resource id, e.g. task id. + file_name : PathLike + The remote file name on S3. + extra_arguments : Mapping[str, str] + Additional arguments for the query url. + + Returns + ------- + _S3STSToken + The S3 STS token. + """ + file_name = str(Path(file_name).as_posix()) + cache_key = f"{resource_id}:{file_name}" + if cache_key not in _s3_sts_tokens or _s3_sts_tokens[cache_key].is_expired(): + method = f"tidy3d/py/tasks/{resource_id}/file?filename={file_name}" + if extra_arguments is not None: + method += "&" + "&".join(f"{k}={v}" for k, v in extra_arguments.items()) + resp = http.get(method) + token = _S3STSToken.model_validate(resp) + _s3_sts_tokens[cache_key] = token + return _s3_sts_tokens[cache_key] + + +def upload_file( + resource_id: str, + path: PathLike, + remote_filename: PathLike, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + extra_arguments: Optional[Mapping[str, str]] = None, +) -> None: + """Upload a file to S3. + + Parameters + ---------- + resource_id : str + The resource id, e.g. task id. + path : PathLike + Path to the file to upload. + remote_filename : PathLike + The remote file name on S3 relative to the resource context root path. + verbose : bool = True + Whether to display a progressbar for the upload. + progress_callback : Callable[[float], None] = None + User-supplied callback function with ``bytes_in_chunk`` as argument. + extra_arguments : Mapping[str, str] + Additional arguments used to specify the upload bucket. + """ + + path = Path(path) + token = get_s3_sts_token(resource_id, remote_filename, extra_arguments) + + def _upload(_callback: Callable) -> None: + """Perform the upload with a callback function. + + Parameters + ---------- + _callback : Callable[[float], None] + Callback function for upload, accepts ``bytes_in_chunk`` + """ + + with path.open("rb") as data: + token.get_client().upload_fileobj( + data, + Bucket=token.get_bucket(), + Key=token.get_s3_key(), + Callback=_callback, + Config=_s3_config, + ExtraArgs={"ContentEncoding": "gzip"} + if token.get_s3_key().endswith(".gz") + else None, + ) + + if progress_callback is not None: + _upload(progress_callback) + else: + if verbose: + with _get_progress(_S3Action.UPLOADING) as progress: + total_size = path.stat().st_size + task_id = progress.add_task( + "upload", filename=str(remote_filename), total=total_size + ) + + def _callback(bytes_in_chunk: int) -> None: + progress.update(task_id, advance=bytes_in_chunk) + + _upload(_callback) + + progress.update(task_id, completed=total_size, refresh=True) + + else: + _upload(lambda bytes_in_chunk: None) + + +def download_file( + resource_id: str, + remote_filename: PathLike, + to_file: Optional[PathLike] = None, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, +) -> Path: + """Download file from S3. + + Parameters + ---------- + resource_id : str + The resource id, e.g. task id. + remote_filename : PathLike + Path to the remote file. + to_file : PathLike = None + Local filename to save to; if not specified, defaults to ``remote_filename`` in a + directory named after ``resource_id``. + verbose : bool = True + Whether to display a progressbar for the upload + progress_callback : Callable[[float], None] = None + User-supplied callback function with ``bytes_in_chunk`` as argument. + """ + + token = get_s3_sts_token(resource_id, remote_filename) + client = token.get_client() + meta_data = client.head_object(Bucket=token.get_bucket(), Key=token.get_s3_key()) + + # Get only last part of the remote file name + remote_basename = Path(remote_filename).name + + # set to_file if None + if to_file is None: + to_path = Path(resource_id) / remote_basename + else: + to_path = Path(to_file) + + # make the leading directories in the 'to_path', if any + to_path.parent.mkdir(parents=True, exist_ok=True) + + def _download(_callback: Callable) -> None: + """Perform the download with a callback function. + + Parameters + ---------- + _callback : Callable[[float], None] + Callback function for download, accepts ``bytes_in_chunk`` + """ + # Caller can assume the existence of the file means download succeeded. + # So make sure this file does not exist until that assumption is true. + to_path.unlink(missing_ok=True) + # Download to a temporary file. + try: + fd, tmp_file_path_str = tempfile.mkstemp(suffix=IN_TRANSIT_SUFFIX, dir=to_path.parent) + os.close(fd) # `tempfile.mkstemp()` creates and opens a randomly named file. close it. + to_path_tmp = Path(tmp_file_path_str) + client.download_file( + Bucket=token.get_bucket(), + Filename=str(to_path_tmp), + Key=token.get_s3_key(), + Callback=_callback, + Config=_s3_config, + ) + to_path_tmp.rename(to_path) + except Exception as e: + to_path_tmp.unlink(missing_ok=True) # Delete incompletely downloaded file. + raise e + + if progress_callback is not None: + _download(progress_callback) + else: + if verbose: + with _get_progress(_S3Action.DOWNLOADING) as progress: + total_size = meta_data.get("ContentLength", 0) + progress.start() + task_id = progress.add_task("download", filename=remote_basename, total=total_size) + + def _callback(bytes_in_chunk: int) -> None: + progress.update(task_id, advance=bytes_in_chunk) + + _download(_callback) + + progress.update(task_id, completed=total_size, refresh=True) + + else: + _download(lambda bytes_in_chunk: None) + + return to_path + + +def download_gz_file( + resource_id: str, + remote_filename: PathLike, + to_file: Optional[PathLike] = None, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, +) -> Path: + """Download a ``.gz`` file and unzip it into ``to_file``, unless ``to_file`` itself + ends in .gz + + Parameters + ---------- + resource_id : str + The resource id, e.g. task id. + remote_filename : PathLike + Path to the remote file. + to_file : Optional[PathLike] = None + Local filename to save to; if not specified, defaults to ``remote_filename`` with the + ``.gz`` suffix removed in a directory named after ``resource_id``. + verbose : bool = True + Whether to display a progressbar for the upload + progress_callback : Callable[[float], None] = None + User-supplied callback function with ``bytes_in_chunk`` as argument. + """ + + # If to_file is a gzip extension, just download + if to_file is None: + remote_basename = Path(remote_filename).name + if remote_basename.endswith(".gz"): + remote_basename = remote_basename[:-3] + to_path = Path(resource_id) / remote_basename + else: + to_path = Path(to_file) + + suffixes = "".join(to_path.suffixes).lower() + if suffixes.endswith(".gz"): + return download_file( + resource_id, + remote_filename, + to_file=to_path, + verbose=verbose, + progress_callback=progress_callback, + ) + + # Otherwise, download and unzip + # The tempfile is set as ``hdf5.gz`` so that the mock download in the webapi tests works + tmp_file, tmp_file_path_str = tempfile.mkstemp(".hdf5.gz") + os.close(tmp_file) + + # make the leading directories in the 'to_file', if any + to_path.parent.mkdir(parents=True, exist_ok=True) + try: + download_file( + resource_id, + remote_filename, + to_file=Path(tmp_file_path_str), + verbose=verbose, + progress_callback=progress_callback, + ) + if not Path(tmp_file_path_str).exists(): + raise WebError(f"Failed to download and extract '{remote_filename}'.") + + tmp_out_fd, tmp_out_path_str = tempfile.mkstemp( + suffix=IN_TRANSIT_SUFFIX, dir=to_path.parent + ) + os.close(tmp_out_fd) + tmp_out_path = Path(tmp_out_path_str) + try: + extract_gzip_file(Path(tmp_file_path_str), tmp_out_path) + tmp_out_path.replace(to_path) + except Exception as e: + tmp_out_path.unlink(missing_ok=True) + raise WebError( + f"Failed to extract '{remote_filename}' from '{tmp_file_path_str}' to '{to_path}'." + ) from e + finally: + Path(tmp_file_path_str).unlink(missing_ok=True) + return to_path diff --git a/tidy3d/_common/web/core/stub.py b/tidy3d/_common/web/core/stub.py new file mode 100644 index 0000000000..cebffd9ba0 --- /dev/null +++ b/tidy3d/_common/web/core/stub.py @@ -0,0 +1,84 @@ +"""Defines interface that can be subclassed to use with the tidy3d webapi""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from os import PathLike + + +class TaskStubData(ABC): + @abstractmethod + def from_file(self, file_path: PathLike) -> TaskStubData: + """Loads a :class:`TaskStubData` from .yaml, .json, or .hdf5 file. + + Parameters + ---------- + file_path : PathLike + Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. + + Returns + ------- + :class:`Stub` + An instance of the component class calling ``load``. + + """ + + @abstractmethod + def to_file(self, file_path: PathLike) -> None: + """Loads a :class:`Stub` from .yaml, .json, or .hdf5 file. + + Parameters + ---------- + file_path : PathLike + Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. + + Returns + ------- + :class:`Stub` + An instance of the component class calling ``load``. + """ + + +class TaskStub(ABC): + @abstractmethod + def from_file(self, file_path: PathLike) -> TaskStub: + """Loads a :class:`TaskStubData` from .yaml, .json, or .hdf5 file. + + Parameters + ---------- + file_path : PathLike + Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. + + Returns + ------- + :class:`TaskStubData` + An instance of the component class calling ``load``. + """ + + @abstractmethod + def to_file(self, file_path: PathLike) -> None: + """Loads a :class:`TaskStub` from .yaml, .json, .hdf5 or .hdf5.gz file. + + Parameters + ---------- + file_path : PathLike + Full path to the .yaml or .json or .hdf5 file to load the :class:`TaskStub` from. + + Returns + ------- + :class:`Stub` + An instance of the component class calling ``load``. + """ + + @abstractmethod + def to_hdf5_gz(self, fname: PathLike) -> None: + """Exports :class:`TaskStub` instance to .hdf5.gz file. + + Parameters + ---------- + fname : PathLike + Full path to the .hdf5.gz file to save the :class:`TaskStub` to. + """ diff --git a/tidy3d/_common/web/core/task_core.py b/tidy3d/_common/web/core/task_core.py new file mode 100644 index 0000000000..08ad9a1b28 --- /dev/null +++ b/tidy3d/_common/web/core/task_core.py @@ -0,0 +1,996 @@ +"""Tidy3d webapi types.""" + +from __future__ import annotations + +import os +import pathlib +import tempfile +from datetime import datetime +from typing import TYPE_CHECKING, Optional + +from botocore.exceptions import ClientError +from pydantic import Field, TypeAdapter + +from tidy3d._common.config import config +from tidy3d._common.exceptions import ValidationError +from tidy3d._common.log import log +from tidy3d._common.web.core import http_util +from tidy3d._common.web.core.cache import FOLDER_CACHE +from tidy3d._common.web.core.constants import ( + SIM_ERROR_FILE, + SIM_FILE_HDF5_GZ, + SIM_LOG_FILE, + SIM_VALIDATION_FILE, + SIMULATION_DATA_HDF5_GZ, +) +from tidy3d._common.web.core.core_config import get_logger_console +from tidy3d._common.web.core.exceptions import WebError, WebNotFoundError +from tidy3d._common.web.core.file_util import read_simulation_from_hdf5 +from tidy3d._common.web.core.http_util import get_version as _get_protocol_version +from tidy3d._common.web.core.http_util import http +from tidy3d._common.web.core.s3utils import download_file, download_gz_file, upload_file +from tidy3d._common.web.core.task_info import BatchDetail, TaskInfo +from tidy3d._common.web.core.types import ( + PayType, + Queryable, + ResourceLifecycle, + Submittable, + Tidy3DResource, +) + +if TYPE_CHECKING: + from os import PathLike + from typing import Callable, Union + + import requests + + from tidy3d._common.web.core.stub import TaskStub + + +class Folder(Tidy3DResource, Queryable, extra="allow"): + """Tidy3D Folder.""" + + folder_id: str = Field( + title="Folder id", + description="folder id", + alias="projectId", + ) + folder_name: str = Field( + title="Folder name", + description="folder name", + alias="projectName", + ) + + @classmethod + def list(cls, projects_endpoint: str = "tidy3d/projects") -> []: + """List all folders. + + Returns + ------- + folders : [Folder] + List of folders + """ + resp = http.get(projects_endpoint) + return TypeAdapter(list[Folder]).validate_python(resp) if resp else None + + @classmethod + def get( + cls, + folder_name: str, + create: bool = False, + projects_endpoint: str = "tidy3d/projects", + project_endpoint: str = "tidy3d/project", + ) -> Folder: + """Get folder by name. + + Parameters + ---------- + folder_name : str + Name of the folder. + create : str + If the folder doesn't exist, create it. + + Returns + ------- + folder : Folder + """ + folder = FOLDER_CACHE.get(folder_name) + if not folder: + resp = http.get(project_endpoint, params={"projectName": folder_name}) + if resp: + folder = Folder(**resp) + if create and not folder: + resp = http.post(projects_endpoint, {"projectName": folder_name}) + if resp: + folder = Folder(**resp) + FOLDER_CACHE[folder_name] = folder + return folder + + @classmethod + def create(cls, folder_name: str) -> Folder: + """Create a folder, return existing folder if there is one has the same name. + + Parameters + ---------- + folder_name : str + Name of the folder. + + Returns + ------- + folder : Folder + """ + return Folder.get(folder_name, True) + + def delete(self, projects_endpoint: str = "tidy3d/projects") -> None: + """Remove this folder.""" + + http.delete(f"{projects_endpoint}/{self.folder_id}") + + def delete_old(self, days_old: int) -> int: + """Remove folder contents older than ``days_old``.""" + + return http.delete( + f"tidy3d/tasks/{self.folder_id}/tasks", + params={"daysOld": days_old}, + ) + + def list_tasks(self, projects_endpoint: str = "tidy3d/projects") -> list[Tidy3DResource]: + """List all tasks in this folder. + + Returns + ------- + tasks : list[:class:`.SimulationTask`] + List of tasks in this folder + """ + resp = http.get(f"{projects_endpoint}/{self.folder_id}/tasks") + return TypeAdapter(list[SimulationTask]).validate_python(resp) if resp else None + + +class WebTask(ResourceLifecycle, Submittable, extra="allow"): + """Interface for managing the running a task on the server.""" + + task_id: Optional[str] = Field( + None, + title="task_id", + description="Task ID number, set when the task is uploaded, leave as None.", + alias="taskId", + ) + + @classmethod + def create( + cls, + task_type: str, + task_name: str, + folder_name: str = "default", + callback_url: Optional[str] = None, + simulation_type: str = "tidy3d", + parent_tasks: Optional[list[str]] = None, + file_type: str = "Gz", + projects_endpoint: str = "tidy3d/projects", + ) -> SimulationTask: + """Create a new task on the server. + + Parameters + ---------- + task_type: :class".TaskType" + The type of task. + task_name: str + The name of the task. + folder_name: str, + The name of the folder to store the task. Default is "default". + callback_url: str + Http PUT url to receive simulation finish event. The body content is a json file with + fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``. + simulation_type : str + Type of simulation being uploaded. + parent_tasks : list[str] + List of related task ids. + file_type: str + the simulation file type Json, Hdf5, Gz + + Returns + ------- + :class:`SimulationTask` + :class:`SimulationTask` object containing info about status, size, + credits of task and others. + """ + + # handle backwards compatibility, "tidy3d" is the default simulation_type + if simulation_type is None: + simulation_type = "tidy3d" + + folder = Folder.get(folder_name, create=True) + if task_type in ["RF", "TERMINAL_CM", "MODAL_CM"]: + payload = { + "groupName": task_name, + "folderId": folder.folder_id, + "fileType": file_type, + "taskType": task_type, + } + resp = http.post("rf/task", payload) + else: + payload = { + "taskName": task_name, + "taskType": task_type, + "callbackUrl": callback_url, # type: ignore[dict-item] + "simulationType": simulation_type, + "parentTasks": parent_tasks, # type: ignore[dict-item] + "fileType": file_type, + } + resp = http.post(f"{projects_endpoint}/{folder.folder_id}/tasks", payload) + return SimulationTask(**resp, taskType=task_type, folder_name=folder_name) + + def get_url(self) -> str: + base = str(config.web.website_endpoint or "") + if isinstance(self, BatchTask): + return "/".join([base.rstrip("/"), f"rf?taskId={self.task_id}"]) + return "/".join([base.rstrip("/"), f"workbench?taskId={self.task_id}"]) + + def get_folder_url(self) -> Optional[str]: + folder_id = getattr(self, "folder_id", None) + if not folder_id: + return None + base = str(config.web.website_endpoint or "") + return "/".join([base.rstrip("/"), f"folders/{folder_id}"]) + + def get_log( + self, + to_file: PathLike, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> pathlib.Path: + """Get log file from Server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while downloading the data. + + Returns + ------- + path: pathlib.Path + Path to saved file. + """ + + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + target_path = pathlib.Path(to_file) + + return download_file( + self.task_id, + SIM_LOG_FILE, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + + def get_data_hdf5( + self, + to_file: PathLike, + remote_data_file_gz: PathLike = SIMULATION_DATA_HDF5_GZ, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> pathlib.Path: + """Download data artifact (simulation or batch) with gz fallback handling. + + Parameters + ---------- + remote_data_file_gz : PathLike + Gzipped remote filename. + to_file : PathLike + Local target path. + verbose : bool + Whether to log progress. + progress_callback : Optional[Callable[[float], None]] + Progress callback. + + Returns + ------- + pathlib.Path + Saved local path. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + target_path = pathlib.Path(to_file) + file = None + try: + file = download_gz_file( + resource_id=self.task_id, + remote_filename=remote_data_file_gz, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + except ClientError: + if verbose: + console = get_logger_console() + console.log(f"Unable to download '{remote_data_file_gz}'.") + if not file: + try: + file = download_file( + resource_id=self.task_id, + remote_filename=str(remote_data_file_gz)[:-3], + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + except Exception as e: + raise WebError( + "Failed to download the data file from the server. " + "Please confirm that the task completed successfully." + ) from e + return file + + @staticmethod + def is_batch(resource_id: str) -> bool: + """Checks if a given resource ID corresponds to a valid batch task. + + This is a utility function to verify a batch task's existence before + instantiating the class. + + Parameters + ---------- + resource_id : str + The unique identifier for the resource. + + Returns + ------- + bool + ``True`` if the resource is a valid batch task, ``False`` otherwise. + """ + try: + # TODO PROPERLY FIXME + # Disable non critical logs due to check for resourceId, until we have a dedicated API for this + resp = http.get( + f"rf/task/{resource_id}/statistics", + suppress_404=True, + ) + status = bool(resp and isinstance(resp, dict) and "status" in resp) + return status + except Exception: + return False + + def delete(self, versions: bool = False) -> None: + """Delete current task from server. + + Parameters + ---------- + versions : bool = False + If ``True``, delete all versions of the task in the task group. Otherwise, delete only + the version associated with the current task ID. + """ + if not self.task_id: + raise ValueError("Task id not found.") + + task_details = self.detail().model_dump() + + if task_details and "groupId" in task_details: + group_id = task_details["groupId"] + if versions: + http.delete("tidy3d/group", json={"groupIds": [group_id]}) + return + elif "version" in task_details: + version = task_details["version"] + http.delete(f"tidy3d/group/{group_id}/versions", json={"versions": [version]}) + return + + # Fallback to old method if we can't get the groupId and version + http.delete(f"tidy3d/tasks/{self.task_id}") + + +class SimulationTask(WebTask): + """Interface for managing the running of solver tasks on the server.""" + + folder_id: Optional[str] = Field( + None, + title="folder_id", + description="Folder ID number, set when the task is uploaded, leave as None.", + alias="folderId", + ) + status: Optional[str] = Field(None, title="status", description="Simulation task status.") + + real_flex_unit: Optional[float] = Field( + None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" + ) + + created_at: Optional[datetime] = Field( + None, + title="created_at", + description="Time at which this task was created.", + alias="createdAt", + ) + + task_type: Optional[str] = Field( + None, title="task_type", description="The type of task.", alias="taskType" + ) + + folder_name: Optional[str] = Field( + "default", + title="Folder Name", + description="Name of the folder associated with this task.", + alias="folderName", + ) + + callback_url: Optional[str] = Field( + None, + title="Callback URL", + description="Http PUT url to receive simulation finish event. " + "The body content is a json file with fields " + "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", + ) + + # simulation_type: str = pd.Field( + # None, + # title="Simulation Type", + # description="Type of simulation, used internally only.", + # ) + + # parent_tasks: Tuple[TaskId, ...] = pd.Field( + # None, + # title="Parent Tasks", + # description="List of parent task ids for the simulation, used internally only." + # ) + + @classmethod + def get(cls, task_id: str, verbose: bool = True) -> SimulationTask: + """Get task from the server by id. + + Parameters + ---------- + task_id: str + Unique identifier of task on server. + verbose: + If `True`, will print progressbars and status, otherwise, will run silently. + + Returns + ------- + :class:`.SimulationTask` + :class:`.SimulationTask` object containing info about status, + size, credits of task and others. + """ + try: + resp = http.get(f"tidy3d/tasks/{task_id}/detail") + except WebNotFoundError as e: + log.error(f"The requested task ID '{task_id}' does not exist.") + raise e + + task = SimulationTask(**resp) if resp else None + return task + + @classmethod + def get_running_tasks(cls) -> list[SimulationTask]: + """Get a list of running tasks from the server" + + Returns + ------- + List[:class:`.SimulationTask`] + :class:`.SimulationTask` object containing info about status, + size, credits of task and others. + """ + resp = http.get("tidy3d/py/tasks") + if not resp: + return [] + return TypeAdapter(list[SimulationTask]).validate_python(resp) + + def detail(self) -> TaskInfo: + """Fetches the detailed information and status of the task. + + Returns + ------- + TaskInfo + An object containing the task's latest data. + """ + resp = http.get(f"tidy3d/tasks/{self.task_id}/detail") + return TaskInfo(**{"taskId": self.task_id, "taskType": self.task_type, **resp}) # type: ignore[dict-item] + + def get_simulation_json(self, to_file: PathLike, verbose: bool = True) -> None: + """Get json file for a :class:`.Simulation` from server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + to_file = pathlib.Path(to_file) + + hdf5_file, hdf5_file_path = tempfile.mkstemp(".hdf5") + os.close(hdf5_file) + try: + self.get_simulation_hdf5(hdf5_file_path) + if os.path.exists(hdf5_file_path): + json_string = read_simulation_from_hdf5(hdf5_file_path) + to_file.parent.mkdir(parents=True, exist_ok=True) + with to_file.open("w", encoding="utf-8") as file: + # Write the string to the file + file.write(json_string.decode("utf-8")) + if verbose: + console = get_logger_console() + console.log(f"Generate {to_file} successfully.") + else: + raise WebError("Failed to download simulation.json.") + finally: + os.unlink(hdf5_file_path) + + def upload_simulation( + self, + stub: TaskStub, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + remote_sim_file: PathLike = SIM_FILE_HDF5_GZ, + ) -> None: + """Upload :class:`.Simulation` object to Server. + + Parameters + ---------- + stub: :class:`TaskStub` + An instance of TaskStub. + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while uploading the data. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + if not stub: + raise WebError("Expected field 'simulation' is unset.") + # Also upload hdf5.gz containing all data. + file, file_name = tempfile.mkstemp() + os.close(file) + try: + # upload simulation + # compress .hdf5 to .hdf5.gz + stub.to_hdf5_gz(file_name) + upload_file( + self.task_id, + file_name, + remote_sim_file, + verbose=verbose, + progress_callback=progress_callback, + ) + finally: + os.unlink(file_name) + + def upload_file( + self, + local_file: PathLike, + remote_filename: str, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> None: + """ + Upload file to platform. Using this method when the json file is too large to parse + as :class".simulation". + Parameters + ---------- + local_file: PathLike + Local file path. + remote_filename: str + file name on the server + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while uploading the data. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + upload_file( + self.task_id, + local_file, + remote_filename, + verbose=verbose, + progress_callback=progress_callback, + ) + + def submit( + self, + solver_version: Optional[str] = None, + worker_group: Optional[str] = None, + pay_type: Union[PayType, str] = PayType.AUTO, + priority: Optional[int] = None, + ) -> None: + """Kick off this task. + + It will be uploaded to server before + starting the task. Otherwise, this method assumes that the Simulation has been uploaded by + the upload_file function, so the task will be kicked off directly. + + Parameters + ---------- + solver_version: str = None + target solver version. + worker_group: str = None + worker group + pay_type: Union[PayType, str] = PayType.AUTO + Which method to pay the simulation. + priority: int = None + Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest). + It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits. + """ + pay_type = PayType(pay_type) if not isinstance(pay_type, PayType) else pay_type + + if solver_version: + protocol_version = None + else: + protocol_version = http_util.get_version() + + http.post( + f"tidy3d/tasks/{self.task_id}/submit", + { + "solverVersion": solver_version, + "workerGroup": worker_group, + "protocolVersion": protocol_version, + "enableCaching": config.web.enable_caching, + "payType": pay_type.value, + "priority": priority, + }, + ) + + def estimate_cost(self, solver_version: Optional[str] = None) -> float: + """Compute the maximum flex unit charge for a given task, assuming the simulation runs for + the full ``run_time``. If early shut-off is triggered, the cost is adjusted proportionately. + + Parameters + ---------- + solver_version: str + target solver version. + + Returns + ------- + flex_unit_cost: float + estimated cost in FlexCredits + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + if solver_version: + protocol_version = None + else: + protocol_version = http_util.get_version() + + resp = http.post( + f"tidy3d/tasks/{self.task_id}/metadata", + { + "solverVersion": solver_version, + "protocolVersion": protocol_version, + }, + ) + return resp + + def get_simulation_hdf5( + self, + to_file: PathLike, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + remote_sim_file: PathLike = SIM_FILE_HDF5_GZ, + ) -> pathlib.Path: + """Get simulation.hdf5 file from Server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while downloading the data. + + Returns + ------- + path: pathlib.Path + Path to saved file. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + target_path = pathlib.Path(to_file) + + return download_gz_file( + resource_id=self.task_id, + remote_filename=remote_sim_file, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + + def get_running_info(self) -> tuple[float, float]: + """Gets the % done and field_decay for a running task. + + Returns + ------- + perc_done : float + Percentage of run done (in terms of max number of time steps). + Is ``None`` if run info not available. + field_decay : float + Average field intensity normalized to max value (1.0). + Is ``None`` if run info not available. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + resp = http.get(f"tidy3d/tasks/{self.task_id}/progress") + perc_done = resp.get("perc_done") + field_decay = resp.get("field_decay") + return perc_done, field_decay + + def get_log( + self, + to_file: PathLike, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> pathlib.Path: + """Get log file from Server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while downloading the data. + + Returns + ------- + path: pathlib.Path + Path to saved file. + """ + + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + target_path = pathlib.Path(to_file) + + return download_file( + self.task_id, + SIM_LOG_FILE, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + + def get_error_json( + self, to_file: PathLike, verbose: bool = True, validation: bool = False + ) -> pathlib.Path: + """Get error json file for a :class:`.Simulation` from server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + validation: bool = False + Whether to get a validation error file or a solver error file. + + Returns + ------- + path: pathlib.Path + Path to saved file. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + target_path = pathlib.Path(to_file) + target_file = SIM_ERROR_FILE if not validation else SIM_VALIDATION_FILE + + return download_file( + self.task_id, + target_file, + to_file=target_path, + verbose=verbose, + ) + + def abort(self) -> requests.Response: + """Abort the current task on the server.""" + if not self.task_id: + raise ValueError("Task id not found.") + return http.put( + "tidy3d/tasks/abort", json={"taskType": self.task_type, "taskId": self.task_id} + ) + + def validate_post_upload(self, parent_tasks: Optional[list[str]] = None) -> None: + """Perform checks after task is uploaded and metadata is processed.""" + if self.task_type == "HEAT_CHARGE" and parent_tasks: + try: + if len(parent_tasks) > 1: + raise ValueError( + "A single parent 'task_id' corresponding to the task in which the meshing " + "was run must be provided." + ) + try: + # get mesh task info + mesh_task = SimulationTask.get(parent_tasks[0], verbose=False) + assert mesh_task.task_type == "VOLUME_MESH" + assert mesh_task.status == "success" + # get up-to-date task info + task = SimulationTask.get(self.task_id, verbose=False) + if task.fileMd5 != mesh_task.childFileMd5: + raise ValidationError( + "Simulation stored in parent task 'VolumeMesher' does not match the " + "current simulation." + ) + except Exception as e: + raise ValidationError( + "The parent task must be a 'VolumeMesher' task which has been successfully " + "run and is associated to the same 'HeatChargeSimulation' as provided here." + ) from e + + except Exception as e: + raise WebError(f"Provided 'parent_tasks' failed validation: {e!s}") from e + + +class BatchTask(WebTask): + """Interface for managing a batch task on the server.""" + + task_type: Optional[str] = Field( + None, title="task_type", description="The type of task.", alias="taskType" + ) + + @classmethod + def get(cls, task_id: str, verbose: bool = True) -> BatchTask: + """Get batch task by id. + + Parameters + ---------- + task_id: str + Unique identifier of batch on server. + verbose: + If `True`, will print progressbars and status, otherwise, will run silently. + + Returns + ------- + :class:`.BatchTask` | None + BatchTask object if found, otherwise None. + """ + try: + resp = http.get(f"rf/task/{task_id}/statistics") + except WebNotFoundError as e: + log.error(f"The requested batch ID '{task_id}' does not exist.") + raise e + # Extract taskType from response if available + if resp: + task_type = resp.get("taskType") if isinstance(resp, dict) else None + return BatchTask(taskId=task_id, taskType=task_type) + return None + + def detail(self) -> BatchDetail: + """Fetches the detailed information and status of the batch. + + Returns + ------- + BatchDetail + An object containing the batch's latest data. + """ + resp = http.get( + f"rf/task/{self.task_id}/statistics", + ) + # Some backends may return null for collection fields; coerce to sensible defaults + if isinstance(resp, dict): + if resp.get("tasks") is None: + resp["tasks"] = [] + return BatchDetail(**(resp or {})) + + def check( + self, + check_task_type: str, + solver_version: Optional[str] = None, + protocol_version: Optional[str] = None, + ) -> requests.Response: + """Submits a request to validate the batch configuration on the server. + + Parameters + ---------- + solver_version : Optional[str], default=None + The version of the solver to use for validation. + protocol_version : Optional[str], default=None + The data protocol version. Defaults to the current version. + + Returns + ------- + Any + The server's response to the check request. + """ + if protocol_version is None: + protocol_version = _get_protocol_version() + return http.post( + f"rf/task/{self.task_id}/check", + { + "solverVersion": solver_version, + "protocolVersion": protocol_version, + "taskType": check_task_type, + }, + ) + + def submit( + self, + solver_version: Optional[str] = None, + protocol_version: Optional[str] = None, + worker_group: Optional[str] = None, + pay_type: Union[PayType, str] = PayType.AUTO, + priority: Optional[int] = None, + ) -> requests.Response: + """Submits the batch for execution on the server. + + Parameters + ---------- + solver_version : Optional[str], default=None + The version of the solver to use for execution. + protocol_version : Optional[str], default=None + The data protocol version. Defaults to the current version. + worker_group : Optional[str], default=None + Optional identifier for a specific worker group to run on. + + Returns + ------- + Any + The server's response to the submit request. + """ + + # TODO: add support for pay_type and priority arguments + if pay_type != PayType.AUTO: + raise NotImplementedError( + "The 'pay_type' argument is not yet supported and will be ignored." + ) + if priority is not None: + raise NotImplementedError( + "The 'priority' argument is not yet supported and will be ignored." + ) + + if protocol_version is None: + protocol_version = _get_protocol_version() + return http.post( + f"rf/task/{self.task_id}/submit", + { + "solverVersion": solver_version, + "protocolVersion": protocol_version, + "workerGroup": worker_group, + }, + ) + + def abort(self) -> requests.Response: + """Abort the current task on the server.""" + if not self.task_id: + raise ValueError("Batch id not found.") + return http.put(f"rf/task/{self.task_id}/abort", {}) + + +class TaskFactory: + """Factory for obtaining the correct task subclass.""" + + _REGISTRY: dict[str, str] = {} + + @classmethod + def reset(cls) -> None: + """Clear the cached task kind registry (used in tests).""" + cls._REGISTRY.clear() + + @classmethod + def register(cls, task_id: str, kind: str) -> None: + cls._REGISTRY[task_id] = kind + + @classmethod + def get(cls, task_id: str, verbose: bool = True) -> WebTask: + kind = cls._REGISTRY.get(task_id) + if kind == "batch": + return BatchTask.get(task_id, verbose=verbose) + if kind == "simulation": + task = SimulationTask.get(task_id, verbose=verbose) + return task + if WebTask.is_batch(task_id): + cls.register(task_id, "batch") + return BatchTask.get(task_id, verbose=verbose) + task = SimulationTask.get(task_id, verbose=verbose) + if task: + cls.register(task_id, "simulation") + return task diff --git a/tidy3d/_common/web/core/task_info.py b/tidy3d/_common/web/core/task_info.py new file mode 100644 index 0000000000..c42ba0f220 --- /dev/null +++ b/tidy3d/_common/web/core/task_info.py @@ -0,0 +1,328 @@ +"""Defines information about a task""" + +from __future__ import annotations + +from abc import ABC +from datetime import datetime +from enum import Enum +from typing import Annotated, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class TaskBase(BaseModel, ABC): + """Base configuration for all task objects.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ChargeType(str, Enum): + """The payment method of the task.""" + + FREE = "free" + """No payment required.""" + + PAID = "paid" + """Payment required.""" + + +class TaskBlockInfo(TaskBase): + """Information about the task's block status. + + Notes + ----- + This includes details about how the task can be blocked by various features + such as user limits and insufficient balance. + """ + + chargeType: Optional[ChargeType] = None + """The type of charge applicable to the task (free or paid).""" + + maxFreeCount: Optional[int] = None + """The maximum number of free tasks allowed.""" + + maxGridPoints: Optional[int] = None + """The maximum number of grid points permitted.""" + + maxTimeSteps: Optional[int] = None + """The maximum number of time steps allowed.""" + + +class TaskInfo(TaskBase): + """General information about a task.""" + + taskId: str + """Unique identifier for the task.""" + + taskName: Optional[str] = None + """Name of the task.""" + + nodeSize: Optional[int] = None + """Size of the node allocated for the task.""" + + completedAt: Optional[datetime] = None + """Timestamp when the task was completed.""" + + status: Optional[str] = None + """Current status of the task.""" + + realCost: Optional[float] = None + """Actual cost incurred by the task.""" + + timeSteps: Optional[int] = None + """Number of time steps involved in the task.""" + + solverVersion: Optional[str] = None + """Version of the solver used for the task.""" + + createAt: Optional[datetime] = None + """Timestamp when the task was created.""" + + estCostMin: Optional[float] = None + """Estimated minimum cost for the task.""" + + estCostMax: Optional[float] = None + """Estimated maximum cost for the task.""" + + realFlexUnit: Optional[float] = None + """Actual flexible units used by the task.""" + + oriRealFlexUnit: Optional[float] = None + """Original real flexible units.""" + + estFlexUnit: Optional[float] = None + """Estimated flexible units for the task.""" + + estFlexCreditTimeStepping: Optional[float] = None + """Estimated flexible credits for time stepping.""" + + estFlexCreditPostProcess: Optional[float] = None + """Estimated flexible credits for post-processing.""" + + estFlexCreditMode: Optional[float] = None + """Estimated flexible credits based on the mode.""" + + s3Storage: Optional[float] = None + """Amount of S3 storage used by the task.""" + + startSolverTime: Optional[datetime] = None + """Timestamp when the solver started.""" + + finishSolverTime: Optional[datetime] = None + """Timestamp when the solver finished.""" + + totalSolverTime: Optional[int] = None + """Total time taken by the solver.""" + + callbackUrl: Optional[str] = None + """Callback URL for task notifications.""" + + taskType: Optional[str] = None + """Type of the task.""" + + metadataStatus: Optional[str] = None + """Status of the metadata for the task.""" + + taskBlockInfo: Optional[TaskBlockInfo] = None + """Blocking information for the task.""" + + version: Optional[str] = None + """Version of the task.""" + + +class RunInfo(TaskBase): + """Information about the run of a task.""" + + perc_done: Annotated[float, Field(ge=0.0, le=100.0)] + """Percentage of the task that is completed (0 to 100).""" + + field_decay: Annotated[float, Field(ge=0.0, le=1.0)] + """Field decay from the maximum value (0 to 1).""" + + def display(self) -> None: + """Print some info about the task's progress.""" + print(f" - {self.perc_done:.2f} (%) done") + print(f" - {self.field_decay:.2e} field decay from max") + + +# ---------------------- Batch (Modeler) detail schema ---------------------- # + + +class BatchTaskBlockInfo(TaskBlockInfo): + """ + Extends `TaskBlockInfo` with specific details for batch task blocking. + + Attributes: + accountLimit: A usage or cost limit imposed by the user's account. + taskBlockMsg: A human-readable message describing the reason for the block. + taskBlockType: The specific type of block (e.g., 'balance', 'limit'). + blockStatus: The current blocking status for the batch. + taskStatus: The status of the task when it was blocked. + """ + + accountLimit: Optional[float] = None + taskBlockMsg: Optional[str] = None + taskBlockType: Optional[str] = None + blockStatus: Optional[str] = None + taskStatus: Optional[str] = None + + +class BatchMember(TaskBase): + """ + Represents a single task within a larger batch operation. + + Attributes: + refId: A reference identifier for the member task. + folderId: The identifier of the folder containing the task. + sweepId: The identifier for the parameter sweep, if applicable. + taskId: The unique identifier of the task. + linkedTaskId: The identifier of a task linked to this one. + groupId: The identifier of the group this task belongs to. + taskName: The name of the individual task. + status: The current status of this specific task. + sweepData: Data associated with a parameter sweep. + validateInfo: Information related to the task's validation. + replaceData: Data used for replacements or modifications. + protocolVersion: The version of the protocol used. + variable: The variable parameter for this task in a sweep. + createdAt: The timestamp when the member task was created. + updatedAt: The timestamp when the member task was last updated. + denormalizeStatus: The status of the data denormalization process. + summary: A dictionary containing summary information for the task. + """ + + refId: Optional[str] = None + folderId: Optional[str] = None + sweepId: Optional[str] = None + taskId: Optional[str] = None + linkedTaskId: Optional[str] = None + groupId: Optional[str] = None + taskName: Optional[str] = None + status: Optional[str] = None + sweepData: Optional[str] = None + validateInfo: Optional[str] = None + replaceData: Optional[str] = None + protocolVersion: Optional[str] = None + variable: Optional[str] = None + createdAt: Optional[datetime] = None + updatedAt: Optional[datetime] = None + denormalizeStatus: Optional[str] = None + summary: Optional[dict] = None + + +class BatchDetail(TaskBase): + """Provides a detailed, top-level view of a batch of tasks. + + Notes + ----- + This model serves as the main payload for retrieving comprehensive + information about a batch operation. + + Attributes + ---------- + refId + A reference identifier for the entire batch. + optimizationId + Identifier for the optimization process, if any. + groupId + Identifier for the group the batch belongs to. + name + The user-defined name of the batch. + status + The current status of the batch. + totalTask + The total number of tasks in the batch. + preprocessSuccess + The count of tasks that completed preprocessing. + postprocessStatus + The status of the batch's postprocessing stage. + validateSuccess + The count of tasks that passed validation. + runSuccess + The count of tasks that ran successfully. + postprocessSuccess + The count of tasks that completed postprocessing. + taskBlockInfo + Information on what might be blocking the batch. + estFlexUnit + The estimated total flexible compute units for the batch. + totalSeconds + The total time in seconds the batch has taken. + totalCheckMillis + Total time in milliseconds spent on checks. + message + A general message providing information about the batch status. + tasks + A list of `BatchMember` objects, one for each task in the batch. + taskType + The type of tasks contained in the batch. + """ + + refId: Optional[str] = None + optimizationId: Optional[str] = None + groupId: Optional[str] = None + name: Optional[str] = None + status: Optional[str] = None + totalTask: int = 0 + preprocessSuccess: int = 0 + postprocessStatus: Optional[str] = None + validateSuccess: int = 0 + runSuccess: int = 0 + postprocessSuccess: int = 0 + taskBlockInfo: Optional[BatchTaskBlockInfo] = None + estFlexUnit: Optional[float] = None + realFlexUnit: Optional[float] = None + totalSeconds: Optional[int] = None + totalCheckMillis: Optional[int] = None + message: Optional[str] = None + tasks: list[BatchMember] = [] + validateErrors: Optional[dict] = None + taskType: str = None + version: Optional[str] = None + + +class AsyncJobDetail(TaskBase): + """Provides a detailed view of an asynchronous job and its sub-tasks. + + Notes + ----- + This model represents a long-running operation. The 'result' attribute holds + the output of a completed job, which for orchestration jobs, is often a + JSON string mapping sub-task names to their unique IDs. + + Attributes + ---------- + asyncId + The unique identifier for the asynchronous job. + status + The current overall status of the job (e.g., 'RUNNING', 'COMPLETED'). + progress + The completion percentage of the job (from 0.0 to 100.0). + createdAt + The timestamp when the job was created. + completedAt + The timestamp when the job finished (successfully or not). + tasks + A dictionary mapping logical task keys to their unique task IDs. + This is often populated by parsing the 'result' of an orchestration task. + result + The raw string output of the completed job. If the job spawns other + tasks, this is expected to be a JSON string detailing those tasks. + taskBlockInfo + Information on any dependencies blocking the job from running. + message + A human-readable message about the job's status. + """ + + asyncId: str + status: str + progress: Optional[float] = None + createdAt: Optional[datetime] = None + completedAt: Optional[datetime] = None + tasks: Optional[dict[str, str]] = None + result: Optional[str] = None + taskBlockInfo: Optional[TaskBlockInfo] = None + message: Optional[str] = None + + +AsyncJobDetail.model_rebuild() diff --git a/tidy3d/_common/web/core/types.py b/tidy3d/_common/web/core/types.py new file mode 100644 index 0000000000..aaac18612a --- /dev/null +++ b/tidy3d/_common/web/core/types.py @@ -0,0 +1,73 @@ +"""Tidy3d abstraction types for the core.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any + +from pydantic import BaseModel + + +class Tidy3DResource(BaseModel, ABC): + """Abstract base class / template for a webservice that implements resource query.""" + + @classmethod + @abstractmethod + def get(cls, *args: Any, **kwargs: Any) -> Tidy3DResource: + """Get a resource from the server.""" + + +class ResourceLifecycle(Tidy3DResource, ABC): + """Abstract base class for a webservice that implements resource life cycle management.""" + + @classmethod + @abstractmethod + def create(cls, *args: Any, **kwargs: Any) -> Tidy3DResource: + """Create a new resource and return it.""" + + @abstractmethod + def delete(self, *args: Any, **kwargs: Any) -> None: + """Delete the resource.""" + + +class Submittable(BaseModel, ABC): + """Abstract base class / template for a webservice that implements a submit method.""" + + @abstractmethod + def submit(self, *args: Any, **kwargs: Any) -> None: + """Submit the task to the webservice.""" + + +class Queryable(BaseModel, ABC): + """Abstract base class / template for a webservice that implements a query method.""" + + @classmethod + @abstractmethod + def list(cls, *args: Any, **kwargs: Any) -> list[Queryable]: + """List all resources of this type.""" + + +class TaskType(str, Enum): + FDTD = "FDTD" + MODE_SOLVER = "MODE_SOLVER" + HEAT = "HEAT" + HEAT_CHARGE = "HEAT_CHARGE" + EME = "EME" + MODE = "MODE" + VOLUME_MESH = "VOLUME_MESH" + MODAL_CM = "MODAL_CM" + TERMINAL_CM = "TERMINAL_CM" + + +class PayType(str, Enum): + CREDITS = "FLEX_CREDIT" + AUTO = "AUTO" + + @classmethod + def _missing_(cls, value: object) -> PayType: + if isinstance(value, str): + key = value.strip().replace(" ", "_").upper() + if key in cls.__members__: + return cls.__members__[key] + return super()._missing_(value) diff --git a/tidy3d/_runtime.py b/tidy3d/_runtime.py index 6dbf61accd..068dfd2d92 100644 --- a/tidy3d/_runtime.py +++ b/tidy3d/_runtime.py @@ -1,12 +1,10 @@ -"""Runtime environment detection for tidy3d. +"""Compatibility shim for :mod:`tidy3d._common._runtime`.""" -This module must have ZERO dependencies on other tidy3d modules to avoid -circular imports. It is imported very early in the initialization chain. -""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common from __future__ import annotations -import sys - -# Detect WASM/Pyodide environment where web and filesystem features are unavailable -WASM_BUILD = "pyodide" in sys.modules or sys.platform == "emscripten" +from tidy3d._common._runtime import ( + WASM_BUILD, +) diff --git a/tidy3d/compat.py b/tidy3d/compat.py index 2e81d76c26..dd9ff00a6b 100644 --- a/tidy3d/compat.py +++ b/tidy3d/compat.py @@ -1,21 +1,8 @@ -"""Compatibility layer for handling differences between package versions.""" +"""Compatibility shim for :mod:`tidy3d._common.compat`.""" -from __future__ import annotations - -import functools -import importlib - -from packaging.version import parse as parse_version - -try: - from xarray.structure import alignment -except ImportError: - from xarray.core import alignment - - -@functools.lru_cache(maxsize=8) -def _shapely_is_older_than(version: str) -> bool: - return parse_version(importlib.metadata.version("shapely")) < parse_version(version) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -__all__ = ["alignment"] +from tidy3d._common.compat import Self, TypeAlias, _package_is_older_than, alignment, np_trapezoid diff --git a/tidy3d/components/README.md b/tidy3d/components/README.md index 8c8fbc5545..aefa5976a4 100644 --- a/tidy3d/components/README.md +++ b/tidy3d/components/README.md @@ -134,7 +134,7 @@ Other checks may be added in future development. #### JSON Operations -The `Simulation` can be exported as .json-like dictionary with `Simulation.json()` +The `Simulation` can be exported as .json-like dictionary with `Simulation.model_dump_json()` The schema corresponding to `Simulation` can be generated with `Simulation.schema()` ## Medium diff --git a/tidy3d/components/apodization.py b/tidy3d/components/apodization.py index 3595640586..13488cc25c 100644 --- a/tidy3d/components/apodization.py +++ b/tidy3d/components/apodization.py @@ -2,16 +2,22 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Optional + import numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat, model_validator from tidy3d.constants import SECOND from tidy3d.exceptions import SetupError -from .base import Tidy3dBaseModel, skip_if_fields_missing -from .types import ArrayFloat1D, Ax +from .base import Tidy3dBaseModel from .viz import add_ax_if_none +if TYPE_CHECKING: + from tidy3d.compat import Self + + from .types import ArrayFloat1D, Ax + class ApodizationSpec(Tidy3dBaseModel): """Stores specifications for the apodizaton of frequency-domain monitors. @@ -27,45 +33,40 @@ class ApodizationSpec(Tidy3dBaseModel): """ - start: pd.NonNegativeFloat = pd.Field( + start: Optional[NonNegativeFloat] = Field( None, title="Start Interval", description="Defines the time at which the start apodization ends.", units=SECOND, ) - end: pd.NonNegativeFloat = pd.Field( + end: Optional[NonNegativeFloat] = Field( None, title="End Interval", description="Defines the time at which the end apodization begins.", units=SECOND, ) - width: pd.PositiveFloat = pd.Field( + width: Optional[PositiveFloat] = Field( None, title="Apodization Width", description="Characteristic decay length of the apodization function, i.e., the width of the ramping up of the scaling function from 0 to 1.", units=SECOND, ) - @pd.validator("end", always=True, allow_reuse=True) - @skip_if_fields_missing(["start"]) - def end_greater_than_start(cls, val, values): + @model_validator(mode="after") + def end_greater_than_start(self) -> Self: """Ensure end is greater than or equal to start.""" - start = values.get("start") - if val is not None and start is not None and val < start: + if self.end is not None and self.start is not None and self.end < self.start: raise SetupError("End apodization begins before start apodization ends.") - return val + return self - @pd.validator("width", always=True, allow_reuse=True) - @skip_if_fields_missing(["start", "end"]) - def width_provided(cls, val, values): + @model_validator(mode="after") + def width_provided(self) -> Self: """Check that width is provided if either start or end apodization is requested.""" - start = values.get("start") - end = values.get("end") - if (start is not None or end is not None) and val is None: + if (self.start is not None or self.end is not None) and self.width is None: raise SetupError("Apodization width must be set.") - return val + return self @add_ax_if_none def plot(self, times: ArrayFloat1D, ax: Ax = None) -> Ax: diff --git a/tidy3d/components/autograd/__init__.py b/tidy3d/components/autograd/__init__.py index a2e9eea893..d83a2da2c6 100644 --- a/tidy3d/components/autograd/__init__.py +++ b/tidy3d/components/autograd/__init__.py @@ -1,31 +1,28 @@ +"""Compatibility shim for :mod:`tidy3d._common.components.autograd`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common from __future__ import annotations -from .boxes import TidyArrayBox -from .functions import interpn -from .types import ( +from tidy3d._common.components.autograd import ( AutogradFieldMap, - AutogradTraced, + InterpolationType, + PathType, + TidyArrayBox, + TracedArrayFloat2D, + TracedArrayLike, + TracedComplex, TracedCoordinate, TracedFloat, + TracedPoleAndResidue, + TracedPolesAndResidues, + TracedPositiveFloat, TracedSize, TracedSize1D, - TracedVertices, + get_static, + hasbox, + interpn, + is_tidy_box, + split_list, ) -from .utils import get_static, is_tidy_box, split_list - -__all__ = [ - "AutogradFieldMap", - "AutogradTraced", - "TidyArrayBox", - "TracedCoordinate", - "TracedFloat", - "TracedSize", - "TracedSize1D", - "TracedVertices", - "add_at", - "get_static", - "interpn", - "is_tidy_box", - "split_list", - "trapz", -] diff --git a/tidy3d/components/autograd/boxes.py b/tidy3d/components/autograd/boxes.py index 78aa52289a..d8765e865c 100644 --- a/tidy3d/components/autograd/boxes.py +++ b/tidy3d/components/autograd/boxes.py @@ -1,160 +1,13 @@ -# Adds some functionality to the autograd arraybox and related autograd patches -# NOTE: we do not subclass ArrayBox since that would break autograd's internal checks -from __future__ import annotations - -import importlib -from typing import Any, Callable - -import autograd.numpy as anp -from autograd.extend import VJPNode, defjvp, register_notrace -from autograd.numpy.numpy_boxes import ArrayBox -from autograd.numpy.numpy_wrapper import _astype - -TidyArrayBox = ArrayBox # NOT a subclass +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.boxes`.""" -_autograd_module_cache = {} # cache for imported autograd modules +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -register_notrace(VJPNode, anp.full_like) +# marked as migrated to _common +from __future__ import annotations -defjvp( - _astype, - lambda g, ans, A, dtype, order="K", casting="unsafe", subok=True, copy=True: _astype(g, dtype), +from tidy3d._common.components.autograd.boxes import ( + TidyArrayBox, + _autograd_module_cache, + from_arraybox, + item, ) - -anp.astype = _astype -anp.permute_dims = anp.transpose - - -@classmethod -def from_arraybox(cls, box: ArrayBox) -> TidyArrayBox: - """Construct a TidyArrayBox from an ArrayBox.""" - return cls(box._value, box._trace, box._node) - - -def __array_function__( - self: Any, - func: Callable, - types: list[Any], - args: tuple[Any, ...], - kwargs: dict[str, Any], -) -> Any: - """ - Handle the dispatch of NumPy functions to autograd's numpy implementation. - - Parameters - ---------- - self : Any - The instance of the class. - func : Callable - The NumPy function being called. - types : List[Any] - The types of the arguments that implement __array_function__. - args : Tuple[Any, ...] - The positional arguments to the function. - kwargs : Dict[str, Any] - The keyword arguments to the function. - - Returns - ------- - Any - The result of the function call, or NotImplemented. - - Raises - ------ - NotImplementedError - If the function is not implemented in autograd.numpy. - - See Also - -------- - https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_function__ - """ - if not all(t in TidyArrayBox.type_mappings for t in types): - return NotImplemented - - module_name = func.__module__ - - if module_name.startswith("numpy"): - anp_module_name = "autograd." + module_name - else: - return NotImplemented - - # Use the cached module if available - anp_module = _autograd_module_cache.get(anp_module_name) - if anp_module is None: - try: - anp_module = importlib.import_module(anp_module_name) - _autograd_module_cache[anp_module_name] = anp_module - except ImportError: - return NotImplemented - - f = getattr(anp_module, func.__name__, None) - if f is None: - return NotImplemented - - if f.__name__ == "nanmean": # somehow xarray always dispatches to nanmean - f = anp.mean - kwargs.pop("dtype", None) # autograd mean vjp doesn't support dtype - - return f(*args, **kwargs) - - -def __array_ufunc__( - self: Any, - ufunc: Callable, - method: str, - *inputs: Any, - **kwargs: dict[str, Any], -) -> Any: - """ - Handle the dispatch of NumPy ufuncs to autograd's numpy implementation. - - Parameters - ---------- - self : Any - The instance of the class. - ufunc : Callable - The universal function being called. - method : str - The method of the ufunc being called. - inputs : Any - The input arguments to the ufunc. - kwargs : Dict[str, Any] - The keyword arguments to the ufunc. - - Returns - ------- - Any - The result of the ufunc call, or NotImplemented. - - See Also - -------- - https://numpy.org/doc/stable/reference/arrays.classes.html#numpy.class.__array_ufunc__ - """ - if method != "__call__": - return NotImplemented - - ufunc_name = ufunc.__name__ - - anp_ufunc = getattr(anp, ufunc_name, None) - if anp_ufunc is not None: - return anp_ufunc(*inputs, **kwargs) - - return NotImplemented - - -def item(self): - if self.size != 1: - raise ValueError("Can only convert an array of size 1 to a scalar") - return anp.ravel(self)[0] - - -TidyArrayBox._tidy = True -TidyArrayBox.from_arraybox = from_arraybox -TidyArrayBox.__array_namespace__ = lambda self, *, api_version=None: anp -TidyArrayBox.__array_ufunc__ = __array_ufunc__ -TidyArrayBox.__array_function__ = __array_function__ -TidyArrayBox.__repr__ = str -TidyArrayBox.real = property(anp.real) -TidyArrayBox.imag = property(anp.imag) -TidyArrayBox.conj = anp.conj -TidyArrayBox.item = item diff --git a/tidy3d/components/autograd/derivative_utils.py b/tidy3d/components/autograd/derivative_utils.py index 7c36444687..156ef9b21a 100644 --- a/tidy3d/components/autograd/derivative_utils.py +++ b/tidy3d/components/autograd/derivative_utils.py @@ -1,840 +1,17 @@ -"""Utilities for autograd derivative computation and field gradient evaluation.""" +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.derivative_utils`.""" -from __future__ import annotations - -from dataclasses import dataclass, field, replace -from typing import Any, Callable, Optional, Union - -import numpy as np -import xarray as xr - -from tidy3d.components.data.data_array import FreqDataArray, ScalarFieldDataArray -from tidy3d.components.types import ArrayLike, Bound, tidycomplex -from tidy3d.config import config -from tidy3d.constants import C_0, EPSILON_0, LARGE_NUMBER, MU_0 -from tidy3d.log import log - -from .types import PathType -from .utils import get_static - -FieldData = dict[str, ScalarFieldDataArray] -PermittivityData = dict[str, ScalarFieldDataArray] -EpsType = Union[tidycomplex, FreqDataArray] - - -class LazyInterpolator: - """Lazy wrapper for interpolators that creates them on first access.""" - - def __init__(self, creator_func: Callable) -> None: - """Initialize with a function that creates the interpolator when called.""" - self.creator_func = creator_func - self._interpolator = None - - def __call__(self, *args: Any, **kwargs: Any): - """Create interpolator on first call and delegate to it.""" - if self._interpolator is None: - self._interpolator = self.creator_func() - return self._interpolator(*args, **kwargs) - - -@dataclass -class DerivativeInfo: - """Stores derivative information passed to the ``._compute_derivatives`` methods. - - This dataclass contains all the field data and parameters needed for computing - gradients with respect to geometry perturbations. - """ - - # Required fields - paths: list[PathType] - """List of paths to the traced fields that need derivatives calculated.""" - - E_der_map: FieldData - """Electric field gradient map. - Dataset where the field components ("Ex", "Ey", "Ez") store the multiplication - of the forward and adjoint electric fields. The tangential components of this - dataset are used when computing adjoint gradients for shifting boundaries. - All components are used when computing volume-based gradients.""" - - D_der_map: FieldData - """Displacement field gradient map. - Dataset where the field components ("Ex", "Ey", "Ez") store the multiplication - of the forward and adjoint displacement fields. The normal component of this - dataset is used when computing adjoint gradients for shifting boundaries.""" - - E_fwd: FieldData - """Forward electric fields. - Dataset where the field components ("Ex", "Ey", "Ez") represent the forward - electric fields used for computing gradients for a given structure.""" - - E_adj: FieldData - """Adjoint electric fields. - Dataset where the field components ("Ex", "Ey", "Ez") represent the adjoint - electric fields used for computing gradients for a given structure.""" - - D_fwd: FieldData - """Forward displacement fields. - Dataset where the field components ("Ex", "Ey", "Ez") represent the forward - displacement fields used for computing gradients for a given structure.""" - - D_adj: FieldData - """Adjoint displacement fields. - Dataset where the field components ("Ex", "Ey", "Ez") represent the adjoint - displacement fields used for computing gradients for a given structure.""" - - eps_data: PermittivityData - """Permittivity dataset. - Dataset of relative permittivity values along all three dimensions. - Used for automatically computing permittivity inside or outside of a simple geometry.""" - - eps_in: EpsType - """Permittivity inside the Structure. - Typically computed from Structure.medium.eps_model. - Used when it cannot be computed from eps_data or when eps_approx=True.""" - - eps_out: EpsType - """Permittivity outside the Structure. - Typically computed from Simulation.medium.eps_model. - Used when it cannot be computed from eps_data or when eps_approx=True.""" - - bounds: Bound - """Geometry bounds. - Bounds corresponding to the structure, used in Medium calculations.""" - - bounds_intersect: Bound - """Geometry and simulation intersection bounds. - Bounds corresponding to the minimum intersection between the structure - and the simulation it is contained in.""" - - simulation_bounds: Bound - """Simulation bounds. - Bounds corresponding to the simulation domain containing this structure. - Unlike bounds_intersect, this is independent of the structure's bounds and - is purely based on the simulation geometry.""" - - frequencies: ArrayLike - """Frequencies at which the adjoint gradient should be computed.""" - - H_der_map: Optional[FieldData] = None - """Magnetic field gradient map. - Dataset where the field components ("Hx", "Hy", "Hz") store the multiplication - of the forward and adjoint magnetic fields. The tangential component of this - dataset is used when computing adjoint gradients for shifting boundaries of - structures composed of PEC mediums.""" - - H_fwd: Optional[FieldData] = None - """Forward magnetic fields. - Dataset where the field components ("Hx", "Hy", "Hz") represent the forward - magnetic fields used for computing gradients for a given structure.""" - - H_adj: Optional[FieldData] = None - """Adjoint magnetic fields. - Dataset where the field components ("Hx", "Hy", "Hz") represent the adjoint - magnetic fields used for computing gradients for a given structure.""" - - # Optional fields with defaults - eps_background: Optional[EpsType] = None - """Permittivity in background. - Permittivity outside of the Structure as manually specified by - Structure.background_medium.""" - - eps_no_structure: Optional[ScalarFieldDataArray] = None - """Permittivity without structure. - The permittivity of the original simulation without the structure that is - being differentiated with respect to. Used to approximate permittivity - outside of the structure for shape optimization.""" - - eps_inf_structure: Optional[ScalarFieldDataArray] = None - """Permittivity with infinite structure. - The permittivity of the original simulation where the structure being - differentiated with respect to is infinitely large. Used to approximate - permittivity inside of the structure for shape optimization.""" - - eps_approx: bool = False - """Use permittivity approximation. - If True, approximates outside permittivity using Simulation.medium and - the inside permittivity using Structure.medium. Only set True for - GeometryGroup handling where it is difficult to automatically evaluate - the inside and outside relative permittivity for each geometry.""" - - is_medium_pec: bool = False - """Indicates if structure material is PEC. - If True, the structure contains a PEC material which changes the gradient - formulation at the boundary compared to the dielectric case.""" - - interpolators: Optional[dict] = None - """Pre-computed interpolators. - Optional pre-computed interpolators for field components and permittivity data. - When provided, avoids redundant interpolator creation for multiple geometries - sharing the same field data. This significantly improves performance for - GeometryGroup processing.""" - - # private cache for interpolators - _interpolators_cache: dict = field(default_factory=dict, init=False, repr=False) - - def updated_copy(self, **kwargs: Any): - """Create a copy with updated fields.""" - kwargs.pop("deep", None) - kwargs.pop("validate", None) - return replace(self, **kwargs) - - @staticmethod - def _nan_to_num_if_needed(coords: np.ndarray) -> np.ndarray: - """Convert NaN and infinite values to finite numbers, optimized for finite inputs.""" - # skip check for small arrays - if coords.size < 1000: - return np.nan_to_num(coords, posinf=LARGE_NUMBER, neginf=-LARGE_NUMBER) - - if np.isfinite(coords).all(): - return coords - return np.nan_to_num(coords, posinf=LARGE_NUMBER, neginf=-LARGE_NUMBER) - - @staticmethod - def _evaluate_with_interpolators( - interpolators: dict, coords: np.ndarray - ) -> dict[str, np.ndarray]: - """Evaluate field components at coordinates using cached interpolators. - - Parameters - ---------- - interpolators : dict - Dictionary mapping field component names to ``RegularGridInterpolator`` objects. - coords : np.ndarray - Spatial coordinates (N, 3) where fields are evaluated. - - Returns - ------- - dict[str, np.ndarray] - Dictionary mapping component names to field values at coordinates. - """ - auto_cfg = config.adjoint - float_dtype = auto_cfg.gradient_dtype_float - complex_dtype = auto_cfg.gradient_dtype_complex - - coords = DerivativeInfo._nan_to_num_if_needed(coords) - if coords.dtype != float_dtype and coords.dtype != complex_dtype: - coords = coords.astype(float_dtype, copy=False) - return {name: interp(coords) for name, interp in interpolators.items()} - - def create_interpolators(self, dtype: Optional[np.dtype] = None) -> dict: - """Create interpolators for field components and permittivity data. - - Creates and caches ``RegularGridInterpolator`` objects for all field components - (E_fwd, E_adj, D_fwd, D_adj) and permittivity data (eps_inf, eps_no). - This caching strategy significantly improves performance by avoiding - repeated interpolator construction in gradient evaluation loops. - - Parameters - ---------- - dtype : np.dtype, optional - Data type for interpolation coordinates and values. Defaults to the - current ``config.adjoint.gradient_dtype_float``. - - Returns - ------- - dict - Nested dictionary structure: - - Field data: {"E_fwd": {"Ex": interpolator, ...}, ...} - - Permittivity: {"eps_inf": interpolator, "eps_no": interpolator} - """ - from scipy.interpolate import RegularGridInterpolator - - auto_cfg = config.adjoint - if dtype is None: - dtype = auto_cfg.gradient_dtype_float - complex_dtype = auto_cfg.gradient_dtype_complex - - cache_key = str(dtype) - if cache_key in self._interpolators_cache: - return self._interpolators_cache[cache_key] - - interpolators = {} - coord_cache = {} - - def _make_lazy_interpolator_group(field_data_dict, group_key, is_field_group=True) -> None: - """Helper to create a group of lazy interpolators.""" - if is_field_group: - interpolators[group_key] = {} - - for component_name, arr in field_data_dict.items(): - # use object ID for caching to handle shared grids - arr_id = id(arr.data) - if arr_id not in coord_cache: - points = tuple(c.data.astype(dtype, copy=False) for c in (arr.x, arr.y, arr.z)) - coord_cache[arr_id] = points - points = coord_cache[arr_id] - - def creator_func(arr=arr, points=points): - data = arr.data.astype( - complex_dtype if np.iscomplexobj(arr.data) else dtype, copy=False - ) - # create interpolator with frequency dimension - if "f" in arr.dims: - freq_coords = arr.coords["f"].data.astype(dtype, copy=False) - # ensure frequency dimension is last - if arr.dims != ("x", "y", "z", "f"): - freq_dim_idx = arr.dims.index("f") - axes = list(range(data.ndim)) - axes.append(axes.pop(freq_dim_idx)) - data = np.transpose(data, axes) - else: - # single frequency case - add singleton dimension - freq_coords = np.array([0.0], dtype=dtype) - data = data[..., np.newaxis] - - points_with_freq = (*points, freq_coords) - # If PEC, use nearest interpolation instead of linear to avoid interpolating - # with field values inside the PEC (which are 0). Instead, we make sure to - # choose interplation points such that their nearest location is outside of - # the PEC surface. - method = "nearest" if self.is_medium_pec else "linear" - interpolator_obj = RegularGridInterpolator( - points_with_freq, data, method=method, bounds_error=False, fill_value=None - ) - - def interpolator(coords): - # coords: (N, 3) spatial points - n_points = coords.shape[0] - n_freqs = len(freq_coords) - - # build coordinates with frequency dimension - coords_with_freq = np.empty((n_points * n_freqs, 4), dtype=coords.dtype) - coords_with_freq[:, :3] = np.repeat(coords, n_freqs, axis=0) - coords_with_freq[:, 3] = np.tile(freq_coords, n_points) - - result = interpolator_obj(coords_with_freq) - return result.reshape(n_points, n_freqs) - - return interpolator - - if is_field_group: - interpolators[group_key][component_name] = LazyInterpolator(creator_func) - else: - interpolators[component_name] = LazyInterpolator(creator_func) - - # process field interpolators (nested dictionaries) - interpolator_groups = [ - ("E_fwd", self.E_fwd), - ("E_adj", self.E_adj), - ("D_fwd", self.D_fwd), - ("D_adj", self.D_adj), - ] - if self.is_medium_pec: - interpolator_groups += [("H_fwd", self.H_fwd), ("H_adj", self.H_adj)] - for group_key, data_dict in interpolator_groups: - _make_lazy_interpolator_group(data_dict, group_key, is_field_group=True) - - if self.eps_inf_structure is not None: - _make_lazy_interpolator_group( - {"eps_inf": self.eps_inf_structure}, None, is_field_group=False - ) - if self.eps_no_structure is not None: - _make_lazy_interpolator_group( - {"eps_no": self.eps_no_structure}, None, is_field_group=False - ) - - self._interpolators_cache[cache_key] = interpolators - return interpolators - - def evaluate_gradient_at_points( - self, - spatial_coords: np.ndarray, - normals: np.ndarray, - perps1: np.ndarray, - perps2: np.ndarray, - interpolators: Optional[dict] = None, - ) -> np.ndarray: - """Compute adjoint gradients at surface points for shape optimization. - - Implements the surface integral formulation for computing gradients with respect - to geometry perturbations. - - Parameters - ---------- - spatial_coords : np.ndarray - (N, 3) array of surface evaluation points. - normals : np.ndarray - (N, 3) array of outward-pointing normal vectors at each surface point. - perps1 : np.ndarray - (N, 3) array of first tangent vectors perpendicular to normals. - perps2 : np.ndarray - (N, 3) array of second tangent vectors perpendicular to both normals and perps1. - interpolators : dict = None - Pre-computed field interpolators for efficiency. - - Returns - ------- - np.ndarray - (N,) array of gradient values at each surface point. Must be integrated - with appropriate quadrature weights to get total gradient. - """ - if interpolators is None: - raise NotImplementedError( - "Direct field evaluation without interpolators is not implemented. " - "Please create interpolators using 'create_interpolators()' first." - ) - - if "eps_no" in interpolators: - eps_out = interpolators["eps_no"](spatial_coords) - else: - # use eps_background if available, otherwise use eps_out - eps_to_prepare = ( - self.eps_background if self.eps_background is not None else self.eps_out - ) - eps_out = self._prepare_epsilon(eps_to_prepare) - - if self.is_medium_pec: - vjps = self._evaluate_pec_gradient_at_points( - spatial_coords, normals, perps1, perps2, interpolators, eps_out - ) - else: - vjps = self._evaluate_dielectric_gradient_at_points( - spatial_coords, normals, perps1, perps2, interpolators, eps_out - ) - - # sum over frequency dimension - vjps = np.sum(vjps, axis=-1) - - return vjps - - def _evaluate_dielectric_gradient_at_points( - self, - spatial_coords: np.ndarray, - normals: np.ndarray, - perps1: np.ndarray, - perps2: np.ndarray, - interpolators: dict, - # todo: type - eps_out, - ) -> np.ndarray: - # evaluate all field components at surface points - E_fwd_at_coords = { - name: interp(spatial_coords) for name, interp in interpolators["E_fwd"].items() - } - E_adj_at_coords = { - name: interp(spatial_coords) for name, interp in interpolators["E_adj"].items() - } - D_fwd_at_coords = { - name: interp(spatial_coords) for name, interp in interpolators["D_fwd"].items() - } - D_adj_at_coords = { - name: interp(spatial_coords) for name, interp in interpolators["D_adj"].items() - } - - if "eps_inf" in interpolators: - eps_in = interpolators["eps_inf"](spatial_coords) - else: - eps_in = self._prepare_epsilon(self.eps_in) - - delta_eps_inv = 1.0 / eps_in - 1.0 / eps_out - delta_eps = eps_in - eps_out - - # project fields onto local surface basis (normal + two tangents) - D_fwd_norm = self._project_in_basis(D_fwd_at_coords, basis_vector=normals) - D_adj_norm = self._project_in_basis(D_adj_at_coords, basis_vector=normals) - - E_fwd_perp1 = self._project_in_basis(E_fwd_at_coords, basis_vector=perps1) - E_adj_perp1 = self._project_in_basis(E_adj_at_coords, basis_vector=perps1) - - E_fwd_perp2 = self._project_in_basis(E_fwd_at_coords, basis_vector=perps2) - E_adj_perp2 = self._project_in_basis(E_adj_at_coords, basis_vector=perps2) - - D_der_norm = D_fwd_norm * D_adj_norm - E_der_perp1 = E_fwd_perp1 * E_adj_perp1 - E_der_perp2 = E_fwd_perp2 * E_adj_perp2 - - vjps = -delta_eps_inv * D_der_norm + E_der_perp1 * delta_eps + E_der_perp2 * delta_eps - - return vjps - - def _evaluate_pec_gradient_at_points( - self, - spatial_coords: np.ndarray, - normals: np.ndarray, - perps1: np.ndarray, - perps2: np.ndarray, - interpolators: dict, - # todo: type - eps_out, - ) -> np.ndarray: - def _adjust_spatial_coords_pec(grid_centers: dict[str, np.ndarray]): - """Assuming a nearest interpolation, adjust the interpolation points given the grid - defined by `grid_centers` and using `spatial_coords` as a starting point such that we - select a point outside of the PEC boundary. - - *** (nearest point outside boundary) - ^ - | n (normal direction) - | - _.-~'`-._.-~'`-._ (PEC surface) - * (nearest point) - - Parameters - ---------- - grid_centers: dict[str, np.ndarray] - The grid points for a given field component indexed by dimension. These grid points - are used to find the nearest snapping point and adjust the inerpolation coordinates - to ensure we fall outside of the PEC surface. - - Returns - ------- - (np.ndarray, np.ndarray) - (N, 3) array of coordinate centers at which to interpolate such that they line up - with a grid center and are outside the PEC surface - (N,) array of distances from the nearest interpolation points to the desired surface - edge points specified by `spatial_coords` - - """ - grid_ddim = np.zeros_like(normals) - for idx, dim in enumerate("xyz"): - expanded_coords = np.expand_dims(spatial_coords[:, idx], axis=1) - grid_centers_select = grid_centers[dim] - - diff = np.abs(expanded_coords - grid_centers_select) - - nearest_grid = np.argmin(diff, axis=-1) - nearest_grid = np.minimum(np.maximum(nearest_grid, 1), len(grid_centers_select) - 1) - - # compute the local grid spacing near the boundary - grid_ddim[:, idx] = ( - grid_centers_select[nearest_grid] - grid_centers_select[nearest_grid - 1] - ) - - # assuming we move in the normal direction, finds which dimension we need to move the least - # in order to ensure we snap to a point outside the boundary in the worst case (i.e. - the - # nearest point is just inside the surface) - min_movement_index = np.argmin( - np.abs(grid_ddim) / (np.abs(normals) + np.finfo(normals.dtype).min), axis=1 - ) - - selection = (np.arange(normals.shape[0]), min_movement_index) - coords_dn = np.expand_dims(np.abs(grid_ddim[selection]), axis=1) - - # adjust coordinates by half a grid point outside boundary such that nearest interpolation - # point snaps to outside the boundary - adjust_spatial_coords = spatial_coords + normals * 0.5 * coords_dn - - edge_distance = np.zeros_like(adjust_spatial_coords[:, 0]) - for idx, dim in enumerate("xyz"): - expanded_adjusted_coords = np.expand_dims(adjust_spatial_coords[:, idx], axis=1) - grid_centers_select = grid_centers[dim] - - # find nearest grid point from the adjusted coordinates - diff = np.abs(expanded_adjusted_coords - grid_centers_select) - nearest_grid = np.argmin(diff, axis=-1) - - # compute edge distance from the nearest interpolated point to the boundary edge - edge_distance += ( - np.abs(spatial_coords[:, idx] - grid_centers_select[nearest_grid]) ** 2 - ) - - # this edge distance is useful when correcting for edge singularities from the PEC material - # and is used when the PEC PolySlab structure has zero thickness - edge_distance = np.sqrt(edge_distance) - - return adjust_spatial_coords, edge_distance - - def _snap_coordinate_outside(field_components: FieldData): - """Helper function to perform coordinate adjustment and compute edge distance for each - component in `field_components`. - - Parameters - ---------- - field_components: FieldData - The field components (i.e - Ex, Ey, Ez, Hx, Hy, Hz) that we would like to sample just - outside the PEC surface using nearest interpolation. - - Returns - ------- - dict[str, dict[str, np.ndarray]] - Dictionary mapping each field component name to a dictionary of adjusted coordinates - and edge distances for that component. - """ - adjustment = {} - for name in field_components: - field_component = field_components[name] - field_component_coords = field_component.coords - - adjusted_coords, edge_distance = _adjust_spatial_coords_pec( - { - key: np.array(field_component_coords[key].values) - for key in field_component_coords - } - ) - adjustment[name] = {"coords": adjusted_coords, "edge_distance": edge_distance} - - return adjustment - - def _interpolate_field_components(interp_coords, field_name): - return { - name: interp(interp_coords[name]["coords"]) - for name, interp in interpolators[field_name].items() - } - - # adjust coordinates for PEC to be outside structure bounds and get edge distance for singularity correction. - E_fwd_coords_adjusted = _snap_coordinate_outside(self.E_fwd) - E_adj_coords_adjusted = _snap_coordinate_outside(self.E_adj) - - H_fwd_coords_adjusted = _snap_coordinate_outside(self.H_fwd) - H_adj_coords_adjusted = _snap_coordinate_outside(self.H_adj) - - # using the adjusted coordinates, evaluate all field components at surface points - E_fwd_at_coords = _interpolate_field_components(E_fwd_coords_adjusted, "E_fwd") - E_adj_at_coords = _interpolate_field_components(E_adj_coords_adjusted, "E_adj") - H_fwd_at_coords = _interpolate_field_components(H_fwd_coords_adjusted, "H_fwd") - H_adj_at_coords = _interpolate_field_components(H_adj_coords_adjusted, "H_adj") - - structure_sizes = np.array( - [self.bounds[1][idx] - self.bounds[0][idx] for idx in range(len(self.bounds[0]))] - ) - - is_flat_perp_dim1 = np.isclose(np.abs(np.sum(perps1[0] * structure_sizes)), 0.0) - is_flat_perp_dim2 = np.isclose(np.abs(np.sum(perps2[0] * structure_sizes)), 0.0) - flat_perp_dims = [is_flat_perp_dim1, is_flat_perp_dim2] - - # check if this integration is happening along an edge in which case we will eliminate - # on of the H field integration components and apply singularity correction - pec_line_integration = is_flat_perp_dim1 or is_flat_perp_dim2 - - def _compute_singularity_correction(adjustment_: dict[str, dict[str, np.ndarray]]): - """ - Given the `adjustment_` which contains the distance from the PEC edge each field - component is nearest interpolated at, computes the singularity correction when - working with 2D PEC using the average edge_distance for each component. In the case - of 3D PEC gradients, no singularity correction is applied so an array of ones is returned. - - Parameters - ---------- - adjustment_: dict[str, dict[str, np.ndarray]] - Dictionary that maps field component name to a dictionary containing the coordinate - adjustment and the distance to the PEC edge for those coordinates. The edge distance - is used for 2D PEC singularity correction. - - Returns - ------- - np.ndarray - Returns the singularity correction which has shape (N,) where there are N points in - `spatial_coords` - """ - return ( - ( - 0.5 - * np.pi - * np.mean([adjustment_[name]["edge_distance"] for name in adjustment_], axis=0) - ) - if pec_line_integration - else np.ones_like(spatial_coords, shape=spatial_coords.shape[0]) - ) - - E_norm_singularity_correction = np.expand_dims( - _compute_singularity_correction(E_fwd_coords_adjusted), axis=1 - ) - H_perp_singularity_correction = np.expand_dims( - _compute_singularity_correction(H_fwd_coords_adjusted), axis=1 - ) - - E_fwd_norm = self._project_in_basis(E_fwd_at_coords, basis_vector=normals) - E_adj_norm = self._project_in_basis(E_adj_at_coords, basis_vector=normals) - - # compute the normal E contribution to the gradient (the tangential E contribution - # is 0 in the case of PEC since this field component is continuous and thus 0 at - # the boundary) - contrib_E = E_norm_singularity_correction * eps_out * E_fwd_norm * E_adj_norm - vjps = contrib_E - - # compute the tangential H contribution to the gradient (the normal H contribution - # is 0 for PEC) - H_fwd_perp1 = self._project_in_basis(H_fwd_at_coords, basis_vector=perps1) - H_adj_perp1 = self._project_in_basis(H_adj_at_coords, basis_vector=perps1) - - H_fwd_perp2 = self._project_in_basis(H_fwd_at_coords, basis_vector=perps2) - H_adj_perp2 = self._project_in_basis(H_adj_at_coords, basis_vector=perps2) - - H_der_perp1 = H_perp_singularity_correction * H_fwd_perp1 * H_adj_perp1 - H_der_perp2 = H_perp_singularity_correction * H_fwd_perp2 * H_adj_perp2 - - H_integration_components = (H_der_perp1, H_der_perp2) - if pec_line_integration: - # if we are integrating along the line, we choose the H component normal to - # the edge which corresponds to a surface current along the edge whereas the other - # tangential component corresponds to a surface current along the flat dimension. - H_integration_components = tuple( - H_comp for idx, H_comp in enumerate(H_integration_components) if flat_perp_dims[idx] - ) - - # for each of the tangential components we are integrating the H fields over, - # adjust weighting to account for pre-weighting of the source by `EPSILON_0` - # and multiply by appropriate `MU_0` factor - for H_perp in H_integration_components: - contrib_H = MU_0 * H_perp / EPSILON_0 - vjps += contrib_H - - return vjps - - @staticmethod - def _prepare_epsilon(eps: EpsType) -> np.ndarray: - """Prepare epsilon values for multi-frequency. - - For FreqDataArray, extracts values and broadcasts to shape (1, n_freqs). - For scalar values, broadcasts to shape (1, 1) for consistency with multi-frequency. - """ - if isinstance(eps, FreqDataArray): - # data is already sliced, just extract values - eps_values = eps.values - # shape: (n_freqs,) - need to broadcast to (1, n_freqs) - return eps_values[np.newaxis, :] - else: - # scalar value - broadcast to (1, 1) - return np.array([[eps]]) - - @staticmethod - def _project_in_basis( - field_components: dict[str, np.ndarray], - basis_vector: np.ndarray, - ) -> np.ndarray: - """Project 3D field components onto a basis vector. - - Parameters - ---------- - field_components : dict[str, np.ndarray] - Dictionary with keys like "Ex", "Ey", "Ez" or "Dx", "Dy", "Dz" containing field values. - Values have shape (N, F) where F is the number of frequencies. - basis_vector : np.ndarray - (N, 3) array of basis vectors, one per evaluation point. - - Returns - ------- - np.ndarray - Projected field values with shape (N, F). - """ - prefix = next(iter(field_components.keys()))[0] - field_matrix = np.stack([field_components[f"{prefix}{dim}"] for dim in "xyz"], axis=0) - - # always expect (3, N, F) shape, transpose to (N, 3, F) - field_matrix = np.transpose(field_matrix, (1, 0, 2)) - return np.einsum("ij...,ij->i...", field_matrix, basis_vector) - - def adaptive_vjp_spacing( - self, - wl_fraction: Optional[float] = None, - min_allowed_spacing_fraction: Optional[float] = None, - ) -> float: - """Compute adaptive spacing for finite-difference gradient evaluation. - - Determines an appropriate spatial resolution based on the material - properties and electromagnetic wavelength/skin depth. - - Parameters - ---------- - wl_fraction : float, optional - Fraction of wavelength/skin depth to use as spacing. Defaults to the configured - ``autograd.default_wavelength_fraction`` when ``None``. - min_allowed_spacing_fraction : float, optional - Minimum allowed spacing fraction of free space wavelength used to - prevent numerical issues. Defaults to ``config.adjoint.minimum_spacing_fraction`` - when not specified. - - Returns - ------- - float - Adaptive spacing value for gradient evaluation. - """ - if wl_fraction is None or min_allowed_spacing_fraction is None: - from tidy3d.config import config - - if wl_fraction is None: - wl_fraction = config.adjoint.default_wavelength_fraction - if min_allowed_spacing_fraction is None: - min_allowed_spacing_fraction = config.adjoint.minimum_spacing_fraction - - # handle FreqDataArray or scalar eps_in - if isinstance(self.eps_in, FreqDataArray): - eps_real = np.asarray(self.eps_in.values, dtype=np.complex128).real - else: - eps_real = np.asarray(self.eps_in, dtype=np.complex128).real - - dx_candidates = [] - max_frequency = np.max(self.frequencies) - - # wavelength-based sampling for dielectrics - if np.any(eps_real > 0): - eps_max = eps_real[eps_real > 0].max() - lambda_min = self.wavelength_min / np.sqrt(eps_max) - dx_candidates.append(wl_fraction * lambda_min) - - # skin depth sampling for metals - if np.any(eps_real <= 0): - omega = 2 * np.pi * max_frequency - eps_neg = eps_real[eps_real <= 0] - delta_min = C_0 / (omega * np.sqrt(np.abs(eps_neg).max())) - dx_candidates.append(wl_fraction * delta_min) - - computed_spacing = min(dx_candidates) - min_allowed_spacing = self.wavelength_min * min_allowed_spacing_fraction - - if computed_spacing < min_allowed_spacing: - log.warning( - f"Based on the material, the adaptive spacing for integrating the polyslab surface " - f"would be {computed_spacing:.3e} μm. The spacing has been clipped to {min_allowed_spacing:.3e} μm " - f"to prevent a performance degradation.", - log_once=True, - ) - return max(computed_spacing, min_allowed_spacing) - - @property - def wavelength_min(self) -> float: - return C_0 / np.max(self.frequencies) - - @property - def wavelength_max(self) -> float: - return C_0 / np.min(self.frequencies) - - -def integrate_within_bounds(arr: xr.DataArray, dims: list[str], bounds: Bound) -> xr.DataArray: - """Integrate a data array within specified spatial bounds. - - Clips the integration domain to the specified bounds and performs - numerical integration using the trapezoidal rule. - - Parameters - ---------- - arr : xr.DataArray - Data array to integrate. - dims : list[str] - Dimensions to integrate over (e.g., ['x', 'y', 'z']). - bounds : Bound - Integration bounds as [[xmin, ymin, zmin], [xmax, ymax, zmax]]. - - Returns - ------- - xr.DataArray - Result of integration with specified dimensions removed. - - Notes - ----- - - Coordinates outside bounds are clipped, effectively setting dL=0 - - Only integrates dimensions with more than one coordinate point - - Uses xarray's integrate method (trapezoidal rule) - """ - bounds = np.asarray(bounds).T - all_coords = {} - - for dim, (bmin, bmax) in zip(dims, bounds): - bmin = get_static(bmin) - bmax = get_static(bmax) - - # clip coordinates to bounds (sets dL=0 outside bounds) - coord_values = arr.coords[dim].data - all_coords[dim] = np.clip(coord_values, bmin, bmax) - - _arr = arr.assign_coords(**all_coords) - - # only integrate dimensions with multiple points - dims_integrate = [dim for dim in dims if len(_arr.coords[dim]) > 1] - return _arr.integrate(coord=dims_integrate) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -__all__ = [ - "DerivativeInfo", - "integrate_within_bounds", -] +from tidy3d._common.components.autograd.derivative_utils import ( + ArrayComplex, + ArrayFloat, + DerivativeInfo, + EpsType, + FieldData, + LazyInterpolator, + PermittivityData, + integrate_within_bounds, +) diff --git a/tidy3d/components/autograd/field_map.py b/tidy3d/components/autograd/field_map.py index 101e0b56bd..6690a71649 100644 --- a/tidy3d/components/autograd/field_map.py +++ b/tidy3d/components/autograd/field_map.py @@ -1,76 +1,13 @@ -"""Typed containers for autograd traced field metadata.""" +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.field_map`.""" -from __future__ import annotations - -import json -from typing import Any, Callable - -import pydantic.v1 as pydantic - -from tidy3d.components.autograd.types import AutogradFieldMap, dict_ag -from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.types import ArrayLike, tidycomplex - - -class Tracer(Tidy3dBaseModel): - """Representation of a single traced element within a model.""" - - path: tuple[Any, ...] = pydantic.Field( - ..., - title="Path to the traced object in the model dictionary.", - ) - data: float | tidycomplex | ArrayLike = pydantic.Field(..., title="Tracing data") - - -class FieldMap(Tidy3dBaseModel): - """Collection of traced elements.""" - - tracers: tuple[Tracer, ...] = pydantic.Field( - ..., - title="Collection of Tracers.", - ) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - @property - def to_autograd_field_map(self) -> AutogradFieldMap: - """Convert to ``AutogradFieldMap`` autograd dictionary.""" - return dict_ag({tracer.path: tracer.data for tracer in self.tracers}) - - @classmethod - def from_autograd_field_map(cls, autograd_field_map: AutogradFieldMap) -> FieldMap: - """Initialize from an ``AutogradFieldMap`` autograd dictionary.""" - tracers = [] - for path, data in autograd_field_map.items(): - tracers.append(Tracer(path=path, data=data)) - return cls(tracers=tuple(tracers)) - - -def _encoded_path(path: tuple[Any, ...]) -> str: - """Return a stable JSON representation for a traced path.""" - return json.dumps(list(path), separators=(",", ":"), ensure_ascii=True) - - -class TracerKeys(Tidy3dBaseModel): - """Collection of traced field paths.""" - - keys: tuple[tuple[Any, ...], ...] = pydantic.Field( - ..., - title="Collection of tracer keys.", - ) - - def encoded_keys(self) -> list[str]: - """Return the JSON-encoded representation of keys.""" - return [_encoded_path(path) for path in self.keys] - - @classmethod - def from_field_mapping( - cls, - field_mapping: AutogradFieldMap, - *, - sort_key: Callable[[tuple[Any, ...]], str] | None = None, - ) -> TracerKeys: - """Construct keys from an autograd field mapping.""" - if sort_key is None: - sort_key = _encoded_path +# marked as migrated to _common +from __future__ import annotations - sorted_paths = tuple(sorted(field_mapping.keys(), key=sort_key)) - return cls(keys=sorted_paths) +from tidy3d._common.components.autograd.field_map import ( + FieldMap, + Tracer, + TracerKeys, + _encoded_path, +) diff --git a/tidy3d/components/autograd/functions.py b/tidy3d/components/autograd/functions.py index 6f86b05f4b..3e23c18503 100644 --- a/tidy3d/components/autograd/functions.py +++ b/tidy3d/components/autograd/functions.py @@ -1,266 +1,16 @@ -from __future__ import annotations - -import itertools -from typing import Any - -import autograd.numpy as anp -import numpy as np -from autograd.extend import defjvp, defvjp, primitive -from autograd.numpy.numpy_jvps import broadcast -from autograd.numpy.numpy_vjps import unbroadcast_f -from numpy.typing import NDArray - -from .types import InterpolationType - - -def _evaluate_nearest( - indices: NDArray[np.int64], norm_distances: NDArray[np.float64], values: NDArray[np.float64] -) -> NDArray[np.float64]: - """Perform nearest neighbor interpolation in an n-dimensional space. - - This function determines the nearest neighbor in a grid for a given point - and returns the corresponding value from the input array. - - Parameters - ---------- - indices : np.ndarray[np.int64] - Indices of the lower bounds of the grid cell containing the interpolation point. - norm_distances : np.ndarray[np.float64] - Normalized distances from the lower bounds of the grid cell to the - interpolation point, for each dimension. - values : np.ndarray[np.float64] - The n-dimensional array of values to interpolate from. - - Returns - ------- - np.ndarray[np.float64] - The value of the nearest neighbor to the interpolation point. - """ - idx_res = tuple(anp.where(yi <= 0.5, i, i + 1) for i, yi in zip(indices, norm_distances)) - return values[idx_res] - - -def _evaluate_linear( - indices: NDArray[np.int64], norm_distances: NDArray[np.float64], values: NDArray[np.float64] -) -> NDArray[np.float64]: - """Perform linear interpolation in an n-dimensional space. - - This function calculates the linearly interpolated value at a point in an - n-dimensional grid, given the indices of the surrounding grid points and - the normalized distances to these points. - The multi-linear interpolation is implemented by computing a weighted - average of the values at the vertices of the hypercube surrounding the - interpolation point. - - Parameters - ---------- - indices : np.ndarray[np.int64] - Indices of the lower bounds of the grid cell containing the interpolation point. - norm_distances : np.ndarray[np.float64] - Normalized distances from the lower bounds of the grid cell to the - interpolation point, for each dimension. - values : np.ndarray[np.float64] - The n-dimensional array of values to interpolate from. - - Returns - ------- - np.ndarray[np.float64] - The interpolated value at the desired point. - """ - # Create a slice object for broadcasting over trailing dimensions - _slice = (slice(None),) + (None,) * (values.ndim - len(indices)) - - # Prepare iterables for lower and upper bounds of the hypercube - ix = zip(indices, (1 - yi for yi in norm_distances)) - iy = zip((i + 1 for i in indices), norm_distances) - - # Initialize the result - value = anp.zeros(1) - - # Iterate over all vertices of the hypercube - for h in itertools.product(*zip(ix, iy)): - edge_indices, weights = zip(*h) - - # Compute the weight for this vertex - weight = anp.ones(1) - for w in weights: - weight = weight * w - - # Compute the contribution of this vertex and add it to the result - term = values[edge_indices] * weight[_slice] - value = value + term - - return value - - -def interpn( - points: tuple[NDArray[np.float64], ...], - values: NDArray[np.float64], - xi: tuple[NDArray[np.float64], ...], - *, - method: InterpolationType = "linear", - **kwargs: Any, -) -> NDArray[np.float64]: - """Interpolate over a rectilinear grid in arbitrary dimensions. - - This function mirrors the interface of `scipy.interpolate.interpn` but is differentiable with autograd. - - Parameters - ---------- - points : tuple[np.ndarray[np.float64], ...] - The points defining the rectilinear grid in n dimensions. - values : np.ndarray[np.float64] - The data values on the rectilinear grid. - xi : tuple[np.ndarray[np.float64], ...] - The coordinates to sample the gridded data at. - method : InterpolationType = "linear" - The method of interpolation to perform. Supported are "linear" and "nearest". - - Returns - ------- - np.ndarray[np.float64] - The interpolated values. - - Raises - ------ - ValueError - If the interpolation method is not supported. +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.functions`.""" - See Also - -------- - `scipy.interpolate.interpn `_ - """ - from scipy.interpolate import RegularGridInterpolator +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - if method == "nearest": - interp_fn = _evaluate_nearest - elif method == "linear": - interp_fn = _evaluate_linear - else: - raise ValueError(f"Unsupported interpolation method: {method}") - - if kwargs.get("fill_value") == "extrapolate": - itrp = RegularGridInterpolator( - points, values, method=method, fill_value=None, bounds_error=False - ) - else: - itrp = RegularGridInterpolator(points, values, method=method) - - # Prepare the grid for interpolation - # This step reshapes the grid, checks for NaNs and out-of-bounds values - # It returns: - # - reshaped grid - # - original shape - # - number of dimensions - # - boolean array indicating NaN positions - # - (discarded) boolean array for out-of-bounds values - xi, shape, ndim, nans, _ = itrp._prepare_xi(xi) - - # Find the indices of the grid cells containing the interpolation points - # and calculate the normalized distances (ranging from 0 at lower grid point to 1 - # at upper grid point) within these cells - indices, norm_distances = itrp._find_indices(xi.T) - - result = interp_fn(indices, norm_distances, values) - nans = anp.reshape(nans, (-1,) + (1,) * (result.ndim - 1)) - result = anp.where(nans, np.nan, result) - return anp.reshape(result, shape[:-1] + values.shape[ndim:]) - - -def trapz(y: NDArray, x: NDArray = None, dx: float = 1.0, axis: int = -1) -> float: - """ - Integrate along the given axis using the composite trapezoidal rule. - - Parameters - ---------- - y : np.ndarray - Input array to integrate. - x : np.ndarray = None - The sample points corresponding to the y values. If None, the sample points are assumed to be evenly spaced - with spacing `dx`. - dx : float = 1.0 - The spacing between sample points when `x` is None. Default is 1.0. - axis : int = -1 - The axis along which to integrate. Default is the last axis. - - Returns - ------- - float - Definite integral as approximated by the trapezoidal rule. - """ - if x is None: - d = dx - elif x.ndim == 1: - d = np.diff(x) - shape = [1] * y.ndim - shape[axis] = d.shape[0] - d = np.reshape(d, shape) - else: - d = np.diff(x, axis=axis) - - slice1 = [slice(None)] * y.ndim - slice2 = [slice(None)] * y.ndim - slice1[axis] = slice(1, None) - slice2[axis] = slice(None, -1) - - return anp.sum((y[tuple(slice1)] + y[tuple(slice2)]) * d / 2, axis=axis) - - -@primitive -def _add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray: - """ - Add values to specified indices of an array. - - Autograd requires that arguments to primitives are passed in positionally. - ``add_at`` is the public-facing wrapper for this function, - which allows keyword arguments in case users pass in kwargs. - """ - out = np.copy(x) # Copy to preserve 'x' for gradient computation - out[tuple(indices_x)] += y - return out - - -defvjp( - _add_at, - lambda ans, x, indices_x, y: unbroadcast_f(x, lambda g: g), - lambda ans, x, indices_x, y: lambda g: g[tuple(indices_x)], - argnums=(0, 2), -) +# marked as migrated to _common +from __future__ import annotations -defjvp( +from tidy3d._common.components.autograd.functions import ( _add_at, - lambda g, ans, x, indices_x, y: broadcast(g, ans), - lambda g, ans, x, indices_x, y: _add_at(anp.zeros_like(ans), indices_x, g), - argnums=(0, 2), + _evaluate_linear, + _evaluate_nearest, + _straight_through_clip, + add_at, + interpn, + trapz, ) - - -def add_at(x: NDArray, indices_x: tuple, y: NDArray) -> NDArray: - """ - Add values to specified indices of an array. - - This function creates a copy of the input array `x`, adds the values from `y` to the specified - indices `indices_x`, and returns the modified array. - - Parameters - ---------- - x : np.ndarray - Input array to which values will be added. - indices_x : tuple - Indices of `x` where values from `y` will be added. - y : np.ndarray - Values to add to the specified indices of `x`. - - Returns - ------- - np.ndarray - The modified array with values added at the specified indices. - """ - return _add_at(x, indices_x, y) - - -__all__ = [ - "add_at", - "interpn", - "trapz", -] diff --git a/tidy3d/components/autograd/types.py b/tidy3d/components/autograd/types.py index bb41935695..dcbfd4a274 100644 --- a/tidy3d/components/autograd/types.py +++ b/tidy3d/components/autograd/types.py @@ -1,55 +1,26 @@ -# type information for autograd +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.types`.""" -# utilities for working with autograd -from __future__ import annotations - -import copy -import typing - -import pydantic.v1 as pd -from autograd.builtins import dict as dict_ag -from autograd.extend import Box, defvjp, primitive - -from tidy3d.components.types import ArrayFloat2D, ArrayLike, Complex, Size1D, _add_schema - -# add schema to the Box -_add_schema(Box, title="AutogradBox", field_type_str="autograd.tracer.Box") - -# make sure Boxes in tidy3d properly define VJPs for copy operations, for computational graph -_copy = primitive(copy.copy) -_deepcopy = primitive(copy.deepcopy) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -defvjp(_copy, lambda ans, x: lambda g: _copy(g)) -defvjp(_deepcopy, lambda ans, x, memo: lambda g: _deepcopy(g, memo)) - -Box.__copy__ = lambda v: _copy(v) -Box.__deepcopy__ = lambda v, memo: _deepcopy(v, memo) - -# Types for floats, or collections of floats that can also be autograd tracers -TracedFloat = typing.Union[float, Box] -TracedPositiveFloat = typing.Union[pd.PositiveFloat, Box] -TracedSize1D = typing.Union[Size1D, Box] -TracedSize = typing.Union[tuple[TracedSize1D, TracedSize1D, TracedSize1D], Box] -TracedCoordinate = typing.Union[tuple[TracedFloat, TracedFloat, TracedFloat], Box] -TracedVertices = typing.Union[ArrayFloat2D, Box] - -# poles -TracedComplex = typing.Union[Complex, Box] -TracedPoleAndResidue = tuple[TracedComplex, TracedComplex] - -# The data type that we pass in and out of the web.run() @autograd.primitive -AutogradTraced = typing.Union[Box, ArrayLike] -PathType = tuple[typing.Union[int, str], ...] -AutogradFieldMap = dict_ag[PathType, AutogradTraced] - -InterpolationType = typing.Literal["nearest", "linear"] +# marked as migrated to _common +from __future__ import annotations -__all__ = [ - "AutogradFieldMap", - "AutogradTraced", - "TracedCoordinate", - "TracedFloat", - "TracedSize", - "TracedSize1D", - "TracedVertices", -] +from tidy3d._common.components.autograd.types import ( + AutogradFieldMap, + InterpolationType, + PathType, + TracedArrayFloat2D, + TracedArrayLike, + TracedComplex, + TracedCoordinate, + TracedDict, + TracedFloat, + TracedPoleAndResidue, + TracedPolesAndResidues, + TracedPositiveFloat, + TracedSize, + TracedSize1D, + _copy, + _deepcopy, + traced_alias, +) diff --git a/tidy3d/components/autograd/utils.py b/tidy3d/components/autograd/utils.py index a87e18f98b..f8d9f4304e 100644 --- a/tidy3d/components/autograd/utils.py +++ b/tidy3d/components/autograd/utils.py @@ -1,55 +1,16 @@ -# utilities for working with autograd -from __future__ import annotations - -from collections.abc import Iterable -from typing import Any - -import autograd.numpy as anp -from autograd.tracer import getval - -__all__ = [ - "asarray1d", - "contains", - "get_static", - "is_tidy_box", - "pack_complex_vec", - "split_list", -] - - -def get_static(x: Any) -> Any: - """Get the 'static' (untraced) version of some value.""" - return getval(x) - +"""Compatibility shim for :mod:`tidy3d._common.components.autograd.utils`.""" -def split_list(x: list[Any], index: int) -> (list[Any], list[Any]): - """Split a list at a given index.""" - x = list(x) - return x[:index], x[index:] - - -def is_tidy_box(x: Any) -> bool: - """Check if a value is a tidy box.""" - return getattr(x, "_tidy", False) - - -def contains(target: Any, seq: Iterable[Any]) -> bool: - """Return ``True`` if target occurs anywhere within arbitrarily nested iterables.""" - for x in seq: - if x == target: - return True - if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): - if contains(target, x): - return True - return False - - -def pack_complex_vec(z): - """Ravel [Re(z); Im(z)] into one real vector (autograd-safe).""" - return anp.concatenate([anp.ravel(anp.real(z)), anp.ravel(anp.imag(z))]) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -def asarray1d(x): - """Autograd-friendly 1D flatten: returns ndarray of shape (-1,).""" - x = anp.array(x) - return x if x.ndim == 1 else anp.ravel(x) +from tidy3d._common.components.autograd.utils import ( + asarray1d, + contains, + get_static, + hasbox, + is_tidy_box, + pack_complex_vec, + split_list, +) diff --git a/tidy3d/components/base.py b/tidy3d/components/base.py index cc7a8d91c5..cfae36803c 100644 --- a/tidy3d/components/base.py +++ b/tidy3d/components/base.py @@ -1,1499 +1,27 @@ -"""global configuration / base class for pydantic models used to make simulation.""" +"""Compatibility shim for :mod:`tidy3d._common.components.base`.""" -from __future__ import annotations - -import hashlib -import io -import json -import math -import os -import tempfile -from functools import wraps -from math import ceil -from os import PathLike -from pathlib import Path -from typing import Any, Callable, Literal, Optional, Union - -import h5py -import numpy as np -import pydantic.v1 as pydantic -import rich -import xarray as xr -import yaml -from autograd.builtins import dict as dict_ag -from autograd.tracer import isbox -from pydantic.v1.fields import ModelField -from pydantic.v1.json import custom_pydantic_encoder -from typing_extensions import Self - -from tidy3d.exceptions import FileError -from tidy3d.log import log - -from .autograd.types import AutogradFieldMap, Box -from .autograd.utils import get_static -from .data.data_array import DATA_ARRAY_MAP, DataArray -from .file_util import compress_file_to_gzip, extract_gzip_file -from .types import TYPE_TAG_STR, ComplexNumber - -INDENT_JSON_FILE = 4 # default indentation of json string in json files -INDENT = None # default indentation of json string used internally -JSON_TAG = "JSON_STRING" -# If json string is larger than ``MAX_STRING_LENGTH``, split the string when storing in hdf5 -MAX_STRING_LENGTH = 1_000_000_000 -FORBID_SPECIAL_CHARACTERS = ["/"] -TRACED_FIELD_KEYS_ATTR = "__tidy3d_traced_field_keys__" -TYPE_TO_CLASS_MAP: dict[str, type[Tidy3dBaseModel]] = {} - - -def cache(prop): - """Decorates a property to cache the first computed value and return it on subsequent calls.""" - - # note, we could also just use `prop` as dict key, but hashing property might be slow - prop_name = prop.__name__ - - @wraps(prop) - def cached_property_getter(self): - """The new property method to be returned by decorator.""" - - stored_value = self._cached_properties.get(prop_name) - - if stored_value is not None: - return stored_value - - computed_value = prop(self) - self._cached_properties[prop_name] = computed_value - return computed_value - - return cached_property_getter - - -def cached_property(cached_property_getter): - """Shortcut for property(cache()) of a getter.""" - - return property(cache(cached_property_getter)) - - -def cached_property_guarded(key_func): - """Like cached_property, but invalidates when the key_func(self) changes.""" - - def _decorator(getter): - prop_name = getter.__name__ - - @wraps(getter) - def _guarded(self): - cache_store = self._cached_properties.get(prop_name) - current_key = key_func(self) - if cache_store is not None: - cached_key, cached_value = cache_store - if cached_key == current_key: - return cached_value - value = getter(self) - self._cached_properties[prop_name] = (current_key, value) - return value - - return property(_guarded) - - return _decorator - - -def ndarray_encoder(val): - """How a ``np.ndarray`` gets handled before saving to json.""" - if np.any(np.iscomplex(val)): - return {"real": val.real.tolist(), "imag": val.imag.tolist()} - return val.real.tolist() - - -def make_json_compatible(json_string: str) -> str: - """Makes the string compatible with json standards, notably for infinity.""" - - tmp_string = "<>" - json_string = json_string.replace("-Infinity", tmp_string) - json_string = json_string.replace('""-Infinity""', tmp_string) - json_string = json_string.replace("Infinity", '"Infinity"') - json_string = json_string.replace('""Infinity""', '"Infinity"') - return json_string.replace(tmp_string, '"-Infinity"') - - -def _get_valid_extension(fname: PathLike) -> str: - """Return the file extension from fname, validated to accepted ones.""" - valid_extensions = [".json", ".yaml", ".hdf5", ".h5", ".hdf5.gz"] - path = Path(fname) - extensions = [s.lower() for s in path.suffixes[-2:]] - if len(extensions) == 0: - raise FileError(f"File '{path}' missing extension.") - single_extension = extensions[-1] - if single_extension in valid_extensions: - return single_extension - double_extension = "".join(extensions) - if double_extension in valid_extensions: - return double_extension - raise FileError( - f"File extension must be one of {', '.join(valid_extensions)}; file '{path}' does not " - "match any of those." - ) - - -def skip_if_fields_missing(fields: list[str], root=False): - """Decorate ``validator`` to check that other fields have passed validation.""" - - def actual_decorator(validator): - @wraps(validator) - def _validator(cls, *args: Any, **kwargs: Any): - """New validator function.""" - values = kwargs.get("values") - if values is None: - values = args[0] if root else args[1] - for field in fields: - if field not in values: - log.warning( - f"Could not execute validator '{validator.__name__}' because field " - f"'{field}' failed validation." - ) - if root: - return values - return kwargs.get("val") if "val" in kwargs else args[0] - - return validator(cls, *args, **kwargs) - - return _validator - - return actual_decorator - - -class Tidy3dBaseModel(pydantic.BaseModel): - """Base pydantic model that all Tidy3d components inherit from. - Defines configuration for handling data structures - as well as methods for importing, exporting, and hashing tidy3d objects. - For more details on pydantic base models, see: - `Pydantic Models `_ - """ - - def __hash__(self) -> int: - """Hash method.""" - try: - return super().__hash__(self) - except TypeError: - return hash(self.json()) - - def _hash_self(self) -> str: - """Hash this component with ``hashlib`` in a way that is the same every session.""" - bf = io.BytesIO() - self.to_hdf5(bf) - return hashlib.md5(bf.getvalue()).hexdigest() - - def __init__(self, **kwargs: Any) -> None: - """Init method, includes post-init validators.""" - log.begin_capture() - super().__init__(**kwargs) - self._post_init_validators() - log.end_capture(self) - - def _post_init_validators(self) -> None: - """Call validators taking ``self`` that get run after init, implement in subclasses.""" - - def __init_subclass__(cls) -> None: - """Things that are done to each of the models.""" - - cls.add_type_field() - cls.generate_docstring() - type_value = cls.__fields__.get(TYPE_TAG_STR) - if type_value and type_value.default: - TYPE_TO_CLASS_MAP[type_value.default] = cls - - @classmethod - def _get_type_value(cls, obj: dict[str, Any]) -> str: - """Return the type tag from a raw dictionary.""" - if not isinstance(obj, dict): - raise TypeError("Input must be a dict") - try: - type_value = obj[TYPE_TAG_STR] - except KeyError as exc: - raise ValueError(f'Missing "{TYPE_TAG_STR}" in data') from exc - if not isinstance(type_value, str) or not type_value: - raise ValueError(f'Invalid "{TYPE_TAG_STR}" value: {type_value!r}') - return type_value - - @classmethod - def _get_registered_class(cls, type_value: str) -> type[Tidy3dBaseModel]: - try: - return TYPE_TO_CLASS_MAP[type_value] - except KeyError as exc: - raise ValueError(f"Unknown type: {type_value}") from exc - - @classmethod - def _should_dispatch_to(cls, target_cls: type[Tidy3dBaseModel]) -> bool: - """Return True if ``cls`` allows auto-dispatch to ``target_cls``.""" - return issubclass(target_cls, cls) - - @classmethod - def _resolve_dispatch_target(cls, obj: dict[str, Any]) -> type[Tidy3dBaseModel]: - """Determine which subclass should receive ``obj``.""" - type_value = cls._get_type_value(obj) - target_cls = cls._get_registered_class(type_value) - if cls._should_dispatch_to(target_cls): - return target_cls - if target_cls is cls: - return cls - raise ValueError( - f'Cannot parse type "{type_value}" using {cls.__name__}; expected subclass of {cls.__name__}.' - ) - - @classmethod - def _target_cls_from_file( - cls, fname: PathLike, group_path: Optional[str] = None - ) -> type[Tidy3dBaseModel]: - """Peek the file metadata to determine the subclass to instantiate.""" - model_dict = cls.dict_from_file( - fname=fname, - group_path=group_path, - load_data_arrays=False, - ) - return cls._resolve_dispatch_target(model_dict) - - @classmethod - def _parse_obj(cls, obj: dict[str, Any], **parse_obj_kwargs: Any) -> Tidy3dBaseModel: - """Dispatch ``obj`` to the correct subclass registered in the type map.""" - target_cls = cls._resolve_dispatch_target(obj) - if target_cls is cls: - return super().parse_obj(obj, **parse_obj_kwargs) - return target_cls.parse_obj(obj, **parse_obj_kwargs) - - @classmethod - def _parse_model_dict( - cls, model_dict: dict[str, Any], **parse_obj_kwargs: Any - ) -> Tidy3dBaseModel: - """Parse ``model_dict`` while optionally auto-dispatching when called on the base class.""" - if cls is Tidy3dBaseModel: - return cls._parse_obj(model_dict, **parse_obj_kwargs) - return cls.parse_obj(model_dict, **parse_obj_kwargs) - - class Config: - """Sets config for all :class:`Tidy3dBaseModel` objects. - - Configuration Options - --------------------- - allow_population_by_field_name : bool = True - Allow properties to stand in for fields(?). - arbitrary_types_allowed : bool = True - Allow types like numpy arrays. - extra : str = 'forbid' - Forbid extra kwargs not specified in model. - json_encoders : Dict[type, Callable] - Defines how to encode type in json file. - validate_all : bool = True - Validate default values just to be safe. - validate_assignment : bool - Re-validate after re-assignment of field in model. - """ - - arbitrary_types_allowed = True - validate_all = True - extra = "forbid" - validate_assignment = True - allow_population_by_field_name = True - json_encoders = { - np.ndarray: ndarray_encoder, - complex: lambda x: ComplexNumber(real=x.real, imag=x.imag), - xr.DataArray: DataArray._json_encoder, - Box: lambda x: x._value, - } - frozen = True - allow_mutation = False - copy_on_model_validation = "none" - - _cached_properties = pydantic.PrivateAttr({}) - _has_tracers: Optional[bool] = pydantic.PrivateAttr(default=None) - - @pydantic.root_validator(skip_on_failure=True) - def _special_characters_not_in_name(cls, values): - name = values.get("name") - if name: - for character in FORBID_SPECIAL_CHARACTERS: - if character in name: - raise ValueError( - f"Special character '{character}' not allowed in component name {name}." - ) - return values - - attrs: dict = pydantic.Field( - {}, - title="Attributes", - description="Dictionary storing arbitrary metadata for a Tidy3D object. " - "This dictionary can be freely used by the user for storing data without affecting the " - "operation of Tidy3D as it is not used internally. " - "Note that, unlike regular Tidy3D fields, ``attrs`` are mutable. " - "For example, the following is allowed for setting an ``attr`` ``obj.attrs['foo'] = bar``. " - "Also note that Tidy3D will raise a ``TypeError`` if ``attrs`` contain objects " - "that can not be serialized. One can check if ``attrs`` are serializable " - "by calling ``obj.json()``.", - ) - - def _attrs_digest(self) -> str: - """Stable digest of `attrs` using the same JSON encoding rules as pydantic .json().""" - encoders = getattr(self.__config__, "json_encoders", {}) or {} - - def _default(o): - return custom_pydantic_encoder(encoders, o) - - json_str = json.dumps( - self.attrs, - default=_default, - sort_keys=True, - separators=(",", ":"), - ensure_ascii=False, - ) - json_str = make_json_compatible(json_str) - - return hashlib.sha256(json_str.encode("utf-8")).hexdigest() - - def copy(self, deep: bool = True, validate: bool = True, **kwargs: Any) -> Self: - """Copy a Tidy3dBaseModel. With ``deep=True`` and ``validate=True`` as default.""" - kwargs.update(deep=deep) - new_copy = pydantic.BaseModel.copy(self, **kwargs) - if validate: - return self.validate(new_copy.dict()) - # cached property is cleared automatically when validation is on, but it - # needs to be manually cleared when validation is off - new_copy._cached_properties = {} - new_copy._has_tracers = None - return new_copy - - def updated_copy( - self, path: Optional[str] = None, deep: bool = True, validate: bool = True, **kwargs: Any - ) -> Self: - """Make copy of a component instance with ``**kwargs`` indicating updated field values. - - Note - ---- - If ``path`` supplied, applies the updated copy with the update performed on the sub- - component corresponding to the path. For indexing into a tuple or list, use the integer - value. - - Example - ------- - >>> sim = simulation.updated_copy(size=new_size, path=f"structures/{i}/geometry") # doctest: +SKIP - """ - - if not path: - return self._updated_copy(**kwargs, deep=deep, validate=validate) - - path_components = path.split("/") - - field_name = path_components[0] - - try: - sub_component = getattr(self, field_name) - except AttributeError as e: - raise AttributeError( - f"Could not field field '{field_name}' in the sub-component `path`. " - f"Found fields of '{tuple(self.__fields__.keys())}'. " - "Please double check the `path` passed to `.updated_copy()`." - ) from e - - if isinstance(sub_component, (list, tuple)): - integer_index_path = path_components[1] - - try: - index = int(integer_index_path) - except ValueError: - raise ValueError( - f"Could not grab integer index from path '{path}'. " - f"Please correct the sub path containing '{integer_index_path}' to be an " - f"integer index into '{field_name}' (containing {len(sub_component)} elements)." - ) from None - - sub_component_list = list(sub_component) - sub_component = sub_component_list[index] - sub_path = "/".join(path_components[2:]) - - sub_component_list[index] = sub_component.updated_copy( - path=sub_path, deep=deep, validate=validate, **kwargs - ) - new_component = tuple(sub_component_list) - else: - sub_path = "/".join(path_components[1:]) - new_component = sub_component.updated_copy( - path=sub_path, deep=deep, validate=validate, **kwargs - ) - - return self._updated_copy(deep=deep, validate=validate, **{field_name: new_component}) - - def _updated_copy(self, deep: bool = True, validate: bool = True, **kwargs: Any) -> Self: - """Make copy of a component instance with ``**kwargs`` indicating updated field values.""" - return self.copy(update=kwargs, deep=deep, validate=validate) - - def help(self, methods: bool = False) -> None: - """Prints message describing the fields and methods of a :class:`Tidy3dBaseModel`. - - Parameters - ---------- - methods : bool = False - Whether to also print out information about object's methods. - - Example - ------- - >>> simulation.help(methods=True) # doctest: +SKIP - """ - rich.inspect(self, methods=methods) - - @classmethod - def from_file( - cls, - fname: PathLike, - group_path: Optional[str] = None, - lazy: bool = False, - on_load: Optional[Callable] = None, - **parse_obj_kwargs: Any, - ) -> Self: - """Loads a :class:`Tidy3dBaseModel` from .yaml, .json, .hdf5, or .hdf5.gz file. - - Parameters - ---------- - fname : PathLike - Full path to the file to load the :class:`Tidy3dBaseModel` from. - group_path : str | None = None - Path to a group inside the file to use as the base level. Only for hdf5 files. - Starting `/` is optional. - lazy : bool = False - Whether to load the actual data (``lazy=False``) or return a proxy that loads - the data when accessed (``lazy=True``). - on_load : Callable | None = None - Callback function executed once the model is fully materialized. - Only used if ``lazy=True``. The callback is invoked with the loaded - instance as its sole argument, enabling post-processing such as - validation, logging, or warnings checks. - **parse_obj_kwargs - Keyword arguments passed to either pydantic's ``parse_obj`` function when loading model. - - Returns - ------- - Self - An instance of the component class calling ``load``. - - Example - ------- - >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP - """ - if lazy: - target_cls = cls._target_cls_from_file(fname=fname, group_path=group_path) - Proxy = _make_lazy_proxy(target_cls, on_load=on_load) - return Proxy(fname, group_path, parse_obj_kwargs) - model_dict = cls.dict_from_file(fname=fname, group_path=group_path) - obj = cls._parse_model_dict(model_dict, **parse_obj_kwargs) - if not lazy and on_load is not None: - on_load(obj) - return obj - - @classmethod - def dict_from_file( - cls, fname: PathLike, group_path: Optional[str] = None, *, load_data_arrays: bool = True - ) -> dict: - """Loads a dictionary containing the model from a .yaml, .json, .hdf5, or .hdf5.gz file. - - Parameters - ---------- - fname : PathLike - Full path to the file to load the :class:`Tidy3dBaseModel` from. - group_path : str, optional - Path to a group inside the file to use as the base level. - - Returns - ------- - dict - A dictionary containing the model. - - Example - ------- - >>> simulation = Simulation.from_file(fname='folder/sim.json') # doctest: +SKIP - """ - fname_path = Path(fname) - extension = _get_valid_extension(fname_path) - kwargs = {"fname": fname_path} - - if group_path is not None: - if extension in {".hdf5", ".hdf5.gz", ".h5"}: - kwargs["group_path"] = group_path - else: - log.warning("'group_path' provided, but this feature only works with hdf5 files.") - - if extension in {".hdf5", ".hdf5.gz", ".h5"}: - kwargs["load_data_arrays"] = load_data_arrays - - converter = { - ".json": cls.dict_from_json, - ".yaml": cls.dict_from_yaml, - ".hdf5": cls.dict_from_hdf5, - ".hdf5.gz": cls.dict_from_hdf5_gz, - ".h5": cls.dict_from_hdf5, - }[extension] - return converter(**kwargs) - - def to_file(self, fname: PathLike) -> None: - """Exports :class:`Tidy3dBaseModel` instance to .yaml, .json, or .hdf5 file - - Parameters - ---------- - fname : PathLike - Full path to the .yaml or .json file to save the :class:`Tidy3dBaseModel` to. - - Example - ------- - >>> simulation.to_file(fname='folder/sim.json') # doctest: +SKIP - """ - extension = _get_valid_extension(fname) - converter = { - ".json": self.to_json, - ".yaml": self.to_yaml, - ".hdf5": self.to_hdf5, - ".hdf5.gz": self.to_hdf5_gz, - }[extension] - return converter(fname=fname) - - @classmethod - def from_json(cls, fname: PathLike, **parse_obj_kwargs: Any) -> Self: - """Load a :class:`Tidy3dBaseModel` from .json file. - - Parameters - ---------- - fname : PathLike - Full path to the .json file to load the :class:`Tidy3dBaseModel` from. - - Returns - ------- - Self - An instance of the component class calling `load`. - **parse_obj_kwargs - Keyword arguments passed to pydantic's ``parse_obj`` method. - - Example - ------- - >>> simulation = Simulation.from_json(fname='folder/sim.json') # doctest: +SKIP - """ - model_dict = cls.dict_from_json(fname=fname) - return cls._parse_model_dict(model_dict, **parse_obj_kwargs) - - @classmethod - def dict_from_json(cls, fname: PathLike) -> dict: - """Load dictionary of the model from a .json file. - - Parameters - ---------- - fname : PathLike - Full path to the .json file to load the :class:`Tidy3dBaseModel` from. - - Returns - ------- - dict - A dictionary containing the model. - - Example - ------- - >>> sim_dict = Simulation.dict_from_json(fname='folder/sim.json') # doctest: +SKIP - """ - with open(fname, encoding="utf-8") as json_fhandle: - model_dict = json.load(json_fhandle) - return model_dict - - def to_json(self, fname: PathLike) -> None: - """Exports :class:`Tidy3dBaseModel` instance to .json file - - Parameters - ---------- - fname : PathLike - Full path to the .json file to save the :class:`Tidy3dBaseModel` to. - - Example - ------- - >>> simulation.to_json(fname='folder/sim.json') # doctest: +SKIP - """ - export_model = self.to_static() - json_string = export_model._json(indent=INDENT_JSON_FILE) - self._warn_if_contains_data(json_string) - path = Path(fname) - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w", encoding="utf-8") as file_handle: - file_handle.write(json_string) - - @classmethod - def from_yaml(cls, fname: PathLike, **parse_obj_kwargs: Any) -> Self: - """Loads :class:`Tidy3dBaseModel` from .yaml file. - - Parameters - ---------- - fname : PathLike - Full path to the .yaml file to load the :class:`Tidy3dBaseModel` from. - **parse_obj_kwargs - Keyword arguments passed to pydantic's ``parse_obj`` method. - - Returns - ------- - Self - An instance of the component class calling `from_yaml`. - - Example - ------- - >>> simulation = Simulation.from_yaml(fname='folder/sim.yaml') # doctest: +SKIP - """ - model_dict = cls.dict_from_yaml(fname=fname) - return cls._parse_model_dict(model_dict, **parse_obj_kwargs) - - @classmethod - def dict_from_yaml(cls, fname: PathLike) -> dict: - """Load dictionary of the model from a .yaml file. - - Parameters - ---------- - fname : PathLike - Full path to the .yaml file to load the :class:`Tidy3dBaseModel` from. - - Returns - ------- - dict - A dictionary containing the model. - - Example - ------- - >>> sim_dict = Simulation.dict_from_yaml(fname='folder/sim.yaml') # doctest: +SKIP - """ - with open(fname, encoding="utf-8") as yaml_in: - model_dict = yaml.safe_load(yaml_in) - return model_dict - - def to_yaml(self, fname: PathLike) -> None: - """Exports :class:`Tidy3dBaseModel` instance to .yaml file. - - Parameters - ---------- - fname : PathLike - Full path to the .yaml file to save the :class:`Tidy3dBaseModel` to. - - Example - ------- - >>> simulation.to_yaml(fname='folder/sim.yaml') # doctest: +SKIP - """ - export_model = self.to_static() - json_string = export_model._json() - self._warn_if_contains_data(json_string) - model_dict = json.loads(json_string) - path = Path(fname) - path.parent.mkdir(parents=True, exist_ok=True) - with path.open("w+", encoding="utf-8") as file_handle: - yaml.dump(model_dict, file_handle, indent=INDENT_JSON_FILE) - - @staticmethod - def _warn_if_contains_data(json_str: str) -> None: - """Log a warning if the json string contains data, used in '.json' and '.yaml' file.""" - if any((key in json_str for key, _ in DATA_ARRAY_MAP.items())): - log.warning( - "Data contents found in the model to be written to file. " - "Note that this data will not be included in '.json' or '.yaml' formats. " - "As a result, it will not be possible to load the file back to the original model." - "Instead, use `.hdf5` extension in filename passed to 'to_file()'." - ) - - @staticmethod - def _construct_group_path(group_path: str) -> str: - """Construct a group path with the leading forward slash if not supplied.""" - - # empty string or None - if not group_path: - return "/" - - # missing leading forward slash - if group_path[0] != "/": - return f"/{group_path}" - - return group_path - - @staticmethod - def get_tuple_group_name(index: int) -> str: - """Get the group name of a tuple element.""" - return str(int(index)) - - @staticmethod - def get_tuple_index(key_name: str) -> int: - """Get the index into the tuple based on its group name.""" - return int(str(key_name)) - - @classmethod - def tuple_to_dict(cls, tuple_values: tuple) -> dict: - """How we generate a dictionary mapping new keys to tuple values for hdf5.""" - return {cls.get_tuple_group_name(index=i): val for i, val in enumerate(tuple_values)} - - @classmethod - def get_sub_model(cls, group_path: str, model_dict: dict | list) -> dict: - """Get the sub model for a given group path.""" - - for key in group_path.split("/"): - if key: - if isinstance(model_dict, list): - tuple_index = cls.get_tuple_index(key_name=key) - model_dict = model_dict[tuple_index] - else: - model_dict = model_dict[key] - return model_dict - - @staticmethod - def _json_string_key(index: int) -> str: - """Get json string key for string chunk number ``index``.""" - if index: - return f"{JSON_TAG}_{index}" - return JSON_TAG - - @classmethod - def _json_string_from_hdf5(cls, fname: PathLike) -> str: - """Load the model json string from an hdf5 file.""" - with h5py.File(fname, "r") as f_handle: - num_string_parts = len([key for key in f_handle.keys() if JSON_TAG in key]) - json_string = b"" - for ind in range(num_string_parts): - json_string += f_handle[cls._json_string_key(ind)][()] - return json_string - - @classmethod - def dict_from_hdf5( - cls, - fname: PathLike, - group_path: str = "", - custom_decoders: Optional[list[Callable]] = None, - load_data_arrays: bool = True, - ) -> dict: - """Loads a dictionary containing the model contents from a .hdf5 file. - - Parameters - ---------- - fname : PathLike - Full path to the .hdf5 file to load the :class:`Tidy3dBaseModel` from. - group_path : str, optional - Path to a group inside the file to selectively load a sub-element of the model only. - custom_decoders : List[Callable] - List of functions accepting - (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the - value in the model dict after a custom decoding. - - Returns - ------- - dict - Dictionary containing the model. - - Example - ------- - >>> sim_dict = Simulation.dict_from_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP - """ - - def is_data_array(value: Any) -> bool: - """Whether a value is supposed to be a data array based on the contents.""" - return isinstance(value, str) and value in DATA_ARRAY_MAP - - fname_path = Path(fname) - - def load_data_from_file(model_dict: dict, group_path: str = "") -> None: - """For every DataArray item in dictionary, load path of hdf5 group as value.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - for key, value in model_dict.items(): - subpath = f"{group_path}/{key}" - - # apply custom validation to the key value pair and modify model_dict - if custom_decoders: - for custom_decoder in custom_decoders: - custom_decoder( - fname=str(fname_path), - group_path=subpath, - model_dict=model_dict, - key=key, - value=value, - ) - - # write the path to the element of the json dict where the data_array should be - if is_data_array(value): - data_array_type = DATA_ARRAY_MAP[value] - model_dict[key] = data_array_type.from_hdf5( - fname=fname_path, group_path=subpath - ) - continue - - # if a list, assign each element a unique key, recurse - if isinstance(value, (list, tuple)): - value_dict = cls.tuple_to_dict(tuple_values=value) - load_data_from_file(model_dict=value_dict, group_path=subpath) - - # handle case of nested list of DataArray elements - val_tuple = list(value_dict.values()) - for ind, (model_item, value_item) in enumerate(zip(model_dict[key], val_tuple)): - if is_data_array(model_item): - model_dict[key][ind] = value_item - - # if a dict, recurse - elif isinstance(value, dict): - load_data_from_file(model_dict=value, group_path=subpath) - - model_dict = json.loads(cls._json_string_from_hdf5(fname=fname_path)) - group_path = cls._construct_group_path(group_path) - model_dict = cls.get_sub_model(group_path=group_path, model_dict=model_dict) - if load_data_arrays: - load_data_from_file(model_dict=model_dict, group_path=group_path) - return model_dict - - @classmethod - def from_hdf5( - cls, - fname: PathLike, - group_path: str = "", - custom_decoders: Optional[list[Callable]] = None, - **parse_obj_kwargs: Any, - ) -> Self: - """Loads :class:`Tidy3dBaseModel` instance to .hdf5 file. - - Parameters - ---------- - fname : PathLike - Full path to the .hdf5 file to load the :class:`Tidy3dBaseModel` from. - group_path : str, optional - Path to a group inside the file to selectively load a sub-element of the model only. - Starting `/` is optional. - custom_decoders : List[Callable] - List of functions accepting - (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the - value in the model dict after a custom decoding. - **parse_obj_kwargs - Keyword arguments passed to pydantic's ``parse_obj`` method. - - Example - ------- - >>> simulation = Simulation.from_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP - """ - - group_path = cls._construct_group_path(group_path) - model_dict = cls.dict_from_hdf5( - fname=fname, - group_path=group_path, - custom_decoders=custom_decoders, - ) - return cls._parse_model_dict(model_dict, **parse_obj_kwargs) - - def to_hdf5( - self, - fname: PathLike | io.BytesIO, - custom_encoders: Optional[list[Callable]] = None, - ) -> None: - """Exports :class:`Tidy3dBaseModel` instance to .hdf5 file. - - Parameters - ---------- - fname : PathLike | BytesIO - Full path to the .hdf5 file or buffer to save the :class:`Tidy3dBaseModel` to. - custom_encoders : List[Callable] - List of functions accepting (fname: str, group_path: str, value: Any) that take - the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. - - Example - ------- - >>> simulation.to_hdf5(fname='folder/sim.hdf5') # doctest: +SKIP - """ - - export_model = self.to_static() - traced_keys_payload = export_model.attrs.get(TRACED_FIELD_KEYS_ATTR) - - if traced_keys_payload is None: - traced_keys_payload = self.attrs.get(TRACED_FIELD_KEYS_ATTR) - if traced_keys_payload is None: - traced_keys_payload = self._serialized_traced_field_keys() - path = Path(fname) if isinstance(fname, PathLike) else fname - with h5py.File(path, "w") as f_handle: - json_str = export_model._json() - for ind in range(ceil(len(json_str) / MAX_STRING_LENGTH)): - ind_start = int(ind * MAX_STRING_LENGTH) - ind_stop = min(int(ind + 1) * MAX_STRING_LENGTH, len(json_str)) - f_handle[self._json_string_key(ind)] = json_str[ind_start:ind_stop] - - def add_data_to_file(data_dict: dict, group_path: str = "") -> None: - """For every DataArray item in dictionary, write path of hdf5 group as value.""" - - for key, value in data_dict.items(): - # append the key to the path - subpath = f"{group_path}/{key}" - - if custom_encoders: - for custom_encoder in custom_encoders: - custom_encoder(fname=f_handle, group_path=subpath, value=value) - - # write the path to the element of the json dict where the data_array should be - if isinstance(value, xr.DataArray): - value.to_hdf5(fname=f_handle, group_path=subpath) - - # if a tuple, assign each element a unique key - if isinstance(value, (list, tuple)): - value_dict = export_model.tuple_to_dict(tuple_values=value) - add_data_to_file(data_dict=value_dict, group_path=subpath) - - # if a dict, recurse - elif isinstance(value, dict): - add_data_to_file(data_dict=value, group_path=subpath) - - add_data_to_file(data_dict=export_model.dict()) - if traced_keys_payload: - f_handle.attrs[TRACED_FIELD_KEYS_ATTR] = traced_keys_payload - - @classmethod - def dict_from_hdf5_gz( - cls, - fname: PathLike, - group_path: str = "", - custom_decoders: Optional[list[Callable]] = None, - load_data_arrays: bool = True, - ) -> dict: - """Loads a dictionary containing the model contents from a .hdf5.gz file. - - Parameters - ---------- - fname : PathLike - Full path to the .hdf5.gz file to load the :class:`Tidy3dBaseModel` from. - group_path : str, optional - Path to a group inside the file to selectively load a sub-element of the model only. - custom_decoders : List[Callable] - List of functions accepting - (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the - value in the model dict after a custom decoding. - - Returns - ------- - dict - Dictionary containing the model. - - Example - ------- - >>> sim_dict = Simulation.dict_from_hdf5(fname='folder/sim.hdf5.gz') # doctest: +SKIP - """ - file_descriptor, extracted = tempfile.mkstemp(".hdf5") - os.close(file_descriptor) - extracted_path = Path(extracted) - try: - extract_gzip_file(fname, extracted_path) - result = cls.dict_from_hdf5( - extracted_path, - group_path=group_path, - custom_decoders=custom_decoders, - load_data_arrays=load_data_arrays, - ) - finally: - extracted_path.unlink(missing_ok=True) - - return result - - @classmethod - def from_hdf5_gz( - cls, - fname: PathLike, - group_path: str = "", - custom_decoders: Optional[list[Callable]] = None, - **parse_obj_kwargs: Any, - ) -> Self: - """Loads :class:`Tidy3dBaseModel` instance to .hdf5.gz file. - - Parameters - ---------- - fname : PathLike - Full path to the .hdf5.gz file to load the :class:`Tidy3dBaseModel` from. - group_path : str, optional - Path to a group inside the file to selectively load a sub-element of the model only. - Starting `/` is optional. - custom_decoders : List[Callable] - List of functions accepting - (fname: str, group_path: str, model_dict: dict, key: str, value: Any) that store the - value in the model dict after a custom decoding. - **parse_obj_kwargs - Keyword arguments passed to pydantic's ``parse_obj`` method. - - Example - ------- - >>> simulation = Simulation.from_hdf5_gz(fname='folder/sim.hdf5.gz') # doctest: +SKIP - """ - - group_path = cls._construct_group_path(group_path) - model_dict = cls.dict_from_hdf5_gz( - fname=fname, - group_path=group_path, - custom_decoders=custom_decoders, - ) - return cls._parse_model_dict(model_dict, **parse_obj_kwargs) - - def to_hdf5_gz( - self, fname: PathLike | io.BytesIO, custom_encoders: Optional[list[Callable]] = None - ) -> None: - """Exports :class:`Tidy3dBaseModel` instance to .hdf5.gz file. - - Parameters - ---------- - fname : PathLike | BytesIO - Full path to the .hdf5.gz file or buffer to save the :class:`Tidy3dBaseModel` to. - custom_encoders : List[Callable] - List of functions accepting (fname: str, group_path: str, value: Any) that take - the ``value`` supplied and write it to the hdf5 ``fname`` at ``group_path``. - - Example - ------- - >>> simulation.to_hdf5_gz(fname='folder/sim.hdf5.gz') # doctest: +SKIP - """ - file, decompressed = tempfile.mkstemp(".hdf5") - os.close(file) - try: - self.to_hdf5(decompressed, custom_encoders=custom_encoders) - compress_file_to_gzip(decompressed, fname) - finally: - os.unlink(decompressed) - - def __lt__(self, other): - """define < for getting unique indices based on hash.""" - return hash(self) < hash(other) - - def __gt__(self, other): - """define > for getting unique indices based on hash.""" - return hash(self) > hash(other) - - def __le__(self, other): - """define <= for getting unique indices based on hash.""" - return hash(self) <= hash(other) - - def __ge__(self, other): - """define >= for getting unique indices based on hash.""" - return hash(self) >= hash(other) - - def __eq__(self, other): - """Define == for two Tidy3dBaseModels.""" - if other is None: - return False - - def check_equal(dict1: dict, dict2: dict) -> bool: - """Check if two dictionaries are equal, with special handlings.""" - - # if different keys, automatically fail - if not dict1.keys() == dict2.keys(): - return False - - # loop through elements in each dict - for key in dict1: # noqa: PLC0206 - val1 = dict1[key] - val2 = dict2[key] - - val1 = get_static(val1) - val2 = get_static(val2) - - # if one of val1 or val2 is None (exclusive OR) - if (val1 is None) != (val2 is None): - return False - - # convert tuple to dict to use this recursive function - if isinstance(val1, tuple) and isinstance(val2, tuple): - val1 = dict(zip(range(len(val1)), val1)) - val2 = dict(zip(range(len(val2)), val2)) - - # if dictionaries, recurse - if isinstance(val1, dict) or isinstance(val2, dict): - are_equal = check_equal(val1, val2) - if not are_equal: - return False - - # if numpy arrays, use numpy to do equality check - elif isinstance(val1, np.ndarray) or isinstance(val2, np.ndarray): - if not np.array_equal(val1, val2): - return False - - # everything else - else: - # note: this logic is because != is handled differently in DataArrays apparently - if not val1 == val2: - return False - - return True - - return check_equal(self.dict(), other.dict()) - - @cached_property_guarded(lambda self: self._attrs_digest()) - def _json_string(self) -> str: - """Returns string representation of a :class:`Tidy3dBaseModel`. - - Returns - ------- - str - Json-formatted string holding :class:`Tidy3dBaseModel` data. - """ - return self._json() - - def _json(self, indent=INDENT, exclude_unset=False, **kwargs: Any) -> str: - """Overwrites the model ``json`` representation with some extra customized handling. - - Parameters - ----------- - **kwargs : kwargs passed to `self.json()` - - Returns - ------- - str - Json-formatted string holding :class:`Tidy3dBaseModel` data. - """ - - json_string = self.json(indent=indent, exclude_unset=exclude_unset, **kwargs) - json_string = make_json_compatible(json_string) - return json_string - - def _strip_traced_fields( - self, starting_path: tuple[str, ...] = (), include_untraced_data_arrays: bool = False - ) -> AutogradFieldMap: - """Extract a dictionary mapping paths in the model to the data traced by ``autograd``. - - Parameters - ---------- - starting_path : tuple[str, ...] = () - If provided, starts recursing in self.dict() from this path of field names - include_untraced_data_arrays : bool = False - Whether to include ``DataArray`` objects without tracers. - We need to include these when returning data, but are unnecessary for structures. - - Returns - ------- - dict - mapping of traced fields used by ``autograd`` - - """ - - path = tuple(starting_path) - if self._has_tracers is False and not include_untraced_data_arrays: - return dict_ag() - - field_mapping = {} - - def handle_value(x: Any, path: tuple[str, ...]) -> None: - """recursively update ``field_mapping`` with path to the autograd data.""" - - # this is a leaf node that we want to trace, add this path and data to the mapping - if isbox(x): - field_mapping[path] = x - - # for data arrays, need to be more careful as their tracers are stored in .data - elif isinstance(x, xr.DataArray) and (isbox(x.data) or include_untraced_data_arrays): - field_mapping[path] = x.data - - # for sequences, add (i,) to the path and handle each value individually - elif isinstance(x, (list, tuple)): - for i, val in enumerate(x): - handle_value(val, path=(*path, i)) - - # for dictionaries, add the (key,) to the path and handle each value individually - elif isinstance(x, dict): - for key, val in x.items(): - handle_value(val, path=(*path, key)) - - # recursively parse the dictionary of this object - self_dict = self.dict() - - # if an include_only string was provided, only look at that subset of the dict - if path: - for key in path: - self_dict = self_dict[key] - - handle_value(self_dict, path=path) - - if field_mapping: - if not include_untraced_data_arrays: - self._has_tracers = True - return dict_ag(field_mapping) - - if not include_untraced_data_arrays and not path: - self._has_tracers = False - return dict_ag() - - def _insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Self: - """Recursively insert a map of paths to autograd-traced fields into a copy of this obj.""" - - self_dict = self.dict() - - def insert_value(x, path: tuple[str, ...], sub_dict: dict) -> None: - """Insert a value into the path into a dictionary.""" - current_dict = sub_dict - for key in path[:-1]: - if isinstance(current_dict[key], tuple): - current_dict[key] = list(current_dict[key]) - current_dict = current_dict[key] - - final_key = path[-1] - if isinstance(current_dict[final_key], tuple): - current_dict[final_key] = list(current_dict[final_key]) - - sub_element = current_dict[final_key] - if isinstance(sub_element, xr.DataArray): - current_dict[final_key] = sub_element.copy(deep=False, data=x) - - else: - current_dict[final_key] = x - - for path, value in field_mapping.items(): - insert_value(value, path=path, sub_dict=self_dict) - - return type(self)._parse_model_dict(self_dict) - - def _serialized_traced_field_keys( - self, field_mapping: AutogradFieldMap | None = None - ) -> Optional[str]: - """Return a serialized, order-independent representation of traced field paths.""" - - if field_mapping is None: - field_mapping = self._strip_traced_fields() - if not field_mapping: - return None - - # TODO: remove this deferred import once TracerKeys is decoupled from Tidy3dBaseModel. - from tidy3d.components.autograd.field_map import TracerKeys - - tracer_keys = TracerKeys.from_field_mapping(field_mapping) - return tracer_keys.json(separators=(",", ":"), ensure_ascii=True) - - def to_static(self) -> Self: - """Version of object with all autograd-traced fields removed.""" - - if self._has_tracers is False: - return self - - # get dictionary of all traced fields - field_mapping = self._strip_traced_fields() - - # shortcut to just return self if no tracers found, for performance - if not field_mapping: - self._has_tracers = False - return self - - # convert all fields to static values - field_mapping_static = {key: get_static(val) for key, val in field_mapping.items()} - - # insert the static values into a copy of self - static_self = self._insert_traced_fields(field_mapping_static) - static_self._has_tracers = False - return static_self - - @classmethod - def add_type_field(cls) -> None: - """Automatically place "type" field with model name in the model field dictionary.""" - - value = cls.__name__ - annotation = Literal[value] - - tag_field = ModelField.infer( - name=TYPE_TAG_STR, - value=value, - annotation=annotation, - class_validators=None, - config=cls.__config__, - ) - cls.__fields__[TYPE_TAG_STR] = tag_field - - @classmethod - def generate_docstring(cls) -> str: - """Generates a docstring for a Tidy3D mode and saves it to the __doc__ of the class.""" - - # store the docstring in here - doc = "" - - # if the model already has a docstring, get the first lines and save the rest - original_docstrings = [] - if cls.__doc__: - original_docstrings = cls.__doc__.split("\n\n") - class_description = original_docstrings.pop(0) - doc += class_description - original_docstrings = "\n\n".join(original_docstrings) - - # create the list of parameters (arguments) for the model - doc += "\n\n Parameters\n ----------\n" - for field_name, field in cls.__fields__.items(): - # ignore the type tag - if field_name == TYPE_TAG_STR: - continue - - # get data type - data_type = field._type_display() - - # get default values - default_val = field.get_default() - if "=" in str(default_val): - # handle cases where default values are pydantic models - default_val = f"{default_val.__class__.__name__}({default_val})" - default_val = (", ").join(default_val.split(" ")) - - # make first line: name : type = default - default_str = "" if field.required else f" = {default_val}" - doc += f" {field_name} : {data_type}{default_str}\n" - - # get field metadata - field_info = field.field_info - doc += " " - - # add units (if present) - units = field_info.extra.get("units") - if units is not None: - if isinstance(units, (tuple, list)): - unitstr = "(" - for unit in units: - unitstr += str(unit) - unitstr += ", " - unitstr = unitstr[:-2] - unitstr += ")" - else: - unitstr = units - doc += f"[units = {unitstr}]. " - - # add description - description_str = field_info.description - if description_str is not None: - doc += f"{description_str}\n" - - # add in remaining things in the docs - if original_docstrings: - doc += "\n" - doc += original_docstrings - - doc += "\n" - cls.__doc__ = doc - - def get_submodels_by_hash(self) -> dict[int, list[Union[str, tuple[str, int]]]]: - """Return a dictionary of this object's sub-models indexed by their hash values.""" - fields = {} - for key in self.__fields__: - field = getattr(self, key) - - if isinstance(field, Tidy3dBaseModel): - hash_ = hash(field) - if hash_ not in fields: - fields[hash_] = [] - fields[hash_].append(key) - - # Do we need to consider np.ndarray here? - elif isinstance(field, (list, tuple, np.ndarray)): - for index, sub_field in enumerate(field): - if isinstance(sub_field, Tidy3dBaseModel): - hash_ = hash(sub_field) - if hash_ not in fields: - fields[hash_] = [] - fields[hash_].append((key, index)) - - elif isinstance(field, dict): - for index, sub_field in field.items(): - if isinstance(sub_field, Tidy3dBaseModel): - hash_ = hash(sub_field) - if hash_ not in fields: - fields[hash_] = [] - fields[hash_].append((key, index)) - - return fields - - @staticmethod - def _scientific_notation( - min_val: float, max_val: float, min_digits: int = 4 - ) -> tuple[str, str]: - """ - Convert numbers to scientific notation, displaying only digits up to the point of difference, - with a minimum number of significant digits specified by `min_digits`. - """ - - def to_sci(value: float, exponent: int, precision: int) -> str: - normalized_value = value / (10**exponent) - return f"{normalized_value:.{precision}f}e{exponent}" - - if min_val == 0 or max_val == 0: - return f"{min_val:.0e}", f"{max_val:.0e}" - - exponent_min = math.floor(math.log10(abs(min_val))) - exponent_max = math.floor(math.log10(abs(max_val))) - - common_exponent = min(exponent_min, exponent_max) - normalized_min = min_val / (10**common_exponent) - normalized_max = max_val / (10**common_exponent) - - if normalized_min == normalized_max: - precision = min_digits - else: - precision = 0 - while round(normalized_min, precision) == round(normalized_max, precision): - precision += 1 - - precision = max(precision, min_digits) - - sci_min = to_sci(min_val, common_exponent, precision) - sci_max = to_sci(max_val, common_exponent, precision) - - return sci_min, sci_max - - -def _make_lazy_proxy( - target_cls: type, - on_load: Optional[Callable[[Any], None]] = None, -) -> type: - """ - Return a lazy-loading proxy subclass of ``target_cls``. - - Parameters - ---------- - target_cls : type - Must implement ``dict_from_file`` and ``parse_obj``. - on_load : Callable[[Any], None] | None = None - A function to call with the fully loaded instance once loaded. - - Returns - ------- - type - A class named ``Proxy`` with init args: - ``(fname, group_path, parse_obj_kwargs)``. - """ - - proxy_name = f"{target_cls.__name__}Proxy" - - class _LazyProxy(target_cls): - def __init__( - self, - fname: PathLike, - group_path: Optional[str], - parse_obj_kwargs: Any, - ): - object.__setattr__(self, "_lazy_fname", Path(fname)) - object.__setattr__(self, "_lazy_group_path", group_path) - object.__setattr__(self, "_lazy_parse_obj_kwargs", dict(parse_obj_kwargs or {})) - - def copy(self, **kwargs: Any): - """Return another lazy proxy instead of materializing.""" - return _LazyProxy( - self._lazy_fname, - self._lazy_group_path, - {**self._lazy_parse_obj_kwargs, **kwargs}, - ) - - def __getattribute__(self, name: str): - if name in ( - "__class__", - "__dict__", - "__weakref__", - "__post_root_validators__", - "copy", # <-- avoid materializing just for copy - ) or name.startswith("_lazy_"): - return object.__getattribute__(self, name) - - d = object.__getattribute__(self, "__dict__") - if "_lazy_fname" in d: # sentinel: not loaded yet - fname = d["_lazy_fname"] - group_path = d["_lazy_group_path"] - kwargs = d["_lazy_parse_obj_kwargs"] - - model_dict = target_cls.dict_from_file(fname=fname, group_path=group_path) - target = target_cls._parse_model_dict(model_dict, **kwargs) - - d.clear() - d.update(target.__dict__) - object.__setattr__(self, "__class__", target.__class__) - object.__setattr__(self, "__fields_set__", set(target.__fields_set__)) - private_attrs = getattr(target, "__private_attributes__", {}) or {} - for attr_name in private_attrs: - object.__setattr__(self, attr_name, getattr(target, attr_name)) - - if on_load is not None: - on_load(self) - - return object.__getattribute__(self, name) +# marked as migrated to _common +from __future__ import annotations - _LazyProxy.__name__ = proxy_name - return _LazyProxy +from tidy3d._common.components.base import ( + FORBID_SPECIAL_CHARACTERS, + INDENT, + INDENT_JSON_FILE, + JSON_TAG, + MAX_STRING_LENGTH, + TRACED_FIELD_KEYS_ATTR, + TYPE_TO_CLASS_MAP, + T, + Tidy3dBaseModel, + _CacheReturn, + _fmt_ann_literal, + _get_valid_extension, + _GuardedReturn, + _make_lazy_proxy, + cache, + cached_property, + cached_property_guarded, + make_json_compatible, +) diff --git a/tidy3d/components/base_sim/data/monitor_data.py b/tidy3d/components/base_sim/data/monitor_data.py index 18fb82728a..5feb749c13 100644 --- a/tidy3d/components/base_sim/data/monitor_data.py +++ b/tidy3d/components/base_sim/data/monitor_data.py @@ -4,7 +4,7 @@ from abc import ABC -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base_sim.monitor import AbstractMonitor from tidy3d.components.data.dataset import Dataset @@ -15,8 +15,7 @@ class AbstractMonitorData(Dataset, ABC): :class:`AbstractMonitor`. """ - monitor: AbstractMonitor = pd.Field( - ..., + monitor: AbstractMonitor = Field( title="Monitor", description="Monitor associated with the data.", ) diff --git a/tidy3d/components/base_sim/data/sim_data.py b/tidy3d/components/base_sim/data/sim_data.py index 86e752cc55..c44c52c60d 100644 --- a/tidy3d/components/base_sim/data/sim_data.py +++ b/tidy3d/components/base_sim/data/sim_data.py @@ -2,20 +2,29 @@ from __future__ import annotations +import pathlib from abc import ABC -from typing import Union +from typing import TYPE_CHECKING, Any, Optional import numpy as np -import pydantic.v1 as pd -import xarray as xr +from pydantic import Field, field_validator, model_validator -from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData from tidy3d.components.base_sim.simulation import AbstractSimulation -from tidy3d.components.data.utils import UnstructuredGridDatasetType -from tidy3d.components.monitor import AbstractMonitor -from tidy3d.components.types import FieldVal -from tidy3d.exceptions import DataError, Tidy3dKeyError, ValidationError +from tidy3d.components.file_util import replace_values +from tidy3d.exceptions import DataError, FileError, Tidy3dKeyError, ValidationError + +if TYPE_CHECKING: + from os import PathLike + from typing import Union + + import xarray as xr + + from tidy3d.compat import Self + from tidy3d.components.data.utils import UnstructuredGridDatasetType + from tidy3d.components.monitor import AbstractMonitor + from tidy3d.components.types import FieldVal class AbstractSimulationData(Tidy3dBaseModel, ABC): @@ -23,20 +32,18 @@ class AbstractSimulationData(Tidy3dBaseModel, ABC): a :class:`AbstractSimulation`. """ - simulation: AbstractSimulation = pd.Field( - ..., + simulation: AbstractSimulation = Field( title="Simulation", description="Original :class:`AbstractSimulation` associated with the data.", ) - data: tuple[AbstractMonitorData, ...] = pd.Field( - ..., + data: tuple[AbstractMonitorData, ...] = Field( title="Monitor Data", description="List of :class:`AbstractMonitorData` instances " "associated with the monitors of the original :class:`AbstractSimulation`.", ) - log: str = pd.Field( + log: Optional[str] = Field( None, title="Solver Log", description="A string containing the log information from the simulation run.", @@ -52,15 +59,14 @@ def monitor_data(self) -> dict[str, AbstractMonitorData]: """Dictionary mapping monitor name to its associated :class:`AbstractMonitorData`.""" return {monitor_data.monitor.name: monitor_data for monitor_data in self.data} - @pd.root_validator(skip_on_failure=True) - def data_monitors_match_sim(cls, values): + @model_validator(mode="after") + def data_monitors_match_sim(self) -> Self: """Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in ``.simulation``. """ - sim = values.get("simulation") - data = values.get("data") + sim = self.simulation - for mnt_data in data: + for mnt_data in self.data: try: monitor_name = mnt_data.monitor.name sim.get_monitor_by_name(monitor_name) @@ -69,11 +75,13 @@ def data_monitors_match_sim(cls, values): f"Data with monitor name '{monitor_name}' supplied " f"but not found in the original '{sim.type}'." ) from exc - return values + return self - @pd.validator("data", always=True) - @skip_if_fields_missing(["simulation"]) - def validate_no_ambiguity(cls, val, values): + @field_validator("data") + @classmethod + def validate_no_ambiguity( + cls, val: tuple[AbstractMonitorData, ...] + ) -> tuple[AbstractMonitorData, ...]: """Ensure all :class:`AbstractMonitorData` entries in ``.data`` correspond to different monitors in ``.simulation``. """ @@ -129,6 +137,90 @@ def _field_component_value( return field_value + @staticmethod + def _apply_log_scale( + field_data: xr.DataArray, + vmin: Optional[float] = None, + db_factor: float = 1.0, + ) -> xr.DataArray: + """Prepare field data for log-scale plotting by handling zeros. + + Takes absolute value of the data, replaces zeros with a fill value + (to prevent log10(0) warnings), and applies log10 scaling. + + Parameters + ---------- + field_data : xr.DataArray + The field data to prepare. + vmin : float, optional + The minimum value for the color scale. If provided, zeros are replaced + with ``10 ** (vmin / db_factor)`` instead of NaN. + db_factor : float + Factor to multiply the log10 result by (e.g., 20 for dB scale of field, + 10 for dB scale of power). Default is 1 (pure log10 scale). + + Returns + ------- + xr.DataArray + The log-scaled field data. + """ + fill_val = np.nan + if vmin is not None: + fill_val = 10 ** (vmin / db_factor) + field_data = np.abs(field_data) + field_data = field_data.where((field_data > 0) | np.isnan(field_data), fill_val) + return db_factor * np.log10(field_data) + def get_monitor_by_name(self, name: str) -> AbstractMonitor: """Return monitor named 'name'.""" return self.simulation.get_monitor_by_name(name) + + def to_mat_file(self, fname: PathLike, **kwargs: Any) -> None: + """Output the simulation data object as ``.mat`` MATLAB file. + + Parameters + ---------- + fname : PathLike + Full path to the output file. Should include ``.mat`` file extension. + **kwargs : dict, optional + Extra arguments to ``scipy.io.savemat``: see ``scipy`` documentation for more detail. + + Example + ------- + >>> sim_data.to_mat_file('/path/to/file/data.mat') # doctest: +SKIP + """ + # Check .mat file extension is given + extension = pathlib.Path(fname).suffixes[0].lower() + if len(extension) == 0: + raise FileError(f"File '{fname}' missing extension.") + if extension != ".mat": + raise FileError(f"File '{fname}' should have a .mat extension.") + + # Handle m_dict in kwargs + if "m_dict" in kwargs: + raise ValueError( + "'m_dict' is automatically determined by 'to_mat_file', can't pass to 'savemat'." + ) + + # Get SimData object as dictionary + sim_dict = self.model_dump() + + # set long field names true by default, otherwise it wont save fields with > 31 characters + if "long_field_names" not in kwargs: + kwargs["long_field_names"] = True + + # Remove NoneType values from dict + # Built from theory discussed in https://github.com/scipy/scipy/issues/3488 + modified_sim_dict = replace_values(sim_dict, None, []) + + try: + from scipy.io import savemat + + savemat(fname, modified_sim_dict, **kwargs) + except Exception as e: + raise ValueError( + "Could not save supplied simulation data to file. As this is an experimental " + "feature, we may not be able to support the contents of your dataset. If you " + "receive this error, please feel free to raise an issue on our front end " + "repository so we can investigate." + ) from e diff --git a/tidy3d/components/base_sim/monitor.py b/tidy3d/components/base_sim/monitor.py index 061e08dc94..85a24caf67 100644 --- a/tidy3d/components/base_sim/monitor.py +++ b/tidy3d/components/base_sim/monitor.py @@ -3,22 +3,25 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import numpy as np -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import cached_property from tidy3d.components.geometry.base import Box -from tidy3d.components.types import ArrayFloat1D, Axis, Numpy from tidy3d.components.validators import _warn_unsupported_traced_argument -from tidy3d.components.viz import PlotParams, plot_params_monitor +from tidy3d.components.viz import plot_params_monitor + +if TYPE_CHECKING: + from tidy3d.components.types import ArrayFloat1D, Axis + from tidy3d.components.viz import PlotParams class AbstractMonitor(Box, ABC): """Abstract base class for steady-state monitors.""" - name: str = pd.Field( - ..., + name: str = Field( title="Name", description="Unique name for monitor.", min_length=1, @@ -60,20 +63,20 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: Number of bytes to be stored in monitor. """ - def downsample(self, arr: Numpy, axis: Axis) -> Numpy: + def downsample(self, arr: np.ndarray, axis: Axis) -> np.ndarray: """Downsample a 1D array making sure to keep the first and last entries, based on the spatial interval defined for the ``axis``. Parameters ---------- - arr : Numpy + arr : np.ndarray A 1D array of arbitrary type. axis : Axis Axis for which to select the interval_space defined for the monitor. Returns ------- - Numpy + np.ndarray Downsampled array. """ diff --git a/tidy3d/components/base_sim/simulation.py b/tidy3d/components/base_sim/simulation.py index 7f64c9d545..f141db9894 100644 --- a/tidy3d/components/base_sim/simulation.py +++ b/tidy3d/components/base_sim/simulation.py @@ -3,43 +3,40 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional import autograd.numpy as anp -import pydantic.v1 as pd +from pydantic import Field, model_validator -from tidy3d.components.base import cached_property, skip_if_fields_missing +from tidy3d.components.base import cached_property from tidy3d.components.geometry.base import Box from tidy3d.components.medium import Medium, MediumType3D from tidy3d.components.scene import Scene from tidy3d.components.structure import Structure -from tidy3d.components.types import ( - TYPE_TAG_STR, - Ax, - Axis, - Bound, - LengthUnit, - PriorityMode, - Symmetry, -) +from tidy3d.components.types import TYPE_TAG_STR, LengthUnit, PriorityMode, Symmetry from tidy3d.components.validators import ( _warn_unsupported_traced_argument, assert_objects_in_sim_bounds, assert_unique_names, ) -from tidy3d.components.viz import PlotParams, add_ax_if_none, equal_aspect, plot_params_symmetry +from tidy3d.components.viz import add_ax_if_none, equal_aspect, plot_params_symmetry from tidy3d.exceptions import Tidy3dKeyError from tidy3d.log import log from tidy3d.version import __version__ -from .monitor import AbstractMonitor +if TYPE_CHECKING: + from tidy3d.compat import Self + from tidy3d.components.types import Ax, Axis, Bound + from tidy3d.components.viz import PlotParams + + from .monitor import AbstractMonitor class AbstractSimulation(Box, ABC): """Base class for simulation classes of different solvers.""" - medium: MediumType3D = pd.Field( - Medium(), + medium: MediumType3D = Field( + default_factory=Medium, title="Background Medium", description="Background medium of simulation, defaults to vacuum if not specified.", discriminator=TYPE_TAG_STR, @@ -48,7 +45,7 @@ class AbstractSimulation(Box, ABC): Background medium of simulation, defaults to vacuum if not specified. """ - structures: tuple[Structure, ...] = pd.Field( + structures: tuple[Structure, ...] = Field( (), title="Structures", description="Tuple of structures present in simulation. " @@ -77,7 +74,7 @@ class AbstractSimulation(Box, ABC): ) """ - symmetry: tuple[Symmetry, Symmetry, Symmetry] = pd.Field( + symmetry: tuple[Symmetry, Symmetry, Symmetry] = Field( (0, 0, 0), title="Symmetries", description="Tuple of integers defining reflection symmetry across a plane " @@ -85,37 +82,37 @@ class AbstractSimulation(Box, ABC): "at the simulation center of each axis, respectively. ", ) - sources: tuple[None, ...] = pd.Field( + sources: tuple[None, ...] = Field( (), title="Sources", description="Sources in the simulation.", ) - boundary_spec: None = pd.Field( + boundary_spec: Literal[None] = Field( None, title="Boundaries", description="Specification of boundary conditions.", ) - monitors: tuple[None, ...] = pd.Field( + monitors: tuple[None, ...] = Field( (), title="Monitors", description="Monitors in the simulation. ", ) - grid_spec: None = pd.Field( + grid_spec: Literal[None] = Field( None, title="Grid Specification", description="Specifications for the simulation grid.", ) - version: str = pd.Field( + version: str = Field( __version__, title="Version", description="String specifying the front end version number.", ) - plot_length_units: Optional[LengthUnit] = pd.Field( + plot_length_units: Optional[LengthUnit] = Field( "μm", title="Plot Units", description="When set to a supported ``LengthUnit``, " @@ -123,7 +120,7 @@ class AbstractSimulation(Box, ABC): "include the desired unit specifier in labels.", ) - structure_priority_mode: PriorityMode = pd.Field( + structure_priority_mode: PriorityMode = Field( "equal", title="Structure Priority Setting", description="This field only affects structures of `priority=None`. " @@ -134,17 +131,19 @@ class AbstractSimulation(Box, ABC): """ Validating setup """ - @pd.root_validator(pre=True) - def _update_simulation(cls, values): + @model_validator(mode="before") + @classmethod + def _update_simulation(cls, data: dict[str, Any]) -> dict[str, Any]: """Update the simulation if it is an earlier version.""" - # dummy upgrade of version number # this should be overriden by each simulation class if needed - current_version = values.get("version") + if not hasattr(data, "get"): + return data + current_version = data.get("version") if current_version != __version__ and current_version is not None: log.warning(f"updating {cls.__name__} from {current_version} to {__version__}") - values["version"] = __version__ - return values + data["version"] = __version__ + return data # make sure all names are unique _unique_monitor_names = assert_unique_names("monitors") @@ -157,20 +156,19 @@ def _update_simulation(cls, values): _warn_traced_center = _warn_unsupported_traced_argument("center") _warn_traced_size = _warn_unsupported_traced_argument("size") - @pd.validator("structures", always=True) - @skip_if_fields_missing(["size", "center"]) - def _structures_not_at_edges(cls, val, values): + @model_validator(mode="after") + def _structures_not_at_edges(self) -> Self: """Warn if any structures lie at the simulation boundaries.""" - if val is None: - return val + if self.structures is None: + return self - sim_box = Box(size=values.get("size"), center=values.get("center")) + sim_box = Box(size=self.size, center=self.center) sim_bound_min, sim_bound_max = sim_box.bounds sim_bounds = list(sim_bound_min) + list(sim_bound_max) with log as consolidated_logger: - for istruct, structure in enumerate(val): + for istruct, structure in enumerate(self.structures): struct_bound_min, struct_bound_max = structure.geometry.bounds struct_bounds = list(struct_bound_min) + list(struct_bound_max) @@ -185,13 +183,12 @@ def _structures_not_at_edges(cls, val, values): ) continue - return val - - """ Post-init validators """ + return self - def _post_init_validators(self) -> None: - """Call validators taking z`self` that get run after init.""" + @model_validator(mode="after") + def _validate_scene(self) -> Self: _ = self.scene + return self def validate_pre_upload(self) -> None: """Validate the fully initialized simulation is ok for upload to our servers.""" @@ -269,9 +266,9 @@ def plot( Opacity of the monitors. If ``None``, uses Tidy3d default. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. fill_structures : bool = True Whether to fill structures with color or just draw outlines. @@ -325,9 +322,9 @@ def plot_sources( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. alpha : float = None Opacity of the sources, If ``None`` uses Tidy3d default. @@ -373,9 +370,9 @@ def plot_monitors( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. alpha : float = None Opacity of the sources, If ``None`` uses Tidy3d default. @@ -420,9 +417,9 @@ def plot_symmetries( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. @@ -532,9 +529,9 @@ def plot_structures( position of plane in z direction, only one of x, y, z must be specified to define plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. fill : bool = True Whether to fill structures with color or just draw outlines. @@ -591,9 +588,9 @@ def plot_structures_eps( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -657,9 +654,9 @@ def plot_structures_heat_conductivity( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -703,7 +700,7 @@ def from_scene(cls, scene: Scene, **kwargs: Any) -> AbstractSimulation: **kwargs, ) - def plot_3d(self, width=800, height=800) -> None: + def plot_3d(self, width: float = 800, height: float = 800) -> None: """Render 3D plot of ``AbstractSimulation`` (in jupyter notebook only). Parameters ---------- diff --git a/tidy3d/components/base_sim/source.py b/tidy3d/components/base_sim/source.py index f1630c41c0..277c4e6de5 100644 --- a/tidy3d/components/base_sim/source.py +++ b/tidy3d/components/base_sim/source.py @@ -1,23 +1,10 @@ -"""Abstract base for classes that define simulation sources.""" +"""Compatibility shim for :mod:`tidy3d._common.components.base_sim.source`.""" -from __future__ import annotations - -from abc import ABC, abstractmethod - -import pydantic.v1 as pydantic - -from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.validators import validate_name_str -from tidy3d.components.viz import PlotParams +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -class AbstractSource(Tidy3dBaseModel, ABC): - """Abstract base class for all sources.""" - - name: str = pydantic.Field(None, title="Name", description="Optional name for the source.") - - @abstractmethod - def plot_params(self) -> PlotParams: - """Default parameters for plotting a Source object.""" - - _name_validator = validate_name_str() +from tidy3d._common.components.base_sim.source import ( + AbstractSource, +) diff --git a/tidy3d/components/bc_placement.py b/tidy3d/components/bc_placement.py index 1f48a94fbf..e513bc3646 100644 --- a/tidy3d/components/bc_placement.py +++ b/tidy3d/components/bc_placement.py @@ -5,8 +5,9 @@ from abc import ABC from typing import Union -import pydantic.v1 as pd +from pydantic import Field, field_validator +from tidy3d.components.types.base import discriminated_union from tidy3d.exceptions import SetupError from .base import Tidy3dBaseModel @@ -25,7 +26,7 @@ class StructureBoundary(AbstractBCPlacement): >>> bc_placement = StructureBoundary(structure="box") """ - structure: str = pd.Field( + structure: str = Field( title="Structure Name", description="Name of the structure.", ) @@ -39,13 +40,14 @@ class StructureStructureInterface(AbstractBCPlacement): >>> bc_placement = StructureStructureInterface(structures=["box", "sphere"]) """ - structures: tuple[str, str] = pd.Field( + structures: tuple[str, str] = Field( title="Structures", description="Names of two structures.", ) - @pd.validator("structures", always=True) - def unique_names(cls, val): + @field_validator("structures") + @classmethod + def unique_names(cls, val: tuple[str, str]) -> tuple[str, str]: """Error if the same structure is provided twice""" if val[0] == val[1]: raise SetupError( @@ -62,13 +64,14 @@ class MediumMediumInterface(AbstractBCPlacement): >>> bc_placement = MediumMediumInterface(mediums=["dieletric", "metal"]) """ - mediums: tuple[str, str] = pd.Field( + mediums: tuple[str, str] = Field( title="Mediums", description="Names of two mediums.", ) - @pd.validator("mediums", always=True) - def unique_names(cls, val): + @field_validator("mediums") + @classmethod + def unique_names(cls, val: tuple[str, str]) -> tuple[str, str]: """Error if the same structure is provided twice""" if val[0] == val[1]: raise SetupError("The same medium is provided twice in 'MediumMediumInterface'.") @@ -83,7 +86,7 @@ class SimulationBoundary(AbstractBCPlacement): >>> bc_placement = SimulationBoundary(surfaces=["x-", "x+"]) """ - surfaces: tuple[BoxSurface, ...] = pd.Field( + surfaces: tuple[BoxSurface, ...] = Field( ("x-", "x+", "y-", "y+", "z-", "z+"), title="Surfaces", description="Surfaces of simulation domain where to apply boundary conditions.", @@ -98,22 +101,24 @@ class StructureSimulationBoundary(AbstractBCPlacement): >>> bc_placement = StructureSimulationBoundary(structure="box", surfaces=["y-", "y+"]) """ - structure: str = pd.Field( + structure: str = Field( title="Structure Name", description="Name of the structure.", ) - surfaces: tuple[BoxSurface, ...] = pd.Field( + surfaces: tuple[BoxSurface, ...] = Field( ("x-", "x+", "y-", "y+", "z-", "z+"), title="Surfaces", description="Surfaces of simulation domain where to apply boundary conditions.", ) -BCPlacementType = Union[ - StructureBoundary, - StructureStructureInterface, - MediumMediumInterface, - SimulationBoundary, - StructureSimulationBoundary, -] +BCPlacementType = discriminated_union( + Union[ + StructureBoundary, + StructureStructureInterface, + MediumMediumInterface, + SimulationBoundary, + StructureSimulationBoundary, + ] +) diff --git a/tidy3d/components/beam.py b/tidy3d/components/beam.py index ae5ffd9b56..c6349ebd3f 100644 --- a/tidy3d/components/beam.py +++ b/tidy3d/components/beam.py @@ -4,10 +4,10 @@ from __future__ import annotations from abc import abstractmethod -from typing import Literal, Optional, Union +from typing import TYPE_CHECKING, Optional, Union import autograd.numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat from tidy3d.constants import C_0, ETA_0, HERTZ, MICROMETER, RADIAN @@ -19,16 +19,21 @@ from .medium import Medium, MediumType from .monitor import FieldMonitor from .source.field import FixedAngleSpec, FixedInPlaneKSpec -from .types import TYPE_TAG_STR, Direction, FreqArray, Numpy +from .types import TYPE_TAG_STR, Direction, FreqArray from .validators import assert_plane, warn_backward_waist_distance +if TYPE_CHECKING: + from typing import Literal + + from numpy.typing import NDArray + DEFAULT_RESOLUTION = 200 class BeamProfile(Box): """Base class for handling analytic beams.""" - resolution: float = pd.Field( + resolution: float = Field( DEFAULT_RESOLUTION, title="Sampling resolution", description="Sampling resolution in the tangential directions of the beam (defines a " @@ -36,27 +41,26 @@ class BeamProfile(Box): units=MICROMETER, ) - freqs: FreqArray = pd.Field( - ..., + freqs: FreqArray = Field( title="Frequencies", description="List of frequencies at which the beam is sampled.", units=HERTZ, ) - background_medium: MediumType = pd.Field( - Medium(), + background_medium: MediumType = Field( + default_factory=Medium, title="Background Medium", description="Background medium in which the beam is embedded.", ) - angle_theta: float = pd.Field( + angle_theta: float = Field( 0.0, title="Polar Angle", description="Polar angle of the propagation axis from the normal axis.", units=RADIAN, ) - angle_phi: float = pd.Field( + angle_phi: float = Field( 0.0, title="Azimuth Angle", description="Azimuth angle of the propagation axis in the plane orthogonal to the " @@ -64,7 +68,7 @@ class BeamProfile(Box): units=RADIAN, ) - pol_angle: float = pd.Field( + pol_angle: float = Field( 0.0, title="Polarization Angle", description="Specifies the angle between the electric field polarization of the " @@ -78,7 +82,7 @@ class BeamProfile(Box): units=RADIAN, ) - direction: Direction = pd.Field( + direction: Direction = Field( "+", title="Direction", description="Specifies propagation in the positive or negative direction of the normal " @@ -132,7 +136,9 @@ def field_data(self) -> FieldData: return data_raw.updated_copy(**fields_norm) - def _field_data_on_grid(self, grid: Grid, background_n: Numpy, colocate=True) -> dict: + def _field_data_on_grid( + self, grid: Grid, background_n: NDArray, colocate: bool = True + ) -> dict[str, ScalarFieldDataArray]: """Compute the field data for each field component on a grid for the beam. A dictionary of the scalar field data arrays is returned, not yet packaged as ``FieldData``. """ @@ -166,14 +172,14 @@ def _field_data_on_grid(self, grid: Grid, background_n: Numpy, colocate=True) -> return scalar_fields @abstractmethod - def scalar_field(self, points: Numpy, background_n: float) -> Numpy: + def scalar_field(self, points: NDArray, background_n: float) -> NDArray: """Scalar field corresponding to the analytic beam in coordinate system such that the propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is computed on an unstructured array ``points`` of shape ``(3, ...)``.""" def analytic_beam_z_normal( - self, points: Numpy, background_n: float, field: Literal["E", "H"] - ) -> Numpy: + self, points: NDArray, background_n: float, field: Literal["E", "H"] + ) -> NDArray: """Analytic beam with all the beam parameters but assuming ``z`` as the normal axis.""" # Add a frequency dimension to points @@ -213,12 +219,12 @@ def analytic_beam_z_normal( def analytic_beam( self, - x: Numpy, - y: Numpy, - z: Numpy, + x: NDArray, + y: NDArray, + z: NDArray, background_n: float, field: Literal["E", "H"], - ) -> Numpy: + ) -> NDArray: """Sample the analytic beam fields on a cartesian grid of points in x, y, z.""" # Make a meshgrid @@ -242,13 +248,13 @@ def analytic_beam( # Reshape to (3, Nx, Ny, Nz, num_freqs) return np.reshape(field_vals, (3, Nx, Ny, Nz, len(self.freqs))) - def _rotate_points_z(self, points: Numpy, background_n: Numpy) -> Numpy: + def _rotate_points_z(self, points: NDArray, background_n: NDArray) -> NDArray: """Rotate points to new coordinates where z is the propagation axis.""" points_prop_z = self.rotate_points(points, [0, 0, 1], -self.angle_phi) points_prop_z = self.rotate_points(points_prop_z, [0, 1, 0], -self.angle_theta) return points_prop_z - def _inverse_rotate_field_vals_z(self, field_vals: Numpy, background_n: Numpy) -> Numpy: + def _inverse_rotate_field_vals_z(self, field_vals: NDArray, background_n: NDArray) -> NDArray: """Rotate field values from coordinates where z is the propagation axis to angled coordinates.""" field_vals = self.rotate_points(field_vals, [0, 1, 0], self.angle_theta) @@ -260,17 +266,19 @@ class PlaneWaveBeamProfile(BeamProfile): """Component for constructing plane wave beam data. The normal direction is implicitly defined by the ``size`` parameter. - See also :class:`.PlaneWave`. + See Also + -------- + :class:`.PlaneWave` """ - angular_spec: Union[FixedInPlaneKSpec, FixedAngleSpec] = pd.Field( + angular_spec: Union[FixedInPlaneKSpec, FixedAngleSpec] = Field( FixedAngleSpec(), title="Angular Dependence Specification", description="Specification of plane wave propagation direction dependence on wavelength.", discriminator=TYPE_TAG_STR, ) - as_fixed_angle_source: bool = pd.Field( + as_fixed_angle_source: bool = Field( False, title="Fixed Angle Flag", description="Fixed angle flag. Only used internally when computing source beams for " @@ -278,7 +286,7 @@ class PlaneWaveBeamProfile(BeamProfile): "switch between waves with fixed angle and fixed in-plane k.", ) - angle_theta_frequency: Optional[float] = pd.Field( + angle_theta_frequency: Optional[float] = Field( None, title="Frequency at Which Angle Theta is Defined", description="Frequency for which ``angle_theta`` is set. This only has an effect for " @@ -287,18 +295,18 @@ class PlaneWaveBeamProfile(BeamProfile): ) @property - def _angle_theta_frequency(self): + def _angle_theta_frequency(self) -> float: if not self.angle_theta_frequency: return np.mean(self.freqs) return self.angle_theta_frequency - def in_plane_k(self, background_n: float): + def in_plane_k(self, background_n: float) -> list[float]: """In-plane wave vector. Only the real part is taken so the beam has no in-plane decay.""" k0 = 2 * np.pi * self._angle_theta_frequency / C_0 * background_n k_in_plane = k0.real * np.sin(self.angle_theta) return [k_in_plane * np.cos(self.angle_phi), k_in_plane * np.sin(self.angle_phi)] - def scalar_field(self, points: Numpy, background_n: float) -> Numpy: + def scalar_field(self, points: NDArray, background_n: float) -> NDArray: """Scalar field for plane wave. Scalar field corresponding to the analytic beam in coordinate system such that the propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is @@ -313,14 +321,14 @@ def scalar_field(self, points: Numpy, background_n: float) -> Numpy: kz *= np.cos(self.angle_theta) return np.exp(1j * points[2] * kz) - def _angle_theta_actual(self, background_n: Numpy) -> Numpy: + def _angle_theta_actual(self, background_n: NDArray) -> NDArray: """Compute the frequency-dependent actual propagation angle theta.""" k0 = 2 * np.pi * np.array(self.freqs) / C_0 * background_n kx, ky = self.in_plane_k(background_n) k_perp = np.sqrt(kx**2 + ky**2) return np.real(np.arcsin(k_perp / k0)) * np.sign(self.angle_theta) - def _rotate_points_z(self, points: Numpy, background_n: Numpy) -> Numpy: + def _rotate_points_z(self, points: NDArray, background_n: NDArray) -> NDArray: """Rotate points to new coordinates where z is the propagation axis.""" if self.as_fixed_angle_source: # For fixed-angle, we do not rotate the points @@ -334,7 +342,7 @@ def _rotate_points_z(self, points: Numpy, background_n: Numpy) -> Numpy: return points return super()._rotate_points_z(points, background_n) - def _inverse_rotate_field_vals_z(self, field_vals: Numpy, background_n: Numpy) -> Numpy: + def _inverse_rotate_field_vals_z(self, field_vals: NDArray, background_n: NDArray) -> NDArray: """Rotate field values from coordinates where z is the propagation axis to angled coordinates. Special handling is needed if fixed in-plane k wave.""" if isinstance(self.angular_spec, FixedInPlaneKSpec): @@ -353,17 +361,19 @@ class GaussianBeamProfile(BeamProfile): """Component for constructing Gaussian beam data. The normal direction is implicitly defined by the ``size`` parameter. - See also :class:`.GaussianBeam`. + See Also + -------- + :class:`.GaussianBeam` """ - waist_radius: pd.PositiveFloat = pd.Field( + waist_radius: PositiveFloat = Field( 1.0, title="Waist Radius", description="Radius of the beam at the waist.", units=MICROMETER, ) - waist_distance: float = pd.Field( + waist_distance: float = Field( 0.0, title="Waist Distance", description="Distance from the beam waist along the propagation direction. " @@ -376,14 +386,14 @@ class GaussianBeamProfile(BeamProfile): ) _backward_waist_warning = warn_backward_waist_distance("waist_distance") - def beam_params(self, z: Numpy, k0: Numpy) -> tuple[Numpy, Numpy, Numpy]: + def beam_params(self, z: NDArray, k0: NDArray) -> tuple[NDArray, NDArray, NDArray]: """Compute the parameters needed to evaluate a Gaussian beam at z. Parameters ---------- - z : Numpy + z : np.ndarray Axial distance from the beam center. - k0 : Numpy + k0 : np.ndarray Wave vector magnitude. """ @@ -398,7 +408,7 @@ def beam_params(self, z: Numpy, k0: Numpy) -> tuple[Numpy, Numpy, Numpy]: psi_g = np.arctan((z + z_0) / z_r) - np.arctan(z_0 / z_r) return w_z, inv_r_z, psi_g - def scalar_field(self, points: Numpy, background_n: float) -> Numpy: + def scalar_field(self, points: NDArray, background_n: float) -> NDArray: """Scalar field for Gaussian beam. Scalar field corresponding to the analytic beam in coordinate system such that the propagation direction is z and the ``E``-field is entirely ``x``-polarized. The field is @@ -420,17 +430,19 @@ class AstigmaticGaussianBeamProfile(BeamProfile): """Component for constructing astigmatic Gaussian beam data. The normal direction is implicitly defined by the ``size`` parameter. - See also :class:`.AstigmaticGaussianBeam`. + See Also + -------- + :class:`.AstigmaticGaussianBeam` """ - waist_sizes: tuple[pd.PositiveFloat, pd.PositiveFloat] = pd.Field( + waist_sizes: tuple[PositiveFloat, PositiveFloat] = Field( (1.0, 1.0), title="Waist sizes", description="Size of the beam at the waist in the local x and y directions.", units=MICROMETER, ) - waist_distances: tuple[float, float] = pd.Field( + waist_distances: tuple[float, float] = Field( (0.0, 0.0), title="Waist distances", description="Distance to the beam waist along the propagation direction " @@ -443,14 +455,14 @@ class AstigmaticGaussianBeamProfile(BeamProfile): ) _backward_waist_warning = warn_backward_waist_distance("waist_distances") - def beam_params(self, z: Numpy, k0: Numpy) -> tuple[Numpy, Numpy, Numpy, Numpy]: + def beam_params(self, z: NDArray, k0: NDArray) -> tuple[NDArray, NDArray, NDArray, NDArray]: """Compute the parameters needed to evaluate an astigmatic Gaussian beam at z. Parameters ---------- - z : Numpy + z : np.ndarray Axial distance from the beam center. - k0 : Numpy + k0 : np.ndarray Wave vector magnitude. """ @@ -470,7 +482,7 @@ def beam_params(self, z: Numpy, k0: Numpy) -> tuple[Numpy, Numpy, Numpy, Numpy]: return w_0, w_z, inv_r_z, psi_g - def scalar_field(self, points: Numpy, background_n: float) -> Numpy: + def scalar_field(self, points: NDArray, background_n: float) -> NDArray: """ Scalar field for astigmatic Gaussian beam. Scalar field corresponding to the analytic beam in coordinate system such that the diff --git a/tidy3d/components/boundary.py b/tidy3d/components/boundary.py index 2a586278a0..12172c27cf 100644 --- a/tidy3d/components/boundary.py +++ b/tidy3d/components/boundary.py @@ -3,43 +3,58 @@ from __future__ import annotations from abc import ABC -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + field_validator, + model_validator, +) from tidy3d.components.validators import _assert_min_freq, assert_plane -from tidy3d.components.viz import ( - ARROW_ALPHA, - ARROW_COLOR_ABSORBER, - PlotParams, - plot_params_absorber, -) +from tidy3d.components.viz import ARROW_ALPHA, ARROW_COLOR_ABSORBER, plot_params_absorber from tidy3d.constants import C_0, CONDUCTIVITY, EPSILON_0, HERTZ, MU_0, PML_SIGMA from tidy3d.exceptions import DataError, SetupError, ValidationError from tidy3d.log import log -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from .base import Tidy3dBaseModel, cached_property from .geometry.base import Box from .medium import Medium from .mode_spec import ModeSpec -from .monitor import ModeMonitor, ModeSolverMonitor from .source.field import TFSF, GaussianBeam, ModeSource, PlaneWave -from .types import TYPE_TAG_STR, Ax, Axis, Complex, Direction, FreqBound +from .types import TYPE_TAG_STR, Direction, FreqBound from .types.mode_spec import ModeSpecType +if TYPE_CHECKING: + from typing import Callable + + from numpy.typing import NDArray + + from tidy3d.compat import Self + from tidy3d.components.viz import PlotParams + + from .monitor import ModeMonitor, ModeSolverMonitor + from .types import Ax, Axis, Complex + MIN_NUM_PML_LAYERS = 6 MIN_NUM_STABLE_PML_LAYERS = 6 MIN_NUM_ABSORBER_LAYERS = 6 -def warn_num_layers_factory(min_num_layers: int, descr: str): +def warn_num_layers_factory( + min_num_layers: int, descr: str +) -> Callable[[type[AbsorberSpec], int], int]: """Several similar classes defined have a ``num_layers`` data member, and they generate similar warning messages when ``num_layers`` is too small. This function creates a pydantic validator which can be shared with all of these classes to create these warning messages.""" - @pd.validator("num_layers", allow_reuse=True, always=True) - def _warn_num_layers(cls, val): + @field_validator("num_layers") + @classmethod + def _warn_num_layers(cls: type[AbsorberSpec], val: int) -> int: if val < min_num_layers: cls_name = cls.__name__ log.warning( @@ -66,7 +81,11 @@ def _warn_num_layers(cls, val): class BoundaryEdge(ABC, Tidy3dBaseModel): """Electromagnetic boundary condition at a domain edge.""" - name: str = pd.Field(None, title="Name", description="Optional unique name for boundary.") + name: Optional[str] = Field( + None, + title="Name", + description="Optional unique name for boundary.", + ) # PBC keyword @@ -74,7 +93,7 @@ class Periodic(BoundaryEdge): """Periodic boundary condition class.""" @property - def bloch_vec(self): + def bloch_vec(self) -> int: """Periodic boundaries are effectively Bloch boundaries with ``bloch_vec == 0``. In practice, periodic boundaries do not force the use of complex fields, while Bloch boundaries do, even with ``bloch_vec == 0``. Thus, it is more efficient to use periodic. @@ -101,7 +120,7 @@ class ABCBoundary(AbstractABCBoundary): See, for example, John B. Schneider, Understanding the Finite-Difference Time-Domain Method, Chapter 6. """ - permittivity: Optional[float] = pd.Field( + permittivity: Optional[float] = Field( None, title="Effective Permittivity", description="Effective permittivity for determining propagation constant. " @@ -110,7 +129,7 @@ class ABCBoundary(AbstractABCBoundary): ge=1.0, ) - conductivity: Optional[pd.NonNegativeFloat] = pd.Field( + conductivity: Optional[NonNegativeFloat] = Field( None, title="Effective Conductivity", description="Effective conductivity for determining propagation constant. " @@ -119,17 +138,15 @@ class ABCBoundary(AbstractABCBoundary): units=CONDUCTIVITY, ) - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["permittivity"]) - def _conductivity_only_with_float_permittivity(cls, val, values): + @model_validator(mode="after") + def _conductivity_only_with_float_permittivity(self) -> Self: """Validate that conductivity can be provided only with float permittivity.""" - perm = values["permittivity"] - if val is not None and perm is None: + if self.conductivity is not None and self.permittivity is None: raise ValidationError( "Field 'conductivity' in 'ABCBoundary' can only be provided " "simultaneously with 'permittivity'." ) - return val + return self class BroadbandModeABCFitterParam(Tidy3dBaseModel): @@ -144,7 +161,7 @@ class BroadbandModeABCFitterParam(Tidy3dBaseModel): >>> fitter_param = BroadbandModeABCFitterParam(max_num_poles=5, tolerance_rms=1e-4, frequency_sampling_points=10) """ - max_num_poles: int = pd.Field( + max_num_poles: int = Field( DEFAULT_BROADBAND_MODE_ABC_NUM_POLES, title="Maximal Number Of Poles", description="Maximal number of poles in complex-conjugate pole residue model for " @@ -153,13 +170,13 @@ class BroadbandModeABCFitterParam(Tidy3dBaseModel): le=MAX_BROADBAND_MODE_ABC_NUM_POLES, ) - tolerance_rms: pd.NonNegativeFloat = pd.Field( + tolerance_rms: NonNegativeFloat = Field( DEFAULT_BROADBAND_MODE_ABC_FITTER_TOLERANCE, title="Fitting Tolerance", description="Tolerance in fitting the mode propagation index.", ) - frequency_sampling_points: int = pd.Field( + frequency_sampling_points: int = Field( DEFAULT_BROADBAND_MODE_ABC_NUM_FREQS, title="Number Of Frequencies", description="Number of sampling frequencies used in fitting the mode propagation index.", @@ -179,21 +196,21 @@ class BroadbandModeABCSpec(Tidy3dBaseModel): >>> broadband_mode_abc_spec = BroadbandModeABCSpec(frequency_range=(100e12, 120e12), fit_param=BroadbandModeABCFitterParam()) """ - frequency_range: FreqBound = pd.Field( - ..., + frequency_range: FreqBound = Field( title="Frequency Range", description="Frequency range for the broadband mode absorption boundary conditions.", units=(HERTZ, HERTZ), ) - fit_param: BroadbandModeABCFitterParam = pd.Field( + fit_param: BroadbandModeABCFitterParam = Field( DEFAULT_BROADBAND_MODE_ABC_FITTER_PARAMS, title="Fitting Parameters For Broadband Mode Absorption Boundary Conditions", description="Parameters for fitting the mode propagation index over the frequency range using pole-residue pair model.", ) - @pd.validator("frequency_range", always=True) - def validate_frequency_range(cls, val, values): + @field_validator("frequency_range", mode="after") + @classmethod + def validate_frequency_range(cls, val: FreqBound) -> FreqBound: """Validate that max frequency is greater than min frequency.""" _assert_min_freq(val[0], "min frequency") if val[1] <= val[0]: @@ -233,7 +250,7 @@ def from_wavelength_range( ) @property - def _frequency_grid(self) -> np.ndarray: + def _frequency_grid(self) -> NDArray: """Frequency grid for the broadband mode absorption boundary conditions. Propagation constant is sampled at these frequencies and fitted using pole-residue pair model. """ @@ -247,14 +264,14 @@ def _frequency_grid(self) -> np.ndarray: class ModeABCBoundary(AbstractABCBoundary): """One-way wave equation absorbing boundary conditions for absorbing a waveguide mode.""" - mode_spec: ModeSpecType = pd.Field( + mode_spec: ModeSpecType = Field( DEFAULT_MODE_SPEC_MODE_ABC, title="Mode Specification", description="Parameters that determine the modes computed by the mode solver.", discriminator=TYPE_TAG_STR, ) - mode_index: pd.NonNegativeInt = pd.Field( + mode_index: NonNegativeInt = Field( 0, title="Mode Index", description="Index into the collection of modes returned by mode solver. " @@ -263,20 +280,21 @@ class ModeABCBoundary(AbstractABCBoundary): "``num_modes`` in the solver will be set to ``mode_index + 1``.", ) - freq_spec: Optional[Union[pd.PositiveFloat, BroadbandModeABCSpec]] = pd.Field( + freq_spec: Optional[Union[PositiveFloat, BroadbandModeABCSpec]] = Field( None, title="Absorption Frequency Specification", description="Specifies the frequency at which field is absorbed. If ``None``, then the central frequency of the source is used. If ``BroadbandModeABCSpec``, then the field is absorbed over the specified frequency range.", ) - plane: Box = pd.Field( + plane: Box = Field( ..., title="Plane", description="Cross-sectional plane in which the absorbed mode will be computed.", ) - @pd.validator("plane", always=True) - def is_plane(cls, val): + @field_validator("plane") + @classmethod + def is_plane(cls, val: Box) -> Box: """Raise validation error if not planar.""" if val.size.count(0.0) != 1: raise ValidationError( @@ -288,15 +306,15 @@ def is_plane(cls, val): def from_source( cls, source: ModeSource, - freq_spec: Optional[Union[pd.PositiveFloat, BroadbandModeABCSpec]] = None, - ) -> ModeABCBoundary: + freq_spec: Optional[Union[PositiveFloat, BroadbandModeABCSpec]] = None, + ) -> Self: """Instantiate from a ``ModeSource``. Parameters ---------- source : :class:`ModeSource` Mode source. - freq_spec : Optional[Union[pd.PositiveFloat, BroadbandModeABCSpec]] = None + freq_spec : Optional[Union[PositiveFloat, BroadbandModeABCSpec]] = None Specifies the frequency at which field is absorbed. If ``None``, then the central frequency of the source is used. If ``BroadbandModeABCSpec``, then the field is absorbed over the specified frequency range. Returns @@ -326,18 +344,18 @@ def from_source( def from_monitor( cls, monitor: Union[ModeMonitor, ModeSolverMonitor], - mode_index: pd.NonNegativeInt = 0, - freq_spec: Optional[Union[pd.PositiveFloat, BroadbandModeABCSpec]] = None, - ) -> ModeABCBoundary: + mode_index: NonNegativeInt = 0, + freq_spec: Optional[Union[PositiveFloat, BroadbandModeABCSpec]] = None, + ) -> Self: """Instantiate from a ``ModeMonitor`` or ``ModeSolverMonitor``. Parameters ---------- monitor : Union[:class:`ModeMonitor`, :class:`ModeSolverMonitor`] Mode monitor. - mode_index : pd.NonNegativeInt = 0 + mode_index : NonNegativeInt = 0 Mode index. - freq_spec : Optional[Union[pd.PositiveFloat, BroadbandModeABCSpec]] = None + freq_spec : Optional[Union[PositiveFloat, BroadbandModeABCSpec]] = None Specifies the frequency at which field is absorbed. If ``None``, then the central frequency of the source is used. If ``BroadbandModeABCSpec``, then the field is absorbed over the specified frequency range. Returns @@ -365,13 +383,13 @@ class InternalAbsorber(Box): Note that internal absorbers are automatically wrapped in a PEC frame with a backing PEC plate on the non-absorbing side. """ - direction: Direction = pd.Field( + direction: Direction = Field( ..., title="Absorption Direction", description="Indicates which direction of traveling waves are absorbed.", ) - grid_shift: int = pd.Field( + grid_shift: int = Field( 0, title="Absorber Shift", description="Displacement of absorber in the normal positive direction in number of cells. " @@ -379,7 +397,7 @@ class InternalAbsorber(Box): "one can use the same `size` and `center` as for the source and simply set `shift` to 1.", ) - boundary_spec: Union[ModeABCBoundary, ABCBoundary] = pd.Field( + boundary_spec: Union[ModeABCBoundary, ABCBoundary] = Field( ..., title="Boundary Specification", description="Boundary specification for defining effective propagation index in the one-way wave equation.", @@ -388,8 +406,11 @@ class InternalAbsorber(Box): _plane_validator = assert_plane() - @pd.validator("boundary_spec", always=True) - def _must_provide_permittivity(cls, val): + @field_validator("boundary_spec") + @classmethod + def _must_provide_permittivity( + cls, val: Union[ModeABCBoundary, ABCBoundary] + ) -> Union[ModeABCBoundary, ABCBoundary]: """Validate that permittivity is provided for ABCBoundary.""" if isinstance(val, ABCBoundary) and val.permittivity is None: raise ValidationError( @@ -415,7 +436,7 @@ def plot( x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None, - ax: Ax = None, + ax: Optional[Ax] = None, **patch_kwargs: Any, ) -> Ax: """Plot this absorber.""" @@ -467,8 +488,7 @@ class BlochBoundary(BoundaryEdge): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_ """ - bloch_vec: float = pd.Field( - ..., + bloch_vec: float = Field( title="Normalized Bloch vector component", description="Normalized component of the Bloch vector in units of " "2 * pi / (size along dimension) in the background medium, " @@ -482,8 +502,12 @@ def bloch_phase(self) -> Complex: @classmethod def from_source( - cls, source: BlochSourceType, domain_size: float, axis: Axis, medium: Medium = None - ) -> BlochBoundary: + cls, + source: BlochSourceType, + domain_size: float, + axis: Axis, + medium: Optional[Medium] = None, + ) -> Self: """Set the Bloch vector component based on a given angled source and its center frequency. Note that if a broadband angled source is used, only the frequency components near the center frequency will exhibit angled incidence at the expect angle. In this case, a @@ -571,20 +595,20 @@ class AbsorberParams(Tidy3dBaseModel): >>> params = AbsorberParams(sigma_order=3, sigma_min=0.0, sigma_max=1.5) """ - sigma_order: pd.NonNegativeInt = pd.Field( + sigma_order: NonNegativeInt = Field( 3, title="Sigma Order", description="Order of the polynomial describing the absorber profile (~dist^sigma_order).", ) - sigma_min: pd.NonNegativeFloat = pd.Field( + sigma_min: NonNegativeFloat = Field( 0.0, title="Sigma Minimum", description="Minimum value of the absorber conductivity.", units=PML_SIGMA, ) - sigma_max: pd.NonNegativeFloat = pd.Field( + sigma_max: NonNegativeFloat = Field( 1.5, title="Sigma Maximum", description="Maximum value of the absorber conductivity.", @@ -600,29 +624,29 @@ class PMLParams(AbsorberParams): >>> params = PMLParams(sigma_order=3, sigma_min=0.0, sigma_max=1.5, kappa_min=0.0) """ - kappa_order: pd.NonNegativeInt = pd.Field( + kappa_order: NonNegativeInt = Field( 3, title="Kappa Order", description="Order of the polynomial describing the PML kappa profile " "(kappa~dist^kappa_order).", ) - kappa_min: pd.NonNegativeFloat = pd.Field(0.0, title="Kappa Minimum", description="") + kappa_min: NonNegativeFloat = Field(0.0, title="Kappa Minimum") - kappa_max: pd.NonNegativeFloat = pd.Field(1.5, title="Kappa Maximum", description="") + kappa_max: NonNegativeFloat = Field(1.5, title="Kappa Maximum") - alpha_order: pd.NonNegativeInt = pd.Field( + alpha_order: NonNegativeInt = Field( 3, title="Alpha Order", description="Order of the polynomial describing the PML alpha profile " "(alpha~dist^alpha_order).", ) - alpha_min: pd.NonNegativeFloat = pd.Field( + alpha_min: NonNegativeFloat = Field( 0.0, title="Alpha Minimum", description="Minimum value of the PML alpha.", units=PML_SIGMA ) - alpha_max: pd.NonNegativeFloat = pd.Field( + alpha_max: NonNegativeFloat = Field( 1.5, title="Alpha Maximum", description="Maximum value of the PML alpha.", units=PML_SIGMA ) @@ -660,14 +684,12 @@ class PMLParams(AbsorberParams): class AbsorberSpec(BoundaryEdge): """Specifies the generic absorber properties along a single dimension.""" - num_layers: int = pd.Field( - ..., + num_layers: float = Field( title="Number of Layers", description="Number of layers of standard PML.", ge=1, ) - parameters: AbsorberParams = pd.Field( - ..., + parameters: AbsorberParams = Field( title="Absorber Parameters", description="Parameters to fine tune the absorber profile and properties.", ) @@ -777,14 +799,14 @@ class PML(AbsorberSpec): """ - num_layers: int = pd.Field( + num_layers: int = Field( 12, title="Number of Layers", description="Number of layers of standard PML.", ge=1, ) - parameters: PMLParams = pd.Field( + parameters: PMLParams = Field( DefaultPMLParameters, title="PML Parameters", description="Parameters of the complex frequency-shifted absorption poles.", @@ -819,14 +841,14 @@ class StablePML(AbsorberSpec): * `Introduction to perfectly matched layer (PML) tutorial `__ """ - num_layers: int = pd.Field( + num_layers: int = Field( 40, title="Number of Layers", description="Number of layers of 'stable' PML.", ge=1, ) - parameters: PMLParams = pd.Field( + parameters: PMLParams = Field( DefaultStablePMLParameters, title="Stable PML Parameters", description="'Stable' parameters of the complex frequency-shifted absorption poles.", @@ -876,14 +898,14 @@ class Absorber(AbsorberSpec): * `How to troubleshoot a diverged FDTD simulation <../../notebooks/DivergedFDTDSimulation.html>`_ """ - num_layers: int = pd.Field( + num_layers: int = Field( 40, title="Number of Layers", description="Number of layers of absorber to add to + and - boundaries.", ge=1, ) - parameters: AbsorberParams = pd.Field( + parameters: AbsorberParams = Field( DefaultAbsorberParameters, title="Absorber Parameters", description="Adiabatic absorber parameters.", @@ -943,70 +965,68 @@ class Boundary(Tidy3dBaseModel): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_ """ - plus: BoundaryEdgeType = pd.Field( - PML(), + plus: BoundaryEdgeType = Field( + default_factory=PML, title="Plus BC", description="Boundary condition on the plus side along a dimension.", - discriminator=TYPE_TAG_STR, ) - minus: BoundaryEdgeType = pd.Field( - PML(), + minus: BoundaryEdgeType = Field( + default_factory=PML, title="Minus BC", description="Boundary condition on the minus side along a dimension.", - discriminator=TYPE_TAG_STR, ) - @pd.root_validator(skip_on_failure=True) - def bloch_on_both_sides(cls, values): + @model_validator(mode="after") + def bloch_on_both_sides(self) -> Self: """Error if a Bloch boundary is applied on only one side.""" - plus = values.get("plus") - minus = values.get("minus") - num_bloch = isinstance(plus, BlochBoundary) + isinstance(minus, BlochBoundary) + num_bloch = isinstance(self.plus, BlochBoundary) + isinstance(self.minus, BlochBoundary) if num_bloch == 1: raise SetupError( "Bloch boundaries must be applied either on both sides or on neither side." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def periodic_with_pml(cls, values): + @model_validator(mode="after") + def periodic_with_pml(self) -> Self: """Error if PBC is specified with a PML.""" - plus = values.get("plus") - minus = values.get("minus") - num_pbc = isinstance(plus, Periodic) + isinstance(minus, Periodic) + num_pbc = isinstance(self.plus, Periodic) + isinstance(self.minus, Periodic) num_pml = isinstance( - plus, (PML, StablePML, Absorber, ABCBoundary, ModeABCBoundary) - ) + isinstance(minus, (PML, StablePML, Absorber, ABCBoundary, ModeABCBoundary)) + self.plus, (PML, StablePML, Absorber, ABCBoundary, ModeABCBoundary) + ) + isinstance(self.minus, (PML, StablePML, Absorber, ABCBoundary, ModeABCBoundary)) if num_pbc == 1 and num_pml == 1: - raise SetupError("Cannot have both PML and PBC along the same dimension.") - return values - - @pd.root_validator(skip_on_failure=True) - def periodic_with_pec_pmc(cls, values): - """If a PBC is specified along with PEC or PMC on the other side, manually set the PBC - to PEC or PMC so that no special treatment of halos is required.""" - plus = values.get("plus") - minus = values.get("minus") + raise SetupError("Cannot have both 'PML' and 'Periodic' along the same dimension.") + return self + @model_validator(mode="after") + def periodic_with_pec_pmc(self) -> Self: + """ + If a PBC is specified along with PEC or PMC on the other side, manually set the PBC + to PEC or PMC so that no special treatment of halos is required. + """ + plus, minus = self.plus, self.minus switched = False + if isinstance(minus, (PECBoundary, PMCBoundary)) and isinstance(plus, Periodic): plus = minus switched = True elif isinstance(plus, (PECBoundary, PMCBoundary)) and isinstance(minus, Periodic): minus = plus switched = True + if switched: - values.update({"plus": plus, "minus": minus}) + object.__setattr__(self, "plus", plus) + object.__setattr__(self, "minus", minus) log.warning( "A periodic boundary condition was specified on the opposite side of a perfect " "electric or magnetic conductor boundary. This periodic boundary condition will " "be replaced by the perfect electric or magnetic conductor across from it." ) - return values + + return self @classmethod - def periodic(cls): + def periodic(cls) -> Self: """Periodic boundary specification on both sides along a dimension. Example @@ -1018,7 +1038,7 @@ def periodic(cls): return cls(plus=plus, minus=minus) @classmethod - def bloch(cls, bloch_vec: complex): + def bloch(cls, bloch_vec: complex) -> Self: """Bloch boundary specification on both sides along a dimension. Parameters @@ -1037,8 +1057,12 @@ def bloch(cls, bloch_vec: complex): @classmethod def bloch_from_source( - cls, source: BlochSourceType, domain_size: float, axis: Axis, medium: Medium = None - ): + cls, + source: BlochSourceType, + domain_size: float, + axis: Axis, + medium: Optional[Medium] = None, + ) -> Self: """Bloch boundary specification on both sides along a dimension based on a given source. Parameters @@ -1070,7 +1094,7 @@ def bloch_from_source( return cls(plus=plus, minus=minus) @classmethod - def pec(cls): + def pec(cls) -> Self: """PEC boundary specification on both sides along a dimension. Example @@ -1082,7 +1106,7 @@ def pec(cls): return cls(plus=plus, minus=minus) @classmethod - def pmc(cls): + def pmc(cls) -> Self: """PMC boundary specification on both sides along a dimension. Example @@ -1096,9 +1120,9 @@ def pmc(cls): @classmethod def abc( cls, - permittivity: Optional[pd.PositiveFloat] = None, - conductivity: Optional[pd.NonNegativeFloat] = None, - ): + permittivity: Optional[PositiveFloat] = None, + conductivity: Optional[NonNegativeFloat] = None, + ) -> Self: """ABC boundary specification on both sides along a dimension. Example @@ -1120,9 +1144,9 @@ def mode_abc( cls, plane: Box, mode_spec: ModeSpecType = DEFAULT_MODE_SPEC_MODE_ABC, - mode_index: pd.NonNegativeInt = 0, - freq_spec: Optional[Union[pd.PositiveFloat, BroadbandModeABCSpec]] = None, - ): + mode_index: NonNegativeInt = 0, + freq_spec: Optional[Union[PositiveFloat, BroadbandModeABCSpec]] = None, + ) -> Self: """One-way wave equation mode ABC boundary specification on both sides along a dimension. Parameters @@ -1131,9 +1155,9 @@ def mode_abc( Cross-sectional plane in which the absorbed mode will be computed. mode_spec: ModeSpecType = ModeSpec() Parameters that determine the modes computed by the mode solver. - mode_index : pd.NonNegativeInt = 0 + mode_index : NonNegativeInt = 0 Mode index. - freq_spec : Optional[Union[pd.PositiveFloat, BroadbandModeABCSpec]] = None + freq_spec : Optional[Union[PositiveFloat, BroadbandModeABCSpec]] = None Specifies the frequency at which field is absorbed. If ``None``, then the central frequency of the source is used. If ``BroadbandModeABCSpec``, then the field is absorbed over the specified frequency range. Example @@ -1161,15 +1185,15 @@ def mode_abc( def mode_abc_from_source( cls, source: ModeSource, - freq_spec: Optional[Union[pd.PositiveFloat, BroadbandModeABCSpec]] = None, - ): + freq_spec: Optional[Union[PositiveFloat, BroadbandModeABCSpec]] = None, + ) -> Self: """One-way wave equation mode ABC boundary specification on both sides along a dimension constructed from a mode source. Parameters ---------- source : :class:`ModeSource` Mode source. - freq_spec : Optional[Union[pd.PositiveFloat, BroadbandModeABCSpec]] = None + freq_spec : Optional[Union[PositiveFloat, BroadbandModeABCSpec]] = None Specifies the frequency at which field is absorbed. If ``None``, then the central frequency of the source is used. If ``BroadbandModeABCSpec``, then the field is absorbed over the specified frequency range. Example @@ -1187,9 +1211,9 @@ def mode_abc_from_source( def mode_abc_from_monitor( cls, monitor: Union[ModeMonitor, ModeSolverMonitor], - mode_index: pd.NonNegativeInt = 0, - freq_spec: Optional[Union[pd.PositiveFloat, BroadbandModeABCSpec]] = None, - ): + mode_index: NonNegativeInt = 0, + freq_spec: Optional[Union[PositiveFloat, BroadbandModeABCSpec]] = None, + ) -> Self: """One-way wave equation mode ABC boundary specification on both sides along a dimension constructed from a mode monitor. Example @@ -1211,7 +1235,9 @@ def mode_abc_from_monitor( return cls(plus=plus, minus=minus) @classmethod - def pml(cls, num_layers: pd.NonNegativeInt = 12, parameters: PMLParams = DefaultPMLParameters): + def pml( + cls, num_layers: NonNegativeInt = 12, parameters: PMLParams = DefaultPMLParameters + ) -> Self: """PML boundary specification on both sides along a dimension. Parameters @@ -1231,8 +1257,10 @@ def pml(cls, num_layers: pd.NonNegativeInt = 12, parameters: PMLParams = Default @classmethod def stable_pml( - cls, num_layers: pd.NonNegativeInt = 40, parameters: PMLParams = DefaultStablePMLParameters - ): + cls, + num_layers: NonNegativeInt = 40, + parameters: PMLParams = DefaultStablePMLParameters, + ) -> Self: """Stable PML boundary specification on both sides along a dimension. Parameters @@ -1252,8 +1280,10 @@ def stable_pml( @classmethod def absorber( - cls, num_layers: pd.NonNegativeInt = 40, parameters: PMLParams = DefaultAbsorberParameters - ): + cls, + num_layers: NonNegativeInt = 40, + parameters: PMLParams = DefaultAbsorberParameters, + ) -> Self: """Adiabatic absorber boundary specification on both sides along a dimension. Parameters @@ -1307,30 +1337,38 @@ class BoundarySpec(Tidy3dBaseModel): * `Using FDTD to Compute a Transmission Spectrum `__ """ - x: Boundary = pd.Field( - Boundary(), + x: Boundary = Field( + default_factory=Boundary, title="Boundary condition along x.", description="Boundary condition on the plus and minus sides along the x axis. " "If ``None``, periodic boundaries are applied. Default will change to PML in 2.0 " "so explicitly setting the boundaries is recommended.", ) - y: Boundary = pd.Field( - Boundary(), + y: Boundary = Field( + default_factory=Boundary, title="Boundary condition along y.", description="Boundary condition on the plus and minus sides along the y axis. " "If ``None``, periodic boundaries are applied. Default will change to PML in 2.0 " "so explicitly setting the boundaries is recommended.", ) - z: Boundary = pd.Field( - Boundary(), + z: Boundary = Field( + default_factory=Boundary, title="Boundary condition along z.", description="Boundary condition on the plus and minus sides along the z axis. " "If ``None``, periodic boundaries are applied. Default will change to PML in 2.0 " "so explicitly setting the boundaries is recommended.", ) + @field_validator("x", "y", "z", mode="before") + @classmethod + def dict_to_boundary(cls, v: Any) -> Any: + """Convert dict representation to Boundary object if needed.""" + if isinstance(v, dict) and "plus" in v and "minus" in v: + return Boundary(**v) + return v + def __getitem__(self, field_name: str) -> Boundary: """Get the :class:`Boundary` field by name (``boundary_spec[field_name]``). @@ -1353,7 +1391,7 @@ def __getitem__(self, field_name: str) -> Boundary: raise DataError(f"field_name '{field_name}' not found") @classmethod - def pml(cls, x: bool = False, y: bool = False, z: bool = False): + def pml(cls, x: bool = False, y: bool = False, z: bool = False) -> Self: """PML along specified directions Parameters @@ -1376,7 +1414,7 @@ def pml(cls, x: bool = False, y: bool = False, z: bool = False): ) @classmethod - def pec(cls, x: bool = False, y: bool = False, z: bool = False): + def pec(cls, x: bool = False, y: bool = False, z: bool = False) -> Self: """PEC along specified directions Parameters @@ -1399,7 +1437,7 @@ def pec(cls, x: bool = False, y: bool = False, z: bool = False): ) @classmethod - def pmc(cls, x: bool = False, y: bool = False, z: bool = False): + def pmc(cls, x: bool = False, y: bool = False, z: bool = False) -> Self: """PMC along specified directions Parameters @@ -1422,7 +1460,7 @@ def pmc(cls, x: bool = False, y: bool = False, z: bool = False): ) @classmethod - def all_sides(cls, boundary: BoundaryEdge): + def all_sides(cls, boundary: BoundaryEdge) -> Self: """Set a given boundary condition on all six sides of the domain Parameters @@ -1450,7 +1488,7 @@ def to_list(self) -> list[tuple[BoundaryEdgeType, BoundaryEdgeType]]: ] @cached_property - def flipped_bloch_vecs(self) -> BoundarySpec: + def flipped_bloch_vecs(self) -> Self: """Return a copy of the instance where all Bloch vectors are multiplied by -1.""" bound_dims = {"x": self.x.copy(), "y": self.y.copy(), "z": self.z.copy()} for dim_key, bound_dim in bound_dims.items(): diff --git a/tidy3d/components/data/data_array.py b/tidy3d/components/data/data_array.py index 3afd54f516..d7f362c0fa 100644 --- a/tidy3d/components/data/data_array.py +++ b/tidy3d/components/data/data_array.py @@ -1,551 +1,38 @@ -"""Storing tidy3d data at it's most fundamental level as xr.DataArray objects""" +"""Compatibility shim for :mod:`tidy3d._common.components.data.data_array`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as partially migrated to _common from __future__ import annotations -import pathlib -from abc import ABC -from collections.abc import Mapping -from os import PathLike -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Union -import autograd.numpy as anp -import h5py import numpy as np -import xarray as xr -from autograd.tracer import isbox -from xarray.core import missing -from xarray.core.indexes import PandasIndex -from xarray.core.indexing import _outer_to_numpy_indexer -from xarray.core.types import InterpOptions, Self -from xarray.core.utils import OrderedSet, either_dict_or_kwargs -from xarray.core.variable import as_variable - -from tidy3d.compat import alignment -from tidy3d.components.autograd import TidyArrayBox, get_static, interpn, is_tidy_box -from tidy3d.components.geometry.bound_ops import bounds_contains -from tidy3d.components.types import Axis, Bound -from tidy3d.constants import ( + +from tidy3d._common.components.data.data_array import ( + DATA_ARRAY_MAP, + DATA_ARRAY_TYPES, + AbstractSpatialDataArray, + DataArray, + FreqDataArray, + ScalarFieldDataArray, + SpatialDataArray, + TimeDataArray, + TriangleMeshDataArray, +) +from tidy3d._common.constants import ( AMP, - HERTZ, - MICROMETER, OHM, PICOSECOND_PER_NANOMETER_PER_KILOMETER, - RADIAN, - SECOND, VOLT, WATT, ) -from tidy3d.exceptions import DataError, FileError - -# maps the dimension names to their attributes -DIM_ATTRS = { - "x": {"units": MICROMETER, "long_name": "x position"}, - "y": {"units": MICROMETER, "long_name": "y position"}, - "z": {"units": MICROMETER, "long_name": "z position"}, - "f": {"units": HERTZ, "long_name": "frequency"}, - "t": {"units": SECOND, "long_name": "time"}, - "direction": {"long_name": "propagation direction"}, - "mode_index": {"long_name": "mode index"}, - "eme_port_index": {"long_name": "EME port index"}, - "eme_cell_index": {"long_name": "EME cell index"}, - "mode_index_in": {"long_name": "mode index in"}, - "mode_index_out": {"long_name": "mode index out"}, - "sweep_index": {"long_name": "sweep index"}, - "theta": {"units": RADIAN, "long_name": "elevation angle"}, - "phi": {"units": RADIAN, "long_name": "azimuth angle"}, - "ux": {"long_name": "normalized kx"}, - "uy": {"long_name": "normalized ky"}, - "orders_x": {"long_name": "diffraction order"}, - "orders_y": {"long_name": "diffraction order"}, - "face_index": {"long_name": "face index"}, - "vertex_index": {"long_name": "vertex index"}, - "axis": {"long_name": "axis"}, -} - - -# name of the DataArray.values in the hdf5 file (xarray's default name too) -DATA_ARRAY_VALUE_NAME = "__xarray_dataarray_variable__" - - -class DataArray(xr.DataArray): - """Subclass of ``xr.DataArray`` that requires _dims to match the keys of the coords.""" - - # Always set __slots__ = () to avoid xarray warnings - __slots__ = () - # stores an ordered tuple of strings corresponding to the data dimensions - _dims = () - # stores a dictionary of attributes corresponding to the data values - _data_attrs: dict[str, str] = {} - - def __init__(self, data, *args: Any, **kwargs: Any) -> None: - # if data is a vanilla autograd box, convert to our box - if isbox(data) and not is_tidy_box(data): - data = TidyArrayBox.from_arraybox(data) - # do the same for xr.Variable or xr.DataArray type - elif isinstance(data, (xr.Variable, xr.DataArray)): - if isbox(data.data) and not is_tidy_box(data.data): - data.data = TidyArrayBox.from_arraybox(data.data) - super().__init__(data, *args, **kwargs) - - @classmethod - def __get_validators__(cls): - """Validators that get run when :class:`.DataArray` objects are added to pydantic models.""" - yield cls.check_unloaded_data - yield cls.validate_dims - yield cls.assign_data_attrs - yield cls.assign_coord_attrs - - @classmethod - def check_unloaded_data(cls, val): - """If the data comes in as the raw data array string, raise a custom warning.""" - if isinstance(val, str) and val in DATA_ARRAY_MAP: - raise DataError( - f"Trying to load {cls.__name__} but the data is not present. " - "Note that data will not be saved to .json file. " - "use .hdf5 format instead if data present." - ) - return cls(val) - - @classmethod - def validate_dims(cls, val): - """Make sure the dims are the same as _dims, then put them in the correct order.""" - if set(val.dims) != set(cls._dims): - raise ValueError(f"wrong dims, expected '{cls._dims}', got '{val.dims}'") - return val.transpose(*cls._dims) - - @classmethod - def assign_data_attrs(cls, val): - """Assign the correct data attributes to the :class:`.DataArray`.""" - - for attr_name, attr in cls._data_attrs.items(): - val.attrs[attr_name] = attr - return val - - def _interp_validator(self, field_name: Optional[str] = None) -> None: - """Ensure the data can be interpolated or selected by checking for duplicate coordinates. - - NOTE - ---- - This does not check every 'DataArray' by default. Instead, when required, this check can be - called from a validator, as is the case with 'CustomMedium' and 'CustomFieldSource'. - """ - if field_name is None: - field_name = "DataArray" - - for dim, coord in self.coords.items(): - if coord.to_index().duplicated().any(): - raise DataError( - f"Field '{field_name}' contains duplicate coordinates in dimension '{dim}'. " - "Duplicates can be removed by running " - f"'{field_name}={field_name}.drop_duplicates(dim=\"{dim}\")'." - ) - - @classmethod - def assign_coord_attrs(cls, val): - """Assign the correct coordinate attributes to the :class:`.DataArray`.""" - - for dim in cls._dims: - dim_attrs = DIM_ATTRS.get(dim) - if dim_attrs is not None: - for attr_name, attr in dim_attrs.items(): - val.coords[dim].attrs[attr_name] = attr - return val - - @classmethod - def __modify_schema__(cls, field_schema) -> None: - """Sets the schema of DataArray object.""" - - schema = { - "title": "DataArray", - "type": "xr.DataArray", - "properties": { - "_dims": { - "title": "_dims", - "type": "Tuple[str, ...]", - }, - }, - "required": ["_dims"], - } - field_schema.update(schema) - - @classmethod - def _json_encoder(cls, val): - """What function to call when writing a DataArray to json.""" - return type(val).__name__ - - def __eq__(self, other) -> bool: - """Whether two data array objects are equal.""" - - if not isinstance(other, xr.DataArray): - return False - - if not self.data.shape == other.data.shape or not np.all(self.data == other.data): - return False - for key, val in self.coords.items(): - if not np.all(np.array(val) == np.array(other.coords[key])): - return False - return True - - @property - def values(self): - """ - The array's data converted to a numpy.ndarray. - - Returns - ------- - np.ndarray - The values of the DataArray. - """ - return self.data if isbox(self.data) else super().values - - def to_numpy(self) -> np.ndarray: - """Return `.data` when traced to avoid `dtype=object` NumPy conversion.""" - return self.data if isbox(self.data) else super().to_numpy() - - @values.setter - def values(self, value: Any) -> None: - self.variable.values = value - - @property - def abs(self): - """Absolute value of data array.""" - return abs(self) - - @property - def angle(self): - """Angle or phase value of data array.""" - values = np.angle(self.values) - return type(self)(values, coords=self.coords) - - @property - def is_uniform(self): - """Whether each element is of equal value in the data array""" - raw_data = self.data.ravel() - return np.allclose(raw_data, raw_data[0]) - - def to_hdf5(self, fname: Union[PathLike, h5py.File], group_path: str) -> None: - """Save an xr.DataArray to the hdf5 file or file handle with a given path to the group.""" - - # file name passed - if isinstance(fname, (str, pathlib.Path)): - path = pathlib.Path(fname) - path.parent.mkdir(parents=True, exist_ok=True) - with h5py.File(path, "w") as f_handle: - self.to_hdf5_handle(f_handle=f_handle, group_path=group_path) - - # file handle passed - else: - self.to_hdf5_handle(f_handle=fname, group_path=group_path) - - def to_hdf5_handle(self, f_handle: h5py.File, group_path: str) -> None: - """Save an xr.DataArray to the hdf5 file handle with a given path to the group.""" - - sub_group = f_handle.create_group(group_path) - sub_group[DATA_ARRAY_VALUE_NAME] = get_static(self.data) - for key, val in self.coords.items(): - if val.dtype == " Self: - """Load an DataArray from an hdf5 file with a given path to the group.""" - path = pathlib.Path(fname) - with h5py.File(path, "r") as f: - sub_group = f[group_path] - values = np.array(sub_group[DATA_ARRAY_VALUE_NAME]) - coords = {dim: np.array(sub_group[dim]) for dim in cls._dims if dim in sub_group} - for key, val in coords.items(): - if val.dtype == "O": - coords[key] = [byte_string.decode() for byte_string in val.tolist()] - return cls(values, coords=coords, dims=cls._dims) - - @classmethod - def from_file(cls, fname: PathLike, group_path: str) -> Self: - """Load an DataArray from an hdf5 file with a given path to the group.""" - path = pathlib.Path(fname) - if not any(suffix.lower() == ".hdf5" for suffix in path.suffixes): - raise FileError( - f"'DataArray' objects must be written to '.hdf5' format. Given filename of {path}." - ) - return cls.from_hdf5(fname=path, group_path=group_path) - - def __hash__(self) -> int: - """Generate hash value for a :class:`.DataArray` instance, needed for custom components.""" - import dask - - token_str = dask.base.tokenize(self) - return hash(token_str) - - def multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: - """Multiply self by value at indices.""" - if isbox(self.data) or isbox(value): - return self._ag_multiply_at(value, coord_name, indices) - - self_mult = self.copy() - self_mult[{coord_name: indices}] *= value - return self_mult - - def _ag_multiply_at(self, value: complex, coord_name: str, indices: list[int]) -> Self: - """Autograd multiply_at override when tracing.""" - key = {coord_name: indices} - _, index_tuple, _ = self.variable._broadcast_indexes(key) - idx = _outer_to_numpy_indexer(index_tuple, self.data.shape) - mask = np.zeros(self.data.shape, dtype="?") - mask[idx] = True - return self.copy(deep=False, data=anp.where(mask, self.data * value, self.data)) - - def interp( - self, - coords: Mapping[Any, Any] | None = None, - method: InterpOptions = "linear", - assume_sorted: bool = False, - kwargs: Mapping[str, Any] | None = None, - **coords_kwargs: Any, - ) -> Self: - """Interpolate this DataArray to new coordinate values. - - Parameters - ---------- - coords : Union[Mapping[Any, Any], None] = None - A mapping from dimension names to new coordinate labels. - method : InterpOptions = "linear" - The interpolation method to use. - assume_sorted : bool = False - If True, skip sorting of coordinates. - kwargs : Union[Mapping[str, Any], None] = None - Additional keyword arguments to pass to the interpolation function. - **coords_kwargs : Any - The keyword arguments form of coords. - - Returns - ------- - DataArray - A new DataArray with interpolated values. - - Raises - ------ - KeyError - If any of the specified coordinates are not in the DataArray. - """ - if isbox(self.data): - return self._ag_interp(coords, method, assume_sorted, kwargs, **coords_kwargs) - - return super().interp(coords, method, assume_sorted, kwargs, **coords_kwargs) - - def _ag_interp( - self, - coords: Union[Mapping[Any, Any], None] = None, - method: InterpOptions = "linear", - assume_sorted: bool = False, - kwargs: Union[Mapping[str, Any], None] = None, - **coords_kwargs: Any, - ) -> Self: - """Autograd interp override when tracing over self.data. - - This implementation closely follows the interp implementation of xarray - to match its behavior as closely as possible while supporting autograd. - - See: - - https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html - - https://docs.xarray.dev/en/latest/generated/xarray.Dataset.interp.html - """ - if kwargs is None: - kwargs = {} - - ds = self._to_temp_dataset() - - coords = either_dict_or_kwargs(coords, coords_kwargs, "interp") - indexers = dict(ds._validate_interp_indexers(coords)) - - if coords: - # Find shared dimensions between the dataset and the indexers - sdims = ( - set(ds.dims) - .intersection(*[set(nx.dims) for nx in indexers.values()]) - .difference(coords.keys()) - ) - indexers.update({d: ds.variables[d] for d in sdims}) - - obj = ds if assume_sorted else ds.sortby(list(coords)) - - # workaround to get a variable for a dimension without a coordinate - validated_indexers = { - k: (obj._variables.get(k, as_variable((k, range(obj.sizes[k])))), v) - for k, v in indexers.items() - } - - for k, v in validated_indexers.items(): - obj, newidx = missing._localize(obj, {k: v}) - validated_indexers[k] = newidx[k] - - variables = {} - reindex = False - for name, var in obj._variables.items(): - if name in indexers: - continue - dtype_kind = var.dtype.kind - if dtype_kind in "uifc": - # Interpolation for numeric types - var_indexers = {k: v for k, v in validated_indexers.items() if k in var.dims} - variables[name] = self._ag_interp_func(var, var_indexers, method, **kwargs) - elif dtype_kind in "ObU" and (validated_indexers.keys() & var.dims): - # Stepwise interpolation for non-numeric types - reindex = True - elif all(d not in indexers for d in var.dims): - # Keep variables not dependent on interpolated coords - variables[name] = var - - if reindex: - # Reindex for non-numeric types - reindex_indexers = {k: v for k, (_, v) in validated_indexers.items() if v.dims == (k,)} - reindexed = alignment.reindex( - obj, - indexers=reindex_indexers, - method="nearest", - exclude_vars=variables.keys(), - ) - indexes = dict(reindexed._indexes) - variables.update(reindexed.variables) - else: - # Get the indexes that are not being interpolated along - indexes = {k: v for k, v in obj._indexes.items() if k not in indexers} - - # Get the coords that also exist in the variables - coord_names = obj._coord_names & variables.keys() - selected = ds._replace_with_new_dims(variables.copy(), coord_names, indexes=indexes) - - # Attach indexer as coordinate - for k, v in indexers.items(): - if v.dims == (k,): - index = PandasIndex(v, k, coord_dtype=v.dtype) - index_vars = index.create_variables({k: v}) - indexes[k] = index - variables.update(index_vars) - else: - variables[k] = v - - # Extract coordinates from indexers - coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(coords) - variables.update(coord_vars) - indexes.update(new_indexes) - - coord_names = obj._coord_names & variables.keys() | coord_vars.keys() - ds = ds._replace_with_new_dims(variables, coord_names, indexes=indexes) - return self._from_temp_dataset(ds) - - @staticmethod - def _ag_interp_func(var, indexes_coords, method, **kwargs: Any): - """ - Interpolate the variable `var` along the coordinates specified in `indexes_coords` using the given `method`. - - The implementation follows xarray's interp implementation in xarray.core.missing, - but replaces some of the pre-processing as well as the actual interpolation - function with an autograd-compatible approach. - - - Parameters - ---------- - var : xr.Variable - The variable to be interpolated. - indexes_coords : dict - A dictionary mapping dimension names to coordinate values for interpolation. - method : str - The interpolation method to use. - **kwargs : dict - Additional keyword arguments to pass to the interpolation function. - - Returns - ------- - xr.Variable - The interpolated variable. - """ - if not indexes_coords: - return var.copy() - result = var - for indep_indexes_coords in missing.decompose_interp(indexes_coords): - var = result - - # target dimensions - dims = list(indep_indexes_coords) - x, new_x = zip(*[indep_indexes_coords[d] for d in dims]) - destination = missing.broadcast_variables(*new_x) - - broadcast_dims = [d for d in var.dims if d not in dims] - original_dims = broadcast_dims + dims - new_dims = broadcast_dims + list(destination[0].dims) - - x, new_x = missing._floatize_x(x, new_x) - - permutation = [var.dims.index(dim) for dim in original_dims] - combined_permutation = permutation[-len(x) :] + permutation[: -len(x)] - data = anp.transpose(var.data, combined_permutation) - xi = anp.stack([anp.ravel(new_xi.data) for new_xi in new_x], axis=-1) - - result = interpn([xn.data for xn in x], data, xi, method=method, **kwargs) - - result = anp.moveaxis(result, 0, -1) - result = anp.reshape(result, result.shape[:-1] + new_x[0].shape) - - result = xr.Variable(new_dims, result, attrs=var.attrs, fastpath=True) - - out_dims: OrderedSet = OrderedSet() - for d in var.dims: - if d in dims: - out_dims.update(indep_indexes_coords[d][1].dims) - else: - out_dims.add(d) - if len(out_dims) > 1: - result = result.transpose(*out_dims) - return result - - def _with_updated_data(self, data: np.ndarray, coords: dict[str, Any]) -> DataArray: - """Make copy of ``DataArray`` with ``data`` at specified ``coords``, autograd compatible - - Constraints / Edge cases: - - `coords` must map to a specific value eg {x: '1'}, does not broadcast to arrays - - `data` will be reshaped to try to match `self.shape` except where `coords` present - """ - - # make mask - mask = xr.zeros_like(self, dtype=bool) - mask.loc[coords] = True - - # reshape `data` to line up with `self.dims`, with shape of 1 along the selected axis - old_data = self.data - new_shape = list(old_data.shape) - for i, dim in enumerate(self.dims): - if dim in coords: - new_shape[i] = 1 - try: - new_data = data.reshape(new_shape) - except ValueError as e: - raise ValueError( - "Couldn't reshape the supplied 'data' to update 'DataArray'. The provided data was " - f"of shape {data.shape} and tried to reshape to {new_shape}. If you encounter this " - "error please raise an issue on the tidy3d github repository with the context." - ) from e - - # broadcast data to repeat data along the selected dimensions to match mask - new_data = new_data + np.zeros_like(old_data) - - new_data = np.where(mask, new_data, old_data) - - return self.copy(deep=True, data=new_data) - - -class FreqDataArray(DataArray): - """Frequency-domain array. +from tidy3d._common.exceptions import DataError, FileError - Example - ------- - >>> f = [2e14, 3e14] - >>> fd = FreqDataArray((1+1j) * np.random.random((2,)), coords=dict(f=f)) - """ +if TYPE_CHECKING: + from xarray.core.types import Self - __slots__ = () - _dims = ("f",) + from tidy3d._common.components.types.base import Axis, Bound class FreqVoltageDataArray(DataArray): @@ -581,19 +68,6 @@ class FreqModeDataArray(DataArray): _dims = ("f", "mode_index") -class TimeDataArray(DataArray): - """Time-domain array. - - Example - ------- - >>> t = [0, 1e-12, 2e-12] - >>> td = TimeDataArray((1+1j) * np.random.random((3,)), coords=dict(t=t)) - """ - - __slots__ = () - _dims = "t" - - class MixedModeDataArray(DataArray): """Scalar property associated with mode pairs @@ -610,224 +84,6 @@ class MixedModeDataArray(DataArray): _dims = ("f", "mode_index_0", "mode_index_1") -class AbstractSpatialDataArray(DataArray, ABC): - """Spatial distribution.""" - - __slots__ = () - _dims = ("x", "y", "z") - _data_attrs = {"long_name": "field value"} - - @property - def _spatially_sorted(self) -> SpatialDataArray: - """Check whether sorted and sort if not.""" - needs_sorting = [] - for axis in "xyz": - axis_coords = self.coords[axis].values - if len(axis_coords) > 1 and np.any(axis_coords[1:] < axis_coords[:-1]): - needs_sorting.append(axis) - - if len(needs_sorting) > 0: - return self.sortby(needs_sorting) - - return self - - def sel_inside(self, bounds: Bound) -> SpatialDataArray: - """Return a new SpatialDataArray that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. Note that the returned data is sorted with respect - to spatial coordinates. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - SpatialDataArray - Extracted spatial data array. - """ - if any(bmin > bmax for bmin, bmax in zip(*bounds)): - raise DataError( - "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." - ) - - # make sure data is sorted with respect to coordinates - sorted_self = self._spatially_sorted - - inds_list = [] - - coords = (sorted_self.x, sorted_self.y, sorted_self.z) - - for coord, smin, smax in zip(coords, bounds[0], bounds[1]): - length = len(coord) - - # one point along direction, assume invariance - if length == 1: - comp_inds = [0] - else: - # if data does not cover structure at all take the closest index - if smax < coord[0]: # structure is completely on the left side - # take 2 if possible, so that linear iterpolation is possible - comp_inds = np.arange(0, max(2, length)) - - elif smin > coord[-1]: # structure is completely on the right side - # take 2 if possible, so that linear iterpolation is possible - comp_inds = np.arange(min(0, length - 2), length) - - else: - if smin < coord[0]: - ind_min = 0 - else: - ind_min = max(0, (coord >= smin).argmax().data - 1) - - if smax > coord[-1]: - ind_max = length - 1 - else: - ind_max = (coord >= smax).argmax().data - - comp_inds = np.arange(ind_min, ind_max + 1) - - inds_list.append(comp_inds) - - return sorted_self.isel(x=inds_list[0], y=inds_list[1], z=inds_list[2]) - - def does_cover(self, bounds: Bound, rtol: float = 0.0, atol: float = 0.0) -> bool: - """Check whether data fully covers specified by ``bounds`` spatial region. If data contains - only one point along a given direction, then it is assumed the data is constant along that - direction and coverage is not checked. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - rtol : float = 0.0 - Relative tolerance for comparing bounds - atol : float = 0.0 - Absolute tolerance for comparing bounds - - Returns - ------- - bool - Full cover check outcome. - """ - if any(bmin > bmax for bmin, bmax in zip(*bounds)): - raise DataError( - "Min and max bounds must be packaged as '(minx, miny, minz), (maxx, maxy, maxz)'." - ) - xyz = [self.x, self.y, self.z] - self_min = [0] * 3 - self_max = [0] * 3 - for dim in range(3): - coords = xyz[dim] - if len(coords) == 1: - self_min[dim] = bounds[0][dim] - self_max[dim] = bounds[1][dim] - else: - self_min[dim] = np.min(coords) - self_max[dim] = np.max(coords) - self_bounds = (tuple(self_min), tuple(self_max)) - return bounds_contains(self_bounds, bounds, rtol=rtol, atol=atol) - - -class SpatialDataArray(AbstractSpatialDataArray): - """Spatial distribution. - - Example - ------- - >>> x = [1,2] - >>> y = [2,3,4] - >>> z = [3,4,5,6] - >>> coords = dict(x=x, y=y, z=z) - >>> fd = SpatialDataArray((1+1j) * np.random.random((2,3,4)), coords=coords) - """ - - __slots__ = () - - def reflect(self, axis: Axis, center: float, reflection_only: bool = False) -> SpatialDataArray: - """Reflect data across the plane define by parameters ``axis`` and ``center`` from right to - left. Note that the returned data is sorted with respect to spatial coordinates. - - Parameters - ---------- - axis : Literal[0, 1, 2] - Normal direction of the reflection plane. - center : float - Location of the reflection plane along its normal direction. - reflection_only : bool = False - Return only reflected data. - - Returns - ------- - SpatialDataArray - Data after reflection is performed. - """ - - sorted_self = self._spatially_sorted - - coords = [sorted_self.x.values, sorted_self.y.values, sorted_self.z.values] - data = np.array(sorted_self.data) - - data_left_bound = coords[axis][0] - - if np.isclose(center, data_left_bound): - num_duplicates = 1 - elif center > data_left_bound: - raise DataError("Reflection center must be outside and to the left of the data region.") - else: - num_duplicates = 0 - - if reflection_only: - coords[axis] = 2 * center - coords[axis] - coords_dict = dict(zip("xyz", coords)) - - tmp_arr = SpatialDataArray(sorted_self.data, coords=coords_dict) - - return tmp_arr.sortby("xyz"[axis]) - - shape = np.array(np.shape(data)) - old_len = shape[axis] - shape[axis] = 2 * old_len - num_duplicates - - ind_left = [slice(shape[0]), slice(shape[1]), slice(shape[2])] - ind_right = [slice(shape[0]), slice(shape[1]), slice(shape[2])] - - ind_left[axis] = slice(old_len - 1, None, -1) - ind_right[axis] = slice(old_len - num_duplicates, None) - - new_data = np.zeros(shape) - - new_data[ind_left[0], ind_left[1], ind_left[2]] = data - new_data[ind_right[0], ind_right[1], ind_right[2]] = data - - new_coords = np.zeros(shape[axis]) - new_coords[old_len - num_duplicates :] = coords[axis] - new_coords[old_len - 1 :: -1] = 2 * center - coords[axis] - - coords[axis] = new_coords - coords_dict = dict(zip("xyz", coords)) - - return SpatialDataArray(new_data, coords=coords_dict) - - -class ScalarFieldDataArray(AbstractSpatialDataArray): - """Spatial distribution in the frequency-domain. - - Example - ------- - >>> x = [1,2] - >>> y = [2,3,4] - >>> z = [3,4,5,6] - >>> f = [2e14, 3e14] - >>> coords = dict(x=x, y=y, z=z, f=f) - >>> fd = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) - """ - - __slots__ = () - _dims = ("x", "y", "z", "f") - - class ScalarFieldTimeDataArray(AbstractSpatialDataArray): """Spatial distribution in the time-domain. @@ -1055,14 +311,6 @@ class DiffractionDataArray(DataArray): _data_attrs = {"long_name": "diffraction amplitude"} -class TriangleMeshDataArray(DataArray): - """Data of the triangles of a surface mesh as in the STL file format.""" - - __slots__ = () - _dims = ("face_index", "vertex_index", "axis") - _data_attrs = {"long_name": "surface mesh triangles"} - - class HeatDataArray(DataArray): """Heat data array. @@ -1073,7 +321,7 @@ class HeatDataArray(DataArray): """ __slots__ = () - _dims = "T" + _dims = ("T",) class EMEScalarModeFieldDataArray(AbstractSpatialDataArray): @@ -1581,7 +829,7 @@ def _make_base_result_data_array(result: DataArray) -> IntegralResultType: cls = TimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = FreqModeDataArray - return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) + return cls._assign_data_attrs(cls(data=result.data, coords=result.coords)) def _make_voltage_data_array(result: DataArray) -> VoltageIntegralResultType: @@ -1591,7 +839,7 @@ def _make_voltage_data_array(result: DataArray) -> VoltageIntegralResultType: cls = VoltageTimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = VoltageFreqModeDataArray - return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) + return cls._assign_data_attrs(cls(data=result.data, coords=result.coords)) def _make_current_data_array(result: DataArray) -> CurrentIntegralResultType: @@ -1601,7 +849,7 @@ def _make_current_data_array(result: DataArray) -> CurrentIntegralResultType: cls = CurrentTimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = CurrentFreqModeDataArray - return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) + return cls._assign_data_attrs(cls(data=result.data, coords=result.coords)) def _make_impedance_data_array(result: DataArray) -> ImpedanceResultType: @@ -1611,60 +859,8 @@ def _make_impedance_data_array(result: DataArray) -> ImpedanceResultType: cls = ImpedanceTimeDataArray if "f" in result.coords and "mode_index" in result.coords: cls = ImpedanceFreqModeDataArray - return cls.assign_data_attrs(cls(data=result.data, coords=result.coords)) - + return cls._assign_data_attrs(cls(data=result.data, coords=result.coords)) -DATA_ARRAY_TYPES = [ - SpatialDataArray, - ScalarFieldDataArray, - ScalarFieldTimeDataArray, - ScalarModeFieldDataArray, - FluxDataArray, - FluxTimeDataArray, - ModeAmpsDataArray, - ModeIndexDataArray, - GroupIndexDataArray, - ModeDispersionDataArray, - FieldProjectionAngleDataArray, - FieldProjectionCartesianDataArray, - FieldProjectionKSpaceDataArray, - DiffractionDataArray, - FreqModeDataArray, - FreqDataArray, - TimeDataArray, - FreqModeDataArray, - FreqVoltageDataArray, - TriangleMeshDataArray, - HeatDataArray, - EMEScalarFieldDataArray, - EMEScalarModeFieldDataArray, - EMESMatrixDataArray, - EMEInterfaceSMatrixDataArray, - EMECoefficientDataArray, - EMEModeIndexDataArray, - EMEFluxDataArray, - EMEFreqModeDataArray, - ChargeDataArray, - SteadyVoltageDataArray, - PointDataArray, - CellDataArray, - IndexedDataArray, - IndexedFieldVoltageDataArray, - IndexedVoltageDataArray, - SpatialVoltageDataArray, - PerturbationCoefficientDataArray, - IndexedTimeDataArray, - VoltageFreqDataArray, - VoltageTimeDataArray, - VoltageFreqModeDataArray, - CurrentFreqDataArray, - CurrentTimeDataArray, - CurrentFreqModeDataArray, - ImpedanceFreqDataArray, - ImpedanceTimeDataArray, - ImpedanceFreqModeDataArray, -] -DATA_ARRAY_MAP = {data_array.__name__: data_array for data_array in DATA_ARRAY_TYPES} IndexedDataArrayTypes = Union[ IndexedDataArray, diff --git a/tidy3d/components/data/dataset.py b/tidy3d/components/data/dataset.py index aed00f8aa6..3f3decbd20 100644 --- a/tidy3d/components/data/dataset.py +++ b/tidy3d/components/data/dataset.py @@ -1,22 +1,30 @@ -"""Collections of DataArrays.""" +"""Compatibility shim for :mod:`tidy3d._common.components.data.dataset`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as partially migrated to _common from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Callable, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Optional, Union, get_args import numpy as np -import pydantic.v1 as pd import xarray as xr - -from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.types import Axis, FreqArray, xyz -from tidy3d.constants import C_0, PICOSECOND_PER_NANOMETER_PER_KILOMETER, UnitScaling -from tidy3d.exceptions import DataError -from tidy3d.log import log - -from .data_array import ( - DataArray, +from pydantic import Field + +from tidy3d._common.components.data.dataset import ( + DEFAULT_MAX_CELLS_PER_STEP, + DEFAULT_MAX_SAMPLES_PER_STEP, + DEFAULT_TOLERANCE_CELL_FINDING, + AbstractFieldDataset, + AbstractMediumPropertyDataset, + Dataset, + PermittivityDataset, + TimeDataset, + TriangleMeshDataset, +) +from tidy3d.components.data.data_array import ( EMEScalarFieldDataArray, EMEScalarModeFieldDataArray, GroupIndexDataArray, @@ -27,27 +35,21 @@ ScalarModeFieldCylindricalDataArray, ScalarModeFieldDataArray, TimeDataArray, - TriangleMeshDataArray, ) -from .zbf import ZBFData - -DEFAULT_MAX_SAMPLES_PER_STEP = 10_000 -DEFAULT_MAX_CELLS_PER_STEP = 10_000 -DEFAULT_TOLERANCE_CELL_FINDING = 1e-6 +from tidy3d.components.data.zbf import ZBFData +from tidy3d.components.types.base import xyz +from tidy3d.constants import C_0, PICOSECOND_PER_NANOMETER_PER_KILOMETER, UnitScaling +from tidy3d.exceptions import DataError +from tidy3d.log import log +if TYPE_CHECKING: + from typing import Callable, Literal -class Dataset(Tidy3dBaseModel, ABC): - """Abstract base class for objects that store collections of `:class:`.DataArray`s.""" + from numpy.typing import ArrayLike - @property - def data_arrs(self) -> dict: - """Returns a dictionary of all `:class:`.DataArray`s in the dataset.""" - data_arrs = {} - for key in self.__fields__.keys(): - data = getattr(self, key) - if isinstance(data, DataArray): - data_arrs[key] = data - return data_arrs + from tidy3d.compat import Self + from tidy3d.components.data.data_array import DataArray + from tidy3d.components.types.base import Axis, FreqArray class FreqDataset(Dataset, ABC): @@ -161,7 +163,7 @@ def _interp_dataarray_in_freq( class ModeFreqDataset(FreqDataset, ABC): """Abstract base class for objects that store collections of `:class:`.DataArray`s.""" - def _apply_mode_reorder(self, sort_inds_2d): + def _apply_mode_reorder(self, sort_inds_2d: np.ndarray) -> Self: """Apply a mode reordering along mode_index for all frequency indices. Parameters @@ -187,104 +189,6 @@ def _apply_mode_reorder(self, sort_inds_2d): return self.updated_copy(**modify_data) -class AbstractFieldDataset(Dataset, ABC): - """Collection of scalar fields with some symmetry properties.""" - - @property - @abstractmethod - def field_components(self) -> dict[str, DataArray]: - """Maps the field components to their associated data.""" - - def apply_phase(self, phase: float) -> AbstractFieldDataset: - """Create a copy where all elements are phase-shifted by a value (in radians).""" - if phase == 0.0: - return self - phasor = np.exp(1j * phase) - field_components_shifted = {} - for fld_name, fld_cmp in self.field_components.items(): - fld_cmp_shifted = phasor * fld_cmp - field_components_shifted[fld_name] = fld_cmp_shifted - return self.updated_copy(**field_components_shifted) - - @property - @abstractmethod - def grid_locations(self) -> dict[str, str]: - """Maps field components to the string key of their grid locations on the yee lattice.""" - - @property - @abstractmethod - def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: - """Maps field components to their (positive) symmetry eigenvalues.""" - - def package_colocate_results(self, centered_fields: dict[str, ScalarFieldDataArray]) -> Any: - """How to package the dictionary of fields computed via self.colocate().""" - return xr.Dataset(centered_fields) - - def colocate(self, x=None, y=None, z=None) -> xr.Dataset: - """Colocate all of the data at a set of x, y, z coordinates. - - Parameters - ---------- - x : Optional[array-like] = None - x coordinates of locations. - If not supplied, does not try to colocate on this dimension. - y : Optional[array-like] = None - y coordinates of locations. - If not supplied, does not try to colocate on this dimension. - z : Optional[array-like] = None - z coordinates of locations. - If not supplied, does not try to colocate on this dimension. - - Returns - ------- - xr.Dataset - Dataset containing all fields at the same spatial locations. - For more details refer to `xarray's Documentation `_. - - Note - ---- - For many operations (such as flux calculations and plotting), - it is important that the fields are colocated at the same spatial locations. - Be sure to apply this method to your field data in those cases. - """ - - if hasattr(self, "monitor") and self.monitor.colocate: - with log as consolidated_logger: - consolidated_logger.warning( - "Colocating data that has already been colocated during the solver " - "run. For most accurate results when colocating to custom coordinates set " - "'Monitor.colocate' to 'False' to use the raw data on the Yee grid " - "and avoid double interpolation. Note: the default value was changed to 'True' " - "in Tidy3D version 2.4.0." - ) - - # convert supplied coordinates to array and assign string mapping to them - supplied_coord_map = {k: np.array(v) for k, v in zip("xyz", (x, y, z)) if v is not None} - - # dict of data arrays to combine in dataset and return - centered_fields = {} - - # loop through field components - for field_name, field_data in self.field_components.items(): - # loop through x, y, z dimensions and raise an error if only one element along dim - for coord_name, coords_supplied in supplied_coord_map.items(): - coord_data = np.array(field_data.coords[coord_name]) - if coord_data.size == 1: - raise DataError( - f"colocate given {coord_name}={coords_supplied}, but " - f"data only has one coordinate at {coord_name}={coord_data[0]}. " - "Therefore, can't colocate along this dimension. " - f"supply {coord_name}=None to skip it." - ) - - centered_fields[field_name] = field_data.interp( - **supplied_coord_map, kwargs={"bounds_error": True} - ) - - # combine all centered fields in a dataset - return self.package_colocate_results(centered_fields) - - EMScalarFieldType = Union[ ScalarFieldDataArray, ScalarFieldTimeDataArray, @@ -298,32 +202,32 @@ def colocate(self, x=None, y=None, z=None) -> xr.Dataset: class ElectromagneticFieldDataset(AbstractFieldDataset, ABC): """Stores a collection of E and H fields with x, y, z components.""" - Ex: Optional[EMScalarFieldType] = pd.Field( + Ex: Optional[EMScalarFieldType] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) - Ey: Optional[EMScalarFieldType] = pd.Field( + Ey: Optional[EMScalarFieldType] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) - Ez: Optional[EMScalarFieldType] = pd.Field( + Ez: Optional[EMScalarFieldType] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) - Hx: Optional[EMScalarFieldType] = pd.Field( + Hx: Optional[EMScalarFieldType] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) - Hy: Optional[EMScalarFieldType] = pd.Field( + Hy: Optional[EMScalarFieldType] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) - Hz: Optional[EMScalarFieldType] = pd.Field( + Hz: Optional[EMScalarFieldType] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", @@ -375,32 +279,32 @@ class FieldDataset(ElectromagneticFieldDataset): >>> data = FieldDataset(Ex=scalar_field, Hz=scalar_field) """ - Ex: Optional[ScalarFieldDataArray] = pd.Field( + Ex: Optional[ScalarFieldDataArray] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) - Ey: Optional[ScalarFieldDataArray] = pd.Field( + Ey: Optional[ScalarFieldDataArray] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) - Ez: Optional[ScalarFieldDataArray] = pd.Field( + Ez: Optional[ScalarFieldDataArray] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) - Hx: Optional[ScalarFieldDataArray] = pd.Field( + Hx: Optional[ScalarFieldDataArray] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) - Hy: Optional[ScalarFieldDataArray] = pd.Field( + Hy: Optional[ScalarFieldDataArray] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) - Hz: Optional[ScalarFieldDataArray] = pd.Field( + Hz: Optional[ScalarFieldDataArray] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", @@ -507,32 +411,32 @@ class FieldTimeDataset(ElectromagneticFieldDataset): >>> data = FieldTimeDataset(Ex=scalar_field, Hz=scalar_field) """ - Ex: Optional[ScalarFieldTimeDataArray] = pd.Field( + Ex: Optional[ScalarFieldTimeDataArray] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field.", ) - Ey: Optional[ScalarFieldTimeDataArray] = pd.Field( + Ey: Optional[ScalarFieldTimeDataArray] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field.", ) - Ez: Optional[ScalarFieldTimeDataArray] = pd.Field( + Ez: Optional[ScalarFieldTimeDataArray] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field.", ) - Hx: Optional[ScalarFieldTimeDataArray] = pd.Field( + Hx: Optional[ScalarFieldTimeDataArray] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field.", ) - Hy: Optional[ScalarFieldTimeDataArray] = pd.Field( + Hy: Optional[ScalarFieldTimeDataArray] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field.", ) - Hz: Optional[ScalarFieldTimeDataArray] = pd.Field( + Hz: Optional[ScalarFieldTimeDataArray] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field.", @@ -550,19 +454,19 @@ def apply_phase(self, phase: float) -> AbstractFieldDataset: class AuxFieldDataset(AbstractFieldDataset, ABC): """Stores a collection of aux fields with x, y, z components.""" - Nfx: Optional[EMScalarFieldType] = pd.Field( + Nfx: Optional[EMScalarFieldType] = Field( None, title="Nfx", description="Spatial distribution of the free carrier density for " "polarization in the x-direction.", ) - Nfy: Optional[EMScalarFieldType] = pd.Field( + Nfy: Optional[EMScalarFieldType] = Field( None, title="Nfy", description="Spatial distribution of the free carrier density for " "polarization in the y-direction.", ) - Nfz: Optional[EMScalarFieldType] = pd.Field( + Nfz: Optional[EMScalarFieldType] = Field( None, title="Nfz", description="Spatial distribution of the free carrier density for " @@ -609,19 +513,19 @@ class AuxFieldTimeDataset(AuxFieldDataset): >>> data = AuxFieldTimeDataset(Nfx=scalar_field) """ - Nfx: Optional[ScalarFieldTimeDataArray] = pd.Field( + Nfx: Optional[ScalarFieldTimeDataArray] = Field( None, title="Nfx", description="Spatial distribution of the free carrier density for polarization " "in the x-direction.", ) - Nfy: Optional[ScalarFieldTimeDataArray] = pd.Field( + Nfy: Optional[ScalarFieldTimeDataArray] = Field( None, title="Nfy", description="Spatial distribution of the free carrier density for polarization " "in the y-direction.", ) - Nfz: Optional[ScalarFieldTimeDataArray] = pd.Field( + Nfz: Optional[ScalarFieldTimeDataArray] = Field( None, title="Nfz", description="Spatial distribution of the free carrier density for polarization " @@ -655,51 +559,50 @@ class ModeSolverDataset(ElectromagneticFieldDataset, ModeFreqDataset): ... ) """ - Ex: Optional[ScalarModeFieldDataArray] = pd.Field( + Ex: Optional[ScalarModeFieldDataArray] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field of the mode.", ) - Ey: Optional[ScalarModeFieldDataArray] = pd.Field( + Ey: Optional[ScalarModeFieldDataArray] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field of the mode.", ) - Ez: Optional[ScalarModeFieldDataArray] = pd.Field( + Ez: Optional[ScalarModeFieldDataArray] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field of the mode.", ) - Hx: Optional[ScalarModeFieldDataArray] = pd.Field( + Hx: Optional[ScalarModeFieldDataArray] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field of the mode.", ) - Hy: Optional[ScalarModeFieldDataArray] = pd.Field( + Hy: Optional[ScalarModeFieldDataArray] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field of the mode.", ) - Hz: Optional[ScalarModeFieldDataArray] = pd.Field( + Hz: Optional[ScalarModeFieldDataArray] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field of the mode.", ) - n_complex: ModeIndexDataArray = pd.Field( - ..., + n_complex: ModeIndexDataArray = Field( title="Propagation Index", description="Complex-valued effective propagation constants associated with the mode.", ) - n_group_raw: Optional[GroupIndexDataArray] = pd.Field( + n_group_raw: Optional[GroupIndexDataArray] = Field( None, alias="n_group", # This is for backwards compatibility only when loading old data title="Group Index", description="Index associated with group velocity of the mode.", ) - dispersion_raw: Optional[ModeDispersionDataArray] = pd.Field( + dispersion_raw: Optional[ModeDispersionDataArray] = Field( None, title="Dispersion", description="Dispersion parameter for the mode.", @@ -765,56 +668,6 @@ def plot_field(self, *args: Any, **kwargs: Any) -> None: ) -class AbstractMediumPropertyDataset(AbstractFieldDataset, ABC): - """Dataset storing medium property.""" - - eps_xx: ScalarFieldDataArray = pd.Field( - ..., - title="Epsilon xx", - description="Spatial distribution of the xx-component of the relative permittivity.", - ) - eps_yy: ScalarFieldDataArray = pd.Field( - ..., - title="Epsilon yy", - description="Spatial distribution of the yy-component of the relative permittivity.", - ) - eps_zz: ScalarFieldDataArray = pd.Field( - ..., - title="Epsilon zz", - description="Spatial distribution of the zz-component of the relative permittivity.", - ) - - -class PermittivityDataset(AbstractMediumPropertyDataset): - """Dataset storing the diagonal components of the permittivity tensor. - - Example - ------- - >>> x = [-1,1] - >>> y = [-2,0,2] - >>> z = [-3,-1,1,3] - >>> f = [2e14, 3e14] - >>> coords = dict(x=x, y=y, z=z, f=f) - >>> sclr_fld = ScalarFieldDataArray((1+1j) * np.random.random((2,3,4,2)), coords=coords) - >>> data = PermittivityDataset(eps_xx=sclr_fld, eps_yy=sclr_fld, eps_zz=sclr_fld) - """ - - @property - def field_components(self) -> dict[str, ScalarFieldDataArray]: - """Maps the field components to their associated data.""" - return {"eps_xx": self.eps_xx, "eps_yy": self.eps_yy, "eps_zz": self.eps_zz} - - @property - def grid_locations(self) -> dict[str, str]: - """Maps field components to the string key of their grid locations on the yee lattice.""" - return {"eps_xx": "Ex", "eps_yy": "Ey", "eps_zz": "Ez"} - - @property - def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: - """Maps field components to their (positive) symmetry eigenvalues.""" - return {"eps_xx": None, "eps_yy": None, "eps_zz": None} - - class MediumDataset(AbstractMediumPropertyDataset): """Dataset storing the diagonal components of the permittivity and permeability tensor. @@ -829,18 +682,15 @@ class MediumDataset(AbstractMediumPropertyDataset): >>> data = MediumDataset(eps_xx=sclr_fld, eps_yy=sclr_fld, eps_zz=sclr_fld, mu_xx=sclr_fld, mu_yy=sclr_fld, mu_zz=sclr_fld) """ - mu_xx: ScalarFieldDataArray = pd.Field( - ..., + mu_xx: ScalarFieldDataArray = Field( title="Mu xx", description="Spatial distribution of the xx-component of the relative permeability.", ) - mu_yy: ScalarFieldDataArray = pd.Field( - ..., + mu_yy: ScalarFieldDataArray = Field( title="Mu yy", description="Spatial distribution of the yy-component of the relative permeability.", ) - mu_zz: ScalarFieldDataArray = pd.Field( - ..., + mu_zz: ScalarFieldDataArray = Field( title="Mu zz", description="Spatial distribution of the zz-component of the relative permeability.", ) @@ -870,7 +720,7 @@ def grid_locations(self) -> dict[str, str]: } @property - def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: + def symmetry_eigenvalues(self) -> dict[str, None]: """Maps field components to their (positive) symmetry eigenvalues.""" return { "eps_xx": None, @@ -880,22 +730,3 @@ def symmetry_eigenvalues(self) -> dict[str, Callable[[Axis], float]]: "mu_yy": None, "mu_zz": None, } - - -class TriangleMeshDataset(Dataset): - """Dataset for storing triangular surface data.""" - - surface_mesh: TriangleMeshDataArray = pd.Field( - ..., - title="Surface mesh data", - description="Dataset containing the surface triangles and corresponding face indices " - "for a surface mesh.", - ) - - -class TimeDataset(Dataset): - """Dataset for storing a function of time.""" - - values: TimeDataArray = pd.Field( - ..., title="Values", description="Values as a function of time." - ) diff --git a/tidy3d/components/data/index.py b/tidy3d/components/data/index.py index 9f18390113..0aab90d9d5 100644 --- a/tidy3d/components/data/index.py +++ b/tidy3d/components/data/index.py @@ -7,7 +7,7 @@ from collections.abc import Mapping -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.index import ValueMap from tidy3d.components.types.simulation import SimulationDataType @@ -16,12 +16,14 @@ class SimulationDataMap(ValueMap, Mapping[str, SimulationDataType]): """An immutable dictionary-like container for simulation data. - It provides standard dictionary - behaviors like item access (`data["key"]`), iteration (`for key in data`), and - length checking (`len(data)`). + Notes + ----- + It provides standard dictionary + behaviors like item access (`data["key"]`), iteration (`for key in data`), and + length checking (`len(data)`). - It automatically validates that the `keys` and `values` - tuples have matching lengths upon instantiation. + It automatically validates that the `keys` and `values` + tuples have matching lengths upon instantiation. Attributes ---------- @@ -98,11 +100,11 @@ class SimulationDataMap(ValueMap, Mapping[str, SimulationDataType]): >>> # print(simulation_data_map["data_2"]) """ - keys_tuple: tuple[str, ...] = pd.Field( + keys_tuple: tuple[str, ...] = Field( description="A tuple of unique string identifiers for each simulation data object.", alias="keys", ) - values_tuple: tuple[SimulationDataType, ...] = pd.Field( + values_tuple: tuple[SimulationDataType, ...] = Field( description=( "A tuple of `SimulationDataType` objects, each corresponding to a key at the " "same index." diff --git a/tidy3d/components/data/monitor_data.py b/tidy3d/components/data/monitor_data.py index b2bda4e653..dd79c719e0 100644 --- a/tidy3d/components/data/monitor_data.py +++ b/tidy3d/components/data/monitor_data.py @@ -6,19 +6,17 @@ import warnings from abc import ABC from math import isclose -from os import PathLike -from typing import Any, Callable, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, get_args import autograd.numpy as np -import pydantic.v1 as pd import xarray as xr -from pandas import DataFrame, Index +from pandas import Index +from pydantic import Field, model_validator -from tidy3d.components.base import cached_property, skip_if_fields_missing +from tidy3d.components.base import cached_property from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData from tidy3d.components.grid.grid import Coords, Grid from tidy3d.components.medium import Medium, MediumType -from tidy3d.components.mode_spec import ModeSortSpec, ModeSpec from tidy3d.components.monitor import ( AuxFieldTimeMonitor, DiffractionMonitor, @@ -36,24 +34,15 @@ ModeSolverMonitor, PermittivityMonitor, ) -from tidy3d.components.source.base import Source -from tidy3d.components.source.current import CustomCurrentSource, PointDipole +from tidy3d.components.source.current import CustomCurrentSource from tidy3d.components.source.field import CustomFieldSource, ModeSource, PlaneWave -from tidy3d.components.source.time import GaussianPulse, SourceTimeType +from tidy3d.components.source.time import GaussianPulse from tidy3d.components.types import ( TYPE_TAG_STR, ArrayFloat1D, - ArrayFloat2D, Coordinate, - Direction, - EMField, EpsSpecType, - FreqArray, - Numpy, - PolarizationBasis, - Size, Symmetry, - TrackFreq, UnitsZBF, ) from tidy3d.components.types.monitor import MonitorType @@ -77,18 +66,14 @@ FreqDataArray, FreqModeDataArray, GroupIndexDataArray, - MixedModeDataArray, ModeAmpsDataArray, ModeDispersionDataArray, - ModeIndexDataArray, ScalarFieldDataArray, - ScalarFieldTimeDataArray, TimeDataArray, ) from .dataset import ( AbstractFieldDataset, AuxFieldTimeDataset, - Dataset, ElectromagneticFieldDataset, FieldDataset, FieldTimeDataset, @@ -97,6 +82,31 @@ PermittivityDataset, ) +if TYPE_CHECKING: + from os import PathLike + from typing import Literal, SupportsComplex + + from numpy.typing import NDArray + from pandas import DataFrame + + from tidy3d.compat import Self + from tidy3d.components.mode_spec import ModeSortSpec, ModeSpec + from tidy3d.components.source.base import Source + from tidy3d.components.source.current import PointDipole + from tidy3d.components.source.time import SourceTimeType + from tidy3d.components.types import ( + ArrayFloat2D, + Direction, + EMField, + FreqArray, + PolarizationBasis, + Size, + TrackFreq, + ) + + from .data_array import MixedModeDataArray, ModeIndexDataArray, ScalarFieldTimeDataArray + from .dataset import Dataset + Coords1D = ArrayFloat1D # how much to shift the adjoint field source for 0-D axes dimensions @@ -108,14 +118,21 @@ COS_THETA_THRESH = 1e-5 MODE_INTERP_EXTRAPOLATION_TOLERANCE = 1e-2 +GRID_CORRECTION_TYPE = Union[ + float, + FreqDataArray, + TimeDataArray, + FreqModeDataArray, + EMEFreqModeDataArray, +] + class MonitorData(AbstractMonitorData, ABC): """ Abstract base class of objects that store data pertaining to a single :class:`.monitor`. """ - monitor: MonitorType = pd.Field( - ..., + monitor: MonitorType = Field( title="Monitor", description="Monitor associated with the data.", discriminator=TYPE_TAG_STR, @@ -177,7 +194,7 @@ def flip_direction(direction: Union[str, DataArray]) -> str: return "-" if direction == "+" else "+" @staticmethod - def get_amplitude(x) -> complex: + def get_amplitude(x: Union[DataArray, SupportsComplex]) -> complex: """Get the complex amplitude out of some data.""" if isinstance(x, DataArray): @@ -198,19 +215,19 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC): MediumMonitor, ] - symmetry: tuple[Symmetry, Symmetry, Symmetry] = pd.Field( + symmetry: tuple[Symmetry, Symmetry, Symmetry] = Field( (0, 0, 0), title="Symmetry", description="Symmetry eigenvalues of the original simulation in x, y, and z.", ) - symmetry_center: Coordinate = pd.Field( + symmetry_center: Optional[Coordinate] = Field( None, title="Symmetry Center", description="Center of the symmetry planes of the original simulation in x, y, and z. " "Required only if any of the ``symmetry`` field are non-zero.", ) - grid_expanded: Grid = pd.Field( + grid_expanded: Optional[Grid] = Field( None, title="Expanded Grid", description=":class:`.Grid` discretization of the associated monitor in the simulation " @@ -218,27 +235,29 @@ class AbstractFieldData(MonitorData, AbstractFieldDataset, ABC): "well as in order to use some functionalities like getting Poynting vector and flux.", ) - @pd.validator("grid_expanded", always=True) - def warn_missing_grid_expanded(cls, val, values): + @model_validator(mode="after") + def warn_missing_grid_expanded(self) -> Self: """If ``grid_expanded`` not provided and fields data is present, warn that some methods will break.""" field_comps = ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] - if val is None and any(values.get(comp) is not None for comp in field_comps): + if self.grid_expanded is None and any( + getattr(self, comp) is not None for comp in field_comps + ): log.warning( "Monitor data requires 'grid_expanded' to be defined to compute values like " "flux, Poynting and dot product with other data." ) - return val + return self - _require_sym_center = required_if_symmetry_present("symmetry_center") - _require_grid_expanded = required_if_symmetry_present("grid_expanded") + _require_sym_center: Callable[[Any], Any] = required_if_symmetry_present("symmetry_center") + _require_grid_expanded: Callable[[Any], Any] = required_if_symmetry_present("grid_expanded") def _expanded_grid_field_coords(self, field_name: str) -> Coords: """Coordinates in the expanded grid corresponding to a given field component.""" return self.grid_expanded[self.grid_locations[field_name]] @property - def symmetry_expanded(self): + def symmetry_expanded(self) -> Self: """Return the :class:`.AbstractFieldData` with fields expanded based on symmetry. If any symmetry is nonzero (i.e. expanded), the interpolation implicitly creates a copy of the data array. However, if symmetry is not expanded, the returned array contains a view of @@ -256,7 +275,7 @@ def symmetry_expanded(self): return self.updated_copy(**self._symmetry_update_dict, deep=False, validate=False) @property - def symmetry_expanded_copy(self) -> AbstractFieldData: + def symmetry_expanded_copy(self) -> Self: """Create a copy of the :class:`.AbstractFieldData` with fields expanded based on symmetry. Returns @@ -274,7 +293,7 @@ def symmetry_expanded_copy(self) -> AbstractFieldData: def _symmetry_update_dict(self) -> dict: """Dictionary of data fields to create data with expanded symmetry.""" - update_dict = {} + update_dict: dict[str, Optional[tuple[float, float, float], DataArray]] = {} warn_interp = False for field_name, scalar_data in self.field_components.items(): eigenval_fn = self.symmetry_eigenvalues[field_name] @@ -395,13 +414,7 @@ def at_coords(self, coords: Coords) -> xr.Dataset: class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, ABC): """Collection of electromagnetic fields.""" - grid_primal_correction: Union[ - float, - FreqDataArray, - TimeDataArray, - FreqModeDataArray, - EMEFreqModeDataArray, - ] = pd.Field( + grid_primal_correction: GRID_CORRECTION_TYPE = Field( 1.0, title="Field correction factor", description="Correction factor that needs to be applied for data corresponding to a 2D " @@ -409,13 +422,7 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A "which the data was computed. The factor is applied to fields defined on the primal grid " "locations along the normal direction.", ) - grid_dual_correction: Union[ - float, - FreqDataArray, - TimeDataArray, - FreqModeDataArray, - EMEFreqModeDataArray, - ] = pd.Field( + grid_dual_correction: GRID_CORRECTION_TYPE = Field( 1.0, title="Field correction factor", description="Correction factor that needs to be applied for data corresponding to a 2D " @@ -424,7 +431,7 @@ class ElectromagneticFieldData(AbstractFieldData, ElectromagneticFieldDataset, A "locations along the normal direction.", ) - def _expanded_grid_field_coords(self, field_name: str): + def _expanded_grid_field_coords(self, field_name: str) -> Coords: """Coordinates in the expanded grid corresponding to a given field component.""" if self.monitor.colocate: bounds_dict = self.grid_expanded.boundaries.to_dict @@ -432,7 +439,7 @@ def _expanded_grid_field_coords(self, field_name: str): return self.grid_expanded[self.grid_locations[field_name]] @property - def _grid_correction_dict(self): + def _grid_correction_dict(self) -> dict[str, GRID_CORRECTION_TYPE]: """Return the primal and dual finite grid correction factors as a dictionary.""" return { "grid_primal_correction": self.grid_primal_correction, @@ -948,7 +955,7 @@ def outer_dot( d_area = self._diff_area.expand_dims(dim={"f": f}, axis=2).to_numpy() # function to apply at each pair of mode indices before integrating - def fn(fields_1, fields_2): + def fn(fields_1: dict[str, NDArray], fields_2: dict[str, NDArray]) -> NDArray: e_self_1 = fields_1[e_1] e_self_2 = fields_1[e_2] h_self_1 = fields_1[h_1] @@ -989,7 +996,7 @@ def _outer_fn_summation( outer_dim_1: str, outer_dim_2: str, sum_dims: list[str], - fn: Callable, + fn: Callable[[dict[str, NDArray], NDArray], NDArray], ) -> DataArray: """ Loop over ``outer_dim_1`` and ``outer_dim_2``, apply ``fn`` to ``fields_1`` and ``fields_2``, and sum over ``sum_dims``. @@ -1167,7 +1174,7 @@ def to_zbf( Returns ------- - Tuple[:class:`.ScalarFieldDataArray`,:class:`.ScalarFieldDataArray`] + tuple[:class:`.ScalarFieldDataArray`,:class:`.ScalarFieldDataArray`] The two E field components being exported to ``.zbf``. """ log.warning( @@ -1349,8 +1356,9 @@ class FieldData(FieldDataset, ElectromagneticFieldData): * `Advanced monitor data manipulation and visualization <../../notebooks/XarrayTutorial.html>`_ """ - monitor: FieldMonitor = pd.Field( - ..., title="Monitor", description="Frequency-domain field monitor associated with the data." + monitor: FieldMonitor = Field( + title="Monitor", + description="Frequency-domain field monitor associated with the data.", ) _contains_monitor_fields = enforce_monitor_fields_present() @@ -1373,9 +1381,9 @@ def to_source( ---------- source_time: :class:`.SourceTime` Specification of the source time-dependence. - center: Tuple[float, float, float] + center: tuple[float, float, float] Source center in x, y and z. - size: Tuple[float, float, float] + size: tuple[float, float, float] Source size in x, y, and z. If not provided, the size of the monitor associated to the data is used. **kwargs @@ -1501,8 +1509,9 @@ class FieldTimeData(FieldTimeDataset, ElectromagneticFieldData): >>> data = FieldTimeData(monitor=monitor, Ex=scalar_field, Hz=scalar_field, grid_expanded=grid) """ - monitor: FieldTimeMonitor = pd.Field( - ..., title="Monitor", description="Time-domain field monitor associated with the data." + monitor: FieldTimeMonitor = Field( + title="Monitor", + description="Time-domain field monitor associated with the data.", ) _contains_monitor_fields = enforce_monitor_fields_present() @@ -1575,8 +1584,7 @@ class AuxFieldTimeData(AuxFieldTimeDataset, AbstractFieldData): >>> data = AuxFieldTimeData(monitor=monitor, Nfx=scalar_field, grid_expanded=grid) """ - monitor: AuxFieldTimeMonitor = pd.Field( - ..., + monitor: AuxFieldTimeMonitor = Field( title="Monitor", description="Time-domain auxiliary field monitor associated with the data.", ) @@ -1609,8 +1617,9 @@ class PermittivityData(PermittivityDataset, AbstractFieldData): ... ) """ - monitor: PermittivityMonitor = pd.Field( - ..., title="Monitor", description="Permittivity monitor associated with the data." + monitor: PermittivityMonitor = Field( + title="Monitor", + description="Permittivity monitor associated with the data.", ) @@ -1639,8 +1648,8 @@ class MediumData(MediumDataset, AbstractFieldData): ... ) """ - monitor: MediumMonitor = pd.Field( - ..., title="Monitor", description="Medium property monitor associated with the data." + monitor: MediumMonitor = Field( + title="Monitor", description="Medium property monitor associated with the data." ) @@ -1683,34 +1692,38 @@ class ModeData(ModeSolverDataset, ElectromagneticFieldData): >>> data = ModeData(monitor=monitor, amps=amp_data, n_complex=index_data) """ - monitor: ModeMonitor = pd.Field( - ..., title="Monitor", description="Mode monitor associated with the data." + monitor: ModeMonitor = Field( + title="Monitor", + description="Mode monitor associated with the data.", ) - amps: ModeAmpsDataArray = pd.Field( - ..., title="Amplitudes", description="Complex-valued amplitudes associated with the mode." + amps: ModeAmpsDataArray = Field( + title="Amplitudes", + description="Complex-valued amplitudes associated with the mode.", ) - eps_spec: list[EpsSpecType] = pd.Field( + eps_spec: Optional[list[EpsSpecType]] = Field( None, title="Permittivity Specification", description="Characterization of the permittivity profile on the plane where modes are " "computed. Possible values are 'diagonal', 'tensorial_real', 'tensorial_complex'.", ) - @pd.validator("eps_spec", always=True) - @skip_if_fields_missing(["n_complex"]) - def eps_spec_match_mode_spec(cls, val, values): + @model_validator(mode="after") + def eps_spec_match_mode_spec(self) -> Self: """Raise validation error if frequencies in eps_spec does not match frequency list""" + if self.n_complex is None: + return self + val = self.eps_spec if val: - mode_data_freqs = values["n_complex"].coords["f"].values + mode_data_freqs = self.n_complex.coords["f"].values if len(val) != len(mode_data_freqs): raise ValidationError( "eps_spec must be provided at the same frequencies as mode solver data." ) - return val + return self - def normalize(self, source_spectrum_fn) -> ModeData: + def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> Self: """Return copy of self after normalization is applied using source spectrum function.""" source_freq_amps = source_spectrum_fn(self.amps.f)[None, :, None] new_amps = (self.amps / source_freq_amps).astype(self.amps.dtype) @@ -1833,7 +1846,7 @@ def overlap_sort( return data_reordered.updated_copy(monitor=monitor_updated, deep=False, validate=False) - def _isel(self, **isel_kwargs: Any): + def _isel(self, **isel_kwargs: Any) -> Self: """Wraps ``xarray.DataArray.isel`` for all data fields that are defined over frequency and mode index. Used in ``overlap_sort`` but not officially supported since for example ``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the @@ -1847,12 +1860,11 @@ def _isel(self, **isel_kwargs: Any): } return self.updated_copy(**update_dict, deep=False, validate=False) - def _assign_coords(self, **assign_coords_kwargs: Any): + def _assign_coords(self, **assign_coords_kwargs: Any) -> Self: """Wraps ``xarray.DataArray.assign_coords`` for all data fields that are defined over frequency and mode index. Used in ``overlap_sort`` but not officially supported since for example ``self.monitor.mode_spec`` and ``self.monitor.freqs`` will no longer be matching the newly created data.""" - update_dict = dict(self._grid_correction_dict, **self.field_components) update_dict = { key: field.assign_coords(**assign_coords_kwargs) for key, field in update_dict.items() @@ -1863,7 +1875,7 @@ def _find_ordering_one_freq( self, data_to_sort: ModeData, overlap_thresh: Union[float, np.array], - ) -> tuple[Numpy, Numpy]: + ) -> tuple[np.ndarray, np.ndarray]: """Find new ordering of modes in data_to_sort based on their similarity to own modes.""" num_modes = self.n_complex.sizes["mode_index"] @@ -1899,7 +1911,7 @@ def _find_ordering_one_freq( return pairs, complex_amps @staticmethod - def _find_closest_pairs(arr: Numpy) -> tuple[Numpy, Numpy]: + def _find_closest_pairs(arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """Given a complex overlap matrix pair row and column entries.""" n, k = np.shape(arr) @@ -2268,7 +2280,7 @@ def _adjoint_source_amp(self, amp: DataArray, fwidth: float) -> ModeSource: return src_adj - def _apply_mode_reorder(self, sort_inds_2d): + def _apply_mode_reorder(self, sort_inds_2d: NDArray) -> Self: """Apply a mode reordering along mode_index for all frequency indices. Parameters @@ -2367,7 +2379,7 @@ def sort_modes( sort_inds_2d = np.tile(identity, (num_freqs, 1)) # Helper to compute ordered indices within a subset - def _order_indices(indices, vals_all): + def _order_indices(indices: NDArray, vals_all: DataArray) -> NDArray: if indices.size == 0: return indices vals = vals_all.isel(mode_index=indices) @@ -2475,15 +2487,18 @@ class ModeSolverData(ModeData): ... ) """ - monitor: ModeSolverMonitor = pd.Field( - ..., title="Monitor", description="Mode solver monitor associated with the data." + monitor: ModeSolverMonitor = Field( + title="Monitor", + description="Mode solver monitor associated with the data.", ) - amps: ModeAmpsDataArray = pd.Field( - None, title="Amplitudes", description="Unused for ModeSolverData." + amps: Optional[ModeAmpsDataArray] = Field( + None, + title="Amplitudes", + description="Unused for ModeSolverData.", ) - grid_distances_primal: Union[tuple[float], tuple[float, float]] = pd.Field( + grid_distances_primal: Union[tuple[float], tuple[float, float]] = Field( (0.0,), title="Distances to the Primal Grid", description="Relative distances to the primal grid locations along the normal direction in " @@ -2491,7 +2506,7 @@ class ModeSolverData(ModeData): "interpolating in frequency.", ) - grid_distances_dual: Union[tuple[float], tuple[float, float]] = pd.Field( + grid_distances_dual: Union[tuple[float], tuple[float, float]] = Field( (0.0,), title="Distances to the Dual Grid", description="Relative distances to the dual grid locations along the normal direction in " @@ -2499,11 +2514,11 @@ class ModeSolverData(ModeData): "interpolating in frequency.", ) - def normalize(self, source_spectrum_fn: Callable[[float], complex]) -> ModeSolverData: + def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> ModeSolverData: """Return copy of self after normalization is applied using source spectrum function.""" return self.copy() - def _normalize_modes(self): + def _normalize_modes(self) -> None: """Normalize modes. Note: this modifies ``self`` in-place.""" scaling = np.sqrt(np.abs(self.flux)) for field in self.field_components.values(): @@ -2782,12 +2797,14 @@ class FluxData(MonitorData): * `Advanced monitor data manipulation and visualization <../../notebooks/XarrayTutorial.html>`_ """ - monitor: FluxMonitor = pd.Field( - ..., title="Monitor", description="Frequency-domain flux monitor associated with the data." + monitor: FluxMonitor = Field( + title="Monitor", + description="Frequency-domain flux monitor associated with the data.", ) - flux: FluxDataArray = pd.Field( - ..., title="Flux", description="Flux values in the frequency-domain." + flux: FluxDataArray = Field( + title="Flux", + description="Flux values in the frequency-domain.", ) def _make_adjoint_sources( @@ -2808,7 +2825,7 @@ def _make_adjoint_sources( "computation." ) - def normalize(self, source_spectrum_fn) -> FluxData: + def normalize(self, source_spectrum_fn: Callable[[DataArray], NDArray]) -> FluxData: """Return copy of self after normalization is applied using source spectrum function.""" source_freq_amps = source_spectrum_fn(self.flux.f) source_power = abs(source_freq_amps) ** 2 @@ -2836,12 +2853,14 @@ class FluxTimeData(MonitorData): >>> data = FluxTimeData(monitor=monitor, flux=flux_data) """ - monitor: FluxTimeMonitor = pd.Field( - ..., title="Monitor", description="Time-domain flux monitor associated with the data." + monitor: FluxTimeMonitor = Field( + title="Monitor", + description="Time-domain flux monitor associated with the data.", ) - flux: FluxTimeDataArray = pd.Field( - ..., title="Flux", description="Flux values in the time-domain." + flux: FluxTimeDataArray = Field( + title="Flux", + description="Flux values in the time-domain.", ) @@ -2864,52 +2883,45 @@ class FluxTimeData(MonitorData): class AbstractFieldProjectionData(MonitorData): """Collection of projected fields in spherical coordinates in the frequency domain.""" - monitor: ProjMonitorType = pd.Field( - ..., + monitor: ProjMonitorType = Field( title="Projection monitor", description="Field projection monitor.", discriminator=TYPE_TAG_STR, ) - Er: ProjFieldType = pd.Field( - ..., + Er: ProjFieldType = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: ProjFieldType = pd.Field( - ..., + Etheta: ProjFieldType = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: ProjFieldType = pd.Field( - ..., + Ephi: ProjFieldType = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: ProjFieldType = pd.Field( - ..., + Hr: ProjFieldType = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: ProjFieldType = pd.Field( - ..., + Htheta: ProjFieldType = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: ProjFieldType = pd.Field( - ..., + Hphi: ProjFieldType = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) - medium: MediumType = pd.Field( - Medium(), + medium: MediumType = Field( + default_factory=Medium, title="Background Medium", description="Background medium through which to project fields.", discriminator=TYPE_TAG_STR, ) - is_2d_simulation: bool = pd.Field( + is_2d_simulation: bool = Field( False, title="2D Simulation", description="Indicates whether the monitor data is for a 2D simulation.", @@ -3170,45 +3182,37 @@ class FieldProjectionAngleData(AbstractFieldProjectionData): ... ) """ - monitor: FieldProjectionAngleMonitor = pd.Field( - ..., + monitor: FieldProjectionAngleMonitor = Field( title="Projection monitor", description="Field projection monitor with an angle-based projection grid.", ) - projection_surfaces: tuple[FieldProjectionSurface, ...] = pd.Field( - ..., + projection_surfaces: tuple[FieldProjectionSurface, ...] = Field( title="Projection surfaces", description="Surfaces of the monitor where near fields were recorded for projection", ) - Er: FieldProjectionAngleDataArray = pd.Field( - ..., + Er: FieldProjectionAngleDataArray = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: FieldProjectionAngleDataArray = pd.Field( - ..., + Etheta: FieldProjectionAngleDataArray = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: FieldProjectionAngleDataArray = pd.Field( - ..., + Ephi: FieldProjectionAngleDataArray = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: FieldProjectionAngleDataArray = pd.Field( - ..., + Hr: FieldProjectionAngleDataArray = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: FieldProjectionAngleDataArray = pd.Field( - ..., + Htheta: FieldProjectionAngleDataArray = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: FieldProjectionAngleDataArray = pd.Field( - ..., + Hphi: FieldProjectionAngleDataArray = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) @@ -3380,45 +3384,37 @@ class FieldProjectionCartesianData(AbstractFieldProjectionData): ... ) """ - monitor: FieldProjectionCartesianMonitor = pd.Field( - ..., + monitor: FieldProjectionCartesianMonitor = Field( title="Projection monitor", description="Field projection monitor with a Cartesian projection grid.", ) - projection_surfaces: tuple[FieldProjectionSurface, ...] = pd.Field( - ..., + projection_surfaces: tuple[FieldProjectionSurface, ...] = Field( title="Projection surfaces", description="Surfaces of the monitor where near fields were recorded for projection", ) - Er: FieldProjectionCartesianDataArray = pd.Field( - ..., + Er: FieldProjectionCartesianDataArray = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: FieldProjectionCartesianDataArray = pd.Field( - ..., + Etheta: FieldProjectionCartesianDataArray = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: FieldProjectionCartesianDataArray = pd.Field( - ..., + Ephi: FieldProjectionCartesianDataArray = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: FieldProjectionCartesianDataArray = pd.Field( - ..., + Hr: FieldProjectionCartesianDataArray = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: FieldProjectionCartesianDataArray = pd.Field( - ..., + Htheta: FieldProjectionCartesianDataArray = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: FieldProjectionCartesianDataArray = pd.Field( - ..., + Hphi: FieldProjectionCartesianDataArray = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) @@ -3439,7 +3435,7 @@ def z(self) -> np.ndarray: return self.Etheta.z.values @property - def tangential_dims(self): + def tangential_dims(self) -> list[str]: tangential_dims = ["x", "y", "z"] tangential_dims.pop(self.monitor.proj_axis) return tangential_dims @@ -3533,45 +3529,37 @@ class FieldProjectionKSpaceData(AbstractFieldProjectionData): ... ) """ - monitor: FieldProjectionKSpaceMonitor = pd.Field( - ..., + monitor: FieldProjectionKSpaceMonitor = Field( title="Projection monitor", description="Field projection monitor with a projection grid defined in k-space.", ) - projection_surfaces: tuple[FieldProjectionSurface, ...] = pd.Field( - ..., + projection_surfaces: tuple[FieldProjectionSurface, ...] = Field( title="Projection surfaces", description="Surfaces of the monitor where near fields were recorded for projection", ) - Er: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Er: FieldProjectionKSpaceDataArray = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Etheta: FieldProjectionKSpaceDataArray = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Ephi: FieldProjectionKSpaceDataArray = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Hr: FieldProjectionKSpaceDataArray = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Htheta: FieldProjectionKSpaceDataArray = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: FieldProjectionKSpaceDataArray = pd.Field( - ..., + Hphi: FieldProjectionKSpaceDataArray = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) @@ -3673,50 +3661,43 @@ class DiffractionData(AbstractFieldProjectionData): ... ) """ - monitor: DiffractionMonitor = pd.Field( - ..., title="Monitor", description="Diffraction monitor associated with the data." + monitor: DiffractionMonitor = Field( + title="Monitor", + description="Diffraction monitor associated with the data.", ) - Er: DiffractionDataArray = pd.Field( - ..., + Er: DiffractionDataArray = Field( title="Er", description="Spatial distribution of r-component of the electric field.", ) - Etheta: DiffractionDataArray = pd.Field( - ..., + Etheta: DiffractionDataArray = Field( title="Etheta", description="Spatial distribution of the theta-component of the electric field.", ) - Ephi: DiffractionDataArray = pd.Field( - ..., + Ephi: DiffractionDataArray = Field( title="Ephi", description="Spatial distribution of phi-component of the electric field.", ) - Hr: DiffractionDataArray = pd.Field( - ..., + Hr: DiffractionDataArray = Field( title="Hr", description="Spatial distribution of r-component of the magnetic field.", ) - Htheta: DiffractionDataArray = pd.Field( - ..., + Htheta: DiffractionDataArray = Field( title="Htheta", description="Spatial distribution of theta-component of the magnetic field.", ) - Hphi: DiffractionDataArray = pd.Field( - ..., + Hphi: DiffractionDataArray = Field( title="Hphi", description="Spatial distribution of phi-component of the magnetic field.", ) - sim_size: tuple[float, float] = pd.Field( - ..., + sim_size: tuple[float, float] = Field( title="Domain size", description="Size of the near field in the local x and y directions.", units=MICROMETER, ) - bloch_vecs: Union[tuple[float, float], tuple[ArrayFloat1D, ArrayFloat1D]] = pd.Field( - ..., + bloch_vecs: Union[tuple[float, float], tuple[ArrayFloat1D, ArrayFloat1D]] = Field( title="Bloch vectors", description="Bloch vectors along the local x and y directions in units of " "``2 * pi / (simulation size along the respective dimension)``.", @@ -4022,14 +4003,12 @@ class DirectivityData(FieldProjectionAngleData): ... Hr=scalar_field, Htheta=scalar_field, Hphi=scalar_field, projection_surfaces=monitor.projection_surfaces) """ - monitor: DirectivityMonitor = pd.Field( - ..., + monitor: DirectivityMonitor = Field( title="Monitor", description="Monitor describing the angle-based projection grid on which to measure directivity data.", ) - flux: FluxDataArray = pd.Field( - ..., + flux: FluxDataArray = Field( title="Flux", description="Flux values that are either computed from fields recorded on the " "projection surfaces or by integrating the projected fields over a spherical surface.", diff --git a/tidy3d/components/data/sim_data.py b/tidy3d/components/data/sim_data.py index fdb0c6b552..af5f7f04b0 100644 --- a/tidy3d/components/data/sim_data.py +++ b/tidy3d/components/data/sim_data.py @@ -7,40 +7,48 @@ import re from abc import ABC from collections import defaultdict -from os import PathLike -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Union import h5py import numpy as np -import pydantic.v1 as pd import xarray as xr +from pydantic import Field from tidy3d.components.autograd.utils import split_list from tidy3d.components.base import JSON_TAG, Tidy3dBaseModel, cached_property from tidy3d.components.base_sim.data.sim_data import AbstractSimulationData -from tidy3d.components.file_util import replace_values -from tidy3d.components.monitor import Monitor from tidy3d.components.simulation import Simulation from tidy3d.components.source.current import CustomCurrentSource from tidy3d.components.source.time import GaussianPulse from tidy3d.components.source.utils import SourceType from tidy3d.components.structure import Structure -from tidy3d.components.types import Ax, Axis, ColormapType, FieldVal, PlotScale, annotate_type +from tidy3d.components.types.base import discriminated_union from tidy3d.components.types.monitor_data import MonitorDataType, MonitorDataTypes from tidy3d.components.viz import add_ax_if_none, equal_aspect -from tidy3d.exceptions import DataError, FileError, SetupError, Tidy3dKeyError +from tidy3d.exceptions import DataError, SetupError, Tidy3dKeyError from tidy3d.log import log from .data_array import FreqDataArray, TimeDataArray from .monitor_data import AbstractFieldData, FieldTimeData if TYPE_CHECKING: + from os import PathLike + from typing import Callable, Optional + from matplotlib.colors import Colormap + from numpy.typing import NDArray + + from tidy3d.components.monitor import Monitor + from tidy3d.components.types import Ax, Axis, ColormapType, FieldVal, PlotScale + + from .data_array import DataArray -DATA_TYPE_MAP = {data.__fields__["monitor"].type_: data for data in MonitorDataTypes} +DATA_TYPE_MAP = {data.model_fields["monitor"].annotation: data for data in MonitorDataTypes} # maps monitor type (string) to the class of the corresponding data -DATA_TYPE_NAME_MAP = {val.__fields__["monitor"].type_.__name__: val for val in MonitorDataTypes} +DATA_TYPE_NAME_MAP = { + val.model_fields["monitor"].annotation.__name__: val for val in MonitorDataTypes +} # residuals below this are considered good fits for broadband adjoint source creation RESIDUAL_CUTOFF_ADJOINT = 1e-6 @@ -55,21 +63,18 @@ class AdjointSourceInfo(Tidy3dBaseModel): """Stores information about the adjoint sources to pass to autograd pipeline.""" - sources: tuple[annotate_type(SourceType), ...] = pd.Field( - ..., + sources: tuple[discriminated_union(SourceType), ...] = Field( title="Adjoint Sources", description="Set of processed sources to include in the adjoint simulation.", ) - post_norm: Union[float, FreqDataArray] = pd.Field( - ..., + post_norm: Union[float, FreqDataArray] = Field( title="Post Normalization Values", description="Factor to multiply the adjoint fields by after running " "given the adjoint source pipeline used.", ) - normalize_sim: bool = pd.Field( - ..., + normalize_sim: bool = Field( title="Normalize Adjoint Simulation", description="Whether the adjoint simulation needs to be normalized " "given the adjoint source pipeline used.", @@ -246,7 +251,7 @@ def _get_scalar_field( field_name: str, val: FieldVal, phase: float = 0.0, - ): + ) -> xr.DataArray: """return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers. Parameters @@ -277,7 +282,7 @@ def _get_scalar_field_from_data( field_name: str, val: FieldVal, phase: float = 0.0, - ): + ) -> xr.DataArray: """return ``xarray.DataArray`` of the scalar field of a given monitor at Yee cell centers. Parameters @@ -348,7 +353,6 @@ def _get_scalar_field_from_data( f"'val' of {val} not supported. " "Must be one of 'real', 'imag', 'abs', 'abs^2', or 'phase'." ) - return derived_data raise Tidy3dKeyError( @@ -376,7 +380,7 @@ def get_intensity(self, field_monitor_name: str) -> xr.DataArray: @classmethod def mnt_data_from_file( - cls, fname: PathLike, mnt_name: str, **parse_obj_kwargs: Any + cls, fname: PathLike, mnt_name: str, **model_validate_kwargs: Any ) -> MonitorDataType: """Loads data for a specific monitor from a .hdf5 file with data for a ``SimulationData``. @@ -386,8 +390,8 @@ def mnt_data_from_file( Full path to an hdf5 file containing :class:`.SimulationData` data. mnt_name : str, optional ``.name`` of the monitor to load the data from. - **parse_obj_kwargs - Keyword arguments passed to either pydantic's ``parse_obj`` function when loading model. + **model_validate_kwargs + Keyword arguments passed to pydantic's ``model_validate`` method when loading model. Returns ------- @@ -428,7 +432,7 @@ def mnt_data_from_file( # load the monitor data from the file using the group_path group_path = f"data/{monitor_index_str}" return monitor_data_type.from_file( - fname, group_path=group_path, **parse_obj_kwargs + fname, group_path=group_path, **model_validate_kwargs ) raise ValueError(f"No monitor with name '{mnt_name}' found in data file.") @@ -546,7 +550,7 @@ def plot_field_monitor_data( ("E", "abs^2"): 10, ("H", "abs^2"): 10, }.get((field_name[0], val), 20) - field_data = db_factor * np.log10(np.abs(field_data)) + field_data = self._apply_log_scale(field_data, vmin=vmin, db_factor=db_factor) field_data.name += " (dB)" cmap_type = "sequential" elif scale == "lin": @@ -943,20 +947,18 @@ class SimulationData(AbstractYeeGridSimulationData): """ - simulation: Simulation = pd.Field( - ..., + simulation: Simulation = Field( title="Simulation", description="Original :class:`.Simulation` associated with the data.", ) - data: tuple[annotate_type(MonitorDataType), ...] = pd.Field( - ..., + data: tuple[discriminated_union(MonitorDataType), ...] = Field( title="Monitor Data", description="List of :class:`.MonitorData` instances " "associated with the monitors of the original :class:`.Simulation`.", ) - diverged: bool = pd.Field( + diverged: bool = Field( False, title="Diverged", description="A boolean flag denoting whether the simulation run diverged.", @@ -998,7 +1000,7 @@ def source_spectrum(self, source_index: int) -> Callable: dt = self.simulation.dt # plug in mornitor_data frequency domain information - def source_spectrum_fn(freqs): + def source_spectrum_fn(freqs: DataArray) -> NDArray: """Source amplitude as function of frequency.""" spectrum = source_time.spectrum(times, freqs, dt) @@ -1026,14 +1028,14 @@ def renormalize(self, normalize_index: int) -> SimulationData: f"of length {num_sources}" ) - def source_spectrum_fn(freqs): + def source_spectrum_fn(freqs: DataArray) -> NDArray: """Normalization function that also removes previous normalization if needed.""" new_spectrum_fn = self.source_spectrum(normalize_index) old_spectrum_fn = self.source_spectrum(self.simulation.normalize_index) return new_spectrum_fn(freqs) / old_spectrum_fn(freqs) # Make a new monitor_data dictionary with renormalized data - data_normalized = [mnt_data.normalize(source_spectrum_fn) for mnt_data in self.data] + data_normalized = tuple(mnt_data.normalize(source_spectrum_fn) for mnt_data in self.data) simulation = self.simulation.copy(update={"normalize_index": normalize_index}) @@ -1099,6 +1101,9 @@ def _make_adjoint_sims( for src_list in sources_adj_dict.values(): adj_srcs += list(src_list) + if not adj_srcs: + return [] + adjoint_source_infos = self._process_adjoint_sources(adj_srcs=adj_srcs) if not adjoint_source_infos: @@ -1144,7 +1149,7 @@ def _make_adjoint_sims( return adj_sims - def _make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceType]: + def _make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, list[SourceType]]: """Generate all of the non-zero sources for the adjoint simulation given the VJP data.""" # map of index into 'self.data' to the list of datasets we need adjoint sources for @@ -1159,10 +1164,11 @@ def _make_adjoint_sources(self, data_vjp_paths: set[tuple]) -> dict[str, SourceT sources_adj = mnt_data._make_adjoint_sources( dataset_names=dataset_names, fwidth=self._fwidth_adj ) - sources_adj_all[mnt_data.monitor.name] = sources_adj log.info( f"Created {len(sources_adj)} adjoint sources for monitor '{mnt_data.monitor.name}'." ) + if sources_adj: + sources_adj_all[mnt_data.monitor.name] = sources_adj return sources_adj_all @@ -1274,7 +1280,7 @@ def _process_adjoint_sources_broadband( return [src_broadband], post_norm_amps @staticmethod - def _adjoint_src_width_broadband(adj_srcs: list[SourceType]) -> float: + def _adjoint_src_width_broadband(adj_srcs: list[SourceType]) -> tuple[float, float]: """Find the adjoint source fwidth that sufficiently covers all adjoint frequencies.""" adj_srcs_f0 = [adj_src.source_time._freq0 for adj_src in adj_srcs] @@ -1337,50 +1343,3 @@ def _get_adjoint_data(self, structure_index: int, data_type: str) -> MonitorData monitor_name = Structure._get_monitor_name(index=structure_index, data_type=data_type) return self[monitor_name] - - def to_mat_file(self, fname: PathLike, **kwargs: Any) -> None: - """Output the ``SimulationData`` object as ``.mat`` MATLAB file. - - Parameters - ---------- - fname : PathLike - Full path to the output file. Should include ``.mat`` file extension. - **kwargs : dict, optional - Extra arguments to ``scipy.io.savemat``: see ``scipy`` documentation for more detail. - - Example - ------- - >>> simData.to_mat_file('/path/to/file/data.mat') # doctest: +SKIP - """ - # Check .mat file extension is given - extension = pathlib.Path(fname).suffixes[0].lower() - if len(extension) == 0: - raise FileError(f"File '{fname}' missing extension.") - if extension != ".mat": - raise FileError(f"File '{fname}' should have a .mat extension.") - - # Handle m_dict in kwargs - if "m_dict" in kwargs: - raise ValueError( - "'m_dict' is automatically determined by 'to_mat_file', can't pass to 'savemat'." - ) - - # Get SimData object as dictionary - sim_dict = self.dict() - - # set long field names true by default, otherwise it wont save fields with > 31 characters - if "long_field_names" not in kwargs: - kwargs["long_field_names"] = True - - # Remove NoneType values from dict - # Built from theory discussed in https://github.com/scipy/scipy/issues/3488 - modified_sim_dict = replace_values(sim_dict, None, []) - - try: - from scipy.io import savemat - - savemat(fname, modified_sim_dict, **kwargs) - except Exception as e: - raise ValueError( - "Could not save supplied 'SimulationData' to file. As this is an experimental feature, we may not be able to support the contents of your dataset. If you receive this error, please feel free to raise an issue on our front end repository so we can investigate." - ) from e diff --git a/tidy3d/components/data/unstructured/base.py b/tidy3d/components/data/unstructured/base.py index 44ba534833..431851eb50 100644 --- a/tidy3d/components/data/unstructured/base.py +++ b/tidy3d/components/data/unstructured/base.py @@ -4,14 +4,14 @@ import numbers from abc import ABC, abstractmethod -from os import PathLike -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any import numpy as np -import pydantic.v1 as pd +from pandas import RangeIndex +from pydantic import Field, field_validator, model_validator from xarray import DataArray as XrDataArray -from tidy3d.components.base import cached_property, skip_if_fields_missing +from tidy3d.components.base import cached_property from tidy3d.components.data.data_array import ( DATA_ARRAY_MAP, CellDataArray, @@ -21,12 +21,30 @@ SpatialDataArray, ) from tidy3d.components.data.dataset import Dataset -from tidy3d.components.types import ArrayLike, Axis, Bound from tidy3d.constants import inf from tidy3d.exceptions import DataError, Tidy3dNotImplementedError, ValidationError from tidy3d.log import log from tidy3d.packaging import requires_vtk, vtk +if TYPE_CHECKING: + from os import PathLike + from typing import Literal, Optional, Union + + from numpy.typing import DTypeLike, NDArray + from pydantic import PositiveInt + from vtkmodules.vtkCommonCore import vtkPoints + from vtkmodules.vtkCommonDataModel import ( + vtkCellArray, + vtkDataSet, + vtkPointData, + vtkPolyData, + vtkUnstructuredGrid, + ) + + from tidy3d.compat import Self + from tidy3d.components.data.data_array import DataArray + from tidy3d.components.types.base import ArrayLike, Axis, Bound + DEFAULT_MAX_SAMPLES_PER_STEP = 10_000 DEFAULT_MAX_CELLS_PER_STEP = 10_000 DEFAULT_TOLERANCE_CELL_FINDING = 1e-6 @@ -35,20 +53,17 @@ class UnstructuredGridDataset(Dataset, np.lib.mixins.NDArrayOperatorsMixin, ABC): """Abstract base for datasets that store unstructured grid data.""" - points: PointDataArray = pd.Field( - ..., + points: PointDataArray = Field( title="Grid Points", description="Coordinates of points composing the unstructured grid.", ) - values: IndexedDataArrayTypes = pd.Field( - ..., + values: IndexedDataArrayTypes = Field( title="Point Values", description="Values stored at the grid points.", ) - cells: CellDataArray = pd.Field( - ..., + cells: CellDataArray = Field( title="Grid Cells", description="Cells composing the unstructured grid specified as connections between grid " "points.", @@ -58,18 +73,19 @@ class UnstructuredGridDataset(Dataset, np.lib.mixins.NDArrayOperatorsMixin, ABC) @classmethod @abstractmethod - def _point_dims(cls) -> pd.PositiveInt: + def _point_dims(cls) -> PositiveInt: """Dimensionality of stored grid point coordinates.""" @classmethod @abstractmethod - def _cell_num_vertices(cls) -> pd.PositiveInt: + def _cell_num_vertices(cls) -> PositiveInt: """Number of vertices in a cell.""" """ Validators """ - @pd.validator("points", always=True) - def points_right_dims(cls, val): + @field_validator("points") + @classmethod + def points_right_dims(cls, val: PointDataArray) -> PointDataArray: """Check that point coordinates have the right dimensionality.""" # currently support only the standard axis ordering, that is 01(2) axis_coords_expected = np.arange(cls._point_dims()) @@ -81,8 +97,9 @@ def points_right_dims(cls, val): ) return val - @pd.validator("points", always=True) - def points_right_indexing(cls, val): + @field_validator("points") + @classmethod + def points_right_indexing(cls, val: PointDataArray) -> PointDataArray: """Check that points are indexed corrrectly.""" indices_expected = np.arange(len(val.data)) indices_given = val.index.data @@ -94,15 +111,17 @@ def points_right_indexing(cls, val): ) return val - @pd.validator("values", always=True) - def first_values_dim_is_index(cls, val): + @field_validator("values") + @classmethod + def first_values_dim_is_index(cls, val: IndexedDataArrayTypes) -> IndexedDataArrayTypes: """Check that the number of data values matches the number of grid points.""" if val.dims[0] != "index": raise ValidationError("First dimension of array 'values' must be 'index'.") return val - @pd.validator("values", always=True) - def values_right_indexing(cls, val): + @field_validator("values") + @classmethod + def values_right_indexing(cls, val: IndexedDataArrayTypes) -> IndexedDataArrayTypes: """Check that data values are indexed correctly.""" # currently support only simple ordered indexing of points, that is, 0, 1, 2, ... indices_expected = np.arange(len(val.index.data)) @@ -115,24 +134,22 @@ def values_right_indexing(cls, val): ) return val - @pd.root_validator(skip_on_failure=True) - def number_of_values_matches_points(cls, values): + @model_validator(mode="after") + def number_of_values_matches_points(self) -> Self: """Check that the number of data values matches the number of grid points.""" - points = values.get("points") - vals = values.get("values") + num_values = len(self.values.index) + num_points = len(self.points) - if points is not None and vals is not None: - num_points = len(points) - num_values = len(vals.index) - if num_points != num_values: - raise ValidationError( - f"The number of data values ({num_values}) does not match the number of grid " - f"points ({num_points})." - ) - return values + if num_points != num_values: + raise ValidationError( + f"The number of data values ({num_values}) does not match the number of grid " + f"points ({num_points})." + ) + return self - @pd.validator("cells", always=True) - def match_cells_to_vtk_type(cls, val): + @field_validator("cells") + @classmethod + def match_cells_to_vtk_type(cls, val: CellDataArray) -> CellDataArray: """Check that cell connections does not have duplicate points.""" if vtk is None: return val @@ -140,8 +157,9 @@ def match_cells_to_vtk_type(cls, val): # using val.astype(np.int32/64) directly causes issues when dataarray are later checked == return CellDataArray(val.data.astype(vtk["id_type"], copy=False), coords=val.coords) - @pd.validator("cells", always=True) - def cells_right_type(cls, val): + @field_validator("cells") + @classmethod + def cells_right_type(cls, val: CellDataArray) -> CellDataArray: """Check that cell are of the right type.""" # only supporting the standard ordering of cell vertices 012(3) vertex_coords_expected = np.arange(cls._cell_num_vertices()) @@ -153,18 +171,19 @@ def cells_right_type(cls, val): ) return val - @pd.validator("cells", always=True) - @skip_if_fields_missing(["points"]) - def check_cell_vertex_range(cls, val, values): + @model_validator(mode="after") + def check_cell_vertex_range(self) -> Self: """Check that cell connections use only defined points.""" + val = getattr(self, "cells", None) + if val is None: + return self all_point_indices_used = val.data.ravel() # skip validation if zero size data if len(all_point_indices_used) > 0: min_index_used = np.min(all_point_indices_used) max_index_used = np.max(all_point_indices_used) - points = values.get("points") - num_points = len(points) + num_points = len(self.points) if max_index_used > num_points - 1 or min_index_used < 0: raise ValidationError( @@ -172,10 +191,11 @@ def check_cell_vertex_range(cls, val, values): f"[{min_index_used}, {max_index_used}]. The valid range of point indices is " f"[0, {num_points - 1}]." ) - return val + return self - @pd.validator("cells", always=True) - def warn_degenerate_cells(cls, val): + @field_validator("cells") + @classmethod + def warn_degenerate_cells(cls, val: CellDataArray) -> CellDataArray: """Check that cell connections does not have duplicate points.""" degenerate_cells = cls._find_degenerate_cells(val) num_degenerate_cells = len(degenerate_cells) @@ -188,34 +208,58 @@ def warn_degenerate_cells(cls, val): ) return val - @pd.root_validator(pre=True, allow_reuse=True) - def _warn_if_none(cls, values): + @model_validator(mode="before") + @classmethod + def _warn_if_none(cls, data: Any) -> Any: """Warn if any of data arrays are not loaded.""" + if not isinstance(data, dict): + return data # already validated + no_data_fields = [] for field_name in ["points", "cells", "values"]: - field = values.get(field_name) + field = data.get(field_name) if isinstance(field, str) and field in DATA_ARRAY_MAP.keys(): no_data_fields.append(field_name) + if len(no_data_fields) > 0: formatted_names = [f"'{fname}'" for fname in no_data_fields] log.warning( f"Loading {', '.join(formatted_names)} without data. Constructing an empty dataset." ) - values["points"] = PointDataArray( + data["points"] = PointDataArray( np.zeros((0, cls._point_dims())), dims=["index", "axis"] ) - values["cells"] = CellDataArray( + data["cells"] = CellDataArray( np.zeros((0, cls._cell_num_vertices())), dims=["cell_index", "vertex_index"] ) - values["values"] = IndexedDataArray(np.zeros(0), dims=["index"]) - return values + data["values"] = IndexedDataArray(np.zeros(0), dims=["index"]) - @pd.root_validator(skip_on_failure=True, allow_reuse=True) - def _warn_unused_points(cls, values): + return data + + @model_validator(mode="before") + @classmethod + def _add_default_coords(cls, data: dict) -> dict: + def _add_default_coords(da: DataArray) -> DataArray: + """Add 0..N-1 coordinates to any dimension that does not already have one. + Note: We use a pandas `RangeIndex` here for constant memory. + """ + missing = {d: RangeIndex(da.sizes[d]) for d in da.dims if d not in da.coords} + return da.assign_coords(missing) if missing else da + + if "points" in data: + data["points"] = _add_default_coords(data["points"]) + if "cells" in data: + data["cells"] = _add_default_coords(data["cells"]) + if "values" in data: + data["values"] = _add_default_coords(data["values"]) + return data + + @model_validator(mode="after") + def _warn_unused_points(self) -> Self: """Warn if some points are unused.""" - point_indices = set(np.arange(len(values["points"].data))) - used_indices = set(values["cells"].values.ravel()) + point_indices = set(np.arange(len(self.points.data))) + used_indices = set(self.cells.values.ravel()) if not point_indices.issubset(used_indices): log.warning( @@ -223,7 +267,7 @@ def _warn_unused_points(cls, values): "Consider calling 'clean()' to remove them." ) - return values + return self """ Convenience properties """ @@ -243,34 +287,34 @@ def is_complex(self) -> bool: return np.iscomplexobj(self.values) @property - def _double_type(self): + def _double_type(self) -> DTypeLike: """Corresponding double data type.""" return np.complex128 if self.is_complex else np.float64 @property - def is_uniform(self): + def is_uniform(self) -> bool: """Whether each element is of equal value in ``values``.""" return self.values.is_uniform @cached_property - def _values_coords_dict(self): + def _values_coords_dict(self) -> dict[str, Any]: """Non-spatial dimensions are corresponding coordinate values of stored data.""" coord_dict = {dim: self.values.coords[dim].data for dim in self.values.dims} _ = coord_dict.pop("index") return coord_dict @cached_property - def _fields_shape(self): + def _fields_shape(self) -> list[int]: """Shape in which fields are stored.""" return [len(coord) for coord in self._values_coords_dict.values()] @cached_property - def _num_fields(self): + def _num_fields(self) -> int: """Total number of stored fields.""" return 1 if len(self._fields_shape) == 0 else np.prod(self._fields_shape) @cached_property - def _values_type(self): + def _values_type(self) -> type: """Type of array storing values.""" return type(self.values) @@ -287,7 +331,7 @@ def _points_3d_array(self) -> None: """ Grid cleaning """ @classmethod - def _find_degenerate_cells(cls, cells: CellDataArray): + def _find_degenerate_cells(cls, cells: CellDataArray) -> set[int]: """Find explicitly degenerate cells if any. That is, cells that use the same point indices for their different vertices. """ @@ -297,14 +341,13 @@ def _find_degenerate_cells(cls, cells: CellDataArray): if len(indices) > 0: for i in range(cls._cell_num_vertices() - 1): for j in range(i + 1, cls._cell_num_vertices()): - degenerate_cell_inds = degenerate_cell_inds.union( - np.where(indices[:, i] == indices[:, j])[0] - ) + new_inds = np.where(indices[:, i] == indices[:, j])[0] + degenerate_cell_inds |= {int(k) for k in new_inds} return degenerate_cell_inds @classmethod - def _remove_degenerate_cells(cls, cells: CellDataArray): + def _remove_degenerate_cells(cls, cells: CellDataArray) -> CellDataArray: """Remove explicitly degenerate cells if any. That is, cells that use the same point indices for their different vertices. """ @@ -320,7 +363,7 @@ def _remove_degenerate_cells(cls, cells: CellDataArray): @classmethod def _remove_unused_points( cls, points: PointDataArray, values: IndexedDataArrayTypes, cells: CellDataArray - ): + ) -> tuple[PointDataArray, IndexedDataArrayTypes, CellDataArray]: """Remove unused points if any. That is, points that are not used in any grid cell. """ @@ -343,7 +386,9 @@ def _remove_unused_points( return points, values, cells - def clean(self, remove_degenerate_cells=True, remove_unused_points=True): + def clean( + self, remove_degenerate_cells: bool = True, remove_unused_points: bool = True + ) -> Self: """Remove degenerate cells and/or unused points.""" if remove_degenerate_cells: cells = self._remove_degenerate_cells(cells=self.cells) @@ -360,7 +405,9 @@ def clean(self, remove_degenerate_cells=True, remove_unused_points=True): """ Arithmetic operations """ - def __array_ufunc__(self, ufunc, method, *inputs: Any, **kwargs: Any): + def __array_ufunc__( + self, ufunc: np.ufunc, method: str, *inputs: Union[Self, numbers.Number], **kwargs: Any + ) -> Optional[Union[Self, tuple[Self, ...]]]: """Override of numpy functions.""" out = kwargs.get("out", ()) @@ -395,7 +442,7 @@ def __array_ufunc__(self, ufunc, method, *inputs: Any, **kwargs: Any): return self.updated_copy(values=result) @property - def real(self) -> UnstructuredGridDataset: + def real(self) -> Self: """Real part of dataset.""" return self.updated_copy(values=self.values.real) @@ -428,7 +475,7 @@ def _vtk_offsets(self) -> ArrayLike: @property @requires_vtk - def _vtk_cells(self): + def _vtk_cells(self) -> vtkCellArray: """VTK cell array to use in the VTK representation.""" cells = vtk["mod"].vtkCellArray() cells.SetData( @@ -439,7 +486,7 @@ def _vtk_cells(self): @property @requires_vtk - def _vtk_points(self): + def _vtk_points(self) -> vtkPoints: """VTK point array to use in the VTK representation.""" pts = vtk["mod"].vtkPoints() pts.SetData(vtk["numpy_to_vtk"](self._points_3d_array)) @@ -447,7 +494,7 @@ def _vtk_points(self): @property @requires_vtk - def _vtk_obj(self): + def _vtk_obj(self) -> vtkUnstructuredGrid: """A VTK representation (vtkUnstructuredGrid) of the grid.""" grid = vtk["mod"].vtkUnstructuredGrid() @@ -475,7 +522,7 @@ def _vtk_obj(self): @staticmethod @requires_vtk - def _read_vtkUnstructuredGrid(fname: PathLike): + def _read_vtkUnstructuredGrid(fname: PathLike) -> vtkUnstructuredGrid: """Load a :class:`vtkUnstructuredGrid` from a file.""" fname = str(fname) reader = vtk["mod"].vtkXMLUnstructuredGridReader() @@ -487,7 +534,7 @@ def _read_vtkUnstructuredGrid(fname: PathLike): @staticmethod @requires_vtk - def _read_vtkLegacyFile(fname: PathLike): + def _read_vtkLegacyFile(fname: PathLike) -> vtkUnstructuredGrid: """Load a grid from a legacy `.vtk` file.""" fname = str(fname) reader = vtk["mod"].vtkGenericDataObjectReader() @@ -502,20 +549,20 @@ def _read_vtkLegacyFile(fname: PathLike): @requires_vtk def _from_vtk_obj( cls, - vtk_obj, + vtk_obj: vtkUnstructuredGrid, field: Optional[str] = None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, - values_type=IndexedDataArray, - expect_complex=None, - ignore_invalid_cells=False, + values_type: type = IndexedDataArray, + expect_complex: Optional[bool] = None, + ignore_invalid_cells: bool = False, ) -> UnstructuredGridDataset: """Initialize from a vtk object.""" @requires_vtk def _from_vtk_obj_internal( self, - vtk_obj, + vtk_obj: vtkUnstructuredGrid, remove_degenerate_cells: bool = True, remove_unused_points: bool = True, ) -> UnstructuredGridDataset: @@ -628,8 +675,8 @@ def to_vtu(self, fname: PathLike) -> None: @requires_vtk def _cell_to_point_data( cls, - vtk_obj, - ): + vtk_obj: vtkCellArray, + ) -> vtkPointData: """Get point data values from a VTK object.""" cellDataToPointData = vtk["mod"].vtkCellDataToPointData() @@ -642,11 +689,11 @@ def _cell_to_point_data( @requires_vtk def _get_values_from_vtk( cls, - vtk_obj, - num_points: pd.PositiveInt, + vtk_obj: vtkDataSet, + num_points: PositiveInt, field: Optional[str] = None, - values_type=IndexedDataArray, - expect_complex=None, + values_type: type = IndexedDataArray, + expect_complex: Optional[bool] = None, ) -> IndexedDataArray: """Get point data values from a VTK object.""" @@ -682,7 +729,7 @@ def _get_values_from_vtk( f" of grid points ({num_points})." ) - values_numpy = vtk["vtk_to_numpy"](array_vtk) + values_numpy = np.array(vtk["vtk_to_numpy"](array_vtk), copy=True) values_name = array_vtk.GetName() # vtk doesn't support complex numbers @@ -713,7 +760,7 @@ def _get_values_from_vtk( return values - def get_cell_values(self, **kwargs: Any): + def get_cell_values(self, **kwargs: Any) -> NDArray: """This function returns the cell values for the fields stored in the UnstructuredGridDataset. If multiple fields are stored per point, like in an IndexedVoltageDataArray, cell values will be provided for each of the fields unless a selection argument is provided, e.g., voltage=0.2 @@ -738,7 +785,7 @@ def get_cell_volumes(self) -> None: """ Grid operations """ @requires_vtk - def _plane_slice_raw(self, axis: Axis, pos: float): + def _plane_slice_raw(self, axis: Axis, pos: float) -> vtkPolyData: """Slice data with a plane and return the resulting VTK object.""" if pos > self.bounds[1][axis] or pos < self.bounds[0][axis]: @@ -800,7 +847,7 @@ def box_clip(self, bounds: Bound) -> UnstructuredGridDataset: Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -962,7 +1009,12 @@ def interp( return result - def _non_spatial_interp(self, method="linear", fill_value=np.nan, **coords_kwargs: Any): + def _non_spatial_interp( + self, + method: Literal["linear", "nearest"] = "linear", + fill_value: Union[float, Literal["extrapolate"]] = np.nan, + **coords_kwargs: Any, + ) -> Self: """Interpolate data at non-spatial dimensions using xarray's interp() function. Parameters @@ -988,7 +1040,7 @@ def _non_spatial_interp(self, method="linear", fill_value=np.nan, **coords_kwarg return self.updated_copy( values=self.values.interp( **coords_kwargs_only_lists, - method="linear", + method=method, kwargs={"fill_value": fill_value}, ) ) @@ -1208,12 +1260,17 @@ def _interp_vtk( array_id = 0 if self.values.name is None else self.values.name # TODO: generalize this - values_numpy = vtk["vtk_to_numpy"](interpolated.GetPointData().GetAbstractArray(array_id)) + values_numpy = np.array( + vtk["vtk_to_numpy"](interpolated.GetPointData().GetAbstractArray(array_id)), copy=True + ) # fill points without interpolated values if fill_value != 0: - mask = vtk["vtk_to_numpy"]( - interpolated.GetPointData().GetAbstractArray("vtkValidPointMask") + mask = np.array( + vtk["vtk_to_numpy"]( + interpolated.GetPointData().GetAbstractArray("vtkValidPointMask") + ), + copy=True, ) values_numpy[mask != 1] = fill_value @@ -1471,7 +1528,7 @@ def _interp_py_chunk( Parameters ---------- - xyz_grid : Tuple[ArrayLike[float], ...] + xyz_grid : tuple[ArrayLike[float], ...] x, y, and z coordiantes defining rectilinear grid. cell_inds : ArrayLike[int] Indices of cells to perfrom interpolation from. @@ -1484,7 +1541,7 @@ def _interp_py_chunk( Returns ------- - Tuple[Tuple[ArrayLike, ...], ArrayLike] + tuple[tuple[ArrayLike, ...], ArrayLike] x, y, and z indices of interpolated values and values themselves. """ @@ -1732,13 +1789,15 @@ def sel( def _non_spatial_sel( self, - method=None, + method: Optional[Literal["nearest", "pad", "ffill", "backfill", "bfill"]] = None, **sel_kwargs: Any, ) -> XrDataArray: """Select/interpolate data along one or more non-Cartesian directions. Parameters ---------- + method: Optional[Literal["nearest", "pad", "ffill", "backfill", "bfill"]] = None + Method to use in xarray sel() function. **sel_kwargs : dict Keyword arguments to pass to the xarray sel() function. @@ -1792,7 +1851,7 @@ def sel_inside(self, bounds: Bound) -> UnstructuredGridDataset: Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns @@ -1859,7 +1918,7 @@ def does_cover(self, bounds: Bound) -> bool: Parameters ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] + bounds : tuple[float, float, float], tuple[float, float float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. Returns diff --git a/tidy3d/components/data/unstructured/tetrahedral.py b/tidy3d/components/data/unstructured/tetrahedral.py index 9c2cbd582d..02aa6b3fa0 100644 --- a/tidy3d/components/data/unstructured/tetrahedral.py +++ b/tidy3d/components/data/unstructured/tetrahedral.py @@ -2,25 +2,28 @@ from __future__ import annotations -from typing import Any, Union +from typing import TYPE_CHECKING, Any import numpy as np -import pydantic.v1 as pd -from xarray import DataArray as XrDataArray from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import ( - CellDataArray, - IndexedDataArray, - PointDataArray, -) -from tidy3d.components.types import ArrayLike, Axis, Bound, Coordinate +from tidy3d.components.data.data_array import CellDataArray, IndexedDataArray, PointDataArray from tidy3d.exceptions import DataError from tidy3d.packaging import requires_vtk, vtk from .base import UnstructuredGridDataset from .triangular import TriangularGridDataset +if TYPE_CHECKING: + from typing import Literal, Optional, Union + + from pydantic import PositiveInt + from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid + from xarray import DataArray + from xarray import DataArray as XrDataArray + + from tidy3d.components.types.base import ArrayLike, Axis, Bound, Coordinate + class TetrahedralGridDataset(UnstructuredGridDataset): """Dataset for storing tetrahedral grid data. Data values are associated with the nodes of @@ -58,17 +61,17 @@ class TetrahedralGridDataset(UnstructuredGridDataset): """ Fundametal parameters to set up based on grid dimensionality """ @classmethod - def _traingular_dataset_type(cls) -> type: + def _triangular_dataset_type(cls) -> type: """Corresponding class for triangular grid datasets. We need to know this when creating a triangular slice from a tetrahedral grid.""" return TriangularGridDataset @classmethod - def _point_dims(cls) -> pd.PositiveInt: + def _point_dims(cls) -> PositiveInt: """Dimensionality of stored grid point coordinates.""" return 3 @classmethod - def _cell_num_vertices(cls) -> pd.PositiveInt: + def _cell_num_vertices(cls) -> PositiveInt: """Number of vertices in a cell.""" return 4 @@ -83,7 +86,7 @@ def _points_3d_array(self) -> Bound: @classmethod @requires_vtk - def _vtk_cell_type(cls): + def _vtk_cell_type(cls) -> int: """VTK cell type to use in the VTK representation.""" return vtk["mod"].VTK_TETRA @@ -91,29 +94,35 @@ def _vtk_cell_type(cls): @requires_vtk def _from_vtk_obj( cls, - vtk_obj, - field=None, + vtk_obj: vtkUnstructuredGrid, + field: Optional[str] = None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, - values_type=IndexedDataArray, + values_type: type = IndexedDataArray, expect_complex: bool = False, ignore_invalid_cells: bool = False, ) -> TetrahedralGridDataset: """Initialize from a vtkUnstructuredGrid instance.""" # read point, cells, and values info from a vtk instance - cells_numpy = vtk["vtk_to_numpy"](vtk_obj.GetCells().GetConnectivityArray()) - points_numpy = vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData()) + cells_numpy = np.array( + vtk["vtk_to_numpy"](vtk_obj.GetCells().GetConnectivityArray()), + copy=True, + ) + points_numpy = np.array(vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData()), copy=True) values = cls._get_values_from_vtk( vtk_obj, len(points_numpy), field, values_type, expect_complex ) # verify cell_types - cells_types = vtk["vtk_to_numpy"](vtk_obj.GetCellTypesArray()) + cells_types = np.array(vtk["vtk_to_numpy"](vtk_obj.GetCellTypesArray()), copy=True) invalid_cells = cells_types != cls._vtk_cell_type() if any(invalid_cells): if ignore_invalid_cells: - cell_offsets = vtk["vtk_to_numpy"](vtk_obj.GetCells().GetOffsetsArray()) + cell_offsets = np.array( + vtk["vtk_to_numpy"](vtk_obj.GetCells().GetOffsetsArray()), + copy=True, + ) valid_cell_offsets = cell_offsets[:-1][invalid_cells == 0] cells_numpy = cells_numpy[ np.ravel( @@ -172,7 +181,7 @@ def plane_slice(self, axis: Axis, pos: float) -> TriangularGridDataset: slice_vtk = self._plane_slice_raw(axis=axis, pos=pos) - return self._traingular_dataset_type()._from_vtk_obj( + return self._triangular_dataset_type()._from_vtk_obj( slice_vtk, remove_degenerate_cells=True, remove_unused_points=True, @@ -301,7 +310,7 @@ def sel( x: Union[float, ArrayLike] = None, y: Union[float, ArrayLike] = None, z: Union[float, ArrayLike] = None, - method=None, + method: Optional[Literal["nearest", "pad", "ffill", "backfill", "bfill"]] = None, **sel_kwargs: Any, ) -> Union[TriangularGridDataset, XrDataArray]: """Extract/interpolate data along one or more spatial or non-spatial directions. Must provide at least one argument @@ -317,7 +326,7 @@ def sel( y-coordinate of the slice. z : Union[float, ArrayLike] = None z-coordinate of the slice. - method: Literal[None, "nearest", "pad", "ffill", "backfill", "bfill"] = None + method: Optional[Literal["nearest", "pad", "ffill", "backfill", "bfill"]] = None Method to use in xarray sel() function. **sel_kwargs : dict Keyword arguments to pass to the xarray sel() function. @@ -361,7 +370,7 @@ def sel( return self_after_non_spatial_sel - def get_cell_volumes(self): + def get_cell_volumes(self) -> DataArray: """Get the volumes associated to each cell in the grid""" v0 = self.points[self.cells.sel(vertex_index=0)] e01 = self.points[self.cells.sel(vertex_index=1)] - v0 diff --git a/tidy3d/components/data/unstructured/triangular.py b/tidy3d/components/data/unstructured/triangular.py index b187938049..84331882a3 100644 --- a/tidy3d/components/data/unstructured/triangular.py +++ b/tidy3d/components/data/unstructured/triangular.py @@ -2,10 +2,11 @@ from __future__ import annotations -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any import numpy as np -import pydantic.v1 as pd +from pydantic import Field +from xarray import DataArray as XrDataArray try: from matplotlib import pyplot as plt @@ -13,8 +14,6 @@ except ImportError: pass -from xarray import DataArray as XrDataArray - from tidy3d.components.base import cached_property from tidy3d.components.data.data_array import ( CellDataArray, @@ -22,7 +21,7 @@ PointDataArray, SpatialDataArray, ) -from tidy3d.components.types import ArrayLike, Ax, Axis, Bound +from tidy3d.components.types.base import Axis from tidy3d.components.viz import add_ax_if_none, equal_aspect, plot_params_grid from tidy3d.constants import inf from tidy3d.exceptions import DataError @@ -36,6 +35,16 @@ UnstructuredGridDataset, ) +if TYPE_CHECKING: + from typing import Literal, Optional, Union + + from pydantic import PositiveInt + from vtkmodules.vtkCommonDataModel import vtkPointSet + from xarray import DataArray + + from tidy3d.compat import Self + from tidy3d.components.types.base import ArrayLike, Ax, Bound + class TriangularGridDataset(UnstructuredGridDataset): """Dataset for storing triangular grid data. Data values are associated with the nodes of @@ -72,14 +81,12 @@ class TriangularGridDataset(UnstructuredGridDataset): ... ) """ - normal_axis: Axis = pd.Field( - ..., + normal_axis: Axis = Field( title="Grid Axis", description="Orientation of the grid.", ) - normal_pos: float = pd.Field( - ..., + normal_pos: float = Field( title="Position", description="Coordinate of the grid along the normal direction.", ) @@ -87,12 +94,12 @@ class TriangularGridDataset(UnstructuredGridDataset): """ Fundamental parameters to set up based on grid dimensionality """ @classmethod - def _point_dims(cls) -> pd.PositiveInt: + def _point_dims(cls) -> PositiveInt: """Dimensionality of stored grid point coordinates.""" return 2 @classmethod - def _cell_num_vertices(cls) -> pd.PositiveInt: + def _cell_num_vertices(cls) -> PositiveInt: """Number of vertices in a cell.""" return 3 @@ -118,7 +125,7 @@ def _points_3d_array(self) -> ArrayLike: @classmethod @requires_vtk - def _vtk_cell_type(cls): + def _vtk_cell_type(cls) -> int: """VTK cell type to use in the VTK representation.""" return vtk["mod"].VTK_TRIANGLE @@ -126,14 +133,14 @@ def _vtk_cell_type(cls): @requires_vtk def _from_vtk_obj( cls, - vtk_obj, - field=None, + vtk_obj: vtkPointSet, + field: Optional[str] = None, remove_degenerate_cells: bool = False, remove_unused_points: bool = False, - values_type=IndexedDataArray, - expect_complex=None, + values_type: type = IndexedDataArray, + expect_complex: Optional[bool] = None, ignore_invalid_cells: bool = False, - ): + ) -> Self: """Initialize from a vtkUnstructuredGrid instance.""" # get points cells data from vtk object @@ -142,10 +149,13 @@ def _from_vtk_obj( elif isinstance(vtk_obj, vtk["mod"].vtkUnstructuredGrid): cells_vtk = vtk_obj.GetCells() - cells_numpy = vtk["vtk_to_numpy"](cells_vtk.GetConnectivityArray()) + cells_numpy = np.array( + vtk["vtk_to_numpy"](cells_vtk.GetConnectivityArray()), + copy=True, + ) # verify cell_types - cell_offsets = vtk["vtk_to_numpy"](cells_vtk.GetOffsetsArray()) + cell_offsets = np.array(vtk["vtk_to_numpy"](cells_vtk.GetOffsetsArray()), copy=True) invalid_cells = np.diff(cell_offsets) != cls._cell_num_vertices() if np.any(invalid_cells): if ignore_invalid_cells: @@ -159,7 +169,7 @@ def _from_vtk_obj( "'TriangularGridDataset'." ) - points_numpy = vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData()) + points_numpy = np.array(vtk["vtk_to_numpy"](vtk_obj.GetPoints().GetData()), copy=True) # data values are read directly into Tidy3D array values = cls._get_values_from_vtk( @@ -175,7 +185,7 @@ def _from_vtk_obj( f"Provided vtk grid does not represent a two dimensional grid. Found zero size dimensions are {zero_dims}." ) - normal_axis = zero_dims[0] + normal_axis = int(zero_dims[0]) normal_pos = points_numpy[0][normal_axis] tan_dims = [0, 1, 2] tan_dims.remove(normal_axis) @@ -243,7 +253,7 @@ def plane_slice(self, axis: Axis, pos: float) -> XrDataArray: # perform slicing in vtk and get unprocessed points and values slice_vtk = self._plane_slice_raw(axis=axis, pos=pos) - points_numpy = vtk["vtk_to_numpy"](slice_vtk.GetPoints().GetData()) + points_numpy = np.array(vtk["vtk_to_numpy"](slice_vtk.GetPoints().GetData()), copy=True) values = self._get_values_from_vtk( slice_vtk, len(points_numpy), @@ -670,7 +680,7 @@ def plot( ax.set_title(f"{normal_axis_name} = {self.normal_pos}") return ax - def get_cell_volumes(self): + def get_cell_volumes(self) -> DataArray: """Get areas associated to each cell of the grid.""" v0 = self.points[self.cells.sel(vertex_index=0)] e01 = self.points[self.cells.sel(vertex_index=1)] - v0 diff --git a/tidy3d/components/data/utils.py b/tidy3d/components/data/utils.py index e2c96c0031..ada0a4e77a 100644 --- a/tidy3d/components/data/utils.py +++ b/tidy3d/components/data/utils.py @@ -2,22 +2,30 @@ from __future__ import annotations -from typing import Union +from typing import TYPE_CHECKING, Union import numpy as np import xarray as xr -from tidy3d.components.types import ArrayLike, annotate_type +from tidy3d.components.types.base import discriminated_union -from .data_array import DataArray, SpatialDataArray +from .data_array import SpatialDataArray from .unstructured.base import UnstructuredGridDataset from .unstructured.tetrahedral import TetrahedralGridDataset from .unstructured.triangular import TriangularGridDataset +if TYPE_CHECKING: + from tidy3d.components.types.base import ArrayLike + + from .data_array import DataArray + UnstructuredGridDatasetType = Union[TriangularGridDataset, TetrahedralGridDataset] CustomSpatialDataType = Union[SpatialDataArray, UnstructuredGridDatasetType] -CustomSpatialDataTypeAnnotated = Union[SpatialDataArray, annotate_type(UnstructuredGridDatasetType)] +CustomSpatialDataTypeAnnotated = Union[ + discriminated_union(UnstructuredGridDatasetType), + SpatialDataArray, +] def _get_numpy_array(data_array: Union[ArrayLike, DataArray, UnstructuredGridDataset]) -> ArrayLike: diff --git a/tidy3d/components/data/validators.py b/tidy3d/components/data/validators.py index 4988c36044..86f6d6043e 100644 --- a/tidy3d/components/data/validators.py +++ b/tidy3d/components/data/validators.py @@ -1,79 +1,11 @@ -# special validators for Datasets -from __future__ import annotations - -from typing import Optional - -import numpy as np -import pydantic.v1 as pd - -from tidy3d.exceptions import ValidationError - -from .data_array import DataArray -from .dataset import AbstractFieldDataset, ScalarFieldDataArray - - -# this can't go in validators.py because that file imports dataset.py -def validate_no_nans(field_name: str): - """Raise validation error if nans found in Dataset, or other data-containing item.""" - - @pd.validator(field_name, always=True, allow_reuse=True) - def no_nans(cls, val): - """Raise validation error if nans found in Dataset, or other data-containing item.""" - - if val is None: - return val - - def error_if_has_nans(value, identifier: Optional[str] = None) -> None: - """Recursively check if value (or iterable) has nans and error if so.""" +"""Compatibility shim for :mod:`tidy3d._common.components.data.validators`.""" - def has_nans(values) -> bool: - """Base case: do these values contain NaN?""" - try: - return np.any(np.isnan(values)) - # if this fails for some reason (fails in adjoint, for example), don't check it. - except Exception: - return False +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - if isinstance(value, (tuple, list)): - for i, _value in enumerate(value): - error_if_has_nans(_value, identifier=f"[{i}]") - - elif isinstance(value, AbstractFieldDataset): - for key, val in value.field_components.items(): - error_if_has_nans(val, identifier=f".{key}") - - elif isinstance(value, DataArray): - error_if_has_nans(value.values) - - else: - if has_nans(value): - # the identifier is used to make the message more clear by appending some more info - field_name_display = field_name - if identifier: - field_name_display += identifier - - raise ValidationError( - f"Found NaN values in '{field_name_display}'. " - "If they were not intended, please double check your construction. " - "If intended, to replace these data points with a value 'x'," - " call 'values = np.nan_to_num(values, nan=x)'." - ) - - error_if_has_nans(val) - return val - - return no_nans - - -def validate_can_interpolate(field_name: str): - """Make sure the data in 'field_name' can be interpolated.""" - - @pd.validator(field_name, always=True, allow_reuse=True) - def check_fields_interpolate(cls, val: AbstractFieldDataset) -> AbstractFieldDataset: - if isinstance(val, AbstractFieldDataset): - for name, data in val.field_components.items(): - if isinstance(data, ScalarFieldDataArray): - data._interp_validator(name) - return val +# marked as migrated to _common +from __future__ import annotations - return check_fields_interpolate +from tidy3d._common.components.data.validators import ( + validate_can_interpolate, + validate_no_nans, +) diff --git a/tidy3d/components/data/zbf.py b/tidy3d/components/data/zbf.py index 08d9870216..6827e6a4c6 100644 --- a/tidy3d/components/data/zbf.py +++ b/tidy3d/components/data/zbf.py @@ -1,156 +1,10 @@ -"""ZBF utilities""" +"""Compatibility shim for :mod:`tidy3d._common.components.data.zbf`.""" -from __future__ import annotations - -from struct import unpack - -import numpy as np -import pydantic.v1 as pd - -from tidy3d.components.base import Tidy3dBaseModel - - -class ZBFData(Tidy3dBaseModel): - """ - Contains data read in from a ``.zbf`` file - """ - - version: int = pd.Field(title="Version", description="File format version number.") - nx: int = pd.Field(title="Samples in X", description="Number of samples in the x direction.") - ny: int = pd.Field(title="Samples in Y", description="Number of samples in the y direction.") - ispol: bool = pd.Field( - title="Is Polarized", - description="``True`` if the beam is polarized, ``False`` otherwise.", - ) - unit: str = pd.Field( - title="Spatial Units", description="Spatial units, either 'mm', 'cm', 'in', or 'm'." - ) - dx: float = pd.Field(title="Grid Spacing, X", description="Grid spacing in x.") - dy: float = pd.Field(title="Grid Spacing, Y", description="Grid spacing in y.") - zposition_x: float = pd.Field( - title="Z Position, X Direction", - description="The pilot beam z position with respect to the pilot beam waist, x direction.", - ) - zposition_y: float = pd.Field( - title="Z Position, Y Direction", - description="The pilot beam z position with respect to the pilot beam waist, y direction.", - ) - rayleigh_x: float = pd.Field( - title="Rayleigh Distance, X Direction", - description="The pilot beam Rayleigh distance in the x direction.", - ) - rayleigh_y: float = pd.Field( - title="Rayleigh Distance, Y Direction", - description="The pilot beam Rayleigh distance in the y direction.", - ) - waist_x: float = pd.Field( - title="Beam Waist, X", description="The pilot beam waist in the x direction." - ) - waist_y: float = pd.Field( - title="Beam Waist, Y", description="The pilot beam waist in the y direction." - ) - wavelength: float = pd.Field(..., title="Wavelength", description="The wavelength of the beam.") - background_refractive_index: float = pd.Field( - title="Background Refractive Index", - description="The index of refraction in the current medium.", - ) - receiver_eff: float = pd.Field( - title="Receiver Efficiency", - description="The receiver efficiency. Zero if fiber coupling is not computed.", - ) - system_eff: float = pd.Field( - title="System Efficiency", - description="The system efficiency. Zero if fiber coupling is not computed.", - ) - Ex: np.ndarray = pd.Field( - title="Electric Field, X Component", - description="Complex-valued electric field, x component.", - ) - Ey: np.ndarray = pd.Field( - title="Electric Field, Y Component", - description="Complex-valued electric field, y component.", - ) - - def read_zbf(filename: str) -> ZBFData: - """Reads a Zemax Beam File (``.zbf``) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - Parameters - ---------- - filename : str - The file name of the ``.zbf`` file to read. - - Returns - ------- - :class:`.ZBFData` - """ - - # Read the zbf file - with open(filename, "rb") as f: - # Load the header - version, nx, ny, ispol, units = unpack("<5I", f.read(20)) - f.read(16) # unused values - ( - dx, - dy, - zposition_x, - rayleigh_x, - waist_x, - zposition_y, - rayleigh_y, - waist_y, - wavelength, - background_refractive_index, - receiver_eff, - system_eff, - ) = unpack("<12d", f.read(96)) - f.read(64) # unused values - - # read E field - nsamps = 2 * nx * ny - rawx = list(unpack(f"<{nsamps}d", f.read(8 * nsamps))) - if ispol: - rawy = list(unpack(f"<{nsamps}d", f.read(8 * nsamps))) - - # convert unit key to unit string - map_units = {0: "mm", 1: "cm", 2: "in", 3: "m"} - try: - unit = map_units[units] - except KeyError: - raise KeyError( - f"Invalid units specified in the zbf file (expected '0', '1', '2', or '3', got '{units}')." - ) from None - - # load E field - Ex_real = np.asarray(rawx[0::2]).reshape(nx, ny, order="F") - Ex_imag = np.asarray(rawx[1::2]).reshape(nx, ny, order="F") - if ispol: - Ey_real = np.asarray(rawy[0::2]).reshape(nx, ny, order="F") - Ey_imag = np.asarray(rawy[1::2]).reshape(nx, ny, order="F") - else: - Ey_real = np.zeros((nx, ny)) - Ey_imag = np.zeros((nx, ny)) - - Ex = Ex_real + 1j * Ex_imag - Ey = Ey_real + 1j * Ey_imag +# marked as migrated to _common +from __future__ import annotations - return ZBFData( - version=version, - nx=nx, - ny=ny, - ispol=ispol, - unit=unit, - dx=dx, - dy=dy, - zposition_x=zposition_x, - zposition_y=zposition_y, - rayleigh_x=rayleigh_x, - rayleigh_y=rayleigh_y, - waist_x=waist_x, - waist_y=waist_y, - wavelength=wavelength, - background_refractive_index=background_refractive_index, - receiver_eff=receiver_eff, - system_eff=system_eff, - Ex=Ex, - Ey=Ey, - ) +from tidy3d._common.components.data.zbf import ( + ZBFData, +) diff --git a/tidy3d/components/dispersion_fitter.py b/tidy3d/components/dispersion_fitter.py index 8c1be3a827..25b09d4fdb 100644 --- a/tidy3d/components/dispersion_fitter.py +++ b/tidy3d/components/dispersion_fitter.py @@ -2,17 +2,33 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional import numpy as np -from pydantic.v1 import Field, NonNegativeFloat, PositiveFloat, PositiveInt, validator +from pydantic import ( + Field, + NonNegativeFloat, + PositiveFloat, + PositiveInt, + field_validator, + model_validator, +) from tidy3d.constants import fp_eps from tidy3d.exceptions import ValidationError from tidy3d.log import Progress, get_logging_console, log -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing -from .types import ArrayComplex1D, ArrayComplex2D, ArrayFloat1D, ArrayFloat2D +from .base import Tidy3dBaseModel, cached_property +from .types import ArrayComplex1D + +if TYPE_CHECKING: + from typing import Union + + from numpy.typing import NDArray + + from tidy3d.compat import Self + + from .types import ArrayComplex2D, ArrayFloat1D, ArrayFloat2D # numerical tolerance for pole relocation for fast fitter TOL = 1e-8 @@ -52,7 +68,12 @@ def imag_resp_extrema_locs(poles: ArrayComplex1D, residues: ArrayComplex1D) -> A Complex-valued residues for the model. """ - def _extrema_loss_freq_finder(areal, aimag, creal, cimag): + def _extrema_loss_freq_finder( + areal: ArrayFloat1D, + aimag: ArrayFloat1D, + creal: ArrayFloat1D, + cimag: ArrayFloat1D, + ) -> ArrayFloat1D: """For each pole, find frequencies for the extrema of Im[eps]""" a_square = areal**2 + aimag**2 @@ -120,7 +141,7 @@ class AdvancedFastFitterParam(Tidy3dBaseModel): "A finite upper bound may be helpful when fitting lossless materials. " "In this case, consider also increasing the weight for fitting the imaginary part.", ) - weights: tuple[NonNegativeFloat, NonNegativeFloat] = Field( + weights: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = Field( None, title="Weights", description="Weights (real, imag) in objective function for fitting. The weights " @@ -186,8 +207,9 @@ class AdvancedFastFitterParam(Tidy3dBaseModel): "There will be a warning if this value is too small.", ) - @validator("loss_bounds", always=True) - def _max_loss_geq_min_loss(cls, val): + @field_validator("loss_bounds") + @classmethod + def _max_loss_geq_min_loss(cls, val: tuple[float, float]) -> tuple[float, float]: """Must have max_loss >= min_loss.""" if val[0] > val[1]: raise ValidationError( @@ -195,8 +217,11 @@ def _max_loss_geq_min_loss(cls, val): ) return val - @validator("weights", always=True) - def _weights_average_to_one(cls, val): + @field_validator("weights") + @classmethod + def _weights_average_to_one( + cls, val: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] + ) -> Optional[tuple[NonNegativeFloat, NonNegativeFloat]]: """Weights must average to one.""" if val is None: return None @@ -209,25 +234,39 @@ class FastFitterData(AdvancedFastFitterParam): """Data class for internal use while running fitter.""" omega: ArrayComplex1D = Field( - ..., title="Angular frequencies in eV", description="Angular frequencies in eV" + title="Angular frequencies in eV", + description="Angular frequencies in eV", + ) + eps: ArrayComplex1D = Field( + title="Permittivity", + description="Permittivity to fit", ) - eps: ArrayComplex1D = Field(..., title="Permittivity", description="Permittivity to fit") - optimize_eps_inf: bool = Field( - None, title="Optimize eps_inf", description="Whether to optimize ``eps_inf``." + optimize_eps_inf: Optional[bool] = Field( + None, + title="Optimize eps_inf", + description="Whether to optimize ``eps_inf``.", ) - num_poles: PositiveInt = Field(None, title="Number of poles", description="Number of poles") - eps_inf: float = Field( + num_poles: Optional[PositiveInt] = Field( + None, + title="Number of poles", + description="Number of poles", + ) + eps_inf: Optional[float] = Field( None, title="eps_inf", description="Value of ``eps_inf``.", ) poles: Optional[ArrayComplex1D] = Field( - None, title="Pole frequencies in eV", description="Pole frequencies in eV" + None, + title="Pole frequencies in eV", + description="Pole frequencies in eV", ) residues: Optional[ArrayComplex1D] = Field( - None, title="Residues in eV", description="Residues in eV" + None, + title="Residues in eV", + description="Residues in eV", ) passivity_optimized: Optional[bool] = Field( @@ -252,37 +291,31 @@ class FastFitterData(AdvancedFastFitterParam): ) scale_factor: PositiveFloat = Field( - ..., title="Scale Factor", description="Factor by which frequencies have been rescaled prior to fitting. " "The ``pole_residue`` model returned will be rescaled by the inverse of this factor " "in order to restore it to the original units.", ) - @validator("eps_inf", always=True) - @skip_if_fields_missing(["optimize_eps_inf"]) - def _eps_inf_geq_one(cls, val, values): + @model_validator(mode="after") + def _eps_inf_geq_one(self) -> Self: """Must have eps_inf >= 1 unless it is being optimized. In the latter case, it will be made >= 1 later.""" - if values["optimize_eps_inf"] is False and val < 1: + if self.optimize_eps_inf is False and self.eps_inf < 1: raise ValidationError("The value of 'eps_inf' must be at least 1.") - return val + return self - @validator("poles", always=True) - @skip_if_fields_missing(["logspacing", "smooth", "num_poles", "omega", "num_poles"]) - def _generate_initial_poles(cls, val, values): + @model_validator(mode="after") + def _generate_initial_poles(self) -> Self: """Generate initial poles.""" + val = self.poles if val is not None: - return val - if ( - values.get("logspacing") is None - or values.get("smooth") is None - or values.get("num_poles") is None - ): - return None - omega = values["omega"] - num_poles = values["num_poles"] - if values["logspacing"]: + return self + if self.logspacing is None or self.smooth is None or self.num_poles is None: + return self + omega = self.omega + num_poles = self.num_poles + if self.logspacing: pole_range = np.logspace( np.log10(min(omega) / SCALE_FACTOR), np.log10(max(omega) * SCALE_FACTOR), num_poles ) @@ -290,22 +323,22 @@ def _generate_initial_poles(cls, val, values): pole_range = np.linspace( min(omega) / SCALE_FACTOR, max(omega) * SCALE_FACTOR, num_poles ) - if values["smooth"]: + if self.smooth: poles = -pole_range else: poles = -pole_range / 100 + 1j * pole_range - return poles + object.__setattr__(self, "poles", poles) + return self - @validator("residues", always=True) - @skip_if_fields_missing(["poles"]) - def _generate_initial_residues(cls, val, values): + @model_validator(mode="after") + def _generate_initial_residues(self) -> Self: """Generate initial residues.""" - if val is not None: - return val - poles = values.get("poles") - if poles is None: - return None - return np.zeros(len(poles)) + if self.residues is not None: + return self + if self.poles is None: + return self + object.__setattr__(self, "residues", np.zeros(len(self.poles))) + return self @classmethod def initialize( @@ -670,10 +703,10 @@ def iterate_passivity(self, passivity_omega: ArrayFloat1D) -> tuple[FastFitterDa h_matrix = a_matrix_real.T @ a_matrix_real f_vector = a_matrix_real.T @ b_vector_real - def loss(dx): - return dx.T @ h_matrix @ dx / 2 - f_vector.T @ dx + def loss(dx: NDArray) -> float: + return float(dx.T @ h_matrix @ dx / 2 - f_vector.T @ dx) - def jac(dx): + def jac(dx: NDArray) -> NDArray: return dx.T @ h_matrix - f_vector.T cons = { @@ -844,7 +877,7 @@ def fit( ) log.info(f"Fitting weights=({init_model.weights[0]:.3g}, {init_model.weights[1]:.3g}).") - def make_configs(): + def make_configs() -> list[list[Union[int, bool]]]: configs = [[p] for p in range(max(min_num_poles // 2, 1), max_num_poles + 1)] for setting in [ init_model.relaxed, diff --git a/tidy3d/components/eme/data/dataset.py b/tidy3d/components/eme/data/dataset.py index 3b3e7a3722..f2b509c409 100644 --- a/tidy3d/components/eme/data/dataset.py +++ b/tidy3d/components/eme/data/dataset.py @@ -5,7 +5,7 @@ from typing import Optional import numpy as np -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import cached_property from tidy3d.components.data.data_array import ( @@ -24,23 +24,19 @@ class EMESMatrixDataset(Dataset): """Dataset storing S matrix.""" - S11: EMESMatrixDataArray = pd.Field( - ..., + S11: EMESMatrixDataArray = Field( title="S11 matrix", description="S matrix relating output modes at port 1 to input modes at port 1.", ) - S12: EMESMatrixDataArray = pd.Field( - ..., + S12: EMESMatrixDataArray = Field( title="S12 matrix", description="S matrix relating output modes at port 1 to input modes at port 2.", ) - S21: EMESMatrixDataArray = pd.Field( - ..., + S21: EMESMatrixDataArray = Field( title="S21 matrix", description="S matrix relating output modes at port 2 to input modes at port 1.", ) - S22: EMESMatrixDataArray = pd.Field( - ..., + S22: EMESMatrixDataArray = Field( title="S22 matrix", description="S matrix relating output modes at port 2 to input modes at port 2.", ) @@ -49,23 +45,19 @@ class EMESMatrixDataset(Dataset): class EMEInterfaceSMatrixDataset(Dataset): """Dataset storing S matrices associated with EME cell interfaces.""" - S11: EMEInterfaceSMatrixDataArray = pd.Field( - ..., + S11: EMEInterfaceSMatrixDataArray = Field( title="S11 matrix", description="S matrix relating output modes at port 1 to input modes at port 1.", ) - S12: EMEInterfaceSMatrixDataArray = pd.Field( - ..., + S12: EMEInterfaceSMatrixDataArray = Field( title="S12 matrix", description="S matrix relating output modes at port 1 to input modes at port 2.", ) - S21: EMEInterfaceSMatrixDataArray = pd.Field( - ..., + S21: EMEInterfaceSMatrixDataArray = Field( title="S21 matrix", description="S matrix relating output modes at port 2 to input modes at port 1.", ) - S22: EMEInterfaceSMatrixDataArray = pd.Field( - ..., + S22: EMEInterfaceSMatrixDataArray = Field( title="S22 matrix", description="S matrix relating output modes at port 2 to input modes at port 2.", ) @@ -74,25 +66,24 @@ class EMEInterfaceSMatrixDataset(Dataset): class EMEOverlapDataset(Dataset): """Dataset storing overlaps between EME modes. - ``Oij`` is the unconjugated overlap computed using the E field of cell ``i`` - and the H field of cell ``j``. + Notes + ----- + ``Oij`` is the unconjugated overlap computed using the E field of cell ``i`` + and the H field of cell ``j``. - For consistency with ``Sij``, ``mode_index_out`` refers to the mode index - in cell ``i``, and ``mode_index_in`` refers to the mode index in cell ``j``. + For consistency with ``Sij``, ``mode_index_out`` refers to the mode index + in cell ``i``, and ``mode_index_in`` refers to the mode index in cell ``j``. """ - O11: EMEInterfaceSMatrixDataArray = pd.Field( - ..., + O11: EMEInterfaceSMatrixDataArray = Field( title="O11 matrix", description="Overlap integral between E field and H field in the same cell.", ) - O12: EMEInterfaceSMatrixDataArray = pd.Field( - ..., + O12: EMEInterfaceSMatrixDataArray = Field( title="O12 matrix", description="Overlap integral between E field on side 1 and H field on side 2.", ) - O21: EMEInterfaceSMatrixDataArray = pd.Field( - ..., + O21: EMEInterfaceSMatrixDataArray = Field( title="O21 matrix", description="Overlap integral between E field on side 2 and H field on side 1.", ) @@ -100,50 +91,53 @@ class EMEOverlapDataset(Dataset): class EMECoefficientDataset(Dataset): """Dataset storing various coefficients related to the EME simulation. - These coefficients can be used for debugging or optimization. - The ``A`` and ``B`` fields store the expansion coefficients for the modes in a cell. - These are defined at the cell centers. + Notes + ----- + These coefficients can be used for debugging or optimization. - The ``n_complex`` and ``flux`` fields respectively store the complex-valued effective - propagation index and the power flux associated with the modes used in the - EME calculation. + The ``A`` and ``B`` fields store the expansion coefficients for the modes in a cell. + These are defined at the cell centers. - The ``interface_Sij`` fields store the S matrices associated with the interfaces - between EME cells. + The ``n_complex`` and ``flux`` fields respectively store the complex-valued effective + propagation index and the power flux associated with the modes used in the + EME calculation. + + The ``interface_Sij`` fields store the S matrices associated with the interfaces + between EME cells. """ - A: Optional[EMECoefficientDataArray] = pd.Field( + A: Optional[EMECoefficientDataArray] = Field( None, title="A coefficient", description="Coefficient for forward mode in this cell.", ) - B: Optional[EMECoefficientDataArray] = pd.Field( + B: Optional[EMECoefficientDataArray] = Field( None, title="B coefficient", description="Coefficient for backward mode in this cell.", ) - n_complex: Optional[EMEModeIndexDataArray] = pd.Field( + n_complex: Optional[EMEModeIndexDataArray] = Field( None, title="Propagation Index", description="Complex-valued effective propagation indices associated with the EME modes.", ) - flux: Optional[EMEFluxDataArray] = pd.Field( + flux: Optional[EMEFluxDataArray] = Field( None, title="Flux", description="Power flux of the EME modes.", ) - interface_smatrices: Optional[EMEInterfaceSMatrixDataset] = pd.Field( + interface_smatrices: Optional[EMEInterfaceSMatrixDataset] = Field( None, title="Interface S Matrices", description="S matrices associated with the interfaces between EME cells.", ) - overlaps: Optional[EMEOverlapDataset] = pd.Field( + overlaps: Optional[EMEOverlapDataset] = Field( None, title="Overlaps", description="Overlaps between EME modes." ) @@ -187,32 +181,32 @@ def normalized_copy(self) -> EMECoefficientDataset: class EMEFieldDataset(ElectromagneticFieldDataset): """Dataset storing scalar components of E and H fields as a function of freq, mode_index, and port_index.""" - Ex: EMEScalarFieldDataArray = pd.Field( + Ex: Optional[EMEScalarFieldDataArray] = Field( None, title="Ex", description="Spatial distribution of the x-component of the electric field of the mode.", ) - Ey: EMEScalarFieldDataArray = pd.Field( + Ey: Optional[EMEScalarFieldDataArray] = Field( None, title="Ey", description="Spatial distribution of the y-component of the electric field of the mode.", ) - Ez: EMEScalarFieldDataArray = pd.Field( + Ez: Optional[EMEScalarFieldDataArray] = Field( None, title="Ez", description="Spatial distribution of the z-component of the electric field of the mode.", ) - Hx: EMEScalarFieldDataArray = pd.Field( + Hx: Optional[EMEScalarFieldDataArray] = Field( None, title="Hx", description="Spatial distribution of the x-component of the magnetic field of the mode.", ) - Hy: EMEScalarFieldDataArray = pd.Field( + Hy: Optional[EMEScalarFieldDataArray] = Field( None, title="Hy", description="Spatial distribution of the y-component of the magnetic field of the mode.", ) - Hz: EMEScalarFieldDataArray = pd.Field( + Hz: Optional[EMEScalarFieldDataArray] = Field( None, title="Hz", description="Spatial distribution of the z-component of the magnetic field of the mode.", @@ -222,39 +216,32 @@ class EMEFieldDataset(ElectromagneticFieldDataset): class EMEModeSolverDataset(ElectromagneticFieldDataset): """Dataset storing EME modes as a function of freq, mode_index, and cell_index.""" - n_complex: EMEModeIndexDataArray = pd.Field( - ..., + n_complex: EMEModeIndexDataArray = Field( title="Propagation Index", description="Complex-valued effective propagation constants associated with the mode.", ) - Ex: EMEScalarModeFieldDataArray = pd.Field( - ..., + Ex: EMEScalarModeFieldDataArray = Field( title="Ex", description="Spatial distribution of the x-component of the electric field of the mode.", ) - Ey: EMEScalarModeFieldDataArray = pd.Field( - ..., + Ey: EMEScalarModeFieldDataArray = Field( title="Ey", description="Spatial distribution of the y-component of the electric field of the mode.", ) - Ez: EMEScalarModeFieldDataArray = pd.Field( - ..., + Ez: EMEScalarModeFieldDataArray = Field( title="Ez", description="Spatial distribution of the z-component of the electric field of the mode.", ) - Hx: EMEScalarModeFieldDataArray = pd.Field( - ..., + Hx: EMEScalarModeFieldDataArray = Field( title="Hx", description="Spatial distribution of the x-component of the magnetic field of the mode.", ) - Hy: EMEScalarModeFieldDataArray = pd.Field( - ..., + Hy: EMEScalarModeFieldDataArray = Field( title="Hy", description="Spatial distribution of the y-component of the magnetic field of the mode.", ) - Hz: EMEScalarModeFieldDataArray = pd.Field( - ..., + Hz: EMEScalarModeFieldDataArray = Field( title="Hz", description="Spatial distribution of the z-component of the magnetic field of the mode.", ) diff --git a/tidy3d/components/eme/data/monitor_data.py b/tidy3d/components/eme/data/monitor_data.py index 558023c0e8..84fd02133a 100644 --- a/tidy3d/components/eme/data/monitor_data.py +++ b/tidy3d/components/eme/data/monitor_data.py @@ -4,7 +4,7 @@ from typing import Union -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData from tidy3d.components.data.monitor_data import ( @@ -25,8 +25,7 @@ class EMEModeSolverData(ElectromagneticFieldData, EMEModeSolverDataset): """Data associated with an EME mode solver monitor.""" - monitor: EMEModeSolverMonitor = pd.Field( - ..., + monitor: EMEModeSolverMonitor = Field( title="EME Mode Solver Monitor", description="EME mode solver monitor associated with this data.", ) @@ -35,16 +34,16 @@ class EMEModeSolverData(ElectromagneticFieldData, EMEModeSolverDataset): class EMEFieldData(ElectromagneticFieldData, EMEFieldDataset): """Data associated with an EME field monitor.""" - monitor: EMEFieldMonitor = pd.Field( - ..., title="EME Field Monitor", description="EME field monitor associated with this data." + monitor: EMEFieldMonitor = Field( + title="EME Field Monitor", + description="EME field monitor associated with this data.", ) class EMECoefficientData(AbstractMonitorData, EMECoefficientDataset): """Data associated with an EME coefficient monitor.""" - monitor: EMECoefficientMonitor = pd.Field( - ..., + monitor: EMECoefficientMonitor = Field( title="EME Coefficient Monitor", description="EME coefficient monitor associated with this data.", ) diff --git a/tidy3d/components/eme/data/sim_data.py b/tidy3d/components/eme/data/sim_data.py index 750c088f2a..bf94ca445f 100644 --- a/tidy3d/components/eme/data/sim_data.py +++ b/tidy3d/components/eme/data/sim_data.py @@ -2,50 +2,59 @@ from __future__ import annotations -from typing import Literal, Optional, Union +from typing import TYPE_CHECKING, Optional import numpy as np -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import cached_property from tidy3d.components.data.data_array import EMEScalarFieldDataArray, EMESMatrixDataArray -from tidy3d.components.data.monitor_data import FieldData, ModeData, ModeSolverData +from tidy3d.components.data.monitor_data import ModeData, ModeSolverData from tidy3d.components.data.sim_data import AbstractYeeGridSimulationData from tidy3d.components.eme.simulation import EMESimulation from tidy3d.components.geometry.base import Box -from tidy3d.components.types import annotate_type +from tidy3d.components.types.base import discriminated_union from tidy3d.exceptions import SetupError from tidy3d.log import log from .dataset import EMECoefficientDataset, EMESMatrixDataset -from .monitor_data import EMEFieldData, EMEModeSolverData, EMEMonitorDataType +from .monitor_data import EMEModeSolverData, EMEMonitorDataType + +if TYPE_CHECKING: + from typing import Literal, Union + + from tidy3d.components.data.monitor_data import FieldData + + from .monitor_data import EMEFieldData class EMESimulationData(AbstractYeeGridSimulationData): """Data associated with an EME simulation.""" - simulation: EMESimulation = pd.Field( - ..., title="EME simulation", description="EME simulation associated with this data." + simulation: EMESimulation = Field( + title="EME simulation", + description="EME simulation associated with this data.", ) - data: tuple[annotate_type(EMEMonitorDataType), ...] = pd.Field( - ..., + data: tuple[discriminated_union(EMEMonitorDataType), ...] = Field( title="Monitor Data", description="List of EME monitor data " "associated with the monitors of the original :class:`.EMESimulation`.", ) - smatrix: Optional[EMESMatrixDataset] = pd.Field( - None, title="S Matrix", description="Scattering matrix of the EME simulation." + smatrix: Optional[EMESMatrixDataset] = Field( + None, + title="S Matrix", + description="Scattering matrix of the EME simulation.", ) - coeffs: Optional[EMECoefficientDataset] = pd.Field( + coeffs: Optional[EMECoefficientDataset] = Field( None, title="Coefficients", description="Coefficients from the EME simulation. Useful for debugging and optimization.", ) - port_modes_raw: Optional[EMEModeSolverData] = pd.Field( + port_modes_raw: Optional[EMEModeSolverData] = Field( None, title="Port Modes", description="Modes associated with the two ports of the EME device. " @@ -54,7 +63,7 @@ class EMESimulationData(AbstractYeeGridSimulationData): ) @cached_property - def port_modes(self): + def port_modes(self) -> Optional[EMEModeSolverData]: """Modes associated with the two ports of the EME device. The scattering matrix is expressed in this basis. Note: these modes are symmetry expanded.""" diff --git a/tidy3d/components/eme/grid.py b/tidy3d/components/eme/grid.py index 7b1b548784..0e5032d731 100644 --- a/tidy3d/components/eme/grid.py +++ b/tidy3d/components/eme/grid.py @@ -3,20 +3,26 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveInt, field_validator, model_validator -from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.geometry.base import Box from tidy3d.components.grid.grid import Coords1D from tidy3d.components.mode_spec import ModeInterpSpec, ModeSpec -from tidy3d.components.structure import Structure -from tidy3d.components.types import ArrayFloat1D, Axis, Coordinate, Size +from tidy3d.components.types import ArrayFloat1D, Axis from tidy3d.constants import RADIAN, fp_eps, inf from tidy3d.exceptions import SetupError, ValidationError +if TYPE_CHECKING: + from pydantic import NonNegativeFloat, NonNegativeInt + + from tidy3d.compat import Self + from tidy3d.components.structure import Structure + from tidy3d.components.types import Coordinate, Size + # grid limits MAX_NUM_MODES = 100 MAX_NUM_EME_CELLS = 100 @@ -26,7 +32,7 @@ class EMEModeSpec(ModeSpec): """Mode spec for EME cells. Overrides some of the defaults and allowed values.""" - interp_spec: Optional[ModeInterpSpec] = pd.Field( + interp_spec: Optional[ModeInterpSpec] = Field( ModeInterpSpec.cheb(num_points=3, reduce_data=True), title="Mode frequency interpolation specification", description="Specification for computing modes at a reduced set of frequencies and " @@ -36,7 +42,7 @@ class EMEModeSpec(ModeSpec): "not be ``None``) to ensure consistent mode ordering across frequencies.", ) - angle_theta: Literal[0.0] = pd.Field( + angle_theta: Literal[0.0] = Field( 0.0, title="Polar Angle", description="Polar angle of the propagation axis from the injection axis. Not currently " @@ -45,7 +51,7 @@ class EMEModeSpec(ModeSpec): units=RADIAN, ) - angle_phi: Literal[0.0] = pd.Field( + angle_phi: Literal[0.0] = Field( 0.0, title="Azimuth Angle", description="Azimuth angle of the propagation axis in the plane orthogonal to the " @@ -55,7 +61,7 @@ class EMEModeSpec(ModeSpec): units=RADIAN, ) - precision: Literal["auto", "single", "double"] = pd.Field( + precision: Literal["auto", "single", "double"] = Field( "auto", title="single, double, or automatic precision in mode solver", description="The solver will be faster and using less memory under " @@ -84,9 +90,9 @@ class EMEModeSpec(ModeSpec): def _to_mode_spec(self) -> ModeSpec: """Convert to ordinary :class:`.ModeSpec`.""" - ms_dict = self.dict() + ms_dict = self.model_dump() ms_dict.pop("type") - return ModeSpec.parse_obj(ms_dict) + return ModeSpec.model_validate(ms_dict) class EMEGridSpec(Tidy3dBaseModel, ABC): @@ -98,7 +104,7 @@ class EMEGridSpec(Tidy3dBaseModel, ABC): in the simulation. """ - num_reps: pd.PositiveInt = pd.Field( + num_reps: PositiveInt = Field( 1, title="Number of Repetitions", description="Number of periodic repetitions of this EME grid. Useful for " @@ -107,12 +113,15 @@ class EMEGridSpec(Tidy3dBaseModel, ABC): "the EME solver to reuse the modes and cell interface scattering matrices.", ) - name: Optional[str] = pd.Field( - None, title="Name", description="Name of this 'EMEGridSpec'. Used in 'EMEPeriodicitySweep'." + name: Optional[str] = Field( + None, + title="Name", + description="Name of this 'EMEGridSpec'. Used in 'EMEPeriodicitySweep'.", ) - @pd.validator("num_reps", always=True) - def _validate_num_reps(cls, val): + @field_validator("num_reps") + @classmethod + def _validate_num_reps(cls, val: int) -> int: """Check num_reps is not too large.""" if val > MAX_NUM_REPS: raise SetupError( @@ -163,7 +172,7 @@ def num_virtual_cells(self) -> int: """Number of virtual cells in this EME grid spec.""" return len(self.virtual_cell_indices) - def _updated_copy_num_reps(self, num_reps: dict[str, pd.PositiveInt]) -> EMEGridSpec: + def _updated_copy_num_reps(self, num_reps: dict[str, PositiveInt]) -> Self: """Update ``num_reps`` of named subgrids.""" if self.name is not None: new_num_reps = num_reps.get(self.name) @@ -172,7 +181,7 @@ def _updated_copy_num_reps(self, num_reps: dict[str, pd.PositiveInt]) -> EMEGrid return self @property - def _cell_index_pairs(self) -> list[pd.NonNegativeInt]: + def _cell_index_pairs(self) -> list[NonNegativeInt]: """Pairs of adjacent cell indices.""" cell_indices = self.virtual_cell_indices pairs = [] @@ -192,12 +201,14 @@ class EMEUniformGrid(EMEGridSpec): >>> eme_grid = EMEUniformGrid(num_cells=10, mode_spec=mode_spec) """ - num_cells: pd.PositiveInt = pd.Field( - ..., title="Number of cells", description="Number of cells in the uniform EME grid." + num_cells: PositiveInt = Field( + title="Number of cells", + description="Number of cells in the uniform EME grid.", ) - mode_spec: EMEModeSpec = pd.Field( - ..., title="Mode Specification", description="Mode specification for the uniform EME grid." + mode_spec: EMEModeSpec = Field( + title="Mode Specification", + description="Mode specification for the uniform EME grid.", ) def make_grid(self, center: Coordinate, size: Size, axis: Axis) -> EMEGrid: @@ -246,14 +257,12 @@ class EMEExplicitGrid(EMEGridSpec): ... ) """ - mode_specs: list[EMEModeSpec] = pd.Field( - ..., + mode_specs: list[EMEModeSpec] = Field( title="Mode Specifications", description="Mode specifications for each cell in the explicit EME grid.", ) - boundaries: ArrayFloat1D = pd.Field( - ..., + boundaries: ArrayFloat1D = Field( title="Boundaries", description="List of coordinates of internal cell boundaries along the propagation axis. " "Must contain one fewer item than 'mode_specs', and must be strictly increasing. " @@ -262,11 +271,11 @@ class EMEExplicitGrid(EMEGridSpec): "and the simulation boundary.", ) - @pd.validator("boundaries", always=True) - @skip_if_fields_missing(["mode_specs"]) - def _validate_boundaries(cls, val, values): + @model_validator(mode="after") + def _validate_boundaries(self) -> Self: """Check that boundaries is increasing and contains one fewer element than mode_specs.""" - mode_specs = values["mode_specs"] + val = self.boundaries + mode_specs = self.mode_specs boundaries = val if len(mode_specs) - 1 != len(boundaries): raise ValidationError( @@ -278,7 +287,7 @@ def _validate_boundaries(cls, val, values): if rmax < rmin: raise ValidationError("The 'boundaries' must be increasing.") rmin = rmax - return val + return self def make_grid(self, center: Coordinate, size: Size, axis: Axis) -> EMEGrid: """Generate EME grid from the EME grid spec. @@ -322,14 +331,14 @@ def make_grid(self, center: Coordinate, size: Size, axis: Axis) -> EMEGrid: @classmethod def from_structures( cls, structures: list[Structure], axis: Axis, mode_spec: EMEModeSpec, **kwargs: Any - ) -> EMEExplicitGrid: + ) -> Self: """Create an explicit EME grid with boundaries aligned with structure bounding boxes. Every cell in the resulting grid has the same mode specification. Parameters ---------- - structures : List[:class:`.Structure`] + structures : list[:class:`.Structure`] A list of structures to define the :class:`.EMEExplicitGrid`. The EME grid boundaries will be placed at the lower and upper bounds of the bounding boxes of all the structures in the list. @@ -398,12 +407,12 @@ class EMECompositeGrid(EMEGridSpec): ... ) """ - subgrids: list[EMESubgridType] = pd.Field( - ..., title="Subgrids", description="Subgrids in the composite grid." + subgrids: list[EMESubgridType] = Field( + title="Subgrids", + description="Subgrids in the composite grid.", ) - subgrid_boundaries: ArrayFloat1D = pd.Field( - ..., + subgrid_boundaries: ArrayFloat1D = Field( title="Subgrid Boundaries", description="List of coordinates of internal subgrid boundaries along the propagation axis. " "Must contain one fewer item than 'subgrids', and must be strictly increasing. " @@ -412,10 +421,11 @@ class EMECompositeGrid(EMEGridSpec): "and the simulation boundary.", ) - @pd.validator("subgrid_boundaries", always=True) - def _validate_subgrid_boundaries(cls, val, values): + @model_validator(mode="after") + def _validate_subgrid_boundaries(self) -> Self: """Check that subgrid boundaries is increasing and contains one fewer element than subgrids.""" - subgrids = values["subgrids"] + val = self.subgrid_boundaries + subgrids = self.subgrids subgrid_boundaries = val if len(subgrids) - 1 != len(subgrid_boundaries): raise ValidationError( @@ -426,7 +436,7 @@ def _validate_subgrid_boundaries(cls, val, values): if rmax < rmin: raise ValidationError("The 'subgrid_boundaries' must be increasing.") rmin = rmax - return val + return self def subgrid_bounds( self, center: Coordinate, size: Size, axis: Axis @@ -445,7 +455,7 @@ def subgrid_bounds( Returns ------- - List[Tuple[float, float]] + list[tuple[float, float]] A list of pairs (rmin, rmax) of the bounds of the subgrids along the propagation axis. """ @@ -523,7 +533,7 @@ def virtual_cell_indices(self) -> int: inds += [ind + start_ind for ind in subgrid.virtual_cell_indices] return list(inds) * self.num_reps - def _updated_copy_num_reps(self, num_reps: dict[str, pd.PositiveInt]) -> EMEGridSpec: + def _updated_copy_num_reps(self, num_reps: dict[str, PositiveInt]) -> Self: """Update ``num_reps`` of named subgrids.""" new_self = super()._updated_copy_num_reps(num_reps=num_reps) new_subgrids = [ @@ -538,14 +548,14 @@ def from_structure_groups( axis: Axis, mode_specs: list[EMEModeSpec], names: Optional[list[str]] = None, - num_reps: Optional[list[pd.PositiveInt]] = None, + num_reps: Optional[list[PositiveInt]] = None, ) -> EMECompositeGrid: """Create a composite EME grid with boundaries aligned with structure bounding boxes. Parameters ---------- - structure_groups : List[List[:class:`.Structure`]] + structure_groups : list[list[:class:`.Structure`]] A list of structure groups to define the :class:`.EMECompositeGrid`. Each structure group will be used to generate an :class:`.EMEExplicitGrid` with boundaries aligned with the bounding boxes of the structures @@ -558,13 +568,13 @@ def from_structure_groups( Two adjacent structure groups cannot be empty. axis : :class:`.Axis` Propagation axis for the EME simulation. - mode_specs : List[:class:`.EMEModeSpec`] + mode_specs : list[:class:`.EMEModeSpec`] Mode specifications for each subgrid. Must be the same length as ``structure_groups``. - names : List[str] = None + names : list[str] = None Names for each subgrid. Must be the same length as ``structure_groups``. If ``None``, the subgrids do not recieve names. - num_reps : List[pd.PositiveInt] = None + num_reps : list[PositiveInt] = None Number of repetitions for each subgrid. Must be the same length as ``structure_groups``. If ``None``, the subgrids are not repeated. @@ -671,20 +681,24 @@ class EMEGrid(Box): in the simulation. """ - axis: Axis = pd.Field( - ..., title="Propagation axis", description="Propagation axis for the EME simulation." + axis: Axis = Field( + title="Propagation axis", + description="Propagation axis for the EME simulation.", ) - mode_specs: list[EMEModeSpec] = pd.Field( - ..., title="Mode Specifications", description="Mode specifications for the EME cells." + mode_specs: list[EMEModeSpec] = Field( + title="Mode Specifications", + description="Mode specifications for the EME cells.", ) - boundaries: Coords1D = pd.Field( - ..., title="Cell boundaries", description="Boundary coordinates of the EME cells." + boundaries: Coords1D = Field( + title="Cell boundaries", + description="Boundary coordinates of the EME cells.", ) - @pd.validator("mode_specs", always=True) - def _validate_size(cls, val): + @field_validator("mode_specs") + @classmethod + def _validate_size(cls, val: list[EMEModeSpec]) -> list[EMEModeSpec]: """Check grid size and num modes.""" num_eme_cells = len(val) if num_eme_cells > MAX_NUM_EME_CELLS: @@ -701,16 +715,15 @@ def _validate_size(cls, val): ) return val - @pd.validator("boundaries", always=True, pre=False) - @skip_if_fields_missing(["mode_specs", "axis", "center", "size"]) - def _validate_boundaries(cls, val, values): + @model_validator(mode="after") + def _validate_boundaries(self) -> Self: """Check that boundaries is increasing, in simulation domain, and contains one more element than 'mode_specs'.""" - mode_specs = values["mode_specs"] - boundaries = val - axis = values["axis"] - center = values["center"][axis] - size = values["size"][axis] + boundaries = self.boundaries + mode_specs = self.mode_specs + axis = self.axis + center = self.center[axis] + size = self.size[axis] sim_rmin = center - size / 2 sim_rmax = center + size / 2 if len(mode_specs) + 1 != len(boundaries): @@ -729,7 +742,7 @@ def _validate_boundaries(cls, val, values): rmin = rmax if rmax - sim_rmax > fp_eps: raise ValidationError("The last item in 'boundaries' is outside the simulation domain.") - return val + return self @property def centers(self) -> Coords1D: @@ -743,7 +756,7 @@ def centers(self) -> Coords1D: return centers @property - def lengths(self) -> list[pd.NonNegativeFloat]: + def lengths(self) -> list[NonNegativeFloat]: """Lengths of the EME cells along the propagation axis.""" rmin = self.boundaries[0] lengths = [] @@ -754,7 +767,7 @@ def lengths(self) -> list[pd.NonNegativeFloat]: return lengths @property - def num_cells(self) -> pd.NonNegativeInteger: + def num_cells(self) -> NonNegativeInt: """The number of cells in the EME grid.""" return len(self.centers) @@ -797,7 +810,7 @@ def cells(self) -> list[Box]: cells.append(Box(center=center, size=size)) return cells - def cell_indices_in_box(self, box: Box) -> list[pd.NonNegativeInteger]: + def cell_indices_in_box(self, box: Box) -> list[NonNegativeInt]: """Indices of cells that overlap with 'box'. Used to determine which data is recorded by a monitor. @@ -808,7 +821,7 @@ def cell_indices_in_box(self, box: Box) -> list[pd.NonNegativeInteger]: Returns ------- - List[pd.NonNegativeInteger] + list[NonNegativeInteger] The indices of the cells that intersect the provided box. """ indices = [] diff --git a/tidy3d/components/eme/monitor.py b/tidy3d/components/eme/monitor.py index 4b57259c06..776ffab0fc 100644 --- a/tidy3d/components/eme/monitor.py +++ b/tidy3d/components/eme/monitor.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from typing import Literal, Optional, Union -import pydantic.v1 as pd +from pydantic import Field, NonNegativeInt, PositiveInt from tidy3d.components.base_sim.monitor import AbstractMonitor from tidy3d.components.monitor import ( @@ -22,7 +22,7 @@ class EMEMonitor(AbstractMonitor, ABC): """Abstract EME monitor.""" - freqs: Optional[FreqArray] = pd.Field( + freqs: Optional[FreqArray] = Field( None, title="Monitor Frequencies", description="Frequencies at which the monitor will record. " @@ -30,7 +30,7 @@ class EMEMonitor(AbstractMonitor, ABC): "A value of 'None' will record at all simulation 'freqs'.", ) - num_modes: Optional[pd.NonNegativeInt] = pd.Field( + num_modes: Optional[NonNegativeInt] = Field( None, title="Number of Modes", description="Maximum number of modes for the monitor to record. " @@ -38,7 +38,7 @@ class EMEMonitor(AbstractMonitor, ABC): "A value of 'None' will record all modes.", ) - num_sweep: Optional[pd.NonNegativeInt] = pd.Field( + num_sweep: Optional[NonNegativeInt] = Field( 1, title="Number of Sweep Indices", description="Number of sweep indices for the monitor to record. " @@ -47,7 +47,7 @@ class EMEMonitor(AbstractMonitor, ABC): "will be omitted. A value of 'None' will record all sweep indices.", ) - interval_space: tuple[Literal[1], Literal[1], Literal[1]] = pd.Field( + interval_space: tuple[Literal[1], Literal[1], Literal[1]] = Field( (1, 1, 1), title="Spatial Interval", description="Number of grid step intervals between monitor recordings. If equal to 1, " @@ -56,7 +56,7 @@ class EMEMonitor(AbstractMonitor, ABC): "Not all monitors support values different from 1.", ) - eme_cell_interval_space: Literal[1] = pd.Field( + eme_cell_interval_space: Literal[1] = Field( 1, title="EME Cell Interval", description="Number of EME cells between monitor recordings. If equal to 1, " @@ -65,7 +65,7 @@ class EMEMonitor(AbstractMonitor, ABC): "Not all monitors support values different from 1.", ) - colocate: Literal[True] = pd.Field( + colocate: Literal[True] = Field( True, title="Colocate Fields", description="Defines whether fields are colocated to grid cell boundaries (i.e. to the " @@ -129,7 +129,7 @@ class EMEModeSolverMonitor(EMEMonitor): ... ) """ - interval_space: tuple[Literal[1], Literal[1], Literal[1]] = pd.Field( + interval_space: tuple[Literal[1], Literal[1], Literal[1]] = Field( (1, 1, 1), title="Spatial Interval", description="Note: not yet supported. Number of grid step intervals between monitor recordings. If equal to 1, " @@ -138,7 +138,7 @@ class EMEModeSolverMonitor(EMEMonitor): "in the propagation direction is not used. Note: this is not yet supported.", ) - eme_cell_interval_space: pd.PositiveInt = pd.Field( + eme_cell_interval_space: PositiveInt = Field( 1, title="EME Cell Interval", description="Number of EME cells between monitor recordings. If equal to 1, " @@ -147,20 +147,20 @@ class EMEModeSolverMonitor(EMEMonitor): "Not all monitors support values different from 1.", ) - colocate: bool = pd.Field( + colocate: bool = Field( True, title="Colocate Fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " "primal grid nodes). Default (False) is used internally in EME propagation.", ) - normalize: bool = pd.Field( + normalize: bool = Field( True, title="Normalize Modes", description="Whether to normalize the EME modes to unity flux.", ) - keep_invalid_modes: bool = pd.Field( + keep_invalid_modes: bool = Field( False, title="Keep Invalid Modes", description="Whether to store modes containing nan values and modes which are " @@ -203,7 +203,7 @@ class EMEFieldMonitor(EMEMonitor, AbstractFieldMonitor): ... ) """ - interval_space: tuple[pd.PositiveInt, pd.PositiveInt, pd.PositiveInt] = pd.Field( + interval_space: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( (1, 1, 1), title="Spatial Interval", description="Number of grid step intervals between monitor recordings. If equal to 1, " @@ -211,7 +211,7 @@ class EMEFieldMonitor(EMEMonitor, AbstractFieldMonitor): "first and last point of the monitor grid are always included.", ) - eme_cell_interval_space: Literal[1] = pd.Field( + eme_cell_interval_space: Literal[1] = Field( 1, title="EME Cell Interval", description="Number of EME cells between monitor recordings. If equal to 1, " @@ -221,14 +221,14 @@ class EMEFieldMonitor(EMEMonitor, AbstractFieldMonitor): "EME field monitor.", ) - colocate: bool = pd.Field( + colocate: bool = Field( True, title="Colocate Fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " "primal grid nodes). Default (False) is used internally in EME propagation.", ) - num_modes: Optional[pd.NonNegativeInt] = pd.Field( + num_modes: Optional[NonNegativeInt] = Field( None, title="Number of Modes", description="Maximum number of modes for the monitor to record. " @@ -267,7 +267,7 @@ class EMECoefficientMonitor(EMEMonitor): ... ) """ - interval_space: tuple[Literal[1], Literal[1], Literal[1]] = pd.Field( + interval_space: tuple[Literal[1], Literal[1], Literal[1]] = Field( (1, 1, 1), title="Spatial Interval", description="Number of grid step intervals between monitor recordings. If equal to 1, " @@ -277,7 +277,7 @@ class EMECoefficientMonitor(EMEMonitor): "for 'EMECoefficientMonitor'.", ) - eme_cell_interval_space: pd.PositiveInt = pd.Field( + eme_cell_interval_space: PositiveInt = Field( 1, title="EME Cell Interval", description="Number of EME cells between monitor recordings. If equal to 1, " diff --git a/tidy3d/components/eme/simulation.py b/tidy3d/components/eme/simulation.py index c2d30dd6a5..4abd650249 100644 --- a/tidy3d/components/eme/simulation.py +++ b/tidy3d/components/eme/simulation.py @@ -2,37 +2,32 @@ from __future__ import annotations -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional -try: - import matplotlib as mpl -except ImportError: - pass import numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, field_validator, model_validator from tidy3d.components.base import cached_property from tidy3d.components.boundary import BoundarySpec, PECBoundary from tidy3d.components.geometry.base import Box -from tidy3d.components.grid.grid import Grid from tidy3d.components.grid.grid_spec import GridSpec from tidy3d.components.medium import FullyAnisotropicMedium -from tidy3d.components.monitor import AbstractModeMonitor, ModeSolverMonitor, Monitor +from tidy3d.components.monitor import AbstractModeMonitor, ModeSolverMonitor from tidy3d.components.scene import Scene from tidy3d.components.simulation import ( AbstractYeeGridSimulation, Simulation, validate_boundaries_for_zero_dims, ) -from tidy3d.components.types import Ax, Axis, FreqArray, Symmetry, annotate_type -from tidy3d.components.types.monitor import MonitorType +from tidy3d.components.types import Axis, FreqArray +from tidy3d.components.types.base import discriminated_union from tidy3d.components.validators import MIN_FREQUENCY, validate_freqs_min, validate_freqs_not_empty from tidy3d.components.viz import add_ax_if_none, equal_aspect from tidy3d.constants import C_0, inf from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log -from .grid import EMECompositeGrid, EMEExplicitGrid, EMEGrid, EMEGridSpec, EMEGridSpecType +from .grid import EMECompositeGrid, EMEExplicitGrid, EMEGridSpecType from .monitor import ( EMECoefficientMonitor, EMEFieldMonitor, @@ -42,6 +37,25 @@ ) from .sweep import EMEFreqSweep, EMELengthSweep, EMEModeSweep, EMEPeriodicitySweep, EMESweepSpecType +if TYPE_CHECKING: + from typing import Union + + from pydantic import NonNegativeInt, PositiveInt + + from tidy3d.compat import Self + from tidy3d.components.grid.grid import Grid + from tidy3d.components.monitor import Monitor + from tidy3d.components.structure import Structure + from tidy3d.components.types import Ax, Symmetry + from tidy3d.components.types.monitor import MonitorType + + from .grid import EMEGrid, EMEGridSpec + +try: + import matplotlib as mpl +except ImportError: + pass + # maximum numbers of simulation parameters WARN_MONITOR_DATA_SIZE_GB = 10 MAX_MONITOR_INTERNAL_DATA_SIZE_GB = 50 @@ -157,8 +171,7 @@ class EMESimulation(AbstractYeeGridSimulation): * `EME Solver Demonstration <../../notebooks/docs/features/eme.rst>`_ """ - freqs: FreqArray = pd.Field( - ..., + freqs: FreqArray = Field( title="Frequencies", description="Frequencies for the EME simulation. " "The field is propagated independently at each provided frequency, " @@ -166,14 +179,12 @@ class EMESimulation(AbstractYeeGridSimulation): "To change this behavior, you can use 'EMEModeSpec.interp_spec'.", ) - axis: Axis = pd.Field( - ..., + axis: Axis = Field( title="Propagation Axis", description="Propagation axis (0, 1, or 2) for the EME simulation.", ) - eme_grid_spec: EMEGridSpecType = pd.Field( - ..., + eme_grid_spec: EMEGridSpecType = Field( title="EME Grid Specification", description="Specification for the EME propagation grid. " "The simulation is divided into cells in the propagation direction; " @@ -184,15 +195,15 @@ class EMESimulation(AbstractYeeGridSimulation): "tangential directions, as well as the grid used for field monitors.", ) - monitors: tuple[annotate_type(EMEMonitorType), ...] = pd.Field( + monitors: tuple[discriminated_union(EMEMonitorType), ...] = Field( (), title="Monitors", description="Tuple of monitors in the simulation. " "Note: monitor names are used to access data after simulation is run.", ) - boundary_spec: BoundarySpec = pd.Field( - BoundarySpec.all_sides(PECBoundary()), + boundary_spec: BoundarySpec = Field( + default_factory=lambda: BoundarySpec.all_sides(PECBoundary()), title="Boundaries", description="Specification of boundary conditions along each dimension. " "By default, PEC boundary conditions are applied on all sides. " @@ -202,7 +213,7 @@ class EMESimulation(AbstractYeeGridSimulation): "apply PML layers in the mode solver.", ) - sources: tuple[None, ...] = pd.Field( + sources: tuple[None, ...] = Field( (), title="Sources", description="Sources in the simulation. NOTE: sources are not currently supported " @@ -211,15 +222,15 @@ class EMESimulation(AbstractYeeGridSimulation): "use 'smatrix_in_basis' to use another set of modes or input field.", ) - internal_absorbers: tuple[()] = pd.Field( + internal_absorbers: tuple[()] = Field( (), title="Internal Absorbers", description="Planes with the first order absorbing boundary conditions placed inside the computational domain. " "Note: absorbers are not supported in EME simulations.", ) - grid_spec: GridSpec = pd.Field( - GridSpec(), + grid_spec: GridSpec = Field( + default_factory=GridSpec, title="Grid Specification", description="Specifications for the simulation grid along each of the three directions. " "This is distinct from 'eme_grid_spec', which defines the 1D EME grid in the " @@ -227,35 +238,35 @@ class EMESimulation(AbstractYeeGridSimulation): validate_default=True, ) - store_port_modes: bool = pd.Field( + store_port_modes: bool = Field( True, title="Store Port Modes", description="Whether to store the modes associated with the two ports. " "Required to find scattering matrix in basis besides the computational basis.", ) - store_coeffs: bool = pd.Field( + store_coeffs: bool = Field( False, title="Store Coefficients", description="Whether to store the internal coefficients from the EME simulation. " "The results are stored in 'EMESimulationData.coeffs'.", ) - normalize: bool = pd.Field( + normalize: bool = Field( True, title="Normalize Scattering Matrix", description="Whether to normalize the port modes to unity flux, " "thereby normalizing the scattering matrix and expansion coefficients.", ) - port_offsets: tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + port_offsets: tuple[NonNegativeFloat, NonNegativeFloat] = Field( (0, 0), title="Port Offsets", description="Offsets for the two ports, relative to the simulation bounds " "along the propagation axis.", ) - sweep_spec: Optional[EMESweepSpecType] = pd.Field( + sweep_spec: Optional[EMESweepSpecType] = Field( None, title="EME Sweep Specification", description="Specification for a parameter sweep to be performed during the EME " @@ -263,7 +274,7 @@ class EMESimulation(AbstractYeeGridSimulation): "in 'sim_data.smatrix'. Other simulation monitor data is not included in the sweep.", ) - constraint: Optional[Literal["passive", "unitary"]] = pd.Field( + constraint: Optional[Literal["passive", "unitary"]] = Field( "passive", title="EME Constraint", description="Constraint for EME propagation, imposed at cell interfaces. " @@ -277,21 +288,24 @@ class EMESimulation(AbstractYeeGridSimulation): _freqs_not_empty = validate_freqs_not_empty() _freqs_lower_bound = validate_freqs_min() - @pd.validator("grid_spec", always=True) - def _validate_auto_grid_wavelength(cls, val, values): + @field_validator("grid_spec") + @classmethod + def _validate_auto_grid_wavelength(cls, val: GridSpec) -> GridSpec: """Handle the case where grid_spec is auto and wavelength is not provided.""" # this is handled instead post-init to ensure freqs is defined return val - @pd.validator("freqs", always=True) - def _validate_freqs(cls, val): + @field_validator("freqs") + @classmethod + def _validate_freqs(cls, val: FreqArray) -> FreqArray: """Freqs cannot contain duplicates.""" if len(set(val)) != len(val): raise SetupError(f"'EMESimulation' 'freqs={val}' cannot contain duplicate frequencies.") return val - @pd.validator("structures", always=True) - def _validate_structures(cls, val): + @field_validator("structures") + @classmethod + def _validate_structures(cls, val: tuple[Structure, ...]) -> tuple[Structure, ...]: """Validate and warn for certain medium types.""" for ind, structure in enumerate(val): medium = structure.medium @@ -479,9 +493,9 @@ def plot( Opacity of the monitors. If ``None``, uses Tidy3d default. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -589,17 +603,25 @@ def port_modes_monitor(self) -> EMEModeSolverMonitor: normalize=self.normalize, ) - def _post_init_validators(self) -> None: - """Call validators taking `self` that get run after init.""" - self._validate_port_offsets() + @model_validator(mode="after") + def _validate_grid(self) -> Self: _ = self.grid + return self + + @model_validator(mode="after") + def _validate_eme_grid(self) -> Self: _ = self.eme_grid + return self + + @model_validator(mode="after") + def _validate_mode_solver_monitors(self) -> Self: + _ = self.mode_solver_monitors + return self + + @model_validator(mode="after") + def _validate_cell_index_pairs(self) -> Self: _ = self.mode_solver_monitors - _ = self._cell_index_pairs - self._validate_too_close_to_edges() - self._validate_sweep_spec() - self._validate_symmetry() - self._validate_monitor_setup() + return self def validate_pre_upload(self) -> None: """Validate the fully initialized EME simulation is ok for upload to our servers.""" @@ -613,7 +635,8 @@ def validate_pre_upload(self) -> None: # self._warn_monitor_interval() log.end_capture(self) - def _validate_too_close_to_edges(self) -> None: + @model_validator(mode="after") + def _validate_too_close_to_edges(self) -> Self: """Can't have mode planes closer to boundary than extreme Yee grid center.""" cell_centers = self.eme_grid.centers yee_centers = list(self.grid.centers.to_dict.values())[self.axis] @@ -639,6 +662,7 @@ def _validate_too_close_to_edges(self) -> None: "of the simulation boundary along the propagation axis. " "Please move the monitor further from the boundary." ) + return self def _validate_constraint(self) -> None: """Constraint can be slow with too many modes. Warn in this case.""" @@ -653,7 +677,8 @@ def _validate_constraint(self) -> None: "reducing the number of modes or setting 'constraint=None'." ) - def _validate_port_offsets(self) -> None: + @model_validator(mode="after") + def _validate_port_offsets(self) -> Self: """Port offsets cannot jointly exceed simulation length.""" total_offset = self.port_offsets[0] + self.port_offsets[1] size = self.size @@ -663,11 +688,14 @@ def _validate_port_offsets(self) -> None: "The sum of the two 'port_offset' fields " "cannot exceed the simulation 'size' in the 'axis' direction." ) + return self - def _validate_symmetry(self) -> None: + @model_validator(mode="after") + def _validate_symmetry(self) -> Self: """Symmetry in propagation direction is not supported.""" if self.symmetry[self.axis] != 0: raise SetupError("Symmetry in the propagation diretion is not currently supported.") + return self # uncomment once interval_space != 1 is supported in any monitors # def _warn_monitor_interval(self): @@ -692,10 +720,11 @@ def _validate_sweep_spec_size(self) -> None: f"which exceeds the maximum allowed '{MAX_NUM_SWEEP}'." ) - def _validate_sweep_spec(self) -> None: + @model_validator(mode="after") + def _validate_sweep_spec(self) -> Self: """Validate sweep spec.""" if self.sweep_spec is None: - return + return self num_sweep = self.sweep_spec.num_sweep if num_sweep == 0: raise SetupError("Simulation 'sweep_spec' has 'num_sweep=0'.") @@ -760,8 +789,10 @@ def _validate_sweep_spec(self) -> None: raise SetupError( "'EMESimulation.store_coeffs' is not compatible with 'EMEPeriodicitySweep'." ) + return self - def _validate_monitor_setup(self) -> None: + @model_validator(mode="after") + def _validate_monitor_setup(self) -> Self: """Check monitor setup.""" for i, monitor in enumerate(self.monitors): if isinstance(monitor, EMEMonitor): @@ -822,6 +853,7 @@ def _validate_monitor_setup(self) -> None: "which is not compatible with periodic repetition " "('num_reps != 1' in any 'EMEGridSpec'.)" ) + return self def _validate_size(self) -> None: """Ensures the simulation is within size limits before simulation is uploaded.""" @@ -1024,7 +1056,7 @@ def _num_sampling_points(self) -> int: return len(freqs) @property - def _num_sweep(self) -> pd.PositiveInt: + def _num_sweep(self) -> PositiveInt: """Number of sweep indices.""" if self.sweep_spec is None: return 1 @@ -1036,7 +1068,7 @@ def _sweep_modes(self) -> bool: return self.sweep_spec is not None and isinstance(self.sweep_spec, EMEFreqSweep) @property - def _num_sweep_modes(self) -> pd.PositiveInt: + def _num_sweep_modes(self) -> PositiveInt: """Number of sweep indices for modes.""" if self._sweep_modes: return self._num_sweep @@ -1050,7 +1082,7 @@ def _sweep_interfaces(self) -> bool: ) @property - def _num_sweep_interfaces(self) -> pd.PositiveInt: + def _num_sweep_interfaces(self) -> PositiveInt: """Number of sweep indices for interfaces.""" if self._sweep_interfaces: return self._num_sweep @@ -1064,13 +1096,13 @@ def _sweep_cells(self) -> bool: ) @property - def _num_sweep_cells(self) -> pd.PositiveInt: + def _num_sweep_cells(self) -> PositiveInt: """Number of sweep indices for cells.""" if self._sweep_cells: return self._num_sweep return 1 - def _monitor_num_sweep(self, monitor: EMEMonitor) -> pd.PositiveInt: + def _monitor_num_sweep(self, monitor: EMEMonitor) -> PositiveInt: """Number of sweep indices for a certain monitor.""" if self.sweep_spec is None: return 1 @@ -1081,7 +1113,7 @@ def _monitor_num_sweep(self, monitor: EMEMonitor) -> pd.PositiveInt: return self.sweep_spec.num_sweep return min(self.sweep_spec.num_sweep, monitor.num_sweep) - def _monitor_eme_cell_indices(self, monitor: EMEMonitor) -> list[pd.NonNegativeInt]: + def _monitor_eme_cell_indices(self, monitor: EMEMonitor) -> list[NonNegativeInt]: """EME cell indices inside monitor. Takes into account 'eme_cell_interval_space'.""" cell_indices_full = self.eme_grid.cell_indices_in_box(box=monitor.geometry) if len(cell_indices_full) == 0: @@ -1096,13 +1128,13 @@ def _monitor_num_eme_cells(self, monitor: EMEMonitor) -> int: """Total number of EME cells included in monitor based on simulation grid.""" return len(self._monitor_eme_cell_indices(monitor=monitor)) - def _monitor_freqs(self, monitor: Monitor) -> list[pd.NonNegativeFloat]: + def _monitor_freqs(self, monitor: Monitor) -> list[NonNegativeFloat]: """Monitor frequencies.""" if monitor.freqs is None: return list(self.freqs) return list(monitor.freqs) - def _monitor_mode_freqs(self, monitor: EMEModeSolverMonitor) -> list[pd.NonNegativeFloat]: + def _monitor_mode_freqs(self, monitor: EMEModeSolverMonitor) -> list[NonNegativeFloat]: """Monitor frequencies.""" freqs = set() cell_inds = self._monitor_eme_cell_indices(monitor=monitor) @@ -1201,7 +1233,9 @@ def _to_fdtd_sim(self) -> Simulation: grid_spec = grid_spec.updated_copy(wavelength=min_wvl) # copy over all FDTD monitors too - monitors = [monitor for monitor in self.monitors if not isinstance(monitor, EMEMonitor)] + monitors = tuple( + monitor for monitor in self.monitors if not isinstance(monitor, EMEMonitor) + ) kwargs = {key: getattr(self, key) for key in EME_SIM_YEE_SIM_SHARED_ATTRS} return Simulation( @@ -1240,13 +1274,13 @@ def subsection( simulation. If ``identical``, then the original grid is transferred directly as a :class:`.EMEExplicitGrid`. Noe that in the latter case the region of the new simulation is expanded to contain full EME cells. - symmetry : Tuple[Literal[0, -1, 1], Literal[0, -1, 1], Literal[0, -1, 1]] = None + symmetry : tuple[Literal[0, -1, 1], Literal[0, -1, 1], Literal[0, -1, 1]] = None New simulation symmetry. If ``None``, then it is inherited from the original simulation. Note that in this case the size and placement of new simulation domain must be commensurate with the original symmetry. warn_symmetry_expansion : bool = True Whether to warn when the subsection is expanded to preserve symmetry. - monitors : Tuple[MonitorType, ...] = None + monitors : tuple[MonitorType, ...] = None New list of monitors. If ``None``, then the monitors intersecting the new simulation domain are inherited from the original simulation. remove_outside_structures : bool = True @@ -1295,7 +1329,7 @@ def subsection( return new_sim @property - def _cell_index_pairs(self) -> list[pd.NonNegativeInt]: + def _cell_index_pairs(self) -> list[NonNegativeInt]: """All the pairs of adjacent EME cells needed, taken over all sweep indices.""" pairs = set() if isinstance(self.sweep_spec, EMEPeriodicitySweep): diff --git a/tidy3d/components/eme/sweep.py b/tidy3d/components/eme/sweep.py index 6ded35c645..ded0e1e357 100644 --- a/tidy3d/components/eme/sweep.py +++ b/tidy3d/components/eme/sweep.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from typing import Union -import pydantic.v1 as pd +from pydantic import Field, PositiveInt, field_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.types import ArrayFloat1D, ArrayInt1D, ArrayLike @@ -19,15 +19,14 @@ class EMESweepSpec(Tidy3dBaseModel, ABC): @property @abstractmethod - def num_sweep(self) -> pd.PositiveInt: + def num_sweep(self) -> PositiveInt: """Number of sweep indices.""" class EMELengthSweep(EMESweepSpec): """Spec for sweeping EME cell lengths.""" - scale_factors: ArrayLike = pd.Field( - ..., + scale_factors: ArrayLike = Field( title="Length Scale Factor", description="Length scale factors to be used in the EME propagation step. " "The EME propagation step is repeated after scaling every cell length by this amount. " @@ -37,7 +36,7 @@ class EMELengthSweep(EMESweepSpec): ) @property - def num_sweep(self) -> pd.PositiveInt: + def num_sweep(self) -> PositiveInt: """Number of sweep indices.""" return len(self.scale_factors) @@ -46,8 +45,7 @@ class EMEModeSweep(EMESweepSpec): """Spec for sweeping number of modes in EME propagation step. Used for convergence testing.""" - num_modes: ArrayInt1D = pd.Field( - ..., + num_modes: ArrayInt1D = Field( title="Number of Modes", description="Max number of modes to use in the EME propagation step. " "The EME propagation step is repeated after dropping modes with mode_index " @@ -57,7 +55,7 @@ class EMEModeSweep(EMESweepSpec): ) @property - def num_sweep(self) -> pd.PositiveInt: + def num_sweep(self) -> PositiveInt: """Number of sweep indices.""" return len(self.num_modes) @@ -68,8 +66,7 @@ class EMEFreqSweep(EMESweepSpec): perturbative mode solver relative to the simulation EME modes. This can be a faster way to solve at a larger number of frequencies.""" - freq_scale_factors: ArrayFloat1D = pd.Field( - ..., + freq_scale_factors: ArrayFloat1D = Field( title="Frequency Scale Factors", description="Scale factors " "applied to every frequency in 'EMESimulation.freqs'. After applying the scale factors, " @@ -79,20 +76,23 @@ class EMEFreqSweep(EMESweepSpec): ) @property - def num_sweep(self) -> pd.PositiveInt: + def num_sweep(self) -> PositiveInt: """Number of sweep indices.""" return len(self.freq_scale_factors) class EMEPeriodicitySweep(EMESweepSpec): """Spec for sweeping number of repetitions of EME subgrids. - Useful for simulating long periodic structures like Bragg gratings, - as it allows the EME solver to reuse the modes and cell interface - scattering matrices. - Compared to setting ``num_reps`` directly in the ``eme_grid_spec``, - this sweep spec allows varying the number of repetitions, - effectively simulating multiple structures in a single EME simulation. + Notes + ----- + Useful for simulating long periodic structures like Bragg gratings, + as it allows the EME solver to reuse the modes and cell interface + scattering matrices. + + Compared to setting ``num_reps`` directly in the ``eme_grid_spec``, + this sweep spec allows varying the number of repetitions, + effectively simulating multiple structures in a single EME simulation. Example ------- @@ -100,16 +100,16 @@ class EMEPeriodicitySweep(EMESweepSpec): >>> sweep_spec = EMEPeriodicitySweep(num_reps=[{"unit_cell": n} for n in n_list]) """ - num_reps: list[dict[str, pd.PositiveInt]] = pd.Field( - ..., + num_reps: list[dict[str, PositiveInt]] = Field( title="Number of Repetitions", description="Number of periodic repetitions of named subgrids in this EME grid. " "At each sweep index, contains a dict mapping the name of a subgrid to the " "number of repetitions of that subgrid at that sweep index.", ) - @pd.validator("num_reps", always=True) - def _validate_num_reps(cls, val): + @field_validator("num_reps") + @classmethod + def _validate_num_reps(cls, val: list[dict[str, PositiveInt]]) -> list[dict[str, PositiveInt]]: """Check num_reps is not too large.""" for num_reps_dict in val: for value in num_reps_dict.values(): @@ -121,7 +121,7 @@ def _validate_num_reps(cls, val): return val @property - def num_sweep(self) -> pd.PositiveInt: + def num_sweep(self) -> PositiveInt: """Number of sweep indices.""" return len(self.num_reps) diff --git a/tidy3d/components/field_projection.py b/tidy3d/components/field_projection.py index d23b7c424f..d4c7f02fb5 100644 --- a/tidy3d/components/field_projection.py +++ b/tidy3d/components/field_projection.py @@ -4,12 +4,12 @@ from collections.abc import Iterable from itertools import product -from typing import Union +from typing import TYPE_CHECKING, Optional, Union import autograd.numpy as anp import numpy as np -import pydantic.v1 as pydantic import xarray as xr +from pydantic import Field, model_validator from rich.progress import track from tidy3d.constants import C_0, EPSILON_0, ETA_0, MICROMETER, MU_0 @@ -17,7 +17,7 @@ from tidy3d.log import get_logging_console from .autograd.functions import add_at, trapz -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from .base import Tidy3dBaseModel, cached_property from .data.data_array import ( FieldProjectionAngleDataArray, FieldProjectionCartesianDataArray, @@ -31,16 +31,21 @@ FieldProjectionKSpaceData, ) from .data.sim_data import SimulationData -from .medium import MediumType from .monitor import ( - AbstractFieldProjectionMonitor, - FieldMonitor, FieldProjectionAngleMonitor, FieldProjectionCartesianMonitor, - FieldProjectionKSpaceMonitor, FieldProjectionSurface, ) -from .types import ArrayComplex4D, Coordinate, Direction +from .types import ArrayComplex4D, Coordinate + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from tidy3d.compat import Self + + from .medium import MediumType + from .monitor import AbstractFieldProjectionMonitor, FieldMonitor, FieldProjectionKSpaceMonitor + from .types import Direction # Default number of points per wavelength in the background medium to use for resampling fields. PTS_PER_WVL = 10 @@ -178,36 +183,34 @@ def _far_field_integral( class FieldProjector(Tidy3dBaseModel): - """ - Projection of near fields to points on a given observation grid. + """Projection of near fields to points on a given observation grid. + Notes + ----- .. TODO make images to illustrate this See Also -------- - - :class:`FieldProjectionAngleMonitor + :class:`FieldProjectionAngleMonitor` :class:`Monitor` that samples electromagnetic near fields in the frequency domain - and projects them at given observation angles.` + and projects them at given observation angles. **Notebooks**: * `Performing near field to far field projections <../../notebooks/FieldProjections.html>`_ """ - sim_data: SimulationData = pydantic.Field( - ..., + sim_data: SimulationData = Field( title="Simulation data", description="Container for simulation data containing the near field monitors.", ) - surfaces: tuple[FieldProjectionSurface, ...] = pydantic.Field( - ..., + surfaces: tuple[FieldProjectionSurface, ...] = Field( title="Surface monitor with direction", - description="Tuple of each :class:`.FieldProjectionSurface` to use as source of " + description="tuple of each :class:`.FieldProjectionSurface` to use as source of " "near field.", ) - pts_per_wavelength: Union[int, type(None)] = pydantic.Field( + pts_per_wavelength: Optional[int] = Field( PTS_PER_WVL, title="Points per wavelength", description="Number of points per wavelength in the background medium with which " @@ -215,7 +218,7 @@ class FieldProjector(Tidy3dBaseModel): "will not resampled, but will still be colocated.", ) - origin: Coordinate = pydantic.Field( + origin: Optional[Coordinate] = Field( None, title="Local origin", description="Local origin used for defining observation points. If ``None``, uses the " @@ -223,21 +226,19 @@ class FieldProjector(Tidy3dBaseModel): units=MICROMETER, ) + @model_validator(mode="after") + def _check_origin_set(self) -> Self: + """Sets ``.origin`` as the average of centers of all surface monitors if not provided.""" + if self.origin is None: + centers = np.array([surface.monitor.center for surface in self.surfaces]) + object.__setattr__(self, "origin", tuple(np.mean(centers, axis=0))) + return self + @cached_property def is_2d_simulation(self) -> bool: non_zero_dims = sum(1 for size in self.sim_data.simulation.size if size != 0) return non_zero_dims == 2 - @pydantic.validator("origin", always=True) - @skip_if_fields_missing(["surfaces"]) - def set_origin(cls, val, values): - """Sets .origin as the average of centers of all surface monitors if not provided.""" - if val is None: - surfaces = values.get("surfaces") - val = np.array([surface.monitor.center for surface in surfaces]) - return tuple(np.mean(val, axis=0)) - return val - @cached_property def medium(self) -> MediumType: """Medium into which fields are to be projected.""" @@ -258,17 +259,17 @@ def from_near_field_monitors( normal_dirs: list[Direction], pts_per_wavelength: int = PTS_PER_WVL, origin: Coordinate = None, - ): + ) -> Self: """Constructs :class:`FieldProjection` from a list of surface monitors and their directions. Parameters ---------- sim_data : :class:`.SimulationData` Container for simulation data containing the near field monitors. - near_monitors : List[:class:`.FieldMonitor`] - Tuple of :class:`.FieldMonitor` objects on which near fields will be sampled. - normal_dirs : List[:class:`.Direction`] - Tuple containing the :class:`.Direction` of the normal to each surface monitor + near_monitors : list[:class:`.FieldMonitor`] + tuple of :class:`.FieldMonitor` objects on which near fields will be sampled. + normal_dirs : list[:class:`.Direction`] + tuple containing the :class:`.Direction` of the normal to each surface monitor w.r.t. to the positive x, y or z unit vectors. Must have the same length as monitors. pts_per_wavelength : int = 10 Number of points per wavelength with which to discretize the @@ -297,7 +298,7 @@ def from_near_field_monitors( ) @cached_property - def currents(self): + def currents(self) -> dict[str, xr.Dataset]: """Sets the surface currents.""" sim_data = self.sim_data surfaces = self.surfaces @@ -398,7 +399,7 @@ def _fields_to_currents(field_data: FieldData, surface: FieldProjectionSurface) surface_currents[H2] = field_data.field_components[E1] * signs[0] surface_currents[H1] = field_data.field_components[E2] * signs[1] - new_monitor = surface.monitor.copy(update={"fields": [E1, E2, H1, H2]}) + new_monitor = surface.monitor.copy(update={"fields": (E1, E2, H1, H2)}) return FieldData( monitor=new_monitor, @@ -463,7 +464,7 @@ def _resample_surface_currents( coord_list[idx][-1], ) if pts_per_wavelength is None: - points = sim_data.simulation.grid.boundaries.to_list[idx] + points = sim_data.simulation.grid.boundaries.to_list[idx].copy() points[np.argwhere(points < start)] = start points[np.argwhere(points > stop)] = stop colocation_points[idx] = np.unique(points) @@ -482,10 +483,10 @@ def _resample_surface_currents( @staticmethod def trapezoid( - ary: np.ndarray, - pts: Union[Iterable[np.ndarray], np.ndarray], + ary: NDArray, + pts: Union[Iterable[NDArray], NDArray], axes: Union[Iterable[int], int] = 0, - ): + ) -> NDArray: """Trapezoidal integration in n dimensions. Parameters @@ -521,7 +522,7 @@ def _far_fields_for_surface( surface: FieldProjectionSurface, currents: xr.Dataset, medium: MediumType, - ) -> np.ndarray: + ) -> NDArray: """Compute far fields at an angle in spherical coordinates for a given set of surface currents and observation angles. @@ -530,9 +531,9 @@ def _far_fields_for_surface( frequency : float Frequency to select from each :class:`.FieldMonitor` to use for projection. Must be a frequency stored in each :class:`FieldMonitor`. - theta : Union[float, Tuple[float, ...], np.ndarray] + theta : Union[float, tuple[float, ...], np.ndarray] Polar angles (rad) downward from x=y=0 line relative to the local origin. - phi : Union[float, Tuple[float, ...], np.ndarray] + phi : Union[float, tuple[float, ...], np.ndarray] Azimuthal (rad) angles from y=z=0 line relative to the local origin. surface: :class:`FieldProjectionSurface` :class:`FieldProjectionSurface` object to use as source of near field. @@ -995,7 +996,7 @@ def _fields_for_surface_exact( surface: FieldProjectionSurface, currents: xr.Dataset, medium: MediumType, - ) -> np.ndarray: + ) -> NDArray: """Compute projected fields in spherical coordinates at a given projection point on a Cartesian grid for a given set of surface currents using the exact homogeneous medium Green's function without geometric approximations. @@ -1071,7 +1072,7 @@ def _fields_for_surface_exact( d2G_dr2 = dG_dr * (ikr - 1.0) / r + G / (r**2) # operations between unit vectors and currents - def r_x_current(current: tuple[np.ndarray, ...]) -> tuple[np.ndarray, ...]: + def r_x_current(current: tuple[NDArray, ...]) -> tuple[NDArray, ...]: """Cross product between the r unit vector and the current.""" return [ sin_theta * sin_phi * current[2] - cos_theta * current[1], @@ -1079,7 +1080,7 @@ def r_x_current(current: tuple[np.ndarray, ...]) -> tuple[np.ndarray, ...]: sin_theta * cos_phi * current[1] - sin_theta * sin_phi * current[0], ] - def r_dot_current(current: tuple[np.ndarray, ...]) -> np.ndarray: + def r_dot_current(current: tuple[NDArray, ...]) -> NDArray: """Dot product between the r unit vector and the current.""" return ( sin_theta * cos_phi * current[0] @@ -1087,7 +1088,7 @@ def r_dot_current(current: tuple[np.ndarray, ...]) -> np.ndarray: + cos_theta * current[2] ) - def r_dot_current_dtheta(current: tuple[np.ndarray, ...]) -> np.ndarray: + def r_dot_current_dtheta(current: tuple[NDArray, ...]) -> NDArray: """Theta derivative of the dot product between the r unit vector and the current.""" return ( cos_theta * cos_phi * current[0] @@ -1095,12 +1096,12 @@ def r_dot_current_dtheta(current: tuple[np.ndarray, ...]) -> np.ndarray: - sin_theta * current[2] ) - def r_dot_current_dphi_div_sin_theta(current: tuple[np.ndarray, ...]) -> np.ndarray: + def r_dot_current_dphi_div_sin_theta(current: tuple[NDArray, ...]) -> NDArray: """Phi derivative of the dot product between the r unit vector and the current, analytically divided by sin theta.""" return -sin_phi * current[0] + cos_phi * current[1] - def grad_Gr_r_dot_current(current: tuple[np.ndarray, ...]) -> tuple[np.ndarray, ...]: + def grad_Gr_r_dot_current(current: tuple[NDArray, ...]) -> tuple[NDArray, ...]: """Gradient of the product of the gradient of the Green's function and the dot product between the r unit vector and the current.""" temp = [ @@ -1111,7 +1112,9 @@ def grad_Gr_r_dot_current(current: tuple[np.ndarray, ...]) -> tuple[np.ndarray, # convert to Cartesian coordinates return surface.monitor.sph_2_car_field(temp[0], temp[1], temp[2], theta_obs, phi_obs) - def potential_terms(current: tuple[np.ndarray, ...], const: complex): + def potential_terms( + current: tuple[NDArray, ...], const: complex + ) -> tuple[list[complex], list[complex], list[complex]]: """Assemble vector potential and its derivatives.""" r_x_c = r_x_current(current) pot = [const * item * G for item in current] diff --git a/tidy3d/components/file_util.py b/tidy3d/components/file_util.py index bdfe7326a1..54b1ab4c7d 100644 --- a/tidy3d/components/file_util.py +++ b/tidy3d/components/file_util.py @@ -1,81 +1,12 @@ -"""File compression utilities""" - -from __future__ import annotations - -import gzip -import pathlib -import shutil -from io import BytesIO -from os import PathLike -from typing import Any - -import numpy as np - - -def compress_file_to_gzip(input_file: PathLike, output_gz_file: PathLike | BytesIO) -> None: - """ - Compress a file using gzip. - - Parameters - ---------- - input_file : PathLike - The path to the input file. - output_gz_file : PathLike | BytesIO - The path to the output gzip file or an in-memory buffer. - """ - input_file = pathlib.Path(input_file) - with input_file.open("rb") as file_in: - with gzip.open(output_gz_file, "wb") as file_out: - shutil.copyfileobj(file_in, file_out) +"""Compatibility shim for :mod:`tidy3d._common.components.file_util`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -def extract_gzip_file(input_gz_file: PathLike, output_file: PathLike) -> None: - """ - Extract a gzip-compressed file. - - Parameters - ---------- - input_gz_file : PathLike - The path to the gzip-compressed input file. - output_file : PathLike - The path to the extracted output file. - """ - input_path = pathlib.Path(input_gz_file) - output_path = pathlib.Path(output_file) - with gzip.open(input_path, "rb") as file_in: - with output_path.open("wb") as file_out: - shutil.copyfileobj(file_in, file_out) - - -def replace_values(values: Any, search_value: Any, replace_value: Any) -> Any: - """ - Create a copy of ``values`` where any elements equal to ``search_value`` are replaced by ``replace_value``. - - Parameters - ---------- - values : Any - The input object to iterate through. - search_value : Any - An object to match for in ``values``. - replace_value : Any - A replacement object for the matched value in ``values``. - - Returns - ------- - Any - values type object with ``search_value`` terms replaced by ``replace_value``. - """ - # np.all allows for arrays to be evaluated - if np.all(values == search_value): - return replace_value - if isinstance(values, dict): - return { - key: replace_values(val, search_value, replace_value) for key, val in values.items() - } - if isinstance( - values, (tuple, list) - ): # Parts of the nested dict structure include tuples with more dicts - return type(values)(replace_values(val, search_value, replace_value) for val in values) +# marked as migrated to _common +from __future__ import annotations - # Used to maintain values that are not search_value or containers - return values +from tidy3d._common.components.file_util import ( + compress_file_to_gzip, + extract_gzip_file, + replace_values, +) diff --git a/tidy3d/components/frequencies.py b/tidy3d/components/frequencies.py index 125d94b4f4..5085a7ec04 100644 --- a/tidy3d/components/frequencies.py +++ b/tidy3d/components/frequencies.py @@ -2,17 +2,20 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np -import pydantic as pd -import pydantic.v1 as pydantic -from numpy.typing import NDArray +from pydantic import Field, PositiveFloat, model_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.source.time import GaussianPulse from tidy3d.constants import C_0 +if TYPE_CHECKING: + from numpy.typing import NDArray + + from tidy3d.compat import Self + O_BAND = (1.260, 1.360) E_BAND = (1.360, 1.460) S_BAND = (1.460, 1.530) @@ -24,7 +27,7 @@ class FrequencyUtils(Tidy3dBaseModel): """Utilities for classifying frequencies/wavelengths and generating samples for standard optical bands.""" - use_wavelength: bool = pd.Field( + use_wavelength: bool = Field( False, title="Use wavelength", description="Indicate whether to use wavelengths instead of frequencies for the return " @@ -272,30 +275,28 @@ class FreqRange(Tidy3dBaseModel): >>> source = freq_range.to_gaussian_pulse() """ - freq0: pydantic.PositiveFloat = pydantic.Field( + freq0: PositiveFloat = Field( ..., title="Central frequency", description="Real-valued positive central frequency.", units="Hz", ) - fwidth: pydantic.PositiveFloat = pydantic.Field( + fwidth: PositiveFloat = Field( ..., title="Frequency bandwidth", description="Real-valued positive width of the frequency range (bandwidth).", units="Hz", ) - @pydantic.root_validator - def check_half_fwidth_less_than_freq0(cls, values): - freq0 = values.get("freq0") - fwidth = values.get("fwidth") - if freq0 is not None and fwidth is not None: - if (fwidth / 2) >= freq0: + @model_validator(mode="after") + def check_half_fwidth_less_than_freq0(self) -> Self: + if self.freq0 is not None and self.fwidth is not None: + if (self.fwidth / 2) >= self.freq0: raise ValueError( "Frequency bandwidth `fwidth` must be strictly less than `2 * freq0`." ) - return values + return self @property def fmin(self) -> float: diff --git a/tidy3d/components/frequency_extrapolation.py b/tidy3d/components/frequency_extrapolation.py index e2768bba36..afb2cfc143 100644 --- a/tidy3d/components/frequency_extrapolation.py +++ b/tidy3d/components/frequency_extrapolation.py @@ -2,31 +2,34 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional -import pydantic.v1 as pydantic +from pydantic import Field, NonNegativeFloat, field_validator, model_validator from tidy3d.components.base import Tidy3dBaseModel +if TYPE_CHECKING: + from tidy3d.compat import Self + class AbstractLowFrequencySmoothingSpec(Tidy3dBaseModel): """Abstract base class for low frequency smoothing specifications.""" - min_sampling_time: pydantic.NonNegativeFloat = pydantic.Field( + min_sampling_time: NonNegativeFloat = Field( 1.0, title="Minimum Sampling Time (periods)", description="The minimum simulation time in periods of the corresponding frequency for which frequency domain results will be used to fit the polynomial for the low frequency extrapolation. " "Results below this threshold will be completely discarded.", ) - max_sampling_time: pydantic.NonNegativeFloat = pydantic.Field( + max_sampling_time: NonNegativeFloat = Field( 5.0, title="Maximum Sampling Time (periods)", description="The maximum simulation time in periods of the corresponding frequency for which frequency domain results will be used to fit the polynomial for the low frequency extrapolation. " "Results above this threshold will be not be modified.", ) - order: int = pydantic.Field( + order: int = Field( 1, title="Extrapolation Order", description="The order of the polynomial to use for the low frequency extrapolation.", @@ -34,22 +37,22 @@ class AbstractLowFrequencySmoothingSpec(Tidy3dBaseModel): le=3, ) - max_deviation: Optional[float] = pydantic.Field( + max_deviation: Optional[float] = Field( 0.5, title="Maximum Deviation", description="The maximum deviation (in fraction of the trusted values) to allow for the low frequency smoothing.", ge=0, ) - @pydantic.root_validator(skip_on_failure=True) - def _validate_sampling_times(cls, values): - min_sampling_time = values.get("min_sampling_time") - max_sampling_time = values.get("max_sampling_time") + @model_validator(mode="after") + def _validate_sampling_times(self) -> Self: + min_sampling_time = self.min_sampling_time + max_sampling_time = self.max_sampling_time if min_sampling_time >= max_sampling_time: raise ValueError( "The minimum sampling time must be less than the maximum sampling time." ) - return values + return self class LowFrequencySmoothingSpec(AbstractLowFrequencySmoothingSpec): @@ -69,14 +72,14 @@ class LowFrequencySmoothingSpec(AbstractLowFrequencySmoothingSpec): ... ) """ - monitors: tuple[str, ...] = pydantic.Field( - ..., + monitors: tuple[str, ...] = Field( title="Monitors", description="The names of monitors to which low frequency smoothing will be applied.", ) - @pydantic.validator("monitors", always=True) - def _validate_monitors(cls, val, values): + @field_validator("monitors") + @classmethod + def _validate_monitors(cls, val: tuple[str, ...]) -> tuple[str, ...]: """Validate the monitors list is not empty.""" if not val: raise ValueError("The monitors list must not be empty.") diff --git a/tidy3d/components/geometry/__init__.py b/tidy3d/components/geometry/__init__.py index e69de29bb2..a8ed42d8cc 100644 --- a/tidy3d/components/geometry/__init__.py +++ b/tidy3d/components/geometry/__init__.py @@ -0,0 +1,8 @@ +"""Compatibility shim for :mod:`tidy3d._common.components.geometry`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common +from __future__ import annotations + +import tidy3d._common.components.geometry as _module diff --git a/tidy3d/components/geometry/base.py b/tidy3d/components/geometry/base.py index 9a2fc61401..5bfe288e48 100644 --- a/tidy3d/components/geometry/base.py +++ b/tidy3d/components/geometry/base.py @@ -1,3706 +1,24 @@ -"""Abstract base classes for geometry.""" +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.base`.""" -from __future__ import annotations - -import functools -import pathlib -from abc import ABC, abstractmethod -from collections.abc import Iterable, Sequence -from os import PathLike -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import autograd.numpy as np -import pydantic.v1 as pydantic -import shapely -from numpy._typing import ArrayLike, NDArray -from typing_extensions import Self - -try: - from matplotlib import patches -except ImportError: - pass +# marked as migrated to _common +from __future__ import annotations -from tidy3d.compat import _shapely_is_older_than -from tidy3d.components.autograd import ( - AutogradFieldMap, - TracedCoordinate, - TracedFloat, - TracedSize, - get_static, +from tidy3d._common.components.geometry.base import ( + POLY_DISTANCE_TOLERANCE, + POLY_GRID_SIZE, + POLY_TOLERANCE_RATIO, + Box, + Centered, + Circular, + ClipOperation, + Geometry, + GeometryGroup, + Planar, + SimplePlaneIntersection, + Transformed, + _bit_operations, + _shapely_operations, + cleanup_shapely_object, ) -from tidy3d.components.autograd.derivative_utils import DerivativeInfo -from tidy3d.components.base import Tidy3dBaseModel, cached_property -from tidy3d.components.geometry.bound_ops import bounds_intersection, bounds_union -from tidy3d.components.geometry.float_utils import increment_float -from tidy3d.components.transformation import ReflectionFromPlane, RotationAroundAxis -from tidy3d.components.types import ( - ArrayFloat2D, - ArrayFloat3D, - Ax, - Axis, - Bound, - ClipOperationType, - Coordinate, - Coordinate2D, - LengthUnit, - MatrixReal4x4, - PlanePosition, - Shapely, - Size, - annotate_type, -) -from tidy3d.components.viz import ( - ARROW_LENGTH, - PLOT_BUFFER, - PlotParams, - VisualizationSpec, - add_ax_if_none, - arrow_style, - equal_aspect, - plot_params_geometry, - polygon_patch, - set_default_labels_and_title, -) -from tidy3d.constants import LARGE_NUMBER, MICROMETER, RADIAN, fp_eps, inf -from tidy3d.exceptions import ( - SetupError, - Tidy3dError, - Tidy3dImportError, - Tidy3dKeyError, - ValidationError, -) -from tidy3d.log import log -from tidy3d.packaging import verify_packages_import - -if TYPE_CHECKING: - from gdstk import Cell - from matplotlib.backend_bases import Event - from matplotlib.patches import FancyArrowPatch - -POLY_GRID_SIZE = 1e-12 -POLY_TOLERANCE_RATIO = 1e-12 -POLY_DISTANCE_TOLERANCE = 8e-12 - - -_shapely_operations = { - "union": shapely.union, - "intersection": shapely.intersection, - "difference": shapely.difference, - "symmetric_difference": shapely.symmetric_difference, -} - -_bit_operations = { - "union": lambda a, b: a | b, - "intersection": lambda a, b: a & b, - "difference": lambda a, b: a & ~b, - "symmetric_difference": lambda a, b: a != b, -} - - -class Geometry(Tidy3dBaseModel, ABC): - """Abstract base class, defines where something exists in space.""" - - @cached_property - def plot_params(self) -> PlotParams: - """Default parameters for plotting a Geometry object.""" - return plot_params_geometry - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - - def point_inside(x: float, y: float, z: float) -> bool: - """Returns ``True`` if a single point ``(x, y, z)`` is inside.""" - shapes_intersect = self.intersections_plane(z=z) - loc = self.make_shapely_point(x, y) - return any(shape.contains(loc) for shape in shapes_intersect) - - arrays = tuple(map(np.array, (x, y, z))) - self._ensure_equal_shape(*arrays) - inside = np.zeros((arrays[0].size,), dtype=bool) - arrays_flat = map(np.ravel, arrays) - for ipt, args in enumerate(zip(*arrays_flat)): - inside[ipt] = point_inside(*args) - return inside.reshape(arrays[0].shape) - - @staticmethod - def _ensure_equal_shape(*arrays: Any) -> None: - """Ensure all input arrays have the same shape.""" - shapes = {np.array(arr).shape for arr in arrays} - if len(shapes) > 1: - raise ValueError("All coordinate inputs (x, y, z) must have the same shape.") - - @staticmethod - def make_shapely_box(minx: float, miny: float, maxx: float, maxy: float) -> shapely.box: - """Make a shapely box ensuring everything untraced.""" - - minx = get_static(minx) - miny = get_static(miny) - maxx = get_static(maxx) - maxy = get_static(maxy) - - return shapely.box(minx, miny, maxx, maxy) - - @staticmethod - def make_shapely_point(minx: float, miny: float) -> shapely.Point: - """Make a shapely Point ensuring everything untraced.""" - - minx = get_static(minx) - miny = get_static(miny) - - return shapely.Point(minx, miny) - - def _inds_inside_bounds( - self, x: NDArray[float], y: NDArray[float], z: NDArray[float] - ) -> tuple[slice, slice, slice]: - """Return slices into the sorted input arrays that are inside the geometry bounds. - - Parameters - ---------- - x : np.ndarray[float] - 1D array of point positions in x direction. - y : np.ndarray[float] - 1D array of point positions in y direction. - z : np.ndarray[float] - 1D array of point positions in z direction. - - Returns - ------- - Tuple[slice, slice, slice] - Slices into each of the three arrays that are inside the geometry bounds. - """ - bounds = self.bounds - inds_in = [] - for dim, coords in enumerate([x, y, z]): - inds = np.nonzero((bounds[0][dim] <= coords) * (coords <= bounds[1][dim]))[0] - inds_in.append(slice(0, 0) if inds.size == 0 else slice(inds[0], inds[-1] + 1)) - - return tuple(inds_in) - - def inside_meshgrid( - self, x: NDArray[float], y: NDArray[float], z: NDArray[float] - ) -> NDArray[bool]: - """Perform ``self.inside`` on a set of sorted 1D coordinates. Applies meshgrid to the - supplied coordinates before checking inside. - - Parameters - ---------- - - x : np.ndarray[float] - 1D array of point positions in x direction. - y : np.ndarray[float] - 1D array of point positions in y direction. - z : np.ndarray[float] - 1D array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every - point that is inside the geometry. - """ - - arrays = tuple(map(np.array, (x, y, z))) - if any(arr.ndim != 1 for arr in arrays): - raise ValueError("Each of the supplied coordinates (x, y, z) must be 1D.") - shape = tuple(arr.size for arr in arrays) - is_inside = np.zeros(shape, dtype=bool) - inds_inside = self._inds_inside_bounds(*arrays) - coords_inside = tuple(arr[ind] for ind, arr in zip(inds_inside, arrays)) - coords_3d = np.meshgrid(*coords_inside, indexing="ij") - is_inside[inds_inside] = self.inside(*coords_3d) - return is_inside - - @abstractmethod - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - origin = self.unpop_axis(position, (0, 0), axis=axis) - normal = self.unpop_axis(1, (0, 0), axis=axis) - to_2D = np.eye(4) - if axis != 2: - last, indices = self.pop_axis((0, 1, 2), axis) - to_2D = to_2D[[*list(indices), last, 3]] - return self.intersections_tilted_plane( - normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs - ) - - def intersections_2dbox(self, plane: Box) -> list[Shapely]: - """Returns list of shapely geometries representing the intersections of the geometry with - a 2D box. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. For more details refer to - `Shapely's Documentation `_. - """ - log.warning( - "'intersections_2dbox()' is deprecated and will be removed in the future. " - "Use 'plane.intersections_with(...)' for the same functionality." - ) - return plane.intersections_with(self) - - def intersects( - self, other: Geometry, strict_inequality: tuple[bool, bool, bool] = [False, False, False] - ) -> bool: - """Returns ``True`` if two :class:`Geometry` have intersecting `.bounds`. - - Parameters - ---------- - other : :class:`Geometry` - Geometry to check intersection with. - strict_inequality : Tuple[bool, bool, bool] = [False, False, False] - For each dimension, defines whether to include equality in the boundaries comparison. - If ``False``, equality is included, and two geometries that only intersect at their - boundaries will evaluate as ``True``. If ``True``, such geometries will evaluate as - ``False``. - - Returns - ------- - bool - Whether the rectangular bounding boxes of the two geometries intersect. - """ - - self_bmin, self_bmax = self.bounds - other_bmin, other_bmax = other.bounds - - for smin, omin, smax, omax, strict in zip( - self_bmin, other_bmin, self_bmax, other_bmax, strict_inequality - ): - # are all of other's minimum coordinates less than self's maximum coordinate? - in_minus = omin < smax if strict else omin <= smax - # are all of other's maximum coordinates greater than self's minimum coordinate? - in_plus = omax > smin if strict else omax >= smin - - # if either failed, return False - if not all((in_minus, in_plus)): - return False - - return True - - def contains( - self, other: Geometry, strict_inequality: tuple[bool, bool, bool] = [False, False, False] - ) -> bool: - """Returns ``True`` if the `.bounds` of ``other`` are contained within the - `.bounds` of ``self``. - - Parameters - ---------- - other : :class:`Geometry` - Geometry to check containment with. - strict_inequality : Tuple[bool, bool, bool] = [False, False, False] - For each dimension, defines whether to include equality in the boundaries comparison. - If ``False``, equality will be considered as contained. If ``True``, ``other``'s - bounds must be strictly within the bounds of ``self``. - - Returns - ------- - bool - Whether the rectangular bounding box of ``other`` is contained within the bounding - box of ``self``. - """ - - self_bmin, self_bmax = self.bounds - other_bmin, other_bmax = other.bounds - - for smin, omin, smax, omax, strict in zip( - self_bmin, other_bmin, self_bmax, other_bmax, strict_inequality - ): - # are all of other's minimum coordinates greater than self's minimim coordinate? - in_minus = omin > smin if strict else omin >= smin - # are all of other's maximum coordinates less than self's maximum coordinate? - in_plus = omax < smax if strict else omax <= smax - - # if either failed, return False - if not all((in_minus, in_plus)): - return False - - return True - - def intersects_plane( - self, x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None - ) -> bool: - """Whether self intersects plane specified by one non-None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - - Returns - ------- - bool - Whether this geometry intersects the plane. - """ - - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - return self.intersects_axis_position(axis, position) - - def intersects_axis_position(self, axis: int, position: float) -> bool: - """Whether self intersects plane specified by a given position along a normal axis. - - Parameters - ---------- - axis : int = None - Axis normal to the plane. - position : float = None - Position of plane along the normal axis. - - Returns - ------- - bool - Whether this geometry intersects the plane. - """ - return self.bounds[0][axis] <= position <= self.bounds[1][axis] - - @cached_property - @abstractmethod - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - - @staticmethod - def bounds_intersection(bounds1: Bound, bounds2: Bound) -> Bound: - """Return the bounds that are the intersection of two bounds.""" - return bounds_intersection(bounds1, bounds2) - - @staticmethod - def bounds_union(bounds1: Bound, bounds2: Bound) -> Bound: - """Return the bounds that are the union of two bounds.""" - return bounds_union(bounds1, bounds2) - - @cached_property - def bounding_box(self) -> Box: - """Returns :class:`Box` representation of the bounding box of a :class:`Geometry`. - - Returns - ------- - :class:`Box` - Geometric object representing bounding box. - """ - return Box.from_bounds(*self.bounds) - - @cached_property - def zero_dims(self) -> list[Axis]: - """A list of axes along which the :class:`Geometry` is zero-sized based on its bounds.""" - zero_dims = [] - for dim in range(3): - if self.bounds[1][dim] == self.bounds[0][dim]: - zero_dims.append(dim) - return zero_dims - - def _pop_bounds(self, axis: Axis) -> tuple[Coordinate2D, tuple[Coordinate2D, Coordinate2D]]: - """Returns min and max bounds in plane normal to and tangential to ``axis``. - - Parameters - ---------- - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - Tuple[float, float], Tuple[Tuple[float, float], Tuple[float, float]] - Bounds along axis and a tuple of bounds in the ordered planar coordinates. - Packed as ``(zmin, zmax), ((xmin, ymin), (xmax, ymax))``. - """ - b_min, b_max = self.bounds - zmin, (xmin, ymin) = self.pop_axis(b_min, axis=axis) - zmax, (xmax, ymax) = self.pop_axis(b_max, axis=axis) - return (zmin, zmax), ((xmin, ymin), (xmax, ymax)) - - @staticmethod - def _get_center(pt_min: float, pt_max: float) -> float: - """Returns center point based on bounds along dimension.""" - if np.isneginf(pt_min) and np.isposinf(pt_max): - return 0.0 - if np.isneginf(pt_min) or np.isposinf(pt_max): - raise SetupError( - f"Bounds of ({pt_min}, {pt_max}) supplied along one dimension. " - "We currently don't support a single ``inf`` value in bounds for ``Box``. " - "To construct a semi-infinite ``Box``, " - "please supply a large enough number instead of ``inf``. " - "For example, a location extending outside of the " - "Simulation domain (including PML)." - ) - return (pt_min + pt_max) / 2.0 - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - raise ValidationError("'Medium2D' is not compatible with this geometry class.") - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Geometry: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - raise NotImplementedError( - "'_update_from_bounds' is not compatible with this geometry class." - ) - - @equal_aspect - @add_ax_if_none - def plot( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - ax: Ax = None, - plot_length_units: LengthUnit = None, - viz_spec: VisualizationSpec = None, - **patch_kwargs: Any, - ) -> Ax: - """Plot geometry cross section at single (x,y,z) coordinate. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - plot_length_units : LengthUnit = None - Specify units to use for axis labels, tick labels, and the title. - viz_spec : VisualizationSpec = None - Plotting parameters associated with a medium to use instead of defaults. - **patch_kwargs - Optional keyword arguments passed to the matplotlib patch plotting of structure. - For details on accepted values, refer to - `Matplotlib's documentation `_. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - - # find shapes that intersect self at plane - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - shapes_intersect = self.intersections_plane(x=x, y=y, z=z) - - plot_params = self.plot_params - if viz_spec is not None: - plot_params = plot_params.override_with_viz_spec(viz_spec) - plot_params = plot_params.include_kwargs(**patch_kwargs) - - # for each intersection, plot the shape - for shape in shapes_intersect: - ax = self.plot_shape(shape, plot_params=plot_params, ax=ax) - - # clean up the axis display - ax = self.add_ax_lims(axis=axis, ax=ax) - ax.set_aspect("equal") - # Add the default axis labels, tick labels, and title - ax = Box.add_ax_labels_and_title(ax=ax, x=x, y=y, z=z, plot_length_units=plot_length_units) - return ax - - def plot_shape(self, shape: Shapely, plot_params: PlotParams, ax: Ax) -> Ax: - """Defines how a shape is plotted on a matplotlib axes.""" - if shape.geom_type in ( - "MultiPoint", - "MultiLineString", - "MultiPolygon", - "GeometryCollection", - ): - for sub_shape in shape.geoms: - ax = self.plot_shape(shape=sub_shape, plot_params=plot_params, ax=ax) - - return ax - - _shape = Geometry.evaluate_inf_shape(shape) - - if _shape.geom_type == "LineString": - xs, ys = zip(*_shape.coords) - ax.plot(xs, ys, color=plot_params.facecolor, linewidth=plot_params.linewidth) - elif _shape.geom_type == "Point": - ax.scatter(shape.x, shape.y, color=plot_params.facecolor) - else: - patch = polygon_patch(_shape, **plot_params.to_kwargs()) - ax.add_artist(patch) - return ax - - @staticmethod - def _do_not_intersect( - bounds_a: float, bounds_b: float, shape_a: Shapely, shape_b: Shapely - ) -> bool: - """Check whether two shapes intersect.""" - - # do a bounding box check to see if any intersection to do anything about - if ( - bounds_a[0] > bounds_b[2] - or bounds_b[0] > bounds_a[2] - or bounds_a[1] > bounds_b[3] - or bounds_b[1] > bounds_a[3] - ): - return True - - # look more closely to see if intersected. - if shape_b.is_empty or not shape_a.intersects(shape_b): - return True - - return False - - @staticmethod - def _get_plot_labels(axis: Axis) -> tuple[str, str]: - """Returns planar coordinate x and y axis labels for cross section plots. - - Parameters - ---------- - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - str, str - Labels of plot, packaged as ``(xlabel, ylabel)``. - """ - _, (xlabel, ylabel) = Geometry.pop_axis("xyz", axis=axis) - return xlabel, ylabel - - def _get_plot_limits( - self, axis: Axis, buffer: float = PLOT_BUFFER - ) -> tuple[Coordinate2D, Coordinate2D]: - """Gets planar coordinate limits for cross section plots. - - Parameters - ---------- - axis : int - Integer index into 'xyz' (0,1,2). - buffer : float = 0.3 - Amount of space to add around the limits on the + and - sides. - - Returns - ------- - Tuple[float, float], Tuple[float, float] - The x and y plot limits, packed as ``(xmin, xmax), (ymin, ymax)``. - """ - _, ((xmin, ymin), (xmax, ymax)) = self._pop_bounds(axis=axis) - return (xmin - buffer, xmax + buffer), (ymin - buffer, ymax + buffer) - - def add_ax_lims(self, axis: Axis, ax: Ax, buffer: float = PLOT_BUFFER) -> Ax: - """Sets the x,y limits based on ``self.bounds``. - - Parameters - ---------- - axis : int - Integer index into 'xyz' (0,1,2). - ax : matplotlib.axes._subplots.Axes - Matplotlib axes to add labels and limits on. - buffer : float = 0.3 - Amount of space to place around the limits on the + and - sides. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - (xmin, xmax), (ymin, ymax) = self._get_plot_limits(axis=axis, buffer=buffer) - - # note: axes limits dont like inf values, so we need to evaluate them first if present - xmin, xmax, ymin, ymax = self._evaluate_inf((xmin, xmax, ymin, ymax)) - - ax.set_xlim(xmin, xmax) - ax.set_ylim(ymin, ymax) - return ax - - @staticmethod - def add_ax_labels_and_title( - ax: Ax, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - plot_length_units: LengthUnit = None, - ) -> Ax: - """Sets the axis labels, tick labels, and title based on ``axis`` - and an optional ``plot_length_units`` argument. - - Parameters - ---------- - ax : matplotlib.axes._subplots.Axes - Matplotlib axes to add labels and limits on. - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - plot_length_units : LengthUnit = None - When set to a supported ``LengthUnit``, plots will be produced with annotated axes - and title with the proper units. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied matplotlib axes. - """ - axis, position = Box.parse_xyz_kwargs(x=x, y=y, z=z) - axis_labels = Box._get_plot_labels(axis) - ax = set_default_labels_and_title( - axis_labels=axis_labels, - axis=axis, - position=position, - ax=ax, - plot_length_units=plot_length_units, - ) - return ax - - @staticmethod - def _evaluate_inf(array: ArrayLike) -> NDArray[np.floating]: - """Processes values and evaluates any infs into large (signed) numbers.""" - array = get_static(np.array(array)) - return np.where(np.isinf(array), np.sign(array) * LARGE_NUMBER, array) - - @staticmethod - def evaluate_inf_shape(shape: Shapely) -> Shapely: - """Returns a copy of shape with inf vertices replaced by large numbers if polygon.""" - if not any(np.isinf(b) for b in shape.bounds): - return shape - - def _processed_coords(coords: Sequence[tuple[Any, ...]]) -> list[tuple[float, ...]]: - evaluated = Geometry._evaluate_inf(np.array(coords)) - return [tuple(point) for point in evaluated.tolist()] - - if shape.geom_type == "Polygon": - shell = _processed_coords(shape.exterior.coords) - holes = [_processed_coords(g.coords) for g in shape.interiors] - return shapely.Polygon(shell, holes) - if shape.geom_type in {"Point", "LineString", "LinearRing"}: - return shape.__class__(Geometry._evaluate_inf(np.array(shape.coords))) - if shape.geom_type in { - "MultiPoint", - "MultiLineString", - "MultiPolygon", - "GeometryCollection", - }: - return shape.__class__([Geometry.evaluate_inf_shape(g) for g in shape.geoms]) - return shape - - @staticmethod - def pop_axis(coord: tuple[Any, Any, Any], axis: int) -> tuple[Any, tuple[Any, Any]]: - """Separates coordinate at ``axis`` index from coordinates on the plane tangent to ``axis``. - - Parameters - ---------- - coord : Tuple[Any, Any, Any] - Tuple of three values in original coordinate system. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - Any, Tuple[Any, Any] - The input coordinates are separated into the one along the axis provided - and the two on the planar coordinates, - like ``axis_coord, (planar_coord1, planar_coord2)``. - """ - plane_vals = list(coord) - axis_val = plane_vals.pop(axis) - return axis_val, tuple(plane_vals) - - @staticmethod - def unpop_axis(ax_coord: Any, plane_coords: tuple[Any, Any], axis: int) -> tuple[Any, Any, Any]: - """Combine coordinate along axis with coordinates on the plane tangent to the axis. - - Parameters - ---------- - ax_coord : Any - Value along axis direction. - plane_coords : Tuple[Any, Any] - Values along ordered planar directions. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - Tuple[Any, Any, Any] - The three values in the xyz coordinate system. - """ - coords = list(plane_coords) - coords.insert(axis, ax_coord) - return tuple(coords) - - @staticmethod - def parse_xyz_kwargs(**xyz: Any) -> tuple[Axis, float]: - """Turns x,y,z kwargs into index of the normal axis and position along that axis. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - - Returns - ------- - int, float - Index into xyz axis (0,1,2) and position along that axis. - """ - xyz_filtered = {k: v for k, v in xyz.items() if v is not None} - if len(xyz_filtered) != 1: - raise ValueError("exactly one kwarg in [x,y,z] must be specified.") - axis_label, position = list(xyz_filtered.items())[0] - axis = "xyz".index(axis_label) - return axis, position - - @staticmethod - def parse_two_xyz_kwargs(**xyz: Any) -> list[tuple[Axis, float]]: - """Turns x,y,z kwargs into indices of axes and the position along each axis. - - Parameters - ---------- - x : float = None - Position in x direction, only two of x,y,z can be specified to define line. - y : float = None - Position in y direction, only two of x,y,z can be specified to define line. - z : float = None - Position in z direction, only two of x,y,z can be specified to define line. - - Returns - ------- - [(int, float), (int, float)] - Index into xyz axis (0,1,2) and position along that axis. - """ - xyz_filtered = {k: v for k, v in xyz.items() if v is not None} - assert len(xyz_filtered) == 2, "exactly two kwarg in [x,y,z] must be specified." - xyz_list = list(xyz_filtered.items()) - return [("xyz".index(axis_label), position) for axis_label, position in xyz_list] - - @staticmethod - def rotate_points(points: ArrayFloat3D, axis: Coordinate, angle: float) -> ArrayFloat3D: - """Rotate a set of points in 3D. - - Parameters - ---------- - points : ArrayLike[float] - Array of shape ``(3, ...)``. - axis : Coordinate - Axis of rotation - angle : float - Angle of rotation counter-clockwise around the axis (rad). - """ - rotation = RotationAroundAxis(axis=axis, angle=angle) - return rotation.rotate_vector(points) - - def reflect_points( - self, - points: ArrayFloat3D, - polar_axis: Axis, - angle_theta: float, - angle_phi: float, - ) -> ArrayFloat3D: - """Reflect a set of points in 3D at a plane passing through the coordinate origin defined - and normal to a given axis defined in polar coordinates (theta, phi) w.r.t. the - ``polar_axis`` which can be 0, 1, or 2. - - Parameters - ---------- - points : ArrayLike[float] - Array of shape ``(3, ...)``. - polar_axis : Axis - Cartesian axis w.r.t. which the normal axis angles are defined. - angle_theta : float - Polar angle w.r.t. the polar axis. - angle_phi : float - Azimuth angle around the polar axis. - """ - - # Rotate such that the plane normal is along the polar_axis - axis_theta, axis_phi = [0, 0, 0], [0, 0, 0] - axis_phi[polar_axis] = 1 - plane_axes = [0, 1, 2] - plane_axes.pop(polar_axis) - axis_theta[plane_axes[1]] = 1 - points_new = self.rotate_points(points, axis_phi, -angle_phi) - points_new = self.rotate_points(points_new, axis_theta, -angle_theta) - - # Flip the ``polar_axis`` coordinate of the points, which is now normal to the plane - points_new[polar_axis, :] *= -1 - - # Rotate back - points_new = self.rotate_points(points_new, axis_theta, angle_theta) - points_new = self.rotate_points(points_new, axis_phi, angle_phi) - - return points_new - - def volume(self, bounds: Bound = None) -> float: - """Returns object's volume with optional bounds. - - Parameters - ---------- - bounds : Tuple[Tuple[float, float, float], Tuple[float, float, float]] = None - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - float - Volume in um^3. - """ - - if not bounds: - bounds = self.bounds - - return self._volume(bounds) - - @abstractmethod - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - - def surface_area(self, bounds: Bound = None) -> float: - """Returns object's surface area with optional bounds. - - Parameters - ---------- - bounds : Tuple[Tuple[float, float, float], Tuple[float, float, float]] = None - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - float - Surface area in um^2. - """ - - if not bounds: - bounds = self.bounds - - return self._surface_area(bounds) - - @abstractmethod - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - - def translated(self, x: float, y: float, z: float) -> Geometry: - """Return a translated copy of this geometry. - - Parameters - ---------- - x : float - Translation along x. - y : float - Translation along y. - z : float - Translation along z. - - Returns - ------- - :class:`Geometry` - Translated copy of this geometry. - """ - return Transformed(geometry=self, transform=Transformed.translation(x, y, z)) - - def scaled(self, x: float = 1.0, y: float = 1.0, z: float = 1.0) -> Geometry: - """Return a scaled copy of this geometry. - - Parameters - ---------- - x : float = 1.0 - Scaling factor along x. - y : float = 1.0 - Scaling factor along y. - z : float = 1.0 - Scaling factor along z. - - Returns - ------- - :class:`Geometry` - Scaled copy of this geometry. - """ - return Transformed(geometry=self, transform=Transformed.scaling(x, y, z)) - - def rotated(self, angle: float, axis: Union[Axis, Coordinate]) -> Geometry: - """Return a rotated copy of this geometry. - - Parameters - ---------- - angle : float - Rotation angle (in radians). - axis : Union[int, Tuple[float, float, float]] - Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. - - Returns - ------- - :class:`Geometry` - Rotated copy of this geometry. - """ - return Transformed(geometry=self, transform=Transformed.rotation(angle, axis)) - - def reflected(self, normal: Coordinate) -> Geometry: - """Return a reflected copy of this geometry. - - Parameters - ---------- - normal : Tuple[float, float, float] - The 3D normal vector of the plane of reflection. The plane is assumed - to pass through the origin (0,0,0). - - Returns - ------- - :class:`Geometry` - Reflected copy of this geometry. - """ - return Transformed(geometry=self, transform=Transformed.reflection(normal)) - - """ Field and coordinate transformations """ - - @staticmethod - def car_2_sph(x: float, y: float, z: float) -> tuple[float, float, float]: - """Convert Cartesian to spherical coordinates. - - Parameters - ---------- - x : float - x coordinate relative to ``local_origin``. - y : float - y coordinate relative to ``local_origin``. - z : float - z coordinate relative to ``local_origin``. - - Returns - ------- - Tuple[float, float, float] - r, theta, and phi coordinates relative to ``local_origin``. - """ - r = np.sqrt(x**2 + y**2 + z**2) - theta = np.arccos(z / r) - phi = np.arctan2(y, x) - return r, theta, phi - - @staticmethod - def sph_2_car(r: float, theta: float, phi: float) -> tuple[float, float, float]: - """Convert spherical to Cartesian coordinates. - - Parameters - ---------- - r : float - radius. - theta : float - polar angle (rad) downward from x=y=0 line. - phi : float - azimuthal (rad) angle from y=z=0 line. - - Returns - ------- - Tuple[float, float, float] - x, y, and z coordinates relative to ``local_origin``. - """ - r_sin_theta = r * np.sin(theta) - x = r_sin_theta * np.cos(phi) - y = r_sin_theta * np.sin(phi) - z = r * np.cos(theta) - return x, y, z - - @staticmethod - def sph_2_car_field( - f_r: float, f_theta: float, f_phi: float, theta: float, phi: float - ) -> tuple[complex, complex, complex]: - """Convert vector field components in spherical coordinates to cartesian. - - Parameters - ---------- - f_r : float - radial component of the vector field. - f_theta : float - polar angle component of the vector fielf. - f_phi : float - azimuthal angle component of the vector field. - theta : float - polar angle (rad) of location of the vector field. - phi : float - azimuthal angle (rad) of location of the vector field. - - Returns - ------- - Tuple[float, float, float] - x, y, and z components of the vector field in cartesian coordinates. - """ - sin_theta = np.sin(theta) - cos_theta = np.cos(theta) - sin_phi = np.sin(phi) - cos_phi = np.cos(phi) - f_x = f_r * sin_theta * cos_phi + f_theta * cos_theta * cos_phi - f_phi * sin_phi - f_y = f_r * sin_theta * sin_phi + f_theta * cos_theta * sin_phi + f_phi * cos_phi - f_z = f_r * cos_theta - f_theta * sin_theta - return f_x, f_y, f_z - - @staticmethod - def car_2_sph_field( - f_x: float, f_y: float, f_z: float, theta: float, phi: float - ) -> tuple[complex, complex, complex]: - """Convert vector field components in cartesian coordinates to spherical. - - Parameters - ---------- - f_x : float - x component of the vector field. - f_y : float - y component of the vector fielf. - f_z : float - z component of the vector field. - theta : float - polar angle (rad) of location of the vector field. - phi : float - azimuthal angle (rad) of location of the vector field. - - Returns - ------- - Tuple[float, float, float] - radial (s), elevation (theta), and azimuthal (phi) components - of the vector field in spherical coordinates. - """ - sin_theta = np.sin(theta) - cos_theta = np.cos(theta) - sin_phi = np.sin(phi) - cos_phi = np.cos(phi) - f_r = f_x * sin_theta * cos_phi + f_y * sin_theta * sin_phi + f_z * cos_theta - f_theta = f_x * cos_theta * cos_phi + f_y * cos_theta * sin_phi - f_z * sin_theta - f_phi = -f_x * sin_phi + f_y * cos_phi - return f_r, f_theta, f_phi - - @staticmethod - def kspace_2_sph(ux: float, uy: float, axis: Axis) -> tuple[float, float]: - """Convert normalized k-space coordinates to angles. - - Parameters - ---------- - ux : float - normalized kx coordinate. - uy : float - normalized ky coordinate. - axis : int - axis along which the observation plane is oriented. - - Returns - ------- - Tuple[float, float] - theta and phi coordinates relative to ``local_origin``. - """ - phi_local = np.arctan2(uy, ux) - with np.errstate(invalid="ignore"): - theta_local = np.arcsin(np.sqrt(ux**2 + uy**2)) - # Spherical coordinates rotation matrix reference: - # https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula#Matrix_notation - if axis == 2: - return theta_local, phi_local - - x = np.cos(theta_local) - y = np.sin(theta_local) * np.cos(phi_local) - z = np.sin(theta_local) * np.sin(phi_local) - - if axis == 1: - x, y, z = y, x, z - - theta = np.arccos(z) - phi = np.arctan2(y, x) - return theta, phi - - @staticmethod - @verify_packages_import(["gdstk"]) - def load_gds_vertices_gdstk( - gds_cell: Cell, - gds_layer: int, - gds_dtype: Optional[int] = None, - gds_scale: pydantic.PositiveFloat = 1.0, - ) -> list[ArrayFloat2D]: - """Load polygon vertices from a ``gdstk.Cell``. - - Parameters - ---------- - gds_cell : gdstk.Cell - ``gdstk.Cell`` containing 2D geometric data. - gds_layer : int - Layer index in the ``gds_cell``. - gds_dtype : int = None - Data-type index in the ``gds_cell``. If ``None``, imports all data for this layer into - the returned list. - gds_scale : float = 1.0 - Length scale used in GDS file in units of micrometer. For example, if gds file uses - nanometers, set ``gds_scale=1e-3``. Must be positive. - - Returns - ------- - List[ArrayFloat2D] - List of polygon vertices - """ - - # apply desired scaling and load the polygon vertices - if gds_dtype is not None: - # if both layer and datatype are specified, let gdstk do the filtering for better - # performance on large layouts - all_vertices = [ - polygon.scale(gds_scale).points - for polygon in gds_cell.get_polygons(layer=gds_layer, datatype=gds_dtype) - ] - else: - all_vertices = [ - polygon.scale(gds_scale).points - for polygon in gds_cell.get_polygons() - if polygon.layer == gds_layer - ] - # make sure something got loaded, otherwise error - if not all_vertices: - raise Tidy3dKeyError( - f"Couldn't load gds_cell, no vertices found at gds_layer={gds_layer} " - f"with specified gds_dtype={gds_dtype}." - ) - - return all_vertices - - @staticmethod - @verify_packages_import(["gdstk"]) - def from_gds( - gds_cell: Cell, - axis: Axis, - slab_bounds: tuple[float, float], - gds_layer: int, - gds_dtype: Optional[int] = None, - gds_scale: pydantic.PositiveFloat = 1.0, - dilation: float = 0.0, - sidewall_angle: float = 0, - reference_plane: PlanePosition = "middle", - ) -> Geometry: - """Import a ``gdstk.Cell`` and extrude it into a GeometryGroup. - - Parameters - ---------- - gds_cell : gdstk.Cell - ``gdstk.Cell`` containing 2D geometric data. - axis : int - Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). - slab_bounds: Tuple[float, float] - Minimal and maximal positions of the extruded slab along ``axis``. - gds_layer : int - Layer index in the ``gds_cell``. - gds_dtype : int = None - Data-type index in the ``gds_cell``. If ``None``, imports all data for this layer into - the returned list. - gds_scale : float = 1.0 - Length scale used in GDS file in units of micrometer. For example, if gds file uses - nanometers, set ``gds_scale=1e-3``. Must be positive. - dilation : float = 0.0 - Dilation (positive) or erosion (negative) amount to be applied to the original polygons. - sidewall_angle : float = 0 - Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive - (negative) values result in slabs larger (smaller) at the base than at the top. - reference_plane : PlanePosition = "middle" - Reference position of the (dilated/eroded) polygons along the slab axis. One of - ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` - (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value - has no effect if ``sidewall_angle == 0``. - - Returns - ------- - :class:`Geometry` - Geometries created from the 2D data. - """ - import gdstk - - if not isinstance(gds_cell, gdstk.Cell): - # Check if it might be a gdstk cell but gdstk is not found (should be caught by decorator) - # or if it's an entirely different type. - if "gdstk" in gds_cell.__class__.__name__.lower(): - raise Tidy3dImportError( - "Module 'gdstk' not found. It is required to import gdstk cells." - ) - raise Tidy3dImportError("Argument 'gds_cell' must be an instance of 'gdstk.Cell'.") - - gds_loader_fn = Geometry.load_gds_vertices_gdstk - geometries = [] - with log as consolidated_logger: - for vertices in gds_loader_fn(gds_cell, gds_layer, gds_dtype, gds_scale): - # buffer(0) is necessary to merge self-intersections - shape = shapely.set_precision(shapely.Polygon(vertices).buffer(0), POLY_GRID_SIZE) - try: - geometries.append( - from_shapely( - shape, axis, slab_bounds, dilation, sidewall_angle, reference_plane - ) - ) - except pydantic.ValidationError as error: - consolidated_logger.warning(str(error)) - except Tidy3dError as error: - consolidated_logger.warning(str(error)) - return geometries[0] if len(geometries) == 1 else GeometryGroup(geometries=geometries) - - @staticmethod - def from_shapely( - shape: Shapely, - axis: Axis, - slab_bounds: tuple[float, float], - dilation: float = 0.0, - sidewall_angle: float = 0, - reference_plane: PlanePosition = "middle", - ) -> Geometry: - """Convert a shapely primitive into a geometry instance by extrusion. - - Parameters - ---------- - shape : shapely.geometry.base.BaseGeometry - Shapely primitive to be converted. It must be a linear ring, a polygon or a collection - of any of those. - axis : int - Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). - slab_bounds: Tuple[float, float] - Minimal and maximal positions of the extruded slab along ``axis``. - dilation : float - Dilation of the polygon in the base by shifting each edge along its normal outwards - direction by a distance; a negative value corresponds to erosion. - sidewall_angle : float = 0 - Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive - (negative) values result in slabs larger (smaller) at the base than at the top. - reference_plane : PlanePosition = "middle" - Reference position of the (dilated/eroded) polygons along the slab axis. One of - ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` - (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value - has no effect if ``sidewall_angle == 0``. - - Returns - ------- - :class:`Geometry` - Geometry extruded from the 2D data. - """ - return from_shapely(shape, axis, slab_bounds, dilation, sidewall_angle, reference_plane) - - @verify_packages_import(["gdstk"]) - def to_gdstk( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, - ) -> list: - """Convert a Geometry object's planar slice to a .gds type polygon. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - gds_layer : int = 0 - Layer index to use for the shapes stored in the .gds file. - gds_dtype : int = 0 - Data-type index to use for the shapes stored in the .gds file. - - Return - ------ - List - List of `gdstk.Polygon`. - """ - import gdstk - - shapes = self.intersections_plane(x=x, y=y, z=z) - polygons = [] - for shape in shapes: - for vertices in vertices_from_shapely(shape): - if len(vertices) == 1: - polygons.append(gdstk.Polygon(vertices[0], gds_layer, gds_dtype)) - else: - polygons.extend( - gdstk.boolean( - vertices[:1], - vertices[1:], - "not", - layer=gds_layer, - datatype=gds_dtype, - ) - ) - return polygons - - @verify_packages_import(["gdstk"]) - def to_gds( - self, - cell: Cell, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, - ) -> None: - """Append a Geometry object's planar slice to a .gds cell. - - Parameters - ---------- - cell : ``gdstk.Cell`` - Cell object to which the generated polygons are added. - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - gds_layer : int = 0 - Layer index to use for the shapes stored in the .gds file. - gds_dtype : int = 0 - Data-type index to use for the shapes stored in the .gds file. - """ - import gdstk - - if not isinstance(cell, gdstk.Cell): - if "gdstk" in cell.__class__.__name__.lower(): - raise Tidy3dImportError( - "Module 'gdstk' not found. It is required to export shapes to gdstk cells." - ) - raise Tidy3dImportError("Argument 'cell' must be an instance of 'gdstk.Cell'.") - - polygons = self.to_gdstk(x=x, y=y, z=z, gds_layer=gds_layer, gds_dtype=gds_dtype) - if polygons: - cell.add(*polygons) - - @verify_packages_import(["gdstk"]) - def to_gds_file( - self, - fname: PathLike, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, - gds_cell_name: str = "MAIN", - ) -> None: - """Export a Geometry object's planar slice to a .gds file. - - Parameters - ---------- - fname : PathLike - Full path to the .gds file to save the :class:`Geometry` slice to. - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - gds_layer : int = 0 - Layer index to use for the shapes stored in the .gds file. - gds_dtype : int = 0 - Data-type index to use for the shapes stored in the .gds file. - gds_cell_name : str = 'MAIN' - Name of the cell created in the .gds file to store the geometry. - """ - try: - import gdstk - except ImportError as e: - raise Tidy3dImportError( - "Python module 'gdstk' not found. To export geometries to .gds " - "files, please install it." - ) from e - - library = gdstk.Library() - cell = library.new_cell(gds_cell_name) - self.to_gds(cell, x=x, y=y, z=z, gds_layer=gds_layer, gds_dtype=gds_dtype) - fname = pathlib.Path(fname) - fname.parent.mkdir(parents=True, exist_ok=True) - library.write_gds(fname) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - raise NotImplementedError(f"Can't compute derivative for 'Geometry': '{type(self)}'.") - - def _as_union(self) -> list[Geometry]: - """Return a list of geometries that, united, make up the given geometry.""" - if isinstance(self, GeometryGroup): - return self.geometries - - if isinstance(self, ClipOperation) and self.operation == "union": - return (self.geometry_a, self.geometry_b) - return (self,) - - def __add__(self, other: Union[int, Geometry]) -> Union[Self, GeometryGroup]: - """Union of geometries""" - # This allows the user to write sum(geometries...) with the default start=0 - if isinstance(other, int): - return self - if not isinstance(other, Geometry): - return NotImplemented # type: ignore[return-value] - return GeometryGroup(geometries=self._as_union() + other._as_union()) - - def __radd__(self, other: Union[int, Geometry]) -> Union[Self, GeometryGroup]: - """Union of geometries""" - # This allows the user to write sum(geometries...) with the default start=0 - if isinstance(other, int): - return self - if not isinstance(other, Geometry): - return NotImplemented # type: ignore[return-value] - return GeometryGroup(geometries=other._as_union() + self._as_union()) - - def __or__(self, other: Geometry) -> GeometryGroup: - """Union of geometries""" - if not isinstance(other, Geometry): - return NotImplemented - return GeometryGroup(geometries=self._as_union() + other._as_union()) - - def __mul__(self, other: Geometry) -> ClipOperation: - """Intersection of geometries""" - if not isinstance(other, Geometry): - return NotImplemented - return ClipOperation(operation="intersection", geometry_a=self, geometry_b=other) - - def __and__(self, other: Geometry) -> ClipOperation: - """Intersection of geometries""" - if not isinstance(other, Geometry): - return NotImplemented - return ClipOperation(operation="intersection", geometry_a=self, geometry_b=other) - - def __sub__(self, other: Geometry) -> ClipOperation: - """Difference of geometries""" - if not isinstance(other, Geometry): - return NotImplemented # type: ignore[return-value] - return ClipOperation(operation="difference", geometry_a=self, geometry_b=other) - - def __xor__(self, other: Geometry) -> ClipOperation: - """Symmetric difference of geometries""" - if not isinstance(other, Geometry): - return NotImplemented - return ClipOperation(operation="symmetric_difference", geometry_a=self, geometry_b=other) - - def __pos__(self) -> Self: - """No op""" - return self - - def __neg__(self) -> ClipOperation: - """Opposite of a geometry""" - return ClipOperation( - operation="difference", geometry_a=Box(size=(inf, inf, inf)), geometry_b=self - ) - - def __invert__(self) -> ClipOperation: - """Opposite of a geometry""" - return ClipOperation( - operation="difference", geometry_a=Box(size=(inf, inf, inf)), geometry_b=self - ) - - -""" Abstract subclasses """ - - -class Centered(Geometry, ABC): - """Geometry with a well defined center.""" - - center: TracedCoordinate = pydantic.Field( - (0.0, 0.0, 0.0), - title="Center", - description="Center of object in x, y, and z.", - units=MICROMETER, - ) - - @pydantic.validator("center", always=True) - def _center_not_inf(cls, val: tuple[float, float, float]) -> tuple[float, float, float]: - """Make sure center is not infinitiy.""" - if any(np.isinf(v) for v in val): - raise ValidationError("center can not contain td.inf terms.") - return val - - -class SimplePlaneIntersection(Geometry, ABC): - """A geometry where intersections with an axis aligned plane may be computed efficiently.""" - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - Checks special cases before relying on the complete computation. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - # Check if normal is a special case, where the normal is aligned with an axis. - if np.sum(np.isclose(normal, 0.0)) == 2: - axis = np.argmax(np.abs(normal)).item() - coord = "xyz"[axis] - kwargs = {coord: origin[axis]} - section = self.intersections_plane(cleanup=cleanup, quad_segs=quad_segs, **kwargs) - # Apply transformation in the plane by removing row and column - to_2D_in_plane = np.delete(np.delete(to_2D, 2, 0), axis, 1) - - def transform(p_array: NDArray) -> NDArray: - return np.dot( - np.hstack((p_array, np.ones((p_array.shape[0], 1)))), to_2D_in_plane.T - )[:, :2] - - transformed_section = shapely.transform(section, transformation=transform) - return transformed_section - # Otherwise compute the arbitrary intersection - return self._do_intersections_tilted_plane( - normal=normal, origin=origin, to_2D=to_2D, quad_segs=quad_segs - ) - - @abstractmethod - def _do_intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - -class Planar(SimplePlaneIntersection, Geometry, ABC): - """Geometry with one ``axis`` that is slab-like with thickness ``height``.""" - - axis: Axis = pydantic.Field( - 2, title="Axis", description="Specifies dimension of the planar axis (0,1,2) -> (x,y,z)." - ) - - sidewall_angle: TracedFloat = pydantic.Field( - 0.0, - title="Sidewall angle", - description="Angle of the sidewall. " - "``sidewall_angle=0`` (default) specifies a vertical wall; " - "``0 float: - lower_bound = -np.pi / 2 - upper_bound = np.pi / 2 - if (value <= lower_bound) or (value >= upper_bound): - # u03C0 is unicode for pi - raise ValidationError(f"Sidewall angle ({value}) must be between -π/2 and π/2 rad.") - - return value - - @property - @abstractmethod - def center_axis(self) -> float: - """Gets the position of the center of the geometry in the out of plane dimension.""" - - @property - @abstractmethod - def length_axis(self) -> float: - """Gets the length of the geometry along the out of plane dimension.""" - - @property - def finite_length_axis(self) -> float: - """Gets the length of the geometry along the out of plane dimension. - If the length is td.inf, return ``LARGE_NUMBER`` - """ - return min(self.length_axis, LARGE_NUMBER) - - @property - def reference_axis_pos(self) -> float: - """Coordinate along the slab axis at the reference plane. - - Returns the axis coordinate corresponding to the selected - reference_plane: - - "bottom": lower bound of slab_bounds - - "middle": center_axis - - "top": upper bound of slab_bounds - """ - if self.reference_plane == "bottom": - return self.slab_bounds[0] - if self.reference_plane == "top": - return self.slab_bounds[1] - # default to middle - return self.center_axis - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns shapely geometry at plane specified by one non None value of x,y,z. - - Parameters - ---------- - x : float - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation ``. - """ - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - if not self.intersects_axis_position(axis, position): - return [] - if axis == self.axis: - return self._intersections_normal(position, quad_segs=quad_segs) - return self._intersections_side(position, axis) - - @abstractmethod - def _intersections_normal(self, z: float, quad_segs: Optional[int] = None) -> list: - """Find shapely geometries intersecting planar geometry with axis normal to slab. - - Parameters - ---------- - z : float - Position along the axis normal to slab - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - @abstractmethod - def _intersections_side(self, position: float, axis: Axis) -> list[Shapely]: - """Find shapely geometries intersecting planar geometry with axis orthogonal to plane. - - Parameters - ---------- - position : float - Position along axis. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - def _order_axis(self, axis: int) -> int: - """Order the axis as if self.axis is along z-direction. - - Parameters - ---------- - axis : int - Integer index into the structure's planar axis. - - Returns - ------- - int - New index of axis. - """ - axis_index = [0, 1] - axis_index.insert(self.axis, 2) - return axis_index[axis] - - def _order_by_axis(self, plane_val: Any, axis_val: Any, axis: int) -> tuple[Any, Any]: - """Orders a value in the plane and value along axis in correct (x,y) order for plotting. - Note: sometimes if axis=1 and we compute cross section values orthogonal to axis, - they can either be x or y in the plots. - This function allows one to figure out the ordering. - - Parameters - ---------- - plane_val : Any - The value in the planar coordinate. - axis_val : Any - The value in the ``axis`` coordinate. - axis : int - Integer index into the structure's planar axis. - - Returns - ------- - ``(Any, Any)`` - The two planar coordinates in this new coordinate system. - """ - vals = 3 * [plane_val] - vals[self.axis] = axis_val - _, (val_x, val_y) = self.pop_axis(vals, axis=axis) - return val_x, val_y - - @cached_property - def _tanq(self) -> float: - """Value of ``tan(sidewall_angle)``. - - The (possibliy infinite) geometry offset is given by ``_tanq * length_axis``. - """ - return np.tan(self.sidewall_angle) - - -class Circular(Geometry): - """Geometry with circular characteristics (specified by a radius).""" - - radius: pydantic.NonNegativeFloat = pydantic.Field( - ..., title="Radius", description="Radius of geometry.", units=MICROMETER - ) - - @pydantic.validator("radius", always=True) - def _radius_not_inf(cls, val: float) -> float: - """Make sure center is not infinitiy.""" - if np.isinf(val): - raise ValidationError("radius can not be td.inf.") - return val - - def _intersect_dist(self, position: float, z0: float) -> float: - """Distance between points on circle at z=position where center of circle at z=z0. - - Parameters - ---------- - position : float - position along z. - z0 : float - center of circle in z. - - Returns - ------- - float - Distance between points on the circle intersecting z=z, if no points, ``None``. - """ - dz = np.abs(z0 - position) - if dz > self.radius: - return None - return 2 * np.sqrt(self.radius**2 - dz**2) - - -"""Primitive classes""" - - -class Box(SimplePlaneIntersection, Centered): - """Rectangular prism. - Also base class for :class:`.Simulation`, :class:`Monitor`, and :class:`Source`. - - Example - ------- - >>> b = Box(center=(1,2,3), size=(2,2,2)) - """ - - size: TracedSize = pydantic.Field( - ..., - title="Size", - description="Size in x, y, and z directions.", - units=MICROMETER, - ) - - @classmethod - def from_bounds(cls, rmin: Coordinate, rmax: Coordinate, **kwargs: Any) -> Self: - """Constructs a :class:`Box` from minimum and maximum coordinate bounds - - Parameters - ---------- - rmin : Tuple[float, float, float] - (x, y, z) coordinate of the minimum values. - rmax : Tuple[float, float, float] - (x, y, z) coordinate of the maximum values. - - Example - ------- - >>> b = Box.from_bounds(rmin=(-1, -2, -3), rmax=(3, 2, 1)) - """ - - center = tuple(cls._get_center(pt_min, pt_max) for pt_min, pt_max in zip(rmin, rmax)) - size = tuple((pt_max - pt_min) for pt_min, pt_max in zip(rmin, rmax)) - return cls(center=center, size=size, **kwargs) - - @cached_property - def _normal_axis(self) -> Axis: - """Axis normal to the Box. Errors if box is not planar.""" - if self.size.count(0.0) != 1: - raise ValidationError( - f"Tried to get 'normal_axis' of 'Box' that is not planar. Given 'size={self.size}.'" - ) - return self.size.index(0.0) - - @classmethod - def surfaces(cls, size: Size, center: Coordinate, **kwargs: Any) -> list[Self]: - """Returns a list of 6 :class:`Box` instances corresponding to each surface of a 3D volume. - The output surfaces are stored in the order [x-, x+, y-, y+, z-, z+], where x, y, and z - denote which axis is perpendicular to that surface, while "-" and "+" denote the direction - of the normal vector of that surface. If a name is provided, each output surface's name - will be that of the provided name appended with the above symbols. E.g., if the provided - name is "box", the x+ surfaces's name will be "box_x+". - - Parameters - ---------- - size : Tuple[float, float, float] - Size of object in x, y, and z directions. - center : Tuple[float, float, float] - Center of object in x, y, and z. - - Example - ------- - >>> b = Box.surfaces(size=(1, 2, 3), center=(3, 2, 1)) - """ - - if any(s == 0.0 for s in size): - raise SetupError( - "Can't generate surfaces for the given object because it has zero volume." - ) - - bounds = Box(center=center, size=size).bounds - - # Set up geometry data and names for each surface: - centers = [list(center) for _ in range(6)] - sizes = [list(size) for _ in range(6)] - - surface_index = 0 - for dim_index in range(3): - for min_max_index in range(2): - new_center = centers[surface_index] - new_size = sizes[surface_index] - - new_center[dim_index] = bounds[min_max_index][dim_index] - new_size[dim_index] = 0.0 - - centers[surface_index] = new_center - sizes[surface_index] = new_size - - surface_index += 1 - - name_base = kwargs.pop("name", "") - kwargs.pop("normal_dir", None) - - names = [] - normal_dirs = [] - - for coord in "xyz": - for direction in "-+": - surface_name = name_base + "_" + coord + direction - names.append(surface_name) - normal_dirs.append(direction) - - # ignore surfaces that are infinitely far away - del_idx = [] - for idx, _size in enumerate(size): - if _size == inf: - del_idx.append(idx) - del_idx = [[2 * i, 2 * i + 1] for i in del_idx] - del_idx = [item for sublist in del_idx for item in sublist] - - def del_items(items: Iterable, indices: int) -> list: - """Delete list items at indices.""" - return [i for j, i in enumerate(items) if j not in indices] - - centers = del_items(centers, del_idx) - sizes = del_items(sizes, del_idx) - names = del_items(names, del_idx) - normal_dirs = del_items(normal_dirs, del_idx) - - surfaces = [] - for _cent, _size, _name, _normal_dir in zip(centers, sizes, names, normal_dirs): - if "normal_dir" in cls.__dict__["__fields__"]: - kwargs["normal_dir"] = _normal_dir - - if "name" in cls.__dict__["__fields__"]: - kwargs["name"] = _name - - surface = cls(center=_cent, size=_size, **kwargs) - surfaces.append(surface) - - return surfaces - - @classmethod - def surfaces_with_exclusion(cls, size: Size, center: Coordinate, **kwargs: Any) -> list[Self]: - """Returns a list of 6 :class:`Box` instances corresponding to each surface of a 3D volume. - The output surfaces are stored in the order [x-, x+, y-, y+, z-, z+], where x, y, and z - denote which axis is perpendicular to that surface, while "-" and "+" denote the direction - of the normal vector of that surface. If a name is provided, each output surface's name - will be that of the provided name appended with the above symbols. E.g., if the provided - name is "box", the x+ surfaces's name will be "box_x+". If ``kwargs`` contains an - ``exclude_surfaces`` parameter, the returned list of surfaces will not include the excluded - surfaces. Otherwise, the behavior is identical to that of ``surfaces()``. - - Parameters - ---------- - size : Tuple[float, float, float] - Size of object in x, y, and z directions. - center : Tuple[float, float, float] - Center of object in x, y, and z. - - Example - ------- - >>> b = Box.surfaces_with_exclusion( - ... size=(1, 2, 3), center=(3, 2, 1), exclude_surfaces=["x-"] - ... ) - """ - exclude_surfaces = kwargs.pop("exclude_surfaces", None) - surfaces = cls.surfaces(size=size, center=center, **kwargs) - if "name" in cls.__dict__["__fields__"] and exclude_surfaces: - surfaces = [surf for surf in surfaces if surf.name[-2:] not in exclude_surfaces] - return surfaces - - @verify_packages_import(["trimesh"]) - def _do_intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for Box geometry. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - import trimesh - - (x0, y0, z0), (x1, y1, z1) = self.bounds - vertices = [ - (x0, y0, z0), # 0 - (x0, y0, z1), # 1 - (x0, y1, z0), # 2 - (x0, y1, z1), # 3 - (x1, y0, z0), # 4 - (x1, y0, z1), # 5 - (x1, y1, z0), # 6 - (x1, y1, z1), # 7 - ] - faces = [ - (0, 1, 3, 2), # -x - (4, 6, 7, 5), # +x - (0, 4, 5, 1), # -y - (2, 3, 7, 6), # +y - (0, 2, 6, 4), # -z - (1, 5, 7, 3), # +z - ] - mesh = trimesh.Trimesh(vertices, faces) - - section = mesh.section(plane_origin=origin, plane_normal=normal) - if section is None: - return [] - path, _ = section.to_2D(to_2D=to_2D) - return path.polygons_full - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns shapely geometry at plane specified by one non None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for Box geometry. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - if not self.intersects_axis_position(axis, position): - return [] - z0, (x0, y0) = self.pop_axis(self.center, axis=axis) - Lz, (Lx, Ly) = self.pop_axis(self.size, axis=axis) - dz = np.abs(z0 - position) - if dz > Lz / 2 + fp_eps: - return [] - - minx = x0 - Lx / 2 - miny = y0 - Ly / 2 - maxx = x0 + Lx / 2 - maxy = y0 + Ly / 2 - - # handle case where the box vertices are identical - if np.isclose(minx, maxx) and np.isclose(miny, maxy): - return [self.make_shapely_point(minx, miny)] - - return [self.make_shapely_box(minx, miny, maxx, maxy)] - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - self._ensure_equal_shape(x, y, z) - x0, y0, z0 = self.center - Lx, Ly, Lz = self.size - dist_x = np.abs(x - x0) - dist_y = np.abs(y - y0) - dist_z = np.abs(z - z0) - return (dist_x <= Lx / 2) * (dist_y <= Ly / 2) * (dist_z <= Lz / 2) - - def intersections_with( - self, other: Shapely, cleanup: bool = True, quad_segs: Optional[int] = None - ) -> list[Shapely]: - """Returns list of shapely geometries representing the intersections of the geometry with - this 2D box. - - Parameters - ---------- - other : Shapely - Geometry to intersect with. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect this 2D box. - For more details refer to - `Shapely's Documentation `_. - """ - - # Verify 2D - if self.size.count(0.0) != 1: - raise ValidationError( - "Intersections with other geometry are only calculated from a 2D box." - ) - - # dont bother if the geometry doesn't intersect the self at all - if not other.intersects(self): - return [] - - # get list of Shapely shapes that intersect at the self - normal_ind = self.size.index(0.0) - dim = "xyz"[normal_ind] - pos = self.center[normal_ind] - xyz_kwargs = {dim: pos} - shapes_plane = other.intersections_plane(cleanup=cleanup, quad_segs=quad_segs, **xyz_kwargs) - - # intersect all shapes with the input self - bs_min, bs_max = (self.pop_axis(bounds, axis=normal_ind)[1] for bounds in self.bounds) - - shapely_box = self.make_shapely_box(bs_min[0], bs_min[1], bs_max[0], bs_max[1]) - shapely_box = Geometry.evaluate_inf_shape(shapely_box) - return [Geometry.evaluate_inf_shape(shape) & shapely_box for shape in shapes_plane] - - def slightly_enlarged_copy(self) -> Box: - """Box size slightly enlarged around machine precision.""" - size = [increment_float(orig_length, 1) for orig_length in self.size] - return self.updated_copy(size=size) - - def padded_copy( - self, - x: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, - y: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, - z: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, - ) -> Box: - """Created a padded copy of a :class:`Box` instance. - - Parameters - ---------- - x : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None - Padding sizes at the left and right boundaries of the box along x-axis. - y : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None - Padding sizes at the left and right boundaries of the box along y-axis. - z : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None - Padding sizes at the left and right boundaries of the box along z-axis. - - Returns - ------- - Box - Padded instance of :class:`Box`. - """ - - # Validate that padding values are non-negative - for axis_name, axis_padding in zip(("x", "y", "z"), (x, y, z)): - if axis_padding is not None: - if not isinstance(axis_padding, (tuple, list)) or len(axis_padding) != 2: - raise ValueError(f"Padding for {axis_name}-axis must be a tuple of two values.") - if any(p < 0 for p in axis_padding): - raise ValueError( - f"Padding values for {axis_name}-axis must be non-negative. Got {axis_padding}." - ) - - rmin, rmax = self.bounds - - def bound_array(arrs: ArrayLike, idx: int) -> NDArray: - return np.array([(a[idx] if a is not None else 0) for a in arrs]) - - # parse padding sizes for simulation - drmin = bound_array((x, y, z), 0) - drmax = bound_array((x, y, z), 1) - - rmin = np.array(rmin) - drmin - rmax = np.array(rmax) + drmax - - return Box.from_bounds(rmin=rmin, rmax=rmax) - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - size = self.size - center = self.center - coord_min = tuple(c - s / 2 for (s, c) in zip(size, center)) - coord_max = tuple(c + s / 2 for (s, c) in zip(size, center)) - return (coord_min, coord_max) - - @cached_property - def geometry(self) -> Box: - """:class:`Box` representation of self (used for subclasses of Box). - - Returns - ------- - :class:`Box` - Instance of :class:`Box` representing self's geometry. - """ - return Box(center=self.center, size=self.size) - - @cached_property - def zero_dims(self) -> list[Axis]: - """A list of axes along which the :class:`Box` is zero-sized.""" - return [dim for dim, size in enumerate(self.size) if size == 0] - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - if np.count_nonzero(self.size) != 2: - raise ValidationError( - "'Medium2D' requires exactly one of the 'Box' dimensions to have size zero." - ) - return self.size.index(0) - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Box: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - new_center = list(self.center) - new_center[axis] = (bounds[0] + bounds[1]) / 2 - new_size = list(self.size) - new_size[axis] = bounds[1] - bounds[0] - return self.updated_copy(center=new_center, size=new_size) - - def _plot_arrow( - self, - direction: tuple[float, float, float], - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - color: Optional[str] = None, - alpha: Optional[float] = None, - bend_radius: Optional[float] = None, - bend_axis: Axis = None, - both_dirs: bool = False, - ax: Ax = None, - arrow_base: Coordinate = None, - ) -> Ax: - """Adds an arrow to the axis if with options if certain conditions met. - - Parameters - ---------- - direction: Tuple[float, float, float] - Normalized vector describing the arrow direction. - x : float = None - Position of plotting plane in x direction. - y : float = None - Position of plotting plane in y direction. - z : float = None - Position of plotting plane in z direction. - color : str = None - Color of the arrow. - alpha : float = None - Opacity of the arrow (0, 1) - bend_radius : float = None - Radius of curvature for this arrow. - bend_axis : Axis = None - Axis of curvature of ``bend_radius``. - both_dirs : bool = False - If True, plots an arrow pointing in direction and one in -direction. - arrow_base : :class:`.Coordinate` = None - Custom base of the arrow. Uses the geometry's center if not provided. - - Returns - ------- - matplotlib.axes._subplots.Axes - The matplotlib axes with the arrow added. - """ - - plot_axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) - _, (dx, dy) = self.pop_axis(direction, axis=plot_axis) - - # conditions to check to determine whether to plot arrow, taking into account the - # possibility of a custom arrow base - arrow_intersecting_plane = len(self.intersections_plane(x=x, y=y, z=z)) > 0 - center = self.center - if arrow_base: - arrow_intersecting_plane = arrow_intersecting_plane and any( - a == b for a, b in zip(arrow_base, [x, y, z]) - ) - center = arrow_base - - _, (dx, dy) = self.pop_axis(direction, axis=plot_axis) - components_in_plane = any(not np.isclose(component, 0) for component in (dx, dy)) - - # plot if arrow in plotting plane and some non-zero component can be displayed. - if arrow_intersecting_plane and components_in_plane: - _, (x0, y0) = self.pop_axis(center, axis=plot_axis) - - # Reasonable value for temporary arrow size. The correct size and direction - # have to be calculated after all transforms have been set. That is why we - # use a callback to do these calculations only at the drawing phase. - xmin, xmax = ax.get_xlim() - ymin, ymax = ax.get_ylim() - v_x = (xmax - xmin) / 10 - v_y = (ymax - ymin) / 10 - - directions = (1.0, -1.0) if both_dirs else (1.0,) - for sign in directions: - arrow = patches.FancyArrowPatch( - (x0, y0), - (x0 + v_x, y0 + v_y), - arrowstyle=arrow_style, - color=color, - alpha=alpha, - zorder=np.inf, - ) - # Don't draw this arrow until it's been reshaped - arrow.set_visible(False) - - callback = self._arrow_shape_cb( - arrow, (x0, y0), (dx, dy), sign, bend_radius if bend_axis == plot_axis else None - ) - callback_id = ax.figure.canvas.mpl_connect("draw_event", callback) - - # Store a reference to the callback because mpl_connect does not. - arrow.set_shape_cb = (callback_id, callback) - - ax.add_patch(arrow) - - return ax - - @staticmethod - def _arrow_shape_cb( - arrow: FancyArrowPatch, - pos: tuple[float, float], - direction: ArrayLike, - sign: float, - bend_radius: float | None, - ) -> Callable[[Event], None]: - def _cb(event: Event) -> None: - # We only want to set the shape once, so we disconnect ourselves - event.canvas.mpl_disconnect(arrow.set_shape_cb[0]) - - transform = arrow.axes.transData.transform - scale_x = transform((1, 0))[0] - transform((0, 0))[0] - scale_y = transform((0, 1))[1] - transform((0, 0))[1] - scale = max(scale_x, scale_y) # <-- Hack: This is a somewhat arbitrary choice. - arrow_length = ARROW_LENGTH * event.canvas.figure.get_dpi() / scale - - if bend_radius: - v_norm = (direction[0] ** 2 + direction[1] ** 2) ** 0.5 - vx_norm = direction[0] / v_norm - vy_norm = direction[1] / v_norm - bend_angle = -sign * arrow_length / bend_radius - t_x = 1 - np.cos(bend_angle) - t_y = np.sin(bend_angle) - v_x = -bend_radius * (vx_norm * t_y - vy_norm * t_x) - v_y = -bend_radius * (vx_norm * t_x + vy_norm * t_y) - tangent_angle = np.arctan2(direction[1], direction[0]) - arrow.set_connectionstyle( - patches.ConnectionStyle.Angle3( - angleA=180 / np.pi * tangent_angle, - angleB=180 / np.pi * (tangent_angle + bend_angle), - ) - ) - - else: - v_x = sign * arrow_length * direction[0] - v_y = sign * arrow_length * direction[1] - - arrow.set_positions(pos, (pos[0] + v_x, pos[1] + v_y)) - arrow.set_visible(True) - arrow.draw(event.renderer) - - return _cb - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - - volume = 1 - - for axis in range(3): - min_bound = max(self.bounds[0][axis], bounds[0][axis]) - max_bound = min(self.bounds[1][axis], bounds[1][axis]) - - volume *= max_bound - min_bound - - return volume - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - - min_bounds = list(self.bounds[0]) - max_bounds = list(self.bounds[1]) - - in_bounds_factor = [2, 2, 2] - length = [0, 0, 0] - - for axis in (0, 1, 2): - if min_bounds[axis] < bounds[0][axis]: - min_bounds[axis] = bounds[0][axis] - in_bounds_factor[axis] -= 1 - - if max_bounds[axis] > bounds[1][axis]: - max_bounds[axis] = bounds[1][axis] - in_bounds_factor[axis] -= 1 - - length[axis] = max_bounds[axis] - min_bounds[axis] - - return ( - length[0] * length[1] * in_bounds_factor[2] - + length[1] * length[2] * in_bounds_factor[0] - + length[2] * length[0] * in_bounds_factor[1] - ) - - """ Autograd code """ - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - - # get gradients w.r.t. each of the 6 faces (in normal direction) - vjps_faces = self._derivative_faces(derivative_info=derivative_info) - - # post-process these values to give the gradients w.r.t. center and size - vjps_center_size = self._derivatives_center_size(vjps_faces=vjps_faces) - - # store only the gradients asked for in 'field_paths' - derivative_map = {} - for field_path in derivative_info.paths: - field_name, *index = field_path - - if field_name in vjps_center_size: - # if the vjp calls for a specific index into the tuple - if index and len(index) == 1: - index = int(index[0]) - if field_path not in derivative_map: - derivative_map[field_path] = vjps_center_size[field_name][index] - - # otherwise, just grab the whole array - else: - derivative_map[field_path] = vjps_center_size[field_name] - - return derivative_map - - @staticmethod - def _derivatives_center_size(vjps_faces: Bound) -> dict[str, Coordinate]: - """Derivatives with respect to the ``center`` and ``size`` fields in the ``Box``.""" - - vjps_faces_min, vjps_faces_max = np.array(vjps_faces) - - # post-process min and max face gradients into center and size - vjp_center = vjps_faces_max - vjps_faces_min - vjp_size = (vjps_faces_min + vjps_faces_max) / 2.0 - - return { - "center": tuple(vjp_center.tolist()), - "size": tuple(vjp_size.tolist()), - } - - def _derivative_faces(self, derivative_info: DerivativeInfo) -> Bound: - """Derivative with respect to normal position of 6 faces of ``Box``.""" - - axes_to_compute = (0, 1, 2) - if len(derivative_info.paths[0]) > 1: - axes_to_compute = tuple(info[1] for info in derivative_info.paths) - - # change in permittivity between inside and outside - vjp_faces = np.zeros((2, 3)) - - for min_max_index, _ in enumerate((0, -1)): - for axis in axes_to_compute: - vjp_face = self._derivative_face( - min_max_index=min_max_index, - axis_normal=axis, - derivative_info=derivative_info, - ) - - # record vjp for this face - vjp_faces[min_max_index, axis] = vjp_face - - return vjp_faces - - def _derivative_face( - self, - min_max_index: int, - axis_normal: Axis, - derivative_info: DerivativeInfo, - ) -> float: - """Compute the derivative w.r.t. shifting a face in the normal direction.""" - - interpolators = derivative_info.interpolators or derivative_info.create_interpolators() - _, axis_perp = self.pop_axis((0, 1, 2), axis=axis_normal) - - # First, check if the face is outside the simulation domain in which case set the - # face gradient to 0. - bounds_normal, bounds_perp = self.pop_axis( - np.array(derivative_info.bounds).T, axis=axis_normal - ) - coord_normal_face = bounds_normal[min_max_index] - - if min_max_index == 0: - if coord_normal_face < derivative_info.simulation_bounds[0][axis_normal]: - return 0.0 - else: - if coord_normal_face > derivative_info.simulation_bounds[1][axis_normal]: - return 0.0 - - intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect) - extents = intersect_max - intersect_min - _, intersect_min_perp = self.pop_axis(np.array(intersect_min), axis=axis_normal) - _, intersect_max_perp = self.pop_axis(np.array(intersect_max), axis=axis_normal) - - is_2d_map = [] - for axis_idx in range(3): - if axis_idx == axis_normal: - continue - is_2d_map.append(np.isclose(extents[axis_idx], 0.0)) - - if np.all(is_2d_map): - return 0.0 - - is_2d = np.any(is_2d_map) - - sim_bounds_normal, sim_bounds_perp = self.pop_axis( - np.array(derivative_info.simulation_bounds).T, axis=axis_normal - ) - - # Build point grid - adaptive_spacing = derivative_info.adaptive_vjp_spacing() - - def spacing_to_grid_points( - spacing: float, min_coord: float, max_coord: float - ) -> NDArray[float]: - N = np.maximum(3, 1 + int((max_coord - min_coord) / spacing)) - - points = np.linspace(min_coord, max_coord, N) - centers = 0.5 * (points[0:-1] + points[1:]) - - return centers - - def verify_integration_interval(bound: tuple[float, float]) -> bool: - # assume the bounds should not be equal or else this integration interval - # would be the flat dimension of a 2D geometry. - return bound[1] > bound[0] - - def compute_integration_weight(grid_points: NDArray[float]) -> float: - grid_spacing = grid_points[1] - grid_points[0] - if grid_spacing == 0.0: - integration_weight = 1.0 / len(grid_points) - else: - integration_weight = grid_points[1] - grid_points[0] - - return integration_weight - - if is_2d: - # build 1D grid for sampling points along the face, which is an edge in the 2D case - zero_dim = np.where(is_2d_map)[0][0] - # zero dim is one of the perpendicular directions, so the other perpendicular direction - # is the nonzero dimension - nonzero_dim = 1 - zero_dim - - # clip at simulation bounds for integration dimension - integration_bounds_perp = ( - intersect_min_perp[nonzero_dim], - intersect_max_perp[nonzero_dim], - ) - - if not verify_integration_interval(integration_bounds_perp): - return 0.0 - - grid_points_linear = spacing_to_grid_points( - adaptive_spacing, integration_bounds_perp[0], integration_bounds_perp[1] - ) - integration_weight = compute_integration_weight(grid_points_linear) - - grid_points = np.repeat(np.expand_dims(grid_points_linear.copy(), 1), 3, axis=1) - - # set up grid points to pass into evaluate_gradient_at_points - grid_points[:, axis_perp[nonzero_dim]] = grid_points_linear - grid_points[:, axis_perp[zero_dim]] = intersect_min_perp[zero_dim] - grid_points[:, axis_normal] = coord_normal_face - else: - # build 3D grid for sampling points along the face - - # clip at simulation bounds for each integration dimension - integration_bounds_perp = ( - (intersect_min_perp[0], intersect_max_perp[0]), - (intersect_min_perp[1], intersect_max_perp[1]), - ) - - if not np.all([verify_integration_interval(b) for b in integration_bounds_perp]): - return 0.0 - - grid_points_perp_1 = spacing_to_grid_points( - adaptive_spacing, integration_bounds_perp[0][0], integration_bounds_perp[0][1] - ) - grid_points_perp_2 = spacing_to_grid_points( - adaptive_spacing, integration_bounds_perp[1][0], integration_bounds_perp[1][1] - ) - integration_weight = compute_integration_weight( - grid_points_perp_1 - ) * compute_integration_weight(grid_points_perp_2) - - mesh_perp1, mesh_perp2 = np.meshgrid(grid_points_perp_1, grid_points_perp_2) - - zip_perp_coords = np.array(list(zip(mesh_perp1.flatten(), mesh_perp2.flatten()))) - - grid_points = np.pad(zip_perp_coords.copy(), ((0, 0), (1, 0)), mode="constant") - - # set up grid points to pass into evaluate_gradient_at_points - grid_points[:, axis_perp[0]] = zip_perp_coords[:, 0] - grid_points[:, axis_perp[1]] = zip_perp_coords[:, 1] - grid_points[:, axis_normal] = coord_normal_face - - normals = np.zeros_like(grid_points) - perps1 = np.zeros_like(grid_points) - perps2 = np.zeros_like(grid_points) - - normals[:, axis_normal] = -1 if (min_max_index == 0) else 1 - perps1[:, axis_perp[0]] = 1 - perps2[:, axis_perp[1]] = 1 - - gradient_at_points = derivative_info.evaluate_gradient_at_points( - spatial_coords=grid_points, - normals=normals, - perps1=perps1, - perps2=perps2, - interpolators=interpolators, - ) - - vjp_value = np.sum(integration_weight * np.real(gradient_at_points)) - return vjp_value - - -"""Compound subclasses""" - - -class Transformed(Geometry): - """Class representing a transformed geometry.""" - - geometry: annotate_type(GeometryType) = pydantic.Field( - ..., title="Geometry", description="Base geometry to be transformed." - ) - - transform: MatrixReal4x4 = pydantic.Field( - np.eye(4).tolist(), - title="Transform", - description="Transform matrix applied to the base geometry.", - ) - - @pydantic.validator("transform") - def _transform_is_invertible(cls, val: MatrixReal4x4) -> MatrixReal4x4: - # If the transform is not invertible, this will raise an error - _ = np.linalg.inv(val) - return val - - @pydantic.validator("geometry") - def _geometry_is_finite(cls, val: GeometryType) -> GeometryType: - if not np.isfinite(val.bounds).all(): - raise ValidationError( - "Transformations are only supported on geometries with finite dimensions. " - "Try using a large value instead of 'inf' when creating geometries that undergo " - "transformations." - ) - return val - - @pydantic.root_validator(skip_on_failure=True) - def _apply_transforms(cls, values: dict[str, Any]) -> dict[str, Any]: - while isinstance(values["geometry"], Transformed): - inner = values["geometry"] - values["geometry"] = inner.geometry - values["transform"] = np.dot(values["transform"], inner.transform) - return values - - @cached_property - def inverse(self) -> MatrixReal4x4: - """Inverse of this transform.""" - return np.linalg.inv(self.transform) - - @staticmethod - def _vertices_from_bounds(bounds: Bound) -> ArrayFloat2D: - """Return the 8 vertices derived from bounds. - - The vertices are returned as homogeneous coordinates (with 4 components). - - Parameters - ---------- - bounds : Bound - Bounds from which to derive the vertices. - - Returns - ------- - ArrayFloat2D - Array with shape (4, 8) with all vertices from ``bounds``. - """ - (x0, y0, z0), (x1, y1, z1) = bounds - return np.array( - ( - (x0, x0, x0, x0, x1, x1, x1, x1), - (y0, y0, y1, y1, y0, y0, y1, y1), - (z0, z1, z0, z1, z0, z1, z0, z1), - (1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0), - ) - ) - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float, float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - # NOTE (Lucas): The bounds are overestimated because we don't want to calculate - # precise TriangleMesh representations for GeometryGroup or ClipOperation. - vertices = np.dot(self.transform, self._vertices_from_bounds(self.geometry.bounds))[:3] - return (tuple(vertices.min(axis=1)), tuple(vertices.max(axis=1))) - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - return self.geometry.intersections_tilted_plane( - tuple(np.dot((normal[0], normal[1], normal[2], 0.0), self.transform)[:3]), - tuple(np.dot(self.inverse, (origin[0], origin[1], origin[2], 1.0))[:3]), - np.dot(to_2D, self.transform), - cleanup=cleanup, - quad_segs=quad_segs, - ) - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - x = np.array(x) - y = np.array(y) - z = np.array(z) - xyz = np.dot(self.inverse, np.vstack((x.flat, y.flat, z.flat, np.ones(x.size)))) - if xyz.shape[1] == 1: - # TODO: This "fix" is required because of a bug in PolySlab.inside (with non-zero sidewall angle) - return self.geometry.inside(xyz[0][0], xyz[1][0], xyz[2][0]).reshape(x.shape) - return self.geometry.inside(xyz[0], xyz[1], xyz[2]).reshape(x.shape) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - # NOTE (Lucas): Bounds are overestimated. - vertices = np.dot(self.inverse, self._vertices_from_bounds(bounds))[:3] - inverse_bounds = (tuple(vertices.min(axis=1)), tuple(vertices.max(axis=1))) - return abs(np.linalg.det(self.transform)) * self.geometry.volume(inverse_bounds) - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - log.warning("Surface area of transformed elements cannot be calculated.") - return None - - @staticmethod - def translation(x: float, y: float, z: float) -> MatrixReal4x4: - """Return a translation matrix. - - Parameters - ---------- - x : float - Translation along x. - y : float - Translation along y. - z : float - Translation along z. - - Returns - ------- - numpy.ndarray - Transform matrix with shape (4, 4). - """ - return np.array( - [ - (1.0, 0.0, 0.0, x), - (0.0, 1.0, 0.0, y), - (0.0, 0.0, 1.0, z), - (0.0, 0.0, 0.0, 1.0), - ], - dtype=float, - ) - - @staticmethod - def scaling(x: float = 1.0, y: float = 1.0, z: float = 1.0) -> MatrixReal4x4: - """Return a scaling matrix. - - Parameters - ---------- - x : float = 1.0 - Scaling factor along x. - y : float = 1.0 - Scaling factor along y. - z : float = 1.0 - Scaling factor along z. - - Returns - ------- - numpy.ndarray - Transform matrix with shape (4, 4). - """ - if np.isclose((x, y, z), 0.0).any(): - raise Tidy3dError("Scaling factors cannot be zero in any dimensions.") - return np.array( - [ - (x, 0.0, 0.0, 0.0), - (0.0, y, 0.0, 0.0), - (0.0, 0.0, z, 0.0), - (0.0, 0.0, 0.0, 1.0), - ], - dtype=float, - ) - - @staticmethod - def rotation(angle: float, axis: Union[Axis, Coordinate]) -> MatrixReal4x4: - """Return a rotation matrix. - - Parameters - ---------- - angle : float - Rotation angle (in radians). - axis : Union[int, Tuple[float, float, float]] - Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. - - Returns - ------- - numpy.ndarray - Transform matrix with shape (4, 4). - """ - transform = np.eye(4) - transform[:3, :3] = RotationAroundAxis(angle=angle, axis=axis).matrix - return transform - - @staticmethod - def reflection(normal: Coordinate) -> MatrixReal4x4: - """Return a reflection matrix. - - Parameters - ---------- - normal : Tuple[float, float, float] - Normal of the plane of reflection. - - Returns - ------- - numpy.ndarray - Transform matrix with shape (4, 4). - """ - - transform = np.eye(4) - transform[:3, :3] = ReflectionFromPlane(normal=normal).matrix - return transform - - @staticmethod - def preserves_axis(transform: MatrixReal4x4, axis: Axis) -> bool: - """Indicate if the transform preserves the orientation of a given axis. - - Parameters: - transform: MatrixReal4x4 - Transform matrix to check. - axis : int - Axis to check. Values 0, 1, or 2, to check x, y, or z, respectively. - - Returns - ------- - bool - ``True`` if the transformation preserves the axis orientation, ``False`` otherwise. - """ - i = (axis + 1) % 3 - j = (axis + 2) % 3 - return np.isclose(transform[i, axis], 0) and np.isclose(transform[j, axis], 0) - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - normal = self.geometry._normal_2dmaterial - preserves_axis = Transformed.preserves_axis(self.transform, normal) - - if not preserves_axis: - raise ValidationError( - "'Medium2D' requires geometries of type 'Transformed' to " - "perserve the axis normal to the 'Medium2D'." - ) - - return normal - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Transformed: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - min_bound = np.array([0, 0, 0, 1.0]) - min_bound[axis] = bounds[0] - max_bound = np.array([0, 0, 0, 1.0]) - max_bound[axis] = bounds[1] - new_bounds = [] - new_bounds.append(np.dot(self.inverse, min_bound)[axis]) - new_bounds.append(np.dot(self.inverse, max_bound)[axis]) - new_geometry = self.geometry._update_from_bounds(bounds=new_bounds, axis=axis) - return self.updated_copy(geometry=new_geometry) - - -class ClipOperation(Geometry): - """Class representing the result of a set operation between geometries.""" - - operation: ClipOperationType = pydantic.Field( - ..., - title="Operation Type", - description="Operation to be performed between geometries.", - ) - - geometry_a: annotate_type(GeometryType) = pydantic.Field( - ..., - title="Geometry A", - description="First operand for the set operation. It can be any geometry type, including " - ":class:`GeometryGroup`.", - ) - - geometry_b: annotate_type(GeometryType) = pydantic.Field( - ..., - title="Geometry B", - description="Second operand for the set operation. It can also be any geometry type.", - ) - - @pydantic.validator("geometry_a", "geometry_b", always=True) - def _geometries_untraced(cls, val: GeometryType) -> GeometryType: - """Make sure that ``ClipOperation`` geometries do not contain tracers.""" - traced = val._strip_traced_fields() - if traced: - raise ValidationError( - f"{val.type} contains traced fields {list(traced.keys())}. Note that " - "'ClipOperation' does not currently support automatic differentiation." - ) - return val - - @staticmethod - def to_polygon_list(base_geometry: Shapely, cleanup: bool = False) -> list[Shapely]: - """Return a list of valid polygons from a shapely geometry, discarding points, lines, and - empty polygons, and empty triangles within polygons. - - Parameters - ---------- - base_geometry : shapely.geometry.base.BaseGeometry - Base geometry for inspection. - cleanup: bool = False - If True, removes extremely small features from each polygon's boundary. - This is useful for removing artifacts from 2D plots displayed to the user. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - Valid polygons retrieved from ``base geometry``. - """ - unfiltered_geoms = [] - if base_geometry.geom_type == "GeometryCollection": - unfiltered_geoms = [ - p - for geom in base_geometry.geoms - for p in ClipOperation.to_polygon_list(geom, cleanup) - ] - if base_geometry.geom_type == "MultiPolygon": - unfiltered_geoms = [p for p in base_geometry.geoms if not p.is_empty] - if base_geometry.geom_type == "Polygon" and not base_geometry.is_empty: - unfiltered_geoms = [base_geometry] - geoms = [] - if cleanup: - # Optional: "clean" each of the polygons (by removing extremely small or thin features). - for geom in unfiltered_geoms: - geom_clean = cleanup_shapely_object(geom) - if geom_clean.geom_type == "Polygon": - geoms.append(geom_clean) - if geom_clean.geom_type == "MultiPolygon": - geoms += [p for p in geom_clean.geoms if not p.is_empty] - # Ignore other types of shapely objects (points and lines) - else: - geoms = unfiltered_geoms - return geoms - - @property - def _shapely_operation(self) -> Callable[[Shapely, Shapely], Shapely]: - """Return a Shapely function equivalent to this operation.""" - result = _shapely_operations.get(self.operation, None) - if not result: - raise ValueError( - "'operation' must be one of 'union', 'intersection', 'difference', or " - "'symmetric_difference'." - ) - return result - - @property - def _bit_operation(self) -> Callable[[Any, Any], Any]: - """Return a function equivalent to this operation using bit operators.""" - result = _bit_operations.get(self.operation, None) - if not result: - raise ValueError( - "'operation' must be one of 'union', 'intersection', 'difference', or " - "'symmetric_difference'." - ) - return result - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - a = self.geometry_a.intersections_tilted_plane( - normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs - ) - b = self.geometry_b.intersections_tilted_plane( - normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs - ) - geom_a = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in a]) - geom_b = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in b]) - return ClipOperation.to_polygon_list( - self._shapely_operation(geom_a, geom_b), - cleanup=cleanup, - ) - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentaton `_. - """ - a = self.geometry_a.intersections_plane(x, y, z, cleanup=cleanup, quad_segs=quad_segs) - b = self.geometry_b.intersections_plane(x, y, z, cleanup=cleanup, quad_segs=quad_segs) - geom_a = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in a]) - geom_b = shapely.unary_union([Geometry.evaluate_inf_shape(g) for g in b]) - return ClipOperation.to_polygon_list( - self._shapely_operation(geom_a, geom_b), - cleanup=cleanup, - ) - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - # Overestimates - if self.operation == "difference": - result = self.geometry_a.bounds - elif self.operation == "intersection": - bounds = (self.geometry_a.bounds, self.geometry_b.bounds) - result = ( - tuple(max(b[i] for b, _ in bounds) for i in range(3)), - tuple(min(b[i] for _, b in bounds) for i in range(3)), - ) - if any(result[0][i] > result[1][i] for i in range(3)): - result = ((0, 0, 0), (0, 0, 0)) - else: - bounds = (self.geometry_a.bounds, self.geometry_b.bounds) - result = ( - tuple(min(b[i] for b, _ in bounds) for i in range(3)), - tuple(max(b[i] for _, b in bounds) for i in range(3)), - ) - return result - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - inside_a = self.geometry_a.inside(x, y, z) - inside_b = self.geometry_b.inside(x, y, z) - return self._bit_operation(inside_a, inside_b) - - def inside_meshgrid( - self, x: NDArray[float], y: NDArray[float], z: NDArray[float] - ) -> NDArray[bool]: - """Faster way to check ``self.inside`` on a meshgrid. The input arrays are assumed sorted. - - Parameters - ---------- - x : np.ndarray[float] - 1D array of point positions in x direction. - y : np.ndarray[float] - 1D array of point positions in y direction. - z : np.ndarray[float] - 1D array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every - point that is inside the geometry. - """ - inside_a = self.geometry_a.inside_meshgrid(x, y, z) - inside_b = self.geometry_b.inside_meshgrid(x, y, z) - return self._bit_operation(inside_a, inside_b) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - # Overestimates - if self.operation == "intersection": - return min(self.geometry_a.volume(bounds), self.geometry_b.volume(bounds)) - if self.operation == "difference": - return self.geometry_a.volume(bounds) - return self.geometry_a.volume(bounds) + self.geometry_b.volume(bounds) - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - # Overestimates - return self.geometry_a.surface_area(bounds) + self.geometry_b.surface_area(bounds) - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - normal_a = self.geometry_a._normal_2dmaterial - normal_b = self.geometry_b._normal_2dmaterial - - if normal_a != normal_b: - raise ValidationError( - "'Medium2D' requires both geometries in the 'ClipOperation' to " - "have exactly one dimension with zero size in common." - ) - - plane_position_a = self.geometry_a.bounds[0][normal_a] - plane_position_b = self.geometry_b.bounds[0][normal_b] - - if plane_position_a != plane_position_b: - raise ValidationError( - "'Medium2D' requires both geometries in the 'ClipOperation' to be co-planar." - ) - return normal_a - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> ClipOperation: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - new_geom_a = self.geometry_a._update_from_bounds(bounds=bounds, axis=axis) - new_geom_b = self.geometry_b._update_from_bounds(bounds=bounds, axis=axis) - return self.updated_copy(geometry_a=new_geom_a, geometry_b=new_geom_b) - - -class GeometryGroup(Geometry): - """A collection of Geometry objects that can be called as a single geometry object.""" - - geometries: tuple[annotate_type(GeometryType), ...] = pydantic.Field( - ..., - title="Geometries", - description="Tuple of geometries in a single grouping. " - "Can provide significant performance enhancement in ``Structure`` when all geometries are " - "assigned the same medium.", - ) - - @pydantic.validator("geometries", always=True) - def _geometries_not_empty( - cls, val: tuple[annotate_type(GeometryType), ...] - ) -> tuple[annotate_type(GeometryType), ...]: - """make sure geometries are not empty.""" - if not len(val) > 0: - raise ValidationError("GeometryGroup.geometries must not be empty.") - return val - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float, float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - - bounds = tuple(geometry.bounds for geometry in self.geometries) - return ( - tuple(min(b[i] for b, _ in bounds) for i in range(3)), - tuple(max(b[i] for _, b in bounds) for i in range(3)), - ) - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - return [ - intersection - for geometry in self.geometries - for intersection in geometry.intersections_tilted_plane( - normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs - ) - ] - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - if not self.intersects_plane(x, y, z): - return [] - return [ - intersection - for geometry in self.geometries - for intersection in geometry.intersections_plane( - x=x, y=y, z=z, cleanup=cleanup, quad_segs=quad_segs - ) - ] - - def intersects_axis_position(self, axis: float, position: float) -> bool: - """Whether self intersects plane specified by a given position along a normal axis. - - Parameters - ---------- - axis : int = None - Axis normal to the plane. - position : float = None - Position of plane along the normal axis. - - Returns - ------- - bool - Whether this geometry intersects the plane. - """ - return any(geom.intersects_axis_position(axis, position) for geom in self.geometries) - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - individual_insides = (geometry.inside(x, y, z) for geometry in self.geometries) - return functools.reduce(lambda a, b: a | b, individual_insides) - - def inside_meshgrid( - self, x: NDArray[float], y: NDArray[float], z: NDArray[float] - ) -> NDArray[bool]: - """Faster way to check ``self.inside`` on a meshgrid. The input arrays are assumed sorted. - - Parameters - ---------- - x : np.ndarray[float] - 1D array of point positions in x direction. - y : np.ndarray[float] - 1D array of point positions in y direction. - z : np.ndarray[float] - 1D array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - Array with shape ``(x.size, y.size, z.size)``, which is ``True`` for every - point that is inside the geometry. - """ - individual_insides = (geom.inside_meshgrid(x, y, z) for geom in self.geometries) - return functools.reduce(lambda a, b: a | b, individual_insides) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - individual_volumes = (geometry.volume(bounds) for geometry in self.geometries) - return np.sum(individual_volumes) - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - individual_areas = (geometry.surface_area(bounds) for geometry in self.geometries) - return np.sum(individual_areas) - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - - normals = {geom._normal_2dmaterial for geom in self.geometries} - - if len(normals) != 1: - raise ValidationError( - "'Medium2D' requires all geometries in the 'GeometryGroup' to " - "share exactly one dimension with zero size." - ) - normal = list(normals)[0] - positions = {geom.bounds[0][normal] for geom in self.geometries} - if len(positions) != 1: - raise ValidationError( - "'Medium2D' requires all geometries in the 'GeometryGroup' to be co-planar." - ) - return normal - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> GeometryGroup: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - new_geometries = [ - geometry._update_from_bounds(bounds=bounds, axis=axis) for geometry in self.geometries - ] - return self.updated_copy(geometries=new_geometries) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - - grad_vjps = {} - - # create interpolators once for all geometries to avoid redundant field data conversions - interpolators = derivative_info.interpolators or derivative_info.create_interpolators() - - for field_path in derivative_info.paths: - _, index, *geo_path = field_path - geo = self.geometries[index] - # pass pre-computed interpolators if available - geo_info = derivative_info.updated_copy( - paths=[tuple(geo_path)], - bounds=geo.bounds, - bounds_intersect=self.bounds_intersection( - geo.bounds, derivative_info.simulation_bounds - ), - eps_approx=True, - deep=False, - interpolators=interpolators, - ) - - vjp_dict_geo = geo._compute_derivatives(geo_info) - - if len(vjp_dict_geo) != 1: - raise AssertionError("Got multiple gradients for single geometry field.") - - grad_vjps[field_path] = vjp_dict_geo.popitem()[1] - - return grad_vjps - - -def cleanup_shapely_object(obj: Shapely, tolerance_ratio: float = POLY_TOLERANCE_RATIO) -> Shapely: - """Remove small geometric features from the boundaries of a shapely object including - inward and outward spikes, thin holes, and thin connections between larger regions. - - Parameters - ---------- - obj : shapely - a shapely object (typically a ``Polygon`` or a ``MultiPolygon``) - tolerance_ratio : float = ``POLY_TOLERANCE_RATIO`` - Features on the boundaries of polygons will be discarded if they are smaller - or narrower than ``tolerance_ratio`` multiplied by the size of the object. - - Returns - ------- - Shapely - A new shapely object whose small features (eg. thin spikes or holes) are removed. - - Notes - ----- - This function does not attempt to delete overlapping, nearby, or collinear vertices. - To solve that problem, use ``shapely.simplify()`` afterwards. - """ - if _shapely_is_older_than("2.1"): - log.warning( - "Using old versions of the shapely library (prior to v2.1) may cause " - "plot errors. This can be solved by upgrading to Python 3.10 " - "(or later) and reinstalling Tidy3d.", - log_once=True, - ) - return obj - if obj.is_empty: - return obj - centroid = obj.centroid - object_size = min(obj.bounds[2] - obj.bounds[0], obj.bounds[3] - obj.bounds[1]) - if object_size == 0.0: - return shapely.Polygon([]) - # In order to prevent numerical overflow or underflow errors, we first subtract - # the centroid and divide by (rescale) the size of the object so it is not too big. - normalized_obj = shapely.affinity.affine_transform( - # https://shapely.readthedocs.io/en/stable/manual.html#affine-transformations - obj, - matrix=[ - 1 / object_size, - 0.0, - 0.0, - 1 / object_size, - -centroid.x / object_size, - -centroid.y / object_size, - ], - ) - # Important: Remove any self intersections beforehand using `shapely.make_valid()`. - valid_obj = shapely.make_valid(normalized_obj, method="structure", keep_collapsed=False) - # To get rid of small thin features, erode(shrink), dilate(expand), and erode again. - eroded_obj = shapely.buffer( # This removes outward spikes - valid_obj, - distance=-tolerance_ratio, - cap_style="square", # (optional parameter to reduce computation time) - quad_segs=3, # (optional parameter to reduce computation time) - ) - dilated_obj = shapely.buffer( # This removes inward spikes and tiny holes - eroded_obj, - distance=2 * tolerance_ratio, - cap_style="square", - quad_segs=3, - ) - cleaned_obj = dilated_obj - # Optional: Now shrink the polygon back to the original size. - cleaned_obj = shapely.buffer( - cleaned_obj, - distance=-tolerance_ratio, - cap_style="square", - quad_segs=3, - ) - # Clean vertices of very close distances created during the erosion/dilation process. - # The distance value is heuristic. - cleaned_obj = cleaned_obj.simplify(POLY_DISTANCE_TOLERANCE, preserve_topology=True) - # Revert to the original scale and position. - rescaled_clean_obj = shapely.affinity.affine_transform( - cleaned_obj, - matrix=[ - object_size, - 0.0, - 0.0, - object_size, - centroid.x, - centroid.y, - ], - ) - return rescaled_clean_obj - - -from .utils import GeometryType, from_shapely, vertices_from_shapely # noqa: E402 diff --git a/tidy3d/components/geometry/bound_ops.py b/tidy3d/components/geometry/bound_ops.py index 2cd3428c2c..58aa09c5f8 100644 --- a/tidy3d/components/geometry/bound_ops.py +++ b/tidy3d/components/geometry/bound_ops.py @@ -1,68 +1,12 @@ -"""Geometry operations for bounding box type with minimal imports.""" - -from __future__ import annotations - -from math import isclose - -from tidy3d.components.types import Bound -from tidy3d.constants import fp_eps - - -def bounds_intersection(bounds1: Bound, bounds2: Bound) -> Bound: - """Return the bounds that are the intersection of two bounds.""" - rmin1, rmax1 = bounds1 - rmin2, rmax2 = bounds2 - rmin = tuple(max(v1, v2) for v1, v2 in zip(rmin1, rmin2)) - rmax = tuple(min(v1, v2) for v1, v2 in zip(rmax1, rmax2)) - return (rmin, rmax) +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.bound_ops`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -def bounds_union(bounds1: Bound, bounds2: Bound) -> Bound: - """Return the bounds that are the union of two bounds.""" - rmin1, rmax1 = bounds1 - rmin2, rmax2 = bounds2 - rmin = tuple(min(v1, v2) for v1, v2 in zip(rmin1, rmin2)) - rmax = tuple(max(v1, v2) for v1, v2 in zip(rmax1, rmax2)) - return (rmin, rmax) - - -def bounds_contains( - outer_bounds: Bound, inner_bounds: Bound, rtol: float = fp_eps, atol: float = 0.0 -) -> bool: - """Checks whether ``inner_bounds`` is contained within ``outer_bounds`` within specified tolerances. - - Parameters - ---------- - outer_bounds : Bound - The outer bounds to check containment against - inner_bounds : Bound - The inner bounds to check if contained - rtol : float = fp_eps - Relative tolerance for comparing bounds - atol : float = 0.0 - Absolute tolerance for comparing bounds - - Returns - ------- - bool - True if ``inner_bounds`` is contained within ``outer_bounds`` within tolerances - """ - outer_min, outer_max = outer_bounds - inner_min, inner_max = inner_bounds - for dim in range(3): - outer_min_dim = outer_min[dim] - outer_max_dim = outer_max[dim] - inner_min_dim = inner_min[dim] - inner_max_dim = inner_max[dim] - within_min = ( - isclose(outer_min_dim, inner_min_dim, rel_tol=rtol, abs_tol=atol) - or outer_min_dim <= inner_min_dim - ) - within_max = ( - isclose(outer_max_dim, inner_max_dim, rel_tol=rtol, abs_tol=atol) - or outer_max_dim >= inner_max_dim - ) +# marked as migrated to _common +from __future__ import annotations - if not within_min or not within_max: - return False - return True +from tidy3d._common.components.geometry.bound_ops import ( + bounds_contains, + bounds_intersection, + bounds_union, +) diff --git a/tidy3d/components/geometry/float_utils.py b/tidy3d/components/geometry/float_utils.py index 5ab7b438be..a45674303e 100644 --- a/tidy3d/components/geometry/float_utils.py +++ b/tidy3d/components/geometry/float_utils.py @@ -1,31 +1,10 @@ -"""Utilities for float manipulation.""" - -from __future__ import annotations - -import numpy as np - -from tidy3d.constants import inf +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.float_utils`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -def increment_float(val: float, sign: int) -> float: - """Applies a small positive or negative shift as though `val` is a 32bit float - using numpy.nextafter, but additionally handles some corner cases. - """ - # Infinity is left unchanged - if val == inf or val == -inf: - return val - - if sign >= 0: - sign = 1 - else: - sign = -1 - - # Avoid small increments within subnormal values - if np.abs(val) <= np.finfo(np.float32).tiny: - return val + sign * np.finfo(np.float32).tiny - - # Numpy seems to skip over the increment from -0.0 and +0.0 - # which is different from c++ - val_inc = np.nextafter(val, sign * inf, dtype=np.float32) +# marked as migrated to _common +from __future__ import annotations - return np.float32(val_inc) +from tidy3d._common.components.geometry.float_utils import ( + increment_float, +) diff --git a/tidy3d/components/geometry/mesh.py b/tidy3d/components/geometry/mesh.py index 006f49dc0a..c1a305e221 100644 --- a/tidy3d/components/geometry/mesh.py +++ b/tidy3d/components/geometry/mesh.py @@ -1,1280 +1,11 @@ -"""Mesh-defined geometry.""" +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.mesh`.""" -from __future__ import annotations - -from abc import ABC -from os import PathLike -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union - -import numpy as np -import pydantic.v1 as pydantic -from autograd import numpy as anp -from numpy.typing import NDArray - -from tidy3d.components.autograd import AutogradFieldMap, get_static -from tidy3d.components.autograd.derivative_utils import DerivativeInfo -from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import DATA_ARRAY_MAP, TriangleMeshDataArray -from tidy3d.components.data.dataset import TriangleMeshDataset -from tidy3d.components.data.validators import validate_no_nans -from tidy3d.components.types import Ax, Bound, Coordinate, MatrixReal4x4, Shapely -from tidy3d.components.viz import add_ax_if_none, equal_aspect -from tidy3d.config import config -from tidy3d.constants import fp_eps, inf -from tidy3d.exceptions import DataError, ValidationError -from tidy3d.log import log -from tidy3d.packaging import verify_packages_import - -from . import base - -if TYPE_CHECKING: - from trimesh import Trimesh - -AREA_SIZE_THRESHOLD = 1e-36 - - -class TriangleMesh(base.Geometry, ABC): - """Custom surface geometry given by a triangle mesh, as in the STL file format. - - Example - ------- - >>> vertices = np.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]) - >>> faces = np.array([[1, 2, 3], [0, 3, 2], [0, 1, 3], [0, 2, 1]]) - >>> stl_geom = TriangleMesh.from_vertices_faces(vertices, faces) - """ - - mesh_dataset: Optional[TriangleMeshDataset] = pydantic.Field( - ..., - title="Surface mesh data", - description="Surface mesh data.", - ) - - _no_nans_mesh = validate_no_nans("mesh_dataset") - _barycentric_samples: dict[int, NDArray] = pydantic.PrivateAttr(default_factory=dict) - - @pydantic.root_validator(pre=True) - @verify_packages_import(["trimesh"]) - def _validate_trimesh_library(cls, values: dict[str, Any]) -> dict[str, Any]: - """Check if the trimesh package is imported as a validator.""" - return values - - @pydantic.validator("mesh_dataset", pre=True, always=True) - def _warn_if_none(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: - """Warn if the Dataset fails to load.""" - if isinstance(val, dict): - if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): - log.warning("Loading 'mesh_dataset' without data.") - return None - return val - - @pydantic.validator("mesh_dataset", always=True) - @verify_packages_import(["trimesh"]) - def _check_mesh(cls, val: TriangleMeshDataset) -> TriangleMeshDataset: - """Check that the mesh is valid.""" - if val is None: - return None - - import trimesh - - surface_mesh = val.surface_mesh - triangles = get_static(surface_mesh.data) - mesh = cls._triangles_to_trimesh(triangles) - if not all(np.array(mesh.area_faces) > AREA_SIZE_THRESHOLD): - old_tol = trimesh.tol.merge - trimesh.tol.merge = np.sqrt(2 * AREA_SIZE_THRESHOLD) - new_mesh = mesh.process(validate=True) - trimesh.tol.merge = old_tol - val = TriangleMesh.from_trimesh(new_mesh).mesh_dataset - log.warning( - f"The provided mesh has triangles with near zero area < {AREA_SIZE_THRESHOLD}. " - "Triangles which have one edge of their 2D oriented bounding box shorter than " - f"'sqrt(2*{AREA_SIZE_THRESHOLD}) are being automatically removed.'" - ) - if not all(np.array(new_mesh.area_faces) > AREA_SIZE_THRESHOLD): - raise ValidationError( - f"The provided mesh has triangles with near zero area < {AREA_SIZE_THRESHOLD}. " - "The automatic removal of these triangles has failed. You can try " - "using numpy-stl's 'from_file' import with 'remove_empty_areas' set " - "to True and a suitable 'AREA_SIZE_THRESHOLD' to remove them." - ) - if not mesh.is_watertight: - log.warning( - "The provided mesh is not watertight. " - "This can lead to incorrect permittivity distributions, " - "and can also cause problems with plotting and mesh validation. " - "You can try 'TriangleMesh.fill_holes', which attempts to repair the mesh. " - "Otherwise, the mesh may require manual repair. You can use a " - "'PermittivityMonitor' to check if the permittivity distribution is correct. " - "You can see which faces are broken using 'trimesh.repair.broken_faces'." - ) - if not mesh.is_winding_consistent: - log.warning( - "The provided mesh does not have consistent winding (face orientations). " - "This can lead to incorrect permittivity distributions, " - "and can also cause problems with plotting and mesh validation. " - "You can try 'TriangleMesh.fix_winding', which attempts to repair the mesh. " - "Otherwise, the mesh may require manual repair. You can use a " - "'PermittivityMonitor' to check if the permittivity distribution is correct. " - ) - if not mesh.is_volume: - log.warning( - "The provided mesh does not represent a valid volume, possibly due to " - "incorrect normal vector orientation. " - "This can lead to incorrect permittivity distributions, " - "and can also cause problems with plotting and mesh validation. " - "You can try 'TriangleMesh.fix_normals', " - "which attempts to fix the normals to be consistent and outward-facing. " - "Otherwise, the mesh may require manual repair. You can use a " - "'PermittivityMonitor' to check if the permittivity distribution is correct." - ) - - return val - - @verify_packages_import(["trimesh"]) - def fix_winding(self) -> TriangleMesh: - """Try to fix winding in the mesh.""" - import trimesh - - mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) - trimesh.repair.fix_winding(mesh) - return TriangleMesh.from_trimesh(mesh) - - @verify_packages_import(["trimesh"]) - def fill_holes(self) -> TriangleMesh: - """Try to fill holes in the mesh. Can be used to repair non-watertight meshes.""" - import trimesh - - mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) - trimesh.repair.fill_holes(mesh) - return TriangleMesh.from_trimesh(mesh) - - @verify_packages_import(["trimesh"]) - def fix_normals(self) -> TriangleMesh: - """Try to fix normals to be consistent and outward-facing.""" - import trimesh - - mesh = TriangleMesh._triangles_to_trimesh(self.mesh_dataset.surface_mesh) - trimesh.repair.fix_normals(mesh) - return TriangleMesh.from_trimesh(mesh) - - @classmethod - @verify_packages_import(["trimesh"]) - def from_stl( - cls, - filename: str, - scale: float = 1.0, - origin: tuple[float, float, float] = (0, 0, 0), - solid_index: Optional[int] = None, - **kwargs: Any, - ) -> Union[TriangleMesh, base.GeometryGroup]: - """Load a :class:`.TriangleMesh` directly from an STL file. - The ``solid_index`` parameter can be used to select a single solid from the file. - Otherwise, if the file contains a single solid, it will be loaded as a - :class:`.TriangleMesh`; if the file contains multiple solids, - they will all be loaded as a :class:`.GeometryGroup`. - - Parameters - ---------- - filename : str - The name of the STL file containing the surface geometry mesh data. - scale : float = 1.0 - The length scale for the loaded geometry (um). - For example, a scale of 10.0 means that a vertex (1, 0, 0) will be placed at - x = 10 um. - origin : Tuple[float, float, float] = (0, 0, 0) - The origin of the loaded geometry, in units of ``scale``. - Translates from (0, 0, 0) to this point after applying the scaling. - solid_index : int = None - If set, read a single solid with this index from the file. - - Returns - ------- - Union[:class:`.TriangleMesh`, :class:`.GeometryGroup`] - The geometry or geometry group from the file. - """ - import trimesh - - from tidy3d.components.types.third_party import TrimeshType - - def process_single(mesh: TrimeshType) -> TriangleMesh: - """Process a single 'trimesh.Trimesh' using scale and origin.""" - mesh.apply_scale(scale) - mesh.apply_translation(origin) - return cls.from_trimesh(mesh) - - scene = trimesh.load(filename, **kwargs) - meshes = [] - if isinstance(scene, trimesh.Trimesh): - meshes = [scene] - elif isinstance(scene, trimesh.Scene): - meshes = scene.dump() - else: - raise ValidationError( - "Invalid trimesh type in file. Supported types are 'trimesh.Trimesh' " - "and 'trimesh.Scene'." - ) - - if solid_index is None: - if isinstance(scene, trimesh.Trimesh): - return process_single(scene) - if isinstance(scene, trimesh.Scene): - geoms = [process_single(mesh) for mesh in meshes] - return base.GeometryGroup(geometries=geoms) - - if solid_index < len(meshes): - return process_single(meshes[solid_index]) - raise ValidationError("No solid found at 'solid_index' in the stl file.") - - @verify_packages_import(["trimesh"]) - def to_stl( - self, - filename: PathLike, - *, - binary: bool = True, - ) -> None: - """Export this TriangleMesh to an STL file. - - Parameters - ---------- - filename : str - Output STL filename. - binary : bool = True - Whether to write binary STL. Set False for ASCII STL. - """ - triangles = get_static(self.mesh_dataset.surface_mesh.data) - mesh = self._triangles_to_trimesh(triangles) - - file_type = "stl" if binary else "stl_ascii" - mesh.export(file_obj=filename, file_type=file_type) - - @classmethod - @verify_packages_import(["trimesh"]) - def from_trimesh(cls, mesh: trimesh.Trimesh) -> TriangleMesh: - """Create a :class:`.TriangleMesh` from a ``trimesh.Trimesh`` object. - - Parameters - ---------- - trimesh : ``trimesh.Trimesh`` - The Trimesh object containing the surface geometry mesh data. - - Returns - ------- - :class:`.TriangleMesh` - The custom surface mesh geometry given by the ``trimesh.Trimesh`` provided. - """ - return cls.from_vertices_faces(mesh.vertices, mesh.faces) - - @classmethod - def from_triangles(cls, triangles: NDArray) -> TriangleMesh: - """Create a :class:`.TriangleMesh` from a numpy array - containing the triangles of a surface mesh. - - Parameters - ---------- - triangles : ``np.ndarray`` - A numpy array of shape (N, 3, 3) storing the triangles of the surface mesh. - The first index labels the triangle, the second index labels the vertex - within a given triangle, and the third index is the coordinate (x, y, or z). - - Returns - ------- - :class:`.TriangleMesh` - The custom surface mesh geometry given by the triangles provided. - - """ - triangles = anp.array(triangles) - if len(triangles.shape) != 3 or triangles.shape[1] != 3 or triangles.shape[2] != 3: - raise ValidationError( - f"Provided 'triangles' must be an N x 3 x 3 array, given {triangles.shape}." - ) - num_faces = len(triangles) - coords = { - "face_index": np.arange(num_faces), - "vertex_index": np.arange(3), - "axis": np.arange(3), - } - vertices = TriangleMeshDataArray(triangles, coords=coords) - mesh_dataset = TriangleMeshDataset(surface_mesh=vertices) - return TriangleMesh(mesh_dataset=mesh_dataset) - - @classmethod - @verify_packages_import(["trimesh"]) - def from_vertices_faces(cls, vertices: NDArray, faces: NDArray) -> TriangleMesh: - """Create a :class:`.TriangleMesh` from numpy arrays containing the data - of a surface mesh. The first array contains the vertices, and the second array contains - faces formed from triples of the vertices. - - Parameters - ---------- - vertices: ``np.ndarray`` - A numpy array of shape (N, 3) storing the vertices of the surface mesh. - The first index labels the vertex, and the second index is the coordinate - (x, y, or z). - faces : ``np.ndarray`` - A numpy array of shape (M, 3) storing the indices of the vertices of each face - in the surface mesh. The first index labels the face, and the second index - labels the vertex index within the ``vertices`` array. - - Returns - ------- - :class:`.TriangleMesh` - The custom surface mesh geometry given by the vertices and faces provided. - - """ - import trimesh - - vertices = np.array(vertices) - faces = np.array(faces) - if len(vertices.shape) != 2 or vertices.shape[1] != 3: - raise ValidationError( - f"Provided 'vertices' must be an N x 3 array, given {vertices.shape}." - ) - if len(faces.shape) != 2 or faces.shape[1] != 3: - raise ValidationError(f"Provided 'faces' must be an M x 3 array, given {faces.shape}.") - return cls.from_triangles(trimesh.Trimesh(vertices, faces).triangles) - - @classmethod - @verify_packages_import(["trimesh"]) - def _triangles_to_trimesh( - cls, triangles: NDArray - ) -> Trimesh: # -> We need to get this out of the classes and into functional methods operating on a class (maybe still referenced to the class) - """Convert an (N, 3, 3) numpy array of triangles to a ``trimesh.Trimesh``.""" - import trimesh - - # ``triangles`` may contain autograd ``ArrayBox`` entries when differentiating - # geometry parameters. ``trimesh`` expects plain ``float`` values, so strip any - # tracing information before constructing the mesh. - triangles = get_static(anp.array(triangles)) - return trimesh.Trimesh(**trimesh.triangles.to_kwargs(triangles)) - - @classmethod - def from_height_grid( - cls, - axis: Ax, - direction: Literal["-", "+"], - base: float, - grid: tuple[np.ndarray, np.ndarray], - height: NDArray, - ) -> TriangleMesh: - """Construct a TriangleMesh object from grid based height information. - - Parameters - ---------- - axis : Ax - Axis of extrusion. - direction : Literal["-", "+"] - Direction of extrusion. - base : float - Coordinate of the base surface along the geometry's axis. - grid : Tuple[np.ndarray, np.ndarray] - Tuple of two one-dimensional arrays representing the sampling grid (XY, YZ, or ZX - corresponding to values of axis) - height : np.ndarray - Height values sampled on the given grid. Can be 1D (raveled) or 2D (matching grid mesh). - - Returns - ------- - TriangleMesh - The resulting TriangleMesh geometry object. - """ - - x_coords = grid[0] - y_coords = grid[1] - - nx = len(x_coords) - ny = len(y_coords) - nt = nx * ny - - x_mesh, y_mesh = np.meshgrid(x_coords, y_coords, indexing="ij") - - sign = 1 - if direction == "-": - sign = -1 - - flat_height = np.ravel(height) - if flat_height.shape[0] != nt: - raise ValueError( - f"Shape of flattened height array {flat_height.shape} does not match " - f"the number of grid points {nt}." - ) - - if np.any(flat_height < 0): - raise ValueError("All height values must be non-negative.") - - max_h = np.max(flat_height) - min_h_clip = fp_eps * max_h - flat_height = np.clip(flat_height, min_h_clip, inf) - - vertices_raw_list = [ - [np.ravel(x_mesh), np.ravel(y_mesh), base + sign * flat_height], # Alpha surface - [np.ravel(x_mesh), np.ravel(y_mesh), base * np.ones(nt)], - ] - - if direction == "-": - vertices_raw_list = vertices_raw_list[::-1] - - vertices = np.hstack(vertices_raw_list).T - vertices = np.roll(vertices, shift=axis - 2, axis=1) - - q0 = (np.arange(nx - 1)[:, None] * ny + np.arange(ny - 1)[None, :]).ravel() - q1 = (np.arange(1, nx)[:, None] * ny + np.arange(ny - 1)[None, :]).ravel() - q2 = (np.arange(1, nx)[:, None] * ny + np.arange(1, ny)[None, :]).ravel() - q3 = (np.arange(nx - 1)[:, None] * ny + np.arange(1, ny)[None, :]).ravel() - - q0_b = nt + q0 - q1_b = nt + q1 - q2_b = nt + q2 - q3_b = nt + q3 - - top_quads = np.stack((q0, q1, q2, q3), axis=-1) - bottom_quads = np.stack((q0_b, q3_b, q2_b, q1_b), axis=-1) - - s1_q0 = (0 * ny + np.arange(ny - 1)).ravel() - s1_q1 = (0 * ny + np.arange(1, ny)).ravel() - s1_q2 = (nt + 0 * ny + np.arange(1, ny)).ravel() - s1_q3 = (nt + 0 * ny + np.arange(ny - 1)).ravel() - side1_quads = np.stack((s1_q0, s1_q1, s1_q2, s1_q3), axis=-1) - - s2_q0 = ((nx - 1) * ny + np.arange(ny - 1)).ravel() - s2_q1 = (nt + (nx - 1) * ny + np.arange(ny - 1)).ravel() - s2_q2 = (nt + (nx - 1) * ny + np.arange(1, ny)).ravel() - s2_q3 = ((nx - 1) * ny + np.arange(1, ny)).ravel() - side2_quads = np.stack((s2_q0, s2_q1, s2_q2, s2_q3), axis=-1) - - s3_q0 = (np.arange(nx - 1) * ny + 0).ravel() - s3_q1 = (nt + np.arange(nx - 1) * ny + 0).ravel() - s3_q2 = (nt + np.arange(1, nx) * ny + 0).ravel() - s3_q3 = (np.arange(1, nx) * ny + 0).ravel() - side3_quads = np.stack((s3_q0, s3_q1, s3_q2, s3_q3), axis=-1) - - s4_q0 = (np.arange(nx - 1) * ny + ny - 1).ravel() - s4_q1 = (np.arange(1, nx) * ny + ny - 1).ravel() - s4_q2 = (nt + np.arange(1, nx) * ny + ny - 1).ravel() - s4_q3 = (nt + np.arange(nx - 1) * ny + ny - 1).ravel() - side4_quads = np.stack((s4_q0, s4_q1, s4_q2, s4_q3), axis=-1) - - all_quads = np.vstack( - (top_quads, bottom_quads, side1_quads, side2_quads, side3_quads, side4_quads) - ) - - triangles_list = [ - np.stack((all_quads[:, 0], all_quads[:, 1], all_quads[:, 3]), axis=-1), - np.stack((all_quads[:, 3], all_quads[:, 1], all_quads[:, 2]), axis=-1), - ] - tri_faces = np.vstack(triangles_list) - - return cls.from_vertices_faces(vertices=vertices, faces=tri_faces) - - @classmethod - def from_height_function( - cls, - axis: Ax, - direction: Literal["-", "+"], - base: float, - center: tuple[float, float], - size: tuple[float, float], - grid_size: tuple[int, int], - height_func: Callable[[np.ndarray, np.ndarray], np.ndarray], - ) -> TriangleMesh: - """Construct a TriangleMesh object from analytical expression of height function. - The height function should be vectorized to accept 2D meshgrid arrays. - - Parameters - ---------- - axis : Ax - Axis of extrusion. - direction : Literal["-", "+"] - Direction of extrusion. - base : float - Coordinate of the base rectangle along the geometry's axis. - center : Tuple[float, float] - Center of the base rectangle in the plane perpendicular to the extrusion axis - (XY, YZ, or ZX corresponding to values of axis). - size : Tuple[float, float] - Size of the base rectangle in the plane perpendicular to the extrusion axis - (XY, YZ, or ZX corresponding to values of axis). - grid_size : Tuple[int, int] - Number of grid points for discretization of the base rectangle - (XY, YZ, or ZX corresponding to values of axis). - height_func : Callable[[np.ndarray, np.ndarray], np.ndarray] - Vectorized function to compute height values from 2D meshgrid coordinate arrays. - It should take two ndarrays (x_mesh, y_mesh) and return an ndarray of heights. - - Returns - ------- - TriangleMesh - The resulting TriangleMesh geometry object. - """ - x_lin = np.linspace(center[0] - 0.5 * size[0], center[0] + 0.5 * size[0], grid_size[0]) - y_lin = np.linspace(center[1] - 0.5 * size[1], center[1] + 0.5 * size[1], grid_size[1]) - - x_mesh, y_mesh = np.meshgrid(x_lin, y_lin, indexing="ij") - - height_values = height_func(x_mesh, y_mesh) - - if not (isinstance(height_values, np.ndarray) and height_values.shape == x_mesh.shape): - raise ValueError( - f"The 'height_func' must return a NumPy array with shape {x_mesh.shape}, " - f"but got shape {getattr(height_values, 'shape', type(height_values))}." - ) - - return cls.from_height_grid( - axis=axis, - direction=direction, - base=base, - grid=(x_lin, y_lin), - height=height_values, - ) - - @cached_property - @verify_packages_import(["trimesh"]) - def trimesh( - self, - ) -> Trimesh: # -> We need to get this out of the classes and into functional methods operating on a class (maybe still referenced to the class) - """A ``trimesh.Trimesh`` object representing the custom surface mesh geometry.""" - return self._triangles_to_trimesh(self.triangles) - - @cached_property - def triangles(self) -> np.ndarray: - """The triangles of the surface mesh as an ``np.ndarray``.""" - if self.mesh_dataset is None: - raise DataError("Can't get triangles as 'mesh_dataset' is None.") - return np.asarray(get_static(self.mesh_dataset.surface_mesh.data)) - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - # currently ignores bounds - return self.trimesh.area - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - # currently ignores bounds - return self.trimesh.volume - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - if self.mesh_dataset is None: - return ((-inf, -inf, -inf), (inf, inf, inf)) - return self.trimesh.bounds - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for TriangleMesh. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - section = self.trimesh.section(plane_origin=origin, plane_normal=normal) - if section is None: - return [] - path, _ = section.to_2D(to_2D=to_2D) - return path.polygons_full - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Returns list of shapely geometries at plane specified by one non-None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for TriangleMesh. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentaton `_. - """ - - if self.mesh_dataset is None: - return [] - - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - - origin = self.unpop_axis(position, (0, 0), axis=axis) - normal = self.unpop_axis(1, (0, 0), axis=axis) - - mesh = self.trimesh - - try: - section = mesh.section(plane_origin=origin, plane_normal=normal) - - if section is None: - return [] - - # homogeneous transformation matrix to map to xy plane - mapping = np.eye(4) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - # translate to origin - mapping[3, :3] = -np.array(origin) - - # permute so normal is aligned with z axis - # and (y, z), (x, z), resp. (x, y) are aligned with (x, y) - identity = np.eye(3) - permutation = self.unpop_axis(identity[2], identity[0:2], axis=axis) - mapping[:3, :3] = np.array(permutation).T - - section2d, _ = section.to_2D(to_2D=mapping) - return list(section2d.polygons_full) - - except ValueError as e: - if not mesh.is_watertight: - log.warning( - "Unable to compute 'TriangleMesh.intersections_plane' " - "because the mesh was not watertight. Using bounding box instead. " - "This may be overly strict; consider using 'TriangleMesh.fill_holes' " - "to repair the non-watertight mesh." - ) - else: - log.warning( - "Unable to compute 'TriangleMesh.intersections_plane'. " - "Using bounding box instead." - ) - log.warning(f"Error encountered: {e}") - return self.bounding_box.intersections_plane(x=x, y=y, z=z, cleanup=cleanup) - - def inside(self, x: NDArray, y: NDArray, z: NDArray) -> np.ndarray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - - arrays = tuple(map(np.array, (x, y, z))) - self._ensure_equal_shape(*arrays) - arrays_flat = map(np.ravel, arrays) - arrays_stacked = np.stack(tuple(arrays_flat), axis=-1) - inside = self.trimesh.contains(arrays_stacked) - return inside.reshape(arrays[0].shape) - - @equal_aspect - @add_ax_if_none - def plot( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - ax: Ax = None, - **patch_kwargs: Any, - ) -> Ax: - """Plot geometry cross section at single (x,y,z) coordinate. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in y direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in z direction, only one of x,y,z can be specified to define plane. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - **patch_kwargs - Optional keyword arguments passed to the matplotlib patch plotting of structure. - For details on accepted values, refer to - `Matplotlib's documentation `_. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - - log.warning( - "Plotting a 'TriangleMesh' may give inconsistent results " - "if the mesh is not unionized. We recommend unionizing all meshes before import. " - "A 'PermittivityMonitor' can be used to check that the mesh is loaded correctly." - ) - - return base.Geometry.plot(self, x=x, y=y, z=z, ax=ax, **patch_kwargs) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute adjoint derivatives for a ``TriangleMesh`` geometry.""" - vjps: AutogradFieldMap = {} - - if not self.mesh_dataset: - raise DataError("Can't compute derivatives without mesh data.") - - valid_paths = {("mesh_dataset", "surface_mesh")} - for path in derivative_info.paths: - if path not in valid_paths: - raise ValueError(f"No derivative defined w.r.t. 'TriangleMesh' field '{path}'.") - - if ("mesh_dataset", "surface_mesh") not in derivative_info.paths: - return vjps - - triangles = np.asarray(self.triangles, dtype=config.adjoint.gradient_dtype_float) - - # early exit if geometry is completely outside simulation bounds - sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds) - mesh_min, mesh_max = map(np.asarray, self.bounds) - if np.any(mesh_max < sim_min) or np.any(mesh_min > sim_max): - log.warning( - "'TriangleMesh' lies completely outside the simulation domain.", - log_once=True, - ) - zeros = np.zeros_like(triangles) - vjps[("mesh_dataset", "surface_mesh")] = zeros - return vjps - - # gather surface samples within the simulation bounds - dx = derivative_info.adaptive_vjp_spacing() - samples = self._collect_surface_samples( - triangles=triangles, - spacing=dx, - sim_min=sim_min, - sim_max=sim_max, - ) - - if samples["points"].shape[0] == 0: - zeros = np.zeros_like(triangles) - vjps[("mesh_dataset", "surface_mesh")] = zeros - return vjps - - interpolators = derivative_info.interpolators - if interpolators is None: - interpolators = derivative_info.create_interpolators( - dtype=config.adjoint.gradient_dtype_float - ) - - g = derivative_info.evaluate_gradient_at_points( - samples["points"], - samples["normals"], - samples["perps1"], - samples["perps2"], - interpolators, - ) - - # accumulate per-vertex contributions using barycentric weights - weights = (samples["weights"] * g).real - normals = samples["normals"] - faces = samples["faces"] - bary = samples["barycentric"] - - contrib_vec = weights[:, None] * normals - - triangle_grads = np.zeros_like(triangles, dtype=config.adjoint.gradient_dtype_float) - for vertex_idx in range(3): - scaled = contrib_vec * bary[:, vertex_idx][:, None] - np.add.at(triangle_grads[:, vertex_idx, :], faces, scaled) - - vjps[("mesh_dataset", "surface_mesh")] = triangle_grads - return vjps - - def _collect_surface_samples( - self, - triangles: NDArray, - spacing: float, - sim_min: NDArray, - sim_max: NDArray, - ) -> dict[str, np.ndarray]: - """Deterministic per-triangle sampling used historically.""" - - dtype = config.adjoint.gradient_dtype_float - tol = config.adjoint.edge_clip_tolerance - - sim_min = np.asarray(sim_min, dtype=dtype) - sim_max = np.asarray(sim_max, dtype=dtype) - - points_list: list[np.ndarray] = [] - normals_list: list[np.ndarray] = [] - perps1_list: list[np.ndarray] = [] - perps2_list: list[np.ndarray] = [] - weights_list: list[np.ndarray] = [] - faces_list: list[np.ndarray] = [] - bary_list: list[np.ndarray] = [] - - spacing = max(float(spacing), np.finfo(float).eps) - triangles_arr = np.asarray(triangles, dtype=dtype) - - sim_extents = sim_max - sim_min - valid_axes = np.abs(sim_extents) > tol - collapsed_indices = np.flatnonzero(np.isclose(sim_extents, 0.0, atol=tol)) - collapsed_axis: Optional[int] = None - plane_value: Optional[float] = None - if collapsed_indices.size == 1: - collapsed_axis = int(collapsed_indices[0]) - plane_value = float(sim_min[collapsed_axis]) - - warned = False - warning_msg = "Some triangles from the mesh lie outside the simulation bounds - this may lead to inaccurate gradients." - for face_index, tri in enumerate(triangles_arr): - area, normal = self._triangle_area_and_normal(tri) - if area <= AREA_SIZE_THRESHOLD: - continue - - perps = self._triangle_tangent_basis(tri, normal) - if perps is None: - continue - perp1, perp2 = perps - - if collapsed_axis is not None and plane_value is not None: - samples, outside_bounds = self._collect_surface_samples_2d( - triangle=tri, - face_index=face_index, - normal=normal, - perp1=perp1, - perp2=perp2, - spacing=spacing, - collapsed_axis=collapsed_axis, - plane_value=plane_value, - sim_min=sim_min, - sim_max=sim_max, - valid_axes=valid_axes, - tol=tol, - dtype=dtype, - ) - else: - samples, outside_bounds = self._collect_surface_samples_3d( - triangle=tri, - face_index=face_index, - normal=normal, - perp1=perp1, - perp2=perp2, - area=area, - spacing=spacing, - sim_min=sim_min, - sim_max=sim_max, - valid_axes=valid_axes, - tol=tol, - dtype=dtype, - ) - - if outside_bounds and not warned: - log.warning(warning_msg) - warned = True - - if samples is None: - continue - - points_list.append(samples["points"]) - normals_list.append(samples["normals"]) - perps1_list.append(samples["perps1"]) - perps2_list.append(samples["perps2"]) - weights_list.append(samples["weights"]) - faces_list.append(samples["faces"]) - bary_list.append(samples["barycentric"]) - - if not points_list: - return { - "points": np.zeros((0, 3), dtype=dtype), - "normals": np.zeros((0, 3), dtype=dtype), - "perps1": np.zeros((0, 3), dtype=dtype), - "perps2": np.zeros((0, 3), dtype=dtype), - "weights": np.zeros((0,), dtype=dtype), - "faces": np.zeros((0,), dtype=int), - "barycentric": np.zeros((0, 3), dtype=dtype), - } - - return { - "points": np.concatenate(points_list, axis=0), - "normals": np.concatenate(normals_list, axis=0), - "perps1": np.concatenate(perps1_list, axis=0), - "perps2": np.concatenate(perps2_list, axis=0), - "weights": np.concatenate(weights_list, axis=0), - "faces": np.concatenate(faces_list, axis=0), - "barycentric": np.concatenate(bary_list, axis=0), - } - - def _collect_surface_samples_2d( - self, - triangle: NDArray, - face_index: int, - normal: np.ndarray, - perp1: np.ndarray, - perp2: np.ndarray, - spacing: float, - collapsed_axis: int, - plane_value: float, - sim_min: np.ndarray, - sim_max: np.ndarray, - valid_axes: np.ndarray, - tol: float, - dtype: np.dtype, - ) -> tuple[Optional[dict[str, np.ndarray]], bool]: - """Collect samples when the simulation bounds collapse onto a 2D plane.""" - - segments = self._triangle_plane_segments( - triangle=triangle, axis=collapsed_axis, plane_value=plane_value, tol=tol - ) - - points: list[np.ndarray] = [] - normals: list[np.ndarray] = [] - perps1_list: list[np.ndarray] = [] - perps2_list: list[np.ndarray] = [] - weights: list[np.ndarray] = [] - faces: list[np.ndarray] = [] - barycentric: list[np.ndarray] = [] - outside_bounds = False - - for start, end in segments: - vec = end - start - length = float(np.linalg.norm(vec)) - if length <= tol: - continue - - subdivisions = max(1, int(np.ceil(length / spacing))) - t_vals = (np.arange(subdivisions, dtype=dtype) + 0.5) / subdivisions - sample_points = start[None, :] + t_vals[:, None] * vec[None, :] - bary = self._barycentric_coordinates(triangle, sample_points, tol) - - inside_mask = np.ones(sample_points.shape[0], dtype=bool) - if np.any(valid_axes): - min_bound = (sim_min - tol)[valid_axes] - max_bound = (sim_max + tol)[valid_axes] - coords = sample_points[:, valid_axes] - inside_mask = np.all(coords >= min_bound, axis=1) & np.all( - coords <= max_bound, axis=1 - ) - - outside_bounds = outside_bounds or (not np.all(inside_mask)) - if not np.any(inside_mask): - continue - - sample_points = sample_points[inside_mask] - bary_inside = bary[inside_mask] - n_inside = sample_points.shape[0] - - normal_tile = np.repeat(normal[None, :], n_inside, axis=0) - perp1_tile = np.repeat(perp1[None, :], n_inside, axis=0) - perp2_tile = np.repeat(perp2[None, :], n_inside, axis=0) - weights_tile = np.full(n_inside, length / subdivisions, dtype=dtype) - faces_tile = np.full(n_inside, face_index, dtype=int) - - points.append(sample_points) - normals.append(normal_tile) - perps1_list.append(perp1_tile) - perps2_list.append(perp2_tile) - weights.append(weights_tile) - faces.append(faces_tile) - barycentric.append(bary_inside) - - if not points: - return None, outside_bounds - - samples = { - "points": np.concatenate(points, axis=0), - "normals": np.concatenate(normals, axis=0), - "perps1": np.concatenate(perps1_list, axis=0), - "perps2": np.concatenate(perps2_list, axis=0), - "weights": np.concatenate(weights, axis=0), - "faces": np.concatenate(faces, axis=0), - "barycentric": np.concatenate(barycentric, axis=0), - } - return samples, outside_bounds - - def _collect_surface_samples_3d( - self, - triangle: NDArray, - face_index: int, - normal: np.ndarray, - perp1: np.ndarray, - perp2: np.ndarray, - area: float, - spacing: float, - sim_min: np.ndarray, - sim_max: np.ndarray, - valid_axes: np.ndarray, - tol: float, - dtype: np.dtype, - ) -> tuple[Optional[dict[str, np.ndarray]], bool]: - """Collect samples when the simulation bounds represent a full 3D region.""" - - edge_lengths = ( - np.linalg.norm(triangle[1] - triangle[0]), - np.linalg.norm(triangle[2] - triangle[1]), - np.linalg.norm(triangle[0] - triangle[2]), - ) - subdivisions = self._subdivision_count(area, spacing, edge_lengths) - barycentric = self._get_barycentric_samples(subdivisions, dtype) - num_samples = barycentric.shape[0] - base_weight = area / num_samples - - sample_points = barycentric @ triangle - - inside_mask = np.all( - sample_points[:, valid_axes] >= (sim_min - tol)[valid_axes], axis=1 - ) & np.all(sample_points[:, valid_axes] <= (sim_max + tol)[valid_axes], axis=1) - outside_bounds = not np.all(inside_mask) - if not np.any(inside_mask): - return None, outside_bounds - - sample_points = sample_points[inside_mask] - bary_inside = barycentric[inside_mask] - n_samples_inside = sample_points.shape[0] - - normal_tile = np.repeat(normal[None, :], n_samples_inside, axis=0) - perp1_tile = np.repeat(perp1[None, :], n_samples_inside, axis=0) - perp2_tile = np.repeat(perp2[None, :], n_samples_inside, axis=0) - weights_tile = np.full(n_samples_inside, base_weight, dtype=dtype) - faces_tile = np.full(n_samples_inside, face_index, dtype=int) - - samples = { - "points": sample_points, - "normals": normal_tile, - "perps1": perp1_tile, - "perps2": perp2_tile, - "weights": weights_tile, - "faces": faces_tile, - "barycentric": bary_inside, - } - return samples, outside_bounds - - @staticmethod - def _triangle_area_and_normal(triangle: NDArray) -> tuple[float, np.ndarray]: - """Return area and outward normal of the provided triangle.""" - - edge01 = triangle[1] - triangle[0] - edge02 = triangle[2] - triangle[0] - cross = np.cross(edge01, edge02) - norm = np.linalg.norm(cross) - if norm <= 0.0: - return 0.0, np.zeros(3, dtype=triangle.dtype) - normal = (cross / norm).astype(triangle.dtype, copy=False) - area = 0.5 * norm - return area, normal - - @staticmethod - def _triangle_plane_segments( - triangle: NDArray, axis: int, plane_value: float, tol: float - ) -> list[tuple[np.ndarray, np.ndarray]]: - """Return intersection segments between a triangle and an axis-aligned plane.""" - - vertices = np.asarray(triangle) - distances = vertices[:, axis] - plane_value - edges = ((0, 1), (1, 2), (2, 0)) - - segments: list[tuple[np.ndarray, np.ndarray]] = [] - points: list[np.ndarray] = [] - - def add_point(pt: np.ndarray) -> None: - for existing in points: - if np.linalg.norm(existing - pt) <= tol: - return - points.append(pt.copy()) - - for i, j in edges: - di = distances[i] - dj = distances[j] - vi = vertices[i] - vj = vertices[j] - - if abs(di) <= tol and abs(dj) <= tol: - segments.append((vi.copy(), vj.copy())) - continue - - if di * dj > 0.0: - continue - - if abs(di) <= tol: - add_point(vi) - continue - - if abs(dj) <= tol: - add_point(vj) - continue - - denom = di - dj - if abs(denom) <= tol: - continue - t = di / denom - if t < 0.0 or t > 1.0: - continue - point = vi + t * (vj - vi) - add_point(point) - - if segments: - return segments - - if len(points) >= 2: - return [(points[0], points[1])] - - return [] - - @staticmethod - def _barycentric_coordinates(triangle: NDArray, points: np.ndarray, tol: float) -> np.ndarray: - """Compute barycentric coordinates of ``points`` with respect to ``triangle``.""" - - pts = np.asarray(points, dtype=triangle.dtype) - v0 = triangle[0] - v1 = triangle[1] - v2 = triangle[2] - v0v1 = v1 - v0 - v0v2 = v2 - v0 - - d00 = float(np.dot(v0v1, v0v1)) - d01 = float(np.dot(v0v1, v0v2)) - d11 = float(np.dot(v0v2, v0v2)) - denom = d00 * d11 - d01 * d01 - if abs(denom) <= tol: - return np.tile( - np.array([1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], dtype=triangle.dtype), (pts.shape[0], 1) - ) - - v0p = pts - v0 - d20 = v0p @ v0v1 - d21 = v0p @ v0v2 - v = (d11 * d20 - d01 * d21) / denom - w = (d00 * d21 - d01 * d20) / denom - u = 1.0 - v - w - bary = np.stack((u, v, w), axis=1) - return bary.astype(triangle.dtype, copy=False) - - @classmethod - def _subdivision_count( - cls, - area: float, - spacing: float, - edge_lengths: Optional[tuple[float, float, float]] = None, - ) -> int: - """Determine the number of subdivisions needed for the given area and spacing.""" - - spacing = max(float(spacing), np.finfo(float).eps) - - target = np.sqrt(max(area, 0.0)) - area_based = np.ceil(np.sqrt(2.0) * target / spacing) - - edge_based = 0.0 - if edge_lengths: - max_edge = max(edge_lengths) - if max_edge > 0.0: - edge_based = np.ceil(max_edge / spacing) - - subdivisions = max(1, int(max(area_based, edge_based))) - return subdivisions - - def _get_barycentric_samples(self, subdivisions: int, dtype: np.dtype) -> np.ndarray: - """Return barycentric sample coordinates for a subdivision level.""" - - cache = self._barycentric_samples - if subdivisions not in cache: - cache[subdivisions] = self._build_barycentric_samples(subdivisions) - return cache[subdivisions].astype(dtype, copy=False) - - @staticmethod - def _build_barycentric_samples(subdivisions: int) -> np.ndarray: - """Construct barycentric sampling points for a given subdivision level.""" - - if subdivisions <= 1: - return np.array([[1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]]) - - bary = [] - for i in range(subdivisions): - for j in range(subdivisions - i): - l1 = (i + 1.0 / 3.0) / subdivisions - l2 = (j + 1.0 / 3.0) / subdivisions - l0 = 1.0 - l1 - l2 - bary.append((l0, l1, l2)) - return np.asarray(bary, dtype=float) - - @staticmethod - def subdivide_faces(vertices: NDArray, faces: NDArray) -> tuple[np.ndarray, np.ndarray]: - """Uniformly subdivide each triangular face by inserting edge midpoints.""" - - midpoint_cache: dict[tuple[int, int], int] = {} - verts_list = [np.asarray(v, dtype=float) for v in vertices] - - def midpoint(i: int, j: int) -> int: - key = (i, j) if i < j else (j, i) - if key in midpoint_cache: - return midpoint_cache[key] - vm = 0.5 * (verts_list[i] + verts_list[j]) - verts_list.append(vm) - idx = len(verts_list) - 1 - midpoint_cache[key] = idx - return idx - - new_faces: list[tuple[int, int, int]] = [] - for tri in faces: - a = midpoint(tri[0], tri[1]) - b = midpoint(tri[1], tri[2]) - c = midpoint(tri[2], tri[0]) - new_faces.extend(((tri[0], a, c), (tri[1], b, a), (tri[2], c, b), (a, b, c))) - - verts_arr = np.asarray(verts_list, dtype=float) - return verts_arr, np.asarray(new_faces, dtype=int) - - @staticmethod - def _triangle_tangent_basis( - triangle: NDArray, normal: NDArray - ) -> Optional[tuple[np.ndarray, np.ndarray]]: - """Compute orthonormal tangential vectors for a triangle.""" - - tol = np.finfo(triangle.dtype).eps - edges = [triangle[1] - triangle[0], triangle[2] - triangle[0], triangle[2] - triangle[1]] - - edge = None - for candidate in edges: - length = np.linalg.norm(candidate) - if length > tol: - edge = (candidate / length).astype(triangle.dtype, copy=False) - break - - if edge is None: - return None +# marked as migrated to _common +from __future__ import annotations - perp1 = edge - perp2 = np.cross(normal, perp1) - perp2_norm = np.linalg.norm(perp2) - if perp2_norm <= tol: - return None - perp2 = (perp2 / perp2_norm).astype(triangle.dtype, copy=False) - return perp1, perp2 +from tidy3d._common.components.geometry.mesh import ( + AREA_SIZE_THRESHOLD, + TriangleMesh, +) diff --git a/tidy3d/components/geometry/polyslab.py b/tidy3d/components/geometry/polyslab.py index a8e633663d..f433226fd0 100644 --- a/tidy3d/components/geometry/polyslab.py +++ b/tidy3d/components/geometry/polyslab.py @@ -1,2760 +1,17 @@ -"""Geometry extruded from polygonal shapes.""" +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.polyslab`.""" -from __future__ import annotations - -import math -from copy import copy -from functools import lru_cache -from typing import TYPE_CHECKING, Any, Optional, Union +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import autograd.numpy as np -import pydantic.v1 as pydantic -import shapely -from autograd.tracer import getval, isbox -from numpy._typing import NDArray -from numpy.polynomial.legendre import leggauss as _leggauss +# marked as migrated to _common +from __future__ import annotations -from tidy3d.components.autograd import AutogradFieldMap, TracedVertices, get_static -from tidy3d.components.autograd.derivative_utils import DerivativeInfo -from tidy3d.components.autograd.types import TracedFloat -from tidy3d.components.base import cached_property, skip_if_fields_missing -from tidy3d.components.transformation import ReflectionFromPlane, RotationAroundAxis -from tidy3d.components.types import ( - ArrayFloat1D, - ArrayFloat2D, - ArrayLike, - Axis, - Bound, - Coordinate, - MatrixReal4x4, - PlanePosition, - Shapely, +from tidy3d._common.components.geometry.polyslab import ( + _COMPLEX_POLYSLAB_DIVISIONS_WARN, + _IS_CLOSE_RTOL, + _MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION, + _MIN_POLYGON_AREA, + _N_SAMPLE_POLYGON_INTERSECT, + ComplexPolySlabBase, + PolySlab, + leggauss, ) -from tidy3d.config import config -from tidy3d.constants import LARGE_NUMBER, MICROMETER, fp_eps -from tidy3d.exceptions import SetupError, Tidy3dImportError, ValidationError -from tidy3d.log import log -from tidy3d.packaging import verify_packages_import - -from . import base, triangulation - -if TYPE_CHECKING: - from gdstk import Cell - -# sampling polygon along dilation for validating polygon to be -# non self-intersecting during the entire dilation process -_N_SAMPLE_POLYGON_INTERSECT = 5 - -_IS_CLOSE_RTOL = np.finfo(float).eps - -# Warn for too many divided polyslabs -_COMPLEX_POLYSLAB_DIVISIONS_WARN = 100 - -# Warn before triangulating large polyslabs due to inefficiency -_MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION = 500 - -_MIN_POLYGON_AREA = fp_eps - - -@lru_cache(maxsize=128) -def leggauss(n: int) -> tuple[NDArray, NDArray]: - """Cached version of leggauss with dtype conversions.""" - g, w = _leggauss(n) - return g.astype(config.adjoint.gradient_dtype_float, copy=False), w.astype( - config.adjoint.gradient_dtype_float, copy=False - ) - - -class PolySlab(base.Planar): - """Polygon extruded with optional sidewall angle along axis direction. - - Example - ------- - >>> vertices = np.array([(0,0), (1,0), (1,1)]) - >>> p = PolySlab(vertices=vertices, axis=2, slab_bounds=(-1, 1)) - """ - - slab_bounds: tuple[TracedFloat, TracedFloat] = pydantic.Field( - ..., - title="Slab Bounds", - description="Minimum and maximum positions of the slab along axis dimension.", - units=MICROMETER, - ) - - dilation: float = pydantic.Field( - 0.0, - title="Dilation", - description="Dilation of the supplied polygon by shifting each edge along its " - "normal outwards direction by a distance; a negative value corresponds to erosion.", - units=MICROMETER, - ) - - vertices: TracedVertices = pydantic.Field( - ..., - title="Vertices", - description="List of (d1, d2) defining the 2 dimensional positions of the polygon " - "face vertices at the ``reference_plane``. " - "The index of dimension should be in the ascending order: e.g. if " - "the slab normal axis is ``axis=y``, the coordinate of the vertices will be in (x, z)", - units=MICROMETER, - ) - - @staticmethod - def make_shapely_polygon(vertices: ArrayLike) -> shapely.Polygon: - """Make a shapely polygon from some vertices, first ensures they are untraced.""" - vertices = get_static(vertices) - return shapely.Polygon(vertices) - - @pydantic.validator("slab_bounds", always=True) - def slab_bounds_order(cls, val: tuple[float, float]) -> tuple[float, float]: - """Maximum position of the slab should be no smaller than its minimal position.""" - if val[1] < val[0]: - raise SetupError( - "Polyslab.slab_bounds must be specified in the order of " - "minimum and maximum positions of the slab along the axis. " - f"But now the maximum {val[1]} is smaller than the minimum {val[0]}." - ) - return val - - @pydantic.validator("vertices", always=True) - def correct_shape(cls, val: ArrayFloat2D) -> ArrayFloat2D: - """Makes sure vertices size is correct. - Make sure no intersecting edges. - """ - # overall shape of vertices - if val.shape[1] != 2: - raise SetupError( - "PolySlab.vertices must be a 2 dimensional array shaped (N, 2). " - f"Given array with shape of {val.shape}." - ) - - # make sure no polygon splitting, islands, 0 area - poly_heal = shapely.make_valid(cls.make_shapely_polygon(val)) - if poly_heal.area < _MIN_POLYGON_AREA: - raise SetupError("The polygon almost collapses to a 1D curve.") - - if not poly_heal.geom_type == "Polygon" or len(poly_heal.interiors) > 0: - raise SetupError( - "Polygon is self-intersecting, resulting in " - "polygon splitting or generation of holes/islands. " - "A general treatment to self-intersecting polygon will be available " - "in future releases." - ) - return val - - @pydantic.validator("vertices", always=True) - @skip_if_fields_missing(["dilation"]) - def no_complex_self_intersecting_polygon_at_reference_plane( - cls, val: ArrayFloat2D, values: dict[str, Any] - ) -> ArrayFloat2D: - """At the reference plane, check if the polygon is self-intersecting. - - There are two types of self-intersection that can occur during dilation: - 1) the one that creates holes/islands, or splits polygons, or removes everything; - 2) the one that does not. - - For 1), we issue an error since it is yet to be supported; - For 2), we heal the polygon, and warn that the polygon has been cleaned up. - """ - # no need to validate anything here - if math.isclose(values["dilation"], 0): - return val - - val_np = PolySlab._proper_vertices(val) - dist = values["dilation"] - - # 0) fully eroded - if dist < 0 and dist < -PolySlab._maximal_erosion(val_np): - raise SetupError("Erosion value is too large. The polygon is fully eroded.") - - # no edge events - if not PolySlab._edge_events_detection(val_np, dist, ignore_at_dist=False): - return val - - poly_offset = PolySlab._shift_vertices(val_np, dist)[0] - if PolySlab._area(poly_offset) < fp_eps**2: - raise SetupError("Erosion value is too large. The polygon is fully eroded.") - - # edge events - poly_offset = shapely.make_valid(cls.make_shapely_polygon(poly_offset)) - # 1) polygon split or create holes/islands - if not poly_offset.geom_type == "Polygon" or len(poly_offset.interiors) > 0: - raise SetupError( - "Dilation/Erosion value is too large, resulting in " - "polygon splitting or generation of holes/islands. " - "A general treatment to self-intersecting polygon will be available " - "in future releases." - ) - - # case 2 - log.warning( - "The dilation/erosion value is too large. resulting in a " - "self-intersecting polygon. " - "The vertices have been modified to make a valid polygon." - ) - return val - - @pydantic.validator("vertices", always=True) - @skip_if_fields_missing(["sidewall_angle", "dilation", "slab_bounds", "reference_plane"]) - def no_self_intersecting_polygon_during_extrusion( - cls, val: ArrayFloat2D, values: dict[str, Any] - ) -> ArrayFloat2D: - """In this simple polyslab, we don't support self-intersecting polygons yet, meaning that - any normal cross section of the PolySlab cannot be self-intersecting. This part checks - if any self-interction will occur during extrusion with non-zero sidewall angle. - - There are two types of self-intersection, known as edge events, - that can occur during dilation: - 1) neighboring vertex-vertex crossing. This type of edge event can be treated with - ``ComplexPolySlab`` which divides the polyslab into a list of simple polyslabs. - - 2) other types of edge events that can create holes/islands or split polygons. - To detect this, we sample _N_SAMPLE_POLYGON_INTERSECT cross sections to see if any creation - of polygons/holes, and changes in vertices number. - """ - - # no need to validate anything here - # sidewall_angle may be autograd-traced; use static value for this check only - if math.isclose(getval(values["sidewall_angle"]), 0): - return val - - # apply dilation - poly_ref = PolySlab._proper_vertices(val) - if not math.isclose(values["dilation"], 0): - poly_ref = PolySlab._shift_vertices(poly_ref, values["dilation"])[0] - poly_ref = PolySlab._heal_polygon(poly_ref) - - slab_min, slab_max = values["slab_bounds"] - slab_bounds = [getval(slab_min), getval(slab_max)] - - # first, check vertex-vertex crossing at any point during extrusion - length = slab_bounds[1] - slab_bounds[0] - dist = [-length * np.tan(values["sidewall_angle"])] - # reverse the dilation value if it's defined on the top - if values["reference_plane"] == "top": - dist = [-dist[0]] - # for middle, both direction needs to be examined - elif values["reference_plane"] == "middle": - dist = [dist[0] / 2, -dist[0] / 2] - - # capture vertex crossing events - max_thick = [] - for dist_val in dist: - max_dist = PolySlab._neighbor_vertices_crossing_detection(poly_ref, dist_val) - - if max_dist is not None: - max_thick.append(max_dist / abs(dist_val) * length) - - if len(max_thick) > 0: - max_thick = min(max_thick) - raise SetupError( - "Sidewall angle or structure thickness is so large that the polygon " - "is self-intersecting during extrusion. " - f"Please either reduce structure thickness to be < {max_thick:.3e}, " - "or use our plugin 'ComplexPolySlab' to divide the complex polyslab " - "into a list of simple polyslabs." - ) - - # vertex-edge crossing event. - for dist_val in dist: - if PolySlab._edge_events_detection(poly_ref, dist_val): - raise SetupError( - "Sidewall angle or structure thickness is too large, " - "resulting in polygon splitting or generation of holes/islands. " - "A general treatment to self-intersecting polygon will be available " - "in future releases." - ) - return val - - @classmethod - def from_gds( - cls, - gds_cell: Cell, - axis: Axis, - slab_bounds: tuple[float, float], - gds_layer: int, - gds_dtype: Optional[int] = None, - gds_scale: pydantic.PositiveFloat = 1.0, - dilation: float = 0.0, - sidewall_angle: float = 0, - reference_plane: PlanePosition = "middle", - ) -> list[PolySlab]: - """Import :class:`PolySlab` from a ``gdstk.Cell``. - - Parameters - ---------- - gds_cell : gdstk.Cell - ``gdstk.Cell`` containing 2D geometric data. - axis : int - Integer index into the polygon's slab axis. (0,1,2) -> (x,y,z). - slab_bounds: Tuple[float, float] - Minimum and maximum positions of the slab along ``axis``. - gds_layer : int - Layer index in the ``gds_cell``. - gds_dtype : int = None - Data-type index in the ``gds_cell``. - If ``None``, imports all data for this layer into the returned list. - gds_scale : float = 1.0 - Length scale used in GDS file in units of MICROMETER. - For example, if gds file uses nanometers, set ``gds_scale=1e-3``. - Must be positive. - dilation : float = 0.0 - Dilation of the polygon in the base by shifting each edge along its - normal outwards direction by a distance; - a negative value corresponds to erosion. - sidewall_angle : float = 0 - Angle of the sidewall. - ``sidewall_angle=0`` (default) specifies vertical wall, - while ``0 list[ArrayFloat2D]: - """Import :class:`PolySlab` from a ``gdstk.Cell``. - - Parameters - ---------- - gds_cell : gdstk.Cell - ``gdstk.Cell`` containing 2D geometric data. - gds_layer : int - Layer index in the ``gds_cell``. - gds_dtype : int = None - Data-type index in the ``gds_cell``. - If ``None``, imports all data for this layer into the returned list. - gds_scale : float = 1.0 - Length scale used in GDS file in units of MICROMETER. - For example, if gds file uses nanometers, set ``gds_scale=1e-3``. - Must be positive. - - Returns - ------- - List[ArrayFloat2D] - List of :class:`.ArrayFloat2D` - """ - import gdstk - - gds_cell_class_name = str(gds_cell.__class__) - if not isinstance(gds_cell, gdstk.Cell): - if ( - "gdstk" in gds_cell_class_name - ): # Check if it might be a gdstk cell but gdstk is not found - raise Tidy3dImportError( - "Module 'gdstk' not found. It is required to import gdstk cells." - ) - raise ValueError( - f"validate 'gds_cell' of type '{gds_cell_class_name}' " - "does not seem to be associated with 'gdstk' package " - "and therefore can't be loaded by Tidy3D." - ) - - all_vertices = base.Geometry.load_gds_vertices_gdstk( - gds_cell=gds_cell, - gds_layer=gds_layer, - gds_dtype=gds_dtype, - gds_scale=gds_scale, - ) - - # convert vertices into polyslabs - polygons = [PolySlab.make_shapely_polygon(vertices).buffer(0) for vertices in all_vertices] - polys_union = shapely.unary_union(polygons, grid_size=base.POLY_GRID_SIZE) - - if polys_union.geom_type == "Polygon": - all_vertices = [np.array(polys_union.exterior.coords)] - elif polys_union.geom_type == "MultiPolygon": - all_vertices = [np.array(polygon.exterior.coords) for polygon in polys_union.geoms] - return all_vertices - - @property - def center_axis(self) -> float: - """Gets the position of the center of the geometry in the out of plane dimension.""" - zmin, zmax = self.slab_bounds - if np.isneginf(zmin) and np.isposinf(zmax): - return 0.0 - zmin = max(zmin, -LARGE_NUMBER) - zmax = min(zmax, LARGE_NUMBER) - return (zmax + zmin) / 2.0 - - @property - def length_axis(self) -> float: - """Gets the length of the geometry along the out of plane dimension.""" - zmin, zmax = self.slab_bounds - return zmax - zmin - - @property - def finite_length_axis(self) -> float: - """Gets the length of the PolySlab along the out of plane dimension. - First clips the slab bounds to LARGE_NUMBER and then returns difference. - """ - zmin, zmax = self.slab_bounds - zmin = max(zmin, -LARGE_NUMBER) - zmax = min(zmax, LARGE_NUMBER) - return zmax - zmin - - @cached_property - def reference_polygon(self) -> NDArray: - """The polygon at the reference plane. - - Returns - ------- - ArrayLike[float, float] - The vertices of the polygon at the reference plane. - """ - vertices = self._proper_vertices(self.vertices) - if math.isclose(self.dilation, 0): - return vertices - offset_vertices = self._shift_vertices(vertices, self.dilation)[0] - return self._heal_polygon(offset_vertices) - - @cached_property - def middle_polygon(self) -> NDArray: - """The polygon at the middle. - - Returns - ------- - ArrayLike[float, float] - The vertices of the polygon at the middle. - """ - - dist = self._extrusion_length_to_offset_distance(self.finite_length_axis / 2) - if self.reference_plane == "bottom": - return self._shift_vertices(self.reference_polygon, dist)[0] - if self.reference_plane == "top": - return self._shift_vertices(self.reference_polygon, -dist)[0] - # middle case - return self.reference_polygon - - @cached_property - def base_polygon(self) -> NDArray: - """The polygon at the base, derived from the ``middle_polygon``. - - Returns - ------- - ArrayLike[float, float] - The vertices of the polygon at the base. - """ - if self.reference_plane == "bottom": - return self.reference_polygon - dist = self._extrusion_length_to_offset_distance(-self.finite_length_axis / 2) - return self._shift_vertices(self.middle_polygon, dist)[0] - - @cached_property - def top_polygon(self) -> NDArray: - """The polygon at the top, derived from the ``middle_polygon``. - - Returns - ------- - ArrayLike[float, float] - The vertices of the polygon at the top. - """ - if self.reference_plane == "top": - return self.reference_polygon - dist = self._extrusion_length_to_offset_distance(self.finite_length_axis / 2) - return self._shift_vertices(self.middle_polygon, dist)[0] - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - if self.slab_bounds[0] != self.slab_bounds[1]: - raise ValidationError("'Medium2D' requires the 'PolySlab' bounds to be equal.") - return self.axis - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> PolySlab: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - if axis != self.axis: - raise ValueError( - f"'_update_from_bounds' may only be applied along axis '{self.axis}', " - f"but was given axis '{axis}'." - ) - return self.updated_copy(slab_bounds=bounds) - - @cached_property - def is_ccw(self) -> bool: - """Is this ``PolySlab`` CCW-oriented?""" - return PolySlab._area(self.vertices) > 0 - - def inside(self, x: NDArray[float], y: NDArray[float], z: NDArray[float]) -> NDArray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Note - ---- - For slanted sidewalls, this function only works if x, y, and z are arrays produced by a - ``meshgrid call``, i.e. 3D arrays and each is constant along one axis. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - self._ensure_equal_shape(x, y, z) - - z, (x, y) = self.pop_axis((x, y, z), axis=self.axis) - - z0 = self.center_axis - dist_z = np.abs(z - z0) - inside_height = dist_z <= (self.finite_length_axis / 2) - - # avoid going into face checking if no points are inside slab bounds - if not np.any(inside_height): - return inside_height - - # check what points are inside polygon cross section (face) - z_local = z - z0 # distance to the middle - dist = -z_local * self._tanq - - if isinstance(x, np.ndarray): - inside_polygon = np.zeros_like(inside_height) - xs_slab = x[inside_height] - ys_slab = y[inside_height] - - # vertical sidewall - if math.isclose(self.sidewall_angle, 0): - face_polygon = shapely.Polygon(self.reference_polygon).buffer(fp_eps) - shapely.prepare(face_polygon) - inside_polygon_slab = shapely.contains_xy(face_polygon, x=xs_slab, y=ys_slab) - inside_polygon[inside_height] = inside_polygon_slab - # slanted sidewall, offsetting vertices at each z - else: - # a helper function for moving axis - def _move_axis(arr: NDArray) -> NDArray: - return np.moveaxis(arr, source=self.axis, destination=-1) - - def _move_axis_reverse(arr: NDArray) -> NDArray: - return np.moveaxis(arr, source=-1, destination=self.axis) - - inside_polygon_axis = _move_axis(inside_polygon) - x_axis = _move_axis(x) - y_axis = _move_axis(y) - - for z_i in range(z.shape[self.axis]): - if not _move_axis(inside_height)[0, 0, z_i]: - continue - vertices_z = self._shift_vertices( - self.middle_polygon, _move_axis(dist)[0, 0, z_i] - )[0] - face_polygon = shapely.Polygon(vertices_z).buffer(fp_eps) - shapely.prepare(face_polygon) - xs = x_axis[:, :, 0].flatten() - ys = y_axis[:, :, 0].flatten() - inside_polygon_slab = shapely.contains_xy(face_polygon, x=xs, y=ys) - inside_polygon_axis[:, :, z_i] = inside_polygon_slab.reshape(x_axis.shape[:2]) - inside_polygon = _move_axis_reverse(inside_polygon_axis) - else: - vertices_z = self._shift_vertices(self.middle_polygon, dist)[0] - face_polygon = self.make_shapely_polygon(vertices_z).buffer(fp_eps) - point = shapely.Point(x, y) - inside_polygon = face_polygon.covers(point) - return inside_height * inside_polygon - - @verify_packages_import(["trimesh"]) - def _do_intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for PolySlab geometry. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - import trimesh - - if len(self.base_polygon) > _MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION: - log.warning( - f"Processing PolySlabs with over {_MAX_POLYSLAB_VERTICES_FOR_TRIANGULATION} vertices can be slow.", - log_once=True, - ) - base_triangles = triangulation.triangulate(self.base_polygon) - top_triangles = ( - base_triangles - if math.isclose(self.sidewall_angle, 0) - else triangulation.triangulate(self.top_polygon) - ) - - n = len(self.base_polygon) - faces = ( - [[a, b, c] for c, b, a in base_triangles] - + [[n + a, n + b, n + c] for a, b, c in top_triangles] - + [(i, (i + 1) % n, n + i) for i in range(n)] - + [((i + 1) % n, n + ((i + 1) % n), n + i) for i in range(n)] - ) - - x = np.hstack((self.base_polygon[:, 0], self.top_polygon[:, 0])) - y = np.hstack((self.base_polygon[:, 1], self.top_polygon[:, 1])) - z = np.hstack((np.full(n, self.slab_bounds[0]), np.full(n, self.slab_bounds[1]))) - vertices = np.vstack(self.unpop_axis(z, (x, y), self.axis)).T - mesh = trimesh.Trimesh(vertices, faces) - - section = mesh.section(plane_origin=origin, plane_normal=normal) - if section is None: - return [] - path, _ = section.to_2D(to_2D=to_2D) - return path.polygons_full - - def _intersections_normal(self, z: float, quad_segs: Optional[int] = None) -> list[Shapely]: - """Find shapely geometries intersecting planar geometry with axis normal to slab. - - Parameters - ---------- - z : float - Position along the axis normal to slab. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - if math.isclose(self.sidewall_angle, 0): - return [self.make_shapely_polygon(self.reference_polygon)] - - z0 = self.center_axis - z_local = z - z0 # distance to the middle - dist = -z_local * self._tanq - vertices_z = self._shift_vertices(self.middle_polygon, dist)[0] - return [self.make_shapely_polygon(vertices_z)] - - def _intersections_side(self, position: float, axis: int) -> list[Shapely]: - """Find shapely geometries intersecting planar geometry with axis orthogonal to slab. - - For slanted polyslab, the procedure is as follows, - 1) Find out all z-coordinates where the plane will intersect directly with a vertex. - Denote the coordinates as (z_0, z_1, z_2, ... ) - 2) Find out all polygons that can be formed between z_i and z_{i+1}. There are two - types of polygons: - a) formed by the plane intersecting the edges - b) formed by the plane intersecting the vertices. - For either type, one needs to compute: - i) intersecting position - ii) angle between the plane and the intersecting edge - For a), both are straightforward to compute; while for b), one needs to compute - which edge the plane will slide into. - 3) Looping through z_i, and merge all polygons. The partition by z_i is because once - the plane intersects the vertex, it can intersect with other edges during - the extrusion. - - Parameters - ---------- - position : float - Position along ``axis``. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - - # find out all z_i where the plane will intersect the vertex - z0 = self.center_axis - z_base = z0 - self.finite_length_axis / 2 - - axis_ordered = self._order_axis(axis) - height_list = self._find_intersecting_height(position, axis_ordered) - polys = [] - - # looping through z_i to assemble the polygons - height_list = np.append(height_list, self.finite_length_axis) - h_base = 0.0 - for h_top in height_list: - # length within between top and bottom - h_length = h_top - h_base - - # coordinate of each subsection - z_min = z_base + h_base - z_max = np.inf if np.isposinf(h_top) else z_base + h_top - - # for vertical sidewall, no need for complications - if math.isclose(self.sidewall_angle, 0): - ints_y, ints_angle = self._find_intersecting_ys_angle_vertical( - self.reference_polygon, position, axis_ordered - ) - else: - # for slanted sidewall, move up by `fp_eps` in case vertices are degenerate at the base. - dist = -(h_base - self.finite_length_axis / 2 + fp_eps) * self._tanq - vertices = self._shift_vertices(self.middle_polygon, dist)[0] - ints_y, ints_angle = self._find_intersecting_ys_angle_slant( - vertices, position, axis_ordered - ) - - # make polygon with intersections and z axis information - for y_index in range(len(ints_y) // 2): - y_min = ints_y[2 * y_index] - y_max = ints_y[2 * y_index + 1] - minx, miny = self._order_by_axis(plane_val=y_min, axis_val=z_min, axis=axis) - maxx, maxy = self._order_by_axis(plane_val=y_max, axis_val=z_max, axis=axis) - - if math.isclose(self.sidewall_angle, 0): - polys.append(self.make_shapely_box(minx, miny, maxx, maxy)) - else: - angle_min = ints_angle[2 * y_index] - angle_max = ints_angle[2 * y_index + 1] - - angle_min = np.arctan(np.tan(self.sidewall_angle) / np.sin(angle_min)) - angle_max = np.arctan(np.tan(self.sidewall_angle) / np.sin(angle_max)) - - dy_min = h_length * np.tan(angle_min) - dy_max = h_length * np.tan(angle_max) - - x1, y1 = self._order_by_axis(plane_val=y_min, axis_val=z_min, axis=axis) - x2, y2 = self._order_by_axis(plane_val=y_max, axis_val=z_min, axis=axis) - x3, y3 = self._order_by_axis( - plane_val=y_max - dy_max, axis_val=z_max, axis=axis - ) - x4, y4 = self._order_by_axis( - plane_val=y_min + dy_min, axis_val=z_max, axis=axis - ) - vertices = ((x1, y1), (x2, y2), (x3, y3), (x4, y4)) - polys.append(self.make_shapely_polygon(vertices).buffer(0)) - # update the base coordinate for the next subsection - h_base = h_top - - # merge touching polygons - polys_union = shapely.unary_union(polys, grid_size=base.POLY_GRID_SIZE) - if polys_union.geom_type == "Polygon": - return [polys_union] - if polys_union.geom_type == "MultiPolygon": - return polys_union.geoms - # in other cases, just return the original unmerged polygons - return polys - - def _find_intersecting_height(self, position: float, axis: int) -> NDArray: - """Found a list of height where the plane will intersect with the vertices; - For vertical sidewall, just return np.array([]). - Assumes axis is handles so this function works on xy plane. - - Parameters - ---------- - position : float - position along axis. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - np.ndarray - Height (relative to the base) where the plane will intersect with vertices. - """ - if math.isclose(self.sidewall_angle, 0): - return np.array([]) - - # shift rate - dist = 1.0 - shift_x, shift_y = PolySlab._shift_vertices(self.middle_polygon, dist)[2] - shift_val = shift_x if axis == 0 else shift_y - shift_val[np.isclose(shift_val, 0, rtol=_IS_CLOSE_RTOL)] = np.inf # for static vertices - - # distance to the plane in the direction of vertex shifting - distance = self.middle_polygon[:, axis] - position - height = distance / self._tanq / shift_val + self.finite_length_axis / 2 - height = np.unique(height) - # further filter very close ones - is_not_too_close = np.insert((np.diff(height) > fp_eps), 0, True) - height = height[is_not_too_close] - - height = height[height > fp_eps] - height = height[height < self.finite_length_axis - fp_eps] - return height - - def _find_intersecting_ys_angle_vertical( - self, - vertices: NDArray, - position: float, - axis: int, - exclude_on_vertices: bool = False, - ) -> tuple[NDArray, NDArray, NDArray]: - """Finds pairs of forward and backwards vertices where polygon intersects position at axis, - Find intersection point (in y) assuming straight line,and intersecting angle between plane - and edges. (For unslanted polyslab). - Assumes axis is handles so this function works on xy plane. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - position : float - position along axis. - axis : int - Integer index into 'xyz' (0,1,2). - exclude_on_vertices : bool = False - Whether to exclude those intersecting directly with the vertices. - - Returns - ------- - Union[np.ndarray, np.ndarray] - List of intersection points along y direction. - List of angles between plane and edges. - """ - - vertices_axis = vertices - - # flip vertices x,y for axis = y - if axis == 1: - vertices_axis = np.roll(vertices_axis, shift=1, axis=1) - - # get the forward vertices - vertices_f = np.roll(vertices_axis, shift=-1, axis=0) - - # x coordinate of the two sets of vertices - x_vertices_f, _ = vertices_f.T - x_vertices_axis, _ = vertices_axis.T - - # find which segments intersect - f_left_to_intersect = x_vertices_f <= position - orig_right_to_intersect = x_vertices_axis > position - intersects_b = np.logical_and(f_left_to_intersect, orig_right_to_intersect) - - f_right_to_intersect = x_vertices_f > position - orig_left_to_intersect = x_vertices_axis <= position - intersects_f = np.logical_and(f_right_to_intersect, orig_left_to_intersect) - - # exclude vertices at the position if exclude_on_vertices is True - if exclude_on_vertices: - intersects_on = np.isclose(x_vertices_axis, position, rtol=_IS_CLOSE_RTOL) - intersects_f_on = np.isclose(x_vertices_f, position, rtol=_IS_CLOSE_RTOL) - intersects_both_off = np.logical_not(np.logical_or(intersects_on, intersects_f_on)) - intersects_f &= intersects_both_off - intersects_b &= intersects_both_off - intersects_segment = np.logical_or(intersects_b, intersects_f) - - iverts_b = vertices_axis[intersects_segment] - iverts_f = vertices_f[intersects_segment] - - # intersecting positions and angles - ints_y = [] - ints_angle = [] - for vertices_f_local, vertices_b_local in zip(iverts_b, iverts_f): - x1, y1 = vertices_f_local - x2, y2 = vertices_b_local - slope = (y2 - y1) / (x2 - x1) - y = y1 + slope * (position - x1) - ints_y.append(y) - ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope))) - - ints_y = np.array(ints_y) - ints_angle = np.array(ints_angle) - - sort_index = np.argsort(ints_y) - ints_y_sort = ints_y[sort_index] - ints_angle_sort = ints_angle[sort_index] - - return ints_y_sort, ints_angle_sort - - def _find_intersecting_ys_angle_slant( - self, vertices: NDArray, position: float, axis: int - ) -> tuple[NDArray, NDArray, NDArray]: - """Finds pairs of forward and backwards vertices where polygon intersects position at axis, - Find intersection point (in y) assuming straight line,and intersecting angle between plane - and edges. (For slanted polyslab) - Assumes axis is handles so this function works on xy plane. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - position : float - position along axis. - axis : int - Integer index into 'xyz' (0,1,2). - - Returns - ------- - Union[np.ndarray, np.ndarray] - List of intersection points along y direction. - List of angles between plane and edges. - """ - - vertices_axis = vertices.copy() - # flip vertices x,y for axis = y - if axis == 1: - vertices_axis = np.roll(vertices_axis, shift=1, axis=1) - - # get the forward vertices - vertices_f = np.roll(vertices_axis, shift=-1, axis=0) - # get the backward vertices - vertices_b = np.roll(vertices_axis, shift=1, axis=0) - - ## First part, plane intersects with edges, same as vertical - ints_y, ints_angle = self._find_intersecting_ys_angle_vertical( - vertices, position, axis, exclude_on_vertices=True - ) - ints_y = ints_y.tolist() - ints_angle = ints_angle.tolist() - - ## Second part, plane intersects directly with vertices - # vertices on the intersection - intersects_on = np.isclose(vertices_axis[:, 0], position, rtol=_IS_CLOSE_RTOL) - iverts_on = vertices_axis[intersects_on] - # position of the neighbouring vertices - iverts_b = vertices_b[intersects_on] - iverts_f = vertices_f[intersects_on] - # shift rate - dist = -np.sign(self.sidewall_angle) - shift_x, shift_y = self._shift_vertices(self.middle_polygon, dist)[2] - shift_val = shift_x if axis == 0 else shift_y - shift_val = shift_val[intersects_on] - - for vertices_f_local, vertices_b_local, vertices_on_local, shift_local in zip( - iverts_f, iverts_b, iverts_on, shift_val - ): - x_on, y_on = vertices_on_local - x_f, y_f = vertices_f_local - x_b, y_b = vertices_b_local - - num_added = 0 # keep track the number of added vertices - slope = [] # list of slopes for added vertices - # case 1, shifting velocity is 0 - if np.isclose(shift_local, 0, rtol=_IS_CLOSE_RTOL): - ints_y.append(y_on) - # Slope w.r.t. forward and backward should equal, - # just pick one of them. - slope.append((y_on - y_b) / (x_on - x_b)) - ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope[0]))) - continue - - # case 2, shifting towards backward direction - if (x_b - position) * shift_local < 0: - ints_y.append(y_on) - slope.append((y_on - y_b) / (x_on - x_b)) - num_added += 1 - - # case 3, shifting towards forward direction - if (x_f - position) * shift_local < 0: - ints_y.append(y_on) - slope.append((y_on - y_f) / (x_on - x_f)) - num_added += 1 - - # in case 2, and case 3, if just num_added = 1 - if num_added == 1: - ints_angle.append(np.pi / 2 - np.arctan(np.abs(slope[0]))) - # if num_added = 2, the order of the two new vertices needs to handled correctly; - # it should be sorted according to the -slope * moving direction - elif num_added == 2: - dressed_slope = [-s_i * shift_local for s_i in slope] - sort_index = np.argsort(np.array(dressed_slope)) - sorted_slope = np.array(slope)[sort_index] - - ints_angle.append(np.pi / 2 - np.arctan(np.abs(sorted_slope[0]))) - ints_angle.append(np.pi / 2 - np.arctan(np.abs(sorted_slope[1]))) - - ints_y = np.array(ints_y) - ints_angle = np.array(ints_angle) - - sort_index = np.argsort(ints_y) - ints_y_sort = ints_y[sort_index] - ints_angle_sort = ints_angle[sort_index] - - return ints_y_sort, ints_angle_sort - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. The dilation and slant angle are not - taken into account exactly for speed. Instead, the polygon may be slightly smaller than - the returned bounds, but it should always be fully contained. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - - # check for the maximum possible contribution from dilation/slant on each side - max_offset = self.dilation - # sidewall_angle may be autograd-traced; unbox for this check - if not math.isclose(getval(self.sidewall_angle), 0): - if self.reference_plane == "bottom": - max_offset += max(0, -self._tanq * self.finite_length_axis) - elif self.reference_plane == "top": - max_offset += max(0, self._tanq * self.finite_length_axis) - elif self.reference_plane == "middle": - max_offset += max(0, abs(self._tanq) * self.finite_length_axis / 2) - - # special care when dilated - if max_offset > 0: - dilated_vertices = self._shift_vertices( - self._proper_vertices(self.vertices), max_offset - )[0] - xmin, ymin = np.amin(dilated_vertices, axis=0) - xmax, ymax = np.amax(dilated_vertices, axis=0) - else: - # otherwise, bounds are directly based on the supplied vertices - xmin, ymin = np.amin(self.vertices, axis=0) - xmax, ymax = np.amax(self.vertices, axis=0) - - # get bounds in (local) z - zmin, zmax = self.slab_bounds - - # rearrange axes - coords_min = self.unpop_axis(zmin, (xmin, ymin), axis=self.axis) - coords_max = self.unpop_axis(zmax, (xmax, ymax), axis=self.axis) - return (tuple(coords_min), tuple(coords_max)) - - def _extrusion_length_to_offset_distance(self, extrusion: float) -> float: - """Convert extrusion length to offset distance.""" - if math.isclose(self.sidewall_angle, 0): - return 0 - return -extrusion * self._tanq - - @staticmethod - def _area(vertices: NDArray) -> float: - """Compute the signed polygon area (positive for CCW orientation). - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - - Returns - ------- - float - Signed polygon area (positive for CCW orientation). - """ - vert_shift = np.roll(vertices, axis=0, shift=-1) - - xs, ys = vertices.T - xs_shift, ys_shift = vert_shift.T - - term1 = xs * ys_shift - term2 = ys * xs_shift - return np.sum(term1 - term2) * 0.5 - - @staticmethod - def _perimeter(vertices: NDArray) -> float: - """Compute the polygon perimeter. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - - Returns - ------- - float - Polygon perimeter. - """ - - vert_shift = np.roll(vertices, axis=0, shift=-1) - squared_diffs = (vertices - vert_shift) ** 2 - - # distance along each edge - dists = np.sqrt(squared_diffs.sum(axis=-1)) - - # total distance along all edges - return np.sum(dists) - - @staticmethod - def _orient(vertices: NDArray) -> NDArray: - """Return a CCW-oriented polygon. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - - Returns - ------- - np.ndarray - Vertices of a CCW-oriented polygon. - """ - return vertices if PolySlab._area(vertices) > 0 else vertices[::-1, :] - - @staticmethod - def _remove_duplicate_vertices(vertices: NDArray) -> NDArray: - """Remove redundant/identical nearest neighbour vertices. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - - Returns - ------- - np.ndarray - Vertices of polygon. - """ - - vertices_f = np.roll(vertices, shift=-1, axis=0) - vertices_diff = np.linalg.norm(vertices - vertices_f, axis=1) - return vertices[~np.isclose(vertices_diff, 0, rtol=_IS_CLOSE_RTOL)] - - @staticmethod - def _proper_vertices(vertices: ArrayFloat2D) -> NDArray: - """convert vertices to np.array format, - removing duplicate neighbouring vertices, - and oriented in CCW direction. - - Returns - ------- - ArrayLike[float, float] - The vertices of the polygon for internal use. - """ - - vertices_np = PolySlab.vertices_to_array(vertices) - return PolySlab._orient(PolySlab._remove_duplicate_vertices(vertices_np)) - - @staticmethod - def _edge_events_detection( - proper_vertices: NDArray, dilation: float, ignore_at_dist: bool = True - ) -> bool: - """Detect any edge events within the offset distance ``dilation``. - If ``ignore_at_dist=True``, the edge event at ``dist`` is ignored. - """ - - # ignore the event that occurs right at the offset distance - if ignore_at_dist: - dilation -= fp_eps * dilation / abs(dilation) - # number of vertices before offsetting - num_vertices = proper_vertices.shape[0] - - # 0) fully eroded? - if dilation < 0 and dilation < -PolySlab._maximal_erosion(proper_vertices): - return True - - # sample at a few dilation values - dist_list = ( - dilation - * np.linspace( - 0, 1, 1 + _N_SAMPLE_POLYGON_INTERSECT, dtype=config.adjoint.gradient_dtype_float - )[1:] - ) - for dist in dist_list: - # offset: we offset the vertices first, and then use shapely to make it proper - # in principle, one can offset with shapely.buffer directly, but shapely somehow - # automatically removes some vertices even though no change of topology. - poly_offset = PolySlab._shift_vertices(proper_vertices, dist)[0] - # flipped winding number - if PolySlab._area(poly_offset) < fp_eps**2: - return True - - poly_offset = shapely.make_valid(PolySlab.make_shapely_polygon(poly_offset)) - # 1) polygon split or create holes/islands - if not poly_offset.geom_type == "Polygon" or len(poly_offset.interiors) > 0: - return True - - # 2) reduction in vertex number - offset_vertices = PolySlab._proper_vertices(list(poly_offset.exterior.coords)) - if offset_vertices.shape[0] != num_vertices: - return True - - # 3) some split polygon might fully disappear after the offset, but they - # can be detected if we offset back. - poly_offset_back = shapely.make_valid( - PolySlab.make_shapely_polygon(PolySlab._shift_vertices(offset_vertices, -dist)[0]) - ) - if poly_offset_back.geom_type == "MultiPolygon" or len(poly_offset_back.interiors) > 0: - return True - offset_back_vertices = list(poly_offset_back.exterior.coords) - if PolySlab._proper_vertices(offset_back_vertices).shape[0] != num_vertices: - return True - - return False - - @staticmethod - def _neighbor_vertices_crossing_detection( - vertices: NDArray, dist: float, ignore_at_dist: bool = True - ) -> float: - """Detect if neighboring vertices will cross after a dilation distance dist. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - dist : float - Distance to offset. - ignore_at_dist : bool, optional - whether to ignore the event right at ``dist`. - - Returns - ------- - float - the absolute value of the maximal allowed dilation - if there are any crossing, otherwise return ``None``. - """ - # ignore the event that occurs right at the offset distance - if ignore_at_dist: - dist -= fp_eps * dist / abs(dist) - - edge_length, edge_reduction = PolySlab._edge_length_and_reduction_rate(vertices) - length_remaining = edge_length - edge_reduction * dist - - if np.any(length_remaining < 0): - index_oversized = length_remaining < 0 - max_dist = np.min( - np.abs(edge_length[index_oversized] / edge_reduction[index_oversized]) - ) - return max_dist - return None - - @staticmethod - def array_to_vertices(arr_vertices: NDArray) -> ArrayFloat2D: - """Converts a numpy array of vertices to a list of tuples.""" - return list(arr_vertices) - - @staticmethod - def vertices_to_array(vertices_tuple: ArrayFloat2D) -> NDArray: - """Converts a list of tuples (vertices) to a numpy array.""" - return np.array(vertices_tuple) - - @cached_property - def interior_angle(self) -> ArrayFloat1D: - """Angle formed inside polygon by two adjacent edges.""" - - def normalize(v: NDArray) -> NDArray: - return v / np.linalg.norm(v, axis=0) - - vs_orig = self.reference_polygon.T - vs_next = np.roll(vs_orig, axis=-1, shift=-1) - vs_previous = np.roll(vs_orig, axis=-1, shift=+1) - - asp = normalize(vs_next - vs_orig) - asm = normalize(vs_previous - vs_orig) - - cos_angle = asp[0] * asm[0] + asp[1] * asm[1] - sin_angle = asp[0] * asm[1] - asp[1] * asm[0] - - angle = np.arccos(cos_angle) - # concave angles - angle[sin_angle < 0] = 2 * np.pi - angle[sin_angle < 0] - return angle - - @staticmethod - def _shift_vertices( - vertices: NDArray, dist: float - ) -> tuple[NDArray, NDArray, tuple[NDArray, NDArray]]: - """Shifts the vertices of a polygon outward uniformly by distances - `dists`. - - Parameters - ---------- - np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - dist : float - Distance to offset. - - Returns - ------- - Tuple[np.ndarray, np.narray,Tuple[np.ndarray,np.ndarray]] - New polygon vertices; - and the shift of vertices in direction parallel to the edges. - Shift along x and y direction. - """ - - # 'dist' may be autograd-traced; unbox for the zero-check only - if math.isclose(getval(dist), 0): - return vertices, np.zeros(vertices.shape[0], dtype=float), None - - def rot90(v: tuple[NDArray, NDArray]) -> NDArray: - """90 degree rotation of 2d vector - vx -> vy - vy -> -vx - """ - vxs, vys = v - return np.stack((-vys, vxs), axis=0) - - def cross(u: NDArray, v: NDArray) -> Any: - return u[0] * v[1] - u[1] * v[0] - - def normalize(v: NDArray) -> NDArray: - return v / np.linalg.norm(v, axis=0) - - vs_orig = copy(vertices.T) - vs_next = np.roll(copy(vs_orig), axis=-1, shift=-1) - vs_previous = np.roll(copy(vs_orig), axis=-1, shift=+1) - - asp = normalize(vs_next - vs_orig) - asm = normalize(vs_orig - vs_previous) - - # the vertex shift is decomposed into parallel and perpendicular directions - perpendicular_shift = -dist - det = cross(asm, asp) - - tan_half_angle = np.where( - np.isclose(det, 0, rtol=_IS_CLOSE_RTOL), - 0.0, - cross(asm, rot90(asm - asp)) / (det + np.isclose(det, 0, rtol=_IS_CLOSE_RTOL)), - ) - parallel_shift = dist * tan_half_angle - - shift_total = perpendicular_shift * rot90(asm) + parallel_shift * asm - shift_x = shift_total[0, :] - shift_y = shift_total[1, :] - - return ( - np.swapaxes(vs_orig + shift_total, -2, -1), - parallel_shift, - (shift_x, shift_y), - ) - - @staticmethod - def _edge_length_and_reduction_rate( - vertices: NDArray, - ) -> tuple[NDArray, NDArray]: - """Edge length of reduction rate of each edge with unit offset length. - - Parameters - ---------- - vertices : np.ndarray - Shape (N, 2) defining the polygon vertices in the xy-plane. - - Returns - ------- - Tuple[np.ndarray, np.narray] - edge length, and reduction rate - """ - - # edge length - vs_orig = copy(vertices.T) - vs_next = np.roll(copy(vs_orig), axis=-1, shift=-1) - edge_length = np.linalg.norm(vs_next - vs_orig, axis=0) - - # edge length remaining - dist = 1 - parallel_shift = PolySlab._shift_vertices(vertices, dist)[1] - parallel_shift_p = np.roll(copy(parallel_shift), shift=-1) - edge_reduction = -(parallel_shift + parallel_shift_p) - return edge_length, edge_reduction - - @staticmethod - def _maximal_erosion(vertices: NDArray) -> float: - """The erosion value that reduces the length of - all edges to be non-positive. - """ - edge_length, edge_reduction = PolySlab._edge_length_and_reduction_rate(vertices) - ind_nonzero = abs(edge_reduction) > fp_eps - return -np.min(edge_length[ind_nonzero] / edge_reduction[ind_nonzero]) - - @staticmethod - def _heal_polygon(vertices: NDArray) -> NDArray: - """heal a self-intersecting polygon.""" - shapely_poly = PolySlab.make_shapely_polygon(vertices) - if shapely_poly.is_valid: - return vertices - if isbox(vertices): - raise NotImplementedError( - "The dilation caused damage to the polygon. " - "Automatically healing this is currently not supported when " - "differentiating w.r.t. the vertices. Try increasing the spacing " - "between vertices or reduce the amount of dilation." - ) - # perform healing - poly_heal = shapely.make_valid(shapely_poly) - return PolySlab._proper_vertices(list(poly_heal.exterior.coords)) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - - z_min, z_max = self.slab_bounds - - z_min = max(z_min, bounds[0][self.axis]) - z_max = min(z_max, bounds[1][self.axis]) - - length = z_max - z_min - - top_area = abs(self._area(self.top_polygon)) - base_area = abs(self._area(self.base_polygon)) - - # https://mathworld.wolfram.com/PyramidalFrustum.html - return 1.0 / 3.0 * length * (top_area + base_area + np.sqrt(top_area * base_area)) - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - - area = 0 - - top_polygon = self.top_polygon - base_polygon = self.base_polygon - - top_area = abs(self._area(top_polygon)) - base_area = abs(self._area(base_polygon)) - - top_perim = self._perimeter(top_polygon) - base_perim = self._perimeter(base_polygon) - - z_min, z_max = self.slab_bounds - - if z_min < bounds[0][self.axis]: - z_min = bounds[0][self.axis] - else: - area += base_area - - if z_max > bounds[1][self.axis]: - z_max = bounds[1][self.axis] - else: - area += top_area - - length = z_max - z_min - - area += 0.5 * (top_perim + base_perim) * length - - return area - - """ Autograd code """ - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """ - Return VJPs while handling several edge-cases: - - - If the slab volume does not overlap the simulation, all grads are zero - (one warning is issued). - - Faces that lie completely outside the simulation give zero ``slab_bounds`` - gradients; this includes the +/- inf cases. - - A 2d simulation collapses the surface integral to a line integral - """ - vjps: AutogradFieldMap = {} - - intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect) - sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds) - - extents = intersect_max - intersect_min - is_2d = np.isclose(extents[self.axis], 0.0) - - # early return if polyslab is not in simulation domain - slab_min, slab_max = self.slab_bounds - if (slab_max < sim_min[self.axis]) or (slab_min > sim_max[self.axis]): - log.warning( - "'PolySlab' lies completely outside the simulation domain.", - log_once=True, - ) - for p in derivative_info.paths: - vjps[p] = np.zeros_like(self.vertices) if p == ("vertices",) else 0.0 - return vjps - - # create interpolators once for ALL derivative computations - # use provided interpolators if available to avoid redundant field data conversions - interpolators = derivative_info.interpolators or derivative_info.create_interpolators( - dtype=config.adjoint.gradient_dtype_float - ) - - for path in derivative_info.paths: - if path == ("vertices",): - vjps[path] = self._compute_derivative_vertices( - derivative_info, sim_min, sim_max, is_2d, interpolators - ) - - elif path == ("sidewall_angle",): - vjps[path] = self._compute_derivative_sidewall_angle( - derivative_info, sim_min, sim_max, is_2d, interpolators - ) - elif path[0] == "slab_bounds": - idx = path[1] - face_coord = self.slab_bounds[idx] - - # face entirely outside -> gradient 0 - if ( - np.isinf(face_coord) - or face_coord < sim_min[self.axis] - or face_coord > sim_max[self.axis] - or is_2d - ): - vjps[path] = 0.0 - continue - - v = self._compute_derivative_slab_bounds(derivative_info, idx, interpolators) - # outward-normal convention - if idx == 0: - v *= -1 - vjps[path] = v - else: - raise ValueError(f"No derivative defined w.r.t. 'PolySlab' field '{path}'.") - - return vjps - - # ---- Shared helpers for VJP surface integrations ---- - def _z_slices( - self, sim_min: NDArray, sim_max: NDArray, is_2d: bool, dx: float - ) -> tuple[NDArray, float, float, float]: - """Compute z-slice centers and spacing within bounds. - - Returns (z_centers, dz, z0, z1). For 2D, returns single center and dz=1. - """ - if is_2d: - midpoint_z = np.maximum( - np.minimum(self.center_axis, sim_max[self.axis]), - sim_min[self.axis], - ) - zc = np.array([midpoint_z], dtype=config.adjoint.gradient_dtype_float) - return zc, 1.0, self.center_axis, self.center_axis - - z0 = max(self.slab_bounds[0], sim_min[self.axis]) - z1 = min(self.slab_bounds[1], sim_max[self.axis]) - if z1 <= z0: - return np.array([], dtype=config.adjoint.gradient_dtype_float), 0.0, z0, z1 - - n_z = max(1, int(np.ceil((z1 - z0) / dx))) - dz = (z1 - z0) / n_z - z_centers = np.linspace( - z0 + dz / 2, z1 - dz / 2, n_z, dtype=config.adjoint.gradient_dtype_float - ) - return z_centers, dz, z0, z1 - - @staticmethod - def _clip_edges_to_bounds_batch( - segment_starts: NDArray, - segment_ends: NDArray, - sim_min: NDArray, - sim_max: NDArray, - *, - _edge_clip_tol: Optional[float] = None, - _dtype: Optional[type] = None, - ) -> tuple[NDArray, NDArray, NDArray]: - """ - Compute parametric bounds for multiple segments clipped to simulation bounds. - - Parameters - ---------- - segment_starts : NDArray - (N, 3) array of segment start coordinates. - segment_ends : NDArray - (N, 3) array of segment end coordinates. - sim_min : NDArray - (3,) array of simulation minimum bounds. - sim_max : NDArray - (3,) array of simulation maximum bounds. - - Returns - ------- - is_within_bounds : NDArray - (N,) boolean array indicating if the segment intersects the bounds. - t_starts : NDArray - (N,) array of parametric start values (0.0 to 1.0). - t_ends : NDArray - (N,) array of parametric end values (0.0 to 1.0). - """ - n = segment_starts.shape[0] - if _edge_clip_tol is None: - _edge_clip_tol = config.adjoint.edge_clip_tolerance - if _dtype is None: - _dtype = config.adjoint.gradient_dtype_float - - t_starts = np.zeros(n, dtype=_dtype) - t_ends = np.ones(n, dtype=_dtype) - is_within_bounds = np.ones(n, dtype=bool) - - for dim in range(3): - start_coords = segment_starts[:, dim] - end_coords = segment_ends[:, dim] - bound_min = sim_min[dim] - bound_max = sim_max[dim] - - # check for parallel edges (faster than isclose) - parallel = np.abs(start_coords - end_coords) < 1e-12 - - # parallel edges: check if outside bounds - outside = parallel & ( - (start_coords < (bound_min - _edge_clip_tol)) - | (start_coords > (bound_max + _edge_clip_tol)) - ) - is_within_bounds &= ~outside - - # non-parallel edges: compute t_min, t_max - not_parallel = ~parallel & is_within_bounds - if np.any(not_parallel): - denom = np.where(not_parallel, end_coords - start_coords, 1.0) # avoid div by zero - t_min = (bound_min - start_coords) / denom - t_max = (bound_max - start_coords) / denom - - # swap if needed - swap = t_min > t_max - t_min_new = np.where(swap, t_max, t_min) - t_max_new = np.where(swap, t_min, t_max) - - # update t_starts and t_ends for valid non-parallel edges - t_starts = np.where(not_parallel, np.maximum(t_starts, t_min_new), t_starts) - t_ends = np.where(not_parallel, np.minimum(t_ends, t_max_new), t_ends) - - # still valid? - is_within_bounds &= ~not_parallel | (t_starts < t_ends) - - is_within_bounds &= t_ends > t_starts + _edge_clip_tol - - return is_within_bounds, t_starts, t_ends - - @staticmethod - def _adaptive_edge_samples( - L: float, - dx: float, - t_start: float = 0.0, - t_end: float = 1.0, - *, - _sample_fraction: Optional[float] = None, - _gauss_order: Optional[int] = None, - _dtype: Optional[type] = None, - ) -> tuple[NDArray, NDArray]: - """ - Compute Gauss samples and weights along [t_start, t_end] with adaptive count. - - Parameters - ---------- - L : float - Physical length of the full edge. - dx : float - Target discretization step size. - t_start : float, optional - Start parameter, by default 0.0. - t_end : float, optional - End parameter, by default 1.0. - - Returns - ------- - tuple[NDArray, NDArray] - Tuple of (samples, weights) for the integration. - """ - if _sample_fraction is None: - _sample_fraction = config.adjoint.quadrature_sample_fraction - if _gauss_order is None: - _gauss_order = config.adjoint.gauss_quadrature_order - if _dtype is None: - _dtype = config.adjoint.gradient_dtype_float - - L_eff = L * max(0.0, t_end - t_start) - n_uniform = max(1, int(np.ceil(L_eff / dx))) - n_gauss = n_uniform if n_uniform <= 3 else max(2, int(n_uniform * _sample_fraction)) - if n_gauss <= _gauss_order: - g, w = leggauss(n_gauss) - half_range = 0.5 * (t_end - t_start) - s = (half_range * g + 0.5 * (t_end + t_start)).astype(_dtype, copy=False) - wt = (w * half_range).astype(_dtype, copy=False) - return s, wt - - # composite Gauss with fixed local order - g_loc, w_loc = leggauss(_gauss_order) - segs = n_uniform - edges_t = np.linspace(t_start, t_end, segs + 1, dtype=_dtype) - - # compute all segments at once - a = edges_t[:-1] # (segs,) - b = edges_t[1:] # (segs,) - half_width = 0.5 * (b - a) # (segs,) - mid = 0.5 * (b + a) # (segs,) - - # (segs, 1) * (order,) + (segs, 1) -> (segs, order) - S = (half_width[:, None] * g_loc + mid[:, None]).astype(_dtype, copy=False) - W = (half_width[:, None] * w_loc).astype(_dtype, copy=False) - return S.ravel(), W.ravel() - - def _collect_sidewall_patches( - self, - vertices: NDArray, - next_v: NDArray, - edges: NDArray, - basis: dict, - sim_min: NDArray, - sim_max: NDArray, - is_2d: bool, - dx: float, - ) -> dict: - """ - Collect sidewall patch geometry for batched VJP evaluation. - - Parameters - ---------- - vertices : NDArray - Array of polygon vertices. - next_v : NDArray - Array of next vertices (forming edges). - edges : NDArray - Edge vectors. - basis : dict - Basis vectors dictionary. - sim_min : NDArray - Simulation minimum bounds. - sim_max : NDArray - Simulation maximum bounds. - is_2d : bool - Whether the simulation is 2D. - dx : float - Discretization step. - - Returns - ------- - dict - Dictionary containing: - - centers: (N, 3) array of patch centers. - - normals: (N, 3) array of patch normals. - - perps1: (N, 3) array of first tangent vectors. - - perps2: (N, 3) array of second tangent vectors. - - Ls: (N,) array of edge lengths. - - s_vals: (N,) array of parametric coordinates along the edge. - - s_weights: (N,) array of quadrature weights. - - zc_vals: (N,) array of z-coordinates. - - dz: float, slice thickness. - - edge_indices: (N,) array of original edge indices. - """ - # cache config values to avoid repeated lookups (overhead not insignificant here) - _dtype = config.adjoint.gradient_dtype_float - _edge_clip_tol = config.adjoint.edge_clip_tolerance - _sample_fraction = config.adjoint.quadrature_sample_fraction - _gauss_order = config.adjoint.gauss_quadrature_order - - theta = get_static(self.sidewall_angle) - z_ref = self.reference_axis_pos - - cos_th = np.cos(theta) - cos_th = np.clip(cos_th, 1e-12, 1.0) - tan_th = np.tan(theta) - dprime = -tan_th # dd/dz - - # axis unit vector in 3D - axis_vec = np.zeros(3, dtype=_dtype) - axis_vec[self.axis] = 1.0 - - # densify along axis as |theta| grows, dz scales with cos(theta) - z_centers, dz, z0, z1 = self._z_slices(sim_min, sim_max, is_2d=is_2d, dx=dx * cos_th) - - # early exit: no slices - if (not is_2d) and len(z_centers) == 0: - return { - "centers": np.empty((0, 3), dtype=_dtype), - "normals": np.empty((0, 3), dtype=_dtype), - "perps1": np.empty((0, 3), dtype=_dtype), - "perps2": np.empty((0, 3), dtype=_dtype), - "Ls": np.empty((0,), dtype=_dtype), - "s_vals": np.empty((0,), dtype=_dtype), - "s_weights": np.empty((0,), dtype=_dtype), - "zc_vals": np.empty((0,), dtype=_dtype), - "dz": dz, - "edge_indices": np.empty((0,), dtype=int), - } - - # estimate patches for pre-allocation - n_edges = len(vertices) - estimated_patches = 0 - denom_edge = max(dx * cos_th, 1e-12) - for ei in range(n_edges): - v0, v1 = vertices[ei], next_v[ei] - L = np.linalg.norm(v1 - v0) - if not np.isclose(L, 0.0): - # prealloc guided by actual step; ds_phys scales with cos(theta) - n_samples = max(1, int(np.ceil(L / denom_edge) * 0.6)) - estimated_patches += n_samples * max(1, len(z_centers)) - estimated_patches = int(max(1, estimated_patches) * 1.2) - - # pre-allocate arrays - centers = np.empty((estimated_patches, 3), dtype=_dtype) - normals = np.empty((estimated_patches, 3), dtype=_dtype) - perps1 = np.empty((estimated_patches, 3), dtype=_dtype) - perps2 = np.empty((estimated_patches, 3), dtype=_dtype) - Ls = np.empty((estimated_patches,), dtype=_dtype) - s_vals = np.empty((estimated_patches,), dtype=_dtype) - s_weights = np.empty((estimated_patches,), dtype=_dtype) - zc_vals = np.empty((estimated_patches,), dtype=_dtype) - edge_indices = np.empty((estimated_patches,), dtype=int) - - patch_idx = 0 - - # if the simulation is effectively 2D (one tangential dimension collapsed), - # slightly expand degenerate bounds to enable finite-length clipping of edges. - sim_min_eff = np.array(sim_min, dtype=_dtype) - sim_max_eff = np.array(sim_max, dtype=_dtype) - for dim in range(3): - if dim == self.axis: - continue - if np.isclose(sim_max_eff[dim] - sim_min_eff[dim], 0.0): - sim_min_eff[dim] -= 0.5 * dx - sim_max_eff[dim] += 0.5 * dx - - # pre-compute values that are constant across z slices - n_z = len(z_centers) - z_centers_arr = np.asarray(z_centers, dtype=_dtype) - - # slanted local basis (constant across z for non-slanted case) - # for slanted: rz = axis_vec + dprime * n2d, but dprime is constant - for ei, (v0, v1) in enumerate(zip(vertices, next_v)): - edge_vec = v1 - v0 - L = np.sqrt(np.dot(edge_vec, edge_vec)) - if L < 1e-12: - continue - - # constant along edge: unit tangent in 3D (no axis component) - t_edge = basis["perp1"][ei] - - # outward in-plane normal from canonical basis normal - n2d = basis["norm"][ei].copy() - n2d[self.axis] = 0.0 - nrm = np.linalg.norm(n2d) - if not np.isclose(nrm, 0.0): - n2d = n2d / nrm - else: - # fallback to right-handed construction if degenerate - tmp = np.cross(axis_vec, t_edge) - n2d = tmp / (np.linalg.norm(tmp) + 1e-20) - - # compute basis vectors once per edge - rz = axis_vec + dprime * n2d - T1_vec = t_edge - N_vec = np.cross(T1_vec, rz) - N_norm = np.linalg.norm(N_vec) - if not np.isclose(N_norm, 0.0): - N_vec = N_vec / N_norm - - # align N with outward edge normal - if float(np.dot(N_vec, basis["norm"][ei])) < 0.0: - N_vec = -N_vec - - T2_vec = np.cross(N_vec, T1_vec) - T2_norm = np.linalg.norm(T2_vec) - if not np.isclose(T2_norm, 0.0): - T2_vec = T2_vec / T2_norm - - # batch compute offsets for all z slices at once - d_all = -(z_centers_arr - z_ref) * tan_th # (n_z,) - offsets_3d = d_all[:, None] * n2d # (n_z, 3) - faster than np.outer - - # batch compute segment starts and ends for all z slices - segment_starts = np.empty((n_z, 3), dtype=_dtype) - segment_ends = np.empty((n_z, 3), dtype=_dtype) - plane_axes = [i for i in range(3) if i != self.axis] - segment_starts[:, self.axis] = z_centers_arr - segment_starts[:, plane_axes] = v0 - segment_starts += offsets_3d - segment_ends[:, self.axis] = z_centers_arr - segment_ends[:, plane_axes] = v1 - segment_ends += offsets_3d - - # batch clip all z slices at once - is_within_bounds, t_starts, t_ends = self._clip_edges_to_bounds_batch( - segment_starts, - segment_ends, - sim_min_eff, - sim_max_eff, - _edge_clip_tol=_edge_clip_tol, - _dtype=_dtype, - ) - - # process only valid z slices (sampling has variable output sizes) - valid_indices = np.nonzero(is_within_bounds)[0] - if len(valid_indices) == 0: - continue - - # group z slices by unique (t0, t1) pairs to avoid redundant quadrature calculations. - # since most z-slices will have identical clipping bounds (0.0, 1.0), - # we can compute the Gauss samples once and reuse them for almost all slices. - # rounding ensures we get cache hits despite tiny floating point differences. - t0_valid = np.round(t_starts[valid_indices], 10) - t1_valid = np.round(t_ends[valid_indices], 10) - - # simple cache for sampling results: (t0, t1) -> (s_list, w_list) - sample_cache = {} - - # process each z slice - for zi, t0, t1 in zip(valid_indices, t0_valid, t1_valid): - if (t0, t1) not in sample_cache: - sample_cache[(t0, t1)] = self._adaptive_edge_samples( - L, - denom_edge, - t0, - t1, - _sample_fraction=_sample_fraction, - _gauss_order=_gauss_order, - _dtype=_dtype, - ) - - s_list, w_list = sample_cache[(t0, t1)] - if len(s_list) == 0: - continue - - zc = z_centers_arr[zi] - offset3d = offsets_3d[zi] - - pts2d = v0 + s_list[:, None] * edge_vec # faster than np.outer - - # inline unpop_axis_vect for xyz computation - n_pts = len(s_list) - xyz = np.empty((n_pts, 3), dtype=_dtype) - xyz[:, self.axis] = zc - xyz[:, plane_axes] = pts2d - xyz += offset3d - - n_patches = n_pts - new_size_needed = patch_idx + n_patches - if new_size_needed > centers.shape[0]: - # grow arrays by 1.5x to avoid frequent reallocations - new_size = int(new_size_needed * 1.5) - centers.resize((new_size, 3), refcheck=False) - normals.resize((new_size, 3), refcheck=False) - perps1.resize((new_size, 3), refcheck=False) - perps2.resize((new_size, 3), refcheck=False) - Ls.resize((new_size,), refcheck=False) - s_vals.resize((new_size,), refcheck=False) - s_weights.resize((new_size,), refcheck=False) - zc_vals.resize((new_size,), refcheck=False) - edge_indices.resize((new_size,), refcheck=False) - - sl = slice(patch_idx, patch_idx + n_patches) - centers[sl] = xyz - normals[sl] = N_vec - perps1[sl] = T1_vec - perps2[sl] = T2_vec - Ls[sl] = L - s_vals[sl] = s_list - s_weights[sl] = w_list - zc_vals[sl] = zc - edge_indices[sl] = ei - - patch_idx += n_patches - - # trim arrays to final size - centers = centers[:patch_idx] - normals = normals[:patch_idx] - perps1 = perps1[:patch_idx] - perps2 = perps2[:patch_idx] - Ls = Ls[:patch_idx] - s_vals = s_vals[:patch_idx] - s_weights = s_weights[:patch_idx] - zc_vals = zc_vals[:patch_idx] - edge_indices = edge_indices[:patch_idx] - - return { - "centers": centers, - "normals": normals, - "perps1": perps1, - "perps2": perps2, - "Ls": Ls, - "s_vals": s_vals, - "s_weights": s_weights, - "zc_vals": zc_vals, - "dz": dz, - "edge_indices": edge_indices, - } - - def _compute_derivative_sidewall_angle( - self, - derivative_info: DerivativeInfo, - sim_min: NDArray, - sim_max: NDArray, - is_2d: bool = False, - interpolators: Optional[dict] = None, - ) -> float: - """VJP for dJ/dtheta where theta = sidewall_angle. - - Use dJ/dtheta = integral_S g(x) * V_n(x; theta) * dA, with g(x) from - `evaluate_gradient_at_points`. For a ruled sidewall built by - offsetting the mid-plane polygon by d(z) = -(z - z_ref) * tan(theta), - the normal velocity is V_n = (dd/dtheta) * cos(theta) = -(z - z_ref)/cos(theta) - and the area element is dA = (dz/cos(theta)) * d_ell. - Therefore each patch weight is w = L * dz * (-(z - z_ref)) / cos(theta)^2. - """ - if interpolators is None: - interpolators = derivative_info.create_interpolators( - dtype=config.adjoint.gradient_dtype_float - ) - - # 2D sim => no dependence on theta (z_local=0) - if is_2d: - return 0.0 - - vertices, next_v, edges, basis = self._edge_geometry_arrays() - - dx = derivative_info.adaptive_vjp_spacing() - - # collect patches once - patch = self._collect_sidewall_patches( - vertices=vertices, - next_v=next_v, - edges=edges, - basis=basis, - sim_min=sim_min, - sim_max=sim_max, - is_2d=False, - dx=dx, - ) - if patch["centers"].shape[0] == 0: - return 0.0 - - # Shape-derivative factors: - # - Offset: d(z) = -(z - z_ref) * tan(theta) - # - Tangential rate: dd/dtheta = -(z - z_ref) * sec(theta)^2 - # - Normal velocity (project to surface normal): V_n = (dd/dtheta) * cos(theta) = -(z - z_ref)/cos(theta) - # - Area element of slanted strip: dA = (dz/cos(theta)) * d_ell - # => Patch weight scales as: V_n * dA = -(z - z_ref) * dz * d_ell / cos(theta)^2 - cos_theta = np.cos(get_static(self.sidewall_angle)) - inv_cos2 = 1.0 / (cos_theta * cos_theta) - z_ref = self.reference_axis_pos - - g = derivative_info.evaluate_gradient_at_points( - patch["centers"], patch["normals"], patch["perps1"], patch["perps2"], interpolators - ) - z_local = patch["zc_vals"] - z_ref - weights = patch["Ls"] * patch["s_weights"] * patch["dz"] * (-z_local) * inv_cos2 - return float(np.real(np.sum(g * weights))) - - def _compute_derivative_slab_bounds( - self, derivative_info: DerivativeInfo, min_max_index: int, interpolators: dict - ) -> float: - """VJP for one of the two horizontal faces of a ``PolySlab``. - - The face is discretized into a Cartesian grid of small planar patches - whose linear size does not exceed ``_VJP_SAMPLE_SPACING``. The adjoint surface - integral is evaluated on every retained patch; the resulting derivative - is split equally between the two vertices that bound the edge segment. - """ - # rmin/rmax over the geometry and simulation box - if np.isclose(self.slab_bounds[1] - self.slab_bounds[0], 0.0): - log.warning( - "Computing slab face derivatives for flat structures is not fully supported and " - "may give zero for the derivative. Try using a structure with a small, but nonzero " - "thickness for slab bound derivatives." - ) - rmin, rmax = derivative_info.bounds_intersect - _, (r1_min, r2_min) = self.pop_axis(rmin, axis=self.axis) - _, (r1_max, r2_max) = self.pop_axis(rmax, axis=self.axis) - ax_val = self.slab_bounds[min_max_index] - - # planar grid resolution, clipped to polygon bounding box - face_verts = self.base_polygon if min_max_index == 0 else self.top_polygon - face_poly = shapely.Polygon(face_verts).buffer(fp_eps) - - # limit the patch grid to the face that lives inside the simulation box - poly_min_r1, poly_min_r2, poly_max_r1, poly_max_r2 = face_poly.bounds - r1_min = max(r1_min, poly_min_r1) - r1_max = min(r1_max, poly_max_r1) - r2_min = max(r2_min, poly_min_r2) - r2_max = min(r2_max, poly_max_r2) - - # intersect the polygon with the simulation bounds - face_poly = face_poly.intersection(shapely.box(r1_min, r2_min, r1_max, r2_max)) - - if (r1_max <= r1_min) and (r2_max <= r2_min): - # the polygon does not intersect the current simulation slice - return 0.0 - - # re-compute the extents after clipping to the polygon bounds - extents = np.array([r1_max - r1_min, r2_max - r2_min]) - - # choose surface or line integral - integral_fun = ( - self.compute_derivative_slab_bounds_line - if np.isclose(extents, 0).any() - else self.compute_derivative_slab_bounds_surface - ) - return integral_fun( - derivative_info, - extents, - r1_min, - r1_max, - r2_min, - r2_max, - ax_val, - face_poly, - min_max_index, - interpolators, - ) - - def compute_derivative_slab_bounds_line( - self, - derivative_info: DerivativeInfo, - extents: NDArray, - r1_min: float, - r1_max: float, - r2_min: float, - r2_max: float, - ax_val: float, - face_poly: shapely.Polygon, - min_max_index: int, - interpolators: dict, - ) -> float: - """Handle degenerate line cross-section case""" - line_dim = 1 if np.isclose(extents[0], 0) else 0 - - poly_min_r1, poly_min_r2, poly_max_r1, poly_max_r2 = face_poly.bounds - if line_dim == 0: # x varies, y is fixed - l_min = max(r1_min, poly_min_r1) - l_max = min(r1_max, poly_max_r1) - else: # y varies, x is fixed - l_min = max(r2_min, poly_min_r2) - l_max = min(r2_max, poly_max_r2) - - length = l_max - l_min - if np.isclose(length, 0): - return 0.0 - - dx = derivative_info.adaptive_vjp_spacing() - n_seg = max(1, int(np.ceil(length / dx))) - coords = np.linspace( - l_min, l_max, 2 * n_seg + 1, dtype=config.adjoint.gradient_dtype_float - )[1::2] - - # build XY coordinates and in-plane direction vectors - if line_dim == 0: - xy = np.column_stack((coords, np.full_like(coords, r2_min))) - dir_vec_plane = np.column_stack((np.ones_like(coords), np.zeros_like(coords))) - else: - xy = np.column_stack((np.full_like(coords, r1_min), coords)) - dir_vec_plane = np.column_stack((np.zeros_like(coords), np.ones_like(coords))) - - inside = shapely.contains_xy(face_poly, xy[:, 0], xy[:, 1]) - if not inside.any(): - return 0.0 - - xy = xy[inside] - dir_vec_plane = dir_vec_plane[inside] - n_pts = len(xy) - - centers_xyz = self.unpop_axis_vect(np.full(n_pts, ax_val), xy) - areas = np.full(n_pts, length / n_seg) # patch length - - normals_xyz = self.unpop_axis_vect( - np.full(n_pts, -1 if min_max_index == 0 else 1), - np.zeros_like(xy), - ) - perps1_xyz = self.unpop_axis_vect(np.zeros(n_pts), dir_vec_plane) - perps2_xyz = self.unpop_axis_vect(np.zeros(n_pts), np.zeros_like(dir_vec_plane)) - - vjps = derivative_info.evaluate_gradient_at_points( - centers_xyz, normals_xyz, perps1_xyz, perps2_xyz, interpolators - ) - return np.real(np.sum(vjps * areas)).item() - - def compute_derivative_slab_bounds_surface( - self, - derivative_info: DerivativeInfo, - extents: NDArray, - r1_min: float, - r1_max: float, - r2_min: float, - r2_max: float, - ax_val: float, - face_poly: shapely.Polygon, - min_max_index: int, - interpolators: dict, - ) -> float: - """2d surface integral on a Gauss quadrature grid""" - dx = derivative_info.adaptive_vjp_spacing() - - # uniform grid would use n1 x n2 points - n1_uniform, n2_uniform = np.maximum(1, np.ceil(extents / dx).astype(int)) - - # use ~1/2 Gauss points in each direction for similar accuracy - n1 = max(2, n1_uniform // 2) - n2 = max(2, n2_uniform // 2) - - g1, w1 = leggauss(n1) - g2, w2 = leggauss(n2) - - coords1 = (0.5 * (r1_max - r1_min) * g1 + 0.5 * (r1_max + r1_min)).astype( - config.adjoint.gradient_dtype_float, copy=False - ) - coords2 = (0.5 * (r2_max - r2_min) * g2 + 0.5 * (r2_max + r2_min)).astype( - config.adjoint.gradient_dtype_float, copy=False - ) - - r1_grid, r2_grid = np.meshgrid(coords1, coords2, indexing="ij") - r1_flat = r1_grid.flatten() - r2_flat = r2_grid.flatten() - pts = np.column_stack((r1_flat, r2_flat)) - - in_face = shapely.contains_xy(face_poly, pts[:, 0], pts[:, 1]) - if not in_face.any(): - return 0.0 - - xyz = self.unpop_axis_vect( - np.full(in_face.sum(), ax_val, dtype=config.adjoint.gradient_dtype_float), pts[in_face] - ) - n_patches = xyz.shape[0] - - normals_xyz = self.unpop_axis_vect( - np.full( - n_patches, - -1 if min_max_index == 0 else 1, - dtype=config.adjoint.gradient_dtype_float, - ), - np.zeros((n_patches, 2), dtype=config.adjoint.gradient_dtype_float), - ) - perps1_xyz = self.unpop_axis_vect( - np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), - np.column_stack( - ( - np.ones(n_patches, dtype=config.adjoint.gradient_dtype_float), - np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), - ) - ), - ) - perps2_xyz = self.unpop_axis_vect( - np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), - np.column_stack( - ( - np.zeros(n_patches, dtype=config.adjoint.gradient_dtype_float), - np.ones(n_patches, dtype=config.adjoint.gradient_dtype_float), - ) - ), - ) - - w1_grid, w2_grid = np.meshgrid(w1, w2, indexing="ij") - weights_flat = (w1_grid * w2_grid).flatten()[in_face] - jacobian = 0.25 * (r1_max - r1_min) * (r2_max - r2_min) - - # area-based correction for non-rectangular domains (e.g. concave polygon) - # for constant integrand, integral should equal polygon area - sum_weights = np.sum(weights_flat) - if sum_weights > 0: - area_correction = face_poly.area / (sum_weights * jacobian) - weights_flat = weights_flat * area_correction - - vjps = derivative_info.evaluate_gradient_at_points( - xyz, normals_xyz, perps1_xyz, perps2_xyz, interpolators - ) - return np.real(np.sum(vjps * weights_flat * jacobian)).item() - - def _compute_derivative_vertices( - self, - derivative_info: DerivativeInfo, - sim_min: NDArray, - sim_max: NDArray, - is_2d: bool = False, - interpolators: Optional[dict] = None, - ) -> NDArray: - """VJP for the vertices of a ``PolySlab``. - - Uses shared sidewall patch collection and batched field evaluation. - """ - vertices, next_v, edges, basis = self._edge_geometry_arrays() - dx = derivative_info.adaptive_vjp_spacing() - - # collect patches once - patch = self._collect_sidewall_patches( - vertices=vertices, - next_v=next_v, - edges=edges, - basis=basis, - sim_min=sim_min, - sim_max=sim_max, - is_2d=is_2d, - dx=dx, - ) - - # early return if no patches - if patch["centers"].shape[0] == 0: - return np.zeros_like(vertices) - - dz = patch["dz"] - dz_surf = 1.0 if is_2d else dz / np.cos(self.sidewall_angle) - - # use provided interpolators or create them if not provided - if interpolators is None: - interpolators = derivative_info.create_interpolators( - dtype=config.adjoint.gradient_dtype_float - ) - - # evaluate integrand - g = derivative_info.evaluate_gradient_at_points( - patch["centers"], patch["normals"], patch["perps1"], patch["perps2"], interpolators - ) - - # compute area-based weights and weighted vjps - areas = patch["Ls"] * patch["s_weights"] * dz_surf - patch_vjps = (g * areas).real - - # distribute to vertices using vectorized accumulation - normals_2d = np.delete(basis["norm"], self.axis, axis=1) - edge_idx = patch["edge_indices"] - s = patch["s_vals"] - w0 = (1.0 - s) * patch_vjps - w1 = s * patch_vjps - edge_norms = normals_2d[edge_idx] - - # Accumulate per-vertex contributions using bincount (O(N_patches)) - num_vertices = vertices.shape[0] - contrib0 = w0[:, None] * edge_norms # (n_patches, 2) - contrib1 = w1[:, None] * edge_norms # (n_patches, 2) - - idx0 = edge_idx - idx1 = (edge_idx + 1) % num_vertices - - v0x = np.bincount(idx0, weights=contrib0[:, 0], minlength=num_vertices) - v0y = np.bincount(idx0, weights=contrib0[:, 1], minlength=num_vertices) - v1x = np.bincount(idx1, weights=contrib1[:, 0], minlength=num_vertices) - v1y = np.bincount(idx1, weights=contrib1[:, 1], minlength=num_vertices) - - vjp_per_vertex = np.stack((v0x + v1x, v0y + v1y), axis=1) - return vjp_per_vertex - - def _edge_geometry_arrays( - self, dtype: np.dtype = config.adjoint.gradient_dtype_float - ) -> tuple[NDArray, NDArray, NDArray, dict[str, NDArray]]: - """Return (vertices, next_v, edges, basis) arrays for sidewall edge geometry.""" - vertices = np.asarray(self.vertices, dtype=dtype) - next_v = np.roll(vertices, -1, axis=0) - edges = next_v - vertices - basis = self.edge_basis_vectors(edges) - return vertices, next_v, edges, basis - - def edge_basis_vectors( - self, - edges: NDArray, # (N, 2) - ) -> dict[str, NDArray]: # (N, 3) - """Normalized basis vectors for ``normal`` direction, ``slab`` tangent direction and ``edge``.""" - - # ensure edges have consistent dtype - edges = edges.astype(config.adjoint.gradient_dtype_float, copy=False) - - num_vertices, _ = edges.shape - zeros = np.zeros(num_vertices, dtype=config.adjoint.gradient_dtype_float) - ones = np.ones(num_vertices, dtype=config.adjoint.gradient_dtype_float) - - # normalized vectors along edges - edges_norm_in_plane = self.normalize_vect(edges) - edges_norm_xyz = self.unpop_axis_vect(zeros, edges_norm_in_plane) - - # normalized vectors from base of edges to tops of edges - cos_angle = np.cos(self.sidewall_angle) - sin_angle = np.sin(self.sidewall_angle) - slabs_axis_components = cos_angle * ones - - # create axis_norm as array directly to avoid tuple->array conversion in np.cross - axis_norm = np.zeros(3, dtype=config.adjoint.gradient_dtype_float) - axis_norm[self.axis] = 1.0 - slab_normal_xyz = -sin_angle * np.cross(edges_norm_xyz, axis_norm) - _, slab_normal_in_plane = self.pop_axis_vect(slab_normal_xyz) - slabs_norm_xyz = self.unpop_axis_vect(slabs_axis_components, slab_normal_in_plane) - - # normalized vectors pointing in normal direction of edge - # cross yields inward normal when the extrusion axis is y, so negate once for axis==1 - sign = (-1 if self.axis == 1 else 1) * (-1 if not self.is_ccw else 1) - normals_norm_xyz = sign * np.cross(edges_norm_xyz, slabs_norm_xyz) - - return { - "norm": normals_norm_xyz, - "perp1": edges_norm_xyz, - "perp2": slabs_norm_xyz, - } - - def unpop_axis_vect(self, ax_coords: NDArray, plane_coords: NDArray) -> NDArray: - """Combine coordinate along axis with coordinates on the plane tangent to the axis. - - ax_coords.shape == [N] - plane_coords.shape == [N, 2] - return shape == [N, 3] - """ - n_pts = ax_coords.shape[0] - arr_xyz = np.zeros((n_pts, 3), dtype=ax_coords.dtype) - - plane_axes = [i for i in range(3) if i != self.axis] - - arr_xyz[:, self.axis] = ax_coords - arr_xyz[:, plane_axes] = plane_coords - - return arr_xyz - - def pop_axis_vect(self, coord: NDArray) -> tuple[NDArray, tuple[NDArray, NDArray]]: - """Combine coordinate along axis with coordinates on the plane tangent to the axis. - - coord.shape == [N, 3] - return shape == ([N], [N, 2] - """ - - arr_axis, arrs_plane = self.pop_axis(coord.T, axis=self.axis) - arrs_plane = np.array(arrs_plane).T - - return arr_axis, arrs_plane - - @staticmethod - def normalize_vect(arr: NDArray) -> NDArray: - """normalize an array shaped (N, d) along the `d` axis and return (N, 1).""" - norm = np.linalg.norm(arr, axis=-1, keepdims=True) - norm = np.where(norm == 0, 1, norm) - return arr / norm - - def translated(self, x: float, y: float, z: float) -> PolySlab: - """Return a translated copy of this geometry. - - Parameters - ---------- - x : float - Translation along x. - y : float - Translation along y. - z : float - Translation along z. - - Returns - ------- - :class:`PolySlab` - Translated copy of this ``PolySlab``. - """ - - t_normal, t_plane = self.pop_axis((x, y, z), axis=self.axis) - translated_vertices = np.array(self.vertices) + np.array(t_plane)[None, :] - translated_slab_bounds = (self.slab_bounds[0] + t_normal, self.slab_bounds[1] + t_normal) - return self.updated_copy(vertices=translated_vertices, slab_bounds=translated_slab_bounds) - - def scaled(self, x: float = 1.0, y: float = 1.0, z: float = 1.0) -> PolySlab: - """Return a scaled copy of this geometry. - - Parameters - ---------- - x : float = 1.0 - Scaling factor along x. - y : float = 1.0 - Scaling factor along y. - z : float = 1.0 - Scaling factor along z. - - Returns - ------- - :class:`Geometry` - Scaled copy of this geometry. - """ - scale_normal, scale_in_plane = self.pop_axis((x, y, z), axis=self.axis) - scaled_vertices = self.vertices * np.array(scale_in_plane) - scaled_slab_bounds = tuple(scale_normal * bound for bound in self.slab_bounds) - return self.updated_copy(vertices=scaled_vertices, slab_bounds=scaled_slab_bounds) - - def rotated(self, angle: float, axis: Union[Axis, Coordinate]) -> PolySlab: - """Return a rotated copy of this geometry. - - Parameters - ---------- - angle : float - Rotation angle (in radians). - axis : Union[int, Tuple[float, float, float]] - Axis of rotation: 0, 1, or 2 for x, y, and z, respectively, or a 3D vector. - - Returns - ------- - :class:`PolySlab` - Rotated copy of this ``PolySlab``. - """ - _, plane_axs = self.pop_axis([0, 1, 2], self.axis) - if (isinstance(axis, int) and axis == self.axis) or ( - isinstance(axis, tuple) and all(axis[ax] == 0 for ax in plane_axs) - ): - verts_3d = np.zeros((3, self.vertices.shape[0])) - verts_3d[plane_axs[0], :] = self.vertices[:, 0] - verts_3d[plane_axs[1], :] = self.vertices[:, 1] - rotation = RotationAroundAxis(angle=angle, axis=axis) - rotated_vertices = rotation.rotate_vector(verts_3d) - rotated_vertices = rotated_vertices[plane_axs, :].T - return self.updated_copy(vertices=rotated_vertices) - - return super().rotated(angle=angle, axis=axis) - - def reflected(self, normal: Coordinate) -> PolySlab: - """Return a reflected copy of this geometry. - - Parameters - ---------- - normal : Tuple[float, float, float] - The 3D normal vector of the plane of reflection. The plane is assumed - to pass through the origin (0,0,0). - - Returns - ------- - ------- - :class:`PolySlab` - Reflected copy of this ``PolySlab``. - """ - if math.isclose(normal[self.axis], 0): - _, plane_axs = self.pop_axis((0, 1, 2), self.axis) - verts_3d = np.zeros((3, self.vertices.shape[0])) - verts_3d[plane_axs[0], :] = self.vertices[:, 0] - verts_3d[plane_axs[1], :] = self.vertices[:, 1] - reflection = ReflectionFromPlane(normal=normal) - reflected_vertices = reflection.reflect_vector(verts_3d) - reflected_vertices = reflected_vertices[plane_axs, :].T - return self.updated_copy(vertices=reflected_vertices) - - return super().reflected(normal=normal) - - -class ComplexPolySlabBase(PolySlab): - """Interface for dividing a complex polyslab where self-intersecting polygon can - occur during extrusion. This class should not be used directly. Use instead - :class:`plugins.polyslab.ComplexPolySlab`.""" - - @pydantic.validator("vertices", always=True) - def no_self_intersecting_polygon_during_extrusion( - cls, val: ArrayFloat2D, values: dict[str, Any] - ) -> ArrayFloat2D: - """Turn off the validation for this class.""" - return val - - @classmethod - def from_gds( - cls, - gds_cell: Cell, - axis: Axis, - slab_bounds: tuple[float, float], - gds_layer: int, - gds_dtype: Optional[int] = None, - gds_scale: pydantic.PositiveFloat = 1.0, - dilation: float = 0.0, - sidewall_angle: float = 0, - reference_plane: PlanePosition = "middle", - ) -> list[PolySlab]: - """Import :class:`.PolySlab` from a ``gdstk.Cell``. - - Parameters - ---------- - gds_cell : gdstk.Cell - ``gdstk.Cell`` containing 2D geometric data. - axis : int - Integer index into the polygon's slab axis. (0,1,2) -> (x,y,z). - slab_bounds: Tuple[float, float] - Minimum and maximum positions of the slab along ``axis``. - gds_layer : int - Layer index in the ``gds_cell``. - gds_dtype : int = None - Data-type index in the ``gds_cell``. - If ``None``, imports all data for this layer into the returned list. - gds_scale : float = 1.0 - Length scale used in GDS file in units of MICROMETER. - For example, if gds file uses nanometers, set ``gds_scale=1e-3``. - Must be positive. - dilation : float = 0.0 - Dilation of the polygon in the base by shifting each edge along its - normal outwards direction by a distance; - a negative value corresponds to erosion. - sidewall_angle : float = 0 - Angle of the sidewall. - ``sidewall_angle=0`` (default) specifies vertical wall, - while ``0 base.GeometryGroup: - """Divide a complex polyslab into a list of simple polyslabs, which - are assembled into a :class:`.GeometryGroup`. - - Returns - ------- - :class:`.GeometryGroup` - GeometryGroup for a list of simple polyslabs divided from the complex - polyslab. - """ - return base.GeometryGroup(geometries=self.sub_polyslabs) - - @property - def sub_polyslabs(self) -> list[PolySlab]: - """Divide a complex polyslab into a list of simple polyslabs. - Only neighboring vertex-vertex crossing events are treated in this - version. - - Returns - ------- - List[PolySlab] - A list of simple polyslabs. - """ - sub_polyslab_list = [] - num_division_count = 0 - # initialize sub-polyslab parameters - sub_polyslab_dict = self.dict(exclude={"type"}).copy() - if math.isclose(self.sidewall_angle, 0): - return [PolySlab.parse_obj(sub_polyslab_dict)] - - sub_polyslab_dict.update({"dilation": 0}) # dilation accounted in setup - # initialize offset distance - offset_distance = 0 - - for dist_val in self._dilation_length: - dist_now = 0.0 - vertices_now = self.reference_polygon - - # constructing sub-polyslabs until reaching the base/top - while not math.isclose(dist_now, dist_val): - # bounds for sub-polyslabs assuming no self-intersection - slab_bounds = [ - self._dilation_value_at_reference_to_coord(dist_now), - self._dilation_value_at_reference_to_coord(dist_val), - ] - # 1) find out any vertices touching events between the current - # position to the base/top - max_dist = PolySlab._neighbor_vertices_crossing_detection( - vertices_now, dist_val - dist_now - ) - - # vertices touching events captured, update bounds for sub-polyslab - if max_dist is not None: - # max_dist doesn't have sign, so construct signed offset distance - offset_distance = max_dist * dist_val / abs(dist_val) - slab_bounds[1] = self._dilation_value_at_reference_to_coord( - dist_now + offset_distance - ) - - # 2) construct sub-polyslab - slab_bounds.sort() # for reference_plane=top/bottom, bounds need to be ordered - # direction of marching - reference_plane = "bottom" if dist_val / self._tanq < 0 else "top" - sub_polyslab_dict.update( - { - "slab_bounds": tuple(slab_bounds), - "vertices": vertices_now, - "reference_plane": reference_plane, - } - ) - sub_polyslab_list.append(PolySlab.parse_obj(sub_polyslab_dict)) - - # Now Step 3 - if max_dist is None: - break - dist_now += offset_distance - # new polygon vertices where collapsing vertices are removed but keep one - vertices_now = PolySlab._shift_vertices(vertices_now, offset_distance)[0] - vertices_now = PolySlab._remove_duplicate_vertices(vertices_now) - # all vertices collapse - if len(vertices_now) < 3: - break - # polygon collapse into 1D - if self.make_shapely_polygon(vertices_now).buffer(0).area < fp_eps: - break - vertices_now = PolySlab._orient(vertices_now) - num_division_count += 1 - - if num_division_count > _COMPLEX_POLYSLAB_DIVISIONS_WARN: - log.warning( - f"Too many self-intersecting events: the polyslab has been divided into " - f"{num_division_count} polyslabs; more than {_COMPLEX_POLYSLAB_DIVISIONS_WARN} may " - f"slow down the simulation." - ) - - return sub_polyslab_list - - @property - def _dilation_length(self) -> list[float]: - """dilation length from reference plane to the top/bottom of the polyslab.""" - - # for "bottom", only needs to compute the offset length to the top - dist = [self._extrusion_length_to_offset_distance(self.finite_length_axis)] - # reverse the dilation value if the reference plane is on the top - if self.reference_plane == "top": - dist = [-dist[0]] - # for middle, both directions - elif self.reference_plane == "middle": - dist = [dist[0] / 2, -dist[0] / 2] - return dist - - def _dilation_value_at_reference_to_coord(self, dilation: float) -> float: - """Compute the coordinate based on the dilation value to the reference plane.""" - - z_coord = -dilation / self._tanq + self.slab_bounds[0] - if self.reference_plane == "middle": - return z_coord + self.finite_length_axis / 2 - if self.reference_plane == "top": - return z_coord + self.finite_length_axis - # bottom case - return z_coord - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. Not used for PolySlab. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - return [ - shapely.unary_union( - [ - base.Geometry.evaluate_inf_shape(shape) - for polyslab in self.sub_polyslabs - for shape in polyslab.intersections_tilted_plane( - normal, origin, to_2D, cleanup=cleanup, quad_segs=quad_segs - ) - ] - ) - ] diff --git a/tidy3d/components/geometry/primitives.py b/tidy3d/components/geometry/primitives.py index f6916b535d..8144045782 100644 --- a/tidy3d/components/geometry/primitives.py +++ b/tidy3d/components/geometry/primitives.py @@ -1,1015 +1,18 @@ -"""Concrete primitive geometrical objects.""" +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.primitives`.""" -from __future__ import annotations - -from math import isclose -from typing import Any, Optional - -import autograd.numpy as anp -import numpy as np -import pydantic.v1 as pydantic -import shapely -from pydantic.v1 import PrivateAttr -from shapely.geometry.base import BaseGeometry - -from tidy3d.components.autograd import AutogradFieldMap, TracedSize1D -from tidy3d.components.autograd.derivative_utils import DerivativeInfo -from tidy3d.components.base import cached_property, skip_if_fields_missing -from tidy3d.components.geometry import base -from tidy3d.components.geometry.mesh import TriangleMesh -from tidy3d.components.geometry.polyslab import PolySlab -from tidy3d.components.types import Axis, Bound, Coordinate, MatrixReal4x4, Shapely -from tidy3d.config import config -from tidy3d.constants import LARGE_NUMBER, MICROMETER -from tidy3d.exceptions import SetupError, ValidationError -from tidy3d.log import log -from tidy3d.packaging import verify_packages_import - -# for sampling conical frustum in visualization -_N_SAMPLE_CURVE_SHAPELY = 40 - -# for shapely circular shapes discretization in visualization -_N_SHAPELY_QUAD_SEGS_VISUALIZATION = 200 - -# Default number of points to discretize polyslab in `Cylinder.to_polyslab()` -_N_PTS_CYLINDER_POLYSLAB = 51 -_MAX_ICOSPHERE_SUBDIVISIONS = 7 # this would have 164K vertices and 328K faces -_DEFAULT_EDGE_FRACTION = 0.25 - - -def _base_icosahedron() -> tuple[np.ndarray, np.ndarray]: - """Return vertices and faces of a unit icosahedron.""" - - phi = (1.0 + np.sqrt(5.0)) / 2.0 - vertices = np.array( - [ - (-1, phi, 0), - (1, phi, 0), - (-1, -phi, 0), - (1, -phi, 0), - (0, -1, phi), - (0, 1, phi), - (0, -1, -phi), - (0, 1, -phi), - (phi, 0, -1), - (phi, 0, 1), - (-phi, 0, -1), - (-phi, 0, 1), - ], - dtype=float, - ) - vertices /= np.linalg.norm(vertices, axis=1)[:, None] - faces = np.array( - [ - (0, 11, 5), - (0, 5, 1), - (0, 1, 7), - (0, 7, 10), - (0, 10, 11), - (1, 5, 9), - (5, 11, 4), - (11, 10, 2), - (10, 7, 6), - (7, 1, 8), - (3, 9, 4), - (3, 4, 2), - (3, 2, 6), - (3, 6, 8), - (3, 8, 9), - (4, 9, 5), - (2, 4, 11), - (6, 2, 10), - (8, 6, 7), - (9, 8, 1), - ], - dtype=int, - ) - return vertices, faces - - -_ICOSAHEDRON_VERTS, _ICOSAHEDRON_FACES = _base_icosahedron() - - -class Sphere(base.Centered, base.Circular): - """Spherical geometry. - - Example - ------- - >>> b = Sphere(center=(1,2,3), radius=2) - """ - - _icosphere_cache: dict[int, tuple[np.ndarray, float]] = PrivateAttr(default_factory=dict) - - def inside( - self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] - ) -> np.ndarray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - self._ensure_equal_shape(x, y, z) - x0, y0, z0 = self.center - dist_x = np.abs(x - x0) - dist_y = np.abs(y - y0) - dist_z = np.abs(z - z0) - return (dist_x**2 + dist_y**2 + dist_z**2) <= (self.radius**2) - - def intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - if quad_segs is None: - quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION - - normal = np.array(normal) - unit_normal = normal / (np.sum(normal**2) ** 0.5) - projection = np.dot(np.array(origin) - np.array(self.center), unit_normal) - if abs(projection) >= self.radius: - return [] - - radius = (self.radius**2 - projection**2) ** 0.5 - center = np.array(self.center) + projection * unit_normal - - v = np.zeros(3) - v[np.argmin(np.abs(unit_normal))] = 1 - u = np.cross(unit_normal, v) - u /= np.sum(u**2) ** 0.5 - v = np.cross(unit_normal, u) - - angles = np.linspace(0, 2 * np.pi, quad_segs * 4 + 1)[:-1] - circ = center + np.outer(np.cos(angles), radius * u) + np.outer(np.sin(angles), radius * v) - vertices = np.dot(np.hstack((circ, np.ones((angles.size, 1)))), to_2D.T) - return [shapely.Polygon(vertices[:, :2])] - - def intersections_plane( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - cleanup: bool = True, - quad_segs: Optional[int] = None, - ) -> list[BaseGeometry]: - """Returns shapely geometry at plane specified by one non None value of x,y,z. - - Parameters - ---------- - x : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - y : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - z : float = None - Position of plane in x direction, only one of x,y,z can be specified to define plane. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation ``. - """ - if quad_segs is None: - quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION - - axis, position = self.parse_xyz_kwargs(x=x, y=y, z=z) - if not self.intersects_axis_position(axis, position): - return [] - z0, (x0, y0) = self.pop_axis(self.center, axis=axis) - intersect_dist = self._intersect_dist(position, z0) - if not intersect_dist: - return [] - return [shapely.Point(x0, y0).buffer(0.5 * intersect_dist, quad_segs=quad_segs)] - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float, float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - coord_min = tuple(c - self.radius for c in self.center) - coord_max = tuple(c + self.radius for c in self.center) - return (coord_min, coord_max) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - - volume = 4.0 / 3.0 * np.pi * self.radius**3 - - # a very loose upper bound on how much of sphere is in bounds - for axis in range(3): - if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: - volume *= 0.5 - - return volume - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - - area = 4.0 * np.pi * self.radius**2 - - # a very loose upper bound on how much of sphere is in bounds - for axis in range(3): - if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: - area *= 0.5 - - return area - - @classmethod - def unit_sphere_triangles( - cls, - *, - target_edge_length: Optional[float] = None, - subdivisions: Optional[int] = None, - ) -> np.ndarray: - """Return unit sphere triangles discretized via an icosphere.""" - - unit_tris = UNIT_SPHERE._unit_sphere_triangles( - target_edge_length=target_edge_length, - subdivisions=subdivisions, - copy_result=True, - ) - return unit_tris - - def _unit_sphere_triangles( - self, - *, - target_edge_length: Optional[float] = None, - subdivisions: Optional[int] = None, - copy_result: bool = True, - ) -> np.ndarray: - """Return cached unit-sphere triangles with optional copying.""" - if target_edge_length is not None and subdivisions is not None: - raise ValueError("Specify either target_edge_length OR subdivisions, not both.") - - if subdivisions is None: - subdivisions = self._subdivisions_for_edge(target_edge_length) - - triangles, _ = self._icosphere_data(subdivisions) - return np.array(triangles, copy=copy_result) - - def _subdivisions_for_edge(self, target_edge_length: Optional[float]) -> int: - if target_edge_length is None or target_edge_length <= 0.0: - return 0 - - for subdiv in range(_MAX_ICOSPHERE_SUBDIVISIONS + 1): - _, max_edge = self._icosphere_data(subdiv) - if max_edge <= target_edge_length: - return subdiv - - log.warning( - f"Requested sphere mesh edge length {target_edge_length:.3e} μm requires more than " - f"{_MAX_ICOSPHERE_SUBDIVISIONS} subdivisions. " - "Clipping to the finest available mesh.", - log_once=True, - ) - return _MAX_ICOSPHERE_SUBDIVISIONS - - def _icosphere_data(self, subdivisions: int) -> tuple[np.ndarray, float]: - cache = self._icosphere_cache - if subdivisions in cache: - return cache[subdivisions] - - vertices = np.asarray(_ICOSAHEDRON_VERTS, dtype=float) - faces = np.asarray(_ICOSAHEDRON_FACES, dtype=int) - if subdivisions > 0: - vertices = vertices.copy() - faces = faces.copy() - for _ in range(subdivisions): - vertices, faces = TriangleMesh.subdivide_faces(vertices, faces) - - norms = np.linalg.norm(vertices, axis=1, keepdims=True) - norms = np.where(norms == 0.0, 1.0, norms) - vertices = vertices / norms - - triangles = vertices[faces] - max_edge = self._max_edge_length(triangles) - cache[subdivisions] = (triangles, max_edge) - return triangles, max_edge - - @staticmethod - def _max_edge_length(triangles: np.ndarray) -> float: - v = triangles - edges = np.stack( - [ - v[:, 1] - v[:, 0], - v[:, 2] - v[:, 1], - v[:, 0] - v[:, 2], - ], - axis=1, - ) - return float(np.linalg.norm(edges, axis=2).max()) - - -UNIT_SPHERE = Sphere(center=(0.0, 0.0, 0.0), radius=1.0) - - -class Cylinder(base.Centered, base.Circular, base.Planar): - """Cylindrical geometry with optional sidewall angle along axis - direction. When ``sidewall_angle`` is nonzero, the shape is a - conical frustum or a cone. - - Example - ------- - >>> c = Cylinder(center=(1,2,3), radius=2, length=5, axis=2) - - See Also - -------- - - **Notebooks** - - * `THz integrated demultiplexer/filter based on a ring resonator <../../../notebooks/THzDemultiplexerFilter.html>`_ - * `Photonic crystal waveguide polarization filter <../../../notebooks/PhotonicCrystalWaveguidePolarizationFilter.html>`_ - """ - - # Provide more explanations on where radius is defined - radius: TracedSize1D = pydantic.Field( - ..., - title="Radius", - description="Radius of geometry at the ``reference_plane``.", - units=MICROMETER, - ) - - length: TracedSize1D = pydantic.Field( - ..., - title="Length", - description="Defines thickness of cylinder along axis dimension.", - units=MICROMETER, - ) - - @pydantic.validator("length", always=True) - @skip_if_fields_missing(["sidewall_angle", "reference_plane"]) - def _only_middle_for_infinite_length_slanted_cylinder( - cls, val: float, values: dict[str, Any] - ) -> float: - """For a slanted cylinder of infinite length, ``reference_plane`` can only - be ``middle``; otherwise, the radius at ``center`` is either td.inf or 0. - """ - if isclose(values["sidewall_angle"], 0) or not np.isinf(val): - return val - if values["reference_plane"] != "middle": - raise SetupError( - "For a slanted cylinder here is of infinite length, " - "defining the reference_plane other than 'middle' " - "leads to undefined cylinder behaviors near 'center'." - ) - return val - - def to_polyslab( - self, num_pts_circumference: int = _N_PTS_CYLINDER_POLYSLAB, **kwargs: Any - ) -> PolySlab: - """Convert instance of ``Cylinder`` into a discretized version using ``PolySlab``. - - Parameters - ---------- - num_pts_circumference : int = 51 - Number of points in the circumference of the discretized polyslab. - **kwargs: - Extra keyword arguments passed to ``PolySlab()``, such as ``dilation``. - - Returns - ------- - PolySlab - Extruded polygon representing a discretized version of the cylinder. - """ - - center_axis = self.center_axis - length_axis = self.length_axis - slab_bounds = (center_axis - length_axis / 2.0, center_axis + length_axis / 2.0) - - if num_pts_circumference < 3: - raise ValueError("'PolySlab' from 'Cylinder' must have 3 or more radius points.") - - _, (x0, y0) = self.pop_axis(self.center, axis=self.axis) - - xs_, ys_ = self._points_unit_circle(num_pts_circumference=num_pts_circumference) - - xs = x0 + self.radius * xs_ - ys = y0 + self.radius * ys_ - - vertices = anp.stack((xs, ys), axis=-1) - - return PolySlab( - vertices=vertices, - axis=self.axis, - slab_bounds=slab_bounds, - sidewall_angle=self.sidewall_angle, - reference_plane=self.reference_plane, - **kwargs, - ) - - def _points_unit_circle( - self, num_pts_circumference: int = _N_PTS_CYLINDER_POLYSLAB - ) -> np.ndarray: - """Set of x and y points for the unit circle when discretizing cylinder as a polyslab.""" - angles = np.linspace(0, 2 * np.pi, num_pts_circumference, endpoint=False) - xs = np.cos(angles) - ys = np.sin(angles) - return np.stack((xs, ys), axis=0) - - def _discretization_wavelength(self, derivative_info: DerivativeInfo) -> float: - """Choose a reference wavelength for discretizing the cylinder into a `PolySlab`.""" - wvl0_min = derivative_info.wavelength_min - wvl_mat = wvl0_min / np.max([1.0, np.max(np.sqrt(abs(derivative_info.eps_in)))]) - - grid_cfg = config.adjoint - - min_wvl_mat = grid_cfg.min_wvl_fraction * wvl0_min - if wvl_mat < min_wvl_mat: - log.warning( - f"The minimum wavelength inside the cylinder material is {wvl_mat:.3e} μm, which would " - f"create a large number of discretization points for computing the gradient. " - f"To prevent performance degradation, the discretization wavelength has " - f"been clipped to {min_wvl_mat:.3e} μm.", - log_once=True, - ) - wvl_mat = max(wvl_mat, min_wvl_mat) - - return wvl_mat - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - - # compute circumference discretization - wvl_mat = self._discretization_wavelength(derivative_info=derivative_info) - - circumference = 2 * np.pi * self.radius - wvls_in_circumference = circumference / wvl_mat +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - grid_cfg = config.adjoint - num_pts_circumference = int(np.ceil(grid_cfg.points_per_wavelength * wvls_in_circumference)) - num_pts_circumference = max(3, num_pts_circumference) - - # construct equivalent polyslab and compute the derivatives - polyslab = self.to_polyslab(num_pts_circumference=num_pts_circumference) - - # build PolySlab derivative paths based on requested Cylinder paths - ps_paths = set() - for path in derivative_info.paths: - if path == ("length",): - ps_paths.update({("slab_bounds", 0), ("slab_bounds", 1)}) - elif path == ("radius",): - ps_paths.add(("vertices",)) - elif "center" in path: - _, center_index = path - _, (index_x, index_y) = self.pop_axis((0, 1, 2), axis=self.axis) - if center_index in (index_x, index_y): - ps_paths.add(("vertices",)) - else: - ps_paths.update({("slab_bounds", 0), ("slab_bounds", 1)}) - elif path == ("sidewall_angle",): - ps_paths.add(("sidewall_angle",)) - - # pass interpolators to PolySlab if available to avoid redundant conversions - update_kwargs = { - "paths": list(ps_paths), - "deep": False, - } - if derivative_info.interpolators is not None: - update_kwargs["interpolators"] = derivative_info.interpolators - - derivative_info_polyslab = derivative_info.updated_copy(**update_kwargs) - vjps_polyslab = polyslab._compute_derivatives(derivative_info_polyslab) - - vjps = {} - for path in derivative_info.paths: - if path == ("length",): - vjp_top = vjps_polyslab.get(("slab_bounds", 0), 0.0) - vjp_bot = vjps_polyslab.get(("slab_bounds", 1), 0.0) - vjps[path] = vjp_top - vjp_bot - - elif path == ("radius",): - # transform polyslab vertices derivatives into radius derivative - xs_, ys_ = self._points_unit_circle(num_pts_circumference=num_pts_circumference) - if ("vertices",) not in vjps_polyslab: - vjps[path] = 0.0 - else: - vjps_vertices_xs, vjps_vertices_ys = vjps_polyslab[("vertices",)].T - vjp_xs = np.sum(xs_ * vjps_vertices_xs) - vjp_ys = np.sum(ys_ * vjps_vertices_ys) - vjps[path] = vjp_xs + vjp_ys - - elif "center" in path: - _, center_index = path - _, (index_x, index_y) = self.pop_axis((0, 1, 2), axis=self.axis) - if center_index == index_x: - if ("vertices",) not in vjps_polyslab: - vjps[path] = 0.0 - else: - vjps_vertices_xs = vjps_polyslab[("vertices",)][:, 0] - vjps[path] = np.sum(vjps_vertices_xs) - elif center_index == index_y: - if ("vertices",) not in vjps_polyslab: - vjps[path] = 0.0 - else: - vjps_vertices_ys = vjps_polyslab[("vertices",)][:, 1] - vjps[path] = np.sum(vjps_vertices_ys) - else: - vjp_top = vjps_polyslab.get(("slab_bounds", 0), 0.0) - vjp_bot = vjps_polyslab.get(("slab_bounds", 1), 0.0) - vjps[path] = vjp_top + vjp_bot - - elif path == ("sidewall_angle",): - # direct mapping: cylinder angle equals polyslab angle - vjps[path] = vjps_polyslab.get(("sidewall_angle",), 0.0) - - else: - raise NotImplementedError( - f"Differentiation with respect to 'Cylinder' '{path}' field not supported. " - "If you would like this feature added, please feel free to raise " - "an issue on the tidy3d front end repository." - ) - - return vjps - - @property - def center_axis(self) -> Any: - """Gets the position of the center of the geometry in the out of plane dimension.""" - z0, _ = self.pop_axis(self.center, axis=self.axis) - return z0 - - @property - def length_axis(self) -> float: - """Gets the length of the geometry along the out of plane dimension.""" - return self.length - - @cached_property - def _normal_2dmaterial(self) -> Axis: - """Get the normal to the given geometry, checking that it is a 2D geometry.""" - if self.length != 0: - raise ValidationError("'Medium2D' requires the 'Cylinder' length to be zero.") - return self.axis - - def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Cylinder: - """Returns an updated geometry which has been transformed to fit within ``bounds`` - along the ``axis`` direction.""" - if axis != self.axis: - raise ValueError( - f"'_update_from_bounds' may only be applied along axis '{self.axis}', " - f"but was given axis '{axis}'." - ) - new_center = list(self.center) - new_center[axis] = (bounds[0] + bounds[1]) / 2 - new_length = bounds[1] - bounds[0] - return self.updated_copy(center=new_center, length=new_length) - - @verify_packages_import(["trimesh"]) - def _do_intersections_tilted_plane( - self, - normal: Coordinate, - origin: Coordinate, - to_2D: MatrixReal4x4, - quad_segs: Optional[int] = None, - ) -> list[Shapely]: - """Return a list of shapely geometries at the plane specified by normal and origin. - - Parameters - ---------- - normal : Coordinate - Vector defining the normal direction to the plane. - origin : Coordinate - Vector defining the plane origin. - to_2D : MatrixReal4x4 - Transformation matrix to apply to resulting shapes. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - import trimesh - - if quad_segs is None: - quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION - - z0, (x0, y0) = self.pop_axis(self.center, self.axis) - half_length = self.finite_length_axis / 2 - - z_top = z0 + half_length - z_bot = z0 - half_length - - if np.isclose(self.sidewall_angle, 0): - r_top = self.radius - r_bot = self.radius - else: - r_top = self.radius_top - r_bot = self.radius_bottom - if r_top < 0 or np.isclose(r_top, 0): - r_top = 0 - z_top = z0 + self._radius_z(z0) / self._tanq - elif r_bot < 0 or np.isclose(r_bot, 0): - r_bot = 0 - z_bot = z0 + self._radius_z(z0) / self._tanq - - angles = np.linspace(0, 2 * np.pi, quad_segs * 4 + 1) - - if r_bot > 0: - x_bot = x0 + r_bot * np.cos(angles) - y_bot = y0 + r_bot * np.sin(angles) - x_bot[-1] = x0 - y_bot[-1] = y0 - else: - x_bot = np.array([x0]) - y_bot = np.array([y0]) - - if r_top > 0: - x_top = x0 + r_top * np.cos(angles) - y_top = y0 + r_top * np.sin(angles) - x_top[-1] = x0 - y_top[-1] = y0 - else: - x_top = np.array([x0]) - y_top = np.array([y0]) - - x = np.hstack((x_bot, x_top)) - y = np.hstack((y_bot, y_top)) - z = np.hstack((np.full_like(x_bot, z_bot), np.full_like(x_top, z_top))) - vertices = np.vstack(self.unpop_axis(z, (x, y), self.axis)).T - - if x_bot.shape[0] == 1: - m = 1 - n = x_top.shape[0] - 1 - faces_top = [(m + n, m + i, m + (i + 1) % n) for i in range(n)] - faces_side = [(m + (i + 1) % n, m + i, 0) for i in range(n)] - faces = faces_top + faces_side - elif x_top.shape[0] == 1: - m = x_bot.shape[0] - n = m - 1 - faces_bot = [(n, (i + 1) % n, i) for i in range(n)] - faces_side = [(i, (i + 1) % n, m) for i in range(n)] - faces = faces_bot + faces_side - else: - m = x_bot.shape[0] - n = m - 1 - faces_bot = [(n, (i + 1) % n, i) for i in range(n)] - faces_top = [(m + n, m + i, m + (i + 1) % n) for i in range(n)] - faces_side_bot = [(i, (i + 1) % n, m + (i + 1) % n) for i in range(n)] - faces_side_top = [(m + (i + 1) % n, m + i, i) for i in range(n)] - faces = faces_bot + faces_top + faces_side_bot + faces_side_top - - mesh = trimesh.Trimesh(vertices, faces) - - section = mesh.section(plane_origin=origin, plane_normal=normal) - if section is None: - return [] - path, _ = section.to_2D(to_2D=to_2D) - return path.polygons_full - - def _intersections_normal( - self, z: float, quad_segs: Optional[int] = None - ) -> list[BaseGeometry]: - """Find shapely geometries intersecting cylindrical geometry with axis normal to slab. - - Parameters - ---------- - z : float - Position along the axis normal to slab - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - ``_N_SHAPELY_QUAD_SEGS_VISUALIZATION`` for high-quality visualization. - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - if quad_segs is None: - quad_segs = _N_SHAPELY_QUAD_SEGS_VISUALIZATION - - static_self = self.to_static() - - # radius at z - radius_offset = static_self._radius_z(z) - - if radius_offset <= 0: - return [] - - _, (x0, y0) = self.pop_axis(static_self.center, axis=self.axis) - return [shapely.Point(x0, y0).buffer(radius_offset, quad_segs=quad_segs)] - - def _intersections_side(self, position: float, axis: int) -> list[BaseGeometry]: - """Find shapely geometries intersecting cylindrical geometry with axis orthogonal to length. - When ``sidewall_angle`` is nonzero, so that it's in fact a conical frustum or cone, the - cross section can contain hyperbolic curves. This is currently approximated by a polygon - of many vertices. - - Parameters - ---------- - position : float - Position along axis direction. - axis : int - Integer index into 'xyz' (0, 1, 2). - - Returns - ------- - List[shapely.geometry.base.BaseGeometry] - List of 2D shapes that intersect plane. - For more details refer to - `Shapely's Documentation `_. - """ - # position in the local coordinate of the cylinder - position_local = position - self.center[axis] - - # no intersection - if abs(position_local) >= self.radius_max: - return [] - - # half of intersection length at the top and bottom - intersect_half_length_max = np.sqrt(self.radius_max**2 - position_local**2) - intersect_half_length_min = -LARGE_NUMBER - if abs(position_local) < self.radius_min: - intersect_half_length_min = np.sqrt(self.radius_min**2 - position_local**2) - - # the vertices on the max side of top/bottom - # The two vertices are present in all scenarios. - vertices_max = [ - self._local_to_global_side_cross_section([-intersect_half_length_max, 0], axis), - self._local_to_global_side_cross_section([intersect_half_length_max, 0], axis), - ] - - # Extending to a cone, the maximal height of the cone - h_cone = ( - LARGE_NUMBER if isclose(self.sidewall_angle, 0) else self.radius_max / abs(self._tanq) - ) - # The maximal height of the cross section - height_max = min( - (1 - abs(position_local) / self.radius_max) * h_cone, self.finite_length_axis - ) - - # more vertices to add for conical frustum shape - vertices_frustum_right = [] - vertices_frustum_left = [] - if not (isclose(position, self.center[axis]) or isclose(self.sidewall_angle, 0)): - # The y-coordinate for the additional vertices - y_list = height_max * np.linspace(0, 1, _N_SAMPLE_CURVE_SHAPELY) - # `abs()` to make sure np.sqrt(0-fp_eps) goes through - x_list = np.sqrt( - np.abs(self.radius_max**2 * (1 - y_list / h_cone) ** 2 - position_local**2) - ) - for i in range(_N_SAMPLE_CURVE_SHAPELY): - vertices_frustum_right.append( - self._local_to_global_side_cross_section([x_list[i], y_list[i]], axis) - ) - vertices_frustum_left.append( - self._local_to_global_side_cross_section( - [ - -x_list[_N_SAMPLE_CURVE_SHAPELY - i - 1], - y_list[_N_SAMPLE_CURVE_SHAPELY - i - 1], - ], - axis, - ) - ) - - # the vertices on the min side of top/bottom - vertices_min = [] - - ## termination at the top/bottom - if intersect_half_length_min > 0: - vertices_min.append( - self._local_to_global_side_cross_section( - [intersect_half_length_min, self.finite_length_axis], axis - ) - ) - vertices_min.append( - self._local_to_global_side_cross_section( - [-intersect_half_length_min, self.finite_length_axis], axis - ) - ) - ## early termination - else: - vertices_min.append(self._local_to_global_side_cross_section([0, height_max], axis)) - - return [ - shapely.Polygon( - vertices_max + vertices_frustum_right + vertices_min + vertices_frustum_left - ) - ] - - def inside( - self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float] - ) -> np.ndarray[bool]: - """For input arrays ``x``, ``y``, ``z`` of arbitrary but identical shape, return an array - with the same shape which is ``True`` for every point in zip(x, y, z) that is inside the - volume of the :class:`Geometry`, and ``False`` otherwise. - - Parameters - ---------- - x : np.ndarray[float] - Array of point positions in x direction. - y : np.ndarray[float] - Array of point positions in y direction. - z : np.ndarray[float] - Array of point positions in z direction. - - Returns - ------- - np.ndarray[bool] - ``True`` for every point that is inside the geometry. - """ - # radius at z - self._ensure_equal_shape(x, y, z) - z0, (x0, y0) = self.pop_axis(self.center, axis=self.axis) - z, (x, y) = self.pop_axis((x, y, z), axis=self.axis) - radius_offset = self._radius_z(z) - positive_radius = radius_offset > 0 - - dist_x = np.abs(x - x0) - dist_y = np.abs(y - y0) - dist_z = np.abs(z - z0) - inside_radius = (dist_x**2 + dist_y**2) <= (radius_offset**2) - inside_height = dist_z <= (self.finite_length_axis / 2) - return positive_radius * inside_radius * inside_height - - @cached_property - def bounds(self) -> Bound: - """Returns bounding box min and max coordinates. - - Returns - ------- - Tuple[float, float, float], Tuple[float, float, float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - """ - coord_min = [c - self.radius_max for c in self.center] - coord_max = [c + self.radius_max for c in self.center] - coord_min[self.axis] = self.center[self.axis] - self.length_axis / 2.0 - coord_max[self.axis] = self.center[self.axis] + self.length_axis / 2.0 - return (tuple(coord_min), tuple(coord_max)) - - def _volume(self, bounds: Bound) -> float: - """Returns object's volume within given bounds.""" - - coord_min = max(self.bounds[0][self.axis], bounds[0][self.axis]) - coord_max = min(self.bounds[1][self.axis], bounds[1][self.axis]) - - length = coord_max - coord_min - - volume = np.pi * self.radius_max**2 * length - - # a very loose upper bound on how much of the cylinder is in bounds - for axis in range(3): - if axis != self.axis: - if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: - volume *= 0.5 - - return volume - - def _surface_area(self, bounds: Bound) -> float: - """Returns object's surface area within given bounds.""" - - area = 0 - - coord_min = self.bounds[0][self.axis] - coord_max = self.bounds[1][self.axis] - - if coord_min < bounds[0][self.axis]: - coord_min = bounds[0][self.axis] - else: - area += np.pi * self.radius_max**2 - - if coord_max > bounds[1][self.axis]: - coord_max = bounds[1][self.axis] - else: - area += np.pi * self.radius_max**2 - - length = coord_max - coord_min - - area += 2.0 * np.pi * self.radius_max * length - - # a very loose upper bound on how much of the cylinder is in bounds - for axis in range(3): - if axis != self.axis: - if self.center[axis] <= bounds[0][axis] or self.center[axis] >= bounds[1][axis]: - area *= 0.5 - - return area - - @cached_property - def radius_bottom(self) -> float: - """radius of bottom""" - return self._radius_z(self.center_axis - self.finite_length_axis / 2) - - @cached_property - def radius_top(self) -> float: - """radius of bottom""" - return self._radius_z(self.center_axis + self.finite_length_axis / 2) - - @cached_property - def radius_max(self) -> float: - """max(radius of top, radius of bottom)""" - return max(self.radius_bottom, self.radius_top) - - @cached_property - def radius_min(self) -> float: - """min(radius of top, radius of bottom). It can be negative for a large - sidewall angle. - """ - return min(self.radius_bottom, self.radius_top) - - def _radius_z(self, z: float) -> float: - """Compute the radius of the cross section at the position z. - - Parameters - ---------- - z : float - Position along the axis normal to slab - """ - if isclose(self.sidewall_angle, 0): - return self.radius - - radius_middle = self.radius - if self.reference_plane == "top": - radius_middle += self.finite_length_axis / 2 * self._tanq - elif self.reference_plane == "bottom": - radius_middle -= self.finite_length_axis / 2 * self._tanq - - return radius_middle - (z - self.center_axis) * self._tanq - - def _local_to_global_side_cross_section(self, coords: list[float], axis: int) -> list[float]: - """Map a point (x,y) from local to global coordinate system in the - side cross section. - - The definition of the local: y=0 lies at the base if ``sidewall_angle>=0``, - and at the top if ``sidewall_angle<0``; x=0 aligns with the corresponding - ``self.center``. In both cases, y-axis is pointing towards the narrowing - direction of cylinder. - - Parameters - ---------- - axis : int - Integer index into 'xyz' (0, 1, 2). - coords : List[float, float] - The value in the planar coordinate. - - Returns - ------- - Tuple[float, float] - The point in the global coordinate for plotting `_intersection_side`. - - """ - - # For negative sidewall angle, quantities along axis direction usually needs a flipped sign - axis_sign = 1 - if self.sidewall_angle < 0: - axis_sign = -1 +# marked as migrated to _common +from __future__ import annotations - lx_offset, ly_offset = self._order_by_axis( - plane_val=coords[0], - axis_val=axis_sign * (-self.finite_length_axis / 2 + coords[1]), - axis=axis, - ) - _, (x_center, y_center) = self.pop_axis(self.center, axis=axis) - return [x_center + lx_offset, y_center + ly_offset] +from tidy3d._common.components.geometry.primitives import ( + _DEFAULT_EDGE_FRACTION, + _MAX_ICOSPHERE_SUBDIVISIONS, + _N_PTS_CYLINDER_POLYSLAB, + _N_SAMPLE_CURVE_SHAPELY, + _N_SHAPELY_QUAD_SEGS_VISUALIZATION, + UNIT_SPHERE, + Cylinder, + Sphere, + _base_icosahedron, +) diff --git a/tidy3d/components/geometry/triangulation.py b/tidy3d/components/geometry/triangulation.py index c3329c6c20..96debe35cf 100644 --- a/tidy3d/components/geometry/triangulation.py +++ b/tidy3d/components/geometry/triangulation.py @@ -1,181 +1,14 @@ -from __future__ import annotations - -from dataclasses import dataclass - -import numpy as np -import shapely - -from tidy3d.components.types import ArrayFloat1D, ArrayFloat2D -from tidy3d.exceptions import Tidy3dError - - -@dataclass -class Vertex: - """Simple data class to hold triangulation data structures. - - Parameters - ---------- - coordinate: ArrayFloat1D - Vertex coordinate. - index : int - Vertex index in the original polygon. - convexity : float = 0.0 - Value representing the convexity (> 0) or concavity (< 0) of the vertex in the polygon. - is_ear : bool = False - Flag indicating whether this is an ear of the polygon. - """ - - coordinate: ArrayFloat1D - - index: int - - convexity: float - - is_ear: bool - - -def update_convexity(vertices: list[Vertex], i: int) -> int: - """Update the convexity of a vertex in a polygon. - - Parameters - ---------- - vertices : List[Vertex] - Vertices of the polygon. - i : int - Index of the vertex to be updated. - - Returns - ------- - int - Value indicating vertex convexity change w.r.t. 0. See note below. - - Note - ---- - Besides updating the vertex, this function returns a value indicating whether the updated vertex - convexity changed to or from 0 (0 convexity means the vertex is collinear with its neighbors). - If the convexity changes from zero to non-zero, return -1. If it changes from non-zero to zero, - return +1. Return 0 in any other case. This allows the main triangulation loop to keep track of - the total number of collinear vertices in the polygon. +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.triangulation`.""" - """ - result = -1 if vertices[i].convexity == 0.0 else 0 - j = (i + 1) % len(vertices) - vertices[i].convexity = np.linalg.det( - [ - vertices[i].coordinate - vertices[i - 1].coordinate, - vertices[j].coordinate - vertices[i].coordinate, - ] - ) - if vertices[i].convexity == 0.0: - result += 1 - return result +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -def is_inside( - vertex: ArrayFloat1D, triangle: tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] -) -> bool: - """Check if a vertex is inside a triangle. - - Parameters - ---------- - vertex : ArrayFloat1D - Vertex coordinates. - triangle : Tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] - Vertices of the triangle in CCW order. - - Returns - ------- - bool: - Flag indicating if the vertex is inside the triangle. - """ - return all( - np.linalg.det([triangle[i] - triangle[i - 1], vertex - triangle[i - 1]]) > 0 - for i in range(3) - ) - - -def update_ear_flag(vertices: list[Vertex], i: int) -> None: - """Update the ear flag of a vertex in a polygon. - - Parameters - ---------- - vertices : List[Vertex] - Vertices of the polygon. - i : int - Index of the vertex to be updated. - """ - h = (i - 1) % len(vertices) - j = (i + 1) % len(vertices) - triangle = (vertices[h].coordinate, vertices[i].coordinate, vertices[j].coordinate) - vertices[i].is_ear = vertices[i].convexity > 0 and not any( - is_inside(v.coordinate, triangle) - for k, v in enumerate(vertices) - if not (v.convexity > 0 or k == h or k == i or k == j) - ) - - -# TODO: This is an inefficient algorithm that runs in O(n^2). We should use something -# better, and probably as a compiled extension. -def triangulate(vertices: ArrayFloat2D) -> list[tuple[int, int, int]]: - """Triangulate a simple polygon. - - Parameters - ---------- - vertices : ArrayFloat2D - Vertices of the polygon. - - Returns - ------- - List[Tuple[int, int, int]] - List of indices of the vertices of the triangles. - """ - is_ccw = shapely.LinearRing(vertices).is_ccw - - # Initialize vertices as non-collinear because we will update the actual value below and count - # the number of collinear vertices. - vertices = [Vertex(v, i, -1.0, False) for i, v in enumerate(vertices)] - if not is_ccw: - vertices.reverse() - - collinears = 0 - for i in range(len(vertices)): - collinears += update_convexity(vertices, i) - - for i in range(len(vertices)): - update_ear_flag(vertices, i) - - triangles = [] - - ear_found = True - while len(vertices) > 3: - if not ear_found: - raise Tidy3dError( - "Impossible to triangulate polygon. Verify that the polygon is valid." - ) - ear_found = False - i = 0 - while i < len(vertices): - if vertices[i].is_ear: - removed = vertices.pop(i) - h = (i - 1) % len(vertices) - j = i % len(vertices) - collinears += update_convexity(vertices, h) - collinears += update_convexity(vertices, j) - if collinears == len(vertices): - # Undo removal because only collinear vertices remain - vertices.insert(i, removed) - collinears += update_convexity(vertices, (i - 1) % len(vertices)) - collinears += update_convexity(vertices, (i + 1) % len(vertices)) - i += 1 - else: - ear_found = True - triangles.append((vertices[h].index, removed.index, vertices[j].index)) - update_ear_flag(vertices, h) - update_ear_flag(vertices, j) - if len(vertices) == 3: - break - else: - i += 1 - - triangles.append(tuple(v.index for v in vertices)) - return triangles +from tidy3d._common.components.geometry.triangulation import ( + Vertex, + is_inside, + triangulate, + update_convexity, + update_ear_flag, +) diff --git a/tidy3d/components/geometry/utils.py b/tidy3d/components/geometry/utils.py index 9a483278ec..95dce22a99 100644 --- a/tidy3d/components/geometry/utils.py +++ b/tidy3d/components/geometry/utils.py @@ -1,491 +1,48 @@ -"""Utilities for geometry manipulation.""" +"""Compatibility shim for :mod:`tidy3d._common.components.geometry.utils`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as partially migrated to _common from __future__ import annotations -from collections import defaultdict -from collections.abc import Iterable -from enum import Enum from math import isclose -from typing import Any, Optional, Union +from typing import TYPE_CHECKING import numpy as np -import pydantic.v1 as pydantic -import shapely -from shapely.geometry import ( - Polygon, -) -from shapely.geometry.base import ( - BaseMultipartGeometry, -) -from tidy3d.components.autograd.utils import get_static -from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.geometry.base import Box -from tidy3d.components.grid.grid import Grid -from tidy3d.components.types import ( - ArrayFloat2D, - Axis, - Bound, - Coordinate, - Direction, - MatrixReal4x4, - PlanePosition, - Shapely, +from tidy3d._common.components.geometry.utils import ( + GeometryType, + SnapBehavior, + SnapLocation, + SnappingSpec, # noqa: TC + flatten_groups, + flatten_shapely_geometries, + from_shapely, + get_closest_value, + merging_geometries_on_plane, + traverse_geometries, + validate_no_transformed_polyslabs, + vertices_from_shapely, ) +from tidy3d.components.geometry.base import Box from tidy3d.constants import fp_eps -from tidy3d.exceptions import SetupError, Tidy3dError - -from . import base, mesh, polyslab, primitives - -GeometryType = Union[ - base.Box, - base.Transformed, - base.ClipOperation, - base.GeometryGroup, - primitives.Sphere, - primitives.Cylinder, - polyslab.PolySlab, - polyslab.ComplexPolySlabBase, - mesh.TriangleMesh, -] - - -def flatten_shapely_geometries( - geoms: Union[Shapely, Iterable[Shapely]], keep_types: tuple[type, ...] = (Polygon,) -) -> list[Shapely]: - """ - Flatten nested geometries into a flat list, while only keeping the specified types. - - Recursively extracts and returns non-empty geometries of the given types from input geometries, - expanding any GeometryCollections or Multi* types. - - Parameters - ---------- - geoms : Union[Shapely, Iterable[Shapely]] - Input geometries to flatten. - - keep_types : tuple[type, ...] - Geometry types to keep (e.g., (Polygon, LineString)). Default is - (Polygon). - - Returns - ------- - list[Shapely] - Flat list of non-empty geometries matching the specified types. - """ - # Handle single Shapely object by wrapping it in a list - if isinstance(geoms, Shapely): - geoms = [geoms] - - flat = [] - for geom in geoms: - if geom.is_empty: - continue - if isinstance(geom, keep_types): - flat.append(geom) - elif isinstance(geom, BaseMultipartGeometry): - flat.extend(flatten_shapely_geometries(geom.geoms, keep_types)) - return flat - - -def merging_geometries_on_plane( - geometries: list[GeometryType], - plane: Box, - property_list: list[Any], - interior_disjoint_geometries: bool = False, - cleanup: bool = True, - quad_segs: Optional[int] = None, -) -> list[tuple[Any, Shapely]]: - """Compute list of shapes on plane. Overlaps are removed or merged depending on - provided property_list. - - Parameters - ---------- - geometries : List[GeometryType] - List of structures to filter on the plane. - plane : Box - Plane specification. - property_list : List = None - Property value for each structure. - interior_disjoint_geometries: bool = False - If ``True``, geometries of different properties on the plane must not be overlapping. - cleanup : bool = True - If True, removes extremely small features from each polygon's boundary. - quad_segs : Optional[int] = None - Number of segments used to discretize circular shapes. If ``None``, uses - high-quality visualization settings. - - Returns - ------- - List[Tuple[Any, Shapely]] - List of shapes and their property value on the plane after merging. - """ - - if len(geometries) != len(property_list): - raise SetupError( - "Number of provided property values is not equal to the number of geometries." - ) - - shapes = [] - for geo, prop in zip(geometries, property_list): - # get list of Shapely shapes that intersect at the plane - shapes_plane = plane.intersections_with(geo, cleanup=cleanup, quad_segs=quad_segs) - - # Append each of them and their property information to the list of shapes - for shape in shapes_plane: - shapes.append((prop, shape, shape.bounds)) - - if interior_disjoint_geometries: - # No need to consider overlapping. We simply group shapes by property, and union_all - # shapes of the same property. - shapes_by_prop = defaultdict(list) - for prop, shape, _ in shapes: - shapes_by_prop[prop].append(shape) - # union shapes of same property - results = [] - for prop, shapes in shapes_by_prop.items(): - unionized = shapely.union_all(shapes).buffer(0).normalize() - if not unionized.is_empty: - results.append((prop, unionized)) - return results - - background_shapes = [] - for prop, shape, bounds in shapes: - minx, miny, maxx, maxy = bounds - - # loop through background_shapes (note: all background are non-intersecting or merged) - for index, (_prop, _shape, _bounds) in enumerate(background_shapes): - _minx, _miny, _maxx, _maxy = _bounds - - # do a bounding box check to see if any intersection to do anything about - if minx > _maxx or _minx > maxx or miny > _maxy or _miny > maxy: - continue - - # look more closely to see if intersected. - if shape.disjoint(_shape): - continue - - # different prop, remove intersection from background shape - if prop != _prop: - diff_shape = (_shape - shape).buffer(0).normalize() - # mark background shape for removal if nothing left - if diff_shape.is_empty or len(diff_shape.bounds) == 0: - background_shapes[index] = None - background_shapes[index] = (_prop, diff_shape, diff_shape.bounds) - # same prop, unionize shapes and mark background shape for removal - else: - shape = (shape | _shape).buffer(0).normalize() - background_shapes[index] = None - - # after doing this with all background shapes, add this shape to the background - background_shapes.append((prop, shape, shape.bounds)) - - # remove any existing background shapes that have been marked as 'None' - background_shapes = [b for b in background_shapes if b is not None] - - # filter out any remaining None or empty shapes (shapes with area completely removed) - return [(prop, shape) for (prop, shape, _) in background_shapes if shape] - - -def flatten_groups( - *geometries: GeometryType, - flatten_nonunion_type: bool = False, - flatten_transformed: bool = False, - transform: Optional[MatrixReal4x4] = None, -) -> GeometryType: - """Iterates over all geometries, flattening groups and unions. - - Parameters - ---------- - *geometries : GeometryType - Geometries to flatten. - flatten_nonunion_type : bool = False - If ``False``, only flatten geometry unions (and ``GeometryGroup``). If ``True``, flatten - all clip operations. - flatten_transformed : bool = False - If ``True``, ``Transformed`` groups are flattened into individual transformed geometries. - transform : Optional[MatrixReal4x4] - Accumulated transform from parents. Only used when ``flatten_transformed`` is ``True``. - - Yields - ------ - GeometryType - Geometries after flattening groups and unions. - """ - for geometry in geometries: - if isinstance(geometry, base.GeometryGroup): - yield from flatten_groups( - *geometry.geometries, - flatten_nonunion_type=flatten_nonunion_type, - flatten_transformed=flatten_transformed, - transform=transform, - ) - elif isinstance(geometry, base.ClipOperation) and ( - flatten_nonunion_type or geometry.operation == "union" - ): - yield from flatten_groups( - geometry.geometry_a, - geometry.geometry_b, - flatten_nonunion_type=flatten_nonunion_type, - flatten_transformed=flatten_transformed, - transform=transform, - ) - elif flatten_transformed and isinstance(geometry, base.Transformed): - new_transform = geometry.transform - if transform is not None: - new_transform = np.matmul(transform, new_transform) - yield from flatten_groups( - geometry.geometry, - flatten_nonunion_type=flatten_nonunion_type, - flatten_transformed=flatten_transformed, - transform=new_transform, - ) - elif flatten_transformed and transform is not None: - yield base.Transformed(geometry=geometry, transform=transform) - else: - yield geometry - - -def traverse_geometries(geometry: GeometryType) -> GeometryType: - """Iterator over all geometries within the given geometry. - - Iterates over groups and clip operations within the given geometry, yielding each one. - - Parameters - ---------- - geometry: GeometryType - Base geometry to start iteration. - - Returns - ------- - :class:`Geometry` - Geometries within the base geometry. - """ - if isinstance(geometry, base.GeometryGroup): - for g in geometry.geometries: - yield from traverse_geometries(g) - elif isinstance(geometry, base.ClipOperation): - yield from traverse_geometries(geometry.geometry_a) - yield from traverse_geometries(geometry.geometry_b) - yield geometry - - -def from_shapely( - shape: Shapely, - axis: Axis, - slab_bounds: tuple[float, float], - dilation: float = 0.0, - sidewall_angle: float = 0, - reference_plane: PlanePosition = "middle", -) -> base.Geometry: - """Convert a shapely primitive into a geometry instance by extrusion. - - Parameters - ---------- - shape : shapely.geometry.base.BaseGeometry - Shapely primitive to be converted. It must be a linear ring, a polygon or a collection - of any of those. - axis : int - Integer index defining the extrusion axis: 0 (x), 1 (y), or 2 (z). - slab_bounds: Tuple[float, float] - Minimal and maximal positions of the extruded slab along ``axis``. - dilation : float - Dilation of the polygon in the base by shifting each edge along its normal outwards - direction by a distance; a negative value corresponds to erosion. - sidewall_angle : float = 0 - Angle of the extrusion sidewalls, away from the vertical direction, in radians. Positive - (negative) values result in slabs larger (smaller) at the base than at the top. - reference_plane : PlanePosition = "middle" - Reference position of the (dilated/eroded) polygons along the slab axis. One of - ``"middle"`` (polygons correspond to the center of the slab bounds), ``"bottom"`` - (minimal slab bound position), or ``"top"`` (maximal slab bound position). This value - has no effect if ``sidewall_angle == 0``. - - Returns - ------- - :class:`Geometry` - Geometry extruded from the 2D data. - """ - if shape.geom_type == "LinearRing": - if sidewall_angle == 0: - return polyslab.PolySlab( - vertices=shape.coords[:-1], - axis=axis, - slab_bounds=slab_bounds, - dilation=dilation, - reference_plane=reference_plane, - ) - group = polyslab.ComplexPolySlabBase( - vertices=shape.coords[:-1], - axis=axis, - slab_bounds=slab_bounds, - dilation=dilation, - sidewall_angle=sidewall_angle, - reference_plane=reference_plane, - ).geometry_group - return group.geometries[0] if len(group.geometries) == 1 else group - - if shape.geom_type == "Polygon": - exterior = from_shapely( - shape.exterior, axis, slab_bounds, dilation, sidewall_angle, reference_plane - ) - interior = [ - from_shapely(hole, axis, slab_bounds, -dilation, -sidewall_angle, reference_plane) - for hole in shape.interiors - ] - if len(interior) == 0: - return exterior - interior = interior[0] if len(interior) == 1 else base.GeometryGroup(geometries=interior) - return base.ClipOperation(operation="difference", geometry_a=exterior, geometry_b=interior) - - if shape.geom_type in {"MultiPolygon", "GeometryCollection"}: - return base.GeometryGroup( - geometries=[ - from_shapely(geo, axis, slab_bounds, dilation, sidewall_angle, reference_plane) - for geo in shape.geoms - ] - ) - - raise Tidy3dError(f"Shape {shape} cannot be converted to Geometry.") - +from tidy3d.exceptions import SetupError -def vertices_from_shapely(shape: Shapely) -> ArrayFloat2D: - """Iterate over the polygons of a shapely geometry returning the vertices. +if TYPE_CHECKING: + from typing import Optional - Parameters - ---------- - shape : shapely.geometry.base.BaseGeometry - Shapely primitive to have its vertices extracted. It must be a linear ring, a polygon or a - collection of any of those. + from numpy.typing import ArrayLike + from pydantic import NonNegativeInt - Returns - ------- - List[Tuple[ArrayFloat2D]] - List of tuples ``(exterior, *interiors)``. - """ - if shape.geom_type == "LinearRing": - return [(shape.coords[:-1],)] - if shape.geom_type == "Polygon": - return [(shape.exterior.coords[:-1], *tuple(hole.coords[:-1] for hole in shape.interiors))] - if shape.geom_type in {"MultiPolygon", "GeometryCollection"}: - return sum(vertices_from_shapely(geo) for geo in shape.geoms) - - raise Tidy3dError(f"Shape {shape} cannot be converted to Geometry.") - - -def validate_no_transformed_polyslabs( - geometry: GeometryType, transform: MatrixReal4x4 = None -) -> None: - """Prevents the creation of slanted polyslabs rotated out of plane.""" - if transform is None: - transform = np.eye(4) - if isinstance(geometry, polyslab.PolySlab): - # sidewall_angle may be autograd-traced; unbox for the check only - if not ( - isclose(get_static(geometry.sidewall_angle), 0) - or base.Transformed.preserves_axis(transform, geometry.axis) - ): - raise Tidy3dError( - "Slanted PolySlabs are not allowed to be rotated out of the slab plane." - ) - elif isinstance(geometry, base.Transformed): - transform = np.dot(transform, geometry.transform) - validate_no_transformed_polyslabs(geometry.geometry, transform) - elif isinstance(geometry, base.GeometryGroup): - for geo in geometry.geometries: - validate_no_transformed_polyslabs(geo, transform) - elif isinstance(geometry, base.ClipOperation): - validate_no_transformed_polyslabs(geometry.geometry_a, transform) - validate_no_transformed_polyslabs(geometry.geometry_b, transform) - - -class SnapLocation(Enum): - """Describes different methods for defining the snapping locations.""" - - Boundary = 1 - """ - Choose the boundaries of Yee cells. - """ - Center = 2 - """ - Choose the center of Yee cells. - """ - - -class SnapBehavior(Enum): - """Describes different methods for snapping intervals, which are defined by two endpoints.""" - - Closest = 1 - """ - Snaps the interval's endpoints to the closest grid point. - """ - Expand = 2 - """ - Snaps the interval's endpoints to the closest grid points, - while guaranteeing that the snapping location will never move endpoints inwards. - """ - Contract = 3 - """ - Snaps the interval's endpoints to the closest grid points, - while guaranteeing that the snapping location will never move endpoints outwards. - """ - StrictExpand = 4 - """ - Same as Expand, but will always move endpoints outwards, even if already coincident with grid. - """ - StrictContract = 5 - """ - Same as Contract, but will always move endpoints inwards, even if already coincident with grid. - """ - Off = 6 - """ - Do not use snapping. - """ - - -class SnappingSpec(Tidy3dBaseModel): - """Specifies how to apply grid snapping along each dimension.""" - - location: tuple[SnapLocation, SnapLocation, SnapLocation] = pydantic.Field( - ..., - title="Location", - description="Describes which positions in the grid will be considered for snapping.", - ) - - behavior: tuple[SnapBehavior, SnapBehavior, SnapBehavior] = pydantic.Field( - ..., - title="Behavior", - description="Describes how snapping positions will be chosen.", - ) - - margin: Optional[ - tuple[pydantic.NonNegativeInt, pydantic.NonNegativeInt, pydantic.NonNegativeInt] - ] = pydantic.Field( - (0, 0, 0), - title="Margin", - description="Number of additional grid points to consider when expanding or contracting " - "during snapping. Only applies when ``SnapBehavior`` is ``Expand`` or ``Contract``.", + from tidy3d.components.grid.grid import Grid + from tidy3d.components.types.base import ( + Bound, + Coordinate, + Direction, ) -def get_closest_value(test: float, coords: np.ArrayLike, upper_bound_idx: int) -> float: - """Helper to choose the closest value in an array to a given test value, - using the index of the upper bound. The ``upper_bound_idx`` corresponds to the first value in - the ``coords`` array which is greater than or equal to the test value. - """ - # Handle corner cases first - if upper_bound_idx == 0: - return coords[upper_bound_idx] - if upper_bound_idx == len(coords): - return coords[upper_bound_idx - 1] - # General case - lower_bound = coords[upper_bound_idx - 1] - upper_bound = coords[upper_bound_idx] - dlower = abs(test - lower_bound) - dupper = abs(test - upper_bound) - return lower_bound if dlower < dupper else upper_bound - - def snap_box_to_grid(grid: Grid, box: Box, snap_spec: SnappingSpec, rtol: float = fp_eps) -> Box: """Snaps a :class:`.Box` to the grid, so that the boundaries of the box are aligned with grid centers or boundaries. The way in which each dimension of the `box` is snapped to the grid is controlled by ``snap_spec``. @@ -496,7 +53,7 @@ def _clamp_index(idx: int, length: int) -> int: def get_lower_bound( test: float, - coords: np.ArrayLike, + coords: ArrayLike, upper_bound_idx: int, rel_tol: float, strict_bounds: bool, @@ -513,7 +70,7 @@ def get_lower_bound( ---------- test : float The value to snap. - coords : np.ArrayLike + coords : ArrayLike Sorted array of coordinate values to snap to. upper_bound_idx : int Index from ``np.searchsorted(coords, test, side="left")`` - the first index where @@ -553,7 +110,7 @@ def get_lower_bound( def get_upper_bound( test: float, - coords: np.ArrayLike, + coords: ArrayLike, upper_bound_idx: int, rel_tol: float, strict_bounds: bool, @@ -570,7 +127,7 @@ def get_upper_bound( ---------- test : float The value to snap. - coords : np.ArrayLike + coords : ArrayLike Sorted array of coordinate values to snap to. upper_bound_idx : int Index from ``np.searchsorted(coords, test, side="left")`` - the first index where @@ -614,7 +171,7 @@ def find_snapping_locations( interval_max: float, coords: np.ndarray, snap_type: SnapBehavior, - snap_margin: pydantic.NonNegativeInt, + snap_margin: NonNegativeInt, ) -> tuple[float, float]: """Helper that snaps a supplied interval [interval_min, interval_max] to a sorted array representing coordinate values. @@ -736,7 +293,7 @@ def _shift_value_signed( # get the index of the grid cell where the obj lies obj_position = obj.center[normal_axis] - obj_pos_gt_grid_bounds = np.argwhere(obj_position > grid_boundaries) + obj_pos_gt_grid_bounds = np.flatnonzero(obj_position > grid_boundaries) # no obj index can be determined if len(obj_pos_gt_grid_bounds) == 0 or obj_position > grid_boundaries[-1]: diff --git a/tidy3d/components/geometry/utils_2d.py b/tidy3d/components/geometry/utils_2d.py index f7399d775a..c285ea27fc 100644 --- a/tidy3d/components/geometry/utils_2d.py +++ b/tidy3d/components/geometry/utils_2d.py @@ -3,6 +3,7 @@ from __future__ import annotations from math import isclose +from typing import TYPE_CHECKING import numpy as np import shapely @@ -10,12 +11,14 @@ from tidy3d.components.geometry.base import Box, ClipOperation, Geometry, GeometryGroup from tidy3d.components.geometry.float_utils import increment_float from tidy3d.components.geometry.polyslab import _MIN_POLYGON_AREA, PolySlab -from tidy3d.components.grid.grid import Grid from tidy3d.components.scene import Scene -from tidy3d.components.structure import Structure -from tidy3d.components.types import Axis, Shapely from tidy3d.constants import fp_eps +if TYPE_CHECKING: + from tidy3d.components.grid.grid import Grid + from tidy3d.components.structure import Structure + from tidy3d.components.types import Axis, Shapely + def snap_coordinate_to_grid(grid: Grid, center: float, axis: Axis) -> float: """2D materials are snapped to grid along their normal axis""" @@ -88,12 +91,12 @@ def subdivide( ---------- geom : Geometry A 2D geometry associated with the :class:`.Medium2D`. - structures : List[Structure] + structures : list[Structure] List of structures that are checked for intersection with ``geom``. Returns ------- - List[Tuple[Geometry, Structure, Structure]] + list[tuple[Geometry, Structure, Structure]] List of the created partitions. Each element of the list represents a partition of the 2D geometry, which includes the newly created structures below and above. diff --git a/tidy3d/components/grid/corner_finder.py b/tidy3d/components/grid/corner_finder.py index 6cf8b34339..bc89fba3a3 100644 --- a/tidy3d/components/grid/corner_finder.py +++ b/tidy3d/components/grid/corner_finder.py @@ -2,19 +2,23 @@ from __future__ import annotations -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, PositiveInt from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.geometry.base import Box, ClipOperation from tidy3d.components.geometry.utils import merging_geometries_on_plane from tidy3d.components.medium import PEC, LossyMetalMedium -from tidy3d.components.structure import Structure -from tidy3d.components.types import ArrayFloat1D, ArrayFloat2D, Axis, Shapely from tidy3d.constants import inf +if TYPE_CHECKING: + from numpy.typing import NDArray + + from tidy3d.components.structure import Structure + from tidy3d.components.types import ArrayFloat1D, ArrayFloat2D, Axis, Shapely + CORNER_ANGLE_THRESOLD = 0.25 * np.pi # For shapely circular shapes discretization. N_SHAPELY_QUAD_SEGS = 8 @@ -25,7 +29,7 @@ class CornerFinderSpec(Tidy3dBaseModel): """Specification for corner detection on a 2D plane.""" - medium: Literal["metal", "dielectric", "all"] = pd.Field( + medium: Literal["metal", "dielectric", "all"] = Field( "metal", title="Material Type For Corner Identification", description="Find corners of structures made of :class:`.Medium`, " @@ -33,7 +37,7 @@ class CornerFinderSpec(Tidy3dBaseModel): "for non-metallic materials, and ``all`` for all materials.", ) - angle_threshold: float = pd.Field( + angle_threshold: float = Field( CORNER_ANGLE_THRESOLD, title="Angle Threshold In Corner Identification", description="A vertex is qualified as a corner if the angle spanned by its two edges " @@ -43,28 +47,28 @@ class CornerFinderSpec(Tidy3dBaseModel): lt=np.pi, ) - distance_threshold: Optional[pd.PositiveFloat] = pd.Field( + distance_threshold: Optional[PositiveFloat] = Field( None, title="Distance Threshold In Corner Identification", description="If not ``None`` and the distance of the vertex to its neighboring vertices " "is below the threshold value based on Douglas-Peucker algorithm, the vertex is disqualified as a corner.", ) - concave_resolution: Optional[pd.PositiveInt] = pd.Field( + concave_resolution: Optional[PositiveInt] = Field( None, title="Concave Region Resolution.", description="Specifies number of steps to use for determining `dl_min` based on concave featues." "If set to ``None``, then the corresponding `dl_min` reduction is not applied.", ) - convex_resolution: Optional[pd.PositiveInt] = pd.Field( + convex_resolution: Optional[PositiveInt] = Field( None, title="Convex Region Resolution.", description="Specifies number of steps to use for determining `dl_min` based on convex featues." "If set to ``None``, then the corresponding `dl_min` reduction is not applied.", ) - mixed_resolution: Optional[pd.PositiveInt] = pd.Field( + mixed_resolution: Optional[PositiveInt] = Field( None, title="Mixed Region Resolution.", description="Specifies number of steps to use for determining `dl_min` based on mixed featues." @@ -72,7 +76,7 @@ class CornerFinderSpec(Tidy3dBaseModel): ) @cached_property - def _no_min_dl_override(self): + def _no_min_dl_override(self) -> bool: return all( ( self.concave_resolution is None, @@ -193,7 +197,7 @@ def _corners_and_convexity( return self._ravel_corners_and_convexity(ravel, corner_list, convexity_list) def _ravel_corners_and_convexity( - self, ravel: bool, corner_list, convexity_list + self, ravel: bool, corner_list: list[ArrayFloat2D], convexity_list: list[ArrayFloat1D] ) -> tuple[ArrayFloat2D, ArrayFloat1D]: """Whether to put the resulting corners in a single list or per polygon.""" if ravel and len(corner_list) > 0: @@ -269,7 +273,7 @@ def _filter_collinear_vertices( Convexity of corners: True for outer corners, False for inner corners. """ - def normalize(v): + def normalize(v: NDArray) -> NDArray: return v / np.linalg.norm(v, axis=-1)[:, np.newaxis] # drop the last vertex, which is identical to the 1st one. diff --git a/tidy3d/components/grid/grid.py b/tidy3d/components/grid/grid.py index 73f8485a9b..5a63f339bd 100644 --- a/tidy3d/components/grid/grid.py +++ b/tidy3d/components/grid/grid.py @@ -2,18 +2,27 @@ from __future__ import annotations -from typing import Literal, Union +from typing import TYPE_CHECKING, Any import numpy as np -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.data.data_array import DataArray, ScalarFieldDataArray, SpatialDataArray -from tidy3d.components.data.utils import UnstructuredGridDataset, UnstructuredGridDatasetType -from tidy3d.components.geometry.base import Box, Geometry -from tidy3d.components.types import ArrayFloat1D, ArrayLike, Axis, Coordinate, InterpMethod +from tidy3d.components.data.utils import UnstructuredGridDataset +from tidy3d.components.types import ArrayFloat1D from tidy3d.exceptions import SetupError +if TYPE_CHECKING: + from typing import Literal, Union + + from numpy.typing import NDArray + + from tidy3d.compat import Self + from tidy3d.components.data.utils import UnstructuredGridDatasetType + from tidy3d.components.geometry.base import Box, Geometry + from tidy3d.components.types import ArrayLike, Axis, Coordinate, InterpMethod + # data type of one dimensional coordinate array. Coords1D = ArrayFloat1D @@ -29,25 +38,28 @@ class Coords(Tidy3dBaseModel): >>> coords = Coords(x=x, y=y, z=z) """ - x: Coords1D = pd.Field( - ..., title="X Coordinates", description="1-dimensional array of x coordinates." + x: Coords1D = Field( + title="X Coordinates", + description="1-dimensional array of x coordinates.", ) - y: Coords1D = pd.Field( - ..., title="Y Coordinates", description="1-dimensional array of y coordinates." + y: Coords1D = Field( + title="Y Coordinates", + description="1-dimensional array of y coordinates.", ) - z: Coords1D = pd.Field( - ..., title="Z Coordinates", description="1-dimensional array of z coordinates." + z: Coords1D = Field( + title="Z Coordinates", + description="1-dimensional array of z coordinates.", ) @property - def to_dict(self): + def to_dict(self) -> dict[str, Any]: """Return a dict of the three Coord1D objects as numpy arrays.""" - return {key: self.dict()[key] for key in "xyz"} + return {key: self.model_dump()[key] for key in "xyz"} @property - def to_list(self): + def to_list(self) -> list[NDArray]: """Return a list of the three Coord1D objects as numpy arrays.""" return list(self.to_dict.values()) @@ -72,7 +84,7 @@ def cell_sizes(self) -> SpatialDataArray: return cell_sizes @cached_property - def cell_size_meshgrid(self): + def cell_size_meshgrid(self) -> NDArray: """Returns an N-dimensional grid where N is the number of coordinate arrays that have more than one element. Each grid element corresponds to the size of the mesh cell in N-dimensions and 1 for N=0.""" coord_dict = self.to_dict @@ -205,6 +217,45 @@ def _interp_from_unstructured( return interp_array + def get_bounding_indices( + self, coordinate: Coordinate, side: Literal["left", "right"], buffer: int = 0 + ) -> tuple[int, int, int]: + """Find the bounding indices up to a buffer corresponding to the supplied coordinate. For x, y, z + values supplied in coordinate, look for index into the x, y, and z coordinate arrays such that the + value at that index bounds the supplied coordinate entry on either the 'right' or 'left' side specified by + the side parameter. An optional buffer of number of indices can be specified with the default 0. All indices + are bound by 0 and the length of each coordinate array so that they can be directly used to index into the + coordinate arrays without going out of bounds.""" + + if not ((side == "left") or (side == "right")): + raise ValueError(f"Side should be 'left' or 'right', but got side={side}.") + + coords = self.to_dict + coord_indices = [] + for idx, key in enumerate("xyz"): + coords_for_axis = coords[key] + index = np.searchsorted(coords_for_axis, coordinate[idx], side=side) + + if side == "left": + index -= 1 + buffer + else: + index += buffer + + coord_indices.append(np.clip(index, 0, len(coords_for_axis) - 1)) + + return tuple(coord_indices) + + def get_bounding_values( + self, coordinate: Coordinate, side: Literal["left", "right"], buffer: int = 0 + ) -> Coordinate: + """Find the bounding values corresponding to the supplied coordinate. The bounding values extract the values + out of the coordinate arrays for the indices found in `get_bounding_indices`.""" + + bounding_indices = self.get_bounding_indices(coordinate, side, buffer) + + coords = self.to_dict + return tuple(coords[key][bounding_indices[idx]] for idx, key in enumerate("xyz")) + def spatial_interp( self, array: Union[SpatialDataArray, ScalarFieldDataArray, UnstructuredGridDatasetType], @@ -279,20 +330,17 @@ class FieldGrid(Tidy3dBaseModel): >>> field_grid = FieldGrid(x=coords, y=coords, z=coords) """ - x: Coords = pd.Field( - ..., + x: Coords = Field( title="X Positions", description="x,y,z coordinates of the locations of the x-component of a vector field.", ) - y: Coords = pd.Field( - ..., + y: Coords = Field( title="Y Positions", description="x,y,z coordinates of the locations of the y-component of a vector field.", ) - z: Coords = pd.Field( - ..., + z: Coords = Field( title="Z Positions", description="x,y,z coordinates of the locations of the z-component of a vector field.", ) @@ -312,20 +360,18 @@ class YeeGrid(Tidy3dBaseModel): >>> Ex_coords = yee_grid.E.x """ - E: FieldGrid = pd.Field( - ..., + E: FieldGrid = Field( title="Electric Field Grid", description="Coordinates of the locations of all three components of the electric field.", ) - H: FieldGrid = pd.Field( - ..., + H: FieldGrid = Field( title="Electric Field Grid", description="Coordinates of the locations of all three components of the magnetic field.", ) @property - def grid_dict(self): + def grid_dict(self) -> dict[str, Coords]: """The Yee grid coordinates associated to various field components as a dictionary.""" return { "Ex": self.E.x, @@ -352,19 +398,18 @@ class Grid(Tidy3dBaseModel): >>> yee_grid = grid.yee """ - boundaries: Coords = pd.Field( - ..., + boundaries: Coords = Field( title="Boundary Coordinates", description="x,y,z coordinates of the boundaries between cells, defining the FDTD grid.", ) @staticmethod - def _avg(coords1d: Coords1D): + def _avg(coords1d: Coords1D) -> Coords1D: """Return average positions of an array of 1D coordinates.""" return (coords1d[1:] + coords1d[:-1]) / 2.0 @staticmethod - def _min(coords1d: Coords1D): + def _min(coords1d: Coords1D) -> Coords1D: """Return minus positions of 1D coordinates.""" return coords1d[:-1] @@ -426,7 +471,7 @@ def num_cells(self) -> tuple[int, int, int]: >>> grid = Grid(boundaries=coords) >>> Nx, Ny, Nz = grid.num_cells """ - return [len(self.boundaries.dict()[dim]) - 1 for dim in "xyz"] + return [len(self.boundaries.model_dump()[dim]) - 1 for dim in "xyz"] @property def min_size(self) -> float: @@ -487,7 +532,7 @@ def _dual_steps(self) -> Coords: applied. """ - primal_steps = {dim: self._primal_steps.dict()[dim] for dim in "xyz"} + primal_steps = {dim: self._primal_steps.model_dump()[dim] for dim in "xyz"} dsteps = {key: (psteps + np.roll(psteps, 1)) / 2 for (key, psteps) in primal_steps.items()} return Coords(**dsteps) @@ -538,7 +583,7 @@ def __getitem__(self, coord_key: str) -> Coords: return coord_dict.get(coord_key) - def _yee_e(self, axis: Axis): + def _yee_e(self, axis: Axis) -> Coords: """E field yee lattice sites for axis.""" boundary_coords = self.boundaries.to_dict @@ -552,7 +597,7 @@ def _yee_e(self, axis: Axis): return Coords(**yee_coords) - def _yee_h(self, axis: Axis): + def _yee_h(self, axis: Axis) -> Coords: """H field yee lattice sites for axis.""" boundary_coords = self.boundaries.to_dict @@ -580,7 +625,7 @@ def discretize_inds(self, box: Box, extend: bool = False) -> list[tuple[int, int Returns ------- - List[Tuple[int, int]] + list[tuple[int, int]] The (start, stop) indexes of the cells that intersect with ``box`` in each of the three dimensions. """ @@ -654,6 +699,8 @@ def extended_subspace( reverse = True while ind_beg < 0: + if num_cells == 0: + break if periodic or not reverse: offset = padded_coords[0] - coords[-1] padded_coords = np.concatenate([coords[:-1] + offset, padded_coords]) @@ -667,6 +714,8 @@ def extended_subspace( reverse = True while ind_end >= padded_coords.size: + if num_cells == 0: + break if periodic or not reverse: offset = padded_coords[-1] - coords[0] padded_coords = np.concatenate([padded_coords, coords[1:] + offset]) @@ -678,7 +727,7 @@ def extended_subspace( return padded_coords[ind_beg:ind_end] - def snap_to_box_zero_dim(self, box: Box): + def snap_to_box_zero_dim(self, box: Box) -> Self: """Snap a grid to an exact box position for dimensions for which the box is size 0. If the box location is outside of the grid, an error is raised. @@ -694,6 +743,7 @@ def snap_to_box_zero_dim(self, box: Box): """ boundary_dict = self.boundaries.to_dict.copy() + for dim, center, size in zip("xyz", box.center, box.size): # Overwrite grid boundaries with box center if box is size 0 along dimension if size == 0: @@ -712,7 +762,9 @@ def _translated_copy(self, vector: Coordinate) -> Grid: ) return self.updated_copy(boundaries=boundaries) - def _get_geo_inds(self, geo: Geometry, span_inds: ArrayLike = None, expand_inds: int = 2): + def _get_geo_inds( + self, geo: Geometry, span_inds: ArrayLike = None, expand_inds: int = 2 + ) -> NDArray: """ Get ``geo_inds`` based on a geometry's bounding box, enlarged by ``expand_inds``. If ``span_inds`` is supplied, take the intersection of ``span_inds`` and ``geo``'s bounding @@ -729,7 +781,7 @@ def _get_geo_inds(self, geo: Geometry, span_inds: ArrayLike = None, expand_inds: Returns ------- - List[Tuple[int, int]] + np.ndarray The (start, stop) indexes of the cells for interpolation. """ # only interpolate inside the bounding box diff --git a/tidy3d/components/grid/grid_spec.py b/tidy3d/components/grid/grid_spec.py index a5e3f76404..a7810bad77 100644 --- a/tidy3d/components/grid/grid_spec.py +++ b/tidy3d/components/grid/grid_spec.py @@ -3,29 +3,31 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + PositiveInt, + field_validator, + model_validator, +) -from tidy3d.components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.geometry.base import Box, ClipOperation -from tidy3d.components.lumped_element import LumpedElementType -from tidy3d.components.source.utils import SourceType from tidy3d.components.structure import MeshOverrideStructure, Structure, StructureType from tidy3d.components.types import ( TYPE_TAG_STR, ArrayFloat1D, ArrayFloat2D, Axis, - Coordinate, CoordinateOptional, - PriorityMode, - Shapely, - Symmetry, Undefined, - annotate_type, ) +from tidy3d.components.types.base import discriminated_union from tidy3d.constants import C_0, MICROMETER, fp_eps, inf from tidy3d.exceptions import SetupError from tidy3d.log import log @@ -34,6 +36,12 @@ from .grid import Coords, Coords1D, Grid from .mesher import GradedMesher, MesherType +if TYPE_CHECKING: + from tidy3d.compat import Self + from tidy3d.components.lumped_element import LumpedElementType + from tidy3d.components.source.utils import SourceType + from tidy3d.components.types import Coordinate, PriorityMode, Shapely, Symmetry + # Scaling factor applied to internally generated lower bound of grid size that is computed from # estimated minimal grid size MIN_STEP_BOUND_SCALE = 0.5 @@ -44,6 +52,8 @@ # Tolerance for distinguishing pec/grid intersections GAP_MESHING_TOL = 1e-3 +CornersAndConvexity = tuple[list[ArrayFloat2D], list[ArrayFloat1D]] + class GridSpec1d(Tidy3dBaseModel, ABC): """Abstract base class, defines 1D grid generation specifications.""" @@ -54,8 +64,8 @@ def make_coords( structures: list[StructureType], symmetry: tuple[Symmetry, Symmetry, Symmetry], periodic: bool, - wavelength: pd.PositiveFloat, - num_pml_layers: tuple[pd.NonNegativeInt, pd.NonNegativeInt], + wavelength: PositiveFloat, + num_pml_layers: tuple[NonNegativeInt, NonNegativeInt], snapping_points: tuple[CoordinateOptional, ...], parse_structures_interval_coords: np.ndarray = None, parse_structures_max_dl_list: np.ndarray = None, @@ -67,9 +77,9 @@ def make_coords( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. - symmetry : Tuple[Symmetry, Symmetry, Symmetry] + symmetry : tuple[Symmetry, Symmetry, Symmetry] Reflection symmetry across a plane bisecting the simulation domain normal to each of the three axes. periodic : bool @@ -77,9 +87,9 @@ def make_coords( Only relevant for autogrids. wavelength : float Free-space wavelength. - num_pml_layers : Tuple[int, int] + num_pml_layers : tuple[int, int] number of layers in the absorber + and - direction along one dimension. - snapping_points : Tuple[CoordinateOptional, ...] + snapping_points : tuple[CoordinateOptional, ...] A set of points that enforce grid boundaries to pass through them. parse_structures_interval_coords : np.ndarray, optional If not None, pre-computed interval coordinates from parsing structures. @@ -135,7 +145,7 @@ def _make_coords_initial( Parameters ---------- - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. **kwargs Other arguments @@ -153,7 +163,7 @@ def _add_pml_to_bounds(num_layers: tuple[int, int], bounds: Coords1D) -> Coords1 Parameters ---------- - num_layers : Tuple[int, int] + num_layers : tuple[int, int] number of layers in the absorber + and - direction along one dimension. bound_coords : np.ndarray coordinates specifying boundaries between cells along one dimension. @@ -186,7 +196,7 @@ def _postprocess_unaligned_grid( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. machine_error_relaxation : bool When operations such as translation are applied to the 1d grids, fix the bounds @@ -254,9 +264,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -287,15 +297,15 @@ class UniformGrid(GridSpec1d): * `Using automatic nonuniform meshing <../../notebooks/AutoGrid.html>`_ """ - dl: pd.PositiveFloat = pd.Field( - ..., + dl: PositiveFloat = Field( title="Grid Size", description="Grid size for uniform grid generation.", units=MICROMETER, ) - @pd.validator("dl", always=True) - def _validate_dl(cls, val): + @field_validator("dl") + @classmethod + def _validate_dl(cls, val: PositiveFloat) -> PositiveFloat: """ Ensure 'dl' is not too small. """ @@ -319,7 +329,7 @@ def _make_coords_initial( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. Returns @@ -350,9 +360,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -372,8 +382,7 @@ class CustomGridBoundaries(GridSpec1d): >>> grid_1d = CustomGridBoundaries(coords=[-0.2, 0.0, 0.2, 0.4, 0.5, 0.6, 0.7]) """ - coords: Coords1D = pd.Field( - ..., + coords: Coords1D = Field( title="Grid Boundary Coordinates", description="An array of grid boundary coordinates.", units=MICROMETER, @@ -391,7 +400,7 @@ def _make_coords_initial( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. Returns @@ -416,9 +425,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -429,8 +438,9 @@ def estimated_min_dl( return min(np.diff(self.coords)) - @pd.validator("coords", always=True) - def _validate_coords(cls, val): + @field_validator("coords") + @classmethod + def _validate_coords(cls, val: Coords1D) -> Coords1D: """ Ensure 'coords' is sorted and has at least 2 entries. """ @@ -455,8 +465,7 @@ class CustomGrid(GridSpec1d): >>> grid_1d = CustomGrid(dl=[0.2, 0.2, 0.1, 0.1, 0.1, 0.2, 0.2]) """ - dl: tuple[pd.PositiveFloat, ...] = pd.Field( - ..., + dl: tuple[PositiveFloat, ...] = Field( title="Customized grid sizes.", description="An array of custom nonuniform grid sizes. The resulting grid is centered on " "the simulation center such that it spans the region " @@ -466,7 +475,7 @@ class CustomGrid(GridSpec1d): units=MICROMETER, ) - custom_offset: float = pd.Field( + custom_offset: Optional[float] = Field( None, title="Customized grid offset.", description="The starting coordinate of the grid which defines the simulation center. " @@ -487,7 +496,7 @@ def _make_coords_initial( ---------- axis : Axis Axis of this direction. - structures : List[StructureType] + structures : list[StructureType] List of structures present in simulation, the first one being the simulation domain. Returns @@ -525,9 +534,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -541,7 +550,7 @@ def estimated_min_dl( class AbstractAutoGrid(GridSpec1d): """Specification for non-uniform or quasi-uniform grid along a given dimension.""" - max_scale: float = pd.Field( + max_scale: float = Field( 1.4, title="Maximum Grid Size Scaling", description="Sets the maximum ratio between any two consecutive grid steps.", @@ -549,13 +558,13 @@ class AbstractAutoGrid(GridSpec1d): lt=2.0, ) - mesher: MesherType = pd.Field( - GradedMesher(), + mesher: MesherType = Field( + default_factory=GradedMesher, title="Grid Construction Tool", description="The type of mesher to use to generate the grid automatically.", ) - dl_min: pd.NonNegativeFloat = pd.Field( + dl_min: Optional[NonNegativeFloat] = Field( None, title="Lower Bound of Grid Size", description="Lower bound of the grid size along this dimension regardless of " @@ -810,8 +819,7 @@ class QuasiUniformGrid(AbstractAutoGrid): * `Using automatic nonuniform meshing <../../notebooks/AutoGrid.html>`_ """ - dl: pd.PositiveFloat = pd.Field( - ..., + dl: PositiveFloat = Field( title="Grid Size", description="Grid size for quasi-uniform grid generation. Grid size at some locations can be " "slightly smaller.", @@ -863,9 +871,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -901,14 +909,14 @@ class AutoGrid(AbstractAutoGrid): * `Numerical dispersion in FDTD `_ """ - min_steps_per_wvl: float = pd.Field( + min_steps_per_wvl: float = Field( 10.0, title="Minimal Number of Steps Per Wavelength", description="Minimal number of steps per wavelength in each medium.", ge=6.0, ) - min_steps_per_sim_size: float = pd.Field( + min_steps_per_sim_size: float = Field( 10.0, title="Minimal Number of Steps Per Simulation Domain Size", description="Minimal number of steps per longest edge length of simulation domain " @@ -955,9 +963,9 @@ def estimated_min_dl( ---------- wavelength : float Wavelength to use for the step size and for dispersive media epsilon. - structure_list : List[Structure] + structure_list : list[Structure] List of structures present in the simulation. - sim_size : Tuple[float, 3] + sim_size : tuple[float, 3] Simulation domain size. Returns @@ -994,27 +1002,27 @@ class GridRefinement(Tidy3dBaseModel): """ - refinement_factor: Optional[pd.PositiveFloat] = pd.Field( + refinement_factor: Optional[PositiveFloat] = Field( None, title="Mesh Refinement Factor", description="Refine grid step size in vacuum by this factor.", ) - dl: Optional[pd.PositiveFloat] = pd.Field( + dl: Optional[PositiveFloat] = Field( None, title="Grid Size", description="Grid step size in the refined region.", units=MICROMETER, ) - num_cells: pd.PositiveInt = pd.Field( + num_cells: PositiveInt = Field( 3, title="Number of Refined Grid Cells", description="Number of grid cells in the refinement region.", ) @property - def _refinement_factor(self) -> pd.PositiveFloat: + def _refinement_factor(self) -> PositiveFloat: """Refinement factor applied internally.""" if self.refinement_factor is None and self.dl is None: return DEFAULT_REFINEMENT_FACTOR @@ -1101,20 +1109,19 @@ class LayerRefinementSpec(Box): """ - axis: Axis = pd.Field( - ..., + axis: Axis = Field( title="Axis", description="Specifies dimension of the layer normal axis (0,1,2) -> (x,y,z).", ) - min_steps_along_axis: Optional[pd.PositiveFloat] = pd.Field( + min_steps_along_axis: Optional[PositiveFloat] = Field( None, title="Minimal Number Of Steps Along Axis", description="If not ``None`` and the thickness of the layer is nonzero, set minimal " "number of steps discretizing the layer thickness.", ) - bounds_refinement: Optional[GridRefinement] = pd.Field( + bounds_refinement: Optional[GridRefinement] = Field( None, title="Mesh Refinement Factor Around Layer Bounds", description="If not ``None``, refine mesh around minimum and maximum positions " @@ -1122,35 +1129,35 @@ class LayerRefinementSpec(Box): "refinement here is only applied if it sets a smaller grid size.", ) - bounds_snapping: Optional[Literal["bounds", "lower", "upper", "center"]] = pd.Field( + bounds_snapping: Optional[Literal["bounds", "lower", "upper", "center"]] = Field( "lower", title="Placing Grid Snapping Point Along Axis", description="If not ``None``, enforcing grid boundaries to pass through ``lower``, " "``center``, or ``upper`` position of the layer; or both ``lower`` and ``upper`` with ``bounds``.", ) - corner_finder: Optional[CornerFinderSpec] = pd.Field( - CornerFinderSpec(), + corner_finder: Optional[CornerFinderSpec] = Field( + default_factory=CornerFinderSpec, title="Inplane Corner Detection Specification", description="Specification for inplane corner detection. Inplane mesh refinement " "is based on the coordinates of those corners.", ) - corner_snapping: bool = pd.Field( + corner_snapping: bool = Field( True, title="Placing Grid Snapping Point At Corners", description="If ``True`` and ``corner_finder`` is not ``None``, enforcing inplane " "grid boundaries to pass through corners of geometries specified by ``corner_finder``.", ) - corner_refinement: Optional[GridRefinement] = pd.Field( - GridRefinement(), + corner_refinement: Optional[GridRefinement] = Field( + default_factory=GridRefinement, title="Inplane Mesh Refinement Factor Around Corners", description="If not ``None`` and ``corner_finder`` is not ``None``, refine mesh around " "corners of geometries specified by ``corner_finder``. ", ) - refinement_inside_sim_only: bool = pd.Field( + refinement_inside_sim_only: bool = Field( True, title="Apply Refinement Only To Features Inside Simulation Domain", description="If ``True``, only apply mesh refinement to features such as corners inside " @@ -1159,21 +1166,21 @@ class LayerRefinementSpec(Box): "and the projection of the simulation domain overlaps.", ) - gap_meshing_iters: pd.NonNegativeInt = pd.Field( + gap_meshing_iters: NonNegativeInt = Field( 1, title="Gap Meshing Iterations", description="Number of recursive iterations for resolving thin gaps. " "The underlying algorithm detects gaps contained in a single cell and places a snapping plane at the gaps's centers.", ) - dl_min_from_gap_width: bool = pd.Field( + dl_min_from_gap_width: bool = Field( True, title="Set ``dl_min`` from Estimated Gap Width", description="Take into account autodetected minimal PEC gap width when determining ``dl_min``. " "This only applies if ``dl_min`` in ``AutoGrid`` specification is not set.", ) - interior_disjoint_geometries: bool = pd.Field( + interior_disjoint_geometries: bool = Field( True, title="Geometries Are Interior-Disjoint", description="If ``True``, geometries made of different materials on the plane must not be overlapping. " @@ -1181,30 +1188,31 @@ class LayerRefinementSpec(Box): "of corner finder when there are many structures crossing the plane.", ) - @pd.validator("axis", always=True) - @skip_if_fields_missing(["size"]) - def _finite_size_along_axis(cls, val, values): + @model_validator(mode="after") + def _finite_size_along_axis(self) -> Self: + if self.size is None: + return self """size must be finite along axis.""" - if np.isinf(values["size"][val]): + if np.isinf(self.size[self.axis]): raise SetupError("'size' must take finite values along 'axis' dimension.") - return val + return self @classmethod def from_layer_bounds( cls, axis: Axis, bounds: tuple[float, float], - min_steps_along_axis: np.PositiveFloat = None, + min_steps_along_axis: PositiveFloat = None, bounds_refinement: GridRefinement = None, bounds_snapping: Literal["bounds", "lower", "upper", "center"] = "lower", corner_finder: Union[CornerFinderSpec, None, object] = Undefined, corner_snapping: bool = True, corner_refinement: Union[GridRefinement, None, object] = Undefined, refinement_inside_sim_only: bool = True, - gap_meshing_iters: pd.NonNegativeInt = 1, + gap_meshing_iters: NonNegativeInt = 1, dl_min_from_gap_width: bool = True, **kwargs: Any, - ): + ) -> Self: """Constructs a :class:`LayerRefinementSpec` that is unbounded in inplane dimensions from bounds along layer thickness dimension. @@ -1212,9 +1220,9 @@ def from_layer_bounds( ---------- axis : Axis Specifies dimension of the layer normal axis (0,1,2) -> (x,y,z). - bounds : Tuple[float, float] + bounds : tuple[float, float] Minimum and maximum positions of the layer along axis dimension. - min_steps_along_axis : np.PositiveFloat = None + min_steps_along_axis : PositiveFloat = None Minimal number of steps along axis. bounds_refinement : GridRefinement = None Mesh refinement factor around layer bounds. @@ -1270,29 +1278,29 @@ def from_bounds( rmin: Coordinate, rmax: Coordinate, axis: Axis = None, - min_steps_along_axis: np.PositiveFloat = None, + min_steps_along_axis: PositiveFloat = None, bounds_refinement: GridRefinement = None, bounds_snapping: Literal["bounds", "lower", "upper", "center"] = "lower", corner_finder: CornerFinderSpec = Undefined, corner_snapping: bool = True, corner_refinement: GridRefinement = Undefined, refinement_inside_sim_only: bool = True, - gap_meshing_iters: pd.NonNegativeInt = 1, + gap_meshing_iters: NonNegativeInt = 1, dl_min_from_gap_width: bool = True, **kwargs: Any, - ): + ) -> Self: """Constructs a :class:`LayerRefinementSpec` from minimum and maximum coordinate bounds. Parameters ---------- - rmin : Tuple[float, float, float] + rmin : tuple[float, float, float] (x, y, z) coordinate of the minimum values. - rmax : Tuple[float, float, float] + rmax : tuple[float, float, float] (x, y, z) coordinate of the maximum values. axis : Axis Specifies dimension of the layer normal axis (0,1,2) -> (x,y,z). If ``None``, apply the dimension along which the layer thas smallest thickness. - min_steps_along_axis : np.PositiveFloat = None + min_steps_along_axis : PositiveFloat = None Minimal number of steps along axis. bounds_refinement : GridRefinement = None Mesh refinement factor around layer bounds. @@ -1347,27 +1355,27 @@ def from_structures( cls, structures: list[Structure], axis: Axis = None, - min_steps_along_axis: np.PositiveFloat = None, + min_steps_along_axis: PositiveFloat = None, bounds_refinement: GridRefinement = None, bounds_snapping: Literal["bounds", "lower", "upper", "center"] = "lower", corner_finder: CornerFinderSpec = Undefined, corner_snapping: bool = True, corner_refinement: GridRefinement = Undefined, refinement_inside_sim_only: bool = True, - gap_meshing_iters: pd.NonNegativeInt = 1, + gap_meshing_iters: NonNegativeInt = 1, dl_min_from_gap_width: bool = True, **kwargs: Any, - ): + ) -> Self: """Constructs a :class:`LayerRefinementSpec` from the bounding box of a list of structures. Parameters ---------- - structures : List[Structure] + structures : list[Structure] A list of structures whose overall bounding box is used to define mesh refinement axis : Axis Specifies dimension of the layer normal axis (0,1,2) -> (x,y,z). If ``None``, apply the dimension along which the bounding box of the structures thas smallest thickness. - min_steps_along_axis : np.PositiveFloat = None + min_steps_along_axis : PositiveFloat = None Minimal number of steps along axis. bounds_refinement : GridRefinement = None Mesh refinement factor around layer bounds. @@ -1455,8 +1463,8 @@ def suggested_dl_min( structures: list[Structure], sim_bounds: tuple, boundary_type: tuple, - cached_merged_geos=None, - cached_corners_and_convexity=None, + cached_merged_geos: Optional[list[tuple[Any, Shapely]]] = None, + cached_corners_and_convexity: Optional[CornersAndConvexity] = None, ) -> float: """Suggested lower bound of grid step size for this layer. @@ -1518,8 +1526,8 @@ def generate_snapping_points( structure_list: list[Structure], sim_bounds: tuple, boundary_type: tuple, - cached_corners_and_convexity=None, - cached_merged_geos=None, + cached_corners_and_convexity: Optional[CornersAndConvexity] = None, + cached_merged_geos: Optional[list[tuple[Any, Shapely]]] = None, ) -> list[CoordinateOptional]: """generate snapping points for mesh refinement.""" snapping_points = self._snapping_points_along_axis @@ -1539,8 +1547,8 @@ def generate_override_structures( structure_list: list[Structure], sim_bounds: tuple, boundary_type: tuple, - cached_corners_and_convexity=None, - cached_merged_geos=None, + cached_corners_and_convexity: Optional[CornersAndConvexity] = None, + cached_merged_geos: Optional[list[tuple[Any, Shapely]]] = None, ) -> list[MeshOverrideStructure]: """Generate mesh override structures for mesh refinement.""" return self._override_structures_along_axis( @@ -1715,9 +1723,9 @@ def _dl_min_from_smallest_feature( structure_list: list[Structure], sim_bounds: tuple, boundary_type: tuple, - cached_merged_geos=None, - cached_corners_and_convexity=None, - ): + cached_merged_geos: Optional[list[tuple[Any, Shapely]]] = None, + cached_corners_and_convexity: Optional[CornersAndConvexity] = None, + ) -> float: """Calculate `dl_min` suggestion based on smallest feature size.""" if cached_corners_and_convexity is None: @@ -1771,8 +1779,8 @@ def _corners( structure_list: list[Structure], sim_bounds: tuple, boundary_type: tuple, - cached_corners_and_convexity=None, - cached_merged_geos=None, + cached_corners_and_convexity: Optional[CornersAndConvexity] = None, + cached_merged_geos: Optional[list[tuple[Any, Shapely]]] = None, ) -> list[CoordinateOptional]: """Inplane corners in 3D coordinate.""" if self.corner_finder is None: @@ -1832,8 +1840,8 @@ def _override_structures_inplane( grid_size_in_vacuum: float, sim_bounds: tuple, boundary_type: tuple, - cached_corners_and_convexity=None, - cached_merged_geos=None, + cached_corners_and_convexity: Optional[CornersAndConvexity] = None, + cached_merged_geos: Optional[list[tuple[Any, Shapely]]] = None, ) -> list[MeshOverrideStructure]: """Inplane mesh override structures for refining mesh around corners.""" if self.corner_refinement is None: @@ -1909,8 +1917,12 @@ def _override_structures_along_axis( return override_structures def _find_vertical_intersections( - self, grid_x_coords, grid_y_coords, poly_vertices, boundary - ) -> tuple[list[tuple[int, int]], list[float]]: + self, + grid_x_coords: ArrayFloat1D, + grid_y_coords: ArrayFloat1D, + poly_vertices: ArrayFloat2D, + boundary: tuple[Optional[str], Optional[str]], + ) -> tuple[np.typing.NDArray[np.int_], np.typing.NDArray[np.float64]]: """Detect intersection points of single polygon and vertical grid lines.""" # indices of cells that contain intersection with grid lines (left edge of a cell) @@ -2075,12 +2087,24 @@ def _find_vertical_intersections( np.zeros(len(cells_ij_one_side)), ] ) + else: + cells_ij = np.empty((0, 2), dtype=int) + cells_dy = np.empty(0, dtype=float) return cells_ij, cells_dy def _process_poly( - self, grid_x_coords, grid_y_coords, poly_vertices, boundaries - ) -> tuple[list[tuple[int, int]], list[float], list[tuple[int, int]], list[float]]: + self, + grid_x_coords: ArrayFloat1D, + grid_y_coords: ArrayFloat1D, + poly_vertices: ArrayFloat2D, + boundaries: tuple[tuple[Optional[str], Optional[str]], tuple[Optional[str], Optional[str]]], + ) -> tuple[ + np.typing.NDArray[np.int_], + np.typing.NDArray[np.float64], + np.typing.NDArray[np.int_], + np.typing.NDArray[np.float64], + ]: """Detect intersection points of single polygon and grid lines.""" # find cells that contain intersections of vertical grid lines @@ -2102,8 +2126,17 @@ def _process_poly( return v_cells_ij, v_cells_dy, h_cells_ij, h_cells_dx def _process_slice( - self, x, y, merged_geos, boundaries - ) -> tuple[list[tuple[int, int]], list[float], list[tuple[int, int]], list[float]]: + self, + x: ArrayFloat1D, + y: ArrayFloat1D, + merged_geos: list[tuple[Any, Shapely]], + boundaries: list[list[Optional[str]]], + ) -> tuple[ + np.typing.NDArray[np.int_], + np.typing.NDArray[np.float64], + np.typing.NDArray[np.int_], + np.typing.NDArray[np.float64], + ]: """Detect intersection points of geometries boundaries and grid lines.""" # cells that contain intersections of vertical grid lines @@ -2177,16 +2210,25 @@ def _process_slice( if len(v_cells_ij) > 0: v_cells_ij = np.concatenate(v_cells_ij) v_cells_dy = np.concatenate(v_cells_dy) + else: + v_cells_ij = np.empty((0, 2), dtype=int) + v_cells_dy = np.empty(0, dtype=float) if len(h_cells_ij) > 0: h_cells_ij = np.concatenate(h_cells_ij) h_cells_dx = np.concatenate(h_cells_dx) + else: + h_cells_ij = np.empty((0, 2), dtype=int) + h_cells_dx = np.empty(0, dtype=float) return v_cells_ij, v_cells_dy, h_cells_ij, h_cells_dx def _generate_horizontal_snapping_lines( - self, grid_y_coords, intersected_cells_ij, relative_vert_disp - ) -> tuple[list[CoordinateOptional], float]: + self, + grid_y_coords: ArrayFloat1D, + intersected_cells_ij: np.typing.NDArray[np.int_], + relative_vert_disp: np.typing.NDArray[np.float64], + ) -> tuple[list[float], float]: """Convert a list of intersections of vertical grid lines, given as coordinates of cells and relative vertical displacement inside each cell, into locations of snapping lines that resolve thin gaps and strips. @@ -2263,8 +2305,15 @@ def _generate_horizontal_snapping_lines( return snapping_lines_y, min_gap_width def _resolve_gaps( - self, grid: Grid, merged_geos: list[tuple[Any, Shapely]], boundary_type: tuple - ) -> tuple[list[CoordinateOptional], float]: + self, + grid: Grid, + merged_geos: list[tuple[Any, Shapely]], + boundary_type: tuple[ + tuple[Optional[str], Optional[str]], + tuple[Optional[str], Optional[str]], + tuple[Optional[str], Optional[str]], + ], + ) -> tuple[tuple[CoordinateOptional], float]: """ Detect underresolved gaps and place snapping lines in them. Also return the detected minimal gap width. @@ -2280,7 +2329,7 @@ def _resolve_gaps( Returns ------- - tuple[list[CoordinateOptional], float] + list[list[CoordinateOptional], float] List of snapping lines and the detected minimal gap width. """ @@ -2391,28 +2440,28 @@ class GridSpec(Tidy3dBaseModel): * `Numerical dispersion in FDTD `_ """ - grid_x: GridType = pd.Field( - AutoGrid(), + grid_x: GridType = Field( + default_factory=AutoGrid, title="Grid specification along x-axis", description="Grid specification along x-axis", discriminator=TYPE_TAG_STR, ) - grid_y: GridType = pd.Field( - AutoGrid(), + grid_y: GridType = Field( + default_factory=AutoGrid, title="Grid specification along y-axis", description="Grid specification along y-axis", discriminator=TYPE_TAG_STR, ) - grid_z: GridType = pd.Field( - AutoGrid(), + grid_z: GridType = Field( + default_factory=AutoGrid, title="Grid specification along z-axis", description="Grid specification along z-axis", discriminator=TYPE_TAG_STR, ) - wavelength: float = pd.Field( + wavelength: Optional[PositiveFloat] = Field( None, title="Free-space wavelength", description="Free-space wavelength for automatic nonuniform grid. It can be ``None`` " @@ -2423,7 +2472,7 @@ class GridSpec(Tidy3dBaseModel): units=MICROMETER, ) - override_structures: tuple[annotate_type(StructureType), ...] = pd.Field( + override_structures: tuple[discriminated_union(StructureType), ...] = Field( (), title="Grid specification override structures", description="A set of structures that is added on top of the simulation structures in " @@ -2433,7 +2482,7 @@ class GridSpec(Tidy3dBaseModel): "uses :class:`.AutoGrid` or :class:`.QuasiUniformGrid`.", ) - snapping_points: tuple[CoordinateOptional, ...] = pd.Field( + snapping_points: tuple[CoordinateOptional, ...] = Field( (), title="Grid specification snapping_points", description="A set of points that enforce grid boundaries to pass through them. " @@ -2444,7 +2493,7 @@ class GridSpec(Tidy3dBaseModel): "uses :class:`.AutoGrid` or :class:`.QuasiUniformGrid`.", ) - layer_refinement_specs: tuple[LayerRefinementSpec, ...] = pd.Field( + layer_refinement_specs: tuple[LayerRefinementSpec, ...] = Field( (), title="Mesh Refinement In Layered Structures", description="Automatic mesh refinement according to layer specifications. The material " @@ -2474,7 +2523,7 @@ def custom_grid_used(self) -> bool: return np.any([isinstance(mesh, (CustomGrid, CustomGridBoundaries)) for mesh in grid_list]) @staticmethod - def wavelength_from_sources(sources: list[SourceType]) -> pd.PositiveFloat: + def wavelength_from_sources(sources: list[SourceType]) -> PositiveFloat: """Define a wavelength based on supplied sources. Called if auto mesh is used and ``self.wavelength is None``.""" @@ -2502,26 +2551,23 @@ def layer_refinement_used(self) -> bool: return len(self.layer_refinement_specs) > 0 @property - def snapping_points_used(self) -> list[bool, bool, bool]: + def snapping_points_used(self) -> tuple[bool, bool, bool]: """Along each axis, ``True`` if any snapping point is used. However, it is still ``False`` if all snapping points take value ``None`` along the axis. """ # empty list - if len(self.snapping_points) == 0: - return [False] * 3 - - snapping_used = [False] * 3 - for point in self.snapping_points: - for ind_coord, coord in enumerate(point): - if snapping_used[ind_coord]: - continue - if coord is not None: - snapping_used[ind_coord] = True - return snapping_used + if not self.snapping_points: + return False, False, False + + x_used = any(p[0] is not None for p in self.snapping_points) + y_used = any(p[1] is not None for p in self.snapping_points) + z_used = any(p[2] is not None for p in self.snapping_points) + + return x_used, y_used, z_used @property - def override_structures_used(self) -> list[bool, bool, bool]: + def override_structures_used(self) -> tuple[bool, bool, bool]: """Along each axis, ``True`` if any override structure is used. However, it is still ``False`` if only :class:`.MeshOverrideStructure` is supplied, and their ``dl[axis]`` all take the ``None`` value. @@ -2529,18 +2575,18 @@ def override_structures_used(self) -> list[bool, bool, bool]: # empty override_structure list if len(self.override_structures) == 0: - return [False] * 3 + return False, False, False override_used = [False] * 3 for structure in self.override_structures: # override used in all axes if any `Structure` is present if isinstance(structure, Structure): - return [True] * 3 + return True, True, True for dl_axis, dl in enumerate(structure.dl): if (not override_used[dl_axis]) and (dl is not None): override_used[dl_axis] = True - return override_used + return tuple(override_used) def internal_snapping_points( self, @@ -2548,21 +2594,21 @@ def internal_snapping_points( lumped_elements: list[LumpedElementType], boundary_types: tuple[tuple[str, str], tuple[str, str], tuple[str, str]], sim_bounds: tuple, - cached_corners_and_convexity=None, - cached_merged_geos=None, + cached_corners_and_convexity: Optional[list[CornersAndConvexity]] = None, + cached_merged_geos: Optional[list[tuple[Any, Shapely]]] = None, ) -> list[CoordinateOptional]: """Internal snapping points. So far, internal snapping points are generated by `layer_refinement_specs` and lumped element. Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of physical structures. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. boundary_types : tuple[tuple[str, str], tuple[str, str], tuple[str, str]] Boundary type of the simulation domain. - cached_corners_and_convexity : Optional[list[CachedCornersAndConvexity]] + cached_corners_and_convexity : Optional[list[CornersAndConvexity]] Cached corners and convexity data. cached_merged_geos : Optional[list[list[tuple[Any, Shapely]]]] Cached merged geometries for each layer. If None, will be computed. @@ -2571,7 +2617,7 @@ def internal_snapping_points( Returns ------- - List[CoordinateOptional] + list[CoordinateOptional] List of snapping points coordinates. """ @@ -2616,18 +2662,18 @@ def all_snapping_points( Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of physical structures. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. boundary_types : tuple[tuple[str, str], tuple[str, str], tuple[str, str]] Boundary type of the simulation domain. - internal_snapping_points : List[CoordinateOptional] + internal_snapping_points : list[CoordinateOptional] If `None`, recomputes internal snapping points. Returns ------- - List[CoordinateOptional] + list[CoordinateOptional] List of snapping points coordinates. """ @@ -2648,27 +2694,27 @@ def external_override_structures(self) -> list[StructureType]: def internal_override_structures( self, structures: list[Structure], - wavelength: pd.PositiveFloat, + wavelength: PositiveFloat, sim_bounds: tuple, lumped_elements: list[LumpedElementType], boundary_types: tuple[tuple[str, str], tuple[str, str], tuple[str, str]], - cached_corners_and_convexity=None, - cached_merged_geos=None, + cached_corners_and_convexity: Optional[list[CornersAndConvexity]] = None, + cached_merged_geos: Optional[list[tuple[Any, Shapely]]] = None, ) -> list[StructureType]: """Internal mesh override structures. So far, internal override structures are generated by `layer_refinement_specs` and lumped element. Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures, with the simulation structure being the first item. - wavelength : pd.PositiveFloat + wavelength : PositiveFloat Wavelength to use for minimal step size in vaccum. lumped_elements : List[LumpedElementType] List of lumped elements. boundary_types : tuple[tuple[str, str], tuple[str, str], tuple[str, str]] Boundary type of the simulation domain. - cached_corners_and_convexity : Optional[list[CachedCornersAndConvexity]] + cached_corners_and_convexity : Optional[list[CornersAndConvexity]] Cached corners and convexity data. cached_merged_geos : Optional[list[list[tuple[Any, Shapely]]]] Cached merged geometries for each layer. If None, will be computed. @@ -2677,7 +2723,7 @@ def internal_override_structures( Returns ------- - List[StructureType] + list[StructureType] List of override structures. """ @@ -2715,7 +2761,7 @@ def internal_override_structures( def all_override_structures( self, structures: list[Structure], - wavelength: pd.PositiveFloat, + wavelength: PositiveFloat, lumped_elements: list[LumpedElementType], boundary_types: tuple[tuple[str, str], tuple[str, str], tuple[str, str]], sim_bounds: tuple, @@ -2727,9 +2773,9 @@ def all_override_structures( Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures, with the simulation structure being the first item. - wavelength : pd.PositiveFloat + wavelength : PositiveFloat Wavelength to use for minimal step size in vaccum. lumped_elements : List[LumpedElementType] List of lumped elements. @@ -2739,11 +2785,11 @@ def all_override_structures( Boundary type of the simulation domain. structure_priority_mode : PriorityMode Structure priority setting. - internal_override_structures : List[MeshOverrideStructure] + internal_override_structures : list[MeshOverrideStructure] If `None`, recomputes internal override structures. Returns ------- - List[StructureType] + list[StructureType] List of sorted override structures. """ @@ -2761,7 +2807,7 @@ def all_override_structures( def _get_all_structures_affecting_grid( self, structures: list[Structure], - wavelength: pd.PositiveFloat, + wavelength: PositiveFloat, lumped_elements: list[LumpedElementType], boundary_types: tuple[tuple[str, str], tuple[str, str], tuple[str, str]], sim_bounds: tuple, @@ -2881,7 +2927,7 @@ def make_grid( symmetry: tuple[Symmetry, Symmetry, Symmetry], periodic: tuple[bool, bool, bool], sources: list[SourceType], - num_pml_layers: list[tuple[pd.NonNegativeInt, pd.NonNegativeInt]], + num_pml_layers: list[tuple[NonNegativeInt, NonNegativeInt]], lumped_elements: list[LumpedElementType] = (), internal_override_structures: Optional[list[MeshOverrideStructure]] = None, internal_snapping_points: Optional[list[CoordinateOptional]] = None, @@ -2896,26 +2942,26 @@ def make_grid( Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures present in the simulation. The first structure must be the simulation geometry with the simulation background medium. - symmetry : Tuple[Symmetry, Symmetry, Symmetry] + symmetry : tuple[Symmetry, Symmetry, Symmetry] Reflection symmetry across a plane bisecting the simulation domain normal to each of the three axes. - periodic: Tuple[bool, bool, bool] + periodic: tuple[bool, bool, bool] Apply periodic boundary condition or not along each of the dimensions. Only relevant for autogrids. - sources : List[SourceType] + sources : list[SourceType] List of sources. - num_pml_layers : List[Tuple[float, float]] + num_pml_layers : list[tuple[float, float]] List containing the number of absorber layers in - and + boundaries. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. - internal_override_structures : List[MeshOverrideStructure] - If `None`, recomputes internal override structures. - internal_snapping_points : List[CoordinateOptional] - If `None`, recomputes internal snapping points. - boundary_types : Tuple[Tuple[str, str], Tuple[str, str], Tuple[str, str]] = [[None, None], [None, None], [None, None]] + internal_override_structures : list[MeshOverrideStructure] + If ``None``, recomputes internal override structures. + internal_snapping_points : list[CoordinateOptional] + If ``None``, recomputes internal snapping points. + boundary_types : tuple[tuple[str, str], tuple[str, str], tuple[str, str]] = [[None, None], [None, None], [None, None]] Type of boundary conditions along each dimension: "pec/pmc", "periodic", or None for any other. This is relevant only for gap meshing. structure_priority_mode : PriorityMode @@ -2947,7 +2993,7 @@ def _make_grid_and_snapping_lines( symmetry: tuple[Symmetry, Symmetry, Symmetry], periodic: tuple[bool, bool, bool], sources: list[SourceType], - num_pml_layers: list[tuple[pd.NonNegativeInt, pd.NonNegativeInt]], + num_pml_layers: list[tuple[NonNegativeInt, NonNegativeInt]], lumped_elements: list[LumpedElementType] = (), internal_override_structures: Optional[list[MeshOverrideStructure]] = None, internal_snapping_points: Optional[list[CoordinateOptional]] = None, @@ -2964,26 +3010,26 @@ def _make_grid_and_snapping_lines( Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures present in the simulation. The first structure must be the simulation geometry with the simulation background medium. - symmetry : Tuple[Symmetry, Symmetry, Symmetry] + symmetry : tuple[Symmetry, Symmetry, Symmetry] Reflection symmetry across a plane bisecting the simulation domain normal to each of the three axes. - periodic: Tuple[bool, bool, bool] + periodic: tuple[bool, bool, bool] Apply periodic boundary condition or not along each of the dimensions. Only relevant for autogrids. - sources : List[SourceType] + sources : list[SourceType] List of sources. - num_pml_layers : List[Tuple[float, float]] + num_pml_layers : list[tuple[float, float]] List containing the number of absorber layers in - and + boundaries. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. - internal_override_structures : List[MeshOverrideStructure] + internal_override_structures : list[MeshOverrideStructure] If `None`, recomputes internal override structures. - internal_snapping_points : List[CoordinateOptional] + internal_snapping_points : list[CoordinateOptional] If `None`, recomputes internal snapping points. - boundary_types : Tuple[Tuple[str, str], Tuple[str, str], Tuple[str, str]] = [[None, None], [None, None], [None, None]] + boundary_types : tuple[tuple[str, str], tuple[str, str], tuple[str, str]] = [[None, None], [None, None], [None, None]] Type of boundary conditions along each dimension: "pec/pmc", "periodic", or None for any other. This is relevant only for gap meshing. structure_priority_mode : PriorityMode @@ -2993,7 +3039,7 @@ def _make_grid_and_snapping_lines( Returns ------- - Tuple[Grid, List[CoordinateOptional]]: + tuple[Grid, list[CoordinateOptional]]: Entire simulation grid and snapping points generated during iterative gap meshing. """ @@ -3127,12 +3173,12 @@ def _make_grid_one_iteration( symmetry: tuple[Symmetry, Symmetry, Symmetry], periodic: tuple[bool, bool, bool], sources: list[SourceType], - num_pml_layers: list[tuple[pd.NonNegativeInt, pd.NonNegativeInt]], + num_pml_layers: list[tuple[NonNegativeInt, NonNegativeInt]], boundary_types: tuple[tuple[str, str], tuple[str, str], tuple[str, str]], lumped_elements: list[LumpedElementType] = (), internal_override_structures: Optional[list[MeshOverrideStructure]] = None, internal_snapping_points: Optional[list[CoordinateOptional]] = None, - dl_min_from_gaps: pd.PositiveFloat = inf, + dl_min_from_gaps: PositiveFloat = inf, structure_priority_mode: PriorityMode = "equal", parse_structures_interval_coords: Optional[list[np.ndarray]] = None, parse_structures_max_dl_list: Optional[list[np.ndarray]] = None, @@ -3141,26 +3187,26 @@ def _make_grid_one_iteration( Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures present in the simulation. The first structure must be the simulation geometry with the simulation background medium. - symmetry : Tuple[Symmetry, Symmetry, Symmetry] + symmetry : tuple[Symmetry, Symmetry, Symmetry] Reflection symmetry across a plane bisecting the simulation domain normal to each of the three axes. - periodic: Tuple[bool, bool, bool] + periodic: tuple[bool, bool, bool] Apply periodic boundary condition or not along each of the dimensions. Only relevant for autogrids. - sources : List[SourceType] + sources : list[SourceType] List of sources. - num_pml_layers : List[Tuple[float, float]] + num_pml_layers : list[tuple[float, float]] List containing the number of absorber layers in - and + boundaries. - lumped_elements : List[LumpedElementType] + lumped_elements : list[LumpedElementType] List of lumped elements. - internal_override_structures : List[MeshOverrideStructure] + internal_override_structures : list[MeshOverrideStructure] If `None`, recomputes internal override structures. - internal_snapping_points : List[CoordinateOptional] + internal_snapping_points : list[CoordinateOptional] If `None`, recomputes internal snapping points. - dl_min_from_gaps : pd.PositiveFloat + dl_min_from_gaps : PositiveFloat Minimal grid size computed based on autodetected gaps. structure_priority_mode : PriorityMode Structure priority setting. @@ -3298,39 +3344,39 @@ def from_grid(cls, grid: Grid) -> GridSpec: @classmethod def auto( cls, - wavelength: pd.PositiveFloat = None, - min_steps_per_wvl: pd.PositiveFloat = 10.0, - max_scale: pd.PositiveFloat = 1.4, + wavelength: PositiveFloat = None, + min_steps_per_wvl: PositiveFloat = 10.0, + max_scale: PositiveFloat = 1.4, override_structures: list[StructureType] = (), snapping_points: tuple[CoordinateOptional, ...] = (), layer_refinement_specs: list[LayerRefinementSpec] = (), - dl_min: pd.NonNegativeFloat = 0.0, - min_steps_per_sim_size: pd.PositiveFloat = 10.0, + dl_min: NonNegativeFloat = 0.0, + min_steps_per_sim_size: PositiveFloat = 10.0, mesher: MesherType = Undefined, ) -> GridSpec: """Use the same :class:`.AutoGrid` along each of the three directions. Parameters ---------- - wavelength : pd.PositiveFloat, optional + wavelength : PositiveFloat, optional Free-space wavelength for automatic nonuniform grid. It can be 'None' if there is at least one source in the simulation, in which case it is defined by the source central frequency. - min_steps_per_wvl : pd.PositiveFloat, optional + min_steps_per_wvl : PositiveFloat, optional Minimal number of steps per wavelength in each medium. - max_scale : pd.PositiveFloat, optional + max_scale : PositiveFloat, optional Sets the maximum ratio between any two consecutive grid steps. - override_structures : List[StructureType] + override_structures : list[StructureType] A list of structures that is added on top of the simulation structures in the process of generating the grid. This can be used to refine the grid or make it coarser depending than the expected need for higher/lower resolution regions. - snapping_points : Tuple[CoordinateOptional, ...] + snapping_points : tuple[CoordinateOptional, ...] A set of points that enforce grid boundaries to pass through them. - layer_refinement_specs: List[LayerRefinementSpec] + layer_refinement_specs: list[LayerRefinementSpec] Mesh refinement according to layer specifications. - dl_min: pd.NonNegativeFloat + dl_min: NonNegativeFloat Lower bound of grid size. - min_steps_per_sim_size : pd.PositiveFloat, optional + min_steps_per_sim_size : PositiveFloat, optional Minimal number of steps per longest edge length of simulation domain. mesher : MesherType = GradedMesher() The type of mesher to use to generate the grid automatically. @@ -3382,7 +3428,7 @@ def uniform(cls, dl: float) -> GridSpec: def quasiuniform( cls, dl: float, - max_scale: pd.PositiveFloat = 1.4, + max_scale: PositiveFloat = 1.4, override_structures: list[StructureType] = (), snapping_points: tuple[CoordinateOptional, ...] = (), mesher: MesherType = Undefined, @@ -3393,13 +3439,13 @@ def quasiuniform( ---------- dl : float Grid size for quasi-uniform grid generation. - max_scale : pd.PositiveFloat, optional + max_scale : PositiveFloat, optional Sets the maximum ratio between any two consecutive grid steps. - override_structures : List[StructureType] + override_structures : list[StructureType] A list of structures that is added on top of the simulation structures in the process of generating the grid. This can be used to snap grid points to the bounding box boundary. - snapping_points : Tuple[CoordinateOptional, ...] + snapping_points : tuple[CoordinateOptional, ...] A set of points that enforce grid boundaries to pass through them. mesher : MesherType = GradedMesher() The type of mesher to use to generate the grid automatically. diff --git a/tidy3d/components/grid/mesher.py b/tidy3d/components/grid/mesher.py index 589e4a82a8..e4b3be63b2 100644 --- a/tidy3d/components/grid/mesher.py +++ b/tidy3d/components/grid/mesher.py @@ -8,22 +8,26 @@ from abc import ABC, abstractmethod from itertools import compress from math import isclose -from typing import Union +from typing import TYPE_CHECKING, Union import numpy as np -import pydantic.v1 as pd from pyroots import Brentq from shapely.errors import ShapelyDeprecationWarning from shapely.geometry import box as shapely_box from shapely.strtree import STRtree from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.structure import MeshOverrideStructure, Structure, StructureType -from tidy3d.components.types import ArrayFloat1D, Axis, Bound, CoordinateOptional +from tidy3d.components.structure import MeshOverrideStructure, Structure from tidy3d.constants import C_0, fp_eps from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log +if TYPE_CHECKING: + from pydantic import NonNegativeFloat, NonNegativeInt, PositiveFloat + + from tidy3d.components.structure import StructureType + from tidy3d.components.types import ArrayFloat1D, Axis, Bound, CoordinateOptional + _ROOTS_TOL = 1e-10 # Shrink min_step a little so that if e.g. a structure has target dl = 0.1 and a width of 0.1, @@ -40,17 +44,17 @@ def parse_structures( self, axis: Axis, structures: list[StructureType], - wavelength: pd.PositiveFloat, - min_steps_per_wvl: pd.NonNegativeInt, - dl_min: pd.NonNegativeFloat, - dl_max: pd.NonNegativeFloat, + wavelength: PositiveFloat, + min_steps_per_wvl: NonNegativeInt, + dl_min: NonNegativeFloat, + dl_max: NonNegativeFloat, ) -> tuple[ArrayFloat1D, ArrayFloat1D]: """Calculate the positions of all bounding box interfaces along a given axis.""" @abstractmethod def insert_snapping_points( self, - dl_min: pd.NonNegativeFloat, + dl_min: NonNegativeFloat, axis: Axis, interval_coords: ArrayFloat1D, max_dl_list: ArrayFloat1D, @@ -100,7 +104,7 @@ class GradedMesher(Mesher): def insert_snapping_points( self, - dl_min: pd.NonNegativeFloat, + dl_min: NonNegativeFloat, axis: Axis, interval_coords: ArrayFloat1D, max_dl_list: ArrayFloat1D, @@ -110,7 +114,7 @@ def insert_snapping_points( Parameters ---------- - dl_min: pd.NonNegativeFloat + dl_min: NonNegativeFloat Lower bound of grid size. axis : Axis Axis index along which to operate. @@ -118,7 +122,7 @@ def insert_snapping_points( Coordinate of interval boundaries. max_dl_list : ArrayFloat1D Maximal allowed step size of each interval generated from `parse_structures`. - snapping_points : List[CoordinateOptional] + snapping_points : list[CoordinateOptional] A set of points that enforce grid boundaries to pass through them. Returns @@ -185,10 +189,10 @@ def parse_structures( self, axis: Axis, structures: list[StructureType], - wavelength: pd.PositiveFloat, - min_steps_per_wvl: pd.NonNegativeInt, - dl_min: pd.NonNegativeFloat, - dl_max: pd.NonNegativeFloat, + wavelength: PositiveFloat, + min_steps_per_wvl: NonNegativeInt, + dl_min: NonNegativeFloat, + dl_max: NonNegativeFloat, ) -> tuple[ArrayFloat1D, ArrayFloat1D]: """Calculate the positions of all bounding box interfaces along a given axis. In this implementation, in most cases the complexity should be O(len(structures)**2), @@ -199,15 +203,15 @@ def parse_structures( ---------- axis : Axis Axis index along which to operate. - structures : List[StructureType] + structures : list[StructureType] List of structures, with the simulation structure being the first item. - wavelength : pd.PositiveFloat + wavelength : PositiveFloat Wavelength to use for the step size and for dispersive media epsilon. - min_steps_per_wvl : pd.NonNegativeInt + min_steps_per_wvl : NonNegativeInt Minimum requested steps per wavelength. - dl_min: pd.NonNegativeFloat + dl_min: NonNegativeFloat Lower bound of grid size. - dl_max: pd.NonNegativeFloat + dl_max: NonNegativeFloat Upper bound of grid size. Returns @@ -412,14 +416,14 @@ def insert_bbox( Parameters ---------- - intervals : Dict[str, List] + intervals : dict[str, List] Dictionary containing the coordinates of the interval boundaries, and a list of lists of structures contained in each interval. str_ind : int Index of the current structure. str_bbox : ArrayFloat1D Bounding box of the current structure. - bbox_contained_2d : List[ArrayFloat1D] + bbox_contained_2d : list[ArrayFloat1D] List of 3D bounding boxes that contain the current structure in 2D. min_step : float Absolute minimum interval size to impose. @@ -549,12 +553,12 @@ def reorder_structures( Parameters ---------- - structures : List[StructureType] + structures : list[StructureType] List of structures, with the simulation structure being the first item. Returns ------- - Tuple[int, List[StructureType]] + tuple[int, list[StructureType]] The number of unenforced structures, reordered structure list """ @@ -604,14 +608,14 @@ def filter_structures_effective_dl( Parameters ---------- - structures : List[StructureType] + structures : list[StructureType] List of structures, with the simulation structure being the first item. axis : Axis Axis index to place last. Returns ------- - List[StructureType] + list[StructureType] A list of filtered structures whose ``dl`` along this axis is not ``None``. """ @@ -657,8 +661,8 @@ def structure_steps( structures: list[StructureType], wavelength: float, min_steps_per_wvl: float, - dl_min: pd.NonNegativeFloat, - dl_max: pd.NonNegativeFloat, + dl_min: NonNegativeFloat, + dl_max: NonNegativeFloat, axis: Axis, ) -> ArrayFloat1D: """Get the minimum mesh required in each structure. Special media are set to index of 1, @@ -667,15 +671,15 @@ def structure_steps( Parameters ---------- - structures : List[Structure] + structures : list[Structure] List of structures, with the simulation structure being the first item. wavelength : float Wavelength to use for the step size and for dispersive media epsilon. min_steps_per_wvl : float Minimum requested steps per wavelength. - dl_min: pd.NonNegativeFloat + dl_min: NonNegativeFloat Lower bound of grid size. - dl_max: pd.NonNegativeFloat + dl_max: NonNegativeFloat Upper bound of grid size. axis : Axis Axis index along which to operate. @@ -698,14 +702,14 @@ def rotate_structure_bounds(structures: list[StructureType], axis: Axis) -> list Parameters ---------- - structures : List[StructureType] + structures : list[StructureType] List of structures, with the simulation structure being the first item. axis : Axis Axis index to place last. Returns ------- - List[ArrayFloat1D] + list[ArrayFloat1D] A list of the bounding boxes of shape ``(2, 3)`` for each structure, with the bounds along ``axis`` being ``(:, 2)``. """ @@ -720,7 +724,7 @@ def rotate_structure_bounds(structures: list[StructureType], axis: Axis) -> list return struct_bbox @staticmethod - def bounds_2d_tree(struct_bbox: list[ArrayFloat1D]): + def bounds_2d_tree(struct_bbox: list[ArrayFloat1D]) -> STRtree: """Make a shapely Rtree for the 2D bounding boxes of all structures in the plane perpendicular to the meshing axis.""" @@ -867,7 +871,7 @@ def make_grid_multiple_intervals( Returns ------- - List[ArrayFloat1D] + list[ArrayFloat1D] A list of of step sizes in each interval. """ @@ -962,7 +966,7 @@ def grid_multiple_interval_analy_refinement( Returns ------- - Tuple[ArrayFloat1D, ArrayFloat1D] + tuple[ArrayFloat1D, ArrayFloat1D] left and right step sizes of each interval. """ @@ -1401,7 +1405,7 @@ def grid_grow_in_interval( if len_mismatch_even > small_dl: - def fun_scale(new_scale): + def fun_scale(new_scale: float) -> float: if isclose(new_scale, 1.0): return len_interval - small_dl * (1 + num_step) return ( diff --git a/tidy3d/components/index.py b/tidy3d/components/index.py index dba7b9d1ea..a3a8ecabb8 100644 --- a/tidy3d/components/index.py +++ b/tidy3d/components/index.py @@ -5,25 +5,32 @@ from __future__ import annotations -from collections.abc import Iterator, Mapping -from typing import Any +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any -import pydantic.v1 as pd +from pydantic import Field, model_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.types.simulation import SimulationType +if TYPE_CHECKING: + from collections.abc import Iterator + + from tidy3d.compat import Self + class ValueMap(Tidy3dBaseModel, Mapping[str, Any]): """An immutable dictionary-like container for objects. - This class maps unique string keys to corresponding value objects. - By inheriting from `collections.abc.Mapping`, it provides standard dictionary - behaviors like item access (`my_dict["my_key"]`), iteration (`for name in my_dict`), and - length checking (`len(my_dict)`). + Notes + ----- + This class maps unique string keys to corresponding value objects. + By inheriting from `collections.abc.Mapping`, it provides standard dictionary + behaviors like item access (`my_dict["my_key"]`), iteration (`for name in my_dict`), and + length checking (`len(my_dict)`). - It automatically validates that the `keys` and `values` - tuples have matching lengths upon instantiation. + It automatically validates that the `keys` and `values` + tuples have matching lengths upon instantiation. Attributes ---------- @@ -34,18 +41,18 @@ class ValueMap(Tidy3dBaseModel, Mapping[str, Any]): same index. Should be overwritten by the subclass instantiation """ - keys_tuple: tuple[str, ...] = pd.Field( + keys_tuple: tuple[str, ...] = Field( description="A tuple of unique string identifiers for each simulation.", alias="keys" ) - values_tuple: tuple[Any, ...] = pd.Field( + values_tuple: tuple[Any, ...] = Field( description=( "A tuple of `Simulation` objects, each corresponding to a key at the same index." ), alias="values", ) - @pd.root_validator(skip_on_failure=True) - def _validate_lengths_match(cls, data: dict) -> dict: + @model_validator(mode="after") + def _validate_lengths_match(self) -> Self: """Pydantic root validator to ensure 'keys' and 'values' have the same length. Parameters @@ -63,10 +70,14 @@ def _validate_lengths_match(cls, data: dict) -> dict: ValueError If the lengths of the 'keys' and 'values' tuples are not equal. """ - keys, values = data.get("keys"), data.get("values") + keys, values = self.keys, self.values + if keys is None or values is None: + return self + if not hasattr(keys, "__len__") or not hasattr(values, "__len__"): + return self if keys is not None and values is not None and len(keys) != len(values): raise ValueError("Length of 'keys' and 'values' must be the same.") - return data + return self def __getitem__(self, key: str) -> Any: """Retrieves a `Simulation` object by its corresponding key. @@ -121,13 +132,15 @@ def __len__(self) -> int: class SimulationMap(ValueMap, Mapping[str, SimulationType]): """An immutable dictionary-like container for simulations. - This class maps unique string keys to corresponding `Simulation` objects. - By inheriting from `collections.abc.Mapping`, it provides standard dictionary - behaviors like item access (`sims["sim_A"]`), iteration (`for name in sims`), and - length checking (`len(sims)`). + Notes + ----- + This class maps unique string keys to corresponding `Simulation` objects. + By inheriting from `collections.abc.Mapping`, it provides standard dictionary + behaviors like item access (`sims["sim_A"]`), iteration (`for name in sims`), and + length checking (`len(sims)`). - It automatically validates that the `keys` and `values` - tuples have matching lengths upon instantiation. + It automatically validates that the `keys` and `values` + tuples have matching lengths upon instantiation. Attributes ---------- @@ -193,10 +206,10 @@ class SimulationMap(ValueMap, Mapping[str, SimulationType]): >>> # print(simulation_map["sim_1"]) """ - keys_tuple: tuple[str, ...] = pd.Field( + keys_tuple: tuple[str, ...] = Field( description="A tuple of unique string identifiers for each simulation.", alias="keys" ) - values_tuple: tuple[SimulationType, ...] = pd.Field( + values_tuple: tuple[SimulationType, ...] = Field( description=( "A tuple of `Simulation` objects, each corresponding to a key at the same index." ), diff --git a/tidy3d/components/lumped_element.py b/tidy3d/components/lumped_element.py index bf8a71b039..163e14da0e 100644 --- a/tidy3d/components/lumped_element.py +++ b/tidy3d/components/lumped_element.py @@ -4,20 +4,23 @@ from abc import ABC, abstractmethod from math import isclose -from typing import Annotated, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import ( + Field, + NonNegativeFloat, + PositiveFloat, + PositiveInt, + field_validator, + model_validator, +) -from tidy3d.components.grid.grid import Grid -from tidy3d.components.medium import PEC2D, Debye, Drude, Lorentz, Medium, Medium2D, PoleResidue -from tidy3d.components.monitor import FieldMonitor -from tidy3d.components.structure import MeshOverrideStructure, Structure -from tidy3d.components.validators import assert_line_or_plane, assert_plane, validate_name_str +from tidy3d.components.types.base import discriminated_union from tidy3d.constants import EPSILON_0, FARAD, HENRY, MICROMETER, OHM, fp_eps from tidy3d.exceptions import ValidationError -from .base import cached_property, skip_if_fields_missing +from .base import cached_property from .geometry.base import Box, ClipOperation, Geometry, GeometryGroup from .geometry.primitives import Cylinder from .geometry.utils import ( @@ -28,6 +31,7 @@ snap_point_to_grid, ) from .geometry.utils_2d import increment_float +from .medium import PEC2D, Debye, Drude, Lorentz, Medium, Medium2D, PoleResidue from .microwave.base import MicrowaveBaseModel from .microwave.formulas.circuit_parameters import ( capacitance_colinear_cylindrical_wire_segments, @@ -35,16 +39,18 @@ inductance_straight_rectangular_wire, total_inductance_colinear_rectangular_wire_segments, ) -from .types import ( - TYPE_TAG_STR, - Axis, - Axis2D, - Coordinate, - CoordinateOptional, - FreqArray, - LumpDistType, -) -from .viz import PlotParams, plot_params_lumped_element +from .monitor import FieldMonitor +from .structure import MeshOverrideStructure, Structure +from .types import Axis, Coordinate, LumpDistType +from .validators import assert_line_or_plane, assert_plane, validate_name_str +from .viz import plot_params_lumped_element + +if TYPE_CHECKING: + from tidy3d.compat import Self + + from .grid.grid import Grid + from .types import Axis2D, CoordinateOptional, FreqArray + from .viz import PlotParams DEFAULT_LUMPED_ELEMENT_NUM_CELLS = 1 LOSS_FACTOR_INDUCTOR = 1e6 @@ -53,14 +59,13 @@ class LumpedElement(MicrowaveBaseModel, ABC): """Base class describing the interface all lumped elements obey.""" - name: str = pd.Field( - ..., + name: str = Field( title="Name", description="Unique name for the lumped element.", min_length=1, ) - num_grid_cells: Optional[pd.PositiveInt] = pd.Field( + num_grid_cells: Optional[PositiveInt] = Field( DEFAULT_LUMPED_ELEMENT_NUM_CELLS, title="Lumped element grid cells", description="Number of mesh grid cells associated with the lumped element along each direction. " @@ -68,7 +73,7 @@ class LumpedElement(MicrowaveBaseModel, ABC): "A value of ``None`` will turn off mesh refinement suggestions.", ) - enable_snapping_points: bool = pd.Field( + enable_snapping_points: bool = Field( True, title="Snap Grid To Lumped Element", description="When enabled, snapping points are automatically generated to snap grids to key " @@ -95,11 +100,11 @@ def to_geometry(self) -> Geometry: """Converts the :class:`.LumpedElement` object to a :class:`.Geometry`.""" @abstractmethod - def to_structure(self, grid: Grid = None) -> Structure: + def to_structure(self, grid: Optional[Grid] = None) -> Structure: """Converts the network portion of the :class:`.LumpedElement` object to a :class:`.Structure`.""" - def to_structures(self, grid: Grid = None) -> list[Structure]: + def to_structures(self, grid: Optional[Grid] = None) -> list[Structure]: """Converts the :class:`.LumpedElement` object to a list of :class:`.Structure` which are ready to be added to the :class:`.Simulation`""" return [self.to_structure(grid)] @@ -110,14 +115,13 @@ class RectangularLumpedElement(LumpedElement, Box): is appended to the list of structures in the simulation as a :class:`.Medium2D` with the appropriate material properties given their size, voltage axis, and the network they represent.""" - voltage_axis: Axis = pd.Field( - ..., + voltage_axis: Axis = Field( title="Voltage Drop Axis", description="Specifies the axis along which the component is oriented and along which the " "associated voltage drop will occur. Must be in the plane of the element.", ) - snap_perimeter_to_grid: bool = pd.Field( + snap_perimeter_to_grid: bool = Field( True, title="Snap Perimeter to Grid", description="When enabled, the perimeter of the lumped element is snapped to the simulation grid, " @@ -130,12 +134,12 @@ class RectangularLumpedElement(LumpedElement, Box): _line_plane_validator = assert_line_or_plane() @cached_property - def normal_axis(self): + def normal_axis(self) -> Axis: """Normal axis of the lumped element, which is the axis where the element has zero size.""" return self.size.index(0.0) @cached_property - def lateral_axis(self): + def lateral_axis(self) -> Axis: """Lateral axis of the lumped element.""" return 3 - self.voltage_axis - self.normal_axis @@ -167,7 +171,7 @@ def _snapping_spec(self) -> SnappingSpec: snap_behavior = [SnapBehavior.Closest] * 3 snap_location[self.lateral_axis] = SnapLocation.Center snap_behavior[self.lateral_axis] = SnapBehavior.Expand - return SnappingSpec(location=snap_location, behavior=snap_behavior) + return SnappingSpec(location=tuple(snap_location), behavior=tuple(snap_behavior)) def to_mesh_overrides(self) -> list[MeshOverrideStructure]: """Creates a suggested :class:`.MeshOverrideStructure` list for mesh refinement both on the @@ -210,14 +214,14 @@ def to_snapping_points(self) -> list[CoordinateOptional]: ) return snapping_points - def to_geometry(self, grid: Grid = None) -> Box: + def to_geometry(self, grid: Optional[Grid] = None) -> Box: """Converts the :class:`RectangularLumpedElement` object to a :class:`.Box`.""" box = Box(size=self.size, center=self.center) if grid and self.snap_perimeter_to_grid: return snap_box_to_grid(grid, box, self._snapping_spec) return box - def _admittance_transfer_function_scaling(self, box: Box = None) -> float: + def _admittance_transfer_function_scaling(self, box: Optional[Box] = None) -> float: """The admittance transfer function of the network needs to be scaled depending on the dimensions of the lumped element. The scaling emulates adding networks with equal admittances in series and parallel, and is needed when distributing the network over a finite volume. @@ -269,21 +273,21 @@ def to_monitor(self, freqs: FreqArray) -> FieldMonitor: ) @cached_property - def monitor_name(self): + def monitor_name(self) -> str: return f"{self.name}_monitor" - @pd.validator("voltage_axis", always=True) - @skip_if_fields_missing(["name", "size"]) - def _voltage_axis_in_plane(cls, val, values): + @model_validator(mode="after") + def _voltage_axis_in_plane(self) -> Self: """Ensure voltage drop axis is in the plane of the lumped element.""" - name = values.get("name") - size = values.get("size") + val = self.voltage_axis + name = self.name + size = self.size if size.count(0.0) == 1 and size.index(0.0) == val: # if not planar, then a separate validator should be triggered, not this one raise ValidationError( f"'voltage_axis' must be in the plane of lumped element '{name}'." ) - return val + return self class LumpedResistor(RectangularLumpedElement): @@ -291,18 +295,17 @@ class LumpedResistor(RectangularLumpedElement): of structures in the simulation as :class:`Medium2D` with the appropriate conductivity given their size and voltage axis.""" - resistance: pd.PositiveFloat = pd.Field( - ..., + resistance: PositiveFloat = Field( title="Resistance", description="Resistance value in ohms.", unit=OHM, ) - def _sheet_conductance(self, box: Box = None): + def _sheet_conductance(self, box: Optional[Box] = None) -> float: """Effective sheet conductance.""" return self._admittance_transfer_function_scaling(box) / self.resistance - def to_structure(self, grid: Grid = None) -> Structure: + def to_structure(self, grid: Optional[Grid] = None) -> Structure: """Converts the :class:`LumpedResistor` object to a :class:`.Structure` ready to be added to the :class:`.Simulation`""" box = self.to_geometry(grid=grid) @@ -327,36 +330,32 @@ class CoaxialLumpedResistor(LumpedElement): structures in the simulation as :class:`Medium2D` with the appropriate conductivity given their size and geometry.""" - resistance: pd.PositiveFloat = pd.Field( - ..., + resistance: PositiveFloat = Field( title="Resistance", description="Resistance value in ohms.", unit=OHM, ) - center: Coordinate = pd.Field( + center: Coordinate = Field( (0.0, 0.0, 0.0), title="Center", description="Center of object in x, y, and z.", units=MICROMETER, ) - outer_diameter: pd.PositiveFloat = pd.Field( - ..., + outer_diameter: PositiveFloat = Field( title="Outer Diameter", description="Diameter of the outer concentric circle.", units=MICROMETER, ) - inner_diameter: pd.PositiveFloat = pd.Field( - ..., + inner_diameter: PositiveFloat = Field( title="Inner Diameter", description="Diameter of the inner concentric circle.", units=MICROMETER, ) - normal_axis: Axis = pd.Field( - ..., + normal_axis: Axis = Field( title="Normal Axis", description="Specifies the normal axis, which defines " "the orientation of the circles making up the coaxial lumped element.", @@ -396,32 +395,33 @@ def to_mesh_overrides(self) -> list[MeshOverrideStructure]: ) ] - @pd.validator("center", always=True) - def _center_not_inf(cls, val): + @field_validator("center") + @classmethod + def _center_not_inf(cls, val: Coordinate) -> Coordinate: """Make sure center is not infinitiy.""" if any(np.isinf(v) for v in val): raise ValidationError("'center' can not contain 'td.inf' terms.") return val - @pd.validator("inner_diameter", always=True) - @skip_if_fields_missing(["outer_diameter"]) - def _ensure_inner_diameter_is_smaller(cls, val, values): + @model_validator(mode="after") + def _ensure_inner_diameter_is_smaller(self) -> Self: """Ensures that the inner diameter is smaller than the outer diameter, so that the final shape is an annulus.""" - outer_diameter = values.get("outer_diameter") + val = self.inner_diameter + outer_diameter = self.outer_diameter if val >= outer_diameter: raise ValidationError( f"The 'inner_diameter' {val} of a coaxial lumped element must be less than its 'outer_diameter' {outer_diameter}." ) - return val + return self @cached_property - def _sheet_conductance(self): + def _sheet_conductance(self) -> float: """Effective sheet conductance for a coaxial resistor.""" rout = self.outer_diameter / 2 rin = self.inner_diameter / 2 return 1 / (2 * np.pi * self.resistance) * (np.log(rout / rin)) - def to_structure(self, grid: Grid = None) -> Structure: + def to_structure(self, grid: Optional[Grid] = None) -> Structure: """Converts the :class:`CoaxialLumpedResistor` object to a :class:`.Structure` ready to be added to the :class:`.Simulation`""" conductivity = self._sheet_conductance @@ -434,7 +434,7 @@ def to_structure(self, grid: Grid = None) -> Structure: medium=Medium2D(**medium_dict), ) - def to_geometry(self, grid: Grid = None) -> ClipOperation: + def to_geometry(self, grid: Optional[Grid] = None) -> ClipOperation: """Converts the :class:`CoaxialLumpedResistor` object to a :class:`Geometry`.""" rout = self.outer_diameter / 2 rin = self.inner_diameter / 2 @@ -573,35 +573,35 @@ class RLCNetwork(MicrowaveBaseModel): """ - resistance: Optional[pd.PositiveFloat] = pd.Field( + resistance: Optional[PositiveFloat] = Field( None, title="Resistance", description="Resistance value in ohms.", unit=OHM, ) - capacitance: Optional[pd.PositiveFloat] = pd.Field( + capacitance: Optional[PositiveFloat] = Field( None, title="Capacitance", description="Capacitance value in farads.", unit=FARAD, ) - inductance: Optional[pd.PositiveFloat] = pd.Field( + inductance: Optional[PositiveFloat] = Field( None, title="Inductance", description="Inductance value in henrys.", unit=HENRY, ) - network_topology: Literal["series", "parallel"] = pd.Field( + network_topology: Literal["series", "parallel"] = Field( "series", title="Network Topology", description="Describes whether network elements are connected in ``series`` or ``parallel``.", ) @cached_property - def _number_network_elements(self) -> pd.PositiveInt: + def _number_network_elements(self) -> PositiveInt: num_elements = 0 if self.resistance: num_elements += 1 @@ -711,11 +711,8 @@ def _to_medium(self, scaling_factor: float) -> PoleResidue: elif self.network_topology == "series": result_medium = RLCNetwork._series_network_to_equivalent_medium(scaling_factor, R, L, C) return result_medium - else: - result_medium = RLCNetwork._parallel_network_to_equivalent_medium( - scaling_factor, R, L, C - ) - return result_medium + result_medium = RLCNetwork._parallel_network_to_equivalent_medium(scaling_factor, R, L, C) + return result_medium @staticmethod def _series_network_to_equivalent_medium( @@ -754,7 +751,9 @@ def _parallel_network_to_equivalent_medium( ) -> PoleResidue: """Converts the RLC parallel network directly to an equivalent medium.""" - def combine_equivalent_medium_in_parallel(first: PoleResidue, second: PoleResidue): + def combine_equivalent_medium_in_parallel( + first: PoleResidue, second: PoleResidue + ) -> PoleResidue: """Helper for combining equivalent media when the network elements are in the 'parallel' configuration. A similar operation cannot be done for the 'series' topology.""" eps_inf = 1.0 + (first.eps_inf - 1) + (second.eps_inf - 1) @@ -787,16 +786,16 @@ def combine_equivalent_medium_in_parallel(first: PoleResidue, second: PoleResidu result_medium = combine_equivalent_medium_in_parallel(med, result_medium) return result_medium - @pd.validator("inductance", always=True) - @skip_if_fields_missing(["resistance", "capacitance"]) - def _validate_single_element(cls, val, values): + @model_validator(mode="after") + def _validate_single_element(self) -> Self: """At least one element should be defined.""" - resistance = values.get("resistance") - capacitance = values.get("capacitance") + val = self.inductance + resistance = self.resistance + capacitance = self.capacitance all_items_are_none = all(item is None for item in [resistance, capacitance, val]) if all_items_are_none: raise ValueError("At least one element must be defined in the 'RLCNetwork'.") - return val + return self class AdmittanceNetwork(MicrowaveBaseModel): @@ -847,15 +846,13 @@ class AdmittanceNetwork(MicrowaveBaseModel): """ - a: tuple[pd.NonNegativeFloat, ...] = pd.Field( - ..., + a: tuple[NonNegativeFloat, ...] = Field( title="Numerator Coefficients", description="A ``tuple`` of floats describing the coefficients of the numerator polynomial. " "The length of the ``tuple`` is equal to the order of the network.", ) - b: tuple[pd.NonNegativeFloat, ...] = pd.Field( - ..., + b: tuple[NonNegativeFloat, ...] = Field( title="Denominator Coefficients", description="A ``tuple`` of floats describing the coefficients of the denomiator polynomial. " "The length of the ``tuple`` is equal to the order of the network.", @@ -876,6 +873,9 @@ def _as_admittance_function(self) -> tuple[tuple[float, ...], tuple[float, ...]] return (self.a, self.b) +NetworkType = discriminated_union(Union[RLCNetwork, AdmittanceNetwork]) + + class LinearLumpedElement(RectangularLumpedElement): """Lumped element representing a network consisting of resistors, capacitors, and inductors. @@ -914,15 +914,13 @@ class LinearLumpedElement(RectangularLumpedElement): * `Using lumped elements in Tidy3D simulations <../../notebooks/LinearLumpedElements.html>`_ """ - network: Union[RLCNetwork, AdmittanceNetwork] = pd.Field( - ..., + network: NetworkType = Field( title="Network", description="The linear element produces an equivalent medium that emulates the " "voltage-current relationship described by the ``network`` field.", - discriminator=TYPE_TAG_STR, ) - dist_type: LumpDistType = pd.Field( + dist_type: LumpDistType = Field( "on", title="Distribute Type", description="Switches between the different methods for distributing the lumped element over " @@ -975,7 +973,7 @@ def _create_box_for_network(self, grid: Grid) -> Box: if size[self.voltage_axis] == 0: behavior = list(snap_spec.behavior) behavior[self.voltage_axis] = SnapBehavior.Expand - snap_spec = snap_spec.updated_copy(behavior=behavior) + snap_spec = snap_spec.updated_copy(behavior=tuple(behavior)) return snap_box_to_grid(grid, cell_box, snap_spec=snap_spec) @@ -1012,7 +1010,7 @@ def _create_connection_boxes( bottom_box = None return (bottom_box, top_box) - def to_structure(self, grid) -> Structure: + def to_structure(self, grid: Grid) -> Structure: """Converts the :class:`LinearLumpedElement` object to a :class:`.Structure`, which enforces the desired voltage-current relationship across one or more grid cells.""" @@ -1031,7 +1029,7 @@ def to_structure(self, grid) -> Structure: medium=Medium2D(**medium_dict), ) - def to_PEC_connection(self, grid) -> Optional[Structure]: + def to_PEC_connection(self, grid: Grid) -> Optional[Structure]: """Converts the :class:`LinearLumpedElement` object to a :class:`.Structure`, representing any PEC connections. """ @@ -1063,7 +1061,7 @@ def to_structures(self, grid: Grid) -> list[Structure]: structures.append(self.to_structure(grid)) return structures - def estimate_parasitic_elements(self, grid: Grid) -> tuple[float, float]: + def estimate_parasitic_elements(self, grid: Grid) -> Optional[tuple[float, float]]: """Provides an estimate for the parasitic inductance and capacitance associated with the connections. These wire or sheet connections are used when the lumped element is not distributed over the voltage axis. @@ -1140,6 +1138,7 @@ def estimate_parasitic_elements(self, grid: Grid) -> tuple[float, float]: # but there will be a contribution to inductance from the single connection L = inductance_straight_rectangular_wire(common_size, v_axis) return (L, 0) + return None def admittance(self, freqs: np.ndarray) -> np.ndarray: """Returns the admittance of this lumped element at the frequencies specified by ``freqs``. @@ -1168,11 +1167,10 @@ def impedance(self, freqs: np.ndarray) -> np.ndarray: # lumped elements allowed in Simulation.lumped_elements -LumpedElementType = Annotated[ +LumpedElementType = discriminated_union( Union[ LumpedResistor, CoaxialLumpedResistor, LinearLumpedElement, - ], - pd.Field(discriminator=TYPE_TAG_STR), -] + ] +) diff --git a/tidy3d/components/material/multi_physics.py b/tidy3d/components/material/multi_physics.py index 7e26a476b9..412e0cf78d 100644 --- a/tidy3d/components/material/multi_physics.py +++ b/tidy3d/components/material/multi_physics.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Optional +from typing import Any, Optional -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.material.solver_types import ( @@ -81,36 +81,36 @@ class MultiPhysicsMedium(Tidy3dBaseModel): ... ) """ - name: Optional[str] = pd.Field(None, title="Name", description="Medium name") + name: Optional[str] = Field(None, title="Name", description="Medium name") - optical: Optional[OpticalMediumType] = pd.Field( + optical: Optional[OpticalMediumType] = Field( None, title="Optical properties", description="Specifies optical properties.", discriminator=TYPE_TAG_STR, ) - # electrical: Optional[ElectricalMediumType] = pd.Field( + # electrical: Optional[ElectricalMediumType] = Field( # None, # title="Electrical properties", # description="Specifies electrical properties for RF simulations. This is currently not in use.", # ) - heat: Optional[HeatMediumType] = pd.Field( + heat: Optional[HeatMediumType] = Field( None, title="Heat properties", description="Specifies properties for Heat simulations.", discriminator=TYPE_TAG_STR, ) - charge: Optional[ChargeMediumType] = pd.Field( + charge: Optional[ChargeMediumType] = Field( None, title="Charge properties", description="Specifies properties for Charge simulations.", discriminator=TYPE_TAG_STR, ) - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Any: """ Delegate attribute lookup to inner media or fail fast. @@ -191,7 +191,7 @@ def __getattr__(self, name: str): ) @property - def heat_spec(self): + def heat_spec(self) -> Optional[HeatMediumType]: if self.heat is not None: return self.heat diff --git a/tidy3d/components/material/tcad/charge.py b/tidy3d/components/material/tcad/charge.py index 94d8cdb778..b2580abe70 100644 --- a/tidy3d/components/material/tcad/charge.py +++ b/tidy3d/components/material/tcad/charge.py @@ -2,9 +2,9 @@ from __future__ import annotations -from typing import Union +from typing import TYPE_CHECKING, Optional, Union -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat, field_validator from tidy3d.components.data.data_array import SpatialDataArray from tidy3d.components.medium import AbstractMedium @@ -21,17 +21,24 @@ from tidy3d.constants import CONDUCTIVITY, ELECTRON_VOLT, PERCMCUBE, PERMITTIVITY from tidy3d.log import log +if TYPE_CHECKING: + from tidy3d.compat import Self + class AbstractChargeMedium(AbstractMedium): """Abstract class for Charge specifications Currently, permittivity is treated as a constant.""" - permittivity: float = pd.Field( - 1.0, ge=1.0, title="Permittivity", description="Relative permittivity.", units=PERMITTIVITY + permittivity: float = Field( + 1.0, + ge=1.0, + title="Permittivity", + description="Relative permittivity.", + units=PERMITTIVITY, ) @property - def charge(self): + def charge(self) -> Self: """ This means that a charge medium has been defined inherently within this solver medium. This provides interconnection with the :class:`MultiPhysicsMedium` higher-dimensional classes. @@ -75,8 +82,7 @@ class ChargeConductorMedium(AbstractChargeMedium): A relative permittivity will be assumed 1 if no value is specified. """ - conductivity: pd.PositiveFloat = pd.Field( - ..., + conductivity: PositiveFloat = Field( title="Electric conductivity", description="Electric conductivity of material.", units=CONDUCTIVITY, @@ -256,73 +262,79 @@ class SemiconductorMedium(AbstractChargeMedium): """ - N_c: Union[EffectiveDOSModelType, pd.PositiveFloat] = pd.Field( - ..., + N_c: Union[EffectiveDOSModelType, PositiveFloat] = Field( title="Effective density of electron states", description=":math:`N_c` Effective density of states in the conduction band.", units=PERCMCUBE, ) - N_v: Union[EffectiveDOSModelType, pd.PositiveFloat] = pd.Field( - ..., + N_v: Union[EffectiveDOSModelType, PositiveFloat] = Field( title="Effective density of hole states", description=":math:`N_v` Effective density of states in the valence band.", units=PERCMCUBE, ) - E_g: Union[EnergyBandGapModelType, pd.PositiveFloat] = pd.Field( - ..., + E_g: Union[EnergyBandGapModelType, PositiveFloat] = Field( title="Band-gap energy", description=":math:`E_g` Band-gap energy", units=ELECTRON_VOLT, ) - mobility_n: MobilityModelType = pd.Field( - ..., + mobility_n: MobilityModelType = Field( title="Mobility model for electrons", description="Mobility model for electrons", ) - mobility_p: MobilityModelType = pd.Field( - ..., + mobility_p: MobilityModelType = Field( title="Mobility model for holes", description="Mobility model for holes", ) - R: tuple[RecombinationModelType, ...] = pd.Field( - [], + R: tuple[RecombinationModelType, ...] = Field( + (), title="Generation-Recombination models", description="Array containing the R models to be applied to the material.", ) - delta_E_g: BandGapNarrowingModelType = pd.Field( + delta_E_g: Optional[BandGapNarrowingModelType] = Field( None, title="Bandgap narrowing model.", description=":math:`\\Delta E_g` Bandgap narrowing model.", units=ELECTRON_VOLT, ) - N_a: Union[pd.NonNegativeFloat, SpatialDataArray, tuple[DopingBoxType, ...]] = pd.Field( + N_a: Union[ + tuple[DopingBoxType, ...], + list[DopingBoxType], + SpatialDataArray, + NonNegativeFloat, + ] = Field( (), title="Doping: Acceptor concentration", description="Concentration of acceptor impurities, which create mobile holes, resulting in p-type material. " "Can be specified as a single float for uniform doping, a :class:`SpatialDataArray` for a custom profile, " - "or a tuple of geometric shapes to define specific doped regions.", + "or a tuple/list of geometric shapes to define specific doped regions.", units=PERCMCUBE, ) - N_d: Union[pd.NonNegativeFloat, SpatialDataArray, tuple[DopingBoxType, ...]] = pd.Field( + N_d: Union[ + tuple[DopingBoxType, ...], + list[DopingBoxType], + SpatialDataArray, + NonNegativeFloat, + ] = Field( (), title="Doping: Donor concentration", description="Concentration of donor impurities, which create mobile electrons, resulting in n-type material. " "Can be specified as a single float for uniform doping, a :class:`SpatialDataArray` for a custom profile, " - "or a tuple of geometric shapes to define specific doped regions.", + "or a tuple/list of geometric shapes to define specific doped regions.", units=PERCMCUBE, ) # DEPRECATION VALIDATORS - @pd.validator("N_c", always=True) - def check_nc_uses_model(cls, val, values): + @field_validator("N_c") + @classmethod + def check_nc_uses_model(cls, val: Union[EffectiveDOSModelType, float]) -> EffectiveDOSModelType: """Issue deprecation warning if float is provided""" if isinstance(val, (float, int)): log.warning( @@ -332,8 +344,9 @@ def check_nc_uses_model(cls, val, values): return ConstantEffectiveDOS(N=val) return val - @pd.validator("N_v", always=True) - def check_nv_uses_model(cls, val, values): + @field_validator("N_v") + @classmethod + def check_nv_uses_model(cls, val: Union[EffectiveDOSModelType, float]) -> EffectiveDOSModelType: """Issue deprecation warning if float is provided""" if isinstance(val, (float, int)): log.warning( @@ -343,8 +356,11 @@ def check_nv_uses_model(cls, val, values): return ConstantEffectiveDOS(N=val) return val - @pd.validator("E_g", always=True) - def check_eg_uses_model(cls, val, values): + @field_validator("E_g") + @classmethod + def check_eg_uses_model( + cls, val: Union[EnergyBandGapModelType, float] + ) -> EnergyBandGapModelType: """Issue deprecation warning if float is provided""" if isinstance(val, (float, int)): log.warning( @@ -354,9 +370,17 @@ def check_eg_uses_model(cls, val, values): return ConstantEnergyBandGap(eg=val) return val - @pd.validator("N_d", always=True) - def check_nd_uses_model(cls, val, values): + @field_validator("N_d") + @classmethod + def check_nd_uses_model( + cls, + val: Union[ + NonNegativeFloat, SpatialDataArray, tuple[DopingBoxType, ...], list[DopingBoxType] + ], + ) -> Union[SpatialDataArray, tuple[DopingBoxType, ...]]: """Issue deprecation warning if float is provided""" + if isinstance(val, list): + return tuple(val) if isinstance(val, (float, int)): log.warning( "Passing a float to 'N_d' is deprecated and will be removed in future versions. " @@ -365,9 +389,17 @@ def check_nd_uses_model(cls, val, values): return (ConstantDoping(concentration=val),) return val - @pd.validator("N_a", always=True) - def check_na_uses_model(cls, val, values): + @field_validator("N_a") + @classmethod + def check_na_uses_model( + cls, + val: Union[ + NonNegativeFloat, SpatialDataArray, tuple[DopingBoxType, ...], list[DopingBoxType] + ], + ) -> Union[SpatialDataArray, tuple[DopingBoxType, ...]]: """Issue deprecation warning if float is provided""" + if isinstance(val, list): + return tuple(val) if isinstance(val, (float, int)): log.warning( "Passing a float to 'N_a' is deprecated and will be removed in future versions. " diff --git a/tidy3d/components/material/tcad/heat.py b/tidy3d/components/material/tcad/heat.py index 8e800177de..91d3cc6cfd 100644 --- a/tidy3d/components/material/tcad/heat.py +++ b/tidy3d/components/material/tcad/heat.py @@ -3,9 +3,9 @@ from __future__ import annotations from abc import ABC -from typing import Union +from typing import TYPE_CHECKING, Optional, Union -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import ( @@ -17,15 +17,18 @@ THERMAL_EXPANSIVITY, ) +if TYPE_CHECKING: + from tidy3d.compat import Self + # Liquid class class AbstractHeatMedium(ABC, Tidy3dBaseModel): """Abstract heat material specification.""" - name: str = pd.Field(None, title="Name", description="Optional unique name for medium.") + name: Optional[str] = Field(None, title="Name", description="Optional unique name for medium.") @property - def heat(self): + def heat(self) -> Self: """ This means that a heat medium has been defined inherently within this solver medium. This provides interconnection with the `MultiPhysicsMedium` higher-dimensional classes. @@ -86,51 +89,53 @@ class FluidMedium(AbstractHeatMedium): ... ) """ - thermal_conductivity: pd.NonNegativeFloat = pd.Field( + thermal_conductivity: Optional[NonNegativeFloat] = Field( default=None, title="Fluid Thermal Conductivity", description="Thermal conductivity (k) of the fluid.", units=THERMAL_CONDUCTIVITY, ) - viscosity: pd.NonNegativeFloat = pd.Field( + viscosity: Optional[NonNegativeFloat] = Field( default=None, title="Fluid Dynamic Viscosity", description="Dynamic viscosity (μ) of the fluid.", units=DYNAMIC_VISCOSITY, ) - specific_heat: pd.NonNegativeFloat = pd.Field( + specific_heat: Optional[NonNegativeFloat] = Field( default=None, title="Fluid Specific Heat", description="Specific heat of the fluid at constant pressure.", units=SPECIFIC_HEAT, ) - density: pd.NonNegativeFloat = pd.Field( + density: Optional[NonNegativeFloat] = Field( default=None, title="Fluid Density", description="Density (ρ) of the fluid.", units=DENSITY, ) - expansivity: pd.NonNegativeFloat = pd.Field( + expansivity: Optional[NonNegativeFloat] = Field( default=None, title="Fluid Thermal Expansivity", description="Thermal expansion coefficient (β) of the fluid.", units=THERMAL_EXPANSIVITY, ) + @classmethod def from_si_units( - thermal_conductivity: pd.NonNegativeFloat, - viscosity: pd.NonNegativeFloat, - specific_heat: pd.NonNegativeFloat, - density: pd.NonNegativeFloat, - expansivity: pd.NonNegativeFloat, - ): + cls, + thermal_conductivity: NonNegativeFloat, + viscosity: NonNegativeFloat, + specific_heat: NonNegativeFloat, + density: NonNegativeFloat, + expansivity: NonNegativeFloat, + ) -> Self: thermal_conductivity_tidy = thermal_conductivity / 1e6 # W/(m*K) -> W/(um*K) viscosity_tidy = viscosity / 1e6 # Pa*s -> kg/(um*s) specific_heat_tidy = specific_heat * 1e12 # J/(kg*K) -> um**2/(s**2*K) density_tidy = density / 1e18 # kg/m**3 -> kg/um**3 expansivity_tidy = expansivity # 1/K -> 1/K (no change) - return FluidMedium( + return cls( thermal_conductivity=thermal_conductivity_tidy, viscosity=viscosity_tidy, specific_heat=specific_heat_tidy, @@ -154,31 +159,33 @@ class SolidMedium(AbstractHeatMedium): ... ) """ - capacity: pd.PositiveFloat = pd.Field( + capacity: Optional[PositiveFloat] = Field( None, title="Heat capacity", description=f"Specific heat capacity in unit of {SPECIFIC_HEAT_CAPACITY}.", units=SPECIFIC_HEAT_CAPACITY, ) - conductivity: pd.PositiveFloat = pd.Field( + conductivity: PositiveFloat = Field( title="Thermal conductivity", description=f"Thermal conductivity of material in units of {THERMAL_CONDUCTIVITY}.", units=THERMAL_CONDUCTIVITY, ) - density: pd.PositiveFloat = pd.Field( + density: Optional[PositiveFloat] = Field( None, title="Density", description=f"Mass density of material in units of {DENSITY}.", units=DENSITY, ) + @classmethod def from_si_units( - conductivity: pd.PositiveFloat, - capacity: pd.PositiveFloat = None, - density: pd.PositiveFloat = None, - ): + cls, + conductivity: PositiveFloat, + capacity: Optional[PositiveFloat] = None, + density: Optional[PositiveFloat] = None, + ) -> Self: """Create a SolidMedium using SI units""" new_conductivity = conductivity * 1e-6 # Convert from W/(m*K) to W/(um*K) new_capacity = capacity @@ -187,7 +194,7 @@ def from_si_units( if density is not None: new_density = density * 1e-18 - return SolidMedium( + return cls( capacity=new_capacity, conductivity=new_conductivity, density=new_density, diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index 0808ec9568..aaf44b9437 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -1,6419 +1,543 @@ -"""Defines properties of the medium / materials""" +"""Compatibility shim for :mod:`tidy3d._common.components.medium`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as partially migrated to _common from __future__ import annotations -import functools from abc import ABC, abstractmethod -from math import isclose -from typing import Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union import autograd.numpy as np - -# TODO: it's hard to figure out which functions need this, for now all get it -import numpy as npo -import pydantic.v1 as pd -import xarray as xr -from autograd.differential_operators import tensor_jacobian_product - -from tidy3d.components.autograd.utils import pack_complex_vec -from tidy3d.components.material.tcad.heat import ThermalSpecType -from tidy3d.constants import ( - C_0, - CONDUCTIVITY, - EPSILON_0, - ETA_0, - HBAR, - HERTZ, - LARGEST_FP_NUMBER, - MICROMETER, - MU_0, - PERMITTIVITY, - RADPERSEC, - SECOND, - fp_eps, - pec_val, +from pydantic import ( + Field, + NonNegativeFloat, + PositiveFloat, + PositiveInt, + field_validator, + model_validator, ) -from tidy3d.exceptions import SetupError, ValidationError -from tidy3d.log import log -from .autograd.derivative_utils import DerivativeInfo, integrate_within_bounds -from .autograd.types import AutogradFieldMap, TracedFloat, TracedPoleAndResidue, TracedPositiveFloat -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing -from .data.data_array import DATA_ARRAY_MAP, ScalarFieldDataArray, SpatialDataArray -from .data.dataset import ( - ElectromagneticFieldDataset, - PermittivityDataset, -) -from .data.unstructured.base import UnstructuredGridDataset -from .data.utils import ( - CustomSpatialDataType, - CustomSpatialDataTypeAnnotated, - _check_same_coordinates, - _get_numpy_array, - _ones_like, - _zeros_like, +from tidy3d._common.components import medium as common_medium +from tidy3d._common.components.medium import ( + ALLOWED_INTERP_METHODS, + FILL_VALUE, + FREQ_EVAL_INF, + LOSSY_METAL_DEFAULT_MAX_POLES, + LOSSY_METAL_DEFAULT_SAMPLING_FREQUENCY, + LOSSY_METAL_DEFAULT_TOLERANCE_RMS, + LOSSY_METAL_SCALED_REAL_PART, + PEC, + PMC, + AbstractCustomMedium, + AbstractMedium, + # IsotropicCustomMediumType, + # IsotropicCustomMediumInternalType, + # IsotropicMediumType, + AnisotropicMedium, + ArrayFloat, + ArrayGeneric, + ComplexArrayOrScalar, + CustomAnisotropicMedium, + CustomAnisotropicMediumInternal, + CustomDebye, + CustomDispersiveMedium, + CustomDrude, + CustomIsotropicMedium, + CustomLorentz, + CustomMedium, + CustomPoleResidue, + CustomSellmeier, + Debye, + DispersiveMedium, + Drude, + FrequencyArray, + # AnisotropicMediumFromMedium2D, + FullyAnisotropicMedium, + IsotropicCustomMediumType, + # SurfaceImpedanceFitterParam, + # AbstractSurfaceRoughness, + # HammerstadSurfaceRoughness, + # HuraySurfaceRoughness, + # SurfaceRoughnessType, + # LossyMetalMedium, + # IsotropicUniformMediumFor2DType, + IsotropicUniformMediumType, + Lorentz, + Medium, + PECMedium, + PMCMedium, + PoleResidue, + Sellmeier, + WeightFunction, + ensure_freq_in_range, + # AbstractPerturbationMedium, + # PerturbationMedium, + # PerturbationPoleResidue, + # PerturbationMediumType, + # T, + # _get_all_subclasses, + # MediumType3D, + # Medium2D, + # PEC2D, + # MediumType, + # medium_from_nk, + extend_isotropic_uniform_medium_type, + extend_perturbation_medium_type, ) -from .data.validators import validate_no_nans -from .dispersion_fitter import ( - LOSS_CHECK_MAX, - LOSS_CHECK_MIN, - LOSS_CHECK_NUM, +from tidy3d.components.autograd.types import TracedFloat +from tidy3d.components.base import Tidy3dBaseModel, cached_property +from tidy3d.components.data.data_array import DATA_ARRAY_MAP, ScalarFieldDataArray, SpatialDataArray +from tidy3d.components.dispersion_fitter import ( fit, - imag_resp_extrema_locs, ) -from .geometry.base import Geometry -from .grid.grid import Coords, Grid -from .nonlinear import ( # noqa: F401 +from tidy3d.components.geometry.base import Geometry +from tidy3d.components.nonlinear import ( KerrNonlinearity, NonlinearModel, NonlinearSpec, NonlinearSusceptibility, TwoPhotonAbsorption, ) -from .parameter_perturbation import ( +from tidy3d.components.parameter_perturbation import ( IndexPerturbation, ParameterPerturbation, PermittivityPerturbation, ) -from .time_modulation import ModulationSpec -from .transformation import RotationType -from .types import ( - TYPE_TAG_STR, - ArrayComplex1D, - ArrayComplex3D, - ArrayFloat1D, - Ax, - Axis, - Bound, - Complex, - FreqBound, - InterpMethod, - PermittivityComponent, - PoleAndResidue, - TensorReal, - annotate_type, +from tidy3d.components.types.base import TYPE_TAG_STR, FreqBound, TensorReal +from tidy3d.components.validators import validate_parameter_perturbation +from tidy3d.components.viz import VisualizationSpec, add_ax_if_none +from tidy3d.constants import ( + C_0, + CONDUCTIVITY, + EPSILON_0, + ETA_0, + HERTZ, + MICROMETER, + MU_0, + PERMITTIVITY, + RADPERSEC, + fp_eps, ) -from .validators import validate_name_str, validate_parameter_perturbation -from .viz import VisualizationSpec, add_ax_if_none - -# evaluate frequency as this number (Hz) if inf -FREQ_EVAL_INF = 1e50 - -# extrapolation option in custom medium -FILL_VALUE = "extrapolate" - -# Lossy metal -LOSSY_METAL_DEFAULT_SAMPLING_FREQUENCY = 20 -LOSSY_METAL_SCALED_REAL_PART = 10.0 -LOSSY_METAL_DEFAULT_MAX_POLES = 5 -LOSSY_METAL_DEFAULT_TOLERANCE_RMS = 1e-3 - - -def ensure_freq_in_range(eps_model: Callable[[float], complex]) -> Callable[[float], complex]: - """Decorate ``eps_model`` to log warning if frequency supplied is out of bounds.""" - - @functools.wraps(eps_model) - def _eps_model(self, frequency: float) -> complex: - """New eps_model function.""" - # evaluate infs and None as FREQ_EVAL_INF - is_inf_scalar = isinstance(frequency, float) and np.isinf(frequency) - if frequency is None or is_inf_scalar: - frequency = FREQ_EVAL_INF - - if isinstance(frequency, np.ndarray): - frequency = frequency.astype(float) - frequency[np.where(np.isinf(frequency))] = FREQ_EVAL_INF - - # if frequency range not present just return original function - if self.frequency_range is None: - return eps_model(self, frequency) - - fmin, fmax = self.frequency_range - # don't warn for evaluating infinite frequency - if is_inf_scalar: - return eps_model(self, frequency) - - outside_lower = np.zeros_like(frequency, dtype=bool) - outside_upper = np.zeros_like(frequency, dtype=bool) - - if fmin > 0: - outside_lower = frequency / fmin < 1 - fp_eps - elif fmin == 0: - outside_lower = frequency < 0 - - if fmax > 0: - outside_upper = frequency / fmax > 1 + fp_eps - - if np.any(outside_lower | outside_upper): - log.warning( - "frequency passed to 'Medium.eps_model()'" - f"is outside of 'Medium.frequency_range' = {self.frequency_range}", - capture=False, - ) - return eps_model(self, frequency) - - return _eps_model - - -""" Medium Definitions """ - +from tidy3d.exceptions import SetupError, ValidationError +from tidy3d.log import log -class AbstractMedium(ABC, Tidy3dBaseModel): - """A medium within which electromagnetic waves propagate.""" +if TYPE_CHECKING: + import xarray as xr + from autograd.numpy.numpy_boxes import ArrayBox + from pydantic import FieldValidationInfo + + from tidy3d._common.components.medium import ArrayComplex + from tidy3d.compat import Self + from tidy3d.components.autograd.derivative_utils import DerivativeInfo + from tidy3d.components.autograd.types import AutogradFieldMap + from tidy3d.components.data.dataset import ElectromagneticFieldDataset + from tidy3d.components.data.utils import CustomSpatialDataType + from tidy3d.components.time_modulation import ModulationSpec + from tidy3d.components.transformation import RotationType + from tidy3d.components.types.base import ( + ArrayComplex1D, + ArrayComplex3D, + ArrayFloat1D, + Ax, + Axis, + Bound, + Complex, + InterpMethod, + PermittivityComponent, + PolesAndResidues, + ) - name: str = pd.Field(None, title="Name", description="Optional unique name for medium.") - frequency_range: FreqBound = pd.Field( - None, - title="Frequency Range", - description="Optional range of validity for the medium.", - units=(HERTZ, HERTZ), - ) +class SurfaceImpedanceFitterParam(Tidy3dBaseModel): + """Advanced parameters for fitting surface impedance of a :class:`.LossyMetalMedium`. + Internally, the quantity to be fitted is surface impedance divided by ``-1j * \\omega``. + """ - allow_gain: bool = pd.Field( - False, - title="Allow gain medium", - description="Allow the medium to be active. Caution: " - "simulations with a gain medium are unstable, and are likely to diverge." - "Simulations where ``allow_gain`` is set to ``True`` will still be charged even if " - "diverged. Monitor data up to the divergence point will still be returned and can be " - "useful in some cases.", + max_num_poles: PositiveInt = Field( + LOSSY_METAL_DEFAULT_MAX_POLES, + title="Maximal Number Of Poles", + description="Maximal number of poles in complex-conjugate pole residue model for " + "fitting surface impedance.", ) - nonlinear_spec: Union[NonlinearSpec, NonlinearSusceptibility] = pd.Field( - None, - title="Nonlinear Spec", - description="Nonlinear spec applied on top of the base medium properties.", + tolerance_rms: NonNegativeFloat = Field( + LOSSY_METAL_DEFAULT_TOLERANCE_RMS, + title="Tolerance In Fitting", + description="Tolerance in fitting.", ) - modulation_spec: ModulationSpec = pd.Field( - None, - title="Modulation Spec", - description="Modulation spec applied on top of the base medium properties.", + frequency_sampling_points: PositiveInt = Field( + LOSSY_METAL_DEFAULT_SAMPLING_FREQUENCY, + title="Number Of Sampling Frequencies", + description="Number of sampling frequencies used in fitting.", ) - viz_spec: Optional[VisualizationSpec] = pd.Field( - None, - title="Visualization Specification", - description="Plotting specification for visualizing medium.", + log_sampling: bool = Field( + True, + title="Frequencies Sampling In Log Scale", + description="Whether to sample frequencies logarithmically (``True``), " + "or linearly (``False``).", ) - @cached_property - def _nonlinear_models(self) -> list: - """The nonlinear models in the nonlinear_spec.""" - if self.nonlinear_spec is None: - return [] - if isinstance(self.nonlinear_spec, NonlinearModel): - return [self.nonlinear_spec] - if self.nonlinear_spec.models is None: - return [] - return list(self.nonlinear_spec.models) - - @cached_property - def _nonlinear_num_iters(self) -> pd.PositiveInt: - """The num_iters of the nonlinear_spec.""" - if self.nonlinear_spec is None: - return 0 - if isinstance(self.nonlinear_spec, NonlinearModel): - if self.nonlinear_spec.numiters is None: - return 1 # old default value for backwards compatibility - return self.nonlinear_spec.numiters - return self.nonlinear_spec.num_iters - - def _post_init_validators(self) -> None: - """Call validators taking ``self`` that get run after init.""" - self._validate_nonlinear_spec() - self._validate_modulation_spec_post_init() - - def _validate_nonlinear_spec(self) -> None: - """Check compatibility with nonlinear_spec.""" - if self.__class__.__name__ == "AnisotropicMedium" and any( - comp.nonlinear_spec is not None for comp in [self.xx, self.yy, self.zz] - ): - raise ValidationError( - "Nonlinearities are not currently supported for the components " - "of an anisotropic medium." - ) - if self.__class__.__name__ == "Medium2D" and any( - comp.nonlinear_spec is not None for comp in [self.ss, self.tt] - ): - raise ValidationError( - "Nonlinearities are not currently supported for the components of a 2D medium." - ) - - if self.nonlinear_spec is None: - return - if isinstance(self.nonlinear_spec, NonlinearModel): - log.warning( - "The API for 'nonlinear_spec' has changed. " - "The old usage 'nonlinear_spec=model' is deprecated and will be removed " - "in a future release. The new usage is " - r"'nonlinear_spec=NonlinearSpec(models=\[model])'." - ) - for model in self._nonlinear_models: - model._validate_medium_type(self) - model._validate_medium(self) - if ( - isinstance(self.nonlinear_spec, NonlinearSpec) - and isinstance(model, NonlinearSusceptibility) - and model.numiters is not None - ): - raise ValidationError( - "'NonlinearSusceptibility.numiters' is deprecated. " - "Please use 'NonlinearSpec.num_iters' instead." - ) - def _validate_modulation_spec_post_init(self) -> None: - """Check compatibility with nonlinear_spec.""" - if self.__class__.__name__ == "Medium2D" and any( - comp.modulation_spec is not None for comp in [self.ss, self.tt] - ): - raise ValidationError( - "Time modulation is not currently supported for the components of a 2D medium." - ) +class AbstractSurfaceRoughness(Tidy3dBaseModel): + """Abstract class for modeling surface roughness of lossy metal.""" - heat_spec: Optional[ThermalSpecType] = pd.Field( - None, - title="Heat Specification", - description="DEPRECATED: Use :class:`MultiPhysicsMedium`. Specification of the medium heat properties. They are " - "used for solving the heat equation via the :class:`HeatSimulation` interface. Such simulations can be" - "used for investigating the influence of heat propagation on the properties of optical systems. " - "Once the temperature distribution in the system is found using :class:`HeatSimulation` object, " - "``Simulation.perturbed_mediums_copy()`` can be used to convert mediums with perturbation " - "models defined into spatially dependent custom mediums. " - "Otherwise, the ``heat_spec`` does not directly affect the running of an optical " - "``Simulation``.", - discriminator=TYPE_TAG_STR, - ) + @abstractmethod + def roughness_correction_factor( + self, frequency: ArrayFloat1D, skin_depths: ArrayFloat1D + ) -> ArrayComplex1D: + """Complex-valued roughness correction factor applied to surface impedance. - @property - def charge(self) -> None: - return None + Notes + ----- + The roughness correction factor should be causal. It is multiplied to the + surface impedance of the lossy metal to account for the effects of surface roughness. - @property - def electrical(self) -> None: - return None + Parameters + ---------- + frequency : ArrayFloat1D + Frequency to evaluate roughness correction factor at (Hz). + skin_depths : ArrayFloat1D + Skin depths of the lossy metal that is frequency-dependent. - @property - def heat(self): - return self.heat_spec + Returns + ------- + ArrayComplex1D + The causal roughness correction factor evaluated at ``frequency``. + """ - @property - def optical(self) -> None: - return None - @pd.validator("modulation_spec", always=True) - @skip_if_fields_missing(["nonlinear_spec"]) - def _validate_modulation_spec(cls, val, values): - """Check compatibility with modulation_spec.""" - nonlinear_spec = values.get("nonlinear_spec") - if val is not None and nonlinear_spec is not None: - raise ValidationError( - f"For medium class {cls.__name__}, 'modulation_spec' of class {type(val)} and " - f"'nonlinear_spec' of class {type(nonlinear_spec)} are " - "not simultaneously supported." - ) - return val +class HammerstadSurfaceRoughness(AbstractSurfaceRoughness): + """Modified Hammerstad surface roughness model. It's a popular model that works well + under 5 GHz for surface roughness below 2 micrometer RMS. - _name_validator = validate_name_str() + Note + ---- - @cached_property - def is_spatially_uniform(self) -> bool: - """Whether the medium is spatially uniform.""" - return True + The power loss compared to smooth surface is described by: - @cached_property - def is_time_modulated(self) -> bool: - """Whether any component of the medium is time modulated.""" - return self.modulation_spec is not None and self.modulation_spec.applied_modulation + .. math:: - @cached_property - def is_nonlinear(self) -> bool: - """Whether the medium is nonlinear.""" - return self.nonlinear_spec is not None + 1 + (RF-1) \\frac{2}{\\pi}\\arctan(1.4\\frac{R_q^2}{\\delta^2}) - @cached_property - def is_custom(self) -> bool: - """Whether the medium is custom.""" - return isinstance(self, AbstractCustomMedium) + where :math:`\\delta` is skin depth, :math:`R_q` the RMS peak-to-vally height, and RF + roughness factor. - @cached_property - def is_fully_anisotropic(self) -> bool: - """Whether the medium is fully anisotropic.""" - return isinstance(self, FullyAnisotropicMedium) + Note + ---- + This model is based on: - @cached_property - def _incompatible_material_types(self) -> list[str]: - """A list of material properties present which may lead to incompatibilities.""" - properties = [ - self.is_time_modulated, - self.is_nonlinear, - self.is_custom, - self.is_fully_anisotropic, - ] - names = ["time modulated", "nonlinear", "custom", "fully anisotropic"] - types = [name for name, prop in zip(names, properties) if prop] - return types + Y. Shlepnev, C. Nwachukwu, "Roughness characterization for interconnect analysis", + 2011 IEEE International Symposium on Electromagnetic Compatibility, + (DOI: 10.1109/ISEMC.2011.6038367), 2011. - @cached_property - def _has_incompatibilities(self) -> bool: - """Whether the medium has incompatibilities. Certain medium types are incompatible - with certain others, and such pairs are not allowed to intersect in a simulation.""" - return len(self._incompatible_material_types) > 0 - - def _compatible_with(self, other: AbstractMedium) -> bool: - """Whether these two media are compatible if in structures that intersect.""" - if not (self._has_incompatibilities and other._has_incompatibilities): - return True - for med1, med2 in [(self, other), (other, self)]: - if med1.is_custom: - # custom and fully_anisotropic is OK - if med2.is_nonlinear or med2.is_time_modulated: - return False - if med1.is_fully_anisotropic: - if med2.is_nonlinear or med2.is_time_modulated: - return False - if med1.is_nonlinear: - if med2.is_time_modulated: - return False - return True + V. Dmitriev-Zdorov, B. Simonovich, I. Kochikov, "A Causal Conductor Roughness Model + and its Effect on Transmission Line Characteristics", Signal Integrity Journal, 2018. + """ - @abstractmethod - def eps_model(self, frequency: float) -> complex: - # TODO this should be moved out of here into FDTD Simulation Mediums? - """Complex-valued permittivity as a function of frequency. + rq: PositiveFloat = Field( + title="RMS Peak-to-Valley Height", + description="RMS peak-to-valley height (Rq) of the surface roughness.", + units=MICROMETER, + ) - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). + roughness_factor: float = Field( + 2.0, + title="Roughness Factor", + description="Expected maximal increase in conductor losses due to roughness effect. " + "Value 2 gives the classic Hammerstad equation.", + gt=1.0, + ) - Returns - ------- - complex - Complex-valued relative permittivity evaluated at ``frequency``. - """ + def roughness_correction_factor( + self, frequency: ArrayFloat1D, skin_depths: ArrayFloat1D + ) -> ArrayComplex1D: + """Complex-valued roughness correction factor applied to surface impedance. - def nk_model(self, frequency: float) -> tuple[float, float]: - """Real and imaginary parts of the refactive index as a function of frequency. + Notes + ----- + The roughness correction factor should be causal. It is multiplied to the + surface impedance of the lossy metal to account for the effects of surface roughness. Parameters ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). + frequency : ArrayFloat1D + Frequency to evaluate roughness correction factor at (Hz). + skin_depths : ArrayFloat1D + Skin depths of the lossy metal that is frequency-dependent. Returns ------- - Tuple[float, float] - Real part (n) and imaginary part (k) of refractive index of medium. + ArrayComplex1D + The causal roughness correction factor evaluated at ``frequency``. """ - eps_complex = self.eps_model(frequency=frequency) - return self.eps_complex_to_nk(eps_complex) - - def loss_tangent_model(self, frequency: float) -> tuple[float, float]: - """Permittivity and loss tangent as a function of frequency. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). + normalized_laplace = -1.4j * (self.rq / skin_depths) ** 2 + sqrt_normalized_laplace = np.sqrt(normalized_laplace) + causal_response = np.log( + 1 + 2 * sqrt_normalized_laplace / (1 + normalized_laplace) + ) + 2 * np.arctan(sqrt_normalized_laplace) + return 1 + (self.roughness_factor - 1) / np.pi * causal_response - Returns - ------- - Tuple[float, float] - Real part of permittivity and loss tangent. - """ - eps_complex = self.eps_model(frequency=frequency) - return self.eps_complex_to_eps_loss_tangent(eps_complex) - @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: - """Main diagonal of the complex-valued permittivity tensor as a function of frequency. +class HuraySurfaceRoughness(AbstractSurfaceRoughness): + """Huray surface roughness model. - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). + Note + ---- - Returns - ------- - Tuple[complex, complex, complex] - The diagonal elements of the relative permittivity tensor evaluated at ``frequency``. - """ + The power loss compared to smooth surface is described by: - # This only needs to be overwritten for anisotropic materials - eps = self.eps_model(frequency) - return (eps, eps, eps) + .. math:: - def eps_diagonal_numerical(self, frequency: float) -> tuple[complex, complex, complex]: - """Main diagonal of the complex-valued permittivity tensor for numerical considerations - such as meshing and runtime estimation. + \\frac{A_{matte}}{A_{flat}} + \\frac{3}{2}\\sum_i f_i/[1+\\frac{\\delta}{r_i}+\\frac{\\delta^2}{2r_i^2}] - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). + where :math:`\\delta` is skin depth, :math:`r_i` the radius of sphere, + :math:`\\frac{A_{matte}}{A_{flat}}` the relative area of the matte compared to flat surface, + and :math:`f_i=N_i4\\pi r_i^2/A_{flat}` the ratio of total sphere + surface area (number of spheres :math:`N_i` times the individual sphere surface area) + to the flat surface area. - Returns - ------- - Tuple[complex, complex, complex] - The diagonal elements of relative permittivity tensor relevant for numerical - considerations evaluated at ``frequency``. - """ + Note + ---- + This model is based on: - if self.is_pec: - # also 1 for lossy metal and Medium2D, but let's handle them in the subclass. - return (1.0 + 0j,) * 3 + J. Eric Bracken, "A Causal Huray Model for Surface Roughness", DesignCon, 2012. + """ - return self.eps_diagonal(frequency) + relative_area: PositiveFloat = Field( + 1, + title="Relative Area", + description="Relative area of the matte base compared to a flat surface", + ) - def eps_comp(self, row: Axis, col: Axis, frequency: float) -> complex: - """Single component of the complex-valued permittivity tensor as a function of frequency. - - Parameters - ---------- - row : int - Component's row in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). - col : int - Component's column in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - complex - Element of the relative permittivity tensor evaluated at ``frequency``. - """ - - # This only needs to be overwritten for anisotropic materials - if row == col: - return self.eps_model(frequency) - return 0j - - def _eps_plot( - self, frequency: float, eps_component: Optional[PermittivityComponent] = None - ) -> float: - """Returns real part of epsilon for plotting. A specific component of the epsilon tensor can - be selected for anisotropic medium. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at. - eps_component : PermittivityComponent - Component of the permittivity tensor to plot - e.g. ``"xx"``, ``"yy"``, ``"zz"``, ``"xy"``, ``"yz"``, ... - Defaults to ``None``, which returns the average of the diagonal values. - - Returns - ------- - float - Element ``eps_component`` of the relative permittivity tensor evaluated at ``frequency``. - """ - # Assumes the material is isotropic - # Will need to be overridden for anisotropic materials - return self.eps_model(frequency).real - - @cached_property - @abstractmethod - def n_cfl(self) -> None: - # TODO this should be moved out of here into FDTD Simulation Mediums? - """To ensure a stable FDTD simulation, it is essential to select an appropriate - time step size in accordance with the CFL condition. The maximal time step - size is inversely proportional to the speed of light in the medium, and thus - proportional to the index of refraction. However, for dispersive medium, - anisotropic medium, and other more complicated media, there are complications in - deciding on the choice of the index of refraction. - - This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl``. - """ - - @add_ax_if_none - def plot(self, freqs: float, ax: Ax = None) -> Ax: - """Plot n, k of a :class:`.Medium` as a function of frequency. - - Parameters - ---------- - freqs: float - Frequencies (Hz) to evaluate the medium properties at. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - - freqs = np.array(freqs) - eps_complex = np.array([self.eps_model(freq) for freq in freqs]) - n, k = AbstractMedium.eps_complex_to_nk(eps_complex) - - freqs_thz = freqs / 1e12 - ax.plot(freqs_thz, n, label="n") - ax.plot(freqs_thz, k, label="k") - ax.set_xlabel("frequency (THz)") - ax.set_title("medium dispersion") - ax.legend() - ax.set_aspect("auto") - return ax - - """ Conversion helper functions """ - - @staticmethod - def nk_to_eps_complex(n: float, k: float = 0.0) -> complex: - """Convert n, k to complex permittivity. - - Parameters - ---------- - n : float - Real part of refractive index. - k : float = 0.0 - Imaginary part of refrative index. - - Returns - ------- - complex - Complex-valued relative permittivity. - """ - eps_real = n**2 - k**2 - eps_imag = 2 * n * k - return eps_real + 1j * eps_imag - - @staticmethod - def eps_complex_to_nk(eps_c: complex) -> tuple[float, float]: - """Convert complex permittivity to n, k values. - - Parameters - ---------- - eps_c : complex - Complex-valued relative permittivity. - - Returns - ------- - Tuple[float, float] - Real and imaginary parts of refractive index (n & k). - """ - eps_c = np.array(eps_c) - ref_index = np.sqrt(eps_c) - return np.real(ref_index), np.imag(ref_index) - - @staticmethod - def nk_to_eps_sigma(n: float, k: float, freq: float) -> tuple[float, float]: - """Convert ``n``, ``k`` at frequency ``freq`` to permittivity and conductivity values. - - Parameters - ---------- - n : float - Real part of refractive index. - k : float = 0.0 - Imaginary part of refrative index. - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[float, float] - Real part of relative permittivity & electric conductivity. - """ - eps_complex = AbstractMedium.nk_to_eps_complex(n, k) - eps_real, eps_imag = eps_complex.real, eps_complex.imag - omega = 2 * np.pi * freq - sigma = omega * eps_imag * EPSILON_0 - return eps_real, sigma - - @staticmethod - def eps_sigma_to_eps_complex(eps_real: float, sigma: float, freq: float) -> complex: - """convert permittivity and conductivity to complex permittivity at freq - - Parameters - ---------- - eps_real : float - Real-valued relative permittivity. - sigma : float - Conductivity. - freq : float - Frequency to evaluate permittivity at (Hz). - If not supplied, returns real part of permittivity (limit as frequency -> infinity.) - - Returns - ------- - complex - Complex-valued relative permittivity. - """ - if freq is None: - return eps_real - omega = 2 * np.pi * freq - - return eps_real + 1j * sigma / omega / EPSILON_0 - - @staticmethod - def eps_complex_to_eps_sigma(eps_complex: complex, freq: float) -> tuple[float, float]: - """Convert complex permittivity at frequency ``freq`` - to permittivity and conductivity values. - - Parameters - ---------- - eps_complex : complex - Complex-valued relative permittivity. - freq : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[float, float] - Real part of relative permittivity & electric conductivity. - """ - eps_real, eps_imag = eps_complex.real, eps_complex.imag - omega = 2 * np.pi * freq - sigma = omega * eps_imag * EPSILON_0 - return eps_real, sigma - - @staticmethod - def eps_complex_to_eps_loss_tangent(eps_complex: complex) -> tuple[float, float]: - """Convert complex permittivity to permittivity and loss tangent. - - Parameters - ---------- - eps_complex : complex - Complex-valued relative permittivity. - - Returns - ------- - Tuple[float, float] - Real part of relative permittivity & loss tangent - """ - eps_real, eps_imag = eps_complex.real, eps_complex.imag - return eps_real, eps_imag / eps_real - - @staticmethod - def eps_loss_tangent_to_eps_complex(eps_real: float, loss_tangent: float) -> complex: - """Convert permittivity and loss tangent to complex permittivity. - - Parameters - ---------- - eps_real : float - Real part of relative permittivity - loss_tangent : float - Loss tangent - - Returns - ------- - eps_complex : complex - Complex-valued relative permittivity. - """ - return eps_real * (1 + 1j * loss_tangent) - - @staticmethod - def eV_to_angular_freq(f_eV: float): - """Convert frequency in unit of eV to rad/s. - - Parameters - ---------- - f_eV : float - Frequency in unit of eV - """ - return f_eV / HBAR - - @staticmethod - def angular_freq_to_eV(f_rad: float): - """Convert frequency in unit of rad/s to eV. - - Parameters - ---------- - f_rad : float - Frequency in unit of rad/s - """ - return f_rad * HBAR - - @staticmethod - def angular_freq_to_Hz(f_rad: float): - """Convert frequency in unit of rad/s to Hz. - - Parameters - ---------- - f_rad : float - Frequency in unit of rad/s - """ - return f_rad / 2 / np.pi - - @staticmethod - def Hz_to_angular_freq(f_hz: float): - """Convert frequency in unit of Hz to rad/s. - - Parameters - ---------- - f_hz : float - Frequency in unit of Hz - """ - return f_hz * 2 * np.pi - - @ensure_freq_in_range - def sigma_model(self, freq: float) -> complex: - """Complex-valued conductivity as a function of frequency. - - Parameters - ---------- - freq: float - Frequency to evaluate conductivity at (Hz). - - Returns - ------- - complex - Complex conductivity at this frequency. - """ - omega = freq * 2 * np.pi - eps_complex = self.eps_model(freq) - eps_inf = self.eps_model(np.inf) - sigma = (eps_inf - eps_complex) * 1j * omega * EPSILON_0 - return sigma - - @cached_property - def is_pec(self): - """Whether the medium is a PEC.""" - return False - - @cached_property - def is_pmc(self): - """Whether the medium is a PMC.""" - return False - - def sel_inside(self, bounds: Bound) -> AbstractMedium: - """Return a new medium that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - AbstractMedium - Medium with reduced data. - """ - - if self.modulation_spec is not None: - modulation_reduced = self.modulation_spec.sel_inside(bounds) - return self.updated_copy(modulation_spec=modulation_reduced) - - return self - - """ Autograd code """ - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - raise NotImplementedError(f"Can't compute derivative for 'Medium': '{type(self)}'.") - - def _derivative_eps_sigma_volume( - self, E_der_map: ElectromagneticFieldDataset, bounds: Bound - ) -> dict[str, xr.DataArray]: - """Get the derivative w.r.t permittivity and conductivity in the volume.""" - - vjp_eps_complex = self._derivative_eps_complex_volume(E_der_map=E_der_map, bounds=bounds) - - values = vjp_eps_complex.values - - # compute directly with frequency dimension - freqs = vjp_eps_complex.coords["f"].values - omegas = 2 * np.pi * freqs - eps_vjp = np.real(values) - sigma_vjp = -np.imag(values) / omegas / EPSILON_0 - - eps_vjp = np.sum(eps_vjp) - sigma_vjp = np.sum(sigma_vjp) - - return {"permittivity": eps_vjp, "conductivity": sigma_vjp} - - def _derivative_eps_complex_volume( - self, E_der_map: ElectromagneticFieldDataset, bounds: Bound - ) -> xr.DataArray: - """Get the derivative w.r.t complex-valued permittivity in the volume.""" - vjp_value = None - for field_name in ("Ex", "Ey", "Ez"): - fld = E_der_map[field_name] - vjp_value_fld = integrate_within_bounds( - arr=fld, - dims=("x", "y", "z"), - bounds=bounds, - ) - if vjp_value is None: - vjp_value = vjp_value_fld - else: - vjp_value += vjp_value_fld - - return vjp_value - - def __repr__(self): - """If the medium has a name, use it as the representation. Otherwise, use the default representation.""" - if self.name: - return self.name - return super().__repr__() - - -class AbstractCustomMedium(AbstractMedium, ABC): - """A spatially varying medium.""" - - interp_method: InterpMethod = pd.Field( - "nearest", - title="Interpolation method", - description="Interpolation method to obtain permittivity values " - "that are not supplied at the Yee grids; For grids outside the range " - "of the supplied data, extrapolation will be applied. When the extrapolated " - "value is smaller (greater) than the minimal (maximal) of the supplied data, " - "the extrapolated value will take the minimal (maximal) of the supplied data.", - ) - - subpixel: bool = pd.Field( - False, - title="Subpixel averaging", - description="If ``True``, apply the subpixel averaging method specified by " - "``Simulation``'s field ``subpixel`` for this type of material on the " - "interface of the structure, including exterior boundary and " - "intersection interfaces with other structures.", - ) - - derived_from: Optional[annotate_type(PerturbationMediumType)] = pd.Field( - None, - title="Parent Medium", - description="If not ``None``, it records the parent medium from which this medium was derived.", - ) - - @cached_property - @abstractmethod - def is_isotropic(self) -> bool: - """The medium is isotropic or anisotropic.""" - - def _interp_method(self, comp: Axis) -> InterpMethod: - """Interpolation method applied to comp.""" - return self.interp_method - - @abstractmethod - def eps_dataarray_freq( - self, frequency: float - ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: - """Permittivity array at ``frequency``. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[ - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset` - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset` - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset` - ], - ] - The permittivity evaluated at ``frequency``. - """ - - def eps_diagonal_on_grid( - self, - frequency: float, - coords: Coords, - ) -> tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D]: - """Spatial profile of main diagonal of the complex-valued permittivity - at ``frequency`` interpolated at the supplied coordinates. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - coords : :class:`.Coords` - The grid point coordinates over which interpolation is performed. - - Returns - ------- - Tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D] - The complex-valued permittivity tensor at ``frequency`` interpolated - at the supplied coordinate. - """ - eps_spatial = self.eps_dataarray_freq(frequency) - if self.is_isotropic: - eps_interp = _get_numpy_array( - coords.spatial_interp(eps_spatial[0], self._interp_method(0)) - ) - return (eps_interp, eps_interp, eps_interp) - return tuple( - _get_numpy_array(coords.spatial_interp(eps_comp, self._interp_method(comp))) - for comp, eps_comp in enumerate(eps_spatial) - ) - - def eps_comp_on_grid( - self, - row: Axis, - col: Axis, - frequency: float, - coords: Coords, - ) -> ArrayComplex3D: - """Spatial profile of a single component of the complex-valued permittivity tensor at - ``frequency`` interpolated at the supplied coordinates. - - Parameters - ---------- - row : int - Component's row in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). - col : int - Component's column in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). - frequency : float - Frequency to evaluate permittivity at (Hz). - coords : :class:`.Coords` - The grid point coordinates over which interpolation is performed. - - Returns - ------- - ArrayComplex3D - Single component of the complex-valued permittivity tensor at ``frequency`` interpolated - at the supplied coordinates. - """ - - if row == col: - return self.eps_diagonal_on_grid(frequency, coords)[row] - return 0j - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - """Complex-valued spatially averaged permittivity as a function of frequency.""" - if self.is_isotropic: - return np.mean(_get_numpy_array(self.eps_dataarray_freq(frequency)[0])) - return np.mean( - [np.mean(_get_numpy_array(eps_comp)) for eps_comp in self.eps_dataarray_freq(frequency)] - ) - - @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: - """Main diagonal of the complex-valued permittivity tensor - at ``frequency``. Spatially, we take max{||eps||}, so that autoMesh generation - works appropriately. - """ - eps_spatial = self.eps_dataarray_freq(frequency) - if self.is_isotropic: - eps_comp = _get_numpy_array(eps_spatial[0]).ravel() - eps = eps_comp[np.argmax(np.abs(eps_comp))] - return (eps, eps, eps) - eps_spatial_array = (_get_numpy_array(eps_comp).ravel() for eps_comp in eps_spatial) - return tuple(eps_comp[np.argmax(np.abs(eps_comp))] for eps_comp in eps_spatial_array) - - def _get_real_vals(self, x: np.ndarray) -> np.ndarray: - """Grab the real part of the values in array. - Used for _eps_bounds() - """ - return _get_numpy_array(np.real(x)).ravel() - - def _eps_bounds( - self, - frequency: Optional[float] = None, - eps_component: Optional[PermittivityComponent] = None, - ) -> tuple[float, float]: - """Returns permittivity bounds for setting the color bounds when plotting. - - Parameters - ---------- - frequency : float = None - Frequency to evaluate the relative permittivity of all mediums. - If not specified, evaluates at infinite frequency. - eps_component : Optional[PermittivityComponent] = None - Component of the permittivity tensor to plot for anisotropic materials, - e.g. ``"xx"``, ``"yy"``, ``"zz"``, ``"xy"``, ``"yz"``, ... - Defaults to ``None``, which returns the average of the diagonal values. - - Returns - ------- - Tuple[float, float] - The min and max values of the permittivity for the selected component and evaluated at ``frequency``. - """ - eps_dataarray = self.eps_dataarray_freq(frequency) - all_eps = np.concatenate(self._get_real_vals(eps_comp) for eps_comp in eps_dataarray) - return (np.min(all_eps), np.max(all_eps)) - - @staticmethod - def _validate_isreal_dataarray(dataarray: CustomSpatialDataType) -> bool: - """Validate that the dataarray is real""" - return np.all(np.isreal(_get_numpy_array(dataarray))) - - @staticmethod - def _validate_isreal_dataarray_tuple( - dataarray_tuple: tuple[CustomSpatialDataType, ...], - ) -> bool: - """Validate that the dataarray is real""" - return np.all([AbstractCustomMedium._validate_isreal_dataarray(f) for f in dataarray_tuple]) - - @abstractmethod - def _sel_custom_data_inside(self, bounds: Bound) -> None: - """Return a new medium that contains the minimal amount custom data necessary to cover - a spatial region defined by ``bounds``.""" - - def sel_inside(self, bounds: Bound) -> AbstractCustomMedium: - """Return a new medium that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - AbstractMedium - Medium with reduced data. - """ - - self_mod_data_reduced = super().sel_inside(bounds) - - return self_mod_data_reduced._sel_custom_data_inside(bounds) - - @staticmethod - def _not_loaded(field): - """Check whether data was not loaded.""" - if isinstance(field, str) and field in DATA_ARRAY_MAP: - return True - # attempting to construct an UnstructuredGridDataset from a dict - if isinstance(field, dict) and field.get("type") in ( - "TriangularGridDataset", - "TetrahedralGridDataset", - ): - return any( - isinstance(subfield, str) and subfield in DATA_ARRAY_MAP - for subfield in [field["points"], field["cells"], field["values"]] - ) - # attempting to pass an UnstructuredGridDataset with zero points - if isinstance(field, UnstructuredGridDataset): - return any(len(subfield) == 0 for subfield in [field.points, field.cells, field.values]) - - def _derivative_field_cmp( - self, - E_der_map: ElectromagneticFieldDataset, - spatial_data: PermittivityDataset, - dim: str, - ) -> np.ndarray: - coords_interp = {key: val for key, val in spatial_data.coords.items() if len(val) > 1} - dims_sum = {dim for dim in spatial_data.coords.keys() if dim not in coords_interp} - - eps_coordinate_shape = [ - len(spatial_data.coords[dim]) for dim in spatial_data.dims if dim in "xyz" - ] - - # compute sizes along each of the interpolation dimensions - sizes_list = [] - for _, coords in coords_interp.items(): - num_coords = len(coords) - coords = np.array(coords) - - # compute distances between midpoints for all internal coords - mid_points = (coords[1:] + coords[:-1]) / 2.0 - dists = np.diff(mid_points) - sizes = np.zeros(num_coords) - sizes[1:-1] = dists - - # estimate the sizes on the edges using 2 x the midpoint distance - sizes[0] = 2 * abs(mid_points[0] - coords[0]) - sizes[-1] = 2 * abs(coords[-1] - mid_points[-1]) - - sizes_list.append(sizes) - - # turn this into a volume element, should be re-sizeable to the gradient shape - if sizes_list: - d_vol = functools.reduce(np.outer, sizes_list) - else: - # if sizes_list is empty, then reduce() fails - d_vol = np.array(1.0) - - # TODO: probably this could be more robust. eg if the DataArray has weird edge cases - E_der_dim = E_der_map[f"E{dim}"] - E_der_dim_interp = ( - E_der_dim.interp(**coords_interp, assume_sorted=True).fillna(0.0).sum(dims_sum).sum("f") - ) - vjp_array = np.array(E_der_dim_interp.values).astype(complex) - vjp_array = vjp_array.reshape(eps_coordinate_shape) - - # multiply by volume elements (if possible, being defensive here..) - try: - vjp_array *= d_vol.reshape(vjp_array.shape) - except ValueError: - log.warning( - "Skipping volume element normalization of 'CustomMedium' gradients. " - f"Could not reshape the volume elements of shape {d_vol.shape} " - f"to the shape of the gradient {vjp_array.shape}. " - "If you encounter this warning, gradient direction will be accurate but the norm " - "will be inaccurate. Please raise an issue on the tidy3d front end with this " - "message and some information about your simulation setup and we will investigate. " - ) - return vjp_array - - -""" Dispersionless Medium """ - - -# PEC keyword -class PECMedium(AbstractMedium): - """Perfect electrical conductor class. - - Note - ---- - - To avoid confusion from duplicate PECs, must import ``tidy3d.PEC`` instance directly. - - - - """ - - @pd.validator("modulation_spec", always=True) - def _validate_modulation_spec(cls, val): - """Check compatibility with modulation_spec.""" - if val is not None: - raise ValidationError( - f"A 'modulation_spec' of class {type(val)} is not " - f"currently supported for medium class {cls.__name__}." - ) - return val - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - # return something like frequency with value of pec_val + 0j - return 0j * frequency + pec_val - - @cached_property - def n_cfl(self): - """This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl``. - """ - return 1.0 - - @cached_property - def is_pec(self): - """Whether the medium is a PEC.""" - return True - - -# PEC builtin instance -PEC = PECMedium(name="PEC") - - -# PMC keyword -class PMCMedium(AbstractMedium): - """Perfect magnetic conductor class. - - Note - ---- - - To avoid confusion from duplicate PMCs, must import ``tidy3d.PMC`` instance directly. - - - - """ - - @pd.validator("modulation_spec", always=True) - def _validate_modulation_spec(cls, val): - """Check compatibility with modulation_spec.""" - if val is not None: - raise ValidationError( - f"A 'modulation_spec' of class {type(val)} is not " - f"currently supported for medium class {cls.__name__}." - ) - return val - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - # permittivity of a PMC. - return 1.0 + 0j - - @cached_property - def n_cfl(self): - """This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl``. - """ - return 1.0 - - @cached_property - def is_pmc(self): - """Whether the medium is a PMC.""" - return True - - -# PEC builtin instance -PMC = PMCMedium(name="PMC") - - -class Medium(AbstractMedium): - """Dispersionless medium. Mediums define the optical properties of the materials within the simulation. - - Notes - ----- - - In a dispersion-less medium, the displacement field :math:`D(t)` reacts instantaneously to the applied - electric field :math:`E(t)`. - - .. math:: - - D(t) = \\epsilon E(t) - - Example - ------- - >>> dielectric = Medium(permittivity=4.0, name='my_medium') - >>> eps = dielectric.eps_model(200e12) - - See Also - -------- - - **Notebooks** - * `Introduction on Tidy3D working principles <../../notebooks/Primer.html#Mediums>`_ - * `Index <../../notebooks/docs/features/medium.html>`_ - - **Lectures** - * `Modeling dispersive material in FDTD `_ - - **GUI** - * `Mediums `_ - - """ - - permittivity: TracedFloat = pd.Field( - 1.0, ge=1.0, title="Permittivity", description="Relative permittivity.", units=PERMITTIVITY - ) - - conductivity: TracedFloat = pd.Field( - 0.0, - title="Conductivity", - description="Electric conductivity. Defined such that the imaginary part of the complex " - "permittivity at angular frequency omega is given by conductivity/omega.", - units=CONDUCTIVITY, - ) - - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): - """Assert passive medium if ``allow_gain`` is False.""" - if not values.get("allow_gain") and val < 0: - raise ValidationError( - "For passive medium, 'conductivity' must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, and are likely to diverge." - ) - return val - - @pd.validator("permittivity", always=True) - @skip_if_fields_missing(["modulation_spec"]) - def _permittivity_modulation_validation(cls, val, values): - """Assert modulated permittivity cannot be <= 0.""" - modulation = values.get("modulation_spec") - if modulation is None or modulation.permittivity is None: - return val - - min_eps_inf = np.min(_get_numpy_array(val)) - if min_eps_inf - modulation.permittivity.max_modulation <= 0: - raise ValidationError( - "The minimum permittivity value with modulation applied was found to be negative." - ) - return val - - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["modulation_spec", "allow_gain"]) - def _passivity_modulation_validation(cls, val, values): - """Assert passive medium if ``allow_gain`` is False.""" - modulation = values.get("modulation_spec") - if modulation is None or modulation.conductivity is None: - return val - - min_sigma = np.min(_get_numpy_array(val)) - if not values.get("allow_gain") and min_sigma - modulation.conductivity.max_modulation < 0: - raise ValidationError( - "For passive medium, 'conductivity' must be non-negative at any time." - "With conductivity modulation, this medium can sometimes be active. " - "Please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - return val - - @cached_property - def n_cfl(self): - """This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl``. - - For dispersiveless medium, it equals ``sqrt(permittivity)``. - """ - permittivity = self.permittivity - if self.modulation_spec is not None and self.modulation_spec.permittivity is not None: - permittivity -= self.modulation_spec.permittivity.max_modulation - n, _ = self.eps_complex_to_nk(permittivity) - return n - - @staticmethod - def _eps_model(permittivity: float, conductivity: float, frequency: float) -> complex: - """Complex-valued permittivity as a function of frequency.""" - - return AbstractMedium.eps_sigma_to_eps_complex(permittivity, conductivity, frequency) - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - """Complex-valued permittivity as a function of frequency.""" - - return self._eps_model(self.permittivity, self.conductivity, frequency) - - @classmethod - def from_nk(cls, n: float, k: float, freq: float, **kwargs: Any): - """Convert ``n`` and ``k`` values at frequency ``freq`` to :class:`.Medium`. - - Parameters - ---------- - n : float - Real part of refractive index. - k : float = 0 - Imaginary part of refrative index. - freq : float - Frequency to evaluate permittivity at (Hz). - kwargs: dict - Keyword arguments passed to the medium construction. - - Returns - ------- - :class:`.Medium` - medium containing the corresponding ``permittivity`` and ``conductivity``. - """ - eps, sigma = AbstractMedium.nk_to_eps_sigma(n, k, freq) - if eps < 1: - raise ValidationError( - "Dispersiveless medium must have 'permittivity>=1`. " - "Please use 'Lorentz.from_nk()' to covert to a Lorentz medium, or the utility " - "function 'td.medium_from_nk()' to automatically return the proper medium type." - ) - return cls(permittivity=eps, conductivity=sigma, **kwargs) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - - # get vjps w.r.t. permittivity and conductivity of the bulk - vjps_volume = self._derivative_eps_sigma_volume( - E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds - ) - - # store the fields asked for by ``field_paths`` - derivative_map = {} - for field_path in derivative_info.paths: - field_name, *_ = field_path - if field_name in vjps_volume: - derivative_map[field_path] = vjps_volume[field_name] - - return derivative_map - - def _derivative_eps_sigma_volume( - self, E_der_map: ElectromagneticFieldDataset, bounds: Bound - ) -> dict[str, xr.DataArray]: - """Get the derivative w.r.t permittivity and conductivity in the volume.""" - - vjp_eps_complex = self._derivative_eps_complex_volume(E_der_map=E_der_map, bounds=bounds) - - values = vjp_eps_complex.values - - # vjp of eps_complex_to_eps_sigma - omegas = 2 * np.pi * vjp_eps_complex.coords["f"].values - eps_vjp = np.real(values) - sigma_vjp = -np.imag(values) / omegas / EPSILON_0 - - eps_vjp = np.sum(eps_vjp) - sigma_vjp = np.sum(sigma_vjp) - - return {"permittivity": eps_vjp, "conductivity": sigma_vjp} - - def _derivative_eps_complex_volume( - self, E_der_map: ElectromagneticFieldDataset, bounds: Bound - ) -> xr.DataArray: - """Get the derivative w.r.t complex-valued permittivity in the volume.""" - - vjp_value = None - for field_name in ("Ex", "Ey", "Ez"): - fld = E_der_map[field_name] - vjp_value_fld = integrate_within_bounds( - arr=fld, - dims=("x", "y", "z"), - bounds=bounds, - ) - if vjp_value is None: - vjp_value = vjp_value_fld - else: - vjp_value += vjp_value_fld - - return vjp_value - - -class CustomIsotropicMedium(AbstractCustomMedium, Medium): - """:class:`.Medium` with user-supplied permittivity distribution. - (This class is for internal use in v2.0; it will be renamed as `CustomMedium` in v3.0.) - - Example - ------- - >>> Nx, Ny, Nz = 10, 9, 8 - >>> X = np.linspace(-1, 1, Nx) - >>> Y = np.linspace(-1, 1, Ny) - >>> Z = np.linspace(-1, 1, Nz) - >>> coords = dict(x=X, y=Y, z=Z) - >>> permittivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) - >>> conductivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) - >>> dielectric = CustomIsotropicMedium(permittivity=permittivity, conductivity=conductivity) - >>> eps = dielectric.eps_model(200e12) - """ - - permittivity: CustomSpatialDataTypeAnnotated = pd.Field( - ..., - title="Permittivity", - description="Relative permittivity.", - units=PERMITTIVITY, - ) - - conductivity: Optional[CustomSpatialDataTypeAnnotated] = pd.Field( - None, - title="Conductivity", - description="Electric conductivity. Defined such that the imaginary part of the complex " - "permittivity at angular frequency omega is given by conductivity/omega.", - units=CONDUCTIVITY, - ) - - _no_nans_eps = validate_no_nans("permittivity") - _no_nans_sigma = validate_no_nans("conductivity") - - @pd.validator("permittivity", always=True) - def _eps_inf_greater_no_less_than_one(cls, val): - """Assert any eps_inf must be >=1""" - - if not CustomIsotropicMedium._validate_isreal_dataarray(val): - raise SetupError("'permittivity' must be real.") - - if np.any(_get_numpy_array(val) < 1): - raise SetupError("'permittivity' must be no less than one.") - - return val - - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["permittivity"]) - def _conductivity_real_and_correct_shape(cls, val, values): - """Assert conductivity is real and of right shape.""" - - if val is None: - return val - - if not CustomIsotropicMedium._validate_isreal_dataarray(val): - raise SetupError("'conductivity' must be real.") - - if not _check_same_coordinates(values["permittivity"], val): - raise SetupError("'permittivity' and 'conductivity' must have the same coordinates.") - return val - - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): - """Assert passive medium if ``allow_gain`` is False.""" - if val is None: - return val - if not values.get("allow_gain") and np.any(_get_numpy_array(val) < 0): - raise ValidationError( - "For passive medium, 'conductivity' must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, and are likely to diverge." - ) - return val - - @cached_property - def is_spatially_uniform(self) -> bool: - """Whether the medium is spatially uniform.""" - if self.conductivity is None: - return self.permittivity.is_uniform - return self.permittivity.is_uniform and self.conductivity.is_uniform - - @cached_property - def n_cfl(self): - """This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl``. - - For dispersiveless medium, it equals ``sqrt(permittivity)``. - """ - permittivity = np.min(_get_numpy_array(self.permittivity)) - if self.modulation_spec is not None and self.modulation_spec.permittivity is not None: - permittivity -= self.modulation_spec.permittivity.max_modulation - n, _ = self.eps_complex_to_nk(permittivity) - return n - - @cached_property - def is_isotropic(self): - """Whether the medium is isotropic.""" - return True - - def eps_dataarray_freq( - self, frequency: float - ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: - """Permittivity array at ``frequency``. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[ - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset` - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset` - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset` - ], - ] - The permittivity evaluated at ``frequency``. - """ - conductivity = self.conductivity - if conductivity is None: - conductivity = _zeros_like(self.permittivity) - eps = self.eps_sigma_to_eps_complex(self.permittivity, conductivity, frequency) - return (eps, eps, eps) - - def _sel_custom_data_inside(self, bounds: Bound): - """Return a new custom medium that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - CustomMedium - CustomMedium with reduced data. - """ - if not self.permittivity.does_cover(bounds=bounds): - log.warning( - "Permittivity spatial data array does not fully cover the requested region." - ) - perm_reduced = self.permittivity.sel_inside(bounds=bounds) - cond_reduced = None - if self.conductivity is not None: - if not self.conductivity.does_cover(bounds=bounds): - log.warning( - "Conductivity spatial data array does not fully cover the requested region." - ) - cond_reduced = self.conductivity.sel_inside(bounds=bounds) - - return self.updated_copy( - permittivity=perm_reduced, - conductivity=cond_reduced, - ) - - -class CustomMedium(AbstractCustomMedium): - """:class:`.Medium` with user-supplied permittivity distribution. - - Example - ------- - >>> Nx, Ny, Nz = 10, 9, 8 - >>> X = np.linspace(-1, 1, Nx) - >>> Y = np.linspace(-1, 1, Ny) - >>> Z = np.linspace(-1, 1, Nz) - >>> coords = dict(x=X, y=Y, z=Z) - >>> permittivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) - >>> conductivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) - >>> dielectric = CustomMedium(permittivity=permittivity, conductivity=conductivity) - >>> eps = dielectric.eps_model(200e12) - """ - - eps_dataset: Optional[PermittivityDataset] = pd.Field( - None, - title="Permittivity Dataset", - description="[To be deprecated] User-supplied dataset containing complex-valued " - "permittivity as a function of space. Permittivity distribution over the Yee-grid " - "will be interpolated based on ``interp_method``.", - ) - - permittivity: Optional[CustomSpatialDataTypeAnnotated] = pd.Field( - None, - title="Permittivity", - description="Spatial profile of relative permittivity.", - units=PERMITTIVITY, - ) - - conductivity: Optional[CustomSpatialDataTypeAnnotated] = pd.Field( - None, - title="Conductivity", - description="Spatial profile Electric conductivity. Defined such " - "that the imaginary part of the complex permittivity at angular " - "frequency omega is given by conductivity/omega.", - units=CONDUCTIVITY, - ) - - _no_nans_eps_dataset = validate_no_nans("eps_dataset") - _no_nans_permittivity = validate_no_nans("permittivity") - _no_nans_sigma = validate_no_nans("conductivity") - - @pd.root_validator(pre=True) - def _warn_if_none(cls, values): - """Warn if the data array fails to load, and return a vacuum medium.""" - eps_dataset = values.get("eps_dataset") - permittivity = values.get("permittivity") - conductivity = values.get("conductivity") - fail_load = False - if cls._not_loaded(permittivity): - log.warning( - "Loading 'permittivity' without data; constructing a vacuum medium instead." - ) - fail_load = True - if cls._not_loaded(conductivity): - log.warning( - "Loading 'conductivity' without data; constructing a vacuum medium instead." - ) - fail_load = True - if isinstance(eps_dataset, dict): - if any((v in DATA_ARRAY_MAP for _, v in eps_dataset.items() if isinstance(v, str))): - log.warning( - "Loading 'eps_dataset' without data; constructing a vacuum medium instead." - ) - fail_load = True - if fail_load: - eps_real = SpatialDataArray(np.ones((1, 1, 1)), coords={"x": [0], "y": [0], "z": [0]}) - return {"permittivity": eps_real} - return values - - @pd.root_validator(pre=True) - def _deprecation_dataset(cls, values): - """Raise deprecation warning if dataset supplied and convert to dataset.""" - - eps_dataset = values.get("eps_dataset") - permittivity = values.get("permittivity") - conductivity = values.get("conductivity") - - # Incomplete custom medium definition. - if eps_dataset is None and permittivity is None and conductivity is None: - raise SetupError("Missing spatial profiles of 'permittivity' or 'eps_dataset'.") - if eps_dataset is None and permittivity is None: - raise SetupError("Missing spatial profiles of 'permittivity'.") - - # Definition racing - if eps_dataset is not None and (permittivity is not None or conductivity is not None): - raise SetupError( - "Please either define 'permittivity' and 'conductivity', or 'eps_dataset', " - "but not both simultaneously." - ) - - if eps_dataset is None: - return values - - # TODO: sometime before 3.0, uncomment these lines to warn users to start using new API - # if isinstance(eps_dataset, dict): - # eps_components = [eps_dataset[f"eps_{dim}{dim}"] for dim in "xyz"] - # else: - # eps_components = [eps_dataset.eps_xx, eps_dataset.eps_yy, eps_dataset.eps_zz] - - # is_isotropic = eps_components[0] == eps_components[1] == eps_components[2] - - # if is_isotropic: - # # deprecation warning for isotropic custom medium - # log.warning( - # "For spatially varying isotropic medium, the 'eps_dataset' field " - # "is being replaced by 'permittivity' and 'conductivity' in v3.0. " - # "We recommend you change your scripts to be compatible with the new API." - # ) - # else: - # # deprecation warning for anisotropic custom medium - # log.warning( - # "For spatially varying anisotropic medium, this class is being replaced " - # "by 'CustomAnisotropicMedium' in v3.0. " - # "We recommend you change your scripts to be compatible with the new API." - # ) - - return values - - @pd.validator("eps_dataset", always=True) - def _eps_dataset_single_frequency(cls, val): - """Assert only one frequency supplied.""" - if val is None: - return val - - for name, eps_dataset_component in val.field_components.items(): - freqs = eps_dataset_component.f - if len(freqs) != 1: - raise SetupError( - f"'eps_dataset.{name}' must have a single frequency, " - f"but it contains {len(freqs)} frequencies." - ) - return val - - @pd.validator("eps_dataset", always=True) - @skip_if_fields_missing(["modulation_spec", "allow_gain"]) - def _eps_dataset_eps_inf_greater_no_less_than_one_sigma_positive(cls, val, values): - """Assert any eps_inf must be >=1""" - if val is None: - return val - modulation = values.get("modulation_spec") - - for comp in ["eps_xx", "eps_yy", "eps_zz"]: - eps_real, sigma = CustomMedium.eps_complex_to_eps_sigma( - val.field_components[comp], val.field_components[comp].f - ) - if np.any(_get_numpy_array(eps_real) < 1): - raise SetupError( - "Permittivity at infinite frequency at any spatial point " - "must be no less than one." - ) - - if modulation is not None and modulation.permittivity is not None: - if np.any(_get_numpy_array(eps_real) - modulation.permittivity.max_modulation <= 0): - raise ValidationError( - "The minimum permittivity value with modulation applied " - "was found to be negative." - ) - - if not values.get("allow_gain") and np.any(_get_numpy_array(sigma) < 0): - raise ValidationError( - "For passive medium, imaginary part of permittivity must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - - if ( - not values.get("allow_gain") - and modulation is not None - and modulation.conductivity is not None - and np.any(_get_numpy_array(sigma) - modulation.conductivity.max_modulation <= 0) - ): - raise ValidationError( - "For passive medium, imaginary part of permittivity must be non-negative " - "at any time. " - "With conductivity modulation, this medium can sometimes be active. " - "Please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - return val - - @pd.validator("permittivity", always=True) - @skip_if_fields_missing(["modulation_spec"]) - def _eps_inf_greater_no_less_than_one(cls, val, values): - """Assert any eps_inf must be >=1""" - if val is None: - return val - - if not CustomMedium._validate_isreal_dataarray(val): - raise SetupError("'permittivity' must be real.") - - if np.any(_get_numpy_array(val) < 1): - raise SetupError("'permittivity' must be no less than one.") - - modulation = values.get("modulation_spec") - if modulation is None or modulation.permittivity is None: - return val - - if np.any(_get_numpy_array(val) - modulation.permittivity.max_modulation <= 0): - raise ValidationError( - "The minimum permittivity value with modulation applied was found to be negative." - ) - - return val - - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["permittivity", "allow_gain"]) - def _conductivity_non_negative_correct_shape(cls, val, values): - """Assert conductivity>=0""" - - if val is None: - return val - - if not CustomMedium._validate_isreal_dataarray(val): - raise SetupError("'conductivity' must be real.") - - if not values.get("allow_gain") and np.any(_get_numpy_array(val) < 0): - raise ValidationError( - "For passive medium, 'conductivity' must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - - if not _check_same_coordinates(values["permittivity"], val): - raise SetupError("'permittivity' and 'conductivity' must have the same coordinates.") - - return val - - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["eps_dataset", "modulation_spec", "allow_gain"]) - def _passivity_modulation_validation(cls, val, values): - """Assert passive medium at any time during modulation if ``allow_gain`` is False.""" - - # validated already when the data is supplied through `eps_dataset` - if values.get("eps_dataset"): - return val - - # permittivity defined with ``permittivity`` and ``conductivity`` - modulation = values.get("modulation_spec") - if values.get("allow_gain") or modulation is None or modulation.conductivity is None: - return val - if val is None or np.any( - _get_numpy_array(val) - modulation.conductivity.max_modulation < 0 - ): - raise ValidationError( - "For passive medium, 'conductivity' must be non-negative at any time. " - "With conductivity modulation, this medium can sometimes be active. " - "Please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - return val - - @pd.validator("permittivity", "conductivity", always=True) - def _check_permittivity_conductivity_interpolate(cls, val, values, field): - """Check that the custom medium 'SpatialDataArrays' can be interpolated.""" - - if isinstance(val, SpatialDataArray): - val._interp_validator(field.name) - - return val - - @cached_property - def is_isotropic(self) -> bool: - """Check if the medium is isotropic or anisotropic.""" - if self.eps_dataset is None: - return True - if self.eps_dataset.eps_xx == self.eps_dataset.eps_yy == self.eps_dataset.eps_zz: - return True - return False - - @cached_property - def is_spatially_uniform(self) -> bool: - """Whether the medium is spatially uniform.""" - return self._medium.is_spatially_uniform - - @cached_property - def freqs(self) -> np.ndarray: - """float array of frequencies. - This field is to be deprecated in v3.0. - """ - # return dummy values in this case - if self.eps_dataset is None: - return np.array([0, 0, 0]) - return np.array( - [ - self.eps_dataset.eps_xx.coords["f"], - self.eps_dataset.eps_yy.coords["f"], - self.eps_dataset.eps_zz.coords["f"], - ] - ) - - @cached_property - def _medium(self): - """Internal representation in the form of - either `CustomIsotropicMedium` or `CustomAnisotropicMedium`. - """ - self_dict = self.dict(exclude={"type", "eps_dataset"}) - # isotropic - if self.eps_dataset is None: - self_dict.update({"permittivity": self.permittivity, "conductivity": self.conductivity}) - return CustomIsotropicMedium.parse_obj(self_dict) - - def get_eps_sigma(eps_complex: SpatialDataArray, freq: float) -> tuple: - """Convert a complex permittivity to real permittivity and conductivity.""" - eps_values = np.array(eps_complex.values) - - eps_real, sigma = CustomMedium.eps_complex_to_eps_sigma(eps_values, freq) - coords = eps_complex.coords - - eps_real = ScalarFieldDataArray(eps_real, coords=coords) - sigma = ScalarFieldDataArray(sigma, coords=coords) - - eps_real = SpatialDataArray(eps_real.squeeze(dim="f", drop=True)) - sigma = SpatialDataArray(sigma.squeeze(dim="f", drop=True)) - - return eps_real, sigma - - # isotropic, but with `eps_dataset` - if self.is_isotropic: - eps_complex = self.eps_dataset.eps_xx - eps_real, sigma = get_eps_sigma(eps_complex, freq=self.freqs[0]) - - self_dict.update({"permittivity": eps_real, "conductivity": sigma}) - return CustomIsotropicMedium.parse_obj(self_dict) - - # anisotropic - mat_comp = {"interp_method": self.interp_method} - for freq, comp in zip(self.freqs, ["xx", "yy", "zz"]): - eps_complex = self.eps_dataset.field_components["eps_" + comp] - eps_real, sigma = get_eps_sigma(eps_complex, freq=freq) - - comp_dict = self_dict.copy() - comp_dict.update({"permittivity": eps_real, "conductivity": sigma}) - mat_comp.update({comp: CustomIsotropicMedium.parse_obj(comp_dict)}) - return CustomAnisotropicMediumInternal(**mat_comp) - - def _interp_method(self, comp: Axis) -> InterpMethod: - """Interpolation method applied to comp.""" - return self._medium._interp_method(comp) - - @cached_property - def n_cfl(self): - """This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl```. - - For dispersiveless custom medium, it equals ``min[sqrt(eps_inf)]``, where ``min`` - is performed over all components and spatial points. - """ - return self._medium.n_cfl - - def eps_dataarray_freq( - self, frequency: float - ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: - """Permittivity array at ``frequency``. () - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[ - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - ] - The permittivity evaluated at ``frequency``. - """ - return self._medium.eps_dataarray_freq(frequency) - - def eps_diagonal_on_grid( - self, - frequency: float, - coords: Coords, - ) -> tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D]: - """Spatial profile of main diagonal of the complex-valued permittivity - at ``frequency`` interpolated at the supplied coordinates. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - coords : :class:`.Coords` - The grid point coordinates over which interpolation is performed. - - Returns - ------- - Tuple[ArrayComplex3D, ArrayComplex3D, ArrayComplex3D] - The complex-valued permittivity tensor at ``frequency`` interpolated - at the supplied coordinate. - """ - return self._medium.eps_diagonal_on_grid(frequency, coords) - - @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: - """Main diagonal of the complex-valued permittivity tensor - at ``frequency``. Spatially, we take max{|eps|}, so that autoMesh generation - works appropriately. - """ - return self._medium.eps_diagonal(frequency) - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - """Spatial and polarizaiton average of complex-valued permittivity - as a function of frequency. - """ - return self._medium.eps_model(frequency) - - @classmethod - def from_eps_raw( - cls, - eps: Union[ScalarFieldDataArray, CustomSpatialDataType], - freq: Optional[float] = None, - interp_method: InterpMethod = "nearest", - **kwargs: Any, - ) -> CustomMedium: - """Construct a :class:`.CustomMedium` from datasets containing raw permittivity values. - - Parameters - ---------- - eps : Union[ - :class:`.SpatialDataArray`, - :class:`.ScalarFieldDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ] - Dataset containing complex-valued permittivity as a function of space. - freq : float, optional - Frequency at which ``eps`` are defined. - interp_method : :class:`.InterpMethod`, optional - Interpolation method to obtain permittivity values that are not supplied - at the Yee grids. - - Notes - ----- - - For lossy medium that has a complex-valued ``eps``, if ``eps`` is supplied through - :class:`.SpatialDataArray`, which doesn't contain frequency information, - the ``freq`` kwarg will be used to evaluate the permittivity and conductivity. - Alternatively, ``eps`` can be supplied through :class:`.ScalarFieldDataArray`, - which contains a frequency coordinate. - In this case, leave ``freq`` kwarg as the default of ``None``. - - Returns - ------- - :class:`.CustomMedium` - Medium containing the spatially varying permittivity data. - """ - if isinstance(eps, CustomSpatialDataType.__args__): - # purely real, not need to know `freq` - if CustomMedium._validate_isreal_dataarray(eps): - return cls(permittivity=eps, interp_method=interp_method, **kwargs) - # complex permittivity, needs to know `freq` - if freq is None: - raise SetupError( - "For a complex 'eps', 'freq' at which 'eps' is defined must be supplied", - ) - eps_real, sigma = CustomMedium.eps_complex_to_eps_sigma(eps, freq) - return cls( - permittivity=eps_real, conductivity=sigma, interp_method=interp_method, **kwargs - ) - - # eps is ScalarFieldDataArray - # contradictory definition of frequency - freq_data = eps.coords["f"].data[0] - if freq is not None and not isclose(freq, freq_data): - raise SetupError( - "'freq' value is inconsistent with the coordinate 'f'" - "in 'eps' DataArray. It's unclear at which frequency 'eps' " - "is defined. Please leave 'freq=None' to use the frequency " - "value in the DataArray." - ) - eps_real, sigma = CustomMedium.eps_complex_to_eps_sigma(eps, freq_data) - eps_real = SpatialDataArray(eps_real.squeeze(dim="f", drop=True)) - sigma = SpatialDataArray(sigma.squeeze(dim="f", drop=True)) - return cls(permittivity=eps_real, conductivity=sigma, interp_method=interp_method, **kwargs) - - @classmethod - def from_nk( - cls, - n: Union[ScalarFieldDataArray, CustomSpatialDataType], - k: Optional[Union[ScalarFieldDataArray, CustomSpatialDataType]] = None, - freq: Optional[float] = None, - interp_method: InterpMethod = "nearest", - **kwargs: Any, - ) -> CustomMedium: - """Construct a :class:`.CustomMedium` from datasets containing n and k values. - - Parameters - ---------- - n : Union[ - :class:`.SpatialDataArray`, - :class:`.ScalarFieldDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ] - Real part of refractive index. - k : Union[ - :class:`.SpatialDataArray`, - :class:`.ScalarFieldDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], optional - Imaginary part of refrative index for lossy medium. - freq : float, optional - Frequency at which ``n`` and ``k`` are defined. - interp_method : :class:`.InterpMethod`, optional - Interpolation method to obtain permittivity values that are not supplied - at the Yee grids. - kwargs: dict - Keyword arguments passed to the medium construction. - - Note - ---- - For lossy medium, if both ``n`` and ``k`` are supplied through - :class:`.SpatialDataArray`, which doesn't contain frequency information, - the ``freq`` kwarg will be used to evaluate the permittivity and conductivity. - Alternatively, ``n`` and ``k`` can be supplied through :class:`.ScalarFieldDataArray`, - which contains a frequency coordinate. - In this case, leave ``freq`` kwarg as the default of ``None``. - - Returns - ------- - :class:`.CustomMedium` - Medium containing the spatially varying permittivity data. - """ - # lossless - if k is None: - if isinstance(n, ScalarFieldDataArray): - n = SpatialDataArray(n.squeeze(dim="f", drop=True)) - freq = 0 # dummy value - eps_real, _ = CustomMedium.nk_to_eps_sigma(n, 0 * n, freq) - return cls(permittivity=eps_real, interp_method=interp_method, **kwargs) - - # lossy case - if not _check_same_coordinates(n, k): - raise SetupError("'n' and 'k' must be of the same type and must have same coordinates.") - - # k is a SpatialDataArray - if isinstance(k, CustomSpatialDataType.__args__): - if freq is None: - raise SetupError( - "For a lossy medium, must supply 'freq' at which to convert 'n' " - "and 'k' to a complex valued permittivity." - ) - eps_real, sigma = CustomMedium.nk_to_eps_sigma(n, k, freq) - return cls( - permittivity=eps_real, conductivity=sigma, interp_method=interp_method, **kwargs - ) - - # k is a ScalarFieldDataArray - freq_data = k.coords["f"].data[0] - if freq is not None and not isclose(freq, freq_data): - raise SetupError( - "'freq' value is inconsistent with the coordinate 'f'" - "in 'k' DataArray. It's unclear at which frequency 'k' " - "is defined. Please leave 'freq=None' to use the frequency " - "value in the DataArray." - ) - - eps_real, sigma = CustomMedium.nk_to_eps_sigma(n, k, freq_data) - eps_real = SpatialDataArray(eps_real.squeeze(dim="f", drop=True)) - sigma = SpatialDataArray(sigma.squeeze(dim="f", drop=True)) - return cls(permittivity=eps_real, conductivity=sigma, interp_method=interp_method, **kwargs) - - def grids(self, bounds: Bound) -> dict[str, Grid]: - """Make a :class:`.Grid` corresponding to the data in each ``eps_ii`` component. - The min and max coordinates along each dimension are bounded by ``bounds``.""" - - rmin, rmax = bounds - pt_mins = dict(zip("xyz", rmin)) - pt_maxs = dict(zip("xyz", rmax)) - - def make_grid(scalar_field: Union[ScalarFieldDataArray, SpatialDataArray]) -> Grid: - """Make a grid for a single dataset.""" - - def make_bound_coords(coords: np.ndarray, pt_min: float, pt_max: float) -> list[float]: - """Convert user supplied coords into boundary coords to use in :class:`.Grid`.""" - - # get coordinates of the bondaries halfway between user-supplied data - coord_bounds = (coords[1:] + coords[:-1]) / 2.0 - - # res-set coord boundaries that lie outside geometry bounds to the boundary (0 vol.) - coord_bounds[coord_bounds <= pt_min] = pt_min - coord_bounds[coord_bounds >= pt_max] = pt_max - - # add the geometry bounds in explicitly - return [pt_min, *coord_bounds.tolist(), pt_max] - - # grab user supplied data long this dimension - coords = {key: np.array(val) for key, val in scalar_field.coords.items()} - spatial_coords = {key: coords[key] for key in "xyz"} - - # convert each spatial coord to boundary coords - bound_coords = {} - for key, coords in spatial_coords.items(): - pt_min = pt_mins[key] - pt_max = pt_maxs[key] - bound_coords[key] = make_bound_coords(coords=coords, pt_min=pt_min, pt_max=pt_max) - - # construct grid - boundaries = Coords(**bound_coords) - return Grid(boundaries=boundaries) - - grids = {} - for field_name in ("eps_xx", "eps_yy", "eps_zz"): - # grab user supplied data long this dimension - scalar_field = self.eps_dataset.field_components[field_name] - - # feed it to make_grid - grids[field_name] = make_grid(scalar_field) - - return grids - - def _sel_custom_data_inside(self, bounds: Bound): - """Return a new custom medium that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - CustomMedium - CustomMedium with reduced data. - """ - - perm_reduced = None - if self.permittivity is not None: - if not self.permittivity.does_cover(bounds=bounds): - log.warning( - "Permittivity spatial data array does not fully cover the requested region." - ) - perm_reduced = self.permittivity.sel_inside(bounds=bounds) - - cond_reduced = None - if self.conductivity is not None: - if not self.conductivity.does_cover(bounds=bounds): - log.warning( - "Conductivity spatial data array does not fully cover the requested region." - ) - cond_reduced = self.conductivity.sel_inside(bounds=bounds) - - eps_reduced = None - if self.eps_dataset is not None: - eps_reduced_dict = {} - for key, comp in self.eps_dataset.field_components.items(): - if not comp.does_cover(bounds=bounds): - log.warning( - f"{key} spatial data array does not fully cover the requested region." - ) - eps_reduced_dict[key] = comp.sel_inside(bounds=bounds) - eps_reduced = PermittivityDataset(**eps_reduced_dict) - - return self.updated_copy( - permittivity=perm_reduced, - conductivity=cond_reduced, - eps_dataset=eps_reduced, - ) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute the adjoint derivatives for this object.""" - - vjps = {} - - for field_path in derivative_info.paths: - if field_path[0] == "permittivity": - vjp_array = 0.0 - for dim in "xyz": - vjp_array += self._derivative_field_cmp( - E_der_map=derivative_info.E_der_map, - spatial_data=self.permittivity, - dim=dim, - freqs=derivative_info.frequencies, - component="real", - ) - vjps[field_path] = vjp_array - - elif field_path[0] == "conductivity": - vjp_array = 0.0 - for dim in "xyz": - vjp_array += self._derivative_field_cmp( - E_der_map=derivative_info.E_der_map, - spatial_data=self.conductivity, - dim=dim, - freqs=derivative_info.frequencies, - component="sigma", - ) - vjps[field_path] = vjp_array - - elif field_path[0] == "eps_dataset": - key = field_path[1] - dim = key[-1] - vjps[field_path] = self._derivative_field_cmp( - E_der_map=derivative_info.E_der_map, - spatial_data=self.eps_dataset.field_components[key], - dim=dim, - freqs=derivative_info.frequencies, - component="complex", - ) - else: - raise NotImplementedError( - f"No derivative defined for 'CustomMedium' field: {field_path}." - ) - - return vjps - - def _derivative_field_cmp( - self, - E_der_map: ElectromagneticFieldDataset, - spatial_data: CustomSpatialDataTypeAnnotated, - dim: str, - freqs: np.ndarray, - component: str = "real", - ) -> np.ndarray: - """Compute the derivative with respect to a material property component.""" - coords_interp = {key: spatial_data.coords[key] for key in "xyz"} - coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1} - - eps_coordinate_shape = [ - len(spatial_data.coords[dim]) for dim in spatial_data.dims if dim in "xyz" - ] - - E_der_dim_interp = E_der_map[f"E{dim}"] - - for dim_ in "xyz": - if dim_ not in coords_interp: - bound_max = np.max(E_der_dim_interp.coords[dim_]) - bound_min = np.min(E_der_dim_interp.coords[dim_]) - dimension_size = bound_max - bound_min - - if dimension_size > 0.0: - E_der_dim_interp = E_der_dim_interp.integrate(dim_) - - # compute sizes along each of the interpolation dimensions - sizes_list = [] - for _, coords in coords_interp.items(): - num_coords = len(coords) - coords = np.array(coords) - - # compute distances between midpoints for all internal coords - mid_points = (coords[1:] + coords[:-1]) / 2.0 - dists = np.diff(mid_points) - sizes = np.zeros(num_coords) - sizes[1:-1] = dists - - # estimate the sizes on the edges using 2 x the midpoint distance - sizes[0] = 2 * abs(mid_points[0] - coords[0]) - sizes[-1] = 2 * abs(coords[-1] - mid_points[-1]) - - sizes_list.append(sizes) - - # turn this into a volume element, should be re-sizeable to the gradient shape - if sizes_list: - d_vol = functools.reduce(np.outer, sizes_list) - else: - # if sizes_list is empty, then reduce() fails - d_vol = np.array(1.0) - - E_der_dim_interp_complex = E_der_dim_interp.interp( - **coords_interp, assume_sorted=True - ).fillna(0.0) - - if component == "sigma": - # compute conductivity gradient from imaginary-permittivity gradient - # apply per-frequency scaling before summing over frequencies - # d eps_imag / d sigma = 1 / (2 * pi * f * EPSILON_0) - E_der_dim_interp = E_der_dim_interp_complex.imag - freqs_da = E_der_dim_interp_complex.coords["f"] - scale = -1.0 / (2.0 * np.pi * freqs_da * EPSILON_0) - E_der_dim_interp *= scale - elif component == "complex": - # for complex permittivity in eps_dataset, return the full complex derivative - E_der_dim_interp = E_der_dim_interp_complex - elif component == "imag": - # pure imaginary component (no conductivity conversion) - E_der_dim_interp = E_der_dim_interp_complex.imag - else: - E_der_dim_interp = E_der_dim_interp_complex.real - - E_der_dim_interp = E_der_dim_interp.sum("f") - - try: - E_der_dim_interp = E_der_dim_interp * d_vol.reshape(E_der_dim_interp.shape) - except ValueError: - log.warning( - "Skipping volume element normalization of 'CustomMedium' gradients. " - f"Could not reshape the volume elements of shape {d_vol.shape} " - f"to the shape of the fields {E_der_dim_interp.shape}. " - "If you encounter this warning, gradient direction will be accurate but the norm " - "will be inaccurate. Please raise an issue on the tidy3d front end with this " - "message and some information about your simulation setup and we will investigate. " - ) - vjp_array = E_der_dim_interp.values - vjp_array = vjp_array.reshape(eps_coordinate_shape) - - return vjp_array - - -""" Dispersive Media """ - - -class DispersiveMedium(AbstractMedium, ABC): - """ - A Medium with dispersion: field propagation characteristics depend on frequency. - - Notes - ----- - - In dispersive mediums, the displacement field :math:`D(t)` depends on the previous electric field :math:`E( - t')` and time-dependent permittivity :math:`\\epsilon` changes. - - .. math:: - - D(t) = \\int \\epsilon(t - t') E(t') \\delta t' - - Dispersive mediums can be defined in three ways: - - - Imported from our `material library <../material_library.html>`_. - - Defined directly by specifying the parameters in the `various supplied dispersive models <../mediums.html>`_. - - Fitted to optical n-k data using the `dispersion fitting tool plugin <../plugins/dispersion.html>`_. - - It is important to keep in mind that dispersive materials are inevitably slower to simulate than their - dispersion-less counterparts, with complexity increasing with the number of poles included in the dispersion - model. For simulations with a narrow range of frequencies of interest, it may sometimes be faster to define - the material through its real and imaginary refractive index at the center frequency. - - - See Also - -------- - - :class:`CustomPoleResidue`: - A spatially varying dispersive medium described by the pole-residue pair model. - - **Notebooks** - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - * `Modeling dispersive material in FDTD `_ - """ - - @staticmethod - def _permittivity_modulation_validation(): - """Assert modulated permittivity cannot be <= 0 at any time.""" - - @pd.validator("eps_inf", allow_reuse=True, always=True) - @skip_if_fields_missing(["modulation_spec"]) - def _validate_permittivity_modulation(cls, val, values): - """Assert modulated permittivity cannot be <= 0.""" - modulation = values.get("modulation_spec") - if modulation is None or modulation.permittivity is None: - return val - - min_eps_inf = np.min(_get_numpy_array(val)) - if min_eps_inf - modulation.permittivity.max_modulation <= 0: - raise ValidationError( - "The minimum permittivity value with modulation applied was found to be negative." - ) - return val - - return _validate_permittivity_modulation - - @staticmethod - def _conductivity_modulation_validation(): - """Assert passive medium at any time if not ``allow_gain``.""" - - @pd.validator("modulation_spec", allow_reuse=True, always=True) - @skip_if_fields_missing(["allow_gain"]) - def _validate_conductivity_modulation(cls, val, values): - """With conductivity modulation, the medium can exhibit gain during the cycle. - So `allow_gain` must be True when the conductivity is modulated. - """ - if val is None or val.conductivity is None: - return val - - if not values.get("allow_gain"): - raise ValidationError( - "For passive medium, 'conductivity' must be non-negative at any time. " - "With conductivity modulation, this medium can sometimes be active. " - "Please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, and are likely to diverge." - ) - return val - - return _validate_conductivity_modulation - - @abstractmethod - def _pole_residue_dict(self) -> dict: - """Dict representation of Medium as a pole-residue model.""" - - @cached_property - def pole_residue(self): - """Representation of Medium as a pole-residue model.""" - return PoleResidue(**self._pole_residue_dict(), allow_gain=self.allow_gain) - - @cached_property - def n_cfl(self): - """This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl``. - - For PoleResidue model, it equals ``sqrt(eps_inf)`` - [https://ieeexplore.ieee.org/document/9082879]. - """ - permittivity = self.pole_residue.eps_inf - if self.modulation_spec is not None and self.modulation_spec.permittivity is not None: - permittivity -= self.modulation_spec.permittivity.max_modulation - n, _ = self.eps_complex_to_nk(permittivity) - return n - - @staticmethod - def tuple_to_complex(value: tuple[float, float]) -> complex: - """Convert a tuple of real and imaginary parts to complex number.""" - - val_r, val_i = value - return val_r + 1j * val_i - - @staticmethod - def complex_to_tuple(value: complex) -> tuple[float, float]: - """Convert a complex number to a tuple of real and imaginary parts.""" - - return (value.real, value.imag) - - # --- shared autograd helpers for dispersive models --- - def _tjp_inputs(self, derivative_info): - """Prepare shared inputs for TJP: frequencies and packed adjoint vector.""" - dJ = self._derivative_eps_complex_volume( - E_der_map=derivative_info.E_der_map, bounds=derivative_info.bounds - ) - freqs = np.asarray(derivative_info.frequencies, float) - dJv = np.asarray(getattr(dJ, "values", dJ)) - return freqs, pack_complex_vec(dJv) - - @staticmethod - def _tjp_grad(theta0, eps_vec_fn, vec): - """Run a tensor-Jacobian-product to get J^T @ vec.""" - return tensor_jacobian_product(eps_vec_fn)(theta0, vec) - - @staticmethod - def _map_grad_real(g, paths, mapping): - """Map flat gradient to model paths, taking the real part.""" - out = {} - for k, idx in mapping: - if k in paths: - out[k] = np.real(g[idx]) - return out - - -class CustomDispersiveMedium(AbstractCustomMedium, DispersiveMedium, ABC): - """A spatially varying dispersive medium.""" - - @cached_property - def n_cfl(self): - """This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl``. - - For PoleResidue model, it equals ``sqrt(eps_inf)`` - [https://ieeexplore.ieee.org/document/9082879]. - """ - permittivity = np.min(_get_numpy_array(self.pole_residue.eps_inf)) - if self.modulation_spec is not None and self.modulation_spec.permittivity is not None: - permittivity -= self.modulation_spec.permittivity.max_modulation - n, _ = self.eps_complex_to_nk(permittivity) - return n - - @cached_property - def is_isotropic(self): - """Whether the medium is isotropic.""" - return True - - @cached_property - def pole_residue(self): - """Representation of Medium as a pole-residue model.""" - return CustomPoleResidue( - **self._pole_residue_dict(), - interp_method=self.interp_method, - allow_gain=self.allow_gain, - subpixel=self.subpixel, - ) - - @staticmethod - def _warn_if_data_none(nested_tuple_field: str): - """Warn if any of `eps_inf` and nested_tuple_field are not loaded, - and return a vacuum with eps_inf = 1. - """ - - @pd.root_validator(pre=True, allow_reuse=True) - def _warn_if_none(cls, values): - """Warn if any of `eps_inf` and nested_tuple_field are not load.""" - eps_inf = values.get("eps_inf") - coeffs = values.get(nested_tuple_field) - fail_load = False - - if AbstractCustomMedium._not_loaded(eps_inf): - log.warning("Loading 'eps_inf' without data; constructing a vacuum medium instead.") - fail_load = True - for coeff in coeffs: - if fail_load: - break - for coeff_i in coeff: - if AbstractCustomMedium._not_loaded(coeff_i): - log.warning( - f"Loading '{nested_tuple_field}' without data; " - "constructing a vacuum medium instead." - ) - fail_load = True - break - - if fail_load and eps_inf is None: - return {nested_tuple_field: ()} - if fail_load: - eps_inf = SpatialDataArray( - np.ones((1, 1, 1)), coords={"x": [0], "y": [0], "z": [0]} - ) - return {"eps_inf": eps_inf, nested_tuple_field: ()} - return values - - return _warn_if_none - - # --- helpers for custom dispersive adjoints --- - def _sum_complex_eps_sensitivity( - self, - derivative_info: DerivativeInfo, - spatial_ref: PermittivityDataset, - ) -> np.ndarray: - """Sum complex permittivity sensitivities over xyz on the given spatial grid. - - Parameters - ---------- - derivative_info : DerivativeInfo - Info bundle carrying field maps and frequencies. - spatial_ref : PermittivityDataset - Spatial dataset to define the grid/coords for interpolation and summation. - - Returns - ------- - np.ndarray - Complex-valued aggregated dJ array with the same spatial shape as ``spatial_ref``. - """ - dJ = 0.0 + 0.0j - for dim in "xyz": - dJ += self._derivative_field_cmp( - E_der_map=derivative_info.E_der_map, - spatial_data=spatial_ref, - dim=dim, - ) - return dJ - - @staticmethod - def _accum_real_inner(dJ: np.ndarray, weight: np.ndarray) -> np.ndarray: - """Compute Re(dJ * conj(weight)) with proper broadcasting.""" - return np.real(dJ * np.conj(weight)) - - def _sum_over_freqs( - self, freqs: list[float] | np.ndarray, dJ: np.ndarray, weight_fn - ) -> np.ndarray: - """Accumulate gradient contributions over frequencies using provided weight function. - - Parameters - ---------- - freqs : array-like - Frequencies to accumulate over. - dJ : np.ndarray - Complex dataset sensitivity with spatial shape. - weight_fn : Callable[[float], np.ndarray] - Function mapping frequency to weight array broadcastable to dJ. - - Returns - ------- - np.ndarray - Real-valued gradient array matching dJ's broadcasted shape. - """ - g = 0.0 - for f in freqs: - g = g + self._accum_real_inner(dJ, weight_fn(f)) - return g - - -class PoleResidue(DispersiveMedium): - """A dispersive medium described by the pole-residue pair model. - - Notes - ----- - - The frequency-dependence of the complex-valued permittivity is described by: - - .. math:: - - \\epsilon(\\omega) = \\epsilon_\\infty - \\sum_i - \\left[\\frac{c_i}{j \\omega + a_i} + - \\frac{c_i^*}{j \\omega + a_i^*}\\right] - - Example - ------- - >>> pole_res = PoleResidue(eps_inf=2.0, poles=[((-1+2j), (3+4j)), ((-5+6j), (7+8j))]) - >>> eps = pole_res.eps_model(200e12) - - See Also - -------- - - :class:`CustomPoleResidue`: - A spatially varying dispersive medium described by the pole-residue pair model. - - **Notebooks** - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - * `Modeling dispersive material in FDTD `_ - """ - - eps_inf: TracedPositiveFloat = pd.Field( - 1.0, - title="Epsilon at Infinity", - description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", - units=PERMITTIVITY, - ) - - poles: tuple[TracedPoleAndResidue, ...] = pd.Field( - (), - title="Poles", - description="Tuple of complex-valued (:math:`a_i, c_i`) poles for the model.", - units=(RADPERSEC, RADPERSEC), - ) - - @pd.validator("poles", always=True) - def _causality_validation(cls, val): - """Assert causal medium.""" - for a, _ in val: - if np.any(np.real(_get_numpy_array(a)) > 0): - raise SetupError("For stable medium, 'Re(a_i)' must be non-positive.") - return val - - @pd.validator("poles", always=True) - def _poles_largest_value(cls, val): - """Assert pole parameters are not too large.""" - for a, c in val: - if np.any(abs(_get_numpy_array(a)) > LARGEST_FP_NUMBER): - raise ValidationError( - "The value of some 'a_i' is too large. They are unlikely to contribute to material dispersion." - ) - if np.any(abs(_get_numpy_array(c)) > LARGEST_FP_NUMBER): - raise ValidationError("The value of some 'c_i' is too large.") - return val - - _validate_permittivity_modulation = DispersiveMedium._permittivity_modulation_validation() - _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() - - @staticmethod - def _eps_model( - eps_inf: pd.PositiveFloat, poles: tuple[PoleAndResidue, ...], frequency: float - ) -> complex: - """Complex-valued permittivity as a function of frequency.""" - - omega = 2 * np.pi * frequency - eps = eps_inf + 0 * frequency + 0.0j - for a, c in poles: - a_cc = np.conj(a) - c_cc = np.conj(c) - eps = eps - c / (1j * omega + a) - eps = eps - c_cc / (1j * omega + a_cc) - return eps - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - """Complex-valued permittivity as a function of frequency.""" - return self._eps_model(eps_inf=self.eps_inf, poles=self.poles, frequency=frequency) - - def _pole_residue_dict(self) -> dict: - """Dict representation of Medium as a pole-residue model.""" - - return { - "eps_inf": self.eps_inf, - "poles": self.poles, - "frequency_range": self.frequency_range, - "name": self.name, - } - - def __str__(self): - """string representation""" - return ( - f"td.PoleResidue(" - f"\n\teps_inf={self.eps_inf}, " - f"\n\tpoles={self.poles}, " - f"\n\tfrequency_range={self.frequency_range})" - ) - - @classmethod - def from_medium(cls, medium: Medium) -> PoleResidue: - """Convert a :class:`.Medium` to a pole residue model. - - Parameters - ---------- - medium: :class:`.Medium` - The medium with permittivity and conductivity to convert. - - Returns - ------- - :class:`.PoleResidue` - The pole residue equivalent. - """ - poles = [(0, medium.conductivity / (2 * EPSILON_0))] - return PoleResidue( - eps_inf=medium.permittivity, poles=poles, frequency_range=medium.frequency_range - ) - - def to_medium(self) -> Medium: - """Convert to a :class:`.Medium`. - Requires the pole residue model to only have a pole at 0 frequency, - corresponding to a constant conductivity term. - - Returns - ------- - :class:`.Medium` - The non-dispersive equivalent with constant permittivity and conductivity. - """ - res = 0 - for a, c in self.poles: - if abs(a) > fp_eps: - raise ValidationError("Cannot convert dispersive 'PoleResidue' to 'Medium'.") - res = res + (c + np.conj(c)) / 2 - sigma = res * 2 * EPSILON_0 - return Medium( - permittivity=self.eps_inf, - conductivity=np.real(sigma), - frequency_range=self.frequency_range, - ) - - @staticmethod - def lo_to_eps_model( - poles: tuple[tuple[float, float, float, float], ...], - eps_inf: pd.PositiveFloat, - frequency: float, - ) -> complex: - """Complex permittivity as a function of frequency for a given set of LO-TO coefficients. - See ``from_lo_to`` in :class:`.PoleResidue` for the detailed form of the model - and a reference paper. - - Parameters - ---------- - poles : Tuple[Tuple[float, float, float, float], ...] - The LO-TO poles, given as list of tuples of the form - (omega_LO, gamma_LO, omega_TO, gamma_TO). - eps_inf: pd.PositiveFloat - The relative permittivity at infinite frequency. - frequency: float - Frequency at which to evaluate the permittivity. - - Returns - ------- - complex - The complex permittivity of the given LO-TO model at the given frequency. - """ - omega = 2 * np.pi * frequency - eps = eps_inf - for omega_lo, gamma_lo, omega_to, gamma_to in poles: - eps *= omega_lo**2 - omega**2 - 1j * omega * gamma_lo - eps /= omega_to**2 - omega**2 - 1j * omega * gamma_to - return eps - - @classmethod - def from_lo_to( - cls, poles: tuple[tuple[float, float, float, float], ...], eps_inf: pd.PositiveFloat = 1 - ) -> PoleResidue: - """Construct a pole residue model from the LO-TO form - (longitudinal and transverse optical modes). - The LO-TO form is :math:`\\epsilon_\\infty \\prod_{i=1}^l \\frac{\\omega_{LO, i}^2 - \\omega^2 - i \\omega \\gamma_{LO, i}}{\\omega_{TO, i}^2 - \\omega^2 - i \\omega \\gamma_{TO, i}}` as given in the paper: - - M. Schubert, T. E. Tiwald, and C. M. Herzinger, - "Infrared dielectric anisotropy and phonon modes of sapphire," - Phys. Rev. B 61, 8187 (2000). - - Parameters - ---------- - poles : Tuple[Tuple[float, float, float, float], ...] - The LO-TO poles, given as list of tuples of the form - (omega_LO, gamma_LO, omega_TO, gamma_TO). - eps_inf: pd.PositiveFloat - The relative permittivity at infinite frequency. - - Returns - ------- - :class:`.PoleResidue` - The pole residue equivalent of the LO-TO form provided. - """ - - omegas_lo, gammas_lo, omegas_to, gammas_to = map(np.array, zip(*poles)) - - # discriminants of quadratic factors of denominator - discs = 2 * npo.emath.sqrt((gammas_to / 2) ** 2 - omegas_to**2) - - # require nondegenerate TO poles - if len({(omega_to, gamma_to) for (_, _, omega_to, gamma_to) in poles}) != len(poles) or any( - disc == 0 for disc in discs - ): - raise ValidationError( - "Unable to construct a pole residue model " - "from an LO-TO form with degenerate TO poles. Consider adding a " - "perturbation to split the poles, or using " - "'PoleResidue.lo_to_eps_model' and fitting with the 'FastDispersionFitter'." - ) - - # roots of denominator, in pairs - roots = [] - for gamma_to, disc in zip(gammas_to, discs): - roots.append(-gamma_to / 2 + disc / 2) - roots.append(-gamma_to / 2 - disc / 2) - - # interpolants - interpolants = eps_inf * np.ones(len(roots), dtype=complex) - for i, a in enumerate(roots): - for omega_lo, gamma_lo in zip(omegas_lo, gammas_lo): - interpolants[i] *= omega_lo**2 + a**2 + a * gamma_lo - for j, a2 in enumerate(roots): - if j != i: - interpolants[i] /= a - a2 - - a_coeffs = [] - c_coeffs = [] - - for i in range(0, len(roots), 2): - if not np.isreal(roots[i]): - a_coeffs.append(roots[i]) - c_coeffs.append(interpolants[i]) - else: - a_coeffs.append(roots[i]) - a_coeffs.append(roots[i + 1]) - # factor of two from adding conjugate pole of real pole - c_coeffs.append(interpolants[i] / 2) - c_coeffs.append(interpolants[i + 1] / 2) - - return PoleResidue(eps_inf=eps_inf, poles=list(zip(a_coeffs, c_coeffs))) - - @staticmethod - def imag_ep_extrema(poles: tuple[PoleAndResidue, ...]) -> ArrayFloat1D: - """Extrema of Im[eps] in the same unit as poles. - - Parameters - ---------- - poles: Tuple[PoleAndResidue, ...] - Tuple of complex-valued (``a_i, c_i``) poles for the model. - """ - - poles_a = [a for (a, _) in poles] - poles_c = [c for (_, c) in poles] - return imag_resp_extrema_locs(poles=poles_a, residues=poles_c) - - def _imag_ep_extrema_with_samples(self) -> ArrayFloat1D: - """Provide a list of frequencies (in unit of rad/s) to probe the possible lower and - upper bound of Im[eps] within the ``frequency_range``. If ``frequency_range`` is None, - it checks the entire frequency range. The returned frequencies include not only extrema, - but also a list of sampled frequencies. - """ - - # extrema frequencies: in the intermediate stage, convert to the unit eV for - # better numerical handling, since those quantities will be ~ 1 in photonics - extrema_freq = self.imag_ep_extrema(self.angular_freq_to_eV(np.array(self.poles))) - extrema_freq = self.eV_to_angular_freq(extrema_freq) - - # let's check a big range in addition to the imag_extrema - if self.frequency_range is None: - range_ev = np.logspace(LOSS_CHECK_MIN, LOSS_CHECK_MAX, LOSS_CHECK_NUM) - range_omega = self.eV_to_angular_freq(range_ev) - else: - fmin, fmax = self.frequency_range - fmin = max(fmin, fp_eps) - range_freq = np.logspace(np.log10(fmin), np.log10(fmax), LOSS_CHECK_NUM) - range_omega = self.Hz_to_angular_freq(range_freq) - - extrema_freq = extrema_freq[ - np.logical_and(extrema_freq > range_omega[0], extrema_freq < range_omega[-1]) - ] - return np.concatenate((range_omega, extrema_freq)) - - @cached_property - def loss_upper_bound(self) -> float: - """Upper bound of Im[eps] in `frequency_range`""" - freq_list = self.angular_freq_to_Hz(self._imag_ep_extrema_with_samples()) - ep = self.eps_model(freq_list) - # filter `NAN` in case some of freq_list are exactly at the pole frequency - # of Sellmeier-type poles. - ep = ep[~np.isnan(ep)] - return max(ep.imag) - - @staticmethod - def _get_vjps_from_params( - dJ_deps_complex: Union[complex, np.ndarray], - poles_vals: list[tuple[Union[complex, np.ndarray], Union[complex, np.ndarray]]], - omega: float, - requested_paths: list[tuple], - project_real: bool = False, - ) -> AutogradFieldMap: - """ - Static helper to compute VJPs from parameters using the analytical chain rule. - - Parameters - - dJ_deps_complex: Complex adjoint sensitivity w.r.t. epsilon at a single frequency. - - poles_vals: Sequence of (a_i, c_i) pole parameters to differentiate with respect to. - - omega: Angular frequency for this VJP evaluation. - - requested_paths: Paths requested by the caller; used to filter outputs. - - project_real: If True, project pole-parameter VJPs to their real part. - Use True for uniform PoleResidue to match real-valued objectives; use False for - CustomPoleResidue where parameters are complex and complex VJPs are required. - """ - jw = 1j * omega - vjps = {} - - if ("eps_inf",) in requested_paths: - vjps[("eps_inf",)] = np.real(dJ_deps_complex) - - for i, (a_val, c_val) in enumerate(poles_vals): - if any(path[1] == i for path in requested_paths if path[0] == "poles"): - if ("poles", i, 0) in requested_paths: - deps_da = c_val / (jw + a_val) ** 2 - dJ_da = dJ_deps_complex * deps_da - vjps[("poles", i, 0)] = np.real(dJ_da) if project_real else dJ_da - if ("poles", i, 1) in requested_paths: - deps_dc = -1 / (jw + a_val) - dJ_dc = dJ_deps_complex * deps_dc - vjps[("poles", i, 1)] = np.real(dJ_dc) if project_real else dJ_dc - - return vjps - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute adjoint derivatives by preparing scalar data and calling the static helper.""" - - dJ_deps_complex = self._derivative_eps_complex_volume( - E_der_map=derivative_info.E_der_map, - bounds=derivative_info.bounds, - ) - - poles_vals = [(complex(a), complex(c)) for a, c in self.poles] - - freqs = dJ_deps_complex.coords["f"].values - vjps_total = {} - - for freq in freqs: - dJ_deps_complex_f = dJ_deps_complex.sel(f=freq) - vjps_f = self._get_vjps_from_params( - dJ_deps_complex=complex(dJ_deps_complex_f), - poles_vals=poles_vals, - omega=2 * np.pi * freq, - requested_paths=derivative_info.paths, - project_real=True, - ) - for path, vjp in vjps_f.items(): - if path not in vjps_total: - vjps_total[path] = vjp - else: - vjps_total[path] += vjp - - return vjps_total - - @classmethod - def _real_partial_fraction_decomposition( - cls, a: np.ndarray, b: np.ndarray, tol: pd.PositiveFloat = 1e-2 - ) -> tuple[list[tuple[Complex, Complex]], np.ndarray]: - """Computes the complex conjugate pole residue pairs given a rational expression with - real coefficients. - - Parameters - ---------- - - a : np.ndarray - Coefficients of the numerator polynomial in increasing monomial order. - b : np.ndarray - Coefficients of the denominator polynomial in increasing monomial order. - tol : pd.PositiveFloat - Tolerance for pole finding. Two poles are considered equal, if their spacing is less - than ``tol``. - - Returns - ------- - tuple[list[tuple[Complex, Complex]], np.ndarray] - The list of complex conjugate poles and their associated residues. The second element of the - ``tuple`` is an array of coefficients representing any direct polynomial term. - - """ - from scipy import signal - - if a.ndim != 1 or np.any(np.iscomplex(a)): - raise ValidationError( - "Numerator coefficients must be a one-dimensional array of real numbers." - ) - if b.ndim != 1 or np.any(np.iscomplex(b)): - raise ValidationError( - "Denominator coefficients must be a one-dimensional array of real numbers." - ) - - # Compute residues and poles using scipy - (r, p, k) = signal.residue(np.flip(a), np.flip(b), tol=tol, rtype="avg") - - # Assuming real coefficients for the polynomials, the poles should be real or come as - # complex conjugate pairs - r_filtered = [] - p_filtered = [] - for res, (idx, pole) in zip(list(r), enumerate(list(p))): - # Residue equal to zero interpreted as rational expression was not - # in simplest form. So skip this pole. - if res == 0: - continue - # Causal and stability check - if np.real(pole) > 0: - raise ValidationError("Transfer function is invalid. It is non-causal.") - # Check for higher order pole, which come in consecutive order - if idx > 0 and p[idx - 1] == pole: - raise ValidationError( - "Transfer function is invalid. A higher order pole was detected. Try reducing ``tol``, " - "or ensure that the rational expression does not have repeated poles. " - ) - if np.imag(pole) == 0: - r_filtered.append(res / 2) - p_filtered.append(pole) - else: - pair_found = len(np.argwhere(np.array(p) == np.conj(pole))) == 1 - if not pair_found: - raise ValueError( - "Failed to find complex-conjugate of pole in poles computed by SciPy." - ) - previously_added = len(np.argwhere(np.array(p_filtered) == np.conj(pole))) == 1 - if not previously_added: - r_filtered.append(res) - p_filtered.append(pole) - - poles_residues = list(zip(p_filtered, r_filtered)) - k_increasing_order = np.flip(k) - return (poles_residues, k_increasing_order) - - @classmethod - def from_admittance_coeffs( - cls, - a: np.ndarray, - b: np.ndarray, - eps_inf: pd.PositiveFloat = 1, - pole_tol: pd.PositiveFloat = 1e-2, - ) -> PoleResidue: - """Construct a :class:`.PoleResidue` model from an admittance function defining the - relationship between the electric field and the polarization current density in the - Laplace domain. - - Parameters - ---------- - a : np.ndarray - Coefficients of the numerator polynomial in increasing monomial order. - b : np.ndarray - Coefficients of the denominator polynomial in increasing monomial order. - eps_inf: pd.PositiveFloat - The relative permittivity at infinite frequency. - pole_tol: pd.PositiveFloat - Tolerance for the pole finding algorithm in Hertz. Two poles are considered equal, if their - spacing is closer than ``pole_tol`. - Returns - ------- - :class:`.PoleResidue` - The pole residue equivalent. - - Notes - ----- - - The supplied admittance function relates the electric field to the polarization current density - in the Laplace domain and is equivalent to a frequency-dependent complex conductivity - :math:`\\sigma(\\omega)`. - - .. math:: - J_p(s) = Y(s)E(s) - - .. math:: - Y(s) = \\frac{a_0 + a_1 s + \\dots + a_M s^M}{b_0 + b_1 s + \\dots + b_N s^N} - - An equivalent :class:`.PoleResidue` medium is constructed using an equivalent frequency-dependent - complex permittivity defined as - - .. math:: - \\epsilon(s) = \\epsilon_\\infty - \\frac{1}{\\epsilon_0 s} - \\frac{a_0 + a_1 s + \\dots + a_M s^M}{b_0 + b_1 s + \\dots + b_N s^N}. - """ - - if a.ndim != 1 or np.any(np.logical_or(np.iscomplex(a), a < 0)): - raise ValidationError( - "Numerator coefficients must be a one-dimensional array of non-negative real numbers." - ) - if b.ndim != 1 or np.any(np.logical_or(np.iscomplex(b), b < 0)): - raise ValidationError( - "Denominator coefficients must be a one-dimensional array of non-negative real numbers." - ) - - # Trim any trailing zeros, so that length corresponds with polynomial order - a = np.trim_zeros(a, "b") - b = np.trim_zeros(b, "b") - - # Validate that transfer function will result in a proper transfer function, once converted to - # the complex permittivity version - # Let q equal the order of the numerator polynomial, and p equal the order - # of the denominator polynomal. Then, q < p is strictly proper rational transfer function (RTF) - # q <= p is a proper RTF, and q > p is an improper RTF. - q = len(a) - 1 - p = len(b) - 1 - - if q > p + 1: - raise ValidationError( - "Transfer function is improper, the order of the numerator polynomial must be at most " - "one greater than the order of the denominator polynomial." - ) - - # Modify the transfer function defining a complex conductivity to match the complex - # frequency-dependent portion of the pole residue model - # Meaning divide by -j*omega*epsilon (s*epsilon) - b = np.concatenate(([0], b * EPSILON_0)) - - poles_and_residues, k = cls._real_partial_fraction_decomposition( - a=a, b=b, tol=pole_tol * 2 * np.pi - ) - - # A direct polynomial term of zeroth order is interpreted as an additional contribution to eps_inf. - # So we only handle that special case. - if len(k) == 1: - if np.iscomplex(k[0]) or k[0] < 0: - raise ValidationError( - "Transfer function is invalid. Direct polynomial term must be real and positive for " - "conversion to an equivalent 'PoleResidue' medium." - ) - # A pure capacitance will translate to an increased permittivity at infinite frequency. - eps_inf = eps_inf + k[0] - - pole_residue_from_transfer = PoleResidue(eps_inf=eps_inf, poles=poles_and_residues) - - # Check passivity - ang_freqs = PoleResidue._imag_ep_extrema_with_samples(pole_residue_from_transfer) - freq_list = PoleResidue.angular_freq_to_Hz(ang_freqs) - ep = pole_residue_from_transfer.eps_model(freq_list) - # filter `NAN` in case some of freq_list are exactly at the pole frequency - ep = ep[~np.isnan(ep)] - - if np.any(np.imag(ep) < -fp_eps): - log.warning( - "Generated 'PoleResidue' medium is not passive. Please raise an issue on the " - "Tidy3d frontend with this message and some information about your " - "simulation setup and we will investigate." - ) - - return pole_residue_from_transfer - - -class CustomPoleResidue(CustomDispersiveMedium, PoleResidue): - """A spatially varying dispersive medium described by the pole-residue pair model. - - Notes - ----- - - In this method, the frequency-dependent permittivity :math:`\\epsilon(\\omega)` is expressed as a sum of - resonant material poles _`[1]`. - - .. math:: - - \\epsilon(\\omega) = \\epsilon_\\infty - \\sum_i - \\left[\\frac{c_i}{j \\omega + a_i} + - \\frac{c_i^*}{j \\omega + a_i^*}\\right] - - For each of these resonant poles identified by the index :math:`i`, an auxiliary differential equation is - used to relate the auxiliary current :math:`J_i(t)` to the applied electric field :math:`E(t)`. - The sum of all these auxiliary current contributions describes the total dielectric response of the material. - - .. math:: - - \\frac{d}{dt} J_i (t) - a_i J_i (t) = \\epsilon_0 c_i \\frac{d}{dt} E (t) - - Hence, the computational cost increases with the number of poles. - - **References** - - .. [1] M. Han, R.W. Dutton and S. Fan, IEEE Microwave and Wireless Component Letters, 16, 119 (2006). - - .. TODO add links to notebooks using this. - - Example - ------- - >>> x = np.linspace(-1, 1, 5) - >>> y = np.linspace(-1, 1, 6) - >>> z = np.linspace(-1, 1, 7) - >>> coords = dict(x=x, y=y, z=z) - >>> eps_inf = SpatialDataArray(np.ones((5, 6, 7)), coords=coords) - >>> a1 = SpatialDataArray(-np.random.random((5, 6, 7)), coords=coords) - >>> c1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) - >>> a2 = SpatialDataArray(-np.random.random((5, 6, 7)), coords=coords) - >>> c2 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) - >>> pole_res = CustomPoleResidue(eps_inf=eps_inf, poles=[(a1, c1), (a2, c2)]) - >>> eps = pole_res.eps_model(200e12) - - See Also - -------- - - **Notebooks** - - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - - * `Modeling dispersive material in FDTD `_ - """ - - eps_inf: CustomSpatialDataTypeAnnotated = pd.Field( - ..., - title="Epsilon at Infinity", - description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", - units=PERMITTIVITY, - ) - - poles: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( - pd.Field( - (), - title="Poles", - description="Tuple of complex-valued (:math:`a_i, c_i`) poles for the model.", - units=(RADPERSEC, RADPERSEC), - ) - ) - - _no_nans_eps_inf = validate_no_nans("eps_inf") - _no_nans_poles = validate_no_nans("poles") - _warn_if_none = CustomDispersiveMedium._warn_if_data_none("poles") - - @pd.validator("eps_inf", always=True) - def _eps_inf_positive(cls, val): - """eps_inf must be positive""" - if not CustomDispersiveMedium._validate_isreal_dataarray(val): - raise SetupError("'eps_inf' must be real.") - if np.any(_get_numpy_array(val) < 0): - raise SetupError("'eps_inf' must be positive.") - return val - - @pd.validator("poles", always=True) - @skip_if_fields_missing(["eps_inf"]) - def _poles_correct_shape(cls, val, values): - """poles must have the same shape.""" - - for coeffs in val: - for coeff in coeffs: - if not _check_same_coordinates(coeff, values["eps_inf"]): - raise SetupError( - "All pole coefficients 'a' and 'c' must have the same coordinates; " - "The coordinates must also be consistent with 'eps_inf'." - ) - return val - - @cached_property - def is_spatially_uniform(self) -> bool: - """Whether the medium is spatially uniform.""" - if not self.eps_inf.is_uniform: - return False - - for coeffs in self.poles: - for coeff in coeffs: - if not coeff.is_uniform: - return False - return True - - def eps_dataarray_freq( - self, frequency: float - ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: - """Permittivity array at ``frequency``. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[ - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - ] - The permittivity evaluated at ``frequency``. - """ - eps = PoleResidue.eps_model(self, frequency) - return (eps, eps, eps) - - def poles_on_grid(self, coords: Coords) -> tuple[tuple[ArrayComplex3D, ArrayComplex3D], ...]: - """Spatial profile of poles interpolated at the supplied coordinates. - - Parameters - ---------- - coords : :class:`.Coords` - The grid point coordinates over which interpolation is performed. - - Returns - ------- - Tuple[Tuple[ArrayComplex3D, ArrayComplex3D], ...] - The poles interpolated at the supplied coordinate. - """ - - def fun_interp(input_data: SpatialDataArray) -> ArrayComplex3D: - return _get_numpy_array(coords.spatial_interp(input_data, self.interp_method)) - - return tuple((fun_interp(a), fun_interp(c)) for (a, c) in self.poles) - - @classmethod - def from_medium(cls, medium: CustomMedium) -> CustomPoleResidue: - """Convert a :class:`.CustomMedium` to a pole residue model. - - Parameters - ---------- - medium: :class:`.CustomMedium` - The medium with permittivity and conductivity to convert. - - Returns - ------- - :class:`.CustomPoleResidue` - The pole residue equivalent. - """ - poles = [(_zeros_like(medium.conductivity), medium.conductivity / (2 * EPSILON_0))] - medium_dict = medium.dict(exclude={"type", "eps_dataset", "permittivity", "conductivity"}) - medium_dict.update({"eps_inf": medium.permittivity, "poles": poles}) - return CustomPoleResidue.parse_obj(medium_dict) - - def to_medium(self) -> CustomMedium: - """Convert to a :class:`.CustomMedium`. - Requires the pole residue model to only have a pole at 0 frequency, - corresponding to a constant conductivity term. - - Returns - ------- - :class:`.CustomMedium` - The non-dispersive equivalent with constant permittivity and conductivity. - """ - res = 0 - for a, c in self.poles: - if np.any(abs(_get_numpy_array(a)) > fp_eps): - raise ValidationError( - "Cannot convert dispersive 'CustomPoleResidue' to 'CustomMedium'." - ) - res = res + (c + np.conj(c)) / 2 - sigma = res * 2 * EPSILON_0 - - self_dict = self.dict(exclude={"type", "eps_inf", "poles"}) - self_dict.update({"permittivity": self.eps_inf, "conductivity": np.real(sigma)}) - return CustomMedium.parse_obj(self_dict) - - @cached_property - def loss_upper_bound(self) -> float: - """Not implemented yet.""" - raise SetupError("To be implemented.") - - def _sel_custom_data_inside(self, bounds: Bound): - """Return a new custom medium that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - CustomPoleResidue - CustomPoleResidue with reduced data. - """ - if not self.eps_inf.does_cover(bounds=bounds): - log.warning("eps_inf spatial data array does not fully cover the requested region.") - eps_inf_reduced = self.eps_inf.sel_inside(bounds=bounds) - poles_reduced = [] - for pole, residue in self.poles: - if not pole.does_cover(bounds=bounds): - log.warning("Pole spatial data array does not fully cover the requested region.") - - if not residue.does_cover(bounds=bounds): - log.warning("Residue spatial data array does not fully cover the requested region.") - - poles_reduced.append((pole.sel_inside(bounds), residue.sel_inside(bounds))) - - return self.updated_copy(eps_inf=eps_inf_reduced, poles=poles_reduced) - - def _derivative_field_cmp( - self, - E_der_map: ElectromagneticFieldDataset, - spatial_data: CustomSpatialDataTypeAnnotated, - dim: str, - freqs=None, - component: str = "complex", - ) -> np.ndarray: - """Compatibility wrapper for derivative computation. - - Accepts the extended signature used by other custom media ( - e.g., `CustomMedium._derivative_field_cmp`) while delegating the actual - computation to the base implementation that only depends on - `E_der_map`, `spatial_data`, and `dim`. - - Parameters `freqs` and `component` are ignored for this model since the - derivative is taken with respect to the complex permittivity directly. - """ - return super()._derivative_field_cmp( - E_der_map=E_der_map, spatial_data=spatial_data, dim=dim - ) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Compute adjoint derivatives by preparing array data and calling the static helper.""" - - # accumulate complex-valued derivatives across xyz; start as complex to avoid casting issues - dJ_deps_complex = 0.0 + 0.0j - for dim in "xyz": - dJ_deps_complex += self._derivative_field_cmp( - E_der_map=derivative_info.E_der_map, - spatial_data=self.eps_inf, - dim=dim, - freqs=derivative_info.frequencies, - component="complex", - ) - - poles_vals = [ - (np.array(a.values, dtype=complex), np.array(c.values, dtype=complex)) - for a, c in self.poles - ] - - vjps_total = {} - for freq in derivative_info.frequencies: - vjps_f = PoleResidue._get_vjps_from_params( - dJ_deps_complex=dJ_deps_complex, - poles_vals=poles_vals, - omega=2 * np.pi * freq, - requested_paths=derivative_info.paths, - project_real=False, - ) - for path, vjp in vjps_f.items(): - if path not in vjps_total: - vjps_total[path] = vjp - else: - vjps_total[path] += vjp - return vjps_total - - -class Sellmeier(DispersiveMedium): - """A dispersive medium described by the Sellmeier model. - - Notes - ----- - - The frequency-dependence of the refractive index is described by: - - .. math:: - - n(\\lambda)^2 = 1 + \\sum_i \\frac{B_i \\lambda^2}{\\lambda^2 - C_i} - - For lossless, weakly dispersive materials, the best way to incorporate the dispersion without doing - complicated fits and without slowing the simulation down significantly is to provide the value of the - refractive index dispersion :math:`\\frac{dn}{d\\lambda}` in :meth:`tidy3d.Sellmeier.from_dispersion`. The - value is assumed to be at the central frequency or wavelength (whichever is provided), and a one-pole model - for the material is generated. - - Example - ------- - >>> sellmeier_medium = Sellmeier(coeffs=[(1,2), (3,4)]) - >>> eps = sellmeier_medium.eps_model(200e12) - - See Also - -------- - - :class:`CustomSellmeier` - A spatially varying dispersive medium described by the Sellmeier model. - - **Notebooks** - - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - - * `Modeling dispersive material in FDTD `_ - """ - - coeffs: tuple[tuple[float, pd.PositiveFloat], ...] = pd.Field( - title="Coefficients", - description="List of Sellmeier (:math:`B_i, C_i`) coefficients.", - units=(None, MICROMETER + "^2"), - ) - - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): - """Assert passive medium if `allow_gain` is False.""" - if values.get("allow_gain"): - return val - for B, _ in val: - if B < 0: - raise ValidationError( - "For passive medium, 'B_i' must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - return val - - @pd.validator("modulation_spec", always=True) - def _validate_permittivity_modulation(cls, val): - """Assert modulated permittivity cannot be <= 0.""" - - if val is None or val.permittivity is None: - return val - - min_eps_inf = 1.0 - if min_eps_inf - val.permittivity.max_modulation <= 0: - raise ValidationError( - "The minimum permittivity value with modulation applied was found to be negative." - ) - return val - - _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() - - def _n_model(self, frequency: float) -> complex: - """Complex-valued refractive index as a function of frequency.""" - - wvl = C_0 / np.array(frequency) - wvl2 = wvl**2 - n_squared = 1.0 - for B, C in self.coeffs: - n_squared = n_squared + B * wvl2 / (wvl2 - C) - return np.sqrt(n_squared + 0j) - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - """Complex-valued permittivity as a function of frequency.""" - - n = self._n_model(frequency) - return AbstractMedium.nk_to_eps_complex(n) - - def _pole_residue_dict(self) -> dict: - """Dict representation of Medium as a pole-residue model""" - poles = [] - eps_inf = _ones_like(self.coeffs[0][0]) - for B, C in self.coeffs: - # for small C, it's equivalent to modifying eps_inf - if np.any(np.isclose(_get_numpy_array(C), 0)): - eps_inf += B - else: - beta = 2 * np.pi * C_0 / np.sqrt(C) - alpha = -0.5 * beta * B - a = 1j * beta - c = 1j * alpha - poles.append((a, c)) - return { - "eps_inf": eps_inf, - "poles": poles, - "frequency_range": self.frequency_range, - "name": self.name, - } - - @staticmethod - def _from_dispersion_to_coeffs(n: float, freq: float, dn_dwvl: float): - """Compute Sellmeier coefficients from dispersion.""" - wvl = C_0 / np.array(freq) - nsqm1 = n**2 - 1 - c_coeff = -(wvl**3) * n * dn_dwvl / (nsqm1 - wvl * n * dn_dwvl) - b_coeff = (wvl**2 - c_coeff) / wvl**2 * nsqm1 - return [(b_coeff, c_coeff)] - - @classmethod - def from_dispersion(cls, n: float, freq: float, dn_dwvl: float = 0, **kwargs: Any): - """Convert ``n`` and wavelength dispersion ``dn_dwvl`` values at frequency ``freq`` to - a single-pole :class:`Sellmeier` medium. - - Parameters - ---------- - n : float - Real part of refractive index. Must be larger than or equal to one. - dn_dwvl : float = 0 - Derivative of the refractive index with wavelength (1/um). Must be negative. - freq : float - Frequency at which ``n`` and ``dn_dwvl`` are sampled. - - Returns - ------- - :class:`Sellmeier` - Single-pole Sellmeier medium with the prvoided refractive index and index dispersion - valuesat at the prvoided frequency. - """ - - if dn_dwvl >= 0: - raise ValidationError("Dispersion ``dn_dwvl`` must be smaller than zero.") - if n < 1: - raise ValidationError("Refractive index ``n`` cannot be smaller than one.") - return cls(coeffs=cls._from_dispersion_to_coeffs(n, freq, dn_dwvl), **kwargs) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Adjoint derivatives for Sellmeier params via TJP through eps_model().""" - - freqs, vec = self._tjp_inputs(derivative_info) - N = len(self.coeffs) - if N == 0: - return {} - - # pack parameters into flat vector [B..., C...] - B0 = np.array([float(b) for (b, _c) in self.coeffs]) - C0 = np.array([float(c) for (_b, c) in self.coeffs]) - theta0 = np.concatenate([B0, C0]) - - def _eps_vec(theta): - B = theta[:N] - C = theta[N : 2 * N] - coeffs = tuple((B[i], C[i]) for i in range(N)) - eps = self.updated_copy(coeffs=coeffs, validate=False).eps_model(freqs) - return pack_complex_vec(eps) - - g = self._tjp_grad(theta0, _eps_vec, vec) - - mapping = [] - mapping += [(("coeffs", i, 0), i) for i in range(N)] - mapping += [(("coeffs", i, 1), N + i) for i in range(N)] - return self._map_grad_real(g, derivative_info.paths, mapping) - - @staticmethod - def _lam2(freq): - return (C_0 / freq) ** 2 - - @staticmethod - def _sellmeier_den(lam2, C): - return lam2 - C - - # frequency weights for custom Sellmeier - @staticmethod - def _w_B(freq, C): - lam2 = Sellmeier._lam2(freq) - return lam2 / Sellmeier._sellmeier_den(lam2, C) - - @staticmethod - def _w_C(freq, B, C): - lam2 = Sellmeier._lam2(freq) - den = Sellmeier._sellmeier_den(lam2, C) - return B * lam2 / (den**2) - - -class CustomSellmeier(CustomDispersiveMedium, Sellmeier): - """A spatially varying dispersive medium described by the Sellmeier model. - - Notes - ----- - - The frequency-dependence of the refractive index is described by: - - .. math:: - - n(\\lambda)^2 = 1 + \\sum_i \\frac{B_i \\lambda^2}{\\lambda^2 - C_i} - - Example - ------- - >>> x = np.linspace(-1, 1, 5) - >>> y = np.linspace(-1, 1, 6) - >>> z = np.linspace(-1, 1, 7) - >>> coords = dict(x=x, y=y, z=z) - >>> b1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) - >>> c1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) - >>> sellmeier_medium = CustomSellmeier(coeffs=[(b1,c1),]) - >>> eps = sellmeier_medium.eps_model(200e12) - - See Also - -------- - - :class:`Sellmeier` - A dispersive medium described by the Sellmeier model. - - **Notebooks** - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - * `Modeling dispersive material in FDTD `_ - """ - - coeffs: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( - pd.Field( - ..., - title="Coefficients", - description="List of Sellmeier (:math:`B_i, C_i`) coefficients.", - units=(None, MICROMETER + "^2"), - ) - ) - - _no_nans = validate_no_nans("coeffs") - - _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") - - @pd.validator("coeffs", always=True) - def _correct_shape_and_sign(cls, val): - """every term in coeffs must have the same shape, and B>=0 and C>0.""" - if len(val) == 0: - return val - for B, C in val: - if not _check_same_coordinates(B, val[0][0]) or not _check_same_coordinates( - C, val[0][0] - ): - raise SetupError("Every term in 'coeffs' must have the same coordinates.") - if not CustomDispersiveMedium._validate_isreal_dataarray_tuple((B, C)): - raise SetupError("'B' and 'C' must be real.") - if np.any(_get_numpy_array(C) <= 0): - raise SetupError("'C' must be positive.") - return val - - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): - """Assert passive medium if `allow_gain` is False.""" - if values.get("allow_gain"): - return val - for B, _ in val: - if np.any(_get_numpy_array(B) < 0): - raise ValidationError( - "For passive medium, 'B_i' must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - return val - - @pd.validator("coeffs", always=True) - def _coeffs_C_all_near_zero_or_much_greater(cls, val): - """We restrict either all C~=0, or very different from 0.""" - for _, C in val: - c_array_near_zero = np.isclose(_get_numpy_array(C), 0) - if np.any(c_array_near_zero) and not np.all(c_array_near_zero): - raise SetupError( - "Coefficients 'C_i' are restricted to be " - "either all near zero or much greater than 0." - ) - return val - - @cached_property - def is_spatially_uniform(self) -> bool: - """Whether the medium is spatially uniform.""" - for coeffs in self.coeffs: - for coeff in coeffs: - if not coeff.is_uniform: - return False - return True - - def _pole_residue_dict(self) -> dict: - """Dict representation of Medium as a pole-residue model.""" - poles_dict = Sellmeier._pole_residue_dict(self) - if len(self.coeffs) > 0: - poles_dict.update({"eps_inf": _ones_like(self.coeffs[0][0])}) - return poles_dict - - def eps_dataarray_freq( - self, frequency: float - ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: - """Permittivity array at ``frequency``. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[ - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - ] - The permittivity evaluated at ``frequency``. - """ - eps = Sellmeier.eps_model(self, frequency) - # if `eps` is simply a float, convert it to a SpatialDataArray ; this is possible when - # `coeffs` is empty. - if isinstance(eps, (int, float, complex)): - eps = SpatialDataArray(eps * np.ones((1, 1, 1)), coords={"x": [0], "y": [0], "z": [0]}) - return (eps, eps, eps) - - @classmethod - def from_dispersion( - cls, - n: CustomSpatialDataType, - freq: float, - dn_dwvl: CustomSpatialDataType, - interp_method="nearest", - **kwargs: Any, - ): - """Convert ``n`` and wavelength dispersion ``dn_dwvl`` values at frequency ``freq`` to - a single-pole :class:`CustomSellmeier` medium. - - Parameters - ---------- - n : Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ] - Real part of refractive index. Must be larger than or equal to one. - dn_dwvl : Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ] - Derivative of the refractive index with wavelength (1/um). Must be negative. - freq : float - Frequency at which ``n`` and ``dn_dwvl`` are sampled. - interp_method : :class:`.InterpMethod`, optional - Interpolation method to obtain permittivity values that are not supplied - at the Yee grids. - - Returns - ------- - :class:`.CustomSellmeier` - Single-pole Sellmeier medium with the prvoided refractive index and index dispersion - valuesat at the prvoided frequency. - """ - - if not _check_same_coordinates(n, dn_dwvl): - raise ValidationError("'n' and'dn_dwvl' must have the same dimension.") - if np.any(_get_numpy_array(dn_dwvl) >= 0): - raise ValidationError("Dispersion ``dn_dwvl`` must be smaller than zero.") - if np.any(_get_numpy_array(n) < 1): - raise ValidationError("Refractive index ``n`` cannot be smaller than one.") - return cls( - coeffs=cls._from_dispersion_to_coeffs(n, freq, dn_dwvl), - interp_method=interp_method, - **kwargs, - ) - - def _sel_custom_data_inside(self, bounds: Bound): - """Return a new custom medium that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - CustomSellmeier - CustomSellmeier with reduced data. - """ - coeffs_reduced = [] - for b_coeff, c_coeff in self.coeffs: - if not b_coeff.does_cover(bounds=bounds): - log.warning( - "Sellmeier B coeff spatial data array does not fully cover the requested region." - ) - - if not c_coeff.does_cover(bounds=bounds): - log.warning( - "Sellmeier C coeff spatial data array does not fully cover the requested region." - ) - - coeffs_reduced.append((b_coeff.sel_inside(bounds), c_coeff.sel_inside(bounds))) - - return self.updated_copy(coeffs=coeffs_reduced) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Adjoint derivatives for CustomSellmeier via analytic chain rule. - - Uses the complex permittivity derivative aggregated over spatial dims and - applies frequency-dependent weights per Sellmeier term. - """ - - if len(self.coeffs) == 0: - return {} - - # accumulate complex-valued sensitivity across xyz using B's grid as reference - ref = self.coeffs[0][0] - dJ = self._sum_complex_eps_sensitivity(derivative_info, spatial_ref=ref) - - # prepare gradients map - grads: AutogradFieldMap = {} - - # iterate coefficients and requested paths - for i, (B_da, C_da) in enumerate(self.coeffs): - need_B = ("coeffs", i, 0) in derivative_info.paths - need_C = ("coeffs", i, 1) in derivative_info.paths - if not (need_B or need_C): - continue - - Bv = np.array(B_da.values, dtype=float) - Cv = np.array(C_da.values, dtype=float) - - gB = 0.0 if not need_B else np.zeros_like(Bv, dtype=float) - gC = 0.0 if not need_C else np.zeros_like(Cv, dtype=float) - - if need_B: - gB = gB + self._sum_over_freqs( - derivative_info.frequencies, - dJ, - weight_fn=lambda f, Cv=Cv: Sellmeier._w_B(f, Cv), - ) - if need_C: - gC = gC + self._sum_over_freqs( - derivative_info.frequencies, - dJ, - weight_fn=lambda f, Bv=Bv, Cv=Cv: Sellmeier._w_C(f, Bv, Cv), - ) - - if need_B: - grads[("coeffs", i, 0)] = gB - if need_C: - grads[("coeffs", i, 1)] = gC - - return grads - - -class Lorentz(DispersiveMedium): - """A dispersive medium described by the Lorentz model. - - Notes - ----- - - The frequency-dependence of the complex-valued permittivity is described by: - - .. math:: - - \\epsilon(f) = \\epsilon_\\infty + \\sum_i - \\frac{\\Delta\\epsilon_i f_i^2}{f_i^2 - 2jf\\delta_i - f^2} - - Example - ------- - >>> lorentz_medium = Lorentz(eps_inf=2.0, coeffs=[(1,2,3), (4,5,6)]) - >>> eps = lorentz_medium.eps_model(200e12) - - See Also - -------- - - **Notebooks** - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - * `Modeling dispersive material in FDTD `_ - """ - - eps_inf: pd.PositiveFloat = pd.Field( - 1.0, - title="Epsilon at Infinity", - description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", - units=PERMITTIVITY, - ) - - coeffs: tuple[tuple[float, float, pd.NonNegativeFloat], ...] = pd.Field( - ..., - title="Coefficients", - description="List of (:math:`\\Delta\\epsilon_i, f_i, \\delta_i`) values for model.", - units=(PERMITTIVITY, HERTZ, HERTZ), - ) - - @pd.validator("coeffs", always=True) - def _coeffs_unequal_f_delta(cls, val): - """f**2 and delta**2 cannot be exactly the same.""" - for _, f, delta in val: - if f**2 == delta**2: - raise SetupError("'f' and 'delta' cannot take equal values.") - return val - - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): - """Assert passive medium if ``allow_gain`` is False.""" - if values.get("allow_gain"): - return val - for del_ep, _, _ in val: - if del_ep < 0: - raise ValidationError( - "For passive medium, 'Delta epsilon_i' must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - return val - - _validate_permittivity_modulation = DispersiveMedium._permittivity_modulation_validation() - _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - """Complex-valued permittivity as a function of frequency.""" - - eps = self.eps_inf + 0.0j - for de, f, delta in self.coeffs: - eps = eps + (de * f**2) / (f**2 - 2j * frequency * delta - frequency**2) - return eps - - def _pole_residue_dict(self) -> dict: - """Dict representation of Medium as a pole-residue model.""" - - poles = [] - for de, f, delta in self.coeffs: - w = 2 * np.pi * f - d = 2 * np.pi * delta - - if self._all_larger(d**2, w**2): - r = np.sqrt(d * d - w * w) + 0j - a0 = -d + r - c0 = de * w**2 / 4 / r - a1 = -d - r - c1 = -c0 - poles.extend(((a0, c0), (a1, c1))) - else: - r = np.sqrt(w * w - d * d) - a = -d - 1j * r - c = 1j * de * w**2 / 2 / r - poles.append((a, c)) - - return { - "eps_inf": self.eps_inf, - "poles": poles, - "frequency_range": self.frequency_range, - "name": self.name, - } - - @staticmethod - def _all_larger(coeff_a, coeff_b) -> bool: - """``coeff_a`` and ``coeff_b`` can be either float or SpatialDataArray.""" - if isinstance(coeff_a, CustomSpatialDataType.__args__): - return np.all(_get_numpy_array(coeff_a) > _get_numpy_array(coeff_b)) - return coeff_a > coeff_b - - @classmethod - def from_nk(cls, n: float, k: float, freq: float, **kwargs: Any): - """Convert ``n`` and ``k`` values at frequency ``freq`` to a single-pole Lorentz - medium. - - Parameters - ---------- - n : float - Real part of refractive index. - k : float = 0 - Imaginary part of refrative index. - freq : float - Frequency to evaluate permittivity at (Hz). - kwargs: dict - Keyword arguments passed to the medium construction. - - Returns - ------- - :class:`Lorentz` - Lorentz medium having refractive index n+ik at frequency ``freq``. - """ - eps_complex = AbstractMedium.nk_to_eps_complex(n, k) - eps_r, eps_i = eps_complex.real, eps_complex.imag - if eps_r >= 1: - log.warning( - "For 'permittivity>=1', it is more computationally efficient to " - "use a dispersiveless medium constructed from 'Medium.from_nk()'." - ) - # first, lossless medium - if isclose(eps_i, 0): - if eps_r < 1: - fp = np.sqrt((eps_r - 1) / (eps_r - 2)) * freq - return cls( - eps_inf=1, - coeffs=[ - (1, fp, 0), - ], - ) - return cls( - eps_inf=1, - coeffs=[ - ((eps_r - 1) / 2, np.sqrt(2) * freq, 0), - ], - ) - # lossy medium - alpha = (eps_r - 1) / eps_i - delta_p = freq / 2 / (alpha**2 - alpha + 1) - fp = np.sqrt((alpha**2 + 1) / (alpha**2 - alpha + 1)) * freq - return cls( - eps_inf=1, - coeffs=[ - (eps_i, fp, delta_p), - ], - **kwargs, - ) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Adjoint derivatives for Lorentz params via TJP through eps_model().""" - - f, vec = self._tjp_inputs(derivative_info) - - N = len(self.coeffs) - if N == 0 and ("eps_inf",) not in derivative_info.paths: - return {} - - # pack into flat [eps_inf, de..., f0..., delta...] - eps_inf0 = float(self.eps_inf) - de0 = np.array([float(de) for (de, _f, _d) in self.coeffs]) if N else np.array([]) - f0 = np.array([float(fi) for (_de, fi, _d) in self.coeffs]) if N else np.array([]) - d0 = np.array([float(dd) for (_de, _f, dd) in self.coeffs]) if N else np.array([]) - theta0 = np.concatenate([np.array([eps_inf0]), de0, f0, d0]) - - def _eps_vec(theta): - eps_inf = theta[0] - de = theta[1 : 1 + N] - fi = theta[1 + N : 1 + 2 * N] - dd = theta[1 + 2 * N : 1 + 3 * N] - coeffs = tuple((de[i], fi[i], dd[i]) for i in range(N)) - eps = self.updated_copy(eps_inf=eps_inf, coeffs=coeffs, validate=False).eps_model(f) - return pack_complex_vec(eps) - - g = self._tjp_grad(theta0, _eps_vec, vec) - - mapping = [(("eps_inf",), 0)] - base = 1 - mapping += [(("coeffs", i, 0), base + i) for i in range(N)] - mapping += [(("coeffs", i, 1), base + N + i) for i in range(N)] - mapping += [(("coeffs", i, 2), base + 2 * N + i) for i in range(N)] - return self._map_grad_real(g, derivative_info.paths, mapping) - - @staticmethod - def _den(freq, f0, delta): - return (f0**2) - 2j * (freq * delta) - (freq**2) - - # frequency weights for custom Lorentz - @staticmethod - def _w_de(freq, f0, delta): - return (f0**2) / Lorentz._den(freq, f0, delta) - - @staticmethod - def _w_f0(freq, de, f0, delta): - den = Lorentz._den(freq, f0, delta) - return (2.0 * de * f0 * (den - f0**2)) / (den**2) - - @staticmethod - def _w_delta(freq, de, f0, delta): - den = Lorentz._den(freq, f0, delta) - return (2j * freq * de * (f0**2)) / (den**2) - - -class CustomLorentz(CustomDispersiveMedium, Lorentz): - """A spatially varying dispersive medium described by the Lorentz model. - - Notes - ----- - - The frequency-dependence of the complex-valued permittivity is described by: - - .. math:: - - \\epsilon(f) = \\epsilon_\\infty + \\sum_i - \\frac{\\Delta\\epsilon_i f_i^2}{f_i^2 - 2jf\\delta_i - f^2} - - Example - ------- - >>> x = np.linspace(-1, 1, 5) - >>> y = np.linspace(-1, 1, 6) - >>> z = np.linspace(-1, 1, 7) - >>> coords = dict(x=x, y=y, z=z) - >>> eps_inf = SpatialDataArray(np.ones((5, 6, 7)), coords=coords) - >>> d_epsilon = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) - >>> f = SpatialDataArray(1+np.random.random((5, 6, 7)), coords=coords) - >>> delta = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) - >>> lorentz_medium = CustomLorentz(eps_inf=eps_inf, coeffs=[(d_epsilon,f,delta),]) - >>> eps = lorentz_medium.eps_model(200e12) - - See Also - -------- - - :class:`CustomPoleResidue`: - A spatially varying dispersive medium described by the pole-residue pair model. - - **Notebooks** - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - * `Modeling dispersive material in FDTD `_ - """ - - eps_inf: CustomSpatialDataTypeAnnotated = pd.Field( - ..., - title="Epsilon at Infinity", - description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", - units=PERMITTIVITY, - ) - - coeffs: tuple[ - tuple[ - CustomSpatialDataTypeAnnotated, - CustomSpatialDataTypeAnnotated, - CustomSpatialDataTypeAnnotated, - ], - ..., - ] = pd.Field( - ..., - title="Coefficients", - description="List of (:math:`\\Delta\\epsilon_i, f_i, \\delta_i`) values for model.", - units=(PERMITTIVITY, HERTZ, HERTZ), - ) - - _no_nans_eps_inf = validate_no_nans("eps_inf") - _no_nans_coeffs = validate_no_nans("coeffs") - - _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") - - @pd.validator("eps_inf", always=True) - def _eps_inf_positive(cls, val): - """eps_inf must be positive""" - if not CustomDispersiveMedium._validate_isreal_dataarray(val): - raise SetupError("'eps_inf' must be real.") - if np.any(_get_numpy_array(val) < 0): - raise SetupError("'eps_inf' must be positive.") - return val - - @pd.validator("coeffs", always=True) - def _coeffs_unequal_f_delta(cls, val): - """f and delta cannot be exactly the same. - Not needed for now because we have a more strict - validator `_coeffs_delta_all_smaller_or_larger_than_fi`. - """ - return val - - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["eps_inf"]) - def _coeffs_correct_shape(cls, val, values): - """coeffs must have consistent shape.""" - for de, f, delta in val: - if ( - not _check_same_coordinates(de, values["eps_inf"]) - or not _check_same_coordinates(f, values["eps_inf"]) - or not _check_same_coordinates(delta, values["eps_inf"]) - ): - raise SetupError( - "All terms in 'coeffs' must have the same coordinates; " - "The coordinates must also be consistent with 'eps_inf'." - ) - if not CustomDispersiveMedium._validate_isreal_dataarray_tuple((de, f, delta)): - raise SetupError("All terms in 'coeffs' must be real.") - return val - - @pd.validator("coeffs", always=True) - def _coeffs_delta_all_smaller_or_larger_than_fi(cls, val): - """We restrict either all f**2>delta**2 or all f**2'f**2'." - ) - return val - - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): - """Assert passive medium if ``allow_gain`` is False.""" - allow_gain = values.get("allow_gain") - for del_ep, _, delta in val: - if np.any(_get_numpy_array(delta) < 0): - raise ValidationError("For stable medium, 'delta_i' must be non-negative.") - if not allow_gain and np.any(_get_numpy_array(del_ep) < 0): - raise ValidationError( - "For passive medium, 'Delta epsilon_i' must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - return val - - @cached_property - def is_spatially_uniform(self) -> bool: - """Whether the medium is spatially uniform.""" - if not self.eps_inf.is_uniform: - return False - for coeffs in self.coeffs: - for coeff in coeffs: - if not coeff.is_uniform: - return False - return True - - def eps_dataarray_freq( - self, frequency: float - ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: - """Permittivity array at ``frequency``. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[ - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - ] - The permittivity evaluated at ``frequency``. - """ - eps = Lorentz.eps_model(self, frequency) - return (eps, eps, eps) - - def _sel_custom_data_inside(self, bounds: Bound): - """Return a new custom medium that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - CustomLorentz - CustomLorentz with reduced data. - """ - if not self.eps_inf.does_cover(bounds=bounds): - log.warning("Eps inf spatial data array does not fully cover the requested region.") - eps_inf_reduced = self.eps_inf.sel_inside(bounds=bounds) - coeffs_reduced = [] - for de, f, delta in self.coeffs: - if not de.does_cover(bounds=bounds): - log.warning( - "Lorentz 'de' spatial data array does not fully cover the requested region." - ) - - if not f.does_cover(bounds=bounds): - log.warning( - "Lorentz 'f' spatial data array does not fully cover the requested region." - ) - - if not delta.does_cover(bounds=bounds): - log.warning( - "Lorentz 'delta' spatial data array does not fully cover the requested region." - ) - - coeffs_reduced.append( - (de.sel_inside(bounds), f.sel_inside(bounds), delta.sel_inside(bounds)) - ) - - return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=coeffs_reduced) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Adjoint derivatives for CustomLorentz via analytic chain rule.""" - - # complex epsilon sensitivity over xyz aligned to eps_inf grid - dJ = self._sum_complex_eps_sensitivity(derivative_info, spatial_ref=self.eps_inf) - - grads: AutogradFieldMap = {} - - # eps_inf path - if ("eps_inf",) in derivative_info.paths: - grads[("eps_inf",)] = np.real(dJ) - - # per-coefficient contributions - for i, (de_da, f0_da, dl_da) in enumerate(self.coeffs): - need_de = ("coeffs", i, 0) in derivative_info.paths - need_f0 = ("coeffs", i, 1) in derivative_info.paths - need_dl = ("coeffs", i, 2) in derivative_info.paths - if not (need_de or need_f0 or need_dl): - continue - - de = np.array(de_da.values, dtype=float) - f0 = np.array(f0_da.values, dtype=float) - dl = np.array(dl_da.values, dtype=float) - - g_de = 0.0 if not need_de else np.zeros_like(de, dtype=float) - g_f0 = 0.0 if not need_f0 else np.zeros_like(f0, dtype=float) - g_dl = 0.0 if not need_dl else np.zeros_like(dl, dtype=float) - - def _den(f, f0=f0, dl=dl): - return Lorentz._den(f, f0, dl) - - if need_de: - g_de = g_de + self._sum_over_freqs( - derivative_info.frequencies, - dJ, - weight_fn=lambda f, f0=f0, dl=dl: Lorentz._w_de(f, f0, dl), - ) - if need_f0: - # d/d f0 of (de f0^2 / den) = (2 de f0 (den - f0^2)) / den^2 - g_f0 = g_f0 + self._sum_over_freqs( - derivative_info.frequencies, - dJ, - weight_fn=lambda f, de=de, f0=f0, dl=dl: Lorentz._w_f0(f, de, f0, dl), - ) - if need_dl: - # d/d delta of (de f0^2 / den) = (2 j f de f0^2) / den^2 - g_dl = g_dl + self._sum_over_freqs( - derivative_info.frequencies, - dJ, - weight_fn=lambda f, de=de, f0=f0, dl=dl: Lorentz._w_delta(f, de, f0, dl), - ) - - if need_de: - grads[("coeffs", i, 0)] = g_de - if need_f0: - grads[("coeffs", i, 1)] = g_f0 - if need_dl: - grads[("coeffs", i, 2)] = g_dl - - return grads - - -class Drude(DispersiveMedium): - """A dispersive medium described by the Drude model. - - Notes - ----- - - The frequency-dependence of the complex-valued permittivity is described by: - - .. math:: - - \\epsilon(f) = \\epsilon_\\infty - \\sum_i - \\frac{ f_i^2}{f^2 + jf\\delta_i} - - Example - ------- - >>> drude_medium = Drude(eps_inf=2.0, coeffs=[(1,2), (3,4)]) - >>> eps = drude_medium.eps_model(200e12) - - See Also - -------- - - :class:`CustomDrude`: - A spatially varying dispersive medium described by the Drude model. - - **Notebooks** - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - * `Modeling dispersive material in FDTD `_ - """ - - eps_inf: pd.PositiveFloat = pd.Field( - 1.0, - title="Epsilon at Infinity", - description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", - units=PERMITTIVITY, - ) - - coeffs: tuple[tuple[float, pd.PositiveFloat], ...] = pd.Field( - ..., - title="Coefficients", - description="List of (:math:`f_i, \\delta_i`) values for model.", - units=(HERTZ, HERTZ), - ) - - _validate_permittivity_modulation = DispersiveMedium._permittivity_modulation_validation() - _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - """Complex-valued permittivity as a function of frequency.""" - - eps = self.eps_inf + 0.0j - for f, delta in self.coeffs: - eps = eps - (f**2) / (frequency**2 + 1j * frequency * delta) - return eps - - # --- unified helpers for autograd + tests --- - - def _pole_residue_dict(self) -> dict: - """Dict representation of Medium as a pole-residue model.""" - - poles = [] - - for f, delta in self.coeffs: - w = 2 * np.pi * f - d = 2 * np.pi * delta - - c0 = (w**2) / 2 / d + 0j - c1 = -c0 - a1 = -d + 0j - - if isinstance(c0, complex): - a0 = 0j - else: - a0 = 0 * c0 - - poles.extend(((a0, c0), (a1, c1))) - - return { - "eps_inf": self.eps_inf, - "poles": poles, - "frequency_range": self.frequency_range, - "name": self.name, - } - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Adjoint derivatives for Drude params via TJP through eps_model().""" - - f, vec = self._tjp_inputs(derivative_info) - - N = len(self.coeffs) - if N == 0 and ("eps_inf",) not in derivative_info.paths: - return {} - - # pack into flat [eps_inf, fp..., delta...] - eps_inf0 = float(self.eps_inf) - fp0 = np.array([float(fp) for (fp, _d) in self.coeffs]) if N else np.array([]) - d0 = np.array([float(dd) for (_fp, dd) in self.coeffs]) if N else np.array([]) - theta0 = np.concatenate([np.array([eps_inf0]), fp0, d0]) - - def _eps_vec(theta): - eps_inf = theta[0] - fp = theta[1 : 1 + N] - dd = theta[1 + N : 1 + 2 * N] - coeffs = tuple((fp[i], dd[i]) for i in range(N)) - eps = self.updated_copy(eps_inf=eps_inf, coeffs=coeffs, validate=False).eps_model(f) - return pack_complex_vec(eps) - - g = self._tjp_grad(theta0, _eps_vec, vec) - - mapping = [(("eps_inf",), 0)] - base = 1 - mapping += [(("coeffs", i, 0), base + i) for i in range(N)] - mapping += [(("coeffs", i, 1), base + N + i) for i in range(N)] - return self._map_grad_real(g, derivative_info.paths, mapping) - - @staticmethod - def _den(freq, delta): - return (freq**2) + 1j * (freq * delta) - - # frequency weights for custom Drude - @staticmethod - def _w_fp(freq, fp, delta): - return -(2.0 * fp) / Drude._den(freq, delta) - - @staticmethod - def _w_delta(freq, fp, delta): - den = Drude._den(freq, delta) - return (1j * freq * (fp**2)) / (den**2) - - -class CustomDrude(CustomDispersiveMedium, Drude): - """A spatially varying dispersive medium described by the Drude model. - - - Notes - ----- - - The frequency-dependence of the complex-valued permittivity is described by: - - .. math:: - - \\epsilon(f) = \\epsilon_\\infty - \\sum_i - \\frac{ f_i^2}{f^2 + jf\\delta_i} - - Example - ------- - >>> x = np.linspace(-1, 1, 5) - >>> y = np.linspace(-1, 1, 6) - >>> z = np.linspace(-1, 1, 7) - >>> coords = dict(x=x, y=y, z=z) - >>> eps_inf = SpatialDataArray(np.ones((5, 6, 7)), coords=coords) - >>> f1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) - >>> delta1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) - >>> drude_medium = CustomDrude(eps_inf=eps_inf, coeffs=[(f1,delta1),]) - >>> eps = drude_medium.eps_model(200e12) - - See Also - -------- - - :class:`Drude`: - A dispersive medium described by the Drude model. - - **Notebooks** - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - * `Modeling dispersive material in FDTD `_ - """ - - eps_inf: CustomSpatialDataTypeAnnotated = pd.Field( - ..., - title="Epsilon at Infinity", - description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", - units=PERMITTIVITY, - ) - - coeffs: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( - pd.Field( - ..., - title="Coefficients", - description="List of (:math:`f_i, \\delta_i`) values for model.", - units=(HERTZ, HERTZ), - ) - ) - - _no_nans_eps_inf = validate_no_nans("eps_inf") - _no_nans_coeffs = validate_no_nans("coeffs") - - _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") - - @pd.validator("eps_inf", always=True) - def _eps_inf_positive(cls, val): - """eps_inf must be positive""" - if not CustomDispersiveMedium._validate_isreal_dataarray(val): - raise SetupError("'eps_inf' must be real.") - if np.any(_get_numpy_array(val) < 0): - raise SetupError("'eps_inf' must be positive.") - return val - - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["eps_inf"]) - def _coeffs_correct_shape_and_sign(cls, val, values): - """coeffs must have consistent shape and sign.""" - for f, delta in val: - if not _check_same_coordinates(f, values["eps_inf"]) or not _check_same_coordinates( - delta, values["eps_inf"] - ): - raise SetupError( - "All terms in 'coeffs' must have the same coordinates; " - "The coordinates must also be consistent with 'eps_inf'." - ) - if not CustomDispersiveMedium._validate_isreal_dataarray_tuple((f, delta)): - raise SetupError("All terms in 'coeffs' must be real.") - if np.any(_get_numpy_array(delta) <= 0): - raise SetupError("For stable medium, 'delta' must be positive.") - return val - - @cached_property - def is_spatially_uniform(self) -> bool: - """Whether the medium is spatially uniform.""" - if not self.eps_inf.is_uniform: - return False - for coeffs in self.coeffs: - for coeff in coeffs: - if not coeff.is_uniform: - return False - return True - - def eps_dataarray_freq( - self, frequency: float - ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: - """Permittivity array at ``frequency``. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[ - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - ] - The permittivity evaluated at ``frequency``. - """ - eps = Drude.eps_model(self, frequency) - return (eps, eps, eps) - - def _sel_custom_data_inside(self, bounds: Bound): - """Return a new custom medium that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - CustomDrude - CustomDrude with reduced data. - """ - if not self.eps_inf.does_cover(bounds=bounds): - log.warning("Eps inf spatial data array does not fully cover the requested region.") - eps_inf_reduced = self.eps_inf.sel_inside(bounds=bounds) - coeffs_reduced = [] - for f, delta in self.coeffs: - if not f.does_cover(bounds=bounds): - log.warning( - "Drude 'f' spatial data array does not fully cover the requested region." - ) - - if not delta.does_cover(bounds=bounds): - log.warning( - "Drude 'delta' spatial data array does not fully cover the requested region." - ) - - coeffs_reduced.append((f.sel_inside(bounds), delta.sel_inside(bounds))) - - return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=coeffs_reduced) - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Adjoint derivatives for CustomDrude via analytic chain rule.""" - - dJ = self._sum_complex_eps_sensitivity(derivative_info, spatial_ref=self.eps_inf) - - grads: AutogradFieldMap = {} - if ("eps_inf",) in derivative_info.paths: - grads[("eps_inf",)] = np.real(dJ) - - for i, (fp_da, dl_da) in enumerate(self.coeffs): - need_fp = ("coeffs", i, 0) in derivative_info.paths - need_dl = ("coeffs", i, 1) in derivative_info.paths - if not (need_fp or need_dl): - continue - - fp = np.array(fp_da.values, dtype=float) - dl = np.array(dl_da.values, dtype=float) - - g_fp = 0.0 if not need_fp else np.zeros_like(fp, dtype=float) - g_dl = 0.0 if not need_dl else np.zeros_like(dl, dtype=float) - - if need_fp: - g_fp = g_fp + self._sum_over_freqs( - derivative_info.frequencies, - dJ, - weight_fn=lambda f, fp=fp, dl=dl: Drude._w_fp(f, fp, dl), - ) - if need_dl: - g_dl = g_dl + self._sum_over_freqs( - derivative_info.frequencies, - dJ, - weight_fn=lambda f, fp=fp, dl=dl: Drude._w_delta(f, fp, dl), - ) - - if need_fp: - grads[("coeffs", i, 0)] = g_fp - if need_dl: - grads[("coeffs", i, 1)] = g_dl - - return grads - - -class Debye(DispersiveMedium): - """A dispersive medium described by the Debye model. - - Notes - ----- - - The frequency-dependence of the complex-valued permittivity is described by: - - .. math:: - - \\epsilon(f) = \\epsilon_\\infty + \\sum_i - \\frac{\\Delta\\epsilon_i}{1 - jf\\tau_i} - - Example - ------- - >>> debye_medium = Debye(eps_inf=2.0, coeffs=[(1,2),(3,4)]) - >>> eps = debye_medium.eps_model(200e12) - - See Also - -------- - - :class:`CustomDebye` - A spatially varying dispersive medium described by the Debye model. - - **Notebooks** - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - * `Modeling dispersive material in FDTD `_ - """ - - eps_inf: pd.PositiveFloat = pd.Field( - 1.0, - title="Epsilon at Infinity", - description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", - units=PERMITTIVITY, - ) - - coeffs: tuple[tuple[float, pd.PositiveFloat], ...] = pd.Field( - ..., - title="Coefficients", - description="List of (:math:`\\Delta\\epsilon_i, \\tau_i`) values for model.", - units=(PERMITTIVITY, SECOND), - ) - - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): - """Assert passive medium if `allow_gain` is False.""" - if values.get("allow_gain"): - return val - for del_ep, _ in val: - if del_ep < 0: - raise ValidationError( - "For passive medium, 'Delta epsilon_i' must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - return val - - _validate_permittivity_modulation = DispersiveMedium._permittivity_modulation_validation() - _validate_conductivity_modulation = DispersiveMedium._conductivity_modulation_validation() - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - """Complex-valued permittivity as a function of frequency.""" - - eps = self.eps_inf + 0.0j - for de, tau in self.coeffs: - eps = eps + de / (1 - 1j * frequency * tau) - return eps - - # --- unified helpers for autograd + tests --- - - def _pole_residue_dict(self): - """Dict representation of Medium as a pole-residue model.""" - - poles = [] - eps_inf = self.eps_inf - for de, tau in self.coeffs: - # for |tau| close to 0, it's equivalent to modifying eps_inf - if np.any(abs(_get_numpy_array(tau)) < 1 / 2 / np.pi / LARGEST_FP_NUMBER): - eps_inf = eps_inf + de - else: - a = -2 * np.pi / tau + 0j - c = -0.5 * de * a - - poles.append((a, c)) - - return { - "eps_inf": eps_inf, - "poles": poles, - "frequency_range": self.frequency_range, - "name": self.name, - } - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Adjoint derivatives for Debye params via TJP through eps_model().""" - - f, vec = self._tjp_inputs(derivative_info) - - N = len(self.coeffs) - if N == 0 and ("eps_inf",) not in derivative_info.paths: - return {} - - # pack into flat [eps_inf, de..., tau...] - eps_inf0 = float(self.eps_inf) - de0 = np.array([float(de) for (de, _t) in self.coeffs]) if N else np.array([]) - tau0 = np.array([float(t) for (_de, t) in self.coeffs]) if N else np.array([]) - theta0 = np.concatenate([np.array([eps_inf0]), de0, tau0]) - - def _eps_vec(theta): - eps_inf = theta[0] - de = theta[1 : 1 + N] - tau = theta[1 + N : 1 + 2 * N] - coeffs = tuple((de[i], tau[i]) for i in range(N)) - eps = self.updated_copy(eps_inf=eps_inf, coeffs=coeffs, validate=False).eps_model(f) - return pack_complex_vec(eps) - - g = self._tjp_grad(theta0, _eps_vec, vec) - - mapping = [(("eps_inf",), 0)] - base = 1 - mapping += [(("coeffs", i, 0), base + i) for i in range(N)] - mapping += [(("coeffs", i, 1), base + N + i) for i in range(N)] - return self._map_grad_real(g, derivative_info.paths, mapping) - - @staticmethod - def _den(freq, tau): - return 1 - 1j * (freq * tau) - - # frequency weights for custom Debye - @staticmethod - def _w_de(freq, tau): - return 1.0 / Debye._den(freq, tau) - - @staticmethod - def _w_tau(freq, de, tau): - den = Debye._den(freq, tau) - return (1j * freq * de) / (den**2) - - -class CustomDebye(CustomDispersiveMedium, Debye): - """A spatially varying dispersive medium described by the Debye model. - - Notes - ----- - - The frequency-dependence of the complex-valued permittivity is described by: - - .. math:: - - \\epsilon(f) = \\epsilon_\\infty + \\sum_i - \\frac{\\Delta\\epsilon_i}{1 - jf\\tau_i} - - Example - ------- - >>> x = np.linspace(-1, 1, 5) - >>> y = np.linspace(-1, 1, 6) - >>> z = np.linspace(-1, 1, 7) - >>> coords = dict(x=x, y=y, z=z) - >>> eps_inf = SpatialDataArray(1+np.random.random((5, 6, 7)), coords=coords) - >>> eps1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) - >>> tau1 = SpatialDataArray(np.random.random((5, 6, 7)), coords=coords) - >>> debye_medium = CustomDebye(eps_inf=eps_inf, coeffs=[(eps1,tau1),]) - >>> eps = debye_medium.eps_model(200e12) - - See Also - -------- - - :class:`Debye` - A dispersive medium described by the Debye model. - - **Notebooks** - * `Fitting dispersive material models <../../notebooks/Fitting.html>`_ - - **Lectures** - * `Modeling dispersive material in FDTD `_ - """ - - eps_inf: CustomSpatialDataTypeAnnotated = pd.Field( - ..., - title="Epsilon at Infinity", - description="Relative permittivity at infinite frequency (:math:`\\epsilon_\\infty`).", - units=PERMITTIVITY, - ) - - coeffs: tuple[tuple[CustomSpatialDataTypeAnnotated, CustomSpatialDataTypeAnnotated], ...] = ( - pd.Field( - ..., - title="Coefficients", - description="List of (:math:`\\Delta\\epsilon_i, \\tau_i`) values for model.", - units=(PERMITTIVITY, SECOND), - ) - ) - - _no_nans_eps_inf = validate_no_nans("eps_inf") - _no_nans_coeffs = validate_no_nans("coeffs") - - _warn_if_none = CustomDispersiveMedium._warn_if_data_none("coeffs") - - @pd.validator("eps_inf", always=True) - def _eps_inf_positive(cls, val): - """eps_inf must be positive""" - if not CustomDispersiveMedium._validate_isreal_dataarray(val): - raise SetupError("'eps_inf' must be real.") - if np.any(_get_numpy_array(val) < 0): - raise SetupError("'eps_inf' must be positive.") - return val - - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["eps_inf"]) - def _coeffs_correct_shape(cls, val, values): - """coeffs must have consistent shape.""" - for de, tau in val: - if not _check_same_coordinates(de, values["eps_inf"]) or not _check_same_coordinates( - tau, values["eps_inf"] - ): - raise SetupError( - "All terms in 'coeffs' must have the same coordinates; " - "The coordinates must also be consistent with 'eps_inf'." - ) - if not CustomDispersiveMedium._validate_isreal_dataarray_tuple((de, tau)): - raise SetupError("All terms in 'coeffs' must be real.") - return val - - @pd.validator("coeffs", always=True) - def _coeffs_tau_all_sufficient_positive(cls, val): - """We restrict either all tau is sufficently greater than 0.""" - for _, tau in val: - if np.any(_get_numpy_array(tau) < 1 / 2 / np.pi / LARGEST_FP_NUMBER): - raise SetupError( - "Coefficients 'tau_i' are restricted to be sufficiently greater than 0." - ) - return val - - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: - """Adjoint derivatives for CustomDebye via analytic chain rule.""" - - dJ = self._sum_complex_eps_sensitivity(derivative_info, spatial_ref=self.eps_inf) - - grads: AutogradFieldMap = {} - if ("eps_inf",) in derivative_info.paths: - grads[("eps_inf",)] = np.real(dJ) - - for i, (de_da, tau_da) in enumerate(self.coeffs): - need_de = ("coeffs", i, 0) in derivative_info.paths - need_tau = ("coeffs", i, 1) in derivative_info.paths - if not (need_de or need_tau): - continue - - de = np.array(de_da.values, dtype=float) - tau = np.array(tau_da.values, dtype=float) - - g_de = 0.0 if not need_de else np.zeros_like(de, dtype=float) - g_tau = 0.0 if not need_tau else np.zeros_like(tau, dtype=float) - - if need_de: - g_de = g_de + self._sum_over_freqs( - derivative_info.frequencies, - dJ, - weight_fn=lambda f, tau=tau: Debye._w_de(f, tau), - ) - if need_tau: - g_tau = g_tau + self._sum_over_freqs( - derivative_info.frequencies, - dJ, - weight_fn=lambda f, de=de, tau=tau: Debye._w_tau(f, de, tau), - ) - - if need_de: - grads[("coeffs", i, 0)] = g_de - if need_tau: - grads[("coeffs", i, 1)] = g_tau - - return grads - - @pd.validator("coeffs", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): - """Assert passive medium if ``allow_gain`` is False.""" - allow_gain = values.get("allow_gain") - for del_ep, tau in val: - if np.any(_get_numpy_array(tau) <= 0): - raise SetupError("For stable medium, 'tau_i' must be positive.") - if not allow_gain and np.any(_get_numpy_array(del_ep) < 0): - raise ValidationError( - "For passive medium, 'Delta epsilon_i' must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, " - "and are likely to diverge." - ) - return val - - @cached_property - def is_spatially_uniform(self) -> bool: - """Whether the medium is spatially uniform.""" - if not self.eps_inf.is_uniform: - return False - for coeffs in self.coeffs: - for coeff in coeffs: - if not coeff.is_uniform: - return False - return True - - def eps_dataarray_freq( - self, frequency: float - ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: - """Permittivity array at ``frequency``. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[ - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - ] - The permittivity evaluated at ``frequency``. - """ - eps = Debye.eps_model(self, frequency) - return (eps, eps, eps) - - def _sel_custom_data_inside(self, bounds: Bound): - """Return a new custom medium that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - CustomDebye - CustomDebye with reduced data. - """ - if not self.eps_inf.does_cover(bounds=bounds): - log.warning("Eps inf spatial data array does not fully cover the requested region.") - eps_inf_reduced = self.eps_inf.sel_inside(bounds=bounds) - coeffs_reduced = [] - for de, tau in self.coeffs: - if not de.does_cover(bounds=bounds): - log.warning( - "Debye 'f' spatial data array does not fully cover the requested region." - ) - - if not tau.does_cover(bounds=bounds): - log.warning( - "Debye 'tau' spatial data array does not fully cover the requested region." - ) - - coeffs_reduced.append((de.sel_inside(bounds), tau.sel_inside(bounds))) - - return self.updated_copy(eps_inf=eps_inf_reduced, coeffs=coeffs_reduced) - - -class SurfaceImpedanceFitterParam(Tidy3dBaseModel): - """Advanced parameters for fitting surface impedance of a :class:`.LossyMetalMedium`. - Internally, the quantity to be fitted is surface impedance divided by ``-1j * \\omega``. - """ - - max_num_poles: pd.PositiveInt = pd.Field( - LOSSY_METAL_DEFAULT_MAX_POLES, - title="Maximal Number Of Poles", - description="Maximal number of poles in complex-conjugate pole residue model for " - "fitting surface impedance.", - ) - - tolerance_rms: pd.NonNegativeFloat = pd.Field( - LOSSY_METAL_DEFAULT_TOLERANCE_RMS, - title="Tolerance In Fitting", - description="Tolerance in fitting.", - ) - - frequency_sampling_points: pd.PositiveInt = pd.Field( - LOSSY_METAL_DEFAULT_SAMPLING_FREQUENCY, - title="Number Of Sampling Frequencies", - description="Number of sampling frequencies used in fitting.", - ) - - log_sampling: bool = pd.Field( - True, - title="Frequencies Sampling In Log Scale", - description="Whether to sample frequencies logarithmically (``True``), " - "or linearly (``False``).", - ) - - -class AbstractSurfaceRoughness(Tidy3dBaseModel): - """Abstract class for modeling surface roughness of lossy metal.""" - - @abstractmethod - def roughness_correction_factor( - self, frequency: ArrayFloat1D, skin_depths: ArrayFloat1D - ) -> ArrayComplex1D: - """Complex-valued roughness correction factor applied to surface impedance. - - Notes - ----- - The roughness correction factor should be causal. It is multiplied to the - surface impedance of the lossy metal to account for the effects of surface roughness. - - Parameters - ---------- - frequency : ArrayFloat1D - Frequency to evaluate roughness correction factor at (Hz). - skin_depths : ArrayFloat1D - Skin depths of the lossy metal that is frequency-dependent. - - Returns - ------- - ArrayComplex1D - The causal roughness correction factor evaluated at ``frequency``. - """ - - -class HammerstadSurfaceRoughness(AbstractSurfaceRoughness): - """Modified Hammerstad surface roughness model. It's a popular model that works well - under 5 GHz for surface roughness below 2 micrometer RMS. - - Note - ---- - - The power loss compared to smooth surface is described by: - - .. math:: - - 1 + (RF-1) \\frac{2}{\\pi}\\arctan(1.4\\frac{R_q^2}{\\delta^2}) - - where :math:`\\delta` is skin depth, :math:`R_q` the RMS peak-to-vally height, and RF - roughness factor. - - Note - ---- - This model is based on: - - Y. Shlepnev, C. Nwachukwu, "Roughness characterization for interconnect analysis", - 2011 IEEE International Symposium on Electromagnetic Compatibility, - (DOI: 10.1109/ISEMC.2011.6038367), 2011. - - V. Dmitriev-Zdorov, B. Simonovich, I. Kochikov, "A Causal Conductor Roughness Model - and its Effect on Transmission Line Characteristics", Signal Integrity Journal, 2018. - """ - - rq: pd.PositiveFloat = pd.Field( - ..., - title="RMS Peak-to-Valley Height", - description="RMS peak-to-valley height (Rq) of the surface roughness.", - units=MICROMETER, - ) - - roughness_factor: float = pd.Field( - 2.0, - title="Roughness Factor", - description="Expected maximal increase in conductor losses due to roughness effect. " - "Value 2 gives the classic Hammerstad equation.", - gt=1.0, - ) - - def roughness_correction_factor( - self, frequency: ArrayFloat1D, skin_depths: ArrayFloat1D - ) -> ArrayComplex1D: - """Complex-valued roughness correction factor applied to surface impedance. - - Notes - ----- - The roughness correction factor should be causal. It is multiplied to the - surface impedance of the lossy metal to account for the effects of surface roughness. - - Parameters - ---------- - frequency : ArrayFloat1D - Frequency to evaluate roughness correction factor at (Hz). - skin_depths : ArrayFloat1D - Skin depths of the lossy metal that is frequency-dependent. - - Returns - ------- - ArrayComplex1D - The causal roughness correction factor evaluated at ``frequency``. - """ - normalized_laplace = -1.4j * (self.rq / skin_depths) ** 2 - sqrt_normalized_laplace = np.sqrt(normalized_laplace) - causal_response = np.log( - 1 + 2 * sqrt_normalized_laplace / (1 + normalized_laplace) - ) + 2 * np.arctan(sqrt_normalized_laplace) - return 1 + (self.roughness_factor - 1) / np.pi * causal_response - - -class HuraySurfaceRoughness(AbstractSurfaceRoughness): - """Huray surface roughness model. - - Note - ---- - - The power loss compared to smooth surface is described by: - - .. math:: - - \\frac{A_{matte}}{A_{flat}} + \\frac{3}{2}\\sum_i f_i/[1+\\frac{\\delta}{r_i}+\\frac{\\delta^2}{2r_i^2}] - - where :math:`\\delta` is skin depth, :math:`r_i` the radius of sphere, - :math:`\\frac{A_{matte}}{A_{flat}}` the relative area of the matte compared to flat surface, - and :math:`f_i=N_i4\\pi r_i^2/A_{flat}` the ratio of total sphere - surface area (number of spheres :math:`N_i` times the individual sphere surface area) - to the flat surface area. - - Note - ---- - This model is based on: - - J. Eric Bracken, "A Causal Huray Model for Surface Roughness", DesignCon, 2012. - """ - - relative_area: pd.PositiveFloat = pd.Field( - 1, - title="Relative Area", - description="Relative area of the matte base compared to a flat surface", - ) - - coeffs: tuple[tuple[pd.PositiveFloat, pd.PositiveFloat], ...] = pd.Field( - ..., - title="Coefficients for surface ratio and sphere radius", - description="List of (:math:`f_i, r_i`) values for model, where :math:`f_i` is " - "the ratio of total sphere surface area to the flat surface area, and :math:`r_i` " - "the radius of the sphere.", - units=(None, MICROMETER), - ) - - @classmethod - def from_cannonball_huray(cls, radius: float) -> HuraySurfaceRoughness: - """Construct a Cannonball-Huray model. - - Note - ---- - - The power loss compared to smooth surface is described by: - - .. math:: - - 1 + \\frac{7\\pi}{3} \\frac{1}{1+\\frac{\\delta}{r}+\\frac{\\delta^2}{2r^2}} - - Parameters - ---------- - radius : float - Radius of the sphere. - - Returns - ------- - HuraySurfaceRoughness - The Huray surface roughness model. - """ - return cls(relative_area=1, coeffs=[(14.0 / 9 * np.pi, radius)]) - - def roughness_correction_factor( - self, frequency: ArrayFloat1D, skin_depths: ArrayFloat1D - ) -> ArrayComplex1D: - """Complex-valued roughness correction factor applied to surface impedance. - - Notes - ----- - The roughness correction factor should be causal. It is multiplied to the - surface impedance of the lossy metal to account for the effects of surface roughness. - - Parameters - ---------- - frequency : ArrayFloat1D - Frequency to evaluate roughness correction factor at (Hz). - skin_depths : ArrayFloat1D - Skin depths of the lossy metal that is frequency-dependent. - - Returns - ------- - ArrayComplex1D - The causal roughness correction factor evaluated at ``frequency``. - """ - - correction = self.relative_area - for f, r in self.coeffs: - normalized_laplace = -2j * (r / skin_depths) ** 2 - sqrt_normalized_laplace = np.sqrt(normalized_laplace) - correction += 1.5 * f / (1 + 1 / sqrt_normalized_laplace) - return correction - - -SurfaceRoughnessType = Union[HammerstadSurfaceRoughness, HuraySurfaceRoughness] - - -class LossyMetalMedium(Medium): - """Lossy metal that can be modeled with a surface impedance boundary condition (SIBC). - - Notes - ----- - - SIBC is most accurate when the skin depth is much smaller than the structure feature size. - If not the case, please use a regular medium instead, or set ``simulation.subpixel.lossy_metal`` - to ``td.VolumetricAveraging()`` or ``td.Staircasing()``. - - Example - ------- - >>> lossy_metal = LossyMetalMedium(conductivity=10, frequency_range=(9e9, 10e9)) - - """ - - allow_gain: Literal[False] = pd.Field( - False, - title="Allow gain medium", - description="Allow the medium to be active. Caution: " - "simulations with a gain medium are unstable, and are likely to diverge." - "Simulations where ``allow_gain`` is set to ``True`` will still be charged even if " - "diverged. Monitor data up to the divergence point will still be returned and can be " - "useful in some cases.", - ) - - permittivity: Literal[1] = pd.Field( - 1.0, title="Permittivity", description="Relative permittivity.", units=PERMITTIVITY - ) - - roughness: SurfaceRoughnessType = pd.Field( - None, - title="Surface Roughness Model", - description="Surface roughness model that applies a frequency-dependent scaling " - "factor to surface impedance.", - discriminator=TYPE_TAG_STR, - ) - - thickness: pd.PositiveFloat = pd.Field( - None, - title="Conductor Thickness", - description="When the thickness of the conductor is not much greater than skin depth, " - "1D transmission line model is applied to compute the surface impedance of the thin conductor.", - units=MICROMETER, - ) - - frequency_range: FreqBound = pd.Field( - ..., - title="Frequency Range", - description="Frequency range of validity for the medium.", - units=(HERTZ, HERTZ), - ) - - fit_param: SurfaceImpedanceFitterParam = pd.Field( - SurfaceImpedanceFitterParam(), - title="Fitting Parameters For Surface Impedance", - description="Parameters for fitting surface impedance divided by (-1j * omega) over " - "the frequency range using pole-residue pair model.", - ) - - @pd.validator("frequency_range") - def _validate_frequency_range(cls, val): - """Validate that frequency range is finite and non-zero.""" - for freq in val: - if not np.isfinite(freq): - raise ValidationError("Values in 'frequency_range' must be finite.") - if freq <= 0: - raise ValidationError("Values in 'frequency_range' must be positive.") - return val - - @pd.validator("conductivity", always=True) - def _positive_conductivity(cls, val): - """Assert conductivity>0.""" - if val <= 0: - raise ValidationError("For lossy metal, 'conductivity' must be positive. ") - return val - - @cached_property - def _fitting_result(self) -> tuple[PoleResidue, float]: - """Fitted scaled surface impedance and residue.""" - - omega_data = self.Hz_to_angular_freq(self.sampling_frequencies) - surface_impedance = self.surface_impedance(self.sampling_frequencies) - scaled_impedance = surface_impedance / (-1j * omega_data) - - # let's use scaled quantity in fitting: minimal real part equals ``SCALED_REAL_PART`` - min_real = np.min(scaled_impedance.real) - if min_real <= 0: - raise SetupError( - "The real part of scaled surface impedance must be positive. " - "Please create a github issue so that the problem can be investigated. " - "In the meantime, make sure the material is passive." - ) - - scaling_factor = LOSSY_METAL_SCALED_REAL_PART / min_real - scaled_impedance *= scaling_factor - - (res_inf, poles, residues), error = fit( - omega_data=omega_data, - resp_data=scaled_impedance, - min_num_poles=0, - max_num_poles=self.fit_param.max_num_poles, - resp_inf=None, - tolerance_rms=self.fit_param.tolerance_rms, - scale_factor=1.0 / np.max(omega_data), - ) - - res_inf /= scaling_factor - residues /= scaling_factor - return PoleResidue(eps_inf=res_inf, poles=list(zip(poles, residues))), error - - @cached_property - def scaled_surface_impedance_model(self) -> PoleResidue: - """Fitted surface impedance divided by (-j \\omega) using pole-residue pair model within ``frequency_range``.""" - return self._fitting_result[0] - - @cached_property - def num_poles(self) -> int: - """Number of poles in the fitted model.""" - return len(self.scaled_surface_impedance_model.poles) - - def surface_impedance(self, frequencies: ArrayFloat1D): - """Computing surface impedance including surface roughness effects.""" - # compute complex-valued skin depth - n, k = self.nk_model(frequencies) - - # with surface roughness effects - correction = 1.0 - if self.roughness is not None: - skin_depths = 1 / np.sqrt(np.pi * frequencies * MU_0 * self.conductivity) - correction = self.roughness.roughness_correction_factor(frequencies, skin_depths) - - if self.thickness is not None: - k_wave = self.Hz_to_angular_freq(frequencies) / C_0 * (n + 1j * k) - correction /= -np.tanh(1j * k_wave * self.thickness) - - return correction * ETA_0 / (n + 1j * k) - - @cached_property - def sampling_frequencies(self) -> ArrayFloat1D: - """Sampling frequencies used in fitting.""" - if self.fit_param.frequency_sampling_points < 2: - return np.array([np.mean(self.frequency_range)]) - - if self.fit_param.log_sampling: - return np.logspace( - np.log10(self.frequency_range[0]), - np.log10(self.frequency_range[1]), - self.fit_param.frequency_sampling_points, - ) - return np.linspace( - self.frequency_range[0], - self.frequency_range[1], - self.fit_param.frequency_sampling_points, - ) - - def eps_diagonal_numerical(self, frequency: float) -> tuple[complex, complex, complex]: - """Main diagonal of the complex-valued permittivity tensor for numerical considerations - such as meshing and runtime estimation. - - Parameters - ---------- - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - Tuple[complex, complex, complex] - The diagonal elements of relative permittivity tensor relevant for numerical - considerations evaluated at ``frequency``. - """ - return (1.0 + 0j,) * 3 - - @add_ax_if_none - def plot( - self, - ax: Ax = None, - ) -> Ax: - """Make plot of complex-valued surface imepdance model vs fitted model, at sampling frequencies. - Parameters - ---------- - ax : matplotlib.axes._subplots.Axes = None - Axes to plot the data on, if None, a new one is created. - Returns - ------- - matplotlib.axis.Axes - Matplotlib axis corresponding to plot. - """ - frequencies = self.sampling_frequencies - surface_impedance = self.surface_impedance(frequencies) - - ax.plot(frequencies, surface_impedance.real, "x", label="Real") - ax.plot(frequencies, surface_impedance.imag, "+", label="Imag") - - surface_impedance_model = ( - -1j - * self.Hz_to_angular_freq(frequencies) - * self.scaled_surface_impedance_model.eps_model(frequencies) - ) - ax.plot(frequencies, surface_impedance_model.real, label="Real (model)") - ax.plot(frequencies, surface_impedance_model.imag, label="Imag (model)") - - ax.set_ylabel(r"Surface impedance ($\Omega$)") - ax.set_xlabel("Frequency (Hz)") - ax.legend() - - return ax - - -IsotropicUniformMediumFor2DType = Union[ - Medium, LossyMetalMedium, PoleResidue, Sellmeier, Lorentz, Debye, Drude, PECMedium -] -IsotropicUniformMediumType = Union[IsotropicUniformMediumFor2DType, PMCMedium] -IsotropicCustomMediumType = Union[ - CustomPoleResidue, - CustomSellmeier, - CustomLorentz, - CustomDebye, - CustomDrude, -] -IsotropicCustomMediumInternalType = Union[IsotropicCustomMediumType, CustomIsotropicMedium] -IsotropicMediumType = Union[IsotropicCustomMediumType, IsotropicUniformMediumType] - - -class AnisotropicMedium(AbstractMedium): - """Diagonally anisotropic medium. - - Notes - ----- - - Only diagonal anisotropy is currently supported. - - Example - ------- - >>> medium_xx = Medium(permittivity=4.0) - >>> medium_yy = Medium(permittivity=4.1) - >>> medium_zz = Medium(permittivity=3.9) - >>> anisotropic_dielectric = AnisotropicMedium(xx=medium_xx, yy=medium_yy, zz=medium_zz) - - See Also - -------- - - :class:`CustomAnisotropicMedium` - Diagonally anisotropic medium with spatially varying permittivity in each component. - - :class:`FullyAnisotropicMedium` - Fully anisotropic medium including all 9 components of the permittivity and conductivity tensors. - - **Notebooks** - * `Broadband polarizer assisted by anisotropic metamaterial <../../notebooks/SWGBroadbandPolarizer.html>`_ - * `Thin film lithium niobate adiabatic waveguide coupler <../../notebooks/AdiabaticCouplerLN.html>`_ - """ - - xx: IsotropicUniformMediumType = pd.Field( - ..., - title="XX Component", - description="Medium describing the xx-component of the diagonal permittivity tensor.", - discriminator=TYPE_TAG_STR, - ) - - yy: IsotropicUniformMediumType = pd.Field( - ..., - title="YY Component", - description="Medium describing the yy-component of the diagonal permittivity tensor.", - discriminator=TYPE_TAG_STR, - ) - - zz: IsotropicUniformMediumType = pd.Field( - ..., - title="ZZ Component", - description="Medium describing the zz-component of the diagonal permittivity tensor.", - discriminator=TYPE_TAG_STR, - ) - - allow_gain: bool = pd.Field( - None, - title="Allow gain medium", - description="This field is ignored. Please set ``allow_gain`` in each component", - ) - - @pd.validator("modulation_spec", always=True) - def _validate_modulation_spec(cls, val): - """Check compatibility with modulation_spec.""" - if val is not None: - raise ValidationError( - f"A 'modulation_spec' of class {type(val)} is not " - f"currently supported for medium class {cls.__name__}. " - "Please add modulation to each component." - ) - return val - - @pd.root_validator(pre=True) - def _ignored_fields(cls, values): - """The field is ignored.""" - if values.get("xx") is not None and values.get("allow_gain") is not None: - log.warning( - "The field 'allow_gain' is ignored. Please set 'allow_gain' in each component." - ) - return values - - @cached_property - def components(self) -> dict[str, Medium]: - """Dictionary of diagonal medium components.""" - return {"xx": self.xx, "yy": self.yy, "zz": self.zz} - - @cached_property - def is_time_modulated(self) -> bool: - """Whether any component of the medium is time modulated.""" - return any(mat.is_time_modulated for mat in self.components.values()) - - @cached_property - def n_cfl(self): - """This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl``. - - For this medium, it takes the minimal of ``n_clf`` in all components. - """ - return min(mat_component.n_cfl for mat_component in self.components.values()) - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - """Complex-valued permittivity as a function of frequency.""" - - return np.mean(self.eps_diagonal(frequency), axis=0) - - @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: - """Main diagonal of the complex-valued permittivity tensor as a function of frequency.""" - - eps_xx = self.xx.eps_model(frequency) - eps_yy = self.yy.eps_model(frequency) - eps_zz = self.zz.eps_model(frequency) - return (eps_xx, eps_yy, eps_zz) - - def eps_comp(self, row: Axis, col: Axis, frequency: float) -> complex: - """Single component the complex-valued permittivity tensor as a function of frequency. - - Parameters - ---------- - row : int - Component's row in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). - col : int - Component's column in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). - frequency : float - Frequency to evaluate permittivity at (Hz). - - Returns - ------- - complex - Element of the relative permittivity tensor evaluated at ``frequency``. - """ - - if row != col: - return 0j - cmp = "xyz"[row] - field_name = cmp + cmp - return self.components[field_name].eps_model(frequency) - - def _eps_plot( - self, frequency: float, eps_component: Optional[PermittivityComponent] = None - ) -> float: - """Returns real part of epsilon for plotting. A specific component of the epsilon tensor can - be selected for anisotropic medium. - - Parameters - ---------- - frequency : float - eps_component : PermittivityComponent - - Returns - ------- - float - Element ``eps_component`` of the relative permittivity tensor evaluated at ``frequency``. - """ - if eps_component is None: - # return the average of the diag - return self.eps_model(frequency).real - if eps_component in ["xx", "yy", "zz"]: - # return the requested diagonal component - comp2indx = {"x": 0, "y": 1, "z": 2} - return self.eps_comp( - row=comp2indx[eps_component[0]], - col=comp2indx[eps_component[1]], - frequency=frequency, - ).real - raise ValueError( - f"Plotting component '{eps_component}' of a diagonally-anisotropic permittivity tensor is not supported." - ) - - @add_ax_if_none - def plot(self, freqs: float, ax: Ax = None) -> Ax: - """Plot n, k of a :class:`.Medium` as a function of frequency.""" - - freqs = np.array(freqs) - freqs_thz = freqs / 1e12 - - for label, medium_component in self.elements.items(): - eps_complex = medium_component.eps_model(freqs) - n, k = AbstractMedium.eps_complex_to_nk(eps_complex) - ax.plot(freqs_thz, n, label=f"n, eps_{label}") - ax.plot(freqs_thz, k, label=f"k, eps_{label}") - - ax.set_xlabel("frequency (THz)") - ax.set_title("medium dispersion") - ax.legend() - ax.set_aspect("auto") - return ax - - @property - def elements(self) -> dict[str, IsotropicUniformMediumType]: - """The diagonal elements of the medium as a dictionary.""" - return {"xx": self.xx, "yy": self.yy, "zz": self.zz} - - @cached_property - def is_pec(self): - """Whether the medium is a PEC.""" - return any(self.is_comp_pec(i) for i in range(3)) - - @cached_property - def is_pmc(self): - """Whether the medium is a PMC.""" - return any(self.is_comp_pmc(i) for i in range(3)) - - def is_comp_pec(self, comp: Axis): - """Whether the medium is a PEC.""" - return isinstance(self.components[["xx", "yy", "zz"][comp]], PECMedium) - - def is_comp_pmc(self, comp: Axis): - """Whether the medium is a PMC.""" - return isinstance(self.components[["xx", "yy", "zz"][comp]], PMCMedium) - - def sel_inside(self, bounds: Bound): - """Return a new medium that contains the minimal amount data necessary to cover - a spatial region defined by ``bounds``. - - - Parameters - ---------- - bounds : Tuple[float, float, float], Tuple[float, float float] - Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. - - Returns - ------- - AnisotropicMedium - AnisotropicMedium with reduced data. - """ - - new_comps = [comp.sel_inside(bounds) for comp in [self.xx, self.yy, self.zz]] - - return self.updated_copy(**dict(zip(["xx", "yy", "zz"], new_comps))) - - -class AnisotropicMediumFromMedium2D(AnisotropicMedium): - """The same as ``AnisotropicMedium``, but converted from Medium2D. - (This class is for internal use only) - """ - - -class FullyAnisotropicMedium(AbstractMedium): - """Fully anisotropic medium including all 9 components of the permittivity and conductivity - tensors. - - Notes - ----- - - Provided permittivity tensor and the symmetric part of the conductivity tensor must - have coinciding main directions. A non-symmetric conductivity tensor can be used to model - magneto-optic effects. Note that dispersive properties and subpixel averaging are currently not - supported for fully anisotropic materials. - - Note - ---- - - Simulations involving fully anisotropic materials are computationally more intensive, thus, - they take longer time to complete. This increase strongly depends on the filling fraction of - the simulation domain by fully anisotropic materials, varying approximately in the range from - 1.5 to 5. The cost of running a simulation is adjusted correspondingly. - - Example - ------- - >>> perm = [[2, 0, 0], [0, 1, 0], [0, 0, 3]] - >>> cond = [[0.1, 0, 0], [0, 0, 0], [0, 0, 0]] - >>> anisotropic_dielectric = FullyAnisotropicMedium(permittivity=perm, conductivity=cond) - - See Also - -------- - - :class:`CustomAnisotropicMedium` - Diagonally anisotropic medium with spatially varying permittivity in each component. - - :class:`AnisotropicMedium` - Diagonally anisotropic medium. - - **Notebooks** - * `Broadband polarizer assisted by anisotropic metamaterial <../../notebooks/SWGBroadbandPolarizer.html>`_ - * `Thin film lithium niobate adiabatic waveguide coupler <../../notebooks/AdiabaticCouplerLN.html>`_ - * `Defining fully anisotropic materials <../../notebooks/FullyAnisotropic.html>`_ - """ - - permittivity: TensorReal = pd.Field( - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - title="Permittivity", - description="Relative permittivity tensor.", - units=PERMITTIVITY, - ) - - conductivity: TensorReal = pd.Field( - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - title="Conductivity", - description="Electric conductivity tensor. Defined such that the imaginary part " - "of the complex permittivity at angular frequency omega is given by conductivity/omega.", - units=CONDUCTIVITY, - ) - - @pd.validator("modulation_spec", always=True) - def _validate_modulation_spec(cls, val): - """Check compatibility with modulation_spec.""" - if val is not None: - raise ValidationError( - f"A 'modulation_spec' of class {type(val)} is not " - f"currently supported for medium class {cls.__name__}." - ) - return val - - @pd.validator("permittivity", always=True) - def permittivity_spd_and_ge_one(cls, val): - """Check that provided permittivity tensor is symmetric positive definite - with eigenvalues >= 1. - """ - - if not np.allclose(val, np.transpose(val), atol=fp_eps): - raise ValidationError("Provided permittivity tensor is not symmetric.") - - if np.any(np.linalg.eigvals(val) < 1 - fp_eps): - raise ValidationError("Main diagonal of provided permittivity tensor is not >= 1.") - - return val - - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["permittivity"]) - def conductivity_commutes(cls, val, values): - """Check that the symmetric part of conductivity tensor commutes with permittivity tensor - (that is, simultaneously diagonalizable). - """ - - perm = values.get("permittivity") - cond_sym = 0.5 * (val + val.T) - comm_diff = np.abs(np.matmul(perm, cond_sym) - np.matmul(cond_sym, perm)) - - if not np.allclose(comm_diff, 0, atol=fp_eps): - raise ValidationError( - "Main directions of conductivity and permittivity tensor do not coincide." - ) - - return val - - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["allow_gain"]) - def _passivity_validation(cls, val, values): - """Assert passive medium if ``allow_gain`` is False.""" - if values.get("allow_gain"): - return val - - cond_sym = 0.5 * (val + val.T) - if np.any(np.linalg.eigvals(cond_sym) < -fp_eps): - raise ValidationError( - "For passive medium, main diagonal of provided conductivity tensor " - "must be non-negative. " - "To simulate a gain medium, please set 'allow_gain=True'. " - "Caution: simulations with a gain medium are unstable, and are likely to diverge." - ) - return val + coeffs: tuple[tuple[PositiveFloat, PositiveFloat], ...] = Field( + title="Coefficients for surface ratio and sphere radius", + description="List of (:math:`f_i, r_i`) values for model, where :math:`f_i` is " + "the ratio of total sphere surface area to the flat surface area, and :math:`r_i` " + "the radius of the sphere.", + units=(None, MICROMETER), + ) @classmethod - def from_diagonal(cls, xx: Medium, yy: Medium, zz: Medium, rotation: RotationType): - """Construct a fully anisotropic medium by rotating a diagonally anisotropic medium. - - Parameters - ---------- - xx : :class:`.Medium` - Medium describing the xx-component of the diagonal permittivity tensor. - yy : :class:`.Medium` - Medium describing the yy-component of the diagonal permittivity tensor. - zz : :class:`.Medium` - Medium describing the zz-component of the diagonal permittivity tensor. - rotation : Union[:class:`.RotationAroundAxis`] - Rotation applied to diagonal permittivity tensor. - - Returns - ------- - :class:`FullyAnisotropicMedium` - Resulting fully anisotropic medium. - """ - - if any(comp.nonlinear_spec is not None for comp in [xx, yy, zz]): - raise ValidationError( - "Nonlinearities are not currently supported for the components " - "of a fully anisotropic medium." - ) - - if any(comp.modulation_spec is not None for comp in [xx, yy, zz]): - raise ValidationError( - "Modulation is not currently supported for the components " - "of a fully anisotropic medium." - ) - - permittivity_diag = np.diag([comp.permittivity for comp in [xx, yy, zz]]).tolist() - conductivity_diag = np.diag([comp.conductivity for comp in [xx, yy, zz]]).tolist() - - permittivity = rotation.rotate_tensor(permittivity_diag) - conductivity = rotation.rotate_tensor(conductivity_diag) - - return cls(permittivity=permittivity, conductivity=conductivity) - - @cached_property - def _to_diagonal(self) -> AnisotropicMedium: - """Construct a diagonally anisotropic medium from main components. - - Returns - ------- - :class:`AnisotropicMedium` - Resulting diagonally anisotropic medium. - """ - - perm, cond, _ = self.eps_sigma_diag - - return AnisotropicMedium( - xx=Medium(permittivity=perm[0], conductivity=cond[0]), - yy=Medium(permittivity=perm[1], conductivity=cond[1]), - zz=Medium(permittivity=perm[2], conductivity=cond[2]), - ) - - @cached_property - def eps_sigma_diag( - self, - ) -> tuple[tuple[float, float, float], tuple[float, float, float], TensorReal]: - """Main components of permittivity and conductivity tensors and their directions.""" - - perm_diag, vecs = np.linalg.eig(self.permittivity) - cond_diag = np.diag(np.matmul(np.transpose(vecs), np.matmul(self.conductivity, vecs))) - - return (perm_diag, cond_diag, vecs) - - @ensure_freq_in_range - def eps_model(self, frequency: float) -> complex: - """Complex-valued permittivity as a function of frequency.""" - perm_diag, cond_diag, _ = self.eps_sigma_diag - - if not np.isscalar(frequency): - perm_diag = perm_diag[:, None] - cond_diag = cond_diag[:, None] - eps_diag = AbstractMedium.eps_sigma_to_eps_complex(perm_diag, cond_diag, frequency) - return np.mean(eps_diag) + def from_cannonball_huray(cls, radius: float) -> Self: + """Construct a Cannonball-Huray model. - @ensure_freq_in_range - def eps_diagonal(self, frequency: float) -> tuple[complex, complex, complex]: - """Main diagonal of the complex-valued permittivity tensor as a function of frequency.""" + Note + ---- - perm_diag, cond_diag, _ = self.eps_sigma_diag + The power loss compared to smooth surface is described by: - if not np.isscalar(frequency): - perm_diag = perm_diag[:, None] - cond_diag = cond_diag[:, None] - return AbstractMedium.eps_sigma_to_eps_complex(perm_diag, cond_diag, frequency) + .. math:: - def eps_comp(self, row: Axis, col: Axis, frequency: float) -> complex: - """Single component the complex-valued permittivity tensor as a function of frequency. + 1 + \\frac{7\\pi}{3} \\frac{1}{1+\\frac{\\delta}{r}+\\frac{\\delta^2}{2r^2}} Parameters ---------- - row : int - Component's row in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). - col : int - Component's column in the permittivity tensor (0, 1, or 2 for x, y, or z respectively). - frequency : float - Frequency to evaluate permittivity at (Hz). + radius : float + Radius of the sphere. Returns ------- - complex - Element of the relative permittivity tensor evaluated at ``frequency``. + HuraySurfaceRoughness + The Huray surface roughness model. """ + return cls(relative_area=1, coeffs=[(14.0 / 9 * np.pi, radius)]) - eps = self.permittivity[row][col] - sig = self.conductivity[row][col] - return AbstractMedium.eps_sigma_to_eps_complex(eps, sig, frequency) + def roughness_correction_factor( + self, frequency: ArrayFloat1D, skin_depths: ArrayFloat1D + ) -> ArrayComplex1D: + """Complex-valued roughness correction factor applied to surface impedance. - def _eps_plot( - self, frequency: float, eps_component: Optional[PermittivityComponent] = None - ) -> float: - """Returns real part of epsilon for plotting. A specific component of the epsilon tensor can - be selected for anisotropic medium. + Notes + ----- + The roughness correction factor should be causal. It is multiplied to the + surface impedance of the lossy metal to account for the effects of surface roughness. Parameters ---------- - frequency : float - eps_component : PermittivityComponent + frequency : ArrayFloat1D + Frequency to evaluate roughness correction factor at (Hz). + skin_depths : ArrayFloat1D + Skin depths of the lossy metal that is frequency-dependent. Returns ------- - float - Element ``eps_component`` of the relative permittivity tensor evaluated at ``frequency``. - """ - if eps_component is None: - # return the average of the diag - return self.eps_model(frequency).real - - # return the requested component - comp2indx = {"x": 0, "y": 1, "z": 2} - return self.eps_comp( - row=comp2indx[eps_component[0]], col=comp2indx[eps_component[1]], frequency=frequency - ).real - - @cached_property - def n_cfl(self): - """This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl``. - - For this medium, it take the minimal of ``sqrt(permittivity)`` for main directions. + ArrayComplex1D + The causal roughness correction factor evaluated at ``frequency``. """ - perm_diag, _, _ = self.eps_sigma_diag - return min(np.sqrt(perm_diag)) - - @add_ax_if_none - def plot(self, freqs: float, ax: Ax = None) -> Ax: - """Plot n, k of a :class:`FullyAnisotropicMedium` as a function of frequency.""" + correction = self.relative_area + for f, r in self.coeffs: + normalized_laplace = -2j * (r / skin_depths) ** 2 + sqrt_normalized_laplace = np.sqrt(normalized_laplace) + correction += 1.5 * f / (1 + 1 / sqrt_normalized_laplace) + return correction - diagonal_medium = self._to_diagonal - ax = diagonal_medium.plot(freqs=freqs, ax=ax) - _, _, directions = self.eps_sigma_diag - # rename components from xx, yy, zz to 1, 2, 3 to avoid misleading - # and add their directions - for label, n_line, k_line, direction in zip( - ("1", "2", "3"), ax.lines[-6::2], ax.lines[-5::2], directions.T - ): - direction_str = f"({direction[0]:.2f}, {direction[1]:.2f}, {direction[2]:.2f})" - k_line.set_label(f"k, eps_{label} {direction_str}") - n_line.set_label(f"n, eps_{label} {direction_str}") +SurfaceRoughnessType = Union[HammerstadSurfaceRoughness, HuraySurfaceRoughness] - ax.legend() - return ax +class LossyMetalMedium(Medium): + """Lossy metal that can be modeled with a surface impedance boundary condition (SIBC). -class CustomAnisotropicMedium(AbstractCustomMedium, AnisotropicMedium): - """Diagonally anisotropic medium with spatially varying permittivity in each component. + Notes + ----- - Note - ---- - Only diagonal anisotropy is currently supported. + SIBC is most accurate when the skin depth is much smaller than the structure feature size. + If not the case, please use a regular medium instead, or set ``simulation.subpixel.lossy_metal`` + to ``td.VolumetricAveraging()`` or ``td.Staircasing()``. Example ------- - >>> Nx, Ny, Nz = 10, 9, 8 - >>> x = np.linspace(-1, 1, Nx) - >>> y = np.linspace(-1, 1, Ny) - >>> z = np.linspace(-1, 1, Nz) - >>> coords = dict(x=x, y=y, z=z) - >>> permittivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) - >>> conductivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) - >>> medium_xx = CustomMedium(permittivity=permittivity, conductivity=conductivity) - >>> medium_yy = CustomMedium(permittivity=permittivity, conductivity=conductivity) - >>> d_epsilon = SpatialDataArray(np.random.random((Nx, Ny, Nz)), coords=coords) - >>> f = SpatialDataArray(1+np.random.random((Nx, Ny, Nz)), coords=coords) - >>> delta = SpatialDataArray(np.random.random((Nx, Ny, Nz)), coords=coords) - >>> medium_zz = CustomLorentz(eps_inf=permittivity, coeffs=[(d_epsilon,f,delta),]) - >>> anisotropic_dielectric = CustomAnisotropicMedium(xx=medium_xx, yy=medium_yy, zz=medium_zz) - - See Also - -------- - - :class:`AnisotropicMedium` - Diagonally anisotropic medium. - - **Notebooks** - * `Broadband polarizer assisted by anisotropic metamaterial <../../notebooks/SWGBroadbandPolarizer.html>`_ - * `Thin film lithium niobate adiabatic waveguide coupler <../../notebooks/AdiabaticCouplerLN.html>`_ - * `Defining fully anisotropic materials <../../notebooks/FullyAnisotropic.html>`_ + >>> lossy_metal = LossyMetalMedium(conductivity=10, frequency_range=(9e9, 10e9)) + """ - xx: Union[IsotropicCustomMediumType, CustomMedium] = pd.Field( - ..., - title="XX Component", - description="Medium describing the xx-component of the diagonal permittivity tensor.", - discriminator=TYPE_TAG_STR, + allow_gain: Literal[False] = Field( + False, + title="Allow gain medium", + description="Allow the medium to be active. Caution: " + "simulations with a gain medium are unstable, and are likely to diverge." + "Simulations where ``allow_gain`` is set to ``True`` will still be charged even if " + "diverged. Monitor data up to the divergence point will still be returned and can be " + "useful in some cases.", ) - yy: Union[IsotropicCustomMediumType, CustomMedium] = pd.Field( - ..., - title="YY Component", - description="Medium describing the yy-component of the diagonal permittivity tensor.", - discriminator=TYPE_TAG_STR, + permittivity: Literal[1.0] = Field( + 1.0, title="Permittivity", description="Relative permittivity.", units=PERMITTIVITY ) - zz: Union[IsotropicCustomMediumType, CustomMedium] = pd.Field( - ..., - title="ZZ Component", - description="Medium describing the zz-component of the diagonal permittivity tensor.", - discriminator=TYPE_TAG_STR, + conductivity: PositiveFloat = Field( + title="Conductivity", + description="Electric conductivity. Defined such that the imaginary part of the complex " + "permittivity at angular frequency omega is given by conductivity/omega.", + units=CONDUCTIVITY, ) - interp_method: Optional[InterpMethod] = pd.Field( + roughness: Optional[SurfaceRoughnessType] = Field( None, - title="Interpolation method", - description="When the value is ``None`` each component will follow its own " - "interpolation method. When the value is other than ``None`` the interpolation " - "method specified by this field will override the one in each component.", + title="Surface Roughness Model", + description="Surface roughness model that applies a frequency-dependent scaling " + "factor to surface impedance.", + discriminator=TYPE_TAG_STR, ) - allow_gain: bool = pd.Field( + thickness: Optional[PositiveFloat] = Field( None, - title="Allow gain medium", - description="This field is ignored. Please set ``allow_gain`` in each component", + title="Conductor Thickness", + description="When the thickness of the conductor is not much greater than skin depth, " + "1D transmission line model is applied to compute the surface impedance of the thin conductor.", + units=MICROMETER, ) - subpixel: bool = pd.Field( - None, - title="Subpixel averaging", - description="This field is ignored. Please set ``subpixel`` in each component", + frequency_range: FreqBound = Field( + title="Frequency Range", + description="Frequency range of validity for the medium.", + units=(HERTZ, HERTZ), ) - @pd.validator("xx", always=True) - def _isotropic_xx(cls, val): - """If it's `CustomMedium`, make sure it's isotropic.""" - if isinstance(val, CustomMedium) and not val.is_isotropic: - raise SetupError("The xx-component medium type is not isotropic.") - return val + fit_param: SurfaceImpedanceFitterParam = Field( + default_factory=SurfaceImpedanceFitterParam, + title="Fitting Parameters For Surface Impedance", + description="Parameters for fitting surface impedance divided by (-1j * omega) over " + "the frequency range using pole-residue pair model.", + ) - @pd.validator("yy", always=True) - def _isotropic_yy(cls, val): - """If it's `CustomMedium`, make sure it's isotropic.""" - if isinstance(val, CustomMedium) and not val.is_isotropic: - raise SetupError("The yy-component medium type is not isotropic.") + @field_validator("frequency_range") + @classmethod + def _validate_frequency_range(cls, val: FreqBound) -> FreqBound: + """Validate that frequency range is finite and non-zero.""" + for freq in val: + if not np.isfinite(freq): + raise ValidationError("Values in 'frequency_range' must be finite.") + if freq <= 0: + raise ValidationError("Values in 'frequency_range' must be positive.") return val - @pd.validator("zz", always=True) - def _isotropic_zz(cls, val): - """If it's `CustomMedium`, make sure it's isotropic.""" - if isinstance(val, CustomMedium) and not val.is_isotropic: - raise SetupError("The zz-component medium type is not isotropic.") - return val + @cached_property + def _fitting_result(self) -> tuple[PoleResidue, float]: + """Fitted scaled surface impedance and residue.""" - @pd.root_validator(pre=True) - def _ignored_fields(cls, values): - """The field is ignored.""" - if values.get("xx") is not None: - if values.get("allow_gain") is not None: - log.warning( - "The field 'allow_gain' is ignored. Please set 'allow_gain' in each component." - ) - if values.get("subpixel") is not None: - log.warning( - "The field 'subpixel' is ignored. Please set 'subpixel' in each component." - ) - return values + omega_data = self.Hz_to_angular_freq(self.sampling_frequencies) + surface_impedance = self.surface_impedance(self.sampling_frequencies) + scaled_impedance = surface_impedance / (-1j * omega_data) + + # let's use scaled quantity in fitting: minimal real part equals ``SCALED_REAL_PART`` + min_real = np.min(scaled_impedance.real) + if min_real <= 0: + raise SetupError( + "The real part of scaled surface impedance must be positive. " + "Please create a github issue so that the problem can be investigated. " + "In the meantime, make sure the material is passive." + ) + + scaling_factor = LOSSY_METAL_SCALED_REAL_PART / min_real + scaled_impedance *= scaling_factor + + (res_inf, poles, residues), error = fit( + omega_data=omega_data, + resp_data=scaled_impedance, + min_num_poles=0, + max_num_poles=self.fit_param.max_num_poles, + resp_inf=None, + tolerance_rms=self.fit_param.tolerance_rms, + scale_factor=1.0 / np.max(omega_data), + ) + + res_inf /= scaling_factor + residues /= scaling_factor + return PoleResidue(eps_inf=res_inf, poles=list(zip(poles, residues))), error @cached_property - def is_spatially_uniform(self) -> bool: - """Whether the medium is spatially uniform.""" - return any(comp.is_spatially_uniform for comp in self.components.values()) + def scaled_surface_impedance_model(self) -> PoleResidue: + """Fitted surface impedance divided by (-j \\omega) using pole-residue pair model within ``frequency_range``.""" + return self._fitting_result[0] @cached_property - def n_cfl(self): - """This property computes the index of refraction related to CFL condition, so that - the FDTD with this medium is stable when the time step size that doesn't take - material factor into account is multiplied by ``n_cfl``. + def num_poles(self) -> int: + """Number of poles in the fitted model.""" + return len(self.scaled_surface_impedance_model.poles) - For this medium, it takes the minimal of ``n_clf`` in all components. - """ - return min(mat_component.n_cfl for mat_component in self.components.values()) + def surface_impedance(self, frequencies: ArrayFloat1D) -> ArrayComplex: + """Computing surface impedance including surface roughness effects.""" + # compute complex-valued skin depth + n, k = self.nk_model(frequencies) + + # with surface roughness effects + correction = 1.0 + if self.roughness is not None: + skin_depths = 1 / np.sqrt(np.pi * frequencies * MU_0 * self.conductivity) + correction = self.roughness.roughness_correction_factor(frequencies, skin_depths) + + if self.thickness is not None: + k_wave = self.Hz_to_angular_freq(frequencies) / C_0 * (n + 1j * k) + correction /= -np.tanh(1j * k_wave * self.thickness) + + return correction * ETA_0 / (n + 1j * k) @cached_property - def is_isotropic(self): - """Whether the medium is isotropic.""" - return False - - def _interp_method(self, comp: Axis) -> InterpMethod: - """Interpolation method applied to comp.""" - # override `interp_method` in components if self.interp_method is not None - if self.interp_method is not None: - return self.interp_method - # use component's interp_method - comp_map = ["xx", "yy", "zz"] - return self.components[comp_map[comp]].interp_method - - def eps_dataarray_freq( - self, frequency: float - ) -> tuple[CustomSpatialDataType, CustomSpatialDataType, CustomSpatialDataType]: - """Permittivity array at ``frequency``. + def sampling_frequencies(self) -> ArrayFloat1D: + """Sampling frequencies used in fitting.""" + if self.fit_param.frequency_sampling_points < 2: + return np.array([np.mean(self.frequency_range)]) + + if self.fit_param.log_sampling: + return np.logspace( + np.log10(self.frequency_range[0]), + np.log10(self.frequency_range[1]), + self.fit_param.frequency_sampling_points, + ) + return np.linspace( + self.frequency_range[0], + self.frequency_range[1], + self.fit_param.frequency_sampling_points, + ) + + def eps_diagonal_numerical(self, frequency: float) -> tuple[complex, complex, complex]: + """Main diagonal of the complex-valued permittivity tensor for numerical considerations + such as meshing and runtime estimation. Parameters ---------- @@ -6422,115 +546,62 @@ def eps_dataarray_freq( Returns ------- - Tuple[ - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - Union[ - :class:`.SpatialDataArray`, - :class:`.TriangularGridDataset`, - :class:`.TetrahedralGridDataset`, - ], - ] - The permittivity evaluated at ``frequency``. + tuple[complex, complex, complex] + The diagonal elements of relative permittivity tensor relevant for numerical + considerations evaluated at ``frequency``. """ - return tuple( - mat_component.eps_dataarray_freq(frequency)[ind] - for ind, mat_component in enumerate(self.components.values()) - ) + return (1.0 + 0j,) * 3 - def _eps_bounds( + @add_ax_if_none + def plot( self, - frequency: Optional[float] = None, - eps_component: Optional[PermittivityComponent] = None, - ) -> tuple[float, float]: - """Returns permittivity bounds for setting the color bounds when plotting. - + ax: Ax = None, + ) -> Ax: + """Make plot of complex-valued surface imepdance model vs fitted model, at sampling frequencies. Parameters ---------- - frequency : float = None - Frequency to evaluate the relative permittivity of all mediums. - If not specified, evaluates at infinite frequency. - eps_component : Optional[PermittivityComponent] = None - Component of the permittivity tensor to plot for anisotropic materials, - e.g. ``"xx"``, ``"yy"``, ``"zz"``, ``"xy"``, ``"yz"``, ... - Defaults to ``None``, which returns the average of the diagonal values. - + ax : matplotlib.axes._subplots.Axes = None + Axes to plot the data on, if None, a new one is created. Returns ------- - Tuple[float, float] - The min and max values of the permittivity for the selected component and evaluated at ``frequency``. + matplotlib.axis.Axes + Matplotlib axis corresponding to plot. """ - comps = ["xx", "yy", "zz"] - if eps_component in comps: - # Return the bounds of a specific component - eps_dataarray = self.eps_dataarray_freq(frequency) - eps = self._get_real_vals(eps_dataarray[comps.index(eps_component)]) - return (np.min(eps), np.max(eps)) - if eps_component is None: - # Returns the bounds across all components - return super()._eps_bounds(frequency=frequency) - raise ValueError( - f"Plotting component '{eps_component}' of a diagonally-anisotropic permittivity tensor is not supported." + frequencies = self.sampling_frequencies + surface_impedance = self.surface_impedance(frequencies) + + ax.plot(frequencies, surface_impedance.real, "x", label="Real") + ax.plot(frequencies, surface_impedance.imag, "+", label="Imag") + + surface_impedance_model = ( + -1j + * self.Hz_to_angular_freq(frequencies) + * self.scaled_surface_impedance_model.eps_model(frequencies) ) + ax.plot(frequencies, surface_impedance_model.real, label="Real (model)") + ax.plot(frequencies, surface_impedance_model.imag, label="Imag (model)") - def _sel_custom_data_inside(self, bounds: Bound): - return self + ax.set_ylabel(r"Surface impedance ($\Omega$)") + ax.set_xlabel("Frequency (Hz)") + ax.legend() + return ax -class CustomAnisotropicMediumInternal(CustomAnisotropicMedium): - """Diagonally anisotropic medium with spatially varying permittivity in each component. - Notes - ----- +extend_isotropic_uniform_medium_type(LossyMetalMedium) - Only diagonal anisotropy is currently supported. +IsotropicUniformMediumFor2DType = Union[ + Medium, LossyMetalMedium, PoleResidue, Sellmeier, Lorentz, Debye, Drude, PECMedium +] - Example - ------- - >>> Nx, Ny, Nz = 10, 9, 8 - >>> X = np.linspace(-1, 1, Nx) - >>> Y = np.linspace(-1, 1, Ny) - >>> Z = np.linspace(-1, 1, Nz) - >>> coords = dict(x=X, y=Y, z=Z) - >>> permittivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) - >>> conductivity= SpatialDataArray(np.ones((Nx, Ny, Nz)), coords=coords) - >>> medium_xx = CustomMedium(permittivity=permittivity, conductivity=conductivity) - >>> medium_yy = CustomMedium(permittivity=permittivity, conductivity=conductivity) - >>> d_epsilon = SpatialDataArray(np.random.random((Nx, Ny, Nz)), coords=coords) - >>> f = SpatialDataArray(1+np.random.random((Nx, Ny, Nz)), coords=coords) - >>> delta = SpatialDataArray(np.random.random((Nx, Ny, Nz)), coords=coords) - >>> medium_zz = CustomLorentz(eps_inf=permittivity, coeffs=[(d_epsilon,f,delta),]) - >>> anisotropic_dielectric = CustomAnisotropicMedium(xx=medium_xx, yy=medium_yy, zz=medium_zz) - """ - xx: Union[IsotropicCustomMediumInternalType, CustomMedium] = pd.Field( - ..., - title="XX Component", - description="Medium describing the xx-component of the diagonal permittivity tensor.", - discriminator=TYPE_TAG_STR, - ) +IsotropicMediumType = Union[IsotropicCustomMediumType, IsotropicUniformMediumType] - yy: Union[IsotropicCustomMediumInternalType, CustomMedium] = pd.Field( - ..., - title="YY Component", - description="Medium describing the yy-component of the diagonal permittivity tensor.", - discriminator=TYPE_TAG_STR, - ) - zz: Union[IsotropicCustomMediumInternalType, CustomMedium] = pd.Field( - ..., - title="ZZ Component", - description="Medium describing the zz-component of the diagonal permittivity tensor.", - discriminator=TYPE_TAG_STR, - ) +class AnisotropicMediumFromMedium2D(AnisotropicMedium): + """The same as ``AnisotropicMedium``, but converted from Medium2D. + (This class is for internal use only) + """ """ Medium perturbation classes """ @@ -6539,7 +610,7 @@ class CustomAnisotropicMediumInternal(CustomAnisotropicMedium): class AbstractPerturbationMedium(ABC, Tidy3dBaseModel): """Abstract class for medium perturbation.""" - subpixel: bool = pd.Field( + subpixel: bool = Field( True, title="Subpixel averaging", description="This value will be transferred to the resulting custom medium. That is, " @@ -6549,7 +620,7 @@ class AbstractPerturbationMedium(ABC, Tidy3dBaseModel): "have an effect.", ) - perturbation_spec: Optional[Union[PermittivityPerturbation, IndexPerturbation]] = pd.Field( + perturbation_spec: Optional[Union[PermittivityPerturbation, IndexPerturbation]] = Field( None, title="Perturbation Spec", description="Specification of medium perturbation as one of predefined types.", @@ -6606,7 +677,7 @@ def from_unperturbed( subpixel: bool = True, perturbation_spec: Union[PermittivityPerturbation, IndexPerturbation] = None, **kwargs: Any, - ) -> AbstractPerturbationMedium: + ) -> Self: """Construct a medium with pertubation models from an unpertubed one. Parameters @@ -6630,18 +701,14 @@ def from_unperturbed( Resulting medium with perturbation model. """ - new_dict = medium.dict( - exclude={ - "type", - } - ) + new_dict = medium.model_dump(exclude={"type"}) new_dict["perturbation_spec"] = perturbation_spec new_dict["subpixel"] = subpixel new_dict.update(kwargs) - return cls.parse_obj(new_dict) + return cls.model_validate(new_dict) class PerturbationMedium(Medium, AbstractPerturbationMedium): @@ -6662,14 +729,14 @@ class PerturbationMedium(Medium, AbstractPerturbationMedium): ... ) """ - permittivity_perturbation: Optional[ParameterPerturbation] = pd.Field( + permittivity_perturbation: Optional[ParameterPerturbation] = Field( None, title="Permittivity Perturbation", description="List of heat and/or charge perturbations to permittivity.", units=PERMITTIVITY, ) - conductivity_perturbation: Optional[ParameterPerturbation] = pd.Field( + conductivity_perturbation: Optional[ParameterPerturbation] = Field( None, title="Permittivity Perturbation", description="List of heat and/or charge perturbations to permittivity.", @@ -6688,15 +755,15 @@ class PerturbationMedium(Medium, AbstractPerturbationMedium): allowed_complex=False, ) - @pd.root_validator(pre=True) - def _check_overdefining(cls, values): + @model_validator(mode="after") + def _check_overdefining(self) -> Self: """Check that perturbation model is provided either directly or through ``perturbation_spec``, but not both. """ - perm_p = values.get("permittivity_perturbation") is not None - cond_p = values.get("conductivity_perturbation") is not None - p_spec = values.get("perturbation_spec") is not None + perm_p = self.permittivity_perturbation is not None + cond_p = self.conductivity_perturbation is not None + p_spec = self.perturbation_spec is not None if p_spec and (perm_p or cond_p): raise SetupError( @@ -6705,7 +772,7 @@ def _check_overdefining(cls, values): "but not in both ways simultaneously." ) - return values + return self def perturbed_copy( self, @@ -6753,7 +820,7 @@ def perturbed_copy( if all(x is None for x in [temperature, electron_density, hole_density]): return self - new_dict = self.dict( + new_dict = self.model_dump( exclude={ "permittivity_perturbation", "conductivity_perturbation", @@ -6803,7 +870,7 @@ def perturbed_copy( new_dict["interp_method"] = interp_method new_dict["derived_from"] = self - return CustomMedium.parse_obj(new_dict) + return CustomMedium.model_validate(new_dict) class PerturbationPoleResidue(PoleResidue, AbstractPerturbationMedium): @@ -6837,7 +904,7 @@ class PerturbationPoleResidue(PoleResidue, AbstractPerturbationMedium): ... ) """ - eps_inf_perturbation: Optional[ParameterPerturbation] = pd.Field( + eps_inf_perturbation: Optional[ParameterPerturbation] = Field( None, title="Perturbation of Epsilon at Infinity", description="Perturbations to relative permittivity at infinite frequency " @@ -6847,7 +914,7 @@ class PerturbationPoleResidue(PoleResidue, AbstractPerturbationMedium): poles_perturbation: Optional[ tuple[tuple[Optional[ParameterPerturbation], Optional[ParameterPerturbation]], ...] - ] = pd.Field( + ] = Field( None, title="Perturbations of Poles", description="Perturbations to poles of the model.", @@ -6865,15 +932,15 @@ class PerturbationPoleResidue(PoleResidue, AbstractPerturbationMedium): "poles", ) - @pd.root_validator(pre=True) - def _check_overdefining(cls, values): + @model_validator(mode="after") + def _check_overdefining(self) -> Self: """Check that perturbation model is provided either directly or through ``perturbation_spec``, but not both. """ - eps_i_p = values.get("eps_inf_perturbation") is not None - poles_p = values.get("poles_perturbation") is not None - p_spec = values.get("perturbation_spec") is not None + eps_i_p = self.eps_inf_perturbation is not None + poles_p = self.poles_perturbation is not None + p_spec = self.perturbation_spec is not None if p_spec and (eps_i_p or poles_p): raise SetupError( @@ -6882,7 +949,7 @@ def _check_overdefining(cls, values): "but not in both ways simultaneously." ) - return values + return self def perturbed_copy( self, @@ -6930,10 +997,14 @@ def perturbed_copy( if all(x is None for x in [temperature, electron_density, hole_density]): return self - new_dict = self.dict( + new_dict = self.model_dump( exclude={"eps_inf_perturbation", "poles_perturbation", "perturbation_spec", "type"} ) + if all(x is None for x in [temperature, electron_density, hole_density]): + new_dict.pop("subpixel") + return PoleResidue.model_validate(new_dict) + zeros = ParameterPerturbation._zeros_like(temperature, electron_density, hole_density) eps_inf_field = self.eps_inf + zeros @@ -6983,16 +1054,17 @@ def perturbed_copy( new_dict["interp_method"] = interp_method new_dict["derived_from"] = self - return CustomPoleResidue.parse_obj(new_dict) + return CustomPoleResidue.model_validate(new_dict) # types of mediums that can be used in Simulation and Structures +extend_perturbation_medium_type(PerturbationMedium, PerturbationPoleResidue) -PerturbationMediumType = Union[PerturbationMedium, PerturbationPoleResidue] +T = TypeVar("T") # Update forward references for all Custom medium classes that inherit from AbstractCustomMedium -def _get_all_subclasses(cls): +def _get_all_subclasses(cls: T) -> list[type[T]]: """Recursively get all subclasses of a class.""" all_subclasses = [] for subclass in cls.__subclasses__(): @@ -7002,7 +1074,7 @@ def _get_all_subclasses(cls): for _custom_medium_cls in _get_all_subclasses(AbstractCustomMedium): - _custom_medium_cls.update_forward_refs() + _custom_medium_cls.model_rebuild() MediumType3D = Union[ Medium, @@ -7043,8 +1115,7 @@ class Medium2D(AbstractMedium): """ - ss: IsotropicUniformMediumFor2DType = pd.Field( - ..., + ss: IsotropicUniformMediumFor2DType = Field( title="SS Component", description="Medium describing the ss-component of the diagonal permittivity tensor. " "The ss-component refers to the in-plane dimension of the medium that is the first " @@ -7054,8 +1125,7 @@ class Medium2D(AbstractMedium): discriminator=TYPE_TAG_STR, ) - tt: IsotropicUniformMediumFor2DType = pd.Field( - ..., + tt: IsotropicUniformMediumFor2DType = Field( title="TT Component", description="Medium describing the tt-component of the diagonal permittivity tensor. " "The tt-component refers to the in-plane dimension of the medium that is the second " @@ -7065,8 +1135,9 @@ class Medium2D(AbstractMedium): discriminator=TYPE_TAG_STR, ) - @pd.validator("modulation_spec", always=True) - def _validate_modulation_spec(cls, val): + @field_validator("modulation_spec") + @classmethod + def _validate_modulation_spec(cls, val: Optional[ModulationSpec]) -> Optional[ModulationSpec]: """Check compatibility with modulation_spec.""" if val is not None: raise ValidationError( @@ -7075,16 +1146,16 @@ def _validate_modulation_spec(cls, val): ) return val - @pd.validator("tt", always=True) - @skip_if_fields_missing(["ss"]) - def _validate_inplane_pec(cls, val, values): + @model_validator(mode="after") + def _validate_inplane_pec(self) -> Self: """ss/tt components must be both PEC or non-PEC.""" - if isinstance(val, PECMedium) != isinstance(values["ss"], PECMedium): + val = self.tt + if isinstance(val, PECMedium) != isinstance(self.ss, PECMedium): raise ValidationError( "Materials describing ss- and tt-components must be " "either both 'PECMedium', or non-'PECMedium'." ) - return val + return self @classmethod def _weighted_avg( @@ -7128,11 +1199,11 @@ def volumetric_equivalent( axis : Axis Index (0, 1, or 2 for x, y, or z respectively) of the normal direction to the 2D material. - adjacent_media : Tuple[MediumType3D, MediumType3D] + adjacent_media : tuple[MediumType3D, MediumType3D] The neighboring media on either side of the 2D material. The first element is directly on the - side of the 2D material in the supplied axis, and the second element is directly on the + side. - adjacent_dls : Tuple[float, float] + adjacent_dls : tuple[float, float] Each dl represents twice the thickness of the desired volumetric model on the respective side of the 2D material. @@ -7240,7 +1311,7 @@ def to_medium(self, thickness: float) -> Medium: return self.to_pole_residue(thickness=thickness).to_medium() @classmethod - def from_medium(cls, medium: Medium, thickness: float) -> Medium2D: + def from_medium(cls, medium: Medium, thickness: float) -> Self: """Generate a :class:`.Medium2D` equivalent of a :class:`.Medium` with a given thickness. @@ -7260,7 +1331,7 @@ def from_medium(cls, medium: Medium, thickness: float) -> Medium2D: return Medium2D(ss=med, tt=med, frequency_range=medium.frequency_range) @classmethod - def from_dispersive_medium(cls, medium: DispersiveMedium, thickness: float) -> Medium2D: + def from_dispersive_medium(cls, medium: DispersiveMedium, thickness: float) -> Self: """Generate a :class:`.Medium2D` equivalent of a :class:`.DispersiveMedium` with a given thickness. @@ -7282,7 +1353,7 @@ def from_dispersive_medium(cls, medium: DispersiveMedium, thickness: float) -> M @classmethod def from_anisotropic_medium( cls, medium: AnisotropicMedium, axis: Axis, thickness: float - ) -> Medium2D: + ) -> Self: """Generate a :class:`.Medium2D` equivalent of a :class:`.AnisotropicMedium` with given normal axis and thickness. The ``ss`` and ``tt`` components of the resulting 2D medium correspond to the first of the ``xx``, ``yy``, and ``zz`` components of @@ -7339,7 +1410,7 @@ def eps_diagonal_numerical(self, frequency: float) -> tuple[complex, complex, co Returns ------- - Tuple[complex, complex, complex] + tuple[complex, complex, complex] The diagonal elements of relative permittivity tensor relevant for numerical considerations evaluated at ``frequency``. """ @@ -7409,7 +1480,7 @@ def elements(self) -> dict[str, IsotropicUniformMediumFor2DType]: return {"ss": self.ss, "tt": self.tt} @cached_property - def n_cfl(self): + def n_cfl(self) -> float: """This property computes the index of refraction related to CFL condition, so that the FDTD with this medium is stable when the time step size that doesn't take material factor into account is multiplied by ``n_cfl``. @@ -7417,11 +1488,11 @@ def n_cfl(self): return 1.0 @cached_property - def is_pec(self): + def is_pec(self) -> bool: """Whether the medium is a PEC.""" return any(isinstance(comp, PECMedium) for comp in self.elements.values()) - def is_comp_pec_2d(self, comp: Axis, axis: Axis): + def is_comp_pec_2d(self, comp: Axis, axis: Axis) -> bool: """Whether the medium is a PEC.""" elements_3d = Geometry.unpop_axis( ax_coord=Medium(), plane_coords=self.elements.values(), axis=axis diff --git a/tidy3d/components/microwave/base.py b/tidy3d/components/microwave/base.py index 417adeb4e0..186f24e9e5 100644 --- a/tidy3d/components/microwave/base.py +++ b/tidy3d/components/microwave/base.py @@ -2,15 +2,20 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from tidy3d.components.base import Tidy3dBaseModel from tidy3d.config import config +if TYPE_CHECKING: + from tidy3d.compat import Self + class MicrowaveBaseModel(Tidy3dBaseModel): """Base model that all RF and microwave specific components inherit from.""" @classmethod - def _default_without_license_warning(cls) -> MicrowaveBaseModel: + def _default_without_license_warning(cls) -> Self: """Internal helper factory function for classes inheriting from ``MicrowaveBaseModel``.""" if config.microwave.suppress_rf_license_warning is True: return cls() diff --git a/tidy3d/components/microwave/data/data_array.py b/tidy3d/components/microwave/data/data_array.py index 76a7cb4824..d836eeae8c 100644 --- a/tidy3d/components/microwave/data/data_array.py +++ b/tidy3d/components/microwave/data/data_array.py @@ -7,8 +7,10 @@ class PropagationConstantArray(FreqModeDataArray): """Data array for the complex propagation constant :math:`\\gamma = -\\alpha + j\\beta` with units of 1/m. - In the physics convention where time-harmonic fields evolve with :math:`e^{-j\\omega t}`, a wave - propagating in the +z direction varies as :math:`E(z) = E_0 e^{\\gamma z} = E_0 e^{-\\alpha z} e^{j\\beta z}`. + Notes + ----- + In the physics convention where time-harmonic fields evolve with :math:`e^{-j\\omega t}`, a wave + propagating in the +z direction varies as :math:`E(z) = E_0 e^{\\gamma z} = E_0 e^{-\\alpha z} e^{j\\beta z}`. """ __slots__ = () diff --git a/tidy3d/components/microwave/data/dataset.py b/tidy3d/components/microwave/data/dataset.py index 669e32b004..b7a59bc500 100644 --- a/tidy3d/components/microwave/data/dataset.py +++ b/tidy3d/components/microwave/data/dataset.py @@ -2,7 +2,7 @@ from __future__ import annotations -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.data.data_array import ( CurrentFreqModeDataArray, @@ -23,22 +23,19 @@ class TransmissionLineDataset(ModeFreqDataset): or :class:`ModeSimulation`. """ - Z0: ImpedanceFreqModeDataArray = pd.Field( - ..., + Z0: ImpedanceFreqModeDataArray = Field( title="Characteristic Impedance", description="The characteristic impedance of the transmission line.", ) - voltage_coeffs: VoltageFreqModeDataArray = pd.Field( - ..., + voltage_coeffs: VoltageFreqModeDataArray = Field( title="Mode Voltage Coefficients", description="Quantity calculated for transmission lines, which associates " "a voltage-like quantity with each mode profile that scales linearly with the " "complex-valued mode amplitude.", ) - current_coeffs: CurrentFreqModeDataArray = pd.Field( - ..., + current_coeffs: CurrentFreqModeDataArray = Field( title="Mode Current Coefficients", description="Quantity calculated for transmission lines, which associates " "a current-like quantity with each mode profile that scales linearly with the " diff --git a/tidy3d/components/microwave/data/monitor_data.py b/tidy3d/components/microwave/data/monitor_data.py index cfccc9bcb9..5a64b4687e 100644 --- a/tidy3d/components/microwave/data/monitor_data.py +++ b/tidy3d/components/microwave/data/monitor_data.py @@ -4,15 +4,12 @@ from __future__ import annotations -from typing import Literal, Optional +from typing import TYPE_CHECKING, Optional import numpy as np -import pydantic.v1 as pd -import xarray as xr -from typing_extensions import Self +from pydantic import Field from tidy3d.components.data.data_array import ( - FieldProjectionAngleDataArray, FreqDataArray, FreqModeDataArray, ImpedanceFreqModeDataArray, @@ -20,18 +17,29 @@ from tidy3d.components.data.monitor_data import DirectivityData, ModeData, ModeSolverData from tidy3d.components.microwave.base import MicrowaveBaseModel from tidy3d.components.microwave.data.data_array import ( - AttenuationConstantArray, GroupVelocityArray, - PhaseConstantArray, PhaseVelocityArray, PropagationConstantArray, ) from tidy3d.components.microwave.data.dataset import TransmissionLineDataset from tidy3d.components.microwave.monitor import MicrowaveModeMonitor, MicrowaveModeSolverMonitor -from tidy3d.components.types import FreqArray, ModeClassification, PolarizationBasis from tidy3d.constants import C_0 from tidy3d.log import log +if TYPE_CHECKING: + from typing import Literal + + import xarray as xr + from numpy.typing import NDArray + from typing_extensions import Self + + from tidy3d.components.data.data_array import FieldProjectionAngleDataArray + from tidy3d.components.microwave.data.data_array import ( + AttenuationConstantArray, + PhaseConstantArray, + ) + from tidy3d.components.types import FreqArray, ModeClassification, PolarizationBasis + class AntennaMetricsData(DirectivityData, MicrowaveBaseModel): """Data representing the main parameters and figures of merit for antennas. @@ -82,14 +90,12 @@ class AntennaMetricsData(DirectivityData, MicrowaveBaseModel): John Wiley & Sons, Chapter 2.9 (2016). """ - power_incident: FreqDataArray = pd.Field( - ..., + power_incident: FreqDataArray = Field( title="Power incident", description="Array of values representing the incident power to an antenna.", ) - power_reflected: FreqDataArray = pd.Field( - ..., + power_reflected: FreqDataArray = Field( title="Power reflected", description="Array of values representing power reflected due to an impedance mismatch with the antenna.", ) @@ -115,7 +121,7 @@ def from_directivity_data( New instance combining directivity data with incident and reflected power measurements. """ antenna_params_dict = { - **dir_data.dict(), + **dir_data.model_dump(), "power_incident": power_inc, "power_reflected": power_refl, } @@ -238,7 +244,7 @@ class MicrowaveModeDataBase(MicrowaveBaseModel): are used. """ - transmission_line_data: Optional[TransmissionLineDataset] = pd.Field( + transmission_line_data: Optional[TransmissionLineDataset] = Field( None, title="Transmission Line Data", description="Additional data relevant to transmission lines in RF and microwave applications, " @@ -441,7 +447,7 @@ def _group_index_post_process(self, frequency_step: float) -> Self: super_data = super_data.updated_copy(**update_dict, path="transmission_line_data") return super_data - def _apply_mode_reorder(self, sort_inds_2d): + def _apply_mode_reorder(self, sort_inds_2d: NDArray) -> Self: """Apply a mode reordering along mode_index for all frequency indices. Parameters @@ -522,8 +528,8 @@ class MicrowaveModeData(MicrowaveModeDataBase, ModeData): ... ) """ - monitor: MicrowaveModeMonitor = pd.Field( - ..., title="Monitor", description="Mode monitor associated with the data." + monitor: MicrowaveModeMonitor = Field( + title="Monitor", description="Mode monitor associated with the data." ) @@ -596,8 +602,8 @@ class MicrowaveModeSolverData(MicrowaveModeDataBase, ModeSolverData): ... ) """ - monitor: MicrowaveModeSolverMonitor = pd.Field( - ..., title="Monitor", description="Mode monitor associated with the data." + monitor: MicrowaveModeSolverMonitor = Field( + title="Monitor", description="Mode monitor associated with the data." ) def interp_in_freq( diff --git a/tidy3d/components/microwave/formulas/circuit_parameters.py b/tidy3d/components/microwave/formulas/circuit_parameters.py index 4514d6e993..d253b99747 100644 --- a/tidy3d/components/microwave/formulas/circuit_parameters.py +++ b/tidy3d/components/microwave/formulas/circuit_parameters.py @@ -13,12 +13,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np from tidy3d.components.geometry.base import Geometry -from tidy3d.components.types import Axis from tidy3d.constants import EPSILON_0 +if TYPE_CHECKING: + from tidy3d.components.types import Axis + def inductance_straight_rectangular_wire( size: tuple[float, float, float], current_axis: Axis diff --git a/tidy3d/components/microwave/impedance_calculator.py b/tidy3d/components/microwave/impedance_calculator.py index 8f80c57320..14b0c8fca8 100644 --- a/tidy3d/components/microwave/impedance_calculator.py +++ b/tidy3d/components/microwave/impedance_calculator.py @@ -2,25 +2,19 @@ from __future__ import annotations -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator from tidy3d.components.data.data_array import ( - CurrentIntegralResultType, - ImpedanceResultType, - VoltageIntegralResultType, _make_current_data_array, _make_impedance_data_array, _make_voltage_data_array, ) from tidy3d.components.data.monitor_data import FieldTimeData from tidy3d.components.microwave.base import MicrowaveBaseModel -from tidy3d.components.microwave.path_integrals.integrals.base import ( - AxisAlignedPathIntegral, - IntegrableMonitorDataType, -) +from tidy3d.components.microwave.path_integrals.integrals.base import AxisAlignedPathIntegral from tidy3d.components.microwave.path_integrals.integrals.current import ( AxisAlignedCurrentIntegral, CompositeCurrentIntegral, @@ -33,6 +27,15 @@ from tidy3d.components.monitor import ModeMonitor, ModeSolverMonitor from tidy3d.exceptions import ValidationError +if TYPE_CHECKING: + from tidy3d.compat import Self + from tidy3d.components.data.data_array import ( + CurrentIntegralResultType, + ImpedanceResultType, + VoltageIntegralResultType, + ) + from tidy3d.components.microwave.path_integrals.integrals.base import IntegrableMonitorDataType + VoltageIntegralType = Union[AxisAlignedVoltageIntegral, Custom2DVoltageIntegral] CurrentIntegralType = Union[ AxisAlignedCurrentIntegral, Custom2DCurrentIntegral, CompositeCurrentIntegral @@ -67,20 +70,22 @@ class ImpedanceCalculator(MicrowaveBaseModel): >>> _ = ImpedanceCalculator(voltage_integral=v_int) """ - voltage_integral: Optional[VoltageIntegralType] = pd.Field( + voltage_integral: Optional[VoltageIntegralType] = Field( None, title="Voltage Integral", description="Definition of path integral for computing voltage.", ) - current_integral: Optional[CurrentIntegralType] = pd.Field( + current_integral: Optional[CurrentIntegralType] = Field( None, title="Current Integral", description="Definition of contour integral for computing current.", ) def compute_impedance( - self, em_field: IntegrableMonitorDataType, return_voltage_and_current=False + self, + em_field: IntegrableMonitorDataType, + return_voltage_and_current: bool = False, ) -> Union[ ImpedanceResultType, tuple[ImpedanceResultType, VoltageIntegralResultType, CurrentIntegralResultType], @@ -94,7 +99,7 @@ def compute_impedance( em_field : :class:`.IntegrableMonitorDataType` The electromagnetic field data that will be used for computing the characteristic impedance. - return_voltage_and_current: bool + return_voltage_and_current: bool = False When ``True``, returns additional :class:`.IntegralResultType` that represent the voltage and current associated with the supplied fields. @@ -156,12 +161,13 @@ def compute_impedance( return (impedance, voltage, current) return impedance - @pd.validator("current_integral", always=True) - def check_voltage_or_current(cls, val, values): + @model_validator(mode="after") + def check_voltage_or_current(self) -> Self: """Raise validation error if both ``voltage_integral`` and ``current_integral`` are not provided.""" - if not values.get("voltage_integral") and not val: + val = self.current_integral + if not self.voltage_integral and not val: raise ValidationError( "At least one of 'voltage_integral' or 'current_integral' must be provided." ) - return val + return self diff --git a/tidy3d/components/microwave/mode_spec.py b/tidy3d/components/microwave/mode_spec.py index 579f080500..c6ba7af323 100644 --- a/tidy3d/components/microwave/mode_spec.py +++ b/tidy3d/components/microwave/mode_spec.py @@ -2,13 +2,12 @@ from __future__ import annotations -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator from tidy3d.components.base import cached_property -from tidy3d.components.geometry.base import Box from tidy3d.components.geometry.bound_ops import bounds_contains from tidy3d.components.microwave.base import MicrowaveBaseModel from tidy3d.components.microwave.path_integrals.specs.impedance import ( @@ -16,19 +15,25 @@ ImpedanceSpecType, ) from tidy3d.components.mode_spec import AbstractModeSpec -from tidy3d.components.types import annotate_type from tidy3d.constants import fp_eps from tidy3d.exceptions import SetupError +if TYPE_CHECKING: + from tidy3d.compat import Self + from tidy3d.components.geometry.base import Box + TEM_POLARIZATION_THRESHOLD = 0.995 QTEM_POLARIZATION_THRESHOLD = 0.95 class MicrowaveModeSpec(AbstractModeSpec, MicrowaveBaseModel): - """ - The :class:`.MicrowaveModeSpec` class specifies how quantities related to transmission line - modes and microwave waveguides are computed. For example, it defines the paths for line integrals, which are used to - compute voltage, current, and characteristic impedance of the transmission line. + """Specification for transmission line modes and microwave waveguides. + + Notes + ----- + The :class:`.MicrowaveModeSpec` class specifies how quantities related to transmission line + modes and microwave waveguides are computed. For example, it defines the paths for line integrals, which are used to + compute voltage, current, and characteristic impedance of the transmission line. Example ------- @@ -55,9 +60,9 @@ class MicrowaveModeSpec(AbstractModeSpec, MicrowaveBaseModel): """ impedance_specs: Union[ - annotate_type(ImpedanceSpecType), - tuple[Optional[annotate_type(ImpedanceSpecType)], ...], - ] = pd.Field( + ImpedanceSpecType, + tuple[Optional[ImpedanceSpecType], ...], + ] = Field( default_factory=AutoImpedanceSpec._default_without_license_warning, title="Impedance Specifications", description="Field controls how the impedance is calculated for each mode calculated by the mode solver. " @@ -67,7 +72,7 @@ class MicrowaveModeSpec(AbstractModeSpec, MicrowaveBaseModel): "ignored for the associated mode.", ) - tem_polarization_threshold: float = pd.Field( + tem_polarization_threshold: float = Field( TEM_POLARIZATION_THRESHOLD, gt=0.0, le=1.0, @@ -78,7 +83,7 @@ class MicrowaveModeSpec(AbstractModeSpec, MicrowaveBaseModel): "(or TM) fraction is greater than or equal to this threshold.", ) - qtem_polarization_threshold: float = pd.Field( + qtem_polarization_threshold: float = Field( QTEM_POLARIZATION_THRESHOLD, gt=0.0, le=1.0, @@ -89,9 +94,9 @@ class MicrowaveModeSpec(AbstractModeSpec, MicrowaveBaseModel): ) @cached_property - def _impedance_specs_as_tuple(self) -> tuple[Optional[ImpedanceSpecType]]: + def _impedance_specs_as_tuple(self) -> tuple[Optional[ImpedanceSpecType], ...]: """Gets the impedance_specs field converted to a tuple.""" - if isinstance(self.impedance_specs, Union[tuple, list]): + if isinstance(self.impedance_specs, (tuple, list)): return tuple(self.impedance_specs) return (self.impedance_specs,) @@ -103,15 +108,16 @@ def _using_auto_current_spec(self) -> bool: for impedance_spec in self._impedance_specs_as_tuple ) - @pd.validator("impedance_specs", always=True) - def check_impedance_specs_consistent_with_num_modes(cls, val, values): + @model_validator(mode="after") + def check_impedance_specs_consistent_with_num_modes(self) -> Self: """Check that the number of impedance specifications is equal to the number of modes. A single impedance spec is also permitted.""" - num_modes = values.get("num_modes") - if isinstance(val, Union[tuple, list]): + val = self.impedance_specs + num_modes = self.num_modes + if isinstance(val, (tuple, list)): num_impedance_specs = len(val) else: - return val + return self # Otherwise, check that the count matches if num_impedance_specs != num_modes: @@ -122,9 +128,9 @@ def check_impedance_specs_consistent_with_num_modes(cls, val, values): "a single specification to apply to all modes." ) - return val + return self - def _check_path_integrals_within_box(self, box: Box): + def _check_path_integrals_within_box(self, box: Box) -> None: """Raise SetupError if a ``CustomImpedanceSpec`` includes a path specification defined outside a candidate box. """ diff --git a/tidy3d/components/microwave/monitor.py b/tidy3d/components/microwave/monitor.py index 575ac2651a..9a6bf631d2 100644 --- a/tidy3d/components/microwave/monitor.py +++ b/tidy3d/components/microwave/monitor.py @@ -2,7 +2,7 @@ from __future__ import annotations -import pydantic.v1 as pydantic +from pydantic import Field from tidy3d.components.microwave.base import MicrowaveBaseModel from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec @@ -22,7 +22,7 @@ class MicrowaveModeMonitorBase(MicrowaveBaseModel): precedence over the base :class:`.ModeSpec` field from :class:`.AbstractModeMonitor`. """ - mode_spec: MicrowaveModeSpec = pydantic.Field( + mode_spec: MicrowaveModeSpec = Field( default_factory=MicrowaveModeSpec._default_without_license_warning, title="Mode Specification", description="Parameters to feed to mode solver which determine modes measured by monitor.", diff --git a/tidy3d/components/microwave/path_integrals/factory.py b/tidy3d/components/microwave/path_integrals/factory.py index 9100dcc206..e91d358c56 100644 --- a/tidy3d/components/microwave/path_integrals/factory.py +++ b/tidy3d/components/microwave/path_integrals/factory.py @@ -2,13 +2,8 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING -from tidy3d.components.microwave.impedance_calculator import ( - CurrentIntegralType, - VoltageIntegralType, -) -from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec from tidy3d.components.microwave.path_integrals.integrals.current import ( AxisAlignedCurrentIntegral, CompositeCurrentIntegral, @@ -23,20 +18,27 @@ CompositeCurrentIntegralSpec, Custom2DCurrentIntegralSpec, ) -from tidy3d.components.microwave.path_integrals.specs.impedance import ( - AutoImpedanceSpec, - CustomImpedanceSpec, -) +from tidy3d.components.microwave.path_integrals.specs.impedance import AutoImpedanceSpec from tidy3d.components.microwave.path_integrals.specs.voltage import ( AxisAlignedVoltageIntegralSpec, Custom2DVoltageIntegralSpec, ) -from tidy3d.components.microwave.path_integrals.types import ( - CurrentPathSpecType, - VoltagePathSpecType, -) from tidy3d.exceptions import SetupError, ValidationError +if TYPE_CHECKING: + from typing import Optional + + from tidy3d.components.microwave.impedance_calculator import ( + CurrentIntegralType, + VoltageIntegralType, + ) + from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec + from tidy3d.components.microwave.path_integrals.specs.impedance import CustomImpedanceSpec + from tidy3d.components.microwave.path_integrals.types import ( + CurrentPathSpecType, + VoltagePathSpecType, + ) + def make_voltage_integral(path_spec: VoltagePathSpecType) -> VoltageIntegralType: """Create a voltage path integral from a path specification. @@ -54,9 +56,9 @@ def make_voltage_integral(path_spec: VoltagePathSpecType) -> VoltageIntegralType """ v_integral = None if isinstance(path_spec, AxisAlignedVoltageIntegralSpec): - v_integral = AxisAlignedVoltageIntegral(**path_spec.dict(exclude={"type"})) + v_integral = AxisAlignedVoltageIntegral(**path_spec.model_dump(exclude={"type"})) elif isinstance(path_spec, Custom2DVoltageIntegralSpec): - v_integral = Custom2DVoltageIntegral(**path_spec.dict(exclude={"type"})) + v_integral = Custom2DVoltageIntegral(**path_spec.model_dump(exclude={"type"})) else: raise ValidationError(f"Unsupported voltage path specification type: {type(path_spec)}") return v_integral @@ -78,11 +80,11 @@ def make_current_integral(path_spec: CurrentPathSpecType) -> CurrentIntegralType """ i_integral = None if isinstance(path_spec, AxisAlignedCurrentIntegralSpec): - i_integral = AxisAlignedCurrentIntegral(**path_spec.dict(exclude={"type"})) + i_integral = AxisAlignedCurrentIntegral(**path_spec.model_dump(exclude={"type"})) elif isinstance(path_spec, Custom2DCurrentIntegralSpec): - i_integral = Custom2DCurrentIntegral(**path_spec.dict(exclude={"type"})) + i_integral = Custom2DCurrentIntegral(**path_spec.model_dump(exclude={"type"})) elif isinstance(path_spec, CompositeCurrentIntegralSpec): - i_integral = CompositeCurrentIntegral(**path_spec.dict(exclude={"type"})) + i_integral = CompositeCurrentIntegral(**path_spec.model_dump(exclude={"type"})) else: raise ValidationError(f"Unsupported current path specification type: {type(path_spec)}") return i_integral diff --git a/tidy3d/components/microwave/path_integrals/integrals/auto.py b/tidy3d/components/microwave/path_integrals/integrals/auto.py index 2af7caefaa..ac75be92bb 100644 --- a/tidy3d/components/microwave/path_integrals/integrals/auto.py +++ b/tidy3d/components/microwave/path_integrals/integrals/auto.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from tidy3d.components.geometry.base import Box from tidy3d.components.geometry.utils import ( SnapBehavior, @@ -9,15 +11,17 @@ SnappingSpec, snap_box_to_grid, ) -from tidy3d.components.grid.grid import Grid -from tidy3d.components.lumped_element import LinearLumpedElement from tidy3d.components.microwave.path_integrals.integrals.current import ( AxisAlignedCurrentIntegral, ) from tidy3d.components.microwave.path_integrals.integrals.voltage import ( AxisAlignedVoltageIntegral, ) -from tidy3d.components.types import Direction + +if TYPE_CHECKING: + from tidy3d.components.grid.grid import Grid + from tidy3d.components.lumped_element import LinearLumpedElement + from tidy3d.components.types import Direction def path_integrals_from_lumped_element( diff --git a/tidy3d/components/microwave/path_integrals/integrals/base.py b/tidy3d/components/microwave/path_integrals/integrals/base.py index 34527cce14..e49469fe01 100644 --- a/tidy3d/components/microwave/path_integrals/integrals/base.py +++ b/tidy3d/components/microwave/path_integrals/integrals/base.py @@ -2,13 +2,12 @@ from __future__ import annotations -from typing import Literal, Union +from typing import TYPE_CHECKING, Literal, Union import numpy as np import xarray as xr from tidy3d.components.data.data_array import ( - IntegralResultType, ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray, @@ -22,6 +21,9 @@ from tidy3d.constants import fp_eps from tidy3d.exceptions import DataError +if TYPE_CHECKING: + from tidy3d.components.data.data_array import IntegralResultType + IntegrableMonitorDataType = Union[FieldData, FieldTimeData, ModeData, ModeSolverData] EMScalarFieldType = Union[ScalarFieldDataArray, ScalarFieldTimeDataArray, ScalarModeFieldDataArray] FieldParameter = Literal["E", "H"] @@ -175,7 +177,7 @@ def compute_integral( v_field_name = f"{field}{dim2}" # Validate that fields are present - em_field._check_fields_stored([h_field_name, v_field_name]) + em_field._check_fields_stored([h_field_name, v_field_name]) # type: ignore[list-item] # Select fields lying on the plane plane_indexer = {dim3: self.position} diff --git a/tidy3d/components/microwave/path_integrals/integrals/current.py b/tidy3d/components/microwave/path_integrals/integrals/current.py index d0b2105da0..1cdea6b457 100644 --- a/tidy3d/components/microwave/path_integrals/integrals/current.py +++ b/tidy3d/components/microwave/path_integrals/integrals/current.py @@ -2,24 +2,17 @@ from __future__ import annotations -from typing import Union +from typing import TYPE_CHECKING import numpy as np import xarray as xr from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import ( - CurrentIntegralResultType, - FreqDataArray, - FreqModeDataArray, - IntegralResultType, - _make_current_data_array, -) +from tidy3d.components.data.data_array import FreqModeDataArray, _make_current_data_array from tidy3d.components.data.monitor_data import FieldTimeData from tidy3d.components.microwave.path_integrals.integrals.base import ( AxisAlignedPathIntegral, Custom2DPathIntegral, - IntegrableMonitorDataType, ) from tidy3d.components.microwave.path_integrals.specs.current import ( AxisAlignedCurrentIntegralSpec, @@ -29,6 +22,17 @@ from tidy3d.exceptions import DataError from tidy3d.log import log +if TYPE_CHECKING: + from typing import Optional, Union + + from tidy3d.components.data.data_array import ( + CurrentIntegralResultType, + DataArray, + FreqDataArray, + IntegralResultType, + ) + from tidy3d.components.microwave.path_integrals.integrals.base import IntegrableMonitorDataType + class AxisAlignedCurrentIntegral(AxisAlignedCurrentIntegralSpec): """Class for computing conduction current via Ampère's circuital law on an axis-aligned loop. @@ -55,7 +59,7 @@ def compute_current(self, em_field: IntegrableMonitorDataType) -> CurrentIntegra h_field_name = f"H{h_component}" v_field_name = f"H{v_component}" # Validate that fields are present - em_field._check_fields_stored([h_field_name, v_field_name]) + em_field._check_fields_stored([h_field_name, v_field_name]) # type: ignore[list-item] h_horizontal = em_field.field_components[h_field_name] h_vertical = em_field.field_components[v_field_name] @@ -74,13 +78,16 @@ def compute_current(self, em_field: IntegrableMonitorDataType) -> CurrentIntegra return _make_current_data_array(current) def _to_path_integrals( - self, h_horizontal=None, h_vertical=None + self, + h_horizontal: Optional[DataArray] = None, + h_vertical: Optional[DataArray] = None, ) -> tuple[AxisAlignedPathIntegral, ...]: """Returns four ``AxisAlignedPathIntegral`` instances, which represent a contour integral around the surface defined by ``self.size``.""" path_specs = self._to_path_integral_specs(h_horizontal=h_horizontal, h_vertical=h_vertical) path_integrals = tuple( - AxisAlignedPathIntegral(**path_spec.dict(exclude={"type"})) for path_spec in path_specs + AxisAlignedPathIntegral(**path_spec.model_dump(exclude={"type"})) + for path_spec in path_specs ) return path_integrals diff --git a/tidy3d/components/microwave/path_integrals/integrals/voltage.py b/tidy3d/components/microwave/path_integrals/integrals/voltage.py index 3d803a5a49..28f70b94ab 100644 --- a/tidy3d/components/microwave/path_integrals/integrals/voltage.py +++ b/tidy3d/components/microwave/path_integrals/integrals/voltage.py @@ -2,20 +2,22 @@ from __future__ import annotations -from tidy3d.components.data.data_array import ( - VoltageIntegralResultType, - _make_voltage_data_array, -) +from typing import TYPE_CHECKING + +from tidy3d.components.data.data_array import _make_voltage_data_array from tidy3d.components.microwave.path_integrals.integrals.base import ( AxisAlignedPathIntegral, Custom2DPathIntegral, - IntegrableMonitorDataType, ) from tidy3d.components.microwave.path_integrals.specs.voltage import ( AxisAlignedVoltageIntegralSpec, Custom2DVoltageIntegralSpec, ) +if TYPE_CHECKING: + from tidy3d.components.data.data_array import VoltageIntegralResultType + from tidy3d.components.microwave.path_integrals.integrals.base import IntegrableMonitorDataType + class AxisAlignedVoltageIntegral(AxisAlignedPathIntegral, AxisAlignedVoltageIntegralSpec): """Class for computing the voltage between two points defined by an axis-aligned line. @@ -38,7 +40,7 @@ def compute_voltage(self, em_field: IntegrableMonitorDataType) -> VoltageIntegra e_component = "xyz"[self.main_axis] field_name = f"E{e_component}" # Validate that fields are present - em_field._check_fields_stored([field_name]) + em_field._check_fields_stored([field_name]) # type: ignore[list-item] e_field = em_field.field_components[field_name] voltage = self.compute_integral(e_field) diff --git a/tidy3d/components/microwave/path_integrals/mode_plane_analyzer.py b/tidy3d/components/microwave/path_integrals/mode_plane_analyzer.py index 81d2188fbd..c5ab6c7e7c 100644 --- a/tidy3d/components/microwave/path_integrals/mode_plane_analyzer.py +++ b/tidy3d/components/microwave/path_integrals/mode_plane_analyzer.py @@ -4,9 +4,10 @@ from itertools import chain from math import isclose +from typing import TYPE_CHECKING -import pydantic.v1 as pd import shapely +from pydantic import Field from shapely.geometry import LineString, Polygon from tidy3d.components.base import cached_property @@ -19,23 +20,28 @@ merging_geometries_on_plane, snap_box_to_grid, ) -from tidy3d.components.grid.grid import Grid -from tidy3d.components.medium import LossyMetalMedium, Medium -from tidy3d.components.structure import Structure -from tidy3d.components.types import Axis, Bound, Coordinate, Shapely, Symmetry +from tidy3d.components.medium import LossyMetalMedium from tidy3d.components.validators import assert_plane from tidy3d.exceptions import SetupError +if TYPE_CHECKING: + from tidy3d.components.grid.grid import Grid + from tidy3d.components.medium import Medium + from tidy3d.components.structure import Structure + from tidy3d.components.types import Axis, Bound, Coordinate, Shapely, Symmetry + class ModePlaneAnalyzer(Box): """Analyzes conductor geometry intersecting a mode plane. - This class analyzes the geometry of conductors in a simulation cross-section and is for internal use. + Notes + ----- + This class analyzes the geometry of conductors in a simulation cross-section and is for internal use. """ _plane_validator = assert_plane() - field_data_colocated: bool = pd.Field( + field_data_colocated: bool = Field( False, title="Field Data Colocated", description="Whether field data is colocated with grid points. When 'True', bounding boxes " diff --git a/tidy3d/components/microwave/path_integrals/specs/base.py b/tidy3d/components/microwave/path_integrals/specs/base.py index d012360b23..2c9b054092 100644 --- a/tidy3d/components/microwave/path_integrals/specs/base.py +++ b/tidy3d/components/microwave/path_integrals/specs/base.py @@ -3,22 +3,31 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import numpy as np -import pydantic.v1 as pd import shapely -import xarray as xr -from typing_extensions import Self +from pydantic import Field, field_validator from tidy3d.components.base import cached_property from tidy3d.components.geometry.base import Box, Geometry from tidy3d.components.microwave.base import MicrowaveBaseModel -from tidy3d.components.types import ArrayFloat2D, Bound, Coordinate, Coordinate2D -from tidy3d.components.types.base import Axis, Direction +from tidy3d.components.types import ArrayFloat2D +from tidy3d.components.types.base import Axis from tidy3d.components.validators import assert_line from tidy3d.constants import MICROMETER, fp_eps from tidy3d.exceptions import SetupError +if TYPE_CHECKING: + from typing import Optional + + import xarray as xr + from numpy.typing import NDArray + from typing_extensions import Self + + from tidy3d.components.types import Bound, Coordinate, Coordinate2D + from tidy3d.components.types.base import Direction + class AbstractAxesRH(MicrowaveBaseModel, ABC): """Represents an axis-aligned right-handed coordinate system with one axis preferred. @@ -68,7 +77,7 @@ class AxisAlignedPathIntegralSpec(AbstractAxesRH, Box): _line_validator = assert_line() - extrapolate_to_endpoints: bool = pd.Field( + extrapolate_to_endpoints: bool = Field( False, title="Extrapolate to Endpoints", description="If the endpoints of the path integral terminate at or near a material interface, " @@ -76,7 +85,7 @@ class AxisAlignedPathIntegralSpec(AbstractAxesRH, Box): "of the integral are ignored. Should be enabled when computing voltage between two conductors.", ) - snap_path_to_grid: bool = pd.Field( + snap_path_to_grid: bool = Field( False, title="Snap Path to Grid", description="It might be desirable to integrate exactly along the Yee grid associated with " @@ -84,11 +93,12 @@ class AxisAlignedPathIntegralSpec(AbstractAxesRH, Box): ) @cached_property - def main_axis(self) -> Axis: + def main_axis(self) -> Optional[Axis]: """Axis for performing integration.""" for index, value in enumerate(self.size): if value != 0: return index + return None def _vertices_2D(self, axis: Axis) -> tuple[Coordinate2D, Coordinate2D]: """Returns the two vertices of this path in the plane defined by ``axis``.""" @@ -126,18 +136,16 @@ class Custom2DPathIntegralSpec(AbstractAxesRH): If the path is not closed, forward and backward differences are used at the endpoints. """ - axis: Axis = pd.Field( - ..., title="Axis", description="Specifies dimension of the planar axis (0,1,2) -> (x,y,z)." + axis: Axis = Field( + title="Axis", description="Specifies dimension of the planar axis (0,1,2) -> (x,y,z)." ) - position: float = pd.Field( - ..., + position: float = Field( title="Position", description="Position of the plane along the ``axis``.", ) - vertices: ArrayFloat2D = pd.Field( - ..., + vertices: ArrayFloat2D = Field( title="Vertices", description="List of (d1, d2) defining the 2 dimensional positions of the path. " "The index of dimension should be in the ascending order, which means " @@ -148,7 +156,7 @@ class Custom2DPathIntegralSpec(AbstractAxesRH): ) @staticmethod - def _compute_dl_component(coord_array: xr.DataArray, closed_contour=False) -> np.ndarray: + def _compute_dl_component(coord_array: xr.DataArray, closed_contour: bool = False) -> NDArray: """Computes the differential length element along the integration path.""" dl = np.gradient(coord_array) if closed_contour: @@ -185,7 +193,9 @@ def from_circular_path( A path integral defined on a circular path. """ - def generate_circle_coordinates(radius: float, num_points: int, clockwise: bool): + def generate_circle_coordinates( + radius: float, num_points: int, clockwise: bool + ) -> tuple[np.ndarray, np.ndarray]: """Helper for generating x,y vertices around a circle in the local coordinate frame.""" sign = 1.0 if clockwise: @@ -225,8 +235,9 @@ def main_axis(self) -> Axis: """Axis for performing integration.""" return self.axis - @pd.validator("vertices", always=True) - def _correct_shape(cls, val): + @field_validator("vertices") + @classmethod + def _correct_shape(cls, val: ArrayFloat2D) -> ArrayFloat2D: """Makes sure vertices size is correct.""" # overall shape of vertices if val.shape[1] != 2: diff --git a/tidy3d/components/microwave/path_integrals/specs/current.py b/tidy3d/components/microwave/path_integrals/specs/current.py index 20b742e153..aeffae7384 100644 --- a/tidy3d/components/microwave/path_integrals/specs/current.py +++ b/tidy3d/components/microwave/path_integrals/specs/current.py @@ -2,10 +2,10 @@ from __future__ import annotations -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, field_validator from tidy3d.components.base import cached_property from tidy3d.components.geometry.base import Box, Geometry @@ -17,13 +17,19 @@ Custom2DPathIntegralSpec, ) from tidy3d.components.microwave.path_integrals.viz import ARROW_CURRENT, plot_params_current_path -from tidy3d.components.types import Ax, Bound -from tidy3d.components.types.base import Axis, Direction +from tidy3d.components.types.base import Direction from tidy3d.components.validators import assert_plane from tidy3d.components.viz import add_ax_if_none from tidy3d.constants import fp_eps from tidy3d.exceptions import SetupError +if TYPE_CHECKING: + from typing import Optional + + from tidy3d.components.data.data_array import DataArray + from tidy3d.components.types import Ax, Bound + from tidy3d.components.types.base import Axis + class AxisAlignedCurrentIntegralSpec(AbstractAxesRH, Box): """Class for specifying the computation of conduction current via Ampère's circuital law on an axis-aligned loop. @@ -40,19 +46,18 @@ class AxisAlignedCurrentIntegralSpec(AbstractAxesRH, Box): _plane_validator = assert_plane() - sign: Direction = pd.Field( - ..., + sign: Direction = Field( title="Direction of Contour Integral", description="Positive indicates current flowing in the positive normal axis direction.", ) - extrapolate_to_endpoints: bool = pd.Field( + extrapolate_to_endpoints: bool = Field( False, title="Extrapolate to Endpoints", description="This parameter is passed to :class:`AxisAlignedPathIntegral` objects when computing the contour integral.", ) - snap_contour_to_grid: bool = pd.Field( + snap_contour_to_grid: bool = Field( False, title="Snap Contour to Grid", description="This parameter is passed to :class:`AxisAlignedPathIntegral` objects when computing the contour integral.", @@ -64,9 +69,12 @@ def main_axis(self) -> Axis: for index, value in enumerate(self.size): if value == 0: return index + raise SetupError("AxisAlignedCurrentIntegralSpec requires a zero-sized dimension.") def _to_path_integral_specs( - self, h_horizontal=None, h_vertical=None + self, + h_horizontal: Optional[DataArray] = None, + h_vertical: Optional[DataArray] = None, ) -> tuple[AxisAlignedPathIntegralSpec, ...]: """Returns four ``AxisAlignedPathIntegralSpec`` instances, which represent a contour integral around the surface defined by ``self.size``.""" @@ -294,8 +302,10 @@ def plot( class CompositeCurrentIntegralSpec(MicrowaveBaseModel): """Specification for a composite current integral. - This class is used to set up a ``CompositeCurrentIntegral``, which combines - multiple current integrals. It does not perform any integration itself. + Notes + ----- + This class is used to set up a ``CompositeCurrentIntegral``, which combines + multiple current integrals. It does not perform any integration itself. Example ------- @@ -312,15 +322,13 @@ class CompositeCurrentIntegralSpec(MicrowaveBaseModel): """ path_specs: tuple[Union[AxisAlignedCurrentIntegralSpec, Custom2DCurrentIntegralSpec], ...] = ( - pd.Field( - ..., + Field( title="Path Specifications", description="Definition of the disjoint path specifications for each isolated contour integral.", ) ) - sum_spec: Literal["sum", "split"] = pd.Field( - ..., + sum_spec: Literal["sum", "split"] = Field( title="Sum Specification", description="Determines the method used to combine the currents calculated by the different " "current integrals defined by ``path_specs``. ``sum`` simply adds all currents, while ``split`` " @@ -364,8 +372,11 @@ def plot( ax = path_spec.plot(x=x, y=y, z=z, ax=ax, **path_kwargs) return ax - @pd.validator("path_specs", always=True) - def _path_specs_not_empty(cls, val): + @field_validator("path_specs") + @classmethod + def _path_specs_not_empty( + cls, val: tuple[Union[AxisAlignedCurrentIntegralSpec, Custom2DCurrentIntegralSpec], ...] + ) -> tuple[Union[AxisAlignedCurrentIntegralSpec, Custom2DCurrentIntegralSpec], ...]: """Makes sure at least one path spec has been supplied""" # overall shape of vertices if len(val) < 1: diff --git a/tidy3d/components/microwave/path_integrals/specs/impedance.py b/tidy3d/components/microwave/path_integrals/specs/impedance.py index cdffbdffba..7ae0b898bc 100644 --- a/tidy3d/components/microwave/path_integrals/specs/impedance.py +++ b/tidy3d/components/microwave/path_integrals/specs/impedance.py @@ -2,11 +2,10 @@ from __future__ import annotations -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union -import pydantic.v1 as pd +from pydantic import Field, model_validator -from tidy3d.components.base import skip_if_fields_missing from tidy3d.components.microwave.base import MicrowaveBaseModel from tidy3d.components.microwave.path_integrals.types import ( CurrentPathSpecType, @@ -14,25 +13,32 @@ ) from tidy3d.exceptions import SetupError +if TYPE_CHECKING: + from tidy3d.compat import Self + class AutoImpedanceSpec(MicrowaveBaseModel): """Specification for fully automatic transmission line impedance computation. - This specification automatically calculates impedance by current - paths based on the simulation geometry and conductors that intersect the mode plane. - No user-defined path specifications are required. + Notes + ----- + Automatically calculates impedance using paths based on simulation geometry + and conductors that intersect the mode plane. No user-defined path + specifications are required. """ class CustomImpedanceSpec(MicrowaveBaseModel): """Specification for custom transmission line voltages and currents in mode solvers. - The :class:`.CustomImpedanceSpec` class specifies how quantities related to transmission line - modes are computed. It defines the paths for line integrals, which are used to - compute voltage, current, and characteristic impedance of the transmission line. + Notes + ----- + The :class:`.CustomImpedanceSpec` class specifies how quantities related to transmission line + modes are computed. It defines the paths for line integrals, which are used to + compute voltage, current, and characteristic impedance of the transmission line. - Users must supply at least one of voltage or current path specifications to control where these integrals - are evaluated. Both voltage_spec and current_spec cannot be ``None`` simultaneously. + Users must supply at least one of voltage or current path specifications to control where these integrals + are evaluated. Both voltage_spec and current_spec cannot be ``None`` simultaneously. Example ------- @@ -50,33 +56,32 @@ class CustomImpedanceSpec(MicrowaveBaseModel): ... ) """ - voltage_spec: Optional[VoltagePathSpecType] = pd.Field( + voltage_spec: Optional[VoltagePathSpecType] = Field( None, title="Voltage Integration Path", description="Path specification for computing the voltage associated with a mode profile.", ) - current_spec: Optional[CurrentPathSpecType] = pd.Field( + current_spec: Optional[CurrentPathSpecType] = Field( None, title="Current Integration Path", description="Path specification for computing the current associated with a mode profile.", ) - @pd.validator("current_spec", always=True) - @skip_if_fields_missing(["voltage_spec"]) - def check_path_spec_combinations(cls, val, values): + @model_validator(mode="after") + def check_path_spec_combinations(self) -> Self: """Validate that at least one of voltage_spec or current_spec is provided. In order to define voltage/current/impedance, either a voltage or current path specification must be provided. Both cannot be ``None`` simultaneously. """ - - voltage_spec = values["voltage_spec"] + val = self.current_spec + voltage_spec = self.voltage_spec if val is None and voltage_spec is None: raise SetupError( "Not a valid 'CustomImpedanceSpec', the 'voltage_spec' and 'current_spec' cannot both be 'None'." ) - return val + return self ImpedanceSpecType = Union[AutoImpedanceSpec, CustomImpedanceSpec] diff --git a/tidy3d/components/microwave/path_integrals/specs/voltage.py b/tidy3d/components/microwave/path_integrals/specs/voltage.py index e564502d85..c7cf3c80f0 100644 --- a/tidy3d/components/microwave/path_integrals/specs/voltage.py +++ b/tidy3d/components/microwave/path_integrals/specs/voltage.py @@ -2,11 +2,10 @@ from __future__ import annotations -from typing import Any, Optional +from typing import TYPE_CHECKING, Any import numpy as np -import pydantic.v1 as pd -from typing_extensions import Self +from pydantic import Field from tidy3d.components.geometry.base import Geometry from tidy3d.components.microwave.path_integrals.specs.base import ( @@ -18,17 +17,22 @@ plot_params_voltage_path, plot_params_voltage_plus, ) -from tidy3d.components.types import Ax from tidy3d.components.types.base import Direction from tidy3d.components.viz import add_ax_if_none from tidy3d.constants import fp_eps +if TYPE_CHECKING: + from typing import Optional + + from typing_extensions import Self + + from tidy3d.components.types import Ax + class AxisAlignedVoltageIntegralSpec(AxisAlignedPathIntegralSpec): """Class for specifying the voltage calculation between two points defined by an axis-aligned line.""" - sign: Direction = pd.Field( - ..., + sign: Direction = Field( title="Direction of Path Integral", description="Positive indicates V=Vb-Va where position b has a larger coordinate along the axis of integration.", ) diff --git a/tidy3d/components/mode/data/sim_data.py b/tidy3d/components/mode/data/sim_data.py index 0aba271bf2..9f412446b7 100644 --- a/tidy3d/components/mode/data/sim_data.py +++ b/tidy3d/components/mode/data/sim_data.py @@ -2,39 +2,43 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Union -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import cached_property from tidy3d.components.data.monitor_data import MediumData, PermittivityData from tidy3d.components.data.sim_data import AbstractYeeGridSimulationData from tidy3d.components.mode.simulation import ModeSimulation -from tidy3d.components.mode_spec import ModeSortSpec -from tidy3d.components.types import TYPE_TAG_STR, Ax, PlotScale +from tidy3d.components.types import TYPE_TAG_STR from tidy3d.components.types.monitor_data import ModeSolverDataType ModeSimulationMonitorDataType = Union[PermittivityData, MediumData] if TYPE_CHECKING: + from typing import Literal, Optional + from matplotlib.colors import Colormap + from tidy3d.components.mode_spec import ModeSortSpec + from tidy3d.components.types import Ax, PlotScale + class ModeSimulationData(AbstractYeeGridSimulationData): """Data associated with a mode solver simulation.""" - simulation: ModeSimulation = pd.Field( - ..., title="Mode simulation", description="Mode simulation associated with this data." + simulation: ModeSimulation = Field( + title="Mode simulation", + description="Mode simulation associated with this data.", ) - modes_raw: ModeSolverDataType = pd.Field( - ..., + modes_raw: ModeSolverDataType = Field( title="Raw Modes", description=":class:`.ModeSolverDataType` containing the field and effective index on unexpanded grid.", discriminator=TYPE_TAG_STR, ) - data: tuple[ModeSimulationMonitorDataType, ...] = pd.Field( + data: tuple[ModeSimulationMonitorDataType, ...] = Field( (), title="Monitor Data", description="List of monitor data " diff --git a/tidy3d/components/mode/derivatives.py b/tidy3d/components/mode/derivatives.py index 19f0f4b04f..5bd4e43010 100644 --- a/tidy3d/components/mode/derivatives.py +++ b/tidy3d/components/mode/derivatives.py @@ -2,12 +2,24 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np +from numpy.typing import NDArray from tidy3d.constants import EPSILON_0, ETA_0 +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal + + from scipy import sparse as sp + +ArrayFloat = NDArray[np.floating] +ArrayComplex = NDArray[np.complexfloating] + -def make_dxf(dls, shape, pmc): +def make_dxf(dls: ArrayFloat, shape: tuple[int, int], pmc: bool) -> sp.csr_matrix: """Forward derivative in x.""" import scipy.sparse as sp @@ -22,7 +34,7 @@ def make_dxf(dls, shape, pmc): return dxf -def make_dxb(dls, shape, pmc): +def make_dxb(dls: ArrayFloat, shape: tuple[int, int], pmc: bool) -> sp.csr_matrix: """Backward derivative in x.""" import scipy.sparse as sp @@ -39,7 +51,7 @@ def make_dxb(dls, shape, pmc): return dxb -def make_dyf(dls, shape, pmc): +def make_dyf(dls: ArrayFloat, shape: tuple[int, int], pmc: bool) -> sp.csr_matrix: """Forward derivative in y.""" import scipy.sparse as sp @@ -54,7 +66,7 @@ def make_dyf(dls, shape, pmc): return dyf -def make_dyb(dls, shape, pmc): +def make_dyb(dls: ArrayFloat, shape: tuple[int, int], pmc: bool) -> sp.csr_matrix: """Backward derivative in y.""" import scipy.sparse as sp @@ -71,7 +83,11 @@ def make_dyb(dls, shape, pmc): return dyb -def create_d_matrices(shape, dls, dmin_pmc=(False, False)): +def create_d_matrices( + shape: tuple[int, int], + dls: tuple[Sequence[ArrayFloat], Sequence[ArrayFloat]], + dmin_pmc: tuple[bool, bool] = (False, False), +) -> tuple[sp.csr_matrix, sp.csr_matrix, sp.csr_matrix, sp.csr_matrix]: """Make the derivative matrices without PML. If dmin_pmc is True, the 'backward' derivative in that dimension will be set to implement PMC boundary, otherwise it will be set to PEC.""" @@ -85,7 +101,15 @@ def create_d_matrices(shape, dls, dmin_pmc=(False, False)): return (dxf, dxb, dyf, dyb) -def create_s_matrices(omega, shape, npml, dls, eps_tensor, mu_tensor, dmin_pml=(True, True)): +def create_s_matrices( + omega: float, + shape: tuple[int, int], + npml: tuple[int, int], + dls: tuple[Sequence[ArrayFloat], Sequence[ArrayFloat]], + eps_tensor: ArrayComplex, + mu_tensor: ArrayComplex, + dmin_pml: tuple[bool, bool] = (True, True), +) -> tuple[sp.csr_matrix, sp.csr_matrix, sp.csr_matrix, sp.csr_matrix]: """Makes the 'S-matrices'. When dotted with derivative matrices, they add PML. If dmin_pml is set to False, PML will not be applied on the "bottom" side of the domain.""" @@ -136,17 +160,23 @@ def create_s_matrices(omega, shape, npml, dls, eps_tensor, mu_tensor, dmin_pml=( return sx_f, sx_b, sy_f, sy_b -def average_relative_speed(Nx, Ny, npml, eps_tensor, mu_tensor): +def average_relative_speed( + Nx: int, + Ny: int, + npml: tuple[int, int], + eps_tensor: ArrayComplex, + mu_tensor: ArrayComplex, +) -> ArrayFloat: """Compute the relative speed of light in the four pml regions by averaging the diagonal elements of the relative epsilon and mu within the pml region.""" - def relative_mean(tensor): + def relative_mean(tensor: ArrayComplex) -> float: """Mean for relative parameters. If an empty array just return 1.""" if tensor.size == 0: return 1.0 return np.mean(tensor) - def pml_average_allsides(tensor): + def pml_average_allsides(tensor: ArrayComplex) -> ArrayFloat: """Average ``tensor`` in the PML regions on all four sides. Returns the average values in order (xminus, xplus, yminus, yplus).""" @@ -165,7 +195,15 @@ def pml_average_allsides(tensor): return 1 / np.sqrt(eps_avg * mu_avg) -def create_sfactor(direction, omega, dls, N, n_pml, dmin_pml, avg_speed): +def create_sfactor( + direction: Literal["f", "b"], + omega: float, + dls: ArrayFloat, + N: int, + n_pml: int, + dmin_pml: bool, + avg_speed: Sequence[float], +) -> ArrayComplex: """Creates the S-factor cross section needed in the S-matrices""" # For no PNL, this should just be identity matrix. @@ -181,7 +219,14 @@ def create_sfactor(direction, omega, dls, N, n_pml, dmin_pml, avg_speed): raise ValueError(f"Direction value {direction} not recognized") -def create_sfactor_f(omega, dls, N, n_pml, dmin_pml, avg_speed=(1, 1)): +def create_sfactor_f( + omega: float, + dls: ArrayFloat, + N: int, + n_pml: int, + dmin_pml: bool, + avg_speed: Sequence[float] = (1, 1), +) -> ArrayComplex: """S-factor profile applied after forward derivative matrix, i.e. applied to H-field locations.""" sfactor_array = np.ones(N, dtype=np.complex128) @@ -195,7 +240,14 @@ def create_sfactor_f(omega, dls, N, n_pml, dmin_pml, avg_speed=(1, 1)): return sfactor_array -def create_sfactor_b(omega, dls, N, n_pml, dmin_pml, avg_speed=(1, 1)): +def create_sfactor_b( + omega: float, + dls: ArrayFloat, + N: int, + n_pml: int, + dmin_pml: bool, + avg_speed: Sequence[float] = (1, 1), +) -> ArrayComplex: """S-factor profile applied after backward derivative matrix, i.e. applied to E-field locations.""" sfactor_array = np.ones(N, dtype=np.complex128) @@ -209,14 +261,14 @@ def create_sfactor_b(omega, dls, N, n_pml, dmin_pml, avg_speed=(1, 1)): def s_value( dl: float, - step: int, + step: float, omega: float, avg_speed: float, sigma_max: float = 2, kappa_min: float = 1, kappa_max: float = 3, order: int = 3, -): +) -> complex: """S-value to use in the S-matrices. We use coordinate stretching formulation such that s(x) = kappa(x) + 1j * sigma(x) / (omega * EPSILON_0) diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 5acc5f4e64..426306a118 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -6,20 +6,18 @@ from functools import wraps from math import isclose -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union, get_args import numpy as np -import pydantic.v1 as pydantic import xarray as xr +from pydantic import Field, field_validator, model_validator from tidy3d.components.base import ( Tidy3dBaseModel, cached_property, - skip_if_fields_missing, ) from tidy3d.components.boundary import PML, Absorber, Boundary, BoundarySpec, PECBoundary, StablePML from tidy3d.components.data.data_array import ( - FreqModeDataArray, ModeIndexDataArray, ScalarModeFieldCylindricalDataArray, ScalarModeFieldDataArray, @@ -32,7 +30,6 @@ from tidy3d.components.eme.data.sim_data import EMESimulationData from tidy3d.components.eme.simulation import EMESimulation from tidy3d.components.geometry.base import Box -from tidy3d.components.grid.grid import Coords, Grid from tidy3d.components.medium import ( FullyAnisotropicMedium, IsotropicUniformMediumType, @@ -40,40 +37,18 @@ ) from tidy3d.components.microwave.data.dataset import TransmissionLineDataset from tidy3d.components.microwave.data.monitor_data import MicrowaveModeSolverData -from tidy3d.components.microwave.impedance_calculator import ( - CurrentIntegralType, - ImpedanceCalculator, - VoltageIntegralType, -) +from tidy3d.components.microwave.impedance_calculator import ImpedanceCalculator from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec from tidy3d.components.microwave.monitor import MicrowaveModeMonitor, MicrowaveModeSolverMonitor from tidy3d.components.microwave.path_integrals.factory import make_path_integrals -from tidy3d.components.mode_spec import ModeSpec from tidy3d.components.monitor import ModeMonitor, ModeSolverMonitor from tidy3d.components.scene import Scene from tidy3d.components.simulation import Simulation from tidy3d.components.source.field import ModeSource -from tidy3d.components.source.time import SourceTime -from tidy3d.components.structure import Structure from tidy3d.components.subpixel_spec import SurfaceImpedance -from tidy3d.components.types import ( - TYPE_TAG_STR, - ArrayComplex3D, - ArrayComplex4D, - ArrayFloat1D, - ArrayFloat2D, - Ax, - Axis, - Axis2D, - Direction, - EMField, - EpsSpecType, - FreqArray, - PlotScale, - Symmetry, -) +from tidy3d.components.types import ArrayComplex3D, Direction, EMField, FreqArray +from tidy3d.components.types.base import TYPE_TAG_STR, discriminated_union from tidy3d.components.types.mode_spec import ModeSpecType -from tidy3d.components.types.monitor_data import ModeSolverDataType from tidy3d.components.validators import ( validate_freqs_min, validate_freqs_not_empty, @@ -84,7 +59,33 @@ from tidy3d.log import log if TYPE_CHECKING: + from typing import Callable, Literal, Optional + from matplotlib.colors import Colormap + from pydantic import NonNegativeFloat, NonNegativeInt, PositiveInt + + from tidy3d.compat import Self + from tidy3d.components.data.data_array import FreqModeDataArray + from tidy3d.components.grid.grid import Coords, Grid + from tidy3d.components.microwave.impedance_calculator import ( + CurrentIntegralType, + VoltageIntegralType, + ) + from tidy3d.components.mode_spec import ModeSpec + from tidy3d.components.source.time import SourceTime + from tidy3d.components.structure import Structure + from tidy3d.components.types import ( + ArrayComplex4D, + ArrayFloat1D, + ArrayFloat2D, + Ax, + Axis, + Axis2D, + EpsSpecType, + PlotScale, + Symmetry, + ) + from tidy3d.components.types.monitor_data import ModeSolverDataType from tidy3d.packaging import supports_local_subpixel, tidy3d_extras # Importing the local solver may not work if e.g. scipy is not installed @@ -108,9 +109,9 @@ # Maximum allowed size of the field data produced by the mode solver MAX_MODES_DATA_SIZE_GB = 20 -MODE_SIMULATION_TYPE = Union[Simulation, EMESimulation] -MODE_SIMULATION_DATA_TYPE = Union[SimulationData, EMESimulationData] -MODE_PLANE_TYPE = Union[Box, ModeSource, ModeMonitor, ModeSolverMonitor] +MODE_SIMULATION_TYPE = discriminated_union(Union[Simulation, EMESimulation]) +MODE_SIMULATION_DATA_TYPE = discriminated_union(Union[SimulationData, EMESimulationData]) +MODE_PLANE_TYPE = discriminated_union(Union[Box, ModeSource, ModeMonitor, ModeSolverMonitor]) # When using ``angle_rotation`` without a bend, use a very large effective radius EFFECTIVE_RADIUS_FACTOR = 10_000 @@ -118,19 +119,23 @@ # Log a warning when the PML covers more than this portion of the mode plane in any axis WARN_THICK_PML_PERCENT = 50 +P = ParamSpec("P") +R = TypeVar("R") + -def require_fdtd_simulation(fn): +def require_fdtd_simulation(fn: Callable[P, R]) -> Callable[P, R]: """Decorate a function to check that ``simulation`` is an FDTD ``Simulation``.""" @wraps(fn) - def _fn(self, **kwargs: Any): + def _fn(*args: P.args, **kwargs: P.kwargs) -> R: """New decorated function.""" + self = args[0] if not isinstance(self.simulation, Simulation): raise SetupError( f"The function '{fn.__name__}' is only supported " "for 'simulation' of type FDTD 'Simulation'." ) - return fn(self, **kwargs) + return fn(*args, **kwargs) return _fn @@ -154,52 +159,49 @@ class ModeSolver(Tidy3dBaseModel): * `Prelude to Integrated Photonics Simulation: Mode Injection `_ """ - simulation: MODE_SIMULATION_TYPE = pydantic.Field( - ..., + simulation: MODE_SIMULATION_TYPE = Field( title="Simulation", description="Simulation or EMESimulation defining all structures and mediums.", discriminator="type", ) - plane: MODE_PLANE_TYPE = pydantic.Field( - ..., + plane: MODE_PLANE_TYPE = Field( title="Plane", description="Cross-sectional plane in which the mode will be computed.", - discriminator=TYPE_TAG_STR, ) - mode_spec: ModeSpecType = pydantic.Field( - ..., + mode_spec: ModeSpecType = Field( title="Mode specification", description="Container with specifications about the modes to be solved for.", discriminator=TYPE_TAG_STR, ) - freqs: FreqArray = pydantic.Field( - ..., title="Frequencies", description="A list of frequencies at which to solve." + freqs: FreqArray = Field( + title="Frequencies", + description="A list of frequencies at which to solve.", ) - direction: Direction = pydantic.Field( + direction: Direction = Field( "+", title="Propagation direction", description="Direction of waveguide mode propagation along the axis defined by its normal " "dimension.", ) - colocate: bool = pydantic.Field( + colocate: bool = Field( True, title="Colocate fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " "primal grid nodes). Default is ``True``.", ) - conjugated_dot_product: bool = pydantic.Field( + conjugated_dot_product: bool = Field( True, title="Conjugated Dot Product", description="Use conjugated or non-conjugated dot product for mode decomposition.", ) - fields: tuple[EMField, ...] = pydantic.Field( + fields: tuple[EMField, ...] = Field( ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], title="Field Components", description="Collection of field components to store in the monitor. Note that some " @@ -207,8 +209,9 @@ class ModeSolver(Tidy3dBaseModel): "like ``mode_area`` require all E-field components.", ) - @pydantic.validator("simulation", pre=True, always=True) - def _convert_to_simulation(cls, val): + @field_validator("simulation") + @classmethod + def _convert_to_simulation(cls, val: MODE_SIMULATION_TYPE) -> MODE_SIMULATION_TYPE: """Convert to regular Simulation if e.g. JaxSimulation given.""" if hasattr(val, "to_simulation"): val = val.to_simulation()[0] @@ -218,8 +221,9 @@ def _convert_to_simulation(cls, val): ) return val - @pydantic.validator("plane", always=True) - def is_plane(cls, val): + @field_validator("plane") + @classmethod + def is_plane(cls, val: MODE_PLANE_TYPE) -> MODE_PLANE_TYPE: """Raise validation error if not planar.""" if val.size.count(0.0) != 1: raise ValidationError(f"ModeSolver plane must be planar, given size={val}") @@ -228,53 +232,66 @@ def is_plane(cls, val): _freqs_not_empty = validate_freqs_not_empty() _freqs_lower_bound = validate_freqs_min() - @pydantic.validator("plane", always=True) - @skip_if_fields_missing(["simulation"]) - def plane_in_sim_bounds(cls, val, values): + @model_validator(mode="after") + def plane_in_sim_bounds(self) -> Self: """Check that the plane is at least partially inside the simulation bounds.""" - sim_center = values.get("simulation").center - sim_size = values.get("simulation").size - sim_box = Box(size=sim_size, center=sim_center) - - if not sim_box.intersects(val): + sim_box = Box(size=self.simulation.size, center=self.simulation.center) + if not sim_box.intersects(self.plane): raise SetupError("'ModeSolver.plane' must intersect 'ModeSolver.simulation'.") - return val + return self - @pydantic.validator("plane", always=True) - @skip_if_fields_missing(["simulation"]) - def _warn_plane_crosses_symmetry(cls, val, values): + @model_validator(mode="after") + def _warn_plane_crosses_symmetry(self) -> Self: """Warn if the mode plane crosses the symmetry plane of the underlying simulation but the centers do not match.""" - simulation = values.get("simulation") - bounds = val.bounds - # now check in each dimension whether we cross symmetry plane for dim in range(3): - if simulation.symmetry[dim] != 0: + if self.simulation.symmetry[dim] != 0: crosses_symmetry = ( - bounds[0][dim] < simulation.center[dim] - and bounds[1][dim] > simulation.center[dim] + self.plane.bounds[0][dim] < self.simulation.center[dim] + and self.plane.bounds[1][dim] > self.simulation.center[dim] ) if crosses_symmetry: - if not isclose(val.center[dim], simulation.center[dim]): + if not isclose(self.plane.center[dim], self.simulation.center[dim]): log.warning( f"The original simulation is symmetric along {'xyz'[dim]} direction. " "The mode simulation region does cross the symmetry plane but is " "not symmetric with respect to it. To preserve correct symmetry, " "the requested simulation region will be expanded by the solver." ) - return val + return self - def _post_init_validators(self) -> None: - self._validate_mode_plane_radius( - mode_spec=self.mode_spec, - plane=self.plane, - sim_geom=self.simulation.geometry, - ) + @model_validator(mode="after") + def _validate_warn_thick_pml(self) -> Self: + """Warn if the pml covers a significant portion of the mode plane.""" self._warn_thick_pml(simulation=self.simulation, plane=self.plane, mode_spec=self.mode_spec) self._validate_rotate_structures() - self._validate_num_grid_points() - if self._has_microwave_mode_spec: - self._validate_microwave_mode_spec(mode_spec=self.mode_spec, plane=self.plane) + return self + + @model_validator(mode="after") + def _validate_bend_radius(self) -> Self: + """Validate that the bend radius is not too small.""" + sim_box = Box(size=self.simulation.size, center=self.simulation.center) + self._validate_mode_plane_radius(self.mode_spec, self.plane, sim_box) + return self + + @model_validator(mode="after") + def _validate_rotate_structures_after(self) -> Self: + self._validate_rotate_structures() + return self + + @model_validator(mode="after") + def _validate_num_grid_points(self) -> Self: + """Upper bound of the product of the number of grid points and the number of modes. The bound is very loose: subspace + size times the size of eigenvector can be indexed by a 32bit integer. + """ + num_cells, _, num_modes = self._num_cells_freqs_modes + relaxation_factor = 2 + if num_cells * (20 + 2 * num_modes) * relaxation_factor > 2**32 - 1: + raise SetupError( + "Too many grid points on the modal plane. Please reduce the modal plane size, apply a coarser grid, " + "or reduce the number of modes." + ) + return self @classmethod def _warn_thick_pml( @@ -285,10 +302,7 @@ def _warn_thick_pml( msg_prefix: str = "'ModeSolver'", ) -> None: """Warn if the pml covers a significant portion of the mode plane.""" - coord_0, coord_1 = cls._plane_grid( - simulation=simulation, - plane=plane, - ) + coord_0, coord_1 = cls._plane_grid(simulation=simulation, plane=plane) num_cells = [len(coord_0), len(coord_1)] effective_num_pml = cls._effective_num_pml( simulation=simulation, plane=plane, mode_spec=mode_spec @@ -319,7 +333,6 @@ def _validate_mode_plane_radius(cls, mode_spec: ModeSpec, plane: Box, sim_geom: return mode_plane = cls._mode_plane(plane=plane, sim_geom=sim_geom) - # radial axis is the plane axis that is not the bend axis _, plane_axs = mode_plane.pop_axis([0, 1, 2], mode_plane.size.index(0.0)) radial_ax = plane_axs[(mode_spec.bend_axis + 1) % 2] @@ -335,17 +348,34 @@ def _validate_rotate_structures(self) -> None: if np.abs(self.mode_spec.angle_theta) > 0 and self.mode_spec.angle_rotation: _ = self._rotate_structures - def _validate_num_grid_points(self) -> None: - """Upper bound of the product of the number of grid points and the number of modes. The bound is very loose: subspace - size times the size of eigenvector can be indexed by a 32bit integer. - """ - num_cells, _, num_modes = self._num_cells_freqs_modes - relaxation_factor = 2 - if num_cells * (20 + 2 * num_modes) * relaxation_factor > 2**32 - 1: + @staticmethod + def _make_rotated_structures( + structures: list[Structure], translate_kwargs: dict, rotate_kwargs: dict + ) -> list[Structure]: + try: + rotated_structures = [] + for structure in structures: + if not isinstance(structure.medium, get_args(IsotropicUniformMediumType)): + raise NotImplementedError( + "Mode solver plane intersects an unsupported medium. " + "Only uniform isotropic media are supported for the plane rotation." + ) + + # Rotate and apply translations + geometry = structure.geometry + geometry = ( + geometry.translated(**{key: -val for key, val in translate_kwargs.items()}) + .rotated(**rotate_kwargs) + .translated(**translate_kwargs) + ) + + rotated_structures.append(structure.updated_copy(geometry=geometry)) + + return rotated_structures + except Exception as e: raise SetupError( - "Too many grid points on the modal plane. Please reduce the modal plane size, apply a coarser grid, " - "or reduce the number of modes." - ) + f"'angle_rotation' set to True but could not rotate structures: {e!s}" + ) from e @classmethod def _validate_microwave_mode_spec(cls, mode_spec: MicrowaveModeSpec, plane: Box) -> None: @@ -358,7 +388,7 @@ def normal_axis(self) -> Axis: return self.plane.size.index(0.0) @staticmethod - def plane_center_tangential(plane) -> tuple[float, float]: + def plane_center_tangential(plane: MODE_PLANE_TYPE) -> tuple[float, float]: """Mode lane center in the tangential axes.""" _, plane_center = plane.pop_axis(plane.center, plane.size.index(0.0)) return plane_center @@ -379,7 +409,7 @@ def _solver_symmetry(simulation: Simulation, plane: Box) -> tuple[Symmetry, Symm if not isclose(simulation.center[dim], plane.center[dim]): mode_symmetry[dim] = 0 _, solver_sym = plane.pop_axis(mode_symmetry, axis=normal_axis) - return solver_sym + return tuple(solver_sym) @cached_property def solver_symmetry(self) -> tuple[Symmetry, Symmetry]: @@ -523,7 +553,9 @@ def data_raw(self) -> ModeSolverDataType: # Compute data on the Yee grid mode_solver_data = self._data_on_yee_grid() if self._has_microwave_mode_spec: - mode_solver_data = MicrowaveModeSolverData(**mode_solver_data.dict(exclude={"type"})) + mode_solver_data = MicrowaveModeSolverData( + **mode_solver_data.model_dump(exclude={"type"}) + ) # Colocate to grid boundaries if requested if self.colocate: @@ -651,7 +683,7 @@ def rotated_mode_solver_data(self) -> ModeSolverData: return rotated_mode_data @cached_property - def rotated_structures_copy(self): + def rotated_structures_copy(self) -> ModeSolver: """Create a copy of the original ModeSolver with rotated structures to the simulation and updates the ModeSpec to disable bend correction and reset angles to normal.""" @@ -696,35 +728,6 @@ def _rotate_structures(self) -> list[Structure]: structs_in = Scene.intersecting_structures(self.plane, self.simulation.structures) return self._make_rotated_structures(structs_in, translate_kwargs, rotate_kwargs) - @staticmethod - def _make_rotated_structures( - structures: list[Structure], translate_kwargs: dict, rotate_kwargs: dict - ): - try: - rotated_structures = [] - for structure in structures: - if not isinstance(structure.medium, get_args(IsotropicUniformMediumType)): - raise NotImplementedError( - "Mode solver plane intersects an unsupported medium. " - "Only uniform isotropic media are supported for the plane rotation." - ) - - # Rotate and apply translations - geometry = structure.geometry - geometry = ( - geometry.translated(**{key: -val for key, val in translate_kwargs.items()}) - .rotated(**rotate_kwargs) - .translated(**translate_kwargs) - ) - - rotated_structures.append(structure.updated_copy(geometry=geometry)) - - return rotated_structures - except Exception as e: - raise SetupError( - f"'angle_rotation' set to True but could not rotate structures: {e!s}" - ) from e - @cached_property def rotated_bend_center(self) -> list: """Calculate the center at the rotated bend such that the modal plane is normal @@ -738,7 +741,7 @@ def rotated_bend_center(self) -> list: # # Leaving for future reference if needed # def _ref_data_straight( # self, mode_solver_data: ModeSolverData - # ) -> Dict[Union[ScalarModeFieldDataArray, ModeIndexDataArray]]: + # ) -> dict[Union[ScalarModeFieldDataArray, ModeIndexDataArray]]: # """Convert reference data to be centered at the monitor center.""" # # Reference solution stored @@ -762,7 +765,7 @@ def rotated_bend_center(self) -> list: def _car_2_cyn( self, mode_solver_data: ModeSolverData - ) -> dict[Union[ScalarModeFieldCylindricalDataArray, ModeIndexDataArray]]: + ) -> dict[str, Union[ScalarModeFieldCylindricalDataArray, ModeIndexDataArray]]: """Convert cartesian fields to cylindrical fields centered at the rotated bend center.""" @@ -854,7 +857,7 @@ def _car_2_cyn( # # Leaving for future reference if needed # def _mode_rotation_straight( # self, - # solver_ref_data: Dict[Union[ModeSolverData]], + # solver_ref_data: dict[Union[ModeSolverData]], # solver: ModeSolver, # ) -> ModeSolverData: # """Rotate the mode solver solution from the reference plane @@ -963,7 +966,7 @@ def _car_2_cyn( def _mode_rotation( self, solver_ref_data_cylindrical: dict[ - Union[ScalarModeFieldCylindricalDataArray, ModeIndexDataArray] + str, Union[ScalarModeFieldCylindricalDataArray, ModeIndexDataArray] ], solver: ModeSolver, ) -> ModeSolverData: @@ -1105,7 +1108,7 @@ def theta_reference(self) -> float: return theta_ref @cached_property - def _bend_radius(self): + def _bend_radius(self) -> float: """A bend_radius to use when ``angle_rotation`` is on. When there is no bend defined, we use an effectively very large radius, much larger than the mode plane. This is only used for the rotation of the fields - the reference modes are still computed without any @@ -1117,7 +1120,7 @@ def _bend_radius(self): return EFFECTIVE_RADIUS_FACTOR * largest_dim @cached_property - def bend_center(self) -> list: + def bend_center(self) -> list[float]: """Computes the bend center based on plane center, angle_theta and angle_phi.""" _, id_bend_uv = self.plane.pop_axis((0, 1, 2), axis=self.bend_axis_3d) @@ -1361,7 +1364,7 @@ def _normalize_modes(self, mode_solver_data: ModeSolverData) -> None: """Normalize modes. Note: this modifies ``mode_solver_data`` in-place.""" mode_solver_data._normalize_modes() - def _filter_components(self, mode_solver_data: ModeSolverData): + def _filter_components(self, mode_solver_data: ModeSolverData) -> ModeSolverData: skip_components = { comp: None for comp in mode_solver_data.field_components.keys() @@ -1369,7 +1372,7 @@ def _filter_components(self, mode_solver_data: ModeSolverData): } return mode_solver_data.updated_copy(**skip_components, validate=False) - def _filter_polarization(self, mode_solver_data: ModeSolverData): + def _filter_polarization(self, mode_solver_data: ModeSolverData) -> ModeSolverData: """Filter polarization.""" filter_pol = self.mode_spec.filter_pol if filter_pol is None: @@ -1480,7 +1483,7 @@ def sim_data(self) -> MODE_SIMULATION_DATA_TYPE: :class:`.SimulationData` object containing the effective index and mode fields. """ monitor_data = self.data - new_monitors = [*list(self.simulation.monitors), monitor_data.monitor] + new_monitors = (*self.simulation.monitors, monitor_data.monitor) new_simulation = self.simulation.copy(update={"monitors": new_monitors}) if isinstance(new_simulation, Simulation): return SimulationData(simulation=new_simulation, data=(monitor_data,)) @@ -1613,7 +1616,13 @@ def _solve_all_freqs_relative( return n_complex, fields, eps_spec @staticmethod - def _postprocess_solver_fields(solver_fields, normal_axis, plane, mode_spec, coords): + def _postprocess_solver_fields( + solver_fields: ArrayComplex4D, + normal_axis: Axis, + plane: MODE_PLANE_TYPE, + mode_spec: ModeSpec, + coords: tuple[ArrayFloat1D, ArrayFloat1D], + ) -> dict[str, ArrayComplex4D]: """Postprocess `solver_fields` from `compute_modes` to proper coordinate""" fields = {key: [] for key in ("Ex", "Ey", "Ez", "Hx", "Hy", "Hz")} diff_coords = (np.diff(coords[0]), np.diff(coords[1])) @@ -1673,7 +1682,9 @@ def _rotate_field_coords_inverse( return np.stack(plane.unpop_axis(f_n, f_ts, axis=2), axis=0) @classmethod - def _postprocess_solver_fields_inverse(cls, fields, normal_axis: Axis, plane: MODE_PLANE_TYPE): + def _postprocess_solver_fields_inverse( + cls, fields: dict[str, ArrayComplex4D], normal_axis: Axis, plane: MODE_PLANE_TYPE + ) -> ArrayComplex4D: """Convert ``fields`` to ``solver_fields``. Doesn't change gauge.""" E = [fields[key] for key in ("Ex", "Ey", "Ez")] H = [fields[key] for key in ("Hx", "Hy", "Hz")] @@ -1791,7 +1802,7 @@ def _inverted_gauge(e_field: FIELD, diff_coords: tuple[ArrayFloat1D, ArrayFloat1 @staticmethod def _process_fields( mode_fields: ArrayComplex4D, - mode_index: pydantic.NonNegativeInt, + mode_index: NonNegativeInt, normal_axis: Axis, plane: MODE_PLANE_TYPE, diff_coords: tuple[ArrayFloat1D, ArrayFloat1D], @@ -2019,8 +2030,8 @@ def to_source( self, source_time: SourceTime, direction: Direction = None, - mode_index: pydantic.NonNegativeInt = 0, - num_freqs: pydantic.PositiveInt = 1, + mode_index: NonNegativeInt = 0, + num_freqs: PositiveInt = 1, **kwargs: Any, ) -> ModeSource: """Creates :class:`.ModeSource` from a :class:`.ModeSolver` instance plus additional @@ -2065,7 +2076,7 @@ def to_monitor( Parameters ---------- - freqs : List[float] + freqs : list[float] Frequencies to include in Monitor (Hz). If not specified, passes ``self.freqs``. name : str @@ -2158,7 +2169,7 @@ def sim_with_source( self, source_time: SourceTime, direction: Direction = None, - mode_index: pydantic.NonNegativeInt = 0, + mode_index: NonNegativeInt = 0, ) -> Simulation: """Creates :class:`.Simulation` from a :class:`.ModeSolver`. Creates a copy of the ModeSolver's original simulation with a ModeSource added corresponding to @@ -2200,7 +2211,7 @@ def sim_with_monitor( Parameters ---------- - freqs : List[float] = None + freqs : list[float] = None Frequencies to include in Monitor (Hz). If not specified, uses the frequencies from the mode solver. name : str @@ -2245,7 +2256,7 @@ def sim_with_mode_solver_monitor( def plot_field( self, field_name: str, - val: Literal["real", "imag", "abs"] = "real", + val: Literal["real", "imag", abs] = "real", scale: PlotScale = "lin", eps_alpha: float = 0.2, robust: bool = True, @@ -2531,7 +2542,7 @@ def _plane_grid(cls, simulation: Simulation, plane: Box) -> tuple[Coords, Coords @classmethod def _effective_num_pml( cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec - ) -> tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]: + ) -> tuple[NonNegativeFloat, NonNegativeFloat]: """Number of cells of the mode solver pml.""" coord_0, coord_1 = cls._plane_grid(simulation=simulation, plane=plane) @@ -2546,8 +2557,8 @@ def _effective_num_pml( def _pml_thickness( cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec ) -> tuple[ - tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat], - tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat], + tuple[NonNegativeFloat, NonNegativeFloat], + tuple[NonNegativeFloat, NonNegativeFloat], ]: """Thickness of the mode solver pml in the form ((plus0, minus0), (plus1, minus1)) @@ -2585,7 +2596,7 @@ def _pml_thickness( @classmethod def _mode_plane_size( cls, simulation: Simulation, plane: Box - ) -> tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]: + ) -> tuple[NonNegativeFloat, NonNegativeFloat]: """The size of the mode plane intersected with the simulation.""" _, h_lim, v_lim, _ = cls._center_and_lims(simulation=simulation, plane=plane) return h_lim[1] - h_lim[0], v_lim[1] - v_lim[0] @@ -2593,7 +2604,7 @@ def _mode_plane_size( @classmethod def _mode_plane_size_no_pml( cls, simulation: Simulation, plane: Box, mode_spec: ModeSpec - ) -> tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]: + ) -> tuple[NonNegativeFloat, NonNegativeFloat]: """The size of the remaining portion of the mode plane, after the pml has been removed.""" size = cls._mode_plane_size(simulation=simulation, plane=plane) @@ -2726,7 +2737,7 @@ def validate_pre_upload(self) -> None: self._validate_modes_size() @cached_property - def reduced_simulation_copy(self): + def reduced_simulation_copy(self) -> Self: """Strip objects not used by the mode solver from simulation object. This might significantly reduce upload time in the presence of custom mediums. """ @@ -2772,9 +2783,9 @@ def reduced_simulation_copy(self): # extract sub-simulation removing everything irrelevant new_sim = self.simulation.subsection( region=new_sim_box, - monitors=[], - sources=[], - internal_absorbers=[], + monitors=(), + sources=(), + internal_absorbers=(), warn_symmetry_expansion=False, # we already warn upon mode solver creation grid_spec="identical", boundary_spec=new_bspec, @@ -2787,10 +2798,10 @@ def reduced_simulation_copy(self): ) # Let's only validate mode solver where geometry validation is skipped: geometry replaced by its bounding # box - structures = [ + structures = tuple( strc.updated_copy(geometry=strc.geometry.bounding_box, deep=False) for strc in new_sim.structures - ] + ) # skip validation as it's validated already in subsection aux_new_sim = new_sim.updated_copy(structures=structures, deep=False, validate=False) # validate mode solver here where geometry is replaced by its bounding box @@ -2839,7 +2850,7 @@ def _patch_data(self, data: ModeSolverData) -> None: self._cached_properties.pop("data", None) self._cached_properties.pop("sim_data", None) - def plot_3d(self, width=800, height=800) -> None: + def plot_3d(self, width: int = 800, height: int = 800) -> None: """Render 3D plot of ``ModeSolver`` (in jupyter notebook only). Parameters ---------- diff --git a/tidy3d/components/mode/simulation.py b/tidy3d/components/mode/simulation.py index e506706822..1923e7ad21 100644 --- a/tidy3d/components/mode/simulation.py +++ b/tidy3d/components/mode/simulation.py @@ -2,15 +2,14 @@ from __future__ import annotations -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, field_validator, model_validator from tidy3d.components.base import cached_property from tidy3d.components.boundary import BoundarySpec from tidy3d.components.geometry.base import Box -from tidy3d.components.grid.grid import Grid from tidy3d.components.grid.grid_spec import GridSpec from tidy3d.components.monitor import ( MediumMonitor, @@ -24,7 +23,8 @@ validate_boundaries_for_zero_dims, ) from tidy3d.components.source.field import ModeSource -from tidy3d.components.types import TYPE_TAG_STR, Ax, Direction, EMField, FreqArray +from tidy3d.components.types import Direction, EMField, FreqArray +from tidy3d.components.types.base import TYPE_TAG_STR, discriminated_union from tidy3d.components.types.mode_spec import ModeSpecType from tidy3d.constants import C_0 from tidy3d.exceptions import SetupError, ValidationError @@ -33,13 +33,21 @@ from .mode_solver import ModeSolver +if TYPE_CHECKING: + from pydantic import PositiveFloat + + from tidy3d.compat import Self + from tidy3d.components.grid.grid import Grid + from tidy3d.components.mode.data.sim_data import ModeSimulationData + from tidy3d.components.types import Ax + ModeSimulationMonitorType = Union[PermittivityMonitor, MediumMonitor] # dummy run time for conversion to FDTD sim # should be very small -- otherwise, generating tmesh will fail or take a long time RUN_TIME = 1e-30 -MODE_PLANE_TYPE = Union[Box, ModeSource, ModeMonitor, ModeSolverMonitor] +MODE_PLANE_TYPE = discriminated_union(Union[Box, ModeSource, ModeMonitor, ModeSolverMonitor]) # attributes shared between ModeSimulation class and ModeSolver class @@ -119,38 +127,41 @@ class ModeSimulation(AbstractYeeGridSimulation): * `Prelude to Integrated Photonics Simulation: Mode Injection `_ """ - mode_spec: ModeSpecType = pd.Field( - ..., + # This validator needs to run before others that might access _mode_solver + _boundaries_for_zero_dims = validate_boundaries_for_zero_dims(warn_on_change=False) + + mode_spec: ModeSpecType = Field( title="Mode specification", description="Container with specifications about the modes to be solved for.", discriminator=TYPE_TAG_STR, ) - freqs: FreqArray = pd.Field( - ..., title="Frequencies", description="A list of frequencies at which to solve." + freqs: FreqArray = Field( + title="Frequencies", + description="A list of frequencies at which to solve.", ) - direction: Direction = pd.Field( + direction: Direction = Field( "+", title="Propagation direction", description="Direction of waveguide mode propagation along the axis defined by its normal " "dimension.", ) - colocate: bool = pd.Field( + colocate: bool = Field( True, title="Colocate fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " "primal grid nodes). Default is ``True``.", ) - conjugated_dot_product: bool = pd.Field( + conjugated_dot_product: bool = Field( True, title="Conjugated Dot Product", description="Use conjugated or non-conjugated dot product for mode decomposition.", ) - fields: tuple[EMField, ...] = pd.Field( + fields: tuple[EMField, ...] = Field( ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], title="Field Components", description="Collection of field components to store in the monitor. Note that some " @@ -158,8 +169,8 @@ class ModeSimulation(AbstractYeeGridSimulation): "like ``mode_area`` require all E-field components.", ) - boundary_spec: BoundarySpec = pd.Field( - BoundarySpec(), + boundary_spec: BoundarySpec = Field( + default_factory=BoundarySpec, title="Boundaries", description="Specification of boundary conditions along each dimension. If ``None``, " "PML boundary conditions are applied on all sides. This behavior is for " @@ -168,34 +179,34 @@ class ModeSimulation(AbstractYeeGridSimulation): "apply PML layers in the mode solver.", ) - monitors: tuple[ModeSimulationMonitorType, ...] = pd.Field( + monitors: tuple[ModeSimulationMonitorType, ...] = Field( (), title="Monitors", description="Tuple of monitors in the simulation. " "Note: monitor names are used to access data after simulation is run.", ) - sources: tuple[()] = pd.Field( + sources: tuple[()] = Field( (), title="Sources", description="Sources in the simulation. Note: sources are not supported in mode " "simulations.", ) - internal_absorbers: tuple[()] = pd.Field( + internal_absorbers: tuple[()] = Field( (), title="Internal Absorbers", description="Planes with the first order absorbing boundary conditions placed inside the computational domain. " "Note: absorbers are not supported in mode simulations.", ) - grid_spec: GridSpec = pd.Field( - GridSpec(), + grid_spec: GridSpec = Field( + default_factory=GridSpec, title="Grid Specification", description="Specifications for the simulation grid along each of the three directions.", ) - plane: MODE_PLANE_TYPE = pd.Field( + plane: Optional[MODE_PLANE_TYPE] = Field( None, title="Plane", description="Cross-sectional plane in which the mode will be computed. " @@ -203,47 +214,52 @@ class ModeSimulation(AbstractYeeGridSimulation): "the provided ``plane`` and the simulation geometry. " "If ``None``, the simulation must be 2D, and the plane will be the entire " "simulation geometry.", - discriminator=TYPE_TAG_STR, ) - @pd.validator("plane", always=True) - def is_plane(cls, val, values): + @field_validator("grid_spec") + @classmethod + def _validate_auto_grid_wavelength(cls, val: GridSpec) -> GridSpec: + # abstract override, logic is handled in post-init to ensure freqs is defined + return val + + @field_validator("plane") + @classmethod + def _validate_planar(cls, val: Optional[MODE_PLANE_TYPE]) -> Optional[MODE_PLANE_TYPE]: + if val.size.count(0.0) != 1: + raise ValidationError(f"'ModeSimulation.plane' must be planar, given 'size={val.size}'") + return val + + @model_validator(mode="before") + @classmethod + def is_plane(cls, data: dict[str, Any]) -> dict[str, Any]: """Raise validation error if not planar.""" - if val is None: - sim_center = values.get("center") - sim_size = values.get("size") - val = Box(size=sim_size, center=sim_center) + if hasattr(data, "get") and data.get("plane") is None: + val = Box(size=data.get("size"), center=data.get("center")) if val.size.count(0.0) != 1: raise ValidationError( "If the 'ModeSimulation' geometry is not planar, " "then 'plane' must be specified." ) - return val - if val.size.count(0.0) != 1: - raise ValidationError(f"'ModeSimulation.plane' must be planar, given 'size={val}'") - return val + data["plane"] = val + return data - @pd.validator("plane", always=True) - def plane_in_sim_bounds(cls, val, values): + @model_validator(mode="after") + def plane_in_sim_bounds(self) -> Self: """Check that the plane is at least partially inside the simulation bounds.""" - sim_center = values.get("center") - sim_size = values.get("size") - sim_box = Box(size=sim_size, center=sim_center) - - if not sim_box.intersects(val): + sim_box = Box(size=self.size, center=self.center) + if not sim_box.intersects(self.plane): raise SetupError("'ModeSimulation.plane' must intersect 'ModeSimulation.geometry.") - return val + return self - def _post_init_validators(self) -> None: - """Call validators taking `self` that get run after init.""" + @model_validator(mode="after") + def _validate_mode_solver(self) -> Self: _ = self._mode_solver - _ = self.grid + return self - @pd.validator("grid_spec", always=True) - def _validate_auto_grid_wavelength(cls, val, values): - """Handle the case where grid_spec is auto and wavelength is not provided.""" - # this is handled instead post-init to ensure freqs is defined - return val + @model_validator(mode="after") + def _validate_grid(self) -> Self: + _ = self.grid + return self @cached_property def _mode_solver(self) -> ModeSolver: @@ -252,7 +268,7 @@ def _mode_solver(self) -> ModeSolver: return ModeSolver(simulation=self._as_fdtd_sim, **kwargs) @supports_local_subpixel - def run_local(self): + def run_local(self) -> ModeSimulationData: """Run locally.""" if tidy3d_extras["use_local_subpixel"]: @@ -308,18 +324,25 @@ def _as_fdtd_sim(self) -> Simulation: grid_spec = grid_spec.updated_copy(wavelength=min_wvl) kwargs = {key: getattr(self, key) for key in MODE_SIM_YEE_SIM_SHARED_ATTRS} + + # For ModeSimulation with zero-size dimensions, boundary_spec might have been + # automatically updated to use periodic boundaries. The Simulation validator + # would log a warning about this, but we don't want that since ModeSimulation + # handles it silently. + + # Create the simulation - it will run its own validators return Simulation( **kwargs, run_time=RUN_TIME, grid_spec=grid_spec, - monitors=[], + monitors=(), ) @classmethod def from_simulation( cls, simulation: AbstractYeeGridSimulation, - wavelength: Optional[pd.PositiveFloat] = None, + wavelength: Optional[PositiveFloat] = None, **kwargs: Any, ) -> ModeSimulation: """Creates :class:`.ModeSimulation` from a :class:`.AbstractYeeGridSimulation`. @@ -328,7 +351,7 @@ def from_simulation( ---------- simulation: :class:`.AbstractYeeGridSimulation` Starting simulation defining structures, grid, etc. - wavelength: Optional[pd.PositiveFloat] + wavelength: Optional[PositiveFloat] Wavelength used for automatic grid generation. Required if auto grid is used in ``grid_spec``. **kwargs @@ -376,7 +399,7 @@ def reduced_simulation_copy(self) -> ModeSimulation: @classmethod def from_mode_solver( - cls, mode_solver: ModeSolver, wavelength: Optional[pd.PositiveFloat] = None + cls, mode_solver: ModeSolver, wavelength: Optional[PositiveFloat] = None ) -> ModeSimulation: """Creates :class:`.ModeSimulation` from a :class:`.ModeSolver`. @@ -384,7 +407,7 @@ def from_mode_solver( ---------- simulation: :class:`.AbstractYeeGridSimulation` Starting simulation defining structures, grid, etc. - wavelength: Optional[pd.PositiveFloat] + wavelength: Optional[PositiveFloat] Wavelength used for automatic grid generation. Required if auto grid is used in ``grid_spec``. @@ -615,5 +638,3 @@ def plot_pml_mode_plane( def validate_pre_upload(self) -> None: super().validate_pre_upload() self._mode_solver.validate_pre_upload() - - _boundaries_for_zero_dims = validate_boundaries_for_zero_dims(warn_on_change=False) diff --git a/tidy3d/components/mode/solver.py b/tidy3d/components/mode/solver.py index e533c3d084..ee6ac0eae7 100644 --- a/tidy3d/components/mode/solver.py +++ b/tidy3d/components/mode/solver.py @@ -2,18 +2,26 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any import numpy as np +from numpy.typing import NDArray from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.types import EpsSpecType, ModeSolverType, Numpy from tidy3d.constants import C_0, ETA_0, fp_eps, pec_val from .derivatives import create_d_matrices as d_mats from .derivatives import create_s_matrices as s_mats from .transforms import angled_transform, radial_transform +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Literal, Optional, Union + + from scipy import sparse as sp + + from tidy3d.components.types import EpsSpecType, ModeSolverType + # Consider vec to be complex if norm(vec.imag)/norm(vec) > TOL_COMPLEX TOL_COMPLEX = 1e-10 # Tolerance for eigs @@ -32,12 +40,13 @@ # double precision. This value is very heuristic. GOOD_CONDUCTOR_CUT_OFF = 1e70 -if TYPE_CHECKING: - from scipy import sparse as sp # Consider a material to be good conductor if |ep| (or |mu|) > GOOD_CONDUCTOR_THRESHOLD * |pec_val| GOOD_CONDUCTOR_THRESHOLD = 0.9 +ArrayFloat = NDArray[np.floating] +ArrayComplex = NDArray[np.complexfloating] + class EigSolver(Tidy3dBaseModel): """Interface for computing eigenvalues given permittivity and mode spec. @@ -47,18 +56,18 @@ class EigSolver(Tidy3dBaseModel): @classmethod def compute_modes( cls, - eps_cross, - coords, - freq, - mode_spec, - precision, - mu_cross=None, - split_curl_scaling=None, - symmetry=(0, 0), - direction="+", - solver_basis_fields=None, + eps_cross: Union[ArrayComplex, tuple[ArrayComplex, ...]], + coords: Sequence[ArrayFloat], + freq: float, + mode_spec: ModeSolverType, + precision: Literal["single", "double"], + mu_cross: Optional[Union[ArrayComplex, tuple[ArrayComplex, ...]]] = None, + split_curl_scaling: Optional[ArrayFloat] = None, + symmetry: tuple[int, int] = (0, 0), + direction: Literal["+", "-"] = "+", + solver_basis_fields: Optional[ArrayComplex] = None, plane_center: Optional[tuple[float, float]] = None, - ) -> tuple[Numpy, Numpy, EpsSpecType]: + ) -> tuple[ArrayComplex, ArrayComplex, EpsSpecType]: """ Solve for the modes of a waveguide cross-section. @@ -68,7 +77,7 @@ def compute_modes( Either a single 2D array defining the relative permittivity in the cross-section, or nine 2D arrays defining the permittivity at the Ex, Ey, and Ez locations of the Yee grid in the order xx, xy, xz, yx, yy, yz, zx, zy, zz. - coords : List[Numpy] + coords : List[np.ndarray] Two 1D arrays with each with size one larger than the corresponding axis of ``eps_cross``. Defines a (potentially non-uniform) Cartesian grid on which the modes are computed. @@ -100,7 +109,7 @@ def compute_modes( Returns ------- - Tuple[Numpy, Numpy, str] + tuple[np.ndarray, np.ndarray, str] The first array gives the E and H fields for all modes, the second one gives the complex effective index. The last variable describes permittivity characterization on the mode solver's plane ("diagonal", "tensorial_real", or "tensorial_complex"). @@ -128,7 +137,7 @@ def compute_modes( if len(coords[0]) != Nx + 1 or len(coords[1]) != Ny + 1: raise ValueError("Mismatch between 'coords' and 'esp_cross' shapes.") - new_coords = [np.copy(c) for c in coords] + new_coords = (np.copy(coords[0]), np.copy(coords[1])) """We work with full tensorial epsilon in mu to handle the most general cases that can be introduced by coordinate transformations. In the solver, we distinguish the case when @@ -304,20 +313,20 @@ def compute_modes( @classmethod def solver_em( cls, - Nx, - Ny, - eps_tensor, - mu_tensor, - der_mats, - num_modes, - neff_guess, - mat_precision, - direction, - enable_incidence_matrices, - basis_E, - dls, - dmin_pmc=None, - ): + Nx: int, + Ny: int, + eps_tensor: ArrayComplex, + mu_tensor: ArrayComplex, + der_mats: Sequence[sp.csr_matrix], + num_modes: int, + neff_guess: float, + mat_precision: Literal["single", "double"], + direction: Literal["+", "-"], + enable_incidence_matrices: bool, + basis_E: Optional[ArrayComplex], + dls: tuple[Sequence[ArrayFloat], Sequence[ArrayFloat]], + dmin_pmc: Optional[Sequence[bool]] = None, + ) -> tuple[ArrayComplex, ArrayComplex, ArrayFloat, ArrayFloat, EpsSpecType]: """Solve for the electromagnetic modes of a system defined by in-plane permittivity and permeability and assuming translational invariance in the normal direction. @@ -371,8 +380,8 @@ def solver_em( # use a high-conductivity model for locations associated with a good conductor def conductivity_model_for_good_conductor( - eps, threshold=GOOD_CONDUCTOR_THRESHOLD * pec_val - ): + eps: ArrayComplex, threshold: complex = GOOD_CONDUCTOR_THRESHOLD * pec_val + ) -> ArrayComplex: """Entries associated with 'eps' are converted to a high-conductivity model.""" eps = eps.astype(complex) eps[np.abs(eps) >= abs(threshold)] = 1 + 1j * pec_scaled_val @@ -475,16 +484,16 @@ def conductivity_model_for_good_conductor( @classmethod def solver_diagonal( cls, - eps, - mu, - der_mats, - num_modes, - neff_guess, - vec_init, - mat_precision, - enable_incidence_matrices, - basis_E, - ): + eps: ArrayComplex, + mu: ArrayComplex, + der_mats: Sequence[sp.csr_matrix], + num_modes: int, + neff_guess: float, + vec_init: ArrayComplex, + mat_precision: Literal["single", "double"], + enable_incidence_matrices: bool, + basis_E: Optional[ArrayComplex], + ) -> tuple[ArrayComplex, ArrayComplex, ArrayFloat, ArrayFloat]: """EM eigenmode solver assuming ``eps`` and ``mu`` are diagonal everywhere.""" import scipy.sparse as sp import scipy.sparse.linalg as spl @@ -494,7 +503,9 @@ def solver_diagonal( analyze_conditioning = False _threshold = 0.9 * np.abs(pec_val) - def incidence_matrix_for_pec(eps_vec, threshold=_threshold): + def incidence_matrix_for_pec( + eps_vec: ArrayComplex, threshold: float = _threshold + ) -> sp.csr_matrix: """Incidence matrix indicating non-PEC entries associated with 'eps_vec'.""" nnz = eps_vec[np.abs(eps_vec) < threshold] eps_nz = eps_vec.copy() @@ -583,7 +594,9 @@ def incidence_matrix_for_pec(eps_vec, threshold=_threshold): elif PRECONDITIONER == "Material": - def conditional_inverted_vec(eps_vec, threshold=1): + def conditional_inverted_vec( + eps_vec: ArrayComplex, threshold: float = 1 + ) -> sp.csr_matrix: """Returns a diagonal sparse matrix whose i-th element in the diagonal is |eps_i|^-1 if |eps_i|>threshold, and |eps_i| otherwise. """ @@ -701,7 +714,14 @@ def conditional_inverted_vec(eps_vec, threshold=1): return E, H, neff, keff @classmethod - def matrix_data_type(cls, eps, mu, der_mats, mat_precision, is_tensorial): + def matrix_data_type( + cls, + eps: ArrayComplex, + mu: ArrayComplex, + der_mats: Sequence[sp.csr_matrix], + mat_precision: Literal["single", "double"], + is_tensorial: bool, + ) -> np.dtype[Any]: """Determine data type that should be used for the matrix for diagonalization.""" mat_dtype = np.float32 # In tensorial case, even though the matrix can be real, the @@ -745,18 +765,18 @@ def _check_reciprocity(material_tensor: np.ndarray, tol: float) -> bool: @classmethod def solver_tensorial( cls, - eps, - mu, - der_mats, - num_modes, - neff_guess, - vec_init, - mat_precision, - direction, - dls, - Nxy=None, - dmin_pmc=None, - ): + eps: ArrayComplex, + mu: ArrayComplex, + der_mats: Sequence[sp.csr_matrix], + num_modes: int, + neff_guess: float, + vec_init: ArrayComplex, + mat_precision: Literal["single", "double"], + direction: Literal["+", "-"], + dls: tuple[Sequence[ArrayFloat], Sequence[ArrayFloat]], + Nxy: Optional[tuple[int, int]] = None, + dmin_pmc: Optional[Sequence[bool]] = None, + ) -> tuple[ArrayComplex, ArrayComplex, ArrayFloat, ArrayFloat]: """EM eigenmode solver assuming ``eps`` or ``mu`` have off-diagonal elements.""" import scipy.sparse as sp @@ -887,13 +907,13 @@ def solver_tensorial( @classmethod def solver_eigs( cls, - mat, - num_modes, - vec_init, - guess_value=1.0, - M=None, + mat: sp.csr_matrix, + num_modes: int, + vec_init: ArrayComplex, + guess_value: float = 1.0, + M: Optional[sp.csr_matrix] = None, **kwargs: Any, - ): + ) -> tuple[ArrayComplex, ArrayComplex]: """Find ``num_modes`` eigenmodes of ``mat`` cloest to ``guess_value``. Parameters @@ -925,14 +945,14 @@ def solver_eigs( @classmethod def solver_eigs_relative( cls, - mat, - num_modes, - vec_init, - guess_value=1.0, - M=None, - basis_vecs=None, + mat: sp.csr_matrix, + num_modes: int, + vec_init: ArrayComplex, + guess_value: float = 1.0, + M: Optional[sp.csr_matrix] = None, + basis_vecs: Optional[ArrayComplex] = None, **kwargs: Any, - ): + ) -> tuple[ArrayComplex, ArrayComplex]: """Find ``num_modes`` eigenmodes of ``mat`` cloest to ``guess_value``. Parameters @@ -953,7 +973,9 @@ def solver_eigs_relative( return values, vectors @classmethod - def isinstance_complex(cls, vec_or_mat, tol=TOL_COMPLEX): + def isinstance_complex( + cls, vec_or_mat: Union[ArrayComplex, sp.csr_matrix], tol: float = TOL_COMPLEX + ) -> bool: """Check if a numpy array or scipy.sparse.csr_matrix has complex component by looking at norm(x.imag)/norm(x)>TOL_COMPLEX @@ -975,7 +997,9 @@ def isinstance_complex(cls, vec_or_mat, tol=TOL_COMPLEX): ) @classmethod - def type_conversion(cls, vec_or_mat, new_dtype): + def type_conversion( + cls, vec_or_mat: Union[ArrayComplex, sp.csr_matrix], new_dtype: np.dtype[Any] + ) -> Union[ArrayComplex, sp.csr_matrix]: """Convert vec_or_mat to new_type. Parameters @@ -999,7 +1023,7 @@ def type_conversion(cls, vec_or_mat, new_dtype): raise RuntimeError("Unsupported new_type.") @classmethod - def set_initial_vec(cls, Nx, Ny, is_tensorial=False): + def set_initial_vec(cls, Nx: int, Ny: int, is_tensorial: bool = False) -> ArrayComplex: """Set initial vector for eigs: 1) The field at x=0 and y=0 boundaries are set to 0. This should be the case for PEC boundaries, but wouldn't hurt for non-PEC boundary; @@ -1037,19 +1061,21 @@ def set_initial_vec(cls, Nx, Ny, is_tensorial=False): return vec_init.flatten("F") @classmethod - def eigs_to_effective_index(cls, eig_list: Numpy, mode_solver_type: ModeSolverType): + def eigs_to_effective_index( + cls, eig_list: ArrayComplex, mode_solver_type: ModeSolverType + ) -> tuple[ArrayFloat, ArrayFloat]: """Convert obtained eigenvalues to n_eff and k_eff. Parameters ---------- - eig_list : Numpy + eig_list : np.ndarray Array of eigenvalues mode_solver_type : ModeSolverType The type of mode solver problems Returns ------- - Tuple[Numpy, Numpy] + tuple[np.ndarray, np.ndarray] n_eff and k_eff """ if eig_list.size == 0: @@ -1067,21 +1093,23 @@ def eigs_to_effective_index(cls, eig_list: Numpy, mode_solver_type: ModeSolverTy raise RuntimeError(f"Unidentified 'mode_solver_type={mode_solver_type}'.") @staticmethod - def format_medium_data(mat_data): + def format_medium_data( + mat_data: Union[ArrayComplex, Sequence[ArrayComplex]], + ) -> tuple[ArrayComplex, ...]: """ mat_data can be either permittivity or permeability. It's either a single 2D array defining the relative property in the cross-section, or nine 2D arrays defining the property at the E(H)x, E(H)y, and E(H)z locations of the Yee grid in the order xx, xy, xz, yx, yy, yz, zx, zy, zz. """ - if isinstance(mat_data, Numpy): - return (mat_data[i, :, :] for i in range(9)) + if isinstance(mat_data, np.ndarray): + return tuple(mat_data[i, :, :] for i in range(9)) if len(mat_data) == 9: - return (np.copy(e) for e in mat_data) + return tuple(np.copy(e) for e in mat_data) raise ValueError("Wrong input to mode solver pemittivity/permeability!") @staticmethod - def split_curl_field_postprocess(split_curl, E): + def split_curl_field_postprocess(split_curl: ArrayFloat, E: ArrayComplex) -> ArrayComplex: """E has the shape (3, N, num_modes)""" _, Nx, Ny = split_curl.shape field_shape = E.shape @@ -1099,7 +1127,9 @@ def split_curl_field_postprocess(split_curl, E): return E @staticmethod - def make_pml_invariant(Nxy, tensor, num_pml): + def make_pml_invariant( + Nxy: tuple[int, int], tensor: ArrayComplex, num_pml: tuple[int, int] + ) -> ArrayComplex: """For a given epsilon or mu tensor of shape ``(3, 3, Nx, Ny)``, and ``num_pml`` pml layers along ``x`` and ``y``, make all the tensor values in the PML equal by replicating the first pixel into the PML.""" @@ -1113,12 +1143,16 @@ def make_pml_invariant(Nxy, tensor, num_pml): return new_ten.reshape((3, 3, -1)) @staticmethod - def split_curl_field_postprocess_inverse(split_curl, E) -> None: + def split_curl_field_postprocess_inverse( + split_curl: ArrayFloat, E: ArrayComplex + ) -> ArrayComplex: """E has the shape (3, N, num_modes)""" raise RuntimeError("Split curl not yet implemented for relative mode solver.") @staticmethod - def mode_plane_contain_good_conductor(material_response) -> bool: + def mode_plane_contain_good_conductor( + material_response: Optional[ArrayComplex], + ) -> bool: """Find out if epsilon on the modal plane contain good conductors whose permittivity or permeability value is very large. """ @@ -1467,6 +1501,6 @@ def _make_orthogonal_basis_for_degenerate_modes( return E_vec, H_vec -def compute_modes(*args: Any, **kwargs: Any) -> tuple[Numpy, Numpy, str]: +def compute_modes(*args: Any, **kwargs: Any) -> tuple[ArrayComplex, ArrayComplex, EpsSpecType]: """A wrapper around ``EigSolver.compute_modes``, which is used in :class:`.ModeSolver`.""" return EigSolver.compute_modes(*args, **kwargs) diff --git a/tidy3d/components/mode/transforms.py b/tidy3d/components/mode/transforms.py index c498be1af9..2b01ff1bde 100644 --- a/tidy3d/components/mode/transforms.py +++ b/tidy3d/components/mode/transforms.py @@ -10,10 +10,24 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np +from numpy.typing import NDArray + +if TYPE_CHECKING: + from collections.abc import Sequence + +ArrayFloat = NDArray[np.floating] +CoordsTuple = tuple[ArrayFloat, ArrayFloat] -def radial_transform(coords, radius, bend_axis, plane_center): +def radial_transform( + coords: CoordsTuple, + radius: float, + bend_axis: int, + plane_center: Sequence[float], +) -> tuple[CoordsTuple, ArrayFloat, ArrayFloat]: """Compute the new coordinates and the Jacobian of a polar coordinate transformation. After offsetting the plane such that its center is a distance of ``radius`` away from the center of curvature, we have, e.g. for ``bend_axis=='y'``: @@ -73,7 +87,11 @@ def radial_transform(coords, radius, bend_axis, plane_center): return new_coords, jac_e, jac_h -def angled_transform(coords, angle_theta, angle_phi): +def angled_transform( + coords: CoordsTuple, + angle_theta: float, + angle_phi: float, +) -> tuple[CoordsTuple, ArrayFloat, ArrayFloat]: """Compute the new coordinates and the Jacobian for a transformation that "straightens" an angled waveguide such that it is translationally invariant in w. The transformation is u = x - tan(angle) * z @@ -100,7 +118,7 @@ def angled_transform(coords, angle_theta, angle_phi): Nx, Ny = coords[0].size - 1, coords[1].size - 1 # The new coordinates are exactly the same at z = 0 - new_coords = [np.copy(c) for c in coords] + new_coords = tuple(np.copy(c) for c in coords) # The only nontrivial derivatives are dudz, dvdz and they are constant everywhere jac = np.zeros((3, 3, Nx * Ny)) diff --git a/tidy3d/components/mode_spec.py b/tidy3d/components/mode_spec.py index 3d372fdbd2..633e1748eb 100644 --- a/tidy3d/components/mode_spec.py +++ b/tidy3d/components/mode_spec.py @@ -4,18 +4,28 @@ from abc import ABC, abstractmethod from math import isclose -from typing import Literal, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import ( + Field, + NonNegativeInt, + PositiveFloat, + PositiveInt, + field_validator, + model_validator, +) from tidy3d.constants import GLANCING_CUTOFF, MICROMETER, RADIAN, fp_eps from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log -from .base import Tidy3dBaseModel, skip_if_fields_missing +from .base import Tidy3dBaseModel from .types import Axis2D, FreqArray, TrackFreq +if TYPE_CHECKING: + from tidy3d.compat import Self + GROUP_INDEX_STEP = 0.005 MODE_DATA_KEYS = Literal[ "n_eff", @@ -43,44 +53,44 @@ class ModeSortSpec(Tidy3dBaseModel): """ # Filtering stage - filter_key: Optional[MODE_DATA_KEYS] = pd.Field( + filter_key: Optional[MODE_DATA_KEYS] = Field( None, title="Filtering key", description="Quantity used to filter modes into two groups before sorting.", ) - filter_reference: float = pd.Field( + filter_reference: float = Field( 0.0, title="Filtering reference", description="Reference value used in the filtering stage.", ) - filter_order: Literal["over", "under"] = pd.Field( + filter_order: Literal["over", "under"] = Field( "over", title="Filtering order", description="Select whether the first group contains values over or under the reference.", ) # Sorting stage - sort_key: Optional[MODE_DATA_KEYS] = pd.Field( + sort_key: Optional[MODE_DATA_KEYS] = Field( None, title="Sorting key", description="Quantity used to sort modes within each filtered group. If ``None``, " "sorting is by descending effective index.", ) - sort_reference: Optional[float] = pd.Field( + sort_reference: Optional[float] = Field( None, title="Sorting reference", description=( "If provided, sorting is based on the absolute difference to this reference value." ), ) - sort_order: Literal["ascending", "descending"] = pd.Field( + sort_order: Literal["ascending", "descending"] = Field( "ascending", title="Sorting direction", description="Sort order for the selected key or difference to reference value.", ) # Frequency tracking - applied after sorting and filtering - track_freq: Optional[TrackFreq] = pd.Field( + track_freq: Optional[TrackFreq] = Field( "central", title="Tracking base frequency", description="If provided, enables cross-frequency mode tracking. Can be 'lowest', " @@ -117,8 +127,7 @@ def _num_points(self) -> int: class UniformSampling(FrequencySamplingSpec): """Uniform frequency sampling specification.""" - num_points: int = pd.Field( - ..., + num_points: int = Field( title="Number of Points", description="Number of uniformly spaced frequency sampling points.", ge=2, @@ -150,8 +159,7 @@ def sampling_points(self, freqs: FreqArray) -> FreqArray: class ChebSampling(FrequencySamplingSpec): """Chebyshev node frequency sampling specification.""" - num_points: int = pd.Field( - ..., + num_points: int = Field( title="Number of Points", description="Number of Chebyshev nodes for frequency sampling.", ge=3, @@ -191,14 +199,14 @@ def sampling_points(self, freqs: FreqArray) -> FreqArray: class CustomSampling(FrequencySamplingSpec): """Custom frequency sampling specification.""" - freqs: FreqArray = pd.Field( - ..., + freqs: FreqArray = Field( title="Frequencies", description="Custom array of frequency sampling points.", ) - @pd.validator("freqs", always=True) - def _validate_freqs(cls, val): + @field_validator("freqs") + @classmethod + def _validate_freqs(cls, val: FreqArray) -> FreqArray: """Validate custom frequencies.""" freqs_array = np.asarray(val) if freqs_array.size < 2: @@ -266,14 +274,13 @@ class ModeInterpSpec(Tidy3dBaseModel): Monitor that can use this specification to reduce mode computation cost. """ - sampling_spec: Union[UniformSampling, ChebSampling, CustomSampling] = pd.Field( - ..., + sampling_spec: Union[UniformSampling, ChebSampling, CustomSampling] = Field( title="Sampling Specification", description="Specification for frequency sampling points.", discriminator="type", ) - method: Literal["linear", "cubic", "poly"] = pd.Field( + method: Literal["linear", "cubic", "poly"] = Field( "linear", title="Interpolation Method", description="Method for interpolating mode data between computed frequencies. " @@ -284,7 +291,7 @@ class ModeInterpSpec(Tidy3dBaseModel): "For complex-valued data, real and imaginary parts are interpolated independently.", ) - reduce_data: bool = pd.Field( + reduce_data: bool = Field( False, title="Reduce Data", description="Applies only to :class:`ModeSolverData`. If ``True``, fields and quantities " @@ -294,13 +301,13 @@ class ModeInterpSpec(Tidy3dBaseModel): "Does not apply if the number of sampling points is greater than the number of monitor frequencies.", ) - @pd.validator("method", always=True) - @skip_if_fields_missing(["sampling_spec"]) - def _validate_method_needs_points(cls, val, values): + @model_validator(mode="after") + def _validate_method_needs_points(self) -> Self: """Validate that the method has enough points.""" - sampling_spec = values.get("sampling_spec") + val = self.method + sampling_spec = self.sampling_spec if sampling_spec is None: - return val + return self num_points = sampling_spec._num_points if val == "cubic" and num_points < 4: @@ -315,7 +322,7 @@ def _validate_method_needs_points(cls, val, values): f"Got {num_points} points. " "Use method='linear' or increase num_points." ) - return val + return self @classmethod def uniform( @@ -448,21 +455,25 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC): Abstract base for mode specification data. """ - num_modes: pd.PositiveInt = pd.Field( - 1, title="Number of modes", description="Number of modes returned by mode solver." + num_modes: PositiveInt = Field( + 1, + title="Number of modes", + description="Number of modes returned by mode solver.", ) - target_neff: pd.PositiveFloat = pd.Field( - None, title="Target effective index", description="Guess for effective index of the mode." + target_neff: Optional[PositiveFloat] = Field( + None, + title="Target effective index", + description="Guess for effective index of the mode.", ) - num_pml: tuple[pd.NonNegativeInt, pd.NonNegativeInt] = pd.Field( + num_pml: tuple[NonNegativeInt, NonNegativeInt] = Field( (0, 0), title="Number of PML layers", description="Number of standard pml layers to add in the two tangential axes.", ) - filter_pol: Literal["te", "tm"] = pd.Field( + filter_pol: Optional[Literal["te", "tm"]] = Field( None, title="Polarization filtering", description="The solver always computes the ``num_modes`` modes closest to the given " @@ -478,14 +489,14 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC): "``tm``-fraction uses the E field component parallel to the second plane axis.", ) - angle_theta: float = pd.Field( + angle_theta: float = Field( 0.0, title="Polar Angle", description="Polar angle of the propagation axis from the injection axis.", units=RADIAN, ) - angle_phi: float = pd.Field( + angle_phi: float = Field( 0.0, title="Azimuth Angle", description="Azimuth angle of the propagation axis in the plane orthogonal to the " @@ -493,7 +504,7 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC): units=RADIAN, ) - precision: Literal["auto", "single", "double"] = pd.Field( + precision: Literal["auto", "single", "double"] = Field( "double", title="single, double, or automatic precision in mode solver", description="The solver will be faster and using less memory under " @@ -502,7 +513,7 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC): "conductor, single precision otherwise.", ) - bend_radius: float = pd.Field( + bend_radius: Optional[float] = Field( None, title="Bend radius", description="A curvature radius for simulation of waveguide bends. Can be negative, in " @@ -511,7 +522,7 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC): units=MICROMETER, ) - bend_axis: Axis2D = pd.Field( + bend_axis: Optional[Axis2D] = Field( None, title="Bend axis", description="Index into the two tangential axes defining the normal to the " @@ -520,26 +531,26 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC): "yz plane, the ``bend_axis`` is always 1 (the global z axis).", ) - angle_rotation: bool = pd.Field( + angle_rotation: bool = Field( False, - title="Use fields rotation when angle_theta is not zero", - description="Defines how modes are computed when angle_theta is not zero. " - "If 'False', a coordinate transformation is applied through the permittivity and permeability tensors." - "If 'True', the structures in the simulation are first rotated to compute a mode solution at " + title="Use fields rotation when ``angle_theta`` is not zero", + description="Defines how modes are computed when ``angle_theta`` is not zero. " + "If ``False``, a coordinate transformation is applied through the permittivity and permeability tensors." + "If ``True``, the structures in the simulation are first rotated to compute a mode solution at " "a reference plane normal to the structure's azimuthal direction. Then, the fields are rotated " - "to align with the mode plane, using the 'n_eff' calculated at the reference plane. The second option can " + "to align with the mode plane, using the ``n_eff`` calculated at the reference plane. The second option can " "produce more accurate results, but more care must be taken, for example, in ensuring that the " "original mode plane intersects the correct geometries in the simulation with rotated structures. " - "Note: currently only supported when 'angle_phi' is a multiple of 'np.pi'.", + "Note: currently only supported when ``angle_phi`` is a multiple of ``np.pi``.", ) - track_freq: Optional[TrackFreq] = pd.Field( + track_freq: Optional[TrackFreq] = Field( None, title="Mode Tracking Frequency (deprecated)", description="Deprecated. Use 'sort_spec.track_freq' instead.", ) - group_index_step: Union[pd.PositiveFloat, bool] = pd.Field( + group_index_step: Union[PositiveFloat, bool] = Field( False, title="Frequency step for group index computation", description="Control the computation of the group index alongside the effective index. If " @@ -548,15 +559,15 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC): f"default of {GROUP_INDEX_STEP} is used.", ) - sort_spec: ModeSortSpec = pd.Field( - ModeSortSpec(), + sort_spec: ModeSortSpec = Field( + default_factory=ModeSortSpec, title="Mode filtering and sorting specification", description="Defines how to filter and sort modes within each frequency. If ``track_freq`` " "is not ``None``, the sorting is only exact at the specified frequency, while at other " "frequencies it can change depending on the mode tracking.", ) - interp_spec: Optional[ModeInterpSpec] = pd.Field( + interp_spec: Optional[ModeInterpSpec] = Field( None, title="Mode frequency interpolation specification", description="Specification for computing modes at a reduced set of frequencies and " @@ -566,58 +577,68 @@ class AbstractModeSpec(Tidy3dBaseModel, ABC): "not be ``None``) to ensure consistent mode ordering across frequencies.", ) - @pd.validator("bend_axis", always=True) - @skip_if_fields_missing(["bend_radius"]) - def bend_axis_given(cls, val, values): - """Check that ``bend_axis`` is provided if ``bend_radius`` is not ``None``""" - if val is None and values.get("bend_radius") is not None: - raise SetupError("'bend_axis' must also be defined if 'bend_radius' is defined.") + @field_validator("group_index_step", mode="before") + @classmethod + def _validate_group_index_step_default( + cls, val: Union[bool, PositiveFloat] + ) -> Union[bool, PositiveFloat]: + """If ``True``, replace with default fractional step.""" + if val is True: + return GROUP_INDEX_STEP return val - @pd.validator("bend_radius", always=True) - def bend_radius_not_zero(cls, val, values): - """Check that ``bend_raidus`` magnitude is not close to zero.`""" - if val is not None and isclose(val, 0): - raise SetupError("The magnitude of 'bend_radius' must be larger than 0.") + @field_validator("group_index_step") + @classmethod + def _validate_group_index_step_size( + cls, val: Union[bool, PositiveFloat] + ) -> Union[bool, PositiveFloat]: + """Ensure group-index step is < 1.""" + if val is not False and val >= 1: + raise ValidationError( + "Parameter 'group_index_step' must be a fractional value less than 1." + ) return val - @pd.validator("angle_theta", allow_reuse=True, always=True) - def glancing_incidence(cls, val): - """Warn if close to glancing incidence.""" - if np.abs(np.pi / 2 - val) < GLANCING_CUTOFF: + @field_validator("bend_radius") + @classmethod + def _validate_bend_radius_not_zero(cls, v: Optional[float]) -> Optional[float]: + """`bend_radius` magnitude must be non-zero.""" + if v is not None and isclose(v, 0): + raise SetupError("The magnitude of 'bend_radius' must be larger than 0.") + return v + + @field_validator("angle_theta") + @classmethod + def _validate_angle_theta_glancing(cls, val: float) -> float: + """Disallow incidence too close to glancing.""" + if abs(np.pi / 2 - val) < GLANCING_CUTOFF: raise SetupError( "Mode propagation axis too close to glancing angle for accurate injection. " "For best results, switch the injection axis." ) return val - # Must be executed before type validation by pydantic, otherwise True is converted to 1.0 - @pd.validator("group_index_step", pre=True) - def assign_default_on_true(cls, val): - """Assign the default fractional frequency step value if not provided.""" - if val is True: - return GROUP_INDEX_STEP - return val + @model_validator(mode="after") + def _check_bend_axis_given(self) -> Self: + """``bend_axis`` must be provided when ``bend_radius`` is set.""" + if self.bend_radius is not None and self.bend_axis is None: + raise SetupError("'bend_axis' must also be defined if 'bend_radius' is defined.") + return self - @pd.validator("group_index_step") - def check_group_step_size(cls, val): - """Ensure a reasonable group index step is used.""" - if val >= 1: + @model_validator(mode="after") + def _check_angle_rotation_with_phi(self) -> Self: + """``angle_rotation`` requires ``angle_phi`` % (π/2) == 0.""" + if self.angle_rotation and not isclose(self.angle_phi % (np.pi / 2), 0): raise ValidationError( - "Parameter 'group_index_step' is a fractional value. It must be less than 1." + "'angle_phi' must be a multiple of 'π/2' when 'angle_rotation' is enabled." ) - return val + return self - @pd.root_validator(skip_on_failure=True) - def check_precision(cls, values): + @model_validator(mode="after") + def check_precision(self) -> Self: """Verify critical ModeSpec settings for group index calculation.""" - if values["group_index_step"] > 0: - # prefer explicit track_freq on ModeSpec, else fall back to sort_spec.track_freq - # TODO: can be replaced with self._track_freq in pydantic v2 - tf = values.get("track_freq") - if tf is None: - sort_spec = values.get("sort_spec") - tf = None if sort_spec is None else sort_spec.track_freq + if self.group_index_step > 0: + tf = self._track_freq if tf is None: log.warning( "Group index calculation without mode tracking can lead to incorrect results " @@ -625,7 +646,7 @@ def check_precision(cls, values): ) # multiply by 5 to be safe - if values["group_index_step"] < 5 * fp_eps and values["precision"] != "double": + if self.group_index_step < 5 * fp_eps and self.precision != "double": log.warning( "Group index step is too small! " "The results might be fully corrupted by numerical errors. " @@ -633,32 +654,23 @@ def check_precision(cls, values): "or increasing the value of 'group_index_step'." ) - return values + return self - @pd.validator("angle_rotation") - def angle_rotation_with_phi(cls, val, values): - """Currently ``angle_rotation`` is only supported with ``angle_phi % (np.pi / 2) == 0``.""" - if val and not isclose(values["angle_phi"] % (np.pi / 2), 0): - raise ValidationError( - "Parameter 'angle_phi' must be a multiple of 'np.pi / 2' when 'angle_rotation' is " - "enabled." - ) - return val - - @pd.root_validator(skip_on_failure=True) - def _filter_pol_and_sort_spec_exclusive(cls, values): + @model_validator(mode="after") + def _filter_pol_and_sort_spec_exclusive(self) -> Self: """Ensure that 'filter_pol' and 'sort_spec' are not used together.""" - sort_spec = values.get("sort_spec") + sort_spec = self.sort_spec sort_or_filter = sort_spec.filter_key is not None or sort_spec.sort_key is not None - if values.get("filter_pol") is not None and sort_or_filter: + if self.filter_pol is not None and sort_or_filter: raise SetupError( "'filter_pol' cannot be used simultaneously with sorting or filtering " "defined in 'sort_spec'. Define the filtering in 'sort_spec' exclusively." ) - return values + return self - @pd.validator("filter_pol", always=True) - def _filter_pol_deprecated(cls, val): + @field_validator("filter_pol") + @classmethod + def _filter_pol_deprecated(cls, val: Optional[str]) -> Optional[str]: """Warn that 'filter_pol' is deprecated in favor of 'sort_spec'.""" if val is not None: log.warning( @@ -667,8 +679,9 @@ def _filter_pol_deprecated(cls, val): ) return val - @pd.validator("track_freq", always=True) - def _track_freq_deprecated(cls, val): + @field_validator("track_freq") + @classmethod + def _track_freq_deprecated(cls, val: Optional[TrackFreq]) -> Optional[TrackFreq]: """Warn that 'track_freq' on ModeSpec is deprecated in favor of 'sort_spec.track_freq'.""" if val is not None: log.warning( @@ -689,23 +702,23 @@ def _track_freq_from_specs( return sort_spec.track_freq return None - @pd.validator("interp_spec", always=True) - @skip_if_fields_missing(["sort_spec", "track_freq"]) - def _interp_spec_needs_tracking(cls, val, values): + @model_validator(mode="after") + def _interp_spec_needs_tracking(self) -> Self: """Ensure frequency tracking is enabled when using interpolation.""" + val = self.interp_spec if val is None: - return val + return self # Check if track_freq is enabled (prefer ModeSpec.track_freq, else sort_spec.track_freq) - track_freq = values.get("track_freq") - sort_spec = values.get("sort_spec") - if cls._track_freq_from_specs(track_freq, sort_spec) is None: + track_freq = self.track_freq + sort_spec = self.sort_spec + if self.__class__._track_freq_from_specs(track_freq, sort_spec) is None: raise ValidationError( "Mode frequency interpolation requires frequency tracking to be enabled. " "Please set 'sort_spec.track_freq' to 'central', 'lowest', or 'highest'." ) - return val + return self @property def _track_freq(self) -> Optional[TrackFreq]: diff --git a/tidy3d/components/monitor.py b/tidy3d/components/monitor.py index df4ec33049..a698ad1609 100644 --- a/tidy3d/components/monitor.py +++ b/tidy3d/components/monitor.py @@ -3,35 +3,30 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional import numpy as np -import pydantic.v1 as pydantic +from pydantic import Field, NonNegativeFloat, PositiveInt, field_validator, model_validator from tidy3d.constants import HERTZ, MICROMETER, RADIAN, SECOND, inf from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log from .apodization import ApodizationSpec -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from .base import Tidy3dBaseModel, cached_property from .base_sim.monitor import AbstractMonitor from .medium import MediumType from .microwave.base import MicrowaveBaseModel from .mode_spec import ModeSpec from .types import ( - ArrayFloat1D, AuxField, - Ax, Axis, - Bound, BoxSurface, Coordinate, Direction, EMField, FreqArray, - FreqBound, ObsGridArray, - Size, ) from .validators import ( assert_plane, @@ -40,6 +35,13 @@ ) from .viz import ARROW_ALPHA, ARROW_COLOR_MONITOR +if TYPE_CHECKING: + from pydantic import FieldValidationInfo + + from tidy3d.compat import Self + + from .types import ArrayFloat1D, Ax, Bound, FreqBound, Size + BYTES_REAL = 4 BYTES_COMPLEX = 8 WARN_NUM_FREQS = 2000 @@ -55,7 +57,7 @@ class Monitor(AbstractMonitor): """Abstract base class for monitors.""" - interval_space: tuple[Literal[1], Literal[1], Literal[1]] = pydantic.Field( + interval_space: tuple[Literal[1], Literal[1], Literal[1]] = Field( (1, 1, 1), title="Spatial Interval", description="Number of grid step intervals between monitor recordings. If equal to 1, " @@ -64,7 +66,7 @@ class Monitor(AbstractMonitor): "Not all monitors support values different from 1.", ) - colocate: Literal[True] = pydantic.Field( + colocate: Literal[True] = Field( True, title="Colocate Fields", description="Defines whether fields are colocated to grid cell boundaries (i.e. to the " @@ -73,7 +75,7 @@ class Monitor(AbstractMonitor): ) @property - def _to_solver_monitor(self): + def _to_solver_monitor(self) -> Self: """Monitor definition that will be used to define the field recording during the time stepping.""" return self @@ -90,15 +92,14 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class FreqMonitor(Monitor, ABC): """:class:`Monitor` that records data in the frequency-domain.""" - freqs: FreqArray = pydantic.Field( - ..., + freqs: FreqArray = Field( title="Frequencies", description="Array or list of frequencies stored by the field monitor.", units=HERTZ, ) - apodization: ApodizationSpec = pydantic.Field( - ApodizationSpec(), + apodization: ApodizationSpec = Field( + default_factory=ApodizationSpec, title="Apodization Specification", description="Sets parameters of (optional) apodization. Apodization applies a windowing " "function to the Fourier transform of the time-domain fields into frequency-domain ones, " @@ -110,13 +111,16 @@ class FreqMonitor(Monitor, ABC): _freqs_not_empty = validate_freqs_not_empty() _freqs_lower_bound = validate_freqs_min() - @pydantic.validator("freqs", always=True) - def _warn_num_freqs(cls, val, values): + @field_validator("freqs") + @classmethod + def _warn_num_freqs( + cls: type[FreqMonitor], val: FreqArray, info: FieldValidationInfo + ) -> FreqArray: """Warn if number of frequencies is too large.""" if len(val) > WARN_NUM_FREQS: log.warning( f"A large number ({len(val)}) of frequencies detected in monitor " - f"'{values['name']}'. This can lead to solver slow-down and increased cost. " + f"'{info.field_name}'. This can lead to solver slow-down and increased cost. " "Consider decreasing the number of frequencies in the monitor. This may become a " "hard limit in future Tidy3D versions.", custom_loc=["freqs"], @@ -129,7 +133,7 @@ def frequency_range(self) -> FreqBound: Returns ------- - Tuple[float, float] + tuple[float, float] Minimum and maximum frequencies of the frequency array. """ return (min(self.freqs), max(self.freqs)) @@ -138,14 +142,14 @@ def frequency_range(self) -> FreqBound: class TimeMonitor(Monitor, ABC): """:class:`Monitor` that records data in the time-domain.""" - start: pydantic.NonNegativeFloat = pydantic.Field( + start: NonNegativeFloat = Field( 0.0, title="Start Time", description="Time at which to start monitor recording.", units=SECOND, ) - stop: pydantic.NonNegativeFloat = pydantic.Field( + stop: Optional[NonNegativeFloat] = Field( None, title="Stop Time", description="Time at which to stop monitor recording. " @@ -153,7 +157,7 @@ class TimeMonitor(Monitor, ABC): units=SECOND, ) - interval: pydantic.PositiveInt = pydantic.Field( + interval: Optional[PositiveInt] = Field( None, title="Time Interval", description="Sampling rate of the monitor: number of time steps between each measurement. " @@ -162,14 +166,14 @@ class TimeMonitor(Monitor, ABC): "This can be useful for reducing data storage as needed by the application.", ) - @pydantic.validator("interval", always=True) - @skip_if_fields_missing(["start", "stop"]) - def _warn_interval_default(cls, val, values): + @model_validator(mode="after") + def _warn_interval_default(self) -> Self: """If all defaults used for time sampler, warn and set ``interval=1`` internally.""" + val = self.interval if val is None: - start = values.get("start") - stop = values.get("stop") + start = self.start + stop = self.stop if start == 0.0 and stop is None: log.warning( "The monitor 'interval' field was left as its default value, " @@ -185,18 +189,18 @@ def _warn_interval_default(cls, val, values): ) # set 'interval = 1' for backwards compatibility - val = 1 + object.__setattr__(self, "interval", 1) - return val + return self - @pydantic.validator("stop", always=True, allow_reuse=True) - @skip_if_fields_missing(["start"]) - def stop_greater_than_start(cls, val, values): + @model_validator(mode="after") + def stop_greater_than_start(self) -> Self: """Ensure sure stop is greater than or equal to start.""" - start = values.get("start") - if val and val < start: + stop = self.stop + start = self.start + if stop and stop < start: raise SetupError("Monitor start time is greater than stop time.") - return val + return self def time_inds(self, tmesh: ArrayFloat1D) -> tuple[int, int]: """Compute the starting and stopping index of the monitor in a given discrete time mesh.""" @@ -237,23 +241,21 @@ def num_steps(self, tmesh: ArrayFloat1D) -> int: class AbstractFieldMonitor(Monitor, ABC): """:class:`Monitor` that records electromagnetic field data as a function of x,y,z.""" - fields: tuple[EMField, ...] = pydantic.Field( + fields: tuple[EMField, ...] = Field( ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], title="Field Components", description="Collection of field components to store in the monitor.", ) - interval_space: tuple[pydantic.PositiveInt, pydantic.PositiveInt, pydantic.PositiveInt] = ( - pydantic.Field( - (1, 1, 1), - title="Spatial Interval", - description="Number of grid step intervals between monitor recordings. If equal to 1, " - "there will be no downsampling. If greater than 1, the step will be applied, but the " - "first and last point of the monitor grid are always included.", - ) + interval_space: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( + (1, 1, 1), + title="Spatial Interval", + description="Number of grid step intervals between monitor recordings. If equal to 1, " + "there will be no downsampling. If greater than 1, the step will be applied, but the " + "first and last point of the monitor grid are always included.", ) - colocate: bool = pydantic.Field( + colocate: bool = Field( True, title="Colocate Fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " @@ -281,28 +283,29 @@ def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class AbstractAuxFieldMonitor(Monitor, ABC): """:class:`.Monitor` that records auxiliary fields as a function of x,y,z. - Auxiliary fields are used in certain nonlinear material models. - :class:`.TwoPhotonAbsorption` uses `Nfx`, `Nfy`, and `Nfz` for the - free-carrier density.""" + Notes + ----- + Auxiliary fields are used in certain nonlinear material models. + :class:`.TwoPhotonAbsorption` uses `Nfx`, `Nfy`, and `Nfz` for the + free-carrier density. + """ - fields: tuple[AuxField, ...] = pydantic.Field( + fields: tuple[AuxField, ...] = Field( (), title="Aux Field Components", description="Collection of auxiliary field components to store in the monitor. " "Auxiliary fields which are not present in the simulation will be zero.", ) - interval_space: tuple[pydantic.PositiveInt, pydantic.PositiveInt, pydantic.PositiveInt] = ( - pydantic.Field( - (1, 1, 1), - title="Spatial Interval", - description="Number of grid step intervals between monitor recordings. If equal to 1, " - "there will be no downsampling. If greater than 1, the step will be applied, but the " - "first and last point of the monitor grid are always included.", - ) + interval_space: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( + (1, 1, 1), + title="Spatial Interval", + description="Number of grid step intervals between monitor recordings. If equal to 1, " + "there will be no downsampling. If greater than 1, the step will be applied, but the " + "first and last point of the monitor grid are always included.", ) - colocate: bool = pydantic.Field( + colocate: bool = Field( True, title="Colocate Fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " @@ -337,26 +340,26 @@ def normal_axis(self) -> Axis: class AbstractModeMonitor(PlanarMonitor, FreqMonitor): """:class:`Monitor` that records mode-related data.""" - mode_spec: ModeSpec = pydantic.Field( - ModeSpec(), + mode_spec: ModeSpec = Field( + default_factory=ModeSpec, title="Mode Specification", description="Parameters to feed to mode solver which determine modes measured by monitor.", ) - store_fields_direction: Direction = pydantic.Field( + store_fields_direction: Optional[Direction] = Field( None, title="Store Fields", description="Propagation direction for the mode field profiles stored from mode solving.", ) - colocate: bool = pydantic.Field( + colocate: bool = Field( True, title="Colocate Fields", description="Toggle whether fields should be colocated to grid cell boundaries (i.e. " "primal grid nodes).", ) - conjugated_dot_product: bool = pydantic.Field( + conjugated_dot_product: bool = Field( True, title="Conjugated Dot Product", description="Use conjugated or non-conjugated dot product for mode decomposition.", @@ -414,13 +417,16 @@ def _bend_axis(self) -> Axis: direction = self.unpop_axis(0, in_plane, axis=self.normal_axis) return direction.index(1) - @pydantic.validator("mode_spec", always=True) - def _warn_num_modes(cls, val, values): + @field_validator("mode_spec") + @classmethod + def _warn_num_modes( + cls: type[ModeMonitor], val: ModeSpec, info: FieldValidationInfo + ) -> ModeSpec: """Warn if number of modes is too large.""" if val.num_modes > WARN_NUM_MODES: log.warning( f"A large number ({val.num_modes}) of modes requested in monitor " - f"'{values['name']}'. This can lead to solver slow-down and increased cost. " + f"'{info.field_name}'. This can lead to solver slow-down and increased cost. " "Consider decreasing the number of modes and using 'ModeSpec.target_neff' " "to target the modes of interest. This may become a hard limit in future " "Tidy3D versions.", @@ -531,9 +537,11 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class AuxFieldTimeMonitor(AbstractAuxFieldMonitor, TimeMonitor): """:class:`.Monitor` that records auxiliary fields in the time domain. - Auxiliary fields are used in certain nonlinear material models. - :class:`.TwoPhotonAbsorption` uses `Nfx`, `Nfy`, and `Nfz` for the - free-carrier density. + Notes + ----- + Auxiliary fields are used in certain nonlinear material models. + :class:`.TwoPhotonAbsorption` uses `Nfx`, `Nfy`, and `Nfz` for the + free-carrier density. Example ------- @@ -558,25 +566,23 @@ def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: class AbstractMediumPropertyMonitor(FreqMonitor, ABC): """:class:`Monitor` that records material properties in the frequency domain.""" - colocate: Literal[False] = pydantic.Field( + colocate: Literal[False] = Field( False, title="Colocate Fields", description="Colocation turned off, since colocated medium property values do not have a " "physical meaning - they do not correspond to the subpixel-averaged ones.", ) - interval_space: tuple[pydantic.PositiveInt, pydantic.PositiveInt, pydantic.PositiveInt] = ( - pydantic.Field( - (1, 1, 1), - title="Spatial Interval", - description="Number of grid step intervals between monitor recordings. If equal to 1, " - "there will be no downsampling. If greater than 1, the step will be applied, but the " - "first and last point of the monitor grid are always included.", - ) + interval_space: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( + (1, 1, 1), + title="Spatial Interval", + description="Number of grid step intervals between monitor recordings. If equal to 1, " + "there will be no downsampling. If greater than 1, the step will be applied, but the " + "first and last point of the monitor grid are always included.", ) - apodization: ApodizationSpec = pydantic.Field( - ApodizationSpec(), + apodization: ApodizationSpec = Field( + default_factory=ApodizationSpec, title="Apodization Specification", description="This field is ignored in this monitor.", ) @@ -644,7 +650,7 @@ class SurfaceIntegrationMonitor(Monitor, ABC): """Abstract class for monitors that perform surface integrals during the solver run, as in flux and near to far transformations.""" - normal_dir: Direction = pydantic.Field( + normal_dir: Optional[Direction] = Field( None, title="Normal Vector Orientation", description="Direction of the surface monitor's normal vector w.r.t. " @@ -652,51 +658,48 @@ class SurfaceIntegrationMonitor(Monitor, ABC): "Applies to surface monitors only, and defaults to ``'+'`` if not provided.", ) - exclude_surfaces: tuple[BoxSurface, ...] = pydantic.Field( + exclude_surfaces: Optional[tuple[BoxSurface, ...]] = Field( None, title="Excluded Surfaces", description="Surfaces to exclude in the integration, if a volume monitor.", ) @property - def integration_surfaces(self): + def integration_surfaces(self) -> list[SurfaceIntegrationMonitor]: """Surfaces of the monitor where fields will be recorded for subsequent integration.""" if self.size.count(0.0) == 0: - return self.surfaces_with_exclusion(**self.dict()) + return self.surfaces_with_exclusion(**self.model_dump()) return [self] - @pydantic.root_validator(skip_on_failure=True) - def normal_dir_exists_for_surface(cls, values): + @model_validator(mode="after") + def normal_dir_exists_for_surface(self) -> Self: """If the monitor is a surface, set default ``normal_dir`` if not provided. If the monitor is a box, warn that ``normal_dir`` is relevant only for surfaces.""" - normal_dir = values.get("normal_dir") - name = values.get("name") - size = values.get("size") - if size.count(0.0) != 1: - if normal_dir is not None: + if self.size.count(0.0) != 1: + if self.normal_dir is not None: log.warning( "The ``normal_dir`` field is relevant only for surface monitors " - f"and will be ignored for monitor {name}, which is a box." + f"and will be ignored for monitor {self.name}, which is a box." ) else: - if normal_dir is None: - values["normal_dir"] = "+" - return values + if self.normal_dir is None: + object.__setattr__(self, "normal_dir", "+") + return self - @pydantic.root_validator(skip_on_failure=True) - def check_excluded_surfaces(cls, values): + @model_validator(mode="after") + def check_excluded_surfaces(self) -> Self: """Error if ``exclude_surfaces`` is provided for a surface monitor.""" - exclude_surfaces = values.get("exclude_surfaces") + exclude_surfaces = self.exclude_surfaces if exclude_surfaces is None: - return values - name = values.get("name") - size = values.get("size") + return self + name = self.name + size = self.size if size.count(0.0) > 0: raise SetupError( f"Can't specify ``exclude_surfaces`` for surface monitor {name}; " "valid for box monitors only." ) - return values + return self def _storage_size_solver(self, num_cells: int, tmesh: ArrayFloat1D) -> int: """Size of intermediate data recorded by the monitor during a solver run.""" @@ -811,7 +814,7 @@ class ModeMonitor(AbstractModeMonitor): """ @property - def _to_solver_monitor(self): + def _to_solver_monitor(self) -> Self: """Monitor definition that will be used to define the field recording during the time stepping.""" return self.updated_copy(colocate=False) @@ -850,14 +853,14 @@ class ModeSolverMonitor(AbstractModeMonitor): ... name='mode_monitor') """ - direction: Direction = pydantic.Field( + direction: Direction = Field( "+", title="Propagation Direction", description="Direction of waveguide mode propagation along the axis defined by its normal " "dimension.", ) - fields: tuple[EMField, ...] = pydantic.Field( + fields: tuple[EMField, ...] = Field( ["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"], title="Field Components", description="Collection of field components to store in the monitor. Note that some " @@ -870,19 +873,19 @@ def _stored_freqs(self) -> list[float]: """Return actually stored frequencies of the data.""" return self.mode_spec._sampling_freqs_mode_solver_data(freqs=self.freqs) - @pydantic.root_validator(skip_on_failure=True) - def set_store_fields(cls, values): + @model_validator(mode="after") + def set_store_fields(self) -> Self: """Ensure 'store_fields_direction' is compatible with 'direction'.""" - store_fields_direction = values["store_fields_direction"] - direction = values["direction"] + store_fields_direction = self.store_fields_direction + direction = self.direction if store_fields_direction is None: - values["store_fields_direction"] = direction + object.__setattr__(self, "store_fields_direction", direction) elif store_fields_direction != direction: raise ValidationError( f"The values of 'direction' ({direction}) and 'store_fields_direction' " f"({store_fields_direction}) must be equal." ) - return values + return self def storage_size(self, num_cells: int, tmesh: int) -> int: """Size of monitor storage given the number of points after discretization.""" @@ -895,27 +898,24 @@ def storage_size(self, num_cells: int, tmesh: int) -> int: class FieldProjectionSurface(Tidy3dBaseModel): - """ - Data structure to store surface monitors where near fields are recorded for - field projections. + """Data structure to store surface monitors where near fields are recorded for field projections. + Notes + ----- .. TODO add example and derivation, and more relevant links. See Also -------- - **Notebooks**: * `Performing near field to far field projections <../../notebooks/FieldProjections.html>`_ """ - monitor: FieldMonitor = pydantic.Field( - ..., + monitor: FieldMonitor = Field( title="Field Monitor", description=":class:`.FieldMonitor` on which near fields will be sampled and integrated.", ) - normal_dir: Direction = pydantic.Field( - ..., + normal_dir: Direction = Field( title="Normal Vector Orientation", description=":class:`.Direction` of the surface monitor's normal vector w.r.t.\ the positive x, y or z unit vectors. Must be one of '+' or '-'.", @@ -927,8 +927,9 @@ def axis(self) -> Axis: # assume that the monitor's axis is in the direction where the monitor is thinnest return self.monitor.size.index(0.0) - @pydantic.validator("monitor", always=True) - def is_plane(cls, val): + @field_validator("monitor") + @classmethod + def is_plane(cls, val: FieldMonitor) -> FieldMonitor: """Ensures that the monitor is a plane, i.e., its ``size`` attribute has exactly 1 zero""" size = val.size if size.count(0.0) != 1: @@ -941,7 +942,7 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): and projects them to a given set of observation points. """ - custom_origin: Optional[Coordinate] = pydantic.Field( + custom_origin: Optional[Coordinate] = Field( None, title="Local Origin", description="Local origin used for defining observation points. If ``None``, uses the " @@ -949,7 +950,7 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): units=MICROMETER, ) - far_field_approx: bool = pydantic.Field( + far_field_approx: bool = Field( True, title="Far Field Approximation", description="Whether to enable the far field approximation when projecting fields. " @@ -959,22 +960,20 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): "in the far field of the device.", ) - interval_space: tuple[pydantic.PositiveInt, pydantic.PositiveInt, pydantic.PositiveInt] = ( - pydantic.Field( - (1, 1, 1), - title="Spatial Interval", - description="Number of grid step intervals at which near fields are recorded for " - "projection to the far field, along each direction. If equal to 1, there will be no " - "downsampling. If greater than 1, the step will be applied, but the first and last " - "point of the monitor grid are always included. Using values greater than 1 can " - "help speed up server-side far field projections with minimal accuracy loss, " - "especially in cases where it is necessary for the grid resolution to be high for " - "the FDTD simulation, but such a high resolution is unnecessary for the purpose of " - "projecting the recorded near fields to the far field.", - ) + interval_space: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( + (1, 1, 1), + title="Spatial Interval", + description="Number of grid step intervals at which near fields are recorded for " + "projection to the far field, along each direction. If equal to 1, there will be no " + "downsampling. If greater than 1, the step will be applied, but the first and last " + "point of the monitor grid are always included. Using values greater than 1 can " + "help speed up server-side far field projections with minimal accuracy loss, " + "especially in cases where it is necessary for the grid resolution to be high for " + "the FDTD simulation, but such a high resolution is unnecessary for the purpose of " + "projecting the recorded near fields to the far field.", ) - window_size: tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat] = pydantic.Field( + window_size: tuple[NonNegativeFloat, NonNegativeFloat] = Field( (0, 0), title="Spatial filtering window size", description="Size of the transition region of the windowing function used to ensure that " @@ -991,7 +990,7 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): "and otherwise must remain (0, 0).", ) - medium: MediumType = pydantic.Field( + medium: Optional[MediumType] = Field( None, title="Projection medium", description="Medium through which to project fields. Generally, the fields should be " @@ -1001,12 +1000,12 @@ class AbstractFieldProjectionMonitor(SurfaceIntegrationMonitor, FreqMonitor): "non-default ``medium``.", ) - @pydantic.validator("window_size", always=True) - @skip_if_fields_missing(["size", "name"]) - def window_size_for_surface(cls, val, values): + @model_validator(mode="after") + def window_size_for_surface(self) -> Self: """Ensures that windowing is applied for surface monitors only.""" - size = values.get("size") - name = values.get("name") + val = self.window_size + size = self.size + name = self.name if size.count(0.0) != 1: if val != (0, 0): @@ -1014,16 +1013,19 @@ def window_size_for_surface(cls, val, values): f"A non-zero 'window_size' cannot be used for projection monitor '{name}'. " "Windowing can be applied only for surface projection monitors." ) - return val + return self - @pydantic.validator("window_size", always=True) - @skip_if_fields_missing(["name"]) - def window_size_leq_one(cls, val, values): + @field_validator("window_size") + @classmethod + def window_size_leq_one( + cls: type[AbstractFieldProjectionMonitor], + val: tuple[float, float], + info: FieldValidationInfo, + ) -> tuple[float, float]: """Ensures that each component of the window size is less than or equal to 1.""" - name = values.get("name") if val[0] > 1 or val[1] > 1: raise ValidationError( - f"Each component of 'window_size' for monitor '{name}' " + f"Each component of 'window_size' for monitor '{info.field_name}' " "must be less than or equal to 1." ) return val @@ -1215,23 +1217,21 @@ class FieldProjectionAngleMonitor(AbstractFieldProjectionMonitor): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_: For far field projections in the context of perdiodic boundary conditions. """ - proj_distance: float = pydantic.Field( + proj_distance: float = Field( 1e6, title="Projection Distance", description="Radial distance of the projection points from ``local_origin``.", units=MICROMETER, ) - theta: ObsGridArray = pydantic.Field( - ..., + theta: ObsGridArray = Field( title="Polar Angles", description="Polar angles with respect to the global z axis, relative to the location of " "``local_origin``, at which to project fields.", units=RADIAN, ) - phi: ObsGridArray = pydantic.Field( - ..., + phi: ObsGridArray = Field( title="Azimuth Angles", description="Azimuth angles with respect to the global z axis, relative to the location of " "``local_origin``, at which to project fields.", @@ -1277,7 +1277,7 @@ class DirectivityMonitor(MicrowaveBaseModel, FieldProjectionAngleMonitor, FluxMo ... ) """ - far_field_approx: Literal[True] = pydantic.Field( + far_field_approx: Literal[True] = Field( True, title="Far Field Approximation", description="Directivity calculations require the far field approximation. " @@ -1398,13 +1398,12 @@ class FieldProjectionCartesianMonitor(AbstractFieldProjectionMonitor): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_ """ - proj_axis: Axis = pydantic.Field( - ..., + proj_axis: Axis = Field( title="Projection Plane Axis", description="Axis along which the observation plane is oriented.", ) - proj_distance: float = pydantic.Field( + proj_distance: float = Field( 1e6, title="Projection Distance", description="Signed distance of the projection plane along ``proj_axis``. " @@ -1412,8 +1411,7 @@ class FieldProjectionCartesianMonitor(AbstractFieldProjectionMonitor): units=MICROMETER, ) - x: ObsGridArray = pydantic.Field( - ..., + x: ObsGridArray = Field( title="Local x Observation Coordinates", description="Local x observation coordinates w.r.t. ``local_origin`` and ``proj_axis``. " "When ``proj_axis`` is 0, this corresponds to the global y axis. " @@ -1422,8 +1420,7 @@ class FieldProjectionCartesianMonitor(AbstractFieldProjectionMonitor): units=MICROMETER, ) - y: ObsGridArray = pydantic.Field( - ..., + y: ObsGridArray = Field( title="Local y Observation Coordinates", description="Local y observation coordinates w.r.t. ``local_origin`` and ``proj_axis``. " "When ``proj_axis`` is 0, this corresponds to the global z axis. " @@ -1507,21 +1504,19 @@ class FieldProjectionKSpaceMonitor(AbstractFieldProjectionMonitor): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_ """ - proj_axis: Axis = pydantic.Field( - ..., + proj_axis: Axis = Field( title="Projection Plane Axis", description="Axis along which the observation plane is oriented.", ) - proj_distance: float = pydantic.Field( + proj_distance: float = Field( 1e6, title="Projection Distance", description="Radial distance of the projection points from ``local_origin``.", units=MICROMETER, ) - ux: ObsGridArray = pydantic.Field( - ..., + ux: ObsGridArray = Field( title="Normalized kx", description="Local x component of wave vectors on the observation plane, " "relative to ``local_origin`` and oriented with respect to ``proj_axis``, " @@ -1529,8 +1524,7 @@ class FieldProjectionKSpaceMonitor(AbstractFieldProjectionMonitor): "associated with the background medium. Must be in the range [-1, 1].", ) - uy: ObsGridArray = pydantic.Field( - ..., + uy: ObsGridArray = Field( title="Normalized ky", description="Local y component of wave vectors on the observation plane, " "relative to ``local_origin`` and oriented with respect to ``proj_axis``, " @@ -1538,17 +1532,17 @@ class FieldProjectionKSpaceMonitor(AbstractFieldProjectionMonitor): "associated with the background medium. Must be in the range [-1, 1].", ) - @pydantic.root_validator() - def reciprocal_vector_range(cls, values): + @model_validator(mode="after") + def reciprocal_vector_range(self) -> Self: """Ensure that ux, uy are in [-1, 1].""" - maxabs_ux = max(list(values.get("ux")), key=abs) - maxabs_uy = max(list(values.get("uy")), key=abs) - name = values.get("name") + maxabs_ux = max(list(self.ux), key=abs) + maxabs_uy = max(list(self.uy), key=abs) + name = self.name if maxabs_ux > 1: raise SetupError(f"Entries of 'ux' must lie in the range [-1, 1] for monitor {name}.") if maxabs_uy > 1: raise SetupError(f"Entries of 'uy' must lie in the range [-1, 1] for monitor {name}.") - return values + return self def storage_size(self, num_cells: int, tmesh: ArrayFloat1D) -> int: """Size of monitor storage given the number of points after discretization.""" @@ -1597,7 +1591,7 @@ class DiffractionMonitor(PlanarMonitor, FreqMonitor): * `Multilevel blazed diffraction grating <../../notebooks/GratingEfficiency.html>`_ """ - normal_dir: Direction = pydantic.Field( + normal_dir: Direction = Field( "+", title="Normal Vector Orientation", description="Direction of the surface monitor's normal vector w.r.t. " @@ -1605,7 +1599,7 @@ class DiffractionMonitor(PlanarMonitor, FreqMonitor): "Defaults to ``'+'`` if not provided.", ) - colocate: Literal[False] = pydantic.Field( + colocate: Literal[False] = Field( False, title="Colocate Fields", description="Defines whether fields are colocated to grid cell boundaries (i.e. to the " @@ -1613,8 +1607,9 @@ class DiffractionMonitor(PlanarMonitor, FreqMonitor): "monitors depending on their specific function.", ) - @pydantic.validator("size", always=True) - def diffraction_monitor_size(cls, val): + @field_validator("size") + @classmethod + def diffraction_monitor_size(cls: type[DiffractionMonitor], val: Size) -> Size: """Ensure that the monitor is infinite in the transverse direction.""" if val.count(inf) != 2: raise SetupError( diff --git a/tidy3d/components/nonlinear.py b/tidy3d/components/nonlinear.py index 084a2c1bb9..b178711dcb 100644 --- a/tidy3d/components/nonlinear.py +++ b/tidy3d/components/nonlinear.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union import autograd.numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat, PositiveInt, field_validator from tidy3d.constants import MICROMETER, SECOND, VOLT, WATT from tidy3d.exceptions import SetupError, ValidationError @@ -54,7 +54,7 @@ def _validate_medium_type(self, medium: AbstractMedium) -> None: def _validate_medium(self, medium: AbstractMedium) -> None: """Any additional validation that depends on the medium""" - def _validate_medium_freqs(self, medium: AbstractMedium, freqs: list[pd.PositiveFloat]) -> None: + def _validate_medium_freqs(self, medium: AbstractMedium, freqs: list[PositiveFloat]) -> None: """Any additional validation that depends on the central frequencies of the sources.""" @property @@ -103,14 +103,14 @@ class NonlinearSusceptibility(NonlinearModel): >>> nonlinear_susceptibility = NonlinearSusceptibility(chi3=1) """ - chi3: float = pd.Field( + chi3: float = Field( 0, title="Chi3", description=":math:`\\chi_3` nonlinear susceptibility.", units=f"{MICROMETER}^2 / {VOLT}^2", ) - numiters: pd.PositiveInt = pd.Field( + numiters: Optional[PositiveInt] = Field( None, title="Number of iterations", description="Deprecated. The old usage ``nonlinear_spec=model`` with ``model.numiters`` " @@ -119,8 +119,9 @@ class NonlinearSusceptibility(NonlinearModel): "usage, this parameter is ignored, and ``NonlinearSpec.num_iters`` is used instead.", ) - @pd.validator("numiters", always=True) - def _validate_numiters(cls, val): + @field_validator("numiters") + @classmethod + def _validate_numiters(cls, val: Optional[PositiveInt]) -> Optional[PositiveInt]: """Check that numiters is not too large.""" if val is None: return val @@ -183,51 +184,51 @@ class TwoPhotonAbsorption(NonlinearModel): >>> tpa_model = TwoPhotonAbsorption(beta=1) """ - beta: float = pd.Field( + beta: float = Field( 0, title="TPA coefficient", description="Coefficient for two-photon absorption (TPA).", units=f"{MICROMETER} / {WATT}", ) - tau: pd.NonNegativeFloat = pd.Field( + tau: NonNegativeFloat = Field( 0, title="Carrier lifetime", description="Lifetime for the free carriers created by two-photon absorption (TPA).", units=f"{SECOND}", ) - sigma: pd.NonNegativeFloat = pd.Field( + sigma: NonNegativeFloat = Field( 0, title="FCA cross section", description="Total cross section for free-carrier absorption (FCA). " "Contains contributions from electrons and from holes.", units=f"{MICROMETER}^2", ) - e_e: pd.NonNegativeFloat = pd.Field( + e_e: NonNegativeFloat = Field( 1, title="Electron exponent", description="Exponent for the free electron refractive index shift in the free-carrier plasma dispersion (FCPD).", ) - e_h: pd.NonNegativeFloat = pd.Field( + e_h: NonNegativeFloat = Field( 1, title="Hole exponent", description="Exponent for the free hole refractive index shift in the free-carrier plasma dispersion (FCPD).", ) - c_e: float = pd.Field( + c_e: float = Field( 0, title="Electron coefficient", description="Coefficient for the free electron refractive index shift in the free-carrier plasma dispersion (FCPD).", units=f"{MICROMETER}^(3 e_e)", ) - c_h: float = pd.Field( + c_h: float = Field( 0, title="Hole coefficient", description="Coefficient for the free hole refractive index shift in the free-carrier plasma dispersion (FCPD).", units=f"{MICROMETER}^(3 e_h)", ) - n0: Optional[float] = pd.Field( + n0: Optional[float] = Field( None, title="Linear refractive index", description="Real linear refractive index of the medium, computed for instance using " @@ -235,7 +236,7 @@ class TwoPhotonAbsorption(NonlinearModel): "frequencies of the simulation sources (as long as these are all equal).", ) - freq0: Optional[pd.PositiveFloat] = pd.Field( + freq0: Optional[PositiveFloat] = Field( None, title="Central frequency", description="Central frequency, used to calculate the energy of the free-carriers " @@ -311,14 +312,14 @@ class KerrNonlinearity(NonlinearModel): >>> kerr_model = KerrNonlinearity(n2=1) """ - n2: float = pd.Field( + n2: float = Field( 0, title="Nonlinear refractive index", description="Nonlinear refractive index in the Kerr nonlinearity.", units=f"{MICROMETER}^2 / {WATT}", ) - n0: Optional[float] = pd.Field( + n0: Optional[float] = Field( None, title="Complex linear refractive index", description="Complex linear refractive index of the medium, computed for instance using " @@ -346,7 +347,7 @@ class NonlinearSpec(ABC, Tidy3dBaseModel): >>> medium = Medium(permittivity=2, nonlinear_spec=nonlinear_spec) """ - models: tuple[NonlinearModelType, ...] = pd.Field( + models: tuple[NonlinearModelType, ...] = Field( (), title="Nonlinear models", description="The nonlinear models present in this nonlinear spec. " @@ -354,14 +355,17 @@ class NonlinearSpec(ABC, Tidy3dBaseModel): "Multiple nonlinear models of the same type are not allowed.", ) - num_iters: pd.PositiveInt = pd.Field( + num_iters: PositiveInt = Field( NONLINEAR_DEFAULT_NUM_ITERS, title="Number of iterations", description="Number of iterations for solving nonlinear constitutive relation.", ) - @pd.validator("models", always=True) - def _no_duplicate_models(cls, val): + @field_validator("models") + @classmethod + def _no_duplicate_models( + cls, val: Optional[tuple[NonlinearModelType, ...]] + ) -> Optional[tuple[NonlinearModelType, ...]]: """Ensure each type of model appears at most once.""" if val is None: return val @@ -375,8 +379,9 @@ def _no_duplicate_models(cls, val): ) return val - @pd.validator("num_iters", always=True) - def _validate_num_iters(cls, val, values): + @field_validator("num_iters") + @classmethod + def _validate_num_iters(cls, val: PositiveInt) -> PositiveInt: """Check that num_iters is not too large.""" if val > NONLINEAR_MAX_NUM_ITERS: raise ValidationError( @@ -393,8 +398,11 @@ def aux_fields(self) -> list[str]: fields += model.aux_fields return fields - @pd.validator("models", always=True) - def _consistent_models(cls, val): + @field_validator("models") + @classmethod + def _consistent_models( + cls, val: Optional[tuple[NonlinearModelType, ...]] + ) -> Optional[tuple[NonlinearModelType, ...]]: """Ensure that parameters shared between models are consistent.""" if val is None: return val diff --git a/tidy3d/components/parameter_perturbation.py b/tidy3d/components/parameter_perturbation.py index 8b4cf51392..6e2579aa63 100644 --- a/tidy3d/components/parameter_perturbation.py +++ b/tidy3d/components/parameter_perturbation.py @@ -4,15 +4,12 @@ import functools from abc import ABC, abstractmethod -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, TypeVar, Union import numpy as np -import pydantic.v1 as pd -import xarray as xr +from pydantic import Field, NonNegativeFloat, model_validator -from tidy3d.components.data.validators import validate_no_nans -from tidy3d.components.types import TYPE_TAG_STR, ArrayLike, Ax, Complex, FieldVal, InterpMethod -from tidy3d.components.viz import add_ax_if_none +from tidy3d.components.types.base import ArrayComplex, ArrayFloat, discriminated_union from tidy3d.constants import C_0, CMCUBE, EPSILON_0, HERTZ, KELVIN, PERCMCUBE, inf from tidy3d.exceptions import DataError from tidy3d.log import log @@ -32,6 +29,18 @@ _get_numpy_array, _zeros_like, ) +from .data.validators import validate_no_nans +from .types import Complex, InterpMethod +from .viz import add_ax_if_none + +if TYPE_CHECKING: + from typing import Callable + + import xarray as xr + + from tidy3d.compat import Self + + from .types import Ax, FieldVal """ Generic perturbation classes """ @@ -50,7 +59,9 @@ def is_complex(self) -> bool: """Whether perturbation is complex valued.""" @staticmethod - def _linear_range(interval: tuple[float, float], ref: float, coeff: Union[float, Complex]): + def _linear_range( + interval: tuple[float, float], ref: float, coeff: Union[float, Complex] + ) -> Union[np.ndarray, tuple[float, float]]: """Find value range for a linear perturbation.""" if coeff in (0, 0j): # to avoid 0*inf return np.array([0, 0]) @@ -58,8 +69,8 @@ def _linear_range(interval: tuple[float, float], ref: float, coeff: Union[float, @staticmethod def _get_val( - field: Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType], val: FieldVal - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + field: Union[ArrayFloat, ArrayComplex, CustomSpatialDataType], val: FieldVal + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Get specified value from a field.""" if val == "real": @@ -86,21 +97,21 @@ def _get_val( """ Elementary heat perturbation classes """ +HeatPerturbationType = TypeVar("HeatPerturbationType", bound="HeatPerturbation") +HeatSampleReturn = Union[ArrayFloat, ArrayComplex, CustomSpatialDataType] + + def ensure_temp_in_range( sample: Callable[ - Union[ArrayLike[float], CustomSpatialDataType], - Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType], + [HeatPerturbationType, Union[ArrayFloat, CustomSpatialDataType]], HeatSampleReturn ], -) -> Callable[ - Union[ArrayLike[float], CustomSpatialDataType], - Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType], -]: +) -> Callable[[HeatPerturbationType, Union[ArrayFloat, CustomSpatialDataType]], HeatSampleReturn]: """Decorate ``sample`` to log warning if temperature supplied is out of bounds.""" @functools.wraps(sample) def _sample( - self, temperature: Union[ArrayLike[float], CustomSpatialDataType] - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + self: HeatPerturbationType, temperature: Union[ArrayFloat, CustomSpatialDataType] + ) -> HeatSampleReturn: """New sample function.""" if np.iscomplexobj(temperature): @@ -121,7 +132,7 @@ def _sample( class HeatPerturbation(AbstractPerturbation): """Abstract class for heat perturbation.""" - temperature_range: tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + temperature_range: tuple[NonNegativeFloat, NonNegativeFloat] = Field( (0, inf), title="Temperature range", description="Temperature range in which perturbation model is valid.", @@ -130,14 +141,14 @@ class HeatPerturbation(AbstractPerturbation): @abstractmethod def sample( - self, temperature: Union[ArrayLike[float], CustomSpatialDataType] - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + self, temperature: Union[ArrayFloat, CustomSpatialDataType] + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation. Parameters ---------- temperature : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -147,8 +158,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -159,19 +170,19 @@ def sample( @add_ax_if_none def plot( self, - temperature: ArrayLike[float], + temperature: ArrayFloat, val: FieldVal = "real", - ax: Ax = None, + ax: Optional[Ax] = None, ) -> Ax: """Plot perturbation using provided temperature sample points. Parameters ---------- - temperature : ArrayLike[float] + temperature : ArrayFloat Array of temperature sample points. val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real' Which part of the field to plot. - ax : matplotlib.axes._subplots.Axes = None + ax : Optional[matplotlib.axes._subplots.Axes] = None Matplotlib axes to plot on, if not specified, one is created. Returns @@ -224,15 +235,13 @@ class LinearHeatPerturbation(HeatPerturbation): ... ) """ - temperature_ref: pd.NonNegativeFloat = pd.Field( - ..., + temperature_ref: NonNegativeFloat = Field( title="Reference temperature", description="Temperature at which perturbation is zero.", units=KELVIN, ) - coeff: Union[float, Complex] = pd.Field( - ..., + coeff: Union[float, Complex] = Field( title="Thermo-optic Coefficient", description="Sensitivity (derivative) of perturbation with respect to temperature.", units=f"1/{KELVIN}", @@ -245,14 +254,14 @@ def perturbation_range(self) -> Union[tuple[float, float], tuple[Complex, Comple @ensure_temp_in_range def sample( - self, temperature: Union[ArrayLike[float], CustomSpatialDataType] - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + self, temperature: Union[ArrayFloat, CustomSpatialDataType] + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation at temperature points. Parameters ---------- temperature : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -262,8 +271,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -312,13 +321,12 @@ class CustomHeatPerturbation(HeatPerturbation): ... ) """ - perturbation_values: HeatDataArray = pd.Field( - ..., + perturbation_values: HeatDataArray = Field( title="Perturbation Values", description="Sampled perturbation values.", ) - temperature_range: tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + temperature_range: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = Field( None, title="Temperature range", description="Temperature range in which perturbation model is valid. For " @@ -327,7 +335,7 @@ class CustomHeatPerturbation(HeatPerturbation): units=KELVIN, ) - interp_method: InterpMethod = pd.Field( + interp_method: InterpMethod = Field( "linear", title="Interpolation method", description="Interpolation method to obtain perturbation values between sample points.", @@ -340,11 +348,11 @@ def perturbation_range(self) -> Union[tuple[float, float], tuple[Complex, Comple """Range of possible parameter perturbation values.""" return np.min(self.perturbation_values).item(), np.max(self.perturbation_values).item() - @pd.root_validator(skip_on_failure=True) - def compute_temperature_range(cls, values): + @model_validator(mode="after") + def compute_temperature_range(self) -> Self: """Compute and set temperature range based on provided ``perturbation_values``.""" - perturbation_values = values["perturbation_values"] + perturbation_values = self.perturbation_values # .item() to convert to a scalar temperature_range = ( @@ -352,30 +360,27 @@ def compute_temperature_range(cls, values): np.max(perturbation_values.coords["T"]).item(), ) - if ( - values["temperature_range"] is not None - and values["temperature_range"] != temperature_range - ): + if self.temperature_range is not None and self.temperature_range != temperature_range: log.warning( "Temperature range for 'CustomHeatPerturbation' is calculated automatically " "based on provided 'perturbation_values'. Provided 'temperature_range' will be " "overwritten." ) - values.update({"temperature_range": temperature_range}) + object.__setattr__(self, "temperature_range", temperature_range) - return values + return self @ensure_temp_in_range def sample( - self, temperature: Union[ArrayLike[float], CustomSpatialDataType] - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + self, temperature: Union[ArrayFloat, CustomSpatialDataType] + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation at provided temperature points. Parameters ---------- temperature : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -385,8 +390,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -418,7 +423,7 @@ def is_complex(self) -> bool: return np.iscomplexobj(self.perturbation_values) -HeatPerturbationType = Union[LinearHeatPerturbation, CustomHeatPerturbation] +HeatPerturbationType = discriminated_union(Union[LinearHeatPerturbation, CustomHeatPerturbation]) """ Elementary charge perturbation classes """ @@ -427,26 +432,28 @@ def is_complex(self) -> bool: def ensure_charge_in_range( sample: Callable[ [ - Union[ArrayLike[float], CustomSpatialDataType], - Union[ArrayLike[float], CustomSpatialDataType], + ChargePerturbation, + Union[ArrayFloat, CustomSpatialDataType], + Union[ArrayFloat, CustomSpatialDataType], ], - Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType], + Union[ArrayFloat, ArrayComplex, CustomSpatialDataType], ], ) -> Callable[ [ - Union[ArrayLike[float], CustomSpatialDataType], - Union[ArrayLike[float], CustomSpatialDataType], + ChargePerturbation, + Union[ArrayFloat, CustomSpatialDataType], + Union[ArrayFloat, CustomSpatialDataType], ], - Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType], + Union[ArrayFloat, ArrayComplex, CustomSpatialDataType], ]: """Decorate ``sample`` to log warning if charge supplied is out of bounds.""" @functools.wraps(sample) def _sample( - self, - electron_density: Union[ArrayLike[float], CustomSpatialDataType], - hole_density: Union[ArrayLike[float], CustomSpatialDataType], - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + self: ChargePerturbation, + electron_density: Union[ArrayFloat, CustomSpatialDataType], + hole_density: Union[ArrayFloat, CustomSpatialDataType], + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """New sample function.""" # disable complex input @@ -483,13 +490,13 @@ def _sample( class ChargePerturbation(AbstractPerturbation): """Abstract class for charge perturbation.""" - electron_range: tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + electron_range: tuple[NonNegativeFloat, NonNegativeFloat] = Field( (0, inf), title="Electron Density Range", description="Range of electrons densities in which perturbation model is valid.", ) - hole_range: tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + hole_range: tuple[NonNegativeFloat, NonNegativeFloat] = Field( (0, inf), title="Hole Density Range", description="Range of holes densities in which perturbation model is valid.", @@ -498,22 +505,22 @@ class ChargePerturbation(AbstractPerturbation): @abstractmethod def sample( self, - electron_density: Union[ArrayLike[float], CustomSpatialDataType], - hole_density: Union[ArrayLike[float], CustomSpatialDataType], - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + electron_density: Union[ArrayFloat, CustomSpatialDataType], + hole_density: Union[ArrayFloat, CustomSpatialDataType], + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation. Parameters ---------- electron_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, ] Electron density sample point(s). hole_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -528,8 +535,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -540,22 +547,22 @@ def sample( @add_ax_if_none def plot( self, - electron_density: ArrayLike[float], - hole_density: ArrayLike[float], + electron_density: ArrayFloat, + hole_density: ArrayFloat, val: FieldVal = "real", - ax: Ax = None, + ax: Optional[Ax] = None, ) -> Ax: """Plot perturbation using provided electron and hole density sample points. Parameters ---------- - electron_density : Union[ArrayLike[float], CustomSpatialDataType] + electron_density : Union[ArrayFloat, CustomSpatialDataType] Array of electron density sample points. - hole_density : Union[ArrayLike[float], CustomSpatialDataType] + hole_density : Union[ArrayFloat, CustomSpatialDataType] Array of hole density sample points. val : Literal['real', 'imag', 'abs', 'abs^2', 'phase'] = 'real' Which part of the field to plot. - ax : matplotlib.axes._subplots.Axes = None + ax : Optional[matplotlib.axes._subplots.Axes] = None Matplotlib axes to plot on, if not specified, one is created. Returns @@ -631,30 +638,26 @@ class LinearChargePerturbation(ChargePerturbation): ... ) """ - electron_ref: pd.NonNegativeFloat = pd.Field( - ..., + electron_ref: NonNegativeFloat = Field( title="Reference Electron Density", description="Electron density value at which there is no perturbation due to electrons's " "presence.", units=PERCMCUBE, ) - hole_ref: pd.NonNegativeFloat = pd.Field( - ..., + hole_ref: NonNegativeFloat = Field( title="Reference Hole Density", description="Hole density value at which there is no perturbation due to holes' presence.", units=PERCMCUBE, ) - electron_coeff: float = pd.Field( - ..., + electron_coeff: float = Field( title="Sensitivity to Electron Density", description="Sensitivity (derivative) of perturbation with respect to electron density.", units=CMCUBE, ) - hole_coeff: float = pd.Field( - ..., + hole_coeff: float = Field( title="Sensitivity to Hole Density", description="Sensitivity (derivative) of perturbation with respect to hole density.", units=CMCUBE, @@ -676,22 +679,22 @@ def perturbation_range(self) -> Union[tuple[float, float], tuple[Complex, Comple @ensure_charge_in_range def sample( self, - electron_density: Union[ArrayLike[float], CustomSpatialDataType], - hole_density: Union[ArrayLike[float], CustomSpatialDataType], - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + electron_density: Union[ArrayFloat, CustomSpatialDataType], + hole_density: Union[ArrayFloat, CustomSpatialDataType], + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation at electron and hole density points. Parameters ---------- electron_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, ] Electron density sample point(s). hole_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -707,8 +710,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -791,13 +794,12 @@ class CustomChargePerturbation(ChargePerturbation): ... ) """ - perturbation_values: ChargeDataArray = pd.Field( - ..., + perturbation_values: ChargeDataArray = Field( title="Petrubation Values", description="2D array (vs electron and hole densities) of sampled perturbation values.", ) - electron_range: tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + electron_range: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = Field( None, title="Electron Density Range", description="Range of electrons densities in which perturbation model is valid. For " @@ -805,7 +807,7 @@ class CustomChargePerturbation(ChargePerturbation): "provided ``perturbation_values``", ) - hole_range: tuple[pd.NonNegativeFloat, pd.NonNegativeFloat] = pd.Field( + hole_range: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = Field( None, title="Hole Density Range", description="Range of holes densities in which perturbation model is valid. For " @@ -813,7 +815,7 @@ class CustomChargePerturbation(ChargePerturbation): "provided ``perturbation_values``", ) - interp_method: InterpMethod = pd.Field( + interp_method: InterpMethod = Field( "linear", title="Interpolation method", description="Interpolation method to obtain perturbation values between sample points.", @@ -826,13 +828,13 @@ def perturbation_range(self) -> Union[tuple[float, float], tuple[complex, comple """Range of possible parameter perturbation values.""" return np.min(self.perturbation_values).item(), np.max(self.perturbation_values).item() - @pd.root_validator(skip_on_failure=True) - def compute_eh_ranges(cls, values): + @model_validator(mode="after") + def compute_eh_ranges(self) -> Self: """Compute and set electron and hole density ranges based on provided ``perturbation_values``. """ - perturbation_values = values["perturbation_values"] + perturbation_values = self.perturbation_values electron_range = ( np.min(perturbation_values.coords["n"]).item(), @@ -844,43 +846,44 @@ def compute_eh_ranges(cls, values): np.max(perturbation_values.coords["p"]).item(), ) - if values["electron_range"] is not None and electron_range != values["electron_range"]: + if self.electron_range is not None and electron_range != self.electron_range: log.warning( "Electron density range for 'CustomChargePerturbation' is calculated automatically " "based on provided 'perturbation_values'. Provided 'electron_range' will be " "overwritten." ) - if values["hole_range"] is not None and hole_range != values["hole_range"]: + if self.hole_range is not None and hole_range != self.hole_range: log.warning( "Hole density range for 'CustomChargePerturbation' is calculated automatically " "based on provided 'perturbation_values'. Provided 'hole_range' will be " "overwritten." ) - values.update({"electron_range": electron_range, "hole_range": hole_range}) + object.__setattr__(self, "electron_range", electron_range) + object.__setattr__(self, "hole_range", hole_range) - return values + return self @ensure_charge_in_range def sample( self, - electron_density: Union[ArrayLike[float], CustomSpatialDataType], - hole_density: Union[ArrayLike[float], CustomSpatialDataType], - ) -> Union[ArrayLike[float], ArrayLike[Complex], CustomSpatialDataType]: + electron_density: Union[ArrayFloat, CustomSpatialDataType], + hole_density: Union[ArrayFloat, CustomSpatialDataType], + ) -> Union[ArrayFloat, ArrayComplex, CustomSpatialDataType]: """Sample perturbation at electron and hole density points. Parameters ---------- electron_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, ] Electron density sample point(s). hole_density : Union[ - ArrayLike[float], + ArrayFloat, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -896,8 +899,8 @@ def sample( Returns ------- Union[ - ArrayLike[float], - ArrayLike[complex], + ArrayFloat, + ArrayComplex, :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, @@ -926,10 +929,10 @@ def sample( # clip to allowed values # (this also implicitly convert python arrays into numpy - e_vals = np.core.umath.clip( + e_vals = np._core.umath.clip( electron_density, self.electron_range[0], self.electron_range[1] ) - h_vals = np.core.umath.clip(hole_density, self.hole_range[0], self.hole_range[1]) + h_vals = np._core.umath.clip(hole_density, self.hole_range[0], self.hole_range[1]) # we cannot pass UnstructuredGridDataset directly into xarray interp # thus we need to explicitly grad the underlying xarray @@ -970,9 +973,10 @@ def is_complex(self) -> bool: return np.iscomplexobj(self.perturbation_values) -ChargePerturbationType = Union[LinearChargePerturbation, CustomChargePerturbation] - -PerturbationType = Union[HeatPerturbationType, ChargePerturbationType] +ChargePerturbationType = discriminated_union( + Union[LinearChargePerturbation, CustomChargePerturbation] +) +PerturbationType = discriminated_union(Union[HeatPerturbationType, ChargePerturbationType]) class ParameterPerturbation(Tidy3dBaseModel): @@ -998,26 +1002,24 @@ class ParameterPerturbation(Tidy3dBaseModel): >>> param_perturb = ParameterPerturbation(heat=heat_perturb, charge=charge_perturb) """ - heat: HeatPerturbationType = pd.Field( + heat: Optional[HeatPerturbationType] = Field( None, title="Heat Perturbation", description="Heat perturbation to apply.", - discriminator=TYPE_TAG_STR, ) - charge: ChargePerturbationType = pd.Field( + charge: Optional[ChargePerturbationType] = Field( None, title="Charge Perturbation", description="Charge perturbation to apply.", - discriminator=TYPE_TAG_STR, ) - @pd.root_validator(skip_on_failure=True) - def _check_not_empty(cls, values): + @model_validator(mode="after") + def _check_not_empty(self) -> Self: """Check that perturbation model is not empty.""" - heat = values.get("heat") - charge = values.get("charge") + heat = self.heat + charge = self.charge if heat is None and charge is None: raise DataError( @@ -1025,7 +1027,7 @@ def _check_not_empty(cls, values): "simultaneously 'None'." ) - return values + return self @cached_property def perturbation_list(self) -> list[PerturbationType]: @@ -1048,10 +1050,10 @@ def perturbation_range(self) -> Union[tuple[float, float], tuple[Complex, Comple @staticmethod def _zeros_like( - T: CustomSpatialDataType = None, - n: CustomSpatialDataType = None, - p: CustomSpatialDataType = None, - ): + T: Optional[CustomSpatialDataType] = None, + n: Optional[CustomSpatialDataType] = None, + p: Optional[CustomSpatialDataType] = None, + ) -> CustomSpatialDataType: """Check that fields have the same coordinates and return an array field with zeros.""" template = None for field in [T, n, p]: @@ -1073,9 +1075,9 @@ def _zeros_like( def apply_data( self, - temperature: CustomSpatialDataType = None, - electron_density: CustomSpatialDataType = None, - hole_density: CustomSpatialDataType = None, + temperature: Optional[CustomSpatialDataType] = None, + electron_density: Optional[CustomSpatialDataType] = None, + hole_density: Optional[CustomSpatialDataType] = None, ) -> CustomSpatialDataType: """Sample perturbations on provided heat and/or charge data. At least one of ``temperature``, ``electron_density``, and ``hole_density`` must be not ``None``. @@ -1083,23 +1085,23 @@ def apply_data( Parameters ---------- - temperature : Union[ + temperature : Optional[Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, - ] = None + ]] = None Temperature field data. - electron_density : Union[ + electron_density : Optional[Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, - ] = None + ]] = None Electron density field data. - hole_density : Union[ + hole_density : Optional[Union[ :class:`.SpatialDataArray`, :class:`.TriangularGridDataset`, :class:`.TetrahedralGridDataset`, - ] = None + ]] = None Hole density field data. Returns @@ -1158,24 +1160,24 @@ class PermittivityPerturbation(Tidy3dBaseModel): >>> permittivity_pb = PermittivityPerturbation(delta_eps=delta_eps, delta_sigma=delta_sigma) """ - delta_eps: Optional[ParameterPerturbation] = pd.Field( + delta_eps: Optional[ParameterPerturbation] = Field( None, title="Permittivity Perturbation", description="Perturbation model for permittivity.", ) - delta_sigma: Optional[ParameterPerturbation] = pd.Field( + delta_sigma: Optional[ParameterPerturbation] = Field( None, title="Conductivity Perturbation", description="Perturbation model for conductivity.", ) - @pd.root_validator(skip_on_failure=True) - def _check_not_complex(cls, values): + @model_validator(mode="after") + def _check_not_complex(self) -> Self: """Check that perturbation values are not complex.""" - delta_eps = values.get("delta_eps") - delta_sigma = values.get("delta_sigma") + delta_eps = self.delta_eps + delta_sigma = self.delta_sigma delta_eps_complex = False if delta_eps is None else delta_eps.is_complex delta_sigma_complex = False if delta_sigma is None else delta_sigma.is_complex @@ -1186,14 +1188,14 @@ def _check_not_complex(cls, values): "complex-valued." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def _check_not_empty(cls, values): + @model_validator(mode="after") + def _check_not_empty(self) -> Self: """Check that perturbation model is not empty.""" - delta_eps = values.get("delta_eps") - delta_sigma = values.get("delta_sigma") + delta_eps = self.delta_eps + delta_sigma = self.delta_sigma if delta_eps is None and delta_sigma is None: raise DataError( @@ -1201,9 +1203,14 @@ def _check_not_empty(cls, values): "simultaneously 'None'." ) - return values + return self - def _delta_eps_delta_sigma_ranges(self): + def _delta_eps_delta_sigma_ranges( + self, + ) -> tuple[ + Union[tuple[float, float], tuple[Complex, Complex]], + Union[tuple[float, float], tuple[Complex, Complex]], + ]: """Perturbation range of permittivity.""" delta_eps_range = (0, 0) if self.delta_eps is None else self.delta_eps.perturbation_range @@ -1214,9 +1221,9 @@ def _delta_eps_delta_sigma_ranges(self): def _sample_delta_eps_delta_sigma( self, - temperature: CustomSpatialDataType = None, - electron_density: CustomSpatialDataType = None, - hole_density: CustomSpatialDataType = None, + temperature: Optional[CustomSpatialDataType] = None, + electron_density: Optional[CustomSpatialDataType] = None, + hole_density: Optional[CustomSpatialDataType] = None, ) -> CustomSpatialDataType: """Compute effictive pertubation to eps and sigma.""" @@ -1243,7 +1250,7 @@ def delta_k(self) -> ChargePerturbationType: """Return the perturbation range of the model.""" @abstractmethod - def delta_n(self) -> HeatPerturbationType: + def delta_n(self) -> HeatPerturbationType: # type: ignore[type-var] """Return the perturbation range of the model.""" @@ -1251,16 +1258,16 @@ class NedeljkovicSorefMashanovich(AbstractDeltaModel): """Nedeljkovic-Soref-Mashanovich model for the perturbation of the refractive index and extinction coefficient due to free carriers. - M. Nedeljkovic, R. Soref and G. Z. Mashanovich, "Free-Carrier Electrorefraction and Electroabsorption - Modulation Predictions for Silicon Over the 1–14- μm Infrared Wavelength Range," in IEEE Photonics - Journal, vol. 3, no. 6, pp. 1171-1180, Dec. 2011, doi: 10.1109/JPHOT.2011.2171930 + References + ---------- + .. [1] M. Nedeljkovic, R. Soref and G. Z. Mashanovich, "Free-Carrier Electrorefraction and Electroabsorption + Modulation Predictions for Silicon Over the 1–14- μm Infrared Wavelength Range," in IEEE Photonics + Journal, vol. 3, no. 6, pp. 1171-1180, Dec. 2011, doi: 10.1109/JPHOT.2011.2171930 - Example - ------- """ - perturb_coeffs: PerturbationCoefficientDataArray = pd.Field( - default=PerturbationCoefficientDataArray( + perturb_coeffs: PerturbationCoefficientDataArray = Field( + default_factory=lambda: PerturbationCoefficientDataArray( np.column_stack( [ [ @@ -1506,40 +1513,40 @@ class NedeljkovicSorefMashanovich(AbstractDeltaModel): ) ) - ref_freq: pd.NonNegativeFloat = pd.Field( + ref_freq: NonNegativeFloat = Field( title="Reference Frequency", description="Reference frequency to evaluate perturbation at (Hz).", units=HERTZ, ) - electrons_grid: ArrayLike = pd.Field( - default=np.concatenate(([0], np.logspace(-6, 22, num=200))), + electrons_grid: ArrayFloat = Field( + default_factory=lambda: np.concatenate(([0], np.logspace(-6, 22, num=200))), title="Electron concentration grid.", - descriptio="The model will be evaluated at these concentration values. Since " + description="The model will be evaluated at these concentration values. Since " "the data at these locations will later be interpolated to determine perturbations " "one should provide representative values. Usually, it is convenient to provide " "evenly spaced values in logarithmic scale to cover the whole range of concentrations, " "i.e., `np.concatenate(([0], np.logspace(-6, 22, num=200)))`.", ) - holes_grid: ArrayLike = pd.Field( - default=np.concatenate(([0], np.logspace(-6, 22, num=200))), + holes_grid: ArrayFloat = Field( + default_factory=lambda: np.concatenate(([0], np.logspace(-6, 22, num=200))), title="Hole concentration grid.", - descriptio="The model will be evaluated at these concentration values. Since " + description="The model will be evaluated at these concentration values. Since " "the data at these locations will later be interpolated to determine perturbations " "one should provide representative values. Usually, it is convenient to provide " "evenly spaced values in logarithmic scale to cover the whole range of concentrations, " "i.e., `np.concatenate(([0], np.logspace(-6, 22, num=200)))`.", ) - @pd.root_validator(skip_on_failure=True) - def _check_freq_in_range(cls, values): + @model_validator(mode="after") + def _check_freq_in_range(self) -> Self: """Check that the given frequency is within validity range. If not, issue a warning. """ - freq = values.get("ref_freq") - wavelengths = list(values.get("perturb_coeffs").coords["wvl"]) + freq = self.ref_freq + wavelengths = list(self.perturb_coeffs.coords["wvl"]) freq_range = (C_0 / np.max(wavelengths), C_0 / np.min(wavelengths)) @@ -1549,7 +1556,7 @@ def _check_freq_in_range(cls, values): f"{freq_range[1]} Hz) of the Nedeljkovic-Soref-Mashanovich model." ) - return values + return self @cached_property def ref_wavelength(self) -> float: @@ -1647,31 +1654,30 @@ class IndexPerturbation(Tidy3dBaseModel): >>> index_pb = IndexPerturbation(delta_n=dn_pb, delta_k=dk_pb, freq=C_0) """ - delta_n: Optional[ParameterPerturbation] = pd.Field( + delta_n: Optional[ParameterPerturbation] = Field( None, title="Refractive Index Perturbation", description="Perturbation of the real part of refractive index.", ) - delta_k: Optional[ParameterPerturbation] = pd.Field( + delta_k: Optional[ParameterPerturbation] = Field( None, title="Exctinction Coefficient Perturbation", description="Perturbation of the imaginary part of refractive index.", ) - freq: pd.NonNegativeFloat = pd.Field( - ..., + freq: NonNegativeFloat = Field( title="Frequency", description="Frequency to evaluate permittivity at (Hz).", units=HERTZ, ) - @pd.root_validator(skip_on_failure=True) - def _check_not_complex(cls, values): + @model_validator(mode="after") + def _check_not_complex(self) -> Self: """Check that perturbation values are not complex.""" - dn = values.get("delta_n") - dk = values.get("delta_k") + dn = self.delta_n + dk = self.delta_k dn_complex = False if dn is None else dn.is_complex dk_complex = False if dk is None else dk.is_complex @@ -1681,14 +1687,14 @@ def _check_not_complex(cls, values): "Perturbation models 'dn' and 'dk' in 'IndexPerturbation' cannot be complex-valued." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def _check_not_empty(cls, values): + @model_validator(mode="after") + def _check_not_empty(self) -> Self: """Check that perturbation model is not empty.""" - dn = values.get("delta_n") - dk = values.get("delta_k") + dn = self.delta_n + dk = self.delta_k if dn is None and dk is None: raise DataError( @@ -1696,9 +1702,11 @@ def _check_not_empty(cls, values): "simultaneously 'None'." ) - return values + return self - def _delta_eps_delta_sigma_ranges(self, n: float, k: float): + def _delta_eps_delta_sigma_ranges( + self, n: float, k: float + ) -> tuple[tuple[float, float], tuple[float, float]]: """Perturbation range of permittivity.""" omega0 = 2 * np.pi * self.freq @@ -1734,9 +1742,9 @@ def _sample_delta_eps_delta_sigma( self, n: float, k: float, - temperature: CustomSpatialDataType = None, - electron_density: CustomSpatialDataType = None, - hole_density: CustomSpatialDataType = None, + temperature: Optional[CustomSpatialDataType] = None, + electron_density: Optional[CustomSpatialDataType] = None, + hole_density: Optional[CustomSpatialDataType] = None, ) -> CustomSpatialDataType: """Compute effictive pertubation to eps and sigma.""" @@ -1777,6 +1785,6 @@ def _sample_delta_eps_delta_sigma( return delta_eps, delta_sigma - def from_perturbation_delta_model(cls, deltas_model: AbstractDeltaModel) -> IndexPerturbation: + def from_perturbation_delta_model(cls, deltas_model: AbstractDeltaModel) -> Self: """Create an IndexPerturbation from a DeltaPerturbationModel.""" return IndexPerturbation(delta_n=deltas_model.delta_n, delta_k=deltas_model.delta_k) diff --git a/tidy3d/components/run_time_spec.py b/tidy3d/components/run_time_spec.py index 669ff12e4d..dad94a4b97 100644 --- a/tidy3d/components/run_time_spec.py +++ b/tidy3d/components/run_time_spec.py @@ -1,7 +1,7 @@ # Defines specifications for how long to run a simulation from __future__ import annotations -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat from .base import Tidy3dBaseModel @@ -26,15 +26,14 @@ class RunTimeSpec(Tidy3dBaseModel): """ - quality_factor: pd.PositiveFloat = pd.Field( - ..., + quality_factor: PositiveFloat = Field( title="Quality Factor", description="Quality factor expected in the device. This determines how long the " "simulation will run as it assumes a field decay time that scales proportionally to " "this value.", ) - source_factor: pd.PositiveFloat = pd.Field( + source_factor: PositiveFloat = Field( 3, title="Source Factor", description="The contribution to the ``run_time`` from the longest source is computed from " diff --git a/tidy3d/components/scene.py b/tidy3d/components/scene.py index 8a236c9b82..267a84e9be 100644 --- a/tidy3d/components/scene.py +++ b/tidy3d/components/scene.py @@ -2,24 +2,46 @@ from __future__ import annotations -from typing import Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Optional import autograd.numpy as np +if TYPE_CHECKING: + from typing import Literal, Union + + from pydantic import NonNegativeInt + + from tidy3d.compat import Self + from tidy3d.components.material.types import StructureMediumType + + from .data.utils import CustomSpatialDataType + from .grid.grid import Grid + from .types import ( + Ax, + Bound, + Coordinate, + InterpMethod, + PermittivityComponent, + PlotScale, + Shapely, + Size, + ) + from .viz import PlotParams + try: import matplotlib as mpl import matplotlib.pylab as plt from mpl_toolkits.axes_grid1 import make_axes_locatable except ImportError: mpl = None -import pydantic.v1 as pd +from pydantic import Field, field_validator from tidy3d.components.material.tcad.charge import ( ChargeConductorMedium, SemiconductorMedium, ) from tidy3d.components.material.tcad.heat import SolidMedium, SolidSpec -from tidy3d.components.material.types import MultiPhysicsMediumType3D, StructureMediumType +from tidy3d.components.material.types import MultiPhysicsMediumType3D from tidy3d.components.tcad.doping import ( ConstantDoping, CustomDoping, @@ -33,7 +55,6 @@ from .base import Tidy3dBaseModel, cached_property from .data.utils import ( - CustomSpatialDataType, SpatialDataArray, TetrahedralGridDataset, TriangularGridDataset, @@ -41,7 +62,7 @@ ) from .geometry.base import Box from .geometry.utils import merging_geometries_on_plane -from .grid.grid import Coords, Grid +from .grid.grid import Coords from .material.multi_physics import MultiPhysicsMedium from .medium import ( AbstractCustomMedium, @@ -51,26 +72,13 @@ Medium2D, ) from .structure import Structure -from .types import ( - TYPE_TAG_STR, - Ax, - Bound, - Coordinate, - InterpMethod, - LengthUnit, - PermittivityComponent, - PlotScale, - PriorityMode, - Shapely, - Size, -) +from .types import TYPE_TAG_STR, LengthUnit, PriorityMode from .validators import assert_unique_names from .viz import ( MEDIUM_CMAP, STRUCTURE_EPS_CMAP, STRUCTURE_EPS_CMAP_R, STRUCTURE_HEAT_COND_CMAP, - PlotParams, add_ax_if_none, equal_aspect, plot_params_fluid, @@ -79,6 +87,13 @@ polygon_path, ) +try: + import matplotlib as mpl + import matplotlib.pylab as plt + from mpl_toolkits.axes_grid1 import make_axes_locatable +except ImportError: + pass + # maximum number of mediums supported MAX_NUM_MEDIUMS = 65530 @@ -90,7 +105,7 @@ MAX_STRUCTURES_PER_MEDIUM = 1_000 -def _get_colormap(reverse: bool = False): +def _get_colormap(reverse: bool = False) -> str: return STRUCTURE_EPS_CMAP_R if reverse else STRUCTURE_EPS_CMAP @@ -111,14 +126,14 @@ class Scene(Tidy3dBaseModel): ... ) """ - medium: MultiPhysicsMediumType3D = pd.Field( - Medium(), + medium: MultiPhysicsMediumType3D = Field( + default_factory=Medium, title="Background Medium", description="Background medium of scene, defaults to vacuum if not specified.", discriminator=TYPE_TAG_STR, ) - structures: tuple[Structure, ...] = pd.Field( + structures: Optional[tuple[Structure, ...]] = Field( (), title="Structures", description="Tuple of structures present in scene. " @@ -129,7 +144,7 @@ class Scene(Tidy3dBaseModel): "the structure added later to the structure list takes precedence.", ) - structure_priority_mode: PriorityMode = pd.Field( + structure_priority_mode: PriorityMode = Field( "equal", title="Structure Priority Setting", description="This field only affects structures of `priority=None`. " @@ -138,7 +153,7 @@ class Scene(Tidy3dBaseModel): "`PECMedium` to 100, and others to 0.", ) - plot_length_units: Optional[LengthUnit] = pd.Field( + plot_length_units: Optional[LengthUnit] = Field( "μm", title="Plot Units", description="When set to a supported ``LengthUnit``, " @@ -148,11 +163,13 @@ class Scene(Tidy3dBaseModel): """ Validating setup """ - # make sure all names are unique _unique_structure_names = assert_unique_names("structures") - @pd.validator("structures", always=True) - def _validate_mediums(cls, val): + @field_validator("structures") + @classmethod + def _validate_mediums( + cls, val: Optional[tuple[Structure, ...]] + ) -> Optional[tuple[Structure, ...]]: """Error if too many mediums present. Warn if different mediums have the same name.""" if val is None: @@ -175,7 +192,8 @@ def _validate_mediums(cls, val): return val - # @pd.validator("structures", always=True) + # @field_validator("structures") + # @classmethod # def _validate_num_geometries(cls, val): # """Error if too many geometries in a single structure.""" @@ -198,8 +216,11 @@ def _validate_mediums(cls, val): # return val - @pd.validator("structures", always=True) - def _validate_structures_per_medium(cls, val): + @field_validator("structures") + @classmethod + def _validate_structures_per_medium( + cls, val: Optional[tuple[Structure, ...]] + ) -> Optional[tuple[Structure, ...]]: """Error if too many structures share the same medium; suggest using GeometryGroup.""" if val is None: return val @@ -246,7 +267,7 @@ def bounds(self) -> Bound: Returns ------- - Tuple[float, float, float], Tuple[float, float, float] + tuple[float, float, float], tuple[float, float, float] Min and max bounds packaged as ``(minx, miny, minz), (maxx, maxy, maxz)``. """ @@ -262,7 +283,7 @@ def size(self) -> Size: Returns ------- - Tuple[float, float, float] + tuple[float, float, float] Scene's size. """ @@ -274,7 +295,7 @@ def center(self) -> Coordinate: Returns ------- - Tuple[float, float, float] + tuple[float, float, float] Scene's center. """ @@ -298,7 +319,7 @@ def mediums(self) -> set[StructureMediumType]: Returns ------- - List[:class:`.AbstractMedium`] + list[:class:`.AbstractMedium`] Set of distinct mediums in the scene. """ medium_dict = {self.medium: None} @@ -306,13 +327,13 @@ def mediums(self) -> set[StructureMediumType]: return list(medium_dict.keys()) @cached_property - def medium_map(self) -> dict[StructureMediumType, pd.NonNegativeInt]: + def medium_map(self) -> dict[StructureMediumType, NonNegativeInt]: """Returns dict mapping medium to index in material. ``medium_map[medium]`` returns unique global index of :class:`.AbstractMedium` in scene. Returns ------- - Dict[:class:`.AbstractMedium`, int] + dict[:class:`.AbstractMedium`, int] Mapping between distinct mediums to index in scene. """ @@ -352,12 +373,12 @@ def intersecting_media( ------- test_object : :class:`.Box` Object for which intersecting media are to be detected. - structures : List[:class:`.AbstractMedium`] + structures : list[:class:`.AbstractMedium`] List of structures whose media will be tested. Returns ------- - List[:class:`.AbstractMedium`] + list[:class:`.AbstractMedium`] Set of distinct mediums that intersect with the given planar object. """ structures = [s.to_static() for s in structures] @@ -368,7 +389,7 @@ def intersecting_media( return mediums # if the test object is a volume, test each surface recursively - surfaces = test_object.surfaces_with_exclusion(**test_object.dict()) + surfaces = test_object.surfaces_with_exclusion(**test_object.model_dump()) mediums = set() for surface in surfaces: _mediums = Scene.intersecting_media(surface, structures) @@ -386,12 +407,12 @@ def intersecting_structures( ------- test_object : :class:`.Box` Object for which intersecting media are to be detected. - structures : List[:class:`.AbstractMedium`] + structures : list[:class:`.AbstractMedium`] List of structures whose media will be tested. Returns ------- - List[:class:`.Structure`] + list[:class:`.Structure`] Set of distinct structures that intersect with the given surface, or with the surfaces of the given volume. """ @@ -410,7 +431,7 @@ def intersecting_structures( return structures_merged # if the test object is a volume, test each surface recursively - surfaces = test_object.surfaces_with_exclusion(**test_object.dict()) + surfaces = test_object.surfaces_with_exclusion(**test_object.model_dump()) structures_merged = [] for surface in surfaces: structures_merged += Scene.intersecting_structures(surface, structures) @@ -477,9 +498,9 @@ def plot( position of plane in z direction, only one of x, y, z must be specified to define plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. fill_structures : bool = True Whether to fill structures with color or just draw outlines. @@ -520,9 +541,9 @@ def plot_structures( position of plane in z direction, only one of x, y, z must be specified to define plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. fill : bool = True Whether to fill structures with color or just draw outlines. @@ -675,9 +696,9 @@ def _set_plot_bounds( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns ------- @@ -703,7 +724,7 @@ def _get_structures_2dbox( Parameters ---------- - structures : List[:class:`.Structure`] + structures : list[:class:`.Structure`] list of structures to filter on the plane. x : float = None position of plane in x direction, only one of x, y, z must be specified to define plane. @@ -711,14 +732,14 @@ def _get_structures_2dbox( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns ------- - List[Tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] + list[tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] List of shapes and mediums on the plane. """ # if no hlim and/or vlim given, the bounds will then be the usual pml bounds @@ -760,14 +781,14 @@ def _filter_structures_plane_medium( Parameters ---------- - structures : List[:class:`.Structure`] + structures : list[:class:`.Structure`] List of structures to filter on the plane. plane : Box Plane specification. Returns ------- - List[Tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] + list[tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] List of shapes and mediums on the plane after merging. """ @@ -780,14 +801,14 @@ def _filter_structures_plane_medium( def _filter_structures_plane( structures: list[Structure], plane: Box, - property_list: list, + property_list: list[Any], ) -> list[tuple[Medium, Shapely]]: """Compute list of shapes to plot on plane. Overlaps are removed or merged depending on provided property_list. Parameters ---------- - structures : List[:class:`.Structure`] + structures : list[:class:`.Structure`] List of structures to filter on the plane. plane : Box Plane specification. @@ -796,7 +817,7 @@ def _filter_structures_plane( Returns ------- - List[Tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] + list[tuple[:class:`.AbstractMedium`, shapely.geometry.base.BaseGeometry]] List of shapes and their property value on the plane after merging. """ return merging_geometries_on_plane( @@ -839,9 +860,9 @@ def plot_eps( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. eps_lim : Tuple[float, float] = None Custom limits for eps coloring. @@ -913,15 +934,15 @@ def plot_structures_eps( alpha : float = None Opacity of the structures being plotted. Defaults to the structure default alpha. - eps_lim : Tuple[float, float] = None + eps_lim : tuple[float, float] = None Custom limits for eps coloring. scale : PlotScale = "lin" Scale for the plot. Either 'lin' for linear, 'log' for log10, 'symlog' for symmetric logarithmic (linear near zero, logarithmic elsewhere), or 'dB' for decibel scale. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. eps_component : Optional[PermittivityComponent] = None Component of the permittivity tensor to plot for anisotropic materials, @@ -994,16 +1015,16 @@ def plot_structures_property( alpha : float = None Opacity of the structures being plotted. Defaults to the structure default alpha. - limits : Tuple[float, float] = None + limits : tuple[float, float] = None Custom coloring limits for the property to plot. scale : PlotScale = "lin" Scale for the plot. Either 'lin' for linear, 'log' for log10, or 'dB' for decibel scale. For log scale with negative values, the absolute value is taken before log transformation. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. property: Literal["eps", "doping", "N_a", "N_d"] = "eps" Indicates the property to plot for the structures. Currently supported properties @@ -1246,7 +1267,7 @@ def eps_bounds( Returns ------- - Tuple[float, float] + tuple[float, float] Minimal and maximal values of relative permittivity in scene. """ @@ -1532,9 +1553,9 @@ def plot_heat_charge_property( ["heat_conductivity", "electric_conductivity"] ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -1586,9 +1607,9 @@ def plot_structures_heat_conductivity( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -1652,9 +1673,9 @@ def plot_structures_heat_charge_property( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -1721,12 +1742,12 @@ def plot_structures_heat_charge_property( ) return ax - def heat_charge_property_bounds(self, property) -> tuple[float, float]: + def heat_charge_property_bounds(self, property: str) -> tuple[float, float]: """Compute range of the heat-charge simulation property present in the scene. Returns ------- - Tuple[float, float] + tuple[float, float] Minimal and maximal values of thermal conductivity in scene. """ @@ -1755,7 +1776,7 @@ def heat_conductivity_bounds(self) -> tuple[float, float]: Returns ------- - Tuple[float, float] + tuple[float, float] Minimal and maximal values of thermal conductivity in scene. """ log.warning( @@ -1847,7 +1868,7 @@ def plot_heat_conductivity( ax: Ax = None, hlim: Optional[tuple[float, float]] = None, vlim: Optional[tuple[float, float]] = None, - ): + ) -> Ax: """Plot each of scebe's components on a plane defined by one nonzero x,y,z coordinate. The thermal conductivity is plotted in grayscale based on its value. @@ -1866,9 +1887,9 @@ def plot_heat_conductivity( Whether to plot a colorbar for the thermal conductivity. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -1899,11 +1920,11 @@ def plot_heat_conductivity( def perturbed_mediums_copy( self, - temperature: CustomSpatialDataType = None, - electron_density: CustomSpatialDataType = None, - hole_density: CustomSpatialDataType = None, + temperature: Optional[CustomSpatialDataType] = None, + electron_density: Optional[CustomSpatialDataType] = None, + hole_density: Optional[CustomSpatialDataType] = None, interp_method: InterpMethod = "linear", - ) -> Scene: + ) -> Self: """Return a copy of the scene with heat and/or charge data applied to all mediums that have perturbation models specified. That is, such mediums will be replaced with spatially dependent custom mediums that reflect perturbation effects. Any of temperature, @@ -1940,7 +1961,7 @@ def perturbed_mediums_copy( Simulation after application of heat and/or charge data. """ - scene_dict = self.dict() + scene_dict = self.model_dump() structures = self.sorted_structures array_dict = { "temperature": temperature, @@ -1984,9 +2005,9 @@ def perturbed_mediums_copy( if isinstance(med, AbstractPerturbationMedium): scene_dict["medium"] = med.perturbed_copy(**array_dict, interp_method=interp_method) - return Scene.parse_obj(scene_dict) + return Scene.model_validate(scene_dict) - def doping_bounds(self): + def doping_bounds(self) -> tuple[list[float], list[float]]: """Get the maximum and minimum of the doping""" acceptors_lims = [np.inf, -np.inf] @@ -2041,7 +2062,7 @@ def doping_bounds(self): donors_lims[1] = 0 return acceptors_lims, donors_lims - def doping_absolute_minimum(self): + def doping_absolute_minimum(self) -> tuple[float, float]: """Get the absolute minimum values of the doping concentrations. Returns @@ -2069,7 +2090,9 @@ def doping_absolute_minimum(self): return acceptors_abs_min, donors_abs_min - def _get_absolute_minimum_from_doping(self, doping): + def _get_absolute_minimum_from_doping( + self, doping: Union[float, SpatialDataArray, tuple[DopingBoxType, ...]] + ) -> float: """Helper method to get absolute minimum from a single doping specification. Parameters @@ -2114,7 +2137,7 @@ def _pcolormesh_shape_doping_box( shape: Shapely, ax: Ax, plt_type: str = "doping", - norm: mpl.colors.Normalize = None, + norm: Optional[mpl.colors.Normalize] = None, ) -> None: """ Plot shape made of structure defined with doping. @@ -2158,7 +2181,11 @@ def _pcolormesh_shape_doping_box( if not data_is_2d: selector = {"xyz"[normal_axis_ind]: normal_position} data_2D = doping.sel(**selector) - contrib = data_2D.interp(**struct_coords, method="nearest") + contrib = data_2D.interp( + **struct_coords, + method="nearest", + kwargs={"bounds_error": False, "fill_value": 0}, + ) struct_doping[n] = struct_doping[n] + contrib # Handle doping boxes if isinstance(doping, tuple): @@ -2188,7 +2215,7 @@ def _pcolormesh_shape_doping_box( norm=norm, ) - def plot_3d(self, width=800, height=800) -> None: + def plot_3d(self, width: int = 800, height: int = 800) -> None: """Render 3D plot of ``Scene`` (in jupyter notebook only). Parameters ---------- diff --git a/tidy3d/components/simulation.py b/tidy3d/components/simulation.py index 2b2be561d0..6c50886fc0 100644 --- a/tidy3d/components/simulation.py +++ b/tidy3d/components/simulation.py @@ -6,31 +6,34 @@ import pathlib from abc import ABC, abstractmethod from collections import defaultdict -from os import PathLike -from typing import Any, Literal, Optional, Union, get_args +from typing import TYPE_CHECKING, Any, Literal, Optional, Union, get_args import autograd.numpy as np - -from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec - -from .types.monitor import MonitorType - -try: - import matplotlib as mpl -except ImportError: - pass - - -import pydantic.v1 as pydantic import xarray as xr +from pydantic import ( + Field, + NonNegativeFloat, + NonNegativeInt, + PositiveFloat, + field_validator, + model_validator, +) +from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec +from tidy3d.components.types.base import discriminated_union from tidy3d.constants import C_0, SECOND, fp_eps, inf -from tidy3d.exceptions import SetupError, Tidy3dError, Tidy3dImportError, ValidationError +from tidy3d.exceptions import ( + AdjointError, + SetupError, + Tidy3dError, + Tidy3dImportError, + ValidationError, +) from tidy3d.log import log from tidy3d.packaging import disable_local_subpixel, supports_local_subpixel, tidy3d_extras from tidy3d.updater import Updater -from .base import cached_property, skip_if_fields_missing +from .base import cached_property from .base_sim.simulation import AbstractSimulation from .boundary import ( PML, @@ -47,20 +50,15 @@ PMCBoundary, StablePML, ) -from .data.data_array import ( - FreqDataArray, - IndexedDataArray, -) -from .data.dataset import Dataset +from .data.data_array import FreqDataArray, IndexedDataArray from .data.unstructured.tetrahedral import TetrahedralGridDataset from .data.unstructured.triangular import TriangularGridDataset -from .data.utils import CustomSpatialDataType from .frequency_extrapolation import LowFrequencySmoothingSpec from .geometry.base import Box, Geometry, GeometryGroup from .geometry.mesh import TriangleMesh from .geometry.utils import _shift_object, flatten_groups, traverse_geometries from .geometry.utils_2d import get_bounds, get_thickened_geom, snap_coordinate_to_grid, subdivide -from .grid.grid import Coords, Coords1D, Grid +from .grid.grid import Coords, Grid from .grid.grid_spec import AutoGrid, GridSpec, UniformGrid from .lumped_element import LumpedElementType from .medium import ( @@ -72,7 +70,6 @@ LossyMetalMedium, Medium, Medium2D, - MediumType, MediumType3D, PECMedium, ) @@ -93,7 +90,6 @@ FreqMonitor, MediumMonitor, ModeMonitor, - Monitor, PermittivityMonitor, SurfaceIntegrationMonitor, TimeMonitor, @@ -114,22 +110,10 @@ from .source.frame import PECFrame from .source.time import ContinuousWave, CustomSourceTime from .source.utils import SourceType -from .structure import MeshOverrideStructure, Structure +from .structure import Structure from .subpixel_spec import SubpixelSpec -from .types import ( - TYPE_TAG_STR, - ArrayFloat1D, - ArrayFloat2D, - Ax, - Axis, - CoordinateOptional, - FreqBound, - InterpMethod, - PermittivityComponent, - Shapely, - Symmetry, - annotate_type, -) +from .types import TYPE_TAG_STR, PermittivityComponent, Symmetry +from .types.monitor import MonitorType from .validators import ( assert_objects_contained_in_sim_bounds, assert_objects_in_sim_bounds, @@ -149,6 +133,38 @@ plot_sim_3d, ) +if TYPE_CHECKING: + from os import PathLike + from typing import Callable + + from numpy.typing import NDArray + + from tidy3d.compat import Self + + from .autograd.types import AutogradFieldMap + from .boundary import BoundaryEdgeType + from .data.dataset import Dataset + from .data.utils import CustomSpatialDataType + from .grid.grid import Coords1D + from .medium import MediumType + from .monitor import Monitor + from .structure import MeshOverrideStructure + from .types import ( + ArrayFloat1D, + ArrayFloat2D, + Ax, + Axis, + CoordinateOptional, + FreqBound, + InterpMethod, + Shapely, + ) + +try: + import matplotlib as mpl +except ImportError: + pass + try: gdstk_available = True import gdstk @@ -191,16 +207,18 @@ RF_FREQ_WARNING = 300e9 -def validate_boundaries_for_zero_dims(warn_on_change: bool = True): +def validate_boundaries_for_zero_dims( + warn_on_change: bool = True, +) -> Callable[[AbstractYeeGridSimulation], AbstractYeeGridSimulation]: """Error if absorbing boundaries, bloch boundaries, unmatching pec/pmc, or symmetry is used along a zero dimension.""" - @pydantic.validator("boundary_spec", allow_reuse=True, always=True) - @skip_if_fields_missing(["size", "symmetry"]) - def boundaries_for_zero_dims(cls, val, values): + @model_validator(mode="after") + def boundaries_for_zero_dims(self: AbstractYeeGridSimulation) -> AbstractYeeGridSimulation: """Error if absorbing boundaries, bloch boundaries, unmatching pec/pmc, or symmetry is used along a zero dimension.""" + val = self.boundary_spec boundaries = val.to_list - size = values.get("size") - symmetry = values.get("symmetry") + size = self.size + symmetry = self.symmetry axis_names = "xyz" for dim, (boundary, symmetry_dim, size_dim) in enumerate(zip(boundaries, symmetry, size)): @@ -246,7 +264,11 @@ def boundaries_for_zero_dims(cls, val, values): "minus must be the same." ) - return val + # Update boundary_spec if it was modified + if val != self.boundary_spec: + object.__setattr__(self, "boundary_spec", val) + + return self return boundaries_for_zero_dims @@ -256,7 +278,7 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): Abstract class for a simulation involving electromagnetic fields defined on a Yee grid. """ - lumped_elements: tuple[LumpedElementType, ...] = pydantic.Field( + lumped_elements: tuple[LumpedElementType, ...] = Field( (), title="Lumped Elements", description="Tuple of lumped elements in the simulation. " @@ -266,8 +288,8 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): Tuple of lumped elements in the simulation. """ - grid_spec: GridSpec = pydantic.Field( - GridSpec(), + grid_spec: GridSpec = Field( + default_factory=GridSpec, title="Grid Specification", description="Specifications for the simulation grid along each of the three directions.", ) @@ -306,8 +328,8 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): * `Using automatic nonuniform meshing <../../notebooks/AutoGrid.html>`_ """ - subpixel: Union[bool, SubpixelSpec] = pydantic.Field( - SubpixelSpec(), + subpixel: Union[bool, SubpixelSpec] = Field( + default_factory=SubpixelSpec, title="Subpixel Averaging", description="Apply subpixel averaging methods of the permittivity on structure interfaces " "to result in much higher accuracy for a given grid size. Supply a :class:`.SubpixelSpec` " @@ -316,6 +338,7 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): "``True`` to apply the default subpixel averaging methods corresponding to ``SubpixelSpec()`` " ", or ``False`` to apply staircasing.", ) + """ Supply :class:`.SubpixelSpec` to select subpixel averaging methods separately for dielectric, metal, and PEC material interfaces. Alternatively, supply ``True`` to use default subpixel averaging methods, @@ -360,44 +383,42 @@ class AbstractYeeGridSimulation(AbstractSimulation, ABC): * `Dielectric constant assignment on Yee grids `_ """ - simulation_type: Optional[Literal["autograd_fwd", "autograd_bwd", "tidy3d", None]] = ( - pydantic.Field( - "tidy3d", - title="Simulation Type", - description="Tag used internally to distinguish types of simulations for " - "``autograd`` gradient processing.", - ) + simulation_type: Optional[Literal["autograd_fwd", "autograd_bwd", "tidy3d"]] = Field( + "tidy3d", + title="Simulation Type", + description="Tag used internally to distinguish types of simulations for " + "``autograd`` gradient processing.", ) - post_norm: Union[float, FreqDataArray] = pydantic.Field( + post_norm: Union[float, FreqDataArray] = Field( 1.0, title="Post Normalization Values", description="Factor to multiply the fields by after running, " "given the adjoint source pipeline used. Note: this is used internally only.", ) - internal_absorbers: tuple[InternalAbsorber, ...] = pydantic.Field( + internal_absorbers: tuple[InternalAbsorber, ...] = Field( (), title="Internal Absorbers", description="Planes with the first order absorbing boundary conditions placed inside the computational domain. " "Note that internal absorbers are automatically wrapped in a PEC frame with a backing PEC plate on the non-absorbing side.", ) - @pydantic.validator("simulation_type", always=True) - def _validate_simulation_type_tidy3d(cls, val): + @field_validator("simulation_type") + @classmethod + def _validate_simulation_type_tidy3d( + cls, val: Optional[Literal["autograd_fwd", "autograd_bwd", "tidy3d"]] + ) -> Literal["autograd_fwd", "autograd_bwd", "tidy3d"]: """Enforce the simulation_type is 'tidy3d' if passed as None for bkwrds compatibility.""" - if val is None: - return "tidy3d" - return val + return "tidy3d" if val is None else val - @pydantic.validator("lumped_elements", always=True) - @skip_if_fields_missing(["structures"]) - def _validate_num_lumped_elements(cls, val, values): + @model_validator(mode="after") + def _validate_num_lumped_elements(self) -> Self: """Error if too many lumped elements present.""" - + val = self.lumped_elements if val is None: - return val - structures = values.get("structures") + return self + structures = self.structures mediums = {structure.medium for structure in structures} total_num_mediums = len(val) + len(mediums) if total_num_mediums > MAX_NUM_MEDIUMS: @@ -406,22 +427,21 @@ def _validate_num_lumped_elements(cls, val, values): f"{total_num_mediums} were supplied." ) - return val + return self - @pydantic.validator("lumped_elements") - @skip_if_fields_missing(["size"]) - def _check_3d_simulation_with_lumped_elements(cls, val, values): + @model_validator(mode="after") + def _check_3d_simulation_with_lumped_elements(self) -> Self: """Error if Simulation contained lumped elements and is not a 3D simulation""" - size = values.get("size") + val = self.lumped_elements + size = self.size if val and size.count(0.0) > 0: raise ValidationError( - f"'{cls.__name__}' must be a 3D simulation when a 'LumpedElement' is present." + f"'{self.__class__.__name__}' must be a 3D simulation when a 'LumpedElement' is present." ) - return val + return self - @pydantic.validator("grid_spec", always=True) @abstractmethod - def _validate_auto_grid_wavelength(cls, val, values) -> None: + def _validate_auto_grid_wavelength(val) -> None: """Check that wavelength can be defined if there is auto grid spec.""" def _monitor_num_cells(self, monitor: Monitor) -> int: @@ -442,8 +462,8 @@ def num_cells_in_monitor(monitor: Monitor) -> int: return sum(num_cells_in_monitor(mnt) for mnt in monitor.integration_surfaces) return num_cells_in_monitor(monitor) - @pydantic.validator("boundary_spec") - def _validate_boundary_spec_symmetry(cls, val, values): + @model_validator(mode="after") + def _validate_boundary_spec_symmetry(self) -> Self: """Error if symmetry is imposed along an axis but the boundary conditions are not the same on both sides.""" @@ -454,14 +474,15 @@ def equivalent(plus: BoundarySpec, minus: BoundarySpec) -> bool: minus_cpy = minus.updated_copy(name="") return plus_cpy == minus_cpy - boundaries = [val.x, val.y, val.z] - for ax, symmetry, ax_bounds in zip("xyz", values.get("symmetry"), boundaries): + bs = self.boundary_spec + boundaries = [bs.x, bs.y, bs.z] + for ax, symmetry, ax_bounds in zip("xyz", self.symmetry, boundaries): if symmetry != 0 and not equivalent(ax_bounds.plus, ax_bounds.minus): raise ValidationError( f"Symmetry '{symmetry}' along axis {ax} requires the same boundary " f"condition on both sides of the axis." ) - return val + return self @cached_property def _subpixel(self) -> SubpixelSpec: @@ -583,9 +604,9 @@ def plot( Use the exact placement of port absorbers which take into account their ``shift`` values. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -689,9 +710,9 @@ def plot_eps( Use the exact placement of port absorbers which take into account their ``shift`` values. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. eps_component : Optional[PermittivityComponent] = None Component of the permittivity tensor to plot for anisotropic materials, @@ -804,9 +825,9 @@ def plot_structures_eps( Defaults to the structure default alpha. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. eps_component : Optional[PermittivityComponent] = None Component of the permittivity tensor to plot for anisotropic materials, @@ -874,9 +895,9 @@ def plot_pml( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. ax : matplotlib.axes._subplots.Axes = None Matplotlib axes to plot on, if not specified, one is created. @@ -947,7 +968,7 @@ def _make_pml_box(self, pml_axis: Axis, pml_height: float, sign: int) -> Box: for dim_index, sim_size in enumerate(self.size): if sim_size == 0.0: new_size[dim_index] = PML_HEIGHT_FOR_0_DIMS - pml_box = pml_box.updated_copy(size=new_size) + pml_box = pml_box.updated_copy(size=tuple(new_size)) return pml_box @@ -980,7 +1001,7 @@ def pml_thicknesses(self) -> list[tuple[float, float]]: return pml_thicknesses @cached_property - def _internal_layerrefinement_boundary_types(self): + def _internal_layerrefinement_boundary_types(self) -> list[list[Optional[str]]]: """Boundary types for layer refinement.""" boundary_types = [[None, None], [None, None], [None, None]] for dim, boundary in enumerate(self.boundary_spec.to_list): @@ -1086,9 +1107,9 @@ def plot_lumped_elements( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. alpha : float = None Opacity of the lumped element, If ``None`` uses Tidy3d default. @@ -1132,9 +1153,9 @@ def plot_grid( position of plane in y direction, only one of x, y, z must be specified to define plane. z : float = None position of plane in z direction, only one of x, y, z must be specified to define plane. - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. override_structures_alpha : float = 1 Opacity of the override structures. @@ -1161,8 +1182,8 @@ def plot_grid( cell_boundaries = self.grid.boundaries axis, _ = self.parse_xyz_kwargs(x=x, y=y, z=z) _, (axis_x, axis_y) = self.pop_axis([0, 1, 2], axis=axis) - boundaries_x = cell_boundaries.dict()["xyz"[axis_x]] - boundaries_y = cell_boundaries.dict()["xyz"[axis_y]] + boundaries_x = cell_boundaries.model_dump()["xyz"[axis_x]] + boundaries_y = cell_boundaries.model_dump()["xyz"[axis_y]] if self.size[axis_x] > 0: for b in boundaries_x: @@ -1290,7 +1311,12 @@ def plot_boundaries( The supplied or created matplotlib axes. """ - def set_plot_params(boundary_edge, lim, side, thickness): + def set_plot_params( + boundary_edge: Union[ABCBoundary, ModeABCBoundary, BoundaryEdgeType], + lim: float, + side: Literal[-1, 1], + thickness: float, + ) -> tuple[PlotParams, float]: """Return the line plot properties such as color and opacity based on the boundary""" if isinstance(boundary_edge, PECBoundary): plot_params = plot_params_pec.copy(deep=True) @@ -1397,7 +1423,7 @@ def _grid_and_snapping_lines(self) -> tuple[Grid, list[CoordinateOptional]]: Returns ------- - Tuple[:class:`.Grid`, list[CoordinateOptional]] + Tuple[:class:`.Grid`, List[CoordinateOptional]] :class:`.Grid` storing the spatial locations relevant to the simulation the list of snapping points generated during iterative gap meshing. """ @@ -1486,7 +1512,7 @@ def grid_info(self) -> dict: """Dictionary collecting various properties of the grids in the simulation.""" return self.grid.info - def _subgrid(self, span_inds: np.ndarray, grid: Grid = None): + def _subgrid(self, span_inds: np.ndarray, grid: Grid = None) -> Grid: """Take a subgrid of the simulation grid with cell span defined by ``span_inds`` along the three dimensions. Optionally, a grid different from the simulation grid can be provided. The ``span_inds`` can also extend beyond the grid, in which case the grid is padded based @@ -1520,7 +1546,7 @@ def num_pml_layers(self) -> list[tuple[float, float]]: Returns ------- - list[Tuple[float, float]] + list[tuple[float, float]] List containing the number of absorber layers in - and + boundaries. """ num_layers = [[0, 0], [0, 0], [0, 0]] @@ -1532,7 +1558,7 @@ def num_pml_layers(self) -> list[tuple[float, float]]: return num_layers - def _snap_zero_dim(self, grid: Grid, skip_axis: Axis = None): + def _snap_zero_dim(self, grid: Grid, skip_axis: Optional[Axis] = None) -> Grid: """Snap a grid to the simulation center along any dimension along which simulation is effectively 0D, defined as having a single pixel. This is more general than just checking size = 0.""" @@ -1558,7 +1584,7 @@ def _discretize_grid(self, box: Box, grid: Grid, extend: bool = False) -> Grid: def _discretize_inds_monitor( self, monitor: Union[Monitor, Box], colocate: Optional[bool] = None - ): + ) -> NDArray: """Start and stopping indexes for the cells where data needs to be recorded to fully cover a ``monitor``. This is used during the solver run. The final grid on which a monitor data lives is computed in ``discretize_monitor``, with the difference being that 0-sized @@ -1703,7 +1729,7 @@ def epsilon_on_grid( subpixel_sim = tidy3d_extras["mod"].SubpixelSimulation.from_simulation(self) return subpixel_sim.epsilon_on_grid(grid=grid, coord_key=coord_key, freq=freq) - def get_eps(structure: Structure, frequency: float, coords: Coords): + def get_eps(structure: Structure, frequency: float, coords: Coords) -> complex: """Select the correct epsilon component if field locations are requested.""" if coord_key[0] != "E": return np.mean(structure.eps_diagonal(frequency, coords), axis=0) @@ -1714,7 +1740,7 @@ def get_eps(structure: Structure, frequency: float, coords: Coords): col = ["x", "y", "z"].index(coord_key[2]) return structure.eps_comp(row, col, frequency, coords) - def make_eps_data(coords: Coords): + def make_eps_data(coords: Coords) -> xr.DataArray: """returns epsilon data on grid of points defined by coords""" arrays = (np.array(coords.x), np.array(coords.y), np.array(coords.z)) eps_background = get_eps( @@ -1925,16 +1951,16 @@ def subsection( simulation. If ``identical``, then the original grid is transferred directly as a :class:`.CustomGrid`. Note that in the latter case the region of the new simulation is snapped to the original grid lines. - symmetry : Tuple[Literal[0, -1, 1], Literal[0, -1, 1], Literal[0, -1, 1]] = None + symmetry : tuple[Literal[0, -1, 1], Literal[0, -1, 1], Literal[0, -1, 1]] = None New simulation symmetry. If ``None``, then it is inherited from the original simulation. Note that in this case the size and placement of new simulation domain must be commensurate with the original symmetry. warn_symmetry_expansion : bool = True Whether to warn when the subsection is expanded to preserve symmetry. - sources : Tuple[SourceType, ...] = None + sources : tuple[SourceType, ...] = None New list of sources. If ``None``, then the sources intersecting the new simulation domain are inherited from the original simulation. - monitors : Tuple[MonitorType, ...] = None + monitors : tuple[MonitorType, ...] = None New list of monitors. If ``None``, then the monitors intersecting the new simulation domain are inherited from the original simulation. remove_outside_structures : bool = True @@ -2086,10 +2112,10 @@ def subsection( size=new_box.size, grid_spec=grid_spec, boundary_spec=boundary_spec, - monitors=[], - sources=sources, # need wavelength in case of auto grid - symmetry=symmetry, - structures=aux_new_structures, + monitors=(), + sources=tuple(sources), # need wavelength in case of auto grid + symmetry=tuple(symmetry), + structures=tuple(aux_new_structures), deep=deep_copy, ) @@ -2147,12 +2173,12 @@ def subsection( medium=new_sim_medium, grid_spec=grid_spec, boundary_spec=boundary_spec, - monitors=monitors, - sources=sources, - symmetry=symmetry, - structures=aux_new_structures, - lumped_elements=new_lumped_elements, - internal_absorbers=internal_absorbers, + monitors=tuple(monitors), + sources=tuple(sources), + symmetry=tuple(symmetry), + structures=tuple(aux_new_structures), + lumped_elements=tuple(new_lumped_elements), + internal_absorbers=tuple(internal_absorbers), **kwargs, ) @@ -2161,7 +2187,9 @@ def subsection( # 1) Perform validators not directly related to geometries new_sim = self.updated_copy(**new_sim_dict, deep=deep_copy, validate=True) # 2) Assemble the full simulation without validation - return new_sim.updated_copy(structures=new_structures, deep=deep_copy, validate=False) + return new_sim.updated_copy( + structures=tuple(new_structures), deep=deep_copy, validate=False + ) def _invalidate_solver_cache(self) -> None: """Clear cached attributes that become stale when subpixel changes.""" @@ -2274,7 +2302,7 @@ def _finalized_volumetric_structures(self) -> list[Structure]: return list(self.volumetric_structures) + modal_frames @cached_property - def _finalized_optical_medium_map(self) -> dict[MediumType, pydantic.NonNegativeInt]: + def _finalized_optical_medium_map(self) -> dict[MediumType, NonNegativeInt]: """Returns dict mapping medium to index in material in finalized simulation. Returns @@ -2402,8 +2430,8 @@ class Simulation(AbstractYeeGridSimulation): * `FDTD Walkthrough `_ """ - boundary_spec: BoundarySpec = pydantic.Field( - BoundarySpec(), + boundary_spec: BoundarySpec = Field( + default_factory=BoundarySpec, title="Boundaries", description="Specification of boundary conditions along each dimension. If ``None``, " "PML boundary conditions are applied on all sides.", @@ -2446,7 +2474,7 @@ class Simulation(AbstractYeeGridSimulation): * `Using FDTD to Compute a Transmission Spectrum `__ """ - courant: float = pydantic.Field( + courant: float = Field( 0.99, title="Normalized Courant Factor", description="Normalized Courant stability factor that is no larger than 1 when CFL " @@ -2456,6 +2484,7 @@ class Simulation(AbstractYeeGridSimulation): gt=0.0, le=1.0, ) + """The Courant-Friedrichs-Lewy (CFL) stability factor :math:`C`, controls time step to spatial step ratio. A physical wave has to propagate slower than the numerical information propagation in a Yee-cell grid. This is because in this spatially-discrete grid, information propagates over 1 spatial step :math:`\\Delta x` @@ -2528,7 +2557,7 @@ class Simulation(AbstractYeeGridSimulation): * `Numerical dispersion in FDTD `_ """ - precision: Literal["hybrid", "double"] = pydantic.Field( + precision: Literal["hybrid", "double"] = Field( "hybrid", title="Floating-point Precision", description="Floating point precision to use in the computations.", @@ -2546,7 +2575,7 @@ class Simulation(AbstractYeeGridSimulation): ``ModeSpec.precision`` argument, which only affects the eigenvalue solver. """ - lumped_elements: tuple[LumpedElementType, ...] = pydantic.Field( + lumped_elements: tuple[LumpedElementType, ...] = Field( (), title="Lumped Elements", description="Tuple of lumped elements in the simulation. ", @@ -2584,8 +2613,8 @@ class Simulation(AbstractYeeGridSimulation): * `Using lumped elements in Tidy3D simulations <../../notebooks/LinearLumpedElements.html>`_ """ - grid_spec: GridSpec = pydantic.Field( - GridSpec(), + grid_spec: GridSpec = Field( + default_factory=GridSpec, title="Grid Specification", description="Specifications for the simulation grid along each of the three directions.", ) @@ -2731,8 +2760,8 @@ class Simulation(AbstractYeeGridSimulation): * `Numerical dispersion in FDTD `_ """ - medium: MediumType3D = pydantic.Field( - Medium(), + medium: MediumType3D = Field( + default_factory=Medium, title="Background Medium", description="Background medium of simulation, defaults to vacuum if not specified.", discriminator=TYPE_TAG_STR, @@ -2763,7 +2792,7 @@ class Simulation(AbstractYeeGridSimulation): """ - normalize_index: Union[pydantic.NonNegativeInt, None] = pydantic.Field( + normalize_index: Optional[NonNegativeInt] = Field( 0, title="Normalization index", description="Index of the source in the tuple of sources whose spectrum will be used to " @@ -2775,7 +2804,7 @@ class Simulation(AbstractYeeGridSimulation): data. If ``None``, the raw field data is returned. If ``None``, the raw field data is returned unnormalized. """ - monitors: tuple[annotate_type(MonitorType), ...] = pydantic.Field( + monitors: tuple[discriminated_union(MonitorType), ...] = Field( (), title="Monitors", description="Tuple of monitors in the simulation. " @@ -2791,7 +2820,7 @@ class Simulation(AbstractYeeGridSimulation): All the monitor implementations. """ - sources: tuple[annotate_type(SourceType), ...] = pydantic.Field( + sources: tuple[discriminated_union(SourceType), ...] = Field( (), title="Sources", description="Tuple of electric current sources injecting fields into the simulation.", @@ -2828,7 +2857,7 @@ class Simulation(AbstractYeeGridSimulation): Frequency and time domain source models. """ - shutoff: pydantic.NonNegativeFloat = pydantic.Field( + shutoff: NonNegativeFloat = Field( 1e-5, title="Shutoff Condition", description="Ratio of the instantaneous integrated E-field intensity to the maximum value " @@ -2843,7 +2872,7 @@ class Simulation(AbstractYeeGridSimulation): Set to ``0`` to disable this feature. """ - structures: tuple[Structure, ...] = pydantic.Field( + structures: tuple[Structure, ...] = Field( (), title="Structures", description="Tuple of structures present in simulation. " @@ -2908,7 +2937,7 @@ class Simulation(AbstractYeeGridSimulation): * `Structures `_ """ - symmetry: tuple[Symmetry, Symmetry, Symmetry] = pydantic.Field( + symmetry: tuple[Symmetry, Symmetry, Symmetry] = Field( (0, 0, 0), title="Symmetries", description="Tuple of integers defining reflection symmetry across a plane " @@ -2938,8 +2967,7 @@ class Simulation(AbstractYeeGridSimulation): """ # TODO: at a later time (once well tested) we could consider making default of RunTimeSpec() - run_time: Union[pydantic.PositiveFloat, RunTimeSpec] = pydantic.Field( - ..., + run_time: Union[PositiveFloat, RunTimeSpec] = Field( title="Run Time", description="Total electromagnetic evolution time in seconds. " "Note: If simulation 'shutoff' is specified, " @@ -2999,7 +3027,7 @@ class Simulation(AbstractYeeGridSimulation): """ - low_freq_smoothing: Optional[LowFrequencySmoothingSpec] = pydantic.Field( + low_freq_smoothing: Optional[LowFrequencySmoothingSpec] = Field( None, title="Low Frequency Smoothing", description="The low frequency smoothing parameters for the simulation.", @@ -3007,25 +3035,26 @@ class Simulation(AbstractYeeGridSimulation): """ Validating setup """ - @pydantic.root_validator(pre=True) - def _update_simulation(cls, values): + @model_validator(mode="before") + @classmethod + def _update_simulation(cls, data: dict[str, Any]) -> dict[str, Any]: """Update the simulation if it is an earlier version.""" # if no version, assume it's already updated - if "version" not in values: - return values + if "version" not in data: + return data # otherwise, call the updator to update the values dictionary - updater = Updater(sim_dict=values) + updater = Updater(sim_dict=data) return updater.update_to_current() - @pydantic.validator("grid_spec", always=True) - @skip_if_fields_missing(["sources"]) - def _validate_auto_grid_wavelength(cls, val, values): + @model_validator(mode="after") + def _validate_auto_grid_wavelength(self) -> Self: """Check that wavelength can be defined if there is auto grid spec.""" + val = self.grid_spec if val.wavelength is None and val.auto_grid_used: - _ = val.wavelength_from_sources(sources=values.get("sources")) - return val + _ = val.wavelength_from_sources(sources=self.sources) + return self _sources_in_bounds = assert_objects_in_sim_bounds("sources", strict_inequality=True) _lumped_elements_in_bounds = assert_objects_contained_in_sim_bounds( @@ -3041,34 +3070,33 @@ def _validate_auto_grid_wavelength(cls, val, values): # _resolution_fine_enough = validate_resolution() # _plane_waves_in_homo = validate_plane_wave_intersections() - @pydantic.validator("boundary_spec", always=True) - @skip_if_fields_missing(["symmetry"]) - def bloch_with_symmetry(cls, val, values): + @model_validator(mode="after") + def bloch_with_symmetry(self) -> Self: """Error if a Bloch boundary is applied with symmetry""" + val = self.boundary_spec boundaries = val.to_list - symmetry = values.get("symmetry") + symmetry = self.symmetry for dim, boundary in enumerate(boundaries): num_bloch = sum(isinstance(bnd, BlochBoundary) for bnd in boundary) if num_bloch > 0 and symmetry[dim] != 0: raise SetupError( f"Bloch boundaries cannot be used with a symmetry along dimension {dim}." ) - return val + return self - @pydantic.validator("boundary_spec", always=True) - @skip_if_fields_missing(["medium", "size", "structures", "sources"]) - def plane_wave_boundaries(cls, val, values): + @model_validator(mode="after") + def plane_wave_boundaries(self) -> Self: """Error if there are plane wave sources incompatible with boundary conditions.""" - boundaries = val.to_list - sources = values.get("sources") - size = values.get("size") - sim_medium = values.get("medium") - structures = values.get("structures") + boundaries = self.boundary_spec.to_list + sources = self.sources + size = self.size + sim_medium = self.medium + structures = self.structures for source_ind, source in enumerate(sources): if not isinstance(source, PlaneWave): continue - _, tan_dirs = cls.pop_axis([0, 1, 2], axis=source.injection_axis) + _, tan_dirs = self.pop_axis([0, 1, 2], axis=source.injection_axis) medium_set = Scene.intersecting_media(source, structures) medium = medium_set.pop() if medium_set else sim_medium @@ -3099,7 +3127,7 @@ def plane_wave_boundaries(cls, val, values): else: num_bloch = sum(isinstance(bnd, (Periodic, BlochBoundary)) for bnd in boundary) if num_bloch > 0: - cls._check_bloch_vec( + self._check_bloch_vec( source=source, source_ind=source_ind, bloch_vec=boundary[0].bloch_vec, @@ -3107,28 +3135,27 @@ def plane_wave_boundaries(cls, val, values): medium=medium, domain_size=size[tan_dir], ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["boundary_spec", "medium", "size", "structures", "sources"]) - def bloch_boundaries_diff_mnt(cls, val, values): + @model_validator(mode="after") + def bloch_boundaries_diff_mnt(self) -> Self: """Error if there are diffraction monitors incompatible with boundary conditions.""" - monitors = val + monitors = self.monitors - if not val or not any(isinstance(mnt, DiffractionMonitor) for mnt in monitors): - return val + if not monitors or not any(isinstance(mnt, DiffractionMonitor) for mnt in monitors): + return self - boundaries = values.get("boundary_spec").to_list - sources = values.get("sources") - size = values.get("size") - sim_medium = values.get("medium") - structures = values.get("structures") + boundaries = self.boundary_spec.to_list + sources = self.sources + size = self.size + sim_medium = self.medium + structures = self.structures for source_ind, source in enumerate(sources): if not isinstance(source, PlaneWave): continue - _, tan_dirs = cls.pop_axis([0, 1, 2], axis=source.injection_axis) + _, tan_dirs = self.pop_axis([0, 1, 2], axis=source.injection_axis) medium_set = Scene.intersecting_media(source, structures) medium = medium_set.pop() if medium_set else sim_medium @@ -3138,7 +3165,7 @@ def bloch_boundaries_diff_mnt(cls, val, values): # check the Bloch boundary + angled plane wave case num_bloch = sum(isinstance(bnd, (Periodic, BlochBoundary)) for bnd in boundary) if num_bloch > 0: - cls._check_bloch_vec( + self._check_bloch_vec( source=source, source_ind=source_ind, bloch_vec=boundary[0].bloch_vec, @@ -3147,18 +3174,17 @@ def bloch_boundaries_diff_mnt(cls, val, values): domain_size=size[tan_dir], has_diff_mnt=True, ) - return val + return self - @pydantic.validator("boundary_spec", always=True) - @skip_if_fields_missing(["medium", "center", "size", "structures", "sources"]) - def tfsf_boundaries(cls, val, values): + @model_validator(mode="after") + def tfsf_boundaries(self) -> Self: """Error if the boundary conditions are incompatible with TFSF sources, if any.""" - boundaries = val.to_list - sources = values.get("sources") - size = values.get("size") - center = values.get("center") - sim_medium = values.get("medium") - structures = values.get("structures") + boundaries = self.boundary_spec.to_list + sources = self.sources + size = self.size + center = self.center + sim_medium = self.medium + structures = self.structures sim_bounds = [ [c - s / 2.0 for c, s in zip(center, size)], [c + s / 2.0 for c, s in zip(center, size)], @@ -3167,7 +3193,7 @@ def tfsf_boundaries(cls, val, values): if not isinstance(source, TFSF): continue - norm_dir, tan_dirs = cls.pop_axis([0, 1, 2], axis=source.injection_axis) + norm_dir, tan_dirs = self.pop_axis([0, 1, 2], axis=source.injection_axis) src_bounds = source.bounds # make a dummy source that represents the injection surface to get the intersecting @@ -3204,7 +3230,7 @@ def tfsf_boundaries(cls, val, values): # Bloch vector has been correctly set, similar to the check for plane waves num_bloch = sum(isinstance(bnd, (Periodic, BlochBoundary)) for bnd in boundary) if num_bloch == 2: - cls._check_bloch_vec( + self._check_bloch_vec( source=source, source_ind=src_idx, bloch_vec=boundary[0].bloch_vec, @@ -3221,17 +3247,15 @@ def tfsf_boundaries(cls, val, values): "unless that boundary is 'Periodic' or 'BlochBoundary'." ) - return val + return self - @pydantic.validator("sources", always=True) - @skip_if_fields_missing(["symmetry"]) - def tfsf_with_symmetry(cls, val, values): + @model_validator(mode="after") + def tfsf_with_symmetry(self) -> Self: """Error if a TFSF source is applied with symmetry""" - symmetry = values.get("symmetry") - for source in val: - if isinstance(source, TFSF) and not all(sym == 0 for sym in symmetry): + for source in self.sources: + if isinstance(source, TFSF) and not all(sym == 0 for sym in self.symmetry): raise SetupError("TFSF sources cannot be used with symmetries.") - return val + return self @staticmethod def _get_fixed_angle_sources(sources: tuple[SourceType, ...]) -> tuple[SourceType, ...]: @@ -3241,15 +3265,12 @@ def _get_fixed_angle_sources(sources: tuple[SourceType, ...]) -> tuple[SourceTyp source for source in sources if isinstance(source, PlaneWave) and source._is_fixed_angle ] - @pydantic.root_validator() - @skip_if_fields_missing( - ["sources", "structures", "medium", "monitors", "internal_absorbers"], root=True - ) - def check_fixed_angle_components(cls, values): + @model_validator(mode="after") + def check_fixed_angle_components(self) -> Self: """Error if a fixed-angle plane wave is combined with other sources or fully anisotropic mediums or gain mediums.""" - fixed_angle_sources = cls._get_fixed_angle_sources(values["sources"]) + fixed_angle_sources = self._get_fixed_angle_sources(self.sources) if len(fixed_angle_sources) > 0: if len(fixed_angle_sources) > 1: @@ -3257,9 +3278,9 @@ def check_fixed_angle_components(cls, values): "A fixed-angle plane wave source cannot be combined with other sources." ) - structures = values.get("structures") + structures = self.structures structures = structures or [] - medium_bg = values.get("medium") + medium_bg = self.medium mediums = [medium_bg] + [structure.medium for structure in structures] if any(med.is_fully_anisotropic for med in mediums): @@ -3282,22 +3303,23 @@ def check_fixed_angle_components(cls, values): "Fixed-angle plane wave sources cannot be used in the presence of gain materials." ) - if any(isinstance(mnt, TimeMonitor) for mnt in values["monitors"]): + if any(isinstance(mnt, TimeMonitor) for mnt in self.monitors): raise SetupError("Time monitors cannot be used in fixed-angle simulations.") - if len(values.get("internal_absorbers")) > 0: + if len(self.internal_absorbers) > 0: raise SetupError( "Fixed-angle plane wave sources cannot be used in the presence of internal absorbers." ) - return values + return self - @pydantic.root_validator() - @skip_if_fields_missing(["sources", "boundary_spec", "internal_absorbers"], root=True) - def _validate_frequency_mode_abc(cls, values): + @model_validator(mode="after") + def _validate_frequency_mode_abc(self) -> Self: """Warn if ModeABCBoundary expects a frequency from a source, but there are multiple sources with different central frequencies.""" - def boundary_needs_freq(boundary): + def boundary_needs_freq( + boundary: Union[ModeABCBoundary, ABCBoundary, BoundaryEdgeType], + ) -> bool: return (isinstance(boundary, ModeABCBoundary) and boundary.freq_spec is None) or ( isinstance(boundary, ABCBoundary) and ( @@ -3307,16 +3329,16 @@ def boundary_needs_freq(boundary): ) # check domain boundaries - boundaries = values["boundary_spec"].to_list + boundaries = self.boundary_spec.to_list need_wavelength = any(boundary_needs_freq(edge) for edge in np.ravel(boundaries)) # check dinternal absorbers need_wavelength = need_wavelength or any( - boundary_needs_freq(abc.boundary_spec) for abc in values["internal_absorbers"] + boundary_needs_freq(abc.boundary_spec) for abc in self.internal_absorbers ) if need_wavelength: - sources = values.get("sources") + sources = self.sources if len(sources) == 0: raise SetupError( @@ -3332,27 +3354,29 @@ def boundary_needs_freq(boundary): capture=False, ) - return values + return self - @pydantic.validator("internal_absorbers", always=True) - @skip_if_fields_missing(["size"]) - def _validate_absorber_in_zero_dims(cls, val, values): + @model_validator(mode="after") + def _validate_absorber_in_zero_dims(self) -> Self: """Error if internal absorber is oriented along zero size dim.""" - + val = self.internal_absorbers if val is None: return val - sim_size = values["size"] + sim_size = self.size for abc in val: if sim_size[abc._normal_axis] == 0: raise SetupError( "Port absorbers are not allowed to be oriented along simulation zero size dimensions." ) - return val + return self - @pydantic.validator("sources", always=True) - def _validate_num_sources(cls, val): + @field_validator("sources") + @classmethod + def _validate_num_sources( + cls, val: Optional[tuple[SourceType, ...]] + ) -> Optional[tuple[SourceType, ...]]: """Error if too many sources present.""" if val is None: @@ -3367,8 +3391,11 @@ def _validate_num_sources(cls, val): return val - @pydantic.validator("structures", always=True) - def _validate_2d_geometry_has_2d_medium(cls, val, values): + @field_validator("structures") + @classmethod + def _validate_2d_geometry_has_2d_medium( + cls, val: tuple[Structure, ...] + ) -> tuple[Structure, ...]: """Warn if a geometry bounding box has zero size in a certain dimension.""" if val is None: @@ -3392,8 +3419,11 @@ def _validate_2d_geometry_has_2d_medium(cls, val, values): return val - @pydantic.validator("structures", always=True) - def _validate_incompatible_material_intersections(cls, val, values): + @field_validator("structures") + @classmethod + def _validate_incompatible_material_intersections( + cls, val: tuple[Structure, ...] + ) -> tuple[Structure, ...]: """Check for intersections of incompatible materials.""" structures = val incompatible_indices = [] @@ -3422,24 +3452,24 @@ def _validate_incompatible_material_intersections(cls, val, values): ) return val - @pydantic.validator("boundary_spec", always=True) - @skip_if_fields_missing(["sources", "center", "size", "structures"]) - def _structures_not_close_pml(cls, val, values): + @model_validator(mode="after") + def _structures_not_close_pml(self) -> Self: """Warn if any structures lie at the simulation boundaries.""" + val = self.boundary_spec - sim_box = Box(size=values.get("size"), center=values.get("center")) + sim_box = Box(size=self.size, center=self.center) sim_bound_min, sim_bound_max = sim_box.bounds boundaries = val.to_list - structures = values.get("structures") - sources = values.get("sources") + structures = self.structures + sources = self.sources if (not structures) or (not sources): - return val + return self with log as consolidated_logger: - def warn(structure, istruct, side) -> None: + def warn(structure: Structure, istruct: int, side: str) -> None: """Warning message for a structure too close to PML.""" obj_descr = named_obj_descr(structure, "structures", istruct) consolidated_logger.warning( @@ -3480,19 +3510,18 @@ def warn(structure, istruct, side) -> None: ): warn(structure, istruct, axis + "-max") - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["medium", "structures"]) - def _warn_monitor_mediums_frequency_range(cls, val, values): + @model_validator(mode="after") + def _warn_monitor_mediums_frequency_range(self) -> Self: """Warn user if any DFT monitors have frequencies outside of medium frequency range.""" + val = self.monitors if val is None: - return val + return self - structures = values.get("structures") - structures = structures or [] - medium_bg = values.get("medium") + structures = self.structures or [] + medium_bg = self.medium mediums = [medium_bg] + [structure.medium for structure in structures] with log as consolidated_logger: @@ -3510,7 +3539,7 @@ def _warn_monitor_mediums_frequency_range(cls, val, values): # make sure medium frequency range includes all monitor frequencies fmin_med, fmax_med = medium.frequency_range - sci_fmin_med, sci_fmax_med = cls._scientific_notation(fmin_med, fmax_med) + sci_fmin_med, sci_fmax_med = self._scientific_notation(fmin_med, fmax_med) if fmin_mon < fmin_med or fmax_mon > fmax_med: if medium_index == 0: @@ -3521,7 +3550,7 @@ def _warn_monitor_mediums_frequency_range(cls, val, values): medium_str = f"The medium associated with {medium_descr}" custom_loc = [ "structures", - medium_index - 1, + str(medium_index - 1), "medium", "frequency_range", ] @@ -3534,29 +3563,28 @@ def _warn_monitor_mediums_frequency_range(cls, val, values): "This can cause inaccuracies in the recorded results.", custom_loc=custom_loc, ) + return self - return val - - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["sources"]) - def _warn_monitor_simulation_frequency_range(cls, val, values): + @model_validator(mode="after") + def _warn_monitor_simulation_frequency_range(self) -> Self: """Warn if any DFT monitors have frequencies outside of the simulation frequency range.""" + val = self.monitors if val is None: - return val + return self source_ranges = [ - source.source_time._frequency_range_sigma_cached for source in values["sources"] + source.source_time._frequency_range_sigma_cached for source in self.sources ] if not source_ranges: # Commented out to eliminate this message from Mode real time log in GUI # TODO: Bring it back when it doesn't interfere with mode solver # log.info("No sources in simulation.") - return val + return self freq_min = min((freq_range[0] for freq_range in source_ranges), default=0.0) freq_max = max((freq_range[1] for freq_range in source_ranges), default=0.0) - sci_fmin, sci_fmax = cls._scientific_notation(freq_min, freq_max) + sci_fmin, sci_fmax = self._scientific_notation(freq_min, freq_max) with log as consolidated_logger: for monitor_index, monitor in enumerate(val): @@ -3571,15 +3599,14 @@ def _warn_monitor_simulation_frequency_range(cls, val, values): "(Hz) as defined by the sources.", custom_loc=["monitors", monitor_index, "freqs"], ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["boundary_spec"]) - def diffraction_monitor_boundaries(cls, val, values): + @model_validator(mode="after") + def diffraction_monitor_boundaries(self) -> Self: """If any :class:`.DiffractionMonitor` exists, ensure boundary conditions in the transverse directions are periodic or Bloch.""" - monitors = val - boundary_spec = values.get("boundary_spec") + monitors = self.monitors + boundary_spec = self.boundary_spec for monitor in monitors: if isinstance(monitor, DiffractionMonitor): _, (n_x, n_y) = monitor.pop_axis(["x", "y", "z"], axis=monitor.normal_axis) @@ -3596,26 +3623,26 @@ def diffraction_monitor_boundaries(cls, val, values): f"The 'DiffractionMonitor' {monitor.name} requires periodic " f"or Bloch boundaries along dimensions {n_x} and {n_y}." ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["medium", "center", "size", "structures"]) - def _projection_monitors_homogeneous(cls, val, values): + @model_validator(mode="after") + def _projection_monitors_homogeneous(self) -> Self: """Error if any field projection monitor is not in a homogeneous region.""" + val = self.monitors if val is None: - return val + return self # list of structures including background as a Box() structure_bg = Structure( geometry=Box( - size=values.get("size"), - center=values.get("center"), + size=self.size, + center=self.center, ), - medium=values.get("medium"), + medium=self.medium, ) - structures = values.get("structures") or [] + structures = self.structures or [] total_structures = [structure_bg, *list(structures)] with log as consolidated_logger: @@ -3643,7 +3670,7 @@ def _projection_monitors_homogeneous(cls, val, values): custom_loc=["monitors", monitor_ind], ) - return val + return self @classmethod def _get_mediums_on_abc( @@ -3682,23 +3709,22 @@ def _get_mediums_on_abc( return mediums - @pydantic.validator("boundary_spec", always=True) - @skip_if_fields_missing(["medium", "center", "size", "structures"]) - def _abc_boundaries_homogeneous(cls, val, values): + @model_validator(mode="after") + def _abc_boundaries_homogeneous(self) -> Self: """Error if abc boundaries intersect multiple mediums or anisotropic mediums.""" - + val = self.boundary_spec if val is None: return val sim_structure = Structure( - geometry=Box(size=values.get("size"), center=values.get("center")), - medium=values.get("medium"), + geometry=Box(size=self.size, center=self.center), + medium=self.medium, ) - mediums_all_sides = cls._get_mediums_on_abc( + mediums_all_sides = self._get_mediums_on_abc( boundary_spec=val, sim_structure=sim_structure, - structures=values.get("structures") or [], + structures=self.structures or [], ) with log as consolidated_logger: @@ -3731,15 +3757,15 @@ def _abc_boundaries_homogeneous(cls, val, values): "Boundary medium must be homogeneous and isotropic." ) - return val + return self - @pydantic.validator("monitors", always=True) - def _projection_direction(cls, val, values): + @field_validator("monitors") + @classmethod + def _projection_direction(cls, val: tuple[MonitorType, ...]) -> tuple[MonitorType, ...]: """Warn if field projection observation points are behind surface projection monitors.""" # This validator is in simulation.py rather than monitor.py because volume monitors are # eventually converted to their bounding surface projection monitors, in which case we # do not want this validator to be triggered. - if val is None: return val @@ -3796,15 +3822,16 @@ def _projection_direction(cls, val, values): return val - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["size"]) - def proj_distance_for_approx(cls, val, values): + @model_validator(mode="after") + def proj_distance_for_approx(self) -> Self: """Warn if projection distance for projection monitors is not large compared to monitor or, simulation size, yet far_field_approx is True.""" + val = self.monitors + if val is None: - return val + return self - sim_size = values.get("size") + sim_size = self.size with log as consolidated_logger: for monitor_ind, monitor in enumerate(val): @@ -3823,18 +3850,18 @@ def proj_distance_for_approx(cls, val, values): "size of the monitor that records near fields.", custom_loc=["monitors", monitor_ind], ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["center", "size"]) - def _integration_surfaces_in_bounds(cls, val, values): + @model_validator(mode="after") + def _integration_surfaces_in_bounds(self) -> Self: """Error if all of the integration surfaces are outside of the simulation domain.""" + val = self.monitors if val is None: - return val + return self - sim_center = values.get("center") - sim_size = values.get("size") + sim_center = self.center + sim_size = self.size sim_box = Box(size=sim_size, center=sim_center) for mnt in (mnt for mnt in val if isinstance(mnt, SurfaceIntegrationMonitor)): @@ -3844,17 +3871,17 @@ def _integration_surfaces_in_bounds(cls, val, values): "simulation bounds." ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["size"]) - def _projection_monitors_distance(cls, val, values): + @model_validator(mode="after") + def _projection_monitors_distance(self) -> Self: """Warn if the projection distance is large for exact projections.""" + val = self.monitors if val is None: - return val + return self - sim_size = values.get("size") + sim_size = self.size with log as consolidated_logger: for idx, monitor in enumerate(val): @@ -3877,11 +3904,10 @@ def _projection_monitors_distance(cls, val, values): "available.", custom_loc=["monitors", idx, "proj_distance"], ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["size"]) - def _projection_mnts_2d(cls, val, values): + @model_validator(mode="after") + def _projection_mnts_2d(self) -> Self: """ Validate if the field projection monitor is set up for a 2D simulation and ensure the observation parameters are configured correctly. @@ -3893,16 +3919,17 @@ def _projection_mnts_2d(cls, val, values): Note: Exact far field projection is not available yet. Currently, only ``far_field_approx = True`` is supported. """ + val = self.monitors if val is None: - return val + return self - sim_size = values.get("size") + sim_size = self.size # Validation if is 3D simulation non_zero_dims = sum(1 for size in sim_size if size != 0) if non_zero_dims == 3: - return val + return self if sim_size[0] == 0: plane = "y-z" @@ -3982,15 +4009,14 @@ def _projection_mnts_2d(cls, val, values): f"'{monitor.name}' should be set to '[0]'." ) - return val + return self - @pydantic.validator("monitors", always=True) - @skip_if_fields_missing(["medium", "structures"]) - def diffraction_and_directivity_monitor_medium(cls, val, values): + @model_validator(mode="after") + def diffraction_and_directivity_monitor_medium(self) -> Self: """If any :class:`.DiffractionMonitor` or :class:`.DirectivityMonitor` exists, ensure it does not lie in a lossy medium.""" - monitors = val - structures = values.get("structures") - medium = values.get("medium") + monitors = self.monitors + structures = self.structures + medium = self.medium for monitor in monitors: if isinstance(monitor, (DiffractionMonitor, DirectivityMonitor)): medium_set = Scene.intersecting_media(monitor, structures) @@ -4001,23 +4027,23 @@ def diffraction_and_directivity_monitor_medium(cls, val, values): _, index_k = medium.nk_model(frequency=freqs) if not np.all(index_k == 0): raise SetupError(f"'{monitor.type}' must not lie in a lossy medium.") - return val + return self - @pydantic.validator("grid_spec", always=True) - @skip_if_fields_missing(["medium", "sources", "structures"]) - def _warn_grid_size_too_small(cls, val, values): + @model_validator(mode="after") + def _warn_grid_size_too_small(self) -> Self: """Warn user if any grid size is too large compared to minimum wavelength in material.""" + val = self.grid_spec if val is None: - return val + return self - structures = values.get("structures") + structures = self.structures structures = structures or [] - medium_bg = values.get("medium") + medium_bg = self.medium mediums = [medium_bg] + [structure.to_static().medium for structure in structures] with log as consolidated_logger: - for source_index, source in enumerate(values.get("sources")): + for source_index, source in enumerate(self.sources): freq0 = source.source_time._freq0 for medium_index, medium in enumerate(mediums): @@ -4064,28 +4090,28 @@ def _warn_grid_size_too_small(cls, val, values): ) # TODO: warn about custom grid spec - return val + return self - @pydantic.validator("sources", always=True) - @skip_if_fields_missing(["medium", "center", "size", "structures"]) - def _source_homogeneous_isotropic(cls, val, values): + @model_validator(mode="after") + def _source_homogeneous_isotropic(self) -> Self: """Error if a plane wave or gaussian beam source is not in a homogeneous and isotropic region. """ + val = self.sources if val is None: - return val + return self # list of structures including background as a Box() structure_bg = Structure( geometry=Box( - size=values.get("size"), - center=values.get("center"), + size=self.size, + center=self.center, ), - medium=values.get("medium"), + medium=self.medium, ) - structures = values.get("structures") or [] + structures = self.structures or [] total_structures = [structure_bg, *list(structures)] # for each plane wave in the sources list @@ -4132,6 +4158,7 @@ def _source_homogeneous_isotropic(cls, val, values): "A fixed angle plane wave can only be injected into a homogeneous isotropic" "dispersionless medium." ) + # check if broadband angled gaussian beam frequency variation is too fast if ( isinstance(source, (GaussianBeam, AstigmaticGaussianBeam)) @@ -4139,7 +4166,7 @@ def _source_homogeneous_isotropic(cls, val, values): and source.num_freqs > 1 ): - def radius(waist_radius, waist_distance, k0): + def radius(waist_radius: float, waist_distance: float, k0: float) -> float: """Gaussian beam radius at a given waist distance and k0.""" z_r = waist_radius**2 * k0 / 2 return waist_radius * np.sqrt(1 + (waist_distance / z_r) ** 2) @@ -4179,18 +4206,18 @@ def radius(waist_radius, waist_distance, k0): "source injection in an empty simulation.", ) - return val + return self - @pydantic.validator("normalize_index", always=True) - @skip_if_fields_missing(["sources"]) - def _check_normalize_index(cls, val, values): + @model_validator(mode="after") + def _check_normalize_index(self) -> Self: """Check validity of normalize index in context of simulation.sources.""" + val = self.normalize_index # not normalizing if val is None: - return val + return self - sources = values.get("sources") + sources = self.sources num_sources = len(sources) if num_sources > 0: # No check if no sources, but it should be irrelevant anyway @@ -4218,15 +4245,16 @@ def _check_normalize_index(cls, val, values): "source is only meaningful if field decay occurs." ) - return val + return self - @pydantic.validator("low_freq_smoothing", always=True) - def _validate_low_freq_smoothing(cls, val, values): + @model_validator(mode="after") + def _validate_low_freq_smoothing(self) -> Self: """Validate the low frequency smoothing parameters.""" # check that all monitors are present and they are mode monitors + val = self.low_freq_smoothing if val is None: - return val - monitors = values.get("monitors") + return self + monitors = self.monitors present_mode_monitor_names = [ monitor.name for monitor in monitors if isinstance(monitor, ModeMonitor) ] @@ -4235,23 +4263,10 @@ def _validate_low_freq_smoothing(cls, val, values): raise SetupError( f"Low frequency smoothing specification refers to monitor '{monitor}' which either does not exist or is not a mode monitor." ) - return val + return self - """ Post-init validators """ - - def _post_init_validators(self) -> None: - """Call validators taking z`self` that get run after init.""" - _ = self.scene - self._validate_no_structures_pml() - self._validate_tfsf_nonuniform_grid() - self._validate_tfsf_aux_sources() - self._validate_nonlinear_specs() - self._validate_custom_source_time() - self._validate_mode_objects() - self._warn_rf_license() - self._validate_internal_abc_no_fully_anisotropic() - - def _warn_rf_license(self) -> None: + @model_validator(mode="after") + def _warn_rf_license(self) -> Self: """ Warn about new licensing requirements for RF simulations. This function details all the conditions in which a simulation is categorised as RF simulation at the backend. @@ -4285,7 +4300,10 @@ def _warn_rf_license(self) -> None: msg += rf_component_breakdown_msg log.warning(msg, log_once=True) - def _validate_mode_objects(self) -> None: + return self + + @model_validator(mode="after") + def _validate_mode_objects(self) -> Self: """Create a ModeSolver for each mode object in order to validate.""" from .mode.mode_solver import ModeSolver @@ -4337,7 +4355,10 @@ def validate_mode_object( except Exception as e: raise SetupError(f"Source at 'sources[{isrc}]' failed validation: {e!s}") from e - def _validate_custom_source_time(self) -> None: + return self + + @model_validator(mode="after") + def _validate_custom_source_time(self) -> Self: """Warn if all simulation times are outside CustomSourceTime definition range.""" run_time = self._run_time for idx, source in enumerate(self.sources): @@ -4354,8 +4375,10 @@ def _validate_custom_source_time(self) -> None: "from the first or last value in the 'CustomSourceTime', which may not " "be the desired outcome." ) + return self - def _validate_no_structures_pml(self) -> None: + @model_validator(mode="after") + def _validate_no_structures_pml(self) -> Self: """Ensure no structures terminate / have bounds inside of PML.""" pml_thicks = np.array(self.pml_thicknesses).T @@ -4363,7 +4386,7 @@ def _validate_no_structures_pml(self) -> None: bound_spec = self.boundary_spec.to_list with log as consolidated_logger: - for i, structure in enumerate(self.structures): + for i, structure in enumerate(self.static_structures): geo_bounds = structure.geometry.bounds warn = False # will only warn once per structure for sim_bound, geo_bound, pml_thick, bound_dim, pm_val in zip( @@ -4388,13 +4411,16 @@ def _validate_no_structures_pml(self) -> None: custom_loc=["structures", i], ) - def _validate_tfsf_nonuniform_grid(self) -> None: + return self + + @model_validator(mode="after") + def _validate_tfsf_nonuniform_grid(self) -> Self: """Warn if the grid is nonuniform along the directions tangential to the injection plane, inside the TFSF box. """ # if the grid is uniform in all directions, there's no need to proceed if not (self.grid_spec.snapped_grid_used or self.grid_spec.custom_grid_used): - return + return self with log as consolidated_logger: for source_ind, source in enumerate(self.sources): @@ -4428,6 +4454,7 @@ def _validate_tfsf_nonuniform_grid(self) -> None: f"axis, '{'xyz'[source.injection_axis]}'.", custom_loc=["sources", source_ind], ) + return self def _aux_tfsf_source(self, source: TFSF) -> PlaneWave: """Create the auxiliary plane wave source for a give TFSF source.""" @@ -4471,13 +4498,16 @@ def _aux_tfsf_source(self, source: TFSF) -> PlaneWave: num_freqs=source.num_freqs, ) - def _validate_tfsf_aux_sources(self) -> None: + @model_validator(mode="after") + def _validate_tfsf_aux_sources(self) -> Self: """Validate that PlaneWave sources auxiliary to TFSF sources can be successfully created.""" for source in self.sources: if isinstance(source, TFSF): _ = self._aux_tfsf_source(source) + return self - def _validate_nonlinear_specs(self) -> None: + @model_validator(mode="after") + def _validate_nonlinear_specs(self) -> Self: """Run :class:`.NonlinearSpec` validators that depend on knowing the central frequencies of the sources. Also print some warnings only once per unique medium.""" freqs = np.array([source.source_time._freq0 for source in self.sources]) @@ -4498,6 +4528,8 @@ def _validate_nonlinear_specs(self) -> None: "will be zero." ) + return self + @cached_property def aux_fields(self) -> list[str]: """All aux fields available in the simulation.""" @@ -4507,16 +4539,18 @@ def aux_fields(self) -> list[str]: fields += medium.nonlinear_spec.aux_fields return fields - def _validate_internal_abc_no_fully_anisotropic(self) -> None: + @model_validator(mode="after") + def _validate_internal_abc_no_fully_anisotropic(self) -> Self: """Error if internal absorber intersect fully anisotropic mediums.""" total_structures = [self.scene.background_structure, *list(self.structures)] for abc in self._shifted_internal_absorbers: - mediums = Scene.intersecting_media(abc, total_structures) + mediums = Scene.intersecting_media(abc, tuple(total_structures)) if any(isinstance(med, FullyAnisotropicMedium) for med in mediums): raise SetupError("A 'InternalAbsorber' cannot cross a 'FullyAnisotropicMedium'.") + return self """ Pre submit validation (before web.upload()) """ @@ -4884,6 +4918,7 @@ def _make_adjoint_monitors(self, sim_fields_keys: list) -> tuple[list, list]: index_to_keys[index].append(fields) freqs = self._freqs_adjoint + sim_plane = self if self.size.count(0.0) == 1 else None adjoint_monitors_fld = [] adjoint_monitors_eps = [] @@ -4893,7 +4928,7 @@ def _make_adjoint_monitors(self, sim_fields_keys: list) -> tuple[list, list]: structure = self.structures[i] mnt_fld, mnt_eps = structure._make_adjoint_monitors( - freqs=freqs, index=i, field_keys=field_keys + freqs=freqs, index=i, field_keys=field_keys, plane=sim_plane ) adjoint_monitors_fld.append(mnt_fld) @@ -4901,6 +4936,22 @@ def _make_adjoint_monitors(self, sim_fields_keys: list) -> tuple[list, list]: return adjoint_monitors_fld, adjoint_monitors_eps + def _check_custom_medium_geometry_overlap(self, sim_fields_keys: AutogradFieldMap) -> None: + index_to_keys = defaultdict(list) + + for _, index, *fields in sim_fields_keys: + index_to_keys[index].append(fields) + + for structure_index, gradient_paths in index_to_keys.items(): + if self.structures[structure_index].medium.is_custom: + gradient_type_tags = [path[0] for path in gradient_paths] + if "geometry" in gradient_type_tags: + raise AdjointError( + f"Detected structure at index {structure_index} containing a CustomMedium type " + "and traced geometry attributes. Combined shape and medium derivatives like this " + "are not currently supported." + ) + @property def _freqs_adjoint(self) -> list[float]: """Unique list of all frequencies. For now should be only one.""" @@ -4962,7 +5013,7 @@ def mediums(self) -> set[MediumType]: Returns ------- - set[:class:`.AbstractMedium`] + List[:class:`.AbstractMedium`] Set of distinct mediums in the simulation. """ log.warning( @@ -4973,14 +5024,14 @@ def mediums(self) -> set[MediumType]: # candidate for removal in 3.0 @cached_property - def medium_map(self) -> dict[MediumType, pydantic.NonNegativeInt]: + def medium_map(self) -> dict[MediumType, NonNegativeInt]: """Returns dict mapping medium to index in material. ``medium_map[medium]`` returns unique global index of :class:`.AbstractMedium` in simulation. Returns ------- - Dict[:class:`.AbstractMedium`, int] + dict[:class:`.AbstractMedium`, int] Mapping between distinct mediums to index in simulation. """ @@ -5024,7 +5075,7 @@ def intersecting_media( ------- test_object : :class:`.Box` Object for which intersecting media are to be detected. - structures : tuple[:class:`.AbstractMedium`] + structures : List[:class:`.AbstractMedium`] List of structures whose media will be tested. Returns @@ -5067,7 +5118,7 @@ def intersecting_structures( ) return Scene.intersecting_structures(test_object=test_object, structures=structures) - def monitor_medium(self, monitor: MonitorType): + def monitor_medium(self, monitor: MonitorType) -> AbstractMedium: """Return the medium in which the given monitor resides. Parameters @@ -5140,10 +5191,10 @@ def to_gdstk( x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, gds_layer_dtype_map: Optional[ - dict[AbstractMedium, tuple[pydantic.NonNegativeInt, pydantic.NonNegativeInt]] + dict[AbstractMedium, tuple[NonNegativeInt, NonNegativeInt]] ] = None, pixel_exact: bool = False, ) -> list: @@ -5211,14 +5262,14 @@ def to_gdstk( def to_gds( self, - cell, + cell: gdstk.Cell, x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, gds_layer_dtype_map: Optional[ - dict[AbstractMedium, tuple[pydantic.NonNegativeInt, pydantic.NonNegativeInt]] + dict[AbstractMedium, tuple[NonNegativeInt, NonNegativeInt]] ] = None, pixel_exact: bool = False, ) -> None: @@ -5273,10 +5324,10 @@ def to_gds_file( x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, gds_layer_dtype_map: Optional[ - dict[AbstractMedium, tuple[pydantic.NonNegativeInt, pydantic.NonNegativeInt]] + dict[AbstractMedium, tuple[NonNegativeInt, NonNegativeInt]] ] = None, gds_cell_name: str = "MAIN", pixel_exact: bool = False, @@ -5351,7 +5402,7 @@ def frequency_range(self) -> FreqBound: Returns ------- - Tuple[float, float] + tuple[float, float] Minimum and maximum frequencies of the power spectrum of the sources. """ source_ranges = [ @@ -5362,7 +5413,7 @@ def frequency_range(self) -> FreqBound: return (freq_min, freq_max) - def plot_3d(self, width=800, height=800) -> None: + def plot_3d(self, width: int = 800, height: int = 800) -> None: """Render 3D plot of ``Simulation`` (in jupyter notebook only). Parameters ---------- @@ -5523,7 +5574,7 @@ def num_cells(self) -> int: return int(np.prod([float(nc) for nc in self.grid.num_cells])) @property - def _num_computational_grid_points_dim(self): + def _num_computational_grid_points_dim(self) -> list[int]: """Number of cells in the computational domain for this simulation along each dimension.""" num_cells = self.grid.num_cells num_cells_comp_domain = [] @@ -5538,7 +5589,7 @@ def _num_computational_grid_points_dim(self): return num_cells_comp_domain @property - def num_computational_grid_points(self): + def num_computational_grid_points(self) -> int: """Number of cells in the computational domain for this simulation. This is usually different from ``num_cells`` due to the boundary conditions. Specifically, all boundary conditions apart from :class:`Periodic` require an extra pixel at the end of the simulation @@ -5778,7 +5829,7 @@ def perturbed_mediums_copy( normal_axis=data.normal_axis, ) - sim_dict = self.dict() + sim_dict = self.model_dump() structures = self.structures sim_bounds = self.simulation_bounds array_dict = { @@ -5850,7 +5901,7 @@ def perturbed_mediums_copy( **restricted_arrays, interp_method=interp_method ) - return Simulation.parse_obj(sim_dict) + return Simulation.model_validate(sim_dict) @classmethod def from_scene(cls, scene: Scene, **kwargs: Any) -> Simulation: @@ -5893,19 +5944,19 @@ def from_scene(cls, scene: Scene, **kwargs: Any) -> Simulation: def padded_copy( self, - x: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, - y: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, - z: Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None, + x: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = None, + y: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = None, + z: Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = None, ) -> Simulation: """Created a copy of simulation with padded simulation domain. Parameters ---------- - x : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None + x : Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = None Padding sizes at the left and right boundaries of the simulation along x-axis. - y : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None + y : Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = None Padding sizes at the left and right boundaries of the simulation along y-axis. - z : Optional[tuple[pydantic.NonNegativeFloat, pydantic.NonNegativeFloat]] = None + z : Optional[tuple[NonNegativeFloat, NonNegativeFloat]] = None Padding sizes at the left and right boundaries of the simulation along z-axis. Returns @@ -5919,12 +5970,12 @@ def padded_copy( return self.updated_copy(size=padded_box.size, center=padded_box.center) - def uniformly_padded_copy(self, padding: pydantic.NonNegativeFloat) -> Simulation: + def uniformly_padded_copy(self, padding: NonNegativeFloat) -> Simulation: """Create copy of simulation with uniformly padded simulation domain. Parameters ---------- - padding : pydantic.NonNegativeFloat + padding : NonNegativeFloat Padding size applied uniformly at all simulation boundaries. Returns diff --git a/tidy3d/components/source/base.py b/tidy3d/components/source/base.py index 0b76239d08..f080f49ad7 100644 --- a/tidy3d/components/source/base.py +++ b/tidy3d/components/source/base.py @@ -1,130 +1,10 @@ -"""Defines an abstract base for electromagnetic sources.""" +"""Compatibility shim for :mod:`tidy3d._common.components.source.base`.""" -from __future__ import annotations - -from abc import ABC -from typing import Any, Optional +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import pydantic.v1 as pydantic +# marked as migrated to _common +from __future__ import annotations -from tidy3d.components.base import cached_property -from tidy3d.components.base_sim.source import AbstractSource -from tidy3d.components.geometry.base import Box -from tidy3d.components.types import TYPE_TAG_STR, Ax -from tidy3d.components.validators import _assert_min_freq, _warn_unsupported_traced_argument -from tidy3d.components.viz import ( - ARROW_ALPHA, - ARROW_COLOR_POLARIZATION, - ARROW_COLOR_SOURCE, - PlotParams, - plot_params_source, +from tidy3d._common.components.source.base import ( + Source, ) - -from .time import SourceTimeType - - -class Source(Box, AbstractSource, ABC): - """Abstract base class for all sources.""" - - source_time: SourceTimeType = pydantic.Field( - ..., - title="Source Time", - description="Specification of the source time-dependence.", - discriminator=TYPE_TAG_STR, - ) - - @cached_property - def plot_params(self) -> PlotParams: - """Default parameters for plotting a Source object.""" - return plot_params_source - - @cached_property - def geometry(self) -> Box: - """:class:`Box` representation of source.""" - - return Box(center=self.center, size=self.size) - - @cached_property - def _injection_axis(self) -> None: - """Injection axis of the source.""" - return - - @cached_property - def _dir_vector(self) -> tuple[float, float, float]: - """Returns a vector indicating the source direction for arrow plotting, if not None.""" - return None - - @cached_property - def _pol_vector(self) -> tuple[float, float, float]: - """Returns a vector indicating the source polarization for arrow plotting, if not None.""" - return None - - _warn_traced_center = _warn_unsupported_traced_argument("center") - _warn_traced_size = _warn_unsupported_traced_argument("size") - - @pydantic.validator("source_time", always=True) - def _freqs_lower_bound(cls, val): - """Raise validation error if central frequency is too low.""" - _assert_min_freq(val._freq0_sigma_centroid, msg_start="'source_time.freq0'") - return val - - def plot( - self, - x: Optional[float] = None, - y: Optional[float] = None, - z: Optional[float] = None, - ax: Ax = None, - **patch_kwargs: Any, - ) -> Ax: - """Plot this source.""" - - kwargs_arrow_base = patch_kwargs.pop("arrow_base", None) - - # call the `Source.plot()` function first. - ax = Box.plot(self, x=x, y=y, z=z, ax=ax, **patch_kwargs) - - kwargs_alpha = patch_kwargs.get("alpha") - arrow_alpha = ARROW_ALPHA if kwargs_alpha is None else kwargs_alpha - - # then add the arrow based on the propagation direction - if self._dir_vector is not None: - bend_radius = None - bend_axis = None - if hasattr(self, "mode_spec") and self.mode_spec.bend_radius is not None: - bend_radius = self.mode_spec.bend_radius - bend_axis = self._bend_axis - sign = 1 if self.direction == "+" else -1 - # Curvature has to be reversed because of ploting coordinates - if (self.size.index(0), bend_axis) in [(1, 2), (2, 0), (2, 1)]: - bend_radius *= -sign - else: - bend_radius *= sign - - ax = self._plot_arrow( - x=x, - y=y, - z=z, - ax=ax, - direction=self._dir_vector, - bend_radius=bend_radius, - bend_axis=bend_axis, - color=ARROW_COLOR_SOURCE, - alpha=arrow_alpha, - both_dirs=False, - arrow_base=kwargs_arrow_base, - ) - - if self._pol_vector is not None: - ax = self._plot_arrow( - x=x, - y=y, - z=z, - ax=ax, - direction=self._pol_vector, - color=ARROW_COLOR_POLARIZATION, - alpha=arrow_alpha, - both_dirs=False, - arrow_base=kwargs_arrow_base, - ) - - return ax diff --git a/tidy3d/components/source/current.py b/tidy3d/components/source/current.py index 994ef1eff3..a7a53770a3 100644 --- a/tidy3d/components/source/current.py +++ b/tidy3d/components/source/current.py @@ -4,10 +4,9 @@ from abc import ABC from math import cos, isclose, sin -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional -import pydantic.v1 as pydantic -from typing_extensions import Literal +from pydantic import Field from tidy3d.components.base import cached_property from tidy3d.components.data.dataset import FieldDataset @@ -17,14 +16,15 @@ from tidy3d.constants import MICROMETER from .base import Source -from .time import SourceTimeType + +if TYPE_CHECKING: + from .time import SourceTimeType class CurrentSource(Source, ABC): """Source implements a current distribution directly.""" - polarization: Polarization = pydantic.Field( - ..., + polarization: Polarization = Field( title="Polarization", description="Specifies the direction and type of current component.", ) @@ -36,13 +36,13 @@ def _pol_vector(self) -> tuple[float, float, float]: pol_axis = "xyz".index(component) pol_vec = [0, 0, 0] pol_vec[pol_axis] = 1 - return pol_vec + return tuple(pol_vec) class ReverseInterpolatedSource(Source): """Abstract source that allows reverse-interpolation along zero-sized dimensions.""" - interpolate: bool = pydantic.Field( + interpolate: bool = Field( True, title="Enable Interpolation", description="Handles reverse-interpolation of zero-size dimensions of the source. " @@ -51,7 +51,7 @@ class ReverseInterpolatedSource(Source): "placement at the specified location using linear interpolation.", ) - confine_to_bounds: bool = pydantic.Field( + confine_to_bounds: bool = Field( False, title="Confine to Analytical Bounds", description="If ``True``, any source amplitudes which, after discretization, fall beyond " @@ -102,7 +102,7 @@ class PointDipole(CurrentSource, ReverseInterpolatedSource): * `Adjoint optimization of quantum emitter light extraction to an integrated waveguide <../../notebooks/AdjointPlugin12LightExtractor.html>`_ """ - size: tuple[Literal[0], Literal[0], Literal[0]] = pydantic.Field( + size: tuple[Literal[0], Literal[0], Literal[0]] = Field( (0, 0, 0), title="Size", description="Size in x, y, and z directions, constrained to ``(0, 0, 0)``.", @@ -208,8 +208,7 @@ class CustomCurrentSource(ReverseInterpolatedSource): * `Defining spatially-varying sources <../../notebooks/CustomFieldSource.html>`_ """ - current_dataset: Optional[FieldDataset] = pydantic.Field( - ..., + current_dataset: Optional[FieldDataset] = Field( title="Current Dataset", description=":class:`.FieldDataset` containing the desired frequency-domain " "electric and magnetic current patterns to inject.", diff --git a/tidy3d/components/source/field.py b/tidy3d/components/source/field.py index 4255561698..0c94af0ea9 100644 --- a/tidy3d/components/source/field.py +++ b/tidy3d/components/source/field.py @@ -3,17 +3,17 @@ from __future__ import annotations from abc import ABC -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np -import pydantic.v1 as pydantic +from pydantic import Field, NonNegativeInt, PositiveFloat, field_validator, model_validator -from tidy3d.components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.data.dataset import FieldDataset from tidy3d.components.data.validators import validate_can_interpolate, validate_no_nans from tidy3d.components.mode_spec import ModeSpec from tidy3d.components.source.frame import PECFrame -from tidy3d.components.types import TYPE_TAG_STR, Ax, Axis, Coordinate, Direction +from tidy3d.components.types import TYPE_TAG_STR, Axis, Direction from tidy3d.components.types.mode_spec import ModeSpecType from tidy3d.components.validators import ( assert_plane, @@ -28,6 +28,12 @@ from .base import Source +if TYPE_CHECKING: + from numpy.typing import NDArray + + from tidy3d.compat import Self + from tidy3d.components.types import Ax, Coordinate + # width of Chebyshev grid used for broadband sources (in units of pulse width) CHEB_GRID_WIDTH = 1.5 # For broadband plane waves with constan in-plane k, the Chebyshev grid is truncated at @@ -49,12 +55,12 @@ class PlanarSource(Source, ABC): _plane_validator = assert_plane() @cached_property - def injection_axis(self): + def injection_axis(self) -> Axis: """Injection axis of the source.""" return self._injection_axis @cached_property - def _injection_axis(self): + def _injection_axis(self) -> Axis: """Injection axis of the source.""" return self.size.index(0.0) @@ -71,27 +77,26 @@ class VolumeSource(Source, ABC): class DirectionalSource(FieldSource, ABC): """A Field source that propagates in a given direction.""" - direction: Direction = pydantic.Field( - ..., + direction: Direction = Field( title="Direction", description="Specifies propagation in the positive or negative direction of the injection " "axis.", ) @cached_property - def _dir_vector(self) -> tuple[float, float, float]: + def _dir_vector(self) -> Optional[tuple[float, float, float]]: """Returns a vector indicating the source direction for arrow plotting, if not None.""" if self._injection_axis is None: return None dir_vec = [0, 0, 0] dir_vec[int(self._injection_axis)] = 1 if self.direction == "+" else -1 - return dir_vec + return tuple(dir_vec) class BroadbandSource(Source, ABC): """A source with frequency dependent field distributions.""" - num_freqs: int = pydantic.Field( + num_freqs: int = Field( 1, title="Number of Frequency Points", description="Number of points used to approximate the frequency dependence of the injected " @@ -104,14 +109,14 @@ class BroadbandSource(Source, ABC): ) @cached_property - def frequency_grid(self) -> np.ndarray: + def frequency_grid(self) -> NDArray: """A Chebyshev grid used to approximate frequency dependence.""" if self.num_freqs == 1: return np.array([self.source_time._freq0]) freq_min, freq_max = self.source_time.frequency_range_sigma(sigma=CHEB_GRID_WIDTH) return self._chebyshev_freq_grid(freq_min, freq_max) - def _chebyshev_freq_grid(self, freq_min, freq_max): + def _chebyshev_freq_grid(self, freq_min: float, freq_max: float) -> NDArray: """A Chebyshev grid based on a minimum and maximum frequency.""" freq_avg = 0.5 * (freq_min + freq_max) freq_diff = 0.5 * (freq_max - freq_min) @@ -216,8 +221,8 @@ class CustomFieldSource(FieldSource, PlanarSource): * `Defining spatially-varying sources <../../notebooks/CustomFieldSource.html>`_ """ - field_dataset: Optional[FieldDataset] = pydantic.Field( - ..., + field_dataset: Optional[FieldDataset] = Field( + None, title="Field Dataset", description=":class:`.FieldDataset` containing the desired frequency-domain " "fields patterns to inject. At least one tangential field component must be specified.", @@ -228,20 +233,19 @@ class CustomFieldSource(FieldSource, PlanarSource): _field_dataset_single_freq = assert_single_freq_in_range("field_dataset") _can_interpolate = validate_can_interpolate("field_dataset") - @pydantic.validator("field_dataset", always=True) - @skip_if_fields_missing(["size"]) - def _tangential_component_defined(cls, val: FieldDataset, values: dict) -> FieldDataset: + @model_validator(mode="after") + def _tangential_component_defined(self) -> FieldDataset: """Assert that at least one tangential field component is provided.""" + val = self.field_dataset if val is None: - return val - size = values.get("size") - normal_axis = size.index(0.0) - _, (cmp1, cmp2) = cls.pop_axis("xyz", axis=normal_axis) + return self + normal_axis = self.size.index(0.0) + _, (cmp1, cmp2) = self.pop_axis("xyz", axis=normal_axis) for field in "EH": for cmp_name in (cmp1, cmp2): tangential_field = field + cmp_name if tangential_field in val.field_components: - return val + return self raise SetupError("No tangential field found in the suppled 'field_dataset'.") @@ -262,14 +266,14 @@ class AngledFieldSource(DirectionalSource, ABC): """ - angle_theta: float = pydantic.Field( + angle_theta: float = Field( 0.0, title="Polar Angle", description="Polar angle of the propagation axis from the injection axis.", units=RADIAN, ) - angle_phi: float = pydantic.Field( + angle_phi: float = Field( 0.0, title="Azimuth Angle", description="Azimuth angle of the propagation axis in the plane orthogonal to the " @@ -277,7 +281,7 @@ class AngledFieldSource(DirectionalSource, ABC): units=RADIAN, ) - pol_angle: float = pydantic.Field( + pol_angle: float = Field( 0, title="Polarization Angle", description="Specifies the angle between the electric field polarization of the " @@ -291,8 +295,9 @@ class AngledFieldSource(DirectionalSource, ABC): units=RADIAN, ) - @pydantic.validator("angle_theta", allow_reuse=True, always=True) - def glancing_incidence(cls, val): + @field_validator("angle_theta") + @classmethod + def glancing_incidence(cls, val: float) -> float: """Warn if close to glancing incidence.""" if np.abs(np.pi / 2 - val) < GLANCING_CUTOFF: log.warning( @@ -393,14 +398,14 @@ class ModeSource(DirectionalSource, PlanarSource, BroadbandSource): * `Prelude to Integrated Photonics Simulation: Mode Injection `_ """ - mode_spec: ModeSpecType = pydantic.Field( - ModeSpec(), + mode_spec: ModeSpecType = Field( + default_factory=ModeSpec, title="Mode Specification", description="Parameters to feed to mode solver which determine modes measured by monitor.", discriminator=TYPE_TAG_STR, ) - mode_index: pydantic.NonNegativeInt = pydantic.Field( + mode_index: NonNegativeInt = Field( 0, title="Mode Index", description="Index into the collection of modes returned by mode solver. " @@ -409,7 +414,7 @@ class ModeSource(DirectionalSource, PlanarSource, BroadbandSource): "``num_modes`` in the solver will be set to ``mode_index + 1``.", ) - frame: Optional[PECFrame] = pydantic.Field( + frame: Optional[PECFrame] = Field( None, title="Source Frame", description="Add a thin frame around the source during the FDTD run to improve " @@ -418,12 +423,12 @@ class ModeSource(DirectionalSource, PlanarSource, BroadbandSource): ) @cached_property - def angle_theta(self): + def angle_theta(self) -> float: """Polar angle of propagation.""" return self.mode_spec.angle_theta @cached_property - def angle_phi(self): + def angle_phi(self) -> float: """Azimuth angle of propagation.""" return self.mode_spec.angle_phi @@ -437,7 +442,7 @@ def _dir_vector(self) -> tuple[float, float, float]: return self.unpop_axis(dz, (dx, dy), axis=self._injection_axis) @cached_property - def _bend_axis(self) -> Axis: + def _bend_axis(self) -> Optional[Axis]: if self.mode_spec.bend_radius is None: return None in_plane = [0, 0] @@ -493,14 +498,14 @@ class PlaneWave(AngledFieldSource, PlanarSource, BroadbandSource): * `Using FDTD to Compute a Transmission Spectrum `__ """ - angular_spec: Union[FixedInPlaneKSpec, FixedAngleSpec] = pydantic.Field( - FixedInPlaneKSpec(), + angular_spec: Union[FixedInPlaneKSpec, FixedAngleSpec] = Field( + default_factory=FixedInPlaneKSpec, title="Angular Dependence Specification", description="Specification of plane wave propagation direction dependence on wavelength.", discriminator=TYPE_TAG_STR, ) - num_freqs: int = pydantic.Field( + num_freqs: int = Field( 3, title="Number of Frequency Points", description="Number of points used to approximate the frequency dependence of the injected " @@ -517,7 +522,7 @@ def _is_fixed_angle(self) -> bool: return isinstance(self.angular_spec, FixedAngleSpec) and self.angle_theta != 0.0 @cached_property - def frequency_grid(self) -> np.ndarray: + def frequency_grid(self) -> NDArray: """A Chebyshev grid used to approximate frequency dependence.""" if self.num_freqs == 1: return np.array([self.source_time._freq0]) @@ -529,11 +534,12 @@ def frequency_grid(self) -> np.ndarray: freq_min = max(freq_min, f_crit * CRITICAL_FREQUENCY_FACTOR) return self._chebyshev_freq_grid(freq_min, freq_max) - def _post_init_validators(self) -> None: + @model_validator(mode="after") + def _validate_source_frequency_range(self) -> Self: """Error if a broadband plane wave with constant in-plane k is defined such that the source frequency range is entirely below ``f_crit * CRITICAL_FREQUENCY_FACTOR.""" if self._is_fixed_angle or self.num_freqs == 1: - return + return self freq_min, freq_max = self.source_time.frequency_range_sigma(sigma=CHEB_GRID_WIDTH) f_crit = self.source_time._freq0 * np.sin(self.angle_theta) if f_crit * CRITICAL_FREQUENCY_FACTOR > freq_max: @@ -542,6 +548,7 @@ def _post_init_validators(self) -> None: "frequency of oblique incidence. Increase the source bandwidth, or disable the " "broadband handling by setting 'num_freqs' to 1." ) + return self class GaussianBeam(AngledFieldSource, PlanarSource, BroadbandSource): @@ -573,14 +580,14 @@ class GaussianBeam(AngledFieldSource, PlanarSource, BroadbandSource): * `Inverse taper edge coupler <../../notebooks/EdgeCoupler.html>`_ """ - waist_radius: pydantic.PositiveFloat = pydantic.Field( + waist_radius: PositiveFloat = Field( 1.0, title="Waist Radius", description="Radius of the beam at the waist.", units=MICROMETER, ) - waist_distance: float = pydantic.Field( + waist_distance: float = Field( 0.0, title="Waist Distance", description="Distance from the beam waist along the propagation direction. " @@ -592,7 +599,7 @@ class GaussianBeam(AngledFieldSource, PlanarSource, BroadbandSource): units=MICROMETER, ) - num_freqs: int = pydantic.Field( + num_freqs: int = Field( 1, title="Number of Frequency Points", description="Number of points used to approximate the frequency dependence of the injected " @@ -634,14 +641,14 @@ class AstigmaticGaussianBeam(AngledFieldSource, PlanarSource, BroadbandSource): ... waist_distances = (3.0, 4.0)) """ - waist_sizes: tuple[pydantic.PositiveFloat, pydantic.PositiveFloat] = pydantic.Field( + waist_sizes: tuple[PositiveFloat, PositiveFloat] = Field( (1.0, 1.0), title="Waist sizes", description="Size of the beam at the waist in the local x and y directions.", units=MICROMETER, ) - waist_distances: tuple[float, float] = pydantic.Field( + waist_distances: tuple[float, float] = Field( (0.0, 0.0), title="Waist distances", description="Distance to the beam waist along the propagation direction " @@ -653,7 +660,7 @@ class AstigmaticGaussianBeam(AngledFieldSource, PlanarSource, BroadbandSource): units=MICROMETER, ) - num_freqs: int = pydantic.Field( + num_freqs: int = Field( 1, title="Number of Frequency Points", description="Number of points used to approximate the frequency dependence of the injected " @@ -664,7 +671,7 @@ class AstigmaticGaussianBeam(AngledFieldSource, PlanarSource, BroadbandSource): ge=1, le=20, ) - _backward_waist_warning = warn_backward_waist_distance("waist_distances") + backward_waist_warning = warn_backward_waist_distance("waist_distances") class TFSF(AngledFieldSource, VolumeSource, BroadbandSource): @@ -700,8 +707,7 @@ class TFSF(AngledFieldSource, VolumeSource, BroadbandSource): * `Nanoparticle Scattering <../../notebooks/PlasmonicNanoparticle.html>`_: To force a uniform grid in the TFSF region and avoid the warnings, a mesh override structure can be used as illustrated here. """ - injection_axis: Axis = pydantic.Field( - ..., + injection_axis: Axis = Field( title="Injection Axis", description="Specifies the injection axis. The plane of incidence is defined via this " "``injection_axis`` and the ``direction``. The popagation axis is defined with respect " @@ -709,7 +715,7 @@ class TFSF(AngledFieldSource, VolumeSource, BroadbandSource): ) @cached_property - def _injection_axis(self): + def _injection_axis(self) -> Axis: """Injection axis of the source.""" return self.injection_axis diff --git a/tidy3d/components/source/frame.py b/tidy3d/components/source/frame.py index a16e9708ef..00e3847efd 100644 --- a/tidy3d/components/source/frame.py +++ b/tidy3d/components/source/frame.py @@ -4,7 +4,7 @@ from abc import ABC -import pydantic.v1 as pydantic +from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel @@ -12,7 +12,7 @@ class AbstractSourceFrame(Tidy3dBaseModel, ABC): """Abstract base class for all source frames.""" - length: int = pydantic.Field( + length: int = Field( 2, title="Length", description="The length of the frame, specified as the number of cells along the source " diff --git a/tidy3d/components/source/freq_range.py b/tidy3d/components/source/freq_range.py new file mode 100644 index 0000000000..df5d0e041d --- /dev/null +++ b/tidy3d/components/source/freq_range.py @@ -0,0 +1,230 @@ +"""Utility class ``FreqRange`` for frequency and wavelength handling.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +from pydantic import Field, PositiveFloat + +from tidy3d import constants as td_const +from tidy3d.components.base import Tidy3dBaseModel +from tidy3d.components.source.time import GaussianPulse + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +class FreqRange(Tidy3dBaseModel): + """ + Convenience class for handling frequency/wavelength conversion; it simplifies specification + of frequency ranges and sample points for sources and monitors. + + Notes + ----- + Depending on the context the user can define desired frequency range by specifying: + - central frequency ``freq0`` and frequency bandwidth ``fwidth``; + - frequency interval [``fmin``,``fmax``]; + - central wavelength ``wvl0`` and wavelength range ``wvl_width``; + - wavelength interval [``wvl_min``, ``wvl_max``]. + + Example + ------- + >>> import tidy3d as td + >>> freq0 = 1e12 + >>> fwidth = 1e11 + >>> freq_range = td.FreqRange(freq0=freq0, fwidth=fwidth) + >>> central_freq = freq_range.freqs(num_points=1) + >>> freqs = freq_range.freqs(num_points=11) + >>> source = freq_range.to_gaussian_pulse() + """ + + freq0: PositiveFloat = Field( + title="Central frequency", + description="Real-valued positive central frequency.", + units="Hz", + ) + + fwidth: PositiveFloat = Field( + title="Frequency bandwidth", + description="Real-valued positive width of the frequency range (bandwidth).", + units="Hz", + ) + + @property + def fmin(self) -> float: + """Infer lowest frequency ``fmin`` from central frequency ``freq0`` and bandwidth ``fwidth``.""" + return self.freq0 - self.fwidth + + @property + def fmax(self) -> float: + """Infer highest frequency ``fmax`` from central frequency ``freq0`` and bandwidth ``fwidth``.""" + return self.freq0 + self.fwidth + + @property + def lda0(self) -> float: + """Get central wavelength from central frequency and bandwidth.""" + lmin = td_const.C_0 / (self.freq0 + self.fwidth) + lmax = td_const.C_0 / (self.freq0 - self.fwidth) + return 0.5 * (lmin + lmax) + + @classmethod + def from_freq_interval(cls, fmin: float, fmax: float) -> FreqRange: + """ + method ``from_freq_interval()`` creates instance of class ``FreqRange`` from frequency interval + defined by arguments ``fmin`` and ``fmax``. + + NB: central frequency never corresponds to central wavelength! + ``freq0 = (fmin + fmax) / 2`` implies that ``lda0 != (lda_min + lda_max) / 2`` and vise-versa. + + Parameters + ---------- + fmin : float + Lower bound of frequency of interest. + fmax : float + Upper bound of frequency of interest. + + Returns + ------- + FreqRange + An instance of ``FreqRange`` defined by frequency interval [``fmin``, ``fmax``]. + """ + + # extract frequency-related info + freq0 = 0.5 * (fmax + fmin) # extract central freq + fwidth = 0.5 * (fmax - fmin) # extract bandwidth + return cls(freq0=freq0, fwidth=fwidth) + + @classmethod + def from_wavelength(cls, wvl0: float, wvl_width: float) -> FreqRange: + """ + method ``from_wavelength()`` updated instance of class ``FreqRange`` by reassigning new + frequency- and wavelength-related parameters. + + NB: central frequency never corresponds to central wavelength! + ``lda0 = (lda_min + lda_max) / 2`` implies that ``freq0 != (fmin + fmax) / 2`` and vise versa. + + Parameters + ---------- + wvl0 : float + Real-valued central wavelength. + wvl_width : float + Real-valued wavelength range. + + Returns + ------- + FreqRange + An instance of ``FreqRange`` defined by central wavelength ``wvl0`` and wavelength range ``wvl_width``. + """ + + # calculate lowest and highest frequencies + fmin = td_const.C_0 / (wvl0 + wvl_width) + fmax = td_const.C_0 / (wvl0 - wvl_width) + + return cls.from_freq_interval(fmin=fmin, fmax=fmax) + + @classmethod + def from_wvl_interval(cls, wvl_min: float, wvl_max: float) -> FreqRange: + """ + method ``from_wvl_interval()`` updated instance of class ``FreqRange`` by reassigning new + frequency- and wavelength-related parameters. + + NB: central frequency never corresponds to central wavelength! + ``lda0 = (lda_min + lda_max) / 2`` implies that ``freq0 != (fmin + fmax) / 2``. + + Parameters + ---------- + wvl_min : float + The lowest wavelength of interest. + wvl_max : float + The longest wavelength of interest. + + Returns + ------- + FreqRange + An instance of ``FreqRange`` defined by the wavelength interval [``wvl_min``, ``wvl_max``]. + """ + + # convert wavelength intervals to frequency interval + fmax = td_const.C_0 / wvl_min + fmin = td_const.C_0 / wvl_max + + return cls.from_freq_interval(fmin=fmin, fmax=fmax) + + def freqs(self, num_points: int) -> NDArray[np.float64]: + """ + method ``freqs()`` returns a numpy array of ``num_point`` frequencies uniformly + sampled from the specified frequency range; + if ``num_points == 1`` method returns the central frequency ``freq0``. + + Parameters + ---------- + num_points : int + Number of frequency points in a frequency range of interest. + + Returns + ------- + np.ndarray + a numpy array of uniformly distributed frequency samples in a frequency range of interest. + """ + + if num_points == 1: # return central frequency + return np.array([self.freq0]) + else: + # calculate frequency points and corresponding wavelengths + return np.linspace(self.fmin, self.fmax, num_points) + + def ldas(self, num_points: int) -> NDArray[np.float64]: + """ + method ``ldas()`` returns a numpy array of ``num_points`` wavelengths uniformly + sampled from the range of wavelengths; + if ``num_points == 1`` the method returns central wavelength ``lda0``. + + Parameters + ---------- + num_points : int + Number of wavelength points in a range of wavelengths of interest. + + Returns + ------- + np.ndarray + a numpy array of uniformly distributed wavelength samples in ascending order. + """ + if num_points == 1: # return central wavelength + return np.array([self.lda0]) + else: + # define shortest and longest wavelengths + lmin = td_const.C_0 / self.fmax + lmax = td_const.C_0 / self.fmin + + # generate array of wavelengths (in ascending order) + return np.linspace(lmin, lmax, num_points) + + def to_gaussian_pulse(self, **kwargs: Any) -> GaussianPulse: + """ + method ``to_gaussian_pulse()`` returns instance of class ``GaussianPulse`` + with frequency-specific parameters defined in ``FreqRange``. + + Parameters + ---------- + kwargs : dict + Keyword arguments passed to ``GaussianPulse()``, excluding ``freq0`` & ``fwidth``. + + Returns + ------- + GaussianPulse + A ``GaussianPulse`` that maximizes its amplitude in the frequency range [``fmin``, ``fmax``]. + """ + + duplicate_keys = {"fmin", "fmax"} & kwargs.keys() + if duplicate_keys: + is_plural = len(duplicate_keys) > 1 + keys_str = ", ".join(f"'{key}'" for key in sorted(duplicate_keys, reverse=True)) + raise ValueError( + f"Keyword argument{'s' if is_plural else ''} {keys_str} " + f"conflict{'' if is_plural else 's'} with values already set in the 'FreqRange' object. " + f"Please exclude {'them' if is_plural else 'it'} from the 'to_gaussian_pulse()' call." + ) + + # create an instance of GaussianPulse class with defined frequency params + return GaussianPulse.from_frequency_range(fmin=self.fmin, fmax=self.fmax, **kwargs) diff --git a/tidy3d/components/source/time.py b/tidy3d/components/source/time.py index 503a4490ce..628cd35b8b 100644 --- a/tidy3d/components/source/time.py +++ b/tidy3d/components/source/time.py @@ -1,685 +1,21 @@ -"""Defines time dependencies of injected electromagnetic sources.""" +"""Compatibility shim for :mod:`tidy3d._common.components.source.time`.""" -from __future__ import annotations - -import logging -from abc import ABC, abstractmethod -from typing import Any, Optional, Union - -import numpy as np -import pydantic.v1 as pydantic -from pyroots import Brentq - -from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import TimeDataArray -from tidy3d.components.data.dataset import TimeDataset -from tidy3d.components.data.validators import validate_no_nans -from tidy3d.components.time import AbstractTimeDependence -from tidy3d.components.types import ArrayComplex1D, ArrayFloat1D, Ax, FreqBound, PlotVal -from tidy3d.components.validators import warn_if_dataset_none -from tidy3d.components.viz import add_ax_if_none -from tidy3d.constants import HERTZ -from tidy3d.exceptions import ValidationError -from tidy3d.log import log -from tidy3d.packaging import check_tidy3d_extras_licensed_feature, tidy3d_extras - -# how many units of ``twidth`` from the ``offset`` until a gaussian pulse is considered "off" -END_TIME_FACTOR_GAUSSIAN = 10 - -# warn if source amplitude is too small at the endpoints of frequency range -WARN_SOURCE_AMPLITUDE = 0.1 -# used in Brentq -_ROOTS_TOL = 1e-10 -# Default sigma value in frequency_range -DEFAULT_SIGMA = 4.0 -# Offset in fwidth in finding frequency_range_sigma[1] to ensure the interval brackets the root -OFFSET_FWIDTH_FMAX = 100 - - -class SourceTime(AbstractTimeDependence): - """Base class describing the time dependence of a source.""" - - @add_ax_if_none - def plot_spectrum( - self, - times: ArrayFloat1D, - num_freqs: int = 101, - val: PlotVal = "real", - ax: Ax = None, - ) -> Ax: - """Plot the complex-valued amplitude of the source time-dependence. - Note: Only the real part of the time signal is used. - - Parameters - ---------- - times : np.ndarray - Array of evenly-spaced times (seconds) to evaluate source time-dependence at. - The spectrum is computed from this value and the source time frequency content. - To see source spectrum for a specific :class:`.Simulation`, - pass ``simulation.tmesh``. - num_freqs : int = 101 - Number of frequencies to plot within the SourceTime.frequency_range. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - - fmin, fmax = self.frequency_range_sigma() - return self.plot_spectrum_in_frequency_range( - times, fmin, fmax, num_freqs=num_freqs, val=val, ax=ax - ) - - @abstractmethod - def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: - """Frequency range within plus/minus ``num_fwidth * fwidth`` of the central frequency.""" - - def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: - """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" - return self.frequency_range(num_fwidth=sigma) - - @cached_property - def _frequency_range_sigma_cached(self) -> FreqBound: - """Cached `frequency_range_sigma` for the default sigma value.""" - return self.frequency_range_sigma(sigma=DEFAULT_SIGMA) - - @abstractmethod - def end_time(self) -> Optional[float]: - """Time after which the source is effectively turned off / close to zero amplitude.""" - - @cached_property - def _freq0(self) -> float: - """Central frequency. If not present in input parameters, returns `_freq0_sigma_centroid`.""" - return self._freq0_sigma_centroid - - @cached_property - def _freq0_sigma_centroid(self) -> float: - """Central of frequency range at 1-sigma drop from the peak amplitude.""" - return np.mean(self.frequency_range_sigma(sigma=1)) - - -class Pulse(SourceTime, ABC): - """A source time that ramps up with some ``fwidth`` and oscillates at ``freq0``.""" - - freq0: pydantic.PositiveFloat = pydantic.Field( - ..., title="Central Frequency", description="Central frequency of the pulse.", units=HERTZ - ) - fwidth: pydantic.PositiveFloat = pydantic.Field( - ..., - title="", - description="Standard deviation of the frequency content of the pulse.", - units=HERTZ, - ) - - offset: float = pydantic.Field( - 5.0, - title="Offset", - description="Time delay of the maximum value of the " - "pulse in units of 1 / (``2pi * fwidth``).", - ge=2.5, - ) - - @cached_property - def _freq0(self) -> float: - """Central frequency.""" - return self.freq0 - - @property - def offset_time(self) -> float: - """Offset time in seconds.""" - return self.offset * self.twidth - - @property - def twidth(self) -> float: - """Width of pulse in seconds.""" - return 1.0 / (2 * np.pi * self.fwidth) - - def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: - """Frequency range within 5 standard deviations of the central frequency. - - Parameters - ---------- - num_fwidth : float = 4. - Frequency range defined as plus/minus ``num_fwidth * self.fwdith``. - - Returns - ------- - Tuple[float, float] - Minimum and maximum frequencies of the :class:`GaussianPulse` or :class:`ContinuousWave` - power. - """ - - freq_width_range = num_fwidth * self.fwidth - freq_min = max(0, self.freq0 - freq_width_range) - freq_max = self.freq0 + freq_width_range - return (freq_min, freq_max) - - -class GaussianPulse(Pulse): - """Source time dependence that describes a Gaussian pulse. - - Example - ------- - >>> pulse = GaussianPulse(freq0=200e12, fwidth=20e12) - """ - - remove_dc_component: bool = pydantic.Field( - True, - title="Remove DC Component", - description="Whether to remove the DC component in the Gaussian pulse spectrum. " - "If ``True``, the Gaussian pulse is modified at low frequencies to zero out the " - "DC component, which is usually desirable so that the fields will decay. However, " - "for broadband simulations, it may be better to have non-vanishing source power " - "near zero frequency. Setting this to ``False`` results in an unmodified Gaussian " - "pulse spectrum which can have a nonzero DC component.", - ) - - @property - def peak_time(self) -> float: - """Peak time in seconds, defined by ``offset``.""" - return self.offset * self.twidth - - @property - def _peak_time_shift(self) -> float: - """In the case of DC removal, correction to offset_time so that ``offset`` indeed defines time delay - of pulse peak. - """ - if self.remove_dc_component and self.fwidth > self.freq0: - return self.twidth * np.sqrt(1 - self.freq0**2 / self.fwidth**2) - return 0 - - @property - def offset_time(self) -> float: - """Offset time in seconds. Note that in the case of DC removal, the maximal value of pulse can be shifted.""" - return self.peak_time + self._peak_time_shift - - def amp_time(self, time: float) -> complex: - """Complex-valued source amplitude as a function of time.""" - - omega0 = 2 * np.pi * self.freq0 - time_shifted = time - self.offset_time - - offset = np.exp(1j * self.phase) - oscillation = np.exp(-1j * omega0 * time) - amp = np.exp(-(time_shifted**2) / 2 / self.twidth**2) * self.amplitude - - pulse_amp = offset * oscillation * amp - - # subtract out DC component - if self.remove_dc_component: - pulse_amp = pulse_amp * (1j * omega0 + time_shifted / self.twidth**2) - # normalize by peak frequency instead of omega0, as for small omega0, omega0 approaches 0 faster - pulse_amp /= 2 * np.pi * self.peak_frequency - else: - # 1j to make it agree in large omega0 limit - pulse_amp = pulse_amp * 1j - - return pulse_amp - - def end_time(self) -> Optional[float]: - """Time after which the source is effectively turned off / close to zero amplitude.""" - - # TODO: decide if we should continue to return an end_time if the DC component remains - # if not self.remove_dc_component: - # return None - - end_time = self.offset_time + END_TIME_FACTOR_GAUSSIAN * self.twidth - - # for derivative Gaussian that contains two peaks, add time interval between them - if self.remove_dc_component and self.fwidth > self.freq0: - end_time += 2 * self._peak_time_shift - return end_time - - def amp_freq(self, freq: float) -> complex: - """Complex-valued source spectrum in frequency domain.""" - phase = np.exp(1j * self.phase + 1j * 2 * np.pi * (freq - self.freq0) * self.offset_time) - envelope = np.exp(-((freq - self.freq0) ** 2) / 2 / self.fwidth**2) - amp = 1j * self.amplitude / self.fwidth * phase * envelope - if not self.remove_dc_component: - return amp - - # derivative of Gaussian when DC is removed - return freq * amp / (2 * np.pi * self.peak_frequency) - - def _rel_amp_freq(self, freq: float) -> complex: - """Complex-valued source spectrum in frequency domain normalized by peak amplitude.""" - return self.amp_freq(freq) / self._peak_freq_amp - - @property - def peak_frequency(self) -> float: - """Frequency at which the source time dependence has its peak amplitude in the frequency domain.""" - if not self.remove_dc_component: - return self.freq0 - return 0.5 * (self.freq0 + np.sqrt(self.freq0**2 + 4 * self.fwidth**2)) - - @property - def _peak_freq_amp(self) -> complex: - """Peak amplitude in frequency domain""" - return self.amp_freq(self.peak_frequency) - - @property - def _peak_time_amp(self) -> complex: - """Peak amplitude in time domain""" - return self.amp_time(self.peak_time) - - def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: - """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" - if not self.remove_dc_component: - return self.frequency_range(num_fwidth=sigma) - - # With dc removed, we'll need to solve for the transcendental equation to find the frequency range - def equation_for_sigma_frequency(freq): - """computes A / A_p - exp(-sigma)""" - return np.abs(self._rel_amp_freq(freq)) - np.exp(-(sigma**2) / 2) - - logger = logging.getLogger("pyroots") - logger.setLevel(logging.CRITICAL) - root_scalar = Brentq(raise_on_fail=False, epsilon=_ROOTS_TOL) - fmin_data = root_scalar(equation_for_sigma_frequency, xa=0, xb=self.peak_frequency) - fmax_data = root_scalar( - equation_for_sigma_frequency, - xa=self.peak_frequency, - xb=self.peak_frequency - + self.fwidth - * ( - OFFSET_FWIDTH_FMAX + 2 * sigma**2 - ), # offset slightly to make sure that it flips sign - ) - fmin, fmax = fmin_data.x0, fmax_data.x0 - - # if unconverged, fall back to `frequency_range` - if not (fmin_data.converged and fmax_data.converged and fmax > fmin): - return self.frequency_range(num_fwidth=sigma) - - # converged - return fmin.item(), fmax.item() - - @property - def amp_complex(self) -> complex: - """Grab the complex amplitude from a ``GaussianPulse``.""" - phase = np.exp(1j * self.phase) - return self.amplitude * phase - - @classmethod - def from_amp_complex(cls, amp: complex, **kwargs: Any) -> GaussianPulse: - """Set the complex amplitude of a ``GaussianPulse``. - - Parameters - ---------- - amp : complex - Complex-valued amplitude to set in the returned ``GaussianPulse``. - kwargs : dict - Keyword arguments passed to ``GaussianPulse()``, excluding ``amplitude`` & ``phase``. - """ - amplitude = abs(amp) - phase = np.angle(amp) - return cls(amplitude=amplitude, phase=phase, **kwargs) - - @staticmethod - def _minimum_source_bandwidth( - fmin: float, fmax: float, minimum_source_bandwidth: float - ) -> tuple[float, float]: - """Define a source bandwidth based on fmin and fmax, but enforce a minimum bandwidth.""" - if minimum_source_bandwidth <= 0: - raise ValidationError("'minimum_source_bandwidth' must be positive") - if minimum_source_bandwidth >= 1: - raise ValidationError("'minimum_source_bandwidth' must less than or equal to 1") - - f_difference = fmax - fmin - f_middle = 0.5 * (fmin + fmax) - - full_width = minimum_source_bandwidth * f_middle - if f_difference < full_width: - half_width = 0.5 * full_width - fmin = f_middle - half_width - fmax = f_middle + half_width - - return fmin, fmax - - @classmethod - def from_frequency_range( - cls, - fmin: pydantic.PositiveFloat, - fmax: pydantic.PositiveFloat, - minimum_source_bandwidth: pydantic.PositiveFloat = None, - **kwargs: Any, - ) -> GaussianPulse: - """Create a ``GaussianPulse`` that maximizes its amplitude in the frequency range [fmin, fmax]. - - Parameters - ---------- - fmin : float - Lower bound of frequency of interest. - fmax : float - Upper bound of frequency of interest. - kwargs : dict - Keyword arguments passed to ``GaussianPulse()``, excluding ``freq0`` & ``fwidth``. - - Returns - ------- - GaussianPulse - A ``GaussianPulse`` that maximizes its amplitude in the frequency range [fmin, fmax]. - """ - # validate that fmin and fmax must positive, and fmax > fmin - if fmin <= 0: - raise ValidationError("'fmin' must be positive.") - if fmax <= fmin: - raise ValidationError("'fmax' must be greater than 'fmin'.") - - if minimum_source_bandwidth is not None: - fmin, fmax = cls._minimum_source_bandwidth(fmin, fmax, minimum_source_bandwidth) - - # frequency range and center - freq_range = fmax - fmin - freq_center = (fmax + fmin) / 2.0 - - # If remove_dc_component=False, simply return the standard GaussianPulse parameters - if kwargs.get("remove_dc_component", True) is False: - return cls(freq0=freq_center, fwidth=freq_range / 2.0, **kwargs) - - # If remove_dc_component=True, the Gaussian pulse is distorted - kwargs.update({"remove_dc_component": True}) - log_ratio = np.log(fmax / fmin) - coeff = ((1 + log_ratio**2) ** 0.5 - 1) / 2.0 - freq0 = freq_center - coeff / log_ratio * freq_range - fwidth = freq_range / log_ratio * coeff**0.5 - pulse = cls(freq0=freq0, fwidth=fwidth, **kwargs) - if np.abs(pulse._rel_amp_freq(fmin)) < WARN_SOURCE_AMPLITUDE: - log.warning( - "Source amplitude is not sufficiently large throughout the specified frequency range, " - "which can result in inaccurate simulation results. Please decrease the frequency range.", - ) - return pulse - - -class ContinuousWave(Pulse): - """Source time dependence that ramps up to continuous oscillation - and holds until end of simulation. - - Note - ---- - Field decay will not occur, so the simulation will run for the full ``run_time``. - Also, source normalization of frequency-domain monitors is not meaningful. - - Example - ------- - >>> cw = ContinuousWave(freq0=200e12, fwidth=20e12) - """ - - def amp_time(self, time: float) -> complex: - """Complex-valued source amplitude as a function of time.""" - - twidth = 1.0 / (2 * np.pi * self.fwidth) - omega0 = 2 * np.pi * self.freq0 - time_shifted = time - self.offset_time - - const = 1.0 - offset = np.exp(1j * self.phase) - oscillation = np.exp(-1j * omega0 * time) - amp = 1 / (1 + np.exp(-time_shifted / twidth)) * self.amplitude - - return const * offset * oscillation * amp - - def end_time(self) -> Optional[float]: - """Time after which the source is effectively turned off / close to zero amplitude.""" - return None - - -class CustomSourceTime(Pulse): - """Custom source time dependence consisting of a real or complex envelope - modulated at a central frequency, as shown below. - - Note - ---- - .. math:: - - amp\\_time(t) = amplitude \\cdot \\ - e^{i \\cdot phase - 2 \\pi i \\cdot freq0 \\cdot t} \\cdot \\ - envelope(t - offset / (2 \\pi \\cdot fwidth)) - - Note - ---- - Depending on the envelope, field decay may not occur. - If field decay does not occur, then the simulation will run for the full ``run_time``. - Also, if field decay does not occur, then source normalization of frequency-domain - monitors is not meaningful. - - Note - ---- - The source time dependence is linearly interpolated to the simulation time steps. - The sampling rate should be sufficiently fast that this interpolation does not - introduce artifacts. The source time dependence should also start at zero and ramp up smoothly. - The first and last values of the envelope will be used for times that are out of range - of the provided data. - - Example - ------- - >>> cst = CustomSourceTime.from_values(freq0=1, fwidth=0.1, - ... values=np.linspace(0, 9, 10), dt=0.1) - - """ - - offset: float = pydantic.Field( - 0.0, - title="Offset", - description="Time delay of the envelope in units of 1 / (``2pi * fwidth``).", - ) - - source_time_dataset: Optional[TimeDataset] = pydantic.Field( - ..., - title="Source time dataset", - description="Dataset for storing the envelope of the custom source time. " - "This envelope will be modulated by a complex exponential at frequency ``freq0``.", - ) - - _no_nans_dataset = validate_no_nans("source_time_dataset") - _source_time_dataset_none_warning = warn_if_dataset_none("source_time_dataset") - - @pydantic.validator("source_time_dataset", always=True) - def _more_than_one_time(cls, val): - """Must have more than one time to interpolate.""" - if val is None: - return val - if val.values.size <= 1: - raise ValidationError("'CustomSourceTime' must have more than one time coordinate.") - return val - - @classmethod - def from_values( - cls, freq0: float, fwidth: float, values: ArrayComplex1D, dt: float - ) -> CustomSourceTime: - """Create a :class:`.CustomSourceTime` from a numpy array. - - Parameters - ---------- - freq0 : float - Central frequency of the source. The envelope provided will be modulated - by a complex exponential at this frequency. - fwidth : float - Estimated frequency width of the source. - values: ArrayComplex1D - Complex values of the source envelope. - dt: float - Time step for the ``values`` array. This value should be sufficiently small - that the interpolation to simulation time steps does not introduce artifacts. - - Returns - ------- - CustomSourceTime - :class:`.CustomSourceTime` with envelope given by ``values``, modulated by a complex - exponential at frequency ``freq0``. The time coordinates are evenly spaced - between ``0`` and ``dt * (N-1)`` with a step size of ``dt``, where ``N`` is the length of - the values array. - """ - - times = np.arange(len(values)) * dt - source_time_dataarray = TimeDataArray(values, coords={"t": times}) - source_time_dataset = TimeDataset(values=source_time_dataarray) - return CustomSourceTime( - freq0=freq0, - fwidth=fwidth, - source_time_dataset=source_time_dataset, - ) - - @property - def data_times(self) -> ArrayFloat1D: - """Times of envelope definition.""" - if self.source_time_dataset is None: - return [] - data_times = self.source_time_dataset.values.coords["t"].values.squeeze() - return data_times - - def _all_outside_range(self, run_time: float) -> bool: - """Whether all times are outside range of definition.""" - - # can't validate if data isn't loaded - if self.source_time_dataset is None: - return False - - # make time a numpy array for uniform handling - data_times = self.data_times - - # shift time - max_time_shifted = run_time - self.offset_time - min_time_shifted = -self.offset_time - - return (max_time_shifted < min(data_times)) | (min_time_shifted > max(data_times)) - - def amp_time(self, time: float) -> complex: - """Complex-valued source amplitude as a function of time. - - Parameters - ---------- - time : float - Time in seconds. - - Returns - ------- - complex - Complex-valued source amplitude at that time. - """ - - if self.source_time_dataset is None: - return None - - # make time a numpy array for uniform handling - times = np.array([time] if isinstance(time, (int, float)) else time) - data_times = self.data_times - - # shift time - twidth = 1.0 / (2 * np.pi * self.fwidth) - time_shifted = times - self.offset * twidth - - # mask times that are out of range - mask = (time_shifted < min(data_times)) | (time_shifted > max(data_times)) - - # get envelope - envelope = np.zeros(len(time_shifted), dtype=complex) - values = self.source_time_dataset.values - envelope[mask] = values.sel(t=time_shifted[mask], method="nearest").to_numpy() - if not all(mask): - envelope[~mask] = values.interp(t=time_shifted[~mask]).to_numpy() - - # modulation, phase, amplitude - omega0 = 2 * np.pi * self.freq0 - offset = np.exp(1j * self.phase) - oscillation = np.exp(-1j * omega0 * times) - amp = self.amplitude - - return offset * oscillation * amp * envelope - - def end_time(self) -> Optional[float]: - """Time after which the source is effectively turned off / close to zero amplitude.""" - - if self.source_time_dataset is None: - return None - - data_array = self.source_time_dataset.values - - t_coords = data_array.coords["t"] - source_is_non_zero = ~np.isclose(abs(data_array), 0) - t_non_zero = t_coords[source_is_non_zero] - - return np.max(t_non_zero) - - -class BroadbandPulse(SourceTime): - """A source time injecting significant energy in the entire custom frequency range.""" - - freq_range: FreqBound = pydantic.Field( - ..., - title="Frequency Range", - description="Frequency range where the pulse should have significant energy.", - units=HERTZ, - ) - minimum_amplitude: float = pydantic.Field( - 0.3, - title="Minimum Amplitude", - description="Minimum amplitude of the pulse relative to the peak amplitude in the frequency range.", - gt=0.05, - lt=0.5, - ) - offset: float = pydantic.Field( - 0.0, - title="Offset", - description="An automatic time delay of the peak value of the pulse has been applied under the hood " - "to ensure smooth ramping up of the pulse at time = 0. This offfset is added on top of the automatic time delay " - "in units of 1 / [``2pi * (freq_range[1] - freq_range[0])``].", - ) - - @pydantic.validator("freq_range", always=True) - def _validate_freq_range(cls, val): - """Validate that freq_range is positive and properly ordered.""" - if val[0] <= 0 or val[1] <= 0: - raise ValidationError("Both elements of 'freq_range' must be positive.") - if val[1] <= val[0]: - raise ValidationError( - f"'freq_range[1]' ({val[1]}) must be greater than 'freq_range[0]' ({val[0]})." - ) - return val - - @pydantic.root_validator() - def _check_broadband_pulse_available(cls, values): - """Check if BroadbandPulse is available.""" - check_tidy3d_extras_licensed_feature("BroadbandPulse") - return values - - @cached_property - def _source(self): - """Implementation of broadband pulse.""" - return tidy3d_extras["mod"].extension.BroadbandPulse( - fmin=self.freq_range[0], - fmax=self.freq_range[1], - minRelAmp=self.minimum_amplitude, - amp=self.amplitude, - phase=self.phase, - offset=self.offset, - ) - - def end_time(self) -> float: - """Time after which the source is effectively turned off / close to zero amplitude.""" - return self._source.end_time(END_TIME_FACTOR_GAUSSIAN) - - def amp_time(self, time: float) -> complex: - """Complex-valued source amplitude as a function of time.""" - return self._source.amp_time(time) - - def amp_freq(self, freq: float) -> complex: - """Complex-valued source amplitude as a function of frequency.""" - return self._source.amp_freq(freq) - - def frequency_range_sigma(self, sigma: float = DEFAULT_SIGMA) -> FreqBound: - """Frequency range where the source amplitude is within ``exp(-sigma**2/2)`` of the peak amplitude.""" - return self._source.frequency_range(sigma) - - def frequency_range(self, num_fwidth: float = DEFAULT_SIGMA) -> FreqBound: - """Delegated to `frequency_range_sigma(sigma=num_fwidth)` for computing the frequency range where the source amplitude - is within ``exp(-num_fwidth**2/2)`` of the peak amplitude. - """ - return self.frequency_range_sigma(num_fwidth) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -SourceTimeType = Union[GaussianPulse, ContinuousWave, CustomSourceTime, BroadbandPulse] +from tidy3d._common.components.source.time import ( + _ROOTS_TOL, + DEFAULT_SIGMA, + END_TIME_FACTOR_GAUSSIAN, + OFFSET_FWIDTH_FMAX, + WARN_SOURCE_AMPLITUDE, + BroadbandPulse, + ContinuousWave, + CustomSourceTime, + GaussianPulse, + Pulse, + SourceTime, + SourceTimeType, +) diff --git a/tidy3d/components/spice/analysis/ac.py b/tidy3d/components/spice/analysis/ac.py index 2dd510284c..83087800ad 100644 --- a/tidy3d/components/spice/analysis/ac.py +++ b/tidy3d/components/spice/analysis/ac.py @@ -2,7 +2,7 @@ from abc import ABC -import pydantic.v1 as pd +from pydantic import Field, field_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.spice.analysis.dc import ( @@ -35,16 +35,16 @@ class AbstractSSACAnalysis(Tidy3dBaseModel, ABC): >>> ssac_spec = td.SSACAnalysis(freqs=sweep_freqs) """ - freqs: ArrayFloat1D = pd.Field( - ..., + freqs: ArrayFloat1D = Field( title="Small Signal AC Frequencies", description="List of frequencies for small signal AC analysis. " "At least one :class:`.SSACVoltageSource` must be present in the boundary conditions.", units=HERTZ, ) - @pd.validator("freqs") - def validate_freqs(cls, val): + @field_validator("freqs") + @classmethod + def validate_freqs(cls, val: ArrayFloat1D) -> ArrayFloat1D: if len(val) == 0: raise ValueError("'freqs' cannot be empty (size 0).") else: diff --git a/tidy3d/components/spice/analysis/dc.py b/tidy3d/components/spice/analysis/dc.py index ec3afc44c9..30f26b2f8e 100644 --- a/tidy3d/components/spice/analysis/dc.py +++ b/tidy3d/components/spice/analysis/dc.py @@ -4,7 +4,7 @@ from __future__ import annotations -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, PositiveInt from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import KELVIN @@ -20,19 +20,19 @@ class ChargeToleranceSpec(Tidy3dBaseModel): >>> charge_settings = td.ChargeToleranceSpec(abs_tol=1e8, rel_tol=1e-10, max_iters=30) """ - abs_tol: pd.PositiveFloat = pd.Field( + abs_tol: PositiveFloat = Field( default=1e10, title="Absolute tolerance.", description="Absolute tolerance used as stop criteria when converging towards a solution.", ) - rel_tol: pd.PositiveFloat = pd.Field( + rel_tol: PositiveFloat = Field( default=1e-10, title="Relative tolerance.", description="Relative tolerance used as stop criteria when converging towards a solution.", ) - max_iters: pd.PositiveInt = pd.Field( + max_iters: PositiveInt = Field( default=30, title="Maximum number of iterations.", description="Indicates the maximum number of iterations to be run. " @@ -40,7 +40,7 @@ class ChargeToleranceSpec(Tidy3dBaseModel): "or when the tolerance criteria has been met.", ) - ramp_up_iters: pd.PositiveInt = pd.Field( + ramp_up_iters: PositiveInt = Field( default=1, title="Ramp-up iterations.", description="In order to help in start up, quantities such as doping " @@ -54,13 +54,13 @@ class SteadyChargeDCAnalysis(Tidy3dBaseModel): Configures relevant steady-state DC simulation parameters for a charge simulation. """ - tolerance_settings: ChargeToleranceSpec = pd.Field( + tolerance_settings: ChargeToleranceSpec = Field( default=ChargeToleranceSpec(), title="Tolerance settings", description="Charge tolerance parameters relevant to multiple simulation analysis types.", ) - convergence_dv: pd.PositiveFloat = pd.Field( + convergence_dv: PositiveFloat = Field( default=1.0, title="Bias step.", description="By default, a solution is computed at 0 bias. If a bias different than " @@ -69,7 +69,7 @@ class SteadyChargeDCAnalysis(Tidy3dBaseModel): "convergence parameter in DC computations.", ) - fermi_dirac: bool = pd.Field( + fermi_dirac: bool = Field( False, title="Fermi-Dirac statistics", description="Determines whether Fermi-Dirac statistics are used. When ``False``, " @@ -84,7 +84,7 @@ class IsothermalSteadyChargeDCAnalysis(SteadyChargeDCAnalysis): Configures relevant Isothermal steady-state DC simulation parameters for a charge simulation. """ - temperature: pd.PositiveFloat = pd.Field( + temperature: PositiveFloat = Field( 300, title="Temperature", description="Lattice temperature. Assumed constant throughout the device. " diff --git a/tidy3d/components/spice/sources/ac.py b/tidy3d/components/spice/sources/ac.py index 17af5fd9ff..a82cf95d31 100644 --- a/tidy3d/components/spice/sources/ac.py +++ b/tidy3d/components/spice/sources/ac.py @@ -2,7 +2,7 @@ from typing import Optional -import pydantic.v1 as pd +from pydantic import Field, FiniteFloat, field_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.types import ArrayFloat1D @@ -33,36 +33,37 @@ class SSACVoltageSource(Tidy3dBaseModel): ... ) """ - name: Optional[str] = pd.Field( + name: Optional[str] = Field( None, title="Name", description="Unique name for the SSAC voltage source.", min_length=1, ) - voltage: ArrayFloat1D = pd.Field( - ..., + voltage: ArrayFloat1D = Field( title="DC Bias Voltages", description="List of DC operating point voltages (above ground) used with :class:`VoltageBC`.", units=VOLT, ) - amplitude: pd.FiniteFloat = pd.Field( + amplitude: FiniteFloat = Field( default=1.0, title="Small Signal Amplitude", description="Amplitude of the small-signal perturbation for SSAC analysis.", units=VOLT, ) - @pd.validator("voltage") - def validate_voltage(cls, val): + @field_validator("voltage") + @classmethod + def validate_voltage(cls, val: ArrayFloat1D) -> ArrayFloat1D: for v in val: if v == td_inf: raise ValueError(f"Voltages must be finite. Current voltage={val}.") return val - @pd.validator("amplitude") - def validate_amplitude(cls, val): + @field_validator("amplitude") + @classmethod + def validate_amplitude(cls, val: FiniteFloat) -> FiniteFloat: if val == td_inf: raise ValueError(f"Signal amplitude must be finite. Current amplitude={val}.") return val diff --git a/tidy3d/components/spice/sources/dc.py b/tidy3d/components/spice/sources/dc.py index dcc1eb0e44..5906f31b78 100644 --- a/tidy3d/components/spice/sources/dc.py +++ b/tidy3d/components/spice/sources/dc.py @@ -21,14 +21,18 @@ from __future__ import annotations -from typing import Literal, Optional +from typing import TYPE_CHECKING, Literal, Optional -import pydantic.v1 as pd +import numpy as np +from pydantic import Field, FiniteFloat, field_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.types import ArrayFloat1D -from tidy3d.constants import AMP, VOLT -from tidy3d.constants import inf as td_inf +from tidy3d.constants import AMP, VOLT, inf +from tidy3d.log import log + +if TYPE_CHECKING: + from numpy.typing import NDArray class DCVoltageSource(Tidy3dBaseModel): @@ -48,15 +52,14 @@ class DCVoltageSource(Tidy3dBaseModel): >>> voltage_source = td.DCVoltageSource(voltage=voltages) """ - name: Optional[str] = pd.Field( + name: Optional[str] = Field( None, title="Name", description="Unique name for the DC voltage source", min_length=1, ) - voltage: ArrayFloat1D = pd.Field( - ..., + voltage: ArrayFloat1D = Field( title="Voltage", description="DC voltage usually used as source in :class:`VoltageBC` boundary conditions.", units=VOLT, @@ -66,13 +69,50 @@ class DCVoltageSource(Tidy3dBaseModel): # standalone field. Keeping for compatibility, remove in 3.0. units: Literal[VOLT] = VOLT - @pd.validator("voltage") - def check_voltage(cls, val): + @field_validator("voltage") + @classmethod + def check_voltage(cls, val: ArrayFloat1D) -> ArrayFloat1D: for v in val: - if v == td_inf: + if v == inf: raise ValueError(f"Voltages must be finite. Currently voltage={val}.") return val + @staticmethod + def _count_unique_with_tolerance(arr: NDArray, rtol: float = 1e-9, atol: float = 1e-12) -> int: + """Count unique values treating values within tolerance as duplicates. + + Uses sorted comparison to group values that are practically equal + due to floating-point representation differences (e.g., single vs double precision). + """ + if len(arr) == 0: + return 0 + sorted_arr = np.sort(arr) + # Count values that are "different enough" from their predecessor + unique_count = 1 + for i in range(1, len(sorted_arr)): + if not np.isclose(sorted_arr[i], sorted_arr[i - 1], rtol=rtol, atol=atol): + unique_count += 1 + return unique_count + + @field_validator("voltage") + @classmethod + def check_repeated_voltage(cls, val: ArrayFloat1D) -> ArrayFloat1D: + """Warn if repeated voltage values are present, treating 0 and -0 as the same value. + + Uses tolerance-based comparison to handle floating-point representation + differences (e.g., values from single vs double precision sources). + """ + # Normalize all zero values (both 0.0 and -0.0) to 0.0 so they are treated as duplicates + normalized = np.where(np.isclose(val, 0, atol=1e-10), 0.0, val) + unique_count = cls._count_unique_with_tolerance(normalized) + if unique_count < len(val): + log.warning( + "Duplicate voltage values detected in 'voltage' array. " + f"Found {len(val)} values but only {unique_count} are unique. " + "Note: values within floating-point tolerance are considered duplicates." + ) + return val + class GroundVoltage(Tidy3dBaseModel): """ @@ -108,14 +148,14 @@ class DCCurrentSource(Tidy3dBaseModel): >>> current_source = td.DCCurrentSource(current=0.4) """ - name: Optional[str] = pd.Field( + name: Optional[str] = Field( None, title="Name", description="Unique name for the DC current source", min_length=1, ) - current: pd.FiniteFloat = pd.Field( + current: FiniteFloat = Field( title="Current", description="DC current usually used as source in :class:`CurrentBC` boundary conditions.", units=AMP, diff --git a/tidy3d/components/structure.py b/tidy3d/components/structure.py index 2d99c55559..d100e80f33 100644 --- a/tidy3d/components/structure.py +++ b/tidy3d/components/structure.py @@ -5,23 +5,20 @@ import pathlib from collections import defaultdict from functools import cmp_to_key -from os import PathLike -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import autograd.numpy as anp import numpy as np -import pydantic.v1 as pydantic +from autograd.extend import Box as AutogradBox +from pydantic import Field, PositiveFloat, field_validator, model_validator from tidy3d.config import config from tidy3d.constants import MICROMETER from tidy3d.exceptions import SetupError, Tidy3dImportError from tidy3d.log import log -from .autograd.derivative_utils import DerivativeInfo -from .autograd.types import AutogradFieldMap -from .autograd.types import Box as AutogradBox from .autograd.utils import contains, get_static -from .base import Tidy3dBaseModel, skip_if_fields_missing +from .base import Tidy3dBaseModel from .data.data_array import ScalarFieldDataArray from .geometry.base import Box, Geometry from .geometry.utils import GeometryType, validate_no_transformed_polyslabs @@ -30,10 +27,23 @@ from .material.types import StructureMediumType from .medium import AbstractCustomMedium, CustomMedium, LossyMetalMedium, Medium, Medium2D from .monitor import FieldMonitor, PermittivityMonitor -from .types import TYPE_TAG_STR, Ax, Axis, PriorityMode +from .types import TYPE_TAG_STR from .validators import validate_name_str from .viz import add_ax_if_none, equal_aspect +if TYPE_CHECKING: + from os import PathLike + + import gdstk + from pydantic import NonNegativeFloat, NonNegativeInt + + from tidy3d import VisualizationSpec + from tidy3d.compat import Self + + from .autograd.derivative_utils import DerivativeInfo + from .autograd.types import AutogradFieldMap + from .types import Ax, Axis, PriorityMode + try: gdstk_available = True import gdstk @@ -46,16 +56,15 @@ class AbstractStructure(Tidy3dBaseModel): A basic structure object. """ - geometry: GeometryType = pydantic.Field( - ..., + geometry: GeometryType = Field( title="Geometry", description="Defines geometric properties of the structure.", discriminator=TYPE_TAG_STR, ) - name: str = pydantic.Field(None, title="Name", description="Optional name for the structure.") + name: Optional[str] = Field(None, title="Name", description="Optional name for the structure.") - background_permittivity: float = pydantic.Field( + background_permittivity: Optional[float] = Field( None, ge=1.0, title="Background Permittivity", @@ -64,7 +73,7 @@ class AbstractStructure(Tidy3dBaseModel): "when performing shape optimization with autograd.", ) - background_medium: StructureMediumType = pydantic.Field( + background_medium: Optional[StructureMediumType] = Field( None, title="Background Medium", description="Medium used for the background of this structure " @@ -73,7 +82,7 @@ class AbstractStructure(Tidy3dBaseModel): "``Simulation`` by default to compute the shape derivatives.", ) - priority: int = pydantic.Field( + priority: Optional[int] = Field( None, title="Priority", description="Priority of the structure applied in structure overlapping region. " @@ -83,12 +92,12 @@ class AbstractStructure(Tidy3dBaseModel): "the value is automatically assigned based on `structure_priority_mode` in the `Simulation`.", ) - @pydantic.root_validator(skip_on_failure=True) - def _handle_background_mediums(cls, values): + @model_validator(mode="after") + def _handle_background_mediums(self) -> Self: """Handle background medium combinations, including deprecation.""" - background_permittivity = values.get("background_permittivity") - background_medium = values.get("background_medium") + background_permittivity = self.background_permittivity + background_medium = self.background_medium # old case, only permittivity supplied, warn and set the Medium automatically if background_medium is None and background_permittivity is not None: @@ -97,7 +106,9 @@ def _handle_background_mediums(cls, values): "set the 'Structure.background_medium' directly using a 'Medium'. " "Handling automatically using the supplied relative permittivity." ) - values["background_medium"] = Medium(permittivity=background_permittivity) + object.__setattr__( + self, "background_medium", Medium(permittivity=background_permittivity) + ) # both present, just make sure they are consistent, error if not if background_medium is not None and background_permittivity is not None: @@ -108,12 +119,13 @@ def _handle_background_mediums(cls, values): "Use 'background_medium' only as 'background_permittivity' is deprecated." ) - return values + return self _name_validator = validate_name_str() - @pydantic.validator("geometry") - def _transformed_slanted_polyslabs_not_allowed(cls, val): + @field_validator("geometry") + @classmethod + def _transformed_slanted_polyslabs_not_allowed(cls, val: GeometryType) -> GeometryType: """Prevents the creation of slanted polyslabs rotated out of plane.""" validate_no_transformed_polyslabs(val) return val @@ -132,7 +144,7 @@ def _sort_structures( ) -> list[StructureType]: """Sort structure lists based on their priority values in ascending order.""" - def structure_comparator(struct1, struct2): + def structure_comparator(struct1: StructureType, struct2: StructureType) -> int: return struct1._priority(structure_priority_mode) - struct2._priority( structure_priority_mode ) @@ -216,8 +228,7 @@ class Structure(AbstractStructure): * `Structures `_ """ - medium: StructureMediumType = pydantic.Field( - ..., + medium: StructureMediumType = Field( title="Medium", description="Defines the electromagnetic properties of the structure's medium.", discriminator=TYPE_TAG_STR, @@ -238,7 +249,7 @@ def _priority(self, priority_mode: PriorityMode) -> int: return 0 @property - def viz_spec(self): + def viz_spec(self) -> Optional[VisualizationSpec]: return self.medium.viz_spec def eps_diagonal(self, frequency: float, coords: Coords) -> tuple[complex, complex, complex]: @@ -259,7 +270,7 @@ def eps_diagonal(self, frequency: float, coords: Coords) -> tuple[complex, compl return self.medium.eps_diagonal(frequency=frequency) @staticmethod - def _get_optical_medium(medium): + def _get_optical_medium(medium: MultiPhysicsMedium) -> Optional[StructureMediumType]: """Get optical medium.""" return medium.optical if isinstance(medium, MultiPhysicsMedium) else medium @@ -268,11 +279,11 @@ def _optical_medium(self) -> StructureMediumType: """Optical medium of the structure.""" return self._get_optical_medium(self.medium) - @pydantic.validator("medium", always=True) - @skip_if_fields_missing(["geometry"]) - def _check_2d_geometry(cls, val, values): + @model_validator(mode="after") + def _check_2d_geometry(self) -> Self: """Medium2D is only consistent with certain geometry types""" - geom = values.get("geometry") + val = self.medium + geom = self.geometry if isinstance(val, Medium2D): # the geometry needs to be supported by 2d materials @@ -286,7 +297,7 @@ def _check_2d_geometry(cls, val, values): # if the geometry is not supported / not 2d _ = geom._normal_2dmaterial - return val + return self def _compatible_with(self, other: Structure) -> bool: """Whether these two structures are compatible.""" @@ -314,12 +325,47 @@ def _get_monitor_name(index: int, data_type: str) -> str: return monitor_name_map[data_type] def _make_adjoint_monitors( - self, freqs: list[float], index: int, field_keys: list[str] - ) -> (FieldMonitor, PermittivityMonitor): + self, + freqs: list[float], + index: int, + field_keys: list[str], + plane: Optional[Box] = None, + ) -> tuple[FieldMonitor, PermittivityMonitor]: """Generate the field and permittivity monitor for this structure.""" geometry = self.geometry - box = geometry.bounding_box + geom_box = geometry.bounding_box + + def _box_from_plane_intersection() -> Box: + plane_axis = plane._normal_axis + plane_position = plane.center[plane_axis] + axis_char = "xyz"[plane_axis] + + intersections = geometry.intersections_plane(**{axis_char: plane_position}) + bounds = [shape.bounds for shape in intersections if not shape.is_empty] + if len(bounds) == 0: + intersections = geom_box.intersections_plane(**{axis_char: plane_position}) + bounds = [shape.bounds for shape in intersections if not shape.is_empty] + if len(bounds) == 0: # fallback + return geom_box + + min_plane = (min(b[0] for b in bounds), min(b[1] for b in bounds)) + max_plane = (max(b[2] for b in bounds), max(b[3] for b in bounds)) + + rmin = [plane_position, plane_position, plane_position] + rmax = [plane_position, plane_position, plane_position] + + _, plane_axes = Geometry.pop_axis((0, 1, 2), axis=plane_axis) + for ind, ax in enumerate(plane_axes): + rmin[ax] = min_plane[ind] + rmax[ax] = max_plane[ind] + + return Box.from_bounds(tuple(rmin), tuple(rmax)) + + if plane is not None: + box = _box_from_plane_intersection() + else: + box = geom_box # we dont want these fields getting traced by autograd, otherwise it messes stuff up size = [get_static(x) for x in box.size] @@ -333,7 +379,12 @@ def _make_adjoint_monitors( interval_space = monitor_cfg.monitor_interval_poly field_components_for_adjoint = [f"E{dim}" for dim in "xyz"] - if self.medium.is_pec: + + background_medium_pec = ( + self.background_medium is not None + ) and self.background_medium.is_pec + if self.medium.is_pec or background_medium_pec: + # record H-fields when the structure medium or its background is marked PEC field_components_for_adjoint += [f"H{dim}" for dim in "xyz"] mnt_fld = FieldMonitor( @@ -420,12 +471,12 @@ def to_gdstk( x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, pixel_exact: bool = False, - ) -> None: + ) -> list[Any]: """Convert a structure's planar slice to a .gds type polygon. Parameters @@ -532,14 +583,14 @@ def to_gdstk( def to_gds( self, - cell, + cell: gdstk.Cell, x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, pixel_exact: bool = False, ) -> None: """Append a structure's planar slice to a .gds cell. @@ -592,10 +643,10 @@ def to_gds_file( x: Optional[float] = None, y: Optional[float] = None, z: Optional[float] = None, - permittivity_threshold: pydantic.NonNegativeFloat = 1, - frequency: pydantic.PositiveFloat = 0, - gds_layer: pydantic.NonNegativeInt = 0, - gds_dtype: pydantic.NonNegativeInt = 0, + permittivity_threshold: NonNegativeFloat = 1, + frequency: PositiveFloat = 0, + gds_layer: NonNegativeInt = 0, + gds_dtype: NonNegativeInt = 0, gds_cell_name: str = "MAIN", pixel_exact: bool = False, ) -> None: @@ -729,24 +780,23 @@ class MeshOverrideStructure(AbstractStructure): """ dl: tuple[ - Optional[pydantic.PositiveFloat], - Optional[pydantic.PositiveFloat], - Optional[pydantic.PositiveFloat], - ] = pydantic.Field( - ..., + Optional[PositiveFloat], + Optional[PositiveFloat], + Optional[PositiveFloat], + ] = Field( title="Grid Size", description="Grid size along x, y, z directions.", units=MICROMETER, ) - priority: int = pydantic.Field( + priority: int = Field( 0, title="Priority", description="Priority of the structure applied in mesh override structure overlapping region. " "The priority of internal override structures is ``-1``.", ) - enforce: bool = pydantic.Field( + enforce: bool = Field( False, title="Enforce Grid Size", description="If ``True``, enforce the grid size setup inside the structure " @@ -755,7 +805,7 @@ class MeshOverrideStructure(AbstractStructure): "the last added structure of ``enforce=True``.", ) - shadow: bool = pydantic.Field( + shadow: bool = Field( True, title="Grid Size Choice In Structure Overlapping Region", description="In structure intersection region, grid size is decided by the latter added " @@ -764,7 +814,7 @@ class MeshOverrideStructure(AbstractStructure): "the bounding box of the structure is disabled.", ) - drop_outside_sim: bool = pydantic.Field( + drop_outside_sim: bool = Field( True, title="Drop Structure Outside Simulation Domain", description="If ``True``, structure outside the simulation domain is dropped; if ``False``, " @@ -772,8 +822,9 @@ class MeshOverrideStructure(AbstractStructure): "and that of the simulation domain overlap.", ) - @pydantic.validator("geometry") - def _box_only(cls, val): + @field_validator("geometry") + @classmethod + def _box_only(cls, val: GeometryType) -> GeometryType: """Ensure this is a box.""" if isinstance(val, Geometry): if not isinstance(val, Box): @@ -784,12 +835,12 @@ def _box_only(cls, val): return val.bounding_box return val - @pydantic.validator("shadow") - def _unshadowed_cannot_be_enforced(cls, val, values): + @model_validator(mode="after") + def _unshadowed_cannot_be_enforced(self) -> Self: """Unshadowed structure cannot be enforced.""" - if not val and values["enforce"]: + if not self.shadow and self.enforce: raise SetupError("A structure cannot be simultaneously enforced and unshadowed.") - return val + return self StructureType = Union[Structure, MeshOverrideStructure] diff --git a/tidy3d/components/subpixel_spec.py b/tidy3d/components/subpixel_spec.py index 89d0ceba8c..bcda62eb5c 100644 --- a/tidy3d/components/subpixel_spec.py +++ b/tidy3d/components/subpixel_spec.py @@ -1,12 +1,15 @@ # Defines specifications for subpixel averaging from __future__ import annotations -from typing import Union +from typing import TYPE_CHECKING, Union -import pydantic.v1 as pd +from pydantic import Field from .base import Tidy3dBaseModel, cached_property -from .types import TYPE_TAG_STR +from .types.base import discriminated_union + +if TYPE_CHECKING: + from tidy3d.compat import Self # Default Courant number reduction rate in PEC conformal's scheme DEFAULT_COURANT_REDUCTION_PEC_CONFORMAL = 0.3 @@ -65,7 +68,9 @@ class ContourPathAveraging(AbstractSubpixelAveragingMethod): """ -DielectricSubpixelType = Union[Staircasing, PolarizedAveraging, ContourPathAveraging] +DielectricSubpixelType = discriminated_union( + Union[Staircasing, PolarizedAveraging, ContourPathAveraging] +) class VolumetricAveraging(AbstractSubpixelAveragingMethod): @@ -73,7 +78,7 @@ class VolumetricAveraging(AbstractSubpixelAveragingMethod): The material property is averaged in the volume surrounding the Yee grid. """ - staircase_normal_component: bool = pd.Field( + staircase_normal_component: bool = Field( True, title="Staircasing For Field Components Substantially Normal To Interface", description="Volumetric averaging works accurately if the electric field component " @@ -83,7 +88,7 @@ class VolumetricAveraging(AbstractSubpixelAveragingMethod): ) -MetalSubpixelType = Union[Staircasing, VolumetricAveraging] +MetalSubpixelType = discriminated_union(Union[Staircasing, VolumetricAveraging]) class HeuristicPECStaircasing(AbstractSubpixelAveragingMethod): @@ -110,7 +115,7 @@ class PECConformal(AbstractSubpixelAveragingMethod): IEEE Transactions on Antennas and Propagation, 54(6), 1843 (2006). """ - timestep_reduction: float = pd.Field( + timestep_reduction: float = Field( DEFAULT_COURANT_REDUCTION_PEC_CONFORMAL, title="Time Step Size Reduction Rate", description="Reduction factor between 0 and 1 such that the simulation's time step size " @@ -120,7 +125,7 @@ class PECConformal(AbstractSubpixelAveragingMethod): ge=0, ) - edge_singularity_correction: bool = pd.Field( + edge_singularity_correction: bool = Field( True, title="Apply Singularity Model At Metal Edges", description="Apply field correction model at metallic edges where field singularity occurs. " @@ -136,8 +141,8 @@ def courant_ratio(self) -> float: return 1 - self.timestep_reduction -PECSubpixelType = Union[Staircasing, HeuristicPECStaircasing, PECConformal] -PMCSubpixelType = Union[Staircasing, HeuristicPECStaircasing] +PECSubpixelType = discriminated_union(Union[Staircasing, HeuristicPECStaircasing, PECConformal]) +PMCSubpixelType = discriminated_union(Union[Staircasing, HeuristicPECStaircasing]) class SurfaceImpedance(PECConformal): @@ -145,7 +150,7 @@ class SurfaceImpedance(PECConformal): structure made of :class:`.LossyMetalMedium`. """ - timestep_reduction: float = pd.Field( + timestep_reduction: float = Field( DEFAULT_COURANT_REDUCTION_SIBC_CONFORMAL, title="Time Step Size Reduction Rate", description="Reduction factor between 0 and 1 such that the simulation's time step size " @@ -156,51 +161,48 @@ class SurfaceImpedance(PECConformal): ) -LossyMetalSubpixelType = Union[Staircasing, VolumetricAveraging, SurfaceImpedance] +LossyMetalSubpixelType = discriminated_union( + Union[Staircasing, VolumetricAveraging, SurfaceImpedance] +) class SubpixelSpec(Tidy3dBaseModel): """Defines specification for subpixel averaging schemes when added to ``Simulation.subpixel``.""" - dielectric: DielectricSubpixelType = pd.Field( - PolarizedAveraging(), + dielectric: DielectricSubpixelType = Field( + default_factory=PolarizedAveraging, title="Subpixel Averaging Method For Dielectric Interfaces", description="Subpixel averaging method applied to dielectric material interfaces.", - discriminator=TYPE_TAG_STR, ) - metal: MetalSubpixelType = pd.Field( - Staircasing(), + metal: MetalSubpixelType = Field( + default_factory=Staircasing, title="Subpixel Averaging Method For Metallic Interfaces", description="Subpixel averaging method applied to metallic structure interfaces. " "A material is considered as metallic if its real part of relative permittivity " "is less than 1 at the central frequency.", - discriminator=TYPE_TAG_STR, ) - pec: PECSubpixelType = pd.Field( - PECConformal(), + pec: PECSubpixelType = Field( + default_factory=PECConformal, title="Subpixel Averaging Method For PEC Interfaces", description="Subpixel averaging method applied to PEC structure interfaces.", - discriminator=TYPE_TAG_STR, ) - pmc: PMCSubpixelType = pd.Field( - Staircasing(), + pmc: PMCSubpixelType = Field( + default_factory=Staircasing, title="Subpixel Averaging Method For PMC Interfaces", description="Subpixel averaging method applied to PMC structure interfaces.", - discriminator=TYPE_TAG_STR, ) - lossy_metal: LossyMetalSubpixelType = pd.Field( - SurfaceImpedance(), + lossy_metal: LossyMetalSubpixelType = Field( + default_factory=SurfaceImpedance, title="Subpixel Averaging Method for Lossy Metal Interfaces", description="Subpixel averaging method applied to ``td.LossyMetalMedium`` material interfaces.", - discriminator=TYPE_TAG_STR, ) @classmethod - def staircasing(cls) -> SubpixelSpec: + def staircasing(cls) -> Self: """Apply staircasing on all material boundaries.""" return cls( dielectric=Staircasing(), diff --git a/tidy3d/components/tcad/analysis/heat_simulation_type.py b/tidy3d/components/tcad/analysis/heat_simulation_type.py index 154f10dce2..701cee8534 100644 --- a/tidy3d/components/tcad/analysis/heat_simulation_type.py +++ b/tidy3d/components/tcad/analysis/heat_simulation_type.py @@ -2,7 +2,7 @@ from __future__ import annotations -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, PositiveInt from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import KELVIN, SECOND @@ -20,14 +20,14 @@ class UnsteadySpec(Tidy3dBaseModel): ... ) """ - time_step: pd.PositiveFloat = pd.Field( + time_step: PositiveFloat = Field( ..., title="Time-step", description="Time step taken for each iteration of the time integration loop.", units=SECOND, ) - total_time_steps: pd.PositiveInt = pd.Field( + total_time_steps: PositiveInt = Field( ..., title="Total time steps", description="Specifies the total number of time steps run during the simulation.", @@ -50,14 +50,14 @@ class UnsteadyHeatAnalysis(Tidy3dBaseModel): ... ) """ - initial_temperature: pd.PositiveFloat = pd.Field( + initial_temperature: PositiveFloat = Field( ..., title="Initial temperature.", description="Initial value for the temperature field.", units=KELVIN, ) - unsteady_spec: UnsteadySpec = pd.Field( + unsteady_spec: UnsteadySpec = Field( ..., title="Unsteady specification", description="Time step and total time steps for the unsteady simulation.", diff --git a/tidy3d/components/tcad/bandgap.py b/tidy3d/components/tcad/bandgap.py index 7b889f174c..c6d5fbdacd 100644 --- a/tidy3d/components/tcad/bandgap.py +++ b/tidy3d/components/tcad/bandgap.py @@ -1,6 +1,6 @@ from __future__ import annotations -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import PERCMCUBE, VOLT @@ -37,27 +37,24 @@ class SlotboomBandGapNarrowing(Tidy3dBaseModel): .. [1] 'UNIFIED APPARENT BANDGAP NARROWING IN n- AND p-TYPE SILICON' Solid-State Electronics Vol. 35, No. 2, pp. 125-129, 1992""" - v1: pd.PositiveFloat = pd.Field( - ..., + v1: PositiveFloat = Field( title=":math:`V_{1,bgn}` parameter", description=":math:`V_{1,bgn}` parameter", units=VOLT, ) - n2: pd.PositiveFloat = pd.Field( - ..., + n2: PositiveFloat = Field( title=":math:`N_{2,bgn}` parameter", description=":math:`N_{2,bgn}` parameter", units=PERCMCUBE, ) - c2: float = pd.Field( + c2: float = Field( title=":math:`C_{2,bgn}` parameter", description=":math:`C_{2,bgn}` parameter", ) - min_N: pd.NonNegativeFloat = pd.Field( - ..., + min_N: NonNegativeFloat = Field( title="Minimum total doping", description="Bandgap narrowing is applied at location where total doping " "is higher than ``min_N``.", diff --git a/tidy3d/components/tcad/bandgap_energy.py b/tidy3d/components/tcad/bandgap_energy.py index 450966dc08..eea311df7a 100644 --- a/tidy3d/components/tcad/bandgap_energy.py +++ b/tidy3d/components/tcad/bandgap_energy.py @@ -1,6 +1,6 @@ from __future__ import annotations -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import ELECTRON_VOLT @@ -9,7 +9,7 @@ class ConstantEnergyBandGap(Tidy3dBaseModel): """Constant Energy band gap""" - eg: pd.PositiveFloat = pd.Field( + eg: PositiveFloat = Field( title="Band Gap", description="Energy band gap", units=ELECTRON_VOLT, @@ -45,22 +45,19 @@ class VarshniEnergyBandGap(Tidy3dBaseModel): """ - eg_0: pd.PositiveFloat = pd.Field( - ..., + eg_0: PositiveFloat = Field( title="Band Gap at 0 K", description="Energy band gap at absolute zero (0 Kelvin).", units=ELECTRON_VOLT, ) - alpha: pd.PositiveFloat = pd.Field( - ..., + alpha: PositiveFloat = Field( title="Varshni Alpha Coefficient", description="Empirical Varshni coefficient (α).", units="eV/K", ) - beta: pd.PositiveFloat = pd.Field( - ..., + beta: PositiveFloat = Field( title="Varshni Beta Coefficient", description="Empirical Varshni coefficient (β), related to the Debye temperature.", units="K", diff --git a/tidy3d/components/tcad/boundary/charge.py b/tidy3d/components/tcad/boundary/charge.py index 3ce5cf7228..fba6e13c0b 100644 --- a/tidy3d/components/tcad/boundary/charge.py +++ b/tidy3d/components/tcad/boundary/charge.py @@ -2,7 +2,7 @@ from __future__ import annotations -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.spice.sources.types import CurrentSourceType, VoltageSourceType from tidy3d.components.tcad.boundary.abstract import HeatChargeBC @@ -28,8 +28,7 @@ class VoltageBC(HeatChargeBC): >>> voltage_bc = td.VoltageBC(source=voltage_source) """ - source: VoltageSourceType = pd.Field( - ..., + source: VoltageSourceType = Field( title="Voltage", description="Electric potential to be applied at the specified boundary.", units=VOLT, @@ -47,8 +46,7 @@ class CurrentBC(HeatChargeBC): >>> current_bc = CurrentBC(source=current_source) """ - source: CurrentSourceType = pd.Field( - ..., + source: CurrentSourceType = Field( title="Current Source", description="A current source", units=CURRENT_DENSITY, diff --git a/tidy3d/components/tcad/boundary/heat.py b/tidy3d/components/tcad/boundary/heat.py index 7430cc22f2..f6065a3cf4 100644 --- a/tidy3d/components/tcad/boundary/heat.py +++ b/tidy3d/components/tcad/boundary/heat.py @@ -2,9 +2,9 @@ from __future__ import annotations -from typing import Union +from typing import TYPE_CHECKING, Optional, Union -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.material.tcad.heat import FluidMedium @@ -18,6 +18,9 @@ MICROMETER, ) +if TYPE_CHECKING: + from tidy3d.compat import Self + class TemperatureBC(HeatChargeBC): """Constant temperature thermal boundary conditions. @@ -28,7 +31,7 @@ class TemperatureBC(HeatChargeBC): >>> bc = td.TemperatureBC(temperature=300) """ - temperature: pd.PositiveFloat = pd.Field( + temperature: PositiveFloat = Field( title="Temperature", description="Temperature value.", units=KELVIN, @@ -44,7 +47,7 @@ class HeatFluxBC(HeatChargeBC): >>> bc = td.HeatFluxBC(flux=1) """ - flux: float = pd.Field( + flux: float = Field( title="Heat Flux", description="Heat flux value.", units=HEAT_FLUX, @@ -68,7 +71,7 @@ class VerticalNaturalConvectionCoeffModel(Tidy3dBaseModel): """ - medium: FluidMedium = pd.Field( + medium: Optional[FluidMedium] = Field( default=None, title="Interface medium", description=( @@ -78,24 +81,26 @@ class VerticalNaturalConvectionCoeffModel(Tidy3dBaseModel): ), ) - plate_length: pd.NonNegativeFloat = pd.Field( + plate_length: NonNegativeFloat = Field( title="Plate Characteristic Length", description="Characteristic length (L), defined as the height of the vertical plate.", units=MICROMETER, ) - gravity: pd.NonNegativeFloat = pd.Field( + gravity: NonNegativeFloat = Field( default=GRAV_ACC, title="Gravitational Acceleration", description="Gravitational acceleration (g).", units=ACCELERATION, ) + @classmethod def from_si_units( - plate_length: pd.NonNegativeFloat, + cls, + plate_length: NonNegativeFloat, medium: FluidMedium = None, - gravity: pd.NonNegativeFloat = GRAV_ACC * 1e-6, - ): + gravity: NonNegativeFloat = GRAV_ACC * 1e-6, + ) -> Self: """ Create an instance from standard SI units. @@ -113,7 +118,7 @@ def from_si_units( plate_length_tidy = plate_length * 1e6 # m -> um g_tidy = gravity * 1e6 # m/s**2 -> um/s**2 - return VerticalNaturalConvectionCoeffModel( + return cls( medium=medium, plate_length=plate_length_tidy, gravity=g_tidy, @@ -156,13 +161,13 @@ class ConvectionBC(HeatChargeBC): ... ) """ - ambient_temperature: pd.PositiveFloat = pd.Field( + ambient_temperature: PositiveFloat = Field( title="Ambient Temperature", description="Ambient temperature.", units=KELVIN, ) - transfer_coeff: Union[pd.NonNegativeFloat, VerticalNaturalConvectionCoeffModel] = pd.Field( + transfer_coeff: Union[NonNegativeFloat, VerticalNaturalConvectionCoeffModel] = Field( title="Heat Transfer Coefficient", description="Heat transfer coefficient value.", units=HEAT_TRANSFER_COEFF, diff --git a/tidy3d/components/tcad/boundary/specification.py b/tidy3d/components/tcad/boundary/specification.py index 08289fa19f..193fb3b546 100644 --- a/tidy3d/components/tcad/boundary/specification.py +++ b/tidy3d/components/tcad/boundary/specification.py @@ -2,7 +2,7 @@ from __future__ import annotations -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.bc_placement import BCPlacementType @@ -22,13 +22,13 @@ class HeatChargeBoundarySpec(Tidy3dBaseModel): ... ) """ - placement: BCPlacementType = pd.Field( + placement: BCPlacementType = Field( title="Boundary Conditions Placement", description="Location to apply boundary conditions.", discriminator=TYPE_TAG_STR, ) - condition: HeatChargeBCType = pd.Field( + condition: HeatChargeBCType = Field( title="Boundary Conditions", description="Boundary conditions to apply at the selected location.", discriminator=TYPE_TAG_STR, diff --git a/tidy3d/components/tcad/data/monitor_data/abstract.py b/tidy3d/components/tcad/data/monitor_data/abstract.py index 629c97ac74..7228da5adf 100644 --- a/tidy3d/components/tcad/data/monitor_data/abstract.py +++ b/tidy3d/components/tcad/data/monitor_data/abstract.py @@ -7,22 +7,19 @@ from typing import Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base_sim.data.monitor_data import AbstractMonitorData -from tidy3d.components.data.data_array import ( - SpatialDataArray, -) +from tidy3d.components.data.data_array import SpatialDataArray from tidy3d.components.data.utils import TetrahedralGridDataset, TriangularGridDataset -from tidy3d.components.tcad.types import ( - HeatChargeMonitorType, -) -from tidy3d.components.types import Coordinate, ScalarSymmetry, annotate_type +from tidy3d.components.tcad.types import HeatChargeMonitorType +from tidy3d.components.types import Coordinate, ScalarSymmetry +from tidy3d.components.types.base import discriminated_union from tidy3d.constants import MICROMETER from tidy3d.log import log FieldDataset = Union[ - SpatialDataArray, annotate_type(Union[TriangularGridDataset, TetrahedralGridDataset]) + SpatialDataArray, discriminated_union(Union[TriangularGridDataset, TetrahedralGridDataset]) ] UnstructuredFieldType = Union[TriangularGridDataset, TetrahedralGridDataset] @@ -30,19 +27,18 @@ class HeatChargeMonitorData(AbstractMonitorData, ABC): """Abstract base class of objects that store data pertaining to a single :class:`HeatChargeMonitor`.""" - monitor: HeatChargeMonitorType = pd.Field( - ..., + monitor: HeatChargeMonitorType = Field( title="Monitor", description="Monitor associated with the data.", ) - symmetry: tuple[ScalarSymmetry, ScalarSymmetry, ScalarSymmetry] = pd.Field( + symmetry: tuple[ScalarSymmetry, ScalarSymmetry, ScalarSymmetry] = Field( (0, 0, 0), title="Symmetry", description="Symmetry of the original simulation in x, y, and z.", ) - symmetry_center: Coordinate = pd.Field( + symmetry_center: Coordinate = Field( (0, 0, 0), title="Symmetry Center", description="Symmetry center of the original simulation in x, y, and z.", diff --git a/tidy3d/components/tcad/data/monitor_data/charge.py b/tidy3d/components/tcad/data/monitor_data/charge.py index f185934421..f3dc5b95b4 100644 --- a/tidy3d/components/tcad/data/monitor_data/charge.py +++ b/tidy3d/components/tcad/data/monitor_data/charge.py @@ -2,13 +2,12 @@ from __future__ import annotations -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator from tidy3d.components.data.data_array import ( - DataArray, IndexedFieldVoltageDataArray, IndexedVoltageDataArray, PointDataArray, @@ -25,34 +24,39 @@ SteadyFreeCarrierMonitor, SteadyPotentialMonitor, ) -from tidy3d.components.types import TYPE_TAG_STR, Ax, annotate_type +from tidy3d.components.types import TYPE_TAG_STR +from tidy3d.components.types.base import discriminated_union from tidy3d.components.viz import add_ax_if_none from tidy3d.exceptions import DataError +if TYPE_CHECKING: + from tidy3d.compat import Self + from tidy3d.components.types import Ax + FieldDataset = Union[ - SpatialDataArray, annotate_type(Union[TriangularGridDataset, TetrahedralGridDataset]) + discriminated_union(Union[TriangularGridDataset, TetrahedralGridDataset]), + SpatialDataArray, ] -UnstructuredFieldType = Union[TriangularGridDataset, TetrahedralGridDataset] +UnstructuredFieldType = discriminated_union(Union[TriangularGridDataset, TetrahedralGridDataset]) class SteadyPotentialData(HeatChargeMonitorData): """Stores electric potential :math:`\\psi` from a charge simulation.""" - monitor: SteadyPotentialMonitor = pd.Field( - ..., + monitor: SteadyPotentialMonitor = Field( title="Electric potential monitor", description="Electric potential monitor associated with a `charge` simulation.", ) - potential: FieldDataset = pd.Field( + potential: Optional[FieldDataset] = Field( None, title="Electric potential series", description="Contains the electric potential series.", ) @property - def field_components(self) -> dict[str, DataArray]: + def field_components(self) -> dict[str, Optional[FieldDataset]]: """Maps the field components to their associated data.""" return {"potential": self.potential} @@ -68,49 +72,42 @@ class SteadyFreeCarrierData(HeatChargeMonitorData): ``monitor``. """ - monitor: SteadyFreeCarrierMonitor = pd.Field( - ..., + monitor: SteadyFreeCarrierMonitor = Field( title="Free carrier monitor", description="Free carrier data associated with a Charge simulation.", ) - electrons: UnstructuredFieldType = pd.Field( + electrons: Optional[UnstructuredFieldType] = Field( None, title="Electrons series", description=r"Contains the computed electrons concentration :math:`n`.", - discriminator=TYPE_TAG_STR, ) # n = electrons - holes: UnstructuredFieldType = pd.Field( + holes: Optional[UnstructuredFieldType] = Field( None, title="Holes series", description=r"Contains the computed holes concentration :math:`p`.", - discriminator=TYPE_TAG_STR, ) # p = holes @property - def field_components(self) -> dict[str, DataArray]: + def field_components(self) -> dict[str, Optional[UnstructuredFieldType]]: """Maps the field components to their associated data.""" return {"electrons": self.electrons, "holes": self.holes} - @pd.root_validator(skip_on_failure=True) - def check_correct_data_type(cls, values): + @model_validator(mode="after") + def check_correct_data_type(self) -> Self: """Issue error if incorrect data type is used""" - - mnt = values.get("monitor") - field_data = {field: values.get(field) for field in ["electrons", "holes"]} - + field_data = {field: getattr(self, field) for field in ["electrons", "holes"]} for field, data in field_data.items(): if isinstance(data, TetrahedralGridDataset) or isinstance(data, TriangularGridDataset): if not isinstance(data.values, IndexedVoltageDataArray): raise ValueError( - f"In the data associated with monitor {mnt}, the field {field} does not contain " - "data associated to any voltage value." + f"In the data associated with monitor {self.monitor}, the " + f"field {field} does not contain data associated to any voltage value." ) - - return values + return self class SteadyEnergyBandData(HeatChargeMonitorData): @@ -148,68 +145,61 @@ class SteadyEnergyBandData(HeatChargeMonitorData): as defined in the ``monitor``. """ - monitor: SteadyEnergyBandMonitor = pd.Field( - ..., + monitor: SteadyEnergyBandMonitor = Field( title="Energy band monitor", description="Energy bands data associated with a Charge simulation.", ) - Ec: UnstructuredFieldType = pd.Field( + Ec: Optional[UnstructuredFieldType] = Field( None, title="Conduction band series", description="Contains the computed energy of the bottom of the conduction band :math:`E_c`.", - discriminator=TYPE_TAG_STR, ) - Ev: UnstructuredFieldType = pd.Field( + Ev: Optional[UnstructuredFieldType] = Field( None, title="Valence band series", description="Contains the computed energy of the top of the valence band :math:`E_v`.", - discriminator=TYPE_TAG_STR, ) - Ei: UnstructuredFieldType = pd.Field( + Ei: Optional[UnstructuredFieldType] = Field( None, title="Intrinsic Fermi level series", description="Contains the computed intrinsic Fermi level for the material :math:`E_i`.", - discriminator=TYPE_TAG_STR, ) - Efn: UnstructuredFieldType = pd.Field( + Efn: Optional[UnstructuredFieldType] = Field( None, title="Electron's quasi-Fermi level series", description="Contains the computed quasi-Fermi level for electrons :math:`E_{fn}`.", - discriminator=TYPE_TAG_STR, ) - Efp: UnstructuredFieldType = pd.Field( + Efp: Optional[UnstructuredFieldType] = Field( None, title="Hole's quasi-Fermi level series", description="Contains the computed quasi-Fermi level for holes :math:`E_{fp}`.", - discriminator=TYPE_TAG_STR, ) @property - def field_components(self) -> dict[str, DataArray]: + def field_components(self) -> dict[str, Optional[UnstructuredFieldType]]: """Maps the field components to their associated data.""" return {"Ec": self.Ec, "Ev": self.Ev, "Ei": self.Ei, "Efn": self.Efn, "Efp": self.Efp} - @pd.root_validator(skip_on_failure=True) - def check_correct_data_type(cls, values): + @model_validator(mode="after") + def check_correct_data_type(self) -> Self: """Issue error if incorrect data type is used""" - mnt = values.get("monitor") - field_data = {field: values.get(field) for field in ["Ec", "Ev", "Ei", "Efn", "Efp"]} + field_data = {field: getattr(self, field) for field in ["Ec", "Ev", "Ei", "Efn", "Efp"]} for field, data in field_data.items(): if isinstance(data, TetrahedralGridDataset) or isinstance(data, TriangularGridDataset): if not isinstance(data.values, IndexedVoltageDataArray): raise ValueError( - f"In the data associated with monitor {mnt}, the field {field} does not contain " - "data associated to any voltage value." + f"In the data associated with monitor {self.monitor}, the " + f"field {field} does not contain data associated to any voltage value." ) - return values + return self @add_ax_if_none def plot(self, ax: Ax = None, **sel_kwargs: Any) -> Ax: @@ -307,20 +297,19 @@ class SteadyCapacitanceData(HeatChargeMonitorData): This is only computed when a voltage source with more than two sources is included within the simulation and determines the :math:`\\Delta V`. """ - monitor: SteadyCapacitanceMonitor = pd.Field( - ..., + monitor: SteadyCapacitanceMonitor = Field( title="Capacitance monitor", description="Capacitance data associated with a Charge simulation.", ) - hole_capacitance: SteadyVoltageDataArray = pd.Field( + hole_capacitance: Optional[SteadyVoltageDataArray] = Field( None, title="Hole capacitance", description="Small signal capacitance :math:`(\\frac{dQ_p}{dV})` associated to the monitor.", ) # C_p = hole_capacitance - electron_capacitance: SteadyVoltageDataArray = pd.Field( + electron_capacitance: Optional[SteadyVoltageDataArray] = Field( None, title="Electron capacitance", description="Small signal capacitance :math:`(\\frac{dQn}{dV})` associated to the monitor.", @@ -371,17 +360,15 @@ class SteadyElectricFieldData(HeatChargeMonitorData): It is given in units of :math:`V/\\mu m` (Volts per micrometer). """ - monitor: SteadyElectricFieldMonitor = pd.Field( - ..., + monitor: SteadyElectricFieldMonitor = Field( title="Electric field monitor", description="Electric field data associated with a Charge/Conduction simulation.", ) - E: UnstructuredFieldType = pd.Field( + E: Optional[UnstructuredFieldType] = Field( None, title="Electric field", description="Contains the computed electric field.", - discriminator=TYPE_TAG_STR, units=":math:`V/\\mu m`", ) @@ -390,22 +377,18 @@ def field_components(self) -> dict[str, UnstructuredFieldType]: """Maps the field components to their associated data.""" return {"E": self.E} - @pd.root_validator(skip_on_failure=True) - def check_correct_data_type(cls, values): + @model_validator(mode="after") + def check_correct_data_type(self) -> Self: """Issue error if incorrect data type is used""" - mnt = values.get("monitor") - E = values.get("E") - - if isinstance(E, TetrahedralGridDataset) or isinstance(E, TriangularGridDataset): - AcceptedTypes = (IndexedFieldVoltageDataArray, PointDataArray) - if not isinstance(E.values, AcceptedTypes): + if isinstance(self.E, TetrahedralGridDataset) or isinstance(self.E, TriangularGridDataset): + if not isinstance(self.E.values, (IndexedFieldVoltageDataArray, PointDataArray)): raise ValueError( - f"In the data associated with monitor {mnt}, must contain a field. This can be " - "defined with IndexedFieldVoltageDataArray or PointDataArray." + f"The data associated with monitor {self.monitor.name} must contain a field. This can be " + "defined with 'IndexedFieldVoltageDataArray' or 'PointDataArray'." ) - return values + return self class SteadyCurrentDensityData(HeatChargeMonitorData): @@ -414,13 +397,12 @@ class SteadyCurrentDensityData(HeatChargeMonitorData): units of :math:`A/\\mu m^2` """ - monitor: SteadyCurrentDensityMonitor = pd.Field( - ..., + monitor: SteadyCurrentDensityMonitor = Field( title="Current density monitor", description="Current density data associated with a Charge/Conduction simulation.", ) - J: UnstructuredFieldType = pd.Field( + J: Optional[UnstructuredFieldType] = Field( None, title="Current density", description="Contains the computed current density.", @@ -433,12 +415,12 @@ def field_components(self) -> dict[str, UnstructuredFieldType]: """Maps the field components to their associated data.""" return {"J": self.J} - @pd.root_validator(skip_on_failure=True) - def check_correct_data_type(cls, values): + @model_validator(mode="after") + def check_correct_data_type(self) -> Self: """Issue error if incorrect data type is used""" - mnt = values.get("monitor") - J = values.get("J") + mnt = self.monitor + J = self.J if isinstance(J, TetrahedralGridDataset) or isinstance(J, TriangularGridDataset): AcceptedTypes = (IndexedFieldVoltageDataArray, PointDataArray) @@ -448,4 +430,4 @@ def check_correct_data_type(cls, values): "defined with IndexedFieldVoltageDataArray or PointDataArray." ) - return values + return self diff --git a/tidy3d/components/tcad/data/monitor_data/heat.py b/tidy3d/components/tcad/data/monitor_data/heat.py index 9734735ec9..5f105fd91c 100644 --- a/tidy3d/components/tcad/data/monitor_data/heat.py +++ b/tidy3d/components/tcad/data/monitor_data/heat.py @@ -4,25 +4,23 @@ from typing import Optional, Union -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.data.data_array import ( - DataArray, ScalarFieldTimeDataArray, SpatialDataArray, ) from tidy3d.components.data.utils import TetrahedralGridDataset, TriangularGridDataset from tidy3d.components.tcad.data.monitor_data.abstract import HeatChargeMonitorData -from tidy3d.components.tcad.monitors.heat import ( - TemperatureMonitor, -) -from tidy3d.components.types import annotate_type +from tidy3d.components.tcad.monitors.heat import TemperatureMonitor +from tidy3d.components.types.base import discriminated_union from tidy3d.constants import KELVIN FieldDataset = Union[ + discriminated_union(Union[TriangularGridDataset, TetrahedralGridDataset]), SpatialDataArray, ScalarFieldTimeDataArray, - annotate_type(Union[TriangularGridDataset, TetrahedralGridDataset]), + SpatialDataArray, ] UnstructuredFieldType = Union[TriangularGridDataset, TetrahedralGridDataset] @@ -44,18 +42,19 @@ class TemperatureData(HeatChargeMonitorData): >>> temp_mnt_data_expanded = temp_mnt_data.symmetry_expanded_copy """ - monitor: TemperatureMonitor = pd.Field( - ..., title="Monitor", description="Temperature monitor associated with the data." + monitor: TemperatureMonitor = Field( + title="Monitor", + description="Temperature monitor associated with the data.", ) - temperature: Optional[FieldDataset] = pd.Field( - ..., + temperature: Optional[FieldDataset] = Field( + None, title="Temperature", description="Spatial temperature field.", units=KELVIN, ) @property - def field_components(self) -> dict[str, DataArray]: + def field_components(self) -> dict[str, Optional[FieldDataset]]: """Maps the field components to their associated data.""" return {"temperature": self.temperature} diff --git a/tidy3d/components/tcad/data/monitor_data/mesh.py b/tidy3d/components/tcad/data/monitor_data/mesh.py index af3a4ec3f3..afa8d2fe9c 100644 --- a/tidy3d/components/tcad/data/monitor_data/mesh.py +++ b/tidy3d/components/tcad/data/monitor_data/mesh.py @@ -4,7 +4,7 @@ from typing import Union -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.data.utils import TetrahedralGridDataset, TriangularGridDataset from tidy3d.components.tcad.data.monitor_data.abstract import HeatChargeMonitorData @@ -43,12 +43,12 @@ class VolumeMeshData(HeatChargeMonitorData): >>> mesh_mnt_data = td.VolumeMeshData(monitor=mesh_mnt, mesh=tet_grid) # doctest: +SKIP """ - monitor: VolumeMeshMonitor = pd.Field( - ..., title="Monitor", description="Mesh monitor associated with the data." + monitor: VolumeMeshMonitor = Field( + title="Monitor", + description="Mesh monitor associated with the data.", ) - mesh: UnstructuredFieldType = pd.Field( - ..., + mesh: UnstructuredFieldType = Field( title="Mesh", description="Dataset storing the mesh.", ) diff --git a/tidy3d/components/tcad/data/sim_data.py b/tidy3d/components/tcad/data/sim_data.py index d72d25a931..94d427147c 100644 --- a/tidy3d/components/tcad/data/sim_data.py +++ b/tidy3d/components/tcad/data/sim_data.py @@ -3,10 +3,10 @@ from __future__ import annotations from abc import ABC -from typing import TYPE_CHECKING, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Optional import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.base_sim.data.sim_data import AbstractSimulationData @@ -30,14 +30,20 @@ from tidy3d.components.tcad.monitors.mesh import VolumeMeshMonitor from tidy3d.components.tcad.simulation.heat import HeatSimulation from tidy3d.components.tcad.simulation.heat_charge import HeatChargeSimulation -from tidy3d.components.types import Ax, RealFieldVal, annotate_type +from tidy3d.components.types.base import discriminated_union from tidy3d.components.viz import add_ax_if_none, equal_aspect from tidy3d.exceptions import DataError, Tidy3dKeyError from tidy3d.log import log if TYPE_CHECKING: + from typing import Literal, Union + from matplotlib.colors import Colormap + from tidy3d.compat import Self + from tidy3d.components.data.data_array import DataArray + from tidy3d.components.types import Ax, RealFieldVal + class DeviceCharacteristics(Tidy3dBaseModel): """Stores device characteristics. For example, in steady-state it stores @@ -60,27 +66,27 @@ class DeviceCharacteristics(Tidy3dBaseModel): """ - steady_dc_hole_capacitance: Optional[SteadyVoltageDataArray] = pd.Field( + steady_dc_hole_capacitance: Optional[SteadyVoltageDataArray] = Field( None, title="Steady DC hole capacitance", description="Device steady DC capacitance data based on holes. If the simulation " "has converged, these result should be close to that of electrons.", ) - steady_dc_electron_capacitance: Optional[SteadyVoltageDataArray] = pd.Field( + steady_dc_electron_capacitance: Optional[SteadyVoltageDataArray] = Field( None, title="Steady DC electron capacitance", description="Device steady DC capacitance data based on electrons. If the simulation " "has converged, these result should be close to that of holes.", ) - steady_dc_current_voltage: Optional[SteadyVoltageDataArray] = pd.Field( + steady_dc_current_voltage: Optional[SteadyVoltageDataArray] = Field( None, title="Steady DC current-voltage", description="Device steady DC current-voltage relation for the device.", ) - steady_dc_resistance_voltage: Optional[SteadyVoltageDataArray] = pd.Field( + steady_dc_resistance_voltage: Optional[SteadyVoltageDataArray] = Field( None, title="Small signal resistance", description="Steady DC computation of the small signal resistance. This is computed " @@ -88,7 +94,7 @@ class DeviceCharacteristics(Tidy3dBaseModel): "is given in Ohms. Note that in 2D the resistance is given in :math:`\\Omega \\mu`.", ) - ac_current_voltage: Optional[FreqVoltageDataArray] = pd.Field( + ac_current_voltage: Optional[FreqVoltageDataArray] = Field( None, title="Small-signal AC current-voltage", description="Small-signal AC current as a function of DC bias voltage and frequency. " @@ -101,13 +107,15 @@ class DeviceCharacteristics(Tidy3dBaseModel): class AbstractHeatChargeSimulationData(AbstractSimulationData, ABC): """Abstract class for HeatChargeSimulation results, or VolumeMesher results.""" - simulation: HeatChargeSimulation = pd.Field( + simulation: HeatChargeSimulation = Field( title="Heat-Charge Simulation", description="Original :class:`.HeatChargeSimulation` associated with the data.", ) @staticmethod - def _get_field_by_name(monitor_data: TCADMonitorDataType, field_name: Optional[str] = None): + def _get_field_by_name( + monitor_data: TCADMonitorDataType, field_name: Optional[str] = None + ) -> DataArray: """Return a field data based on a monitor dataset and a specified field name.""" if field_name is None: if len(monitor_data.field_components) > 1: @@ -258,14 +266,13 @@ class HeatChargeSimulationData(AbstractHeatChargeSimulationData): ... ) """ - data: tuple[annotate_type(TCADMonitorDataType), ...] = pd.Field( - ..., + data: tuple[discriminated_union(TCADMonitorDataType), ...] = Field( title="Monitor Data", description="List of :class:`.MonitorData` instances " "associated with the monitors of the original :class:`.Simulation`.", ) - device_characteristics: Optional[DeviceCharacteristics] = pd.Field( + device_characteristics: Optional[DeviceCharacteristics] = Field( None, title="Device characteristics", description="Data characterizing the device :class:`DeviceCharacteristics`.", @@ -483,19 +490,20 @@ class HeatSimulationData(HeatChargeSimulationData): Consider using :class:`HeatChargeSimulationData` instead. """ - simulation: HeatSimulation = pd.Field( + simulation: HeatSimulation = Field( title="Heat Simulation", description="Original :class:`HeatSimulation` associated with the data.", ) - @pd.root_validator(skip_on_failure=True) - def issue_warning_deprecated(cls, values): + @model_validator(mode="before") + @classmethod + def issue_warning_deprecated(cls, data: dict[str, Any]) -> dict[str, Any]: """Issue warning for 'HeatSimulations'.""" log.warning( - "'HeatSimulationData' is deprecated and will be discontinued. You can use " + "'HeatSimulationData' is deprecated and will be discontinued. Use " "'HeatChargeSimulationData' instead" ) - return values + return data class VolumeMesherData(AbstractHeatChargeSimulationData): @@ -554,14 +562,12 @@ class VolumeMesherData(AbstractHeatChargeSimulationData): >>> mesh_data = td.VolumeMesherData(simulation=heat_sim, data=[mesh_mnt_data], monitors=[mesh_mnt]) # doctest: +SKIP """ - monitors: tuple[VolumeMeshMonitor, ...] = pd.Field( - ..., + monitors: tuple[VolumeMeshMonitor, ...] = Field( title="Monitors", description="List of monitors to be used for the mesher.", ) - data: tuple[VolumeMeshData, ...] = pd.Field( - ..., + data: tuple[VolumeMeshData, ...] = Field( title="Monitor Data", description="List of :class:`.MonitorData` instances " "associated with the monitors of the original :class:`.VolumeMesher`.", @@ -575,23 +581,21 @@ def mesher(self) -> VolumeMesher: monitors=self.monitors, ) - @pd.root_validator(skip_on_failure=True) - def data_monitors_match_sim(cls, values): + @model_validator(mode="after") + def data_monitors_match_sim(self) -> Self: """Ensure each :class:`AbstractMonitorData` in ``.data`` corresponds to a monitor in ``.simulation``. """ - monitors = values.get("monitors") - data = values.get("data") - mnt_names = {mnt.name for mnt in monitors} + mnt_names = {mnt.name for mnt in self.monitors} - for mnt_data in data: + for mnt_data in self.data: monitor_name = mnt_data.monitor.name if monitor_name not in mnt_names: raise DataError( f"Data with monitor name '{monitor_name}' supplied " f"but not found in the list of monitors." ) - return values + return self def get_monitor_by_name(self, name: str) -> VolumeMeshMonitor: """Return monitor named 'name'.""" diff --git a/tidy3d/components/tcad/doping.py b/tidy3d/components/tcad/doping.py index 1d82639891..83d284bd8f 100644 --- a/tidy3d/components/tcad/doping.py +++ b/tidy3d/components/tcad/doping.py @@ -2,11 +2,11 @@ from __future__ import annotations -from typing import Union +from typing import TYPE_CHECKING, Union import numpy as np -import pydantic.v1 as pd import xarray as xr +from pydantic import Field, NonNegativeFloat, PositiveFloat, model_validator from tidy3d.components.autograd import TracedSize from tidy3d.components.base import cached_property @@ -15,19 +15,26 @@ from tidy3d.constants import MICROMETER, PERCMCUBE, inf from tidy3d.exceptions import SetupError +if TYPE_CHECKING: + from numpy.typing import ArrayLike, NDArray + + from tidy3d.compat import Self + class AbstractDopingBox(Box): """Derived class from Box to deal with dopings""" # Override size so that we can set default values - size: TracedSize = pd.Field( + size: TracedSize = Field( (inf, inf, inf), title="Size", description="Size in x, y, and z directions.", units=MICROMETER, ) - def _get_indices_in_box(self, coords: dict, meshgrid: bool = True): + def _get_indices_in_box( + self, coords: dict[str, ArrayLike], meshgrid: bool = True + ) -> tuple[NDArray[np.bool_], NDArray[np.floating], NDArray[np.floating], NDArray[np.floating]]: """Returns locations inside box""" # work out whether x,y, and z are present @@ -59,12 +66,14 @@ def _get_indices_in_box(self, coords: dict, meshgrid: bool = True): return indices_in_box, X, Y, Z - def _post_init_validators(self) -> None: + @model_validator(mode="after") + def _post_init_validators(self) -> Self: # check the doping box is 3D if len(self.zero_dims) > 0: raise SetupError( "The doping box must be 3D. If you want a 2D doping box, please set one of the dimensions to a large or infinite size." ) + return self class ConstantDoping(AbstractDopingBox): @@ -83,14 +92,14 @@ class ConstantDoping(AbstractDopingBox): >>> constant_box2 = td.ConstantDoping.from_bounds(rmin=box_coords[0], rmax=box_coords[1], concentration=1e18) """ - concentration: pd.NonNegativeFloat = pd.Field( + concentration: NonNegativeFloat = Field( default=0, title="Doping concentration density.", description="Doping concentration density.", units=PERCMCUBE, ) - def _get_contrib(self, coords: dict, meshgrid: bool = True): + def _get_contrib(self, coords: dict, meshgrid: bool = True) -> NDArray: """Returns the contribution to the doping a the locations specified in coords""" indices_in_box, X, _, _ = self._get_indices_in_box(coords=coords, meshgrid=meshgrid) @@ -150,20 +159,20 @@ class GaussianDoping(AbstractDopingBox): ... ) """ - ref_con: pd.PositiveFloat = pd.Field( + ref_con: PositiveFloat = Field( title="Reference concentration.", description="Reference concentration. This is the minimum concentration in the box " "and it is attained at the edges/faces of the box.", units=PERCMCUBE, ) - concentration: pd.PositiveFloat = pd.Field( + concentration: PositiveFloat = Field( title="Concentration", description="The concentration at the center of the box.", units=PERCMCUBE, ) - width: pd.PositiveFloat = pd.Field( + width: PositiveFloat = Field( title="Width of the gaussian.", description="Width of the gaussian. The concentration will transition from " "``concentration`` at the center of the box to ``ref_con`` at the edge/face " @@ -171,7 +180,7 @@ class GaussianDoping(AbstractDopingBox): units=MICROMETER, ) - source: str = pd.Field( + source: str = Field( "xmin", title="Source face", description="Specifies the side of the box acting as the source, i.e., " @@ -181,11 +190,11 @@ class GaussianDoping(AbstractDopingBox): ) @cached_property - def sigma(self): + def sigma(self) -> float: """The sigma parameter of the pseudo-gaussian""" return np.sqrt(-self.width * self.width / 2 / np.log(self.ref_con / self.concentration)) - def _get_contrib(self, coords: dict, meshgrid: bool = True): + def _get_contrib(self, coords: dict[str, ArrayLike], meshgrid: bool = True) -> NDArray: """Returns the contribution to the doping a the locations specified in coords""" indices_in_box, X, Y, Z = self._get_indices_in_box(coords=coords, meshgrid=meshgrid) @@ -304,14 +313,13 @@ class CustomDoping(AbstractDopingBox): ... ) """ - concentration: SpatialDataArray = pd.Field( - ..., + concentration: SpatialDataArray = Field( title="Doping concentration data array.", description="Doping concentration data array.", units=PERCMCUBE, ) - def _get_contrib(self, coords: dict, meshgrid: bool = True): + def _get_contrib(self, coords: dict, meshgrid: bool = True) -> NDArray: """Returns the contribution to the doping a the locations specified in coords""" indices_in_box, X, Y, Z = self._get_indices_in_box(coords=coords, meshgrid=meshgrid) diff --git a/tidy3d/components/tcad/effective_DOS.py b/tidy3d/components/tcad/effective_DOS.py index a60396a7df..f5ed752bfe 100644 --- a/tidy3d/components/tcad/effective_DOS.py +++ b/tidy3d/components/tcad/effective_DOS.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import HBAR, K_B, M_E_EV, PERCMCUBE @@ -20,7 +20,7 @@ class EffectiveDOS(Tidy3dBaseModel, ABC): def calc_eff_dos(self, T: float) -> None: """Abstract method to calculate the effective density of states.""" - def get_effective_DOS(self, T: float): + def get_effective_DOS(self, T: float) -> float: if T <= 0: raise DataError( f"Incorrect temperature value ({T}) for the effective density of states calculation." @@ -32,11 +32,11 @@ def get_effective_DOS(self, T: float): class ConstantEffectiveDOS(EffectiveDOS): """Constant effective density of states model.""" - N: pd.PositiveFloat = pd.Field( - ..., title="Effective DOS", description="Effective density of states", units=PERCMCUBE + N: PositiveFloat = Field( + title="Effective DOS", description="Effective density of states", units=PERCMCUBE ) - def calc_eff_dos(self, T: float): + def calc_eff_dos(self, T: float) -> float: return self.N @@ -53,13 +53,13 @@ class IsotropicEffectiveDOS(EffectiveDOS): """ - m_eff: pd.PositiveFloat = pd.Field( + m_eff: PositiveFloat = Field( ..., title="Effective mass", description="Effective mass of the carriers relative to the electron mass at rest", ) - def calc_eff_dos(self, T: float): + def calc_eff_dos(self, T: float) -> float: return np.power(self.m_eff * T, 1.5) * DOS_aux_const @@ -76,23 +76,21 @@ class MultiValleyEffectiveDOS(EffectiveDOS): """ - m_eff_long: pd.PositiveFloat = pd.Field( - ..., + m_eff_long: PositiveFloat = Field( title="Longitudinal effective mass", description="Relative effective mass of the carriers in the longitudinal direction. This is a relative value compared to the electron mass at rest.", ) - m_eff_trans: pd.PositiveFloat = pd.Field( - ..., + m_eff_trans: PositiveFloat = Field( title="Transverse effective mass", description="Relative effective mass of the carriers in the transverse direction. This is a relative value compared to the electron mass at rest.", ) - N_valley: pd.PositiveFloat = pd.Field( - ..., title="Number of valleys", description="Number of effective valleys" + N_valley: PositiveFloat = Field( + title="Number of valleys", description="Number of effective valleys" ) - def calc_eff_dos(self, T: float): + def calc_eff_dos(self, T: float) -> float: return ( self.N_valley * np.power(self.m_eff_long * self.m_eff_trans * self.m_eff_trans, 0.5) @@ -114,17 +112,17 @@ class DualValleyEffectiveDOS(EffectiveDOS): """ - m_eff_lh: pd.PositiveFloat = pd.Field( + m_eff_lh: PositiveFloat = Field( ..., title="Light hole effective mass", description="Relative effective mass of the light holes. This is a relative value compared to the electron mass at rest.", ) - m_eff_hh: pd.PositiveFloat = pd.Field( + m_eff_hh: PositiveFloat = Field( ..., title="Heavy hole effective mass", description="Relative effective mass of the heavy holes. This is a relative value compared to the electron mass at rest.", ) - def calc_eff_dos(self, T: float): + def calc_eff_dos(self, T: float) -> float: return (np.power(self.m_eff_lh * T, 1.5) + np.power(self.m_eff_hh * T, 1.5)) * DOS_aux_const diff --git a/tidy3d/components/tcad/generation_recombination.py b/tidy3d/components/tcad/generation_recombination.py index 03c6dad786..5e2016f84d 100644 --- a/tidy3d/components/tcad/generation_recombination.py +++ b/tidy3d/components/tcad/generation_recombination.py @@ -1,14 +1,17 @@ from __future__ import annotations -from typing import Literal, Union +from typing import TYPE_CHECKING, Literal, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, model_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.data.data_array import SpatialDataArray from tidy3d.constants import PERCMCUBE, SECOND +if TYPE_CHECKING: + from tidy3d.compat import Self + class FossumCarrierLifetime(Tidy3dBaseModel): """ @@ -43,25 +46,42 @@ class FossumCarrierLifetime(Tidy3dBaseModel): """ - tau_300: pd.PositiveFloat = pd.Field( - ..., title="Tau at 300K", description="Carrier lifetime at 300K", units=SECOND + tau_300: PositiveFloat = Field( + title="Tau at 300K", + description="Carrier lifetime at 300K", + units=SECOND, ) - alpha_T: float = pd.Field( - ..., title="Exponent for thermal dependence", description="Exponent for thermal dependence" + alpha_T: float = Field( + title="Exponent for thermal dependence", + description="Exponent for thermal dependence", ) - N0: pd.PositiveFloat = pd.Field( - ..., title="Reference concentration", description="Reference concentration", units=PERCMCUBE + N0: PositiveFloat = Field( + title="Reference concentration", + description="Reference concentration", + units=PERCMCUBE, ) - A: float = pd.Field(..., title="Constant A", description="Constant A") + A: float = Field( + title="Constant A", + description="Constant A", + ) - B: float = pd.Field(..., title="Constant B", description="Constant B") + B: float = Field( + title="Constant B", + description="Constant B", + ) - C: float = pd.Field(..., title="Constant C", description="Constant C") + C: float = Field( + title="Constant C", + description="Constant C", + ) - alpha: float = pd.Field(..., title="Exponent constant", description="Exponent constant") + alpha: float = Field( + title="Exponent constant", + description="Exponent constant", + ) CarrierLifetimeType = Union[FossumCarrierLifetime] @@ -90,15 +110,13 @@ class AugerRecombination(Tidy3dBaseModel): ... ) """ - c_n: pd.PositiveFloat = pd.Field( - ..., + c_n: PositiveFloat = Field( title="Constant for electrons", description="Constant for electrons.", units="cm^6/s", ) - c_p: pd.PositiveFloat = pd.Field( - ..., + c_p: PositiveFloat = Field( title="Constant for holes", description="Constant for holes.", units="cm^6/s", @@ -126,8 +144,7 @@ class RadiativeRecombination(Tidy3dBaseModel): ... ) """ - r_const: float = pd.Field( - ..., + r_const: float = Field( title="Radiation constant", description="Radiation constant of the radiative recombination model.", units="cm^3/s", @@ -168,12 +185,16 @@ class ShockleyReedHallRecombination(Tidy3dBaseModel): - This model represents mid-gap traps Shockley-Reed-Hall recombination. """ - tau_n: Union[pd.PositiveFloat, CarrierLifetimeType] = pd.Field( - ..., title="Electron lifetime", description="Electron lifetime", units=SECOND + tau_n: Union[PositiveFloat, CarrierLifetimeType] = Field( + title="Electron lifetime", + description="Electron lifetime", + units=SECOND, ) - tau_p: Union[pd.PositiveFloat, CarrierLifetimeType] = pd.Field( - ..., title="Hole lifetime", description="Hole lifetime", units=SECOND + tau_p: Union[PositiveFloat, CarrierLifetimeType] = Field( + title="Hole lifetime", + description="Hole lifetime", + units=SECOND, ) @@ -197,8 +218,7 @@ class DistributedGeneration(Tidy3dBaseModel): >>> dist_g = td.DistributedGeneration(rate=fd) """ - rate: SpatialDataArray = pd.Field( - ..., + rate: SpatialDataArray = Field( title="Generation rate", description="Spatially varying generation rate.", units="1/(cm^3 s^1)", @@ -211,18 +231,18 @@ def from_rate_um3(cls, gen_um3: SpatialDataArray) -> DistributedGeneration: new_gen = SpatialDataArray(gen_cm3, coords=gen_um3.coords) return cls(rate=new_gen) - @pd.root_validator(skip_on_failure=True) - def check_spatialdataarray_dimensions(cls, values): + @model_validator(mode="after") + def check_spatialdataarray_dimensions(self) -> Self: """Check that the SpatialDataArray is at least 2D:""" - rate = values.get("rate") + rate = self.rate zero_dims = [d for d in ["x", "y", "z"] if len(rate.coords[d]) <= 1] if len(zero_dims) > 1: raise ValueError("SpatialDataArray must be at least 2D.") - return values + return self class HurkxDirectBandToBandTunneling(Tidy3dBaseModel): @@ -259,25 +279,25 @@ class HurkxDirectBandToBandTunneling(Tidy3dBaseModel): .. [1] Palankovski, Vassil, and Rüdiger Quay. Analysis and simulation of heterostructure devices. Springer Science & Business Media, 2004. """ - A: pd.PositiveFloat = pd.Field( + A: PositiveFloat = Field( 4e14, title="Parameter :math:`A`", description="Parameter :math:`A` in the direct BTBT Hurkx model.", units="1/(cm^3 s)", ) - B: float = pd.Field( + B: float = Field( 1.9e6, title="Parameter :math:`B`", description="Parameter :math:`B` in the direct BTBT Hurkx model.", units="V/cm", ) - E_0: pd.PositiveFloat = pd.Field( + E_0: PositiveFloat = Field( 1, title="Reference electric field :math:`E_0`", description="Reference electric field :math:`E_0` in the direct BTBT Hurkx model.", units="V/cm", ) - sigma: float = pd.Field( + sigma: float = Field( 2.5, title="Exponent parameter", description="Exponent :math:`\\sigma` in the direct BTBT Hurkx model. For direct " @@ -321,42 +341,37 @@ class SelberherrImpactIonization(Tidy3dBaseModel): .. [2] Vassil Palankovski and Rüdiger Quay. Analysis and simulation of heterostructure devices. Springer Science & Business Media, 2004. """ - alpha_n_inf: pd.PositiveFloat = pd.Field( - ..., + alpha_n_inf: PositiveFloat = Field( title="Electron ionization coefficient at infinite field", description="Electron ionization coefficient at infinite field.", units="1/cm", ) - alpha_p_inf: pd.PositiveFloat = pd.Field( - ..., + alpha_p_inf: PositiveFloat = Field( title="Hole ionization coefficient at infinite field", description="Hole ionization coefficient at infinite field.", units="1/cm", ) - E_n_crit: pd.PositiveFloat = pd.Field( - ..., + E_n_crit: PositiveFloat = Field( title="Critical electric field for electrons", description="Critical electric field for electrons.", units="V/cm", ) - E_p_crit: pd.PositiveFloat = pd.Field( + E_p_crit: PositiveFloat = Field( ..., title="Critical electric field for holes", description="Critical electric field for holes.", units="V/cm", ) - beta_n: pd.PositiveFloat = pd.Field( - ..., + beta_n: PositiveFloat = Field( title="Exponent for electrons", description="Exponent for electrons.", ) - beta_p: pd.PositiveFloat = pd.Field( - ..., + beta_p: PositiveFloat = Field( title="Exponent for holes", description="Exponent for holes.", ) - formulation: Literal["Selberherr", "PQ"] = pd.Field( + formulation: Literal["Selberherr", "PQ"] = Field( "PQ", title="Formulation", description="Formulation used for impact ionization. Options are 'Selberherr' " diff --git a/tidy3d/components/tcad/grid.py b/tidy3d/components/tcad/grid.py index fcf0cf6f27..167708d91e 100644 --- a/tidy3d/components/tcad/grid.py +++ b/tidy3d/components/tcad/grid.py @@ -3,22 +3,26 @@ from __future__ import annotations from abc import ABC -from typing import Union +from typing import TYPE_CHECKING, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat, field_validator, model_validator -from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.geometry.base import Box -from tidy3d.components.types import Coordinate, annotate_type +from tidy3d.components.types import Coordinate +from tidy3d.components.types.base import discriminated_union from tidy3d.constants import MICROMETER from tidy3d.exceptions import ValidationError +if TYPE_CHECKING: + from tidy3d.compat import Self + class UnstructuredGrid(Tidy3dBaseModel, ABC): """Abstract unstructured grid.""" - relative_min_dl: pd.NonNegativeFloat = pd.Field( + relative_min_dl: NonNegativeFloat = Field( 1e-3, title="Relative Mesh Size Limit", description="The minimal allowed mesh size relative to the largest dimension of the simulation domain." @@ -34,14 +38,13 @@ class UniformUnstructuredGrid(UnstructuredGrid): >>> heat_grid = UniformUnstructuredGrid(dl=0.1) """ - dl: pd.PositiveFloat = pd.Field( - ..., + dl: PositiveFloat = Field( title="Grid Size", description="Grid size for uniform grid generation.", units=MICROMETER, ) - min_edges_per_circumference: pd.PositiveFloat = pd.Field( + min_edges_per_circumference: PositiveFloat = Field( 15, title="Minimum Edges per Circumference", description="Enforced minimum number of mesh segments per circumference of an object. " @@ -49,13 +52,13 @@ class UniformUnstructuredGrid(UnstructuredGrid): "is taken as 2 * pi * radius.", ) - min_edges_per_side: pd.PositiveFloat = pd.Field( + min_edges_per_side: PositiveFloat = Field( 2, title="Minimum Edges per Side", description="Enforced minimum number of mesh segments per any side of an object.", ) - non_refined_structures: tuple[str, ...] = pd.Field( + non_refined_structures: tuple[str, ...] = Field( (), title="Structures Without Refinement", description="List of structures for which ``min_edges_per_circumference`` and " @@ -67,15 +70,13 @@ class GridRefinementRegion(Box): """Refinement region for the unstructured mesh. The cell size is enforced to be constant inside the region. The cell size outside of the region depends on the distance from the region.""" - dl_internal: pd.PositiveFloat = pd.Field( - ..., + dl_internal: PositiveFloat = Field( title="Internal mesh cell size", description="Mesh cell size inside the refinement region", units=MICROMETER, ) - transition_thickness: pd.NonNegativeFloat = pd.Field( - ..., + transition_thickness: NonNegativeFloat = Field( title="Interface Distance", description="Thickness of a transition layer outside the box where the mesh cell size changes from the" "internal size to the external one.", @@ -86,51 +87,40 @@ class GridRefinementRegion(Box): class GridRefinementLine(Tidy3dBaseModel, ABC): """Refinement line for the unstructured mesh. The cell size depends on the distance from the line.""" - r1: Coordinate = pd.Field( - ..., + r1: Coordinate = Field( title="Start point of the line", description="Start point of the line in x, y, and z.", units=MICROMETER, ) - r2: Coordinate = pd.Field( - ..., + r2: Coordinate = Field( title="End point of the line", description="End point of the line in x, y, and z.", units=MICROMETER, ) - @pd.validator("r1", always=True) - def _r1_not_inf(cls, val): - """Make sure the point is not infinitiy.""" - if any(np.isinf(v) for v in val): - raise ValidationError("Point can not contain td.inf terms.") - return val - - @pd.validator("r2", always=True) - def _r2_not_inf(cls, val): + @field_validator("r1", "r2") + @classmethod + def _not_inf(cls, val: Coordinate) -> Coordinate: """Make sure the point is not infinitiy.""" if any(np.isinf(v) for v in val): - raise ValidationError("Point can not contain td.inf terms.") + raise ValidationError("Point can not contain 'td.inf' terms.") return val - dl_near: pd.PositiveFloat = pd.Field( - ..., + dl_near: PositiveFloat = Field( title="Mesh cell size near the line", description="Mesh cell size near the line", units=MICROMETER, ) - distance_near: pd.NonNegativeFloat = pd.Field( - ..., + distance_near: NonNegativeFloat = Field( title="Near distance", description="Distance from the line within which ``dl_near`` is enforced." "Typically the same as ``dl_near`` or its multiple.", units=MICROMETER, ) - distance_bulk: pd.NonNegativeFloat = pd.Field( - ..., + distance_bulk: NonNegativeFloat = Field( title="Bulk distance", description="Distance from the line outside of which ``dl_bulk`` is enforced." "Typically twice of ``dl_bulk`` or its multiple. Use larger values for a smoother " @@ -138,15 +128,13 @@ def _r2_not_inf(cls, val): units=MICROMETER, ) - @pd.validator("distance_bulk", always=True) - @skip_if_fields_missing(["distance_near"]) - def names_exist_bcs(cls, val, values): + @model_validator(mode="after") + def names_exist_bcs(self) -> Self: """Error if distance_bulk is less than distance_near""" - distance_near = values.get("distance_near") - if distance_near > val: + if self.distance_near > self.distance_bulk: raise ValidationError("'distance_bulk' cannot be smaller than 'distance_near'.") - return val + return self class DistanceUnstructuredGrid(UnstructuredGrid): @@ -163,30 +151,26 @@ class DistanceUnstructuredGrid(UnstructuredGrid): ... ) """ - dl_interface: pd.PositiveFloat = pd.Field( - ..., + dl_interface: PositiveFloat = Field( title="Interface Grid Size", description="Grid size near material interfaces.", units=MICROMETER, ) - dl_bulk: pd.PositiveFloat = pd.Field( - ..., + dl_bulk: PositiveFloat = Field( title="Bulk Grid Size", description="Grid size away from material interfaces.", units=MICROMETER, ) - distance_interface: pd.NonNegativeFloat = pd.Field( - ..., + distance_interface: NonNegativeFloat = Field( title="Interface Distance", description="Distance from interface within which ``dl_interface`` is enforced." "Typically the same as ``dl_interface`` or its multiple.", units=MICROMETER, ) - distance_bulk: pd.NonNegativeFloat = pd.Field( - ..., + distance_bulk: NonNegativeFloat = Field( title="Bulk Distance", description="Distance from interface outside of which ``dl_bulk`` is enforced." "Typically twice of ``dl_bulk`` or its multiple. Use larger values for a smoother " @@ -194,44 +178,42 @@ class DistanceUnstructuredGrid(UnstructuredGrid): units=MICROMETER, ) - sampling: pd.PositiveFloat = pd.Field( + sampling: PositiveFloat = Field( 100, title="Surface Sampling", description="An internal advanced parameter that defines number of sampling points per " "surface when computing distance values.", ) - uniform_grid_mediums: tuple[str, ...] = pd.Field( + uniform_grid_mediums: tuple[str, ...] = Field( (), title="Mediums With Uniform Refinement", description="List of mediums for which ``dl_interface`` will be enforced everywhere " "in the volume.", ) - non_refined_structures: tuple[str, ...] = pd.Field( + non_refined_structures: tuple[str, ...] = Field( (), title="Structures Without Refinement", description="List of structures for which ``dl_interface`` will not be enforced. " "``dl_bulk`` is used instead.", ) - mesh_refinements: tuple[annotate_type(Union[GridRefinementRegion, GridRefinementLine]), ...] = ( - pd.Field( - (), - title="Mesh refinement structures", - description="List of regions/lines for which the mesh refinement will be applied", - ) + mesh_refinements: tuple[ + discriminated_union(Union[GridRefinementRegion, GridRefinementLine]), ... + ] = Field( + (), + title="Mesh refinement structures", + description="List of regions/lines for which the mesh refinement will be applied", ) - @pd.validator("distance_bulk", always=True) - @skip_if_fields_missing(["distance_interface"]) - def names_exist_bcs(cls, val, values): + @model_validator(mode="after") + def names_exist_bcs(self) -> Self: """Error if distance_bulk is less than distance_interface""" - distance_interface = values.get("distance_interface") - if distance_interface > val: + if self.distance_interface > self.distance_bulk: raise ValidationError("'distance_bulk' cannot be smaller than 'distance_interface'.") - return val + return self UnstructuredGridType = Union[UniformUnstructuredGrid, DistanceUnstructuredGrid] diff --git a/tidy3d/components/tcad/mesher.py b/tidy3d/components/tcad/mesher.py index 82af0de6f3..57e58acbcd 100644 --- a/tidy3d/components/tcad/mesher.py +++ b/tidy3d/components/tcad/mesher.py @@ -1,22 +1,26 @@ from __future__ import annotations -import pydantic.v1 as pd +from typing import TYPE_CHECKING + +from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.tcad.monitors.mesh import VolumeMeshMonitor from tidy3d.components.tcad.simulation.heat_charge import HeatChargeSimulation, TCADAnalysisTypes +if TYPE_CHECKING: + from tidy3d.compat import Self + class VolumeMesher(Tidy3dBaseModel): """Specification for a standalone volume mesher.""" - simulation: HeatChargeSimulation = pd.Field( - ..., + simulation: HeatChargeSimulation = Field( title="Simulation", description="HeatCharge simulation instance for the mesh specification.", ) - monitors: tuple[VolumeMeshMonitor, ...] = pd.Field( + monitors: tuple[VolumeMeshMonitor, ...] = Field( (), title="Monitors", description="List of monitors to be used for the mesher.", @@ -25,7 +29,7 @@ class VolumeMesher(Tidy3dBaseModel): def _get_simulation_types(self) -> list[TCADAnalysisTypes]: return [TCADAnalysisTypes.MESH] - def validate_pre_upload(self): + def validate_pre_upload(self: Self) -> None: """Validate the VolumeMesher before uploading to the cloud. Currently no validation but method is required when calling ``web.upload``. """ diff --git a/tidy3d/components/tcad/mobility.py b/tidy3d/components/tcad/mobility.py index 428e0ffacd..b65e6db16e 100644 --- a/tidy3d/components/tcad/mobility.py +++ b/tidy3d/components/tcad/mobility.py @@ -1,6 +1,6 @@ from __future__ import annotations -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import PERCMCUBE @@ -15,8 +15,10 @@ class ConstantMobilityModel(Tidy3dBaseModel): >>> mobility_model = td.ConstantMobilityModel(mu=1500) """ - mu: pd.NonNegativeFloat = pd.Field( - ..., title="Mobility", description="Mobility", units="cm²/V-s" + mu: NonNegativeFloat = Field( + title="Mobility", + description="Mobility", + units="cm²/V-s", ) @@ -115,55 +117,48 @@ class CaugheyThomasMobility(Tidy3dBaseModel): """ # mobilities - mu_min: pd.PositiveFloat = pd.Field( - ..., + mu_min: PositiveFloat = Field( title="Minimum electron mobility", description="Minimum electron mobility :math:`\\mu_{\\text{min}}` at reference temperature (300K).", units="cm^2/V-s", ) - mu: pd.PositiveFloat = pd.Field( - ..., + mu: PositiveFloat = Field( title="Reference mobility", description="Reference mobility at reference temperature (300K).", units="cm^2/V-s", ) # thermal exponent for reference mobility - exp_2: float = pd.Field( - ..., title="Exponent for temperature dependent behavior of reference mobility" + exp_2: float = Field( + title="Exponent for temperature dependent behavior of reference mobility", ) # doping exponent - exp_N: pd.PositiveFloat = pd.Field( - ..., + exp_N: PositiveFloat = Field( title="Exponent for doping dependence of mobility.", description="Exponent for doping dependence of mobility at reference temperature (300K).", ) # reference doping - ref_N: pd.PositiveFloat = pd.Field( - ..., + ref_N: PositiveFloat = Field( title="Reference doping", description="Reference doping at reference temperature (300K).", units=PERCMCUBE, ) # temperature exponent - exp_1: float = pd.Field( - ..., + exp_1: float = Field( title="Exponent of thermal dependence of minimum mobility.", description="Exponent of thermal dependence of minimum mobility.", ) - exp_3: float = pd.Field( - ..., + exp_3: float = Field( title="Exponent of thermal dependence of reference doping.", description="Exponent of thermal dependence of reference doping.", ) - exp_4: float = pd.Field( - ..., + exp_4: float = Field( title="Exponent of thermal dependence of the doping exponent effect.", description="Exponent of thermal dependence of the doping exponent effect.", ) diff --git a/tidy3d/components/tcad/monitors/abstract.py b/tidy3d/components/tcad/monitors/abstract.py index 8c3fa91bbe..434685a1f8 100644 --- a/tidy3d/components/tcad/monitors/abstract.py +++ b/tidy3d/components/tcad/monitors/abstract.py @@ -3,11 +3,14 @@ from __future__ import annotations from abc import ABC +from typing import TYPE_CHECKING -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base_sim.monitor import AbstractMonitor -from tidy3d.components.types import ArrayFloat1D + +if TYPE_CHECKING: + from tidy3d.components.types import ArrayFloat1D BYTES_REAL = 4 @@ -15,13 +18,13 @@ class HeatChargeMonitor(AbstractMonitor, ABC): """Abstract base class for heat-charge monitors.""" - unstructured: bool = pd.Field( + unstructured: bool = Field( False, title="Unstructured Grid", description="Return data on the original unstructured grid.", ) - conformal: bool = pd.Field( + conformal: bool = Field( False, title="Conformal Monitor Meshing", description="If ``True`` the simulation mesh will conform to the monitor's geometry. " diff --git a/tidy3d/components/tcad/monitors/charge.py b/tidy3d/components/tcad/monitors/charge.py index 63838cc341..6c822bf984 100644 --- a/tidy3d/components/tcad/monitors/charge.py +++ b/tidy3d/components/tcad/monitors/charge.py @@ -4,7 +4,7 @@ from typing import Literal -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.tcad.monitors.abstract import HeatChargeMonitor @@ -35,7 +35,7 @@ class SteadyFreeCarrierMonitor(HeatChargeMonitor): """ # NOTE: for the time being supporting unstructured - unstructured: Literal[True] = pd.Field( + unstructured: Literal[True] = Field( True, title="Unstructured Grid", description="Return data on the original unstructured grid.", @@ -55,7 +55,7 @@ class SteadyEnergyBandMonitor(HeatChargeMonitor): """ # NOTE: for the time being supporting unstructured - unstructured: Literal[True] = pd.Field( + unstructured: Literal[True] = Field( True, title="Unstructured Grid", description="Return data on the original unstructured grid.", @@ -75,7 +75,7 @@ class SteadyCapacitanceMonitor(HeatChargeMonitor): """ # NOTE: for the time being supporting unstructured - unstructured: Literal[True] = pd.Field( + unstructured: Literal[True] = Field( True, title="Unstructured Grid", description="Return data on the original unstructured grid.", @@ -94,7 +94,7 @@ class SteadyElectricFieldMonitor(HeatChargeMonitor): ... ) """ - unstructured: Literal[True] = pd.Field( + unstructured: Literal[True] = Field( True, title="Unstructured Grid", description="Return data on the original unstructured grid.", @@ -113,7 +113,7 @@ class SteadyCurrentDensityMonitor(HeatChargeMonitor): ... ) """ - unstructured: Literal[True] = pd.Field( + unstructured: Literal[True] = Field( True, title="Unstructured Grid", description="Return data on the original unstructured grid.", diff --git a/tidy3d/components/tcad/monitors/heat.py b/tidy3d/components/tcad/monitors/heat.py index 3d8fff722b..42cb113d39 100644 --- a/tidy3d/components/tcad/monitors/heat.py +++ b/tidy3d/components/tcad/monitors/heat.py @@ -2,7 +2,7 @@ from __future__ import annotations -from pydantic.v1 import Field, PositiveInt +from pydantic import Field, PositiveInt from tidy3d.components.tcad.monitors.abstract import HeatChargeMonitor diff --git a/tidy3d/components/tcad/monitors/mesh.py b/tidy3d/components/tcad/monitors/mesh.py index 87106068d1..0b737cbe5a 100644 --- a/tidy3d/components/tcad/monitors/mesh.py +++ b/tidy3d/components/tcad/monitors/mesh.py @@ -3,26 +3,30 @@ from __future__ import annotations from math import isclose -from typing import Literal +from typing import TYPE_CHECKING, Literal -import pydantic.v1 as pd +from pydantic import Field, field_validator from tidy3d.components.tcad.monitors.abstract import HeatChargeMonitor +if TYPE_CHECKING: + from tidy3d.components.autograd import TracedSize + class VolumeMeshMonitor(HeatChargeMonitor): """Monitor recording the volume mesh. The monitor size must be either 2D or 3D. If a 2D monitor is used in a 3D simulation, the sliced volumetric mesh on the plane of the monitor will be stored as a ``TriangularGridDataset``.""" - unstructured: Literal[True] = pd.Field( + unstructured: Literal[True] = Field( True, title="Unstructured Grid", description="Return the original unstructured grid.", ) - @pd.validator("size", always=True) - def _at_least_2d(cls, val): + @field_validator("size") + @classmethod + def _at_least_2d(cls, val: TracedSize) -> TracedSize: """Validate that the monitor has at least two non-zero dimensions.""" if len([d for d in val if isclose(d, 0)]) > 1: raise ValueError("'VolumeMeshMonitor' must have at least two nonzero dimensions.") diff --git a/tidy3d/components/tcad/simulation/heat.py b/tidy3d/components/tcad/simulation/heat.py index ace1d8a855..c10a3ea830 100644 --- a/tidy3d/components/tcad/simulation/heat.py +++ b/tidy3d/components/tcad/simulation/heat.py @@ -3,15 +3,19 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Any -import pydantic.v1 as pd +from pydantic import model_validator from tidy3d.components.tcad.simulation.heat_charge import HeatChargeSimulation -from tidy3d.components.types import Ax from tidy3d.components.viz import add_ax_if_none, equal_aspect from tidy3d.log import log +if TYPE_CHECKING: + from typing import Optional + + from tidy3d.components.types import Ax + class HeatSimulation(HeatChargeSimulation): """ @@ -47,14 +51,15 @@ class HeatSimulation(HeatChargeSimulation): ... ) """ - @pd.root_validator(skip_on_failure=True) - def issue_warning_deprecated(cls, values): + @model_validator(mode="before") + @classmethod + def issue_warning_deprecated(cls, data: dict[str, Any]) -> dict[str, Any]: """Issue warning for 'HeatSimulations'.""" log.warning( "Setting up deprecated 'HeatSimulation'. " "Consider defining 'HeatChargeSimulation' instead." ) - return values + return data @equal_aspect @add_ax_if_none diff --git a/tidy3d/components/tcad/simulation/heat_charge.py b/tidy3d/components/tcad/simulation/heat_charge.py index 59b08ab1d7..b76cc6cbd1 100644 --- a/tidy3d/components/tcad/simulation/heat_charge.py +++ b/tidy3d/components/tcad/simulation/heat_charge.py @@ -1,20 +1,13 @@ -# ruff: noqa: W293, W291 """Defines heat simulation class""" from __future__ import annotations from enum import Enum -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, field_validator, model_validator -try: - from matplotlib import colormaps -except ImportError: - pass - -from tidy3d.components.base import skip_if_fields_missing from tidy3d.components.base_sim.simulation import AbstractSimulation from tidy3d.components.bc_placement import ( MediumMediumInterface, @@ -47,10 +40,7 @@ from tidy3d.components.structure import Structure from tidy3d.components.tcad.analysis.heat_simulation_type import UnsteadyHeatAnalysis from tidy3d.components.tcad.boundary.heat import VerticalNaturalConvectionCoeffModel -from tidy3d.components.tcad.boundary.specification import ( - HeatBoundarySpec, - HeatChargeBoundarySpec, -) +from tidy3d.components.tcad.boundary.specification import HeatBoundarySpec, HeatChargeBoundarySpec from tidy3d.components.tcad.grid import ( DistanceUnstructuredGrid, UniformUnstructuredGrid, @@ -62,12 +52,8 @@ SteadyFreeCarrierMonitor, SteadyPotentialMonitor, ) -from tidy3d.components.tcad.monitors.heat import ( - TemperatureMonitor, -) -from tidy3d.components.tcad.source.abstract import ( - GlobalHeatChargeSource, -) +from tidy3d.components.tcad.monitors.heat import TemperatureMonitor +from tidy3d.components.tcad.source.abstract import GlobalHeatChargeSource from tidy3d.components.tcad.types import ( ConvectionBC, CurrentBC, @@ -90,19 +76,29 @@ plot_params_heat_bc, plot_params_heat_source, ) -from tidy3d.components.types import ( - TYPE_TAG_STR, - Ax, - Bound, - ScalarSymmetry, - Shapely, - annotate_type, -) -from tidy3d.components.viz import PlotParams, add_ax_if_none, equal_aspect +from tidy3d.components.types import TYPE_TAG_STR, ScalarSymmetry +from tidy3d.components.types.base import discriminated_union +from tidy3d.components.viz import add_ax_if_none, equal_aspect from tidy3d.constants import VOLUMETRIC_HEAT_RATE, inf from tidy3d.exceptions import SetupError from tidy3d.log import log +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import Literal + + from pydantic import FiniteFloat + + from tidy3d.compat import Self + from tidy3d.components.types import Ax, Bound, Shapely + from tidy3d.components.types.base import ArrayFloat1D + from tidy3d.components.viz import PlotParams + +try: + from matplotlib import colormaps +except ImportError: + pass + HEAT_CHARGE_BACK_STRUCTURE_STR = "<<>>" HeatBCTypes = (TemperatureBC, HeatFluxBC, ConvectionBC) @@ -292,8 +288,8 @@ class HeatChargeSimulation(AbstractSimulation): top of the coupling heat source. """ - medium: StructureMediumType = pd.Field( - Medium(), + medium: StructureMediumType = Field( + default_factory=Medium, title="Background Medium", description="Background medium of simulation, defaults to a standard dispersion-less :class:`.Medium` if not " "specified.", @@ -303,34 +299,34 @@ class HeatChargeSimulation(AbstractSimulation): Background medium of simulation, defaults to a standard dispersion-less :class:`.Medium` if not specified. """ - sources: tuple[annotate_type(HeatChargeSourceType), ...] = pd.Field( + sources: tuple[discriminated_union(HeatChargeSourceType), ...] = Field( (), title="Heat and Charge sources", description="List of heat and/or charge sources.", ) - monitors: tuple[annotate_type(HeatChargeMonitorType), ...] = pd.Field( + monitors: tuple[discriminated_union(HeatChargeMonitorType), ...] = Field( (), title="Monitors", description="Monitors in the simulation.", ) - boundary_spec: tuple[annotate_type(Union[HeatChargeBoundarySpec, HeatBoundarySpec]), ...] = ( - pd.Field( - (), - title="Boundary Condition Specifications", - description="List of boundary condition specifications.", - ) + boundary_spec: tuple[ + discriminated_union(Union[HeatChargeBoundarySpec, HeatBoundarySpec]), ... + ] = Field( + (), + title="Boundary Condition Specifications", + description="List of boundary condition specifications.", ) # NOTE: creating a union with HeatBoundarySpec for backwards compatibility - grid_spec: UnstructuredGridType = pd.Field( + grid_spec: UnstructuredGridType = Field( title="Grid Specification", description="Grid specification for heat-charge simulation.", discriminator=TYPE_TAG_STR, ) - symmetry: tuple[ScalarSymmetry, ScalarSymmetry, ScalarSymmetry] = pd.Field( + symmetry: tuple[ScalarSymmetry, ScalarSymmetry, ScalarSymmetry] = Field( (0, 0, 0), title="Symmetries", description="Tuple of integers defining reflection symmetry across a plane " @@ -339,21 +335,16 @@ class HeatChargeSimulation(AbstractSimulation): "Each element can be ``0`` (symmetry off) or ``1`` (symmetry on).", ) - analysis_spec: AnalysisSpecType = pd.Field( + analysis_spec: Optional[AnalysisSpecType] = Field( None, title="Analysis specification.", description="The `analysis_spec` is used to specify the type of simulation. Currently, it is used to " "specify Charge simulations or transient Heat simulations.", ) - def _post_init_validators(self) -> None: - """Call validators taking ``self`` that get run after init.""" - - # Charge mesh size validator - self._estimate_charge_mesh_size() - - @pd.validator("structures", always=True) - def check_unsupported_geometries(cls, val): + @field_validator("structures") + @classmethod + def check_unsupported_geometries(cls, val: tuple[Structure, ...]) -> tuple[Structure, ...]: """Error if structures contain unsupported yet geometries.""" for ind, structure in enumerate(val): bbox = structure.geometry.bounding_box @@ -363,8 +354,7 @@ def check_unsupported_geometries(cls, val): ) return val - @staticmethod - def _check_cross_solids(objs: tuple[Box, ...], values: dict) -> tuple[int, ...]: + def _check_cross_solids(self, objs: tuple[Box, ...]) -> tuple[int, ...]: """Given model dictionary ``values``, check whether objects in list ``objs`` cross a ``SolidSpec`` medium. """ @@ -373,29 +363,16 @@ def _check_cross_solids(objs: tuple[Box, ...], values: dict) -> tuple[int, ...]: # will be accepted valid_electric_medium = (SemiconductorMedium, ChargeConductorMedium) - try: - size = values["size"] - center = values["center"] - medium = values["medium"] - structures = values["structures"] - except KeyError: - raise SetupError( - "Function '_check_cross_solids' assumes dictionary 'values' contains well-defined " - "'size', 'center', 'medium', and 'structures'. Thus, it should only be used in " - "validators with @skip_if_fields_missing(['medium', 'center', 'size', 'structures']) " - "or root validators with option 'skip_on_failure=True'." - ) from None - # list of structures including background as a Box() structure_bg = Structure( geometry=Box( - size=size, - center=center, + size=self.size, + center=self.center, ), - medium=medium, + medium=self.medium, ) - total_structures = [structure_bg, *list(structures)] + total_structures = [structure_bg, *list(self.structures)] obj_do_not_cross_solid_idx = [] obj_do_not_cross_cond_idx = [] @@ -430,15 +407,12 @@ def _check_cross_solids(objs: tuple[Box, ...], values: dict) -> tuple[int, ...]: return obj_do_not_cross_solid_idx, obj_do_not_cross_cond_idx - @pd.validator("monitors", always=True) - @skip_if_fields_missing(["medium", "center", "size", "structures"]) - def _monitors_cross_solids(cls, val, values): + @model_validator(mode="after") + def _monitors_cross_solids(self) -> Self: """Error if monitors does not cross any solid medium.""" + val = self.monitors - # if val is None: - # return val - - failed_solid_idx, failed_elect_idx = cls._check_cross_solids(val, values) + failed_solid_idx, failed_elect_idx = self._check_cross_solids(val) temp_monitors = [idx for idx, mnt in enumerate(val) if isinstance(mnt, TemperatureMonitor)] volt_monitors = [ @@ -464,19 +438,19 @@ def _monitors_cross_solids(cls, val, values): "materials. Thus, no information will be recorded in these monitors." ) - return val + return self - @pd.root_validator(skip_on_failure=True) - def check_voltage_array_if_capacitance(cls, values): + @model_validator(mode="after") + def check_voltage_array_if_capacitance(self) -> Self: """Make sure an array of voltages has been defined if a SteadyCapacitanceMonitor' has been defined""" - bounday_spec = values["boundary_spec"] - monitors = values["monitors"] + boundary_spec = self.boundary_spec + monitors = self.monitors is_capacitance_mnt = any(isinstance(mnt, SteadyCapacitanceMonitor) for mnt in monitors) voltage_array_present = False if is_capacitance_mnt: - for bc in bounday_spec: + for bc in boundary_spec: if isinstance(bc.condition, VoltageBC): if isinstance(bc.condition.source, DCVoltageSource): if len(bc.condition.source.voltage) > 1: @@ -491,11 +465,13 @@ def check_voltage_array_if_capacitance(cls, values): "Voltage arrays can be included in a source in this manner: " "'VoltageBC(source=DCVoltageSource(voltage=yourArray))'" ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def check_single_ssac(cls, values): - boundary_spec = values["boundary_spec"] + @field_validator("boundary_spec") + @classmethod + def check_single_ssac( + cls, boundary_spec: Union[HeatChargeBoundarySpec, HeatBoundarySpec] + ) -> Union[HeatChargeBoundarySpec, HeatBoundarySpec]: ssac_present = False for bc in boundary_spec: if isinstance(bc.condition, VoltageBC): @@ -506,18 +482,18 @@ def check_single_ssac(cls, values): ) else: ssac_present = True - return values + return boundary_spec - @pd.root_validator(skip_on_failure=True) - def check_natural_convection_bc(cls, values): + @model_validator(mode="after") + def check_natural_convection_bc(self) -> Self: """Make sure that natural convection BCs are defined correctly.""" - boundary_spec = values.get("boundary_spec") + boundary_spec = self.boundary_spec if not boundary_spec: - return values + return self - structures = values["structures"] - boundary_spec = values["boundary_spec"] - bg_medium = values["medium"] + structures = self.structures + boundary_spec = self.boundary_spec + bg_medium = self.medium # Create mappings for easy lookup of media and structures by name. media = {s.medium.name: s.medium for s in structures if s.medium.name} @@ -525,7 +501,7 @@ def check_natural_convection_bc(cls, values): media[bg_medium.name] = bg_medium structures_map = {s.name: s for s in structures if s.name} - def check_fluid_medium_attr(fluid_medium) -> None: + def check_fluid_medium_attr(fluid_medium: FluidMedium) -> None: if ( (fluid_medium.thermal_conductivity is None) or (fluid_medium.viscosity is None) @@ -582,10 +558,11 @@ def check_fluid_medium_attr(fluid_medium) -> None: # Case 2: The fluid medium IS specified directly in the convection model. else: check_fluid_medium_attr(natural_conv_model.medium) - return values + return self - @pd.validator("size", always=True) - def check_zero_dim_domain(cls, val, values): + @field_validator("size") + @classmethod + def check_zero_dim_domain(cls, val: Any) -> Any: """Error if heat domain have zero dimensions.""" dim_names = ["x", "y", "z"] @@ -605,17 +582,15 @@ def check_zero_dim_domain(cls, val, values): return val - @pd.validator("boundary_spec", always=True) - @skip_if_fields_missing(["structures", "medium"]) - def names_exist_bcs(cls, val, values): + @model_validator(mode="after") + def names_exist_bcs(self) -> Self: """Error if boundary conditions point to non-existing structures/media.""" - - structures = values.get("structures") + structures = self.structures structures_names = {s.name for s in structures} mediums_names = {s.medium.name for s in structures} - mediums_names.add(values.get("medium").name) + mediums_names.add(self.medium.name) - for bc_ind, bc_spec in enumerate(val): + for bc_ind, bc_spec in enumerate(self.boundary_spec): bc_place = bc_spec.placement if isinstance(bc_place, (StructureBoundary, StructureSimulationBoundary)): if bc_place.structure not in structures_names: @@ -640,14 +615,14 @@ def names_exist_bcs(cls, val, values): f"'boundary_spec[{bc_ind}].placement' (type '{bc_place.type}') " "is not found among simulation mediums." ) - return val + return self - @pd.validator("boundary_spec", always=True) - def check_only_one_voltage_array_provided(cls, val, values): + @field_validator("boundary_spec") + @classmethod + def check_only_one_voltage_array_provided(cls, val: Any) -> Any: """Issue error if more than one voltage array is provided. Currently we only allow to sweep over one voltage array. """ - array_already_provided = False for bc in val: @@ -667,15 +642,15 @@ def check_only_one_voltage_array_provided(cls, val, values): ) return val - @pd.root_validator(skip_on_failure=True) - def check_freqs_requires_ac_source(cls, values): + @model_validator(mode="after") + def check_freqs_requires_ac_source(self) -> Self: """Ensure that if freqs is provided, at least one ACVoltageSource is present.""" - analysis_spec = values.get("analysis_spec") + analysis_spec = self.analysis_spec if ( isinstance(analysis_spec, (SSACAnalysis, IsothermalSSACAnalysis)) and len(analysis_spec.freqs) > 0 ): - bcs = values.get("boundary_spec") + bcs = self.boundary_spec has_ac_source = False for bc in bcs: if isinstance(bc.condition, VoltageBC): @@ -689,19 +664,18 @@ def check_freqs_requires_ac_source(cls, values): "'SSACVoltageSource' must be present in the boundary conditions." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def check_charge_simulation(cls, values): + @model_validator(mode="after") + def check_charge_simulation(self) -> Self: """Makes sure that Charge simulations are set correctly.""" - simulation_types = cls._check_simulation_types(values=values) + simulation_types = self._check_simulation_types() if TCADAnalysisTypes.CHARGE in simulation_types: # check that we have at least 2 'VoltageBC's - boundary_spec = values["boundary_spec"] voltage_bcs = 0 - for bc in boundary_spec: + for bc in self.boundary_spec: if isinstance(bc.condition, VoltageBC): voltage_bcs = voltage_bcs + 1 if voltage_bcs < 2: @@ -711,8 +685,7 @@ def check_charge_simulation(cls, values): ) # check that we have at least one charge monitor - monitors = values["monitors"] - if not any(isinstance(mnt, ChargeMonitorTypes) for mnt in monitors): + if not any(isinstance(mnt, ChargeMonitorTypes) for mnt in self.monitors): raise SetupError( "Charge simulations require the definition of, at least, one of these monitors: " "'[SteadyPotentialMonitor, SteadyFreeCarrierMonitor, SteadyCapacitanceMonitor, SteadyCurrentDensityMonitor]' " @@ -721,7 +694,7 @@ def check_charge_simulation(cls, values): # NOTE: in Charge we're only supporting unstructured monitors. # only Temperature and Potential monitors can be structured. - for mnt in monitors: + for mnt in self.monitors: if isinstance(mnt, SteadyPotentialMonitor) or isinstance(mnt, TemperatureMonitor): if not mnt.unstructured: log.warning( @@ -729,29 +702,28 @@ def check_charge_simulation(cls, values): f"monitor '{mnt.name}' to 'unstructured = True'." ) # check that we have at least one semiconductor medium - structures = values["structures"] + structures = self.structures sc_present = HeatChargeSimulation._check_if_semiconductor_present(structures=structures) if not sc_present: raise SetupError( f"{TCADAnalysisTypes.CHARGE} simulations require the definition of at least one semiconductor medium." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def not_all_neumann(cls, values): + @model_validator(mode="after") + def not_all_neumann(self) -> Self: """Make sure not all BCs are of Neumann type""" NeumannBCsHeat = (HeatFluxBC,) NeumannBCsCharge = (CurrentBC, InsulatingBC) - simulation_types = cls._check_simulation_types(values=values) - bounday_conditions = values["boundary_spec"] + simulation_types = self._check_simulation_types() raise_error = False for sim_type in simulation_types: if sim_type == TCADAnalysisTypes.HEAT: type_bcs = [ - bc for bc in bounday_conditions if isinstance(bc.condition, HeatBCTypes) + bc for bc in self.boundary_spec if isinstance(bc.condition, HeatBCTypes) ] if len(type_bcs) == 0 or all( isinstance(bc.condition, NeumannBCsHeat) for bc in type_bcs @@ -759,7 +731,7 @@ def not_all_neumann(cls, values): raise_error = True elif sim_type == TCADAnalysisTypes.CONDUCTION: type_bcs = [ - bc for bc in bounday_conditions if isinstance(bc.condition, ElectricBCTypes) + bc for bc in self.boundary_spec if isinstance(bc.condition, ElectricBCTypes) ] if len(type_bcs) == 0 or all( isinstance(bc.condition, NeumannBCsCharge) for bc in type_bcs @@ -775,30 +747,25 @@ def not_all_neumann(cls, values): f"Current Neumann BCs are {names_neumann_Bcs}" ) - return values + return self - @pd.validator("grid_spec", always=True) - @skip_if_fields_missing(["structures"]) - def names_exist_grid_spec(cls, val, values): + @model_validator(mode="after") + def names_exist_grid_spec(self) -> Self: """Warn if 'UniformUnstructuredGrid' points at a non-existing structure.""" - - structures = values.get("structures") - structures_names = {s.name for s in structures} - - for structure_name in val.non_refined_structures: + structures_names = {s.name for s in self.structures} + for structure_name in self.grid_spec.non_refined_structures: if structure_name not in structures_names: log.warning( f"Structure '{structure_name}' listed as a non-refined structure in " "'HeatChargeSimulation.grid_spec' is not present in 'HeatChargeSimulation.structures'" ) + return self - return val - - @pd.validator("grid_spec", always=True) - def warn_if_minimal_mesh_size_override(cls, val, values): + @model_validator(mode="after") + def warn_if_minimal_mesh_size_override(self) -> Self: """Warn if minimal mesh size limit overrides desired mesh size.""" - - max_size = np.max(values.get("size")) + val = self.grid_spec + max_size = np.max(self.size) min_dl = val.relative_min_dl * max_size if isinstance(val, UniformUnstructuredGrid): @@ -812,16 +779,14 @@ def warn_if_minimal_mesh_size_override(cls, val, values): "Consider lowering parameter 'relative_min_dl' if a finer grid is required." ) - return val + return self - @pd.validator("sources", always=True) - @skip_if_fields_missing(["structures"]) - def names_exist_sources(cls, val, values): + @model_validator(mode="after") + def names_exist_sources(self) -> Self: """Error if a heat-charge source point to non-existing structures.""" - structures = values.get("structures") - structures_names = {s.name for s in structures} + structures_names = {s.name for s in self.structures} - sources = [s for s in val if not isinstance(s, HeatFromElectricSource)] + sources = [s for s in self.sources if not isinstance(s, HeatFromElectricSource)] for source in sources: for name in source.structures: @@ -830,22 +795,17 @@ def names_exist_sources(cls, val, values): f"Structure '{name}' provided in a '{source.type}' " "is not found among simulation structures." ) - return val + return self - @pd.root_validator(skip_on_failure=True) - def check_medium_specs(cls, values): + @model_validator(mode="after") + def check_medium_specs(self) -> Self: """Error if no appropriate specs.""" - sim_box = ( - Box( - size=values.get("size"), - center=values.get("center"), - ), - ) + sim_box = (Box(size=self.size, center=self.center),) - failed_solid_idx, failed_elect_idx = cls._check_cross_solids(sim_box, values) + failed_solid_idx, failed_elect_idx = self._check_cross_solids(sim_box) - simulation_types = cls._check_simulation_types(values=values) + simulation_types = self._check_simulation_types() for sim_type in simulation_types: if sim_type == TCADAnalysisTypes.HEAT: @@ -859,10 +819,10 @@ def check_medium_specs(cls, values): "No conducting materials ('ChargeConductorMedium') are detected in conduction simulation. Solution domain is empty." ) - return values + return self @staticmethod - def _check_if_semiconductor_present(structures) -> bool: + def _check_if_semiconductor_present(structures: Iterable[Structure]) -> bool: """Checks whether the simulation object can run a Charge simulation.""" charge_sim = False @@ -877,32 +837,27 @@ def _check_if_semiconductor_present(structures) -> bool: charge_sim = True return charge_sim - @staticmethod def _check_simulation_types( - values: dict, - HeatBCTypes=HeatBCTypes, - ElectricBCTypes=ElectricBCTypes, - HeatSourceTypes=HeatSourceTypes, + self, + HeatBCTypes: tuple[type, ...] = HeatBCTypes, + ElectricBCTypes: tuple[type, ...] = ElectricBCTypes, + HeatSourceTypes: tuple[type, ...] = HeatSourceTypes, ) -> list[TCADAnalysisTypes]: """Given model dictionary ``values``, check the type of simulations to be run based on BCs and sources. """ simulation_types = [] - boundaries = list(values["boundary_spec"]) - sources = list(values["sources"]) - analysis_spec = values["analysis_spec"] - - structures = list(values["structures"]) + analysis_spec = self.analysis_spec if isinstance(analysis_spec, ChargeTypes): simulation_types.append(TCADAnalysisTypes.CHARGE) semiconductor_present = HeatChargeSimulation._check_if_semiconductor_present( - structures=structures + structures=self.structures ) - for boundary in boundaries: + for boundary in self.boundary_spec: if isinstance(boundary.condition, HeatBCTypes): simulation_types.append(TCADAnalysisTypes.HEAT) if isinstance(boundary.condition, ElectricBCTypes): @@ -910,26 +865,22 @@ def _check_simulation_types( if not semiconductor_present: simulation_types.append(TCADAnalysisTypes.CONDUCTION) - for source in sources: + for source in self.sources: if isinstance(source, HeatSourceTypes): simulation_types.append(TCADAnalysisTypes.HEAT) return set(simulation_types) - @pd.root_validator(skip_on_failure=True) - def check_coupling_source_can_be_applied(cls, values): + @model_validator(mode="after") + def check_coupling_source_can_be_applied(self) -> Self: """Error if material doesn't have the right specifications""" HeatSourceTypes_noCoupling = (UniformHeatSource, HeatSource) - simulation_types = cls._check_simulation_types( - values, HeatSourceTypes=HeatSourceTypes_noCoupling - ) + simulation_types = self._check_simulation_types(HeatSourceTypes=HeatSourceTypes_noCoupling) simulation_types = list(simulation_types) - sources = list(values["sources"]) - - for source in sources: + for source in self.sources: if isinstance(source, HeatFromElectricSource) and len(simulation_types) < 2: raise SetupError( f"Using 'HeatFromElectricSource' requires the definition of both " @@ -937,35 +888,32 @@ def check_coupling_source_can_be_applied(cls, values): f"The current simulation setup contains only conditions of type {simulation_types[0].name}" ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def check_heat_sim(cls, values): + @model_validator(mode="after") + def check_heat_sim(self) -> Self: """Make sure that heat simulations have at least one monitor defined.""" - simulation_types = cls._check_simulation_types(values=values) + simulation_types = self._check_simulation_types() if TCADAnalysisTypes.HEAT in simulation_types: - monitors = values.get("monitors") - if not any(isinstance(mnt, TemperatureMonitor) for mnt in monitors): + if not any(isinstance(mnt, TemperatureMonitor) for mnt in self.monitors): raise SetupError( "Heat simulations require the definition of, at least, one " "'TemperatureMonitor' but none have been defined." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def check_conduction_sim(cls, values): + @model_validator(mode="after") + def check_conduction_sim(self) -> Self: """Make sure that conduction simulations have at least one monitor defined.""" - simulation_types = cls._check_simulation_types(values=values) - sources = values.get("sources") + simulation_types = self._check_simulation_types() if TCADAnalysisTypes.CONDUCTION in simulation_types: - monitors = values.get("monitors") - if not any(isinstance(mnt, SteadyPotentialMonitor) for mnt in monitors): - if any(isinstance(s, HeatFromElectricSource) for s in sources): + if not any(isinstance(mnt, SteadyPotentialMonitor) for mnt in self.monitors): + if any(isinstance(s, HeatFromElectricSource) for s in self.sources): log.warning( "A Conduction simulation has been defined but no " "SteadyPotentialMonitor has been defined. " @@ -977,7 +925,7 @@ def check_conduction_sim(cls, values): ) # now make sure we only have one voltage per VoltageBC - for bc in values.get("boundary_spec", []): + for bc in self.boundary_spec: if isinstance(bc.condition, VoltageBC): if isinstance(bc.condition.source, DCVoltageSource): if len(bc.condition.source.voltage) > 1: @@ -987,33 +935,31 @@ def check_conduction_sim(cls, values): ) # make sure that at least one structure has appropriate charge medium - ValidConductionMediums = ChargeConductorMedium - structures = values.get("structures") - if all(isinstance(s.medium, Medium) for s in structures): + if all(isinstance(s.medium, Medium) for s in self.structures): raise SetupError( "Conduction simulations must be defined using 'MultiPhysicsMedium' but none have been defined." ) - if not any(isinstance(s.medium.charge, ValidConductionMediums) for s in structures): + if not any(isinstance(s.medium.charge, ChargeConductorMedium) for s in self.structures): raise SetupError( "Conduction simulations require at least one structure with a 'ChargeConductorMedium' " "but none have been defined." ) - return values + return self - def _estimate_charge_mesh_size(self) -> None: + @model_validator(mode="after") + def estimate_charge_mesh_size(self) -> Self: """Make an estimate of the mesh size and raise a warning if too big. NOTE: this is a very rough estimate. The back-end will actually stop execution based on actual node-count.""" if TCADAnalysisTypes.CHARGE not in self._get_simulation_types(): - return + return self # let's raise a warning if the estimate is larger than 2M nodes max_nodes = 2e6 nodes_estimate = 0 - structures = self.structures grid_spec = self.grid_spec non_refined_structures = grid_spec.non_refined_structures @@ -1028,7 +974,7 @@ def _estimate_charge_mesh_size(self) -> None: dl_min = grid_spec.dl_interface dl_max = grid_spec.dl_bulk - for struct in structures: + for struct in self.structures: name = struct.name bounds = np.array(struct.geometry.bounds) for dim in range(3): @@ -1057,14 +1003,15 @@ def _estimate_charge_mesh_size(self) -> None: "the pipeline will be stopped. If this happens the grid specification " "may need to be modified." ) + return self - @pd.root_validator(skip_on_failure=True) - def check_transient_heat(cls, values): + @model_validator(mode="after") + def check_transient_heat(self) -> Self: """Make sure transient heat simulations can run.""" - analysis_type = values.get("analysis_spec") + analysis_type = self.analysis_spec if isinstance(analysis_type, UnsteadyHeatAnalysis): - monitors = values.get("monitors") + monitors = self.monitors for mnt in monitors: if isinstance(mnt, TemperatureMonitor): if not mnt.unstructured: @@ -1075,7 +1022,7 @@ def check_transient_heat(cls, values): capacities = [] densities = [] conductivities = [] - structures = values.get("structures") + structures = self.structures for structure in structures: heat_properties = None if isinstance(structure.medium, MultiPhysicsMedium): @@ -1105,7 +1052,7 @@ def check_transient_heat(cls, values): ) # check simulation time - domain_length = np.max([d for d in values.get("size") if d != np.inf]) + domain_length = np.max([d for d in self.size if d != np.inf]) characteristic_time = ( domain_length**2 * np.mean(capacities) @@ -1122,20 +1069,20 @@ def check_transient_heat(cls, values): "This may lead to unnecessary long simulation times. " "Consider reducing the simulation time or the time step size." ) - return values + return self - @pd.root_validator(skip_on_failure=True) - def check_non_isothermal_is_possible(cls, values): + @model_validator(mode="after") + def check_non_isothermal_is_possible(self) -> Self: """Make sure that when a non-isothermal case is defined the structures have both electrical and thermal properties.""" - analysis_spec = values.get("analysis_spec") + analysis_spec = self.analysis_spec if isinstance(analysis_spec, SteadyChargeDCAnalysis) and not isinstance( analysis_spec, IsothermalSteadyChargeDCAnalysis ): has_heat = False has_elec = False - structures = values.get("structures") + structures = self.structures for struct in structures: if isinstance(struct.medium, MultiPhysicsMedium): if struct.medium.heat is not None: @@ -1160,7 +1107,77 @@ def check_non_isothermal_is_possible(cls, values): "The current simulation is defined as non-isothermal but no " "solid or semiconductor materials have been defined. " ) - return values + return self + + @equal_aspect + @add_ax_if_none + def plot( + self, + x: Optional[float] = None, + y: Optional[float] = None, + z: Optional[float] = None, + ax: Ax = None, + source_alpha: Optional[float] = None, + monitor_alpha: Optional[float] = None, + hlim: Optional[tuple[float, float]] = None, + vlim: Optional[tuple[float, float]] = None, + fill_structures: bool = True, + **patch_kwargs: Any, + ) -> Ax: + """Plot each of simulation's components on a plane defined by one nonzero x,y,z coordinate. + + Parameters + ---------- + x : float = None + position of plane in x direction, only one of x, y, z must be specified to define plane. + y : float = None + position of plane in y direction, only one of x, y, z must be specified to define plane. + z : float = None + position of plane in z direction, only one of x, y, z must be specified to define plane. + ax : matplotlib.axes._subplots.Axes = None + Matplotlib axes to plot on, if not specified, one is created. + source_alpha : float = None + Opacity of the sources. If ``None``, uses Tidy3d default. + monitor_alpha : float = None + Opacity of the monitors. If ``None``, uses Tidy3d default. + hlim : Tuple[float, float] = None + The x range if plotting on xy or xz planes, y range if plotting on yz plane. + vlim : Tuple[float, float] = None + The z range if plotting on xz or yz planes, y plane if plotting on xy plane. + fill_structures : bool = True + Whether to fill structures with color or just draw outlines. + + Returns + ------- + matplotlib.axes._subplots.Axes + The supplied or created matplotlib axes. + """ + + # Call the parent's plot method + ax = super().plot( + x=x, + y=y, + z=z, + ax=ax, + source_alpha=source_alpha, + monitor_alpha=monitor_alpha, + hlim=hlim, + vlim=vlim, + fill_structures=fill_structures, + **patch_kwargs, + ) + + # Add boundaries based on simulation type + # NOTE: there's no need to add heat boundaries since + # they are already added in the parent 'plot' method. + simulation_types = self._get_simulation_types() + if ( + TCADAnalysisTypes.CHARGE in simulation_types + or TCADAnalysisTypes.CONDUCTION in simulation_types + ): + ax = self.plot_boundaries(ax=ax, x=x, y=y, z=z, property="electric_conductivity") + + return ax @equal_aspect @add_ax_if_none @@ -1173,7 +1190,9 @@ def plot_property( alpha: Optional[float] = None, source_alpha: Optional[float] = None, monitor_alpha: Optional[float] = None, - property: str = "heat_conductivity", + property: Literal[ + "heat_conductivity", "electric_conductivity", "source" + ] = "heat_conductivity", hlim: Optional[tuple[float, float]] = None, vlim: Optional[tuple[float, float]] = None, ) -> Ax: @@ -1199,9 +1218,9 @@ def plot_property( property : str = "heat_conductivity" Specified the type of simulation for which the plot will be tailored. Options are ["heat_conductivity", "electric_conductivity", "source"] - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -1301,9 +1320,9 @@ def plot_heat_conductivity( colorbar: str = "conductivity" Display colorbar for thermal conductivity ("conductivity") or heat source rate ("source"). - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. Returns @@ -1312,9 +1331,9 @@ def plot_heat_conductivity( The supplied or created matplotlib axes. """ log.warning( - """This function `plot_heat_conductivity` is - deprecated and will be discontinued. In its place you can use - `plot_property(property="heat_conductivity")`""" + "The function 'plot_heat_conductivity' is " + "deprecated and will be discontinued. In its place you can use " + r"'plot_property(property=\"heat_conductivity\")'" ) plot_type = "heat_conductivity" @@ -1664,16 +1683,16 @@ def _construct_heat_charge_boundaries( Parameters ---------- - structures : List[:class:`.Structure`] + structures : list[:class:`.Structure`] list of structures to filter on the plane. plane : :class:`.Box` target plane. - boundary_spec : List[HeatBoundarySpec] + boundary_spec : list[HeatBoundarySpec] list of boundary conditions associated with structures. Returns ------- - List[Tuple[:class:`.HeatBoundarySpec`, shapely.geometry.base.BaseGeometry]] + list[tuple[:class:`.HeatBoundarySpec`, shapely.geometry.base.BaseGeometry]] List of boundary lines and boundary conditions on the plane after merging. """ @@ -1745,9 +1764,9 @@ def plot_sources( property : str = None Specified the type of simulation for which the plot will be tailored. Options are ["heat_conductivity", "electric_conductivity"] - hlim : Tuple[float, float] = None + hlim : tuple[float, float] = None The x range if plotting on xy or xz planes, y range if plotting on yz plane. - vlim : Tuple[float, float] = None + vlim : tuple[float, float] = None The z range if plotting on xz or yz planes, y plane if plotting on xy plane. alpha : float = None Opacity of the sources, If ``None`` uses Tidy3d default. @@ -1828,7 +1847,6 @@ def _add_source_cbar(self, ax: Ax, property: str = "heat_conductivity") -> None: def source_bounds(self, property: str = "heat_conductivity") -> tuple[float, float]: """Compute range of heat sources present in the simulation.""" - if property == "heat_conductivity" or property == "source": rate_list = [ np.mean(source.rate) for source in self.sources if isinstance(source, HeatSource) @@ -1981,18 +1999,17 @@ def _get_simulation_types(self) -> list[TCADAnalysisTypes]: return simulation_types - def _useHeatSourceFromConductionSim(self): + def _useHeatSourceFromConductionSim(self) -> bool: """Returns True if 'HeatFromElectricSource' has been defined.""" - return any(isinstance(source, HeatFromElectricSource) for source in self.sources) - def _get_charge_type(self): + def _get_charge_type(self) -> Literal["ac", "dc"]: if isinstance(self.analysis_spec, (SSACAnalysis, IsothermalSSACAnalysis)): return "ac" else: return "dc" - def _get_ssac_frequency_and_amplitude(self): + def _get_ssac_frequency_and_amplitude(self) -> tuple[ArrayFloat1D, FiniteFloat]: if not isinstance(self.analysis_spec, (SSACAnalysis, IsothermalSSACAnalysis)): raise SetupError( "Invalid analysis type for Small-Signal AC (SSAC). " @@ -2002,7 +2019,8 @@ def _get_ssac_frequency_and_amplitude(self): amplitude = None for bc in self.boundary_spec: - if isinstance(bc.condition, VoltageBC): - if isinstance(bc.condition.source, SSACVoltageSource): - amplitude = bc.condition.source.amplitude + if isinstance(bc.condition, VoltageBC) and isinstance( + bc.condition.source, SSACVoltageSource + ): + amplitude = bc.condition.source.amplitude return (self.analysis_spec.freqs, amplitude) diff --git a/tidy3d/components/tcad/source/abstract.py b/tidy3d/components/tcad/source/abstract.py index d5891207e8..955955d531 100644 --- a/tidy3d/components/tcad/source/abstract.py +++ b/tidy3d/components/tcad/source/abstract.py @@ -3,15 +3,18 @@ from __future__ import annotations from abc import ABC +from typing import TYPE_CHECKING -import pydantic.v1 as pd +from pydantic import Field, field_validator from tidy3d.components.base import cached_property from tidy3d.components.base_sim.source import AbstractSource from tidy3d.components.tcad.viz import plot_params_heat_source -from tidy3d.components.viz import PlotParams from tidy3d.exceptions import SetupError +if TYPE_CHECKING: + from tidy3d.components.viz import PlotParams + class AbstractHeatChargeSource(AbstractSource, ABC): """Abstract source for heat-charge simulations. All source types @@ -27,13 +30,14 @@ class StructureBasedHeatChargeSource(AbstractHeatChargeSource): """Abstract class associated with structures. Sources associated to structures must derive from this class""" - structures: tuple[str, ...] = pd.Field( + structures: tuple[str, ...] = Field( title="Target Structures", description="Names of structures where to apply heat source.", ) - @pd.validator("structures", always=True) - def check_non_empty_structures(cls, val): + @field_validator("structures") + @classmethod + def check_non_empty_structures(cls, val: tuple[str, ...]) -> tuple[str, ...]: """Error if source doesn't point at any structures.""" if len(val) == 0: raise SetupError("List of structures for heat source is empty.") diff --git a/tidy3d/components/tcad/source/heat.py b/tidy3d/components/tcad/source/heat.py index e47de2963d..1fbbde2007 100644 --- a/tidy3d/components/tcad/source/heat.py +++ b/tidy3d/components/tcad/source/heat.py @@ -2,9 +2,9 @@ from __future__ import annotations -from typing import Union +from typing import Any, Union -import pydantic.v1 as pd +from pydantic import Field, model_validator from tidy3d.components.data.data_array import SpatialDataArray from tidy3d.components.tcad.source.abstract import StructureBasedHeatChargeSource @@ -21,7 +21,7 @@ class HeatSource(StructureBasedHeatChargeSource): >>> heat_source = HeatSource(rate=1, structures=["box"]) """ - rate: Union[float, SpatialDataArray] = pd.Field( + rate: Union[float, SpatialDataArray] = Field( title="Volumetric Heat Rate", description="Volumetric rate of heating or cooling (if negative).", units=VOLUMETRIC_HEAT_RATE, @@ -39,11 +39,12 @@ class UniformHeatSource(HeatSource): # NOTE: wrapper for backwards compatibility. - @pd.root_validator(skip_on_failure=True) - def issue_warning_deprecated(cls, values): + @model_validator(mode="before") + @classmethod + def issue_warning_deprecated(cls, data: dict[str, Any]) -> dict[str, Any]: """Issue warning for 'UniformHeatSource'.""" log.warning( "'UniformHeatSource' is deprecated and will be discontinued. You can use " "'HeatSource' instead." ) - return values + return data diff --git a/tidy3d/components/time.py b/tidy3d/components/time.py index 448ce28935..c14f051d8b 100644 --- a/tidy3d/components/time.py +++ b/tidy3d/components/time.py @@ -1,202 +1,11 @@ -"""Defines time dependence""" +"""Compatibility shim for :mod:`tidy3d._common.components.time`.""" -from __future__ import annotations - -from abc import ABC, abstractmethod - -import numpy as np -import pydantic.v1 as pydantic - -from tidy3d.constants import RADIAN -from tidy3d.exceptions import SetupError - -from .base import Tidy3dBaseModel -from .types import ArrayFloat1D, Ax, PlotVal -from .viz import add_ax_if_none - -# in spectrum computation, discard amplitudes with relative magnitude smaller than cutoff -DFT_CUTOFF = 1e-8 - - -class AbstractTimeDependence(ABC, Tidy3dBaseModel): - """Base class describing time dependence.""" - - amplitude: pydantic.NonNegativeFloat = pydantic.Field( - 1.0, title="Amplitude", description="Real-valued maximum amplitude of the time dependence." - ) - - phase: float = pydantic.Field( - 0.0, title="Phase", description="Phase shift of the time dependence.", units=RADIAN - ) - - @abstractmethod - def amp_time(self, time: float) -> complex: - """Complex-valued amplitude as a function of time. - - Parameters - ---------- - time : float - Time in seconds. - - Returns - ------- - complex - Complex-valued amplitude at that time. - """ - - def spectrum( - self, - times: ArrayFloat1D, - freqs: ArrayFloat1D, - dt: float, - ) -> complex: - """Complex-valued spectrum as a function of frequency. - Note: Only the real part of the time signal is used. - - Parameters - ---------- - times : np.ndarray - Times to use to evaluate spectrum Fourier transform. - (Typically the simulation time mesh). - freqs : np.ndarray - Frequencies in Hz to evaluate spectrum at. - dt : float or np.ndarray - Time step to weight FT integral with. - If array, use to weigh each of the time intervals in ``times``. - - Returns - ------- - np.ndarray - Complex-valued array (of len(freqs)) containing spectrum at those frequencies. - """ - - times = np.array(times) - freqs = np.array(freqs) - time_amps = np.real(self.amp_time(times)) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - # if all time amplitudes are zero, just return (complex-valued) zeros for spectrum - if np.all(np.equal(time_amps, 0.0)): - return (0.0 + 0.0j) * np.zeros_like(freqs) - - # Cut to only relevant times - relevant_time_inds = np.where(np.abs(time_amps) / np.amax(np.abs(time_amps)) > DFT_CUTOFF) - # find first and last index where the filter is True - start_ind = relevant_time_inds[0][0] - stop_ind = relevant_time_inds[0][-1] + 1 - time_amps = time_amps[start_ind:stop_ind] - times_cut = times[start_ind:stop_ind] - if times_cut.size == 0: - return (0.0 + 0.0j) * np.zeros_like(freqs) - - # only need to compute DTFT kernel for distinct dts - # usually, there is only one dt, if times is simulation time mesh - dts = np.diff(times_cut) - dts_unique, kernel_indices = np.unique(dts, return_inverse=True) - - dft_kernels = [np.exp(2j * np.pi * freqs * curr_dt) for curr_dt in dts_unique] - running_kernel = np.exp(2j * np.pi * freqs * times_cut[0]) - dft = np.zeros(len(freqs), dtype=complex) - for amp, kernel_index in zip(time_amps, kernel_indices): - dft += running_kernel * amp - running_kernel *= dft_kernels[kernel_index] - - # kernel_indices was one index shorter than time_amps - dft += running_kernel * time_amps[-1] - - return dt * dft / np.sqrt(2 * np.pi) - - @add_ax_if_none - def plot_spectrum_in_frequency_range( - self, - times: ArrayFloat1D, - fmin: float, - fmax: float, - num_freqs: int = 101, - val: PlotVal = "real", - ax: Ax = None, - ) -> Ax: - """Plot the complex-valued amplitude of the time-dependence. - Note: Only the real part of the time signal is used. - - Parameters - ---------- - times : np.ndarray - Array of evenly-spaced times (seconds) to evaluate time-dependence at. - The spectrum is computed from this value and the time frequency content. - To see spectrum for a specific :class:`.Simulation`, - pass ``simulation.tmesh``. - fmin : float - Lower bound of frequency for the spectrum plot. - fmax : float - Upper bound of frequency for the spectrum plot. - num_freqs : int = 101 - Number of frequencies to plot within the [fmin, fmax]. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - times = np.array(times) - - dts = np.diff(times) - if not np.allclose(dts, dts[0] * np.ones_like(dts), atol=1e-17): - raise SetupError("Supplied times not evenly spaced.") - - dt = np.mean(dts) - freqs = np.linspace(fmin, fmax, num_freqs) - - spectrum = self.spectrum(times=times, dt=dt, freqs=freqs) - - if val == "real": - ax.plot(freqs, spectrum.real, color="blueviolet", label="real") - elif val == "imag": - ax.plot(freqs, spectrum.imag, color="crimson", label="imag") - elif val == "abs": - ax.plot(freqs, np.abs(spectrum), color="k", label="abs") - else: - raise ValueError(f"Plot 'val' option of '{val}' not recognized.") - ax.set_xlabel("frequency (Hz)") - ax.set_title("source spectrum") - ax.legend() - ax.set_aspect("auto") - return ax - - @add_ax_if_none - def plot(self, times: ArrayFloat1D, val: PlotVal = "real", ax: Ax = None) -> Ax: - """Plot the complex-valued amplitude of the time-dependence. - - Parameters - ---------- - times : np.ndarray - Array of times (seconds) to plot source at. - To see source time amplitude for a specific :class:`.Simulation`, - pass ``simulation.tmesh``. - val : Literal['real', 'imag', 'abs'] = 'real' - Which part of the spectrum to plot. - ax : matplotlib.axes._subplots.Axes = None - Matplotlib axes to plot on, if not specified, one is created. - - Returns - ------- - matplotlib.axes._subplots.Axes - The supplied or created matplotlib axes. - """ - times = np.array(times) - amp_complex = self.amp_time(times) +# marked as migrated to _common +from __future__ import annotations - if val == "real": - ax.plot(times, amp_complex.real, color="blueviolet", label="real") - elif val == "imag": - ax.plot(times, amp_complex.imag, color="crimson", label="imag") - elif val == "abs": - ax.plot(times, np.abs(amp_complex), color="k", label="abs") - else: - raise ValueError(f"Plot 'val' option of '{val}' not recognized.") - ax.set_xlabel("time (s)") - ax.set_title("source amplitude") - ax.legend() - ax.set_aspect("auto") - return ax +from tidy3d._common.components.time import ( + DFT_CUTOFF, + AbstractTimeDependence, +) diff --git a/tidy3d/components/time_modulation.py b/tidy3d/components/time_modulation.py index a4d82d7384..17094c7462 100644 --- a/tidy3d/components/time_modulation.py +++ b/tidy3d/components/time_modulation.py @@ -4,19 +4,26 @@ from abc import ABC, abstractmethod from math import isclose -from typing import Union +from typing import TYPE_CHECKING, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, field_validator, model_validator from tidy3d.constants import HERTZ, RADIAN from tidy3d.exceptions import ValidationError -from .base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from .base import Tidy3dBaseModel, cached_property from .data.data_array import SpatialDataArray from .data.validators import validate_no_nans from .time import AbstractTimeDependence -from .types import Bound, InterpMethod +from .types import InterpMethod + +if TYPE_CHECKING: + from pydantic import FieldValidationInfo + + from tidy3d.compat import Self + + from .types import Bound class AbstractTimeModulation(AbstractTimeDependence, ABC): @@ -63,8 +70,8 @@ class ContinuousWaveTimeModulation(AbstractTimeDependence): >>> cw = ContinuousWaveTimeModulation(freq0=200e12, amplitude=1, phase=0) """ - freq0: pd.PositiveFloat = pd.Field( - ..., title="Modulation Frequency", description="Modulation frequency.", units=HERTZ + freq0: PositiveFloat = Field( + title="Modulation Frequency", description="Modulation frequency.", units=HERTZ ) def amp_time(self, time: float) -> complex: @@ -128,41 +135,36 @@ class SpaceModulation(AbstractSpaceModulation): >>> space = SpaceModulation(amplitude=amp, phase=phase) """ - amplitude: Union[float, SpatialDataArray] = pd.Field( + amplitude: Union[float, SpatialDataArray] = Field( 1, title="Amplitude of modulation in space", description="Amplitude of modulation that can vary spatially. " "It takes the unit of whatever is being modulated.", ) - phase: Union[float, SpatialDataArray] = pd.Field( + phase: Union[float, SpatialDataArray] = Field( 0, title="Phase of modulation in space", description="Phase of modulation that can vary spatially.", units=RADIAN, ) - interp_method: InterpMethod = pd.Field( + interp_method: InterpMethod = Field( "nearest", title="Interpolation method", description="Method of interpolation to use to obtain values at spatial locations on the Yee grids.", ) - _no_nans_amplitude = validate_no_nans("amplitude") - _no_nans_phase = validate_no_nans("phase") + _no_nans = validate_no_nans("amplitude", "phase") - @pd.validator("amplitude", always=True) - def _real_amplitude(cls, val): + @field_validator("amplitude", "phase") + @classmethod + def _validate_fields_real( + cls, val: Union[float, SpatialDataArray], info: FieldValidationInfo + ) -> Union[float, SpatialDataArray]: """Assert that the amplitude is real.""" if np.iscomplexobj(val): - raise ValidationError("'amplitude' must be real.") - return val - - @pd.validator("phase", always=True) - def _real_phase(cls, val): - """Assert that the phase is real.""" - if np.iscomplexobj(val): - raise ValidationError("'phase' must be real.") + raise ValidationError(f"'{info.field_name}' must be real.") return val @cached_property @@ -170,7 +172,7 @@ def max_modulation(self) -> float: """Estimated maximum modulation amplitude.""" return np.max(abs(np.array(self.amplitude))) - def sel_inside(self, bounds: Bound) -> SpaceModulation: + def sel_inside(self, bounds: Bound) -> Self: """Return a new space modulation that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. @@ -217,18 +219,15 @@ class SpaceTimeModulation(Tidy3dBaseModel): \\delta \\epsilon(r, t) = \\Re[amp\\_time(t) \\cdot amp\\_space(r)] """ - space_modulation: SpaceModulationType = pd.Field( - SpaceModulation(), + space_modulation: SpaceModulationType = Field( + default_factory=SpaceModulation, title="Space modulation", description="Space modulation part from the separable SpaceTimeModulation.", - # discriminator=TYPE_TAG_STR, ) - time_modulation: TimeModulationType = pd.Field( - ..., + time_modulation: TimeModulationType = Field( title="Time modulation", description="Time modulation part from the separable SpaceTimeModulation.", - # discriminator=TYPE_TAG_STR, ) @cached_property @@ -244,7 +243,7 @@ def negligible_modulation(self) -> bool: return True return False - def sel_inside(self, bounds: Bound) -> SpaceTimeModulation: + def sel_inside(self, bounds: Bound) -> Self: """Return a new space-time modulation that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. @@ -268,38 +267,36 @@ class ModulationSpec(Tidy3dBaseModel): including relative permittivity at infinite frequency and electric conductivity. """ - permittivity: SpaceTimeModulation = pd.Field( + permittivity: Optional[SpaceTimeModulation] = Field( None, title="Space-time modulation of relative permittivity", description="Space-time modulation of relative permittivity at infinite frequency " "applied on top of the base permittivity at infinite frequency.", ) - conductivity: SpaceTimeModulation = pd.Field( + conductivity: Optional[SpaceTimeModulation] = Field( None, title="Space-time modulation of conductivity", description="Space-time modulation of electric conductivity " "applied on top of the base conductivity.", ) - @pd.validator("conductivity", always=True) - @skip_if_fields_missing(["permittivity"]) - def _same_modulation_frequency(cls, val, values): + @model_validator(mode="after") + def _check_same_modulation_frequency(self) -> Self: """Assert same time-modulation applied to permittivity and conductivity.""" - permittivity = values.get("permittivity") - if val is not None and permittivity is not None: - if val.time_modulation != permittivity.time_modulation: + if self.conductivity is not None and self.permittivity is not None: + if self.conductivity.time_modulation != self.permittivity.time_modulation: raise ValidationError( "'permittivity' and 'conductivity' should have the same time modulation." ) - return val + return self @cached_property def applied_modulation(self) -> bool: """Check if any modulation has been applied to ``permittivity`` or ``conductivity``.""" return self.permittivity is not None or self.conductivity is not None - def sel_inside(self, bounds: Bound) -> ModulationSpec: + def sel_inside(self, bounds: Bound) -> Self: """Return a new modulation specficiation that contains the minimal amount data necessary to cover a spatial region defined by ``bounds``. diff --git a/tidy3d/components/transformation.py b/tidy3d/components/transformation.py index 4e2643a9ae..3add2c4b41 100644 --- a/tidy3d/components/transformation.py +++ b/tidy3d/components/transformation.py @@ -1,205 +1,15 @@ -"""Defines geometric transformation classes""" +"""Compatibility shim for :mod:`tidy3d._common.components.transformation`.""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Union - -import numpy as np -import pydantic.v1 as pd - -from tidy3d.constants import RADIAN -from tidy3d.exceptions import ValidationError - -from .autograd import TracedFloat -from .base import Tidy3dBaseModel, cached_property -from .types import ArrayFloat2D, Axis, Coordinate, TensorReal - - -class AbstractRotation(ABC, Tidy3dBaseModel): - """Abstract rotation of vectors and tensors.""" - - @cached_property - @abstractmethod - def matrix(self) -> TensorReal: - """Rotation matrix.""" - - @cached_property - @abstractmethod - def isidentity(self) -> bool: - """Check whether rotation is identity.""" - - def rotate_vector(self, vector: ArrayFloat2D) -> ArrayFloat2D: - """Rotate a vector/point or a list of vectors/points. - - Parameters - ---------- - points : ArrayLike[float] - Array of shape ``(3, ...)``. - - Returns - ------- - Coordinate - Rotated vector. - """ - - if self.isidentity: - return vector - - if len(vector.shape) == 1: - return self.matrix @ vector - - return np.tensordot(self.matrix, vector, axes=1) - - def rotate_tensor(self, tensor: TensorReal) -> TensorReal: - """Rotate a tensor. - - Parameters - ---------- - tensor : ArrayLike[float] - Array of shape ``(3, 3)``. - - Returns - ------- - TensorReal - Rotated tensor. - """ - - if self.isidentity: - return tensor - - return np.matmul(self.matrix, np.matmul(tensor, self.matrix.T)) - - -class RotationAroundAxis(AbstractRotation): - """Rotation of vectors and tensors around a given vector.""" - - axis: Union[Axis, Coordinate] = pd.Field( - 0, - title="Axis of Rotation", - description="A vector that specifies the axis of rotation, or a single int: 0, 1, or 2, " - "indicating x, y, or z.", - ) - - angle: TracedFloat = pd.Field( - 0.0, - title="Angle of Rotation", - description="Angle of rotation in radians.", - units=RADIAN, - ) - - @pd.validator("axis", always=True) - def _convert_axis_index_to_vector(cls, val): - if not isinstance(val, tuple): - axis = [0.0, 0.0, 0.0] - axis[val] = 1.0 - val = tuple(axis) - return val - - @pd.validator("axis") - def _guarantee_nonzero_axis(cls, val): - norm = np.linalg.norm(val) - if np.isclose(norm, 0): - raise ValidationError( - "The norm of vector 'axis' cannot be zero. Please provide a proper rotation axis." - ) - return val - - @cached_property - def isidentity(self) -> bool: - """Check whether rotation is identity.""" - - return np.isclose(self.angle % (2 * np.pi), 0) - - @cached_property - def matrix(self) -> TensorReal: - """Rotation matrix.""" - - if self.isidentity: - return np.eye(3) - - norm = np.linalg.norm(self.axis) - n = self.axis / norm - c = np.cos(self.angle) - s = np.sin(self.angle) - K = np.array([[0, -n[2], n[1]], [n[2], 0, -n[0]], [-n[1], n[0], 0]]) - R = np.eye(3) + s * K + (1 - c) * K @ K - - return R - - -class AbstractReflection(ABC, Tidy3dBaseModel): - """Abstract reflection of vectors and tensors.""" - - @cached_property - @abstractmethod - def matrix(self) -> TensorReal: - """Reflection matrix.""" - - def reflect_vector(self, vector: ArrayFloat2D) -> ArrayFloat2D: - """Reflect a vector/point or a list of vectors/points. - - Parameters - ---------- - vector : ArrayLike[float] - Array of shape ``(3, ...)``. - - Returns - ------- - Coordinate - Reflected vector. - """ - - if len(vector.shape) == 1: - return self.matrix @ vector - - return np.tensordot(self.matrix, vector, axes=1) - - def reflect_tensor(self, tensor: TensorReal) -> TensorReal: - """Reflect a tensor. - - Parameters - ---------- - tensor : ArrayLike[float] - Array of shape ``(3, 3)``. - - Returns - ------- - TensorReal - Reflected tensor. - """ - - return np.matmul(self.matrix, np.matmul(tensor, self.matrix.T)) - - -class ReflectionFromPlane(AbstractReflection): - """Reflection of vectors and tensors around a given vector.""" - - normal: Coordinate = pd.Field( - (1, 0, 0), - title="Normal of the reflecting plane", - description="A vector that specifies the normal of the plane of reflection", - ) - - @pd.validator("normal") - def _guarantee_nonzero_normal(cls, val): - norm = np.linalg.norm(val) - if np.isclose(norm, 0): - raise ValidationError( - "The norm of vector 'normal' cannot be zero. Please provide a proper normal vector." - ) - return val - - @cached_property - def matrix(self) -> TensorReal: - """Reflection matrix.""" - - norm = np.linalg.norm(self.normal) - n = self.normal / norm - R = np.eye(3) - 2 * np.outer(n, n) - - return R +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -RotationType = Union[RotationAroundAxis] -ReflectionType = Union[ReflectionFromPlane] +from tidy3d._common.components.transformation import ( + AbstractReflection, + AbstractRotation, + ReflectionFromPlane, + ReflectionType, + RotationAroundAxis, + RotationType, +) diff --git a/tidy3d/components/types/__init__.py b/tidy3d/components/types/__init__.py index 0d7c70e94c..4c4ddd1b22 100644 --- a/tidy3d/components/types/__init__.py +++ b/tidy3d/components/types/__init__.py @@ -23,7 +23,6 @@ ClipOperationType, ColormapType, Complex, - ComplexNumber, Coordinate, Coordinate2D, CoordinateOptional, @@ -43,7 +42,6 @@ MatrixReal4x4, ModeClassification, ModeSolverType, - Numpy, ObsGridArray, PermittivityComponent, PlanePosition, @@ -63,9 +61,6 @@ TrackFreq, Undefined, UnitsZBF, - annotate_type, - constrained_array, - tidycomplex, xyz, ) from tidy3d.components.types.third_party import TrimeshType @@ -92,7 +87,6 @@ "ClipOperationType", "ColormapType", "Complex", - "ComplexNumber", "Coordinate", "Coordinate2D", "CoordinateOptional", @@ -112,7 +106,6 @@ "MatrixReal4x4", "ModeClassification", "ModeSolverType", - "Numpy", "ObsGridArray", "PermittivityComponent", "PlanePosition", @@ -134,8 +127,5 @@ "Undefined", "UnitsZBF", "_add_schema", - "annotate_type", - "constrained_array", - "tidycomplex", "xyz", ] diff --git a/tidy3d/components/types/base.py b/tidy3d/components/types/base.py index b49e5d7086..39695d994e 100644 --- a/tidy3d/components/types/base.py +++ b/tidy3d/components/types/base.py @@ -1,264 +1,82 @@ -"""Defines 'types' that various fields can be""" +"""Compatibility shim for :mod:`tidy3d._common.components.types.base`.""" -from __future__ import annotations - -from typing import Literal, Optional, Union - -import autograd.numpy as np -import pydantic.v1 as pydantic - -try: - from matplotlib.axes import Axes -except ImportError: - Axes = None -from typing import Annotated - -from shapely.geometry.base import BaseGeometry - -from tidy3d.exceptions import ValidationError - -# type tag default name -TYPE_TAG_STR = "type" - - -def annotate_type(UnionType): - """Annotated union type using TYPE_TAG_STR as discriminator.""" - return Annotated[UnionType, pydantic.Field(discriminator=TYPE_TAG_STR)] - - -""" Numpy Arrays """ - - -def _totuple(arr: np.ndarray) -> tuple: - """Convert a numpy array to a nested tuple.""" - if arr.ndim > 1: - return tuple(_totuple(val) for val in arr) - return tuple(arr) - - -# generic numpy array -Numpy = np.ndarray - - -class ArrayLike: - """Type that stores a numpy array.""" - - ndim = None - dtype = None - shape = None - - @classmethod - def __get_validators__(cls): - yield cls.load_complex - yield cls.convert_to_numpy - yield cls.check_dims - yield cls.check_shape - yield cls.assert_non_null - - @classmethod - def load_complex(cls, val): - """Special handling to load a complex-valued np.ndarray saved to file.""" - if not isinstance(val, dict): - return val - if "real" not in val or "imag" not in val: - raise ValueError("ArrayLike real and imaginary parts not stored properly.") - arr_real = np.array(val["real"]) - arr_imag = np.array(val["imag"]) - return arr_real + 1j * arr_imag - - @classmethod - def convert_to_numpy(cls, val): - """Convert the value to np.ndarray and provide some casting.""" - arr_numpy = np.array(val, ndmin=1, dtype=cls.dtype, copy=True) - return arr_numpy - - @classmethod - def check_dims(cls, val): - """Make sure the number of dimensions is correct.""" - if cls.ndim and val.ndim != cls.ndim: - raise ValidationError(f"Expected {cls.ndim} dimensions for ArrayLike, got {val.ndim}.") - return val - - @classmethod - def check_shape(cls, val): - """Make sure the shape is correct.""" - if cls.shape and val.shape != cls.shape: - raise ValidationError(f"Expected shape {cls.shape} for ArrayLike, got {val.shape}.") - return val - - @classmethod - def assert_non_null(cls, val): - """Make sure array is not None.""" - if np.any(np.isnan(val)): - raise ValidationError("'ArrayLike' field contained None or nan values.") - return val - - @classmethod - def __modify_schema__(cls, field_schema) -> None: - """Sets the schema of DataArray object.""" - - schema = { - "type": "ArrayLike", - } - field_schema.update(schema) - - -def constrained_array( - dtype: Optional[type] = None, - ndim: Optional[int] = None, - shape: Optional[tuple[pydantic.NonNegativeInt, ...]] = None, -) -> type: - """Generate an ArrayLike sub-type with constraints built in.""" - - # note, a unique name is required for each subclass of ArrayLike with constraints - type_name = "ArrayLike" - - meta_args = [] - if dtype is not None: - meta_args.append(f"dtype={dtype.__name__}") - if ndim is not None: - meta_args.append(f"ndim={ndim}") - if shape is not None: - meta_args.append(f"shape={shape}") - type_name += "[" + ", ".join(meta_args) + "]" - - return type(type_name, (ArrayLike,), {"dtype": dtype, "ndim": ndim, "shape": shape}) - - -# pre-define a set of commonly used array like instances for import and use in type hints -ArrayInt1D = constrained_array(dtype=int, ndim=1) -ArrayFloat1D = constrained_array(dtype=float, ndim=1) -ArrayFloat2D = constrained_array(dtype=float, ndim=2) -ArrayFloat3D = constrained_array(dtype=float, ndim=3) -ArrayFloat4D = constrained_array(dtype=float, ndim=4) -ArrayComplex1D = constrained_array(dtype=complex, ndim=1) -ArrayComplex2D = constrained_array(dtype=complex, ndim=2) -ArrayComplex3D = constrained_array(dtype=complex, ndim=3) -ArrayComplex4D = constrained_array(dtype=complex, ndim=4) - -TensorReal = constrained_array(dtype=float, ndim=2, shape=(3, 3)) -MatrixReal4x4 = constrained_array(dtype=float, ndim=2, shape=(4, 4)) - -""" Complex Values """ +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -class ComplexNumber(pydantic.BaseModel): - """Complex number with a well defined schema.""" - - real: float - imag: float - - @property - def as_complex(self): - """return complex representation of ComplexNumber.""" - return self.real + 1j * self.imag - - -class tidycomplex(complex): - """complex type that we can use in our models.""" - - @classmethod - def __get_validators__(cls): - """Defines which validator function to use for ComplexNumber.""" - yield cls.validate - - @classmethod - def validate(cls, value): - """What gets called when you construct a tidycomplex.""" - - if isinstance(value, ComplexNumber): - return value.as_complex - if isinstance(value, dict): - c = ComplexNumber(**value) - return c.as_complex - return cls(value) - - @classmethod - def __modify_schema__(cls, field_schema) -> None: - """Sets the schema of ComplexNumber.""" - field_schema.update(ComplexNumber.schema()) - - -""" symmetry """ - -Symmetry = Literal[0, -1, 1] -ScalarSymmetry = Literal[0, 1] - -""" geometric """ - -Size1D = pydantic.NonNegativeFloat -Size = tuple[Size1D, Size1D, Size1D] -Coordinate = tuple[float, float, float] -CoordinateOptional = tuple[Optional[float], Optional[float], Optional[float]] -Coordinate2D = tuple[float, float] -Bound = tuple[Coordinate, Coordinate] -GridSize = Union[pydantic.PositiveFloat, tuple[pydantic.PositiveFloat, ...]] -Axis = Literal[0, 1, 2] -Axis2D = Literal[0, 1] -Shapely = BaseGeometry -PlanePosition = Literal["bottom", "middle", "top"] -ClipOperationType = Literal["union", "intersection", "difference", "symmetric_difference"] -BoxSurface = Literal["x-", "x+", "y-", "y+", "z-", "z+"] -LengthUnit = Literal["nm", "μm", "um", "mm", "cm", "m", "mil", "in"] -PriorityMode = Literal["equal", "conductor"] - -""" medium """ - -# custom medium -InterpMethod = Literal["nearest", "linear"] - -# Complex = Union[complex, ComplexNumber] -Complex = Union[tidycomplex, ComplexNumber] -PoleAndResidue = tuple[Complex, Complex] - -# PoleAndResidue = Tuple[Tuple[float, float], Tuple[float, float]] -FreqBoundMax = float -FreqBoundMin = float -FreqBound = tuple[FreqBoundMin, FreqBoundMax] - -PermittivityComponent = Literal["xx", "xy", "xz", "yx", "yy", "yz", "zx", "zy", "zz"] - -""" sources """ - -Polarization = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] -Direction = Literal["+", "-"] - -""" monitors """ - -EMField = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] -FieldType = Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"] -FreqArray = Union[tuple[float, ...], ArrayFloat1D] -ObsGridArray = Union[tuple[float, ...], ArrayFloat1D] -PolarizationBasis = Literal["linear", "circular"] -AuxField = Literal["Nfx", "Nfy", "Nfz"] - -""" plotting """ - -Ax = Axes -PlotVal = Literal["real", "imag", "abs"] -FieldVal = Literal["real", "imag", "abs", "abs^2", "phase"] -RealFieldVal = Literal["real", "abs", "abs^2"] -PlotScale = Literal["lin", "dB", "log", "symlog"] -ColormapType = Literal["divergent", "sequential", "cyclic"] - -""" mode solver """ - -ModeSolverType = Literal["tensorial", "diagonal"] -EpsSpecType = Literal["diagonal", "tensorial_real", "tensorial_complex"] -ModeClassification = Literal["TEM", "quasi-TEM", "TE", "TM", "Hybrid"] - -""" mode tracking """ - -TrackFreq = Literal["central", "lowest", "highest"] - -""" lumped elements""" - -LumpDistType = Literal["off", "laterally_only", "on"] - -""" dataset """ - -xyz = Literal["x", "y", "z"] -UnitsZBF = Literal["mm", "cm", "in", "m"] - -""" sentinel """ -Undefined = object() +from tidy3d._common.components.types.base import ( + TYPE_TAG_STR, + ArrayComplex, + ArrayComplex1D, + ArrayComplex2D, + ArrayComplex3D, + ArrayComplex4D, + ArrayConstraints, + ArrayFloat, + ArrayFloat1D, + ArrayFloat2D, + ArrayFloat3D, + ArrayFloat4D, + ArrayInt1D, + ArrayLike, + ArrayLikeStrict, + AuxField, + Ax, + Axis, + Axis2D, + Bound, + BoxSurface, + ClipOperationType, + ColormapType, + Complex, + Coordinate, + Coordinate2D, + CoordinateOptional, + Direction, + DTypeLike, + EMField, + EpsSpecType, + FieldType, + FieldVal, + FreqArray, + FreqBound, + FreqBoundMax, + FreqBoundMin, + GridSize, + InterpMethod, + LengthUnit, + LumpDistType, + MatrixReal4x4, + ModeClassification, + ModeSolverType, + ObsGridArray, + PermittivityComponent, + PlanePosition, + PlotScale, + PlotVal, + Polarization, + PolarizationBasis, + PoleAndResidue, + PolesAndResidues, + PriorityMode, + RealFieldVal, + ScalarSymmetry, + Shapely, + Size, + Size1D, + Symmetry, + TensorReal, + TrackFreq, + Undefined, + UnitsZBF, + _auto_serializer, + _coerce, + _dtype2python, + _from_complex_dict, + _list_to_tuple, + _parse_complex, + array_alias, + discriminated_union, + xyz, +) diff --git a/tidy3d/components/types/third_party.py b/tidy3d/components/types/third_party.py index 4fe305ce68..7e2eda6240 100644 --- a/tidy3d/components/types/third_party.py +++ b/tidy3d/components/types/third_party.py @@ -1,14 +1,8 @@ -from __future__ import annotations - -from typing import Any +"""Compatibility shim for :mod:`tidy3d._common.components.types.third_party`.""" -from tidy3d.packaging import check_import +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -# TODO Complicated as trimesh should be a core package unless decoupled implementation types in functional location. -# We need to restructure. -if check_import("trimesh"): - import trimesh # Won't add much overhead if already imported +# marked as migrated to _common +from __future__ import annotations - TrimeshType = trimesh.Trimesh -else: - TrimeshType = Any +from tidy3d._common.components.types.third_party import TrimeshType diff --git a/tidy3d/components/types/utils.py b/tidy3d/components/types/utils.py index 983f623829..39cb1b0f5c 100644 --- a/tidy3d/components/types/utils.py +++ b/tidy3d/components/types/utils.py @@ -1,14 +1,10 @@ -"""Utilities for type & schema creation.""" - -from __future__ import annotations +"""Compatibility shim for :mod:`tidy3d._common.components.types.utils`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -def _add_schema(arbitrary_type: type, title: str, field_type_str: str) -> None: - """Adds a schema to the ``arbitrary_type`` class without subclassing.""" - - @classmethod - def mod_schema_fn(cls, field_schema: dict) -> None: - """Function that gets set to ``arbitrary_type.__modify_schema__``.""" - field_schema.update({"title": title, "type": field_type_str}) +# marked as migrated to _common +from __future__ import annotations - arbitrary_type.__modify_schema__ = mod_schema_fn +from tidy3d._common.components.types.utils import ( + _add_schema, +) diff --git a/tidy3d/components/validators.py b/tidy3d/components/validators.py index cb20bd84c1..2ed5f84fe3 100644 --- a/tidy3d/components/validators.py +++ b/tidy3d/components/validators.py @@ -1,55 +1,42 @@ -"""Defines various validation functions that get used to ensure inputs are legit""" +"""Compatibility shim for :mod:`tidy3d._common.components.validators`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any, TypeVar, Union import numpy as np -import pydantic.v1 as pydantic -from autograd.tracer import isbox - +from numpy.typing import NDArray +from pydantic import field_validator, model_validator + +from tidy3d._common.components.validators import ( + MIN_FREQUENCY, + FloatArray, + _assert_min_freq, + _warn_unsupported_traced_argument, + validate_name_str, + warn_if_dataset_none, +) +from tidy3d.components.data.data_array import DATA_ARRAY_MAP +from tidy3d.components.geometry.base import Box from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log -from .autograd.utils import get_static -from .base import DATA_ARRAY_MAP, skip_if_fields_missing -from .data.dataset import Dataset, FieldDataset -from .geometry.base import Box - -""" Explanation of pydantic validators: - - Validators are class methods that are added to the models to validate their fields (kwargs). - The functions on this page return validators based on config arguments - and are generally in multiple components of tidy3d. - The inner functions (validators) are decorated with @pydantic.validator, which is configured. - First argument is the string of the field being validated in the model. - ``allow_reuse`` lets us use the validator in more than one model. - ``always`` makes sure if the model is changed, the validator gets called again. - - The function being decorated by @pydantic.validator generally takes - ``cls`` the class that the validator is added to. - ``val`` the value of the field being validated. - ``values`` a dictionary containing all of the other fields of the model. - It is important to note that the validator only has access to fields that are defined - before the field being validated. - Fields defined under the validated field will not be in ``values``. - - All validators generally should throw an exception if the validation fails - and return val if it passes. - Sometimes, we can use validators to change ``val`` or ``values``, - but this should be done with caution as it can be hard to reason about. - - To add a validator from this file to the pydantic model, - put it in the model's main body and assign it to a variable (class method). - For example ``_plane_validator = assert_plane()``. - Note, if the assigned name ``_plane_validator`` is used later on for another validator, say, - the original validator will be overwritten so be aware of this. +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Callable, Optional - For more details: `Pydantic Validators `_ -""" + from pydantic import FieldValidationInfo -# Lowest frequency supported (Hz) -MIN_FREQUENCY = 1e5 + from tidy3d import Simulation + from tidy3d._common.components.validators import T + from tidy3d.components.base_sim.simulation import AbstractSimulation + from tidy3d.components.data.monitor_data import AbstractFieldData + from tidy3d.components.types import FreqArray + from tidy3d.plugins.smatrix import AbstractComponentModeler def named_obj_descr(obj: Any, field_name: str, position_index: int) -> str: @@ -60,11 +47,12 @@ def named_obj_descr(obj: Any, field_name: str, position_index: int) -> str: return descr -def assert_line(): +def assert_line() -> Callable[[type, tuple[float, ...]], tuple[float, ...]]: """makes sure a field's ``size`` attribute has exactly 2 zeros""" - @pydantic.validator("size", allow_reuse=True, always=True) - def is_line(cls, val): + @field_validator("size") + @classmethod + def is_line(cls: type, val: tuple[float, ...]) -> tuple[float, ...]: """Raise validation error if not 1 dimensional.""" if val.count(0.0) != 2: raise ValidationError(f"'{cls.__name__}' object must be a line, given size={val}") @@ -73,11 +61,12 @@ def is_line(cls, val): return is_line -def assert_plane(): +def assert_plane() -> Callable[[type, tuple[float, ...]], tuple[float, ...]]: """makes sure a field's ``size`` attribute has exactly 1 zero""" - @pydantic.validator("size", allow_reuse=True, always=True) - def is_plane(cls, val): + @field_validator("size") + @classmethod + def is_plane(cls: type, val: tuple[float, ...]) -> tuple[float, ...]: """Raise validation error if not planar.""" if val.count(0.0) != 1: raise ValidationError(f"'{cls.__name__}' object must be planar, given size={val}") @@ -86,11 +75,12 @@ def is_plane(cls, val): return is_plane -def assert_line_or_plane(): +def assert_line_or_plane() -> Callable[[type, tuple[float, ...]], tuple[float, ...]]: """makes sure a field's ``size`` attribute has either 1 or 2 zeros""" - @pydantic.validator("size", allow_reuse=True, always=True) - def is_line_or_plane(cls, val): + @field_validator("size") + @classmethod + def is_line_or_plane(cls: type, val: tuple[float, ...]) -> tuple[float, ...]: """Raise validation error if not a line or plane.""" if val.count(0.0) == 0 or val.count(0.0) == 3: raise ValidationError( @@ -101,11 +91,12 @@ def is_line_or_plane(cls, val): return is_line_or_plane -def assert_volumetric(): +def assert_volumetric() -> Callable[[type, tuple[float, ...]], tuple[float, ...]]: """makes sure a field's ``size`` attribute has no zero entry""" - @pydantic.validator("size", allow_reuse=True, always=True) - def is_volumetric(cls, val): + @field_validator("size") + @classmethod + def is_volumetric(cls: type, val: tuple[float, ...]) -> tuple[float, ...]: """Raise validation error if volume is 0.""" if val.count(0.0) > 0: raise ValidationError( @@ -118,47 +109,39 @@ def is_volumetric(cls, val): return is_volumetric -def validate_name_str(): - """make sure the name does not include [, ] (used for default names)""" - - @pydantic.validator("name", allow_reuse=True, always=True, pre=True) - def field_has_unique_names(cls, val): - """raise exception if '[' or ']' in name""" - # if val and ('[' in val or ']' in val): - # raise SetupError(f"'[' or ']' not allowed in name: {val} (used for defaults)") - return val - - return field_has_unique_names - - -def validate_unique(field_name: str): +def validate_unique( + *field_names: str, +) -> Callable[[type, Sequence[Any], FieldValidationInfo], Sequence[Any]]: """Make sure the given field has unique entries.""" - @pydantic.validator(field_name, always=True, allow_reuse=True) - def field_has_unique_entries(cls, val): + @field_validator(*field_names) + @classmethod + def field_has_unique_entries( + cls: type, val: Sequence[Any], info: FieldValidationInfo + ) -> Sequence[Any]: """Check if the field has unique entries.""" if len(set(val)) != len(val): - raise SetupError(f"Entries of '{field_name}' must be unique.") + raise SetupError(f"Entries of '{info.field_name}' must be unique.") return val return field_has_unique_entries -def validate_mode_objects_symmetry(field_name: str): +def validate_mode_objects_symmetry(field_name: str) -> Callable[[T], T]: """If a Mode object, this checks that the object is fully in the main quadrant in the presence of symmetry along a given axis, or else centered on the symmetry center.""" obj_type = "ModeSource" if field_name == "sources" else "ModeMonitor" - @pydantic.validator(field_name, allow_reuse=True, always=True) - @skip_if_fields_missing(["center", "symmetry"]) - def check_symmetry(cls, val, values): + @model_validator(mode="after") + def check_symmetry(self: T) -> T: """check for intersection of each structure with simulation bounds.""" - sim_center = values.get("center") + val: Sequence[Any] = getattr(self, field_name) + sim_center = self.center for position_index, geometric_object in enumerate(val): if geometric_object.type == obj_type: bounds_min, _ = geometric_object.bounds - for dim, sym in enumerate(values.get("symmetry")): + for dim, sym in enumerate(self.symmetry): if ( sym != 0 and bounds_min[dim] < sim_center[dim] @@ -170,21 +153,26 @@ def check_symmetry(cls, val, values): "quadrant, or centered on the symmetry axis." ) - return val + return self return check_symmetry -def assert_unique_names(field_name: str): +def assert_unique_names( + *field_names: str, +) -> Callable[[type, Sequence[Any], FieldValidationInfo], Sequence[Any]]: """makes sure all elements of a field have unique .name values""" - @pydantic.validator(field_name, allow_reuse=True, always=True) - def field_has_unique_names(cls, val, values): + @field_validator(*field_names) + @classmethod + def field_has_unique_names( + cls: type, val: Sequence[Any], info: FieldValidationInfo + ) -> Sequence[Any]: """make sure each element of val has a unique name (if specified).""" field_names = [field.name for field in val if field.name] unique_names = set(field_names) if len(unique_names) != len(field_names): - raise SetupError(f"'{field_name}' names are not unique, given {field_names}.") + raise SetupError(f"'{info.field_name}' names are not unique, given {field_names}.") return val return field_has_unique_names @@ -192,19 +180,19 @@ def field_has_unique_names(cls, val, values): def assert_objects_in_sim_bounds( field_name: str, error: bool = True, strict_inequality: bool = False -): +) -> Callable[[AbstractSimulation], AbstractSimulation]: """Makes sure all objects in field are at least partially inside of simulation bounds.""" - @pydantic.validator(field_name, allow_reuse=True, always=True) - @skip_if_fields_missing(["center", "size"]) - def objects_in_sim_bounds(cls, val, values): + @model_validator(mode="after") + def objects_in_sim_bounds(self: AbstractSimulation) -> AbstractSimulation: """check for intersection of each structure with simulation bounds.""" - sim_center = values.get("center") - sim_size = values.get("size") + val: Sequence[Any] = getattr(self, field_name) + sim_center = self.center + sim_size = self.size sim_box = Box(size=sim_size, center=sim_center) # Do a strict check, unless simulation is 0D along a dimension - strict_ineq = [size != 0 and strict_inequality for size in sim_size] + strict_ineq: list[bool] = [size != 0 and strict_inequality for size in sim_size] with log as consolidated_logger: for position_index, geometric_object in enumerate(val): @@ -216,7 +204,7 @@ def objects_in_sim_bounds(cls, val, values): raise SetupError(message) consolidated_logger.warning(message, custom_loc=custom_loc) - return val + return self return objects_in_sim_bounds @@ -226,19 +214,19 @@ def assert_objects_contained_in_sim_bounds( error: bool = True, strict_inequality: bool = False, strict_for_zero_size_dim: bool = False, -): +) -> Callable[[Simulation], Simulation]: """Makes sure all objects in field are completely inside the simulation bounds.""" - @pydantic.validator(field_name, allow_reuse=True, always=True) - @skip_if_fields_missing(["center", "size"]) - def objects_contained_in_sim_bounds(cls, val, values): + @model_validator(mode="after") + def objects_contained_in_sim_bounds(self: Simulation) -> Simulation: """check for containment of each structure with simulation bounds.""" - sim_center = values.get("center") - sim_size = values.get("size") + val: Sequence[Any] = getattr(self, field_name) + sim_center = self.center + sim_size = self.size sim_box = Box(size=sim_size, center=sim_center) # Do a strict check, unless simulation is 0D along a dimension - strict_ineq = [size != 0 and strict_inequality for size in sim_size] + strict_ineq: list[bool] = [size != 0 and strict_inequality for size in sim_size] with log as consolidated_logger: for position_index, geometric_object in enumerate(val): geo_strict_ineq = list(strict_ineq) @@ -257,90 +245,75 @@ def objects_contained_in_sim_bounds(cls, val, values): raise SetupError(message) consolidated_logger.warning(message, custom_loc=custom_loc) - return val + return self return objects_contained_in_sim_bounds -def enforce_monitor_fields_present(): +def enforce_monitor_fields_present() -> Callable[[AbstractFieldData], AbstractFieldData]: """Make sure all of the fields in the monitor are present in the corresponding data.""" - @pydantic.root_validator(skip_on_failure=True, allow_reuse=True) - def _contains_fields(cls, values): + @model_validator(mode="after") + def _contains_fields(self: AbstractFieldData) -> AbstractFieldData: """Make sure the initially specified fields are here.""" - for field_name in values.get("monitor").fields: - if values.get(field_name) is None: + for field_name in self.monitor.fields: + if getattr(self, field_name) is None: raise SetupError(f"missing field {field_name}") - return values + return self return _contains_fields -def required_if_symmetry_present(field_name: str): +def required_if_symmetry_present(field_name: str) -> Callable[[T], T]: """Make a field required (not None) if any non-zero symmetry eigenvalue is present.""" - @pydantic.validator(field_name, allow_reuse=True, always=True) - @skip_if_fields_missing(["symmetry"]) - def _make_required(cls, val, values): + @model_validator(mode="after") + def _make_required(self: T) -> T: """Ensure val is not None if the symmetry is non-zero along any dimension.""" - symmetry = values.get("symmetry") + val = getattr(self, field_name) + symmetry = self.symmetry if any(sym_val != 0 for sym_val in symmetry) and val is None: raise SetupError(f"'{field_name}' must be provided if symmetry present.") - return val + return self return _make_required -def warn_if_dataset_none(field_name: str): - """Warn if a Dataset field has None in its dictionary.""" - - @pydantic.validator(field_name, pre=True, always=True, allow_reuse=True) - def _warn_if_none(cls, val: Dataset) -> Dataset: - """Warn if the DataArrays fail to load.""" - if isinstance(val, dict): - if any((v in DATA_ARRAY_MAP for _, v in val.items() if isinstance(v, str))): - log.warning(f"Loading {field_name} without data.", custom_loc=[field_name]) - return None - return val - - return _warn_if_none - - -def warn_backward_waist_distance(field_name: str): +def warn_backward_waist_distance(field_name: str) -> Callable[[T], T]: """Warn if a backward-propagating beam uses a non-zero waist distance.""" - @pydantic.root_validator(allow_reuse=True) - def _warn_backward_nonzero(cls, values): + @model_validator(mode="after") + def _warn_backward_nonzero(self: T) -> T: """Emit deprecation warning for backward propagation with non-zero waist.""" - direction = values.get("direction") + direction = self.direction if direction != "-": - return values - waist_value = values.get(field_name) + return self + waist_value = getattr(self, field_name) waist_array = np.atleast_1d(waist_value) if not np.all(np.isclose(waist_array, 0.0)): log.warning( - f"Behavior of {cls.__name__} with direction '-' and non-zero '{field_name}' will " + f"Behavior of {self.__class__.__name__} with direction '-' and non-zero '{field_name}' will " "change in version 2.11 to be consistent with upcoming beam overlap monitors and " "ports. Currently, the waist distance is interpreted w.r.t. the directed " "propagation axis, so switching 'direction' also switches the position of the " "waist in the global reference frame. In the future, the waist position will be " "defined such that it is the same for backward- and forward-propagating beams.", ) - return values + return self return _warn_backward_nonzero -def assert_single_freq_in_range(field_name: str): +def assert_single_freq_in_range(field_name: str) -> Callable[[T], T]: """Assert only one frequency supplied in source and it's in source time range.""" - @pydantic.validator(field_name, always=True, allow_reuse=True) - @skip_if_fields_missing(["source_time"]) - def _single_frequency_in_range(cls, val: FieldDataset, values: dict) -> FieldDataset: + @model_validator(mode="after") + def _single_frequency_in_range(self: T) -> T: """Assert only one frequency supplied and it's in source time range.""" + val = getattr(self, field_name, None) if val is None: - return val - source_time = values.get("source_time") + return self + source_time = self.source_time fmin, fmax = source_time.frequency_range() for name, scalar_field in val.field_components.items(): freqs = scalar_field.f @@ -355,7 +328,7 @@ def _single_frequency_in_range(cls, val: FieldDataset, values: dict) -> FieldDat f"'{field_name}.{name}' contains frequency: {freq:.2e} Hz, which is outside " f"of the 'source_time' frequency range [{fmin:.2e}-{fmax:.2e}] Hz." ) - return val + return self return _single_frequency_in_range @@ -364,16 +337,17 @@ def validate_parameter_perturbation( field_name: str, base_field_name: str, allowed_complex: bool = True, -): +) -> Callable[[type, Any, FieldValidationInfo], Any]: """Assert perturbations have a valid shape and data type.""" - @pydantic.validator(field_name, always=True, allow_reuse=True) - def _check_perturbed_val(cls, val, values): + @field_validator(field_name) + @classmethod + def _check_perturbed_val(cls: type, val: Any, info: FieldValidationInfo) -> Any: """Assert perturbations have a valid shape and data type.""" if val is not None: # get base values - base_values = values[base_field_name] + base_values = info.data[base_field_name] # check that shapes of base parameter and perturbations coincide if np.shape(base_values) != np.shape(val): @@ -396,20 +370,12 @@ def _check_perturbed_val(cls, val, values): return _check_perturbed_val -def _assert_min_freq(freqs, msg_start: str) -> None: - """Check if all ``freqs`` are above the minimum frequency.""" - if np.min(freqs) < MIN_FREQUENCY: - raise ValidationError( - f"{msg_start} must be no lower than {MIN_FREQUENCY:.0e} Hz. " - "Note that the unit of frequency is 'Hz'." - ) - - -def validate_freqs_min(): +def validate_freqs_min() -> Callable[[type, FreqArray], FreqArray]: """Validate lower bound for monitor, and mode solver frequencies.""" - @pydantic.validator("freqs", always=True, allow_reuse=True) - def freqs_lower_bound(cls, val): + @field_validator("freqs") + @classmethod + def freqs_lower_bound(cls: type, val: FreqArray) -> FreqArray: """Raise validation error if any of ``freqs`` is lower than ``MIN_FREQUENCY``.""" _assert_min_freq(val, msg_start=f"All of '{cls.__name__}.freqs'") return val @@ -417,11 +383,12 @@ def freqs_lower_bound(cls, val): return freqs_lower_bound -def validate_freqs_not_empty(): +def validate_freqs_not_empty() -> Callable[[type, FreqArray], FreqArray]: """Validate that the array of frequencies is not empty.""" - @pydantic.validator("freqs", always=True, allow_reuse=True) - def freqs_not_empty(cls, val): + @field_validator("freqs") + @classmethod + def freqs_not_empty(cls: type, val: FreqArray) -> FreqArray: """Raise validation error if ``freqs`` is an empty Tuple.""" if len(val) == 0: raise ValidationError(f"'{cls.__name__}.freqs' cannot be empty (size 0).") @@ -430,32 +397,15 @@ def freqs_not_empty(cls, val): return freqs_not_empty -def validate_freqs_unique(): +def validate_freqs_unique() -> Callable[[AbstractComponentModeler, FreqArray], FreqArray]: """Validate that the array of frequencies does not have duplicate entries.""" - @pydantic.validator("freqs", always=True, allow_reuse=True) - def freqs_unique(cls, val): + @field_validator("freqs") + @classmethod + def freqs_unique(cls: AbstractComponentModeler, val: FreqArray) -> FreqArray: """Raise validation error if ``freqs`` has duplicate entries.""" if len(set(val)) != len(val): raise ValidationError(f"'{cls.__name__}.freqs' must not contain duplicate entries.") return val return freqs_unique - - -def _warn_unsupported_traced_argument(name: str): - @pydantic.validator(name, always=True, allow_reuse=True) - def _warn_traced_arg(cls, val, values): - if isbox(val): - log.warning( - f"Field '{name}' of '{cls.__name__}' received an autograd tracer " - f"(i.e., a value being tracked for automatic differentiation). " - f"Automatic differentiation through this field is unsupported, " - f"so the tracer has been converted to its static value. " - f"If you want to avoid this warning, you manually unbox the value " - f"using the 'autograd.tracer.getval' function before passing it to Tidy3D." - ) - return get_static(val) - return val - - return _warn_traced_arg diff --git a/tidy3d/components/viz/__init__.py b/tidy3d/components/viz/__init__.py index bbc533144f..3e677c6058 100644 --- a/tidy3d/components/viz/__init__.py +++ b/tidy3d/components/viz/__init__.py @@ -1,12 +1,33 @@ +"""Compatibility shim for :mod:`tidy3d._common.components.viz`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common from __future__ import annotations -from .axes_utils import add_ax_if_none, equal_aspect, make_ax, set_default_labels_and_title -from .descartes import Polygon, polygon_patch, polygon_path -from .flex_style import apply_tidy3d_params, restore_matplotlib_rcparams -from .plot_params import ( +from tidy3d._common.components.viz import ( + ARROW_ALPHA, + ARROW_COLOR_ABSORBER, + ARROW_COLOR_MONITOR, + ARROW_COLOR_POLARIZATION, + ARROW_COLOR_SOURCE, + ARROW_LENGTH, + FLEXCOMPUTE_COLORS, + MATPLOTLIB_IMPORTED, + MEDIUM_CMAP, + PLOT_BUFFER, + STRUCTURE_EPS_CMAP, + STRUCTURE_EPS_CMAP_R, + STRUCTURE_HEAT_COND_CMAP, AbstractPlotParams, PathPlotParams, PlotParams, + Polygon, + VisualizationSpec, + add_ax_if_none, + arrow_style, + equal_aspect, + make_ax, plot_params_abc, plot_params_absorber, plot_params_bloch, @@ -22,69 +43,10 @@ plot_params_source, plot_params_structure, plot_params_symmetry, + plot_scene_3d, + plot_sim_3d, + polygon_patch, + polygon_path, + restore_matplotlib_rcparams, + set_default_labels_and_title, ) -from .plot_sim_3d import plot_scene_3d, plot_sim_3d -from .styles import ( - ARROW_ALPHA, - ARROW_COLOR_ABSORBER, - ARROW_COLOR_MONITOR, - ARROW_COLOR_POLARIZATION, - ARROW_COLOR_SOURCE, - ARROW_LENGTH, - FLEXCOMPUTE_COLORS, - MEDIUM_CMAP, - PLOT_BUFFER, - STRUCTURE_EPS_CMAP, - STRUCTURE_EPS_CMAP_R, - STRUCTURE_HEAT_COND_CMAP, - arrow_style, -) -from .visualization_spec import MATPLOTLIB_IMPORTED, VisualizationSpec - -apply_tidy3d_params() - -__all__ = [ - "ARROW_ALPHA", - "ARROW_COLOR_ABSORBER", - "ARROW_COLOR_MONITOR", - "ARROW_COLOR_POLARIZATION", - "ARROW_COLOR_SOURCE", - "ARROW_LENGTH", - "FLEXCOMPUTE_COLORS", - "MATPLOTLIB_IMPORTED", - "MEDIUM_CMAP", - "PLOT_BUFFER", - "STRUCTURE_EPS_CMAP", - "STRUCTURE_EPS_CMAP_R", - "STRUCTURE_HEAT_COND_CMAP", - "AbstractPlotParams", - "PathPlotParams", - "PlotParams", - "Polygon", - "VisualizationSpec", - "add_ax_if_none", - "arrow_style", - "equal_aspect", - "make_ax", - "plot_params_abc", - "plot_params_absorber", - "plot_params_bloch", - "plot_params_fluid", - "plot_params_geometry", - "plot_params_grid", - "plot_params_lumped_element", - "plot_params_monitor", - "plot_params_override_structures", - "plot_params_pec", - "plot_params_pmc", - "plot_params_pml", - "plot_params_source", - "plot_params_structure", - "plot_params_symmetry", - "plot_scene_3d", - "plot_sim_3d", - "polygon_patch", - "polygon_path", - "restore_matplotlib_rcparams", - "set_default_labels_and_title", -] diff --git a/tidy3d/components/viz/axes_utils.py b/tidy3d/components/viz/axes_utils.py index 8fb78f699c..cdc5454a95 100644 --- a/tidy3d/components/viz/axes_utils.py +++ b/tidy3d/components/viz/axes_utils.py @@ -1,186 +1,14 @@ -from __future__ import annotations - -from functools import wraps -from typing import Any, Optional - -from tidy3d.components.types import Ax, Axis, LengthUnit -from tidy3d.constants import UnitScaling -from tidy3d.exceptions import Tidy3dKeyError - - -def _create_unit_aware_locator(): - """Create UnitAwareLocator lazily due to matplotlib import restrictions.""" - import matplotlib.ticker as ticker - - class UnitAwareLocator(ticker.Locator): - """Custom tick locator that places ticks at nice positions in the target unit.""" - - def __init__(self, scale_factor: float) -> None: - """ - Parameters - ---------- - scale_factor : float - Factor to convert from micrometers to the target unit. - """ - super().__init__() - self.scale_factor = scale_factor - - def __call__(self): - vmin, vmax = self.axis.get_view_interval() - return self.tick_values(vmin, vmax) - - def view_limits(self, vmin, vmax): - """Override to prevent matplotlib from adjusting our limits.""" - return vmin, vmax - - def tick_values(self, vmin, vmax): - # convert the view range to the target unit - vmin_unit = vmin * self.scale_factor - vmax_unit = vmax * self.scale_factor - - # tolerance for floating point comparisons in target unit - unit_range = vmax_unit - vmin_unit - unit_tol = unit_range * 1e-8 - - locator = ticker.MaxNLocator(nbins=11, prune=None, min_n_ticks=2) - - ticks_unit = locator.tick_values(vmin_unit, vmax_unit) - - # ensure we have ticks that cover the full range - if len(ticks_unit) > 0: - if ticks_unit[0] > vmin_unit + unit_tol or ticks_unit[-1] < vmax_unit - unit_tol: - # try with fewer bins to get better coverage - for n in [10, 9, 8, 7, 6, 5]: - locator = ticker.MaxNLocator(nbins=n, prune=None, min_n_ticks=2) - ticks_unit = locator.tick_values(vmin_unit, vmax_unit) - if ( - len(ticks_unit) >= 3 - and ticks_unit[0] <= vmin_unit + unit_tol - and ticks_unit[-1] >= vmax_unit - unit_tol - ): - break - - # if still no good coverage, manually ensure edge coverage - if len(ticks_unit) > 0: - if ( - ticks_unit[0] > vmin_unit + unit_tol - or ticks_unit[-1] < vmax_unit - unit_tol - ): - # find a reasonable step size from existing ticks - if len(ticks_unit) > 1: - step = ticks_unit[1] - ticks_unit[0] - else: - step = unit_range / 5 - - # extend the range to ensure coverage - extended_min = vmin_unit - step - extended_max = vmax_unit + step - - # try one more time with extended range - locator = ticker.MaxNLocator(nbins=8, prune=None, min_n_ticks=2) - ticks_unit = locator.tick_values(extended_min, extended_max) - - # filter to reasonable bounds around the original range - ticks_unit = [ - t - for t in ticks_unit - if t >= vmin_unit - step / 2 and t <= vmax_unit + step / 2 - ] - - # convert the nice ticks back to the original data unit (micrometers) - ticks_um = ticks_unit / self.scale_factor - - # filter to ensure ticks are within bounds (with small tolerance) - eps = (vmax - vmin) * 1e-8 - return [tick for tick in ticks_um if vmin - eps <= tick <= vmax + eps] +"""Compatibility shim for :mod:`tidy3d._common.components.viz.axes_utils`.""" - return UnitAwareLocator +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -def make_ax() -> Ax: - """makes an empty ``ax``.""" - import matplotlib.pyplot as plt - - _, ax = plt.subplots(1, 1, tight_layout=True) - return ax - - -def add_ax_if_none(plot): - """Decorates ``plot(*args, **kwargs, ax=None)`` function. - if ax=None in the function call, creates an ax and feeds it to rest of function. - """ - - @wraps(plot) - def _plot(*args: Any, **kwargs: Any) -> Ax: - """New plot function using a generated ax if None.""" - if kwargs.get("ax") is None: - ax = make_ax() - kwargs["ax"] = ax - return plot(*args, **kwargs) - - return _plot - - -def equal_aspect(plot): - """Decorates a plotting function returning a matplotlib axes. - Ensures the aspect ratio of the returned axes is set to equal. - Useful for 2D plots, like sim.plot() or sim_data.plot_fields() - """ - - @wraps(plot) - def _plot(*args: Any, **kwargs: Any) -> Ax: - """New plot function with equal aspect ratio axes returned.""" - ax = plot(*args, **kwargs) - ax.set_aspect("equal") - return ax - - return _plot - - -def set_default_labels_and_title( - axis_labels: tuple[str, str], - axis: Axis, - position: float, - ax: Ax, - plot_length_units: Optional[LengthUnit] = None, -) -> Ax: - """Adds axis labels and title to plots involving spatial dimensions. - When the ``plot_length_units`` are specified, the plot axes are scaled, and - the title and axis labels include the desired units. - """ - - import matplotlib.ticker as ticker - - xlabel = axis_labels[0] - ylabel = axis_labels[1] - if plot_length_units is not None: - if plot_length_units not in UnitScaling: - raise Tidy3dKeyError( - f"Provided units '{plot_length_units}' are not supported. " - f"Please choose one of '{LengthUnit}'." - ) - ax.set_xlabel(f"{xlabel} ({plot_length_units})") - ax.set_ylabel(f"{ylabel} ({plot_length_units})") - - scale_factor = UnitScaling[plot_length_units] - - # for imperial units, use custom tick locator for nice tick positions - if plot_length_units in ["mil", "in"]: - UnitAwareLocator = _create_unit_aware_locator() - x_locator = UnitAwareLocator(scale_factor) - y_locator = UnitAwareLocator(scale_factor) - ax.xaxis.set_major_locator(x_locator) - ax.yaxis.set_major_locator(y_locator) - - formatter = ticker.FuncFormatter(lambda y, _: f"{y * scale_factor:.2f}") - - ax.xaxis.set_major_formatter(formatter) - ax.yaxis.set_major_formatter(formatter) - - position_scaled = position * scale_factor - ax.set_title(f"cross section at {'xyz'[axis]}={position_scaled:.2f} ({plot_length_units})") - else: - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(f"cross section at {'xyz'[axis]}={position:.2f}") - return ax +from tidy3d._common.components.viz.axes_utils import ( + _create_unit_aware_locator, + add_ax_if_none, + equal_aspect, + make_ax, + set_default_labels_and_title, +) diff --git a/tidy3d/components/viz/descartes.py b/tidy3d/components/viz/descartes.py index b743839585..a1b2f54fc2 100644 --- a/tidy3d/components/viz/descartes.py +++ b/tidy3d/components/viz/descartes.py @@ -1,109 +1,12 @@ -"""================================================================================================= -Descartes modified from https://pypi.org/project/descartes/ for Shapely >= 1.8.0 +"""Compatibility shim for :mod:`tidy3d._common.components.viz.descartes`.""" -Copyright Flexcompute 2022 - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR -CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER -IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT -OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common from __future__ import annotations -from typing import Any - -try: - from matplotlib.patches import PathPatch - from matplotlib.path import Path -except ImportError: - pass -from numpy import array, concatenate, ones - - -class Polygon: - """Adapt Shapely polygons to a common interface""" - - def __init__(self, context) -> None: - if isinstance(context, dict): - self.context = context["coordinates"] - else: - self.context = context - - @property - def exterior(self): - """Get polygon exterior.""" - value = getattr(self.context, "exterior", None) - if value is None: - value = self.context[0] - return value - - @property - def interiors(self): - """Get polygon interiors.""" - value = getattr(self.context, "interiors", None) - if value is None: - value = self.context[1:] - return value - - -def polygon_path(polygon): - """Constructs a compound matplotlib path from a Shapely or GeoJSON-like - geometric object""" - - def coding(obj): - # The codes will be all "LINETO" commands, except for "MOVETO"s at the - # beginning of each subpath - crds = getattr(obj, "coords", None) - if crds is None: - crds = obj - n = len(crds) - vals = ones(n, dtype=Path.code_type) * Path.LINETO - if len(vals) > 0: - vals[0] = Path.MOVETO - return vals - - ptype = polygon.geom_type - if ptype == "Polygon": - polygon = [Polygon(polygon)] - elif ptype == "MultiPolygon": - polygon = [Polygon(p) for p in polygon.geoms] - - vertices = concatenate( - [ - concatenate( - [array(t.exterior.coords)[:, :2]] + [array(r.coords)[:, :2] for r in t.interiors] - ) - for t in polygon - ] - ) - codes = concatenate( - [concatenate([coding(t.exterior)] + [coding(r) for r in t.interiors]) for t in polygon] - ) - - return Path(vertices, codes) - - -def polygon_patch(polygon, **kwargs: Any): - """Constructs a matplotlib patch from a geometric object - - The ``polygon`` may be a Shapely or GeoJSON-like object with or without holes. - The ``kwargs`` are those supported by the matplotlib.patches.Polygon class - constructor. Returns an instance of matplotlib.patches.PathPatch. - - Example - ------- - >>> b = Point(0, 0).buffer(1.0) # doctest: +SKIP - >>> patch = PolygonPatch(b, fc='blue', ec='blue', alpha=0.5) # doctest: +SKIP - >>> axis.add_patch(patch) # doctest: +SKIP - - """ - return PathPatch(polygon_path(polygon), **kwargs) - - -"""End descartes modification -=================================================================================================""" +from tidy3d._common.components.viz.descartes import ( + Polygon, + polygon_patch, + polygon_path, +) diff --git a/tidy3d/components/viz/flex_color_palettes.py b/tidy3d/components/viz/flex_color_palettes.py index 7fc1454a0b..0b80b28bef 100644 --- a/tidy3d/components/viz/flex_color_palettes.py +++ b/tidy3d/components/viz/flex_color_palettes.py @@ -1,3306 +1,12 @@ +"""Compatibility shim for :mod:`tidy3d._common.components.viz.flex_color_palettes`.""" + +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common from __future__ import annotations -SEQUENTIAL_PALETTES_HEX = { - "flex_turquoise_seq": [ - "#ffffff", - "#fefefe", - "#fdfdfd", - "#fcfcfc", - "#fbfbfb", - "#fafbfa", - "#f9fafa", - "#f8f9f9", - "#f7f8f8", - "#f6f7f7", - "#f5f6f6", - "#f3f5f5", - "#f2f4f4", - "#f1f3f3", - "#f0f3f2", - "#eff2f1", - "#eef1f1", - "#edf0f0", - "#ecefef", - "#ebeeee", - "#eaeded", - "#e9edec", - "#e8eceb", - "#e7ebeb", - "#e6eaea", - "#e5e9e9", - "#e4e8e8", - "#e3e7e7", - "#e2e7e6", - "#e1e6e5", - "#e0e5e5", - "#dfe4e4", - "#dee3e3", - "#dde2e2", - "#dce2e1", - "#dbe1e0", - "#dae0df", - "#d9dfdf", - "#d8dede", - "#d7dedd", - "#d6dddc", - "#d5dcdb", - "#d4dbdb", - "#d3dada", - "#d2dad9", - "#d1d9d8", - "#d1d8d7", - "#d0d7d6", - "#cfd6d6", - "#ced6d5", - "#cdd5d4", - "#ccd4d3", - "#cbd3d2", - "#cad2d2", - "#c9d2d1", - "#c8d1d0", - "#c7d0cf", - "#c6cfce", - "#c5cece", - "#c4cecd", - "#c3cdcc", - "#c2cccb", - "#c1cbca", - "#c0cbca", - "#bfcac9", - "#bec9c8", - "#bec8c7", - "#bdc8c7", - "#bcc7c6", - "#bbc6c5", - "#bac5c4", - "#b9c5c3", - "#b8c4c3", - "#b7c3c2", - "#b6c2c1", - "#b5c2c0", - "#b4c1c0", - "#b3c0bf", - "#b2bfbe", - "#b2bfbd", - "#b1bebd", - "#b0bdbc", - "#afbcbb", - "#aebcba", - "#adbbba", - "#acbab9", - "#abbab8", - "#aab9b7", - "#a9b8b7", - "#a9b7b6", - "#a8b7b5", - "#a7b6b4", - "#a6b5b4", - "#a5b4b3", - "#a4b4b2", - "#a3b3b2", - "#a2b2b1", - "#a1b2b0", - "#a1b1af", - "#a0b0af", - "#9fb0ae", - "#9eafad", - "#9daeac", - "#9cadac", - "#9badab", - "#9aacaa", - "#99abaa", - "#99aba9", - "#98aaa8", - "#97a9a7", - "#96a9a7", - "#95a8a6", - "#94a7a5", - "#93a6a5", - "#92a6a4", - "#92a5a3", - "#91a4a2", - "#90a4a2", - "#8fa3a1", - "#8ea2a0", - "#8da2a0", - "#8ca19f", - "#8ca09e", - "#8ba09e", - "#8a9f9d", - "#899e9c", - "#889e9c", - "#879d9b", - "#869c9a", - "#869c9a", - "#859b99", - "#849a98", - "#839a97", - "#829997", - "#819896", - "#809895", - "#809795", - "#7f9694", - "#7e9693", - "#7d9593", - "#7c9492", - "#7b9491", - "#7a9391", - "#7a9290", - "#79928f", - "#78918f", - "#77908e", - "#76908d", - "#758f8d", - "#758f8c", - "#748e8b", - "#738d8b", - "#728d8a", - "#718c89", - "#708b89", - "#708b88", - "#6f8a87", - "#6e8987", - "#6d8986", - "#6c8885", - "#6b8885", - "#6a8784", - "#6a8684", - "#698683", - "#688582", - "#678482", - "#668481", - "#658380", - "#658280", - "#64827f", - "#63817e", - "#62817e", - "#61807d", - "#607f7c", - "#607f7c", - "#5f7e7b", - "#5e7d7b", - "#5d7d7a", - "#5c7c79", - "#5b7c79", - "#5b7b78", - "#5a7a77", - "#597a77", - "#587976", - "#577975", - "#567875", - "#567774", - "#557774", - "#547673", - "#537572", - "#527572", - "#517471", - "#507470", - "#507370", - "#4f726f", - "#4e726f", - "#4d716e", - "#4c716d", - "#4b706d", - "#4b6f6c", - "#4a6f6b", - "#496e6b", - "#486e6a", - "#476d6a", - "#466c69", - "#456c68", - "#446b68", - "#446b67", - "#436a67", - "#426966", - "#416965", - "#406865", - "#3f6864", - "#3e6763", - "#3e6663", - "#3d6662", - "#3c6562", - "#3b6561", - "#3a6460", - "#396360", - "#38635f", - "#37625f", - "#36625e", - "#35615d", - "#35605d", - "#34605c", - "#335f5c", - "#325f5b", - "#315e5a", - "#305d5a", - "#2f5d59", - "#2e5c58", - "#2d5c58", - "#2c5b57", - "#2b5a57", - "#2a5a56", - "#295955", - "#285955", - "#275854", - "#265754", - "#255753", - "#245652", - "#235652", - "#225551", - "#215551", - "#205450", - "#1e534f", - "#1d534f", - "#1c524e", - "#1b524e", - "#1a514d", - "#18504c", - "#17504c", - "#164f4b", - "#144f4b", - "#134e4a", - ], - "flex_green_seq": [ - "#ffffff", - "#fefefe", - "#fdfdfd", - "#fcfcfc", - "#fbfbfb", - "#f9fafa", - "#f8f9f9", - "#f7f8f8", - "#f6f7f7", - "#f5f6f6", - "#f4f5f5", - "#f3f5f3", - "#f2f4f2", - "#f1f3f1", - "#f0f2f0", - "#eff1ef", - "#eef0ee", - "#ecefed", - "#ebeeec", - "#eaedeb", - "#e9ecea", - "#e8ebe9", - "#e7eae8", - "#e6eae7", - "#e5e9e6", - "#e4e8e5", - "#e3e7e4", - "#e2e6e3", - "#e1e5e2", - "#e0e4e1", - "#dfe3e0", - "#dee3df", - "#dde2de", - "#dce1dd", - "#dbe0dc", - "#dadfdb", - "#d8deda", - "#d7ddd9", - "#d6dcd8", - "#d5dcd7", - "#d4dbd6", - "#d3dad5", - "#d2d9d4", - "#d1d8d3", - "#d0d7d2", - "#cfd6d1", - "#ced6d0", - "#cdd5cf", - "#ccd4ce", - "#cbd3ce", - "#cad2cd", - "#c9d1cc", - "#c8d1cb", - "#c7d0ca", - "#c6cfc9", - "#c5cec8", - "#c4cdc7", - "#c3cdc6", - "#c2ccc5", - "#c1cbc4", - "#c0cac3", - "#bfc9c2", - "#bec9c1", - "#bdc8c0", - "#bcc7bf", - "#bbc6be", - "#bac5bd", - "#b9c5bd", - "#b9c4bc", - "#b8c3bb", - "#b7c2ba", - "#b6c1b9", - "#b5c1b8", - "#b4c0b7", - "#b3bfb6", - "#b2beb5", - "#b1bdb4", - "#b0bdb3", - "#afbcb2", - "#aebbb1", - "#adbab1", - "#acbab0", - "#abb9af", - "#aab8ae", - "#a9b7ad", - "#a8b7ac", - "#a7b6ab", - "#a6b5aa", - "#a5b4a9", - "#a5b4a8", - "#a4b3a8", - "#a3b2a7", - "#a2b1a6", - "#a1b1a5", - "#a0b0a4", - "#9fafa3", - "#9eaea2", - "#9daea1", - "#9cada0", - "#9baca0", - "#9aab9f", - "#99ab9e", - "#99aa9d", - "#98a99c", - "#97a89b", - "#96a89a", - "#95a799", - "#94a699", - "#93a598", - "#92a597", - "#91a496", - "#90a395", - "#90a394", - "#8fa293", - "#8ea193", - "#8da092", - "#8ca091", - "#8b9f90", - "#8a9e8f", - "#899e8e", - "#889d8d", - "#879c8d", - "#879b8c", - "#869b8b", - "#859a8a", - "#849989", - "#839988", - "#829888", - "#819787", - "#809786", - "#809685", - "#7f9584", - "#7e9483", - "#7d9483", - "#7c9382", - "#7b9281", - "#7a9280", - "#79917f", - "#79907e", - "#78907e", - "#778f7d", - "#768e7c", - "#758e7b", - "#748d7a", - "#738c79", - "#728c79", - "#728b78", - "#718a77", - "#708a76", - "#6f8975", - "#6e8875", - "#6d8774", - "#6c8773", - "#6c8672", - "#6b8571", - "#6a8571", - "#698470", - "#68836f", - "#67836e", - "#66826d", - "#66816d", - "#65816c", - "#64806b", - "#637f6a", - "#627f69", - "#617e69", - "#607d68", - "#607d67", - "#5f7c66", - "#5e7c65", - "#5d7b65", - "#5c7a64", - "#5b7a63", - "#5a7962", - "#5a7861", - "#597861", - "#587760", - "#57765f", - "#56765e", - "#55755d", - "#55745d", - "#54745c", - "#53735b", - "#52725a", - "#51725a", - "#507159", - "#4f7058", - "#4f7057", - "#4e6f56", - "#4d6e56", - "#4c6e55", - "#4b6d54", - "#4a6d53", - "#4a6c53", - "#496b52", - "#486b51", - "#476a50", - "#466950", - "#45694f", - "#44684e", - "#44674d", - "#43674c", - "#42664c", - "#41654b", - "#40654a", - "#3f6449", - "#3e6449", - "#3e6348", - "#3d6247", - "#3c6246", - "#3b6146", - "#3a6045", - "#396044", - "#385f43", - "#385e43", - "#375e42", - "#365d41", - "#355c40", - "#345c40", - "#335b3f", - "#325b3e", - "#315a3d", - "#30593d", - "#30593c", - "#2f583b", - "#2e573a", - "#2d573a", - "#2c5639", - "#2b5538", - "#2a5537", - "#295437", - "#285436", - "#275335", - "#265234", - "#265234", - "#255133", - "#245032", - "#235031", - "#224f31", - "#214e30", - "#204e2f", - "#1f4d2e", - "#1e4c2e", - "#1d4c2d", - "#1c4b2c", - "#1b4b2b", - "#1a4a2b", - "#18492a", - "#174929", - "#164828", - "#154728", - "#144727", - "#134626", - "#124525", - "#104525", - "#0f4424", - ], - "flex_blue_seq": [ - "#ffffff", - "#fefefe", - "#fdfdfd", - "#fbfcfc", - "#fafbfb", - "#f9fafa", - "#f8f9f9", - "#f7f7f8", - "#f6f6f8", - "#f4f5f7", - "#f3f4f6", - "#f2f3f5", - "#f1f2f4", - "#f0f1f3", - "#eff0f2", - "#eeeff1", - "#eceef0", - "#ebedf0", - "#eaecef", - "#e9ebee", - "#e8eaed", - "#e7e9ec", - "#e6e8eb", - "#e4e7ea", - "#e3e6ea", - "#e2e5e9", - "#e1e4e8", - "#e0e3e7", - "#dfe2e6", - "#dee1e5", - "#dde0e5", - "#dcdfe4", - "#dadee3", - "#d9dde2", - "#d8dce1", - "#d7dbe0", - "#d6dae0", - "#d5d9df", - "#d4d8de", - "#d3d7dd", - "#d2d6dc", - "#d1d5dc", - "#d0d4db", - "#cfd3da", - "#ced2d9", - "#ccd1d9", - "#cbd0d8", - "#cacfd7", - "#c9ced6", - "#c8cdd5", - "#c7ccd5", - "#c6ccd4", - "#c5cbd3", - "#c4cad2", - "#c3c9d2", - "#c2c8d1", - "#c1c7d0", - "#c0c6cf", - "#bfc5cf", - "#bec4ce", - "#bdc3cd", - "#bcc2cd", - "#bbc1cc", - "#bac0cb", - "#b9c0ca", - "#b8bfca", - "#b7bec9", - "#b6bdc8", - "#b5bcc7", - "#b4bbc7", - "#b3bac6", - "#b2b9c5", - "#b1b8c5", - "#b0b7c4", - "#afb7c3", - "#aeb6c3", - "#adb5c2", - "#acb4c1", - "#abb3c1", - "#aab2c0", - "#a9b1bf", - "#a8b0be", - "#a7b0be", - "#a6afbd", - "#a5aebc", - "#a4adbc", - "#a3acbb", - "#a2abba", - "#a1aaba", - "#a0aab9", - "#9fa9b8", - "#9ea8b8", - "#9da7b7", - "#9ca6b7", - "#9ba5b6", - "#9aa4b5", - "#99a4b5", - "#98a3b4", - "#97a2b3", - "#96a1b3", - "#95a0b2", - "#949fb1", - "#939fb1", - "#929eb0", - "#919db0", - "#909caf", - "#8f9bae", - "#8f9aae", - "#8e9aad", - "#8d99ac", - "#8c98ac", - "#8b97ab", - "#8a96ab", - "#8996aa", - "#8895a9", - "#8794a9", - "#8693a8", - "#8592a8", - "#8492a7", - "#8391a6", - "#8290a6", - "#818fa5", - "#818ea5", - "#808ea4", - "#7f8da3", - "#7e8ca3", - "#7d8ba2", - "#7c8aa2", - "#7b8aa1", - "#7a89a1", - "#7988a0", - "#78879f", - "#77869f", - "#76869e", - "#76859e", - "#75849d", - "#74839d", - "#73829c", - "#72829b", - "#71819b", - "#70809a", - "#6f7f9a", - "#6e7f99", - "#6d7e99", - "#6c7d98", - "#6c7c98", - "#6b7c97", - "#6a7b97", - "#697a96", - "#687995", - "#677895", - "#667894", - "#657794", - "#647693", - "#637593", - "#637592", - "#627492", - "#617391", - "#607291", - "#5f7290", - "#5e7190", - "#5d708f", - "#5c6f8f", - "#5b6f8e", - "#5b6e8e", - "#5a6d8d", - "#596c8c", - "#586b8c", - "#576b8b", - "#566a8b", - "#55698a", - "#54688a", - "#536889", - "#536789", - "#526688", - "#516588", - "#506587", - "#4f6487", - "#4e6386", - "#4d6286", - "#4c6285", - "#4b6185", - "#4a6084", - "#4a5f84", - "#495f83", - "#485e83", - "#475d83", - "#465d82", - "#455c82", - "#445b81", - "#435a81", - "#425a80", - "#425980", - "#41587f", - "#40577f", - "#3f577e", - "#3e567e", - "#3d557d", - "#3c547d", - "#3b547c", - "#3a537c", - "#39527b", - "#39517b", - "#38517b", - "#37507a", - "#364f7a", - "#354e79", - "#344e79", - "#334d78", - "#324c78", - "#314b77", - "#304b77", - "#2f4a76", - "#2e4976", - "#2d4876", - "#2c4875", - "#2c4775", - "#2b4674", - "#2a4574", - "#294473", - "#284473", - "#274373", - "#264272", - "#254172", - "#244171", - "#234071", - "#223f70", - "#213e70", - "#203e70", - "#1f3d6f", - "#1e3c6f", - "#1d3b6e", - "#1c3a6e", - "#1b3a6e", - "#1a396d", - "#19386d", - "#17376c", - "#16366c", - "#15366c", - "#14356b", - "#13346b", - "#12336b", - "#10326a", - "#0f326a", - "#0e316a", - "#0d3069", - "#0b2f69", - "#0a2e68", - "#082d68", - "#072c68", - "#062c68", - "#042b67", - "#032a67", - "#022967", - "#012866", - "#002766", - ], - "flex_orange_seq": [ - "#ffffff", - "#fefefe", - "#fefdfd", - "#fdfdfc", - "#fdfcfb", - "#fcfbfa", - "#fbfafa", - "#fbf9f9", - "#faf9f8", - "#faf8f7", - "#f9f7f6", - "#f8f6f5", - "#f8f6f4", - "#f7f5f3", - "#f7f4f2", - "#f6f3f1", - "#f5f2f1", - "#f5f2f0", - "#f4f1ef", - "#f3f0ee", - "#f3efed", - "#f2efec", - "#f2eeeb", - "#f1edea", - "#f1ece9", - "#f0ece8", - "#f0ebe7", - "#efeae6", - "#efe9e5", - "#eee9e4", - "#eee8e3", - "#ede7e2", - "#ede6e1", - "#ece5e0", - "#ece5df", - "#ebe4de", - "#ebe3dd", - "#eae2dc", - "#eae2db", - "#e9e1da", - "#e9e0d9", - "#e9dfd8", - "#e8dfd7", - "#e8ded6", - "#e7ddd5", - "#e7dcd4", - "#e6dbd3", - "#e6dbd2", - "#e6dad1", - "#e5d9d0", - "#e5d8cf", - "#e4d8ce", - "#e4d7cd", - "#e3d6cc", - "#e3d5cb", - "#e3d5ca", - "#e2d4c9", - "#e2d3c8", - "#e1d2c7", - "#e1d2c6", - "#e0d1c5", - "#e0d0c4", - "#e0cfc3", - "#dfcfc2", - "#dfcec1", - "#decdc0", - "#deccbf", - "#deccbe", - "#ddcbbd", - "#ddcabc", - "#dcc9bb", - "#dcc9ba", - "#dcc8b9", - "#dbc7b8", - "#dbc6b8", - "#dbc6b7", - "#dac5b6", - "#dac4b5", - "#d9c4b4", - "#d9c3b3", - "#d9c2b2", - "#d8c1b1", - "#d8c1b0", - "#d7c0af", - "#d7bfae", - "#d7bead", - "#d6beac", - "#d6bdab", - "#d6bcaa", - "#d5bba9", - "#d5bba8", - "#d4baa7", - "#d4b9a6", - "#d4b9a5", - "#d3b8a4", - "#d3b7a3", - "#d3b6a2", - "#d2b6a1", - "#d2b5a0", - "#d2b49f", - "#d1b49e", - "#d1b39d", - "#d0b29c", - "#d0b19b", - "#d0b19a", - "#cfb099", - "#cfaf99", - "#cfaf98", - "#ceae97", - "#cead96", - "#ceac95", - "#cdac94", - "#cdab93", - "#cdaa92", - "#ccaa91", - "#cca990", - "#cca88f", - "#cba78e", - "#cba78d", - "#caa68c", - "#caa58b", - "#caa58a", - "#c9a489", - "#c9a388", - "#c9a387", - "#c8a286", - "#c8a185", - "#c8a085", - "#c7a084", - "#c79f83", - "#c79e82", - "#c69e81", - "#c69d80", - "#c69c7f", - "#c59c7e", - "#c59b7d", - "#c59a7c", - "#c4997b", - "#c4997a", - "#c49879", - "#c39778", - "#c39777", - "#c29676", - "#c29575", - "#c29575", - "#c19474", - "#c19373", - "#c19372", - "#c09271", - "#c09170", - "#c0906f", - "#bf906e", - "#bf8f6d", - "#bf8e6c", - "#be8e6b", - "#be8d6a", - "#be8c69", - "#bd8c68", - "#bd8b67", - "#bd8a67", - "#bc8a66", - "#bc8965", - "#bc8864", - "#bb8863", - "#bb8762", - "#bb8661", - "#ba8660", - "#ba855f", - "#ba845e", - "#b9835d", - "#b9835c", - "#b8825b", - "#b8815b", - "#b8815a", - "#b78059", - "#b77f58", - "#b77f57", - "#b67e56", - "#b67d55", - "#b67d54", - "#b57c53", - "#b57b52", - "#b57b51", - "#b47a50", - "#b4794f", - "#b4794f", - "#b3784e", - "#b3774d", - "#b3774c", - "#b2764b", - "#b2754a", - "#b17549", - "#b17448", - "#b17347", - "#b07346", - "#b07245", - "#b07144", - "#af7144", - "#af7043", - "#af6f42", - "#ae6f41", - "#ae6e40", - "#ae6d3f", - "#ad6d3e", - "#ad6c3d", - "#ac6b3c", - "#ac6b3b", - "#ac6a3a", - "#ab6939", - "#ab6939", - "#ab6838", - "#aa6737", - "#aa6736", - "#aa6635", - "#a96534", - "#a96533", - "#a86432", - "#a86331", - "#a86330", - "#a7622f", - "#a7612e", - "#a7612d", - "#a6602c", - "#a65f2b", - "#a55f2a", - "#a55e2a", - "#a55d29", - "#a45d28", - "#a45c27", - "#a35b26", - "#a35b25", - "#a35a24", - "#a25923", - "#a25922", - "#a25821", - "#a15720", - "#a1571f", - "#a0561e", - "#a0551d", - "#a0551c", - "#9f541b", - "#9f531a", - "#9e5318", - "#9e5217", - "#9e5116", - "#9d5115", - "#9d5014", - "#9c4f13", - "#9c4f12", - "#9b4e10", - "#9b4d0f", - "#9b4d0e", - "#9a4c0c", - "#9a4b0b", - "#994b09", - "#994a08", - ], - "flex_red_seq": [ - "#ffffff", - "#fefefe", - "#fefdfd", - "#fdfcfc", - "#fcfbfb", - "#fcfafa", - "#fbf9f9", - "#faf8f8", - "#faf7f7", - "#f9f6f6", - "#f8f5f5", - "#f8f4f5", - "#f7f3f4", - "#f6f2f3", - "#f5f2f2", - "#f5f1f1", - "#f4f0f0", - "#f3efef", - "#f3eeee", - "#f2eded", - "#f1ecec", - "#f1ebec", - "#f0eaeb", - "#efe9ea", - "#efe8e9", - "#eee7e8", - "#eee6e7", - "#ede5e6", - "#ece4e6", - "#ece3e5", - "#ebe2e4", - "#ebe1e3", - "#eae0e2", - "#eae0e1", - "#e9dfe0", - "#e9dedf", - "#e8dddf", - "#e8dcde", - "#e7dbdd", - "#e7dadc", - "#e6d9db", - "#e6d8da", - "#e5d7d9", - "#e5d6d8", - "#e4d5d7", - "#e4d4d7", - "#e3d3d6", - "#e3d2d5", - "#e2d1d4", - "#e2d0d3", - "#e1d0d2", - "#e1cfd1", - "#e0ced1", - "#e0cdd0", - "#dfcccf", - "#dfcbce", - "#decacd", - "#dec9cc", - "#ddc8cb", - "#ddc7cb", - "#dcc6ca", - "#dcc5c9", - "#dbc4c8", - "#dbc4c7", - "#dbc3c6", - "#dac2c5", - "#dac1c5", - "#d9c0c4", - "#d9bfc3", - "#d8bec2", - "#d8bdc1", - "#d7bcc0", - "#d7bbc0", - "#d7babf", - "#d6babe", - "#d6b9bd", - "#d5b8bc", - "#d5b7bb", - "#d4b6bb", - "#d4b5ba", - "#d4b4b9", - "#d3b3b8", - "#d3b2b7", - "#d2b1b6", - "#d2b0b6", - "#d1b0b5", - "#d1afb4", - "#d1aeb3", - "#d0adb2", - "#d0acb1", - "#cfabb1", - "#cfaab0", - "#cfa9af", - "#cea8ae", - "#cea8ad", - "#cda7ad", - "#cda6ac", - "#cca5ab", - "#cca4aa", - "#cca3a9", - "#cba2a9", - "#cba1a8", - "#caa0a7", - "#caa0a6", - "#ca9fa5", - "#c99ea5", - "#c99da4", - "#c89ca3", - "#c89ba2", - "#c89aa1", - "#c799a1", - "#c799a0", - "#c6989f", - "#c6979e", - "#c6969d", - "#c5959d", - "#c5949c", - "#c4939b", - "#c4929a", - "#c49299", - "#c39199", - "#c39098", - "#c28f97", - "#c28e96", - "#c28d96", - "#c18c95", - "#c18c94", - "#c18b93", - "#c08a92", - "#c08992", - "#bf8891", - "#bf8790", - "#bf868f", - "#be858f", - "#be858e", - "#bd848d", - "#bd838c", - "#bd828c", - "#bc818b", - "#bc808a", - "#bb7f89", - "#bb7f88", - "#bb7e88", - "#ba7d87", - "#ba7c86", - "#b97b85", - "#b97a85", - "#b97984", - "#b87983", - "#b87882", - "#b87782", - "#b77681", - "#b77580", - "#b6747f", - "#b6737f", - "#b6737e", - "#b5727d", - "#b5717c", - "#b4707c", - "#b46f7b", - "#b46e7a", - "#b36d79", - "#b36d79", - "#b26c78", - "#b26b77", - "#b26a76", - "#b16976", - "#b16875", - "#b06774", - "#b06773", - "#b06673", - "#af6572", - "#af6471", - "#ae6371", - "#ae6270", - "#ae616f", - "#ad616e", - "#ad606e", - "#ac5f6d", - "#ac5e6c", - "#ab5d6b", - "#ab5c6b", - "#ab5b6a", - "#aa5b69", - "#aa5a69", - "#a95968", - "#a95867", - "#a95766", - "#a85666", - "#a85565", - "#a75464", - "#a75463", - "#a65363", - "#a65262", - "#a65161", - "#a55061", - "#a54f60", - "#a44e5f", - "#a44d5e", - "#a34d5e", - "#a34c5d", - "#a34b5c", - "#a24a5c", - "#a2495b", - "#a1485a", - "#a1475a", - "#a04659", - "#a04558", - "#9f4557", - "#9f4457", - "#9f4356", - "#9e4255", - "#9e4155", - "#9d4054", - "#9d3f53", - "#9c3e52", - "#9c3d52", - "#9b3c51", - "#9b3b50", - "#9a3b50", - "#9a3a4f", - "#99394e", - "#99384e", - "#98374d", - "#98364c", - "#98354b", - "#97344b", - "#97334a", - "#963249", - "#963149", - "#953048", - "#952f47", - "#942e47", - "#942d46", - "#932c45", - "#932b45", - "#922a44", - "#922943", - "#912843", - "#912742", - "#902641", - "#902540", - "#8f2440", - "#8e223f", - "#8e213e", - "#8d203e", - "#8d1f3d", - "#8c1e3c", - "#8c1d3c", - "#8b1b3b", - "#8b1a3a", - "#8a193a", - "#8a1739", - "#891638", - "#891438", - "#881337", - ], - "flex_purple_seq": [ - "#ffffff", - "#fefefe", - "#fdfdfd", - "#fcfcfd", - "#fbfbfc", - "#fafafb", - "#f9f9fa", - "#f8f8f9", - "#f7f7f9", - "#f6f6f8", - "#f5f5f7", - "#f4f4f6", - "#f3f3f6", - "#f2f2f5", - "#f1f1f4", - "#f0f0f3", - "#efeff3", - "#eeeef2", - "#ededf1", - "#ececf0", - "#ebebf0", - "#eaeaef", - "#e9e9ee", - "#e8e8ed", - "#e8e8ed", - "#e7e7ec", - "#e6e6eb", - "#e5e5eb", - "#e4e4ea", - "#e3e3e9", - "#e2e2e8", - "#e1e1e8", - "#e0e0e7", - "#dfdfe6", - "#dedee6", - "#dddde5", - "#dcdce4", - "#dbdbe4", - "#dadae3", - "#d9dae2", - "#d9d9e2", - "#d8d8e1", - "#d7d7e0", - "#d6d6e0", - "#d5d5df", - "#d4d4df", - "#d3d3de", - "#d2d2dd", - "#d1d1dd", - "#d0d1dc", - "#d0d0db", - "#cfcfdb", - "#ceceda", - "#cdcdd9", - "#ccccd9", - "#cbcbd8", - "#cacad8", - "#c9c9d7", - "#c8c9d6", - "#c8c8d6", - "#c7c7d5", - "#c6c6d5", - "#c5c5d4", - "#c4c4d3", - "#c3c3d3", - "#c2c2d2", - "#c2c2d2", - "#c1c1d1", - "#c0c0d1", - "#bfbfd0", - "#bebecf", - "#bdbdcf", - "#bcbcce", - "#bcbcce", - "#bbbbcd", - "#babacd", - "#b9b9cc", - "#b8b8cc", - "#b7b7cb", - "#b7b7ca", - "#b6b6ca", - "#b5b5c9", - "#b4b4c9", - "#b3b3c8", - "#b3b2c8", - "#b2b2c7", - "#b1b1c7", - "#b0b0c6", - "#afafc6", - "#aeaec5", - "#aeadc5", - "#adadc4", - "#acacc4", - "#ababc3", - "#aaaac3", - "#aaa9c2", - "#a9a8c2", - "#a8a8c1", - "#a7a7c1", - "#a6a6c0", - "#a6a5c0", - "#a5a4bf", - "#a4a4bf", - "#a3a3be", - "#a3a2be", - "#a2a1bd", - "#a1a0bd", - "#a0a0bc", - "#9f9fbc", - "#9f9ebb", - "#9e9dbb", - "#9d9cba", - "#9c9cba", - "#9c9bb9", - "#9b9ab9", - "#9a99b8", - "#9998b8", - "#9998b8", - "#9897b7", - "#9796b7", - "#9695b6", - "#9694b6", - "#9594b5", - "#9493b5", - "#9392b4", - "#9391b4", - "#9291b4", - "#9190b3", - "#908fb3", - "#908eb2", - "#8f8db2", - "#8e8db1", - "#8d8cb1", - "#8d8bb1", - "#8c8ab0", - "#8b8ab0", - "#8a89af", - "#8a88af", - "#8987ae", - "#8886ae", - "#8886ae", - "#8785ad", - "#8684ad", - "#8583ac", - "#8583ac", - "#8482ac", - "#8381ab", - "#8280ab", - "#8280ab", - "#817faa", - "#807eaa", - "#807da9", - "#7f7ca9", - "#7e7ca9", - "#7e7ba8", - "#7d7aa8", - "#7c79a8", - "#7b79a7", - "#7b78a7", - "#7a77a6", - "#7976a6", - "#7976a6", - "#7875a5", - "#7774a5", - "#7773a5", - "#7673a4", - "#7572a4", - "#7571a4", - "#7470a3", - "#736fa3", - "#736fa3", - "#726ea2", - "#716da2", - "#716ca2", - "#706ca1", - "#6f6ba1", - "#6f6aa1", - "#6e69a0", - "#6d69a0", - "#6d68a0", - "#6c679f", - "#6b669f", - "#6b669f", - "#6a659e", - "#69649e", - "#69639e", - "#68629d", - "#67629d", - "#67619d", - "#66609d", - "#655f9c", - "#655f9c", - "#645e9c", - "#635d9b", - "#635c9b", - "#625c9b", - "#625b9b", - "#615a9a", - "#60599a", - "#60589a", - "#5f589a", - "#5e5799", - "#5e5699", - "#5d5599", - "#5d5498", - "#5c5498", - "#5b5398", - "#5b5298", - "#5a5198", - "#595097", - "#595097", - "#584f97", - "#584e97", - "#574d96", - "#564c96", - "#564c96", - "#554b96", - "#554a95", - "#544995", - "#544895", - "#534795", - "#524795", - "#524694", - "#514594", - "#514494", - "#504394", - "#504294", - "#4f4294", - "#4e4193", - "#4e4093", - "#4d3f93", - "#4d3e93", - "#4c3d93", - "#4c3c93", - "#4b3b93", - "#4b3a92", - "#4a3992", - "#4a3892", - "#493892", - "#493792", - "#483692", - "#483592", - "#473492", - "#473391", - "#463291", - "#463191", - "#452f91", - "#452e91", - "#442d91", - "#442c91", - "#432b91", - "#432a91", - "#422991", - "#422891", - "#412691", - "#412591", - ], - "flex_grey_seq": [ - "#ffffff", - "#fefefe", - "#fdfdfd", - "#fcfcfc", - "#fbfbfc", - "#fafafb", - "#f9f9fa", - "#f8f9f9", - "#f8f8f8", - "#f7f7f7", - "#f6f6f6", - "#f5f5f6", - "#f4f4f5", - "#f3f3f4", - "#f2f2f3", - "#f1f1f2", - "#f0f0f1", - "#eff0f1", - "#eeeff0", - "#eeeeef", - "#ededee", - "#ececed", - "#ebebec", - "#eaeaec", - "#e9e9eb", - "#e8e9ea", - "#e7e8e9", - "#e6e7e8", - "#e6e6e8", - "#e5e5e7", - "#e4e4e6", - "#e3e3e5", - "#e2e3e4", - "#e1e2e4", - "#e0e1e3", - "#dfe0e2", - "#dfdfe1", - "#dedee0", - "#dddde0", - "#dcdddf", - "#dbdcde", - "#dadbdd", - "#d9dadd", - "#d9d9dc", - "#d8d8db", - "#d7d8da", - "#d6d7da", - "#d5d6d9", - "#d4d5d8", - "#d4d4d7", - "#d3d4d6", - "#d2d3d6", - "#d1d2d5", - "#d0d1d4", - "#cfd0d3", - "#cfcfd3", - "#cecfd2", - "#cdced1", - "#cccdd0", - "#cbccd0", - "#cacbcf", - "#cacbce", - "#c9cace", - "#c8c9cd", - "#c7c8cc", - "#c6c8cb", - "#c6c7cb", - "#c5c6ca", - "#c4c5c9", - "#c3c4c8", - "#c2c4c8", - "#c2c3c7", - "#c1c2c6", - "#c0c1c6", - "#bfc0c5", - "#bec0c4", - "#bebfc3", - "#bdbec3", - "#bcbdc2", - "#bbbdc1", - "#babcc1", - "#babbc0", - "#b9babf", - "#b8babe", - "#b7b9be", - "#b7b8bd", - "#b6b7bc", - "#b5b7bc", - "#b4b6bb", - "#b3b5ba", - "#b3b4ba", - "#b2b4b9", - "#b1b3b8", - "#b0b2b8", - "#b0b1b7", - "#afb1b6", - "#aeb0b6", - "#adafb5", - "#adaeb4", - "#acaeb3", - "#abadb3", - "#aaacb2", - "#aaabb1", - "#a9abb1", - "#a8aab0", - "#a7a9af", - "#a7a8af", - "#a6a8ae", - "#a5a7ad", - "#a4a6ad", - "#a4a6ac", - "#a3a5ab", - "#a2a4ab", - "#a1a3aa", - "#a1a3aa", - "#a0a2a9", - "#9fa1a8", - "#9ea1a8", - "#9ea0a7", - "#9d9fa6", - "#9c9ea6", - "#9c9ea5", - "#9b9da4", - "#9a9ca4", - "#999ca3", - "#999ba2", - "#989aa2", - "#979aa1", - "#9799a0", - "#9698a0", - "#95979f", - "#94979f", - "#94969e", - "#93959d", - "#92959d", - "#92949c", - "#91939b", - "#90939b", - "#8f929a", - "#8f919a", - "#8e9199", - "#8d9098", - "#8d8f98", - "#8c8e97", - "#8b8e96", - "#8b8d96", - "#8a8c95", - "#898c95", - "#888b94", - "#888a93", - "#878a93", - "#868992", - "#868891", - "#858891", - "#848790", - "#848690", - "#83868f", - "#82858e", - "#82848e", - "#81848d", - "#80838d", - "#80828c", - "#7f828b", - "#7e818b", - "#7e808a", - "#7d808a", - "#7c7f89", - "#7b7e88", - "#7b7e88", - "#7a7d87", - "#797c87", - "#797c86", - "#787b85", - "#777a85", - "#777a84", - "#767984", - "#757983", - "#757882", - "#747782", - "#737781", - "#737681", - "#727580", - "#71757f", - "#71747f", - "#70737e", - "#70737e", - "#6f727d", - "#6e717d", - "#6e717c", - "#6d707b", - "#6c707b", - "#6c6f7a", - "#6b6e7a", - "#6a6e79", - "#6a6d79", - "#696c78", - "#686c77", - "#686b77", - "#676a76", - "#666a76", - "#666975", - "#656974", - "#646874", - "#646773", - "#636773", - "#636672", - "#626572", - "#616571", - "#616470", - "#606370", - "#5f636f", - "#5f626f", - "#5e626e", - "#5d616e", - "#5d606d", - "#5c606c", - "#5b5f6c", - "#5b5e6b", - "#5a5e6b", - "#5a5d6a", - "#595d6a", - "#585c69", - "#585b69", - "#575b68", - "#565a67", - "#565967", - "#555966", - "#555866", - "#545865", - "#535765", - "#535664", - "#525663", - "#515563", - "#515562", - "#505462", - "#4f5361", - "#4f5361", - "#4e5260", - "#4e5160", - "#4d515f", - "#4c505e", - "#4c505e", - "#4b4f5d", - "#4a4e5d", - "#4a4e5c", - "#494d5c", - "#494d5b", - "#484c5a", - "#474b5a", - "#474b59", - "#464a59", - "#454958", - "#454958", - "#444857", - "#444857", - "#434756", - ], -} -CATEGORICAL_PALETTES_HEX = { - "flex_distinct": [ - "#176737", - "#FF7B0D", - "#979BAA", - "#F44E6A", - "#0062FF", - "#26AB5B", - "#6D3EF2", - "#F59E0B", - ] -} -DIVERGING_PALETTES_HEX = { - "flex_BuRd": [ - "#002766", - "#022967", - "#052b67", - "#072d68", - "#0a2e69", - "#0d3069", - "#10326a", - "#12346b", - "#15356c", - "#17376c", - "#1a396d", - "#1c3a6e", - "#1e3c6f", - "#203e70", - "#223f71", - "#244171", - "#264372", - "#284473", - "#2a4674", - "#2c4775", - "#2e4976", - "#304a77", - "#324c78", - "#344e79", - "#364f7a", - "#38517b", - "#3a527c", - "#3c547d", - "#3e557e", - "#3f577f", - "#415980", - "#435a80", - "#455c81", - "#475d83", - "#495f84", - "#4b6085", - "#4c6286", - "#4e6387", - "#506588", - "#526789", - "#54688a", - "#566a8b", - "#586b8c", - "#5a6d8d", - "#5b6e8e", - "#5d708f", - "#5f7290", - "#617391", - "#637592", - "#657694", - "#677895", - "#687a96", - "#6a7b97", - "#6c7d98", - "#6e7e99", - "#70809a", - "#72829b", - "#74839d", - "#76859e", - "#78879f", - "#7a88a0", - "#7b8aa1", - "#7d8ca3", - "#7f8da4", - "#818fa5", - "#8391a6", - "#8592a8", - "#8794a9", - "#8996aa", - "#8b97ab", - "#8d99ad", - "#8f9bae", - "#919daf", - "#939eb1", - "#95a0b2", - "#97a2b3", - "#99a4b4", - "#9ba5b6", - "#9da7b7", - "#9fa9b9", - "#a1abba", - "#a3acbb", - "#a5aebd", - "#a7b0be", - "#a9b2c0", - "#abb4c1", - "#adb6c2", - "#afb7c4", - "#b1b9c5", - "#b4bbc7", - "#b6bdc8", - "#b8bfca", - "#bac1cb", - "#bcc3cd", - "#bec5ce", - "#c0c6d0", - "#c3c8d1", - "#c5cad3", - "#c7ccd5", - "#c9ced6", - "#cbd0d8", - "#ced2d9", - "#d0d4db", - "#d2d6dd", - "#d4d8de", - "#d7dae0", - "#d9dce2", - "#dbdee3", - "#dee0e5", - "#e0e3e7", - "#e2e5e9", - "#e4e7ea", - "#e7e9ec", - "#e9ebee", - "#ecedf0", - "#eeeff2", - "#f0f2f3", - "#f3f4f5", - "#f5f6f7", - "#f8f8f9", - "#fafafb", - "#fdfdfd", - "#FFFFFF", - "#fefdfd", - "#fcfbfb", - "#fbf9f9", - "#f9f7f7", - "#f8f5f5", - "#f6f3f3", - "#f5f1f1", - "#f4efef", - "#f2edee", - "#f1ebec", - "#efe9ea", - "#eee7e8", - "#ede5e6", - "#ece3e4", - "#ebe1e3", - "#e9dfe1", - "#e8dddf", - "#e7dbdd", - "#e6d9db", - "#e5d7d9", - "#e4d5d8", - "#e3d3d6", - "#e2d1d4", - "#e1cfd2", - "#e0cdd0", - "#dfcccf", - "#decacd", - "#ddc8cb", - "#dcc6c9", - "#dbc4c7", - "#dac2c6", - "#d9c0c4", - "#d8bec2", - "#d7bcc0", - "#d6babf", - "#d6b8bd", - "#d5b6bb", - "#d4b5b9", - "#d3b3b8", - "#d2b1b6", - "#d1afb4", - "#d0adb2", - "#cfabb1", - "#cfa9af", - "#cea8ad", - "#cda6ac", - "#cca4aa", - "#cba2a8", - "#caa0a7", - "#c99ea5", - "#c99ca3", - "#c89ba2", - "#c799a0", - "#c6979e", - "#c5959d", - "#c4939b", - "#c49199", - "#c39098", - "#c28e96", - "#c18c94", - "#c08a93", - "#c08891", - "#bf8790", - "#be858e", - "#bd838c", - "#bc818b", - "#bb7f89", - "#bb7e88", - "#ba7c86", - "#b97a84", - "#b87883", - "#b77681", - "#b67580", - "#b6737e", - "#b5717d", - "#b46f7b", - "#b36d79", - "#b26c78", - "#b26a76", - "#b16875", - "#b06673", - "#af6572", - "#ae6370", - "#ad616f", - "#ac5f6d", - "#ac5d6c", - "#ab5c6a", - "#aa5a69", - "#a95867", - "#a85666", - "#a75464", - "#a65263", - "#a55161", - "#a44f60", - "#a44d5e", - "#a34b5d", - "#a2495b", - "#a1475a", - "#a04658", - "#9f4457", - "#9e4255", - "#9d4054", - "#9c3e52", - "#9b3c51", - "#9a3a4f", - "#99384e", - "#98364c", - "#97344b", - "#96324a", - "#953048", - "#942e47", - "#932c45", - "#922a44", - "#912842", - "#902541", - "#8f233f", - "#8e213e", - "#8d1e3d", - "#8b1c3b", - "#8a193a", - "#891638", - "#881337", - ], - "flex_RdBu": [ - "#881337", - "#891638", - "#8a193a", - "#8b1c3b", - "#8d1e3d", - "#8e213e", - "#8f233f", - "#902541", - "#912842", - "#922a44", - "#932c45", - "#942e47", - "#953048", - "#96324a", - "#97344b", - "#98364c", - "#99384e", - "#9a3a4f", - "#9b3c51", - "#9c3e52", - "#9d4054", - "#9e4255", - "#9f4457", - "#a04658", - "#a1475a", - "#a2495b", - "#a34b5d", - "#a44d5e", - "#a44f60", - "#a55161", - "#a65263", - "#a75464", - "#a85666", - "#a95867", - "#aa5a69", - "#ab5c6a", - "#ac5d6c", - "#ac5f6d", - "#ad616f", - "#ae6370", - "#af6572", - "#b06673", - "#b16875", - "#b26a76", - "#b26c78", - "#b36d79", - "#b46f7b", - "#b5717d", - "#b6737e", - "#b67580", - "#b77681", - "#b87883", - "#b97a84", - "#ba7c86", - "#bb7e88", - "#bb7f89", - "#bc818b", - "#bd838c", - "#be858e", - "#bf8790", - "#c08891", - "#c08a93", - "#c18c94", - "#c28e96", - "#c39098", - "#c49199", - "#c4939b", - "#c5959d", - "#c6979e", - "#c799a0", - "#c89ba2", - "#c99ca3", - "#c99ea5", - "#caa0a7", - "#cba2a8", - "#cca4aa", - "#cda6ac", - "#cea8ad", - "#cfa9af", - "#cfabb1", - "#d0adb2", - "#d1afb4", - "#d2b1b6", - "#d3b3b8", - "#d4b5b9", - "#d5b6bb", - "#d6b8bd", - "#d6babf", - "#d7bcc0", - "#d8bec2", - "#d9c0c4", - "#dac2c6", - "#dbc4c7", - "#dcc6c9", - "#ddc8cb", - "#decacd", - "#dfcccf", - "#e0cdd0", - "#e1cfd2", - "#e2d1d4", - "#e3d3d6", - "#e4d5d8", - "#e5d7d9", - "#e6d9db", - "#e7dbdd", - "#e8dddf", - "#e9dfe1", - "#ebe1e3", - "#ece3e4", - "#ede5e6", - "#eee7e8", - "#efe9ea", - "#f1ebec", - "#f2edee", - "#f4efef", - "#f5f1f1", - "#f6f3f3", - "#f8f5f5", - "#f9f7f7", - "#fbf9f9", - "#fcfbfb", - "#fefdfd", - "#FFFFFF", - "#fdfdfd", - "#fafafb", - "#f8f8f9", - "#f5f6f7", - "#f3f4f5", - "#f0f2f3", - "#eeeff2", - "#ecedf0", - "#e9ebee", - "#e7e9ec", - "#e4e7ea", - "#e2e5e9", - "#e0e3e7", - "#dee0e5", - "#dbdee3", - "#d9dce2", - "#d7dae0", - "#d4d8de", - "#d2d6dd", - "#d0d4db", - "#ced2d9", - "#cbd0d8", - "#c9ced6", - "#c7ccd5", - "#c5cad3", - "#c3c8d1", - "#c0c6d0", - "#bec5ce", - "#bcc3cd", - "#bac1cb", - "#b8bfca", - "#b6bdc8", - "#b4bbc7", - "#b1b9c5", - "#afb7c4", - "#adb6c2", - "#abb4c1", - "#a9b2c0", - "#a7b0be", - "#a5aebd", - "#a3acbb", - "#a1abba", - "#9fa9b9", - "#9da7b7", - "#9ba5b6", - "#99a4b4", - "#97a2b3", - "#95a0b2", - "#939eb1", - "#919daf", - "#8f9bae", - "#8d99ad", - "#8b97ab", - "#8996aa", - "#8794a9", - "#8592a8", - "#8391a6", - "#818fa5", - "#7f8da4", - "#7d8ca3", - "#7b8aa1", - "#7a88a0", - "#78879f", - "#76859e", - "#74839d", - "#72829b", - "#70809a", - "#6e7e99", - "#6c7d98", - "#6a7b97", - "#687a96", - "#677895", - "#657694", - "#637592", - "#617391", - "#5f7290", - "#5d708f", - "#5b6e8e", - "#5a6d8d", - "#586b8c", - "#566a8b", - "#54688a", - "#526789", - "#506588", - "#4e6387", - "#4c6286", - "#4b6085", - "#495f84", - "#475d83", - "#455c81", - "#435a80", - "#415980", - "#3f577f", - "#3e557e", - "#3c547d", - "#3a527c", - "#38517b", - "#364f7a", - "#344e79", - "#324c78", - "#304a77", - "#2e4976", - "#2c4775", - "#2a4674", - "#284473", - "#264372", - "#244171", - "#223f71", - "#203e70", - "#1e3c6f", - "#1c3a6e", - "#1a396d", - "#17376c", - "#15356c", - "#12346b", - "#10326a", - "#0d3069", - "#0a2e69", - "#072d68", - "#052b67", - "#022967", - "#002766", - ], - "flex_GrPu": [ - "#0f4424", - "#124526", - "#144727", - "#174829", - "#19492a", - "#1b4b2c", - "#1d4c2d", - "#1f4d2f", - "#214f30", - "#235032", - "#255234", - "#275335", - "#295437", - "#2b5638", - "#2d573a", - "#2f583b", - "#315a3d", - "#335b3e", - "#355c40", - "#365e42", - "#385f43", - "#3a6045", - "#3c6246", - "#3e6348", - "#3f644a", - "#41664b", - "#43674d", - "#45684e", - "#476a50", - "#486b52", - "#4a6c53", - "#4c6e55", - "#4e6f56", - "#4f7058", - "#51725a", - "#53735b", - "#55745d", - "#57765f", - "#587760", - "#5a7962", - "#5c7a63", - "#5e7b65", - "#5f7d67", - "#617e68", - "#637f6a", - "#65816c", - "#67826d", - "#68846f", - "#6a8571", - "#6c8672", - "#6e8874", - "#6f8976", - "#718b78", - "#738c79", - "#758e7b", - "#778f7d", - "#79907e", - "#7a9280", - "#7c9382", - "#7e9584", - "#809685", - "#829887", - "#849989", - "#859b8b", - "#879c8c", - "#899d8e", - "#8b9f90", - "#8da092", - "#8fa294", - "#91a395", - "#92a597", - "#94a699", - "#96a89b", - "#98aa9d", - "#9aab9e", - "#9cada0", - "#9eaea2", - "#a0b0a4", - "#a2b1a6", - "#a4b3a8", - "#a6b4aa", - "#a8b6ab", - "#aab8ad", - "#acb9af", - "#aebbb1", - "#b0bcb3", - "#b2beb5", - "#b4c0b7", - "#b6c1b9", - "#b8c3bb", - "#bac5bd", - "#bcc6bf", - "#bec8c1", - "#c0cac2", - "#c2cbc4", - "#c4cdc6", - "#c6cfc8", - "#c8d0ca", - "#cad2cc", - "#ccd4ce", - "#ced6d0", - "#d0d7d2", - "#d3d9d5", - "#d5dbd7", - "#d7ddd9", - "#d9dfdb", - "#dbe0dd", - "#dde2df", - "#dfe4e1", - "#e2e6e3", - "#e4e8e5", - "#e6eae7", - "#e8ebe9", - "#ebedeb", - "#edefee", - "#eff1f0", - "#f1f3f2", - "#f4f5f4", - "#f6f7f6", - "#f8f9f8", - "#fafbfb", - "#fdfdfd", - "#FFFFFF", - "#fdfdfd", - "#fbfbfc", - "#f9f9fa", - "#f7f7f8", - "#f5f5f7", - "#f3f3f5", - "#f1f0f4", - "#efeef2", - "#ececf0", - "#eaeaef", - "#e8e8ed", - "#e6e6ec", - "#e4e5ea", - "#e3e3e9", - "#e1e1e8", - "#dfdfe6", - "#dddde5", - "#dbdbe3", - "#d9d9e2", - "#d7d7e1", - "#d5d5df", - "#d3d3de", - "#d1d1dd", - "#cfd0db", - "#ceceda", - "#ccccd9", - "#cacad7", - "#c8c8d6", - "#c6c6d5", - "#c4c4d4", - "#c3c3d2", - "#c1c1d1", - "#bfbfd0", - "#bdbdcf", - "#bcbcce", - "#babacd", - "#b8b8cb", - "#b6b6ca", - "#b5b4c9", - "#b3b3c8", - "#b1b1c7", - "#afafc6", - "#aeaec5", - "#acacc4", - "#aaaac3", - "#a9a8c1", - "#a7a7c0", - "#a5a5bf", - "#a4a3be", - "#a2a2bd", - "#a1a0bc", - "#9f9ebb", - "#9d9dba", - "#9c9bb9", - "#9a99b8", - "#9898b8", - "#9796b7", - "#9594b6", - "#9493b5", - "#9291b4", - "#918fb3", - "#8f8eb2", - "#8e8cb1", - "#8c8ab0", - "#8b89af", - "#8987af", - "#8786ae", - "#8684ad", - "#8482ac", - "#8381ab", - "#817faa", - "#807eaa", - "#7f7ca9", - "#7d7aa8", - "#7c79a7", - "#7a77a6", - "#7976a6", - "#7774a5", - "#7672a4", - "#7471a4", - "#736fa3", - "#726ea2", - "#706ca1", - "#6f6aa1", - "#6d69a0", - "#6c679f", - "#6b669f", - "#69649e", - "#68629d", - "#67619d", - "#655f9c", - "#645e9c", - "#635c9b", - "#615a9a", - "#60599a", - "#5f5799", - "#5d5599", - "#5c5498", - "#5b5298", - "#595097", - "#584f97", - "#574d96", - "#564b96", - "#554a95", - "#534895", - "#524695", - "#514494", - "#504394", - "#4f4193", - "#4d3f93", - "#4c3d93", - "#4b3b92", - "#4a3992", - "#493792", - "#483592", - "#473392", - "#463191", - "#452f91", - "#442d91", - "#432a91", - "#422891", - "#412591", - ], - "flex_PuGr": [ - "#412591", - "#422891", - "#432a91", - "#442d91", - "#452f91", - "#463191", - "#473392", - "#483592", - "#493792", - "#4a3992", - "#4b3b92", - "#4c3d93", - "#4d3f93", - "#4f4193", - "#504394", - "#514494", - "#524695", - "#534895", - "#554a95", - "#564b96", - "#574d96", - "#584f97", - "#595097", - "#5b5298", - "#5c5498", - "#5d5599", - "#5f5799", - "#60599a", - "#615a9a", - "#635c9b", - "#645e9c", - "#655f9c", - "#67619d", - "#68629d", - "#69649e", - "#6b669f", - "#6c679f", - "#6d69a0", - "#6f6aa1", - "#706ca1", - "#726ea2", - "#736fa3", - "#7471a4", - "#7672a4", - "#7774a5", - "#7976a6", - "#7a77a6", - "#7c79a7", - "#7d7aa8", - "#7f7ca9", - "#807eaa", - "#817faa", - "#8381ab", - "#8482ac", - "#8684ad", - "#8786ae", - "#8987af", - "#8b89af", - "#8c8ab0", - "#8e8cb1", - "#8f8eb2", - "#918fb3", - "#9291b4", - "#9493b5", - "#9594b6", - "#9796b7", - "#9898b8", - "#9a99b8", - "#9c9bb9", - "#9d9dba", - "#9f9ebb", - "#a1a0bc", - "#a2a2bd", - "#a4a3be", - "#a5a5bf", - "#a7a7c0", - "#a9a8c1", - "#aaaac3", - "#acacc4", - "#aeaec5", - "#afafc6", - "#b1b1c7", - "#b3b3c8", - "#b5b4c9", - "#b6b6ca", - "#b8b8cb", - "#babacd", - "#bcbcce", - "#bdbdcf", - "#bfbfd0", - "#c1c1d1", - "#c3c3d2", - "#c4c4d4", - "#c6c6d5", - "#c8c8d6", - "#cacad7", - "#ccccd9", - "#ceceda", - "#cfd0db", - "#d1d1dd", - "#d3d3de", - "#d5d5df", - "#d7d7e1", - "#d9d9e2", - "#dbdbe3", - "#dddde5", - "#dfdfe6", - "#e1e1e8", - "#e3e3e9", - "#e4e5ea", - "#e6e6ec", - "#e8e8ed", - "#eaeaef", - "#ececf0", - "#efeef2", - "#f1f0f4", - "#f3f3f5", - "#f5f5f7", - "#f7f7f8", - "#f9f9fa", - "#fbfbfc", - "#fdfdfd", - "#FFFFFF", - "#fdfdfd", - "#fafbfb", - "#f8f9f8", - "#f6f7f6", - "#f4f5f4", - "#f1f3f2", - "#eff1f0", - "#edefee", - "#ebedeb", - "#e8ebe9", - "#e6eae7", - "#e4e8e5", - "#e2e6e3", - "#dfe4e1", - "#dde2df", - "#dbe0dd", - "#d9dfdb", - "#d7ddd9", - "#d5dbd7", - "#d3d9d5", - "#d0d7d2", - "#ced6d0", - "#ccd4ce", - "#cad2cc", - "#c8d0ca", - "#c6cfc8", - "#c4cdc6", - "#c2cbc4", - "#c0cac2", - "#bec8c1", - "#bcc6bf", - "#bac5bd", - "#b8c3bb", - "#b6c1b9", - "#b4c0b7", - "#b2beb5", - "#b0bcb3", - "#aebbb1", - "#acb9af", - "#aab8ad", - "#a8b6ab", - "#a6b4aa", - "#a4b3a8", - "#a2b1a6", - "#a0b0a4", - "#9eaea2", - "#9cada0", - "#9aab9e", - "#98aa9d", - "#96a89b", - "#94a699", - "#92a597", - "#91a395", - "#8fa294", - "#8da092", - "#8b9f90", - "#899d8e", - "#879c8c", - "#859b8b", - "#849989", - "#829887", - "#809685", - "#7e9584", - "#7c9382", - "#7a9280", - "#79907e", - "#778f7d", - "#758e7b", - "#738c79", - "#718b78", - "#6f8976", - "#6e8874", - "#6c8672", - "#6a8571", - "#68846f", - "#67826d", - "#65816c", - "#637f6a", - "#617e68", - "#5f7d67", - "#5e7b65", - "#5c7a63", - "#5a7962", - "#587760", - "#57765f", - "#55745d", - "#53735b", - "#51725a", - "#4f7058", - "#4e6f56", - "#4c6e55", - "#4a6c53", - "#486b52", - "#476a50", - "#45684e", - "#43674d", - "#41664b", - "#3f644a", - "#3e6348", - "#3c6246", - "#3a6045", - "#385f43", - "#365e42", - "#355c40", - "#335b3e", - "#315a3d", - "#2f583b", - "#2d573a", - "#2b5638", - "#295437", - "#275335", - "#255234", - "#235032", - "#214f30", - "#1f4d2f", - "#1d4c2d", - "#1b4b2c", - "#19492a", - "#174829", - "#144727", - "#124526", - "#0f4424", - ], - "flex_TuOr": [ - "#134e4a", - "#164f4b", - "#19504d", - "#1b524e", - "#1e534f", - "#205450", - "#225552", - "#255753", - "#275854", - "#295955", - "#2b5a57", - "#2d5c58", - "#2f5d59", - "#315e5a", - "#335f5c", - "#35615d", - "#37625e", - "#39635f", - "#3b6461", - "#3c6662", - "#3e6763", - "#406865", - "#426966", - "#446b67", - "#456c68", - "#476d6a", - "#496e6b", - "#4b706c", - "#4d716e", - "#4e726f", - "#507370", - "#527572", - "#547673", - "#557774", - "#577975", - "#597a77", - "#5b7b78", - "#5d7c79", - "#5e7e7b", - "#607f7c", - "#62807d", - "#64827f", - "#658380", - "#678482", - "#698683", - "#6b8784", - "#6c8886", - "#6e8a87", - "#708b88", - "#728c8a", - "#738e8b", - "#758f8c", - "#77908e", - "#79928f", - "#7a9391", - "#7c9492", - "#7e9693", - "#809795", - "#819896", - "#839a98", - "#859b99", - "#879d9b", - "#899e9c", - "#8a9f9d", - "#8ca19f", - "#8ea2a0", - "#90a4a2", - "#92a5a3", - "#93a7a5", - "#95a8a6", - "#97a9a8", - "#99aba9", - "#9bacab", - "#9daeac", - "#9eafae", - "#a0b1af", - "#a2b2b1", - "#a4b4b2", - "#a6b5b4", - "#a8b7b5", - "#aab8b7", - "#acbab8", - "#adbbba", - "#afbdbb", - "#b1bebd", - "#b3c0bf", - "#b5c1c0", - "#b7c3c2", - "#b9c5c3", - "#bbc6c5", - "#bdc8c7", - "#bfc9c8", - "#c1cbca", - "#c3cccc", - "#c5cecd", - "#c7d0cf", - "#c9d1d1", - "#cbd3d2", - "#cdd5d4", - "#cfd6d6", - "#d1d8d7", - "#d3dad9", - "#d5dbdb", - "#d7dddc", - "#d9dfde", - "#dbe0e0", - "#dde2e2", - "#dfe4e3", - "#e1e6e5", - "#e3e7e7", - "#e5e9e9", - "#e7ebeb", - "#e9edec", - "#ebeeee", - "#eef0f0", - "#f0f2f2", - "#f2f4f4", - "#f4f6f6", - "#f6f8f7", - "#f8f9f9", - "#fbfbfb", - "#fdfdfd", - "#FFFFFF", - "#fefdfd", - "#fcfcfb", - "#fbfaf9", - "#faf8f7", - "#f9f7f6", - "#f7f5f4", - "#f6f4f2", - "#f5f2f0", - "#f4f0ee", - "#f2efec", - "#f1edea", - "#f0ece8", - "#efeae6", - "#eee8e4", - "#ede7e2", - "#ece5df", - "#ebe4dd", - "#eae2db", - "#e9e0d9", - "#e8dfd7", - "#e7ddd5", - "#e6dcd3", - "#e5dad1", - "#e5d8cf", - "#e4d7cd", - "#e3d5cb", - "#e2d4c9", - "#e1d2c7", - "#e0d0c5", - "#dfcfc3", - "#dfcdc1", - "#deccbe", - "#ddcabc", - "#dcc9ba", - "#dbc7b8", - "#dac6b6", - "#dac4b4", - "#d9c2b2", - "#d8c1b0", - "#d7bfae", - "#d6beac", - "#d6bcaa", - "#d5bba8", - "#d4b9a6", - "#d3b8a4", - "#d3b6a2", - "#d2b5a0", - "#d1b39e", - "#d0b29c", - "#d0b09a", - "#cfaf98", - "#cead96", - "#cdac94", - "#cdaa92", - "#cca990", - "#cba78e", - "#caa68c", - "#caa48a", - "#c9a388", - "#c8a286", - "#c7a084", - "#c79f82", - "#c69d80", - "#c59c7e", - "#c59a7c", - "#c4997a", - "#c39778", - "#c29676", - "#c29474", - "#c19372", - "#c09270", - "#c0906e", - "#bf8f6c", - "#be8d6b", - "#bd8c69", - "#bd8a67", - "#bc8965", - "#bb8863", - "#bb8661", - "#ba855f", - "#b9835d", - "#b8825b", - "#b88059", - "#b77f57", - "#b67e55", - "#b57c53", - "#b57b51", - "#b47950", - "#b3784e", - "#b3774c", - "#b2754a", - "#b17448", - "#b07246", - "#b07144", - "#af7042", - "#ae6e40", - "#ad6d3e", - "#ad6b3c", - "#ac6a3a", - "#ab6938", - "#aa6737", - "#a96635", - "#a96433", - "#a86331", - "#a7622f", - "#a6602d", - "#a65f2b", - "#a55d29", - "#a45c27", - "#a35b25", - "#a25923", - "#a15821", - "#a1571f", - "#a0551c", - "#9f541a", - "#9e5218", - "#9d5116", - "#9c5013", - "#9c4e11", - "#9b4d0e", - "#9a4b0b", - "#994a08", - ], - "flex_OrTu": [ - "#994a08", - "#9a4b0b", - "#9b4d0e", - "#9c4e11", - "#9c5013", - "#9d5116", - "#9e5218", - "#9f541a", - "#a0551c", - "#a1571f", - "#a15821", - "#a25923", - "#a35b25", - "#a45c27", - "#a55d29", - "#a65f2b", - "#a6602d", - "#a7622f", - "#a86331", - "#a96433", - "#a96635", - "#aa6737", - "#ab6938", - "#ac6a3a", - "#ad6b3c", - "#ad6d3e", - "#ae6e40", - "#af7042", - "#b07144", - "#b07246", - "#b17448", - "#b2754a", - "#b3774c", - "#b3784e", - "#b47950", - "#b57b51", - "#b57c53", - "#b67e55", - "#b77f57", - "#b88059", - "#b8825b", - "#b9835d", - "#ba855f", - "#bb8661", - "#bb8863", - "#bc8965", - "#bd8a67", - "#bd8c69", - "#be8d6b", - "#bf8f6c", - "#c0906e", - "#c09270", - "#c19372", - "#c29474", - "#c29676", - "#c39778", - "#c4997a", - "#c59a7c", - "#c59c7e", - "#c69d80", - "#c79f82", - "#c7a084", - "#c8a286", - "#c9a388", - "#caa48a", - "#caa68c", - "#cba78e", - "#cca990", - "#cdaa92", - "#cdac94", - "#cead96", - "#cfaf98", - "#d0b09a", - "#d0b29c", - "#d1b39e", - "#d2b5a0", - "#d3b6a2", - "#d3b8a4", - "#d4b9a6", - "#d5bba8", - "#d6bcaa", - "#d6beac", - "#d7bfae", - "#d8c1b0", - "#d9c2b2", - "#dac4b4", - "#dac6b6", - "#dbc7b8", - "#dcc9ba", - "#ddcabc", - "#deccbe", - "#dfcdc1", - "#dfcfc3", - "#e0d0c5", - "#e1d2c7", - "#e2d4c9", - "#e3d5cb", - "#e4d7cd", - "#e5d8cf", - "#e5dad1", - "#e6dcd3", - "#e7ddd5", - "#e8dfd7", - "#e9e0d9", - "#eae2db", - "#ebe4dd", - "#ece5df", - "#ede7e2", - "#eee8e4", - "#efeae6", - "#f0ece8", - "#f1edea", - "#f2efec", - "#f4f0ee", - "#f5f2f0", - "#f6f4f2", - "#f7f5f4", - "#f9f7f6", - "#faf8f7", - "#fbfaf9", - "#fcfcfb", - "#fefdfd", - "#FFFFFF", - "#fdfdfd", - "#fbfbfb", - "#f8f9f9", - "#f6f8f7", - "#f4f6f6", - "#f2f4f4", - "#f0f2f2", - "#eef0f0", - "#ebeeee", - "#e9edec", - "#e7ebeb", - "#e5e9e9", - "#e3e7e7", - "#e1e6e5", - "#dfe4e3", - "#dde2e2", - "#dbe0e0", - "#d9dfde", - "#d7dddc", - "#d5dbdb", - "#d3dad9", - "#d1d8d7", - "#cfd6d6", - "#cdd5d4", - "#cbd3d2", - "#c9d1d1", - "#c7d0cf", - "#c5cecd", - "#c3cccc", - "#c1cbca", - "#bfc9c8", - "#bdc8c7", - "#bbc6c5", - "#b9c5c3", - "#b7c3c2", - "#b5c1c0", - "#b3c0bf", - "#b1bebd", - "#afbdbb", - "#adbbba", - "#acbab8", - "#aab8b7", - "#a8b7b5", - "#a6b5b4", - "#a4b4b2", - "#a2b2b1", - "#a0b1af", - "#9eafae", - "#9daeac", - "#9bacab", - "#99aba9", - "#97a9a8", - "#95a8a6", - "#93a7a5", - "#92a5a3", - "#90a4a2", - "#8ea2a0", - "#8ca19f", - "#8a9f9d", - "#899e9c", - "#879d9b", - "#859b99", - "#839a98", - "#819896", - "#809795", - "#7e9693", - "#7c9492", - "#7a9391", - "#79928f", - "#77908e", - "#758f8c", - "#738e8b", - "#728c8a", - "#708b88", - "#6e8a87", - "#6c8886", - "#6b8784", - "#698683", - "#678482", - "#658380", - "#64827f", - "#62807d", - "#607f7c", - "#5e7e7b", - "#5d7c79", - "#5b7b78", - "#597a77", - "#577975", - "#557774", - "#547673", - "#527572", - "#507370", - "#4e726f", - "#4d716e", - "#4b706c", - "#496e6b", - "#476d6a", - "#456c68", - "#446b67", - "#426966", - "#406865", - "#3e6763", - "#3c6662", - "#3b6461", - "#39635f", - "#37625e", - "#35615d", - "#335f5c", - "#315e5a", - "#2f5d59", - "#2d5c58", - "#2b5a57", - "#295955", - "#275854", - "#255753", - "#225552", - "#205450", - "#1e534f", - "#1b524e", - "#19504d", - "#164f4b", - "#134e4a", - ], -} +from tidy3d._common.components.viz.flex_color_palettes import ( + CATEGORICAL_PALETTES_HEX, + DIVERGING_PALETTES_HEX, + SEQUENTIAL_PALETTES_HEX, +) diff --git a/tidy3d/components/viz/flex_style.py b/tidy3d/components/viz/flex_style.py index 0706826fca..c26686d494 100644 --- a/tidy3d/components/viz/flex_style.py +++ b/tidy3d/components/viz/flex_style.py @@ -1,46 +1,12 @@ -from __future__ import annotations - -from tidy3d.log import log - -_ORIGINAL_PARAMS = None - - -def apply_tidy3d_params() -> None: - """ - Applies a set of defaults to the matplotlib params that are following the tidy3d color palettes and design. - """ - global _ORIGINAL_PARAMS - try: - import matplotlib as mpl - import matplotlib.pyplot as plt +"""Compatibility shim for :mod:`tidy3d._common.components.viz.flex_style`.""" - _ORIGINAL_PARAMS = mpl.rcParams.copy() +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - try: - plt.style.use("tidy3d.style") - except Exception as e: - log.error(f"Failed to apply Tidy3D plotting style on import. Error: {e}") - _ORIGINAL_PARAMS = {} - except ImportError: - pass - - -def restore_matplotlib_rcparams() -> None: - """ - Resets matplotlib rcParams to the values they had before the Tidy3D - style was automatically applied on import. - """ - global _ORIGINAL_PARAMS - try: - import matplotlib.pyplot as plt - from matplotlib import style - - if not _ORIGINAL_PARAMS: - style.use("default") - return +# marked as migrated to _common +from __future__ import annotations - plt.rcParams.update(_ORIGINAL_PARAMS) - except ImportError: - log.error("Matplotlib is not installed on your system. Failed to reset to default styles.") - except Exception as e: - log.error(f"Failed to reset previous Matplotlib style. Error: {e}") +from tidy3d._common.components.viz.flex_style import ( + _ORIGINAL_PARAMS, + apply_tidy3d_params, + restore_matplotlib_rcparams, +) diff --git a/tidy3d/components/viz/plot_params.py b/tidy3d/components/viz/plot_params.py index f01920667e..e6bd14d668 100644 --- a/tidy3d/components/viz/plot_params.py +++ b/tidy3d/components/viz/plot_params.py @@ -1,87 +1,27 @@ -from __future__ import annotations - -from typing import Any - -import pydantic.v1 as pd -from numpy import inf - -from tidy3d.components.base import Tidy3dBaseModel - - -class AbstractPlotParams(Tidy3dBaseModel): - """Abstract class for storing plotting parameters. - Corresponds with select properties of ``matplotlib.artist.Artist``. - """ - - alpha: Any = pd.Field(1.0, title="Opacity") - zorder: float = pd.Field(None, title="Display Order") - - def include_kwargs(self, **kwargs: Any) -> AbstractPlotParams: - """Update the plot params with supplied kwargs.""" - update_dict = { - key: value - for key, value in kwargs.items() - if key not in ("type",) and value is not None and key in self.__fields__ - } - return self.copy(update=update_dict) - - def override_with_viz_spec(self, viz_spec) -> AbstractPlotParams: - """Override plot params with supplied VisualizationSpec.""" - return self.include_kwargs(**dict(viz_spec)) +"""Compatibility shim for :mod:`tidy3d._common.components.viz.plot_params`.""" - def to_kwargs(self) -> dict: - """Export the plot parameters as kwargs dict that can be supplied to plot function.""" - kwarg_dict = self.dict() - for ignore_key in ("type", "attrs"): - kwarg_dict.pop(ignore_key) - return kwarg_dict - - -class PathPlotParams(AbstractPlotParams): - """Stores plotting parameters / specifications for a path. - Corresponds with select properties of ``matplotlib.lines.Line2D``. - """ - - color: Any = pd.Field(None, title="Color", alias="c") - linewidth: pd.NonNegativeFloat = pd.Field(2, title="Line Width", alias="lw") - linestyle: str = pd.Field("--", title="Line Style", alias="ls") - marker: Any = pd.Field("o", title="Marker Style") - markeredgecolor: Any = pd.Field(None, title="Marker Edge Color", alias="mec") - markerfacecolor: Any = pd.Field(None, title="Marker Face Color", alias="mfc") - markersize: pd.NonNegativeFloat = pd.Field(10, title="Marker Size", alias="ms") - - -class PlotParams(AbstractPlotParams): - """Stores plotting parameters / specifications for a given model. - Corresponds with select properties of ``matplotlib.patches.Patch``. - """ - - edgecolor: Any = pd.Field(None, title="Edge Color", alias="ec") - facecolor: Any = pd.Field(None, title="Face Color", alias="fc") - fill: bool = pd.Field(True, title="Is Filled") - hatch: str = pd.Field(None, title="Hatch Style") - linewidth: pd.NonNegativeFloat = pd.Field(1, title="Line Width", alias="lw") +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -# defaults for different tidy3d objects -plot_params_geometry = PlotParams() -plot_params_structure = PlotParams() -plot_params_source = PlotParams(alpha=0.4, facecolor="limegreen", edgecolor="limegreen", lw=3) -plot_params_absorber = PlotParams( - alpha=0.4, facecolor="lightskyblue", edgecolor="lightskyblue", lw=3 -) -plot_params_monitor = PlotParams(alpha=0.4, facecolor="orange", edgecolor="orange", lw=3) -plot_params_pml = PlotParams(alpha=0.7, facecolor="gray", edgecolor="gray", hatch="x", zorder=inf) -plot_params_pec = PlotParams(alpha=1.0, facecolor="gold", edgecolor="black", zorder=inf) -plot_params_pmc = PlotParams(alpha=1.0, facecolor="lightsteelblue", edgecolor="black", zorder=inf) -plot_params_bloch = PlotParams(alpha=1.0, facecolor="orchid", edgecolor="black", zorder=inf) -plot_params_abc = PlotParams(alpha=1.0, facecolor="lightskyblue", edgecolor="black", zorder=inf) -plot_params_symmetry = PlotParams(edgecolor="gray", facecolor="gray", alpha=0.6, zorder=inf) -plot_params_override_structures = PlotParams( - linewidth=0.4, edgecolor="black", fill=False, zorder=inf -) -plot_params_fluid = PlotParams(facecolor="white", edgecolor="lightsteelblue", lw=0.4, hatch="xx") -plot_params_grid = PlotParams(edgecolor="black", lw=0.2) -plot_params_lumped_element = PlotParams( - alpha=0.4, facecolor="mediumblue", edgecolor="mediumblue", lw=3 +from tidy3d._common.components.viz.plot_params import ( + AbstractPlotParams, + PathPlotParams, + PlotParams, + plot_params_abc, + plot_params_absorber, + plot_params_bloch, + plot_params_fluid, + plot_params_geometry, + plot_params_grid, + plot_params_lumped_element, + plot_params_monitor, + plot_params_override_structures, + plot_params_pec, + plot_params_pmc, + plot_params_pml, + plot_params_source, + plot_params_structure, + plot_params_symmetry, ) diff --git a/tidy3d/components/viz/plot_sim_3d.py b/tidy3d/components/viz/plot_sim_3d.py index 7111309446..e6de969fbd 100644 --- a/tidy3d/components/viz/plot_sim_3d.py +++ b/tidy3d/components/viz/plot_sim_3d.py @@ -1,183 +1,11 @@ -from __future__ import annotations - -from html import escape - -from tidy3d.exceptions import SetupError - - -def plot_scene_3d(scene, width=800, height=800) -> None: - import gzip - import json - from base64 import b64encode - from io import BytesIO - - import h5py - - # Serialize scene to HDF5 in-memory - buffer = BytesIO() - scene.to_hdf5(buffer) - buffer.seek(0) - - # Open source HDF5 for reading and prepare modified copy - with h5py.File(buffer, "r") as src: - buffer2 = BytesIO() - with h5py.File(buffer2, "w") as dst: - - def copy_item(name, obj) -> None: - if isinstance(obj, h5py.Group): - dst.create_group(name) - for k, v in obj.attrs.items(): - dst[name].attrs[k] = v - elif isinstance(obj, h5py.Dataset): - data = obj[()] - if name == "JSON_STRING": - # Parse and update JSON string - json_str = ( - data.decode("utf-8") if isinstance(data, (bytes, bytearray)) else data - ) - json_data = json.loads(json_str) - json_data["size"] = list(scene.size) - json_data["center"] = list(scene.center) - json_data["grid_spec"] = {} - new_str = json.dumps(json_data) - dst.create_dataset(name, data=new_str.encode("utf-8")) - else: - dst.create_dataset(name, data=data) - for k, v in obj.attrs.items(): - dst[name].attrs[k] = v - - src.visititems(copy_item) - buffer2.seek(0) - - # Gzip the modified HDF5 - gz_buffer = BytesIO() - with gzip.GzipFile(fileobj=gz_buffer, mode="wb") as gz: - gz.write(buffer2.read()) - gz_buffer.seek(0) - - # Base64 encode and display with gzipped flag - sim_base64 = b64encode(gz_buffer.read()).decode("utf-8") - plot_sim_3d(sim_base64, width=width, height=height, is_gz_base64=True) - - -def plot_sim_3d(sim, width=800, height=800, is_gz_base64=False) -> None: - """Make 3D display of simulation in ipython notebook.""" +"""Compatibility shim for :mod:`tidy3d._common.components.viz.plot_sim_3d`.""" - try: - from IPython.display import HTML, display - except ImportError as e: - raise SetupError( - "3D plotting requires ipython to be installed " - "and the code to be running on a jupyter notebook." - ) from e +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - from base64 import b64encode - from io import BytesIO - - if not is_gz_base64: - buffer = BytesIO() - sim.to_hdf5_gz(buffer) - buffer.seek(0) - base64 = b64encode(buffer.read()).decode("utf-8") - else: - base64 = sim - - js_code = """ - /** - * Simulation Viewer Injector - * - * Monitors the document for elements being added in the form: - * - *
- * - * This script will then inject an iframe to the viewer application, and pass it the simulation data - * via the postMessage API on request. The script may be safely included multiple times, with only the - * configuration of the first started script (e.g. viewer URL) applying. - * - */ - (function() { - const TARGET_CLASS = "simulation-viewer"; - const ACTIVE_CLASS = "simulation-viewer-active"; - const VIEWER_URL = "https://tidy3d.simulation.cloud/simulation-viewer"; - - class SimulationViewerInjector { - constructor() { - for (var node of document.getElementsByClassName(TARGET_CLASS)) { - this.injectViewer(node); - } - - // Monitor for newly added nodes to the DOM - this.observer = new MutationObserver(this.onMutations.bind(this)); - this.observer.observe(document.body, {childList: true, subtree: true}); - } - - onMutations(mutations) { - for (var mutation of mutations) { - if (mutation.type === 'childList') { - /** - * Have found that adding the element does not reliably trigger the mutation observer. - * It may be the case that setting content with innerHTML does not trigger. - * - * It seems to be sufficient to re-scan the document for un-activated viewers - * whenever an event occurs, as Jupyter triggers multiple events on cell evaluation. - */ - var viewers = document.getElementsByClassName(TARGET_CLASS); - for (var node of viewers) { - this.injectViewer(node); - } - } - } - } - - injectViewer(node) { - // (re-)check that this is a valid simulation container and has not already been injected - if (node.classList.contains(TARGET_CLASS) && !node.classList.contains(ACTIVE_CLASS)) { - // Mark node as injected, to prevent re-runs - node.classList.add(ACTIVE_CLASS); - - var uuid; - if (window.crypto && window.crypto.randomUUID) { - uuid = window.crypto.randomUUID(); - } else { - uuid = "" + Math.random(); - } - - var frame = document.createElement("iframe"); - frame.width = node.dataset.width || 800; - frame.height = node.dataset.height || 800; - frame.style.cssText = `width:${frame.width}px;height:${frame.height}px;max-width:none;border:0;display:block` - frame.src = VIEWER_URL + "?uuid=" + uuid; - - var postMessageToViewer; - postMessageToViewer = event => { - if(event.data.type === 'viewer' && event.data.uuid===uuid){ - frame.contentWindow.postMessage({ type: 'jupyter', uuid, value: node.dataset.simulation, fileType: 'hdf5'}, '*'); - - // Run once only - window.removeEventListener('message', postMessageToViewer); - } - }; - window.addEventListener( - 'message', - postMessageToViewer, - false - ); - - node.appendChild(frame); - } - } - } - - if (!window.simulationViewerInjector) { - window.simulationViewerInjector = new SimulationViewerInjector(); - } - })(); - """ - html_code = f""" -
- - """ +# marked as migrated to _common +from __future__ import annotations - return display(HTML(html_code)) +from tidy3d._common.components.viz.plot_sim_3d import ( + plot_scene_3d, + plot_sim_3d, +) diff --git a/tidy3d/components/viz/styles.py b/tidy3d/components/viz/styles.py index 067afa9327..77f0a87390 100644 --- a/tidy3d/components/viz/styles.py +++ b/tidy3d/components/viz/styles.py @@ -1,41 +1,21 @@ -from __future__ import annotations - -try: - from matplotlib.patches import ArrowStyle +"""Compatibility shim for :mod:`tidy3d._common.components.viz.styles`.""" - arrow_style = ArrowStyle.Simple(head_length=11, head_width=9, tail_width=4) -except ImportError: - arrow_style = None +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -FLEXCOMPUTE_COLORS = { - "brand_green": "#00643C", - "brand_tan": "#B8A18B", - "brand_blue": "#6DB5DD", - "brand_purple": "#8851AD", - "brand_black": "#000000", - "brand_orange": "#FC7A4C", -} -ARROW_COLOR_SOURCE = FLEXCOMPUTE_COLORS["brand_green"] -ARROW_COLOR_POLARIZATION = FLEXCOMPUTE_COLORS["brand_tan"] -ARROW_COLOR_MONITOR = FLEXCOMPUTE_COLORS["brand_orange"] -ARROW_COLOR_ABSORBER = FLEXCOMPUTE_COLORS["brand_blue"] -PLOT_BUFFER = 0.3 -ARROW_ALPHA = 0.8 -ARROW_LENGTH = 0.3 - -# stores color of simulation.structures for given index in simulation.medium_map -MEDIUM_CMAP = [ - "#689DBC", - "#D0698E", - "#5E6EAD", - "#C6224E", - "#BDB3E2", - "#9EC3E0", - "#616161", - "#877EBC", -] +# marked as migrated to _common +from __future__ import annotations -# colormap for structure's permittivity in plot_eps -STRUCTURE_EPS_CMAP = "gist_yarg" -STRUCTURE_EPS_CMAP_R = "gist_yarg_r" -STRUCTURE_HEAT_COND_CMAP = "gist_yarg" +from tidy3d._common.components.viz.styles import ( + ARROW_ALPHA, + ARROW_COLOR_ABSORBER, + ARROW_COLOR_MONITOR, + ARROW_COLOR_POLARIZATION, + ARROW_COLOR_SOURCE, + ARROW_LENGTH, + FLEXCOMPUTE_COLORS, + MEDIUM_CMAP, + PLOT_BUFFER, + STRUCTURE_EPS_CMAP, + STRUCTURE_EPS_CMAP_R, + STRUCTURE_HEAT_COND_CMAP, +) diff --git a/tidy3d/components/viz/visualization_spec.py b/tidy3d/components/viz/visualization_spec.py index abe22ed7cf..58070983c7 100644 --- a/tidy3d/components/viz/visualization_spec.py +++ b/tidy3d/components/viz/visualization_spec.py @@ -1,62 +1,12 @@ -from __future__ import annotations - -from typing import Any, Optional - -import pydantic.v1 as pd - -from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.log import log - -MATPLOTLIB_IMPORTED = True -try: - from matplotlib.colors import is_color_like -except ImportError: - is_color_like = None - MATPLOTLIB_IMPORTED = False - - -def is_valid_color(value: str) -> str: - if not MATPLOTLIB_IMPORTED: - log.warning( - "matplotlib was not successfully imported, but is required " - "to validate colors in the VisualizationSpec. The specified colors " - "have not been validated." - ) - else: - if is_color_like is not None and not is_color_like(value): - raise ValueError(f"{value} is not a valid plotting color") +"""Compatibility shim for :mod:`tidy3d._common.components.viz.visualization_spec`.""" - return value +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -class VisualizationSpec(Tidy3dBaseModel): - """Defines specification for visualization when used with plotting functions.""" - - facecolor: str = pd.Field( - "", - title="Face color", - description="Color applied to the faces in visualization.", - ) - - edgecolor: Optional[str] = pd.Field( - "", - title="Edge color", - description="Color applied to the edges in visualization.", - ) - - alpha: Optional[pd.confloat(ge=0.0, le=1.0)] = pd.Field( - 1.0, - title="Opacity", - description="Opacity/alpha value in plotting between 0 and 1.", - ) - - @pd.validator("facecolor", always=True) - def validate_color(value: str) -> str: - return is_valid_color(value) - - @pd.validator("edgecolor", always=True) - def validate_and_copy_color(value: str, values: dict[str, Any]) -> str: - if (value == "") and "facecolor" in values: - return is_valid_color(values["facecolor"]) - - return is_valid_color(value) +from tidy3d._common.components.viz.visualization_spec import ( + MATPLOTLIB_IMPORTED, + VisualizationSpec, + is_valid_color, +) diff --git a/tidy3d/config/__init__.py b/tidy3d/config/__init__.py index 8865c1ec95..90a9963354 100644 --- a/tidy3d/config/__init__.py +++ b/tidy3d/config/__init__.py @@ -1,69 +1,32 @@ -"""Tidy3D configuration system public API.""" +"""Compatibility shim for :mod:`tidy3d._common.config`.""" -from __future__ import annotations +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -from typing import Any +# marked as migrated to _common +from __future__ import annotations -from . import sections # noqa: F401 - ensure builtin sections register -from .legacy import LegacyConfigWrapper, LegacyEnvironment, LegacyEnvironmentConfig -from .manager import ConfigManager -from .registry import ( +import tidy3d._common.config as _common_config +from tidy3d.config import sections + +_common_config.initialize_env() + +from tidy3d._common.config import ( # noqa: E402 - import after Env setup + ConfigManager, + Env, + Environment, + EnvironmentConfig, + LegacyConfigWrapper, + LegacyEnvironment, + LegacyEnvironmentConfig, + _base_manager, + _config_wrapper, + _create_manager, + config, get_handlers, + get_manager, get_sections, register_handler, register_plugin, register_section, + reload_config, ) - -__all__ = [ - "ConfigManager", - "Env", - "Environment", - "EnvironmentConfig", - "config", - "get_handlers", - "get_sections", - "register_handler", - "register_plugin", - "register_section", -] - - -def _create_manager() -> ConfigManager: - return ConfigManager() - - -_base_manager = _create_manager() -# TODO(FXC-3827): Drop LegacyConfigWrapper once legacy accessors are removed in Tidy3D 2.12. -_config_wrapper = LegacyConfigWrapper(_base_manager) -config = _config_wrapper - -# TODO(FXC-3827): Remove legacy Env exports after deprecation window (planned 2.12). -Environment = LegacyEnvironment -EnvironmentConfig = LegacyEnvironmentConfig -Env = LegacyEnvironment(_base_manager) - - -def reload_config(*, profile: str | None = None) -> LegacyConfigWrapper: - """Recreate the global configuration manager (primarily for tests).""" - - global _base_manager, Env - if _base_manager is not None: - try: - _base_manager.apply_web_env({}) - except AttributeError: - pass - _base_manager = ConfigManager(profile=profile) - _config_wrapper.reset_manager(_base_manager) - Env.reset_manager(_base_manager) - return _config_wrapper - - -def get_manager() -> ConfigManager: - """Return the underlying configuration manager instance.""" - - return _base_manager - - -def __getattr__(name: str) -> Any: - return getattr(config, name) diff --git a/tidy3d/config/legacy.py b/tidy3d/config/legacy.py index 75d94d20f1..1356bafe49 100644 --- a/tidy3d/config/legacy.py +++ b/tidy3d/config/legacy.py @@ -1,536 +1,16 @@ -"""Legacy compatibility layer for tidy3d.config. +"""Compatibility shim for :mod:`tidy3d._common.config.legacy`.""" -This module holds (most) of the compatibility layer to the pre-2.10 tidy3d config -and is intended to be removed in a future release. -""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common from __future__ import annotations -import os -import warnings -from pathlib import Path -from typing import Any, Optional - -import toml - -from tidy3d._runtime import WASM_BUILD -from tidy3d.log import LogLevel, log - -# TODO(FXC-3827): Remove LegacyConfigWrapper/Environment shims and related helpers in Tidy3D 2.12. -from .manager import ConfigManager, normalize_profile_name -from .profiles import BUILTIN_PROFILES - - -def _warn_env_deprecated() -> None: - message = "'tidy3d.config.Env' is deprecated; use 'config.switch_profile(...)' instead." - warnings.warn(message, DeprecationWarning, stacklevel=3) - log.warning(message, log_once=True) - - -# TODO(FXC-3827): Delete LegacyConfigWrapper once legacy attribute access is dropped. -class LegacyConfigWrapper: - """Provide attribute-level compatibility with the legacy config module.""" - - def __init__(self, manager: ConfigManager): - self._manager = manager - self._frozen = False # retained for backwards compatibility tests - - @property - def logging_level(self) -> LogLevel: - return self._manager.get_section("logging").level - - @logging_level.setter - def logging_level(self, value: LogLevel) -> None: - from warnings import warn - - warn( - "'config.logging_level' is deprecated; use 'config.logging.level' instead.", - DeprecationWarning, - stacklevel=2, - ) - self._manager.update_section("logging", level=value) - - @property - def log_suppression(self) -> bool: - return self._manager.get_section("logging").suppression - - @log_suppression.setter - def log_suppression(self, value: bool) -> None: - from warnings import warn - - warn( - "'config.log_suppression' is deprecated; use 'config.logging.suppression'.", - DeprecationWarning, - stacklevel=2, - ) - self._manager.update_section("logging", suppression=value) - - @property - def use_local_subpixel(self) -> Optional[bool]: - return self._manager.get_section("simulation").use_local_subpixel - - @use_local_subpixel.setter - def use_local_subpixel(self, value: Optional[bool]) -> None: - from warnings import warn - - warn( - "'config.use_local_subpixel' is deprecated; use 'config.simulation.use_local_subpixel'.", - DeprecationWarning, - stacklevel=2, - ) - self._manager.update_section("simulation", use_local_subpixel=value) - - @property - def suppress_rf_license_warning(self) -> bool: - return self._manager.get_section("microwave").suppress_rf_license_warning - - @suppress_rf_license_warning.setter - def suppress_rf_license_warning(self, value: bool) -> None: - from warnings import warn - - warn( - "'config.suppress_rf_license_warning' is deprecated; " - "use 'config.microwave.suppress_rf_license_warning'.", - DeprecationWarning, - stacklevel=2, - ) - self._manager.update_section("microwave", suppress_rf_license_warning=value) - - @property - def frozen(self) -> bool: - return self._frozen - - @frozen.setter - def frozen(self, value: bool) -> None: - self._frozen = bool(value) - - def save(self, include_defaults: bool = False) -> None: - self._manager.save(include_defaults=include_defaults) - - def reset_manager(self, manager: ConfigManager) -> None: - """Swap the underlying manager instance.""" - - self._manager = manager - - def switch_profile(self, profile: str) -> None: - """Switch active profile and synchronize the legacy environment proxy.""" - - normalized = normalize_profile_name(profile) - self._manager.switch_profile(normalized) - try: - from tidy3d.config import Env as _legacy_env - except Exception: - _legacy_env = None - if _legacy_env is not None: - _legacy_env._sync_to_manager(apply_env=True) - - def __getattr__(self, name: str) -> Any: - return getattr(self._manager, name) - - def __setattr__(self, name: str, value: Any) -> None: - if name.startswith("_"): - object.__setattr__(self, name, value) - elif name in { - "logging_level", - "log_suppression", - "use_local_subpixel", - "suppress_rf_license_warning", - "frozen", - }: - prop = getattr(type(self), name) - prop.fset(self, value) - else: - setattr(self._manager, name, value) - - def __str__(self) -> str: - return self._manager.format() - - -# TODO(FXC-3827): Delete LegacyEnvironmentConfig once profile-based Env shim is removed. -class LegacyEnvironmentConfig: - """Backward compatible environment config wrapper that proxies ConfigManager.""" - - def __init__( - self, - manager: Optional[ConfigManager] = None, - name: Optional[str] = None, - *, - web_api_endpoint: Optional[str] = None, - website_endpoint: Optional[str] = None, - s3_region: Optional[str] = None, - ssl_verify: Optional[bool] = None, - enable_caching: Optional[bool] = None, - ssl_version: Optional[str] = None, - env_vars: Optional[dict[str, str]] = None, - environment: Optional[LegacyEnvironment] = None, - ) -> None: - if name is None: - raise ValueError("Environment name is required") - self._manager = manager - self._name = normalize_profile_name(name) - self._environment = environment - self._pending: dict[str, Any] = {} - if web_api_endpoint is not None: - self._pending["api_endpoint"] = web_api_endpoint - if website_endpoint is not None: - self._pending["website_endpoint"] = website_endpoint - if s3_region is not None: - self._pending["s3_region"] = s3_region - if ssl_verify is not None: - self._pending["ssl_verify"] = ssl_verify - if enable_caching is not None: - self._pending["enable_caching"] = enable_caching - if ssl_version is not None: - self._pending["ssl_version"] = ssl_version - if env_vars is not None: - self._pending["env_vars"] = dict(env_vars) - - def reset_manager(self, manager: ConfigManager) -> None: - self._manager = manager - - @property - def manager(self) -> Optional[ConfigManager]: - if self._manager is not None: - return self._manager - if self._environment is not None: - return self._environment._manager - return None - - def active(self) -> None: - _warn_env_deprecated() - environment = self._environment - if environment is None: - from tidy3d.config import Env # local import to avoid circular - - environment = Env - - environment.set_current(self) - - @property - def web_api_endpoint(self) -> Optional[str]: - value = self._value("api_endpoint") - return _maybe_str(value) - - @property - def website_endpoint(self) -> Optional[str]: - value = self._value("website_endpoint") - return _maybe_str(value) - - @property - def s3_region(self) -> Optional[str]: - return self._value("s3_region") - - @property - def ssl_verify(self) -> bool: - value = self._value("ssl_verify") - if value is None: - return True - return bool(value) - - @property - def enable_caching(self) -> bool: - value = self._value("enable_caching") - if value is None: - return True - return bool(value) - - @enable_caching.setter - def enable_caching(self, value: Optional[bool]) -> None: - self._set_pending("enable_caching", value) - - @property - def ssl_version(self) -> Optional[str]: - return self._value("ssl_version") - - @ssl_version.setter - def ssl_version(self, value: Optional[str]) -> None: - self._set_pending("ssl_version", value) - - @property - def env_vars(self) -> dict[str, str]: - value = self._value("env_vars") - if value is None: - return {} - return dict(value) - - @env_vars.setter - def env_vars(self, value: dict[str, str]) -> None: - self._set_pending("env_vars", dict(value)) - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, value: str) -> None: - self._name = normalize_profile_name(value) - - def copy_state_from(self, other: LegacyEnvironmentConfig) -> None: - if not isinstance(other, LegacyEnvironmentConfig): - raise TypeError("Expected LegacyEnvironmentConfig instance.") - for key, value in other._pending.items(): - if key == "env_vars" and value is not None: - self._pending[key] = dict(value) - else: - self._pending[key] = value - - def get_real_url(self, path: str) -> str: - manager = self.manager - if manager is not None and manager.profile == self._name: - web_section = manager.get_section("web") - if hasattr(web_section, "build_api_url"): - return web_section.build_api_url(path) - - endpoint = self.web_api_endpoint or "" - if not path: - return endpoint - return "/".join([endpoint.rstrip("/"), str(path).lstrip("/")]) - - def apply_pending_overrides(self) -> None: - manager = self.manager - if manager is None or manager.profile != self._name: - return - if not self._pending: - return - updates = dict(self._pending) - manager.update_section("web", **updates) - self._pending.clear() - - def _set_pending(self, key: str, value: Any) -> None: - if key == "env_vars" and value is not None: - self._pending[key] = dict(value) - else: - self._pending[key] = value - self.apply_pending_overrides() - - def _web_section(self) -> dict[str, Any]: - manager = self.manager - if manager is None or WASM_BUILD: - return {} - profile = normalize_profile_name(self._name) - if manager.profile == profile: - section = manager.get_section("web") - return section.model_dump(mode="python", exclude_unset=False) - preview = manager.preview_profile(profile) - source = preview.get("web", {}) - return dict(source) if isinstance(source, dict) else {} - - def _value(self, key: str) -> Any: - if key in self._pending: - return self._pending[key] - return self._web_section().get(key) - - -# TODO(FXC-3827): Delete LegacyEnvironment after deprecating `tidy3d.config.Env`. -class LegacyEnvironment: - """Legacy Env wrapper that maps to profiles.""" - - def __init__(self, manager: ConfigManager): - self._previous_env_vars: dict[str, Optional[str]] = {} - self.env_map: dict[str, LegacyEnvironmentConfig] = {} - self._current: Optional[LegacyEnvironmentConfig] = None - self._manager: Optional[ConfigManager] = None - self._applied_profile: Optional[str] = None - self.reset_manager(manager) - - def reset_manager(self, manager: ConfigManager) -> None: - self._manager = manager - self.env_map = {} - for name in BUILTIN_PROFILES: - key = normalize_profile_name(name) - self.env_map[key] = LegacyEnvironmentConfig(manager, key, environment=self) - self._applied_profile = None - self._current = None - self._sync_to_manager(apply_env=True) - - @property - def current(self) -> LegacyEnvironmentConfig: - self._sync_to_manager() - assert self._current is not None - return self._current - - def set_current(self, env_config: LegacyEnvironmentConfig) -> None: - _warn_env_deprecated() - key = normalize_profile_name(env_config.name) - stored = self._get_config(key) - stored.copy_state_from(env_config) - if self._manager and self._manager.profile != key: - self._manager.switch_profile(key) - self._sync_to_manager(apply_env=True) - - def enable_caching(self, enable_caching: Optional[bool] = True) -> None: - config = self.current - config.enable_caching = enable_caching - self._sync_to_manager() - - def set_ssl_version(self, ssl_version: Optional[str]) -> None: - config = self.current - config.ssl_version = ssl_version - self._sync_to_manager() - - def __getattr__(self, name: str) -> LegacyEnvironmentConfig: - return self._get_config(name) - - def _get_config(self, name: str) -> LegacyEnvironmentConfig: - key = normalize_profile_name(name) - config = self.env_map.get(key) - if config is None: - config = LegacyEnvironmentConfig(self._manager, key, environment=self) - self.env_map[key] = config - else: - manager = self._manager - if manager is not None: - config.reset_manager(manager) - config._environment = self - return config - - def _sync_to_manager(self, *, apply_env: bool = False) -> None: - if self._manager is None: - return - active = normalize_profile_name(self._manager.profile) - config = self._get_config(active) - config.apply_pending_overrides() - self._current = config - if apply_env or self._applied_profile != active: - self._apply_env_vars(config) - self._applied_profile = active - - def _apply_env_vars(self, config: LegacyEnvironmentConfig) -> None: - self._restore_env_vars() - env_vars = config.env_vars or {} - self._previous_env_vars = {} - for key, value in env_vars.items(): - self._previous_env_vars[key] = os.environ.get(key) - os.environ[key] = value - - def _restore_env_vars(self) -> None: - for key, previous in self._previous_env_vars.items(): - if previous is None: - os.environ.pop(key, None) - else: - os.environ[key] = previous - self._previous_env_vars = {} - - -def _maybe_str(value: Any) -> Optional[str]: - if value is None: - return None - return str(value) - - -def load_legacy_flat_config(config_dir: Path) -> dict[str, Any]: - """Load legacy flat configuration file (pre-migration format). - - This function now supports both the original flat config format and - Nexus custom deployment settings introduced in later versions. - - Legacy key mappings: - - apikey -> web.apikey - - web_api_endpoint -> web.api_endpoint - - website_endpoint -> web.website_endpoint - - s3_region -> web.s3_region - - s3_endpoint -> web.env_vars.AWS_ENDPOINT_URL_S3 - - ssl_verify -> web.ssl_verify - - enable_caching -> web.enable_caching - """ - - legacy_path = config_dir / "config" - if not legacy_path.exists(): - return {} - - try: - text = legacy_path.read_text(encoding="utf-8") - except Exception as exc: - log.warning(f"Failed to read legacy configuration file '{legacy_path}': {exc}") - return {} - - try: - parsed = toml.loads(text) - except Exception as exc: - log.warning(f"Failed to decode legacy configuration file '{legacy_path}': {exc}") - return {} - - legacy_data: dict[str, Any] = {} - - # Migrate API key (original functionality) - apikey = parsed.get("apikey") - if apikey is not None: - legacy_data.setdefault("web", {})["apikey"] = apikey - - # Migrate Nexus API endpoint - web_api = parsed.get("web_api_endpoint") - if web_api is not None: - legacy_data.setdefault("web", {})["api_endpoint"] = web_api - - # Migrate Nexus website endpoint - website = parsed.get("website_endpoint") - if website is not None: - legacy_data.setdefault("web", {})["website_endpoint"] = website - - # Migrate S3 region - s3_region = parsed.get("s3_region") - if s3_region is not None: - legacy_data.setdefault("web", {})["s3_region"] = s3_region - - # Migrate SSL verification setting - ssl_verify = parsed.get("ssl_verify") - if ssl_verify is not None: - legacy_data.setdefault("web", {})["ssl_verify"] = ssl_verify - - # Migrate caching setting - enable_caching = parsed.get("enable_caching") - if enable_caching is not None: - legacy_data.setdefault("web", {})["enable_caching"] = enable_caching - - # Migrate S3 endpoint to env_vars - s3_endpoint = parsed.get("s3_endpoint") - if s3_endpoint is not None: - env_vars = legacy_data.setdefault("web", {}).setdefault("env_vars", {}) - env_vars["AWS_ENDPOINT_URL_S3"] = s3_endpoint - - return legacy_data - - -__all__ = [ - "LegacyConfigWrapper", - "LegacyEnvironment", - "LegacyEnvironmentConfig", - "finalize_legacy_migration", - "load_legacy_flat_config", -] - - -def finalize_legacy_migration(config_dir: Path) -> None: - """Promote a copied legacy configuration tree into the structured format. - - Parameters - ---------- - config_dir : Path - Destination directory (typically the canonical config location). - """ - - legacy_data = load_legacy_flat_config(config_dir) - - from .manager import ConfigManager # local import to avoid circular dependency - - manager = ConfigManager(profile="default", config_dir=config_dir) - config_path = config_dir / "config.toml" - for section, values in legacy_data.items(): - if isinstance(values, dict): - manager.update_section(section, **values) - try: - manager.save(include_defaults=True) - except Exception: - if config_path.exists(): - try: - config_path.unlink() - except Exception: - pass - raise - - legacy_flat_path = config_dir / "config" - if legacy_flat_path.exists(): - try: - legacy_flat_path.unlink() - except Exception as exc: - log.warning(f"Failed to remove legacy configuration file '{legacy_flat_path}': {exc}") +from tidy3d._common.config.legacy import ( + LegacyConfigWrapper, + LegacyEnvironment, + LegacyEnvironmentConfig, + _maybe_str, + _warn_env_deprecated, + finalize_legacy_migration, + load_legacy_flat_config, +) diff --git a/tidy3d/config/loader.py b/tidy3d/config/loader.py index 21782244c4..b614d63401 100644 --- a/tidy3d/config/loader.py +++ b/tidy3d/config/loader.py @@ -1,438 +1,23 @@ -"""Filesystem helpers and persistence utilities for the configuration system.""" +"""Compatibility shim for :mod:`tidy3d._common.config.loader`.""" -from __future__ import annotations - -import os -import shutil -import tempfile -from copy import deepcopy -from pathlib import Path -from typing import Any, Optional - -import toml -import tomlkit - -from tidy3d.log import log - -from .profiles import BUILTIN_PROFILES -from .serializer import build_document, collect_descriptions - - -class ConfigLoader: - """Handle reading and writing configuration files.""" - - def __init__(self, config_dir: Optional[Path] = None): - self.config_dir = config_dir or resolve_config_directory() - self.config_dir.mkdir(mode=0o700, parents=True, exist_ok=True) - self._docs: dict[Path, tomlkit.TOMLDocument] = {} - - def load_base(self) -> dict[str, Any]: - """Load base configuration from config.toml. - - If config.toml doesn't exist but the legacy flat config does, - automatically migrate to the new format. - """ - - config_path = self.config_dir / "config.toml" - data = self._read_toml(config_path) - if data: - return data - - # Check for legacy flat config - from .legacy import load_legacy_flat_config - - legacy_path = self.config_dir / "config" - legacy = load_legacy_flat_config(self.config_dir) - - # Auto-migrate if legacy config exists - if legacy and legacy_path.exists(): - log.info( - f"Detected legacy configuration at '{legacy_path}'. " - "Automatically migrating to new format..." - ) - - try: - # Save in new format - self.save_base(legacy) - - # Rename old config to preserve it - backup_path = legacy_path.with_suffix(".migrated") - legacy_path.rename(backup_path) - - log.info( - f"Migration complete. Configuration saved to '{config_path}'. " - f"Legacy config backed up as '{backup_path.name}'." - ) - - # Re-read the newly created config - return self._read_toml(config_path) - except Exception as exc: - log.warning( - f"Failed to auto-migrate legacy configuration: {exc}. " - "Using legacy data without migration." - ) - return legacy - - if legacy: - return legacy - return {} - - def load_user_profile(self, profile: str) -> dict[str, Any]: - """Load user profile overrides (if any).""" - - if profile in ("default", "prod"): - # default and prod share the same baseline; user overrides live in config.toml - return {} - - profile_path = self.profile_path(profile) - return self._read_toml(profile_path) - - def get_builtin_profile(self, profile: str) -> dict[str, Any]: - """Return builtin profile data if available.""" - - return BUILTIN_PROFILES.get(profile, {}) - - def save_base(self, data: dict[str, Any]) -> None: - """Persist base configuration.""" - - config_path = self.config_dir / "config.toml" - self._atomic_write(config_path, data) - - def save_profile(self, profile: str, data: dict[str, Any]) -> None: - """Persist profile overrides (remove file if empty).""" - - profile_path = self.profile_path(profile) - if not data: - if profile_path.exists(): - profile_path.unlink() - self._docs.pop(profile_path, None) - return - profile_path.parent.mkdir(mode=0o700, parents=True, exist_ok=True) - self._atomic_write(profile_path, data) - - def profile_path(self, profile: str) -> Path: - """Return on-disk path for a profile.""" - - return self.config_dir / "profiles" / f"{profile}.toml" - - def get_default_profile(self) -> Optional[str]: - """Read the default_profile from config.toml. - - Returns - ------- - Optional[str] - The default profile name if set, None otherwise. - """ - - config_path = self.config_dir / "config.toml" - if not config_path.exists(): - return None - - try: - text = config_path.read_text(encoding="utf-8") - data = toml.loads(text) - return data.get("default_profile") - except Exception as exc: - log.warning(f"Failed to read default_profile from '{config_path}': {exc}") - return None - - def set_default_profile(self, profile: Optional[str]) -> None: - """Set the default_profile in config.toml. - - Parameters - ---------- - profile : Optional[str] - The profile name to set as default, or None to remove the setting. - """ - - config_path = self.config_dir / "config.toml" - data = self._read_toml(config_path) - - if profile is None: - # Remove default_profile if it exists - if "default_profile" in data: - del data["default_profile"] - else: - # Set default_profile as a top-level key - data["default_profile"] = profile - - self._atomic_write(config_path, data) - - def _read_toml(self, path: Path) -> dict[str, Any]: - if not path.exists(): - self._docs.pop(path, None) - return {} - - try: - text = path.read_text(encoding="utf-8") - except Exception as exc: - log.warning(f"Failed to read configuration file '{path}': {exc}") - self._docs.pop(path, None) - return {} - - try: - document = tomlkit.parse(text) - except Exception as exc: - log.warning(f"Failed to parse configuration file '{path}': {exc}") - document = tomlkit.document() - self._docs[path] = document - - try: - return toml.loads(text) - except Exception as exc: - log.warning(f"Failed to decode configuration file '{path}': {exc}") - return {} - - def _atomic_write(self, path: Path, data: dict[str, Any]) -> None: - path.parent.mkdir(mode=0o700, parents=True, exist_ok=True) - tmp_dir = path.parent - - cleaned = _clean_data(deepcopy(data)) - descriptions = collect_descriptions() - - base_document = self._docs.get(path) - document = build_document(cleaned, base_document, descriptions) - toml_text = tomlkit.dumps(document) - - with tempfile.NamedTemporaryFile( - "w", dir=tmp_dir, delete=False, encoding="utf-8" - ) as handle: - tmp_path = Path(handle.name) - handle.write(toml_text) - handle.flush() - os.fsync(handle.fileno()) +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - backup_path = path.with_suffix(path.suffix + ".bak") - try: - if path.exists(): - shutil.copy2(path, backup_path) - tmp_path.replace(path) - os.chmod(path, 0o600) - if backup_path.exists(): - backup_path.unlink() - except Exception: - if tmp_path.exists(): - tmp_path.unlink() - if backup_path.exists(): - try: - backup_path.replace(path) - except Exception: - log.warning("Failed to restore configuration backup") - raise - - self._docs[path] = tomlkit.parse(toml_text) - - -def load_environment_overrides() -> dict[str, Any]: - """Parse environment variables into a nested configuration dict.""" - - overrides: dict[str, Any] = {} - for key, value in os.environ.items(): - if key == "SIMCLOUD_APIKEY": - _assign_path(overrides, ("web", "apikey"), value) - continue - if not key.startswith("TIDY3D_"): - continue - rest = key[len("TIDY3D_") :] - if "__" not in rest: - continue - segments = tuple(segment.lower() for segment in rest.split("__") if segment) - if not segments: - continue - if segments[0] == "auth": - segments = ("web",) + segments[1:] - _assign_path(overrides, segments, value) - return overrides - - -def deep_merge(*sources: dict[str, Any]) -> dict[str, Any]: - """Deep merge multiple dictionaries into a new dict.""" - - result: dict[str, Any] = {} - for source in sources: - _merge_into(result, source) - return result - - -def _merge_into(target: dict[str, Any], source: dict[str, Any]) -> None: - for key, value in source.items(): - if isinstance(value, dict): - node = target.setdefault(key, {}) - if isinstance(node, dict): - _merge_into(node, value) - else: - target[key] = deepcopy(value) - else: - target[key] = value - - -def deep_diff(base: dict[str, Any], target: dict[str, Any]) -> dict[str, Any]: - """Return keys from target that differ from base.""" - - diff: dict[str, Any] = {} - keys = set(base.keys()) | set(target.keys()) - for key in keys: - base_value = base.get(key) - target_value = target.get(key) - if isinstance(base_value, dict) and isinstance(target_value, dict): - nested = deep_diff(base_value, target_value) - if nested: - diff[key] = nested - elif target_value != base_value: - if isinstance(target_value, dict): - diff[key] = deepcopy(target_value) - else: - diff[key] = target_value - return diff - - -def _assign_path(target: dict[str, Any], path: tuple[str, ...], value: Any) -> None: - node = target - for segment in path[:-1]: - node = node.setdefault(segment, {}) - node[path[-1]] = value - - -def _clean_data(data: Any) -> Any: - if isinstance(data, dict): - cleaned: dict[str, Any] = {} - for key, value in data.items(): - cleaned_value = _clean_data(value) - if cleaned_value is None: - continue - cleaned[key] = cleaned_value - return cleaned - if isinstance(data, list): - cleaned_list = [_clean_data(item) for item in data] - return [item for item in cleaned_list if item is not None] - if data is None: - return None - return data - - -def legacy_config_directory() -> Path: - """Return the legacy configuration directory (~/.tidy3d).""" - - return Path.home() / ".tidy3d" - - -def canonical_config_directory() -> Path: - """Return the platform-dependent canonical configuration directory.""" - - return _xdg_config_home() / "tidy3d" - - -def resolve_config_directory() -> Path: - """Determine the directory used to store tidy3d configuration files.""" - - base_override = os.getenv("TIDY3D_BASE_DIR") - if base_override: - base_path = Path(base_override).expanduser().resolve() - path = base_path / "config" - if _is_writable(path.parent): - return path - log.warning( - "'TIDY3D_BASE_DIR' is not writable; using temporary configuration directory instead." - ) - return _temporary_config_dir() - - canonical_dir = canonical_config_directory() - if _is_writable(canonical_dir.parent): - legacy_dir = legacy_config_directory() - if legacy_dir.exists(): - log.warning( - f"Using canonical configuration directory at '{canonical_dir}'. " - "Found legacy directory at '~/.tidy3d', which will be ignored. " - "Remove it manually or run 'tidy3d config migrate --delete-legacy' to clean up.", - log_once=True, - ) - return canonical_dir - - legacy_dir = legacy_config_directory() - if legacy_dir.exists(): - log.warning( - "Configuration found in legacy location '~/.tidy3d'. Consider running 'tidy3d config migrate'.", - log_once=True, - ) - return legacy_dir - - log.warning(f"Unable to write to '{canonical_dir}'; falling back to temporary directory.") - return _temporary_config_dir() - - -def _xdg_config_home() -> Path: - xdg_home = os.getenv("XDG_CONFIG_HOME") - if xdg_home: - return Path(xdg_home).expanduser() - return Path.home() / ".config" - - -def _temporary_config_dir() -> Path: - base = Path(tempfile.gettempdir()) / "tidy3d" - base.mkdir(mode=0o700, exist_ok=True) - return base / "config" - - -def _is_writable(path: Path) -> bool: - try: - path.mkdir(parents=True, exist_ok=True) - test_file = path / ".tidy3d_write_test" - with open(test_file, "w", encoding="utf-8"): - pass - test_file.unlink() - return True - except Exception: - return False - - -def migrate_legacy_config(*, overwrite: bool = False, remove_legacy: bool = False) -> Path: - """Copy configuration files from the legacy ``~/.tidy3d`` directory to the canonical location. - - Parameters - ---------- - overwrite : bool - If ``True``, existing files in the canonical directory will be replaced. - remove_legacy : bool - If ``True``, the legacy directory is removed after a successful migration. - - Returns - ------- - Path - The path of the canonical configuration directory. - - Raises - ------ - FileNotFoundError - If the legacy directory does not exist. - FileExistsError - If the destination already exists and ``overwrite`` is ``False``. - RuntimeError - If the legacy and canonical directories resolve to the same location. - """ - - legacy_dir = legacy_config_directory() - if not legacy_dir.exists(): - raise FileNotFoundError("Legacy configuration directory '~/.tidy3d' was not found.") - - canonical_dir = canonical_config_directory() - if canonical_dir.resolve() == legacy_dir.resolve(): - raise RuntimeError( - "Legacy and canonical configuration directories are the same path; nothing to migrate." - ) - - if canonical_dir.exists() and not overwrite: - raise FileExistsError( - f"Destination '{canonical_dir}' already exists. Pass overwrite=True to replace existing files." - ) - - canonical_dir.parent.mkdir(parents=True, exist_ok=True) - shutil.copytree(legacy_dir, canonical_dir, dirs_exist_ok=overwrite) - - from .legacy import finalize_legacy_migration # local import to avoid circular dependency - - finalize_legacy_migration(canonical_dir) - - if remove_legacy: - shutil.rmtree(legacy_dir) +# marked as migrated to _common +from __future__ import annotations - return canonical_dir +from tidy3d._common.config.loader import ( + ConfigLoader, + _assign_path, + _clean_data, + _is_writable, + _merge_into, + _temporary_config_dir, + _xdg_config_home, + canonical_config_directory, + deep_diff, + deep_merge, + legacy_config_directory, + load_environment_overrides, + migrate_legacy_config, + resolve_config_directory, +) diff --git a/tidy3d/config/manager.py b/tidy3d/config/manager.py index cbb08a4fdc..19dd6975b4 100644 --- a/tidy3d/config/manager.py +++ b/tidy3d/config/manager.py @@ -1,631 +1,24 @@ -"""Central configuration manager implementation.""" +"""Compatibility shim for :mod:`tidy3d._common.config.manager`.""" -from __future__ import annotations - -import os -import shutil -from collections import defaultdict -from collections.abc import Iterable, Mapping -from copy import deepcopy -from enum import Enum -from io import StringIO -from pathlib import Path -from typing import Any, Optional, get_args, get_origin - -from pydantic import BaseModel -from rich.console import Console -from rich.panel import Panel -from rich.pretty import Pretty -from rich.text import Text -from rich.tree import Tree - -from tidy3d.log import log - -from .loader import ConfigLoader, deep_diff, deep_merge, load_environment_overrides -from .profiles import BUILTIN_PROFILES -from .registry import attach_manager, get_handlers, get_sections - - -def normalize_profile_name(name: str) -> str: - """Return a canonical profile name for builtin profiles.""" - - normalized = name.strip() - lowered = normalized.lower() - if lowered in BUILTIN_PROFILES: - return lowered - return normalized - - -class SectionAccessor: - """Attribute proxy that routes assignments back through the manager.""" - - def __init__(self, manager: ConfigManager, path: str): - self._manager = manager - self._path = path - - def __getattr__(self, name: str) -> Any: - model = self._manager._get_model(self._path) - if model is None: - raise AttributeError(f"Section '{self._path}' is not available") - return getattr(model, name) - - def __setattr__(self, name: str, value: Any) -> None: - if name.startswith("_"): - object.__setattr__(self, name, value) - return - self._manager.update_section(self._path, **{name: value}) - - def __repr__(self) -> str: - model = self._manager._get_model(self._path) - return f"SectionAccessor({self._path}={model!r})" - - def __rich__(self) -> Panel: - model = self._manager._get_model(self._path) - if model is None: - return Panel(Text(f"Section '{self._path}' is unavailable", style="red")) - data = _prepare_for_display(model.model_dump(exclude_unset=False)) - return _build_section_panel(self._path, data) - - def dict(self, *args: Any, **kwargs: Any) -> dict[str, Any]: - model = self._manager._get_model(self._path) - if model is None: - return {} - return model.model_dump(*args, **kwargs) - - def __str__(self) -> str: - return self._manager.format_section(self._path) - - -class PluginsAccessor: - """Provides access to registered plugin configurations.""" - - def __init__(self, manager: ConfigManager): - self._manager = manager - - def __getattr__(self, plugin: str) -> SectionAccessor: - if plugin not in self._manager._plugin_models: - raise AttributeError(f"Plugin '{plugin}' is not registered") - return SectionAccessor(self._manager, f"plugins.{plugin}") - - def list(self) -> Iterable[str]: - return sorted(self._manager._plugin_models.keys()) - - -class ProfilesAccessor: - """Read-only profile helper.""" - - def __init__(self, manager: ConfigManager): - self._manager = manager - - def list(self) -> dict[str, list[str]]: - return self._manager.list_profiles() - - def __getattr__(self, profile: str) -> dict[str, Any]: - return self._manager.preview_profile(profile) - - -class ConfigManager: - """High-level orchestrator for tidy3d configuration.""" - - def __init__( - self, - profile: Optional[str] = None, - config_dir: Optional[os.PathLike[str]] = None, - ): - loader_path = None if config_dir is None else Path(config_dir) - self._loader = ConfigLoader(loader_path) - self._runtime_overrides: dict[str, dict[str, Any]] = defaultdict(dict) - self._plugin_models: dict[str, BaseModel] = {} - self._section_models: dict[str, BaseModel] = {} - self._profile = self._resolve_initial_profile(profile) - self._builtin_data: dict[str, Any] = {} - self._base_data: dict[str, Any] = {} - self._profile_data: dict[str, Any] = {} - self._raw_tree: dict[str, Any] = {} - self._effective_tree: dict[str, Any] = {} - self._env_overrides: dict[str, Any] = load_environment_overrides() - self._web_env_previous: dict[str, Optional[str]] = {} - - attach_manager(self) - self._reload() - - # Notify users when using a non-default profile - if self._profile != "default": - log.info(f"Using configuration profile: '{self._profile}'", log_once=True) - - self._apply_handlers() - - @property - def profile(self) -> str: - return self._profile - - @property - def config_dir(self) -> Path: - return self._loader.config_dir - - @property - def plugins(self) -> PluginsAccessor: - return PluginsAccessor(self) - - @property - def profiles(self) -> ProfilesAccessor: - return ProfilesAccessor(self) - - def update_section(self, name: str, **updates: Any) -> None: - if not updates: - return - segments = name.split(".") - overrides = self._runtime_overrides[self._profile] - previous = deepcopy(overrides) - node = overrides - for segment in segments[:-1]: - node = node.setdefault(segment, {}) - section_key = segments[-1] - section_payload = node.setdefault(section_key, {}) - for key, value in updates.items(): - section_payload[key] = _serialize_value(value) - try: - self._reload() - except Exception: - self._runtime_overrides[self._profile] = previous - raise - self._apply_handlers(section=name) - - def switch_profile(self, profile: str) -> None: - if not profile: - raise ValueError("Profile name cannot be empty") - normalized = normalize_profile_name(profile) - if not normalized: - raise ValueError("Profile name cannot be empty") - self._profile = normalized - self._reload() - - # Notify users when switching to a non-default profile - if self._profile != "default": - log.info(f"Switched to configuration profile: '{self._profile}'") - - self._apply_handlers() - - def set_default_profile(self, profile: Optional[str]) -> None: - """Set the default profile to be used on startup. - - Parameters - ---------- - profile : Optional[str] - The profile name to use as default, or None to clear the default. - When set, this profile will be automatically loaded unless overridden - by environment variables (TIDY3D_CONFIG_PROFILE, TIDY3D_PROFILE, or TIDY3D_ENV). - - Notes - ----- - This setting is persisted to config.toml and survives across sessions. - Environment variables always take precedence over the default profile. - """ - - if profile is not None: - normalized = normalize_profile_name(profile) - if not normalized: - raise ValueError("Profile name cannot be empty") - self._loader.set_default_profile(normalized) - else: - self._loader.set_default_profile(None) - - def get_default_profile(self) -> Optional[str]: - """Get the currently configured default profile. - - Returns - ------- - Optional[str] - The default profile name if set, None otherwise. - """ - - return self._loader.get_default_profile() - - def save(self, include_defaults: bool = False) -> None: - if self._profile == "default": - # For base config: only save fields marked with persist=True - base_without_env = self._filter_persisted(self._compose_without_env()) - if include_defaults: - defaults = self._filter_persisted(self._default_tree()) - base_without_env = deep_merge(defaults, base_without_env) - self._loader.save_base(base_without_env) - else: - # For profile overrides: save any field that differs from baseline - # (don't filter by persist flag - profiles should save all customizations) - base_without_env = self._compose_without_env() - baseline = deep_merge(self._builtin_data, self._base_data) - diff = deep_diff(baseline, base_without_env) - self._loader.save_profile(self._profile, diff) - # refresh cached base/profile data after saving - self._base_data = self._loader.load_base() - self._profile_data = self._loader.load_user_profile(self._profile) - self._reload() - - def reset_to_defaults(self, *, include_profiles: bool = True) -> None: - """Reset configuration files to their default annotated state.""" - - self._runtime_overrides = defaultdict(dict) - defaults = self._filter_persisted(self._default_tree()) - self._loader.save_base(defaults) - - if include_profiles: - profiles_dir = self._loader.profile_path("_dummy").parent - if profiles_dir.exists(): - shutil.rmtree(profiles_dir) - loader_docs = getattr(self._loader, "_docs", {}) - for path in list(loader_docs.keys()): - try: - path.relative_to(profiles_dir) - except ValueError: - continue - loader_docs.pop(path, None) - self._profile = "default" - - self._reload() - self._apply_handlers() - - def apply_web_env(self, env_vars: Mapping[str, str]) -> None: - """Apply environment variable overrides for the web configuration section.""" - - self._restore_web_env() - for key, value in env_vars.items(): - self._web_env_previous[key] = os.environ.get(key) - os.environ[key] = value - - def _restore_web_env(self) -> None: - """Restore previously overridden environment variables.""" - - for key, previous in self._web_env_previous.items(): - if previous is None: - os.environ.pop(key, None) - else: - os.environ[key] = previous - self._web_env_previous.clear() - - def list_profiles(self) -> dict[str, list[str]]: - profiles_dir = self._loader.config_dir / "profiles" - user_profiles = [] - if profiles_dir.exists(): - for path in profiles_dir.glob("*.toml"): - user_profiles.append(path.stem) - built_in = sorted(name for name in BUILTIN_PROFILES.keys()) - return {"built_in": built_in, "user": sorted(user_profiles)} - - def preview_profile(self, profile: str) -> dict[str, Any]: - builtin = self._loader.get_builtin_profile(profile) - base = self._loader.load_base() - overrides = self._loader.load_user_profile(profile) - view = deep_merge(builtin, base, overrides) - return deepcopy(view) - - def get_section(self, name: str) -> BaseModel: - model = self._get_model(name) - if model is None: - raise AttributeError(f"Section '{name}' is not available") - return model - - def as_dict(self, include_env: bool = True) -> dict[str, Any]: - """Return the current configuration tree, including defaults for all sections.""" - - tree = self._compose_without_env() - if include_env: - tree = deep_merge(tree, self._env_overrides) - return deep_merge(self._default_tree(), tree) - - def __rich__(self) -> Panel: - """Return a rich renderable representation of the full configuration.""" - - return _build_config_panel( - title=f"Config (profile='{self._profile}')", - data=_prepare_for_display(self.as_dict(include_env=True)), - ) - - def format(self, *, include_env: bool = True) -> str: - """Return a human-friendly representation of the full configuration.""" - - panel = _build_config_panel( - title=f"Config (profile='{self._profile}')", - data=_prepare_for_display(self.as_dict(include_env=include_env)), - ) - return _render_panel(panel) - - def format_section(self, name: str) -> str: - """Return a string representation for an individual section.""" - - model = self._get_model(name) - if model is None: - raise AttributeError(f"Section '{name}' is not available") - data = _prepare_for_display(model.model_dump(exclude_unset=False)) - panel = _build_section_panel(name, data) - return _render_panel(panel) - - def on_section_registered(self, section: str) -> None: - self._reload() - self._apply_handlers(section=section) - - def on_handler_registered(self, section: str) -> None: - self._apply_handlers(section=section) - - def _resolve_initial_profile(self, profile: Optional[str]) -> str: - if profile: - return normalize_profile_name(str(profile)) - - # Check environment variables first (highest priority) - env_profile = ( - os.getenv("TIDY3D_CONFIG_PROFILE") - or os.getenv("TIDY3D_PROFILE") - or os.getenv("TIDY3D_ENV") - ) - if env_profile: - return normalize_profile_name(env_profile) - - # Check for default_profile in config file - config_default = self._loader.get_default_profile() - if config_default: - return normalize_profile_name(config_default) - - # Fall back to "default" profile - return "default" - - def _reload(self) -> None: - self._env_overrides = load_environment_overrides() - self._builtin_data = deepcopy(self._loader.get_builtin_profile(self._profile)) - self._base_data = deepcopy(self._loader.load_base()) - self._profile_data = deepcopy(self._loader.load_user_profile(self._profile)) - self._raw_tree = deep_merge(self._builtin_data, self._base_data, self._profile_data) - - runtime = deepcopy(self._runtime_overrides.get(self._profile, {})) - effective = deep_merge(self._raw_tree, self._env_overrides, runtime) - self._effective_tree = effective - self._build_models() - - def _build_models(self) -> None: - sections = get_sections() - new_sections: dict[str, BaseModel] = {} - new_plugins: dict[str, BaseModel] = {} - - errors: list[tuple[str, Exception]] = [] - for name, schema in sections.items(): - if name.startswith("plugins."): - plugin_name = name.split(".", 1)[1] - plugin_data = _deep_get(self._effective_tree, ("plugins", plugin_name)) or {} - try: - new_plugins[plugin_name] = schema(**plugin_data) - except Exception as exc: - log.error(f"Failed to load configuration for plugin '{plugin_name}': {exc}") - errors.append((name, exc)) - continue - if name == "plugins": - continue - section_data = self._effective_tree.get(name, {}) - try: - new_sections[name] = schema(**section_data) - except Exception as exc: - log.error(f"Failed to load configuration for section '{name}': {exc}") - errors.append((name, exc)) - - if errors: - # propagate the first error; others already logged - raise errors[0][1] - - self._section_models = new_sections - self._plugin_models = new_plugins - - def _get_model(self, name: str) -> Optional[BaseModel]: - if name.startswith("plugins."): - plugin = name.split(".", 1)[1] - return self._plugin_models.get(plugin) - return self._section_models.get(name) - - def _apply_handlers(self, section: Optional[str] = None) -> None: - handlers = get_handlers() - targets = [section] if section else handlers.keys() - for target in targets: - handler = handlers.get(target) - if handler is None: - continue - model = self._get_model(target) - if model is None: - continue - try: - handler(model) - except Exception as exc: - log.error(f"Failed to apply configuration handler for '{target}': {exc}") - - def _compose_without_env(self) -> dict[str, Any]: - runtime = self._runtime_overrides.get(self._profile, {}) - return deep_merge(self._raw_tree, runtime) - - def _default_tree(self) -> dict[str, Any]: - defaults: dict[str, Any] = {} - for name, schema in get_sections().items(): - if name.startswith("plugins."): - plugin = name.split(".", 1)[1] - defaults.setdefault("plugins", {})[plugin] = _model_dict(schema()) - elif name == "plugins": - defaults.setdefault("plugins", {}) - else: - defaults[name] = _model_dict(schema()) - return defaults - - def _filter_persisted(self, tree: dict[str, Any]) -> dict[str, Any]: - sections = get_sections() - filtered: dict[str, Any] = {} - plugins_source = tree.get("plugins", {}) - plugin_filtered: dict[str, Any] = {} - - for name, schema in sections.items(): - if name == "plugins": - continue - if name.startswith("plugins."): - plugin_name = name.split(".", 1)[1] - plugin_data = plugins_source.get(plugin_name, {}) - if not isinstance(plugin_data, dict): - continue - persisted_plugin = _extract_persisted(schema, plugin_data) - if persisted_plugin: - plugin_filtered[plugin_name] = persisted_plugin - continue - - section_data = tree.get(name, {}) - if not isinstance(section_data, dict): - continue - persisted_section = _extract_persisted(schema, section_data) - if persisted_section: - filtered[name] = persisted_section - - if plugin_filtered: - filtered["plugins"] = plugin_filtered - return filtered - - def __getattr__(self, name: str) -> Any: - if name in self._section_models: - return SectionAccessor(self, name) - if name == "plugins": - return self.plugins - raise AttributeError(f"Config has no section '{name}'") - - def __setattr__(self, name: str, value: Any) -> None: - if name.startswith("_"): - object.__setattr__(self, name, value) - return - if name in self._section_models: - if isinstance(value, BaseModel): - payload = value.model_dump(exclude_unset=False) - else: - payload = value - self.update_section(name, **payload) - return - object.__setattr__(self, name, value) - - def __str__(self) -> str: - return self.format() - - -def _deep_get(tree: dict[str, Any], path: Iterable[str]) -> Optional[dict[str, Any]]: - node: Any = tree - for segment in path: - if not isinstance(node, dict): - return None - node = node.get(segment) - if node is None: - return None - return node if isinstance(node, dict) else None - - -def _resolve_model_type(annotation: Any) -> Optional[type[BaseModel]]: - """Return the first BaseModel subclass found in an annotation (if any).""" - - if isinstance(annotation, type) and issubclass(annotation, BaseModel): - return annotation - - origin = get_origin(annotation) - if origin is None: - return None - - for arg in get_args(annotation): - nested = _resolve_model_type(arg) - if nested is not None: - return nested - return None - - -def _serialize_value(value: Any) -> Any: - if isinstance(value, BaseModel): - return value.model_dump(exclude_unset=False) - if hasattr(value, "get_secret_value"): - return value.get_secret_value() - return value - - -def _prepare_for_display(value: Any) -> Any: - if isinstance(value, BaseModel): - return { - k: _prepare_for_display(v) for k, v in value.model_dump(exclude_unset=False).items() - } - if isinstance(value, dict): - return {str(k): _prepare_for_display(v) for k, v in value.items()} - if isinstance(value, (list, tuple, set)): - return [_prepare_for_display(v) for v in value] - if isinstance(value, Path): - return str(value) - if isinstance(value, Enum): - return value.value - if hasattr(value, "get_secret_value"): - displayed = getattr(value, "display", None) - if callable(displayed): - return displayed() - return str(value) - return value - - -def _build_config_panel(title: str, data: dict[str, Any]) -> Panel: - tree = Tree(Text(title, style="bold cyan")) - if data: - for key in sorted(data.keys()): - branch = tree.add(Text(key, style="bold magenta")) - branch.add(Pretty(data[key], expand_all=True)) - else: - tree.add(Text("", style="dim")) - return Panel(tree, border_style="cyan", padding=(0, 1)) - - -def _build_section_panel(name: str, data: Any) -> Panel: - tree = Tree(Text(name, style="bold cyan")) - tree.add(Pretty(data, expand_all=True)) - return Panel(tree, border_style="cyan", padding=(0, 1)) - - -def _render_panel(renderable: Panel, *, width: int = 100) -> str: - buffer = StringIO() - console = Console(file=buffer, record=True, force_terminal=True, width=width, color_system=None) - console.print(renderable) - return buffer.getvalue().rstrip() - - -def _model_dict(model: BaseModel) -> dict[str, Any]: - data = model.model_dump(exclude_unset=False) - for key, value in list(data.items()): - if hasattr(value, "get_secret_value"): - data[key] = value.get_secret_value() - return data - - -def _extract_persisted(schema: type[BaseModel], data: dict[str, Any]) -> dict[str, Any]: - persisted: dict[str, Any] = {} - for field_name, field in schema.model_fields.items(): - schema_extra = field.json_schema_extra or {} - annotation = field.annotation - persist = bool(schema_extra.get("persist")) if isinstance(schema_extra, dict) else False - if not persist: - continue - if field_name not in data: - continue - value = data[field_name] - if value is None: - persisted[field_name] = None - continue - - nested_type = _resolve_model_type(annotation) - if nested_type is not None: - nested_source = value if isinstance(value, dict) else {} - nested_persisted = _extract_persisted(nested_type, nested_source) - if nested_persisted: - persisted[field_name] = nested_persisted - continue - - if hasattr(value, "get_secret_value"): - persisted[field_name] = value.get_secret_value() - else: - persisted[field_name] = deepcopy(value) - - return persisted +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -__all__ = [ - "ConfigManager", - "PluginsAccessor", - "ProfilesAccessor", - "SectionAccessor", - "normalize_profile_name", -] +from tidy3d._common.config.manager import ( + BUILTIN_PROFILES, + ConfigManager, + PluginsAccessor, + ProfilesAccessor, + SectionAccessor, + _build_config_panel, + _build_section_panel, + _deep_get, + _extract_persisted, + _model_dict, + _prepare_for_display, + _render_panel, + _resolve_model_type, + _serialize_value, + normalize_profile_name, +) diff --git a/tidy3d/config/profiles.py b/tidy3d/config/profiles.py index f73f1be562..a7870a6f1b 100644 --- a/tidy3d/config/profiles.py +++ b/tidy3d/config/profiles.py @@ -1,57 +1,10 @@ -"""Built-in configuration profiles for tidy3d.""" +"""Compatibility shim for :mod:`tidy3d._common.config.profiles`.""" -from __future__ import annotations - -from typing import Any +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -BUILTIN_PROFILES: dict[str, dict[str, Any]] = { - "default": { - "web": { - "api_endpoint": "https://tidy3d-api.simulation.cloud", - "website_endpoint": "https://tidy3d.simulation.cloud", - "s3_region": "us-gov-west-1", - } - }, - "prod": { - "web": { - "api_endpoint": "https://tidy3d-api.simulation.cloud", - "website_endpoint": "https://tidy3d.simulation.cloud", - "s3_region": "us-gov-west-1", - } - }, - "dev": { - "web": { - "api_endpoint": "https://tidy3d-api.dev-simulation.cloud", - "website_endpoint": "https://tidy3d.dev-simulation.cloud", - "s3_region": "us-east-1", - } - }, - "uat": { - "web": { - "api_endpoint": "https://tidy3d-api.uat-simulation.cloud", - "website_endpoint": "https://tidy3d.uat-simulation.cloud", - "s3_region": "us-west-2", - } - }, - "pre": { - "web": { - "api_endpoint": "https://preprod-tidy3d-api.simulation.cloud", - "website_endpoint": "https://preprod-tidy3d.simulation.cloud", - "s3_region": "us-gov-west-1", - } - }, - "nexus": { - "web": { - "api_endpoint": "http://127.0.0.1:5000", - "website_endpoint": "http://127.0.0.1/tidy3d", - "ssl_verify": False, - "enable_caching": False, - "s3_region": "us-east-1", - "env_vars": { - "AWS_ENDPOINT_URL_S3": "http://127.0.0.1:9000", - }, - } - }, -} +# marked as migrated to _common +from __future__ import annotations -__all__ = ["BUILTIN_PROFILES"] +from tidy3d._common.config.profiles import ( + BUILTIN_PROFILES, +) diff --git a/tidy3d/config/registry.py b/tidy3d/config/registry.py index 8ee8e216a5..ad4a9bddee 100644 --- a/tidy3d/config/registry.py +++ b/tidy3d/config/registry.py @@ -1,80 +1,21 @@ -"""Registry utilities for tidy3d configuration sections and handlers.""" - -from __future__ import annotations - -from typing import Callable, Optional, TypeVar - -from pydantic import BaseModel - -T = TypeVar("T", bound=BaseModel) - -_SECTIONS: dict[str, type[BaseModel]] = {} -_HANDLERS: dict[str, Callable[[BaseModel], None]] = {} -_MANAGER: Optional[ConfigManagerProtocol] = None - - -class ConfigManagerProtocol: - """Protocol-like interface for manager notifications.""" - - def on_section_registered(self, section: str) -> None: - """Called when a new section schema is registered.""" - - def on_handler_registered(self, section: str) -> None: - """Called when a handler is registered.""" - - -def attach_manager(manager: ConfigManagerProtocol) -> None: - """Attach the active configuration manager for registry callbacks.""" - - global _MANAGER - _MANAGER = manager - - -def get_manager() -> Optional[ConfigManagerProtocol]: - """Return the currently attached configuration manager, if any.""" - - return _MANAGER +"""Compatibility shim for :mod:`tidy3d._common.config.registry`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -def register_section(name: str) -> Callable[[type[T]], type[T]]: - """Decorator to register a configuration section schema.""" - - def decorator(cls: type[T]) -> type[T]: - _SECTIONS[name] = cls - if _MANAGER is not None: - _MANAGER.on_section_registered(name) - return cls - - return decorator - - -def register_plugin(name: str) -> Callable[[type[T]], type[T]]: - """Decorator to register a plugin configuration schema.""" - - return register_section(f"plugins.{name}") - - -def register_handler( - name: str, -) -> Callable[[Callable[[BaseModel], None]], Callable[[BaseModel], None]]: - """Decorator to register a handler for a configuration section.""" - - def decorator(func: Callable[[BaseModel], None]) -> Callable[[BaseModel], None]: - _HANDLERS[name] = func - if _MANAGER is not None: - _MANAGER.on_handler_registered(name) - return func - - return decorator - - -def get_sections() -> dict[str, type[BaseModel]]: - """Return registered section schemas.""" - - return dict(_SECTIONS) - - -def get_handlers() -> dict[str, Callable[[BaseModel], None]]: - """Return registered configuration handlers.""" +# marked as migrated to _common +from __future__ import annotations - return dict(_HANDLERS) +from tidy3d._common.config.registry import ( + _HANDLERS, + _MANAGER, + _SECTIONS, + ConfigManagerProtocol, + T, + attach_manager, + get_handlers, + get_manager, + get_sections, + register_handler, + register_plugin, + register_section, +) diff --git a/tidy3d/config/sections.py b/tidy3d/config/sections.py index ae70ce8278..a5f912f219 100644 --- a/tidy3d/config/sections.py +++ b/tidy3d/config/sections.py @@ -3,9 +3,8 @@ from __future__ import annotations import os -from os import PathLike from pathlib import Path -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional from urllib.parse import urlparse import numpy as np @@ -16,6 +15,7 @@ Field, NonNegativeFloat, NonNegativeInt, + NonPositiveFloat, PositiveInt, SecretStr, field_serializer, @@ -23,11 +23,21 @@ ) from tidy3d._runtime import WASM_BUILD -from tidy3d.log import DEFAULT_LEVEL, LogLevel, log, set_log_suppression, set_logging_level +from tidy3d.log import ( + DEFAULT_LEVEL, + LogLevel, + log, + set_log_suppression, + set_logging_level, + set_warn_once, +) from .registry import get_manager as _get_attached_manager from .registry import register_handler, register_section +if TYPE_CHECKING: + from os import PathLike + TLS_VERSION_CHOICES = {"TLSv1", "TLSv1_1", "TLSv1_2", "TLSv1_3"} @@ -69,6 +79,12 @@ class LoggingConfig(ConfigSection): description="Suppress repeated log messages when True.", ) + warn_once: bool = Field( + False, + title="Warn once", + description="When True, each unique warning message is only shown once per process.", + ) + @register_handler("logging") def apply_logging(config: LoggingConfig) -> None: @@ -76,6 +92,7 @@ def apply_logging(config: LoggingConfig) -> None: set_logging_level(config.level) set_log_suppression(config.suppression) + set_warn_once(config.warn_once) @register_section("simulation") @@ -145,6 +162,25 @@ class AdjointConfig(ConfigSection): ge=0.0, ) + boundary_snapping_fraction: float = Field( + 0.65, + title="Boundary snapping fraction", + description=( + "Fraction of minimum local grid size to use for snapping coordinates outside of " + "a boundary when computing shape gradients. Should be at least 0.5." + ), + ge=0.5, + ) + + pec_detection_threshold: NonPositiveFloat = Field( + -100.0, + title="PEC detection threshold", + description=( + "Value the real permittivity should be below to consider it a PEC material in " + "the shape gradient boundary integration." + ), + ) + local_gradient: bool = Field( False, title="Enable local gradients", @@ -266,7 +302,7 @@ def apply_adjoint(config: AdjointConfig) -> None: defaults = AdjointConfig() overridden = [ name - for name in config.model_fields + for name in type(config).model_fields if name != "local_gradient" and getattr(config, name) != getattr(defaults, name) ] if not overridden: @@ -468,6 +504,7 @@ class LocalCacheConfig(ConfigSection): ) @field_validator("directory", mode="before") + @classmethod def _ensure_directory_exists(cls, v: PathLike) -> Path: """Expand ~, resolve path, and create directory if missing before DirectoryPath validation.""" p = Path(v).expanduser().resolve() diff --git a/tidy3d/config/serializer.py b/tidy3d/config/serializer.py index 78b829c64a..e664881565 100644 --- a/tidy3d/config/serializer.py +++ b/tidy3d/config/serializer.py @@ -1,145 +1,16 @@ -from __future__ import annotations - -from collections.abc import Iterable -from typing import Any, get_args, get_origin - -import tomlkit -from pydantic import BaseModel -from pydantic.fields import FieldInfo -from tomlkit.items import Item, Table - -from .registry import get_sections - -Path = tuple[str, ...] - - -def collect_descriptions() -> dict[Path, str]: - """Collect description strings for registered configuration fields.""" - - descriptions: dict[Path, str] = {} - for section_name, model in get_sections().items(): - base_path = tuple(segment for segment in section_name.split(".") if segment) - section_doc = (model.__doc__ or "").strip() - if section_doc and base_path: - descriptions[base_path] = descriptions.get( - base_path, section_doc.splitlines()[0].strip() - ) - for field_name, field in model.model_fields.items(): - descriptions.update(_describe_field(field, prefix=(*base_path, field_name))) - return descriptions - - -def _describe_field(field: FieldInfo, prefix: Path) -> dict[Path, str]: - descriptions: dict[Path, str] = {} - description = (field.description or "").strip() - if description: - descriptions[prefix] = description - - nested_models: Iterable[type[BaseModel]] = _iter_model_types(field.annotation) - for model in nested_models: - nested_doc = (model.__doc__ or "").strip() - if nested_doc: - descriptions[prefix] = descriptions.get(prefix, nested_doc.splitlines()[0].strip()) - for sub_name, sub_field in model.model_fields.items(): - descriptions.update(_describe_field(sub_field, prefix=(*prefix, sub_name))) - return descriptions - - -def _iter_model_types(annotation: Any) -> Iterable[type[BaseModel]]: - """Yield BaseModel subclasses referenced by a field annotation (if any).""" +"""Compatibility shim for :mod:`tidy3d._common.config.serializer`.""" - if annotation is None: - return +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - stack = [annotation] - seen: set[type[BaseModel]] = set() - - while stack: - current = stack.pop() - if isinstance(current, type) and issubclass(current, BaseModel): - if current not in seen: - seen.add(current) - yield current - continue - - origin = get_origin(current) - if origin is None: - continue - - stack.extend(get_args(current)) - - -def build_document( - data: dict[str, Any], - existing: tomlkit.TOMLDocument | None, - descriptions: dict[Path, str] | None = None, -) -> tomlkit.TOMLDocument: - """Return a TOML document populated with data and annotated comments.""" - - descriptions = descriptions or collect_descriptions() - document = existing if existing is not None else tomlkit.document() - _prune_missing_keys(document, data.keys()) - for key, value in data.items(): - _apply_value( - container=document, - key=key, - value=value, - path=(key,), - descriptions=descriptions, - is_new=key not in document, - ) - return document - - -def _prune_missing_keys(container: Table | tomlkit.TOMLDocument, keys: Iterable[str]) -> None: - desired = set(keys) - for existing_key in list(container.keys()): - if existing_key not in desired: - del container[existing_key] - - -def _apply_value( - container: Table | tomlkit.TOMLDocument, - key: str, - value: Any, - path: Path, - descriptions: dict[Path, str], - is_new: bool, -) -> None: - description = descriptions.get(path) - if isinstance(value, dict): - existing = container.get(key) - table = existing if isinstance(existing, Table) else tomlkit.table() - _prune_missing_keys(table, value.keys()) - for sub_key, sub_value in value.items(): - _apply_value( - container=table, - key=sub_key, - value=sub_value, - path=(*path, sub_key), - descriptions=descriptions, - is_new=not isinstance(existing, Table) or sub_key not in table, - ) - if key in container: - container[key] = table - else: - if isinstance(container, tomlkit.TOMLDocument) and len(container) > 0: - container.add(tomlkit.nl()) - container.add(key, table) - return - - if value is None: - return - - existing_item = container.get(key) - new_item = tomlkit.item(value) - if isinstance(existing_item, Item): - new_item.trivia.comment = existing_item.trivia.comment - new_item.trivia.comment_ws = existing_item.trivia.comment_ws - elif description: - new_item.comment(description) +# marked as migrated to _common +from __future__ import annotations - if key in container: - container[key] = new_item - else: - container.add(key, new_item) +from tidy3d._common.config.serializer import ( + Path, + _apply_value, + _describe_field, + _iter_model_types, + _prune_missing_keys, + build_document, + collect_descriptions, +) diff --git a/tidy3d/constants.py b/tidy3d/constants.py index 81b168cad5..15810fcca5 100644 --- a/tidy3d/constants.py +++ b/tidy3d/constants.py @@ -1,313 +1,65 @@ -"""Defines importable constants. +"""Compatibility shim for :mod:`tidy3d._common.constants`.""" -Attributes: - inf (float): Tidy3d representation of infinity. - C_0 (float): Speed of light in vacuum [um/s] - EPSILON_0 (float): Vacuum permittivity [F/um] - MU_0 (float): Vacuum permeability [H/um] - ETA_0 (float): Vacuum impedance - HBAR (float): reduced Planck constant [eV*s] - Q_e (float): funamental charge [C] -""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common from __future__ import annotations -from types import MappingProxyType - -import numpy as np - -# fundamental constants (https://physics.nist.gov) -C_0 = 2.99792458e14 -""" -Speed of light in vacuum [um/s] -""" - -MU_0 = 1.25663706212e-12 -""" -Vacuum permeability [H/um] -""" - -EPSILON_0 = 1 / (MU_0 * C_0**2) -""" -Vacuum permittivity [F/um] -""" - -#: Free space impedance -ETA_0 = np.sqrt(MU_0 / EPSILON_0) -""" -Vacuum impedance in Ohms -""" - -Q_e = 1.602176634e-19 -""" -Fundamental charge [C] -""" - -HBAR = 6.582119569e-16 -""" -Reduced Planck constant [eV*s] -""" - -K_B = 8.617333262e-5 -""" -Boltzmann constant [eV/K] -""" - -GRAV_ACC = 9.80665 * 1e6 -""" -Gravitational acceleration (g) [um/s^2].", -""" - -M_E_C_SQUARE = 0.51099895069e6 -""" -Electron rest mass energy (m_e * c^2) [eV] -""" - -M_E_EV = M_E_C_SQUARE / C_0**2 -""" -Electron mass [eV*s^2/um^2] -""" - -# floating point precisions -dp_eps = np.finfo(np.float64).eps -""" -Double floating point precision. -""" - -fp_eps = np.float64(np.finfo(np.float32).eps) -""" -Floating point precision. -""" - -# values of PEC for mode solver -pec_val = -1e8 -""" -PEC values for mode solver -""" - -# unit labels -HERTZ = "Hz" -""" -One cycle per second. -""" - -TERAHERTZ = "THz" -""" -One trillion (10^12) cycles per second. -""" - -SECOND = "sec" -""" -SI unit of time. -""" - -PICOSECOND = "ps" -""" -One trillionth (10^-12) of a second. -""" - -METER = "m" -""" -SI unit of length. -""" - -PERMETER = "1/m" -""" -SI unit of inverse length. -""" - -MICROMETER = "um" -""" -One millionth (10^-6) of a meter. -""" - -NANOMETER = "nm" -""" -One billionth (10^-9) of a meter. -""" - -RADIAN = "rad" -""" -SI unit of angle. -""" - -CONDUCTIVITY = "S/um" -""" -Siemens per micrometer. -""" - -PERMITTIVITY = "None (relative permittivity)" -""" -Relative permittivity. -""" - -PML_SIGMA = "2*EPSILON_0/dt" -""" -2 times vacuum permittivity over time differential step. -""" - -RADPERSEC = "rad/sec" -""" -One radian per second. -""" - -RADPERMETER = "rad/m" -""" -One radian per meter. -""" - -NEPERPERMETER = "Np/m" -""" -SI unit for attenuation constant. -""" - - -ELECTRON_VOLT = "eV" -""" -Unit of energy. -""" - -KELVIN = "K" -""" -SI unit of temperature. -""" - -CMCUBE = "cm^3" -""" -Cubic centimeter unit of volume. -""" - -PERCMCUBE = "1/cm^3" -""" -Unit per centimeter cube. -""" - -WATT = "W" -""" -SI unit of power. -""" - -VOLT = "V" -""" -SI unit of electric potential. -""" - -PICOSECOND_PER_NANOMETER_PER_KILOMETER = "ps/(nm km)" -""" -Picosecond per (nanometer kilometer). -""" - -OHM = "ohm" -""" -SI unit of resistance. -""" - -FARAD = "farad" -""" -SI unit of capacitance. -""" - -HENRY = "henry" -""" -SI unit of inductance. -""" - -AMP = "A" -""" -SI unit of electric current. -""" - -THERMAL_CONDUCTIVITY = "W/(um*K)" -""" -Watts per (micrometer Kelvin). -""" - -SPECIFIC_HEAT_CAPACITY = "J/(kg*K)" -""" -Joules per (kilogram Kelvin). -""" - -DENSITY = "kg/um^3" -""" -Kilograms per cubic micrometer. -""" - -HEAT_FLUX = "W/um^2" -""" -Watts per square micrometer. -""" - -VOLUMETRIC_HEAT_RATE = "W/um^3" -""" -Watts per cube micrometer. -""" - -HEAT_TRANSFER_COEFF = "W/(um^2*K)" -""" -Watts per (square micrometer Kelvin). -""" - -CURRENT_DENSITY = "A/um^2" -""" -Amperes per square micrometer -""" - -DYNAMIC_VISCOSITY = "kg/(um*s)" -""" -Kilograms per (micrometer second) -""" - -SPECIFIC_HEAT = "um^2/(s^2*K)" -""" -Square micrometers per (square second Kelvin). -""" - -THERMAL_EXPANSIVITY = "1/K" -""" -Inverse Kelvin. -""" - -VELOCITY_SI = "m/s" -""" -SI unit of velocity -""" - -ACCELERATION = "um/s^2" -""" -Acceleration unit. -""" - -LARGE_NUMBER = 1e10 -""" -Large number used for comparing infinity. -""" - -LARGEST_FP_NUMBER = 1e38 -""" -Largest number used for single precision floating point number. -""" - -inf = np.inf -""" -Representation of infinity used within tidy3d. -""" - -# if |np.pi/2 - angle_theta| < GLANCING_CUTOFF in an angled source or in mode spec, raise warning -GLANCING_CUTOFF = 0.1 -""" -if |np.pi/2 - angle_theta| < GLANCING_CUTOFF in an angled source or in mode spec, raise warning. -""" - -UnitScaling = MappingProxyType( - { - "nm": 1e3, - "μm": 1e0, - "um": 1e0, - "mm": 1e-3, - "cm": 1e-4, - "m": 1e-6, - "mil": 1.0 / 25.4, - "in": 1.0 / 25400, - } +from tidy3d._common.constants import ( + ACCELERATION, + AMP, + C_0, + CMCUBE, + CONDUCTIVITY, + CURRENT_DENSITY, + DENSITY, + DYNAMIC_VISCOSITY, + ELECTRON_VOLT, + EPSILON_0, + ETA_0, + FARAD, + GLANCING_CUTOFF, + GRAV_ACC, + HBAR, + HEAT_FLUX, + HEAT_TRANSFER_COEFF, + HENRY, + HERTZ, + K_B, + KELVIN, + LARGE_NUMBER, + LARGEST_FP_NUMBER, + M_E_C_SQUARE, + M_E_EV, + METER, + MICROMETER, + MU_0, + NANOMETER, + NEPERPERMETER, + OHM, + PERCMCUBE, + PERMETER, + PERMITTIVITY, + PICOSECOND, + PICOSECOND_PER_NANOMETER_PER_KILOMETER, + PML_SIGMA, + RADIAN, + RADPERMETER, + RADPERSEC, + SECOND, + SPECIFIC_HEAT, + SPECIFIC_HEAT_CAPACITY, + TERAHERTZ, + THERMAL_CONDUCTIVITY, + THERMAL_EXPANSIVITY, + VELOCITY_SI, + VOLT, + VOLUMETRIC_HEAT_RATE, + WATT, + Q_e, + UnitScaling, + dp_eps, + fp_eps, + inf, + pec_val, ) -"""Immutable dictionary for converting microns to another spatial unit, eg. nm = um * UnitScaling["nm"].""" diff --git a/tidy3d/exceptions.py b/tidy3d/exceptions.py index 901b803a1a..4a6ff4822c 100644 --- a/tidy3d/exceptions.py +++ b/tidy3d/exceptions.py @@ -1,61 +1,21 @@ -"""Custom Tidy3D exceptions""" +"""Compatibility shim for :mod:`tidy3d._common.exceptions`.""" -from __future__ import annotations - -from typing import Optional - -from .log import log - - -class Tidy3dError(ValueError): - """Any error in tidy3d""" - - def __init__(self, message: Optional[str] = None, log_error: bool = True) -> None: - """Log just the error message and then raise the Exception.""" - super().__init__(message) - if log_error: - log.error(message) - - -class ConfigError(Tidy3dError): - """Error when configuring Tidy3d.""" - - -class Tidy3dKeyError(Tidy3dError): - """Could not find a key in a Tidy3d dictionary.""" - - -class ValidationError(Tidy3dError): - """Error when constructing Tidy3d components.""" - - -class SetupError(Tidy3dError): - """Error regarding the setup of the components (outside of domains, etc).""" - - -class FileError(Tidy3dError): - """Error reading or writing to file.""" - - -class WebError(Tidy3dError): - """Error with the webAPI.""" - - -class AuthenticationError(Tidy3dError): - """Error authenticating a user through webapi webAPI.""" - - -class DataError(Tidy3dError): - """Error accessing data.""" - - -class Tidy3dImportError(Tidy3dError): - """Error importing a package needed for tidy3d.""" - - -class Tidy3dNotImplementedError(Tidy3dError): - """Error when a functionality is not (yet) supported.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -class AdjointError(Tidy3dError): - """An error in setting up the adjoint solver.""" +from tidy3d._common.exceptions import ( + AdjointError, + AuthenticationError, + ConfigError, + DataError, + FileError, + SetupError, + Tidy3dError, + Tidy3dImportError, + Tidy3dKeyError, + Tidy3dNotImplementedError, + ValidationError, + WebError, +) diff --git a/tidy3d/log.py b/tidy3d/log.py index 20f08682f1..ca4a776d3c 100644 --- a/tidy3d/log.py +++ b/tidy3d/log.py @@ -1,478 +1,30 @@ -"""Logging Configuration for Tidy3d.""" +"""Compatibility shim for :mod:`tidy3d._common.log`.""" -from __future__ import annotations - -import inspect -from contextlib import contextmanager -from datetime import datetime -from os import PathLike -from typing import Any, Callable, Optional, Union - -from rich.console import Console -from rich.text import Text -from typing_extensions import Literal - -# Note: "SUPPORT" and "USER" levels are meant for backend runs only. -# Logging in frontend code should just use the standard debug/info/warning/error/critical. -LogLevel = Literal["DEBUG", "SUPPORT", "USER", "INFO", "WARNING", "ERROR", "CRITICAL"] -LogValue = Union[int, LogLevel] - -# Logging levels compatible with logging module -_level_value = { - "DEBUG": 10, - "SUPPORT": 12, - "USER": 15, - "INFO": 20, - "WARNING": 30, - "ERROR": 40, - "CRITICAL": 50, -} - -_level_name = {v: k for k, v in _level_value.items()} - -DEFAULT_LEVEL = "WARNING" - -DEFAULT_LOG_STYLES = { - "DEBUG": None, - "SUPPORT": None, - "USER": None, - "INFO": None, - "WARNING": "red", - "ERROR": "red bold", - "CRITICAL": "red bold", -} - -# Width of the console used for rich logging (in characters). -CONSOLE_WIDTH = 80 - - -def _default_log_level_format(level: str, message: str) -> tuple[str, str]: - """By default just return unformatted prefix and message.""" - return level, message - - -def _get_level_int(level: LogValue) -> int: - """Get the integer corresponding to the level string.""" - if isinstance(level, int): - return level - - if level not in _level_value: - # We don't want to import ConfigError to avoid a circular dependency - raise ValueError( - f"logging level {level} not supported, must be " - "'DEBUG', 'SUPPORT', 'USER', 'INFO', 'WARNING', 'ERROR', or 'CRITICAL'" - ) - return _level_value[level] - - -class LogHandler: - """Handle log messages depending on log level""" - - def __init__( - self, - console: Console, - level: LogValue, - log_level_format: Callable = _default_log_level_format, - prefix_every_line: bool = False, - ) -> None: - self.level = _get_level_int(level) - self.console = console - self.log_level_format = log_level_format - self.prefix_every_line = prefix_every_line - - def handle(self, level, level_name, message) -> None: - """Output log messages depending on log level""" - if level >= self.level: - stack = inspect.stack() - console = self.console - offset = 4 - if stack[offset - 1].filename.endswith("exceptions.py"): - # We want the calling site for exceptions.py - offset += 1 - prefix, msg = self.log_level_format(level_name, message) - if self.prefix_every_line: - wrapped_text = Text(msg, style="default") - msgs = wrapped_text.wrap(console=console, width=console.width - len(prefix) - 2) - else: - msgs = [msg] - for msg in msgs: - console.log( - prefix, - msg, - sep=": ", - style=DEFAULT_LOG_STYLES[level_name], - _stack_offset=offset, - ) - - -class Logger: - """Custom logger to avoid the complexities of the logging module - - The logger can be used in a context manager to avoid the emission of multiple messages. In this - case, the first message in the context is emitted normally, but any others are discarded. When - the context is exited, the number of discarded messages of each level is displayed with the - highest level of the captures messages. - - Messages can also be captured for post-processing. That can be enabled through 'set_capture' to - record all warnings emitted during model validation. A structured copy of all validation - messages can then be recovered through 'captured_warnings'. - """ - - _static_cache = set() - - def __init__(self) -> None: - self.handlers = {} - self.suppression = True - self._counts = None - self._stack = None - self._capture = False - self._captured_warnings = [] - - def set_capture(self, capture: bool) -> None: - """Turn on/off tree-like capturing of log messages.""" - self._capture = capture - - def captured_warnings(self): - """Get the formatted list of captured log messages.""" - captured_warnings = self._captured_warnings - self._captured_warnings = [] - return captured_warnings - - def __enter__(self): - """If suppression is enabled, enter a consolidation context (only a single message is - emitted).""" - if self.suppression and self._counts is None: - self._counts = {} - return self - - def __exit__(self, exc_type, exc_value, traceback): - """Exist a consolidation context (report the number of messages discarded).""" - if self._counts is not None: - total = sum(v for v in self._counts.values()) - if total > 0: - max_level = max(k for k, v in self._counts.items() if v > 0) - counts = [f"{v} {_level_name[k]}" for k, v in self._counts.items() if v > 0] - self._counts = None - if total > 0: - noun = " messages." if total > 1 else " message." - # Temporarily prevent capturing messages to emit consolidated summary - stack = self._stack - self._stack = None - self.log(max_level, "Suppressed " + ", ".join(counts) + noun) - self._stack = stack - return False - - def begin_capture(self) -> None: - """Start capturing log stack for consolidated validation log. - - This method is used before any model validation starts and is included in the initialization - of 'BaseModel'. It must be followed by a corresponding 'end_capture'. - """ - if not self._capture: - return - - stack_item = {"messages": [], "children": {}} - if self._stack: - self._stack.append(stack_item) - else: - self._stack = [stack_item] - - def end_capture(self, model) -> None: - """End capturing log stack for consolidated validation log. - - This method is used after all model validations and is included in the initialization of - 'BaseModel'. It must follow a corresponding 'begin_capture'. - """ - if not self._stack: - return - - stack_item = self._stack.pop() - if len(self._stack) == 0: - self._stack = None - - # Check if this stack item contains any messages or children - if len(stack_item["messages"]) > 0 or len(stack_item["children"]) > 0: - stack_item["type"] = model.__class__.__name__ - - # Set the path for each children - model_fields = model.get_submodels_by_hash() - for child_hash, child_dict in stack_item["children"].items(): - child_dict["parent_fields"] = model_fields.get(child_hash, []) - - # Are we at the bottom of the stack? - if self._stack is None: - # Yes, we're root - self._parse_warning_capture(current_loc=[], stack_item=stack_item) - else: - # No, we're someone else's child - hash_ = hash(model) - self._stack[-1]["children"][hash_] = stack_item - - def _parse_warning_capture(self, current_loc, stack_item) -> None: - """Process capture tree to compile formatted captured warnings.""" - - if "parent_fields" in stack_item: - for field in stack_item["parent_fields"]: - if isinstance(field, tuple): - # array field - new_loc = current_loc + list(field) - else: - # single field - new_loc = [*current_loc, field] - - # process current level warnings - for level, msg, custom_loc in stack_item["messages"]: - if level == "WARNING": - self._captured_warnings.append({"loc": new_loc + custom_loc, "msg": msg}) - - # initialize processing at children level - for child_stack in stack_item["children"].values(): - self._parse_warning_capture(current_loc=new_loc, stack_item=child_stack) - - else: # for root object - # process current level warnings - for level, msg, custom_loc in stack_item["messages"]: - if level == "WARNING": - self._captured_warnings.append({"loc": current_loc + custom_loc, "msg": msg}) - - # initialize processing at children level - for child_stack in stack_item["children"].values(): - self._parse_warning_capture(current_loc=current_loc, stack_item=child_stack) - - def _log( - self, - level: int, - level_name: str, - message: str, - *args: Any, - log_once: bool = False, - custom_loc: Optional[list] = None, - capture: bool = True, - ) -> None: - """Distribute log messages to all handlers""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - # Check global cache if requested (before composing/capturing to avoid duplicates) - if log_once: - # Use the message body before composition as key - if message in self._static_cache: - return - self._static_cache.add(message) - - # Compose message - if len(args) > 0: - try: - composed_message = str(message) % args - - except Exception as e: - composed_message = f"{message} % {args}\n{e}" - else: - composed_message = str(message) - - # Capture all messages (even if suppressed later) - if self._stack and capture: - if custom_loc is None: - custom_loc = [] - self._stack[-1]["messages"].append((level_name, composed_message, custom_loc)) - - # Context-local logger emits a single message and consolidates the rest - if self._counts is not None: - if len(self._counts) > 0: - self._counts[level] = 1 + self._counts.get(level, 0) - return - self._counts[level] = 0 - - # Forward message to handlers - for handler in self.handlers.values(): - handler.handle(level, level_name, composed_message) - - def log(self, level: LogValue, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) with given level""" - if isinstance(level, str): - level_name = level - level = _get_level_int(level) - else: - level_name = _level_name.get(level, "unknown") - self._log(level, level_name, message, *args, log_once=log_once) - - def debug(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at debug level""" - self._log(_level_value["DEBUG"], "DEBUG", message, *args, log_once=log_once) - - def support(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at support level""" - self._log(_level_value["SUPPORT"], "SUPPORT", message, *args, log_once=log_once) - - def user(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at user level""" - self._log(_level_value["USER"], "USER", message, *args, log_once=log_once) - - def info(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at info level""" - self._log(_level_value["INFO"], "INFO", message, *args, log_once=log_once) - - def warning( - self, - message: str, - *args: Any, - log_once: bool = False, - custom_loc: Optional[list] = None, - capture: bool = True, - ) -> None: - """Log (message) % (args) at warning level""" - self._log( - _level_value["WARNING"], - "WARNING", - message, - *args, - log_once=log_once, - custom_loc=custom_loc, - capture=capture, - ) - - def error(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at error level""" - self._log(_level_value["ERROR"], "ERROR", message, *args, log_once=log_once) - - def critical(self, message: str, *args: Any, log_once: bool = False) -> None: - """Log (message) % (args) at critical level""" - self._log(_level_value["CRITICAL"], "CRITICAL", message, *args, log_once=log_once) - - -def set_logging_level(level: LogValue = DEFAULT_LEVEL) -> None: - """Set tidy3d console logging level priority. - - Parameters - ---------- - level : str - The lowest priority level of logging messages to display. One of ``{'DEBUG', 'SUPPORT', - 'USER', INFO', 'WARNING', 'ERROR', 'CRITICAL'}`` (listed in increasing priority). - """ - if "console" in log.handlers: - log.handlers["console"].level = _get_level_int(level) - - -def set_log_suppression(value: bool) -> None: - """Control log suppression for repeated messages.""" - log.suppression = value - - -def get_aware_datetime() -> datetime: - """Get an aware current local datetime(with local timezone info)""" - return datetime.now().astimezone() - - -def set_logging_console(stderr: bool = False) -> None: - """Set stdout or stderr as console output - - Parameters - ---------- - stderr : bool - If False, logs are directed to stdout, otherwise to stderr. - """ - if "console" in log.handlers: - previous_level = log.handlers["console"].level - else: - previous_level = DEFAULT_LEVEL - log.handlers["console"] = LogHandler( - Console( - stderr=stderr, - width=CONSOLE_WIDTH, - log_path=False, - get_datetime=get_aware_datetime, - log_time_format="%X %Z", - ), - previous_level, - ) - - -def set_logging_file( - fname: PathLike, - filemode: str = "w", - level: LogValue = DEFAULT_LEVEL, - log_path: bool = False, -) -> None: - """Set a file to write log to, independently from the stdout and stderr - output chosen using :meth:`set_logging_level`. - - Parameters - ---------- - fname : PathLike - Path to file to direct the output to. If empty string, a previously set logging file will - be closed, if any, but nothing else happens. - filemode : str - 'w' or 'a', defining if the file should be overwritten or appended. - level : str - One of ``{'DEBUG', 'SUPPORT', 'USER', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'}``. This is set - for the file independently of the console output level set by :meth:`set_logging_level`. - log_path : bool = False - Whether to log the path to the file that issued the message. - """ - if filemode not in "wa": - raise ValueError("filemode must be either 'w' or 'a'") - - # Close previous handler, if any - if "file" in log.handlers: - try: - log.handlers["file"].console.file.close() - except Exception: # TODO: catch specific exception - log.warning("Log file could not be closed") - finally: - del log.handlers["file"] - - if str(fname) == "": - # Empty string can be passed to just stop previously opened file handler - return - - try: - file = open(fname, filemode) - except Exception: # TODO: catch specific exception - log.error(f"File {fname} could not be opened") - return - - log.handlers["file"] = LogHandler( - Console(file=file, force_jupyter=False, log_path=log_path), level - ) - - -# Initialize Tidy3d's logger -log = Logger() - -# Set default logging output -set_logging_console() - - -def get_logging_console() -> Console: - """Get console from logging handlers.""" - if "console" not in log.handlers: - set_logging_console() - return log.handlers["console"].console - - -class NoOpProgress: - """Dummy progress manager that doesn't show any output.""" - - def __enter__(self): - return self - - def __exit__(self, *args: Any, **kwargs: Any) -> None: - pass - - def add_task(self, *args: Any, **kwargs: Any) -> None: - pass - - def update(self, *args: Any, **kwargs: Any) -> None: - pass - - -@contextmanager -def Progress(console, show_progress): - """Progress manager that wraps ``rich.Progress`` if ``show_progress`` is ``True``, - and ``NoOpProgress`` otherwise.""" - if show_progress: - from rich.progress import Progress +# marked as migrated to _common +from __future__ import annotations - with Progress(console=console) as progress: - yield progress - else: - with NoOpProgress() as progress: - yield progress +from tidy3d._common.log import ( + CONSOLE_WIDTH, + DEFAULT_LEVEL, + DEFAULT_LOG_STYLES, + Logger, + LogHandler, + LogLevel, + LogValue, + NoOpProgress, + Progress, + _default_log_level_format, + _get_level_int, + _level_name, + _level_value, + get_aware_datetime, + get_logging_console, + log, + set_log_suppression, + set_logging_console, + set_logging_file, + set_logging_level, + set_warn_once, +) diff --git a/tidy3d/material_library/material_library.py b/tidy3d/material_library/material_library.py index 638538e95b..73481e811a 100644 --- a/tidy3d/material_library/material_library.py +++ b/tidy3d/material_library/material_library.py @@ -3,17 +3,19 @@ from __future__ import annotations import json -from os import PathLike -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Optional, Union -import pydantic.v1 as pd -from rich.panel import Panel -from rich.table import Table +from pydantic import Field, model_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.material.multi_physics import MultiPhysicsMedium from tidy3d.components.material.tcad.charge import SemiconductorMedium -from tidy3d.components.medium import AnisotropicMedium, Medium2D, PoleResidue, Sellmeier +from tidy3d.components.medium import ( + AnisotropicMedium, + Medium2D, + PoleResidue, + Sellmeier, +) from tidy3d.components.tcad.bandgap_energy import ConstantEnergyBandGap from tidy3d.components.tcad.types import ( AugerRecombination, @@ -23,7 +25,6 @@ ShockleyReedHallRecombination, SlotboomBandGapNarrowing, ) -from tidy3d.components.types import Axis from tidy3d.exceptions import SetupError from tidy3d.log import log @@ -40,7 +41,14 @@ ) if TYPE_CHECKING: + from os import PathLike + from IPython.lib.pretty import RepresentationPrinter + from rich.panel import Panel + from rich.table import Table + + from tidy3d.compat import Self + from tidy3d.components.types import Axis def export_matlib_to_file(fname: PathLike = "matlib.json") -> None: @@ -75,13 +83,13 @@ def export_matlib_to_file(fname: PathLike = "matlib.json") -> None: class AbstractVariantItem(Tidy3dBaseModel): """Reference, and data_source for a variant of a material.""" - reference: list[ReferenceData] = pd.Field( + reference: Optional[list[ReferenceData]] = Field( None, title="Reference information", description="A list of references related to this variant model.", ) - data_url: str = pd.Field( + data_url: Optional[str] = Field( None, title="Dispersion data URL", description="The URL to access the dispersion data upon which the material " @@ -105,8 +113,7 @@ def _repr_pretty_(self, p: RepresentationPrinter, cycle: bool) -> None: class VariantItem(AbstractVariantItem): """Reference, data_source, and material model for a variant of a material.""" - medium: Union[PoleResidue, MultiPhysicsMedium] = pd.Field( - ..., + medium: Union[PoleResidue, MultiPhysicsMedium] = Field( title="Material dispersion model", description="A dispersive medium described by the pole-residue pair model.", ) @@ -119,26 +126,23 @@ def summarize_mediums(self) -> dict[str, Union[PoleResidue, Medium2D, MultiPhysi class MaterialItem(Tidy3dBaseModel): """A material that includes several variants.""" - name: str = pd.Field(..., title="Name", description="Unique name for the medium.") - variants: dict[str, VariantItem] = pd.Field( - ..., + name: str = Field(title="Name", description="Unique name for the medium.") + variants: dict[str, VariantItem] = Field( title="Dictionary of available variants for this material", description="A dictionary of available variants for this material " "that maps from a key to the variant model.", ) - default: str = pd.Field( - ..., title="default variant", description="The default type of variant." - ) + default: str = Field(title="default variant", description="The default type of variant.") - @pd.validator("default", always=True) - def _default_in_variants(cls, val: str, values: dict[str, Any]) -> Any: + @model_validator(mode="after") + def _default_in_variants(self: Self) -> Self: """Make sure the default variant is already included in the ``variants``.""" - if val not in values["variants"]: + if self.default not in self.variants: raise SetupError( - f"The data of the default variant '{val}' is not supplied; " + f"The data of the default variant '{self.default}' is not supplied; " "please include it in the 'variants'." ) - return val + return self def __getitem__(self, variant_name: str) -> Union[PoleResidue, MultiPhysicsMedium]: """Helper function to easily access the medium of a variant""" @@ -167,8 +171,7 @@ def _repr_pretty_(self, p: RepresentationPrinter, cycle: bool) -> None: class VariantItem2D(AbstractVariantItem): """Reference, data_source, and material model for a variant of a 2D material.""" - medium: Medium2D = pd.Field( - ..., + medium: Medium2D = Field( title="Material dispersion model", description="A dispersive 2D medium described by a surface conductivity model, " "which is handled as an anisotropic medium with pole-residue pair models " @@ -183,8 +186,7 @@ def summarize_mediums(self) -> dict[str, Union[PoleResidue, Medium2D, MultiPhysi class MaterialItem2D(MaterialItem): """A 2D material that includes several variants.""" - variants: dict[str, VariantItem2D] = pd.Field( - ..., + variants: dict[str, VariantItem2D] = Field( title="Dictionary of available variants for this material", description="A dictionary of available variants for this material " "that maps from a key to the variant model.", @@ -194,14 +196,12 @@ class MaterialItem2D(MaterialItem): class VariantItemUniaxial(AbstractVariantItem): """Reference, data_source, and material model for a variant of an uniaxial material.""" - ordinary: PoleResidue = pd.Field( - ..., + ordinary: PoleResidue = Field( title="Ordinary Component", description="Medium describing the ordinary component.", ) - extraordinary: PoleResidue = pd.Field( - ..., + extraordinary: PoleResidue = Field( title="Extraordinary Component", description="Medium describing the extraordinary component.", ) @@ -224,7 +224,7 @@ def medium(self, optical_axis: Axis) -> AnisotropicMedium: components = ["xx", "yy", "zz"] mat_dict = dict.fromkeys(components, self.ordinary) mat_dict.update({components[optical_axis]: self.extraordinary}) - return AnisotropicMedium.parse_obj(mat_dict) + return AnisotropicMedium.model_validate(mat_dict) @property def summarize_mediums(self) -> dict[str, Union[PoleResidue, Medium2D, MultiPhysicsMedium]]: @@ -234,8 +234,7 @@ def summarize_mediums(self) -> dict[str, Union[PoleResidue, Medium2D, MultiPhysi class MaterialItemUniaxial(MaterialItem): """A material that includes several variants.""" - variants: dict[str, VariantItemUniaxial] = pd.Field( - ..., + variants: dict[str, VariantItemUniaxial] = Field( title="Dictionary of available variants for this material", description="A dictionary of available variants for this material " "that maps from a key to the variant model.", diff --git a/tidy3d/material_library/material_reference.py b/tidy3d/material_library/material_reference.py index 53cfa89eca..985433e329 100644 --- a/tidy3d/material_library/material_reference.py +++ b/tidy3d/material_library/material_reference.py @@ -2,7 +2,9 @@ from __future__ import annotations -import pydantic.v1 as pd +from typing import Optional + +from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel @@ -10,23 +12,23 @@ class ReferenceData(Tidy3dBaseModel): """Reference data.""" - doi: str = pd.Field(None, title="DOI", description="DOI of the reference.") - journal: str = pd.Field( + doi: Optional[str] = Field(None, title="DOI", description="DOI of the reference.") + journal: Optional[str] = Field( None, title="Journal publication info", description="Publication info in the order of author, title, journal volume, and year.", ) - url: str = pd.Field( + url: Optional[str] = Field( None, title="URL link", description="Some reference can be accessed through a url link to its pdf etc.", ) - manufacturer: str = pd.Field( + manufacturer: Optional[str] = Field( None, title="Manufacturer", description="Name of the manufacturer, e.g., Rogers, Arlon.", ) - datasheet_title: str = pd.Field( + datasheet_title: Optional[str] = Field( None, title="Datasheet Title", description="Title of the datasheet.", diff --git a/tidy3d/material_library/parametric_materials.py b/tidy3d/material_library/parametric_materials.py index 179ec7a660..630c74717f 100644 --- a/tidy3d/material_library/parametric_materials.py +++ b/tidy3d/material_library/parametric_materials.py @@ -4,9 +4,10 @@ import warnings from abc import ABC, abstractmethod +from typing import Optional import numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeInt from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.medium import Drude, Medium2D, PoleResidue @@ -68,35 +69,35 @@ class Graphene(ParametricVariantItem2D): """ - mu_c: float = pd.Field( + mu_c: float = Field( GRAPHENE_DEF_MU_C, title="Chemical potential in eV", description="Chemical potential in eV.", units=ELECTRON_VOLT, ) - temp: float = pd.Field( + temp: float = Field( GRAPHENE_DEF_TEMP, title="Temperature in K", description="Temperature in K.", units=KELVIN ) - gamma: float = pd.Field( + gamma: float = Field( GRAPHENE_DEF_GAMMA, title="Scattering rate in eV", description="Scattering rate in eV. Must be small compared to the optical frequency.", units=ELECTRON_VOLT, ) - scaling: float = pd.Field( + scaling: float = Field( 1, title="Scaling factor", description="Scaling factor used to model multiple layers of graphene.", ) - include_interband: bool = pd.Field( + include_interband: bool = Field( True, title="Include interband terms", description="Include interband terms, relevant at high frequency (IR). " "Otherwise, the intraband terms only give a simpler Drude-type model relevant " "only at low frequency (THz).", ) - interband_fit_freq_nodes: list[tuple[float, float]] = pd.Field( + interband_fit_freq_nodes: Optional[list[tuple[float, float]]] = Field( None, title="Interband fitting frequency nodes", description="Frequency nodes for fitting interband term. " @@ -108,7 +109,7 @@ class Graphene(ParametricVariantItem2D): "of frequencies; consider changing the nodes to obtain a better fit for a " "narrow-band simulation.", ) - interband_fit_num_iters: pd.NonNegativeInt = pd.Field( + interband_fit_num_iters: NonNegativeInt = Field( GRAPHENE_FIT_NUM_ITERS, title="Interband fitting number of iterations", description="Number of iterations for optimizing each Pade approximant when " @@ -201,12 +202,12 @@ def numerical_conductivity(self, freqs: list[float]) -> list[complex]: Parameters ---------- - freqs : List[float] + freqs : list[float] The list of frequencies. Returns ------- - List[complex] + list[complex] The list of corresponding conductivities, in S. """ intra_sigma = self.intraband_drude.sigma_model(freqs) @@ -218,20 +219,22 @@ def interband_conductivity(self, freqs: list[float]) -> list[complex]: Parameters ---------- - freqs : List[float] + freqs : list[float] The list of frequencies. Returns ------- - List[complex] + list[complex] The list of corresponding interband conductivities, in S. """ try: from scipy import integrate - - INTEGRATE_AVAILABLE = True except ImportError: - INTEGRATE_AVAILABLE = False + raise ImportError( + "The package 'scipy' was not found. Please install the 'core' " + "dependencies to calculate the interband term of graphene. For example: " + "pip install tidy3d" + ) from None def fermi(E: float) -> float: """Fermi distribution.""" @@ -248,13 +251,6 @@ def integrand(E: float, omega: float) -> float: """Integrand for interband term.""" return (fermi_g(E * HBAR) - fermi_g(HBAR * omega / 2)) / (omega**2 - 4 * E**2) - if not INTEGRATE_AVAILABLE: - raise ImportError( - "The package 'scipy' was not found. Please install the 'core' " - "dependencies to calculate the interband term of graphene. For example: " - "pip install tidy3d" - ) - omegas = 2 * np.pi * np.array(freqs) sigma = np.zeros(len(omegas), dtype=complex) integration_min = GRAPHENE_INT_MIN @@ -283,11 +279,11 @@ def _fit_interband_conductivity( Parameters ---------- - freqs : List[float] + freqs : list[float] The input frequencies. - sigma : List[complex] + sigma : list[complex] The interband conductivity to fit. - indslist : List[Tuple[int, int]] + indslist : list[tuple[int, int]] The indices at which to sample the data for fitting. The length of this list determines the number of Pade terms used. Returns diff --git a/tidy3d/material_library/util.py b/tidy3d/material_library/util.py index 46083676c8..2da7a43ef9 100644 --- a/tidy3d/material_library/util.py +++ b/tidy3d/material_library/util.py @@ -1,7 +1,7 @@ from __future__ import annotations from io import StringIO -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from rich.console import Console from rich.panel import Panel @@ -9,12 +9,14 @@ from rich.text import Text from rich.tree import Tree -from tidy3d import Medium2D, MultiPhysicsMedium, PoleResidue from tidy3d.components.viz import FLEXCOMPUTE_COLORS if TYPE_CHECKING: + from typing import Union + from IPython.lib.pretty import RepresentationPrinter + from tidy3d import Medium2D, MultiPhysicsMedium, PoleResidue from tidy3d.material_library.material_library import ( AbstractVariantItem, MaterialItem, diff --git a/tidy3d/packaging.py b/tidy3d/packaging.py index 370b1caa1d..5ef16c3c1e 100644 --- a/tidy3d/packaging.py +++ b/tidy3d/packaging.py @@ -1,289 +1,36 @@ -""" -This file contains a set of functions relating to packaging tidy3d for distribution. Sections of the codebase should depend on this file, but this file should not depend on any other part of the codebase. +"""Compatibility shim for :mod:`tidy3d._common.packaging`.""" -This section should only depend on the standard core installation in the pyproject.toml, and should not depend on any other part of the codebase optional imports. -""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as partially migrated to _common from __future__ import annotations import functools -from importlib import import_module -from importlib.util import find_spec -from typing import Any, Literal - -import numpy as np - -from tidy3d.config import config - -from .exceptions import Tidy3dImportError -from .version import __version__ - -vtk = { - "mod": None, - "id_type": np.int64, - "vtk_to_numpy": None, - "numpy_to_vtkIdTypeArray": None, - "numpy_to_vtk": None, -} - -tidy3d_extras = {"mod": None, "use_local_subpixel": None} - - -def check_import(module_name: str) -> bool: - """ - Check if a module or submodule section has been imported. This is a functional way of loading packages that will still load the corresponding module into the total space. - - Parameters - ---------- - module_name - - Returns - ------- - bool - True if the module has been imported, False otherwise. - - """ - try: - import_module(module_name) - return True - except ImportError: - return False - - -def verify_packages_import(modules: list, required: Literal["any", "all"] = "all"): - def decorator(func): - """ - When decorating a method, requires that the specified modules are available. It will raise an error if the - module is not available depending on the value of the 'required' parameter which represents the type of - import required. - - There are a few options to choose for the 'required' parameter: - - 'all': All the modules must be available for the operation to continue without raising an error - - 'any': At least one of the modules must be available for the operation to continue without raising an error - - Parameters - ---------- - func - The function to decorate. - - Returns - ------- - checks_modules_import - The decorated function. - - """ - - @functools.wraps(func) - def checks_modules_import(*args: Any, **kwargs: Any): - """ - Checks if the modules are available. If they are not available, it will raise an error depending on the value. - """ - available_modules_status = [] - maximum_amount_modules = len(modules) - - module_id_i = 0 - for module in modules: - # Starts counting from one so that it can be compared to len(modules) - module_id_i += 1 - import_available = check_import(module) - available_modules_status.append( - import_available - ) # Stores the status of the module import - - if not import_available: - if required == "all": - raise Tidy3dImportError( - f"The package '{module}' is required for this operation, but it was not found. " - f"Please install the '{module}' dependencies using, for example, " - f"'pip install tidy3d[]" - ) - if required == "any": - # Means we need to verify that at least one of the modules is available - if ( - not any(available_modules_status) - ) and module_id_i == maximum_amount_modules: - # Means that we have reached the last module and none of them were available - raise Tidy3dImportError( - f"The package '{module}' is required for this operation, but it was not found. " - f"Please install the '{module}' dependencies using, for example, " - f"'pip install tidy3d[]" - ) - else: - raise ValueError( - f"The value '{required}' is not a valid value for the 'required' parameter. " - f"Please use any 'all' or 'any'." - ) - else: - # Means that the module is available, so we can just continue with the operation - pass - return func(*args, **kwargs) - - return checks_modules_import - - return decorator - - -def requires_vtk(fn): - """When decorating a method, requires that vtk is available.""" - - @functools.wraps(fn) - def _fn(*args: Any, **kwargs: Any): - if vtk["mod"] is None: - try: - import vtk as vtk_mod - from vtk.util.numpy_support import ( - numpy_to_vtk, - numpy_to_vtkIdTypeArray, - vtk_to_numpy, - ) - from vtkmodules.vtkCommonCore import vtkLogger - - vtk["mod"] = vtk_mod - vtk["vtk_to_numpy"] = vtk_to_numpy - vtk["numpy_to_vtkIdTypeArray"] = numpy_to_vtkIdTypeArray - vtk["numpy_to_vtk"] = numpy_to_vtk - - vtkLogger.SetStderrVerbosity(vtkLogger.VERBOSITY_WARNING) - - if vtk["mod"].vtkIdTypeArray().GetDataTypeSize() == 4: - vtk["id_type"] = np.int32 - - except ImportError as exc: - raise Tidy3dImportError( - "The package 'vtk' is required for this operation, but it was not found. " - "Please install the 'vtk' dependencies using, for example, " - "'pip install .[vtk]'." - ) from exc - - return fn(*args, **kwargs) - - return _fn - - -def get_numpy_major_version(module=np): - """ - Extracts the major version of the installed numpy accordingly. - - Parameters - ---------- - module : module - The module to extract the version from. Default is numpy. - - Returns - ------- - int - The major version of the module. - """ - # Get the version of the module - module_version = module.__version__ - - # Extract the major version number - major_version = int(module_version.split(".")[0]) - - return major_version - - -def _check_tidy3d_extras_available(quiet: bool = False): - """Helper function to check if 'tidy3d-extras' is available and version matched. - - Parameters - ---------- - quiet : bool - If True, suppress error logging when raising exceptions. - - Raises - ------ - Tidy3dImportError - If tidy3d-extras is not available or not properly initialized. - """ - if tidy3d_extras["mod"] is not None: - return - - module_exists = find_spec("tidy3d_extras") is not None - if not module_exists: - raise Tidy3dImportError( - "The package 'tidy3d-extras' is absent. " - "Please install the 'tidy3d-extras' package using, for " - r"example, 'pip install tidy3d\[extras]'.", - log_error=not quiet, - ) - - try: - import tidy3d_extras as tidy3d_extras_mod - - except ImportError as exc: - raise Tidy3dImportError( - "The package 'tidy3d-extras' did not initialize correctly.", - log_error=not quiet, - ) from exc - - if not hasattr(tidy3d_extras_mod, "__version__"): - raise Tidy3dImportError( - "The package 'tidy3d-extras' did not initialize correctly. " - "Please install the 'tidy3d-extras' package using, for " - r"example, 'pip install tidy3d\[extras]'.", - log_error=not quiet, - ) - - version = tidy3d_extras_mod.__version__ - - if version is None: - raise Tidy3dImportError( - "The package 'tidy3d-extras' did not initialize correctly, " - "likely due to an invalid API key.", - log_error=not quiet, - ) - - if version != __version__: - raise Tidy3dImportError( - f"The version of 'tidy3d-extras' is {version}, but the version of 'tidy3d' is {__version__}. " - "They must match. You can install the correct " - r"version using 'pip install tidy3d\[extras]'.", - log_error=not quiet, - ) - - tidy3d_extras["mod"] = tidy3d_extras_mod - - -def check_tidy3d_extras_licensed_feature(feature_name: str, quiet: bool = False): - """Helper function to check if a specific feature is licensed in 'tidy3d-extras'. - - Parameters - ---------- - feature_name : str - The name of the feature to check for. - quiet : bool - If True, suppress error logging when raising exceptions. - - Raises - ------ - Tidy3dImportError - If the feature is not available with your license. - """ - - try: - _check_tidy3d_extras_available(quiet=quiet) - except Tidy3dImportError as exc: - raise Tidy3dImportError( - f"The package 'tidy3d-extras' is required for this feature '{feature_name}'.", - log_error=not quiet, - ) from exc - - features = tidy3d_extras["mod"].extension._features() - if feature_name not in features: - raise Tidy3dImportError( - f"The feature '{feature_name}' is not available with your license. " - "Please contact Tidy3D support, or upgrade your license.", - log_error=not quiet, - ) - - -def supports_local_subpixel(fn): +from typing import TYPE_CHECKING, Any + +from tidy3d._common.config import config +from tidy3d._common.exceptions import Tidy3dImportError +from tidy3d._common.packaging import ( + _check_tidy3d_extras_available, + check_import, + check_tidy3d_extras_licensed_feature, + get_numpy_major_version, + requires_vtk, + tidy3d_extras, + verify_packages_import, + vtk, +) + +if TYPE_CHECKING: + from tidy3d._common.packaging import F + + +def supports_local_subpixel(fn: F) -> F: """When decorating a method, checks that 'tidy3d-extras' is available, conditioned on 'config.simulation.use_local_subpixel'.""" @functools.wraps(fn) - def _fn(*args: Any, **kwargs: Any): + def _fn(*args: Any, **kwargs: Any) -> Any: preference = config.simulation.use_local_subpixel if preference is False: @@ -309,11 +56,11 @@ def _fn(*args: Any, **kwargs: Any): return _fn -def disable_local_subpixel(fn): +def disable_local_subpixel(fn: F) -> F: """When decorating a method, temporarily disables local subpixel.""" @functools.wraps(fn) - def _fn(*args: Any, **kwargs: Any): + def _fn(*args: Any, **kwargs: Any) -> Any: simulation = config.simulation previous = simulation.use_local_subpixel diff --git a/tidy3d/plugins/autograd/differential_operators.py b/tidy3d/plugins/autograd/differential_operators.py index 3bd92356eb..4638e52929 100644 --- a/tidy3d/plugins/autograd/differential_operators.py +++ b/tidy3d/plugins/autograd/differential_operators.py @@ -1,15 +1,19 @@ from __future__ import annotations -from typing import Callable +from typing import TYPE_CHECKING from autograd.builtins import tuple as atuple from autograd.core import make_vjp from autograd.extend import vspace from autograd.wrap_util import unary_to_nary -from numpy.typing import ArrayLike from .utilities import scalar_objective +if TYPE_CHECKING: + from typing import Callable + + from numpy.typing import ArrayLike + __all__ = [ "grad", "value_and_grad", diff --git a/tidy3d/plugins/autograd/functions.py b/tidy3d/plugins/autograd/functions.py index 578a1efc79..2becf0f5f0 100644 --- a/tidy3d/plugins/autograd/functions.py +++ b/tidy3d/plugins/autograd/functions.py @@ -1,7 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import Callable, Literal, SupportsInt, Union +from typing import TYPE_CHECKING import autograd.numpy as np import numpy as onp @@ -12,12 +11,20 @@ from autograd.tracer import getval from numpy.fft import irfftn, rfftn from numpy.lib.stride_tricks import sliding_window_view -from numpy.typing import NDArray from scipy.fft import next_fast_len from tidy3d.components.autograd.functions import add_at, interpn, trapz -from .types import PaddingType +if TYPE_CHECKING: + from collections.abc import Iterable + from types import ModuleType + from typing import Callable, Literal, Optional, SupportsInt, Union + + from numpy.typing import NDArray + + from tidy3d.components.autograd import TracedArrayLike + + from .types import PaddingType __all__ = [ "add_at", @@ -178,7 +185,7 @@ def _get_pad_indices( pad_width: tuple[int, int], *, mode: PaddingType, - numpy_module, + numpy_module: ModuleType, ) -> NDArray: """Compute the indices to pad an array along a single axis based on the padding mode. @@ -335,7 +342,11 @@ def convolve( return _fft_convolve_general(working_array, kernel, axes_array, axes_kernel, effective_mode) -def _get_footprint(size, structure, maxval): +def _get_footprint( + size: Union[int, tuple[int, int], None], + structure: Optional[NDArray], + maxval: float, +) -> NDArray: """Helper to generate the morphological footprint from size or structure.""" if size is None and structure is None: raise ValueError("Either size or structure must be provided.") @@ -404,7 +415,15 @@ def grey_dilation( return onp.max(dilated_windows, axis=(-2, -1)) -def _vjp_maker_dilation(ans, array, size=None, structure=None, *, mode="reflect", maxval=1e4): +def _vjp_maker_dilation( + ans: NDArray, + array: NDArray, + size: Union[int, tuple[int, int], None] = None, + structure: Optional[NDArray] = None, + *, + mode: PaddingType = "reflect", + maxval: float = 1e4, +) -> Callable[[TracedArrayLike], TracedArrayLike]: """VJP for the custom grey_dilation primitive.""" nb = _get_footprint(size, structure, maxval) h, w = nb.shape @@ -429,7 +448,7 @@ def _vjp_maker_dilation(ans, array, size=None, structure=None, *, mode="reflect" multiplicity = onp.sum(is_max_mask, axis=(-2, -1), keepdims=True) is_max_mask /= onp.maximum(multiplicity, 1) - def vjp(g): + def vjp(g: TracedArrayLike) -> TracedArrayLike: g_reshaped = g[..., None, None] grad_windows = g_reshaped * is_max_mask diff --git a/tidy3d/plugins/autograd/invdes/filters.py b/tidy3d/plugins/autograd/invdes/filters.py index 10da127178..b20d6adc40 100644 --- a/tidy3d/plugins/autograd/invdes/filters.py +++ b/tidy3d/plugins/autograd/invdes/filters.py @@ -1,22 +1,28 @@ from __future__ import annotations import abc -from collections.abc import Iterable from functools import lru_cache, partial -from typing import Annotated, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Annotated, Any, Union import numpy as np -import pydantic.v1 as pd -from numpy.typing import NDArray +from pydantic import Field, PositiveInt import tidy3d as td from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.types import TYPE_TAG_STR from tidy3d.plugins.autograd.functions import convolve from tidy3d.plugins.autograd.primitives import gaussian_filter as autograd_gaussian_filter -from tidy3d.plugins.autograd.types import KernelType, PaddingType +from tidy3d.plugins.autograd.types import PaddingType from tidy3d.plugins.autograd.utilities import get_kernel_size_px, make_kernel +if TYPE_CHECKING: + from collections.abc import Iterable + from typing import Callable, Optional + + from numpy.typing import NDArray + + from tidy3d.plugins.autograd.types import KernelType + _GAUSSIAN_SIGMA_SCALE = 0.445 # empirically matches conic kernel response in 1D/2D tests _GAUSSIAN_PADDING_MAP = { "constant": "constant", @@ -30,14 +36,19 @@ class AbstractFilter(Tidy3dBaseModel, abc.ABC): """An abstract class for creating and applying convolution filters.""" - kernel_size: Union[pd.PositiveInt, tuple[pd.PositiveInt, ...]] = pd.Field( - ..., title="Kernel Size", description="Size of the kernel in pixels for each dimension." + kernel_size: Union[PositiveInt, tuple[PositiveInt, ...]] = Field( + title="Kernel Size", + description="Size of the kernel in pixels for each dimension.", ) - normalize: bool = pd.Field( - True, title="Normalize", description="Whether to normalize the kernel so that it sums to 1." + normalize: bool = Field( + True, + title="Normalize", + description="Whether to normalize the kernel so that it sums to 1.", ) - padding: PaddingType = pd.Field( - "reflect", title="Padding", description="The padding mode to use." + padding: PaddingType = Field( + "reflect", + title="Padding", + description="The padding mode to use.", ) @classmethod @@ -51,9 +62,9 @@ def from_radius_dl( Parameters ---------- - radius : Union[float, Tuple[float, ...]] + radius : Union[float, tuple[float, ...]] The radius of the kernel. Can be a scalar or a tuple. - dl : Union[float, Tuple[float, ...]] + dl : Union[float, tuple[float, ...]] The grid spacing. Can be a scalar or a tuple. **kwargs Additional keyword arguments to pass to the filter constructor. @@ -154,13 +165,13 @@ class GaussianFilter(AbstractFilter): a unit-sum kernel; setting it to ``False`` has no effect. """ - sigma_scale: float = pd.Field( + sigma_scale: float = Field( _GAUSSIAN_SIGMA_SCALE, title="Sigma Scale", description="Scale factor mapping radius in pixels to Gaussian sigma.", ge=0.0, ) - truncate: float = pd.Field( + truncate: float = Field( 2.0, title="Truncate", description="Truncation radius in multiples of sigma passed to ``gaussian_filter``.", @@ -204,16 +215,16 @@ def _get_kernel_size( Parameters ---------- - radius : Union[float, Tuple[float, ...]] + radius : Union[float, tuple[float, ...]] The radius of the kernel. Can be a scalar or a tuple. - dl : Union[float, Tuple[float, ...]] + dl : Union[float, tuple[float, ...]] The grid spacing. Can be a scalar or a tuple. - size_px : Union[int, Tuple[int, ...]] + size_px : Union[int, tuple[int, ...]] The size of the kernel in pixels for each dimension. Can be a scalar or a tuple. Returns ------- - Tuple[int, ...] + tuple[int, ...] The size of the kernel in pixels for each dimension. Raises @@ -246,11 +257,11 @@ def make_filter( Parameters ---------- - radius : Union[float, Tuple[float, ...]] = None + radius : Union[float, tuple[float, ...]] = None The radius of the kernel. Can be a scalar or a tuple. - dl : Union[float, Tuple[float, ...]] = None + dl : Union[float, tuple[float, ...]] = None The grid spacing. Can be a scalar or a tuple. - size_px : Union[int, Tuple[int, ...]] = None + size_px : Union[int, tuple[int, ...]] = None The size of the kernel in pixels for each dimension. Can be a scalar or a tuple. normalize : bool = True Whether to normalize the kernel so that it sums to 1. @@ -307,5 +318,5 @@ def make_filter( """ FilterType = Annotated[ - Union[ConicFilter, CircularFilter, GaussianFilter], pd.Field(discriminator=TYPE_TAG_STR) + Union[ConicFilter, CircularFilter, GaussianFilter], Field(discriminator=TYPE_TAG_STR) ] diff --git a/tidy3d/plugins/autograd/invdes/misc.py b/tidy3d/plugins/autograd/invdes/misc.py index 0221d35b3e..1e7cf87cbd 100644 --- a/tidy3d/plugins/autograd/invdes/misc.py +++ b/tidy3d/plugins/autograd/invdes/misc.py @@ -1,7 +1,11 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import autograd.numpy as np -from numpy.typing import NDArray + +if TYPE_CHECKING: + from numpy.typing import NDArray def grey_indicator(array: NDArray) -> float: diff --git a/tidy3d/plugins/autograd/invdes/parametrizations.py b/tidy3d/plugins/autograd/invdes/parametrizations.py index c6435f25a9..5a2e721594 100644 --- a/tidy3d/plugins/autograd/invdes/parametrizations.py +++ b/tidy3d/plugins/autograd/invdes/parametrizations.py @@ -1,15 +1,15 @@ from __future__ import annotations from collections import deque -from typing import Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import autograd.numpy as np -import pydantic.v1 as pd from autograd import value_and_grad -from numpy.typing import NDArray +from pydantic import Field, NonNegativeFloat from scipy.optimize import minimize import tidy3d as td +from tidy3d.components.autograd.functions import _straight_through_clip from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.grid.grid import Coords from tidy3d.plugins.autograd.constants import BETA_DEFAULT, ETA_DEFAULT @@ -18,30 +18,47 @@ from .filters import make_filter from .projections import tanh_projection +if TYPE_CHECKING: + from typing import Callable, Literal + + from numpy.typing import NDArray + class FilterAndProject(Tidy3dBaseModel): """A class that combines filtering and projection operations.""" - radius: Union[float, tuple[float, ...]] = pd.Field( - ..., title="Radius", description="The radius of the kernel." + radius: Union[float, tuple[float, ...]] = Field( + title="Radius", + description="The radius of the kernel.", ) - dl: Union[float, tuple[float, ...]] = pd.Field( - ..., title="Grid Spacing", description="The grid spacing." + dl: Union[float, tuple[float, ...]] = Field( + title="Grid Spacing", + description="The grid spacing.", ) - size_px: Union[int, tuple[int, ...]] = pd.Field( - None, title="Size in Pixels", description="The size of the kernel in pixels." + size_px: Optional[Union[int, tuple[int, ...]]] = Field( + None, + title="Size in Pixels", + description="The size of the kernel in pixels.", ) - beta: pd.NonNegativeFloat = pd.Field( - BETA_DEFAULT, title="Beta", description="The beta parameter for the tanh projection." + beta: NonNegativeFloat = Field( + BETA_DEFAULT, + title="Beta", + description="The beta parameter for the tanh projection.", ) - eta: pd.NonNegativeFloat = pd.Field( - ETA_DEFAULT, title="Eta", description="The eta parameter for the tanh projection." + eta: NonNegativeFloat = Field( + ETA_DEFAULT, + title="Eta", + description="The eta parameter for the tanh projection.", ) - filter_type: KernelType = pd.Field( - "conic", title="Filter Type", description="The type of filter to create." + filter_type: KernelType = Field( + "conic", + title="Filter Type", + description="The type of filter to create.", ) - padding: PaddingType = pd.Field( - "reflect", title="Padding", description="The padding mode to use." + padding: PaddingType = Field( + "reflect", + title="Padding", + description="The padding mode to use.", ) def __call__( @@ -74,7 +91,9 @@ def __call__( beta = beta if beta is not None else self.beta eta = eta if eta is not None else self.eta projected = tanh_projection(filtered, beta, eta) - return projected + clip_projected = _straight_through_clip(projected, a_min=0.0, a_max=1.0) + + return clip_projected def make_filter_and_project( diff --git a/tidy3d/plugins/autograd/invdes/penalties.py b/tidy3d/plugins/autograd/invdes/penalties.py index 92eaac6e79..aa3cbc1818 100644 --- a/tidy3d/plugins/autograd/invdes/penalties.py +++ b/tidy3d/plugins/autograd/invdes/penalties.py @@ -1,43 +1,60 @@ from __future__ import annotations -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Optional, Union import autograd.numpy as np -import pydantic.v1 as pd -from numpy.typing import NDArray +from pydantic import Field, NonNegativeFloat from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.types import ArrayFloat2D from tidy3d.plugins.autograd.types import PaddingType from .parametrizations import FilterAndProject +if TYPE_CHECKING: + from typing import Callable + + from numpy.typing import NDArray + + from tidy3d.components.types import ArrayFloat2D + class ErosionDilationPenalty(Tidy3dBaseModel): """A class that computes a penalty for erosion/dilation of a parameter map not being unity.""" - radius: Union[float, tuple[float, ...]] = pd.Field( - ..., title="Radius", description="The radius of the kernel." + radius: Union[float, tuple[float, ...]] = Field( + title="Radius", + description="The radius of the kernel.", ) - dl: Union[float, tuple[float, ...]] = pd.Field( - ..., title="Grid Spacing", description="The grid spacing." + dl: Union[float, tuple[float, ...]] = Field( + title="Grid Spacing", + description="The grid spacing.", ) - size_px: Union[int, tuple[int, ...]] = pd.Field( - None, title="Size in Pixels", description="The size of the kernel in pixels." + size_px: Optional[Union[int, tuple[int, ...]]] = Field( + None, + title="Size in Pixels", + description="The size of the kernel in pixels.", ) - beta: pd.NonNegativeFloat = pd.Field( - 20.0, title="Beta", description="The beta parameter for the tanh projection." + beta: NonNegativeFloat = Field( + 20.0, + title="Beta", + description="The beta parameter for the tanh projection.", ) - eta: pd.NonNegativeFloat = pd.Field( - 0.5, title="Eta", description="The eta parameter for the tanh projection." + eta: NonNegativeFloat = Field( + 0.5, + title="Eta", + description="The eta parameter for the tanh projection.", ) - filter_type: str = pd.Field( - "conic", title="Filter Type", description="The type of filter to create." + filter_type: str = Field( + "conic", + title="Filter Type", + description="The type of filter to create.", ) - padding: PaddingType = pd.Field( - "reflect", title="Padding", description="The padding mode to use." + padding: PaddingType = Field( + "reflect", + title="Padding", + description="The padding mode to use.", ) - delta_eta: float = pd.Field( + delta_eta: float = Field( 0.01, title="Delta Eta", description="The binarization threshold for erosion and dilation operations.", @@ -69,16 +86,16 @@ def __call__(self, array: NDArray) -> float: eta_dilate = 0.0 + self.delta_eta eta_eroded = 1.0 - self.delta_eta - def _dilate(arr: NDArray): + def _dilate(arr: NDArray) -> NDArray: return filtproj(arr, eta=eta_dilate) - def _erode(arr: NDArray): + def _erode(arr: NDArray) -> NDArray: return filtproj(arr, eta=eta_eroded) - def _open(arr: NDArray): + def _open(arr: NDArray) -> NDArray: return _dilate(_erode(arr)) - def _close(arr: NDArray): + def _close(arr: NDArray) -> NDArray: return _erode(_dilate(arr)) diff = _close(array) - _open(array) diff --git a/tidy3d/plugins/autograd/invdes/projections.py b/tidy3d/plugins/autograd/invdes/projections.py index ee14f38aa2..753d7e1fed 100644 --- a/tidy3d/plugins/autograd/invdes/projections.py +++ b/tidy3d/plugins/autograd/invdes/projections.py @@ -1,10 +1,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import autograd.numpy as np -from numpy.typing import NDArray from tidy3d.plugins.autograd.constants import BETA_DEFAULT, ETA_DEFAULT +if TYPE_CHECKING: + from numpy.typing import NDArray + def ramp_projection(array: NDArray, width: float = 0.1, center: float = 0.5) -> NDArray: """Apply a piecewise linear ramp projection to an array. @@ -78,7 +82,7 @@ def smoothed_projection( array: NDArray, beta: float = BETA_DEFAULT, eta: float = ETA_DEFAULT, - scaling_factor=1.0, + scaling_factor: float = 1.0, ) -> NDArray: """ Apply a subpixel-smoothed projection method. diff --git a/tidy3d/plugins/autograd/invdes/symmetries.py b/tidy3d/plugins/autograd/invdes/symmetries.py index 09af15486e..1fc0624999 100644 --- a/tidy3d/plugins/autograd/invdes/symmetries.py +++ b/tidy3d/plugins/autograd/invdes/symmetries.py @@ -1,8 +1,10 @@ from __future__ import annotations from collections.abc import Sequence +from typing import TYPE_CHECKING -from numpy.typing import NDArray +if TYPE_CHECKING: + from numpy.typing import NDArray def symmetrize_mirror(array: NDArray, axis: int | tuple[int, int]) -> NDArray: @@ -52,7 +54,7 @@ def symmetrize_mirror(array: NDArray, axis: int | tuple[int, int]) -> NDArray: # Helper function to flip along a specific axis using slicing # Autograd supports slicing (e.g. ::-1) but lacks VJP for np.flip - def flip_axis(arr, ax): + def flip_axis(arr: NDArray, ax: int) -> NDArray: if ax == 0: return arr[::-1, :] elif ax == 1: diff --git a/tidy3d/plugins/autograd/primitives/interpolate.py b/tidy3d/plugins/autograd/primitives/interpolate.py index ac5b6c2d4c..d0981907b5 100644 --- a/tidy3d/plugins/autograd/primitives/interpolate.py +++ b/tidy3d/plugins/autograd/primitives/interpolate.py @@ -1,13 +1,17 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING import numpy as np from autograd.extend import defvjp, primitive -from numpy.typing import NDArray from tidy3d.log import log +if TYPE_CHECKING: + from typing import Callable, Optional + + from numpy.typing import NDArray + def _assert_strictly_monotonic(x: NDArray) -> None: """Raise if ``x`` is not strictly monotonic (all increasing or all decreasing).""" @@ -586,7 +590,7 @@ def compute_spline_coeffs( y_points: NDArray, endpoint_derivatives: tuple[Optional[float], Optional[float]] = (None, None), order: int = 3, -) -> tuple: +) -> tuple[NDArray, ...]: """Compute spline coefficients for the given order. Parameters @@ -617,7 +621,7 @@ def compute_spline_coeffs( raise NotImplementedError(f"Spline order '{order}' not implemented.") -def evaluate_spline(x_points: NDArray, coeffs: tuple, x_eval: NDArray) -> NDArray: +def evaluate_spline(x_points: NDArray, coeffs: tuple[NDArray, ...], x_eval: NDArray) -> NDArray: """Evaluate a spline at the specified points. Parameters @@ -650,7 +654,7 @@ def get_spline_derivatives_wrt_y( x_points: NDArray, y_points: NDArray, endpoint_derivatives: tuple[Optional[float], Optional[float]] = (None, None), -): +) -> tuple[NDArray, ...]: """Returns a tuple of derivative arrays for the given spline order. Parameters @@ -731,10 +735,17 @@ def _interpolate_spline( return x_eval, y_eval -def interpolate_spline_y_vjp(ans, x_points, y_points, num_points, order, endpoint_derivatives): +def interpolate_spline_y_vjp( + ans: tuple[NDArray, NDArray], + x_points: NDArray, + y_points: NDArray, + num_points: int, + order: int, + endpoint_derivatives: tuple[Optional[float], Optional[float]], +) -> Callable[[tuple[NDArray, NDArray] | NDArray], NDArray]: """VJP for interpolate_spline wrt y_points.""" - def vjp(g): + def vjp(g: tuple[NDArray, NDArray] | NDArray) -> NDArray: reversed_order = x_points[0] > x_points[-1] if reversed_order: x_proc = x_points[::-1] diff --git a/tidy3d/plugins/autograd/primitives/misc.py b/tidy3d/plugins/autograd/primitives/misc.py index 71b144aa4a..4b5d686ec8 100644 --- a/tidy3d/plugins/autograd/primitives/misc.py +++ b/tidy3d/plugins/autograd/primitives/misc.py @@ -1,13 +1,20 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from functools import cache +from typing import TYPE_CHECKING, Any import autograd.numpy as anp import numpy as np import scipy.ndimage from autograd.extend import defjvp, defvjp, primitive +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Callable + + from numpy.typing import NDArray + def _normalize_sequence(value: float | Sequence[float], ndim: int) -> tuple[float, ...]: """Convert a scalar or sequence into a tuple of length ``ndim``.""" @@ -56,15 +63,15 @@ def _gaussian_weight_matrix( @primitive def gaussian_filter( - array, - sigma, + array: NDArray, + sigma: float | Sequence[float], *, - order=0, - mode="reflect", - cval=0.0, - truncate=4.0, - **kwargs, -): + order: float | Sequence[float] = 0, + mode: str | Sequence[str] = "reflect", + cval: float = 0.0, + truncate: float = 4.0, + **kwargs: Any, +) -> NDArray: return scipy.ndimage.gaussian_filter( array, sigma, @@ -77,8 +84,16 @@ def gaussian_filter( def _gaussian_filter_vjp( - ans, array, sigma, *, order=0, mode="reflect", cval=0.0, truncate=4.0, **kwargs -): + ans: NDArray, + array: NDArray, + sigma: float | Sequence[float], + *, + order: float | Sequence[float] = 0, + mode: str | Sequence[str] = "reflect", + cval: float | Sequence[float] = 0.0, + truncate: float | Sequence[float] = 4.0, + **kwargs: Any, +) -> Callable[[NDArray], NDArray]: ndim = array.ndim sigma_seq = _normalize_sequence(sigma, ndim) order_seq = _normalize_sequence(order, ndim) @@ -93,7 +108,7 @@ def _gaussian_filter_vjp( f"gaussian_filter VJP does not support additional keyword arguments: {tuple(kwargs)}" ) - def vjp(g): + def vjp(g: NDArray) -> NDArray: grad = np.asarray(g) for axis in reversed(range(ndim)): sigma_axis = float(sigma_seq[axis]) diff --git a/tidy3d/plugins/autograd/utilities.py b/tidy3d/plugins/autograd/utilities.py index 7a7b5f83a8..a260ad8e53 100644 --- a/tidy3d/plugins/autograd/utilities.py +++ b/tidy3d/plugins/autograd/utilities.py @@ -2,16 +2,23 @@ from collections.abc import Iterable from functools import reduce, wraps -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, overload import autograd.numpy as anp import numpy as np import xarray as xr -from numpy.typing import NDArray from tidy3d.exceptions import Tidy3dError -from .types import KernelType +if TYPE_CHECKING: + from typing import Callable, Optional, Union + + from numpy.typing import NDArray + + from .types import KernelType + +P = ParamSpec("P") +R = TypeVar("R") def _kernel_circular(size: Iterable[int]) -> NDArray: @@ -100,7 +107,7 @@ def get_kernel_size_px( Returns ------- - Union[int, List[int]] + Union[int, list[int]] The size of the kernel in pixels for each dimension. Returns an integer if the radius is scalar, otherwise a list of integers. Raises @@ -112,9 +119,9 @@ def get_kernel_size_px( raise ValueError("Either 'size_px' or both 'radius' and 'dl' must be provided.") if np.isscalar(radius): - radius = [radius] * len(dl) if isinstance(dl, Iterable) else [radius] + radius = [radius] * len(dl) if isinstance(dl, Iterable) else [radius] # type: ignore[list-item] if np.isscalar(dl): - dl = [dl] * len(radius) + dl = [dl] * len(radius) # type: ignore[list-item] radius_px = [np.ceil(r / g) for r, g in zip(radius, dl)] return ( @@ -124,7 +131,7 @@ def get_kernel_size_px( ) -def chain(*funcs: Union[Callable, Iterable[Callable]]): +def chain(*funcs: Union[Callable, Iterable[Callable]]) -> Callable[[NDArray], NDArray]: """Chain multiple functions together to apply them sequentially to an array. Parameters @@ -162,13 +169,25 @@ def chain(*funcs: Union[Callable, Iterable[Callable]]): if not all(callable(f) for f in funcs): raise TypeError("All elements in funcs must be callable.") - def chained(array: NDArray): + def chained(array: NDArray) -> NDArray: return reduce(lambda x, y: y(x), funcs, array) return chained -def scalar_objective(func: Optional[Callable] = None, *, has_aux: bool = False) -> Callable: +@overload +def scalar_objective( + func: None = None, *, has_aux: bool = False +) -> Callable[[Callable[P, Any]], Callable[P, Any]]: ... + + +@overload +def scalar_objective(func: Callable[P, Any], *, has_aux: bool = False) -> Callable[P, Any]: ... + + +def scalar_objective( + func: Optional[Callable[P, Any]] = None, *, has_aux: bool = False +) -> Callable[..., Any]: """Decorator to ensure the objective function returns a real scalar value. This decorator wraps an objective function to ensure that its return value is a real scalar. @@ -194,51 +213,45 @@ def scalar_objective(func: Optional[Callable] = None, *, has_aux: bool = False) Tidy3dError If the return value is not a real scalar, or if `has_aux` is True and the function does not return a tuple of length 2. """ - if func is None: - return lambda f: scalar_objective(f, has_aux=has_aux) - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - result = func(*args, **kwargs) - aux_data = None - - # Unpack auxiliary data if present - if has_aux: - if not isinstance(result, tuple) or len(result) != 2: - raise Tidy3dError( - "If 'has_aux' is True, the objective function must return a tuple of length 2." - ) - result, aux_data = result - - # Extract data from xarray.DataArray - if isinstance(result, xr.DataArray): - result = result.data - - # Squeeze to remove singleton dimensions - result = anp.squeeze(result) - - # Attempt to extract scalar value - try: - result = result.item() - except AttributeError: - # If result is already a scalar, pass - if not isinstance(result, (float, int)): + def decorator(f: Callable[P, Any]) -> Callable[P, Any]: + @wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any: + result = f(*args, **kwargs) + aux_data = None + + if has_aux: + if not isinstance(result, tuple) or len(result) != 2: + raise Tidy3dError( + "If 'has_aux' is True, the objective function must return a tuple of length 2." + ) + result, aux_data = result + + if isinstance(result, xr.DataArray): + result = result.data + + result = anp.squeeze(result) + + try: + result = result.item() + except AttributeError: + if not isinstance(result, (float, int)): + raise Tidy3dError( + "An objective function's return value must be a scalar, " + "a Python float/int, or an array containing a single element." + ) from None + except ValueError as e: raise Tidy3dError( - "An objective function's return value must be a scalar, " - "a Python float/int, or an array containing a single element." - ) from None - except ValueError as e: - # Result contains more than one element - raise Tidy3dError( - "An objective function's return value must be a scalar " - "but got an array with shape " - f"{getattr(result, 'shape', 'N/A')}." - ) from e - - # Ensure the result is real - if not anp.isreal(result): - raise Tidy3dError("An objective function's return value must be real.") - - return (result, aux_data) if aux_data is not None else result - - return wrapper + "An objective function's return value must be a scalar " + "but got an array with shape " + f"{getattr(result, 'shape', 'N/A')}." + ) from e + + if not anp.isreal(result): + raise Tidy3dError("An objective function's return value must be real.") + + return (result, aux_data) if aux_data is not None else result + + return wrapper + + return decorator(func) if func is not None else decorator diff --git a/tidy3d/plugins/design/design.py b/tidy3d/plugins/design/design.py index 2e9fd86d4d..ddbda70490 100644 --- a/tidy3d/plugins/design/design.py +++ b/tidy3d/plugins/design/design.py @@ -3,15 +3,15 @@ from __future__ import annotations import inspect -from typing import Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.data.sim_data import SimulationData from tidy3d.components.simulation import Simulation from tidy3d.components.types import TYPE_TAG_STR -from tidy3d.log import Console, get_logging_console, log +from tidy3d.log import get_logging_console, log from tidy3d.web.api.container import Batch, BatchData, Job from .method import ( @@ -24,6 +24,11 @@ from .parameter import ParameterAny, ParameterInt, ParameterType from .result import Result +if TYPE_CHECKING: + from typing import Callable, Union + + from rich.console import Console + class DesignSpace(Tidy3dBaseModel): """Manages all exploration of a parameter space within specified parameters using a supplied search method. @@ -66,20 +71,19 @@ class DesignSpace(Tidy3dBaseModel): """ - parameters: tuple[ParameterType, ...] = pd.Field( + parameters: tuple[ParameterType, ...] = Field( (), title="Parameters", description="Set of parameters defining the dimensions and allowed values for the design space.", ) - method: MethodType = pd.Field( - ..., + method: MethodType = Field( title="Search Type", description="Specifications for the procedure used to explore the parameter space.", discriminator=TYPE_TAG_STR, # Stops pydantic trying to validate every method whilst checking MethodType ) - task_name: str = pd.Field( + task_name: str = Field( "", title="Task Name", description="Task name assigned to tasks along with a simulation counter in the form of {task_name}_{sim_index}_{counter} where ``sim_index`` is " @@ -88,15 +92,19 @@ class DesignSpace(Tidy3dBaseModel): "Only used when pre-post functions are supplied.", ) - name: str = pd.Field(None, title="Name", description="Optional name for the design space.") + name: Optional[str] = Field( + None, + title="Name", + description="Optional name for the design space.", + ) - path_dir: str = pd.Field( + path_dir: str = Field( ".", title="Path Directory", description="Directory where simulation data files will be locally saved to. Only used when pre and post functions are supplied.", ) - folder_name: str = pd.Field( + folder_name: str = Field( "default", title="Folder Name", description="Folder path where the simulation will be uploaded in the Tidy3D Workspace. Will use 'default' if no path is set.", @@ -270,7 +278,7 @@ def run( opt_output=opt_output, ) - def run_single(self, fn: Callable, console: Console) -> tuple(list[dict], list, list[Any]): + def run_single(self, fn: Callable, console: Console) -> tuple[list[dict], list, list[Any]]: """Run a single function of parameter inputs.""" evaluate_fn = self._get_evaluate_fn_single(fn=fn) return self.method._run(run_fn=evaluate_fn, parameters=self.parameters, console=console) @@ -281,7 +289,7 @@ def run_pre_post( fn_post: Callable, console: Console, priority: Optional[int] = None, - ) -> tuple(list[dict], list[dict], list[Any]): + ) -> tuple[list[dict], list[dict], list[Any]]: """Run a function with Tidy3D implicitly called in between.""" handler = self._get_evaluate_fn_pre_post( fn_pre=fn_pre, @@ -313,11 +321,11 @@ def _get_evaluate_fn_pre_post( fn_mid: Callable, console: Console, priority: Optional[int], - ): + ) -> Any: """Get function that tries to use batch processing on a set of arguments.""" class Pre_Post_Handler: - def __init__(self, console, priority) -> None: + def __init__(self, console: Console, priority: Optional[int]) -> None: self.sim_counter = 0 self.sim_names = [] self.sim_paths = [] @@ -566,7 +574,7 @@ def estimate_cost(self, fn_pre: Callable) -> float: # Compute fn_pre pre_out = fn_pre(**arg_dict) - def _estimate_sim_cost(sim): + def _estimate_sim_cost(sim: Simulation) -> float: job = Job(simulation=sim, task_name="estimate_cost") estimate = job.estimate_cost() @@ -649,8 +657,8 @@ def summarize(self, fn_pre: Optional[Callable] = None, verbose: bool = True) -> # If check stops it printing standard attributes arg_values = [ f"{field}: {getattr(self.method, field)}\n" - for field in self.method.__fields__ - if field not in MethodOptimize.__fields__ + for field in type(self.method).model_fields + if field not in MethodOptimize.model_fields ] param_values = [] diff --git a/tidy3d/plugins/design/method.py b/tidy3d/plugins/design/method.py index 777995079d..e1f0f1de23 100644 --- a/tidy3d/plugins/design/method.py +++ b/tidy3d/plugins/design/method.py @@ -3,34 +3,52 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Annotated, Any, Callable, Literal, Optional, Union, overload import numpy as np -import pydantic.v1 as pd +import scipy.stats.qmc as qmc +from pydantic import Field, NonNegativeFloat, PositiveFloat, PositiveInt from tidy3d.components.base import Tidy3dBaseModel from tidy3d.constants import inf -from .parameter import ParameterAny, ParameterFloat, ParameterInt, ParameterType +from .parameter import ParameterAny, ParameterFloat, ParameterInt if TYPE_CHECKING: + import pygad + from numpy.typing import NDArray + from rich.console import Console from scipy.stats import qmc as qmc_type + from .parameter import ParameterType + + +ArgsList = list[dict[str, Any]] +RunFunction = Callable[[ArgsList], list[Any]] +RunResult = tuple[ArgsList, list[Any], list[Any] | None, Any | None] + class Method(Tidy3dBaseModel, ABC): """Spec for a sweep algorithm, with a method to run it.""" - name: str = pd.Field(None, title="Name", description="Optional name for the sweep method.") + name: Optional[str] = Field( + None, title="Name", description="Optional name for the sweep method." + ) @abstractmethod - def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable) -> tuple[Any]: + def _run( + self, + parameters: tuple[ParameterType, ...], + run_fn: RunFunction, + console: Optional[Console], + ) -> RunResult: """Defines the search algorithm.""" @abstractmethod - def _get_run_count(self, parameters: Optional[list] = None) -> int: + def _get_run_count(self, parameters: Optional[list[ParameterType]] = None) -> int: """Return the maximum number of runs for the method based on current method arguments.""" - def _force_int(self, next_point: dict, parameters: list) -> None: + def _force_int(self, next_point: dict[str, Any], parameters: tuple[ParameterType, ...]) -> None: """Convert a float asigned to an int parameter to be an int. Update dict in place.""" for param in parameters: @@ -38,8 +56,17 @@ def _force_int(self, next_point: dict, parameters: list) -> None: # Using int(round()) instead of just int as int always rounds down making upper bound value impossible next_point[param.name] = int(round(next_point[param.name], 0)) - @staticmethod - def _extract_output(output: list, sampler: bool = False) -> tuple: + @overload + def _extract_output(self, output: list[Any], sampler: Literal[True]) -> list[Any]: ... + + @overload + def _extract_output( + self, output: list[Any], sampler: Literal[False] = False + ) -> tuple[list[float], list[Any]]: ... + + def _extract_output( + self, output: list[Any], sampler: bool = False + ) -> list[Any] | tuple[list[float], list[Any]]: """Format the user function output for further optimization and result storage.""" # Light check if all the outputs are the same type @@ -58,7 +85,7 @@ def _extract_output(output: list, sampler: bool = False) -> tuple: if all(isinstance(val, (float, int)) for val in output): # No aux_out none_aux = [None for _ in range(len(output))] - return (output, none_aux) + return output, none_aux if all(isinstance(val, (list, tuple)) for val in output): if all(isinstance(val[0], (float, int)) for val in output): @@ -74,7 +101,7 @@ def _extract_output(output: list, sampler: bool = False) -> tuple: aux_out.append(val[1]) # Float with aux_out - return (float_out, aux_out) + return float_out, aux_out raise ValueError( "Unrecognized output from supplied post function. The first element in the iterable object should be a 'float'." @@ -85,7 +112,9 @@ def _extract_output(output: list, sampler: bool = False) -> tuple: ) @staticmethod - def _flatten_and_append(list_of_lists: list[list], append_target: list) -> None: + def _flatten_and_append( + list_of_lists: Optional[list[list[Any]]], append_target: list[Any] + ) -> None: """Flatten a list of lists and append the sublist to a new list.""" if list_of_lists is not None: for sub_list in list_of_lists: @@ -96,13 +125,13 @@ class MethodSample(Method, ABC): """A sweep method where all points are independently computed in one iteration.""" @abstractmethod - def sample(self, parameters: tuple[ParameterType, ...], **kwargs: Any) -> dict[str, Any]: + def sample(self, parameters: tuple[ParameterType, ...], **kwargs: Any) -> ArgsList: """Defines how the design parameters are sampled.""" def _assemble_args( self, parameters: tuple[ParameterType, ...], - ) -> tuple[dict, int]: + ) -> ArgsList: """Sample design parameters, check the args are hashable and compute number of points.""" fn_args = self.sample(parameters) @@ -110,7 +139,12 @@ def _assemble_args( self._force_int(arg_dict, parameters) return fn_args - def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) -> tuple[Any]: + def _run( + self, + parameters: tuple[ParameterType, ...], + run_fn: RunFunction, + console: Optional[Console], + ) -> RunResult: """Defines the search algorithm.""" # get all function inputs @@ -135,16 +169,16 @@ class MethodGrid(MethodSample): >>> method = tdd.MethodGrid() """ - def _get_run_count(self, parameters: list) -> int: + def _get_run_count(self, parameters: list[ParameterType]) -> int: """Return the maximum number of runs for the method based on current method arguments.""" return len(self.sample(parameters)) @staticmethod - def sample(parameters: tuple[ParameterType, ...]) -> dict[str, Any]: + def sample(parameters: tuple[ParameterType, ...]) -> ArgsList: """Defines how the design parameters are sampled on the grid.""" # sample each dimension individually - vals_each_dim = {} + vals_each_dim: dict[str, list[Any]] = {} for param in parameters: vals = param.sample_grid() vals_each_dim[param.name] = vals @@ -153,7 +187,9 @@ def sample(parameters: tuple[ParameterType, ...]) -> dict[str, Any]: vals_grid = np.meshgrid(*vals_each_dim.values()) vals_grid = (np.ravel(x).tolist() for x in vals_grid) vals_dict = dict(zip(vals_each_dim.keys(), vals_grid)) - t_vals_dict = [dict(zip(vals_dict.keys(), values)) for values in zip(*vals_dict.values())] + t_vals_dict: ArgsList = [ + dict(zip(vals_dict.keys(), values)) for values in zip(*vals_dict.values()) + ] return t_vals_dict @@ -162,20 +198,23 @@ class MethodOptimize(Method, ABC): """A method for handling design searches that optimize the design.""" # NOTE: We could move this to the Method base class but it's not relevant to MethodGrid - seed: pd.PositiveInt = pd.Field( - default=None, + seed: Optional[PositiveInt] = Field( + None, title="Seed for random number generation", description="Set the seed used by the optimizers to ensure consistant random number generation.", ) - def any_to_int_param(self, parameter: ParameterAny) -> dict: + def any_to_int_param(self, parameter: ParameterAny) -> dict[int, Any]: """Convert ParameterAny object to integers and provide a conversion dict to return""" return dict(enumerate(parameter.allowed_values)) def sol_array_to_dict( - self, solution: np.array, keys: list, param_converter: dict - ) -> list[dict]: + self, + solution: NDArray[np.floating], + keys: list[str], + param_converter: dict[str, dict[int, Any]], + ) -> ArgsList: """Convert an array of solutions to a list of dicts for function input""" sol_dict_list = [dict(zip(keys, sol)) for sol in solution] @@ -183,7 +222,9 @@ def sol_array_to_dict( return sol_dict_list - def _handle_param_convert(self, param_converter: dict, sol_dict_list: list[dict]) -> None: + def _handle_param_convert( + self, param_converter: dict[str, dict[int, Any]], sol_dict_list: ArgsList + ) -> None: for param, convert in param_converter.items(): for sol in sol_dict_list: if isinstance(sol[param], float): @@ -201,41 +242,44 @@ class MethodBayOpt(MethodOptimize, ABC): >>> method = tdd.MethodBayOpt(initial_iter=4, n_iter=10) """ - initial_iter: pd.PositiveInt = pd.Field( - ..., + initial_iter: PositiveInt = Field( title="Number of Initial Random Search Iterations", description="The number of search runs to be done initialially with parameter values picked randomly. This provides a starting point for the Gaussian processor to optimize from. These solutions can be computed as a single ``Batch`` if the pre function generates ``Simulation`` objects.", ) - n_iter: pd.PositiveInt = pd.Field( - ..., + n_iter: PositiveInt = Field( title="Number of Bayesian Optimization Iterations", description="Following the initial search, this is number of iterations the Gaussian processor should be sequentially called to suggest parameter values and register the results.", ) - acq_func: Literal["ucb", "ei", "poi"] = pd.Field( + acq_func: Literal["ucb", "ei", "poi"] = Field( default="ucb", title="Type of Acquisition Function", description="The type of acquisition function that should be used to suggest parameter values. More detail available in the `package docs `_.", ) - kappa: pd.PositiveFloat = pd.Field( + kappa: PositiveFloat = Field( default=2.5, title="Kappa", description="The kappa coefficient used by the ``ucb`` acquisition function. More detail available in the `package docs `_.", ) - xi: pd.NonNegativeFloat = pd.Field( + xi: NonNegativeFloat = Field( default=0.0, title="Xi", description="The Xi coefficient used by the ``ei`` and ``poi`` acquisition functions. More detail available in the `package docs `_.", ) - def _get_run_count(self, parameters: Optional[list] = None) -> int: + def _get_run_count(self, parameters: Optional[list[ParameterType]] = None) -> int: """Return the maximum number of runs for the method based on current method arguments.""" return self.initial_iter + self.n_iter - def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) -> tuple[Any]: + def _run( + self, + parameters: tuple[ParameterType, ...], + run_fn: RunFunction, + console: Optional[Console], + ) -> RunResult: """Defines the Bayesian optimization search algorithm for the method. Uses the ``bayes_opt`` package to carry out a Bayesian optimization. Utilizes the ``.suggest`` and ``.register`` methods instead of @@ -251,8 +295,8 @@ def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) ) from None # Identify non-numeric params and define boundaries for Bay-opt - param_converter = {} - boundary_dict = {} + param_converter: dict[str, dict[int, Any]] = {} + boundary_dict: dict[str, tuple[float, float]] = {} for param in parameters: if isinstance(param, ParameterAny): param_converter[param.name] = self.any_to_int_param(param) @@ -345,98 +389,97 @@ class MethodGenAlg(MethodOptimize, ABC): >>> method = tdd.MethodGenAlg(solutions_per_pop=2, n_generations=1, n_parents_mating=2) """ - # Args for the user - solutions_per_pop: pd.PositiveInt = pd.Field( - ..., + solutions_per_pop: PositiveInt = Field( title="Solutions per Population", description="The number of solutions to be generated for each population.", ) - n_generations: pd.PositiveInt = pd.Field( - ..., + n_generations: PositiveInt = Field( title="Number of Generations", description="The maximum number of generations to run the genetic algorithm.", ) - n_parents_mating: pd.PositiveInt = pd.Field( - ..., + n_parents_mating: PositiveInt = Field( title="Number of Parents Mating", description="The number of solutions to be selected as parents for the next generation. Crossovers of these parents will produce the next population.", ) - stop_criteria_type: Literal["reach", "saturate"] = pd.Field( + stop_criteria_type: Optional[Literal["reach", "saturate"]] = Field( default=None, title="Early Stopping Criteria Type", description="Define the early stopping criteria. Supported words are 'reach' or 'saturate'. 'reach' stops at a desired fitness, 'saturate' stops when the fitness stops improving. Must set ``stop_criteria_number``. See the `PyGAD docs `_ for more details.", ) - stop_criteria_number: pd.PositiveFloat = pd.Field( + stop_criteria_number: Optional[PositiveFloat] = Field( default=None, title="Early Stopping Criteria Number", description="Must set ``stop_criteria_type``. If type is 'reach' the number is acceptable fitness value to stop the optimization. If type is 'saturate' the number is the number generations where the fitness doesn't improve before optimization is stopped. See the `PyGAD docs `_ for more details.", ) - parent_selection_type: Literal["sss", "rws", "sus", "rank", "random", "tournament"] = pd.Field( + parent_selection_type: Literal["sss", "rws", "sus", "rank", "random", "tournament"] = Field( default="sss", title="Parent Selection Type", description="The style of parent selector. See the `PyGAD docs `_ for more details.", ) - keep_parents: Union[pd.PositiveInt, Literal[-1, 0]] = pd.Field( + keep_parents: Union[PositiveInt, Literal[-1, 0]] = Field( default=-1, title="Keep Parents", description="The number of parents to keep unaltered in the population of the next generation. Default value of -1 keeps all current parents for the next generation. This value is overwritten if ``keep_parents`` is > 0. See the `PyGAD docs `_ for more details.", ) - keep_elitism: Union[pd.PositiveInt, Literal[0]] = pd.Field( + keep_elitism: Union[PositiveInt, Literal[0]] = Field( default=1, title="Keep Elitism", description="The number of top solutions to be included in the population of the next generation. Overwrites ``keep_parents`` if value is > 0. See the `PyGAD docs `_ for more details.", ) - crossover_type: Union[None, Literal["single_point", "two_points", "uniform", "scattered"]] = ( - pd.Field( - default="single_point", - title="Crossover Type", - description="The style of crossover operation. See the `PyGAD docs `_ for more details.", - ) + crossover_type: Optional[Literal["single_point", "two_points", "uniform", "scattered"]] = Field( + default="single_point", + title="Crossover Type", + description="The style of crossover operation. See the `PyGAD docs `_ for more details.", ) - crossover_prob: pd.confloat(ge=0, le=1) = pd.Field( + crossover_prob: float = Field( default=0.8, title="Crossover Probability", description="The probability of performing a crossover between two parents.", + ge=0, + le=1, ) - mutation_type: Union[None, Literal["random", "swap", "inversion", "scramble", "adaptive"]] = ( - pd.Field( - default="random", - title="Mutation Type", - description="The style of gene mutation. See the `PyGAD docs `_ for more details.", - ) + mutation_type: Optional[Literal["random", "swap", "inversion", "scramble", "adaptive"]] = Field( + default="random", + title="Mutation Type", + description="The style of gene mutation. See the `PyGAD docs `_ for more details.", ) - mutation_prob: Union[pd.confloat(ge=0, le=1), Literal[None]] = pd.Field( + mutation_prob: Optional[float] = Field( default=0.2, title="Mutation Probability", description="The probability of mutating a gene.", + ge=0, + le=1, ) - save_solution: pd.StrictBool = pd.Field( + save_solution: bool = Field( default=False, title="Save Solutions", description="Save all solutions from all generations within a numpy array. Can be accessed from the optimizer object stored in the Result. May cause memory issues with large populations or many generations. See the `PyGAD docs _` for more details.", ) - # TODO: See if anyone is interested in having the full suite of PyGAD options - there's a lot! - - def _get_run_count(self, parameters: Optional[list] = None) -> int: + def _get_run_count(self, parameters: Optional[list[ParameterType]] = None) -> int: """Return the maximum number of runs for the method based on current method arguments.""" # +1 to generations as pygad creates an initial population which is effectively "Generation 0" run_count = self.solutions_per_pop * (self.n_generations + 1) return run_count - def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) -> tuple[Any]: + def _run( + self, + parameters: tuple[ParameterType, ...], + run_fn: RunFunction, + console: Optional[Console], + ) -> RunResult: """Defines the genetic algorithm for the method. Uses the ``pygad`` package to carry out a particle search optimization. Additional development has ensured that @@ -454,15 +497,15 @@ def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) param_keys = [param.name for param in parameters] # Store parameters and fitness - store_parameters = [] - store_fitness = [] - store_aux = [] - previous_solutions = {} + store_parameters: ArgsList = [] + store_fitness: list[NDArray[np.floating]] = [] + store_aux: list[Any] = [] + previous_solutions: dict[str, tuple[float, Any]] = {} # Set gene_spaces to keep GA within ranges - param_converter = {} - gene_spaces = [] - gene_types = [] + param_converter: dict[str, dict[int, Any]] = {} + gene_spaces: list[Any] = [] + gene_types: list[type[Any]] = [] for param in parameters: if isinstance(param, ParameterFloat): gene_spaces.append({"low": param.span[0], "high": param.span[1]}) @@ -479,9 +522,9 @@ def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) gene_spaces.append(range(len(param.allowed_values))) gene_types.append(int) - def capture_aux(sol_dict_list: list[dict]) -> None: + def capture_aux(sol_dict_list: ArgsList) -> None: """Store the aux data by pulling from previous_solutions.""" - aux_out = [] + aux_out: list[Any] = [] for sol in sol_dict_list: composite_key = str(sol.keys()) + str(sol.values()) _, aux_data = previous_solutions[composite_key] @@ -490,16 +533,20 @@ def capture_aux(sol_dict_list: list[dict]) -> None: self._flatten_and_append(aux_out, store_aux) # Create fitness function combining pre and post fn with the tidy3d call - def fitness_function(ga_instance: pygad.GA, solution: np.array, solution_idx) -> dict: + def fitness_function( + ga_instance: pygad.GA, + solution: NDArray[np.floating], + solution_idx: int, + ) -> list[float]: """Fitness function for GA. Format of inputs cannot be changed.""" # Break solution down to list of dict sol_dict_list = self.sol_array_to_dict(solution, param_keys, param_converter) # Check if solution already exists # Have to update the solutions as need to pass to run_fn together to be batched - known_sol = {} - unknown_sol = [] - unknown_keys = [] + known_sol: dict[int, str] = {} + unknown_sol: ArgsList = [] + unknown_keys: list[str] = [] for sol_idx, sol in enumerate(sol_dict_list): composite_key = str(sol.keys()) + str(sol.values()) @@ -565,9 +612,11 @@ def on_generation(ga_instance: pygad.GA) -> None: num_genes = len(parameters) # PyGAD doesn't store the initial population fitness - this captures parameters, fitness and aux data - init_state = [] + init_state: list[str] = [] - def capture_init_pop_fitness(ga_instance: pygad.GA, population_fitness) -> None: + def capture_init_pop_fitness( + ga_instance: pygad.GA, population_fitness: NDArray[np.floating] + ) -> None: """Store the initial population fitness which PyGAD otherwise ignores Has to be run ``on_fitness`` but contains a check so that it only runs on the first pass @@ -633,59 +682,62 @@ class MethodParticleSwarm(MethodOptimize, ABC): >>> method = tdd.MethodParticleSwarm(n_particles=5, n_iter=3) """ - n_particles: pd.PositiveInt = pd.Field( - ..., + n_particles: PositiveInt = Field( title="Number of Particles", description="The number of particles to be used in the swarm for the optimization.", ) - n_iter: pd.PositiveInt = pd.Field( - ..., + n_iter: PositiveInt = Field( title="Number of Iterations", description="The maxmium number of iterations to run the optimization.", ) - cognitive_coeff: pd.PositiveFloat = pd.Field( + cognitive_coeff: PositiveFloat = Field( default=1.5, title="Cognitive Coefficient", description="The cognitive parameter decides how attracted the particle is to its previous best position.", ) - social_coeff: pd.PositiveFloat = pd.Field( + social_coeff: PositiveFloat = Field( default=1.5, title="Social Coefficient", description="The social parameter decides how attracted the particle is to the global best position found by the swarm.", ) - weight: pd.PositiveFloat = pd.Field( + weight: PositiveFloat = Field( default=0.9, title="Weight", description="The weight or inertia of particles in the optimization.", ) - ftol: Union[pd.confloat(ge=0, le=1), Literal[-inf]] = pd.Field( + ftol: Union[Annotated[float, Field(ge=0, le=1)], Literal[-inf]] = Field( default=-inf, title="Relative Error for Convergence", description="Relative error in ``objective_func(best_solution)`` acceptable for convergence. See the `PySwarms docs `_ for details. Off by default.", ) - ftol_iter: pd.PositiveInt = pd.Field( + ftol_iter: PositiveInt = Field( default=1, title="Number of Iterations Before Convergence", description="Number of iterations over which the relative error in the objective_func is acceptable for convergence.", ) - init_pos: np.ndarray = pd.Field( + init_pos: Optional[np.ndarray] = Field( default=None, title="Initial Swarm Positions", description="Set the initial positions of the swarm using a numpy array of appropriate size.", ) - def _get_run_count(self, parameters: Optional[list] = None) -> int: + def _get_run_count(self, parameters: Optional[list[ParameterType]] = None) -> int: """Return the maximum number of runs for the method based on current method arguments.""" return self.n_particles * self.n_iter - def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) -> tuple[Any]: + def _run( + self, + parameters: tuple[ParameterType, ...], + run_fn: RunFunction, + console: Optional[Console], + ) -> RunResult: """Defines the particle search optimization algorithm for the method. Uses the ``pyswarms`` package to carry out a particle search optimization. @@ -705,14 +757,14 @@ def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) # Variable assignment here so it is available to the fitness function param_keys = [param.name for param in parameters] - store_parameters = [] - store_fitness = [] - store_aux = [] + store_parameters: list[ArgsList] = [] + store_fitness: list[list[float]] = [] + store_aux: list[Any] = [] # Build bounds and conversion dict for ParameterAny inputs - param_converter = {} - min_bound = [] - max_bound = [] + param_converter: dict[str, dict[int, Any]] = {} + min_bound: list[float] = [] + max_bound: list[float] = [] for param in parameters: if isinstance(param, ParameterAny): param_converter[param.name] = self.any_to_int_param(param) @@ -722,9 +774,9 @@ def _run(self, parameters: tuple[ParameterType, ...], run_fn: Callable, console) min_bound.append(param.span[0]) max_bound.append(param.span[1]) - bounds = (min_bound, max_bound) + bounds: tuple[list[float], list[float]] = (min_bound, max_bound) - def fitness_function(solution: np.array) -> np.array: + def fitness_function(solution: NDArray[np.floating]) -> NDArray[np.floating]: """Fitness function for PSO. Input format cannot be changed""" # Correct solutions that should be ints sol_dict_list = self.sol_array_to_dict(solution, param_keys, param_converter) @@ -764,8 +816,8 @@ def fitness_function(solution: np.array) -> np.array: ) # Collapse stores into fn_args and results lists - fn_args = [val for sublist in store_parameters for val in sublist] - results = [val for sublist in store_fitness for val in sublist] + fn_args: ArgsList = [val for sublist in store_parameters for val in sublist] + results: list[float] = [val for sublist in store_fitness for val in sublist] return fn_args, results, store_aux, optimizer @@ -773,13 +825,12 @@ def fitness_function(solution: np.array) -> np.array: class AbstractMethodRandom(MethodSample, ABC): """Select parameters with an object with a ``random`` method.""" - num_points: pd.PositiveInt = pd.Field( - ..., + num_points: PositiveInt = Field( title="Number of Sampling Points", description="The number of points to be generated for sampling.", ) - seed: pd.PositiveInt = pd.Field( + seed: Optional[PositiveInt] = Field( default=None, title="Seed", description="Sets the seed used by the optimizers to set constant random number generation.", @@ -789,11 +840,11 @@ class AbstractMethodRandom(MethodSample, ABC): def _get_sampler(self, parameters: tuple[ParameterType, ...]) -> qmc_type.QMCEngine: """Sampler for this ``Method`` class. If ``None``, sets a default.""" - def _get_run_count(self, parameters: Optional[list] = None) -> int: + def _get_run_count(self, parameters: Optional[list[ParameterType]] = None) -> int: """Return the maximum number of runs for the method based on current method arguments.""" return self.num_points - def sample(self, parameters: tuple[ParameterType, ...], **kwargs: Any) -> list[dict[str, Any]]: + def sample(self, parameters: tuple[ParameterType, ...], **kwargs: Any) -> ArgsList: """Defines how the design parameters are sampled on grid.""" sampler = self._get_sampler(parameters) @@ -808,7 +859,7 @@ def sample(self, parameters: tuple[ParameterType, ...], **kwargs: Any) -> list[d # Get output list of kwargs for pre_fn keys = [param.name for param in parameters] - result = [{keys[j]: row[j] for j in range(len(keys))} for row in args_by_sample] + result: ArgsList = [{keys[j]: row[j] for j in range(len(keys))} for row in args_by_sample] return result @@ -825,7 +876,6 @@ class MethodMonteCarlo(AbstractMethodRandom): def _get_sampler(self, parameters: tuple[ParameterType, ...]) -> qmc_type.QMCEngine: """Sampler for this ``Method`` class.""" - from scipy.stats import qmc d = len(parameters) return qmc.LatinHypercube(d=d, seed=self.seed) diff --git a/tidy3d/plugins/design/parameter.py b/tidy3d/plugins/design/parameter.py index 4d96646c2e..d901517ecb 100644 --- a/tidy3d/plugins/design/parameter.py +++ b/tidy3d/plugins/design/parameter.py @@ -3,31 +3,34 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveInt, field_validator from tidy3d.components.base import Tidy3dBaseModel +if TYPE_CHECKING: + from numpy.typing import NDArray + class Parameter(Tidy3dBaseModel, ABC): """Specification for a single variable / dimension in a design problem.""" - name: str = pd.Field( - ..., + name: str = Field( title="Name", description="Unique name for the variable. Used as a key into the parameter sweep results.", ) - values: tuple[Any, ...] = pd.Field( + values: Optional[tuple[Any, ...]] = Field( None, title="Custom Values", description="If specified, the parameter scan uses these values for grid search methods.", ) - @pd.validator("values", always=True) - def _values_unique(cls, val): + @field_validator("values") + @classmethod + def _values_unique(cls, val: Optional[tuple[Any, ...]]) -> Optional[tuple[Any, ...]]: """Supplied unique values.""" if (val is not None) and (len(set(val)) != len(val)): raise ValueError("Supplied 'values' were not unique.") @@ -36,7 +39,7 @@ def _values_unique(cls, val): def sample_grid(self) -> list[Any]: """Sample design variable on grid, checking for custom values.""" if self.values is not None: - return self.values + return list(self.values) return self._sample_grid() @abstractmethod @@ -48,7 +51,7 @@ def _sample_grid(self) -> list[Any]: """Sample this design variable on a grid.""" @abstractmethod - def select_from_01(self, pts_01: np.ndarray) -> list[Any]: + def select_from_01(self, pts_01: NDArray[np.floating]) -> list[Any]: """Select values given a set of points between 0, 1.""" @abstractmethod @@ -59,14 +62,16 @@ def sample_first(self) -> Any: class ParameterNumeric(Parameter, ABC): """A variable with numeric values.""" - span: tuple[Union[float, int], Union[float, int]] = pd.Field( - ..., + span: tuple[Union[float, int], Union[float, int]] = Field( title="Span", description="(min, max) range within which are allowed values for the variable. Is inclusive of max value.", ) - @pd.validator("span", always=True) - def _span_valid(cls, val): + @field_validator("span") + @classmethod + def _span_valid( + cls, val: tuple[Union[float, int], Union[float, int]] + ) -> tuple[Union[float, int], Union[float, int]]: """Span min <= span max.""" span_min, span_max = val if span_min > span_max: @@ -76,13 +81,13 @@ def _span_valid(cls, val): return val @property - def span_size(self): + def span_size(self) -> float: """Size of the span of this numeric variable.""" span_min = min(self.span) span_max = max(self.span) return span_max - span_min - def sample_first(self) -> tuple: + def sample_first(self) -> Union[float, int]: """Output the first allowed sample.""" return self.span[0] @@ -96,15 +101,18 @@ class ParameterFloat(ParameterNumeric): >>> var = tdd.ParameterFloat(name="x", num_points=10, span=(1, 2.5)) """ - num_points: pd.PositiveInt = pd.Field( + num_points: Optional[PositiveInt] = Field( None, title="Number of Points", description="Number of uniform sampling points for this variable. " "Only used for 'MethodGrid'. ", ) - @pd.validator("span", always=True) - def _span_is_float(cls, val): + @field_validator("span") + @classmethod + def _span_is_float( + cls, val: tuple[Union[float, int], Union[float, int]] + ) -> tuple[float, float]: """Make sure the span contains floats.""" low, high = val return float(low), float(high) @@ -121,7 +129,7 @@ def _sample_grid(self) -> list[float]: low, high = self.span return np.linspace(low, high, self.num_points).tolist() - def select_from_01(self, pts_01: np.ndarray) -> list[Any]: + def select_from_01(self, pts_01: NDArray[np.floating]) -> list[float]: """Select values given a set of points between 0, 1.""" return (min(self.span) + pts_01 * self.span_size).tolist() @@ -136,16 +144,16 @@ class ParameterInt(ParameterNumeric): >>> var = tdd.ParameterInt(name="x", span=(1, 4)) """ - span: tuple[int, int] = pd.Field( - ..., + span: tuple[int, int] = Field( title="Span", description="``(min, max)`` range within which are allowed values for the variable. " "The ``min`` value is inclusive and the ``max`` value is exclusive. In other words, " "a grid search over this variable will iterate over ``np.arange(min, max)``.", ) - @pd.validator("span", always=True) - def _span_is_int(cls, val): + @field_validator("span") + @classmethod + def _span_is_int(cls, val: tuple[Union[float, int], Union[float, int]]) -> tuple[int, int]: """Make sure the span contains ints.""" low, high = val return int(low), int(high) @@ -160,7 +168,7 @@ def _sample_grid(self) -> list[float]: low, high = self.span return np.arange(low, high).tolist() - def select_from_01(self, pts_01: np.ndarray) -> list[Any]: + def select_from_01(self, pts_01: NDArray[np.floating]) -> list[int]: """Select values given a set of points between 0, 1.""" pts_continuous = min(self.span) + pts_01 * self.span_size return np.floor(pts_continuous).astype(int).tolist() @@ -175,21 +183,22 @@ class ParameterAny(Parameter): >>> var = tdd.ParameterAny(name="x", allowed_values=("a", "b", "c")) """ - allowed_values: tuple[Any, ...] = pd.Field( - ..., + allowed_values: tuple[Any, ...] = Field( title="Allowed Values", description="The discrete set of values that this variable can take on.", ) - @pd.validator("allowed_values", always=True) - def _given_any_allowed_values(cls, val): + @field_validator("allowed_values") + @classmethod + def _given_any_allowed_values(cls, val: tuple[Any, ...]) -> tuple[Any, ...]: """Need at least one allowed value.""" if not len(val): raise ValueError("Given empty tuple of allowed values. Must have at least one.") return val - @pd.validator("allowed_values", always=True) - def _no_duplicate_allowed_values(cls, val): + @field_validator("allowed_values") + @classmethod + def _no_duplicate_allowed_values(cls, val: tuple[Any, ...]) -> tuple[Any, ...]: """No duplicates in allowed_values.""" if len(val) != len(set(val)): raise ValueError("'allowed_values' has duplicate entries, must be unique.") @@ -203,7 +212,7 @@ def _sample_grid(self) -> list[Any]: """Sample this design variable uniformly, ie just take all allowed values.""" return list(self.allowed_values) - def select_from_01(self, pts_01: np.ndarray) -> list[Any]: + def select_from_01(self, pts_01: NDArray[np.floating]) -> list[Any]: """Select values given a set of points between 0, 1.""" pts_continuous = pts_01 * len(self.allowed_values) indices = np.floor(pts_continuous).astype(int) diff --git a/tidy3d/plugins/design/result.py b/tidy3d/plugins/design/result.py index 1d69022627..87624c323c 100644 --- a/tidy3d/plugins/design/result.py +++ b/tidy3d/plugins/design/result.py @@ -2,14 +2,19 @@ from __future__ import annotations -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import numpy as np import pandas -import pydantic.v1 as pd +from pydantic import Field, model_validator from tidy3d.components.base import Tidy3dBaseModel, cached_property +if TYPE_CHECKING: + from collections.abc import Iterator + + from tidy3d.compat import Self + # NOTE: Coords are args_dict from method and design. This may be changed in future to unify naming @@ -30,39 +35,39 @@ class Result(Tidy3dBaseModel): >>> # df.head() # print out first 5 elements of data """ - dims: tuple[str, ...] = pd.Field( + dims: tuple[str, ...] = Field( (), title="Dimensions", description="The dimensions of the design variables (indexed by 'name').", ) - values: tuple[Any, ...] = pd.Field( + values: tuple[Any, ...] = Field( (), title="Values", description="The return values from the design problem function.", ) - coords: tuple[tuple[Any, ...], ...] = pd.Field( + coords: tuple[tuple[Any, ...], ...] = Field( (), title="Coordinates", description="The values of the coordinates corresponding to each of the dims." "Note: shaped (D, N) where D is the ``len(dims)`` and N is the ``len(values)``", ) - output_names: tuple[str, ...] = pd.Field( + output_names: Optional[tuple[str, ...]] = Field( None, title="Output Names", description="Names for each of the outputs stored in ``values``. If not specified, default " "values are assigned.", ) - fn_source: str = pd.Field( + fn_source: Optional[str] = Field( None, title="Function Source Code", description="Source code for the function evaluated in the parameter sweep.", ) - task_names: list = pd.Field( + task_names: Optional[list] = Field( None, title="Task Names", description="Task name of every simulation run during ``DesignSpace.run``. Only available if " @@ -70,7 +75,7 @@ class Result(Tidy3dBaseModel): "Stored in the same format as the output of fn_pre i.e. if pre outputs a dict, this output is a dict with the keys preserved.", ) - task_paths: list = pd.Field( + task_paths: Optional[list] = Field( None, title="Task Paths", description="Task paths of every simulation run during ``DesignSpace.run``. Useful for loading download ``SimulationData`` hdf5 files." @@ -78,50 +83,48 @@ class Result(Tidy3dBaseModel): "Stored in the same format as the output of fn_pre i.e. if pre outputs a dict, this output is a dict with the keys preserved.", ) - aux_values: tuple[Any, ...] = pd.Field( + aux_values: Optional[tuple[Any, ...]] = Field( None, title="Auxiliary values output from the user function", description="The auxiliary return values from the design problem function. This is the collection of objects returned " "alongside the float value used for the optimization. These weren't used to inform the optimizer, if one was used.", ) - optimizer: Any = pd.Field( + optimizer: Any = Field( None, title="Optimizer object", description="The optimizer returned at the end of an optimizer run. Can be used to analyze and plot how the optimization progressed. " "Attributes depend on the optimizer used; a full explaination of the optimizer can be found on associated library doc pages. Will be ``None`` for sampling based methods.", ) - @pd.validator("coords", always=True) - def _coords_and_dims_shape(cls, val, values): + @model_validator(mode="after") + def _coords_and_dims_shape(self) -> Self: """Make sure coords and dims have same size.""" - dims = values.get("dims") - - if val is None or dims is None: - return None + if self.coords is None or self.dims is None: + return self - num_dims = len(dims) - for i, _val in enumerate(val): + num_dims = len(self.dims) + for i, _val in enumerate(self.coords): if len(_val) != num_dims: raise ValueError( f"Number of 'coords' at index '{i}' ({len(_val)}) " f"doesn't match the number of 'dims' ({num_dims})." ) - return val + return self - @pd.validator("coords", always=True) - def _coords_and_values_shape(cls, val, values): + @model_validator(mode="after") + def _coords_and_values_shape(self) -> Self: """Make sure coords and values have same length.""" - _values = values.get("values") + _values = self.values - if val is None or _values is None: - return None + if self.coords is None or _values is None: + return self num_values = len(_values) - num_coords = len(val) + num_coords = len(self.coords) if num_values != num_coords: raise ValueError( @@ -129,9 +132,9 @@ def _coords_and_values_shape(cls, val, values): f"Have {num_coords} and {num_values} elements, respectively." ) - return val + return self - def value_as_dict(self, value) -> dict[str, Any]: + def value_as_dict(self, value: Any) -> dict[str, Any]: """How to convert an output function value as a dictionary.""" if isinstance(value, dict): return value @@ -141,7 +144,7 @@ def value_as_dict(self, value) -> dict[str, Any]: return dict(zip(keys, value)) @staticmethod - def default_value_keys(value) -> tuple[str, ...]: + def default_value_keys(value: Any) -> tuple[str, ...]: """The default keys for a given value.""" # if a dict already, just use the existing keys as labels @@ -155,7 +158,7 @@ def default_value_keys(value) -> tuple[str, ...]: # if simply single value (float, int, bool, etc) just label "output" return ("output",) - def items(self) -> tuple[dict, Any]: + def items(self) -> Iterator[tuple[dict[str, Any], Any]]: """Iterate through coordinates (args) and values (outputs) one by one.""" for coord_tuple, val in zip(self.coords, self.values): @@ -163,7 +166,7 @@ def items(self) -> tuple[dict, Any]: yield coord_dict, val @cached_property - def data(self) -> dict[tuple, Any]: + def data(self) -> dict[tuple[Any, ...], Any]: """Dict mapping tuple of fn args to their value.""" result = {} @@ -173,7 +176,7 @@ def data(self) -> dict[tuple, Any]: return result - def get_value(self, coords: tuple) -> Any: + def get_value(self, coords: tuple[Any, ...]) -> Any: """Get a data element indexing by function arg tuple.""" return self.data[coords] @@ -253,7 +256,7 @@ def from_dataframe(cls, df: pandas.DataFrame, dims: Optional[list[str]] = None) ---------- df : ``pandas.DataFrame`` ```DataFrame`` object to load into a :class:`.Result`. - dims : List[str] = None + dims : list[str] = None Set of dimensions corresponding to the function arguments. Not required if this dataframe was generated directly from a :class:`.Result` without modification. In that case, it contains the dims in its ``.attrs`` metadata. @@ -321,7 +324,7 @@ def combine(self, other: Result) -> Result: if self.dims != other.dims: raise ValueError("Can't combine results, dimensions don't match.") - def combine_tuples(tuple1: tuple, tuple2: tuple): + def combine_tuples(tuple1: tuple[Any, ...], tuple2: tuple[Any, ...]) -> list[Any] | None: """Combine two tuples together if not None.""" if tuple1 is None and tuple2 is None: return None @@ -342,7 +345,7 @@ def combine_tuples(tuple1: tuple, tuple2: tuple): task_names=task_names, ) - def __add__(self, other): + def __add__(self, other: Result) -> Result: """Special syntax for design_result1 + design_result2.""" return self.combine(other) @@ -358,7 +361,7 @@ def delete(self, fn_args: dict[str, float]) -> Result: Parameters ---------- - fn_args : Dict[str, float] + fn_args : dict[str, float] ``dict`` containing the function arguments one wishes to delete. Returns @@ -397,7 +400,7 @@ def add(self, fn_args: dict[str, float], value: Any) -> Result: Parameters ---------- - fn_args : Dict[str, float] + fn_args : dict[str, float] ``dict`` containing the function arguments one wishes to add. value : Any Data point value corresponding to these arguments. @@ -428,12 +431,12 @@ def add(self, fn_args: dict[str, float], value: Any) -> Result: return self.updated_copy(values=new_values, coords=new_coords) - def __len__(self): + def __len__(self) -> int: """Implement len function to return the number of items in the result.""" return len(self.coords) - def __getitem__(self, data_index): + def __getitem__(self, data_index: int | slice) -> tuple[np.ndarray, np.ndarray]: """Implement the accessor function to index into the coordinates and values of the result.""" features = self.coords[data_index] diff --git a/tidy3d/plugins/dispersion/fit.py b/tidy3d/plugins/dispersion/fit.py index 56a5aaef78..1b3c15877a 100644 --- a/tidy3d/plugins/dispersion/fit.py +++ b/tidy3d/plugins/dispersion/fit.py @@ -4,42 +4,49 @@ import codecs import csv -from os import PathLike -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import numpy as np import requests -from pydantic.v1 import Field, validator +import scipy.optimize as opt +from pydantic import Field, field_validator, model_validator from rich.progress import Progress -from tidy3d.components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.medium import AbstractMedium, PoleResidue -from tidy3d.components.types import ArrayFloat1D, Ax +from tidy3d.components.types import ArrayFloat1D from tidy3d.components.viz import add_ax_if_none from tidy3d.config import config from tidy3d.constants import C_0, HBAR, MICROMETER from tidy3d.exceptions import SetupError, ValidationError, WebError from tidy3d.log import get_logging_console, log +if TYPE_CHECKING: + from collections.abc import Sequence + from os import PathLike + + from numpy.typing import NDArray + + from tidy3d.compat import Self + from tidy3d.components.types import Ax + class DispersionFitter(Tidy3dBaseModel): """Tool for fitting refractive index data to get a dispersive medium described by :class:`.PoleResidue` model.""" wvl_um: ArrayFloat1D = Field( - ..., title="Wavelength data", description="Wavelength data in micrometers.", units=MICROMETER, ) n_data: ArrayFloat1D = Field( - ..., title="Index of refraction data", description="Real part of the complex index of refraction.", ) - k_data: ArrayFloat1D = Field( + k_data: Optional[ArrayFloat1D] = Field( None, title="Extinction coefficient data", description="Imaginary part of the complex index of refraction.", @@ -53,32 +60,29 @@ class DispersionFitter(Tidy3dBaseModel): units=MICROMETER, ) - @validator("wvl_um", always=True) - def _setup_wvl(cls, val): + @field_validator("wvl_um") + @classmethod + def _setup_wvl(cls, val: ArrayFloat1D) -> ArrayFloat1D: """Convert wvl_um to a numpy array.""" if val.size == 0: raise ValidationError("Wavelength data cannot be empty.") return val - @validator("n_data", always=True) - @skip_if_fields_missing(["wvl_um"]) - def _ndata_length_match_wvl(cls, val, values): + @model_validator(mode="after") + def _ndata_length_match_wvl(self) -> Self: """Validate n_data""" - - if val.shape != values["wvl_um"].shape: + if self.n_data.shape != self.wvl_um.shape: raise ValidationError("The length of 'n_data' doesn't match 'wvl_um'.") - return val + return self - @validator("k_data", always=True) - @skip_if_fields_missing(["wvl_um"]) - def _kdata_setup_and_length_match(cls, val, values): + @model_validator(mode="after") + def _kdata_setup_and_length_match(self) -> Self: """Validate the length of k_data, or setup k if it's None.""" - - if val is None: - return np.zeros_like(values["wvl_um"]) - if val.shape != values["wvl_um"].shape: + if self.k_data is None: + object.__setattr__(self, "k_data", np.zeros_like(self.wvl_um)) + if self.k_data.shape != self.wvl_um.shape: raise ValidationError("The length of 'k_data' doesn't match 'wvl_um'.") - return val + return self @cached_property def data_in_range(self) -> tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D]: @@ -86,7 +90,7 @@ def data_in_range(self) -> tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D]: Returns ------- - Tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] + tuple[ArrayFloat1D, ArrayFloat1D, ArrayFloat1D] Filtered wvl_um, n_data, k_data """ @@ -133,7 +137,7 @@ def freqs(self) -> tuple[float, ...]: Returns ------- - Tuple[float, ...] + tuple[float, ...] Frequency array converted from filtered input wavelength data """ @@ -146,14 +150,16 @@ def frequency_range(self) -> tuple[float, float]: Returns ------- - Tuple[float, float] + tuple[float, float] The minimal frequency and the maximal frequency """ return self.freqs.min(), self.freqs.max() @staticmethod - def _unpack_coeffs(coeffs): + def _unpack_coeffs( + coeffs: NDArray[np.floating], + ) -> tuple[NDArray[np.complexfloating], NDArray[np.complexfloating]]: """Unpack coefficient vector into complex pole parameters. Parameters @@ -163,7 +169,7 @@ def _unpack_coeffs(coeffs): Returns ------- - Tuple[np.ndarray[complex], np.ndarray[complex]] + tuple[np.ndarray[complex], np.ndarray[complex]] "a" and "c" poles for the PoleResidue model. """ if len(coeffs) % 4 != 0: @@ -179,7 +185,9 @@ def _unpack_coeffs(coeffs): return poles_a, poles_c @staticmethod - def _pack_coeffs(pole_a, pole_c): + def _pack_coeffs( + pole_a: NDArray[np.complexfloating], pole_c: NDArray[np.complexfloating] + ) -> NDArray[np.floating]: """Pack complex a and c pole parameters into coefficient array. Parameters @@ -198,7 +206,7 @@ def _pack_coeffs(pole_a, pole_c): return stacked_coeffs.flatten() @staticmethod - def _coeffs_to_poles(coeffs): + def _coeffs_to_poles(coeffs: NDArray[np.floating]) -> list[tuple[complex, complex]]: """Convert model coefficients to poles. Parameters @@ -208,7 +216,7 @@ def _coeffs_to_poles(coeffs): Returns ------- - List[Tuple[complex, complex]] + list[tuple[complex, complex]] List of complex poles (a, c) """ coeffs_scaled = coeffs / HBAR @@ -216,12 +224,12 @@ def _coeffs_to_poles(coeffs): return list(zip(poles_a, poles_c)) @staticmethod - def _poles_to_coeffs(poles): + def _poles_to_coeffs(poles: Sequence[tuple[complex, complex]]) -> NDArray[np.floating]: """Convert poles to model coefficients. Parameters ---------- - poles : List[Tuple[complex, complex]] + poles : list[tuple[complex, complex]] List of complex poles (a, c) Returns @@ -234,7 +242,7 @@ def _poles_to_coeffs(poles): return coeffs * HBAR @staticmethod - def _eV_to_Hz(f_eV: float): + def _eV_to_Hz(f_eV: float) -> float: """Convert frequency in unit of eV to Hz. Parameters @@ -245,7 +253,7 @@ def _eV_to_Hz(f_eV: float): return f_eV / (HBAR * 2 * np.pi) @staticmethod - def _Hz_to_eV(f_Hz: float): + def _Hz_to_eV(f_Hz: float) -> float: """Convert frequency in unit of Hz to eV. Parameters @@ -278,7 +286,7 @@ def fit( Returns ------- - Tuple[:class:`.PoleResidue`, float] + tuple[:class:`.PoleResidue`, float] Best results of multiple fits: (dispersive medium, RMS error). """ @@ -326,7 +334,7 @@ def fit( log.info("Returning best fit with RMS error %.3g", best_rms) return best_medium, best_rms - def _make_medium(self, coeffs): + def _make_medium(self, coeffs: NDArray[np.floating]) -> PoleResidue: """Return medium from coeffs from optimizer. Parameters @@ -358,13 +366,14 @@ def _fit_single( Returns ------- - Tuple[:class:`.PoleResidue`, float] + tuple[:class:`.PoleResidue`, float] Results of single fit: (dispersive medium, RMS error). """ - import scipy.optimize as opt # NOTE: Not used - def constraint(coeffs, _grad=None): + def constraint( + coeffs: NDArray[np.floating], _grad: NDArray[np.floating] | None = None + ) -> float: """Evaluate the nonlinear stability criterion of Hongjin Choi, Jae-Woo Baek, and Kyung-Young Jung, "Comprehensive Study on Numerical Aspects of Modified Lorentz Model Based Dispersive FDTD Formulations," IEEE TAP 2019. @@ -391,7 +400,9 @@ def constraint(coeffs, _grad=None): res[res >= 0] = 0 return np.sum(res) - def objective(coeffs, _grad=None): + def objective( + coeffs: NDArray[np.floating], _grad: NDArray[np.floating] | None = None + ) -> float: """Objective function for fit Parameters @@ -765,7 +776,7 @@ def from_complex_permittivity( Real parts of relative permittivity data eps_imag : Optional[ArrayFloat1D] Imaginary parts of relative permittivity data; `None` for lossless medium. - wvg_range : Tuple[Optional[float], Optional[float]] + wvg_range : tuple[Optional[float], Optional[float]] Wavelength range [wvl_min,wvl_max] for fitting. Returns @@ -797,7 +808,7 @@ def from_loss_tangent( Real parts of relative permittivity data loss_tangent : Optional[ArrayFloat1D] Loss tangent data, defined as the ratio of imaginary and real parts of permittivity. - wvl_range : Tuple[Optional[float], Optional[float]] + wvl_range : tuple[Optional[float], Optional[float]] Wavelength range [wvl_min,wvl_max] for fitting. Returns diff --git a/tidy3d/plugins/dispersion/fit_fast.py b/tidy3d/plugins/dispersion/fit_fast.py index 356dbde0a9..7ee9e1b986 100644 --- a/tidy3d/plugins/dispersion/fit_fast.py +++ b/tidy3d/plugins/dispersion/fit_fast.py @@ -2,10 +2,9 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING import numpy as np -from pydantic.v1 import NonNegativeFloat, PositiveInt from tidy3d.components.dispersion_fitter import ( AdvancedFastFitterParam, @@ -17,6 +16,11 @@ from .fit import DispersionFitter +if TYPE_CHECKING: + from typing import Optional + + from pydantic import NonNegativeFloat, PositiveInt + # numerical tolerance for pole relocation for fast fitter TOL = 1e-8 # numerical cutoff for passivity testing @@ -92,7 +96,7 @@ def fit( Returns ------- - Tuple[:class:`.PoleResidue`, float] + tuple[:class:`.PoleResidue`, float] Best fitting result: (dispersive medium, weighted RMS error). """ @@ -136,7 +140,7 @@ def constant_loss_tangent_model( Real part of permittivity loss_tangent : float Loss tangent. - frequency_range : Tuple[float, float] + frequency_range : tuple[float, float] Freqquency range for the material to exhibit constant loss tangent response. max_num_poles : PositiveInt, optional Maximum number of poles in the model. diff --git a/tidy3d/plugins/dispersion/web.py b/tidy3d/plugins/dispersion/web.py index 69a2666c06..0aded833b4 100644 --- a/tidy3d/plugins/dispersion/web.py +++ b/tidy3d/plugins/dispersion/web.py @@ -4,13 +4,12 @@ import ssl from enum import Enum -from typing import Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional -import pydantic.v1 as pydantic import requests -from pydantic.v1 import Field, NonNegativeFloat, PositiveFloat, PositiveInt, validator +from pydantic import Field, NonNegativeFloat, PositiveFloat, PositiveInt, model_validator -from tidy3d.components.base import Tidy3dBaseModel, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.medium import PoleResidue from tidy3d.components.types import Undefined from tidy3d.config import config @@ -21,6 +20,9 @@ from .fit import DispersionFitter +if TYPE_CHECKING: + from tidy3d.compat import Self + BOUND_MAX_FACTOR = 10 URL_ENV = { @@ -40,7 +42,7 @@ class ExceptionCodes(Enum): class AdvancedFitterParam(Tidy3dBaseModel): """Advanced fitter parameters""" - bound_amp: NonNegativeFloat = Field( + bound_amp: Optional[NonNegativeFloat] = Field( None, title="Upper bound of oscillator strength", description="Upper bound of real and imagniary part of oscillator " @@ -48,7 +50,7 @@ class AdvancedFitterParam(Tidy3dBaseModel): "automatic setup based on the frequency range of interest).", units=HERTZ, ) - bound_f: NonNegativeFloat = Field( + bound_f: Optional[NonNegativeFloat] = Field( None, title="Upper bound of pole frequency", description="Upper bound of real and imaginary part of ``a`` that corresponds to pole " @@ -96,38 +98,37 @@ class AdvancedFitterParam(Tidy3dBaseModel): lt=2**32, ) - @validator("bound_f_lower", always=True) - @skip_if_fields_missing(["bound_f"]) - def _validate_lower_frequency_bound(cls, val, values): + @model_validator(mode="after") + def _validate_lower_frequency_bound(self) -> Self: """bound_f_lower cannot be larger than bound_f.""" - if values["bound_f"] is not None and val > values["bound_f"]: + if self.bound_f is not None and self.bound_f_lower > self.bound_f: raise SetupError( "The upper bound 'bound_f' cannot be smaller than the lower bound 'bound_f_lower'." ) - return val + return self class FitterData(AdvancedFitterParam): """Data class for request body of Fitter where dipsersion data is input through tuple.""" wvl_um: tuple[float, ...] = Field( - ..., title="Wavelengths", description="A set of wavelengths for dispersion data.", units=MICROMETER, ) n_data: tuple[float, ...] = Field( - ..., title="Index of refraction", description="Real part of the complex index of refraction at each wavelength.", ) - k_data: tuple[float, ...] = Field( + k_data: Optional[tuple[float, ...]] = Field( None, title="Extinction coefficient", description="Imaginary part of the complex index of refraction at each wavelength.", ) num_poles: PositiveInt = Field( - 1, title="Number of poles", description="Number of poles in model." + 1, + title="Number of poles", + description="Number of poles in model.", ) num_tries: PositiveInt = Field( 50, @@ -217,7 +218,7 @@ def create( return task @staticmethod - def _set_url(config_env: Literal["default", "dev", "prod", "local"] = "default"): + def _set_url(config_env: Literal["default", "dev", "prod", "local"] = "default") -> str: """Set the url of python web service Parameters @@ -262,7 +263,7 @@ def run(self) -> tuple[PoleResidue, float]: Returns ------- - Tuple[:class:`.PoleResidue`, float] + tuple[:class:`.PoleResidue`, float] Best results of multiple fits: (dispersive medium, RMS error). """ @@ -272,7 +273,7 @@ def run(self) -> tuple[PoleResidue, float]: resp = requests.post( f"{url_server}/dispersion/fit", headers=headers, - data=self.json(), + data=self.model_dump_json(), verify=ssl_verify, ) @@ -299,7 +300,7 @@ def run(self) -> tuple[PoleResidue, float]: ) from e run_result = resp.json() - best_medium = PoleResidue.parse_raw(run_result["message"]) + best_medium = PoleResidue.model_validate_json(run_result["message"]) best_rms = float(run_result["rms"]) if best_rms < self.tolerance_rms: @@ -337,7 +338,7 @@ def run( Returns ------- - Tuple[:class:`.PoleResidue`, float] + tuple[:class:`.PoleResidue`, float] Best results of multiple fits: (dispersive medium, RMS error). """ if advanced_param is Undefined: @@ -349,13 +350,14 @@ def run( class StableDispersionFitter(DispersionFitter): """Deprecated.""" - @pydantic.root_validator() - def _deprecate_stable_fitter(cls, values): + @model_validator(mode="before") + @classmethod + def _deprecate_stable_fitter(cls, data: dict[str, Any]) -> dict[str, Any]: log.warning( "'StableDispersionFitter' has been deprecated. Use 'DispersionFitter' with " "'tidy3d.plugins.dispersion.web.run' to access the stable fitter from the web server." ) - return values + return data def fit( self, diff --git a/tidy3d/plugins/expressions/__init__.py b/tidy3d/plugins/expressions/__init__.py index 616aac08e8..d9ebcaf329 100644 --- a/tidy3d/plugins/expressions/__init__.py +++ b/tidy3d/plugins/expressions/__init__.py @@ -3,19 +3,41 @@ from .base import Expression from .functions import Cos, Exp, Log, Log10, Sin, Sqrt, Tan from .metrics import ModeAmp, ModePower, generate_validation_data +from .operators import ( + Abs, + Add, + Divide, + FloorDivide, + MatMul, + Modulus, + Multiply, + Negate, + Power, + Subtract, +) from .variables import Constant, Variable __all__ = [ + "Abs", + "Add", "Constant", "Cos", + "Divide", "Exp", "Expression", + "FloorDivide", "Log", "Log10", + "MatMul", "ModeAmp", "ModePower", + "Modulus", + "Multiply", + "Negate", + "Power", "Sin", "Sqrt", + "Subtract", "Tan", "Variable", "generate_validation_data", @@ -43,4 +65,4 @@ _local_vars[name] = obj for cls in _model_classes: - cls.update_forward_refs(**_local_vars) + cls.model_rebuild(force=True) diff --git a/tidy3d/plugins/expressions/base.py b/tidy3d/plugins/expressions/base.py index a3c1ed4b3d..3817a5aa57 100644 --- a/tidy3d/plugins/expressions/base.py +++ b/tidy3d/plugins/expressions/base.py @@ -1,14 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Generator -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from tidy3d.components.base import Tidy3dBaseModel -from .types import ExpressionType, NumberOrExpression, NumberType - if TYPE_CHECKING: + from collections.abc import Generator + from typing import Optional + from .operators import ( Abs, Add, @@ -21,19 +21,18 @@ Power, Subtract, ) + from .types import ExpressionType, NumberOrExpression, NumberType class Expression(Tidy3dBaseModel, ABC): - """ - Base class for all expressions in the metrics module. + """Base class for all expressions in the metrics module. - This class serves as the foundation for all other components in the metrics module. - It provides common functionality and operator overloading for derived classes. + Notes + ----- + This class serves as the foundation for all other components in the metrics module. + It provides common functionality and operator overloading for derived classes. """ - class Config: - smart_union = True - @abstractmethod def evaluate(self, *args: Any, **kwargs: Any) -> NumberType: pass @@ -42,8 +41,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> NumberType: return self.evaluate(*args, **kwargs) @classmethod - def parse_obj(cls, obj: dict[str, Any]) -> ExpressionType: - return super()._parse_obj(obj) + def model_validate(cls, obj: dict[str, Any]) -> ExpressionType: + return super()._model_validate(obj) def filter( self, target_type: type[Expression], target_field: Optional[str] = None @@ -64,7 +63,7 @@ def filter( Instances of the specified type or field found in the expression. """ - def _find_instances(expr: Expression): + def _find_instances(expr: Expression) -> Generator[Any, None, None]: if isinstance(expr, target_type): if target_field: value = getattr(expr, target_field, None) @@ -72,8 +71,8 @@ def _find_instances(expr: Expression): yield value else: yield expr - for field in expr.__fields__.values(): - value = getattr(expr, field.name) + for name in type(expr).model_fields: + value = getattr(expr, name) if isinstance(value, Expression): yield from _find_instances(value) elif isinstance(value, list): @@ -92,7 +91,7 @@ def _to_expression(other: NumberOrExpression | dict[str, Any]) -> ExpressionType if isinstance(other, Expression): return other if isinstance(other, dict): - return Expression.parse_obj(other) + return Expression.model_validate(other) from .variables import Constant return Constant(other) diff --git a/tidy3d/plugins/expressions/functions.py b/tidy3d/plugins/expressions/functions.py index e5460fe028..4a571149ef 100644 --- a/tidy3d/plugins/expressions/functions.py +++ b/tidy3d/plugins/expressions/functions.py @@ -1,12 +1,15 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import autograd.numpy as anp -import pydantic.v1 as pd +from pydantic import Field, field_validator from .base import Expression -from .types import NumberOrExpression, NumberType +from .types import NumberOrExpression + +if TYPE_CHECKING: + from .types import ExpressionType, NumberType class Function(Expression): @@ -14,16 +17,16 @@ class Function(Expression): Base class for mathematical functions in expressions. """ - operand: NumberOrExpression = pd.Field( - ..., + operand: NumberOrExpression = Field( title="Operand", description="The operand for the function.", ) _format: str = "{func}({operand})" - @pd.validator("operand", pre=True, always=True) - def validate_operand(cls, v): + @field_validator("operand") + @classmethod + def validate_operand(cls, v: NumberOrExpression) -> ExpressionType: """ Validate and convert operand to an expression. """ @@ -42,7 +45,7 @@ def __init__(self, operand: NumberOrExpression, **kwargs: dict[str, Any]) -> Non """ super().__init__(operand=operand, **kwargs) - def __repr__(self): + def __repr__(self) -> str: """ Return a string representation of the function. """ diff --git a/tidy3d/plugins/expressions/metrics.py b/tidy3d/plugins/expressions/metrics.py index 277085bfe7..be5a2ca9dd 100644 --- a/tidy3d/plugins/expressions/metrics.py +++ b/tidy3d/plugins/expressions/metrics.py @@ -1,19 +1,23 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import autograd.numpy as np -import pydantic.v1 as pd import xarray as xr +from pydantic import Field, NonNegativeInt -from tidy3d.components.monitor import ModeMonitor from tidy3d.components.types import Direction, FreqArray -from .base import Expression -from .types import NumberType from .variables import Variable +if TYPE_CHECKING: + from tidy3d.compat import Self + from tidy3d.components.monitor import ModeMonitor + + from .base import Expression + from .types import NumberType + def generate_validation_data(expr: Expression) -> dict[str, xr.Dataset]: """Generate combined dummy simulation data for all metrics in the expression. @@ -28,17 +32,18 @@ def generate_validation_data(expr: Expression) -> dict[str, xr.Dataset]: dict[str, xr.Dataset] The combined validation data. """ - metrics = set(expr.filter(target_type=Metric)) + metrics = set(expr.filter(target_type=Metric)) # type: ignore[type-abstract] combined_data = {k: v for metric in metrics for k, v in metric._validation_data.items()} return combined_data class Metric(Variable, ABC): - """ - Base class for all metrics. + """Base class for all metrics. - To subclass Metric, you must implement an evaluate() method that takes a SimulationData - object and returns a scalar value. + Notes + ----- + To subclass Metric, you must implement an evaluate() method that takes a SimulationData + object and returns a scalar value. """ @property @@ -64,23 +69,22 @@ class ModeAmp(Metric): (abs(ModeAmp("monitor1")) ** 2) """ - monitor_name: str = pd.Field( - ..., + monitor_name: str = Field( title="Monitor Name", description="The name of the mode monitor. This needs to match the name of the monitor in the simulation.", ) - f: Optional[Union[float, FreqArray]] = pd.Field( # type: ignore + f: Optional[Union[float, FreqArray]] = Field( None, title="Frequency Array", description="The frequency array. If None, all frequencies in the monitor will be used.", alias="freqs", ) - direction: Direction = pd.Field( + direction: Direction = Field( "+", title="Direction", description="The direction of propagation of the mode.", ) - mode_index: pd.NonNegativeInt = pd.Field( + mode_index: NonNegativeInt = Field( 0, title="Mode Index", description="The index of the mode.", @@ -89,7 +93,7 @@ class ModeAmp(Metric): @classmethod def from_mode_monitor( cls, monitor: ModeMonitor, mode_index: int = 0, direction: Direction = "+" - ): + ) -> Self: return cls( monitor_name=monitor.name, f=monitor.freqs, mode_index=mode_index, direction=direction ) diff --git a/tidy3d/plugins/expressions/operators.py b/tidy3d/plugins/expressions/operators.py index e25004180e..3fbc44d83e 100644 --- a/tidy3d/plugins/expressions/operators.py +++ b/tidy3d/plugins/expressions/operators.py @@ -1,23 +1,26 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any -import pydantic.v1 as pd +from pydantic import Field, field_validator from .base import Expression -from .types import NumberOrExpression, NumberType +from .types import NumberOrExpression + +if TYPE_CHECKING: + from .types import ExpressionType, NumberType class UnaryOperator(Expression): - """ - Base class for unary operators in the metrics module. + """Base class for unary operators in the metrics module. - This class represents an operation with a single operand. - Subclasses should implement the evaluate method to define the specific operation. + Notes + ----- + This class represents an operation with a single operand. + Subclasses should implement the evaluate method to define the specific operation. """ - operand: NumberOrExpression = pd.Field( - ..., + operand: NumberOrExpression = Field( title="Operand", description="The operand for the unary operator.", ) @@ -25,8 +28,9 @@ class UnaryOperator(Expression): _symbol: str _format: str = "({symbol}{operand})" - @pd.validator("operand", pre=True, always=True) - def validate_operand(cls, v): + @field_validator("operand") + @classmethod + def validate_operand(cls, v: NumberOrExpression) -> ExpressionType: return cls._to_expression(v) def __repr__(self) -> str: @@ -34,20 +38,19 @@ def __repr__(self) -> str: class BinaryOperator(Expression): - """ - Base class for binary operators in the metrics module. + """Base class for binary operators in the metrics module. - This class represents an operation with two operands. - Subclasses should implement the evaluate method to define the specific operation. + Notes + ----- + This class represents an operation with two operands. + Subclasses should implement the evaluate method to define the specific operation. """ - left: NumberOrExpression = pd.Field( - ..., + left: NumberOrExpression = Field( title="Left", description="The left operand for the binary operator.", ) - right: NumberOrExpression = pd.Field( - ..., + right: NumberOrExpression = Field( title="Right", description="The right operand for the binary operator.", ) @@ -55,8 +58,9 @@ class BinaryOperator(Expression): _symbol: str _format: str = "({left} {symbol} {right})" - @pd.validator("left", "right", pre=True, always=True) - def validate_operands(cls, v): + @field_validator("left", "right") + @classmethod + def validate_operands(cls, v: NumberOrExpression) -> ExpressionType: return cls._to_expression(v) def __repr__(self) -> str: diff --git a/tidy3d/plugins/expressions/types.py b/tidy3d/plugins/expressions/types.py index 1f6a12e9fe..ff234c32c0 100644 --- a/tidy3d/plugins/expressions/types.py +++ b/tidy3d/plugins/expressions/types.py @@ -1,10 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Annotated, Union +from typing import TYPE_CHECKING, Union -from pydantic.v1 import Field - -from tidy3d.components.types import TYPE_TAG_STR, ArrayLike, Complex +from tidy3d.components.types import Complex +from tidy3d.components.types.base import ArrayLikeStrict, discriminated_union if TYPE_CHECKING: from .functions import Cos, Exp, Log, Log10, Sin, Sqrt, Tan @@ -23,9 +22,9 @@ ) from .variables import Constant, Variable -NumberType = Union[int, float, Complex, ArrayLike] +NumberType = Union[int, float, Complex, ArrayLikeStrict] -OperatorType = Annotated[ +OperatorType = discriminated_union( Union[ "Add", "Subtract", @@ -37,11 +36,10 @@ "MatMul", "Negate", "Abs", - ], - Field(discriminator=TYPE_TAG_STR), -] + ] +) -FunctionType = Annotated[ +FunctionType = discriminated_union( Union[ "Sin", "Cos", @@ -50,19 +48,17 @@ "Log", "Log10", "Sqrt", - ], - Field(discriminator=TYPE_TAG_STR), -] + ] +) -MetricType = Annotated[ +MetricType = discriminated_union( Union[ "Constant", "Variable", "ModeAmp", "ModePower", - ], - Field(discriminator=TYPE_TAG_STR), -] + ] +) ExpressionType = Union[ OperatorType, @@ -70,4 +66,4 @@ MetricType, ] -NumberOrExpression = Union[NumberType, ExpressionType] +NumberOrExpression = Union[ExpressionType, NumberType] diff --git a/tidy3d/plugins/expressions/variables.py b/tidy3d/plugins/expressions/variables.py index 13cd5534b8..501bb9da76 100644 --- a/tidy3d/plugins/expressions/variables.py +++ b/tidy3d/plugins/expressions/variables.py @@ -2,7 +2,7 @@ from typing import Any, Optional -import pydantic.v1 as pd +from pydantic import Field from .base import Expression from .types import NumberType @@ -39,7 +39,7 @@ class Variable(Expression): 10 """ - name: Optional[str] = pd.Field( + name: Optional[str] = Field( None, title="Name", description="The name of the variable used for lookup during evaluation.", @@ -81,8 +81,7 @@ class Constant(Variable): 5 """ - value: NumberType = pd.Field( - ..., + value: NumberType = Field( title="Value", description="The fixed value of the constant.", ) diff --git a/tidy3d/plugins/invdes/__init__.py b/tidy3d/plugins/invdes/__init__.py index 86f819248e..ce189f40ad 100644 --- a/tidy3d/plugins/invdes/__init__.py +++ b/tidy3d/plugins/invdes/__init__.py @@ -1,29 +1,50 @@ -# imports from tidy3d.plugins.invdes as tdi from __future__ import annotations +import tidy3d.plugins.expressions + from . import utils -from .design import InverseDesign, InverseDesignMulti +from .base import InvdesBaseModel +from .design import AbstractInverseDesign, InverseDesign, InverseDesignMulti, InverseDesignType from .initialization import ( + AbstractInitializationSpec, CustomInitializationSpec, + InitializationSpecType, RandomInitializationSpec, UniformInitializationSpec, ) -from .optimizer import AdamOptimizer -from .penalty import ErosionDilationPenalty -from .region import TopologyDesignRegion +from .optimizer import AbstractOptimizer, AdamOptimizer +from .penalty import AbstractPenalty, ErosionDilationPenalty, PenaltyType +from .region import DesignRegion, DesignRegionType, TopologyDesignRegion from .result import InverseDesignResult -from .transformation import FilterProject +from .transformation import AbstractTransformation, FilterProject, TransformationType + +rebuild_context_namespace = tidy3d.plugins.expressions._local_vars.copy() +AbstractInverseDesign.model_rebuild(_types_namespace=rebuild_context_namespace) +InverseDesign.model_rebuild(_types_namespace=rebuild_context_namespace) +InverseDesignMulti.model_rebuild(_types_namespace=rebuild_context_namespace) __all__ = ( + "AbstractInitializationSpec", + "AbstractInverseDesign", + "AbstractOptimizer", + "AbstractPenalty", + "AbstractTransformation", "AdamOptimizer", "CustomInitializationSpec", + "DesignRegion", + "DesignRegionType", "ErosionDilationPenalty", "FilterProject", + "InitializationSpecType", + "InvdesBaseModel", "InverseDesign", "InverseDesignMulti", "InverseDesignResult", + "InverseDesignType", + "PenaltyType", "RandomInitializationSpec", "TopologyDesignRegion", + "TransformationType", "UniformInitializationSpec", "utils", ) diff --git a/tidy3d/plugins/invdes/base.py b/tidy3d/plugins/invdes/base.py index bde96efcf0..56853e0255 100644 --- a/tidy3d/plugins/invdes/base.py +++ b/tidy3d/plugins/invdes/base.py @@ -1,10 +1,10 @@ # base class for all of the invdes fields from __future__ import annotations -import abc +from abc import ABC -import tidy3d as td +from tidy3d.components.base import Tidy3dBaseModel -class InvdesBaseModel(td.components.base.Tidy3dBaseModel, abc.ABC): +class InvdesBaseModel(Tidy3dBaseModel, ABC): """Base class for ``invdes`` components, in case we need it.""" diff --git a/tidy3d/plugins/invdes/design.py b/tidy3d/plugins/invdes/design.py index 22c8a88c68..0b403b51a7 100644 --- a/tidy3d/plugins/invdes/design.py +++ b/tidy3d/plugins/invdes/design.py @@ -3,12 +3,10 @@ from __future__ import annotations import abc -import typing -from typing import Any +from typing import TYPE_CHECKING, Any, Callable, Optional, Union -import autograd.numpy as anp import numpy as np -import pydantic.v1 as pd +from pydantic import Field, field_validator, model_validator import tidy3d as td from tidy3d.components.autograd import get_static @@ -20,39 +18,42 @@ from .region import DesignRegionType from .validators import check_pixel_size -PostProcessFnType = typing.Callable[[td.SimulationData], float] +if TYPE_CHECKING: + import autograd.numpy as anp + + from tidy3d.compat import Self + +PostProcessFnType = Callable[[td.SimulationData], float] class AbstractInverseDesign(InvdesBaseModel, abc.ABC): """Container for an inverse design problem.""" - design_region: DesignRegionType = pd.Field( - ..., + design_region: DesignRegionType = Field( title="Design Region", description="Region within which we will optimize the simulation.", ) - task_name: str = pd.Field( - ..., + task_name: str = Field( title="Task Name", description="Task name to use in the objective function when running the ``JaxSimulation``.", ) - verbose: bool = pd.Field( + verbose: bool = Field( False, title="Task Verbosity", description="If ``True``, will print the regular output from ``web`` functions.", ) - metric: typing.Optional[ExpressionType] = pd.Field( + metric: Optional[ExpressionType] = Field( None, title="Objective Metric", description="Serializable expression defining the objective function.", ) def make_objective_fn( - self, post_process_fn: typing.Optional[typing.Callable] = None, maximize: bool = True - ) -> typing.Callable[[anp.ndarray], tuple[float, dict]]: + self, post_process_fn: Optional[Callable] = None, maximize: bool = True + ) -> Callable[[anp.ndarray], tuple[float, dict]]: """Construct the objective function for this InverseDesign object.""" if (post_process_fn is None) and (self.metric is None): @@ -63,7 +64,7 @@ def make_objective_fn( direction_multiplier = 1 if maximize else -1 - def objective_fn(params: anp.ndarray, aux_data: typing.Optional[dict] = None) -> float: + def objective_fn(params: anp.ndarray, aux_data: Optional[dict] = None) -> float: """Full objective function.""" data = self.to_simulation_data(params=params) @@ -100,7 +101,7 @@ def initial_simulation(self) -> td.Simulation: initial_params = self.design_region.initial_parameters return self.to_simulation(initial_params) - def run(self, simulation, **kwargs: Any) -> td.SimulationData: + def run(self, simulation: td.Simulation, **kwargs: Any) -> td.SimulationData: """Run a single tidy3d simulation.""" from tidy3d.web import run @@ -108,7 +109,7 @@ def run(self, simulation, **kwargs: Any) -> td.SimulationData: kwargs.setdefault("task_name", self.task_name) return run(simulation, **kwargs) - def run_async(self, simulations, **kwargs: Any) -> web.BatchData: # noqa: F821 + def run_async(self, simulations: dict[str, td.Simulation], **kwargs: Any) -> td.web.BatchData: """Run a batch of tidy3d simulations.""" from tidy3d.web import run_async @@ -119,13 +120,12 @@ def run_async(self, simulations, **kwargs: Any) -> web.BatchData: # noqa: F821 class InverseDesign(AbstractInverseDesign): """Container for an inverse design problem.""" - simulation: td.Simulation = pd.Field( - ..., + simulation: td.Simulation = Field( title="Base Simulation", description="Simulation without the design regions or monitors used in the objective fn.", ) - output_monitor_names: tuple[str, ...] = pd.Field( + output_monitor_names: Optional[tuple[str, ...]] = Field( None, title="Output Monitor Names", description="Optional names of monitors whose data the differentiable output depends on." @@ -136,23 +136,16 @@ class InverseDesign(AbstractInverseDesign): _check_sim_pixel_size = check_pixel_size("simulation") - @pd.root_validator(pre=False) - def _validate_model(cls, values: dict) -> dict: - cls._validate_metric(values) - return values - - @staticmethod - def _validate_metric(values: dict) -> dict: - metric_expr = values.get("metric") - if not metric_expr: - return values - simulation = values.get("simulation") - for metric in metric_expr.filter(Metric): - InverseDesign._validate_metric_monitor_name(metric, simulation) - InverseDesign._validate_metric_mode_index(metric, simulation) - InverseDesign._validate_metric_f(metric, simulation) - InverseDesign._validate_metric_data(metric_expr, simulation) - return values + @model_validator(mode="after") + def _validate_model(self) -> Self: + if not self.metric: + return self + for metric in self.metric.filter(Metric): + InverseDesign._validate_metric_monitor_name(metric, self.simulation) + InverseDesign._validate_metric_mode_index(metric, self.simulation) + InverseDesign._validate_metric_f(metric, self.simulation) + InverseDesign._validate_metric_data(self.metric, self.simulation) + return self @staticmethod def _validate_metric_monitor_name(metric: Metric, simulation: td.Simulation) -> None: @@ -260,13 +253,12 @@ def to_simulation_data(self, params: anp.ndarray, **kwargs: Any) -> td.Simulatio class InverseDesignMulti(AbstractInverseDesign): """``InverseDesign`` with multiple simulations and corresponding postprocess functions.""" - simulations: tuple[td.Simulation, ...] = pd.Field( - ..., + simulations: tuple[td.Simulation, ...] = Field( title="Base Simulations", description="Set of simulation without the design regions or monitors used in the objective fn.", ) - output_monitor_names: tuple[typing.Union[tuple[str, ...], None], ...] = pd.Field( + output_monitor_names: Optional[tuple[Union[tuple[str, ...], None], ...]] = Field( None, title="Output Monitor Names", description="Optional names of monitors whose data the differentiable output depends on." @@ -275,14 +267,32 @@ class InverseDesignMulti(AbstractInverseDesign): "not fully supported, for example ``FieldMonitor`` instances with ``.colocate != False``.", ) + @field_validator("output_monitor_names", mode="before") + @classmethod + def _convert_list_to_tuple(cls, v: Any) -> Any: + """Convert lists to tuples for output_monitor_names.""" + if v is None: + return v + if isinstance(v, list): + # Convert outer list to tuple + result = [] + for item in v: + if isinstance(item, list): + # Convert inner list to tuple + result.append(tuple(item)) + else: + result.append(item) + return tuple(result) + return v + _check_sim_pixel_size = check_pixel_size("simulations") - @pd.root_validator() - def _check_lengths(cls, values): + @model_validator(mode="after") + def _check_lengths(self) -> Self: """Check the lengths of all of the multi fields.""" - keys = ("simulations", "post_process_fns", "output_monitor_names", "override_structure_dl") - multi_dict = {key: values.get(key) for key in keys} + keys = ("simulations", "output_monitor_names") + multi_dict = {key: getattr(self, key) for key in keys} sizes = {key: len(val) for key, val in multi_dict.items() if val is not None} if len(set(sizes.values())) != 1: @@ -292,7 +302,7 @@ def _check_lengths(cls, values): "corresponding sizes of '{sizes}'." ) - return values + return self @property def task_names(self) -> list[str]: @@ -323,10 +333,10 @@ def to_simulation(self, params: anp.ndarray) -> dict[str, td.Simulation]: simulation_list = [design.to_simulation(params) for design in self.designs] return dict(zip(self.task_names, simulation_list)) - def to_simulation_data(self, params: anp.ndarray, **kwargs: Any) -> web.BatchData: # noqa: F821 + def to_simulation_data(self, params: anp.ndarray, **kwargs: Any) -> td.web.BatchData: """Convert the ``InverseDesignMulti`` to a set of ``td.Simulation``s and run async.""" simulations = self.to_simulation(params) return self.run_async(simulations, **kwargs) -InverseDesignType = typing.Union[InverseDesign, InverseDesignMulti] +InverseDesignType = Union[InverseDesign, InverseDesignMulti] diff --git a/tidy3d/plugins/invdes/initialization.py b/tidy3d/plugins/invdes/initialization.py index 3acd297d25..fd3b82b282 100644 --- a/tidy3d/plugins/invdes/initialization.py +++ b/tidy3d/plugins/invdes/initialization.py @@ -3,17 +3,21 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import numpy as np -import pydantic.v1 as pd -from numpy.typing import NDArray +from pydantic import Field, NonNegativeInt, field_validator, model_validator import tidy3d as td from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.types import ArrayLike from tidy3d.exceptions import ValidationError +if TYPE_CHECKING: + from numpy.typing import NDArray + + from tidy3d.compat import Self + class AbstractInitializationSpec(Tidy3dBaseModel, ABC): """Abstract base class for initialization specifications.""" @@ -26,37 +30,38 @@ def create_parameters(self, shape: tuple[int, ...]) -> NDArray: class RandomInitializationSpec(AbstractInitializationSpec): """Specification for random initial parameters. - When a seed is provided, a call to `create_parameters` will always return the same array. + Notes + ----- + When a seed is provided, a call to `create_parameters` will always return the same array. """ - min_value: float = pd.Field( + min_value: float = Field( 0.0, ge=0.0, le=1.0, title="Minimum Value", description="Minimum value for the random parameters (inclusive).", ) - max_value: float = pd.Field( + max_value: float = Field( 1.0, ge=0.0, le=1.0, title="Maximum Value", description="Maximum value for the random parameters (exclusive).", ) - seed: Optional[pd.NonNegativeInt] = pd.Field( - None, description="Seed for the random number generator." + seed: Optional[NonNegativeInt] = Field( + None, + description="Seed for the random number generator.", ) - @pd.root_validator(pre=False) - def _validate_max_ge_min(cls, values): + @model_validator(mode="after") + def _validate_max_ge_min(self) -> Self: """Ensure that max_value is greater than or equal to min_value.""" - minval = values.get("min_value") - maxval = values.get("max_value") - if minval > maxval: + if self.min_value > self.max_value: raise ValidationError( - f"'max_value' ({maxval}) must be greater or equal than 'min_value' ({minval})" + f"'max_value' ({self.max_value}) must be greater or equal than 'min_value' ({self.min_value})" ) - return values + return self def create_parameters(self, shape: tuple[int, ...]) -> NDArray: """Generate the parameter array based on the specification.""" @@ -67,7 +72,7 @@ def create_parameters(self, shape: tuple[int, ...]) -> NDArray: class UniformInitializationSpec(AbstractInitializationSpec): """Specification for uniform initial parameters.""" - value: float = pd.Field( + value: float = Field( 0.5, ge=0.0, le=1.0, @@ -83,37 +88,40 @@ def create_parameters(self, shape: tuple[int, ...]) -> NDArray: class CustomInitializationSpec(AbstractInitializationSpec): """Specification for custom initial parameters provided by the user.""" - params: ArrayLike = pd.Field( + params: ArrayLike = Field( ..., title="Parameters", description="Custom parameters provided by the user.", ) - @pd.validator("params") - def _validate_params_range(cls, value, values): + @field_validator("params") + @classmethod + def _validate_params_range(cls, val: NDArray) -> NDArray: """Ensure that all parameter values are between 0 and 1.""" - if np.any((value < 0) | (value > 1)): + if np.any((val < 0) | (val > 1)): raise ValidationError("'params' need to be between 0 and 1.") - return value + return val - @pd.validator("params") - def _validate_params_dtype(cls, value, values): + @field_validator("params") + @classmethod + def _validate_params_dtype(cls, val: NDArray) -> NDArray: """Ensure that params is real-valued.""" - if np.issubdtype(value.dtype, np.bool_): + if np.issubdtype(val.dtype, np.bool_): td.log.warning( "Got a boolean array for 'params'. This will be treated as a floating point array." ) - value = value.astype(float) - elif not np.issubdtype(value.dtype, np.floating): - raise ValidationError(f"'params' need to be real-valued, but got '{value.dtype}'.") - return value - - @pd.validator("params") - def _validate_params_3d(cls, value, values): + val = val.astype(float) + elif not np.issubdtype(val.dtype, np.floating): + raise ValidationError(f"'params' need to be real-valued, but got '{val.dtype}'.") + return val + + @field_validator("params") + @classmethod + def _validate_params_3d(cls, val: NDArray) -> NDArray: """Ensure that params is a 3D array.""" - if value.ndim != 3: - raise ValidationError(f"'params' must be 3D, but got {value.ndim}D.") - return value + if val.ndim != 3: + raise ValidationError(f"'params' must be 3D, but got {val.ndim}D.") + return val def create_parameters(self, shape: tuple[int, ...]) -> NDArray: """Return the custom parameters provided by the user.""" diff --git a/tidy3d/plugins/invdes/optimizer.py b/tidy3d/plugins/invdes/optimizer.py index a43ba47e23..a74777f67e 100644 --- a/tidy3d/plugins/invdes/optimizer.py +++ b/tidy3d/plugins/invdes/optimizer.py @@ -2,51 +2,52 @@ from __future__ import annotations import abc -import typing from copy import deepcopy +from typing import TYPE_CHECKING, Optional import autograd as ag import autograd.numpy as anp import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, PositiveInt import tidy3d as td from tidy3d.components.types import TYPE_TAG_STR +from tidy3d.exceptions import SetupError from .base import InvdesBaseModel from .design import InverseDesignType from .result import InverseDesignResult +if TYPE_CHECKING: + from typing import Callable + class AbstractOptimizer(InvdesBaseModel, abc.ABC): """Specification for an optimization.""" - design: InverseDesignType = pd.Field( - ..., + design: InverseDesignType = Field( title="Inverse Design Specification", description="Specification describing the inverse design problem we wish to optimize.", discriminator=TYPE_TAG_STR, ) - learning_rate: pd.PositiveFloat = pd.Field( - ..., + learning_rate: PositiveFloat = Field( title="Learning Rate", description="Step size for the gradient descent optimizer.", ) - maximize: bool = pd.Field( + maximize: bool = Field( True, title="Direction of Optimization", description="If ``True``, the optimizer will maximize the objective function. If ``False``, the optimizer will minimize the objective function.", ) - num_steps: pd.PositiveInt = pd.Field( - ..., + num_steps: PositiveInt = Field( title="Number of Steps", description="Number of steps in the gradient descent optimizer.", ) - results_cache_fname: str = pd.Field( + results_cache_fname: Optional[str] = Field( None, title="History Storage File", description="If specified, will save the optimization state to a local ``.pkl`` file " @@ -58,7 +59,7 @@ class AbstractOptimizer(InvdesBaseModel, abc.ABC): "``optimizer.continue_run(result)``. ", ) - store_full_results: bool = pd.Field( + store_full_results: bool = Field( True, title="Store Full Results", description="If ``True``, stores the full history for the vector fields, specifically " @@ -84,9 +85,7 @@ def display_fn(self, result: InverseDesignResult, step_index: int) -> None: print(f"\tpost_process_val = {result.post_process_val[-1]:.3e}") print(f"\tpenalty = {result.penalty[-1]:.3e}") - def initialize_result( - self, params0: typing.Optional[anp.ndarray] = None - ) -> InverseDesignResult: + def initialize_result(self, params0: Optional[anp.ndarray] = None) -> InverseDesignResult: """ Create an initially empty `InverseDesignResult` from the starting parameters. @@ -111,8 +110,8 @@ def initialize_result( def run( self, - post_process_fn: typing.Optional[typing.Callable] = None, - callback: typing.Optional[typing.Callable] = None, + post_process_fn: Optional[Callable] = None, + callback: Optional[Callable] = None, params0: anp.ndarray = None, ) -> InverseDesignResult: """Run this inverse design problem from an optional initial set of parameters. @@ -140,9 +139,9 @@ def run( def continue_run( self, result: InverseDesignResult, - num_steps: typing.Optional[int] = None, - post_process_fn: typing.Optional[typing.Callable] = None, - callback: typing.Optional[typing.Callable] = None, + num_steps: Optional[int] = None, + post_process_fn: Optional[Callable] = None, + callback: Optional[Callable] = None, ) -> InverseDesignResult: """Run optimizer for a series of steps with an initialized state. @@ -180,9 +179,9 @@ def continue_run( aux_data = {} val, grad = val_and_grad_fn(params, aux_data=aux_data) - if anp.allclose(grad, 0.0): - td.log.warning( - "All elements of the gradient are almost zero. This likely indicates " + if np.count_nonzero(grad) == 0: + raise SetupError( + "All elements of the gradient are exactly zero. This likely indicates " "a problem with the optimization set up. This can occur if the symmetry of the " "simulation and design region are preventing any data to be recorded in the " "'output_monitors'. In this case, we recommend initializing with a " @@ -230,9 +229,9 @@ def continue_run( def continue_run_from_file( self, fname: str, - num_steps: typing.Optional[int] = None, - post_process_fn: typing.Optional[typing.Callable] = None, - callback: typing.Optional[typing.Callable] = None, + num_steps: Optional[int] = None, + post_process_fn: Optional[Callable] = None, + callback: Optional[Callable] = None, ) -> InverseDesignResult: """Continue the optimization run from a ``.pkl`` file with an ``InverseDesignResult``.""" result = InverseDesignResult.from_file(fname) @@ -245,9 +244,9 @@ def continue_run_from_file( def continue_run_from_history( self, - num_steps: typing.Optional[int] = None, - post_process_fn: typing.Optional[typing.Callable] = None, - callback: typing.Optional[typing.Callable] = None, + num_steps: Optional[int] = None, + post_process_fn: Optional[Callable] = None, + callback: Optional[Callable] = None, ) -> InverseDesignResult: """Continue the optimization run from a ``.pkl`` file with an ``InverseDesignResult``.""" return self.continue_run_from_file( @@ -261,7 +260,7 @@ def continue_run_from_history( class AdamOptimizer(AbstractOptimizer): """Specification for an optimization.""" - beta1: float = pd.Field( + beta1: float = Field( 0.9, ge=0.0, le=1.0, @@ -269,7 +268,7 @@ class AdamOptimizer(AbstractOptimizer): description="Beta 1 parameter in the Adam optimization method.", ) - beta2: float = pd.Field( + beta2: float = Field( 0.999, ge=0.0, le=1.0, @@ -277,7 +276,7 @@ class AdamOptimizer(AbstractOptimizer): description="Beta 2 parameter in the Adam optimization method.", ) - eps: pd.PositiveFloat = pd.Field( + eps: PositiveFloat = Field( 1e-8, title="Epsilon", description="Epsilon parameter in the Adam optimization method.", @@ -289,7 +288,7 @@ def initial_state(self, parameters: np.ndarray) -> dict: return {"m": zeros, "v": zeros, "t": 0} def update( - self, parameters: np.ndarray, gradient: np.ndarray, state: typing.Optional[dict] = None + self, parameters: np.ndarray, gradient: np.ndarray, state: Optional[dict] = None ) -> tuple[np.ndarray, dict]: if state is None: state = self.initial_state(parameters) @@ -311,6 +310,6 @@ def update( v_ = v / (1 - self.beta2**t) # update parameters and state - parameters -= self.learning_rate * m_ / (np.sqrt(v_) + self.eps) + parameters = parameters - self.learning_rate * m_ / (np.sqrt(v_) + self.eps) state = {"m": m, "v": v, "t": t} return parameters, state diff --git a/tidy3d/plugins/invdes/penalty.py b/tidy3d/plugins/invdes/penalty.py index 37d31b4978..e06b6fe84d 100644 --- a/tidy3d/plugins/invdes/penalty.py +++ b/tidy3d/plugins/invdes/penalty.py @@ -2,22 +2,23 @@ from __future__ import annotations import abc -import typing -from typing import Any +from typing import TYPE_CHECKING, Any, Union -import autograd.numpy as anp -import pydantic.v1 as pd +from pydantic import Field, NonNegativeFloat, PositiveFloat from tidy3d.constants import MICROMETER from tidy3d.plugins.autograd.invdes import make_erosion_dilation_penalty from .base import InvdesBaseModel +if TYPE_CHECKING: + import autograd.numpy as anp + class AbstractPenalty(InvdesBaseModel, abc.ABC): """Base class for penalties added to ``invdes.DesignRegion`` objects.""" - weight: pd.NonNegativeFloat = pd.Field( + weight: NonNegativeFloat = Field( 1.0, title="Weight", description="When this penalty is evaluated, it will be weighted by this " @@ -50,8 +51,7 @@ class ErosionDilationPenalty(AbstractPenalty): """ - length_scale: pd.PositiveFloat = pd.Field( - ..., + length_scale: PositiveFloat = Field( title="Length Scale", description="Length scale of erosion and dilation. " "Corresponds to ``radius`` in the :class:`ConicFilter` used for filtering. " @@ -60,7 +60,7 @@ class ErosionDilationPenalty(AbstractPenalty): units=MICROMETER, ) - beta: float = pd.Field( + beta: float = Field( 100.0, ge=1.0, title="Projection Beta", @@ -69,7 +69,7 @@ class ErosionDilationPenalty(AbstractPenalty): "Higher values correspond to stronger discretization.", ) - eta0: float = pd.Field( + eta0: float = Field( 0.5, ge=0.0, le=1.0, @@ -79,7 +79,7 @@ class ErosionDilationPenalty(AbstractPenalty): "Corresponds to ``eta`` in the :class:`BinaryProjector`.", ) - delta_eta: float = pd.Field( + delta_eta: float = Field( 0.01, ge=0.0, le=1.0, @@ -99,4 +99,4 @@ def evaluate(self, x: anp.ndarray, pixel_size: float) -> float: return self.weight * penalty_unweighted -PenaltyType = typing.Union[ErosionDilationPenalty] +PenaltyType = Union[ErosionDilationPenalty] diff --git a/tidy3d/plugins/invdes/region.py b/tidy3d/plugins/invdes/region.py index 640bafc2ab..7edae94c19 100644 --- a/tidy3d/plugins/invdes/region.py +++ b/tidy3d/plugins/invdes/region.py @@ -2,14 +2,13 @@ from __future__ import annotations import abc -import typing import warnings -from typing import Any +from typing import TYPE_CHECKING, Any, Literal, Optional, Union import autograd.numpy as anp import numpy as np -import pydantic.v1 as pd from autograd import elementwise_grad, grad +from pydantic import Field, PositiveFloat, field_validator, model_validator import tidy3d as td from tidy3d.components.types import TYPE_TAG_STR, Coordinate, Size @@ -20,34 +19,35 @@ from .penalty import PenaltyType from .transformation import TransformationType +if TYPE_CHECKING: + from numpy.typing import NDArray + + from tidy3d.compat import Self + # TODO: support auto handling of symmetry in parameters class DesignRegion(InvdesBaseModel, abc.ABC): """Base class for design regions in the ``invdes`` plugin.""" - size: Size = pd.Field( - ..., + size: Size = Field( title="Size", description="Size in x, y, and z directions.", units=td.constants.MICROMETER, ) - center: Coordinate = pd.Field( - ..., + center: Coordinate = Field( title="Center", description="Center of object in x, y, and z.", units=td.constants.MICROMETER, ) - eps_bounds: tuple[float, float] = pd.Field( - ..., - ge=1.0, + eps_bounds: tuple[float, float] = Field( title="Relative Permittivity Bounds", description="Minimum and maximum relative permittivity expressed to the design region.", ) - transformations: tuple[TransformationType, ...] = pd.Field( + transformations: tuple[TransformationType, ...] = Field( (), title="Transformations", description="Transformations that get applied from first to last on the parameter array." @@ -57,7 +57,7 @@ class DesignRegion(InvdesBaseModel, abc.ABC): "Specific permittivity values given the density array are determined by ``eps_bounds``.", ) - penalties: tuple[PenaltyType, ...] = pd.Field( + penalties: tuple[PenaltyType, ...] = Field( (), title="Penalties", description="Set of penalties that get evaluated on the material density. Note that the " @@ -65,25 +65,28 @@ class DesignRegion(InvdesBaseModel, abc.ABC): "inside of the penalties directly through the ``.weight`` field.", ) - initialization_spec: InitializationSpecType = pd.Field( - UniformInitializationSpec(value=0.5), + initialization_spec: InitializationSpecType = Field( + default_factory=lambda: UniformInitializationSpec(value=0.5), title="Initialization Specification", description="Specification of how to initialize the parameters in the design region.", discriminator=TYPE_TAG_STR, ) - def _post_init_validators(self) -> None: - """Automatically call any `_validate_XXX` method.""" - for attr_name in dir(self): - if attr_name.startswith("_validate") and callable(getattr(self, attr_name)): - getattr(self, attr_name)() + @field_validator("eps_bounds") + @classmethod + def _validate_ge_one(cls, v: tuple[float, float]) -> tuple[float, float]: + if any(vi < 1 for vi in v): + raise ValueError("Each value in 'eps_bounds' must be '>=1.0'.") + return v - def _validate_eps_bounds(self) -> None: + @model_validator(mode="after") + def _validate_eps_bounds(self) -> Self: if self.eps_bounds[1] < self.eps_bounds[0]: raise ValidationError( f"Maximum relative permittivity ({self.eps_bounds[1]}) must be " f"greater than minimum relative permittivity ({self.eps_bounds[0]})." ) + return self @property def geometry(self) -> td.Box: @@ -133,8 +136,7 @@ def initial_parameters(self) -> np.ndarray: class TopologyDesignRegion(DesignRegion): """Design region as a pixellated permittivity grid.""" - pixel_size: pd.PositiveFloat = pd.Field( - ..., + pixel_size: PositiveFloat = Field( title="Pixel Size", description="Pixel size of the design region in x, y, z. For now, we only support the same " "pixel size in all 3 dimensions. If ``TopologyDesignRegion.override_structure_dl`` is left " @@ -144,14 +146,14 @@ class TopologyDesignRegion(DesignRegion): "a value on the same order as the grid size.", ) - uniform: tuple[bool, bool, bool] = pd.Field( + uniform: tuple[bool, bool, bool] = Field( (False, False, True), title="Uniform", description="Axes along which the design should be uniform. By default, the structure " "is assumed to be uniform, i.e. invariant, in the z direction.", ) - transformations: tuple[TransformationType, ...] = pd.Field( + transformations: tuple[TransformationType, ...] = Field( (), title="Transformations", description="Transformations that get applied from first to last on the parameter array." @@ -160,7 +162,7 @@ class TopologyDesignRegion(DesignRegion): "permittivity and 1 corresponds to the maximum relative permittivity. " "Specific permittivity values given the density array are determined by ``eps_bounds``.", ) - penalties: tuple[PenaltyType, ...] = pd.Field( + penalties: tuple[PenaltyType, ...] = Field( (), title="Penalties", description="Set of penalties that get evaluated on the material density. Note that the " @@ -168,7 +170,7 @@ class TopologyDesignRegion(DesignRegion): "inside of the penalties directly through the ``.weight`` field.", ) - override_structure_dl: typing.Union[pd.PositiveFloat, typing.Literal[False]] = pd.Field( + override_structure_dl: Optional[Union[PositiveFloat, Literal[False]]] = Field( None, title="Design Region Override Structure", description="Defines grid size when adding an ``override_structure`` to the " @@ -179,6 +181,16 @@ class TopologyDesignRegion(DesignRegion): "Supplying ``False`` will completely leave out the override structure.", ) + priority: Optional[int] = Field( + None, + title="Priority", + description="Priority of the structure applied in structure overlapping region. " + "The material property in the overlapping region is dictated by the structure " + "of higher priority. For structures of equal priority, " + "the structure added later to the structure list takes precedence. When `priority` is None, " + "the value is automatically assigned based on `structure_priority_mode` in the `Simulation`.", + ) + def _validate_eps_values(self) -> None: """Validate the epsilon values by evaluating the transformations.""" try: @@ -261,31 +273,31 @@ def _warn_deprecate_params(self) -> None: "'initialization_spec' instead." ) - def params_uniform(self, value: float) -> np.ndarray: + def params_uniform(self, value: float) -> NDArray[np.floating]: """Make an array of parameters with all the same value.""" self._warn_deprecate_params() return value * np.ones(self.params_shape) @property - def params_random(self) -> np.ndarray: + def params_random(self) -> NDArray[np.floating]: """Convenience for generating random parameters between (0,1) with correct shape.""" self._warn_deprecate_params() return np.random.random(self.params_shape) @property - def params_zeros(self): + def params_zeros(self) -> NDArray[np.floating]: """Convenience for generating random parameters of all 0 values with correct shape.""" self._warn_deprecate_params() return self.params_uniform(0.0) @property - def params_half(self): + def params_half(self) -> NDArray[np.floating]: """Convenience for generating random parameters of all 0.5 values with correct shape.""" self._warn_deprecate_params() return self.params_uniform(0.5) @property - def params_ones(self): + def params_ones(self) -> NDArray[np.floating]: """Convenience for generating random parameters of all 1 values with correct shape.""" self._warn_deprecate_params() return self.params_uniform(1.0) @@ -329,7 +341,7 @@ def to_structure(self, params: anp.ndarray) -> td.Structure: eps_values = self.eps_values(params) eps_data_array = td.SpatialDataArray(eps_values, coords=coords) medium = td.CustomMedium(permittivity=eps_data_array) - return td.Structure(geometry=self.geometry, medium=medium) + return td.Structure(geometry=self.geometry, medium=medium, priority=self.priority) @property def _override_structure_dl(self) -> float: @@ -367,4 +379,4 @@ def evaluate_penalty(self, penalty: PenaltyType, material_density: anp.ndarray) return penalty.evaluate(x=material_density, pixel_size=self.pixel_size) -DesignRegionType = typing.Union[TopologyDesignRegion] +DesignRegionType = Union[TopologyDesignRegion] diff --git a/tidy3d/plugins/invdes/result.py b/tidy3d/plugins/invdes/result.py index ae844fb5ad..8380b28b6e 100644 --- a/tidy3d/plugins/invdes/result.py +++ b/tidy3d/plugins/invdes/result.py @@ -1,12 +1,11 @@ # convenient container for the output of the inverse design (specifically the history) from __future__ import annotations -import typing -from typing import Any +from typing import Any, Union import matplotlib.pyplot as plt import numpy as np -import pydantic.v1 as pd +from pydantic import Field, field_validator import tidy3d as td from tidy3d.components.types import ArrayLike @@ -20,56 +19,58 @@ class InverseDesignResult(InvdesBaseModel): """Container for the result of an ``InverseDesign.run()`` call.""" - design: InverseDesignType = pd.Field( - ..., + design: InverseDesignType = Field( title="Inverse Design Specification", description="Specification describing the inverse design problem we wish to optimize.", ) - params: tuple[ArrayLike, ...] = pd.Field( + params: tuple[ArrayLike, ...] = Field( (), title="Parameter History", description="History of parameter arrays throughout the optimization.", ) - objective_fn_val: tuple[float, ...] = pd.Field( + objective_fn_val: tuple[float, ...] = Field( (), title="Objective Function History", description="History of objective function values throughout the optimization.", ) - grad: tuple[ArrayLike, ...] = pd.Field( + grad: tuple[ArrayLike, ...] = Field( (), title="Gradient History", description="History of objective function gradient arrays throughout the optimization.", ) - penalty: tuple[float, ...] = pd.Field( + penalty: tuple[float, ...] = Field( (), title="Penalty History", description="History of weighted sum of penalties throughout the optimization.", ) - post_process_val: tuple[float, ...] = pd.Field( + post_process_val: tuple[float, ...] = Field( (), title="Post-Process Function History", description="History of return values from ``post_process_fn`` throughout the optimization.", ) - simulation: tuple[td.Simulation, ...] = pd.Field( + simulation: tuple[td.Simulation, ...] = Field( (), title="Simulation History", description="History of ``td.Simulation`` instances throughout the optimization.", ) - opt_state: tuple[dict, ...] = pd.Field( + opt_state: tuple[dict[str, Union[int, ArrayLike]], ...] = Field( (), title="Optimizer State History", description="History of optimizer states throughout the optimization.", ) - @pd.validator("params", pre=False, allow_reuse=True) - def _validate_and_clip_params(cls, params_tuple): + @field_validator("params") + @classmethod + def _validate_and_clip_params( + cls, params_tuple: tuple[ArrayLike, ...] + ) -> tuple[ArrayLike, ...]: """Ensure all parameters in history are within [0,1] bounds, clipping if necessary.""" if not params_tuple: return params_tuple @@ -121,11 +122,11 @@ def keys(self) -> list[str]: return list(self.history.keys()) @property - def last(self) -> dict[str, typing.Any]: + def last(self) -> dict[str, Any]: """Dictionary of last values in ``self.history``.""" return {key: value[-1] for key, value in self.history.items()} - def get(self, key: str, index: int = -1) -> typing.Any: + def get(self, key: str, index: int = -1) -> Any: """Get the value from the history at a certain index (-1 means last).""" if key not in self.keys: raise KeyError(f"'{key}' not present in 'Result.history' dict with: {self.keys}.") @@ -134,24 +135,24 @@ def get(self, key: str, index: int = -1) -> typing.Any: raise ValueError(f"Can't get the last value of '{key}' as there is no history present.") return values[index] - def get_last(self, key: str) -> typing.Any: + def get_last(self, key: str) -> Any: """Get the last value from the history.""" return self.get(key=key, index=-1) - def get_sim(self, index: int = -1) -> typing.Union[td.Simulation, list[td.Simulation]]: + def get_sim(self, index: int = -1) -> Union[td.Simulation, list[td.Simulation]]: """Get the simulation at a specific index in the history (list of sims if multi).""" params = np.array(self.get(key="params", index=index)) return self.design.to_simulation(params=params) def get_sim_data( self, index: int = -1, **kwargs: Any - ) -> typing.Union[td.SimulationData, list[td.SimulationData]]: + ) -> Union[td.SimulationData, list[td.SimulationData]]: """Get the simulation data at a specific index in the history (list of simdata if multi).""" params = np.array(self.get(key="params", index=index)) return self.design.to_simulation_data(params=params, **kwargs) @property - def sim_last(self) -> typing.Union[td.Simulation, list[td.Simulation]]: + def sim_last(self) -> Union[td.Simulation, list[td.Simulation]]: """The last simulation.""" return self.get_sim(index=-1) diff --git a/tidy3d/plugins/invdes/transformation.py b/tidy3d/plugins/invdes/transformation.py index 233531dc0a..230a213404 100644 --- a/tidy3d/plugins/invdes/transformation.py +++ b/tidy3d/plugins/invdes/transformation.py @@ -2,11 +2,9 @@ from __future__ import annotations import abc -import typing -from typing import Any +from typing import TYPE_CHECKING, Any, Union -import autograd.numpy as anp -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat import tidy3d as td from tidy3d.plugins.autograd.functions import threshold @@ -14,6 +12,9 @@ from .base import InvdesBaseModel +if TYPE_CHECKING: + import autograd.numpy as anp + class AbstractTransformation(InvdesBaseModel, abc.ABC): """Base class for transformations.""" @@ -36,8 +37,7 @@ class FilterProject(InvdesBaseModel): """ - radius: pd.PositiveFloat = pd.Field( - ..., + radius: PositiveFloat = Field( title="Filter Radius", description="Radius of the filter to convolve with supplied spatial data. " "Note: the corresponding feature size expressed in the device is typically " @@ -50,7 +50,7 @@ class FilterProject(InvdesBaseModel): units=td.constants.MICROMETER, ) - beta: float = pd.Field( + beta: float = Field( 1.0, ge=1.0, title="Beta", @@ -59,11 +59,15 @@ class FilterProject(InvdesBaseModel): "at the expense of gradient accuracy and ease of optimization. ", ) - eta: float = pd.Field( - 0.5, ge=0.0, le=1.0, title="Eta", description="Halfway point in projection function." + eta: float = Field( + 0.5, + ge=0.0, + le=1.0, + title="Eta", + description="Halfway point in projection function.", ) - strict_binarize: bool = pd.Field( + strict_binarize: bool = Field( False, title="Binarize strictly", description="If ``False``, the binarization is still continuous between min and max. " @@ -83,4 +87,4 @@ def evaluate(self, spatial_data: anp.ndarray, design_region_dl: float) -> anp.nd return data_projected -TransformationType = typing.Union[FilterProject] +TransformationType = Union[FilterProject] diff --git a/tidy3d/plugins/invdes/utils.py b/tidy3d/plugins/invdes/utils.py index b29f4e703d..556bdec73c 100644 --- a/tidy3d/plugins/invdes/utils.py +++ b/tidy3d/plugins/invdes/utils.py @@ -1,9 +1,7 @@ """Functional utilities that help define postprocessing functions more simply in ``invdes``.""" -# TODO: improve these? from __future__ import annotations -import typing from typing import Any import autograd.numpy as anp @@ -12,7 +10,7 @@ import tidy3d as td -def make_array(arr: typing.Any) -> anp.ndarray: +def make_array(arr: Any) -> anp.ndarray: """Turn something into a ``anp.ndarray``.""" if isinstance(arr, xr.DataArray): return anp.array(arr.values) diff --git a/tidy3d/plugins/invdes/validators.py b/tidy3d/plugins/invdes/validators.py index 2fdfc2df62..7b75eb501b 100644 --- a/tidy3d/plugins/invdes/validators.py +++ b/tidy3d/plugins/invdes/validators.py @@ -1,22 +1,26 @@ # validator utilities for invdes plugin from __future__ import annotations -import typing +from typing import TYPE_CHECKING, Any -import pydantic.v1 as pd +from pydantic import field_validator, model_validator import tidy3d as td -from tidy3d.components.base import skip_if_fields_missing + +if TYPE_CHECKING: + from typing import Callable, Optional + + from pydantic._internal._decorators import ModelValidatorDecoratorInfo, PydanticDescriptorProxy # warn if pixel size is > PIXEL_SIZE_WARNING_THRESHOLD * (minimum wavelength in material) PIXEL_SIZE_WARNING_THRESHOLD = 0.1 -def ignore_inherited_field(field_name: str) -> typing.Callable: +def ignore_inherited_field(field_name: str) -> Callable: """Create validator that ignores a field inherited but not set by user.""" - @pd.validator(field_name, always=True) - def _ignore_field(cls, val) -> None: + @field_validator(field_name) + def _ignore_field(val: Any) -> None: """Ignore supplied field value and warn.""" if val is not None: td.log.warning( @@ -29,11 +33,11 @@ def _ignore_field(cls, val) -> None: return _ignore_field -def check_pixel_size(sim_field_name: str): +def check_pixel_size(sim_field_name: str) -> PydanticDescriptorProxy[ModelValidatorDecoratorInfo]: """make validator to check the pixel size of sim or list of sims in an ``InverseDesign``.""" def check_pixel_size_sim( - sim: td.Simulation, pixel_size: float, index: typing.Optional[int] = None + sim: td.Simulation, pixel_size: float, index: Optional[int] = None ) -> None: """Check a pixel size compared to the simulation min wvl in material.""" if not sim.sources: @@ -55,16 +59,15 @@ def check_pixel_size_sim( "array resolution, one can set 'DesignRegion.override_structure_dl'." ) - @pd.root_validator(allow_reuse=True) - @skip_if_fields_missing(["design_region"], root=True) - def _check_pixel_size(cls, values): + @model_validator(mode="after") + def _check_pixel_size(self: Any) -> Any: """Make sure region pixel_size isn't too large compared to sim's wavelength in material.""" - sim = values.get(sim_field_name) - region = values.get("design_region") + sim = getattr(self, sim_field_name) + region = self.design_region pixel_size = region.pixel_size if not sim and region: - return values + return self if isinstance(sim, (list, tuple)): for i, s in enumerate(sim): @@ -72,6 +75,6 @@ def _check_pixel_size(cls, values): else: check_pixel_size_sim(sim=sim, pixel_size=pixel_size) - return values + return self return _check_pixel_size diff --git a/tidy3d/plugins/klayout/drc/drc.py b/tidy3d/plugins/klayout/drc/drc.py index 010081e46b..16631adb86 100644 --- a/tidy3d/plugins/klayout/drc/drc.py +++ b/tidy3d/plugins/klayout/drc/drc.py @@ -6,10 +6,9 @@ from collections.abc import Mapping from pathlib import Path from subprocess import run -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any -import pydantic.v1 as pd -from pydantic.v1 import validator +from pydantic import Field, FilePath, field_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.geometry.base import Geometry @@ -25,43 +24,48 @@ from tidy3d.plugins.klayout.drc.results import DRCResults from tidy3d.plugins.klayout.util import check_installation +if TYPE_CHECKING: + from typing import Optional, Union + SUPPORTED_DRC_SUFFIXES: frozenset[str] = frozenset({".drc", ".lydrc"}) class DRCConfig(Tidy3dBaseModel): """Configuration for KLayout DRC.""" - gdsfile: pd.FilePath = pd.Field( + gdsfile: FilePath = Field( title="GDS File", description="The path to the GDS file to write the Tidy3D object to.", ) - drc_runset: pd.FilePath = pd.Field( + drc_runset: FilePath = Field( title="DRC Runset file", description="Path to the KLayout DRC runset file.", ) - resultsfile: Path = pd.Field( + resultsfile: Path = Field( title="DRC Results File", description="Path to the KLayout DRC results file.", ) - verbose: bool = pd.Field( + verbose: bool = Field( title="Verbose", description="Whether to print logging.", ) - drc_args: dict[str, str] = pd.Field( + drc_args: dict[str, str] = Field( default_factory=dict, title="DRC File Arguments", description="Optional key/value pairs forwarded to KLayout as -rd = definitions.", ) - @validator("gdsfile") - def _validate_gdsfile_filetype(cls, v: pd.FilePath) -> pd.FilePath: + @field_validator("gdsfile") + @classmethod + def _validate_gdsfile_filetype(cls, v: FilePath) -> FilePath: """Check GDS filetype is ``.gds``.""" if v.suffix != ".gds": raise ValidationError(f"GDS file '{v}' must end with '.gds'.") return v - @validator("drc_runset") - def _validate_drc_runset_filetype(cls, v: pd.FilePath) -> pd.FilePath: + @field_validator("drc_runset") + @classmethod + def _validate_drc_runset_filetype(cls, v: FilePath) -> FilePath: """Check DRC runset filetype is ``.drc`` or ``.lydrc``.""" if v.suffix not in SUPPORTED_DRC_SUFFIXES: raise ValidationError( @@ -69,8 +73,9 @@ def _validate_drc_runset_filetype(cls, v: pd.FilePath) -> pd.FilePath: ) return v - @validator("drc_runset") - def _validate_drc_runset_format(cls, v: pd.FilePath) -> pd.FilePath: + @field_validator("drc_runset") + @classmethod + def _validate_drc_runset_format(cls, v: FilePath) -> FilePath: """Check if the DRC runset file is formatted correctly. The checks are: 1. The GDS source must be loaded with 'source($gdsfile)'. @@ -88,7 +93,8 @@ def _validate_drc_runset_format(cls, v: pd.FilePath) -> pd.FilePath: ) return v - @validator("drc_args", pre=True) + @field_validator("drc_args", mode="before") + @classmethod def _validate_drc_args_stringable(cls, v: Any) -> dict[str, str]: """Coerce all keys and values in drc_args to strings.""" if v is None: @@ -101,7 +107,8 @@ def _validate_drc_args_stringable(cls, v: Any) -> dict[str, str]: raise ValidationError("Could not coerce keys and values of drc_args to strings.") from e return v - @validator("drc_args") + @field_validator("drc_args") + @classmethod def _validate_drc_args_reserved(cls, v: dict[str, str]) -> dict[str, str]: """Ensure user arguments do not override the reserved keys.""" @@ -144,11 +151,11 @@ class DRCRunner(Tidy3dBaseModel): >>> print(results) # doctest: +SKIP """ - drc_runset: pd.FilePath = pd.Field( + drc_runset: FilePath = Field( title="DRC Runset file", description="Path to the KLayout DRC runset file.", ) - verbose: bool = pd.Field( + verbose: bool = Field( default=DEFAULT_VERBOSE, title="Verbose", description="Whether to print logging.", @@ -273,7 +280,8 @@ def run_drc_on_gds(config: DRCConfig, max_results: Optional[int] = None) -> DRCR output = run(cmd, capture_output=True) if output.returncode != 0: - raise RuntimeError(f"KLayout DRC failed with error message: '{output.stderr}'.") + msg = output.stderr.decode(errors="replace") + raise RuntimeError(f"KLayout DRC failed with error message: '{msg}'.") if config.verbose: console.log("KLayout DRC completed successfully.") diff --git a/tidy3d/plugins/klayout/drc/results.py b/tidy3d/plugins/klayout/drc/results.py index 5093524b50..c19ea581e2 100644 --- a/tidy3d/plugins/klayout/drc/results.py +++ b/tidy3d/plugins/klayout/drc/results.py @@ -4,16 +4,19 @@ import re import xml.etree.ElementTree as ET -from pathlib import Path -from typing import Optional, Union +from typing import TYPE_CHECKING -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.types import Coordinate2D from tidy3d.exceptions import FileError from tidy3d.log import log +if TYPE_CHECKING: + from pathlib import Path + from typing import Optional, Union + # Types for DRC markers DRCEdge = tuple[Coordinate2D, Coordinate2D] DRCEdgePair = tuple[DRCEdge, DRCEdge] @@ -199,13 +202,13 @@ def parse_violation_value(value: str, *, cell: str) -> DRCMarker: class DRCMarker(Tidy3dBaseModel): """Base marker storing the cell in which the violation was detected.""" - cell: str = pd.Field(title="Cell", description="Cell name where the violation occurred.") + cell: str = Field(title="Cell", description="Cell name where the violation occurred.") class EdgeMarker(DRCMarker): """A class for storing KLayout DRC edge marker results.""" - edge: DRCEdge = pd.Field( + edge: DRCEdge = Field( title="DRC Edge Marker", description="The edge marker of the DRC violation. The format is ((x1, y1), (x2, y2)).", ) @@ -214,7 +217,7 @@ class EdgeMarker(DRCMarker): class EdgePairMarker(DRCMarker): """A class for storing KLayout DRC edge pair marker results.""" - edge_pair: DRCEdgePair = pd.Field( + edge_pair: DRCEdgePair = Field( title="DRC Edge Pair Marker", description="The edge pair marker of the DRC violation. The format is (edge1, edge2), where an edge has format ((x1, y1), (x2, y2)).", ) @@ -223,7 +226,7 @@ class EdgePairMarker(DRCMarker): class MultiPolygonMarker(DRCMarker): """A class for storing KLayout DRC multi-polygon marker results.""" - polygons: DRCMultiPolygon = pd.Field( + polygons: DRCMultiPolygon = Field( title="DRC Multi-Polygon Marker", description="The multi-polygon marker of the DRC violation. The format is (polygon1, polygon2, ...), where each polygon has format ((x1, y1), (x2, y2), ...).", ) @@ -232,10 +235,10 @@ class MultiPolygonMarker(DRCMarker): class DRCViolation(Tidy3dBaseModel): """A class for storing KLayout DRC violation results for a single category.""" - category: str = pd.Field( + category: str = Field( title="DRC Violation Category", description="The category of the DRC violation." ) - markers: tuple[DRCMarker, ...] = pd.Field( + markers: tuple[DRCMarker, ...] = Field( title="DRC Markers", description="Tuple of DRC markers in this category." ) @@ -272,7 +275,7 @@ def __str__(self) -> str: class DRCResults(Tidy3dBaseModel): """A class for loading and storing KLayout DRC results.""" - violations_by_category: dict[str, DRCViolation] = pd.Field( + violations_by_category: dict[str, DRCViolation] = Field( title="DRC Violations", description="Dictionary of DRC violations by category." ) diff --git a/tidy3d/plugins/klayout/util.py b/tidy3d/plugins/klayout/util.py index 35a7c9a06a..20ebc8b8b6 100644 --- a/tidy3d/plugins/klayout/util.py +++ b/tidy3d/plugins/klayout/util.py @@ -4,10 +4,13 @@ import platform from pathlib import Path from shutil import which -from typing import Union +from typing import TYPE_CHECKING import tidy3d as td +if TYPE_CHECKING: + from typing import Union + def check_installation(raise_error: bool = False) -> Union[str, None]: """Return the path to the KLayout executable if it is installed. diff --git a/tidy3d/plugins/microwave/array_factor.py b/tidy3d/plugins/microwave/array_factor.py index 446c8fa292..373d7d96b8 100644 --- a/tidy3d/plugins/microwave/array_factor.py +++ b/tidy3d/plugins/microwave/array_factor.py @@ -3,37 +3,53 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np -import pydantic.v1 as pd -from pydantic.v1 import NonNegativeFloat, PositiveInt, conint +from pydantic import ( + Field, + NonNegativeFloat, + PositiveFloat, + PositiveInt, + conint, + field_validator, + model_validator, +) from scipy.signal.windows import blackman, blackmanharris, chebwin, hamming, hann, kaiser, taylor from scipy.special import j0, jn_zeros -from tidy3d.components.base import skip_if_fields_missing -from tidy3d.components.data.monitor_data import AbstractFieldProjectionData, DirectivityData +from tidy3d.components.data.monitor_data import DirectivityData from tidy3d.components.data.sim_data import SimulationData -from tidy3d.components.geometry.base import Box, Geometry -from tidy3d.components.grid.grid_spec import GridSpec, LayerRefinementSpec +from tidy3d.components.geometry.base import Box +from tidy3d.components.grid.grid_spec import LayerRefinementSpec from tidy3d.components.lumped_element import LumpedElement -from tidy3d.components.medium import Medium, MediumType3D +from tidy3d.components.medium import Medium from tidy3d.components.microwave.base import MicrowaveBaseModel from tidy3d.components.monitor import AbstractFieldProjectionMonitor -from tidy3d.components.simulation import Simulation -from tidy3d.components.source.utils import SourceType from tidy3d.components.structure import MeshOverrideStructure, Structure -from tidy3d.components.types import TYPE_TAG_STR, ArrayLike, Axis, Bound, Undefined -from tidy3d.components.types.monitor import MonitorType +from tidy3d.components.types import TYPE_TAG_STR, ArrayLike, Undefined from tidy3d.constants import C_0, inf from tidy3d.exceptions import Tidy3dNotImplementedError from tidy3d.log import log +if TYPE_CHECKING: + from numpy.typing import NDArray + + from tidy3d.compat import Self + from tidy3d.components.data.monitor_data import AbstractFieldProjectionData + from tidy3d.components.geometry.base import Geometry + from tidy3d.components.grid.grid_spec import GridSpec + from tidy3d.components.medium import MediumType3D + from tidy3d.components.simulation import Simulation + from tidy3d.components.source.utils import SourceType + from tidy3d.components.types import Axis, Bound + from tidy3d.components.types.monitor import MonitorType + class AbstractAntennaArrayCalculator(MicrowaveBaseModel, ABC): """Abstract base for phased array calculators.""" - taper: Union[RectangularTaper, RadialTaper] = pd.Field( + taper: Optional[Union[RectangularTaper, RadialTaper]] = Field( None, discriminator=TYPE_TAG_STR, title="Antenna Array Taper", @@ -76,7 +92,7 @@ def _antenna_nominal_center(self) -> ArrayLike: return 0.5 * (rmax + rmin) - def _detect_antenna_bounds(self, simulation: Simulation): + def _detect_antenna_bounds(self, simulation: Simulation) -> Bound: """Detect the bounds of the antenna in the simulation.""" # directions in which we will need to tile simulation extend_dims = self._extend_dims @@ -129,7 +145,7 @@ def _detect_antenna_bounds(self, simulation: Simulation): def _try_to_expand_geometry( self, geometry: Geometry, old_sim_bounds: Bound, new_sim_bounds: Bound - ): + ) -> Geometry: """Try to expand geometry to cover the entire simulation domain.""" can_expand = isinstance(geometry, Box) and all( @@ -177,7 +193,7 @@ def _duplicate_or_expand_list_of_objects( ], old_sim_bounds: Bound, new_sim_bounds: Bound, - ): + ) -> list[Union[Structure, MeshOverrideStructure, LayerRefinementSpec, LumpedElement]]: """Duplicate or expand a list of objects.""" locations = self._antenna_locations @@ -245,7 +261,7 @@ def _expand_monitors( antenna_bounds: Bound, new_sim_bounds: Bound, old_sim_bounds: Bound, - ): + ) -> list[MonitorType]: """Expand monitors.""" extend_dims = self._extend_dims @@ -314,7 +330,7 @@ def _expand_monitors( def _duplicate_structures( self, structures: tuple[Structure, ...], new_sim_bounds: Bound, old_sim_bounds: Bound - ): + ) -> list[Structure]: """Duplicate structures.""" return self._duplicate_or_expand_list_of_objects( @@ -327,7 +343,7 @@ def _duplicate_sources( lumped_elements: tuple[LumpedElement, ...], old_sim_bounds: Bound, new_sim_bounds: Bound, - ): + ) -> tuple[list[SourceType], list[LumpedElement]]: """Duplicate sources and lumped elements.""" array_lumped_elements = self._duplicate_or_expand_list_of_objects( objects=lumped_elements, old_sim_bounds=old_sim_bounds, new_sim_bounds=new_sim_bounds @@ -353,7 +369,7 @@ def _duplicate_sources( def _duplicate_grid_specs( self, grid_spec: GridSpec, old_sim_bounds: Bound, new_sim_bounds: Bound - ): + ) -> GridSpec: """Duplicate grid specs.""" array_overrides = self._duplicate_or_expand_list_of_objects( @@ -370,12 +386,12 @@ def _duplicate_grid_specs( array_snapping_points = [] for translation_vector in self._antenna_locations: for snapping_point in grid_spec.snapping_points: - new_snapping_point = [ + new_snapping_point = tuple( snapping_point[dim] + translation_vector[dim] if snapping_point[dim] is not None else None for dim in range(3) - ] + ) array_snapping_points.append(new_snapping_point) return grid_spec.updated_copy( @@ -731,35 +747,49 @@ class RectangularAntennaArrayCalculator(AbstractAntennaArrayCalculator): ... ) """ - array_size: tuple[PositiveInt, PositiveInt, PositiveInt] = pd.Field( + array_size: tuple[PositiveInt, PositiveInt, PositiveInt] = Field( title="Array Size", description="Number of antennas along x, y, and z directions.", ) - spacings: tuple[NonNegativeFloat, NonNegativeFloat, NonNegativeFloat] = pd.Field( + spacings: tuple[NonNegativeFloat, NonNegativeFloat, NonNegativeFloat] = Field( title="Antenna Spacings", description="Center-to-center spacings between antennas along x, y, and z directions.", ) - phase_shifts: tuple[float, float, float] = pd.Field( + phase_shifts: tuple[float, float, float] = Field( (0, 0, 0), title="Phase Shifts", description="Phase-shifts between antennas along x, y, and z directions.", ) - amp_multipliers: tuple[Optional[ArrayLike], Optional[ArrayLike], Optional[ArrayLike]] = ( - pd.Field( - (None, None, None), - title="Amplitude Multipliers", - description="Amplitude multipliers spatially distributed along x, y, and z directions.", - ) + amp_multipliers: tuple[Optional[ArrayLike], Optional[ArrayLike], Optional[ArrayLike]] = Field( + (None, None, None), + title="Amplitude Multipliers", + description="Amplitude multipliers spatially distributed along x, y, and z directions.", ) - @pd.validator("amp_multipliers", pre=True, always=True) - @skip_if_fields_missing(["array_size"]) - def _check_amp_multipliers(cls, val, values): + @field_validator("array_size", "spacings", "phase_shifts", "amp_multipliers", mode="before") + @classmethod + def _convert_list_to_tuple(cls, v: Any) -> Any: + """Convert lists to tuples for tuple fields.""" + if isinstance(v, list): + return tuple(v) + # Handle numpy arrays + try: + import numpy as np + + if isinstance(v, np.ndarray): + return tuple(v.tolist()) + except ImportError: + pass + return v + + @model_validator(mode="after") + def _check_amp_multipliers(self) -> Self: """Check that the length of the amplitude multipliers is equal to the array size along each dimension.""" - array_size = values.get("array_size") + val = self.amp_multipliers + array_size = self.array_size if len(val) != 3: raise ValueError("'amp_multipliers' must have 3 elements.") if val[0] is not None and len(val[0]) != array_size[0]: @@ -774,7 +804,7 @@ def _check_amp_multipliers(cls, val, values): raise ValueError( f"'amp_multipliers' has length of {len(val[2])} along the z direction, but the array size is {array_size[2]}." ) - return val + return self @property def _antenna_locations(self) -> ArrayLike: @@ -999,7 +1029,7 @@ def _get_weights_discrete(self, N: int) -> ArrayLike: class ChebWindow(AbstractWindow): """Standard Chebyshev window for tapering with configurable sidelobe attenuation.""" - attenuation: pd.PositiveFloat = pd.Field( + attenuation: PositiveFloat = Field( default=30, title="Attenuation", description="Desired attenuation level of sidelobes.", @@ -1026,7 +1056,7 @@ def _get_weights_discrete(self, N: int) -> ArrayLike: class KaiserWindow(AbstractWindow): """Class for Kaiser window.""" - beta: pd.NonNegativeFloat = pd.Field( + beta: NonNegativeFloat = Field( ..., title="Shape Parameter", description="Shape parameter, determines trade-off between main-lobe width and side lobe level.", @@ -1052,20 +1082,20 @@ def _get_weights_discrete(self, N: int) -> ArrayLike: class TaylorWindow(AbstractWindow): """Taylor window with configurable sidelobe suppression and sidelobe count.""" - sll: pd.PositiveFloat = pd.Field( + sll: PositiveFloat = Field( default=30, title="Sidelobe Suppression Level", description="Desired suppression of sidelobe level relative to the DC gain.", units="dB", ) - nbar: conint(gt=0, le=10) = pd.Field( + nbar: conint(gt=0, le=10) = Field( default=4, title="Number of Nearly Constant Sidelobes", description="Number of nearly constant level sidelobes adjacent to the mainlobe.", ) - def _get_weights_discrete(self, N): + def _get_weights_discrete(self, N: int) -> NDArray: """ Generate a 1D Taylor window of length N. @@ -1081,7 +1111,7 @@ def _get_weights_discrete(self, N): """ return taylor(N, self.nbar, self.sll) - def _get_exp_weights(self, mus: np.ndarray): + def _get_exp_weights(self, mus: NDArray) -> NDArray: """ Compute expansion coefficients B_l for the circular Taylor taper. @@ -1182,21 +1212,21 @@ def amp_multipliers( class RectangularTaper(AbstractTaper): """Class for rectangular taper.""" - window_x: Optional[RectangularWindowType] = pd.Field( + window_x: Optional[RectangularWindowType] = Field( None, title="X Axis Window", description="Window type used to taper array antenna along x axis.", discriminator=TYPE_TAG_STR, ) - window_y: Optional[RectangularWindowType] = pd.Field( + window_y: Optional[RectangularWindowType] = Field( None, title="Y Axis Window", description="Window type used to taper array antenna along y axis.", discriminator=TYPE_TAG_STR, ) - window_z: Optional[RectangularWindowType] = pd.Field( + window_z: Optional[RectangularWindowType] = Field( None, title="Z Axis Window", description="Window type used to taper array antenna along z axis.", @@ -1220,11 +1250,11 @@ def from_isotropic_window(cls, window: RectangularWindowType) -> RectangularTape """ return cls(window_x=window, window_y=window, window_z=window) - @pd.root_validator - def check_at_least_one_window(cls, values): - if not any([values.get("window_x"), values.get("window_y"), values.get("window_z")]): + @model_validator(mode="after") + def check_at_least_one_window(self) -> Self: + if not any([self.window_x, self.window_y, self.window_z]): raise ValueError("At least one window (x, y, or z) must be provided.") - return values + return self def amp_multipliers( self, array_size: tuple[PositiveInt, PositiveInt, PositiveInt] @@ -1258,7 +1288,7 @@ def amp_multipliers( class RadialTaper(AbstractTaper): """Class for Radial Taper.""" - window: TaylorWindow = pd.Field( + window: TaylorWindow = Field( ..., title="Window Object", description="Window type used to taper array antenna." ) @@ -1296,4 +1326,4 @@ def amp_multipliers( return (amps,) -RectangularAntennaArrayCalculator.update_forward_refs() +RectangularAntennaArrayCalculator.model_rebuild() diff --git a/tidy3d/plugins/microwave/lobe_measurer.py b/tidy3d/plugins/microwave/lobe_measurer.py index 18985ac7c9..bdf7c0c2e6 100644 --- a/tidy3d/plugins/microwave/lobe_measurer.py +++ b/tidy3d/plugins/microwave/lobe_measurer.py @@ -3,20 +3,26 @@ from __future__ import annotations from math import isclose, isnan -from typing import Optional +from typing import TYPE_CHECKING import numpy as np -import pydantic.v1 as pd from pandas import DataFrame +from pydantic import Field, field_validator, model_validator -from tidy3d.components.base import cached_property, skip_if_fields_missing +from tidy3d.components.base import cached_property from tidy3d.components.microwave.base import MicrowaveBaseModel -from tidy3d.components.types import ArrayFloat1D, ArrayLike, Ax +from tidy3d.components.types import ArrayFloat1D from tidy3d.constants import fp_eps from tidy3d.exceptions import ValidationError from .viz import plot_params_lobe_FNBW, plot_params_lobe_peak, plot_params_lobe_width +if TYPE_CHECKING: + from typing import Optional + + from tidy3d.compat import Self + from tidy3d.components.types import ArrayLike, Ax + # The minimum plateau size for peak finding, which is set to 0 to ensure that all peaks are found. # A value must be provided to retrieve additional information from `find_peaks`. MIN_PLATEAU_SIZE = 0 @@ -39,21 +45,19 @@ class LobeMeasurer(MicrowaveBaseModel): >>> lobe_measures = lobe_measurer.lobe_measures """ - angle: ArrayFloat1D = pd.Field( - ..., + angle: ArrayFloat1D = Field( title="Angle", description="A 1-dimensional array of angles in radians. The angles should be " "in the range [0, 2π] and should be sorted in ascending order.", ) - radiation_pattern: ArrayFloat1D = pd.Field( - ..., + radiation_pattern: ArrayFloat1D = Field( title="Radiation Pattern", description="A 1-dimensional array of real values representing the radiation pattern " "of the antenna measured on a linear scale.", ) - apply_cyclic_extension: bool = pd.Field( + apply_cyclic_extension: bool = Field( True, title="Apply Cyclic Extension", description="To enable accurate peak finding near boundaries of the ``angle`` array, " @@ -61,7 +65,7 @@ class LobeMeasurer(MicrowaveBaseModel): "of interest, this can be set to ``False``.", ) - width_measure: float = pd.Field( + width_measure: float = Field( 0.5, gt=0.0, le=1.0, @@ -70,7 +74,7 @@ class LobeMeasurer(MicrowaveBaseModel): "Default value of ``0.5`` corresponds with the half-power beamwidth.", ) - min_lobe_height: float = pd.Field( + min_lobe_height: float = Field( DEFAULT_MIN_LOBE_REL_HEIGHT, gt=0.0, le=1.0, @@ -79,7 +83,7 @@ class LobeMeasurer(MicrowaveBaseModel): "Lobe heights are measured relative to the maximum value in ``radiation_pattern``.", ) - null_threshold: float = pd.Field( + null_threshold: float = Field( DEFAULT_NULL_THRESHOLD, gt=0.0, le=1.0, @@ -88,30 +92,31 @@ class LobeMeasurer(MicrowaveBaseModel): "which is relative to the maximum value in the ``radiation_pattern``.", ) - @pd.validator("angle", always=True) - def _sorted_angle(cls, val): + @field_validator("angle") + @classmethod + def _sorted_angle(cls, val: ArrayFloat1D) -> ArrayFloat1D: """Ensure the angle array is sorted.""" if not np.all(np.diff(val) >= 0): raise ValidationError("The angle array must be sorted in ascending order.") return val - @pd.validator("radiation_pattern", always=True) - def _nonnegative_radiation_pattern(cls, val): + @field_validator("radiation_pattern") + @classmethod + def _nonnegative_radiation_pattern(cls, val: ArrayFloat1D) -> ArrayFloat1D: """Ensure the radiation pattern is nonnegative.""" if not np.all(val >= 0): raise ValidationError("Radiation pattern must be nonnegative.") return val - @pd.validator("apply_cyclic_extension", always=True) - @skip_if_fields_missing(["angle"]) - def _cyclic_extension_valid(cls, val, values): - if val: - angle = values.get("angle") + @model_validator(mode="after") + def _cyclic_extension_valid(self) -> Self: + if self.apply_cyclic_extension: + angle = self.angle if np.any(angle < 0) or np.any(angle > 2 * np.pi): raise ValidationError( "When using cyclic extension, the angle array must be in the range [0, 2π]." ) - return val + return self @cached_property def lobe_measures(self) -> DataFrame: diff --git a/tidy3d/plugins/polyslab/polyslab.py b/tidy3d/plugins/polyslab/polyslab.py index a5dfa3cefd..c7b34a95ed 100644 --- a/tidy3d/plugins/polyslab/polyslab.py +++ b/tidy3d/plugins/polyslab/polyslab.py @@ -2,10 +2,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from tidy3d.components.geometry.polyslab import ComplexPolySlabBase -from tidy3d.components.medium import MediumType from tidy3d.components.structure import Structure +if TYPE_CHECKING: + from tidy3d.components.medium import MediumType + class ComplexPolySlab(ComplexPolySlabBase): """Interface for dividing a complex polyslab where self-intersecting polygon can diff --git a/tidy3d/plugins/pytorch/wrapper.py b/tidy3d/plugins/pytorch/wrapper.py index d6716dca3a..682445bc65 100644 --- a/tidy3d/plugins/pytorch/wrapper.py +++ b/tidy3d/plugins/pytorch/wrapper.py @@ -1,13 +1,18 @@ from __future__ import annotations import inspect -from typing import Any +from typing import TYPE_CHECKING, Any import torch from autograd import make_vjp +if TYPE_CHECKING: + from typing import Callable -def to_torch(fun): + from torch.autograd.function import FunctionCtx + + +def to_torch(fun: Callable[..., Any]) -> Callable[..., torch.Tensor]: """ Converts an autograd function to a PyTorch function. @@ -49,7 +54,7 @@ class _Wrapper(torch.autograd.Function): """ @staticmethod - def forward(ctx, *args: Any): + def forward(ctx: FunctionCtx, *args: Any) -> torch.Tensor: numpy_args = [] grad_argnums = [] @@ -78,7 +83,9 @@ def forward(ctx, *args: Any): return torch.as_tensor(ans, device=device) @staticmethod - def backward(ctx, grad_output): + def backward( + ctx: FunctionCtx, grad_output: torch.Tensor + ) -> tuple[torch.Tensor | None, ...]: numpy_grad_output = grad_output.detach().cpu().numpy() _grads = ctx.vjp(numpy_grad_output) grads = [None] * ctx.num_args @@ -86,7 +93,7 @@ def backward(ctx, grad_output): grads[idx] = torch.as_tensor(grad, device=ctx.device) return tuple(grads) - def apply(*args: Any, **kwargs: Any): + def apply(*args: Any, **kwargs: Any) -> torch.Tensor: # we bind the full function signature including defaults so that we can pass # all values as positional since torch.autograd.Function.apply only accepts # positional arguments diff --git a/tidy3d/plugins/resonance/resonance.py b/tidy3d/plugins/resonance/resonance.py index 0f49c0f0c9..fa415ad705 100644 --- a/tidy3d/plugins/resonance/resonance.py +++ b/tidy3d/plugins/resonance/resonance.py @@ -3,20 +3,25 @@ from __future__ import annotations from functools import partial -from typing import Union +from typing import TYPE_CHECKING, Optional import numpy as np import xarray as xr -from pydantic.v1 import Field, NonNegativeFloat, PositiveInt, validator +from pydantic import Field, NonNegativeFloat, PositiveInt, field_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.data.data_array import ScalarFieldTimeDataArray from tidy3d.components.data.monitor_data import FieldTimeData -from tidy3d.components.types import ArrayComplex1D, ArrayComplex2D, ArrayComplex3D, ArrayFloat1D +from tidy3d.components.types import ArrayComplex1D, ArrayFloat1D from tidy3d.constants import HERTZ from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log +if TYPE_CHECKING: + from typing import Literal, Union + + from tidy3d.components.types import ArrayComplex2D, ArrayComplex3D + INIT_NUM_FREQS = 200 TIME_STEP_RTOL = 1e-5 @@ -28,12 +33,19 @@ class ResonanceData(Tidy3dBaseModel): """Data class for storing objects computed while running the resonance finder.""" - eigvals: ArrayComplex1D = Field(..., title="Eigenvalues", description="Resonance eigenvalues.") - complex_amplitudes: ArrayComplex1D = Field( - None, title="Complex amplitudes", description="Complex resonance amplitudes" + eigvals: ArrayComplex1D = Field( + title="Eigenvalues", + description="Resonance eigenvalues.", ) - errors: ArrayFloat1D = Field( - None, title="Errors", description="Rough eigenvalue error estimate." + complex_amplitudes: Optional[ArrayComplex1D] = Field( + None, + title="Complex amplitudes", + description="Complex resonance amplitudes", + ) + errors: Optional[ArrayFloat1D] = Field( + None, + title="Errors", + description="Rough eigenvalue error estimate.", ) @@ -71,7 +83,6 @@ class ResonanceFinder(Tidy3dBaseModel): """ freq_window: tuple[float, float] = Field( - ..., title="Window ``[fmin, fmax]``", description="Window ``[fmin, fmax]`` for the initial frequencies. " "The resonance finder is initialized with an even grid of frequencies between " @@ -102,8 +113,9 @@ class ResonanceFinder(Tidy3dBaseModel): "Making this closer to zero will typically return more resonances.", ) - @validator("freq_window", always=True) - def _check_freq_window(cls, val): + @field_validator("freq_window") + @classmethod + def _check_freq_window(cls, val: tuple[float, float]) -> tuple[float, float]: """Validate ``freq_window``""" if val[1] < val[0]: raise ValidationError( @@ -165,7 +177,7 @@ def run_raw_signal(self, signal: list[complex], time_step: float) -> xr.Dataset: Parameters ---------- - signal : List[complex] + signal : list[complex] One-dimensional array holding the complex-valued time series data to search for resonances. time_step : float @@ -228,7 +240,9 @@ def _validate_scalar_field_time( return np.squeeze(signal.data), dt def _aggregate_field_time_comps( - self, signals: tuple[FieldTimeData, ...], comps + self, + signals: tuple[FieldTimeData, ...], + comps: list[Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]], ) -> ScalarFieldTimeDataArray: """Aggregates the given components from several :class:`.FieldTimeData`.""" total_signal = None diff --git a/tidy3d/plugins/smatrix/analysis/antenna.py b/tidy3d/plugins/smatrix/analysis/antenna.py index 1a60f37521..58259df92f 100644 --- a/tidy3d/plugins/smatrix/analysis/antenna.py +++ b/tidy3d/plugins/smatrix/analysis/antenna.py @@ -1,13 +1,17 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING import numpy as np from tidy3d.components.microwave.data.monitor_data import AntennaMetricsData from tidy3d.plugins.smatrix.data.data_array import PortDataArray -from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData -from tidy3d.plugins.smatrix.types import NetworkIndex + +if TYPE_CHECKING: + from typing import Optional + + from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData + from tidy3d.plugins.smatrix.types import NetworkIndex def get_antenna_metrics_data( @@ -52,7 +56,7 @@ def get_antenna_metrics_data( # Use the first port as default if none specified if port_amplitudes is None: first_port_index = terminal_component_modeler_data.modeler.matrix_indices_source[0] - port_amplitudes = {first_port_index: None} + port_amplitudes = {first_port_index: None} # type: ignore[dict-item] # Get the radiation monitor, use first as default # if none specified if monitor_name is None: diff --git a/tidy3d/plugins/smatrix/analysis/modal.py b/tidy3d/plugins/smatrix/analysis/modal.py index 147b472b74..03c2da7a92 100644 --- a/tidy3d/plugins/smatrix/analysis/modal.py +++ b/tidy3d/plugins/smatrix/analysis/modal.py @@ -4,10 +4,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np from tidy3d.plugins.smatrix.data.data_array import ModalPortDataArray -from tidy3d.plugins.smatrix.data.modal import ModalComponentModelerData + +if TYPE_CHECKING: + from tidy3d.plugins.smatrix.data.modal import ModalComponentModelerData def modal_construct_smatrix(modeler_data: ModalComponentModelerData) -> ModalPortDataArray: diff --git a/tidy3d/plugins/smatrix/analysis/terminal.py b/tidy3d/plugins/smatrix/analysis/terminal.py index ea387d5f66..684e5aaf70 100644 --- a/tidy3d/plugins/smatrix/analysis/terminal.py +++ b/tidy3d/plugins/smatrix/analysis/terminal.py @@ -14,14 +14,12 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np -from tidy3d.components.data.sim_data import SimulationData -from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler -from tidy3d.plugins.smatrix.data.data_array import PortDataArray, TerminalPortDataArray -from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData +from tidy3d.plugins.smatrix.data.data_array import PortDataArray from tidy3d.plugins.smatrix.ports.wave import WavePort -from tidy3d.plugins.smatrix.types import SParamDef from tidy3d.plugins.smatrix.utils import ( ab_to_s, check_port_impedance_sign, @@ -29,6 +27,13 @@ compute_port_VI, ) +if TYPE_CHECKING: + from tidy3d.components.data.sim_data import SimulationData + from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler + from tidy3d.plugins.smatrix.data.data_array import TerminalPortDataArray + from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData + from tidy3d.plugins.smatrix.types import SParamDef + def terminal_construct_smatrix( modeler_data: TerminalComponentModelerData, diff --git a/tidy3d/plugins/smatrix/component_modelers/base.py b/tidy3d/plugins/smatrix/component_modelers/base.py index 49f04ca377..959705b0e3 100644 --- a/tidy3d/plugins/smatrix/component_modelers/base.py +++ b/tidy3d/plugins/smatrix/component_modelers/base.py @@ -5,9 +5,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Literal, Optional, Union -import pydantic.v1 as pd +from pydantic import Field, field_validator, model_validator -from tidy3d.components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.geometry.utils import _shift_value_signed from tidy3d.components.simulation import Simulation from tidy3d.components.source.time import SourceTimeType @@ -23,13 +23,18 @@ from tidy3d.exceptions import SetupError, Tidy3dKeyError from tidy3d.log import log from tidy3d.plugins.smatrix.ports.modal import Port -from tidy3d.plugins.smatrix.ports.types import LumpedPortType, PortType, TerminalPortType +from tidy3d.plugins.smatrix.ports.types import LumpedPortType, TerminalPortType from tidy3d.plugins.smatrix.ports.wave import WavePort from tidy3d.plugins.smatrix.types import Element, MatrixIndex, NetworkElement, NetworkIndex if TYPE_CHECKING: - from tidy3d.web.core.types import PayType + from pydantic import ValidationInfo + from tidy3d.compat import Self + from tidy3d.plugins.smatrix import MicrowaveSMatrixData + from tidy3d.plugins.smatrix.ports.modal import ModalPortDataArray + from tidy3d.plugins.smatrix.ports.types import PortType + from tidy3d.web.core.types import PayType # fwidth of gaussian pulse in units of central frequency FWIDTH_FRAC = 1.0 / 10 DEFAULT_DATA_DIR = "." @@ -42,32 +47,30 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel): """Tool for modeling devices and computing port parameters.""" - name: str = pd.Field( + name: str = Field( "", title="Name", ) - simulation: Simulation = pd.Field( - ..., + simulation: Simulation = Field( title="Simulation", description="Simulation describing the device without any sources present.", ) - ports: tuple[Union[Port, TerminalPortType], ...] = pd.Field( + ports: tuple[Union[Port, TerminalPortType], ...] = Field( (), title="Ports", description="Collection of ports describing the scattering matrix elements. " "For each input mode, one simulation will be run with a modal source.", ) - freqs: FreqArray = pd.Field( - ..., + freqs: FreqArray = Field( title="Frequencies", description="Array or list of frequencies at which to compute port parameters.", units=HERTZ, ) - remove_dc_component: bool = pd.Field( + remove_dc_component: bool = Field( True, title="Remove DC Component", description="Whether to remove the DC component in the Gaussian pulse spectrum. " @@ -78,7 +81,7 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel): "pulse spectrum which can have a nonzero DC component.", ) - run_only: Optional[tuple[IndexType, ...]] = pd.Field( + run_only: Optional[tuple[IndexType, ...]] = Field( None, title="Run Only", description="Set of matrix indices that define the simulations to run. " @@ -86,7 +89,7 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel): "If a tuple is given, simulations will be run only for the given matrix indices.", ) - element_mappings: tuple[tuple[ElementType, ElementType, Complex], ...] = pd.Field( + element_mappings: tuple[tuple[ElementType, ElementType, Complex], ...] = Field( (), title="Element Mappings", description="Tuple of S matrix element mappings, each described by a tuple of " @@ -95,34 +98,41 @@ class AbstractComponentModeler(ABC, Tidy3dBaseModel): "matrix element. If all elements of a given column of the scattering matrix are defined " "by ``element_mappings``, the simulation corresponding to this column is skipped automatically.", ) - custom_source_time: Optional[SourceTimeType] = pd.Field( + custom_source_time: Optional[SourceTimeType] = Field( None, title="Custom Source Time", description="If provided, this will be used as specification of the source time-dependence in simulations. " "Otherwise, a default source time will be constructed.", ) - @pd.root_validator(pre=False) - def _warn_refactor_2_10(cls, values): + @model_validator(mode="before") + @classmethod + def _warn_refactor_2_10(cls, data: dict) -> dict: log.warning( f"'{cls.__name__}' was refactored (tidy3d 'v2.10.0'). Existing functionality is available differently. Please consult the migration documentation: https://docs.flexcompute.com/projects/tidy3d/en/latest/api/microwave/microwave_migration.html", log_once=True, ) - return values + return data - @pd.validator("simulation", always=True) - def _sim_has_no_sources(cls, val): + @field_validator("simulation") + @classmethod + def _sim_has_no_sources(cls, val: Simulation) -> Simulation: """Make sure simulation has no sources as they interfere with tool.""" if len(val.sources) > 0: raise SetupError(f"'{cls.__name__}.simulation' must not have any sources.") return val - @pd.validator("element_mappings", always=True) - def _validate_element_mappings(cls, element_mappings, values): + @field_validator("element_mappings") + @classmethod + def _validate_element_mappings( + cls, + element_mappings: tuple[tuple[ElementType, ElementType, Complex], ...], + info: ValidationInfo, + ) -> tuple[tuple[ElementType, ElementType, Complex], ...]: """ Validate that each source index referenced in element_mappings is included in run_only. """ - run_only = values.get("run_only") + run_only = info.data.get("run_only") if run_only is None: return element_mappings @@ -141,12 +151,12 @@ def _validate_element_mappings(cls, element_mappings, values): ) return element_mappings - @pd.validator("run_only", always=True) - @skip_if_fields_missing(["ports"]) - def _validate_run_only(cls, val, values): + @model_validator(mode="after") + def _validate_run_only(self) -> Self: """Validate that run_only entries are unique and exist in matrix_indices_monitor.""" + val = self.run_only if val is None: - return val + return self # Check uniqueness if len(val) != len(set(val)): @@ -157,9 +167,9 @@ def _validate_run_only(cls, val, values): ) # Check membership - use the helper method to get valid indices - ports = values["ports"] + ports = self.ports - valid_indices = set(cls._construct_matrix_indices_monitor(ports)) + valid_indices = set(self._construct_matrix_indices_monitor(ports)) invalid_indices = [idx for idx in val if idx not in valid_indices] if invalid_indices: @@ -168,26 +178,26 @@ def _validate_run_only(cls, val, values): f"'matrix_indices_monitor'. Valid indices are: {sorted(valid_indices)}" ) - return val + return self _freqs_not_empty = validate_freqs_not_empty() _freqs_lower_bound = validate_freqs_min() _freqs_unique = validate_freqs_unique() - @pd.validator("custom_source_time", always=True) - @skip_if_fields_missing(["freqs"]) - def _freqs_in_custom_source_time(cls, val, values): + @model_validator(mode="after") + def _freqs_in_custom_source_time(self) -> Self: """Make sure freqs is in the range of the custom source time.""" + val = self.custom_source_time if val is None: - return val + return self freq_range = val._frequency_range_sigma_cached - freqs = values["freqs"] + freqs = self.freqs if freq_range[0] > min(freqs) or max(freqs) > freq_range[1]: log.warning( "Custom source time does not cover all 'freqs'.", ) - return val + return self @staticmethod def get_task_name(port: PortType, mode_index: Optional[int] = None) -> str: @@ -325,7 +335,7 @@ def run( priority: Optional[int] = None, local_gradient: bool = False, max_num_adjoint_per_fwd: Optional[int] = None, - ): + ) -> Union[ModalPortDataArray, MicrowaveSMatrixData]: log.warning( "'ComponentModeler.run()' is deprecated and will be removed in a future release. " "Use web.run(modeler) instead. 'web.run' returns a 'ComponentModelerData' object; " @@ -351,9 +361,9 @@ def run( ) return data.smatrix() - def validate_pre_upload(self): + def validate_pre_upload(self: Self) -> None: """Validate the modeler before upload.""" self.base_sim.validate_pre_upload(source_required=False) -AbstractComponentModeler.update_forward_refs() +AbstractComponentModeler.model_rebuild() diff --git a/tidy3d/plugins/smatrix/component_modelers/modal.py b/tidy3d/plugins/smatrix/component_modelers/modal.py index f75c98bc81..ea0d8a8124 100644 --- a/tidy3d/plugins/smatrix/component_modelers/modal.py +++ b/tidy3d/plugins/smatrix/component_modelers/modal.py @@ -2,31 +2,37 @@ from __future__ import annotations -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import autograd.numpy as np -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import cached_property -from tidy3d.components.data.sim_data import SimulationData from tidy3d.components.index import SimulationMap from tidy3d.components.monitor import ModeMonitor from tidy3d.components.source.field import ModeSource from tidy3d.components.source.time import GaussianPulse -from tidy3d.components.types import Ax, Complex +from tidy3d.components.types import Complex from tidy3d.components.viz import add_ax_if_none, equal_aspect from tidy3d.plugins.smatrix.ports.modal import Port from tidy3d.plugins.smatrix.types import Element, MatrixIndex from .base import FWIDTH_FRAC, AbstractComponentModeler +if TYPE_CHECKING: + from tidy3d.components.data.sim_data import SimulationData + from tidy3d.components.simulation import Simulation + from tidy3d.components.types import Ax + class ModalComponentModeler(AbstractComponentModeler): """A tool for modeling devices and computing scattering matrix elements. - This class orchestrates the process of running multiple simulations to - derive the scattering matrix (S-matrix) of a component. It uses modal - sources and monitors defined by a set of ports. + Notes + ----- + This class orchestrates the process of running multiple simulations to + derive the scattering matrix (S-matrix) of a component. It uses modal + sources and monitors defined by a set of ports. See Also -------- @@ -34,14 +40,14 @@ class ModalComponentModeler(AbstractComponentModeler): * `Computing the scattering matrix of a device <../../notebooks/SMatrix.html>`_ """ - ports: tuple[Port, ...] = pd.Field( + ports: tuple[Port, ...] = Field( (), title="Ports", description="Collection of ports describing the scattering matrix elements. " "For each input mode, one simulation will be run with a modal source.", ) - run_only: Optional[tuple[MatrixIndex, ...]] = pd.Field( + run_only: Optional[tuple[MatrixIndex, ...]] = Field( None, title="Run Only", description="Set of matrix indices that define the simulations to run. " @@ -49,7 +55,7 @@ class ModalComponentModeler(AbstractComponentModeler): "If a tuple is given, simulations will be run only for the given matrix indices.", ) - element_mappings: tuple[tuple[Element, Element, Complex], ...] = pd.Field( + element_mappings: tuple[tuple[Element, Element, Complex], ...] = Field( (), title="Element Mappings", description="Tuple of S matrix element mappings, each described by a tuple of " @@ -60,7 +66,7 @@ class ModalComponentModeler(AbstractComponentModeler): ) @property - def base_sim(self): + def base_sim(self) -> Simulation: """The base simulation.""" return self.simulation @@ -186,7 +192,7 @@ def to_monitor(self, port: Port) -> ModeMonitor: def to_source( self, port: Port, mode_index: int, num_freqs: int = 1, **kwargs: Any - ) -> list[ModeSource]: + ) -> ModeSource: """Creates a mode source from a given port. This source is used to excite a specific mode at the port. @@ -281,7 +287,7 @@ def plot_sim( for port_source in self.ports: mode_source_0 = self.to_source(port=port_source, mode_index=0) plot_sources.append(mode_source_0) - sim_plot = self.simulation.copy(update={"sources": plot_sources}) + sim_plot = self.simulation.copy(update={"sources": tuple(plot_sources)}) return sim_plot.plot(x=x, y=y, z=z, ax=ax) @equal_aspect @@ -322,7 +328,7 @@ def plot_sim_eps( for port_source in self.ports: mode_source_0 = self.to_source(port=port_source, mode_index=0) plot_sources.append(mode_source_0) - sim_plot = self.simulation.copy(update={"sources": plot_sources}) + sim_plot = self.simulation.copy(update={"sources": tuple(plot_sources)}) return sim_plot.plot_eps(x=x, y=y, z=z, ax=ax, **kwargs) def _normalization_factor(self, port_source: Port, sim_data: SimulationData) -> complex: diff --git a/tidy3d/plugins/smatrix/component_modelers/terminal.py b/tidy3d/plugins/smatrix/component_modelers/terminal.py index e9be3e1596..f64ed56269 100644 --- a/tidy3d/plugins/smatrix/component_modelers/terminal.py +++ b/tidy3d/plugins/smatrix/component_modelers/terminal.py @@ -2,13 +2,13 @@ from __future__ import annotations -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np -import pydantic.v1 as pd +from pydantic import Field, NonNegativeInt, field_validator, model_validator from tidy3d import ClipOperation, GeometryGroup, GridSpec, PolySlab -from tidy3d.components.base import cached_property, skip_if_fields_missing +from tidy3d.components.base import cached_property from tidy3d.components.boundary import BroadbandModeABCSpec from tidy3d.components.frequency_extrapolation import ( AbstractLowFrequencySmoothingSpec, @@ -21,10 +21,9 @@ from tidy3d.components.index import SimulationMap from tidy3d.components.microwave.base import MicrowaveBaseModel from tidy3d.components.monitor import DirectivityMonitor, ModeMonitor -from tidy3d.components.simulation import Simulation from tidy3d.components.source.time import GaussianPulse -from tidy3d.components.types import Ax, Complex, Coordinate -from tidy3d.components.types.base import annotate_type +from tidy3d.components.types import Complex, Coordinate +from tidy3d.components.types.base import discriminated_union from tidy3d.components.viz import add_ax_if_none, equal_aspect from tidy3d.constants import C_0, MICROMETER, OHM, fp_eps, inf from tidy3d.exceptions import SetupError, Tidy3dKeyError, ValidationError @@ -33,14 +32,19 @@ FWIDTH_FRAC, AbstractComponentModeler, ) -from tidy3d.plugins.smatrix.data.data_array import PortDataArray from tidy3d.plugins.smatrix.ports.base_lumped import AbstractLumpedPort -from tidy3d.plugins.smatrix.ports.coaxial_lumped import CoaxialLumpedPort -from tidy3d.plugins.smatrix.ports.rectangular_lumped import LumpedPort from tidy3d.plugins.smatrix.ports.types import TerminalPortType from tidy3d.plugins.smatrix.ports.wave import WavePort from tidy3d.plugins.smatrix.types import NetworkElement, NetworkIndex, SParamDef +if TYPE_CHECKING: + from tidy3d.compat import Self + from tidy3d.components.simulation import Simulation + from tidy3d.components.types import Ax + from tidy3d.plugins.smatrix.data.data_array import PortDataArray + from tidy3d.plugins.smatrix.ports.coaxial_lumped import CoaxialLumpedPort + from tidy3d.plugins.smatrix.ports.rectangular_lumped import LumpedPort + AUTO_RADIATION_MONITOR_NAME = "radiation" AUTO_RADIATION_MONITOR_BUFFER = 2 AUTO_RADIATION_MONITOR_NUM_POINTS_THETA = 100 @@ -48,18 +52,17 @@ class DirectivityMonitorSpec(MicrowaveBaseModel): - """ - Specification for automatically generating a :class:`.DirectivityMonitor`. + """Specification for automatically generating a :class:`.DirectivityMonitor`. - When included in the :attr:`.TerminalComponentModeler.radiation_monitors` tuple, - a :class:`.DirectivityMonitor` will be automatically generated with the specified - parameters. This allows users to mix manual :class:`.DirectivityMonitor` objects - with automatically generated ones, each with customizable parameters. + Notes + ----- + When included in the :attr:`.TerminalComponentModeler.radiation_monitors` tuple, + a :class:`.DirectivityMonitor` will be automatically generated with the specified + parameters. This allows users to mix manual :class:`.DirectivityMonitor` objects + with automatically generated ones, each with customizable parameters. - Note - ---- - The default origin (`custom_origin`) for defining observation points in the automatically - generated monitor is set to (0, 0, 0) in the global coordinate system. + The default origin (`custom_origin`) for defining observation points in the automatically + generated monitor is set to (0, 0, 0) in the global coordinate system. Example ------- @@ -71,42 +74,42 @@ class DirectivityMonitorSpec(MicrowaveBaseModel): ... ) """ - name: Optional[str] = pd.Field( + name: Optional[str] = Field( None, title="Monitor Name", description=f"Optional name for the auto-generated monitor. " f"If not provided, defaults to '{AUTO_RADIATION_MONITOR_NAME}_' + index of the monitor in the list of radiation monitors.", ) - freqs: Optional[tuple[pd.NonNegativeInt, ...]] = pd.Field( + freqs: Optional[tuple[NonNegativeInt, ...]] = Field( None, title="Frequencies", description="Frequencies to obtain fields at. If not provided, uses all frequencies " "from the :class:`.TerminalComponentModeler`. Must be a subset of modeler frequencies if provided.", ) - buffer: pd.NonNegativeInt = pd.Field( + buffer: NonNegativeInt = Field( AUTO_RADIATION_MONITOR_BUFFER, title="Buffer Distance", description="Number of grid cells to maintain between monitor and PML/domain boundaries. " f"Default: {AUTO_RADIATION_MONITOR_BUFFER} cells.", ) - num_theta_points: pd.NonNegativeInt = pd.Field( + num_theta_points: NonNegativeInt = Field( AUTO_RADIATION_MONITOR_NUM_POINTS_THETA, title="Elevation Angle Points", description="Number of elevation angle (theta) sample points from 0 to π. " f"Default: {AUTO_RADIATION_MONITOR_NUM_POINTS_THETA}.", ) - num_phi_points: pd.NonNegativeInt = pd.Field( + num_phi_points: NonNegativeInt = Field( AUTO_RADIATION_MONITOR_NUM_POINTS_PHI, title="Azimuthal Angle Points", description="Number of azimuthal angle (phi) sample points from -π to π. " f"Default: {AUTO_RADIATION_MONITOR_NUM_POINTS_PHI}.", ) - custom_origin: Optional[Coordinate] = pd.Field( + custom_origin: Optional[Coordinate] = Field( (0, 0, 0), title="Local Origin", description="Local origin used for defining observation points. If ``None``, uses the " @@ -151,14 +154,14 @@ class TerminalComponentModeler(AbstractComponentModeler, MicrowaveBaseModel): John Wiley & Sons, 2012. """ - ports: tuple[TerminalPortType, ...] = pd.Field( + ports: tuple[TerminalPortType, ...] = Field( (), title="Terminal Ports", description="Collection of lumped and wave ports associated with the network. " "For each port, one simulation will be run with a source that is associated with the port.", ) - run_only: Optional[tuple[NetworkIndex, ...]] = pd.Field( + run_only: Optional[tuple[NetworkIndex, ...]] = Field( None, title="Run Only", description="Set of matrix indices that define the simulations to run. " @@ -166,7 +169,7 @@ class TerminalComponentModeler(AbstractComponentModeler, MicrowaveBaseModel): "If a tuple is given, simulations will be run only for the given matrix indices.", ) - element_mappings: tuple[tuple[NetworkElement, NetworkElement, Complex], ...] = pd.Field( + element_mappings: tuple[tuple[NetworkElement, NetworkElement, Complex], ...] = Field( (), title="Element Mappings", description="Tuple of S matrix element mappings, each described by a tuple of " @@ -177,8 +180,8 @@ class TerminalComponentModeler(AbstractComponentModeler, MicrowaveBaseModel): ) radiation_monitors: tuple[ - annotate_type(Union[DirectivityMonitor, DirectivityMonitorSpec]), ... - ] = pd.Field( + discriminated_union(Union[DirectivityMonitor, DirectivityMonitorSpec]), ... + ] = Field( (), title="Radiation Monitors", description="Facilitates the calculation of figures-of-merit for antennas. " @@ -187,7 +190,7 @@ class TerminalComponentModeler(AbstractComponentModeler, MicrowaveBaseModel): "objects for automatic generation.", ) - assume_ideal_excitation: bool = pd.Field( + assume_ideal_excitation: bool = Field( False, title="Assume Ideal Excitation", description="If ``True``, only the excited port is assumed to have a nonzero incident wave " @@ -198,13 +201,13 @@ class TerminalComponentModeler(AbstractComponentModeler, MicrowaveBaseModel): "reflections from simulation boundaries. ", ) - s_param_def: SParamDef = pd.Field( + s_param_def: SParamDef = Field( "pseudo", title="Scattering Parameter Definition", description="Whether to compute scattering parameters using the 'pseudo' or 'power' wave definitions.", ) - low_freq_smoothing: Optional[ModelerLowFrequencySmoothingSpec] = pd.Field( + low_freq_smoothing: Optional[ModelerLowFrequencySmoothingSpec] = Field( None, title="Low Frequency Smoothing", description="The low frequency smoothing parameters for the terminal component simulation.", @@ -718,7 +721,7 @@ def _add_source_to_sim(self, source_index: NetworkIndex) -> tuple[str, Simulatio ) @cached_property - def _source_time(self): + def _source_time(self) -> GaussianPulse: """Helper to create a time domain pulse for the frequency range of interest.""" if self.custom_source_time is not None: return self.custom_source_time @@ -735,8 +738,9 @@ def _source_time(self): minimum_source_bandwidth=FWIDTH_FRAC, ) - @pd.validator("simulation") - def _validate_3d_simulation(cls, val): + @field_validator("simulation") + @classmethod + def _validate_3d_simulation(cls, val: Simulation) -> Simulation: """Error if :class:`.Simulation` is not a 3D simulation""" if val.size.count(0.0) > 0: @@ -745,17 +749,17 @@ def _validate_3d_simulation(cls, val): ) return val - @pd.validator("ports") - @skip_if_fields_missing(["simulation"]) - def _validate_port_refinement_usage(cls, val, values): + @model_validator(mode="after") + def _validate_port_refinement_usage(self) -> Self: """Warn if port refinement options are enabled, but the supplied simulation does not contain a grid type that will make use of them.""" + val = self.ports - sim: Simulation = values.get("simulation") + sim: Simulation = self.simulation # If grid spec is using AutoGrid # then set up is acceptable if sim.grid_spec.auto_grid_used: - return val + return self for port in val: if port._is_using_mesh_refinement: @@ -768,18 +772,20 @@ def _validate_port_refinement_usage(cls, val, values): "the 'enable_snapping_points=False' and 'num_grid_cells=None' for lumped ports." ) - return val + return self - @pd.validator("radiation_monitors") - @skip_if_fields_missing(["freqs"]) - def _validate_radiation_monitors(cls, val, values): + @model_validator(mode="after") + def _validate_radiation_monitors(self) -> Self: """Validate radiation monitors configuration. Validates that: - DirectivityMonitor frequencies are a subset of modeler frequencies - DirectivityMonitorSpec frequencies (if provided) are a subset of modeler frequencies """ - modeler_freqs = set(values.get("freqs", [])) + val = self.radiation_monitors + if self.freqs is None: + return self + modeler_freqs = set(self.freqs) for index, rad_mon in enumerate(val): # Only validate freqs if explicitly provided @@ -793,10 +799,10 @@ def _validate_radiation_monitors(cls, val, values): mon_name = rad_mon.name or f"{AUTO_RADIATION_MONITOR_NAME}_{index}" raise ValidationError( f"The frequencies in the radiation monitor '{mon_name}' " - f"must be equal to or a subset of the frequencies in the '{cls.__name__}'." + f"must be equal to or a subset of the frequencies in the '{self.__class__.__name__}'." ) - return val + return self @staticmethod def _check_grid_size_at_ports( @@ -998,4 +1004,4 @@ def _extrude_port_structures(self, sim: Simulation) -> Simulation: return sim -TerminalComponentModeler.update_forward_refs() +TerminalComponentModeler.model_rebuild() diff --git a/tidy3d/plugins/smatrix/data/base.py b/tidy3d/plugins/smatrix/data/base.py index a560019d65..88e64ed599 100644 --- a/tidy3d/plugins/smatrix/data/base.py +++ b/tidy3d/plugins/smatrix/data/base.py @@ -3,38 +3,42 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional -import pydantic.v1 as pd +from pydantic import Field, model_validator from tidy3d.components.base import Tidy3dBaseModel -from tidy3d.components.data.data_array import DataArray from tidy3d.components.data.index import SimulationDataMap from tidy3d.plugins.smatrix.component_modelers.base import AbstractComponentModeler +if TYPE_CHECKING: + from tidy3d.compat import Self + from tidy3d.components.data.data_array import DataArray + class AbstractComponentModelerData(ABC, Tidy3dBaseModel): """A data container for the results of a :class:`.AbstractComponentModeler` run. - This class stores the original modeler and the simulation data obtained - from running the simulations it defines. It also provides a method to - compute the S-matrix from the simulation data. + Notes + ----- + This class stores the original modeler and the simulation data obtained + from running the simulations it defines. It also provides a method to + compute the S-matrix from the simulation data. """ - modeler: AbstractComponentModeler = pd.Field( - ..., + modeler: AbstractComponentModeler = Field( title="Component modeler", description="The original :class:`AbstractComponentModeler` object that defines the " "simulation setup and from which this data was generated.", ) - data: SimulationDataMap = pd.Field( - ..., + data: SimulationDataMap = Field( title="SimulationDataMap", description="A mapping from task names to :class:`.SimulationData` objects, " "containing the results of each simulation run.", ) - log: str = pd.Field( + log: Optional[str] = Field( None, title="Modeler Post-process Log", description="A string containing the log information from the modeler post-processing run.", @@ -44,21 +48,21 @@ class AbstractComponentModelerData(ABC, Tidy3dBaseModel): def smatrix(self) -> DataArray: """Computes and returns the scattering matrix (S-matrix).""" - @pd.validator("data") - def keys_match_modeler(cls, val, values): + @model_validator(mode="after") + def keys_match_modeler(self) -> Self: """ Validates that the keys of the 'data' dictionary match the keys of the 'modeler.sim_dict' dictionary, irrespective of order. """ - modeler = values.get("modeler") + modeler = self.modeler # It's good practice to handle cases where 'modeler' might not be present if not modeler or not hasattr(modeler, "sim_dict"): - return val + return self # Use sets for an order-insensitive comparison modeler_keys = set(modeler.sim_dict.keys()) - data_keys = set(val.keys()) + data_keys = set(self.data.keys()) if modeler_keys != data_keys: # Provide a more helpful error by showing the exact differences @@ -73,4 +77,4 @@ def keys_match_modeler(cls, val, values): raise ValueError(f"Key mismatch between modeler and data. {'; '.join(error_parts)}") - return val + return self diff --git a/tidy3d/plugins/smatrix/data/modal.py b/tidy3d/plugins/smatrix/data/modal.py index 343354279f..e1fda38969 100644 --- a/tidy3d/plugins/smatrix/data/modal.py +++ b/tidy3d/plugins/smatrix/data/modal.py @@ -2,23 +2,28 @@ from __future__ import annotations -import pydantic.v1 as pd +from typing import TYPE_CHECKING + +from pydantic import Field from tidy3d.plugins.smatrix.component_modelers.modal import ModalComponentModeler from tidy3d.plugins.smatrix.data.base import AbstractComponentModelerData -from tidy3d.plugins.smatrix.data.data_array import ModalPortDataArray + +if TYPE_CHECKING: + from tidy3d.plugins.smatrix.data.data_array import ModalPortDataArray class ModalComponentModelerData(AbstractComponentModelerData): """A data container for the results of a :class:`.ModalComponentModeler` run. - This class stores the original modeler and the simulation data obtained - from running the simulations it defines. It also provides a method to - compute the S-matrix from the simulation data. + Notes + ----- + This class stores the original modeler and the simulation data obtained + from running the simulations it defines. It also provides a method to + compute the S-matrix from the simulation data. """ - modeler: ModalComponentModeler = pd.Field( - ..., + modeler: ModalComponentModeler = Field( title="ModalComponentModeler", description="The original :class:`ModalComponentModeler` object that defines the simulation setup " "and from which this data was generated.", diff --git a/tidy3d/plugins/smatrix/data/terminal.py b/tidy3d/plugins/smatrix/data/terminal.py index c00a889efa..c32b17b40f 100644 --- a/tidy3d/plugins/smatrix/data/terminal.py +++ b/tidy3d/plugins/smatrix/data/terminal.py @@ -2,27 +2,20 @@ from __future__ import annotations -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional import numpy as np -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.base import cached_property from tidy3d.components.data.data_array import FreqDataArray -from tidy3d.components.data.monitor_data import MonitorData -from tidy3d.components.data.sim_data import SimulationData from tidy3d.components.microwave.base import MicrowaveBaseModel -from tidy3d.components.microwave.data.monitor_data import AntennaMetricsData from tidy3d.constants import C_0 from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler from tidy3d.plugins.smatrix.data.base import AbstractComponentModelerData -from tidy3d.plugins.smatrix.data.data_array import ( - PortDataArray, - PortNameDataArray, - TerminalPortDataArray, -) +from tidy3d.plugins.smatrix.data.data_array import PortDataArray, TerminalPortDataArray from tidy3d.plugins.smatrix.ports.types import LumpedPortType -from tidy3d.plugins.smatrix.types import NetworkIndex, SParamDef +from tidy3d.plugins.smatrix.types import SParamDef from tidy3d.plugins.smatrix.utils import ( ab_to_s, check_port_impedance_sign, @@ -33,23 +26,31 @@ s_to_z, ) +if TYPE_CHECKING: + from typing import Union + + from tidy3d.components.data.monitor_data import MonitorData + from tidy3d.components.data.sim_data import SimulationData + from tidy3d.components.microwave.data.monitor_data import AntennaMetricsData + from tidy3d.plugins.smatrix.data.data_array import PortNameDataArray + from tidy3d.plugins.smatrix.types import NetworkIndex + class MicrowaveSMatrixData(MicrowaveBaseModel): """Stores the computed S-matrix and reference impedances for the terminal ports.""" - port_reference_impedances: Optional[PortDataArray] = pd.Field( + port_reference_impedances: Optional[PortDataArray] = Field( None, title="Port Reference Impedances", description="Reference impedance for each port used in the S-parameter calculation. This is optional and may not be present if not specified or computed.", ) - data: TerminalPortDataArray = pd.Field( - ..., + data: TerminalPortDataArray = Field( title="S-Matrix Data", description="An array containing the computed S-matrix of the device. The data is organized by terminal ports, representing the scattering parameters between them.", ) - s_param_def: SParamDef = pd.Field( + s_param_def: SParamDef = Field( "pseudo", title="Scattering Parameter Definition", description="Whether scattering parameters are defined using the 'pseudo' or 'power' wave definitions.", @@ -77,7 +78,7 @@ class TerminalComponentModelerData(AbstractComponentModelerData, MicrowaveBaseMo John Wiley & Sons, 2012. """ - modeler: TerminalComponentModeler = pd.Field( + modeler: TerminalComponentModeler = Field( ..., title="TerminalComponentModeler", description="The original :class:`.TerminalComponentModeler` object that defines the simulation setup " diff --git a/tidy3d/plugins/smatrix/ports/base.py b/tidy3d/plugins/smatrix/ports/base.py index 1d1e3200dc..eb84fbb602 100644 --- a/tidy3d/plugins/smatrix/ports/base.py +++ b/tidy3d/plugins/smatrix/ports/base.py @@ -4,7 +4,7 @@ from abc import ABC -import pydantic.v1 as pd +from pydantic import Field, field_validator from tidy3d.components.base import Tidy3dBaseModel from tidy3d.exceptions import SetupError @@ -13,15 +13,15 @@ class AbstractBasePort(Tidy3dBaseModel, ABC): """Abstract base class representing a port excitation of a component.""" - name: str = pd.Field( - ..., + name: str = Field( title="Name", description="Unique name for the port.", min_length=1, ) - @pd.validator("name") - def _valid_port_name(cls, val): + @field_validator("name") + @classmethod + def _valid_port_name(cls, val: str) -> str: """Make sure port name does not include the '@' symbol, so that task names will always be unique.""" if "@" in val: raise SetupError(f"Port names must not include the '@' symbol. Name given was '{val}'.") diff --git a/tidy3d/plugins/smatrix/ports/base_lumped.py b/tidy3d/plugins/smatrix/ports/base_lumped.py index 4582773128..6e21c21642 100644 --- a/tidy3d/plugins/smatrix/ports/base_lumped.py +++ b/tidy3d/plugins/smatrix/ports/base_lumped.py @@ -3,20 +3,23 @@ from __future__ import annotations from abc import abstractmethod -from typing import Optional +from typing import TYPE_CHECKING, Optional -import pydantic.v1 as pd +from pydantic import Field, PositiveInt from tidy3d.components.base import cached_property from tidy3d.components.geometry.utils_2d import snap_coordinate_to_grid -from tidy3d.components.grid.grid import Grid, YeeGrid -from tidy3d.components.lumped_element import LumpedElementType -from tidy3d.components.monitor import FieldMonitor -from tidy3d.components.types import Complex, Coordinate, FreqArray +from tidy3d.components.types import Complex from tidy3d.constants import OHM from .base_terminal import AbstractTerminalPort +if TYPE_CHECKING: + from tidy3d.components.grid.grid import Grid, YeeGrid + from tidy3d.components.lumped_element import LumpedElementType + from tidy3d.components.monitor import FieldMonitor + from tidy3d.components.types import Coordinate, FreqArray + DEFAULT_PORT_NUM_CELLS = 3 DEFAULT_REFERENCE_IMPEDANCE = 50 @@ -24,14 +27,14 @@ class AbstractLumpedPort(AbstractTerminalPort): """Class representing a single lumped port.""" - impedance: Complex = pd.Field( + impedance: Complex = Field( DEFAULT_REFERENCE_IMPEDANCE, title="Reference impedance", description="Reference port impedance for scattering parameter computation.", units=OHM, ) - num_grid_cells: Optional[pd.PositiveInt] = pd.Field( + num_grid_cells: Optional[PositiveInt] = Field( DEFAULT_PORT_NUM_CELLS, title="Port grid cells", description="Number of mesh grid cells associated with the port along each direction, " @@ -39,7 +42,7 @@ class AbstractLumpedPort(AbstractTerminalPort): "A value of ``None`` will turn off automatic mesh refinement.", ) - enable_snapping_points: bool = pd.Field( + enable_snapping_points: bool = Field( True, title="Snap Grid To Lumped Port", description="When enabled, snapping points are automatically generated to snap grids to key " diff --git a/tidy3d/plugins/smatrix/ports/base_terminal.py b/tidy3d/plugins/smatrix/ports/base_terminal.py index 8fdc10b8c2..36376c796b 100644 --- a/tidy3d/plugins/smatrix/ports/base_terminal.py +++ b/tidy3d/plugins/smatrix/ports/base_terminal.py @@ -3,20 +3,24 @@ from __future__ import annotations from abc import abstractmethod -from typing import Optional, Union +from typing import TYPE_CHECKING from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import FreqDataArray -from tidy3d.components.data.sim_data import SimulationData -from tidy3d.components.grid.grid import Grid from tidy3d.components.microwave.base import MicrowaveBaseModel -from tidy3d.components.monitor import FieldMonitor, ModeMonitor -from tidy3d.components.source.base import Source -from tidy3d.components.source.time import GaussianPulse -from tidy3d.components.types import FreqArray from tidy3d.log import log from tidy3d.plugins.smatrix.ports.base import AbstractBasePort +if TYPE_CHECKING: + from typing import Optional, Union + + from tidy3d.components.data.data_array import FreqDataArray + from tidy3d.components.data.sim_data import SimulationData + from tidy3d.components.grid.grid import Grid + from tidy3d.components.monitor import FieldMonitor, ModeMonitor + from tidy3d.components.source.base import Source + from tidy3d.components.source.time import GaussianPulse + from tidy3d.components.types import FreqArray + class AbstractTerminalPort(AbstractBasePort, MicrowaveBaseModel): """Class representing a single terminal-based port. All terminal ports must provide methods diff --git a/tidy3d/plugins/smatrix/ports/coaxial_lumped.py b/tidy3d/plugins/smatrix/ports/coaxial_lumped.py index 6ceb518c87..c15637701b 100644 --- a/tidy3d/plugins/smatrix/ports/coaxial_lumped.py +++ b/tidy3d/plugins/smatrix/ports/coaxial_lumped.py @@ -2,32 +2,40 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING import numpy as np -import pydantic.v1 as pd +from pydantic import Field, PositiveFloat, field_validator, model_validator from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import FreqDataArray, ScalarFieldDataArray +from tidy3d.components.data.data_array import ScalarFieldDataArray from tidy3d.components.data.dataset import FieldDataset -from tidy3d.components.data.sim_data import SimulationData from tidy3d.components.geometry.base import Box, Geometry from tidy3d.components.geometry.utils_2d import increment_float -from tidy3d.components.grid.grid import Grid, YeeGrid from tidy3d.components.lumped_element import CoaxialLumpedResistor from tidy3d.components.microwave.path_integrals.integrals.current import Custom2DCurrentIntegral from tidy3d.components.microwave.path_integrals.integrals.voltage import AxisAlignedVoltageIntegral from tidy3d.components.microwave.path_integrals.specs.base import AbstractAxesRH from tidy3d.components.monitor import FieldMonitor from tidy3d.components.source.current import CustomCurrentSource -from tidy3d.components.source.time import GaussianPulse -from tidy3d.components.types import Axis, Coordinate, Direction, FreqArray, Size -from tidy3d.components.validators import skip_if_fields_missing +from tidy3d.components.types import Axis, Coordinate, Direction from tidy3d.constants import MICROMETER from tidy3d.exceptions import SetupError, ValidationError from .base_lumped import AbstractLumpedPort +if TYPE_CHECKING: + from typing import Optional + + from numpy.typing import NDArray + + from tidy3d.compat import Self + from tidy3d.components.data.data_array import FreqDataArray + from tidy3d.components.data.sim_data import SimulationData + from tidy3d.components.grid.grid import Grid, YeeGrid + from tidy3d.components.source.time import GaussianPulse + from tidy3d.components.types import FreqArray, Size + DEFAULT_COAX_SOURCE_NUM_POINTS = 11 @@ -46,35 +54,31 @@ class CoaxialLumpedPort(AbstractLumpedPort, AbstractAxesRH): ... ) """ - center: Coordinate = pd.Field( + center: Coordinate = Field( (0.0, 0.0, 0.0), title="Center", description="Center of object in x, y, and z.", units=MICROMETER, ) - outer_diameter: pd.PositiveFloat = pd.Field( - ..., + outer_diameter: PositiveFloat = Field( title="Outer Diameter", description="Diameter of the outer coaxial circle.", units=MICROMETER, ) - inner_diameter: pd.PositiveFloat = pd.Field( - ..., + inner_diameter: PositiveFloat = Field( title="Inner Diameter", description="Diameter of the inner coaxial circle.", units=MICROMETER, ) - normal_axis: Axis = pd.Field( - ..., + normal_axis: Axis = Field( title="Normal Axis", description="Specifies the axis which is normal to the concentric circles.", ) - direction: Direction = pd.Field( - ..., + direction: Direction = Field( title="Direction", description="The direction of the signal travelling in the transmission line. " "This is needed in order to position the path integral, which is used for computing " @@ -82,34 +86,33 @@ class CoaxialLumpedPort(AbstractLumpedPort, AbstractAxesRH): ) @cached_property - def main_axis(self): + def main_axis(self) -> int: """Required for inheriting from AbstractAxesRH.""" return self.normal_axis @cached_property - def injection_axis(self): + def injection_axis(self) -> int: """Required for inheriting from AbstractTerminalPort.""" return self.normal_axis - @pd.validator("center", always=True) - def _center_not_inf(cls, val): + @field_validator("center") + @classmethod + def _center_not_inf(cls, val: Coordinate) -> Coordinate: """Make sure center is not infinity.""" if any(np.isinf(v) for v in val): raise ValidationError("'center' can not contain 'td.inf' terms.") return val - @pd.validator("inner_diameter", always=True) - @skip_if_fields_missing(["outer_diameter"]) - def _ensure_inner_diameter_is_smaller(cls, val, values): + @model_validator(mode="after") + def _ensure_inner_diameter_is_smaller(self) -> Self: """Ensures that the inner diameter is smaller than the outer diameter, so that the final shape is an annulus.""" - outer_diameter = values.get("outer_diameter") - if val >= outer_diameter: + if self.inner_diameter >= self.outer_diameter: raise ValidationError( - f"The 'inner_diameter' {val} of a coaxial lumped element must be less than its " - f"'outer_diameter' {outer_diameter}." + f"The 'inner_diameter' {self.inner_diameter} of a coaxial lumped element " + f"must be less than its 'outer_diameter' {self.outer_diameter}." ) - return val + return self def to_source( self, source_time: GaussianPulse, snap_center: Optional[float] = None, grid: Grid = None @@ -142,7 +145,9 @@ def to_source( # Get a normalized current density that is flowing radially from inner circle to outer circle # Total current is normalized to 1 - def compute_coax_current(rin, rout, x, y): + def compute_coax_current( + rin: float, rout: float, x: NDArray, y: NDArray + ) -> tuple[NDArray, NDArray]: # Radial distance r = np.sqrt(x**2 + y**2) # Remove division by 0 @@ -238,7 +243,7 @@ def to_voltage_monitor( center=self._voltage_path_center(center), size=self._voltage_path_size, freqs=freqs, - fields=[E1, E2], + fields=(E1, E2), name=self._voltage_monitor_name, colocate=False, ) @@ -266,10 +271,10 @@ def to_current_monitor( # Create a current monitor return FieldMonitor( - center=center, + center=tuple(center), size=current_mon_size, freqs=freqs, - fields=[H1, H2], + fields=(H1, H2), name=self._current_monitor_name, colocate=False, ) diff --git a/tidy3d/plugins/smatrix/ports/modal.py b/tidy3d/plugins/smatrix/ports/modal.py index 4920696af1..e656d8150f 100644 --- a/tidy3d/plugins/smatrix/ports/modal.py +++ b/tidy3d/plugins/smatrix/ports/modal.py @@ -2,7 +2,7 @@ from __future__ import annotations -import pydantic.v1 as pd +from pydantic import Field from tidy3d.components.data.data_array import DataArray from tidy3d.components.geometry.base import Box @@ -40,17 +40,18 @@ class ModalPortDataArray(DataArray): class Port(AbstractBasePort, Box): """Specifies a port for S-matrix calculation. - A port defines a location and a set of modes for which the S-matrix - is calculated. + Notes + ----- + A port defines a location and a set of modes for which the S-matrix + is calculated. """ - direction: Direction = pd.Field( - ..., + direction: Direction = Field( title="Direction", description="'+' or '-', defining which direction is considered 'input'.", ) - mode_spec: ModeSpec = pd.Field( - ModeSpec(), + mode_spec: ModeSpec = Field( + default_factory=ModeSpec, title="Mode Specification", description="Specifies how the mode solver will solve for the modes of the port.", ) diff --git a/tidy3d/plugins/smatrix/ports/rectangular_lumped.py b/tidy3d/plugins/smatrix/ports/rectangular_lumped.py index 9e477c7b2b..1f0e4ff3b7 100644 --- a/tidy3d/plugins/smatrix/ports/rectangular_lumped.py +++ b/tidy3d/plugins/smatrix/ports/rectangular_lumped.py @@ -2,16 +2,14 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Any import numpy as np -import pydantic.v1 as pd +from pydantic import Field, model_validator from shapely import union_all from shapely.geometry.base import BaseMultipartGeometry from tidy3d.components.base import cached_property -from tidy3d.components.data.data_array import FreqDataArray -from tidy3d.components.data.sim_data import SimulationData from tidy3d.components.geometry.base import Box, Geometry from tidy3d.components.geometry.utils import ( SnapBehavior, @@ -20,21 +18,31 @@ snap_box_to_grid, ) from tidy3d.components.geometry.utils_2d import increment_float -from tidy3d.components.grid.grid import Grid, YeeGrid -from tidy3d.components.lumped_element import LinearLumpedElement, LumpedResistor, RLCNetwork +from tidy3d.components.lumped_element import LinearLumpedElement, RLCNetwork from tidy3d.components.medium import LossyMetalMedium, PECMedium from tidy3d.components.microwave.path_integrals.integrals.current import AxisAlignedCurrentIntegral from tidy3d.components.microwave.path_integrals.integrals.voltage import AxisAlignedVoltageIntegral from tidy3d.components.monitor import FieldMonitor from tidy3d.components.source.current import UniformCurrentSource -from tidy3d.components.source.time import GaussianPulse -from tidy3d.components.structure import Structure -from tidy3d.components.types import Axis, FreqArray, LumpDistType +from tidy3d.components.types import Axis, LumpDistType from tidy3d.components.validators import assert_line_or_plane from tidy3d.exceptions import SetupError, ValidationError from .base_lumped import AbstractLumpedPort +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Optional + + from tidy3d.compat import Self + from tidy3d.components.data.data_array import FreqDataArray + from tidy3d.components.data.sim_data import SimulationData + from tidy3d.components.grid.grid import Grid, YeeGrid + from tidy3d.components.lumped_element import LumpedResistor + from tidy3d.components.source.time import GaussianPulse + from tidy3d.components.structure import Structure + from tidy3d.components.types import FreqArray + class LumpedPort(AbstractLumpedPort, Box): """Class representing a single rectangular lumped port. @@ -54,14 +62,13 @@ class LumpedPort(AbstractLumpedPort, Box): The lumped element representing the load of the port. """ - voltage_axis: Axis = pd.Field( - ..., + voltage_axis: Axis = Field( title="Voltage Integration Axis", description="Specifies the axis along which the E-field line integral is performed when " "computing the port voltage. The integration axis must lie in the plane of the port.", ) - snap_perimeter_to_grid: bool = pd.Field( + snap_perimeter_to_grid: bool = Field( True, title="Snap Perimeter to Grid", description="When enabled, the perimeter of the port is snapped to the simulation grid, " @@ -69,7 +76,7 @@ class LumpedPort(AbstractLumpedPort, Box): "is always snapped to the grid along its injection axis.", ) - dist_type: LumpDistType = pd.Field( + dist_type: LumpDistType = Field( "on", title="Distribute Type", description="Optional field that is passed directly to the :class:`.LinearLumpedElement` used to model the port's load. " @@ -84,17 +91,16 @@ class LumpedPort(AbstractLumpedPort, Box): _line_plane_validator = assert_line_or_plane() @cached_property - def injection_axis(self): + def injection_axis(self) -> int: """Injection axis of the port.""" return self.size.index(0.0) - @pd.validator("voltage_axis", always=True) - def _voltage_axis_in_plane(cls, val, values): + @model_validator(mode="after") + def _voltage_axis_in_plane(self) -> Self: """Ensure voltage integration axis is in the port's plane.""" - size = values.get("size") - if val == size.index(0.0): + if self.voltage_axis == self.size.index(0.0): raise ValidationError("'voltage_axis' must lie in the port's plane.") - return val + return self @cached_property def current_axis(self) -> Axis: @@ -170,10 +176,10 @@ def to_voltage_monitor( e_component = "xyz"[self.voltage_axis] # Create a voltage monitor return FieldMonitor( - center=center, - size=size, + center=tuple(center), + size=tuple(size), freqs=freqs, - fields=[f"E{e_component}"], + fields=(f"E{e_component}",), name=self._voltage_monitor_name, colocate=False, ) @@ -204,10 +210,10 @@ def to_current_monitor( h_cap_component = "xyz"[self.injection_axis] # Create a current monitor return FieldMonitor( - center=center, - size=size, + center=tuple(center), + size=tuple(size), freqs=freqs, - fields=[f"H{h_component}", f"H{h_cap_component}"], + fields=(f"H{h_component}", f"H{h_cap_component}"), name=self._current_monitor_name, colocate=False, ) @@ -312,7 +318,7 @@ def from_structures( voltage_axis: Axis = None, lateral_coord: Optional[float] = None, port_width: Optional[float] = None, - **kwargs, + **kwargs: Any, ) -> LumpedPort: """ Auto-generate lumped port based on provided structures and plane coordinates. @@ -467,7 +473,7 @@ def from_structures( ground_bounds = np.array(ground_2d.bounds).reshape(2, 2).T signal_bounds = np.array(signal_2d.bounds).reshape(2, 2).T - def intervals_overlap(a, b): + def intervals_overlap(a: Sequence, b: Sequence) -> bool: """Return True if [a_min, a_max] and [b_min, b_max] overlap.""" a_min, a_max = a b_min, b_max = b diff --git a/tidy3d/plugins/smatrix/ports/wave.py b/tidy3d/plugins/smatrix/ports/wave.py index 8b160c3bdc..5ad62c3501 100644 --- a/tidy3d/plugins/smatrix/ports/wave.py +++ b/tidy3d/plugins/smatrix/ports/wave.py @@ -2,30 +2,36 @@ from __future__ import annotations -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union -import pydantic.v1 as pd +from pydantic import Field, NonNegativeInt, field_validator, model_validator -from tidy3d.components.base import cached_property, skip_if_fields_missing +from tidy3d.components.base import cached_property from tidy3d.components.boundary import ABCBoundary, InternalAbsorber, ModeABCBoundary -from tidy3d.components.data.data_array import FreqDataArray, FreqModeDataArray from tidy3d.components.data.sim_data import SimulationData from tidy3d.components.geometry.base import Box -from tidy3d.components.grid.grid import Grid -from tidy3d.components.microwave.data.monitor_data import MicrowaveModeData from tidy3d.components.microwave.mode_spec import MicrowaveModeSpec from tidy3d.components.microwave.monitor import MicrowaveModeMonitor -from tidy3d.components.simulation import Simulation from tidy3d.components.source.field import ModeSource from tidy3d.components.source.frame import PECFrame -from tidy3d.components.source.time import GaussianPulse from tidy3d.components.structure import MeshOverrideStructure -from tidy3d.components.types import Axis, Direction, FreqArray +from tidy3d.components.types import Direction from tidy3d.exceptions import SetupError, ValidationError from tidy3d.log import log from tidy3d.plugins.mode import ModeSolver from tidy3d.plugins.smatrix.ports.base_terminal import AbstractTerminalPort +if TYPE_CHECKING: + from pydantic import NonNegativeFloat + + from tidy3d.compat import Self + from tidy3d.components.data.data_array import FreqDataArray, FreqModeDataArray + from tidy3d.components.grid.grid import Grid + from tidy3d.components.microwave.data.monitor_data import MicrowaveModeData + from tidy3d.components.simulation import Simulation + from tidy3d.components.source.time import GaussianPulse + from tidy3d.components.types import Axis, FreqArray + DEFAULT_WAVE_PORT_NUM_CELLS = 5 MIN_WAVE_PORT_NUM_CELLS = 3 DEFAULT_WAVE_PORT_FRAME = PECFrame() @@ -34,20 +40,19 @@ class WavePort(AbstractTerminalPort, Box): """Class representing a single wave port""" - direction: Direction = pd.Field( - ..., + direction: Direction = Field( title="Direction", description="'+' or '-', defining which direction is considered 'input'.", ) - mode_spec: MicrowaveModeSpec = pd.Field( + mode_spec: MicrowaveModeSpec = Field( default_factory=MicrowaveModeSpec._default_without_license_warning, title="Mode Specification", description="Parameters to feed to mode solver which determine modes and how transmission line " "quantities, e.g., charateristic impedance, are computed.", ) - num_grid_cells: Optional[int] = pd.Field( + num_grid_cells: Optional[int] = Field( DEFAULT_WAVE_PORT_NUM_CELLS, ge=MIN_WAVE_PORT_NUM_CELLS, title="Number of Grid Cells", @@ -56,32 +61,32 @@ class WavePort(AbstractTerminalPort, Box): "Must be greater than or equal to 3. When set to `None`, no grid refinement is performed.", ) - conjugated_dot_product: bool = pd.Field( + conjugated_dot_product: bool = Field( False, title="Conjugated Dot Product", description="Use conjugated or non-conjugated dot product for mode decomposition.", ) - frame: Optional[PECFrame] = pd.Field( + frame: Optional[PECFrame] = Field( DEFAULT_WAVE_PORT_FRAME, title="Source Frame", description="Add a thin frame around the source during FDTD run for an improved injection.", ) - absorber: Union[bool, ABCBoundary, ModeABCBoundary] = pd.Field( + absorber: Union[bool, ABCBoundary, ModeABCBoundary] = Field( True, title="Absorber", description="Place a mode absorber in the port. If ``True``, an automatically generated mode absorber is placed in the port. " "If :class:`.ABCBoundary` or :class:`.ModeABCBoundary`, a mode absorber is placed in the port with the specified boundary conditions.", ) - extrude_structures: bool = pd.Field( + extrude_structures: bool = Field( False, title="Extrude Structures", description="Extrudes structures that intersect the wave port plane by a few grid cells when ``True``, improving mode injection accuracy.", ) - mode_index: Optional[pd.NonNegativeInt] = pd.Field( + mode_index: Optional[NonNegativeInt] = Field( None, title="Mode Index (deprecated)", description="Index into the collection of modes returned by mode solver. " @@ -89,7 +94,7 @@ class WavePort(AbstractTerminalPort, Box): "Deprecated. Use the 'mode_selection' field instead.", ) - mode_selection: Optional[tuple[int, ...]] = pd.Field( + mode_selection: Optional[tuple[int, ...]] = Field( None, title="Mode Selection", description="Selects specific mode(s) to use from the mode solver. " @@ -136,7 +141,7 @@ def to_source( if snap_center: center[self.injection_axis] = snap_center return ModeSource( - center=center, + center=tuple(center), size=self.size, source_time=source_time, mode_spec=self.mode_spec, @@ -179,7 +184,7 @@ def to_mode_solver(self, simulation: Simulation, freqs: FreqArray) -> ModeSolver return mode_solver def to_absorber( - self, snap_center: Optional[float] = None, freq_spec: Optional[pd.NonNegativeFloat] = None + self, snap_center: Optional[float] = None, freq_spec: Optional[NonNegativeFloat] = None ) -> InternalAbsorber: """Create an internal absorber from the wave port.""" center = list(self.center) @@ -281,26 +286,29 @@ def to_mesh_overrides(self) -> list[MeshOverrideStructure]: ) ] - @pd.validator("mode_spec", always=True) - def _validate_path_integrals_within_port(cls, val, values): + @model_validator(mode="after") + def _validate_path_integrals_within_port(self) -> Self: """Validate that the microwave mode spec contains path specs all within the port bounds.""" - center = values["center"] - size = values["size"] + val = self.mode_spec + center = self.center + size = self.size self_plane = Box(size=size, center=center) try: val._check_path_integrals_within_box(self_plane) except SetupError as e: raise SetupError( - f"Failed to setup '{cls.__name__}' with the suppled 'MicrowaveModeSpec'. {e!s}" + f"Failed to setup '{self.__class__.__name__}' with the suppled 'MicrowaveModeSpec'. {e!s}" ) from e - return val + return self - @skip_if_fields_missing(["mode_spec"]) - @pd.validator("mode_selection", always=True) - def _validate_mode_selection(cls, val, values): + @model_validator(mode="after") + def _validate_mode_selection(self) -> Self: """Validate that mode_selection contains valid, unique indices within range.""" + if self.mode_spec is None: + return self + val = self.mode_selection if val is None: - return val + return self indices = val @@ -319,7 +327,7 @@ def _validate_mode_selection(cls, val, values): ) # Check that indices are within range of num_modes - mode_spec = values["mode_spec"] + mode_spec = self.mode_spec num_modes = mode_spec.num_modes invalid_indices = [idx for idx in indices if idx >= num_modes] if invalid_indices: @@ -328,22 +336,23 @@ def _validate_mode_selection(cls, val, values): f"'mode_spec.num_modes' ({num_modes}). Valid range is 0 to {num_modes - 1}." ) - return val + return self - @pd.root_validator(pre=False) - def _check_absorber_if_extruding_structures(cls, values): + @model_validator(mode="after") + def _check_absorber_if_extruding_structures(self) -> Self: """Raise validation error when ``extrude_structures`` is set to ``True`` while ``absorber`` is set to ``False``.""" - if values.get("extrude_structures") and not values.get("absorber"): + if self.extrude_structures and not self.absorber: raise ValidationError( "Structure extrusion for a waveport requires an internal absorber. Set `absorber=True` to enable it." ) - return values + return self - @pd.validator("mode_index", always=True) - def _mode_index_deprecated(cls, val): + @field_validator("mode_index") + @classmethod + def _mode_index_deprecated(cls, val: Optional[int]) -> Optional[int]: """Warn that 'mode_index' is deprecated in favor of 'mode_selection'.""" if val is not None: log.warning( @@ -352,19 +361,21 @@ def _mode_index_deprecated(cls, val): ) return val - @skip_if_fields_missing(["mode_spec"]) - @pd.validator("mode_index", always=True) - def _validate_mode_index(cls, val, values): + @model_validator(mode="after") + def _validate_mode_index(self) -> Self: """Validate that mode_selection contains valid, unique indices within range.""" + val = self.mode_index if val is None: - return val - num_modes = values["mode_spec"].num_modes + return self + if self.mode_spec is None: + return self + num_modes = self.mode_spec.num_modes if val >= num_modes: raise ValidationError( f"'mode_index' is >= " f"'mode_spec.num_modes' ({num_modes}). Valid range is 0 to {num_modes - 1}." ) - return val + return self @property def _is_using_mesh_refinement(self) -> bool: diff --git a/tidy3d/plugins/smatrix/run.py b/tidy3d/plugins/smatrix/run.py index ff4b9d39f1..8e15979dd8 100644 --- a/tidy3d/plugins/smatrix/run.py +++ b/tidy3d/plugins/smatrix/run.py @@ -1,16 +1,19 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from tidy3d.components.data.index import SimulationDataMap from tidy3d.log import log from tidy3d.plugins.smatrix.component_modelers.modal import ModalComponentModeler from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler -from tidy3d.plugins.smatrix.component_modelers.types import ComponentModelerType from tidy3d.plugins.smatrix.data.modal import ModalComponentModelerData from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData -from tidy3d.plugins.smatrix.data.types import ComponentModelerDataType -from tidy3d.web import Batch, BatchData +from tidy3d.web import Batch + +if TYPE_CHECKING: + from tidy3d.plugins.smatrix.component_modelers.types import ComponentModelerType + from tidy3d.plugins.smatrix.data.types import ComponentModelerDataType + from tidy3d.web import BatchData DEFAULT_DATA_DIR = "." diff --git a/tidy3d/plugins/smatrix/types.py b/tidy3d/plugins/smatrix/types.py index d2fd649117..f31249a006 100644 --- a/tidy3d/plugins/smatrix/types.py +++ b/tidy3d/plugins/smatrix/types.py @@ -2,10 +2,10 @@ from typing import Literal -import pydantic.v1 as pd +from pydantic import NonNegativeInt # S matrix indices and entries for the ModalComponentModeler -MatrixIndex = tuple[str, pd.NonNegativeInt] # the 'i' in S_ij +MatrixIndex = tuple[str, NonNegativeInt] # the 'i' in S_ij Element = tuple[MatrixIndex, MatrixIndex] # the 'ij' in S_ij # S matrix indices and entries for the TerminalComponentModeler NetworkIndex = str # the 'i' in S_ij diff --git a/tidy3d/plugins/smatrix/utils.py b/tidy3d/plugins/smatrix/utils.py index 43c822dc00..02dd588bdc 100644 --- a/tidy3d/plugins/smatrix/utils.py +++ b/tidy3d/plugins/smatrix/utils.py @@ -7,28 +7,31 @@ from __future__ import annotations -from typing import Union +from typing import TYPE_CHECKING import numpy as np -from tidy3d.components.data.data_array import ( - DataArray, - FreqDataArray, -) -from tidy3d.components.data.sim_data import SimulationData -from tidy3d.components.types import ArrayFloat1D from tidy3d.exceptions import Tidy3dError from tidy3d.plugins.smatrix.data.data_array import PortDataArray, TerminalPortDataArray -from tidy3d.plugins.smatrix.ports.types import ( - LumpedPortType, - PortCurrentType, - PortVoltageType, - TerminalPortType, -) -from tidy3d.plugins.smatrix.types import SParamDef +if TYPE_CHECKING: + from typing import Union -def port_array_inv(matrix: DataArray): + from numpy.typing import NDArray + + from tidy3d.components.data.data_array import DataArray, FreqDataArray + from tidy3d.components.data.sim_data import SimulationData + from tidy3d.components.types import ArrayFloat1D + from tidy3d.plugins.smatrix.ports.types import ( + LumpedPortType, + PortCurrentType, + PortVoltageType, + TerminalPortType, + ) + from tidy3d.plugins.smatrix.types import SParamDef + + +def port_array_inv(matrix: DataArray) -> NDArray: """Helper to invert a port matrix. Parameters @@ -105,7 +108,7 @@ def check_port_impedance_sign(Z_numpy: np.ndarray) -> None: ) -def compute_F(Z_numpy: ArrayFloat1D, s_param_def: SParamDef = "pseudo"): +def compute_F(Z_numpy: ArrayFloat1D, s_param_def: SParamDef = "pseudo") -> float: r"""Helper to convert port impedance matrix to F, which is used for computing scattering parameters diff --git a/tidy3d/plugins/waveguide/rectangular_dielectric.py b/tidy3d/plugins/waveguide/rectangular_dielectric.py index 22b01f0d3d..cbbf9fac71 100644 --- a/tidy3d/plugins/waveguide/rectangular_dielectric.py +++ b/tidy3d/plugins/waveguide/rectangular_dielectric.py @@ -2,15 +2,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Union +from typing import TYPE_CHECKING, Annotated, Any, Optional, Union import numpy -import pydantic.v1 as pydantic from matplotlib import pyplot +from pydantic import Field, field_validator, model_validator -from tidy3d.components.base import Tidy3dBaseModel, cached_property, skip_if_fields_missing +from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.boundary import BoundarySpec, Periodic -from tidy3d.components.data.data_array import FreqModeDataArray, ModeIndexDataArray from tidy3d.components.geometry.base import Box from tidy3d.components.geometry.polyslab import PolySlab from tidy3d.components.grid.grid_spec import GridSpec @@ -20,7 +19,7 @@ from tidy3d.components.source.field import ModeSource from tidy3d.components.source.time import GaussianPulse from tidy3d.components.structure import Structure -from tidy3d.components.types import TYPE_TAG_STR, ArrayFloat1D, Ax, Axis, Coordinate, Size1D +from tidy3d.components.types import TYPE_TAG_STR, ArrayFloat1D, Axis, Coordinate, Size1D from tidy3d.components.viz import add_ax_if_none from tidy3d.constants import C_0, MICROMETER, RADIAN, inf from tidy3d.exceptions import Tidy3dError, ValidationError @@ -28,76 +27,81 @@ from tidy3d.plugins.mode.mode_solver import ModeSolver if TYPE_CHECKING: + from typing import Literal + from matplotlib.colors import Colormap + from pydantic import ValidationInfo + + from tidy3d.compat import Self + from tidy3d.components.data.data_array import FreqModeDataArray, ModeIndexDataArray + from tidy3d.components.types import Ax -AnnotatedMedium = Annotated[MediumType, pydantic.Field(discriminator=TYPE_TAG_STR)] +AnnotatedMedium = Annotated[MediumType, Field(discriminator=TYPE_TAG_STR)] EVANESCENT_TAIL = 1.5 class RectangularDielectric(Tidy3dBaseModel): - """General rectangular dielectric waveguide - - Supports: - - Strip and rib geometries - - Angled sidewalls - - Modes in waveguide bends - - Surface and sidewall loss models - - Coupled waveguides + """General rectangular dielectric waveguide. + + Notes + ----- + Supports: + + - Strip and rib geometries + - Angled sidewalls + - Modes in waveguide bends + - Surface and sidewall loss models + - Coupled waveguides """ - wavelength: Union[float, ArrayFloat1D] = pydantic.Field( - ..., + wavelength: Union[float, ArrayFloat1D] = Field( title="Wavelength", description="Wavelength(s) at which to calculate modes (in μm).", units=MICROMETER, ) - core_width: Union[Size1D, ArrayFloat1D] = pydantic.Field( - ..., + core_width: Union[Size1D, ArrayFloat1D] = Field( title="Core width", description="Core width at the top of the waveguide. If set to an array, defines " "the widths of adjacent waveguides.", units=MICROMETER, ) - core_thickness: Size1D = pydantic.Field( - ..., + core_thickness: Size1D = Field( title="Core Thickness", description="Thickness of the core layer.", units=MICROMETER, ) - core_medium: MediumType = pydantic.Field( - ..., + core_medium: MediumType = Field( title="Core Medium", description="Medium associated with the core layer.", discriminator=TYPE_TAG_STR, ) - clad_medium: Union[AnnotatedMedium, tuple[AnnotatedMedium, ...]] = pydantic.Field( - ..., + clad_medium: Union[AnnotatedMedium, tuple[AnnotatedMedium, ...]] = Field( title="Clad Medium", description="Medium associated with the upper cladding layer. A sequence of mediums can " "be used to create a layered clad.", ) - box_medium: Union[AnnotatedMedium, tuple[AnnotatedMedium, ...]] = pydantic.Field( + box_medium: Optional[Union[AnnotatedMedium, tuple[AnnotatedMedium, ...]]] = Field( None, title="Box Medium", description="Medium associated with the lower cladding layer. A sequence of mediums can " "be used to create a layered substrate. If not set, the first clad medium is used.", ) - slab_thickness: Size1D = pydantic.Field( + slab_thickness: Size1D = Field( 0.0, title="Slab Thickness", description="Thickness of the slab for rib geometry.", units=MICROMETER, ) - clad_thickness: Union[Size1D, ArrayFloat1D] = pydantic.Field( + clad_thickness: Optional[Union[Size1D, ArrayFloat1D]] = Field( None, title="Clad Thickness", description="Domain size above the core layer. An array can be used to define a layered " @@ -105,7 +109,7 @@ class RectangularDielectric(Tidy3dBaseModel): units=MICROMETER, ) - box_thickness: Union[Size1D, ArrayFloat1D] = pydantic.Field( + box_thickness: Optional[Union[Size1D, ArrayFloat1D]] = Field( None, title="Box Thickness", description="Domain size below the core layer. An array can be used to define a layered " @@ -113,14 +117,14 @@ class RectangularDielectric(Tidy3dBaseModel): units=MICROMETER, ) - side_margin: Size1D = pydantic.Field( + side_margin: Optional[Size1D] = Field( None, title="Side Margin", description="Domain size to the sides of the waveguide core.", units=MICROMETER, ) - sidewall_angle: float = pydantic.Field( + sidewall_angle: float = Field( 0.0, title="Sidewall Angle", description="Angle of the core sidewalls measured from the vertical direction (in " @@ -129,7 +133,7 @@ class RectangularDielectric(Tidy3dBaseModel): units=RADIAN, ) - gap: Union[float, ArrayFloat1D] = pydantic.Field( + gap: Union[float, ArrayFloat1D] = Field( 0.0, title="Gap", description="Distance between adjacent waveguides, measured at the top core edges. " @@ -137,21 +141,21 @@ class RectangularDielectric(Tidy3dBaseModel): units=MICROMETER, ) - sidewall_thickness: Size1D = pydantic.Field( + sidewall_thickness: Size1D = Field( 0.0, title="Sidewall Thickness", description="Sidewall layer thickness (within core).", units=MICROMETER, ) - sidewall_medium: MediumType = pydantic.Field( + sidewall_medium: Optional[MediumType] = Field( None, title="Sidewall medium", description="Medium associated with the sidewall layer to model sidewall losses.", discriminator=TYPE_TAG_STR, ) - surface_thickness: Size1D = pydantic.Field( + surface_thickness: Size1D = Field( 0.0, title="Surface Thickness", description="Thickness of the surface layers defined on the top of the waveguide and " @@ -159,14 +163,14 @@ class RectangularDielectric(Tidy3dBaseModel): units=MICROMETER, ) - surface_medium: MediumType = pydantic.Field( + surface_medium: Optional[MediumType] = Field( None, title="Surface Medium", description="Medium associated with the surface layer to model surface losses.", discriminator=TYPE_TAG_STR, ) - origin: Coordinate = pydantic.Field( + origin: Coordinate = Field( (0, 0, 0), title="Origin", description="Center of the waveguide geometry. This coordinate represents the base " @@ -175,57 +179,61 @@ class RectangularDielectric(Tidy3dBaseModel): units=MICROMETER, ) - length: Size1D = pydantic.Field( + length: Size1D = Field( 1e30, title="Length", description="Length of the waveguides in the propagation direction", units=MICROMETER, ) - propagation_axis: Axis = pydantic.Field( + propagation_axis: Axis = Field( 0, title="Propagation Axis", description="Axis of propagation of the waveguide", ) - normal_axis: Axis = pydantic.Field( + normal_axis: Axis = Field( 2, title="Normal Axis", description="Axis normal to the substrate surface", ) - mode_spec: ModeSpec = pydantic.Field( - ModeSpec(num_modes=2), + mode_spec: ModeSpec = Field( + default_factory=lambda: ModeSpec(num_modes=2), title="Mode Specification", description=":class:`ModeSpec` defining waveguide mode properties.", ) - grid_resolution: int = pydantic.Field( + grid_resolution: int = Field( 15, title="Grid Resolution", description="Solver grid resolution per wavelength.", ) - max_grid_scaling: float = pydantic.Field( + max_grid_scaling: float = Field( 1.2, title="Maximal Grid Scaling", description="Maximal size increase between adjacent grid boundaries.", ) - @pydantic.validator("wavelength", "core_width", "gap", always=True) - def _set_non_negative_array(cls, val): + @field_validator("wavelength", "core_width", "gap") + @classmethod + def _set_non_negative_array(cls, val: Union[float, ArrayFloat1D]) -> Union[float, ArrayFloat1D]: """Ensure values are not negative and convert to numpy arrays.""" val = numpy.array(val, ndmin=1) if any(val < 0): raise ValidationError("Values may not be negative.") return val - @pydantic.validator("core_medium", "clad_medium", "box_medium") - def _check_non_metallic(cls, val, values): + @field_validator("core_medium", "clad_medium", "box_medium") + @classmethod + def _check_non_metallic( + cls, val: Union[MediumType, tuple[MediumType, ...]], info: ValidationInfo + ) -> Union[MediumType, tuple[MediumType, ...]]: if val is None: return val media = val if isinstance(val, tuple) else (val,) - freqs = C_0 / values["wavelength"] + freqs = C_0 / info.data["wavelength"] if any(medium.eps_model(f).real < 1 for medium in media for f in freqs): raise ValidationError( "'RectangularDielectric' can only be used with dielectric media. " @@ -233,60 +241,60 @@ def _check_non_metallic(cls, val, values): ) return val - @pydantic.validator("gap", always=True) - @skip_if_fields_missing(["core_width"]) - def _validate_gaps(cls, val, values): + @model_validator(mode="after") + def _validate_gaps(self) -> Self: """Ensure the number of gaps is compatible with the number of cores supplied.""" - if val.size == 1 and values["core_width"].size != 2: + if self.gap.size == 1 and self.core_width.size != 2: # If a single value is defined, use it for all gaps - return numpy.array([val[0]] * (values["core_width"].size - 1)) - if val.size != values["core_width"].size - 1: + object.__setattr__(self, "gap", numpy.array([self.gap[0]] * (self.core_width.size - 1))) + return self + if self.gap.size != self.core_width.size - 1: raise ValidationError("Number of gaps must be 1 less than number of core widths.") - return val + return self - @pydantic.root_validator - def _set_box_medium(cls, values): + @model_validator(mode="after") + def _set_box_medium(self) -> Self: """Set BOX medium same as cladding as default value.""" - box_medium = values.get("box_medium") - if box_medium is None: - clad_medium = values.get("clad_medium") - if clad_medium is None: - return values - if isinstance(clad_medium, tuple): - clad_medium = clad_medium[0] - values["box_medium"] = clad_medium - return values - - @pydantic.root_validator - def _set_clad_thickness(cls, values): + if self.box_medium is None: + if self.clad_medium is None: + return self + if isinstance(self.clad_medium, tuple): + object.__setattr__(self, "box_medium", self.clad_medium[0]) + else: + object.__setattr__(self, "box_medium", self.clad_medium) + return self + + @model_validator(mode="after") + def _set_clad_thickness(self) -> Self: """Set default clad/BOX thickness based on the max wavelength in the medium.""" for side in ("clad", "box"): - val = values.get(side + "_thickness") + val = getattr(self, side + "_thickness") if val is None: - wavelength = values.get("wavelength") - medium = values.get(side + "_medium") - if wavelength is None or medium is None: - return values + medium = getattr(self, side + "_medium") + if self.wavelength is None or medium is None: + return self if isinstance(medium, tuple): medium = medium[0] - n = numpy.array([medium.nk_model(f)[0] for f in C_0 / wavelength]) - lda = wavelength / n - values[side + "_thickness"] = EVANESCENT_TAIL * lda.max() + n = numpy.array([medium.nk_model(f)[0] for f in C_0 / self.wavelength]) + lda = self.wavelength / n + object.__setattr__(self, side + "_thickness", EVANESCENT_TAIL * lda.max()) elif isinstance(val, float): if val < 0: raise ValidationError("Thickness may not be negative.") else: - values[side + "_thickness"] = cls._set_non_negative_array(val) - return values + object.__setattr__( + self, side + "_thickness", type(self)._set_non_negative_array(val) + ) + return self - @pydantic.root_validator - def _validate_layers(cls, values): + @model_validator(mode="after") + def _validate_layers(self) -> Self: """Ensure the number of clad media is compatible with the number of layers supplied.""" for side in ("clad", "box"): - thickness = values.get(side + "_thickness") - medium = values.get(side + "_medium") + thickness = getattr(self, side + "_thickness") + medium = getattr(self, side + "_medium") if thickness is None or medium is None: - return values + return self num_layers = 1 if isinstance(thickness, float) else thickness.size num_media = 1 if not isinstance(medium, tuple) else len(medium) if num_layers != num_media: @@ -294,53 +302,46 @@ def _validate_layers(cls, values): f"Number of '{side}_thickness' values ({num_layers}) must be equal to that of " f"'{side}_medium' ({num_media})." ) - return values + return self - @pydantic.root_validator - def _set_side_margin(cls, values): + @model_validator(mode="after") + def _set_side_margin(self) -> Self: """Set default side margin based on BOX and cladding thicknesses.""" - clad_thickness = values.get("clad_thickness") - box_thickness = values.get("box_thickness") + clad_thickness = self.clad_thickness + box_thickness = self.box_thickness if clad_thickness is None or box_thickness is None: - return values - if values["side_margin"] is None: + return self + if self.side_margin is None: if not isinstance(clad_thickness, float): clad_thickness = clad_thickness.sum() if not isinstance(box_thickness, float): box_thickness = box_thickness.sum() - values["side_margin"] = max(clad_thickness, box_thickness) - return values + object.__setattr__(self, "side_margin", max(clad_thickness, box_thickness)) + return self - @pydantic.root_validator - def _ensure_consistency(cls, values): + @model_validator(mode="after") + def _ensure_consistency(self) -> Self: """Ensure consistency in setting surface/sidewall models and propagation/normal axes.""" - sidewall_thickness = values["sidewall_thickness"] - sidewall_medium = values["sidewall_medium"] - surface_thickness = values["surface_thickness"] - surface_medium = values["surface_medium"] - propagation_axis = values["propagation_axis"] - normal_axis = values["normal_axis"] - - if sidewall_thickness > 0 and sidewall_medium is None: + if self.sidewall_thickness > 0 and self.sidewall_medium is None: raise ValidationError( "Sidewall medium must be provided when sidewall thickness is greater than 0." ) - if sidewall_thickness == 0 and sidewall_medium is not None: + if self.sidewall_thickness == 0 and self.sidewall_medium is not None: log.warning("Sidewall medium not used because sidewall thickness is zero.") - if surface_thickness > 0 and surface_medium is None: + if self.surface_thickness > 0 and self.surface_medium is None: raise ValidationError( "Surface medium must be provided when surface thickness is greater than 0." ) - if surface_thickness == 0 and surface_medium is not None: + if self.surface_thickness == 0 and self.surface_medium is not None: log.warning("Surface medium not used because surface thickness is zero.") - if propagation_axis == normal_axis: + if self.propagation_axis == self.normal_axis: raise ValidationError("Propagation and normal axes must be different.") - return values + return self @property def _clad_medium(self) -> tuple[MediumType, ...]: @@ -604,7 +605,9 @@ def structures(self) -> list[Structure]: if self.mode_spec.bend_radius is None or self.mode_spec.bend_radius == 0.0: half_length = 0.5 * self.length - def polyslab_vertices(x, w): + def polyslab_vertices( + x: float, w: float + ) -> tuple[list[float], list[float], list[float], list[float]]: return ( self._transform_in_plane(x, -half_length), self._transform_in_plane(x + w, -half_length), @@ -630,7 +633,7 @@ def polyslab_vertices(x, w): sin = numpy.sin(angles) cos = numpy.cos(angles) - def polyslab_vertices(x, w): + def polyslab_vertices(x: float, w: float) -> list[list[float]]: r_in = bend_radius + x v_in = numpy.vstack((-bend_radius + r_in * cos, r_in * sin)).T r_out = r_in + w @@ -1095,7 +1098,7 @@ def plot_geometry_edges( def plot_field( self, field_name: str, - val: Literal["real", "imag", "abs"] = "real", + val: Literal["real", "imag", abs] = "real", eps_alpha: float = 0.2, robust: bool = True, vmin: Optional[float] = None, diff --git a/tidy3d/updater.py b/tidy3d/updater.py index 867dd14cc5..cea3c23731 100644 --- a/tidy3d/updater.py +++ b/tidy3d/updater.py @@ -4,22 +4,24 @@ import functools import json -from os import PathLike from pathlib import Path -from typing import Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional -import pydantic.v1 as pd import yaml +from pydantic import BaseModel from .components.base import Tidy3dBaseModel from .exceptions import FileError, SetupError from .log import log from .version import __version__ +if TYPE_CHECKING: + from os import PathLike + """Storing version numbers.""" -class Version(pd.BaseModel): +class Version(BaseModel): """Stores a version number (excluding patch).""" major: int @@ -39,23 +41,25 @@ def from_string(cls, string: Optional[str] = None) -> Version: return version @property - def as_tuple(self): + def as_tuple(self) -> tuple[int, int]: """version as a tuple, leave out patch for now.""" return (self.major, self.minor) - def __hash__(self): + def __hash__(self) -> int: """define a hash.""" return hash(self.as_tuple) - def __str__(self): + def __str__(self) -> str: """Convert back to string.""" return f"{self.major}.{self.minor}" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """versions equal.""" + if not isinstance(other, Version): + return False return (self.major == other.major) and (self.minor == other.minor) - def __lt__(self, other): + def __lt__(self, other: Version) -> bool: """self < other.""" if self.major < other.major: return True @@ -63,7 +67,7 @@ def __lt__(self, other): return self.minor < other.minor return False - def __gt__(self, other): + def __gt__(self, other: Version) -> bool: """self > other.""" if self.major > other.major: return True @@ -71,11 +75,11 @@ def __gt__(self, other): return self.minor > other.minor return False - def __le__(self, other): + def __le__(self, other: Version) -> bool: """self <= other.""" return (self < other) or (self == other) - def __ge__(self, other): + def __ge__(self, other: Version) -> bool: """self >= other.""" return (self > other) or (self == other) @@ -85,10 +89,10 @@ def __ge__(self, other): """Class for updating simulation objects.""" -class Updater(pd.BaseModel): +class Updater(BaseModel): """Converts a tidy3d simulation.json file to an up-to-date Simulation instance.""" - sim_dict: dict + sim_dict: dict[str, Any] @classmethod def from_file(cls, fname: PathLike) -> Updater: @@ -96,7 +100,7 @@ def from_file(cls, fname: PathLike) -> Updater: path = Path(fname) # TODO: fix this, it broke if path.suffix in {".hdf5", ".gz"}: - sim_dict = Tidy3dBaseModel.from_file(fname=str(path)).dict() + sim_dict = Tidy3dBaseModel.from_file(fname=str(path)).model_dump() else: with path.open(encoding="utf-8") as f: if path.suffix == ".json": @@ -121,7 +125,7 @@ def version(self) -> Version: raise SetupError("Could not find a version in the supplied json.") return Version.from_string(version_string) - def get_update_function(self): + def get_update_function(self) -> Callable[[dict[str, Any]], dict[str, Any]]: """Get the highest update version <= self.version.""" leq_versions = [v for v in UPDATE_MAP if v <= self.version] if not leq_versions: @@ -134,9 +138,9 @@ def get_next_version(self) -> Version: gt_versions = [v for v in UPDATE_MAP if v > self.version] if not gt_versions: return CurrentVersion - return str(min(gt_versions)) + return min(gt_versions) - def update_to_current(self) -> dict: + def update_to_current(self) -> dict[str, Any]: """Update supplied simulation dictionary to current version.""" if self.version == CurrentVersion: self.sim_dict["version"] = __version__ @@ -149,28 +153,33 @@ def update_to_current(self) -> dict: self.sim_dict["version"] = __version__ return self.sim_dict - def __eq__(self, other: Updater) -> bool: + def __eq__(self, other: object) -> bool: """Is Updater equal to another one?""" + if not isinstance(other, Updater): + return False return self.sim_dict == other.sim_dict """Update conversion functions.""" # versions will be dynamically mapped in this table when the update functions are initialized. -UPDATE_MAP = {} +UpdateFn = Callable[[dict[str, Any]], dict[str, Any]] +TransformFn = Callable[[dict[str, Any]], Optional[dict[str, Any]]] + +UPDATE_MAP: dict[Version, UpdateFn] = {} -def updates_from_version(version_from_string: str): +def updates_from_version(version_from_string: str) -> Callable[[UpdateFn], UpdateFn]: """Decorates a sim_dict update function to change the version.""" # make sure the version strings are legit from_version = Version.from_string(version_from_string) - def decorator(update_fn): + def decorator(update_fn: UpdateFn) -> UpdateFn: """The actual decorator that gets returned by `updates_to_version('x.y.z')`""" @functools.wraps(update_fn) - def new_update_function(sim_dict: dict) -> dict: + def new_update_function(sim_dict: dict[str, Any]) -> dict[str, Any]: """Update function that automatically adds version string.""" return update_fn(sim_dict) @@ -182,7 +191,7 @@ def new_update_function(sim_dict: dict) -> dict: return decorator -def iterate_update_dict(update_dict: dict, update_types: dict[str, Callable]) -> None: +def iterate_update_dict(update_dict: Any, update_types: dict[str, TransformFn]) -> None: """Recursively iterate nested ``update_dict``. For any nested ``nested_dict`` found, apply an update function if its ``nested_dict["type"]`` is in the keys of the ``update_types`` dictionary. Also iterates lists and tuples. @@ -201,10 +210,10 @@ def iterate_update_dict(update_dict: dict, update_types: dict[str, Callable]) -> @updates_from_version("1.8") -def update_1_8(sim_dict: dict) -> dict: +def update_1_8(sim_dict: dict[str, Any]) -> dict[str, Any]: """Updates version 1.8.""" - def fix_missing_scalar_field(mnt_dict: dict) -> dict: + def fix_missing_scalar_field(mnt_dict: dict[str, Any]) -> dict[str, Any]: for key, val in mnt_dict["field_dataset"].items(): if isinstance(val, str) and val == "XR.DATAARRAY": mnt_dict["field_dataset"][key] = "ScalarFieldDataArray" @@ -220,17 +229,17 @@ def fix_missing_scalar_field(mnt_dict: dict) -> dict: @updates_from_version("1.7") -def update_1_7(sim_dict: dict) -> dict: +def update_1_7(sim_dict: dict[str, Any]) -> dict[str, Any]: """Updates version 1.7.""" - def fix_angle_info(mnt_dict: dict) -> dict: + def fix_angle_info(mnt_dict: dict[str, Any]) -> dict[str, Any]: mnt_dict["type"] = "FieldProjectionAngleMonitor" mnt_dict.pop("fields") mnt_dict.pop("medium") mnt_dict["proj_distance"] = 1e6 return mnt_dict - def fix_cartesian_info(mnt_dict: dict) -> dict: + def fix_cartesian_info(mnt_dict: dict[str, Any]) -> dict[str, Any]: mnt_dict["type"] = "FieldProjectionCartesianMonitor" mnt_dict.pop("fields") mnt_dict.pop("medium") @@ -240,7 +249,7 @@ def fix_cartesian_info(mnt_dict: dict) -> dict: mnt_dict["proj_axis"] = axis return mnt_dict - def fix_kspace_info(mnt_dict: dict) -> dict: + def fix_kspace_info(mnt_dict: dict[str, Any]) -> dict[str, Any]: mnt_dict["type"] = "FieldProjectionKSpaceMonitor" mnt_dict.pop("fields") mnt_dict.pop("medium") @@ -249,13 +258,13 @@ def fix_kspace_info(mnt_dict: dict) -> dict: mnt_dict["proj_axis"] = axis return mnt_dict - def fix_diffraction_info(mnt_dict: dict) -> dict: + def fix_diffraction_info(mnt_dict: dict[str, Any]) -> dict[str, Any]: mnt_dict.pop("medium", None) mnt_dict.pop("orders_x", None) mnt_dict.pop("orders_y", None) return mnt_dict - def fix_bloch_vec(mnt_dict: dict) -> dict: + def fix_bloch_vec(mnt_dict: dict[str, Any]) -> dict[str, Any]: mnt_dict["bloch_vec"] = mnt_dict["bloch_vec"]["real"] return mnt_dict @@ -273,7 +282,7 @@ def fix_bloch_vec(mnt_dict: dict) -> dict: @updates_from_version("1.6") -def update_1_6(sim_dict: dict) -> dict: +def update_1_6(sim_dict: dict[str, Any]) -> dict[str, Any]: """Updates version 1.6.""" if "grid_size" in sim_dict: sim_dict.pop("grid_size") @@ -281,10 +290,10 @@ def update_1_6(sim_dict: dict) -> dict: @updates_from_version("1.5") -def update_1_5(sim_dict: dict) -> dict: +def update_1_5(sim_dict: dict[str, Any]) -> dict[str, Any]: """Updates version 1.5.""" - def fix_mode_field_mnt(mnt_dict: dict) -> dict: + def fix_mode_field_mnt(mnt_dict: dict[str, Any]) -> dict[str, Any]: mnt_dict["type"] = "ModeSolverMonitor" return mnt_dict @@ -293,15 +302,15 @@ def fix_mode_field_mnt(mnt_dict: dict) -> dict: @updates_from_version("1.4") -def update_1_4(sim_dict: dict) -> dict: +def update_1_4(sim_dict: dict[str, Any]) -> dict[str, Any]: """Updates version 1.4.""" - def fix_polyslab(geo_dict) -> None: + def fix_polyslab(geo_dict: dict[str, Any]) -> None: """Fix a PolySlab dictionary.""" geo_dict.pop("length", None) geo_dict.pop("center", None) - def fix_modespec(ms_dict) -> None: + def fix_modespec(ms_dict: dict[str, Any]) -> None: """Fix a ModeSpec dictionary.""" sort_by = ms_dict.pop("sort_by", None) if sort_by and sort_by != "largest_neff": @@ -310,7 +319,7 @@ def fix_modespec(ms_dict) -> None: "largest effective index. Use ModeSpec.filter_pol to select polarization instead." ) - def fix_geometry_group(geo_dict) -> None: + def fix_geometry_group(geo_dict: dict[str, Any]) -> None: """Fix a GeometryGroup dictionary.""" geo_dict.pop("center", None) @@ -326,7 +335,7 @@ def fix_geometry_group(geo_dict) -> None: @updates_from_version("1.3") -def update_1_3(sim_dict: dict) -> dict: +def update_1_3(sim_dict: dict[str, Any]) -> dict[str, Any]: """Updates version 1.3.""" sim_dict["boundary_spec"] = {"x": {}, "y": {}, "z": {}} diff --git a/tidy3d/version.py b/tidy3d/version.py index 53cf26a3a3..cc406c5611 100644 --- a/tidy3d/version.py +++ b/tidy3d/version.py @@ -1,3 +1,3 @@ from __future__ import annotations -__version__ = "2.10.0" +__version__ = "2.10.2" diff --git a/tidy3d/web/api/asynchronous.py b/tidy3d/web/api/asynchronous.py index 3dc63f3491..fbc7c64401 100644 --- a/tidy3d/web/api/asynchronous.py +++ b/tidy3d/web/api/asynchronous.py @@ -2,14 +2,20 @@ from __future__ import annotations -from os import PathLike -from typing import Literal, Optional, Union +from typing import TYPE_CHECKING -from tidy3d.components.types.workflow import WorkflowType from tidy3d.log import log from tidy3d.web.core.types import PayType -from .container import DEFAULT_DATA_DIR, Batch, BatchData +from .container import DEFAULT_DATA_DIR, Batch + +if TYPE_CHECKING: + from os import PathLike + from typing import Literal, Optional, Union + + from tidy3d.components.types.workflow import WorkflowType + + from .container import BatchData def run_async( @@ -34,7 +40,7 @@ def run_async( Parameters ---------- - simulations : Union[Dict[str, Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]], tuple[Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]], list[Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]]] + simulations : Union[dict[str, Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]], tuple[Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]], list[Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`]]] Mapping of task name to simulation or list of simulations. folder_name : str = "default" Name of folder to store each task on web UI. diff --git a/tidy3d/web/api/autograd/autograd.py b/tidy3d/web/api/autograd/autograd.py index c2e4eb965c..c664f49605 100644 --- a/tidy3d/web/api/autograd/autograd.py +++ b/tidy3d/web/api/autograd/autograd.py @@ -1,26 +1,20 @@ # autograd wrapper for web functions from __future__ import annotations -import typing -from os import PathLike from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any, get_args -from autograd.builtins import dict as dict_ag from autograd.extend import defvjp, primitive import tidy3d as td -from tidy3d.components.autograd import AutogradFieldMap +from tidy3d.components.autograd.types import TracedDict from tidy3d.components.base import TRACED_FIELD_KEYS_ATTR -from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType from tidy3d.config import config from tidy3d.exceptions import AdjointError +from tidy3d.web.api import asynchronous as asynchronous_webapi +from tidy3d.web.api import webapi from tidy3d.web.api.asynchronous import DEFAULT_DATA_DIR -from tidy3d.web.api.asynchronous import run_async as run_async_webapi -from tidy3d.web.api.container import BatchData from tidy3d.web.api.tidy3d_stub import Tidy3dStub -from tidy3d.web.api.webapi import load, restore_simulation_if_cached -from tidy3d.web.api.webapi import run as run_webapi from tidy3d.web.core.types import PayType from .backward import postprocess_adj as _postprocess_adj_impl @@ -51,8 +45,16 @@ upload_sim_fields_keys as _upload_sim_fields_keys_impl, ) +if TYPE_CHECKING: + from os import PathLike + from typing import Callable, Literal, Optional, Union -def _resolve_local_gradient(value: typing.Optional[bool]) -> bool: + from tidy3d.components.autograd import AutogradFieldMap + from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType + from tidy3d.web.api.container import BatchData + + +def _resolve_local_gradient(value: Optional[bool]) -> bool: if value is not None: return bool(value) @@ -102,23 +104,23 @@ def is_valid_for_autograd_async(simulations: dict[str, td.Simulation]) -> bool: def run( simulation: WorkflowType, - task_name: typing.Optional[str] = None, + task_name: Optional[str] = None, folder_name: str = "default", path: PathLike = "simulation_data.hdf5", - callback_url: typing.Optional[str] = None, + callback_url: Optional[str] = None, verbose: bool = True, - progress_callback_upload: typing.Optional[typing.Callable[[float], None]] = None, - progress_callback_download: typing.Optional[typing.Callable[[float], None]] = None, - solver_version: typing.Optional[str] = None, - worker_group: typing.Optional[str] = None, + progress_callback_upload: Optional[Callable[[float], None]] = None, + progress_callback_download: Optional[Callable[[float], None]] = None, + solver_version: Optional[str] = None, + worker_group: Optional[str] = None, simulation_type: str = "tidy3d", - parent_tasks: typing.Optional[list[str]] = None, - local_gradient: typing.Optional[bool] = None, - max_num_adjoint_per_fwd: typing.Optional[int] = None, - reduce_simulation: typing.Literal["auto", True, False] = "auto", - pay_type: typing.Union[PayType, str] = PayType.AUTO, - priority: typing.Optional[int] = None, - lazy: typing.Optional[bool] = None, + parent_tasks: Optional[list[str]] = None, + local_gradient: Optional[bool] = None, + max_num_adjoint_per_fwd: Optional[int] = None, + reduce_simulation: Literal["auto", True, False] = "auto", + pay_type: Union[PayType, str] = PayType.AUTO, + priority: Optional[int] = None, + lazy: Optional[bool] = None, ) -> WorkflowDataType: """ Submits a :class:`.Simulation` to server, starts running, monitors progress, downloads, @@ -155,11 +157,11 @@ def run( but apply the configuration overrides defined in ``config.adjoint``; remote gradients ignore those overrides and enforce backend defaults. more stable with experimental features. - max_num_adjoint_per_fwd: typing.Optional[int] = None + max_num_adjoint_per_fwd: Optional[int] = None Maximum number of adjoint simulations allowed to run automatically. Uses the autograd configuration when None. reduce_simulation: Literal["auto", True, False] = "auto" Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver. - pay_type: typing.Union[PayType, str] = PayType.AUTO + pay_type: Union[PayType, str] = PayType.AUTO Which method to pay for the simulation. priority: int = None Task priority for vGPU queue (1=lowest, 10=highest). @@ -229,7 +231,7 @@ def run( path = Path(path) - if isinstance(simulation, typing.get_args(ComponentModelerType)): + if isinstance(simulation, get_args(ComponentModelerType)): if any(is_valid_for_autograd(s) for s in simulation.sim_dict.values()): from tidy3d.plugins.smatrix import run as smatrix_run @@ -268,7 +270,7 @@ def run( lazy=lazy, ) - return run_webapi( + return webapi.run( simulation=simulation, task_name=task_name, folder_name=folder_name, @@ -289,21 +291,21 @@ def run( def run_async( - simulations: typing.Union[dict[str, td.Simulation], tuple[td.Simulation], list[td.Simulation]], + simulations: Union[dict[str, td.Simulation], tuple[td.Simulation], list[td.Simulation]], folder_name: str = "default", path_dir: PathLike = DEFAULT_DATA_DIR, - callback_url: typing.Optional[str] = None, - num_workers: typing.Optional[int] = None, + callback_url: Optional[str] = None, + num_workers: Optional[int] = None, verbose: bool = True, simulation_type: str = "tidy3d", - solver_version: typing.Optional[str] = None, - parent_tasks: typing.Optional[dict[str, list[str]]] = None, - local_gradient: typing.Optional[bool] = None, - max_num_adjoint_per_fwd: typing.Optional[int] = None, - reduce_simulation: typing.Literal["auto", True, False] = "auto", - pay_type: typing.Union[PayType, str] = PayType.AUTO, - priority: typing.Optional[int] = None, - lazy: typing.Optional[bool] = None, + solver_version: Optional[str] = None, + parent_tasks: Optional[dict[str, list[str]]] = None, + local_gradient: Optional[bool] = None, + max_num_adjoint_per_fwd: Optional[int] = None, + reduce_simulation: Literal["auto", True, False] = "auto", + pay_type: Union[PayType, str] = PayType.AUTO, + priority: Optional[int] = None, + lazy: Optional[bool] = None, ) -> BatchData: """Submits a set of Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] objects to server, starts running, monitors progress, downloads, and loads results as a :class:`.BatchData` object. @@ -333,13 +335,13 @@ def run_async( Whether to perform gradient calculations locally. Defaults to ``config.adjoint.local_gradient`` when not provided. Local gradients require more downloads but ensure autograd overrides take effect; remote gradients ignore those overrides. - max_num_adjoint_per_fwd: typing.Optional[int] = None + max_num_adjoint_per_fwd: Optional[int] = None Maximum number of adjoint simulations allowed to run automatically. Uses the autograd configuration when None. reduce_simulation: Literal["auto", True, False] = "auto" Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver. - pay_type: typing.Union[PayType, str] = PayType.AUTO + pay_type: Union[PayType, str] = PayType.AUTO Specify the payment method. - priority: typing.Optional[int] = None + priority: Optional[int] = None Queue priority for vGPU simulations (1=lowest, 10=highest). lazy: Optional[bool] = None Whether to return lazy data proxies. Defaults to ``True`` for batch runs when @@ -398,7 +400,7 @@ def run_async( lazy=lazy, ) - return run_async_webapi( + return asynchronous_webapi.run_async( simulations=simulations, folder_name=folder_name, path_dir=path_dir, @@ -422,7 +424,7 @@ def _run( simulation: td.Simulation, task_name: str, local_gradient: bool = False, - max_num_adjoint_per_fwd: typing.Optional[int] = None, + max_num_adjoint_per_fwd: Optional[int] = None, **run_kwargs: Any, ) -> td.SimulationData: """User-facing ``web.run`` function, compatible with ``autograd`` differentiation.""" @@ -465,7 +467,7 @@ def _run( def _run_async( simulations: dict[str, td.Simulation], local_gradient: bool = False, - max_num_adjoint_per_fwd: typing.Optional[int] = None, + max_num_adjoint_per_fwd: Optional[int] = None, **run_async_kwargs: Any, ) -> dict[str, td.SimulationData]: """User-facing ``web.run_async`` function, compatible with ``autograd`` differentiation.""" @@ -483,7 +485,7 @@ def _run_async( if payload: sim_static.attrs[TRACED_FIELD_KEYS_ATTR] = payload sims_original[task_name] = sim_static - traced_fields_sim_dict = dict_ag(traced_fields_sim_dict) + traced_fields_sim_dict = TracedDict(traced_fields_sim_dict) # TODO: shortcut primitive running for any items with no tracers? @@ -511,7 +513,6 @@ def _run_async( def setup_run(simulation: td.Simulation) -> AutogradFieldMap: """Process a user-supplied ``Simulation`` into inputs to ``_run_primitive``.""" - # get a mapping of all the traced fields in the provided simulation return simulation._strip_traced_fields( include_untraced_data_arrays=False, starting_path=("structures",) @@ -563,7 +564,7 @@ def _run_primitive( ) else: sim_original = sim_original.updated_copy(simulation_type="autograd_fwd", deep=False) - restored_path, task_id_fwd = restore_simulation_if_cached( + restored_path, task_id_fwd = webapi.restore_simulation_if_cached( simulation=sim_original, path=run_kwargs.get("path", None), reduce_simulation=run_kwargs.get("reduce_simulation", "auto"), @@ -580,7 +581,7 @@ def _run_primitive( **run_kwargs, ) else: - sim_data_orig = load( + sim_data_orig = webapi.load( task_id=None, path=run_kwargs.get("path", None), verbose=run_kwargs.get("verbose", None), @@ -602,7 +603,7 @@ def _run_primitive( def _run_async_primitive( sim_fields_dict: dict[str, AutogradFieldMap], sims_original: dict[str, td.Simulation], - aux_data_dict: dict[dict[str, typing.Any]], + aux_data_dict: dict[str, dict[str, Any]], local_gradient: bool, max_num_adjoint_per_fwd: int, **run_async_kwargs: Any, @@ -711,7 +712,7 @@ def _run_bwd( local_gradient: bool, max_num_adjoint_per_fwd: int, **run_kwargs: Any, -) -> typing.Callable[[AutogradFieldMap], AutogradFieldMap]: +) -> Callable[[AutogradFieldMap], AutogradFieldMap]: """VJP-maker for ``_run_primitive()``. Constructs and runs adjoint simulations, computes grad.""" # indicate this is an adjoint run @@ -832,11 +833,11 @@ def _run_async_bwd( data_fields_original_dict: dict[str, AutogradFieldMap], sim_fields_original_dict: dict[str, AutogradFieldMap], sims_original: dict[str, td.Simulation], - aux_data_dict: dict[str, dict[str, typing.Any]], + aux_data_dict: dict[str, dict[str, Any]], local_gradient: bool, max_num_adjoint_per_fwd: int, **run_async_kwargs: Any, -) -> typing.Callable[[dict[str, AutogradFieldMap]], dict[str, AutogradFieldMap]]: +) -> Callable[[dict[str, AutogradFieldMap]], dict[str, AutogradFieldMap]]: """VJP-maker for ``_run_primitive()``. Constructs and runs adjoint simulation, computes grad.""" # indicate this is an adjoint run diff --git a/tidy3d/web/api/autograd/backward.py b/tidy3d/web/api/autograd/backward.py index 2e90656caf..9573fa7df2 100644 --- a/tidy3d/web/api/autograd/backward.py +++ b/tidy3d/web/api/autograd/backward.py @@ -1,13 +1,13 @@ from __future__ import annotations from collections import defaultdict +from typing import TYPE_CHECKING import numpy as np import xarray as xr import tidy3d as td -from tidy3d import Medium -from tidy3d.components.autograd import AutogradFieldMap, get_static +from tidy3d.components.autograd import get_static from tidy3d.components.autograd.derivative_utils import DerivativeInfo from tidy3d.components.data.data_array import DataArray from tidy3d.config import config @@ -16,6 +16,10 @@ from .utils import E_to_D, get_derivative_maps +if TYPE_CHECKING: + from tidy3d import Medium + from tidy3d.components.autograd import AutogradFieldMap + def setup_adj( data_fields_vjp: AutogradFieldMap, @@ -27,10 +31,14 @@ def setup_adj( td.log.info("Running custom vjp (adjoint) pipeline.") - # filter out any data_fields_vjp with all 0's - data_fields_vjp = { - k: get_static(v) for k, v in data_fields_vjp.items() if not np.allclose(v, 0) - } + # filter out any data_fields_vjp with exact all 0's + data_fields_vjp_static = {} + for k, v in data_fields_vjp.items(): + v_static = get_static(v) + if np.count_nonzero(v_static) == 0: + continue + data_fields_vjp_static[k] = v_static + data_fields_vjp = data_fields_vjp_static for k, v in data_fields_vjp.items(): if np.any(np.isnan(v)): @@ -146,7 +154,6 @@ def postprocess_adj( D_fwd = E_to_D(fld_fwd, eps_fwd) D_adj = E_to_D(fld_adj, eps_fwd) - # compute the derivatives for this structure structure = sim_data_fwd.simulation.structures[structure_index] # compute epsilon arrays for all frequencies @@ -165,19 +172,9 @@ def postprocess_adj( f"but derivative map has: {adjoint_frequencies}. " ) - eps_in = _compute_eps_array(structure.medium, adjoint_frequencies) - eps_out = _compute_eps_array(sim_data_orig.simulation.medium, adjoint_frequencies) - - # handle background medium if present - if structure.background_medium: - eps_background = _compute_eps_array(structure.background_medium, adjoint_frequencies) - else: - eps_background = None - # auto permittivity detection sim_orig = sim_data_orig.simulation plane_eps = eps_fwd.monitor.geometry - sim_orig_grid_spec = td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid) # permittivity without this structure @@ -187,8 +184,23 @@ def postprocess_adj( structures=structs_no_struct, monitors=[], sources=[], grid_spec=sim_orig_grid_spec ) + # for the outside permittivity of the structure, resize the bounds of the permittivity region + # to make sure we capture data outside the structure bounds + low_coords = [center - 0.5 * size for center, size in zip(plane_eps.center, plane_eps.size)] + high_coords = [ + center + 0.5 * size for center, size in zip(plane_eps.center, plane_eps.size) + ] + + low_bounds = sim_orig.grid.boundaries.get_bounding_values(low_coords, "left", buffer=1) + high_bounds = sim_orig.grid.boundaries.get_bounding_values(high_coords, "right", buffer=1) + + resized_center = [0.5 * (low + high) for low, high in zip(low_bounds, high_bounds)] + resized_size = [(high - low) for low, high in zip(low_bounds, high_bounds)] + + resize_plane_eps = plane_eps.updated_copy(center=resized_center, size=resized_size) + eps_no_structure_data = [ - sim_no_structure.epsilon(box=plane_eps, coord_key="centers", freq=f) + sim_no_structure.epsilon(box=resize_plane_eps, coord_key="centers", freq=f) for f in adjoint_frequencies ] @@ -196,14 +208,29 @@ def postprocess_adj( f=adjoint_frequencies ) - if structure.medium.is_pec: + if structure.medium.is_custom: + # we can't make an infinite structure from a custom medium permittivity eps_inf_structure = None else: + geometry_box = structure.geometry.bounding_box + background_structures_2d = [] + sim_inf_background_medium = sim_orig.medium + if np.any(np.array(geometry_box.size) == 0.0): + zero_coordinate = tuple(geometry_box.size).index(0.0) + new_size = [td.inf, td.inf, td.inf] + new_size[zero_coordinate] = 0.0 + + background_structures_2d = [ + structure.updated_copy(geometry=geometry_box.updated_copy(size=new_size)) + ] + else: + sim_inf_background_medium = structure.medium + # permittivity with infinite structure structs_inf_struct = list(sim_orig.structures)[structure_index + 1 :] sim_inf_structure = sim_orig.updated_copy( - structures=structs_inf_struct, - medium=structure.medium, + structures=background_structures_2d + structs_inf_struct, + medium=sim_inf_background_medium, monitors=[], sources=[], grid_spec=sim_orig_grid_spec, @@ -220,7 +247,7 @@ def postprocess_adj( # compute bounds intersection struct_bounds = rmin_struct, rmax_struct = structure.geometry.bounds - rmin_sim, rmax_sim = sim_data_orig.simulation.bounds + rmin_sim, rmax_sim = sim_orig.bounds rmin_intersect = tuple([max(a, b) for a, b in zip(rmin_sim, rmin_struct)]) rmax_intersect = tuple([min(a, b) for a, b in zip(rmax_sim, rmax_struct)]) bounds_intersect = (rmin_intersect, rmax_intersect) @@ -271,11 +298,6 @@ def postprocess_adj( ) # slice epsilon arrays - eps_in_chunk = eps_in.sel(f=select_adjoint_freqs) - eps_out_chunk = eps_out.sel(f=select_adjoint_freqs) - eps_background_chunk = ( - eps_background.sel(f=select_adjoint_freqs) if eps_background is not None else None - ) eps_no_structure_chunk = ( eps_no_structure.sel(f=select_adjoint_freqs) if eps_no_structure is not None @@ -300,16 +322,15 @@ def postprocess_adj( H_fwd=H_fwd_chunk, H_adj=H_adj_chunk, eps_data=eps_data_chunk, - eps_in=eps_in_chunk, - eps_out=eps_out_chunk, - eps_background=eps_background_chunk, + eps_in=eps_inf_structure_chunk, + eps_out=eps_no_structure_chunk, frequencies=select_adjoint_freqs, # only chunk frequencies - eps_no_structure=eps_no_structure_chunk, - eps_inf_structure=eps_inf_structure_chunk, bounds=struct_bounds, bounds_intersect=bounds_intersect, simulation_bounds=sim_data_orig.simulation.bounds, is_medium_pec=structure.medium.is_pec, + background_medium_is_pec=structure.background_medium + and structure.background_medium.is_pec, ) # compute derivatives for chunk diff --git a/tidy3d/web/api/autograd/engine.py b/tidy3d/web/api/autograd/engine.py index c9f36e0a42..8f78edde73 100644 --- a/tidy3d/web/api/autograd/engine.py +++ b/tidy3d/web/api/autograd/engine.py @@ -11,7 +11,7 @@ def parse_run_kwargs(**run_kwargs: Any) -> dict[str, Any]: """Parse the ``run_kwargs`` to extract what should be passed to the ``Job``/``Batch`` init.""" - job_fields = [*list(Job._upload_fields), "solver_version", "pay_type", "lazy"] + job_fields = [*list(Job._upload_fields.default), "solver_version", "pay_type", "lazy"] job_init_kwargs = {k: v for k, v in run_kwargs.items() if k in job_fields} return job_init_kwargs diff --git a/tidy3d/web/api/autograd/forward.py b/tidy3d/web/api/autograd/forward.py index 7804f6a8a2..ba2ac1d75c 100644 --- a/tidy3d/web/api/autograd/forward.py +++ b/tidy3d/web/api/autograd/forward.py @@ -1,7 +1,10 @@ from __future__ import annotations -import tidy3d as td -from tidy3d.components.autograd import AutogradFieldMap +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import tidy3d as td + from tidy3d.components.autograd import AutogradFieldMap def setup_fwd( @@ -11,6 +14,9 @@ def setup_fwd( ) -> td.Simulation: """Return a forward simulation with adjoint monitors attached.""" + # Ensure there aren't any traced geometries with custom media + sim_original._check_custom_medium_geometry_overlap(sim_fields) + # Always try to build the variant that includes adjoint monitors so that # errors in monitor placement are caught early. sim_with_adj_mon = sim_original._with_adjoint_monitors(sim_fields) diff --git a/tidy3d/web/api/autograd/io_utils.py b/tidy3d/web/api/autograd/io_utils.py index 08d03b0f47..549995dfe5 100644 --- a/tidy3d/web/api/autograd/io_utils.py +++ b/tidy3d/web/api/autograd/io_utils.py @@ -2,9 +2,9 @@ import os import tempfile +from typing import TYPE_CHECKING import tidy3d as td -from tidy3d.components.autograd import AutogradFieldMap from tidy3d.components.autograd.field_map import FieldMap, TracerKeys from tidy3d.web.api.webapi import get_info, load_simulation from tidy3d.web.cache import resolve_local_cache @@ -12,6 +12,9 @@ from .constants import SIM_FIELDS_KEYS_FILE, SIM_VJP_FILE +if TYPE_CHECKING: + from tidy3d.components.autograd import AutogradFieldMap + def upload_sim_fields_keys( sim_fields_keys: list[tuple], task_id: str, verbose: bool = False @@ -50,7 +53,6 @@ def get_vjp_traced_fields(task_id_adj: str, verbose: bool) -> AutogradFieldMap: with tempfile.NamedTemporaryFile(suffix=".hdf5") as tmp_file: simulation = load_simulation(task_id_adj, path=tmp_file.name, verbose=False) simulation_cache.store_result( - stub_data=field_map, task_id=task_id_adj, path=fname, workflow_type=workflow_type, diff --git a/tidy3d/web/api/autograd/utils.py b/tidy3d/web/api/autograd/utils.py index f58f2f3c34..786dbcf56d 100644 --- a/tidy3d/web/api/autograd/utils.py +++ b/tidy3d/web/api/autograd/utils.py @@ -1,12 +1,15 @@ # utility functions for autograd web API from __future__ import annotations -import typing +from typing import TYPE_CHECKING import numpy as np import tidy3d as td +if TYPE_CHECKING: + from typing import Optional, Union + """ E and D field gradient map calculation helpers. """ @@ -15,7 +18,7 @@ def get_derivative_maps( eps_fwd: td.PermittivityData, fld_adj: td.FieldData, eps_adj: td.PermittivityData, -) -> dict[str, td.FieldData]: +) -> dict[str, Optional[td.FieldData]]: """Get electric and displacement field derivative maps.""" der_map_E = derivative_map_E(fld_fwd=fld_fwd, fld_adj=fld_adj) der_map_D = derivative_map_D(fld_fwd=fld_fwd, eps_fwd=eps_fwd, fld_adj=fld_adj, eps_adj=eps_adj) @@ -58,11 +61,11 @@ def E_to_D(fld_data: td.FieldData, eps_data: td.PermittivityData) -> td.FieldDat def multiply_field_data( - fld_1: td.FieldData, fld_2: typing.Union[td.FieldData, td.PermittivityData], fld_key: str + fld_1: td.FieldData, fld_2: Union[td.FieldData, td.PermittivityData], fld_key: str ) -> td.FieldData: """Elementwise multiply two field data objects, writes data into ``fld_1`` copy.""" - def get_field_key(dim: str, fld_data: typing.Union[td.FieldData, td.PermittivityData]) -> str: + def get_field_key(dim: str, fld_data: Union[td.FieldData, td.PermittivityData]) -> str: """Get the key corresponding to the scalar field along this dimension.""" return f"{fld_key}{dim}" if isinstance(fld_data, td.FieldData) else f"eps_{dim}{dim}" diff --git a/tidy3d/web/api/connect_util.py b/tidy3d/web/api/connect_util.py index 6148c67154..f14637b061 100644 --- a/tidy3d/web/api/connect_util.py +++ b/tidy3d/web/api/connect_util.py @@ -4,7 +4,7 @@ import time from functools import wraps -from typing import Any, Callable, Optional +from typing import TYPE_CHECKING, Any from requests import ReadTimeout from requests.exceptions import ConnectionError as ConnErr @@ -16,6 +16,9 @@ from tidy3d.web import common from tidy3d.web.common import REFRESH_TIME +if TYPE_CHECKING: + from typing import Callable, Optional + def wait_for_connection( decorated_fn: Optional[Callable[..., Any]] = None, diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index bbc7616760..0dd94f587e 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -10,27 +10,18 @@ import time import uuid from abc import ABC -from collections.abc import Iterator, Mapping +from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor -from os import PathLike from pathlib import Path -from typing import Any, Literal, Optional, Union - -import pydantic.v1 as pd -from pydantic.v1 import PrivateAttr -from rich.progress import ( - BarColumn, - Progress, - TaskID, - TaskProgressColumn, - TextColumn, - TimeElapsedColumn, -) +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +from pydantic import Field, PositiveInt, PrivateAttr, model_validator +from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn from tidy3d.components.base import Tidy3dBaseModel, cached_property from tidy3d.components.mode.mode_solver import ModeSolver -from tidy3d.components.types import annotate_type -from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType +from tidy3d.components.types.base import discriminated_union +from tidy3d.components.types.workflow import WorkflowType from tidy3d.exceptions import DataError from tidy3d.log import get_logging_console, log from tidy3d.web.api import webapi as web @@ -50,9 +41,17 @@ from tidy3d.web.cache import _store_mode_solver_in_cache from tidy3d.web.core.constants import TaskId, TaskName from tidy3d.web.core.task_core import Folder -from tidy3d.web.core.task_info import RunInfo, TaskInfo from tidy3d.web.core.types import PayType +if TYPE_CHECKING: + from collections.abc import Iterator + from os import PathLike + + from rich.progress import TaskID + + from tidy3d.components.types.workflow import WorkflowDataType + from tidy3d.web.core.task_info import RunInfo, TaskInfo + # Max # of workers for parallel upload / download: above 10, performance is same but with warnings DEFAULT_NUM_WORKERS = 10 DEFAULT_DATA_PATH = "simulation_data.hdf5" @@ -172,24 +171,25 @@ class Job(WebContainer): * `Inverse taper edge coupler <../../notebooks/EdgeCoupler.html>`_ """ - simulation: WorkflowType = pd.Field( - ..., + simulation: WorkflowType = Field( title="simulation", description="Simulation to run as a 'task'.", discriminator="type", ) - task_name: TaskName = pd.Field( + task_name: Optional[TaskName] = Field( None, title="Task Name", description="Unique name of the task. Will be auto-generated if not provided.", ) - folder_name: str = pd.Field( - "default", title="Folder Name", description="Name of folder to store task on web UI." + folder_name: str = Field( + "default", + title="Folder Name", + description="Name of folder to store task on web UI.", ) - callback_url: str = pd.Field( + callback_url: Optional[str] = Field( None, title="Callback URL", description="Http PUT url to receive simulation finish event. " @@ -197,28 +197,32 @@ class Job(WebContainer): "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", ) - solver_version: str = pd.Field( + solver_version: Optional[str] = Field( None, title="Solver Version", description="Custom solver version to use, " "otherwise uses default for the current front end version.", ) - verbose: bool = pd.Field( - True, title="Verbose", description="Whether to print info messages and progressbars." + verbose: bool = Field( + True, + title="Verbose", + description="Whether to print info messages and progressbars.", ) - simulation_type: BatchCategoryType = pd.Field( + simulation_type: BatchCategoryType = Field( "tidy3d", title="Simulation Type", description="Type of simulation, used internally only.", ) - parent_tasks: tuple[TaskId, ...] = pd.Field( - None, title="Parent Tasks", description="Tuple of parent task ids, used internally only." + parent_tasks: Optional[tuple[TaskId, ...]] = Field( + None, + title="Parent Tasks", + description="Tuple of parent task ids, used internally only.", ) - task_id_cached: TaskId = pd.Field( + task_id_cached: Optional[TaskId] = Field( None, title="Task ID (Cached)", description="Optional field to specify ``task_id``. Only used as a workaround internally " @@ -227,34 +231,36 @@ class Job(WebContainer): "fields that were not used to create the task will cause errors.", ) - reduce_simulation: Literal["auto", True, False] = pd.Field( + reduce_simulation: Literal["auto", True, False] = Field( "auto", title="Reduce Simulation", description="Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.", ) - pay_type: PayType = pd.Field( + pay_type: PayType = Field( PayType.AUTO, title="Payment Type", description="Specify the payment method.", ) - lazy: bool = pd.Field( + lazy: bool = Field( False, title="Lazy", description="Whether to load the actual data (lazy=False) or return a proxy that loads the data when accessed (lazy=True).", ) - _upload_fields = ( - "simulation", - "task_name", - "folder_name", - "callback_url", - "verbose", - "simulation_type", - "parent_tasks", - "solver_version", - "reduce_simulation", + _upload_fields: tuple[str, ...] = PrivateAttr( + ( + "simulation", + "task_name", + "folder_name", + "callback_url", + "verbose", + "simulation_type", + "parent_tasks", + "solver_version", + "reduce_simulation", + ) ) _stash_path: Optional[str] = PrivateAttr(default=None) @@ -483,7 +489,6 @@ def load(self, path: PathLike = DEFAULT_DATA_PATH) -> WorkflowDataType: _store_mode_solver_in_cache( self.task_id, self.simulation, - data, path, ) self.simulation._patch_data(data=data) @@ -545,16 +550,20 @@ def _check_path_dir(path: PathLike) -> None: if parent_dir != Path(".") and not parent_dir.exists(): parent_dir.mkdir(parents=True, exist_ok=True) - @pd.root_validator(pre=True) - def set_task_name_if_none(cls, values: dict[str, Any]) -> dict[str, Any]: + @model_validator(mode="before") + @classmethod + def set_task_name_if_none(cls, data: dict[str, Any]) -> dict[str, Any]: """ Auto-assign a task_name if user did not provide one. """ - if values.get("task_name") is None: - sim = values.get("simulation") + if not isinstance(data, dict): + return data + + if data.get("task_name") is None: + sim = data.get("simulation") stub = Tidy3dStub(simulation=sim) - values["task_name"] = stub.get_default_task_name() - return values + data["task_name"] = stub.get_default_task_name() + return data class BatchData(Tidy3dBaseModel, Mapping): @@ -582,32 +591,34 @@ class BatchData(Tidy3dBaseModel, Mapping): * `Performing parallel / batch processing of simulations <../../notebooks/ParameterScan.html>`_ """ - task_paths: dict[TaskName, str] = pd.Field( - ..., + task_paths: dict[TaskName, str] = Field( title="Data Paths", description="Mapping of task_name to path to corresponding data for each task in batch.", ) - task_ids: dict[TaskName, str] = pd.Field( - ..., title="Task IDs", description="Mapping of task_name to task_id for each task in batch." + task_ids: dict[TaskName, str] = Field( + title="Task IDs", + description="Mapping of task_name to task_id for each task in batch.", ) - verbose: bool = pd.Field( - True, title="Verbose", description="Whether to print info messages and progressbars." + verbose: bool = Field( + True, + title="Verbose", + description="Whether to print info messages and progressbars.", ) - cached_tasks: Optional[dict[TaskName, bool]] = pd.Field( + cached_tasks: Optional[dict[TaskName, bool]] = Field( None, title="Cached Tasks", description="Whether the data of a task came from the cache.", ) - lazy: bool = pd.Field( + lazy: bool = Field( False, title="Lazy", description="Whether to load the actual data (lazy=False) or return a proxy that loads the data when accessed (lazy=True).", ) - is_downloaded: Optional[bool] = pd.Field( + is_downloaded: Optional[bool] = Field( False, title="Is Downloaded", description="Whether the simulation data was downloaded before.", @@ -624,7 +635,7 @@ def load_sim_data(self, task_name: str) -> WorkflowDataType: return web.load( task_id=None if from_cache else task_id, path=task_data_path, - verbose=self.verbose, + verbose=False, replace_existing=not (from_cache or self.is_downloaded), lazy=self.lazy, ) @@ -695,31 +706,33 @@ class Batch(WebContainer): """ simulations: Union[ - dict[TaskName, annotate_type(WorkflowType)], tuple[annotate_type(WorkflowType), ...] - ] = pd.Field( - ..., + dict[TaskName, discriminated_union(WorkflowType)], + tuple[discriminated_union(WorkflowType), ...], + ] = Field( title="Simulations", description="Mapping of task names to Simulations to run as a batch.", ) - folder_name: str = pd.Field( + folder_name: str = Field( "default", title="Folder Name", description="Name of folder to store member of each batch on web UI.", ) - verbose: bool = pd.Field( - True, title="Verbose", description="Whether to print info messages and progressbars." + verbose: bool = Field( + True, + title="Verbose", + description="Whether to print info messages and progressbars.", ) - solver_version: str = pd.Field( + solver_version: Optional[str] = Field( None, title="Solver Version", description="Custom solver version to use, " "otherwise uses default for the current front end version.", ) - callback_url: str = pd.Field( + callback_url: Optional[str] = Field( None, title="Callback URL", description="Http PUT url to receive simulation finish event. " @@ -727,19 +740,19 @@ class Batch(WebContainer): "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", ) - simulation_type: BatchCategoryType = pd.Field( + simulation_type: BatchCategoryType = Field( "tidy3d", title="Simulation Type", description="Type of each simulation in the batch, used internally only.", ) - parent_tasks: dict[str, tuple[TaskId, ...]] = pd.Field( + parent_tasks: Optional[dict[str, tuple[TaskId, ...]]] = Field( None, title="Parent Tasks", description="Collection of parent task ids for each job in batch, used internally only.", ) - num_workers: Optional[pd.PositiveInt] = pd.Field( + num_workers: Optional[PositiveInt] = Field( DEFAULT_NUM_WORKERS, title="Number of Workers", description="Number of workers for multi-threading upload and download of batch. " @@ -748,19 +761,19 @@ class Batch(WebContainer): "number of threads available on the system.", ) - reduce_simulation: Literal["auto", True, False] = pd.Field( + reduce_simulation: Literal["auto", True, False] = Field( "auto", title="Reduce Simulation", description="Whether to reduce structures in the simulation to the simulation domain only. Note: currently only implemented for the mode solver.", ) - pay_type: PayType = pd.Field( + pay_type: PayType = Field( PayType.AUTO, title="Payment Type", description="Specify the payment method.", ) - jobs_cached: dict[TaskName, Job] = pd.Field( + jobs_cached: Optional[dict[TaskName, Job]] = Field( None, title="Jobs (Cached)", description="Optional field to specify ``jobs``. Only used as a workaround internally " @@ -769,13 +782,13 @@ class Batch(WebContainer): "fields that were not used to create the task will cause errors.", ) - lazy: bool = pd.Field( + lazy: bool = Field( False, title="Lazy", description="Whether to load the actual data (lazy=False) or return a proxy that loads the data when accessed (lazy=True).", ) - _job_type = Job + _job_type: type = PrivateAttr(Job) def run( self, @@ -851,13 +864,13 @@ def jobs(self) -> dict[TaskName, Job]: # the type of job to upload (to generalize to subclasses) JobType = self._job_type - self_dict = self.dict() + self_dict = self.model_dump() jobs = {} for task_name, simulation in simulations.items(): job_kwargs = {} - for key in JobType._upload_fields: + for key in JobType._upload_fields.default: if key in self_dict: job_kwargs[key] = self_dict.get(key) @@ -945,7 +958,7 @@ def get_info(self) -> dict[TaskName, TaskInfo]: Returns ------- - Dict[str, :class:`TaskInfo`] + dict[str, :class:`TaskInfo`] Mapping of task name to data about task associated with each task. """ info_dict = {} @@ -986,7 +999,7 @@ def get_run_info(self) -> dict[TaskName, RunInfo]: Returns ------- - Dict[str: :class:`RunInfo`] + dict[str: :class:`RunInfo`] Maps task names to run info for each task in the :class:`Batch`. """ run_info_dict = {} @@ -1290,13 +1303,14 @@ def fn(job: Job = job, job_path: PathLike = job_path) -> None: pbar_message = f"Downloading data for {len(fns)} tasks" pbar = progress.add_task(pbar_message, total=len(fns)) completed = 0 - for _ in concurrent.futures.as_completed(futures): + for fut in concurrent.futures.as_completed(futures): + fut.result() completed += 1 progress.update(pbar, completed=completed) else: # Still ensure completion if verbose is off - for _ in concurrent.futures.as_completed(futures): - pass + for fut in concurrent.futures.as_completed(futures): + fut.result() def load( self, @@ -1358,7 +1372,7 @@ def load( job_data = data[task_name] if not loaded_from_cache[task_name]: _store_mode_solver_in_cache( - task_ids[task_name], job.simulation, job_data, task_paths[task_name] + task_ids[task_name], job.simulation, task_paths[task_name] ) job.simulation._patch_data(data=job_data) diff --git a/tidy3d/web/api/material_fitter.py b/tidy3d/web/api/material_fitter.py index f66dce6471..a070bcb835 100644 --- a/tidy3d/web/api/material_fitter.py +++ b/tidy3d/web/api/material_fitter.py @@ -50,16 +50,27 @@ class _FitterRequest(BaseModel): class MaterialFitterTask(Submittable): """Material Fitter Task.""" - id: str = Field(title="Task ID", description="Task ID") + id: str = Field( + title="Task ID", + description="Task ID", + ) dispersion_fitter: DispersionFitter = Field( - title="Dispersion Fitter", description="Dispersion Fitter data" + title="Dispersion Fitter", + description="Dispersion Fitter data", + ) + status: str = Field( + title="Task Status", + description="Task Status", ) - status: str = Field(title="Task Status", description="Task Status") file_name: str = Field( - ..., title="file name", description="fitter data file name", alias="fileName" + title="file name", + description="fitter data file name", + alias="fileName", ) resource_path: str = Field( - ..., title="resource path", description="resource path", alias="resourcePath" + title="resource path", + description="resource path", + alias="resourcePath", ) @classmethod diff --git a/tidy3d/web/api/material_library.py b/tidy3d/web/api/material_library.py new file mode 100644 index 0000000000..e51abe8ba4 --- /dev/null +++ b/tidy3d/web/api/material_library.py @@ -0,0 +1,64 @@ +"""Material Library API.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, Optional + +from pydantic import Field, TypeAdapter, field_validator + +from tidy3d.components.medium import MediumType +from tidy3d.web.core.http_util import http +from tidy3d.web.core.types import Queryable + +if TYPE_CHECKING: + from tidy3d.web.core.http_util import JSONType + + +class MaterialLibrary(Queryable): + """Material Library Resource interface.""" + + id: str = Field( + title="Material Library ID", + description="Material Library ID", + ) + name: str = Field( + title="Material Library Name", + description="Material Library Name", + ) + medium: Optional[MediumType] = Field( + None, + title="medium", + description="medium", + alias="calcResult", + ) + medium_type: Optional[str] = Field( + None, + title="medium type", + description="medium type", + alias="mediumType", + ) + json_input: Optional[dict] = Field( + None, + title="json input", + description="original input", + alias="jsonInput", + ) + + @field_validator("medium", "json_input", mode="before") + @classmethod + def parse_result(cls, values: Any) -> JSONType: + """Automatically parsing medium and json_input from string to object.""" + return json.loads(values) + + @classmethod + def list(cls) -> list[MaterialLibrary]: + """List all material libraries. + + Returns + ------- + tasks : list[:class:`.MaterialLibrary`] + List of material libraries/ + """ + resp = http.get("tidy3d/libraries") + return TypeAdapter(list[MaterialLibrary]).validate_python(resp) if resp else None diff --git a/tidy3d/web/api/material_libray.py b/tidy3d/web/api/material_libray.py deleted file mode 100644 index 2bc05cb419..0000000000 --- a/tidy3d/web/api/material_libray.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Material Library API.""" - -from __future__ import annotations - -import builtins -import json -from typing import Any, Optional - -from pydantic.v1 import Field, parse_obj_as, validator - -from tidy3d.components.medium import MediumType -from tidy3d.web.core.http_util import JSONType, http -from tidy3d.web.core.types import Queryable - - -class MaterialLibray(Queryable, smart_union=True): - """Material Library Resource interface.""" - - id: str = Field(title="Material Library ID", description="Material Library ID") - name: str = Field(title="Material Library Name", description="Material Library Name") - medium: Optional[MediumType] = Field(title="medium", description="medium", alias="calcResult") - medium_type: Optional[str] = Field( - title="medium type", description="medium type", alias="mediumType" - ) - json_input: Optional[dict] = Field( - title="json input", description="original input", alias="jsonInput" - ) - - @validator("medium", "json_input", pre=True) - def parse_result(cls, values: dict[str, Any]) -> JSONType: - """Automatically parsing medium and json_input from string to object.""" - return json.loads(values) - - @classmethod - def list(cls) -> builtins.list[MaterialLibray]: - """List all material libraries. - - Returns - ------- - tasks : List[:class:`.MaterialLibray`] - List of material libraries/ - """ - resp = http.get("tidy3d/libraries") - return parse_obj_as(list[MaterialLibray], resp) if resp else None diff --git a/tidy3d/web/api/mode.py b/tidy3d/web/api/mode.py index 1927ac8b9f..8e769038b3 100644 --- a/tidy3d/web/api/mode.py +++ b/tidy3d/web/api/mode.py @@ -3,18 +3,15 @@ from __future__ import annotations import os -import pathlib import tempfile import time from datetime import datetime -from os import PathLike -from typing import Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Optional -import pydantic.v1 as pydantic -import requests from botocore.exceptions import ClientError from joblib import Parallel, delayed -from rich.progress import Progress, TaskID +from pydantic import Field +from rich.progress import Progress from tidy3d.components.data.monitor_data import ModeSolverData from tidy3d.components.eme.simulation import EMESimulation @@ -31,6 +28,14 @@ from tidy3d.web.core.task_core import Folder from tidy3d.web.core.types import PayType, ResourceLifecycle, Submittable +if TYPE_CHECKING: + import pathlib + from os import PathLike + from typing import Callable, Literal, Union + + import requests + from rich.progress import TaskID + SIMULATION_JSON = "simulation.json" SIM_FILE_HDF5_GZ = "simulation.hdf5.gz" MODESOLVER_API = "tidy3d/modesolver/py" @@ -164,13 +169,13 @@ def run_batch( Parameters ---------- - mode_solvers : List[ModeSolver] + mode_solvers : list[ModeSolver] List of mode solvers to be submitted to the server. task_name : str Base name for tasks. Each task in the batch will have a unique index appended to this base name. folder_name : str Name of the folder where tasks are stored on the server's web UI. - results_files : List[str], optional + results_files : list[str], optional List of file paths where the results for each ModeSolver should be downloaded. If None, a default path based on the folder name and index is used. verbose : bool If True, displays a progress bar. If False, runs silently. @@ -188,7 +193,7 @@ def run_batch( Returns ------- - List[ModeSolverData] + list[ModeSolverData] A list of ModeSolverData objects containing the results from each simulation in the batch. ``None`` is placed in the list for simulations that fail after all retries. """ console = get_logging_console() @@ -233,6 +238,7 @@ def handle_mode_solver( if verbose: progress.update(pbar, advance=1) return None + return None if verbose: console.log(f"[cyan]Running a batch of [deep_pink4]{num_mode_solvers} mode solvers.\n") @@ -256,45 +262,45 @@ def handle_mode_solver( return results -class ModeSolverTask(ResourceLifecycle, Submittable, extra=pydantic.Extra.allow): +class ModeSolverTask(ResourceLifecycle, Submittable, extra="allow"): """Interface for managing the running of a :class:`.ModeSolver` task on server.""" - task_id: str = pydantic.Field( + task_id: Optional[str] = Field( None, title="task_id", description="Task ID number, set when the task is created, leave as None.", alias="refId", ) - solver_id: str = pydantic.Field( + solver_id: Optional[str] = Field( None, title="solver", description="Solver ID number, set when the task is created, leave as None.", alias="id", ) - real_flex_unit: float = pydantic.Field( + real_flex_unit: Optional[float] = Field( None, title="real FlexCredits", description="Billed FlexCredits.", alias="charge" ) - created_at: Optional[datetime] = pydantic.Field( + created_at: Optional[datetime] = Field( title="created_at", description="Time at which this task was created.", alias="createdAt" ) - status: str = pydantic.Field( + status: Optional[str] = Field( None, title="status", description="Mode solver task status.", ) - file_type: str = pydantic.Field( + file_type: Optional[str] = Field( None, title="file_type", description="File type used to upload the mode solver.", alias="fileType", ) - mode_solver: ModeSolver = pydantic.Field( + mode_solver: Optional[ModeSolver] = Field( None, title="mode_solver", description="Mode solver being run by this task.", @@ -574,7 +580,7 @@ def get_modesolver( progress_callback=progress_callback, ) mode_solver_dict["simulation"] = Simulation.from_json(sim_file) - mode_solver = ModeSolver.parse_obj(mode_solver_dict) + mode_solver = ModeSolver.model_validate(mode_solver_dict) # Store requested mode solver file mode_solver.to_file(to_file) diff --git a/tidy3d/web/api/run.py b/tidy3d/web/api/run.py index 892f8efc16..7c8de875f6 100644 --- a/tidy3d/web/api/run.py +++ b/tidy3d/web/api/run.py @@ -1,8 +1,8 @@ from __future__ import annotations import typing -from os import PathLike from pathlib import Path +from typing import TYPE_CHECKING from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType from tidy3d.config import config @@ -12,6 +12,9 @@ from tidy3d.web.api.container import DEFAULT_DATA_DIR, DEFAULT_DATA_PATH from tidy3d.web.core.types import PayType +if TYPE_CHECKING: + from os import PathLike + RunInput: typing.TypeAlias = typing.Union[ WorkflowType, list["RunInput"], diff --git a/tidy3d/web/api/states.py b/tidy3d/web/api/states.py index e6bf17a8ef..a0e2004141 100644 --- a/tidy3d/web/api/states.py +++ b/tidy3d/web/api/states.py @@ -10,7 +10,7 @@ "success", ) -MAX_STEPS = len(PROGRESSION_ORDER) - 1 +MAX_STEPS = 4 COMPLETED_PERCENT = 100 PRE_ERROR_STATES = { @@ -46,16 +46,21 @@ "run_success", } -COMPLETED_STATES = { +SUCCESS_STATES = { "visualize", "success", "completed", "processed", "postprocess_success", +} + +DIVERGED_STATES = { "diverge", "diverged", } +COMPLETED_STATES = DIVERGED_STATES | SUCCESS_STATES + END_STATES = ERROR_STATES | COMPLETED_STATES POST_VALIDATE_STATES = {"validate_success", "validate_warn", "warning"} @@ -79,16 +84,22 @@ STATE_PROGRESS_PERCENTAGE = dict.fromkeys(ALL_STATES, 0) STATE_PROGRESS_PERCENTAGE.update(dict.fromkeys(COMPLETED_STATES, COMPLETED_PERCENT)) STATE_PROGRESS_PERCENTAGE.update( - {state: round((1 / MAX_STEPS) * COMPLETED_PERCENT) for state in QUEUED_STATES} + {state: round((0 / MAX_STEPS) * COMPLETED_PERCENT) for state in QUEUED_STATES} +) +STATE_PROGRESS_PERCENTAGE.update( + {state: round((0 / MAX_STEPS) * COMPLETED_PERCENT) for state in DIVERGED_STATES} +) +STATE_PROGRESS_PERCENTAGE.update( + {state: round((1 / MAX_STEPS) * COMPLETED_PERCENT) for state in PREPROCESS_STATES} ) STATE_PROGRESS_PERCENTAGE.update( - {state: round((2 / MAX_STEPS) * COMPLETED_PERCENT) for state in PREPROCESS_STATES} + {state: round((2 / MAX_STEPS) * COMPLETED_PERCENT) for state in RUNNING_STATES} ) STATE_PROGRESS_PERCENTAGE.update( - {state: round((3 / MAX_STEPS) * COMPLETED_PERCENT) for state in RUNNING_STATES} + {state: round((3 / MAX_STEPS) * COMPLETED_PERCENT) for state in POSTPROCESS_STATES} ) STATE_PROGRESS_PERCENTAGE.update( - {state: round((4 / MAX_STEPS) * COMPLETED_PERCENT) for state in POSTPROCESS_STATES} + {state: round((4 / MAX_STEPS) * COMPLETED_PERCENT) for state in SUCCESS_STATES} ) @@ -110,14 +121,14 @@ def status_to_stage(status: str) -> tuple[str, int]: if s in DRAFT_STATES: return ("draft", 0) if s in QUEUED_STATES: - return ("queued", 1) + return ("queued", 0) if s in PREPROCESS_STATES: - return ("preprocess", 2) + return ("preprocess", 1) if s in RUNNING_STATES: - return ("running", 3) + return ("running", 2) if s in POSTPROCESS_STATES: - return ("postprocess", 4) + return ("postprocess", 3) if s in COMPLETED_STATES: - return ("success", 5) + return ("success", 4) # Unknown states map to earliest stage to avoid showing 100% prematurely return (s or "unknown", 0) diff --git a/tidy3d/web/api/tidy3d_stub.py b/tidy3d/web/api/tidy3d_stub.py index eb49f90461..f4b9de526b 100644 --- a/tidy3d/web/api/tidy3d_stub.py +++ b/tidy3d/web/api/tidy3d_stub.py @@ -3,11 +3,9 @@ from __future__ import annotations from datetime import datetime -from os import PathLike -from typing import Callable, Optional +from typing import TYPE_CHECKING -import pydantic.v1 as pd -from pydantic.v1 import BaseModel +from pydantic import BaseModel, Field from tidy3d import log from tidy3d.components.base import Tidy3dBaseModel @@ -38,6 +36,10 @@ from tidy3d.web.core.stub import TaskStub, TaskStubData from tidy3d.web.core.types import TaskType +if TYPE_CHECKING: + from os import PathLike + from typing import Callable, Optional + TYPE_MAP: dict[type, TaskType] = { Simulation: TaskType.FDTD, ModeSolver: TaskType.MODE_SOLVER, @@ -59,7 +61,7 @@ def task_type_name_of(simulation: WorkflowType) -> str: class Tidy3dStub(BaseModel, TaskStub): - simulation: WorkflowType = pd.Field(discriminator="type") + simulation: WorkflowType = Field(discriminator="type") @classmethod def from_file(cls, file_path: PathLike) -> WorkflowType: diff --git a/tidy3d/web/api/webapi.py b/tidy3d/web/api/webapi.py index c51ccc2e6a..98a92f0868 100644 --- a/tidy3d/web/api/webapi.py +++ b/tidy3d/web/api/webapi.py @@ -5,9 +5,8 @@ import json import tempfile import time -from os import PathLike from pathlib import Path -from typing import Callable, Literal, Optional, Union +from typing import TYPE_CHECKING from requests import HTTPError from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeElapsedColumn @@ -15,7 +14,6 @@ from tidy3d.components.medium import AbstractCustomMedium from tidy3d.components.mode.mode_solver import ModeSolver from tidy3d.components.mode.simulation import ModeSimulation -from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType from tidy3d.config import config from tidy3d.exceptions import WebError from tidy3d.log import get_logging_console, log @@ -27,7 +25,7 @@ STATE_PROGRESS_PERCENTAGE, status_to_stage, ) -from tidy3d.web.cache import CacheEntry, _store_mode_solver_in_cache, resolve_local_cache +from tidy3d.web.cache import _store_mode_solver_in_cache, resolve_local_cache from tidy3d.web.core.account import Account from tidy3d.web.core.constants import ( CM_DATA_HDF5_GZ, @@ -37,22 +35,23 @@ SIM_FILE_HDF5, SIM_FILE_HDF5_GZ, SIMULATION_DATA_HDF5_GZ, - TaskId, -) -from tidy3d.web.core.task_core import ( - BatchDetail, - BatchTask, - Folder, - SimulationTask, - TaskFactory, - WebTask, ) +from tidy3d.web.core.task_core import BatchTask, Folder, SimulationTask, TaskFactory, WebTask from tidy3d.web.core.task_info import ChargeType, TaskInfo from tidy3d.web.core.types import PayType, TaskType from .connect_util import REFRESH_TIME, get_grid_points_str, get_time_steps_str, wait_for_connection from .tidy3d_stub import Tidy3dStub, Tidy3dStubData +if TYPE_CHECKING: + from os import PathLike + from typing import Callable, Literal, Optional, Union + + from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType + from tidy3d.web.cache import CacheEntry + from tidy3d.web.core.constants import TaskId + from tidy3d.web.core.task_core import BatchDetail + # time between checking run status RUN_REFRESH_TIME = 1.0 @@ -420,7 +419,7 @@ def run( if isinstance(simulation, ModeSolver): if task_id is not None: - _store_mode_solver_in_cache(task_id, simulation, data, path) + _store_mode_solver_in_cache(task_id, simulation, path) simulation._patch_data(data=data) return data @@ -479,7 +478,7 @@ def upload( Optional callback function called when uploading file with ``bytes_in_chunk`` as argument. simulation_type : str = "tidy3d" Type of simulation being uploaded. - parent_tasks : List[str] + parent_tasks : list[str] List of related task ids. source_required: bool = True If ``True``, simulations without sources will raise an error before being uploaded. @@ -958,7 +957,7 @@ def abort(task_id: TaskId) -> Optional[TaskInfo]: f"Task is aborting. View task using web UI at [link={url}]'{url}'[/link] to check the result." ) return TaskInfo( - **{"taskId": task_id, "taskType": getattr(task, "task_type", None), **task.dict()} + **{"taskId": task_id, "taskType": getattr(task, "task_type", None), **task.model_dump()} ) @@ -1183,8 +1182,9 @@ def load( except Exception as e: log.info(f"Failed to load simulation for storing results: {e}.") return stub_data + else: + simulation = stub_data.simulation simulation_cache.store_result( - stub_data=stub_data, task_id=task_id, path=path, workflow_type=workflow_type, @@ -1221,6 +1221,7 @@ def _monitor_modeler_batch( TextColumn("[progress.description]{task.description}"), BarColumn(bar_width=25), TaskProgressColumn(), + TextColumn("[progress.description]{task.fields[status]}"), TimeElapsedColumn(), ) # Make the header @@ -1230,19 +1231,21 @@ def _monitor_modeler_batch( console.log(header) with Progress(*progress_columns, console=console, transient=False) as progress: # Phase: Run (aggregate + per-task) - p_run = progress.add_task("Run Total", total=1.0) - task_bars: dict[str, int] = {} stage = status_to_stage(status)[0] + p_run = progress.add_task("Run Total", total=1.0, status=f" {stage} ") + task_bars: dict[str, int] = {} prev_stage = status_to_stage(status)[0] console.log(f"Batch status = {status}") # Note: get_status errors if an erroring status occurred - while stage not in END_STATES: + end_monitor = False + while not end_monitor: total = len(detail.tasks) r = detail.runSuccess or 0 if stage != prev_stage: prev_stage = stage console.log(f"Batch status = {stage}") + progress.update(p_run, status=f" {stage} ") # Create per-task bars as soon as tasks appear if total and total <= max_detail_tasks and detail.tasks: @@ -1255,10 +1258,12 @@ def _monitor_modeler_batch( f" {name}", total=1.0, completed=STATE_PROGRESS_PERCENTAGE[tstatus] / 100, + status=f" {tstatus} ", ) task_bars[name] = pbar - # Aggregate run progress: average stage fraction across tasks + # Aggregate run progress: average stage fraction across tasks (80% weight) + # Final 20% achieved only when batch status is completed if detail.tasks: acc = 0.0 n_members = 0 @@ -1267,9 +1272,17 @@ def _monitor_modeler_batch( tstatus = (t.status or "draft").lower() _, idx = status_to_stage(tstatus) acc += max(0.0, min(1.0, idx / MAX_STEPS)) - run_frac = (acc / float(n_members)) if n_members else 0.0 + task_avg = (acc / float(n_members)) if n_members else 0.0 + run_frac = task_avg * 0.8 else: - run_frac = (r / total) if total else 0.0 + run_frac = (r / total) * 0.8 if total else 0.0 + + # Final 20% only when batch is completed + if status in END_STATES: + # Makes sure last state is logged + end_monitor = True + run_frac = 1.0 + progress.update(p_run, completed=run_frac) # Update per-task bars @@ -1281,11 +1294,11 @@ def _monitor_modeler_batch( continue tstatus = (t.status or "draft").lower() _, idx = status_to_stage(tstatus) - desc = f" {tname} [{tstatus or 'draft'}]" progress.update( pbar, completed=STATE_PROGRESS_PERCENTAGE[tstatus] / 100, - description=desc, + description=f" {tname}", + status=f" {tstatus} ", refresh=False, ) @@ -1321,7 +1334,7 @@ def delete(task_id: TaskId, versions: bool = False) -> TaskInfo: raise ValueError("Task id not found.") task = TaskFactory.get(task_id, verbose=False) task.delete(versions) - return TaskInfo(**{"taskId": task.task_id, **task.dict()}) + return TaskInfo(**{"taskId": task.task_id, **task.model_dump()}) @wait_for_connection @@ -1377,7 +1390,7 @@ def get_tasks( Returns ------- - List[Dict] + list[dict] List of dictionaries storing the information for each of the tasks last ``num_tasks`` tasks. """ folder = Folder.get(folder, create=True) @@ -1390,7 +1403,7 @@ def get_tasks( tasks = sorted(tasks, key=lambda t: t.created_at) if num_tasks is not None: tasks = tasks[:num_tasks] - return [task.dict() for task in tasks] + return [task.model_dump() for task in tasks] @wait_for_connection diff --git a/tidy3d/web/cache.py b/tidy3d/web/cache.py index 5c96a7a860..1e0047934b 100644 --- a/tidy3d/web/cache.py +++ b/tidy3d/web/cache.py @@ -1,874 +1,56 @@ -"""Local simulation cache manager.""" +"""Compatibility shim for :mod:`tidy3d._common.web.cache`.""" -from __future__ import annotations - -import hashlib -import json -import os -import shutil -import tempfile -import threading -from collections.abc import Iterator -from contextlib import contextmanager -from dataclasses import dataclass -from datetime import datetime, timezone -from enum import Enum -from pathlib import Path -from typing import Any, Optional +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt +# marked as migrated to _common +from __future__ import annotations -from tidy3d import config -from tidy3d.components.mode.mode_solver import ModeSolver -from tidy3d.components.types.workflow import WorkflowDataType, WorkflowType -from tidy3d.log import log +from typing import TYPE_CHECKING + +from tidy3d._common.web.cache import ( + _CACHE, + CACHE_ARTIFACT_NAME, + CACHE_METADATA_NAME, + CACHE_STATS_NAME, + TMP_BATCH_PREFIX, + TMP_PREFIX, + CacheEntry, + CacheEntryMetadata, + CacheStats, + LocalCache, + _canonicalize, + _copy_and_hash, + _Hasher, + _now, + _read_metadata, + _timestamp_suffix, + _write_metadata, + build_cache_key, + build_entry_metadata, + clear, + get_cache_entry_dir, + register_get_workflow_type, + resolve_local_cache, +) +from tidy3d._common.web.core.types import TaskType from tidy3d.web.api.tidy3d_stub import Tidy3dStub -from tidy3d.web.core.constants import TaskId -from tidy3d.web.core.http_util import get_version as _get_protocol_version -from tidy3d.web.core.types import TaskType - -CACHE_ARTIFACT_NAME = "simulation_data.hdf5" -CACHE_METADATA_NAME = "metadata.json" -CACHE_STATS_NAME = "stats.json" - -TMP_PREFIX = "tidy3d-cache-" -TMP_BATCH_PREFIX = "tmp_batch" - -_CACHE: Optional[LocalCache] = None - - -def get_cache_entry_dir(root: os.PathLike, key: str) -> Path: - """ - Returns the cache directory for a given key. - A three-character prefix subdirectory is used to avoid hitting filesystem limits on the number of entries per folder. - """ - return Path(root) / key[:3] / key - - -class CacheStats(BaseModel): - """Lightweight summary of cache usage persisted in ``stats.json``.""" - - last_used: dict[str, str] = Field( - default_factory=dict, - description="Mapping from cache entry key to the most recent ISO-8601 access timestamp.", - ) - total_size: NonNegativeInt = Field( - default=0, - description="Aggregate size in bytes across cached artifacts captured in the stats file.", - ) - updated_at: Optional[datetime] = Field( - default=None, - description="UTC timestamp indicating when the statistics were last refreshed.", - ) - - model_config = ConfigDict(extra="allow", validate_assignment=True) - - @property - def total_entries(self) -> int: - return len(self.last_used) - - -class CacheEntryMetadata(BaseModel): - """Schema for cache entry metadata persisted on disk.""" - - cache_key: str - checksum: str - created_at: datetime - last_used: datetime - file_size: int = Field(ge=0) - simulation_hash: str - workflow_type: str - versions: Any - task_id: str - path: str - - model_config = ConfigDict(extra="allow", validate_assignment=True) - - def bump_last_used(self) -> None: - self.last_used = datetime.now(timezone.utc) - - def as_dict(self) -> dict[str, Any]: - return self.model_dump(mode="json") - - def get(self, key: str, default: Any = None) -> Any: - return self.as_dict().get(key, default) - - def __getitem__(self, key: str) -> Any: - data = self.as_dict() - if key not in data: - raise KeyError(key) - return data[key] - - -@dataclass -class CacheEntry: - """Internal representation of a cache entry.""" - - key: str - root: Path - metadata: CacheEntryMetadata - - @property - def path(self) -> Path: - return get_cache_entry_dir(self.root, self.key) - - @property - def artifact_path(self) -> Path: - return self.path / CACHE_ARTIFACT_NAME - - @property - def metadata_path(self) -> Path: - return self.path / CACHE_METADATA_NAME - - def exists(self) -> bool: - return self.path.exists() and self.artifact_path.exists() and self.metadata_path.exists() - - def verify(self) -> bool: - if not self.exists(): - return False - checksum = self.metadata.checksum - if not checksum: - return False - try: - actual_checksum, file_size = _copy_and_hash(self.artifact_path, None) - except FileNotFoundError: - return False - if checksum != actual_checksum: - log.warning( - "Simulation cache checksum mismatch for key '%s'. Removing stale entry.", self.key - ) - return False - if self.metadata.file_size != file_size: - self.metadata.file_size = file_size - _write_metadata(self.metadata_path, self.metadata) - return True - - def materialize(self, target: Path) -> Path: - """Copy cached artifact to ``target`` and return the resulting path.""" - target = Path(target) - target.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(self.artifact_path, target) - return target - - -class LocalCache: - """Manages storing and retrieving cached simulation artifacts.""" - - def __init__(self, directory: os.PathLike, max_size_gb: float, max_entries: int) -> None: - self.max_size_gb = max_size_gb - self.max_entries = max_entries - self._root = Path(directory) - self._lock = threading.RLock() - self._syncing_stats = False - self._sync_pending = False - - @property - def _stats_path(self) -> Path: - return self._root / CACHE_STATS_NAME - - def _schedule_sync(self) -> None: - self._sync_pending = True - - def _run_pending_sync(self) -> None: - if self._sync_pending and not self._syncing_stats: - self._sync_pending = False - self.sync_stats() - - @contextmanager - def _with_lock(self) -> Iterator[None]: - self._run_pending_sync() - with self._lock: - yield - self._run_pending_sync() - - def _write_stats(self, stats: CacheStats) -> CacheStats: - updated = stats.model_copy(update={"updated_at": datetime.now(timezone.utc)}) - payload = updated.model_dump(mode="json") - payload["total_entries"] = updated.total_entries - self._stats_path.parent.mkdir(parents=True, exist_ok=True) - _write_metadata(self._stats_path, payload) - self._sync_pending = False - return updated - - def _load_stats(self, *, rebuild: bool = False) -> CacheStats: - path = self._stats_path - if not path.exists(): - if not self._syncing_stats: - self._schedule_sync() - return CacheStats() - try: - data = json.loads(path.read_text(encoding="utf-8")) - if "last_used" not in data and "entries" in data: - data["last_used"] = data.pop("entries") - stats = CacheStats.model_validate(data) - except Exception: - if rebuild and not self._syncing_stats: - self._schedule_sync() - return CacheStats() - if stats.total_size < 0: - self._schedule_sync() - return CacheStats() - return stats - - def _record_store_stats( - self, - key: str, - *, - last_used: str, - file_size: int, - previous_size: int, - ) -> None: - stats = self._load_stats() - entries = dict(stats.last_used) - entries[key] = last_used - total_size = stats.total_size - previous_size + file_size - if total_size < 0: - total_size = 0 - self._schedule_sync() - updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) - self._write_stats(updated) - - def _record_touch_stats( - self, key: str, last_used: str, *, file_size: Optional[int] = None - ) -> None: - stats = self._load_stats() - entries = dict(stats.last_used) - existed = key in entries - total_size = stats.total_size - if not existed and file_size is not None: - total_size += file_size - if total_size < 0: - total_size = 0 - self._schedule_sync() - entries[key] = last_used - updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) - self._write_stats(updated) - - def _record_remove_stats(self, key: str, file_size: int) -> None: - stats = self._load_stats() - entries = dict(stats.last_used) - entries.pop(key, None) - total_size = stats.total_size - file_size - if total_size < 0: - total_size = 0 - self._schedule_sync() - updated = stats.model_copy(update={"last_used": entries, "total_size": total_size}) - self._write_stats(updated) - - def _enforce_limits_post_sync(self, entries: list[CacheEntry]) -> None: - if not entries: - return - - entries_map = {entry.key: entry.metadata.last_used.isoformat() for entry in entries} - - if self.max_entries > 0 and len(entries) > self.max_entries: - excess = len(entries) - self.max_entries - self._evict(entries_map, remove_count=excess, exclude_keys=set()) - - max_size_bytes = int(self.max_size_gb * (1024**3)) - if max_size_bytes > 0: - total_size = sum(entry.metadata.file_size for entry in entries) - if total_size > max_size_bytes: - bytes_to_free = total_size - max_size_bytes - self._evict_by_size(entries_map, bytes_to_free, exclude_keys=set()) - - def sync_stats(self) -> CacheStats: - with self._lock: - self._syncing_stats = True - log.debug("Syncing stats.json of local cache") - try: - entries: list[CacheEntry] = [] - last_used_map: dict[str, str] = {} - total_size = 0 - for entry in self._iter_entries(): - entries.append(entry) - total_size += entry.metadata.file_size - last_used_map[entry.key] = entry.metadata.last_used.isoformat() - stats = CacheStats(last_used=last_used_map, total_size=total_size) - written = self._write_stats(stats) - self._enforce_limits_post_sync(entries) - return written - finally: - self._syncing_stats = False - - @property - def root(self) -> Path: - return self._root - - def list(self) -> list[dict[str, Any]]: - """Return metadata for all cache entries.""" - with self._with_lock(): - entries = [entry.metadata.model_dump(mode="json") for entry in self._iter_entries()] - return entries - - def clear(self, hard: bool = False) -> None: - """Remove all cache contents. If set to hard, root directory is removed.""" - with self._with_lock(): - if self._root.exists(): - try: - shutil.rmtree(self._root) - if not hard: - self._root.mkdir(parents=True, exist_ok=True) - except (FileNotFoundError, OSError): - pass - if not hard: - self._write_stats(CacheStats()) - - def _fetch(self, key: str) -> Optional[CacheEntry]: - """Retrieve an entry by key, verifying checksum.""" - with self._with_lock(): - entry = self._load_entry(key) - if not entry or not entry.exists(): - return None - if not entry.verify(): - self._remove_entry(entry) - return None - self._touch(entry) - return entry - - def __len__(self) -> int: - """Return number of valid cache entries.""" - with self._with_lock(): - count = self._load_stats().total_entries - return count - - def _store( - self, key: str, source_path: Path, metadata: CacheEntryMetadata - ) -> Optional[CacheEntry]: - """Store a new cache entry from ``source_path``. - - Parameters - ---------- - key : str - Cache key computed from simulation hash and runtime context. - source_path : Path - Location of the artifact to cache. - metadata : CacheEntryMetadata - Metadata describing the cache entry to be persisted. - - Returns - ------- - CacheEntry - Representation of the stored cache entry. - """ - source_path = Path(source_path) - if not source_path.exists(): - raise FileNotFoundError(f"Cannot cache missing artifact: {source_path}") - os.makedirs(self._root, exist_ok=True) - tmp_dir = Path(tempfile.mkdtemp(prefix=TMP_PREFIX, dir=self._root)) - tmp_artifact = tmp_dir / CACHE_ARTIFACT_NAME - tmp_meta = tmp_dir / CACHE_METADATA_NAME - os.makedirs(tmp_dir, exist_ok=True) - - checksum, file_size = _copy_and_hash(source_path, tmp_artifact) - metadata.cache_key = key - metadata.created_at = datetime.now(timezone.utc) - metadata.last_used = metadata.created_at - metadata.checksum = checksum - metadata.file_size = file_size - - _write_metadata(tmp_meta, metadata) - entry: Optional[CacheEntry] = None - try: - with self._with_lock(): - self._root.mkdir(parents=True, exist_ok=True) - existing_entry = self._load_entry(key) - previous_size = ( - existing_entry.metadata.file_size if existing_entry is not None else 0 - ) - self._ensure_limits( - file_size, - incoming_key=key, - replacing_size=previous_size, - ) - final_dir = get_cache_entry_dir(self._root, key) - final_dir.parent.mkdir(parents=True, exist_ok=True) - if final_dir.exists(): - shutil.rmtree(final_dir) - os.replace(tmp_dir, final_dir) - entry = CacheEntry(key=key, root=self._root, metadata=metadata) - - self._record_store_stats( - key, - last_used=metadata.last_used.isoformat(), - file_size=file_size, - previous_size=previous_size, - ) - log.debug("Stored simulation cache entry '%s' (%d bytes).", key, file_size) - finally: - try: - if tmp_dir.exists(): - shutil.rmtree(tmp_dir, ignore_errors=True) - except FileNotFoundError: - pass - return entry - - def invalidate(self, key: str) -> None: - with self._with_lock(): - entry = self._load_entry(key) - if entry: - self._remove_entry(entry) - - def _ensure_limits( - self, - incoming_size: int, - *, - incoming_key: Optional[str] = None, - replacing_size: int = 0, - ) -> None: - max_entries = self.max_entries - max_size_bytes = int(self.max_size_gb * (1024**3)) - - try: - incoming_size_int = int(incoming_size) - except (TypeError, ValueError): - incoming_size_int = 0 - if incoming_size_int < 0: - incoming_size_int = 0 - - stats = self._load_stats() - entries_info = dict(stats.last_used) - existing_keys = set(entries_info) - projected_entries = stats.total_entries - if not incoming_key or incoming_key not in existing_keys: - projected_entries += 1 - - if projected_entries > max_entries > 0: - excess = projected_entries - max_entries - exclude = {incoming_key} if incoming_key else set() - self._evict(entries_info, remove_count=excess, exclude_keys=exclude) - stats = self._load_stats() - entries_info = dict(stats.last_used) - existing_keys = set(entries_info) - - if max_size_bytes == 0: # no limit - return - existing_size = stats.total_size - try: - replacing_size_int = int(replacing_size) - except (TypeError, ValueError): - replacing_size_int = 0 - if incoming_key and incoming_key in existing_keys: - projected_size = existing_size - replacing_size_int + incoming_size_int - else: - projected_size = existing_size + incoming_size_int +if TYPE_CHECKING: + import os - if max_size_bytes > 0 and projected_size > max_size_bytes: - bytes_to_free = projected_size - max_size_bytes - exclude = {incoming_key} if incoming_key else set() - self._evict_by_size(entries_info, bytes_to_free, exclude_keys=exclude) + from tidy3d.components.mode.mode_solver import ModeSolver + from tidy3d.components.types.workflow import WorkflowType - def _evict(self, entries: dict[str, str], *, remove_count: int, exclude_keys: set[str]) -> None: - if remove_count <= 0: - return - candidates = [(key, entries.get(key, "")) for key in entries if key not in exclude_keys] - if not candidates: - return - candidates.sort(key=lambda item: item[1] or "") - for key, _ in candidates[:remove_count]: - self._remove_entry_by_key(key) - def _evict_by_size( - self, entries: dict[str, str], bytes_to_free: int, *, exclude_keys: set[str] - ) -> None: - if bytes_to_free <= 0: - return - candidates = [(key, entries.get(key, "")) for key in entries if key not in exclude_keys] - if not candidates: - return - candidates.sort(key=lambda item: item[1] or "") - reclaimed = 0 - for key, _ in candidates: - if reclaimed >= bytes_to_free: - break - entry = self._load_entry(key) - if entry is None: - log.debug("Could not find entry for eviction.") - self._schedule_sync() - break - size = entry.metadata.file_size - self._remove_entry(entry) - reclaimed += size - log.info(f"Simulation cache evicted entry '{key}' to reclaim {size} bytes.") +def get_workflow_type(simulation: WorkflowType) -> str: + """Resolve workflow type name for cache logging.""" + return Tidy3dStub(simulation=simulation).get_type() - def _iter_entries(self) -> Iterator[CacheEntry]: - """Iterate lazily over all cache entries, including those in prefix subdirectories.""" - if not self._root.exists(): - return - for prefix_dir in self._root.iterdir(): - if not prefix_dir.is_dir() or prefix_dir.name.startswith( - (TMP_PREFIX, TMP_BATCH_PREFIX) - ): - continue +register_get_workflow_type(get_workflow_type) - # if cache is directly flat (no prefix directories), include that level too - subdirs = [prefix_dir] - if any((prefix_dir / name).is_dir() for name in prefix_dir.iterdir()): - subdirs = prefix_dir.iterdir() - for child in subdirs: - if not child.is_dir(): - continue - if child.name.startswith((TMP_PREFIX, TMP_BATCH_PREFIX)): - continue - - meta_path = child / CACHE_METADATA_NAME - if not meta_path.exists(): - continue - - try: - metadata = _read_metadata(meta_path, child / CACHE_ARTIFACT_NAME) - except Exception: - log.debug( - "Failed to parse metadata for '%s'; scheduling stats sync.", child.name - ) - self._schedule_sync() - continue - - yield CacheEntry(key=child.name, root=self._root, metadata=metadata) - - def _load_entry(self, key: str) -> Optional[CacheEntry]: - entry = CacheEntry(key=key, root=self._root, metadata={}) - if not entry.metadata_path.exists() or not entry.artifact_path.exists(): - return None - try: - metadata = _read_metadata(entry.metadata_path, entry.artifact_path) - except Exception: - return None - return CacheEntry(key=key, root=self._root, metadata=metadata) - - def _touch(self, entry: CacheEntry) -> None: - entry.metadata.bump_last_used() - _write_metadata(entry.metadata_path, entry.metadata) - self._record_touch_stats( - entry.key, - entry.metadata.last_used.isoformat(), - file_size=entry.metadata.file_size, - ) - - def _remove_entry_by_key(self, key: str) -> None: - entry = self._load_entry(key) - if entry is None: - path = get_cache_entry_dir(self._root, key) - if path.exists(): - shutil.rmtree(path, ignore_errors=True) - else: - log.debug("Could not find entry for key '%s' to delete.", key) - self._record_remove_stats(key, 0) - return - self._remove_entry(entry) - - def _remove_entry(self, entry: CacheEntry) -> None: - file_size = entry.metadata.file_size - if entry.path.exists(): - shutil.rmtree(entry.path, ignore_errors=True) - self._record_remove_stats(entry.key, file_size) - - def try_fetch( - self, - simulation: WorkflowType, - verbose: bool = False, - ) -> Optional[CacheEntry]: - """ - Attempt to resolve and fetch a cached result entry for the given simulation context. - On miss or any cache error, returns None (the caller should proceed with upload/run). - """ - try: - simulation_hash = simulation._hash_self() - workflow_type = Tidy3dStub(simulation=simulation).get_type() - - versions = _get_protocol_version() - - cache_key = build_cache_key( - simulation_hash=simulation_hash, - version=versions, - ) - - entry = self._fetch(cache_key) - if not entry: - return None - - if verbose: - log.info( - f"Simulation cache hit for workflow '{workflow_type}'; using local results." - ) - - return entry - except Exception as e: - log.error("Failed to fetch cache results: " + str(e)) - - def store_result( - self, - stub_data: WorkflowDataType, - task_id: TaskId, - path: str, - workflow_type: str, - simulation: Optional[WorkflowType] = None, - ) -> bool: - """ - Stores completed workflow results in the local cache using a canonical cache key. - - Parameters - ---------- - stub_data : :class:`.WorkflowDataType` - Object containing the workflow results, including references to the originating simulation. - task_id : str - Unique identifier of the finished workflow task. - path : str - Path to the results file on disk. - workflow_type : str - Type of workflow associated with the results (e.g., ``"SIMULATION"`` or ``"MODE_SOLVER"``). - simulation : Optional[:class:`.WorkflowDataType`] - Simulation object to use when computing the cache key. If not provided, - it will be inferred from ``stub_data.simulation`` when possible. - - Returns - ------- - bool - ``True`` if the result was successfully stored in the local cache, ``False`` otherwise. - - Notes - ----- - The cache entry is keyed by the simulation hash, workflow type, environment, and protocol version. - This enables automatic reuse of identical simulation results across future runs. - Legacy task ID mappings are recorded to support backward lookup compatibility. - """ - try: - if simulation is not None: - simulation_obj = simulation - else: - simulation_obj = getattr(stub_data, "simulation", None) - if simulation_obj is None: - log.debug( - "Failed storing local cache entry: Could not find simulation data in stub_data." - ) - return False - simulation_hash = simulation_obj._hash_self() if simulation_obj is not None else None - if not simulation_hash: - log.debug("Failed storing local cache entry: Could not hash simulation.") - return False - - version = _get_protocol_version() - - cache_key = build_cache_key( - simulation_hash=simulation_hash, - version=version, - ) - - metadata = build_entry_metadata( - simulation_hash=simulation_hash, - workflow_type=workflow_type, - task_id=task_id, - version=version, - path=Path(path), - ) - - self._store( - key=cache_key, - source_path=Path(path), - metadata=metadata, - ) - log.debug("Stored local cache entry for workflow type '%s'.", workflow_type) - except Exception as e: - log.error(f"Could not store cache entry: {e}") - return False - return True - - -def _copy_and_hash( - source: Path, dest: Optional[Path], existing_hash: Optional[str] = None -) -> tuple[str, int]: - """Copy ``source`` to ``dest`` while computing SHA256 checksum. - - Parameters - ---------- - source : Path - Source file path. - dest : Path or None - Destination file path. If ``None``, no copy is performed. - existing_hash : str, optional - If provided alongside ``dest`` and ``dest`` already exists, skip copying when hashes match. - - Returns - ------- - tuple[str, int] - The hexadecimal digest and file size in bytes. - """ - source = Path(source) - if dest is not None: - dest = Path(dest) - sha256 = _Hasher() - size = 0 - with source.open("rb") as src: - if dest is None: - while chunk := src.read(1024 * 1024): - sha256.update(chunk) - size += len(chunk) - else: - dest.parent.mkdir(parents=True, exist_ok=True) - with dest.open("wb") as dst: - while chunk := src.read(1024 * 1024): - dst.write(chunk) - sha256.update(chunk) - size += len(chunk) - return sha256.hexdigest(), size - - -def _write_metadata(path: Path, metadata: CacheEntryMetadata | dict[str, Any]) -> None: - tmp_path = path.with_suffix(".tmp") - payload: dict[str, Any] - if isinstance(metadata, CacheEntryMetadata): - payload = metadata.model_dump(mode="json") - else: - payload = metadata - with tmp_path.open("w", encoding="utf-8") as fh: - json.dump(payload, fh, indent=2, sort_keys=True) - os.replace(tmp_path, path) - - -def _now() -> str: - return datetime.now(timezone.utc).isoformat() - - -def _timestamp_suffix() -> str: - return datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S%f") - - -def _read_metadata(meta_path: Path, artifact_path: Path) -> CacheEntryMetadata: - raw = json.loads(meta_path.read_text(encoding="utf-8")) - if "file_size" not in raw: - try: - raw["file_size"] = artifact_path.stat().st_size - except FileNotFoundError: - raw["file_size"] = 0 - raw.setdefault("created_at", _now()) - raw.setdefault("last_used", raw["created_at"]) - raw.setdefault("cache_key", meta_path.parent.name) - return CacheEntryMetadata.model_validate(raw) - - -class _Hasher: - def __init__(self) -> None: - self._hasher = hashlib.sha256() - - def update(self, data: bytes) -> None: - self._hasher.update(data) - - def hexdigest(self) -> str: - return self._hasher.hexdigest() - - -def clear() -> None: - """Remove all cache entries.""" - cache = resolve_local_cache(use_cache=True) - if cache is not None: - cache.clear() - - -def _canonicalize(value: Any) -> Any: - """Convert value into a JSON-serializable object for hashing/metadata.""" - - if isinstance(value, dict): - return { - str(k): _canonicalize(v) - for k, v in sorted(value.items(), key=lambda item: str(item[0])) - } - if isinstance(value, (list, tuple)): - return [_canonicalize(v) for v in value] - if isinstance(value, set): - return sorted(_canonicalize(v) for v in value) - if isinstance(value, Enum): - return value.value - if isinstance(value, Path): - return str(value) - if isinstance(value, datetime): - return value.isoformat() - if isinstance(value, bytes): - return value.decode("utf-8", errors="ignore") - return value - - -def build_cache_key( - *, - simulation_hash: str, - version: str, -) -> str: - """Construct a deterministic cache key.""" - - payload = { - "simulation_hash": simulation_hash, - "versions": _canonicalize(version), - } - encoded = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") - return hashlib.sha256(encoded).hexdigest() - - -def build_entry_metadata( - *, - simulation_hash: str, - workflow_type: str, - task_id: str, - version: str, - path: Path, -) -> CacheEntryMetadata: - """Create metadata object for a cache entry.""" - - now = datetime.now(timezone.utc) - return CacheEntryMetadata( - cache_key="", - checksum="", - created_at=now, - last_used=now, - file_size=0, - simulation_hash=simulation_hash, - workflow_type=workflow_type, - versions=_canonicalize(version), - task_id=task_id, - path=str(path), - ) - - -def resolve_local_cache(use_cache: Optional[bool] = None) -> Optional[LocalCache]: - """ - Returns LocalCache instance if enabled. - Returns None if use_cached=False or config-fetched 'enabled' is False. - Deletes old cache directory if existing. - """ - global _CACHE - - if use_cache is False or (use_cache is not True and not config.local_cache.enabled): - return None - - if _CACHE is not None and _CACHE._root != Path(config.local_cache.directory): - old_root = _CACHE._root - new_root = Path(config.local_cache.directory) - log.debug(f"Moving cache directory from {old_root} → {new_root}") - try: - new_root.parent.mkdir(parents=True, exist_ok=True) - if old_root.exists(): - shutil.move(old_root, new_root) - except Exception as e: - log.warning(f"Failed to move cache directory: {e}. Delete old cache.") - shutil.rmtree(old_root) - - _CACHE = LocalCache( - directory=config.local_cache.directory, - max_entries=config.local_cache.max_entries, - max_size_gb=config.local_cache.max_size_gb, - ) - - try: - return _CACHE - except Exception as err: - log.debug(f"Simulation cache unavailable: {err}") - return None - - -def _store_mode_solver_in_cache( - task_id: TaskId, simulation: ModeSolver, data: WorkflowDataType, path: os.PathLike -) -> bool: +def _store_mode_solver_in_cache(task_id: str, simulation: ModeSolver, path: os.PathLike) -> bool: """ Stores the results of a :class:`.ModeSolver` run in the local cache, if available. @@ -878,8 +60,6 @@ def _store_mode_solver_in_cache( Unique identifier of the mode solver task. simulation : :class:`.ModeSolver` Mode solver simulation object whose results should be cached. - data : :class:`.WorkflowDataType` - Data object containing the computed results to store. path : PathLike Path to the result file on disk. @@ -896,7 +76,6 @@ def _store_mode_solver_in_cache( simulation_cache = resolve_local_cache() if simulation_cache is not None: stored = simulation_cache.store_result( - stub_data=data, task_id=task_id, path=path, workflow_type=TaskType.MODE_SOLVER.name, @@ -904,6 +83,3 @@ def _store_mode_solver_in_cache( ) return stored return False - - -resolve_local_cache() diff --git a/tidy3d/web/cli/cache.py b/tidy3d/web/cli/cache.py index 8d9c5454cf..78b999b0b5 100644 --- a/tidy3d/web/cli/cache.py +++ b/tidy3d/web/cli/cache.py @@ -2,13 +2,18 @@ from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING import click from tidy3d import config -from tidy3d.web.cache import LocalCache, resolve_local_cache from tidy3d.web.cache import clear as clear_cache +from tidy3d.web.cache import resolve_local_cache + +if TYPE_CHECKING: + from typing import Optional + + from tidy3d.web.cache import LocalCache def _fmt_size(num_bytes: int) -> str: diff --git a/tidy3d/web/cli/develop/documentation.py b/tidy3d/web/cli/develop/documentation.py index ab86cc5509..00c37fbc90 100644 --- a/tidy3d/web/cli/develop/documentation.py +++ b/tidy3d/web/cli/develop/documentation.py @@ -17,13 +17,16 @@ import json import os -from typing import Any, Optional +from typing import TYPE_CHECKING, Any import click from .index import develop from .utils import echo_and_check_subprocess, get_install_directory +if TYPE_CHECKING: + from typing import Optional + __all__ = [ "build_documentation", # "build_documentation_pdf", diff --git a/tidy3d/web/cli/develop/install.py b/tidy3d/web/cli/develop/install.py index a80a5c9849..39fef01308 100644 --- a/tidy3d/web/cli/develop/install.py +++ b/tidy3d/web/cli/develop/install.py @@ -9,13 +9,16 @@ import platform import re import subprocess -from typing import Any, Optional +from typing import TYPE_CHECKING, Any import click from .index import develop from .utils import echo_and_check_subprocess, echo_and_run_subprocess, get_install_directory +if TYPE_CHECKING: + from typing import Optional + __all__ = [ "activate_correct_poetry_python", "configure_submodules", @@ -113,25 +116,23 @@ def verify_pipx_is_installed() -> Optional[bool]: result = echo_and_run_subprocess( ["pipx", "--version"], capture_output=True, text=True, check=True ) - # If the command was successful, it means pipx is installed - if result.returncode == 0: - print("pipx is installed: " + result.stdout) - return True + print("pipx is installed: " + result.stdout) + return True except subprocess.CalledProcessError: # This exception is raised if the command returned a non-zero exit status print("pipx is not installed or not found in the system PATH.") return False -def verify_poetry_is_installed() -> Optional[bool]: +def verify_poetry_is_installed() -> bool: """ Check if Poetry is installed on the system. Returns ------- bool - True if Poetry is installed, False otherwise. + True if Poetry is installed, raises `OSError` otherwise. Raises ------ @@ -144,9 +145,8 @@ def verify_poetry_is_installed() -> Optional[bool]: ["poetry", "--version"], capture_output=True, text=True, check=True ) # If the command was successful, we'll get the version info - if result.returncode == 0: - print("Poetry is installed: " + result.stdout) - return True + print("Poetry is installed: " + result.stdout) + return True except subprocess.CalledProcessError as exc: # This exception is raised if the command returned a non-zero exit status raise OSError("Poetry is not installed or not found in the system PATH.") from exc diff --git a/tidy3d/web/core/__init__.py b/tidy3d/web/core/__init__.py index 281ccca03a..f353d09be6 100644 --- a/tidy3d/web/core/__init__.py +++ b/tidy3d/web/core/__init__.py @@ -1,8 +1,10 @@ -"""Tidy3d core package imports""" +"""Compatibility shim for :mod:`tidy3d._common.web.core`.""" -from __future__ import annotations +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -# TODO(FXC-3827): Drop this import once the legacy shim is removed in Tidy3D 2.12. -from . import environment +# marked as migrated to _common +from __future__ import annotations -__all__ = ["environment"] +from tidy3d._common.web.core import ( + environment, +) diff --git a/tidy3d/web/core/account.py b/tidy3d/web/core/account.py index 0eda27928f..fbc4c856da 100644 --- a/tidy3d/web/core/account.py +++ b/tidy3d/web/core/account.py @@ -1,66 +1,10 @@ -"""Tidy3d user account.""" - -from __future__ import annotations - -from datetime import datetime -from typing import Optional - -from pydantic.v1 import Extra, Field - -from .http_util import http -from .types import Tidy3DResource +"""Compatibility shim for :mod:`tidy3d._common.web.core.account`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -class Account(Tidy3DResource, extra=Extra.allow): - """Tidy3D User Account.""" - - allowance_cycle_type: Optional[str] = Field( - None, - title="AllowanceCycleType", - description="Daily or Monthly", - alias="allowanceCycleType", - ) - credit: Optional[float] = Field( - 0, title="credit", description="Current FlexCredit balance", alias="credit" - ) - credit_expiration: Optional[datetime] = Field( - None, - title="creditExpiration", - description="Expiration date", - alias="creditExpiration", - ) - allowance_current_cycle_amount: Optional[float] = Field( - 0, - title="allowanceCurrentCycleAmount", - description="Daily/Monthly free simulation balance", - alias="allowanceCurrentCycleAmount", - ) - allowance_current_cycle_end_date: Optional[datetime] = Field( - None, - title="allowanceCurrentCycleEndDate", - description="Daily/Monthly free simulation balance expiration date", - alias="allowanceCurrentCycleEndDate", - ) - daily_free_simulation_counts: Optional[int] = Field( - 0, - title="dailyFreeSimulationCounts", - description="Daily free simulation counts", - alias="dailyFreeSimulationCounts", - ) - - @classmethod - def get(cls) -> Optional[Account]: - """Get user account information. - - Parameters - ---------- +# marked as migrated to _common +from __future__ import annotations - Returns - ------- - account : Account - """ - resp = http.get("tidy3d/py/account") - if resp: - account = Account(**resp) - return account - return None +from tidy3d._common.web.core.account import ( + Account, +) diff --git a/tidy3d/web/core/cache.py b/tidy3d/web/core/cache.py index d83421ca21..115080c123 100644 --- a/tidy3d/web/core/cache.py +++ b/tidy3d/web/core/cache.py @@ -1,6 +1,11 @@ -"""Local caches.""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.cache`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility + +# marked as migrated to _common from __future__ import annotations -FOLDER_CACHE = {} -S3_STS_TOKENS = {} +from tidy3d._common.web.core.cache import ( + FOLDER_CACHE, + S3_STS_TOKENS, +) diff --git a/tidy3d/web/core/constants.py b/tidy3d/web/core/constants.py index 623af2bba8..bb03702f91 100644 --- a/tidy3d/web/core/constants.py +++ b/tidy3d/web/core/constants.py @@ -1,38 +1,34 @@ -"""Defines constants for core.""" - -# HTTP Header key and value -from __future__ import annotations - -HEADER_APIKEY = "simcloud-api-key" -HEADER_VERSION = "tidy3d-python-version" -HEADER_SOURCE = "source" -HEADER_SOURCE_VALUE = "Python" -HEADER_USER_AGENT = "User-Agent" -HEADER_APPLICATION = "Application" -HEADER_APPLICATION_VALUE = "TIDY3D" +"""Compatibility shim for :mod:`tidy3d._common.web.core.constants`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -SIMCLOUD_APIKEY = "SIMCLOUD_APIKEY" -KEY_APIKEY = "apikey" -JSON_TAG = "JSON_STRING" -# type of the task_id -TaskId = str -# type of task_name -TaskName = str - - -SIMULATION_JSON = "simulation.json" -SIMULATION_DATA_HDF5 = "output/monitor_data.hdf5" -SIMULATION_DATA_HDF5_GZ = "output/simulation_data.hdf5.gz" -RUNNING_INFO = "output/solver_progress.csv" -SIM_LOG_FILE = "output/tidy3d.log" -SIM_FILE_HDF5 = "simulation.hdf5" -SIM_FILE_HDF5_GZ = "simulation.hdf5.gz" -MODE_FILE_HDF5_GZ = "mode_solver.hdf5.gz" -MODE_DATA_HDF5_GZ = "output/mode_solver_data.hdf5.gz" -SIM_ERROR_FILE = "output/tidy3d_error.json" -SIM_VALIDATION_FILE = "output/tidy3d_validation.json" +# marked as migrated to _common +from __future__ import annotations -# Component modeler specific artifacts -MODELER_FILE_HDF5_GZ = "modeler.hdf5.gz" -CM_DATA_HDF5_GZ = "output/cm_data.hdf5.gz" +from tidy3d._common.web.core.constants import ( + CM_DATA_HDF5_GZ, + HEADER_APIKEY, + HEADER_APPLICATION, + HEADER_APPLICATION_VALUE, + HEADER_SOURCE, + HEADER_SOURCE_VALUE, + HEADER_USER_AGENT, + HEADER_VERSION, + JSON_TAG, + KEY_APIKEY, + MODE_DATA_HDF5_GZ, + MODE_FILE_HDF5_GZ, + MODELER_FILE_HDF5_GZ, + RUNNING_INFO, + SIM_ERROR_FILE, + SIM_FILE_HDF5, + SIM_FILE_HDF5_GZ, + SIM_LOG_FILE, + SIM_VALIDATION_FILE, + SIMCLOUD_APIKEY, + SIMULATION_DATA_HDF5, + SIMULATION_DATA_HDF5_GZ, + SIMULATION_JSON, + TaskId, + TaskName, +) diff --git a/tidy3d/web/core/core_config.py b/tidy3d/web/core/core_config.py index 72d2853c3a..b3640c8df2 100644 --- a/tidy3d/web/core/core_config.py +++ b/tidy3d/web/core/core_config.py @@ -1,48 +1,14 @@ -"""Tidy3d core log, need init config from Tidy3d api""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.core_config`.""" -from __future__ import annotations - -import logging as log - -from rich.console import Console - -from tidy3d.log import Logger - -# default setting -config_setting = { - "logger": log, - "logger_console": None, - "version": "", -} - - -def set_config(logger: Logger, logger_console: Console, version: str) -> None: - """Init tidy3d core logger and logger console. - - Parameters - ---------- - logger : :class:`.Logger` - Tidy3d log Logger. - logger_console : :class:`.Console` - Get console from logging handlers. - version : str - tidy3d version - """ - config_setting["logger"] = logger - config_setting["logger_console"] = logger_console - config_setting["version"] = version - - -def get_logger() -> Logger: - """Get logging handlers.""" - return config_setting["logger"] - - -def get_logger_console() -> Console: - """Get console from logging handlers.""" - return config_setting["logger_console"] +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -def get_version() -> str: - """Get version from cache.""" - return config_setting["version"] +from tidy3d._common.web.core.core_config import ( + config_setting, + get_logger, + get_logger_console, + get_version, + set_config, +) diff --git a/tidy3d/web/core/environment.py b/tidy3d/web/core/environment.py index ffe86d89d4..7873b9b4d2 100644 --- a/tidy3d/web/core/environment.py +++ b/tidy3d/web/core/environment.py @@ -1,42 +1,20 @@ -"""Legacy re-export of configuration environment helpers.""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.environment`.""" -from __future__ import annotations - -# TODO(FXC-3827): Remove this module-level legacy shim in Tidy3D 2.12. -import warnings -from typing import Any - -from tidy3d.config import Env, Environment, EnvironmentConfig +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -__all__ = [ # noqa: F822 - "Env", - "Environment", - "EnvironmentConfig", - "dev", - "nexus", - "pre", - "prod", - "uat", -] +# marked as migrated to _common +from __future__ import annotations -_LEGACY_ENV_NAMES = {"dev", "uat", "pre", "prod", "nexus"} -_DEPRECATION_MESSAGE = ( - "'tidy3d.web.core.environment.{name}' is deprecated and will be removed in " - "Tidy3D 2.12. Transition to 'tidy3d.config.Env.{name}' or " - "'tidy3d.config.config.switch_profile(...)'." +from tidy3d._common.web.core.environment import ( + _DEPRECATION_MESSAGE, + _LEGACY_ENV_NAMES, + Env, + Environment, + EnvironmentConfig, + _get_legacy_env, + dev, + nexus, + pre, + prod, + uat, ) - - -def _get_legacy_env(name: str) -> Any: - warnings.warn(_DEPRECATION_MESSAGE.format(name=name), DeprecationWarning, stacklevel=2) - return getattr(Env, name) - - -def __getattr__(name: str) -> Any: - if name in _LEGACY_ENV_NAMES: - return _get_legacy_env(name) - raise AttributeError(f"module '{__name__}' has no attribute '{name}'") - - -def __dir__() -> list[str]: - return sorted(set(__all__)) diff --git a/tidy3d/web/core/exceptions.py b/tidy3d/web/core/exceptions.py index 3dd6f15a35..1900371e9b 100644 --- a/tidy3d/web/core/exceptions.py +++ b/tidy3d/web/core/exceptions.py @@ -1,21 +1,11 @@ -"""Custom Tidy3D exceptions""" - -from __future__ import annotations - -from typing import Optional - -from .core_config import get_logger +"""Compatibility shim for :mod:`tidy3d._common.web.core.exceptions`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -class WebError(Exception): - """Any error in tidy3d""" - - def __init__(self, message: Optional[str] = None) -> None: - """Log just the error message and then raise the Exception.""" - log = get_logger() - super().__init__(message) - log.error(message) - +# marked as migrated to _common +from __future__ import annotations -class WebNotFoundError(WebError): - """A generic error indicating an HTTP 404 (resource not found).""" +from tidy3d._common.web.core.exceptions import ( + WebError, + WebNotFoundError, +) diff --git a/tidy3d/web/core/file_util.py b/tidy3d/web/core/file_util.py index b3987f495c..b908de53a2 100644 --- a/tidy3d/web/core/file_util.py +++ b/tidy3d/web/core/file_util.py @@ -1,87 +1,15 @@ -"""File compression utilities""" - -from __future__ import annotations - -import gzip -import os -import shutil -import tempfile - -import h5py - -from tidy3d.web.core.constants import JSON_TAG - - -def compress_file_to_gzip(input_file: os.PathLike, output_gz_file: os.PathLike) -> None: - """Compresses a file using gzip. - - Parameters - ---------- - input_file : PathLike - The path of the input file. - output_gz_file : PathLike - The path of the output gzip file. - """ - with open(input_file, "rb") as file_in: - with gzip.open(output_gz_file, "wb") as file_out: - shutil.copyfileobj(file_in, file_out) - - -def extract_gzip_file(input_gz_file: os.PathLike, output_file: os.PathLike) -> None: - """Extract a gzip file. - - Parameters - ---------- - input_gz_file : PathLike - The path of the gzip input file. - output_file : PathLike - The path of the output file. - """ - with gzip.open(input_gz_file, "rb") as file_in: - with open(output_file, "wb") as file_out: - shutil.copyfileobj(file_in, file_out) +"""Compatibility shim for :mod:`tidy3d._common.web.core.file_util`.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -def read_simulation_from_hdf5_gz(file_name: os.PathLike) -> str: - """read simulation str from hdf5.gz""" - - hdf5_file, hdf5_file_path = tempfile.mkstemp(".hdf5") - os.close(hdf5_file) - try: - extract_gzip_file(file_name, hdf5_file_path) - # Pass the uncompressed temporary file path to the reader - json_str = read_simulation_from_hdf5(hdf5_file_path) - finally: - os.unlink(hdf5_file_path) - return json_str - - -"""TODO: _json_string_key and read_simulation_from_hdf5 are duplicated functions that also exist -as methods in Tidy3dBaseModel. For consistency it would be best if this duplication is avoided.""" - - -def _json_string_key(index: int) -> str: - """Get json string key for string chunk number ``index``.""" - if index: - return f"{JSON_TAG}_{index}" - return JSON_TAG - - -def read_simulation_from_hdf5(file_name: os.PathLike) -> bytes: - """read simulation str from hdf5""" - with h5py.File(file_name, "r") as f_handle: - num_string_parts = len([key for key in f_handle.keys() if JSON_TAG in key]) - json_string = b"" - for ind in range(num_string_parts): - json_string += f_handle[_json_string_key(ind)][()] - return json_string - - -"""End TODO""" - +# marked as migrated to _common +from __future__ import annotations -def read_simulation_from_json(file_name: os.PathLike) -> str: - """read simulation str from json""" - with open(file_name) as json_file: - json_data = json_file.read() - return json_data +from tidy3d._common.web.core.file_util import ( + _json_string_key, + compress_file_to_gzip, + extract_gzip_file, + read_simulation_from_hdf5, + read_simulation_from_hdf5_gz, + read_simulation_from_json, +) diff --git a/tidy3d/web/core/http_util.py b/tidy3d/web/core/http_util.py index 35c73b3cbb..31170a69f5 100644 --- a/tidy3d/web/core/http_util.py +++ b/tidy3d/web/core/http_util.py @@ -1,280 +1,20 @@ -"""Http connection pool and authentication management.""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.http_util`.""" -from __future__ import annotations - -import json -import os -import ssl -from enum import Enum -from functools import wraps -from typing import Any, Callable, Optional, TypeAlias +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import requests -from requests.adapters import HTTPAdapter -from urllib3.util.ssl_ import create_urllib3_context - -from tidy3d import log -from tidy3d.config import config +# marked as migrated to _common +from __future__ import annotations -from . import core_config -from .constants import ( - HEADER_APIKEY, - HEADER_APPLICATION, - HEADER_APPLICATION_VALUE, - HEADER_SOURCE, - HEADER_SOURCE_VALUE, - HEADER_USER_AGENT, - HEADER_VERSION, - SIMCLOUD_APIKEY, +from tidy3d._common.web.core.http_util import ( + HttpSessionManager, + JSONType, + ResponseCodes, + TLSAdapter, + api_key, + api_key_auth, + get_headers, + get_user_agent, + get_version, + http, + http_interceptor, ) -from .core_config import get_logger -from .exceptions import WebError, WebNotFoundError - -JSONType: TypeAlias = dict[str, Any] | list[Any] | str | int - - -class ResponseCodes(Enum): - """HTTP response codes to handle individually.""" - - UNAUTHORIZED = 401 - OK = 200 - NOT_FOUND = 404 - - -def get_version() -> str: - """Get the version for the current environment.""" - return core_config.get_version() - # return "2.10.0rc2.1" - - -def get_user_agent() -> str: - """Get the user agent the current environment.""" - return os.environ.get("TIDY3D_AGENT", f"Python-Client/{get_version()}") - - -def api_key() -> None: - """Get the api key for the current environment.""" - - if os.environ.get(SIMCLOUD_APIKEY): - return os.environ.get(SIMCLOUD_APIKEY) - - try: - apikey = config.web.apikey - except AttributeError: - return None - - if apikey is None: - return None - if hasattr(apikey, "get_secret_value"): - return apikey.get_secret_value() - return str(apikey) - - -def api_key_auth(request: requests.request) -> requests.request: - """Save the authentication info in a request. - - Parameters - ---------- - request : requests.request - The original request to set authentication for. - - Returns - ------- - requests.request - The request with authentication set. - """ - key = api_key() - version = get_version() - if not key: - raise ValueError( - "API key not found. To get your API key, sign into 'https://tidy3d.simulation.cloud' " - "and copy it from your 'Account' page. Then you can configure tidy3d through command " - "line 'tidy3d configure' and enter your API key when prompted. " - "Alternatively, especially if using windows, you can manually create the configuration " - "file by creating a file at their home directory '~/.tidy3d/config' (unix) or " - "'.tidy3d/config' (windows) containing the following line: " - "apikey = 'XXX'. Here XXX is your API key copied from your account page within quotes." - ) - if not version: - raise ValueError("version not found.") - - request.headers[HEADER_APIKEY] = key - request.headers[HEADER_VERSION] = version - request.headers[HEADER_SOURCE] = HEADER_SOURCE_VALUE - request.headers[HEADER_USER_AGENT] = get_user_agent() - return request - - -def get_headers() -> dict[str, str]: - """get headers for http request. - - Returns - ------- - Dict[str, str] - dictionary with "Authorization" and "Application" keys. - """ - return { - HEADER_APIKEY: api_key(), - HEADER_APPLICATION: HEADER_APPLICATION_VALUE, - HEADER_USER_AGENT: get_user_agent(), - } - - -def http_interceptor(func: Callable[..., Any]) -> Callable[..., JSONType]: - """Intercept the response and raise an exception if the status code is not 200.""" - - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> JSONType: - """The wrapper function.""" - suppress_404 = kwargs.pop("suppress_404", False) - - # Extend some capabilities of func - resp = func(*args, **kwargs) - - if resp.status_code != ResponseCodes.OK.value: - if resp.status_code == ResponseCodes.NOT_FOUND.value: - if suppress_404: - return None - raise WebNotFoundError("Resource not found (HTTP 404).") - try: - json_resp = resp.json() - except Exception: - json_resp = None - - # Build a helpful error message using available fields - err_msg = None - if isinstance(json_resp, dict): - parts = [] - for key in ("error", "message", "msg", "detail", "code", "httpStatus", "warning"): - val = json_resp.get(key) - if not val: - continue - if key == "error": - # Always include the raw 'error' payload for debugging. Also try to extract a nested message. - if isinstance(val, str): - try: - nested = json.loads(val) - if isinstance(nested, dict): - nested_msg = ( - nested.get("message") - or nested.get("error") - or nested.get("msg") - ) - if nested_msg: - parts.append(str(nested_msg)) - except Exception: - pass - parts.append(f"error={val}") - else: - parts.append(f"error={val!s}") - continue - parts.append(str(val)) - if parts: - err_msg = "; ".join(parts) - if not err_msg: - # Fallback to response text or status code - err_msg = resp.text or f"HTTP {resp.status_code}" - - # Append request context to aid debugging - try: - method = getattr(resp.request, "method", "") - url = getattr(resp.request, "url", "") - err_msg = f"{err_msg} [HTTP {resp.status_code} {method} {url}]" - except Exception: - pass - - raise WebError(err_msg) - - if not resp.text: - return None - result = resp.json() - - if isinstance(result, dict): - warning = result.get("warning") - if warning: - log = get_logger() - log.warning(warning) - - return result.get("data") if "data" in result else result - - return wrapper - - -class TLSAdapter(HTTPAdapter): - def init_poolmanager(self, *args: Any, **kwargs: Any) -> None: - try: - ssl_version = ( - ssl.TLSVersion[config.web.ssl_version] - if config.web.ssl_version is not None - else None - ) - except KeyError: - log.warning(f"Invalid SSL/TLS version '{config.web.ssl_version}', using default") - ssl_version = None - context = create_urllib3_context(ssl_version=ssl_version) - kwargs["ssl_context"] = context - return super().init_poolmanager(*args, **kwargs) - - -class HttpSessionManager: - """Http util class.""" - - def __init__(self, session: requests.Session) -> None: - """Initialize the session.""" - self.session = session - self._mounted_ssl_version = None - self._ensure_tls_adapter(config.web.ssl_version) - self.session.verify = config.web.ssl_verify - - def reinit(self) -> None: - """Reinitialize the session.""" - ssl_version = config.web.ssl_version - self._ensure_tls_adapter(ssl_version) - self.session.verify = config.web.ssl_verify - - def _ensure_tls_adapter(self, ssl_version: str) -> None: - if not ssl_version: - self._mounted_ssl_version = None - return - if self._mounted_ssl_version != ssl_version: - self.session.mount("https://", TLSAdapter()) - self._mounted_ssl_version = ssl_version - - @http_interceptor - def get( - self, path: str, json: JSONType = None, params: Optional[dict[str, Any]] = None - ) -> requests.Response: - """Get the resource.""" - self.reinit() - return self.session.get( - url=config.web.build_api_url(path), auth=api_key_auth, json=json, params=params - ) - - @http_interceptor - def post(self, path: str, json: JSONType = None) -> requests.Response: - """Create the resource.""" - self.reinit() - return self.session.post(config.web.build_api_url(path), json=json, auth=api_key_auth) - - @http_interceptor - def put( - self, path: str, json: JSONType = None, files: Optional[dict[str, Any]] = None - ) -> requests.Response: - """Update the resource.""" - self.reinit() - return self.session.put( - config.web.build_api_url(path), json=json, auth=api_key_auth, files=files - ) - - @http_interceptor - def delete( - self, path: str, json: JSONType = None, params: Optional[dict[str, Any]] = None - ) -> requests.Response: - """Delete the resource.""" - self.reinit() - return self.session.delete( - config.web.build_api_url(path), auth=api_key_auth, json=json, params=params - ) - - -http = HttpSessionManager(requests.Session()) diff --git a/tidy3d/web/core/s3utils.py b/tidy3d/web/core/s3utils.py index 1dfb3ec1e1..401347b271 100644 --- a/tidy3d/web/core/s3utils.py +++ b/tidy3d/web/core/s3utils.py @@ -1,456 +1,22 @@ -"""handles filesystem, storage""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.s3utils`.""" -from __future__ import annotations +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import os -import tempfile -import urllib -from collections.abc import Mapping -from datetime import datetime -from enum import Enum -from os import PathLike -from pathlib import Path -from typing import Any, Callable, Optional +# marked as migrated to _common +from __future__ import annotations -import boto3 -import rich -from boto3.s3.transfer import TransferConfig -from pydantic import BaseModel, Field -from rich.progress import ( - BarColumn, - DownloadColumn, - Progress, - TextColumn, - TimeRemainingColumn, - TransferSpeedColumn, +from tidy3d._common.web.core.s3utils import ( + IN_TRANSIT_SUFFIX, + DownloadProgress, + UploadProgress, + _get_progress, + _s3_config, + _s3_sts_tokens, + _S3Action, + _S3STSToken, + _UserCredential, + download_file, + download_gz_file, + get_s3_sts_token, + upload_file, ) - -from tidy3d.config import config - -from .core_config import get_logger_console -from .exceptions import WebError -from .file_util import extract_gzip_file -from .http_util import http - -IN_TRANSIT_SUFFIX = ".tmp" - - -class _UserCredential(BaseModel): - """Stores information about user credentials.""" - - access_key_id: str = Field(alias="accessKeyId") - expiration: datetime - secret_access_key: str = Field(alias="secretAccessKey") - session_token: str = Field(alias="sessionToken") - - -class _S3STSToken(BaseModel): - """Stores information about S3 token.""" - - cloud_path: str = Field(alias="cloudpath") - user_credential: _UserCredential = Field(alias="userCredentials") - - def get_bucket(self) -> str: - """Get the bucket name for this token.""" - - r = urllib.parse.urlparse(self.cloud_path) - return r.netloc - - def get_s3_key(self) -> str: - """Get the s3 key for this token.""" - - r = urllib.parse.urlparse(self.cloud_path) - return r.path[1:] - - def get_client(self) -> boto3.client: - """Get the boto client for this token. - - Automatically configures custom S3 endpoint if specified in web.env_vars. - """ - - client_kwargs = { - "service_name": "s3", - "region_name": config.web.s3_region, - "aws_access_key_id": self.user_credential.access_key_id, - "aws_secret_access_key": self.user_credential.secret_access_key, - "aws_session_token": self.user_credential.session_token, - "verify": config.web.ssl_verify, - } - - # Add custom S3 endpoint if configured (e.g., for Nexus deployments) - if config.web.env_vars and "AWS_ENDPOINT_URL_S3" in config.web.env_vars: - s3_endpoint = config.web.env_vars["AWS_ENDPOINT_URL_S3"] - client_kwargs["endpoint_url"] = s3_endpoint - - return boto3.client(**client_kwargs) - - def is_expired(self) -> bool: - """True if token is expired.""" - - return ( - self.user_credential.expiration - - datetime.now(tz=self.user_credential.expiration.tzinfo) - ).total_seconds() < 300 - - -class UploadProgress: - """Updates progressbar with the upload status. - - Attributes - ---------- - progress : rich.progress.Progress() - Progressbar instance from rich - ul_task : rich.progress.Task - Progressbar task instance. - """ - - def __init__(self, size_bytes: int, progress: rich.progress.Progress) -> None: - """initialize with the size of file and rich.progress.Progress() instance. - - Parameters - ---------- - size_bytes: int - Number of total bytes to upload. - progress : rich.progress.Progress() - Progressbar instance from rich - """ - self.progress = progress - self.ul_task = self.progress.add_task("[red]Uploading...", total=size_bytes) - - def report(self, bytes_in_chunk: Any) -> None: - """Update the progressbar with the most recent chunk. - - Parameters - ---------- - bytes_in_chunk : int - Description - """ - self.progress.update(self.ul_task, advance=bytes_in_chunk) - - -class DownloadProgress: - """Updates progressbar using the download status. - - Attributes - ---------- - progress : rich.progress.Progress() - Progressbar instance from rich - ul_task : rich.progress.Task - Progressbar task instance. - """ - - def __init__(self, size_bytes: int, progress: rich.progress.Progress) -> None: - """initialize with the size of file and rich.progress.Progress() instance - - Parameters - ---------- - size_bytes: float - Number of total bytes to download. - progress : rich.progress.Progress() - Progressbar instance from rich - """ - self.progress = progress - self.dl_task = self.progress.add_task("[red]Downloading...", total=size_bytes) - - def report(self, bytes_in_chunk: int) -> None: - """Update the progressbar with the most recent chunk. - - Parameters - ---------- - bytes_in_chunk : float - Description - """ - self.progress.update(self.dl_task, advance=bytes_in_chunk) - - -class _S3Action(Enum): - UPLOADING = "↑" - DOWNLOADING = "↓" - - -def _get_progress(action: _S3Action) -> Progress: - """Get the progress of an action.""" - - col = ( - TextColumn(f"[bold green]{_S3Action.DOWNLOADING.value}") - if action == _S3Action.DOWNLOADING - else TextColumn(f"[bold red]{_S3Action.UPLOADING.value}") - ) - return Progress( - col, - TextColumn("[bold blue]{task.fields[filename]}"), - BarColumn(), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", - DownloadColumn(), - "•", - TransferSpeedColumn(), - "•", - TimeRemainingColumn(), - console=get_logger_console(), - ) - - -_s3_config = TransferConfig() - -_s3_sts_tokens: dict[str, _S3STSToken] = {} - - -def get_s3_sts_token( - resource_id: str, file_name: PathLike, extra_arguments: Optional[Mapping[str, str]] = None -) -> _S3STSToken: - """Get s3 sts token for the given resource id and file name. - - Parameters - ---------- - resource_id : str - The resource id, e.g. task id. - file_name : PathLike - The remote file name on S3. - extra_arguments : Mapping[str, str] - Additional arguments for the query url. - - Returns - ------- - _S3STSToken - The S3 STS token. - """ - file_name = str(Path(file_name).as_posix()) - cache_key = f"{resource_id}:{file_name}" - if cache_key not in _s3_sts_tokens or _s3_sts_tokens[cache_key].is_expired(): - method = f"tidy3d/py/tasks/{resource_id}/file?filename={file_name}" - if extra_arguments is not None: - method += "&" + "&".join(f"{k}={v}" for k, v in extra_arguments.items()) - resp = http.get(method) - token = _S3STSToken.model_validate(resp) - _s3_sts_tokens[cache_key] = token - return _s3_sts_tokens[cache_key] - - -def upload_file( - resource_id: str, - path: PathLike, - remote_filename: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - extra_arguments: Optional[Mapping[str, str]] = None, -) -> None: - """Upload a file to S3. - - Parameters - ---------- - resource_id : str - The resource id, e.g. task id. - path : PathLike - Path to the file to upload. - remote_filename : PathLike - The remote file name on S3 relative to the resource context root path. - verbose : bool = True - Whether to display a progressbar for the upload. - progress_callback : Callable[[float], None] = None - User-supplied callback function with ``bytes_in_chunk`` as argument. - extra_arguments : Mapping[str, str] - Additional arguments used to specify the upload bucket. - """ - - path = Path(path) - token = get_s3_sts_token(resource_id, remote_filename, extra_arguments) - - def _upload(_callback: Callable) -> None: - """Perform the upload with a callback function. - - Parameters - ---------- - _callback : Callable[[float], None] - Callback function for upload, accepts ``bytes_in_chunk`` - """ - - with path.open("rb") as data: - token.get_client().upload_fileobj( - data, - Bucket=token.get_bucket(), - Key=token.get_s3_key(), - Callback=_callback, - Config=_s3_config, - ExtraArgs={"ContentEncoding": "gzip"} - if token.get_s3_key().endswith(".gz") - else None, - ) - - if progress_callback is not None: - _upload(progress_callback) - else: - if verbose: - with _get_progress(_S3Action.UPLOADING) as progress: - total_size = path.stat().st_size - task_id = progress.add_task( - "upload", filename=str(remote_filename), total=total_size - ) - - def _callback(bytes_in_chunk: int) -> None: - progress.update(task_id, advance=bytes_in_chunk) - - _upload(_callback) - - progress.update(task_id, completed=total_size, refresh=True) - - else: - _upload(lambda bytes_in_chunk: None) - - -def download_file( - resource_id: str, - remote_filename: PathLike, - to_file: Optional[PathLike] = None, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, -) -> Path: - """Download file from S3. - - Parameters - ---------- - resource_id : str - The resource id, e.g. task id. - remote_filename : PathLike - Path to the remote file. - to_file : PathLike = None - Local filename to save to; if not specified, defaults to ``remote_filename`` in a - directory named after ``resource_id``. - verbose : bool = True - Whether to display a progressbar for the upload - progress_callback : Callable[[float], None] = None - User-supplied callback function with ``bytes_in_chunk`` as argument. - """ - - token = get_s3_sts_token(resource_id, remote_filename) - client = token.get_client() - meta_data = client.head_object(Bucket=token.get_bucket(), Key=token.get_s3_key()) - - # Get only last part of the remote file name - remote_basename = Path(remote_filename).name - - # set to_file if None - if to_file is None: - to_path = Path(resource_id) / remote_basename - else: - to_path = Path(to_file) - - # make the leading directories in the 'to_path', if any - to_path.parent.mkdir(parents=True, exist_ok=True) - - def _download(_callback: Callable) -> None: - """Perform the download with a callback function. - - Parameters - ---------- - _callback : Callable[[float], None] - Callback function for download, accepts ``bytes_in_chunk`` - """ - # Caller can assume the existence of the file means download succeeded. - # So make sure this file does not exist until that assumption is true. - to_path.unlink(missing_ok=True) - # Download to a temporary file. - try: - fd, tmp_file_path_str = tempfile.mkstemp(suffix=IN_TRANSIT_SUFFIX, dir=to_path.parent) - os.close(fd) # `tempfile.mkstemp()` creates and opens a randomly named file. close it. - to_path_tmp = Path(tmp_file_path_str) - client.download_file( - Bucket=token.get_bucket(), - Filename=str(to_path_tmp), - Key=token.get_s3_key(), - Callback=_callback, - Config=_s3_config, - ) - to_path_tmp.rename(to_path) - except Exception as e: - to_path_tmp.unlink(missing_ok=True) # Delete incompletely downloaded file. - raise e - - if progress_callback is not None: - _download(progress_callback) - else: - if verbose: - with _get_progress(_S3Action.DOWNLOADING) as progress: - total_size = meta_data.get("ContentLength", 0) - progress.start() - task_id = progress.add_task("download", filename=remote_basename, total=total_size) - - def _callback(bytes_in_chunk: int) -> None: - progress.update(task_id, advance=bytes_in_chunk) - - _download(_callback) - - progress.update(task_id, completed=total_size, refresh=True) - - else: - _download(lambda bytes_in_chunk: None) - - return to_path - - -def download_gz_file( - resource_id: str, - remote_filename: PathLike, - to_file: Optional[PathLike] = None, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, -) -> Path: - """Download a ``.gz`` file and unzip it into ``to_file``, unless ``to_file`` itself - ends in .gz - - Parameters - ---------- - resource_id : str - The resource id, e.g. task id. - remote_filename : PathLike - Path to the remote file. - to_file : Optional[PathLike] = None - Local filename to save to; if not specified, defaults to ``remote_filename`` with the - ``.gz`` suffix removed in a directory named after ``resource_id``. - verbose : bool = True - Whether to display a progressbar for the upload - progress_callback : Callable[[float], None] = None - User-supplied callback function with ``bytes_in_chunk`` as argument. - """ - - # If to_file is a gzip extension, just download - if to_file is None: - remote_basename = Path(remote_filename).name - if remote_basename.endswith(".gz"): - remote_basename = remote_basename[:-3] - to_path = Path(resource_id) / remote_basename - else: - to_path = Path(to_file) - - suffixes = "".join(to_path.suffixes).lower() - if suffixes.endswith(".gz"): - return download_file( - resource_id, - remote_filename, - to_file=to_path, - verbose=verbose, - progress_callback=progress_callback, - ) - - # Otherwise, download and unzip - # The tempfile is set as ``hdf5.gz`` so that the mock download in the webapi tests works - tmp_file, tmp_file_path_str = tempfile.mkstemp(".hdf5.gz") - os.close(tmp_file) - - # make the leading directories in the 'to_file', if any - to_path.parent.mkdir(parents=True, exist_ok=True) - try: - download_file( - resource_id, - remote_filename, - to_file=Path(tmp_file_path_str), - verbose=verbose, - progress_callback=progress_callback, - ) - if os.path.exists(tmp_file_path_str): - extract_gzip_file(Path(tmp_file_path_str), to_path) - else: - raise WebError(f"Failed to download and extract '{remote_filename}'.") - finally: - os.unlink(tmp_file_path_str) - return to_path diff --git a/tidy3d/web/core/stub.py b/tidy3d/web/core/stub.py index c351551f63..91fc96ac90 100644 --- a/tidy3d/web/core/stub.py +++ b/tidy3d/web/core/stub.py @@ -1,81 +1,11 @@ -"""Defines interface that can be subclassed to use with the tidy3d webapi""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.stub`.""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from os import PathLike - - -class TaskStubData(ABC): - @abstractmethod - def from_file(self, file_path: PathLike) -> TaskStubData: - """Loads a :class:`TaskStubData` from .yaml, .json, or .hdf5 file. - - Parameters - ---------- - file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. - - Returns - ------- - :class:`Stub` - An instance of the component class calling ``load``. - - """ - - @abstractmethod - def to_file(self, file_path: PathLike) -> None: - """Loads a :class:`Stub` from .yaml, .json, or .hdf5 file. - - Parameters - ---------- - file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility - Returns - ------- - :class:`Stub` - An instance of the component class calling ``load``. - """ - - -class TaskStub(ABC): - @abstractmethod - def from_file(self, file_path: PathLike) -> TaskStub: - """Loads a :class:`TaskStubData` from .yaml, .json, or .hdf5 file. - - Parameters - ---------- - file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to load the :class:`Stub` from. - - Returns - ------- - :class:`TaskStubData` - An instance of the component class calling ``load``. - """ - - @abstractmethod - def to_file(self, file_path: PathLike) -> None: - """Loads a :class:`TaskStub` from .yaml, .json, .hdf5 or .hdf5.gz file. - - Parameters - ---------- - file_path : PathLike - Full path to the .yaml or .json or .hdf5 file to load the :class:`TaskStub` from. - - Returns - ------- - :class:`Stub` - An instance of the component class calling ``load``. - """ - - @abstractmethod - def to_hdf5_gz(self, fname: PathLike) -> None: - """Exports :class:`TaskStub` instance to .hdf5.gz file. +# marked as migrated to _common +from __future__ import annotations - Parameters - ---------- - fname : PathLike - Full path to the .hdf5.gz file to save the :class:`TaskStub` to. - """ +from tidy3d._common.web.core.stub import ( + TaskStub, + TaskStubData, +) diff --git a/tidy3d/web/core/task_core.py b/tidy3d/web/core/task_core.py index f1e4ca8c7d..704961bae8 100644 --- a/tidy3d/web/core/task_core.py +++ b/tidy3d/web/core/task_core.py @@ -1,986 +1,14 @@ -"""Tidy3d webapi types.""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.task_core`.""" -from __future__ import annotations - -import os -import pathlib -import tempfile -from datetime import datetime -from os import PathLike -from typing import Callable, Optional, Union - -import requests -from botocore.exceptions import ClientError -from pydantic.v1 import Extra, Field, parse_obj_as +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility -import tidy3d as td -from tidy3d.config import config -from tidy3d.exceptions import ValidationError +# marked as migrated to _common +from __future__ import annotations -from . import http_util -from .cache import FOLDER_CACHE -from .constants import ( - SIM_ERROR_FILE, - SIM_FILE_HDF5_GZ, - SIM_LOG_FILE, - SIM_VALIDATION_FILE, - SIMULATION_DATA_HDF5_GZ, +from tidy3d._common.web.core.task_core import ( + BatchTask, + Folder, + SimulationTask, + TaskFactory, + WebTask, ) -from .core_config import get_logger_console -from .exceptions import WebError, WebNotFoundError -from .file_util import read_simulation_from_hdf5 -from .http_util import get_version as _get_protocol_version -from .http_util import http -from .s3utils import download_file, download_gz_file, upload_file -from .stub import TaskStub -from .task_info import BatchDetail, TaskInfo -from .types import PayType, Queryable, ResourceLifecycle, Submittable, Tidy3DResource - - -class Folder(Tidy3DResource, Queryable, extra=Extra.allow): - """Tidy3D Folder.""" - - folder_id: str = Field(..., title="Folder id", description="folder id", alias="projectId") - folder_name: str = Field( - ..., title="Folder name", description="folder name", alias="projectName" - ) - - @classmethod - def list(cls, projects_endpoint: str = "tidy3d/projects") -> []: - """List all folders. - - Returns - ------- - folders : [Folder] - List of folders - """ - resp = http.get(projects_endpoint) - return ( - parse_obj_as( - list[Folder], - resp, - ) - if resp - else None - ) - - @classmethod - def get( - cls, - folder_name: str, - create: bool = False, - projects_endpoint: str = "tidy3d/projects", - project_endpoint: str = "tidy3d/project", - ) -> Folder: - """Get folder by name. - - Parameters - ---------- - folder_name : str - Name of the folder. - create : str - If the folder doesn't exist, create it. - - Returns - ------- - folder : Folder - """ - folder = FOLDER_CACHE.get(folder_name) - if not folder: - resp = http.get(project_endpoint, params={"projectName": folder_name}) - if resp: - folder = Folder(**resp) - if create and not folder: - resp = http.post(projects_endpoint, {"projectName": folder_name}) - if resp: - folder = Folder(**resp) - FOLDER_CACHE[folder_name] = folder - return folder - - @classmethod - def create(cls, folder_name: str) -> Folder: - """Create a folder, return existing folder if there is one has the same name. - - Parameters - ---------- - folder_name : str - Name of the folder. - - Returns - ------- - folder : Folder - """ - return Folder.get(folder_name, True) - - def delete(self, projects_endpoint: str = "tidy3d/projects") -> None: - """Remove this folder.""" - - http.delete(f"{projects_endpoint}/{self.folder_id}") - - def delete_old(self, days_old: int) -> int: - """Remove folder contents older than ``days_old``.""" - - return http.delete( - f"tidy3d/tasks/{self.folder_id}/tasks", - params={"daysOld": days_old}, - ) - - def list_tasks(self, projects_endpoint: str = "tidy3d/projects") -> list[Tidy3DResource]: - """List all tasks in this folder. - - Returns - ------- - tasks : List[:class:`.SimulationTask`] - List of tasks in this folder - """ - resp = http.get(f"{projects_endpoint}/{self.folder_id}/tasks") - return ( - parse_obj_as( - list[SimulationTask], - resp, - ) - if resp - else None - ) - - -class WebTask(ResourceLifecycle, Submittable, extra=Extra.allow): - """Interface for managing the running a task on the server.""" - - task_id: Optional[str] = Field( - ..., - title="task_id", - description="Task ID number, set when the task is uploaded, leave as None.", - alias="taskId", - ) - - @classmethod - def create( - cls, - task_type: str, - task_name: str, - folder_name: str = "default", - callback_url: Optional[str] = None, - simulation_type: str = "tidy3d", - parent_tasks: Optional[list[str]] = None, - file_type: str = "Gz", - projects_endpoint: str = "tidy3d/projects", - ) -> SimulationTask: - """Create a new task on the server. - - Parameters - ---------- - task_type: :class".TaskType" - The type of task. - task_name: str - The name of the task. - folder_name: str, - The name of the folder to store the task. Default is "default". - callback_url: str - Http PUT url to receive simulation finish event. The body content is a json file with - fields ``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``. - simulation_type : str - Type of simulation being uploaded. - parent_tasks : List[str] - List of related task ids. - file_type: str - the simulation file type Json, Hdf5, Gz - - Returns - ------- - :class:`SimulationTask` - :class:`SimulationTask` object containing info about status, size, - credits of task and others. - """ - - # handle backwards compatibility, "tidy3d" is the default simulation_type - if simulation_type is None: - simulation_type = "tidy3d" - - folder = Folder.get(folder_name, create=True) - - if task_type in ["RF", "TERMINAL_CM", "MODAL_CM"]: - payload = { - "groupName": task_name, - "folderId": folder.folder_id, - "fileType": file_type, - "taskType": task_type, - } - resp = http.post("rf/task", payload) - else: - payload = { - "taskName": task_name, - "taskType": task_type, - "callbackUrl": callback_url, - "simulationType": simulation_type, - "parentTasks": parent_tasks, - "fileType": file_type, - } - resp = http.post(f"{projects_endpoint}/{folder.folder_id}/tasks", payload) - - return SimulationTask(**resp, taskType=task_type, folder_name=folder_name) - - def get_url(self) -> str: - base = str(config.web.website_endpoint or "") - if isinstance(self, BatchTask): - return "/".join([base.rstrip("/"), f"rf?taskId={self.task_id}"]) - return "/".join([base.rstrip("/"), f"workbench?taskId={self.task_id}"]) - - def get_folder_url(self) -> Optional[str]: - folder_id = getattr(self, "folder_id", None) - if not folder_id: - return None - base = str(config.web.website_endpoint or "") - return "/".join([base.rstrip("/"), f"folders/{folder_id}"]) - - def get_log( - self, - to_file: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - ) -> pathlib.Path: - """Get log file from Server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while downloading the data. - - Returns - ------- - path: pathlib.Path - Path to saved file. - """ - - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - target_path = pathlib.Path(to_file) - - return download_file( - self.task_id, - SIM_LOG_FILE, - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - - def get_data_hdf5( - self, - to_file: PathLike, - remote_data_file_gz: PathLike = SIMULATION_DATA_HDF5_GZ, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - ) -> pathlib.Path: - """Download data artifact (simulation or batch) with gz fallback handling. - - Parameters - ---------- - remote_data_file_gz : PathLike - Gzipped remote filename. - to_file : PathLike - Local target path. - verbose : bool - Whether to log progress. - progress_callback : Optional[Callable[[float], None]] - Progress callback. - - Returns - ------- - pathlib.Path - Saved local path. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - target_path = pathlib.Path(to_file) - file = None - try: - file = download_gz_file( - resource_id=self.task_id, - remote_filename=remote_data_file_gz, - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - except ClientError: - if verbose: - console = get_logger_console() - console.log(f"Unable to download '{remote_data_file_gz}'.") - if not file: - try: - file = download_file( - resource_id=self.task_id, - remote_filename=str(remote_data_file_gz)[:-3], - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - except Exception as e: - raise WebError( - "Failed to download the data file from the server. " - "Please confirm that the task completed successfully." - ) from e - return file - - @staticmethod - def is_batch(resource_id: str) -> bool: - """Checks if a given resource ID corresponds to a valid batch task. - - This is a utility function to verify a batch task's existence before - instantiating the class. - - Parameters - ---------- - resource_id : str - The unique identifier for the resource. - - Returns - ------- - bool - ``True`` if the resource is a valid batch task, ``False`` otherwise. - """ - try: - # TODO PROPERLY FIXME - # Disable non critical logs due to check for resourceId, until we have a dedicated API for this - resp = http.get( - f"rf/task/{resource_id}/statistics", - suppress_404=True, - ) - status = bool(resp and isinstance(resp, dict) and "status" in resp) - return status - except Exception: - return False - - def delete(self, versions: bool = False) -> None: - """Delete current task from server. - - Parameters - ---------- - versions : bool = False - If ``True``, delete all versions of the task in the task group. Otherwise, delete only - the version associated with the current task ID. - """ - if not self.task_id: - raise ValueError("Task id not found.") - - task_details = self.detail().dict() - - if task_details and "groupId" in task_details: - group_id = task_details["groupId"] - if versions: - http.delete("tidy3d/group", json={"groupIds": [group_id]}) - return - elif "version" in task_details: - version = task_details["version"] - http.delete(f"tidy3d/group/{group_id}/versions", json={"versions": [version]}) - return - - # Fallback to old method if we can't get the groupId and version - http.delete(f"tidy3d/tasks/{self.task_id}") - - -class SimulationTask(WebTask): - """Interface for managing the running of solver tasks on the server.""" - - folder_id: Optional[str] = Field( - None, - title="folder_id", - description="Folder ID number, set when the task is uploaded, leave as None.", - alias="folderId", - ) - status: Optional[str] = Field(title="status", description="Simulation task status.") - - real_flex_unit: float = Field( - None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" - ) - - created_at: Optional[datetime] = Field( - title="created_at", description="Time at which this task was created.", alias="createdAt" - ) - - task_type: Optional[str] = Field( - title="task_type", description="The type of task.", alias="taskType" - ) - - folder_name: Optional[str] = Field( - "default", - title="Folder Name", - description="Name of the folder associated with this task.", - alias="folderName", - ) - - callback_url: str = Field( - None, - title="Callback URL", - description="Http PUT url to receive simulation finish event. " - "The body content is a json file with fields " - "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", - ) - - # simulation_type: str = pd.Field( - # None, - # title="Simulation Type", - # description="Type of simulation, used internally only.", - # ) - - # parent_tasks: Tuple[TaskId, ...] = pd.Field( - # None, - # title="Parent Tasks", - # description="List of parent task ids for the simulation, used internally only." - # ) - - @classmethod - def get(cls, task_id: str, verbose: bool = True) -> SimulationTask: - """Get task from the server by id. - - Parameters - ---------- - task_id: str - Unique identifier of task on server. - verbose: - If `True`, will print progressbars and status, otherwise, will run silently. - - Returns - ------- - :class:`.SimulationTask` - :class:`.SimulationTask` object containing info about status, - size, credits of task and others. - """ - try: - resp = http.get(f"tidy3d/tasks/{task_id}/detail") - except WebNotFoundError as e: - td.log.error(f"The requested task ID '{task_id}' does not exist.") - raise e - - task = SimulationTask(**resp) if resp else None - return task - - @classmethod - def get_running_tasks(cls) -> list[SimulationTask]: - """Get a list of running tasks from the server" - - Returns - ------- - List[:class:`.SimulationTask`] - :class:`.SimulationTask` object containing info about status, - size, credits of task and others. - """ - resp = http.get("tidy3d/py/tasks") - if not resp: - return [] - return parse_obj_as(list[SimulationTask], resp) - - def detail(self) -> TaskInfo: - """Fetches the detailed information and status of the task. - - Returns - ------- - TaskInfo - An object containing the task's latest data. - """ - resp = http.get(f"tidy3d/tasks/{self.task_id}/detail") - return TaskInfo(**{"taskId": self.task_id, "taskType": self.task_type, **resp}) - - def get_simulation_json(self, to_file: PathLike, verbose: bool = True) -> None: - """Get json file for a :class:`.Simulation` from server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - to_file = pathlib.Path(to_file) - - hdf5_file, hdf5_file_path = tempfile.mkstemp(".hdf5") - os.close(hdf5_file) - try: - self.get_simulation_hdf5(hdf5_file_path) - if os.path.exists(hdf5_file_path): - json_string = read_simulation_from_hdf5(hdf5_file_path) - to_file.parent.mkdir(parents=True, exist_ok=True) - with to_file.open("w", encoding="utf-8") as file: - # Write the string to the file - file.write(json_string.decode("utf-8")) - if verbose: - console = get_logger_console() - console.log(f"Generate {to_file} successfully.") - else: - raise WebError("Failed to download simulation.json.") - finally: - os.unlink(hdf5_file_path) - - def upload_simulation( - self, - stub: TaskStub, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - remote_sim_file: PathLike = SIM_FILE_HDF5_GZ, - ) -> None: - """Upload :class:`.Simulation` object to Server. - - Parameters - ---------- - stub: :class:`TaskStub` - An instance of TaskStub. - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while uploading the data. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - if not stub: - raise WebError("Expected field 'simulation' is unset.") - # Also upload hdf5.gz containing all data. - file, file_name = tempfile.mkstemp() - os.close(file) - try: - # upload simulation - # compress .hdf5 to .hdf5.gz - stub.to_hdf5_gz(file_name) - upload_file( - self.task_id, - file_name, - remote_sim_file, - verbose=verbose, - progress_callback=progress_callback, - ) - finally: - os.unlink(file_name) - - def upload_file( - self, - local_file: PathLike, - remote_filename: str, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - ) -> None: - """ - Upload file to platform. Using this method when the json file is too large to parse - as :class".simulation". - Parameters - ---------- - local_file: PathLike - Local file path. - remote_filename: str - file name on the server - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while uploading the data. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - upload_file( - self.task_id, - local_file, - remote_filename, - verbose=verbose, - progress_callback=progress_callback, - ) - - def submit( - self, - solver_version: Optional[str] = None, - worker_group: Optional[str] = None, - pay_type: Union[PayType, str] = PayType.AUTO, - priority: Optional[int] = None, - ) -> None: - """Kick off this task. - - It will be uploaded to server before - starting the task. Otherwise, this method assumes that the Simulation has been uploaded by - the upload_file function, so the task will be kicked off directly. - - Parameters - ---------- - solver_version: str = None - target solver version. - worker_group: str = None - worker group - pay_type: Union[PayType, str] = PayType.AUTO - Which method to pay the simulation. - priority: int = None - Priority of the simulation in the Virtual GPU (vGPU) queue (1 = lowest, 10 = highest). - It affects only simulations from vGPU licenses and does not impact simulations using FlexCredits. - """ - pay_type = PayType(pay_type) if not isinstance(pay_type, PayType) else pay_type - - if solver_version: - protocol_version = None - else: - protocol_version = http_util.get_version() - - http.post( - f"tidy3d/tasks/{self.task_id}/submit", - { - "solverVersion": solver_version, - "workerGroup": worker_group, - "protocolVersion": protocol_version, - "enableCaching": config.web.enable_caching, - "payType": pay_type.value, - "priority": priority, - }, - ) - - def estimate_cost(self, solver_version: Optional[str] = None) -> float: - """Compute the maximum flex unit charge for a given task, assuming the simulation runs for - the full ``run_time``. If early shut-off is triggered, the cost is adjusted proportionately. - - Parameters - ---------- - solver_version: str - target solver version. - - Returns - ------- - flex_unit_cost: float - estimated cost in FlexCredits - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - if solver_version: - protocol_version = None - else: - protocol_version = http_util.get_version() - - resp = http.post( - f"tidy3d/tasks/{self.task_id}/metadata", - { - "solverVersion": solver_version, - "protocolVersion": protocol_version, - }, - ) - return resp - - def get_simulation_hdf5( - self, - to_file: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - remote_sim_file: PathLike = SIM_FILE_HDF5_GZ, - ) -> pathlib.Path: - """Get simulation.hdf5 file from Server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while downloading the data. - - Returns - ------- - path: pathlib.Path - Path to saved file. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - target_path = pathlib.Path(to_file) - - return download_gz_file( - resource_id=self.task_id, - remote_filename=remote_sim_file, - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - - def get_running_info(self) -> tuple[float, float]: - """Gets the % done and field_decay for a running task. - - Returns - ------- - perc_done : float - Percentage of run done (in terms of max number of time steps). - Is ``None`` if run info not available. - field_decay : float - Average field intensity normalized to max value (1.0). - Is ``None`` if run info not available. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - resp = http.get(f"tidy3d/tasks/{self.task_id}/progress") - perc_done = resp.get("perc_done") - field_decay = resp.get("field_decay") - return perc_done, field_decay - - def get_log( - self, - to_file: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - ) -> pathlib.Path: - """Get log file from Server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while downloading the data. - - Returns - ------- - path: pathlib.Path - Path to saved file. - """ - - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - target_path = pathlib.Path(to_file) - - return download_file( - self.task_id, - SIM_LOG_FILE, - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - - def get_error_json( - self, to_file: PathLike, verbose: bool = True, validation: bool = False - ) -> pathlib.Path: - """Get error json file for a :class:`.Simulation` from server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - validation: bool = False - Whether to get a validation error file or a solver error file. - - Returns - ------- - path: pathlib.Path - Path to saved file. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - target_path = pathlib.Path(to_file) - target_file = SIM_ERROR_FILE if not validation else SIM_VALIDATION_FILE - - return download_file( - self.task_id, - target_file, - to_file=target_path, - verbose=verbose, - ) - - def abort(self) -> requests.Response: - """Abort the current task on the server.""" - if not self.task_id: - raise ValueError("Task id not found.") - return http.put( - "tidy3d/tasks/abort", json={"taskType": self.task_type, "taskId": self.task_id} - ) - - def validate_post_upload(self, parent_tasks: Optional[list[str]] = None) -> None: - """Perform checks after task is uploaded and metadata is processed.""" - if self.task_type == "HEAT_CHARGE" and parent_tasks: - try: - if len(parent_tasks) > 1: - raise ValueError( - "A single parent 'task_id' corresponding to the task in which the meshing " - "was run must be provided." - ) - try: - # get mesh task info - mesh_task = SimulationTask.get(parent_tasks[0], verbose=False) - assert mesh_task.task_type == "VOLUME_MESH" - assert mesh_task.status == "success" - # get up-to-date task info - task = SimulationTask.get(self.task_id, verbose=False) - if task.fileMd5 != mesh_task.childFileMd5: - raise ValidationError( - "Simulation stored in parent task 'VolumeMesher' does not match the " - "current simulation." - ) - except Exception as e: - raise ValidationError( - "The parent task must be a 'VolumeMesher' task which has been successfully " - "run and is associated to the same 'HeatChargeSimulation' as provided here." - ) from e - - except Exception as e: - raise WebError(f"Provided 'parent_tasks' failed validation: {e!s}") from e - - -class BatchTask(WebTask): - """Interface for managing a batch task on the server.""" - - @classmethod - def get(cls, task_id: str, verbose: bool = True) -> BatchTask: - """Get batch task by id. - - Parameters - ---------- - task_id: str - Unique identifier of batch on server. - verbose: - If `True`, will print progressbars and status, otherwise, will run silently. - - Returns - ------- - :class:`.BatchTask` | None - BatchTask object if found, otherwise None. - """ - try: - resp = http.get(f"rf/task/{task_id}/statistics") - except WebNotFoundError as e: - td.log.error(f"The requested batch ID '{task_id}' does not exist.") - raise e - # We only need to validate existence; store id on the instance. - return BatchTask(taskId=task_id) if resp else None - - def detail(self) -> BatchDetail: - """Fetches the detailed information and status of the batch. - - Returns - ------- - BatchDetail - An object containing the batch's latest data. - """ - resp = http.get( - f"rf/task/{self.task_id}/statistics", - ) - # Some backends may return null for collection fields; coerce to sensible defaults - if isinstance(resp, dict): - if resp.get("tasks") is None: - resp["tasks"] = [] - return BatchDetail(**(resp or {})) - - def check( - self, - check_task_type: str, - solver_version: Optional[str] = None, - protocol_version: Optional[str] = None, - ) -> requests.Response: - """Submits a request to validate the batch configuration on the server. - - Parameters - ---------- - solver_version : Optional[str], default=None - The version of the solver to use for validation. - protocol_version : Optional[str], default=None - The data protocol version. Defaults to the current version. - - Returns - ------- - Any - The server's response to the check request. - """ - if protocol_version is None: - protocol_version = _get_protocol_version() - return http.post( - f"rf/task/{self.task_id}/check", - { - "solverVersion": solver_version, - "protocolVersion": protocol_version, - "taskType": check_task_type, - }, - ) - - def submit( - self, - solver_version: Optional[str] = None, - protocol_version: Optional[str] = None, - worker_group: Optional[str] = None, - pay_type: Union[PayType, str] = PayType.AUTO, - priority: Optional[int] = None, - ) -> requests.Response: - """Submits the batch for execution on the server. - - Parameters - ---------- - solver_version : Optional[str], default=None - The version of the solver to use for execution. - protocol_version : Optional[str], default=None - The data protocol version. Defaults to the current version. - worker_group : Optional[str], default=None - Optional identifier for a specific worker group to run on. - - Returns - ------- - Any - The server's response to the submit request. - """ - - # TODO: add support for pay_type and priority arguments - if pay_type != PayType.AUTO: - raise NotImplementedError( - "The 'pay_type' argument is not yet supported and will be ignored." - ) - if priority is not None: - raise NotImplementedError( - "The 'priority' argument is not yet supported and will be ignored." - ) - - if protocol_version is None: - protocol_version = _get_protocol_version() - return http.post( - f"rf/task/{self.task_id}/submit", - { - "solverVersion": solver_version, - "protocolVersion": protocol_version, - "workerGroup": worker_group, - }, - ) - - def abort(self) -> requests.Response: - """Abort the current task on the server.""" - if not self.task_id: - raise ValueError("Batch id not found.") - return http.put(f"rf/task/{self.task_id}/abort", {}) - - -class TaskFactory: - """Factory for obtaining the correct task subclass.""" - - _REGISTRY: dict[str, str] = {} - - @classmethod - def reset(cls) -> None: - """Clear the cached task kind registry (used in tests).""" - cls._REGISTRY.clear() - - @classmethod - def register(cls, task_id: str, kind: str) -> None: - cls._REGISTRY[task_id] = kind - - @classmethod - def get(cls, task_id: str, verbose: bool = True) -> WebTask: - kind = cls._REGISTRY.get(task_id) - if kind == "batch": - return BatchTask.get(task_id, verbose=verbose) - if kind == "simulation": - task = SimulationTask.get(task_id, verbose=verbose) - return task - if WebTask.is_batch(task_id): - cls.register(task_id, "batch") - return BatchTask.get(task_id, verbose=verbose) - task = SimulationTask.get(task_id, verbose=verbose) - if task: - cls.register(task_id, "simulation") - return task diff --git a/tidy3d/web/core/task_info.py b/tidy3d/web/core/task_info.py index 825ab9b6a5..f902fc1bcd 100644 --- a/tidy3d/web/core/task_info.py +++ b/tidy3d/web/core/task_info.py @@ -1,299 +1,18 @@ -"""Defines information about a task""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.task_info`.""" -from __future__ import annotations - -from abc import ABC -from datetime import datetime -from enum import Enum -from typing import Optional - -import pydantic.v1 as pydantic - - -class TaskBase(pydantic.BaseModel, ABC): - """Base configuration for all task objects.""" - - class Config: - """Configuration for TaskBase""" - - arbitrary_types_allowed = True - """Allow arbitrary types to be used within the model.""" - - -class ChargeType(str, Enum): - """The payment method of the task.""" - - FREE = "free" - """No payment required.""" - - PAID = "paid" - """Payment required.""" - - -class TaskBlockInfo(TaskBase): - """Information about the task's block status. - - This includes details about how the task can be blocked by various features - such as user limits and insufficient balance. - """ - - chargeType: ChargeType = None - """The type of charge applicable to the task (free or paid).""" - - maxFreeCount: int = None - """The maximum number of free tasks allowed.""" - - maxGridPoints: int = None - """The maximum number of grid points permitted.""" - - maxTimeSteps: int = None - """The maximum number of time steps allowed.""" - - -class TaskInfo(TaskBase): - """General information about a task.""" - - taskId: str - """Unique identifier for the task.""" - - taskName: str = None - """Name of the task.""" - - nodeSize: int = None - """Size of the node allocated for the task.""" - - completedAt: Optional[datetime] = None - """Timestamp when the task was completed.""" - - status: str = None - """Current status of the task.""" - - realCost: float = None - """Actual cost incurred by the task.""" - - timeSteps: int = None - """Number of time steps involved in the task.""" - - solverVersion: str = None - """Version of the solver used for the task.""" - - createAt: Optional[datetime] = None - """Timestamp when the task was created.""" - - estCostMin: float = None - """Estimated minimum cost for the task.""" - - estCostMax: float = None - """Estimated maximum cost for the task.""" - - realFlexUnit: float = None - """Actual flexible units used by the task.""" - - oriRealFlexUnit: float = None - """Original real flexible units.""" - - estFlexUnit: float = None - """Estimated flexible units for the task.""" - - estFlexCreditTimeStepping: float = None - """Estimated flexible credits for time stepping.""" - - estFlexCreditPostProcess: float = None - """Estimated flexible credits for post-processing.""" - - estFlexCreditMode: float = None - """Estimated flexible credits based on the mode.""" - - s3Storage: float = None - """Amount of S3 storage used by the task.""" - - startSolverTime: Optional[datetime] = None - """Timestamp when the solver started.""" - - finishSolverTime: Optional[datetime] = None - """Timestamp when the solver finished.""" - - totalSolverTime: int = None - """Total time taken by the solver.""" - - callbackUrl: str = None - """Callback URL for task notifications.""" - - taskType: str = None - """Type of the task.""" - - metadataStatus: str = None - """Status of the metadata for the task.""" - - taskBlockInfo: TaskBlockInfo = None - """Blocking information for the task.""" - - version: str = None - """Version of the task.""" - - -class RunInfo(TaskBase): - """Information about the run of a task.""" - - perc_done: pydantic.confloat(ge=0.0, le=100.0) - """Percentage of the task that is completed (0 to 100).""" - - field_decay: pydantic.confloat(ge=0.0, le=1.0) - """Field decay from the maximum value (0 to 1).""" - - def display(self) -> None: - """Print some info about the task's progress.""" - print(f" - {self.perc_done:.2f} (%) done") - print(f" - {self.field_decay:.2e} field decay from max") - - -# ---------------------- Batch (Modeler) detail schema ---------------------- # - - -class BatchTaskBlockInfo(TaskBlockInfo): - """ - Extends `TaskBlockInfo` with specific details for batch task blocking. - - Attributes: - accountLimit: A usage or cost limit imposed by the user's account. - taskBlockMsg: A human-readable message describing the reason for the block. - taskBlockType: The specific type of block (e.g., 'balance', 'limit'). - blockStatus: The current blocking status for the batch. - taskStatus: The status of the task when it was blocked. - """ - - accountLimit: float = None - taskBlockMsg: str = None - taskBlockType: str = None - blockStatus: str = None - taskStatus: str = None - - -class BatchMember(TaskBase): - """ - Represents a single task within a larger batch operation. - - Attributes: - refId: A reference identifier for the member task. - folderId: The identifier of the folder containing the task. - sweepId: The identifier for the parameter sweep, if applicable. - taskId: The unique identifier of the task. - linkedTaskId: The identifier of a task linked to this one. - groupId: The identifier of the group this task belongs to. - taskName: The name of the individual task. - status: The current status of this specific task. - sweepData: Data associated with a parameter sweep. - validateInfo: Information related to the task's validation. - replaceData: Data used for replacements or modifications. - protocolVersion: The version of the protocol used. - variable: The variable parameter for this task in a sweep. - createdAt: The timestamp when the member task was created. - updatedAt: The timestamp when the member task was last updated. - denormalizeStatus: The status of the data denormalization process. - summary: A dictionary containing summary information for the task. - """ - - refId: str = None - folderId: str = None - sweepId: str = None - taskId: str = None - linkedTaskId: str = None - groupId: str = None - taskName: str = None - status: str = None - sweepData: str = None - validateInfo: str = None - replaceData: str = None - protocolVersion: str = None - variable: str = None - createdAt: Optional[datetime] = None - updatedAt: Optional[datetime] = None - denormalizeStatus: str = None - summary: dict = None - - -class BatchDetail(TaskBase): - """ - Provides a detailed, top-level view of a batch of tasks. - - This model serves as the main payload for retrieving comprehensive - information about a batch operation. - - Attributes: - refId: A reference identifier for the entire batch. - optimizationId: Identifier for the optimization process, if any. - groupId: Identifier for the group the batch belongs to. - name: The user-defined name of the batch. - status: The current status of the batch. - totalTask: The total number of tasks in the batch. - preprocessSuccess: The count of tasks that completed preprocessing. - postprocessStatus: The status of the batch's postprocessing stage. - validateSuccess: The count of tasks that passed validation. - runSuccess: The count of tasks that ran successfully. - postprocessSuccess: The count of tasks that completed postprocessing. - taskBlockInfo: Information on what might be blocking the batch. - estFlexUnit: The estimated total flexible compute units for the batch. - totalSeconds: The total time in seconds the batch has taken. - totalCheckMillis: Total time in milliseconds spent on checks. - message: A general message providing information about the batch status. - tasks: A list of `BatchMember` objects, one for each task in the batch. - taskType: The type of tasks contained in the batch. - """ - - refId: str = None - optimizationId: str = None - groupId: str = None - name: str = None - status: str = None - totalTask: int = 0 - preprocessSuccess: int = 0 - postprocessStatus: str = None - validateSuccess: int = 0 - runSuccess: int = 0 - postprocessSuccess: int = 0 - taskBlockInfo: BatchTaskBlockInfo = None - estFlexUnit: float = None - realFlexUnit: float = None - totalSeconds: int = None - totalCheckMillis: int = None - message: str = None - tasks: list[BatchMember] = [] - validateErrors: dict = None - taskType: str = "RF" - version: str = None - - -class AsyncJobDetail(TaskBase): - """ - Provides a detailed view of an asynchronous job and its sub-tasks. - - This model represents a long-running operation. The 'result' attribute holds - the output of a completed job, which for orchestration jobs, is often a - JSON string mapping sub-task names to their unique IDs. - - Attributes: - asyncId: The unique identifier for the asynchronous job. - status: The current overall status of the job (e.g., 'RUNNING', 'COMPLETED'). - progress: The completion percentage of the job (from 0.0 to 100.0). - createdAt: The timestamp when the job was created. - completedAt: The timestamp when the job finished (successfully or not). - tasks: A dictionary mapping logical task keys to their unique task IDs. - This is often populated by parsing the 'result' of an orchestration task. - result: The raw string output of the completed job. If the job spawns other - tasks, this is expected to be a JSON string detailing those tasks. - taskBlockInfo: Information on any dependencies blocking the job from running. - message: A human-readable message about the job's status. - """ - - asyncId: str - status: str - progress: Optional[float] = None - createdAt: Optional[datetime] = None - completedAt: Optional[datetime] = None - tasks: Optional[dict[str, str]] = None - result: Optional[str] = None - taskBlockInfo: Optional[TaskBlockInfo] = None - message: Optional[str] = None +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -AsyncJobDetail.update_forward_refs() +from tidy3d._common.web.core.task_info import ( + AsyncJobDetail, + BatchDetail, + BatchMember, + BatchTaskBlockInfo, + ChargeType, + RunInfo, + TaskBase, + TaskBlockInfo, + TaskInfo, +) diff --git a/tidy3d/web/core/types.py b/tidy3d/web/core/types.py index 1c1439366c..51437dbbdf 100644 --- a/tidy3d/web/core/types.py +++ b/tidy3d/web/core/types.py @@ -1,73 +1,15 @@ -"""Tidy3d abstraction types for the core.""" +"""Compatibility shim for :mod:`tidy3d._common.web.core.types`.""" -from __future__ import annotations - -from abc import ABC, abstractmethod -from enum import Enum -from typing import Any - -from pydantic.v1 import BaseModel - - -class Tidy3DResource(BaseModel, ABC): - """Abstract base class / template for a webservice that implements resource query.""" - - @classmethod - @abstractmethod - def get(cls, *args: Any, **kwargs: Any) -> Tidy3DResource: - """Get a resource from the server.""" - - -class ResourceLifecycle(Tidy3DResource, ABC): - """Abstract base class for a webservice that implements resource life cycle management.""" - - @classmethod - @abstractmethod - def create(cls, *args: Any, **kwargs: Any) -> Tidy3DResource: - """Create a new resource and return it.""" - - @abstractmethod - def delete(self, *args: Any, **kwargs: Any) -> None: - """Delete the resource.""" +# ruff: noqa: F401 - ignore unused imports, imports ensure compatibility +# marked as migrated to _common +from __future__ import annotations -class Submittable(BaseModel, ABC): - """Abstract base class / template for a webservice that implements a submit method.""" - - @abstractmethod - def submit(self, *args: Any, **kwargs: Any) -> None: - """Submit the task to the webservice.""" - - -class Queryable(BaseModel, ABC): - """Abstract base class / template for a webservice that implements a query method.""" - - @classmethod - @abstractmethod - def list(cls, *args: Any, **kwargs: Any) -> [Queryable]: - """List all resources of this type.""" - - -class TaskType(str, Enum): - FDTD = "FDTD" - MODE_SOLVER = "MODE_SOLVER" - HEAT = "HEAT" - HEAT_CHARGE = "HEAT_CHARGE" - EME = "EME" - MODE = "MODE" - VOLUME_MESH = "VOLUME_MESH" - MODAL_CM = "MODAL_CM" - TERMINAL_CM = "TERMINAL_CM" - - -class PayType(str, Enum): - CREDITS = "FLEX_CREDIT" - AUTO = "AUTO" - - @classmethod - def _missing_(cls, value: object) -> PayType: - if isinstance(value, str): - key = value.strip().replace(" ", "_").upper() - if key in cls.__members__: - return cls.__members__[key] - return super()._missing_(value) +from tidy3d._common.web.core.types import ( + PayType, + Queryable, + ResourceLifecycle, + Submittable, + TaskType, + Tidy3DResource, +) diff --git a/tidy3d/web/tests/conftest.py b/tidy3d/web/tests/conftest.py index 918dce5588..5b6585a547 100644 --- a/tidy3d/web/tests/conftest.py +++ b/tidy3d/web/tests/conftest.py @@ -1,10 +1,13 @@ from __future__ import annotations -from collections.abc import Generator +from typing import TYPE_CHECKING import pytest from tidy3d_frontend.tidy3d.web.core.task_core import TaskFactory +if TYPE_CHECKING: + from collections.abc import Generator + @pytest.fixture(autouse=True) def clear_task_factory_registry() -> Generator[None, None, None]: