Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRPO on Atari #415

Open
ling-pan opened this issue Mar 18, 2019 · 2 comments
Open

TRPO on Atari #415

ling-pan opened this issue Mar 18, 2019 · 2 comments

Comments

@ling-pan
Copy link

Hi,

I am wondering whether chainerrl supports TRPO to run atari? I tried to do so by following the code for training PPO on atari, but I am faced with the following error:

Traceback (most recent call last):
File "train_trpo_ale.py", line 187, in
main()
File "train_trpo_ale.py", line 182, in main
eval_interval=args.eval_interval,
File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainerrl-0.6.0-py3.6.egg/chainerrl/experiments/train_agent.py", line 174, in train_agent_with_evaluation
File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainerrl-0.6.0-py3.6.egg/chainerrl/experiments/train_agent.py", line 59, in train_agent
File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainerrl-0.6.0-py3.6.egg/chainerrl/agents/trpo.py", line 521, in act_and_train
File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainerrl-0.6.0-py3.6.egg/chainerrl/misc/batch_states.py", line 23, in batch_states
File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainer-6.0.0b3-py3.6.egg/chainer/dataset/convert.py", line 58, in wrap_call
return func(*args, **kwargs)
File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainer-6.0.0b3-py3.6.egg/chainer/dataset/convert.py", line 249, in concat_examples
return to_device(device, _concat_arrays(batch, padding))
File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/chainer-6.0.0b3-py3.6.egg/chainer/dataset/convert.py", line 256, in _concat_arrays
arrays = numpy.asarray(arrays)
File "/anaconda3/envs/chainerrl/lib/python3.6/site-packages/numpy/core/numeric.py", line 538, in asarray
return array(a, dtype, copy=False, order=order)
TypeError: int() argument must be a string, a bytes-like object or a number, not 'LazyFrames'

It seems that PPO can handle LazyFrames, and I don't know why it fails on TRPO.

Thanks!

@muupan
Copy link
Member

muupan commented Mar 18, 2019

I think it should work, but not tested. Can you share train_trpo_ale.py?

@ling-pan
Copy link
Author

It's here:

from future import division
from future import print_function
from future import unicode_literals
from future import absolute_import
from builtins import * # NOQA
from future import standard_library
standard_library.install_aliases() # NOQA

import argparse
import logging
import os

import chainer
from chainer import functions as F
import gym
import gym.wrappers
import numpy as np

import chainerrl
from chainerrl import links
import cupy

from chainerrl.wrappers import atari_wrappers

def main():

parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4',
                    help='Gym Env ID')
parser.add_argument('--gpu', type=int, default=0,
                    help='GPU device ID. Set to -1 to use CPUs only.')
parser.add_argument('--seed', type=int, default=0,
                    help='Random seed [0, 2 ** 32)')
parser.add_argument('--outdir', type=str, default='results',
                    help='Directory path to save output files.'
                         ' If it does not exist, it will be created.')
parser.add_argument('--steps', type=int, default=10 ** 6,
                    help='Total time steps for training.')
parser.add_argument('--max-frames', type=int, default=30 * 60 * 60,  # 30 minutes with 60 fps
                    help='Maximum number of frames for each episode.')
parser.add_argument('--eval-interval', type=int, default=10 ** 5,
                    help='Interval between evaluation phases in steps.')
parser.add_argument('--eval-n-runs', type=int, default=10,
                    help='Number of episodes ran in an evaluation phase')
parser.add_argument('--demo', action='store_true', default=False,
                    help='Run demo episodes, not training')
parser.add_argument('--load', type=str, default='',
                    help='Directory path to load a saved agent data from'
                         ' if it is a non-empty string.')
parser.add_argument('--logger-level', type=int, default=logging.INFO,
                    help='Level of the root logger.')
parser.add_argument('--render', action='store_true', default=False,
                    help='Render the env')
parser.add_argument('--monitor', action='store_true',
                    help='Monitor the env by gym.wrappers.Monitor. Videos and additional log will be saved.')
parser.add_argument('--trpo-update-interval', type=int, default=5000,
                    help='Interval steps of TRPO iterations.')
args = parser.parse_args()

logging.basicConfig(level=args.logger_level)

