Skip to content

Commit 2bf2ca8

Browse files
authored
feat: update ConvMixer to support reactant (#1063)
* fix: update to new reactant changes * fix: use enzyme correctly * fix: update training code * feat: handle optimisers correctly * fix: don't transfer via CPU * refactor: remove promote_to handling * revert: specific version bumps * docs: add multiple CIFAR10 examples using Reactant * feat: BF16 training + inference * fix: incorrect AdamW handling * feat: implement resnet20 baseline * fix: conv mixer working again
1 parent 46a012d commit 2bf2ca8

File tree

13 files changed

+405
-240
lines changed

13 files changed

+405
-240
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <[email protected]> and contributors"]
4-
version = "1.4.3"
4+
version = "1.4.4"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Optimisers = "0.4.1"
5656
Pkg = "1.10"
5757
Printf = "1.10"
5858
Random = "1.10"
59-
Reactant = "0.2.12"
59+
Reactant = "0.2.11"
6060
StableRNGs = "1"
6161
StaticArrays = "1"
6262
WeightInitializers = "1"

docs/src/.vitepress/config.mts

+2-2
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ export default defineConfig({
243243
link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/DDIM",
244244
},
245245
{
246-
text: "ConvMixer on CIFAR-10",
247-
link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer",
246+
text: "Different Vision Models on CIFAR-10",
247+
link: "https://github.com/LuxDL/Lux.jl/tree/main/examples/CIFAR10",
248248
},
249249
],
250250
},

docs/src/tutorials/index.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,10 @@ const large_models = [
9797
desc: "Train a Diffusion Model to generate images from Gaussian noises."
9898
},
9999
{
100-
href: "https://github.com/LuxDL/Lux.jl/tree/main/examples/ConvMixer",
100+
href: "https://github.com/LuxDL/Lux.jl/tree/main/examples/CIFAR10",
101101
src: "https://datasets.activeloop.ai/wp-content/uploads/2022/09/CIFAR-10-dataset-Activeloop-Platform-visualization-image-1.webp",
102-
caption: "ConvMixer on CIFAR-10",
103-
desc: "Train ConvMixer on CIFAR-10 to 90% accuracy within 10 minutes."
102+
caption: "Vision Models on CIFAR-10",
103+
desc: "Train different vision models on CIFAR-10 to 90% accuracy within 10 minutes."
104104
}
105105
];
106106
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
[deps]
2+
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
23
Comonicon = "863f3e99-da2a-4334-8734-de3dacbe5542"
34
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
45
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
6+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
57
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
68
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
79
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
@@ -11,18 +13,20 @@ MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1113
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1214
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
1315
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
14-
PreferenceTools = "ba661fbb-e901-4445-b070-854aec6bfbc5"
1516
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
16-
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
17+
ProgressTables = "e0b4b9f6-8cc7-451e-9c86-94c5316e9f73"
1718
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
19+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
1820
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1921
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2022
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2123

2224
[compat]
25+
BFloat16s = "0.5.0"
2326
Comonicon = "1.0.8"
2427
ConcreteStructs = "0.2.3"
2528
DataAugmentation = "0.3"
29+
Enzyme = "0.13.14"
2630
ImageCore = "0.10.2"
2731
ImageShow = "0.3.8"
2832
Interpolations = "0.15.1"
@@ -32,10 +36,9 @@ MLDatasets = "0.7.14"
3236
MLUtils = "0.4.4"
3337
OneHotArrays = "0.2.5"
3438
Optimisers = "0.4.1"
35-
PreferenceTools = "0.1.2"
3639
Printf = "1.10"
37-
ProgressBars = "1.5.1"
3840
Random = "1.10"
41+
Reactant = "0.2.12"
3942
StableRNGs = "1.0.2"
4043
Statistics = "1.10"
4144
Zygote = "0.6.70"

