diff --git a/applications/cell_tracking/cell_tracking_ctc.py b/applications/cell_tracking/cell_tracking_ctc.py new file mode 100755 index 000000000..7666903f8 --- /dev/null +++ b/applications/cell_tracking/cell_tracking_ctc.py @@ -0,0 +1,379 @@ +#!/usr/bin/env -S uv run --script +# +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "tracksdata @ git+https://github.com/royerlab/tracksdata.git", +# "onnxruntime-gpu", +# "napari[pyqt5]", +# "gurobipy", +# "py-ctcmetrics==1.2.2", +# "spatial-graph", +# ] +# /// + +from pathlib import Path +from typing import Any +import polars as pl +import numpy as np +import napari +import onnxruntime as ort + +from dask.array.image import imread +from numpy.typing import NDArray +from toolz import curry +from rich import print +from scipy.ndimage import gaussian_filter + +import tracksdata as td + + +def _seg_dir(dataset_dir: Path, dataset_num: str) -> Path: + return dataset_dir / f"{dataset_num}_ERR_SEG" + + +def _pad(image: NDArray, shape: tuple[int, int], mode: str) -> NDArray: + """ + Pad the image to the given shape. + """ + diff = np.asarray(shape) - np.asarray(image.shape) + + if diff.sum() == 0: + return image + + left = diff // 2 + right = diff - left + + return np.pad(image, tuple(zip(left, right)), mode=mode) + + +@curry +def _crop_embedding( + frame: NDArray, + mask: list[td.nodes.Mask], + final_shape: tuple[int, int], + session: ort.InferenceSession, + input_name: str, +) -> NDArray[np.float32]: + """ + Crop the frame and compute the DynaCLR embedding. + + Parameters + ---------- + frame : NDArray + The frame to crop. + mask : Mask + The mask to crop the frame. + shape : tuple[int, int] + The shape of the crop. + session : ort.InferenceSession + The session to use for the embedding. + input_name : str + The name of the input tensor. + padding : int, optional + The padding to apply to the crop. + + Returns + ------- + NDArray[np.float32] + The embedding of the crop. + """ + label_img = np.zeros_like(frame, dtype=np.int16) + + crops = [] + for i, m in enumerate(mask, start=1): + + if frame.ndim == 3: + crop_shape = (1, *final_shape) + else: + crop_shape = final_shape + + label_img[m.mask_indices()] = i + + crop = m.crop(frame, shape=crop_shape).astype(np.float32) + crop_mask = (m.crop(label_img, shape=crop_shape) == i).astype(np.float32) + + if crop.ndim == 3: + assert crop.shape[0] == 1, f"Expected 1 z-slice in 3D crop. Found {crop.shape[0]}" + crop = crop[0] + crop_mask = crop_mask[0] + + crop = _pad(crop, final_shape, mode="reflect") + crop_mask = _pad(crop_mask, final_shape, mode="constant") + + blurred_mask = gaussian_filter(crop_mask, sigma=5) + blurred_coef = blurred_mask.max() + if blurred_coef > 1e-8: # if too small use the binary mask + crop_mask = np.maximum(crop_mask, blurred_mask / blurred_coef) + + mu, sigma = np.mean(crop), np.std(crop) + # mu = np.median(crop) + # sigma = np.quantile(crop, 0.99) - mu + crop = (crop - mu) / np.maximum(sigma, 1e-8) + + # removing background + crop = crop * crop_mask + + if crop.shape != final_shape: + raise ValueError(f"Crop shape {crop.shape} does not match final shape {final_shape}") + + crops.append(crop) + + # expanding batch, channel, and z dimensions + crops = np.stack(crops, axis=0) + crops = crops[:, np.newaxis, np.newaxis, ...] + output = session.run(None, {input_name: crops}) + + # import napari + # viewer = napari.Viewer() + # viewer.add_image(frame) + # viewer.add_image(np.squeeze(crops)) + # napari.run() + + # embedding = output[-1] # projected 32-dimensional embedding + embedding = output[0] # 768-dimensional embedding + embedding = embedding / np.linalg.norm(embedding, axis=1, keepdims=True) + + return [e for e in embedding] + + +def _add_dynaclr_attrs( + model_path: Path, + graph: td.graph.InMemoryGraph, + images: NDArray, +) -> None: + """ + Add DynaCLR embedding attributes to each node in the graph + and compute the cosine similarity for existing edges. + + Parameters + ---------- + graph : td.graph.InMemoryGraph + The graph to add the attributes to. + images : NDArray + The images to use for the embedding. + """ + + session = ort.InferenceSession(model_path) + + input_name = session.get_inputs()[0].name + input_dim = session.get_inputs()[0].shape + input_type = session.get_inputs()[0].type + + print(f"Model input name: '{input_name}'") + print(f"Expected input dimensions: {input_dim}") + print(f"Expected input type: {input_type}") + + crop_attr_func = _crop_embedding( + final_shape=(64, 64), + session=session, + input_name=input_name, + ) + + print("Adding DynaCLR embedding attributes ...") + td.nodes.GenericFuncNodeAttrs( + func=crop_attr_func, + output_key="dynaclr_embedding", + attr_keys=["mask"], + batch_size=128, + ).add_node_attrs(graph, frames=images) + + print("Adding cosine similarity attributes ...") + td.edges.GenericFuncEdgeAttrs( + func=np.dot, + output_key="dynaclr_similarity", + attr_keys="dynaclr_embedding", + ).add_edge_attrs(graph) + + +def _track( + dynaclr_model_path: Path | None, + images: NDArray, + labels: NDArray, + dist_edge_kwargs: dict[str, Any] | None = None, + ilp_kwargs: dict[str, Any] | None = None, +) -> tuple[td.graph.InMemoryGraph, td.graph.InMemoryGraph]: + """ + Track cells in a graph. + + Parameters + ---------- + dynaclr_model_path : Path | None + Path to the DynaCLR model. If None, the model will not be used. + images : NDArray + The images to use for the embedding. + labels : NDArray + The labels to use for the tracking. + + Returns + ------- + tuple[td.graph.InMemoryGraph, td.graph.InMemoryGraph] + The original graph and the solution graph. + """ + print("Starting tracking ...") + graph = td.graph.InMemoryGraph() + + nodes_operator = td.nodes.RegionPropsNodes() + nodes_operator.add_nodes(graph, labels=labels) + print(f"Number of nodes: {graph.num_nodes}") + + if dist_edge_kwargs is None: + dist_edge_kwargs = {} + + dist_operator = td.edges.DistanceEdges( + distance_threshold=dist_edge_kwargs.pop("distance_threshold", 325.0), + n_neighbors=dist_edge_kwargs.pop("n_neighbors", 10), + delta_t=dist_edge_kwargs.pop("delta_t", 5), + **dist_edge_kwargs, + ) + dist_operator.add_edges(graph) + print(f"Number of edges: {graph.num_edges}") + + td.edges.GenericFuncEdgeAttrs( + func=lambda x, y: abs(x - y), + output_key="delta_t", + attr_keys="t", + ).add_edge_attrs(graph) + + dist_weight = (-td.EdgeAttr(td.DEFAULT_ATTR_KEYS.EDGE_DIST) / dist_operator.distance_threshold).exp() + + if dynaclr_model_path is not None: + _add_dynaclr_attrs( + dynaclr_model_path, + graph, + images, + ) + # decrease dynaclr similarity given the distance? + edge_weight = -td.EdgeAttr("dynaclr_similarity") * dist_weight + + else: + iou_operator = td.edges.IoUEdgeAttr(output_key="iou") + iou_operator.add_edge_attrs(graph) + + edge_weight = -(td.EdgeAttr("iou") + 0.1) * dist_weight + + edge_weight = edge_weight / td.EdgeAttr("delta_t").clip(lower_bound=1) + + if ilp_kwargs is None: + ilp_kwargs = {} + + solver = td.solvers.ILPSolver( + edge_weight=edge_weight, + appearance_weight=ilp_kwargs.pop("appearance_weight", 0), + disappearance_weight=ilp_kwargs.pop("disappearance_weight", 0), + division_weight=ilp_kwargs.pop("division_weight", 0.5), + node_weight=ilp_kwargs.pop("node_weight", -10), + **ilp_kwargs, + ) + + solution_graph = solver.solve(graph) + + return graph, solution_graph + + +def track_single_dataset( + dataset_dir: Path, + dataset_num: str, + show_napari_viewer: bool, + dynaclr_model_path: Path | None, +) -> None: + """ + Main function to track cells in a dataset. + + Parameters + ---------- + dataset_dir : Path + Path to the dataset directory. + dataset_num : str + Number of the dataset. + show_napari_viewer : bool + Whether to show the napari viewer. + dynaclr_model_path : Path | None + Path to the DynaCLR model. If None, the model will not be used. + """ + assert dataset_dir.exists(), f"Data directory {dataset_dir} does not exist." + + print(f"Loading labels from '{dataset_dir}'...") + labels = imread(str(_seg_dir(dataset_dir, dataset_num) / "*.tif")) + images = imread(str(dataset_dir / dataset_num / "*.tif")) + + gt_graph = td.graph.InMemoryGraph.from_ctc(dataset_dir / f"{dataset_num}_GT" / "TRA") + + graph, solution_graph = _track( + dynaclr_model_path, + images, + labels, + ) + + print("Evaluating results ...") + metrics = td.metrics.evaluate_ctc_metrics( + solution_graph, + gt_graph, + input_reset=False, + reference_reset=False, + ) + + if show_napari_viewer: + print("Converting to napari format ...") + tracks_df, track_graph, labels = td.functional.to_napari_format( + graph, labels.shape, mask_key=td.DEFAULT_ATTR_KEYS.MASK + ) + + print("Opening napari viewer ...") + viewer = napari.Viewer() + viewer.add_image(images) + viewer.add_labels(labels) + viewer.add_tracks(tracks_df, graph=track_graph) + napari.run() + + return metrics + + +def main() -> None: + models = [ + Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_classical_gfp_rfp_ph_temp0p5_batch128_ckpt146.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/time_interval/dynaclr_gfp_rfp_ph_2D/deploy/dynaclr2d_timeaware_gfp_rfp_ph_temp0p5_batch128_ckpt185.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_phase_brightfield_temp0p2_batch256_ckpt33.onnx"), + Path("/hpc/projects/organelle_phenotyping/models/dynamorph_microglia/deploy/dynaclr2d_timeaware_phase_brightfield_temp0p2_batch256_ckpt13.onnx"), + None, + ] + + results = [] + + dataset_root = Path("/hpc/reference/group.royer/CTC/training/") + + for model_path in models: + for dataset_dir in sorted(dataset_root.iterdir()): + + for dataset_num in ["01", "02"]: + # processing only datasets with segmentation (linking challenge) + seg_dir = _seg_dir(dataset_dir, dataset_num) + if not seg_dir.exists(): + print(f"Skipping {dataset_dir.name} because it does not have segmentation") + continue + + metrics = track_single_dataset( + dataset_dir=dataset_dir, + dataset_num=dataset_num, + show_napari_viewer=False, + dynaclr_model_path=model_path, + ) + metrics["model"] = "None" if model_path is None else model_path.stem + metrics["dataset"] = dataset_dir.name + metrics["dataset_num"] = dataset_num + print(metrics) + results.append(metrics) + + # update for every new result + df = pl.DataFrame(results) + df.write_csv("results.csv") + + print( + df.group_by("model", "dataset").mean().select( + "model", "dataset", "LNK", "BIO(0)", "OP_CLB(0)", "CHOTA" + ) + ) + + +if __name__ == "__main__": + main() diff --git a/applications/cell_tracking/report.py b/applications/cell_tracking/report.py new file mode 100755 index 000000000..322ff5511 --- /dev/null +++ b/applications/cell_tracking/report.py @@ -0,0 +1,39 @@ +#!/usr/bin/env -S uv run --script +# +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "polars", +# "altair>=5.4.0", +# ] +# /// +import polars as pl +import altair as alt + + +def main() -> None: + df = pl.read_csv("results.csv") + gdf = df.group_by("model", "dataset").mean() + + alt.renderers.enable("browser") # use system browser + + metric = "OP_CLB(0)" + p = gdf.plot.bar(x="model", y=metric, color="model") + p = p.properties( + title=metric, + ) + + facet_chart = p.facet( + facet="dataset", + columns=5, + ) + + facet_chart.show() + + df.group_by("model").mean().plot.bar(x="model", y=metric, color="model").properties( + title=metric, + ).show() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/applications/cell_tracking/track_proj.py b/applications/cell_tracking/track_proj.py new file mode 100755 index 000000000..deb78f9af --- /dev/null +++ b/applications/cell_tracking/track_proj.py @@ -0,0 +1,134 @@ +#!/usr/bin/env -S uv run --script +# +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "tracksdata @ git+https://github.com/royerlab/tracksdata.git", +# "contrastive-td @ git+https://github.com/royerlab/contrastive-td.git", +# "onnxruntime-gpu", +# "napari[pyqt5]", +# "gurobipy", +# "py-ctcmetrics==1.2.2", +# "spatial-graph", +# ] +# /// + +from pathlib import Path +import polars as pl +import napari +import zarr +import dask.array as da +from tracksdata.io._ctc import _add_edges_from_tracklet_ids + +from rich import print + +import tracksdata as td +from cell_tracking_ctc import _track + + +def _load_gt_graph( + segm: da.Array, + tracklets_df: pl.DataFrame, +) -> td.graph.InMemoryGraph: + + tracklets_df = tracklets_df.filter(pl.col("parent_track_id") != -1) + tracklet_id_graph = dict(zip( + tracklets_df["track_id"].to_list(), + tracklets_df["parent_track_id"].to_list(), + )) + + gt_graph = td.graph.InMemoryGraph() + + td.nodes.RegionPropsNodes( + extra_properties=["label"], + ).add_nodes(gt_graph, labels=segm) + + _add_edges_from_tracklet_ids( + gt_graph, + gt_graph.node_attrs(attr_keys=[ + td.DEFAULT_ATTR_KEYS.NODE_ID, + td.DEFAULT_ATTR_KEYS.T, + "label", + ]), + tracklet_id_graph=tracklet_id_graph, + tracklet_id_key="label", + ) + + return gt_graph + + +def main() -> None: + + img_dir = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_08_26_A549_SEC61_TOMM20_ZIKV/4-phenotyping/0-train-test/2025_08_26_A549_SEC61_TOMM20_ZIKV.zarr") + + segm_dir = Path("/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_08_26_A549_SEC61_TOMM20_ZIKV/2-assemble/tracking_annotation.zarr") + + pos_key = "A/2/000000/0" + img_ds = zarr.open(img_dir, mode="r") + img = da.from_zarr(img_ds[pos_key])[:, 0, 0] + + segm_ds = zarr.open(segm_dir, mode="r") + segm = da.from_zarr(segm_ds[pos_key])[:, 0, 0] + + print(img.shape) + print(segm.shape) + + short_pos_key = pos_key[:-2] + suffix = "_".join(short_pos_key.split("/")) + + tracklets_graph_path = segm_dir / short_pos_key / f"tracks_{suffix}.csv" + tracklets_df = pl.read_csv(tracklets_graph_path) + gt_graph = _load_gt_graph(segm, tracklets_df) + + model_path = Path("/hpc/projects/organelle_phenotyping/models/SEC61_TOMM20_G3BP1_Sensor/deploy/dynaclr2d_timeaware_bagchannels_160patch_ckpt104.onnx") + + graph, solution_graph = _track( + model_path, + img, + segm, + dist_edge_kwargs={"delta_t": 1}, + ilp_kwargs={"division_weight": 0.95}, + ) + + gt_tracks_df, gt_track_graph, gt_labels = td.functional.to_napari_format( + gt_graph, + segm.shape, + solution_key=None, + mask_key=td.DEFAULT_ATTR_KEYS.MASK, + ) + + tracks_df, track_graph, labels = td.functional.to_napari_format( + solution_graph, + segm.shape, + solution_key=None, + mask_key=td.DEFAULT_ATTR_KEYS.MASK, + ) + + metrics = td.metrics.evaluate_ctc_metrics( + gt_graph, + solution_graph, + input_reset=False, + reference_reset=False, + ) + + print(metrics) + + viewer = napari.Viewer() + viewer.add_image(img) + viewer.add_labels(segm, name="GT labels") + viewer.add_tracks(gt_tracks_df, graph=gt_track_graph, name="GT tracks") + viewer.add_tracks(tracks_df, graph=track_graph, name="Solution tracks") + viewer.add_labels(labels, name="Solution labels") + + # solution_graph.match(gt_graph) + # td.metrics.visualize_matches( + # solution_graph, + # gt_graph, + # viewer=viewer, + # ) + + napari.run() + + +if __name__ == "__main__": + main()