Skip to content

Commit

Permalink
Merge pull request #170 from FluxML/faulty-test
Browse files Browse the repository at this point in the history
Fix faulty test leading to "scalar indexing issue" on GPU builds
  • Loading branch information
ablaom authored Jun 24, 2021
2 parents 4bcde16 + 74c5e8b commit 5e5d698
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
20 changes: 12 additions & 8 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ vector of arrays where the last dimension is the batch size. `y`
is the target observation vector.
"""
function train!(loss_func, parameters, optimiser, X, y)
for i=1:length(X)
n_batches = length(y)
training_loss = zero(Float32)
for i in 1:n_batches
gs = Flux.gradient(parameters) do
training_loss = loss_func(X[i], y[i])
return training_loss
batch_loss = loss_func(X[i], y[i])
training_loss += batch_loss
return batch_loss
end
Flux.update!(optimiser, parameters, gs)
end
return training_loss/n_batches
end


Expand Down Expand Up @@ -124,15 +128,15 @@ function fit!(chain, optimiser, loss, epochs,
loss_func(x, y) = loss(chain(x), y)

# initiate history:
prev_loss = mean(loss_func(X[i], y[i]) for i=1:length(X))
history = [prev_loss,]
n_batches = length(y)

training_loss = mean(loss_func(X[i], y[i]) for i in 1:n_batches)
history = [training_loss,]

for i in 1:epochs
# We're taking data in a Flux-fashion.
# @show i rand()
train!(loss_func, Flux.params(chain), optimiser, X, y)
current_loss =
mean(loss_func(X[i], y[i]) for i=1:length(X))
current_loss = train!(loss_func, Flux.params(chain), optimiser, X, y)
verbosity < 2 ||
@info "Loss is $(round(current_loss; sigdigits=4))"
push!(history, current_loss)
Expand Down
6 changes: 4 additions & 2 deletions test/builders.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# to control chain initialization:
myinit(n, m) = reshape(float(1:n*m), n , m)
myinit(n, m) = reshape(convert(Vector{Float32}, (1:n*m)), n , m)

mutable struct TESTBuilder <: MLJFlux.Builder end
MLJFlux.build(builder::TESTBuilder, rng, n_in, n_out) =
Expand All @@ -10,7 +10,8 @@ MLJFlux.build(builder::TESTBuilder, rng, n_in, n_out) =
# data:
n = 100
d = 5
Xmat = rand(Float64, n, d)
Xmat = rand(Float32, n, d)
# Xmat = fill(one(Float32), n, d)
X = MLJBase.table(Xmat);
y = X.x1 .^2 + X.x2 .* X.x3 - 4 * X.x4

Expand All @@ -31,6 +32,7 @@ MLJFlux.build(builder::TESTBuilder, rng, n_in, n_out) =
pretraining_yhat = Xmat*chain0' |> vec
@test y isa Vector && pretraining_yhat isa Vector
pretraining_loss_by_hand = MLJBase.l2(pretraining_yhat, y) |> mean
mean(((pretraining_yhat - y).^2)[1:2])

# compare:
@test pretraining_loss pretraining_loss_by_hand
Expand Down

0 comments on commit 5e5d698

Please sign in to comment.