Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Jax #31

Open
swamidass opened this issue Dec 24, 2022 · 24 comments
Open

Support for Jax #31

swamidass opened this issue Dec 24, 2022 · 24 comments
Assignees
Labels
enhancement New feature or request

Comments

@swamidass
Copy link

Excellent tool. We are likely to cite you in the future.

Would you consider building out a Jax backend?

@andreped andreped added the enhancement New feature or request label Dec 24, 2022
@andreped
Copy link
Collaborator

Hello, @swamidass! I hope you are having a nice christmas holiday.

I've been considering getting into Jax, just never had the time. Mostly use TF myself, but use torch from time to time.

I might make an attempt to add Jax backend, if it is cruicial for a study of yours. Do you have a deadline when you need this feature to be ready?

I have made PRs to add full backend support for reinhard and a modified reinhard variant a few days ago. Could add Jax support, but would need to make an attempt first.

@swamidass
Copy link
Author

Thanks for the response, and I hope you also are having a nice holiday.

We are in the middle of porting our WSI pipeline from tensorflow to Jax. So the there isn't a rush, because we are still using tenorflow. However, we would be well situated to test a jax backend.

@andreped
Copy link
Collaborator

andreped commented Jan 2, 2023

We are in the middle of porting our WSI pipeline from tensorflow to Jax.

That's interesting. What is the main reason you are moving to Jax from TF? I was considering it, but for training purposes tf/keras offer most of what I need. For deployment, on the other hand, I find that using the inference engines such as ONNX Runtime or TensorRT are better, but that is not a problem as "all" TF models can be converted to the ONNX format, which both IEs support.

However, we would be well situated to test a jax backend.

I'm partly back from vacation now and will likely make an attempt at the Jax backend soon. Will keep you updated :]

@andreped
Copy link
Collaborator

andreped commented Jan 2, 2023

It was too tempting...

Jax backend with Macenko seems to work. Did not spend lots of time with it. Hence, there are probably lots of ways to optimize it and whatnot, but at least we have a baseline that you could test and benchmark, if you'd like, @swamidass. Then we can try to improve it later on if necessary.

To test, be sure to install from this branch by:

pip install git+https://github.com/andreped/torchstain.git@jax-backend

Then to use, simply do:

import torchstain
import numpy as np

normalizer_jax = torchstain.normalizers.MacenkoNormalizer(backend='jax')
normalizer_jax.fit(target)
norm_jax, _, _ = normalizer_jax.normalize(to_transform)
norm = np.asarray(norm_jax)

When @carloalbertobarbano is back from vacation, we will resolve some existing PRs and then I can start adding Jax to the main branch for all stain normalization techniques. Should be quite seemless, as Jax has a surprisingly neat numpy-like API, when using jax.numpy as a direct replacement for numpy.

When all that is working, I believe we are ready for a new release :] Will keep you updated, @swamidass.

Screenshot 2023-01-02 at 15 33 57

@andreped andreped self-assigned this Jan 2, 2023
@swamidass
Copy link
Author

Very nice! In our experience, JAX is substantially faster than TF and pytorch, which is why we are migrating to it.

So it will be interesting to see if that's true in this case too.

@swamidass
Copy link
Author

swamidass commented Jan 3, 2023

One thing you want to be sure to test though, is whether or not the entrypoint you are providing is compilable. You can test this easily (and it should be a test case) with this modification to your code:

@jax.jit
def jax_normalize(to_transform):
  norm_jax, _, _ = normalizer_jax.normalize(to_transform)
  return norm_jax

norm_jax = jax_normalize(to_transform)

If that throws an error. there is more work to do. Usually, there will be an internal function you are calling that needs to be wrapped in a jax.jit, identifying what are the static_args.

If that does not throw an error, you might be done!

@swamidass
Copy link
Author

So, running this, we do get an error. It's a classic example of where code needs to be refactored for JAX:

ODhat = OD[~jnp.any(OD < beta, axis=1)]

This needs to be refactored so that all intermediate matrices are a fixed size. The boolean select here yields an array with indeterminate size. The way to refactor this is with a mask.

I'm willing to make some of the changes, if I can. If i do, would you mind adding me as an author? I was also thinking of adding a utility we developed to enable parallelized application of this to large slides too.

@swamidass
Copy link
Author

swamidass commented Jan 3, 2023

Also so a few of these

Inorm.at[Inorm > 255].set(255)

They are no ops as written. You'd need to do:

 Inorm = Inorm.at[Inorm > 255].set(255)

But I'm not srue they will compile, regardless. A better way to write that is:

Inorm = jnp.where(Inorm > 255, 255, Inorm) 

@andreped
Copy link
Collaborator

andreped commented Jan 3, 2023

Its rather late in Norway, but I can try to incorporate your ideas, @swamidass. Give me a second, and I will make a commit if I get it working as intended.

@andreped
Copy link
Collaborator

andreped commented Jan 3, 2023

Hmm, as far as I can tell, there does not seem to exist a masking mechanism in jax as of yet to mimic y = x[mask]. Rather strange I must say, but there are numerous threads about it.

Even more surprising: jax-ml/jax#11557

Any ideas? It is this line that is giving me a headache:

ODhat = OD[~jnp.any(OD < beta, axis=1)]

@andreped
Copy link
Collaborator

andreped commented Jan 3, 2023

I got further, but the masking in Jax is giving me a headache. At least now, assuming there is a fix for this one line, adding the @jax.jit decorator works and hence, Macenko should work with Jax-backend.

You do not need to add it yourself externally, it is added directly within the class (see here).

Just reinstall the latest version of the same branch and run the same commands as mentioned previously to test: #31 (comment)

@andreped
Copy link
Collaborator

andreped commented Jan 3, 2023

I'm willing to make some of the changes, if I can. If i do, would you mind adding me as an author? I was also thinking of adding a utility we developed to enable parallelized application of this to large slides too.

@swamidass Oh, and of course, contributors are always welcome. Regarding authorship, I am not the owner of this project, just contributing to it, but I'm open to the idea :] You have already been helpful in the Jax-backend implementation.

@swamidass
Copy link
Author

I will take a look at it in a moment. but regarding:

ODhat = OD[~jnp.any(OD < beta, axis=1)]

The trick is to refactor the code so ODhat isn't needed. I believe the correct solution would be to change line 61,

_, eigvecs = jnp.linalg.eigh(jnp.cov(ODhat.T))

To, something close to:

mask = ~jnp.any(OD < beta, axis=1)
cov = jnp.cov(OD.T, fweights = mask)
_, eigvecs = jnp.linalg.eigh(cov)

That leaves leaves lines 27-32 to be refactored,

That = ODhat.dot(eigvecs[:, 1:3])

phi = jnp.arctan2(That[:, 1], That[:, 0])

minPhi = jnp.percentile(phi, alpha)
maxPhi = jnp.percentile(phi, 100 - alpha)

Into something like:

Th = OD.dot(eigvecs[:, 1:3])

phi = jnp.arctan2(Th[:, 1], Th[:, 0])

phi = jnp.where(mask, phi, jnp.inf)
pvalid = mask.mean() # proportion that is valid and not masked

minPhi = jnp.percentile(phi, alpha * pvalid)
maxPhi = jnp.percentile(phi, (100 - alpha) * pvalid)

I think those two changes will make the fix complete.

@swamidass
Copy link
Author

Also, the way you are doing jit in the code base is incorrect. You should not do:

    @partial(jax.jit, static_argnums=(0,))
    def normalize(self, I, Io=240, alpha=1, beta=0.15, stains=True):

The problem is that the 'stains' argument is also a static argument, and as written, it will not work correctly with self. This is one of the weird things about Jax, which they explain at length: https://jax.readthedocs.io/en/latest/faq.html?highlight=self#how-to-use-jit-with-methods. You are taking Strategy #2, and the docs explain how that breaks.

@andreped
Copy link
Collaborator

andreped commented Jan 3, 2023

Incorporated your suggestions and it seems to work. Great! :) Cheers!

Will add a unit test later on to verify that we get "identical" results compared to the other backends.

Its 7 AM in Norway right now, so I think I will head to bed. Just started using Jax today, so lots of new things to consider, such as how to use jit with methods.

I will look into Strategy #3 tomorrow, which seems like the appropriate approach, but pushing what I have now to the jax-backend branch, for you to test, if you want.

@carloalbertobarbano
Copy link
Member

JAX Support is a good idea, given we already have tf and torch. We can mark it as experimental for now

@andreped
Copy link
Collaborator

JAX Support is a good idea, given we already have tf and torch. We can mark it as experimental for now

I have the PR ready which adds Jax backend support for Macenko (as discussed above). Will make a PR after all the existing PRs have been merged.

@swamidass
Copy link
Author

@andreped can you link me to it to review for any issues?

@andreped
Copy link
Collaborator

andreped commented Jan 18, 2023

@andreped can you link me to it to review for any issues?

The edits are in this branch. I noticed now that I had not yet fixed the methods thing we discussed above. I have been focusing on my own PhD work lately. Note that it was Reinhard we are adding jax backend support for, not Macenko (yet).

I have a deadline for tomorrow, but I can make an attempt to fix that after. Atlernatively, if you have time, @swamidass, you can fork my branch and add the final fix there.

We can look into adding Macenko (+ modified Reinhard) jax backend support in a future PR, if that is of interest to you, @swamidass :]

But it might be a good idea to wait until the development branch has been merged with the main, as there will likely be several merge conflicts.

@andreped
Copy link
Collaborator

andreped commented Jan 22, 2023

Added the PyTree fix as we discussed in a new commit, @swamidass. See latest version in branch here.

At least it runs with the @jax.jit decorator. Added to both normalize and fit methods.

Not sure if I understood the whole static/dynamic value thing in the PyTree, as for this specific class, it made sense to not have any. What do you think? Note that the class itself does not have any arguments, but maybe that is required to work properly? Not sure.

@andreped
Copy link
Collaborator

andreped commented Jan 23, 2023

Oh, important note! Just added a unit test which is ran for each new commit, and I observed that the Jax output differ from the numpy output.

From the look of the CI log, by quite a lot: https://github.com/andreped/torchstain/actions/runs/3982389028/jobs/6826815783

I swear it used to work, at least before we added the three main fixes discussed above to get it working with the @jax.jit decorator. Can see if I can make a gist tomorrow which reproduces the issue, which we can use to debug this further, and hopefully finalize this feature.


EDIT: @carloalbertobarbano I believe this feature is not ready for the upcoming new release v1.3.0. It is not a critical feature either, as I believe @swamidass do not need it for his current study.

@andreped
Copy link
Collaborator

I believe I have fixed it. It both yields output that visually appear similar to the numpy backend output and passes the unit test.

Hence, I made a PR that adds Macenko JAX backend support: #36

Runtime-wise the current JAX implementation is a lot slower than the alternative backends. There is a long thread likely explaining why that might be the case here. However, it is also likely that further improvements can be made to further optimize the JAX backend, but right now I'm happy with having something that runs.

I ran a simple benchmark (single run) which yielded:

backends numpy jax torch tf
runtime [s] 0.455 2.427 0.201 0.442

If you have time, @swamidass, it would be great if you could review the implementation in the PR.

@swamidass
Copy link
Author

Thanks for doing this. Performance issues are important, but not required to solve for the first implementation.

@andreped
Copy link
Collaborator

andreped commented Feb 6, 2023

Thanks for doing this. Performance issues are important, but not required to solve for the first implementation.

Always happy to contribute :] Let me know when you start testing it further, and if any modifications are made to make it faster. And of course, PRs are always welcome!

When you have time @carloalbertobarbano, you can make the next release :] After the PRs have been merged into the development branch, and merged with master of course.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants