@@ -43,12 +43,12 @@ def set(self, name: str, value):
43
43
44
44
def update_tensor (self , name : str , value : torch .Tensor ):
45
45
if name not in self ._data_dict :
46
- self ._data_dict [name ] = value .cpu (). clone ().detach ()
46
+ self ._data_dict [name ] = value .clone ().detach ()
47
47
else :
48
48
if not isinstance (self ._data_dict [name ], torch .Tensor ):
49
49
raise ValueError ("{} is not a tensor." .format (name ))
50
50
self ._data_dict [name ] = torch .cat (
51
- (self ._data_dict [name ], value .cpu (). clone ().detach ()), dim = 0
51
+ (self ._data_dict [name ], value .clone ().detach ()), dim = 0
52
52
)
53
53
54
54
def __str__ (self ):
@@ -149,13 +149,15 @@ def eval_batch_collect(
149
149
positive_i(Torch.Tensor): the positive item id for each user.
150
150
"""
151
151
if self .register .need ("rec.items" ):
152
+
152
153
# get topk
153
154
_ , topk_idx = torch .topk (
154
155
scores_tensor , max (self .topk ), dim = - 1
155
156
) # n_users x k
156
157
self .data_struct .update_tensor ("rec.items" , topk_idx )
157
158
158
159
if self .register .need ("rec.topk" ):
160
+
159
161
_ , topk_idx = torch .topk (
160
162
scores_tensor , max (self .topk ), dim = - 1
161
163
) # n_users x k
@@ -167,6 +169,7 @@ def eval_batch_collect(
167
169
self .data_struct .update_tensor ("rec.topk" , result )
168
170
169
171
if self .register .need ("rec.meanrank" ):
172
+
170
173
desc_scores , desc_index = torch .sort (scores_tensor , dim = - 1 , descending = True )
171
174
172
175
# get the index of positive items in the ranking list
@@ -185,6 +188,7 @@ def eval_batch_collect(
185
188
self .data_struct .update_tensor ("rec.meanrank" , result )
186
189
187
190
if self .register .need ("rec.score" ):
191
+
188
192
self .data_struct .update_tensor ("rec.score" , scores_tensor )
189
193
190
194
if self .register .need ("data.label" ):
@@ -219,6 +223,8 @@ def get_data_struct(self):
219
223
"""Get all the evaluation resource that been collected.
220
224
And reset some of outdated resource.
221
225
"""
226
+ for key in self .data_struct ._data_dict :
227
+ self .data_struct ._data_dict [key ] = self .data_struct ._data_dict [key ].cpu ()
222
228
returned_struct = copy .deepcopy (self .data_struct )
223
229
for key in ["rec.topk" , "rec.meanrank" , "rec.score" , "rec.items" , "data.label" ]:
224
230
if key in self .data_struct :
0 commit comments