From 23f910cfb418727b8457b0f85a3316e8f500ca98 Mon Sep 17 00:00:00 2001 From: Keavnn Date: Sun, 29 Aug 2021 15:06:11 +0800 Subject: [PATCH] v5.1.3 perf(rnn): optimized representation model (#34, #51) 1. updated README 2. optimized representation model --- README.md | 2 +- rls/_metadata.py | 2 +- rls/algorithms/base/policy.py | 12 ++-- rls/algorithms/multi/qplex.py | 6 +- rls/algorithms/multi/qtran.py | 3 +- rls/algorithms/multi/vdn.py | 3 +- rls/configs/algorithms.yaml | 12 ++-- rls/nn/mixers/__init__.py | 2 +- rls/nn/mixers/qplex/qplex.py | 11 +-- rls/nn/mixers/qplex/si_weight.py | 11 +-- rls/nn/models.py | 65 ++++++++++++++++++ rls/nn/networks.py | 89 +++++++++---------------- rls/nn/represent_nets.py | 41 +++++------- rls/nn/represents/encoders.py | 29 ++++++++ rls/nn/represents/memories.py | 111 +++++++++++++++++++++++++++++++ rls/nn/represents/vectors.py | 28 ++++---- 16 files changed, 301 insertions(+), 126 deletions(-) create mode 100644 rls/nn/represents/encoders.py create mode 100644 rls/nn/represents/memories.py diff --git a/README.md b/README.md index cd80e1f..69694bb 100644 --- a/README.md +++ b/README.md @@ -238,7 +238,7 @@ If using this repository for your research, please cite: ``` @misc{RLs, author = {Keavnn}, - title = {RLs: Reinforcement Learning research framework for Unity3D and Gym}, + title = {RLs: A Featureless Reinforcement Learning Repository}, year = {2019}, publisher = {GitHub}, journal = {GitHub repository}, diff --git a/rls/_metadata.py b/rls/_metadata.py index 95ed85b..ea9e582 100644 --- a/rls/_metadata.py +++ b/rls/_metadata.py @@ -8,7 +8,7 @@ # We follow Semantic Versioning (https://semver.org/) _MAJOR_VERSION = '5' _MINOR_VERSION = '1' -_PATCH_VERSION = '2' +_PATCH_VERSION = '3' # Example: '0.4.2' __version__ = '.'.join([_MAJOR_VERSION, _MINOR_VERSION, _PATCH_VERSION]) diff --git a/rls/algorithms/base/policy.py b/rls/algorithms/base/policy.py index 750a974..8d920e2 100644 --- a/rls/algorithms/base/policy.py +++ b/rls/algorithms/base/policy.py @@ -36,17 +36,17 @@ def __init__(self, normalize_vector_obs=False, obs_with_pre_action=False, rep_net_params={ - 'use_encoder': False, - 'use_rnn': False, # always false, using -r to active RNN 'vector_net_params': { + 'h_dim': 16, 'network_type': 'adaptive' # rls.nn.represents.vectors }, 'visual_net_params': { - 'visual_feature': 128, + 'h_dim': 128, 'network_type': 'simple' # rls.nn.represents.visuals }, 'encoder_net_params': { - 'output_dim': 16 + 'h_dim': 16, + 'network_type': 'identity' # rls.nn.represents.encoders }, 'memory_net_params': { 'rnn_units': 16, @@ -77,12 +77,12 @@ def __init__(self, super().__init__() - # TODO: optimization - self.use_rnn = rep_net_params.get('use_rnn', False) self.memory_net_params = rep_net_params.get('memory_net_params', { 'rnn_units': 16, 'network_type': 'lstm' }) + self.use_rnn = self.memory_net_params.get( + 'network_type', 'identity') != 'identity' self.cp_dir, self.log_dir = [os.path.join( base_dir, i) for i in ['model', 'log']] diff --git a/rls/algorithms/multi/qplex.py b/rls/algorithms/multi/qplex.py index 3eb1177..03ba588 100644 --- a/rls/algorithms/multi/qplex.py +++ b/rls/algorithms/multi/qplex.py @@ -76,9 +76,11 @@ def _train(self, BATCH_DICT): q_target_next_max = ( q_target * next_max_action_one_hot).sum(-1, keepdim=True) # [T, B, 1] - q_target_next_choose_maxs.append(q_target_next_max) # N * [T, B, 1] + q_target_next_choose_maxs.append( + q_target_next_max) # N * [T, B, 1] q_target_actions.append(next_max_action_one_hot) # N * [T, B, A] - q_target_next_maxs.append(q_target.max(-1, keepdim=True)[0]) # N * [T, B, 1] + q_target_next_maxs.append( + q_target.max(-1, keepdim=True)[0]) # N * [T, B, 1] q_eval_tot = self.mixer(BATCH_DICT['global'].obs, q_evals, diff --git a/rls/algorithms/multi/qtran.py b/rls/algorithms/multi/qtran.py index 2dee090..8009b1b 100644 --- a/rls/algorithms/multi/qtran.py +++ b/rls/algorithms/multi/qtran.py @@ -94,7 +94,8 @@ def _train(self, BATCH_DICT): # [T, B, 1] q_target_next_max = q_target.max(-1, keepdim=True)[0] - q_target_next_choose_maxs.append(q_target_next_max) # N * [T, B, 1] + q_target_next_choose_maxs.append( + q_target_next_max) # N * [T, B, 1] q_target_cell_states.append(q_target_cell_state) # N * [T, B, *] q_target_actions.append(next_max_action_one_hot) # N * [T, B, A] diff --git a/rls/algorithms/multi/vdn.py b/rls/algorithms/multi/vdn.py index b585eeb..4f40079 100644 --- a/rls/algorithms/multi/vdn.py +++ b/rls/algorithms/multi/vdn.py @@ -127,7 +127,8 @@ def _train(self, BATCH_DICT): # [T, B, 1] q_target_next_max = q_target.max(-1, keepdim=True)[0] - q_target_next_choose_maxs.append(q_target_next_max) # N * [T, B, 1] + q_target_next_choose_maxs.append( + q_target_next_max) # N * [T, B, 1] q_eval_tot = self.mixer( q_evals, BATCH_DICT['global'].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] q_target_next_max_tot = self.mixer.t( diff --git a/rls/configs/algorithms.yaml b/rls/configs/algorithms.yaml index d28dafc..8e85326 100644 --- a/rls/configs/algorithms.yaml +++ b/rls/configs/algorithms.yaml @@ -11,18 +11,18 @@ policy: &policy # ----- could be overrided in specific algorithms, i.e. dqn, so as to using different type of visual net, memory net. rep_net_params: &rep_net_params - use_encoder: false - use_rnn: false vector_net_params: + h_dim: 16 network_type: "adaptive" # rls.nn.represents.vectors visual_net_params: - visual_feature: 128 + h_dim: 128 network_type: "simple" # rls.nn.represents.visuals encoder_net_params: - output_dim: 16 + h_dim: 16 + network_type: "identity" # rls.nn.represents.encoders memory_net_params: rnn_units: 16 - network_type: "lstm" + network_type: "identity" # rls.nn.represents.memories # ----- sarl_policy: &sarl_policy @@ -81,7 +81,7 @@ dqn: &dqn rep_net_params: <<: *rep_net_params visual_net_params: - visual_feature: 128 + h_dim: 128 network_type: "nature" ddqn: *dqn diff --git a/rls/nn/mixers/__init__.py b/rls/nn/mixers/__init__.py index b3da9ce..0786344 100644 --- a/rls/nn/mixers/__init__.py +++ b/rls/nn/mixers/__init__.py @@ -1,8 +1,8 @@ from .qatten import QattenMixer from .qmix import QMixer +from .qplex.qplex import QPLEXMixer from .qtran_base import QTranBase from .vdn import VDNMixer -from .qplex.qplex import QPLEXMixer Mixer_REGISTER = {} diff --git a/rls/nn/mixers/qplex/qplex.py b/rls/nn/mixers/qplex/qplex.py index ac93f91..ffdeac7 100644 --- a/rls/nn/mixers/qplex/qplex.py +++ b/rls/nn/mixers/qplex/qplex.py @@ -1,10 +1,11 @@ -import torch as t import numpy as np -from .si_weight import SI_Weight +import torch as t from rls.nn.mlps import MLP from rls.nn.represent_nets import RepresentationNetwork +from .si_weight import SI_Weight + class QPLEXMixer(t.nn.Module): '''https://github.com/wjh720/QPLEX/''' @@ -73,9 +74,11 @@ def forward(self, state, q_values, actions, max_q_i, **kwargs): adv_w_final = self.si_weight(state_feat, actions) # [T, B, N] if self.is_minus_one: - adv_tot = t.sum(adv_q * (adv_w_final - 1.), dim=-1, keepdim=True) # [T, B, 1] + adv_tot = t.sum(adv_q * (adv_w_final - 1.), + dim=-1, keepdim=True) # [T, B, 1] else: - adv_tot = t.sum(adv_q * adv_w_final, dim=-1, keepdim=True) # [T, B, 1] + adv_tot = t.sum(adv_q * adv_w_final, dim=- + 1, keepdim=True) # [T, B, 1] q_tot = v_tot + adv_tot # [T, B, 1] diff --git a/rls/nn/mixers/qplex/si_weight.py b/rls/nn/mixers/qplex/si_weight.py index 2332247..c9091d1 100644 --- a/rls/nn/mixers/qplex/si_weight.py +++ b/rls/nn/mixers/qplex/si_weight.py @@ -1,5 +1,5 @@ -import torch as t import numpy as np +import torch as t from rls.nn.mlps import MLP @@ -33,9 +33,12 @@ def forward(self, state_feat, actions): ''' data = t.cat([state_feat]+actions, dim=-1) # [T, B, *] - all_head_key = [k_ext(state_feat) for k_ext in self.key_extractors] # List[[T, B, 1]] - all_head_agents = [k_ext(state_feat) for k_ext in self.agents_extractors] # List[[T, B, N]] - all_head_action = [sel_ext(data) for sel_ext in self.action_extractors] # List[[T, B, N]] + all_head_key = [k_ext(state_feat) + for k_ext in self.key_extractors] # List[[T, B, 1]] + all_head_agents = [k_ext(state_feat) + for k_ext in self.agents_extractors] # List[[T, B, N]] + # List[[T, B, N]] + all_head_action = [sel_ext(data) for sel_ext in self.action_extractors] head_attend_weights = [] for curr_head_key, curr_head_agents, curr_head_action in zip(all_head_key, all_head_agents, all_head_action): diff --git a/rls/nn/models.py b/rls/nn/models.py index a8716ad..a905976 100644 --- a/rls/nn/models.py +++ b/rls/nn/models.py @@ -8,6 +8,8 @@ from rls.nn.represent_nets import RepresentationNetwork from rls.utils.torch_utils import clip_nn_log_std +Model_REGISTER = {} + class BaseModel(t.nn.Module): @@ -46,6 +48,9 @@ def forward(self, x, **kwargs): return self.net(x) +Model_REGISTER['actor_dpg'] = ActorDPG + + class ActorMuLogstd(BaseModel): ''' use for PPO/PG algorithms' actor network. @@ -83,6 +88,9 @@ def forward(self, x, **kwargs): return (mu, log_std) +Model_REGISTER['actor_mulogstd'] = ActorMuLogstd + + class ActorCts(BaseModel): ''' use for continuous action space. @@ -118,6 +126,9 @@ def forward(self, x, **kwargs): return (mu, log_std) +Model_REGISTER['actor_continuous'] = ActorCts + + class ActorDct(BaseModel): ''' use for discrete action space. @@ -136,6 +147,9 @@ def forward(self, x, **kwargs): return logits +Model_REGISTER['actor_discrete'] = ActorDct + + class CriticQvalueOne(BaseModel): ''' use for evaluate the value given a state-action pair. @@ -154,6 +168,9 @@ def forward(self, x, a, **kwargs): return q +Model_REGISTER['critic_q1'] = CriticQvalueOne + + class CriticQvalueOneDDPG(BaseModel): ''' Original architecture in DDPG paper. @@ -175,6 +192,9 @@ def forward(self, x, a, **kwargs): return q +Model_REGISTER['critic_q1_ddpg'] = CriticQvalueOneDDPG + + class CriticQvalueOneTD3(BaseModel): ''' Original architecture in TD3 paper. @@ -197,6 +217,9 @@ def forward(self, x, a, **kwargs): return q +Model_REGISTER['critic_q1_td3'] = CriticQvalueOneTD3 + + class CriticValue(BaseModel): ''' use for evaluate the value given a state. @@ -214,6 +237,9 @@ def forward(self, x, **kwargs): return v +Model_REGISTER['critic_v'] = CriticValue + + class CriticQvalueAll(BaseModel): ''' use for evaluate all values of Q(S,A) given a state. must be discrete action space. @@ -232,6 +258,9 @@ def forward(self, x, **kwargs): return q +Model_REGISTER['critic_q_all'] = CriticQvalueAll + + class CriticQvalueBootstrap(BaseModel): ''' use for bootstrapped dqn. @@ -249,6 +278,9 @@ def forward(self, x, **kwargs): return q +Model_REGISTER['critic_q_bootstrap'] = CriticQvalueBootstrap + + class CriticDueling(BaseModel): ''' Neural network for dueling deep Q network. @@ -278,6 +310,9 @@ def forward(self, x, **kwargs): return q +Model_REGISTER['critic_dueling'] = CriticDueling + + class OcIntraOption(BaseModel): ''' Intra Option Neural network of Option-Critic. @@ -298,6 +333,9 @@ def forward(self, x, **kwargs): return pi +Model_REGISTER['oc_intra_option'] = OcIntraOption + + class AocShare(BaseModel): ''' Neural network for AOC. @@ -329,6 +367,9 @@ def forward(self, x, **kwargs): return q, pi, beta +Model_REGISTER['aoc_share'] = AocShare + + class PpocShare(BaseModel): ''' Neural network for PPOC. @@ -363,6 +404,9 @@ def forward(self, x, **kwargs): return q, pi, beta, o +Model_REGISTER['ppoc_share'] = PpocShare + + class ActorCriticValueCts(BaseModel): ''' combine actor network and critic network, share some nn layers. use for continuous action space. @@ -405,6 +449,9 @@ def forward(self, x, **kwargs): return (mu, log_std, v) +Model_REGISTER['ac_v_continuous'] = ActorCriticValueCts + + class ActorCriticValueDct(BaseModel): ''' combine actor network and critic network, share some nn layers. use for discrete action space. @@ -431,6 +478,9 @@ def forward(self, x, **kwargs): return (logits, v) +Model_REGISTER['ac_v_discrete'] = ActorCriticValueDct + + class C51Distributional(BaseModel): ''' neural network for C51 @@ -452,6 +502,9 @@ def forward(self, x, **kwargs): return q_dist +Model_REGISTER['c51'] = C51Distributional + + class QrdqnDistributional(BaseModel): ''' neural network for QRDQN @@ -472,6 +525,9 @@ def forward(self, x, **kwargs): return q_dist +Model_REGISTER['qrdqn'] = C51Distributional + + class RainbowDueling(BaseModel): ''' Neural network for Rainbow. @@ -511,6 +567,9 @@ def forward(self, x, **kwargs): return qs # [B, A, N] or [T, B, A, N] +Model_REGISTER['rainbow'] = RainbowDueling + + class IqnNet(BaseModel): def __init__(self, obs_spec, rep_net_params, action_dim, quantiles_idx, network_settings): super().__init__(obs_spec, rep_net_params) @@ -551,6 +610,9 @@ def forward(self, x, quantiles_tiled, **kwargs): return quantiles_value # [N, B, A] or [T, N, B, A] +Model_REGISTER['iqn'] = IqnNet + + class MACriticQvalueOne(t.nn.Module): ''' use for evaluate the value given a state-action pair. @@ -576,3 +638,6 @@ def forward(self, x, a, **kwargs): x = t.cat(outs, -1) q = self.net(t.cat((x, a), -1)) return q + + +Model_REGISTER['ma_critic_q1'] = MACriticQvalueOne diff --git a/rls/nn/networks.py b/rls/nn/networks.py index 631905b..3881a87 100644 --- a/rls/nn/networks.py +++ b/rls/nn/networks.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # encoding: utf-8 -from collections import defaultdict from typing import Dict, Optional, Tuple import numpy as np @@ -9,16 +8,19 @@ from torch.nn import Linear, Sequential from rls.nn.activations import Act_REGISTER, default_act +from rls.nn.represents.encoders import End_REGISTER +from rls.nn.represents.memories import Rnn_REGISTER from rls.nn.represents.vectors import Vec_REGISTER from rls.nn.represents.visuals import Vis_REGISTER class MultiVectorNetwork(t.nn.Module): - def __init__(self, vector_dim=[], network_type='concat'): + def __init__(self, vector_dim=[], h_dim=16, network_type='identity'): super().__init__() self.nets = t.nn.ModuleList() for in_dim in vector_dim: - self.nets.append(Vec_REGISTER[network_type](in_dim=in_dim)) + self.nets.append(Vec_REGISTER[network_type]( + in_dim=in_dim, h_dim=h_dim)) self.h_dim = sum([net.h_dim for net in self.nets]) def forward(self, *vector_inputs): @@ -31,7 +33,7 @@ def forward(self, *vector_inputs): class MultiVisualNetwork(t.nn.Module): - def __init__(self, visual_dim=[], visual_feature=128, network_type='nature'): + def __init__(self, visual_dim=[], h_dim=128, network_type='nature'): super().__init__() self.dense_nets = t.nn.ModuleList() for vd in visual_dim: @@ -39,11 +41,11 @@ def __init__(self, visual_dim=[], visual_feature=128, network_type='nature'): self.dense_nets.append( Sequential( net, - Linear(net.output_dim, visual_feature), + Linear(net.output_dim, h_dim), Act_REGISTER[default_act]() ) ) - self.h_dim = visual_feature * len(self.dense_nets) + self.h_dim = h_dim * len(self.dense_nets) def forward(self, *visual_inputs): # h, w, c => c, h, w @@ -61,26 +63,21 @@ def forward(self, *visual_inputs): class EncoderNetwork(t.nn.Module): - def __init__(self, feat_dim=64, output_dim=64): + def __init__(self, feat_dim=64, h_dim=64, network_type='identity'): super().__init__() - self.h_dim = output_dim - self.net = Linear(feat_dim, output_dim) - self.act = Act_REGISTER[default_act]() + self.net = End_REGISTER[network_type](in_dim=feat_dim, h_dim=h_dim) + self.h_dim = self.net.h_dim def forward(self, feat): - return self.act(self.net(feat)) + return self.net(feat) class MemoryNetwork(t.nn.Module): - def __init__(self, feat_dim=64, rnn_units=8, *, network_type='lstm'): + def __init__(self, feat_dim=64, rnn_units=8, network_type='lstm'): super().__init__() - self.h_dim = rnn_units - self.network_type = network_type - - if self.network_type == 'gru': - self.rnn = t.nn.GRUCell(feat_dim, rnn_units) - elif self.network_type == 'lstm': - self.rnn = t.nn.LSTMCell(feat_dim, rnn_units) + self.net = Rnn_REGISTER[network_type]( + in_dim=feat_dim, rnn_units=rnn_units) + self.h_dim = self.net.h_dim def forward(self, feat, cell_state: Optional[Dict], begin_mask: Optional[t.Tensor]): ''' @@ -91,43 +88,21 @@ def forward(self, feat, cell_state: Optional[Dict], begin_mask: Optional[t.Tenso output: [T, B, *] or [B, *] cell_states: [T, B, *] or [B, *] ''' - T, B = feat.shape[:2] - output = [] - cell_states = defaultdict(list) - if self.network_type == 'gru': - if cell_state: - hx = cell_state['hx'][0] - else: - hx = t.zeros(size=(B, self.h_dim)) - for i in range(T): # T - if begin_mask is not None: - hx *= (1 - begin_mask[i]) - hx = self.rnn(feat[i, ...], hx) - - output.append(hx) - cell_states['hx'].append(hx) - - elif self.network_type == 'lstm': + _squeeze = False + if feat.ndim == 2: # [B, *] + _squeeze = True + feat = feat.unsqueeze(0) # [B, *] => [1, B, *] if cell_state: - hx, cx = cell_state['hx'][0], cell_state['cx'][0] - else: - hx, cx = t.zeros(size=(B, self.h_dim)), t.zeros( - size=(B, self.h_dim)) - for i in range(T): # T - if begin_mask is not None: - hx *= (1 - begin_mask[i]) - cx *= (1 - begin_mask[i]) - hx, cx = self.rnn(feat[i, ...], (hx, cx)) - - output.append(hx) - cell_states['hx'].append(hx) - cell_states['cx'].append(cx) - if T > 1: - output = t.stack(output, dim=0) # [T, B, N] - cell_states = {k: t.stack(v, 0) - for k, v in cell_states.items()} # [T, B, N] - return output, cell_states - else: - # [B, *] - return output[0], {k: v[0] for k, v in cell_states.items()} + cell_state = {k: v.unsqueeze(0) # [1, B, *] + for k, v in cell_state.items()} + + output, cell_states = self.net( + feat, cell_state, begin_mask) # [B, *] or [T, B, *] + + if _squeeze: + output = output.squeeze(0) # [B, *] + if cell_states: + cell_states = {k: v.squeeze(0) + for k, v in cell_states.items()} # [B, *] + return output, cell_states diff --git a/rls/nn/represent_nets.py b/rls/nn/represent_nets.py index 00b901b..9925dde 100644 --- a/rls/nn/represent_nets.py +++ b/rls/nn/represent_nets.py @@ -54,21 +54,17 @@ def __init__(self, self.use_other_info = True self.h_dim += self.obs_spec.other_dims - self.use_encoder = bool(rep_net_params.get('use_encoder', False)) - if self.use_encoder: - encoder_net_params = dict( - rep_net_params.get('encoder_net_params', {})) - self.encoder_net = EncoderNetwork(self.h_dim, **encoder_net_params) - logger.debug('initialize encoder network successfully.') - self.h_dim = self.encoder_net.h_dim - - self.use_rnn = bool(rep_net_params.get('use_rnn', False)) - if self.use_rnn: - memory_net_params = dict( - rep_net_params.get('memory_net_params', {})) - self.memory_net = MemoryNetwork(self.h_dim, **memory_net_params) - logger.debug('initialize memory network successfully.') - self.h_dim = self.memory_net.h_dim + encoder_net_params = dict( + rep_net_params.get('encoder_net_params', {})) + self.encoder_net = EncoderNetwork(self.h_dim, **encoder_net_params) + logger.debug('initialize encoder network successfully.') + self.h_dim = self.encoder_net.h_dim + + memory_net_params = dict( + rep_net_params.get('memory_net_params', {})) + self.memory_net = MemoryNetwork(self.h_dim, **memory_net_params) + logger.debug('initialize memory network successfully.') + self.h_dim = self.memory_net.h_dim def forward(self, obs, cell_state=None, begin_mask=None): ''' @@ -90,17 +86,10 @@ def forward(self, obs, cell_state=None, begin_mask=None): if self.use_other_info: feat = t.cat([feat, obs.other], -1) - if self.use_encoder: - feat = self.encoder_net(feat) # [T, B, *] or [B, *] - - if self.use_rnn: - if feat.ndim == 2: # [B, *] - feat = feat.unsqueeze(0) # [B, *] => [1, B, *] - if cell_state: - cell_state = {k: v.unsqueeze(0) # [1, B, *] - for k, v in cell_state.items()} - feat, cell_state = self.memory_net( - feat, cell_state, begin_mask) # [T, B, *] or [B, *] + feat = self.encoder_net(feat) # [T, B, *] or [B, *] + + feat, cell_state = self.memory_net( + feat, cell_state, begin_mask) # [T, B, *] or [B, *] return feat, cell_state diff --git a/rls/nn/represents/encoders.py b/rls/nn/represents/encoders.py new file mode 100644 index 0000000..05439a4 --- /dev/null +++ b/rls/nn/represents/encoders.py @@ -0,0 +1,29 @@ + + +from torch.nn import Identity, Linear, Sequential + +from rls.nn.activations import Act_REGISTER, default_act + +End_REGISTER = {} + + +class EncoderIdentityNetwork(Sequential): + + def __init__(self, in_dim, *args, **kwargs): + super().__init__() + self.h_dim = self.in_dim = in_dim + self.add_module(f'identity', Identity()) + + +class EncoderMlpNetwork(Sequential): + + def __init__(self, in_dim, h_dim=16, **kwargs): + super().__init__() + self.in_dim = in_dim + self.h_dim = h_dim + self.add_module('linear', Linear(self.in_dim, self.h_dim)) + self.add_module('activation', Act_REGISTER[default_act]()) + + +End_REGISTER['identity'] = EncoderIdentityNetwork +End_REGISTER['mlp'] = EncoderMlpNetwork diff --git a/rls/nn/represents/memories.py b/rls/nn/represents/memories.py new file mode 100644 index 0000000..2cccf71 --- /dev/null +++ b/rls/nn/represents/memories.py @@ -0,0 +1,111 @@ + +from collections import defaultdict +from typing import Dict, Optional, Tuple + +import torch as t +from torch.nn import Identity, Linear, Sequential + +from rls.nn.activations import Act_REGISTER, default_act + +Rnn_REGISTER = {} + + +class IdentityRNN(t.nn.Module): + + def __init__(self, in_dim, *args, **kwargs): + super().__init__() + self.h_dim = self.in_dim = in_dim + + def forward(self, x, *args, **kwargs): + return x, None + + +class GRU_RNN(t.nn.Module): + + def __init__(self, in_dim, rnn_units=16, **kwargs): + super().__init__() + self.in_dim = in_dim + self.rnn = t.nn.GRUCell(self.in_dim, rnn_units) + self.h_dim = rnn_units + + def forward(self, feat, cell_state: Optional[Dict], begin_mask: Optional[t.Tensor]): + ''' + params: + feat: [T, B, *] + cell_state: [T, B, *] + returns: + output: [T, B, *] or [B, *] + cell_states: [T, B, *] or [B, *] + ''' + T, B = feat.shape[:2] + + output = [] + cell_states = defaultdict(list) + + if cell_state: + hx = cell_state['hx'][0] + else: + hx = t.zeros(size=(B, self.h_dim)) + for i in range(T): # T + if begin_mask is not None: + hx *= (1 - begin_mask[i]) + hx = self.rnn(feat[i, ...], hx) + + output.append(hx) + cell_states['hx'].append(hx) + + output = t.stack(output, dim=0) # [T, B, N] + if cell_states: + cell_states = {k: t.stack(v, 0) + for k, v in cell_states.items()} # [T, B, N] + return output, cell_states + + +class LSTM_RNN(t.nn.Module): + + def __init__(self, in_dim, rnn_units=16, **kwargs): + super().__init__() + self.in_dim = in_dim + self.rnn = t.nn.LSTMCell(self.in_dim, rnn_units) + self.h_dim = rnn_units + + def forward(self, feat, cell_state: Optional[Dict], begin_mask: Optional[t.Tensor]): + ''' + params: + feat: [T, B, *] + cell_state: [T, B, *] + returns: + output: [T, B, *] or [B, *] + cell_states: [T, B, *] or [B, *] + ''' + T, B = feat.shape[:2] + + output = [] + cell_states = defaultdict(list) + + if cell_state: + hx, cx = cell_state['hx'][0], cell_state['cx'][0] + else: + hx, cx = t.zeros(size=(B, self.h_dim)), t.zeros( + size=(B, self.h_dim)) + for i in range(T): # T + if begin_mask is not None: + hx *= (1 - begin_mask[i]) + cx *= (1 - begin_mask[i]) + hx, cx = self.rnn(feat[i, ...], (hx, cx)) + + output.append(hx) + cell_states['hx'].append(hx) + cell_states['cx'].append(cx) + + output = t.stack(output, dim=0) # [T, B, N] + if cell_states: + cell_states = {k: t.stack(v, 0) + for k, v in cell_states.items()} # [T, B, N] + + return output, cell_states + + +Rnn_REGISTER['identity'] = Rnn_REGISTER['none'] = IdentityRNN +Rnn_REGISTER['gru'] = GRU_RNN +Rnn_REGISTER['lstm'] = LSTM_RNN diff --git a/rls/nn/represents/vectors.py b/rls/nn/represents/vectors.py index 6819771..187e483 100644 --- a/rls/nn/represents/vectors.py +++ b/rls/nn/represents/vectors.py @@ -1,32 +1,28 @@ import math -from torch.nn import Linear, Sequential +from torch.nn import Identity, Linear, Sequential from rls.nn.activations import Act_REGISTER, default_act Vec_REGISTER = {} -class VectorConcatNetwork: +class VectorIdentityNetwork(Sequential): - def __init__(self, *args, **kwargs): - assert 'in_dim' in kwargs.keys(), "assert dim in kwargs.keys()" - self.h_dim = self.in_dim = int(kwargs['in_dim']) - pass - - def __call__(self, x): - return x + def __init__(self, in_dim, *args, **kwargs): + super().__init__() + self.h_dim = self.in_dim = in_dim + self.add_module(f'identity', Identity()) class VectorAdaptiveNetwork(Sequential): - def __init__(self, **kwargs): + def __init__(self, in_dim, h_dim=16, **kwargs): super().__init__() - assert 'in_dim' in kwargs.keys(), "assert dim in kwargs.keys()" - self.in_dim = int(kwargs['in_dim']) - self.h_dim = self.out_dim = int(kwargs.get('out_dim', 16)) - x = math.log2(self.out_dim) + self.in_dim = in_dim + self.h_dim = h_dim + x = math.log2(self.h_dim) y = math.log2(self.in_dim) l = math.ceil(x) + 1 if math.ceil(x) == math.floor(x) else math.ceil(x) r = math.floor(y) if math.ceil(y) == math.floor(y) else math.ceil(y) @@ -42,9 +38,9 @@ def __init__(self, **kwargs): ins = outs[-1] else: ins = self.in_dim - self.add_module('linear', Linear(ins, self.out_dim)) + self.add_module('linear', Linear(ins, self.h_dim)) self.add_module(f'{default_act}', Act_REGISTER[default_act]()) -Vec_REGISTER['concat'] = VectorConcatNetwork +Vec_REGISTER['identity'] = VectorIdentityNetwork Vec_REGISTER['adaptive'] = VectorAdaptiveNetwork