Skip to content

Commit

Permalink
Merge pull request #131 from ran-wei-verses/obj_array_130
Browse files Browse the repository at this point in the history
make safe obj_array_from_list to address issue #130 with test
  • Loading branch information
conorheins authored May 8, 2024
2 parents 9aa1e2c + 5731428 commit 6c23ab9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pymdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,10 @@ def obj_array_from_list(list_input):
"""
Takes a list of `numpy.ndarray` and converts them to a `numpy.ndarray` of `dtype = object`
"""
return np.array(list_input, dtype = object)
arr = obj_array(len(list_input))
for i, item in enumerate(list_input):
arr[i] = item
return arr

def process_observation_seq(obs_seq, n_modalities, n_observations):
"""
Expand Down
28 changes: 28 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

""" Agent Class
__author__: Conor Heins, Alexander Tschantz, Daphne Demekas, Brennan Klein
"""

import unittest

import numpy as np

from pymdp import utils

class TestUtils(unittest.TestCase):
def test_obj_array_from_list(self):
"""
Tests `obj_array_from_list`
"""
# make arrays with same leading dimensions. naive method trigger numpy broadcasting error.
arrs = [np.zeros((3, 6)), np.zeros((3, 4, 5))]
obs_arrs = utils.obj_array_from_list(arrs)

self.assertTrue(all([np.all(a == b) for a, b in zip(arrs, obs_arrs)]))

if __name__ == "__main__":
unittest.main()

0 comments on commit 6c23ab9

Please sign in to comment.