@@ -196,7 +196,7 @@ def get_reporter(self):
196196 # in utils.py
197197
198198 def serving_module_names (self ):
199- module_names = ["default_model" , "dqn" , "actor_dqn" ]
199+ module_names = ["default_model" , "dqn" , "actor_dqn" , "reward" ]
200200 if len (self .action_names ) == 2 :
201201 module_names .append ("binary_difference_scorer" )
202202 return module_names
@@ -219,6 +219,7 @@ def build_serving_modules(
219219 "dqn" : self ._build_dqn_module (
220220 trainer_module .q1_network , normalization_data_map
221221 ),
222+ "reward" : self .build_reward_module (trainer_module , normalization_data_map ),
222223 "actor_dqn" : self ._build_dqn_module (
223224 ActorDQN (trainer_module .actor_network ), normalization_data_map
224225 ),
@@ -286,6 +287,23 @@ def build_actor_module(
286287 action_feature_ids = list (range (len (self .action_names ))),
287288 )
288289
290+ def build_reward_module (
291+ self ,
292+ trainer_module : DiscreteCRRTrainer ,
293+ normalization_data_map : Dict [str , NormalizationData ],
294+ ) -> torch .nn .Module :
295+ """
296+ Returns a TorchScript predictor module
297+ """
298+ net_builder = self .cpe_net_builder .value
299+ assert trainer_module .reward_network is not None
300+ return net_builder .build_serving_module (
301+ trainer_module .reward_network ,
302+ normalization_data_map [NormalizationKey .STATE ],
303+ action_names = self .action_names ,
304+ state_feature_config = self .state_feature_config ,
305+ )
306+
289307
290308class ActorDQN (ModelBase ):
291309 def __init__ (self , actor ):
0 commit comments