Skip to content

Commit

Permalink
Merge branch 'main' into tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Apr 20, 2024
2 parents 80bf0d8 + 9d6b969 commit 01bdf0d
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 79 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ EnzymeSpecialFunctionsExt = "SpecialFunctions"
[compat]
CEnum = "0.4, 0.5"
EnzymeCore = "0.7"
Enzyme_jll = "0.0.103"
Enzyme_jll = "0.0.104"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
LLVM = "6.1"
Libdl = "1.6"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ Enzyme.autodiff(Reverse, f, Active(1.2), Const(Vector{Float64}(undef, 1)), Const
Passing in a dupliacted (e.g. differentiable) variable for `tmp` now leads to the correct answer.

```jldoctest storage
Enzyme.autodiff(Reverse, f, Active(1.2), Duplicated(Vector{Float64}(undef, 1), Vector{Float64}(undef, 1)), Const(1), Const(5)) # Correct (returns 10.367999999999999 == 1.2^4 * 5)
Enzyme.autodiff(Reverse, f, Active(1.2), Duplicated(Vector{Float64}(undef, 1), zeros(1)), Const(1), Const(5)) # Correct (returns 10.367999999999999 == 1.2^4 * 5)
# output
Expand Down Expand Up @@ -539,4 +539,4 @@ For `d/d conj(z)`, $\frac12 \left( [u_x + i v_x] + i [u_y + i v_y] \right) = \fr
3.1 + 2.7im
```

Note: when writing rules for complex scalar functions, in reverse mode one needs to conjugate the differential return, and similarly the true result will be the conjugate of that value (in essence you can think of reverse-mode AD as working in the conjugate space).
Note: when writing rules for complex scalar functions, in reverse mode one needs to conjugate the differential return, and similarly the true result will be the conjugate of that value (in essence you can think of reverse-mode AD as working in the conjugate space).
4 changes: 3 additions & 1 deletion lib/EnzymeTestUtils/test/test_forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ end
elseif TT <: NamedTuple
x = (a=randn(T), b=randn(T))
else # TT <: TestStruct
VERSION v"1.8" && (@test_skip false; continue)
if VERSION <= v"1.8" && Tx == BatchDuplicated
continue
end
x = TestStruct(randn(T, 5), randn(T))
end
atol = rtol = sqrt(eps(real(T)))
Expand Down
208 changes: 137 additions & 71 deletions src/compiler.jl

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function is_primitive_func(@nospecialize(TT))
end
end

if ft == typeof(Base.inv)
if ft == typeof(Base.inv) || ft == typeof(Base.sqrt)
if TT <: Tuple{ft, Complex{Float32}} || TT <: Tuple{ft, Complex{Float64}}
return true
end
Expand Down
7 changes: 5 additions & 2 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,18 @@ end
function EnzymeRules.inactive(::typeof(Core.kwfunc), args...)
return nothing
end
function EnzymeRules.inactive(::typeof(Random.rand), args...)
function EnzymeRules.inactive(::typeof(Random.rand), ::Random.AbstractRNG, ::Random.Sampler)
return nothing
end
function EnzymeRules.inactive(::typeof(Random.rand!), args...)
function EnzymeRules.inactive(::typeof(Random.rand!), ::Random.AbstractRNG, ::Random.Sampler, ::AbstractArray)
return nothing
end
function EnzymeRules.inactive(::typeof(Random.randn), args...)
return nothing
end
function EnzymeRules.inactive(::typeof(Random.randn!), args...)
return nothing
end
function EnzymeRules.inactive(::typeof(Random.default_rng), args...)
return nothing
end
Expand Down
2 changes: 1 addition & 1 deletion src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function jlcall_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
if in(name, ("ijl_f_getfield", "jl_f_getfield"))
return common_jl_getfield_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR)
end
if in(name, ("ijl_s_getfield", "jl_s_getfield"))
if in(name, ("ijl_f_setfield", "jl_f_setfield"))
return common_setfield_augfwd(2, B, orig, gutils, normalR, shadowR, tapeR)
end
if in(name, ("ijl_f__apply_iterate", "jl_f__apply_iterate"))
Expand Down
15 changes: 15 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using FiniteDifferences
using LinearAlgebra
using SparseArrays
using Test
import Random

struct TPair
a::Float64
Expand Down Expand Up @@ -432,4 +433,18 @@ end
end
end
end

@testset "rand and randn rules" begin
# Distributed as x + unit normal + uniform
struct MyDistribution
x::Float64
end

Random.rand(rng::Random.AbstractRNG, d::MyDistribution) = d.x + randn() + rand()
Random.rand(d::MyDistribution) = rand(Random.default_rng(), d)

# Outer rand should be differentiated through, and inner rand and randn should be ignored.
@test autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) == ((1.0,),)
end

end # InternalRules
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ make3() = (1.0, 2.0, 3.0)
test_scalar(x->rem(x, 1), 0.7)
test_scalar(x->rem2pi(x,RoundDown), 0.7)
test_scalar(x->fma(x,x+1,x/3), 2.3)
test_scalar(sqrt, 1.7+2.1im)

@test autodiff(Forward, sincos, Duplicated(1.0, 1.0))[1][1] cos(1.0)

Expand Down

0 comments on commit 01bdf0d

Please sign in to comment.