Skip to content

Commit

Permalink
Confusion matrix (ultralytics#1474)
Browse files Browse the repository at this point in the history
* initial commit

* add plotting

* matrix to cpu

* bug fix

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* seaborn pandas to requirements.txt

* seaborn pandas to requirements.txt

* update wandb plotting

* remove pandas

* if plots

* if plots

* if plots

* if plots

* if plots

* initial commit

* add plotting

* matrix to cpu

* bug fix

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* update plot

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* seaborn pandas to requirements.txt

* seaborn pandas to requirements.txt

* update wandb plotting

* remove pandas

* if plots

* if plots

* if plots

* if plots

* if plots

* Cat apriori to autolabels

* cleanup
  • Loading branch information
glenn-jocher authored Nov 23, 2020
1 parent 95fa653 commit 0a3ff71
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 10 deletions.
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ tqdm>=4.41.0
# logging -------------------------------------
# wandb

# coco ----------------------------------------
# pycocotools>=2.0
# plotting ------------------------------------
seaborn
pandas

# export --------------------------------------
# coremltools==4.0
Expand All @@ -26,4 +27,4 @@ tqdm>=4.41.0

# extras --------------------------------------
# thop # FLOPS computation
# seaborn # plotting
# pycocotools>=2.0 # COCO mAP
15 changes: 10 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \
non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
from utils.loss import compute_loss
from utils.metrics import ap_per_class
from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target
from utils.torch_utils import select_device, time_synchronized

Expand Down Expand Up @@ -89,6 +89,7 @@ def test(data,
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]

seen = 0
confusion_matrix = ConfusionMatrix(nc=nc)
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
coco91class = coco80_to_coco91_class()
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', '[email protected]', '[email protected]:.95')
Expand Down Expand Up @@ -176,6 +177,8 @@ def test(data,
# target boxes
tbox = xywh2xyxy(labels[:, 1:5])
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels
if plots:
confusion_matrix.process_batch(pred, torch.cat((labels[:, 0:1], tbox), 1))

# Per target class
for cls in torch.unique(tcls_tensor):
Expand Down Expand Up @@ -218,10 +221,12 @@ def test(data,
else:
nt = torch.zeros(1)

# W&B logging
if plots and wandb and wandb.run:
wandb.log({"Images": wandb_images})
wandb.log({"Validation": [wandb.Image(str(x), caption=x.name) for x in sorted(save_dir.glob('test*.jpg'))]})
# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb and wandb.run:
wandb.log({"Images": wandb_images})
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})

# Print results
pf = '%20s' + '%12.3g' * 6 # print format
Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if plots:
plot_results(save_dir=save_dir) # save as results.png
if wandb:
wandb.log({"Results": [wandb.Image(str(save_dir / x), caption=x) for x in
['results.png', 'precision_recall_curve.png']]})
files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
else:
dist.destroy_process_group()
Expand Down
81 changes: 81 additions & 0 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import matplotlib.pyplot as plt
import numpy as np
import torch

from . import general


def fitness(x):
Expand Down Expand Up @@ -102,6 +105,84 @@ def compute_ap(recall, precision):
return ap, mpre, mrec


class ConfusionMatrix:
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
def __init__(self, nc, conf=0.25, iou_thres=0.45):
self.matrix = np.zeros((nc + 1, nc + 1))
self.nc = nc # number of classes
self.conf = conf
self.iou_thres = iou_thres

def process_batch(self, detections, labels):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
labels (Array[M, 5]), class, x1, y1, x2, y2
Returns:
None, updates confusion matrix accordingly
"""
detections = detections[detections[:, 4] > self.conf]
gt_classes = labels[:, 0].int()
detection_classes = detections[:, 5].int()
iou = general.box_iou(labels[:, 1:], detections[:, :4])

x = torch.where(iou > self.iou_thres)
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
else:
matches = np.zeros((0, 3))

n = matches.shape[0] > 0
m0, m1, _ = matches.transpose().astype(np.int16)
for i, gc in enumerate(gt_classes):
j = m0 == i
if n and sum(j) == 1:
self.matrix[gc, detection_classes[m1[j]]] += 1 # correct
else:
self.matrix[gc, self.nc] += 1 # background FP

if n:
for i, dc in enumerate(detection_classes):
if not any(m1 == i):
self.matrix[self.nc, dc] += 1 # background FN

def matrix(self):
return self.matrix

def plot(self, save_dir='', names=()):
try:
import seaborn as sn

array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig = plt.figure(figsize=(12, 9))
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
xticklabels=names + ['background FN'] if labels else "auto",
yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
fig.tight_layout()
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
except Exception as e:
pass

def print(self):
for i in range(self.nc + 1):
print(' '.join(map(str, self.matrix[i])))


# Plots ----------------------------------------------------------------------------------------------------------------

def plot_pr_curve(px, py, ap, save_dir='.', names=()):
fig, ax = plt.subplots(1, 1, figsize=(9, 6))
py = np.stack(py, axis=1)
Expand Down

0 comments on commit 0a3ff71

Please sign in to comment.