Skip to content

Commit 0141ae5

Browse files
committed
Fixed an issue downloading Torchvision datasets with download = True.
1 parent fdb256b commit 0141ae5

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

plato/datasources/torchvision.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,13 @@ def __init__(self, **kwargs):
298298
common_kwargs.setdefault("root", config.params["data_path"])
299299

300300
download_flag = getattr(data_cfg, "download", kwargs.get("download", True))
301-
download_supported = "download" in signature.parameters
301+
# Some torchvision datasets accept download via **kwargs (e.g., EMNIST).
302+
# Treat VAR_KEYWORD as supporting download to avoid skipping it.
303+
has_var_kwargs = any(
304+
param.kind == inspect.Parameter.VAR_KEYWORD
305+
for param in signature.parameters.values()
306+
)
307+
download_supported = "download" in signature.parameters or has_var_kwargs
302308

303309
default_train_transform = dataset_defaults.get("train_transform")
304310
if default_train_transform is None:

0 commit comments

Comments
 (0)