Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type hints #293

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import time
from distutils.util import strtobool
from typing import Callable

import gym
import numpy as np
Expand All @@ -15,7 +16,7 @@
from torch.utils.tensorboard import SummaryWriter


def parse_args():
def parse_args() -> argparse.Namespace:
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
Expand Down Expand Up @@ -65,8 +66,8 @@ def parse_args():
return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
def make_env(env_id: str, seed: int, idx: int, capture_video: bool, run_name: str) -> Callable[[], gym.Env]:
def thunk() -> gym.Env:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
Expand All @@ -82,7 +83,10 @@ def thunk():

# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
def __init__(self, env):

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove this space?

network: nn.Sequential

def __init__(self, env: gym.vector.SyncVectorEnv):
super().__init__()
self.network = nn.Sequential(
nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
Expand All @@ -92,11 +96,11 @@ def __init__(self, env):
nn.Linear(84, env.single_action_space.n),
)

def forward(self, x):
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've used FloatTensor here but we should just use torch.Tensor if type hints are pursued further.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine

return self.network(x)


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
def linear_schedule(start_e: float, end_e: float, duration: int, t: int) -> float:
slope = (end_e - start_e) / duration
return max(slope * t + start_e, end_e)

Expand Down Expand Up @@ -131,7 +135,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]
) # type:ignore[abstract]
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

q_network = QNetwork(envs).to(device)
Expand Down
40 changes: 28 additions & 12 deletions cleanrl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import random
import time
from distutils.util import strtobool
from typing import Callable, Optional, cast

import gym
import numpy as np
Expand All @@ -15,7 +16,7 @@
from torch.utils.tensorboard import SummaryWriter


def parse_args():
def parse_args() -> argparse.Namespace:
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
Expand Down Expand Up @@ -77,8 +78,8 @@ def parse_args():
return args


def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
def make_env(env_id: str, seed: int, idx: int, capture_video: bool, run_name: str) -> Callable[[], gym.Env]:
def thunk() -> gym.Env:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
Expand All @@ -92,14 +93,18 @@ def thunk():
return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
def layer_init(layer: nn.Linear, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module:
torch.nn.init.orthogonal_(layer.weight, std)
torch.nn.init.constant_(layer.bias, bias_const)
return layer


class Agent(nn.Module):
def __init__(self, envs):

critic: nn.Sequential
actor: nn.Sequential

def __init__(self, envs: gym.vector.SyncVectorEnv):
super().__init__()
self.critic = nn.Sequential(
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
Expand All @@ -116,10 +121,12 @@ def __init__(self, envs):
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01),
)

def get_value(self, x):
def get_value(self, x: torch.Tensor) -> torch.Tensor:
return self.critic(x)

def get_action_and_value(self, x, action=None):
def get_action_and_value(
self, x: torch.Tensor, action: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
logits = self.actor(x)
probs = Categorical(logits=logits)
if action is None:
Expand Down Expand Up @@ -159,15 +166,24 @@ def get_action_and_value(self, x, action=None):
# env setup
envs = gym.vector.SyncVectorEnv(
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
)
) # type:ignore[abstract]
Copy link
Collaborator Author

@timoklein timoklein Oct 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SyncVectorEnv inherits from VectorEnv which inherits from Env. For older gym versions (I'm currently on 0.23.1), Env is an ABC with abstract method render that is not overriden by any of the vector envs. Since it's fixed in the newest gym release and not our issue, I'm ignoring here.

assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"
# Handling gym shapes being Optionals (variant 1)
# Personally i'd prefer the asserts
assert isinstance(envs.single_observation_space.shape, tuple), "shape of observation space must be defined"
assert isinstance(envs.single_action_space.shape, tuple), "shape of action space must be defined"

# Handling gym shapes being Optionals (variant 2)
# Once could also cast inside each call but in my eyes that's not conducive to readability
obs_space_shape = cast(tuple[int, ...], envs.single_observation_space.shape)
action_space_shape = cast(tuple[int, ...], envs.single_action_space.shape)
Comment on lines +171 to +179
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gym spaces can in theory return None shapes. Mypy will complain about this when concatenating the shape tuples later.

  • Option 1 is to use a cast either once here or every time the spaces are accessed. I don't think that's very readable.
  • Option 2 is to assert that the space shapes are tuples. Doing it once here fixes all errors for the rest of the code. Since there's an assert in this place already anyway I think this is the better option.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option 1 is more preferrable


agent = Agent(envs).to(device)
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

# ALGO Logic: Storage setup
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
obs = torch.zeros((args.num_steps, args.num_envs) + obs_space_shape).to(device)
actions = torch.zeros((args.num_steps, args.num_envs) + action_space_shape).to(device)
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
Expand Down Expand Up @@ -228,9 +244,9 @@ def get_action_and_value(self, x, action=None):
returns = advantages + values

# flatten the batch
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
b_obs = obs.reshape((-1,) + obs_space_shape)
b_logprobs = logprobs.reshape(-1)
b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
b_actions = actions.reshape((-1,) + action_space_shape)
b_advantages = advantages.reshape(-1)
b_returns = returns.reshape(-1)
b_values = values.reshape(-1)
Expand Down