Skip to content

Commit

Permalink
fix passing of predictive distribution inside update_empirical_prior …
Browse files Browse the repository at this point in the history
…method for mmp and vmp algos
  • Loading branch information
dimarkov committed Jun 21, 2024
1 parent 0575fca commit 264ee7d
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion pymdp/jax/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,11 @@ def update_empirical_prior(self, action, qs):

qs_last = jtu.tree_map( lambda x: x[-1], qs)
# this computation of the predictive prior is correct only for fully factorised Bs.
pred = control.compute_expected_state(qs_last, self.B, action, B_dependencies=self.B_dependencies)
if self.inference_algo in ['mmp', 'vmp']:
# in the case of the 'mmp' or 'vmp' we have to use D as prior parameter for infer states
pred = self.D
else:
pred = control.compute_expected_state(qs_last, self.B, action, B_dependencies=self.B_dependencies)

return (pred, qs)

Expand Down

0 comments on commit 264ee7d

Please sign in to comment.