diff --git a/README.rst b/README.rst index de93996..097c04b 100644 --- a/README.rst +++ b/README.rst @@ -95,6 +95,20 @@ To export Images or Plates via the OMERO API:: # By default, a tile (chunk) size of 1024 is used. Specify values with $ omero zarr export Image:1 --tile_width 256 --tile_height 256 + # To exclude Wells from a Plate export, based on Key-Value pairs + # (map annotations) use --skip_wells_map key:value. Supports wildcards. + $ omero zarr export Plate:1 --skip_wells_map my_key:my_value + $ omero zarr export Plate:1 --skip_wells_map my_key:my* + $ omero zarr export Plate:1 --skip_wells_map my_key:*value + $ omero zarr export Plate:1 --skip_wells_map my_key:*val* + $ omero zarr export Plate:1 --skip_wells_map my_key:* + + # Use --metadata_only to export only metadata, no pixel data + $ omero zarr export Image:1 --metadata_only + $ omero zarr export Plate:2 --metadata_only + + # To export Key-Value pairs from Wells in a Plate as a CSV file, + $ omero zarr export_csv Plate:1 --skip_wells_map my_key:* NB: If the connection to OMERO is lost and the Image is partially exported, re-running the command will attempt to complete the export. diff --git a/src/omero_zarr/cli.py b/src/omero_zarr/cli.py index 98bb79a..91f6782 100644 --- a/src/omero_zarr/cli.py +++ b/src/omero_zarr/cli.py @@ -29,6 +29,7 @@ from zarr.hierarchy import open_group from zarr.storage import FSStore +from .kvp_tables import plate_to_table from .masks import ( MASK_DTYPE_SIZE, MaskSaver, @@ -202,15 +203,6 @@ def _configure(self, parser: Parser) -> None: help=("Name of the array that will be stored. Ignored for --style=split"), default="0", ) - polygons.add_argument( - "--name_by", - default="id", - choices=["id", "name"], - help=( - "How the existing Image or Plate zarr is named. Default 'id' is " - "[ID].ome.zarr. 'name' is [NAME].ome.zarr" - ), - ) masks = parser.add(sub, self.masks, MASKS_HELP) masks.add_argument( @@ -261,15 +253,6 @@ def _configure(self, parser: Parser) -> None: "overlapping labels" ), ) - masks.add_argument( - "--name_by", - default="id", - choices=["id", "name"], - help=( - "How the existing Image or Plate zarr is named. Default 'id' is " - "[ID].ome.zarr. 'name' is [NAME].ome.zarr" - ), - ) export = parser.add(sub, self.export, EXPORT_HELP) export.add_argument( @@ -305,24 +288,44 @@ def _configure(self, parser: Parser) -> None: help="Maximum number of workers (only for use with bioformats2raw)", ) export.add_argument( - "--name_by", - default="id", - choices=["id", "name"], - help=( - "How to name the Image or Plate zarr. Default 'id' is [ID].ome.zarr. " - "'name' is [NAME].ome.zarr" - ), + "object", + type=ProxyStringType("Image"), + help="The Image to export.", ) export.add_argument( + "--metadata_only", + action="store_true", + help="Only write metadata, do not export pixel data", + ) + + # CSV export + csv = parser.add(sub, self.export_csv, "Export Key-Value pairs as csv") + csv.add_argument( "object", type=ProxyStringType("Image"), - help="The Image to export.", + help="The Plate from which to export Key-Value pairs.", ) - for subcommand in (polygons, masks, export): + # Need same arguments for Images and Masks + for subcommand in (polygons, masks, export, csv): subcommand.add_argument( "--output", type=str, default="", help="The output directory" ) + subcommand.add_argument( + "--skip_wells_map", + type=str, + help="For Plates, skip wells with MapAnnotation values" + "matching this key-value pair. e.g. 'MyKey:MyVal*'", + ) + subcommand.add_argument( + "--name_by", + default="id", + choices=["id", "name"], + help=( + "How to name the Image or Plate zarr. Default 'id' is " + "[ID].ome.zarr. 'name' is [NAME].ome.zarr" + ), + ) for subcommand in (polygons, masks): subcommand.add_argument( "--overlaps", @@ -406,6 +409,15 @@ def export(self, args: argparse.Namespace) -> None: plate = self._lookup(self.gateway, "Plate", args.object.id) plate_to_zarr(plate, args) + @gateway_required + def export_csv(self, args: argparse.Namespace) -> None: + """Export Image or Plate as a CSV file.""" + print("export_csv...", isinstance(args.object, PlateI)) + if isinstance(args.object, PlateI): + plate = self._lookup(self.gateway, "Plate", args.object.id) + self.ctx.out("Export Plate: %s" % plate.name) + plate_to_table(plate, args) + @gateway_required def import_cmd(self, args: argparse.Namespace) -> None: """Import a zarr file as an Image in OMERO.""" diff --git a/src/omero_zarr/kvp_tables.py b/src/omero_zarr/kvp_tables.py new file mode 100644 index 0000000..7683496 --- /dev/null +++ b/src/omero_zarr/kvp_tables.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python + +# Copyright (C) 2023 University of Dundee & Open Microscopy Environment. +# All rights reserved. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import argparse +import csv + +import omero.clients # noqa + +from .util import get_map_anns, get_zarr_name, map_anns_match + + +def plate_to_table( + plate: omero.gateway._PlateWrapper, args: argparse.Namespace +) -> None: + """ + Exports Well KVPs to a CSV table. + """ + name = get_zarr_name(plate, args.output, args.name_by) + skip_wells_map = args.skip_wells_map + + wells = list(plate.listChildren()) + # sort by row then column... + wells = sorted(wells, key=lambda x: (x.row, x.column)) + well_count = len(wells) + + well_kvps_by_id = get_map_anns(wells) + + if skip_wells_map: + # skip_wells_map is like MyKey:MyValue. + # Or wild-card MyKey:* or MyKey:Val* + wells = [ + well + for well in wells + if not map_anns_match(well_kvps_by_id.get(well.id, {}), skip_wells_map) + ] + print( + f"Skipping {well_count - len(wells)} out of {well_count} wells" + f" with skip_wells_map: {skip_wells_map}" + ) + + keys_set = set() + + for well in wells: + kvps = well_kvps_by_id.get(well.id, {}) + for key in kvps.keys(): + keys_set.add(key) + + column_names = list(keys_set) + column_names = sorted(column_names) + + print("column_names", column_names) + + plate_name = plate.getName() + + # write csv file... + csv_name = name.replace(".ome.zarr", ".csv") + print(f"Writing CSV file: {csv_name}") + with open(csv_name, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(["Plate", "Well"] + column_names) + + for well in wells: + kvps = well_kvps_by_id.get(well.id, {}) + row = [plate_name, f"{well.getWellPos()}"] + for key in column_names: + row.append(";".join(kvps.get(key, []))) + writer.writerow(row) diff --git a/src/omero_zarr/masks.py b/src/omero_zarr/masks.py index e403a07..44a9849 100644 --- a/src/omero_zarr/masks.py +++ b/src/omero_zarr/masks.py @@ -36,10 +36,13 @@ from omero.model import MaskI, PolygonI from omero.rtypes import unwrap from skimage.draw import polygon as sk_polygon +from zarr.errors import GroupNotFoundError from zarr.hierarchy import open_group from .util import ( + get_map_anns, get_zarr_name, + map_anns_match, marshal_axes, marshal_transformations, open_store, @@ -100,7 +103,22 @@ def plate_shapes_to_zarr( count = 0 t0 = time.time() - for well in plate.listChildren(): + skip_wells_map = args.skip_wells_map + wells = list(plate.listChildren()) + if skip_wells_map: + # skip_wells_map is like MyKey:MyValue. + # Or wildcard MyKey:* or MyKey:Val* + well_kvps_by_id = get_map_anns(wells) + wells = [ + well + for well in wells + if not map_anns_match(well_kvps_by_id.get(well.id, {}), skip_wells_map) + ] + + # sort by row then column... + wells = sorted(wells, key=lambda x: (x.row, x.column)) + + for well in wells: row = plate.getRowLabels()[well.row] col = plate.getColumnLabels()[well.column] for field in range(n_fields[0], n_fields[1] + 1): @@ -293,12 +311,13 @@ def save(self, masks: List[omero.model.Shape], name: str) -> None: # Figure out whether we can flatten some dimensions unique_dims: Dict[str, Set[int]] = { "T": {unwrap(mask.theT) for shapes in masks for mask in shapes}, + "C": {unwrap(mask.theC) for shapes in masks for mask in shapes}, "Z": {unwrap(mask.theZ) for shapes in masks for mask in shapes}, } ignored_dimensions: Set[str] = set() - # We always ignore the C dimension - ignored_dimensions.add("C") print(f"Unique dimensions: {unique_dims}") + if unique_dims["C"] == {None} or len(unique_dims["C"]) == 1: + ignored_dimensions.add("C") for d in "TZ": if unique_dims[d] == {None}: @@ -308,8 +327,6 @@ def save(self, masks: List[omero.model.Shape], name: str) -> None: # Verify that we are linking this mask to a real ome-zarr source_image = self.source_image - print(f"source_image ??? needs to be None to use filename: {source_image}") - print(f"filename: {filename}", self.output, self.name_by) source_image_link = self.source_image if source_image is None: # Assume that we're using the output directory @@ -320,18 +337,28 @@ def save(self, masks: List[omero.model.Shape], name: str) -> None: assert self.plate_path, "Need image path within the plate" source_image = f"{source_image}/{self.plate_path}" - print(f"source_image {source_image}") + print(f"Exporting labels for image at {source_image}") image_path = source_image if self.output: image_path = os.path.join(self.output, source_image) src = parse_url(image_path) assert src, f"Source image does not exist at {image_path}" + + store = open_store(image_path) + try: + # Check if labels group already exists... + open_group(store, path=f"labels/{name}", mode="r") + print(f"Labels group: {name} already exists in {image_path}") + # and if so, we assume that array data is already there + return + except GroupNotFoundError: + pass + input_pyramid = Node(src, []) assert input_pyramid.load(Multiscales), "No multiscales metadata found" input_pyramid_levels = len(input_pyramid.data) - store = open_store(image_path) - label_group = open_group(store) + image_group = open_group(store) _mask_shape: List[int] = list(self.image_shape) mask_shape: Tuple[int, ...] = tuple(_mask_shape) @@ -385,7 +412,7 @@ def save(self, masks: List[omero.model.Shape], name: str) -> None: write_multiscale_labels( label_pyramid, - label_group, + image_group, name, axes=axes, coordinate_transformations=transformations, @@ -557,6 +584,7 @@ def masks_to_labels( if shape.fillColor: fillColors[shape_value] = unwrap(shape.fillColor) binim_yx, (t, c, z, y, x, h, w) = self.shape_to_binim_yx(shape) + # if z, c or t are None, we apply the mask to all Z, C or T indices for i_t in self._get_indices(ignored_dimensions, "T", t, size_t): for i_c in self._get_indices(ignored_dimensions, "C", c, size_c): for i_z in self._get_indices( diff --git a/src/omero_zarr/raw_pixels.py b/src/omero_zarr/raw_pixels.py index ce66725..6f39253 100644 --- a/src/omero_zarr/raw_pixels.py +++ b/src/omero_zarr/raw_pixels.py @@ -49,7 +49,9 @@ from . import __version__ from . import ngff_version as VERSION from .util import ( + get_map_anns, get_zarr_name, + map_anns_match, marshal_axes, marshal_transformations, open_store, @@ -64,7 +66,13 @@ def image_to_zarr(image: omero.gateway.ImageWrapper, args: argparse.Namespace) - print(f"Exporting to {name} ({VERSION})") store = open_store(name) root = open_group(store) - add_image(image, root, tile_width=tile_width, tile_height=tile_height) + add_image( + image, + root, + tile_width=tile_width, + tile_height=tile_height, + metadata_only=args.metadata_only, + ) add_omero_metadata(root, image) add_toplevel_metadata(root) print("Finished.") @@ -75,6 +83,7 @@ def add_image( parent: Group, tile_width: Optional[int] = None, tile_height: Optional[int] = None, + metadata_only: bool = False, ) -> Tuple[int, List[Dict[str, Any]]]: """Adds an OMERO image pixel data as array to the given parent zarr group. Returns the number of resolution levels generated for the image. @@ -91,7 +100,9 @@ def add_image( longest = longest // 2 level_count += 1 - paths = add_raw_image(image, parent, level_count, tile_width, tile_height) + paths = add_raw_image( + image, parent, level_count, tile_width, tile_height, metadata_only + ) axes = marshal_axes(image) transformations = marshal_transformations(image, len(paths)) @@ -111,6 +122,7 @@ def add_raw_image( level_count: int, tile_width: Optional[int] = None, tile_height: Optional[int] = None, + metadata_only: bool = False, ) -> List[str]: pixels = image.getPrimaryPixels() omero_dtype = image.getPixelsType() @@ -159,6 +171,12 @@ def add_raw_image( chunks=chunks, dtype=d_type, ) + paths = [str(level) for level in range(level_count)] + + if metadata_only: + # Skip export of pixel data, but still create empty arrays + downsample_pyramid_on_disk(parent, paths) + return paths # Need to be sure that dims match (if array already existed) assert zarray.shape == shape @@ -200,8 +218,6 @@ def add_raw_image( tile = pixels.getTile(z, c, t, tile_dims) zarray[tuple(indices)] = tile - paths = [str(level) for level in range(level_count)] - downsample_pyramid_on_disk(parent, paths) return paths @@ -269,6 +285,7 @@ def plate_to_zarr(plate: omero.gateway._PlateWrapper, args: argparse.Namespace) n_fields = plate.getNumberOfFields() total = n_rows * n_cols * (n_fields[1] - n_fields[0] + 1) name = get_zarr_name(plate, args.output, args.name_by) + skip_wells_map = args.skip_wells_map store = open_store(name) print(f"Exporting to {name} ({VERSION})") @@ -278,7 +295,7 @@ def plate_to_zarr(plate: omero.gateway._PlateWrapper, args: argparse.Namespace) max_fields = 0 t0 = time.time() - well_paths = set() + well_paths = [] col_names = [str(name) for name in plate.getColumnLabels()] row_names = [str(name) for name in plate.getRowLabels()] @@ -289,9 +306,24 @@ def plate_to_zarr(plate: omero.gateway._PlateWrapper, args: argparse.Namespace) if acquisitions: plate_acq = [marshal_acquisition(x) for x in acquisitions] - wells = plate.listChildren() + wells = list(plate.listChildren()) # sort by row then column... wells = sorted(wells, key=lambda x: (x.row, x.column)) + well_count = len(wells) + + if skip_wells_map: + # skip_wells_map is like MyKey:MyValue. + # Or wild-card MyKey:* or MyKey:Val* + well_kvps_by_id = get_map_anns(wells) + wells = [ + well + for well in wells + if not map_anns_match(well_kvps_by_id.get(well.id, {}), skip_wells_map) + ] + print( + f"Skipping {well_count - len(wells)} out of {well_count} wells" + f" with skip_wells_map: {skip_wells_map}" + ) for well in wells: row = plate.getRowLabels()[well.row] @@ -304,7 +336,8 @@ def plate_to_zarr(plate: omero.gateway._PlateWrapper, args: argparse.Namespace) field_name = "%d" % field count += 1 img = ws.getImage() - well_paths.add(f"{row}/{col}") + if f"{row}/{col}" not in well_paths: + well_paths.append(f"{row}/{col}") field_info = {"path": f"{field_name}"} if ac: field_info["acquisition"] = ac.id @@ -312,7 +345,7 @@ def plate_to_zarr(plate: omero.gateway._PlateWrapper, args: argparse.Namespace) row_group = root.require_group(row) col_group = row_group.require_group(col) field_group = col_group.require_group(field_name) - add_image(img, field_group) + add_image(img, field_group, metadata_only=args.metadata_only) add_omero_metadata(field_group, img) # Update Well metadata after each image write_well_metadata(col_group, fields) @@ -326,7 +359,7 @@ def plate_to_zarr(plate: omero.gateway._PlateWrapper, args: argparse.Namespace) root, row_names, col_names, - wells=list(well_paths), + wells=well_paths, field_count=max_fields, acquisitions=plate_acq, name=plate.name, diff --git a/src/omero_zarr/util.py b/src/omero_zarr/util.py index b562929..3182565 100644 --- a/src/omero_zarr/util.py +++ b/src/omero_zarr/util.py @@ -18,9 +18,10 @@ import os import time +from collections import defaultdict from typing import Dict, List, Optional -from omero.gateway import BlitzObjectWrapper, ImageWrapper +from omero.gateway import BlitzObjectWrapper, ImageWrapper, MapAnnotationWrapper from zarr.storage import FSStore @@ -149,3 +150,49 @@ def get_zarr_name( if target_dir is not None: name = os.path.join(target_dir, name) return name + + +def get_map_anns(objs: List[BlitzObjectWrapper]) -> Dict[int, Dict[str, List[str]]]: + """ + Returns a map of {obj_id: {key: [value, ...]}} for all MapAnnotations + """ + map_anns_by_id: Dict[int, Dict[str, List[str]]] = {} + for obj in objs: + map_anns = defaultdict(list) + for ann in obj.listAnnotations(): + if isinstance(ann, MapAnnotationWrapper): + for key_value in ann.getValue(): + map_anns[key_value[0]].append(key_value[1]) + map_anns_by_id[obj.id] = dict(map_anns) + return map_anns_by_id + + +def map_anns_match(kvps: Dict[str, List[str]], key_value: str) -> bool: + """ + Returns True if the key_value pair matches any of the MapAnnotation values. + + kvps: a map of {key: [value, ...]} + key_value: a string in the format "key:value". value can include wildcard* + """ + if ":" not in key_value: + return False + key, value = key_value.split(":", 1) + if key in kvps: + for v in kvps[key]: + if value.endswith("*") and value.startswith("*"): + # wildcard match + if value.replace("*", "") in v: + return True + elif value.startswith("*"): + # wildcard match at the start + if v.endswith(value.replace("*", "")): + return True + elif value.endswith("*"): + # wildcard match at the end + if v.startswith(value.replace("*", "")): + return True + else: + # exact match + if v == value: + return True + return False diff --git a/test/integration/clitest/test_export.py b/test/integration/clitest/test_export.py index d6145cc..205d4a7 100644 --- a/test/integration/clitest/test_export.py +++ b/test/integration/clitest/test_export.py @@ -24,7 +24,7 @@ import dask.array as da import pytest -from omero.gateway import BlitzGateway, PlateWrapper +from omero.gateway import BlitzGateway, MapAnnotationWrapper, PlateWrapper from omero.model import ImageI, PolygonI, RoiI from omero.rtypes import rint, rstring from omero.testlib.cli import AbstractCLITest @@ -39,6 +39,25 @@ def setup_method(self, method: str) -> None: self.args = self.login_args() self.cli.register("zarr", ZarrControl, "TEST") self.args += ["zarr"] + self.plate = self.create_plate() + + def create_plate(self) -> PlateWrapper: + plates = self.import_plates( + client=self.client, + plates=1, + plate_acqs=1, + plate_cols=2, + plate_rows=2, + fields=1, + ) + plate_id = plates[0].id.val + print("Plate Created ID:", plate_id) + + conn = BlitzGateway(client_obj=self.client) + plate = conn.getObject("Plate", plate_id) + self.add_polygons_to_plate(plate) + self.add_kvps_to_wells(plate) + return plate def add_shape_to_image(self, shape: PolygonI, image: ImageI) -> None: roi = RoiI() @@ -79,13 +98,52 @@ def add_polygons_to_plate(self, plate: PlateWrapper) -> None: y = 100 + (i * 5) # Rectangles don't overlap self.add_polygon_to_image(image, xywh=[x, y, 40, 40], z=0, t=0) + # for "B1", add overlapping polygon + if wellPos == "B1": + # Add an overlapping polygon to the image in B1 + self.add_polygon_to_image(image, xywh=[20, 100, 40, 40], z=0, t=0) + + def add_kvps_to_wells(self, plate: PlateWrapper) -> None: + """Add key-value pairs of "label:A1" to each well in the plate.""" + for well in plate.listChildren(): + wellPos = well.getWellPos() + map_ann = MapAnnotationWrapper() + vals = [["label", wellPos]] + if "A" in wellPos: + vals.append(["rowA", "True"]) + map_ann.setValue(vals) + well.linkAnnotation(map_ann) + + def check_well(self, well_path: Path, label_count: int) -> None: + label_text = (well_path / "0" / "labels" / "0" / ".zattrs").read_text( + encoding="utf-8" + ) + label_image_json = json.loads(label_text) + assert "multiscales" in label_image_json + assert "image-label" in label_image_json + datasets = label_image_json["multiscales"][0]["datasets"] + for dataset in datasets: + label_path = dataset["path"] + print("label_path", well_path / "0" / "labels" / "0" / label_path) + arr_data = da.from_zarr(well_path / "0" / "labels" / "0" / label_path) + print("arr_data", arr_data) + if label_path == "0": + assert arr_data.shape == (512, 512) + max_value = arr_data.max().compute() + print("max_value", max_value) + assert max_value == label_count # export tests # ======================================================================== + @pytest.mark.parametrize("metadata_only", [False, True]) @pytest.mark.parametrize("name_by", ["id", "name"]) def test_export_zarr( - self, capsys: pytest.CaptureFixture, tmp_path: Path, name_by: str + self, + capsys: pytest.CaptureFixture, + tmp_path: Path, + name_by: str, + metadata_only: bool, ) -> None: """Test export of a Zarr image.""" sizec = 2 @@ -99,6 +157,8 @@ def test_export_zarr( "--name_by", name_by, ] + if metadata_only: + exp_args.append("--metadata_only") self.cli.invoke( self.args + exp_args, strict=True, @@ -131,20 +191,21 @@ def test_export_zarr( arr_json = json.loads(arr_text) assert arr_json["shape"] == [sizec, 512, 512] + arr_data = da.from_zarr(tmp_path / zarr_name / "0") + max_value = arr_data.max().compute() + assert metadata_only == (max_value == 0) + + @pytest.mark.parametrize("metadata_only", [False, True]) @pytest.mark.parametrize("name_by", ["id", "name"]) def test_export_plate( - self, capsys: pytest.CaptureFixture, tmp_path: Path, name_by: str + self, + capsys: pytest.CaptureFixture, + tmp_path: Path, + name_by: str, + metadata_only: bool, ) -> None: - plates = self.import_plates( - client=self.client, - plates=1, - plate_acqs=1, - plate_cols=2, - plate_rows=2, - fields=1, - ) - plate_id = plates[0].id.val + plate_id = self.plate.id exp_args = [ "export", f"Plate:{plate_id}", @@ -153,12 +214,13 @@ def test_export_plate( "--name_by", name_by, ] + if metadata_only: + exp_args.append("--metadata_only") self.cli.invoke( self.args + exp_args, strict=True, ) - plate = self.query.get("Plate", plate_id) - plate_name = plate.name.val + plate_name = self.plate.name zarr_name = ( f"{plate_name}.ome.zarr" if name_by == "name" else f"{plate_id}.ome.zarr" ) @@ -176,12 +238,12 @@ def test_export_plate( assert len(attrs_json["plate"]["wells"]) == 4 assert attrs_json["plate"]["rows"] == [{"name": "A"}, {"name": "B"}] assert attrs_json["plate"]["columns"] == [{"name": "1"}, {"name": "2"}] - - arr_text = (tmp_path / zarr_name / "A" / "1" / "0" / "0" / ".zarray").read_text( - encoding="utf-8" - ) - arr_json = json.loads(arr_text) - assert arr_json["shape"] == [512, 512] + # check first well A1 + arr_data = da.from_zarr(tmp_path / zarr_name / "A" / "1" / "0" / "0") + assert arr_data.shape == (512, 512) + print("arr_data", arr_data) + max_value = arr_data.max().compute() + assert metadata_only == (max_value == 0) @pytest.mark.parametrize("name_by", ["id", "name"]) def test_export_masks( @@ -192,20 +254,21 @@ def test_export_masks( img_id = images[0].id.val size_xy = 512 - # Create a mask + # Create a mask for each channel from skimage.data import binary_blobs - blobs = binary_blobs(length=size_xy, volume_fraction=0.1, n_dim=2).astype( - "int8" - ) red = [255, 0, 0, 255] - mask = mask_from_binary_image(blobs, rgba=red, z=0, c=0, t=0) - - roi = RoiI() - roi.setImage(images[0]) - roi.addShape(mask) - updateService = self.client.sf.getUpdateService() - updateService.saveAndReturnObject(roi) + green = [0, 255, 0, 255] + for ch, color in enumerate([red, green]): + blobs = binary_blobs(length=size_xy, volume_fraction=0.1, n_dim=2).astype( + "int8" + ) + mask = mask_from_binary_image(blobs, rgba=color, z=0, c=ch, t=0) + roi = RoiI() + roi.setImage(images[0]) + roi.addShape(mask) + updateService = self.client.sf.getUpdateService() + updateService.saveAndReturnObject(roi) print("tmp_path", tmp_path) @@ -232,39 +295,43 @@ def test_export_masks( all_lines = ", ".join(lines) assert "Exporting to" in all_lines assert "Finished" in all_lines - assert "Found 1 mask shapes in 1 ROIs" in all_lines + assert "Found 2 mask shapes in 2 ROIs" in all_lines labels_text = (tmp_path / zarr_name / "labels" / "0" / ".zattrs").read_text( encoding="utf-8" ) labels_json = json.loads(labels_text) - assert labels_json["image-label"]["colors"] == [{"label-value": 1, "rgba": red}] + assert labels_json["image-label"]["colors"] == [ + {"label-value": 1, "rgba": red}, + {"label-value": 2, "rgba": green}, + ] arr_text = (tmp_path / zarr_name / "labels" / "0" / "0" / ".zarray").read_text( encoding="utf-8" ) arr_json = json.loads(arr_text) - assert arr_json["shape"] == [1, 512, 512] + assert arr_json["shape"] == [2, 512, 512] + + SKIP_WELLS_MAPS = { + "": ["A1", "A2", "B1", "B2"], + "label:B*": ["A1", "A2"], + "label:A1": ["A2", "B1", "B2"], + "label:*2": ["A1", "B1"], + "rowA:*": ["B1", "B2"], + } @pytest.mark.parametrize("name_by", ["id", "name"]) + @pytest.mark.parametrize("skip_wells_map", SKIP_WELLS_MAPS.keys()) def test_export_plate_polygons( - self, capsys: pytest.CaptureFixture, tmp_path: Path, name_by: str + self, + capsys: pytest.CaptureFixture, + tmp_path: Path, + name_by: str, + skip_wells_map: str, ) -> None: - plates = self.import_plates( - client=self.client, - plates=1, - plate_acqs=1, - plate_cols=2, - plate_rows=2, - fields=1, - ) - plate_id = plates[0].id.val - - conn = BlitzGateway(client_obj=self.client) - plate = conn.getObject("Plate", plate_id) - self.add_polygons_to_plate(plate) - + plate = self.plate + plate_id = plate.id print("Plate ID:", plate_id) extra_args = [ f"Plate:{plate_id}", @@ -272,14 +339,18 @@ def test_export_plate_polygons( str(tmp_path), "--name_by", name_by, + "--skip_wells_map", + skip_wells_map, ] self.cli.invoke( self.args + ["export"] + extra_args, strict=True, ) + # Don't fail on "B1" due to overlapping polygons + overlap_args = ["--overlaps", "dtype_max"] self.cli.invoke( - self.args + ["polygons"] + extra_args, + self.args + ["polygons"] + extra_args + overlap_args, strict=True, ) @@ -287,27 +358,67 @@ def test_export_plate_polygons( f"{plate.name}.ome.zarr" if name_by == "name" else f"{plate_id}.ome.zarr" ) + plate_text = (tmp_path / zarr_name / ".zattrs").read_text(encoding="utf-8") + plate_json = json.loads(plate_text) + assert "plate" in plate_json + well_labels = [ + well["path"].replace("/", "") for well in plate_json["plate"]["wells"] + ] + assert well_labels == self.SKIP_WELLS_MAPS[skip_wells_map] + print("tmp_path", tmp_path) - def check_well(well_path: Path, label_count: int) -> None: - label_text = (well_path / "0" / "labels" / "0" / ".zattrs").read_text( - encoding="utf-8" + # expect 1 label in A1, 4 labels in B2 + if "A1" in self.SKIP_WELLS_MAPS[skip_wells_map]: + self.check_well(tmp_path / zarr_name / "A" / "1", 1) + if "B2" in self.SKIP_WELLS_MAPS[skip_wells_map]: + self.check_well(tmp_path / zarr_name / "B" / "2", 4) + # overlapping polygons in B1 - dtype max + if "B1" in self.SKIP_WELLS_MAPS[skip_wells_map]: + self.check_well(tmp_path / zarr_name / "B" / "1", 127) + + def test_plate_polygons_overlap( + self, + capsys: pytest.CaptureFixture, + tmp_path: Path, + ) -> None: + + plate = self.plate + plate_id = plate.id + print("Plate ID:", plate_id) + + extra_args = [f"Plate:{plate_id}", "--output", str(tmp_path)] + + self.cli.invoke( + self.args + ["export"] + extra_args, + strict=True, + ) + + # This should fail on "B1" due to overlapping polygons + with pytest.raises(Exception) as exc_info: + self.cli.invoke( + self.args + ["polygons"] + extra_args, + strict=True, ) - label_image_json = json.loads(label_text) - assert "multiscales" in label_image_json - assert "image-label" in label_image_json - datasets = label_image_json["multiscales"][0]["datasets"] - for dataset in datasets: - label_path = dataset["path"] - print("label_path", well_path / "0" / "labels" / "0" / label_path) - arr_data = da.from_zarr(well_path / "0" / "labels" / "0" / label_path) - print("arr_data", arr_data) - if label_path == "0": - assert arr_data.shape == (512, 512) - max_value = arr_data.max().compute() - print("max_value", max_value) - assert max_value == label_count + assert "overlaps with existing labels" in str(exc_info.value) - # expect 1 label in A1, 4 labels in B2 - check_well(tmp_path / zarr_name / "A" / "1", 1) - check_well(tmp_path / zarr_name / "B" / "2", 4) + zarr_name = f"{plate_id}.ome.zarr" + + # First Wells labels should be exported OK... + self.check_well(tmp_path / zarr_name / "A" / "1", 1) + self.check_well(tmp_path / zarr_name / "A" / "2", 2) + # Wells B3 and B4 - no labels + with pytest.raises(FileNotFoundError): + self.check_well(tmp_path / zarr_name / "B" / "1", 127) + with pytest.raises(FileNotFoundError): + self.check_well(tmp_path / zarr_name / "B" / "2", 4) + + # Test that we can pick-up labels export where we left off + overlap_args = ["--overlaps", "dtype_max"] + self.cli.invoke( + self.args + ["polygons"] + extra_args + overlap_args, + strict=True, + ) + + self.check_well(tmp_path / zarr_name / "B" / "1", 127) + self.check_well(tmp_path / zarr_name / "B" / "2", 4)