Skip to content

Commit 5f17b97

Browse files
David Vengerovfacebook-github-bot
authored andcommitted
Allow for publishing of reward network in discrete CRR (#588)
Summary: Pull Request resolved: #588 Allow for publishing of reward network in discrete_crr.py Differential Revision: D32711991 fbshipit-source-id: d13fcf724cd5de0c04609378a86b779c07db9efb
1 parent 4ab19c5 commit 5f17b97

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

reagent/model_managers/discrete/discrete_crr.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def get_reporter(self):
196196
# in utils.py
197197

198198
def serving_module_names(self):
199-
module_names = ["default_model", "dqn", "actor_dqn"]
199+
module_names = ["default_model", "dqn", "actor_dqn", "reward"]
200200
if len(self.action_names) == 2:
201201
module_names.append("binary_difference_scorer")
202202
return module_names
@@ -219,6 +219,7 @@ def build_serving_modules(
219219
"dqn": self._build_dqn_module(
220220
trainer_module.q1_network, normalization_data_map
221221
),
222+
"reward": self.build_reward_module(trainer_module, normalization_data_map),
222223
"actor_dqn": self._build_dqn_module(
223224
ActorDQN(trainer_module.actor_network), normalization_data_map
224225
),
@@ -286,6 +287,23 @@ def build_actor_module(
286287
action_feature_ids=list(range(len(self.action_names))),
287288
)
288289

290+
def build_reward_module(
291+
self,
292+
trainer_module: DiscreteCRRTrainer,
293+
normalization_data_map: Dict[str, NormalizationData],
294+
) -> torch.nn.Module:
295+
"""
296+
Returns a TorchScript predictor module
297+
"""
298+
net_builder = self.cpe_net_builder.value
299+
assert trainer_module.reward_network is not None
300+
return net_builder.build_serving_module(
301+
trainer_module.reward_network,
302+
normalization_data_map[NormalizationKey.STATE],
303+
action_names=self.action_names,
304+
state_feature_config=self.state_feature_config,
305+
)
306+
289307

290308
class ActorDQN(ModelBase):
291309
def __init__(self, actor):

reagent/training/discrete_crr_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def __init__(
135135
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
136136
self.reward_boosts[0, i] = rl.reward_boost[k]
137137

138+
# The function below adds reward_network as a member object to DQNTrainerBaseLightning,
139+
# from which DiscreteCRRTrainer is derived.
138140
self._initialize_cpe(
139141
reward_network,
140142
q_network_cpe,

0 commit comments

Comments
 (0)