Skip to content

Conversation

@frankyoujian
Copy link

Hi community, thanks for your great work on this opensource project!

I am training the segmentation model on multiple GPUs following tutorials:

trainer = Trainer(devices=4, accelerator="gpu", strategy='ddp')
trainer.fit(model)

But I find each GPU iterates over the whole batches instead of 1/4 the whole. Then I find the current TrainDataset is based on IterableDataset, so it can not use Lightning DistributedSampler. Looks like we need manually split the whole batches per rank and world size.

So I draft this PR and verified it works locally. Please correct me if anything wrong. Thanks

@hbredin
Copy link
Member

hbredin commented Nov 8, 2025

Thanks for your contribution.

Genuine question: can you explain what this change allows to do that cannot be done with the current codebase and provide examples?

The (weird, I agree) current sampling process is the results of lots of iterations so I am worried this might break part of it.

In particular, we made sure every single worker will never ever generate the same samples as the other workers. Also, I am aware that, currently, the total number of samples per epoch is growing linearly with the number of nodes (this should indeed be documented).

@frankyoujian
Copy link
Author

@hbredin In my local test, under multiple GPU scenario, the current codebase will make each GPU process iterate over the whole dataset which should be unnecessary (personal opinion). With this PR change, it will firstly limit the batch num of each GPU process via dividing the whole batches by world size in train__len__, so that each GPU process will have its own averaged batch num.

Then in train__iter__helper, it will yield corresponding averaged batch num of samples by deterministic round-robin. And lastly, it will only yield train__len__() batches in train__iter__

In this way, each GPU only iterates the averaged num of samples, which should be more efficient under multipli GPU scenario.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants