Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ppo_pettingzoo_ma_atari.py #408

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d9b9b11
Update ppo_pettingzoo_ma_atari.py
elliottower Jul 12, 2023
edc79d6
Pre-commit
elliottower Jul 13, 2023
d39da5e
Update PZ version
elliottower Jul 13, 2023
2b2dfce
Update Super
elliottower Jul 13, 2023
6d37313
Run pre-commit --hook-stage manual --all-files
elliottower Jul 13, 2023
0168986
run poetry lock --no-update to fix inconsistencies with versions
elliottower Jul 13, 2023
b7bffe9
re-run pre-commit with --hook-stage manual
elliottower Jul 13, 2023
2c76bb1
Change torch.maximum to torch.logical_or for dones
elliottower Jul 17, 2023
025f491
Use np.logical_or instead of torch (allows subtraction)
elliottower Jul 18, 2023
09f7a7f
Merge remote-tracking branch 'upstream/master' into patch-1
elliottower Jan 18, 2024
16e0764
Finish merge with upstream master
elliottower Jan 18, 2024
928b7b3
Fix SuperSuit to most recent version
elliottower Jan 18, 2024
d7a2aa2
Fix SuperSuit version in poetry lockfile and tinyscaler in pettingzoo…
elliottower Jan 18, 2024
d77cca0
Fix pettingzoo-requirements export (pre-commit hooks)
elliottower Jan 18, 2024
afba4e8
Test updating pettingzoo to new version 1.24.3
elliottower Jan 18, 2024
8671154
Update ma_atari to match regular atari (tyro, minor code style changes)
elliottower Jan 18, 2024
d2cf1a5
pre-commit
elliottower Jan 18, 2024
981bc63
Revert accidentally changed files (zoo and ipynb, which randomly seem…
elliottower Jan 18, 2024
454364d
Revert ipynb change
elliottower Jan 18, 2024
06473b2
Update dead pettingzoo.ml links to Farama foundation links
elliottower Jan 18, 2024
1b725cf
Update to newly release SuperSuit 3.9.2 (minor bugfixes but best to k…
elliottower Jan 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions cleanrl/ppo_pettingzoo_ma_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from distutils.util import strtobool

import gym
import gymnasium as gym
import numpy as np
import supersuit as ss
import torch
Expand Down Expand Up @@ -156,11 +156,10 @@ def get_action_and_value(self, x, action=None):
env = ss.frame_stack_v1(env, 4)
env = ss.agent_indicator_v0(env, type_only=False)
env = ss.pettingzoo_env_to_vec_env_v1(env)
envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gym")
envs = ss.concat_vec_envs_v1(env, args.num_envs // 2, num_cpus=0, base_class="gymnasium")
envs.single_observation_space = envs.observation_space
envs.single_action_space = envs.action_space
envs.is_vector_env = True
envs = gym.wrappers.RecordEpisodeStatistics(envs)
if args.capture_video:
envs = gym.wrappers.RecordVideo(envs, f"videos/{run_name}")
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
Expand All @@ -173,14 +172,17 @@ def get_action_and_value(self, x, action=None):
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
terminations = torch.zeros((args.num_steps, args.num_envs)).to(device)
truncations = torch.zeros((args.num_steps, args.num_envs)).to(device)
values = torch.zeros((args.num_steps, args.num_envs)).to(device)

# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
next_obs = torch.Tensor(envs.reset()).to(device)
next_done = torch.zeros(args.num_envs).to(device)
next_obs, info = envs.reset(seed=args.seed)
next_obs = torch.Tensor(next_obs).to(device)
next_termination = torch.zeros(args.num_envs).to(device)
next_truncation = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size

for update in range(1, num_updates + 1):
Expand All @@ -193,7 +195,8 @@ def get_action_and_value(self, x, action=None):
for step in range(0, args.num_steps):
global_step += 1 * args.num_envs
obs[step] = next_obs
dones[step] = next_done
terminations[step] = next_termination
truncations[step] = next_truncation

# ALGO LOGIC: action logic
with torch.no_grad():
Expand All @@ -203,10 +206,15 @@ def get_action_and_value(self, x, action=None):
logprobs[step] = logprob

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, reward, done, info = envs.step(action.cpu().numpy())
next_obs, reward, termination, truncation, info = envs.step(action.cpu().numpy())
rewards[step] = torch.tensor(reward).to(device).view(-1)
next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
next_obs, next_termination, next_truncation = (
torch.Tensor(next_obs).to(device),
torch.Tensor(termination).to(device),
torch.Tensor(truncation).to(device),
)

# TODO: fix this
for idx, item in enumerate(info):
player_idx = idx % 2
if "episode" in item.keys():
Expand All @@ -219,6 +227,8 @@ def get_action_and_value(self, x, action=None):
next_value = agent.get_value(next_obs).reshape(1, -1)
advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
next_done = torch.maximum(next_termination, next_truncation)
elliottower marked this conversation as resolved.
Show resolved Hide resolved
dones = torch.maximum(terminations, truncations)
for t in reversed(range(args.num_steps)):
if t == args.num_steps - 1:
nextnonterminal = 1.0 - next_done
Expand Down
Loading
Loading