Skip to content

Support batch-level transformations in Encodings #251

Description

@lorenzoh

Sometimes encodings need to be able to take into account batch information, as in a sequence learning task where samples in a batch should be padded to the length of the longest sequence.

Currently, all Encodings transform individual samples, which is great for simplicity and composability, but doesn't allow implementing these batch-level transformations.

A usage of encodings in basically every training loop is taskdataloaders which will always give batches of encoded data. We could have this use a new function encodebatch(encoding, context, block, samples) that transforms multiple samples at a time. This would operate on vectors of samples, not a collated batch, since not all kinds of data can be collated (e.g. different-sized images).

By default, it would simply delegate to the single-sample encode function:

function encodebatch(encoding, context, block, observations::AbstractVector)
    map(obs -> encode(encoding, context, block, obs), observations)
end

But it could be overwritten by individual encodings:

function encodebatch(encoding::PadSequences, context, block, observations::AbstractVector)
    # dummy padding code
    n  = maximum(length, observations)
    return map(obs, pad(obs, n), observations)
end

Tagging relevant parties @Chandu-4444 @darsnack @ToucheSir for discussion.

Metadata

Metadata

Assignees

No one assigned

    Labels

    api-proposalImplementation or suggestion for new APIs and improvements to existing APIs

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions