@@ -169,16 +169,26 @@ def _dtypes(self, kind):
169169 int32 = torch .int32
170170 int64 = torch .int64
171171 uint8 = torch .uint8
172- # uint16, uint32, and uint64 are present in newer versions of pytorch,
173- # but they aren't generally supported by the array API functions, so
174- # we omit them from this function.
172+ try :
173+ # pytorch >= 2.3
174+ uint16 = torch .uint16
175+ uint32 = torch .uint32
176+ uint64 = torch .uint64
177+ uint_kinds = {
178+ "uint16" : uint16 ,
179+ "uint32" : uint32 ,
180+ "uint64" : uint64 ,
181+ }
182+ except AttributeError :
183+ uint_kinds = {}
184+
175185 float32 = torch .float32
176186 float64 = torch .float64
177187 complex64 = torch .complex64
178188 complex128 = torch .complex128
179189
180190 if kind is None :
181- return {
191+ kinds = {
182192 "bool" : bool ,
183193 "int8" : int8 ,
184194 "int16" : int16 ,
@@ -190,6 +200,8 @@ def _dtypes(self, kind):
190200 "complex64" : complex64 ,
191201 "complex128" : complex128 ,
192202 }
203+ kinds .update (uint_kinds )
204+ return kinds
193205 if kind == "bool" :
194206 return {"bool" : bool }
195207 if kind == "signed integer" :
@@ -200,17 +212,21 @@ def _dtypes(self, kind):
200212 "int64" : int64 ,
201213 }
202214 if kind == "unsigned integer" :
203- return {
215+ kinds = {
204216 "uint8" : uint8 ,
205217 }
218+ kinds .update (uint_kinds )
219+ return kinds
206220 if kind == "integral" :
207- return {
221+ kinds = {
208222 "int8" : int8 ,
209223 "int16" : int16 ,
210224 "int32" : int32 ,
211225 "int64" : int64 ,
212226 "uint8" : uint8 ,
213227 }
228+ kinds .update (uint_kinds )
229+ return kinds
214230 if kind == "real floating" :
215231 return {
216232 "float32" : float32 ,
@@ -222,7 +238,7 @@ def _dtypes(self, kind):
222238 "complex128" : complex128 ,
223239 }
224240 if kind == "numeric" :
225- return {
241+ kinds = {
226242 "int8" : int8 ,
227243 "int16" : int16 ,
228244 "int32" : int32 ,
@@ -233,6 +249,9 @@ def _dtypes(self, kind):
233249 "complex64" : complex64 ,
234250 "complex128" : complex128 ,
235251 }
252+ kinds .update (uint_kinds )
253+ return kinds
254+
236255 if isinstance (kind , tuple ):
237256 res = {}
238257 for k in kind :
0 commit comments