diff --git a/igibson/scenes/indoor_scene.py b/igibson/scenes/indoor_scene.py index 404e476c4..41c6669e9 100644 --- a/igibson/scenes/indoor_scene.py +++ b/igibson/scenes/indoor_scene.py @@ -5,8 +5,8 @@ from abc import ABCMeta import cv2 -import networkx as nx import numpy as np +import rustworkx as rx from future.utils import with_metaclass from PIL import Image @@ -73,6 +73,8 @@ def load_trav_map(self, maps_path): for floor in range(len(self.floor_heights)): if self.trav_map_type == "with_obj": trav_map = np.array(Image.open(os.path.join(maps_path, "floor_trav_{}.png".format(floor)))) + elif self.trav_map_type == "no_door": + trav_map = np.array(Image.open(os.path.join(maps_path, "floor_trav_no_door_{}.png".format(floor)))) else: trav_map = np.array(Image.open(os.path.join(maps_path, "floor_trav_no_obj_{}.png".format(floor)))) @@ -100,47 +102,70 @@ def load_trav_map(self, maps_path): # We search for the largest connected areas if self.build_graph: - self.build_trav_graph(maps_path, floor, trav_map) + self.build_trav_graph(trav_map) self.floor_map.append(trav_map) - # TODO: refactor into C++ for speedup - def build_trav_graph(self, maps_path, floor, trav_map): + def build_trav_graph(self, trav_map): """ Build traversibility graph and only take the largest connected component - :param maps_path: String with the path to the folder containing the traversability maps - :param floor: floor number :param trav_map: traversability map """ - log.debug("Building traversable graph") - g = nx.Graph() - for i in range(self.trav_map_size): - for j in range(self.trav_map_size): - if trav_map[i, j] == 0: - continue - g.add_node((i, j)) - # 8-connected graph - neighbors = [(i - 1, j - 1), (i, j - 1), (i + 1, j - 1), (i - 1, j)] - for n in neighbors: - if 0 <= n[0] < self.trav_map_size and 0 <= n[1] < self.trav_map_size and trav_map[n[0], n[1]] > 0: - g.add_edge(n, (i, j), weight=l2_distance(n, (i, j))) + node_mapping = {} + g = rx.PyGraph(multigraph=False) # type: ignore + + x, y = np.where(trav_map != 0) + nodes_to_add = np.stack((x, y)).T + nodes_to_add = list(map(tuple, nodes_to_add)) + node_idxs = g.add_nodes_from(nodes_to_add) + node_mapping = {data: idx for data, idx in zip(nodes_to_add, node_idxs)} + + edges = set() + from_nodes = [] + from_nodes_pos = [] + to_nodes = [] + to_nodes_pos = [] + for node_idx, node in zip(g.node_indexes(), g.nodes()): + i, j = node + neighbors = [(i - 1, j - 1), (i, j - 1), (i + 1, j - 1), (i - 1, j)] + for n in neighbors: + if n in node_mapping and (node_idx, n) not in edges: + from_nodes.append(node_idx) + from_nodes_pos.append(node) + to_nodes.append(node_mapping[n]) + to_nodes_pos.append(n) + edges.add((node_idx, n)) + + distances = np.linalg.norm(np.array(from_nodes_pos) - np.array(to_nodes_pos), axis=1) + edges = [(x, y, z) for x, y, z in zip(from_nodes, to_nodes, distances)] + g.add_edges_from(edges) # only take the largest connected component - largest_cc = max(nx.connected_components(g), key=len) - g = g.subgraph(largest_cc).copy() - - self.floor_graph.append(g) - + largest_cc = max(rx.connected_components(g), key=len) # type: ignore + g = g.subgraph(list(largest_cc), preserve_attrs=True).copy() # update trav_map accordingly # This overwrites the traversability map loaded before # It sets everything to zero, then only sets to one the points where we have graph nodes # Dangerous! if the traversability graph is not computed from the loaded map but from a file, it could overwrite # it silently. trav_map[:, :] = 0 - for node in g.nodes: + for node in g.nodes(): trav_map[node[0], node[1]] = 255 + nodes = g.nodes() + node_idxs = g.node_indexes() + map_to_idx = {data: idx for data, idx in zip(nodes, node_idxs)} + self.floor_graph.append( + { + "graph": g, + "map_to_idx": map_to_idx, + "nodes": np.array(nodes), + "node_idxs": node_idxs, + } + ) + self.trav_map = trav_map + def get_random_point(self, floor=None): """ Sample a random point on the given floor number. If not given, sample a random floor number. @@ -204,19 +229,38 @@ def get_shortest_path(self, floor, source_world, target_world, entire_path=False source_map = tuple(self.world_to_map(source_world)) target_map = tuple(self.world_to_map(target_world)) - g = self.floor_graph[floor] - - if not g.has_node(target_map): - nodes = np.array(g.nodes) - closest_node = tuple(nodes[np.argmin(np.linalg.norm(nodes - target_map, axis=1))]) - g.add_edge(closest_node, target_map, weight=l2_distance(closest_node, target_map)) - - if not g.has_node(source_map): - nodes = np.array(g.nodes) - closest_node = tuple(nodes[np.argmin(np.linalg.norm(nodes - source_map, axis=1))]) - g.add_edge(closest_node, source_map, weight=l2_distance(closest_node, source_map)) - - path_map = np.array(nx.astar_path(g, source_map, target_map, heuristic=l2_distance)) + g = self.floor_graph[floor]["graph"] + map_to_idx = self.floor_graph[floor]["map_to_idx"] + nodes = self.floor_graph[floor]["nodes"] + node_idxs = self.floor_graph[floor]["node_idxs"] + + if target_map not in map_to_idx: + closest_node = np.argmin(np.linalg.norm(nodes - target_map, axis=1)) + closest_node_idx = node_idxs[closest_node] + closest_node_data = nodes[closest_node_idx] + target_node = g.add_node(target_map) + map_to_idx[target_map] = target_node + g.add_edge(closest_node, target_node, l2_distance(closest_node_data, target_map)) + + if source_map not in map_to_idx: + closest_node = np.argmin(np.linalg.norm(nodes - source_map, axis=1)) + closest_node_idx = node_idxs[closest_node] + closest_node_data = nodes[closest_node_idx] + source_node = g.add_node(source_map) + map_to_idx[source_map] = source_node + g.add_edge(closest_node, source_node, l2_distance(closest_node_data, source_map)) + + idx_to_map = {idx: data for data, idx in map_to_idx.items()} + + path_map = rx.astar_shortest_path( + g, + map_to_idx[source_map], + goal_fn=lambda x: x == target_map, + edge_cost_fn=lambda x: x, + estimate_cost_fn=lambda _: 0, + ) + + path_map = np.array([idx_to_map[idx] for idx in path_map]) path_world = self.map_to_world(path_map) geodesic_distance = np.sum(np.linalg.norm(path_world[1:] - path_world[:-1], axis=1)) diff --git a/pyproject.toml b/pyproject.toml index 3c933091e..8bcaa7a10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "bddl~=1.0.1", "urllib3>=1.20", "progressbar>=2.5", + "rustworkx", "packaging", ]