Skip to content
This repository has been archived by the owner on Jul 4, 2023. It is now read-only.

MaxTokensBatchSampler #114

Open
salvacarrion opened this issue May 15, 2021 · 1 comment
Open

MaxTokensBatchSampler #114

salvacarrion opened this issue May 15, 2021 · 1 comment

Comments

@salvacarrion
Copy link

Is there any straightforward way to specify the maximum number of tokens per batch in a sampler (eg.: BucketBatchSampler)?

Reducing the amount of padding per batch is critical for performance and the BucketBatchSampler class does an excellent job in doing so. However, imo, in NLP tasks the concept of batch_size comes second to the number of tokens per batch since the former is an optimization and the latter a constraint to avoid OutOfMemory errors. (For instance, I can train a batch size of 128 with max_length 100 but not one with max_length 512)

@salvacarrion salvacarrion changed the title MaxTokensSampler MaxTokensBatchSampler May 15, 2021
@salvacarrion
Copy link
Author

salvacarrion commented May 15, 2021

I have written this piece of code. Is it not a clean solution, but it works.

import random
from torch.utils.data.sampler import BatchSampler, RandomSampler, SubsetRandomSampler
from torchnlp.utils import identity


class MaxTokensBatchSampler(BatchSampler):

    def __init__(self,
                 sampler,
                 batch_size,
                 max_tokens,
                 drop_last,
                 sort_key=identity,
                 bucket_size_multiplier=100,
                 shuffle=True):
        super().__init__(sampler, batch_size, drop_last)
        self.max_tokens = max_tokens
        self.sort_key = sort_key
        self.bucket_size_multiplier = bucket_size_multiplier
        self.shuffle = shuffle

        # Not a clean solution
        self.bucket_batches = []
        self._build_buckets()

    def __iter__(self):
        # Iterate over buckets
        for batches, batch_sizes in self.bucket_batches:
            # Shuffle bucket-batch order
            batches = SubsetRandomSampler(batches) if self.shuffle else batches
            for batch in batches:
                if self.shuffle:  # Shuffle inner batch
                    random.shuffle(batch)
                yield batch  # Batch indexes [sent1_idx, sent2_idx,...]

    def __len__(self):
        return sum([len(x[0]) for x in self.bucket_batches])

    def _build_buckets(self):
        # Randomize samples
        tmp_sampler = RandomSampler(self.sampler) if self.shuffle else self.sampler

        # Split samples in N batches (or "buckets")
        tmp_sampler = BatchSampler(tmp_sampler, min(self.batch_size * self.bucket_size_multiplier, len(self.sampler)),
                                   False)

        # Sort samples
        self.bucket_batches = []
        for bucket in tmp_sampler:
            bucket_sorted = sorted([(i, self.sort_key(i)) for i in bucket], key=lambda x: x[1])

            # Create batches constrained
            batches = []
            batch_sizes = []

            last_batch = []
            last_batch_size = 0
            for i, (sample_i, length_i) in enumerate(bucket_sorted):
                if (last_batch_size + length_i) < self.max_tokens:
                    last_batch.append(sample_i)
                    last_batch_size += length_i
                else:
                    # Add batch
                    batches.append(last_batch)
                    batch_sizes.append(last_batch_size)

                    # Add new sample
                    last_batch = [sample_i]
                    last_batch_size = length_i

            # Add last batch
            batches.append(last_batch)
            batch_sizes.append(last_batch_size)

            # Add bucket batches
            self.bucket_batches.append((batches, batch_sizes))

It works as follows:

  1. Randomize all samples/sentences: [0,1,2,3,...n] => [6, 12, 60,... , 31]
  2. Split samples into buckets => [6, 12, 60,...], [92, 1, 52,... , 24], [95, 234, 33,... , 31]
  3. Sort in-bucket by sentence lengths
  4. Shuffle batch-orders in butckets: bucket1 (batch1, batch2,...batchN => batch5, batch10,... batch3), butcket2...
  5. Shuffle batch sentences: batch1: [23, 51, 12...] => [391, 2, 33,...]

You can call using:

train_sampler = MaxTokensBatchSampler(SequentialSampler(train_ds), shuffle=True, batch_size=BATCH_SIZE, max_tokens=MAX_TOKENS, drop_last=False, sort_key=lambda i: len(train_ds.datasets.iloc[i]["src"].split()))
    val_sampler = MaxTokensBatchSampler(SequentialSampler(val_ds), shuffle=False, batch_size=BATCH_SIZE, max_tokens=MAX_TOKENS, drop_last=False, sort_key=lambda i: len(val_ds.datasets.iloc[i]["src"].split()))
 

train_ds and val_ds are torch Dataset classes: (class TranslationDataset(Dataset):)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant