Skip to content

Commit

Permalink
Merge pull request #629 from borglab/netvlad-plot-fix
Browse files Browse the repository at this point in the history
Fix netvlad output dir
  • Loading branch information
akshay-krishnan authored Apr 26, 2023
2 parents 3ad18e5 + 635bdad commit 4a9dc6a
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 35 deletions.
8 changes: 5 additions & 3 deletions gtsfm/retriever/joint_netvlad_sequential_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
Authors: John Lambert
"""

from typing import List, Tuple
from pathlib import Path
from typing import List, Optional, Tuple

import dask
from dask.delayed import Delayed
Expand Down Expand Up @@ -43,17 +44,18 @@ def create_computation_graph(self, loader: LoaderBase) -> Delayed:
"""
return self.run(loader=loader)

def run(self, loader: LoaderBase) -> Delayed:
def run(self, loader: LoaderBase, plots_output_dir: Optional[Path] = None) -> Delayed:
"""Compute potential image pairs.
Args:
loader: image loader. The length of this loader will provide the total number of images
for exhaustive global descriptor matching.
plots_output_dir: Directory to save plots to. If None, plots are not saved.
Return:
pair_indices: (i1,i2) image pairs.
"""
sim_pairs = self._similarity_retriever.create_computation_graph(loader)
sim_pairs = self._similarity_retriever.create_computation_graph(loader, plots_output_dir=plots_output_dir)
seq_pairs = self._seq_retriever.create_computation_graph(loader)

return dask.delayed(self.aggregate_pairs)(sim_pairs=sim_pairs, seq_pairs=seq_pairs)
Expand Down
30 changes: 15 additions & 15 deletions gtsfm/retriever/netvlad_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

logger = logger_utils.get_logger()

PLOT_SAVE_DIR = Path(__file__).parent.parent.parent / "plots"

MAX_NUM_IMAGES = 10000


Expand All @@ -54,32 +52,35 @@ def __init__(self, num_matched: int, min_score: float = 0.1, blocksize: int = 50
self._blocksize = blocksize
self._min_score = min_score

def create_computation_graph(self, loader: LoaderBase) -> Delayed:
def create_computation_graph(self, loader: LoaderBase, plots_output_dir: Optional[Path] = None) -> Delayed:
"""Compute potential image pairs.
Args:
loader: image loader. The length of this loader will provide the total number of images
for exhaustive global descriptor matching.
plots_output_dir: Directory to save plots to. If None, plots are not saved.
Return:
Delayed task that evaluates to a list of (i1,i2) image pairs.
"""
return self.run(loader=loader)
return self.run(loader=loader, plots_output_dir=plots_output_dir)

def run(self, loader: LoaderBase, visualize: bool = True) -> Delayed:
def run(self, loader: LoaderBase, plots_output_dir: Optional[Path] = None) -> Delayed:
"""Compute potential image pairs.
Args:
loader: image loader. The length of this loader will provide the total number of images
for exhaustive global descriptor matching.
visualize:
plots_output_dir: Directory to save plots to. If None, plots are not saved.
Return:
Delayed task which evaluates to a list of (i1,i2) image pairs.
"""
num_images = len(loader)
sim = self.compute_similarity_matrix(loader, num_images)
return dask.delayed(self.compute_pairs_from_similarity_matrix)(sim=sim, loader=loader, visualize=visualize)
return dask.delayed(self.compute_pairs_from_similarity_matrix)(
sim=sim, loader=loader, plots_output_dir=plots_output_dir
)

def compute_similarity_matrix(self, loader: LoaderBase, num_images: int) -> Delayed:
"""Compute a similarity matrix between all pairs of images.
Expand Down Expand Up @@ -184,14 +185,14 @@ def _aggregate_subblocks(self, subblock_results: List[SubBlockSimilarityResult],
return sim

def compute_pairs_from_similarity_matrix(
self, sim: torch.Tensor, loader: LoaderBase, visualize: bool = True
self, sim: torch.Tensor, loader: LoaderBase, plots_output_dir: Optional[Path] = None
) -> List[Tuple[int, int]]:
"""
Args:
sim: tensor of shape (num_images, num_images) representing similarity matrix.
loader: image loader. The length of this loader will provide the total number of images
for exhaustive global descriptor matching.
visualize: whether to save a visual plot of the computed image similarity matrix.
plots_output_dir: Directory to save plots to. If None, plots are not saved.
Returns:
pair_indices: (i1,i2) image pairs.
Expand All @@ -206,30 +207,29 @@ def compute_pairs_from_similarity_matrix(
)
named_pairs = [(query_names[i], query_names[j]) for i, j in pairs]

if visualize:
os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
if plots_output_dir:
os.makedirs(plots_output_dir, exist_ok=True)

# Save image of similarity matrix.
plt.imshow(np.triu(sim.detach().cpu().numpy()))
plt.title("Image Similarity Matrix")
plt.savefig(os.path.join(PLOT_SAVE_DIR, "netvlad_similarity_matrix.jpg"), dpi=500)
plt.savefig(str(plots_output_dir / "netvlad_similarity_matrix.jpg"), dpi=500)
plt.close("all")

# Save values in similarity matrix.
np.savetxt(
fname=os.path.join(PLOT_SAVE_DIR, "netvlad_similarity_matrix.txt"),
fname=str(plots_output_dir / "netvlad_similarity_matrix.txt"),
X=sim.detach().cpu().numpy(),
fmt="%.2f",
delimiter=",",
)

# Save named pairs and scores.
with open(os.path.join(PLOT_SAVE_DIR, "netvlad_named_pairs.txt"), "w") as fid:
with open(plots_output_dir / "netvlad_named_pairs.txt", "w") as fid:
for (_named_pair, _pair_ind) in zip(named_pairs, pairs):
fid.write("%.4f %s %s\n" % (sim[_pair_ind[0], _pair_ind[1]], _named_pair[0], _named_pair[1]))

logger.info("Found %d pairs from the NetVLAD Retriever.", len(pairs))
logger.info("Image Name Pairs:" + str(named_pairs))
return pairs


Expand Down
11 changes: 7 additions & 4 deletions gtsfm/retriever/retriever_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import abc
from enum import Enum
from typing import List, Tuple
from pathlib import Path
from typing import List, Optional, Tuple

import dask
from dask.delayed import Delayed
Expand Down Expand Up @@ -44,25 +45,27 @@ def get_ui_metadata() -> UiMetadata:
)

