Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference #132

Merged
merged 273 commits into from
Jun 6, 2024

Conversation

conorheins
Copy link
Collaborator

@conorheins conorheins commented Jun 6, 2024

This pull request introduces a jax re-implementation of most of the active inference-related functionality provided by the pymdp package. This also introduces new features (including 3 types of inference algorithms: variational message passing, marginal message passing, and online variational filtering), and the ability to simulate active inference agents in parallel using an additional batch_dimension that is appended to the leading axes of all parameter tensors, actions, posterior beliefs, and observations.

As an example, a typical A array in the numpy backend might have the following shapes

>>> A[modality_m].shape
>>> [num_obs[modality_m], num_states[0], num_states[1], ...., num_states[-1]]

With the new changes, now a given A[modality_m] tensor will have shape

>>> A[modality_m].shape
>>> [N, num_obs[modality_m], num_states[0], num_states[1], ...., num_states[-1]]

where N is an additional batch dimension that indicates the number of generative models / agents one is parallelizing active inference processes across.

Most importantly, the Agent API has been amended in the following ways:

  • the Agent object is now an instance of an equinox.Module, which means agents can be treated as pytrees. Using the Agent class thus requires both jax and the equinox package to be included in the requirements.
  • methods are now vmap decorated, so that the methods of an Agent can be used to simulate N agents in parallel. This also means the methods are much more functional, with fewer in-place operations on object properties, as done in the numpy version of Agent

Other features:

  • allows subsets of hidden state factors to influence subsets of observation modalities (see this branch)
  • interactions among hidden state factors in the transition dynamics (see this branch)
  • sophisticated inference and inductive planning available in numpy backend (thanks to @tverbele's fork)
  • inductive planning in jax based on the implementation introduced in @tverbele's fork
  • preliminary ability to perform parameter estimation using numpyro and @dimarkov's pybefit package. This required adding numpyro, optax and arviz to the requirements of the package. See the Model Inversion Notebook for a worked example of fitting the parameters of a T-Maze navigating agent to simulated pairs of (action, observation) data. Warning: parameter estimation is still buggy and not thoroughly tested. We find it is currently error-prone while fitting active inference agents equipped with advanced features like learning of A and B. Sometimes we see nan-valued gradients when using numpyro's svi routine, meaning this is a WIP feature.

dimarkov and others added 30 commits September 19, 2022 18:15
…istering the class as a pytree node. @dimarkov see my comments under the @classmethod `tree_unflatten(cls, aux_data, children)`
…broadcasting of `gamma` variable more intuitive/readable
… of jupyter notebook `model_inversion.ipynb`
…'t work in fixing JAX fixed-point iteration :(
… If `None`, they will default to being `all` factors

- infer `num_controls` from the last dimension of each `B[f]`, if not given. This needs to be changed from dimension `2` of each B sub-array, because now we allow inter-factor dependencies in `B`
- fixed D vector normalization error message
@conorheins conorheins changed the title JAX backend for pymdp JAX backend for pymdp and sparse likelihood dependencies Jun 6, 2024
@conorheins conorheins changed the title JAX backend for pymdp and sparse likelihood dependencies JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference Jun 6, 2024
@conorheins conorheins changed the title JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference JAX backend, sparse likelihood dependencies, inductive inference Jun 6, 2024
@conorheins conorheins changed the title JAX backend, sparse likelihood dependencies, inductive inference JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference Jun 6, 2024
@conorheins conorheins merged commit bfc1346 into master Jun 6, 2024
3 checks passed
@conorheins conorheins deleted the agent_jax branch September 20, 2024 07:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants