-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cispi
has poor performance
#1276
Comments
It should use the rule, but I suspect something else is causing broadcasting to take a slow path here (you can guess because that path requires a couple of allocations per element to store pullbacks, and x1 million is not far off the reported amount) |
Ah, this is the same issue as #961 (comment) because |
But And the linked issue was only for CUDA |
I would not say it has good performance. Even though the memory allocations are better, they and overall runtime are strictly worse than a function with a more "expensive" forward pass. manyexp(x) = exp(x) * exp(x) + exp(x)
x = ones(1024, 1024)
julia> @btime cis.($x);
10.829 ms (2 allocations: 16.00 MiB)
julia> @btime manyexp.($x);
17.025 ms (2 allocations: 8.00 MiB)
julia> @btime gradient(x -> sum(abs2, cis.(x)), $x);
78.728 ms (38 allocations: 160.00 MiB)
julia> @btime gradient(x -> sum(abs2, manyexp.(x)), $x);
22.156 ms (31 allocations: 40.00 MiB) It just so happens that the pullback for
The linked comment discusses a general-purpose solution that works for the CPU path as well. I'm not sure what level of effort would be required to implement it. |
I recently discovered JuliaDiff/ForwardDiff.jl#583, which enables real -> complex function differentiation in ForwardDiff. @mcabbott do you think that could be adapted to what Zygote does for broadcast? |
Closed in #1324 julia> f1(x) = sum(abs2, cispi.(x))
julia> f2(x) = sum(abs2, cis.(x))
julia> @btime Zygote.gradient(f1, $(ones((1024,1024))));
29.061 ms (34 allocations: 72.00 MiB)
julia> @btime Zygote.gradient(f2, $(ones((1024,1024))));
28.697 ms (34 allocations: 72.00 MiB) |
Hi!
cispi
seems to have type infereability issues:Zygote uses ChainRules, doesn't it?
There is this rule for
sincospi
implemented which should be called bycispi
.Best,
Felix
The text was updated successfully, but these errors were encountered: