diff --git a/applications/pseudotime_analysis/cell_cycle/cell_cycle_consensus.py b/applications/pseudotime_analysis/cell_cycle/cell_cycle_consensus.py new file mode 100644 index 00000000..6baaf02a --- /dev/null +++ b/applications/pseudotime_analysis/cell_cycle/cell_cycle_consensus.py @@ -0,0 +1,921 @@ +#%% +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from anndata import read_zarr +from iohub import open_ome_zarr +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from viscy.data.triplet import TripletDataset +from viscy.representation.pseudotime import CytoDtw + +#%% +logger = logging.getLogger("viscy") +logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(message)s") +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + +""" +TODO +- We need to find a way to save the annotations, features and track information into one file. +- We need to standardize the naming convention. i.e The annotations fov_name is missing a / at the beginning. +- It would be nice to also select which will be the reference lineages and add that as a column. +- Figure out what is the best format to save the consensus lineage +- Does the consensus track generalize? +- There is a lot of fragmentation. Which tracking was used for the annotations? There is a script that unifies this but no record of which one was it. We can append these as extra columns + +""" + +# Configuration +NAPARI = True +if NAPARI: + import os + + import napari + os.environ["DISPLAY"] = ":1" + viewer = napari.Viewer() + +# File paths + +# ANNOTATIONS +cell_cycle_annotations_denv_dict= { + # "tomm20_cc_1": + # {'data_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/2-assemble/2024_11_21_A549_TOMM20_DENV.zarr", + # 'fov_name': "/C/2/001000", + # 'annotations_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/4-phenotyping/0-annotations/track_cell_state_annotation.csv", + # 'features_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/4-phenotyping/1-predictions/phase_160patch_104ckpt_ver3max.zarr", + # 'tracks_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/2-assemble/tracking.zarr", + # }, + # "tomm20_cc_2": + # {'data_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_21_A549_TOMM20_DENV/2-assemble/2024_11_21_A549_TOMM20_DENV.zarr", + # 'fov_name': "/B/3/000001", + # 'annotations_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_21_A549_TOMM20_DENV/4-phenotyping/0-annotations/track_cell_state_annotation.csv", + # 'features_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_21_A549_TOMM20_DENV/4-phenotyping/1-predictions/phase_160patch_104ckpt_ver3max.zarr", + # # 'tracks_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_21_A549_TOMM20_DENV/2-assemble/tracking.zarr", + # }, + # "sec61b_cc_1": + # {'data_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr/B/3/001000", + # 'annotations_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/track_cell_state_annotation.csv", + # 'tracks_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/2-assemble/tracking.zarr", + # }, + # "sec61b_cc_2": + # {'data_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr/C/2/000001", + # 'annotations_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/track_cell_state_annotation.csv", + # }, + "g3bp1_cc_1": + {'data_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr", + 'annotations_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/cytospeak_annotations/2025_07_24_annotations.csv", + 'features_path': "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/cell_cycle/output/phase_160patch_104ckpt_ver3max.anndata", + 'fov_name': "C/1/001000", + }, +} +output_root = Path( + "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/cell_cycle/output" +) +output_root.mkdir(parents=True, exist_ok=True) + +#%% +color_dict = { + "interphase": "blue", + "mitosis": "orange", +} +ANNOTATION_CELL_CYCLE = "predicted_cellstate" + +# Load each dataframe and find the lineages +key , cell_cycle_annotations_denv=next(iter(cell_cycle_annotations_denv_dict.items())) +cell_cycle_annotations_df = pd.read_csv(cell_cycle_annotations_denv["annotations_path"]) +data_path = cell_cycle_annotations_denv["data_path"] +fov_name = cell_cycle_annotations_denv["fov_name"] +features_path = cell_cycle_annotations_denv["features_path"] + +# Load AnnData directly +adata = read_zarr(features_path) +print("Loaded AnnData with shape:", adata.shape) +print("Available columns:", adata.obs.columns.tolist()) + +# Instantiate the CytoDtw object with AnnData +cytodtw = CytoDtw(adata) +feature_df = cytodtw.adata.obs + +min_timepoints = 7 +filtered_lineages = cytodtw.get_lineages(min_timepoints) +filtered_lineages = pd.DataFrame(filtered_lineages, columns=["fov_name", "track_id"]) +logger.info(f"Found {len(filtered_lineages)} lineages with at least {min_timepoints} timepoints") + +#%% +n_timepoints_before = min_timepoints//2 +n_timepoints_after = min_timepoints//2 +valid_annotated_examples=[ + { + 'fov_name': "A/2/001001", + 'track_id': [136,137], + 'timepoints': (43-n_timepoints_before, 43+n_timepoints_after+1), + 'annotations': ["interphase"] * (n_timepoints_before) + ["mitosis"] + ["interphase"] * (n_timepoints_after-1), + 'weight': 1.0 +}, +# { +# 'fov_name': "C/1/001000", +# 'track_id': [47,48], +# 'timepoints': (45-n_timepoints_before, 45+n_timepoints_after+1), +# 'annotations': ["interphase"] * (n_timepoints_before) + ["mitosis"] + ["interphase"] * (n_timepoints_after-1), +# 'weight': 1.0 +# }, +{ + 'fov_name': "C/1/000000", + 'track_id': [118,119], + 'timepoints': (27-n_timepoints_before, 27+n_timepoints_after+1), + 'annotations': ["interphase"] * (n_timepoints_before) + ["mitosis"] + ["interphase"] * (n_timepoints_after-1), + 'weight': 1.0 +}, +# { +# 'fov_name': "C/1/001000", +# 'track_id': [59,60], +# 'timepoints': (52-n_timepoints_before, 52+n_timepoints_after+1), +# 'annotations': ["interphase"] * (n_timepoints_before) + ["mitosis"] + ["interphase"] * (n_timepoints_after-1), +# 'weight': 1.0 +# }, +{ + 'fov_name': "C/1/001001", + 'track_id': [93,94], + 'timepoints': (29-n_timepoints_before, 29+n_timepoints_after+1), + 'annotations': ["interphase"] * (n_timepoints_before) + ["mitosis"] + ["interphase"] * (n_timepoints_after-1), + 'weight': 1.0 +}, + +] +#%% +# Extract all reference patterns +patterns = [] +pattern_info = [] +REFERENCE_TYPE = "features" +DTW_CONSTRAINT_TYPE="sakoe_chiba" +DTW_BAND_WIDTH_RATIO=0.3 + +for i, example in enumerate(valid_annotated_examples): + pattern = cytodtw.get_reference_pattern( + fov_name=example['fov_name'], + track_id=example['track_id'], + timepoints=example['timepoints'], + reference_type=REFERENCE_TYPE, + ) + patterns.append(pattern) + pattern_info.append({ + 'index': i, + 'fov_name': example['fov_name'], + 'track_id': example['track_id'], + 'timepoints': example['timepoints'], + 'annotations': example['annotations'] + }) + +# Concatenate all patterns to fit PCA on full dataset +all_patterns_concat = np.vstack(patterns) + +#%% +# Plot the sample patterns + +# Fit PCA on all data +scaler = StandardScaler() +scaled_patterns = scaler.fit_transform(all_patterns_concat) +pca = PCA(n_components=3) +pca.fit(scaled_patterns) + +# Create subplots for PC1, PC2, PC3 over time +n_patterns = len(patterns) +fig, axes = plt.subplots(n_patterns, 3, figsize=(12, 3*n_patterns)) +if n_patterns == 1: + axes = axes.reshape(1, -1) + +# Plot each pattern +for i, (pattern, info) in enumerate(zip(patterns, pattern_info)): + # Transform this pattern to PC space + scaled_pattern = scaler.transform(pattern) + pc_pattern = pca.transform(scaled_pattern) + + # Create time axis + time_axis = np.arange(len(pattern)) + + # Plot PC1, PC2, PC3 + for pc_idx in range(3): + ax = axes[i, pc_idx] + + # Plot PC trajectory with colorblind-friendly colors + ax.plot(time_axis, pc_pattern[:, pc_idx], 'o-', color='blue', linewidth=2, markersize=4) + + # Color timepoints by annotation + annotations = info['annotations'] + for t, annotation in enumerate(annotations): + if annotation == 'mitosis': + ax.axvline(t, color='orange', alpha=0.7, linestyle='--', linewidth=2) + ax.scatter(t, pc_pattern[t, pc_idx], c='orange', s=50, zorder=5) + + # Formatting + ax.set_xlabel('Time') + ax.set_ylabel(f'PC{pc_idx+1}') + ax.set_title(f'Pattern {i+1}: FOV {info["fov_name"]}, Tracks {info["track_id"]}\nPC{pc_idx+1} over time') + ax.grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() + +#%% +# Create consensus pattern if we have valid examples +if len(valid_annotated_examples) >= 2: + consensus_result = cytodtw.create_consensus_reference_pattern( + annotated_samples=valid_annotated_examples, + reference_selection="median_length", + aggregation_method="median", + reference_type=REFERENCE_TYPE, + ) + consensus_lineage = consensus_result['pattern'] + consensus_annotations = consensus_result.get('annotations', None) + consensus_metadata = consensus_result['metadata'] + + logger.info(f"Created consensus pattern with shape: {consensus_lineage.shape}") + logger.info(f"Consensus method: {consensus_metadata['aggregation_method']}") + logger.info(f"Reference pattern: {consensus_metadata['reference_pattern']}") + if consensus_annotations: + logger.info(f"Consensus annotations length: {len(consensus_annotations)}") +else: + logger.warning("Not enough valid lineages found to create consensus pattern") + +#%% +# Perform DTW analysis for each embedding method +alignment_results = {} +top_n = 30 + +name = "consensus_lineage" +consensus_lineage = cytodtw.consensus_data['pattern'] +# Find pattern matches +matches = cytodtw.get_matches( + reference_pattern=consensus_lineage, + lineages=filtered_lineages.to_numpy(), + window_step=1, + num_candidates=top_n, + method="bernd_clifford", + metric="cosine", + save_path=output_root / f"{name}_matching_lineages_cosine.csv", + reference_type=REFERENCE_TYPE, + constraint_type=DTW_CONSTRAINT_TYPE, + band_width_ratio=DTW_BAND_WIDTH_RATIO +) + +alignment_results[name] = matches +logger.info(f"Found {len(matches)} matches for {name}") +#%% +# Save matches +print(f'Saving matches to {output_root / f"{name}_matching_lineages_cosine.csv"}') +# cytodtw.save_consensus(output_root / f"{name}_consensus_lineage.pkl") +# Add consensus path to the df all rows +# Add a new column 'consensus_path' to the matches DataFrame, with the same value for all rows. +# This is useful for downstream analysis to keep track of the consensus pattern used for matching. +# Reference: pandas.DataFrame.assign +matches['consensus_path'] = str(output_root / f"{name}_consensus_lineage.pkl") +# Save the pkl +cytodtw.save_consensus(output_root / f"{name}_consensus_lineage.pkl") + +matches.to_csv(output_root / f"{name}_matching_lineages_cosine.csv", index=False) +#%% +top_matches = matches.head(top_n) + +# Use the new enhanced alignment dataframe method instead of manual alignment +enhanced_df = cytodtw.create_enhanced_alignment_dataframe(top_matches, consensus_lineage, alignment_name="cell_division", reference_type=REFERENCE_TYPE) + +logger.info(f"Enhanced DataFrame created with {len(enhanced_df)} rows") +logger.info(f"Lineages: {enhanced_df['lineage_id'].nunique()} (including consensus)") +logger.info(f"Cell division aligned timepoints: {enhanced_df['dtw_cell_division_aligned'].sum()}/{len(enhanced_df)} ({100*enhanced_df['dtw_cell_division_aligned'].mean():.1f}%)") +# PCA plotting and alignment visualization is now handled by the enhanced alignment dataframe method +logger.info("Cell division consensus analysis completed successfully!") +print(f"Enhanced DataFrame columns: {enhanced_df.columns.tolist()}") + +#%% +# Prototype video alignment based on DTW matches + +z_range = slice(0, 1) +initial_yx_patch_size = (192, 192) + +positions = [] +tracks_tables = [] +images_plate = open_ome_zarr(data_path) + +# Load matching positions +print(f"Loading positions for {len(top_matches)} FOV matches...") +matches_found = 0 +for _, pos in images_plate.positions(): + pos_name = pos.zgroup.name + pos_normalized = pos_name.lstrip('/') + + if pos_normalized in top_matches['fov_name'].values: + positions.append(pos) + matches_found += 1 + + # Get ALL tracks for this FOV to ensure TripletDataset has complete access + tracks_df = cytodtw.adata.obs[cytodtw.adata.obs["fov_name"] == pos_normalized].copy() + + if len(tracks_df) > 0: + tracks_df = tracks_df.dropna(subset=['x', 'y']) + tracks_df['x'] = tracks_df['x'].astype(int) + tracks_df['y'] = tracks_df['y'].astype(int) + tracks_tables.append(tracks_df) + + if matches_found == 1: + processing_channels = pos.channel_names + +print(f"Loaded {matches_found} positions with {sum(len(df) for df in tracks_tables)} total tracks") + +# Create TripletDataset if we have valid positions +if len(positions) > 0 and len(tracks_tables) > 0: + if 'processing_channels' not in locals(): + processing_channels = positions[0].channel_names + + # Use all three channels for overlay visualization + selected_channels = processing_channels # Use all available channels + print(f"Creating TripletDataset with {len(selected_channels)} channels: {selected_channels}") + + dataset = TripletDataset( + positions=positions, + tracks_tables=tracks_tables, + channel_names=selected_channels, + initial_yx_patch_size=initial_yx_patch_size, + z_range=z_range, + fit=False, + predict_cells=False, + include_fov_names=None, + include_track_ids=None, + time_interval=1, + return_negative=False, + ) + print(f"TripletDataset created with {len(dataset.valid_anchors)} valid anchors") +else: + print("Cannot create TripletDataset - no valid positions or tracks") + dataset = None + +# %% +# Simplified sequence alignment using existing DTW results +def get_aligned_image_sequences(dataset: TripletDataset, candidates_df:pd.DataFrame): + """Get image sequences aligned to consensus timeline using DTW warp paths.""" + + aligned_sequences = {} + for idx, row in candidates_df.iterrows(): + fov_name = row['fov_name'] + track_ids = row['track_ids'] + warp_path = row['warp_path'] + start_time = int(row['start_track_timepoint']) if not pd.isna(row['start_track_timepoint']) else 0 + + # Determine alignment length from warp path + alignment_length = max(ref_idx for ref_idx, _ in warp_path) + 1 + + # Find matching dataset indices + matching_indices = [] + for dataset_idx in range(len(dataset.valid_anchors)): + anchor_row = dataset.valid_anchors.iloc[dataset_idx] + if (anchor_row['fov_name'] == fov_name and anchor_row['track_id'] in track_ids): + matching_indices.append(dataset_idx) + + if not matching_indices: + logger.warning(f"No matching indices found for FOV {fov_name}, tracks {track_ids}") + continue + + # Get images and sort by time + batch_data = dataset.__getitems__(matching_indices) + + # Extract individual images from batch + images = [] + for i in range(len(matching_indices)): + img_data = { + 'anchor': batch_data['anchor'][i], + 'index': batch_data['index'][i] + } + images.append(img_data) + + images.sort(key=lambda x: x['index']['t']) + time_to_image = {img['index']['t']: img for img in images} + + # Create warp_path mapping and align images + # Note: query_idx is now actual t value, not relative index + ref_to_query = {ref_idx: query_t for ref_idx, query_t in warp_path} + aligned_images = [None] * alignment_length + + for ref_idx in range(alignment_length): + if ref_idx in ref_to_query: + query_time = ref_to_query[ref_idx] # query_time is already actual t value + if query_time in time_to_image: + aligned_images[ref_idx] = time_to_image[query_time] + else: + # Find closest available time + available_times = list(time_to_image.keys()) + if available_times: + closest_time = min(available_times, key=lambda x: abs(x - query_time)) + aligned_images[ref_idx] = time_to_image[closest_time] + + # Fill None values with nearest neighbor + for i in range(alignment_length): + if aligned_images[i] is None: + for offset in range(1, alignment_length): + for direction in [-1, 1]: + neighbor_idx = i + direction * offset + if 0 <= neighbor_idx < alignment_length and aligned_images[neighbor_idx] is not None: + aligned_images[i] = aligned_images[neighbor_idx] + break + if aligned_images[i] is not None: + break + + aligned_sequences[idx] = { + 'aligned_images': aligned_images, + 'metadata': { + 'fov_name': fov_name, + 'track_ids': track_ids, + 'distance': row['distance'], + 'alignment_length': alignment_length + } + } + + return aligned_sequences + +# Get aligned sequences using consolidated function +aligned_sequences = get_aligned_image_sequences(dataset, top_matches) + +logger.info(f"Retrieved {len(aligned_sequences)} aligned sequences") +for idx, seq in aligned_sequences.items(): + meta = seq['metadata'] + index=seq['aligned_images'][0]['index'] + logger.info(f"Track id {index['track_id']}: FOV {meta['fov_name']} aligned images, distance={meta['distance']:.3f}") + +# %% +# Load aligned sequences into napari +if NAPARI and len(aligned_sequences) > 0: + import numpy as np + + for idx, seq_data in aligned_sequences.items(): + aligned_images = seq_data['aligned_images'] + meta = seq_data['metadata'] + index=seq_data['aligned_images'][0]['index'] + + if len(aligned_images) == 0: + continue + + # Stack images into time series (T, C, Z, Y, X) + image_stack = [] + for img_sample in aligned_images: + img_tensor = img_sample['anchor'] # Shape should be (Z, C, Y, X) + img_np = img_tensor.cpu().numpy() + image_stack.append(img_np) + + if len(image_stack) > 0: + # Stack into (T, Z, C, Y, X) or (T, C, Z, Y, X) + time_series = np.stack(image_stack, axis=0) + + # Add to napari viewer + layer_name = f"track_id_{index['track_id']}_FOV_{meta['fov_name']}_dist_{meta['distance']:.3f}" + viewer.add_image( + time_series, + name=layer_name, + contrast_limits=(time_series.min(), time_series.max()), + ) + logger.info(f"Added {layer_name} with shape {time_series.shape}") +# Enhanced DataFrame was already created above with PCA plotting - skip duplicate +logger.info(f"Cell division aligned timepoints: {enhanced_df['dtw_cell_division_aligned'].sum()}/{len(enhanced_df)} ({100*enhanced_df['dtw_cell_division_aligned'].mean():.1f}%)") +logger.info(f"Columns: {list(enhanced_df.columns)}") + +# Show sample of the enhanced DataFrame +print("\nSample of enhanced DataFrame:") +sample_df = enhanced_df[enhanced_df['lineage_id'] != -1].head(10) +display_cols = ['lineage_id', 'track_id', 't', 'dtw_cell_division_aligned', 'dtw_cell_division_consensus_mapping', 'PC1'] +print(sample_df[display_cols].to_string()) + +#%% + +# Clean function that works directly with enhanced DataFrame +def plot_concatenated_from_dataframe(df, alignment_name="cell_division", + feature_columns=['PC1', 'PC2', 'PC3'], + max_lineages=5, y_offset_step=2.0, + aligned_scale=1.0, unaligned_scale=1.0): + """ + Plot concatenated [DTW-aligned portion] + [unaligned portion] sequences + using ONLY the enhanced DataFrame and alignment information stored in it. + + This function reconstructs the aligned portions using the consensus mapping + information already stored in the DataFrame. + + Parameters + ---------- + df : pd.DataFrame + Enhanced DataFrame with alignment information + alignment_name : str + Name of alignment to plot (e.g., "cell_division") + feature_columns : list + Feature columns to plot + max_lineages : int + Maximum number of lineages to display + y_offset_step : float + Vertical separation between lineages + aligned_scale : float + Scale factor for DTW-aligned portions (line width & marker size) + unaligned_scale : float + Scale factor for unaligned portions (line width & marker size) + """ + import matplotlib.pyplot as plt + + # Calculate line widths and marker sizes based on separate scale factors + aligned_linewidth = 5 * aligned_scale + unaligned_linewidth = 2 * unaligned_scale + aligned_markersize = 8 * aligned_scale + unaligned_markersize = 4 * unaligned_scale + + # Dynamic column names based on alignment_name + aligned_col = f'dtw_{alignment_name}_aligned' + mapping_col = f'dtw_{alignment_name}_consensus_mapping' + distance_col = f'dtw_{alignment_name}_distance' + + # Check if alignment columns exist + if aligned_col not in df.columns: + logger.error(f"Alignment '{alignment_name}' not found in DataFrame") + return + + # Get consensus and lineages + consensus_df = df[df['lineage_id'] == -1].sort_values('t').copy() + lineages = df[df['lineage_id'] != -1]['lineage_id'].unique()[:max_lineages] + + if consensus_df.empty: + logger.error("No consensus found in DataFrame") + return + + consensus_length = len(consensus_df) + + # Create concatenated sequences for each lineage + concatenated_lineages = {} + + for lineage_id in lineages: + lineage_df = df[df['lineage_id'] == lineage_id].copy().sort_values('t') + if lineage_df.empty: + continue + + # Split into aligned and unaligned portions + aligned_rows = lineage_df[lineage_df[aligned_col]].copy() + unaligned_rows = lineage_df[~lineage_df[aligned_col]].copy() + + # Create consensus-length aligned portion using mapping information + aligned_portion = {} # consensus_idx -> feature_values + + for _, row in aligned_rows.iterrows(): + consensus_idx = row[mapping_col] + if not pd.isna(consensus_idx): + consensus_idx = int(consensus_idx) + if 0 <= consensus_idx < consensus_length: + aligned_portion[consensus_idx] = {col: row[col] for col in feature_columns} + + # Fill gaps in aligned portion (interpolate missing consensus indices) + if aligned_portion: + filled_aligned = {} + for i in range(consensus_length): + if i in aligned_portion: + filled_aligned[i] = aligned_portion[i] + else: + # Find nearest available index + available_indices = list(aligned_portion.keys()) + if available_indices: + closest_idx = min(available_indices, key=lambda x: abs(x - i)) + filled_aligned[i] = aligned_portion[closest_idx] + else: + # Use consensus values if no aligned portion available + consensus_row = consensus_df.iloc[i] + filled_aligned[i] = {col: consensus_row[col] for col in feature_columns} + + # Convert aligned portion to arrays + aligned_arrays = {} + for col in feature_columns: + aligned_arrays[col] = np.array([filled_aligned[i][col] for i in range(consensus_length)]) + else: + # No aligned portion, use consensus as fallback + aligned_arrays = {} + for col in feature_columns: + aligned_arrays[col] = consensus_df[col].values.copy() + + # Get unaligned portion (sorted by original time) + unaligned_arrays = {} + if not unaligned_rows.empty: + unaligned_rows = unaligned_rows.sort_values('t') + for col in feature_columns: + unaligned_arrays[col] = unaligned_rows[col].values + else: + for col in feature_columns: + unaligned_arrays[col] = np.array([]) + + # Concatenate aligned + unaligned portions + concatenated_arrays = {} + for col in feature_columns: + if len(unaligned_arrays[col]) > 0: + concatenated_arrays[col] = np.concatenate([aligned_arrays[col], unaligned_arrays[col]]) + else: + concatenated_arrays[col] = aligned_arrays[col] + + # Store concatenated data + concatenated_lineages[lineage_id] = { + 'concatenated': concatenated_arrays, + 'aligned_length': len(aligned_arrays[feature_columns[0]]), + 'unaligned_length': len(unaligned_arrays[feature_columns[0]]), + 'dtw_distance': lineage_df[distance_col].iloc[0] if not pd.isna(lineage_df[distance_col].iloc[0]) else np.nan + } + + # Plotting + n_features = len(feature_columns) + fig, axes = plt.subplots(n_features, 1, figsize=(15, 4*n_features)) + if n_features == 1: + axes = [axes] + + # Generate colors using a colormap that works for all scenarios + cmap = plt.cm.get_cmap('tab10' if len(concatenated_lineages) <= 10 else 'tab20' if len(concatenated_lineages) <= 20 else 'hsv') + colors = [cmap(i / max(len(concatenated_lineages), 1)) for i in range(len(concatenated_lineages))] + + for feat_idx, feat_col in enumerate(feature_columns): + ax = axes[feat_idx] + + # Plot consensus (no offset) + consensus_values = consensus_df[feat_col].values + consensus_time = np.arange(len(consensus_values)) + ax.plot(consensus_time, consensus_values, 'o-', + color='black', linewidth=4, markersize=8, + label=f'Consensus ({alignment_name})', alpha=0.9, zorder=5) + + # Add consensus annotations if available + if alignment_name == "cell_division" and 'consensus_annotations' in globals(): + for t, annotation in enumerate(consensus_annotations): + if annotation == 'mitosis': + ax.axvline(t, color='orange', alpha=0.7, + linestyle='--', linewidth=2, zorder=1) + + # Plot each concatenated lineage + for lineage_idx, (lineage_id, data) in enumerate(concatenated_lineages.items()): + # Remove the color limit - now we have enough colors + + y_offset = -(lineage_idx + 1) * y_offset_step + color = colors[lineage_idx] + + # Get concatenated sequence values + concat_values = data['concatenated'][feat_col] + y_offset + time_axis = np.arange(len(concat_values)) + + # Plot full concatenated sequence + ax.plot(time_axis, concat_values, '.-', + color=color, linewidth=unaligned_linewidth, markersize=unaligned_markersize, alpha=0.8, + label=f'Lineage {lineage_id} (d={data["dtw_distance"]:.3f})') + + # Highlight aligned portion with thicker line + aligned_length = data['aligned_length'] + if aligned_length > 0: + aligned_time = time_axis[:aligned_length] + aligned_values = concat_values[:aligned_length] + + ax.plot(aligned_time, aligned_values, 's-', + color=color, linewidth=aligned_linewidth, markersize=aligned_markersize, + alpha=0.9, zorder=4) + + # Mark boundary between aligned and unaligned + if aligned_length > 0 and aligned_length < len(concat_values): + ax.axvline(aligned_length, color=color, alpha=0.5, + linestyle=':', linewidth=1) + + # Formatting + ax.set_xlabel('Concatenated Time: [DTW Aligned] + [Unaligned Continuation]') + ax.set_ylabel(f'{feat_col} (vertically separated)') + ax.set_title(f'{feat_col}: Concatenated {alignment_name.replace("_", " ").title()} Trajectories') + ax.grid(True, alpha=0.3) + + if feat_idx == 0: + ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + + plt.suptitle(f'DataFrame-Based Concatenated Alignment: {alignment_name.replace("_", " ").title()}\n' + f'Thick lines = DTW-aligned portions, Dotted lines = segment boundaries', + fontsize=14) + plt.tight_layout() + plt.show() + + # Print summary statistics + logger.info(f"\nConcatenated alignment summary for '{alignment_name}':") + logger.info(f"Processed {len(concatenated_lineages)} lineages") + for lineage_id, data in concatenated_lineages.items(): + logger.info(f" Lineage {lineage_id}: A={data['aligned_length']} + U={data['unaligned_length']} = {data['aligned_length'] + data['unaligned_length']}, d={data['dtw_distance']:.3f}") + +# Plot using the clean DataFrame-only function +plot_concatenated_from_dataframe(enhanced_df, alignment_name="cell_division", + feature_columns=['PC1','PC2','PC3'], max_lineages=15, + aligned_scale=0.5, unaligned_scale=0.7) + +# %% + + +def get_concatenated_image_sequences_from_dataframe(dataset, df, alignment_name="cell_division", max_lineages=5): + """ + Create concatenated [DTW-aligned portion] + [unaligned portion] image sequences + using the enhanced DataFrame alignment information, similar to plot_concatenated_from_dataframe(). + + Parameters + ---------- + dataset : TripletDataset + Dataset containing the images + df : pd.DataFrame + Enhanced DataFrame with alignment information + alignment_name : str + Name of alignment to use (e.g., "cell_division") + max_lineages : int + Maximum number of lineages to process + + Returns + ------- + dict + Dictionary mapping lineage_id to concatenated image sequences + Each entry contains: + - 'concatenated_images': List of concatenated image tensors + - 'aligned_length': Number of DTW-aligned images + - 'unaligned_length': Number of unaligned continuation images + - 'metadata': Lineage metadata + """ + + # Dynamic column names based on alignment_name + aligned_col = f'dtw_{alignment_name}_aligned' + mapping_col = f'dtw_{alignment_name}_consensus_mapping' + distance_col = f'dtw_{alignment_name}_distance' + + # Check if alignment columns exist + if aligned_col not in df.columns: + logger.error(f"Alignment '{alignment_name}' not found in DataFrame") + return {} + + # Get consensus and lineages + consensus_df = df[df['lineage_id'] == -1].sort_values('t').copy() + lineages = df[df['lineage_id'] != -1]['lineage_id'].unique()[:max_lineages] + + if consensus_df.empty: + logger.error("No consensus found in DataFrame") + return {} + + consensus_length = len(consensus_df) + concatenated_sequences = {} + + for lineage_id in lineages: + lineage_df = df[df['lineage_id'] == lineage_id].copy().sort_values('t') + if lineage_df.empty: + continue + + # Get FOV name and track IDs for this lineage + fov_name = lineage_df['fov_name'].iloc[0] + track_ids = lineage_df['track_id'].unique() + + # Find matching dataset indices + matching_indices = [] + for dataset_idx in range(len(dataset.valid_anchors)): + anchor_row = dataset.valid_anchors.iloc[dataset_idx] + if (anchor_row['fov_name'] == fov_name and anchor_row['track_id'] in track_ids): + matching_indices.append(dataset_idx) + + if not matching_indices: + logger.warning(f"No matching indices found for lineage {lineage_id}, FOV {fov_name}, tracks {track_ids}") + continue + + # Get images and create time mapping + batch_data = dataset.__getitems__(matching_indices) + images = [] + for i in range(len(matching_indices)): + img_data = { + 'anchor': batch_data['anchor'][i], + 'index': batch_data['index'][i] + } + images.append(img_data) + + images.sort(key=lambda x: x['index']['t']) + time_to_image = {img['index']['t']: img for img in images} + + # Split DataFrame into aligned and unaligned portions + aligned_rows = lineage_df[lineage_df[aligned_col]].copy() + unaligned_rows = lineage_df[~lineage_df[aligned_col]].copy() + + # Create consensus-length aligned portion using mapping information + aligned_images = [None] * consensus_length + + for _, row in aligned_rows.iterrows(): + consensus_idx = row[mapping_col] + timepoint = row['t'] + + if not pd.isna(consensus_idx) and timepoint in time_to_image: + consensus_idx = int(consensus_idx) + if 0 <= consensus_idx < consensus_length: + aligned_images[consensus_idx] = time_to_image[timepoint] + + # Fill gaps in aligned portion with nearest neighbor + for i in range(consensus_length): + if aligned_images[i] is None: + # Find nearest available aligned image + available_indices = [j for j, img in enumerate(aligned_images) if img is not None] + if available_indices: + closest_idx = min(available_indices, key=lambda x: abs(x - i)) + aligned_images[i] = aligned_images[closest_idx] + else: + # Use first available image from time_to_image as fallback + if time_to_image: + aligned_images[i] = next(iter(time_to_image.values())) + + # Get unaligned continuation images (sorted by original time) + unaligned_images = [] + if not unaligned_rows.empty: + unaligned_rows = unaligned_rows.sort_values('t') + for _, row in unaligned_rows.iterrows(): + timepoint = row['t'] + if timepoint in time_to_image: + unaligned_images.append(time_to_image[timepoint]) + + # Concatenate aligned + unaligned portions + concatenated_images = aligned_images + unaligned_images + + # Store results + concatenated_sequences[lineage_id] = { + 'concatenated_images': concatenated_images, + 'aligned_length': len(aligned_images), + 'unaligned_length': len(unaligned_images), + 'metadata': { + 'fov_name': fov_name, + 'track_ids': list(track_ids), + 'dtw_distance': lineage_df[distance_col].iloc[0] if not pd.isna(lineage_df[distance_col].iloc[0]) else np.nan, + 'lineage_id': lineage_id + } + } + + logger.info(f"Created concatenated sequences for {len(concatenated_sequences)} lineages") + for lineage_id, data in concatenated_sequences.items(): + logger.info(f" Lineage {lineage_id}: A={data['aligned_length']} + U={data['unaligned_length']} = {len(data['concatenated_images'])}, d={data['metadata']['dtw_distance']:.3f}") + + return concatenated_sequences + + +#%% +# Create concatenated image sequences + +# Create concatenated image sequences using the DataFrame alignment information +if dataset is not None: + concatenated_image_sequences = get_concatenated_image_sequences_from_dataframe( + dataset, enhanced_df, alignment_name="cell_division", max_lineages=30 + ) +else: + print("Skipping image sequence creation - no valid dataset available") + concatenated_image_sequences = {} + +# Load concatenated sequences into napari +if NAPARI and dataset is not None and len(concatenated_image_sequences) > 0: + import numpy as np + + for lineage_id, seq_data in concatenated_image_sequences.items(): + concatenated_images = seq_data['concatenated_images'] + meta = seq_data['metadata'] + aligned_length = seq_data['aligned_length'] + unaligned_length = seq_data['unaligned_length'] + + if len(concatenated_images) == 0: + continue + + # Stack images into time series (T, C, Z, Y, X) + image_stack = [] + for img_sample in concatenated_images: + if img_sample is not None: + img_tensor = img_sample['anchor'] # Shape should be (C, Z, Y, X) + img_np = img_tensor.cpu().numpy() + image_stack.append(img_np) + + if len(image_stack) > 0: + # Stack into (T, C, Z, Y, X) + time_series = np.stack(image_stack, axis=0) + n_channels = time_series.shape[1] + + logger.info(f"Processing lineage {lineage_id} with {n_channels} channels, shape {time_series.shape}") + + # Set up colormap based on number of channels + if n_channels == 2: + colormap = ['green', 'magenta'] + elif n_channels == 3: + colormap = ['gray', 'green', 'magenta'] + else: + colormap = ['gray'] * n_channels # Default fallback + + # Add each channel as a separate layer in napari + for channel_idx in range(n_channels): + # Extract single channel: (T, Z, Y, X) + channel_data = time_series[:, channel_idx, :, :, :] + + # Get channel name if available + channel_name = processing_channels[channel_idx] if channel_idx < len(processing_channels) else f"ch{channel_idx}" + + layer_name = f"track_id_{meta['track_ids'][0]}_FOV_{meta['fov_name']}_dist_{meta['dtw_distance']:.3f}_{channel_name}" + + viewer.add_image( + channel_data, + name=layer_name, + contrast_limits=(channel_data.min(), channel_data.max()), + colormap=colormap[channel_idx], + blending='additive' + ) + logger.info(f"Added {channel_name} channel for lineage {lineage_id} with shape {channel_data.shape}") +# %% diff --git a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings.py b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings.py index 17c15d1d..dfe0de32 100644 --- a/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings.py +++ b/applications/pseudotime_analysis/evaluation/compare_dtw_embeddings.py @@ -1,3 +1,10 @@ +#!/usr/bin/env python3 +""" +Refactored DTW embedding comparison using the new viscy.representation.pseudotime API. + +This demonstrates how to use the integrated DTW functionality without local imports. +""" + # %% import ast import logging @@ -7,26 +14,25 @@ import numpy as np import pandas as pd import seaborn as sns -from plotting_utils import ( - find_pattern_matches, - identify_lineages, + +from viscy.representation.evaluation.pseudotime_plotting import ( + align_image_stacks, plot_pc_trajectories, ) -from tqdm import tqdm -from viscy.data.triplet import TripletDataModule -from viscy.representation.embedding_writer import read_embedding_dataset +# Use the new integrated DTW API +from viscy.representation.pseudotime import DTWAnalyzer, identify_lineages logger = logging.getLogger("viscy") logger.setLevel(logging.INFO) console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) -formatter = logging.Formatter("%(message)s") # Simplified format +formatter = logging.Formatter("%(message)s") console_handler.setFormatter(formatter) logger.addHandler(console_handler) - -NAPARI = True +# Configuration +NAPARI = False if NAPARI: import os @@ -34,9 +40,8 @@ os.environ["DISPLAY"] = ":1" viewer = napari.Viewer() -# %% -# Organelle and Phate aligned to infection +# File paths input_data_path = Path( "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_DENV.zarr" ) @@ -44,24 +49,15 @@ "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/4-track-gt/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr" ) infection_annotations_path = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/combined_annotations_n_tracks_infection.csv" + "/hpc/projects/intracellular_dashboard/organelle_dynamics/rerun/2024_11_07_A549_SEC61_DENV/4-phenotyping/0-annotation/track_infection_annotation.csv" ) pretrain_features_root = Path( "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/prediction_pretrained_models" ) -# Phase n organelle -# dynaclr_features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr" -# pahe n sensor +# Embedding paths dynaclr_features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions_infection/2chan_192patch_100ckpt_timeAware_ntxent_GT.zarr" - -output_root = Path( - "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/4-phenotyping/figure/SEC61B/model_comparison" -) - - -# Load embeddings imagenet_features_path = ( pretrain_features_root / "ImageNet/20241107_sensor_n_phase_imagenet.zarr" ) @@ -69,694 +65,280 @@ pretrain_features_root / "OpenPhenom/20241107_sensor_n_phase_openphenom.zarr" ) -dynaclr_embeddings = read_embedding_dataset(dynaclr_features_path) -imagenet_embeddings = read_embedding_dataset(imagenet_features_path) -openphenom_embeddings = read_embedding_dataset(openphenom_features_path) - -# Load infection annotations -infection_annotations_df = pd.read_csv(infection_annotations_path) -infection_annotations_df["fov_name"] = "/C/2/000001" - -process_embeddings = [ - (dynaclr_embeddings, "dynaclr"), - (imagenet_embeddings, "imagenet"), - (openphenom_embeddings, "openphenom"), -] - - -output_root.mkdir(parents=True, exist_ok=True) +# %% Check that the directories exist +print(f"Input data path exists: {input_data_path.exists()}") +print(f"Tracks path exists: {tracks_path.exists()}") +print(f"Infection annotations path exists: {infection_annotations_path.exists()}") +print(f"Pretrain features root exists: {pretrain_features_root.exists()}") +print(f"Dynaclr features path exists: {dynaclr_features_path.exists()}") +print(f"Imagenet features path exists: {imagenet_features_path.exists()}") +print(f"Openphenom features path exists: {openphenom_features_path.exists()}") # %% -feature_df = dynaclr_embeddings["sample"].to_dataframe().reset_index(drop=True) - -# Logic to find lineages -lineages = identify_lineages(feature_df) -logger.info(f"Found {len(lineages)} distinct lineages") -filtered_lineages = [] -min_timepoints = 20 -for fov_id, track_ids in lineages: - # Get all rows for this lineage - lineage_rows = feature_df[ - (feature_df["fov_name"] == fov_id) & (feature_df["track_id"].isin(track_ids)) - ] - - # Count the total number of timepoints - total_timepoints = len(lineage_rows) - # Only keep lineages with at least min_timepoints - if total_timepoints >= min_timepoints: - filtered_lineages.append((fov_id, track_ids)) -logger.info( - f"Found {len(filtered_lineages)} lineages with at least {min_timepoints} timepoints" +output_root = Path( + "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/output" ) +output_root.mkdir(parents=True, exist_ok=True) -# %% -# Aligning condition embeddings to infection -# OPTION 1: Use the infection annotations to find the reference lineage -reference_lineage_fov = "/C/2/001000" -reference_lineage_track_id = [129] -reference_timepoints = [8, 70] # sensor rellocalization and partial remodelling - -# Option 2: from the filtered lineages find one from FOV C/2/000001 -reference_lineage_fov = "/C/2/000001" -for fov_id, track_ids in filtered_lineages: - if reference_lineage_fov == fov_id: - break -reference_lineage_track_id = track_ids -reference_timepoints = [8, 70] # sensor rellocalization and partial remodelling # %% -# Dictionary to store alignment results for comparison -alignment_results = {} +def main(): + """Main analysis pipeline using the new DTW API.""" + + # Initialize DTW analyzers for each embedding method + analyzers = { + "dynaclr": DTWAnalyzer(dynaclr_features_path), + "imagenet": DTWAnalyzer(imagenet_features_path), + "openphenom": DTWAnalyzer(openphenom_features_path), + } -for embeddings, name in process_embeddings: - # Get the reference pattern from the current embedding space - reference_pattern = None - reference_lineage = [] - for fov_id, track_ids in filtered_lineages: - if fov_id == reference_lineage_fov and all( - track_id in track_ids for track_id in reference_lineage_track_id - ): - logger.info( - f"Found reference pattern for {fov_id} {reference_lineage_track_id} using {name} embeddings" - ) - reference_pattern = embeddings.sel( - sample=(fov_id, reference_lineage_track_id) - ).features.values - reference_lineage.append(reference_pattern) - break - if reference_pattern is None: - logger.info(f"Reference pattern not found for {name} embeddings. Skipping.") - continue - reference_pattern = np.concatenate(reference_lineage) - reference_pattern = reference_pattern[ - reference_timepoints[0] : reference_timepoints[1] - ] + # Load infection annotations + infection_annotations_df = pd.read_csv(infection_annotations_path) + infection_annotations_df["fov_name"] = "/C/2/000001" - # Find all matches to the reference pattern - metric = "cosine" - all_match_positions = find_pattern_matches( - reference_pattern, - filtered_lineages, - embeddings, - window_step_fraction=0.1, - num_candidates=4, - method="bernd_clifford", - save_path=output_root / f"{name}_matching_lineages_{metric}.csv", - metric=metric, + # Identify lineages from the first dataset + feature_df = ( + analyzers["dynaclr"].embeddings["sample"].to_dataframe().reset_index(drop=True) ) + all_lineages = identify_lineages(feature_df) + logger.info(f"Found {len(all_lineages)} distinct lineages") + + # Filter lineages by minimum timepoints + min_timepoints = 20 + filtered_lineages = [] + for fov_id, track_ids in all_lineages: + lineage_rows = feature_df[ + (feature_df["fov_name"] == fov_id) + & (feature_df["track_id"].isin(track_ids)) + ] + total_timepoints = len(lineage_rows) + if total_timepoints >= min_timepoints: + filtered_lineages.append((fov_id, track_ids)) - # Store results for later comparison - alignment_results[name] = all_match_positions - -# Visualize warping paths in PC space instead of raw embedding dimensions -for name, match_positions in alignment_results.items(): - if match_positions is not None and not match_positions.empty: - # Call the new function from plotting_utils - plot_pc_trajectories( - reference_lineage_fov=reference_lineage_fov, - reference_lineage_track_id=reference_lineage_track_id, - reference_timepoints=reference_timepoints, - match_positions=match_positions, - embeddings_dataset=next( - emb for emb, emb_name in process_embeddings if emb_name == name - ), - filtered_lineages=filtered_lineages, - name=name, - save_path=output_root / f"{name}_pc_lineage_alignment.png", - ) - + logger.info( + f"Found {len(filtered_lineages)} lineages with at least {min_timepoints} timepoints" + ) -# %% -# Compare DTW performance between embedding methods - -# Create a DataFrame to collect the alignment statistics for comparison -match_data = [] -for name, match_positions in alignment_results.items(): - if match_positions is not None and not match_positions.empty: - for i, row in match_positions.head(10).iterrows(): # Take top 10 matches - warping_path = ( - ast.literal_eval(row["warp_path"]) - if isinstance(row["warp_path"], str) - else row["warp_path"] - ) - match_data.append( - { - "model": name, - "match_position": row["start_timepoint"], - "dtw_distance": row["distance"], - "path_skewness": row["skewness"], - "path_length": len(warping_path), - } - ) + # Reference pattern configuration + reference_lineage_fov = "/C/2/000001" + reference_lineage_track_id = [129] + reference_timepoints = (8, 70) # sensor relocalization and partial remodelling -comparison_df = pd.DataFrame(match_data) - -# Create visualizations to compare alignment quality -plt.figure(figsize=(12, 10)) - -# 1. Compare DTW distances -plt.subplot(2, 2, 1) -sns.boxplot(x="model", y="dtw_distance", data=comparison_df) -plt.title("DTW Distance by Model") -plt.ylabel("DTW Distance (lower is better)") - -# 2. Compare path skewness -plt.subplot(2, 2, 2) -sns.boxplot(x="model", y="path_skewness", data=comparison_df) -plt.title("Path Skewness by Model") -plt.ylabel("Skewness (lower is better)") - -# 3. Compare path lengths -plt.subplot(2, 2, 3) -sns.boxplot(x="model", y="path_length", data=comparison_df) -plt.title("Warping Path Length by Model") -plt.ylabel("Path Length") - -# 4. Scatterplot of distance vs skewness -plt.subplot(2, 2, 4) -scatter = sns.scatterplot( - x="dtw_distance", y="path_skewness", hue="model", data=comparison_df -) -plt.title("DTW Distance vs Path Skewness") -plt.xlabel("DTW Distance") -plt.ylabel("Path Skewness") -plt.legend(title="Model") + # Find a valid reference lineage from filtered lineages + for fov_id, track_ids in filtered_lineages: + if reference_lineage_fov == fov_id: + reference_lineage_track_id = track_ids + break -plt.tight_layout() -plt.savefig(output_root / "dtw_alignment_comparison.png", dpi=300) -plt.close() + # Perform DTW analysis for each embedding method + alignment_results = {} -# %% -# Analyze warping path step patterns for better understanding of alignment quality - -# Step pattern analysis -step_pattern_counts = { - name: {"diagonal": 0, "horizontal": 0, "vertical": 0, "total": 0} - for name in alignment_results.keys() -} - -for name, match_positions in alignment_results.items(): - if match_positions is not None and not match_positions.empty: - # Get the top match - top_match = match_positions.iloc[0] - path = ( - ast.literal_eval(top_match["warp_path"]) - if isinstance(top_match["warp_path"], str) - else top_match["warp_path"] - ) + for name, analyzer in analyzers.items(): + logger.info(f"Processing {name} embeddings") - # Count step types - for i in range(1, len(path)): - prev_i, prev_j = path[i - 1] - curr_i, curr_j = path[i] - - step_i = curr_i - prev_i - step_j = curr_j - prev_j - - if step_i == 1 and step_j == 1: - step_pattern_counts[name]["diagonal"] += 1 - elif step_i == 1 and step_j == 0: - step_pattern_counts[name]["vertical"] += 1 - elif step_i == 0 and step_j == 1: - step_pattern_counts[name]["horizontal"] += 1 - - step_pattern_counts[name]["total"] += 1 - -# Convert to percentages -for name in step_pattern_counts: - total = step_pattern_counts[name]["total"] - if total > 0: - for key in ["diagonal", "horizontal", "vertical"]: - step_pattern_counts[name][key] = ( - step_pattern_counts[name][key] / total - ) * 100 - -# Visualize step pattern distributions -step_df = pd.DataFrame( - { - "model": [name for name in step_pattern_counts.keys() for _ in range(3)], - "step_type": ["diagonal", "horizontal", "vertical"] * len(step_pattern_counts), - "percentage": [ - step_pattern_counts[name]["diagonal"] for name in step_pattern_counts.keys() - ] - + [ - step_pattern_counts[name]["horizontal"] - for name in step_pattern_counts.keys() - ] - + [ - step_pattern_counts[name]["vertical"] for name in step_pattern_counts.keys() - ], - } -) + try: + # Extract reference pattern + reference_pattern = analyzer.get_reference_pattern( + fov_name=reference_lineage_fov, + track_id=reference_lineage_track_id, + timepoints=reference_timepoints, + ) -plt.figure(figsize=(10, 6)) -sns.barplot(x="model", y="percentage", hue="step_type", data=step_df) -plt.title("Step Pattern Distribution in Warping Paths") -plt.ylabel("Percentage (%)") -plt.savefig(output_root / "step_pattern_distribution.png", dpi=300) -plt.close() + logger.info( + f"Found reference pattern for {name} with shape {reference_pattern.shape}" + ) -# %% -# Find all matches to the reference pattern -MODEL = "openphenom" -alignment_df_path = output_root / f"{MODEL}_matching_lineages_cosine.csv" -alignment_df = pd.read_csv(alignment_df_path) - -# Get the top N aligned cells - -source_channels = [ - "Phase3D", - "raw GFP EX488 EM525-45", - "raw mCherry EX561 EM600-37", -] -yx_patch_size = (192, 192) -z_range = (10, 30) -view_ref_sector_only = (True,) - -all_lineage_images = [] -all_aligned_stacks = [] -all_unaligned_stacks = [] - -# Get aligned and unaligned stacks -top_aligned_cells = alignment_df.head(5) -napari_viewer = viewer if NAPARI else None -# Plot the aligned and unaligned stacks -for idx, row in tqdm( - top_aligned_cells.iterrows(), - total=len(top_aligned_cells), - desc="Aligning images", -): - fov_name = row["fov_name"] - track_ids = ast.literal_eval(row["track_ids"]) - warp_path = ast.literal_eval(row["warp_path"]) - start_time = int(row["start_timepoint"]) - - print(f"Aligning images for {fov_name} with track ids: {track_ids}") - data_module = TripletDataModule( - data_path=input_data_path, - tracks_path=tracks_path, - source_channel=source_channels, - z_range=z_range, - initial_yx_patch_size=yx_patch_size, - final_yx_patch_size=yx_patch_size, - batch_size=1, - num_workers=12, - predict_cells=True, - include_fov_names=[fov_name] * len(track_ids), - include_track_ids=track_ids, - ) - data_module.setup("predict") - - # Get the images for the lineage - lineage_images = [] - for batch in data_module.predict_dataloader(): - image = batch["anchor"].numpy()[0] - lineage_images.append(image) - - lineage_images = np.array(lineage_images) - all_lineage_images.append(lineage_images) - print(f"Lineage images shape: {np.array(lineage_images).shape}") - - # Create an aligned stack based on the warping path - if view_ref_sector_only: - aligned_stack = np.zeros( - (len(reference_pattern),) + lineage_images.shape[-4:], - dtype=lineage_images.dtype, - ) - unaligned_stack = np.zeros( - (len(reference_pattern),) + lineage_images.shape[-4:], - dtype=lineage_images.dtype, - ) + # Find pattern matches + matches = analyzer.find_pattern_matches( + reference_pattern=reference_pattern, + filtered_lineages=filtered_lineages, + window_step_fraction=0.1, + num_candidates=4, + method="bernd_clifford", + metric="cosine", + save_path=output_root / f"{name}_matching_lineages_cosine.csv", + ) - # Map each reference timepoint to the corresponding lineage timepoint - for ref_idx in range(len(reference_pattern)): - # Find matches in warping path for this reference index - matches = [(i, q) for i, q in warp_path if i == ref_idx] - unaligned_stack[ref_idx] = lineage_images[ref_idx] - if matches: - # Get the corresponding lineage timepoint (first match if multiple) - print(f"Found match for ref idx: {ref_idx}") - match = matches[0] - query_idx = match[1] - lineage_idx = int(start_time + query_idx) - print( - f"Lineage index: {lineage_idx}, start time: {start_time}, query idx: {query_idx}, ref idx: {ref_idx}" + alignment_results[name] = matches + logger.info(f"Found {len(matches)} matches for {name}") + + # Generate PC trajectory visualization + if not matches.empty: + plot_pc_trajectories( + reference_lineage_fov=reference_lineage_fov, + reference_lineage_track_id=reference_lineage_track_id, + reference_timepoints=list(reference_timepoints), + match_positions=matches, + embeddings_dataset=analyzer.embeddings, + filtered_lineages=filtered_lineages, + name=name, + save_path=output_root / f"{name}_pc_lineage_alignment.png", ) - # Copy the image if it's within bounds - if 0 <= lineage_idx < len(lineage_images): - aligned_stack[ref_idx] = lineage_images[lineage_idx] - else: - # Find nearest valid timepoint if out of bounds - nearest_idx = min(max(0, lineage_idx), len(lineage_images) - 1) - aligned_stack[ref_idx] = lineage_images[nearest_idx] - else: - # If no direct match, find closest reference timepoint in warping path - print(f"No match found for ref idx: {ref_idx}") - all_ref_indices = [i for i, _ in warp_path] - if all_ref_indices: - closest_ref_idx = min( - all_ref_indices, key=lambda x: abs(x - ref_idx) - ) - closest_matches = [ - (i, q) for i, q in warp_path if i == closest_ref_idx - ] - - if closest_matches: - closest_query_idx = closest_matches[0][1] - lineage_idx = int(start_time + closest_query_idx) - - if 0 <= lineage_idx < len(lineage_images): - aligned_stack[ref_idx] = lineage_images[lineage_idx] - else: - # Bound to valid range - nearest_idx = min( - max(0, lineage_idx), len(lineage_images) - 1 - ) - aligned_stack[ref_idx] = lineage_images[nearest_idx] - - all_aligned_stacks.append(aligned_stack) - all_unaligned_stacks.append(unaligned_stack) - -all_aligned_stacks = np.array(all_aligned_stacks) -all_unaligned_stacks = np.array(all_unaligned_stacks) -# %% -if NAPARI: - for idx, row in tqdm( - top_aligned_cells.reset_index().iterrows(), - total=len(top_aligned_cells), - desc="Plotting aligned and unaligned stacks", - ): - fov_name = row["fov_name"] - # track_ids = ast.literal_eval(row["track_ids"]) - track_ids = row["track_ids"] - - aligned_stack = all_aligned_stacks[idx] - unaligned_stack = all_unaligned_stacks[idx] - - unaligned_gfp_mip = np.max(unaligned_stack[:, 1, :, :], axis=1) - aligned_gfp_mip = np.max(aligned_stack[:, 1, :, :], axis=1) - unaligned_mcherry_mip = np.max(unaligned_stack[:, 2, :, :], axis=1) - aligned_mcherry_mip = np.max(aligned_stack[:, 2, :, :], axis=1) - - z_slice = 15 - unaligned_phase = unaligned_stack[:, 0, z_slice, :] - aligned_phase = aligned_stack[:, 0, z_slice, :] - - # unaligned - viewer.add_image( - unaligned_gfp_mip, - name=f"unaligned_gfp_{fov_name}_{track_ids[0]}", - colormap="green", - contrast_limits=(106, 215), - ) - viewer.add_image( - unaligned_mcherry_mip, - name=f"unaligned_mcherry_{fov_name}_{track_ids[0]}", - colormap="magenta", - contrast_limits=(106, 190), - ) - viewer.add_image( - unaligned_phase, - name=f"unaligned_phase_{fov_name}_{track_ids[0]}", - colormap="gray", - contrast_limits=(-0.74, 0.4), - ) - # aligned - viewer.add_image( - aligned_gfp_mip, - name=f"aligned_gfp_{fov_name}_{track_ids[0]}", - colormap="green", - contrast_limits=(106, 215), - ) - viewer.add_image( - aligned_mcherry_mip, - name=f"aligned_mcherry_{fov_name}_{track_ids[0]}", - colormap="magenta", - contrast_limits=(106, 190), - ) - viewer.add_image( - aligned_phase, - name=f"aligned_phase_{fov_name}_{track_ids[0]}", - colormap="gray", - contrast_limits=(-0.74, 0.4), - ) - viewer.grid.enabled = True - viewer.grid.shape = (-1, 6) -# %% -# Evaluate model performance based on infection state warping accuracy -# Check unique infection status values -unique_infection_statuses = infection_annotations_df["infection_status"].unique() -logger.info(f"Unique infection status values: {unique_infection_statuses}") - -# If "infected" is not in the unique values, this could explain zero precision/recall -if "infected" not in unique_infection_statuses: - logger.warning('The label "infected" is not found in the infection_status column!') - logger.info(f"Using these values instead: {unique_infection_statuses}") - - # If we need to map values, we could do it here - if len(unique_infection_statuses) >= 2: - logger.info( - f'Will treat "{unique_infection_statuses[1]}" as "infected" for metrics calculation' - ) - infection_target_value = unique_infection_statuses[1] - else: - infection_target_value = unique_infection_statuses[0] -else: - infection_target_value = "infected" - -logger.info(f'Using "{infection_target_value}" as positive class for F1 calculation') -# Check if the reference track is in the annotations -logger.info( - f"Looking for infection annotations for reference lineage: {reference_lineage_fov}, tracks: {reference_lineage_track_id}" -) -print(f"Sample of infection_annotations_df: {infection_annotations_df.head()}") - -reference_infection_states = {} -for track_id in reference_lineage_track_id: - reference_annotations = infection_annotations_df[ - (infection_annotations_df["fov_name"] == reference_lineage_fov) - & (infection_annotations_df["track_id"] == track_id) - ] - - # Add annotations for this reference track - annotation_count = len(reference_annotations) - logger.info(f"Found {annotation_count} annotations for track {track_id}") - if annotation_count > 0: - print( - f"Sample annotations for track {track_id}: {reference_annotations.head()}" + except Exception as e: + logger.error(f"Failed to process {name}: {e}") + continue + + # Compare DTW performance between embedding methods + create_dtw_comparison_plots(alignment_results, output_root) + + # Demonstrate image alignment for the best model + if alignment_results: + best_model = min( + alignment_results.keys(), + key=lambda k: ( + alignment_results[k]["distance"].min() + if not alignment_results[k].empty + else float("inf") + ), ) - for _, row in reference_annotations.iterrows(): - reference_infection_states[row["t"]] = row["infection_status"] + logger.info(f"Best performing model: {best_model}") + demonstrate_image_alignment( + analyzers[best_model], + alignment_results[best_model], + reference_pattern, + output_root, + ) -if reference_infection_states: - logger.info( - f"Total reference timepoints with infection status: {len(reference_infection_states)}" - ) - reference_t_range = range(reference_timepoints[0], reference_timepoints[1]) - reference_gt_states = [ - reference_infection_states.get(t, "unknown") for t in reference_t_range - ] - logger.info(f"Reference track infection states: {reference_gt_states[:5]}...") - # Evaluate warping accuracy for each model - model_performance = [] +def create_dtw_comparison_plots(alignment_results, output_root): + """Create comparison plots for DTW performance across models.""" + # Collect alignment statistics + match_data = [] for name, match_positions in alignment_results.items(): if match_positions is not None and not match_positions.empty: - total_correct = 0 - total_predictions = 0 - true_positives = 0 - false_positives = 0 - false_negatives = 0 - - # Analyze top alignments for this model - alignment_details = [] for i, row in match_positions.head(10).iterrows(): - fov_name = row["fov_name"] - track_ids = row[ - "track_ids" - ] # This is already a list of track IDs for the lineage warp_path = ( ast.literal_eval(row["warp_path"]) if isinstance(row["warp_path"], str) else row["warp_path"] ) - start_time = int(row["start_timepoint"]) - - # Get annotations for all tracks in this lineage - track_infection_states = {} - for track_id in track_ids: - track_annotations = infection_annotations_df[ - (infection_annotations_df["fov_name"] == fov_name) - & (infection_annotations_df["track_id"] == track_id) - ] - - # Add annotations for this track to the combined dictionary - for _, annotation_row in track_annotations.iterrows(): - # Use t + track-specific offset if needed to handle timepoint overlaps between tracks - track_infection_states[annotation_row["t"]] = annotation_row[ - "infection_status" - ] - - # Only proceed if we found annotations for at least one track - if track_infection_states: - # For each reference timepoint, check if the warped timepoint maintains the infection state - track_correct = 0 - track_predictions = 0 - track_tp = 0 - track_fp = 0 - track_fn = 0 - - for ref_idx, query_idx in warp_path: - # Map to actual timepoints - ref_t = reference_timepoints[0] + ref_idx - query_t = start_time + query_idx - - # Get ground truth infection states - ref_state = reference_infection_states.get(ref_t, "unknown") - query_state = track_infection_states.get(query_t, "unknown") - - # Skip unknown states - if ref_state != "unknown" and query_state != "unknown": - track_predictions += 1 - - # Count correct alignments - if ref_state == query_state: - track_correct += 1 - - # Calculate F1 score components for "infected" state - if ( - ref_state == infection_target_value - and query_state == infection_target_value - ): - track_tp += 1 - elif ( - ref_state != infection_target_value - and query_state == infection_target_value - ): - track_fp += 1 - elif ( - ref_state == infection_target_value - and query_state != infection_target_value - ): - track_fn += 1 - - # Calculate track-specific metrics - if track_predictions > 0: - track_accuracy = track_correct / track_predictions - track_precision = ( - track_tp / (track_tp + track_fp) - if (track_tp + track_fp) > 0 - else 0 - ) - track_recall = ( - track_tp / (track_tp + track_fn) - if (track_tp + track_fn) > 0 - else 0 - ) - track_f1 = ( - 2 - * (track_precision * track_recall) - / (track_precision + track_recall) - if (track_precision + track_recall) > 0 - else 0 - ) - - alignment_details.append( - { - "fov_name": fov_name, - "track_ids": track_ids, - "accuracy": track_accuracy, - "precision": track_precision, - "recall": track_recall, - "f1_score": track_f1, - "correct": track_correct, - "total": track_predictions, - } - ) - - # Add to model totals - total_correct += track_correct - total_predictions += track_predictions - true_positives += track_tp - false_positives += track_fp - false_negatives += track_fn - - # Calculate metrics - accuracy = total_correct / total_predictions if total_predictions > 0 else 0 - precision = ( - true_positives / (true_positives + false_positives) - if (true_positives + false_positives) > 0 - else 0 - ) - recall = ( - true_positives / (true_positives + false_negatives) - if (true_positives + false_negatives) > 0 - else 0 - ) - f1 = ( - 2 * (precision * recall) / (precision + recall) - if (precision + recall) > 0 - else 0 - ) - - # Store alignment details for this model - if alignment_details: - alignment_details_df = pd.DataFrame(alignment_details) - print(f"\nDetailed alignment results for {name}:") - print(alignment_details_df) - alignment_details_df.to_csv( - output_root / f"{name}_alignment_details.csv", index=False + match_data.append( + { + "model": name, + "match_position": row["start_timepoint"], + "dtw_distance": row["distance"], + "path_skewness": row["skewness"], + "path_length": len(warp_path), + } ) - model_performance.append( - { - "model": name, - "accuracy": accuracy, - "precision": precision, - "recall": recall, - "f1_score": f1, - "total_predictions": total_predictions, - } - ) + if not match_data: + logger.warning("No match data available for comparison plots") + return - # Create performance DataFrame and visualize - performance_df = pd.DataFrame(model_performance) - print(performance_df) + comparison_df = pd.DataFrame(match_data) - # Plot performance metrics - plt.figure(figsize=(12, 8)) + # Create comparison visualizations + plt.figure(figsize=(12, 10)) - # Accuracy plot + # DTW distances comparison plt.subplot(2, 2, 1) - sns.barplot(x="model", y="accuracy", data=performance_df) - plt.title("Infection State Warping Accuracy") - plt.ylabel("Accuracy") + sns.boxplot(x="model", y="dtw_distance", data=comparison_df) + plt.title("DTW Distance by Model") + plt.ylabel("DTW Distance (lower is better)") - # Precision plot + # Path skewness comparison plt.subplot(2, 2, 2) - sns.barplot(x="model", y="precision", data=performance_df) - plt.title("Precision for Infected State") - plt.ylabel("Precision") + sns.boxplot(x="model", y="path_skewness", data=comparison_df) + plt.title("Path Skewness by Model") + plt.ylabel("Skewness (lower is better)") - # Recall plot + # Path lengths comparison plt.subplot(2, 2, 3) - sns.barplot(x="model", y="recall", data=performance_df) - plt.title("Recall for Infected State") - plt.ylabel("Recall") + sns.boxplot(x="model", y="path_length", data=comparison_df) + plt.title("Warping Path Length by Model") + plt.ylabel("Path Length") - # F1 score plot + # Distance vs skewness scatterplot plt.subplot(2, 2, 4) - sns.barplot(x="model", y="f1_score", data=performance_df) - plt.title("F1 Score for Infected State") - plt.ylabel("F1 Score") + sns.scatterplot( + x="dtw_distance", y="path_skewness", hue="model", data=comparison_df + ) + plt.title("DTW Distance vs Path Skewness") + plt.xlabel("DTW Distance") + plt.ylabel("Path Skewness") + plt.legend(title="Model") plt.tight_layout() - # plt.savefig(output_root / "infection_state_warping_performance.png", dpi=300) - # plt.close() -else: - logger.warning("Reference track annotations not found in infection_annotations_df") + plt.savefig(output_root / "dtw_alignment_comparison.png", dpi=300) + plt.close() + + logger.info("Saved DTW comparison plots") + + +def demonstrate_image_alignment(analyzer, matches, reference_pattern, output_root): + """Demonstrate image alignment using DTW results.""" + + if matches.empty: + logger.warning("No matches available for image alignment") + return + + # Configuration for image alignment + source_channels = [ + "Phase3D", + "raw GFP EX488 EM525-45", + "raw mCherry EX561 EM600-37", + ] + yx_patch_size = (192, 192) + z_range = (10, 30) + + # Get top aligned cells + top_aligned_cells = matches.head(5) + napari_viewer = viewer if NAPARI else None + + try: + # Align image stacks + all_lineage_images, all_aligned_stacks = align_image_stacks( + reference_pattern=reference_pattern, + top_aligned_cells=top_aligned_cells, + input_data_path=input_data_path, + tracks_path=tracks_path, + source_channels=source_channels, + yx_patch_size=yx_patch_size, + z_range=z_range, + view_ref_sector_only=True, + napari_viewer=napari_viewer, + ) + + logger.info(f"Aligned {len(all_aligned_stacks)} image stacks") + + # Display aligned stacks in napari if available + if NAPARI and napari_viewer: + for idx, stack in enumerate(all_aligned_stacks): + # Display different channels + gfp_mip = np.max(stack[:, 1, :, :], axis=1) + mcherry_mip = np.max(stack[:, 2, :, :], axis=1) + phase_slice = stack[:, 0, 15, :] # middle z-slice + + napari_viewer.add_image( + gfp_mip, + name=f"Aligned_GFP_{idx}", + colormap="green", + contrast_limits=(106, 215), + ) + napari_viewer.add_image( + mcherry_mip, + name=f"Aligned_mCherry_{idx}", + colormap="magenta", + contrast_limits=(106, 190), + ) + napari_viewer.add_image( + phase_slice, + name=f"Aligned_Phase_{idx}", + colormap="gray", + contrast_limits=(-0.74, 0.4), + ) + + napari_viewer.grid.enabled = True + napari_viewer.grid.shape = (-1, 3) + + except Exception as e: + logger.error(f"Failed to align images: {e}") + +if __name__ == "__main__": + main() # %% diff --git a/applications/pseudotime_analysis/get_tracking_stat.py b/applications/pseudotime_analysis/get_tracking_stat.py new file mode 100644 index 00000000..653d0218 --- /dev/null +++ b/applications/pseudotime_analysis/get_tracking_stat.py @@ -0,0 +1,154 @@ +# %% +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from anndata import read_zarr + +from viscy.representation.pseudotime import ( + CytoDtw, +) + +# %% +logger = logging.getLogger("viscy") +logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(message)s") +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + + +features_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_06_26_A549_G3BP1_ZIKV/4-phenotyping/predictions/anndata_predictions/phase_160patch_104ckpt_ver3max.zarr" +# %% +# Load AnnData directly +adata = read_zarr(features_path) +print("Loaded AnnData with shape:", adata.shape) +print("Available columns:", adata.obs.columns.tolist()) + +# Instantiate the CytoDtw object with AnnData +cytodtw = CytoDtw(adata) +feature_df = cytodtw.adata.obs + +min_timepoints = 0 +filtered_lineages = cytodtw.get_lineages(min_timepoints) + +fov_stats = cytodtw.get_track_statistics(filtered_lineages, per_fov=True) +logger.info("\n=== Confluence Table Format ===") +logger.info( + "| FOV Name | Lineages | Total Tracks | Tracks/Lineage (mean ± std) | Total Timepoints/Lineage (mean ± std) | Timepoints/Track (mean ± std) |" +) +logger.info( + "|----------|----------|--------------|------------------------------|---------------------------------------|-------------------------------|" +) +for _, row in fov_stats.iterrows(): + logger.info( + f"| {row['fov_name']} | {row['n_lineages']} | {row['total_tracks']} | " + f"{row['mean_tracks_per_lineage']:.2f} ± {row['std_tracks_per_lineage']:.2f} | " + f"{row['mean_total_timepoints']:.2f} ± {row['std_total_timepoints']:.2f} | " + f"{row['mean_timepoints_per_track']:.2f} ± {row['std_timepoints_per_track']:.2f} |" + ) + +logger.info("\n=== Global Statistics (All FOVs) ===") +min_t = adata.obs["t"].min() +max_t = adata.obs["t"].max() +n_timepoints = max_t - min_t + 1 +global_lineages = fov_stats["n_lineages"].sum() +global_tracks = fov_stats["total_tracks"].sum() +logger.info(f"Total Timepoints: ({n_timepoints})") +logger.info(f"Total lineages: {global_lineages}") +logger.info(f"Total tracks: {global_tracks}") +logger.info( + f"Tracks per lineage (global): {fov_stats['mean_tracks_per_lineage'].mean():.2f} ± {fov_stats['mean_tracks_per_lineage'].std():.2f}" +) +logger.info( + f"Total timepoints per lineage (global): {fov_stats['mean_total_timepoints'].mean():.2f} ± {fov_stats['mean_total_timepoints'].std():.2f}" +) +logger.info( + f"Timepoints per track (global): {fov_stats['mean_timepoints_per_track'].mean():.2f} ± {fov_stats['mean_timepoints_per_track'].std():.2f}" +) + +track_stats = cytodtw.get_track_statistics(filtered_lineages, per_fov=False) + +# %% +fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + +axes[0, 0].hist( + track_stats["total_timepoints"], + bins=30, + color="#1f77b4", + alpha=0.7, + edgecolor="black", +) +axes[0, 0].axvline( + track_stats["total_timepoints"].mean(), + color="#ff7f0e", + linestyle="--", + linewidth=2, + label=f'Mean: {track_stats["total_timepoints"].mean():.1f}', +) +axes[0, 0].set_xlabel("Total Timepoints per Lineage") +axes[0, 0].set_ylabel("Count") +axes[0, 0].set_title("Distribution of Total Timepoints per Lineage") +axes[0, 0].legend() +axes[0, 0].grid(alpha=0.3) + +axes[0, 1].hist( + track_stats["n_tracks"], + bins=range(1, int(track_stats["n_tracks"].max()) + 2), + color="#1f77b4", + alpha=0.7, + edgecolor="black", +) +axes[0, 1].axvline( + track_stats["n_tracks"].mean(), + color="#ff7f0e", + linestyle="--", + linewidth=2, + label=f'Mean: {track_stats["n_tracks"].mean():.2f}', +) +axes[0, 1].set_xlabel("Number of Tracks per Lineage") +axes[0, 1].set_ylabel("Count") +axes[0, 1].set_title("Distribution of Tracks per Lineage") +axes[0, 1].legend() +axes[0, 1].grid(alpha=0.3) + +axes[1, 0].hist( + track_stats["mean_timepoints_per_track"], + bins=30, + color="#1f77b4", + alpha=0.7, + edgecolor="black", +) +axes[1, 0].axvline( + track_stats["mean_timepoints_per_track"].mean(), + color="#ff7f0e", + linestyle="--", + linewidth=2, + label=f'Mean: {track_stats["mean_timepoints_per_track"].mean():.1f}', +) +axes[1, 0].set_xlabel("Mean Timepoints per Track") +axes[1, 0].set_ylabel("Count") +axes[1, 0].set_title("Distribution of Mean Timepoints per Track") +axes[1, 0].legend() +axes[1, 0].grid(alpha=0.3) + +axes[1, 1].scatter( + track_stats["n_tracks"], + track_stats["total_timepoints"], + alpha=0.6, + s=50, + color="#1f77b4", + edgecolor="black", + linewidth=0.5, +) +axes[1, 1].set_xlabel("Number of Tracks") +axes[1, 1].set_ylabel("Total Timepoints") +axes[1, 1].set_title("Tracks vs Total Timepoints") +axes[1, 1].grid(alpha=0.3) +plt.tight_layout() +plt.show() + +# %% diff --git a/applications/pseudotime_analysis/infection_state/compute_alignment.py b/applications/pseudotime_analysis/infection_state/compute_alignment.py new file mode 100644 index 00000000..54650d57 --- /dev/null +++ b/applications/pseudotime_analysis/infection_state/compute_alignment.py @@ -0,0 +1,987 @@ +# %% +import logging +import pickle +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from anndata import read_zarr +from iohub import open_ome_zarr +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from viscy.data.triplet import TripletDataset +from viscy.representation.pseudotime import ( + CytoDtw, + compute_hpi_from_absolute_time, + create_synchronized_warped_sequences, + get_aligned_image_sequences, +) + +# FIXME: standardize the naming convention for the computed features columns. (i.e replace time_point with t) +# FIXME: merge the computed features and the features in AnnData object +# FIXME: the pipeline should take the Anndata objects instea of pd.Dataframes +# FIXME: aligned_df should be an Anndata object instead of pandas +# FIXME: generalize the merging to use the tracking Dictionary instead of hardcoding the column names +# FIXME: be able to load the csv from another file and align the new embeddings w.r.t to this. +# %% +logger = logging.getLogger("viscy") +logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(message)s") +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + +""" +TODO +- We need to find a way to save the annotations, features and track information into one file. +- We need to standardize the naming convention. i.e The annotations fov_name is missing a / at the beginning. +- It would be nice to also select which will be the reference lineages and add that as a column. +- Figure out what is the best format to save the consensus lineage +- Does the consensus track generalize? +- There is a lot of fragmentation. Which tracking was used for the annotations? There is a script that unifies this but no record of which one was it. We can append these as extra columns + +""" + +# Configuration +NAPARI = True +TIME_INTERVAL_MINUTES = 30 # Frame sampling interval in minutes + +if NAPARI: + import os + + import napari + + os.environ["DISPLAY"] = ":1" + viewer = napari.Viewer() +# %% +# File paths for infection state analysis +# +# FIXME combine the annotations,computed features into 1 single file +perturbations_dict = { + # 'denv': { + # 'data_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/4-phenotyping/train-test/2024_11_21_A549_TOMM20_DENV.zarr", + # 'annotations_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/4-phenotyping/0-annotations/track_cell_state_annotation.csv", + # 'features_path': "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/infection_state/output/phase_160patch_104ckpt_ver3max.zarr", + # }, + "zikv": { + "data_path": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr", + "annotations_path": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/0-annotations/track_cell_state_annotation.csv", + "features_path_sensor": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/anndata_predictions/sensor_160patch_104ckpt_ver3max.zarr", + "features_path_phase": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/anndata_predictions/phase_160patch_104ckpt_ver3max.zarr", + "features_path_organelle": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/anndata_predictions/organelle_160patch_104ckpt_ver3max.zarr", + "computed_features_path": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/quantify_remodeling/feature_list_all.csv", + "segmentation_features_path": "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/organelle_segmentation/output/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_mito_features.csv", + # "segmentation_features_path": "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/organelle_segmentation/output/train_test_mito_seg_2/train_test_mito_seg_2_mito_features_nellie.csv", + }, +} + + +ALIGN_TYPE = "infection_apoptotic" # Options: "cell_division" or "infection_state" or "apoptosis" +ALIGNMENT_CHANNEL = "sensor" # sensor,phase,organelle + +output_root = Path( + "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/infection_state/output/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" +) +output_root.mkdir(parents=True, exist_ok=True) + +# FIXME: find a better logic to manage this +consensus_path = None +# consensus_path = "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/infection_state/output/2025_06_26_A549_G3BP1_ZIKV/consensus_lineage_infection_apoptotic_sensor.pkl" + + +# %% +color_dict = { + "uninfected": "blue", + "infected": "orange", +} + + +for key in perturbations_dict.keys(): + data_path = perturbations_dict[key]["data_path"] + annotations_path = perturbations_dict[key]["annotations_path"] + + if ALIGNMENT_CHANNEL not in ["sensor", "phase", "organelle"]: + raise ValueError( + "ALIGNMENT_CHANNEL must be one of 'sensor', 'phase', or 'organelle'" + ) + + computed_features_path = perturbations_dict[key]["computed_features_path"] + segmentation_features_path = perturbations_dict[key]["segmentation_features_path"] + + channel_types = ["sensor", "phase", "organelle"] + features_paths = {} + ad_features = {} + + for channel in channel_types: + path_key = f"features_path_{channel}" + if path_key in perturbations_dict[key]: + path = perturbations_dict[key][path_key] + features_paths[channel] = path + ad_features[channel] = read_zarr(path) + + n_pca_components = 8 + scaler = StandardScaler() + pca = PCA(n_components=n_pca_components) + + for channel, adata in ad_features.items(): + scaled_features = scaler.fit_transform(adata.X) + pca_features = pca.fit_transform(scaled_features) + + adata.obsm["X_pca"] = pca_features + + logger.info( + f"Computed PCA for {channel} channel: explained variance ratio = {pca.explained_variance_ratio_}" + ) + + ad_features_alignment = ad_features[ALIGNMENT_CHANNEL] + + print("Loaded AnnData with shape:", ad_features_alignment.shape) + print("Available columns:", ad_features_alignment.obs.columns.tolist()) + + logger.info(f"Processing dataset: {key}") + logger.info(f"Data path: {data_path}") + logger.info(f"Annotations path: {annotations_path}") + logger.info(f"Alignment channel: {ALIGNMENT_CHANNEL}") + logger.info(f"Features path for alignment: {features_paths[ALIGNMENT_CHANNEL]}") + logger.info(f"Computed features path: {computed_features_path}") + + # Instantiate the CytoDtw object with AnnData + cytodtw = CytoDtw(ad_features_alignment) + feature_df = cytodtw.adata.obs + break + +min_timepoints = 10 +filtered_lineages = cytodtw.get_lineages(min_timepoints) +filtered_lineages = pd.DataFrame(filtered_lineages, columns=["fov_name", "track_id"]) +logger.info( + f"Found {len(filtered_lineages)} lineages with at least {min_timepoints} timepoints" +) + +# %% +# TODO: cleanup annotations +# Annotations +n_timepoints_before = min_timepoints // 2 +n_timepoints_after = min_timepoints // 2 + +# Annotations on the 2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV dataset +cell_div_infected_annotations = [ + { + "fov_name": "A/2/001000", + "track_id": [239], + "timepoints": (25 - n_timepoints_before, 25 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] * (n_timepoints_after), + "weight": 1.0, + }, + { + "fov_name": "C/2/001001", + "track_id": [120], + "timepoints": (30 - n_timepoints_before, 30 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] * (n_timepoints_after), + "weight": 1.0, + }, +] +# apoptotic infected annotation from the 2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV dataset +apoptotic_infected_annotations = [ + { + "fov_name": "B/2/001000", + "track_id": [109], + "timepoints": (25 - n_timepoints_before, 25 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] * (n_timepoints_after), + "weight": 1.0, + }, + { + "fov_name": "B/2/001000", + "track_id": [77], + "timepoints": (21 - n_timepoints_before, 21 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] * (n_timepoints_after), + "weight": 1.0, + }, + # Dies but there is no infection. there is some that look infected from the phase but dont die. + # in A/2/001000 we see cells that never get infected and all around die + # { + # "fov_name": "A/2/000001", + # "track_id": [137], + # "timepoints": (21 - n_timepoints_before, 21 + n_timepoints_after), + # "annotations": ["uninfected"] * (n_timepoints_before) + # + ["infected"] * (n_timepoints_after), + # "weight": 1.0, + # }, + # { + # "fov_name": "C/2/000001", + # "track_id": [40], + # "timepoints": (24 - n_timepoints_before, 24 + n_timepoints_after), + # "annotations": ["uninfected"] * (n_timepoints_before) + # + ["infected"] * (n_timepoints_after), + # "weight": 1.0, + # }, + # { + # "fov_name": "C/2/001001", + # "track_id": [115], + # "timepoints": (21 - n_timepoints_before, 2 + n_timepoints_after), + # "annotations": ["uninfected"] * (n_timepoints_before) + # + ["infected"] * (n_timepoints_after), + # "weight": 1.0, + # }, +] + +# Annotations on the 2024_11_21_A549_TOMM20_DENV dataset +infection_annotations = [ + { + "fov_name": "C/2/001001", + "track_id": [193], + "timepoints": (31 - n_timepoints_before, 31 + n_timepoints_after), + "annotations": ["uinfected"] * (n_timepoints_before) + + ["infected"] + + ["uninfected"] * (n_timepoints_after - 1), + "weight": 1.0, + }, + { + "fov_name": "C/2/001000", + "track_id": [66], + "timepoints": (19 - n_timepoints_before, 19 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] + + ["uninfected"] * (n_timepoints_after - 1), + "weight": 1.0, + }, + { + "fov_name": "C/2/001000", + "track_id": [54], + "timepoints": (27 - n_timepoints_before, 27 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] + + ["uninfected"] * (n_timepoints_after - 1), + "weight": 1.0, + }, + { + "fov_name": "C/2/001000", + "track_id": [53], + "timepoints": (21 - n_timepoints_before, 21 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] + + ["uninfected"] * (n_timepoints_after - 1), + "weight": 1.0, + }, +] + +if ALIGN_TYPE == "infection_state": + aligning_annotations = infection_annotations +elif "apoptotic" in ALIGN_TYPE: + aligning_annotations = apoptotic_infected_annotations +else: + NotImplementedError("Only infection_state alignment is implemented in this example") + + +# %% +# Extract all reference patterns +patterns = [] +pattern_info = [] +REFERENCE_TYPE = "features" +DTW_CONSTRAINT_TYPE = "sakoe_chiba" +DTW_BAND_WIDTH_RATIO = 0.4 + +for i, example in enumerate(aligning_annotations): + pattern = cytodtw.get_reference_pattern( + fov_name=example["fov_name"], + track_id=example["track_id"], + timepoints=example["timepoints"], + reference_type=REFERENCE_TYPE, + ) + patterns.append(pattern) + pattern_info.append( + { + "index": i, + "fov_name": example["fov_name"], + "track_id": example["track_id"], + "timepoints": example["timepoints"], + "annotations": example["annotations"], + } + ) +all_patterns_concat = np.vstack(patterns) + +# %% +# Create consensus pattern +# If loading existing consensus, skip pattern plotting and go directly to alignment +if consensus_path is not None and Path(consensus_path).exists(): + logger.info(f"Loading existing consensus from {consensus_path}") + consensus_result = np.load(consensus_path, allow_pickle=True) + cytodtw.consensus_data = consensus_result + logger.info("Skipping pattern plotting - using existing consensus") +else: + # Plot the sample patterns when creating new consensus + logger.info("Creating new consensus - plotting sample patterns") + fig = cytodtw.plot_sample_patterns( + annotated_samples=aligning_annotations, + reference_type=REFERENCE_TYPE, + n_pca_components=3, + ) + plt.show() + + # Create consensus pattern + consensus_result = cytodtw.create_consensus_reference_pattern( + annotated_samples=aligning_annotations, + reference_selection="median_length", + aggregation_method="median", + reference_type=REFERENCE_TYPE, + ) + + # Plot consensus validation + logger.info("Plotting consensus validation") + fig = cytodtw.plot_consensus_validation( + annotated_samples=aligning_annotations, + reference_type=REFERENCE_TYPE, + metric="cosine", + constraint_type=DTW_CONSTRAINT_TYPE, + band_width_ratio=DTW_BAND_WIDTH_RATIO, + n_pca_components=3, + ) + plt.show() + +# Extract consensus data for use in alignment +consensus_lineage = consensus_result["pattern"] +consensus_annotations = consensus_result.get("annotations", None) +consensus_metadata = consensus_result["metadata"] + +logger.info(f"Consensus pattern shape: {consensus_lineage.shape}") +logger.info(f"Consensus method: {consensus_metadata['aggregation_method']}") +logger.info(f"Reference pattern: {consensus_metadata['reference_pattern']}") +if consensus_annotations: + logger.info(f"Consensus annotations length: {len(consensus_annotations)}") + +# Extract raw infection timepoint from consensus annotations +raw_infection_timepoint = None +if consensus_annotations is not None and "infected" in consensus_annotations: + consensus_infection_idx = consensus_annotations.index("infected") + # For apoptotic infections, find the reference cell's infection timepoint + # We'll update this after alignment with the top-1 cell's actual timepoint + logger.info( + f"Consensus infection marker at index {consensus_infection_idx} in consensus space" + ) + +# %% +# Perform DTW analysis for each embedding method +alignment_results = {} +top_n = 30 + +name = f"consensus_lineage_{ALIGN_TYPE}_{ALIGNMENT_CHANNEL}" +consensus_lineage = cytodtw.consensus_data["pattern"] +# Find pattern matches +matches = cytodtw.get_matches( + reference_pattern=consensus_lineage, + lineages=filtered_lineages.to_numpy(), + window_step=1, + num_candidates=top_n, + method="bernd_clifford", + metric="cosine", + save_path=output_root / f"{name}_matching_lineages_cosine.csv", + reference_type=REFERENCE_TYPE, + constraint_type=DTW_CONSTRAINT_TYPE, + band_width_ratio=DTW_BAND_WIDTH_RATIO, +) + +alignment_results[name] = matches +logger.info(f"Found {len(matches)} matches for {name}") +# %% +# Save matches +print(f"Saving matches to {output_root / f'{name}_matching_lineages_cosine.csv'}") +matches["consensus_path"] = str(output_root / f"{name}.pkl") +cytodtw.save_consensus(output_root / f"{name}.pkl") +matches.to_csv(output_root / f"{name}_matching_lineages_cosine.csv", index=False) +# %% +# Filtering and creating one with just the top n matches +alignment_df = cytodtw.create_alignment_dataframe( + matches, + consensus_lineage, + alignment_name=ALIGN_TYPE, + reference_type=REFERENCE_TYPE, +) + +logger.info(f"Enhanced DataFrame created with {len(alignment_df)} rows") +logger.info(f"Lineages: {alignment_df['lineage_id'].nunique()} (including consensus)") + +# Extract reference cell info from alignment_df +distance_col = f"dtw_{ALIGN_TYPE}_distance" +consensus_mapping_col = f"dtw_{ALIGN_TYPE}_consensus_mapping" + +# Find reference cell: cell with minimum DTW distance (NaN distances are automatically skipped) +reference_lineage_id = alignment_df.loc[ + alignment_df[distance_col].idxmin(), "lineage_id" +] +reference_cell_rows = alignment_df[ + alignment_df["lineage_id"] == reference_lineage_id +].copy() + +# Build reference cell info +reference_cell_info = { + "fov_name": reference_cell_rows.iloc[0]["fov_name"], + "track_ids": reference_cell_rows["track_id"].unique().tolist(), + "dtw_distance": reference_cell_rows.iloc[0][distance_col], + "lineage_id": reference_lineage_id, +} + +# Map consensus infection index to raw timepoint +matching_row = reference_cell_rows[ + reference_cell_rows[consensus_mapping_col] == consensus_infection_idx +] +raw_infection_timepoint = matching_row.iloc[0]["t"] if len(matching_row) > 0 else None +reference_cell_info["raw_infection_timepoint"] = raw_infection_timepoint + +logger.info( + f"Reference cell (top-1 match): FOV={reference_cell_info['fov_name']}, " + f"lineage_id={reference_lineage_id}, " + f"track_ids={reference_cell_info['track_ids']}, " + f"distance={reference_cell_info['dtw_distance']:.4f}" +) + +if raw_infection_timepoint is not None: + logger.info( + f"Mapped consensus infection (idx={consensus_infection_idx}) to raw timepoint t={raw_infection_timepoint}" + ) +else: + logger.warning( + f"Could not map consensus infection index {consensus_infection_idx} to raw timepoint for reference cell" + ) + +# %% +# Prototype video alignment based on DTW matches +z_range = slice(0, 1) +initial_yx_patch_size = (192, 192) +# Top matches should be unique fov_name and lineage_id combinations +top_matches = alignment_df.drop_duplicates(subset=["fov_name", "lineage_id"]).head( + top_n +) + +positions = [] +tracks_tables = [] +images_plate = open_ome_zarr(data_path) + +# Load matching positions +print(f"Loading positions for {len(top_matches)} FOV matches...") +matches_found = 0 +for _, pos in images_plate.positions(): + pos_name = pos.zgroup.name + pos_normalized = pos_name.lstrip("/") + + if pos_normalized in top_matches["fov_name"].values: + positions.append(pos) + matches_found += 1 + + # Get ALL tracks for this FOV to ensure TripletDataset has complete access + tracks_df = cytodtw.adata.obs[ + cytodtw.adata.obs["fov_name"] == pos_normalized + ].copy() + + if len(tracks_df) > 0: + tracks_df = tracks_df.dropna(subset=["x", "y"]) + tracks_df["x"] = tracks_df["x"].astype(int) + tracks_df["y"] = tracks_df["y"].astype(int) + tracks_tables.append(tracks_df) + + if matches_found == 1: + processing_channels = pos.channel_names + +print( + f"Loaded {matches_found} positions with {sum(len(df) for df in tracks_tables)} total tracks" +) + +# Create TripletDataset if we have valid positions +if len(positions) > 0 and len(tracks_tables) > 0: + if "processing_channels" not in locals(): + processing_channels = positions[0].channel_names + + # Use all three channels for overlay visualization + selected_channels = processing_channels # Use all available channels + print( + f"Creating TripletDataset with {len(selected_channels)} channels: {selected_channels}" + ) + + dataset = TripletDataset( + positions=positions, + tracks_tables=tracks_tables, + channel_names=selected_channels, + initial_yx_patch_size=initial_yx_patch_size, + z_range=z_range, + fit=False, + predict_cells=False, + include_fov_names=None, + include_track_ids=None, + time_interval=1, + return_negative=False, + ) + print(f"TripletDataset created with {len(dataset.valid_anchors)} valid anchors") +else: + print("Cannot create TripletDataset - no valid positions or tracks") + dataset = None + +# %% +# Get aligned sequences using consolidated function +if dataset is not None: + + def load_images_from_triplet_dataset(fov_name, track_ids): + """Load images from TripletDataset for given FOV and track IDs.""" + matching_indices = [] + for dataset_idx in range(len(dataset.valid_anchors)): + anchor_row = dataset.valid_anchors.iloc[dataset_idx] + if ( + anchor_row["fov_name"] == fov_name + and anchor_row["track_id"] in track_ids + ): + matching_indices.append(dataset_idx) + + if not matching_indices: + logger.warning( + f"No matching indices found for FOV {fov_name}, tracks {track_ids}" + ) + return {} + + # Get images and create time mapping + batch_data = dataset.__getitems__(matching_indices) + images = [] + for i in range(len(matching_indices)): + img_data = { + "anchor": batch_data["anchor"][i], + "index": batch_data["index"][i], + } + images.append(img_data) + + images.sort(key=lambda x: x["index"]["t"]) + return {img["index"]["t"]: img for img in images} + + # Filter alignment_df to only aligned rows for loading just the aligned region + alignment_col = f"dtw_{ALIGN_TYPE}_aligned" + aligned_only_df = alignment_df[alignment_df[alignment_col]].copy() + + # Use filtered alignment_df since get_aligned_image_sequences expects 'track_id' column + aligned_sequences = get_aligned_image_sequences( + cytodtw_instance=cytodtw, + df=aligned_only_df, + alignment_name=ALIGN_TYPE, + image_loader_fn=load_images_from_triplet_dataset, + max_lineages=top_n, + ) +else: + aligned_sequences = {} + +logger.info(f"Retrieved {len(aligned_sequences)} aligned sequences") +for idx, seq in aligned_sequences.items(): + meta = seq["metadata"] + # Handle both possible keys depending on return structure + images_key = "aligned_images" if "aligned_images" in seq else "concatenated_images" + if images_key in seq and len(seq[images_key]) > 0: + index = seq[images_key][0]["index"] + logger.info( + f"Track id {index['track_id']}: FOV {meta['fov_name']} aligned images, distance={meta.get('distance', meta.get('dtw_distance', 'N/A')):.3f}" + ) + +# %% +logger.info( + f"{ALIGN_TYPE.capitalize()} aligned timepoints: {alignment_df[f'dtw_{ALIGN_TYPE}_aligned'].sum()}/{len(alignment_df)} ({100 * alignment_df[f'dtw_{ALIGN_TYPE}_aligned'].mean():.1f}%)" +) +logger.info(f"Columns: {list(alignment_df.columns)}") + +print("\nSample of enhanced DataFrame:") +sample_df = alignment_df[alignment_df["lineage_id"] != -1].head(10) +display_cols = [ + "lineage_id", + "track_id", + "t", + f"dtw_{ALIGN_TYPE}_aligned", + f"dtw_{ALIGN_TYPE}_consensus_mapping", + "PC1", +] +print(sample_df[display_cols].to_string()) + + +# Plot using the CytoDtw method +fig = cytodtw.plot_individual_lineages( + alignment_df, + alignment_name=ALIGN_TYPE, + feature_columns=["PC1", "PC2", "PC3"], + max_lineages=15, + aligned_linewidth=2.5, + unaligned_linewidth=1.4, + y_offset_step=0, +) + + +# %% +# Create concatenated image sequences using the DataFrame alignment information +# Filter for infection wells only for specific organelles +fov_name_patterns = ["consensus", "B/2"] + +filtered_alignment_df = alignment_df[ + alignment_df["fov_name"].str.contains("|".join(fov_name_patterns)) +] +if dataset is not None: + # Define TripletDataset-specific image loader + def load_images_from_triplet_dataset(fov_name, track_ids): + """Load images from TripletDataset for given FOV and track IDs.""" + matching_indices = [] + for dataset_idx in range(len(dataset.valid_anchors)): + anchor_row = dataset.valid_anchors.iloc[dataset_idx] + if ( + anchor_row["fov_name"] == fov_name + and anchor_row["track_id"] in track_ids + ): + matching_indices.append(dataset_idx) + + if not matching_indices: + logger.warning( + f"No matching indices found for FOV {fov_name}, tracks {track_ids}" + ) + return {} + + # Get images and create time mapping + batch_data = dataset.__getitems__(matching_indices) + images = [] + for i in range(len(matching_indices)): + img_data = { + "anchor": batch_data["anchor"][i], + "index": batch_data["index"][i], + } + images.append(img_data) + + images.sort(key=lambda x: x["index"]["t"]) + return {img["index"]["t"]: img for img in images} + + concatenated_image_sequences = get_aligned_image_sequences( + cytodtw_instance=cytodtw, + df=filtered_alignment_df, + alignment_name=ALIGN_TYPE, + image_loader_fn=load_images_from_triplet_dataset, + max_lineages=10, + ) +else: + print("Skipping image sequence creation - no valid dataset available") + concatenated_image_sequences = {} + +# Load concatenated sequences into napari using WARPED SYNCHRONIZED coordinates +if NAPARI and dataset is not None and len(concatenated_image_sequences) > 0: + import numpy as np + + logger.info("\n" + "=" * 70) + logger.info("Creating synchronized warped sequences for napari visualization") + logger.info("=" * 70) + + # Create synchronized warped sequences with HPI metadata + warped_result = create_synchronized_warped_sequences( + concatenated_image_sequences=concatenated_image_sequences, + alignment_df=filtered_alignment_df, + alignment_name=ALIGN_TYPE, + consensus_infection_idx=consensus_infection_idx, + time_interval_minutes=TIME_INTERVAL_MINUTES, + ) + + warped_sequences = warped_result["warped_sequences"] + alignment_shifts = warped_result["alignment_shifts"] + warped_metadata = warped_result["warped_metadata"] + + # Log HPI mapping information + logger.info("\n" + "=" * 70) + logger.info("Alignment Shifts and HPI Metadata") + logger.info("=" * 70) + for lineage_id, shift_info in alignment_shifts.items(): + logger.info(f"\nLineage {lineage_id}:") + logger.info( + f" FOV: {shift_info['fov_name']}, tracks: {shift_info['track_ids']}" + ) + logger.info( + f" first_aligned_t (pseudotime 0): {shift_info['first_aligned_t']}" + ) + logger.info( + f" infection_t_abs (biological event): {shift_info['infection_t_abs']}" + ) + logger.info(f" shift (t_viz = t_abs + shift): {shift_info['shift']}") + logger.info( + f" infection_offset_in_viz: {shift_info['infection_offset_in_viz']} frames from pseudotime 0" + ) + + # Example HPI calculation + example_t = shift_info["infection_t_abs"] + 5 + example_hpi = compute_hpi_from_absolute_time( + t_abs=example_t, + alignment_shifts=alignment_shifts, + lineage_id=lineage_id, + time_interval_minutes=TIME_INTERVAL_MINUTES, + ) + logger.info( + f" Example: t_abs={example_t} → HPI={example_hpi:.2f} hours post infection" + ) + + # Load into napari using warped coordinates + logger.info("\n" + "=" * 70) + logger.info("Loading warped sequences into napari") + logger.info("=" * 70) + + aligned_start_idx = warped_metadata["aligned_region_start_idx"] + aligned_end_idx = warped_metadata["aligned_region_end_idx"] + + logger.info( + f"Warped time structure: {warped_metadata['total_warped_length']} total frames" + ) + logger.info(f" Frames [0 to {aligned_start_idx - 1}]: Before aligned region") + logger.info( + f" Frames [{aligned_start_idx} to {aligned_end_idx - 1}]: ALIGNED REGION (synchronized!)" + ) + logger.info(f" Frames [{aligned_end_idx} to end]: After aligned region") + logger.info(" Black frames = zero-padding (cell has no data at this warped time)") + + for lineage_id, warped_time_series in warped_sequences.items(): + meta = concatenated_image_sequences[lineage_id]["metadata"] + n_channels = warped_time_series.shape[1] + + logger.info( + f"Lineage {lineage_id}: {warped_time_series.shape[0]} warped frames, {n_channels} channels" + ) + + # Set up colormap based on number of channels + # FIXME: This is hardcoded for specific datasets - improve logic as needed + if n_channels == 2: + colormap = ["green", "magenta"] + elif n_channels == 3: + colormap = ["gray", "green", "magenta"] + else: + colormap = ["gray"] * n_channels # Default fallback + + # Add each channel as a separate layer in napari + for channel_idx in range(n_channels): + channel_data = warped_time_series[:, channel_idx, :, :, :] + channel_name = ( + processing_channels[channel_idx] + if channel_idx < len(processing_channels) + else f"ch{channel_idx}" + ) + # Use WARPED prefix to indicate warped/synchronized time coordinates + layer_name = f"WARPED_track_id_{meta['track_ids'][0]}_FOV_{meta['fov_name']}_dist_{meta['dtw_distance']:.3f}_{channel_name}" + + viewer.add_image( + channel_data, + name=layer_name, + contrast_limits=(channel_data.min(), channel_data.max()), + colormap=colormap[channel_idx], + blending="additive", + ) + logger.debug(f"Added {channel_name} channel for lineage {lineage_id}") + + # Create shapes layer with circle markers for aligned region + img_height = warped_time_series.shape[3] # y dimension + img_width = warped_time_series.shape[4] # x dimension + circle_center_y = img_height * 0.1 # 10% from top + circle_center_x = img_width * 0.1 # 10% from left + circle_radius = min(img_height, img_width) * 0.05 # 5% of smaller dimension + + # Create circles for each aligned frame in warped time + shapes_data_warped = [] + for frame_idx in range(aligned_start_idx, aligned_end_idx): + # Ellipse defined by 4 corners of bounding box: [t, z, y, x] + ellipse = np.array( + [ + [ + frame_idx, + 0, + circle_center_y - circle_radius, + circle_center_x - circle_radius, + ], # top-left + [ + frame_idx, + 0, + circle_center_y - circle_radius, + circle_center_x + circle_radius, + ], # top-right + [ + frame_idx, + 0, + circle_center_y + circle_radius, + circle_center_x + circle_radius, + ], # bottom-right + [ + frame_idx, + 0, + circle_center_y + circle_radius, + circle_center_x - circle_radius, + ], # bottom-left + ] + ) + shapes_data_warped.append(ellipse) + + if len(shapes_data_warped) > 0: + # Add shapes layer with cyan circles marking aligned frames + viewer.add_shapes( + shapes_data_warped, + shape_type="ellipse", + edge_width=3, + edge_color="cyan", + face_color="transparent", + name=f"WARPED_ALIGNED_MARKER_track_id_{meta['track_ids'][0]}_FOV_{meta['fov_name']}", + ndim=4, + ) + logger.info( + f"Added alignment marker for lineage {lineage_id} " + f"(frames {aligned_start_idx} to {aligned_end_idx - 1} marked with cyan circles)" + ) + + logger.info("\n" + "=" * 70) + logger.info("Napari visualization complete!") + logger.info("=" * 70) + logger.info("All cells are synchronized:") + logger.info( + f" - Frame {aligned_start_idx} = START of aligned region (pseudotime 0)" + ) + logger.info( + f" - Frame {aligned_start_idx + consensus_infection_idx} = INFECTION EVENT" + ) + logger.info(f" - Frame {aligned_end_idx - 1} = END of aligned region") + logger.info("\nVisual indicators:") + logger.info(" - Cyan circles in top-left corner = ALIGNED frames (synchronized!)") + logger.info(" - Black frames = zero-padding (cell has no data at that time)") + logger.info("\nScrub through frames to see synchronized biological progression!") + logger.info("=" * 70) +# %% +if segmentation_features_path is not None or Path(segmentation_features_path).exists(): + # Get the segmentation based features and compute per-cell aggregates + segmentation_features_df = pd.read_csv(segmentation_features_path) + segmentation_features_df["fov_name"] = segmentation_features_df[ + "fov_name" + ].str.lstrip("/") + + # Compute per-cell mitochondria population statistics + segs_population_features = [] + for (fov, track, t), group in segmentation_features_df.groupby( + ["fov_name", "track_id", "t"] + ): + stats = { + "fov_name": fov, + "track_id": track, + "t": t, + # Count metrics + "segs_count": len(group), + # Area/volume metrics + "segs_total_area": group["area"].sum(), + "segs_mean_area": group["area"].mean(), + "segs_std_area": group["area"].std(), + "segs_median_area": group["area"].median(), + # Shape metrics + "segs_mean_eccentricity": group["eccentricity"].mean(), + "segs_std_eccentricity": group["eccentricity"].std(), + "segs_mean_solidity": group["solidity"].mean(), + "segs_std_solidity": group["solidity"].std(), + "segs_circularity_mean": group["circularity"].mean(), + "segs_circularity_std": group["circularity"].std(), + # Intensity metrics + # "segs_mean_intensity": group["mean_intensity"].mean(), + # "segs_std_intensity_across_mitos": group["mean_intensity"].std(), + # "segs_mean_max_intensity": group["max_intensity"].mean(), + # Texture metrics (aggregated) + # "segs_mean_texture_contrast": group["texture_contrast"].mean(), + # "segs_mean_texture_homogeneity": group["texture_homogeneity"].mean(), + # Frangi filter metrics (tubularity/network structure) + "segs_mean_frangi_mean": group["frangi_mean_intensity"].mean(), + "segs_mean_frangi_std": group["frangi_std_intensity"].mean(), + # Shape diversity (coefficient of variation) + "segs_area_cv": group["area"].std() / (group["area"].mean() + 1e-6), + "segs_eccentricity_cv": group["eccentricity"].std() + / (group["eccentricity"].mean() + 1e-6), + "segs_solidity_cv": group["solidity"].std() + / (group["solidity"].mean() + 1e-6), + "segs_frangi_cv": group["frangi_mean_intensity"].std() + / (group["frangi_mean_intensity"].mean() + 1e-6), + "segs_circularity_cv": group["circularity"].std() + / (group["circularity"].mean() + 1e-6), + } + segs_population_features.append(stats) + + segs_population_df = pd.DataFrame(segs_population_features) + + logger.info( + f"Computed mitochondria population features for {len(segs_population_df)} (fov, track, t) combinations" + ) + logger.info( + f"Mitochondria population feature columns: {list(segs_population_df.columns)}" + ) + +# Load the computed features and PCs +computed_features_df = pd.read_csv(computed_features_path) +# Rename time_point to t for merging +computed_features_df = computed_features_df.rename(columns={"time_point": "t"}) +# Remove the first forward slash from the fov_name +computed_features_df["fov_name"] = computed_features_df["fov_name"].str.lstrip("/") + + +# %% +if segmentation_features_path is not None and Path(segmentation_features_path).exists(): + # Merge the computed features and mitochondria population features + combined_features_df = computed_features_df.merge( + segs_population_df, on=["fov_name", "track_id", "t"], how="left" + ) + +# Add PCs from each channel to the combined features +for channel, adata in ad_features.items(): + # Create a temporary dataframe with PCs from this channel + pcs_df = adata.obs[["fov_name", "track_id", "t"]].copy() + + # Add PC columns with channel prefix + for i in range(n_pca_components): + pcs_df[f"{channel}_PC{i + 1}"] = adata.obsm["X_pca"][:, i] + + # Merge with combined features + combined_features_df = combined_features_df.merge( + pcs_df, + on=["fov_name", "track_id", "t"], + how="left", + ) + logger.info( + f"Added {n_pca_components} PCs from {channel} channel to combined features" + ) + +# %% +# Merge alignment_df with combined_features_df to create master features dataframe +master_features_df = alignment_df.merge( + combined_features_df, + on=["fov_name", "track_id", "t", "x", "y"], + how="outer", # Use outer to keep all tracking data, not just aligned +) + +logger.info(f"Created master features dataframe. Shape: {master_features_df.shape}") +logger.info(f"Columns: {list(master_features_df.columns)}") + +# Save master features dataframe +output_path = output_root / f"master_features_{ALIGN_TYPE}_{ALIGNMENT_CHANNEL}.csv" +master_features_df.to_csv(output_path, index=False) +logger.info(f"Saved master features dataframe to {output_path}") + +# Save alignment metadata (including warped visualization metadata if available) +metadata = { + "consensus_pattern": consensus_lineage, + "consensus_annotations": consensus_annotations, + "consensus_metadata": consensus_metadata, + "reference_cell_info": reference_cell_info, # Top-1 cell's full trajectory info + "raw_infection_timepoint": raw_infection_timepoint, # Infection timepoint in raw data space + "infection_timepoint": raw_infection_timepoint, # Backward compatibility + "aligned_region_bounds": None, # Will be computed in visualization script + "alignment_type": ALIGN_TYPE, + "alignment_channel": ALIGNMENT_CHANNEL, + "time_interval_minutes": TIME_INTERVAL_MINUTES, # For HPI conversion +} + +# Add warped visualization metadata if napari visualization was run +if NAPARI and "alignment_shifts" in locals(): + metadata["alignment_shifts"] = alignment_shifts + metadata["warped_metadata"] = warped_metadata + logger.info("Added warped visualization metadata to save file") + +metadata_path = output_root / f"alignment_metadata_{ALIGN_TYPE}_{ALIGNMENT_CHANNEL}.pkl" +with open(metadata_path, "wb") as f: + pickle.dump(metadata, f) +logger.info(f"Saved alignment metadata to {metadata_path}") + +logger.info("\n## Pipeline Complete!") +logger.info("To visualize results, run visualize_alignment.py with the saved outputs:") + +# %% diff --git a/applications/pseudotime_analysis/infection_state/infection_state.py b/applications/pseudotime_analysis/infection_state/infection_state.py new file mode 100644 index 00000000..e7e5594f --- /dev/null +++ b/applications/pseudotime_analysis/infection_state/infection_state.py @@ -0,0 +1,2357 @@ +# %% +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from anndata import read_zarr +from iohub import open_ome_zarr +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler + +from viscy.data.triplet import INDEX_COLUMNS, TripletDataset +from viscy.representation.pseudotime import ( + CytoDtw, + align_embedding_patterns, + get_aligned_image_sequences, +) + +# FIXME: standardize the naming convention for the computed features columns. (i.e replace time_point with t) +# FIXME: merge the computed features and the features in AnnData object +# FIXME: the pipeline should take the Anndata objects instea of pd.Dataframes +# FIXME: aligned_df should be an Anndata object instead of pandas +# FIXME: generalize the merging to use the tracking Dictionary instead of hardcoding the column names +# FIXME: be able to load the csv from another file and align the new embeddings w.r.t to this. +# %% +logger = logging.getLogger("viscy") +logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(message)s") +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + +""" +TODO +- We need to find a way to save the annotations, features and track information into one file. +- We need to standardize the naming convention. i.e The annotations fov_name is missing a / at the beginning. +- It would be nice to also select which will be the reference lineages and add that as a column. +- Figure out what is the best format to save the consensus lineage +- Does the consensus track generalize? +- There is a lot of fragmentation. Which tracking was used for the annotations? There is a script that unifies this but no record of which one was it. We can append these as extra columns + +""" + +# Configuration +NAPARI = True +if NAPARI: + import os + + import napari + + os.environ["DISPLAY"] = ":1" + viewer = napari.Viewer() +# %% +# File paths for infection state analysis +# +# FIXME combine the annotations,computed features into 1 single file +perturbations_dict = { + # 'denv': { + # 'data_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/4-phenotyping/train-test/2024_11_21_A549_TOMM20_DENV.zarr", + # 'annotations_path': "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/4-phenotyping/0-annotations/track_cell_state_annotation.csv", + # 'features_path': "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/infection_state/output/phase_160patch_104ckpt_ver3max.zarr", + # }, + "zikv": { + "data_path": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr", + "annotations_path": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/0-annotations/track_cell_state_annotation.csv", + "features_path_sensor": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/anndata_predictions/sensor_160patch_104ckpt_ver3max.zarr", + "features_path_phase": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/anndata_predictions/phase_160patch_104ckpt_ver3max.zarr", + "features_path_organelle": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/anndata_predictions/organelle_160patch_104ckpt_ver3max.zarr", + "computed_features_path": "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/quantify_remodeling/feature_list_all.csv", + "segmentation_features_path": "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/organelle_segmentation/output/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_segs_features.csv", + }, +} + + +ALIGN_TYPE = "infection_apoptotic" # Options: "cell_division" or "infection_state" or "apoptosis" +ALIGNMENT_CHANNEL = "sensor" # sensor,phase,organelle + +output_root = Path( + "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/infection_state/output/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" +) +output_root.mkdir(parents=True, exist_ok=True) + +# FIXME: find a better logic to manage this +consensus_path = None +# consensus_path = "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/infection_state/output/2025_06_26_A549_G3BP1_ZIKV/consensus_lineage_infection_apoptotic_sensor.pkl" + + +# %% +color_dict = { + "uninfected": "blue", + "infected": "orange", +} + + +for key in perturbations_dict.keys(): + data_path = perturbations_dict[key]["data_path"] + annotations_path = perturbations_dict[key]["annotations_path"] + + if ALIGNMENT_CHANNEL not in ["sensor", "phase", "organelle"]: + raise ValueError( + "ALIGNMENT_CHANNEL must be one of 'sensor', 'phase', or 'organelle'" + ) + + computed_features_path = perturbations_dict[key]["computed_features_path"] + segmentation_features_path = perturbations_dict[key]["segmentation_features_path"] + + channel_types = ["sensor", "phase", "organelle"] + features_paths = {} + ad_features = {} + + for channel in channel_types: + path_key = f"features_path_{channel}" + if path_key in perturbations_dict[key]: + path = perturbations_dict[key][path_key] + features_paths[channel] = path + ad_features[channel] = read_zarr(path) + + n_pca_components = 8 + scaler = StandardScaler() + pca = PCA(n_components=n_pca_components) + + for channel, adata in ad_features.items(): + scaled_features = scaler.fit_transform(adata.X) + pca_features = pca.fit_transform(scaled_features) + + adata.obsm["X_pca"] = pca_features + + logger.info( + f"Computed PCA for {channel} channel: explained variance ratio = {pca.explained_variance_ratio_}" + ) + + ad_features_alignment = ad_features[ALIGNMENT_CHANNEL] + + print("Loaded AnnData with shape:", ad_features_alignment.shape) + print("Available columns:", ad_features_alignment.obs.columns.tolist()) + + logger.info(f"Processing dataset: {key}") + logger.info(f"Data path: {data_path}") + logger.info(f"Annotations path: {annotations_path}") + logger.info(f"Alignment channel: {ALIGNMENT_CHANNEL}") + logger.info(f"Features path for alignment: {features_paths[ALIGNMENT_CHANNEL]}") + logger.info(f"Computed features path: {computed_features_path}") + + # Instantiate the CytoDtw object with AnnData + cytodtw = CytoDtw(ad_features_alignment) + feature_df = cytodtw.adata.obs + break + +min_timepoints = 20 +filtered_lineages = cytodtw.get_lineages(min_timepoints) +filtered_lineages = pd.DataFrame(filtered_lineages, columns=["fov_name", "track_id"]) +logger.info( + f"Found {len(filtered_lineages)} lineages with at least {min_timepoints} timepoints" +) + +# %% +# Annotations +n_timepoints_before = min_timepoints // 2 +n_timepoints_after = min_timepoints // 2 + +# Annotations on the 2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV dataset +cell_div_infected_annotations = [ + { + "fov_name": "A/2/001000", + "track_id": [239], + "timepoints": (25 - n_timepoints_before, 25 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] * (n_timepoints_after), + "weight": 1.0, + }, + { + "fov_name": "C/2/001001", + "track_id": [120], + "timepoints": (30 - n_timepoints_before, 30 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] * (n_timepoints_after), + "weight": 1.0, + }, +] +# apoptotic infected annotation from the 2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV dataset +apoptotic_infected_annotations = [ + { + "fov_name": "B/2/001000", + "track_id": [109], + "timepoints": (25 - n_timepoints_before, 25 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] * (n_timepoints_after), + "weight": 1.0, + }, + { + "fov_name": "B/2/001000", + "track_id": [77], + "timepoints": (21 - n_timepoints_before, 21 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] * (n_timepoints_after), + "weight": 1.0, + }, + # Dies but there is no infection. there is some that look infected from the phase but dont die. + # in A/2/001000 we see cells that never get infected and all around die + # { + # "fov_name": "A/2/000001", + # "track_id": [137], + # "timepoints": (21 - n_timepoints_before, 21 + n_timepoints_after), + # "annotations": ["uninfected"] * (n_timepoints_before) + # + ["infected"] * (n_timepoints_after), + # "weight": 1.0, + # }, + # { + # "fov_name": "C/2/000001", + # "track_id": [40], + # "timepoints": (24 - n_timepoints_before, 24 + n_timepoints_after), + # "annotations": ["uninfected"] * (n_timepoints_before) + # + ["infected"] * (n_timepoints_after), + # "weight": 1.0, + # }, + # { + # "fov_name": "C/2/001001", + # "track_id": [115], + # "timepoints": (21 - n_timepoints_before, 2 + n_timepoints_after), + # "annotations": ["uninfected"] * (n_timepoints_before) + # + ["infected"] * (n_timepoints_after), + # "weight": 1.0, + # }, +] + +# Annotations on the 2024_11_21_A549_TOMM20_DENV dataset +infection_annotations = [ + { + "fov_name": "C/2/001001", + "track_id": [193], + "timepoints": (31 - n_timepoints_before, 31 + n_timepoints_after), + "annotations": ["uinfected"] * (n_timepoints_before) + + ["infected"] + + ["uninfected"] * (n_timepoints_after - 1), + "weight": 1.0, + }, + { + "fov_name": "C/2/001000", + "track_id": [66], + "timepoints": (19 - n_timepoints_before, 19 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] + + ["uninfected"] * (n_timepoints_after - 1), + "weight": 1.0, + }, + { + "fov_name": "C/2/001000", + "track_id": [54], + "timepoints": (27 - n_timepoints_before, 27 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] + + ["uninfected"] * (n_timepoints_after - 1), + "weight": 1.0, + }, + { + "fov_name": "C/2/001000", + "track_id": [53], + "timepoints": (21 - n_timepoints_before, 21 + n_timepoints_after), + "annotations": ["uninfected"] * (n_timepoints_before) + + ["infected"] + + ["uninfected"] * (n_timepoints_after - 1), + "weight": 1.0, + }, +] + +if ALIGN_TYPE == "infection_state": + aligning_annotations = infection_annotations +elif "apoptotic" in ALIGN_TYPE: + aligning_annotations = apoptotic_infected_annotations +else: + NotImplementedError("Only infection_state alignment is implemented in this example") + + +# %% +# Extract all reference patterns +patterns = [] +pattern_info = [] +REFERENCE_TYPE = "features" +DTW_CONSTRAINT_TYPE = "sakoe_chiba" +DTW_BAND_WIDTH_RATIO = 0.3 + +for i, example in enumerate(aligning_annotations): + pattern = cytodtw.get_reference_pattern( + fov_name=example["fov_name"], + track_id=example["track_id"], + timepoints=example["timepoints"], + reference_type=REFERENCE_TYPE, + ) + patterns.append(pattern) + pattern_info.append( + { + "index": i, + "fov_name": example["fov_name"], + "track_id": example["track_id"], + "timepoints": example["timepoints"], + "annotations": example["annotations"], + } + ) +all_patterns_concat = np.vstack(patterns) + +# %% +# Plot the sample patterns +scaler = StandardScaler() +scaled_patterns = scaler.fit_transform(all_patterns_concat) +pca = PCA(n_components=3) +pca.fit(scaled_patterns) + +n_patterns = len(patterns) +fig, axes = plt.subplots(n_patterns, 3, figsize=(12, 3 * n_patterns)) +if n_patterns == 1: + axes = axes.reshape(1, -1) + +for i, (pattern, info) in enumerate(zip(patterns, pattern_info)): + scaled_pattern = scaler.transform(pattern) + pc_pattern = pca.transform(scaled_pattern) + time_axis = np.arange(len(pattern)) + + for pc_idx in range(3): + ax = axes[i, pc_idx] + + ax.plot( + time_axis, + pc_pattern[:, pc_idx], + "o-", + color="blue", + linewidth=2, + markersize=4, + ) + + annotations = info["annotations"] + for t, annotation in enumerate(annotations): + if annotation == "mitosis": + ax.axvline(t, color="orange", alpha=0.7, linestyle="--", linewidth=2) + ax.scatter(t, pc_pattern[t, pc_idx], c="orange", s=50, zorder=5) + elif annotation == "infected": + ax.axvline(t, color="red", alpha=0.5, linestyle="--", linewidth=1) + ax.scatter(t, pc_pattern[t, pc_idx], c="red", s=30, zorder=5) + break # Only mark the first infection timepoint + + ax.set_xlabel("Time") + ax.set_ylabel(f"PC{pc_idx+1}") + ax.set_title( + f'Pattern {i+1}: FOV {info["fov_name"]}, Tracks {info["track_id"]}\nPC{pc_idx+1} over time' + ) + ax.grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() + +# %% +# Create consensus pattern +if consensus_path is not None and Path(consensus_path).exists(): + logger.info(f"Loading existing consensus from {consensus_path}") + consensus_result = np.load(consensus_path, allow_pickle=True) + cytodtw.consensus_data = consensus_result +else: + consensus_result = cytodtw.create_consensus_reference_pattern( + annotated_samples=aligning_annotations, + reference_selection="median_length", + aggregation_method="median", + reference_type=REFERENCE_TYPE, + ) + +consensus_lineage = consensus_result["pattern"] +consensus_annotations = consensus_result.get("annotations", None) +consensus_metadata = consensus_result["metadata"] + +logger.info(f"Created consensus pattern with shape: {consensus_lineage.shape}") +logger.info(f"Consensus method: {consensus_metadata['aggregation_method']}") +logger.info(f"Reference pattern: {consensus_metadata['reference_pattern']}") +if consensus_annotations: + logger.info(f"Consensus annotations length: {len(consensus_annotations)}") + + +# %% +# Plot all aligned consensus patterns together to validate alignment +# The patterns need to be DTW-aligned to the consensus for proper visualization +# We'll align each pattern to the consensus using the same method used internally + +aligned_patterns_list = [] +aligned_annotations_list = [] +for i, example in enumerate(aligning_annotations): + # Extract pattern + pattern = cytodtw.get_reference_pattern( + fov_name=example["fov_name"], + track_id=example["track_id"], + timepoints=example["timepoints"], + reference_type=REFERENCE_TYPE, + ) + + # Align to consensus + if len(pattern) == len(consensus_lineage): + # Already same length, likely the reference pattern + aligned_patterns_list.append(pattern) + aligned_annotations_list.append(example.get("annotations", None)) + else: + # Align to consensus + alignment_result = align_embedding_patterns( + query_pattern=pattern, + reference_pattern=consensus_lineage, + metric="cosine", + query_annotations=example.get("annotations", None), + constraint_type=DTW_CONSTRAINT_TYPE, + band_width_ratio=DTW_BAND_WIDTH_RATIO, + ) + aligned_patterns_list.append(alignment_result["pattern"]) + aligned_annotations_list.append(alignment_result.get("annotations", None)) + +fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + +for pc_idx in range(3): + ax = axes[pc_idx] + + # Transform each aligned pattern to PC space and plot + for i, pattern in enumerate(aligned_patterns_list): + scaled_ref = scaler.transform(pattern) + pc_ref = pca.transform(scaled_ref) + + time_axis = np.arange(len(pc_ref)) + ax.plot( + time_axis, + pc_ref[:, pc_idx], + "o-", + label=f"Ref {i+1}", + alpha=0.7, + linewidth=2, + markersize=4, + ) + + # Mark infection timepoint for this aligned trajectory + if aligned_annotations_list[i] and "infected" in aligned_annotations_list[i]: + infection_t = aligned_annotations_list[i].index("infected") + ax.axvline( + infection_t, color="orange", alpha=0.4, linestyle="--", linewidth=1 + ) + + # Overlay consensus pattern + scaled_consensus = scaler.transform(consensus_lineage) + pc_consensus = pca.transform(scaled_consensus) + time_axis = np.arange(len(pc_consensus)) + ax.plot( + time_axis, + pc_consensus[:, pc_idx], + "s-", + color="black", + linewidth=3, + markersize=6, + label="Consensus", + zorder=10, + ) + + # Mark consensus infection timepoint with a thicker, more prominent line + if consensus_annotations and "infected" in consensus_annotations: + consensus_infection_t = consensus_annotations.index("infected") + ax.axvline( + consensus_infection_t, + color="orange", + alpha=0.9, + linestyle="--", + linewidth=2.5, + label="Infection", + ) + + ax.set_xlabel("Aligned Time") + ax.set_ylabel(f"PC{pc_idx+1}") + ax.set_title(f"PC{pc_idx+1}: All DTW-Aligned References + Consensus") + ax.grid(True, alpha=0.3) + ax.legend() + +plt.suptitle( + "Consensus Validation: DTW-Aligned References + Computed Consensus", fontsize=14 +) +plt.tight_layout() +plt.show() +logger.info("Plotted DTW-aligned consensus patterns for validation") + +# %% +# Perform DTW analysis for each embedding method +alignment_results = {} +top_n = 30 + +name = f"consensus_lineage_{ALIGN_TYPE}_{ALIGNMENT_CHANNEL}" +consensus_lineage = cytodtw.consensus_data["pattern"] +# Find pattern matches +matches = cytodtw.get_matches( + reference_pattern=consensus_lineage, + lineages=filtered_lineages.to_numpy(), + window_step=1, + num_candidates=top_n, + method="bernd_clifford", + metric="cosine", + save_path=output_root / f"{name}_matching_lineages_cosine.csv", + reference_type=REFERENCE_TYPE, + constraint_type=DTW_CONSTRAINT_TYPE, + band_width_ratio=DTW_BAND_WIDTH_RATIO, +) + +alignment_results[name] = matches +logger.info(f"Found {len(matches)} matches for {name}") +# %% +# Save matches +print(f'Saving matches to {output_root / f"{name}_matching_lineages_cosine.csv"}') +matches["consensus_path"] = str(output_root / f"{name}.pkl") +cytodtw.save_consensus(output_root / f"{name}.pkl") +matches.to_csv(output_root / f"{name}_matching_lineages_cosine.csv", index=False) +# %% +top_matches = matches.head(top_n) + +# Use the new alignment dataframe method instead of manual alignment +alignment_df = cytodtw.create_alignment_dataframe( + top_matches, + consensus_lineage, + alignment_name=ALIGN_TYPE, + reference_type=REFERENCE_TYPE, +) + +logger.info(f"Enhanced DataFrame created with {len(alignment_df)} rows") +logger.info(f"Lineages: {alignment_df['lineage_id'].nunique()} (including consensus)") + +# %% +# Prototype video alignment based on DTW matches + +z_range = slice(0, 1) +initial_yx_patch_size = (192, 192) + +positions = [] +tracks_tables = [] +images_plate = open_ome_zarr(data_path) + +# Load matching positions +print(f"Loading positions for {len(top_matches)} FOV matches...") +matches_found = 0 +for _, pos in images_plate.positions(): + pos_name = pos.zgroup.name + pos_normalized = pos_name.lstrip("/") + + if pos_normalized in top_matches["fov_name"].values: + positions.append(pos) + matches_found += 1 + + # Get ALL tracks for this FOV to ensure TripletDataset has complete access + tracks_df = cytodtw.adata.obs[ + cytodtw.adata.obs["fov_name"] == pos_normalized + ].copy() + + if len(tracks_df) > 0: + tracks_df = tracks_df.dropna(subset=["x", "y"]) + tracks_df["x"] = tracks_df["x"].astype(int) + tracks_df["y"] = tracks_df["y"].astype(int) + tracks_tables.append(tracks_df) + + if matches_found == 1: + processing_channels = pos.channel_names + +print( + f"Loaded {matches_found} positions with {sum(len(df) for df in tracks_tables)} total tracks" +) + +# Create TripletDataset if we have valid positions +if len(positions) > 0 and len(tracks_tables) > 0: + if "processing_channels" not in locals(): + processing_channels = positions[0].channel_names + + # Use all three channels for overlay visualization + selected_channels = processing_channels # Use all available channels + print( + f"Creating TripletDataset with {len(selected_channels)} channels: {selected_channels}" + ) + + dataset = TripletDataset( + positions=positions, + tracks_tables=tracks_tables, + channel_names=selected_channels, + initial_yx_patch_size=initial_yx_patch_size, + z_range=z_range, + fit=False, + predict_cells=False, + include_fov_names=None, + include_track_ids=None, + time_interval=1, + return_negative=False, + ) + print(f"TripletDataset created with {len(dataset.valid_anchors)} valid anchors") +else: + print("Cannot create TripletDataset - no valid positions or tracks") + dataset = None + +# %% +# Get aligned sequences using consolidated function +if dataset is not None: + + def load_images_from_triplet_dataset(fov_name, track_ids): + """Load images from TripletDataset for given FOV and track IDs.""" + matching_indices = [] + for dataset_idx in range(len(dataset.valid_anchors)): + anchor_row = dataset.valid_anchors.iloc[dataset_idx] + if ( + anchor_row["fov_name"] == fov_name + and anchor_row["track_id"] in track_ids + ): + matching_indices.append(dataset_idx) + + if not matching_indices: + logger.warning( + f"No matching indices found for FOV {fov_name}, tracks {track_ids}" + ) + return {} + + # Get images and create time mapping + batch_data = dataset.__getitems__(matching_indices) + images = [] + for i in range(len(matching_indices)): + img_data = { + "anchor": batch_data["anchor"][i], + "index": batch_data["index"][i], + } + images.append(img_data) + + images.sort(key=lambda x: x["index"]["t"]) + return {img["index"]["t"]: img for img in images} + + # Filter alignment_df to only aligned rows for loading just the aligned region + alignment_col = f"dtw_{ALIGN_TYPE}_aligned" + aligned_only_df = alignment_df[alignment_df[alignment_col] == True].copy() + + # Use filtered alignment_df since get_aligned_image_sequences expects 'track_id' column + aligned_sequences = get_aligned_image_sequences( + cytodtw_instance=cytodtw, + df=aligned_only_df, + alignment_name=ALIGN_TYPE, + image_loader_fn=load_images_from_triplet_dataset, + max_lineages=top_n, + ) +else: + aligned_sequences = {} + +logger.info(f"Retrieved {len(aligned_sequences)} aligned sequences") +for idx, seq in aligned_sequences.items(): + meta = seq["metadata"] + # Handle both possible keys depending on return structure + images_key = "aligned_images" if "aligned_images" in seq else "concatenated_images" + if images_key in seq and len(seq[images_key]) > 0: + index = seq[images_key][0]["index"] + logger.info( + f"Track id {index['track_id']}: FOV {meta['fov_name']} aligned images, distance={meta.get('distance', meta.get('dtw_distance', 'N/A')):.3f}" + ) + +# %% +# # Load aligned sequences into napari (ALIGNED REGION ONLY) +# # Note: This loads only the aligned portion of the trajectory. +# # For complete trajectories (unaligned + aligned + unaligned), see the concatenated_image_sequences section below. +# if NAPARI and len(aligned_sequences) > 0: +# import numpy as np + +# for idx, seq_data in aligned_sequences.items(): +# # Handle both possible keys depending on return structure +# images_key = "aligned_images" if "aligned_images" in seq_data else "concatenated_images" + +# if images_key not in seq_data or len(seq_data[images_key]) == 0: +# continue + +# aligned_images = seq_data[images_key] +# meta = seq_data["metadata"] +# index = aligned_images[0]["index"] + +# # Stack images into time series (T, C, Z, Y, X) +# image_stack = [] +# for img_sample in aligned_images: +# if img_sample is not None: +# img_tensor = img_sample["anchor"] # Shape should be (Z, C, Y, X) +# img_np = img_tensor.cpu().numpy() +# image_stack.append(img_np) + +# if len(image_stack) > 0: +# # Stack into (T, Z, C, Y, X) or (T, C, Z, Y, X) +# time_series = np.stack(image_stack, axis=0) + +# # Add to napari viewer (prefix with ALIGNED_ to distinguish from full trajectories) +# distance = meta.get('distance', meta.get('dtw_distance', 0.0)) +# layer_name = f"ALIGNED_track_id_{index['track_id']}_FOV_{meta['fov_name']}_dist_{distance:.3f}" +# viewer.add_image( +# time_series, +# name=layer_name, +# contrast_limits=(time_series.min(), time_series.max()), +# ) +# logger.info( +# f"Added {layer_name} with shape {time_series.shape} (aligned region only)" +# ) +# %% +# Enhanced DataFrame was already created above with PCA plotting - skip duplicate +logger.info( + f"{ALIGN_TYPE.capitalize()} aligned timepoints: {alignment_df[f'dtw_{ALIGN_TYPE}_aligned'].sum()}/{len(alignment_df)} ({100*alignment_df[f'dtw_{ALIGN_TYPE}_aligned'].mean():.1f}%)" +) +logger.info(f"Columns: {list(alignment_df.columns)}") + +print("\nSample of enhanced DataFrame:") +sample_df = alignment_df[alignment_df["lineage_id"] != -1].head(10) +display_cols = [ + "lineage_id", + "track_id", + "t", + f"dtw_{ALIGN_TYPE}_aligned", + f"dtw_{ALIGN_TYPE}_consensus_mapping", + "PC1", +] +print(sample_df[display_cols].to_string()) + + +# Plot using the CytoDtw method +fig = cytodtw.plot_individual_lineages( + alignment_df, + alignment_name=ALIGN_TYPE, + feature_columns=["PC1", "PC2", "PC3"], + max_lineages=15, + aligned_linewidth=2.5, + unaligned_linewidth=1.4, + y_offset_step=0, +) + + +# %% +# Create concatenated image sequences using the DataFrame alignment information +# Filter for infection wells only for specific organelles +fov_name_patterns = ["consensus", "B/2"] +filtered_alignment_df = alignment_df[ + alignment_df["fov_name"].str.contains("|".join(fov_name_patterns)) +] +if dataset is not None: + # Define TripletDataset-specific image loader + def load_images_from_triplet_dataset(fov_name, track_ids): + """Load images from TripletDataset for given FOV and track IDs.""" + matching_indices = [] + for dataset_idx in range(len(dataset.valid_anchors)): + anchor_row = dataset.valid_anchors.iloc[dataset_idx] + if ( + anchor_row["fov_name"] == fov_name + and anchor_row["track_id"] in track_ids + ): + matching_indices.append(dataset_idx) + + if not matching_indices: + logger.warning( + f"No matching indices found for FOV {fov_name}, tracks {track_ids}" + ) + return {} + + # Get images and create time mapping + batch_data = dataset.__getitems__(matching_indices) + images = [] + for i in range(len(matching_indices)): + img_data = { + "anchor": batch_data["anchor"][i], + "index": batch_data["index"][i], + } + images.append(img_data) + + images.sort(key=lambda x: x["index"]["t"]) + return {img["index"]["t"]: img for img in images} + + concatenated_image_sequences = get_aligned_image_sequences( + cytodtw_instance=cytodtw, + df=filtered_alignment_df, + alignment_name=ALIGN_TYPE, + image_loader_fn=load_images_from_triplet_dataset, + max_lineages=30, + ) +else: + print("Skipping image sequence creation - no valid dataset available") + concatenated_image_sequences = {} + +# Load concatenated sequences into napari (includes unaligned + aligned + unaligned timepoints) +if NAPARI and dataset is not None and len(concatenated_image_sequences) > 0: + import numpy as np + + for lineage_id, seq_data in concatenated_image_sequences.items(): + concatenated_images = seq_data["concatenated_images"] + meta = seq_data["metadata"] + aligned_length = seq_data["aligned_length"] + unaligned_length = seq_data["unaligned_length"] + + if len(concatenated_images) == 0: + continue + + # Stack images into time series (T, C, Z, Y, X) + image_stack = [] + for img_sample in concatenated_images: + if img_sample is not None: + img_tensor = img_sample["anchor"] + img_np = img_tensor.cpu().numpy() + image_stack.append(img_np) + + if len(image_stack) > 0: + time_series = np.stack(image_stack, axis=0) + n_channels = time_series.shape[1] + + logger.info( + f"Processing lineage {lineage_id} with {n_channels} channels, shape {time_series.shape}" + ) + logger.info( + f" Aligned length: {aligned_length}, Unaligned length: {unaligned_length}, Total: {len(image_stack)}" + ) + + # Set up colormap based on number of channels + # FIXME: This is hardcoded for specific datasets - improve logic as needed + if n_channels == 2: + colormap = ["green", "magenta"] + elif n_channels == 3: + colormap = ["gray", "green", "magenta"] + else: + colormap = ["gray"] * n_channels # Default fallback + + # Add each channel as a separate layer in napari + for channel_idx in range(n_channels): + channel_data = time_series[:, channel_idx, :, :, :] + channel_name = ( + processing_channels[channel_idx] + if channel_idx < len(processing_channels) + else f"ch{channel_idx}" + ) + # Indicate that this includes full trajectory (unaligned + aligned + unaligned) + layer_name = f"FULL_track_id_{meta['track_ids'][0]}_FOV_{meta['fov_name']}_dist_{meta['dtw_distance']:.3f}_{channel_name}" + + viewer.add_image( + channel_data, + name=layer_name, + contrast_limits=(channel_data.min(), channel_data.max()), + colormap=colormap[channel_idx], + blending="additive", + ) + logger.debug( + f"Added {channel_name} channel for lineage {lineage_id} with shape {channel_data.shape}" + ) +# %% +# Get the segmentation based features and compute per-cell aggregates +segmentation_features_df = pd.read_csv(segmentation_features_path) +segmentation_features_df["fov_name"] = segmentation_features_df["fov_name"].str.lstrip( + "/" +) + +# Compute per-cell mitochondria population statistics +segs_population_features = [] +for (fov, track, t), group in segmentation_features_df.groupby( + ["fov_name", "track_id", "t"] +): + stats = { + "fov_name": fov, + "track_id": track, + "t": t, + # Count metrics + "segs_count": len(group), + # Area/volume metrics + "segs_total_area": group["area"].sum(), + "segs_mean_area": group["area"].mean(), + "segs_std_area": group["area"].std(), + "segs_median_area": group["area"].median(), + # Shape metrics + "segs_mean_eccentricity": group["eccentricity"].mean(), + "segs_std_eccentricity": group["eccentricity"].std(), + "segs_mean_solidity": group["solidity"].mean(), + "segs_std_solidity": group["solidity"].std(), + "segs_circularity_mean": group["circularity"].mean(), + "segs_circularity_std": group["circularity"].std(), + # Intensity metrics + "segs_mean_intensity": group["mean_intensity"].mean(), + "segs_std_intensity_across_mitos": group["mean_intensity"].std(), + "segs_mean_max_intensity": group["max_intensity"].mean(), + # Texture metrics (aggregated) + "segs_mean_texture_contrast": group["texture_contrast"].mean(), + "segs_mean_texture_homogeneity": group["texture_homogeneity"].mean(), + # Frangi filter metrics (tubularity/network structure) + "segs_mean_frangi_mean": group["frangi_mean_intensity"].mean(), + "segs_mean_frangi_std": group["frangi_std_intensity"].mean(), + # Shape diversity (coefficient of variation) + "segs_area_cv": group["area"].std() / (group["area"].mean() + 1e-6), + "segs_eccentricity_cv": group["eccentricity"].std() + / (group["eccentricity"].mean() + 1e-6), + "segs_solidity_cv": group["solidity"].std() / (group["solidity"].mean() + 1e-6), + "segs_frangi_cv": group["frangi_mean_intensity"].std() + / (group["frangi_mean_intensity"].mean() + 1e-6), + "segs_circularity_cv": group["circularity"].std() + / (group["circularity"].mean() + 1e-6), + } + segs_population_features.append(stats) + +segs_population_df = pd.DataFrame(segs_population_features) + +logger.info( + f"Computed mitochondria population features for {len(segs_population_df)} (fov, track, t) combinations" +) +logger.info( + f"Mitochondria population feature columns: {list(segs_population_df.columns)}" +) + +# Load the computed features and PCs +computed_features_df = pd.read_csv(computed_features_path) +# Rename time_point to t for merging +computed_features_df = computed_features_df.rename(columns={"time_point": "t"}) +# Remove the first forward slash from the fov_name +computed_features_df["fov_name"] = computed_features_df["fov_name"].str.lstrip("/") + +# Population-based normalization to measure conserved remodeling states across cells +cf_of_interests = ["homogeneity", "contrast", "edge_density", "organelle_volume"] +percentile = 10 + +for cf in cf_of_interests: + # Use population-wide baseline (same for all cells) to preserve absolute differences + population_baseline = computed_features_df[cf].quantile(percentile / 100) + computed_features_df[f"normalized_{cf}"] = ( + computed_features_df[cf] - population_baseline + ) / (population_baseline + 1e-6) +# %% +# Merge the computed features and mitochondria population features +combined_features_df = computed_features_df.merge( + segs_population_df, on=["fov_name", "track_id", "t"], how="left" +) + +# Add PCs from each channel to the combined features +for channel, adata in ad_features.items(): + # Create a temporary dataframe with PCs from this channel + pcs_df = adata.obs[["fov_name", "track_id", "t"]].copy() + + # Add PC columns with channel prefix + for i in range(n_pca_components): + pcs_df[f"{channel}_PC{i+1}"] = adata.obsm["X_pca"][:, i] + + # Merge with combined features + combined_features_df = combined_features_df.merge( + pcs_df, + on=["fov_name", "track_id", "t"], + how="left", + ) + logger.info( + f"Added {n_pca_components} PCs from {channel} channel to combined features" + ) + +# %% +# Create dataframe with uninfected cells (B/1) - no alignment needed +# Start from original tracking data for B/1 wells only +uninfected_features_df = cytodtw.adata.obs[ + cytodtw.adata.obs["fov_name"].str.contains("B/1") +].merge( + combined_features_df, + on=["fov_name", "track_id", "t", "x", "y"], + how="left", +) + +logger.info( + f"Created uninfected_features_df with B/1 wells. Shape: {uninfected_features_df.shape}" +) +logger.info(f"Wells included: {sorted(uninfected_features_df['fov_name'].unique())}") + +# %% +# Create filtered dataframe with only B/2 (infected) wells for specific analysis +align_n_comp_feat_df = filtered_alignment_df.merge( + combined_features_df, + on=["fov_name", "track_id", "t", "x", "y"], + how="left", +) + +logger.info( + f"Created align_n_comp_feat_df with B/2 wells only. Shape: {align_n_comp_feat_df.shape}" +) + +all_infected_tracking_df = cytodtw.adata.obs[ + cytodtw.adata.obs["fov_name"].str.contains("B/2") +][["fov_name", "track_id", "t", "x", "y"]].copy() + +all_infected_features_df = all_infected_tracking_df.merge( + combined_features_df, + on=["fov_name", "track_id", "t", "x", "y"], + how="left", +) + +# %% +fig = cytodtw.plot_individual_lineages( + align_n_comp_feat_df, + alignment_name=ALIGN_TYPE, + feature_columns=[ + "sensor_PC1", + "homogeneity", + "contrast", + "edge_density", + "segs_count", + "segs_total_area", + "segs_mean_area", + ], + max_lineages=8, + aligned_linewidth=2.5, + unaligned_linewidth=1.4, + y_offset_step=0.0, +) +# %% +# Heatmap showing all tracks +fig = cytodtw.plot_global_trends( + align_n_comp_feat_df, + alignment_name=ALIGN_TYPE, + plot_type="heatmap", + cmap="RdBu", + figsize=(12, 12), + feature_columns=[ + # "sensor_PC1", + # "sensor_PC2", + # "sensor_PC3", + # "phase_PC1", + # "phase_PC2", + # "phase_PC3", + "organelle_PC1", + "organelle_PC2", + "organelle_PC3", + # "homogeneity", + "edge_density", + "segs_count", + "segs_total_area", + "segs_mean_area", + "segs_circularity_mean", + "segs_mean_frangi_mean", + ], + max_lineages=8, +) + + +# %% +# Unified trajectory aggregation function +def aggregate_trajectory_by_time( + df: pd.DataFrame, + feature_columns: list, +) -> pd.DataFrame: + """ + Unified function to aggregate trajectory data by timepoint. + + Parameters + ---------- + df : pd.DataFrame + Pre-filtered dataframe with features and 't' column + feature_columns : list + Features to aggregate + + Returns + ------- + pd.DataFrame + Aggregated trajectory with columns: t, n_cells, {feature}_{stat} + where stat is one of: mean, median, std, q25, q75 + """ + aggregated_data = [] + + for t in sorted(df["t"].unique()): + time_slice = df[df["t"] == t] + + row_data = { + "t": t, + "n_cells": len(time_slice), + } + + # Compute statistics for each feature + for feature in feature_columns: + if feature not in time_slice.columns: + # Feature doesn't exist - set all stats to NaN + row_data[f"{feature}_median"] = np.nan + row_data[f"{feature}_mean"] = np.nan + row_data[f"{feature}_std"] = np.nan + row_data[f"{feature}_q25"] = np.nan + row_data[f"{feature}_q75"] = np.nan + continue + + values = time_slice[feature].dropna() + + if len(values) == 0: + # No valid values - set all stats to NaN + row_data[f"{feature}_median"] = np.nan + row_data[f"{feature}_mean"] = np.nan + row_data[f"{feature}_std"] = np.nan + row_data[f"{feature}_q25"] = np.nan + row_data[f"{feature}_q75"] = np.nan + continue + + # Compute all statistics + row_data[f"{feature}_mean"] = values.mean() + row_data[f"{feature}_median"] = values.median() + row_data[f"{feature}_std"] = values.std() + row_data[f"{feature}_q25"] = values.quantile(0.25) + row_data[f"{feature}_q75"] = values.quantile(0.75) + + aggregated_data.append(row_data) + + result_df = pd.DataFrame(aggregated_data) + + logger.info( + f"Aggregated {len(result_df)} timepoints from {len(df)} total observations" + ) + + return result_df + + +# Select features to compute common response for +common_response_features = [ + "organelle_PC1", + "organelle_PC2", + "organelle_PC3", + "phase_PC1", + "phase_PC2", + "phase_PC3", + "edge_density", + # "organelle_volume", + "segs_count", + # "segs_total_area", + "segs_mean_area", + "segs_mean_eccentricity", + # "segs_mean_texture_contrast", + "segs_mean_frangi_mean", + "segs_circularity_mean", + "segs_circularity_cv", + "segs_eccentricity_cv", + "segs_area_cv", +] + +# Compute common response from top N aligned cells +# First, select top N lineages by DTW distance +top_n_cells = 10 +alignment_col = f"dtw_{ALIGN_TYPE}_aligned" + +# Get aligned cells only +aligned_cells = align_n_comp_feat_df[align_n_comp_feat_df[alignment_col] == True].copy() + +# Select top N lineages by distance +if "distance" in aligned_cells.columns and "lineage_id" in aligned_cells.columns: + # Drop duplicates to get one row per lineage, then select N with smallest distance + top_lineages = ( + aligned_cells.drop_duplicates("lineage_id") + .nsmallest(top_n_cells, "distance")["lineage_id"] + .tolist() + ) + logger.info( + f"Selected top {len(top_lineages)} lineages by DTW distance: {top_lineages}" + ) +else: + top_lineages = aligned_cells["lineage_id"].unique()[:top_n_cells].tolist() + logger.info(f"Selected {len(top_lineages)} lineages (no distance info)") + +# Filter to top lineages - include ALL timepoints (aligned and unaligned) +top_cells_df = align_n_comp_feat_df[ + align_n_comp_feat_df["lineage_id"].isin(top_lineages) +].copy() + +logger.info( + f"Filtered to {len(top_cells_df)} observations from {len(top_lineages)} lineages" +) + +# Aggregate using unified function +common_response_df = aggregate_trajectory_by_time( + top_cells_df, + feature_columns=common_response_features, +) + +# Compute infection timepoint and aligned region for visualization +# These are derived from the alignment metadata BEFORE aggregation +infection_timepoint = None +aligned_region_bounds = None + +if alignment_col in top_cells_df.columns: + aligned_mask = top_cells_df[alignment_col] == True + if aligned_mask.any(): + aligned_subset = top_cells_df[aligned_mask] + + # Aligned region: compute median aligned span across lineages + # Each lineage has an aligned window; we want the typical/median window + lineage_aligned_regions = [] + for lineage_id in aligned_subset["lineage_id"].unique(): + lineage_aligned = aligned_subset[aligned_subset["lineage_id"] == lineage_id] + lineage_times = sorted(lineage_aligned["t"].unique()) + if len(lineage_times) > 0: + lineage_aligned_regions.append((lineage_times[0], lineage_times[-1])) + + if len(lineage_aligned_regions) > 0: + # Use median start and end across all lineages + starts = [r[0] for r in lineage_aligned_regions] + ends = [r[1] for r in lineage_aligned_regions] + aligned_region_bounds = (int(np.median(starts)), int(np.median(ends))) + + # Infection timepoint: propagate consensus annotations via DTW alignment + consensus_mapping_col = f"dtw_{ALIGN_TYPE}_consensus_mapping" + if consensus_annotations and "infected" in consensus_annotations: + if consensus_mapping_col in aligned_subset.columns: + # For each aligned cell, look up its annotation from consensus + # consensus_mapping tells us which position in the consensus pattern + aligned_subset_copy = aligned_subset.copy() + + # Map consensus position to annotation + def get_annotation(consensus_pos): + idx = int(round(consensus_pos)) + if 0 <= idx < len(consensus_annotations): + return consensus_annotations[idx] + return None + + aligned_subset_copy["propagated_annotation"] = aligned_subset_copy[ + consensus_mapping_col + ].apply(get_annotation) + + # Find first appearance of "infected" for each lineage + first_infected_times = [] + for lineage_id in aligned_subset_copy["lineage_id"].unique(): + lineage_data = aligned_subset_copy[ + aligned_subset_copy["lineage_id"] == lineage_id + ].sort_values("t") + infected_rows = lineage_data[ + lineage_data["propagated_annotation"] == "infected" + ] + if len(infected_rows) > 0: + first_infected_times.append(infected_rows.iloc[0]["t"]) + + if len(first_infected_times) > 0: + # Use mean of first infection timepoints across lineages + infection_timepoint = int(np.mean(first_infected_times)) + +logger.info(f"Infection timepoint: {infection_timepoint}") +logger.info(f"Aligned region: {aligned_region_bounds}") + +# Debug: check if consensus_annotations is available +if consensus_annotations: + logger.info( + f"Consensus annotations available: 'infected' at position {consensus_annotations.index('infected') if 'infected' in consensus_annotations else 'NOT FOUND'}" + ) +else: + logger.warning("consensus_annotations is None or empty!") + +# Debug: check what columns are available in top_cells_df +logger.info( + f"Available alignment columns in top_cells_df: {[c for c in top_cells_df.columns if 'dtw' in c.lower() or 'align' in c.lower()]}" +) + + +# %% +# Compute uninfected baseline from control wells (B/1) +# Filter to uninfected FOVs with sufficient track length +uninfected_fov_pattern = "B/1" +min_track_length = 20 + +uninfected_filtered = uninfected_features_df[ + uninfected_features_df["fov_name"].str.contains(uninfected_fov_pattern) +].copy() + +# Filter by track length +track_lengths = uninfected_filtered.groupby(["fov_name", "track_id"]).size() +valid_tracks = track_lengths[track_lengths >= min_track_length].index + +uninfected_filtered = uninfected_filtered[ + uninfected_filtered.set_index(["fov_name", "track_id"]).index.isin(valid_tracks) +].reset_index(drop=True) + +logger.info( + f"Filtered uninfected cells: {len(valid_tracks)} tracks with >= {min_track_length} timepoints" +) +logger.info(f"Total observations: {len(uninfected_filtered)}") + +# Aggregate using unified function +uninfected_baseline_df = aggregate_trajectory_by_time( + uninfected_filtered, + feature_columns=common_response_features, +) + +logger.info(f"Uninfected baseline shape: {uninfected_baseline_df.shape}") + + +# %% +# Compute global average of ALL infected cells (B/2) without alignment +# Filter to infected FOVs with sufficient track length +infected_fov_pattern = "B/2" + +global_infected_filtered = all_infected_features_df[ + all_infected_features_df["fov_name"].str.contains(infected_fov_pattern) +].copy() + +if len(global_infected_filtered) == 0: + logger.warning(f"No cells found matching pattern '{infected_fov_pattern}'") + global_infected_df = pd.DataFrame() +else: + # Filter by track length - use (fov_name, track_id) since lineage_id may not exist + track_lengths = global_infected_filtered.groupby(["fov_name", "track_id"]).size() + valid_tracks = track_lengths[track_lengths >= min_track_length].index + + global_infected_filtered = global_infected_filtered[ + global_infected_filtered.set_index(["fov_name", "track_id"]).index.isin( + valid_tracks + ) + ].reset_index(drop=True) + + n_tracks = len(valid_tracks) + logger.info( + f"Filtered global infected: {n_tracks} tracks with >= {min_track_length} timepoints" + ) + logger.info(f"Total observations: {len(global_infected_filtered)}") + + # Aggregate using unified function + global_infected_df = aggregate_trajectory_by_time( + global_infected_filtered, + feature_columns=common_response_features, + ) + + logger.info(f"Global infected average shape: {global_infected_df.shape}") + + +# %% +# Compute baseline normalization values BEFORE plotting +def compute_baseline_normalization_values( + infected_df: pd.DataFrame, + uninfected_df: pd.DataFrame, + global_infected_df: pd.DataFrame, + feature_columns: list, + n_baseline_timepoints: int = 10, +): + """ + Compute baseline normalization values from the first n timepoints of each trajectory. + + Each trajectory (infected, uninfected, global infected) is normalized by its own + baseline computed from its first n timepoints. This allows comparison of relative + changes across different trajectories. + + Parameters + ---------- + infected_df : pd.DataFrame + Infected common response aggregated dataframe with 't' column + uninfected_df : pd.DataFrame + Uninfected baseline aggregated dataframe with 't' column + global_infected_df : pd.DataFrame + Global infected average aggregated dataframe with 't' column + feature_columns : list + Features to compute baselines for + n_baseline_timepoints : int + Number of initial timepoints to use as baseline (default: 10) + + Returns + ------- + dict + Baseline values for each feature and trajectory type + """ + # Compute baseline normalization values from first n timepoints of each trajectory + baseline_values = {} + + logger.info( + f"Computing baseline from first {n_baseline_timepoints} timepoints of each trajectory" + ) + + for feature in feature_columns: + median_col = f"{feature}_median" + + # Initialize baseline dict for this feature + baseline_values[feature] = {} + + # Infected trajectory baseline - use first n timepoints + if median_col in infected_df.columns: + sorted_times = sorted(infected_df["t"].unique()) + baseline_times = sorted_times[:n_baseline_timepoints] + baseline_mask = infected_df["t"].isin(baseline_times) + baseline_vals = infected_df.loc[baseline_mask, median_col].dropna() + if len(baseline_vals) > 0: + baseline_values[feature]["infected"] = baseline_vals.mean() + logger.debug( + f" {feature} infected baseline: {baseline_vals.mean():.3f} (from times {min(baseline_times)}-{max(baseline_times)})" + ) + else: + baseline_values[feature]["infected"] = None + + # Uninfected trajectory baseline - use first n timepoints + if median_col in uninfected_df.columns: + sorted_times = sorted(uninfected_df["t"].unique()) + baseline_times = sorted_times[:n_baseline_timepoints] + baseline_mask = uninfected_df["t"].isin(baseline_times) + baseline_vals = uninfected_df.loc[baseline_mask, median_col].dropna() + if len(baseline_vals) > 0: + baseline_values[feature]["uninfected"] = baseline_vals.mean() + logger.debug( + f" {feature} uninfected baseline: {baseline_vals.mean():.3f} (from times {min(baseline_times)}-{max(baseline_times)})" + ) + else: + baseline_values[feature]["uninfected"] = None + + # Global infected trajectory baseline - use first n timepoints + if global_infected_df is not None and median_col in global_infected_df.columns: + sorted_times = sorted(global_infected_df["t"].unique()) + baseline_times = sorted_times[:n_baseline_timepoints] + baseline_mask = global_infected_df["t"].isin(baseline_times) + baseline_vals = global_infected_df.loc[baseline_mask, median_col].dropna() + if len(baseline_vals) > 0: + baseline_values[feature]["global"] = baseline_vals.mean() + logger.debug( + f" {feature} global infected baseline: {baseline_vals.mean():.3f} (from times {min(baseline_times)}-{max(baseline_times)})" + ) + else: + baseline_values[feature]["global"] = None + + return baseline_values + + +# Compute baseline values +baseline_normalization_values = compute_baseline_normalization_values( + common_response_df, + uninfected_baseline_df, + global_infected_df, + feature_columns=common_response_features, + n_baseline_timepoints=int(np.floor(min_track_length * 0.75)), +) + +logger.info( + f"Computed baseline values for {len(baseline_normalization_values)} features" +) + + +# %% +def plot_binned_period_comparison( + infected_df: pd.DataFrame, + uninfected_df: pd.DataFrame, + feature_columns: list, + infection_time: int, + baseline_values: dict = None, + global_infected_df: pd.DataFrame = None, + output_root: Path = None, + figsize=(18, 14), + plot_type: str = "line", + add_stats: bool = True, +): + """ + Plot binned period comparison showing fold-change across biological phases. + + Creates line plots or bar plots comparing infected vs uninfected trajectories across + biologically meaningful periods (baseline, peri-infection, fragmentation, late/death). + Each period's value is normalized to the baseline period to show fold-change. + Includes statistical testing to identify significant differences. + + Parameters + ---------- + infected_df : pd.DataFrame + Infected common response aggregated dataframe + uninfected_df : pd.DataFrame + Uninfected baseline aggregated dataframe + feature_columns : list + Features to plot + infection_time : int + Infection timepoint (used to define period boundaries) + baseline_values : dict, optional + Pre-computed baseline normalization values + global_infected_df : pd.DataFrame, optional + Global average of all infected cells + output_root : Path, optional + Directory to save output figure + figsize : tuple + Figure size + plot_type : str + 'line' for connected line plots or 'bar' for grouped bar plots (default: 'line') + add_stats : bool + If True, perform statistical testing and mark significant differences (default: True) + """ + from scipy.stats import ttest_ind + + # Define biologically meaningful periods relative to infection + periods = { + "Baseline": (infection_time - 10, infection_time), + "Peri-infection": (infection_time - 2, infection_time + 3), + "Fragmentation": (infection_time + 5, infection_time + 15), + "Late/Death": (infection_time + 15, infection_time + 25), + } + + period_names = list(periods.keys()) + n_periods = len(periods) + + n_features = len(feature_columns) + ncols = 3 + nrows = int(np.ceil(n_features / ncols)) + + fig, axes = plt.subplots(nrows, ncols, figsize=figsize) + axes = axes.flatten() if n_features > 1 else [axes] + + # Colorblind-friendly palette + colors = { + "uninfected": "#1f77b4", # blue + "infected": "#ff7f0e", # orange + "global": "#2ca02c", # green + } + + # Store statistical results for logging + stats_results = {} + + for idx, feature in enumerate(feature_columns): + ax = axes[idx] + + median_col = f"{feature}_median" + + # Check if feature exists + if ( + median_col not in infected_df.columns + or median_col not in uninfected_df.columns + ): + ax.text(0.5, 0.5, f"{feature}\nno data", ha="center", va="center") + ax.set_title(feature) + continue + + # Check if CV/SEM feature (no normalization) + is_cv_feature = feature.endswith("_cv") or feature.endswith("_sem") + + # Compute values for each period + period_values = {"uninfected": [], "infected": [], "global": []} + period_errors = {"uninfected": [], "infected": [], "global": []} + + # Baseline period values for normalization (use mean within baseline period) + baseline_period = periods["Baseline"] + + def compute_baseline_value(df, feature_col): + """Compute baseline value from baseline period.""" + mask = (df["t"] >= baseline_period[0]) & (df["t"] <= baseline_period[1]) + values = df.loc[mask, feature_col].dropna() + return values.mean() if len(values) > 0 else None + + # Compute baseline for this feature from each trajectory + uninfected_baseline = compute_baseline_value(uninfected_df, median_col) + infected_baseline = compute_baseline_value(infected_df, median_col) + global_baseline = None + if global_infected_df is not None and median_col in global_infected_df.columns: + global_baseline = compute_baseline_value(global_infected_df, median_col) + + # For each period, compute aggregate value and normalize + for period_name, (t_start, t_end) in periods.items(): + # Uninfected + mask = (uninfected_df["t"] >= t_start) & (uninfected_df["t"] <= t_end) + values = uninfected_df.loc[mask, median_col].dropna() + if len(values) > 0: + mean_val = values.mean() + std_val = values.std() + + # Normalize to baseline if not CV feature + if not is_cv_feature and uninfected_baseline is not None: + mean_val = mean_val / uninfected_baseline + std_val = std_val / (np.abs(uninfected_baseline) + 1e-6) + + period_values["uninfected"].append(mean_val) + period_errors["uninfected"].append(std_val) + else: + period_values["uninfected"].append(np.nan) + period_errors["uninfected"].append(np.nan) + + # Infected + mask = (infected_df["t"] >= t_start) & (infected_df["t"] <= t_end) + values = infected_df.loc[mask, median_col].dropna() + if len(values) > 0: + mean_val = values.mean() + std_val = values.std() + + if not is_cv_feature and infected_baseline is not None: + mean_val = mean_val / infected_baseline + std_val = std_val / (np.abs(infected_baseline) + 1e-6) + + period_values["infected"].append(mean_val) + period_errors["infected"].append(std_val) + else: + period_values["infected"].append(np.nan) + period_errors["infected"].append(np.nan) + + # Global infected + if ( + global_infected_df is not None + and median_col in global_infected_df.columns + ): + mask = (global_infected_df["t"] >= t_start) & ( + global_infected_df["t"] <= t_end + ) + values = global_infected_df.loc[mask, median_col].dropna() + if len(values) > 0: + mean_val = values.mean() + std_val = values.std() + + if not is_cv_feature and global_baseline is not None: + mean_val = mean_val / global_baseline + std_val = std_val / (np.abs(global_baseline) + 1e-6) + + period_values["global"].append(mean_val) + period_errors["global"].append(std_val) + else: + period_values["global"].append(np.nan) + period_errors["global"].append(np.nan) + + # Statistical testing between infected and uninfected at each period + p_values = [] + if add_stats: + stats_results[feature] = {} + for period_name, (t_start, t_end) in periods.items(): + # Get raw values (not aggregated medians) for statistical testing + uninfected_mask = (uninfected_df["t"] >= t_start) & ( + uninfected_df["t"] <= t_end + ) + infected_mask = (infected_df["t"] >= t_start) & ( + infected_df["t"] <= t_end + ) + + uninfected_vals = uninfected_df.loc[ + uninfected_mask, median_col + ].dropna() + infected_vals = infected_df.loc[infected_mask, median_col].dropna() + + if len(uninfected_vals) >= 3 and len(infected_vals) >= 3: + _, p_val = ttest_ind(uninfected_vals, infected_vals) + p_values.append(p_val) + stats_results[feature][period_name] = p_val + else: + p_values.append(np.nan) + stats_results[feature][period_name] = np.nan + + x = np.arange(n_periods) + + if plot_type == "line": + # Line plot with error bars + ax.errorbar( + x, + period_values["uninfected"], + yerr=period_errors["uninfected"], + label="Uninfected (B/1)", + color=colors["uninfected"], + marker="o", + markersize=8, + linewidth=2.5, + capsize=4, + capthick=2, + ) + ax.errorbar( + x, + period_values["infected"], + yerr=period_errors["infected"], + label="Infected top-N (B/2)", + color=colors["infected"], + marker="s", + markersize=8, + linewidth=2.5, + capsize=4, + capthick=2, + ) + + if global_infected_df is not None: + ax.errorbar( + x, + period_values["global"], + yerr=period_errors["global"], + label="All B/2 cells", + color=colors["global"], + marker="^", + markersize=7, + linewidth=2, + linestyle="--", + capsize=4, + capthick=1.5, + alpha=0.8, + ) + + # Mark significant differences with asterisks + if add_stats and len(p_values) > 0: + y_max = max( + [ + max( + [v for v in period_values["uninfected"] if not np.isnan(v)] + ), + max([v for v in period_values["infected"] if not np.isnan(v)]), + ] + ) + y_offset = 0.1 * (y_max - 1.0) if not is_cv_feature else 0.1 * y_max + + for i, p_val in enumerate(p_values): + if not np.isnan(p_val): + # Determine significance level + if p_val < 0.001: + marker = "***" + elif p_val < 0.01: + marker = "**" + elif p_val < 0.05: + marker = "*" + else: + marker = "ns" + + if marker != "ns": + # Position text above the higher of the two values + max_val = max( + period_values["uninfected"][i], + period_values["infected"][i], + ) + ax.text( + x[i], + max_val + y_offset, + marker, + ha="center", + va="bottom", + fontsize=10, + fontweight="bold", + color="black", + ) + + else: # bar plot + width = 0.25 + + ax.bar( + x - width, + period_values["uninfected"], + width, + label="Uninfected (B/1)", + color=colors["uninfected"], + yerr=period_errors["uninfected"], + capsize=3, + ) + ax.bar( + x, + period_values["infected"], + width, + label="Infected top-N (B/2)", + color=colors["infected"], + yerr=period_errors["infected"], + capsize=3, + ) + + if global_infected_df is not None: + ax.bar( + x + width, + period_values["global"], + width, + label="All B/2 cells", + color=colors["global"], + yerr=period_errors["global"], + capsize=3, + alpha=0.8, + ) + + # Mark significant differences with brackets + if add_stats and len(p_values) > 0: + y_max = max( + [ + max( + [v for v in period_values["uninfected"] if not np.isnan(v)] + ), + max([v for v in period_values["infected"] if not np.isnan(v)]), + ] + ) + y_offset = 0.1 * (y_max - 1.0) if not is_cv_feature else 0.1 * y_max + + for i, p_val in enumerate(p_values): + if not np.isnan(p_val) and p_val < 0.05: + if p_val < 0.001: + marker = "***" + elif p_val < 0.01: + marker = "**" + else: + marker = "*" + + max_val = max( + period_values["uninfected"][i], period_values["infected"][i] + ) + ax.text( + x[i], + max_val + y_offset, + marker, + ha="center", + va="bottom", + fontsize=10, + fontweight="bold", + ) + + # Add horizontal line at 1.0 (no change from baseline) + if not is_cv_feature: + ax.axhline(1.0, color="gray", linestyle="--", linewidth=1, alpha=0.5) + + ax.set_xlabel("Period") + if is_cv_feature: + ax.set_ylabel(f"{feature}\n(raw value)") + else: + ax.set_ylabel(f"{feature}\n(fold-change from baseline)") + ax.set_title(feature) + ax.set_xticks(x) + ax.set_xticklabels(period_names, rotation=45, ha="right") + # ax.legend(loc="best", fontsize=7) + ax.grid(True, alpha=0.3, axis="y") + + # Hide unused subplots + for idx in range(n_features, len(axes)): + axes[idx].axis("off") + + # Create title with statistical note + title = "Binned Period Comparison: Fold-Change Across Infection Phases" + if add_stats: + title += "\n(* p<0.05, ** p<0.01, *** p<0.001)" + + plt.suptitle( + title, + fontsize=14, + y=1.00, + ) + plt.tight_layout() + + if output_root is not None: + save_path = output_root / f"binned_period_comparison_{ALIGN_TYPE}.png" + plt.savefig(save_path, dpi=150, bbox_inches="tight") + logger.info(f"Saved binned period comparison to {save_path}") + + plt.show() + + # Log summary in markdown format + logger.info("\n## Binned Period Comparison Summary") + logger.info(f"**Infection timepoint:** {infection_time}") + logger.info("\n### Period Definitions") + for period_name, (t_start, t_end) in periods.items(): + logger.info(f"- **{period_name}:** t={t_start} to t={t_end}") + + if add_stats and len(stats_results) > 0: + logger.info("\n### Statistical Significance (t-tests)") + logger.info( + "Comparing infected top-N vs uninfected at each period. Significance levels: * p<0.05, ** p<0.01, *** p<0.001\n" + ) + + # Create markdown table + logger.info(f"| Feature | {' | '.join(period_names)} |") + logger.info(f"|---------|{'---------|-' * (len(period_names) - 1)}---------|\n") + + for feature, period_results in stats_results.items(): + sig_markers = [] + for period_name in period_names: + p_val = period_results.get(period_name, np.nan) + if np.isnan(p_val): + sig_markers.append("N/A") + elif p_val < 0.001: + sig_markers.append(f"***({p_val:.3e})") + elif p_val < 0.01: + sig_markers.append(f"**({p_val:.3f})") + elif p_val < 0.05: + sig_markers.append(f"*({p_val:.3f})") + else: + sig_markers.append(f"ns({p_val:.3f})") + + logger.info(f"| {feature} | {' | '.join(sig_markers)} |") + + logger.info("\nns = not significant (p >= 0.05)") + + +# %% +# Compare infected (aligned) vs uninfected (baseline) trajectories +def plot_infected_vs_uninfected_comparison( + infected_df: pd.DataFrame, + uninfected_df: pd.DataFrame, + feature_columns: list, + baseline_values: dict = None, + infection_time: int = None, + aligned_region: tuple = None, + figsize=(18, 14), + n_consecutive_divergence: int = 5, + global_infected_df: pd.DataFrame = None, + normalize_to_baseline: bool = True, +): + """ + Plot comparison of infected (DTW-aligned) vs uninfected (baseline) trajectories. + + Shows where infected cells diverge from normal behavior (crossover points). + Includes infection timepoint marker, aligned region highlighting, and consecutive divergence detection. + + NOTE: Features ending in '_cv' or '_sem' are plotted as raw values without baseline + normalization, since CV and SEM are already relative/uncertainty metrics. + + Parameters + ---------- + infected_df : pd.DataFrame + Infected common response with 'time', 'is_aligned', and 'aligned_time' columns + uninfected_df : pd.DataFrame + Uninfected baseline + feature_columns : list + Features to plot + baseline_values : dict, optional + Pre-computed baseline normalization values for each feature and trajectory + infection_time : int, optional + Pre-computed infection timepoint in raw time + aligned_region : tuple, optional + Pre-computed aligned region boundaries (start, end) + figsize : tuple + Figure size + n_consecutive_divergence : int + Number of consecutive timepoints required to confirm divergence (default: 5) + global_infected_df : pd.DataFrame, optional + Global average of ALL infected cells without alignment (if provided, will be plotted) + normalize_to_baseline : bool + If True, normalize all trajectories to pre-infection baseline showing fold-change (default: True). + CV and SEM features (ending in '_cv' or '_sem') are always plotted as raw values regardless of this setting. + """ + from scipy.interpolate import interp1d + + n_features = len(feature_columns) + ncols = 3 + nrows = int(np.ceil(n_features / ncols)) + + fig, axes = plt.subplots(nrows, ncols, figsize=figsize) + axes = axes.flatten() if n_features > 1 else [axes] + + # Colorblind-friendly palette: blue (uninfected) vs orange (infected) + uninfected_color = "#1f77b4" # blue + infected_color = "#ff7f0e" # orange + + # Use pre-computed baseline values or initialize empty dict + if baseline_values is None: + baseline_values = {} + logger.warning("No baseline_values provided - plots will not be normalized") + + for idx, feature in enumerate(feature_columns): + ax = axes[idx] + + median_col = f"{feature}_median" + q25_col = f"{feature}_q25" + q75_col = f"{feature}_q75" + + # Check if data exists + if ( + median_col not in infected_df.columns + or median_col not in uninfected_df.columns + ): + ax.text(0.5, 0.5, f"{feature}\nno data", ha="center", va="center") + ax.set_title(feature) + continue + + # Highlight aligned region first (background layer) + if aligned_region is not None: + ax.axvspan( + aligned_region[0], + aligned_region[1], + alpha=0.1, + color="gray", + label="Aligned region", + zorder=0, + ) + + # Plot uninfected baseline + uninfected_time = uninfected_df["t"].values + uninfected_median = uninfected_df[median_col].values + uninfected_q25 = uninfected_df[q25_col].values + uninfected_q75 = uninfected_df[q75_col].values + + # Check if this is a CV or SEM feature (skip normalization for relative/uncertainty metrics) + is_cv_feature = feature.endswith("_cv") or feature.endswith("_sem") + + # Apply baseline normalization if requested (but skip for CV features) + if ( + normalize_to_baseline + and not is_cv_feature + and feature in baseline_values + and baseline_values[feature]["uninfected"] is not None + ): + baseline = baseline_values[feature]["uninfected"] + uninfected_median = (uninfected_median - baseline) / ( + np.abs(baseline) + 1e-6 + ) + uninfected_q25 = (uninfected_q25 - baseline) / (np.abs(baseline) + 1e-6) + uninfected_q75 = (uninfected_q75 - baseline) / (np.abs(baseline) + 1e-6) + + ax.plot( + uninfected_time, + uninfected_median, + color=uninfected_color, + linewidth=2.5, + label="Uninfected (B/1)", + linestyle="-", + ) + ax.fill_between( + uninfected_time, + uninfected_q25, + uninfected_q75, + color=uninfected_color, + alpha=0.2, + ) + + # Plot infected aligned response + infected_time = infected_df["t"].values + infected_median = infected_df[median_col].values + infected_q25 = infected_df[q25_col].values + infected_q75 = infected_df[q75_col].values + + # Apply baseline normalization if requested (but skip for CV features) + if ( + normalize_to_baseline + and not is_cv_feature + and feature in baseline_values + and baseline_values[feature]["infected"] is not None + ): + baseline = baseline_values[feature]["infected"] + infected_median = (infected_median - baseline) / (np.abs(baseline) + 1e-6) + infected_q25 = (infected_q25 - baseline) / (np.abs(baseline) + 1e-6) + infected_q75 = (infected_q75 - baseline) / (np.abs(baseline) + 1e-6) + + ax.plot( + infected_time, + infected_median, + color=infected_color, + linewidth=2.5, + label="Infected top-N aligned (B/2)", + linestyle="-", + ) + ax.fill_between( + infected_time, + infected_q25, + infected_q75, + color=infected_color, + alpha=0.2, + ) + + # Plot global infected average (all B/2 cells, no alignment) if provided + if global_infected_df is not None and median_col in global_infected_df.columns: + global_time = global_infected_df["t"].values + global_median = global_infected_df[median_col].values + global_q25 = ( + global_infected_df[q25_col].values + if q25_col in global_infected_df.columns + else None + ) + global_q75 = ( + global_infected_df[q75_col].values + if q75_col in global_infected_df.columns + else None + ) + + # Apply baseline normalization if requested (but skip for CV features) + if ( + normalize_to_baseline + and not is_cv_feature + and feature in baseline_values + and baseline_values[feature]["global"] is not None + ): + baseline = baseline_values[feature]["global"] + global_median = (global_median - baseline) / (np.abs(baseline) + 1e-6) + if global_q25 is not None: + global_q25 = (global_q25 - baseline) / (np.abs(baseline) + 1e-6) + if global_q75 is not None: + global_q75 = (global_q75 - baseline) / (np.abs(baseline) + 1e-6) + + ax.plot( + global_time, + global_median, + color="#15ba10", # red + linewidth=2, + label="All B/2 cells (no alignment)", + linestyle="--", + alpha=0.8, + ) + if global_q25 is not None and global_q75 is not None: + ax.fill_between( + global_time, + global_q25, + global_q75, + color="#15ba10", + alpha=0.15, + ) + + # Mark infection timepoint + if infection_time is not None: + ax.axvline( + infection_time, + color="red", + linestyle="-", + alpha=0.8, + linewidth=2.5, + label="Infection event", + zorder=5, + ) + + # Find consecutive divergence points + if len(uninfected_median) > 0 and n_consecutive_divergence > 0: + uninfected_std = np.nanstd(uninfected_median) + + # Interpolate uninfected to match infected timepoints for comparison + if len(uninfected_time) > 1 and len(infected_time) > 1: + # Only interpolate within the range of uninfected data + min_t = max(uninfected_time.min(), infected_time.min()) + max_t = min(uninfected_time.max(), infected_time.max()) + + if min_t < max_t: + interp_func = interp1d( + uninfected_time, + uninfected_median, + kind="linear", + fill_value="extrapolate", + ) + + # Find timepoints where infected is significantly different + # Constrain to only timepoints AFTER infection + overlap_mask = (infected_time >= min_t) & (infected_time <= max_t) + if infection_time is not None: + overlap_mask = overlap_mask & (infected_time >= infection_time) + + overlap_times = infected_time[overlap_mask] + overlap_infected = infected_median[overlap_mask] + overlap_uninfected = interp_func(overlap_times) + + divergence = np.abs(overlap_infected - overlap_uninfected) + threshold = 1.5 * uninfected_std + + divergent_mask = divergence > threshold + + # Find consecutive divergence streaks + if np.any(divergent_mask): + # Find runs of consecutive True values + consecutive_start = None + consecutive_count = 0 + + for i, is_divergent in enumerate(divergent_mask): + if is_divergent: + if consecutive_start is None: + consecutive_start = i + consecutive_count += 1 + + # Check if we've reached the required consecutive count + if consecutive_count >= n_consecutive_divergence: + # Mark the first divergence point + first_divergence = overlap_times[consecutive_start] + ax.axvline( + first_divergence, + color="red", + linestyle="--", + alpha=0.6, + linewidth=2, + label=f"Divergence (t={first_divergence:.0f})", + zorder=4, + ) + break # Only mark the first sustained divergence + else: + # Reset counter if streak breaks + consecutive_start = None + consecutive_count = 0 + + ax.set_xlabel("Time") + # Update y-axis label based on whether CV/SEM feature or normalized + if feature.endswith("_cv"): + ax.set_ylabel(f"{feature}\n(raw CV)") + elif feature.endswith("_sem"): + ax.set_ylabel(f"{feature}\n(raw SEM)") + elif normalize_to_baseline and feature in baseline_values: + ax.set_ylabel(f"{feature}\n(fold-change from baseline)") + else: + ax.set_ylabel(feature) + ax.set_title(feature) + ax.grid(True, alpha=0.3) + ax.legend(loc="best", fontsize=7) + + # Hide unused subplots + for idx in range(n_features, len(axes)): + axes[idx].axis("off") + + plt.tight_layout() + plt.savefig( + output_root / f"infected_vs_uninfected_comparison_{ALIGN_TYPE}.png", + dpi=150, + bbox_inches="tight", + ) + plt.show() + + +# Plot comparison +plot_infected_vs_uninfected_comparison( + common_response_df, + uninfected_baseline_df, + feature_columns=common_response_features, + baseline_values=baseline_normalization_values, + infection_time=infection_timepoint, + aligned_region=aligned_region_bounds, + figsize=(18, 14), + # global_infected_df=global_infected_df, + normalize_to_baseline=True, +) +# %% +# Plot binned period comparison with statistical testing +plot_binned_period_comparison( + common_response_df, + uninfected_baseline_df, + feature_columns=common_response_features, + infection_time=infection_timepoint, + baseline_values=baseline_normalization_values, + # global_infected_df=global_infected_df, + output_root=output_root, + figsize=(12, 24), + plot_type="line", # Use line plots to show trends across periods + add_stats=True, # Include statistical testing with significance markers +) +# %% +# Plot PC/PHATE for all the cells grayed out and the top N aligned cells highlighted with fancyarrows +# Shows unaligned + aligned + unaligned timepoints like in the time-series plots +from matplotlib.patches import FancyArrowPatch +import matplotlib.cm as cm + +fig, ax = plt.subplots(figsize=(12, 10)) +if ( + "organelle_PC1" in combined_features_df.columns + and "organelle_PC2" in combined_features_df.columns +): + + # Highlight top N aligned cells - include ALL timepoints (unaligned + aligned + unaligned) + top_aligned_cells = align_n_comp_feat_df[ + align_n_comp_feat_df["lineage_id"].isin(top_lineages[:3]) + ] + + # Filter the top_n_aligned_cells to only include wells B/1 and B/2 + top_n_aligned_cells = top_aligned_cells[ + top_aligned_cells["fov_name"].str.contains("B/1|B/2") + ] + # Filter combined features to only include wells B/1 and B/2 + filter_combined_features = combined_features_df[ + combined_features_df["fov_name"].str.contains("B/1|B/2") + ] + ax.scatter( + filter_combined_features["organelle_PC1"], + filter_combined_features["organelle_PC2"], + color="lightgray", + alpha=0.3, + s=10, + label="All cells", + zorder=1, + ) + + # Get colormap for lineages + n_lineages = len(top_n_aligned_cells["lineage_id"].unique()) + colors = cm.tab10(np.linspace(0, 1, n_lineages)) + + # Column name for alignment status + alignment_col = f"dtw_{ALIGN_TYPE}_aligned" + + # One color per track with temporal arrows, showing aligned vs unaligned timepoints + for idx, lineage_id in enumerate(top_n_aligned_cells["lineage_id"].unique()): + lineage_data = top_n_aligned_cells[ + top_n_aligned_cells["lineage_id"] == lineage_id + ].sort_values("t") + + color = colors[idx] + + # Split into aligned and unaligned portions + if alignment_col in lineage_data.columns: + aligned_data = lineage_data[lineage_data[alignment_col] == True] + unaligned_data = lineage_data[lineage_data[alignment_col] == False] + else: + aligned_data = lineage_data + unaligned_data = pd.DataFrame() + + # Plot unaligned timepoints (pre/post alignment) with smaller, more transparent markers + if len(unaligned_data) > 0: + n_unaligned = len(unaligned_data) + alphas_unaligned = np.linspace(0.3, 0.5, n_unaligned) + + for i, (_, row) in enumerate(unaligned_data.iterrows()): + ax.scatter( + row["organelle_PC1"], + row["organelle_PC2"], + color=color, + alpha=alphas_unaligned[i], + s=15, # Smaller size for unaligned + zorder=2, + edgecolors="gray", + linewidths=0.3, + marker="s", # Square marker for unaligned + ) + + # Plot aligned timepoints with larger, more prominent markers + if len(aligned_data) > 0: + n_aligned = len(aligned_data) + alphas_aligned = np.linspace(0.5, 1.0, n_aligned) + + for i, (_, row) in enumerate(aligned_data.iterrows()): + ax.scatter( + row["organelle_PC1"], + row["organelle_PC2"], + color=color, + alpha=alphas_aligned[i], + s=40, # Larger size for aligned + zorder=3, + edgecolors="white", + linewidths=0.8, + marker="o", # Circle marker for aligned + ) + + # Add fancy arrows connecting ALL consecutive timepoints + for i in range(len(lineage_data) - 1): + row_start = lineage_data.iloc[i] + row_end = lineage_data.iloc[i + 1] + + # Check if this arrow is within aligned region + is_aligned_arrow = False + if alignment_col in row_start.index and alignment_col in row_end.index: + is_aligned_arrow = ( + row_start[alignment_col] == True and row_end[alignment_col] == True + ) + + # Create arrow with different styles for aligned vs unaligned + arrow = FancyArrowPatch( + (row_start["organelle_PC1"], row_start["organelle_PC2"]), + (row_end["organelle_PC1"], row_end["organelle_PC2"]), + arrowstyle="->,head_width=0.4,head_length=0.4", + color=color, + alpha=0.7 if is_aligned_arrow else 0.3, + linewidth=2.0 if is_aligned_arrow else 1.0, + linestyle="-" if is_aligned_arrow else "--", + zorder=2, + ) + ax.add_patch(arrow) + + # Mark first timepoint with a star + first_row = lineage_data.iloc[0] + ax.scatter( + first_row["organelle_PC1"], + first_row["organelle_PC2"], + marker="*", + s=300, + color=color, + edgecolors="black", + linewidths=1.5, + zorder=4, + label=f"Track {lineage_id}", + ) + + # Add legend elements for aligned vs unaligned + from matplotlib.lines import Line2D + + legend_elements = [ + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor="gray", + markersize=8, + label="Aligned timepoints", + markeredgecolor="white", + markeredgewidth=0.8, + ), + Line2D( + [0], + [0], + marker="s", + color="w", + markerfacecolor="gray", + markersize=6, + label="Unaligned timepoints", + markeredgecolor="gray", + markeredgewidth=0.3, + ), + ] + + ax.set_xlabel("Organelle PC1") + ax.set_ylabel("Organelle PC2") + ax.set_title( + "PCA of Organelle Channel: Complete Temporal Trajectories\n(Unaligned + Aligned + Unaligned)" + ) + + # Combine legend elements + handles, labels = ax.get_legend_handles_labels() + ax.legend( + handles=handles + legend_elements, + loc="upper left", + bbox_to_anchor=(1.05, 1), + fontsize=8, + ) + ax.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig( + output_root / f"trajectory_plot_{ALIGN_TYPE}.png", dpi=150, bbox_inches="tight" + ) + plt.show() +# %% diff --git a/applications/pseudotime_analysis/infection_state/visualize_alignment.py b/applications/pseudotime_analysis/infection_state/visualize_alignment.py new file mode 100644 index 00000000..036c4fd0 --- /dev/null +++ b/applications/pseudotime_analysis/infection_state/visualize_alignment.py @@ -0,0 +1,1842 @@ +# %% +import logging +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from anndata import read_zarr +from divergence_utils import quantify_divergence +from iohub import open_ome_zarr + +from viscy.data.triplet import TripletDataset +from viscy.representation.pseudotime import ( + CytoDtw, + filter_tracks_by_fov_and_length, + get_aligned_image_sequences, +) + +# %% +logger = logging.getLogger("viscy") +logger.setLevel(logging.INFO) +console_handler = logging.StreamHandler() +console_handler.setLevel(logging.INFO) +formatter = logging.Formatter("%(message)s") +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + +# Configuration +NAPARI = True +if NAPARI: + import os + + import napari + + os.environ["DISPLAY"] = ":1" + viewer = napari.Viewer() + +# %% +# File paths and configuration +output_root = Path( + "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/infection_state/output/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV" +) + +ALIGN_TYPE = "infection_apoptotic" # Options: "cell_division" or "infection_state" or "apoptosis" +ALIGNMENT_CHANNEL = "sensor" # sensor, phase, organelle + +# Cropping configuration +CROP_TO_ABSOLUTE_BOUNDS = ( + True # If True, crop before/after regions to stay within [0, max_absolute_time] +) + +NORMALIZE_N_TIMEPOINTS = 5 +NORMALIZE_N_CELLS_FOR_BASELINE = 5 + +# FOV filtering configuration +# Modify these for different datasets/experimental conditions +INFECTED_FOV_PATTERN = ( + "B/2" # Pattern to match infected FOVs (e.g., "B/2", "infected", "treatment") +) +UNINFECTED_FOV_PATTERN = ( + "B/1" # Pattern to match uninfected/control FOVs (e.g., "B/1", "control") +) +INFECTED_LABEL = "Infected" # Label for infected condition in plots +UNINFECTED_LABEL = "Uninfected" # Label for uninfected/control condition in plots + +# Data paths +data_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr" +segmentation_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/train_test_mito_seg_2.zarr" # Segmentation masks +features_path_sensor = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/anndata_predictions/sensor_160patch_104ckpt_ver3max.zarr" +features_path_phase = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/anndata_predictions/phase_160patch_104ckpt_ver3max.zarr" +features_path_organelle = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/predictions/anndata_predictions/organelle_160patch_104ckpt_ver3max.zarr" +metadata_path = output_root / f"alignment_metadata_{ALIGN_TYPE}_{ALIGNMENT_CHANNEL}.pkl" + +# %% +# Load master features dataframe +master_features_path = ( + output_root / f"master_features_{ALIGN_TYPE}_{ALIGNMENT_CHANNEL}.csv" +) +master_df = pd.read_csv(master_features_path) +logger.info(f"Loaded master features from {master_features_path}") +logger.info(f"Shape: {master_df.shape}") +logger.info(f"Columns: {list(master_df.columns)}") + +# Load AnnData for CytoDtw methods (needed for plotting) +ad_features_alignment = read_zarr( + features_path_sensor + if ALIGNMENT_CHANNEL == "sensor" + else ( + features_path_phase if ALIGNMENT_CHANNEL == "phase" else features_path_organelle + ) +) + +cytodtw = CytoDtw(ad_features_alignment) +cytodtw.load_consensus(metadata_path) + +# TODO: we should get rid of this redundancy later +metadata = cytodtw.consensus_data +consensus_lineage = metadata.get("consensus_pattern") +consensus_annotations = metadata.get("consensus_annotations") +consensus_metadata = metadata.get("consensus_metadata") +reference_cell_info = metadata.get("reference_cell_info") +alignment_infection_timepoint = metadata.get("raw_infection_timepoint") +aligned_region_bounds = metadata.get("aligned_region_bounds") + +# Get infection timepoint in absolute time coordinates (mapped from reference cell) +# NOTE: This is the actual 't' value where infection occurs, NOT the consensus index +absolute_infection_timepoint = alignment_infection_timepoint + +# Also get consensus infection index for validation +consensus_infection_idx = ( + consensus_annotations.index("infected") + if consensus_annotations and "infected" in consensus_annotations + else None +) + +# Log both values to help debug coordinate system alignment +logger.info("Infection timepoint mapping:") +logger.info( + f" - Consensus infection index (within consensus window): {consensus_infection_idx}" +) +logger.info( + f" - Absolute infection timepoint (reference cell's t value): {absolute_infection_timepoint}" +) +if consensus_infection_idx is not None and absolute_infection_timepoint is not None: + logger.info( + f" - Difference: {abs(absolute_infection_timepoint - consensus_infection_idx)} timepoints" + ) + logger.info( + f" - Using absolute timepoint ({absolute_infection_timepoint}) for coordinate alignment" + ) + +# %% +# Add warped coordinates if not already present +warped_col = f"dtw_{ALIGN_TYPE}_warped_t" +if warped_col not in master_df.columns: + logger.info(f"\n{'=' * 70}") + logger.info(f"Adding warped coordinates for {ALIGN_TYPE} alignment...") + logger.info(f"{'=' * 70}") + master_df = cytodtw.add_warped_coordinates(master_df, alignment_name=ALIGN_TYPE) + # Save augmented master_df + master_df.to_csv(master_features_path, index=False) + logger.info(f"Saved master_df with warped coordinates to {master_features_path}") +else: + logger.info(f"\nWarped coordinates already present: {warped_col}") + # Load warped metadata from consensus_data if available + if cytodtw.consensus_data and "warped_metadata" in cytodtw.consensus_data: + warped_meta = cytodtw.consensus_data["warped_metadata"] + logger.info(f"Warped metadata: {warped_meta}") + +# %% +# Data filtering and preparation +min_track_length = 20 + +# Filter to infected and uninfected cells using utility function +filtered_infected_df = filter_tracks_by_fov_and_length( + master_df, + fov_pattern=INFECTED_FOV_PATTERN, + min_timepoints=min_track_length, +) +logger.info( + f"Filtered {INFECTED_LABEL} cells: " + f"{filtered_infected_df.groupby(['fov_name', 'track_id']).ngroups} tracks" +) + +uninfected_filtered_df = filter_tracks_by_fov_and_length( + master_df, + fov_pattern=UNINFECTED_FOV_PATTERN, + min_timepoints=min_track_length, +) +logger.info( + f"Filtered {UNINFECTED_LABEL} cells: " + f"{uninfected_filtered_df.groupby(['fov_name', 'track_id']).ngroups} tracks" +) + +consensus_df = master_df[master_df["lineage_id"] == -1].copy() + +# %% +# Select features for analysis +common_response_features = [ + "organelle_PC1", + "organelle_PC2", + "organelle_PC3", + "phase_PC1", + "phase_PC2", + "phase_PC3", + "edge_density", + "organelle_volume", + "homogeneity", + "contrast", + "segs_count", + "segs_mean_area", + "segs_mean_eccentricity", + "segs_mean_frangi_mean", + "segs_circularity_mean", + "segs_circularity_cv", + "segs_eccentricity_cv", + "segs_area_cv", +] + +# %% +# Compute aggregated trajectories of the common response from top N aligned cells +top_n_cells = 30 + +# Get aligned cells only +aligned_cells = filtered_infected_df[ + filtered_infected_df[f"dtw_{ALIGN_TYPE}_aligned"].fillna(False) +].copy() + +# Select top N lineages by DTW distance +top_lineages_df = aligned_cells.drop_duplicates(["fov_name", "lineage_id"]).nsmallest( + top_n_cells, f"dtw_{ALIGN_TYPE}_distance" +)[["fov_name", "lineage_id"]] + +logger.info( + f"Selected top {len(top_lineages_df)} lineages by DTW distance from fovs {top_lineages_df['fov_name'].unique()}" +) + +# Filter using merge +top_cells_df = filtered_infected_df.merge( + top_lineages_df, on=["fov_name", "lineage_id"], how="inner" +).copy() +top_cells_df = top_cells_df.sort_values( + [f"dtw_{ALIGN_TYPE}_distance", "fov_name", "lineage_id", "t"] +) + + +# %% +# Helper function to aggregate trajectories +def aggregate_trajectory( + df: pd.DataFrame, + feature_columns: list, + baseline_n_timepoints: int = 3, + time_column: str = "t", + min_cell_count_for_baseline: int = 5, +) -> pd.DataFrame: + """ + Aggregate trajectories across cells by computing median and IQR at each timepoint. + + Also normalizes features to baseline (window with sufficient cell coverage). + + Parameters + ---------- + df : pd.DataFrame + Dataframe with feature columns and time column + feature_columns : list + List of feature columns to aggregate + baseline_n_timepoints : int + Number of consecutive timepoints to use as baseline for normalization + time_column : str + Name of the time column to group by (default: "t" for absolute time, + can be "dtw_{alignment}_warped_t" for warped time) + min_cell_count_for_baseline : int + Minimum number of cells required per timepoint to be included in baseline window + + Returns + ------- + pd.DataFrame + Aggregated dataframe with columns: {time_column}, {feature}_median, {feature}_q25, {feature}_q75, + and normalized versions {feature}_median_normalized, etc. + """ + # Group by timepoint + grouped = df.groupby(time_column) + + agg_data = [] + for t, group in grouped: + row = {time_column: t} + for feature in feature_columns: + if feature in group.columns: + values = group[feature].dropna() + if len(values) > 0: + row[f"{feature}_median"] = values.median() + row[f"{feature}_q25"] = values.quantile(0.25) + row[f"{feature}_q75"] = values.quantile(0.75) + row[f"{feature}_count"] = len(values) + else: + row[f"{feature}_median"] = np.nan + row[f"{feature}_q25"] = np.nan + row[f"{feature}_q75"] = np.nan + row[f"{feature}_count"] = 0 + agg_data.append(row) + + agg_df = pd.DataFrame(agg_data).sort_values(time_column).reset_index(drop=True) + + # Normalize to baseline (window with sufficient cell coverage) + baseline_mask = None + baseline_timepoints = None + + if baseline_n_timepoints > 0: + # Get cell counts per timepoint (use any feature's count column) + count_col = f"{feature_columns[0]}_count" + if count_col in agg_df.columns: + # Find timepoints with sufficient cell coverage + sufficient_cells_mask = agg_df[count_col] >= min_cell_count_for_baseline + sufficient_timepoints = agg_df[sufficient_cells_mask][time_column].values + + if len(sufficient_timepoints) >= baseline_n_timepoints: + # Find consecutive window of N timepoints with sufficient cells + best_window_start = None + for i in range(len(agg_df) - baseline_n_timepoints + 1): + window_slice = agg_df.iloc[i : i + baseline_n_timepoints] + if all(window_slice[count_col] >= min_cell_count_for_baseline): + best_window_start = i + break + + if best_window_start is not None: + # Use this window + baseline_mask = pd.Series(False, index=agg_df.index) + baseline_mask.iloc[ + best_window_start : best_window_start + baseline_n_timepoints + ] = True + baseline_timepoints = agg_df.loc[baseline_mask, time_column].values + logger.info( + f" Baseline window: t={baseline_timepoints[0]:.1f}-{baseline_timepoints[-1]:.1f}, " + f"cell counts: {agg_df.loc[baseline_mask, count_col].min():.0f}-{agg_df.loc[baseline_mask, count_col].max():.0f}" + ) + else: + # Fallback: use all timepoints with sufficient cells + baseline_mask = sufficient_cells_mask + baseline_timepoints = sufficient_timepoints + logger.warning( + f" Could not find {baseline_n_timepoints} consecutive timepoints with ≥{min_cell_count_for_baseline} cells. " + f"Using all {len(baseline_timepoints)} timepoints with sufficient coverage." + ) + else: + # Not enough timepoints with sufficient cells - use all available + logger.warning( + f" Insufficient timepoints with ≥{min_cell_count_for_baseline} cells " + f"(found {len(sufficient_timepoints)}, need {baseline_n_timepoints}). " + f"Using all timepoints for baseline." + ) + baseline_mask = pd.Series(True, index=agg_df.index) + baseline_timepoints = agg_df[time_column].values + else: + # Fallback to original logic if count column not found + logger.warning( + " Cell count column not found, using first N timepoints for baseline" + ) + baseline_mask = agg_df[time_column] <= ( + agg_df[time_column].min() + baseline_n_timepoints - 1 + ) + baseline_timepoints = agg_df.loc[baseline_mask, time_column].values + + # Normalize using selected baseline + if baseline_mask is not None: + for feature in feature_columns: + median_col = f"{feature}_median" + q25_col = f"{feature}_q25" + q75_col = f"{feature}_q75" + + # Skip CV/SEM features, PC features, and already normalized features + # PCs are relative measures and should not be baseline-normalized + if ( + feature.endswith("_cv") + or feature.endswith("_sem") + or feature.startswith("normalized_") + or "_PC" in feature # Skip PC1, PC2, PC3, etc. + ): + continue + + if median_col in agg_df.columns: + # Compute baseline mean from median values + baseline_values = agg_df.loc[baseline_mask, median_col].dropna() + if len(baseline_values) > 0: + baseline_mean = baseline_values.mean() + if not np.isnan(baseline_mean) and baseline_mean != 0: + # Normalize + agg_df[f"{median_col}_normalized"] = ( + agg_df[median_col] / baseline_mean + ) + agg_df[f"{q25_col}_normalized"] = ( + agg_df[q25_col] / baseline_mean + ) + agg_df[f"{q75_col}_normalized"] = ( + agg_df[q75_col] / baseline_mean + ) + + return agg_df + + +# %% +logger.info("\n" + "=" * 70) +logger.info("Computing aggregated trajectories in ABSOLUTE TIME") +logger.info("=" * 70) + +# Aggregate top N infected cells (common response) +logger.info("Aggregating top-N infected cells (common response):") +common_response_df = aggregate_trajectory( + top_cells_df, + common_response_features, + baseline_n_timepoints=NORMALIZE_N_TIMEPOINTS, + min_cell_count_for_baseline=NORMALIZE_N_CELLS_FOR_BASELINE, +) +logger.info( + f"Common response (top-{top_n_cells}): {len(common_response_df)} timepoints" +) + +# Aggregate uninfected baseline +logger.info("\nAggregating uninfected baseline:") +uninfected_baseline_df = aggregate_trajectory( + uninfected_filtered_df, + common_response_features, + baseline_n_timepoints=NORMALIZE_N_TIMEPOINTS, + min_cell_count_for_baseline=NORMALIZE_N_CELLS_FOR_BASELINE, +) +logger.info(f"Uninfected baseline: {len(uninfected_baseline_df)} timepoints") + +# Aggregate all infected cells (global average) +logger.info("\nAggregating all infected cells (global average):") +global_infected_df = aggregate_trajectory( + filtered_infected_df, + common_response_features, + baseline_n_timepoints=NORMALIZE_N_TIMEPOINTS, + min_cell_count_for_baseline=5, +) +logger.info(f"Global infected average: {len(global_infected_df)} timepoints") + +# Get anchor metadata from consensus +anchor_metadata = { + "anchor_start": ( + aligned_region_bounds[0] if aligned_region_bounds is not None else None + ), + "anchor_end": ( + aligned_region_bounds[1] if aligned_region_bounds is not None else None + ), + "window_start": common_response_df["t"].min(), + "window_end": common_response_df["t"].max(), +} +logger.info(f"Anchor metadata: {anchor_metadata}") + +# %% +logger.info("\n" + "=" * 70) +logger.info("Computing aggregated trajectories in WARPED TIME") +logger.info("=" * 70) + +# Filter to cells with valid warped coordinates (includes aligned + unaligned before/after) +warped_col = f"dtw_{ALIGN_TYPE}_warped_t" +top_cells_warped_df = top_cells_df[~top_cells_df[warped_col].isna()].copy() + +logger.info( + f"Cells with warped coordinates (full concatenated sequence): {len(top_cells_warped_df)} rows, " + f"{top_cells_warped_df.groupby(['fov_name', 'lineage_id']).ngroups} lineages" +) + +# Aggregate top N infected cells in warped time +logger.info("\nAggregating top-N infected cells in warped time:") +common_response_warped_df = aggregate_trajectory( + top_cells_warped_df, + common_response_features, + baseline_n_timepoints=NORMALIZE_N_TIMEPOINTS, + time_column=warped_col, + min_cell_count_for_baseline=NORMALIZE_N_CELLS_FOR_BASELINE, +) +logger.info( + f"Common response in warped time (top-{top_n_cells}): {len(common_response_warped_df)} timepoints" +) + +# Rename warped_t column to "t" for plotting compatibility +common_response_warped_df = common_response_warped_df.rename(columns={warped_col: "t"}) + +# Aggregate all infected cells in warped time (no alignment requirement) +# This shows what the average infected trajectory looks like WITHOUT DTW synchronization +filtered_infected_warped_df = filtered_infected_df[ + ~filtered_infected_df[warped_col].isna() +].copy() + +logger.info( + f"All infected cells with warped coordinates: {len(filtered_infected_warped_df)} rows, " + f"{filtered_infected_warped_df.groupby(['fov_name', 'lineage_id']).ngroups} lineages" +) + +logger.info("\nAggregating all infected cells in warped time (no alignment):") +global_infected_warped_df = aggregate_trajectory( + filtered_infected_warped_df, + common_response_features, + baseline_n_timepoints=NORMALIZE_N_TIMEPOINTS, + time_column=warped_col, + min_cell_count_for_baseline=NORMALIZE_N_CELLS_FOR_BASELINE, +) +logger.info( + f"Global infected in warped time (no alignment): {len(global_infected_warped_df)} timepoints" +) + +# Rename warped_t column to "t" for plotting compatibility +global_infected_warped_df = global_infected_warped_df.rename(columns={warped_col: "t"}) + +# Get warped metadata for period definitions +if cytodtw.consensus_data and "warped_metadata" in cytodtw.consensus_data: + warped_meta = cytodtw.consensus_data["warped_metadata"] + logger.info(f"Warped metadata: {warped_meta}") + + +# %% +def plot_binned_period_comparison( + infected_df: pd.DataFrame, + uninfected_df: pd.DataFrame, + feature_columns: list, + periods: dict, + baseline_period_name: str = None, + infection_time: int = None, + global_infected_df: pd.DataFrame = None, + output_root: Path = None, + figsize=(18, 14), + plot_type: str = "line", + add_stats: bool = True, + infected_label: str = "Infected", + uninfected_label: str = "Uninfected", +): + """ + Plot binned period comparison showing fold-change across biological phases. + + Parameters + ---------- + infected_df : pd.DataFrame + Infected common response aggregated dataframe + uninfected_df : pd.DataFrame + Uninfected baseline aggregated dataframe + feature_columns : list + Features to plot + periods : dict + Dictionary defining time periods for binning. + Keys are period labels (str), values are tuples (start_time, end_time). + Example: {"Baseline": (0, 10), "Early": (10, 20), "Late": (20, 30)} + baseline_period_name : str, optional + Name of the period to use as baseline for normalization. + If None, uses the first period in the periods dict. + infection_time : int, optional + Infection timepoint (only used to mark infection event on plot) + global_infected_df : pd.DataFrame, optional + Global average of all infected cells + output_root : Path, optional + Directory to save output figure + figsize : tuple + Figure size + plot_type : str + 'line' for connected line plots or 'bar' for grouped bar plots (default: 'line') + add_stats : bool + If True, perform statistical testing and mark significant differences (default: True) + infected_label : str + Label for infected condition in plots + uninfected_label : str + Label for uninfected/control condition in plots + """ + from scipy.stats import ttest_ind + + # Determine baseline period for normalization + if baseline_period_name is None: + # Use first period as baseline + baseline_period_name = list(periods.keys())[0] + + if baseline_period_name not in periods: + raise ValueError( + f"Baseline period '{baseline_period_name}' not found in periods dict. " + f"Available periods: {list(periods.keys())}" + ) + + period_names = list(periods.keys()) + n_periods = len(periods) + + n_features = len(feature_columns) + ncols = 3 + nrows = int(np.ceil(n_features / ncols)) + + fig, axes = plt.subplots(nrows, ncols, figsize=figsize) + axes = axes.flatten() if n_features > 1 else [axes] + + # Colorblind-friendly palette + colors = { + "uninfected": "#1f77b4", # blue + "infected": "#ff7f0e", # orange + "global": "#2ca02c", # green + } + + # Store statistical results for logging + stats_results = {} + + for idx, feature in enumerate(feature_columns): + ax = axes[idx] + + median_col = f"{feature}_median" + + # Check if feature exists + if ( + median_col not in infected_df.columns + or median_col not in uninfected_df.columns + ): + ax.text(0.5, 0.5, f"{feature}\nno data", ha="center", va="center") + ax.set_title(feature) + continue + + # Check if CV/SEM, PC, or pre-normalized feature (no additional normalization) + is_cv_feature = feature.endswith("_cv") or feature.endswith("_sem") + is_pc_feature = "_PC" in feature # PC features should not be normalized + is_prenormalized = feature.startswith("normalized_") + + # Compute values for each period + period_values = {"uninfected": [], "infected": [], "global": []} + period_errors = {"uninfected": [], "infected": [], "global": []} + + # Baseline period values for normalization + baseline_period = periods[baseline_period_name] + + def compute_baseline_value(df, feature_col): + """Compute baseline value from baseline period.""" + mask = (df["t"] >= baseline_period[0]) & (df["t"] <= baseline_period[1]) + values = df.loc[mask, feature_col].dropna() + return values.mean() if len(values) > 0 else None + + # Compute baseline for this feature from each trajectory + uninfected_baseline = compute_baseline_value(uninfected_df, median_col) + infected_baseline = compute_baseline_value(infected_df, median_col) + global_baseline = None + if global_infected_df is not None and median_col in global_infected_df.columns: + global_baseline = compute_baseline_value(global_infected_df, median_col) + + # For each period, compute aggregate value and normalize + for period_name, (t_start, t_end) in periods.items(): + # Uninfected + mask = (uninfected_df["t"] >= t_start) & (uninfected_df["t"] <= t_end) + values = uninfected_df.loc[mask, median_col].dropna() + if len(values) > 0: + mean_val = values.mean() + std_val = values.std() + + # Normalize to baseline if not CV/PC feature and not pre-normalized + if ( + not is_cv_feature + and not is_pc_feature + and not is_prenormalized + and uninfected_baseline is not None + ): + mean_val = mean_val / uninfected_baseline + std_val = std_val / (np.abs(uninfected_baseline) + 1e-6) + + period_values["uninfected"].append(mean_val) + period_errors["uninfected"].append(std_val) + else: + period_values["uninfected"].append(np.nan) + period_errors["uninfected"].append(np.nan) + + # Infected + mask = (infected_df["t"] >= t_start) & (infected_df["t"] <= t_end) + values = infected_df.loc[mask, median_col].dropna() + if len(values) > 0: + mean_val = values.mean() + std_val = values.std() + + if ( + not is_cv_feature + and not is_pc_feature + and not is_prenormalized + and infected_baseline is not None + ): + mean_val = mean_val / infected_baseline + std_val = std_val / (np.abs(infected_baseline) + 1e-6) + + period_values["infected"].append(mean_val) + period_errors["infected"].append(std_val) + else: + period_values["infected"].append(np.nan) + period_errors["infected"].append(np.nan) + + # Global infected + if ( + global_infected_df is not None + and median_col in global_infected_df.columns + ): + mask = (global_infected_df["t"] >= t_start) & ( + global_infected_df["t"] <= t_end + ) + values = global_infected_df.loc[mask, median_col].dropna() + if len(values) > 0: + mean_val = values.mean() + std_val = values.std() + + if ( + not is_cv_feature + and not is_pc_feature + and not is_prenormalized + and global_baseline is not None + ): + mean_val = mean_val / global_baseline + std_val = std_val / (np.abs(global_baseline) + 1e-6) + + period_values["global"].append(mean_val) + period_errors["global"].append(std_val) + else: + period_values["global"].append(np.nan) + period_errors["global"].append(np.nan) + + # Statistical testing between infected and uninfected at each period + p_values = [] + if add_stats: + stats_results[feature] = {} + for period_name, (t_start, t_end) in periods.items(): + # Get raw values for statistical testing + uninfected_mask = (uninfected_df["t"] >= t_start) & ( + uninfected_df["t"] <= t_end + ) + infected_mask = (infected_df["t"] >= t_start) & ( + infected_df["t"] <= t_end + ) + + uninfected_vals = uninfected_df.loc[ + uninfected_mask, median_col + ].dropna() + infected_vals = infected_df.loc[infected_mask, median_col].dropna() + + if len(uninfected_vals) >= 3 and len(infected_vals) >= 3: + _, p_val = ttest_ind(uninfected_vals, infected_vals) + p_values.append(p_val) + stats_results[feature][period_name] = p_val + else: + p_values.append(np.nan) + stats_results[feature][period_name] = np.nan + + x = np.arange(n_periods) + + if plot_type == "line": + # Line plot with error bars + ax.errorbar( + x, + period_values["uninfected"], + yerr=period_errors["uninfected"], + label=f"{uninfected_label}", + color=colors["uninfected"], + marker="o", + markersize=8, + linewidth=2.5, + capsize=4, + capthick=2, + ) + ax.errorbar( + x, + period_values["infected"], + yerr=period_errors["infected"], + label=f"{infected_label} (top-N)", + color=colors["infected"], + marker="s", + markersize=8, + linewidth=2.5, + capsize=4, + capthick=2, + ) + + if global_infected_df is not None: + ax.errorbar( + x, + period_values["global"], + yerr=period_errors["global"], + label=f"All {infected_label}", + color=colors["global"], + marker="^", + markersize=7, + linewidth=2, + linestyle="--", + capsize=4, + capthick=1.5, + alpha=0.8, + ) + + # Mark significant differences with asterisks + if add_stats and len(p_values) > 0: + # Get valid (non-NaN) values for computing y_max + valid_uninfected = [ + v for v in period_values["uninfected"] if not np.isnan(v) + ] + valid_infected = [ + v for v in period_values["infected"] if not np.isnan(v) + ] + + # Only add markers if we have valid values + if len(valid_uninfected) > 0 and len(valid_infected) > 0: + y_max = max(max(valid_uninfected), max(valid_infected)) + y_offset = 0.1 * (y_max - 1.0) if not is_cv_feature else 0.1 * y_max + + for i, p_val in enumerate(p_values): + if not np.isnan(p_val): + # Determine significance level + if p_val < 0.001: + marker = "***" + elif p_val < 0.01: + marker = "**" + elif p_val < 0.05: + marker = "*" + else: + marker = "ns" + + if marker != "ns": + # Position text above the higher of the two values + # Skip if either value is NaN + if not np.isnan( + period_values["uninfected"][i] + ) and not np.isnan(period_values["infected"][i]): + max_val = max( + period_values["uninfected"][i], + period_values["infected"][i], + ) + ax.text( + x[i], + max_val + y_offset, + marker, + ha="center", + va="bottom", + fontsize=10, + fontweight="bold", + color="black", + ) + + else: # bar plot + width = 0.25 + + ax.bar( + x - width, + period_values["uninfected"], + width, + label=f"{uninfected_label}", + color=colors["uninfected"], + yerr=period_errors["uninfected"], + capsize=3, + ) + ax.bar( + x, + period_values["infected"], + width, + label=f"{infected_label} (top-N)", + color=colors["infected"], + yerr=period_errors["infected"], + capsize=3, + ) + + if global_infected_df is not None: + ax.bar( + x + width, + period_values["global"], + width, + label=f"All {infected_label}", + color=colors["global"], + yerr=period_errors["global"], + capsize=3, + alpha=0.8, + ) + + # Add horizontal line at 1.0 (no change from baseline) - but not for PCs or CVs + if not is_cv_feature and not is_pc_feature: + ax.axhline(1.0, color="gray", linestyle="--", linewidth=1, alpha=0.5) + + ax.set_xlabel("Period") + if is_cv_feature: + ax.set_ylabel(f"{feature}\n(raw value)") + elif is_pc_feature: + ax.set_ylabel(f"{feature}\n(PC units, not normalized)") + elif is_prenormalized: + ax.set_ylabel(f"{feature}\n(pre-normalized, baseline t=1-10)") + else: + ax.set_ylabel(f"{feature}\n(fold-change from baseline)") + ax.set_title(feature) + ax.set_xticks(x) + ax.set_xticklabels(period_names, rotation=45, ha="right") + ax.grid(True, alpha=0.3, axis="y") + + # Hide unused subplots + for idx in range(n_features, len(axes)): + axes[idx].axis("off") + + # Create a single shared legend for the entire figure + # Get handles and labels from the first subplot (they're all the same) + handles, labels = axes[0].get_legend_handles_labels() + # Place legend on the right side of the figure + fig.legend( + handles, + labels, + loc="center right", + bbox_to_anchor=(0.98, 0.5), + fontsize=9, + frameon=True, + ) + + # Create title with statistical note + title = "Binned Period Comparison: Fold-Change Across Infection Phases" + if add_stats: + title += "\n(* p<0.05, ** p<0.01, *** p<0.001)" + + plt.suptitle( + title, + fontsize=14, + y=1.00, + ) + # Use tight_layout with extra space on the right for the shared legend + plt.tight_layout(rect=[0, 0, 0.85, 1]) + + if output_root is not None: + # Detect if this is warped time by checking the label + is_warped = "DTW-aligned" in infected_label + suffix = "warped" if is_warped else "absolute" + save_path = output_root / f"binned_period_comparison_{ALIGN_TYPE}_{suffix}.png" + plt.savefig(save_path, dpi=150, bbox_inches="tight") + logger.info(f"Saved binned period comparison to {save_path}") + + plt.show() + + # Log summary in markdown format + logger.info("\n## Binned Period Comparison Summary") + logger.info(f"**Infection timepoint:** {infection_time}") + logger.info("\n### Period Definitions") + for period_name, (t_start, t_end) in periods.items(): + logger.info(f"- **{period_name}:** t={t_start} to t={t_end}") + + if add_stats and len(stats_results) > 0: + logger.info("\n### Statistical Significance (t-tests)") + logger.info( + "Comparing infected top-N vs uninfected at each period. Significance levels: * p<0.05, ** p<0.01, *** p<0.001\n" + ) + + # Create markdown table + logger.info(f"| Feature | {' | '.join(period_names)} |") + logger.info(f"|---------|{'---------|-' * (len(period_names) - 1)}---------|") + + for feature, period_results in stats_results.items(): + sig_markers = [] + for period_name in period_names: + p_val = period_results.get(period_name, np.nan) + if np.isnan(p_val): + sig_markers.append("N/A") + elif p_val < 0.001: + sig_markers.append(f"***({p_val:.3e})") + elif p_val < 0.01: + sig_markers.append(f"**({p_val:.3f})") + elif p_val < 0.05: + sig_markers.append(f"*({p_val:.3f})") + else: + sig_markers.append(f"ns({p_val:.3f})") + + logger.info(f"| {feature} | {' | '.join(sig_markers)} |") + + logger.info("\nns = not significant (p >= 0.05)") + + +# %% +def plot_infected_vs_uninfected_comparison( + infected_df: pd.DataFrame, + uninfected_df: pd.DataFrame, + feature_columns: list, + warped_metadata: dict, + anchor_metadata: dict = None, + figsize=(18, 14), + n_consecutive_divergence: int = 5, + global_infected_df: pd.DataFrame = None, + normalize_to_baseline: bool = True, + infected_label: str = "Infected", + uninfected_label: str = "Uninfected", + output_root: Path = None, + plot_cell_counts: bool = True, + min_cell_count_threshold: int = 10, +): + """ + Plot comparison of infected vs uninfected trajectories in warped/pseudotime. + + Compares trajectories in synchronized biological time (warped/pseudotime coordinates) + where infected cells' aligned regions are synchronized. Shows where infected cells + diverge from normal behavior (crossover points). + + Uses pre-computed normalized columns from dataframes (e.g., '{feature}_median_normalized') + added by aggregate_trajectory(). + + NOTE: Features ending in '_cv' or '_sem' are plotted as raw values without baseline + normalization, since CV and SEM are already relative/uncertainty metrics. + + Parameters + ---------- + infected_df : pd.DataFrame + Infected common response aggregated in warped time with 't' column (warped coordinates) + and '{feature}_*_normalized' columns + uninfected_df : pd.DataFrame + Uninfected baseline shifted to warped time coordinates with '{feature}_*_normalized' columns + feature_columns : list + Features to plot + warped_metadata : dict + Metadata from warped coordinate system containing: + - max_unaligned_before: warped time where aligned region starts + - consensus_aligned_length: length of the synchronized aligned region + - total_warped_length: total length of warped time axis + anchor_metadata : dict, optional + Metadata for highlighting the aligned region bounds: + - anchor_start: warped time where aligned region starts + - anchor_end: warped time where aligned region ends + - window_start: warped time where aggregated data starts + - window_end: warped time where aggregated data ends + figsize : tuple + Figure size + n_consecutive_divergence : int + Number of consecutive timepoints required to confirm divergence (default: 5) + global_infected_df : pd.DataFrame, optional + Global average of ALL infected cells in warped time, with normalized columns + normalize_to_baseline : bool + If True, use normalized columns to show fold-change + infected_label : str + Label for infected condition in plots + uninfected_label : str + Label for uninfected/control condition in plots + output_root : Path, optional + Directory to save output figure + """ + from scipy.interpolate import interp1d + + n_features = len(feature_columns) + ncols = 3 + nrows = int(np.ceil(n_features / ncols)) + + fig, axes = plt.subplots(nrows, ncols, figsize=figsize) + axes = axes.flatten() if n_features > 1 else [axes] + + # Colorblind-friendly palette + uninfected_color = "#1f77b4" # blue + infected_color = "#ff7f0e" # orange + + for idx, feature in enumerate(feature_columns): + ax = axes[idx] + + median_col = f"{feature}_median" + q25_col = f"{feature}_q25" + q75_col = f"{feature}_q75" + + # Check if data exists + if ( + median_col not in infected_df.columns + or median_col not in uninfected_df.columns + ): + ax.text(0.5, 0.5, f"{feature}\nno data", ha="center", va="center") + ax.set_title(feature) + continue + + # Highlight aligned/anchor region in warped time (background layer) + if anchor_metadata is not None: + anchor_start = anchor_metadata.get("anchor_start") + anchor_end = anchor_metadata.get("anchor_end") + if anchor_start is not None and anchor_end is not None: + ax.axvspan( + anchor_start, + anchor_end, + alpha=0.15, + color="gray", + label="Synchronized aligned region", + zorder=0, + ) + + # Filter timepoints based on cell count threshold + count_col = f"{feature}_count" + + # Plot uninfected baseline + uninfected_time = uninfected_df["t"].values + + # Check if this is a CV/SEM, PC, or pre-normalized feature + is_cv_feature = feature.endswith("_cv") or feature.endswith("_sem") + is_pc_feature = "_PC" in feature # PC features should not be normalized + is_prenormalized = feature.startswith("normalized_") + + # Check if normalized columns exist in dataframe (from normalize_aggregated_trajectory) + normalized_median_col = f"{median_col}_normalized" + normalized_q25_col = f"{q25_col}_normalized" + normalized_q75_col = f"{q75_col}_normalized" + has_normalized_columns = ( + normalized_median_col in uninfected_df.columns + and normalized_q25_col in uninfected_df.columns + and normalized_q75_col in uninfected_df.columns + ) + + # Use normalized columns if available and requested (but not for PCs) + if ( + normalize_to_baseline + and not is_cv_feature + and not is_pc_feature + and not is_prenormalized + and has_normalized_columns + ): + uninfected_median = uninfected_df[normalized_median_col].values + uninfected_q25 = uninfected_df[normalized_q25_col].values + uninfected_q75 = uninfected_df[normalized_q75_col].values + else: + # Use raw values + uninfected_median = uninfected_df[median_col].values + uninfected_q25 = uninfected_df[q25_col].values + uninfected_q75 = uninfected_df[q75_col].values + + # Filter by cell count threshold + if count_col in uninfected_df.columns: + valid_mask = uninfected_df[count_col].values >= min_cell_count_threshold + uninfected_time_filtered = uninfected_time[valid_mask] + uninfected_median_filtered = uninfected_median[valid_mask] + uninfected_q25_filtered = uninfected_q25[valid_mask] + uninfected_q75_filtered = uninfected_q75[valid_mask] + else: + uninfected_time_filtered = uninfected_time + uninfected_median_filtered = uninfected_median + uninfected_q25_filtered = uninfected_q25 + uninfected_q75_filtered = uninfected_q75 + + ax.plot( + uninfected_time_filtered, + uninfected_median_filtered, + color=uninfected_color, + linewidth=2.5, + label=f"{uninfected_label}", + linestyle="-", + ) + ax.fill_between( + uninfected_time_filtered, + uninfected_q25_filtered, + uninfected_q75_filtered, + color=uninfected_color, + alpha=0.2, + ) + + # Plot infected aligned response + infected_time = infected_df["t"].values + + # Check if normalized columns exist for infected trajectory + has_normalized_columns_infected = ( + normalized_median_col in infected_df.columns + and normalized_q25_col in infected_df.columns + and normalized_q75_col in infected_df.columns + ) + + # Use normalized columns if available and requested (but not for PCs) + if ( + normalize_to_baseline + and not is_cv_feature + and not is_pc_feature + and not is_prenormalized + and has_normalized_columns_infected + ): + infected_median = infected_df[normalized_median_col].values + infected_q25 = infected_df[normalized_q25_col].values + infected_q75 = infected_df[normalized_q75_col].values + else: + # Use raw values + infected_median = infected_df[median_col].values + infected_q25 = infected_df[q25_col].values + infected_q75 = infected_df[q75_col].values + + # Filter by cell count threshold + if count_col in infected_df.columns: + valid_mask = infected_df[count_col].values >= min_cell_count_threshold + infected_time_filtered = infected_time[valid_mask] + infected_median_filtered = infected_median[valid_mask] + infected_q25_filtered = infected_q25[valid_mask] + infected_q75_filtered = infected_q75[valid_mask] + else: + infected_time_filtered = infected_time + infected_median_filtered = infected_median + infected_q25_filtered = infected_q25 + infected_q75_filtered = infected_q75 + + ax.plot( + infected_time_filtered, + infected_median_filtered, + color=infected_color, + linewidth=2.5, + label=f"{infected_label} (top-N aligned)", + linestyle="-", + ) + ax.fill_between( + infected_time_filtered, + infected_q25_filtered, + infected_q75_filtered, + color=infected_color, + alpha=0.2, + ) + + # Plot global infected average + if global_infected_df is not None and median_col in global_infected_df.columns: + global_time = global_infected_df["t"].values + + # Check if normalized columns exist for global trajectory + has_normalized_columns_global = ( + normalized_median_col in global_infected_df.columns + and normalized_q25_col in global_infected_df.columns + and normalized_q75_col in global_infected_df.columns + ) + + # Use normalized columns if available and requested (but not for PCs) + if ( + normalize_to_baseline + and not is_cv_feature + and not is_pc_feature + and not is_prenormalized + and has_normalized_columns_global + ): + global_median = global_infected_df[normalized_median_col].values + global_q25 = global_infected_df[normalized_q25_col].values + global_q75 = global_infected_df[normalized_q75_col].values + else: + # Use raw values + global_median = global_infected_df[median_col].values + global_q25 = ( + global_infected_df[q25_col].values + if q25_col in global_infected_df.columns + else None + ) + global_q75 = ( + global_infected_df[q75_col].values + if q75_col in global_infected_df.columns + else None + ) + + # Filter by cell count threshold + if count_col in global_infected_df.columns: + valid_mask = ( + global_infected_df[count_col].values >= min_cell_count_threshold + ) + global_time_filtered = global_time[valid_mask] + global_median_filtered = global_median[valid_mask] + global_q25_filtered = ( + global_q25[valid_mask] if global_q25 is not None else None + ) + global_q75_filtered = ( + global_q75[valid_mask] if global_q75 is not None else None + ) + else: + global_time_filtered = global_time + global_median_filtered = global_median + global_q25_filtered = global_q25 + global_q75_filtered = global_q75 + + ax.plot( + global_time_filtered, + global_median_filtered, + color="#15ba10", # green + linewidth=2, + label=f"All {infected_label} (no alignment)", + linestyle="--", + alpha=0.8, + ) + if global_q25_filtered is not None and global_q75_filtered is not None: + ax.fill_between( + global_time_filtered, + global_q25_filtered, + global_q75_filtered, + color="#15ba10", + alpha=0.15, + ) + + # Mark alignment start (where synchronized biological response begins) + if warped_metadata is not None: + alignment_start = warped_metadata.get("max_unaligned_before") + if alignment_start is not None: + ax.axvline( + alignment_start, + color="red", + linestyle="-", + alpha=0.8, + linewidth=2.5, + label="Alignment start (infection)", + zorder=5, + ) + + # Find consecutive divergence points (use filtered data) + if len(uninfected_median_filtered) > 0 and n_consecutive_divergence > 0: + uninfected_std = np.nanstd(uninfected_median_filtered) + + # Interpolate uninfected to match infected timepoints (use filtered data) + if len(uninfected_time_filtered) > 1 and len(infected_time_filtered) > 1: + min_t = max( + uninfected_time_filtered.min(), infected_time_filtered.min() + ) + max_t = min( + uninfected_time_filtered.max(), infected_time_filtered.max() + ) + + if min_t < max_t: + interp_func = interp1d( + uninfected_time_filtered, + uninfected_median_filtered, + kind="linear", + fill_value="extrapolate", + ) + + # Find timepoints where infected is significantly different + # Allow divergence detection across all timepoints (including before infection) + overlap_mask = (infected_time_filtered >= min_t) & ( + infected_time_filtered <= max_t + ) + + overlap_times = infected_time_filtered[overlap_mask] + overlap_infected = infected_median_filtered[overlap_mask] + overlap_uninfected = interp_func(overlap_times) + + divergence = np.abs(overlap_infected - overlap_uninfected) + threshold = 1.5 * uninfected_std + + divergent_mask = divergence > threshold + + # Find consecutive divergence streaks + if np.any(divergent_mask): + consecutive_start = None + consecutive_count = 0 + + for i, is_divergent in enumerate(divergent_mask): + if is_divergent: + if consecutive_start is None: + consecutive_start = i + consecutive_count += 1 + + if consecutive_count >= n_consecutive_divergence: + first_divergence = overlap_times[consecutive_start] + ax.axvline( + first_divergence, + color="red", + linestyle="--", + alpha=0.6, + linewidth=2, + label=f"Divergence (t={first_divergence:.0f})", + zorder=4, + ) + break + else: + consecutive_start = None + consecutive_count = 0 + + ax.set_xlabel("Warped Pseudotime") + # Update y-axis label + if feature.endswith("_cv"): + ax.set_ylabel(f"{feature}\n(raw CV)") + elif feature.endswith("_sem"): + ax.set_ylabel(f"{feature}\n(raw SEM)") + elif is_pc_feature: + ax.set_ylabel(f"{feature}\n(PC units, not normalized)") + elif is_prenormalized: + ax.set_ylabel(f"{feature}\n(pre-normalized, baseline t=1-10)") + elif normalize_to_baseline and has_normalized_columns: + ax.set_ylabel(f"{feature}\n(fold-change from baseline)") + else: + ax.set_ylabel(feature) + ax.set_title(feature) + ax.grid(True, alpha=0.3) + + # Hide unused subplots + for idx in range(n_features, len(axes)): + axes[idx].axis("off") + + # Create a single shared legend for the entire figure + # Get handles and labels from the first subplot (they're all the same) + handles, labels = axes[0].get_legend_handles_labels() + # Place legend on the right side of the figure + fig.legend( + handles, + labels, + loc="center right", + bbox_to_anchor=(0.98, 0.5), + fontsize=9, + frameon=True, + ) + + # Use tight_layout with extra space on the right for the shared legend + plt.tight_layout(rect=[0, 0, 0.85, 1]) + + # Save with warped time suffix + if output_root is not None: + save_path = ( + output_root / f"infected_vs_uninfected_comparison_{ALIGN_TYPE}_warped.png" + ) + plt.savefig(save_path, dpi=150, bbox_inches="tight") + logger.info(f"Saved warped time comparison plot to {save_path}") + + plt.show() + + +# %% +# Plot individual lineages +# Use top_cells_df which has BOTH min_track_length filter AND top-N by DTW distance +# Important: Include "consensus" to ensure plotting methods have the reference pattern +consensus_df = master_df[master_df["lineage_id"] == -1].copy() + +alignment_df_for_plotting = pd.concat([top_cells_df, consensus_df], ignore_index=True) + +logger.info( + f"Filtered plotting dataframe ({INFECTED_FOV_PATTERN}): {len(alignment_df_for_plotting)} rows, " + f"{alignment_df_for_plotting['lineage_id'].nunique()} unique lineages" +) +logger.info( + f"Includes consensus: {(alignment_df_for_plotting['fov_name'] == 'consensus').any()}" +) +logger.info( + f"All lineages have minimum {min_track_length} timepoints and are top-{top_n_cells} by DTW distance (except consensus)" +) + +fig = cytodtw.plot_individual_lineages( + alignment_df_for_plotting, + alignment_name=ALIGN_TYPE, + feature_columns=[ + "sensor_PC1", + "homogeneity", + "contrast", + "edge_density", + "segs_count", + "segs_total_area", + "segs_mean_area", + ], + max_lineages=8, + aligned_linewidth=2.5, + unaligned_linewidth=1.4, + y_offset_step=0.0, +) + +# %% +# Heatmap showing all tracks +fig = cytodtw.plot_global_trends( + alignment_df_for_plotting, + alignment_name=ALIGN_TYPE, + plot_type="heatmap", + cmap="RdBu", + figsize=(12, 12), + feature_columns=[ + "organelle_PC1", + "organelle_PC2", + "organelle_PC3", + "edge_density", + "segs_count", + "segs_total_area", + "segs_mean_area", + "segs_circularity_mean", + "segs_mean_frangi_mean", + ], + max_lineages=10, +) +# %% +# Infected vs uninfected comparison - WARPED TIME +logger.info("\n" + "=" * 70) +logger.info("WARPED TIME: Infected vs Uninfected Comparison") +logger.info("=" * 70) + +# Get warped metadata +if cytodtw.consensus_data and "warped_metadata" in cytodtw.consensus_data: + warped_meta = cytodtw.consensus_data["warped_metadata"] + max_unaligned_before = warped_meta["max_unaligned_before"] + consensus_aligned_length = warped_meta["consensus_aligned_length"] + + # Shift uninfected trajectory to align with warped time + # Strategy: align infection time in uninfected with start of aligned region in warped time + uninfected_shifted_df = uninfected_baseline_df.copy() + time_shift = max_unaligned_before - absolute_infection_timepoint + uninfected_shifted_df["t"] = uninfected_shifted_df["t"] + time_shift + + # Shift global infected trajectory to align with warped time (same shift) + # This shows all infected cells without alignment in ABSOLUTE time, shifted to warped coordinates for comparison + global_infected_shifted_df = global_infected_df.copy() + global_infected_shifted_df["t"] = global_infected_df["t"] + time_shift + + # Create warped anchor metadata + warped_anchor_metadata = { + "anchor_start": max_unaligned_before, + "anchor_end": max_unaligned_before + consensus_aligned_length - 1, + "window_start": common_response_warped_df["t"].min(), + "window_end": common_response_warped_df["t"].max(), + } + + logger.info(f"Warped anchor metadata: {warped_anchor_metadata}") + logger.info( + f"Shifted uninfected time by {time_shift} frames to align with warped time" + ) + logger.info( + f"Shifted global infected (no alignment) by same {time_shift} frames to align with warped time" + ) + + # Plot warped time comparison + plot_infected_vs_uninfected_comparison( + common_response_warped_df, # Infected top-N in warped time (aligned) + uninfected_shifted_df, # Uninfected shifted to warped time + common_response_features, + warped_metadata=warped_meta, + anchor_metadata=warped_anchor_metadata, + figsize=(18, 14), + n_consecutive_divergence=5, + global_infected_df=global_infected_shifted_df, # All infected in warped time (no alignment) + normalize_to_baseline=True, + infected_label=INFECTED_LABEL, + uninfected_label=UNINFECTED_LABEL, + output_root=output_root, + min_cell_count_threshold=NORMALIZE_N_CELLS_FOR_BASELINE, + ) + + # Divergence quantification analysis + logger.info("\n" + "=" * 70) + logger.info("DIVERGENCE TIMING ANALYSIS: Quantifying Organelle Remodeling") + logger.info("=" * 70) + + # Configuration for divergence detection + n_consecutive = 5 + threshold_multiplier = 1.5 + + # Collect divergence results + divergence_results = [] + + for feature in common_response_features: + logger.info(f"\nAnalyzing divergence for: {feature}") + + # Comparison 1: Aligned infected vs Uninfected (conserved response) + result_aligned = quantify_divergence( + test_df=common_response_warped_df, + reference_df=uninfected_shifted_df, + feature=f"{feature}_median", + n_consecutive=n_consecutive, + threshold_std_multiplier=threshold_multiplier, + normalize_to_baseline=False, # Already normalized in dataframes + ) + result_aligned["feature"] = feature + result_aligned["comparison"] = "aligned_vs_uninfected" + divergence_results.append(result_aligned) + + # Comparison 2: Unaligned infected vs Uninfected (population average) + result_unaligned = quantify_divergence( + test_df=global_infected_shifted_df, + reference_df=uninfected_shifted_df, + feature=f"{feature}_median", + n_consecutive=n_consecutive, + threshold_std_multiplier=threshold_multiplier, + normalize_to_baseline=False, + ) + result_unaligned["feature"] = feature + result_unaligned["comparison"] = "unaligned_vs_uninfected" + divergence_results.append(result_unaligned) + + # Comparison 3: Aligned vs Unaligned infected (effect of synchronization) + result_aligned_vs_unaligned = quantify_divergence( + test_df=common_response_warped_df, + reference_df=global_infected_shifted_df, + feature=f"{feature}_median", + n_consecutive=n_consecutive, + threshold_std_multiplier=threshold_multiplier, + normalize_to_baseline=False, + ) + result_aligned_vs_unaligned["feature"] = feature + result_aligned_vs_unaligned["comparison"] = "aligned_vs_unaligned" + divergence_results.append(result_aligned_vs_unaligned) + + # Create results dataframe + divergence_df = pd.DataFrame(divergence_results) + + # Save to CSV + divergence_csv_path = output_root / f"divergence_analysis_{ALIGN_TYPE}.csv" + divergence_df.to_csv(divergence_csv_path, index=False) + logger.info(f"\nSaved divergence analysis to: {divergence_csv_path}") + + # Log results in markdown format + logger.info("\n## Divergence Timing Analysis Results") + logger.info("**Analysis**: Organelle remodeling timing during infection") + logger.info( + "**Method**: DTW-synchronized (aligned) vs unsynchronized (unaligned) populations" + ) + logger.info( + f"**Detection**: {n_consecutive} consecutive timepoints above {threshold_multiplier}x reference IQR\n" + ) + + # Summary statistics by comparison type + for comparison in [ + "aligned_vs_uninfected", + "unaligned_vs_uninfected", + "aligned_vs_unaligned", + ]: + comparison_data = divergence_df[divergence_df["comparison"] == comparison] + + if comparison == "aligned_vs_uninfected": + comp_label = "**DTW-Aligned Infected vs Uninfected Control**" + description = "Conserved response timing in synchronized cells" + elif comparison == "unaligned_vs_uninfected": + comp_label = "**Unaligned Infected vs Uninfected Control**" + description = "Population average without synchronization" + else: + comp_label = "**DTW-Aligned vs Unaligned Infected**" + description = "Effect of DTW synchronization" + + logger.info(f"\n### {comp_label}") + logger.info(f"_{description}_\n") + + # Table header + logger.info( + "| Feature | Divergence Time | Time from Start | Magnitude | Detected |" + ) + logger.info( + "|---------|----------------|-----------------|-----------|----------|" + ) + + for _, row in comparison_data.iterrows(): + divergence_time = ( + f"{row['divergence_time']:.1f}" if row["has_divergence"] else "N/A" + ) + time_from_start = ( + f"{row['time_from_start']:.1f}" if row["has_divergence"] else "N/A" + ) + magnitude = f"{row['divergence_magnitude']:.3f}" + detected = "✓" if row["has_divergence"] else "✗" + + logger.info( + f"| {row['feature']} | {divergence_time} | {time_from_start} | " + f"{magnitude} | {detected} |" + ) + + # Summary insights + logger.info("\n### Key Insights") + + # Which features diverge earliest in aligned cells? + aligned_divergent = divergence_df[ + (divergence_df["comparison"] == "aligned_vs_uninfected") + & (divergence_df["has_divergence"]) + ].sort_values("divergence_time") + + if len(aligned_divergent) > 0: + earliest_features = aligned_divergent.head(3) + logger.info("\n**Earliest remodeling (DTW-aligned):**") + for _, row in earliest_features.iterrows(): + logger.info( + f"- **{row['feature']}**: t={row['divergence_time']:.1f} (Δt={row['time_from_start']:.1f})" + ) + + # Does synchronization help reveal timing? + features_with_both = [] + for feature in common_response_features: + aligned_div = divergence_df[ + (divergence_df["feature"] == feature) + & (divergence_df["comparison"] == "aligned_vs_uninfected") + ].iloc[0] + unaligned_div = divergence_df[ + (divergence_df["feature"] == feature) + & (divergence_df["comparison"] == "unaligned_vs_uninfected") + ].iloc[0] + + if aligned_div["has_divergence"] and unaligned_div["has_divergence"]: + time_diff = ( + aligned_div["divergence_time"] - unaligned_div["divergence_time"] + ) + features_with_both.append( + { + "feature": feature, + "time_diff": time_diff, + "aligned_time": aligned_div["divergence_time"], + "unaligned_time": unaligned_div["divergence_time"], + } + ) + + if len(features_with_both) > 0: + logger.info( + "\n**Impact of DTW synchronization (features with divergence in both):**" + ) + for item in sorted( + features_with_both, key=lambda x: abs(x["time_diff"]), reverse=True + )[:3]: + direction = "earlier" if item["time_diff"] < 0 else "later" + logger.info( + f"- **{item['feature']}**: Aligned diverges {abs(item['time_diff']):.1f} " + f"timepoints {direction} (t={item['aligned_time']:.1f} vs t={item['unaligned_time']:.1f})" + ) + + logger.info("\n" + "=" * 70) +else: + logger.warning("Warped metadata not available, skipping warped time comparison") + + +# %% +from cmap import Colormap +from skimage.exposure import adjust_gamma, rescale_intensity + +z_range = slice(0, 1) +initial_yx_patch_size = (192, 192) +# Top matches should be unique fov_name and lineage_id combinations +matches_path = ( + output_root + / f"consensus_lineage_{ALIGN_TYPE}_{ALIGNMENT_CHANNEL}_matching_lineages_cosine.csv" +) +matches = pd.read_csv(matches_path) +top_matches = matches.head(top_n_cells) + +positions = [] +tracks_tables = [] +images_plate = open_ome_zarr(data_path) +selected_channels = images_plate.channel_names +# Load matching positions +print(f"Loading positions for {len(top_matches)} FOV matches...") +matches_found = 0 +for _, pos in images_plate.positions(): + pos_name = pos.zgroup.name + pos_normalized = pos_name.lstrip("/") + + if pos_normalized in top_matches["fov_name"].values: + positions.append(pos) + matches_found += 1 + + # Get ALL tracks for this FOV to ensure TripletDataset has complete access + tracks_df = cytodtw.adata.obs[ + cytodtw.adata.obs["fov_name"] == pos_normalized + ].copy() + + if len(tracks_df) > 0: + tracks_df = tracks_df.dropna(subset=["x", "y"]) + tracks_df["x"] = tracks_df["x"].astype(int) + tracks_df["y"] = tracks_df["y"].astype(int) + tracks_tables.append(tracks_df) + + if matches_found == 1: + processing_channels = pos.channel_names + +print( + f"Loaded {matches_found} positions with {sum(len(df) for df in tracks_tables)} total tracks" +) + +dataset = TripletDataset( + positions=positions, + tracks_tables=tracks_tables, + channel_names=selected_channels, + initial_yx_patch_size=initial_yx_patch_size, + z_range=z_range, + fit=False, + predict_cells=False, + include_fov_names=None, + include_track_ids=None, + time_interval=1, + return_negative=False, +) + + +def load_images_from_triplet_dataset(fov_name, track_ids): + """Load images from TripletDataset for given FOV and track IDs.""" + matching_indices = [] + for dataset_idx in range(len(dataset.valid_anchors)): + anchor_row = dataset.valid_anchors.iloc[dataset_idx] + if anchor_row["fov_name"] == fov_name and anchor_row["track_id"] in track_ids: + matching_indices.append(dataset_idx) + + if not matching_indices: + logger.warning( + f"No matching indices found for FOV {fov_name}, tracks {track_ids}" + ) + return {} + + # Get images and create time mapping + batch_data = dataset.__getitems__(matching_indices) + images = [] + for i in range(len(matching_indices)): + img_data = { + "anchor": batch_data["anchor"][i], + "index": batch_data["index"][i], + } + images.append(img_data) + + images.sort(key=lambda x: x["index"]["t"]) + return {img["index"]["t"]: img for img in images} + + +# Use alignment_df_for_plotting directly to get full concatenated sequences +# (includes unaligned before + aligned + unaligned after) +# The function internally handles filtering to aligned cells and extracting all timepoints +concatenated_image_sequences = get_aligned_image_sequences( + cytodtw_instance=cytodtw, + df=alignment_df_for_plotting, + alignment_name=ALIGN_TYPE, + image_loader_fn=load_images_from_triplet_dataset, + max_lineages=30, +) + +figure_output_path = output_root / "figure_parts" +figure_output_path.mkdir(exist_ok=True, parents=True) + +green_cmap = Colormap("green") +magenta_cmap = Colormap("magenta") + +seq_values = list(concatenated_image_sequences.keys()) + +# Taking the first lineage for example +lineage_id = seq_values[0] + +concatenated_images = concatenated_image_sequences[lineage_id]["concatenated_images"] + +# Stack images into time series +image_stack = [] +for img_sample in concatenated_images: + if img_sample is not None: + img_tensor = img_sample["anchor"] + img_np = img_tensor.cpu().numpy() + image_stack.append(img_np) + + if len(image_stack) > 0: + time_series = np.stack(image_stack, axis=0) + n_channels = time_series.shape[1] + +infection_timepoint = ( + absolute_infection_timepoint # Use the computed value from line 103 +) +tidx_figures = [ + min(0, infection_timepoint - 5), + infection_timepoint, + min(infection_timepoint + 10, time_series.shape[0] - 1), + min(infection_timepoint + 20, time_series.shape[0] - 1), +] + +organelle_clims = (104, 383) +sensor_clims = (102, 165) +phase_clims = (-0.79, 0.6) +for tidx in tidx_figures: + # FIXME: hardcoded channel order the current dataset + img_phase = time_series[tidx, 0, 0] + img_organelle = time_series[tidx, 1, 0] + img_sensor = time_series[tidx, 2, 0] + + # Apply gamma correction first (optional) + img_sensor = adjust_gamma(img_sensor, gamma=1) + img_organelle = adjust_gamma(img_organelle, gamma=1) + img_phase = rescale_intensity(img_phase, in_range=phase_clims, out_range=(0, 1)) + + # Use in_range to specify your contrast limits + img_sensor = rescale_intensity(img_sensor, in_range=sensor_clims, out_range=(0, 1)) + img_organelle = rescale_intensity( + img_organelle, in_range=organelle_clims, out_range=(0, 1) + ) + + # Apply colormaps + img_sensor = magenta_cmap(img_sensor) + img_organelle = green_cmap(img_organelle) + img_rgb = np.clip(img_sensor[..., :3] + img_organelle[..., :3], 0, 1) + + # Fluorescence only + fig = plt.figure(figsize=(4, 4)) + plt.imshow(img_rgb) + plt.axis("off") + plt.tight_layout() + plt.savefig( + figure_output_path / f"lineage_{lineage_id}_t{tidx}_fluor.png", + dpi=300, + bbox_inches="tight", + ) + plt.show() + + # Phase only + fig = plt.figure(figsize=(4, 4)) + plt.imshow(img_phase, cmap="gray") + plt.axis("off") + plt.tight_layout() + plt.savefig( + figure_output_path / f"lineage_{lineage_id}_t{tidx}_phase.png", + dpi=300, + bbox_inches="tight", + ) + plt.show() + +# %% diff --git a/applications/pseudotime_analysis/organelle_segmentation/extract_features.py b/applications/pseudotime_analysis/organelle_segmentation/extract_features.py new file mode 100644 index 00000000..c933c5a5 --- /dev/null +++ b/applications/pseudotime_analysis/organelle_segmentation/extract_features.py @@ -0,0 +1,297 @@ +import warnings as warning + +import numpy as np +import pandas as pd +from numpy.typing import ArrayLike +from skimage import measure +from skimage.feature import graycomatrix, graycoprops + + +def extract_features_zyx( + labels_zyx: ArrayLike, + intensity_zyx: ArrayLike = None, + frangi_zyx: ArrayLike = None, + spacing: tuple = (1.0, 1.0), + properties: list = None, + extra_properties: list = None, +): + """ + Extract morphological and intensity features from labeled organelles. + + Handles both 2D (Z=1) and 3D (Z>1) data automatically + For 2D data, processes the single Z-slice. For 3D data, performs max projection + along Z axis before feature extraction. + + Based on: + Lefebvre, A.E.Y.T., Sturm, G., Lin, TY. et al. + Nellie (2025) https://doi.org/10.1038/s41592-025-02612-7 + + Parameters + ---------- + labels_zyx : ndarray + Labeled segmentation mask with shape (Z, Y, X). + Each unique integer value represents a different organelle instance. + intensity_zyx : ndarray, optional + Original intensity image with shape (Z, Y, X) for computing + intensity-based features. If None, only morphological features computed. + frangi_image : ndarray, optional + Frangi vesselness response with shape (Z, Y, X) for computing + tubularity/filament features. + spacing : tuple + Physical spacing in same units (e.g., µm). + For 2D (Z=1): (pixel_size_y, pixel_size_x) + For 3D (Z>1): (pixel_size_z, pixel_size_y, pixel_size_x) + properties : list of str, optional + List of standard regionprops features to compute. If None, uses default set. + Available: 'label', 'area', 'perimeter', 'axis_major_length', + 'axis_minor_length', 'solidity', 'extent', 'orientation', + 'equivalent_diameter_area', 'convex_area', 'eccentricity', + 'mean_intensity', 'min_intensity', 'max_intensity' + extra_properties : list of str, optional + Additional features beyond standard regionprops. Options: + - 'moments_hu': Hu moments (shape descriptors, 7 features) + - 'texture': Haralick texture features (4 features: contrast, homogeneity, energy, correlation) + - 'aspect_ratio': Major axis / minor axis ratio + - 'circularity': area / perimeter + - 'frangi_intensity': Mean/min/max/sum/std of Frangi vesselness + - 'feret_diameter_max': Maximum Feret diameter (expensive) + - 'sum_intensity': Sum of intensity values + - 'std_intensity': Standard deviation of intensity values + + Returns + ------- + features_df : pd.DataFrame + DataFrame where each row represents one labeled object with columns + for each computed feature. Always includes 'label' and 'channel' columns. + + Examples + -------- + >>> # Basic morphology only + >>> df = extract_features_zyx(labels_zyx) + + >>> # With intensity features + >>> df = extract_features_zyx(labels_zyx, intensity_zyx=intensity) + + >>> # Custom property selection + >>> df = extract_features_zyx( + ... labels_zyx, + ... intensity_zyx=intensity, + ... properties=['label', 'area', 'mean_intensity'], + ... extra_properties=['aspect_ratio', 'circularity'] + ... ) + + >>> # Full feature set including Frangi + >>> df = extract_features_zyx( + ... labels_zyx, + ... intensity_zyx=intensity, + ... frangi_image=vesselness, + ... extra_properties=['moments_hu', 'texture', 'frangi_intensity'] + ... ) + """ + + if intensity_zyx is not None: + assert intensity_zyx.shape == labels_zyx.shape, ( + "Image and labels must have same shape" + ) + + Z, _, _ = labels_zyx.shape + + # Default properties if not specified + if properties is None: + properties = [ + "label", + "area", + "perimeter", + "axis_major_length", + "axis_minor_length", + "solidity", + "extent", + "orientation", + "equivalent_diameter_area", + "convex_area", + "eccentricity", + ] + # Add intensity features if image provided + if intensity_zyx is not None: + properties.extend(["mean_intensity", "min_intensity", "max_intensity"]) + + if extra_properties is None: + extra_properties = [] + + # Determine 2D vs 3D mode and set appropriate spacing + spacing_2d = spacing if len(spacing) == 2 else spacing[-2:] + + if Z == 1: + # Squeeze Z dimension for 2D processing + labels_processed = labels_zyx[0] # Shape: (Y, X) + intensity_processed = intensity_zyx[0] if intensity_zyx is not None else None + frangi_processed = frangi_zyx[0] if frangi_zyx is not None else None + else: + # Use max projection along Z for 3D -> 2D + labels_processed = np.max(labels_zyx, axis=0) # Shape: (Y, X) + intensity_processed = ( + np.max(intensity_zyx, axis=0) if intensity_zyx is not None else None + ) + frangi_processed = ( + np.max(frangi_zyx, axis=0) if frangi_zyx is not None else None + ) + + # Check if we have any objects to process + if labels_processed.max() == 0: + return pd.DataFrame() + + # Compute base regionprops features (those that support spacing) + props_with_spacing = [p for p in properties if p not in ["moments_hu"]] + + try: + props_dict = measure.regionprops_table( + labels_processed, + intensity_image=intensity_processed, + properties=tuple(props_with_spacing), + spacing=spacing_2d, + ) + df = pd.DataFrame(props_dict) + except Exception as e: + warning.warn(f"Error computing base regionprops: {e}") + return pd.DataFrame() + + # Add Hu moments separately (without spacing) + if "moments_hu" in properties or "moments_hu" in extra_properties: + try: + hu_props = measure.regionprops_table( + labels_processed, properties=("label", "moments_hu"), spacing=(1, 1) + ) + hu_df = pd.DataFrame(hu_props) + # Rename columns to be clearer + hu_rename = {f"moments_hu-{i}": f"hu_moment_{i}" for i in range(7)} + hu_df = hu_df.rename(columns=hu_rename) + df = df.merge(hu_df, on="label", how="left") + except Exception as e: + warning.warn(f"Could not compute Hu moments: {e}") + + # Add derived metrics + if "aspect_ratio" in extra_properties: + minor_axis = df["axis_minor_length"].replace(0, 1) # Avoid division by zero + df["aspect_ratio"] = df["axis_major_length"] / minor_axis + + if "circularity" in extra_properties: + perimeter_sq = df["perimeter"] ** 2 + df["circularity"] = np.divide( + 4 * np.pi * df["area"], + perimeter_sq, + out=np.ones_like(perimeter_sq), + where=perimeter_sq != 0, + ) + + # Add expensive/iterative features + if any( + prop in extra_properties + for prop in ["texture", "feret_diameter_max", "frangi_intensity"] + ): + regions = measure.regionprops( + labels_processed, intensity_image=intensity_processed + ) + extra_features = [] + + for region in regions: + features = {"label": region.label} + + # Haralick texture features + if "texture" in extra_properties and intensity_processed is not None: + min_r, min_c, max_r, max_c = region.bbox + region_intensity = ( + intensity_processed[min_r:max_r, min_c:max_c] * region.image + ) + + # Normalize to uint8 + min_val, max_val = region_intensity.min(), region_intensity.max() + if max_val > min_val: + region_uint8 = ( + (region_intensity - min_val) / (max_val - min_val) * 255 + ).astype(np.uint8) + else: + region_uint8 = np.zeros_like(region_intensity, dtype=np.uint8) + + try: + glcm = graycomatrix( + region_uint8, + distances=[1], + angles=[0], + levels=256, + symmetric=True, + normed=True, + ) + features["texture_contrast"] = graycoprops(glcm, "contrast")[0, 0] + features["texture_homogeneity"] = graycoprops(glcm, "homogeneity")[ + 0, 0 + ] + features["texture_energy"] = graycoprops(glcm, "energy")[0, 0] + features["texture_correlation"] = graycoprops(glcm, "correlation")[ + 0, 0 + ] + except Exception: + features["texture_contrast"] = np.nan + features["texture_homogeneity"] = np.nan + features["texture_energy"] = np.nan + features["texture_correlation"] = np.nan + + # Feret diameter + if "feret_diameter_max" in extra_properties: + features["feret_diameter_max"] = region.feret_diameter_max + + # Frangi intensity features + if "frangi_intensity" in extra_properties and frangi_processed is not None: + min_r, min_c, max_r, max_c = region.bbox + region_frangi = frangi_processed[min_r:max_r, min_c:max_c][region.image] + + if region_frangi.size > 0: + features["frangi_mean_intensity"] = np.mean(region_frangi) + features["frangi_min_intensity"] = np.min(region_frangi) + features["frangi_max_intensity"] = np.max(region_frangi) + features["frangi_sum_intensity"] = np.sum(region_frangi) + features["frangi_std_intensity"] = np.std(region_frangi) + else: + features["frangi_mean_intensity"] = np.nan + features["frangi_min_intensity"] = np.nan + features["frangi_max_intensity"] = np.nan + features["frangi_sum_intensity"] = np.nan + features["frangi_std_intensity"] = np.nan + + extra_features.append(features) + + if extra_features: + extra_df = pd.DataFrame(extra_features) + df = df.merge(extra_df, on="label", how="left") + + # Add sum and std intensity if we have intensity image + if intensity_processed is not None and ( + "sum_intensity" in extra_properties or "std_intensity" in extra_properties + ): + regions = measure.regionprops( + labels_processed, intensity_image=intensity_processed + ) + sum_std_features = [] + + for region in regions: + min_r, min_c, max_r, max_c = region.bbox + region_pixels = intensity_processed[min_r:max_r, min_c:max_c][region.image] + + features = {"label": region.label} + if region_pixels.size > 0: + if "sum_intensity" in extra_properties: + features["sum_intensity"] = np.sum(region_pixels) + if "std_intensity" in extra_properties: + features["std_intensity"] = np.std(region_pixels) + else: + if "sum_intensity" in extra_properties: + features["sum_intensity"] = np.nan + if "std_intensity" in extra_properties: + features["std_intensity"] = np.nan + + sum_std_features.append(features) + + if sum_std_features: + sum_std_df = pd.DataFrame(sum_std_features) + df = df.merge(sum_std_df, on="label", how="left") + + return df diff --git a/applications/pseudotime_analysis/organelle_segmentation/segment_mito.py b/applications/pseudotime_analysis/organelle_segmentation/segment_mito.py new file mode 100644 index 00000000..4f68ca9f --- /dev/null +++ b/applications/pseudotime_analysis/organelle_segmentation/segment_mito.py @@ -0,0 +1,381 @@ +# %% code for organelle and nuclear segmentation and feature extraction + +import os +from pathlib import Path + +import napari +import numpy as np +import pandas as pd +from extract_features import ( + extract_features_zyx, +) +from iohub import open_ome_zarr +from matplotlib import pyplot as plt +from segment_organelles import ( + calculate_nellie_sigmas, + segment_zyx, +) +from skimage.exposure import rescale_intensity +from tqdm import tqdm + +os.environ["DISPLAY"] = ":1" +viewer = napari.Viewer() +# %% + +# input_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_21_A549_TOMM20_DENV/4-phenotyping/train-test/2024_11_21_A549_TOMM20_DENV.zarr" +input_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV.zarr" +) +# input_path = "/hpc/projects/intracellular_dashboard/organelle_box/2025_04_04_organelle_box_Live_dye/4-concatenate/organelle_box_live_dye_TOMM20.zarr" +input_zarr = open_ome_zarr(input_path, mode="r", layout="hcs") +in_chans = input_zarr.channel_names +organelle_channel_name = "GFP EX488 EM525-45" +# Organelle_chan = "4-MultiCam_GFP_Cy5-BSI_Express" + +output_root = ( + Path( + "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/organelle_segmentation/output" + ) + / input_path.stem +) +output_root.mkdir(parents=True, exist_ok=True) + +# %% +# Frangi parameters - WORKING configuration for 2D mitochondria segmentation +frangi_params = { + "clahe_clip_limit": 0.01, # Mild contrast enhancement (0.01-0.03 range) + "sigma_steps": 2, # Multiple scales to capture size variation + "auto_optimize_sigma": False, # Use multi-scale (max across scales) + "frangi_alpha": 0.5, # Standard tubular structure sensitivity + "frangi_beta": 0.5, # Standard blob rejection + "threshold_method": "nellie_max", # CRITICAL: Manual threshold where auto methods fail + "min_object_size": 10, # Remove small noise clusters (20-50 pixels) + "apply_morphology": False, # Connect fragmented mitochondria +} + + +position_names = [] +for ds, position in input_zarr.positions(): + position_names.append(tuple(ds.split("/"))) + + +for well_id, well_data in input_zarr.wells(): + # print(well_id) + # well_id, well_data = next(input_zarr.wells()) + well_name, well_no = well_id.split("/") + + if "B/1" not in well_id: + continue + # if well_id == 'C/4': + # print(well_name, well_no) + for pos_id, pos_data in well_data.positions(): + if pos_id != "000000": + continue + scale = ( + pos_data.metadata.multiscales[0] + .datasets[0] + .coordinate_transformations[0] + .scale + ) + pixel_size_um = scale[-1] # XY pixel size in micrometers + z_spacing_um = scale[-3] # Z spacing in micrometers + print(f" Pixel size: {pixel_size_um:.4f} µm, Z spacing: {z_spacing_um:.4f} µm") + + in_data = pos_data.data.numpy() + if in_data.shape[-3] != 1: + in_data = np.max(in_data, axis=-3, keepdims=True) + T, C, Z, Y, X = in_data.shape + print(f"Input data shape: {in_data.shape} (T={T}, C={C}, Z={Z}, Y={Y}, X={X})") + + # Extract and normalize organelle channel (keep Z dimension) + organelle_data = in_data[ + :, in_chans.index(organelle_channel_name), : + ] # (T, Z, Y, X) + organelle_data = rescale_intensity(organelle_data, out_range=(0, 1)) + print(f"Organelle data shape after extraction: {organelle_data.shape}") + + # Calculate sigma range - ADJUSTED for your pixel size + # With 0.1494 µm/pixel, mitochondria (0.3-1.0 µm diameter) = 2-7 pixels diameter + # Sigma should be ~radius/2, so for diameter 2-7px, sigma = 0.5-1.75 px + min_radius_um = 0.15 # 300 nm diameter = ~2 pixels + max_radius_um = 0.6 # 1 µm diameter = ~6.7 pixels + sigma_range = calculate_nellie_sigmas( + min_radius_um, + max_radius_um, + pixel_size_um, + num_sigma=frangi_params["sigma_steps"], + ) + + print(f"Using sigma range: {sigma_range[0]:.2f} to {sigma_range[1]:.2f} pixels") + + # Frangi filtering and segmentation + print( + f"Computing Frangi segmentation and feature extraction for {well_id}/{pos_id}..." + ) + frangi_seg_masks = [] + frangi_vesselness_maps = [] + all_features = [] + + # FIXME: temporary for testing + selected_timepoints = np.linspace(0, T - 1, 3).astype(int) + for t in tqdm(selected_timepoints, desc="Processing timepoints"): + labels, vesselness, optimal_sigma = segment_zyx( + organelle_data[t], sigma_range=sigma_range, **frangi_params + ) + frangi_seg_masks.append(labels[0]) + frangi_vesselness_maps.append(vesselness[0]) + + # Extract features from this timepoint + features_df = extract_features_zyx( + labels_zyx=labels, + intensity_zyx=organelle_data[t], + frangi_zyx=vesselness, + spacing=(pixel_size_um, pixel_size_um), + extra_properties=[ + "aspect_ratio", + "circularity", + "frangi_intensity", + # "texture", + # "moments_hu", + ], + ) + + if not features_df.empty: + features_df["well_id"] = well_id + features_df["position_id"] = pos_id + features_df["timepoint"] = t + all_features.append(features_df) + + frangi_seg_masks = np.array(frangi_seg_masks) + frangi_vesselness_maps = np.array(frangi_vesselness_maps) + + # Save combined features + if all_features: + combined_features = pd.concat(all_features, ignore_index=True) + output_csv = output_root / f"features_{well_name}_{well_no}_{pos_id}.csv" + combined_features.to_csv(output_csv, index=False) + print(f" Saved {len(combined_features)} object features to {output_csv}") + + # Convert to output format (T_actual, C=1, Z, Y, X) + T_actual = frangi_seg_masks.shape[0] + out_data = frangi_seg_masks[:, :, :].astype(np.uint32) + print(f" Processed {T_actual} timepoints, output shape: {out_data.shape}") + + position_key = (well_name, well_no, pos_id) + +# %% + +viewer.add_image(organelle_data[selected_timepoints, 0]) +viewer.add_labels(frangi_seg_masks) + + +# %% +# Plot mitochondrial dynamics: elongation and fragmentation + +if all_features: + df = combined_features + + # Aggregate features per timepoint + timepoint_summary = ( + df.groupby("timepoint") + .agg( + { + "label": "count", # Number of mitochondrial objects + "area": ["mean", "median", "sum"], # Size metrics + "aspect_ratio": ["mean", "median"], # Elongation metric + "circularity": ["mean", "median"], # Shape metric + "frangi_mean_intensity": ["mean", "median"], # Tubularity metric + # "moments_hu_1": ["mean", "median"], # Shape descriptor + # "moments_hu_2": ["mean", "median"], # Shape descriptor + # "moments_hu_3": ["mean", "median"], # Shape descriptor + # "moments_hu_4": ["mean", "median"], # Shape descriptor + # "contrast": ["mean", "median"], # Texture metric + } + ) + .reset_index() + ) + + # Flatten column names + timepoint_summary.columns = [ + "_".join(col).strip("_") for col in timepoint_summary.columns.values + ] + + # Create figure with subplots + fig, axes = plt.subplots(2, 3, figsize=(15, 10)) + fig.suptitle( + f"Mitochondrial Dynamics: {well_id}/{pos_id}", fontsize=14, fontweight="bold" + ) + + # Plot 1: Number of objects (fragmentation indicator) + ax = axes[0, 0] + ax.plot( + timepoint_summary["timepoint"], + timepoint_summary["label_count"], + marker="o", + linewidth=2, + markersize=8, + color="#1f77b4", + ) + ax.set_xlabel("Timepoint", fontsize=11) + ax.set_ylabel("Number of Objects", fontsize=11) + ax.set_title("Fragmentation (Object Count)", fontsize=12, fontweight="bold") + ax.grid(True, alpha=0.3) + + # Plot 2: Mean area per object + ax = axes[0, 1] + ax.plot( + timepoint_summary["timepoint"], + timepoint_summary["area_mean"], + marker="o", + linewidth=2, + markersize=8, + color="#ff7f0e", + label="Mean", + ) + ax.plot( + timepoint_summary["timepoint"], + timepoint_summary["area_median"], + marker="s", + linewidth=2, + markersize=7, + color="#d62728", + label="Median", + alpha=0.7, + ) + ax.set_xlabel("Timepoint", fontsize=11) + ax.set_ylabel("Area (µm²)", fontsize=11) + ax.set_title("Mitochondrial Size", fontsize=12, fontweight="bold") + ax.legend() + ax.grid(True, alpha=0.3) + + # Plot 3: Total area (network coverage) + ax = axes[0, 2] + ax.plot( + timepoint_summary["timepoint"], + timepoint_summary["area_sum"], + marker="o", + linewidth=2, + markersize=8, + color="#2ca02c", + ) + ax.set_xlabel("Timepoint", fontsize=11) + ax.set_ylabel("Total Area (µm²)", fontsize=11) + ax.set_title("Total Mitochondrial Coverage", fontsize=12, fontweight="bold") + ax.grid(True, alpha=0.3) + + # Plot 4: Aspect ratio (elongation) + ax = axes[1, 0] + ax.plot( + timepoint_summary["timepoint"], + timepoint_summary["aspect_ratio_mean"], + marker="o", + linewidth=2, + markersize=8, + color="#9467bd", + label="Mean", + ) + ax.plot( + timepoint_summary["timepoint"], + timepoint_summary["aspect_ratio_median"], + marker="s", + linewidth=2, + markersize=7, + color="#8c564b", + label="Median", + alpha=0.7, + ) + ax.set_xlabel("Timepoint", fontsize=11) + ax.set_ylabel("Aspect Ratio", fontsize=11) + ax.set_title("Elongation (Aspect Ratio)", fontsize=12, fontweight="bold") + ax.legend() + ax.grid(True, alpha=0.3) + + # Plot 5: Circularity + ax = axes[1, 1] + ax.plot( + timepoint_summary["timepoint"], + timepoint_summary["circularity_mean"], + marker="o", + linewidth=2, + markersize=8, + color="#e377c2", + label="Mean", + ) + ax.plot( + timepoint_summary["timepoint"], + timepoint_summary["circularity_median"], + marker="s", + linewidth=2, + markersize=7, + color="#7f7f7f", + label="Median", + alpha=0.7, + ) + ax.set_xlabel("Timepoint", fontsize=11) + ax.set_ylabel("Circularity", fontsize=11) + ax.set_title("Shape Circularity", fontsize=12, fontweight="bold") + ax.legend() + ax.grid(True, alpha=0.3) + + # Plot 6: Frangi vesselness (tubularity) + ax = axes[1, 2] + ax.plot( + timepoint_summary["timepoint"], + timepoint_summary["frangi_mean_intensity_mean"], + marker="o", + linewidth=2, + markersize=8, + color="#bcbd22", + label="Mean", + ) + ax.plot( + timepoint_summary["timepoint"], + timepoint_summary["frangi_mean_intensity_median"], + marker="s", + linewidth=2, + markersize=7, + color="#17becf", + label="Median", + alpha=0.7, + ) + ax.set_xlabel("Timepoint", fontsize=11) + ax.set_ylabel("Frangi Vesselness", fontsize=11) + ax.set_title("Tubularity (Frangi)", fontsize=12, fontweight="bold") + ax.legend() + ax.grid(True, alpha=0.3) + + plt.tight_layout() + + # Save figure + output_fig = output_root / f"dynamics_{well_name}_{well_no}_{pos_id}.png" + plt.savefig(output_fig, dpi=300, bbox_inches="tight") + print(f" Saved dynamics plot to {output_fig}") + + plt.show() + + # Print summary statistics + print(f"\n=== Mitochondrial Dynamics Summary ===") + print(f"Position: {well_id}/{pos_id}") + print(f"\nTimepoint range: {selected_timepoints[0]} -> {selected_timepoints[-1]}") + print(f"\nFragmentation (Object Count):") + print(f" Start: {timepoint_summary['label_count'].iloc[0]:.0f} objects") + print(f" End: {timepoint_summary['label_count'].iloc[-1]:.0f} objects") + print( + f" Change: {timepoint_summary['label_count'].iloc[-1] - timepoint_summary['label_count'].iloc[0]:+.0f} ({(timepoint_summary['label_count'].iloc[-1]/timepoint_summary['label_count'].iloc[0] - 1)*100:+.1f}%)" + ) + + print(f"\nElongation (Aspect Ratio):") + print(f" Start: {timepoint_summary['aspect_ratio_mean'].iloc[0]:.2f}") + print(f" End: {timepoint_summary['aspect_ratio_mean'].iloc[-1]:.2f}") + print( + f" Change: {timepoint_summary['aspect_ratio_mean'].iloc[-1] - timepoint_summary['aspect_ratio_mean'].iloc[0]:+.2f} ({(timepoint_summary['aspect_ratio_mean'].iloc[-1]/timepoint_summary['aspect_ratio_mean'].iloc[0] - 1)*100:+.1f}%)" + ) + + print(f"\nMean Object Size (Area):") + print(f" Start: {timepoint_summary['area_mean'].iloc[0]:.2f} µm²") + print(f" End: {timepoint_summary['area_mean'].iloc[-1]:.2f} µm²") + print( + f" Change: {timepoint_summary['area_mean'].iloc[-1] - timepoint_summary['area_mean'].iloc[0]:+.2f} µm² ({(timepoint_summary['area_mean'].iloc[-1]/timepoint_summary['area_mean'].iloc[0] - 1)*100:+.1f}%)" + ) + +# %% diff --git a/applications/pseudotime_analysis/organelle_segmentation/segment_mito_sc.py b/applications/pseudotime_analysis/organelle_segmentation/segment_mito_sc.py new file mode 100644 index 00000000..c1b4ce8b --- /dev/null +++ b/applications/pseudotime_analysis/organelle_segmentation/segment_mito_sc.py @@ -0,0 +1,209 @@ +# %% +import os +from logging import warning +from pathlib import Path + +import numpy as np +import pandas as pd +from extract_features import ( + extract_features_zyx, +) +from iohub import open_ome_zarr +from joblib import Parallel, delayed + + +# %% +def get_patch(data, cell_centroid, PATCH_SIZE): + """ + Extract a patch of PATCH_SIZE x PATCH_SIZE centered on the cell centroid. + If the patch would extend beyond image boundaries, slide it to fit while + keeping the centroid within the patch. + + Returns None if the image is smaller than PATCH_SIZE in any dimension. + """ + x_centroid, y_centroid = cell_centroid + height, width = data.shape + + # Check if image is large enough for patch + if height < PATCH_SIZE or width < PATCH_SIZE: + return None + + # Calculate ideal patch boundaries centered on centroid + x_start = x_centroid - PATCH_SIZE // 2 + x_end = x_centroid + PATCH_SIZE // 2 + y_start = y_centroid - PATCH_SIZE // 2 + y_end = y_centroid + PATCH_SIZE // 2 + + # Slide patch if it extends beyond left/top boundaries + if x_start < 0: + x_start = 0 + x_end = PATCH_SIZE + if y_start < 0: + y_start = 0 + y_end = PATCH_SIZE + + # Slide patch if it extends beyond right/bottom boundaries + if x_end > width: + x_end = width + x_start = width - PATCH_SIZE + if y_end > height: + y_end = height + y_start = height - PATCH_SIZE + + # Extract patch (should always be PATCH_SIZE x PATCH_SIZE now) + patch = data[int(y_start) : int(y_end), int(x_start) : int(x_end)] + return patch + + +# TODO add the intesnity zarr and parsing + + +def process_position( + well_id, + pos_id, + segmentations_zarr, + nuclear_labels_path, + patch_size, +): + """Process a single position and return the features DataFrame.""" + # Open zarr stores (each worker needs its own file handles) + input_zarr = open_ome_zarr(segmentations_zarr, mode="r") + + well_name, well_no = well_id.split("/") + + # Load position data + position = input_zarr[well_id + "/" + pos_id] + T, C, Z, Y, X = position.data.shape + channel_names = position.channel_names + scale_um = position.scale + + in_data = position.data.numpy() + + # Read the csv stored in each nucl seg zarr folder + file_name = "tracks_" + well_name + "_" + well_no + "_" + pos_id + ".csv" + nuclear_labels_csv = os.path.join( + nuclear_labels_path, well_id + "/" + pos_id + "/" + file_name + ) + nuclear_labels_df = pd.read_csv(nuclear_labels_csv) + + for chan_name in channel_names: + if "_labels" in chan_name: + labels_cidx = channel_names.index(chan_name) + if "_vesselness" in chan_name: + vesselness_cidx = channel_names.index(chan_name) + + labels = in_data[:, labels_cidx] + vesselness = in_data[:, vesselness_cidx] + + # Initialize an empty list to store values from each row of the csv + position_features = [] + for idx, row in nuclear_labels_df.iterrows(): + cell_centroid = row["x"], row["y"] + timepoint = row["t"] + + # Extract patches (will slide to fit within boundaries) + labels_patch = get_patch(labels[int(timepoint), 0], cell_centroid, patch_size) + vesselness_patch = get_patch( + vesselness[int(timepoint), 0], cell_centroid, patch_size + ) + + # Skip if patches couldn't be extracted (image too small) + if labels_patch is None or vesselness_patch is None: + continue + + label_patch = labels_patch[np.newaxis].astype(np.uint32) + vesselness_patch = vesselness_patch[np.newaxis] + + features_df = extract_features_zyx( + labels_zyx=label_patch, + intensity_zyx=None, + frangi_zyx=vesselness_patch, + spacing=(scale_um[-1], scale_um[-1]), + extra_properties=[ + "aspect_ratio", + "circularity", + "eccentricity", + "solidity", + "frangi_intensity", + "texture", + "moments_hu", + ], + ) + + if not features_df.empty: + features_df["fov_name"] = well_id + "/" + pos_id + features_df["track_id"] = row["track_id"] + features_df["t"] = timepoint + features_df["x"] = row["x"] + features_df["y"] = row["y"] + position_features.append(features_df) + + input_zarr.close() + if position_features: + # Concatenate the list of DataFrames + position_df = pd.concat(position_features, ignore_index=True) + return position_df + else: + warning(f"No valid features extracted for position {well_id}/{pos_id}.") + return pd.DataFrame() + + +# %% +if __name__ == "__main__": + segmentations_zarr = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/4-phenotyping/train-test/train_test_mito_seg_2.zarr" + ) + input_zarr = open_ome_zarr(segmentations_zarr, mode="r") + in_chans = input_zarr.channel_names + + nuclear_labels_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV/1-preprocess/label-free/3-track/2025_07_24_A549_SEC61_TOMM20_G3BP1_ZIKV_cropped.zarr" + + output_root = ( + Path( + "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/organelle_segmentation/output" + ) + / segmentations_zarr.stem + ) + output_root.mkdir(parents=True, exist_ok=True) + + output_csv = output_root / f"{segmentations_zarr.stem}_mito_features_nellie.csv" + + PATCH_SIZE = 160 + + # Number of parallel jobs - adjust based on your system + # -1 means use all available cores, or set to a specific number + n_jobs = -1 + + # Collect all (well_id, pos_id) pairs to process + with open_ome_zarr(segmentations_zarr, mode="r") as input_zarr: + positions_to_process = [] + for well_id, well_data in input_zarr.wells(): + for pos_id, pos_data in well_data.positions(): + positions_to_process.append((well_id, pos_id)) + + print( + f"Processing {len(positions_to_process)} positions using {n_jobs} parallel jobs..." + ) + + # Process positions in parallel + results = Parallel(n_jobs=n_jobs, verbose=10)( + delayed(process_position)( + well_id=well_id, + pos_id=pos_id, + segmentations_zarr=segmentations_zarr, + nuclear_labels_path=nuclear_labels_path, + patch_size=PATCH_SIZE, + ) + for well_id, pos_id in positions_to_process + ) + + # Combine all results and save + all_features = [df for df in results if not df.empty] + if all_features: + final_df = pd.concat(all_features, ignore_index=True) + # Save all at once instead of appending + output_csv.parent.mkdir(parents=True, exist_ok=True) + final_df.to_csv(output_csv, index=False) + print(f"Saved {len(final_df)} features to {output_csv}") + +# %% diff --git a/applications/pseudotime_analysis/organelle_segmentation/segment_organelles.py b/applications/pseudotime_analysis/organelle_segmentation/segment_organelles.py new file mode 100644 index 00000000..be255520 --- /dev/null +++ b/applications/pseudotime_analysis/organelle_segmentation/segment_organelles.py @@ -0,0 +1,295 @@ +import logging + +import numpy as np +from numpy.typing import ArrayLike +from skimage import measure, morphology +from skimage.exposure import equalize_adapthist +from skimage.filters import frangi, threshold_otsu, threshold_triangle + +_logger = logging.getLogger("viscy") + + +def segment_zyx( + input_zyx: ArrayLike, + clahe_kernel_size=None, + clahe_clip_limit=0.01, + sigma_range=(0.5, 3.0), + sigma_steps=5, + auto_optimize_sigma=True, + frangi_alpha=0.5, + frangi_beta=0.5, + frangi_gamma=None, + threshold_method="otsu", + min_object_size=10, + apply_morphology=True, +): + """ + Segment mitochondria from a 2D or 3D input_zyx using CLAHE preprocessing, + Frangi filtering, and connected component labeling. + + Based on: + Lefebvre, A.E.Y.T., Sturm, G., Lin, TY. et al. + Nellie (2025) https://doi.org/10.1038/s41592-025-02612-7 + + Parameters + ---------- + input_zyx : ndarray + Input image with shape (Z, Y, X) for 3D. + If 2D, uses 2D Frangi filter. If 3D with Z=1, squeezes to 2D. + clahe_kernel_size : int or None + Kernel size for CLAHE (Contrast Limited Adaptive Histogram Equalization). + If None, automatically set to max(input_zyx.shape) // 8. + clahe_clip_limit : float + Clipping limit for CLAHE, normalized between 0 and 1 (default: 0.01). + sigma_range : tuple of float + Range of sigma values to test for Frangi filter (min_sigma, max_sigma). + Represents the scale of structures to detect. + sigma_steps : int + Number of sigma values to test in the range. + auto_optimize_sigma : bool + If True, automatically finds optimal sigma by maximizing vesselness response. + If False, uses all sigmas in range for multi-scale filtering. + frangi_alpha : float + Frangi filter sensitivity to plate-j like structures (2D) or blob-like (3D). + frangi_beta : float + Frangi filter sensitivity to blob-like structures (2D) or tube-like (3D). + frangi_gamma : float or None + Frangi filter sensitivity to background noise. If None, auto-computed. + threshold_method : str + Thresholding method: 'otsu', 'triangle', 'percentile', 'manual_X'. + min_object_size : int + Minimum object size in pixels for connected components. + apply_morphology : bool + If True, applies morphological closing to connect nearby structures. + + Returns + ------- + labels : ndarray + Instance segmentation labels with same dimensionality as input. + vesselness : ndarray + Filtered vesselness response with same dimensionality as input. + optimal_sigma : float or None + The optimal sigma value if auto_optimize_sigma=True, else None. + """ + + assert input_zyx.ndim == 3 + Z, Y, X = input_zyx.shape[-3:] + + if clahe_kernel_size is None: + clahe_kernel_size = max(Z, Y, X) // 8 + + # Apply CLAHE for contrast enhancement + enhanced_zyx = equalize_adapthist( + input_zyx, + kernel_size=clahe_kernel_size, + clip_limit=clahe_clip_limit, + ) + + # Generate sigma values + sigmas = np.linspace(sigma_range[0], sigma_range[1], sigma_steps) + + # Auto-optimize sigma or use multi-scale + if auto_optimize_sigma: + optimal_sigma, vesselness = _find_optimal_sigma( + enhanced_zyx, sigmas, frangi_alpha, frangi_beta, frangi_gamma + ) + else: + optimal_sigma = None + vesselness = _multiscale_frangi( + enhanced_zyx, sigmas, frangi_alpha, frangi_beta, frangi_gamma + ) + + # Threshold the vesselness response + if threshold_method == "otsu": + threshold = threshold_otsu(vesselness) + _logger.debug(f"Otsu threshold: {threshold:.4f}") + elif threshold_method == "triangle": + threshold = threshold_triangle(vesselness) + _logger.debug(f"Triangle threshold: {threshold:.4f}") + elif threshold_method == "nellie_min": + threshold_otsu_val = threshold_otsu(vesselness) + threshold_triangle_val = threshold_triangle(vesselness) + threshold = min(threshold_otsu_val, threshold_triangle_val) + _logger.debug( + f"Nellie-min threshold: otsu={threshold_otsu_val:.4f}, triangle={threshold_triangle_val:.4f}, using min={threshold:.4f}" + ) + elif threshold_method == "nellie_max": + threshold_otsu_val = threshold_otsu(vesselness) + threshold_triangle_val = threshold_triangle(vesselness) + threshold = max(threshold_otsu_val, threshold_triangle_val) + _logger.debug( + f"Nellie-max threshold: otsu={threshold_otsu_val:.4f}, triangle={threshold_triangle_val:.4f}, using max={threshold:.4f}" + ) + elif threshold_method == "percentile": + # Use percentile-based threshold (good for sparse features) + threshold = np.percentile(vesselness[vesselness > 0], 95) # Keep top 5% + _logger.debug(f"Percentile (95th) threshold: {threshold:.4f}") + elif threshold_method.startswith("manual_"): + # Manual threshold: "manual_0.05" means threshold at 0.05 + threshold = float(threshold_method.split("_")[1]) + _logger.debug(f"Manual threshold: {threshold:.4f}") + else: + raise ValueError(f"Unknown threshold method: {threshold_method}") + + binary_mask = vesselness > threshold + + _logger.debug( + f" Selected {binary_mask.sum()} / {binary_mask.size} pixels ({100*binary_mask.sum()/binary_mask.size:.2f}%)" + ) + + # Apply morphological operations + if apply_morphology: + binary_mask = morphology.binary_closing( + binary_mask, footprint=morphology.ball(1) + ) + binary_mask = morphology.remove_small_holes(binary_mask, area_threshold=64) + + # Label connected components + labels = measure.label(binary_mask, connectivity=2) + + # Remove small objects + labels = morphology.remove_small_objects(labels, min_size=min_object_size) + + if Z == 1: + labels = labels[np.newaxis, ...] + vesselness = vesselness[np.newaxis, ...] + + return labels, vesselness, optimal_sigma + + +def _find_optimal_sigma(input_zyx, sigmas, alpha, beta, gamma): + """ + Find the optimal sigma that maximizes the vesselness response. + + Parameters + ---------- + input_zyx : ndarray + 3D input_zyx (Z, Y, X). + sigmas : array-like + Sigma values to test. + alpha, beta, gamma : float + Frangi filter parameters. + + Returns + ------- + optimal_sigma : float + The sigma with the highest mean vesselness response. + vesselness : ndarray + The vesselness response using optimal sigma. + """ + best_sigma = sigmas[0] + best_score = -np.inf + best_vesselness = None + + if input_zyx.shape[0] == 1: + input_zyx = input_zyx[0] + + for sigma in sigmas: + vessel = frangi( + input_zyx, + sigmas=[sigma], + alpha=alpha, + beta=beta, + gamma=gamma, + black_ridges=False, + ) + + # Score is the mean of top 10% vesselness values + score = np.mean( + np.partition(vessel.ravel(), -int(0.1 * vessel.size))[ + -int(0.1 * vessel.size) : + ] + ) + + if score > best_score: + best_score = score + best_sigma = sigma + best_vesselness = vessel + + if input_zyx.shape[0] == 1: + best_vesselness = best_vesselness[np.newaxis, ...] + + return best_sigma, best_vesselness + + +def _multiscale_frangi( + input_zyx, sigmas: ArrayLike, alpha: float, beta: float, gamma: float +): + """ + Apply Frangi filter at multiple scales and return the maximum response. + + Parameters + ---------- + input_zyx : ndarray + 3D input_zyx (Z, Y, X). + sigmas : array-like + Sigma values for multi-scale filtering. + alpha, beta, gamma : float + Frangi filter parameters. + + Returns + ------- + vesselness : ndarray + Maximum vesselness response across all scales. + """ + if input_zyx.shape[0] == 1: + input_zyx = input_zyx[0] + vesselness = frangi( + input_zyx, + sigmas=sigmas, + alpha=alpha, + beta=beta, + gamma=gamma, + black_ridges=False, + ) + if input_zyx.shape[0] == 1: + vesselness = vesselness[np.newaxis, ...] + return vesselness + + +def calculate_nellie_sigmas( + min_radius_um, max_radius_um, pixel_size_um, num_sigma=5, min_step_size_px=0.2 +): + """ + Calculate sigma values following Nellie's approach. + + Parameters + ---------- + min_radius_um : float + Minimum structure radius in micrometers (e.g., 0.2 for diffraction limit) + max_radius_um : float + Maximum structure radius in micrometers (e.g., 1.0 for thick tubules) + pixel_size_um : float + Pixel size in micrometers + num_sigma : int + Target number of sigma values + min_step_size_px : float + Minimum step size between sigmas in pixels + + Returns + ------- + tuple : (sigma_min, sigma_max) + Sigma range in pixels + """ + min_radius_px = min_radius_um / pixel_size_um + max_radius_px = max_radius_um / pixel_size_um + + # Nellie uses radius/2 to radius/3 as sigma + sigma_1 = min_radius_px / 2 + sigma_2 = max_radius_px / 3 + sigma_min = min(sigma_1, sigma_2) + sigma_max = max(sigma_1, sigma_2) + + # Calculate step size with minimum constraint + sigma_step_calculated = (sigma_max - sigma_min) / num_sigma + sigma_step = max(min_step_size_px, sigma_step_calculated) + + sigmas = list(np.arange(sigma_min, sigma_max + sigma_step, sigma_step)) + + _logger.debug(f" Nellie-style sigmas: {sigma_min:.3f} to {sigma_max:.3f} pixels") + _logger.debug( + f" Radius range: {min_radius_um:.3f}-{max_radius_um:.3f} µm = {min_radius_px:.2f}-{max_radius_px:.2f} pixels" + ) + _logger.debug(f" Sigma values: {[f'{s:.2f}' for s in sigmas]}") + + return (sigma_min, sigma_max) diff --git a/applications/pseudotime_analysis/simulation/demo_ndim_dtw.py b/applications/pseudotime_analysis/simulation/demo_ndim_dtw.py index 601bbda8..8744f0f5 100644 --- a/applications/pseudotime_analysis/simulation/demo_ndim_dtw.py +++ b/applications/pseudotime_analysis/simulation/demo_ndim_dtw.py @@ -136,7 +136,7 @@ def compute_dtw_matrix(s1, s2): best_path: The optimal warping path """ # Compute pairwise distances between all timepoints - distance_matrix = cdist(s1, s2) + distance_matrix = cdist(s1, s2, metric="cosine") n, m = distance_matrix.shape diff --git a/viscy/representation/evaluation/pseudotime_plotting.py b/viscy/representation/evaluation/pseudotime_plotting.py new file mode 100644 index 00000000..eab09761 --- /dev/null +++ b/viscy/representation/evaluation/pseudotime_plotting.py @@ -0,0 +1,903 @@ +import ast +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import xarray as xr +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from viscy.data.triplet import TripletDataModule + + +def plot_reference_aligned_average( + reference_pattern: np.ndarray, + top_aligned_cells: pd.DataFrame, + embeddings_dataset: xr.Dataset, + save_path: str | None = None, +) -> np.ndarray: + """Plot the reference embedding, aligned embeddings, and average aligned embedding. + + Parameters + ---------- + reference_pattern : np.ndarray + The reference pattern embeddings + top_aligned_cells : pd.DataFrame + DataFrame with alignment information + embeddings_dataset : xr.Dataset + Dataset containing embeddings + save_path : str, optional + Path to save the figure + + Returns + ------- + np.ndarray + Average aligned embeddings + """ + plt.figure(figsize=(15, 10)) + + # Get the reference pattern embeddings + reference_embeddings = reference_pattern + + # Calculate average aligned embeddings + all_aligned_embeddings = [] + for idx, row in top_aligned_cells.iterrows(): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + warp_path = row["warp_path"] + start_time = int(row["start_timepoint"]) + + # Reconstruct the concatenated lineage + lineages = [] + for track_id in track_ids: + track_embeddings = embeddings_dataset.sel( + sample=(fov_name, track_id) + ).features.values + lineages.append(track_embeddings) + + lineage_embeddings = np.concatenate(lineages, axis=0) + + # Create aligned embeddings using the warping path + aligned_embeddings = np.zeros( + (len(reference_pattern), lineage_embeddings.shape[1]), + dtype=lineage_embeddings.dtype, + ) + + # Create mapping from reference to lineage + ref_to_lineage = {} + for ref_idx, query_idx in warp_path: + lineage_idx = int(start_time + query_idx) + if 0 <= lineage_idx < len(lineage_embeddings): + ref_to_lineage[ref_idx] = lineage_idx + + # Fill aligned embeddings + for ref_idx in range(len(reference_pattern)): + if ref_idx in ref_to_lineage: + aligned_embeddings[ref_idx] = lineage_embeddings[ + ref_to_lineage[ref_idx] + ] + elif ref_to_lineage: + closest_ref_idx = min( + ref_to_lineage.keys(), key=lambda x: abs(x - ref_idx) + ) + aligned_embeddings[ref_idx] = lineage_embeddings[ + ref_to_lineage[closest_ref_idx] + ] + + all_aligned_embeddings.append(aligned_embeddings) + + # Calculate average aligned embeddings + average_aligned_embeddings = np.mean(all_aligned_embeddings, axis=0) + + # Plot dimension 0 + plt.subplot(2, 1, 1) + plt.plot( + range(len(reference_embeddings)), + reference_embeddings[:, 0], + label="Reference", + color="black", + linewidth=3, + ) + + # Plot each aligned embedding + for i, aligned_embeddings in enumerate(all_aligned_embeddings): + plt.plot( + range(len(aligned_embeddings)), + aligned_embeddings[:, 0], + label=f"Aligned {i}", + alpha=0.4, + linestyle="--", + ) + + # Plot average aligned embedding + plt.plot( + range(len(average_aligned_embeddings)), + average_aligned_embeddings[:, 0], + label="Average Aligned", + color="orange", # Changed from red for colorblind friendly + linewidth=2, + ) + + plt.title("Dimension 0: Reference, Aligned, and Average Embeddings") + plt.xlabel("Reference Time Index") + plt.ylabel("Embedding Value") + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot dimension 1 + plt.subplot(2, 1, 2) + plt.plot( + range(len(reference_embeddings)), + reference_embeddings[:, 1], + label="Reference", + color="black", + linewidth=3, + ) + + # Plot each aligned embedding + for i, aligned_embeddings in enumerate(all_aligned_embeddings): + plt.plot( + range(len(aligned_embeddings)), + aligned_embeddings[:, 1], + label=f"Aligned {i}", + alpha=0.4, + linestyle="--", + ) + + # Plot average aligned embedding + plt.plot( + range(len(average_aligned_embeddings)), + average_aligned_embeddings[:, 1], + label="Average Aligned", + color="orange", # Changed from red for colorblind friendly + linewidth=2, + ) + + plt.title("Dimension 1: Reference, Aligned, and Average Embeddings") + plt.xlabel("Reference Time Index") + plt.ylabel("Embedding Value") + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + if save_path: + plt.savefig(save_path, dpi=300) + plt.show() + + return average_aligned_embeddings + + +def plot_reference_vs_full_lineages( + reference_pattern: np.ndarray, + top_aligned_cells: pd.DataFrame, + embeddings_dataset: xr.Dataset, + save_path: str | None = None, +) -> None: + """Visualize where the reference pattern matches in each full lineage. + + Parameters + ---------- + reference_pattern : np.ndarray + The reference pattern embeddings + top_aligned_cells : pd.DataFrame + DataFrame with alignment information + embeddings_dataset : xr.Dataset + Dataset containing embeddings + save_path : str, optional + Path to save the figure + """ + plt.figure(figsize=(15, 15)) + + # First, plot the reference pattern for comparison + plt.subplot(len(top_aligned_cells) + 1, 2, 1) + plt.plot( + range(len(reference_pattern)), + reference_pattern[:, 0], + label="Reference Dim 0", + color="black", + linewidth=2, + ) + plt.title("Reference Pattern - Dimension 0") + plt.xlabel("Time Index") + plt.ylabel("Embedding Value") + plt.grid(True, alpha=0.3) + plt.legend() + + plt.subplot(len(top_aligned_cells) + 1, 2, 2) + plt.plot( + range(len(reference_pattern)), + reference_pattern[:, 1], + label="Reference Dim 1", + color="black", + linewidth=2, + ) + plt.title("Reference Pattern - Dimension 1") + plt.xlabel("Time Index") + plt.ylabel("Embedding Value") + plt.grid(True, alpha=0.3) + plt.legend() + + # Then plot each lineage with the matched section highlighted + for i, (_, row) in enumerate(top_aligned_cells.iterrows()): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + warp_path = row["warp_path"] + start_time = row["start_timepoint"] + distance = row["distance"] + + # Get the full lineage embeddings + lineage_embeddings = embeddings_dataset.sel( + sample=(fov_name, track_ids) + ).features.values + + # Create a subplot for dimension 0 + plt.subplot(len(top_aligned_cells) + 1, 2, 2 * i + 3) + + # Plot the full lineage + plt.plot( + range(len(lineage_embeddings)), + lineage_embeddings[:, 0], + label="Full Lineage", + color="blue", + alpha=0.7, + ) + + # Highlight the matched section + matched_indices = set() + for _, query_idx in warp_path: + lineage_idx = int(start_time + query_idx) + if 0 <= lineage_idx < len(lineage_embeddings): + matched_indices.add(lineage_idx) + + matched_indices = sorted(list(matched_indices)) + if matched_indices: + plt.plot( + matched_indices, + [lineage_embeddings[idx, 0] for idx in matched_indices], + "o-", + color="orange", # Changed from red for colorblind friendly + label=f"Matched Section (DTW dist={distance:.2f})", + linewidth=2, + ) + + # Add vertical lines to mark the start and end of the matched section + plt.axvline( + x=min(matched_indices), color="orange", linestyle="--", alpha=0.5 + ) + plt.axvline( + x=max(matched_indices), color="orange", linestyle="--", alpha=0.5 + ) + + # Add text labels + plt.text( + min(matched_indices), + min(lineage_embeddings[:, 0]), + f"Start: {min(matched_indices)}", + color="orange", + fontsize=10, + ) + plt.text( + max(matched_indices), + min(lineage_embeddings[:, 0]), + f"End: {max(matched_indices)}", + color="orange", + fontsize=10, + ) + + plt.title(f"Lineage {i} ({fov_name}) Track {track_ids[0]} - Dimension 0") + plt.xlabel("Lineage Time") + plt.ylabel("Embedding Value") + plt.legend() + plt.grid(True, alpha=0.3) + + # Create a subplot for dimension 1 + plt.subplot(len(top_aligned_cells) + 1, 2, 2 * i + 4) + + # Plot the full lineage + plt.plot( + range(len(lineage_embeddings)), + lineage_embeddings[:, 1], + label="Full Lineage", + color="blue", # Changed from green for consistency + alpha=0.7, + ) + + # Highlight the matched section + if matched_indices: + plt.plot( + matched_indices, + [lineage_embeddings[idx, 1] for idx in matched_indices], + "o-", + color="orange", # Changed from red for colorblind friendly + label=f"Matched Section (DTW dist={distance:.2f})", + linewidth=2, + ) + + # Add vertical lines to mark the start and end of the matched section + plt.axvline( + x=min(matched_indices), color="orange", linestyle="--", alpha=0.5 + ) + plt.axvline( + x=max(matched_indices), color="orange", linestyle="--", alpha=0.5 + ) + + # Add text labels + plt.text( + min(matched_indices), + min(lineage_embeddings[:, 1]), + f"Start: {min(matched_indices)}", + color="orange", + fontsize=10, + ) + plt.text( + max(matched_indices), + min(lineage_embeddings[:, 1]), + f"End: {max(matched_indices)}", + color="orange", + fontsize=10, + ) + + plt.title(f"Lineage {i} ({fov_name}) - Dimension 1") + plt.xlabel("Lineage Time") + plt.ylabel("Embedding Value") + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + if save_path: + plt.savefig(save_path, dpi=300) + plt.show() + + +def plot_pc_trajectories( + reference_lineage_fov: str, + reference_lineage_track_id: list[int], + reference_timepoints: list[int], + match_positions: pd.DataFrame, + embeddings_dataset: xr.Dataset, + filtered_lineages: list[tuple[str, list[int]]], + name: str, + save_path: Path, +): + """Visualize warping paths in PC space, comparing reference pattern with aligned lineages. + + Parameters + ---------- + reference_lineage_fov : str + FOV name for the reference lineage + reference_lineage_track_id : list[int] + Track ID for the reference lineage + reference_timepoints : list[int] + Time range [start, end] to use from reference + match_positions : pd.DataFrame + DataFrame with alignment matches + embeddings_dataset : xr.Dataset + Dataset with embeddings + filtered_lineages : list[tuple[str, list[int]]] + List of lineages to search in (fov_name, track_ids) + name : str + Name of the embedding model + save_path : Path + Path to save the figure + """ + # Get reference pattern + ref_pattern = None + for fov_id, track_ids in filtered_lineages: + if fov_id == reference_lineage_fov and all( + track_id in track_ids for track_id in reference_lineage_track_id + ): + ref_pattern = embeddings_dataset.sel( + sample=(fov_id, reference_lineage_track_id) + ).features.values + break + + if ref_pattern is None: + print(f"Reference pattern not found for {name}. Skipping PC trajectory plot.") + return + + ref_pattern = np.concatenate([ref_pattern]) + ref_pattern = ref_pattern[reference_timepoints[0] : reference_timepoints[1]] + + # Get top matches + top_n_aligned_cells = match_positions.head(5) + + # Compute PCA directly with sklearn + scaler = StandardScaler() + ref_pattern_scaled = scaler.fit_transform(ref_pattern) + + # Create and fit PCA model + pca_model = PCA(n_components=2, random_state=42) + pca_ref = pca_model.fit_transform(ref_pattern_scaled) + + # Create a figure to display the results + plt.figure(figsize=(15, 15)) + + # Plot the reference pattern PCs + plt.subplot(len(top_n_aligned_cells) + 1, 2, 1) + plt.plot( + range(len(pca_ref)), + pca_ref[:, 0], + label="Reference PC1", + color="black", + linewidth=2, + ) + plt.title(f"{name} - Reference Pattern - PC1") + plt.xlabel("Time Index") + plt.ylabel("PC1 Value") + plt.grid(True, alpha=0.3) + plt.legend() + + plt.subplot(len(top_n_aligned_cells) + 1, 2, 2) + plt.plot( + range(len(pca_ref)), + pca_ref[:, 1], + label="Reference PC2", + color="black", + linewidth=2, + ) + plt.title(f"{name} - Reference Pattern - PC2") + plt.xlabel("Time Index") + plt.ylabel("PC2 Value") + plt.grid(True, alpha=0.3) + plt.legend() + + # Then plot each lineage with the matched section highlighted + for i, (_, row) in enumerate(top_n_aligned_cells.iterrows()): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + if isinstance(track_ids, str): + track_ids = ast.literal_eval(track_ids) + warp_path = row["warp_path"] + if isinstance(warp_path, str): + warp_path = ast.literal_eval(warp_path) + start_time = row["start_timepoint"] + distance = row["distance"] + + # Get the full lineage embeddings + lineage_embeddings = [] + for track_id in track_ids: + try: + track_emb = embeddings_dataset.sel( + sample=(fov_name, track_id) + ).features.values + lineage_embeddings.append(track_emb) + except KeyError: + pass + + if not lineage_embeddings: + continue + + lineage_embeddings = np.concatenate(lineage_embeddings, axis=0) + + # Transform lineage embeddings using the same PCA model + lineage_scaled = scaler.transform(lineage_embeddings) + pca_lineage = pca_model.transform(lineage_scaled) + + # Create a subplot for PC1 + plt.subplot(len(top_n_aligned_cells) + 1, 2, 2 * i + 3) + + # Plot the full lineage PC1 + plt.plot( + range(len(pca_lineage)), + pca_lineage[:, 0], + label="Full Lineage", + color="blue", + alpha=0.7, + ) + + # Highlight the matched section + matched_indices = set() + for _, query_idx in warp_path: + lineage_idx = ( + int(start_time) + query_idx if not pd.isna(start_time) else query_idx + ) + if 0 <= lineage_idx < len(pca_lineage): + matched_indices.add(lineage_idx) + + matched_indices = sorted(list(matched_indices)) + if matched_indices: + plt.plot( + matched_indices, + [pca_lineage[idx, 0] for idx in matched_indices], + "o-", + color="orange", # Changed from red for colorblind friendly + label=f"Matched Section (DTW dist={distance:.2f})", + linewidth=2, + ) + + # Add vertical lines to mark the start and end of the matched section + plt.axvline( + x=min(matched_indices), color="orange", linestyle="--", alpha=0.5 + ) + plt.axvline( + x=max(matched_indices), color="orange", linestyle="--", alpha=0.5 + ) + + # Add text labels + plt.text( + min(matched_indices), + min(pca_lineage[:, 0]), + f"Start: {min(matched_indices)}", + color="orange", + fontsize=10, + ) + plt.text( + max(matched_indices), + min(pca_lineage[:, 0]), + f"End: {max(matched_indices)}", + color="orange", + fontsize=10, + ) + + plt.title(f"Lineage {i} ({fov_name}) Track {track_ids[0]} - PC1") + plt.xlabel("Lineage Time") + plt.ylabel("PC1 Value") + plt.legend() + plt.grid(True, alpha=0.3) + + # Create a subplot for PC2 + plt.subplot(len(top_n_aligned_cells) + 1, 2, 2 * i + 4) + + # Plot the full lineage PC2 + plt.plot( + range(len(pca_lineage)), + pca_lineage[:, 1], + label="Full Lineage", + color="blue", # Changed from green for consistency + alpha=0.7, + ) + + # Highlight the matched section + if matched_indices: + plt.plot( + matched_indices, + [pca_lineage[idx, 1] for idx in matched_indices], + "o-", + color="orange", # Changed from red for colorblind friendly + label=f"Matched Section (DTW dist={distance:.2f})", + linewidth=2, + ) + + # Add vertical lines to mark the start and end of the matched section + plt.axvline( + x=min(matched_indices), color="orange", linestyle="--", alpha=0.5 + ) + plt.axvline( + x=max(matched_indices), color="orange", linestyle="--", alpha=0.5 + ) + + # Add text labels + plt.text( + min(matched_indices), + min(pca_lineage[:, 1]), + f"Start: {min(matched_indices)}", + color="orange", + fontsize=10, + ) + plt.text( + max(matched_indices), + min(pca_lineage[:, 1]), + f"End: {max(matched_indices)}", + color="orange", + fontsize=10, + ) + + plt.title(f"Lineage {i} ({fov_name}) - PC2") + plt.xlabel("Lineage Time") + plt.ylabel("PC2 Value") + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(save_path, dpi=300) + plt.close() + + +def align_and_average_embeddings( + reference_pattern: np.ndarray, + top_aligned_cells: pd.DataFrame, + embeddings_dataset: xr.Dataset, + use_median: bool = False, +) -> np.ndarray: + """Align embeddings from multiple lineages to a reference pattern and compute their average. + + Parameters + ---------- + reference_pattern : np.ndarray + The reference pattern embeddings + top_aligned_cells : pd.DataFrame + DataFrame with alignment information + embeddings_dataset : xr.Dataset + Dataset containing embeddings + use_median : bool + If True, use median instead of mean for averaging + + Returns + ------- + np.ndarray + The average (or median) aligned embeddings + """ + all_aligned_embeddings = [] + + for idx, row in top_aligned_cells.iterrows(): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + warp_path = row["warp_path"] + start_time = int(row["start_timepoint"]) + + # Reconstruct the concatenated lineage + lineages = [] + for track_id in track_ids: + track_embeddings = embeddings_dataset.sel( + sample=(fov_name, track_id) + ).features.values + lineages.append(track_embeddings) + + lineage_embeddings = np.concatenate(lineages, axis=0) + + # Create aligned embeddings using the warping path + aligned_segment = np.zeros_like(reference_pattern) + + # Map each reference timepoint to the corresponding lineage timepoint + ref_to_lineage = {} + for ref_idx, query_idx in warp_path: + lineage_idx = int(start_time + query_idx) + if 0 <= lineage_idx < len(lineage_embeddings): + ref_to_lineage[ref_idx] = lineage_idx + aligned_segment[ref_idx] = lineage_embeddings[lineage_idx] + + # Fill in missing values by using the closest available reference index + for ref_idx in range(len(reference_pattern)): + if ref_idx not in ref_to_lineage and ref_to_lineage: + closest_ref_idx = min( + ref_to_lineage.keys(), key=lambda x: abs(x - ref_idx) + ) + aligned_segment[ref_idx] = lineage_embeddings[ + ref_to_lineage[closest_ref_idx] + ] + + all_aligned_embeddings.append(aligned_segment) + + all_aligned_embeddings = np.array(all_aligned_embeddings) + + # Compute average or median + if use_median: + return np.median(all_aligned_embeddings, axis=0) + else: + return np.mean(all_aligned_embeddings, axis=0) + + +def align_image_stacks( + reference_pattern: np.ndarray, + top_aligned_cells: pd.DataFrame, + input_data_path: Path, + tracks_path: Path, + source_channels: list[str], + yx_patch_size: tuple[int, int] = (192, 192), + z_range: tuple[int, int] = (0, 1), + view_ref_sector_only: bool = True, + napari_viewer=None, +) -> tuple[list, list]: + """Align image stacks from multiple lineages to a reference pattern. + + Parameters + ---------- + reference_pattern : np.ndarray + The reference pattern embeddings + top_aligned_cells : pd.DataFrame + DataFrame with alignment information + input_data_path : Path + Path to the input data + tracks_path : Path + Path to the tracks data + source_channels : list[str] + List of channels to include + yx_patch_size : tuple[int, int] + Patch size for images + z_range : tuple[int, int] + Z-range to include + view_ref_sector_only : bool + If True, only show the section that matches the reference pattern + napari_viewer : optional + Optional napari viewer for visualization + + Returns + ------- + tuple[list, list] + Tuple of (all_lineage_images, all_aligned_stacks) + """ + all_lineage_images = [] + all_aligned_stacks = [] + + for idx, row in tqdm( + top_aligned_cells.iterrows(), + total=len(top_aligned_cells), + desc="Aligning images", + ): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + warp_path = row["warp_path"] + start_time = int(row["start_timepoint"]) + + print(f"Aligning images for {fov_name} with track ids: {track_ids}") + data_module = TripletDataModule( + data_path=input_data_path, + tracks_path=tracks_path, + source_channel=source_channels, + z_range=z_range, + initial_yx_patch_size=yx_patch_size, + final_yx_patch_size=yx_patch_size, + batch_size=1, + num_workers=12, + predict_cells=True, + include_fov_names=[fov_name] * len(track_ids), + include_track_ids=track_ids, + ) + data_module.setup("predict") + + # Get the images for the lineage + lineage_images = [] + for batch in data_module.predict_dataloader(): + image = batch["anchor"].numpy()[0] + lineage_images.append(image) + + lineage_images = np.array(lineage_images) + all_lineage_images.append(lineage_images) + print(f"Lineage images shape: {np.array(lineage_images).shape}") + + # Create an aligned stack based on the warping path + if view_ref_sector_only: + aligned_stack = np.zeros( + (len(reference_pattern),) + lineage_images.shape[-4:], + dtype=lineage_images.dtype, + ) + + # Map each reference timepoint to the corresponding lineage timepoint + for ref_idx in range(len(reference_pattern)): + # Find matches in warping path for this reference index + matches = [(i, q) for i, q in warp_path if i == ref_idx] + + if matches: + # Get the corresponding lineage timepoint (first match if multiple) + print(f"Found match for ref idx: {ref_idx}") + match = matches[0] + query_idx = match[1] + lineage_idx = int(start_time + query_idx) + print( + f"Lineage index: {lineage_idx}, start time: {start_time}, query idx: {query_idx}, ref idx: {ref_idx}" + ) + # Copy the image if it's within bounds + if 0 <= lineage_idx < len(lineage_images): + aligned_stack[ref_idx] = lineage_images[lineage_idx] + else: + # Find nearest valid timepoint if out of bounds + nearest_idx = min(max(0, lineage_idx), len(lineage_images) - 1) + aligned_stack[ref_idx] = lineage_images[nearest_idx] + else: + # If no direct match, find closest reference timepoint in warping path + print(f"No match found for ref idx: {ref_idx}") + all_ref_indices = [i for i, _ in warp_path] + if all_ref_indices: + closest_ref_idx = min( + all_ref_indices, key=lambda x: abs(x - ref_idx) + ) + closest_matches = [ + (i, q) for i, q in warp_path if i == closest_ref_idx + ] + + if closest_matches: + closest_query_idx = closest_matches[0][1] + lineage_idx = int(start_time + closest_query_idx) + + if 0 <= lineage_idx < len(lineage_images): + aligned_stack[ref_idx] = lineage_images[lineage_idx] + else: + # Bound to valid range + nearest_idx = min( + max(0, lineage_idx), len(lineage_images) - 1 + ) + aligned_stack[ref_idx] = lineage_images[nearest_idx] + + all_aligned_stacks.append(aligned_stack) + if napari_viewer: + napari_viewer.add_image( + aligned_stack, + name=f"Aligned_{fov_name}_track_{track_ids[0]}", + channel_axis=1, + ) + else: + # View the whole lineage shifted by the start time + start_idx = int(start_time) + aligned_stack = lineage_images[start_idx:] + all_aligned_stacks.append(aligned_stack) + if napari_viewer: + napari_viewer.add_image( + aligned_stack, + name=f"Aligned_{fov_name}_track_{track_ids[0]}", + channel_axis=1, + ) + + return all_lineage_images, all_aligned_stacks + + +def create_consensus_embedding( + reference_pattern: np.ndarray, + top_aligned_cells: pd.DataFrame, + embeddings_dataset: xr.Dataset, +) -> np.ndarray: + """Create a consensus embedding from multiple aligned embeddings using weighted approach. + + Parameters + ---------- + reference_pattern : np.ndarray + The reference pattern embeddings + top_aligned_cells : pd.DataFrame + DataFrame with alignment information + embeddings_dataset : xr.Dataset + Dataset containing embeddings + + Returns + ------- + np.ndarray + The consensus embedding + """ + all_aligned_embeddings = [] + distances = [] + + for idx, row in top_aligned_cells.iterrows(): + fov_name = row["fov_name"] + track_ids = row["track_ids"] + warp_path = row["warp_path"] + start_time = int(row["start_timepoint"]) + distance = row["distance"] + + # Get lineage embeddings + lineages = [] + for track_id in track_ids: + track_embeddings = embeddings_dataset.sel( + sample=(fov_name, track_id) + ).features.values + lineages.append(track_embeddings) + + lineage_embeddings = np.concatenate(lineages, axis=0) + + # Create aligned embeddings using the warping path + aligned_segment = np.zeros_like(reference_pattern) + + # Map each reference timepoint to the corresponding lineage timepoint + ref_to_lineage = {} + for ref_idx, query_idx in warp_path: + lineage_idx = int(start_time + query_idx) + if 0 <= lineage_idx < len(lineage_embeddings): + ref_to_lineage[ref_idx] = lineage_idx + aligned_segment[ref_idx] = lineage_embeddings[lineage_idx] + + # Fill in missing values + for ref_idx in range(len(reference_pattern)): + if ref_idx not in ref_to_lineage and ref_to_lineage: + closest_ref_idx = min( + ref_to_lineage.keys(), key=lambda x: abs(x - ref_idx) + ) + aligned_segment[ref_idx] = lineage_embeddings[ + ref_to_lineage[closest_ref_idx] + ] + + all_aligned_embeddings.append(aligned_segment) + distances.append(distance) + + all_aligned_embeddings = np.array(all_aligned_embeddings) + + # Convert distances to weights (smaller distance = higher weight) + weights = 1.0 / ( + np.array(distances) + 1e-10 + ) # Add small epsilon to avoid division by zero + weights = weights / np.sum(weights) # Normalize weights + + # Create weighted consensus + consensus_embedding = np.zeros_like(reference_pattern) + for i, aligned_embedding in enumerate(all_aligned_embeddings): + consensus_embedding += weights[i] * aligned_embedding + + return consensus_embedding diff --git a/viscy/representation/pseudotime.py b/viscy/representation/pseudotime.py new file mode 100644 index 00000000..5a55a053 --- /dev/null +++ b/viscy/representation/pseudotime.py @@ -0,0 +1,3676 @@ +import logging +from pathlib import Path +from typing import Literal, Tuple + +import anndata as ad +import numpy as np +import pandas as pd +from numpy.typing import ArrayLike +from scipy.spatial.distance import cdist +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm +from typing_extensions import TypedDict + +_logger = logging.getLogger("lightning.pytorch") + + +# Utility functions +def filter_tracks_by_fov_and_length( + df: pd.DataFrame, + fov_pattern: str | list[str], + min_timepoints: int, +) -> pd.DataFrame: + """ + Filter dataframe by FOV pattern(s) and minimum timepoints per track. + + Convenience function for filtering tracking data by field of view pattern + and track length. Useful for separating experimental conditions (e.g., + infected vs uninfected) and ensuring sufficient temporal coverage. + + Parameters + ---------- + df : pd.DataFrame + Dataframe with 'fov_name', 'track_id', and 't' columns + fov_pattern : str or list[str] + Pattern(s) to match FOV names (used with str.contains()). + If list, matches any of the patterns (OR logic). + min_timepoints : int + Minimum number of timepoints required per track + + Returns + ------- + pd.DataFrame + Filtered dataframe containing only tracks that: + 1. Match the FOV pattern(s) + 2. Have at least min_timepoints timepoints + + Examples + -------- + >>> # Single pattern + >>> infected_df = filter_tracks_by_fov_and_length( + ... master_df, fov_pattern="B/2", min_timepoints=20 + ... ) + >>> + >>> # Multiple patterns (OR logic) + >>> infected_df = filter_tracks_by_fov_and_length( + ... master_df, fov_pattern=["B/2", "C/2"], min_timepoints=20 + ... ) + """ + # Handle single pattern or multiple patterns + if isinstance(fov_pattern, str): + fov_patterns = [fov_pattern] + else: + fov_patterns = fov_pattern + + # Filter by FOV pattern(s) - OR logic + fov_mask = pd.Series(False, index=df.index) + for pattern in fov_patterns: + fov_mask |= df["fov_name"].str.contains(pattern) + + fov_filtered = df[fov_mask].copy() + + if len(fov_filtered) == 0: + _logger.warning(f"No FOVs matched pattern(s): {fov_patterns}") + return pd.DataFrame() + + # Count timepoints per track + track_lengths = fov_filtered.groupby(["fov_name", "track_id"]).size() + valid_tracks = track_lengths[track_lengths >= min_timepoints].index + + # Filter to valid tracks only + result = fov_filtered[ + fov_filtered.set_index(["fov_name", "track_id"]).index.isin(valid_tracks) + ].reset_index(drop=True) + + pattern_str = fov_patterns if len(fov_patterns) > 1 else fov_patterns[0] + _logger.info( + f"Filtered by pattern '{pattern_str}' with min_timepoints={min_timepoints}: " + f"{len(valid_tracks)} tracks from {fov_filtered['fov_name'].nunique()} FOVs" + ) + + return result + + +# Annotated Example TypeDict +class AnnotatedSample(TypedDict): + fov_name: str + track_id: int | list[int] + timepoints: tuple[int, int] + annotations: dict | list + weight: float + + +class DtwSample(TypedDict, total=False): + pattern: np.ndarray + annotations: list[str, int, float] | None + distance: float + skewness: float + warping_path: list[tuple[int, int]] + metadata: dict + + +class CytoDtw: + def __init__(self, adata: ad.AnnData): + """ + DTW for Dynamic Cell Embeddings + + Parameters + ---------- + adata : ad.AnnData + AnnData object containing: + - X: features/embeddings + - obs: tracking info (fov_name, track_id, t, x, y, etc.) and annotations + - obsm: multi-dimensional embeddings (X_PCA, X_UMAP, etc.) + """ + self.adata = adata + self.lineages = None + self.consensus_data = None + self.reference_patterns = None + + def _validate_input(self): + raise NotImplementedError("Validation of input not implemented") + + def save_consensus(self, path: str): + """Save consensus pattern to a file.""" + import pickle + + if self.consensus_data is None: + raise ValueError("Consensus pattern not found") + with open(path, "wb") as f: + pickle.dump(self.consensus_data, f) + + def load_consensus(self, path: str): + """Load consensus pattern from a file.""" + import pickle + + with open(path, "rb") as f: + self.consensus_data = pickle.load(f) + + def get_lineages(self, min_timepoints: int = 15) -> list[tuple[str, list[int]]]: + """Get identified lineages with specified minimum timepoints.""" + return self._identify_lineages(min_timepoints) + + def get_track_statistics( + self, + lineages: list[tuple[str, list[int]]] = None, + min_timepoints: int = 15, + per_fov: bool = False, + ) -> pd.DataFrame: + """Get statistics for tracks/lineages. + + Parameters + ---------- + lineages : list[tuple[str, list[int]]], optional + List of (fov_name, track_ids) to analyze. If None, uses all lineages. + min_timepoints : int + Minimum timepoints required for lineage identification if lineages is None + per_fov : bool + If True, return aggregated statistics per FOV. If False, return per-lineage statistics. + + Returns + ------- + pd.DataFrame + If per_fov=False: DataFrame with per-lineage statistics: + - fov_name: FOV identifier + - track_ids: List of track IDs in lineage + - n_tracks: Number of tracks in lineage + - total_timepoints: Total timepoints across all tracks + - mean_timepoints_per_track: Average timepoints per track + - std_timepoints_per_track: Standard deviation of timepoints per track + - min_t: Earliest timepoint + - max_t: Latest timepoint + + If per_fov=True: DataFrame with aggregated per-FOV statistics: + - fov_name: FOV identifier + - n_lineages: Number of lineages in FOV + - total_tracks: Total number of tracks across all lineages + - mean_tracks_per_lineage: Average tracks per lineage + - std_tracks_per_lineage: Standard deviation of tracks per lineage + - mean_total_timepoints: Average total timepoints per lineage + - std_total_timepoints: Standard deviation of total timepoints + - mean_timepoints_per_track: Average timepoints per track across all tracks + - std_timepoints_per_track: Standard deviation across all tracks + """ + if lineages is None: + lineages = self.get_lineages(min_timepoints) + + stats_list = [] + for fov_name, track_ids in lineages: + lineage_rows = self.adata.obs[ + (self.adata.obs["fov_name"] == fov_name) + & (self.adata.obs["track_id"].isin(track_ids)) + ] + + total_timepoints = len(lineage_rows) + timepoints_per_track = [] + + for track_id in track_ids: + track_rows = lineage_rows[lineage_rows["track_id"] == track_id] + timepoints_per_track.append(len(track_rows)) + + stats_list.append( + { + "fov_name": fov_name, + "track_ids": track_ids, + "n_tracks": len(track_ids), + "total_timepoints": total_timepoints, + "mean_timepoints_per_track": np.mean(timepoints_per_track), + "std_timepoints_per_track": np.std(timepoints_per_track), + "timepoints_per_track": timepoints_per_track, + "min_t": lineage_rows["t"].min(), + "max_t": lineage_rows["t"].max(), + } + ) + + df = pd.DataFrame(stats_list) + + if not per_fov: + return df.drop(columns=["timepoints_per_track"]) + + # Aggregate by FOV + fov_stats = [] + for fov_name in df["fov_name"].unique(): + fov_df = df[df["fov_name"] == fov_name] + all_timepoints = [ + tp for tps in fov_df["timepoints_per_track"] for tp in tps + ] + + fov_stats.append( + { + "fov_name": fov_name, + "n_lineages": len(fov_df), + "total_tracks": fov_df["n_tracks"].sum(), + "mean_tracks_per_lineage": fov_df["n_tracks"].mean(), + "std_tracks_per_lineage": fov_df["n_tracks"].std(), + "mean_total_timepoints": fov_df["total_timepoints"].mean(), + "std_total_timepoints": fov_df["total_timepoints"].std(), + "mean_timepoints_per_track": np.mean(all_timepoints), + "std_timepoints_per_track": np.std(all_timepoints), + } + ) + + return pd.DataFrame(fov_stats) + + def _identify_lineages( + self, min_timepoints: int = 15 + ) -> list[tuple[str, list[int]]]: + """Auto-identify lineages from the data.""" + # Use parent_track_id if available for proper lineage identification + if "parent_track_id" in self.adata.obs.columns: + all_lineages = identify_lineages(self.adata.obs, return_both_branches=False) + else: + # Fallback: treat each track as individual lineage + all_lineages = [] + for (fov, track_id), group in self.adata.obs.groupby( + ["fov_name", "track_id"] + ): + all_lineages.append((fov, [track_id])) + + # Filter lineages by total timepoints across all tracks in lineage + filtered_lineages = [] + for fov_id, track_ids in all_lineages: + lineage_rows = self.adata.obs[ + (self.adata.obs["fov_name"] == fov_id) + & (self.adata.obs["track_id"].isin(track_ids)) + ] + total_timepoints = len(lineage_rows) + if total_timepoints >= min_timepoints: + filtered_lineages.append((fov_id, track_ids)) + self.lineages = filtered_lineages + return self.lineages + + def get_reference_pattern( + self, + fov_name: str, + track_id: int | list[int], + timepoints: tuple[int, int], + reference_type: str = "features", + ) -> np.ndarray: + """ + Extract reference pattern from embeddings. + + Parameters + ---------- + fov_name : str + FOV identifier + track_id : int | list[int] + Track ID(s) to use as reference + timepoints : tuple[int, int] + Start and end timepoints (start, end) + reference_type : str + Type of embedding to use ('features' for X, or obsm key like 'X_PCA') + Returns + ------- + np.ndarray + Reference pattern embeddings + """ + if isinstance(track_id, int): + track_id = [track_id] + + reference_embeddings = [] + for tid in track_id: + # Filter by fov_name and track_id + mask = (self.adata.obs["fov_name"] == fov_name) & ( + self.adata.obs["track_id"] == tid + ) + track_data = self.adata[mask] + + # Sort by timepoint to ensure correct order + time_order = np.argsort(track_data.obs["t"].values) + track_data = track_data[time_order] + + if reference_type == "features": + track_emb = track_data.X + else: + # Assume it's an obsm key + track_emb = track_data.obsm[reference_type] + + # Handle 1D arrays (PC components) by reshaping to (time, 1) + if track_emb.ndim == 1: + track_emb = track_emb.reshape(-1, 1) + + reference_embeddings.append(track_emb) + + reference_pattern = np.concatenate(reference_embeddings, axis=0) + + start_t, end_t = timepoints + reference_pattern = reference_pattern[start_t:end_t] + + return reference_pattern + + def get_matches( + self, + reference_pattern: np.ndarray = None, + lineages: list[tuple[str, list[int]]] = None, + window_step: int = 5, + num_candidates: int | None = None, + max_distance: float = float("inf"), + max_skew: float = 0.8, + method: str = "bernd_clifford", + normalize: bool = True, + metric: str = "euclidean", + reference_type: str = "features", + constraint_type: str = "unconstrained", + band_width_ratio: float = 0.0, + save_path: str | Path = None, + ) -> pd.DataFrame: + """Find pattern matches across lineages using DTW. + + Parameters + ---------- + reference_pattern : np.ndarray + Reference pattern to search for + lineages : list[tuple[str, list[int]]], optional + List of (fov_name, track_ids) to search in. If None, searches all. + window_step : int + Step size for sliding window search + num_candidates : int + Number of best candidates per lineage + max_distance : float + Maximum DTW distance threshold + max_skew : float + Maximum path skewness (0-1) + method : str + DTW method ('bernd_clifford' or 'dtai') + normalize : bool + Whether to normalize DTW distance by path length + metric : str + Distance metric for embeddings + save_path : str | Path, optional + Path to save results CSV + + Returns + ------- + pd.DataFrame + Match results with distances and warping paths + """ + if reference_pattern is None: + reference_pattern = self.consensus_data["pattern"] + if lineages is None: + # FIXME: Auto-identify lineages from tracking data + lineages = self.get_lineages() + + return find_pattern_matches( + reference_pattern=reference_pattern, + filtered_lineages=lineages, + adata=self.adata, + window_step=window_step, + num_candidates=num_candidates, + max_distance=max_distance, + max_skew=max_skew, + method=method, + normalize=normalize, + metric=metric, + reference_type=reference_type, + constraint_type=constraint_type, + band_width_ratio=band_width_ratio, + save_path=save_path, + ) + + def create_consensus_reference_pattern( + self, + annotated_samples: list[AnnotatedSample], + reference_selection: str = "median_length", + aggregation_method: str = "mean", + annotations_name: str = "annotations", + reference_type: str = "features", + **kwargs, + ) -> DtwSample: + """ + Create consensus reference pattern from annotated samples. + + This method takes one or more annotated cell examples and creates a + consensus reference pattern. For single annotations, uses it directly. For + multiple annotations, aligns them with DTW and aggregates. + + Parameters + ---------- + annotated_samples : list[AnnotatedSample] + List of annotated examples (minimum 1 required) + reference_selection : str + mode of selection of reference: "median_length", "first", "longest", "shortest" + aggregation_method : str + mode of aggregation: "mean", "median", "weighted_mean" + annotations_name : str + name of the annotations column + reference_type : str + Type of embedding to use ("features", "projections", "PC1", etc.) + Returns + ------- + DtwSample + DtwSample containing: + - 'pattern': np.ndarray - The consensus embedding pattern + - 'annotations': list - Consensus annotations (if available) + - 'metadata': dict - Information about consensus creation including method used + - 'distance': float - DTW distance + - 'skewness': float - Path skewness + - 'warping_path': list - DTW warping path + + Examples + -------- + >>> analyzer = CytoDtw("embeddings.zarr") + >>> examples = [ + ... AnnotatedSample( + ... 'fov_name': '/FOV1', 'track_id': 129, + ... 'timepoints': (8, 70), 'annotations': ['G1', 'S', 'G2', ...] + ... ), + ... AnnotatedSample( + ... 'fov_name': '/FOV2', 'track_id': 45, + ... 'timepoints': (5, 55), 'weight': 1.2 + ... ) + ... ] + >>> consensus = analyzer.create_consensus_reference_pattern(examples) + """ + if not annotated_samples: + raise ValueError("At least one annotated example is required") + + # Extract embedding patterns from each example + extracted_patterns = {} + for i, example in enumerate(annotated_samples): + pattern = self.get_reference_pattern( + fov_name=example["fov_name"], + track_id=example["track_id"], + timepoints=example["timepoints"], + reference_type=reference_type, + ) + + extracted_patterns[ + f"example_{example['fov_name']}_{example['track_id']}" + ] = { + "pattern": pattern, + "annotations": example.get(annotations_name, None), + "weight": example.get("weight", 1.0), + "source": example, + } + + self.consensus_data = create_consensus_from_patterns( + patterns=extracted_patterns, + reference_selection=reference_selection, + aggregation_method=aggregation_method, + **kwargs, + ) + return self.consensus_data + + def create_alignment_dataframe( + self, + top_matches: pd.DataFrame, + consensus_lineage: np.ndarray, + alignment_name: str = "cell_division", + reference_type: str = "features", + ) -> pd.DataFrame: + """ + Create alignment DataFrame that: + 1. Preserves lineage relationships (lineage_id groups related tracks) + 2. Supports multiple alignment types (cell_division, apoptosis, migration, etc.) + 3. Stores computed features once, reused across alignments + 4. Maintains original timepoint relationships + + Parameters + ---------- + top_matches : pd.DataFrame + DTW match results + consensus_lineage : np.ndarray + Consensus pattern for this alignment + alignment_name : str + Name for this alignment type (e.g., "cell_division", "apoptosis") + reference_type : str + Feature type to use + + Returns + ------- + pd.DataFrame + DataFrame with lineage preservation and extensible alignments + """ + alignment_data = [] + track_lineage_mapping = {} + + available_obsm = list(self.adata.obsm.keys()) + has_pca_obsm = "X_PCA" in available_obsm + pc_components = [] + + if has_pca_obsm: + n_pca_components = self.adata.obsm["X_PCA"].shape[1] + pc_components = [f"PC{i + 1}" for i in range(n_pca_components)] + else: + pc_components = [ + col + for col in self.adata.obs.columns + if col.startswith("PC") and col[2:].isdigit() + ] + pc_components.sort(key=lambda x: int(x[2:])) # Sort PC1, PC2, PC3, etc. + + has_pc_components = len(pc_components) > 0 + + lineage_counter = 0 + for idx, match_row in top_matches.iterrows(): + fov_name = match_row["fov_name"] + track_ids = match_row["track_ids"] + + lineage_id = lineage_counter + lineage_counter += 1 + + for track_id in track_ids: + track_lineage_mapping[(fov_name, track_id)] = lineage_id + + pca = None + scaler = None + consensus_pca = None + + if not has_pc_components: + all_embeddings = [] + + for idx, match_row in top_matches.iterrows(): + fov_name = match_row["fov_name"] + track_ids = match_row["track_ids"] + + for track_id in track_ids: + try: + mask = (self.adata.obs["fov_name"] == fov_name) & ( + self.adata.obs["track_id"] == track_id + ) + track_data = self.adata[mask] + + time_order = np.argsort(track_data.obs["t"].values) + track_data = track_data[time_order] + + if reference_type == "features": + track_embeddings = track_data.X + else: + track_embeddings = track_data.obsm[reference_type] + + all_embeddings.append(track_embeddings) + except KeyError: + continue + all_embeddings.append(consensus_lineage) + all_concat = np.vstack(all_embeddings) + n_components = min(8, all_concat.shape[1]) + scaler = StandardScaler() + scaled_all = scaler.fit_transform(all_concat) + pca = PCA(n_components=n_components) + pca_all = pca.fit_transform(scaled_all) + + consensus_pca = pca_all[-len(consensus_lineage) :] + + for idx, match_row in top_matches.iterrows(): + fov_name = match_row["fov_name"] + track_ids = match_row["track_ids"] + warp_path = match_row["warp_path"] + dtw_distance = match_row.get("distance", np.nan) + + # Create mapping from query timepoint to consensus timepoint + query_to_consensus = {} + for consensus_idx, query_timepoint in warp_path: + query_to_consensus[query_timepoint] = consensus_idx + + # Process each track in this lineage + for track_id in track_ids: + try: + # Filter by fov_name and track_id + mask = (self.adata.obs["fov_name"] == fov_name) & ( + self.adata.obs["track_id"] == track_id + ) + track_data = self.adata[mask] + + # Sort by timepoint to ensure correct order + time_order = np.argsort(track_data.obs["t"].values) + track_data = track_data[time_order] + + track_timepoints = track_data.obs["t"].values + + # Get PC features - either from obsm/obs or computed PCA + pc_values = {} + if has_pc_components: + if has_pca_obsm: + # Extract from X_PCA obsm + pca_data = track_data.obsm["X_PCA"] + for i, pc_name in enumerate(pc_components): + pc_values[pc_name] = pca_data[:, i] + else: + # Extract from individual PC columns in obs + for pc_name in pc_components: + pc_values[pc_name] = track_data.obs[pc_name].values + else: + # Use computed PCA + if reference_type == "features": + track_embeddings = track_data.X + else: + track_embeddings = track_data.obsm[reference_type] + scaled_embeddings = scaler.transform(track_embeddings) + track_pca = pca.transform(scaled_embeddings) + # Create PC values for all computed components + for i in range(n_components): + pc_name = f"PC{i + 1}" + pc_values[pc_name] = track_pca[:, i] + + # Get lineage ID for this track + lineage_id = track_lineage_mapping.get((fov_name, track_id), -1) + + # Create row for each timepoint + for i, t in enumerate(track_timepoints): + # Get spatial coordinates from track_data.obs (which is already filtered) + obs_row = track_data.obs.iloc[i] + x_coord = obs_row.get("x", np.nan) + y_coord = obs_row.get("y", np.nan) + + # Determine alignment status for this specific alignment type + is_aligned = t in query_to_consensus + consensus_mapping = query_to_consensus.get(t, np.nan) + + # Create dynamic column names based on alignment_name + row_data = { + # Core tracking info (preserves lineage relationships) + "fov_name": fov_name, + "lineage_id": lineage_id, + "track_id": track_id, + "t": t, + "x": x_coord, + "y": y_coord, + # Alignment-specific columns (dynamic based on alignment_name) + f"dtw_{alignment_name}_consensus_mapping": consensus_mapping, + f"dtw_{alignment_name}_aligned": is_aligned, + f"dtw_{alignment_name}_distance": dtw_distance, + f"dtw_{alignment_name}_match_rank": idx, + } + + # Add all PC components dynamically + for pc_name, pc_vals in pc_values.items(): + row_data[pc_name] = pc_vals[i] + alignment_data.append(row_data) + + except KeyError as e: + _logger.warning( + f"Could not find track {track_id} in FOV {fov_name}: {e}" + ) + continue + + consensus_pc_values = {} + if has_pc_components: + for pc_name in pc_components: + consensus_pc_values[pc_name] = [np.nan] * len(consensus_lineage) + else: + for i in range(n_components): + pc_name = f"PC{i + 1}" + consensus_pc_values[pc_name] = consensus_pca[:, i] + + for i in range(len(consensus_lineage)): + consensus_row = { + "fov_name": "consensus", + "lineage_id": -1, + "track_id": -1, + "t": i, + "x": np.nan, + "y": np.nan, + f"dtw_{alignment_name}_consensus_mapping": i, # Maps to itself + f"dtw_{alignment_name}_aligned": True, + f"dtw_{alignment_name}_distance": np.nan, + f"dtw_{alignment_name}_match_rank": -1, + } + + for pc_name, pc_vals in consensus_pc_values.items(): + consensus_row[pc_name] = pc_vals[i] + + alignment_data.append(consensus_row) + + return pd.DataFrame(alignment_data) + + def get_concatenated_sequences( + self, + df: pd.DataFrame, + alignment_name: str = "cell_division", + feature_columns: list[str] = None, + max_lineages: int = None, + ) -> dict: + """ + Extract concatenated [unaligned_before + aligned + unaligned_after] sequences from enhanced DataFrame. + + This is a shared method used by both plotting and image sequence functions. + + Parameters + ---------- + df : pd.DataFrame + Enhanced DataFrame with alignment information + alignment_name : str + Name of alignment type (e.g., "cell_division") + feature_columns : list[str], optional + Feature columns to extract. If None, extracts PC components only. + max_lineages : int, optional + Maximum number of lineages to process + + Returns + ------- + dict + Dictionary mapping lineage_id to: + - 'unaligned_before_data': dict of arrays/dicts for timepoints before aligned region + - 'aligned_data': dict of consensus-length aligned arrays/dicts + - 'unaligned_after_data': dict of arrays/dicts for timepoints after aligned region + - 'metadata': lineage metadata (fov_name, track_ids, etc.) + """ + aligned_col = f"dtw_{alignment_name}_aligned" + mapping_col = f"dtw_{alignment_name}_consensus_mapping" + distance_col = f"dtw_{alignment_name}_distance" + + if aligned_col not in df.columns: + _logger.error(f"Alignment '{alignment_name}' not found in DataFrame") + return {} + + consensus_df = df[df["lineage_id"] == -1].sort_values("t").copy() + lineages = df[df["lineage_id"] != -1]["lineage_id"].unique() + + if max_lineages is not None: + lineages = lineages[:max_lineages] + + if consensus_df.empty: + _logger.error("No consensus found in DataFrame") + return {} + + consensus_length = len(consensus_df) + concatenated_sequences = {} + + for lineage_id in lineages: + lineage_df = df[df["lineage_id"] == lineage_id].copy().sort_values("t") + if lineage_df.empty: + continue + + aligned_rows = lineage_df[lineage_df[aligned_col]].copy() + + # Extract unaligned timepoints BEFORE and AFTER the aligned portion + if not aligned_rows.empty: + min_aligned_t = aligned_rows["t"].min() + max_aligned_t = aligned_rows["t"].max() + unaligned_before_rows = lineage_df[ + (~lineage_df[aligned_col]) & (lineage_df["t"] < min_aligned_t) + ].copy() + unaligned_after_rows = lineage_df[ + (~lineage_df[aligned_col]) & (lineage_df["t"] > max_aligned_t) + ].copy() + else: + # No aligned portion - treat all as unaligned_before + unaligned_before_rows = lineage_df[~lineage_df[aligned_col]].copy() + unaligned_after_rows = pd.DataFrame() + + # Create consensus-length aligned portion using mapping + # Make dict to map pseudotime to real time + aligned_portion = {} + for _, row in aligned_rows.iterrows(): + consensus_idx = row[mapping_col] + if not pd.isna(consensus_idx): + consensus_idx = int(consensus_idx) + if 0 <= consensus_idx < consensus_length: + row_dict = {"t": row["t"], "row": row} + if feature_columns: + row_dict.update({col: row[col] for col in feature_columns}) + aligned_portion[consensus_idx] = row_dict + + # Fill gaps in aligned portion + filled_aligned = {} + if aligned_portion: + for i in range(consensus_length): + if i in aligned_portion: + filled_aligned[i] = aligned_portion[i] + else: + available_indices = list(aligned_portion.keys()) + if available_indices: + closest_idx = min( + available_indices, key=lambda x: abs(x - i) + ) + filled_aligned[i] = aligned_portion[closest_idx] + else: + consensus_row = consensus_df.iloc[i] + row_dict = {"row": consensus_row} + if feature_columns: + row_dict.update( + {col: consensus_row[col] for col in feature_columns} + ) + filled_aligned[i] = row_dict + + # Convert to arrays/lists for features + aligned_data = {"length": consensus_length, "mapping": filled_aligned} + if feature_columns: + aligned_data["features"] = {} + for col in feature_columns: + aligned_data["features"][col] = np.array( + [filled_aligned[i][col] for i in range(consensus_length)] + ) + + # Process unaligned BEFORE portion + unaligned_before_data = { + "length": len(unaligned_before_rows), + "rows": unaligned_before_rows, + } + if feature_columns and not unaligned_before_rows.empty: + unaligned_before_rows = unaligned_before_rows.sort_values("t") + unaligned_before_data["features"] = {} + for col in feature_columns: + unaligned_before_data["features"][col] = unaligned_before_rows[ + col + ].values + + # Process unaligned AFTER portion + unaligned_after_data = { + "length": len(unaligned_after_rows), + "rows": unaligned_after_rows, + } + if feature_columns and not unaligned_after_rows.empty: + unaligned_after_rows = unaligned_after_rows.sort_values("t") + unaligned_after_data["features"] = {} + for col in feature_columns: + unaligned_after_data["features"][col] = unaligned_after_rows[ + col + ].values + + concatenated_sequences[lineage_id] = { + "unaligned_before_data": unaligned_before_data, + "aligned_data": aligned_data, + "unaligned_after_data": unaligned_after_data, + "metadata": { + "fov_name": lineage_df["fov_name"].iloc[0], + "track_ids": list(lineage_df["track_id"].unique()), + "dtw_distance": ( + lineage_df[distance_col].iloc[0] + if not pd.isna(lineage_df[distance_col].iloc[0]) + else np.nan + ), + "lineage_id": lineage_id, + "consensus_length": consensus_length, + }, + } + + return concatenated_sequences + + def add_warped_coordinates( + self, + df: pd.DataFrame, + alignment_name: str = "cell_division", + ) -> pd.DataFrame: + """ + Add warped time coordinates to dataframe for synchronized pseudotime analysis. + + This method computes synchronized pseudotime coordinates (warped_t) for all cells + aligned with the specified alignment type. Warped time synchronizes all cells so + their aligned regions start at the same timepoint, enabling proper aggregation + across cells in biological time rather than experimental time. + + Parameters + ---------- + df : pd.DataFrame + Master dataframe with DTW alignment columns (dtw_{alignment}_aligned, + dtw_{alignment}_consensus_mapping, etc.) + alignment_name : str + Which alignment to compute warped coordinates for + (e.g., "infection_state", "cell_division", "apoptosis") + + Returns + ------- + pd.DataFrame + Augmented dataframe with new columns: + - dtw_{alignment_name}_warped_t : int + Synchronized pseudotime coordinate (0 to total_warped_length-1). + All cells' aligned regions start at max_unaligned_before. + - dtw_{alignment_name}_warped_offset : int + Per-lineage padding offset to synchronize sequences. + + Notes + ----- + Warped time structure: + - [0 to max_unaligned_before-1]: Before aligned region (padded/variable per cell) + - [max_unaligned_before to max_unaligned_before+consensus_length-1]: ALIGNED region (synchronized) + - [max_unaligned_before+consensus_length to end]: After aligned region (padded/variable per cell) + + The metadata is stored in self.consensus_data['warped_metadata'] containing: + - max_unaligned_before, max_unaligned_after, consensus_length, total_warped_length + + Examples + -------- + >>> # Add warped coordinates for infection alignment + >>> master_df = cytodtw.add_warped_coordinates( + ... master_df, + ... alignment_name="infection_state" + ... ) + >>> # Now aggregate in warped time + >>> warped_agg = master_df.groupby('dtw_infection_state_warped_t').median() + """ + import numpy as np + import pandas as pd + + aligned_col = f"dtw_{alignment_name}_aligned" + mapping_col = f"dtw_{alignment_name}_consensus_mapping" + + # Check if this alignment exists + if aligned_col not in df.columns: + raise ValueError( + f"Alignment '{alignment_name}' not found in dataframe. " + f"Expected column '{aligned_col}' not present." + ) + + # Get consensus length + consensus_df = df[df["lineage_id"] == -1].copy() + if consensus_df.empty: + raise ValueError( + "No consensus pattern found in dataframe (lineage_id == -1)" + ) + consensus_length = len(consensus_df) + + # Pass 1: Compute global offsets + _logger.info(f"Computing warped coordinates for alignment: {alignment_name}") + _logger.info("Pass 1: Computing global offsets across all lineages...") + + max_unaligned_before = 0 + max_unaligned_after = 0 + lineage_info = {} + + # Get unique lineages (excluding consensus) + lineages = df[df["lineage_id"] != -1]["lineage_id"].unique() + + for lineage_id in lineages: + lineage_df = df[df["lineage_id"] == lineage_id].copy().sort_values("t") + if lineage_df.empty: + continue + + # Get aligned rows + aligned_rows = lineage_df[lineage_df[aligned_col].fillna(False)].copy() + + if not aligned_rows.empty: + # Find unaligned before and after regions + min_aligned_t = aligned_rows["t"].min() + max_aligned_t = aligned_rows["t"].max() + + unaligned_before = lineage_df[ + (~lineage_df[aligned_col].fillna(False)) + & (lineage_df["t"] < min_aligned_t) + ] + unaligned_after = lineage_df[ + (~lineage_df[aligned_col].fillna(False)) + & (lineage_df["t"] > max_aligned_t) + ] + + unaligned_before_length = len(unaligned_before) + unaligned_after_length = len(unaligned_after) + + # Update global maxes + max_unaligned_before = max( + max_unaligned_before, unaligned_before_length + ) + max_unaligned_after = max(max_unaligned_after, unaligned_after_length) + + # Store lineage info for pass 2 + lineage_info[lineage_id] = { + "unaligned_before_length": unaligned_before_length, + "unaligned_after_length": unaligned_after_length, + "min_aligned_t": min_aligned_t, + "max_aligned_t": max_aligned_t, + } + + total_warped_length = ( + max_unaligned_before + consensus_length + max_unaligned_after + ) + + _logger.info(f" Max unaligned before: {max_unaligned_before}") + _logger.info(f" Consensus length: {consensus_length}") + _logger.info(f" Max unaligned after: {max_unaligned_after}") + _logger.info(f" Total warped length: {total_warped_length}") + + # Store metadata in consensus_data + if self.consensus_data is None: + self.consensus_data = {} + self.consensus_data["warped_metadata"] = { + "max_unaligned_before": max_unaligned_before, + "max_unaligned_after": max_unaligned_after, + "consensus_length": consensus_length, + "total_warped_length": total_warped_length, + } + + # Pass 2: Assign warped_t to each row + _logger.info("Pass 2: Assigning warped_t coordinates to all rows...") + + warped_t_col = f"dtw_{alignment_name}_warped_t" + warped_offset_col = f"dtw_{alignment_name}_warped_offset" + + # Initialize new columns + df[warped_t_col] = np.nan + df[warped_offset_col] = np.nan + + for lineage_id in lineages: + if lineage_id not in lineage_info: + continue + + lineage_mask = df["lineage_id"] == lineage_id + lineage_df = df[lineage_mask].copy().sort_values("t") + + info = lineage_info[lineage_id] + warped_offset = max_unaligned_before - info["unaligned_before_length"] + + # Assign warped_t for each row + for idx, row in lineage_df.iterrows(): + t = row["t"] + is_aligned = row[aligned_col] + + if pd.isna(is_aligned): + is_aligned = False + + if is_aligned: + # Aligned region: use consensus mapping + consensus_idx = row[mapping_col] + if not pd.isna(consensus_idx): + warped_t = max_unaligned_before + int(consensus_idx) + else: + warped_t = np.nan + elif t < info["min_aligned_t"]: + # Unaligned before + before_rows = lineage_df[ + lineage_df["t"] < info["min_aligned_t"] + ].sort_values("t") + position_in_before = before_rows.index.get_loc(idx) + warped_t = warped_offset + position_in_before + elif t > info["max_aligned_t"]: + # Unaligned after + after_rows = lineage_df[ + lineage_df["t"] > info["max_aligned_t"] + ].sort_values("t") + position_in_after = after_rows.index.get_loc(idx) + warped_t = ( + max_unaligned_before + consensus_length + position_in_after + ) + else: + warped_t = np.nan + + # Assign to dataframe + df.loc[idx, warped_t_col] = warped_t + df.loc[idx, warped_offset_col] = warped_offset + + # Also add warped_t for consensus rows (maps to itself in aligned region) + consensus_mask = df["lineage_id"] == -1 + for idx, row in df[consensus_mask].iterrows(): + consensus_idx = ( + int(row[mapping_col]) if not pd.isna(row[mapping_col]) else row.name + ) + df.loc[idx, warped_t_col] = max_unaligned_before + consensus_idx + df.loc[idx, warped_offset_col] = 0 + + _logger.info( + f"Added warped coordinates: {warped_t_col} and {warped_offset_col}" + ) + _logger.info( + f"Warped time range: [0, {total_warped_length - 1}] ({total_warped_length} timepoints)" + ) + + return df + + def plot_global_trends( + self, + df: pd.DataFrame, + alignment_name: str = "cell_division", + feature_columns: list[str] = None, + max_lineages: int = None, + plot_type: str = "mean_bands", + figsize: tuple = (15, 12), + colors: tuple = ("#1f77b4", "#ff7f0e"), + cmap: str = "RdBu_r", + remove_outliers: bool = False, + outlier_percentile: tuple = (1, 99), + ): + """ + Plot global trends across all aligned lineages. + + Parameters + ---------- + df : pd.DataFrame + Enhanced DataFrame with alignment information + alignment_name : str + Name of alignment type + feature_columns : list[str], optional + Feature columns to plot. If None, uses PC1, PC2, PC3 + max_lineages : int, optional + Maximum number of lineages to include + plot_type : str + Type of plot: "mean_bands", "heatmap", "quantile_bands", or "individual_with_mean" + figsize : tuple + Figure size + colors : tuple + Colors for (aligned, unaligned) portions in line plots + cmap : str + Colormap for heatmap plot + remove_outliers : bool + Whether to clip outlier values for better visualization + outlier_percentile : tuple + (lower, upper) percentile bounds for clipping (default: 1st-99th percentile) + Returns + ------- + matplotlib.figure.Figure + The created figure + """ + import matplotlib.pyplot as plt + + if feature_columns is None: + feature_columns = ["PC1", "PC2", "PC3"] + + # Get concatenated sequences + concatenated_seqs = self.get_concatenated_sequences( + df=df, + alignment_name=alignment_name, + feature_columns=feature_columns, + max_lineages=max_lineages, + ) + + if not concatenated_seqs: + _logger.error("No concatenated sequences found") + return None + + consensus_df = df[df["lineage_id"] == -1].sort_values("t").copy() + consensus_length = len(consensus_df) + + n_features = len(feature_columns) + fig, axes = plt.subplots(n_features, 1, figsize=figsize) + if n_features == 1: + axes = [axes] + + for feat_idx, feat_col in enumerate(feature_columns): + ax = axes[feat_idx] + + all_unaligned_before = [] + all_aligned = [] + all_unaligned_after = [] + + for lineage_id, seq_data in concatenated_seqs.items(): + unaligned_before_features = seq_data["unaligned_before_data"].get( + "features", {} + ) + aligned_features = seq_data["aligned_data"]["features"] + unaligned_after_features = seq_data["unaligned_after_data"].get( + "features", {} + ) + + if feat_col in unaligned_before_features: + all_unaligned_before.append(unaligned_before_features[feat_col]) + all_aligned.append(aligned_features[feat_col]) + if feat_col in unaligned_after_features: + all_unaligned_after.append(unaligned_after_features[feat_col]) + + aligned_array = np.array(all_aligned) + max_unaligned_before_len = ( + max([len(u) for u in all_unaligned_before]) + if all_unaligned_before + else 0 + ) + max_unaligned_after_len = ( + max([len(u) for u in all_unaligned_after]) if all_unaligned_after else 0 + ) + + if remove_outliers: + all_values = [] + all_values.extend(consensus_df[feat_col].values) + if all_unaligned_before: + for u in all_unaligned_before: + all_values.extend(u) + all_values.extend(aligned_array.flatten()) + if all_unaligned_after: + for u in all_unaligned_after: + all_values.extend(u) + + all_values = np.array(all_values) + all_values = all_values[~np.isnan(all_values)] + + if len(all_values) > 0: + lower_bound = np.percentile(all_values, outlier_percentile[0]) + upper_bound = np.percentile(all_values, outlier_percentile[1]) + _logger.info( + f"{feat_col}: setting y-limits to [{lower_bound:.3f}, {upper_bound:.3f}]" + ) + + # Pad unaligned arrays to same length + # Pad "before" arrays: prepend NaN to align shorter sequences to the right + if max_unaligned_before_len > 0: + unaligned_before_array = np.full( + (len(concatenated_seqs), max_unaligned_before_len), np.nan + ) + for i, u in enumerate(all_unaligned_before): + padding_needed = max_unaligned_before_len - len(u) + # Right-align by prepending NaN padding + unaligned_before_array[i, padding_needed:] = u + else: + unaligned_before_array = np.array([]).reshape(0, 0) + + # Pad "after" arrays: append NaN (left-aligned) + if max_unaligned_after_len > 0: + unaligned_after_array = np.full( + (len(concatenated_seqs), max_unaligned_after_len), np.nan + ) + for i, u in enumerate(all_unaligned_after): + unaligned_after_array[i, : len(u)] = u + else: + unaligned_after_array = np.array([]).reshape(0, 0) + + if plot_type == "mean_bands": + # Plot mean ± SEM for all three segments + time_offset = 0 + + # 1. Plot unaligned BEFORE + if unaligned_before_array.size > 0: + unaligned_before_mean = np.nanmean(unaligned_before_array, axis=0) + unaligned_before_sem = np.nanstd( + unaligned_before_array, axis=0 + ) / np.sqrt(np.sum(~np.isnan(unaligned_before_array), axis=0)) + unaligned_before_time = np.arange( + time_offset, time_offset + max_unaligned_before_len + ) + + ax.plot( + unaligned_before_time, + unaligned_before_mean, + "--", + color=colors[1], + linewidth=2, + label="Unaligned before", + alpha=0.7, + zorder=2, + ) + ax.fill_between( + unaligned_before_time, + unaligned_before_mean - unaligned_before_sem, + unaligned_before_mean + unaligned_before_sem, + alpha=0.2, + color=colors[1], + zorder=1, + ) + time_offset += max_unaligned_before_len + + # 2. Plot ALIGNED (highlight with thick lines) + aligned_mean = np.nanmean(aligned_array, axis=0) + aligned_sem = np.nanstd(aligned_array, axis=0) / np.sqrt( + np.sum(~np.isnan(aligned_array), axis=0) + ) + aligned_time = np.arange(time_offset, time_offset + consensus_length) + + ax.plot( + aligned_time, + aligned_mean, + "-", + color=colors[0], + linewidth=3, + label="Aligned mean", + zorder=4, + ) + ax.fill_between( + aligned_time, + aligned_mean - aligned_sem, + aligned_mean + aligned_sem, + alpha=0.3, + color=colors[0], + zorder=3, + ) + time_offset += consensus_length + + # 3. Plot unaligned AFTER + if unaligned_after_array.size > 0: + unaligned_after_mean = np.nanmean(unaligned_after_array, axis=0) + unaligned_after_sem = np.nanstd( + unaligned_after_array, axis=0 + ) / np.sqrt(np.sum(~np.isnan(unaligned_after_array), axis=0)) + unaligned_after_time = np.arange( + time_offset, time_offset + max_unaligned_after_len + ) + + ax.plot( + unaligned_after_time, + unaligned_after_mean, + "--", + color=colors[1], + linewidth=2, + label="Unaligned after", + alpha=0.7, + zorder=2, + ) + ax.fill_between( + unaligned_after_time, + unaligned_after_mean - unaligned_after_sem, + unaligned_after_mean + unaligned_after_sem, + alpha=0.2, + color=colors[1], + zorder=1, + ) + + elif plot_type == "quantile_bands": + # Plot median + quartiles for all three segments + time_offset = 0 + + # 1. Plot unaligned BEFORE + if unaligned_before_array.size > 0: + unaligned_before_median = np.nanmedian( + unaligned_before_array, axis=0 + ) + unaligned_before_q25 = np.nanpercentile( + unaligned_before_array, 25, axis=0 + ) + unaligned_before_q75 = np.nanpercentile( + unaligned_before_array, 75, axis=0 + ) + unaligned_before_time = np.arange( + time_offset, time_offset + max_unaligned_before_len + ) + + ax.plot( + unaligned_before_time, + unaligned_before_median, + "--", + color=colors[1], + linewidth=2, + label="Unaligned before", + alpha=0.7, + zorder=2, + ) + ax.fill_between( + unaligned_before_time, + unaligned_before_q25, + unaligned_before_q75, + alpha=0.2, + color=colors[1], + zorder=1, + ) + time_offset += max_unaligned_before_len + + # 2. Plot ALIGNED (highlight with thick lines) + aligned_median = np.nanmedian(aligned_array, axis=0) + aligned_q25 = np.nanpercentile(aligned_array, 25, axis=0) + aligned_q75 = np.nanpercentile(aligned_array, 75, axis=0) + aligned_time = np.arange(time_offset, time_offset + consensus_length) + + ax.plot( + aligned_time, + aligned_median, + "-", + color=colors[0], + linewidth=3, + label="Aligned median", + zorder=4, + ) + ax.fill_between( + aligned_time, + aligned_q25, + aligned_q75, + alpha=0.3, + color=colors[0], + zorder=3, + ) + time_offset += consensus_length + + # 3. Plot unaligned AFTER + if unaligned_after_array.size > 0: + unaligned_after_median = np.nanmedian(unaligned_after_array, axis=0) + unaligned_after_q25 = np.nanpercentile( + unaligned_after_array, 25, axis=0 + ) + unaligned_after_q75 = np.nanpercentile( + unaligned_after_array, 75, axis=0 + ) + unaligned_after_time = np.arange( + time_offset, time_offset + max_unaligned_after_len + ) + + ax.plot( + unaligned_after_time, + unaligned_after_median, + "--", + color=colors[1], + linewidth=2, + label="Unaligned after", + alpha=0.7, + zorder=2, + ) + ax.fill_between( + unaligned_after_time, + unaligned_after_q25, + unaligned_after_q75, + alpha=0.2, + color=colors[1], + zorder=1, + ) + + elif plot_type == "heatmap": + # Stack all data for heatmap [before, aligned, after] + # Use padded arrays so boundaries align across all lineages + full_data = [] + for i in range(len(all_aligned)): + parts = [] + # Add padded before segment + if unaligned_before_array.size > 0: + parts.append(unaligned_before_array[i]) + # Add aligned + parts.append(all_aligned[i]) + # Add padded after segment + if unaligned_after_array.size > 0: + parts.append(unaligned_after_array[i]) + + full_seq = np.concatenate(parts) if len(parts) > 1 else parts[0] + full_data.append(full_seq) + + # Stack into 2D array (all sequences should have same length now due to padding) + heatmap_data = np.array(full_data) + + im = ax.imshow( + heatmap_data, aspect="auto", cmap=cmap, interpolation="nearest" + ) + # Mark aligned region boundaries with gray lines + if max_unaligned_before_len > 0: + ax.axvline( + max_unaligned_before_len - 0.5, + color="gray", + linewidth=2, + linestyle=":", + alpha=0.7, + label="Aligned region start", + ) + if max_unaligned_after_len > 0: + ax.axvline( + max_unaligned_before_len + consensus_length - 0.5, + color="gray", + linewidth=2, + linestyle=":", + alpha=0.7, + label="Aligned region end", + ) + plt.colorbar(im, ax=ax, label=feat_col) + ax.set_ylabel("Lineage") + + elif plot_type == "individual_with_mean": + # Plot all individual traces + mean overlay + time_offset = 0 + + # 1. Plot unaligned BEFORE individual traces + if max_unaligned_before_len > 0: + for i in range(len(all_unaligned_before)): + unaligned_before_time = np.arange( + time_offset, time_offset + len(all_unaligned_before[i]) + ) + ax.plot( + unaligned_before_time, + all_unaligned_before[i], + "--", + color=colors[1], + alpha=0.2, + linewidth=1, + zorder=1, + ) + time_offset += max_unaligned_before_len + + # 2. Plot ALIGNED individual traces + aligned_time = np.arange(time_offset, time_offset + consensus_length) + for i in range(len(all_aligned)): + ax.plot( + aligned_time, + all_aligned[i], + "-", + color=colors[0], + alpha=0.2, + linewidth=1, + zorder=1, + ) + time_offset += consensus_length + + # 3. Plot unaligned AFTER individual traces + if max_unaligned_after_len > 0: + for i in range(len(all_unaligned_after)): + unaligned_after_time = np.arange( + time_offset, time_offset + len(all_unaligned_after[i]) + ) + ax.plot( + unaligned_after_time, + all_unaligned_after[i], + "--", + color=colors[1], + alpha=0.2, + linewidth=1, + zorder=1, + ) + + # Overlay mean traces + time_offset = 0 + + # 1. Plot unaligned BEFORE mean + if unaligned_before_array.size > 0: + unaligned_before_mean = np.nanmean(unaligned_before_array, axis=0) + unaligned_before_time = np.arange( + time_offset, time_offset + max_unaligned_before_len + ) + ax.plot( + unaligned_before_time, + unaligned_before_mean, + "--", + color="black", + linewidth=3, + label="Mean (before)", + zorder=3, + ) + time_offset += max_unaligned_before_len + + # 2. Plot ALIGNED mean + aligned_mean = np.nanmean(aligned_array, axis=0) + aligned_time = np.arange(time_offset, time_offset + consensus_length) + ax.plot( + aligned_time, + aligned_mean, + "-", + color="black", + linewidth=3, + label="Mean (aligned)", + zorder=3, + ) + time_offset += consensus_length + + # 3. Plot unaligned AFTER mean + if unaligned_after_array.size > 0: + unaligned_after_mean = np.nanmean(unaligned_after_array, axis=0) + unaligned_after_time = np.arange( + time_offset, time_offset + max_unaligned_after_len + ) + ax.plot( + unaligned_after_time, + unaligned_after_mean, + "--", + color="black", + linewidth=3, + label="Mean (after)", + zorder=3, + ) + + # Mark infection state if available in consensus + if self.consensus_data is not None: + consensus_annotations = self.consensus_data.get("annotations", None) + if consensus_annotations and "infected" in consensus_annotations: + infection_idx = consensus_annotations.index("infected") + if plot_type == "heatmap": + # Offset by max_unaligned_before_len to align with padded structure + infection_t = max_unaligned_before_len + infection_idx + ax.axvline( + infection_t - 0.5, # -0.5 for proper pixel alignment + color="orange", + linewidth=2.5, + linestyle="-", + alpha=0.9, + label="Infection", + ) + else: + # For line plots, use time offset + infection_t = max_unaligned_before_len + infection_idx + ax.axvline( + infection_t, + color="orange", + alpha=0.7, + linestyle="--", + linewidth=2.5, + label="Infection", + ) + + # Mark alignment boundary + if plot_type != "heatmap": + ax.axvline( + consensus_length, + color="gray", + alpha=0.5, + linestyle=":", + linewidth=2, + ) + ax.text( + consensus_length, + ax.get_ylim()[1], + " Alignment end", + rotation=90, + verticalalignment="top", + fontsize=9, + alpha=0.7, + ) + + # Plot consensus reference + if plot_type != "heatmap": + consensus_values = consensus_df[feat_col].values.copy() + consensus_time = np.arange(len(consensus_values)) + ax.plot( + consensus_time, + consensus_values, + "o-", + color="black", + linewidth=2, + markersize=6, + label="Consensus", + alpha=0.6, + zorder=4, + ) + + # Set y-axis limits based on outlier bounds if requested + if remove_outliers and plot_type != "heatmap": + ax.set_ylim(lower_bound, upper_bound) + + ax.set_ylabel(feat_col) + if feat_idx == 0: + ax.legend(loc="best") + if feat_idx == n_features - 1: + ax.set_xlabel("Time: [Aligned] | [Unaligned continuation]") + ax.set_title( + f"{feat_col} - {plot_type} (n={len(concatenated_seqs)} lineages)" + ) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + return fig + + def plot_sample_patterns( + self, + annotated_samples: list[AnnotatedSample], + reference_type: str = "features", + n_pca_components: int = 3, + figsize: tuple = None, + ): + """ + Plot PCA projections of reference patterns with annotations before consensus creation. + + This method provides quality control visualization of the reference patterns + before creating a consensus. It extracts patterns from annotated samples, + computes PCA for dimensionality reduction, and plots each pattern's trajectory + in PC space with annotation markers (e.g., mitosis, infection). + + Parameters + ---------- + annotated_samples : list[AnnotatedSample] + List of annotated examples to extract and plot patterns from. Each should contain: + - fov_name: str - FOV identifier + - track_id: int | list[int] - Track ID(s) + - timepoints: tuple[int, int] - (start, end) timepoints + - annotations: list - Optional annotations for each timepoint + reference_type : str + Type of embedding to use ("features" for X, or obsm key like "X_PCA") + n_pca_components : int + Number of PCA components to compute and plot (default 3) + figsize : tuple, optional + Figure size (width, height). If None, auto-computed based on number of patterns. + + Returns + ------- + matplotlib.figure.Figure + The created matplotlib figure + + Examples + -------- + >>> cytodtw = CytoDtw(adata) + >>> samples = [ + ... {'fov_name': 'A/2/001', 'track_id': [10], + ... 'timepoints': (0, 30), 'annotations': ['G1']*15 + ['mitosis'] + ['G1']*14} + ... ] + >>> fig = cytodtw.plot_sample_patterns(samples) + """ + import matplotlib.pyplot as plt + + # Extract patterns from annotated samples + patterns = [] + pattern_info = [] + + for i, example in enumerate(annotated_samples): + pattern = self.get_reference_pattern( + fov_name=example["fov_name"], + track_id=example["track_id"], + timepoints=example["timepoints"], + reference_type=reference_type, + ) + patterns.append(pattern) + pattern_info.append( + { + "index": i, + "fov_name": example["fov_name"], + "track_id": example["track_id"], + "timepoints": example["timepoints"], + "annotations": example.get("annotations", None), + } + ) + + # Concatenate all patterns and fit PCA + all_patterns_concat = np.vstack(patterns) + scaler = StandardScaler() + scaled_patterns = scaler.fit_transform(all_patterns_concat) + pca = PCA(n_components=n_pca_components) + pca.fit(scaled_patterns) + + # Create figure + n_patterns = len(patterns) + if figsize is None: + figsize = (12, 3 * n_patterns) + + fig, axes = plt.subplots(n_patterns, n_pca_components, figsize=figsize) + if n_patterns == 1: + axes = axes.reshape(1, -1) + + # Plot each pattern + for i, (pattern, info) in enumerate(zip(patterns, pattern_info)): + scaled_pattern = scaler.transform(pattern) + pc_pattern = pca.transform(scaled_pattern) + time_axis = np.arange(len(pattern)) + + for pc_idx in range(n_pca_components): + ax = axes[i, pc_idx] + + ax.plot( + time_axis, + pc_pattern[:, pc_idx], + "o-", + color="blue", + linewidth=2, + markersize=4, + ) + + # Add annotation markers + annotations = info.get("annotations") + if annotations: + for t, annotation in enumerate(annotations): + if annotation == "mitosis": + ax.axvline( + t, + color="orange", + alpha=0.7, + linestyle="--", + linewidth=2, + ) + ax.scatter( + t, pc_pattern[t, pc_idx], c="orange", s=50, zorder=5 + ) + elif annotation == "infected": + ax.axvline( + t, color="red", alpha=0.5, linestyle="--", linewidth=1 + ) + ax.scatter( + t, pc_pattern[t, pc_idx], c="red", s=30, zorder=5 + ) + break # Only mark the first infection timepoint + + ax.set_xlabel("Time") + ax.set_ylabel(f"PC{pc_idx + 1}") + ax.set_title( + f"Pattern {i + 1}: FOV {info['fov_name']}, Tracks {info['track_id']}\nPC{pc_idx + 1} over time" + ) + ax.grid(True, alpha=0.3) + + plt.tight_layout() + return fig + + def plot_consensus_validation( + self, + annotated_samples: list[AnnotatedSample], + reference_type: str = "features", + metric: str = "cosine", + constraint_type: str = "unconstrained", + band_width_ratio: float = 0.0, + n_pca_components: int = 3, + figsize: tuple = (18, 5), + ): + """ + Plot DTW-aligned reference patterns overlaid with consensus for quality validation. + + This method validates consensus quality by aligning all reference patterns to the + consensus using DTW, then visualizing them together in PCA space. This helps verify + that the consensus captures the common structure across all reference patterns. + + Parameters + ---------- + annotated_samples : list[AnnotatedSample] + List of annotated examples that were used to create the consensus + reference_type : str + Type of embedding to use ("features" for X, or obsm key like "X_PCA") + metric : str + Distance metric for DTW alignment (default "cosine") + constraint_type : str + DTW constraint type ("unconstrained", "sakoe_chiba") + band_width_ratio : float + DTW band width ratio for Sakoe-Chiba constraint (default 0.0) + n_pca_components : int + Number of PCA components to compute and plot (default 3) + figsize : tuple + Figure size (width, height) + + Returns + ------- + matplotlib.figure.Figure + The created matplotlib figure + + Raises + ------ + ValueError + If consensus_data has not been set in the instance + + Examples + -------- + >>> cytodtw = CytoDtw(adata) + >>> samples = [...] + >>> consensus = cytodtw.create_consensus_reference_pattern(samples) + >>> fig = cytodtw.plot_consensus_validation(samples) + """ + import matplotlib.pyplot as plt + + if self.consensus_data is None: + raise ValueError( + "Consensus data not found. Please create or load consensus first using " + "create_consensus_reference_pattern() or load_consensus()." + ) + + consensus_lineage = self.consensus_data["pattern"] + consensus_annotations = self.consensus_data.get("annotations", None) + + # Extract and align patterns to consensus + aligned_patterns_list = [] + aligned_annotations_list = [] + all_patterns_for_pca = [] + + for i, example in enumerate(annotated_samples): + # Extract pattern + pattern = self.get_reference_pattern( + fov_name=example["fov_name"], + track_id=example["track_id"], + timepoints=example["timepoints"], + reference_type=reference_type, + ) + + # Align to consensus + if len(pattern) == len(consensus_lineage): + # Already same length, likely the reference pattern + aligned_patterns_list.append(pattern) + aligned_annotations_list.append(example.get("annotations", None)) + else: + # Align to consensus using DTW + alignment_result = align_embedding_patterns( + query_pattern=pattern, + reference_pattern=consensus_lineage, + metric=metric, + query_annotations=example.get("annotations", None), + constraint_type=constraint_type, + band_width_ratio=band_width_ratio, + ) + aligned_patterns_list.append(alignment_result["pattern"]) + aligned_annotations_list.append( + alignment_result.get("annotations", None) + ) + + all_patterns_for_pca.append(aligned_patterns_list[-1]) + + # Add consensus to patterns for PCA fitting + all_patterns_for_pca.append(consensus_lineage) + all_patterns_concat = np.vstack(all_patterns_for_pca) + + # Fit PCA on all patterns (aligned references + consensus) + scaler = StandardScaler() + scaled_patterns = scaler.fit_transform(all_patterns_concat) + pca = PCA(n_components=n_pca_components) + pca.fit(scaled_patterns) + + # Create figure + fig, axes = plt.subplots(1, n_pca_components, figsize=figsize) + if n_pca_components == 1: + axes = [axes] + + # Plot each aligned pattern + for pc_idx in range(n_pca_components): + ax = axes[pc_idx] + + # Transform each aligned pattern to PC space and plot + for i, pattern in enumerate(aligned_patterns_list): + scaled_ref = scaler.transform(pattern) + pc_ref = pca.transform(scaled_ref) + + time_axis = np.arange(len(pc_ref)) + ax.plot( + time_axis, + pc_ref[:, pc_idx], + "o-", + label=f"Ref {i + 1}", + alpha=0.7, + linewidth=2, + markersize=4, + ) + + # Overlay consensus pattern + scaled_consensus = scaler.transform(consensus_lineage) + pc_consensus = pca.transform(scaled_consensus) + time_axis = np.arange(len(pc_consensus)) + ax.plot( + time_axis, + pc_consensus[:, pc_idx], + "s-", + color="black", + linewidth=3, + markersize=6, + label="Consensus", + zorder=10, + ) + + # Mark consensus infection timepoint with a thicker, more prominent line + if consensus_annotations and "infected" in consensus_annotations: + consensus_infection_t = consensus_annotations.index("infected") + ax.axvline( + consensus_infection_t, + color="orange", + alpha=0.9, + linestyle="--", + linewidth=2.5, + label="Infection", + ) + + ax.set_xlabel("Aligned Time") + ax.set_ylabel(f"PC{pc_idx + 1}") + ax.set_title(f"PC{pc_idx + 1}: All DTW-Aligned References + Consensus") + ax.grid(True, alpha=0.3) + ax.legend() + + plt.suptitle( + "Consensus Validation: DTW-Aligned References + Computed Consensus", + fontsize=14, + ) + plt.tight_layout() + return fig + + def plot_individual_lineages( + self, + df: pd.DataFrame, + alignment_name: str = "cell_division", + feature_columns: list[str] = None, + max_lineages: int = 5, + y_offset_step: float = 0.0, + figsize: tuple = (15, 12), + aligned_linewidth: float = 2.5, + unaligned_linewidth: float = 1.0, + aligned_markersize: float = 4.0, + unaligned_markersize: float = 2.0, + remove_outliers: bool = False, + outlier_percentile: tuple = (1, 99), + ): + """ + Plot individual lineages with y-offsets (waterfall plot). + + Each lineage is shown as a separate trace with vertical offset for clarity. + The aligned portion is highlighted with thicker lines/markers. + + Parameters + ---------- + df : pd.DataFrame + Enhanced DataFrame with alignment information + alignment_name : str + Name of alignment type + feature_columns : list[str], optional + Feature columns to plot. If None, uses PC1, PC2, PC3 + max_lineages : int + Maximum number of lineages to display + y_offset_step : float + Vertical separation between lineages + figsize : tuple + Figure size + aligned_linewidth : float + Line width for aligned portions + unaligned_linewidth : float + Line width for unaligned portions + aligned_markersize : float + Marker size for aligned portions + unaligned_markersize : float + Marker size for unaligned portions + remove_outliers : bool + Whether to clip outlier values for better visualization + outlier_percentile : tuple + (lower, upper) percentile bounds for clipping + + Returns + ------- + matplotlib.figure.Figure + The created figure + """ + import matplotlib.pyplot as plt + + if feature_columns is None: + feature_columns = ["PC1", "PC2", "PC3"] + + # Get concatenated sequences + concatenated_seqs = self.get_concatenated_sequences( + df=df, + alignment_name=alignment_name, + feature_columns=feature_columns, + max_lineages=max_lineages, + ) + + if not concatenated_seqs: + _logger.error("No concatenated sequences found") + return None + + # Get consensus for reference + consensus_df = df[df["lineage_id"] == -1].sort_values("t").copy() + if consensus_df.empty: + _logger.error("No consensus found in DataFrame") + return None + + # Find maximum "before" length across all lineages to align them + max_unaligned_before_length = max( + [ + seq_data["unaligned_before_data"]["length"] + for seq_data in concatenated_seqs.values() + ], + default=0, + ) + + # Prepare concatenated lineages data with padding + concatenated_lineages = {} + for lineage_id, seq_data in concatenated_seqs.items(): + unaligned_before_features = seq_data["unaligned_before_data"].get( + "features", {} + ) + aligned_features = seq_data["aligned_data"]["features"] + unaligned_after_features = seq_data["unaligned_after_data"].get( + "features", {} + ) + + unaligned_before_length = seq_data["unaligned_before_data"]["length"] + padding_needed = max_unaligned_before_length - unaligned_before_length + + concatenated_arrays = {} + for col in feature_columns: + # Pad "before" segment to max length for alignment + parts = [] + if ( + col in unaligned_before_features + and len(unaligned_before_features[col]) > 0 + ): + before_vals = unaligned_before_features[col] + if padding_needed > 0: + # Prepend NaN padding + before_padded = np.concatenate( + [np.full(padding_needed, np.nan), before_vals] + ) + else: + before_padded = before_vals + parts.append(before_padded) + else: + # No before data - pad with full max length + parts.append(np.full(max_unaligned_before_length, np.nan)) + + # Add aligned and after portions + parts.append(aligned_features[col]) + if ( + col in unaligned_after_features + and len(unaligned_after_features[col]) > 0 + ): + parts.append(unaligned_after_features[col]) + + concatenated_arrays[col] = np.concatenate(parts) + + concatenated_lineages[lineage_id] = { + "concatenated": concatenated_arrays, + "unaligned_before_length": seq_data["unaligned_before_data"]["length"], + "aligned_length": seq_data["aligned_data"]["length"], + "unaligned_after_length": seq_data["unaligned_after_data"]["length"], + "dtw_distance": seq_data["metadata"]["dtw_distance"], + "fov_name": seq_data["metadata"]["fov_name"], + "track_ids": seq_data["metadata"]["track_ids"], + } + + # Find maximum concatenated length across ALL features and lineages + # This ensures all features have the same x-axis length + max_concat_length = 0 + for lineage_data in concatenated_lineages.values(): + for feat_array in lineage_data["concatenated"].values(): + max_concat_length = max(max_concat_length, len(feat_array)) + + # Pad all feature arrays to max_concat_length to ensure consistent plot lengths + for lineage_data in concatenated_lineages.values(): + for col in feature_columns: + current_array = lineage_data["concatenated"][col] + if len(current_array) < max_concat_length: + # Pad with NaN at the end + padding_length = max_concat_length - len(current_array) + padded_array = np.concatenate( + [current_array, np.full(padding_length, np.nan)] + ) + lineage_data["concatenated"][col] = padded_array + + # Compute outlier bounds per feature if requested + outlier_bounds = {} + if remove_outliers: + for feat_col in feature_columns: + all_values = [] + all_values.extend(consensus_df[feat_col].values) + for data in concatenated_lineages.values(): + all_values.extend(data["concatenated"][feat_col]) + + all_values = np.array(all_values) + all_values = all_values[~np.isnan(all_values)] + + if len(all_values) > 0: + lower_bound = np.percentile(all_values, outlier_percentile[0]) + upper_bound = np.percentile(all_values, outlier_percentile[1]) + outlier_bounds[feat_col] = (lower_bound, upper_bound) + _logger.info( + f"{feat_col}: removing outliers outside [{lower_bound:.3f}, {upper_bound:.3f}]" + ) + + n_features = len(feature_columns) + fig, axes = plt.subplots(n_features, 1, figsize=figsize) + if n_features == 1: + axes = [axes] + + cmap = plt.cm.get_cmap( + "tab10" + if len(concatenated_lineages) <= 10 + else "tab20" + if len(concatenated_lineages) <= 20 + else "hsv" + ) + colors = [ + cmap(i / max(len(concatenated_lineages), 1)) + for i in range(len(concatenated_lineages)) + ] + + # Create a single time axis for all features based on max_concat_length + # This ensures all feature plots show the same x-axis range + time_axis = np.arange(max_concat_length) + + for feat_idx, feat_col in enumerate(feature_columns): + ax = axes[feat_idx] + + # Plot consensus at offset where aligned regions are + consensus_values = consensus_df[feat_col].values.copy() + consensus_time = np.arange( + max_unaligned_before_length, + max_unaligned_before_length + len(consensus_values), + ) + ax.plot( + consensus_time, + consensus_values, + "o-", + color="black", + linewidth=4, + markersize=8, + label=f"Consensus ({alignment_name})", + alpha=0.9, + zorder=5, + ) + + for lineage_idx, (lineage_id, data) in enumerate( + concatenated_lineages.items() + ): + y_offset = -(lineage_idx + 1) * y_offset_step + color = colors[lineage_idx] + + concat_values = data["concatenated"][feat_col].copy() + y_offset + + track_id_str = ",".join(map(str, data["track_ids"])) + ax.plot( + time_axis, + concat_values, + ".-", + color=color, + linewidth=unaligned_linewidth, + markersize=unaligned_markersize, + alpha=0.8, + label=f"{data['fov_name']}, track={track_id_str} (d={data['dtw_distance']:.3f})", + ) + + # Overlay aligned portion with thick lines + # All aligned portions now start at max_unaligned_before_length due to padding + aligned_length = data["aligned_length"] + aligned_start = max_unaligned_before_length + aligned_end = aligned_start + aligned_length + + if aligned_length > 0: + aligned_time = time_axis[aligned_start:aligned_end] + aligned_vals = concat_values[aligned_start:aligned_end] + ax.plot( + aligned_time, + aligned_vals, + "o-", + color=color, + linewidth=aligned_linewidth, + markersize=aligned_markersize, + alpha=1.0, + zorder=3, + ) + + # Add vertical lines at aligned region boundaries + if max_unaligned_before_length > 0 and aligned_length > 0: + ax.axvline( + aligned_start - 0.5, + color=color, + linestyle=":", + alpha=0.3, + linewidth=1, + ) + if aligned_length > 0 and data["unaligned_after_length"] > 0: + ax.axvline( + aligned_end - 0.5, + color=color, + linestyle=":", + alpha=0.3, + linewidth=1, + ) + + # Mark infection state if available in consensus + if self.consensus_data is not None: + consensus_annotations = self.consensus_data.get("annotations", None) + if consensus_annotations and "infected" in consensus_annotations: + infection_idx = consensus_annotations.index("infected") + # Offset by max_unaligned_before_length to align with consensus position + infection_t = max_unaligned_before_length + infection_idx + ax.axvline( + infection_t, + color="orange", + alpha=0.7, + linestyle="--", + linewidth=2.5, + label="Infection", + zorder=6, + ) + + if remove_outliers and feat_col in outlier_bounds: + lower, upper = outlier_bounds[feat_col] + max_offset = -(len(concatenated_lineages)) * y_offset_step + ax.set_ylim(max_offset + lower - 1, upper + 1) + + ax.set_ylabel(feat_col) + ax.set_xlabel("Time: [before alignment] + [aligned] + [after alignment]") + ax.set_title( + f"{feat_col} - Individual lineages (n={len(concatenated_lineages)})" + ) + # Only add legend to first subplot - all subplots have same legend + if feat_idx == 0: + ax.legend( + loc="upper left", + bbox_to_anchor=(1.01, 1), + fontsize=8, + ncol=1, + frameon=True, + ) + ax.grid(True, alpha=0.3) + # Set explicit x-axis limits to ensure all feature subplots show the same range + ax.set_xlim(0, max_concat_length - 1) + + plt.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend on right + return fig + + +def get_aligned_image_sequences( + cytodtw_instance, + df: pd.DataFrame, + alignment_name: str, + image_loader_fn, + max_lineages: int = None, +) -> dict: + """ + Create concatenated [unaligned_before + aligned + unaligned_after] image sequences from alignment. + + This is a generic function that works with any dataset by accepting a custom + image loader function. + + Parameters + ---------- + cytodtw_instance : CytoDtw + CytoDtw instance with get_concatenated_sequences method + df : pd.DataFrame + Enhanced DataFrame with alignment information + alignment_name : str + Name of alignment to use (e.g., "cell_division", "infection_state") + image_loader_fn : callable + Function that takes (fov_name, track_ids) and returns a dict mapping + timepoint -> image_data. Signature: fn(fov_name, track_ids) -> {t: img_data} + max_lineages : int, optional + Maximum number of lineages to process + + Returns + ------- + dict + Dictionary mapping lineage_id to: + - 'concatenated_images': List of concatenated images (before + aligned + after) + - 'unaligned_before_length': Number of unaligned images before aligned region + - 'aligned_length': Number of DTW-aligned images + - 'unaligned_after_length': Number of unaligned images after aligned region + - 'unaligned_length': Total number of unaligned images (before + after, for backward compatibility) + - 'metadata': Lineage metadata (fov_name, track_ids, dtw_distance, etc.) + + Examples + -------- + >>> # Define custom image loader for your dataset + >>> def load_images(fov_name, track_ids): + ... # Your dataset-specific loading logic + ... images = dataset.get_images_for_tracks(fov_name, track_ids) + ... return {img['t']: img for img in images} + >>> + >>> # Get image sequences + >>> image_seqs = get_image_sequences_from_alignment( + ... cytodtw, alignment_df, "cell_division", + ... image_loader_fn=load_images, max_lineages=10 + ... ) + """ + aligned_col = f"dtw_{alignment_name}_aligned" + if aligned_col not in df.columns: + _logger.error(f"Alignment '{alignment_name}' not found in DataFrame") + return {} + + consensus_df = df[df["lineage_id"] == -1].sort_values("t").copy() + if consensus_df.empty: + _logger.error("No consensus found in DataFrame") + return {} + + concatenated_seqs = cytodtw_instance.get_concatenated_sequences( + df=df, + alignment_name=alignment_name, + feature_columns=None, + max_lineages=max_lineages, + ) + + if not concatenated_seqs: + _logger.error("No concatenated sequences found") + return {} + + concatenated_image_sequences = {} + + for lineage_id, seq_data in concatenated_seqs.items(): + fov_name = seq_data["metadata"]["fov_name"] + track_ids = seq_data["metadata"]["track_ids"] + + try: + time_to_image = image_loader_fn(fov_name, track_ids) + except Exception as e: + _logger.warning(f"Failed to load images for lineage {lineage_id}: {e}") + continue + + if not time_to_image: + _logger.warning( + f"No images found for lineage {lineage_id}, FOV {fov_name}, tracks {track_ids}" + ) + continue + + aligned_mapping = seq_data["aligned_data"]["mapping"] + consensus_length = seq_data["metadata"]["consensus_length"] + aligned_images = [None] * consensus_length + + # pseudotime-indexed image array for the aligned portion + for i in range(consensus_length): + if i in aligned_mapping: + timepoint = aligned_mapping[i]["t"] + if timepoint in time_to_image: + aligned_images[i] = time_to_image[timepoint] + else: + available_times = list(time_to_image.keys()) + if available_times: + closest_time = min( + available_times, key=lambda x: abs(x - timepoint) + ) + aligned_images[i] = time_to_image[closest_time] + + # Gap filling for any missing aligned images + for i in range(consensus_length): + if aligned_images[i] is None: + available_indices = [ + j for j, img in enumerate(aligned_images) if img is not None + ] + if available_indices: + closest_idx = min(available_indices, key=lambda x: abs(x - i)) + aligned_images[i] = aligned_images[closest_idx] + elif time_to_image: + aligned_images[i] = next(iter(time_to_image.values())) + + # Map unaligned BEFORE portion + unaligned_before_images = [] + unaligned_before_rows = seq_data["unaligned_before_data"]["rows"] + if not unaligned_before_rows.empty: + unaligned_before_rows = unaligned_before_rows.sort_values("t") + for _, row in unaligned_before_rows.iterrows(): + timepoint = row["t"] + if timepoint in time_to_image: + unaligned_before_images.append(time_to_image[timepoint]) + else: + # Find closest available time + available_times = list(time_to_image.keys()) + if available_times: + closest_time = min( + available_times, key=lambda x: abs(x - timepoint) + ) + unaligned_before_images.append(time_to_image[closest_time]) + + # Map unaligned AFTER portion + unaligned_after_images = [] + unaligned_after_rows = seq_data["unaligned_after_data"]["rows"] + if not unaligned_after_rows.empty: + unaligned_after_rows = unaligned_after_rows.sort_values("t") + for _, row in unaligned_after_rows.iterrows(): + timepoint = row["t"] + if timepoint in time_to_image: + unaligned_after_images.append(time_to_image[timepoint]) + else: + # Find closest available time + available_times = list(time_to_image.keys()) + if available_times: + closest_time = min( + available_times, key=lambda x: abs(x - timepoint) + ) + unaligned_after_images.append(time_to_image[closest_time]) + + # Concatenate in order: [before, aligned, after] + concatenated_images = ( + unaligned_before_images + aligned_images + unaligned_after_images + ) + + concatenated_image_sequences[lineage_id] = { + "concatenated_images": concatenated_images, + "unaligned_before_length": len(unaligned_before_images), + "aligned_length": len(aligned_images), + "unaligned_after_length": len(unaligned_after_images), + "unaligned_length": len(unaligned_before_images) + + len(unaligned_after_images), # Keep for backward compatibility + "metadata": seq_data["metadata"], + } + + _logger.debug( + f"Created image sequences for {len(concatenated_image_sequences)} lineages" + ) + return concatenated_image_sequences + + +def create_synchronized_warped_sequences( + concatenated_image_sequences: dict, + alignment_df: pd.DataFrame, + alignment_name: str, + consensus_infection_idx: int, + time_interval_minutes: float = 30, +) -> dict: + """ + Create synchronized warped sequences with zero-padding and HPI metadata tracking. + + This function takes concatenated image sequences (from get_aligned_image_sequences) + and creates synchronized warped sequences where all cells' aligned regions start + at the same frame index. It also computes metadata for mapping to Hours Post Infection (HPI). + + Parameters + ---------- + concatenated_image_sequences : dict + Output from get_aligned_image_sequences(), mapping lineage_id to: + - 'concatenated_images': List of [before + aligned + after] images + - 'unaligned_before_length': Number of frames before aligned region + - 'aligned_length': Number of aligned frames (consensus_length) + - 'unaligned_after_length': Number of frames after aligned region + - 'metadata': Lineage metadata (fov_name, track_ids, etc.) + alignment_df : pd.DataFrame + Enhanced DataFrame with DTW alignment columns including: + - 'fov_name', 'track_id', 't' (absolute time) + - f'dtw_{alignment_name}_aligned': Boolean marking aligned frames + - f'dtw_{alignment_name}_consensus_mapping': Pseudotime index (0 to consensus_length-1) + alignment_name : str + Name of alignment (e.g., "cell_division", "infection_state") + consensus_infection_idx : int + Pseudotime index where biological event occurs (e.g., infection starts) + This is used to map the event to absolute time for each cell. + time_interval_minutes : float, optional + Time between frames in minutes (default: 30) + + Returns + ------- + dict + Dictionary containing: + - 'warped_sequences': dict mapping lineage_id to synchronized numpy arrays with shape + (time, channels, z, y, x). All arrays have the same time dimension. + - 'alignment_shifts': dict mapping lineage_id to: + - 'fov_name': FOV name + - 'track_ids': List of track IDs + - 'first_aligned_t': Absolute time where pseudotime 0 starts + - 'infection_t_abs': Absolute time where biological event occurs + - 'shift': Value to add to absolute time for pseudotime-anchored coordinates + - 'infection_offset_in_viz': Frame offset where event occurs in warped coordinates + - 'warped_metadata': dict with: + - 'max_unaligned_before': Maximum frames before aligned region across all cells + - 'max_unaligned_after': Maximum frames after aligned region across all cells + - 'consensus_aligned_length': Length of aligned region (same for all cells) + - 'time_interval_minutes': Time interval for HPI conversion + + Notes + ----- + Coordinate Systems: + 1. Absolute time (t_abs): Original experimental timepoints + 2. Pseudotime (consensus_idx): DTW-aligned indices (0 to consensus_length-1) + 3. Visualization time (t_viz): Pseudotime-anchored coordinates where t_viz=0 at pseudotime 0 + 4. Warped frame index: Synchronized frame indices in output arrays + 5. Hours Post Infection (HPI): Biological time relative to event + + Conversions: + - t_viz = t_abs + shift (where shift = -first_aligned_t) + - HPI = (t_abs - infection_t_abs) * (time_interval_minutes / 60) + - warped_frame_idx = max_unaligned_before + (varies by cell's history) + + Warped Time Structure: + - Frames [0 to max_unaligned_before-1]: Before aligned region (padded with zeros where needed) + - Frames [max_unaligned_before to max_unaligned_before+consensus_length-1]: ALIGNED region (synchronized!) + - Frames [max_unaligned_before+consensus_length to end]: After aligned region (padded where needed) + + Examples + -------- + >>> result = create_synchronized_warped_sequences( + ... concatenated_image_sequences, + ... alignment_df, + ... alignment_name="infection_state", + ... consensus_infection_idx=5, + ... time_interval_minutes=30 + ... ) + >>> warped_arrays = result['warped_sequences'] + >>> shifts = result['alignment_shifts'] + >>> + >>> # Get Hours Post Infection for a specific cell at absolute time t=25 + >>> lineage_id = 1 + >>> t_abs = 25 + >>> infection_t_abs = shifts[lineage_id]['infection_t_abs'] + >>> hpi = (t_abs - infection_t_abs) * (30 / 60) # Convert to hours + """ + import numpy as np + + if len(concatenated_image_sequences) == 0: + _logger.warning("No concatenated sequences provided") + return { + "warped_sequences": {}, + "alignment_shifts": {}, + "warped_metadata": {}, + } + + # Step 1: Find maximum unaligned lengths across all cells + max_unaligned_before = 0 + max_unaligned_after = 0 + consensus_aligned_length = None + + for lineage_id, seq_data in concatenated_image_sequences.items(): + max_unaligned_before = max( + max_unaligned_before, seq_data["unaligned_before_length"] + ) + max_unaligned_after = max( + max_unaligned_after, seq_data["unaligned_after_length"] + ) + if consensus_aligned_length is None: + consensus_aligned_length = seq_data["aligned_length"] + + _logger.info(f"Consensus aligned length: {consensus_aligned_length}") + _logger.info(f"Max unaligned before across all cells: {max_unaligned_before}") + _logger.info(f"Max unaligned after across all cells: {max_unaligned_after}") + + total_warped_length = ( + max_unaligned_before + consensus_aligned_length + max_unaligned_after + ) + _logger.info(f"Total synchronized warped time length: {total_warped_length}") + + # Prepare output dictionaries + warped_sequences = {} + alignment_shifts = {} + + # Column names for alignment + aligned_col = f"dtw_{alignment_name}_aligned" + consensus_mapping_col = f"dtw_{alignment_name}_consensus_mapping" + + # Step 2: Process each lineage + for lineage_id, seq_data in concatenated_image_sequences.items(): + meta = seq_data["metadata"] + fov_name = meta["fov_name"] + track_ids = meta["track_ids"] + concatenated_images = seq_data["concatenated_images"] + + unaligned_before_length = seq_data["unaligned_before_length"] + seq_data["aligned_length"] + unaligned_after_length = seq_data["unaligned_after_length"] + + # Extract images from concatenated sequence + image_stack = [] + for img_sample in concatenated_images: + if img_sample is not None: + img_tensor = img_sample["anchor"] + img_np = img_tensor.cpu().numpy() + image_stack.append(img_np) + + if len(image_stack) == 0: + _logger.warning(f"No images found for lineage {lineage_id}") + continue + + # Stack the raw concatenated sequence + concatenated_stack = np.stack(image_stack, axis=0) + + # Compute padding needed for synchronization + pad_before = max_unaligned_before - unaligned_before_length + pad_after = max_unaligned_after - unaligned_after_length + + # Create zero frames for padding (same shape as images) + dummy_frame = np.zeros_like(concatenated_stack[0]) + + # Pad before and after + padded_before = [dummy_frame] * pad_before + padded_after = [dummy_frame] * pad_after + + # Create synchronized sequence + synchronized_frames = padded_before + list(concatenated_stack) + padded_after + warped_time_series = np.stack(synchronized_frames, axis=0) + + warped_sequences[lineage_id] = warped_time_series + + # Step 3: Compute alignment shifts and HPI metadata + # Get alignment information for this lineage from alignment_df + lineage_rows = alignment_df[ + (alignment_df["fov_name"] == fov_name) + & (alignment_df["track_id"].isin(track_ids)) + ].copy() + + if len(lineage_rows) == 0: + _logger.warning( + f"No alignment data found for lineage {lineage_id} (FOV: {fov_name}, tracks: {track_ids})" + ) + continue + + # Find aligned rows (where DTW matched) + aligned_rows = lineage_rows[lineage_rows[aligned_col]].copy() + + if len(aligned_rows) == 0: + _logger.warning(f"No aligned frames found for lineage {lineage_id}") + continue + + # Find first_aligned_t: absolute time where pseudotime 0 starts + # This is the minimum absolute time among frames mapped to consensus_idx == 0 + pseudotime_zero_rows = aligned_rows[ + aligned_rows[consensus_mapping_col] == 0 + ].copy() + + if len(pseudotime_zero_rows) > 0: + first_aligned_t = pseudotime_zero_rows["t"].min() + else: + # Fallback: use minimum aligned time + first_aligned_t = aligned_rows["t"].min() + _logger.warning( + f"No consensus_idx=0 found for lineage {lineage_id}, using min aligned time" + ) + + # Find infection_t_abs: absolute time where biological event occurs + infection_rows = aligned_rows[ + aligned_rows[consensus_mapping_col] == consensus_infection_idx + ].copy() + + if len(infection_rows) > 0: + infection_t_abs = infection_rows["t"].iloc[0] + else: + # Fallback: estimate based on first_aligned_t + infection_t_abs = first_aligned_t + consensus_infection_idx + _logger.warning( + f"No consensus_idx={consensus_infection_idx} found for lineage {lineage_id}, estimating infection time" + ) + + # Compute shift for pseudotime-anchored coordinates + shift = -first_aligned_t + + # Compute infection offset in visualization coordinates + infection_offset_in_viz = infection_t_abs - first_aligned_t + + # Store shift metadata + alignment_shifts[lineage_id] = { + "fov_name": fov_name, + "track_ids": track_ids, + "first_aligned_t": float(first_aligned_t), + "infection_t_abs": float(infection_t_abs), + "shift": float(shift), + "infection_offset_in_viz": float(infection_offset_in_viz), + } + + _logger.debug( + f"Lineage {lineage_id}: first_aligned_t={first_aligned_t}, " + f"infection_t_abs={infection_t_abs}, shift={shift}, " + f"warped_shape={warped_time_series.shape}" + ) + + # Step 4: Create metadata summary + warped_metadata = { + "max_unaligned_before": int(max_unaligned_before), + "max_unaligned_after": int(max_unaligned_after), + "consensus_aligned_length": int(consensus_aligned_length), + "time_interval_minutes": float(time_interval_minutes), + "total_warped_length": int(total_warped_length), + "aligned_region_start_idx": int(max_unaligned_before), + "aligned_region_end_idx": int(max_unaligned_before + consensus_aligned_length), + } + + _logger.info( + f"Created synchronized warped sequences for {len(warped_sequences)} lineages" + ) + _logger.info( + f"Aligned region in warped coordinates: frames [{warped_metadata['aligned_region_start_idx']} to {warped_metadata['aligned_region_end_idx'] - 1}]" + ) + + return { + "warped_sequences": warped_sequences, + "alignment_shifts": alignment_shifts, + "warped_metadata": warped_metadata, + } + + +def compute_hpi_from_absolute_time( + t_abs: float, + alignment_shifts: dict, + lineage_id: int, + time_interval_minutes: float = 30, +) -> float: + """ + Convert absolute time to Hours Post Infection (HPI) for a specific cell. + + Parameters + ---------- + t_abs : float + Absolute timepoint from experimental data + alignment_shifts : dict + Alignment shifts dictionary from create_synchronized_warped_sequences() + lineage_id : int + Lineage ID to compute HPI for + time_interval_minutes : float, optional + Time between frames in minutes (default: 30) + + Returns + ------- + float + Hours post infection (negative values = before infection) + + Examples + -------- + >>> # Get HPI for cell 1 at absolute time t=25 + >>> hpi = compute_hpi_from_absolute_time( + ... t_abs=25, + ... alignment_shifts=shifts, + ... lineage_id=1, + ... time_interval_minutes=30 + ... ) + >>> print(f"t=25 corresponds to {hpi:.2f} hours post infection") + """ + if lineage_id not in alignment_shifts: + raise ValueError(f"Lineage {lineage_id} not found in alignment_shifts") + + infection_t_abs = alignment_shifts[lineage_id]["infection_t_abs"] + hpi = (t_abs - infection_t_abs) * (time_interval_minutes / 60.0) + return hpi + + +def identify_lineages( + tracking_df: pd.DataFrame, return_both_branches: bool = False +) -> list[tuple[str, list[int]]]: + """Identify distinct lineages in cell tracking data. + + Parameters + ---------- + tracking_df : pd.DataFrame + Tracking dataframe with columns: fov_name, track_id, parent_track_id + return_both_branches : bool + If True, return both branches after division. If False, return only first branch. + + Returns + ------- + list[tuple[str, list[int]]] + List of (fov_name, track_ids) representing lineages + """ + all_lineages = [] + + for fov_id, fov_df in tracking_df.groupby("fov_name"): + # Create parent-child mapping + child_to_parent = {} + for track_id, track_group in fov_df.groupby("track_id"): + first_row = track_group.iloc[0] + parent_track_id = first_row["parent_track_id"] + if parent_track_id != -1: + child_to_parent[track_id] = parent_track_id + + # Find root tracks + all_tracks = set(fov_df["track_id"].unique()) + root_tracks = set() + for track_id in all_tracks: + track_data = fov_df[fov_df["track_id"] == track_id] + if ( + track_data.iloc[0]["parent_track_id"] == -1 + or track_data.iloc[0]["parent_track_id"] not in all_tracks + ): + root_tracks.add(track_id) + + # Build parent-to-children mapping + parent_to_children = {} + for child, parent in child_to_parent.items(): + if parent not in parent_to_children: + parent_to_children[parent] = [] + parent_to_children[parent].append(child) + + def get_all_branches(track_id): + """Get all branches from a parent track.""" + branches = [] + current_branch = [track_id] + + if track_id in parent_to_children: + for child in parent_to_children[track_id]: + child_branches = get_all_branches(child) + for branch in child_branches: + branches.append(current_branch + branch) + else: + branches.append(current_branch) + return branches + + # Build lineages from root tracks + for root_track in root_tracks: + lineage_tracks = get_all_branches(root_track) + if return_both_branches: + for branch in lineage_tracks: + all_lineages.append((fov_id, branch)) + else: + all_lineages.append((fov_id, lineage_tracks[0])) + + return all_lineages + + +def find_pattern_matches( + reference_pattern: np.ndarray, + filtered_lineages: list[tuple[str, list[int]]], + adata: ad.AnnData, + window_step: int = 5, + num_candidates: int = 3, + max_distance: float = float("inf"), + max_skew: float = 0.8, + save_path: str | None = None, + method: str = "bernd_clifford", + normalize: bool = True, + metric: str = "euclidean", + reference_type: Literal["features", "X_PCA", "X_UMAP", "X_PHATE"] = "features", + constraint_type: str = "unconstrained", + band_width_ratio: float = 0.0, +) -> pd.DataFrame: + """Find best matches of a reference pattern in multiple lineages using DTW. + + Parameters + ---------- + reference_pattern : np.ndarray + Reference pattern embeddings + filtered_lineages : list[tuple[str, list[int]]] + List of lineages to search in (fov_name, track_ids) + embeddings_dataset : xr.Dataset + Dataset containing embeddings + window_step : int + Step size for sliding window search + num_candidates : int + Number of best candidates to consider per lineage + max_distance : float + Maximum distance threshold to consider a match + max_skew : float + Maximum allowed path skewness (0-1, where 0=perfect diagonal) + save_path : str, optional + Path to save the results CSV + method : str + DTW method to use - 'bernd_clifford' or 'dtai' + normalize : bool + Whether to normalize DTW distance by path length + metric : str + Distance metric for computing distance matrix + reference_type : str + Type of embedding to use for reference pattern + Returns + ------- + pd.DataFrame + Match results with distances and warping paths + """ + # Use window_step directly as step size + window_step = max(1, window_step) + + all_match_positions = { + "fov_name": [], + "track_ids": [], + "distance": [], + "skewness": [], + "warp_path": [], + "start_track_timepoint": [], + "end_track_timepoint": [], + } + + for fov_name, track_ids in tqdm(filtered_lineages, desc="Finding pattern matches"): + lineages = [] + t_values = [] + for track_id in track_ids: + # Filter by fov_name and track_id + mask = (adata.obs["fov_name"] == fov_name) & ( + adata.obs["track_id"] == track_id + ) + track_data = adata[mask] + + # Sort by timepoint to ensure correct order + time_order = np.argsort(track_data.obs["t"].values) + track_data = track_data[time_order] + + if reference_type == "features": + track_embeddings = track_data.X + else: + # Assume it's an obsm key + track_embeddings = track_data.obsm[reference_type] + + track_t = track_data.obs["t"].values + + # Handle 1D arrays (PC components) by reshaping to (time, 1) + if track_embeddings.ndim == 1: + track_embeddings = track_embeddings.reshape(-1, 1) + + lineages.append(track_embeddings) + t_values.extend(track_t) # Add t values to our mapping + + lineage_embeddings = np.concatenate(lineages, axis=0) + + # Find best matches using the selected DTW method + if method == "bernd_clifford": + matches_df = find_best_match_dtw_bernd_clifford( + lineage_embeddings, + reference_pattern=reference_pattern, + num_candidates=num_candidates, + window_step=window_step, + max_distance=max_distance, + max_skew=max_skew, + normalize=normalize, + metric=metric, + constraint_type=constraint_type, + band_width_ratio=band_width_ratio, + ) + else: + matches_df = find_best_match_dtw( + lineage_embeddings, + reference_pattern=reference_pattern, + num_candidates=num_candidates, + window_step=window_step, + max_distance=max_distance, + max_skew=max_skew, + normalize=normalize, + constraint_type=constraint_type, + band_width_ratio=band_width_ratio, + ) + + if not matches_df.empty: + # Get the best match (first row) + best_match = matches_df.iloc[0] + best_pos = best_match["position"] + best_path = best_match["path"] + best_dist = best_match["distance"] + best_skew = best_match["skewness"] + + # warping path is relative to the reference pattern + # query_idx is relative to the lineage + converted_path = [] + for ref_idx, query_idx in best_path: + query_t_idx = best_pos + query_idx + if query_t_idx < len(t_values): + actual_t = t_values[query_t_idx] + converted_path.append((ref_idx, actual_t)) + + start_t = t_values[best_pos] if best_pos < len(t_values) else None + end_pos = best_pos + len(reference_pattern) - 1 + end_t = t_values[end_pos] if end_pos < len(t_values) else None + + all_match_positions["fov_name"].append(fov_name) + all_match_positions["track_ids"].append(track_ids) + all_match_positions["distance"].append(best_dist) + all_match_positions["skewness"].append(best_skew) + all_match_positions["warp_path"].append(converted_path) + all_match_positions["start_track_timepoint"].append(start_t) + all_match_positions["end_track_timepoint"].append(end_t) + else: + # No matches found + all_match_positions["fov_name"].append(fov_name) + all_match_positions["track_ids"].append(track_ids) + all_match_positions["distance"].append(None) + all_match_positions["skewness"].append(None) + all_match_positions["warp_path"].append(None) + all_match_positions["start_track_timepoint"].append(None) + all_match_positions["end_track_timepoint"].append(None) + + all_match_positions = pd.DataFrame(all_match_positions) + all_match_positions = all_match_positions.dropna() + + all_match_positions = all_match_positions.sort_values( + by=["distance", "skewness"], ascending=[True, True] + ) + + if save_path: + all_match_positions.to_csv(save_path, index=False) + + return all_match_positions + + +def find_best_match_dtw( + lineage: np.ndarray, + reference_pattern: np.ndarray, + num_candidates: int = 5, + window_step: int = 5, + max_distance: float = float("inf"), + max_skew: float = 0.8, + normalize: bool = True, + constraint_type: str = "unconstrained", + band_width_ratio: float = 0.0, +) -> pd.DataFrame: + """Find best matches using DTW with dtaidistance library. + + Note: constraint_type and band_width_ratio are ignored for dtaidistance method. + + Parameters + ---------- + lineage : np.ndarray + The lineage to search (t, embeddings) + reference_pattern : np.ndarray + The pattern to search for (t, embeddings) + num_candidates : int + Number of candidates to return + window_step : int + Step size for sliding window + max_distance : float + Maximum distance threshold + max_skew : float + Maximum allowed path skewness (0-1) + normalize : bool + Whether to normalize distance by path length + constraint_type : str + Ignored for this method + band_width_ratio : float + Ignored for this method + + Returns + ------- + pd.DataFrame + Results with position, path, distance, and skewness + """ + from dtaidistance.dtw_ndim import warping_path + + dtw_results = [] + n_windows = len(lineage) - len(reference_pattern) + 1 + + if n_windows <= 0: + return pd.DataFrame(columns=["position", "path", "distance", "skewness"]) + + for i in range(0, n_windows, window_step): + window = lineage[i : i + len(reference_pattern)] + path, dist = warping_path( + reference_pattern, + window, + include_distance=True, + ) + if normalize: + # Normalize by path length + dist = dist / len(path) + + # Calculate skewness + skewness = path_skew(path, len(reference_pattern), len(window)) + + if dist <= max_distance and skewness <= max_skew: + dtw_results.append( + {"position": i, "path": path, "distance": dist, "skewness": skewness} + ) + + # Convert to DataFrame and sort + results_df = pd.DataFrame(dtw_results) + if not results_df.empty: + results_df = results_df.sort_values(by=["distance", "skewness"]).head( + num_candidates + ) + + return results_df + + +def find_best_match_dtw_bernd_clifford( + lineage: np.ndarray, + reference_pattern: np.ndarray, + num_candidates: int = 5, + window_step: int = 5, + normalize: bool = True, + max_distance: float = float("inf"), + max_skew: float = 0.8, + metric: str = "euclidean", + constraint_type: str = "unconstrained", + band_width_ratio: float = 0.0, +) -> pd.DataFrame: + """Find best matches using custom DTW implementation. + + Parameters + ---------- + lineage : np.ndarray + The lineage to search (t, embeddings) + reference_pattern : np.ndarray + The pattern to search for (t, embeddings) + num_candidates : int + Number of candidates to return + window_step : int + Step size for sliding window + normalize : bool + Whether to normalize distance by path length + max_distance : float + Maximum distance threshold + max_skew : float + Maximum allowed path skewness (0-1) + metric : str + Distance metric for computing distance matrix + + Returns + ------- + pd.DataFrame + Results with position, path, distance, and skewness + """ + dtw_results = [] + n_windows = len(lineage) - len(reference_pattern) + 1 + + if n_windows <= 0: + return pd.DataFrame(columns=["position", "path", "distance", "skewness"]) + + for i in range(0, n_windows, window_step): + window = lineage[i : i + len(reference_pattern)] + distance_matrix = cdist(reference_pattern, window, metric=metric) + distance, _, path = dtw_with_matrix( + distance_matrix, + normalize=normalize, + constraint_type=constraint_type, + band_width_ratio=band_width_ratio, + ) + skewness = path_skew(path, len(reference_pattern), len(window)) + + # Only add if both thresholds are met + if distance <= max_distance and skewness <= max_skew: + dtw_results.append( + { + "position": i, + "path": path, + "distance": distance, + "skewness": skewness, + } + ) + + # Convert to DataFrame and sort + results_df = pd.DataFrame(dtw_results) + if not results_df.empty: + results_df = results_df.sort_values(by=["distance", "skewness"]).head( + num_candidates + ) + + return results_df + + +def compute_dtw_distance( + s1: ArrayLike, + s2: ArrayLike, + metric: Literal["cosine", "euclidean"] = "cosine", + constraint_type: str = "unconstrained", + band_width_ratio: float = None, +) -> dict: + """Compute DTW distance between two embedding sequences. + + Parameters + ---------- + s1 : ArrayLike + First embedding sequence + s2 : ArrayLike + Second embedding sequence + metric : Literal["cosine", "euclidean"] + Distance metric to use + + Returns + ------- + dict + - 'distance': float - DTW distance + - 'skewness': float - Path skewness + - 'warping_path': list - Warping path + """ + # Create distance matrix + distance_matrix = cdist(s1, s2, metric=metric) + + # Compute DTW + dtw_distance, _, warping_path = dtw_with_matrix( + distance_matrix, + normalize=True, + constraint_type=constraint_type, + band_width_ratio=band_width_ratio, + ) + + # Compute path skewness + skewness = path_skew(warping_path, len(s1), len(s2)) + + return { + "distance": dtw_distance, + "skewness": skewness, + "warping_path": warping_path, + } + + +def dtw_with_matrix( + distance_matrix: np.ndarray, + normalize: bool = True, + constraint_type: str = "unconstrained", + band_width_ratio: float = 0.0, +) -> Tuple[float, np.ndarray, list]: + """Compute DTW using a pre-computed distance matrix with constraints. + + Parameters + ---------- + distance_matrix : np.ndarray + Pre-computed distance matrix between two sequences + normalize : bool + Whether to normalize the distance by path length + constraint_type : str + Type of constraint: "sakoe_chiba", "unconstrained" + band_width_ratio : float + Ratio of matrix size for Sakoe-Chiba band constraint + + Returns + ------- + Tuple[float, np.ndarray, list] + DTW distance, warping matrix, and optimal warping path + """ + n, m = distance_matrix.shape + warping_matrix = np.full((n, m), np.inf) + + if constraint_type == "sakoe_chiba": + # Sakoe-Chiba band constraint + band_width = int(max(n, m) * band_width_ratio) + + for i in range(n): + for j in range(m): + # Only allow alignment within the band + diagonal_position = j * n / m + if abs(i - diagonal_position) <= band_width: + if i == 0 and j == 0: + warping_matrix[i, j] = distance_matrix[i, j] + elif i == 0 and j > 0 and warping_matrix[i, j - 1] != np.inf: + warping_matrix[i, j] = ( + warping_matrix[i, j - 1] + distance_matrix[i, j] + ) + elif j == 0 and i > 0 and warping_matrix[i - 1, j] != np.inf: + warping_matrix[i, j] = ( + warping_matrix[i - 1, j] + distance_matrix[i, j] + ) + elif i > 0 and j > 0: + candidates = [] + if warping_matrix[i - 1, j] != np.inf: + candidates.append(warping_matrix[i - 1, j]) + if warping_matrix[i, j - 1] != np.inf: + candidates.append(warping_matrix[i, j - 1]) + if warping_matrix[i - 1, j - 1] != np.inf: + candidates.append(warping_matrix[i - 1, j - 1]) + + if candidates: + warping_matrix[i, j] = distance_matrix[i, j] + min( + candidates + ) + else: + # Unconstrained DTW + warping_matrix[0, 0] = distance_matrix[0, 0] + + # Fill first column and row + for i in range(1, n): + warping_matrix[i, 0] = warping_matrix[i - 1, 0] + distance_matrix[i, 0] + for j in range(1, m): + warping_matrix[0, j] = warping_matrix[0, j - 1] + distance_matrix[0, j] + + # Fill the rest of the matrix + for i in range(1, n): + for j in range(1, m): + warping_matrix[i, j] = distance_matrix[i, j] + min( + warping_matrix[i - 1, j], # insertion + warping_matrix[i, j - 1], # deletion + warping_matrix[i - 1, j - 1], # match + ) + + # Backtrack to find optimal path + i, j = n - 1, m - 1 + warping_path = [(i, j)] + + while i > 0 or j > 0: + if i == 0: + j -= 1 + elif j == 0: + i -= 1 + else: + min_cost = min( + warping_matrix[i - 1, j], + warping_matrix[i, j - 1], + warping_matrix[i - 1, j - 1], + ) + + if min_cost == warping_matrix[i - 1, j - 1]: + i, j = i - 1, j - 1 + elif min_cost == warping_matrix[i - 1, j]: + i -= 1 + else: + j -= 1 + + warping_path.append((i, j)) + + warping_path.reverse() + + dtw_distance = warping_matrix[n - 1, m - 1] + + if normalize: + dtw_distance = dtw_distance / len(warping_path) + + return dtw_distance, warping_matrix, warping_path + + +def path_skew(warping_path: list, ref_len: int, query_len: int) -> float: + """Calculate skewness of a DTW warping path. + + Parameters + ---------- + warping_path : list + List of (ref_idx, query_idx) tuples representing the warping path + ref_len : int + Length of the reference sequence + query_len : int + Length of the query sequence + + Returns + ------- + float + Skewness metric between 0 and 1, where 0 means perfectly diagonal path + and 1 means completely skewed path + """ + # Calculate "ideal" diagonal indices + diagonal_x = np.linspace(0, ref_len - 1, len(warping_path)) + diagonal_y = np.linspace(0, query_len - 1, len(warping_path)) + diagonal_path = np.column_stack((diagonal_x, diagonal_y)) + + max_distance = max(ref_len, query_len) + + distances = [] + for i, (x, y) in enumerate(warping_path): + dx, dy = diagonal_path[i] + dist = np.sqrt((x - dx) ** 2 + (y - dy) ** 2) + distances.append(dist) + + skew = np.mean(distances) / max_distance + + return skew + + +def create_consensus_from_patterns( + patterns: dict[str, dict], + reference_selection: str = "median_length", + aggregation_method: str = "mean", + metric: Literal["cosine", "euclidean"] = "cosine", + constraint_type: str = "unconstrained", + band_width_ratio: float = 0.0, +) -> dict: + """Create consensus pattern from one or more embedding patterns using DTW alignment. + + For single patterns, uses it directly. For multiple patterns, aligns them with DTW + and aggregates. + + Parameters + ---------- + patterns : dict[str, dict] + Dictionary where keys are pattern identifiers and values contain: + - 'pattern': np.ndarray - The embedding pattern (time, features) + - 'annotations': list or dict - Optional annotations/labels + - 'weight': float - Optional weight for this pattern (default 1.0) + - Other metadata fields are preserved + reference_selection : str + How to select reference: "median_length", "first", "longest", "shortest" + aggregation_method : str + How to aggregate: "mean", "median", "weighted_mean" + metric: Literal["cosine", "euclidean"] + Distance metric for DTW alignment + + Returns + ------- + dict + Dictionary containing: + - 'pattern': np.ndarray - The consensus embedding pattern + - 'annotations': list - Consensus annotations (if available) + - 'metadata': dict - Information about the consensus creation process + """ + if not patterns: + raise ValueError("At least one pattern is required") + + for pattern_id, pattern_data in patterns.items(): + if "pattern" not in pattern_data: + raise ValueError(f"Pattern '{pattern_id}' missing 'pattern' key") + if not isinstance(pattern_data["pattern"], np.ndarray): + raise ValueError(f"Pattern '{pattern_id}' must be numpy array") + + # Handle single pattern case - use it directly as consensus + if len(patterns) == 1: + pattern_id = list(patterns.keys())[0] + pattern_data = patterns[pattern_id] + + consensus = DtwSample( + pattern=pattern_data["pattern"], + annotations=pattern_data.get("annotations"), + distance=np.nan, + skewness=0.0, + warping_path=[(i, i) for i in range(len(pattern_data["pattern"]))], + ) + + consensus["metadata"] = { + "reference_pattern": pattern_id, + "source_patterns": [pattern_id], + "reference_selection": "single_pattern", + "aggregation_method": "none", + "n_patterns": 1, + } + + return consensus + + # Multiple patterns - perform DTW alignment and aggregation + reference_id = _select_reference_pattern(patterns, reference_selection) + reference_pattern = patterns[reference_id]["pattern"] + + _logger.debug(f"Selected reference pattern: {reference_id}") + _logger.debug(f"Reference shape: {reference_pattern.shape}") + + reference_pattern = patterns[reference_id]["pattern"] + aligned_patterns = {reference_id: patterns[reference_id]} + + for pattern_id, pattern_data in patterns.items(): + if pattern_id == reference_id: + continue # Skip reference + + query_pattern = pattern_data["pattern"] + alignment_result = align_embedding_patterns( + query_pattern, + reference_pattern, + metric=metric, + query_annotations=pattern_data.get("annotations"), + constraint_type=constraint_type, + band_width_ratio=band_width_ratio, + ) + aligned_data = { + "pattern": alignment_result["pattern"], + "annotations": alignment_result["annotations"], + "weight": pattern_data.get("weight", 1.0), + "dtw_distance": alignment_result["distance"], + "dtw_skewness": alignment_result["skewness"], + "alignment_path": alignment_result["warping_path"], + } + # Copy other metadata + for key, value in pattern_data.items(): + if key not in ["pattern", "annotations", "weight"]: + aligned_data[key] = value + + aligned_patterns[pattern_id] = aligned_data + consensus = _aggregate_aligned_patterns(aligned_patterns, aggregation_method) + + consensus = DtwSample( + pattern=consensus["pattern"], + annotations=consensus["annotations"], + distance=alignment_result["distance"], + skewness=alignment_result["skewness"], + warping_path=alignment_result["warping_path"], + ) + + consensus["metadata"] = { + "reference_pattern": reference_id, + "source_patterns": list(patterns.keys()), + "reference_selection": reference_selection, + "aggregation_method": aggregation_method, + "n_patterns": len(patterns), + } + + return consensus + + +def _select_reference_pattern(patterns: dict, method: str) -> str: + """Select which pattern to use as reference for DTW alignment.""" + if method == "first": + return list(patterns.keys())[0] + + elif method == "median_length": + lengths = {pid: len(pdata["pattern"]) for pid, pdata in patterns.items()} + median_length = np.median(list(lengths.values())) + closest_id = min(lengths.keys(), key=lambda x: abs(lengths[x] - median_length)) + return closest_id + + elif method == "longest": + lengths = {pid: len(pdata["pattern"]) for pid, pdata in patterns.items()} + return max(lengths.keys(), key=lambda x: lengths[x]) + + elif method == "shortest": + lengths = {pid: len(pdata["pattern"]) for pid, pdata in patterns.items()} + return min(lengths.keys(), key=lambda x: lengths[x]) + + else: + raise ValueError(f"Unknown reference selection method: {method}") + + +def align_embedding_patterns( + query_pattern: np.ndarray, + reference_pattern: np.ndarray, + metric: str = "cosine", + query_annotations: list = None, + constraint_type: str = "unconstrained", + band_width_ratio: float = 0.0, +) -> DtwSample: + """Align two embedding patterns using DTW. + + This is a modular function that aligns two embedding sequences (T, ndim) + using DTW and returns comprehensive alignment information. + + Parameters + ---------- + query_pattern : np.ndarray + Query embedding pattern with shape (T1, ndim) + reference_pattern : np.ndarray + Reference embedding pattern with shape (T2, ndim) + metric : str + Distance metric for DTW alignment + query_annotations : list, optional + Optional annotations for query pattern to align alongside the embeddings + + Returns + ------- + DtwSample + """ + + dtw_result = compute_dtw_distance( + query_pattern, + reference_pattern, + metric=metric, + constraint_type=constraint_type, + band_width_ratio=band_width_ratio, + ) + + # Apply warping path once for both pattern and annotations + aligned_query, aligned_annotations = _apply_warping_path( + query_pattern=query_pattern, + reference_pattern=reference_pattern, + warping_path=dtw_result["warping_path"], + query_annotations=query_annotations, + ) + + return DtwSample( + pattern=aligned_query, + annotations=aligned_annotations, + distance=dtw_result["distance"], + skewness=dtw_result["skewness"], + warping_path=dtw_result["warping_path"], + ) + + +def _apply_warping_path( + query_pattern: np.ndarray, + reference_pattern: np.ndarray, + warping_path: list[tuple[int, int]], + query_annotations: list = None, +) -> tuple[np.ndarray, list]: + """Apply DTW warping path to align query pattern to reference pattern. + + This is a modular helper function that applies a DTW warping path + to align embedding patterns and their annotations. + + Parameters + ---------- + query_pattern : np.ndarray + Query embedding pattern to be aligned (time, features) + reference_pattern : np.ndarray + Reference pattern to align to (time, features) + warping_path : list[tuple[int, int]] + DTW warping path as list of (query_idx, ref_idx) tuples + query_annotations : list, optional + Optional annotations for query pattern + + Returns + ------- + tuple[np.ndarray, list] + Aligned pattern and aligned annotations (if provided) + """ + ref_length, n_features = reference_pattern.shape + aligned_pattern = np.zeros_like(reference_pattern) + + # Apply warping path to align the embedding vectors + for query_idx, ref_idx in warping_path: + if ref_idx < ref_length and query_idx < len(query_pattern): + aligned_pattern[ref_idx] = query_pattern[query_idx] + + # Align annotations if present + aligned_annotations = None + if query_annotations is not None: + aligned_annotations = ["Unknown"] * ref_length + for query_idx, ref_idx in warping_path: + if ref_idx < ref_length and query_idx < len(query_annotations): + aligned_annotations[ref_idx] = query_annotations[query_idx] + + return aligned_pattern, aligned_annotations + + +def _aggregate_aligned_patterns( + aligned_patterns: DtwSample, method: Literal["mean", "median", "weighted_mean"] +) -> DtwSample: + """Aggregate aligned embedding patterns into consensus.""" + consensus = {} + + # Extract patterns and weights + pattern_arrays = [] + weights = [] + + for pattern_data in aligned_patterns.values(): + pattern_arrays.append(pattern_data["pattern"]) + weights.append(pattern_data.get("weight", 1.0)) + + pattern_arrays = np.array(pattern_arrays) + weights = np.array(weights) + + # Aggregate embedding patterns + if method == "mean": + pattern = np.mean(pattern_arrays, axis=0) + elif method == "median": + pattern = np.median(pattern_arrays, axis=0) + elif method == "weighted_mean": + weights = weights / np.sum(weights) + pattern = np.average(pattern_arrays, axis=0, weights=weights) + + consensus["pattern"] = pattern + + # Aggregate annotations if present (use most common at each timepoint) + annotation_lists = [] + for pattern_data in aligned_patterns.values(): + if pattern_data.get("annotations") is not None: + annotation_lists.append(pattern_data["annotations"]) + + if annotation_lists: + annotations = [] + time_length = pattern.shape[0] + + for t in range(time_length): + annotations_at_t = [] + for ann_list in annotation_lists: + if t < len(ann_list) and ann_list[t] != "Unknown": + annotations_at_t.append(ann_list[t]) + + if annotations_at_t: + # Find most common annotation + most_common = max(set(annotations_at_t), key=annotations_at_t.count) + annotations.append(most_common) + else: + annotations.append("Unknown") + + consensus["annotations"] = annotations + + return consensus