We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
It would be interesting to compare with the Flash-attention repo. https://github.com/Dao-AILab/flash-attention.
I've benchmarked the RMSNorm.
Here is my code. Suppose sequence length = 8192, and batch_size =4
class TritonRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) def forward( self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False ): return layer_norm_fn( input, self.weight, None, residual=residual, eps=self.eps, dropout_p=dropout_p, prenorm=prenorm, residual_in_fp32=residual_in_fp32, is_rms_norm=True, return_dropout_mask=return_dropout_mask, ) class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-5): """ LlamaRMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, input): input_dtype = input.dtype input = input.to(torch.float32) variance = input.pow(2).mean(-1, keepdim=True) input = input * torch.rsqrt(variance + self.variance_epsilon) return self.weight * input.to(input_dtype) @triton.testing.perf_report( [ triton.testing.Benchmark( x_names=["N"], x_vals=[2**i for i in range(10, 16)], xlabel="hidden size", line_arg="provider", line_vals=["liger", "huggingface","flash-attention"], line_names=["Liger", "Hugging Face", "flash-attention"], styles=[("blue", "solid"), ("orange", "solid"), ("red", "solid")], ylabel="time (ms)", plot_name="rmsnorm-fwd-speed-benchmark", args={"M": 2048, "dtype": torch.bfloat16, "mode": "forward"}, ), triton.testing.Benchmark( x_names=["N"], x_vals=[2**i for i in range(10, 16)], xlabel="hidden size", line_arg="provider", line_vals=["liger", "huggingface","flash-attention"], line_names=["Liger", "Hugging Face", "flash-attention"], styles=[("blue", "solid"), ("orange", "solid"), ("red", "solid")], ylabel="time (ms)", plot_name="rmsnorm-bwd-speed-benchmark", args={"M": 2048, "dtype": torch.bfloat16, "mode": "backward"}, ), triton.testing.Benchmark( x_names=["N"], x_vals=[2**i for i in range(10, 16)], xlabel="hidden size", line_arg="provider", line_vals=["liger", "huggingface","flash-attention"], line_names=["Liger", "Hugging Face", "flash-attention"], styles=[("blue", "solid"), ("orange", "solid"), ("red", "solid")], ylabel="time (ms)", plot_name="rmsnorm-full-speed-benchmark", args={"M": 2048, "dtype": torch.bfloat16, "mode": "full"}, ), ] ) def bench_speed_rms_norm(M, N, dtype, provider, mode, eps=1e-5, device="cuda"): x_shape = (8192*4, N) triton_rms = LigerRMSNorm(hidden_size=N).to("cuda") llama_rms = RMSNorm(hidden_size=N).to("cuda") triton_rmsnorm = TritonRMSNorm(hidden_size=N).to("cuda") x = torch.randn(x_shape, dtype=dtype, device="cuda") dy = torch.randn_like(x) x.requires_grad_(True) quantiles = [0.5, 0.2, 0.8] # utility functions def y_fwd(): if provider == "liger": return triton_rms(x) if provider == "huggingface": return llama_rms(x) if provider == "flash-attention": return triton_rmsnorm(x) if mode == "forward": ms, min_ms, max_ms = triton.testing.do_bench( y_fwd, quantiles=quantiles, grad_to_none=[x], rep=500 ) elif mode == "backward": y = y_fwd() ms, min_ms, max_ms = triton.testing.do_bench( lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, grad_to_none=[x], rep=500, ) elif mode == "full": def full(): y = y_fwd() y.backward(dy, retain_graph=True) ms, min_ms, max_ms = triton.testing.do_bench( full, quantiles=quantiles, grad_to_none=[x], rep=500 ) return ms, max_ms, min_ms bench_speed_rms_norm.run(save_path='', print_data=True) @triton.testing.perf_report( [ triton.testing.Benchmark( x_names=["N"], x_vals=[2**i for i in range(10, 16)], xlabel="hidden size", line_arg="provider", line_vals=["liger", "huggingface","flash-attention"], line_names=["Liger", "Hugging Face", "flash-attention"], styles=[("blue", "solid"), ("orange", "solid"), ("red", "solid")], ylabel="GPU memory usage (MB)", plot_name="rmsnorm-full-memory-benchmark", args={"M": 2048, "dtype": torch.bfloat16, "mode": "full"}, ) ] ) def bench_memory_rms_norm(M, N, dtype, provider, mode, eps=1e-5, device="cuda"): x_shape = (8192*4, N) triton_rms = LigerRMSNorm(hidden_size=N).to("cuda") llama_rms = RMSNorm(hidden_size=N).to("cuda") triton_rmsnorm = TritonRMSNorm(hidden_size=N).to("cuda") x = torch.randn(x_shape, dtype=dtype, device="cuda") dy = torch.randn_like(x) x.requires_grad_(True) # utility functions def y_fwd(): if provider == "liger": return triton_rms(x) if provider == "huggingface": return llama_rms(x) if provider == "flash-attention": return triton_rmsnorm(x) def full(): y = y_fwd() y.backward(dy, retain_graph=True) mem = _test_memory(full) return mem / 2**20 bench_memory_rms_norm.run(save_path='', print_data=True)
No response
The text was updated successfully, but these errors were encountered:
lancerts
No branches or pull requests
🚀 The feature, motivation and pitch
It would be interesting to compare with the Flash-attention repo. https://github.com/Dao-AILab/flash-attention.
I've benchmarked the RMSNorm.
Here is my code. Suppose sequence length = 8192, and batch_size =4
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: