Skip to content

Thunder gives you PyTorch models superpowers for training and inference. Unlock out-of-the-box optimizations for performance, memory and parallelism, or roll out your own.

License

Notifications You must be signed in to change notification settings

Lightning-AI/lightning-thunder

Repository files navigation

Give your PyTorch models superpowers ⚡

Thunder Thunder

 

Source-to-source compiler for PyTorch. Fast. Understandable. Extensible.


Thunder makes optimizing PyTorch models easy, augmenting them with custom kernels, fusions, quantization, distributed strategies, and more.

For end users, Thunder comes with plugins that provide model speed-ups out of the box, for optimal utilization of last generation hardware.

For performance experts, Thunder is the most ergonomic framework for understanding, modifying, and optimizing AI models through composable transformations.

✅ Run PyTorch 40% faster   ✅ Quantization                ✅ Kernel fusion
✅ Training recipes         ✅ FP4/FP6/FP8 precision       ✅ Distributed TP/PP/DP
✅ Inference recipes        ✅ Ready for NVIDIA Blackwell  ✅ CUDA Graphs
✅ LLMs, non LLMs and more  ✅ Custom Triton kernels       ✅ Compose all the above

license CI testing General checks Documentation Status pre-commit.ci status

 

 

Thunder

Quick start

Install Thunder via pip (more options):

pip install torch==2.6.0 torchvision==0.21 nvfuser-cu124-torch26

pip install lightning-thunder
Advanced install options

Blackwell support

For Blackwell you'll need CUDA 12.8

pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu128
pip install --pre nvfuser-cu128 --extra-index-url https://pypi.nvidia.com

pip install lightning-thunder

Install additional executors

These are optional, feel free to mix and match

# cuDNN SDPA
pip install nvidia-cudnn-frontend

# Float8 support (this will compile from source, be patient)
pip install "transformer_engine[pytorch]"

Install Thunder bleeding edge

pip install git+https://github.com/Lightning-AI/lightning-thunder.git@main

Install Thunder for development

git clone https://github.com/Lightning-AI/lightning-thunder.git
cd lightning-thunder
pip install -e .

Hello world

Define a function or a torch module:

import torch.nn as nn

model = nn.Sequential(nn.Linear(2048, 4096), nn.ReLU(), nn.Linear(4096, 64))

Optimize it with thunder:

import thunder
import torch

thunder_model = thunder.compile(model)

x = torch.randn(64, 2048)

y = thunder_model(x)

assert torch.testing.assert_close(y, model(x))

Examples

Speed up LLM training

Install LitGPT (without updating other dependencies)

pip install --no-deps 'litgpt[all]'

and run

import thunder
import torch
import litgpt

with torch.device("cuda"):
    model = litgpt.GPT.from_name("Llama-3.2-1B").to(torch.bfloat16)

thunder_model = thunder.compile(model)

inp = torch.ones((1, 2048), device="cuda", dtype=torch.int64)

out = thunder_model(inp)
out.sum().backward()

Speed up HuggingFace BERT inference

Install Hugging Face Transformers (recommended version is 4.50.2 and above)

pip install -U transformers

and run

import thunder
import torch
import transformers

model_name = "bert-large-uncased"

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

with torch.device("cuda"):
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.bfloat16
    )
    model.requires_grad_(False)
    model.eval()

    inp = tokenizer(["Hello world!"], return_tensors="pt")

thunder_model = thunder.compile(model)

out = thunder_model(**inp)
print(out)

Speed up HuggingFace DeepSeek R1 distill inference

Install Hugging Face Transformers (recommended version is 4.50.2 and above)

pip install -U transformers

and run

import torch
import transformers
import thunder

model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

with torch.device("cuda"):
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.bfloat16
    )
    model.requires_grad_(False)
    model.eval()

    inp = tokenizer(["Hello world! Here's a long story"], return_tensors="pt")

thunder_model = thunder.compile(model)

out = thunder_model.generate(
    **inp, do_sample=False, cache_implementation="static", max_new_tokens=100
)
print(out)

To get an idea of the speedups, just run

python examples/quickstart/hf_llm.py

Here what you get on a L4 machine from Lightning Studio:

Eager: 2273.22ms
Thunder: 1254.39ms

81% faster 🏎️! Quite the speedup ⚡️

Speed up Vision Transformer inference

import thunder
import torch
import torchvision as tv

with torch.device("cuda"):
    model = tv.models.vit_b_16()
    model.requires_grad_(False)
    model.eval()

    inp = torch.randn(128, 3, 224, 224)

out = model(inp)

thunder_model = thunder.compile(model)

out = thunder_model(inp)

Plugins

Plugins are a way to apply optimizations to a model, such as parallelism and quantization.

Thunder comes with a few plugins included of the box, but it's easy to write new ones.

  • scale up with distributed strategies with DDP, FSDP, TP ()
  • optimize numerical precision with FP8, MXFP8
  • save memory with quantization
  • reduce latency with CUDAGraphs
  • debugging and profiling

For example, in order to reduce CPU overheads via CUDAGraphs you can add "reduce-overhead" to the plugins= argument of thunder.compile:

thunder_model = thunder.compile(model, plugins="reduce-overhead")

This may or may not make a big difference. The point of Thunder is that you can easily swap optimizations in and out and explore the best combination for your setup.

How it works

Thunder works in three stages:

  1. ⚡️ It acquires your model by interpreting Python bytecode and producing a straight-line Python program

  2. ️⚡️ It transforms the computation trace to make it distributed, change precision

  3. ⚡️ It routes parts of the trace for execution

    • fusion (NVFuser, torch.compile)
    • specialized libraries (e.g. cuDNN SDPA, TransformerEngine)
    • custom Triton and CUDA kernels
    • PyTorch eager operations

 

Thunder

 

This is how the trace looks like for a simple MLP:

import thunder
import torch.nn as nn

model = nn.Sequential(nn.Linear(1024, 2048), nn.ReLU(), nn.Linear(2048, 256))

thunder_model = thunder.compile(model)
y = thunder_model(torch.randn(4, 1024))

print(thunder.last_traces(thunder_model)[-1])

This is the acquired trace, ready to be transformed and executed:

def computation(input, t_0_bias, t_0_weight, t_2_bias, t_2_weight):
# input: "cuda:0 f32[4, 1024]"
# t_0_bias: "cuda:0 f32[2048]"
# t_0_weight: "cuda:0 f32[2048, 1024]"
# t_2_bias: "cuda:0 f32[256]"
# t_2_weight: "cuda:0 f32[256, 2048]"
t3 = ltorch.linear(input, t_0_weight, t_0_bias) # t3: "cuda:0 f32[4, 2048]"
t6 = ltorch.relu(t3, False) # t6: "cuda:0 f32[4, 2048]"
t10 = ltorch.linear(t6, t_2_weight, t_2_bias) # t10: "cuda:0 f32[4, 256]"
return (t10,)

Note how Thunder's intermediate representation is just (a subset of) Python!

Performance

Thunder is fast. Here are the speed-ups obtained on a pre-training task using LitGPT on H100 and B200 hardware, relative to PyTorch eager.

Thunder

Community

Thunder is an open source project, developed in collaboration with the community with significant contributions from NVIDIA.

💬 Get help on Discord 📋 License: Apache 2.0

About

Thunder gives you PyTorch models superpowers for training and inference. Unlock out-of-the-box optimizations for performance, memory and parallelism, or roll out your own.

Resources

License

Stars

Watchers

Forks

Packages

No packages published