@@ -686,6 +686,10 @@ def __init__(
686686 policy = RandomPolicy (env .full_action_spec )
687687 elif policy_factory is not None :
688688 raise TypeError ("policy_factory cannot be used with policy argument." )
689+ # If the underlying policy has a state_dict, we keep a reference to the policy and
690+ # do all policy weight saving/loading through it
691+ if hasattr (policy , "state_dict" ):
692+ self ._policy_w_state_dict = policy
689693
690694 if trust_policy is None :
691695 trust_policy = isinstance (policy , (RandomPolicy , CudaGraphModule ))
@@ -1686,8 +1690,8 @@ def state_dict(self) -> OrderedDict:
16861690 else :
16871691 env_state_dict = OrderedDict ()
16881692
1689- if hasattr (self . policy , "state_dict " ):
1690- policy_state_dict = self .policy .state_dict ()
1693+ if hasattr (self , "_policy_w_state_dict " ):
1694+ policy_state_dict = self ._policy_w_state_dict .state_dict ()
16911695 state_dict = OrderedDict (
16921696 policy_state_dict = policy_state_dict ,
16931697 env_state_dict = env_state_dict ,
@@ -1711,7 +1715,13 @@ def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:
17111715 if strict or "env_state_dict" in state_dict :
17121716 self .env .load_state_dict (state_dict ["env_state_dict" ], ** kwargs )
17131717 if strict or "policy_state_dict" in state_dict :
1714- self .policy .load_state_dict (state_dict ["policy_state_dict" ], ** kwargs )
1718+ if not hasattr (self , "_policy_w_state_dict" ):
1719+ raise ValueError (
1720+ "Underlying policy does not have state_dict to load policy_state_dict into."
1721+ )
1722+ self ._policy_w_state_dict .load_state_dict (
1723+ state_dict ["policy_state_dict" ], ** kwargs
1724+ )
17151725 self ._frames = state_dict ["frames" ]
17161726 self ._iter = state_dict ["iter" ]
17171727
0 commit comments