@@ -109,15 +109,24 @@ def forward(self, shape, list_gene, array_coord
109
109
device = self .device )
110
110
111
111
batch_size = shape [0 ]
112
- for b in range (batch_size ):
113
- if self .radius_rna is None :
114
- rna_imgs [b , list_y [b ], list_x [b ]] = emb [b ]
115
- else :
116
- for i in range (- self .radius_rna , self .radius_rna + 1 ):
117
- for j in range (- self .radius_rna , self .radius_rna + 1 ):
118
- list_x = torch .clamp (list_x + i , 0 , shape [- 1 ] - 1 )
119
- list_y = torch .clamp (list_y + j , 0 , shape [- 2 ] - 1 )
120
- rna_imgs [b , list_y [b ], list_x [b ]] = emb [b ]
112
+ # Vectorized approach:
113
+ if self .radius_rna is None :
114
+ # Use scatter operation for all batches at once
115
+ for b in range (batch_size ):
116
+ rna_imgs [b ].index_put_ ((list_y [b ], list_x [b ]), emb [b ])
117
+ else :
118
+ # For the radius case, precompute offsets
119
+ offsets_y , offsets_x = torch .meshgrid (
120
+ torch .arange (- self .radius_rna , self .radius_rna + 1 , device = self .device ),
121
+ torch .arange (- self .radius_rna , self .radius_rna + 1 , device = self .device )
122
+ )
123
+ offsets = torch .stack ([offsets_y .flatten (), offsets_x .flatten ()], dim = 1 )
124
+
125
+ for b in range (batch_size ):
126
+ for offset_y , offset_x in offsets :
127
+ y_coords = torch .clamp (list_y [b ] + offset_y , 0 , shape [- 2 ] - 1 )
128
+ x_coords = torch .clamp (list_x [b ] + offset_x , 0 , shape [- 1 ] - 1 )
129
+ rna_imgs [b ].index_put_ ((y_coords , x_coords ), emb [b ])
121
130
122
131
rna_imgs = rna_imgs .permute (0 , 3 , 1 , 2 )
123
132
0 commit comments