-
Notifications
You must be signed in to change notification settings - Fork 1
/
Priority_Replay.py
119 lines (102 loc) · 4.52 KB
/
Priority_Replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 25 00:43:00 2018
@author: wuyuankai
"""
import numpy as np
class SumTree(object):
"""
This SumTree code is modified version and the original code is from:
https://github.com/jaara/AI-blog/blob/master/SumTree.py
Story the data with it priority in tree and data frameworks.
"""
data_pointer = 0
def __init__(self, capacity):
self.capacity = capacity # for all priority values
self.tree = np.zeros(2 * capacity - 1)
# [--------------Parent nodes-------------][-------leaves to recode priority-------]
# size: capacity - 1 size: capacity
self.data = np.zeros(capacity, dtype=object) # for all transitions
# [--------------data frame-------------]
# size: capacity
def add(self, p, data):
tree_idx = self.data_pointer + self.capacity - 1
self.data[self.data_pointer] = data # update data_frame
self.update(tree_idx, p) # update tree_frame
self.data_pointer += 1
if self.data_pointer >= self.capacity: # replace when exceed the capacity
self.data_pointer = 0
def update(self, tree_idx, p):
change = p - self.tree[tree_idx]
self.tree[tree_idx] = p
# then propagate the change through tree
while tree_idx != 0: # this method is faster than the recursive loop in the reference code
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change
def get_leaf(self, v):
"""
Tree structure and array storage:
Tree index:
0 -> storing priority sum
/ \
1 2
/ \ / \
3 4 5 6 -> storing priority for transitions
Array type for storing:
[0,1,2,3,4,5,6]
"""
parent_idx = 0
while True: # the while loop is faster than the method in the reference code
cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree): # reach bottom, end search
leaf_idx = parent_idx
break
else: # downward search, always search for a higher priority node
if v <= self.tree[cl_idx]:
parent_idx = cl_idx
else:
v -= self.tree[cl_idx]
parent_idx = cr_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
@property
def total_p(self):
return self.tree[0] # the root
class Memory(object): # stored as ( s, a, r, s_ ) in SumTree
"""
This SumTree code is modified version and the original code is from:
https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
"""
epsilon = 0.01 # small amount to avoid zero priority
alpha = 0.6 # [0~1] convert the importance of TD error to priority
beta = 0.4 # importance-sampling, from initial value increasing to 1
beta_increment_per_sampling = 0.001
abs_err_upper = 1. # clipped abs error
def __init__(self, capacity):
self.tree = SumTree(capacity)
def store(self, transition):
max_p = np.max(self.tree.tree[-self.tree.capacity:])
if max_p == 0:
max_p = self.abs_err_upper
self.tree.add(max_p, transition) # set the max p for new p
def sample(self, n):
b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1))
pri_seg = self.tree.total_p / n # priority segment
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) # max = 1
min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p # for later calculate ISweight
for i in range(n):
a, b = pri_seg * i, pri_seg * (i + 1)
v = np.random.uniform(a, b)
idx, p, data = self.tree.get_leaf(v)
prob = p / self.tree.total_p
ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
b_idx[i], b_memory[i, :] = idx, data
return b_idx, b_memory, ISWeights
def batch_update(self, tree_idx, abs_errors):
abs_errors += self.epsilon # convert to abs and avoid 0
clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
ps = np.power(clipped_errors, self.alpha)
for ti, p in zip(tree_idx, ps):
self.tree.update(ti, p)