Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start forward mode AD #389

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
be316ff
Start forward mode prototype
gdalle Nov 24, 2024
deac913
First working autodiff
gdalle Nov 24, 2024
9c96c8d
Docstring
gdalle Nov 24, 2024
136aff6
Apply suggestions from code review
gdalle Nov 24, 2024
f65cc53
Moving files around
gdalle Nov 24, 2024
053a8bb
Primitives already known
gdalle Nov 24, 2024
6d8ec04
Merge branch 'main' into gd/forward
gdalle Nov 25, 2024
a3107a8
Keep pushing forward (pun intended)
gdalle Nov 25, 2024
2836ac8
Still buggy, don't touch
gdalle Nov 25, 2024
09d63bd
Keep instruction mapping one to one
gdalle Nov 26, 2024
fa679eb
Use replace_call
gdalle Nov 26, 2024
a68257c
Ignore code cov
gdalle Nov 27, 2024
7a096ba
No Aqua piracies test
gdalle Nov 27, 2024
46c3e5a
Start control flow
gdalle Nov 28, 2024
ad3f98a
Fix intrinsic
gdalle Nov 28, 2024
9071574
Import
gdalle Nov 28, 2024
dcfe282
Typos
gdalle Nov 28, 2024
e44380d
Co-authored-by: Will Tebbutt <[email protected]>
gdalle Dec 6, 2024
dd89e57
Figure out incremental additions
gdalle Dec 6, 2024
9bdb57f
Initial test case additions
willtebbutt Dec 6, 2024
4bb9911
Formatting
willtebbutt Dec 6, 2024
9b037e7
Add verify_dual_type
willtebbutt Dec 6, 2024
6dea624
test_frule_interface runs
willtebbutt Dec 6, 2024
a614846
Fix ReturnNode
willtebbutt Dec 6, 2024
eadae95
Correctness testing runs
willtebbutt Dec 6, 2024
345b3fd
Add randn_dual
willtebbutt Dec 6, 2024
f58c394
Improve sin and cos frules
willtebbutt Dec 6, 2024
c8d8895
Performance tests run
willtebbutt Dec 6, 2024
578e41b
Tidy up implementation
willtebbutt Dec 6, 2024
b5d34b2
Standard testing infrastructure
willtebbutt Dec 6, 2024
205e716
Fix typos
willtebbutt Dec 6, 2024
d328db0
Fix return node to return dual
gdalle Dec 6, 2024
66a48c8
Handle PiNode
gdalle Dec 6, 2024
e455cf6
Deleted line
gdalle Dec 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 21 additions & 149 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,31 @@ jobs:
test_group: [
'quality',
'basic',
'rrules/avoiding_non_differentiable_code',
'rrules/blas',
'rrules/builtins',
'rrules/fastmath',
'rrules/foreigncall',
'rrules/functionwrappers',
'rrules/iddict',
'rrules/lapack',
'rrules/linear_algebra',
'rrules/low_level_maths',
'rrules/memory',
'rrules/misc',
'rrules/new',
'rrules/tasks',
'rrules/twice_precision',
# 'rrules/avoiding_non_differentiable_code',
# 'rrules/blas',
# 'rrules/builtins',
# 'rrules/fastmath',
# 'rrules/foreigncall',
# 'rrules/functionwrappers',
# 'rrules/iddict',
# 'rrules/lapack',
# 'rrules/linear_algebra',
# 'rrules/low_level_maths',
# 'rrules/memory',
# 'rrules/misc',
# 'rrules/new',
# 'rrules/tasks',
# 'rrules/twice_precision',
]
version:
- 'lts'
# - 'lts'
- '1'
arch:
- x64
include:
- test_group: 'basic'
version: '1.10'
arch: x86
# include:
# - test_group: 'basic'
# version: '1.10'
# arch: x86
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand All @@ -66,132 +66,4 @@ jobs:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
extra:
name: ${{matrix.test_group.test_type}}-${{ matrix.test_group.label }}-${{ matrix.version }}-${{ matrix.arch }}
runs-on: ubuntu-latest
if: github.event_name != 'schedule'
strategy:
fail-fast: false
matrix:
test_group: [
{test_type: 'ext', label: 'differentiation_interface'},
{test_type: 'ext', label: 'dynamic_ppl'},
{test_type: 'ext', label: 'luxlib'},
{test_type: 'ext', label: 'nnlib'},
{test_type: 'ext', label: 'special_functions'},
{test_type: 'integration_testing', label: 'array'},
{test_type: 'integration_testing', label: 'bijectors'},
{test_type: 'integration_testing', label: 'diff_tests'},
{test_type: 'integration_testing', label: 'distributions'},
{test_type: 'integration_testing', label: 'gp'},
{test_type: 'integration_testing', label: 'logexpfunctions'},
{test_type: 'integration_testing', label: 'lux'},
{test_type: 'integration_testing', label: 'battery_tests'},
{test_type: 'integration_testing', label: 'misc_abstract_array'},
{test_type: 'integration_testing', label: 'temporalgps'},
{test_type: 'integration_testing', label: 'turing'},
]
version:
- '1'
- 'lts'
arch:
- x64
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
include-all-prereleases: false
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- run: |
if [ ${{ matrix.test_group.test_type }} == 'ext' ]; then
julia --code-coverage=user --eval 'include("test/run_extra.jl")'
else
julia --eval 'include("test/run_extra.jl")'
fi
env:
LABEL: ${{ matrix.test_group.label }}
TEST_TYPE: ${{ matrix.test_group.test_type }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v5
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
perf:
name: "Performance (${{ matrix.perf_group }})"
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
perf_group:
- 'hand_written'
- 'derived'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
arch: x64
include-all-prereleases: false
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- run: julia --project=bench --eval 'include("bench/run_benchmarks.jl"); main()'
env:
PERF_GROUP: ${{ matrix.perf_group }}
shell: bash
compperf:
name: "Performance (inter-AD)"
runs-on: ubuntu-latest
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
strategy:
fail-fast: false
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
arch: x64
include-all-prereleases: false
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- run: mkdir bench_results
- run: julia --project=bench --eval 'include("bench/run_benchmarks.jl"); main()'
env:
PERF_GROUP: 'comparison'
GKSwstype: '100'
shell: bash
- uses: actions/upload-artifact@v4
with:
name: benchmarking-results
path: bench_results/
# Useful code for testing action.
# - run: |
# text="this is line one
# this is line two
# this is line three"
# echo "$text" > benchmark_results.txt
- name: Read file content
id: read-file
run: |
{
echo "table<<EOF"
cat bench/benchmark_results.txt
echo "EOF"
} >> $GITHUB_OUTPUT
- name: Find Comment
uses: peter-evans/find-comment@v3
id: fc
with:
issue-number: ${{ github.event.pull_request.number }}
comment-author: github-actions[bot]
- id: post-report-as-pr-comment
name: Post Report as Pull Request Comment
uses: peter-evans/create-or-update-comment@v4
with:
issue-number: ${{ github.event.pull_request.number }}
body: "Performance Ratio:\nRatio of time to compute gradient and time to compute function.\nWarning: results are very approximate! See [here](https://github.com/compintell/Mooncake.jl/tree/main/bench#inter-framework-benchmarking) for more context.\n```\n${{ steps.read-file.outputs.table }}\n```"
comment-id: ${{ steps.fc.outputs.comment-id }}
edit-mode: replace

32 changes: 0 additions & 32 deletions .github/workflows/documentation.yml

This file was deleted.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ profile.pb.gz
scratch.jl
docs/build/
docs/site/
playground.jl
16 changes: 16 additions & 0 deletions src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ using Base:
twiceprecision
using Base.Experimental: @opaque
using Base.Iterators: product
using Base.Meta: isexpr
using Core:
Intrinsics,
bitcast,
Expand All @@ -50,6 +51,13 @@ using FunctionWrappers: FunctionWrapper
# Needs to be defined before various other things.
function _foreigncall_ end

"""
frule!!(f::Dual, x::Dual...)

Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual`.
"""
function frule!! end

"""
rrule!!(f::CoDual, x::CoDual...)

Expand All @@ -75,8 +83,11 @@ pb!!(1.0)
"""
function rrule!! end

include("interpreter/diffractor_compiler_utils.jl")

include("utils.jl")
include("tangents.jl")
include("dual.jl")
include("fwds_rvs_data.jl")
include("codual.jl")
include("debug_mode.jl")
Expand All @@ -88,6 +99,7 @@ include(joinpath("interpreter", "ir_utils.jl"))
include(joinpath("interpreter", "bbcode.jl"))
include(joinpath("interpreter", "ir_normalisation.jl"))
include(joinpath("interpreter", "zero_like_rdata.jl"))
include(joinpath("interpreter", "s2s_forward_mode_ad.jl"))
include(joinpath("interpreter", "s2s_reverse_mode_ad.jl"))

include("tools_for_rules.jl")
Expand Down Expand Up @@ -133,9 +145,13 @@ export primal,
_add_to_primal,
_diff,
_dot,
Dual,
zero_dual,
zero_codual,
codual_type,
frule!!,
rrule!!,
build_frule,
build_rrule,
value_and_gradient!!,
value_and_pullback!!,
Expand Down
1 change: 1 addition & 0 deletions src/debug_mode.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
DebugFRule(rule) = rule # TODO: make it non-trivial

Check warning on line 1 in src/debug_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/debug_mode.jl#L1

Added line #L1 was not covered by tests

"""
DebugPullback(pb, y, x)
Expand Down
38 changes: 38 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
struct Dual{P,T}
primal::P
tangent::T
end

primal(x::Dual) = x.primal
tangent(x::Dual) = x.tangent
Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x)))
_copy(x::P) where {P<:Dual} = x

Check warning on line 9 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L8-L9

Added lines #L8 - L9 were not covered by tests

zero_dual(x) = Dual(x, zero_tangent(x))
randn_dual(rng::AbstractRNG, x) = Dual(x, randn_tangent(rng, x))

function dual_type(::Type{P}) where {P}
P == DataType && return Dual
P isa Union && return Union{dual_type(P.a),dual_type(P.b)}
P <: UnionAll && return Dual # P is abstract, so we don't know its tangent type.
return isconcretetype(P) ? Dual{P,tangent_type(P)} : Dual
end

function dual_type(p::Type{Type{P}}) where {P}
return @isdefined(P) ? Dual{Type{P},NoTangent} : Dual{_typeof(p),NoTangent}

Check warning on line 22 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L21-L22

Added lines #L21 - L22 were not covered by tests
end

_primal(x) = x
_primal(x::Dual) = primal(x)

Check warning on line 26 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L25-L26

Added lines #L25 - L26 were not covered by tests

_dual(x) = zero_dual(x)
_dual(x::Dual) = x

Check warning on line 29 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L28-L29

Added lines #L28 - L29 were not covered by tests

"""
verify_dual_type(x::Dual)

Check that the type of `tangent(x)` is the tangent type of the type of `primal(x)`.
"""
verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x))

@inline uninit_dual(x::P) where {P} = Dual(x, uninit_tangent(x))

Check warning on line 38 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L38

Added line #L38 was not covered by tests
11 changes: 11 additions & 0 deletions src/interpreter/bbcode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -889,8 +889,19 @@
"""
inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...)
inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x
inc_args(x::GotoIfNot) = GotoIfNot(__inc(x.cond), x.dest)

Check warning on line 892 in src/interpreter/bbcode.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/bbcode.jl#L892

Added line #L892 was not covered by tests
inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest)
inc_args(x::IDGotoNode) = x
inc_args(x::PiNode) = PiNode(__inc(x.val), x.typ)
function inc_args(x::PhiNode)
new_values = Vector{Any}(undef, length(x.values))
for n in eachindex(x.values)
if isassigned(x.values, n)
new_values[n] = __inc(x.values[n])

Check warning on line 900 in src/interpreter/bbcode.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/bbcode.jl#L895-L900

Added lines #L895 - L900 were not covered by tests
end
end
return PhiNode(x.edges, new_values)

Check warning on line 903 in src/interpreter/bbcode.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/bbcode.jl#L902-L903

Added lines #L902 - L903 were not covered by tests
end
function inc_args(x::IDPhiNode)
new_values = Vector{Any}(undef, length(x.values))
for n in eachindex(x.values)
Expand Down
Loading
Loading