You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On an HPC cluster, each term in a mean square loss can be calculated using embarrassingly parallel logic.
Unfortunately, the native way of doing this with jax (using jax.vmap and jax.pmap) is not compatible with input we must parallelize over: the Molecule object. This is because its data is stored in "ragged" structure. I.e., the dimensions of the grid for one molecule are very often different from the grid for another and the dimensions of the 1-RDM for one molecule is different for another: jnp.array([rdm1_1, rdm1_2]) will not work.
I don't think we will get around to solving this problem before our release deadline, but if we want to do something with HPC, getting this right is non-negotiable.
The text was updated successfully, but these errors were encountered:
On an HPC cluster, each term in a mean square loss can be calculated using embarrassingly parallel logic.
Unfortunately, the native way of doing this with
jax
(usingjax.vmap
andjax.pmap
) is not compatible with input we must parallelize over: theMolecule
object. This is because its data is stored in "ragged" structure. I.e., the dimensions of the grid for one molecule are very often different from the grid for another and the dimensions of the 1-RDM for one molecule is different for another:jnp.array([rdm1_1, rdm1_2])
will not work.This means that for loss parallelism, we need to think differently. Sharding may be the way forward, but this requires more thought. A good reference is here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
I don't think we will get around to solving this problem before our release deadline, but if we want to do something with HPC, getting this right is non-negotiable.
The text was updated successfully, but these errors were encountered: