Skip to content

Latest commit

 

History

History
89 lines (58 loc) · 3.72 KB

README.md

File metadata and controls

89 lines (58 loc) · 3.72 KB

Inferring Behavior-Specific Context Improves Zero-Shot Generalization in Reinforcement Learning

This repository contains the code for the paper "Inferring Behavior-Specific Context Improves Zero-Shot Generalization in Reinforcement Learning"

The environments are generated with different dynamics using CARL and the agent is trained/evaluated on subsets of those environments.

Usage

Installation

First, create an environment with the required dependencies :

conda env create -f environment.yaml

Install the CARL library :

git clone https://github.com/automl/CARL.git --recursive
cd CARL
pip install .

Run experiments

The experiments shown in the paper are based on the SAC algorithm, and can be run using the following script :

python3 scripts/run_sac.py env=brax_ant context_mode=learned_jcpl

This script accepts the following parameters :

Environments

  • brax_ant : 3D humanoid robot with 8 degrees of freedom.
  • pendulum : 2D pendulum with 2 degrees of freedom.
  • cartpole_continuous_tau : 2D cartpole with 2 degrees of freedom.
  • mountain_car : 1D mountain car with 1 degree of freedom.

Context learning methods

  • explicit : the dynamics are given as input to the model as additional state data, both at training and testing time.
  • hidden : no dynamics are given as input to the model, neither at training nor testing time.
  • learned_iida : Context is Everything : Using previously generated trajectories, a predictor model is trained to predict next states from the current state and the action taken. The predictor model is then used as a context encoder, provided as additional state data to the RL agent.
  • learned_jcpl : Joint Context and Policy learning (JCPL): the context encoder is not trained on a prediction task but directly on the trained jointly with the policy network.
  • default_value : the agent is only trained on the default value of the context. Used as a baseline.

Sweep over multiple configurations using Hydra

The script can be run with multiple configurations using Hydra's submitit plugin :

python3 scripts/run_sac.py -m context_mode=default_value,explicit,hidden,learned_iida,learned_jcpl

It is also possible to run the script with multiple seeds using the seed parameter. The results of the paper are obtained by running the seeds 0 to 9 :

python3 scripts/run_sac.py -m seed=0,1,2,3,4,5,6,7,8,9

Generate plots

Learning curves

The learning curves are logged using the wandb library. They can be accessed and downloaded from the wandb dashboard

Violin plots

The violin plots can be generated using the plot_eval_violins.py script, by specifying the path to the hydra results folder :

python3 scripts/plots/plot_eval_violins.py --results_folder_path results/hydra/multirun/2024-04-12/11-31-34

Interquartile Means and performance profiles

The Interquartile Means (IQM) and performance profiles (PP) can be plotted via the plot_eval_stats.py script, by specifying the path to the hydra results folder :

python3 scripts/plots/plot_eval_stats.py --results_folder_path results/hydra/multirun/2024-04-12/11-31-34