|
| 1 | +using ConcreteStructs, DataAugmentation, ImageShow, Lux, MLDatasets, MLUtils, OneHotArrays, |
| 2 | + Printf, ProgressTables, Random, BFloat16s |
| 3 | +using Reactant, LuxCUDA |
| 4 | + |
| 5 | +@concrete struct TensorDataset |
| 6 | + dataset |
| 7 | + transform |
| 8 | +end |
| 9 | + |
| 10 | +Base.length(ds::TensorDataset) = length(ds.dataset) |
| 11 | + |
| 12 | +function Base.getindex(ds::TensorDataset, idxs::Union{Vector{<:Integer}, AbstractRange}) |
| 13 | + img = Image.(eachslice(convert2image(ds.dataset, idxs); dims=3)) |
| 14 | + y = onehotbatch(ds.dataset.targets[idxs], 0:9) |
| 15 | + return stack(parent ∘ itemdata ∘ Base.Fix1(apply, ds.transform), img), y |
| 16 | +end |
| 17 | + |
| 18 | +function get_cifar10_dataloaders(::Type{T}, batchsize; kwargs...) where {T} |
| 19 | + cifar10_mean = (0.4914, 0.4822, 0.4465) .|> T |
| 20 | + cifar10_std = (0.2471, 0.2435, 0.2616) .|> T |
| 21 | + |
| 22 | + train_transform = RandomResizeCrop((32, 32)) |> |
| 23 | + Maybe(FlipX{2}()) |> |
| 24 | + ImageToTensor() |> |
| 25 | + Normalize(cifar10_mean, cifar10_std) |> |
| 26 | + ToEltype(T) |
| 27 | + |
| 28 | + test_transform = ImageToTensor() |> Normalize(cifar10_mean, cifar10_std) |> ToEltype(T) |
| 29 | + |
| 30 | + trainset = TensorDataset(CIFAR10(; Tx=T, split=:train), train_transform) |
| 31 | + trainloader = DataLoader(trainset; batchsize, shuffle=true, kwargs...) |
| 32 | + |
| 33 | + testset = TensorDataset(CIFAR10(; Tx=T, split=:test), test_transform) |
| 34 | + testloader = DataLoader(testset; batchsize, shuffle=false, kwargs...) |
| 35 | + |
| 36 | + return trainloader, testloader |
| 37 | +end |
| 38 | + |
| 39 | +function accuracy(model, ps, st, dataloader) |
| 40 | + total_correct, total = 0, 0 |
| 41 | + cdev = cpu_device() |
| 42 | + for (x, y) in dataloader |
| 43 | + target_class = onecold(cdev(y)) |
| 44 | + predicted_class = onecold(cdev(first(model(x, ps, st)))) |
| 45 | + total_correct += sum(target_class .== predicted_class) |
| 46 | + total += length(target_class) |
| 47 | + end |
| 48 | + return total_correct / total |
| 49 | +end |
| 50 | + |
| 51 | +function get_accelerator_device(backend::String) |
| 52 | + if backend == "gpu_if_available" |
| 53 | + return gpu_device() |
| 54 | + elseif backend == "gpu" |
| 55 | + return gpu_device(; force=true) |
| 56 | + elseif backend == "reactant" |
| 57 | + return reactant_device(; force=true) |
| 58 | + elseif backend == "cpu" |
| 59 | + return cpu_device() |
| 60 | + else |
| 61 | + error("Invalid backend: $(backend). Valid Options are: `gpu_if_available`, `gpu`, \ |
| 62 | + `reactant`, and `cpu`.") |
| 63 | + end |
| 64 | +end |
| 65 | + |
| 66 | +function train_model( |
| 67 | + model, opt, scheduler=nothing; |
| 68 | + backend::String, batchsize::Int=512, seed::Int=1234, epochs::Int=25, |
| 69 | + bfloat16::Bool=false |
| 70 | +) |
| 71 | + rng = Random.default_rng() |
| 72 | + Random.seed!(rng, seed) |
| 73 | + |
| 74 | + prec = bfloat16 ? bf16 : f32 |
| 75 | + prec_jl = bfloat16 ? BFloat16 : Float32 |
| 76 | + prec_str = bfloat16 ? "BFloat16" : "Float32" |
| 77 | + @printf "[Info] Using %s precision\n" prec_str |
| 78 | + |
| 79 | + accelerator_device = get_accelerator_device(backend) |
| 80 | + kwargs = accelerator_device isa ReactantDevice ? (; partial=false) : () |
| 81 | + trainloader, testloader = get_cifar10_dataloaders(prec_jl, batchsize; kwargs...) |> |
| 82 | + accelerator_device |
| 83 | + |
| 84 | + ps, st = Lux.setup(rng, model) |> prec |> accelerator_device |
| 85 | + |
| 86 | + train_state = Training.TrainState(model, ps, st, opt) |
| 87 | + |
| 88 | + adtype = backend == "reactant" ? AutoEnzyme() : AutoZygote() |
| 89 | + |
| 90 | + if backend == "reactant" |
| 91 | + x_ra = rand(rng, prec_jl, size(first(trainloader)[1])) |> accelerator_device |
| 92 | + @printf "[Info] Compiling model with Reactant.jl\n" |
| 93 | + st_test = Lux.testmode(st) |
| 94 | + model_compiled = Reactant.compile(model, (x_ra, ps, st_test)) |
| 95 | + @printf "[Info] Model compiled!\n" |
| 96 | + else |
| 97 | + model_compiled = model |
| 98 | + end |
| 99 | + |
| 100 | + loss_fn = CrossEntropyLoss(; logits=Val(true)) |
| 101 | + |
| 102 | + pt = ProgressTable(; |
| 103 | + header=[ |
| 104 | + "Epoch", "Learning Rate", "Train Accuracy (%)", "Test Accuracy (%)", "Time (s)" |
| 105 | + ], |
| 106 | + widths=[24, 24, 24, 24, 24], |
| 107 | + format=["%3d", "%.6f", "%.6f", "%.6f", "%.6f"], |
| 108 | + color=[:normal, :normal, :blue, :blue, :normal], |
| 109 | + border=true, |
| 110 | + alignment=[:center, :center, :center, :center, :center] |
| 111 | + ) |
| 112 | + |
| 113 | + @printf "[Info] Training model\n" |
| 114 | + initialize(pt) |
| 115 | + |
| 116 | + for epoch in 1:epochs |
| 117 | + stime = time() |
| 118 | + lr = 0 |
| 119 | + for (i, (x, y)) in enumerate(trainloader) |
| 120 | + if scheduler !== nothing |
| 121 | + lr = scheduler((epoch - 1) + (i + 1) / length(trainloader)) |
| 122 | + train_state = Optimisers.adjust!(train_state, lr) |
| 123 | + end |
| 124 | + (_, loss, _, train_state) = Training.single_train_step!( |
| 125 | + adtype, loss_fn, (x, y), train_state |
| 126 | + ) |
| 127 | + isnan(loss) && error("NaN loss encountered!") |
| 128 | + end |
| 129 | + ttime = time() - stime |
| 130 | + |
| 131 | + train_acc = accuracy( |
| 132 | + model_compiled, train_state.parameters, |
| 133 | + Lux.testmode(train_state.states), trainloader |
| 134 | + ) * 100 |
| 135 | + test_acc = accuracy( |
| 136 | + model_compiled, train_state.parameters, |
| 137 | + Lux.testmode(train_state.states), testloader |
| 138 | + ) * 100 |
| 139 | + |
| 140 | + scheduler === nothing && (lr = NaN32) |
| 141 | + next(pt, [epoch, lr, train_acc, test_acc, ttime]) |
| 142 | + end |
| 143 | + |
| 144 | + finalize(pt) |
| 145 | + @printf "[Info] Finished training\n" |
| 146 | +end |
0 commit comments