From 11f6c21b7b02a458ddbeaf62fbc6926afb733a3f Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Thu, 2 Oct 2025 15:30:23 -0700 Subject: [PATCH 01/13] initial example dynaclr tracking --- .../cell_tracking/cell_tracking_ctc.py | 297 ++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 applications/cell_tracking/cell_tracking_ctc.py diff --git a/applications/cell_tracking/cell_tracking_ctc.py b/applications/cell_tracking/cell_tracking_ctc.py new file mode 100644 index 000000000..6170d0c94 --- /dev/null +++ b/applications/cell_tracking/cell_tracking_ctc.py @@ -0,0 +1,297 @@ +#!/usr/bin/env -S uv run --script +# +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "tracksdata", +# "onnxruntime-gpu", +# "napari[pyqt5]", +# "gurobipy", +# "py-ctcmetrics", +# ] +# /// + +from pathlib import Path +import polars as pl +import numpy as np +import napari +import onnxruntime as ort + +from dask.array.image import imread +from numpy.typing import NDArray +from toolz import curry +from rich import print +from scipy.ndimage import zoom + +import tracksdata as td + + +def _seg_dir(dataset_dir: Path, dataset_num: str) -> Path: + return dataset_dir / f"{dataset_num}_ERR_SEG" + + +def _pad(image: NDArray, shape: tuple[int, int]) -> NDArray: + """ + Pad the image to the given shape. + """ + diff = np.asarray(shape) - np.asarray(image.shape) + + if diff.sum() == 0: + return image + + left = diff // 2 + right = diff - left + + return np.pad(image, ((left[0], right[0]), (left[1], right[1])), mode="reflect") + + +@curry +def _crop_embedding( + frame: NDArray, + mask: list[td.nodes.Mask], + shape: tuple[int, int], + session: ort.InferenceSession, + input_name: str, + upscale: float = 1.0, +) -> NDArray[np.float32]: + """ + Crop the frame and compute the DynaCLR embedding. + + Parameters + ---------- + frame : NDArray + The frame to crop. + mask : Mask + The mask to crop the frame. + shape : tuple[int, int] + The shape of the crop. + session : ort.InferenceSession + The session to use for the embedding. + input_name : str + The name of the input tensor. + upscale : float, optional + The upscale factor to apply to the crop. + + Returns + ------- + NDArray[np.float32] + The embedding of the crop. + """ + crops = [] + for m in mask: + crop = m.crop(frame, shape=shape).astype(np.float32) + crop = _pad(crop, shape) + # mean, std = crop.mean(), crop.std() + mean, std = np.median(crop), np.quantile(crop, 0.75) - np.quantile(crop, 0.25) + crop = (crop - mean) / (std + 1e-8) + if upscale != 1.0: + crop = zoom(crop, upscale, order=1, mode="nearest") + + crops.append(crop) + + # expanding batch, channel, and z dimensions + crops = np.stack(crops, axis=0) + crops = crops[:, np.newaxis, np.newaxis, ...] + output = session.run(None, {input_name: crops}) + + # embedding = output[-1] # projected 32-dimensional embedding + embedding = output[0] # 768-dimensional embedding + embedding = embedding / np.linalg.norm(embedding, axis=1, keepdims=True) + + return [e for e in embedding] + + +def _add_dynaclr_attrs( + model_path: Path, + graph: td.graph.InMemoryGraph, + images: NDArray, +) -> None: + """ + Add DynaCLR embedding attributes to each node in the graph + and compute the cosine similarity for existing edges. + + Parameters + ---------- + graph : td.graph.InMemoryGraph + The graph to add the attributes to. + images : NDArray + The images to use for the embedding. + """ + + session = ort.InferenceSession(model_path) + + input_name = session.get_inputs()[0].name + input_dim = session.get_inputs()[0].shape + input_type = session.get_inputs()[0].type + + print(f"Model input name: '{input_name}'") + print(f"Expected input dimensions: {input_dim}") + print(f"Expected input type: {input_type}") + + crop_attr_func = _crop_embedding( + shape=(64, 64), + session=session, + input_name=input_name, + upscale=1.0, + ) + + print("Adding DynaCLR embedding attributes ...") + td.nodes.GenericFuncNodeAttrs( + func=crop_attr_func, + output_key="dynaclr_embedding", + attr_keys=["mask"], + batch_size=128, + ).add_node_attrs(graph, frames=images) + + print("Adding cosine similarity attributes ...") + td.edges.GenericFuncEdgeAttrs( + func=np.dot, + output_key="dynaclr_similarity", + attr_keys="dynaclr_embedding", + ).add_edge_attrs(graph) + + +def track_single_dataset( + dataset_dir: Path, + dataset_num: str, + show_napari_viewer: bool, + dynaclr_model_path: Path | None, +) -> None: + """ + Main function to track cells in a dataset. + + Parameters + ---------- + dataset_dir : Path + Path to the dataset directory. + dataset_num : str + Number of the dataset. + show_napari_viewer : bool + Whether to show the napari viewer. + dynaclr_model_path : Path | None + Path to the DynaCLR model. If None, the model will not be used. + """ + assert dataset_dir.exists(), f"Data directory {dataset_dir} does not exist." + + print(f"Loading labels from {dataset_dir} ...") + labels = imread(str(_seg_dir(dataset_dir, dataset_num) / "*.tif")) + images = imread(str(dataset_dir / dataset_num / "*.tif")) + + gt_graph = td.graph.InMemoryGraph.from_ctc(dataset_dir / f"{dataset_num}_GT" / "TRA") + + print("Starting tracking ...") + graph = td.graph.InMemoryGraph() + + nodes_operator = td.nodes.RegionPropsNodes() + nodes_operator.add_nodes(graph, labels=labels) + print(f"Number of nodes: {graph.num_nodes}") + + dist_operator = td.edges.DistanceEdges( + distance_threshold=325.0, # 50, + n_neighbors=10, + delta_t=5, # 30, + ) + dist_operator.add_edges(graph) + print(f"Number of edges: {graph.num_edges}") + + td.edges.GenericFuncEdgeAttrs( + func=lambda x, y: abs(x - y), + output_key="delta_t", + attr_keys="t", + ).add_edge_attrs(graph) + + dist_weight = (-td.EdgeAttr(td.DEFAULT_ATTR_KEYS.EDGE_DIST) / dist_operator.distance_threshold).exp() + + if dynaclr_model_path is not None: + _add_dynaclr_attrs(dynaclr_model_path, graph, images) + # decrease dynaclr similarity given the distance? + edge_weight = -td.EdgeAttr("dynaclr_similarity") * dist_weight + + else: + iou_operator = td.edges.IoUEdgeAttr(output_key="iou") + iou_operator.add_edge_attrs(graph) + + edge_weight = -(td.EdgeAttr("iou") + 0.1) * dist_weight + + edge_weight = edge_weight / td.EdgeAttr("delta_t").clip(lower_bound=1) + + solver = td.solvers.ILPSolver( + edge_weight=edge_weight, + appearance_weight=0, + disappearance_weight=0, + division_weight=0.5, + node_weight=-10, # we assume all segmentations are correct + ) + + solution_graph = solver.solve(graph) + + print("Evaluating results ...") + metrics = td.metrics.evaluate_ctc_metrics( + solution_graph, + gt_graph, + input_reset=False, + reference_reset=False, + ) + + if show_napari_viewer: + print("Converting to napari format ...") + labels, tracks_df, track_graph = td.functional.to_napari_format(graph, labels.shape) + + print("Opening napari viewer ...") + viewer = napari.Viewer() + viewer.add_image(images) + viewer.add_labels(labels) + viewer.add_tracks(tracks_df, graph=track_graph) + napari.run() + + return metrics + + +if __name__ == "__main__": + + model_paths = [ + None, + Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_timeaware_phase_brightfield_temp0p2_batch256_ckpt13.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_classical_gfp_rfp_ph_temp0p5_batch128_ckpt146.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_timeaware_gfp_rfp_ph_temp0p5_batch128_ckpt185.onnx"), + ] + + results = [] + + dataset_root = Path("/hpc/reference/group.royer/CTC/training/") + + for model_path in model_paths: + for dataset_dir in list(dataset_root.iterdir()): + + # executing only 2D datasets + if "2D" not in dataset_dir.name: + print(f"Skipping {dataset_dir.name} because it is not a 2D dataset") + continue + + for dataset_num in ["01", "02"]: + # processing only datasets with segmentation (linking challenge) + seg_dir = _seg_dir(dataset_dir, dataset_num) + if not seg_dir.exists(): + print(f"Skipping {dataset_dir.name} because it does not have segmentation") + continue + + metrics = track_single_dataset( + dataset_dir=dataset_dir, + dataset_num=dataset_num, + show_napari_viewer=False, + dynaclr_model_path=model_path, + ) + metrics["model"] = "None" if model_path is None else model_path.stem + metrics["dataset"] = dataset_dir.name + metrics["dataset_num"] = dataset_num + print(metrics) + results.append(metrics) + + df = pl.DataFrame(results) + df.write_csv("results.csv") + + print( + df.group_by("model", "dataset").mean().select( + "model", "dataset", "LNK", "BIO(0)", "OP_CLB(0)", "CHOTA" + ) + ) From a09048c186384d2b06c188acffb9eefa2313b7d5 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Fri, 10 Oct 2025 09:26:10 -0700 Subject: [PATCH 02/13] adding rescaling and 3D inference --- .../cell_tracking/cell_tracking_ctc.py | 78 ++++++++++++++----- 1 file changed, 60 insertions(+), 18 deletions(-) mode change 100644 => 100755 applications/cell_tracking/cell_tracking_ctc.py diff --git a/applications/cell_tracking/cell_tracking_ctc.py b/applications/cell_tracking/cell_tracking_ctc.py old mode 100644 new mode 100755 index 6170d0c94..6924eac2e --- a/applications/cell_tracking/cell_tracking_ctc.py +++ b/applications/cell_tracking/cell_tracking_ctc.py @@ -12,6 +12,7 @@ # /// from pathlib import Path +import yaml import polars as pl import numpy as np import napari @@ -49,7 +50,7 @@ def _pad(image: NDArray, shape: tuple[int, int]) -> NDArray: def _crop_embedding( frame: NDArray, mask: list[td.nodes.Mask], - shape: tuple[int, int], + final_shape: tuple[int, int], session: ort.InferenceSession, input_name: str, upscale: float = 1.0, @@ -79,13 +80,27 @@ def _crop_embedding( """ crops = [] for m in mask: - crop = m.crop(frame, shape=shape).astype(np.float32) - crop = _pad(crop, shape) - # mean, std = crop.mean(), crop.std() - mean, std = np.median(crop), np.quantile(crop, 0.75) - np.quantile(crop, 0.25) - crop = (crop - mean) / (std + 1e-8) - if upscale != 1.0: + crop_shape = tuple(np.round(np.asarray(final_shape) / upscale).astype(int)) + + if frame.ndim == 3: + crop_shape = (1, *crop_shape) + upscale = (1, *upscale) + + crop = m.crop(frame, shape=crop_shape).astype(np.float32) + crop = _pad(crop, crop_shape) + + mu, sigma = np.median(crop), np.quantile(crop, 0.75) - np.quantile(crop, 0.25) + crop = (crop - mu) / (sigma + 1e-8) + + if np.all(upscale != 1.0): crop = zoom(crop, upscale, order=1, mode="nearest") + + if crop.shape[-2:] != final_shape: + raise ValueError(f"Crop shape {crop.shape} does not match final shape {final_shape}") + + if crop.ndim == 3: + assert crop.shape[0] == 1, f"Expected 1 z-slice in 3D crop. Found {crop.shape[0]}" + crop = crop[0] crops.append(crop) @@ -105,6 +120,8 @@ def _add_dynaclr_attrs( model_path: Path, graph: td.graph.InMemoryGraph, images: NDArray, + dataset_scale: tuple[float, float], + model_scale: tuple[float, float], ) -> None: """ Add DynaCLR embedding attributes to each node in the graph @@ -129,10 +146,10 @@ def _add_dynaclr_attrs( print(f"Expected input type: {input_type}") crop_attr_func = _crop_embedding( - shape=(64, 64), + final_shape=(192, 192), session=session, input_name=input_name, - upscale=1.0, + upscale=np.asarray(model_scale) / np.asarray(dataset_scale), ) print("Adding DynaCLR embedding attributes ...") @@ -154,8 +171,10 @@ def _add_dynaclr_attrs( def track_single_dataset( dataset_dir: Path, dataset_num: str, + dataset_scale: tuple[float, float], show_napari_viewer: bool, dynaclr_model_path: Path | None, + model_scale: tuple[float, float], ) -> None: """ Main function to track cells in a dataset. @@ -166,10 +185,14 @@ def track_single_dataset( Path to the dataset directory. dataset_num : str Number of the dataset. + dataset_scale: tuple[float, float], + The scale of the dataset. show_napari_viewer : bool Whether to show the napari viewer. dynaclr_model_path : Path | None Path to the DynaCLR model. If None, the model will not be used. + model_scale : tuple[float, float] + The scale of the model. """ assert dataset_dir.exists(), f"Data directory {dataset_dir} does not exist." @@ -203,7 +226,13 @@ def track_single_dataset( dist_weight = (-td.EdgeAttr(td.DEFAULT_ATTR_KEYS.EDGE_DIST) / dist_operator.distance_threshold).exp() if dynaclr_model_path is not None: - _add_dynaclr_attrs(dynaclr_model_path, graph, images) + _add_dynaclr_attrs( + dynaclr_model_path, + graph, + images, + dataset_scale=dataset_scale, + model_scale=model_scale, + ) # decrease dynaclr similarity given the distance? edge_weight = -td.EdgeAttr("dynaclr_similarity") * dist_weight @@ -247,27 +276,34 @@ def track_single_dataset( return metrics -if __name__ == "__main__": - - model_paths = [ - None, - Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_timeaware_phase_brightfield_temp0p2_batch256_ckpt13.onnx"), - Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_classical_gfp_rfp_ph_temp0p5_batch128_ckpt146.onnx"), - Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_timeaware_gfp_rfp_ph_temp0p5_batch128_ckpt185.onnx"), +def main() -> None: + models = [ + # (None, (1, 1)), + (Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_phase_brightfield_temp0p2_batch256_ckpt33.onnx"), (0.325, 0.325)), + (Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_timeaware_phase_brightfield_temp0p2_batch256_ckpt13.onnx"), (0.325, 0.325)), + (Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_classical_gfp_rfp_ph_temp0p5_batch128_ckpt146.onnx"), (0.150, 0.150)), + (Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_timeaware_gfp_rfp_ph_temp0p5_batch128_ckpt185.onnx"), (0.150, 0.150)), ] results = [] dataset_root = Path("/hpc/reference/group.royer/CTC/training/") - for model_path in model_paths: + scale_metadata_fpath = dataset_root.parent / "metadata.yaml" + with open(scale_metadata_fpath, "r") as f: + scale_metadata = yaml.safe_load(f) + + for model_path, model_scale in models: for dataset_dir in list(dataset_root.iterdir()): + # FIXME: include 3D datasets with center crop # executing only 2D datasets if "2D" not in dataset_dir.name: print(f"Skipping {dataset_dir.name} because it is not a 2D dataset") continue + dataset_scale = scale_metadata[dataset_dir.name] + for dataset_num in ["01", "02"]: # processing only datasets with segmentation (linking challenge) seg_dir = _seg_dir(dataset_dir, dataset_num) @@ -278,8 +314,10 @@ def track_single_dataset( metrics = track_single_dataset( dataset_dir=dataset_dir, dataset_num=dataset_num, + dataset_scale=dataset_scale[-2:], show_napari_viewer=False, dynaclr_model_path=model_path, + model_scale=model_scale, ) metrics["model"] = "None" if model_path is None else model_path.stem metrics["dataset"] = dataset_dir.name @@ -295,3 +333,7 @@ def track_single_dataset( "model", "dataset", "LNK", "BIO(0)", "OP_CLB(0)", "CHOTA" ) ) + + +if __name__ == "__main__": + main() From ebad3c5c5268a38c753da4dfc6c0292eaeab3ace Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Fri, 10 Oct 2025 10:29:28 -0700 Subject: [PATCH 03/13] working but not optimal version --- .../cell_tracking/cell_tracking_ctc.py | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/applications/cell_tracking/cell_tracking_ctc.py b/applications/cell_tracking/cell_tracking_ctc.py index 6924eac2e..f1ca9c63a 100755 --- a/applications/cell_tracking/cell_tracking_ctc.py +++ b/applications/cell_tracking/cell_tracking_ctc.py @@ -3,7 +3,7 @@ # /// script # requires-python = ">=3.10" # dependencies = [ -# "tracksdata", +# "tracksdata @ git+https://github.com/royerlab/tracksdata.git", # "onnxruntime-gpu", # "napari[pyqt5]", # "gurobipy", @@ -17,6 +17,7 @@ import numpy as np import napari import onnxruntime as ort +import warnings from dask.array.image import imread from numpy.typing import NDArray @@ -43,7 +44,7 @@ def _pad(image: NDArray, shape: tuple[int, int]) -> NDArray: left = diff // 2 right = diff - left - return np.pad(image, ((left[0], right[0]), (left[1], right[1])), mode="reflect") + return np.pad(image, tuple(zip(left, right)), mode="reflect") @curry @@ -53,7 +54,7 @@ def _crop_embedding( final_shape: tuple[int, int], session: ort.InferenceSession, input_name: str, - upscale: float = 1.0, + upscale: tuple[float, float], ) -> NDArray[np.float32]: """ Crop the frame and compute the DynaCLR embedding. @@ -80,28 +81,27 @@ def _crop_embedding( """ crops = [] for m in mask: - crop_shape = tuple(np.round(np.asarray(final_shape) / upscale).astype(int)) + crop_shape = tuple(np.ceil(np.asarray(final_shape) / upscale).astype(int)) if frame.ndim == 3: crop_shape = (1, *crop_shape) - upscale = (1, *upscale) crop = m.crop(frame, shape=crop_shape).astype(np.float32) crop = _pad(crop, crop_shape) + if crop.ndim == 3: + assert crop.shape[0] == 1, f"Expected 1 z-slice in 3D crop. Found {crop.shape[0]}" + crop = crop[0] + mu, sigma = np.median(crop), np.quantile(crop, 0.75) - np.quantile(crop, 0.25) crop = (crop - mu) / (sigma + 1e-8) if np.all(upscale != 1.0): - crop = zoom(crop, upscale, order=1, mode="nearest") + crop = zoom(crop, upscale, order=2, mode="nearest") - if crop.shape[-2:] != final_shape: - raise ValueError(f"Crop shape {crop.shape} does not match final shape {final_shape}") + if crop.shape != final_shape: + warnings.warn(f"Crop shape {crop.shape} does not match final shape {final_shape}") - if crop.ndim == 3: - assert crop.shape[0] == 1, f"Expected 1 z-slice in 3D crop. Found {crop.shape[0]}" - crop = crop[0] - crops.append(crop) # expanding batch, channel, and z dimensions @@ -145,11 +145,13 @@ def _add_dynaclr_attrs( print(f"Expected input dimensions: {input_dim}") print(f"Expected input type: {input_type}") + cell_zoom_factor = 1 + upscale = cell_zoom_factor * np.asarray(model_scale) / np.asarray(dataset_scale) crop_attr_func = _crop_embedding( - final_shape=(192, 192), + final_shape=(64, 64), session=session, input_name=input_name, - upscale=np.asarray(model_scale) / np.asarray(dataset_scale), + upscale=upscale, ) print("Adding DynaCLR embedding attributes ...") @@ -196,7 +198,7 @@ def track_single_dataset( """ assert dataset_dir.exists(), f"Data directory {dataset_dir} does not exist." - print(f"Loading labels from {dataset_dir} ...") + print(f"Loading labels from '{dataset_dir}' with scale {dataset_scale} ...") labels = imread(str(_seg_dir(dataset_dir, dataset_num) / "*.tif")) images = imread(str(dataset_dir / dataset_num / "*.tif")) @@ -295,12 +297,9 @@ def main() -> None: for model_path, model_scale in models: for dataset_dir in list(dataset_root.iterdir()): - - # FIXME: include 3D datasets with center crop - # executing only 2D datasets - if "2D" not in dataset_dir.name: - print(f"Skipping {dataset_dir.name} because it is not a 2D dataset") - continue + # FIXME: remove this + # if "-ce" not in dataset_dir.name.lower(): + # continue dataset_scale = scale_metadata[dataset_dir.name] From 9dca69924a7aa61bdacac6b814be19f3b031b9f3 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Fri, 10 Oct 2025 11:01:00 -0700 Subject: [PATCH 04/13] using mask information for cropping --- .../cell_tracking/cell_tracking_ctc.py | 92 +++++++++---------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/applications/cell_tracking/cell_tracking_ctc.py b/applications/cell_tracking/cell_tracking_ctc.py index f1ca9c63a..f2707cc88 100755 --- a/applications/cell_tracking/cell_tracking_ctc.py +++ b/applications/cell_tracking/cell_tracking_ctc.py @@ -23,7 +23,7 @@ from numpy.typing import NDArray from toolz import curry from rich import print -from scipy.ndimage import zoom +from scipy.ndimage import gaussian_filter import tracksdata as td @@ -32,7 +32,7 @@ def _seg_dir(dataset_dir: Path, dataset_num: str) -> Path: return dataset_dir / f"{dataset_num}_ERR_SEG" -def _pad(image: NDArray, shape: tuple[int, int]) -> NDArray: +def _pad(image: NDArray, shape: tuple[int, int], mode: str) -> NDArray: """ Pad the image to the given shape. """ @@ -44,7 +44,7 @@ def _pad(image: NDArray, shape: tuple[int, int]) -> NDArray: left = diff // 2 right = diff - left - return np.pad(image, tuple(zip(left, right)), mode="reflect") + return np.pad(image, tuple(zip(left, right)), mode=mode) @curry @@ -54,7 +54,7 @@ def _crop_embedding( final_shape: tuple[int, int], session: ort.InferenceSession, input_name: str, - upscale: tuple[float, float], + padding: int, ) -> NDArray[np.float32]: """ Crop the frame and compute the DynaCLR embedding. @@ -71,36 +71,53 @@ def _crop_embedding( The session to use for the embedding. input_name : str The name of the input tensor. - upscale : float, optional - The upscale factor to apply to the crop. + padding : int, optional + The padding to apply to the crop. Returns ------- NDArray[np.float32] The embedding of the crop. """ + label_img = np.zeros_like(frame, dtype=np.int16) + crops = [] - for m in mask: - crop_shape = tuple(np.ceil(np.asarray(final_shape) / upscale).astype(int)) + for i, m in enumerate(mask, start=1): + + ndim = len(m.bbox) // 2 + bbox_start = m.bbox[:ndim] + bbox_end = m.bbox[ndim:] + max_length = (bbox_end - bbox_start)[-2:].max() + 2 * padding + + crop_shape = np.minimum(final_shape, max_length) if frame.ndim == 3: crop_shape = (1, *crop_shape) + + label_img[m.mask_indices()] = i crop = m.crop(frame, shape=crop_shape).astype(np.float32) - crop = _pad(crop, crop_shape) + crop_mask = (m.crop(label_img, shape=crop_shape) == i).astype(np.float32) if crop.ndim == 3: assert crop.shape[0] == 1, f"Expected 1 z-slice in 3D crop. Found {crop.shape[0]}" crop = crop[0] + crop_mask = crop_mask[0] + + crop = _pad(crop, final_shape, mode="reflect") + crop_mask = _pad(crop_mask, final_shape, mode="constant") + + blurred_mask = gaussian_filter(crop_mask, sigma=5) + crop_mask = np.maximum(crop_mask, blurred_mask / blurred_mask.max()) mu, sigma = np.median(crop), np.quantile(crop, 0.75) - np.quantile(crop, 0.25) crop = (crop - mu) / (sigma + 1e-8) - if np.all(upscale != 1.0): - crop = zoom(crop, upscale, order=2, mode="nearest") - + # removing background + crop = crop * crop_mask + if crop.shape != final_shape: - warnings.warn(f"Crop shape {crop.shape} does not match final shape {final_shape}") + raise ValueError(f"Crop shape {crop.shape} does not match final shape {final_shape}") crops.append(crop) @@ -109,6 +126,12 @@ def _crop_embedding( crops = crops[:, np.newaxis, np.newaxis, ...] output = session.run(None, {input_name: crops}) + # import napari + # viewer = napari.Viewer() + # viewer.add_image(frame) + # viewer.add_image(np.squeeze(crops)) + # napari.run() + # embedding = output[-1] # projected 32-dimensional embedding embedding = output[0] # 768-dimensional embedding embedding = embedding / np.linalg.norm(embedding, axis=1, keepdims=True) @@ -120,8 +143,6 @@ def _add_dynaclr_attrs( model_path: Path, graph: td.graph.InMemoryGraph, images: NDArray, - dataset_scale: tuple[float, float], - model_scale: tuple[float, float], ) -> None: """ Add DynaCLR embedding attributes to each node in the graph @@ -145,13 +166,11 @@ def _add_dynaclr_attrs( print(f"Expected input dimensions: {input_dim}") print(f"Expected input type: {input_type}") - cell_zoom_factor = 1 - upscale = cell_zoom_factor * np.asarray(model_scale) / np.asarray(dataset_scale) crop_attr_func = _crop_embedding( final_shape=(64, 64), session=session, input_name=input_name, - upscale=upscale, + padding=3, ) print("Adding DynaCLR embedding attributes ...") @@ -173,10 +192,8 @@ def _add_dynaclr_attrs( def track_single_dataset( dataset_dir: Path, dataset_num: str, - dataset_scale: tuple[float, float], show_napari_viewer: bool, dynaclr_model_path: Path | None, - model_scale: tuple[float, float], ) -> None: """ Main function to track cells in a dataset. @@ -187,18 +204,14 @@ def track_single_dataset( Path to the dataset directory. dataset_num : str Number of the dataset. - dataset_scale: tuple[float, float], - The scale of the dataset. show_napari_viewer : bool Whether to show the napari viewer. dynaclr_model_path : Path | None Path to the DynaCLR model. If None, the model will not be used. - model_scale : tuple[float, float] - The scale of the model. """ assert dataset_dir.exists(), f"Data directory {dataset_dir} does not exist." - print(f"Loading labels from '{dataset_dir}' with scale {dataset_scale} ...") + print(f"Loading labels from '{dataset_dir}'...") labels = imread(str(_seg_dir(dataset_dir, dataset_num) / "*.tif")) images = imread(str(dataset_dir / dataset_num / "*.tif")) @@ -232,8 +245,6 @@ def track_single_dataset( dynaclr_model_path, graph, images, - dataset_scale=dataset_scale, - model_scale=model_scale, ) # decrease dynaclr similarity given the distance? edge_weight = -td.EdgeAttr("dynaclr_similarity") * dist_weight @@ -250,7 +261,7 @@ def track_single_dataset( edge_weight=edge_weight, appearance_weight=0, disappearance_weight=0, - division_weight=0.5, + division_weight=0.0, node_weight=-10, # we assume all segmentations are correct ) @@ -280,28 +291,19 @@ def track_single_dataset( def main() -> None: models = [ - # (None, (1, 1)), - (Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_phase_brightfield_temp0p2_batch256_ckpt33.onnx"), (0.325, 0.325)), - (Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_timeaware_phase_brightfield_temp0p2_batch256_ckpt13.onnx"), (0.325, 0.325)), - (Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_classical_gfp_rfp_ph_temp0p5_batch128_ckpt146.onnx"), (0.150, 0.150)), - (Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_timeaware_gfp_rfp_ph_temp0p5_batch128_ckpt185.onnx"), (0.150, 0.150)), + None, + Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_phase_brightfield_temp0p2_batch256_ckpt33.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_timeaware_phase_brightfield_temp0p2_batch256_ckpt13.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_classical_gfp_rfp_ph_temp0p5_batch128_ckpt146.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_timeaware_gfp_rfp_ph_temp0p5_batch128_ckpt185.onnx"), ] results = [] dataset_root = Path("/hpc/reference/group.royer/CTC/training/") - scale_metadata_fpath = dataset_root.parent / "metadata.yaml" - with open(scale_metadata_fpath, "r") as f: - scale_metadata = yaml.safe_load(f) - - for model_path, model_scale in models: - for dataset_dir in list(dataset_root.iterdir()): - # FIXME: remove this - # if "-ce" not in dataset_dir.name.lower(): - # continue - - dataset_scale = scale_metadata[dataset_dir.name] + for model_path in models: + for dataset_dir in sorted(dataset_root.iterdir()): for dataset_num in ["01", "02"]: # processing only datasets with segmentation (linking challenge) @@ -313,10 +315,8 @@ def main() -> None: metrics = track_single_dataset( dataset_dir=dataset_dir, dataset_num=dataset_num, - dataset_scale=dataset_scale[-2:], show_napari_viewer=False, dynaclr_model_path=model_path, - model_scale=model_scale, ) metrics["model"] = "None" if model_path is None else model_path.stem metrics["dataset"] = dataset_dir.name From 4381ce59e4c9f8e2f8a3108789198a5edea11d5a Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Fri, 10 Oct 2025 16:09:17 -0700 Subject: [PATCH 05/13] minor changes --- .../cell_tracking/cell_tracking_ctc.py | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/applications/cell_tracking/cell_tracking_ctc.py b/applications/cell_tracking/cell_tracking_ctc.py index f2707cc88..fc3a1ec22 100755 --- a/applications/cell_tracking/cell_tracking_ctc.py +++ b/applications/cell_tracking/cell_tracking_ctc.py @@ -54,7 +54,6 @@ def _crop_embedding( final_shape: tuple[int, int], session: ort.InferenceSession, input_name: str, - padding: int, ) -> NDArray[np.float32]: """ Crop the frame and compute the DynaCLR embedding. @@ -84,15 +83,10 @@ def _crop_embedding( crops = [] for i, m in enumerate(mask, start=1): - ndim = len(m.bbox) // 2 - bbox_start = m.bbox[:ndim] - bbox_end = m.bbox[ndim:] - max_length = (bbox_end - bbox_start)[-2:].max() + 2 * padding - - crop_shape = np.minimum(final_shape, max_length) - if frame.ndim == 3: - crop_shape = (1, *crop_shape) + crop_shape = (1, *final_shape) + else: + crop_shape = final_shape label_img[m.mask_indices()] = i @@ -110,8 +104,8 @@ def _crop_embedding( blurred_mask = gaussian_filter(crop_mask, sigma=5) crop_mask = np.maximum(crop_mask, blurred_mask / blurred_mask.max()) - mu, sigma = np.median(crop), np.quantile(crop, 0.75) - np.quantile(crop, 0.25) - crop = (crop - mu) / (sigma + 1e-8) + mu, sigma = np.mean(crop), np.std(crop) + crop = (crop - mu) / np.maximum(sigma, 1e-8) # removing background crop = crop * crop_mask @@ -170,7 +164,6 @@ def _add_dynaclr_attrs( final_shape=(64, 64), session=session, input_name=input_name, - padding=3, ) print("Adding DynaCLR embedding attributes ...") @@ -261,7 +254,7 @@ def track_single_dataset( edge_weight=edge_weight, appearance_weight=0, disappearance_weight=0, - division_weight=0.0, + division_weight=0.5, node_weight=-10, # we assume all segmentations are correct ) @@ -291,11 +284,11 @@ def track_single_dataset( def main() -> None: models = [ - None, - Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_phase_brightfield_temp0p2_batch256_ckpt33.onnx"), - Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_timeaware_phase_brightfield_temp0p2_batch256_ckpt13.onnx"), Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_classical_gfp_rfp_ph_temp0p5_batch128_ckpt146.onnx"), Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_timeaware_gfp_rfp_ph_temp0p5_batch128_ckpt185.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_phase_brightfield_temp0p2_batch256_ckpt33.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_timeaware_phase_brightfield_temp0p2_batch256_ckpt13.onnx"), + None, ] results = [] @@ -324,8 +317,9 @@ def main() -> None: print(metrics) results.append(metrics) - df = pl.DataFrame(results) - df.write_csv("results.csv") + # update for every new result + df = pl.DataFrame(results) + df.write_csv("results.csv") print( df.group_by("model", "dataset").mean().select( From bed0543d836231f29d095ea46aaafcf823e86ecc Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Fri, 10 Oct 2025 16:27:34 -0700 Subject: [PATCH 06/13] adding comment --- applications/cell_tracking/cell_tracking_ctc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/applications/cell_tracking/cell_tracking_ctc.py b/applications/cell_tracking/cell_tracking_ctc.py index fc3a1ec22..34e211b0c 100755 --- a/applications/cell_tracking/cell_tracking_ctc.py +++ b/applications/cell_tracking/cell_tracking_ctc.py @@ -105,6 +105,8 @@ def _crop_embedding( crop_mask = np.maximum(crop_mask, blurred_mask / blurred_mask.max()) mu, sigma = np.mean(crop), np.std(crop) + # mu = np.median(crop) + # sigma = np.quantile(crop, 0.99) - mu crop = (crop - mu) / np.maximum(sigma, 1e-8) # removing background From ceb29ffdbbb0bf8f74ee16633bc6c584d3f84e16 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 14 Oct 2025 09:11:46 -0700 Subject: [PATCH 07/13] adding plotting code --- applications/cell_tracking/report.py | 39 ++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100755 applications/cell_tracking/report.py diff --git a/applications/cell_tracking/report.py b/applications/cell_tracking/report.py new file mode 100755 index 000000000..638e3c238 --- /dev/null +++ b/applications/cell_tracking/report.py @@ -0,0 +1,39 @@ +#!/usr/bin/env -S uv run --script +# +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "polars", +# "altair>=5.4.0", +# ] +# /// +import polars as pl +import altair as alt + + +def main() -> None: + df = pl.read_csv("results.csv") + gdf = df.group_by("model", "dataset").mean() + + alt.renderers.enable("browser") # use system browser + + metric = "LNK" + p = gdf.plot.bar(x="model", y=metric, color="model") + p = p.properties( + title=metric, + ) + + facet_chart = p.facet( + facet="dataset", + columns=5, + ) + + facet_chart.show() + + df.group_by("model").mean().plot.bar(x="model", y=metric, color="model").properties( + title=metric, + ).show() + + +if __name__ == "__main__": + main() \ No newline at end of file From 87ec936629de1738fed8d222874e68ab78e8746e Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 14 Oct 2025 09:52:35 -0700 Subject: [PATCH 08/13] fixed division by zero --- applications/cell_tracking/cell_tracking_ctc.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/applications/cell_tracking/cell_tracking_ctc.py b/applications/cell_tracking/cell_tracking_ctc.py index 34e211b0c..27b2238a1 100755 --- a/applications/cell_tracking/cell_tracking_ctc.py +++ b/applications/cell_tracking/cell_tracking_ctc.py @@ -8,16 +8,15 @@ # "napari[pyqt5]", # "gurobipy", # "py-ctcmetrics", +# "spatial-graph", # ] # /// from pathlib import Path -import yaml import polars as pl import numpy as np import napari import onnxruntime as ort -import warnings from dask.array.image import imread from numpy.typing import NDArray @@ -102,7 +101,9 @@ def _crop_embedding( crop_mask = _pad(crop_mask, final_shape, mode="constant") blurred_mask = gaussian_filter(crop_mask, sigma=5) - crop_mask = np.maximum(crop_mask, blurred_mask / blurred_mask.max()) + blurred_coef = blurred_mask.max() + if blurred_coef > 1e-8: # if too small use the binary mask + crop_mask = np.maximum(crop_mask, blurred_mask / blurred_coef) mu, sigma = np.mean(crop), np.std(crop) # mu = np.median(crop) @@ -272,7 +273,9 @@ def track_single_dataset( if show_napari_viewer: print("Converting to napari format ...") - labels, tracks_df, track_graph = td.functional.to_napari_format(graph, labels.shape) + tracks_df, track_graph, labels = td.functional.to_napari_format( + graph, labels.shape, mask_key=td.DEFAULT_ATTR_KEYS.MASK + ) print("Opening napari viewer ...") viewer = napari.Viewer() From 48ac955479015db3528fdfb66bb945a35dcd823c Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 14 Oct 2025 16:34:56 -0700 Subject: [PATCH 09/13] updating metric --- applications/cell_tracking/report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/cell_tracking/report.py b/applications/cell_tracking/report.py index 638e3c238..322ff5511 100755 --- a/applications/cell_tracking/report.py +++ b/applications/cell_tracking/report.py @@ -17,7 +17,7 @@ def main() -> None: alt.renderers.enable("browser") # use system browser - metric = "LNK" + metric = "OP_CLB(0)" p = gdf.plot.bar(x="model", y=metric, color="model") p = p.properties( title=metric, From 1f21c24eeb8608590561d055adbd4ce45a11fe06 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 5 Nov 2025 16:38:15 -0800 Subject: [PATCH 10/13] fixing ctc metrics version --- applications/cell_tracking/cell_tracking_ctc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/cell_tracking/cell_tracking_ctc.py b/applications/cell_tracking/cell_tracking_ctc.py index 27b2238a1..a199cb604 100755 --- a/applications/cell_tracking/cell_tracking_ctc.py +++ b/applications/cell_tracking/cell_tracking_ctc.py @@ -7,7 +7,7 @@ # "onnxruntime-gpu", # "napari[pyqt5]", # "gurobipy", -# "py-ctcmetrics", +# "py-ctcmetrics==1.2.2", # "spatial-graph", # ] # /// From 8aece9bb77ffa5bae92aa9544551dbdb395cc0a5 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 5 Nov 2025 17:15:43 -0800 Subject: [PATCH 11/13] loading ground-truth data --- applications/cell_tracking/track_proj.py | 248 +++++++++++++++++++++++ 1 file changed, 248 insertions(+) create mode 100755 applications/cell_tracking/track_proj.py diff --git a/applications/cell_tracking/track_proj.py b/applications/cell_tracking/track_proj.py new file mode 100755 index 000000000..702e5f98a --- /dev/null +++ b/applications/cell_tracking/track_proj.py @@ -0,0 +1,248 @@ +#!/usr/bin/env -S uv run --script +# +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "tracksdata @ git+https://github.com/royerlab/tracksdata.git", +# "contrastive-td @ git+https://github.com/royerlab/contrastive-td.git", +# "onnxruntime-gpu", +# "napari[pyqt5]", +# "gurobipy", +# "py-ctcmetrics==1.2.2", +# "spatial-graph", +# ] +# /// + +from pathlib import Path +import polars as pl +import napari +import zarr +import dask.array as da +from tracksdata.io._ctc import _add_edges_from_tracklet_ids + +from dask.array.image import imread +from rich import print + +import tracksdata as td +from cell_tracking_ctc import _add_dynaclr_attrs + + +def track_single_dataset( + dataset_dir: Path, + dataset_num: str, + show_napari_viewer: bool, + dynaclr_model_path: Path | None, +) -> None: + """ + Main function to track cells in a dataset. + + Parameters + ---------- + dataset_dir : Path + Path to the dataset directory. + dataset_num : str + Number of the dataset. + show_napari_viewer : bool + Whether to show the napari viewer. + dynaclr_model_path : Path | None + Path to the DynaCLR model. If None, the model will not be used. + """ + assert dataset_dir.exists(), f"Data directory {dataset_dir} does not exist." + + print(f"Loading labels from '{dataset_dir}'...") + labels = imread(str(_seg_dir(dataset_dir, dataset_num) / "*.tif")) + images = imread(str(dataset_dir / dataset_num / "*.tif")) + + gt_graph = td.graph.InMemoryGraph.from_ctc(dataset_dir / f"{dataset_num}_GT" / "TRA") + + print("Starting tracking ...") + graph = td.graph.InMemoryGraph() + + nodes_operator = td.nodes.RegionPropsNodes() + nodes_operator.add_nodes(graph, labels=labels) + print(f"Number of nodes: {graph.num_nodes}") + + dist_operator = td.edges.DistanceEdges( + distance_threshold=325.0, # 50, + n_neighbors=10, + delta_t=5, # 30, + ) + dist_operator.add_edges(graph) + print(f"Number of edges: {graph.num_edges}") + + td.edges.GenericFuncEdgeAttrs( + func=lambda x, y: abs(x - y), + output_key="delta_t", + attr_keys="t", + ).add_edge_attrs(graph) + + dist_weight = (-td.EdgeAttr(td.DEFAULT_ATTR_KEYS.EDGE_DIST) / dist_operator.distance_threshold).exp() + + if dynaclr_model_path is not None: + _add_dynaclr_attrs( + dynaclr_model_path, + graph, + images, + ) + # decrease dynaclr similarity given the distance? + edge_weight = -td.EdgeAttr("dynaclr_similarity") * dist_weight + + else: + iou_operator = td.edges.IoUEdgeAttr(output_key="iou") + iou_operator.add_edge_attrs(graph) + + edge_weight = -(td.EdgeAttr("iou") + 0.1) * dist_weight + + edge_weight = edge_weight / td.EdgeAttr("delta_t").clip(lower_bound=1) + + solver = td.solvers.ILPSolver( + edge_weight=edge_weight, + appearance_weight=0, + disappearance_weight=0, + division_weight=0.5, + node_weight=-10, # we assume all segmentations are correct + ) + + solution_graph = solver.solve(graph) + + print("Evaluating results ...") + metrics = td.metrics.evaluate_ctc_metrics( + solution_graph, + gt_graph, + input_reset=False, + reference_reset=False, + ) + + if show_napari_viewer: + print("Converting to napari format ...") + tracks_df, track_graph, labels = td.functional.to_napari_format( + graph, labels.shape, mask_key=td.DEFAULT_ATTR_KEYS.MASK + ) + + print("Opening napari viewer ...") + viewer = napari.Viewer() + viewer.add_image(images) + viewer.add_labels(labels) + viewer.add_tracks(tracks_df, graph=track_graph) + napari.run() + + return metrics + + +def _load_gt_graph( + segm: da.Array, + tracklets_df: pl.DataFrame, +) -> td.graph.InMemoryGraph: + + tracklets_df = tracklets_df.filter(pl.col("parent_track_id") != -1) + tracklet_id_graph = dict(zip( + tracklets_df["track_id"].to_list(), + tracklets_df["parent_track_id"].to_list(), + )) + + gt_graph = td.graph.InMemoryGraph() + + td.nodes.RegionPropsNodes( + extra_properties=["label"], + ).add_nodes(gt_graph, labels=segm) + + _add_edges_from_tracklet_ids( + gt_graph, + gt_graph.node_attrs(attr_keys=[ + td.DEFAULT_ATTR_KEYS.NODE_ID, + td.DEFAULT_ATTR_KEYS.T, + "label", + ]), + tracklet_id_graph=tracklet_id_graph, + tracklet_id_key="label", + ) + + return gt_graph + + +def main() -> None: + + img_dir = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_08_26_A549_SEC61_TOMM20_ZIKV/4-phenotyping/0-train-test/2025_08_26_A549_SEC61_TOMM20_ZIKV.zarr") + + segm_dir = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_08_26_A549_SEC61_TOMM20_ZIKV/2-assemble/tracking_annotation.zarr") + + pos_key = "A/2/000000/0" + img_ds = zarr.open(img_dir, mode="r") + img = da.from_zarr(img_ds[pos_key])[:, 0, 0] + + segm_ds = zarr.open(segm_dir, mode="r") + segm = da.from_zarr(segm_ds[pos_key])[:, 0, 0] + + print(img.shape) + print(segm.shape) + + short_pos_key = pos_key[:-2] + suffix = "_".join(short_pos_key.split("/")) + + tracklets_graph_path = segm_dir / short_pos_key / f"tracks_{suffix}.csv" + tracklets_df = pl.read_csv(tracklets_graph_path) + gt_graph = _load_gt_graph(segm, tracklets_df) + + tracks_df, track_graph, labels = td.functional.to_napari_format( + gt_graph, + segm.shape, + solution_key=None, + mask_key=td.DEFAULT_ATTR_KEYS.MASK, + ) + + viewer = napari.Viewer() + viewer.add_image(img) + viewer.add_labels(segm) + viewer.add_tracks(tracks_df, graph=track_graph) + napari.run() + + + + return + models = [ + Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_classical_gfp_rfp_ph_temp0p5_batch128_ckpt146.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_timeaware_gfp_rfp_ph_temp0p5_batch128_ckpt185.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_phase_brightfield_temp0p2_batch256_ckpt33.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_timeaware_phase_brightfield_temp0p2_batch256_ckpt13.onnx"), + None, + ] + + results = [] + + dataset_root = Path("/hpc/reference/group.royer/CTC/training/") + + for model_path in models: + for dataset_dir in sorted(dataset_root.iterdir()): + + for dataset_num in ["01", "02"]: + # processing only datasets with segmentation (linking challenge) + seg_dir = _seg_dir(dataset_dir, dataset_num) + if not seg_dir.exists(): + print(f"Skipping {dataset_dir.name} because it does not have segmentation") + continue + + metrics = track_single_dataset( + dataset_dir=dataset_dir, + dataset_num=dataset_num, + show_napari_viewer=False, + dynaclr_model_path=model_path, + ) + metrics["model"] = "None" if model_path is None else model_path.stem + metrics["dataset"] = dataset_dir.name + metrics["dataset_num"] = dataset_num + print(metrics) + results.append(metrics) + + # update for every new result + df = pl.DataFrame(results) + df.write_csv("results.csv") + + print( + df.group_by("model", "dataset").mean().select( + "model", "dataset", "LNK", "BIO(0)", "OP_CLB(0)", "CHOTA" + ) + ) + + +if __name__ == "__main__": + main() From 2187a1b4c5721b3b1dfdd0d6d7eddd2eec7ea68b Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Thu, 6 Nov 2025 14:44:33 -0800 Subject: [PATCH 12/13] experiments with internal data --- .../cell_tracking/cell_tracking_ctc.py | 96 ++++++--- applications/cell_tracking/track_proj.py | 194 ++++-------------- 2 files changed, 109 insertions(+), 181 deletions(-) diff --git a/applications/cell_tracking/cell_tracking_ctc.py b/applications/cell_tracking/cell_tracking_ctc.py index a199cb604..7666903f8 100755 --- a/applications/cell_tracking/cell_tracking_ctc.py +++ b/applications/cell_tracking/cell_tracking_ctc.py @@ -13,6 +13,7 @@ # /// from pathlib import Path +from typing import Any import polars as pl import numpy as np import napari @@ -185,34 +186,30 @@ def _add_dynaclr_attrs( ).add_edge_attrs(graph) -def track_single_dataset( - dataset_dir: Path, - dataset_num: str, - show_napari_viewer: bool, +def _track( dynaclr_model_path: Path | None, -) -> None: + images: NDArray, + labels: NDArray, + dist_edge_kwargs: dict[str, Any] | None = None, + ilp_kwargs: dict[str, Any] | None = None, +) -> tuple[td.graph.InMemoryGraph, td.graph.InMemoryGraph]: """ - Main function to track cells in a dataset. + Track cells in a graph. Parameters ---------- - dataset_dir : Path - Path to the dataset directory. - dataset_num : str - Number of the dataset. - show_napari_viewer : bool - Whether to show the napari viewer. dynaclr_model_path : Path | None Path to the DynaCLR model. If None, the model will not be used. - """ - assert dataset_dir.exists(), f"Data directory {dataset_dir} does not exist." - - print(f"Loading labels from '{dataset_dir}'...") - labels = imread(str(_seg_dir(dataset_dir, dataset_num) / "*.tif")) - images = imread(str(dataset_dir / dataset_num / "*.tif")) - - gt_graph = td.graph.InMemoryGraph.from_ctc(dataset_dir / f"{dataset_num}_GT" / "TRA") + images : NDArray + The images to use for the embedding. + labels : NDArray + The labels to use for the tracking. + Returns + ------- + tuple[td.graph.InMemoryGraph, td.graph.InMemoryGraph] + The original graph and the solution graph. + """ print("Starting tracking ...") graph = td.graph.InMemoryGraph() @@ -220,10 +217,14 @@ def track_single_dataset( nodes_operator.add_nodes(graph, labels=labels) print(f"Number of nodes: {graph.num_nodes}") + if dist_edge_kwargs is None: + dist_edge_kwargs = {} + dist_operator = td.edges.DistanceEdges( - distance_threshold=325.0, # 50, - n_neighbors=10, - delta_t=5, # 30, + distance_threshold=dist_edge_kwargs.pop("distance_threshold", 325.0), + n_neighbors=dist_edge_kwargs.pop("n_neighbors", 10), + delta_t=dist_edge_kwargs.pop("delta_t", 5), + **dist_edge_kwargs, ) dist_operator.add_edges(graph) print(f"Number of edges: {graph.num_edges}") @@ -253,16 +254,57 @@ def track_single_dataset( edge_weight = edge_weight / td.EdgeAttr("delta_t").clip(lower_bound=1) + if ilp_kwargs is None: + ilp_kwargs = {} + solver = td.solvers.ILPSolver( edge_weight=edge_weight, - appearance_weight=0, - disappearance_weight=0, - division_weight=0.5, - node_weight=-10, # we assume all segmentations are correct + appearance_weight=ilp_kwargs.pop("appearance_weight", 0), + disappearance_weight=ilp_kwargs.pop("disappearance_weight", 0), + division_weight=ilp_kwargs.pop("division_weight", 0.5), + node_weight=ilp_kwargs.pop("node_weight", -10), + **ilp_kwargs, ) solution_graph = solver.solve(graph) + return graph, solution_graph + + +def track_single_dataset( + dataset_dir: Path, + dataset_num: str, + show_napari_viewer: bool, + dynaclr_model_path: Path | None, +) -> None: + """ + Main function to track cells in a dataset. + + Parameters + ---------- + dataset_dir : Path + Path to the dataset directory. + dataset_num : str + Number of the dataset. + show_napari_viewer : bool + Whether to show the napari viewer. + dynaclr_model_path : Path | None + Path to the DynaCLR model. If None, the model will not be used. + """ + assert dataset_dir.exists(), f"Data directory {dataset_dir} does not exist." + + print(f"Loading labels from '{dataset_dir}'...") + labels = imread(str(_seg_dir(dataset_dir, dataset_num) / "*.tif")) + images = imread(str(dataset_dir / dataset_num / "*.tif")) + + gt_graph = td.graph.InMemoryGraph.from_ctc(dataset_dir / f"{dataset_num}_GT" / "TRA") + + graph, solution_graph = _track( + dynaclr_model_path, + images, + labels, + ) + print("Evaluating results ...") metrics = td.metrics.evaluate_ctc_metrics( solution_graph, diff --git a/applications/cell_tracking/track_proj.py b/applications/cell_tracking/track_proj.py index 702e5f98a..89fa0475c 100755 --- a/applications/cell_tracking/track_proj.py +++ b/applications/cell_tracking/track_proj.py @@ -20,113 +20,10 @@ import dask.array as da from tracksdata.io._ctc import _add_edges_from_tracklet_ids -from dask.array.image import imread from rich import print import tracksdata as td -from cell_tracking_ctc import _add_dynaclr_attrs - - -def track_single_dataset( - dataset_dir: Path, - dataset_num: str, - show_napari_viewer: bool, - dynaclr_model_path: Path | None, -) -> None: - """ - Main function to track cells in a dataset. - - Parameters - ---------- - dataset_dir : Path - Path to the dataset directory. - dataset_num : str - Number of the dataset. - show_napari_viewer : bool - Whether to show the napari viewer. - dynaclr_model_path : Path | None - Path to the DynaCLR model. If None, the model will not be used. - """ - assert dataset_dir.exists(), f"Data directory {dataset_dir} does not exist." - - print(f"Loading labels from '{dataset_dir}'...") - labels = imread(str(_seg_dir(dataset_dir, dataset_num) / "*.tif")) - images = imread(str(dataset_dir / dataset_num / "*.tif")) - - gt_graph = td.graph.InMemoryGraph.from_ctc(dataset_dir / f"{dataset_num}_GT" / "TRA") - - print("Starting tracking ...") - graph = td.graph.InMemoryGraph() - - nodes_operator = td.nodes.RegionPropsNodes() - nodes_operator.add_nodes(graph, labels=labels) - print(f"Number of nodes: {graph.num_nodes}") - - dist_operator = td.edges.DistanceEdges( - distance_threshold=325.0, # 50, - n_neighbors=10, - delta_t=5, # 30, - ) - dist_operator.add_edges(graph) - print(f"Number of edges: {graph.num_edges}") - - td.edges.GenericFuncEdgeAttrs( - func=lambda x, y: abs(x - y), - output_key="delta_t", - attr_keys="t", - ).add_edge_attrs(graph) - - dist_weight = (-td.EdgeAttr(td.DEFAULT_ATTR_KEYS.EDGE_DIST) / dist_operator.distance_threshold).exp() - - if dynaclr_model_path is not None: - _add_dynaclr_attrs( - dynaclr_model_path, - graph, - images, - ) - # decrease dynaclr similarity given the distance? - edge_weight = -td.EdgeAttr("dynaclr_similarity") * dist_weight - - else: - iou_operator = td.edges.IoUEdgeAttr(output_key="iou") - iou_operator.add_edge_attrs(graph) - - edge_weight = -(td.EdgeAttr("iou") + 0.1) * dist_weight - - edge_weight = edge_weight / td.EdgeAttr("delta_t").clip(lower_bound=1) - - solver = td.solvers.ILPSolver( - edge_weight=edge_weight, - appearance_weight=0, - disappearance_weight=0, - division_weight=0.5, - node_weight=-10, # we assume all segmentations are correct - ) - - solution_graph = solver.solve(graph) - - print("Evaluating results ...") - metrics = td.metrics.evaluate_ctc_metrics( - solution_graph, - gt_graph, - input_reset=False, - reference_reset=False, - ) - - if show_napari_viewer: - print("Converting to napari format ...") - tracks_df, track_graph, labels = td.functional.to_napari_format( - graph, labels.shape, mask_key=td.DEFAULT_ATTR_KEYS.MASK - ) - - print("Opening napari viewer ...") - viewer = napari.Viewer() - viewer.add_image(images) - viewer.add_labels(labels) - viewer.add_tracks(tracks_df, graph=track_graph) - napari.run() - - return metrics +from cell_tracking_ctc import _track def _load_gt_graph( @@ -183,65 +80,54 @@ def main() -> None: tracklets_df = pl.read_csv(tracklets_graph_path) gt_graph = _load_gt_graph(segm, tracklets_df) - tracks_df, track_graph, labels = td.functional.to_napari_format( + model_path = Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/deploy/dynaclr2d_timeaware_bagchannels_160patch_ckpt104.onnx") + + graph, solution_graph = _track( + model_path, + img, + segm, + dist_edge_kwargs={"delta_t": 1}, + ilp_kwargs={"division_weight": 0.75}, + ) + + gt_tracks_df, gt_track_graph, gt_labels = td.functional.to_napari_format( gt_graph, segm.shape, solution_key=None, mask_key=td.DEFAULT_ATTR_KEYS.MASK, ) - viewer = napari.Viewer() - viewer.add_image(img) - viewer.add_labels(segm) - viewer.add_tracks(tracks_df, graph=track_graph) - napari.run() + tracks_df, track_graph, labels = td.functional.to_napari_format( + solution_graph, + segm.shape, + solution_key=None, + mask_key=td.DEFAULT_ATTR_KEYS.MASK, + ) + + metrics = td.metrics.evaluate_ctc_metrics( + gt_graph, + solution_graph, + input_reset=False, + reference_reset=False, + ) + print(metrics) + viewer = napari.Viewer() + viewer.add_image(img) + viewer.add_labels(segm, name="GT labels") + viewer.add_tracks(gt_tracks_df, graph=gt_track_graph, name="GT tracks") + viewer.add_tracks(tracks_df, graph=track_graph, name="Solution tracks") + viewer.add_labels(labels, name="Solution labels") + + # solution_graph.match(gt_graph) + # td.metrics.visualize_matches( + # solution_graph, + # gt_graph, + # viewer=viewer, + # ) - return - models = [ - Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_classical_gfp_rfp_ph_temp0p5_batch128_ckpt146.onnx"), - Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_timeaware_gfp_rfp_ph_temp0p5_batch128_ckpt185.onnx"), - Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_phase_brightfield_temp0p2_batch256_ckpt33.onnx"), - Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_timeaware_phase_brightfield_temp0p2_batch256_ckpt13.onnx"), - None, - ] - - results = [] - - dataset_root = Path("/hpc/reference/group.royer/CTC/training/") - - for model_path in models: - for dataset_dir in sorted(dataset_root.iterdir()): - - for dataset_num in ["01", "02"]: - # processing only datasets with segmentation (linking challenge) - seg_dir = _seg_dir(dataset_dir, dataset_num) - if not seg_dir.exists(): - print(f"Skipping {dataset_dir.name} because it does not have segmentation") - continue - - metrics = track_single_dataset( - dataset_dir=dataset_dir, - dataset_num=dataset_num, - show_napari_viewer=False, - dynaclr_model_path=model_path, - ) - metrics["model"] = "None" if model_path is None else model_path.stem - metrics["dataset"] = dataset_dir.name - metrics["dataset_num"] = dataset_num - print(metrics) - results.append(metrics) - - # update for every new result - df = pl.DataFrame(results) - df.write_csv("results.csv") - - print( - df.group_by("model", "dataset").mean().select( - "model", "dataset", "LNK", "BIO(0)", "OP_CLB(0)", "CHOTA" - ) - ) + napari.run() if __name__ == "__main__": From 721ca2ef3e0d6954f47d3e7171715379ee8636ac Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Thu, 6 Nov 2025 14:47:33 -0800 Subject: [PATCH 13/13] updated parameter --- applications/cell_tracking/track_proj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/cell_tracking/track_proj.py b/applications/cell_tracking/track_proj.py index 89fa0475c..deb78f9af 100755 --- a/applications/cell_tracking/track_proj.py +++ b/applications/cell_tracking/track_proj.py @@ -87,7 +87,7 @@ def main() -> None: img, segm, dist_edge_kwargs={"delta_t": 1}, - ilp_kwargs={"division_weight": 0.75}, + ilp_kwargs={"division_weight": 0.95}, ) gt_tracks_df, gt_track_graph, gt_labels = td.functional.to_napari_format(