forked from princewen/tensorflow_practice
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
shixiaowen03
committed
Dec 21, 2018
1 parent
62bca23
commit c659ae9
Showing
8 changed files
with
1,458 additions
and
168 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
import random | ||
from collections import deque | ||
from Config import Categorical_DQN_Config | ||
from utils import conv, dense | ||
import math | ||
|
||
|
||
class Categorical_DQN(): | ||
def __init__(self,env,config): | ||
self.sess = tf.InteractiveSession() | ||
self.config = config | ||
self.v_max = self.config.v_max | ||
self.v_min = self.config.v_min | ||
self.atoms = self.config.atoms | ||
|
||
self.time_step = 0 | ||
self.epsilon = self.config.INITIAL_EPSILON | ||
self.state_shape = env.observation_space.shape | ||
self.action_dim = env.action_space.n | ||
|
||
target_state_shape = [1] | ||
target_state_shape.extend(self.state_shape) | ||
|
||
self.state_input = tf.placeholder(tf.float32,target_state_shape) | ||
self.action_input = tf.placeholder(tf.int32,[1,1]) | ||
|
||
self.m_input = tf.placeholder(tf.float32,[self.atoms]) | ||
|
||
self.delta_z = (self.v_max - self.v_min) / (self.atoms - 1) | ||
self.z = [self.v_min + i * self.delta_z for i in range(self.atoms)] | ||
|
||
self.build_cate_dqn_net() | ||
|
||
self.saver = tf.train.Saver() | ||
|
||
self.sess.run(tf.global_variables_initializer()) | ||
|
||
self.save_model() | ||
self.restore_model() | ||
|
||
|
||
|
||
|
||
def build_layers(self, state, action, c_names, units_1, units_2, w_i, b_i, reg=None): | ||
with tf.variable_scope('conv1'): | ||
conv1 = conv(state, [5, 5, 3, 6], [6], [1, 2, 2, 1], w_i, b_i) | ||
with tf.variable_scope('conv2'): | ||
conv2 = conv(conv1, [3, 3, 6, 12], [12], [1, 2, 2, 1], w_i, b_i) | ||
with tf.variable_scope('flatten'): | ||
flatten = tf.contrib.layers.flatten(conv2) | ||
|
||
with tf.variable_scope('dense1'): | ||
dense1 = dense(flatten, units_1, [units_1], w_i, b_i) | ||
with tf.variable_scope('dense2'): | ||
dense2 = dense(dense1, units_2, [units_2], w_i, b_i) | ||
with tf.variable_scope('concat'): | ||
concatenated = tf.concat([dense2, tf.cast(action, tf.float32)], 1) | ||
with tf.variable_scope('dense3'): | ||
dense3 = dense(concatenated, self.atoms, [self.atoms], w_i, b_i) # 返回 | ||
return dense3 | ||
|
||
def build_cate_dqn_net(self): | ||
with tf.variable_scope('target_net'): | ||
c_names = ['target_net_arams',tf.GraphKeys.GLOBAL_VARIABLES] | ||
w_i = tf.random_uniform_initializer(-0.1,0.1) | ||
b_i = tf.constant_initializer(0.1) | ||
self.z_target = self.build_layers(self.state_input,self.action_input,c_names,24,24,w_i,b_i) | ||
|
||
with tf.variable_scope('eval_net'): | ||
c_names = ['eval_net_params',tf.GraphKeys.GLOBAL_VARIABLES] | ||
w_i = tf.random_uniform_initializer(-0.1,0.1) | ||
b_i = tf.constant_initializer(0.1) | ||
self.z_eval = self.build_layers(self.state_input,self.action_input,c_names,24,24,w_i,b_i) | ||
|
||
|
||
self.q_eval = tf.reduce_sum(self.z_eval * self.z) | ||
self.q_target = tf.reduce_sum(self.z_target * self.z) | ||
|
||
self.cross_entropy_loss = -tf.reduce_sum(self.m_input * tf.log(self.z_eval)) | ||
|
||
self.optimizer = tf.train.AdamOptimizer(self.config.LEARNING_RATE).minimize(self.cross_entropy_loss) | ||
|
||
|
||
|
||
|
||
def train(self,s,r,action,s_,gamma): | ||
list_q_ = [self.sess.run(self.q_target,feed_dict={self.state_input:[s_],self.action_input:[[a]]}) for a in range(self.action_dim)] | ||
a_ = tf.argmax(list_q_).eval() | ||
m = np.zeros(self.atoms) | ||
p = self.sess.run(self.z_target,feed_dict = {self.state_input:[s_],self.action_input:[[a_]]})[0] | ||
for j in range(self.atoms): | ||
Tz = min(self.v_max,max(self.v_min,r+gamma * self.z[j])) | ||
bj = (Tz - self.v_min) / self.delta_z # 分在第几个块里 | ||
l,u = math.floor(bj),math.ceil(bj) # 上下界 | ||
|
||
pj = p[j] | ||
|
||
m[int(l)] += pj * (u - bj) | ||
m[int(u)] += pj * (bj - l) | ||
|
||
self.sess.run(self.optimizer,feed_dict={self.state_input:[s] , self.action_input:[action], self.m_input: m }) | ||
|
||
|
||
|
||
def save_model(self): | ||
print("Model saved in : ", self.saver.save(self.sess, self.config.MODEL_PATH)) | ||
|
||
def restore_model(self): | ||
self.saver.restore(self.sess, self.config.MODEL_PATH) | ||
print("Model restored.") | ||
|
||
|
||
def greedy_action(self,s): | ||
self.epsilon = max(self.config.FINAL_EPSILON, self.epsilon * self.config.EPSILIN_DECAY) | ||
if random.random() <= self.epsilon: | ||
return random.randint(0, self.action_dim - 1) | ||
return np.argmax( | ||
[self.sess.run(self.q_target,feed_dict={self.state_input:[s],self.action_input:[[a]]}) for a in range(self.action_dim)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
class Categorical_DQN_Config(): | ||
v_min = 0 | ||
v_max = 1000 | ||
atoms = 51 | ||
|
||
# ENV_NAME = "CartPole-v1" | ||
ENV_NAME = 'Breakout-v0' # 0: hold 1: throw the ball 2: move right 3: move left | ||
# ENV_NAME = "Freeway-v0" | ||
GAMMA = 0.99 # discount factor for target Q | ||
START_TRAINING = 1000 # experience replay buffer size | ||
BATCH_SIZE = 64 # size of minibatch | ||
UPDATE_TARGET_NET = 400 # update eval_network params every 200 steps | ||
LEARNING_RATE = 0.01 | ||
MODEL_PATH = './model/C51DQN_model' | ||
|
||
INITIAL_EPSILON = 0.9 # starting value of epsilon | ||
FINAL_EPSILON = 0.05 # final value of epsilon | ||
EPSILIN_DECAY = 0.9999 | ||
|
||
replay_buffer_size = 2000 | ||
iteration = 5 | ||
episode = 300 # 300 games per iteration |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
|
||
import matplotlib.pyplot as plt | ||
import tensorflow as tf | ||
import gym | ||
import numpy as np | ||
import pickle | ||
from Config import Categorical_DQN_Config | ||
from Categorical_DQN import Categorical_DQN | ||
|
||
|
||
def map_scores(dqfd_scores=None, ddqn_scores=None, xlabel=None, ylabel=None): | ||
if dqfd_scores is not None: | ||
plt.plot(dqfd_scores, 'r') | ||
if ddqn_scores is not None: | ||
plt.plot(ddqn_scores, 'b') | ||
if xlabel is not None: | ||
plt.xlabel(xlabel) | ||
if ylabel is not None: | ||
plt.ylabel(ylabel) | ||
plt.show() | ||
|
||
def BreakOut_CDQN(index, env): | ||
with tf.variable_scope('DQfD_' + str(index)): | ||
agent = Categorical_DQN(env, Categorical_DQN_Config()) | ||
scores = [] | ||
for e in range(Categorical_DQN_Config.episode): | ||
done = False | ||
score = 0 # sum of reward in one episode | ||
state = env.reset() | ||
# while done is False: | ||
last_lives = 5 | ||
throw = True | ||
items_buffer = [] | ||
while not done: | ||
env.render() | ||
action = 1 if throw else agent.greedy_action(state) | ||
next_state, real_reward, done, info = env.step(action) | ||
lives = info['ale.lives'] | ||
train_reward = 1 if throw else -1 if lives < last_lives else real_reward | ||
score += real_reward | ||
throw = lives < last_lives | ||
last_lives = lives | ||
# agent.train(state, train_reward, [action], next_state, 0.1) | ||
items_buffer.append([state, [action], next_state, 0.1]) | ||
state = next_state | ||
if train_reward != 0: # train when miss the ball or score or throw the ball in the beginning | ||
for item in items_buffer: | ||
agent.train(item[0], -1 if throw else train_reward, item[1], item[2], item[3]) | ||
items_buffer = [] | ||
scores.append(score) | ||
agent.save_model() | ||
# if np.mean(scores[-min(10, len(scores)):]) > 495: | ||
# break | ||
return scores | ||
|
||
|
||
if __name__ == '__main__': | ||
env = gym.make("Breakout-v0") | ||
CDQN_sum_scores = np.zeros(Categorical_DQN_Config.episode) | ||
for i in range(Categorical_DQN_Config.iteration): | ||
scores = BreakOut_CDQN(i,env) | ||
c51_sum_scores = [a + b for a, b in zip(scores, CDQN_sum_scores)] | ||
C51DQN_mean_scores = CDQN_sum_scores / Categorical_DQN_Config.iteration | ||
with open('/Users/mahailong/C51DQN/C51DQN_mean_scores.p', 'wb') as f: | ||
pickle.dump(C51DQN_mean_scores, f, protocol=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
论文:A Distributional Perspective on Reinforcement Learning | ||
地址:https://arxiv.org/abs/1707.06887 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import tensorflow as tf | ||
from tensorflow.python.framework import ops | ||
|
||
|
||
def conv(inputs, kernel_shape, bias_shape, strides, w_i, b_i=None, activation=tf.nn.relu): | ||
# 使用tf.layers | ||
# relu1 = tf.layers.conv2d(input_imgs, filters=24, kernel_size=[5, 5], strides=[2, 2], | ||
# padding='SAME', activation=tf.nn.relu, | ||
# kernel_initializer=w_i, bias_initializer=b_i) | ||
weights = tf.get_variable('weights', shape=kernel_shape, initializer=w_i) | ||
conv = tf.nn.conv2d(inputs, weights, strides=strides, padding='SAME') | ||
if bias_shape is not None: | ||
biases = tf.get_variable('biases', shape=bias_shape, initializer=b_i) | ||
return activation(conv + biases) if activation is not None else conv + biases | ||
return activation(conv) if activation is not None else conv | ||
|
||
|
||
def dense(inputs, units, bias_shape, w_i, b_i=None, activation=tf.nn.relu): | ||
# 使用tf.layers,注意:先flatten | ||
# dense1 = tf.layers.dense(tf.contrib.layers.flatten(relu5), activation=tf.nn.relu, units=50) | ||
if not isinstance(inputs, ops.Tensor): | ||
inputs = ops.convert_to_tensor(inputs, dtype='float') | ||
# dim_list = inputs.get_shape().as_list() | ||
# flatten_shape = dim_list[1] if len(dim_list) <= 2 else reduce(lambda x, y: x * y, dim_list[1:]) | ||
# reshaped = tf.reshape(inputs, [dim_list[0], flatten_shape]) | ||
if len(inputs.shape) > 2: | ||
inputs = tf.contrib.layers.flatten(inputs) | ||
flatten_shape = inputs.shape[1] | ||
weights = tf.get_variable('weights', shape=[flatten_shape, units], initializer=w_i) | ||
dense = tf.matmul(inputs, weights) | ||
if bias_shape is not None: | ||
assert bias_shape[0] == units | ||
biases = tf.get_variable('biases', shape=bias_shape, initializer=b_i) | ||
return activation(dense + biases) if activation is not None else dense + biases | ||
return activation(dense) if activation is not None else dense |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"cells": [], | ||
"metadata": {}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.