-
Notifications
You must be signed in to change notification settings - Fork 574
[Auto-Paralllel] tmp fix shard_dataloader #4974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: devel
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a temporary fix for the shard_dataloader functionality in the auto-parallel training system. The changes address distributed data loading by temporarily disabling the original distributed batch sampling logic and implementing a new approach using paddle's shard_dataloader.
- Commented out existing distributed batch sampling logic in favor of a new implementation
- Updated BufferedIterator to accept an explicit data length parameter instead of relying on len(iterable)
- Implemented new data loading approach using paddle.distributed.shard_dataloader with cycling through multiple dataloaders
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| deepmd/pd/utils/dataloader.py | Comments out distributed batch sampling logic and updates BufferedIterator constructor to accept explicit length parameter |
| deepmd/pd/train/training.py | Replaces original dataloader creation with shard_dataloader implementation and adds tracking for total data length |
| log_dict = {} | ||
| if "fid" in batch_data: | ||
| log_dict["fid"] = batch_data["fid"] | ||
| log_dict["sid"] = batch_data["sid"] |
Copilot
AI
Sep 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The log_dict['sid'] is being overwritten immediately after assignment. Line 1171 unconditionally sets it to 0, making the previous assignment from batch_data['sid'] useless.
| log_dict["sid"] = batch_data["sid"] |
deepmd/pd/train/training.py
Outdated
| self.all_dlen += dlen | ||
| all_dataloaders.append(shard_dataloader) | ||
| _shard_dataloader = cycle(chain(*all_dataloaders)) | ||
| _data_buffered = BufferedIterator(iter(_shard_dataloader),self.all_dlen) |
Copilot
AI
Sep 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Missing space after comma in function call parameters. Should be BufferedIterator(iter(_shard_dataloader), self.all_dlen) for consistency with Python style guidelines.
| _data_buffered = BufferedIterator(iter(_shard_dataloader),self.all_dlen) | |
| _data_buffered = BufferedIterator(iter(_shard_dataloader), self.all_dlen) |
| def __init__(self, iterable, alldlen) -> None: | ||
| self._queue = queue.Queue(QUEUESIZE) | ||
| self._iterable = iterable | ||
| self._iterator = iter(iterable) |
Copilot
AI
Sep 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The new _iterator field is added but appears unused in the class. If it's intended for future use, consider adding a comment explaining its purpose, or remove it if it's not needed.
| self._iterator = iter(iterable) |
for more information, see https://pre-commit.ci
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughReplaces single DataLoader usage with a distributed shard-based, cyclic iterator across multiple shard dataloaders; tracks total length via a new all_dlen attribute and uses a length-aware BufferedIterator. Modifies BufferedIterator’s constructor to accept total length. Removes distributed sampler path in utils, always using BatchSampler. Adjusts logging of sid to 0. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Trainer
participant ShardDL as Shard Dataloaders
participant ChainCycle as chain()+cycle()
participant BufIter as BufferedIterator
Trainer->>ShardDL: create per-dataloader shards (shard_dataloader)
Note right of ShardDL: Compute total_len = sum(len(shard) for all shards)
ShardDL-->>Trainer: shard iterators + total_len
Trainer->>ChainCycle: chain(shards) then cycle(...)
ChainCycle-->>Trainer: infinite iterator over shards
Trainer->>BufIter: BufferedIterator(iterable=ChainCycle, total_len)
BufIter-->>Trainer: length-aware data iterator
loop training step
Trainer->>BufIter: next()
alt data available
BufIter-->>Trainer: batch
else StopIteration
Note over Trainer: Rebuild BufferedIterator(iter(...), total_len)
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks (1 passed, 1 warning, 1 inconclusive)❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. ✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/pd/train/training.py (2)
1116-1118: Validation data BufferedIterator missing length parameterThe BufferedIterator for validation data is still using the old single-parameter constructor, which will cause a TypeError at runtime.
Update the validation data BufferedIterator to use the new two-parameter constructor:
- self.validation_data = BufferedIterator( - iter(self.validation_dataloader) - ) + self.validation_data = BufferedIterator( + iter(self.validation_dataloader), len(self.validation_dataloader) + )
1126-1128: Multi-task BufferedIterator calls missing length parameterThe BufferedIterator instantiations for multi-task scenarios are also using the old single-parameter constructor and will fail at runtime.
Update all multi-task BufferedIterator calls:
# For training data (lines 1126-1128): - self.training_data[task_key] = BufferedIterator( - iter(self.training_dataloader[task_key]) - ) + self.training_data[task_key] = BufferedIterator( + iter(self.training_dataloader[task_key]), len(self.training_dataloader[task_key]) + ) # For validation data (lines 1136-1138): - self.validation_data[task_key] = BufferedIterator( - iter(self.validation_dataloader[task_key]) - ) + self.validation_data[task_key] = BufferedIterator( + iter(self.validation_dataloader[task_key]), len(self.validation_dataloader[task_key]) + )Also applies to: 1136-1138
♻️ Duplicate comments (3)
deepmd/pd/utils/dataloader.py (1)
296-296: Unused_iteratorfielddeepmd/pd/train/training.py (2)
1172-1173:sidvalue is overwritten, losing actual system IDThe actual system ID from
batch_data["sid"]is immediately overwritten with 0, which removes useful debugging information and makes the logging less informative.
186-186:all_dlenis being re-initialized unnecessarilyThe
self.all_dlen = 0on line 186 resets the value that was already initialized on line 146. This could be confusing and the initialization on line 146 seems unnecessary.Remove the initialization on line 146 since the value is properly set within
get_dataloader_and_buffer:- self.all_dlen = 0
🧹 Nitpick comments (2)
deepmd/pd/train/training.py (2)
146-146: New attribute should be initialized consistentlyThe
all_dlenattribute is being used to track total shard length but is initialized to 0 outside of the dataloader setup. Consider initializing it closer to where it's actually computed or document why it needs early initialization.
195-196: Consider adding spacing for readability- _data_buffered = BufferedIterator( - iter(_shard_dataloader), self.all_dlen - ) + _data_buffered = BufferedIterator( + iter(_shard_dataloader), self.all_dlen + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pd/train/training.py(4 hunks)deepmd/pd/utils/dataloader.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (.github/copilot-instructions.md)
**/*.py: Run ruff check and ruff format on all Python code before committing
Ensure Python code is formatted with ruff’s formatter
Files:
deepmd/pd/train/training.pydeepmd/pd/utils/dataloader.py
🧬 Code graph analysis (1)
deepmd/pd/train/training.py (1)
deepmd/pd/utils/dataloader.py (1)
BufferedIterator(292-337)
🔇 Additional comments (3)
deepmd/pd/utils/dataloader.py (1)
171-189: Simplified sampler construction looks goodThe removal of the DistributedBatchSampler path simplifies the code while maintaining the shuffle logic based on distributed initialization status.
deepmd/pd/train/training.py (2)
180-198: Good implementation of distributed shard-based loadingThe new approach using
paddle.distributed.shard_dataloaderwith cyclic iteration through chained shards provides better distributed data handling. The accumulation of total length inself.all_dlenproperly supports the BufferedIterator's length tracking.
1107-1107: BufferedIterator constructor call is correctThe updated call to BufferedIterator with two arguments properly uses the new API with
self.all_dlenfor length tracking.
| def __init__(self, iterable, alldlen) -> None: | ||
| self._queue = queue.Queue(QUEUESIZE) | ||
| self._iterable = iterable | ||
| self._iterator = iter(iterable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
API change needs documentation
The BufferedIterator constructor signature change from __init__(self, iterable) to __init__(self, iterable, alldlen) is a breaking API change. Consider adding a docstring to document this new parameter and its purpose.
class BufferedIterator:
def __init__(self, iterable, alldlen) -> None:
+ """Initialize BufferedIterator with an iterable and total length.
+
+ Parameters
+ ----------
+ iterable
+ The iterable data source to buffer.
+ alldlen : int
+ The total length across all data sources for proper iteration tracking.
+ """
self._queue = queue.Queue(QUEUESIZE)
self._iterable = iterable
self._iterator = iter(iterable)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def __init__(self, iterable, alldlen) -> None: | |
| self._queue = queue.Queue(QUEUESIZE) | |
| self._iterable = iterable | |
| self._iterator = iter(iterable) | |
| def __init__(self, iterable, alldlen) -> None: | |
| """Initialize BufferedIterator with an iterable and total length. | |
| Parameters | |
| ---------- | |
| iterable | |
| The iterable data source to buffer. | |
| alldlen : int | |
| The total length across all data sources for proper iteration tracking. | |
| """ | |
| self._queue = queue.Queue(QUEUESIZE) | |
| self._iterable = iterable | |
| self._iterator = iter(iterable) |
🤖 Prompt for AI Agents
In deepmd/pd/utils/dataloader.py around lines 293 to 296, the BufferedIterator
constructor signature was changed from __init__(self, iterable) to
__init__(self, iterable, alldlen) but there is no documentation for the new
parameter; add a concise docstring to the class or its __init__ method that
explains the new alldlen parameter (type, meaning, whether it’s optional,
expected units/semantics and how it affects behavior), update any public API
docs or README references to mention the new required argument, and include a
short example or note about backward compatibility or migration if callers
previously relied on the single-argument form.
| # if dist.is_available() and dist.is_initialized(): | ||
| # system_batch_sampler = DistributedBatchSampler( | ||
| # system, | ||
| # shuffle=( | ||
| # (not (dist.is_available() and dist.is_initialized())) | ||
| # and shuffle |
Check notice
Code scanning / CodeQL
Commented-out code Note
Summary by CodeRabbit
New Features
Refactor