Skip to content

Commit bfe05d0

Browse files
authored
Merge pull request #2071 from HotBento/patch-2
fix: only use .cpu() in the final step in collector.py to reduce the usage of cpu
2 parents d64724a + f77c3fe commit bfe05d0

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

recbole/evaluator/collector.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,12 @@ def set(self, name: str, value):
4343

4444
def update_tensor(self, name: str, value: torch.Tensor):
4545
if name not in self._data_dict:
46-
self._data_dict[name] = value.cpu().clone().detach()
46+
self._data_dict[name] = value.clone().detach()
4747
else:
4848
if not isinstance(self._data_dict[name], torch.Tensor):
4949
raise ValueError("{} is not a tensor.".format(name))
5050
self._data_dict[name] = torch.cat(
51-
(self._data_dict[name], value.cpu().clone().detach()), dim=0
51+
(self._data_dict[name], value.clone().detach()), dim=0
5252
)
5353

5454
def __str__(self):
@@ -149,13 +149,15 @@ def eval_batch_collect(
149149
positive_i(Torch.Tensor): the positive item id for each user.
150150
"""
151151
if self.register.need("rec.items"):
152+
152153
# get topk
153154
_, topk_idx = torch.topk(
154155
scores_tensor, max(self.topk), dim=-1
155156
) # n_users x k
156157
self.data_struct.update_tensor("rec.items", topk_idx)
157158

158159
if self.register.need("rec.topk"):
160+
159161
_, topk_idx = torch.topk(
160162
scores_tensor, max(self.topk), dim=-1
161163
) # n_users x k
@@ -167,6 +169,7 @@ def eval_batch_collect(
167169
self.data_struct.update_tensor("rec.topk", result)
168170

169171
if self.register.need("rec.meanrank"):
172+
170173
desc_scores, desc_index = torch.sort(scores_tensor, dim=-1, descending=True)
171174

172175
# get the index of positive items in the ranking list
@@ -185,6 +188,7 @@ def eval_batch_collect(
185188
self.data_struct.update_tensor("rec.meanrank", result)
186189

187190
if self.register.need("rec.score"):
191+
188192
self.data_struct.update_tensor("rec.score", scores_tensor)
189193

190194
if self.register.need("data.label"):
@@ -219,6 +223,8 @@ def get_data_struct(self):
219223
"""Get all the evaluation resource that been collected.
220224
And reset some of outdated resource.
221225
"""
226+
for key in self.data_struct._data_dict:
227+
self.data_struct._data_dict[key] = self.data_struct._data_dict[key].cpu()
222228
returned_struct = copy.deepcopy(self.data_struct)
223229
for key in ["rec.topk", "rec.meanrank", "rec.score", "rec.items", "data.label"]:
224230
if key in self.data_struct:

0 commit comments

Comments
 (0)