diff --git a/agentevolver/module/trainer/ae_ray_trainer.py b/agentevolver/module/trainer/ae_ray_trainer.py index 5ba5082..f26d9ee 100644 --- a/agentevolver/module/trainer/ae_ray_trainer.py +++ b/agentevolver/module/trainer/ae_ray_trainer.py @@ -197,11 +197,12 @@ def compute_grpo_outcome_advantage( id2score[index[i]].append(scores[i]) for idx in id2score: if len(id2score[idx]) == 1: - id2mean[idx] = torch.tensor(0.0) - id2std[idx] = torch.tensor(1.0) + id2mean[idx] = id2score[idx][0] + id2std[idx] = torch.tensor(1.0, device=scores.device) elif len(id2score[idx]) > 1: - id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) - id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + group_scores = torch.stack(id2score[idx]) + id2mean[idx] = group_scores.mean() + id2std[idx] = group_scores.std() else: raise ValueError(f"no score in prompt index: {idx}") for i in range(bsz):