Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding complex broadcasting for gradients on the GPU #1324
Adding complex broadcasting for gradients on the GPU #1324
Changes from 8 commits
807d689
2972faf
51dc882
0635ba4
739e896
a0e21e6
5a83493
6742644
2aa06c6
851ab33
f42d940
95a6b5b
b29f090
40fdb29
15c33ad
5e53ada
2c4857b
9fc2180
c685798
efc4f67
51e3ba3
c888db8
83ed917
7b0044b
2bb3b65
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be
Union{Dual, Dual{<:Complex}}
? You'd have to try pretty hard but I think the Complex path expects Dual inside.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought is was the other way around? At least that is what I am constructing in the
dual_function
.ForwardDiff.jl
also definesDual <: Real
so I think defining it the other way would break things. However, I probably want to be a little more specific here and doThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry, that's what I was thinking but didn't type...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why define
x
here at all?Also, this
y
has zero imaginary part.rand(ComplexF64, 50)
would be a stronger test.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops! That
x
was for a test I was doing on my machine. I think overall that the testing could be a bit better though so I've added another test that uses both real and complex arguments. I probably need to add some additional tests.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. I think
x.^2 .*y .+ y
uses only functions which have special rules, and ought to work without this PR. I think even broadcasting trivial functions likeadd(x,y) = x+y
will change the path it takes. But messy examples (e.g. with trig, conj/real/imag, in all sorts of ways) are much more likely to expose mistakes like aconj
missing somewhere.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to invent some functions, did not try them on GPU:
But locally, with this branch, I expected them to use the new code... but adding printing doesn't seem to work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I looked into this and this occurred because I hadn't added a Complex method for
_dual_safearg
. When I added this some issues started to appear. One of them was because the partials for the real and complex parts had different lengths.However, that is not the big issue. The big issue is that certain functions seem to be causing some type instabilities during the evaluation of the dual numbers. For instance,
Has a problem where the broadcast can't seem to figure out that eltype of the partial field in
Dual
should be aFloat32
. What is really annoying is that this problem does not occur forFloat64
where I getThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok looking into this more. It appears the
log
withComplex{Dual{Float32}}
arguments is type unstable.My guess is that this occurs because there isn't using the specific forward rule for a complex number for log, or likely any common functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is weird,
@code_warntype log(Dual(1f0, 1f0) + im)
is bad. InsideBase.ssqs
, it looks likeldexp(Dual(1f0, 2f0), 3)
makes a Float64 dual, by a method from ForwardDiff.Anyway not this PR's problem! Maybe make an issue on ForwardDiff (or DiffRules) and test inference etc. with other functions here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok sounds good! I'll skip log for now and make tests for other functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright I was able to add the last test,
and everything passes! The other two tests suggested both run into the
ldexp
problem with Float32. I have opened up an issue JuliaDiff/ForwardDiff.jl#604 detailing the problem. The good news is that when I fix the problem locally all the tests pass!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here are a couple of updates on my end. First, I just realized I was running the previous test on the CPU. When I run it on the GPU, I get a scalar indexing error. The stack trace is
From the look of the stack trace, this isn't due to this pull request. In fact, if I change the function definition to
then everything is fine, so my guess is that this is some funkiness related to the pullback of an adjoint of a real vector. I'll take a look into this, but I am not sure if that's part of this pull request.
Second, I have added some additional tests to ensure we hit every one of the
_broadcast_forward
branches.