Skip to content

Commit

Permalink
Run everything on 1.10 (#385)
Browse files Browse the repository at this point in the history
* Run everything on 1.10

* Expand buildkite runner into matrix

* Also run integration and ext testing on LTS

* Try LTS

* Just write 1.10

* Julia version number in matrix

* Document support policy

* Document support policy

* Restrict Julia compat properly

* Fix DynamicPPL tests on 1.10

* Simplify gitignore

* Formatting

* Relax performance bounds on nnlib scatter

* Improve error message

* Fix a problem

* Fix error introduced in the last PR

* Fix coverage

* Fix coverage

* Fix LuxLib

* Prevent small union inlining

* use static if everywhere

* Revert previous change

* Patch the problem

* Run error printing

* Bump patch
  • Loading branch information
willtebbutt authored Nov 23, 2024
1 parent eaf1fb8 commit 5397d28
Show file tree
Hide file tree
Showing 20 changed files with 170 additions and 48 deletions.
7 changes: 5 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ env:
SECRET_CODECOV_TOKEN: "nkcRFVXdaPNAbiI0x3qK/XUG8rWjBc8fU73YEyP35SeS465XORqrIYrHUbHuJTRyeyqNRdsHaBcV1P7TBbKAaTQAjHQ1Q0KYfd0uRMSWpZSCgTBz5AwttAxVfFrX+Ky3PzTi2TfDe0uPFZtFo0Asq6sUEr1on+Oo+j+q6br2NK6CrA5yKKuTX4Q2V/UPOIK4vNXY3+zDTKSNtr+HQOlcVEeRIk/0ZQ78Cjd52flEaVw8GWo/CC4YBzLtcOZgaFdgOTEDNHMr0mw6zLE4Y6nxq4lHVSoraSjxjhkB0pXTZ1c51yHX8Jc+q6HC5s87+2Zq5YtsuQSGao+eMtkTAYwfLw==;U2FsdGVkX18z27J3+gNgxsPNnXA0ad4LvZnXeohTam7/6UPqX5+3BYI0tAiVkCho4vlJyL7dd8JEyNtk9BFXsg=="

steps:
- label: "Julia v1"
- label: "Julia v{{matrix}}"
plugins:
- JuliaCI/julia#v1:
version: "1"
version: "{{matrix}}"
- JuliaCI/julia-coverage#v1:
dirs:
- src
Expand All @@ -19,3 +19,6 @@ steps:
env:
LABEL: cuda
TEST_TYPE: ext
matrix:
- "1"
- "1.10"
9 changes: 2 additions & 7 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@ jobs:
'rrules/twice_precision',
]
version:
- 'lts'
- '1'
arch:
- x64
include:
- test_group: 'basic'
version: '1.10'
arch: x64
- test_group: 'basic'
version: '1.10'
arch: x86
Expand Down Expand Up @@ -95,12 +93,9 @@ jobs:
]
version:
- '1'
- 'lts'
arch:
- x64
include:
- test_group: {test_type: 'integration_testing', label: 'turing'}
version: '1.10'
arch: x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
6 changes: 1 addition & 5 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
/Manifest.toml
/Manifest-v1.11.toml
Manifest*
dev
bench/Manifest.toml
analysis_results
.vscode
profile.pb.gz
scratch.jl
docs/build/
docs/site/
docs/Manifest.toml
Manifest.toml
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.48"
version = "0.4.49"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -66,7 +66,7 @@ Setfield = "1"
SpecialFunctions = "2"
StableRNGs = "1"
Test = "1"
julia = "1"
julia = "1.10"

[extras]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ If you encounter a new version of `Mooncake.jl` in the wild, please consult this

# Getting Started

Check that you're running a version of Julia that Mooncake.jl supports.
See the `SUPPORT_POLICY.md` file for more info.

