Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for the Generic Bayesian Neural Network? #334

Open
Gregliest opened this issue Nov 10, 2021 · 1 comment
Open

Improvements for the Generic Bayesian Neural Network? #334

Gregliest opened this issue Nov 10, 2021 · 1 comment

Comments

@Gregliest
Copy link

In the Bayesian Neural Network tutorial, there is code for a "Generic Bayesian Neural Network," that rebuilds the Flux model with the new sampled parameters. The tutorial itself also has a fully hard coded version. I think that a multiple dispatch approach might make the code a little more flexible and extensible to other layer types. Here's a prototype:

using Flux, Zygote

function similar(layer::Dense, params)
  i = 1

  weights = params[i: i + length(layer.weight) - 1]
  w = reshape(weights, size(layer.weight))
  i += length(layer.weight)

  b = params[i: i + length(layer.bias) - 1]
  i += length(layer.bias)

  return Dense(w, b, layer.σ)
end

function similar(model::Chain, params)
  layers = Array{Any}(undef, length(model))
  i = 1
  for (layerIndex, layer) in enumerate(model)
    layers[layerIndex] = similar(layer, params[i:end])
    i += numParams(Flux.params(layer))
  end

  return Chain(layers...)
end

function numParams(params::Zygote.Params)
  return reduce((x, y) -> x + length(y), params, init=0)
end

which can then be used like:

model = Chain(
  Dense(2, 3, tanh),
  Dense(3, 2, tanh),
  Dense(2, 1, σ)
)

n = numParams(Flux.params(model))
alpha = 0.09
sig = sqrt(1.0 / alpha)

@model function bayesNN(input, labels)
  nn_params ~ MvNormal(zeros(n), fill(sig, n))

  nn = similar(model, nn_params)
  predictions = nn(input)

  for i in 1:length(labels)
    labels[i] ~ Bernoulli(predictions[i])
  end
end

numSamples = 5000
chain = sample(bayesNN(xs, labels), HMC(.05, 4), numSamples)

This code will throw if similar hasn't been implemented for the Flux component, or if there aren't enough parameters. More error checking could easily be added. This approach could be easily extended to new layer types, and can handle nested models (like a Chain within a Chain). If this approach is of interest, I can add it to the tutorial.

One caveat is that similar is not a perfect analogue for this situation. The docs say that it creates an unintialized version, whereas here I'm creating a fully initialized model with the new params.

@yebai yebai transferred this issue from TuringLang/Turing.jl Nov 14, 2022
@shravanngoswamii
Copy link
Member

@yebai Is this worth changing in Bayes NN OR current changes are fine?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants