Skip to content

Commit

Permalink
fastmath
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Dec 10, 2024
1 parent e52859b commit 49f2157
Showing 1 changed file with 98 additions and 87 deletions.
185 changes: 98 additions & 87 deletions src/rrules/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function rrule!!(
::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{P}
) where {P<:IEEEFloat}
yp = Base.FastMath.exp2_fast(primal(x))
exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(2)
exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * P(log(2))
return CoDual(yp, NoFData()), exp2_fast_pb!!
end

Expand All @@ -21,7 +21,7 @@ function rrule!!(
::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{P}
) where {P<:IEEEFloat}
yp = Base.FastMath.exp10_fast(primal(x))
exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(10)
exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * P(log(10))
return CoDual(yp, NoFData()), exp2_fast_pb!!
end

Expand All @@ -36,96 +36,107 @@ end
@zero_adjoint MinimalCtx Tuple{typeof(log),Int}

function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:fastmath})
test_cases = Any[
(false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, 0.5),
(false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, 0.5),
(false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, 5.0),
(false, :stability_and_allocs, nothing, Base.FastMath.sincos, 3.0),
]
test_cases = reduce(
vcat,
map([Float64, Float32]) do P
return Any[
(false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, P(0.5)),
(false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, P(0.5)),
(false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, P(5.0)),
(false, :stability_and_allocs, nothing, Base.FastMath.sincos, P(3.0)),
]
end,
)
memory = Any[]
return test_cases, memory
end

function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:fastmath})
test_cases = Any[
(false, :allocs, nothing, Base.FastMath.abs2_fast, -5.0),
(false, :allocs, nothing, Base.FastMath.abs_fast, 5.0),
(false, :allocs, nothing, Base.FastMath.acos_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.acosh_fast, 1.2),
(false, :allocs, nothing, Base.FastMath.add_fast, 1.0, 2.0),
(false, :allocs, nothing, Base.FastMath.angle_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.asin_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.asinh_fast, 1.3),
(false, :allocs, nothing, Base.FastMath.atan_fast, 5.4),
(false, :allocs, nothing, Base.FastMath.atanh_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.cbrt_fast, 0.4),
(false, :allocs, nothing, Base.FastMath.cis_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.cmp_fast, 0.5, 0.4),
(false, :allocs, nothing, Base.FastMath.conj_fast, 0.4),
(false, :allocs, nothing, Base.FastMath.conj_fast, ComplexF64(0.5, 0.4)),
(false, :allocs, nothing, Base.FastMath.cos_fast, 0.4),
(false, :allocs, nothing, Base.FastMath.cosh_fast, 0.3),
(false, :allocs, nothing, Base.FastMath.div_fast, 5.0, 1.1),
(false, :allocs, nothing, Base.FastMath.eq_fast, 5.5, 5.5),
(false, :allocs, nothing, Base.FastMath.eq_fast, 5.5, 5.4),
(false, :allocs, nothing, Base.FastMath.expm1_fast, 5.4),
(false, :allocs, nothing, Base.FastMath.ge_fast, 5.0, 4.0),
(false, :allocs, nothing, Base.FastMath.ge_fast, 4.0, 5.0),
(false, :allocs, nothing, Base.FastMath.gt_fast, 5.0, 4.0),
(false, :allocs, nothing, Base.FastMath.gt_fast, 4.0, 5.0),
(false, :allocs, nothing, Base.FastMath.hypot_fast, 5.1, 3.2),
(false, :allocs, nothing, Base.FastMath.inv_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.isfinite_fast, 5.0),
(false, :allocs, nothing, Base.FastMath.isinf_fast, 5.0),
(false, :allocs, nothing, Base.FastMath.isnan_fast, 5.0),
(false, :allocs, nothing, Base.FastMath.issubnormal_fast, 0.3),
(false, :allocs, nothing, Base.FastMath.le_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.log10_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.log1p_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.log2_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.log_fast, 0.5),
(false, :allocs, nothing, Base.FastMath.lt_fast, 0.5, 4.0),
(false, :allocs, nothing, Base.FastMath.lt_fast, 5.0, 0.4),
(false, :allocs, nothing, Base.FastMath.max_fast, 5.0, 4.0),
(
false,
:none,
nothing,
Base.FastMath.maximum!_fast,
sin,
[0.0, 0.0],
[5.0 4.0; 3.0 2.0],
),
(false, :allocs, nothing, Base.FastMath.maximum_fast, [5.0, 4.0, 3.0]),
(false, :allocs, nothing, Base.FastMath.min_fast, 5.0, 4.0),
(false, :allocs, nothing, Base.FastMath.min_fast, 4.0, 5.0),
(
false,
:none,
nothing,
Base.FastMath.minimum!_fast,
sin,
[0.0, 0.0],
[5.0 4.0; 3.0 2.0],
),
(false, :allocs, nothing, Base.FastMath.minimum_fast, [5.0, 3.0, 4.0]),
(false, :allocs, nothing, Base.FastMath.minmax_fast, 5.0, 4.0),
(false, :allocs, nothing, Base.FastMath.mul_fast, 5.0, 4.0),
(false, :allocs, nothing, Base.FastMath.ne_fast, 5.0, 4.0),
(false, :allocs, nothing, Base.FastMath.pow_fast, 5.0, 2.0),
# (:allocs, Base.FastMath.pow_fast, 5.0, 2), # errors -- ADD A RULE FOR ME!
# (:allocs, Base.FastMath.rem_fast, 5.0, 2.0), # error -- ADD A RULE FOR ME!
(false, :allocs, nothing, Base.FastMath.sign_fast, 5.0),
(false, :allocs, nothing, Base.FastMath.sign_fast, -5.0),
(false, :allocs, nothing, Base.FastMath.sin_fast, 5.0),
(false, :allocs, nothing, Base.FastMath.cos_fast, 4.0),
(false, :allocs, nothing, Base.FastMath.sincos_fast, 4.0),
(false, :allocs, nothing, Base.FastMath.sinh_fast, 5.0),
(false, :allocs, nothing, Base.FastMath.sqrt_fast, 5.0),
(false, :allocs, nothing, Base.FastMath.sub_fast, 5.0, 4.0),
(false, :allocs, nothing, Base.FastMath.tan_fast, 4.0),
(false, :allocs, nothing, Base.FastMath.tanh_fast, 0.5),
]
test_cases = reduce(
vcat,
map([Float64, Float32]) do P
C = P === Float64 ? ComplexF64 : ComplexF32
return Any[
(false, :allocs, nothing, Base.FastMath.abs2_fast, P(-5.0)),
(false, :allocs, nothing, Base.FastMath.abs_fast, P(5.0)),
(false, :allocs, nothing, Base.FastMath.acos_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.acosh_fast, P(1.2)),
(false, :allocs, nothing, Base.FastMath.add_fast, P(1.0), P(2.0)),
(false, :allocs, nothing, Base.FastMath.angle_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.asin_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.asinh_fast, P(1.3)),
(false, :allocs, nothing, Base.FastMath.atan_fast, P(5.4)),
(false, :allocs, nothing, Base.FastMath.atanh_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.cbrt_fast, P(0.4)),
(false, :allocs, nothing, Base.FastMath.cis_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.cmp_fast, P(0.5), P(0.4)),
(false, :allocs, nothing, Base.FastMath.conj_fast, P(0.4)),
(false, :allocs, nothing, Base.FastMath.conj_fast, C(0.5, 0.4)),
(false, :allocs, nothing, Base.FastMath.cos_fast, P(0.4)),
(false, :allocs, nothing, Base.FastMath.cosh_fast, P(0.3)),
(false, :allocs, nothing, Base.FastMath.div_fast, P(5.0), P(1.1)),
(false, :allocs, nothing, Base.FastMath.eq_fast, P(5.5), P(5.5)),
(false, :allocs, nothing, Base.FastMath.eq_fast, P(5.5), P(5.4)),
(false, :allocs, nothing, Base.FastMath.expm1_fast, P(5.4)),
(false, :allocs, nothing, Base.FastMath.ge_fast, P(5.0), P(4.0)),
(false, :allocs, nothing, Base.FastMath.ge_fast, P(4.0), P(5.0)),
(false, :allocs, nothing, Base.FastMath.gt_fast, P(5.0), P(4.0)),
(false, :allocs, nothing, Base.FastMath.gt_fast, P(4.0), P(5.0)),
(false, :allocs, nothing, Base.FastMath.hypot_fast, P(5.1), P(3.2)),
(false, :allocs, nothing, Base.FastMath.inv_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.isfinite_fast, P(5.0)),
(false, :allocs, nothing, Base.FastMath.isinf_fast, P(5.0)),
(false, :allocs, nothing, Base.FastMath.isnan_fast, P(5.0)),
(false, :allocs, nothing, Base.FastMath.issubnormal_fast, P(0.3)),
(false, :allocs, nothing, Base.FastMath.le_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.log10_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.log1p_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.log2_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.log_fast, P(0.5)),
(false, :allocs, nothing, Base.FastMath.lt_fast, P(0.5), P(4.0)),
(false, :allocs, nothing, Base.FastMath.lt_fast, P(5.0), P(0.4)),
(false, :allocs, nothing, Base.FastMath.max_fast, P(5.0), P(4.0)),
(
false,
:none,
nothing,
Base.FastMath.maximum!_fast,
sin,
P.([0.0, 0.0]),
P.([5.0 4.0; 3.0 2.0]),
),
(false, :allocs, nothing, Base.FastMath.maximum_fast, P.([5.0, 4.0, 3.0])),
(false, :allocs, nothing, Base.FastMath.min_fast, P(5.0), P(4.0)),
(false, :allocs, nothing, Base.FastMath.min_fast, P(4.0), P(5.0)),
(
false,
:none,
nothing,
Base.FastMath.minimum!_fast,
sin,
P.([0.0, 0.0]),
P.([5.0 4.0; 3.0 2.0]),
),
(false, :allocs, nothing, Base.FastMath.minimum_fast, P.([5.0, 3.0, 4.0])),
(false, :allocs, nothing, Base.FastMath.minmax_fast, P(5.0), P(4.0)),
(false, :allocs, nothing, Base.FastMath.mul_fast, P(5.0), P(4.0)),
(false, :allocs, nothing, Base.FastMath.ne_fast, P(5.0), P(4.0)),
(false, :allocs, nothing, Base.FastMath.pow_fast, P(5.0), P(2.0)),
# (:allocs, Base.FastMath.pow_fast, P(5.0), 2), # errors -- NEEDS RULE!
# (:allocs, Base.FastMath.rem_fast, P(5.0), P(2.0)), # error -- NEEDS RULE!
(false, :allocs, nothing, Base.FastMath.sign_fast, P(5.0)),
(false, :allocs, nothing, Base.FastMath.sign_fast, P(-5.0)),
(false, :allocs, nothing, Base.FastMath.sin_fast, P(5.0)),
(false, :allocs, nothing, Base.FastMath.cos_fast, P(4.0)),
(false, :allocs, nothing, Base.FastMath.sincos_fast, P(4.0)),
(false, :allocs, nothing, Base.FastMath.sinh_fast, P(5.0)),
(false, :allocs, nothing, Base.FastMath.sqrt_fast, P(5.0)),
(false, :allocs, nothing, Base.FastMath.sub_fast, P(5.0), P(4.0)),
(false, :allocs, nothing, Base.FastMath.tan_fast, P(4.0)),
(false, :allocs, nothing, Base.FastMath.tanh_fast, P(0.5)),
]
end,
)
memory = Any[]
return test_cases, memory
end

0 comments on commit 49f2157

Please sign in to comment.