@@ -614,7 +614,7 @@ def __init__(
614
614
if copy :
615
615
batch_dict = deepcopy (batch_dict )
616
616
if batch_dict is not None :
617
- if isinstance (batch_dict , dict | BatchProtocol ):
617
+ if isinstance (batch_dict , dict | Batch ):
618
618
_assert_type_keys (batch_dict .keys ())
619
619
for batch_key , obj in batch_dict .items ():
620
620
self .__dict__ [batch_key ] = _parse_value (obj )
@@ -951,7 +951,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
951
951
self .__dict__ [key ][sum_lens [i ] : sum_lens [i + 1 ]] = value
952
952
953
953
def cat_ (self , batches : BatchProtocol | Sequence [dict | BatchProtocol ]) -> None :
954
- if isinstance (batches , BatchProtocol | dict ):
954
+ if isinstance (batches , Batch | dict ):
955
955
batches = [batches ]
956
956
# check input format
957
957
batch_list = []
@@ -1037,7 +1037,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
1037
1037
{
1038
1038
batch_key
1039
1039
for batch_key , obj in batch .items ()
1040
- if not (isinstance (obj , BatchProtocol ) and len (obj .get_keys ()) == 0 )
1040
+ if not (isinstance (obj , Batch ) and len (obj .get_keys ()) == 0 )
1041
1041
}
1042
1042
for batch in batches
1043
1043
]
@@ -1048,7 +1048,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
1048
1048
if all (isinstance (element , torch .Tensor ) for element in value ):
1049
1049
self .__dict__ [shared_key ] = torch .stack (value , axis )
1050
1050
# third often
1051
- elif all (isinstance (element , BatchProtocol | dict ) for element in value ):
1051
+ elif all (isinstance (element , Batch | dict ) for element in value ):
1052
1052
self .__dict__ [shared_key ] = Batch .stack (value , axis )
1053
1053
else : # most often case is np.ndarray
1054
1054
try :
@@ -1082,7 +1082,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
1082
1082
value = batch .get (key )
1083
1083
# TODO: fix code/annotations s.t. the ignores can be removed
1084
1084
if (
1085
- isinstance (value , BatchProtocol ) # type: ignore
1085
+ isinstance (value , Batch ) # type: ignore
1086
1086
and len (value .get_keys ()) == 0 # type: ignore
1087
1087
):
1088
1088
continue # type: ignore
@@ -1250,7 +1250,7 @@ def set_array_at_key(
1250
1250
) from exception
1251
1251
else :
1252
1252
existing_entry = self [key ]
1253
- if isinstance (existing_entry , BatchProtocol ):
1253
+ if isinstance (existing_entry , Batch ):
1254
1254
raise ValueError (
1255
1255
f"Cannot set sequence at key { key } because it is a nested batch, "
1256
1256
f"can only set a subsequence of an array." ,
@@ -1274,7 +1274,7 @@ def hasnull(self) -> bool:
1274
1274
1275
1275
def is_any_true (boolean_batch : BatchProtocol ) -> bool :
1276
1276
for val in boolean_batch .values ():
1277
- if isinstance (val , BatchProtocol ):
1277
+ if isinstance (val , Batch ):
1278
1278
if is_any_true (val ):
1279
1279
return True
1280
1280
else :
@@ -1325,7 +1325,7 @@ def _apply_batch_values_func_recursively(
1325
1325
"""
1326
1326
result = batch if inplace else deepcopy (batch )
1327
1327
for key , val in batch .__dict__ .items ():
1328
- if isinstance (val , BatchProtocol ):
1328
+ if isinstance (val , Batch ):
1329
1329
result [key ] = _apply_batch_values_func_recursively (val , values_transform , inplace = False )
1330
1330
else :
1331
1331
result [key ] = values_transform (val )
0 commit comments