Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark against Flash attention repo #137

Open
zzhhjjj opened this issue Aug 28, 2024 · 0 comments
Open

Benchmark against Flash attention repo #137

zzhhjjj opened this issue Aug 28, 2024 · 0 comments
Assignees

Comments

@zzhhjjj
Copy link

zzhhjjj commented Aug 28, 2024

🚀 The feature, motivation and pitch

It would be interesting to compare with the Flash-attention repo. https://github.com/Dao-AILab/flash-attention.

I've benchmarked the RMSNorm.

benchmark
image

Here is my code. Suppose sequence length = 8192, and batch_size =4

class TritonRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.ones_(self.weight)

    def forward(
        self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
    ):
        return layer_norm_fn(
            input,
            self.weight,
            None,
            residual=residual,
            eps=self.eps,
            dropout_p=dropout_p,
            prenorm=prenorm,
            residual_in_fp32=residual_in_fp32,
            is_rms_norm=True,
            return_dropout_mask=return_dropout_mask,
        )

class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-5):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, input):
        input_dtype = input.dtype
        input = input.to(torch.float32)
        variance = input.pow(2).mean(-1, keepdim=True)
        input = input * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * input.to(input_dtype)

@triton.testing.perf_report(
    [
        triton.testing.Benchmark(
            x_names=["N"],
            x_vals=[2**i for i in range(10, 16)],
            xlabel="hidden size",
            line_arg="provider",
            line_vals=["liger", "huggingface","flash-attention"],
            line_names=["Liger", "Hugging Face", "flash-attention"],
            styles=[("blue", "solid"), ("orange", "solid"), ("red", "solid")],
            ylabel="time (ms)",
            plot_name="rmsnorm-fwd-speed-benchmark",
            args={"M": 2048, "dtype": torch.bfloat16, "mode": "forward"},
        ),
        triton.testing.Benchmark(
            x_names=["N"],
            x_vals=[2**i for i in range(10, 16)],
            xlabel="hidden size",
            line_arg="provider",
            line_vals=["liger", "huggingface","flash-attention"],
            line_names=["Liger", "Hugging Face", "flash-attention"],
            styles=[("blue", "solid"), ("orange", "solid"), ("red", "solid")],
            ylabel="time (ms)",
            plot_name="rmsnorm-bwd-speed-benchmark",
            args={"M": 2048, "dtype": torch.bfloat16, "mode": "backward"},
        ),
        triton.testing.Benchmark(
            x_names=["N"],
            x_vals=[2**i for i in range(10, 16)],
            xlabel="hidden size",
            line_arg="provider",
            line_vals=["liger", "huggingface","flash-attention"],
            line_names=["Liger", "Hugging Face", "flash-attention"],
            styles=[("blue", "solid"), ("orange", "solid"), ("red", "solid")],
            ylabel="time (ms)",
            plot_name="rmsnorm-full-speed-benchmark",
            args={"M": 2048, "dtype": torch.bfloat16, "mode": "full"},
        ),
    ]
)

def bench_speed_rms_norm(M, N, dtype, provider, mode, eps=1e-5, device="cuda"):
    x_shape = (8192*4, N)

    triton_rms = LigerRMSNorm(hidden_size=N).to("cuda")
    llama_rms = RMSNorm(hidden_size=N).to("cuda")
    triton_rmsnorm = TritonRMSNorm(hidden_size=N).to("cuda")

    x = torch.randn(x_shape, dtype=dtype, device="cuda")
    dy = torch.randn_like(x)
    x.requires_grad_(True)
    quantiles = [0.5, 0.2, 0.8]

    # utility functions

    def y_fwd():
        if provider == "liger":
            return triton_rms(x)

        if provider == "huggingface":

            return llama_rms(x)
        
        if provider == "flash-attention":
            return triton_rmsnorm(x)

    if mode == "forward":
        ms, min_ms, max_ms = triton.testing.do_bench(
            y_fwd, quantiles=quantiles, grad_to_none=[x], rep=500
        )
    elif mode == "backward":
        y = y_fwd()
        ms, min_ms, max_ms = triton.testing.do_bench(
            lambda: y.backward(dy, retain_graph=True),
            quantiles=quantiles,
            grad_to_none=[x],
            rep=500,
        )
    elif mode == "full":

        def full():
            y = y_fwd()
            y.backward(dy, retain_graph=True)

        ms, min_ms, max_ms = triton.testing.do_bench(
            full, quantiles=quantiles, grad_to_none=[x], rep=500
        )

    return ms, max_ms, min_ms

bench_speed_rms_norm.run(save_path='', print_data=True)

@triton.testing.perf_report(
    [
        triton.testing.Benchmark(
            x_names=["N"],
            x_vals=[2**i for i in range(10, 16)],
            xlabel="hidden size",
            line_arg="provider",
            line_vals=["liger", "huggingface","flash-attention"],
            line_names=["Liger", "Hugging Face", "flash-attention"],
            styles=[("blue", "solid"), ("orange", "solid"), ("red", "solid")],
            ylabel="GPU memory usage (MB)",
            plot_name="rmsnorm-full-memory-benchmark",
            args={"M": 2048, "dtype": torch.bfloat16, "mode": "full"},
        )
    ]
)
def bench_memory_rms_norm(M, N, dtype, provider, mode, eps=1e-5, device="cuda"):
    x_shape = (8192*4, N)

    triton_rms = LigerRMSNorm(hidden_size=N).to("cuda")
    llama_rms = RMSNorm(hidden_size=N).to("cuda")
    triton_rmsnorm = TritonRMSNorm(hidden_size=N).to("cuda")

    x = torch.randn(x_shape, dtype=dtype, device="cuda")
    dy = torch.randn_like(x)
    x.requires_grad_(True)

    # utility functions
    def y_fwd():
        if provider == "liger":
            return triton_rms(x)
        if provider == "huggingface":
            return llama_rms(x)
        if provider == "flash-attention":
            return triton_rmsnorm(x)

    def full():
        y = y_fwd()
        y.backward(dy, retain_graph=True)

    mem = _test_memory(full)

    return mem / 2**20

bench_memory_rms_norm.run(save_path='', print_data=True)

Alternatives

No response

Additional context

No response

@linkedin linkedin deleted a comment from ByronHsu Aug 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants