Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Macenko JAX backend support #36

Open
wants to merge 16 commits into
base: development
Choose a base branch
from

Conversation

andreped
Copy link
Collaborator

@andreped andreped commented Jan 29, 2023

This PR adds JAX backend support to Macenko.

Changes:

  • Implemented Macenko with JAX backend and added it to base
  • Added JAX unit test CI jobs (check that JAX yields similar results to numpy backend)
  • Renamed CI names to better match their actual purpose
  • Updated README regarding JAX backend support
  • Fixed setup.py to support installation through pip install torchstain[jax]
  • Fixed np.float32 deprecation in numpy macenko
  • Removed unwanted numpy import in macenko tf backend

Note that the JAX backend runtime-wise is not as optimized as the other backends. Hence, I would perhaps say that we only have experimental JAX support as of now. Here is how JAX backend compared to the other backends:

backends numpy jax torch tf
runtime [s] 0.455 2.427 0.201 0.442

Further optimization to the JAX implementation should be done in future work, but this is outside my area of expertise. Hence, for that, it would be great if more experienced JAX developers could contribute.

@andreped andreped mentioned this pull request Jan 29, 2023
@andreped andreped added the enhancement New feature or request label Jan 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

None yet

1 participant