Skip to content

Commit

Permalink
keep only last state in the environment
Browse files Browse the repository at this point in the history
  • Loading branch information
dimarkov committed Jul 3, 2024
1 parent fe4313d commit 4fa006f
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions pymdp/jax/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def step(self, key: PRNGKeyArray, actions: Optional[Array] = None):

keys = list(jr.split(key_state, len(state_probs)))
new_states = jtu.tree_map(cat_sample, keys, state_probs)

states.append(new_states)

else:
new_states = states[-1]

Expand All @@ -76,4 +73,4 @@ def step(self, key: PRNGKeyArray, actions: Optional[Array] = None):
keys = list(jr.split(key_obs, len(obs_probs)))
new_obs = jtu.tree_map(cat_sample, keys, obs_probs)

return new_obs, tree_at(lambda x: (x.states), self, states)
return new_obs, tree_at(lambda x: (x.states), self, [new_states])

0 comments on commit 4fa006f

Please sign in to comment.