Skip to content

cache var used by each iteration in grid persistent kernel, e.g. weight in layer norm backward #2525

@liqiangxl

Description

@liqiangxl

🚀 The feature, motivation and pitch

In layer norm backward:
For input DataType::Half, the persistent buffers are projected to three inputs (dy, x, weight), total size is 3 * sizeof(half) * dim1
For input DataType::Float the persistent buffers are NOT projected, they are xhat and d_xhat, the total size is 2 * sizeof(float) * dim1
If I enforce projection for input DataType::Float, there is a significiant speedup, e.g. for case 2048 x 10240 the time is reduced from 274 us to 207 us, for case 2048 x 1024 the time is reduced from 39 us to 36 us. The reason is because weight is shared across different rows. If we keep it persistent, we don't need to reload it in the iteration over different rows. The projected version needs more registers per thread but it doesn't reduce the occupancy ratio as the all the blocks must be active at the same time for this grid persistent kernel.

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions