Skip to content

Commit 24dbd94

Browse files
sagunbcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 893159239
1 parent 23d420f commit 24dbd94

File tree

6 files changed

+58
-4
lines changed

6 files changed

+58
-4
lines changed

grain/_src/core/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,29 @@ py_library(
2727
srcs_version = "PY3",
2828
)
2929

30+
py_library(
31+
name = "executor",
32+
srcs = ["executor.py"],
33+
srcs_version = "PY3",
34+
)
35+
36+
py_library(
37+
name = "daemon_thread_pool_executor",
38+
srcs = ["daemon_thread_pool_executor.py"],
39+
srcs_version = "PY3",
40+
deps = [":executor"],
41+
)
42+
43+
py_test(
44+
name = "daemon_thread_pool_executor_test",
45+
srcs = ["daemon_thread_pool_executor_test.py"],
46+
srcs_version = "PY3",
47+
deps = [
48+
":daemon_thread_pool_executor",
49+
"@abseil-py//absl/testing:absltest",
50+
],
51+
)
52+
3053
py_library(
3154
name = "monitoring",
3255
srcs = ["monitoring.py"],

grain/_src/python/dataset/transformations/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ py_test(
6464
shard_count = 50,
6565
srcs_version = "PY3",
6666
deps = [
67-
"//grain/_src/core:config",
6867
"//grain/_src/core:transforms",
6968
"//grain/_src/python:options",
7069
"//grain/_src/python/dataset",

grain/_src/python/dataset/transformations/interleave.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,9 +369,11 @@ def _add_prefetch_and_make_iterator(
369369
ds, prefetch_buffer_size=interleave_iterator_obj._iter_buffer_size # pylint: disable=protected-access
370370
)
371371
iterator = iter_dataset.__iter__()
372+
372373
# Propagate options applied after InterleaveIterDataset to the iterators that
373374
# are being interleaved.
374375
iterator._ctx.dataset_options = interleave_iterator_obj._ctx.dataset_options.merge(iterator._ctx.dataset_options) # pylint: disable=protected-access
376+
375377
if start_prefetch:
376378
iterator.start_prefetch()
377379
return iterator

grain/_src/python/dataset/transformations/interleave_test.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ def __iter__(self) -> dataset.DatasetIterator:
9090
return _IteratorIdDatasetIterator(self._parent.__iter__())
9191

9292

93-
class InterleaveIterDatasetTest(parameterized.TestCase):
93+
@absltest.skipThisClass("Base class")
94+
class _InterleaveIterDatasetTestBase(parameterized.TestCase):
95+
96+
def _maybe_wrap_ds(self, ds):
97+
return ds
9498

9599
@parameterized.named_parameters(*_INTERLEAVE_TEST_CASES)
96100
def test_interleaved_mix(self, to_mix, cycle_length, expected):
@@ -99,6 +103,7 @@ def test_interleaved_mix(self, to_mix, cycle_length, expected):
99103
for elements in to_mix
100104
]
101105
ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length)
106+
ds = self._maybe_wrap_ds(ds)
102107
self.assertEqual(list(ds), expected)
103108
# Sanity check.
104109
flat_inputs = []
@@ -113,6 +118,7 @@ def test_checkpoint(self, to_mix, cycle_length, expected):
113118
for elements in to_mix
114119
]
115120
ds = interleave.InterleaveIterDataset(datasets, cycle_length=cycle_length)
121+
ds = self._maybe_wrap_ds(ds)
116122
ds_iter = ds.__iter__()
117123
checkpoints = {}
118124
for i in range(len(expected)):
@@ -138,6 +144,7 @@ def test_checkpoint_with_extra_threads_creating_iterators(
138144
num_make_iter_threads=10,
139145
make_iter_buffer_size=10,
140146
)
147+
ds = self._maybe_wrap_ds(ds)
141148
ds_iter = ds.__iter__()
142149
checkpoints = {}
143150
for i in range(len(expected)):
@@ -158,6 +165,7 @@ def make_dummy_source(filename):
158165
filenames = dataset.MapDataset.source(["11", "2345", "678", "9999"])
159166
sources = filenames.shuffle(seed=42).map(make_dummy_source)
160167
ds = interleave.InterleaveIterDataset(sources, cycle_length=2)
168+
ds = self._maybe_wrap_ds(ds)
161169
self.assertEqual(
162170
list(ds),
163171
["1", "2", "1", "3", "4", "6", "5", "7", "8", "9", "9", "9", "9"],
@@ -168,6 +176,7 @@ def test_with_mp_prefetch(self):
168176
lambda i: dataset.MapDataset.source([i]).repeat(i).to_iter_dataset()
169177
)
170178
ds = interleave.InterleaveIterDataset(ds, cycle_length=5)
179+
ds = self._maybe_wrap_ds(ds)
171180
ds = ds.mp_prefetch(options.MultiprocessingOptions(num_workers=3))
172181
self.assertEqual(list(ds), [1, 2, 3, 4, 5, 3, 4, 2, 3, 4, 5, 4, 5, 5, 5])
173182

@@ -176,6 +185,7 @@ def test_options_propagated(self):
176185
ds1 = ds1.filter(lambda x: False)
177186
ds2 = dataset.MapDataset.source([2]).repeat(1000).to_iter_dataset()
178187
ds = interleave.InterleaveIterDataset([ds1, ds2], cycle_length=1)
188+
ds = self._maybe_wrap_ds(ds)
179189
ds_options = base.DatasetOptions(filter_raise_threshold_ratio=0.1)
180190
ds = dataset.WithOptionsIterDataset(ds, ds_options)
181191
with self.assertRaisesRegex(ValueError, r"skipped 100\.00 %"):
@@ -187,6 +197,7 @@ def test_checkpointing_comprehensive(self):
187197
for i in range(1, 6)
188198
]
189199
ds = interleave.InterleaveIterDataset(ds, cycle_length=5)
200+
ds = self._maybe_wrap_ds(ds)
190201
assert_equal_output_after_checkpoint(ds)
191202

