Skip to content

Commit

Permalink
update gpu devices
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Sep 23, 2024
1 parent f7813e9 commit db2bed9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 33 deletions.
33 changes: 19 additions & 14 deletions src/exts/mlj_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,21 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
X, Y = XY
x = collect(transpose(MLJModelInterface.matrix(X)))
y = collect(transpose(MLJModelInterface.matrix(Y)))
tdev = if model.m.resource isa ComputationalResources.CUDALibs
Lux.gpu_device()
else
Lux.cpu_device()
end
ps, st = LuxCore.setup(model.m.rng, model.m)
ps = ComponentArrays.ComponentArray(ps)
if model.m.resource isa ComputationalResources.CUDALibs
gdev = Lux.gpu_device()
x = gdev(x)
y = gdev(y)
ps = gdev(ps)
st = gdev(st)
end
if model.m.compute_mode isa VectorMode
data = MLUtils.DataLoader((x, y); batchsize = -1, shuffle = true, partial = true)
x = tdev(x)
y = tdev(y)
ps = tdev(ps)
st = tdev(st)
data = if model.m.compute_mode isa VectorMode
MLUtils.DataLoader((x, y); batchsize = -1, shuffle = true, partial = true)
elseif model.m.compute_mode isa MatrixMode
data = MLUtils.DataLoader(
MLUtils.DataLoader(
(x, y);
batchsize = if model.use_batch
model.batch_size
Expand All @@ -51,6 +53,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
else
error("Not Implemented")
end
data = tdev(data)
optfunc = SciMLBase.OptimizationFunction(
make_opt_loss(model.m, TrainMode(), st, model.loss),
model.adtype,
Expand Down Expand Up @@ -89,11 +92,13 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew)
Xnew, Ynew = XYnew
xnew = collect(transpose(MLJModelInterface.matrix(Xnew)))
ynew = collect(transpose(MLJModelInterface.matrix(Ynew)))
if model.m.resource isa ComputationalResources.CUDALibs
gdev = Lux.gpu_device()
xnew = gdev(xnew)
ynew = gdev(ynew)
tdev = if model.m.resource isa ComputationalResources.CUDALibs
Lux.gpu_device()
else
Lux.cpu_device()
end
xnew = tdev(xnew)
ynew = tdev(ynew)
(ps, st) = fitresult

tst = @timed if model.m.compute_mode isa VectorMode
Expand Down
29 changes: 17 additions & 12 deletions src/exts/mlj_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,20 @@ end

function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
x = collect(transpose(MLJModelInterface.matrix(X)))
tdev = if model.m.resource isa ComputationalResources.CUDALibs
Lux.gpu_device()
else
Lux.cpu_device()
end
ps, st = LuxCore.setup(model.m.rng, model.m)
ps = ComponentArrays.ComponentArray(ps)
if model.m.resource isa ComputationalResources.CUDALibs
gdev = Lux.gpu_device()
x = gdev(x)
ps = gdev(ps)
st = gdev(st)
end
if model.m.compute_mode isa VectorMode
data = MLUtils.DataLoader((x,); batchsize = -1, shuffle = true, partial = true)
x = tdev(x)
ps = tdev(ps)
st = tdev(st)
data = if model.m.compute_mode isa VectorMode
MLUtils.DataLoader((x,); batchsize = -1, shuffle = true, partial = true)
elseif model.m.compute_mode isa MatrixMode
data = MLUtils.DataLoader(
MLUtils.DataLoader(
(x,);
batchsize = if model.use_batch
model.batch_size
Expand All @@ -48,6 +50,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
else
error("Not Implemented")
end
data = tdev(data)
optfunc = SciMLBase.OptimizationFunction(
make_opt_loss(model.m, TrainMode(), st, model.loss),
model.adtype,
Expand Down Expand Up @@ -85,10 +88,12 @@ end

function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew)
xnew = collect(transpose(MLJModelInterface.matrix(Xnew)))
if model.m.resource isa ComputationalResources.CUDALibs
gdev = Lux.gpu_device()
xnew = gdev(xnew)
tdev = if model.m.resource isa ComputationalResources.CUDALibs
Lux.gpu_device()
else
Lux.cpu_device()
end
xnew = tdev(xnew)
(ps, st) = fitresult

tst = @timed if model.m.compute_mode isa VectorMode
Expand Down
15 changes: 8 additions & 7 deletions test/call_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ Test.@testset "Call Tests" begin
resources = ComputationalResources.AbstractResource[ComputationalResources.CPU1()]
if CUDA.has_cuda_gpu() && USE_GPU
push!(resources, ComputationalResources.CUDALibs())
gdev = Lux.gpu_device()
end

Test.@testset "$resource | $data_type | $compute_mode | inplace = $inplace | aug & steer = $aug_steer | nvars = $nvars | $omode | $mt" for resource in
Expand Down Expand Up @@ -145,13 +144,15 @@ Test.@testset "Call Tests" begin
)
ps, st = Lux.setup(icnf.rng, icnf)
ps = ComponentArrays.ComponentArray(ps)
if resource isa ComputationalResources.CUDALibs
r = gdev(r)
r2 = gdev(r2)
ps = gdev(ps)
st = gdev(st)
tdev = if resource isa ComputationalResources.CUDALibs
Lux.gpu_device()
else
Lux.cpu_device()
end

r = tdev(r)
r2 = tdev(r2)
ps = tdev(ps)
st = tdev(st)
if mt <: Union{
ContinuousNormalizingFlows.CondRNODE,
ContinuousNormalizingFlows.CondFFJORD,
Expand Down

0 comments on commit db2bed9

Please sign in to comment.