-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
59 lines (51 loc) · 1.47 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from DQN import DQN
from Cube import *
from Env import Env
from Player import *
env = Env()
agent = DQN(6, 3)
r_avg_list = []
for e in range(EPISODES):
total_reward = 0
done = False
state = np.reshape(env.reset(), [1, 6])
while not done:
action = agent.act(state)
next_state, reward, done = env.step(action)
next_state = np.reshape(next_state, [1, 6])
agent.remember(state, action, reward, next_state, done)
state = next_state
total_reward += reward
env.render()
print("episode: {}/{}, score: {}, reward: {}".format(e, EPISODES, env.player.score, total_reward))
r_avg_list.append(env.player.score)
agent.replay(MINI_BATCH)
# uncomment to play manually
# done = False
# env.reset()
# while not done:
# screen.fill((0, 0, 0))
# clock.tick(FPS)
# pygame.event.get()
# keys = pygame.key.get_pressed()
# if keys[pygame.K_LEFT]:
# s, r, done = env.step(2)
# print(s)
# elif keys[pygame.K_RIGHT]:
# s, r, done = env.step(1)
# print(s)
# elif keys[pygame.K_UP]:
# s, r, done = env.step(0)
# print(s)
# else:
# s, r, done = env.step(0)
# print(s)
#
# env.sprites().draw(screen)
# pygame.display.flip()
# pygame.time.delay(200)
import matplotlib.pyplot as plt
plt.plot(r_avg_list)
plt.show()
agent.save_model()
pygame.quit()