Skip to content

Commit

Permalink
noisy net
Browse files Browse the repository at this point in the history
  • Loading branch information
shixiaowen03 committed Dec 18, 2018
1 parent 5aa1e47 commit 48c04b3
Show file tree
Hide file tree
Showing 8 changed files with 700 additions and 154 deletions.
296 changes: 142 additions & 154 deletions .idea/workspace.xml

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions RL/Basic-NoisyNet-Demo/Config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
class NoisyNetDQNConfig:
# ENV_NAME = "CartPole-v1"
ENV_NAME = 'Breakout-v0' # 0: hold 1: throw the ball 2: move right 3: move left
# ENV_NAME = "Freeway-v0"
GAMMA = 0.99 # discount factor for target Q
START_TRAINING = 1000 # experience replay buffer size
BATCH_SIZE = 64 # size of minibatch
UPDATE_TARGET_NET = 400 # update eval_network params every 200 steps
LEARNING_RATE = 0.01
MODEL_PATH = './model/NoisyNetDQN_model'

INITIAL_EPSILON = 1.0 # starting value of epsilon
FINAL_EPSILON = 0.01 # final value of epsilon
EPSILIN_DECAY = 0.999

replay_buffer_size = 2000
iteration = 5
episode = 300 # 300 games per iteration

noisy_distribution = 'factorised' # independent or factorised




148 changes: 148 additions & 0 deletions RL/Basic-NoisyNet-Demo/NoisyNetDQN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import tensorflow as tf
import numpy as np
import random
from collections import deque

from utils import conv,noisy_dense

class NoisyNetDQN():
def __init__(self,env,config):
self.sess = tf.InteractiveSession()
self.config = config

self.replay_buffer = deque(maxlen = self.config.replay_buffer_size)
self.time_step = 0

self.state_dim = env.observation_space.shape
self.action_dim = env.action_space.n

print('state_dim:', self.state_dim)
print('action_dim:', self.action_dim)

self.action_batch = tf.placeholder('int32',[None])
self.y_input = tf.placeholder('float',[None,self.action_dim])

batch_shape = [None]
batch_shape.extend(self.state_dim)

self.eval_input = tf.placeholder('float',batch_shape)
self.target_input = tf.placeholder('float',batch_shape)

self.build_noisy_dqn_net()

self.saver = tf.train.Saver()

self.sess.run(tf.global_variables_initializer())

self.save_model()
self.restore_model()

def build_layers(self,state,c_names,units_1,units_2,w_i,b_i,reg=None):
with tf.variable_scope('conv1'):
conv1 = conv(state,[5,5,3,6],[6],[1,2,2,1],w_i,b_i)
with tf.variable_scope('conv2'):
conv2 = conv(conv1,[3,3,6,12],[12],[1,2,2,1],w_i,b_i)
with tf.variable_scope('flatten'):
flatten = tf.contrib.layers.flatten(conv2)

with tf.variable_scope('dense1'):
dense1 = noisy_dense(flatten,units_1,[units_1],c_names,w_i,b_i,noisy_distribution = self.config.noisy_distribution)

with tf.variable_scope('dense2'):
dense2 = noisy_dense(dense1,units_2,[units_2],c_names,w_i,b_i,noisy_distribution = self.config.noisy_distribution)

with tf.variable_scope('dense3'):
dense3 = noisy_dense(dense2,self.action_dim,[self.action_dim],c_names,w_i,b_i,noisy_distribution = self.config.noisy_distribution)

return dense3

def build_noisy_dqn_net(self):
with tf.variable_scope('target_net'):
c_names = ['target_net_arams',tf.GraphKeys.GLOBAL_VARIABLES]
w_i = tf.random_uniform_initializer(-0.1,0.1)
b_i = tf.constant_initializer(0.1)
self.q_target = self.build_layers(self.target_input,c_names,24,24,w_i,b_i)

