Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 92 additions & 74 deletions src/client/pydaos/torch/torch_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import stat
import sys
from multiprocessing import Process, Queue
from multiprocessing import Pool, Process, Queue, current_process
from pathlib import Path

from torch.utils.data import Dataset as TorchDataset
Expand Down Expand Up @@ -69,7 +69,8 @@ class Dataset(TorchDataset):
Number of directory entries to read for each readdir call.
dir_cache_size: int (optional)
Number of directory object entries to cache in memory.

readdir_workers: int (optional)
Number of parallel workers for namespace scanning.

Methods
-------
Expand All @@ -92,7 +93,8 @@ class Dataset(TorchDataset):
def __init__(self, pool=None, cont=None, path=None,
transform_fn=transform_fn_default,
readdir_batch_size=READDIR_BATCH_SIZE,
dir_cache_size=DIR_CACHE_SIZE):
dir_cache_size=DIR_CACHE_SIZE,
readdir_workers=PARALLEL_SCAN_WORKERS):
super().__init__()

self._pool = pool
Expand All @@ -102,7 +104,8 @@ def __init__(self, pool=None, cont=None, path=None,
self._readdir_batch_size = readdir_batch_size
self._closed = False

self.objects = self._dfs.parallel_list(path, readdir_batch_size=self._readdir_batch_size)
self.objects = self._dfs.parallel_list(
path, readdir_batch_size=self._readdir_batch_size, workers=readdir_workers)

def __len__(self):
""" Returns number of items in this dataset """
Expand Down Expand Up @@ -216,6 +219,8 @@ class IterableDataset(TorchIterableDataset):
Number of samples to fetch per iteration.
dir_cache_size: int (optional)
Number of directory object entries to cache in memory.
readdir_workers: int (optional)
Number of parallel workers for namespace scanning.


Methods
Expand All @@ -233,7 +238,8 @@ def __init__(self, pool=None, cont=None, path=None,
transform_fn=transform_fn_default,
readdir_batch_size=READDIR_BATCH_SIZE,
batch_size=ITER_BATCH_SIZE,
dir_cache_size=DIR_CACHE_SIZE):
dir_cache_size=DIR_CACHE_SIZE,
readdir_workers=PARALLEL_SCAN_WORKERS):
super().__init__()

self._pool = pool
Expand All @@ -244,7 +250,8 @@ def __init__(self, pool=None, cont=None, path=None,
self._batch_size = batch_size
self._closed = False

self.objects = self._dfs.parallel_list(path, readdir_batch_size=self._readdir_batch_size)
self.objects = self._dfs.parallel_list(
path, readdir_batch_size=self._readdir_batch_size, workers=readdir_workers)
self.workset = self.objects

def __iter__(self):
Expand Down Expand Up @@ -646,6 +653,35 @@ def writer(self, file, ensure_path=True):
self._chunks_limit, self._workers)


def _readdir_worker_init(dfs, readdir_batch_size):
"""
Worker init for parallel readdir.

Receives `self` as an argument to re-init DAOS after fork, per worker process.

It has to be module function since the multiprocessing.Pool methods to init workers
will pickle instance method with main process's _Dfs class reference.
"""

dfs.worker_init()
proc = current_process()
proc.dfs = dfs
proc.readdir_batch_size = readdir_batch_size


def _readdir_batch(work):
"""
Reads the anchored directory at `path` with `anchor_index` and returns
list of discovered directories and files.

It has to be module function since the multiprocessing.Pool methods to submit jobs
will pickle instance method with main process's _Dfs class reference.
"""
path, anchor_index = work
proc = current_process()
return proc.dfs.readdir_anchored(path, anchor_index, proc.readdir_batch_size)


class _Dfs():
"""
Class encapsulating libdfs interface to load PyTorch Dataset
Expand Down Expand Up @@ -676,49 +712,10 @@ def disconnect(self):
raise OSError(ret, os.strerror(ret))
self._dfs = None

def list_worker_fn(self, in_work, out_dirs, out_files, readdir_batch_size=READDIR_BATCH_SIZE):
"""
Worker function to scan directory in parallel.
It expects to receive tuples (path, index) to scan the directory with an anchor index,
from the `in_work` queue.
It should emit tuples (scanned, to_scan) to the `out_dirs` queue, where `scanned` is the
number of scanned directories and `to_scan` is the list of directories to scan in parallel.
Upon completion it should emit the list of files in the `out_files` queue.
"""

self.worker_init()

result = []
while True:
work = in_work.get()
if work is None:
break

(path, index) = work

dirs = []
files = []
ret = torch_shim.torch_list_with_anchor(DAOS_MAGIC, self._dfs,
path, index, files, dirs, readdir_batch_size
)
if ret != 0:
raise OSError(ret, os.strerror(ret), path)

dirs = [chunk for d in dirs for chunk in self.split_dir_for_parallel_scan(
os.path.join(path, d))
]
# Even if there are no dirs, we should emit the tuple to notify the main process
out_dirs.put((1, dirs))

files = [(os.path.join(path, file), size) for (file, size) in files]
result.extend(files)

out_files.put(result)

def split_dir_for_parallel_scan(self, path):
"""
Splits dir for parallel readdir.
It returns list of tuples (dirname, anchor index) to be consumed by worker function
It returns list of tuples (dirname, anchor_index) to be consumed by workers
"""

ret, splits = torch_shim.torch_recommended_dir_split(DAOS_MAGIC, self._dfs, path)
Expand All @@ -727,6 +724,28 @@ def split_dir_for_parallel_scan(self, path):

return [(path, idx) for idx in range(0, splits)]

def readdir_anchored(self, path, anchor_index, readdir_batch_size):
"""
Scans one anchored by index directory at `path`.

Returns (dirs, files):
`dirs` are (path, anchor_index) work items for directories found in this batch,
`files` is a list of resulting tuples: (full_path, size).
"""
dirs = []
files = []
ret = torch_shim.torch_list_with_anchor(
DAOS_MAGIC, self._dfs, path, anchor_index, files, dirs, readdir_batch_size)
if ret != 0:
raise OSError(ret, os.strerror(ret), path)

subdirs = [split
for name in dirs
for split in self.split_dir_for_parallel_scan(os.path.join(path, name))]

files = [(os.path.join(path, name), size) for (name, size) in files]
return subdirs, files

def parallel_list(self, path=None,
readdir_batch_size=READDIR_BATCH_SIZE,
workers=PARALLEL_SCAN_WORKERS):
Expand All @@ -736,43 +755,42 @@ def parallel_list(self, path=None,

To fully use this feature the container should be configured with directory object classes
supporting this mode, e.g. OC_SX.

Using multiprocessing.Pool ensures propagation of errors in the workers and cleaning up
resources, regardless of operation outcome.

It would be even better to use `concurrent.futures.ProcessPoolExecutor`; however,
its `initializer` and `initargs` arguments are available only in Python 3.7+.

Although Python 3.6 is EOL, many distributions still ship it by default.
Keeping `_readdir_worker_init` and `_readdir_batch` as module-level functions
instead of private class methods, is a small price that allows us to support
a much broader range of platforms.
"""

if path is None:
path = os.sep

if not path.startswith(os.sep):
raise ValueError("relative path is unacceptable")

procs = []
work = Queue()
dirs = Queue()
files = Queue()
for _ in range(workers):
worker = Process(target=self.list_worker_fn, args=(
work, dirs, files, readdir_batch_size))
worker.start()
procs.append(worker)

queued = 0
processed = 0
for anchored_dir in self.split_dir_for_parallel_scan(path):
work.put(anchored_dir)
queued += 1

while processed < queued:
(scanned, to_scan) = dirs.get()
processed += scanned
for d in to_scan:
work.put(d)
queued += 1
if readdir_batch_size <= 0:
raise ValueError("readdir batch size should be a positive number")

result = []
for _ in range(workers):
work.put(None)
result.extend(files.get())
if workers <= 0:
raise ValueError("at least one worker is required for namespace scanning")

for worker in procs:
worker.join()
result = []
batch = self.split_dir_for_parallel_scan(path)
with Pool(workers,
initializer=_readdir_worker_init,
initargs=(self, readdir_batch_size)) as pool:
while batch:
next_batch = []
for dirs, files in pool.imap_unordered(_readdir_batch, batch):
next_batch.extend(dirs)
result.extend(files)
batch = next_batch
Comment thread
daltonbohning marked this conversation as resolved.

return result

Expand Down
Loading