Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace tensorflow by jax #23

Open
joamatab opened this issue Jun 15, 2022 · 3 comments
Open

replace tensorflow by jax #23

joamatab opened this issue Jun 15, 2022 · 3 comments

Comments

@joamatab
Copy link
Contributor

Tensorflow is a heavy dependency >500Mb, it would be nice replacing it by JAX,

JAX also comes with neural networks

@jaspreetj
@flaport
@sequoiap
@SkandanC
@AustP

@SkandanC
Copy link
Collaborator

SkandanC commented Sep 6, 2022

I have started working on this. The biggest issue, as I see it, is that SiPANN uses pre-saved tensorflow MetaGraphs to load the NN stuff. I am not sure how to replace that with JAX, but I will figure it out eventually.

@flaport
Copy link

flaport commented Sep 6, 2022

I think a better approach is to write the neural network in jax/flax and just save out the weights. Then create a loader/saver to read/write the weights from/to a file.

Saving the whole graph as a binary blob such as a saved MetaGraph is not transparent enough for an open source package in my opinion.

@SkandanC
Copy link
Collaborator

SkandanC commented Sep 8, 2022

Added support for jax.numpy #31

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants