From e473bf5eed9c6d2692e8567e01c989ec7069c0ae Mon Sep 17 00:00:00 2001 From: SkandanC Date: Tue, 6 Sep 2022 12:27:07 -0400 Subject: [PATCH 1/3] Add jax support --- SiPANN/import_nn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/SiPANN/import_nn.py b/SiPANN/import_nn.py index d4d533c..0ea4817 100644 --- a/SiPANN/import_nn.py +++ b/SiPANN/import_nn.py @@ -1,7 +1,11 @@ import pickle from itertools import combinations_with_replacement as comb_w_r -import numpy as np +try: + import jax.numpy as np + print("JAX imported") +except ImportError: + import numpy as np import tensorflow as tf From f88f32c6481fa5d5c0506bf31b651641f8c5432a Mon Sep 17 00:00:00 2001 From: SkandanC Date: Tue, 6 Sep 2022 12:27:18 -0400 Subject: [PATCH 2/3] Add JAX support text --- README.rst | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.rst b/README.rst index 0deacc5..552a90f 100644 --- a/README.rst +++ b/README.rst @@ -70,6 +70,12 @@ You should then be able to run the examples and tutorials in the examples folder .. _PyPI: https://pypi.org/project/SiPANN/ +JAX support +=========== + +SiPANN provides optional support for [JAX](https://github.com/google/jax). You need to have JAX installed in order to let SiPANN use `jax.numpy`(which is faster) instead of `numpy`. For installation instructions, see the [JAX documentation](https://github.com/google/jax#installation). + + References ========== From e9675ae3f10798634f4ad1fa919ed83ca132750d Mon Sep 17 00:00:00 2001 From: SkandanC Date: Tue, 6 Sep 2022 12:34:21 -0400 Subject: [PATCH 3/3] Add jax support --- SiPANN/comp.py | 5 ++++- SiPANN/import_nn.py | 1 - SiPANN/nn.py | 5 ++++- SiPANN/scee.py | 5 ++++- SiPANN/scee_int.py | 5 ++++- SiPANN/scee_opt.py | 5 ++++- 6 files changed, 20 insertions(+), 6 deletions(-) diff --git a/SiPANN/comp.py b/SiPANN/comp.py index 45b99ff..a529161 100644 --- a/SiPANN/comp.py +++ b/SiPANN/comp.py @@ -2,7 +2,10 @@ from abc import ABC, abstractmethod import gdspy -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np import pkg_resources import scipy.integrate as integrate import scipy.special as special diff --git a/SiPANN/import_nn.py b/SiPANN/import_nn.py index 0ea4817..2d127fa 100644 --- a/SiPANN/import_nn.py +++ b/SiPANN/import_nn.py @@ -3,7 +3,6 @@ try: import jax.numpy as np - print("JAX imported") except ImportError: import numpy as np import tensorflow as tf diff --git a/SiPANN/nn.py b/SiPANN/nn.py index 73533ab..845cdee 100644 --- a/SiPANN/nn.py +++ b/SiPANN/nn.py @@ -17,7 +17,10 @@ # ---------------------------------------------------------------------------- # # Import libraries # ---------------------------------------------------------------------------- # -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np import pkg_resources import skrf as rf from scipy.interpolate import UnivariateSpline diff --git a/SiPANN/scee.py b/SiPANN/scee.py index 33de86c..82e58c7 100644 --- a/SiPANN/scee.py +++ b/SiPANN/scee.py @@ -2,7 +2,10 @@ from abc import ABC, abstractmethod import gdspy -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np import pkg_resources from scipy import special from scipy.integrate import quad diff --git a/SiPANN/scee_int.py b/SiPANN/scee_int.py index 1913823..1003772 100644 --- a/SiPANN/scee_int.py +++ b/SiPANN/scee_int.py @@ -1,4 +1,7 @@ -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np from simphony import Model from simphony.layout import Circuit from simphony.pins import Pin, PinList diff --git a/SiPANN/scee_opt.py b/SiPANN/scee_opt.py index 25f06b1..4a2b9cf 100644 --- a/SiPANN/scee_opt.py +++ b/SiPANN/scee_opt.py @@ -2,7 +2,10 @@ import os import matplotlib.pyplot as plt -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np from numba import njit, vectorize from numba.extending import get_cython_function_address from scipy import special