diff --git a/README.rst b/README.rst index 0deacc5..552a90f 100644 --- a/README.rst +++ b/README.rst @@ -70,6 +70,12 @@ You should then be able to run the examples and tutorials in the examples folder .. _PyPI: https://pypi.org/project/SiPANN/ +JAX support +=========== + +SiPANN provides optional support for [JAX](https://github.com/google/jax). You need to have JAX installed in order to let SiPANN use `jax.numpy`(which is faster) instead of `numpy`. For installation instructions, see the [JAX documentation](https://github.com/google/jax#installation). + + References ========== diff --git a/SiPANN/comp.py b/SiPANN/comp.py index 45b99ff..a529161 100644 --- a/SiPANN/comp.py +++ b/SiPANN/comp.py @@ -2,7 +2,10 @@ from abc import ABC, abstractmethod import gdspy -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np import pkg_resources import scipy.integrate as integrate import scipy.special as special diff --git a/SiPANN/import_nn.py b/SiPANN/import_nn.py index d4d533c..2d127fa 100644 --- a/SiPANN/import_nn.py +++ b/SiPANN/import_nn.py @@ -1,7 +1,10 @@ import pickle from itertools import combinations_with_replacement as comb_w_r -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np import tensorflow as tf diff --git a/SiPANN/nn.py b/SiPANN/nn.py index 73533ab..845cdee 100644 --- a/SiPANN/nn.py +++ b/SiPANN/nn.py @@ -17,7 +17,10 @@ # ---------------------------------------------------------------------------- # # Import libraries # ---------------------------------------------------------------------------- # -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np import pkg_resources import skrf as rf from scipy.interpolate import UnivariateSpline diff --git a/SiPANN/scee.py b/SiPANN/scee.py index 33de86c..82e58c7 100644 --- a/SiPANN/scee.py +++ b/SiPANN/scee.py @@ -2,7 +2,10 @@ from abc import ABC, abstractmethod import gdspy -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np import pkg_resources from scipy import special from scipy.integrate import quad diff --git a/SiPANN/scee_int.py b/SiPANN/scee_int.py index 1913823..1003772 100644 --- a/SiPANN/scee_int.py +++ b/SiPANN/scee_int.py @@ -1,4 +1,7 @@ -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np from simphony import Model from simphony.layout import Circuit from simphony.pins import Pin, PinList diff --git a/SiPANN/scee_opt.py b/SiPANN/scee_opt.py index 25f06b1..4a2b9cf 100644 --- a/SiPANN/scee_opt.py +++ b/SiPANN/scee_opt.py @@ -2,7 +2,10 @@ import os import matplotlib.pyplot as plt -import numpy as np +try: + import jax.numpy as np +except ImportError: + import numpy as np from numba import njit, vectorize from numba.extending import get_cython_function_address from scipy import special