Skip to content

8185 test refactor 2 #8405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 26 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
57d7b09
Merge branch 'dev' into 8185-test-refactor-2
garciadias Feb 27, 2025
5342cba
Merge remote-tracking branch 'upstream/dev' into 8185-test-refactor-2
garciadias Mar 28, 2025
0139125
Refactor self-attention test cases to use dict_product for parameter …
garciadias Mar 28, 2025
836cf6e
Refactor PatchEmbeddingBlock test cases to use dict_product for param…
garciadias Mar 28, 2025
b2ccb26
Refactor RetinaNet test cases to use dict_product for parameter combi…
garciadias Mar 28, 2025
0735dfd
Refactor test_meta_tensor to use dict_product for parameter combinations
garciadias Mar 28, 2025
45622d4
Refactor test_box_transform to use dict_product for parameter combina…
garciadias Mar 28, 2025
129f778
Autofix
garciadias Mar 28, 2025
ebae4e3
Fix mypy error
garciadias Mar 29, 2025
482e5bf
Fix missing parameter
garciadias Apr 1, 2025
9a0d5b1
DCO Remediation Commit for R. Garcia-Dias <[email protected]>
garciadias Apr 1, 2025
c06cad1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 1, 2025
a5596d7
revert change to tests/apps/detection/test_box_transform.py
garciadias Apr 10, 2025
b46ccc0
redesign dict_product to make more readable
garciadias Apr 11, 2025
6077ccd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
cd5f790
add license to temporary test
garciadias Apr 11, 2025
84d85ea
fix test param name
garciadias Apr 11, 2025
336b287
Simplify with list comprehension
garciadias Apr 11, 2025
0cedbb5
autofix
garciadias Apr 11, 2025
a9b0fc9
N806 not catch locally, but in CI
garciadias Apr 11, 2025
d2259b9
dict_product function to accept Iterable[Any] for improved flexibility
garciadias Apr 11, 2025
5922e2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
ef2388e
autofix
garciadias Apr 11, 2025
016d3b0
Fix test mistakes
garciadias Apr 11, 2025
ce798d7
autofix
garciadias Apr 11, 2025
8fad281
fix mypy errora
garciadias Apr 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions tests/apps/detection/networks/test_retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from monai.networks import eval_mode
from monai.networks.nets import resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
from monai.utils import ensure_tuple, optional_import
from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_quick, test_onnx_save, test_script_save
from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, skip_if_quick, test_onnx_save, test_script_save

_, has_torchvision = optional_import("torchvision")

Expand Down Expand Up @@ -86,15 +86,12 @@
(2, 1, 32, 64),
]

TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])
# Create all test case combinations using dict_product
CASE_LIST = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]
MODEL_LIST = [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]

TEST_CASES_TS = []
for case in [TEST_CASE_1]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES_TS.append([model, *case])
TEST_CASES = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=CASE_LIST)]
TEST_CASES_TS = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1])]


@SkipIfBeforePyTorchVersion((1, 12))
Expand Down
10 changes: 5 additions & 5 deletions tests/data/meta_tensor/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
from monai.data.utils import decollate_batch, list_data_collate
from monai.transforms import BorderPadd, Compose, DivisiblePadd, FromMetaTensord, ToMetaTensord
from monai.utils.enums import PostFix
from tests.test_utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, skip_if_no_cuda
from tests.test_utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, dict_product, skip_if_no_cuda

DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32], [None]]
TESTS = []
for _device in TEST_DEVICES:
for _dtype in DTYPES:
TESTS.append((*_device, *_dtype)) # type: ignore

# Replace nested loops with dict_product

TESTS = [(*params["device"], *params["dtype"]) for params in dict_product(device=TEST_DEVICES, dtype=DTYPES)]


def rand_string(min_len=5, max_len=10):
Expand Down
81 changes: 32 additions & 49 deletions tests/networks/blocks/test_patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,58 +21,41 @@
from monai.networks import eval_mode
from monai.networks.blocks.patchembedding import PatchEmbed, PatchEmbeddingBlock
from monai.utils import optional_import
from tests.test_utils import SkipIfBeforePyTorchVersion
from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product

einops, has_einops = optional_import("einops")

