Skip to content

Commit

Permalink
feat: improve performance of shortest path algorithm with rustworkx
Browse files Browse the repository at this point in the history
  • Loading branch information
mjlbach committed Nov 14, 2022
1 parent 32242dd commit 04fb3cc
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 37 deletions.
118 changes: 81 additions & 37 deletions igibson/scenes/indoor_scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from abc import ABCMeta

import cv2
import networkx as nx
import numpy as np
import rustworkx as rx
from future.utils import with_metaclass
from PIL import Image

Expand Down Expand Up @@ -73,6 +73,8 @@ def load_trav_map(self, maps_path):
for floor in range(len(self.floor_heights)):
if self.trav_map_type == "with_obj":
trav_map = np.array(Image.open(os.path.join(maps_path, "floor_trav_{}.png".format(floor))))
elif self.trav_map_type == "no_door":
trav_map = np.array(Image.open(os.path.join(maps_path, "floor_trav_no_door_{}.png".format(floor))))
else:
trav_map = np.array(Image.open(os.path.join(maps_path, "floor_trav_no_obj_{}.png".format(floor))))

Expand Down Expand Up @@ -100,47 +102,70 @@ def load_trav_map(self, maps_path):

# We search for the largest connected areas
if self.build_graph:
self.build_trav_graph(maps_path, floor, trav_map)
self.build_trav_graph(trav_map)

self.floor_map.append(trav_map)

# TODO: refactor into C++ for speedup
def build_trav_graph(self, maps_path, floor, trav_map):
def build_trav_graph(self, trav_map):
"""
Build traversibility graph and only take the largest connected component
:param maps_path: String with the path to the folder containing the traversability maps
:param floor: floor number
:param trav_map: traversability map
"""
log.debug("Building traversable graph")
g = nx.Graph()
for i in range(self.trav_map_size):
for j in range(self.trav_map_size):
if trav_map[i, j] == 0:
continue
g.add_node((i, j))
# 8-connected graph
neighbors = [(i - 1, j - 1), (i, j - 1), (i + 1, j - 1), (i - 1, j)]
for n in neighbors:
if 0 <= n[0] < self.trav_map_size and 0 <= n[1] < self.trav_map_size and trav_map[n[0], n[1]] > 0:
g.add_edge(n, (i, j), weight=l2_distance(n, (i, j)))
node_mapping = {}
g = rx.PyGraph(multigraph=False) # type: ignore

x, y = np.where(trav_map != 0)
nodes_to_add = np.stack((x, y)).T
nodes_to_add = list(map(tuple, nodes_to_add))
node_idxs = g.add_nodes_from(nodes_to_add)
node_mapping = {data: idx for data, idx in zip(nodes_to_add, node_idxs)}

edges = set()
from_nodes = []
from_nodes_pos = []
to_nodes = []
to_nodes_pos = []
for node_idx, node in zip(g.node_indexes(), g.nodes()):
i, j = node
neighbors = [(i - 1, j - 1), (i, j - 1), (i + 1, j - 1), (i - 1, j)]
for n in neighbors:
if n in node_mapping and (node_idx, n) not in edges:
from_nodes.append(node_idx)
from_nodes_pos.append(node)
to_nodes.append(node_mapping[n])
to_nodes_pos.append(n)
edges.add((node_idx, n))

distances = np.linalg.norm(np.array(from_nodes_pos) - np.array(to_nodes_pos), axis=1)
edges = [(x, y, z) for x, y, z in zip(from_nodes, to_nodes, distances)]
g.add_edges_from(edges)

# only take the largest connected component
largest_cc = max(nx.connected_components(g), key=len)
g = g.subgraph(largest_cc).copy()

self.floor_graph.append(g)

largest_cc = max(rx.connected_components(g), key=len) # type: ignore
g = g.subgraph(list(largest_cc), preserve_attrs=True).copy()
# update trav_map accordingly
# This overwrites the traversability map loaded before
# It sets everything to zero, then only sets to one the points where we have graph nodes
# Dangerous! if the traversability graph is not computed from the loaded map but from a file, it could overwrite
# it silently.
trav_map[:, :] = 0
for node in g.nodes:
for node in g.nodes():
trav_map[node[0], node[1]] = 255

nodes = g.nodes()
node_idxs = g.node_indexes()
map_to_idx = {data: idx for data, idx in zip(nodes, node_idxs)}
self.floor_graph.append(
{
"graph": g,
"map_to_idx": map_to_idx,
"nodes": np.array(nodes),
"node_idxs": node_idxs,
}
)
self.trav_map = trav_map

def get_random_point(self, floor=None):
"""
Sample a random point on the given floor number. If not given, sample a random floor number.
Expand Down Expand Up @@ -204,19 +229,38 @@ def get_shortest_path(self, floor, source_world, target_world, entire_path=False
source_map = tuple(self.world_to_map(source_world))
target_map = tuple(self.world_to_map(target_world))

g = self.floor_graph[floor]

if not g.has_node(target_map):
nodes = np.array(g.nodes)
closest_node = tuple(nodes[np.argmin(np.linalg.norm(nodes - target_map, axis=1))])
g.add_edge(closest_node, target_map, weight=l2_distance(closest_node, target_map))

if not g.has_node(source_map):
nodes = np.array(g.nodes)
closest_node = tuple(nodes[np.argmin(np.linalg.norm(nodes - source_map, axis=1))])
g.add_edge(closest_node, source_map, weight=l2_distance(closest_node, source_map))

path_map = np.array(nx.astar_path(g, source_map, target_map, heuristic=l2_distance))
g = self.floor_graph[floor]["graph"]
map_to_idx = self.floor_graph[floor]["map_to_idx"]
nodes = self.floor_graph[floor]["nodes"]
node_idxs = self.floor_graph[floor]["node_idxs"]

if target_map not in map_to_idx:
closest_node = np.argmin(np.linalg.norm(nodes - target_map, axis=1))
closest_node_idx = node_idxs[closest_node]
closest_node_data = nodes[closest_node_idx]
target_node = g.add_node(target_map)
map_to_idx[target_map] = target_node
g.add_edge(closest_node, target_node, l2_distance(closest_node_data, target_map))

if source_map not in map_to_idx:
closest_node = np.argmin(np.linalg.norm(nodes - source_map, axis=1))
closest_node_idx = node_idxs[closest_node]
closest_node_data = nodes[closest_node_idx]
source_node = g.add_node(source_map)
map_to_idx[source_map] = source_node
g.add_edge(closest_node, source_node, l2_distance(closest_node_data, source_map))

idx_to_map = {idx: data for data, idx in map_to_idx.items()}

path_map = rx.astar_shortest_path(
g,
map_to_idx[source_map],
goal_fn=lambda x: x == target_map,
edge_cost_fn=lambda x: x,
estimate_cost_fn=lambda _: 0,
)

path_map = np.array([idx_to_map[idx] for idx in path_map])

path_world = self.map_to_world(path_map)
geodesic_distance = np.sum(np.linalg.norm(path_world[1:] - path_world[:-1], axis=1))
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"bddl~=1.0.1",
"urllib3>=1.20",
"progressbar>=2.5",
"rustworkx",
"packaging",
]

Expand Down

0 comments on commit 04fb3cc

Please sign in to comment.