Skip to content

Commit

Permalink
v5.1.3 perf(rnn): optimized representation model (#34, #51)
Browse files Browse the repository at this point in the history
1. updated README
2. optimized representation model
  • Loading branch information
StepNeverStop committed Aug 29, 2021
1 parent 92d4b9a commit 23f910c
Show file tree
Hide file tree
Showing 16 changed files with 301 additions and 126 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ If using this repository for your research, please cite:
```
@misc{RLs,
author = {Keavnn},
title = {RLs: Reinforcement Learning research framework for Unity3D and Gym},
title = {RLs: A Featureless Reinforcement Learning Repository},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
Expand Down
2 changes: 1 addition & 1 deletion rls/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# We follow Semantic Versioning (https://semver.org/)
_MAJOR_VERSION = '5'
_MINOR_VERSION = '1'
_PATCH_VERSION = '2'
_PATCH_VERSION = '3'

# Example: '0.4.2'
__version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION])
12 changes: 6 additions & 6 deletions rls/algorithms/base/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ def __init__(self,
normalize_vector_obs=False,
obs_with_pre_action=False,
rep_net_params={
'use_encoder': False,
'use_rnn': False, # always false, using -r to active RNN
'vector_net_params': {
'h_dim': 16,
'network_type': 'adaptive' # rls.nn.represents.vectors
},
'visual_net_params': {
'visual_feature': 128,
'h_dim': 128,
'network_type': 'simple' # rls.nn.represents.visuals
},
'encoder_net_params': {
'output_dim': 16
'h_dim': 16,
'network_type': 'identity' # rls.nn.represents.encoders
},
'memory_net_params': {
'rnn_units': 16,
Expand Down Expand Up @@ -77,12 +77,12 @@ def __init__(self,

super().__init__()

# TODO: optimization
self.use_rnn = rep_net_params.get('use_rnn', False)
self.memory_net_params = rep_net_params.get('memory_net_params', {
'rnn_units': 16,
'network_type': 'lstm'
})
self.use_rnn = self.memory_net_params.get(
'network_type', 'identity') != 'identity'

self.cp_dir, self.log_dir = [os.path.join(
base_dir, i) for i in ['model', 'log']]
Expand Down
6 changes: 4 additions & 2 deletions rls/algorithms/multi/qplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ def _train(self, BATCH_DICT):
q_target_next_max = (
q_target * next_max_action_one_hot).sum(-1, keepdim=True) # [T, B, 1]

q_target_next_choose_maxs.append(q_target_next_max) # N * [T, B, 1]
q_target_next_choose_maxs.append(
q_target_next_max) # N * [T, B, 1]
q_target_actions.append(next_max_action_one_hot) # N * [T, B, A]
q_target_next_maxs.append(q_target.max(-1, keepdim=True)[0]) # N * [T, B, 1]
q_target_next_maxs.append(
q_target.max(-1, keepdim=True)[0]) # N * [T, B, 1]

q_eval_tot = self.mixer(BATCH_DICT['global'].obs,
q_evals,
Expand Down
3 changes: 2 additions & 1 deletion rls/algorithms/multi/qtran.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def _train(self, BATCH_DICT):
# [T, B, 1]
q_target_next_max = q_target.max(-1, keepdim=True)[0]

q_target_next_choose_maxs.append(q_target_next_max) # N * [T, B, 1]
q_target_next_choose_maxs.append(
q_target_next_max) # N * [T, B, 1]
q_target_cell_states.append(q_target_cell_state) # N * [T, B, *]
q_target_actions.append(next_max_action_one_hot) # N * [T, B, A]

Expand Down
3 changes: 2 additions & 1 deletion rls/algorithms/multi/vdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def _train(self, BATCH_DICT):
# [T, B, 1]
q_target_next_max = q_target.max(-1, keepdim=True)[0]

q_target_next_choose_maxs.append(q_target_next_max) # N * [T, B, 1]
q_target_next_choose_maxs.append(
q_target_next_max) # N * [T, B, 1]
q_eval_tot = self.mixer(
q_evals, BATCH_DICT['global'].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1]
q_target_next_max_tot = self.mixer.t(
Expand Down
12 changes: 6 additions & 6 deletions rls/configs/algorithms.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ policy: &policy

# ----- could be overrided in specific algorithms, i.e. dqn, so as to using different type of visual net, memory net.
rep_net_params: &rep_net_params
use_encoder: false
use_rnn: false
vector_net_params:
h_dim: 16
network_type: "adaptive" # rls.nn.represents.vectors
visual_net_params:
visual_feature: 128
h_dim: 128
network_type: "simple" # rls.nn.represents.visuals
encoder_net_params:
output_dim: 16
h_dim: 16
network_type: "identity" # rls.nn.represents.encoders
memory_net_params:
rnn_units: 16
network_type: "lstm"
network_type: "identity" # rls.nn.represents.memories
# -----

sarl_policy: &sarl_policy
Expand Down Expand Up @@ -81,7 +81,7 @@ dqn: &dqn
rep_net_params:
<<: *rep_net_params
visual_net_params:
visual_feature: 128
h_dim: 128
network_type: "nature"

ddqn: *dqn
Expand Down
2 changes: 1 addition & 1 deletion rls/nn/mixers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .qatten import QattenMixer
from .qmix import QMixer
from .qplex.qplex import QPLEXMixer
from .qtran_base import QTranBase
from .vdn import VDNMixer
from .qplex.qplex import QPLEXMixer

Mixer_REGISTER = {}

Expand Down
11 changes: 7 additions & 4 deletions rls/nn/mixers/qplex/qplex.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch as t
import numpy as np
from .si_weight import SI_Weight
import torch as t

from rls.nn.mlps import MLP
from rls.nn.represent_nets import RepresentationNetwork

from .si_weight import SI_Weight


class QPLEXMixer(t.nn.Module):
'''https://github.com/wjh720/QPLEX/'''
Expand Down Expand Up @@ -73,9 +74,11 @@ def forward(self, state, q_values, actions, max_q_i, **kwargs):
adv_w_final = self.si_weight(state_feat, actions) # [T, B, N]

if self.is_minus_one:
adv_tot = t.sum(adv_q * (adv_w_final - 1.), dim=-1, keepdim=True) # [T, B, 1]
adv_tot = t.sum(adv_q * (adv_w_final - 1.),
dim=-1, keepdim=True) # [T, B, 1]
else:
adv_tot = t.sum(adv_q * adv_w_final, dim=-1, keepdim=True) # [T, B, 1]
adv_tot = t.sum(adv_q * adv_w_final, dim=-
1, keepdim=True) # [T, B, 1]

q_tot = v_tot + adv_tot # [T, B, 1]

Expand Down
11 changes: 7 additions & 4 deletions rls/nn/mixers/qplex/si_weight.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch as t
import numpy as np
import torch as t

from rls.nn.mlps import MLP

Expand Down Expand Up @@ -33,9 +33,12 @@ def forward(self, state_feat, actions):
'''
data = t.cat([state_feat]+actions, dim=-1) # [T, B, *]

all_head_key = [k_ext(state_feat) for k_ext in self.key_extractors] # List[[T, B, 1]]
all_head_agents = [k_ext(state_feat) for k_ext in self.agents_extractors] # List[[T, B, N]]
all_head_action = [sel_ext(data) for sel_ext in self.action_extractors] # List[[T, B, N]]
all_head_key = [k_ext(state_feat)
for k_ext in self.key_extractors] # List[[T, B, 1]]
all_head_agents = [k_ext(state_feat)
for k_ext in self.agents_extractors] # List[[T, B, N]]
# List[[T, B, N]]
all_head_action = [sel_ext(data) for sel_ext in self.action_extractors]

head_attend_weights = []
for curr_head_key, curr_head_agents, curr_head_action in zip(all_head_key, all_head_agents, all_head_action):
Expand Down
65 changes: 65 additions & 0 deletions rls/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from rls.nn.represent_nets import RepresentationNetwork
from rls.utils.torch_utils import clip_nn_log_std

Model_REGISTER = {}


class BaseModel(t.nn.Module):

Expand Down Expand Up @@ -46,6 +48,9 @@ def forward(self, x, **kwargs):
return self.net(x)


Model_REGISTER['actor_dpg'] = ActorDPG


class ActorMuLogstd(BaseModel):
'''
use for PPO/PG algorithms' actor network.
Expand Down Expand Up @@ -83,6 +88,9 @@ def forward(self, x, **kwargs):
return (mu, log_std)


Model_REGISTER['actor_mulogstd'] = ActorMuLogstd


class ActorCts(BaseModel):
'''
use for continuous action space.
Expand Down Expand Up @@ -118,6 +126,9 @@ def forward(self, x, **kwargs):
return (mu, log_std)


Model_REGISTER['actor_continuous'] = ActorCts


class ActorDct(BaseModel):
'''
use for discrete action space.
Expand All @@ -136,6 +147,9 @@ def forward(self, x, **kwargs):
return logits


Model_REGISTER['actor_discrete'] = ActorDct


class CriticQvalueOne(BaseModel):
'''
use for evaluate the value given a state-action pair.
Expand All @@ -154,6 +168,9 @@ def forward(self, x, a, **kwargs):
return q


Model_REGISTER['critic_q1'] = CriticQvalueOne


class CriticQvalueOneDDPG(BaseModel):
'''
Original architecture in DDPG paper.
Expand All @@ -175,6 +192,9 @@ def forward(self, x, a, **kwargs):
return q


Model_REGISTER['critic_q1_ddpg'] = CriticQvalueOneDDPG


class CriticQvalueOneTD3(BaseModel):
'''
Original architecture in TD3 paper.
Expand All @@ -197,6 +217,9 @@ def forward(self, x, a, **kwargs):
return q


Model_REGISTER['critic_q1_td3'] = CriticQvalueOneTD3


class CriticValue(BaseModel):
'''
use for evaluate the value given a state.
Expand All @@ -214,6 +237,9 @@ def forward(self, x, **kwargs):
return v


Model_REGISTER['critic_v'] = CriticValue


class CriticQvalueAll(BaseModel):
'''
use for evaluate all values of Q(S,A) given a state. must be discrete action space.
Expand All @@ -232,6 +258,9 @@ def forward(self, x, **kwargs):
return q


Model_REGISTER['critic_q_all'] = CriticQvalueAll


class CriticQvalueBootstrap(BaseModel):
'''
use for bootstrapped dqn.
Expand All @@ -249,6 +278,9 @@ def forward(self, x, **kwargs):
return q


Model_REGISTER['critic_q_bootstrap'] = CriticQvalueBootstrap


class CriticDueling(BaseModel):
'''
Neural network for dueling deep Q network.
Expand Down Expand Up @@ -278,6 +310,9 @@ def forward(self, x, **kwargs):
return q


Model_REGISTER['critic_dueling'] = CriticDueling


class OcIntraOption(BaseModel):
'''
Intra Option Neural network of Option-Critic.
Expand All @@ -298,6 +333,9 @@ def forward(self, x, **kwargs):
return pi


Model_REGISTER['oc_intra_option'] = OcIntraOption


class AocShare(BaseModel):
'''
Neural network for AOC.
Expand Down Expand Up @@ -329,6 +367,9 @@ def forward(self, x, **kwargs):
return q, pi, beta


Model_REGISTER['aoc_share'] = AocShare


class PpocShare(BaseModel):
'''
Neural network for PPOC.
Expand Down Expand Up @@ -363,6 +404,9 @@ def forward(self, x, **kwargs):
return q, pi, beta, o


Model_REGISTER['ppoc_share'] = PpocShare


class ActorCriticValueCts(BaseModel):
'''
combine actor network and critic network, share some nn layers. use for continuous action space.
Expand Down Expand Up @@ -405,6 +449,9 @@ def forward(self, x, **kwargs):
return (mu, log_std, v)


Model_REGISTER['ac_v_continuous'] = ActorCriticValueCts


class ActorCriticValueDct(BaseModel):
'''
combine actor network and critic network, share some nn layers. use for discrete action space.
Expand All @@ -431,6 +478,9 @@ def forward(self, x, **kwargs):
return (logits, v)


Model_REGISTER['ac_v_discrete'] = ActorCriticValueDct


class C51Distributional(BaseModel):
'''
neural network for C51
Expand All @@ -452,6 +502,9 @@ def forward(self, x, **kwargs):
return q_dist


Model_REGISTER['c51'] = C51Distributional


class QrdqnDistributional(BaseModel):
'''
neural network for QRDQN
Expand All @@ -472,6 +525,9 @@ def forward(self, x, **kwargs):
return q_dist


Model_REGISTER['qrdqn'] = C51Distributional


class RainbowDueling(BaseModel):
'''
Neural network for Rainbow.
Expand Down Expand Up @@ -511,6 +567,9 @@ def forward(self, x, **kwargs):
return qs # [B, A, N] or [T, B, A, N]


Model_REGISTER['rainbow'] = RainbowDueling


class IqnNet(BaseModel):
def __init__(self, obs_spec, rep_net_params, action_dim, quantiles_idx, network_settings):
super().__init__(obs_spec, rep_net_params)
Expand Down Expand Up @@ -551,6 +610,9 @@ def forward(self, x, quantiles_tiled, **kwargs):
return quantiles_value # [N, B, A] or [T, N, B, A]


Model_REGISTER['iqn'] = IqnNet


class MACriticQvalueOne(t.nn.Module):
'''
use for evaluate the value given a state-action pair.
Expand All @@ -576,3 +638,6 @@ def forward(self, x, a, **kwargs):
x = t.cat(outs, -1)
q = self.net(t.cat((x, a), -1))
return q


Model_REGISTER['ma_critic_q1'] = MACriticQvalueOne
Loading

0 comments on commit 23f910c

Please sign in to comment.