Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarification Needed on Implementing Action Masking in DQN with preprocess_fn in Collector #1159

Open
NeoBerekov opened this issue Jun 7, 2024 · 0 comments

Comments

@NeoBerekov
Copy link

NeoBerekov commented Jun 7, 2024

Tag Request: Please add the tag documentation request

Hi Tianshou Team,

I am currently working on a gymnasium DQN project with action masking and noticed that in Tianshou, all action masks need to be added to the Batch as a "mask" item so that DQNPolicy can handle the masking automatically. To achieve this, I tried passing a preprocess_fn hook when constructing the Collector class, as described in the documentation. However, I found the documentation a bit unclear and couldn't find any relevant examples in the referenced file (test/base/test_collector.py).

The documentation states:

The "preprocess_fn" is a function called before the data has been added to the buffer with batch format. It will receive only "obs" and "env_id" when the collector resets the environment, and will receive the keys "obs_next", "rew", "terminated", "truncated, "info", "policy" and "env_id" in a normal env step. Alternatively, it may also accept the keys "obs_next", "rew", "done", "info", "policy" and "env_id". It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in "test/base/test_collector.py".

In my current DQN project, the observation space is a dictionary that includes an action mask, which complicates things further:

self.observation_space = gym.spaces.Dict({
    'local_obs': gym.spaces.Box(low=-1, high=5000, shape=(9, local_obs_window, local_obs_window), dtype=np.int16),
    'global_obs': gym.spaces.Box(low=0, high=5000, shape=(9, map_size[0], map_size[1]), dtype=np.int16),
    'action_mask': gym.spaces.MultiBinary(11)
})

Do you have any plans to improve the documentation or provide examples regarding this issue? Or should the mask be added to the Batch in a different way other than using preprocess_fn, which I might have overlooked?

Here are the versions of the relevant libraries I am using:
Tianshou version: 0.5.0
Gym version: 0.29.1

Thank you for your assistance.

Best regards

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant