1717from tqdm import tqdm
1818import pandas as pd
1919
20+
2021plt .style .use ("../evaluation/figure.mplstyle" )
2122
2223# plotting
@@ -36,10 +37,9 @@ def compute_piece_wise_dissimilarity(
3637 piece_wise_rank_difference_per_track = []
3738 for name , subdata in features_df .groupby (["fov_name" , "track_id" ]):
3839 if len (subdata ) > 1 :
39- single_track_dissimilarity = select_block (cross_dist , subdata .index .values )
40- single_track_rank_fraction = select_block (
41- rank_fractions , subdata .index .values
42- )
40+ indices = subdata .index .values
41+ single_track_dissimilarity = select_block (cross_dist , indices )
42+ single_track_rank_fraction = select_block (rank_fractions , indices )
4343 piece_wise_dissimilarity = compare_time_offset (
4444 single_track_dissimilarity , time_offset = 1
4545 )
@@ -64,20 +64,25 @@ def plot_histogram(
6464
6565
6666# %%
67+ PATH_TO_GDRIVE_FIGUE = "/home/eduardo.hirata/mydata/gdrive/publications/learning_impacts_of_infection/fig_manuscript/rev2_ICLR_fig/"
68+
6769prediction_path_1 = Path (
68- "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_1 .zarr"
70+ "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2 .zarr"
6971)
7072prediction_path_2 = Path (
71- "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_1 .zarr"
73+ "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2 .zarr"
7274)
73- for prediction_path in tqdm ([prediction_path_1 , prediction_path_2 ]):
75+
76+ for prediction_path , loss_name in tqdm (
77+ [(prediction_path_1 , "ntxent" ), (prediction_path_2 , "triplet" )]
78+ ):
7479
7580 # Read the dataset
7681 embeddings = read_embedding_dataset (prediction_path )
7782 features = embeddings ["features" ]
7883
7984 scaled_features = StandardScaler ().fit_transform (features .values )
80- # COmpute the cosine dissimilarity
85+ # Compute the cosine dissimilarity
8186 cross_dist = cross_dissimilarity (scaled_features , metric = "cosine" )
8287 rank_fractions = rank_nearest_neighbors (cross_dist , normalize = True )
8388
@@ -91,43 +96,58 @@ def plot_histogram(
9196 compute_piece_wise_dissimilarity (features_df , cross_dist , rank_fractions )
9297 )
9398
94- # Get the median/mode of the off diagonal elements
95- median_piece_wise_dissimilarity = [
96- np .median (track ) for track in piece_wise_dissimilarity_per_track
97- ]
98- p99_piece_wise_dissimilarity = [
99- np .percentile (track , 99 ) for track in piece_wise_dissimilarity_per_track
100- ]
101- p1_percentile_piece_wise_dissimilarity = [
102- np .percentile (track , 1 ) for track in piece_wise_dissimilarity_per_track
103- ]
99+ all_dissimilarity = np .concatenate (piece_wise_dissimilarity_per_track )
100+
101+ # # Get the median/mode of the off diagonal elements
102+ # median_piece_wise_dissimilarity = np.array(
103+ # [np.median(track) for track in piece_wise_dissimilarity_per_track]
104+ # )
105+ p99_piece_wise_dissimilarity = np .array (
106+ [np .percentile (track , 99 ) for track in piece_wise_dissimilarity_per_track ]
107+ )
108+ p1_percentile_piece_wise_dissimilarity = np .array (
109+ [np .percentile (track , 1 ) for track in piece_wise_dissimilarity_per_track ]
110+ )
104111
105112 # Random sampling values in the dissimilarity matrix
106- n_samples = 2000
107- sampled_values = [
108- cross_dist [
109- np .random .randint (0 , len (cross_dist )), np .random .randint (0 , len (cross_dist ))
110- ]
111- for _ in range (n_samples )
112- ]
113+ n_samples = 3000
114+ random_indices = np .random .randint (0 , len (cross_dist ), size = (n_samples , 2 ))
115+ sampled_values = cross_dist [random_indices [:, 0 ], random_indices [:, 1 ]]
116+
117+ print (f"Dissimilarity Statistics for { prediction_path .stem } " )
118+ print (f"Mean: { np .mean (all_dissimilarity )} " )
119+ print (f"Std: { np .std (all_dissimilarity )} " )
120+ print (f"Median: { np .median (all_dissimilarity )} " )
121+
122+ print (f"Distance Statistics for random sampling" )
123+ print (f"Mean: { np .mean (sampled_values )} " )
124+ print (f"Std: { np .std (sampled_values )} " )
125+ print (f"Median: { np .median (sampled_values )} " )
113126
114127 if VERBOSE :
115128 # Plot histograms
129+ # plot_histogram(
130+ # median_piece_wise_dissimilarity,
131+ # "Adjacent Frame Median Dissimilarity per Track",
132+ # "Cosine Dissimilarity",
133+ # "Frequency",
134+ # )
135+ # plot_histogram(
136+ # p1_percentile_piece_wise_dissimilarity,
137+ # "Adjacent Frame 1 Percentile Dissimilarity per Track",
138+ # "Cosine Dissimilarity",
139+ # "Frequency",
140+ # )
141+ # plot_histogram(
142+ # p99_piece_wise_dissimilarity,
143+ # "Adjacent Frame 99 Percentile Dissimilarity per Track",
144+ # "Cosine Dissimilarity",
145+ # "Frequency",
146+ # )
147+
116148 plot_histogram (
117- median_piece_wise_dissimilarity ,
118- "Adjacent Frame Median Dissimilarity per Track" ,
119- "Cosine Dissimilarity" ,
120- "Frequency" ,
121- )
122- plot_histogram (
123- p1_percentile_piece_wise_dissimilarity ,
124- "Adjacent Frame 1 Percentile Dissimilarity per Track" ,
125- "Cosine Dissimilarity" ,
126- "Frequency" ,
127- )
128- plot_histogram (
129- p99_piece_wise_dissimilarity ,
130- "Adjacent Frame 99 Percentile Dissimilarity per Track" ,
149+ piece_wise_dissimilarity_per_track ,
150+ "Adjacent Frame Dissimilarity per Track" ,
131151 "Cosine Dissimilarity" ,
132152 "Frequency" ,
133153 )
@@ -145,7 +165,7 @@ def plot_histogram(
145165 # Plot the median and the random sampling in one plot each with different colors
146166 fig = plt .figure ()
147167 sns .histplot (
148- median_piece_wise_dissimilarity ,
168+ all_dissimilarity ,
149169 bins = 30 ,
150170 kde = True ,
151171 color = "cyan" ,
@@ -161,8 +181,8 @@ def plot_histogram(
161181 plt .legend (["Adjacent Frame" , "Random Sample" ])
162182 plt .show ()
163183 fig .savefig (
164- f". /cosine_dissimilarity_smoothness_{ prediction_path .stem } .pdf" ,
165- dpi = 300 ,
184+ f"{ PATH_TO_GDRIVE_FIGUE } /cosine_dissimilarity_smoothness_{ prediction_path .stem } _ { loss_name } .pdf" ,
185+ dpi = 600 ,
166186 )
167187
168188# %%
0 commit comments