@abc.abstractmethod
def run(self, loader: LoaderBase) -> List[Tuple[int, int]]:
def run(self, loader: LoaderBase, plots_output_dir: Optional[Path] = None) -> List[Tuple[int, int]]:
"""Compute potential image pairs.
Args:
loader: image loader. The length of this loader will provide the total number of images
for exhaustive global descriptor matching.
plots_output_dir: Directory to save plots to.
Return:
pair_indices: (i1,i2) image pairs.
"""

def create_computation_graph(self, loader: LoaderBase) -> Delayed:
def create_computation_graph(self, loader: LoaderBase, plots_output_dir: Optional[Path] = None) -> Delayed:
"""Create Dask graph for image retriever.
Args:
loader: image loader. The length of this loader will provide the total number of images
for exhaustive global descriptor matching.
plots_output_dir: Directory to save plots to.
Return:
Delayed task that evaluates to a list of (i1,i2) image pairs.
"""
return dask.delayed(self.run)(loader=loader)
return dask.delayed(self.run)(loader=loader, plots_output_dir=plots_output_dir)
7 changes: 4 additions & 3 deletions gtsfm/retriever/rig_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
Author: Frank Dellaert
"""

from typing import List, Tuple
from pathlib import Path
from typing import List, Optional, Tuple

import gtsfm.utils.logger as logger_utils
from gtsfm.loader.hilti_loader import HiltiLoader
Expand All @@ -29,11 +29,12 @@ def __init__(self, threshold: int = 100):
super().__init__(matching_regime=ImageMatchingRegime.RIG_HILTI)
self._threshold = threshold

def run(self, loader: LoaderBase) -> List[Tuple[int, int]]:
def run(self, loader: LoaderBase, plots_output_dir: Optional[Path] = None) -> List[Tuple[int, int]]:
"""Compute potential image pairs.
Args:
loader: image loader.
plots_output_dir: Directory to save plots to. Unused in this retriever.
Return:
pair_indices: (i1,i2) image pairs.
Expand Down
6 changes: 4 additions & 2 deletions gtsfm/retriever/sequential_hilti_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Authors: Ayush Baid.
"""
from typing import List, Tuple
from pathlib import Path
from typing import List, Optional, Tuple

import gtsfm.utils.logger as logger_utils
from gtsfm.loader.hilti_loader import HiltiLoader
Expand Down Expand Up @@ -48,12 +49,13 @@ def is_valid_pair(self, loader, idx1: int, idx2: int) -> bool:
elif rig_idx_i1 < rig_idx_i2 and rig_idx_i2 - rig_idx_i1 <= self._max_frame_lookahead:
return (cam_idx_i1, cam_idx_i2) in INTER_RIG_VALID_PAIRS

def run(self, loader: HiltiLoader) -> List[Tuple[int, int]]:
def run(self, loader: HiltiLoader, plots_output_dir: Optional[Path] = None) -> List[Tuple[int, int]]:
"""Compute potential image pairs.
Args:
loader: image loader. The length of this loader will provide the total number of images
for exhaustive global descriptor matching.
plots_output_dir: Directory to save plots to. Unused in this retriever.
Return:
pair_indices: (i1,i2) image pairs.
Expand Down
6 changes: 4 additions & 2 deletions gtsfm/retriever/sequential_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
Authors: John Lambert
"""
from typing import List, Tuple
from pathlib import Path
from typing import List, Optional, Tuple

import gtsfm.utils.logger as logger_utils
from gtsfm.loader.loader_base import LoaderBase
Expand All @@ -22,12 +23,13 @@ def __init__(self, max_frame_lookahead: int) -> None:
super().__init__(matching_regime=ImageMatchingRegime.SEQUENTIAL)
self._max_frame_lookahead = max_frame_lookahead

def run(self, loader: LoaderBase) -> List[Tuple[int, int]]:
def run(self, loader: LoaderBase, plots_output_dir: Optional[Path] = None) -> List[Tuple[int, int]]:
"""Compute potential image pairs.
Args:
loader: image loader. The length of this loader will provide the total number of images
for exhaustive global descriptor matching.
plots_output_dir: Directory to save plots to. Unused in this retriever.
Return:
pair_indices: (i1,i2) image pairs.
Expand Down
4 changes: 3 additions & 1 deletion gtsfm/runner/gtsfm_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def run(self) -> None:
process_graph_generator.is_image_correspondence = True
process_graph_generator.save_graph()

pairs_graph = self.scene_optimizer.retriever.create_computation_graph(self.loader)
pairs_graph = self.scene_optimizer.retriever.create_computation_graph(
self.loader, plots_output_dir=self.scene_optimizer._plot_base_path
)
with performance_report(filename="retriever-dask-report.html"):
image_pair_indices = pairs_graph.compute()

Expand Down
4 changes: 3 additions & 1 deletion gtsfm/runner/run_scene_optimizer_astrovision.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def run(self) -> None:
)

with Client(cluster) as client, performance_report(filename="dask-report.html"):
pairs_graph = self.scene_optimizer.retriever.create_computation_graph(self.loader)
pairs_graph = self.scene_optimizer.retriever.create_computation_graph(
self.loader, self.scene_optimizer._plot_base_path
)
image_pair_indices = pairs_graph.compute()

(
Expand Down
10 changes: 6 additions & 4 deletions gtsfm/scene_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,13 @@ def save_full_frontend_metrics(

def _save_retrieval_two_view_metrics(metrics_path: Path, plot_base_path: Path) -> None:
"""Compare 2-view similarity scores with their 2-view pose errors after viewgraph estimation."""
sim_fpath = os.path.join(plot_base_path, "netvlad_similarity_matrix.txt")
sim = np.loadtxt(sim_fpath, delimiter=",")
sim_fpath = plot_base_path / "netvlad_similarity_matrix.txt"
if not sim_fpath.exists():
logger.warning(msg="NetVLAD similarity matrix not found. Skipping retrieval metrics.")
return

json_fpath = os.path.join(metrics_path, "two_view_report_VIEWGRAPH_2VIEW_REPORT.json")
json_data = io_utils.read_json_file(json_fpath)
sim = np.loadtxt(str(sim_fpath), delimiter=",")
json_data = io_utils.read_json_file(metrics_path / "two_view_report_VIEWGRAPH_2VIEW_REPORT.json")

sim_scores = []
R_errors = []
Expand Down

0 comments on commit 4a9dc6a

Please sign in to comment.