Skip to content

Commit a98e882

Browse files
committed
fix plots for smoothnes
1 parent cf6dbe7 commit a98e882

File tree

1 file changed

+62
-42
lines changed

1 file changed

+62
-42
lines changed

applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from tqdm import tqdm
1818
import pandas as pd
1919

20+
2021
plt.style.use("../evaluation/figure.mplstyle")
2122

2223
# plotting
@@ -36,10 +37,9 @@ def compute_piece_wise_dissimilarity(
3637
piece_wise_rank_difference_per_track = []
3738
for name, subdata in features_df.groupby(["fov_name", "track_id"]):
3839
if len(subdata) > 1:
39-
single_track_dissimilarity = select_block(cross_dist, subdata.index.values)
40-
single_track_rank_fraction = select_block(
41-
rank_fractions, subdata.index.values
42-
)
40+
indices = subdata.index.values
41+
single_track_dissimilarity = select_block(cross_dist, indices)
42+
single_track_rank_fraction = select_block(rank_fractions, indices)
4343
piece_wise_dissimilarity = compare_time_offset(
4444
single_track_dissimilarity, time_offset=1
4545
)
@@ -64,20 +64,25 @@ def plot_histogram(
6464

6565

6666
# %%
67+
PATH_TO_GDRIVE_FIGUE = "/home/eduardo.hirata/mydata/gdrive/publications/learning_impacts_of_infection/fig_manuscript/rev2_ICLR_fig/"
68+
6769
prediction_path_1 = Path(
68-
"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_1.zarr"
70+
"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr"
6971
)
7072
prediction_path_2 = Path(
71-
"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_1.zarr"
73+
"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr"
7274
)
73-
for prediction_path in tqdm([prediction_path_1, prediction_path_2]):
75+
76+
for prediction_path, loss_name in tqdm(
77+
[(prediction_path_1, "ntxent"), (prediction_path_2, "triplet")]
78+
):
7479

7580
# Read the dataset
7681
embeddings = read_embedding_dataset(prediction_path)
7782
features = embeddings["features"]
7883

7984
scaled_features = StandardScaler().fit_transform(features.values)
80-
# COmpute the cosine dissimilarity
85+
# Compute the cosine dissimilarity
8186
cross_dist = cross_dissimilarity(scaled_features, metric="cosine")
8287
rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True)
8388

@@ -91,43 +96,58 @@ def plot_histogram(
9196
compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions)
9297
)
9398

94-
# Get the median/mode of the off diagonal elements
95-
median_piece_wise_dissimilarity = [
96-
np.median(track) for track in piece_wise_dissimilarity_per_track
97-
]
98-
p99_piece_wise_dissimilarity = [
99-
np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track
100-
]
101-
p1_percentile_piece_wise_dissimilarity = [
102-
np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track
103-
]
99+
all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track)
100+
101+
# # Get the median/mode of the off diagonal elements
102+
# median_piece_wise_dissimilarity = np.array(
103+
# [np.median(track) for track in piece_wise_dissimilarity_per_track]
104+
# )
105+
p99_piece_wise_dissimilarity = np.array(
106+
[np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track]
107+
)
108+
p1_percentile_piece_wise_dissimilarity = np.array(
109+
[np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track]
110+
)
104111

105112
# Random sampling values in the dissimilarity matrix
106-
n_samples = 2000
107-
sampled_values = [
108-
cross_dist[
109-
np.random.randint(0, len(cross_dist)), np.random.randint(0, len(cross_dist))
110-
]
111-
for _ in range(n_samples)
112-
]
113+
n_samples = 3000
114+
random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2))
115+
sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]]
116+
117+
print(f"Dissimilarity Statistics for {prediction_path.stem}")
118+
print(f"Mean: {np.mean(all_dissimilarity)}")
119+
print(f"Std: {np.std(all_dissimilarity)}")
120+
print(f"Median: {np.median(all_dissimilarity)}")
121+
122+
print(f"Distance Statistics for random sampling")
123+
print(f"Mean: {np.mean(sampled_values)}")
124+
print(f"Std: {np.std(sampled_values)}")
125+
print(f"Median: {np.median(sampled_values)}")
113126

114127
if VERBOSE:
115128
# Plot histograms
129+
# plot_histogram(
130+
# median_piece_wise_dissimilarity,
131+
# "Adjacent Frame Median Dissimilarity per Track",
132+
# "Cosine Dissimilarity",
133+
# "Frequency",
134+
# )
135+
# plot_histogram(
136+
# p1_percentile_piece_wise_dissimilarity,
137+
# "Adjacent Frame 1 Percentile Dissimilarity per Track",
138+
# "Cosine Dissimilarity",
139+
# "Frequency",
140+
# )
141+
# plot_histogram(
142+
# p99_piece_wise_dissimilarity,
143+
# "Adjacent Frame 99 Percentile Dissimilarity per Track",
144+
# "Cosine Dissimilarity",
145+
# "Frequency",
146+
# )
147+
116148
plot_histogram(
117-
median_piece_wise_dissimilarity,
118-
"Adjacent Frame Median Dissimilarity per Track",
119-
"Cosine Dissimilarity",
120-
"Frequency",
121-
)
122-
plot_histogram(
123-
p1_percentile_piece_wise_dissimilarity,
124-
"Adjacent Frame 1 Percentile Dissimilarity per Track",
125-
"Cosine Dissimilarity",
126-
"Frequency",
127-
)
128-
plot_histogram(
129-
p99_piece_wise_dissimilarity,
130-
"Adjacent Frame 99 Percentile Dissimilarity per Track",
149+
piece_wise_dissimilarity_per_track,
150+
"Adjacent Frame Dissimilarity per Track",
131151
"Cosine Dissimilarity",
132152
"Frequency",
133153
)
@@ -145,7 +165,7 @@ def plot_histogram(
145165
# Plot the median and the random sampling in one plot each with different colors
146166
fig = plt.figure()
147167
sns.histplot(
148-
median_piece_wise_dissimilarity,
168+
all_dissimilarity,
149169
bins=30,
150170
kde=True,
151171
color="cyan",
@@ -161,8 +181,8 @@ def plot_histogram(
161181
plt.legend(["Adjacent Frame", "Random Sample"])
162182
plt.show()
163183
fig.savefig(
164-
f"./cosine_dissimilarity_smoothness_{prediction_path.stem}.pdf",
165-
dpi=300,
184+
f"{PATH_TO_GDRIVE_FIGUE}/cosine_dissimilarity_smoothness_{prediction_path.stem}_{loss_name}.pdf",
185+
dpi=600,
166186
)
167187

168188
# %%

0 commit comments

Comments
 (0)