Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ldexp does not maintain type of Float32 arguments #604

Closed
ptiede opened this issue Nov 1, 2022 · 4 comments
Closed

ldexp does not maintain type of Float32 arguments #604

ptiede opened this issue Nov 1, 2022 · 4 comments

Comments

@ptiede
Copy link
Contributor

ptiede commented Nov 1, 2022

Hi,

I uncovered this during FluxML/Zygote.jl#1324 (see comment FluxML/Zygote.jl#1324 (comment)). The specific rule is

https://github.com/JuliaDiff/DiffRules.jl/blob/489e2942e10776c96ab70c5044f595951bbbcaab/src/rules.jl#L88

The use of NaN for the second argument the following branch

using ForwardDIff
x = Dual(1f0, 1f0)
@code_warntype Base.ldexp(x)

# output
MethodInstance for ldexp(::ForwardDiff.Dual{Nothing, Float32, 1}, ::Int64)
  from ldexp(x::ForwardDiff.Dual{Tx}, y::Real) where Tx in ForwardDiff at /home/ptiede/.julia/dev/ForwardDiff/src/dual.jl:144
Static Parameters
  Tx = Nothing
Arguments
  #self#::Core.Const(ldexp)
  x::ForwardDiff.Dual{Nothing, Float32, 1}
  y::Int64
Locals
  dvx::Float64
  val::Float32
  719::Float64
  718::Float32
  vx::Float32
Body::ForwardDiff.Dual{Nothing, Float64, 1}
1nothing%2  = Base.getproperty(ForwardDiff, :value)::Core.Const(ForwardDiff.value)
│         (vx = (%2)(x))
│         (718 = ForwardDiff.ldexp(vx, y))
│         (719 = ForwardDiff.exp2(y))
│         (val = 718)
│         (dvx = 719)
│   %8  = Base.getproperty(ForwardDiff, :dual_definition_retval)::Core.Const(ForwardDiff.dual_definition_retval)
│   %9  = Core.apply_type(ForwardDiff.Val, $(Expr(:static_parameter, 1)))::Core.Const(Val{Nothing})
│   %10 = (%9)()::Core.Const(Val{Nothing}())
│   %11 = val::Float32%12 = dvx::Float64%13 = Base.getproperty(ForwardDiff, :partials)::Core.Const(ForwardDiff.partials)
│   %14 = (%13)(x)::ForwardDiff.Partials{1, Float32}%15 = (%8)(%10, %11, %12, %14)::ForwardDiff.Dual{Nothing, Float64, 1}
└──       return %15

Tracking this down it appears this is because the diff rule for ldexp is

@define_diffrule Base.ldexp(x, y)  = :( exp2($y)                                                ), :NaN

and since y is an Int, exp2(y) returns a Float64. I was originally going to file this as an issue in DiffRules.jl but I couldn't figure out how change the rule to respect Float32.

@mcabbott
Copy link
Member

mcabbott commented Nov 2, 2022

Some rules there contain oftype(float($x), NaN), which I presume was inserted for similar reasons. But many contain just :NaN.

@ptiede
Copy link
Contributor Author

ptiede commented Nov 2, 2022

Good point! I was being silly, it looks like defining the rule as

@define_diffrule Base.ldexp(x, y)  = :( z = exp2(one(float($x)) * $y)                                                ), :NaN

fixes this. I'll file an issue at DiffRules tomorrow.

@mcabbott
Copy link
Member

mcabbott commented Nov 2, 2022

Oh right. I was looking at the wrong bit. Another option might be oftype(float(x), exp2(y)), as I think exp2(::Int) is just some bit-shifting while exp2(::Float32) is real work.

@ptiede
Copy link
Contributor Author

ptiede commented Nov 10, 2022

Fixed in JuliaDiff/DiffRules.jl#89

@ptiede ptiede closed this as completed Nov 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants