-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference #132
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…istering the class as a pytree node. @dimarkov see my comments under the @classmethod `tree_unflatten(cls, aux_data, children)`
…broadcasting of `gamma` variable more intuitive/readable
… of jupyter notebook `model_inversion.ipynb`
… output and revised docstring a bit
…rameterization more readable
Agent jax vmap agent class
…'t work in fixing JAX fixed-point iteration :(
… If `None`, they will default to being `all` factors - infer `num_controls` from the last dimension of each `B[f]`, if not given. This needs to be changed from dimension `2` of each B sub-array, because now we allow inter-factor dependencies in `B`
… `A` and `B` arrays
…hey are not provided (`None`)
- fixed D vector normalization error message
…ions, in case those priors exist
…erm comming from parameters
….py unit tests for `test_get_expected_stats_interactions_{X}_factor`
…rs, to avoid numpy deprecation warnings about converting ndarray with ndim >0 to scalar (relevant if using numpy>=1.25)
…maths.py` library
…ey are computed in `pymdp.jax.algos.get_mmp_messages`
conorheins
changed the title
JAX backend for
JAX backend for Jun 6, 2024
pymdp
pymdp
and sparse likelihood dependencies
conorheins
changed the title
JAX backend for
JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference
Jun 6, 2024
pymdp
and sparse likelihood dependencies
conorheins
changed the title
JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference
JAX backend, sparse likelihood dependencies, inductive inference
Jun 6, 2024
conorheins
changed the title
JAX backend, sparse likelihood dependencies, inductive inference
JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference
Jun 6, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request introduces a
jax
re-implementation of most of the active inference-related functionality provided by thepymdp
package. This also introduces new features (including 3 types of inference algorithms: variational message passing, marginal message passing, and online variational filtering), and the ability to simulate active inference agents in parallel using an additionalbatch_dimension
that is appended to the leading axes of all parameter tensors, actions, posterior beliefs, and observations.As an example, a typical A array in the numpy backend might have the following shapes
With the new changes, now a given
A[modality_m]
tensor will have shapewhere
N
is an additional batch dimension that indicates the number of generative models / agents one is parallelizing active inference processes across.Most importantly, the
Agent
API has been amended in the following ways:equinox.Module
, which means agents can be treated as pytrees. Using theAgent
class thus requires bothjax
and theequinox
package to be included in the requirements.vmap
decorated, so that the methods of anAgent
can be used to simulateN
agents in parallel. This also means the methods are much more functional, with fewer in-place operations on object properties, as done in the numpy version ofAgent
Other features:
numpy
backend (thanks to @tverbele's fork)jax
based on the implementation introduced in @tverbele's forknumpyro
and @dimarkov'spybefit
package. This required addingnumpyro
,optax
andarviz
to the requirements of the package. See the Model Inversion Notebook for a worked example of fitting the parameters of a T-Maze navigating agent to simulated pairs of (action, observation) data. Warning: parameter estimation is still buggy and not thoroughly tested. We find it is currently error-prone while fitting active inference agents equipped with advanced features like learning ofA
andB
. Sometimes we seenan
-valued gradients when usingnumpyro
'ssvi
routine, meaning this is a WIP feature.