Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposed fixes to control.py and utils.py #56

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pymdp/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,9 +492,9 @@ def sample_action(q_pi, policies, num_controls, action_selection="deterministic"

# weight each action according to its integrated posterior probability over policies and timesteps
for pol_idx, policy in enumerate(policies):
for t in range(policy.shape[0]):
for factor_i, action_i in enumerate(policy[t, :]):
action_marginals[factor_i][action_i] += q_pi[pol_idx]
for factor_i, action_i in enumerate(policy[0, :]):
# to get the marginals we just want to add up the actions at time 0
action_marginals[factor_i][action_i] += q_pi[pol_idx]

selected_policy = np.zeros(num_factors)
for factor_i in range(num_factors):
Expand Down
2 changes: 1 addition & 1 deletion pymdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import itertools

def sample(probabilities):
sample_onehot = np.random.multinomial(1, probabilities.squeeze())
sample_onehot = np.random.multinomial(1, probabilities)
return np.where(sample_onehot == 1)[0][0]

def sample_obj_array(arr):
Expand Down