-
Notifications
You must be signed in to change notification settings - Fork 1
/
neural_network_dqn.py
146 lines (122 loc) · 4.71 KB
/
neural_network_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#%%
import torch.nn as nn
import math, random
import torch
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
import os
import logging
# import logging.config
# Setup logger
logger = logging.getLogger(__name__)
import numpy as np
def combined_shape(length, shape=None):
"""
This function takes the size of the replay buffer and the shape of a
matrix to return a tuple with specific dimensions for matrix
initialization.
Args:
length (int): Number of sample slots in the replay buffer.
shape (tuple, optional): Defaults to None. Typical inputs are
state_dims or action_dims
Returns:
tuple: Returns a tuple with dimensions for matrix initialization
"""
if shape is None:
return (length,)
return (length, shape) if np.isscalar(shape) else (length, *shape)
class NetworkUtils:
"""
Provides useful functionalities for neural network classes. Inherit
from this class when creating a new neural network to use its
functionalities.
Functionalities include:
* saving and loading network parameters
"""
def save_checkpoint(self):
if hasattr(self, "checkpoint_dir") and hasattr(self, "name"):
checkpoint_file_path = os.path.join(self.checkpoint_dir, self.name)
logger.info("--- Saving Checkpoint ---")
torch.save(self.state_dict(), checkpoint_file_path)
else:
logger.error(
"--- Could not save checkpoint, some attributes are missing ---"
)
def load_checkpoint(self):
if hasattr(self, "checkpoint_dir") and hasattr(self, "name"):
checkpoint_file_path = os.path.join(self.checkpoint_dir, self.name)
logger.info("--- Loading Checkpoint ---")
self.load_state_dict(torch.load(checkpoint_file_path))
# optional: self.eval()
# Remember that you must call model.eval() to set dropout and
# batch normalization layers to evaluation mode before running
# inference. Failing to do this will yield inconsistent
# inference results.
else:
logger.error(
"--- Could not load checkpoint, some attributes are missing ---"
)
class DQN(nn.Module, NetworkUtils):
def __init__(self,
name:str,
state_dims: int,
action_dims: int,
env,
checkpoint_dir: str = "models/tmp/dqn",):
super(DQN, self).__init__()
self.name = name
self.checkpoint_dir = checkpoint_dir
self.env = env
self.layers = nn.Sequential(
nn.Linear(self.env.observation_space.shape[0], 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, self.env.action_space.shape[0])
)
# Initialize device to which the network should be passed to
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Pass the Critic network to said device
self.to(self.device)
def forward(self, x):
return self.layers(x)
def act(self, state, epsilon):
if self.env.rng.rand() > epsilon:
# print("greedy")
q_value = self.forward(torch.tensor(state).to(self.device))
# print(f"{q_value=}")
action = q_value
action = q_value.argmax().item()
# print(f"{action=}")
else:
# print("random")
action = self.env.action_space.sample()
# print(f"{action=}")
return action
#%%
if __name__ == "__main__":
from buffer import ReplayBuffer
import numpy as np
dqn = DQN("test_dqn", 2, 1)
dqn.save_checkpoint()
STATE_DIMS = 2
ACTION_DIMS = 1
MEM_FACTOR = 100
N_PATHS = 5
replay_buffer = ReplayBuffer(STATE_DIMS, ACTION_DIMS, MEM_FACTOR, N_PATHS)
# Generate arbitrary state transition
states = np.arange(N_PATHS * STATE_DIMS).reshape((N_PATHS, STATE_DIMS))
next_states = np.arange(N_PATHS * STATE_DIMS).reshape((N_PATHS, STATE_DIMS))
actions = np.arange(N_PATHS * ACTION_DIMS).reshape((N_PATHS, ACTION_DIMS))
rewards = np.arange(N_PATHS)
dones = np.zeros(N_PATHS)
replay_buffer.store_transition(states, actions, rewards, next_states, dones)
memory_sample = replay_buffer.sample_batch(1)
states = memory_sample["states"].to(dqn.device)
actions = memory_sample["actions"].to(dqn.device)
rewards = memory_sample["rewards"].to(dqn.device)
next_states = memory_sample["next_states"].to(dqn.device)
dones = memory_sample["dones"].to(dqn.device)
action = dqn.act(states, 0.0001)
# %%