@@ -123,7 +123,7 @@ def test_egreedy_masked(self, module, eps_init, spec_class):
123123 {"observation" : torch .zeros (* batch_size , action_size )},
124124 batch_size = batch_size ,
125125 )
126- with pytest .raises (KeyError , match = "Action mask key action_mask not found in " ):
126+ with pytest .raises (RuntimeError , match = "Failed while executing module " ):
127127 explorative_policy (td )
128128
129129 torch .manual_seed (0 )
@@ -182,9 +182,7 @@ def test_no_spec_error(
182182 batch_size = batch_size ,
183183 )
184184
185- with pytest .raises (
186- RuntimeError , match = "spec must be provided to the exploration wrapper."
187- ):
185+ with pytest .raises (RuntimeError , match = "Failed while executing module" ):
188186 explorative_policy (td )
189187
190188 @pytest .mark .parametrize ("module" , [True , False ])
@@ -201,9 +199,7 @@ def test_wrong_action_shape(self, module):
201199 policy ,
202200 )
203201 td = TensorDict ({"observation" : torch .zeros (10 , 4 )}, batch_size = [10 ])
204- with pytest .raises (
205- ValueError , match = "Action spec shape does not match the action shape"
206- ):
202+ with pytest .raises (RuntimeError , match = "Failed while executing module" ):
207203 explorative_policy (td )
208204
209205
@@ -383,9 +379,8 @@ def test_nested(
383379 )
384380
385381 action_spec = env .action_spec
386- d_act = action_spec .shape [- 1 ]
382+ action_spec .shape [- 1 ]
387383
388- net = nn .LazyLinear (d_act ).to (device )
389384 policy = TensorDictModule (
390385 CountingEnvCountModule (action_spec = action_spec ),
391386 in_keys = [("data" , "states" ) if nested_obs_action else "observation" ],
0 commit comments