There are several ways to interact with `Mooncake.jl`.
The one that we recommend people begin with is [`DifferentiationInterface.jl`](https://github.com/gdalle/DifferentiationInterface.jl/).
For example, use it as follows to compute the gradient of a function mapping a `Vector{Float64}` to `Float64`.
Expand Down
42 changes: 42 additions & 0 deletions SUPPORT_POLICY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Summary

At any given point in time, `Mooncake.jl` supports the current Long Term Support (LTS) release of Julia, and the latest release version of Julia 1.
Consequently, the versions of Julia which are officially supported by `Mooncake.jl` will change (almost) _immediately_ whenever a new Julia LTS version is declared, or a minor release of Julia is made.

For example, the LTS is 1.10 and the latest release is 1.11 at the time of writing. When 1.12 is released, we will
1. bump the Julia compat bounds in `Mooncake.jl` to require either 1.10 or 1.12,
1. cease to run CI on 1.11,
1. cease to provide bug fixes for 1.11,
1. cease to accept 1.11-specific bug fixes, as we will not be running CI for 1.11 and therefore will not be able to test that they have worked.

In short: as far as `Mooncake.jl`'s future releases, 1.11 ceases to exist the moment 1.12 is released.

Note that these changes are not applied retrospectively to existing releases of `Mooncake.jl`.
Suppose that `Mooncake.jl` is at `v0.4.50` when 1.12 is released.
Then the above changes would be relevant to `Mooncake.jl` versions `v0.4.51` and higher.

# Patch Versions

The above only discussed minor versions of Julia (1.10, 1.11, 1.12, etc).
However, it also applies to patch versions of Julia.
For example, at the time of writing, Julia version 1.10.6 is _actually_ the LTS, and 1.11.1 the current release of Julia.
The moment that 1.10.7 is released, we will cease to run any CI on 1.10.6, and will not accept fixes for it.
The same is true of 1.11.2.

Since patch releases of Julia are less invasive than minor releases, this should generally not cause users problems.

# Context

In order to support a particular version of Julia, we must
1. always run CI for that version,
1. accept and proactively produce fixes for that version,
1. maintain version-specific code in the `Mooncake.jl` codebase.

This requires a surprisingly amount of overhead to the development of `Mooncake.jl`, and has the potential to substantially increase the complexity of the codebase.
All of this makes it harder to improve `Mooncake.jl`.
Consequently, this policy represents a decision to tradeoff support for a range of minor Julia versions in exchange for easing the development burden associated to `Mooncake.jl`.

## Why not gently drop support?

In the JuliaGaussianProcesses ecosystem, we had a loosely-defined policy of keeping support for an older version until we ran into a large problem which could not be fixed easily, at which point we would drop support.
While this sounds appealing, in practice it makes it hard to know exactly when to drop support for a particular version of Julia, increases the burden for maintainers, and makes it hard for users to know exactly what to expect.
27 changes: 27 additions & 0 deletions ext/MooncakeLuxLibSLEEFPiratesExtension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using LuxLib, Mooncake, SLEEFPirates
using Base: IEEEFloat
using Mooncake: @from_rrule, DefaultCtx

@static if VERSION >= v"1.11"

# Workaround for package load order problems. See
# https://github.com/JuliaLang/julia/issues/56204#issuecomment-2419553167 for more context.
function __init__()
Expand Down Expand Up @@ -31,4 +33,29 @@ function __init__()
end
end

else

for f in Any[
LuxLib.NNlib.sigmoid_fast,
LuxLib.NNlib.softplus,
LuxLib.NNlib.logsigmoid,
LuxLib.NNlib.swish,
LuxLib.NNlib.lisht,
Base.tanh,
LuxLib.NNlib.tanh_fast,
]
f_fast = LuxLib.Impl.sleefpirates_fast_act(f)
@eval @from_rrule DefaultCtx Tuple{typeof($f_fast), IEEEFloat}
@eval @from_rrule(
DefaultCtx,
Tuple{
typeof(Broadcast.broadcasted),
typeof($f_fast),
Union{IEEEFloat, Array{<:IEEEFloat}},
},
)
end

end

end
4 changes: 2 additions & 2 deletions src/interpreter/abstract_interpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ function CC.method_table(interp::MooncakeInterpreter)
return CC.OverlayMethodTable(interp.world, mooncake_method_table)
end

if VERSION < v"1.11.0"
@static if VERSION < v"1.11.0"
CC.get_world_counter(interp::MooncakeInterpreter) = interp.world
get_inference_world(interp::CC.AbstractInterpreter) = CC.get_world_counter(interp)
else
Expand Down Expand Up @@ -160,7 +160,7 @@ function Core.Compiler.abstract_call_gf_by_type(
end
end

if VERSION < v"1.11-"
@static if VERSION < v"1.11-"

function CC.inlining_policy(
interp::MooncakeInterpreter{C},
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Observe that this information means that whether or not something is a primitive
particular context depends only on static information, not any run-time information that
might live in a particular instance of `Ctx`.
"""
is_primitive(::Type{MinimalCtx}, ::Any) = false
is_primitive(::Type{MinimalCtx}, sig::Type{<:Tuple}) = false
is_primitive(::Type{DefaultCtx}, sig) = is_primitive(MinimalCtx, sig)

"""
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ __get_arg(x::QuoteNode) = x.value
__get_arg(x) = x

# memoryrefget and memoryrefset! were introduced in 1.11.
if VERSION >= v"1.11-"
@static if VERSION >= v"1.11-"

"""
lift_memoryrefget_and_memoryrefset_builtins(inst)
Expand Down
4 changes: 2 additions & 2 deletions src/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true)
ir = __strip_coverage!(ir)
ir = CC.sroa_pass!(ir, inline_state)

if VERSION < v"1.11-"
@static if VERSION < v"1.11-"
ir = CC.adce_pass!(ir, inline_state)
else
ir, _ = CC.adce_pass!(ir, inline_state)
Expand Down Expand Up @@ -227,7 +227,7 @@ function lookup_ir(interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_u
asts = []
for match in get_matches(matches.matches)
match = match::Core.MethodMatch
if VERSION < v"1.11-"
@static if VERSION < v"1.11-"
meth = Base.func_for_method_checked(match.method, tt, match.sparams)
(code, ty) = CC.typeinf_ircode(
interp, meth, match.spec_types, match.sparams, optimize_until
Expand Down
58 changes: 47 additions & 11 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ get_primal_type(::ADInfo, x) = _typeof(x)
function get_primal_type(::ADInfo, x::GlobalRef)
return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty
end
function get_primal_type(::ADInfo, x::Expr)
x.head === :boundscheck && return Bool
error("Unrecognised expression $x found in argument slot.")
end

"""
get_rev_data_id(info::ADInfo, x)
Expand Down Expand Up @@ -394,7 +398,11 @@ function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo)
rvs = new_inst(Expr(:call, increment_ref!, rdata_id, Argument(2)))
return ad_stmt_info(line, nothing, inc_args(stmt), rvs)
else
fwds = ReturnNode(const_codual(stmt.val, info))
const_id = ID()
fwds = [
(const_id, new_inst(const_codual_stmt(stmt.val, info))),
(ID(), new_inst(ReturnNode(const_id))),
]
return ad_stmt_info(line, nothing, fwds, nothing)
end
end
Expand Down Expand Up @@ -457,7 +465,11 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo)
else
# If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to
# do on the reverse-pass.
fwds = PiNode(const_codual(stmt.val, info), fcodual_type(_type(stmt.typ)))
const_id = ID()
fwds = [
(const_id, new_inst(const_codual_stmt(stmt.val, info))),
(line, new_inst(PiNode(const_id, fcodual_type(_type(stmt.typ))))),
]
rvs = nothing
end

Expand All @@ -475,11 +487,11 @@ end
function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo)
isconst(stmt) && return const_ad_stmt(stmt, line, info)

