diff --git a/metadata/register.py b/metadata/register.py index a83c385..0fe42f8 100644 --- a/metadata/register.py +++ b/metadata/register.py @@ -20,6 +20,7 @@ # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. from urllib.parse import urlsplit +import numpy as np import zarr import argparse @@ -29,19 +30,21 @@ # from getpass import getpass from omero.cli import cli_login -from omero.gateway import BlitzGateway +from omero.gateway import BlitzGateway, ColorHolder from omero.gateway import OMERO_NUMPY_TYPES import omero +import omero_rois from omero.model.enums import PixelsTypeint8, PixelsTypeuint8, PixelsTypeint16 from omero.model.enums import PixelsTypeuint16, PixelsTypeint32 from omero.model.enums import PixelsTypeuint32, PixelsTypefloat from omero.model.enums import PixelsTypecomplex, PixelsTypedouble -from omero.model import ExternalInfoI +from omero.model import ExternalInfoI, RoiI, MaskI from omero.rtypes import rbool, rdouble, rint, rlong, rstring, rtime + AWS_DEFAULT_ENDPOINT = "s3.us-east-1.amazonaws.com" PIXELS_TYPE = {'int8': PixelsTypeint8, @@ -88,6 +91,125 @@ def load_attrs(store, path=None): return attrs +def masks_from_labels_nd( + labels_nd, axes="tcz", label_props=None): + rois = {} + + colors_by_value = {} + if "colors" in label_props: + for color in label_props["colors"]: + pixel_value = color.get("label-value", None) + rgba = color.get("rgba", None) + if pixel_value and rgba and len(rgba) == 4: + colors_by_value[pixel_value] = rgba + + text_by_value = {} + if "properties" in label_props: + for props in label_props["properties"]: + pixel_value = props.get("label-value", None) + text = props.get("omero:text", None) + if pixel_value and text: + text_by_value[pixel_value] = text + + # For each label value, we create an ROI that + # contains 2D masks for each time point, channel, and z-slice. + for i in range(1, int(labels_nd.max()) + 1): + print("Mask value", i) + if not np.any(labels_nd == i): + continue + + masks = [] + bin_img = labels_nd == i + + sizes = {dim: labels_nd.shape[axes.index(dim)] for dim in axes} + size_t = sizes.get("t", 1) + size_c = sizes.get("c", 1) + size_z = sizes.get("z", 1) + + for t in range(size_t): + for c in range(size_c): + for z in range(size_z): + print("t, c, z", t, c, z) + + indices = [] + if "t" in axes: + indices.append(t) + if "c" in axes: + indices.append(c) + if "z" in axes: + indices.append(z) + + # indices.append(np.s_[::]) + # indices.append(np.s_[x:x_max:]) + + # slice down to 2D plane + plane = bin_img[tuple(indices)] + + if not np.any(plane): + continue + + # plane = plane.compute() + + # Find bounding box to minimise size of mask + xmask = plane.sum(0).nonzero()[0] + ymask = plane.sum(1).nonzero()[0] + print("xmask", xmask, "ymask", ymask) + # if any(xmask) and any(ymask): + x0 = min(xmask) + w = max(xmask) - x0 + 1 + y0 = min(ymask) + h = max(ymask) - y0 + 1 + print("cropping to x, y, w, h", x0, y0, w, h) + submask = plane[y0:(y0 + h), x0:(x0 + w)] + + mask = MaskI() + mask.setBytes(np.packbits(np.asarray(submask, dtype=int))) + mask.setWidth(rdouble(w)) + mask.setHeight(rdouble(h)) + mask.setX(rdouble(x0)) + mask.setY(rdouble(y0)) + + if i in colors_by_value: + ch = ColorHolder.fromRGBA(*colors_by_value[i]) + mask.setFillColor(rint(ch.getInt())) + if "z" in axes: + mask.setTheZ(rint(z)) + if "c" in axes: + mask.setTheC(rint(c)) + if "t" in axes: + mask.setTheT(rint(t)) + if i in text_by_value: + mask.setTextValue(rstring(text_by_value[i])) + + masks.append(mask) + + rois[i] = masks + + return rois + + +def rois_from_labels_nd(conn, img, labels_nd, axes="tcz", label_props=None): + # Text is set on Mask shapes, not ROIs + rois = masks_from_labels_nd(labels_nd, axes, label_props) + + for label, masks in rois.items(): + if len(masks) > 0: + create_roi(conn, img=img, shapes=masks) + + +def create_roi(conn, img, shapes, name=None): + # create an ROI, link it to Image + roi = RoiI() + roi.setImage(omero.model.ImageI(img.id, False)) + if name is not None: + roi.setName(rstring(name)) + for shape in shapes: + roi.addShape(shape) + # Save the ROI (saves any linked shapes too) + print(f"Save ROI for image {img.getName()}") + return conn.getUpdateService().saveAndReturnObject(roi) + + def parse_image_metadata(store, img_attrs, image_path=None): """ Parse the image metadata @@ -113,6 +235,29 @@ def parse_image_metadata(store, img_attrs, image_path=None): return sizes, pixels_type +def create_labels(conn, store, image, labels_path): + + """ + Create labels for the image + """ + label_image = load_attrs(store, labels_path) + + axes = label_image["multiscales"][0]["axes"] + axes_names = [axis["name"] for axis in axes] + label_props = label_image.get("image-label", None) + + ds_path = label_image["multiscales"][0]["datasets"][0]["path"] + array_path = f"{labels_path}/{ds_path}/" + labels_nd = load_array(store, array_path) + labels_data = labels_nd[slice(None)] + print("labels_nd", labels_nd) + print("labels_data", labels_data) + print("axes_names", axes_names) + + # Create ROIs from the labels + rois_from_labels_nd(conn, image, labels_data, axes_names, label_props) + + def create_image(conn, store, image_attrs, object_name, families, models, args, image_path=None): ''' Create an Image/Pixels object @@ -124,7 +269,6 @@ def create_image(conn, store, image_attrs, object_name, families, models, args, size_z = sizes.get("z", 1) size_x = sizes.get("x", 1) size_y = sizes.get("y", 1) - size_c = sizes.get("c", 1) # if channels is None or len(channels) != size_c: channels = list(range(sizes.get("c", 1))) omero_pixels_type = query_service.findByQuery("from PixelsType as p where p.value='%s'" % PIXELS_TYPE[pixels_type], None) @@ -141,7 +285,21 @@ def create_image(conn, store, image_attrs, object_name, families, models, args, img_obj = image._obj set_external_info(img_obj, args, image_path) - + + # check for labels... + if args.labels: + labels_path = "labels/" + if image_path is not None: + labels_path = image_path.rstrip("/") + "/" + labels_path + print("checking for labels at", labels_path) + try: + labels_attrs = load_attrs(store, labels_path) + print("labels_attrs", labels_attrs) + if "labels" in labels_attrs: + for pth in labels_attrs["labels"]: + create_labels(conn, store, image, f"{labels_path}/{pth}/") + except FileNotFoundError: + pass return img_obj, rnd_def def hex_to_rgba(hex_color): @@ -170,6 +328,8 @@ def get_channels(omero_info): def set_channel_names(conn, iid, omero_attrs): channel_names = get_channels(omero_attrs) + if len(channel_names) == 0: + return nameDict = dict((i + 1, name) for i, name in enumerate(channel_names)) conn.setChannelNames("Image", [iid], nameDict) @@ -527,7 +687,8 @@ def main(): parser.add_argument("--nosignrequest", required=False, action='store_true', help="Indicate to sign anonymously") parser.add_argument("--target", required=False, type=str, help="The id of the target (dataset/screen)") parser.add_argument("--target-by-name", required=False, type=str, help="The name of the target (dataset/screen)") - + parser.add_argument("--labels", required=False, action='store_true', help="Also import any OME-Zarr labels found") + args = parser.parse_args() with cli_login() as cli: