1111from .subject import Subject
1212
1313
14- class SubjectsDataset (Dataset ):
14+ class SubjectsDataset (Dataset [ Subject ] ):
1515 """Base TorchIO dataset.
1616
1717 Reader of 3D medical images that directly inherits from the PyTorch
@@ -61,12 +61,12 @@ class SubjectsDataset(Dataset):
6161 def __init__ (
6262 self ,
6363 subjects : Sequence [Subject ],
64- transform : Callable | None = None ,
64+ transform : Callable [[ Subject ], Subject ] | None = None ,
6565 load_getitem : bool = True ,
6666 ):
6767 self ._parse_subjects_list (subjects )
6868 self ._subjects = subjects
69- self ._transform : Callable | None
69+ self ._transform : Callable [[ Subject ], Subject ] | None
7070 self .set_transform (transform )
7171 self .load_getitem = load_getitem
7272
@@ -104,7 +104,7 @@ def from_batch(cls, batch: dict) -> SubjectsDataset:
104104 subjects : list [Subject ] = get_subjects_from_batch (batch )
105105 return cls (subjects )
106106
107- def dry_iter (self ):
107+ def dry_iter (self ) -> Sequence [ Subject ] :
108108 """Return the internal list of subjects.
109109
110110 This can be used to iterate over the subjects without loading the data
@@ -115,7 +115,10 @@ def dry_iter(self):
115115 """
116116 return self ._subjects
117117
118- def set_transform (self , transform : Callable | None ) -> None :
118+ def set_transform (
119+ self ,
120+ transform : Callable [[Subject ], Subject ] | None ,
121+ ) -> None :
119122 """Set the `transform` attribute.
120123
121124 Args:
0 commit comments