x = const_codual(getglobal(stmt.mod, stmt.name), info)
globalref_id = ID()
const_id, globalref_id = ID(), ID()
fwds = [
(globalref_id, new_inst(stmt)),
(line, new_inst(Expr(:call, __verify_const, globalref_id, x))),
(const_id, new_inst(const_codual_stmt(getglobal(stmt.mod, stmt.name), info))),
(line, new_inst(Expr(:call, __verify_const, globalref_id, const_id))),
]
return ad_stmt_info(line, nothing, fwds, nothing)
end
Expand All @@ -502,8 +514,22 @@ make_ad_stmts!(stmt, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info)
Implementation of `make_ad_stmts!` used for constants.
"""
function const_ad_stmt(stmt, line::ID, info::ADInfo)
x = const_codual(stmt, info)
return ad_stmt_info(line, nothing, x isa ID ? Expr(:call, identity, x) : x, nothing)
return ad_stmt_info(line, nothing, const_codual_stmt(stmt, info), nothing)
end

"""
const_codual_stmt(stmt, info::ADInfo)
Returns a `:call` expression which will return a `CoDual` whose primal is `stmt`, and whose
tangent is whatever `uninit_tangent` returns.
"""
function const_codual_stmt(stmt, info::ADInfo)
v = get_const_primal_value(stmt)
if safe_for_literal(v)
return Expr(:call, uninit_fcodual, v)
else
return Expr(:call, identity, add_data!(info, uninit_fcodual(v)))
end
end

"""
Expand All @@ -519,10 +545,21 @@ function const_codual(stmt, info::ADInfo)
return safe_for_literal(v) ? x : add_data!(info, x)
end