with tf.variable_scope('eval_net'):
c_names = ['eval_net_params',tf.GraphKeys.GLOBAL_VARIABLES]
w_i = tf.random_uniform_initializer(-0.1,0.1)
b_i = tf.constant_initializer(0.1)
self.q_eval = self.build_layers(self.eval_input,c_names,24,24,w_i,b_i)

self.loss = tf.reduce_mean(tf.squared_difference(self.q_eval,self.y_input))

self.optimizer = tf.train.AdamOptimizer(self.config.LEARNING_RATE).minimize(self.loss)

eval_params = tf.get_collection("eval_net_params")
target_params = tf.get_collection('target_net_params')

self.update_target_net = [tf.assign(t,e) for t,e in zip(target_params,eval_params)]


def save_model(self):
print("Model saved in : ", self.saver.save(self.sess, self.config.MODEL_PATH))

def restore_model(self):
self.saver.restore(self.sess, self.config.MODEL_PATH)
print("Model restored.")


def perceive(self,state,action,reward,next_state,done):
self.replay_buffer.append((state,action,reward,next_state,done))


def train_q_network(self,update=True):

if len(self.replay_buffer) < self.config.START_TRAINING:
return

self.time_step += 1
minibatch = random.sample(self.replay_buffer,self.config.BATCH_SIZE)

np.random.shuffle(minibatch)

state_batch = [data[0] for data in minibatch]
action_batch = [data[1] for data in minibatch]
reward_batch = [data[2] for data in minibatch]
next_state_batch = [data[3] for data in minibatch]
done = [data[4] for data in minibatch]

q_target = self.sess.run(self.q_target,feed_dict={self.target_input:next_state_batch})
q_eval = self.sess.run(self.q_eval,feed_dict={self.eval_input:state_batch})

done = np.array(done) + 0

# DQN的结构 r + max q_target[a]
y_batch = np.zeros((self.config.BATCH_SIZE,self.action_dim))
for i in range(0,self.config.BATCH_SIZE):
temp = q_eval[i]
action = np.argmax(q_target[i])
temp[action_batch[i]] = reward_batch[i] + (1 - done[i]) * self.config.GAMMA * q_target[i][action]
y_batch[i] = temp


self.sess.run(self.optimizer,feed_dict={
self.y_input:y_batch,
self.eval_input:state_batch,
self.action_batch:action_batch
})

if update and self.time_step % self.config.UPDATE_TARGET_NET == 0:
self.sess.run(self.update_target_net)



def noisy_action(self, state):

return np.argmax(self.sess.run(self.q_target,feed_dict={self.target_input: [state]})[0])











71 changes: 71 additions & 0 deletions RL/Basic-NoisyNet-Demo/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import matplotlib.pyplot as plt
import gym

import numpy as np
import tensorflow as tf

import pickle

from Config import NoisyNetDQNConfig
from NoisyNetDQN import NoisyNetDQN

def map_scores(dqfd_scores=None, ddqn_scores=None, xlabel=None, ylabel=None):
if dqfd_scores is not None:
plt.plot(dqfd_scores, 'r')
if ddqn_scores is not None:
plt.plot(ddqn_scores, 'b')
if xlabel is not None:
plt.xlabel(xlabel)
if ylabel is not None:
plt.ylabel(ylabel)
plt.show()


def BreakOut_NoisyNetDQN(index,env):
with tf.variable_scope('DQfD_' + str(index)):
agent = NoisyNetDQN(env,NoisyNetDQNConfig())
scores = []
for e in range(NoisyNetDQNConfig.episode):
done = False
score = 0 # sum of reward in one episode
state = env.reset()
# while done is False:
last_lives = 5
throw = True
items_buffer = []
while not done:
env.render()
action = 1 if throw else agent.noisy_action(state)
next_state, real_reward, done, info = env.step(action)
lives = info['ale.lives']
train_reward = 1 if throw else -1 if lives < last_lives else real_reward
score += real_reward
throw = lives < last_lives
last_lives = lives
# agent.perceive(state, action, train_reward, next_state, done) # miss: -1 break: reward nothing: 0
items_buffer.append([state, action, next_state, done]) # miss: -1 break: reward nothing: 0
state = next_state
if train_reward != 0: # train when miss the ball or score or throw the ball in the beginning
print ('len(items_buffer):', len(items_buffer))
for item in items_buffer:
agent.perceive(item[0], item[1], -1 if throw else train_reward, item[2], item[3])
agent.train_q_network(update=False)
items_buffer = []
scores.append(score)
agent.sess.run(agent.update_target_net)
print("episode:", e, " score:", score, " memory length:", len(agent.replay_buffer))

