Skip to content

Conversation

@Xing-lil
Copy link

@Xing-lil Xing-lil commented Sep 12, 2025

Summary by CodeRabbit

  • New Features

    • Introduces shard-based, distributed data loading with automatic cycling across multiple datasets/devices.
    • Adds buffered iteration with accurate total length for more consistent epoch handling.
  • Refactor

    • Simplifies sampling by using a unified sampler path across environments.
    • Updates logging to always report sid as 0 for training logs.
    • Changes BufferedIterator initialization to require a total-length argument (breaking change for custom integrations).

Copilot AI review requested due to automatic review settings September 12, 2025 06:34
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a temporary fix for the shard_dataloader functionality in the auto-parallel training system. The changes address distributed data loading by temporarily disabling the original distributed batch sampling logic and implementing a new approach using paddle's shard_dataloader.

  • Commented out existing distributed batch sampling logic in favor of a new implementation
  • Updated BufferedIterator to accept an explicit data length parameter instead of relying on len(iterable)
  • Implemented new data loading approach using paddle.distributed.shard_dataloader with cycling through multiple dataloaders

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
deepmd/pd/utils/dataloader.py Comments out distributed batch sampling logic and updates BufferedIterator constructor to accept explicit length parameter
deepmd/pd/train/training.py Replaces original dataloader creation with shard_dataloader implementation and adds tracking for total data length

log_dict = {}
if "fid" in batch_data:
log_dict["fid"] = batch_data["fid"]
log_dict["sid"] = batch_data["sid"]
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log_dict['sid'] is being overwritten immediately after assignment. Line 1171 unconditionally sets it to 0, making the previous assignment from batch_data['sid'] useless.

Suggested change
log_dict["sid"] = batch_data["sid"]

Copilot uses AI. Check for mistakes.
self.all_dlen += dlen
all_dataloaders.append(shard_dataloader)
_shard_dataloader = cycle(chain(*all_dataloaders))
_data_buffered = BufferedIterator(iter(_shard_dataloader),self.all_dlen)
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Missing space after comma in function call parameters. Should be BufferedIterator(iter(_shard_dataloader), self.all_dlen) for consistency with Python style guidelines.

Suggested change
_data_buffered = BufferedIterator(iter(_shard_dataloader),self.all_dlen)
_data_buffered = BufferedIterator(iter(_shard_dataloader), self.all_dlen)

Copilot uses AI. Check for mistakes.
def __init__(self, iterable, alldlen) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._iterable = iterable
self._iterator = iter(iterable)
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The new _iterator field is added but appears unused in the class. If it's intended for future use, consider adding a comment explaining its purpose, or remove it if it's not needed.

Suggested change
self._iterator = iter(iterable)

Copilot uses AI. Check for mistakes.
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 12, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

Replaces single DataLoader usage with a distributed shard-based, cyclic iterator across multiple shard dataloaders; tracks total length via a new all_dlen attribute and uses a length-aware BufferedIterator. Modifies BufferedIterator’s constructor to accept total length. Removes distributed sampler path in utils, always using BatchSampler. Adjusts logging of sid to 0.

Changes

Cohort / File(s) Summary
Training pipeline: shard-based loading and iteration
deepmd/pd/train/training.py
Switches from a single DataLoader to per-dataloader shards via paddle.distributed.shard_dataloader; chains and cycles shard iterators; computes and stores total shard length in self.all_dlen; returns a length-aware BufferedIterator(iterable, total_len); on StopIteration, reinitializes with the new BufferedIterator; forces log_dict["sid"] = 0; updates imports (adds chain, cycle; removes DataLoader, NUM_WORKERS).
Dataloader utils and iterator API
deepmd/pd/utils/dataloader.py
Removes DistributedBatchSampler path; always constructs BatchSampler per system; updates BufferedIterator.__init__ signature to (iterable, alldlen), stores self.total = alldlen, initializes self._iterator = iter(iterable); changes iteration length reporting and external construction requirements.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Trainer
  participant ShardDL as Shard Dataloaders
  participant ChainCycle as chain()+cycle()
  participant BufIter as BufferedIterator

  Trainer->>ShardDL: create per-dataloader shards (shard_dataloader)
  Note right of ShardDL: Compute total_len = sum(len(shard) for all shards)
  ShardDL-->>Trainer: shard iterators + total_len

  Trainer->>ChainCycle: chain(shards) then cycle(...)
  ChainCycle-->>Trainer: infinite iterator over shards

  Trainer->>BufIter: BufferedIterator(iterable=ChainCycle, total_len)
  BufIter-->>Trainer: length-aware data iterator

  loop training step
    Trainer->>BufIter: next()
    alt data available
      BufIter-->>Trainer: batch
    else StopIteration
      Note over Trainer: Rebuild BufferedIterator(iter(...), total_len)
    end
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • njzjz
  • iProzd

Pre-merge checks (1 passed, 1 warning, 1 inconclusive)

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 10.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The title "[Auto-Paralllel] tmp fix shard_dataloader" does reference the core change (shard_dataloader) but is noisy and vague: it contains a misspelled tag ("Auto-Paralllel") and the phrase "tmp fix" which does not clearly describe the substantive change (switching to per-dataloader shard iterators and a length-aware BufferedIterator). Please replace the title with a concise, descriptive single sentence—for example, "Fix shard_dataloader: use per-dataloader shard iterators and BufferedIterator(total_len)"—and remove the "tmp" qualifier and fix the "Auto-Paralllel" typo so the main intent is clear to reviewers.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.

✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
deepmd/pd/train/training.py (2)

1116-1118: Validation data BufferedIterator missing length parameter

The BufferedIterator for validation data is still using the old single-parameter constructor, which will cause a TypeError at runtime.

