-
Notifications
You must be signed in to change notification settings - Fork 0
/
DDPG_test_models.py
94 lines (79 loc) · 2.96 KB
/
DDPG_test_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import math
import random
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
import pprint
import highway_env
from DDPG_net import *
from highway_env.vehicle.behavior import IDMVehicle
import json
from collections import defaultdict
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
env = gym.make("lvxinfei-v0")
env.reset()
ddpg = torch.load('./weights_test/ddpg_net1.pth')#直线形模型
ddpg1 = torch.load('./weights_test/ddpg_net2-2.pth')#曲线形模型
max_steps = 1
rewards = []
batch_size = 32
output1 = []
output2 = []
info_out = defaultdict(list)
with torch.no_grad():
for step in range(max_steps):
print("================第{}回合======================================".format(step+1))
state = env.reset()
state = torch.flatten(torch.tensor(state))
done = False
t=0
while not done:
if t>=2:
if info_out['road heading'][-1] - info_out['road heading'][-2] != 0:#曲线形
action = ddpg1.policy_net.get_action(state)
print("曲线形")
else:#直线形
action = ddpg.policy_net.get_action(state)
print("直线形")
elif t<2:
print("直线形")
action = ddpg.policy_net.get_action(state)
# action = ddpg.policy_net.get_action(state)
next_state, reward, done, info = env.step(action)
'''info字典中含有的车辆信息'''
'''
"speed": self.vehicle.speed,
"crashed": self.vehicle.crashed,
"vehicle heading": self.vehicle.heading,#车辆相对于大地坐标系的指向角,以pi为单位
"action": action,
'x': self.vehicle.position[0],
'y': self.vehicle.position[1],
"vx": self.vehicle.velocity[0],#速度与sin_h的乘积
'vy': self.vehicle.velocity[1],
'sin_h': self.vehicle.direction[1],
"cos_h": self.vehicle.direction[0]
'''
# 对一些信息进行存储
info_out["speed"].append(info['speed'])
info_out["x"].append(info['x'])
info_out["y"].append(info['y'])
info_out["vx"].append(info['vx'])
info_out["vy"].append(info['vy'])
info_out["sin_h"].append(info['sin_h'])
info_out["cos_h"].append(info['cos_h'])
info_out["vehicle heading"].append(info['vehicle heading'])
info_out['road heading'].append(info['road heading'])
# print(info)
next_state = torch.flatten(torch.tensor(next_state))
state = next_state
env.render()
t=t+1
env.close()
with open("./JSON/v00.json", 'w', encoding='UTF-8') as f:
f.write(json.dumps(info_out))