From 77b7cb67178b11e94924b362ecda8713f2262603 Mon Sep 17 00:00:00 2001 From: David <1282675518@qq.com> Date: Tue, 27 Feb 2024 16:48:01 +0800 Subject: [PATCH] fix neg sample eval bugs of diffrec and ldiffrec --- recbole/model/general_recommender/diffrec.py | 2 +- recbole/model/general_recommender/ldiffrec.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/recbole/model/general_recommender/diffrec.py b/recbole/model/general_recommender/diffrec.py index 82e33a75c..84716ac76 100644 --- a/recbole/model/general_recommender/diffrec.py +++ b/recbole/model/general_recommender/diffrec.py @@ -328,7 +328,7 @@ def full_sort_predict(self, interaction): def predict(self, interaction): item = interaction[self.ITEM_ID] x_t = self.full_sort_predict(interaction) - scores = x_t[:, item] + scores = x_t[torch.arange(len(item)).to(self.device), item] return scores def calculate_loss(self, interaction): diff --git a/recbole/model/general_recommender/ldiffrec.py b/recbole/model/general_recommender/ldiffrec.py index 7d8364f0b..a537db7d2 100644 --- a/recbole/model/general_recommender/ldiffrec.py +++ b/recbole/model/general_recommender/ldiffrec.py @@ -335,7 +335,7 @@ def full_sort_predict(self, interaction): def predict(self, interaction): item = interaction[self.ITEM_ID] x_t = self.full_sort_predict(interaction) - scores = x_t[:, item] + scores = x_t[torch.arange(len(item)).to(self.device), item] return scores