192203
def test_set_state_does_not_recreate_iterators_if_not_needed(self):
@@ -196,6 +207,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self):
196207
ds = interleave.InterleaveIterDataset(
197208
[ds] * cycle_length, cycle_length=cycle_length
198209
)
210+
ds = self._maybe_wrap_ds(ds)
199211
ds_iter = ds.__iter__()
200212
iter_ids1 = []
201213
for _ in range(cycle_length):
@@ -211,6 +223,7 @@ def test_set_state_does_not_recreate_iterators_if_not_needed(self):
211223
def test_element_spec(self):
212224
ds = dataset.MapDataset.range(3).to_iter_dataset()
213225
ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2)
226+
ds = self._maybe_wrap_ds(ds)
214227
spec = dataset.get_element_spec(ds)
215228
self.assertEqual(spec.dtype, np.int64)
216229
self.assertEqual(spec.shape, ())
@@ -255,6 +268,7 @@ def test_interleave_stats_with_mismatched_dataset_structures(self):
255268
def test_get_next_index(self):
256269
ds = dataset.MapDataset.range(10).to_iter_dataset()
257270
ds = interleave.InterleaveIterDataset([ds], cycle_length=1)
271+
ds = self._maybe_wrap_ds(ds)
258272
ds_iter = ds.__iter__()
259273
self.assertEqual(dataset.get_next_index(ds_iter), 0)
260274
for i in range(10):
@@ -264,6 +278,7 @@ def test_get_next_index(self):
264278
def test_set_next_index(self):
265279
ds = dataset.MapDataset.range(10).to_iter_dataset()
266280
ds = interleave.InterleaveIterDataset([ds], cycle_length=1)
281+
ds = self._maybe_wrap_ds(ds)
267282
ds_iter = ds.__iter__()
268283
for i in reversed(range(10)):
269284
dataset.set_next_index(ds_iter, i)
@@ -272,6 +287,7 @@ def test_set_next_index(self):
272287
def test_get_next_index_with_multiple_datasets(self):
273288
ds = dataset.MapDataset.range(10).to_iter_dataset()
274289
ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2)
290+
ds = self._maybe_wrap_ds(ds)
275291
ds_iter = ds.__iter__()
276292
with self.assertRaisesRegex(
277293
NotImplementedError,
@@ -283,6 +299,7 @@ def test_get_next_index_with_multiple_datasets(self):
283299
def test_set_next_index_with_multiple_datasets(self):
284300
ds = dataset.MapDataset.range(10).to_iter_dataset()
285301
ds = interleave.InterleaveIterDataset([ds, ds], cycle_length=2)
302+
ds = self._maybe_wrap_ds(ds)
286303
ds_iter = ds.__iter__()
287304
with self.assertRaisesRegex(
288305
NotImplementedError,
@@ -297,6 +314,7 @@ def test_future_states(self):
297314
dataset.MapDataset.source([3, 4]).to_iter_dataset(),
298315
]
299316
ds = interleave.InterleaveIterDataset(datasets, cycle_length=1)
317+
ds = self._maybe_wrap_ds(ds)
300318
ds_iter = ds.__iter__()
301319

302320
# Initialize the first iterator and get state.
@@ -323,5 +341,9 @@ def test_future_states(self):
323341
next(ds_iter)
324342

325343

344+
class InterleaveIterDatasetTest(_InterleaveIterDatasetTestBase):
345+
"""Runs tests without prefetch."""
346+
347+
326348
if __name__ == "__main__":
327349
absltest.main()

grain/_src/python/dataset/transformations/prefetch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,7 @@ def start_prefetch(self):
534534
return
535535

536536
self._prefetch_should_stop.clear()
537+
537538
self._prefetch_thread = threading.Thread(
538539
target=functools.partial(
539540
_put_iterator_elements_in_buffer,
@@ -595,8 +596,10 @@ def _stop_prefetch(self):
595596
# is shutting down. Attempting to join can lead to hanging in Python
596597
# 3.13 as daemon threads can hang during interpreter shutdown. See
597598
# https://github.com/python/cpython/issues/123940#issuecomment-2976446309
598-
self._prefetch_thread.join()
599+
if self._prefetch_thread is not None:
600+
self._prefetch_thread.join()
599601
self._prefetch_thread = None
602+
600603
# Clear the buffer again in case the prefetch loop added more elements on
601604
# exit.
602605
self._clear_buffer()

grain/_src/python/dataset/transformations/prefetch_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ def test_set_next_index(self):
448448
self.assertEqual(next(ds_iter), i)
449449

450450

451-
class ThreadPrefetchIterDatasetTest(parameterized.TestCase):
451+
@absltest.skipThisClass('Base class')
452+
class _ThreadPrefetchIterDatasetTestBase(parameterized.TestCase):
452453

453454
def setUp(self):
454455
super().setUp()
@@ -768,6 +769,10 @@ def new_get_state(self):
768769
self.assertEqual(get_state_counter.call_count - get_state_count, 1)
769770

770771

772+
class ThreadPrefetchIterDatasetTest(_ThreadPrefetchIterDatasetTestBase):
773+
"""Runs tests without provided executor."""
774+
775+
771776
class _MpContextCheckIterDataset(dataset.IterDataset[_T]):
772777

773778
def __iter__(self) -> dataset.DatasetIterator[_T]:

0 commit comments

Comments
 (0)