return scores


if __name__ == '__main__':
env = gym.make('Breakout-v0') # 打砖块游戏

NoisyNetDQN_sum_scores = np.zeros(NoisyNetDQNConfig.episode)

for i in range(NoisyNetDQNConfig.iteration):
scores = BreakOut_NoisyNetDQN(i,env)
dqfd_sum_scores = [a + b for a, b in zip(scores, NoisyNetDQN_sum_scores)]
NoisyNetDQN_mean_scores = NoisyNetDQN_sum_scores / NoisyNetDQNConfig.iteration
with open('/Users/mahailong/DQfD/NoisyNetDQN_mean_scores.p', 'wb') as f:
pickle.dump(NoisyNetDQN_mean_scores, f, protocol=2)
2 changes: 2 additions & 0 deletions RL/Basic-NoisyNet-Demo/readme
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
论文名称:Noisy Networks for Exploration
论文下载地址:https://arxiv.org/abs/1706.10295v1
48 changes: 48 additions & 0 deletions RL/Basic-NoisyNet-Demo/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import tensorflow as tf
from tensorflow.python.framework import ops


def conv(inputs, kernel_shape, bias_shape, strides, w_i, b_i=None, activation=tf.nn.relu):

weights = tf.get_variable('weights', shape=kernel_shape, initializer=w_i)
conv = tf.nn.conv2d(inputs, weights, strides=strides, padding='SAME')
if bias_shape is not None:
biases = tf.get_variable('biases', shape=bias_shape, initializer=b_i)
return activation(conv + biases) if activation is not None else conv+biases
return activation(conv) if activation is not None else conv

def noisy_dense(inputs, units, bias_shape, c_names, w_i, b_i=None, activation=tf.nn.relu, noisy_distribution='factorised'):
def f(e_list):
return tf.multiply(tf.sign(e_list), tf.pow(tf.abs(e_list), 0.5))

if not isinstance(inputs, ops.Tensor):
inputs = ops.convert_to_tensor(inputs, dtype='float')

if len(inputs.shape) > 2:
inputs = tf.contrib.layers.flatten(inputs)
flatten_shape = inputs.shape[1]
weights = tf.get_variable('weights', shape=[flatten_shape, units], initializer=w_i)
w_noise = tf.get_variable('w_noise', [flatten_shape, units], initializer=w_i, collections=c_names)
if noisy_distribution == 'independent':
weights += tf.multiply(tf.random_normal(shape=w_noise.shape), w_noise)
elif noisy_distribution == 'factorised':
noise_1 = f(tf.random_normal(tf.TensorShape([flatten_shape, 1]), dtype=tf.float32)) # 注意是列向量形式,方便矩阵乘法
noise_2 = f(tf.random_normal(tf.TensorShape([1, units]), dtype=tf.float32))
weights += tf.multiply(noise_1 * noise_2, w_noise)
dense = tf.matmul(inputs, weights)
if bias_shape is not None:
assert bias_shape[0] == units
biases = tf.get_variable('biases', shape=bias_shape, initializer=b_i)
b_noise = tf.get_variable('b_noise', [1, units], initializer=b_i, collections=c_names)
if noisy_distribution == 'independent':
biases += tf.multiply(tf.random_normal(shape=b_noise.shape), b_noise)
elif noisy_distribution == 'factorised':
biases += tf.multiply(noise_2, b_noise)
return activation(dense + biases) if activation is not None else dense + biases
return activation(dense) if activation is not None else dense






Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 48c04b3

Please sign in to comment.