diff --git a/examples/DynaCLR/annotation/_reader.py b/examples/DynaCLR/annotation/_reader.py new file mode 100644 index 00000000..ce023a13 --- /dev/null +++ b/examples/DynaCLR/annotation/_reader.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import dask.array as da +import numpy as np +import zarr +from iohub.ngff.nodes import ( + MultiScaleMeta, + NGFFNode, + OMEROMeta, + Plate, + Position, + Well, + open_ome_zarr, +) +from pydantic_extra_types.color import Color + +if TYPE_CHECKING: + from _typeshed import StrOrBytesPath + + +def napari_get_reader(path): + """A basic implementation of a Reader contribution. + + Parameters + ---------- + path : str or list of str + Path to file, or list of paths. + + Returns + ------- + function or None + If the path is a recognized format, return a function that accepts the + same path or list of paths, and returns a list of layer data tuples. + """ + if isinstance(path, list): + # reader plugins may be handed single path, or a list of paths. + # if it is a list, it is assumed to be an image stack... + # so we are only going to look at the first file. + path = path[0] + path = Path(path) + # if we know we cannot read the file, we immediately return None. + if not path.is_dir() and path.exists(): + return None + if not ((path / ".zattrs").exists() or (path / "zarr.json").exists()): + return None + # otherwise we return the *function* that can read ``path``. + return reader_function + + +def _get_node(path: StrOrBytesPath): + try: + zgroup = zarr.open_group(path, mode="r") + except Exception as e: + raise RuntimeError(e) + attrs = zgroup.attrs + if "ome" in attrs: + attrs = attrs["ome"] + if "well" in attrs: + first_pos_grp = next(zgroup.groups())[1] + channel_names = Position(first_pos_grp).channel_names + node = Well( + group=zgroup, + parse_meta=True, + channel_names=channel_names, + version="0.4", + ) + elif "plate" in attrs or "multiscales" in attrs: + zgroup.store.close() + node = open_ome_zarr(store_path=path, mode="r") + else: + raise KeyError(f"NGFF plate or well metadata not found under '{zgroup.name}'") + return node + + +def stitch_well_by_channel(well: Well, row_wrap: int): + logging.debug(f"Stitching well: {well.zgroup.name}") + levels = [] + pyramids: list[list] = [] + for i, (_, pos) in enumerate(well.positions()): + lv, ims = _get_multiscales(pos) + levels.append(lv) + pyramids.append(ims) + if i == 0: + layers_kwargs = _ome_to_napari_by_channel(pos.metadata) + stitched_arrays = [] + for i in range(max(levels)): + ims = [p[i] for p in pyramids if i < len(p)] + grid = _make_grid(ims, cols=row_wrap) + stitched_arrays.append(da.block(grid)) + return layers_kwargs, _find_ch_axis(well), stitched_arrays + + +def stack_well_by_position(well: Well): + logging.debug(f"Stacking well: {well.zgroup.name}") + levels = [] + pyramids: list[list] = [] + for i, (_, pos) in enumerate(well.positions()): + lv, ims = _get_multiscales(pos) + levels.append(lv) + pyramids.append(ims) + if i == 0: + layers_kwargs = _ome_to_napari_by_channel(pos.metadata) + stacked_arrays = [] + for i in range(max(levels)): + ims = [p[i] for p in pyramids if i < len(p)] + stacked_arrays.append(da.stack(ims, axis=0)) + return layers_kwargs, _find_ch_axis(well), stacked_arrays + + +def _get_multiscales(pos: Position): + ms: MultiScaleMeta = pos.metadata.multiscales[0] + images = [dataset.path for dataset in ms.datasets] + multiscales = [] + for im in images: + try: + multiscales.append(pos[im].dask_array()) + except Exception as e: + logging.warning(f"Skipped array '{im}' at position {pos.zgroup.name}: {e}") + return len(multiscales), multiscales + + +def _make_grid(elements: list[da.Array], cols: int): + ct = len(elements) + rows = ct // cols + int(bool(ct % cols)) + grid = [elements[r * cols : (r + 1) * cols] for r in range(rows)] + diff = len(grid[0]) - len(grid[-1]) + if diff > 0: + fill_shape = grid[0][0].shape + fill_type = grid[0][0].dtype + grid[-1].extend([da.zeros(fill_shape, fill_type)] * diff) + return grid + + +def _ome_to_napari_by_channel( + metadata, parse_colormap: bool = True, num_channels: int | None = None +): + if metadata.omero is None: + if num_channels is None: + raise ValueError("num_channels must be set when omero is not present") + return [{"name": f"{i}", "blending": "additive"} for i in range(num_channels)] + omero: OMEROMeta = metadata.omero + layers_kwargs = [] + for channel in omero.channels: + meta = {"name": channel.label} + if channel.color and parse_colormap: + # alpha channel is optional + rgb = Color(channel.color).as_rgb_tuple(alpha=None) + start = [0.0] * 3 + if len(rgb) == 4: + start += [1] + meta["colormap"] = np.array( + [ + start, + [v / np.iinfo(np.uint8).max for v in rgb], + ] + ) + meta["blending"] = "additive" + layers_kwargs.append(meta) + return layers_kwargs + + +def _find_ch_axis(dataset: NGFFNode) -> int | None: + for i, axis in enumerate(dataset.axes): + if axis.type == "channel": + return i + + +def layers_from_arrays( + layers_kwargs: list, + ch_axis: int, + arrays: list, + mode: Literal["stitch", "stack"], + layer_type="image", +): + if mode == "stack": + ch_axis += 1 + elif mode != "stitch": + raise ValueError(f"Unknown mode '{mode}'") + if ch_axis is not None: + if ch_axis == 0: + pre_idx = [] + else: + pre_idx = [slice(None)] * ch_axis + layers = [] + for i, kwargs in enumerate(layers_kwargs): + if ch_axis is not None: + slc = tuple(pre_idx + [slice(i, i + 1)]) + data = [da.squeeze(arr[slc], axis=ch_axis) for arr in arrays] + else: + if i > 0: + raise RuntimeError("mismatched number of channels") + data = [arr for arr in arrays] + layer = (data, kwargs, layer_type) + layers.append(layer) + return layers + + +def fov_to_layers(fov: Position, layer_type: str = "image"): + ch_axis = _find_ch_axis(fov) + arrays = [arr.dask_array() for _, arr in fov.images()] + layers_kwargs = _ome_to_napari_by_channel( + fov.metadata, + parse_colormap=(layer_type == "image"), + num_channels=arrays[0].shape[ch_axis] if ch_axis is not None else 1, + ) + return layers_from_arrays( + layers_kwargs, ch_axis, arrays, mode="stitch", layer_type=layer_type + ) + + +def well_to_layers(well: Well, mode: Literal["stitch", "stack"], layer_type: str): + if mode == "stitch": + layers_kwargs, ch_axis, arrays = stitch_well_by_channel(well, row_wrap=4) + elif mode == "stack": + layers_kwargs, ch_axis, arrays = stack_well_by_position(well) + return layers_from_arrays( + layers_kwargs, ch_axis, arrays, mode=mode, layer_type=layer_type + ) + + +def make_bbox(bbox_extents): + """Copied from: + https://napari.org/stable/tutorials/segmentation/annotate_segmentation.html + Get the coordinates of the corners of a + bounding box from the extents + + Parameters + ---------- + bbox_extents : list (4xN) + List of the extents of the bounding boxes for each of the N regions. + Should be ordered: [min_row, min_column, max_row, max_column] + + Returns + ------- + bbox_rect : np.ndarray + The corners of the bounding box. Can be input directly into a + napari Shapes layer. + """ + minr = bbox_extents[0] + minc = bbox_extents[1] + maxr = bbox_extents[2] + maxc = bbox_extents[3] + + bbox_rect = np.array([[minr, minc], [maxr, minc], [maxr, maxc], [minr, maxc]]) + bbox_rect = np.moveaxis(bbox_rect, 2, 0) + + return bbox_rect + + +def plate_to_layers( + plate: Plate, + row_range: tuple[int, int] = None, + col_range: tuple[int, int] = None, +): + plate_arrays = [] + rows = plate.metadata.rows + if row_range: + rows = rows[row_range[0] : row_range[1]] + columns = plate.metadata.columns + if col_range: + columns = columns[col_range[0] : col_range[1]] + boxes = [[] for _ in range(4)] + properties = {"fov": []} + well_paths = [w.path for w in plate.metadata.wells] + for i, row_meta in enumerate(rows): + row_name = row_meta.name + row_arrays = [] + for j, col_meta in enumerate(columns): + col_name = col_meta.name + well_path = f"{row_name}/{col_name}" + if well_path in well_paths: + well = plate[row_name][col_name] + layers_kwargs, ch_axis, arrays = stack_well_by_position(well) + row_arrays.append([a[0] for a in arrays]) + height, width = arrays[0][0].shape[-2:] + box_extents = [ + height * i, + width * j, + height * (i + 1), + width * (j + 1), + ] + for k in range(len(boxes)): + boxes[k].append(box_extents[k] - 0.5) + properties["fov"].append(well_path + "/" + next(well.positions())[0]) + else: + row_arrays.append(None) + plate_arrays.append(row_arrays) + first_blocks = next(a for a in plate_arrays[0] if a is not None) + fill_args = [(b.shape, b.dtype) for b in first_blocks] + plate_levels = [] + for level, first_block in enumerate(first_blocks): + plate_level = [] + for r in plate_arrays: + row_level = [] + for c in r: + if c is None: + arr = da.zeros( + shape=fill_args[level][0], + dtype=fill_args[level][1], + chunks=first_block.chunksize, + ) + else: + arr = c[level] + row_level.append(arr) + plate_level.append(row_level) + plate_levels.append(da.block(plate_level)) + layers = layers_from_arrays( + layers_kwargs, + ch_axis, + plate_levels, + mode="stitch", + layer_type="image", + ) + layers.append( + ( + make_bbox(boxes), + { + "face_color": "transparent", + "edge_color": "black", + "properties": properties, + "text": {"string": "fov", "color": "orange"}, + "name": "Plate Map", + }, + "shapes", + ) + ) + return layers + + +def reader_function(path): + """Take a path or list of paths and return a list of LayerData tuples. + + Readers are expected to return data as a list of tuples, where each tuple + is (data, [add_kwargs, [layer_type]]), "add_kwargs" and "layer_type" are + both optional. + + Parameters + ---------- + path : str or list of str + Path to file, or list of paths. + + Returns + ------- + layer_data : list of tuples + A list of LayerData tuples where each tuple in the list contains + (data, metadata, layer_type), where data is a numpy array, metadata is + a dict of keyword arguments for the corresponding viewer.add_* method + in napari, and layer_type is a lower-case string naming the type of + layer. Both "meta", and "layer_type" are optional. napari will + default to layer_type=="image" if not provided + """ + node = _get_node(path) + match node: + case Plate(): + return plate_to_layers(node) + case Well(): + return well_to_layers(node, mode="stitch", layer_type="image") + case Position(): + return fov_to_layers(node) diff --git a/examples/DynaCLR/annotation/napari_annotate.py b/examples/DynaCLR/annotation/napari_annotate.py new file mode 100644 index 00000000..1868a858 --- /dev/null +++ b/examples/DynaCLR/annotation/napari_annotate.py @@ -0,0 +1,574 @@ +# %% +import logging +from pathlib import Path + +import click +import napari +import numpy as np +import pandas as pd +from _reader import fov_to_layers +from iohub import open_ome_zarr +from napari.types import LayerDataTuple + +from viscy.data.triplet import INDEX_COLUMNS + +_logger = logging.getLogger("viscy") +_logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +_logger.addHandler(console_handler) + + +def _ultrack_inv_tracks_df_forest(df: pd.DataFrame, no_parent=-1) -> dict[int, int]: + """ + Vendored from ultrack.tracks.graph.inv_tracks_df_forest. + """ + for col in ["track_id", "parent_track_id"]: + if col not in df.columns: + raise ValueError( + f"The input dataframe does not contain the column '{col}'." + ) + + df = df.drop_duplicates("track_id") + df = df[df["parent_track_id"] != no_parent] + graph = {} + for parent_id, id in zip(df["parent_track_id"], df["track_id"]): + graph[id] = parent_id + return graph + + +def _ultrack_read_csv(path: Path | str) -> LayerDataTuple: + """ + Vendored from ultrack.reader.napari_reader.read_csv. + """ + if isinstance(path, str): + path = Path(path) + + df = pd.read_csv(path) + + _logger.info(f"Read {len(df)} tracks from {path}") + _logger.info(df.head()) + + # For napari tracks layer, only use position columns: [track_id, t, y, x,z] + tracks_cols = [ + "track_id", + "t", + "z", + "y", + "x", + ] + if "z" not in df.columns: + tracks_cols.remove("z") + + if "parent_track_id" in df.columns: + graph = _ultrack_inv_tracks_df_forest(df) + _logger.info(f"Track lineage graph with length {len(graph)}") + else: + graph = None + + kwargs = { + "features": df, # Full dataframe with all columns is stored in features + "name": path.name.removesuffix(".csv"), + "graph": graph, + } + + return (df[tracks_cols], kwargs, "tracks") + + +# %% +def open_image_and_tracks( + images_dataset: Path, + tracks_dataset: Path, + fov_name: str, + expand_z_for_tracking_labels: bool = True, + load_tracks_layer: bool = True, + tracks_z_index: int = -1, +) -> list[napari.types.LayerDataTuple]: + """ + Load images and tracking labels. + Also load predicted features (if supplied) + and associate them with the tracking labels. + To be used with napari-clusters-plotter plugin. + + Parameters + ---------- + images_dataset : pathlib.Path + Path to the images dataset (HCS OME-Zarr). + tracks_dataset : pathlib.Path + Path to the tracking labels dataset (HCS OME-Zarr). + Potentially with a singleton Z dimension. + fov_name : str + Name of the FOV to load, e.g. `"A/12/2"`. + expand_z_for_tracking_labels : bool + Whether to expand the tracking labels to the Z dimension of the images. + load_tracks_layer : bool + Whether to load the tracks layer. + tracks_z_index : int + Index of the Z slice to place the 2D tracks, by default -1 (middle slice). + + Returns + ------- + List[napari.types.LayerDataTuple] + List of layers to add to the viewer. + (image layers and one labels layer) + """ + _logger.info(f"Loading images from {images_dataset}") + image_plate = open_ome_zarr(images_dataset) + image_fov = image_plate[fov_name] + image_layers = fov_to_layers(image_fov) + _logger.info(f"Loading tracking labels from {tracks_dataset}") + tracks_plate = open_ome_zarr(tracks_dataset) + tracks_fov = tracks_plate[fov_name] + labels_layer = fov_to_layers(tracks_fov, layer_type="labels")[0] + # TODO: remove this after https://github.com/napari/napari/issues/7327 is fixed + labels_layer[0][0] = labels_layer[0][0].astype("uint32") + image_z = image_fov["0"].slices + if expand_z_for_tracking_labels: + _logger.info(f"Expanding tracks to Z={image_z}") + labels_layer[0][0] = labels_layer[0][0].repeat(image_z, axis=1) + image_layers.append(labels_layer) + tracks_csv = next((tracks_dataset / fov_name.strip("/")).glob("*.csv")) + if load_tracks_layer: + _logger.info(f"Loading tracks from {str(tracks_csv)} with ultrack") + tracks_layer = _ultrack_read_csv(tracks_csv) + if tracks_z_index is not None: + tracks_z_index = image_z // 2 + _logger.info(f"Placing tracks at Z={tracks_z_index}") + tracks_layer[0].insert(loc=2, column="z", value=tracks_z_index) + image_layers.append(tracks_layer) + _logger.info(f"Finished loading {len(image_layers)} layers") + _logger.debug(f"Layers: {image_layers}") + return image_layers + + +def setup_annotation_layers(viewer: napari.Viewer) -> None: + """ + Create four annotation points layers (one per event type). + All layers default to 'add' mode for easy annotation. + + Parameters + ---------- + viewer : napari.Viewer + The napari viewer instance to add layers to. + """ + # Cell Division layer - mark mitosis events + layer = viewer.add_points( + ndim=4, + size=20, + face_color="blue", + name="_mitosis_events", + ) + layer.mode = "add" + + # Infection layer - mark infected events + layer = viewer.add_points( + ndim=4, + size=20, + face_color="orange", + name="_infected_events", + ) + layer.mode = "add" + + # Organelle remodeling layer - mark remodel events + layer = viewer.add_points( + ndim=4, + size=20, + face_color="purple", + name="_remodel_events", + ) + layer.mode = "add" + + # Cell death layer - mark death events + layer = viewer.add_points( + ndim=4, + size=20, + face_color="red", + name="_death_events", + ) + layer.mode = "add" + + +def save_annotations( + viewer: napari.Viewer, + output_path: Path, + fov_name: str, + tracks_zarr, + tracks_csv_path: Path, + diameter: int = 10, +) -> None: + """ + Save napari point annotations to ultrack-style CSV. + Expands annotations to all timepoints based on binary logic. + + Parameters + ---------- + viewer : napari.Viewer + The napari viewer instance. + output_path : Path + Path to save the annotations CSV. + fov_name : str + FOV name for the fov_name column. + tracks_zarr : Position + Opened OME-Zarr position with segmentation labels. + tracks_csv_path : Path + Path to tracks CSV file. + diameter : int + Window diameter for robust label lookup. + """ + # Load tracks CSV to get all track-timepoint combinations + tracks_df = pd.read_csv(tracks_csv_path) + + # Collect marked events from each layer + marked_events = { + "cell_division_state": [], # mitosis events + "infection_state": [], # infected events + "organelle_state": [], # remodel events + "cell_death_state": [], # death events + } + + # Process the four annotation layers + layer_mapping = [ + ("_mitosis_events", "cell_division_state", "mitosis"), + ("_infected_events", "infection_state", "infected"), + ("_remodel_events", "organelle_state", "remodel"), + ("_death_events", "cell_death_state", "dead"), + ] + + for layer_name, event_type, event_state in layer_mapping: + if layer_name in viewer.layers: + points_layer = viewer.layers[layer_name] + points_data = points_layer.data # Shape: (n_points, 4) for [t, z, y, x] + + for point in points_data: + t, z, y, x = [int(coord) for coord in point] + + # Load segmentation for this timepoint + labels_image = tracks_zarr["0"][t, 0, 0] # (C, Z, Y, X) → take C=0, Z=0 + + # Get label value in window around point + y_slice = slice( + max(0, y - diameter), min(labels_image.shape[0], y + diameter) + ) + x_slice = slice( + max(0, x - diameter), min(labels_image.shape[1], x + diameter) + ) + label_value = int(labels_image[y_slice, x_slice].mean()) + + if label_value > 0: + marked_events[event_type].append({"track_id": label_value, "t": t}) + else: + _logger.warning( + f"Point at t={t}, y={y}, x={x} maps to background (label=0)" + ) + + # Expand annotations to all timepoints based on binary logic + all_annotations = [] + + # Get all track-timepoint combinations + all_track_timepoints = tracks_df[["track_id", "t"]].drop_duplicates() + + # Process each event type + for track_id in all_track_timepoints["track_id"].unique(): + track_timepoints = all_track_timepoints[ + all_track_timepoints["track_id"] == track_id + ]["t"].sort_values() + + # Cell Division: marked timepoints = mitosis, all others = interphase + division_events = [ + e for e in marked_events["cell_division_state"] if e["track_id"] == track_id + ] + mitosis_timepoints = [e["t"] for e in division_events] + + # Infection: first marked timepoint onwards = infected, before = uninfected + infection_events = [ + e for e in marked_events["infection_state"] if e["track_id"] == track_id + ] + first_infected_t = ( + min([e["t"] for e in infection_events]) if infection_events else None + ) + + # Organelle: first marked timepoint onwards = remodel, before = noremodel + organelle_events = [ + e for e in marked_events["organelle_state"] if e["track_id"] == track_id + ] + first_remodel_t = ( + min([e["t"] for e in organelle_events]) if organelle_events else None + ) + + # Cell death: first marked timepoint onwards = dead, before = alive + _death_events = [ + e for e in marked_events["cell_death_state"] if e["track_id"] == track_id + ] + first_death_t = min([e["t"] for e in _death_events]) if _death_events else None + + # Create one row per timepoint with all event states + for t in track_timepoints: + # Check if cell is dead at this timepoint + is_dead = first_death_t is not None and t >= first_death_t + + if is_dead: + # If dead, all other states are None + cell_division_state = None + infection_state = None + organelle_state = None + cell_death_state = "dead" + else: + # If alive, compute states normally + cell_division_state = ( + "mitosis" if t in mitosis_timepoints else "interphase" + ) + infection_state = ( + "infected" + if (first_infected_t is not None and t >= first_infected_t) + else "uninfected" + if first_infected_t is not None + else None + ) + # Organelle: always has a value - remodel if marked, otherwise noremodel + organelle_state = ( + "remodel" + if (first_remodel_t is not None and t >= first_remodel_t) + else "noremodel" + ) + cell_death_state = "alive" if first_death_t is not None else None + + all_annotations.append( + { + "track_id": track_id, + "t": t, + "cell_division_state": cell_division_state, + "infection_state": infection_state, + "organelle_state": organelle_state, + "cell_death_state": cell_death_state, + } + ) + + # Save to CSV + if all_annotations: + annotations_df = pd.DataFrame(all_annotations) + + # Merge with original tracks dataframe to preserve all INDEX_COLUMNS + # Add fov_name column first + tracks_df["fov_name"] = fov_name + + # Merge on track_id and t + merged_df = tracks_df.merge(annotations_df, on=["track_id", "t"], how="left") + + # Reorder columns to have fov_name first, followed by INDEX_COLUMNS, then annotation columns + index_cols = [col for col in INDEX_COLUMNS if col in merged_df.columns] + annotation_cols = [ + "cell_division_state", + "infection_state", + "organelle_state", + "cell_death_state", + ] + column_order = index_cols + annotation_cols + merged_df = merged_df[column_order] + + output_path.mkdir(parents=True, exist_ok=True) + csv_path = output_path / f"annotations_{fov_name.replace('/', '_')}.csv" + merged_df.to_csv(csv_path, index=False) + _logger.info(f"Saved {len(merged_df)} annotations to {csv_path}") + else: + _logger.warning("No annotations to save") + + +@click.command() +@click.option( + "--images-dataset", + "-i", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to OME-Zarr dataset with images", +) +@click.option( + "--tracks-dataset", + "-t", + type=click.Path(exists=True, path_type=Path), + required=True, + help="Path to OME-Zarr dataset with tracking labels", +) +@click.option( + "--fov-name", + "-f", + type=str, + required=True, + help="FOV name to annotate (e.g., 'A/1/000000')", +) +@click.option( + "--output-path", + "-o", + type=click.Path(path_type=Path), + default=None, + help="Path folder to save annotations CSV (default: tracks_dataset/fov_name). It will use the fov_name to create the file name.", +) +def main(images_dataset, tracks_dataset, fov_name, output_path): + """ + Interactive napari tool for annotating cell division, infection, remodeling, and death events. + + Keyboard shortcuts: + - a/d: Step backward/forward in time + - q/e: Cycle through annotation layers (mitosis → infected → remodel → death) + - r: Enable interpolation mode (click start point → press 'r' → click end point to auto-interpolate) + (For cell divisiona and organelle remodeling only) + - s: Save annotations + + Annotation logic: + - Mitosis: marked timepoints = mitosis, others = interphase + - Infected: first marked timepoint onwards = infected, before = uninfected + - Remodel: first marked timepoint onwards = remodel, before = noremodel + - Death: first marked timepoint onwards = dead (all other states become None), before = alive + """ + # Load image and track layers + _logger.info("Loading images and tracks...") + layers = open_image_and_tracks(images_dataset, tracks_dataset, fov_name) + + # Create napari viewer + viewer = napari.Viewer() + + # Add all layers to viewer + for layer_data, layer_kwargs, layer_type in layers: + if layer_type == "image": + viewer.add_image(layer_data, **layer_kwargs) + elif layer_type == "labels": + viewer.add_labels(layer_data, **layer_kwargs) + elif layer_type == "tracks": + viewer.add_tracks(layer_data, **layer_kwargs) + + # Open tracks zarr for label lookup + tracks_plate = open_ome_zarr(tracks_dataset) + tracks_fov = tracks_plate[fov_name] + + # Get tracks CSV path + tracks_csv_path = next((Path(tracks_dataset) / fov_name.strip("/")).glob("*.csv")) + + # Setup annotation layers + _logger.info("Setting up annotation layers...") + setup_annotation_layers(viewer) + + # Set default output path if not provided + if output_path is None: + output_path = Path(tracks_dataset) / fov_name.strip("/") + + # State for interpolation mode + interpolation_mode = {"enabled": False, "start_point": None} + + # List of annotation layers for cycling + annotation_layers = [ + "_mitosis_events", + "_infected_events", + "_remodel_events", + "_death_events", + ] + current_layer_index = {"index": 0} + + # Add mouse callback for interpolation and tracking last point + def interpolate_points(layer, event): + if ( + interpolation_mode["enabled"] + and interpolation_mode["start_point"] is not None + ): + # Get click position for end point + end_coords = np.array(layer.world_to_data(event.position)) + start_coords = interpolation_mode["start_point"] + + t1, t2 = int(start_coords[0]), int(end_coords[0]) + if t1 > t2: + t1, t2 = t2, t1 + start_coords, end_coords = end_coords, start_coords + + # Add all intermediate timepoints (skip endpoints as they're already added) + for t in range(t1 + 1, t2): + alpha = (t - t1) / (t2 - t1) + interpolated = start_coords + alpha * (end_coords - start_coords) + interpolated[0] = t # Set exact timepoint + layer.add(interpolated) + + _logger.info(f"Interpolated {t2 - t1 - 1} points between t={t1} and t={t2}") + + # Reset interpolation mode + interpolation_mode["enabled"] = False + interpolation_mode["start_point"] = None + else: + # Track last added point for potential interpolation + coords = np.array(layer.world_to_data(event.position)) + interpolation_mode["start_point"] = coords + + # Connect the callback to each annotation layer and add custom keybindings + for layer_name in [ + "_mitosis_events", + "_infected_events", + "_remodel_events", + "_death_events", + ]: + layer = viewer.layers[layer_name] + layer.mouse_drag_callbacks.append(interpolate_points) + + # Bind shortcuts directly to each layer so they work when the layer is active + @layer.bind_key("a") + def layer_step_backward(layer): + current_step = viewer.dims.current_step + if current_step[0] > 0: + viewer.dims.current_step = (current_step[0] - 1, *current_step[1:]) + _logger.info(f"Time: {viewer.dims.current_step[0]}") + + @layer.bind_key("d") + def layer_step_forward(layer): + current_step = viewer.dims.current_step + max_step = viewer.dims.range[0][1] - 1 + if current_step[0] < max_step: + viewer.dims.current_step = (current_step[0] + 1, *current_step[1:]) + _logger.info(f"Time: {viewer.dims.current_step[0]}") + + @layer.bind_key("s") + def layer_save(layer): + _logger.info("Saving annotations...") + save_annotations(viewer, output_path, fov_name, tracks_fov, tracks_csv_path) + + @layer.bind_key("q") + def layer_cycle_backward(layer): + current_layer_index["index"] = (current_layer_index["index"] - 1) % len( + annotation_layers + ) + new_layer_name = annotation_layers[current_layer_index["index"]] + viewer.layers.selection.active = viewer.layers[new_layer_name] + _logger.info(f"Switched to {new_layer_name}") + + @layer.bind_key("e") + def layer_cycle_forward(layer): + current_layer_index["index"] = (current_layer_index["index"] + 1) % len( + annotation_layers + ) + new_layer_name = annotation_layers[current_layer_index["index"]] + viewer.layers.selection.active = viewer.layers[new_layer_name] + _logger.info(f"Switched to {new_layer_name}") + + @layer.bind_key("r") + def layer_toggle_interpolation(layer): + if interpolation_mode["start_point"] is not None: + interpolation_mode["enabled"] = True + start_t = int(interpolation_mode["start_point"][0]) + _logger.info( + f"Interpolation mode ENABLED - click end point to interpolate from t={start_t}" + ) + else: + _logger.info( + "No start point - add a point first, then press 'r' to enable interpolation" + ) + + # Set initial active layer + viewer.layers.selection.active = viewer.layers["_mitosis_events"] + + _logger.info("Viewer ready! Annotation layers in 'add' mode by default") + _logger.info(" Navigation: a/d = step backward/forward in time") + _logger.info(" Layers: q/e = cycle through annotation layers") + _logger.info(" Interpolation: click start point → press 'r' → click end point") + _logger.info(" Save: s = save annotations") + _logger.info(" Annotation layers: mitosis → infected → remodel → death") + + # Run napari + napari.run() + + +if __name__ == "__main__": + main()