Skip to content

Commit

Permalink
removed infer_policies_old from Agent class in numpy backend
Browse files Browse the repository at this point in the history
  • Loading branch information
conorheins committed Jun 6, 2024
1 parent 51c799f commit 540e855
Showing 1 changed file with 0 additions and 68 deletions.
68 changes: 0 additions & 68 deletions pymdp/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,74 +604,6 @@ def _infer_states_test(self, observation, distr_obs=False):
return qs, xn, vn
else:
return qs

def infer_policies_old(self):
"""
Perform policy inference by optimizing a posterior (categorical) distribution over policies.
This distribution is computed as the softmax of ``G * gamma + lnE`` where ``G`` is the negative expected
free energy of policies, ``gamma`` is a policy precision and ``lnE`` is the (log) prior probability of policies.
This function returns the posterior over policies as well as the negative expected free energy of each policy.
Returns
----------
q_pi: 1D ``numpy.ndarray``
Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy.
G: 1D ``numpy.ndarray``
Negative expected free energies of each policy, i.e. a vector containing one negative expected free energy per policy.
"""

if self.inference_algo == "VANILLA":
q_pi, G = control.update_posterior_policies(
self.qs,
self.A,
self.B,
self.C,
self.policies,
self.use_utility,
self.use_states_info_gain,
self.use_param_info_gain,
self.pA,
self.pB,
E=self.E,
I=self.I,
gamma=self.gamma
)
elif self.inference_algo == "MMP":
if self.factorized:
raise NotImplementedError("Factorized inference not implemented for MMP")

if self.sophisticated:
raise NotImplementedError("Sophisticated inference not implemented for MMP")


future_qs_seq = self.get_future_qs()

q_pi, G = control.update_posterior_policies_full(
future_qs_seq,
self.A,
self.B,
self.C,
self.policies,
self.use_utility,
self.use_states_info_gain,
self.use_param_info_gain,
self.latest_belief,
self.pA,
self.pB,
F = self.F,
E = self.E,
I=self.I,
gamma = self.gamma
)

if hasattr(self, "q_pi_hist"):
self.q_pi_hist.append(q_pi)
if len(self.q_pi_hist) > self.inference_horizon:
self.q_pi_hist = self.q_pi_hist[-(self.inference_horizon-1):]

self.q_pi = q_pi
self.G = G
return q_pi, G

def infer_policies(self):
"""
Expand Down

0 comments on commit 540e855

Please sign in to comment.