From 49f21577d11238a7eefbd0610bf5b62b9cc93bb6 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Tue, 10 Dec 2024 14:25:24 +0000 Subject: [PATCH] fastmath --- src/rrules/fastmath.jl | 185 ++++++++++++++++++++++------------------- 1 file changed, 98 insertions(+), 87 deletions(-) diff --git a/src/rrules/fastmath.jl b/src/rrules/fastmath.jl index 935c0947f..6cfa99751 100644 --- a/src/rrules/fastmath.jl +++ b/src/rrules/fastmath.jl @@ -12,7 +12,7 @@ function rrule!!( ::CoDual{typeof(Base.FastMath.exp2_fast)}, x::CoDual{P} ) where {P<:IEEEFloat} yp = Base.FastMath.exp2_fast(primal(x)) - exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(2) + exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * P(log(2)) return CoDual(yp, NoFData()), exp2_fast_pb!! end @@ -21,7 +21,7 @@ function rrule!!( ::CoDual{typeof(Base.FastMath.exp10_fast)}, x::CoDual{P} ) where {P<:IEEEFloat} yp = Base.FastMath.exp10_fast(primal(x)) - exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * log(10) + exp2_fast_pb!!(dy::P) = NoRData(), dy * yp * P(log(10)) return CoDual(yp, NoFData()), exp2_fast_pb!! end @@ -36,96 +36,107 @@ end @zero_adjoint MinimalCtx Tuple{typeof(log),Int} function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:fastmath}) - test_cases = Any[ - (false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, 0.5), - (false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, 0.5), - (false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, 5.0), - (false, :stability_and_allocs, nothing, Base.FastMath.sincos, 3.0), - ] + test_cases = reduce( + vcat, + map([Float64, Float32]) do P + return Any[ + (false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, P(0.5)), + (false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, P(0.5)), + (false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, P(5.0)), + (false, :stability_and_allocs, nothing, Base.FastMath.sincos, P(3.0)), + ] + end, + ) memory = Any[] return test_cases, memory end function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:fastmath}) - test_cases = Any[ - (false, :allocs, nothing, Base.FastMath.abs2_fast, -5.0), - (false, :allocs, nothing, Base.FastMath.abs_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.acos_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.acosh_fast, 1.2), - (false, :allocs, nothing, Base.FastMath.add_fast, 1.0, 2.0), - (false, :allocs, nothing, Base.FastMath.angle_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.asin_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.asinh_fast, 1.3), - (false, :allocs, nothing, Base.FastMath.atan_fast, 5.4), - (false, :allocs, nothing, Base.FastMath.atanh_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.cbrt_fast, 0.4), - (false, :allocs, nothing, Base.FastMath.cis_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.cmp_fast, 0.5, 0.4), - (false, :allocs, nothing, Base.FastMath.conj_fast, 0.4), - (false, :allocs, nothing, Base.FastMath.conj_fast, ComplexF64(0.5, 0.4)), - (false, :allocs, nothing, Base.FastMath.cos_fast, 0.4), - (false, :allocs, nothing, Base.FastMath.cosh_fast, 0.3), - (false, :allocs, nothing, Base.FastMath.div_fast, 5.0, 1.1), - (false, :allocs, nothing, Base.FastMath.eq_fast, 5.5, 5.5), - (false, :allocs, nothing, Base.FastMath.eq_fast, 5.5, 5.4), - (false, :allocs, nothing, Base.FastMath.expm1_fast, 5.4), - (false, :allocs, nothing, Base.FastMath.ge_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.ge_fast, 4.0, 5.0), - (false, :allocs, nothing, Base.FastMath.gt_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.gt_fast, 4.0, 5.0), - (false, :allocs, nothing, Base.FastMath.hypot_fast, 5.1, 3.2), - (false, :allocs, nothing, Base.FastMath.inv_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.isfinite_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.isinf_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.isnan_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.issubnormal_fast, 0.3), - (false, :allocs, nothing, Base.FastMath.le_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.log10_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.log1p_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.log2_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.log_fast, 0.5), - (false, :allocs, nothing, Base.FastMath.lt_fast, 0.5, 4.0), - (false, :allocs, nothing, Base.FastMath.lt_fast, 5.0, 0.4), - (false, :allocs, nothing, Base.FastMath.max_fast, 5.0, 4.0), - ( - false, - :none, - nothing, - Base.FastMath.maximum!_fast, - sin, - [0.0, 0.0], - [5.0 4.0; 3.0 2.0], - ), - (false, :allocs, nothing, Base.FastMath.maximum_fast, [5.0, 4.0, 3.0]), - (false, :allocs, nothing, Base.FastMath.min_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.min_fast, 4.0, 5.0), - ( - false, - :none, - nothing, - Base.FastMath.minimum!_fast, - sin, - [0.0, 0.0], - [5.0 4.0; 3.0 2.0], - ), - (false, :allocs, nothing, Base.FastMath.minimum_fast, [5.0, 3.0, 4.0]), - (false, :allocs, nothing, Base.FastMath.minmax_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.mul_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.ne_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.pow_fast, 5.0, 2.0), - # (:allocs, Base.FastMath.pow_fast, 5.0, 2), # errors -- ADD A RULE FOR ME! - # (:allocs, Base.FastMath.rem_fast, 5.0, 2.0), # error -- ADD A RULE FOR ME! - (false, :allocs, nothing, Base.FastMath.sign_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.sign_fast, -5.0), - (false, :allocs, nothing, Base.FastMath.sin_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.cos_fast, 4.0), - (false, :allocs, nothing, Base.FastMath.sincos_fast, 4.0), - (false, :allocs, nothing, Base.FastMath.sinh_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.sqrt_fast, 5.0), - (false, :allocs, nothing, Base.FastMath.sub_fast, 5.0, 4.0), - (false, :allocs, nothing, Base.FastMath.tan_fast, 4.0), - (false, :allocs, nothing, Base.FastMath.tanh_fast, 0.5), - ] + test_cases = reduce( + vcat, + map([Float64, Float32]) do P + C = P === Float64 ? ComplexF64 : ComplexF32 + return Any[ + (false, :allocs, nothing, Base.FastMath.abs2_fast, P(-5.0)), + (false, :allocs, nothing, Base.FastMath.abs_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.acos_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.acosh_fast, P(1.2)), + (false, :allocs, nothing, Base.FastMath.add_fast, P(1.0), P(2.0)), + (false, :allocs, nothing, Base.FastMath.angle_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.asin_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.asinh_fast, P(1.3)), + (false, :allocs, nothing, Base.FastMath.atan_fast, P(5.4)), + (false, :allocs, nothing, Base.FastMath.atanh_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.cbrt_fast, P(0.4)), + (false, :allocs, nothing, Base.FastMath.cis_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.cmp_fast, P(0.5), P(0.4)), + (false, :allocs, nothing, Base.FastMath.conj_fast, P(0.4)), + (false, :allocs, nothing, Base.FastMath.conj_fast, C(0.5, 0.4)), + (false, :allocs, nothing, Base.FastMath.cos_fast, P(0.4)), + (false, :allocs, nothing, Base.FastMath.cosh_fast, P(0.3)), + (false, :allocs, nothing, Base.FastMath.div_fast, P(5.0), P(1.1)), + (false, :allocs, nothing, Base.FastMath.eq_fast, P(5.5), P(5.5)), + (false, :allocs, nothing, Base.FastMath.eq_fast, P(5.5), P(5.4)), + (false, :allocs, nothing, Base.FastMath.expm1_fast, P(5.4)), + (false, :allocs, nothing, Base.FastMath.ge_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.ge_fast, P(4.0), P(5.0)), + (false, :allocs, nothing, Base.FastMath.gt_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.gt_fast, P(4.0), P(5.0)), + (false, :allocs, nothing, Base.FastMath.hypot_fast, P(5.1), P(3.2)), + (false, :allocs, nothing, Base.FastMath.inv_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.isfinite_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.isinf_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.isnan_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.issubnormal_fast, P(0.3)), + (false, :allocs, nothing, Base.FastMath.le_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.log10_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.log1p_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.log2_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.log_fast, P(0.5)), + (false, :allocs, nothing, Base.FastMath.lt_fast, P(0.5), P(4.0)), + (false, :allocs, nothing, Base.FastMath.lt_fast, P(5.0), P(0.4)), + (false, :allocs, nothing, Base.FastMath.max_fast, P(5.0), P(4.0)), + ( + false, + :none, + nothing, + Base.FastMath.maximum!_fast, + sin, + P.([0.0, 0.0]), + P.([5.0 4.0; 3.0 2.0]), + ), + (false, :allocs, nothing, Base.FastMath.maximum_fast, P.([5.0, 4.0, 3.0])), + (false, :allocs, nothing, Base.FastMath.min_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.min_fast, P(4.0), P(5.0)), + ( + false, + :none, + nothing, + Base.FastMath.minimum!_fast, + sin, + P.([0.0, 0.0]), + P.([5.0 4.0; 3.0 2.0]), + ), + (false, :allocs, nothing, Base.FastMath.minimum_fast, P.([5.0, 3.0, 4.0])), + (false, :allocs, nothing, Base.FastMath.minmax_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.mul_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.ne_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.pow_fast, P(5.0), P(2.0)), + # (:allocs, Base.FastMath.pow_fast, P(5.0), 2), # errors -- NEEDS RULE! + # (:allocs, Base.FastMath.rem_fast, P(5.0), P(2.0)), # error -- NEEDS RULE! + (false, :allocs, nothing, Base.FastMath.sign_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.sign_fast, P(-5.0)), + (false, :allocs, nothing, Base.FastMath.sin_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.cos_fast, P(4.0)), + (false, :allocs, nothing, Base.FastMath.sincos_fast, P(4.0)), + (false, :allocs, nothing, Base.FastMath.sinh_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.sqrt_fast, P(5.0)), + (false, :allocs, nothing, Base.FastMath.sub_fast, P(5.0), P(4.0)), + (false, :allocs, nothing, Base.FastMath.tan_fast, P(4.0)), + (false, :allocs, nothing, Base.FastMath.tanh_fast, P(0.5)), + ] + end, + ) memory = Any[] return test_cases, memory end