TEST_CASE_PATCHEMBEDDINGBLOCK = []
for dropout_rate in (0.5,):
for in_channels in [1, 4]:
for hidden_size in [96, 288]:
for img_size in [32, 64]:
for patch_size in [8, 16]:
for num_heads in [8, 12]:
for proj_type in ["conv", "perceptron"]:
for pos_embed_type in ["none", "learnable", "sincos"]:
# for classification in (False, True): # TODO: add classification tests
for nd in (2, 3):
test_case = [
{
"in_channels": in_channels,
"img_size": (img_size,) * nd,
"patch_size": (patch_size,) * nd,
"hidden_size": hidden_size,
"num_heads": num_heads,
"proj_type": proj_type,
"pos_embed_type": pos_embed_type,
"dropout_rate": dropout_rate,
},
(2, in_channels, *([img_size] * nd)),
(2, (img_size // patch_size) ** nd, hidden_size),
]
if nd == 2:
test_case[0]["spatial_dims"] = 2 # type: ignore
TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case)

TEST_CASE_PATCHEMBED = []
for patch_size in [2]:
for in_chans in [1, 4]:
for img_size in [96]:
for embed_dim in [6, 12]:
for norm_layer in [nn.LayerNorm]:
for nd in [2, 3]:
test_case = [
{
"patch_size": (patch_size,) * nd,
"in_chans": in_chans,
"embed_dim": embed_dim,
"norm_layer": norm_layer,
"spatial_dims": nd,
},
(2, in_chans, *([img_size] * nd)),
(2, embed_dim, *([img_size // patch_size] * nd)),
]
TEST_CASE_PATCHEMBED.append(test_case)

TEST_CASE_PATCHEMBEDDINGBLOCK = [
[
params,
(2, params["in_channels"], *([params["img_size"]] * params["spatial_dims"])),
(2, (params["img_size"] // params["patch_size"]) ** params["spatial_dims"], params["hidden_size"]),
]
for params in dict_product(
dropout_rate=[0.5],
in_channels=[1, 4],
hidden_size=[96, 288],
img_size=[32, 64],
patch_size=[8, 16],
num_heads=[8, 12],
proj_type=["conv", "perceptron"],
pos_embed_type=["none", "learnable", "sincos"],
spatial_dims=[2, 3],
)
]

img_size = 96
TEST_CASE_PATCHEMBED = [
[
params,
(2, params["in_chans"], *([img_size] * params["spatial_dims"])),
(2, params["embed_dim"], *([img_size // params["patch_size"]]) * params["spatial_dims"]),
]
for params in dict_product(
patch_size=[2], in_chans=[1, 4], embed_dim=[6, 12], norm_layer=[nn.LayerNorm], spatial_dims=[2, 3]
)
]


@SkipIfBeforePyTorchVersion((1, 11, 1))
Expand Down
37 changes: 13 additions & 24 deletions tests/networks/blocks/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,22 @@
from monai.networks.blocks.selfattention import SABlock
from monai.networks.layers.factories import RelPosEmbedding
from monai.utils import optional_import
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, dict_product, test_script_save

einops, has_einops = optional_import("einops")

TEST_CASE_SABLOCK = []
for dropout_rate in np.linspace(0, 1, 4):
for hidden_size in [360, 480, 600, 768]:
for num_heads in [4, 6, 8, 12]:
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
for input_size in [(16, 32), (8, 8, 8)]:
for include_fc in [True, False]:
for use_combined_linear in [True, False]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding,
"input_size": input_size,
"include_fc": include_fc,
"use_combined_linear": use_combined_linear,
"use_flash_attention": True if rel_pos_embedding is None else False,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)
TEST_CASE_SABLOCK = [
[params, (2, 512, params["hidden_size"]), (2, 512, params["hidden_size"])]
for params in dict_product(
dropout_rate=np.linspace(0, 1, 4),
hidden_size=[360, 480, 600, 768],
num_heads=[4, 6, 8, 12],
rel_pos_embedding=[None, RelPosEmbedding.DECOMPOSED],
input_size=[(16, 32), (8, 8, 8)],
include_fc=[True, False],
use_combined_linear=[True, False],
)
]


class TestResBlock(unittest.TestCase):
Expand Down
85 changes: 85 additions & 0 deletions tests/test_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

from tests.test_utils import dict_product


class TestTestUtils(unittest.TestCase):
def setUp(self):
test_case_patchembeddingblock = []
for dropout_rate in (0.5,):
for in_channels in [1, 4]:
for hidden_size in [96, 288]:
for img_size in [32, 64]:
for patch_size in [8, 16]:
for num_heads in [8, 12]:
for proj_type in ["conv", "perceptron"]:
for pos_embed_type in ["none", "learnable", "sincos"]:
# for classification in (False, True): # TODO: add classification tests
for nd in (2, 3):
test_case = [
{
"in_channels": in_channels,
"img_size": (img_size,) * nd,
"patch_size": (patch_size,) * nd,
"hidden_size": hidden_size,
"num_heads": num_heads,
"proj_type": proj_type,
"pos_embed_type": pos_embed_type,
"dropout_rate": dropout_rate,
"spatial_dims": nd,
},
(2, in_channels, *([img_size] * nd)),
(2, (img_size // patch_size) ** nd, hidden_size),
]
test_case_patchembeddingblock.append(test_case)

self.test_case_patchembeddingblock = test_case_patchembeddingblock

def test_case_patchembeddingblock(self):
test_case_patchembeddingblock = dict_product(
dropout_rate=[0.5],
in_channels=[1, 4],
hidden_size=[96, 288],
img_size=[32, 64],
patch_size=[8, 16],
num_heads=[8, 12],
proj_type=["conv", "perceptron"],
pos_embed_type=["none", "learnable", "sincos"],
nd=[2, 3],
)
test_case_patchembeddingblock = [
[
params,
(2, params["in_channels"], *([params["img_size"]] * params["nd"])),
(2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["hidden_size"]),
]
for params in test_case_patchembeddingblock
]

self.assertIsInstance(test_case_patchembeddingblock, list)
self.assertGreater(len(test_case_patchembeddingblock), 0)
self.assertEqual(len(test_case_patchembeddingblock), len(self.test_case_patchembeddingblock))
self.assertEqual(len(test_case_patchembeddingblock[0]), len(self.test_case_patchembeddingblock[0]))
self.assertEqual(len(test_case_patchembeddingblock[0][0]), len(self.test_case_patchembeddingblock[0][0]))
self.assertEqual(
test_case_patchembeddingblock[0][0]["in_channels"], self.test_case_patchembeddingblock[0][0]["in_channels"]
)
self.assertEqual(test_case_patchembeddingblock[0][1], self.test_case_patchembeddingblock[0][1])
self.assertEqual(test_case_patchembeddingblock[0][2], self.test_case_patchembeddingblock[0][2])


if __name__ == "__main__":
unittest.main()
29 changes: 18 additions & 11 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
import traceback
import unittest
import warnings
from collections.abc import Iterable
from contextlib import contextmanager
from functools import partial, reduce
from itertools import product
from pathlib import Path
from subprocess import PIPE, Popen
from typing import Callable, Literal
from typing import Any, Callable
from urllib.error import ContentTooShortError, HTTPError

import numpy as np
Expand Down Expand Up @@ -864,18 +865,24 @@ def equal_state_dict(st_1, st_2):
TEST_DEVICES.append([torch.device("cuda")])


def dict_product(trailing=False, format: Literal["list", "dict"] = "dict", **items):
def dict_product(**items: Iterable[Any]) -> list[dict]:
"""Create cartesian product, equivalent to a nested for-loop, combinations of the items dict.

Args:
items: dict of items to be combined.

Returns:
list: list of dictionaries with the combinations of the input items.

Example:
>>> dict_product(x=[1, 2], y=[3, 4])
[{'x': 1, 'y': 3}, {'x': 1, 'y': 4}, {'x': 2, 'y': 3}, {'x': 2, 'y': 4}]
"""
keys = items.keys()
values = items.values()
for pvalues in product(*values):
dict_comb = dict(zip(keys, pvalues))
if format == "dict":
if trailing:
yield [dict_comb] + list(pvalues)
else:
yield dict_comb
else:
yield pvalues
prod_values = product(*values)
prod_dict = [dict(zip(keys, v)) for v in prod_values]
return prod_dict


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions tests/transforms/test_gibbs_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@

_, has_torch_fft = optional_import("torch.fft", name="fftshift")

params = {"shape": ((128, 64), (64, 48, 80)), "input_type": TEST_NDARRAYS if has_torch_fft else [np.array]}
TEST_CASES = list(dict_product(format="list", **params))
shapes = ((128, 64), (64, 48, 80))
input_types = TEST_NDARRAYS if has_torch_fft else [np.array]
TEST_CASES = [[p_dict["shape"], p_dict["input_type"]] for p_dict in dict_product(shape=shapes, input_type=input_types)]


class TestGibbsNoise(unittest.TestCase):
Expand Down
Loading