-
Notifications
You must be signed in to change notification settings - Fork 6
/
multiagent_traffic_simulator.py
205 lines (174 loc) · 7.28 KB
/
multiagent_traffic_simulator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import numpy as np
from dataclasses import replace
from example_adapter import get_observation_adapter
from utils import get_vehicle_start_at_time
from smarts.core.smarts import SMARTS
from smarts.core.agent import AgentSpec
from smarts.core.agent_interface import AgentInterface
from smarts.core.controllers import ActionSpaceType
from smarts.core.scenario import Scenario
from smarts.core.traffic_history_provider import TrafficHistoryProvider
from smarts.env.wrappers.parallel_env import ParallelEnv
def get_action_adapter():
def action_adapter(model_action):
assert len(model_action) == 2
return (model_action[0], model_action[1])
return action_adapter
class MATrafficSim:
def __init__(self, scenarios, agent_number, obs_stacked_size=1):
self.scenarios_iterator = Scenario.scenario_variations(scenarios, [])
self._init_scenario()
self.obs_stacked_size = obs_stacked_size
self.n_agents = agent_number
self.agentid_to_vehid = {}
self.agent_ids = [f"agent_{i}" for i in range(self.n_agents)]
self.agent_spec = AgentSpec(
interface=AgentInterface(
max_episode_steps=None,
waypoints=False,
neighborhood_vehicles=True,
ogm=False,
rgb=False,
lidar=False,
action=ActionSpaceType.Imitation,
),
action_adapter=get_action_adapter(),
observation_adapter=get_observation_adapter(obs_stacked_size),
)
self.smarts = SMARTS(
agent_interfaces={},
traffic_sim=None,
envision=None,
)
def seed(self, seed):
np.random.seed(seed)
def step(self, action):
for agent_id in self.agent_ids:
if agent_id not in action.keys():
continue
agent_action = action[agent_id]
action[agent_id] = self.agent_spec.action_adapter(agent_action)
observations, rewards, dones, _ = self.smarts.step(action)
info = {}
for k in observations.keys():
observations[k] = self.agent_spec.observation_adapter(observations[k])
dones["__all__"] = all(dones.values())
return (
observations,
rewards,
dones,
info,
)
def reset(self, internal_replacement=False, min_successor_time=5.0):
if self.vehicle_itr + self.n_agents >= (len(self.vehicle_ids) - 1):
self.vehicle_itr = 0
self.vehicle_id = self.vehicle_ids[
self.vehicle_itr : self.vehicle_itr + self.n_agents
]
traffic_history_provider = self.smarts.get_provider_by_type(
TrafficHistoryProvider
)
assert traffic_history_provider
for i in range(self.n_agents):
self.agentid_to_vehid[f"agent_{i}"] = self.vehicle_id[i]
history_start_time = self.vehicle_missions[self.vehicle_id[0]].start_time
agent_interfaces = {a_id: self.agent_spec.interface for a_id in self.agent_ids}
if internal_replacement:
# NOTE(zbzhu): we use the first-end vehicle to compute the end time to make sure all vehicles can exist on the map
history_end_time = min(
[
self.scenario.traffic_history.vehicle_final_exit_time(v_id)
for v_id in self.vehicle_id
]
)
alive_time = history_end_time - history_start_time
traffic_history_provider.start_time = (
history_start_time
+ np.random.choice(
max(0, round(alive_time * 10) - round(min_successor_time * 10))
)
/ 10
)
else:
traffic_history_provider.start_time = history_start_time
ego_missions = {}
for agent_id in self.agent_ids:
vehicle_id = self.agentid_to_vehid[agent_id]
start_time = max(
0,
self.vehicle_missions[vehicle_id].start_time
- traffic_history_provider.start_time,
)
ego_missions[agent_id] = replace(
self.vehicle_missions[vehicle_id],
start_time=start_time,
start=get_vehicle_start_at_time(
vehicle_id,
round(
max(
traffic_history_provider.start_time,
self.vehicle_missions[vehicle_id].start_time,
),
1,
),
self.scenario.traffic_history,
),
)
self.scenario.set_ego_missions(ego_missions)
self.smarts.switch_ego_agents(agent_interfaces)
observations = self.smarts.reset(self.scenario)
for k in observations.keys():
observations[k] = self.agent_spec.observation_adapter(observations[k])
self.vehicle_itr += self.n_agents
return observations
def _init_scenario(self):
self.scenario = next(self.scenarios_iterator)
self.vehicle_missions = self.scenario.discover_missions_of_traffic_histories()
self.veh_start_times = {}
for v_id, mission in self.vehicle_missions.items():
self.veh_start_times[v_id] = mission.start_time
self.vehicle_ids = list(self.vehicle_missions.keys())
vlist = []
for vehicle_id, start_time in self.veh_start_times.items():
vlist.append((vehicle_id, start_time))
dtype = [("id", int), ("start_time", float)]
vlist = np.array(vlist, dtype=dtype)
vlist = np.sort(vlist, order="start_time")
self.vehicle_ids = list(self.vehicle_missions.keys())
for id in range(len(self.vehicle_ids)):
self.vehicle_ids[id] = f"{vlist[id][0]}"
self.vehicle_itr = np.random.choice(len(self.vehicle_ids))
def close(self):
if self.smarts is not None:
self.smarts.destroy()
if __name__ == "__main__":
"""Dummy Rollout"""
env = MATrafficSim(["./ngsim"], agent_number=5)
obs = env.reset()
done = {a_id: False for a_id in obs.keys()}
n_steps = 100000
for step in range(n_steps):
act_n = {}
for agent_id in obs.keys():
if step and done[agent_id]:
continue
act_n[agent_id] = np.random.normal(0, 1, size=(2,))
obs, rew, done, info = env.step(act_n)
if done["__all__"]:
print("done")
obs = env.reset()
done = {a_id: False for a_id in obs.keys()}
print(rew)
print("finished")
env.close()
""" Parallel Rollout """
env_num = 2
env_creator = lambda: MATrafficSim(["./ngsim"], agent_number=5)
vector_env = ParallelEnv([env_creator] * env_num, auto_reset=True)
vec_obs = vector_env.reset()
vec_act = []
for obs in vec_obs:
vec_act.append({a_id: np.random.normal(0, 1, size=(2,)) for a_id in obs.keys()})
vec_next_obs, vec_rew, vec_done, vec_info = vector_env.step(vec_act)
print("parallel finished!")
vector_env.close()