Skip to content

Commit

Permalink
Add option in matcher to directly accept PIL images or image paths (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
spagnoloG authored Oct 19, 2024
1 parent 905ff76 commit 0e50b44
Showing 1 changed file with 36 additions and 24 deletions.
60 changes: 36 additions & 24 deletions romatch/models/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,18 +593,25 @@ def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return
@torch.inference_mode()
def match(
self,
im_A_path,
im_B_path,
im_A_input,
im_B_input,
*args,
batched=False,
device = None,
device=None,
):
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if isinstance(im_A_path, (str, os.PathLike)):
im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")

# Check if inputs are file paths or already loaded images
if isinstance(im_A_input, (str, os.PathLike)):
im_A = Image.open(im_A_input).convert("RGB")
else:
im_A = im_A_input

if isinstance(im_B_input, (str, os.PathLike)):
im_B = Image.open(im_B_input).convert("RGB")
else:
im_A, im_B = im_A_path, im_B_path
im_B = im_B_input

symmetric = self.symmetric
self.train(False)
Expand All @@ -616,9 +623,9 @@ def match(
# Get images in good format
ws = self.w_resized
hs = self.h_resized

test_transform = get_tuple_transform_ops(
resize=(hs, ws), normalize=True, clahe = False
resize=(hs, ws), normalize=True, clahe=False
)
im_A, im_B = test_transform((im_A, im_B))
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
Expand All @@ -633,70 +640,75 @@ def match(
finest_scale = 1
# Run matcher
if symmetric:
corresps = self.forward_symmetric(batch)
corresps = self.forward_symmetric(batch)
else:
corresps = self.forward(batch, batched = True)
corresps = self.forward(batch, batched=True)

if self.upsample_preds:
hs, ws = self.upsample_res

if self.attenuate_cert:
low_res_certainty = F.interpolate(
corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
)
cert_clamp = 0
factor = 0.5
low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
low_res_certainty = factor * low_res_certainty * (low_res_certainty < cert_clamp)

if self.upsample_preds:
finest_corresps = corresps[finest_scale]
torch.cuda.empty_cache()
test_transform = get_tuple_transform_ops(
resize=(hs, ws), normalize=True
)
im_A, im_B = test_transform((Image.open(im_A_path).convert('RGB'), Image.open(im_B_path).convert('RGB')))
if isinstance(im_A_input, (str, os.PathLike)):
im_A, im_B = test_transform(
(Image.open(im_A_input).convert('RGB'), Image.open(im_B_input).convert('RGB')))
else:
im_A, im_B = test_transform((im_A_input, im_B_input))

im_A, im_B = im_A[None].to(device), im_B[None].to(device)
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
if symmetric:
corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
corresps = self.forward_symmetric(batch, upsample=True, batched=True, scale_factor=scale_factor)
else:
corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
im_A_to_im_B = corresps[finest_scale]["flow"]
corresps = self.forward(batch, batched=True, upsample=True, scale_factor=scale_factor)

im_A_to_im_B = corresps[finest_scale]["flow"]
certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
if finest_scale != 1:
im_A_to_im_B = F.interpolate(
im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
)
certainty = F.interpolate(
certainty, size=(hs, ws), align_corners=False, mode="bilinear"
certainty, size=(hs, ws), align_corners=False, mode="bilinear"
)
im_A_to_im_B = im_A_to_im_B.permute(
0, 2, 3, 1
)
)
# Create im_A meshgrid
im_A_coords = torch.meshgrid(
(
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
),
indexing = 'ij'
indexing='ij'
)
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
certainty = certainty.sigmoid() # logits -> probs
im_A_coords = im_A_coords.permute(0, 2, 3, 1)
if (im_A_to_im_B.abs() > 1).any() and True:
wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
certainty[wrong[:,None]] = 0
certainty[wrong[:, None]] = 0
im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
if symmetric:
A_to_B, B_to_A = im_A_to_im_B.chunk(2)
q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
im_B_coords = im_A_coords
s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
warp = torch.cat((q_warp, s_warp),dim=2)
warp = torch.cat((q_warp, s_warp), dim=2)
certainty = torch.cat(certainty.chunk(2), dim=3)
else:
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
Expand Down

0 comments on commit 0e50b44

Please sign in to comment.