You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The @from_rrule macro which lets you make use of existing ChainRules.rrules to define new Tapir.rrule!!s is quite limited in what inputs it supports. In particular, it permits only
Float64s, and
types which are non-differentiable.
Moreover, the non-differentiable types must be bits types, and the ChainRules.rrule must return the correct type. e.g. it won't work if the ChainRules.rrule returns a Float64 gradient for an Int primal.
One narrow way to extend this would be to assume that we can just return a Tapir.NoRData whenever the primal is a non-differentiable bits type, regardless what the ChainRules.rrule returns.
The
@from_rrule
macro which lets you make use of existingChainRules.rrule
s to define newTapir.rrule!!
s is quite limited in what inputs it supports. In particular, it permits onlyFloat64
s, andMoreover, the non-differentiable types must be bits types, and the
ChainRules.rrule
must return the correct type. e.g. it won't work if theChainRules.rrule
returns aFloat64
gradient for anInt
primal.One narrow way to extend this would be to assume that we can just return a
Tapir.NoRData
whenever the primal is a non-differentiable bits type, regardless what theChainRules.rrule
returns.Motivated by TuringLang/Bijectors.jl#319 (comment) .
The text was updated successfully, but these errors were encountered: