Skip to content

Commit

Permalink
Merge pull request #67 from jcapriot/prism_cython
Browse files Browse the repository at this point in the history
Prism cython
  • Loading branch information
jcapriot committed Oct 30, 2023
2 parents 5408065 + 3cf25a7 commit 0d2a7da
Show file tree
Hide file tree
Showing 8 changed files with 396 additions and 581 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test_with_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
matplotlib
jupyter
utm
numba
pytest
pytest-cov
sphinx
Expand Down Expand Up @@ -87,6 +88,7 @@ jobs:
matplotlib
jupyter
utm
numba
pytest
pytest-cov
sphinx
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ docs/api/generated/*

# Jupyter
*.ipynb

#Cython generated files
geoana/kernels/_extensions/rTE.cpp
geoana/kernels/_extensions/potential_field_prism.c
geoana/kernels/_extensions/potential_field_prism_api.h

# setuptools_scm
geoana/version.py
54 changes: 54 additions & 0 deletions geoana/kernels/_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
try:
# register numba jitable versions of the prism functions
# if numba is available (and this module is installed).
from numba.extending import (
overload,
get_cython_function_address
)
from numba import types
import ctypes

from .potential_field_prism import (
prism_f,
prism_fz,
prism_fzz,
prism_fzx,
prism_fzy,
prism_fzzz,
prism_fxxy,
prism_fxxz,
prism_fxyz,
)
funcs = [
prism_f,
prism_fz,
prism_fzz,
prism_fzx,
prism_fzy,
prism_fzzz,
prism_fxxy,
prism_fxxz,
prism_fxyz,
]

def _numba_register_prism_func(prism_func):
module = 'geoana.kernels._extensions.potential_field_prism'
name = prism_func.__name__

func_address = get_cython_function_address(module, name)
func_type = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double, ctypes.c_double)
c_func = func_type(func_address)

@overload(prism_func)
def numba_func(x, y, z):
if isinstance(x, types.Float):
if isinstance(y, types.Float):
if isinstance(z, types.Float):
def f(x, y, z):
return c_func(x, y, z)
return f
for func in funcs:
_numba_register_prism_func(func)

except ImportError as err:
pass
Loading

0 comments on commit 0d2a7da

Please sign in to comment.