@@ -90,7 +90,11 @@ def __iter__(self) -> dataset.DatasetIterator:
9090 return _IteratorIdDatasetIterator (self ._parent .__iter__ ())
9191
9292
93- class InterleaveIterDatasetTest (parameterized .TestCase ):
93+ @absltest .skipThisClass ("Base class" )
94+ class _InterleaveIterDatasetTestBase (parameterized .TestCase ):
95+
96+ def _maybe_wrap_ds (self , ds ):
97+ return ds
9498
9599 @parameterized .named_parameters (* _INTERLEAVE_TEST_CASES )
96100 def test_interleaved_mix (self , to_mix , cycle_length , expected ):
@@ -99,6 +103,7 @@ def test_interleaved_mix(self, to_mix, cycle_length, expected):
99103 for elements in to_mix
100104 ]
101105 ds = interleave .InterleaveIterDataset (datasets , cycle_length = cycle_length )
106+ ds = self ._maybe_wrap_ds (ds )
102107 self .assertEqual (list (ds ), expected )
103108 # Sanity check.
104109 flat_inputs = []
@@ -113,6 +118,7 @@ def test_checkpoint(self, to_mix, cycle_length, expected):
113118 for elements in to_mix
114119 ]
115120 ds = interleave .InterleaveIterDataset (datasets , cycle_length = cycle_length )
121+ ds = self ._maybe_wrap_ds (ds )
116122 ds_iter = ds .__iter__ ()
117123 checkpoints = {}
118124 for i in range (len (expected )):
@@ -138,6 +144,7 @@ def test_checkpoint_with_extra_threads_creating_iterators(
138144 num_make_iter_threads = 10 ,
139145 make_iter_buffer_size = 10 ,
140146 )
147+ ds = self ._maybe_wrap_ds (ds )
141148 ds_iter = ds .__iter__ ()
142149 checkpoints = {}
143150 for i in range (len (expected )):
@@ -158,6 +165,7 @@ def make_dummy_source(filename):
158165 filenames = dataset .MapDataset .source (["11" , "2345" , "678" , "9999" ])
159166 sources = filenames .shuffle (seed = 42 ).map (make_dummy_source )
160167 ds = interleave .InterleaveIterDataset (sources , cycle_length = 2 )
168+ ds = self ._maybe_wrap_ds (ds )
161169 self .assertEqual (
162170 list (ds ),
163171 ["1" , "2" , "1" , "3" , "4" , "6" , "5" , "7" , "8" , "9" , "9" , "9" , "9" ],
@@ -168,6 +176,7 @@ def test_with_mp_prefetch(self):
168176 lambda i : dataset .MapDataset .source ([i ]).repeat (i ).to_iter_dataset ()
169177 )
170178 ds = interleave .InterleaveIterDataset (ds , cycle_length = 5 )
179+ ds = self ._maybe_wrap_ds (ds )
171180 ds = ds .mp_prefetch (options .MultiprocessingOptions (num_workers = 3 ))
172181 self .assertEqual (list (ds ), [1 , 2 , 3 , 4 , 5 , 3 , 4 , 2 , 3 , 4 , 5 , 4 , 5 , 5 , 5 ])
173182
@@ -176,6 +185,7 @@ def test_options_propagated(self):
176185 ds1 = ds1 .filter (lambda x : False )
177186 ds2 = dataset .MapDataset .source ([2 ]).repeat (1000 ).to_iter_dataset ()
178187 ds = interleave .InterleaveIterDataset ([ds1 , ds2 ], cycle_length = 1 )
188+ ds = self ._maybe_wrap_ds (ds )
179189 ds_options = base .DatasetOptions (filter_raise_threshold_ratio = 0.1 )
180190 ds = dataset .WithOptionsIterDataset (ds , ds_options )
181191 with self .assertRaisesRegex (ValueError , r"skipped 100\.00 %" ):
@@ -187,6 +197,7 @@ def test_checkpointing_comprehensive(self):
187197 for i in range (1 , 6 )
188198 ]
189199 ds = interleave .InterleaveIterDataset (ds , cycle_length = 5 )
200+ ds = self ._maybe_wrap_ds (ds )
190201 assert_equal_output_after_checkpoint (ds )
191202
192203 def test_set_state_does_not_recreate_iterators_if_not_needed (self ):
@@ -196,6 +207,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self):
196207 ds = interleave .InterleaveIterDataset (
197208 [ds ] * cycle_length , cycle_length = cycle_length
198209 )
210+ ds = self ._maybe_wrap_ds (ds )
199211 ds_iter = ds .__iter__ ()
200212 iter_ids1 = []
201213 for _ in range (cycle_length ):
@@ -211,6 +223,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self):
211223 def test_element_spec (self ):
212224 ds = dataset .MapDataset .range (3 ).to_iter_dataset ()
213225 ds = interleave .InterleaveIterDataset ([ds , ds ], cycle_length = 2 )
226+ ds = self ._maybe_wrap_ds (ds )
214227 spec = dataset .get_element_spec (ds )
215228 self .assertEqual (spec .dtype , np .int64 )
216229 self .assertEqual (spec .shape , ())
@@ -255,6 +268,7 @@ def test_interleave_stats_with_mismatched_dataset_structures(self):
255268 def test_get_next_index (self ):
256269 ds = dataset .MapDataset .range (10 ).to_iter_dataset ()
257270 ds = interleave .InterleaveIterDataset ([ds ], cycle_length = 1 )
271+ ds = self ._maybe_wrap_ds (ds )
258272 ds_iter = ds .__iter__ ()
259273 self .assertEqual (dataset .get_next_index (ds_iter ), 0 )
260274 for i in range (10 ):
@@ -264,6 +278,7 @@ def test_get_next_index(self):
264278 def test_set_next_index (self ):
265279 ds = dataset .MapDataset .range (10 ).to_iter_dataset ()
266280 ds = interleave .InterleaveIterDataset ([ds ], cycle_length = 1 )
281+ ds = self ._maybe_wrap_ds (ds )
267282 ds_iter = ds .__iter__ ()
268283 for i in reversed (range (10 )):
269284 dataset .set_next_index (ds_iter , i )
@@ -272,6 +287,7 @@ def test_set_next_index(self):
272287 def test_get_next_index_with_multiple_datasets (self ):
273288 ds = dataset .MapDataset .range (10 ).to_iter_dataset ()
274289 ds = interleave .InterleaveIterDataset ([ds , ds ], cycle_length = 2 )
290+ ds = self ._maybe_wrap_ds (ds )
275291 ds_iter = ds .__iter__ ()
276292 with self .assertRaisesRegex (
277293 NotImplementedError ,
@@ -283,6 +299,7 @@ def test_get_next_index_with_multiple_datasets(self):
283299 def test_set_next_index_with_multiple_datasets (self ):
284300 ds = dataset .MapDataset .range (10 ).to_iter_dataset ()
285301 ds = interleave .InterleaveIterDataset ([ds , ds ], cycle_length = 2 )
302+ ds = self ._maybe_wrap_ds (ds )
286303 ds_iter = ds .__iter__ ()
287304 with self .assertRaisesRegex (
288305 NotImplementedError ,
@@ -297,6 +314,7 @@ def test_future_states(self):
297314 dataset .MapDataset .source ([3 , 4 ]).to_iter_dataset (),
298315 ]
299316 ds = interleave .InterleaveIterDataset (datasets , cycle_length = 1 )
317+ ds = self ._maybe_wrap_ds (ds )
300318 ds_iter = ds .__iter__ ()
301319
302320 # Initialize the first iterator and get state.
@@ -323,5 +341,9 @@ def test_future_states(self):
323341 next (ds_iter )
324342
325343
344+ class InterleaveIterDatasetTest (_InterleaveIterDatasetTestBase ):
345+ """Runs tests without prefetch."""
346+
347+
326348if __name__ == "__main__" :
327349 absltest .main ()
0 commit comments