Skip to content

Commit

Permalink
Test LogExpFunctions on F32
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Dec 10, 2024
1 parent eacf751 commit 9e3ec58
Showing 1 changed file with 37 additions and 33 deletions.
70 changes: 37 additions & 33 deletions test/integration_testing/logexpfunctions/logexpfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,43 @@ using Mooncake.TestUtils: test_rule
sr(n::Int) = StableRNG(n)

@testset "logexpfunctions" begin
@testset for (perf_flag, f, x...) in [
(:allocs, xlogx, 1.1),
(:allocs, xlogy, 0.3, 1.2),
(:allocs, xlog1py, 0.3, -0.5),
(:allocs, xexpx, -0.5),
(:allocs, xexpy, 1.0, -0.7),
(:allocs, logistic, 0.5),
(:allocs, logit, 0.3),
(:allocs, logcosh, 1.5),
(:allocs, logabssinh, 0.3),
(:allocs, log1psq, 0.3),
(:allocs, log1pexp, 0.1),
(:allocs, log1mexp, -0.5),
(:allocs, log2mexp, 0.1),
(:allocs, logexpm1, 0.1),
(:allocs, log1pmx, -0.95),
(:allocs, logmxp1, 0.02),
(:allocs, logaddexp, -0.5, 0.4),
(:allocs, logsubexp, -0.5, -5.0),
(:allocs, logsumexp, randn(sr(1), 5)),
(:allocs, logsumexp, randn(sr(2), 5, 4)),
(:allocs, logsumexp, randn(sr(3), 5, 4, 3)),
(:none, x -> logsumexp(x; dims=1), randn(sr(4), 5, 4)),
(:none, x -> logsumexp(x; dims=2), randn(sr(5), 5, 4)),
(:none, logsumexp!, rand(sr(6), 5), randn(sr(7), 5, 4)),
(:none, softmax, randn(sr(7), 10)),
(:allocs, cloglog, 0.5),
(:allocs, cexpexp, -0.3),
(:allocs, loglogistic, 0.5),
(:allocs, logitexp, -0.3),
(:allocs, log1mlogistic, -0.9),
(:allocs, logit1mexp, -0.6),
]
@testset for (perf_flag, f, x...) in vcat(
map([Float64, Float32]) do P
return Any[
(:allocs, xlogx, P(1.1)),
(:allocs, xlogy, P(0.3), P(1.2)),
(:allocs, xlog1py, P(0.3), -P(0.5)),
(:allocs, xexpx, -P(0.5)),
(:allocs, xexpy, P(1.0), -P(0.7)),
(:allocs, logistic, P(0.5)),
(:allocs, logit, P(0.3)),
(:allocs, logcosh, P(1.5)),
(:allocs, logabssinh, P(0.3)),
(:allocs, log1psq, P(0.3)),
(:allocs, log1pexp, P(0.1)),
(:allocs, log1mexp, -P(0.5)),
(:allocs, log2mexp, P(0.1)),
(:allocs, logexpm1, P(0.1)),
(:allocs, log1pmx, -P(0.95)),
(:allocs, logmxp1, P(0.02)),
(:allocs, logaddexp, -P(0.5), P(0.4)),
(:allocs, logsubexp, -P(0.5), -P(5.0)),
(:allocs, logsumexp, randn(sr(1), P, 5)),
(:allocs, logsumexp, randn(sr(2), P, 5, 4)),
(:allocs, logsumexp, randn(sr(3), P, 5, 4, 3)),
(:none, x -> logsumexp(x; dims=1), randn(sr(4), P, 5, 4)),
(:none, x -> logsumexp(x; dims=2), randn(sr(5), P, 5, 4)),
(:none, logsumexp!, rand(sr(6), 5), randn(sr(7), P, 5, 4)),
(:none, softmax, randn(sr(7), P, 10)),
(:allocs, cloglog, P(0.5)),
(:allocs, cexpexp, -P(0.3)),
(:allocs, loglogistic, P(0.5)),
(:allocs, logitexp, -P(0.3)),
(:allocs, log1mlogistic, -P(0.9)),
(:allocs, logit1mexp, -P(0.6)),
]
end...,
)
test_rule(sr(123456), f, x...; perf_flag, is_primitive=false)
end
end

0 comments on commit 9e3ec58

Please sign in to comment.