Skip to content

Commit f2618b4

Browse files
authored
Merge pull request #3 from ajinkya-kulkarni/patch-1
Vectorized nested loop(s)
2 parents 47a0f81 + c508255 commit f2618b4

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

src/rna2seg/models.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,24 @@ def forward(self, shape, list_gene, array_coord
109109
device=self.device)
110110

111111
batch_size = shape[0]
112-
for b in range(batch_size):
113-
if self.radius_rna is None:
114-
rna_imgs[b, list_y[b], list_x[b]] = emb[b]
115-
else:
116-
for i in range(-self.radius_rna, self.radius_rna + 1):
117-
for j in range(-self.radius_rna, self.radius_rna + 1):
118-
list_x = torch.clamp(list_x + i, 0, shape[-1] - 1)
119-
list_y = torch.clamp(list_y + j, 0, shape[-2] - 1)
120-
rna_imgs[b, list_y[b], list_x[b]] = emb[b]
112+
# Vectorized approach:
113+
if self.radius_rna is None:
114+
# Use scatter operation for all batches at once
115+
for b in range(batch_size):
116+
rna_imgs[b].index_put_((list_y[b], list_x[b]), emb[b])
117+
else:
118+
# For the radius case, precompute offsets
119+
offsets_y, offsets_x = torch.meshgrid(
120+
torch.arange(-self.radius_rna, self.radius_rna + 1, device=self.device),
121+
torch.arange(-self.radius_rna, self.radius_rna + 1, device=self.device)
122+
)
123+
offsets = torch.stack([offsets_y.flatten(), offsets_x.flatten()], dim=1)
124+
125+
for b in range(batch_size):
126+
for offset_y, offset_x in offsets:
127+
y_coords = torch.clamp(list_y[b] + offset_y, 0, shape[-2] - 1)
128+
x_coords = torch.clamp(list_x[b] + offset_x, 0, shape[-1] - 1)
129+
rna_imgs[b].index_put_((y_coords, x_coords), emb[b])
121130

122131
rna_imgs = rna_imgs.permute(0, 3, 1, 2)
123132

0 commit comments

Comments
 (0)