Skip to content

Commit

Permalink
Correction for matrix valued M sampling, now it works properly.
Browse files Browse the repository at this point in the history
+ some fix for nightly version
  • Loading branch information
Kertoo committed May 23, 2024
1 parent 426c411 commit cbd31d9
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 47 deletions.
8 changes: 7 additions & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.2"
manifest_format = "2.0"
project_hash = "18400d7847a4a2861ac95edee3b678e21d64e1d1"
project_hash = "3e7eef2be031d31b50e076fd1144ec85dd5aef91"

[[deps.ADTypes]]
git-tree-sha1 = "daf26bbdec60d9ca1c0003b70f389d821ddb4224"
Expand Down Expand Up @@ -933,6 +933,12 @@ git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5"
uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e"
version = "0.3.4"

[[deps.StyledStrings]]
deps = ["TOML"]
git-tree-sha1 = "d108f10ee6a0f3955ed73b1ddf7dda09b7de6b21"
uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
version = "1.0.1"

[[deps.SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StyledStrings = "f489334b-da3d-4c2e-b8f0-e476e12c162b"

[compat]
CSV = "0.10"
Expand Down
115 changes: 69 additions & 46 deletions src/binomial_model_sampling_covariates.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@



# TODO:: This is waaaaay to slow
# TODO:: This goes to infinity with sampling fix it!
# probably a math mistake in norm_const calc
function sample_M_matrix_variate_cond_random_eff(n, N, m, γ₀, γ₁, γ₂, u, M)
# compute ξ, μ
μ = reduce(vcat, [(k .^ γ₁)' for k in eachrow(N)]) .* reduce(vcat, [(k .^ γ₂)' for k in eachrow(n ./ N)])
Expand All @@ -8,68 +14,85 @@ function sample_M_matrix_variate_cond_random_eff(n, N, m, γ₀, γ₁, γ₂, u

norm_const = zeros(size(M)...)

for i in axes(M, 1)
for j in axes(M, 2)
# add computable first term
res = BigFloat(0)
# compute minimum
Wⱼ = minimum(M[i, setdiff(axes(M, 2), j)])
if Wⱼ >= m[i, j]
for x in m[i, j]:Wⱼ
k = collect(0:x)
ress = k * log(ξ₀[i]) - k * sum(log.(ξ[i, :])) - logfactorial.(x .- k)
ress .-= [sum(logfactorial.(M[i, setdiff(axes(M, 2), j)] .- t)) for t in k]
ress = sum(ress)
# this is a scalar
ress *= exp(BigFloat(
x * log(ξ[i, j]) + x * log(1 - u[i] * μ[i, j]) +
sum(logfactorial.(M[i, setdiff(axes(M, 2), j)])) +
logfactorial(x) - logfactorial(x - m[i, j])
setprecision(100) do
for i in axes(M, 1)
for j in axes(M, 2)
# add computable first term
res = BigFloat(0)
# compute minimum
M_no_j = M[i, setdiff(axes(M, 2), j)]
Wⱼ = minimum(M_no_j)
if Wⱼ >= m[i, j]
for x in m[i, j]:Wⱼ
k = collect(0:x)
ress = BigFloat.(k .* log(ξ₀[i]) .- k .* sum(log.(ξ[i, :])) .- logfactorial.(x .- k))
ress .-= [sum(logfactorial.(M_no_j .- t) .+ logfactorial(t)) for t in k]
ress = sum(exp.(ress))
ress *= exp(BigFloat(
x * log(ξ[i, j]) + x * log(1 - u[i] * μ[i, j]) +
sum(logfactorial.(M_no_j)) + logfactorial(x) - logfactorial(x - m[i, j])
))
res += ress
end # end for x
end # end if
# add second hyper geometric term
for k in 0:Wⱼ
Tⱼ = max(Wⱼ + 1, m[i, j])
ress = exp(BigFloat(
Tⱼ * log(ξ[i, j] * (1 - u[i] * μ[i, j])) - logfactorial(Tⱼ - m[i, j]) - logfactorial(Tⱼ - k) +
k * log(ξ₀[i]) - k * sum(log.(ξ[i, :])) + sum(logfactorial.(M_no_j)) -
sum(logfactorial.(M_no_j .- k)) - sum(logfactorial.(M_no_j * 0 .+ k))
))
res += ress
end # end for x
end # end if
# add second hyper geometric term
for k in 0:Wⱼ
Tⱼ = max(Wⱼ, m[i, j])
ress = exp(BigFloat(
(Tⱼ + 1) * log(ξ[i, j] * (1 - u[i] * μ[i, j])) - logfactorial(Tⱼ + 1 - m[i, j]) - logfactorial(Tⱼ + 1 - k) +
log(ξ₀[i]) - sum(log.(ξ[i, :])) + sum(log.(M[i, setdiff(axes(M, 2), j)])) - sum(log.(M[i, setdiff(axes(M, 2), j)] .- k))
))
ress *= pFq((1, Tⱼ + 2), (Tⱼ - m[i, j] + 2, Tⱼ + 2 - k), ξ[i, j] * (1 - u[i] * μ[i, j]))
res += ress
end # end for k
# copy needed?
norm_const[i, j] = copy(res)
end # end for j
end # end for i
ress *= pFq((1, Tⱼ + 1), (Tⱼ - m[i, j] + 1, Tⱼ + 1 - k), ξ[i, j] * (1 - u[i] * μ[i, j]))
res += ress
end # end for k
# copy needed?
norm_const[i, j] = copy(res)
end # end for j
end # end for i
end # end set precision

# Conditional mass function at point x + m[i, j] for M[i,j]
function mass_function(x, i, j)
mm = min(minimum(M[i, setdiff(axes(M, 2), j)]), x)
res = BigFloat(0)

for z in 1:mm
xxx = sum(logfactorial.(M[i, setdiff(axes(M, 2), j)]) - logfactorial.(M[i, setdiff(axes(M, 2), j)] .- z))
xxx += z * log.(ξ₀[i]) - z * sum(log.(ξ[i, :]))
res += exp(BigFloat(xxx))
end # end for
setprecision(100) do
mm = min(minimum(M[i, setdiff(axes(M, 2), j)]), x + m[i, j])
for z in 1:mm
xxx = sum(BigFloat.(logfactorial.(M[i, setdiff(axes(M, 2), j)]) - logfactorial.(M[i, setdiff(axes(M, 2), j)] .- z) .- logfactorial(z)))
xxx += logfactorial(x + m[i,j]) - logfactorial(x + m[i, j] - z)
xxx += z * log(ξ₀[i]) - z * sum(log.(ξ[i, :]))
res += exp(xxx)
end # end for

res *= exp(BigFloat((x + m[i, j]) * log(ξ[i, j] * (1 - μ[i, j] * u[i])) - logfactorial(x)))
end # end setprecision

# compute normalization constant
res *= exp(BigFloat(x * log(ξ[i, j] * (1 - μ[i, j] * u[i])) - logfactorial(x) - norm_const[i, j]))
res
end # end function

## Sampling
#println(norm_const)
#error("abc")
U = rand(size(M)...)
x = zeros(Int, size(M)...)
xxx = reshape([mass_function(x[i, j], i, j) for j in axes(x, 2) for i in axes(x, 1)], size(x))
cond = xxx .<= U
# not normalized CDF
CDF = reshape([mass_function(x[i, j], i, j) for j in axes(x, 2) for i in axes(x, 1)], size(x))
# Multiplication is less prone to numerical errors
U .*= norm_const
cond = CDF .<= U

# TODO:: Computationally this is the problem think about jumping by idk like 10 or even 100 since
# on test code there is update of over 1300
while any(cond)
#println("--------")
#println(sum(cond))
#println(maximum(x))
#println([minimum(U[cond] .- CDF[cond]), maximum(U[cond] .- CDF[cond])])
x[cond] .+= 1

xxx += reshape([mass_function(x[i, j], i, j) for j in axes(x, 2) for i in axes(x, 1)], size(x))
cond = xxx .< U
CDF[cond] += reshape([cond[i,j] ? mass_function(x[i, j], i, j) : 0 for j in axes(x, 2) for i in axes(x, 1)], size(x))[cond]
cond = CDF .< U
end

# return x + m
Expand Down

0 comments on commit cbd31d9

Please sign in to comment.