Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No minibatch for computation of logp_old in PPOPolicy #1164

Closed
7 of 9 tasks
jvasso opened this issue Jul 1, 2024 · 1 comment · Fixed by #1168
Closed
7 of 9 tasks

No minibatch for computation of logp_old in PPOPolicy #1164

jvasso opened this issue Jul 1, 2024 · 1 comment · Fixed by #1168
Labels
performance issues Slow execution or poor-quality results

Comments

@jvasso
Copy link
Contributor

jvasso commented Jul 1, 2024

  • I have marked all applicable categories:
    • exception-raising bug
    • RL algorithm bug
    • documentation request (i.e. "X is missing from the documentation.")
    • new feature request
    • design request (i.e. "X should be changed to Y.")
  • I have visited the source website
  • I have searched through the issue tracker for duplicates
  • I have mentioned version numbers, operating system and environment, where applicable:

I have noticed that in the implementation of the PPOPolicy, the computation of the old log probabilities logp_old is performed without using minibatch:

with torch.no_grad():
   batch.logp_old = self(batch).dist.log_prob(batch.act)

This makes this algorithm unusable in situations where the batch is too large, with no possibility of controlling it via batch_size.
I simply suggest to add support for minibatch:

logp_old = []
with torch.no_grad():
    for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
        logp_old.append(self(minibatch).dist.log_prob(minibatch.act))
    batch.logp_old = torch.cat(logp_old, dim=0).flatten()

The version of Tianshou that I'm using is 1.0.0.

@MischaPanch
Copy link
Collaborator

You're right, wanna make a PR for that? Otherwise I can also make one myself

@MischaPanch MischaPanch added the performance issues Slow execution or poor-quality results label Jul 6, 2024
MischaPanch added a commit that referenced this issue Jul 20, 2024
Closes #1164

In PPOPolicy, the method `process_fn()` now computes `logp_old` in
minibatch instead of all at once.

---------

Co-authored-by: Michael Panchenko <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance issues Slow execution or poor-quality results
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants