1616 UnboundedContinuousTensorSpec ,
1717 OneHotDiscreteTensorSpec ,
1818)
19- from torchrl .data .tensordict .tensordict import _TensorDict , TensorDict
19+ from torchrl .data .tensordict .tensordict import TensorDictBase , TensorDict
2020from torchrl .envs .common import _EnvClass
2121
2222spec_dict = {
@@ -110,15 +110,15 @@ def _step(self, tensordict):
110110 done = torch .tensor ([done ], dtype = torch .bool , device = self .device )
111111 return TensorDict ({"reward" : n , "done" : done , "next_observation" : n }, [])
112112
113- def _reset (self , tensordict : _TensorDict , ** kwargs ) -> _TensorDict :
113+ def _reset (self , tensordict : TensorDictBase , ** kwargs ) -> TensorDictBase :
114114 self .max_val = max (self .counter + 100 , self .counter * 2 )
115115
116116 n = torch .tensor ([self .counter ]).to (self .device ).to (torch .get_default_dtype ())
117117 done = self .counter >= self .max_val
118118 done = torch .tensor ([done ], dtype = torch .bool , device = self .device )
119119 return TensorDict ({"done" : done , "next_observation" : n }, [])
120120
121- def rand_step (self , tensordict : Optional [_TensorDict ] = None ) -> _TensorDict :
121+ def rand_step (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDictBase :
122122 return self .step (tensordict )
123123
124124
@@ -144,7 +144,7 @@ def _get_in_obs(self, obs):
144144 def _get_out_obs (self , obs ):
145145 return obs
146146
147- def _reset (self , tensordict : _TensorDict ) -> _TensorDict :
147+ def _reset (self , tensordict : TensorDictBase ) -> TensorDictBase :
148148 self .counter += 1
149149 state = torch .zeros (self .size ) + self .counter
150150 tensordict = tensordict .select ().set (
@@ -156,8 +156,8 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:
156156
157157 def _step (
158158 self ,
159- tensordict : _TensorDict ,
160- ) -> _TensorDict :
159+ tensordict : TensorDictBase ,
160+ ) -> TensorDictBase :
161161 tensordict = tensordict .to (self .device )
162162 a = tensordict .get ("action" )
163163 assert (a .sum (- 1 ) == 1 ).all ()
@@ -199,7 +199,7 @@ def _get_in_obs(self, obs):
199199 def _get_out_obs (self , obs ):
200200 return obs
201201
202- def _reset (self , tensordict : _TensorDict ) -> _TensorDict :
202+ def _reset (self , tensordict : TensorDictBase ) -> TensorDictBase :
203203 self .counter += 1
204204 self .step_count = 0
205205 state = torch .zeros (self .size ) + self .counter
@@ -211,8 +211,8 @@ def _reset(self, tensordict: _TensorDict) -> _TensorDict:
211211
212212 def _step (
213213 self ,
214- tensordict : _TensorDict ,
215- ) -> _TensorDict :
214+ tensordict : TensorDictBase ,
215+ ) -> TensorDictBase :
216216 self .step_count += 1
217217 tensordict = tensordict .to (self .device )
218218 a = tensordict .get ("action" )
0 commit comments