Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to make MultiModel faster #204

Open
jajcayn opened this issue May 11, 2022 · 1 comment
Open

Try to make MultiModel faster #204

jajcayn opened this issue May 11, 2022 · 1 comment
Assignees
Labels
enhancement Enhancing neurolib.

Comments

@jajcayn
Copy link
Collaborator

jajcayn commented May 11, 2022

I got a couple of requests about the speed of MultiModel. It seems that for a larger network it is painfully slow. I suspect a couple of things:

  • jitcdde compilation for large networks - but numba should be faster in this case
  • noise precomputation - I know this is slow and not optimized. E.g. when we have 80 nodes and 2 noise sources per node, 160 noise processes in total are precomputed. However, due to design, it is done first of all not in parallel, and secondly, even if the parameters are the same, so it could be vectorized, it is not and it is working sequentially.
  • last resort would be to disable jitccde backend altogether and then we could move noise integration inside the numba-ized loop which should be much faster...
@jajcayn jajcayn added the enhancement Enhancing neurolib. label May 11, 2022
@jajcayn jajcayn self-assigned this May 11, 2022
@jajcayn
Copy link
Collaborator Author

jajcayn commented Feb 20, 2023

Just in: I tested a new framework for solving diff eqs: https://github.com/patrick-kidger/diffrax
for now, they support ODEs/SDEs/CDEs (controlled diff. eqs - interesting for Lena?), DDEs are on PR and will be merged hopefully soon.
The HUGE (I cannot stress this enough) advantage is - it uses jax as backend, which means:

  • speed should be even better than numba: it uses jax.jit()
  • it is GPU/TPU capable out of the box (everything in jax can be run on GPU just like that)
  • autodifferentiation for free, since, again, jax has autograd, so you can compute derivatives, partial derivatives, jacobians, Hessians, etc, with one line of code (proper fixed point and bifurcation analyses, anyone? )

After DDEs are an official part of the package, I'll check how difficult it would be to rewrite MultiModel to this. My quick initial assessment: it's not that hard since the right-hand sides are passed to diffrax in the same way as to jitcdde; handling delays is a bit different, though... I would just rewrite all symengine principals to jax.numpy and rewrite the backend module, and it should be good.

Imagine creating a MultiModel to your wishes, having the same speed as the "native" numba model, and having the option to do state space analysis since jax.autograd would be for free.

@caglorithm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhancing neurolib.
Projects
None yet
Development

No branches or pull requests

1 participant