Update the validation data BufferedIterator to use the new two-parameter constructor:

-                    self.validation_data = BufferedIterator(
-                        iter(self.validation_dataloader)
-                    )
+                    self.validation_data = BufferedIterator(
+                        iter(self.validation_dataloader), len(self.validation_dataloader)
+                    )

1126-1128: Multi-task BufferedIterator calls missing length parameter

The BufferedIterator instantiations for multi-task scenarios are also using the old single-parameter constructor and will fail at runtime.

Update all multi-task BufferedIterator calls:

# For training data (lines 1126-1128):
-                    self.training_data[task_key] = BufferedIterator(
-                        iter(self.training_dataloader[task_key])
-                    )
+                    self.training_data[task_key] = BufferedIterator(
+                        iter(self.training_dataloader[task_key]), len(self.training_dataloader[task_key])
+                    )

# For validation data (lines 1136-1138):
-                    self.validation_data[task_key] = BufferedIterator(
-                        iter(self.validation_dataloader[task_key])
-                    )
+                    self.validation_data[task_key] = BufferedIterator(
+                        iter(self.validation_dataloader[task_key]), len(self.validation_dataloader[task_key])
+                    )

Also applies to: 1136-1138

♻️ Duplicate comments (3)
deepmd/pd/utils/dataloader.py (1)

296-296: Unused _iterator field

deepmd/pd/train/training.py (2)

1172-1173: sid value is overwritten, losing actual system ID

The actual system ID from batch_data["sid"] is immediately overwritten with 0, which removes useful debugging information and makes the logging less informative.


186-186: all_dlen is being re-initialized unnecessarily

The self.all_dlen = 0 on line 186 resets the value that was already initialized on line 146. This could be confusing and the initialization on line 146 seems unnecessary.

Remove the initialization on line 146 since the value is properly set within get_dataloader_and_buffer:

-        self.all_dlen = 0
🧹 Nitpick comments (2)
deepmd/pd/train/training.py (2)

146-146: New attribute should be initialized consistently

The all_dlen attribute is being used to track total shard length but is initialized to 0 outside of the dataloader setup. Consider initializing it closer to where it's actually computed or document why it needs early initialization.


195-196: Consider adding spacing for readability

-                _data_buffered = BufferedIterator(
-                    iter(_shard_dataloader), self.all_dlen
-                )
+                _data_buffered = BufferedIterator(
+                    iter(_shard_dataloader), self.all_dlen
+                )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 40f0b1c and 3998ff7.

📒 Files selected for processing (2)
  • deepmd/pd/train/training.py (4 hunks)
  • deepmd/pd/utils/dataloader.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

📄 CodeRabbit inference engine (.github/copilot-instructions.md)

**/*.py: Run ruff check and ruff format on all Python code before committing
Ensure Python code is formatted with ruff’s formatter

Files:

  • deepmd/pd/train/training.py
  • deepmd/pd/utils/dataloader.py
🧬 Code graph analysis (1)
deepmd/pd/train/training.py (1)
deepmd/pd/utils/dataloader.py (1)
  • BufferedIterator (292-337)
🔇 Additional comments (3)
deepmd/pd/utils/dataloader.py (1)

171-189: Simplified sampler construction looks good

The removal of the DistributedBatchSampler path simplifies the code while maintaining the shuffle logic based on distributed initialization status.

deepmd/pd/train/training.py (2)

180-198: Good implementation of distributed shard-based loading

The new approach using paddle.distributed.shard_dataloader with cyclic iteration through chained shards provides better distributed data handling. The accumulation of total length in self.all_dlen properly supports the BufferedIterator's length tracking.


1107-1107: BufferedIterator constructor call is correct

The updated call to BufferedIterator with two arguments properly uses the new API with self.all_dlen for length tracking.

Comment on lines +293 to +296
def __init__(self, iterable, alldlen) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._iterable = iterable
self._iterator = iter(iterable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

API change needs documentation

The BufferedIterator constructor signature change from __init__(self, iterable) to __init__(self, iterable, alldlen) is a breaking API change. Consider adding a docstring to document this new parameter and its purpose.

 class BufferedIterator:
     def __init__(self, iterable, alldlen) -> None:
+        """Initialize BufferedIterator with an iterable and total length.
+        
+        Parameters
+        ----------
+        iterable
+            The iterable data source to buffer.
+        alldlen : int
+            The total length across all data sources for proper iteration tracking.
+        """
         self._queue = queue.Queue(QUEUESIZE)
         self._iterable = iterable
         self._iterator = iter(iterable)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(self, iterable, alldlen) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._iterable = iterable
self._iterator = iter(iterable)
def __init__(self, iterable, alldlen) -> None:
"""Initialize BufferedIterator with an iterable and total length.
Parameters
----------
iterable
The iterable data source to buffer.
alldlen : int
The total length across all data sources for proper iteration tracking.
"""
self._queue = queue.Queue(QUEUESIZE)
self._iterable = iterable
self._iterator = iter(iterable)
🤖 Prompt for AI Agents
In deepmd/pd/utils/dataloader.py around lines 293 to 296, the BufferedIterator
constructor signature was changed from __init__(self, iterable) to
__init__(self, iterable, alldlen) but there is no documentation for the new
parameter; add a concise docstring to the class or its __init__ method that
explains the new alldlen parameter (type, meaning, whether it’s optional,
expected units/semantics and how it affects behavior), update any public API
docs or README references to mention the new required argument, and include a
short example or note about backward compatibility or migration if callers
previously relied on the single-argument form.

Comment on lines +171 to +176
# if dist.is_available() and dist.is_initialized():
# system_batch_sampler = DistributedBatchSampler(
# system,
# shuffle=(
# (not (dist.is_available() and dist.is_initialized()))
# and shuffle

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant