Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cispi has poor performance #1276

Closed
roflmaostc opened this issue Aug 1, 2022 · 6 comments
Closed

cispi has poor performance #1276

roflmaostc opened this issue Aug 1, 2022 · 6 comments

Comments

@roflmaostc
Copy link

Hi!

cispi seems to have type infereability issues:

f1(x) = sum(abs2, cispi.(x))
f2(x) = sum(abs2, cis.(x))

@time Zygote.gradient(f1, ones((1024,1024)))
# 1.121823 seconds (18.87 M allocations: 648.002 MiB, 31.55% gc time)

@time Zygote.gradient(f2, ones((1024,1024)))
#   0.066483 seconds (40 allocations: 168.001 MiB, 30.13% gc time)

Zygote uses ChainRules, doesn't it?
There is this rule for sincospi implemented which should be called by cispi.

Best,

Felix

@ToucheSir
Copy link
Member

It should use the rule, but I suspect something else is causing broadcasting to take a slow path here (you can guess because that path requires a couple of allocations per element to store pullbacks, and x1 million is not far off the reported amount)

@ToucheSir
Copy link
Member

ToucheSir commented Aug 1, 2022

Ah, this is the same issue as #961 (comment) because cis and cispi return complex numbers.

@roflmaostc
Copy link
Author

But cishas good performance?

And the linked issue was only for CUDA

@ToucheSir
Copy link
Member

I would not say it has good performance. Even though the memory allocations are better, they and overall runtime are strictly worse than a function with a more "expensive" forward pass.

manyexp(x) = exp(x) * exp(x) + exp(x)
x = ones(1024, 1024)

julia> @btime cis.($x);
  10.829 ms (2 allocations: 16.00 MiB)

julia> @btime manyexp.($x);
  17.025 ms (2 allocations: 8.00 MiB)

julia> @btime gradient(x -> sum(abs2, cis.(x)), $x);
  78.728 ms (38 allocations: 160.00 MiB)

julia> @btime gradient(x -> sum(abs2, manyexp.(x)), $x);
  22.156 ms (31 allocations: 40.00 MiB)

It just so happens that the pullback for cis is simple enough to not capture any values and thus be inlined even on the slow path. cispi's pullback is not so fortunate (I think because of the reverse and splat?) and so has to be materialized for every array element.

And the linked issue was only for CUDA

The linked comment discusses a general-purpose solution that works for the CPU path as well. I'm not sure what level of effort would be required to implement it.

@ToucheSir
Copy link
Member

I recently discovered JuliaDiff/ForwardDiff.jl#583, which enables real -> complex function differentiation in ForwardDiff. @mcabbott do you think that could be adapted to what Zygote does for broadcast?

@CarloLucibello
Copy link
Member

Closed in #1324

julia> f1(x) = sum(abs2, cispi.(x))

julia> f2(x) = sum(abs2, cis.(x))

julia> @btime Zygote.gradient(f1, $(ones((1024,1024))));
  29.061 ms (34 allocations: 72.00 MiB)

julia> @btime Zygote.gradient(f2, $(ones((1024,1024))));
  28.697 ms (34 allocations: 72.00 MiB)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants