Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching for UPSequentialSimulator._get_applicable_actions() #572

Open
dhrchan opened this issue Feb 22, 2024 · 3 comments
Open

Caching for UPSequentialSimulator._get_applicable_actions() #572

dhrchan opened this issue Feb 22, 2024 · 3 comments

Comments

@dhrchan
Copy link

dhrchan commented Feb 22, 2024

I want to implement a MCTS planning algorithm, which makes frequent calls to UPSequentialSimulator._get_applicable_actions(). At each step of a MCTS rollout, the set of applicable actions must be generated so an action can be sampled. When the space of grounded actions is large, _get_applicable_actions() becomes expensive to compute. With 100 grounded actions (arising from 3 lifted actions), I find that it takes up to 1 second to perform a single depth 100 MCTS rollout, with that time dominated by calls to _get_applicable_actions().

I fixed this by adding the method decorator @cache from functools to UPSequentialSimulator._get_applicable_actions(), which yielded a 80x speedup for the MCTS rollouts. Does this use of caching violate any assumptions made by the UPSequentialSimulator? Are there other intended ways I can speed up calls to _get_applicable_actions()?

@alvalentini
Copy link
Member

Hi @dhrchan! The _get_applicable_actions method has only the state parameter that is an immutable object. So the caching doesn't violate any assumptions.

However, the State class doesn't implement the __eq__ method, so the caching should work only for the very same object.
So is it correct that you are calling the _get_applicable_actions method several times for the same State object in a single MCTS rollout?

@dhrchan
Copy link
Author

dhrchan commented Feb 29, 2024

Thank you for your reply! I suppose if there are cycles in my MCTS rollout, repeated states won't be the same State object, and thus would cause a cache miss. I think I can fix this by implementing __eq__ and __hash__ methods for State, as #555 pointed out.

I also found another issue with using the caching decorator: since _get_applicable_actions returns a generator, it can't be cached the same way list can using functools.cache. After the first call to _get_applicable_actions, the generator returned will be initialized to wherever the last call ended, eventually returning an empty iterator. So, the 80x speedup I observed is not actually possible in this way.

@alvalentini
Copy link
Member

alvalentini commented Mar 1, 2024

To solve the problem with the generator, the caching decorator can be used for the _is_applicable method, since the _get_applicable_actions computation time should be dominated by calls to _is_applicable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants