Skip to content

Commit 82ffa12

Browse files
author
Zoltán Szarvas
committed
hack: temporary hack to eliminate slow isinstance(..., *Protocol) calls
see: thu-ml#1225
1 parent 1b0f2fa commit 82ffa12

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tianshou/data/batch.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def __init__(
614614
if copy:
615615
batch_dict = deepcopy(batch_dict)
616616
if batch_dict is not None:
617-
if isinstance(batch_dict, dict | BatchProtocol):
617+
if isinstance(batch_dict, dict | Batch):
618618
_assert_type_keys(batch_dict.keys())
619619
for batch_key, obj in batch_dict.items():
620620
self.__dict__[batch_key] = _parse_value(obj)
@@ -951,7 +951,7 @@ def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
951951
self.__dict__[key][sum_lens[i] : sum_lens[i + 1]] = value
952952

953953
def cat_(self, batches: BatchProtocol | Sequence[dict | BatchProtocol]) -> None:
954-
if isinstance(batches, BatchProtocol | dict):
954+
if isinstance(batches, Batch | dict):
955955
batches = [batches]
956956
# check input format
957957
batch_list = []
@@ -1037,7 +1037,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
10371037
{
10381038
batch_key
10391039
for batch_key, obj in batch.items()
1040-
if not (isinstance(obj, BatchProtocol) and len(obj.get_keys()) == 0)
1040+
if not (isinstance(obj, Batch) and len(obj.get_keys()) == 0)
10411041
}
10421042
for batch in batches
10431043
]
@@ -1048,7 +1048,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
10481048
if all(isinstance(element, torch.Tensor) for element in value):
10491049
self.__dict__[shared_key] = torch.stack(value, axis)
10501050
# third often
1051-
elif all(isinstance(element, BatchProtocol | dict) for element in value):
1051+
elif all(isinstance(element, Batch | dict) for element in value):
10521052
self.__dict__[shared_key] = Batch.stack(value, axis)
10531053
else: # most often case is np.ndarray
10541054
try:
@@ -1082,7 +1082,7 @@ def stack_(self, batches: Sequence[dict | BatchProtocol], axis: int = 0) -> None
10821082
value = batch.get(key)
10831083
# TODO: fix code/annotations s.t. the ignores can be removed
10841084
if (
1085-
isinstance(value, BatchProtocol) # type: ignore
1085+
isinstance(value, Batch) # type: ignore
10861086
and len(value.get_keys()) == 0 # type: ignore
10871087
):
10881088
continue # type: ignore
@@ -1250,7 +1250,7 @@ def set_array_at_key(
12501250
) from exception
12511251
else:
12521252
existing_entry = self[key]
1253-
if isinstance(existing_entry, BatchProtocol):
1253+
if isinstance(existing_entry, Batch):
12541254
raise ValueError(
12551255
f"Cannot set sequence at key {key} because it is a nested batch, "
12561256
f"can only set a subsequence of an array.",
@@ -1274,7 +1274,7 @@ def hasnull(self) -> bool:
12741274

12751275
def is_any_true(boolean_batch: BatchProtocol) -> bool:
12761276
for val in boolean_batch.values():
1277-
if isinstance(val, BatchProtocol):
1277+
if isinstance(val, Batch):
12781278
if is_any_true(val):
12791279
return True
12801280
else:
@@ -1325,7 +1325,7 @@ def _apply_batch_values_func_recursively(
13251325
"""
13261326
result = batch if inplace else deepcopy(batch)
13271327
for key, val in batch.__dict__.items():
1328-
if isinstance(val, BatchProtocol):
1328+
if isinstance(val, Batch):
13291329
result[key] = _apply_batch_values_func_recursively(val, values_transform, inplace=False)
13301330
else:
13311331
result[key] = values_transform(val)

0 commit comments

Comments
 (0)