-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
134 lines (115 loc) · 3.64 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import random
import gym
import turtle_robot_gym
import numpy as np
import matplotlib.pyplot as plt
import pickle
# create the turtle environment
#Laberinto 3x3
'''setup = { 'width': 3,
'height': 3,
'walls': [(1,1),(0,2)],
'start': (0,0),
'goal': (1,2),
'theta': 0
}'''
#Laberinto 4x4
'''setup = { 'width': 4,
'height': 4,
'walls': [(1,1),(2,0),(2,1),(3,1),(3,3)],
'start': (0,0),
'goal': (2,3),
'theta': 0
}'''
#Laberinto 5x5
'''setup = { 'width': 5,
'height': 5,
'walls': [(1,1),(3,0),(2,2),(2,3),(3,1),(4,2)],
'start': (0,0),
'goal': (3,2),
'theta': 0
} '''
#Laberinto 6x6
setup = { 'width': 6,
'height': 6,
'walls': [(1,1),(0,5),(1,2),(1,3),(3,3),(2,4),(2,5),(5,4)],
'start': (0,0),
'goal': (5,5),
'theta': 0
}
env = gym.make('TurtleRobotEnv-v1_2', **setup)
def choose_action(epsilon,state):
if np.random.random() <= epsilon:
return random.randint(0,2)
else:
return np.random.choice((np.argwhere(Q[state,:] == np.amax(Q[state,:]))).flatten())
for i in range(1):
epsilon=1.0
Q = np.zeros([400, 3])
lr = 0.15
y = 0.99
eps = 20000
visited_states=[]
list_acciones=[]
print(i)
for i in range(eps):
# initialize the environment
s=env.reset()
s = list(map(str, s))
OldStrState=''.join(s)
if OldStrState not in visited_states: visited_states.append(OldStrState)
OldState=visited_states.index(OldStrState)
done = False
n_acciones=0
while not done:
# choose a random action
#action = random.randint(0, 2)
action=choose_action(epsilon,OldState)
# take the action and get the information from the environment
new_state, reward, done, info = env.step(action)
new_state = list(map(str, new_state))
StrState=''.join(new_state)
if StrState not in visited_states: visited_states.append(StrState)
NewState=visited_states.index(StrState)
#Update Q table
Q[OldState,action] = Q[OldState,action] + lr*(reward + y*np.max(Q[NewState,:]) - Q[OldState,action])
OldState=NewState
n_acciones+=1
# show the current position and reward
#env.render(action=action, reward=reward)
list_acciones.append(n_acciones)
if done:
epsilon=max(epsilon*0.99,0.05)
#epsilon=epsilon*0.9999
#f=open('data/AccionesQLeaning.txt','a')
#f.write(str(list_acciones))
#f.write('\n')
#f.close()
'''x=[i for i in range(len(list_acciones))]
plt.plot(x,list_acciones)
plt.grid()
plt.show()'''
print(Q[:len(visited_states),:])
print(visited_states)
data=[Q[:len(visited_states),:], visited_states]
with open('models/Qlear6x6.pkl', 'wb') as f:
pickle.dump(data, f)
#Using Q table obtained after all episodes
s=env.reset()
s = list(map(str, s))
OldStrState=''.join(s)
if OldStrState not in visited_states: visited_states.append(OldStrState)
OldState=visited_states.index(OldStrState)
done=False
contador=1
while not done:
action = np.random.choice((np.argwhere(Q[OldState,:] == np.amax(Q[OldState,:]))).flatten())
new_state, reward, done, info = env.step(action)
new_state = list(map(str, new_state))
StrState=''.join(new_state)
#if StrState not in visited_states: visited_states.append(StrState)
NewState=visited_states.index(StrState)
OldState=NewState
env.render(action=action, reward=reward)
contador+=1
print(contador)