diff --git a/requirements/requirements.txt b/requirements/requirements.txt index fd85350ef..258bf30fb 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,6 +4,7 @@ aiohttp>=3.8.1 albumentations>=1.3.0 bokeh>=3.1.1, <3.6.0 Click>=8.1.3, <8.2.0 +dask>=0.12.1 defusedxml>=0.7.1 filelock>=3.9.0 flask>=2.2.2 diff --git a/tests/conftest.py b/tests/conftest.py index 2b7de0fd6..cdc53dee0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -531,6 +531,7 @@ def sample_wsi_dict(remote_sample: Callable) -> dict: "wsi4_4k_4k_svs", "wsi3_20k_20k_pred", "wsi4_4k_4k_pred", + "wsi4_1k_1k_svs", ] return {name: remote_sample(name) for name in file_names} diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index 0cbff82c1..04f39d7cb 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -6,10 +6,11 @@ import logging import shutil from pathlib import Path -from typing import TYPE_CHECKING, NoReturn +from typing import NoReturn import numpy as np import pytest +import torch import torchvision.models as torch_models from typing_extensions import Unpack @@ -26,8 +27,7 @@ ) from tiatoolbox.models.engine.io_config import ModelIOConfigABC -if TYPE_CHECKING: - import torch.nn +device = "cuda:0" if torch.cuda.is_available() else "cpu" class TestEngineABC(EngineABC): @@ -69,6 +69,8 @@ def post_process_wsi( """Post process WSI output.""" return super().post_process_wsi( raw_predictions=raw_predictions, + prediction_shape=(self.batch_size, 1), + prediction_dtype=int, **kwargs, ) @@ -79,7 +81,7 @@ def infer_wsi( **kwargs: dict, ) -> dict | np.ndarray: """Test infer_wsi.""" - return super().infer_wsi( + return super().infer_wsi( # skipcq: PYL-E1121 dataloader, save_path, **kwargs, @@ -362,7 +364,7 @@ def test_engine_run_with_verbose() -> NoReturn: out = eng.run( images=np.zeros((10, 224, 224, 3), dtype=np.uint8), labels=list(range(10)), - on_gpu=False, + device=device, ) assert "probabilities" in out @@ -457,37 +459,6 @@ def test_patch_pred_zarr_store(tmp_path: pytest.TempPathFactory) -> NoReturn: ) -def test_cache_mode_patches(tmp_path: pytest.TempPathFactory) -> NoReturn: - """Test the caching mode.""" - save_dir = tmp_path / "patch_output" - - eng = TestEngineABC(model="alexnet-kather100k") - out = eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), - on_gpu=False, - save_dir=save_dir, - overwrite=True, - cache_mode=True, - ) - assert out.exists(), "Zarr output file does not exist" - - output_file_name = "output2.zarr" - cache_size = 4 - out = eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), - on_gpu=False, - save_dir=save_dir, - overwrite=True, - cache_mode=True, - cache_size=4, - batch_size=8, - output_file=output_file_name, - ) - assert out.stem == output_file_name.split(".")[0] - assert eng.batch_size == cache_size - assert out.exists(), "Zarr output file does not exist" - - def test_get_dataloader(sample_svs: Path) -> None: """Test the get_dataloader function.""" eng = TestEngineABC(model="alexnet-kather100k") @@ -521,7 +492,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) eng = TestEngineABC(model=model) kwargs = { - "patch_input_shape": [512, 512], + "patch_input_shape": [224, 224], "input_resolutions": [{"units": "mpp", "resolution": 1.75}], } with caplog.at_level(logging.WARNING): @@ -537,7 +508,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) # test providing config / full input info for non pretrained models ioconfig = ModelIOConfigABC( - patch_input_shape=(512, 512), + patch_input_shape=(224, 224), stride_shape=(256, 256), input_resolutions=[{"resolution": 1.35, "units": "mpp"}], ) @@ -547,7 +518,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) save_dir=f"{tmp_path}/dump", ioconfig=ioconfig, ) - assert eng._ioconfig.patch_input_shape == (512, 512) + assert eng._ioconfig.patch_input_shape == (224, 224) assert eng._ioconfig.stride_shape == (256, 256) assert eng._ioconfig.input_resolutions == [{"resolution": 1.35, "units": "mpp"}] shutil.rmtree(tmp_path / "dump", ignore_errors=True) @@ -558,15 +529,15 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) save_dir=f"{tmp_path}/dump", **kwargs, ) - assert eng._ioconfig.patch_input_shape == [512, 512] - assert eng._ioconfig.stride_shape == [512, 512] + assert eng._ioconfig.patch_input_shape == [224, 224] + assert eng._ioconfig.stride_shape == [224, 224] assert eng._ioconfig.input_resolutions == [{"resolution": 1.75, "units": "mpp"}] shutil.rmtree(tmp_path / "dump", ignore_errors=True) # test overwriting pretrained ioconfig eng = TestEngineABC(model="alexnet-kather100k") eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + images=np.zeros((10, 300, 300, 3), dtype=np.uint8), patch_input_shape=(300, 300), stride_shape=(300, 300), input_resolutions=[{"units": "baseline", "resolution": 1.99}], @@ -580,7 +551,7 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) shutil.rmtree(tmp_path / "dump", ignore_errors=True) eng.run( - images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + images=np.zeros((10, 300, 300, 3), dtype=np.uint8), patch_input_shape=(300, 300), stride_shape=(300, 300), input_resolutions=None, @@ -618,3 +589,11 @@ def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) stride_shape=(1, 1), input_resolutions=_kwargs["input_resolutions"], ) + + +def test_save_predictions_incorrect_output_type() -> None: + """Engine should raise TypeError if incorrect output type is requested.""" + eng = TestEngineABC(model="alexnet-kather100k") + + with pytest.raises(TypeError, match=r".*Unsupported output type.* "): + eng.save_predictions({"predictions": np.zeros((20, 9))}, output_type="random") diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py index 76574c099..c6034f826 100644 --- a/tests/engines/test_patch_predictor.py +++ b/tests/engines/test_patch_predictor.py @@ -54,6 +54,7 @@ def _test_predictor_output( cache_mode=cache_mode, save_dir=save_dir, output_type=output_type, + return_probabilities=True, ) if tmp_path is not None: @@ -188,6 +189,7 @@ def test_patch_predictor_api( output = predictor.run( inputs, device="cpu", + return_probabilities=True, ) assert sorted(output.keys()) == ["predictions", "probabilities"] assert len(output["probabilities"]) == 2 @@ -198,10 +200,11 @@ def test_patch_predictor_api( inputs, labels=["1", "a"], return_labels=True, + return_probabilities=True, ) assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) assert len(output["probabilities"]) == len(output["labels"]) - assert output["labels"].tolist() == ["1", "a"] + assert list(output["labels"]) == ["1", "a"] shutil.rmtree(save_dir_path, ignore_errors=True) # test loading user weight @@ -234,10 +237,12 @@ def test_patch_predictor_api( labels=[1, 2], return_labels=True, ioconfig=ioconfig, + return_probabilities=True, + num_workers=1, ) assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) assert len(output["probabilities"]) == len(output["labels"]) - assert output["labels"].tolist() == [1, 2] + assert list(output["labels"]) == [1, 2] def test_wsi_predictor_api( @@ -274,6 +279,7 @@ def test_wsi_predictor_api( images=[mini_wsi_svs, mini_wsi_jpg], masks=[mini_wsi_msk, mini_wsi_msk], patch_mode=False, + return_probabilities=True, **_kwargs, ) @@ -324,7 +330,6 @@ def test_patch_predictor_kather100k_output( classification_check=[6, 3], ) - # cache mode for model, expected_prob in pretrained_info.items(): _test_predictor_output( inputs, diff --git a/tests/engines/test_semantic_segmentor.py b/tests/engines/test_semantic_segmentor.py new file mode 100644 index 000000000..5675889ff --- /dev/null +++ b/tests/engines/test_semantic_segmentor.py @@ -0,0 +1,509 @@ +"""Test SemanticSegmentor.""" + +from __future__ import annotations + +import json +import sqlite3 +import tempfile +from pathlib import Path +from typing import Callable +from unittest import mock + +import dask.array as da +import numpy as np +import pytest +import torch +import zarr +from click.testing import CliRunner + +from tiatoolbox import cli +from tiatoolbox.annotation import SQLiteStore +from tiatoolbox.models.engine import semantic_segmentor +from tiatoolbox.models.engine.semantic_segmentor import ( + SemanticSegmentor, + merge_vertical_chunkwise, +) +from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.misc import imread +from tiatoolbox.wsicore import WSIReader + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def test_semantic_segmentor_init() -> None: + """Tests SemanticSegmentor initialization.""" + segmentor = SemanticSegmentor(model="fcn-tissue_mask", device=device) + + assert isinstance(segmentor, SemanticSegmentor) + assert isinstance(segmentor.model, torch.nn.Module) + + +def test_semantic_segmentor_patches(remote_sample: Callable, tmp_path: Path) -> None: + """Tests SemanticSegmentor on image patches.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + inputs = [sample_image, sample_image] + + assert not segmentor.patch_mode + + output = segmentor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + ) + + assert 0.62 < np.mean(output["predictions"][:]) < 0.66 + assert 0.48 < np.mean(output["probabilities"][:]) < 0.52 + + assert ( + tuple(segmentor._ioconfig.patch_output_shape) + == output["probabilities"][0].shape[:-1] + ) + + assert ( + tuple(segmentor._ioconfig.patch_output_shape) == output["predictions"][0].shape + ) + + output = segmentor.run( + images=inputs, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + save_dir=tmp_path / "output0", + ) + + assert output == tmp_path / "output0" / "output.zarr" + + output = zarr.open(output, mode="r") + assert 0.62 < np.mean(output["predictions"][:]) < 0.66 + assert 0.48 < np.mean(output["probabilities"][:]) < 0.52 + + output = segmentor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + output_type="zarr", + save_dir=tmp_path / "output1", + ) + + assert output == tmp_path / "output1" / "output.zarr" + + output = zarr.open(output, mode="r") + assert 0.62 < np.mean(output["predictions"][:]) < 0.66 + assert "probabilities" not in output.keys() # noqa: SIM118 + + output = segmentor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + save_dir=tmp_path / "output2", + output_type="zarr", + ) + + assert output == tmp_path / "output2" / "output.zarr" + + output = zarr.open(output, mode="r") + assert 0.62 < np.mean(output["predictions"][:]) < 0.66 + assert "probabilities" not in output + assert "predictions" in output + + +def _test_store_output_patch(output: Path) -> None: + """Helper method to test annotation store output for a patch.""" + store_ = SQLiteStore.open(output) + annotations_ = store_.values() + annotations_geometry_type = [ + str(annotation_.geometry_type) for annotation_ in annotations_ + ] + assert "Polygon" in annotations_geometry_type + + con = sqlite3.connect(output) + cur = con.cursor() + annotations_properties = list(cur.execute("SELECT properties FROM annotations")) + + out = [] + + for item in annotations_properties: + for json_str in item: + probs = json.loads(json_str) + if "type" in probs: + out.append(probs.pop("type")) + + assert "mask" in out + + assert annotations_properties is not None + + +def test_save_annotation_store(remote_sample: Callable, tmp_path: Path) -> None: + """Test for saving output as annotation store.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + inputs = [sample_image] + + output = segmentor.run( + images=inputs, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + save_dir=tmp_path / "output1", + output_type="annotationstore", + verbose=True, + ) + + assert output[0] == tmp_path / "output1" / (sample_image.stem + ".db") + assert len(output) == 1 + _test_store_output_patch(output[0]) + + +def test_save_annotation_store_nparray( + remote_sample: Callable, tmp_path: Path, caplog: pytest.LogCaptureFixture +) -> None: + """Test for saving output as annotation store using a numpy array.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", batch_size=32, verbose=False, device=device + ) + + sample_image = remote_sample("thumbnail-1k-1k") + + input_image = imread(sample_image) + inputs_list = np.array([input_image, input_image]) + + output = segmentor.run( + images=inputs_list, + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=True, + save_dir=tmp_path / "output1", + output_type="annotationstore", + ) + + assert output[0] == tmp_path / "output1" / "0.db" + assert output[1] == tmp_path / "output1" / "1.db" + + assert (tmp_path / "output1" / "output.zarr").exists() + + zarr_group = zarr.open(str(tmp_path / "output1" / "output.zarr"), mode="r") + assert "probabilities" in zarr_group + + assert "Probability maps cannot be saved as AnnotationStore." in caplog.text + + _test_store_output_patch(output[0]) + _test_store_output_patch(output[1]) + + output = segmentor.run( + images=inputs_list, + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=True, + save_dir=tmp_path / "output2", + output_type="annotationstore", + ) + + assert output[0] == tmp_path / "output2" / "0.db" + assert output[1] == tmp_path / "output2" / "1.db" + assert not (tmp_path / "output2" / "output.zarr").exists() + + assert len(output) == 2 + + _test_store_output_patch(output[0]) + _test_store_output_patch(output[1]) + + +def test_non_overlapping_blocks() -> None: + """Test for non-overlapping merge to canvas.""" + blocks = np.array([np.ones((2, 2, 1)), np.ones((2, 2, 1)) * 2]) + output_locations = np.array([[0, 0, 2, 2], [2, 0, 4, 2]]) + merged_shape = (2, 4, 1) + canvas, count = semantic_segmentor.merge_batch_to_canvas( + blocks, output_locations, merged_shape + ) + assert np.array_equal(canvas[:, :2, :], np.ones((2, 2, 1))) + assert np.array_equal(canvas[:, 2:, :], np.ones((2, 2, 1)) * 2) + assert np.array_equal(count, np.ones((2, 4, 1))) + + +def test_overlapping_blocks() -> None: + """Test for overlapping merge to canvas.""" + blocks = np.array([np.ones((2, 2, 1)), np.ones((2, 2, 1)) * 3]) + output_locations = np.array([[0, 0, 2, 2], [1, 0, 3, 2]]) + merged_shape = (2, 3, 1) + canvas, count = semantic_segmentor.merge_batch_to_canvas( + blocks, output_locations, merged_shape + ) + expected_canvas = np.array([[[1], [4], [3]], [[1], [4], [3]]]) + expected_count = np.array([[[1], [2], [1]], [[1], [2], [1]]]) + assert np.array_equal(canvas, expected_canvas) + assert np.array_equal(count, expected_count) + + +def test_zero_block() -> None: + """Test for zero merge to canvas.""" + blocks = np.array([np.zeros((2, 2, 1)), np.ones((2, 2, 1))]) + output_locations = np.array([[0, 0, 2, 2], [2, 0, 4, 2]]) + merged_shape = (2, 4, 1) + canvas, count = semantic_segmentor.merge_batch_to_canvas( + blocks, output_locations, merged_shape + ) + assert np.array_equal(canvas[:, :2, :], np.zeros((2, 2, 1))) + assert np.array_equal(canvas[:, 2:, :], np.ones((2, 2, 1))) + assert np.array_equal(count[:, :2, :], np.zeros((2, 2, 1))) + assert np.array_equal(count[:, 2:, :], np.ones((2, 2, 1))) + + +def test_empty_blocks() -> None: + """Test for empty merge to canvas.""" + blocks = np.empty((0, 2, 2, 1)) + output_locations = np.empty((0, 4)) + merged_shape = (2, 2, 1) + canvas, count = semantic_segmentor.merge_batch_to_canvas( + blocks, output_locations, merged_shape + ) + assert np.array_equal(canvas, np.zeros((2, 2, 1))) + assert np.array_equal(count, np.zeros((2, 2, 1), dtype=np.uint8)) + + +def test_merge_vertical_chunkwise_memory_threshold_triggered() -> None: + """Test merge vertical chunkwise for memory threshold.""" + # Create dummy canvas and count arrays with 3 vertical chunks + data = np.ones((30, 10), dtype=np.uint8) + canvas = da.from_array(data, chunks=(10, 10)) + count = da.from_array(data, chunks=(10, 10)) + + # Output locations to simulate overlaps + output_locs_y_ = np.array([[0, 10], [10, 20], [20, 30]]) + + # Temporary Zarr group + with tempfile.TemporaryDirectory() as tmpdir: + save_path = Path(tmpdir) + + # Mock psutil to simulate low memory + with mock.patch( + "tiatoolbox.models.engine.semantic_segmentor.psutil.virtual_memory" + ) as mock_vm: + mock_vm.return_value.free = 1 # Very low free memory + + result = merge_vertical_chunkwise( + canvas=canvas, + count=count, + output_locs_y_=output_locs_y_, + zarr_group=None, + save_path=save_path, + memory_threshold=0.01, # Very low threshold to trigger the condition + ) + + # Assertions + assert isinstance(result, da.Array) + assert hasattr(result, "name") + assert result.name.startswith("from-zarr") + assert np.all(result.compute() == data) + + zarr_group = zarr.open(tmpdir, mode="r") + assert np.all(zarr_group["probabilities"][:] == data) + + +def test_raise_value_error_return_labels_wsi( + sample_svs: Path, + tmp_path: Path, +) -> None: + """Test for raises value error for return_labels in wsi mode.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", + batch_size=64, + verbose=False, + num_workers=1, + ) + with pytest.raises( + ValueError, + match=r".*return_labels` is not supported when `patch_mode` is False", + ): + _ = segmentor.run( + images=[sample_svs], + return_probabilities=False, + return_labels=True, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check", + batch_size=2, + output_type="zarr", + ) + + +def test_wsi_segmentor_zarr( + remote_sample: Callable, + sample_svs: Path, + tmp_path: Path, +) -> None: + """Test SemanticSegmentor for WSIs with zarr output.""" + wsi1_2k_2k_svs = Path(remote_sample("wsi1_2k_2k_svs")) + + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", + batch_size=64, + verbose=False, + num_workers=1, + ) + # Return Probabilities is False + output = segmentor.run( + images=[sample_svs], + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check", + batch_size=2, + output_type="zarr", + memory_threshold=1, + ) + + output_ = zarr.open(output[sample_svs], mode="r") + assert 0.17 < np.mean(output_["predictions"][:]) < 0.19 + assert "probabilities" not in output_ + assert "canvas" not in output_ + assert "count" not in output_ + + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", + batch_size=64, + verbose=False, + num_workers=1, + ) + # Return Probabilities is True + # Testing with WSIReader + output = segmentor.run( + images=[WSIReader.open(sample_svs)], + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "task_length_cache", + batch_size=2, + output_type="zarr", + memory_threshold=1, + ) + + output_ = zarr.open(output[sample_svs], mode="r") + assert 0.17 < np.mean(output_["predictions"][:]) < 0.19 + assert "probabilities" in output_ + assert "canvas" not in output_ + assert "count" not in output_ + + # Return Probabilities is True + # Using small image for faster run + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", + batch_size=32, + verbose=False, + num_workers=1, + ) + segmentor.drop_keys = [] + output = segmentor.run( + images=[sample_svs, wsi1_2k_2k_svs], + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check_prob", + output_type="zarr", + ) + + output_ = zarr.open(output[sample_svs], mode="r") + assert 0.17 < np.mean(output_["predictions"][:]) < 0.19 + assert 0.52 < np.mean(output_["probabilities"][:]) < 0.56 + + output_ = zarr.open(output[wsi1_2k_2k_svs], mode="r") + assert 0.24 < np.mean(output_["predictions"][:]) < 0.25 + assert 0.48 < np.mean(output_["probabilities"][:]) < 0.52 + + +def test_wsi_segmentor_annotationstore( + sample_svs: Path, tmp_path: Path, caplog: pytest.CaptureFixture +) -> None: + """Test SemanticSegmentor for WSIs with AnnotationStore output.""" + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", + batch_size=32, + verbose=False, + ) + # Return Probabilities is False + output = segmentor.run( + images=[sample_svs], + return_probabilities=False, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check", + verbose=True, + output_type="annotationstore", + ) + + assert output[sample_svs] == tmp_path / "wsi_out_check" / (sample_svs.stem + ".db") + + # Return Probabilities + segmentor = SemanticSegmentor( + model="fcn-tissue_mask", + batch_size=32, + verbose=False, + ) + # Return Probabilities is False + output = segmentor.run( + images=[sample_svs], + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_prob_out_check", + verbose=True, + output_type="annotationstore", + ) + + assert output[sample_svs] == tmp_path / "wsi_prob_out_check" / ( + sample_svs.stem + ".db" + ) + assert output[sample_svs].with_suffix(".zarr").exists() + + zarr_group = zarr.open(output[sample_svs].with_suffix(".zarr"), mode="r") + assert "probabilities" in zarr_group + assert "Probability maps cannot be saved as AnnotationStore." in caplog.text + + +# ------------------------------------------------------------------------------------- +# Command Line Interface +# ------------------------------------------------------------------------------------- + + +def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: + """Test for models CLI single file.""" + runner = CliRunner() + models_wsi_result = runner.invoke( + cli.main, + [ + "semantic-segmentor", + "--img-input", + str(sample_svs), + "--patch-mode", + "False", + "--output-path", + str(tmp_path / "output"), + ], + ) + + assert models_wsi_result.exit_code == 0 + assert (tmp_path / "output" / (sample_svs.stem + ".db")).exists() diff --git a/tests/models/test_arch_micronet.py b/tests/models/test_arch_micronet.py index e7aa23d5b..d1e84c1dc 100644 --- a/tests/models/test_arch_micronet.py +++ b/tests/models/test_arch_micronet.py @@ -7,7 +7,7 @@ import pytest import torch -from tiatoolbox.models import MicroNet, SemanticSegmentor +from tiatoolbox.models import MicroNet, NucleusInstanceSegmentor from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.utils import env_detection as toolbox_env from tiatoolbox.utils.misc import select_device @@ -63,7 +63,7 @@ def test_micronet_output(remote_sample: Callable, tmp_path: Path) -> None: num_loader_workers = 0 num_postproc_workers = 0 - predictor = SemanticSegmentor( + predictor = NucleusInstanceSegmentor( pretrained_model=pretrained_model, batch_size=batch_size, num_loader_workers=num_loader_workers, diff --git a/tests/models/test_arch_vanilla.py b/tests/models/test_arch_vanilla.py index a87424dfd..b19fce924 100644 --- a/tests/models/test_arch_vanilla.py +++ b/tests/models/test_arch_vanilla.py @@ -33,7 +33,7 @@ def test_functional() -> None: "mobilenet_v3_large", "mobilenet_v3_small", ] - assert CNNModel.postproc([1, 2]) == 1 + assert CNNModel.postproc(np.array([1, 2])) == 1 b = 4 h = w = 512 @@ -60,7 +60,7 @@ def test_timm_functional() -> None: backbones = [ "efficientnet_b0", ] - assert TimmModel.postproc([1, 2]) == 1 + assert TimmModel.postproc(np.array([1, 2])) == 1 b = 4 h = w = 224 diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index ab9a6033f..8edbf4112 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -20,6 +20,7 @@ ) from tiatoolbox.utils import download_data, imread, imwrite, unzip_data from tiatoolbox.utils import env_detection as toolbox_env +from tiatoolbox.utils.exceptions import DimensionMismatchError from tiatoolbox.wsicore import WSIReader RNG = np.random.default_rng() # Numpy Random Generator @@ -120,7 +121,9 @@ def test_kather_dataset(tmp_path: Path) -> None: assert len(dataset.inputs) == len(dataset.labels) # to actually get the image, we feed it to PatchDataset - actual_ds = PatchDataset(dataset.inputs, dataset.labels) + actual_ds = PatchDataset( + dataset.inputs, dataset.labels, patch_input_shape=(224, 224) + ) sample_patch = actual_ds[89] assert isinstance(sample_patch["image"], np.ndarray) assert sample_patch["label"] is not None @@ -129,6 +132,18 @@ def test_kather_dataset(tmp_path: Path) -> None: shutil.rmtree(save_dir_path, ignore_errors=True) +def test_incorrect_input_shape() -> None: + """Incorrect input patch dimensions should raise DimensionMismatchError.""" + size = (5, 5, 3) + img = RNG.integers(low=0, high=255, size=size) + list_imgs = [img, img, img] + dataset = PatchDataset(list_imgs, patch_input_shape=(100, 100)) + with pytest.raises( + DimensionMismatchError, match=r".*\(100, 100\), but got \(5, 5\).*" + ): + _ = dataset[0] + + def test_patch_dataset_path_imgs( sample_patch1: str | Path, sample_patch2: str | Path, @@ -136,7 +151,9 @@ def test_patch_dataset_path_imgs( """Test for patch dataset with a list of file paths as input.""" size = (224, 224, 3) - dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)]) + dataset = PatchDataset( + [Path(sample_patch1), Path(sample_patch2)], patch_input_shape=size[:-1] + ) for _, sample_data in enumerate(dataset): sampled_img_shape = sample_data["image"].shape @@ -152,7 +169,7 @@ def test_patch_dataset_list_imgs(tmp_path: Path) -> None: size = (5, 5, 3) img = RNG.integers(low=0, high=255, size=size) list_imgs = [img, img, img] - dataset = PatchDataset(list_imgs) + dataset = PatchDataset(list_imgs, patch_input_shape=size[:-1]) dataset.preproc_func = lambda x: x @@ -197,14 +214,14 @@ def test_patch_datasetarray_imgs() -> None: array_imgs = np.array(list_imgs) # test different setter for label - dataset = PatchDataset(array_imgs, labels=labels) + dataset = PatchDataset(array_imgs, labels=labels, patch_input_shape=(5, 5)) an_item = dataset[2] assert an_item["label"] == 3 - dataset = PatchDataset(array_imgs, labels=None) + dataset = PatchDataset(array_imgs, labels=None, patch_input_shape=(5, 5)) an_item = dataset[2] assert "label" not in an_item - dataset = PatchDataset(array_imgs) + dataset = PatchDataset(array_imgs, patch_input_shape=size[:-1]) for _, sample_data in enumerate(dataset): sampled_img_shape = sample_data["image"].shape assert sampled_img_shape[0] == size[0] @@ -329,16 +346,15 @@ def test_wsi_patch_dataset( # noqa: PLR0915 """A test for creation and bare output.""" # convert to pathlib Path to prevent wsireader complaint mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset: """Testing function.""" - return WSIPatchDataset(img_path=img_path, **kwargs) + return WSIPatchDataset(input_img=img_path, **kwargs) def reuse_init_wsi(**kwargs: dict) -> WSIPatchDataset: """Testing function.""" - return reuse_init(mode="wsi", **kwargs) + return reuse_init(**kwargs) # test for ABC validate # intentionally created to check error @@ -360,10 +376,9 @@ def __getitem__(self: Proto, idx: int) -> object: Proto() # skipcq # invalid path input - with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): + with pytest.raises(ValueError, match=r".*`input_img` must be a valid file path.*"): WSIPatchDataset( - img_path="aaaa", - mode="wsi", + input_img="aaaa", patch_input_shape=[512, 512], stride_shape=[256, 256], auto_get_mask=False, @@ -372,9 +387,8 @@ def __getitem__(self: Proto, idx: int) -> object: # invalid mask path input with pytest.raises(ValueError, match=r".*`mask_path` must be a valid file path.*"): WSIPatchDataset( - img_path=mini_wsi_svs, + input_img=mini_wsi_svs, mask_path="aaaa", - mode="wsi", patch_input_shape=[512, 512], stride_shape=[256, 256], resolution=1.0, @@ -382,10 +396,6 @@ def __getitem__(self: Proto, idx: int) -> object: auto_get_mask=False, ) - # invalid mode - with pytest.raises(ValueError, match="`X` is not supported."): - reuse_init(mode="X") - # invalid patch with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): reuse_init() @@ -427,9 +437,10 @@ def __getitem__(self: Proto, idx: int) -> object: # * for wsi # dummy test for analysing the output # stride and patch size should be as expected - patch_size = [512, 512] - stride_size = [256, 256] - ds = reuse_init_wsi( + patch_size = (512, 512) + stride_size = (256, 256) + ds = WSIPatchDataset( + input_img=WSIReader.open(mini_wsi_svs), patch_input_shape=patch_size, stride_shape=stride_size, resolution=1.0, @@ -457,7 +468,8 @@ def __getitem__(self: Proto, idx: int) -> object: assert np.min(correlation) > 0.9, correlation # test creation with auto mask gen and input mask - ds = reuse_init_wsi( + ds = WSIPatchDataset( + input_img=mini_wsi_svs, patch_input_shape=patch_size, stride_shape=stride_size, resolution=1.0, @@ -465,12 +477,11 @@ def __getitem__(self: Proto, idx: int) -> object: auto_get_mask=True, ) assert len(ds) > 0 - ds = WSIPatchDataset( - img_path=mini_wsi_svs, + _ = WSIPatchDataset( + input_img=mini_wsi_svs, mask_path=mini_wsi_msk, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], + patch_input_shape=(512, 512), + stride_shape=(256, 256), auto_get_mask=False, resolution=1.0, units="mpp", @@ -480,44 +491,16 @@ def __getitem__(self: Proto, idx: int) -> object: negative_mask_path = tmp_path / "negative_mask.png" imwrite(negative_mask_path, negative_mask) with pytest.raises(ValueError, match="No patch coordinates remain after filtering"): - ds = WSIPatchDataset( - img_path=mini_wsi_svs, + _ = WSIPatchDataset( + input_img=mini_wsi_svs, mask_path=negative_mask_path, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], + patch_input_shape=(512, 512), + stride_shape=(256, 256), auto_get_mask=False, resolution=1.0, units="mpp", ) - # * for tile - reader = WSIReader.open(mini_wsi_jpg) - tile_ds = WSIPatchDataset( - img_path=mini_wsi_jpg, - mode="tile", - patch_input_shape=patch_size, - stride_shape=stride_size, - auto_get_mask=False, - ) - step_idx = 3 # manually calibrate - start = (step_idx * stride_size[1], 0) - end = (start[0] + patch_size[0], start[1] + patch_size[1]) - roi2 = reader.read_bounds( - start + end, - resolution=1.0, - units="baseline", - coord_space="resolution", - ) - roi1 = tile_ds[3]["image"] # match with step_index - correlation = np.corrcoef( - cv2.cvtColor(roi1, cv2.COLOR_RGB2GRAY).flatten(), - cv2.cvtColor(roi2, cv2.COLOR_RGB2GRAY).flatten(), - ) - assert roi1.shape[0] == roi2.shape[0] - assert roi1.shape[1] == roi2.shape[1] - assert np.min(correlation) > 0.9, correlation - def test_patch_dataset_abc() -> None: """Test for ABC methods. diff --git a/tests/test_utils.py b/tests/test_utils.py index 77b5b6646..b9720826a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, NoReturn import cv2 +import dask.array as da import joblib import numpy as np import pandas as pd @@ -34,6 +35,7 @@ ) from tiatoolbox.utils import misc from tiatoolbox.utils.exceptions import FileNotSupportedError +from tiatoolbox.utils.misc import cast_to_min_dtype from tiatoolbox.utils.transforms import locsize2bounds if TYPE_CHECKING: @@ -1672,7 +1674,7 @@ def test_patch_pred_store() -> None: "other": "other", } - store = misc.dict_to_store(patch_output, (1.0, 1.0)) + store = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1685,7 +1687,7 @@ def test_patch_pred_store() -> None: patch_output.pop("coordinates") # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): - misc.dict_to_store(patch_output, (1.0, 1.0)) + misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) patch_output = { "predictions": [1, 0, 1], @@ -1693,7 +1695,7 @@ def test_patch_pred_store() -> None: "other": "other", } - store = misc.dict_to_store(patch_output, (1.0, 1.0)) + store = misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1710,7 +1712,9 @@ def test_patch_pred_store_cdict() -> None: "other": "other", } class_dict = {0: "class0", 1: "class1"} - store = misc.dict_to_store(patch_output, (1.0, 1.0), class_dict=class_dict) + store = misc.dict_to_store_patch_predictions( + patch_output, (1.0, 1.0), class_dict=class_dict + ) # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1731,7 +1735,7 @@ def test_patch_pred_store_sf() -> None: "probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]], "labels": [1, 0, 1], } - store = misc.dict_to_store(patch_output, (2.0, 2.0)) + store = misc.dict_to_store_patch_predictions(patch_output, (2.0, 2.0)) # Check that its an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) @@ -1740,43 +1744,6 @@ def test_patch_pred_store_sf() -> None: assert annotation.geometry.area == 4 -def test_patch_pred_store_zarr(tmp_path: pytest.TempPathFactory) -> None: - """Test patch_pred_store_zarr.""" - # Define a mock patch_output - patch_output = { - "predictions": [1, 0, 1], - "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], - "probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]], - "labels": [1, 0, 1], - } - - save_path = tmp_path / "patch_output" / "output.zarr" - - store_path = misc.dict_to_zarr(patch_output, save_path=save_path) - - print("Zarr path: ", store_path) - assert Path.exists(store_path), "Zarr output file does not exist" - - -def test_patch_pred_store_zarr_ext(tmp_path: pytest.TempPathFactory) -> None: - """Test patch_pred_store_zarr and ensures the output file extension is `.zarr`.""" - # Define a mock patch_output - patch_output = { - "predictions": [1, 0, 1], - "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], - "probabilities": [[0.1, 0.9], [0.9, 0.1], [0.4, 0.6]], - "labels": [1, 0, 1], - } - - # sends the path of a jpeg source image, expects .zarr file in the same directory - save_path = tmp_path / "patch_output" / "patch.jpeg" - - store_path = misc.dict_to_zarr(patch_output, save_path=save_path) - - print("Zarr path: ", store_path) - assert Path.exists(store_path), "Zarr output file does not exist" - - def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None: """Test patch_pred_store. and persists store output to a .db file.""" # Define a mock patch_output @@ -1788,7 +1755,9 @@ def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None: } save_path = tmp_path / "patch_output" / "output.db" - store_path = misc.dict_to_store(patch_output, (1.0, 1.0), save_path=save_path) + store_path = misc.dict_to_store_patch_predictions( + patch_output, (1.0, 1.0), save_path=save_path + ) print("Annotation store path: ", store_path) assert Path.exists(store_path), "Annotation Store output file does not exist" @@ -1806,7 +1775,7 @@ def test_patch_pred_store_persist(tmp_path: pytest.TempPathFactory) -> None: patch_output.pop("coordinates") # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): - misc.dict_to_store(patch_output, (1.0, 1.0)) + misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: @@ -1822,7 +1791,9 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: # sends the path of a jpeg source image, expects .db file in the same directory save_path = tmp_path / "patch_output" / "output.jpeg" - store_path = misc.dict_to_store(patch_output, (1.0, 1.0), save_path=save_path) + store_path = misc.dict_to_store_patch_predictions( + patch_output, (1.0, 1.0), save_path=save_path + ) print("Annotation store path: ", store_path) assert Path.exists(store_path), "Annotation Store output file does not exist" @@ -1840,7 +1811,7 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: patch_output.pop("coordinates") # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): - misc.dict_to_store(patch_output, (1.0, 1.0)) + misc.dict_to_store_patch_predictions(patch_output, (1.0, 1.0)) def test_torch_compile_already_compiled() -> None: @@ -2210,3 +2181,45 @@ def test_save_zarr_array_probability_ome_tiff( assert_ome_metadata_value(ome_xml, "PhysicalSizeY", "0.25") assert_ome_metadata_value(ome_xml, "PhysicalSizeXUnit", "µm") assert_ome_metadata_value(ome_xml, "PhysicalSizeYUnit", "µm") + + +@pytest.mark.parametrize( + ("input_array", "expected_dtype"), + [ + (np.array([0, 1]), np.bool_), # Should cast to bool + (np.array([0, 255]), np.uint8), # Should cast to uint8 + (np.array([0, 256]), np.uint16), # Should cast to uint16 + (np.array([0, 70000]), np.uint32), # Should cast to uint32 + (np.array([0, 2**32]), np.uint64), # Should cast to uint64 + ], +) +def test_cast_to_min_dtype_numpy(input_array: np.ndarray, expected_dtype: type) -> None: + """Check expected np array dtype cast_to_min_dtype.""" + result = cast_to_min_dtype(input_array) + assert isinstance(result, np.ndarray) + assert result.dtype == expected_dtype + + +@pytest.mark.parametrize( + ("input_array", "expected_dtype"), + [ + (da.from_array(np.array([0, 1])), np.bool_), # Should cast to bool + (da.from_array(np.array([0, 255])), np.uint8), # Should cast to uint8 + (da.from_array(np.array([0, 256])), np.uint16), # Should cast to uint16 + (da.from_array(np.array([0, 70000])), np.uint32), # Should cast to uint32 + (da.from_array(np.array([0, 2**32])), np.uint64), # Should cast to uint64 + ], +) +def test_cast_to_min_dtype_dask(input_array: da.Array, expected_dtype: type) -> None: + """Check expected dask array dtype cast_to_min_dtype.""" + result = cast_to_min_dtype(input_array) + assert isinstance(result, da.Array) + assert result.dtype == expected_dtype + + +def test_cast_to_min_dtype_numpy_large_value() -> None: + """Check if return type is changed for large value.""" + large_value = np.array([np.iinfo(np.uint64).max + 1], dtype=object) + result = cast_to_min_dtype(large_value) + assert result == large_value + assert result.dtype == object diff --git a/tiatoolbox/cli/__init__.py b/tiatoolbox/cli/__init__.py index cf6b35701..38c69aa85 100644 --- a/tiatoolbox/cli/__init__.py +++ b/tiatoolbox/cli/__init__.py @@ -11,7 +11,7 @@ from tiatoolbox.cli.patch_predictor import patch_predictor from tiatoolbox.cli.read_bounds import read_bounds from tiatoolbox.cli.save_tiles import save_tiles -from tiatoolbox.cli.semantic_segment import semantic_segment +from tiatoolbox.cli.semantic_segmentor import semantic_segmentor from tiatoolbox.cli.show_wsi import show_wsi from tiatoolbox.cli.slide_info import slide_info from tiatoolbox.cli.slide_thumbnail import slide_thumbnail @@ -42,7 +42,7 @@ def main() -> click.BaseCommand: main.add_command(patch_predictor) main.add_command(read_bounds) main.add_command(save_tiles) -main.add_command(semantic_segment) +main.add_command(semantic_segmentor) main.add_command(slide_info) main.add_command(slide_thumbnail) main.add_command(tissue_mask) diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 6a29032d9..bff4bf7de 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -387,14 +387,28 @@ def cli_masks( ) -def cli_auto_generate_mask( +def cli_memory_threshold( + usage_help: str = ( + "Memory usage threshold (in percentage) to trigger caching behavior." + ), + default: int = 80, +) -> Callable: + """Enables --batch-size option for cli.""" + return click.option( + "--memory-threshold", + help=add_default_to_usage_help(usage_help, default=default), + default=default, + ) + + +def cli_auto_get_mask( usage_help: str = "Automatically generate tile/WSI tissue mask.", *, default: bool = False, ) -> Callable: """Enables --auto-generate-mask option for cli.""" return click.option( - "--auto-generate-mask", + "--auto-get-mask", help=add_default_to_usage_help(usage_help, default=default), type=bool, default=default, @@ -415,27 +429,14 @@ def cli_yaml_config_path( ) -def cli_num_loader_workers( +def cli_num_workers( usage_help: str = "Number of workers to load the data. Please note that they will " "also perform preprocessing.", default: int = 0, ) -> Callable: """Enables --num-loader-workers option for cli.""" return click.option( - "--num-loader-workers", - help=add_default_to_usage_help(usage_help, default=default), - type=int, - default=default, - ) - - -def cli_num_postproc_workers( - usage_help: str = "Number of workers to post-process the network output.", - default: int = 0, -) -> Callable: - """Enables --num-postproc-workers option for cli.""" - return click.option( - "--num-postproc-workers", + "--num-workers", help=add_default_to_usage_help(usage_help, default=default), type=int, default=default, diff --git a/tiatoolbox/cli/nucleus_instance_segment.py b/tiatoolbox/cli/nucleus_instance_segment.py index 94a9208c2..707e71f5b 100644 --- a/tiatoolbox/cli/nucleus_instance_segment.py +++ b/tiatoolbox/cli/nucleus_instance_segment.py @@ -5,15 +5,14 @@ import click from tiatoolbox.cli.common import ( - cli_auto_generate_mask, + cli_auto_get_mask, cli_batch_size, cli_device, cli_file_type, cli_img_input, cli_masks, cli_mode, - cli_num_loader_workers, - cli_num_postproc_workers, + cli_num_workers, cli_output_path, cli_pretrained_model, cli_pretrained_weights, @@ -45,10 +44,9 @@ @cli_batch_size() @cli_masks(default=None) @cli_yaml_config_path(default=None) -@cli_num_loader_workers() +@cli_num_workers() @cli_verbose(default=True) -@cli_num_postproc_workers(default=0) -@cli_auto_generate_mask(default=False) +@cli_auto_get_mask(default=False) def nucleus_instance_segment( pretrained_model: str, pretrained_weights: str, @@ -60,7 +58,6 @@ def nucleus_instance_segment( batch_size: int, yaml_config_path: str, num_loader_workers: int, - num_postproc_workers: int, device: str, *, auto_generate_mask: bool, @@ -91,7 +88,6 @@ def nucleus_instance_segment( pretrained_weights=pretrained_weights, batch_size=batch_size, num_loader_workers=num_loader_workers, - num_postproc_workers=num_postproc_workers, auto_generate_mask=auto_generate_mask, verbose=verbose, ) diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index a33e8e98a..17ed9ebf9 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -3,13 +3,15 @@ from __future__ import annotations from tiatoolbox.cli.common import ( + cli_auto_get_mask, cli_batch_size, cli_device, cli_file_type, cli_img_input, cli_masks, + cli_memory_threshold, cli_model, - cli_num_loader_workers, + cli_num_workers, cli_output_path, cli_output_type, cli_patch_mode, @@ -39,13 +41,15 @@ @cli_batch_size(default=1) @cli_yaml_config_path() @cli_masks(default=None) -@cli_num_loader_workers(default=0) +@cli_num_workers(default=0) @cli_output_type( default="AnnotationStore", ) +@cli_memory_threshold(default=80) @cli_patch_mode(default=False) @cli_return_probabilities(default=True) @cli_return_labels(default=False) +@cli_auto_get_mask(default=True) @cli_verbose(default=True) def patch_predictor( model: str, @@ -56,16 +60,18 @@ def patch_predictor( output_path: str, batch_size: int, yaml_config_path: str, - num_loader_workers: int, + num_workers: int, device: str, output_type: str, + memory_threshold: int, *, + patch_mode: bool, return_probabilities: bool, return_labels: bool, - patch_mode: bool, + auto_get_mask: bool, verbose: bool, ) -> None: - """Process an image/directory of input images with a patch classification CNN.""" + """Process an image/directory of input images with a patch classification engine.""" from tiatoolbox.models.engine.io_config import ( # noqa: PLC0415 IOPatchPredictorConfig, ) @@ -82,7 +88,7 @@ def patch_predictor( model=model, weights=weights, batch_size=batch_size, - num_loader_workers=num_loader_workers, + num_workers=num_workers, verbose=verbose, ) @@ -102,4 +108,6 @@ def patch_predictor( output_type=output_type, return_probabilities=return_probabilities, return_labels=return_labels, + auto_get_mask=auto_get_mask, + memory_threshold=memory_threshold, ) diff --git a/tiatoolbox/cli/semantic_segment.py b/tiatoolbox/cli/semantic_segment.py deleted file mode 100644 index c494e06eb..000000000 --- a/tiatoolbox/cli/semantic_segment.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Command line interface for semantic segmentation.""" - -from __future__ import annotations - -import click - -from tiatoolbox.cli.common import ( - cli_batch_size, - cli_device, - cli_file_type, - cli_img_input, - cli_masks, - cli_mode, - cli_num_loader_workers, - cli_output_path, - cli_pretrained_model, - cli_pretrained_weights, - cli_verbose, - cli_yaml_config_path, - prepare_ioconfig, - prepare_model_cli, - tiatoolbox_cli, -) - - -@tiatoolbox_cli.command() -@cli_img_input() -@cli_output_path( - usage_help="Output directory where model predictions will be saved.", - default="semantic_segmentation", -) -@cli_file_type( - default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", -) -@cli_mode( - usage_help="Type of input file to process.", - default="wsi", - input_type=click.Choice(["patch", "wsi", "tile"], case_sensitive=False), -) -@cli_pretrained_model(default="fcn-tissue_mask") -@cli_pretrained_weights(default=None) -@cli_device() -@cli_batch_size() -@cli_masks(default=None) -@cli_yaml_config_path() -@cli_num_loader_workers() -@cli_verbose() -def semantic_segment( - pretrained_model: str, - pretrained_weights: str, - img_input: str, - file_types: str, - masks: str | None, - mode: str, - output_path: str, - batch_size: int, - yaml_config_path: str, - num_loader_workers: int, - device: str, - *, - verbose: bool, -) -> None: - """Process an image/directory of input images with a patch classification CNN.""" - from tiatoolbox.models import IOSegmentorConfig, SemanticSegmentor # noqa: PLC0415 - from tiatoolbox.utils import save_as_json # noqa: PLC0415 - - files_all, masks_all, output_path = prepare_model_cli( - img_input=img_input, - output_path=output_path, - masks=masks, - file_types=file_types, - ) - - ioconfig = prepare_ioconfig( - IOSegmentorConfig, - pretrained_weights, - yaml_config_path, - ) - - predictor = SemanticSegmentor( - pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, - batch_size=batch_size, - num_loader_workers=num_loader_workers, - verbose=verbose, - ) - - output = predictor.predict( - imgs=files_all, - masks=masks_all, - mode=mode, - device=device, - save_dir=output_path, - ioconfig=ioconfig, - ) - - save_as_json(output, str(output_path.joinpath("results.json"))) diff --git a/tiatoolbox/cli/semantic_segmentor.py b/tiatoolbox/cli/semantic_segmentor.py new file mode 100644 index 000000000..78b27b83c --- /dev/null +++ b/tiatoolbox/cli/semantic_segmentor.py @@ -0,0 +1,110 @@ +"""Command line interface for semantic segmentation.""" + +from __future__ import annotations + +from tiatoolbox.cli.common import ( + cli_auto_get_mask, + cli_batch_size, + cli_device, + cli_file_type, + cli_img_input, + cli_masks, + cli_memory_threshold, + cli_model, + cli_num_workers, + cli_output_path, + cli_output_type, + cli_patch_mode, + cli_return_labels, + cli_return_probabilities, + cli_verbose, + cli_weights, + cli_yaml_config_path, + prepare_ioconfig, + prepare_model_cli, + tiatoolbox_cli, +) + + +@tiatoolbox_cli.command() +@cli_img_input() +@cli_output_path( + usage_help="Output directory where model segmentation will be saved.", + default="semantic_segmentation", +) +@cli_file_type( + default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", +) +@cli_model(default="fcn-tissue_mask") +@cli_weights() +@cli_device(default="cpu") +@cli_batch_size(default=1) +@cli_yaml_config_path() +@cli_masks(default=None) +@cli_num_workers(default=0) +@cli_output_type( + default="AnnotationStore", +) +@cli_memory_threshold(default=80) +@cli_patch_mode(default=False) +@cli_return_probabilities(default=True) +@cli_return_labels(default=False) +@cli_auto_get_mask(default=True) +@cli_verbose(default=True) +def semantic_segmentor( + model: str, + weights: str, + img_input: str, + file_types: str, + masks: str | None, + output_path: str, + batch_size: int, + yaml_config_path: str, + num_workers: int, + device: str, + output_type: str, + memory_threshold: int, + *, + patch_mode: bool, + return_probabilities: bool, + return_labels: bool, + auto_get_mask: bool, + verbose: bool, +) -> None: + """Process a set of input images with a semantic segmentation engine.""" + from tiatoolbox.models import IOSegmentorConfig, SemanticSegmentor # noqa: PLC0415 + + files_all, masks_all, output_path = prepare_model_cli( + img_input=img_input, + output_path=output_path, + masks=masks, + file_types=file_types, + ) + + ioconfig = prepare_ioconfig( + IOSegmentorConfig, + pretrained_weights=weights, + yaml_config_path=yaml_config_path, + ) + + segmentor = SemanticSegmentor( + model=model, + weights=weights, + batch_size=batch_size, + num_workers=num_workers, + verbose=verbose, + ) + + _ = segmentor.run( + images=files_all, + masks=masks_all, + patch_mode=patch_mode, + ioconfig=ioconfig, + device=device, + save_dir=output_path, + output_type=output_type, + return_probabilities=return_probabilities, + return_labels=return_labels, + auto_get_mask=auto_get_mask, + memory_threshold=memory_threshold, + ) diff --git a/tiatoolbox/data/pretrained_model.yaml b/tiatoolbox/data/pretrained_model.yaml index 3dd13f0d8..594036ed0 100644 --- a/tiatoolbox/data/pretrained_model.yaml +++ b/tiatoolbox/data/pretrained_model.yaml @@ -587,7 +587,7 @@ fcn-tissue_mask: - {'units': 'mpp', 'resolution': 2.0} patch_input_shape: [1024, 1024] patch_output_shape: [512, 512] - stride_shape: [256, 256] + stride_shape: [450, 450] save_resolution: {'units': 'mpp', 'resolution': 8.0} fcn_resnet50_unet-bcss: @@ -608,7 +608,7 @@ fcn_resnet50_unet-bcss: - {'units': 'mpp', 'resolution': 0.25} patch_input_shape: [1024, 1024] patch_output_shape: [512, 512] - stride_shape: [256, 256] + stride_shape: [450, 450] save_resolution: {'units': 'mpp', 'resolution': 0.25} unet_tissue_mask_tsef: diff --git a/tiatoolbox/data/remote_samples.yaml b/tiatoolbox/data/remote_samples.yaml index 1b7bf2bf1..44e7d3492 100644 --- a/tiatoolbox/data/remote_samples.yaml +++ b/tiatoolbox/data/remote_samples.yaml @@ -21,6 +21,8 @@ files: extract: True svs-1-small: url: [*wsis, "CMU-1-Small-Region.svs"] + thumbnail-1k-1k: + url: [*wsis, "CMU-2_1k_1k-thumbnail.png"] tiled-tiff-1-small-jpeg: url: [*wsis, "CMU-1-Small-Region.jpeg.tiff"] tiled-tiff-1-small-jp2k: diff --git a/tiatoolbox/models/__init__.py b/tiatoolbox/models/__init__.py index ab52740ed..5de543aad 100644 --- a/tiatoolbox/models/__init__.py +++ b/tiatoolbox/models/__init__.py @@ -20,11 +20,10 @@ from .engine.multi_task_segmentor import MultiTaskSegmentor from .engine.nucleus_instance_segmentor import NucleusInstanceSegmentor from .engine.patch_predictor import PatchPredictor -from .engine.semantic_segmentor import DeepFeatureExtractor, SemanticSegmentor +from .engine.semantic_segmentor import SemanticSegmentor __all__ = [ "SCCNN", - "DeepFeatureExtractor", "HoVerNet", "HoVerNetPlus", "IDaRS", diff --git a/tiatoolbox/models/architecture/unet.py b/tiatoolbox/models/architecture/unet.py index 6385e7587..4af54713a 100644 --- a/tiatoolbox/models/architecture/unet.py +++ b/tiatoolbox/models/architecture/unet.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn.functional as F # noqa: N812 @@ -10,9 +10,16 @@ from torchvision.models.resnet import Bottleneck as ResNetBottleneck from torchvision.models.resnet import ResNet -from tiatoolbox.models.architecture.utils import UpSample2x, centre_crop +from tiatoolbox.models.architecture.utils import ( + UpSample2x, + argmax_last_axis, + centre_crop, +) from tiatoolbox.models.models_abc import ModelABC +if TYPE_CHECKING: # pragma: no cover + import numpy as np + class ResNetEncoder(ResNet): """A subclass of ResNet defined in torch. @@ -416,7 +423,7 @@ def infer_batch( batch_data: torch.Tensor, *, device: str, - ) -> list: + ) -> np.ndarray: """Run inference on an input batch. This contains logic for forward operation as well as i/o @@ -432,9 +439,8 @@ def infer_batch( Transfers model to the specified device. Default is "cpu". Returns: - list: - List of network output head, each output is an - :class:`numpy.ndarray`. + dict: + A dict with "probabilities" key and a :class:`numpy.ndarray` as output. """ model.eval() @@ -457,7 +463,14 @@ def infer_batch( align_corners=False, ) probs = centre_crop(probs, crop_shape) - probs = probs.permute(0, 2, 3, 1) # to NHWC + output = probs.permute(0, 2, 3, 1) # to NHWC + + return output.cpu().numpy() - probs = probs.cpu().numpy() - return [probs] + def postproc(self: UNetModel, image: np.ndarray) -> np.ndarray: + """Define post-processing of this class of model. + + This simply applies argmax along last axis of the input. + + """ + return argmax_last_axis(image=image) diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index e9560e59e..72d430e78 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -234,3 +234,20 @@ def forward(self: UpSample2x, x: torch.Tensor) -> torch.Tensor: ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw ret = ret.permute(0, 1, 2, 4, 3, 5) return ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2)) + + +def argmax_last_axis(image: np.ndarray) -> np.ndarray: + """Define the post-processing of this class of model. + + This simply applies argmax along last axis of the input. + + Args: + image (np.ndarray): + The input image array. + + Returns: + np.ndarray: + The post-processed image array. + + """ + return image.argmax(axis=-1) diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index b471d026c..a99f5e695 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -4,16 +4,17 @@ from typing import TYPE_CHECKING -import numpy as np import timm import torch import torchvision.models as torch_models from timm.layers import SwiGLUPacked from torch import nn +from tiatoolbox.models.architecture.utils import argmax_last_axis from tiatoolbox.models.models_abc import ModelABC if TYPE_CHECKING: # pragma: no cover + import numpy as np from torchvision.models import WeightsEnum @@ -205,28 +206,11 @@ def _get_timm_architecture( raise ValueError(msg) -def _postproc(image: np.ndarray) -> np.ndarray: - """Define the post-processing of this class of model. - - This simply applies argmax along last axis of the input. - - Args: - image (np.ndarray): - The input image array. - - Returns: - np.ndarray: - The post-processed image array. - - """ - return np.argmax(image, axis=-1) - - def _infer_batch( model: nn.Module, batch_data: torch.Tensor, device: str, -) -> dict[str, np.ndarray]: +) -> np.ndarray: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -260,7 +244,7 @@ def _infer_batch( with torch.inference_mode(): output = model(img_patches_device) # Output should be a single tensor or scalar - return {"probabilities": output.cpu().numpy()} + return output.cpu().numpy() class CNNModel(ModelABC): @@ -339,14 +323,14 @@ def postproc(image: np.ndarray) -> np.ndarray: The post-processed image array. """ - return _postproc(image=image) + return argmax_last_axis(image=image) @staticmethod def infer_batch( model: nn.Module, batch_data: torch.Tensor, device: str = "cpu", - ) -> dict[str, np.ndarray]: + ) -> np.ndarray: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -463,14 +447,14 @@ def postproc(image: np.ndarray) -> np.ndarray: The post-processed image array. """ - return _postproc(image=image) + return argmax_last_axis(image=image) @staticmethod def infer_batch( model: nn.Module, batch_data: torch.Tensor, device: str, - ) -> dict[str, np.ndarray]: + ) -> np.ndarray: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -482,10 +466,10 @@ def infer_batch( A batch of data generated by `torch.utils.data.DataLoader`. device (str): - Transfers model to the specified device. Default is "cpu". + Transfers model to the specified device. Returns: - dict[str, np.ndarray]: + np.ndarray: The model predictions as a NumPy array. Example: @@ -573,7 +557,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, device: str, - ) -> list[dict[str, np.ndarray]]: + ) -> list[np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -588,7 +572,7 @@ def infer_batch( Transfers model to the specified device. Default is "cpu". Returns: - list[dict[str, np.ndarray]]: + list[np.ndarray]: list of dictionary values with numpy arrays. Example: @@ -665,7 +649,7 @@ def infer_batch( model: nn.Module, batch_data: torch.Tensor, device: str, - ) -> list[dict[str, np.ndarray]]: + ) -> list[np.ndarray]: """Run inference on an input batch. Contains logic for forward operation as well as i/o aggregation. @@ -680,7 +664,7 @@ def infer_batch( Transfers model to the specified device. Default is "cpu". Returns: - list[dict[str, np.ndarray]]: + list[np.ndarray]: list of dictionary values with numpy arrays. Example: diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index d03fd7b38..4f7de08e5 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import os from abc import ABC, abstractmethod from pathlib import Path from typing import TYPE_CHECKING, Callable, Union @@ -15,6 +16,7 @@ from tiatoolbox import logger from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils import imread +from tiatoolbox.utils.exceptions import DimensionMismatchError from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIMeta, WSIReader if TYPE_CHECKING: # pragma: no cover @@ -22,7 +24,7 @@ from multiprocessing.managers import Namespace from tiatoolbox.models.engine.io_config import IOSegmentorConfig - from tiatoolbox.typing import IntPair, Resolution, Units + from tiatoolbox.type_hints import IntPair, Resolution, Units try: from typing import TypeGuard @@ -361,12 +363,12 @@ class WSIPatchDataset(PatchDatasetABC): """ - def __init__( # skipcq: PY-R1000 # noqa: PLR0915 + def __init__( # skipcq: PY-R1000 self: WSIPatchDataset, - img_path: str | Path, - mode: str = "wsi", + input_img: str | Path | WSIReader, mask_path: str | Path | None = None, patch_input_shape: IntPair = None, + patch_output_shape: IntPair = None, stride_shape: IntPair = None, resolution: Resolution = None, units: Units = None, @@ -378,13 +380,8 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 """Create a WSI-level patch dataset. Args: - mode (str): - Can be either `wsi` or `tile` to denote the image to - read is either a whole-slide image or a large image - tile. - img_path (str or Path): - Valid to pyramidal whole-slide image or large tile to - read. + input_img (str or Path or WSIReader): + Valid path to a whole-slide image class:`WSIReader`. mask_path (str or Path): Valid mask image. patch_input_shape: @@ -393,6 +390,12 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 and `units`. Expected to be positive and of (height, width). Note, this is not at `resolution` coordinate space. + patch_output_shape: + A tuple (int, int) or ndarray of shape (2,). Expected + output shape from the model at requested `resolution` + and `units`. Expected to be positive and of (height, + width). Note, this is not at `resolution` coordinate + space. stride_shape: A tuple (int, int) or ndarray of shape (2,). Expected stride shape to read at requested `resolution` and @@ -421,8 +424,7 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 >>> # Create a dataset to get patches from WSI with above >>> # preprocessing function >>> ds = WSIPatchDataset( - ... img_path='/A/B/C/wsi.svs', - ... mode="wsi", + ... input_img='/A/B/C/wsi.svs', ... patch_input_shape=[512, 512], ... stride_shape=[256, 256], ... auto_get_mask=False, @@ -432,52 +434,51 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 """ super().__init__() + valid_path = bool( + isinstance(input_img, (str, Path)) and Path(input_img).is_file() + ) # Is there a generic func for path test in toolbox? - if not Path.is_file(Path(img_path)): - msg = "`img_path` must be a valid file path." - raise ValueError(msg) - if mode not in ["wsi", "tile"]: - msg = f"`{mode}` is not supported." + if not valid_path and not isinstance(input_img, WSIReader): + msg = "`input_img` must be a valid file path or a `WSIReader` instance." raise ValueError(msg) patch_input_shape = np.array(patch_input_shape) stride_shape = np.array(stride_shape) - if ( - not np.issubdtype(patch_input_shape.dtype, np.integer) - or np.size(patch_input_shape) > 2 # noqa: PLR2004 - or np.any(patch_input_shape < 0) - ): - msg = f"Invalid `patch_input_shape` value {patch_input_shape}." - raise ValueError(msg) - if ( - not np.issubdtype(stride_shape.dtype, np.integer) - or np.size(stride_shape) > 2 # noqa: PLR2004 - or np.any(stride_shape < 0) - ): - msg = f"Invalid `stride_shape` value {stride_shape}." - raise ValueError(msg) + _validate_patch_stride_shape(patch_input_shape, stride_shape) self.preproc_func = preproc_func + img_path = ( + input_img if not isinstance(input_img, WSIReader) else input_img.input_path + ) self.img_path = Path(img_path) - self.mode = mode - self.reader = None - reader = self._get_reader(self.img_path) - if mode != "wsi": - units = "mpp" - resolution = 1.0 - + reader = ( + input_img + if isinstance(input_img, WSIReader) + else WSIReader.open(self.img_path) + ) + # To support multi-threading on Windows + # Helps pickle using Path + self.reader = None if os.name == "nt" else reader # may decouple into misc ? # the scaling factor will scale base level to requested read resolution/units wsi_shape = reader.slide_dimensions(resolution=resolution, units=units) self.reader_info = reader.info # use all patches, as long as it overlaps source image - self.inputs = PatchExtractor.get_coordinates( - image_shape=wsi_shape, - patch_input_shape=patch_input_shape[::-1], - stride_shape=stride_shape[::-1], - input_within_bound=False, - ) + if patch_output_shape is not None: + self.inputs, self.outputs = PatchExtractor.get_coordinates( + image_shape=wsi_shape, + patch_input_shape=patch_input_shape[::-1], + stride_shape=stride_shape[::-1], + patch_output_shape=patch_output_shape, + ) + self.full_outputs = self.outputs + else: + self.inputs = PatchExtractor.get_coordinates( + image_shape=wsi_shape, + patch_input_shape=patch_input_shape[::-1], + stride_shape=stride_shape[::-1], + ) mask_reader = None if mask_path is not None: @@ -490,13 +491,13 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 mask = np.array(mask > 0, dtype=np.uint8) mask_reader = VirtualWSIReader(mask) - mask_reader.info = reader.info - elif auto_get_mask and mode == "wsi" and mask_path is None: + mask_reader.info = self.reader_info + elif auto_get_mask and mask_path is None: # if no mask provided and `wsi` mode, generate basic tissue # mask on the fly mask_reader = reader.tissue_mask(resolution=1.25, units="power") # ? will this mess up ? - mask_reader.info = reader.info + mask_reader.info = self.reader_info if mask_reader is not None: selected = PatchExtractor.filter_coordinates( @@ -506,10 +507,11 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 min_mask_ratio=min_mask_ratio, ) self.inputs = self.inputs[selected] + if hasattr(self, "outputs"): + self.full_outputs = self.outputs # Full list of outputs + self.outputs = self.outputs[selected] - if len(self.inputs) == 0: - msg = "No patch coordinates remain after filtering." - raise ValueError(msg) + self._check_inputs() self.patch_input_shape = patch_input_shape self.resolution = resolution @@ -518,44 +520,26 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 # Perform check on the input self._check_input_integrity(mode="wsi") + def _check_inputs(self: WSIPatchDataset) -> None: + """Check if input length is valid after filtering.""" + if len(self.inputs) == 0: + msg = "No patch coordinates remain after filtering." + raise ValueError(msg) + def _get_reader(self: WSIPatchDataset, img_path: str | Path) -> WSIReader: """Get a reader for the image.""" - if self.mode == "wsi": - reader = WSIReader.open(img_path) - else: - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"` and `resolution=1.0`.', - stacklevel=2, - ) - img = imread(img_path) - axes = "YXS"[: len(img.shape)] - # initialise metadata for VirtualWSIReader. - # here, we simulate a whole-slide image, but with a single level. - # ! should we expose this so that use can provide their metadata ? - metadata = WSIMeta( - mpp=np.array([1.0, 1.0]), - axes=axes, - objective_power=10, - slide_dimensions=np.array(img.shape[:2][::-1]), - level_downsamples=[1.0], - level_dimensions=[np.array(img.shape[:2][::-1])], - ) - # infer value such that read if mask provided is through - # 'mpp' or 'power' as varying 'baseline' is locked atm - reader = VirtualWSIReader( - img, - info=metadata, - ) - return reader + # To avoid ruff errors and compatibility with base class. + return self.reader if self.reader else WSIReader.open(img_path) def __getitem__(self: WSIPatchDataset, idx: int) -> dict: """Get an item from the dataset.""" coords = self.inputs[idx] + output_locs = None + if hasattr(self, "outputs"): + output_locs = self.outputs[idx] + # Read image patch from the whole-slide image - if self.reader is None: - # only set the reader on first call so that it is initially picklable - self.reader = self._get_reader(self.img_path) + self.reader = self._get_reader(self.img_path) patch = self.reader.read_bounds( coords, resolution=self.resolution, @@ -567,6 +551,13 @@ def __getitem__(self: WSIPatchDataset, idx: int) -> dict: # Apply preprocessing to selected patch patch = self._preproc(patch) + if output_locs is not None: + return { + "image": patch, + "coords": np.array(coords), + "output_locs": output_locs, + } + return {"image": patch, "coords": np.array(coords)} @@ -584,6 +575,10 @@ class PatchDataset(PatchDatasetABC): labels (list): List of labels for sample at the same index in `inputs`. Default is `None`. + patch_input_shape (tuple): + Size of patches input to the model. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. Examples: >>> # A user defined preproc func and expected behavior @@ -593,6 +588,7 @@ class PatchDataset(PatchDatasetABC): >>> ds = PatchDataset( ... inputs=['/A/B/C/img1.png', '/A/B/C/img2.png'], ... labels=["labels1", "labels2"], + ... patch_input_shape=(224, 224), ... ) """ @@ -601,6 +597,7 @@ def __init__( self: PatchDataset, inputs: np.ndarray | list, labels: list | None = None, + patch_input_shape: IntPair | None = None, ) -> None: """Initialize :class:`PatchDataset`.""" super().__init__() @@ -609,6 +606,7 @@ def __init__( self.inputs = inputs self.labels = labels + self.patch_input_shape = patch_input_shape # perform check on the input self._check_input_integrity(mode="patch") @@ -621,6 +619,18 @@ def __getitem__(self: PatchDataset, idx: int) -> dict: if not self.data_is_npy_alike: patch = self.load_img(patch) + if patch.shape[:-1] != tuple(self.patch_input_shape): + msg = ( + f"Patch size is not compatible with the model. " + f"Expected dimensions {tuple(self.patch_input_shape)}, but got " + f"{patch.shape[:-1]}." + ) + logger.error(msg=msg) + raise DimensionMismatchError( + expected_dims=tuple(self.patch_input_shape), + actual_dims=patch.shape[:-1], + ) + # Apply preprocessing to selected patch patch = self._preproc(patch) @@ -632,3 +642,40 @@ def __getitem__(self: PatchDataset, idx: int) -> dict: return data return data + + +def _validate_patch_stride_shape( + patch_input_shape: np.ndarray, stride_shape: np.ndarray +) -> None: + """Validate patch and stride shape inputs for semantic segmentation. + + Checks that both `patch_input_shape` and `stride_shape` are integer arrays of + length ≤ 2 and contain non-negative values. Raises a ValueError if any + condition fails. + + Parameters: + patch_input_shape (np.ndarray): + Shape of the input patch (e.g., height, width). + stride_shape (np.ndarray): + Stride dimensions used for patch extraction. + + Raises: + ValueError: + If either input is not a valid integer array of appropriate + shape and values. + + """ + if ( + not np.issubdtype(patch_input_shape.dtype, np.integer) + or np.size(patch_input_shape) > 2 # noqa: PLR2004 + or np.any(patch_input_shape < 0) + ): + msg = f"Invalid `patch_input_shape` value {patch_input_shape}." + raise ValueError(msg) + if ( + not np.issubdtype(stride_shape.dtype, np.integer) + or np.size(stride_shape) > 2 # noqa: PLR2004 + or np.any(stride_shape < 0) + ): + msg = f"Invalid `stride_shape` value {stride_shape}." + raise ValueError(msg) diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 4e9ea70f1..a9adc2779 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -1,17 +1,51 @@ -"""Defines Abstract Base Class for TIAToolbox Engines.""" +"""Abstract Base Class for TIAToolbox Deep Learning Engines. + +This module defines the `EngineABC` class, which serves as a base for implementing +deep learning inference workflows in TIAToolbox. It supports both patch-based and +whole slide image (WSI) processing, and provides a unified interface for model +initialization, data loading, inference, post-processing, and output saving. + +Classes: + - EngineABC: Abstract base class for deep learning engines. + - EngineABCRunParams: TypedDict for runtime configuration parameters. + +Functions: + - prepare_engines_save_dir: Utility to create or validate output directories. + +Features: + - Supports patch and WSI modes. + - Handles caching and memory-efficient inference using Dask. + - Integrates with TIAToolbox models and IO configurations. + - Outputs predictions in multiple formats including dict, zarr, and AnnotationStore. + +Intended Usage: + Subclass `EngineABC` to implement specific inference logic by overriding + abstract methods such as preprocessing, postprocessing, and model-specific behavior. + +Example: + >>> class MyEngine(EngineABC): + >>> def __init__(self, model, weights, verbose): + >>> super().__init__(model=model, weights=weights, verbose=verbose) + >>> # Implement base class functions and then call. + >>> engine = MyEngine(model="resnet18-kather100k") + >>> output = engine.run(images, patch_mode=True) + +""" from __future__ import annotations import copy -import shutil from abc import ABC from pathlib import Path from typing import TYPE_CHECKING, TypedDict +import dask +import dask.array as da import numpy as np import torch -import tqdm import zarr +from dask import compute +from dask.diagnostics import ProgressBar from torch import nn from typing_extensions import Unpack @@ -21,10 +55,10 @@ from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset from tiatoolbox.models.models_abc import load_torch_model from tiatoolbox.utils.misc import ( - dict_to_store, - dict_to_zarr, - write_to_zarr_in_cache_mode, + dict_to_store_patch_predictions, + get_tqdm, ) +from tiatoolbox.wsicore.wsireader import WSIReader, is_zarr from .io_config import ModelIOConfigABC @@ -36,133 +70,61 @@ from tiatoolbox.annotation import AnnotationStore from tiatoolbox.models.models_abc import ModelABC from tiatoolbox.type_hints import IntPair, Resolution, Units - from tiatoolbox.wsicore.wsireader import WSIReader - - -def prepare_engines_save_dir( - save_dir: os | Path | None, - *, - patch_mode: bool, - overwrite: bool = False, -) -> Path | None: - """Create a save directory. - - If patch_mode is False and the save directory is not defined, - this function will raise an error. - - If patch_mode is True and the save directory is defined it will - create save_dir otherwise returns None. - - Args: - save_dir (str or Path): - Path to output directory. - patch_mode(bool): - Whether to treat input image as a patch or WSI. - overwrite (bool): - Whether to overwrite the results. Default = False. - - Returns: - :class:`Path`: - Path to output directory. - - Raises: - OSError: - If the save directory is not defined. - - """ - if patch_mode is True: - if save_dir is not None: - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=overwrite) - return save_dir - - if save_dir is None: - msg = ( - "Input WSIs detected but no save directory provided." - "Please provide a 'save_dir'." - ) - raise OSError(msg) - - logger.info( - "When providing multiple whole slide images, " - "the outputs will be saved and the locations of outputs " - "will be returned to the calling function when `run()`" - "finishes successfully.", - ) - - save_dir = Path(save_dir) - save_dir.mkdir(parents=True, exist_ok=overwrite) - - return save_dir class EngineABCRunParams(TypedDict, total=False): - """Class describing the input parameters for the :func:`EngineABC.run()` method. + """Parameters for configuring the :func:`EngineABC.run()` method. - Attributes: + Optional Keys: + auto_get_mask (bool): + Whether to automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. batch_size (int): - Number of image patches to feed to the model in a forward pass. - cache_mode (bool): - Whether to run the Engine in cache_mode. For large datasets, - we recommend to set this to True to avoid out of memory errors. - For smaller datasets, the cache_mode is set to False as - the results can be saved in memory. - cache_size (int): - Specifies how many image patches to process in a batch when - cache_mode is set to True. If cache_size is less than the batch_size - batch_size is set to cache_size. + Number of image patches per forward pass. class_dict (dict): - Optional dictionary mapping classification outputs to class names. + Mapping of classification outputs to class names. device (str): - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. + Device to run the model on (e.g., "cpu", "cuda"). + See https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details. + input_resolutions (list[dict[Units, Resolution]]): + Resolution settings for input heads. Supported units are `level`, + `power`, and `mpp`. Keys should be "units" and "resolution", e.g., + [{"units": "mpp", "resolution": 0.25}]. See :class:`WSIReader` for details. ioconfig (ModelIOConfigABC): - Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. - return_labels (bool): - Whether to return the labels with the predictions. - num_loader_workers (int): - Number of workers used in :class:`torch.utils.data.DataLoader`. - num_post_proc_workers (int): - Number of workers to postprocess the results of the model. + IO configuration (:class:`ModelIOConfigABC`) for model input/output. + memory_threshold (int): + Memory usage threshold (in percentage) to trigger caching behavior. + num_workers (int): + Number of workers for DataLoader and post-processing. output_file (str): - Output file name to save "zarr" or "db". If None, path to output is - returned by the engine. - patch_input_shape (tuple): - Shape of patches input to the model as tuple of height and width (HW). - Patches are requested at read resolution, not with respect to level 0, - and must be positive. - input_resolutions (list(dict(Units, Resolution))): - List of Python dictionaries with units and resolution for each - input head for model inference for reading the image. Supported - units are `level`, `power` and `mpp`. Keys should be "units" and - "resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see - :class:`WSIReader` for details. + Filename for saving output (e.g., "zarr" or "annotationstore"). + patch_input_shape (IntPair): + Shape of input patches (height, width), requested at read resolution. + Must be positive. + return_labels (bool): + Whether to return labels with predictions. scale_factor (tuple[float, float]): - The scale factor to use when loading the annotations. All coordinates - will be multiplied by this factor to allow conversion of annotations - saved at non-baseline resolution to baseline. Should be model_mpp/slide_mpp. - stride_shape (tuple): - Stride used during WSI processing. Stride is - at requested read resolution, not with respect to - level 0, and must be positive. If not provided, - `stride_shape=patch_input_shape`. + Scale factor for annotations (model_mpp / slide_mpp). + Used to convert coordinates from non-baseline to baseline resolution. + stride_shape (IntPair): + Stride used during WSI processing, at requested read resolution. + Must be positive. Defaults to `patch_input_shape` if not provided. verbose (bool): - Whether to output logging information. + Whether to enable verbose logging. """ + auto_get_mask: bool batch_size: int - cache_mode: bool - cache_size: int class_dict: dict device: str + input_resolutions: list[dict[Units, Resolution]] ioconfig: ModelIOConfigABC - num_loader_workers: int - num_post_proc_workers: int + memory_threshold: int + num_workers: int output_file: str patch_input_shape: IntPair - input_resolutions: list[dict[Units, Resolution]] return_labels: bool scale_factor: tuple[float, float] stride_shape: IntPair @@ -172,9 +134,13 @@ class EngineABCRunParams(TypedDict, total=False): class EngineABC(ABC): # noqa: B024 """Abstract base class for TIAToolbox deep learning engines to run CNN models. + This class provides a unified interface for running inference on image patches + or whole slide images (WSIs), handling preprocessing, batching, postprocessing, + and saving predictions. + Args: model (str | ModelABC): - A PyTorch model. Default is `None`. + Model name from TIAToolbox or a PyTorch model instance. The user can request pretrained models from the toolbox model zoo using the list of pretrained models available at this `link `_ @@ -182,16 +148,11 @@ class EngineABC(ABC): # noqa: B024 be downloaded. However, you can override with your own set of weights using the `weights` parameter. batch_size (int): - Number of image patches fed into the model each time in a - forward/backward pass. Default value is 8. - num_loader_workers (int): - Number of workers to load the data using :class:`torch.utils.data.Dataset`. - Please note that they will also perform preprocessing. Default value is 0. - num_post_proc_workers (int): - Number of workers to postprocess the results of the model. - Default value is 0. - weights (str or Path): - Path to the weight of the corresponding `model`. + Number of patches per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. >>> engine = EngineABC( ... model="pretrained-model", @@ -199,34 +160,35 @@ class EngineABC(ABC): # noqa: B024 ... ) device (str): - Select the device to run the model. Please see + Device to run the model on (e.g., "cpu", "cuda"). Please see https://pytorch.org/docs/stable/tensor_attributes.html#torch.device for more details on input parameters for device. Default is "cpu". verbose (bool): - Whether to output logging information. Default value is False. + Enable verbose logging. Default is False. Attributes: - images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + images (list[str | Path] | np.ndarray): + Input images or patches. A list of image patches in NHWC format as a numpy array or a list of str/paths to WSIs. - masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + masks (list[str | Path] | np.ndarray): + Optional masks for WSIs. A list of tissue masks or binary masks corresponding to processing area of input images. These can be a list of numpy arrays or paths to the saved image masks. These are only utilized when patch_mode is False. Patches are only generated within a masked area. If not provided, then a tissue mask will be automatically - generated for whole slide images. - patch_mode (str): - Whether to treat input images as a set of image patches. TIAToolbox defines + generated for whole slide images, if auto_get_mask is True. + patch_mode (bool): + Whether input is treated as patches. TIAToolbox defines an image as a patch if HWC of the input image matches with the HWC expected by the model. If HWC of the input image does not match with the HWC expected by the model, then the patch_mode must be set to False which will allow the engine to extract patches from the input image. In this case, when the patch_mode is False the input images are treated as WSIs. Default value is True. - model (str | ModelABC): - A PyTorch model or a name of an existing model from the TIAToolbox model zoo - for processing the data. For a full list of pretrained models, + model (ModelABC): + Loaded PyTorch model. For a full list of pretrained models, refer to the `docs `_ By default, the corresponding pretrained weights will also @@ -234,72 +196,59 @@ class EngineABC(ABC): # noqa: B024 of weights via the `weights` argument. Argument is case-insensitive. ioconfig (ModelIOConfigABC): - Input IO configuration of type :class:`ModelIOConfigABC` to run the Engine. - _ioconfig (ModelIOConfigABC): - Runtime ioconfig. + IO configuration (:class:`ModelIOConfigABC`) for model input/output. + dataloader (DataLoader): + Torch DataLoader for inference. return_labels (bool): - Whether to return the labels with the predictions. - input_resolutions (list(dict(Units, Resolution))): - List of Python dictionaries with units and resolution for each - input head for model inference for reading the image. Supported + Whether to return labels with probabilities and predictions. + input_resolutions (list[dict[Units, Resolution]]): + Resolution settings for input heads. Supported units are `level`, `power` and `mpp`. Keys should be "units" and "resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see :class:`WSIReader` for details. - patch_input_shape (tuple): - Shape of patches input to the model as tupled of HW. Patches are at + patch_input_shape (IntPair): + Shape of input patches. Patches are at requested read resolution, not with respect to level 0, and must be positive. - stride_shape (tuple): + stride_shape (IntPair): Stride used during WSI processing. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, `stride_shape=patch_input_shape`. batch_size (int): - Number of images fed into the model each time. - cache_mode (bool): - Whether to run the Engine in cache_mode. For large datasets, - we recommend to set this to True to avoid out of memory errors. - For smaller datasets, the cache_mode is set to False as - the results can be saved in memory. cache_mode is always True when - processing WSIs i.e., when `patch_mode` is False. Default value is False. - cache_size (int): - Specifies how many image patches to process in a batch when - cache_mode is set to True. If cache_size is less than the batch_size - batch_size is set to cache_size. Default value is 10,000. + Number of patches per forward pass. labels (list | None): - List of labels. Only a single label per image is supported. - device (str): - :class:`torch.device` to run the model. - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. Default value is "cpu". - num_loader_workers (int): - Number of workers used in :class:`torch.utils.data.DataLoader`. - num_post_proc_workers (int): - Number of workers to postprocess the results of the model. + Optional labels for input images. Only a single label per image is + supported. + num_workers (int): + Number of workers for data loading. + patch_input_shape (IntPair | None): + Shape of input patches. + input_resolutions (list[dict[Units, Resolution]] | None): + Resolution settings for input heads. return_labels (bool): - Whether to return the output labels. Default value is False. - input_resolutions (list(dict(Units, Resolution))): - List of Python dictionaries with units and resolution for each - input head for model inference for reading the image. Supported - units are `level`, `power` and `mpp`. When `patch_mode` is `True`, - the input image patches are expected to be at the correct resolution and - units. When `patch_mode` is `False`, the patches are extracted at the - requested resolution and units. Default value is [{"units": "baseline", - "resolution": 1.0}]. + Whether to return labels with predictions. + stride_shape (IntPair | None): + Stride used during WSI processing. + verbose (bool): + Whether to enable verbose logging. + dataloader (DataLoader | None): + Torch DataLoader for inference. + drop_keys (list): + Keys to exclude from model output. + output_type (Any): + Format of output ("dict", "zarr", "AnnotationStore"). verbose (bool): - Whether to output logging information. Default value is False. + Whether to enable verbose logging. - Examples: + Example: >>> # Inherit from EngineABC - >>> class TestEngineABC(EngineABC): - >>> def __init__( - >>> self, - >>> model, - >>> weights, - >>> verbose, - >>> ): - >>> super().__init__(model=model, weights=weights, verbose=verbose) + >>> class MyEngine(EngineABC): + >>> def __init__(self, model, weights, verbose): + >>> super().__init__(model=model, weights=weights, verbose=verbose) + >>> engine = MyEngine(model="resnet18-kather100k") + >>> output = engine.run(images, patch_mode=True) + >>> # Define all the abstract classes >>> data = np.array([np.ndarray, np.ndarray]) @@ -327,14 +276,29 @@ def __init__( self: EngineABC, model: str | ModelABC, batch_size: int = 8, - num_loader_workers: int = 0, - num_post_proc_workers: int = 0, + num_workers: int = 0, weights: str | Path | None = None, *, device: str = "cpu", verbose: bool = False, ) -> None: - """Initialize Engine.""" + """Initialize the EngineABC instance. + + Args: + model (str | ModelABC): + Model name from TIAToolbox or a PyTorch model instance. + batch_size (int): + Number of patches per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Enable verbose logging. Default is False. + + """ self.images = None self.masks = None self.patch_mode = None @@ -342,8 +306,7 @@ def __init__( # Initialize model with specified weights and ioconfig. self.model, self.ioconfig = self._initialize_model_ioconfig( - model=model, - weights=weights, + model=model, weights=weights ) self.model.to(device=self.device) self.model = ( @@ -353,35 +316,32 @@ def __init__( ) ) self._ioconfig = self.ioconfig # runtime ioconfig - self.batch_size = batch_size - self.cache_mode: bool = False - self.cache_size: int = self.batch_size if self.batch_size else 10000 self.labels: list | None = None - self.num_loader_workers = num_loader_workers - self.num_post_proc_workers = num_post_proc_workers + self.num_workers = num_workers self.patch_input_shape: IntPair | None = None self.input_resolutions: list[dict[Units, Resolution]] | None = None self.return_labels: bool = False self.stride_shape: IntPair | None = None - self.verbose = verbose + self.verbose: bool = verbose + self.dataloader: DataLoader | None = None + self.drop_keys: list = [] + self.output_type = None @staticmethod def _initialize_model_ioconfig( model: str | ModelABC, weights: str | Path | None, ) -> tuple[nn.Module, ModelIOConfigABC | None]: - """Helper function to initialize model and ioconfig attributes. + """Helper function to initialize model and IO configuration. - If a pretrained model provided by the TIAToolbox is requested. The model - can be specified as a string otherwise :class:`torch.nn.Module` is required. - This function also loads the :class:`ModelIOConfigABC` using the information - from the pretrained models in TIAToolbox. If ioconfig is not available then it - should be provided in the :func:`run()` function. + If a pretrained model from TIAToolbox is specified by name, this function + loads the model and its associated IO configuration. If a custom model is + provided, it loads the weights if specified and returns None for IO config. Args: model (str | ModelABC): - A PyTorch model. Default is `None`. + A model name from TIAToolbox or a PyTorch model instance. The user can request pretrained models from the toolbox model zoo using the list of pretrained models available at this `link `_ @@ -390,18 +350,18 @@ def _initialize_model_ioconfig( of weights using the `weights` parameter. weights (str | Path | None): - Path to pretrained weights. If no pretrained weights are provided - and the `model` is provided by TIAToolbox, then pretrained weights will - be automatically loaded from the TIA servers. + Path to pretrained weights. If None and a TIAToolbox model is used, + default weights are automatically downloaded. Returns: - ModelABC: - The requested PyTorch model as a :class:`ModelABC` instance. + tuple[nn.Module, ModelIOConfigABC | None]: + A tuple containing the loaded PyTorch model and its IO configuration. + If the model is not from TIAToolbox, IO config will be None. - ModelIOConfigABC | None: - The model io configuration for TIAToolbox pretrained models. - If the specified model is not in TIAToolbox model zoo, then the function - returns None. + Raises: + TypeError: + If the model is neither a string (TIAToolbox model) + nor a torch.nn.Module. """ if not isinstance(model, (str, nn.Module)): @@ -428,29 +388,34 @@ def get_dataloader( ioconfig: ModelIOConfigABC | None = None, *, patch_mode: bool = True, + auto_get_mask: bool = True, ) -> torch.utils.data.DataLoader: - """Pre-process images and masks and return dataloader for inference. + """Pre-process images and masks and return a DataLoader for inference. Args: - images (list of str or :class:`Path` or :class:`numpy.ndarray`): - A list of image patches in NHWC format as a numpy array - or a list of str/paths to WSIs. When `patch_mode` is False - the function expects list of str/paths to WSIs. - masks (list | None): - List of masks. Only utilised when patch_mode is False. - Patches are only generated within a masked area. - If not provided, then a tissue mask will be automatically - generated for whole slide images. + images (list[str | Path] | np.ndarray): + A list of image patches in NHWC format as a numpy array, + or a list of file paths to WSIs. When `patch_mode` is False, + expects file paths to WSIs. + masks (Path | None): + Optional list of masks used when `patch_mode` is False. + Patches are generated only within masked areas. If not provided, + tissue masks are automatically generated. labels (list | None): - List of labels. Only a single label per image is supported. - ioconfig (ModelIOConfigABC): - A :class:`ModelIOConfigABC` object. + Optional list of labels. Only one label per image is supported. + ioconfig (ModelIOConfigABC | None): + IO configuration object specifying patch size, stride, and resolution. patch_mode (bool): - Whether to treat input image as a patch or WSI. + Whether to treat input as patches (`True`) or WSIs (`False`). + auto_get_mask (bool): + Whether to automatically generate a tissue mask using + `wsireader.tissue_mask()` when `patch_mode` is False. + If `True`, only tissue regions are processed. If `False`, + all patches are processed. Default is `True`. Returns: torch.utils.data.DataLoader: - :class:`torch.utils.data.DataLoader` for inference. + A PyTorch DataLoader configured for inference. """ if labels: @@ -459,13 +424,13 @@ def get_dataloader( if not patch_mode: dataset = WSIPatchDataset( - img_path=images, - mode="wsi", + input_img=images, mask_path=masks, patch_input_shape=ioconfig.patch_input_shape, stride_shape=ioconfig.stride_shape, resolution=ioconfig.input_resolutions[0]["resolution"], units=ioconfig.input_resolutions[0]["units"], + auto_get_mask=auto_get_mask, ) dataset.preproc_func = self.model.preproc_func @@ -473,294 +438,343 @@ def get_dataloader( # preprocessing must be defined with the dataset return torch.utils.data.DataLoader( dataset, - num_workers=self.num_loader_workers, + num_workers=self.num_workers, batch_size=self.batch_size, drop_last=False, shuffle=False, + persistent_workers=self.num_workers > 0, ) - dataset = PatchDataset(inputs=images, labels=labels) + dataset = PatchDataset( + inputs=images, labels=labels, patch_input_shape=ioconfig.patch_input_shape + ) + dataset.preproc_func = self.model.preproc_func # preprocessing must be defined with the dataset return torch.utils.data.DataLoader( dataset, - num_workers=self.num_loader_workers, + num_workers=self.num_workers, batch_size=self.batch_size, drop_last=False, shuffle=False, ) - @staticmethod - def _update_model_output(raw_predictions: dict, raw_output: dict) -> dict: - """Helper function to append raw output during inference.""" - for key, value in raw_output.items(): - if raw_predictions[key] is None: - raw_predictions[key] = value - else: - raw_predictions[key] = np.append(raw_predictions[key], value, axis=0) + def _get_coordinates(self: EngineABC, batch_data: dict) -> np.ndarray: + """Extract coordinates for each image patch in a batch. - return raw_predictions + This method returns coordinates for each patch, either based on + the patch dimensions (if in patch mode) or from precomputed values + (if in WSI mode). - def _get_coordinates(self: EngineABC, batch_data: dict) -> np.ndarray: - """Helper function to collect coordinates for AnnotationStore.""" + Args: + batch_data (dict): + Dictionary containing batch data, including image and + optional coordinates. + + Returns: + np.ndarray: + Array of coordinates for each patch in the batch. + Shape is (N, 4), where N is the number of patches. + + """ if self.patch_mode: coordinates = [0, 0, *batch_data["image"].shape[1:3]] return np.tile(coordinates, reps=(batch_data["image"].shape[0], 1)) - return batch_data["coords"].numpy() + return np.array(batch_data["coords"]) def infer_patches( self: EngineABC, dataloader: DataLoader, - save_path: Path | None, *, return_coordinates: bool = False, - ) -> dict | Path: - """Runs model inference on image patches and returns output as a dictionary. + ) -> dict[str, da.Array]: + """Run model inference on image patches and return predictions. + + This method performs batched inference using a PyTorch DataLoader, + and accumulates predictions in Dask arrays. It supports optional inclusion + of coordinates and labels in the output. Args: dataloader (DataLoader): - An :class:`torch.utils.data.DataLoader` object to run inference. - save_path (Path | None): - If `cache_mode` is True then path to save zarr file must be provided. + PyTorch DataLoader containing image patches for inference. return_coordinates (bool): - Whether to save coordinates in the output. This is required when - this function is called by `infer_wsi` and `patch_mode` is False. + Whether to include coordinates in the output. Required when + called by `infer_wsi` and `patch_mode` is False. Returns: - dict or Path: - Result of model inference as a dictionary. Returns path to - saved zarr file if `cache_mode` is True. + dict[str, dask.array.Array]: + Dictionary containing prediction results as Dask arrays. + Keys include: + - "probabilities": Model output probabilities. + - "labels": Ground truth labels (if `return_labels` is True). + - "coordinates": Patch coordinates (if `return_coordinates` is + True). """ - progress_bar = None - - if self.verbose: - progress_bar = tqdm.tqdm( - total=len(dataloader), - leave=True, - ncols=80, - ascii=True, - position=0, - ) - keys = ["probabilities"] + probabilities, labels, coordinates = [], [], [] if self.return_labels: keys.append("labels") + labels = [] if return_coordinates: keys.append("coordinates") + coordinates = [] - raw_predictions = dict.fromkeys(keys) + # Main output dictionary + raw_predictions = dict(zip(keys, [[]] * len(keys))) - zarr_group = None - - if self.cache_mode: - zarr_group = zarr.open(save_path, mode="w") + # Inference loop + tqdm = get_tqdm() + tqdm_loop = ( + tqdm(dataloader, leave=False, desc="Inferring patches") + if self.verbose + else self.dataloader + ) - for _, batch_data in enumerate(dataloader): + for batch_data in tqdm_loop: batch_output = self.model.infer_batch( self.model, batch_data["image"], device=self.device, ) - if return_coordinates: - batch_output["coordinates"] = self._get_coordinates(batch_data) - - if self.return_labels: # be careful of `s` - if isinstance(batch_data["label"], torch.Tensor): - batch_output["labels"] = batch_data["label"].numpy() - else: - batch_output["labels"] = batch_data["label"] - raw_predictions = self._update_model_output( - raw_predictions=raw_predictions, - raw_output=batch_output, + probabilities.append( + da.from_array( + batch_output, # probabilities + ) ) - if self.cache_mode: - zarr_group = write_to_zarr_in_cache_mode( - zarr_group=zarr_group, output_data_to_save=raw_predictions + if return_coordinates: + coordinates.append( + da.from_array( + self._get_coordinates(batch_data), + ) ) - raw_predictions = dict.fromkeys(keys) - if progress_bar: - progress_bar.update() + if self.return_labels: + labels.append(da.from_array(np.array(batch_data["label"]))) - if progress_bar: - progress_bar.close() + raw_predictions["probabilities"] = da.concatenate(probabilities, axis=0) - return save_path if self.cache_mode else raw_predictions + if return_coordinates: + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) + + if self.return_labels: + labels = [label.reshape(-1) for label in labels] + raw_predictions["labels"] = da.concatenate(labels, axis=0) + + return raw_predictions - def post_process_patches( + def post_process_patches( # skipcq: PYL-R0201 self: EngineABC, - raw_predictions: dict | Path, - **kwargs: Unpack[EngineABCRunParams], - ) -> dict | Path: + raw_predictions: da.Array, + prediction_shape: tuple[int, ...], # noqa: ARG002 + prediction_dtype: type, # noqa: ARG002 + **kwargs: Unpack[EngineABCRunParams], # noqa: ARG002 + ) -> dask.array.Array: """Post-process raw patch predictions from inference. - The output of :func:`infer_patches()` with patch prediction information will be - post-processed using this function. The processed output will be saved in the - respective input format. If `cache_mode` is True, the function processes the - input using zarr group with size specified by `cache_size`. + This method applies a post-processing function (e.g., smoothing, filtering) + to the raw model predictions. It supports delayed execution using Dask + and returns a Dask array for efficient computation. Args: - raw_predictions (dict | Path): - A dictionary or path to zarr with patch prediction information. + raw_predictions (dask.array.Array): + Raw model predictions as a dask array. + prediction_shape (tuple[int, ...]): + Shape of the prediction output. + prediction_dtype (type): + Data type of the prediction output. **kwargs (EngineABCRunParams): - Keyword Args to update setup_patch_dataset() method attributes. See - :class:`EngineRunParams` for accepted keyword arguments. + Additional runtime parameters used for post-processing. Returns: - dict or Path: - Returns patch based output after post-processing. Returns path to - saved zarr file if `cache_mode` is True. + dask.array.Array: + Post-processed predictions as a Dask array. """ - _ = kwargs.get("return_labels") # Key values required for post-processing - - if self.cache_mode: # cache mode - _ = zarr.open(raw_predictions, mode="w") - return raw_predictions def save_predictions( self: EngineABC, - processed_predictions: dict | Path, + processed_predictions: dict, output_type: str, - save_dir: Path | None = None, - **kwargs: dict, + save_path: Path | None = None, + **kwargs: Unpack[EngineABCRunParams], ) -> dict | AnnotationStore | Path: - """Save model predictions. + """Save model predictions to disk or return them in memory. + + Depending on the output type, this method saves predictions as a zarr group, + an AnnotationStore (SQLite database), or returns them as a dictionary. Args: - processed_predictions (dict | Path): - A dictionary or path to zarr with model prediction information. - save_dir (Path): - Optional output path to directory to save the patch dataset output to a - `.zarr` or `.db` file, provided `patch_mode` is True. If the - `patch_mode` is False then `save_dir` is required. + processed_predictions (dict): + Dictionary containing processed model predictions. output_type (str): - The desired output type for resulting patch dataset. + Desired output format. + Supported values are "dict", "zarr", and "annotationstore". + save_path (Path | None): + Path to save the output file. + Required for "zarr" and "annotationstore" formats. **kwargs (EngineABCRunParams): - Keyword Args required to save the output. + Additional runtime parameters including: + - output_file: Name of the output file. + - scale_factor: Scaling factor for annotations. + - class_dict: Mapping of class indices to names. Returns: - dict or Path or :class:`AnnotationStore`: - If the `output_type` is "AnnotationStore", the function returns - the patch predictor output as an SQLiteStore containing Annotations - for each or the Path to a `.db` file depending on whether a - save_dir Path is provided. Otherwise, the function defaults to - returning patch predictor output, either as a dict or the Path to a - `.zarr` file depending on whether a save_dir Path is provided. + dict | AnnotationStore | Path: + - If output_type is "dict": returns predictions as a dictionary. + - If output_type is "zarr": returns path to saved zarr file. + - If output_type is "annotationstore": returns an AnnotationStore + or path to .db file. + + Raises: + TypeError: + If an unsupported output_type is provided. """ - if ( - self.cache_mode or not save_dir - ) and output_type.lower() != "annotationstore": - return processed_predictions + keys_to_compute = [k for k in processed_predictions if k not in self.drop_keys] + + if output_type.lower() == "zarr": + if is_zarr(save_path): + zarr_group = zarr.open(save_path, mode="r") + keys_to_compute = [k for k in keys_to_compute if k not in zarr_group] + write_tasks = [] + for key in keys_to_compute: + dask_array = processed_predictions[key].rechunk("auto") + task = dask_array.to_zarr( + url=save_path, + component=key, + compute=False, + ) + write_tasks.append(task) + msg = f"Saving output to {save_path}." + logger.info(msg=msg) + with ProgressBar(): + compute(*write_tasks) + + zarr_group = zarr.open(save_path, mode="r+") + for key in self.drop_keys: + if key in zarr_group: + del zarr_group[key] - save_path = Path(kwargs.get("output_file", save_dir / "output.db")) + return save_path + + values_to_compute = [processed_predictions[k] for k in keys_to_compute] + + # Compute all at once + computed_values = compute(*values_to_compute) + + # Assign computed values + processed_predictions = dict(zip(keys_to_compute, computed_values)) + + if output_type.lower() == "dict": + return processed_predictions if output_type.lower() == "annotationstore": + save_path = Path(kwargs.get("output_file", save_path.parent / "output.db")) + # scale_factor set from kwargs scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) # class_dict set from kwargs class_dict = kwargs.get("class_dict") - processed_predictions_path: str | Path | None = None - - # Need to add support for zarr conversion. - if self.cache_mode: - processed_predictions_path = processed_predictions - processed_predictions = zarr.open(processed_predictions, mode="r") - - out_file = dict_to_store( + return dict_to_store_patch_predictions( processed_predictions, scale_factor, class_dict, save_path, ) - if processed_predictions_path is not None: - shutil.rmtree(processed_predictions_path) - return out_file - - return ( - dict_to_zarr( - processed_predictions, - save_path, - **kwargs, - ) - if isinstance(processed_predictions, dict) - else processed_predictions - ) + msg = f"Unsupported output type: {output_type}" + raise TypeError(msg) def infer_wsi( self: EngineABC, dataloader: DataLoader, save_path: Path, - **kwargs: EngineABCRunParams, - ) -> Path: - """Model inference on a WSI. + **kwargs: Unpack[EngineABCRunParams], + ) -> dict: + """Run model inference on a whole slide image (WSI). + + This method performs inference on a WSI using the provided DataLoader, + and accumulates predictions in Dask arrays. Optionally includes + coordinates and labels in the output. Args: dataloader (DataLoader): - A torch dataloader to process WSIs. - + PyTorch DataLoader configured for WSI processing. save_path (Path): Path to save the intermediate output. The intermediate output is saved in a zarr file. **kwargs (EngineABCRunParams): - Keyword Args to update setup_patch_dataset() method attributes. See - :class:`EngineRunParams` for accepted keyword arguments. + Additional runtime parameters used during inference. Returns: - save_path (Path): - Path to zarr file where intermediate output is saved. + dict: + Dictionary containing prediction results as Dask arrays. """ _ = kwargs.get("patch_mode", False) + _ = save_path return self.infer_patches( dataloader=dataloader, - save_path=save_path, return_coordinates=True, ) # This is not a static model for child classes. def post_process_wsi( # skipcq: PYL-R0201 self: EngineABC, - raw_predictions: dict | Path, - **kwargs: Unpack[EngineABCRunParams], - ) -> dict | Path: - """Post process WSI output. + raw_predictions: da.Array, + prediction_shape: tuple[int, ...], # noqa: ARG002 + prediction_dtype: type, # noqa: ARG002 + **kwargs: Unpack[EngineABCRunParams], # noqa: ARG002 + ) -> dask.array.Array: + """Post-process predictions from whole slide image (WSI) inference. + + This method applies a post-processing function (e.g., smoothing, filtering) + to the raw model predictions. It supports delayed execution using Dask + and returns a Dask array for efficient computation. - Takes the raw output from patch predictions and post-processes it to improve the - results e.g., using information from neighbouring patches. + Args: + raw_predictions (dask.array.Array): + Raw model predictions as a Dask array. + prediction_shape (tuple[int, ...]): + Shape of the prediction output. + prediction_dtype (type): + Data type of the prediction output. + **kwargs (EngineABCRunParams): + Additional runtime parameters used for post-processing. + + Returns: + dask.array.Array: + Post-processed predictions as a Dask array. """ - _ = kwargs.get("return_labels") # Key values required for post-processing return raw_predictions def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfigABC: - """Helper function to load ioconfig. + """Load or validate the IO configuration for the engine. - If the model is provided by TIAToolbox it will load the default ioconfig. - Otherwise, ioconfig must be specified. + If the model is from TIAToolbox and no IO configuration is provided, + this method attempts to use the default configuration. Otherwise, + it validates and sets the provided configuration. Args: ioconfig (ModelIOConfigABC): - IO configuration to run the engines. - - Raises: - ValueError: - If no io configuration is provided or found in the pretrained TIAToolbox - models. + IO configuration to use for model inference. Returns: ModelIOConfigABC: - The ioconfig used for the run. + The IO configuration to be used during inference. + + Raises: + ValueError: + If no IO configuration is provided and none is available from the model. """ if self.ioconfig is None and ioconfig is None: @@ -782,29 +796,39 @@ def _update_ioconfig( stride_shape: IntPair, input_resolutions: list[dict[Units, Resolution]], ) -> ModelIOConfigABC: - """Update IOConfig. + """Update the IO configuration used for patch-based inference. + + This method updates the patch input shape, stride, and input resolutions + in the IO configuration. If no configuration is provided, it creates a new one. Args: - ioconfig (:class:`ModelIOConfigABC`): - Input ioconfig for PatchPredictor. - patch_input_shape (tuple): - Size of patches input to the model. Patches are at + ioconfig (ModelIOConfigABC): + Existing IO configuration to update. If None, a new one is created. + patch_input_shape (IntPair): + Size of patches input to the model (height, width). Patches are at requested read resolution, not with respect to level 0, and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. Stride is - at requested read resolution, not with respect to + stride_shape (IntPair): + Stride used during patch extraction. + If None, defaults to patch_input_shape. + Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, `stride_shape=patch_input_shape`. - input_resolutions (list(dict(Units, Resolution))): - List of Python dictionaries with units and resolution for each - input head for model inference for reading the image. Supported - units are `level`, `power` and `mpp`. Keys should be "units" and - "resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see + input_resolutions (list[dict[Units, Resolution]]): + List of dictionaries specifying resolution and units + for each input head. Supported units are `level`, `power` and `mpp`. + Keys should be "units" and "resolution" + e.g., [{"units": "mpp", "resolution": 0.25}]. Please see :class:`WSIReader` for details. Returns: - Updated Patch Predictor IO configuration. + ModelIOConfigABC: + Updated IO configuration for patch-based inference. + + Raises: + ValueError: + If neither an IO configuration nor patch/resolution parameters + are provided. """ config_flag = ( @@ -847,7 +871,26 @@ def _update_ioconfig( @staticmethod def _validate_images_masks(images: list | np.ndarray) -> list | np.ndarray: - """Validate input images for a run.""" + """Validate the format and shape of input images or masks. + + Ensures that the input is either a list of file paths or a 4D NumPy array + in NHWC format. + + Args: + images (list | np.ndarray): + List of image paths or a NumPy array of image patches. + + Returns: + list | np.ndarray: + The validated input images or masks. + + Raises: + TypeError: + If the input is neither a list nor a NumPy array. + ValueError: + If the input is a NumPy array but not 4D (NHWC). + + """ if not isinstance(images, (list, np.ndarray)): msg = "Input must be a list of file paths or a numpy array." raise TypeError( @@ -866,10 +909,30 @@ def _validate_images_masks(images: list | np.ndarray) -> list | np.ndarray: @staticmethod def _validate_input_numbers( images: list | np.ndarray, - masks: list[os | Path] | np.ndarray | None = None, + masks: list[os.PathLike] | np.ndarray | None = None, labels: list | None = None, ) -> None: - """Validates number of input images, masks and labels.""" + """Validate that the number of images, masks, and labels match. + + Ensures that the lengths of masks and labels (if provided) match + the number of input images. + + Args: + images (list | np.ndarray): + List of input images or a NumPy array. + masks (list[PathLike] | np.ndarray | None): + Optional list of masks corresponding to the input images. + labels (list | None): + Optional list of labels corresponding to the input images. + + Returns: + None + + Raises: + ValueError: + If the number of masks or labels does not match the number of images. + + """ if masks is None and labels is None: return @@ -896,10 +959,10 @@ def _validate_input_numbers( def _update_run_params( self: EngineABC, - images: list[os | Path | WSIReader] | np.ndarray, - masks: list[os | Path] | np.ndarray | None = None, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, labels: list | None = None, - save_dir: os | Path | None = None, + save_dir: os.PathLike | Path | None = None, ioconfig: ModelIOConfigABC | None = None, output_type: str = "dict", *, @@ -907,26 +970,74 @@ def _update_run_params( patch_mode: bool, **kwargs: Unpack[EngineABCRunParams], ) -> Path | None: - """Updates runtime parameters. + """Update runtime parameters for the engine before running inference. + + This method sets internal attributes such as caching, batch size, + IO configuration, and output format based on user input and keyword arguments. + + Args: + images (list[PathLike | Path | WSIReader] | np.ndarray): + List of input images or a NumPy array of patches. + masks (list[PathLike | Path] | np.ndarray | None): + Optional list of masks for WSI processing. + labels (list | None): + Optional list of labels for input images. + save_dir (PathLike | Path | None): + Directory to save output files. Required for WSI mode. + ioconfig (ModelIOConfigABC | None): + IO configuration for patch extraction and resolution settings. + output_type (str): + Desired output format: "dict", "zarr", or "annotationstore". + overwrite (bool): + Whether to overwrite existing output files. Default is False. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). + **kwargs (EngineABCRunParams): + Additional runtime parameters to update engine attributes. - Updates runtime parameters for an EngineABC for EngineABC.run(). + Returns: + Path | None: + Path to the save directory if applicable, otherwise None. + + Raises: + TypeError: + If an unsupported output_type is provided. + ValueError: + If required configuration or input parameters are missing. """ for key in kwargs: setattr(self, key, kwargs.get(key)) - self.patch_mode = patch_mode - if not self.patch_mode: - self.cache_mode = True # if input is WSI run using cache mode. + if self.num_workers > 0: + dask.config.set(scheduler="threads", num_workers=self.num_workers) + else: + dask.config.set(scheduler="threads") - if self.cache_mode and self.batch_size > self.cache_size: - self.batch_size = self.cache_size + if not self.return_labels: + self.drop_keys.append("label") + + self.patch_mode = patch_mode self._validate_input_numbers(images=images, masks=masks, labels=labels) if output_type.lower() not in ["dict", "zarr", "annotationstore"]: msg = "output_type must be 'dict' or 'zarr' or 'annotationstore'." raise TypeError(msg) + self.output_type = output_type + + if save_dir is not None and output_type.lower() not in [ + "zarr", + "annotationstore", + ]: + self.output_type = "zarr" + msg = ( + f"output_type has been updated to 'zarr' " + f"for saving the file to {save_dir}." + f"Remove `save_dir` input to return the output as a `dict`." + ) + logger.info(msg) + self.images = self._validate_images_masks(images=images) if masks is not None: @@ -936,7 +1047,7 @@ def _update_run_params( # if necessary move model parameters to "cpu" or "gpu" and update ioconfig self._ioconfig = self._load_ioconfig(ioconfig=ioconfig) - self.model.to(device=self.device) + self.model = self.model.to(device=self.device) self._ioconfig = self._update_ioconfig( ioconfig, self.patch_input_shape, @@ -951,66 +1062,95 @@ def _update_run_params( ) def _run_patch_mode( - self: EngineABC, output_type: str, save_dir: Path, **kwargs: EngineABCRunParams + self: EngineABC, + output_type: str, + save_dir: Path, + **kwargs: EngineABCRunParams, ) -> dict | AnnotationStore | Path: - """Runs the Engine in the patch mode. + """Run the engine in patch mode. - Input arguments are passed from :func:`EngineABC.run()`. + This method performs inference on image patches, post-processes the predictions, + and saves the output in the specified format. + + Args: + output_type (str): + Desired output format. Supported values are "dict", "zarr", + and "annotationstore". + save_dir (Path): + Directory to save the output files. + **kwargs (EngineABCRunParams): + Additional runtime parameters including: + - output_file: Name of the output file. + - scale_factor: Scaling factor for annotations. + - class_dict: Mapping of class indices to names. + + Returns: + dict | AnnotationStore | Path: + - If output_type is "dict": returns predictions as a dictionary. + - If output_type is "zarr": returns path to saved zarr file. + - If output_type is "annotationstore": returns an AnnotationStore + or path to .db file. """ save_path = None - if self.cache_mode: + if save_dir: output_file = Path(kwargs.get("output_file", "output.zarr")) save_path = save_dir / (str(output_file.stem) + ".zarr") duplicate_filter = DuplicateFilter() logger.addFilter(duplicate_filter) - dataloader = self.get_dataloader( + self.dataloader = self.get_dataloader( images=self.images, masks=self.masks, labels=self.labels, patch_mode=True, + ioconfig=self._ioconfig, ) raw_predictions = self.infer_patches( - dataloader=dataloader, - save_path=save_path, + dataloader=self.dataloader, return_coordinates=output_type == "annotationstore", ) - processed_predictions = self.post_process_patches( - raw_predictions=raw_predictions, + + raw_predictions["predictions"] = self.post_process_patches( + raw_predictions=raw_predictions["probabilities"], + prediction_shape=raw_predictions["probabilities"].shape[:-1], + prediction_dtype=raw_predictions["probabilities"].dtype, **kwargs, ) + logger.removeFilter(duplicate_filter) out = self.save_predictions( - processed_predictions=processed_predictions, + processed_predictions=raw_predictions, output_type=output_type, - save_dir=save_dir, + save_path=save_path, **kwargs, ) - if save_dir: - msg = f"Output file saved at {out}." - logger.info(msg=msg) - return out - + msg = f"Output file saved at {out}." + logger.info(msg=msg) return out @staticmethod def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, float]: - """Calculates scale factor for final output. + """Calculate the scale factor for final output based on dataloader resolution. - Uses the dataloader resolution and the WSI resolution to calculate scale - factor for final WSI output. + This method compares the resolution used during reading with the slide's + baseline resolution to compute a scale factor for coordinate transformation. Args: dataloader (DataLoader): - Dataloader for the current run. + PyTorch DataLoader used for WSI inference. Must contain resolution + and unit metadata in its dataset. Returns: - scale_factor (float | tuple[float, float]): - Scale factor for final output. + float | tuple[float, float]: + Scale factor for converting coordinates to baseline resolution. + - If units are "mpp": returns (model_mpp / slide_mpp). + - If units are "level": returns downsample ratio. + - If units are "power": returns objective_power / model_power. + - If units are "baseline": returns the resolution directly. """ # get units and resolution from dataloader. @@ -1048,122 +1188,151 @@ def _run_wsi_mode( save_dir: Path, **kwargs: Unpack[EngineABCRunParams], ) -> dict | AnnotationStore | Path: - """Runs the Engine in the WSI mode (patch_mode = False). + """Run the engine in WSI mode (patch_mode = False). + + This method performs inference on each whole slide image (WSI), + post-processes the predictions, and saves the output in the specified format. + + Args: + output_type (str): + Desired output format. Supported values are "dict", "zarr", + and "annotationstore". + save_dir (Path): + Directory to save the output files. + **kwargs (EngineABCRunParams): + Additional runtime parameters including: + - output_file: Name of the output file. + - scale_factor: Scaling factor for annotations. + - class_dict: Mapping of class indices to names. - Input arguments are passed from :func:`EngineABC.run()`. + Returns: + dict | AnnotationStore | Path: + Dictionary mapping each input WSI to its corresponding output path. + Output may be a zarr file, SQLite database, or in-memory dictionary. """ + progress_bar = None + tqdm = get_tqdm() + + if self.verbose: + progress_bar = tqdm( + total=len(self.images), + desc="Processing WSIs", + ) suffix = ".zarr" if output_type == "AnnotationStore": suffix = ".db" - out = {image: save_dir / (str(image.stem) + suffix) for image in self.images} + def get_path(image: Path | WSIReader) -> Path: + """Return path to output file.""" + return image.input_path if isinstance(image, WSIReader) else image + + out = { + get_path(image): save_dir / (get_path(image).stem + suffix) + for image in self.images + } save_path = { - image: save_dir / (str(image.stem) + ".zarr") for image in self.images + get_path(image): save_dir / (get_path(image).stem + ".zarr") + for image in self.images } for image_num, image in enumerate(self.images): duplicate_filter = DuplicateFilter() logger.addFilter(duplicate_filter) mask = self.masks[image_num] if self.masks is not None else None - dataloader = self.get_dataloader( + self.dataloader = self.get_dataloader( images=image, masks=mask, patch_mode=False, ioconfig=self._ioconfig, + auto_get_mask=kwargs.get("auto_get_mask", True), ) - scale_factor = self._calculate_scale_factor(dataloader=dataloader) + scale_factor = self._calculate_scale_factor(dataloader=self.dataloader) raw_predictions = self.infer_wsi( - dataloader=dataloader, - save_path=save_path[image], + dataloader=self.dataloader, + save_path=save_path[get_path(image)], **kwargs, ) - processed_predictions = self.post_process_wsi( - raw_predictions=raw_predictions, + + raw_predictions["predictions"] = self.post_process_wsi( + raw_predictions=raw_predictions["probabilities"], + prediction_shape=raw_predictions["probabilities"].shape[:-1], + prediction_dtype=raw_predictions["probabilities"].dtype, **kwargs, ) - kwargs["output_file"] = out[image] + + kwargs["output_file"] = out[get_path(image)] kwargs["scale_factor"] = scale_factor - out[image] = self.save_predictions( - processed_predictions=processed_predictions, + out[get_path(image)] = self.save_predictions( + processed_predictions=raw_predictions, output_type=output_type, - save_dir=save_dir, + save_path=save_path[get_path(image)], **kwargs, ) logger.removeFilter(duplicate_filter) - msg = f"Output file saved at {out[image]}." + msg = f"Output file saved at {out[get_path(image)]}." logger.info(msg=msg) + if progress_bar: + progress_bar.update() + + if progress_bar: + progress_bar.close() + return out def run( self: EngineABC, - images: list[os | Path | WSIReader] | np.ndarray, - masks: list[os | Path] | np.ndarray | None = None, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, labels: list | None = None, ioconfig: ModelIOConfigABC | None = None, *, patch_mode: bool = True, - save_dir: os | Path | None = None, # None will not save output + save_dir: os.PathLike | Path | None = None, overwrite: bool = False, output_type: str = "dict", **kwargs: Unpack[EngineABCRunParams], ) -> AnnotationStore | Path | str | dict: """Run the engine on input images. + This method orchestrates the full inference pipeline, including preprocessing, + model inference, post-processing, and saving results. It supports both patch + and WSI modes. + Args: - images (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. - masks (list | None): - List of masks. Only utilised when patch_mode is False. + images (list[PathLike | Path | WSIReader] | np.ndarray): + List of input images or a NumPy array of patches. + masks (list[PathLike | Path] | np.ndarray | None): + Optional list of masks for WSI processing. + Only utilised when patch_mode is False. Patches are only generated within a masked area. If not provided, then a tissue mask will be automatically generated for whole slide images. labels (list | None): - List of labels. Only a single label per image is supported. + Optional list of labels for input images. + ioconfig (ModelIOConfigABC | None): + IO configuration for patch extraction and resolution settings. patch_mode (bool): - Whether to treat input image as a patch or WSI. - default = True. - ioconfig (IOPatchPredictorConfig): - IO configuration. - save_dir (str or pathlib.Path): - Output directory to save the results. - If save_dir is not provided when patch_mode is False, - then for a single image the output is created in the current directory. - If there are multiple WSIs as input then the user must provide - path to save directory otherwise an OSError will be raised. + Whether to treat input as patches (`True`) or WSIs (`False`). + Default is True. + save_dir (PathLike | Path | None): + Directory to save output files. Required for WSI mode. overwrite (bool): - Whether to overwrite the results. Default = False. + Whether to overwrite existing output files. Default is False. output_type (str): - The format of the output type. "output_type" can be - "dict", "zarr" or "AnnotationStore". Default value is "zarr". - When saving in the zarr format the output is saved using the - `python zarr library `__ - as a zarr group. If the required output type is an "AnnotationStore" - then the output will be intermediately saved as zarr but converted - to :class:`AnnotationStore` and saved as a `.db` file - at the end of the loop. + Desired output format: "dict", "zarr", or "annotationstore". **kwargs (EngineABCRunParams): - Keyword Args to update :class:`EngineABC` attributes during runtime. + Additional runtime parameters to update engine attributes. Returns: - (:class:`numpy.ndarray`, dict): - Model predictions of the input dataset. If multiple - whole slide images are provided as input, - or save_output is True, then results are saved to - `save_dir` and a dictionary indicating save location for - each input is returned. - - The dict has the following format: - - - img_path: path of the input image. - - raw: path to save location for raw prediction, - saved in .json. + AnnotationStore | Path | str | dict: + - If patch_mode is True: returns predictions or path to saved output. + - If patch_mode is False: returns a dictionary mapping each WSI to + its output path. Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] @@ -1201,7 +1370,7 @@ def run( if patch_mode: return self._run_patch_mode( - output_type=output_type, + output_type=self.output_type, save_dir=save_dir, **kwargs, ) @@ -1211,7 +1380,59 @@ def run( # pre-processing, post-processing and save_output # for WSIs separately. return self._run_wsi_mode( - output_type=output_type, + output_type=self.output_type, save_dir=save_dir, **kwargs, ) + + +def prepare_engines_save_dir( + save_dir: str | Path | None, + *, + patch_mode: bool, + overwrite: bool = False, +) -> Path | None: + """Create or validate the save directory for engine outputs. + + Args: + save_dir (str | Path | None): + Path to the output directory. + patch_mode (bool): + Whether the input is treated as patches. + overwrite (bool): + Whether to overwrite existing directory. Default is False. + + Returns: + Path | None: + Path to the output directory if created or validated, else None. + + Raises: + OSError: + If patch_mode is False and save_dir is not provided. + + """ + if patch_mode: + if save_dir is not None: + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=overwrite) + return save_dir + return None + + if save_dir is None: + msg = ( + "Input WSIs detected but no save directory provided. " + "Please provide a 'save_dir'." + ) + raise OSError(msg) + + logger.info( + "When providing multiple whole slide images, " + "the outputs will be saved and the locations of outputs " + "will be returned to the calling function when `run()` " + "finishes successfully." + ) + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=overwrite) + + return save_dir diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 23e6b1804..b940598ca 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -1,19 +1,38 @@ -"""Defines PatchPredictor Engine.""" +"""Defines the PatchPredictor engine for patch-level inference in digital pathology. + +This module implements the PatchPredictor class, which extends the EngineABC base +class to support patch-based and whole slide image (WSI) inference using deep learning +models from TIAToolbox. It provides utilities for model initialization, post-processing, +and output management, including support for multiple output formats. + +Classes: + - PatchPredictor: + Engine for performing patch-level predictions. + - PredictorRunParams: + TypedDict for configuring runtime parameters. + +Example: + >>> images = [np.ndarray, np.ndarray] + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(images, patch_mode=True) + +""" from __future__ import annotations -import math from typing import TYPE_CHECKING -import zarr from typing_extensions import Unpack +from tiatoolbox.utils.misc import cast_to_min_dtype + from .engine_abc import EngineABC, EngineABCRunParams if TYPE_CHECKING: # pragma: no cover import os from pathlib import Path + import dask.array as da import numpy as np from tiatoolbox.annotation import AnnotationStore @@ -22,60 +41,44 @@ from tiatoolbox.wsicore import WSIReader -class PredictorRunParams(EngineABCRunParams): - """Class describing the input parameters for the :func:`EngineABC.run()` method. +class PredictorRunParams(EngineABCRunParams, total=False): + """Parameters for configuring the `PatchPredictor.run()` method. + + This class extends `EngineABCRunParams` with additional parameters specific + to patch-level prediction workflows. Attributes: + auto_get_mask (bool): + Whether to automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. batch_size (int): Number of image patches to feed to the model in a forward pass. - cache_mode (bool): - Whether to run the Engine in cache_mode. For large datasets, - we recommend to set this to True to avoid out of memory errors. - For smaller datasets, the cache_mode is set to False as - the results can be saved in memory. - cache_size (int): - Specifies how many image patches to process in a batch when - cache_mode is set to True. If cache_size is less than the batch_size - batch_size is set to cache_size. class_dict (dict): Optional dictionary mapping classification outputs to class names. device (str): - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. + Device to run the model on (e.g., "cpu", "cuda"). + input_resolutions (list[dict]): + Resolution used for reading the image. See `WSIReader` for details. ioconfig (ModelIOConfigABC): - Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. - return_labels (bool): - Whether to return the labels with the predictions. - num_loader_workers (int): - Number of workers used in :class:`torch.utils.data.DataLoader`. - num_post_proc_workers (int): - Number of workers to postprocess the results of the model. + Input/output configuration for patch extraction and resolution. + memory_threshold (int): + Memory usage threshold (in percentage) to trigger caching behavior. + num_workers (int): + Number of workers used in DataLoader. output_file (str): - Output file name to save "zarr" or "db". If None, path to output is - returned by the engine. - patch_input_shape (tuple): - Shape of patches input to the model as tuple of height and width (HW). - Patches are requested at read resolution, not with respect to level 0, - and must be positive. - input_resolutions (list(dict(Units, Resolution))): - List of Python dictionaries with units and resolution for each - input head for model inference for reading the image. Supported - units are `level`, `power` and `mpp`. Keys should be "units" and - "resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see - :class:`WSIReader` for details. + Output file name for saving results (e.g., .zarr or .db). + patch_input_shape (tuple[int, int]): + Shape of input patches (height, width). + return_labels (bool): + Whether to return labels with predictions. return_probabilities (bool): - Whether to return per-class probabilities. + Whether to return per-class probabilities in the output. + If False, only predicted labels are returned. scale_factor (tuple[float, float]): - The scale factor to use when loading the - annotations. All coordinates will be multiplied by this factor to allow - conversion of annotations saved at non-baseline resolution to baseline. - Should be model_mpp/slide_mpp. - stride_shape (tuple): - Stride used during WSI processing. Stride is - at requested read resolution, not with respect to - level 0, and must be positive. If not provided, - `stride_shape=patch_input_shape`. + Scale factor for converting annotations to baseline resolution. + Typically model_mpp / slide_mpp. + stride_shape (tuple[int, int]): + Stride used during WSI processing. Defaults to patch_input_shape. verbose (bool): Whether to output logging information. @@ -85,112 +88,113 @@ class PredictorRunParams(EngineABCRunParams): class PatchPredictor(EngineABC): - r"""Patch level prediction for digital histology images. - - The models provided by TIAToolbox should give the following results: - - .. list-table:: PatchPredictor performance on the Kather100K dataset [1] - :widths: 15 15 - :header-rows: 1 - - * - Model name - - F\ :sub:`1`\ score - * - alexnet-kather100k - - 0.965 - * - resnet18-kather100k - - 0.990 - * - resnet34-kather100k - - 0.991 - * - resnet50-kather100k - - 0.989 - * - resnet101-kather100k - - 0.989 - * - resnext50_32x4d-kather100k - - 0.992 - * - resnext101_32x8d-kather100k - - 0.991 - * - wide_resnet50_2-kather100k - - 0.989 - * - wide_resnet101_2-kather100k - - 0.990 - * - densenet121-kather100k - - 0.993 - * - densenet161-kather100k - - 0.992 - * - densenet169-kather100k - - 0.992 - * - densenet201-kather100k - - 0.991 - * - mobilenet_v2-kather100k - - 0.990 - * - mobilenet_v3_large-kather100k - - 0.991 - * - mobilenet_v3_small-kather100k - - 0.992 - * - googlenet-kather100k - - 0.992 - - .. list-table:: PatchPredictor performance on the PCam dataset [2] - :widths: 15 15 - :header-rows: 1 - - * - Model name - - F\ :sub:`1`\ score - * - alexnet-pcam - - 0.840 - * - resnet18-pcam - - 0.888 - * - resnet34-pcam - - 0.889 - * - resnet50-pcam - - 0.892 - * - resnet101-pcam - - 0.888 - * - resnext50_32x4d-pcam - - 0.900 - * - resnext101_32x8d-pcam - - 0.892 - * - wide_resnet50_2-pcam - - 0.901 - * - wide_resnet101_2-pcam - - 0.898 - * - densenet121-pcam - - 0.897 - * - densenet161-pcam - - 0.893 - * - densenet169-pcam - - 0.895 - * - densenet201-pcam - - 0.891 - * - mobilenet_v2-pcam - - 0.899 - * - mobilenet_v3_large-pcam - - 0.895 - * - mobilenet_v3_small-pcam - - 0.890 - * - googlenet-pcam - - 0.867 + r"""Patch-level prediction engine for digital histology images. + + This class extends `EngineABC` to support patch-based inference using + pretrained or custom models from TIAToolbox. It supports both patch and + whole slide image (WSI) modes, and provides utilities for post-processing + and saving predictions. + + Supported Models: + .. list-table:: PatchPredictor performance on the Kather100K dataset [1]. + :widths: 15 15 + :header-rows: 1 + + * - Model name + - F\ :sub:`1`\ score + * - alexnet-kather100k + - 0.965 + * - resnet18-kather100k + - 0.990 + * - resnet34-kather100k + - 0.991 + * - resnet50-kather100k + - 0.989 + * - resnet101-kather100k + - 0.989 + * - resnext50_32x4d-kather100k + - 0.992 + * - resnext101_32x8d-kather100k + - 0.991 + * - wide_resnet50_2-kather100k + - 0.989 + * - wide_resnet101_2-kather100k + - 0.990 + * - densenet121-kather100k + - 0.993 + * - densenet161-kather100k + - 0.992 + * - densenet169-kather100k + - 0.992 + * - densenet201-kather100k + - 0.991 + * - mobilenet_v2-kather100k + - 0.990 + * - mobilenet_v3_large-kather100k + - 0.991 + * - mobilenet_v3_small-kather100k + - 0.992 + * - googlenet-kather100k + - 0.992 + + .. list-table:: PatchPredictor performance on the PCam dataset [2] + :widths: 15 15 + :header-rows: 1 + + * - Model name + - F\ :sub:`1`\ score + * - alexnet-pcam + - 0.840 + * - resnet18-pcam + - 0.888 + * - resnet34-pcam + - 0.889 + * - resnet50-pcam + - 0.892 + * - resnet101-pcam + - 0.888 + * - resnext50_32x4d-pcam + - 0.900 + * - resnext101_32x8d-pcam + - 0.892 + * - wide_resnet50_2-pcam + - 0.901 + * - wide_resnet101_2-pcam + - 0.898 + * - densenet121-pcam + - 0.897 + * - densenet161-pcam + - 0.893 + * - densenet169-pcam + - 0.895 + * - densenet201-pcam + - 0.891 + * - mobilenet_v2-pcam + - 0.899 + * - mobilenet_v3_large-pcam + - 0.895 + * - mobilenet_v3_small-pcam + - 0.890 + * - googlenet-pcam + - 0.867 Args: model (str | ModelABC): - A PyTorch model or name of pretrained model. + A PyTorch model instance or name of a pretrained model from TIAToolbox. + If a string is provided, pretrained weights + will be downloaded unless overridden via `weights`. The user can request pretrained models from the toolbox model zoo using the list of pretrained models available at this `link `_ By default, the corresponding pretrained weights will also - be downloaded. However, you can override with your own set - of weights using the `weights` parameter. Default is `None`. + be downloaded. batch_size (int): - Number of image patches fed into the model each time in a - forward/backward pass. Default value is 8. - num_loader_workers (int): - Number of workers to load the data using :class:`torch.utils.data.Dataset`. - Please note that they will also perform preprocessing. Default value is 0. - num_post_proc_workers (int): - Number of workers to postprocess the results of the model. - Default value is 0. - weights (str or Path): - Path to the weight of the corresponding `model`. + Number of image patches processed per forward pass. + Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. >>> engine = PatchPredictor( ... model="pretrained-model", @@ -198,98 +202,50 @@ class PatchPredictor(EngineABC): ... ) device (str): - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. Default is "cpu". + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". verbose (bool): - Whether to output logging information. Default value is False. + Whether to enable verbose logging. Default is True. + Attributes: - images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): - A list of image patches in NHWC format as a numpy array - or a list of str/paths to WSIs. - masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): - A list of tissue masks or binary masks corresponding to processing area of - input images. These can be a list of numpy arrays or paths to - the saved image masks. These are only utilized when patch_mode is False. - Patches are only generated within a masked area. + images (list[str | Path] | np.ndarray): + Input image patches or WSI paths. + masks (list[str | Path] | np.ndarray): + Optional tissue masks for WSI processing. + These are only utilized when patch_mode is False. If not provided, then a tissue mask will be automatically generated for whole slide images. - patch_mode (str): - Whether to treat input images as a set of image patches. TIAToolbox defines - an image as a patch if HWC of the input image matches with the HWC expected - by the model. If HWC of the input image does not match with the HWC expected - by the model, then the patch_mode must be set to False which will allow the - engine to extract patches from the input image. - In this case, when the patch_mode is False the input images are treated - as WSIs. Default value is True. - model (str | ModelABC): - A PyTorch model or a name of an existing model from the TIAToolbox model zoo - for processing the data. For a full list of pretrained models, - refer to the `docs - `_ - By default, the corresponding pretrained weights will also - be downloaded. However, you can override with your own set - of weights via the `weights` argument. Argument - is case-insensitive. + patch_mode (bool): + Whether input is treated as patches (`True`) or WSIs (`False`). + model (ModelABC): + Loaded PyTorch model. ioconfig (ModelIOConfigABC): - Input IO configuration of type :class:`ModelIOConfigABC` to run the Engine. - _ioconfig (ModelIOConfigABC): - Runtime ioconfig. + IO configuration for patch extraction and resolution. return_labels (bool): - Whether to return the labels with the predictions. - input_resolutions (list(dict(Units, Resolution))): - List of Python dictionaries with units and resolution for each - input head for model inference for reading the image. Supported + Whether to include labels in the output. + input_resolutions (list[dict]): + Resolution settings for model input. Supported units are `level`, `power` and `mpp`. Keys should be "units" and "resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see :class:`WSIReader` for details. - patch_input_shape (tuple): - Shape of patches input to the model as tupled of HW. Patches are at + patch_input_shape (tuple[int, int]): + Shape of input patches (height, width). Patches are at requested read resolution, not with respect to level 0, and must be positive. - stride_shape (tuple): - Stride used during WSI processing. Stride is + stride_shape (tuple[int, int]): + Stride used during patch extraction. Stride is at requested read resolution, not with respect to level 0, and must be positive. If not provided, `stride_shape=patch_input_shape`. - batch_size (int): - Number of images fed into the model each time. - cache_mode (bool): - Whether to run the Engine in cache_mode. For large datasets, - we recommend to set this to True to avoid out of memory errors. - For smaller datasets, the cache_mode is set to False as - the results can be saved in memory. cache_mode is always True when - processing WSIs i.e., when `patch_mode` is False. Default value is False. - cache_size (int): - Specifies how many image patches to process in a batch when - cache_mode is set to True. If cache_size is less than the batch_size - batch_size is set to cache_size. Default value is 10,000. labels (list | None): - List of labels. Only a single label per image is supported. - device (str): - :class:`torch.device` to run the model. - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. Default value is "cpu". - num_loader_workers (int): - Number of workers used in :class:`torch.utils.data.DataLoader`. - num_post_proc_workers (int): - Number of workers to postprocess the results of the model. - return_labels (bool): - Whether to return the output labels. Default value is False. - input_resolutions (list(dict(Units, Resolution))): - List of Python dictionaries with units and resolution for each - input head for model inference for reading the image. Supported - units are `level`, `power` and `mpp`. When `patch_mode` is `True`, - the input image patches are expected to be at the correct resolution and - units. When `patch_mode` is `False`, the patches are extracted at the - requested resolution and units. Default value is [{"units": "baseline", - "resolution": 1.0}]. - verbose (bool): - Whether to output logging information. Default value is False. - - Examples: + Optional labels for input images. + Only a single label per image is supported. + drop_keys (list): + Keys to exclude from model output. + output_type (str): + Format of output ("dict", "zarr", "annotationstore"). + + Example: >>> # list of 2 image patches as input >>> data = ['path/img.svs', 'path/img.svs'] >>> predictor = PatchPredictor(model="resnet18-kather100k") @@ -330,182 +286,219 @@ def __init__( self: PatchPredictor, model: str | ModelABC, batch_size: int = 8, - num_loader_workers: int = 0, - num_post_proc_workers: int = 0, + num_workers: int = 0, weights: str | Path | None = None, *, device: str = "cpu", verbose: bool = True, ) -> None: - """Initialize :class:`PatchPredictor`.""" + """Initialize the PatchPredictor engine. + + Args: + model (str | ModelABC): + A PyTorch model instance or name of a pretrained model from TIAToolbox. + If a string is provided, the corresponding pretrained + weights will be downloaded unless overridden via `weights`. + batch_size (int): + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): Path to model weights. + If None, default weights are used. + device (str): + device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Whether to enable verbose logging. Default is True. + + """ super().__init__( model=model, batch_size=batch_size, - num_loader_workers=num_loader_workers, - num_post_proc_workers=num_post_proc_workers, + num_workers=num_workers, weights=weights, device=device, verbose=verbose, ) - def post_process_cache_mode( - self: PatchPredictor, - raw_predictions: Path, - **kwargs: Unpack[PredictorRunParams], - ) -> Path: - """Returns an array from raw predictions.""" - return_probabilities = kwargs.get("return_probabilities") - zarr_group = zarr.open(str(raw_predictions), mode="r+") - - num_iter = math.ceil(len(zarr_group["probabilities"]) / self.batch_size) - start = 0 - for _ in range(num_iter): - # Probabilities for post-processing - probabilities = zarr_group["probabilities"][start : start + self.batch_size] - start = start + self.batch_size - predictions = self.model.postproc_func( - probabilities, - ) - if "predictions" in zarr_group: - zarr_group["predictions"].append(predictions) - continue - - zarr_dataset = zarr_group.create_dataset( - name="predictions", - shape=predictions.shape, - compressor=zarr_group["probabilities"].compressor, - ) - zarr_dataset[:] = predictions - - if return_probabilities is not False: - return raw_predictions - - del zarr_group["probabilities"] - - return raw_predictions - def post_process_patches( self: PatchPredictor, - raw_predictions: dict | Path, + raw_predictions: da.Array, + prediction_shape: tuple[int, ...], + prediction_dtype: type, **kwargs: Unpack[PredictorRunParams], - ) -> dict | Path: - """Post-process raw patch predictions from inference. + ) -> da.Array: + """Post-process raw patch predictions from model inference. - The output of :func:`infer_patches()` with patch prediction information will be - post-processed using this function. The processed output will be saved in the - respective input format. If `cache_mode` is True, the function processes the - input using zarr group with size specified by `cache_size`. + This method applies the model's post-processing function to the raw predictions + obtained from `infer_patches()`. The output is wrapped in a Dask array for + efficient computation and memory handling. Args: - raw_predictions (dict | Path): - A dictionary or path to zarr with patch prediction information. + raw_predictions (da.Array | np.ndarray): + Raw model predictions. + prediction_shape (tuple[int, ...]): + Expected shape of the prediction output. + prediction_dtype (type): + Data type of the prediction output. **kwargs (PredictorRunParams): - Keyword Args to update setup_patch_dataset() method attributes. See - :class:`PredictorRunParams` for accepted keyword arguments. + Additional runtime parameters, including `return_probabilities`. Returns: - dict or Path: - Returns patch based output after post-processing. Returns path to - saved zarr file if `cache_mode` is True. + dask.array.Array: Post-processed predictions as a Dask array. """ - return_probabilities = kwargs.get("return_probabilities") - if self.cache_mode: - return self.post_process_cache_mode(raw_predictions, **kwargs) - - probabilities = raw_predictions.get("probabilities") + _ = kwargs.get("return_probabilities") + _ = prediction_shape + _ = prediction_dtype + raw_predictions = self.model.postproc_func(raw_predictions) + return cast_to_min_dtype(raw_predictions) - predictions = self.model.postproc_func( - probabilities, - ) + def post_process_wsi( + self: PatchPredictor, + raw_predictions: da.Array, + prediction_shape: tuple[int, ...], + prediction_dtype: type, + **kwargs: Unpack[PredictorRunParams], + ) -> da.Array: + """Post-process predictions from whole slide image (WSI) inference. - raw_predictions["predictions"] = predictions + This method refines the raw patch-level predictions obtained from WSI inference. + It typically applies spatial smoothing or other contextual operations using + neighboring patch information. Internally, it delegates to + `post_process_patches()`. - if return_probabilities is not False: - return raw_predictions + Args: + raw_predictions (dask.array.Array): + Raw model predictions. + prediction_shape (tuple[int, ...]): + Expected shape of the prediction output. + prediction_dtype (type): + Data type of the prediction output. + **kwargs (PredictorRunParams): + Additional runtime parameters, including `return_probabilities`. - del raw_predictions["probabilities"] + Returns: + dask.array.Array: Post-processed predictions as a Dask array. - return raw_predictions + """ + return self.post_process_patches( + raw_predictions=raw_predictions, + prediction_shape=prediction_shape, + prediction_dtype=prediction_dtype, + **kwargs, + ) - def post_process_wsi( + def _update_run_params( self: PatchPredictor, - raw_predictions: dict | Path, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + labels: list | None = None, + save_dir: os.PathLike | Path | None = None, + ioconfig: ModelIOConfigABC | None = None, + output_type: str = "dict", + *, + overwrite: bool = False, + patch_mode: bool, **kwargs: Unpack[PredictorRunParams], - ) -> dict | Path: - """Post process WSI output. + ) -> Path | None: + """Update runtime parameters for the PatchPredictor engine. - Takes the raw output from patch predictions and post-processes it to improve the - results e.g., using information from neighbouring patches. + This method sets internal attributes such as caching, batch size, + IO configuration, and output format based on user input and keyword arguments. + It also configures whether to include probabilities in the output. + + Args: + images (list[PathLike | WSIReader] | np.ndarray): + Input images or patches. + masks (list[PathLike] | np.ndarray | None): + Optional masks for WSI processing. + labels (list | None): + Optional labels for input images. + save_dir (PathLike | None): + Directory to save output files. Required for WSI mode. + ioconfig (ModelIOConfigABC | None): + IO configuration for patch extraction and resolution. + output_type (str): + Desired output format: "dict", "zarr", or "annotationstore". + overwrite (bool): + Whether to overwrite existing output files. Default is False. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). + **kwargs (PredictorRunParams): + Additional runtime parameters. + + Returns: + Path | None: + Path to the save directory if applicable, otherwise None. """ - return self.post_process_cache_mode(raw_predictions, **kwargs) + return_probabilities = kwargs.get("return_probabilities") + if not return_probabilities: + self.drop_keys.append("probabilities") + return super()._update_run_params( + images=images, + masks=masks, + labels=labels, + save_dir=save_dir, + ioconfig=ioconfig, + overwrite=overwrite, + patch_mode=patch_mode, + output_type=output_type, + **kwargs, + ) def run( self: PatchPredictor, - images: list[os | Path | WSIReader] | np.ndarray, - masks: list[os | Path] | np.ndarray | None = None, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, labels: list | None = None, ioconfig: ModelIOConfigABC | None = None, *, patch_mode: bool = True, - save_dir: os | Path | None = None, # None will not save output + save_dir: os.PathLike | Path | None = None, overwrite: bool = False, output_type: str = "dict", **kwargs: Unpack[PredictorRunParams], ) -> AnnotationStore | Path | str | dict: - """Run the engine on input images. + """Run the PatchPredictor engine on input images. + + This method orchestrates the full inference pipeline, including preprocessing, + model inference, post-processing, and saving results. It supports both patch + and whole slide image (WSI) modes. Args: - images (list, ndarray): - List of inputs to process. when using `patch` mode, the + images (list[PathLike | WSIReader] | np.ndarray): + Input images or patches. When using `patch` mode, the input must be either a list of images, a list of image file paths or a numpy array of an image list. - masks (list | None): - List of masks. Only utilised when patch_mode is False. + masks (list[PathLike] | np.ndarray | None): + Optional masks for WSI processing. + Only utilised when patch_mode is False. Patches are only generated within a masked area. If not provided, then a tissue mask will be automatically generated for whole slide images. labels (list | None): - List of labels. Only a single label per image is supported. + Optional labels for input images. + Only a single label per image is supported. + ioconfig (ModelIOConfigABC | None): + IO configuration for patch extraction and resolution. patch_mode (bool): - Whether to treat input image as a patch or WSI. - default = True. - ioconfig (IOPatchPredictorConfig): - IO configuration. - save_dir (str or pathlib.Path): - Output directory to save the results. - If save_dir is not provided when patch_mode is False, - then for a single image the output is created in the current directory. - If there are multiple WSIs as input then the user must provide - path to save directory otherwise an OSError will be raised. + Whether to treat input as patches (`True`) or WSIs (`False`). + save_dir (PathLike | None): + Directory to save output files. Required for WSI mode. overwrite (bool): - Whether to overwrite the results. Default = False. + Whether to overwrite existing output files. Default is False. output_type (str): - The format of the output type. "output_type" can be - "zarr" or "AnnotationStore". Default value is "zarr". - When saving in the zarr format the output is saved using the - `python zarr library `__ - as a zarr group. If the required output type is an "AnnotationStore" - then the output will be intermediately saved as zarr but converted - to :class:`AnnotationStore` and saved as a `.db` file - at the end of the loop. + Desired output format: "dict", "zarr", or "annotationstore". + Default value is "zarr". **kwargs (PredictorRunParams): - Keyword Args to update :class:`EngineABC` attributes during runtime. + Additional runtime parameters. Returns: - (:class:`numpy.ndarray`, dict): - Model predictions of the input dataset. If multiple - whole slide images are provided as input, - or save_output is True, then results are saved to - `save_dir` and a dictionary indicating save location for - each input is returned. - - The dict has the following format: - - - img_path: path of the input image. - - raw: path to save location for raw prediction, - saved in .json. + AnnotationStore | Path | str | dict: + - If `patch_mode` is True: returns predictions or path to saved output. + - If `patch_mode` is False: returns a dictionary mapping each WSI to + its output path. Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] diff --git a/tiatoolbox/models/engine/semantic_segmentor.py b/tiatoolbox/models/engine/semantic_segmentor.py index ff9988e71..361630617 100644 --- a/tiatoolbox/models/engine/semantic_segmentor.py +++ b/tiatoolbox/models/engine/semantic_segmentor.py @@ -1,1372 +1,1245 @@ -"""This module implements semantic segmentation.""" +"""Semantic Segmentation Engine for Whole Slide Images (WSIs) using TIAToolbox. + +This module defines the `SemanticSegmentor` class, which extends the `PatchPredictor` +engine to support semantic segmentation workflows on digital pathology images. +It leverages deep learning models from TIAToolbox to perform patch-level and +WSI-level inference, and includes utilities for preprocessing, postprocessing, +and saving predictions in various formats. + +Key Components: +--------------- +Classes: +- SemanticSegmentorRunParams: + Configuration parameters for controlling runtime behavior during segmentation. +- SemanticSegmentor: + Core engine for performing semantic segmentation on image patches or WSIs. + +Functions: +- concatenate_none: + Concatenate arrays while gracefully handling None values. +- merge_horizontal: + Incrementally merge horizontal patches and update location arrays. +- save_to_cache: + Save intermediate canvas and count arrays to Zarr cache. +- merge_vertical_chunkwise: + Merge vertically chunked canvas and count arrays into a probability map. +- store_probabilities: + Store computed probability data in Zarr or Dask arrays. +- prepare_full_batch: + Align patch-level predictions with global output locations. + +Example: +>>> from tiatoolbox.models.engine.semantic_segmentor import SemanticSegmentor +>>> segmentor = SemanticSegmentor(model="fcn_resnet50_unet-bcss") +>>> wsis = ["slide1.svs", "slide2.svs"] +>>> output = segmentor.run(wsis, patch_mode=False) +>>> +>>> patches = [np.ndarray, np.ndarray] +>>> segmentor = SemanticSegmentor(model="fcn_resnet50_unet-bcss") +>>> output = segmentor.run(patches, patch_mode=True, output_type="dict") + +Notes: +------ +- Supports both patch-based and WSI-based segmentation. +- Compatible with TIAToolbox pretrained models and custom PyTorch models. +- Outputs can be saved as dictionaries, Zarr arrays, or AnnotationStore databases. +- Includes memory-aware caching and efficient merging strategies for large-scale + inference. + +""" from __future__ import annotations -import copy -import logging -import shutil -from concurrent.futures import ProcessPoolExecutor +import gc from pathlib import Path -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING -import cv2 -import joblib +import dask.array as da import numpy as np +import psutil import torch -import torch.multiprocessing as torch_mp -import torch.utils.data as torch_data -import tqdm +import zarr +from dask import compute +from typing_extensions import Unpack -from tiatoolbox import logger, rcParam -from tiatoolbox.models.architecture import get_pretrained_model -from tiatoolbox.models.architecture.utils import compile_model -from tiatoolbox.models.dataset.dataset_abc import WSIStreamDataset -from tiatoolbox.models.models_abc import model_to -from tiatoolbox.tools.patchextraction import PatchExtractor -from tiatoolbox.utils import imread -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader +from tiatoolbox import logger +from tiatoolbox.models.dataset.dataset_abc import WSIPatchDataset +from tiatoolbox.utils.misc import ( + dict_to_store_semantic_segmentor, + get_tqdm, +) +from tiatoolbox.wsicore.wsireader import is_zarr -from .io_config import IOSegmentorConfig +from .patch_predictor import PatchPredictor, PredictorRunParams if TYPE_CHECKING: # pragma: no cover - from tiatoolbox.type_hints import IntPair, Resolution, Units + import os + from torch.utils.data import DataLoader -def _estimate_canvas_parameters( - sample_prediction: np.ndarray, - canvas_shape: np.ndarray, -) -> tuple[tuple, tuple, bool]: - """Estimates canvas parameters. + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.engine.io_config import IOSegmentorConfig + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.type_hints import Resolution + from tiatoolbox.wsicore import WSIReader - Args: - sample_prediction (:class:`numpy.ndarray`): - Patch prediction assuming to be of shape HWC. - canvas_shape (:class:`numpy.ndarray`): - HW of the supposed assembled image. - Returns: - (tuple, tuple, bool): - Canvas Shape, Canvas Count and whether to add singleton dimension. +class SemanticSegmentorRunParams(PredictorRunParams, total=False): + """Runtime parameters for configuring the `SemanticSegmentor.run()` method. + + This class extends `PredictorRunParams`, which itself extends `EngineABCRunParams`, + and adds parameters specific to semantic segmentation workflows. + + Attributes: + auto_get_mask (bool): + Whether to automatically generate segmentation masks using + `wsireader.tissue_mask()` during processing. + batch_size (int): + Number of image patches to feed to the model in a forward pass. + class_dict (dict): + Optional dictionary mapping classification outputs to class names. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). + input_resolutions (list[dict]): + Resolution used for reading the image. See `WSIReader` for details. + ioconfig (ModelIOConfigABC): + Input/output configuration for patch extraction and resolution. + memory_threshold (int): + Memory usage threshold (in percentage) to trigger caching behavior. + num_workers (int): + Number of workers used in DataLoader. + output_file (str): + Output file name for saving results (e.g., .zarr or .db). + output_resolutions (Resolution): + Resolution used for writing output predictions. + patch_input_shape (tuple[int, int]): + Shape of input patches (height, width). + patch_output_shape (tuple[int, int]): + Shape of output patches (height, width). + return_labels (bool): + Whether to return labels with predictions. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + Scale factor for converting annotations to baseline resolution. + Typically model_mpp / slide_mpp. + stride_shape (tuple[int, int]): + Stride used during WSI processing. Defaults to patch_input_shape. + verbose (bool): + Whether to output logging information. """ - if len(sample_prediction.shape) == 3: # noqa: PLR2004 - num_output_ch = sample_prediction.shape[-1] - canvas_cum_shape_ = (*tuple(canvas_shape), num_output_ch) - canvas_count_shape_ = (*tuple(canvas_shape), 1) - add_singleton_dim = num_output_ch == 1 - else: - canvas_cum_shape_ = (*tuple(canvas_shape), 1) - canvas_count_shape_ = (*tuple(canvas_shape), 1) - add_singleton_dim = True - - return canvas_cum_shape_, canvas_count_shape_, add_singleton_dim - - -def _prepare_save_output( - save_path: str | Path, - cache_count_path: str | Path, - canvas_cum_shape_: tuple[int, ...], - canvas_count_shape_: tuple[int, ...], -) -> tuple: - """Prepares for saving the cached output.""" - if save_path is not None: - save_path = Path(save_path) - cache_count_path = Path(cache_count_path) - if Path.exists(save_path) and Path.exists(cache_count_path): - cum_canvas = np.load(str(save_path), mmap_mode="r+") - count_canvas = np.load(str(cache_count_path), mmap_mode="r+") - if canvas_cum_shape_ != cum_canvas.shape: - msg = "Existing image shape in `save_path` does not match." - raise ValueError(msg) - if canvas_count_shape_ != count_canvas.shape: - msg = "Existing image shape in `cache_count_path` does not match." - raise ValueError( - msg, - ) - else: - cum_canvas = np.lib.format.open_memmap( - save_path, - mode="w+", - shape=canvas_cum_shape_, - dtype=np.float32, - ) - # assuming no more than 255 overlapping times - count_canvas = np.lib.format.open_memmap( - cache_count_path, - mode="w+", - shape=canvas_count_shape_, - dtype=np.uint8, - ) - # flush fill - count_canvas[:] = 0 - is_on_drive = True - else: - is_on_drive = False - cum_canvas = np.zeros( - shape=canvas_cum_shape_, - dtype=np.float32, - ) - # for pixel occurrence counting - count_canvas = np.zeros(canvas_count_shape_, dtype=np.float32) - - return is_on_drive, count_canvas, cum_canvas - - -class SemanticSegmentor: - """Pixel-wise segmentation predictor. - - The tiatoolbox model should produce the following results on the BCSS dataset - using fcn_resnet50_unet-bcss. - - .. list-table:: Semantic segmentation performance on the BCSS dataset - :widths: 15 15 15 15 15 15 15 - :header-rows: 1 - - * - - - Tumour - - Stroma - - Inflammatory - - Necrosis - - Other - - All - * - Amgad et al. - - 0.851 - - 0.800 - - 0.712 - - 0.723 - - 0.666 - - 0.750 - * - TIAToolbox - - 0.885 - - 0.825 - - 0.761 - - 0.765 - - 0.581 - - 0.763 - - Note, if `model` is supplied in the arguments, it will ignore the - `pretrained_model` and `pretrained_weights` arguments. + + patch_output_shape: tuple[int, int] + output_resolutions: Resolution + + +class SemanticSegmentor(PatchPredictor): + r"""Semantic segmentation engine for digital histology images. + + This class extends `PatchPredictor` to support semantic segmentation tasks + using pretrained or custom models from TIAToolbox. It supports both patch-level + and whole slide image (WSI) processing, and provides utilities for merging, + post-processing, and saving predictions. + + Performance: + The TIAToolbox model `fcn_resnet50_unet-bcss` achieves the following + results on the BCSS dataset: + + .. list-table:: Semantic segmentation performance on the BCSS dataset + :widths: 15 15 15 15 15 15 15 + :header-rows: 1 + + * - + - Tumour + - Stroma + - Inflammatory + - Necrosis + - Other + - All + * - Amgad et al. + - 0.851 + - 0.800 + - 0.712 + - 0.723 + - 0.666 + - 0.750 + * - TIAToolbox + - 0.885 + - 0.825 + - 0.761 + - 0.765 + - 0.581 + - 0.763 Args: - model (nn.Module): - Use externally defined PyTorch model for prediction with - weights already loaded. Default is `None`. If provided, - `pretrained_model` argument is ignored. - pretrained_model (str): - Name of the existing models support by tiatoolbox for - processing the data. For a full list of pretrained models, - refer to the `docs - `_. + model (str | ModelABC): + A PyTorch model instance or name of a pretrained model from TIAToolbox. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link + `_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set - of weights via the `pretrained_weights` argument. Argument - is case-insensitive. - pretrained_weights (str): - Path to the weight of the corresponding `pretrained_model`. + of weights using the `weights` parameter. Default is `None`. batch_size (int): - Number of images fed into the model each time. - num_loader_workers (int): - Number of workers to load the data. Take note that they will - also perform preprocessing. - num_postproc_workers (int): - This value is there to maintain input compatibility with - `tiatoolbox.models.classification` and is not used. + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + + >>> engine = SemanticSegmentor( + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) + + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". verbose (bool): - Whether to output logging information. - dataset_class (obj): - Dataset class to be used instead of default. - auto_generate_mask (bool): - To automatically generate tile/WSI tissue mask if is not - provided. + Whether to enable verbose logging. Default is True. Attributes: - process_prediction_per_batch (bool): - A flag to denote whether post-processing for inference - output is applied after each batch or after finishing an entire - tile or WSI. + images (list[str | Path] | np.ndarray): + Input image patches or WSI paths. + masks (list[str | Path] | np.ndarray): + Optional tissue masks for WSI processing. + These are only utilized when patch_mode is False. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + patch_mode (bool): + Whether input is treated as patches (`True`) or WSIs (`False`). + model (ModelABC): + Loaded PyTorch model. + ioconfig (ModelIOConfigABC): + IO configuration for patch extraction and resolution. + return_labels (bool): + Whether to include labels in the output. + input_resolutions (list[dict]): + Resolution settings for model input. Supported + units are `level`, `power` and `mpp`. Keys should be "units" and + "resolution" e.g., [{"units": "mpp", "resolution": 0.25}]. Please see + :class:`WSIReader` for details. + patch_input_shape (tuple[int, int]): + Shape of input patches (height, width). Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple[int, int]): + Stride used during patch extraction. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + labels (list | None): + Optional labels for input images. + Only a single label per image is supported. + drop_keys (list): + Keys to exclude from model output. + output_type (str): + Format of output ("dict", "zarr", "annotationstore"). + output_locations (list | None): + Coordinates of output patches used during WSI processing. Examples: - >>> # Sample output of a network - >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] - >>> predictor = SemanticSegmentor(model='fcn-tissue_mask') - >>> output = predictor.predict(wsis, mode='wsi') - >>> list(output.keys()) - [('A/wsi.svs', 'output/0.raw') , ('B/wsi.svs', 'output/1.raw')] - >>> # if a network have 2 output heads, each head output of 'A/wsi.svs' - >>> # will be respectively stored in 'output/0.raw.0', 'output/0.raw.1' + >>> # list of 2 image patches as input + >>> wsis = ['path/img.svs', 'path/img.svs'] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(wsis, patch_mode=False) + + >>> # array of list of 2 image patches as input + >>> image_patches = [np.ndarray, np.ndarray] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(data, patch_mode=True) + + >>> # list of 2 image patch files as input + >>> data = ['path/img.png', 'path/img.png'] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(data, patch_mode=False) + + >>> # list of 2 image tile files as input + >>> tile_file = ['path/tile1.png', 'path/tile2.png'] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(tile_file, patch_mode=False) + + >>> # list of 2 wsi files as input + >>> wsis = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> segmentor = SemanticSegmentor(model="resnet18-kather100k") + >>> output = segmentor.run(wsis, patch_mode=False) + + References: + [1] Amgad M, Elfandy H, ..., Gutman DA, Cooper LAD. Structured crowdsourcing + enables convolutional segmentation of histology images. Bioinformatics 2019. + doi: 10.1093/bioinformatics/btz083 """ def __init__( self: SemanticSegmentor, + model: str | ModelABC, batch_size: int = 8, - num_loader_workers: int = 0, - num_postproc_workers: int = 0, - model: torch.nn.Module | None = None, - pretrained_model: str | None = None, - pretrained_weights: str | None = None, - dataset_class: Callable = WSIStreamDataset, + num_workers: int = 0, + weights: str | Path | None = None, *, + device: str = "cpu", verbose: bool = True, - auto_generate_mask: bool = False, ) -> None: - """Initialize :class:`SemanticSegmentor`.""" - super().__init__() + """Initialize :class:`SemanticSegmentor`. - if model is None and pretrained_model is None: - msg = "Must provide either of `model` or `pretrained_model`" - raise ValueError(msg) + Args: + model (str | ModelABC): + A PyTorch model instance or name of a pretrained model from TIAToolbox. + If a string is provided, the corresponding pretrained weights will be + downloaded unless overridden via `weights`. + batch_size (int): + Number of image patches processed per forward pass. Default is 8. + num_workers (int): + Number of workers for data loading. Default is 0. + weights (str | Path | None): + Path to model weights. If None, default weights are used. + device (str): + Device to run the model on (e.g., "cpu", "cuda"). Default is "cpu". + verbose (bool): + Whether to enable verbose logging. Default is True. - if model is not None: - self.model = model - # template ioconfig, usually coming from pretrained - self.ioconfig = None - else: - model, ioconfig = get_pretrained_model(pretrained_model, pretrained_weights) - self.ioconfig = ioconfig - self.model = model - - # local variables for flagging mode within class, - # subclass should have overwritten to alter some specific behavior - self.process_prediction_per_batch = True - - # for runtime, such as after wrapping with nn.DataParallel - self._cache_dir = None - self._loader = None - self._model = None - self._device = None - self._mp_shared_space = None - self._postproc_workers = None - self.num_postproc_workers = num_postproc_workers - self._futures = None - self._outputs = [] - self.imgs = None - self.masks = None - - self.dataset_class: WSIStreamDataset = dataset_class - self.model = compile_model( - model, - mode=rcParam["torch_compile_mode"], + """ + super().__init__( + model=model, + batch_size=batch_size, + num_workers=num_workers, + weights=weights, + device=device, + verbose=verbose, ) - self.pretrained_model = pretrained_model - self.batch_size = batch_size - self.num_loader_workers = num_loader_workers - self.num_postproc_workers = None - self.verbose = verbose - self.auto_generate_mask = auto_generate_mask - - @staticmethod - def get_coordinates( - image_shape: tuple[int, int] | np.ndarray, - ioconfig: IOSegmentorConfig, - ) -> tuple[np.ndarray, np.ndarray]: - """Calculate patch tiling coordinates. - - By default, internally, it will call the - `PatchExtractor.get_coordinates`. To use your own approach, - either subclass to overwrite or directly assign your own - function to this name. In either cases, the function must obey - the API defined here. + self.output_locations: list | None = None + + def get_dataloader( + self: SemanticSegmentor, + images: str | Path | list[str | Path] | np.ndarray, + masks: Path | None = None, + labels: list | None = None, + ioconfig: SemanticSegmentorRunParams | None = None, + *, + patch_mode: bool = True, + auto_get_mask: bool = True, + ) -> torch.utils.data.DataLoader: + """Pre-process images and masks and return a DataLoader for inference. + + This method prepares the dataset and returns a PyTorch DataLoader + for either patch-based or WSI-based semantic segmentation. It overrides + the base method to support additional WSI-specific logic, including + patch output shape and output location tracking. Args: - image_shape (tuple(int), :class:`numpy.ndarray`): - This argument specifies the shape of mother image (the - image we want to extract patches from) at requested - `resolution` and `units` and it is expected to be in - (width, height) format. - ioconfig (:class:`IOSegmentorConfig`): - Object that contains information about input and output - placement of patches. Check `IOSegmentorConfig` for - details about available attributes. + images (str | Path | list[str | Path] | np.ndarray): + Input images. Can be a list of file paths or a NumPy array + of image patches in NHWC format. + masks (Path | None): + Optional tissue masks for WSI processing. Only used when + `patch_mode` is False. + labels (list | None): + Optional labels for input images. Only one label per image is supported. + ioconfig (SemanticSegmentorRunParams | None): + IO configuration for patch extraction and resolution. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). + auto_get_mask (bool): + Whether to automatically generate a tissue mask using + `wsireader.tissue_mask()` when `patch_mode` is False. + If `True`, only tissue regions are processed. If `False`, + all patches are processed. Default is `True`. Returns: - tuple: - List of patch inputs and outputs + torch.utils.data.DataLoader: + A PyTorch DataLoader configured for semantic segmentation inference. - - :py:obj:`list` - patch_inputs: - A list of corrdinates in `[start_x, start_y, end_x, - end_y]` format indicating the read location of the - patch in the mother image. + """ + # Overwrite when patch_mode is False. + if not patch_mode: + dataset = WSIPatchDataset( + input_img=images, + mask_path=masks, + patch_input_shape=ioconfig.patch_input_shape, + patch_output_shape=ioconfig.patch_output_shape, + stride_shape=ioconfig.stride_shape, + resolution=ioconfig.input_resolutions[0]["resolution"], + units=ioconfig.input_resolutions[0]["units"], + auto_get_mask=auto_get_mask, + ) - - :py:obj:`list` - patch_outputs: - A list of corrdinates in `[start_x, start_y, end_x, - end_y]` format indicating to write location of the - patch in the mother image. + dataset.preproc_func = self.model.preproc_func + self.output_locations = dataset.outputs - Examples: - >>> # API of function expected to overwrite `get_coordinates` - >>> def func(image_shape, ioconfig): - ... patch_inputs = np.array([[0, 0, 256, 256]]) - ... patch_outputs = np.array([[0, 0, 256, 256]]) - ... return patch_inputs, patch_outputs - >>> segmentor = SemanticSegmentor(model='unet') - >>> segmentor.get_coordinates = func + # preprocessing must be defined with the dataset + return torch.utils.data.DataLoader( + dataset, + num_workers=self.num_workers, + batch_size=self.batch_size, + drop_last=False, + shuffle=False, + ) - """ - results = PatchExtractor.get_coordinates( - patch_output_shape=ioconfig.patch_output_shape, - image_shape=image_shape, - patch_input_shape=ioconfig.patch_input_shape, - stride_shape=ioconfig.stride_shape, + return super().get_dataloader( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, ) - return results[0], results[1] - @staticmethod - def filter_coordinates( - mask_reader: VirtualWSIReader, - bounds: np.ndarray, - resolution: Resolution | None = None, - units: Units | None = None, - ) -> np.ndarray: - """Indicates which coordinate is valid basing on the mask. + def infer_wsi( + self: SemanticSegmentor, + dataloader: DataLoader, + save_path: Path, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict[str, da.Array]: + """Perform model inference on a whole slide image (WSI). - To use your own approaches, either subclass to overwrite or - directly assign your own function to this name. In either cases, - the function must obey the API defined here. + This method processes a WSI using the provided DataLoader, merges + patch-level predictions into a full-resolution canvas, and returns + the aggregated output. It supports memory-aware caching and optional + inclusion of coordinates and labels. Args: - mask_reader (:class:`.VirtualReader`): - A virtual pyramidal reader of the mask related to the - WSI from which we want to extract the patches. - bounds (ndarray and np.int32): - Coordinates to be checked via the `func`. They must be - in the same resolution as requested `resolution` and - `units`. The shape of `coordinates` is (N, K) where N is - the number of coordinate sets and K is either 2 for - centroids or 4 for bounding boxes. When using the - default `func=None`, K should be 4, as we expect the - `coordinates` to be bounding boxes in `[start_x, - start_y, end_x, end_y]` format. - resolution (Resolution): - Resolution of the requested patch. - units (Units): - Units of the requested patch. + dataloader (DataLoader): + PyTorch DataLoader configured for WSI processing. + save_path (Path): + Path to save the intermediate output. The intermediate output + is saved in a Zarr file. + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters, including: + - return_probabilities (bool): Whether to return probability maps. + - return_labels (bool): Whether to include labels in the output. + - memory_threshold (int): Memory usage threshold to trigger disk + caching. Returns: - :class:`numpy.ndarray`: - List of flags to indicate which coordinate is valid. - - Examples: - >>> # API of function expected to overwrite `filter_coordinates` - >>> def func(reader, bounds, resolution, units): - ... # as example, only select first bound - ... return np.array([1, 0]) - >>> coords = [[0, 0, 256, 256], [128, 128, 384, 384]] - >>> segmentor = SemanticSegmentor(model='unet') - >>> segmentor.filter_coordinates = func + dict[str, dask.array.Array]: + Dictionary containing merged prediction results: + - "probabilities": Full-resolution probability map. + - "coordinates": Patch coordinates. + - "labels": Ground truth labels (if `return_labels` is True). """ - if not isinstance(mask_reader, VirtualWSIReader): - msg = "`mask_reader` should be VirtualWSIReader." - raise TypeError(msg) - - if not isinstance(bounds, np.ndarray) or not np.issubdtype( - bounds.dtype, - np.integer, - ): - msg = "`coordinates` should be ndarray of integer type." - raise ValueError(msg) - - mask_real_shape = mask_reader.img.shape[:2] - mask_resolution_shape = mask_reader.slide_dimensions( - resolution=resolution, - units=units, - )[::-1] - mask_real_shape = np.array(mask_real_shape) - mask_resolution_shape = np.array(mask_resolution_shape) - scale_factor = mask_real_shape / mask_resolution_shape - scale_factor = scale_factor[0] # what if ratio x != y - - def sel_func(coord: np.ndarray) -> bool: - """Accept coord as long as its box contains part of mask.""" - coord_in_real_mask = np.ceil(scale_factor * coord).astype(np.int32) - start_x, start_y, end_x, end_y = coord_in_real_mask - roi = mask_reader.img[start_y:end_y, start_x:end_x] - return np.sum(roi > 0) > 0 - - flags = [sel_func(bound) for bound in bounds] - return np.array(flags) - - @staticmethod - def get_reader( - img_path: str | Path, - mask_path: str | Path, - mode: str, - *, - auto_get_mask: bool, - ) -> tuple[WSIReader, WSIReader]: - """Define how to get reader for mask and source image.""" - img_path = Path(img_path) - reader = WSIReader.open(img_path) - - mask_reader = None - if mask_path is not None: - mask_path = Path(mask_path) - if not Path.is_file(mask_path): - msg = "`mask_path` must be a valid file path." - raise ValueError(msg) - mask = imread(mask_path) # assume to be gray - mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) - mask = np.array(mask > 0, dtype=np.uint8) - - mask_reader = VirtualWSIReader(mask) - mask_reader.info = reader.info - elif auto_get_mask and mode == "wsi" and mask_path is None: - # if no mask provided and `wsi` mode, generate basic tissue - # mask on the fly - mask_reader = reader.tissue_mask(resolution=1.25, units="power") - mask_reader.info = reader.info - return reader, mask_reader - - def _predict_one_wsi( - self: SemanticSegmentor, - wsi_idx: int, - ioconfig: IOSegmentorConfig, - save_path: str, - mode: str, - ) -> None: - """Make a prediction on tile/wsi. + # Default Memory threshold percentage is 80. + memory_threshold = kwargs.get("memory_threshold", 80) + vm = psutil.virtual_memory() + + keys = ["probabilities", "coordinates"] + coordinates = [] + + # Main output dictionary + raw_predictions = dict(zip(keys, [da.empty(shape=(0, 0))] * len(keys))) + + # Inference loop + tqdm = get_tqdm() + tqdm_loop = ( + tqdm(dataloader, leave=False, desc="Inferring patches") + if self.verbose + else dataloader + ) - Args: - wsi_idx (int): - Index of the tile/wsi to be processed within `self`. - ioconfig (:class:`IOSegmentorConfig`): - Object which defines I/O placement during inference and - when assembling back to full tile/wsi. - save_path (str): - Location to save output prediction as well as possible - intermediate results. - mode (str): - Either `"tile"` or `"wsi"` to indicate run mode. + canvas_np, output_locs_y_ = None, None + canvas, count, output_locs = None, None, None + canvas_zarr, count_zarr = None, None - """ - cache_dir = self._cache_dir / str(wsi_idx) - cache_dir.mkdir(parents=True) - - wsi_path = self.imgs[wsi_idx] - mask_path = None if self.masks is None else self.masks[wsi_idx] - wsi_reader, mask_reader = self.get_reader( - wsi_path, - mask_path, - mode, - auto_get_mask=self.auto_generate_mask, + full_output_locs = ( + dataloader.dataset.full_outputs + if hasattr(dataloader.dataset, "full_outputs") + else dataloader.dataset.outputs ) - # assume ioconfig has already been converted to `baseline` for `tile` mode - resolution = ioconfig.highest_input_resolution - wsi_proc_shape = wsi_reader.slide_dimensions(**resolution) - - # * retrieve patch and tile placement - # this is in XY - (patch_inputs, patch_outputs) = self.get_coordinates(wsi_proc_shape, ioconfig) - if mask_reader is not None: - sel = self.filter_coordinates(mask_reader, patch_outputs, **resolution) - patch_outputs = patch_outputs[sel] - patch_inputs = patch_inputs[sel] - - # modify the shared space so that we can update worker info - # without needing to re-create the worker. There should be no - # race-condition because only the following enumerate loop - # triggers the parallelism, and this portion is still in - # sequential execution order - patch_inputs = torch.from_numpy(patch_inputs).share_memory_() - patch_outputs = torch.from_numpy(patch_outputs).share_memory_() - self._mp_shared_space.patch_inputs = patch_inputs - self._mp_shared_space.patch_outputs = patch_outputs - self._mp_shared_space.wsi_idx = torch.Tensor([wsi_idx]).share_memory_() - - pbar_desc = "Process Batch: " - pbar = tqdm.tqdm( - desc=pbar_desc, - leave=True, - total=len(self._loader), - ncols=80, - ascii=True, - position=0, - ) + for batch_idx, batch_data in enumerate(tqdm_loop): + batch_output = self.model.infer_batch( + self.model, + batch_data["image"], + device=self.device, + ) + + batch_locs = batch_data["output_locs"].numpy() - cum_output = [] - for _, batch_data in enumerate(self._loader): - sample_datas, sample_infos = batch_data - batch_size = sample_infos.shape[0] - # ! depending on the protocol of the output within infer_batch - # ! this may change, how to enforce/document/expose this in a - # ! sensible way? - - # assume to return a list of L output, - # each of shape N x etc. (N=batch size) - sample_outputs = self.model.infer_batch( - self._model, - sample_datas, - device=self._device, + # Interpolate outputs for masked regions + full_batch_output, full_output_locs, output_locs = prepare_full_batch( + batch_output, + batch_locs, + full_output_locs, + output_locs, + is_last=(batch_idx == (len(dataloader) - 1)), ) - # repackage so that it's an N list, each contains - # L x etc. output - sample_outputs = [np.split(v, batch_size, axis=0) for v in sample_outputs] - sample_outputs = list(zip(*sample_outputs)) - - # tensor to numpy, costly? - sample_infos = sample_infos.numpy() - sample_infos = np.split(sample_infos, batch_size, axis=0) - - sample_outputs = list(zip(sample_infos, sample_outputs)) - if self.process_prediction_per_batch: - self._process_predictions( - sample_outputs, - wsi_reader, - ioconfig, - save_path, - cache_dir, + + canvas_np = concatenate_none(old_arr=canvas_np, new_arr=full_batch_output) + + # Determine if dataloader is moved to next row of patches + change_indices = np.where(np.diff(output_locs[:, 1]) != 0)[0] + 1 + + # If a row of patches has been processed. + if change_indices.size > 0: + canvas, count, canvas_np, output_locs, output_locs_y_ = ( + merge_horizontal( + canvas, + count, + output_locs_y_, + canvas_np, + output_locs, + change_indices, + ) ) - else: - cum_output.extend(sample_outputs) - pbar.update() - pbar.close() - - self._process_predictions( - cum_output, - wsi_reader, - ioconfig, + + used_percent = vm.percent + canvas_used_percent = (canvas.nbytes / vm.free) * 100 + if ( + used_percent > memory_threshold + or canvas_used_percent > memory_threshold + ): + tqdm_loop.desc = "Spill intermediate data to disk" + used_percent = ( + canvas_used_percent + if (canvas_used_percent > memory_threshold) + else used_percent + ) + msg = ( + f"Current Memory usage: {used_percent} % " + f"exceeds specified threshold: {memory_threshold}. " + f"Saving intermediate results to disk." + ) + tqdm.write(msg) + # Flush data in Memory and clear dask graph + canvas_zarr, count_zarr = save_to_cache( + canvas, + count, + canvas_zarr, + count_zarr, + save_path=save_path, + ) + canvas, count = None, None + gc.collect() + tqdm_loop.desc = "Inferring patches" + + coordinates.append( + da.from_array( + self._get_coordinates(batch_data), + ) + ) + + canvas, count, _, _, output_locs_y_ = merge_horizontal( + canvas, + count, + output_locs_y_, + canvas_np, + output_locs, + change_indices=[len(output_locs)], + ) + + zarr_group = None + if canvas_zarr is not None: + canvas_zarr, count_zarr = save_to_cache( + canvas, count, canvas_zarr, count_zarr + ) + # Wrap zarr in dask array + canvas = da.from_zarr(canvas_zarr, chunks=canvas_zarr.chunks) + count = da.from_zarr(count_zarr, chunks=count_zarr.chunks) + zarr_group = zarr.open(canvas_zarr.store.path, mode="a") + + # Final vertical merge + raw_predictions["probabilities"] = merge_vertical_chunkwise( + canvas, + count, + output_locs_y_, + zarr_group, save_path, - cache_dir, + memory_threshold, ) + raw_predictions["coordinates"] = da.concatenate(coordinates, axis=0) - # clean up the cache directories - shutil.rmtree(cache_dir) + return raw_predictions - def _process_predictions( + def save_predictions( self: SemanticSegmentor, - cum_batch_predictions: list, - wsi_reader: WSIReader, - ioconfig: IOSegmentorConfig, - save_path: str, - cache_dir: str, - ) -> None: - """Define how the aggregated predictions are processed. + processed_predictions: dict, + output_type: str, + save_path: Path | None = None, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> dict | AnnotationStore | Path: + """Save semantic segmentation predictions to disk or return them in memory. - This includes merging the prediction if necessary and also saving afterwards. - Note that items within `cum_batch_predictions` will be consumed during - the operation. + This method saves predictions in one of the supported formats: + - "dict": returns predictions as a Python dictionary. + - "zarr": saves predictions as a Zarr group and returns the path. + - "annotationstore": converts predictions to an AnnotationStore (.db file). - Args: - cum_batch_predictions (list): - List of batch predictions. Each item within the list - should be of (location, patch_predictions). - wsi_reader (:class:`WSIReader`): - A reader for the image where the predictions come from. - ioconfig (:class:`IOSegmentorConfig`): - A configuration object contains input and output - information. - save_path (str): - Root path to save current WSI predictions. - cache_dir (str): - Root path to cache current WSI data. - - """ - if len(cum_batch_predictions) == 0: - return - - # assume predictions is N, each item has L output element - locations, predictions = list(zip(*cum_batch_predictions)) - # Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of - # output patch this can exceed the image bound at the requested - # resolution remove singleton due to split. - locations = np.array([v[0] for v in locations]) - for index, output_resolution in enumerate(ioconfig.output_resolutions): - # assume resolution index to be in the same order as L - merged_resolution = ioconfig.highest_input_resolution - merged_locations = locations - # ! location is w.r.t the highest resolution, hence still need conversion - if ioconfig.save_resolution is not None: - merged_resolution = ioconfig.save_resolution - output_shape = wsi_reader.slide_dimensions(**output_resolution) - merged_shape = wsi_reader.slide_dimensions(**merged_resolution) - fx = merged_shape[0] / output_shape[0] - merged_locations = np.ceil(locations * fx).astype(np.int64) - merged_shape = wsi_reader.slide_dimensions(**merged_resolution) - # 0 idx is to remove singleton without removing other axes singleton - to_merge_predictions = [v[index][0] for v in predictions] - sub_save_path = f"{save_path}.raw.{index}.npy" - sub_count_path = f"{cache_dir}/count.{index}.npy" - self.merge_prediction( - merged_shape[::-1], # XY to YX - to_merge_predictions, - merged_locations, - save_path=sub_save_path, - cache_count_path=sub_count_path, - ) - - @staticmethod - def merge_prediction( - canvas_shape: tuple[int] | list[int] | np.ndarray, - predictions: list[np.ndarray], - locations: list | np.ndarray, - save_path: str | Path | None = None, - cache_count_path: str | Path | None = None, - ) -> np.ndarray: - """Merge patch-level predictions to form a 2-dimensional prediction map. - - When accumulating the raw prediction onto a same canvas (via - calling the function multiple times), `save_path` and - `cache_count_path` must be the same. If either of these two do - not exist, the function will create new files. However, if - `save_path` is `None`, the function will perform the - accumulation using CPU-RAM as storage. + If `patch_mode` is True, predictions are saved per image. If False, + predictions are merged and saved as a single output. Args: - canvas_shape (:class:`numpy.ndarray`): - HW of the supposed assembled image. - predictions (list): - List of :class:`np.ndarray`, each item is a patch prediction, - assuming to be of shape HWC. - locations (list): - List of :class:`np.ndarray`, each item is the location of the patch - at the same index within `predictions`. The location is - in the to be assembled canvas and of the form - `(top_left_x, top_left_y, bottom_right_x, - bottom_right_x)`. - save_path (str): - Location to save the assembled image. - cache_count_path (str): - Location to store the canvas for counting how many times - each pixel get overlapped when assembling. + processed_predictions (dict): + Dictionary containing processed model predictions. + output_type (str): + Desired output format: "dict", "zarr", or "annotationstore". + save_path (Path | None): + Path to save the output file. Required for "zarr" and "annotationstore". + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters including: + - scale_factor (tuple[float, float]): For coordinate transformation. + - class_dict (dict): Mapping of class indices to names. + - return_probabilities (bool): Whether to save probability maps. Returns: - :class:`numpy.ndarray`: - An image contains merged data. - - Examples: - >>> SemanticSegmentor.merge_prediction( - ... canvas_shape=[4, 4], - ... predictions=[ - ... np.full((2, 2), 1), - ... np.full((2, 2), 2)], - ... locations=[ - ... [0, 0, 2, 2], - ... [2, 2, 4, 4]], - ... save_path=None, - ... ) - ... array([[1, 1, 0, 0], - ... [1, 1, 0, 0], - ... [0, 0, 2, 2], - ... [0, 0, 2, 2]]) + dict | AnnotationStore | Path: + - If output_type is "dict": returns predictions as a dictionary. + - If output_type is "zarr": returns path to saved Zarr file. + - If output_type is "annotationstore": returns AnnotationStore + or path to .db file. """ - canvas_shape = np.array(canvas_shape) + # Conversion to annotationstore uses a different function for SemanticSegmentor + if output_type.lower() != "annotationstore": + return super().save_predictions( + processed_predictions, output_type, save_path=save_path, **kwargs + ) - sample_prediction = predictions[0] + return_probabilities = kwargs.get("return_probabilities", False) + output_type_ = ( + "zarr" + if is_zarr(save_path.with_suffix(".zarr")) or return_probabilities + else "dict" + ) - if len(sample_prediction.shape) not in (2, 3): - msg = f"Prediction is no HW or HWC: {sample_prediction.shape}." - raise ValueError(msg) + processed_predictions = super().save_predictions( + processed_predictions, + output_type=output_type_, + save_path=save_path.with_suffix(".zarr"), + **kwargs, + ) - ( - canvas_cum_shape_, - canvas_count_shape_, - add_singleton_dim, - ) = _estimate_canvas_parameters(sample_prediction, canvas_shape) + if isinstance(processed_predictions, Path): + processed_predictions = zarr.open(str(processed_predictions), mode="r") + + # scale_factor set from kwargs + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # class_dict set from kwargs + class_dict = kwargs.get("class_dict") + + # Need to add support for zarr conversion. + save_paths = [] + + logger.info("Saving predictions as AnnotationStore.") + if self.patch_mode: + for i, predictions in enumerate(processed_predictions["predictions"]): + if isinstance(self.images[i], Path): + output_path = save_path.parent / (self.images[i].stem + ".db") + else: + output_path = save_path.parent / (str(i) + ".db") + + out_file = dict_to_store_semantic_segmentor( + patch_output={"predictions": predictions}, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=output_path, + ) - is_on_drive, count_canvas, cum_canvas = _prepare_save_output( - save_path, - cache_count_path, - canvas_cum_shape_, - canvas_count_shape_, - ) + save_paths.append(out_file) + else: + out_file = dict_to_store_semantic_segmentor( + patch_output=processed_predictions, + scale_factor=scale_factor, + class_dict=class_dict, + save_path=save_path.with_suffix(".db"), + ) + save_paths = out_file - def index(arr: np.ndarray, tl: np.ndarray, br: np.ndarray) -> np.ndarray: - """Helper to shorten indexing.""" - return arr[tl[0] : br[0], tl[1] : br[1]] - - patch_infos = list(zip(locations, predictions)) - for _, patch_info in enumerate(patch_infos): - # position is assumed to be in XY coordinate - (bound_in_wsi, prediction) = patch_info - # convert to XY to YX, and in tl, br - tl_in_wsi = np.array(bound_in_wsi[:2][::-1]) - br_in_wsi = np.array(bound_in_wsi[2:][::-1]) - old_tl_in_wsi = tl_in_wsi.copy() - - # need to do conversion - patch_shape_in_wsi = tuple(br_in_wsi - tl_in_wsi) - # conversion to make cv2 happy - prediction = prediction.astype(np.float32) - prediction = cv2.resize(prediction, patch_shape_in_wsi[::-1]) - # ! cv2 resize will remove singleton ! - if add_singleton_dim: - prediction = prediction[..., None] - - sel = tl_in_wsi < 0 - tl_in_wsi[sel] = 0 - - if np.any(tl_in_wsi >= canvas_shape): - continue - - sel = br_in_wsi > canvas_shape - br_in_wsi[sel] = canvas_shape[sel] - - # re-calibrate the position in case patch passing the image bound - br_in_patch = br_in_wsi - old_tl_in_wsi - patch_actual_shape = br_in_wsi - tl_in_wsi - tl_in_patch = br_in_patch - patch_actual_shape - - # now cropping the prediction region - patch_pred = prediction[ - tl_in_patch[0] : br_in_patch[0], - tl_in_patch[1] : br_in_patch[1], - ] - - patch_count = np.ones(patch_pred.shape[:2])[..., None] - if not is_on_drive: - index(cum_canvas, tl_in_wsi, br_in_wsi)[:] += patch_pred - index(count_canvas, tl_in_wsi, br_in_wsi)[:] += patch_count - else: - old_avg_pred = np.array(index(cum_canvas, tl_in_wsi, br_in_wsi)) - old_count = np.array(index(count_canvas, tl_in_wsi, br_in_wsi)) - # ! there will be precision error, but we have to live with this - new_count = old_count + patch_count - # retrieve old raw probabilities after summation - old_raw_pred = old_avg_pred * old_count - new_avg_pred = (old_raw_pred + patch_pred) / new_count - index(cum_canvas, tl_in_wsi, br_in_wsi)[:] = new_avg_pred - index(count_canvas, tl_in_wsi, br_in_wsi)[:] = new_count - if not is_on_drive: - cum_canvas /= count_canvas + 1.0e-6 - return cum_canvas - - @staticmethod - def _prepare_save_dir(save_dir: str | Path | None) -> tuple[Path, Path]: - """Prepare save directory and cache.""" - if save_dir is None: - logger.warning( - "Segmentor will only output to directory. " - "All subsequent output will be saved to current runtime " - "location under folder 'output'. Overwriting may happen! ", - stacklevel=2, + if return_probabilities: + msg = ( + f"Probability maps cannot be saved as AnnotationStore. " + f"To visualise heatmaps in TIAToolbox Visualization tool," + f"convert heatmaps in {save_path} to ome.tiff using" + f"tiatoolbox.utils.misc.write_probability_heatmap_as_ome_tiff." ) - save_dir = Path.cwd() / "output" + logger.info(msg) - save_dir = Path(save_dir).resolve() - if save_dir.is_dir(): - msg = f"`save_dir` already exists! {save_dir}" - raise ValueError(msg) - save_dir.mkdir(parents=True) - cache_dir = Path(f"{save_dir}/cache") - Path.mkdir(cache_dir, parents=True) - - return save_dir, cache_dir - - @staticmethod - def _update_ioconfig( - ioconfig: IOSegmentorConfig, - mode: str, - patch_input_shape: IntPair, - patch_output_shape: IntPair, - stride_shape: IntPair, - resolution: Resolution, - units: Units, - ) -> IOSegmentorConfig: - """Update ioconfig according to input parameters. + return save_paths + + def _update_run_params( + self: SemanticSegmentor, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + labels: list | None = None, + save_dir: os.PathLike | Path | None = None, + ioconfig: IOSegmentorConfig | None = None, + output_type: str = "dict", + *, + overwrite: bool = False, + patch_mode: bool, + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> Path | None: + """Update runtime parameters for the PatchPredictor engine. + + This method sets internal attributes such as caching, batch size, + IO configuration, and output format based on user input and keyword arguments. + It also configures whether to include probabilities in the output. Args: - ioconfig (:class:`IOSegmentorConfig`): - Object defines information about input and output - placement of patches. When provided, - `patch_input_shape`, `patch_output_shape`, - `stride_shape`, `resolution`, and `units` arguments are - ignored. Otherwise, those arguments will be internally - converted to a :class:`IOSegmentorConfig` object. - mode (str): - Type of input to process. Choose from either `tile` or - `wsi`. - patch_input_shape (tuple): - Size of patches input to the model. The values - are at requested read resolution and must be positive. - patch_output_shape (tuple): - Size of patches output by the model. The values are at - the requested read resolution and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. The values - are at requested read resolution and must be positive. - If not provided, `stride_shape=patch_input_shape` is - used. - resolution (Resolution): - Resolution used for reading the image. - units (Units): - Units of resolution used for reading the image. + images (list[PathLike | WSIReader] | np.ndarray): + Input images or patches. + masks (list[PathLike] | np.ndarray | None): + Optional masks for WSI processing. + labels (list | None): + Optional labels for input images. + save_dir (PathLike | None): + Directory to save output files. Required for WSI mode. + ioconfig (ModelIOConfigABC | None): + IO configuration for patch extraction and resolution. + output_type (str): + Desired output format: "dict", "zarr", or "annotationstore". + overwrite (bool): + Whether to overwrite existing output files. Default is False. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters. Returns: - :class:`IOSegmentorConfig`: - Updated ioconfig. + Path | None: + Path to the save directory if applicable, otherwise None. + + Raises: + ValueError: + If `labels` are requested for WSI processing. """ - if patch_output_shape is None: - patch_output_shape = patch_input_shape - if stride_shape is None: - stride_shape = patch_output_shape - - if ioconfig is None: - ioconfig = IOSegmentorConfig( - input_resolutions=[{"resolution": resolution, "units": units}], - output_resolutions=[{"resolution": resolution, "units": units}], - patch_input_shape=patch_input_shape, - patch_output_shape=patch_output_shape, - stride_shape=stride_shape, - ) - if mode == "tile": - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"`. Resolutions will be converted ' - "to baseline value.", - stacklevel=2, - ) - return ioconfig.to_baseline() + return_labels = kwargs.get("return_labels") - return ioconfig + if return_labels and not patch_mode: + msg = "`return_labels` is not supported when `patch_mode` is False." + raise ValueError(msg) - def _prepare_workers(self: SemanticSegmentor) -> None: - """Prepare number of workers.""" - self._postproc_workers = None - if self.num_postproc_workers is not None: - self._postproc_workers = ProcessPoolExecutor( - max_workers=self.num_postproc_workers, - ) + return super()._update_run_params( + images=images, + masks=masks, + labels=labels, + save_dir=save_dir, + ioconfig=ioconfig, + overwrite=overwrite, + patch_mode=patch_mode, + output_type=output_type, + **kwargs, + ) - def _memory_cleanup(self: SemanticSegmentor) -> None: - """Memory clean up.""" - self.imgs = None - self.masks = None - self._cache_dir = None - self._model = None - self._loader = None - self._device = None - self._futures = None - self._mp_shared_space = None - if self._postproc_workers is not None: - self._postproc_workers.shutdown() - self._postproc_workers = None - - def _predict_wsi_handle_exception( + def run( self: SemanticSegmentor, - imgs: list, - wsi_idx: int, - img_path: str | Path, - mode: str, - ioconfig: IOSegmentorConfig, - save_dir: str | Path, + images: list[os.PathLike | Path | WSIReader] | np.ndarray, + masks: list[os.PathLike | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: IOSegmentorConfig | None = None, *, - crash_on_exception: bool, - ) -> None: - """Predict on multiple WSIs. + patch_mode: bool = True, + save_dir: os.PathLike | Path | None = None, + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[SemanticSegmentorRunParams], + ) -> AnnotationStore | Path | str | dict | list[Path]: + """Run the semantic segmentation engine on input images. + + This method orchestrates the full inference pipeline, including preprocessing, + model inference, post-processing, and saving results. It supports both + patch-level and whole slide image (WSI) modes. Args: - imgs (list, ndarray): - List of inputs to process. When using `"patch"` mode, - the input must be either a list of images, a list of - image file paths or a numpy array of an image list. When - using `"tile"` or `"wsi"` mode, the input must be a list - of file paths. - wsi_idx (int): - index of current WSI being processed. - img_path(str or Path): - Path to current image. - mode (str): - Type of input to process. Choose from either `tile` or - `wsi`. - ioconfig (:class:`IOSegmentorConfig`): - Object defines information about input and output - placement of patches. When provided, - `patch_input_shape`, `patch_output_shape`, - `stride_shape`, `resolution`, and `units` arguments are - ignored. Otherwise, those arguments will be internally - converted to a :class:`IOSegmentorConfig` object. - save_dir (str or Path): - Output directory when processing multiple tiles and - whole-slide images. By default, it is folder `output` - where the running script is invoked. - crash_on_exception (bool): - If `True`, the running loop will crash if there is any - error during processing a WSI. Otherwise, the loop will - move on to the next wsi for processing. + images (list[PathLike | WSIReader] | np.ndarray): + Input images or patches. Can be a list of file paths, WSIReader objects, + or a NumPy array of image patches. + masks (list[PathLike] | np.ndarray | None): + Optional masks for WSI processing. Only used when `patch_mode` is False. + labels (list | None): + Optional labels for input images. Only one label per image is supported. + ioconfig (IOSegmentorConfig | None): + IO configuration for patch extraction and resolution. + patch_mode (bool): + Whether to treat input as patches (`True`) or WSIs (`False`). Default + is True. + save_dir (PathLike | None): + Directory to save output files. Required for WSI mode. + overwrite (bool): + Whether to overwrite existing output files. Default is False. + output_type (str): + Desired output format: "dict", "zarr", or "annotationstore". Default + is "dict". + **kwargs (SemanticSegmentorRunParams): + Additional runtime parameters to update engine attributes. Returns: - list: - A list of tuple(input_path, save_path) where - `input_path` is the path of the input wsi while - `save_path` corresponds to the output predictions. + AnnotationStore | Path | str | dict | list[Path]: + - If `patch_mode` is True: returns predictions or path to saved output. + - If `patch_mode` is False: returns a dictionary mapping each WSI + to its output path. + + Examples: + >>> wsis = ['wsi1.svs', 'wsi2.svs'] + >>> image_patches = [np.ndarray, np.ndarray] + >>> segmentor = SemanticSegmentor(model="fcn-tissue_mask") + >>> output = segmentor.run(image_patches, patch_mode=True) + >>> output + ... "/path/to/Output.db" + + >>> output = segmentor.run( + ... image_patches, + ... patch_mode=True, + ... output_type="zarr" + ... ) + >>> output + ... "/path/to/Output.zarr" + + >>> output = segmentor.run(wsis, patch_mode=False) + >>> output.keys() + ... ['wsi1.svs', 'wsi2.svs'] + >>> output['wsi1.svs'] + ... "/path/to/wsi1.db" """ - try: - wsi_save_path = save_dir / f"{wsi_idx}" - self._predict_one_wsi(wsi_idx, ioconfig, str(wsi_save_path), mode) - - # Do not use dict with file name as key, because it can be - # overwritten. It may be user intention to provide files with a - # same name multiple times (maybe they have different root path) - self._outputs.append([str(img_path), str(wsi_save_path)]) - - # ? will this corrupt old version if control + c midway? - map_file_path = save_dir / "file_map.dat" - # backup old version first - if Path.exists(map_file_path): - old_map_file_path = save_dir / "file_map_old.dat" - shutil.copy(map_file_path, old_map_file_path) - joblib.dump(self._outputs, map_file_path) - - # verbose mode, error by passing ? - logging.info("Finish: %d", wsi_idx / len(imgs)) - logging.info("--Input: %s", str(img_path)) - logging.info("--Output: %s", str(wsi_save_path)) - # prevent deep source check because this is bypass and - # delegating error message - except Exception as err: # skipcq: PYL-W0703 - wsi_save_path = save_dir.joinpath(f"{wsi_idx}") - if crash_on_exception: - raise err # noqa: TRY201 - logging.exception("Crashed on %s", wsi_save_path) - - def predict( # noqa: PLR0913 - self: SemanticSegmentor, - imgs: list, - masks: list | None = None, - mode: str = "tile", - ioconfig: IOSegmentorConfig = None, - patch_input_shape: IntPair = None, - patch_output_shape: IntPair = None, - stride_shape: IntPair = None, - resolution: Resolution = 1.0, - units: Units = "baseline", - save_dir: str | Path | None = None, - device: str = "cpu", - *, - crash_on_exception: bool = False, - ) -> list[tuple[Path, Path]]: - """Make a prediction for a list of input data. - - By default, if the input model at the object instantiation time - is a pretrained model in the toolbox as well as - `patch_input_shape`, `patch_output_shape`, `stride_shape`, - `resolution`, `units` and `ioconfig` are `None`. The method will - use the `ioconfig` retrieved together with the pretrained model. - Otherwise, either `patch_input_shape`, `patch_output_shape`, - `stride_shape`, `resolution`, `units` or `ioconfig` must be set - else a `Value Error` will be raised. + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) - Args: - imgs (list, ndarray): - List of inputs to process. When using `"patch"` mode, - the input must be either a list of images, a list of - image file paths or a numpy array of an image list. When - using `"tile"` or `"wsi"` mode, the input must be a list - of file paths. - masks (list): - List of masks. Only utilised when processing image tiles - and whole-slide images. Patches are only processed if - they are within a masked area. If not provided, then a - tissue mask will be automatically generated for - whole-slide images or the entire image is processed for - image tiles. - mode (str): - Type of input to process. Choose from either `tile` or - `wsi`. - ioconfig (:class:`IOSegmentorConfig`): - Object defines information about input and output - placement of patches. When provided, - `patch_input_shape`, `patch_output_shape`, - `stride_shape`, `resolution`, and `units` arguments are - ignored. Otherwise, those arguments will be internally - converted to a :class:`IOSegmentorConfig` object. - device (str): - :class:`torch.device` to run the model. - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. Default value is "cpu". - patch_input_shape (tuple): - Size of patches input to the model. The values - are at requested read resolution and must be positive. - patch_output_shape (tuple): - Size of patches output by the model. The values are at - the requested read resolution and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. The values - are at requested read resolution and must be positive. - If not provided, `stride_shape=patch_input_shape` is - used. - resolution (float): - Resolution used for reading the image. - units (Units): - Units of resolution used for reading the image. Choose - from either `"level"`, `"power"` or `"mpp"`. - save_dir (str or pathlib.Path): - Output directory when processing multiple tiles and - whole-slide images. By default, it is folder `output` - where the running script is invoked. - crash_on_exception (bool): - If `True`, the running loop will crash if there is any - error during processing a WSI. Otherwise, the loop will - move on to the next wsi for processing. - Returns: - list: - A list of tuple(input_path, save_path) where - `input_path` is the path of the input wsi while - `save_path` corresponds to the output predictions. +def concatenate_none( + old_arr: np.ndarray | da.Array, + new_arr: np.ndarray | da.Array, +) -> np.ndarray | da.Array: + """Concatenate arrays, handling None values gracefully. - Examples: - >>> # Sample output of a network - >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] - >>> predictor = SemanticSegmentor(model='fcn-tissue_mask') - >>> output = predictor.predict(wsis, mode='wsi') - >>> list(output.keys()) - [('A/wsi.svs', 'output/0.raw') , ('B/wsi.svs', 'output/1.raw')] - >>> # if a network have 2 output heads, each head output of 'A/wsi.svs' - >>> # will be respectively stored in 'output/0.raw.0', 'output/0.raw.1' + This utility function concatenates `new_arr` to `old_arr` along the first axis. + If `old_arr` is None, it returns `new_arr` directly. Supports both NumPy and Dask + arrays. - """ - if mode not in ["wsi", "tile"]: - msg = f"{mode} is not a valid mode. Use either `tile` or `wsi`." - raise ValueError(msg) + Args: + old_arr (np.ndarray | da.Array): + Existing array to append to. Can be None. + new_arr (np.ndarray | da.Array): + New array to append. + + Returns: + np.ndarray | da.Array: + Concatenated array of the same type as `new_arr`. - save_dir, self._cache_dir = self._prepare_save_dir(save_dir) + """ + if isinstance(new_arr, np.ndarray): + return ( + new_arr if old_arr is None else np.concatenate((old_arr, new_arr), axis=0) + ) - if ioconfig is None: - ioconfig = copy.deepcopy(self.ioconfig) + return new_arr if old_arr is None else da.concatenate([old_arr, new_arr], axis=0) - if ioconfig is None and patch_input_shape is None: - msg = ( - "Must provide either `ioconfig` or " - "`patch_input_shape` and `patch_output_shape`" - ) - raise ValueError( - msg, - ) - if resolution is None and units is None: - if ioconfig is None: - msg = f"Invalid resolution: `{resolution}` and units: `{units}`. " - raise ValueError( - msg, - ) +def merge_batch_to_canvas( + blocks: np.ndarray, + output_locations: np.ndarray, + merged_shape: tuple[int, int, int], +) -> tuple[np.ndarray, np.ndarray]: + """Merge patch-level predictions into a single canvas. + + This function aggregates overlapping patch predictions into a unified + output canvas and maintains a count map to normalize overlapping regions. + + Args: + blocks (np.ndarray): + Array of predicted blocks with shape (N, H, W, C), where N is the + number of patches. + output_locations (np.ndarray): + Array of coordinates for each block in the format + [start_x, start_y, end_x, end_y] with shape (N, 4). + merged_shape (tuple[int, int, int]): + Shape of the final merged canvas (H, W, C). + + Returns: + tuple[np.ndarray, np.ndarray]: + - canvas: Merged prediction map of shape (H, W, C). + - count: Count map indicating how many times each pixel was updated, + shape (H, W). + + """ + canvas = np.zeros(merged_shape, dtype=blocks.dtype) + count = np.zeros((*merged_shape[:2], 1), dtype=np.uint8) + for i, block in enumerate(blocks): + xs, ys, xe, ye = output_locations[i] + if not np.any(block): + continue + # To deal with edge cases + canvas[0 : ye - ys, xs:xe, :] += block[0 : ye - ys, 0 : xe - xs, :] + count[ys:ye, xs:xe, 0] += 1 + return canvas, count + + +def merge_horizontal( + canvas: None | da.Array, + count: None | da.Array, + output_locs_y_: np.ndarray, + canvas_np: np.ndarray, + output_locs: np.ndarray, + change_indices: np.ndarray | list[np.ndarray], +) -> tuple[da.Array, da.Array, np.ndarray, np.ndarray, np.ndarray]: + """Merge horizontal patches incrementally for each row of patches. + + This function processes segments of NumPy patch arrays (`canvas_np`, `count_np`, + `output_locs`) based on `change_indices`, merging them horizontally and appending + the results to Dask arrays. It also updates the vertical output locations + (`output_locs_y_`) for downstream vertical merging. - resolution = ioconfig.input_resolutions[0]["resolution"] - units = ioconfig.input_resolutions[0]["units"] - - ioconfig = self._update_ioconfig( - ioconfig, - mode, - patch_input_shape, - patch_output_shape, - stride_shape, - resolution, - units, + Args: + canvas (None | da.Array): + Existing Dask array for canvas data, or None if uninitialized. + count (None | da.Array): + Existing Dask array for count data, or None if uninitialized. + output_locs_y_ (np.ndarray): + Array tracking vertical output locations for merged patches. + canvas_np (np.ndarray): + NumPy array of canvas patches to be merged. + output_locs (np.ndarray): + Array of output locations for each patch. + change_indices (np.ndarray | list[np.ndarray]): + Indices indicating where to flush and merge patches. + + Returns: + tuple: + Updated canvas and count Dask arrays, along with remaining canvas_np, + count_np, output_locs, and output_locs_y_ arrays after processing. + + """ + start_idx = 0 + for c_idx in change_indices: + output_locs_ = output_locs[: c_idx - start_idx] + canvas_np_ = canvas_np[: c_idx - start_idx] + + batch_xs = np.min(output_locs[:, 0], axis=0) + batch_xe = np.max(output_locs[:, 2], axis=0) + + merged_shape = (canvas_np_.shape[1], batch_xe - batch_xs, canvas_np.shape[3]) + + canvas_merge, count_merge = merge_batch_to_canvas( + blocks=canvas_np_, + output_locations=output_locs_, + merged_shape=merged_shape, ) - # use external for testing - self._device = device - self._model = model_to(model=self.model, device=device) + canvas_merge = da.from_array(canvas_merge, chunks=canvas_merge.shape) + count_merge = da.from_array(count_merge, chunks=count_merge.shape) - # workers should be > 0 else Value Error will be thrown - self._prepare_workers() + canvas = concatenate_none(old_arr=canvas, new_arr=canvas_merge) + count = concatenate_none(old_arr=count, new_arr=count_merge) - mp_manager = torch_mp.Manager() - mp_shared_space = mp_manager.Namespace() - self._mp_shared_space = mp_shared_space + output_locs_y_ = concatenate_none( + old_arr=output_locs_y_, new_arr=output_locs[:, (1, 3)] + ) - ds = self.dataset_class( - ioconfig=ioconfig, - preproc=self.model.preproc_func, - wsi_paths=imgs, - mp_shared_space=mp_shared_space, - mode=mode, + canvas_np = canvas_np[c_idx - start_idx :] + output_locs = output_locs[c_idx - start_idx :] + start_idx = c_idx + + return canvas, count, canvas_np, output_locs, output_locs_y_ + + +def save_to_cache( + canvas: da.Array, + count: da.Array, + canvas_zarr: zarr.Array, + count_zarr: zarr.Array, + save_path: str | Path = "temp.zarr", +) -> tuple[zarr.Array, zarr.Array]: + """Save computed canvas and count arrays to Zarr cache. + + This function computes the given Dask arrays (`canvas` and `count`), resizes the + corresponding Zarr datasets to accommodate the new data, and appends the results. + If the Zarr datasets do not exist, it initializes them within the specified + Zarr group. + + Args: + canvas (da.Array): + Dask array representing image or feature data. + count (da.Array): + Dask array representing count or normalization data. + canvas_zarr (zarr.Array): + Existing Zarr dataset for canvas data. If None, a new one is created. + count_zarr (zarr.Array): + Existing Zarr dataset for count data. If None, a new one is created. + save_path (str | Path): + Path to the Zarr group for saving datasets. Defaults to "temp.zarr". + + Returns: + tuple[zarr.Array, zarr.Array]: + Updated Zarr datasets for canvas and count arrays. + + """ + computed_values = compute(*[canvas, count]) + canvas_computed, count_computed = computed_values + + chunk_shape = tuple(chunk[0] for chunk in canvas.chunks) + if canvas_zarr is None: + zarr_group = zarr.open(str(save_path), mode="w") + + canvas_zarr = zarr_group.create_dataset( + name="canvas", + shape=(0, *canvas_computed.shape[1:]), + chunks=(chunk_shape[0], *canvas_computed.shape[1:]), + dtype=canvas_computed.dtype, + overwrite=True, ) - loader = torch_data.DataLoader( - ds, - drop_last=False, - batch_size=self.batch_size, - num_workers=self.num_loader_workers, - persistent_workers=self.num_loader_workers > 0, + count_zarr = zarr_group.create_dataset( + name="count", + shape=(0, *count_computed.shape[1:]), + dtype=count_computed.dtype, + chunks=(chunk_shape[0], *count_computed.shape[1:]), + overwrite=True, ) - self._loader = loader - self.imgs = imgs - self.masks = masks - - # contain input / output prediction mapping - self._outputs = [] - # ? what will happen if this crash midway? - # => may not be able to retrieve the result dict - for wsi_idx, img_path in enumerate(imgs): - self._predict_wsi_handle_exception( - imgs=imgs, - wsi_idx=wsi_idx, - img_path=img_path, - mode=mode, - ioconfig=ioconfig, - save_dir=save_dir, - crash_on_exception=crash_on_exception, + canvas_zarr.resize( + (canvas_zarr.shape[0] + canvas_computed.shape[0], *canvas_zarr.shape[1:]) + ) + canvas_zarr[-canvas_computed.shape[0] :] = canvas_computed + + count_zarr.resize( + (count_zarr.shape[0] + count_computed.shape[0], *count_zarr.shape[1:]) + ) + count_zarr[-count_computed.shape[0] :] = count_computed + + return canvas_zarr, count_zarr + + +def merge_vertical_chunkwise( + canvas: da.Array, + count: da.Array, + output_locs_y_: np.ndarray, + zarr_group: zarr.Group, + save_path: Path, + memory_threshold: int = 80, +) -> da.Array: + """Merge vertically chunked canvas and count arrays into a single probability map. + + This function processes vertically stacked image blocks (`canvas`) and their + associated count arrays to compute normalized probabilities. It handles overlapping + regions between chunks by applying seam folding and trimming halos to ensure smooth + transitions. If a Zarr group is provided, the result is stored incrementally. + + Args: + canvas (da.Array): + Dask array containing image data split into vertical chunks. + count (da.Array): + Dask array containing count data corresponding to the canvas. + output_locs_y_ (np.ndarray): + Array of shape (N, 2) specifying vertical output locations + for each chunk, used to compute overlaps. + zarr_group (zarr.Group): + Zarr group to store the merged probability dataset. + save_path (Path): + Path to save the intermediate output. The intermediate output + is saved in a Zarr file. + memory_threshold (int): + Memory usage threshold (in percentage) to trigger caching behavior. + + Returns: + da.Array: + A merged Dask array of normalized probabilities, either loaded from Zarr + or constructed in memory. + + """ + y0s, y1s = np.unique(output_locs_y_[:, 0]), np.unique(output_locs_y_[:, 1]) + overlaps = np.append(y1s[:-1] - y0s[1:], 0) + + num_chunks = canvas.numblocks[0] + probabilities_zarr, probabilities_da = None, None + chunk_shape = tuple(chunk[0] for chunk in canvas.chunks) + + tqdm = get_tqdm() + tqdm_loop = tqdm(overlaps, leave=False, desc="Merging rows") + + used_percent = 0 + + curr_chunk = canvas.blocks[0, 0].compute() + curr_count = count.blocks[0, 0].compute() + next_chunk = canvas.blocks[1, 0].compute() if num_chunks > 1 else None + next_count = count.blocks[1, 0].compute() if num_chunks > 1 else None + + for i, overlap in enumerate(tqdm_loop): + if next_chunk is not None and overlap > 0: + curr_chunk[-overlap:] += next_chunk[:overlap] + curr_count[-overlap:] += next_count[:overlap] + + # Normalize + curr_count = np.where(curr_count == 0, 1, curr_count) + probabilities = curr_chunk / curr_count.astype(np.float32) + + probabilities_zarr, probabilities_da = store_probabilities( + probabilities=probabilities, + chunk_shape=chunk_shape, + probabilities_zarr=probabilities_zarr, + probabilities_da=probabilities_da, + zarr_group=zarr_group, + ) + + if probabilities_da is not None: + vm = psutil.virtual_memory() + used_percent = (probabilities_da.nbytes / vm.free) * 100 + if probabilities_zarr is None and used_percent > memory_threshold: + msg = ( + f"Current Memory usage: {used_percent} % " + f"exceeds specified threshold: {memory_threshold}. " + f"Saving intermediate results to disk." + ) + tqdm.write(msg) + zarr_group = zarr.open(str(save_path), mode="a") + probabilities_zarr = zarr_group.create_dataset( + name="probabilities", + shape=probabilities_da.shape, + chunks=(chunk_shape[0], *probabilities.shape[1:]), + dtype=probabilities.dtype, + overwrite=True, ) + probabilities_zarr[:] = probabilities_da.compute() - # clean up the cache directories - try: - shutil.rmtree(self._cache_dir) - except PermissionError: # pragma: no cover - logger.warning("Unable to remove %s", self._cache_dir) + probabilities_da = None - self._memory_cleanup() + if next_chunk is not None: + curr_chunk, curr_count = next_chunk[overlap:], next_count[overlap:] + + if i + 2 < num_chunks: + next_chunk = canvas.blocks[i + 2, 0].compute() + next_count = count.blocks[i + 2, 0].compute() + else: + next_chunk, next_count = None, None + + if probabilities_zarr: + if "canvas" in zarr_group: + del zarr_group["canvas"] + if "count" in zarr_group: + del zarr_group["count"] + return da.from_zarr( + probabilities_zarr, chunks=(chunk_shape[0], *probabilities.shape[1:]) + ) - return self._outputs + return probabilities_da -class DeepFeatureExtractor(SemanticSegmentor): - """Generic CNN Feature Extractor. +def store_probabilities( + probabilities: np.ndarray, + chunk_shape: tuple[int, ...], + probabilities_zarr: zarr.Array | None, + probabilities_da: da.Array | None, + zarr_group: zarr.Group | None, +) -> tuple[zarr.Array | None, da.Array | None]: + """Store computed probability data into a Zarr dataset or accumulate in memory. - AN engine for using any CNN model as a feature extractor. Note, if - `model` is supplied in the arguments, it will ignore the - `pretrained_model` and `pretrained_weights` arguments. + If a Zarr group is provided, the function appends the given probability array + to the 'probabilities' dataset, resizing as needed. Otherwise, it concatenates + the array into an existing Dask array for in-memory accumulation. Args: - model (nn.Module): - Use externally defined PyTorch model for prediction with - weights already loaded. Default is `None`. If provided, - `pretrained_model` argument is ignored. - pretrained_model (str): - Name of the existing models support by tiatoolbox for - processing the data. By default, the corresponding - pretrained weights will also be downloaded. However, you can - override with your own set of weights via the - `pretrained_weights` argument. Argument is case-insensitive. - Refer to - :class:`tiatoolbox.models.architecture.vanilla.CNNBackbone` - for list of supported pretrained models. - pretrained_weights (str): - Path to the weight of the corresponding `pretrained_model`. - batch_size (int): - Number of images fed into the model each time. - num_loader_workers (int): - Number of workers to load the data. Take note that they will - also perform preprocessing. - num_postproc_workers (int): - This value is there to maintain input compatibility with - `tiatoolbox.models.classification` and is not used. - verbose (bool): - Whether to output logging information. - dataset_class (obj): - Dataset class to be used instead of default. - auto_generate_mask(bool): - To automatically generate tile/WSI tissue mask if is not - provided. + probabilities (np.ndarray): + Computed probability array to store. + chunk_shape (tuple[int, ...]): + Chunk shape used for Zarr dataset creation. + probabilities_zarr (zarr.Array | None): + Existing Zarr dataset, or None to initialize. + probabilities_da (da.Array | None): + Existing Dask array for in-memory accumulation. + zarr_group (zarr.Group | None): + Zarr group used to create or access the dataset. - Examples: - >>> # Sample output of a network - >>> from tiatoolbox.models.architecture.vanilla import CNNBackbone - >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] - >>> # create resnet50 with pytorch pretrained weights - >>> model = CNNBackbone('resnet50') - >>> predictor = DeepFeatureExtractor(model=model) - >>> output = predictor.predict(wsis, mode='wsi') - >>> list(output.keys()) - [('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')] - >>> # If a network have 2 output heads, for 'A/wsi.svs', - >>> # there will be 3 outputs, and they are respectively stored at - >>> # 'output/0.position.npy' # will always be output - >>> # 'output/0.features.0.npy' # output of head 0 - >>> # 'output/0.features.1.npy' # output of head 1 - >>> # Each file will contain a same number of items, and the item at each - >>> # index corresponds to 1 patch. The item in `.*position.npy` will - >>> # be the corresponding patch bounding box. The box coordinates are at - >>> # the inference resolution defined within the provided `ioconfig`. + Returns: + tuple[zarr.Array | None, da.Array | None]: + Updated Zarr dataset and/or Dask array. """ + if zarr_group is not None: + if probabilities_zarr is None: + probabilities_zarr = zarr_group.create_dataset( + name="probabilities", + shape=(0, *probabilities.shape[1:]), + chunks=(chunk_shape[0], *probabilities.shape[1:]), + dtype=probabilities.dtype, + ) - def __init__( - self: DeepFeatureExtractor, - batch_size: int = 8, - num_loader_workers: int = 0, - num_postproc_workers: int = 0, - model: torch.nn.Module | None = None, - pretrained_model: str | None = None, - pretrained_weights: str | None = None, - dataset_class: Callable = WSIStreamDataset, - *, - verbose: bool = True, - auto_generate_mask: bool = False, - ) -> None: - """Initialize :class:`DeepFeatureExtractor`.""" - super().__init__( - batch_size=batch_size, - num_loader_workers=num_loader_workers, - num_postproc_workers=num_postproc_workers, - model=model, - pretrained_model=pretrained_model, - pretrained_weights=pretrained_weights, - verbose=verbose, - auto_generate_mask=auto_generate_mask, - dataset_class=dataset_class, + probabilities_zarr.resize( + ( + probabilities_zarr.shape[0] + probabilities.shape[0], + *probabilities_zarr.shape[1:], + ) + ) + probabilities_zarr[-probabilities.shape[0] :] = probabilities + else: + probabilities_da = concatenate_none( + old_arr=probabilities_da, + new_arr=da.from_array( + probabilities, chunks=(chunk_shape[0], *probabilities.shape[1:]) + ), ) - self.process_prediction_per_batch = False - - def _process_predictions( - self: DeepFeatureExtractor, - cum_batch_predictions: list, - wsi_reader: WSIReader, # skipcq: PYL-W0613 # noqa: ARG002 - ioconfig: IOSegmentorConfig, - save_path: str, - cache_dir: str, # skipcq: PYL-W0613 # noqa: ARG002 - ) -> None: - """Define how the aggregated predictions are processed. - This includes merging the prediction if necessary and also - saving afterward. + return probabilities_zarr, probabilities_da - Args: - cum_batch_predictions (list): - List of batch predictions. Each item within the list - should be of (location, patch_predictions). - wsi_reader (:class:`WSIReader`): - A reader for the image where the predictions come from. - Not used here. Added for consistency with the API. - ioconfig (:class:`IOSegmentorConfig`): - A configuration object contains input and output - information. - save_path (str): - Root path to save current WSI predictions. - cache_dir (str): - Root path to cache current WSI data. - Not used here. Added for consistency with the API. - """ - # assume prediction_list is N, each item has L output elements - location_list, prediction_list = list(zip(*cum_batch_predictions)) - # Nx4 (N x [tl_x, tl_y, br_x, br_y), denotes the location of output - # patch, this can exceed the image bound at the requested resolution - # remove singleton due to split. - location_list = np.array([v[0] for v in location_list]) - np.save(f"{save_path}.position.npy", location_list) - for idx, _ in enumerate(ioconfig.output_resolutions): - # assume resolution idx to be in the same order as L - # 0 idx is to remove singleton without removing other axes singleton - prediction_list = [v[idx][0] for v in prediction_list] - prediction_list = np.array(prediction_list) - np.save(f"{save_path}.features.{idx}.npy", prediction_list) - - def predict( # noqa: PLR0913 - self: DeepFeatureExtractor, - imgs: list, - masks: list | None = None, - mode: str = "tile", - ioconfig: IOSegmentorConfig | None = None, - patch_input_shape: IntPair | None = None, - patch_output_shape: IntPair | None = None, - stride_shape: IntPair = None, - resolution: Resolution = 1.0, - units: Units = "baseline", - save_dir: str | Path | None = None, - device: str = "cpu", - *, - crash_on_exception: bool = False, - ) -> list[tuple[Path, Path]]: - """Make a prediction for a list of input data. - - By default, if the input model at the time of object - instantiation is a pretrained model in the toolbox as well as - `patch_input_shape`, `patch_output_shape`, `stride_shape`, - `resolution`, `units` and `ioconfig` are `None`. The method will - use the `ioconfig` retrieved together with the pretrained model. - Otherwise, either `patch_input_shape`, `patch_output_shape`, - `stride_shape`, `resolution`, `units` or `ioconfig` must be set - - else a `Value Error` will be raised. +def prepare_full_batch( + batch_output: np.ndarray, + batch_locs: np.ndarray, + full_output_locs: np.ndarray, + output_locs: np.ndarray, + *, + is_last: bool, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Prepare full-sized output and count arrays for a batch of patch predictions. - Args: - imgs (list, ndarray): - List of inputs to process. When using `"patch"` mode, - the input must be either a list of images, a list of - image file paths or a numpy array of an image list. When - using `"tile"` or `"wsi"` mode, the input must be a list - of file paths. - masks (list): - List of masks. Only utilised when processing image tiles - and whole-slide images. Patches are only processed if - they are within a masked area. If not provided, then a - tissue mask will be automatically generated for each - whole-slide image or all image tiles in the entire image - are processed. - mode (str): - Type of input to process. Choose from either `tile` or - `wsi`. - ioconfig (:class:`IOSegmentorConfig`): - Object that defines information about input and output - placement of patches. When provided, - `patch_input_shape`, `patch_output_shape`, - `stride_shape`, `resolution`, and `units` arguments are - ignored. Otherwise, those arguments will be internally - converted to a :class:`IOSegmentorConfig` object. - device (str): - :class:`torch.device` to run the model. - Select the device to run the model. Please see - https://pytorch.org/docs/stable/tensor_attributes.html#torch.device - for more details on input parameters for device. Default value is "cpu". - patch_input_shape (IntPair): - Size of patches input to the model. The values are at - requested read resolution and must be positive. - patch_output_shape (tuple): - Size of patches output by the model. The values are at - the requested read resolution and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. The values - are at requested read resolution and must be positive. - If not provided, `stride_shape=patch_input_shape` is - used. - resolution (Resolution): - Resolution used for reading the image. - units (Units): - Units of resolution used for reading the image. - save_dir (str or pathlib.Path): - Output directory when processing multiple tiles and - whole-slide images. By default, it is folder `output` - where the running script is invoked. - crash_on_exception (bool): - If `True`, the running loop will crash if there is any - error during processing a WSI. Otherwise, the loop will - move on to the next wsi for processing. + This function aligns patch-level predictions with global output locations when + a mask (e.g., auto_get_mask) is applied. It initializes full-sized arrays and + fills them using matched indices. If the batch is the last in the sequence, + it pads the arrays to cover remaining locations. - Returns: - list: - A list of tuple(input_path, save_path) where - `input_path` is the path of the input wsi while - `save_path` corresponds to the output predictions. + Args: + batch_output (np.ndarray): + Patch-level model predictions of shape (N, H, W, C). + batch_locs (np.ndarray): + Output locations corresponding to `batch_output`. + full_output_locs (np.ndarray): + Remaining global output locations to be matched. + output_locs (np.ndarray): + Accumulated output location array across batches. + is_last (bool): + Flag indicating whether this is the final batch. - Examples: - >>> # Sample output of a network - >>> from tiatoolbox.models.architecture.vanilla import CNNBackbone - >>> wsis = ['A/wsi.svs', 'B/wsi.svs'] - >>> # create resnet50 with pytorch pretrained weights - >>> model = CNNBackbone('resnet50') - >>> predictor = DeepFeatureExtractor(model=model) - >>> output = predictor.predict(wsis, mode='wsi') - >>> list(output.keys()) - [('A/wsi.svs', 'output/0') , ('B/wsi.svs', 'output/1')] - >>> # If a network have 2 output heads, for 'A/wsi.svs', - >>> # there will be 3 outputs, and they are respectively stored at - >>> # 'output/0.position.npy' # will always be output - >>> # 'output/0.features.0.npy' # output of head 0 - >>> # 'output/0.features.1.npy' # output of head 1 - >>> # Each file will contain a same number of items, and the item at each - >>> # index corresponds to 1 patch. The item in `.*position.npy` will - >>> # be the corresponding patch bounding box. The box coordinates are at - >>> # the inference resolution defined within the provided `ioconfig`. + Returns: + tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + - full_batch_output: Full-sized output array with predictions placed. + - full_output_locs: Updated remaining global output locations. + - output_locs: Updated accumulated output locations. - """ - return super().predict( - imgs=imgs, - masks=masks, - mode=mode, - device=device, - ioconfig=ioconfig, - patch_input_shape=patch_input_shape, - patch_output_shape=patch_output_shape, - stride_shape=stride_shape, - resolution=resolution, - units=units, - save_dir=save_dir, - crash_on_exception=crash_on_exception, + """ + # Use np.intersect1d once numpy version is upgraded to 2.0 + full_output_dict = {tuple(row): i for i, row in enumerate(full_output_locs)} + matches = [full_output_dict[tuple(row)] for row in batch_locs] + + total_size = np.max(matches).astype(np.uint16) + 1 + + # Initialize full output array + full_batch_output = np.zeros( + shape=(total_size, *batch_output.shape[1:]), + dtype=batch_output.dtype, + ) + + # Place matching outputs using matching indices + full_batch_output[matches] = batch_output + + output_locs = concatenate_none( + old_arr=output_locs, new_arr=full_output_locs[:total_size] + ) + full_output_locs = full_output_locs[total_size:] + + if is_last: + output_locs = concatenate_none(old_arr=output_locs, new_arr=full_output_locs) + full_batch_output = concatenate_none( + old_arr=full_batch_output, + new_arr=np.zeros( + shape=(len(full_output_locs), *batch_output.shape[1:]), dtype=np.uint8 + ), ) + + return full_batch_output, full_output_locs, output_locs diff --git a/tiatoolbox/models/models_abc.py b/tiatoolbox/models/models_abc.py index 412e95607..43a3cec5d 100644 --- a/tiatoolbox/models/models_abc.py +++ b/tiatoolbox/models/models_abc.py @@ -102,7 +102,9 @@ def forward( @staticmethod @abstractmethod - def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> dict: + def infer_batch( + model: nn.Module, batch_data: np.ndarray | torch.Tensor, *, device: str + ) -> np.ndarray | tuple[np.ndarray, ...] | dict: """Run inference on an input batch. Contains logic for forward operation as well as I/O aggregation. @@ -110,13 +112,15 @@ def infer_batch(model: nn.Module, batch_data: np.ndarray, *, device: str) -> dic Args: model (nn.Module): PyTorch defined model. - batch_data (np.ndarray): + batch_data (np.ndarray | torch.Tensor): A batch of data generated by `torch.utils.data.DataLoader`. device (str): - Transfers model to the specified device. Default is "cpu". + Transfers model to the specified device. Returns: + np.ndarray: + The inference results as a numpy array. dict: Returns a dictionary of predictions and other expected outputs depending on the network architecture. diff --git a/tiatoolbox/utils/exceptions.py b/tiatoolbox/utils/exceptions.py index db74af710..2f9f2a126 100644 --- a/tiatoolbox/utils/exceptions.py +++ b/tiatoolbox/utils/exceptions.py @@ -33,3 +33,23 @@ def __init__( ) -> None: """Initialize :class:`MethodNotSupportedError`.""" super().__init__(message) + + +class DimensionMismatchError(Exception): + """Raise dimension mismatch error. + + Args: + expected_dims (list or tuple) : Expected dimensions. + actual_dims (list or tuple) : Actual dimensions. + + """ + + def __init__( + self: DimensionMismatchError, + expected_dims: list | tuple, + actual_dims: list | tuple, + ) -> None: + """Initialize :class:`DimensionMismatchError`.""" + self.expected_dims = expected_dims + self.actual_dims = actual_dims + super().__init__(f"Expected dimensions {expected_dims}, but got {actual_dims}.") diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index cff2c2a6b..4dcd82d48 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -11,8 +11,8 @@ from typing import IO, TYPE_CHECKING, cast import cv2 +import dask.array as da import joblib -import numcodecs import numpy as np import pandas as pd import requests @@ -24,10 +24,12 @@ from shapely.geometry import Polygon from shapely.geometry import shape as feature2geometry from skimage import exposure -from tqdm import trange +from tqdm import notebook as tqdm_notebook +from tqdm import tqdm, trange from tiatoolbox import logger from tiatoolbox.annotation.storage import Annotation, AnnotationStore, SQLiteStore +from tiatoolbox.utils.env_detection import is_notebook from tiatoolbox.utils.exceptions import FileNotSupportedError if TYPE_CHECKING: # pragma: no cover @@ -163,7 +165,7 @@ def imwrite(image_path: PathLike, img: np.ndarray) -> None: def imread(image_path: PathLike, *, as_uint8: bool | None = None) -> np.ndarray: - """Read an image as a NumPy array. + """Read an image as :class:`numpy.ndarray`. Args: image_path (PathLike): @@ -1367,12 +1369,11 @@ def dict_to_store_semantic_segmentor( for each patch. """ - preds = patch_output["predictions"] + preds = da.from_array(patch_output["predictions"], chunks="auto") # Get the number of unique predictions - layer_list = np.unique(preds) - - layer_list = np.delete(layer_list, np.where(layer_list == 0)) + layer_list = da.unique(preds).compute() + layer_list = layer_list[layer_list != 0] store = SQLiteStore() @@ -1381,12 +1382,13 @@ def dict_to_store_semantic_segmentor( annotations_list: list[Annotation] = [] for type_class in layer_list: - layer = np.where(preds == type_class, 1, 0) + layer = da.where(preds == type_class, 1, 0).astype("uint8").compute() contours, hierarchy = cv2.findContours( - layer.astype("uint8"), + layer, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE, ) + contours = cast("list[np.ndarray]", contours) annotations_list_ = process_contours(contours, hierarchy, scale_factor) @@ -1409,13 +1411,13 @@ def dict_to_store_semantic_segmentor( return store -def dict_to_store( +def dict_to_store_patch_predictions( patch_output: dict | zarr.group, scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, ) -> AnnotationStore | Path: - """Converts (and optionally saves) output of TIAToolbox engines as AnnotationStore. + """Converts output of TIAToolbox PatchPredictor engine to AnnotationStore. Args: patch_output (dict | zarr.Group): @@ -1451,7 +1453,7 @@ def dict_to_store( patch_coords = np.array(patch_output.get("coordinates", [])) if not np.all(np.array(scale_factor) == 1): patch_coords = patch_coords * (np.tile(scale_factor, 2)) # to baseline mpp - patch_coords = patch_coords.astype(float) + labels = patch_output.get("labels", []) # get classes to consider if len(class_probs) == 0: @@ -1472,11 +1474,11 @@ def dict_to_store( # put patch predictions into a store annotations = patch_predictions_as_annotations( - preds, + preds.astype(float), keys, class_dict, - class_probs, - patch_coords, + class_probs.astype(float), + patch_coords.astype(float), classes_predicted, labels, ) @@ -1502,6 +1504,27 @@ def _tiles( colormap: int = cv2.COLORMAP_JET, level: int = 0, ) -> Iterator[np.ndarray]: + """Generate color-mapped tiles from an input image or Zarr array. + + This function iterates over the input image in non-overlapping tiles of the + specified size, optionally downsampling by a power-of-two factor (`level`), + and applies a colormap to each tile before yielding it. + + Parameters: + in_img (np.ndarray | zarr.core.Array): + Input image or Zarr array to be tiled. + tile_size (tuple[int, int]): + Height and width of each tile. + colormap (int, optional): + OpenCV colormap to apply to each tile. Defaults to cv2.COLORMAP_JET. + level (int, optional): + Downsampling factor as a power of two. Defaults to 0 (no downsampling). + + Yields: + np.ndarray: + A color-mapped tile extracted from the input image. + + """ for y in trange(0, in_img.shape[0], tile_size[0]): for x in range(0, in_img.shape[1], tile_size[1]): in_img_ = in_img[ @@ -1607,185 +1630,42 @@ def write_probability_heatmap_as_ome_tiff( logger.info(msg) -def dict_to_zarr( - raw_predictions: dict, - save_path: Path, - **kwargs: dict, -) -> Path: - """Saves the output of TIAToolbox engines to a zarr file. +def get_tqdm() -> type[tqdm_notebook | tqdm]: + """Returns appropriate tqdm tqdm object.""" + if is_notebook(): # pragma: no cover + return tqdm_notebook.tqdm + return tqdm - Args: - raw_predictions (dict): - A dictionary in the TIAToolbox Engines output format. - save_path (str or Path): - Path to save the zarr file. - **kwargs (dict): - Keyword Args to update patch_pred_store_zarr attributes. +def cast_to_min_dtype(array: np.ndarray | da.Array) -> np.ndarray | da.Array: + """Cast the input array to the minimal data type required to represent its values. - Returns: - Path to zarr file storing the patch predictor output - - """ - # Default values for Compressor and Chunks set if not received from kwargs. - compressor = ( - kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) - ) - chunks = kwargs.get("chunks", 10000) - - # ensure proper zarr extension - save_path = save_path.parent.absolute() / (save_path.stem + ".zarr") - - # save to zarr - probabilities_array = np.array(raw_predictions["probabilities"]) - z = zarr.open( - str(save_path), - mode="w", - shape=probabilities_array.shape, - chunks=chunks, - compressor=compressor, - ) - z[:] = probabilities_array - - return save_path + This function determines the maximum value in the array and casts it to the smallest + unsigned integer type (or boolean) that can accommodate all values. It supports both + NumPy and Dask arrays and preserves the input type in the output. - -def wsi_batch_output_to_zarr_group( - wsi_batch_zarr_group: zarr.group | None, - batch_output_probabilities: np.ndarray, - batch_output_predictions: np.ndarray, - batch_output_coordinates: np.ndarray | None, - batch_output_label: np.ndarray | None, - save_path: Path, - **kwargs: dict, -) -> zarr.group | Path: - """Saves the intermediate batch outputs of TIAToolbox engines to a zarr file. + For Dask arrays, the maximum value is computed lazily and only when needed. Args: - wsi_batch_zarr_group (zarr.group): - Optional zarr group name consisting of zarrs to save the batch output - values. - batch_output_probabilities (np.ndarray): - Probability batch output from infer wsi. - batch_output_predictions (np.ndarray): - Predictions batch output from infer wsi. - batch_output_coordinates (np.ndarray): - Coordinates batch output from infer wsi. - batch_output_label (np.ndarray): - Labels batch output from infer wsi. - save_path (str or Path): - Path to save the zarr file. - **kwargs (dict): - Keyword Args to update wsi_batch_output_to_zarr_group attributes. + array (Union[np.ndarray, da.Array]): Input array containing integer values. Returns: - Path to the zarr file storing the :class:`EngineABC` output. + (np.ndarray or da.Array): + A copy of the input array cast to the minimal required dtype. + - If the maximum value is 1, the array is cast to boolean. + - Otherwise, it is cast to the smallest suitable unsigned integer type. """ - # Default values for Compressor and Chunks set if not received from kwargs. - compressor = ( - kwargs["compressor"] if "compressor" in kwargs else numcodecs.Zstd(level=1) - ) - chunks = kwargs.get("chunks", 10000) - - # case 1 - new zarr group - if not wsi_batch_zarr_group: - # ensure proper zarr extension and create persistant zarr group - save_path = save_path.parent.absolute() / (save_path.stem + ".zarr") - wsi_batch_zarr_group = zarr.open(save_path, mode="w") - - # populate the zarr group for the first time - probabilities_zarr = wsi_batch_zarr_group.create_dataset( - name="probabilities", - shape=batch_output_probabilities.shape, - chunks=chunks, - compressor=compressor, - ) - probabilities_zarr[:] = batch_output_probabilities - - predictions_zarr = wsi_batch_zarr_group.create_dataset( - name="predictions", - shape=batch_output_predictions.shape, - chunks=chunks, - compressor=compressor, - ) - predictions_zarr[:] = batch_output_predictions - - if batch_output_coordinates is not None: - coordinates_zarr = wsi_batch_zarr_group.create_dataset( - name="coordinates", - shape=batch_output_coordinates.shape, - chunks=chunks, - compressor=compressor, - ) - coordinates_zarr[:] = batch_output_coordinates - - if batch_output_label is not None: - labels_zarr = wsi_batch_zarr_group.create_dataset( - name="labels", - shape=batch_output_label.shape, - chunks=chunks, - compressor=compressor, - ) - labels_zarr[:] = batch_output_label - - # case 2 - append to existing zarr group - probabilities_zarr = wsi_batch_zarr_group["probabilities"] - probabilities_zarr.append(batch_output_probabilities) - - predictions_zarr = wsi_batch_zarr_group["predictions"] - predictions_zarr.append(batch_output_predictions) - - if batch_output_coordinates is not None: - coordinates_zarr = wsi_batch_zarr_group["coordinates"] - coordinates_zarr.append(batch_output_coordinates) - - if batch_output_label is not None: - labels_zarr = wsi_batch_zarr_group["labels"] - labels_zarr.append(batch_output_label) - - return wsi_batch_zarr_group - - -def write_to_zarr_in_cache_mode( - zarr_group: zarr.group, - output_data_to_save: dict, - **kwargs: dict, -) -> zarr.group | Path: - """Saves the intermediate batch outputs of TIAToolbox engines to a zarr file. - - Args: - zarr_group (zarr.group): - Zarr group name consisting of zarr(s) to save the batch output - values. - output_data_to_save (dict): - Output data from the Engine to save to Zarr. Expects the data saved in - dictionary to be a numpy array. - **kwargs (dict): - Keyword Args to update zarr_group attributes. - - Returns: - Path to the zarr file storing the :class:`EngineABC` output. - - """ - # Default values for Compressor and Chunks set if not received from kwargs. - compressor = kwargs.get("compressor", numcodecs.Zstd(level=1)) - - # case 1 - new zarr group - if not zarr_group: - for key, value in output_data_to_save.items(): - # populate the zarr group for the first time - zarr_dataset = zarr_group.create_dataset( - name=key, - shape=value.shape, - compressor=compressor, - ) - zarr_dataset[:] = value + is_dask = isinstance(array, da.Array) + max_value = da.max(array) if is_dask else np.max(array) + max_value = max_value.compute() if is_dask else max_value - return zarr_group + if max_value == 1: + return array.astype(bool) - # case 2 - append to existing zarr group - for key, value in output_data_to_save.items(): - zarr_group[key].append(value) + dtypes = [np.uint8, np.uint16, np.uint32, np.uint64] + for dtype in dtypes: + if max_value <= np.iinfo(dtype).max: + return array.astype(dtype) - return zarr_group + return array diff --git a/tiatoolbox/utils/transforms.py b/tiatoolbox/utils/transforms.py index 8c2817b75..9f1f901c6 100644 --- a/tiatoolbox/utils/transforms.py +++ b/tiatoolbox/utils/transforms.py @@ -95,7 +95,7 @@ def imresize( img: np.ndarray, scale_factor: float | tuple[float, float] | None = None, output_size: int | tuple[int, int] | None = None, - interpolation: str = "optimise", + interpolation: str | int = "optimise", ) -> np.ndarray: """Resize input image.