Skip to content

Commit

Permalink
Feature/extension collector buffer (#1196)
Browse files Browse the repository at this point in the history
Adds important functionality to buffer and collector. The PR is very
large but I didn't want to split it up. It's easiest to review commit by
commit, and I think various people should have a look. One can also look
file by file. Together we can do this ;)

I'll edit the description when the review is done

@Trinkle23897: pls manly have a look at the the changes in buffer
related things, and if you want also in the computation of n_step
return. I had to slightly modify one of the tests that was changing the
private `_insertion_index` leading to a malformed buffer, which now
raises an error. Ofc you are very welcome to look at the rest as well :)

@opcode81 and @maxhuettenrauch : pls have a look at the extensions in
Collector. They are untested for now, wanted to get your opinion on the
design first. Also, a quick glance at the trainer would be nice

Ah, also @Trinkle23897: I think I found a bug in the PPO implementation,
see corresponding commit

@dantp-ai : the changes to the buffer here will make the task of fixing
slicing issues easier, especially the new names and additional comments.
Would also be happy about your review, if you have time!
  • Loading branch information
MischaPanch authored Aug 20, 2024
2 parents 616e6a9 + bd58581 commit 002ffd9
Show file tree
Hide file tree
Showing 97 changed files with 1,976 additions and 651 deletions.
19 changes: 7 additions & 12 deletions docs/02_notebooks/L0_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@
"Before we get started, we must first install Tianshou's library and Gym environment by running the commands below. This tutorials will always keep up with the latest version of Tianshou since they also serve as a test for the latest version. If you are using an older version of Tianshou, please refer to the [documentation](https://tianshou.readthedocs.io/en/latest/) of your version.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# !pip install tianshou gym"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -67,7 +58,7 @@
"import gymnasium as gym\n",
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import PPOPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
Expand Down Expand Up @@ -114,8 +105,12 @@
")\n",
"\n",
"# collector\n",
"train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n",
"test_collector = Collector(policy, test_envs)\n",
"train_collector = Collector[CollectStats](\n",
" policy,\n",
" train_envs,\n",
" VectorReplayBuffer(20000, len(train_envs)),\n",
")\n",
"test_collector = Collector[CollectStats](policy, test_envs)\n",
"\n",
"# trainer\n",
"train_result = OnpolicyTrainer(\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/02_notebooks/L5_Collector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"import gymnasium as gym\n",
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import PGPolicy\n",
"from tianshou.utils.net.common import Net\n",
Expand Down Expand Up @@ -94,7 +94,7 @@
" action_space=env.action_space,\n",
" action_scaling=False,\n",
")\n",
"test_collector = Collector(policy, test_envs)"
"test_collector = Collector[CollectStats](policy, test_envs)"
]
},
{
Expand Down Expand Up @@ -187,7 +187,7 @@
"train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n",
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
"\n",
"train_collector = Collector(policy, train_envs, replayBuffer)"
"train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)"
]
},
{
Expand Down
21 changes: 6 additions & 15 deletions docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,8 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-06T15:34:02.969675Z",
"start_time": "2024-05-06T15:34:00.747309Z"
},
"editable": true,
"id": "do-xZ-8B7nVH",
"slideshow": {
Expand All @@ -77,7 +73,7 @@
"import gymnasium as gym\n",
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import PGPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
Expand All @@ -88,13 +84,8 @@
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-06T15:34:07.536452Z",
"start_time": "2024-05-06T15:34:03.636670Z"
}
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_env_num = 4\n",
Expand Down Expand Up @@ -131,8 +122,8 @@
"\n",
"# Create the replay buffer and the collector\n",
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
"test_collector = Collector(policy, test_envs)\n",
"train_collector = Collector(policy, train_envs, replayBuffer)"
"test_collector = Collector[CollectStats](policy, test_envs)\n",
"train_collector = Collector[CollectStats](policy, train_envs, replayBuffer)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions docs/02_notebooks/L7_Experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"import gymnasium as gym\n",
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.data import Collector, CollectStats, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import PPOPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
Expand Down Expand Up @@ -202,12 +202,12 @@
},
"outputs": [],
"source": [
"train_collector = Collector(\n",
"train_collector = Collector[CollectStats](\n",
" policy=policy,\n",
" env=train_envs,\n",
" buffer=VectorReplayBuffer(20000, len(train_envs)),\n",
")\n",
"test_collector = Collector(policy=policy, env=test_envs)"
"test_collector = Collector[CollectStats](policy=policy, env=test_envs)"
]
},
{
Expand Down
6 changes: 6 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,9 @@ autocompletion
codebase
indexable
sliceable
gaussian
logprob
monte
carlo
subclass
subclassing
8 changes: 4 additions & 4 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import C51
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import C51Policy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -112,8 +112,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -173,7 +173,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import DQNPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -148,8 +148,8 @@ def main(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -215,7 +215,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import FQFPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -125,8 +125,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -186,7 +186,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import IQNPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -122,8 +122,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -183,7 +183,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.distributions import Categorical
from torch.optim.lr_scheduler import LambdaLR

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import ICMPolicy, PPOPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -190,8 +190,8 @@ def dist(logits: torch.Tensor) -> Categorical:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -243,7 +243,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import QRDQN
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import QRDQNPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -116,8 +116,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -177,7 +177,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
13 changes: 9 additions & 4 deletions examples/atari/atari_rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from atari_network import Rainbow
from atari_wrapper import make_atari_env

from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.data import (
Collector,
CollectStats,
PrioritizedVectorReplayBuffer,
VectorReplayBuffer,
)
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import C51Policy, RainbowPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -142,8 +147,8 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
weight_norm=not args.no_weight_norm,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -213,7 +218,7 @@ def watch() -> None:
alpha=args.alpha,
beta=args.beta,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
8 changes: 4 additions & 4 deletions examples/atari/atari_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, CollectStats, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import DiscreteSACPolicy, ICMPolicy
from tianshou.policy.base import BasePolicy
Expand Down Expand Up @@ -173,8 +173,8 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
stack_num=args.frames_stack,
)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector = Collector[CollectStats](policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector[CollectStats](policy, test_envs, exploration_noise=True)

# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
Expand Down Expand Up @@ -226,7 +226,7 @@ def watch() -> None:
save_only_last_obs=True,
stack_num=args.frames_stack,
)
collector = Collector(policy, test_envs, buffer, exploration_noise=True)
collector = Collector[CollectStats](policy, test_envs, buffer, exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
Expand Down
Loading

0 comments on commit 002ffd9

Please sign in to comment.