Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tests run against local code, fix syntax warning breaking coverage, and add code coverage generation to github actions #140

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,22 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install coverage
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest test
- name: Run tests with pytest, calculating coverage
run: coverage run --source=pymdp -m pytest test/
- name: Generate coverage HTML report
run: coverage html
# expect actions/upload-artifact@v4 to fail when run locally with `act`
- name: Upload coverage HTML report for pymdp as a build artifact
uses: actions/upload-artifact@v4
with:
name: pymdp-${{ matrix.python-version }}--coverage-report
path: htmlcov/
retention-days: 30
- name: Print coverage report to console
run: coverage report
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
__pycache__
.DS_Store
.ipynb_checkpoints
.idea
.rope*
.vscode/
.ipynb_checkpoints/
.pytest_cache
env/
pymdp.egg-info
inferactively_pymdp.egg-info
htmlcov
.coverage
4 changes: 2 additions & 2 deletions pymdp/jax/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

def select_probs(positions, matrix, dependency_list, actions=None):
args = tuple(p for i, p in enumerate(positions) if i in dependency_list)
args += () if actions is None else (actions,)
args = args + (actions,) if actions is not None else args

return matrix[..., *args]
return matrix[(...,) + args]

def cat_sample(key, p):
a = jnp.arange(p.shape[-1])
Expand Down
5 changes: 5 additions & 0 deletions test/test_SPM_validation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import os
import sys
import unittest

import numpy as np
from scipy.io import loadmat

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp.agent import Agent
from pymdp.utils import to_obj_array, build_xn_vn_array, get_model_dimensions, convert_observation_array
from pymdp.maths import dirichlet_log_evidence
Expand Down
6 changes: 5 additions & 1 deletion test/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
"""

import os
import sys
import unittest

import numpy as np
from copy import deepcopy

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp.agent import Agent
from pymdp import utils, maths
Expand Down
8 changes: 5 additions & 3 deletions test/test_agent_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
"""

import os
import sys
import unittest

import numpy as np
import jax.numpy as jnp
from jax import vmap, nn, random
import jax.tree_util as jtu

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp.jax.maths import compute_log_likelihood_single_modality
from pymdp.jax.utils import norm_dist
from equinox import Module
from typing import Any, List

class TestAgentJax(unittest.TestCase):

Expand Down
5 changes: 5 additions & 0 deletions test/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
"""

import os
import sys
import unittest

import numpy as np

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp import utils, maths
from pymdp import control

Expand Down
7 changes: 5 additions & 2 deletions test/test_control_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@
"""

import os
import sys
import unittest
import pytest

import numpy as np
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

import pymdp.jax.control as ctl_jax
import pymdp.control as ctl_np

from pymdp.jax.maths import factor_dot
from pymdp import utils

cfg = {"source_key": 0, "num_models": 4}
Expand Down
7 changes: 6 additions & 1 deletion test/test_demos.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import unittest
import numpy as np
import os
import sys
import copy
import seaborn as sns
import matplotlib.pyplot as plt

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp.agent import Agent
from pymdp.utils import plot_beliefs, plot_likelihood
from pymdp import utils, maths, default_models
from pymdp import control
from pymdp.envs import TMazeEnv, TMazeEnvNullOutcome
Expand Down
5 changes: 5 additions & 0 deletions test/test_fpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
"""

import os
import sys
import unittest

import numpy as np

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp import utils, maths
from pymdp.algos import run_vanilla_fpi, run_vanilla_fpi_factorized

Expand Down
5 changes: 5 additions & 0 deletions test/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
"""

import os
import sys
import unittest

import numpy as np

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp import utils, maths
from pymdp import inference

Expand Down
7 changes: 6 additions & 1 deletion test/test_inference_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@
"""

import os
import sys
import unittest

import numpy as np
import jax.numpy as jnp

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp.jax.algos import run_vanilla_fpi as fpi_jax
from pymdp.algos import run_vanilla_fpi as fpi_numpy
from pymdp import utils, maths
from pymdp import utils

class TestInferenceJax(unittest.TestCase):

Expand Down
6 changes: 6 additions & 0 deletions test/test_learning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import os, sys
import unittest

import numpy as np

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp import utils, maths, learning

from copy import deepcopy
Expand Down
5 changes: 5 additions & 0 deletions test/test_learning_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
"""

import os
import sys
import unittest

import numpy as np
import jax.numpy as jnp
import jax.tree_util as jtu

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp.learning import update_obs_likelihood_dirichlet as update_pA_numpy
from pymdp.learning import update_obs_likelihood_dirichlet_factorized as update_pA_numpy_factorized
from pymdp.jax.learning import update_obs_likelihood_dirichlet as update_pA_jax
Expand Down
12 changes: 7 additions & 5 deletions test/test_message_passing_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import os
import sys
import unittest
from functools import partial

Expand All @@ -15,16 +16,17 @@
from jax import vmap, nn
from jax import random as jr

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp.jax.algos import run_vanilla_fpi as fpi_jax
from pymdp.jax.algos import run_factorized_fpi as fpi_jax_factorized
from pymdp.jax.algos import update_variational_filtering as ovf_jax
from pymdp.algos import run_vanilla_fpi as fpi_numpy
from pymdp.algos import run_mmp as mmp_numpy
from pymdp.jax.algos import run_mmp as mmp_jax
from pymdp.jax.algos import run_vmp as vmp_jax
from pymdp import utils, maths
from pymdp import utils

from typing import Any, List, Dict
from typing import List, Dict


def make_model_configs(source_seed=0, num_models=4) -> Dict:
Expand Down
5 changes: 5 additions & 0 deletions test/test_mmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,16 @@
"""

import os
import sys
import unittest

import numpy as np
from scipy.io import loadmat

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp.utils import get_model_dimensions, convert_observation_array
from pymdp.algos import run_mmp
from pymdp.maths import get_joint_likelihood_seq
Expand Down
6 changes: 5 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
__author__: Conor Heins, Alexander Tschantz, Daphne Demekas, Brennan Klein

"""

import os, sys
import unittest

import numpy as np

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp import utils

class TestUtils(unittest.TestCase):
Expand Down
7 changes: 6 additions & 1 deletion test/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
import sys
import unittest
from pathlib import Path

# import the library directly from local source (rather than relying on the library being installed)
# insert the dependency so it's prioritized over an installed variant
sys.path.insert(0, os.path.abspath('../pymdp'))

from pymdp.utils import Dimensions, get_model_dimensions_from_labels

class TestWrappers(unittest.TestCase):
Expand Down