Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weights & Biases integration for inference #543

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ opencv-python>=4.1.2
PyYAML>=5.3.1
scipy>=1.4.1
tqdm>=4.41.0
wget>=3.2
addict>=2.4.0
tensorboard>=2.7.0
pycocotools>=2.0
Expand Down
67 changes: 65 additions & 2 deletions tools/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
import argparse
import os
import sys
import requests
import os.path as osp
from pathlib import Path
from typing import Optional

import wget
import urllib

import wandb
import torch

ROOT = os.getcwd()
Expand All @@ -14,6 +21,8 @@
from yolov6.utils.events import LOGGER
from yolov6.core.inferer import Inferer

from yolov6.logger.wandb_inference_logger import WandbInferenceLogger


def get_args_parser(add_help=True):
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Inference.', add_help=add_help)
Expand All @@ -36,6 +45,8 @@ def get_args_parser(add_help=True):
parser.add_argument('--hide-labels', default=False, action='store_true', help='hide labels.')
parser.add_argument('--hide-conf', default=False, action='store_true', help='hide confidences.')
parser.add_argument('--half', action='store_true', help='whether to use FP16 half-precision inference.')
parser.add_argument("--wandb_project", type=str, default=None, help="Name of Weights & Biases Project.")
parser.add_argument("--wandb_entity", type=str, default=None, help="Name of Weights & Biases Entity.")

args = parser.parse_args()
LOGGER.info(args)
Expand All @@ -62,10 +73,12 @@ def run(weights=osp.join(ROOT, 'yolov6s.pt'),
hide_labels=False,
hide_conf=False,
half=False,
wandb_project: Optional[str] = None,
wandb_entity: Optional[str] = None,
):
""" Inference process, supporting inference on one image file or directory which containing images.
Args:
weights: The path of model.pt, e.g. yolov6s.pt
weights: The path of model.pt, e.g. yolov6s.pt. Note you can pass models that are part of the latest release without having to download them manually.
source: Source path, supporting image files or dirs containing images.
yaml: Data yaml file, .
img_size: Inference image-size, e.g. 640
Expand All @@ -84,6 +97,53 @@ def run(weights=osp.join(ROOT, 'yolov6s.pt'),
hide_conf: Hide confidences
half: Use FP16 half-precision inference, e.g. False
"""
if wandb_project is not None:
wandb.init(
project=wandb_project,
name=name if name is not None else None,
entity=wandb_entity,
job_type="inference"
)

if name is None:
name = wandb.run.name

config = wandb.config
config.weights = Path(weights).name
config.source = source
config.yaml = yaml
config.img_size = img_size
config.conf_thres = conf_thres
config.iou_thres = iou_thres
config.max_det = max_det
config.device = device
config.save_txt = save_txt
config.save_img = not not_save_img
config.classes = classes
config.agnostic_nms = agnostic_nms
config.hide_labels = hide_labels
config.hide_conf = hide_conf
config.half = half

if not osp.isfile(weights):
try:
print("Downloading weights...")
weights_url = requests.get(
"https://api.github.com/repos/meituan/YOLOv6/releases/latest"
).json()["html_url"].replace("tag", "download") + f"/{weights}"
urllib.request.urlretrieve(weights_url, weights)
print("\nDone.")
except urllib.error.HTTPError:
print("Unable to download model.")

if not osp.isfile(source) and not osp.isdir(source):
try:
print("Downloading image...")
source = wget.download(source)
print("\nDone.")
except urllib.error.HTTPError:
print("Unable to download image.")

# create save dir
if save_dir is None:
save_dir = osp.join(project, name)
Expand All @@ -100,11 +160,14 @@ def run(weights=osp.join(ROOT, 'yolov6s.pt'),
os.makedirs(save_txt_path)

# Inference
inferer = Inferer(source, weights, device, yaml, img_size, half)
inferer = Inferer(source, weights, device, yaml, img_size, half, inference_logger=WandbInferenceLogger() if wandb.run is not None else None)
inferer.infer(conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir, save_txt, not not_save_img, hide_labels, hide_conf, view_img)

if save_txt or not not_save_img:
LOGGER.info(f"Results saved to {save_dir}")

if wandb.run is not None:
wandb.finish()


def main(args):
Expand Down
21 changes: 20 additions & 1 deletion yolov6/core/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import numpy as np
import os.path as osp
from typing import Optional

from tqdm import tqdm
from pathlib import Path
Expand All @@ -20,8 +21,10 @@
from yolov6.utils.nms import non_max_suppression
from yolov6.utils.torch_utils import get_model_info

from yolov6.logger.wandb_inference_logger import WandbInferenceLogger

class Inferer:
def __init__(self, source, weights, device, yaml, img_size, half):
def __init__(self, source, weights, device, yaml, img_size, half, inference_logger: Optional[WandbInferenceLogger] = None,):

self.__dict__.update(locals())

Expand All @@ -33,8 +36,15 @@ def __init__(self, source, weights, device, yaml, img_size, half):
self.model = DetectBackend(weights, device=self.device)
self.stride = self.model.stride
self.class_names = load_yaml(yaml)['names']

if self.inference_logger is not None:
self.inference_logger.label_dictionary = {
idx: self.class_names[idx] for idx in range(len(self.class_names))
}

self.img_size = self.check_img_size(self.img_size, s=self.stride) # check image size
self.half = half
self.inference_logger = inference_logger

# Switch model to deploy status
self.model_switch(self.model.model, self.img_size)
Expand Down Expand Up @@ -94,6 +104,7 @@ def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir,

if len(det):
det[:, :4] = self.rescale(img.shape[2:], det[:, :4], img_src.shape).round()

for *xyxy, conf, cls in reversed(det):
if save_txt: # Write to file
xywh = (self.box_convert(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
Expand All @@ -108,6 +119,12 @@ def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir,
self.plot_box_and_label(img_ori, max(round(sum(img_ori.shape) / 2 * 0.003), 2), xyxy, label, color=self.generate_colors(class_num, True))

img_src = np.asarray(img_ori)

if isinstance(self.inference_logger, WandbInferenceLogger):
self.inference_logger.in_infer(
np.array(img_ori), img_path, reversed(det)
)


# FPS counter
fps_calculator.update(1.0 / (t2 - t1))
Expand Down Expand Up @@ -150,6 +167,8 @@ def infer(self, conf_thres, iou_thres, classes, agnostic_nms, max_det, save_dir,
save_path = str(Path(save_path).with_suffix('.mp4')) # force *.mp4 suffix on results videos
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
vid_writer.write(img_src)

self.inference_logger.on_infer_end()

@staticmethod
def precess_image(img_src, img_size, stride, half):
Expand Down
Empty file added yolov6/logger/__init__.py
Empty file.
60 changes: 60 additions & 0 deletions yolov6/logger/wandb_inference_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import wandb
from pathlib import Path


class WandbInferenceLogger:
def __init__(self) -> None:
self._label_dictionary = {}
self.table = wandb.Table(
columns=[
"Image-File",
"Predictions",
"Number-of-Objects",
"Prediction-Confidence",
]
)

@property
def label_dictionary(self):
return self._label_dictionary

@label_dictionary.setter
def label_dictionary(self, new_dict):
self._label_dictionary = new_dict

def in_infer(self, image, image_file, detection_results):
bbox_data, confidences = [], []
height, width, _ = image.shape
for idx, (*xyxy, confidence, class_id) in enumerate(detection_results):
confidences.append(float(confidence))
xyxy = [int(coord) for coord in xyxy]
bbox_data.append(
{
"position": {
"minX": xyxy[0] / width,
"maxX": xyxy[2] / width,
"minY": xyxy[1] / height,
"maxY": xyxy[3] / height,
},
"class_id": int(class_id),
"box_caption": f"Key {idx}: {self.label_dictionary[int(class_id)]} {float(confidence)}",
"scores": {"confidence": float(confidence)},
}
)
self.table.add_data(
Path(image_file).stem,
wandb.Image(
image_file,
boxes={
"predictions": {
"box_data": bbox_data,
"class_labels": self.label_dictionary,
}
},
),
len(detection_results),
confidences,
)

def on_infer_end(self):
wandb.log({"Inference": self.table})