# Set random seed
chainerrl.misc.set_random_seed(args.seed, gpus=(args.gpu,))

args.outdir = chainerrl.experiments.prepare_output_dir(args, args.outdir)
print('Output files are saved in {}'.format(args.outdir))

def make_env(test):
    # Use different random seeds for train and test envs
    env_seed = 2 ** 32 - args.seed if test else args.seed
    env = atari_wrappers.wrap_deepmind(
        atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
        episode_life=not test,
        clip_rewards=not test)
    env.seed(int(env_seed))
    if args.monitor:
        env = gym.wrappers.Monitor(
            env, args.outdir,
            mode='evaluation' if test else 'training')
    if args.render:
        env = chainerrl.wrappers.Render(env)
    return env

env = make_env(test=False)
eval_env = make_env(test=True)
obs_space = env.observation_space
action_space = env.action_space
print('Observation space:', obs_space)
print('Action space:', action_space)

if not isinstance(obs_space, gym.spaces.Box):
    print("""This example only supports gym.spaces.Box observation spaces. To apply it to other observation spaces, use a custom phi function that convert an observation to numpy.ndarray of numpy.float32.""")  # NOQA
    return

# Normalize observations based on their empirical mean and variance
obs_normalizer = chainerrl.links.EmpiricalNormalization(obs_space.low.size)

# Use a Softmax policy for discrete action spaces
policy = chainerrl.policies.FCSoftmaxPolicy(
    obs_space.low.size,
    action_space.n,
    n_hidden_channels=64,
    n_hidden_layers=2,
    last_wscale=0.01,
    nonlinearity=F.tanh,
)

# Use a value function to reduce variance
vf = chainerrl.v_functions.FCVFunction(
    obs_space.low.size,
    n_hidden_channels=64,
    n_hidden_layers=2,
    last_wscale=0.01,
    nonlinearity=F.tanh,
)

if args.gpu >= 0:
    chainer.cuda.get_device_from_id(args.gpu).use()
    policy.to_gpu(args.gpu)
    vf.to_gpu(args.gpu)
    obs_normalizer.to_gpu(args.gpu)

# TRPO's policy is optimized via CG and line search, so it doesn't require a chainer.Optimizer. Only the value function needs it.
vf_opt = chainer.optimizers.Adam()
vf_opt.setup(vf)

# Draw the computational graph and save it in the output directory.
if policy.xp == cupy:
    formatted_obs_space_low = cupy.array(obs_space.low)
else:
    formatted_obs_space_low = obs_space.low

fake_obs = chainer.Variable(
    policy.xp.zeros_like(formatted_obs_space_low, dtype=policy.xp.float32)[None], name='observation'
)
chainerrl.misc.draw_computational_graph([policy(fake_obs)], os.path.join(args.outdir, 'policy'))
chainerrl.misc.draw_computational_graph([vf(fake_obs)], os.path.join(args.outdir, 'vf'))

# Feature extractor
def phi(x):
    return np.asarray(x, dtype=np.float32) / 255

# Hyperparameters in http://arxiv.org/abs/1709.06560
agent = chainerrl.agents.TRPO(
    policy=policy,
    vf=vf,
    vf_optimizer=vf_opt,
    obs_normalizer=obs_normalizer,
    update_interval=args.trpo_update_interval,
    conjugate_gradient_max_iter=20,
    conjugate_gradient_damping=1e-1,
    gamma=0.995,
    lambd=0.97,
    vf_epochs=5,
    entropy_coef=0,
)

if args.load:
    agent.load(args.load)

if args.demo:
    env = make_env(test=True)
    eval_stats = chainerrl.experiments.eval_performance(
        env=eval_env,
        agent=agent,
        n_steps=None,
        n_episodes=args.eval_n_runs,
    )
    print('n_runs: {} mean: {} median: {} stdev {}'.format(
        args.eval_n_runs, eval_stats['mean'], eval_stats['median'], eval_stats['stdev']))
else:
    chainerrl.experiments.train_agent_with_evaluation(
        agent=agent,
        env=env,
        eval_env=eval_env,
        outdir=args.outdir,
        steps=args.steps,
        eval_n_steps=None,
        eval_n_episodes=args.eval_n_runs,
        eval_interval=args.eval_interval,
    )

if name == 'main':
main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants