Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 166 additions & 5 deletions metadata/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

from urllib.parse import urlsplit
import numpy as np
import zarr

import argparse
Expand All @@ -29,19 +30,21 @@
# from getpass import getpass

from omero.cli import cli_login
from omero.gateway import BlitzGateway
from omero.gateway import BlitzGateway, ColorHolder
from omero.gateway import OMERO_NUMPY_TYPES

import omero
import omero_rois

from omero.model.enums import PixelsTypeint8, PixelsTypeuint8, PixelsTypeint16
from omero.model.enums import PixelsTypeuint16, PixelsTypeint32
from omero.model.enums import PixelsTypeuint32, PixelsTypefloat
from omero.model.enums import PixelsTypecomplex, PixelsTypedouble

from omero.model import ExternalInfoI
from omero.model import ExternalInfoI, RoiI, MaskI
from omero.rtypes import rbool, rdouble, rint, rlong, rstring, rtime


AWS_DEFAULT_ENDPOINT = "s3.us-east-1.amazonaws.com"

PIXELS_TYPE = {'int8': PixelsTypeint8,
Expand Down Expand Up @@ -88,6 +91,125 @@ def load_attrs(store, path=None):
return attrs


def masks_from_labels_nd(
labels_nd, axes="tcz", label_props=None):
rois = {}

colors_by_value = {}
if "colors" in label_props:
for color in label_props["colors"]:
pixel_value = color.get("label-value", None)
rgba = color.get("rgba", None)
if pixel_value and rgba and len(rgba) == 4:
colors_by_value[pixel_value] = rgba

text_by_value = {}
if "properties" in label_props:
for props in label_props["properties"]:
pixel_value = props.get("label-value", None)
text = props.get("omero:text", None)
if pixel_value and text:
text_by_value[pixel_value] = text

# For each label value, we create an ROI that
# contains 2D masks for each time point, channel, and z-slice.
for i in range(1, int(labels_nd.max()) + 1):
print("Mask value", i)
if not np.any(labels_nd == i):
continue

masks = []
bin_img = labels_nd == i

sizes = {dim: labels_nd.shape[axes.index(dim)] for dim in axes}
size_t = sizes.get("t", 1)
size_c = sizes.get("c", 1)
size_z = sizes.get("z", 1)

for t in range(size_t):
for c in range(size_c):
for z in range(size_z):
print("t, c, z", t, c, z)

indices = []
if "t" in axes:
indices.append(t)
if "c" in axes:
indices.append(c)
if "z" in axes:
indices.append(z)

# indices.append(np.s_[::])
# indices.append(np.s_[x:x_max:])

# slice down to 2D plane
plane = bin_img[tuple(indices)]

if not np.any(plane):
continue

# plane = plane.compute()

# Find bounding box to minimise size of mask
xmask = plane.sum(0).nonzero()[0]
ymask = plane.sum(1).nonzero()[0]
print("xmask", xmask, "ymask", ymask)
# if any(xmask) and any(ymask):
x0 = min(xmask)
w = max(xmask) - x0 + 1
y0 = min(ymask)
h = max(ymask) - y0 + 1
print("cropping to x, y, w, h", x0, y0, w, h)
submask = plane[y0:(y0 + h), x0:(x0 + w)]

mask = MaskI()
mask.setBytes(np.packbits(np.asarray(submask, dtype=int)))
mask.setWidth(rdouble(w))
mask.setHeight(rdouble(h))
mask.setX(rdouble(x0))
mask.setY(rdouble(y0))

if i in colors_by_value:
ch = ColorHolder.fromRGBA(*colors_by_value[i])
mask.setFillColor(rint(ch.getInt()))
if "z" in axes:
mask.setTheZ(rint(z))
if "c" in axes:
mask.setTheC(rint(c))
if "t" in axes:
mask.setTheT(rint(t))
if i in text_by_value:
mask.setTextValue(rstring(text_by_value[i]))

masks.append(mask)

rois[i] = masks

return rois


def rois_from_labels_nd(conn, img, labels_nd, axes="tcz", label_props=None):
# Text is set on Mask shapes, not ROIs
rois = masks_from_labels_nd(labels_nd, axes, label_props)

for label, masks in rois.items():
if len(masks) > 0:
create_roi(conn, img=img, shapes=masks)


def create_roi(conn, img, shapes, name=None):
# create an ROI, link it to Image
roi = RoiI()
roi.setImage(omero.model.ImageI(img.id, False))
if name is not None:
roi.setName(rstring(name))
for shape in shapes:
roi.addShape(shape)
# Save the ROI (saves any linked shapes too)
print(f"Save ROI for image {img.getName()}")
return conn.getUpdateService().saveAndReturnObject(roi)


def parse_image_metadata(store, img_attrs, image_path=None):
"""
Parse the image metadata
Expand All @@ -113,6 +235,29 @@ def parse_image_metadata(store, img_attrs, image_path=None):
return sizes, pixels_type


def create_labels(conn, store, image, labels_path):

"""
Create labels for the image
"""
label_image = load_attrs(store, labels_path)

axes = label_image["multiscales"][0]["axes"]
axes_names = [axis["name"] for axis in axes]
label_props = label_image.get("image-label", None)

ds_path = label_image["multiscales"][0]["datasets"][0]["path"]
array_path = f"{labels_path}/{ds_path}/"
labels_nd = load_array(store, array_path)
labels_data = labels_nd[slice(None)]
print("labels_nd", labels_nd)
print("labels_data", labels_data)
print("axes_names", axes_names)

# Create ROIs from the labels
rois_from_labels_nd(conn, image, labels_data, axes_names, label_props)


def create_image(conn, store, image_attrs, object_name, families, models, args, image_path=None):
'''
Create an Image/Pixels object
Expand All @@ -124,7 +269,6 @@ def create_image(conn, store, image_attrs, object_name, families, models, args,
size_z = sizes.get("z", 1)
size_x = sizes.get("x", 1)
size_y = sizes.get("y", 1)
size_c = sizes.get("c", 1)
# if channels is None or len(channels) != size_c:
channels = list(range(sizes.get("c", 1)))
omero_pixels_type = query_service.findByQuery("from PixelsType as p where p.value='%s'" % PIXELS_TYPE[pixels_type], None)
Expand All @@ -141,7 +285,21 @@ def create_image(conn, store, image_attrs, object_name, families, models, args,

img_obj = image._obj
set_external_info(img_obj, args, image_path)


# check for labels...
if args.labels:
labels_path = "labels/"
if image_path is not None:
labels_path = image_path.rstrip("/") + "/" + labels_path
print("checking for labels at", labels_path)
try:
labels_attrs = load_attrs(store, labels_path)
print("labels_attrs", labels_attrs)
if "labels" in labels_attrs:
for pth in labels_attrs["labels"]:
create_labels(conn, store, image, f"{labels_path}/{pth}/")
except FileNotFoundError:
pass
return img_obj, rnd_def

def hex_to_rgba(hex_color):
Expand Down Expand Up @@ -170,6 +328,8 @@ def get_channels(omero_info):

def set_channel_names(conn, iid, omero_attrs):
channel_names = get_channels(omero_attrs)
if len(channel_names) == 0:
return
nameDict = dict((i + 1, name) for i, name in enumerate(channel_names))
conn.setChannelNames("Image", [iid], nameDict)

Expand Down Expand Up @@ -527,7 +687,8 @@ def main():
parser.add_argument("--nosignrequest", required=False, action='store_true', help="Indicate to sign anonymously")
parser.add_argument("--target", required=False, type=str, help="The id of the target (dataset/screen)")
parser.add_argument("--target-by-name", required=False, type=str, help="The name of the target (dataset/screen)")

parser.add_argument("--labels", required=False, action='store_true', help="Also import any OME-Zarr labels found")

args = parser.parse_args()

with cli_login() as cli:
Expand Down