This repository contains the code for the paper "Inferring Behavior-Specific Context Improves Zero-Shot Generalization in Reinforcement Learning"
The environments are generated with different dynamics using CARL and the agent is trained/evaluated on subsets of those environments.
First, create an environment with the required dependencies :
conda env create -f environment.yaml
Install the CARL library :
git clone https://github.com/automl/CARL.git --recursive
cd CARL
pip install .
The experiments shown in the paper are based on the SAC algorithm, and can be run using the following script :
python3 scripts/run_sac.py env=brax_ant context_mode=learned_jcpl
This script accepts the following parameters :
- brax_ant : 3D humanoid robot with 8 degrees of freedom.
- pendulum : 2D pendulum with 2 degrees of freedom.
- cartpole_continuous_tau : 2D cartpole with 2 degrees of freedom.
- mountain_car : 1D mountain car with 1 degree of freedom.
- explicit : the dynamics are given as input to the model as additional state data, both at training and testing time.
- hidden : no dynamics are given as input to the model, neither at training nor testing time.
- learned_iida : Context is Everything : Using previously generated trajectories, a predictor model is trained to predict next states from the current state and the action taken. The predictor model is then used as a context encoder, provided as additional state data to the RL agent.
- learned_jcpl : Joint Context and Policy learning (JCPL): the context encoder is not trained on a prediction task but directly on the trained jointly with the policy network.
- default_value : the agent is only trained on the default value of the context. Used as a baseline.
The script can be run with multiple configurations using Hydra's submitit plugin :
python3 scripts/run_sac.py -m context_mode=default_value,explicit,hidden,learned_iida,learned_jcpl
It is also possible to run the script with multiple seeds using the seed parameter. The results of the paper are obtained by running the seeds 0 to 9 :
python3 scripts/run_sac.py -m seed=0,1,2,3,4,5,6,7,8,9
The learning curves are logged using the wandb library. They can be accessed and downloaded from the wandb dashboard
The violin plots can be generated using the plot_eval_violins.py script, by specifying the path to the hydra results folder :
python3 scripts/plots/plot_eval_violins.py --results_folder_path results/hydra/multirun/2024-04-12/11-31-34
The Interquartile Means (IQM) and performance profiles (PP) can be plotted via the plot_eval_stats.py script, by specifying the path to the hydra results folder :
python3 scripts/plots/plot_eval_stats.py --results_folder_path results/hydra/multirun/2024-04-12/11-31-34