Skip to content

Commit

Permalink
Standard testing infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Dec 6, 2024
1 parent 578e41b commit b5d34b2
Showing 1 changed file with 36 additions and 31 deletions.
67 changes: 36 additions & 31 deletions test/interpreter/s2s_forward_mode_ad.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,52 @@
using MistyClosures
using Mooncake
using Test
using Core.Compiler: SSAValue
const CC = Core.Compiler
# using MistyClosures
# using Mooncake
# using Test
# using Core.Compiler: SSAValue
# const CC = Core.Compiler

@testset "s2s_forward_mode_ad" begin
test_cases = collect(enumerate(TestResources.generate_test_functions()))[1:1]
@testset "$(_typeof((f, x...)))" for (n, (interface_only, _, _, f, x...)) in test_cases
test_cases = collect(enumerate(TestResources.generate_test_functions()))[3:4]
@testset "$(_typeof((f, x...)))" for (n, (int_only, pf, _, f, x...)) in test_cases
sig = _typeof((f, x...))
@info "$n: $sig"
TestUtils.test_rule(
Xoshiro(123456), f, x...; perf_flag, interface_only, is_primitive=false
Xoshiro(123456),
f,
x...;
perf_flag=pf,
interface_only=int_only,
is_primitive=false,
forward=true,
)
end
end

# #=
# x, dx = 2.0, 3.0
# xdual = Dual(x, dx)

#=
x, dx = 2.0, 3.0
xdual = Dual(x, dx)
# sin_rule = build_frule(sin, x)
# ydual = sin_rule(zero_dual(sin), xdual)

sin_rule = build_frule(sin, x)
ydual = sin_rule(zero_dual(sin), xdual)
# @test primal(ydual) == sin(x)
# @test tangent(ydual) == dx * cos(x)
# =#

@test primal(ydual) == sin(x)
@test tangent(ydual) == dx * cos(x)
=#
# function func2(x)
# if x > 0.0
# y = sin(x)
# else
# y = cos(x)
# end
# return y
# end

function func2(x)
if x > 0.0
y = sin(x)
else
y = cos(x)
end
return y
end

x = 1.0
xdual = Dual(1.0, 2.0)
# x = 1.0
# xdual = Dual(1.0, 2.0)

ir = Base.code_ircode(func2, (typeof(x),))[1][1]
# ir = Base.code_ircode(func2, (typeof(x),))[1][1]

func_rule = build_frule(func2, x)
ydual = func_rule(zero_dual(func2), xdual)
# func_rule = build_frule(func2, x)
# ydual = func_rule(zero_dual(func2), xdual)

2cos(1)
# 2cos(1)

0 comments on commit b5d34b2

Please sign in to comment.