safe_for_literal(v) = v isa String || v isa Type || isbitstype(_typeof(v))
function safe_for_literal(v)
v isa Expr && v.head === :boundscheck && return true
v isa String && return true
v isa Type && return true
v isa Tuple && all(safe_for_literal, v) && return true
isbitstype(_typeof(v)) && return true
return false
end

inc_or_const(stmt, info::ADInfo) = is_active(stmt) ? __inc(stmt) : const_codual(stmt, info)

function inc_or_const_stmt(stmt, info::ADInfo)
return is_active(stmt) ? Expr(:call, identity, __inc(stmt)) : const_codual_stmt(stmt, info)
end

"""
get_const_primal_value(x::GlobalRef)
Expand Down Expand Up @@ -616,7 +653,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
# Make arguments to rrule call. Things which are not already CoDual must be made so.
codual_arg_ids = map(_ -> ID(), collect(args))
codual_args = map(args, codual_arg_ids) do arg, id
return (id, new_inst(Expr(:call, identity, inc_or_const(arg, info))))
return (id, new_inst(inc_or_const_stmt(arg, info)))
end

# Make call to rule.
Expand Down Expand Up @@ -691,8 +728,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)

elseif Meta.isexpr(stmt, :copyast)
# Get constant out and shove it in shared storage.
x = const_codual(stmt.args[1], info)
return ad_stmt_info(line, nothing, Expr(:call, identity, x), nothing)
return ad_stmt_info(line, nothing, const_codual_stmt(stmt.args[1], info), nothing)

elseif Meta.isexpr(stmt, :loopinfo)
# Cannot pass loopinfo back through the optimiser for some reason.
Expand Down
9 changes: 8 additions & 1 deletion src/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,17 @@ function rrule!!(f::CoDual{<:Core.Builtin}, args...)
"which is specialised to this case. " *
"Either way, please consider commenting on " *
"https://github.com/compintell/Mooncake.jl/issues/208/ so that the issue can be " *
"fixed more widely."
"fixed more widely.\n" *
"For reproducibility, note that the full signature is:\n" *
"$(typeof((f, args...)))"
))
end

function Base.showerror(io::IO, err::MissingRuleForBuiltinException)
print(io, "MissingRuleForBuiltinException: ")
println(io, err.msg)
end

module IntrinsicsWrappers

using Base: IEEEFloat
Expand Down
3 changes: 3 additions & 0 deletions src/rrules/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ function rrule!!(::CoDual{typeof(Base.FastMath.sincos)}, x::CoDual{P}) where {P<
return CoDual(y, NoFData()), sincos_fast_adj!!
end

@is_primitive MinimalCtx Tuple{typeof(Base.log), Union{IEEEFloat, Int}}
@zero_adjoint MinimalCtx Tuple{typeof(log), Int}

function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:fastmath})
test_cases = Any[
(false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, 0.5),
Expand Down
2 changes: 1 addition & 1 deletion src/rrules/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# Required to avoid an ambiguity.
@zero_adjoint MinimalCtx Tuple{Type{Symbol}, TypeVar, Type}

if VERSION >= v"1.11-"
@static if VERSION >= v"1.11-"
@zero_adjoint MinimalCtx Tuple{typeof(Random.hash_seed), Vararg}
@zero_adjoint MinimalCtx Tuple{typeof(Base.dataids), Memory}
end
Expand Down
2 changes: 1 addition & 1 deletion src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ end
_new_excluded(::Type) = false
_new_excluded(::Type{<:Union{String}}) = true

if VERSION < v"1.11-"
@static if VERSION < v"1.11-"
# Prior to 1.11, Arrays are special objects, with special constructors that don't
# involve calling the `:new` instruction. From 1.11 onwards, they behave more like
# regular mutable composite types, so calling `_new_` becomes meaningful.
Expand Down
Loading

2 comments on commit 5397d28

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120035

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.49 -m "<description of version>" 5397d284e89f08afbb22cce32483c1ada6baec10
git push origin v0.4.49

Please sign in to comment.