diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 4e5fea081f..fe5da9b253 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -23,9 +23,6 @@ from paddle.framework import ( core, ) -from paddle.io import ( - DataLoader, -) from deepmd.common import ( symlink_prefix_files, @@ -58,7 +55,6 @@ DEFAULT_PRECISION, DEVICE, JIT, - NUM_WORKERS, SAMPLER_RECORD, enable_prim, ) @@ -147,6 +143,7 @@ def __init__( "change_bias_after_training", False ) self.lcurve_should_print_header = True + self.all_dlen = 0 def get_opt_param(params): opt_type = params.get("opt_type", "Adam") @@ -166,19 +163,39 @@ def get_dataloader_and_buffer(_data, _params): log.warning( "Sampler not specified!" ) # None sampler will lead to a premature stop iteration. Replacement should be True in attribute of the sampler to produce expected number of items in one iteration. - _dataloader = DataLoader( - _data, - batch_sampler=paddle.io.BatchSampler( - sampler=_sampler, - drop_last=False, - ), - num_workers=NUM_WORKERS - if dist.is_available() - else 0, # setting to 0 diverges the behavior of its iterator; should be >=1 - collate_fn=lambda batch: batch[0], # prevent extra conversion + # _dataloader = DataLoader( + # _data, + # batch_sampler=paddle.io.BatchSampler( + # sampler=_sampler, + # drop_last=False, + # ), + # num_workers=NUM_WORKERS + # if dist.is_available() + # else 0, # setting to 0 diverges the behavior of its iterator; should be >=1 + # collate_fn=lambda batch: batch[0], # prevent extra conversion + # ) + # _data_buffered = BufferedIterator(iter(_dataloader)) + # return _dataloader, _data_buffered + + from itertools import ( + chain, + cycle, + ) + + all_dataloaders = [] + self.all_dlen = 0 + for dataloader in _data.dataloaders: + shard_dataloader = paddle.distributed.shard_dataloader( + dataloader, dist.get_mesh(), shard_dims="dp" + ) + dlen = len(shard_dataloader) + self.all_dlen += dlen + all_dataloaders.append(shard_dataloader) + _shard_dataloader = cycle(chain(*all_dataloaders)) + _data_buffered = BufferedIterator( + iter(_shard_dataloader), self.all_dlen ) - _data_buffered = BufferedIterator(iter(_dataloader)) - return _dataloader, _data_buffered + return _shard_dataloader, _data_buffered training_dataloader, training_data_buffered = get_dataloader_and_buffer( _training_data, _training_params["training_data"] @@ -1087,7 +1104,7 @@ def get_data(self, is_train=True, task_key="Default"): except StopIteration: # Refresh the status of the dataloader to start from a new epoch self.training_data = BufferedIterator( - iter(self.training_dataloader) + iter(self.training_dataloader), self.all_dlen ) batch_data = next(iter(self.training_data)) else: @@ -1153,6 +1170,7 @@ def get_data(self, is_train=True, task_key="Default"): if "fid" in batch_data: log_dict["fid"] = batch_data["fid"] log_dict["sid"] = batch_data["sid"] + log_dict["sid"] = 0 return input_dict, label_dict, log_dict def print_header(self, fout, train_results, valid_results) -> None: diff --git a/deepmd/pd/utils/dataloader.py b/deepmd/pd/utils/dataloader.py index 0cb8adbc63..0ed3894a9b 100644 --- a/deepmd/pd/utils/dataloader.py +++ b/deepmd/pd/utils/dataloader.py @@ -168,26 +168,25 @@ def construct_dataset(system): self.batch_sizes = batch_size * np.ones(len(systems), dtype=int) assert len(self.systems) == len(self.batch_sizes) for system, batch_size in zip(self.systems, self.batch_sizes): - if dist.is_available() and dist.is_initialized(): - system_batch_sampler = DistributedBatchSampler( - system, - shuffle=( - (not (dist.is_available() and dist.is_initialized())) - and shuffle - ), - batch_size=int(batch_size), - ) - self.sampler_list.append(system_batch_sampler) - else: - system_batch_sampler = BatchSampler( - system, - shuffle=( - (not (dist.is_available() and dist.is_initialized())) - and shuffle - ), - batch_size=int(batch_size), - ) - self.sampler_list.append(system_batch_sampler) + # if dist.is_available() and dist.is_initialized(): + # system_batch_sampler = DistributedBatchSampler( + # system, + # shuffle=( + # (not (dist.is_available() and dist.is_initialized())) + # and shuffle + # ), + # batch_size=int(batch_size), + # ) + # self.sampler_list.append(system_batch_sampler) + # else: + system_batch_sampler = BatchSampler( + system, + shuffle=( + (not (dist.is_available() and dist.is_initialized())) and shuffle + ), + batch_size=int(batch_size), + ) + self.sampler_list.append(system_batch_sampler) system_dataloader = DataLoader( dataset=system, num_workers=0, # Should be 0 to avoid too many threads forked @@ -291,14 +290,15 @@ def run(self) -> None: class BufferedIterator: - def __init__(self, iterable) -> None: + def __init__(self, iterable, alldlen) -> None: self._queue = queue.Queue(QUEUESIZE) self._iterable = iterable + self._iterator = iter(iterable) self._consumer = None self.start_time = time.time() self.warning_time = None - self.total = len(iterable) + self.total = alldlen def _create_consumer(self) -> None: self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total) @@ -328,7 +328,6 @@ def __next__(self): "number of workers (--num-workers) may help." ) self.warning_time = time.time() - # Get next example item = self._queue.get() if isinstance(item, Exception):