Skip to content

Commit 4ab19c5

Browse files
czxttklfacebook-github-bot
authored andcommitted
Add an internal product model manager for signal loss
Summary: Since the code will become more and more specific to the ads signal loss use case, it is better to create a dedicated version which does not sync to OSS. Reviewed By: j-jiafei Differential Revision: D32591299 fbshipit-source-id: 02600fd68062a24ff22933e91faae3804a9da2fa
1 parent 4c470f4 commit 4ab19c5

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

reagent/core/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,13 @@ class RewardNetworkOutput(TensorDataClass):
10631063
predicted_reward: torch.Tensor
10641064

10651065

1066+
@dataclass
1067+
class SyntheticRewardNetworkOutput(TensorDataClass):
1068+
predicted_reward: torch.Tensor
1069+
mask: torch.Tensor
1070+
output: torch.Tensor
1071+
1072+
10661073
@dataclass
10671074
class FrechetSortConfig:
10681075
shape: float

reagent/models/synthetic_reward.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,11 @@ def forward(self, training_batch: rlt.MemoryNetworkInput):
262262
output_masked = output * mask
263263

264264
pred_reward = output_masked.sum(dim=1, keepdim=True)
265-
return rlt.RewardNetworkOutput(predicted_reward=pred_reward)
265+
return rlt.SyntheticRewardNetworkOutput(
266+
predicted_reward=pred_reward,
267+
mask=mask,
268+
output=output,
269+
)
266270

267271
def export_mlp(self):
268272
"""

reagent/training/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3-
3+
from reagent.core.fb_checker import IS_FB_ENVIRONMENT
44
from reagent.training.c51_trainer import C51Trainer
55
from reagent.training.cem_trainer import CEMTrainer
66
from reagent.training.cfeval import BanditRewardNetTrainer
@@ -68,3 +68,10 @@
6868
"PPOTrainer",
6969
"PPOTrainerParameters",
7070
]
71+
72+
if IS_FB_ENVIRONMENT:
73+
from reagent.training.fb.signal_loss_reward_decomp_trainer import ( # noqa
74+
SignalLossRewardDecompTrainer,
75+
)
76+
77+
__all__.append("SignalLossRewardDecompTrainer")

0 commit comments

Comments
 (0)