From 573142825936409755369e7b4d9714a779b5fdb6 Mon Sep 17 00:00:00 2001 From: Ran Wei Date: Wed, 8 May 2024 15:00:00 -0500 Subject: [PATCH] make safe obj_array_from_list to address issue #130 with test --- pymdp/utils.py | 5 ++++- test/test_utils.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 test/test_utils.py diff --git a/pymdp/utils.py b/pymdp/utils.py index 005bd4ef..df4b3a1e 100644 --- a/pymdp/utils.py +++ b/pymdp/utils.py @@ -259,7 +259,10 @@ def obj_array_from_list(list_input): """ Takes a list of `numpy.ndarray` and converts them to a `numpy.ndarray` of `dtype = object` """ - return np.array(list_input, dtype = object) + arr = obj_array(len(list_input)) + for i, item in enumerate(list_input): + arr[i] = item + return arr def process_observation_seq(obs_seq, n_modalities, n_observations): """ diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..033dd8f6 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" Agent Class + +__author__: Conor Heins, Alexander Tschantz, Daphne Demekas, Brennan Klein + +""" + +import unittest + +import numpy as np + +from pymdp import utils + +class TestUtils(unittest.TestCase): + def test_obj_array_from_list(self): + """ + Tests `obj_array_from_list` + """ + # make arrays with same leading dimensions. naive method trigger numpy broadcasting error. + arrs = [np.zeros((3, 6)), np.zeros((3, 4, 5))] + obs_arrs = utils.obj_array_from_list(arrs) + + self.assertTrue(all([np.all(a == b) for a, b in zip(arrs, obs_arrs)])) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file