-
-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
43e0e9d
commit 09e9551
Showing
3 changed files
with
194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[deps] | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" | ||
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" | ||
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" | ||
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" | ||
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" | ||
|
||
[compat] | ||
CUDA = "5" | ||
Flux = "0.14" | ||
MLDatasets = "0.7" | ||
Metalhead = "0.9" | ||
Optimisers = "0.3" | ||
ProgressMeter = "1.9" | ||
TimerOutputs = "0.5" | ||
UnicodePlots = "3.6" | ||
cuDNN = "1.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
|
||
using CUDA, cuDNN | ||
using Flux | ||
using Flux: logitcrossentropy, onecold, onehotbatch | ||
using Metalhead | ||
using MLDatasets | ||
using Optimisers | ||
using ProgressMeter | ||
using TimerOutputs | ||
using UnicodePlots | ||
|
||
include("tooling.jl") | ||
|
||
epochs = 45 | ||
batchsize = 1000 | ||
device = gpu | ||
allow_skips = true | ||
|
||
train_loader, test_loader, labels = load_cifar10(; batchsize) | ||
nlabels = length(labels) | ||
firstbatch = first(first(train_loader)) | ||
imsize = size(firstbatch)[1:2] | ||
|
||
to = TimerOutput() | ||
|
||
# these should all be the smallest variant of each that is tested in `/test` | ||
modelstrings = ( | ||
"AlexNet()", | ||
"VGG(11, batchnorm=true)", | ||
"SqueezeNet()", | ||
"ResNet(18)", | ||
"WideResNet(50)", | ||
"ResNeXt(50, cardinality=32, base_width=4)", | ||
"SEResNet(18)", | ||
"SEResNeXt(50, cardinality=32, base_width=4)", | ||
"Res2Net(50, base_width=26, scale=4)", | ||
"Res2NeXt(50)", | ||
"GoogLeNet(batchnorm=true)", | ||
"DenseNet(121)", | ||
"Inceptionv3()", | ||
"Inceptionv4()", | ||
"InceptionResNetv2()", | ||
"Xception()", | ||
"MobileNetv1(0.5)", | ||
"MobileNetv2(0.5)", | ||
"MobileNetv3(:small, 0.5)", | ||
"MNASNet(MNASNet, 0.5)", | ||
"EfficientNet(:b0)", | ||
"EfficientNetv2(:small)", | ||
"ConvMixer(:small)", | ||
"ConvNeXt(:small)", | ||
# "MLPMixer()", # found no tests | ||
# "ResMLP()", # found no tests | ||
# "gMLP()", # found no tests | ||
"ViT(:tiny)", | ||
"UNet()" | ||
) | ||
|
||
for (i, modstring) in enumerate(modelstrings) | ||
@timeit to "$modstring" begin | ||
@info "Evaluating $i/$(length(modelstrings)) $modstring" | ||
@timeit to "First Load" eval(Meta.parse(modstring)) | ||
@timeit to "Second Load" model=eval(Meta.parse(modstring)) | ||
@timeit to "Training" train(model, | ||
train_loader, | ||
test_loader; | ||
to, | ||
device)||(allow_skips || break) | ||
end | ||
end | ||
print_timer(to; sortby = :firstexec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
function loss_and_accuracy(data_loader, model, device; limit = nothing) | ||
acc = 0 | ||
ls = 0.0f0 | ||
num = 0 | ||
i = 0 | ||
for (x, y) in data_loader | ||
x, y = x |> device, y |> device | ||
ŷ = model(x) | ||
ls += logitcrossentropy(ŷ, y, agg=sum) | ||
acc += sum(onecold(ŷ) .== onecold(y)) | ||
num += size(x)[end] | ||
if limit !== nothing | ||
i == limit && break | ||
i += 1 | ||
end | ||
end | ||
return ls / num, acc / num | ||
end | ||
|
||
function load_cifar10(; batchsize=1000) | ||
@info "loading CIFAR-10 dataset" | ||
train_dataset, test_dataset = CIFAR10(split=:train), CIFAR10(split=:test) | ||
train_x, train_y = train_dataset[:] | ||
test_x, test_y = test_dataset[:] | ||
@assert train_dataset.metadata["class_names"] == test_dataset.metadata["class_names"] | ||
labels = train_dataset.metadata["class_names"] | ||
|
||
# CIFAR10 label indices seem to be zero-indexed | ||
train_y .+= 1 | ||
test_y .+= 1 | ||
|
||
train_y_ohb = Flux.onehotbatch(train_y, eachindex(labels)) | ||
test_y_ohb = Flux.onehotbatch(test_y, eachindex(labels)) | ||
|
||
train_loader = Flux.DataLoader((data=train_x, labels=train_y_ohb); batchsize, shuffle=true) | ||
test_loader = Flux.DataLoader((data=test_x, labels=test_y_ohb); batchsize) | ||
|
||
return train_loader, test_loader, labels | ||
end | ||
|
||
function _train(model, train_loader, test_loader; epochs = 45, device = gpu, limit=nothing, gpu_gc=true, gpu_stats=false, show_plots=false, to=TimerOutput()) | ||
|
||
model = model |> device | ||
|
||
opt = Optimisers.Adam() | ||
state = Optimisers.setup(opt, model) | ||
|
||
train_loss_hist, train_acc_hist = Float64[], Float64[] | ||
test_loss_hist, test_acc_hist = Float64[], Float64[] | ||
|
||
@info "starting training" | ||
for epoch in 1:epochs | ||
i = 0 | ||
@showprogress "training epoch $epoch/$epochs" for (x, y) in train_loader | ||
x, y = x |> device, y |> device | ||
@timeit to "batch step" begin | ||
gs, _ = gradient(model, x) do m, _x | ||
logitcrossentropy(m(_x), y) | ||
end | ||
state, model = Optimisers.update(state, model, gs) | ||
end | ||
|
||
device === gpu && gpu_stats && CUDA.memory_status() | ||
if limit !== nothing | ||
i == limit && break | ||
i += 1 | ||
end | ||
end | ||
|
||
@info "epoch $epoch complete. Testing..." | ||
train_loss, train_acc = loss_and_accuracy(train_loader, model, device; limit) | ||
@timeit to "testing" test_loss, test_acc = loss_and_accuracy(test_loader, model, device; limit) | ||
@info map(x->round(x, digits=3), (; train_loss, train_acc, test_loss, test_acc)) | ||
|
||
if show_plots | ||
push!(train_loss_hist, train_loss); push!(train_acc_hist, train_acc); | ||
push!(test_loss_hist, test_loss); push!(test_acc_hist, test_acc); | ||
plt = lineplot(1:epoch, train_loss_hist, name = "train_loss", xlabel="epoch", ylabel="loss") | ||
lineplot!(plt, 1:epoch, test_loss_hist, name = "test_loss") | ||
display(plt) | ||
plt = lineplot(1:epoch, train_acc_hist, name = "train_acc", xlabel="epoch", ylabel="acc") | ||
lineplot!(plt, 1:epoch, test_acc_hist, name = "test_acc") | ||
display(plt) | ||
end | ||
if device === gpu && gpu_gc | ||
GC.gc() # GPU will OOM without this | ||
end | ||
end | ||
end | ||
|
||
# because Flux stacktraces are ludicrously big on <1.10 so don't show them | ||
function train(args...;kwargs...) | ||
try | ||
_train(args...; kwargs...) | ||
catch ex | ||
rethrow() | ||
println() | ||
@error sprint(showerror, ex) | ||
GC.gc() | ||
return false | ||
end | ||
end |