examples/CIFAR10/README.md

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Train Vision Models on CIFAR-10
2+
3+
✈️ 🚗 🐦 🐈 🦌 🐕 🐸 🐎 🚢 🚚
4+
5+
We have the following scripts to train vision models on CIFAR-10:
6+
7+
1. `simple_cnn.jl`: Simple CNN model with a sequence of convolutional layers.
8+
2. `mlp_mixer.jl`: MLP-Mixer model.
9+
3. `conv_mixer.jl`: ConvMixer model.
10+
11+
To get the options for each script, run the script with the `--help` flag.
12+
13+
> [!NOTE]
14+
> To train the model using Reactant.jl pass in `--backend=reactant` to the script. This is
15+
> the recommended approach to train the models present in this directory.
16+
17+
> [!NOTE]
18+
> Passing `--bfloat16` will use BFloat16 precision for training. This needs Julia 1.11 or
19+
> above.
20+
21+
## Simple CNN
22+
23+
```bash
24+
julia --startup-file=no \
25+
--project=. \
26+
--threads=auto \
27+
simple_cnn.jl \
28+
--backend=reactant
29+
```
30+
31+
On a RTX 4050 6GB Laptop GPU the training takes approximately 3 mins and the final training
32+
and test accuracies are 97% and 65%, respectively.
33+
34+
## ResNet 20
35+
36+
```bash
37+
julia --startup-file=no \
38+
--project=. \
39+
--threads=auto \
40+
resnet20.jl \
41+
--backend=reactant
42+
```
43+
44+
On a RTX 3060 GPU, each epoch takes about 4.5 seconds and the final training and testing
45+
accuracy are 89% and 75% respectively.
46+
47+
## ConvMixer
48+
49+
> [!NOTE]
50+
> This code has been adapted from https://github.com/locuslab/convmixer-cifar10
51+
52+
This is a simple ConvMixer training script for CIFAR-10. It's probably a good starting point
53+
for new experiments on small datasets.
54+
55+
You can get around **90.0%** accuracy in just **25 epochs** by running the script with the
56+
following arguments, which trains a ConvMixer-256/8 with kernel size 5 and patch size 2.
57+
58+
```bash
59+
julia --startup-file=no \
60+
--project=. \
61+
--threads=auto \
62+
conv_mixer.jl \
63+
--backend=reactant
64+
```
65+
66+
### Notes
67+
68+
1. To match the results from the original repo, we need more augmentation strategies, that
69+
are currently not implemented in DataAugmentation.jl.

examples/CIFAR10/common.jl

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays,
2+
Printf, ProgressTables, Random, BFloat16s
3+
using Reactant, LuxCUDA
4+
5+
@concrete struct TensorDataset
6+
dataset
7+
transform
8+
end
9+
10+
Base.length(ds::TensorDataset) = length(ds.dataset)
11+
12+
function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange})
13+
img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3))
14+
y = onehotbatch(ds.dataset.targets[idxs], 0:9)
15+
return stack(parent itemdata Base.Fix1(apply, ds.transform), img), y
16+
end
17+
18+
function get_cifar10_dataloaders(::Type{T}, batchsize; kwargs...) where {T}
19+
cifar10_mean = (0.4914, 0.4822, 0.4465) .|> T
20+
cifar10_std = (0.2471, 0.2435, 0.2616) .|> T
21+
22+
train_transform = RandomResizeCrop((32, 32)) |>
23+
Maybe(FlipX{2}()) |>
24+
ImageToTensor() |>
25+
Normalize(cifar10_mean, cifar10_std) |>
26+
ToEltype(T)
27+
28+
test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) |> ToEltype(T)
29+
30+
trainset = TensorDataset(CIFAR10(; Tx=T, split=:train), train_transform)
31+
trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...)
32+
33+
testset = TensorDataset(CIFAR10(; Tx=T, split=:test), test_transform)
34+
testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...)
35+
36+
return trainloader, testloader
37+
end
38+
39+
function accuracy(model, ps, st, dataloader)
40+
total_correct, total = 0, 0
41+
cdev = cpu_device()
42+
for (x, y) in dataloader
43+
target_class = onecold(cdev(y))
44+
predicted_class = onecold(cdev(first(model(x, ps, st))))
45+
total_correct += sum(target_class .== predicted_class)
46+
total += length(target_class)
47+
end
48+
return total_correct / total
49+
end
50+
51+
function get_accelerator_device(backend::String)
52+
if backend == "gpu_if_available"
53+
return gpu_device()
54+
elseif backend == "gpu"
55+
return gpu_device(; force=true)
56+
elseif backend == "reactant"
57+
return reactant_device(; force=true)
58+
elseif backend == "cpu"
59+
return cpu_device()
60+
else
61+
error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \
62+
`reactant`, and `cpu`.")
63+
end
64+
end
65+
66+
function train_model(
67+
model, opt, scheduler=nothing;
68+
backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25,
69+
bfloat16::Bool=false
70+
)
71+
rng = Random.default_rng()
72+
Random.seed!(rng, seed)
73+
74+
prec = bfloat16 ? bf16 : f32
75+
prec_jl = bfloat16 ? BFloat16 : Float32
76+
prec_str = bfloat16 ? "BFloat16" : "Float32"
77+
@printf "[Info] Using %s precision\n" prec_str
78+
79+
accelerator_device = get_accelerator_device(backend)
80+
kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : ()
81+
trainloader, testloader = get_cifar10_dataloaders(prec_jl, batchsize; kwargs...) |>
82+
accelerator_device
83+
84+
ps, st = Lux.setup(rng, model) |> prec |> accelerator_device
85+
86+
train_state = Training.TrainState(model, ps, st, opt)
87+
88+
adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote()
89+
90+
if backend == "reactant"
91+
x_ra = rand(rng, prec_jl, size(first(trainloader)[1])) |> accelerator_device
92+
@printf "[Info] Compiling model with Reactant.jl\n"
93+
st_test = Lux.testmode(st)
94+
model_compiled = Reactant.compile(model, (x_ra, ps, st_test))
95+
@printf "[Info] Model compiled!\n"
96+
else
97+
model_compiled = model
98+
end
99+
100+
loss_fn = CrossEntropyLoss(; logits=Val(true))
101+
102+
pt = ProgressTable(;
103+
header=[
104+
"Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)"
105+
],
106+
widths=[24, 24, 24, 24, 24],
107+
format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"],
108+
color=[:normal, :normal, :blue, :blue, :normal],
109+
border=true,
110+
alignment=[:center, :center, :center, :center, :center]
111+
)
112+
113+
@printf "[Info] Training model\n"
114+
initialize(pt)
115+
116+
for epoch in 1:epochs
117+
stime = time()
118+
lr = 0
119+
for (i, (x, y)) in enumerate(trainloader)
120+
if scheduler !== nothing
121+
lr = scheduler((epoch - 1) + (i + 1) / length(trainloader))
122+
train_state = Optimisers.adjust!(train_state, lr)
123+
end
124+
(_, loss, _, train_state) = Training.single_train_step!(
125+
adtype, loss_fn, (x, y), train_state
126+
)
127+
isnan(loss) && error("NaN loss encountered!")
128+
end
129+
ttime = time() - stime
130+
131+
train_acc = accuracy(
132+
model_compiled, train_state.parameters,
133+
Lux.testmode(train_state.states), trainloader
134+
) * 100
135+
test_acc = accuracy(
136+
model_compiled, train_state.parameters,
137+
Lux.testmode(train_state.states), testloader
138+
) * 100
139+
140+
scheduler === nothing && (lr = NaN32)
141+
next(pt, [epoch, lr, train_acc, test_acc, ttime])
142+
end
143+
144+
finalize(pt)
145+
@printf "[Info] Finished training\n"
146+
end

