Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(
"change_bias_after_training", False
)
self.lcurve_should_print_header = True
self.all_dlen = 0

def get_opt_param(params):
opt_type = params.get("opt_type", "Adam")
Expand All @@ -166,19 +167,33 @@ def get_dataloader_and_buffer(_data, _params):
log.warning(
"Sampler not specified!"
) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration.
_dataloader = DataLoader(
_data,
batch_sampler=paddle.io.BatchSampler(
sampler=_sampler,
drop_last=False,
),
num_workers=NUM_WORKERS
if dist.is_available()
else 0, # setting to 0 diverges the behavior of its iterator; should be >=1
collate_fn=lambda batch: batch[0], # prevent extra conversion
)
_data_buffered = BufferedIterator(iter(_dataloader))
return _dataloader, _data_buffered
# _dataloader = DataLoader(
# _data,
# batch_sampler=paddle.io.BatchSampler(
# sampler=_sampler,
# drop_last=False,
# ),
# num_workers=NUM_WORKERS
# if dist.is_available()
# else 0, # setting to 0 diverges the behavior of its iterator; should be >=1
# collate_fn=lambda batch: batch[0], # prevent extra conversion
# )
# _data_buffered = BufferedIterator(iter(_dataloader))
# return _dataloader, _data_buffered

from itertools import chain, cycle
all_dataloaders = []
self.all_dlen = 0
for dataloader in _data.dataloaders:
shard_dataloader = paddle.distributed.shard_dataloader(
dataloader, dist.get_mesh(), shard_dims="dp"
)
dlen = len(shard_dataloader)
self.all_dlen += dlen
all_dataloaders.append(shard_dataloader)
_shard_dataloader = cycle(chain(*all_dataloaders))
_data_buffered = BufferedIterator(iter(_shard_dataloader),self.all_dlen)
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Missing space after comma in function call parameters. Should be BufferedIterator(iter(_shard_dataloader), self.all_dlen) for consistency with Python style guidelines.

Suggested change
_data_buffered = BufferedIterator(iter(_shard_dataloader),self.all_dlen)
_data_buffered = BufferedIterator(iter(_shard_dataloader), self.all_dlen)

Copilot uses AI. Check for mistakes.
return _shard_dataloader, _data_buffered

training_dataloader, training_data_buffered = get_dataloader_and_buffer(
_training_data, _training_params["training_data"]
Expand Down Expand Up @@ -1087,7 +1102,7 @@ def get_data(self, is_train=True, task_key="Default"):
except StopIteration:
# Refresh the status of the dataloader to start from a new epoch
self.training_data = BufferedIterator(
iter(self.training_dataloader)
iter(self.training_dataloader), self.all_dlen
)
batch_data = next(iter(self.training_data))
else:
Expand Down Expand Up @@ -1153,6 +1168,7 @@ def get_data(self, is_train=True, task_key="Default"):
if "fid" in batch_data:
log_dict["fid"] = batch_data["fid"]
log_dict["sid"] = batch_data["sid"]
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log_dict['sid'] is being overwritten immediately after assignment. Line 1171 unconditionally sets it to 0, making the previous assignment from batch_data['sid'] useless.

Suggested change
log_dict["sid"] = batch_data["sid"]

Copilot uses AI. Check for mistakes.
log_dict["sid"] = 0
return input_dict, label_dict, log_dict

def print_header(self, fout, train_results, valid_results) -> None:
Expand Down
46 changes: 23 additions & 23 deletions deepmd/pd/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,26 +168,26 @@ def construct_dataset(system):
self.batch_sizes = batch_size * np.ones(len(systems), dtype=int)
assert len(self.systems) == len(self.batch_sizes)
for system, batch_size in zip(self.systems, self.batch_sizes):
if dist.is_available() and dist.is_initialized():
system_batch_sampler = DistributedBatchSampler(
system,
shuffle=(
(not (dist.is_available() and dist.is_initialized()))
and shuffle
),
batch_size=int(batch_size),
)
self.sampler_list.append(system_batch_sampler)
else:
system_batch_sampler = BatchSampler(
system,
shuffle=(
(not (dist.is_available() and dist.is_initialized()))
and shuffle
),
batch_size=int(batch_size),
)
self.sampler_list.append(system_batch_sampler)
# if dist.is_available() and dist.is_initialized():
# system_batch_sampler = DistributedBatchSampler(
# system,
# shuffle=(
# (not (dist.is_available() and dist.is_initialized()))
# and shuffle
Comment on lines +171 to +176

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.
# ),
# batch_size=int(batch_size),
# )
# self.sampler_list.append(system_batch_sampler)
# else:
system_batch_sampler = BatchSampler(
system,
shuffle=(
(not (dist.is_available() and dist.is_initialized()))
and shuffle
),
batch_size=int(batch_size),
)
self.sampler_list.append(system_batch_sampler)
system_dataloader = DataLoader(
dataset=system,
num_workers=0, # Should be 0 to avoid too many threads forked
Expand Down Expand Up @@ -291,14 +291,15 @@ def run(self) -> None:


class BufferedIterator:
def __init__(self, iterable) -> None:
def __init__(self, iterable, alldlen) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._iterable = iterable
self._iterator = iter(iterable)
Copy link

Copilot AI Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The new _iterator field is added but appears unused in the class. If it's intended for future use, consider adding a comment explaining its purpose, or remove it if it's not needed.

Suggested change
self._iterator = iter(iterable)

Copilot uses AI. Check for mistakes.
Comment on lines +293 to +296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

API change needs documentation

The BufferedIterator constructor signature change from __init__(self, iterable) to __init__(self, iterable, alldlen) is a breaking API change. Consider adding a docstring to document this new parameter and its purpose.

 class BufferedIterator:
     def __init__(self, iterable, alldlen) -> None:
+        """Initialize BufferedIterator with an iterable and total length.
+        
+        Parameters
+        ----------
+        iterable
+            The iterable data source to buffer.
+        alldlen : int
+            The total length across all data sources for proper iteration tracking.
+        """
         self._queue = queue.Queue(QUEUESIZE)
         self._iterable = iterable
         self._iterator = iter(iterable)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __init__(self, iterable, alldlen) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._iterable = iterable
self._iterator = iter(iterable)
def __init__(self, iterable, alldlen) -> None:
"""Initialize BufferedIterator with an iterable and total length.
Parameters
----------
iterable
The iterable data source to buffer.
alldlen : int
The total length across all data sources for proper iteration tracking.
"""
self._queue = queue.Queue(QUEUESIZE)
self._iterable = iterable
self._iterator = iter(iterable)
🤖 Prompt for AI Agents
In deepmd/pd/utils/dataloader.py around lines 293 to 296, the BufferedIterator
constructor signature was changed from __init__(self, iterable) to
__init__(self, iterable, alldlen) but there is no documentation for the new
parameter; add a concise docstring to the class or its __init__ method that
explains the new alldlen parameter (type, meaning, whether it’s optional,
expected units/semantics and how it affects behavior), update any public API
docs or README references to mention the new required argument, and include a
short example or note about backward compatibility or migration if callers
previously relied on the single-argument form.

self._consumer = None

self.start_time = time.time()
self.warning_time = None
self.total = len(iterable)
self.total = alldlen

def _create_consumer(self) -> None:
self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total)
Expand Down Expand Up @@ -328,7 +329,6 @@ def __next__(self):
"number of workers (--num-workers) may help."
)
self.warning_time = time.time()

# Get next example
item = self._queue.get()
if isinstance(item, Exception):
Expand Down