-
Notifications
You must be signed in to change notification settings - Fork 575
[Auto-Paralllel] tmp fix shard_dataloader #4974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -147,6 +147,7 @@ def __init__( | |||
| "change_bias_after_training", False | ||||
| ) | ||||
| self.lcurve_should_print_header = True | ||||
| self.all_dlen = 0 | ||||
|
|
||||
| def get_opt_param(params): | ||||
| opt_type = params.get("opt_type", "Adam") | ||||
|
|
@@ -166,19 +167,33 @@ def get_dataloader_and_buffer(_data, _params): | |||
| log.warning( | ||||
| "Sampler not specified!" | ||||
| ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. | ||||
| _dataloader = DataLoader( | ||||
| _data, | ||||
| batch_sampler=paddle.io.BatchSampler( | ||||
| sampler=_sampler, | ||||
| drop_last=False, | ||||
| ), | ||||
| num_workers=NUM_WORKERS | ||||
| if dist.is_available() | ||||
| else 0, # setting to 0 diverges the behavior of its iterator; should be >=1 | ||||
| collate_fn=lambda batch: batch[0], # prevent extra conversion | ||||
| ) | ||||
| _data_buffered = BufferedIterator(iter(_dataloader)) | ||||
| return _dataloader, _data_buffered | ||||
| # _dataloader = DataLoader( | ||||
| # _data, | ||||
| # batch_sampler=paddle.io.BatchSampler( | ||||
| # sampler=_sampler, | ||||
| # drop_last=False, | ||||
| # ), | ||||
| # num_workers=NUM_WORKERS | ||||
| # if dist.is_available() | ||||
| # else 0, # setting to 0 diverges the behavior of its iterator; should be >=1 | ||||
| # collate_fn=lambda batch: batch[0], # prevent extra conversion | ||||
| # ) | ||||
| # _data_buffered = BufferedIterator(iter(_dataloader)) | ||||
| # return _dataloader, _data_buffered | ||||
|
|
||||
| from itertools import chain, cycle | ||||
| all_dataloaders = [] | ||||
| self.all_dlen = 0 | ||||
| for dataloader in _data.dataloaders: | ||||
| shard_dataloader = paddle.distributed.shard_dataloader( | ||||
| dataloader, dist.get_mesh(), shard_dims="dp" | ||||
| ) | ||||
| dlen = len(shard_dataloader) | ||||
| self.all_dlen += dlen | ||||
| all_dataloaders.append(shard_dataloader) | ||||
| _shard_dataloader = cycle(chain(*all_dataloaders)) | ||||
| _data_buffered = BufferedIterator(iter(_shard_dataloader),self.all_dlen) | ||||
| return _shard_dataloader, _data_buffered | ||||
|
|
||||
| training_dataloader, training_data_buffered = get_dataloader_and_buffer( | ||||
| _training_data, _training_params["training_data"] | ||||
|
|
@@ -1087,7 +1102,7 @@ def get_data(self, is_train=True, task_key="Default"): | |||
| except StopIteration: | ||||
| # Refresh the status of the dataloader to start from a new epoch | ||||
| self.training_data = BufferedIterator( | ||||
| iter(self.training_dataloader) | ||||
| iter(self.training_dataloader), self.all_dlen | ||||
| ) | ||||
| batch_data = next(iter(self.training_data)) | ||||
| else: | ||||
|
|
@@ -1153,6 +1168,7 @@ def get_data(self, is_train=True, task_key="Default"): | |||
| if "fid" in batch_data: | ||||
| log_dict["fid"] = batch_data["fid"] | ||||
| log_dict["sid"] = batch_data["sid"] | ||||
|
||||
| log_dict["sid"] = batch_data["sid"] |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -168,26 +168,26 @@ def construct_dataset(system): | |||||||||||||||||||||||||||||||||||||
| self.batch_sizes = batch_size * np.ones(len(systems), dtype=int) | ||||||||||||||||||||||||||||||||||||||
| assert len(self.systems) == len(self.batch_sizes) | ||||||||||||||||||||||||||||||||||||||
| for system, batch_size in zip(self.systems, self.batch_sizes): | ||||||||||||||||||||||||||||||||||||||
| if dist.is_available() and dist.is_initialized(): | ||||||||||||||||||||||||||||||||||||||
| system_batch_sampler = DistributedBatchSampler( | ||||||||||||||||||||||||||||||||||||||
| system, | ||||||||||||||||||||||||||||||||||||||
| shuffle=( | ||||||||||||||||||||||||||||||||||||||
| (not (dist.is_available() and dist.is_initialized())) | ||||||||||||||||||||||||||||||||||||||
| and shuffle | ||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||
| batch_size=int(batch_size), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| self.sampler_list.append(system_batch_sampler) | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| system_batch_sampler = BatchSampler( | ||||||||||||||||||||||||||||||||||||||
| system, | ||||||||||||||||||||||||||||||||||||||
| shuffle=( | ||||||||||||||||||||||||||||||||||||||
| (not (dist.is_available() and dist.is_initialized())) | ||||||||||||||||||||||||||||||||||||||
| and shuffle | ||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||
| batch_size=int(batch_size), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| self.sampler_list.append(system_batch_sampler) | ||||||||||||||||||||||||||||||||||||||
| # if dist.is_available() and dist.is_initialized(): | ||||||||||||||||||||||||||||||||||||||
| # system_batch_sampler = DistributedBatchSampler( | ||||||||||||||||||||||||||||||||||||||
| # system, | ||||||||||||||||||||||||||||||||||||||
| # shuffle=( | ||||||||||||||||||||||||||||||||||||||
| # (not (dist.is_available() and dist.is_initialized())) | ||||||||||||||||||||||||||||||||||||||
| # and shuffle | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+171
to
+176
Check noticeCode scanning / CodeQL Commented-out code Note
This comment appears to contain commented-out code.
|
||||||||||||||||||||||||||||||||||||||
| # ), | ||||||||||||||||||||||||||||||||||||||
| # batch_size=int(batch_size), | ||||||||||||||||||||||||||||||||||||||
| # ) | ||||||||||||||||||||||||||||||||||||||
| # self.sampler_list.append(system_batch_sampler) | ||||||||||||||||||||||||||||||||||||||
| # else: | ||||||||||||||||||||||||||||||||||||||
| system_batch_sampler = BatchSampler( | ||||||||||||||||||||||||||||||||||||||
| system, | ||||||||||||||||||||||||||||||||||||||
| shuffle=( | ||||||||||||||||||||||||||||||||||||||
| (not (dist.is_available() and dist.is_initialized())) | ||||||||||||||||||||||||||||||||||||||
| and shuffle | ||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||
| batch_size=int(batch_size), | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| self.sampler_list.append(system_batch_sampler) | ||||||||||||||||||||||||||||||||||||||
| system_dataloader = DataLoader( | ||||||||||||||||||||||||||||||||||||||
| dataset=system, | ||||||||||||||||||||||||||||||||||||||
| num_workers=0, # Should be 0 to avoid too many threads forked | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -291,14 +291,15 @@ def run(self) -> None: | |||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| class BufferedIterator: | ||||||||||||||||||||||||||||||||||||||
| def __init__(self, iterable) -> None: | ||||||||||||||||||||||||||||||||||||||
| def __init__(self, iterable, alldlen) -> None: | ||||||||||||||||||||||||||||||||||||||
| self._queue = queue.Queue(QUEUESIZE) | ||||||||||||||||||||||||||||||||||||||
| self._iterable = iterable | ||||||||||||||||||||||||||||||||||||||
| self._iterator = iter(iterable) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
| self._iterator = iter(iterable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
API change needs documentation
The BufferedIterator constructor signature change from __init__(self, iterable) to __init__(self, iterable, alldlen) is a breaking API change. Consider adding a docstring to document this new parameter and its purpose.
class BufferedIterator:
def __init__(self, iterable, alldlen) -> None:
+ """Initialize BufferedIterator with an iterable and total length.
+
+ Parameters
+ ----------
+ iterable
+ The iterable data source to buffer.
+ alldlen : int
+ The total length across all data sources for proper iteration tracking.
+ """
self._queue = queue.Queue(QUEUESIZE)
self._iterable = iterable
self._iterator = iter(iterable)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def __init__(self, iterable, alldlen) -> None: | |
| self._queue = queue.Queue(QUEUESIZE) | |
| self._iterable = iterable | |
| self._iterator = iter(iterable) | |
| def __init__(self, iterable, alldlen) -> None: | |
| """Initialize BufferedIterator with an iterable and total length. | |
| Parameters | |
| ---------- | |
| iterable | |
| The iterable data source to buffer. | |
| alldlen : int | |
| The total length across all data sources for proper iteration tracking. | |
| """ | |
| self._queue = queue.Queue(QUEUESIZE) | |
| self._iterable = iterable | |
| self._iterator = iter(iterable) |
🤖 Prompt for AI Agents
In deepmd/pd/utils/dataloader.py around lines 293 to 296, the BufferedIterator
constructor signature was changed from __init__(self, iterable) to
__init__(self, iterable, alldlen) but there is no documentation for the new
parameter; add a concise docstring to the class or its __init__ method that
explains the new alldlen parameter (type, meaning, whether it’s optional,
expected units/semantics and how it affects behavior), update any public API
docs or README references to mention the new required argument, and include a
short example or note about backward compatibility or migration if callers
previously relied on the single-argument form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Missing space after comma in function call parameters. Should be
BufferedIterator(iter(_shard_dataloader), self.all_dlen)for consistency with Python style guidelines.