examples/CIFAR10/conv_mixer.jl

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using Comonicon, Interpolations, Lux, Optimisers, Printf, Random, Statistics, Zygote
2+
3+
include("common.jl")
4+
5+
function ConvMixer(; dim, depth, kernel_size=5, patch_size=2)
6+
#! format: off
7+
return Chain(
8+
Conv((patch_size, patch_size), 3 => dim, gelu; stride=patch_size),
9+
BatchNorm(dim),
10+
[
11+
Chain(
12+
SkipConnection(
13+
Chain(
14+
Conv(
15+
(kernel_size, kernel_size), dim => dim, gelu;
16+
groups=dim, pad=SamePad()
17+
),
18+
BatchNorm(dim)
19+
),
20+
+
21+
),
22+
Conv((1, 1), dim => dim, gelu),
23+
BatchNorm(dim)
24+
)
25+
for _ in 1:depth
26+
]...,
27+
GlobalMeanPool(),
28+
FlattenLayer(),
29+
Dense(dim => 10)
30+
)
31+
#! format: on
32+
end
33+
34+
Comonicon.@main function main(;
35+
batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
36+
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=0.0001,
37+
clip_norm::Bool=false, seed::Int=1234, epochs::Int=25, lr_max::Float64=0.05,
38+
backend::String="reactant", bfloat16::Bool=false
39+
)
40+
model = ConvMixer(; dim=hidden_dim, depth, kernel_size, patch_size)
41+
42+
opt = AdamW(; eta=lr_max, lambda=weight_decay)
43+
clip_norm && (opt = OptimiserChain(ClipNorm(), opt))
44+
45+
lr_schedule = linear_interpolation(
46+
[0, epochs * 2 ÷ 5, epochs * 4 ÷ 5, epochs + 1], [0, lr_max, lr_max / 20, 0]
47+
)
48+
49+
return train_model(model, opt, lr_schedule; backend, batchsize, seed, epochs, bfloat16)
50+
end

0 commit comments

Comments
 (0)