forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update PR curve * legend outside * list(Path().glob())
- Loading branch information
1 parent
8d2d6d2
commit 4250f84
Showing
3 changed files
with
35 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -213,7 +213,7 @@ def test(data, | |
# Compute statistics | ||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy | ||
if len(stats) and stats[0].any(): | ||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png') | ||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) | ||
p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, [email protected], [email protected]:0.95] | ||
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() | ||
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
# Model validation metrics | ||
|
||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
|
@@ -10,7 +12,7 @@ def fitness(x): | |
return (x[:, :4] * w).sum(1) | ||
|
||
|
||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-recall_curve.png'): | ||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision-recall_curve.png', names=[]): | ||
""" Compute the average precision, given the recall and precision curves. | ||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. | ||
# Arguments | ||
|
@@ -19,7 +21,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-re | |
pred_cls: Predicted object classes (nparray). | ||
target_cls: True object classes (nparray). | ||
plot: Plot precision-recall curve at [email protected] | ||
fname: Plot filename | ||
save_dir: Plot save directory | ||
# Returns | ||
The average precision as computed in py-faster-rcnn. | ||
""" | ||
|
@@ -66,17 +68,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-re | |
f1 = 2 * p * r / (p + r + 1e-16) | ||
|
||
if plot: | ||
py = np.stack(py, axis=1) | ||
fig, ax = plt.subplots(1, 1, figsize=(5, 5)) | ||
ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision) | ||
ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes %.3f [email protected]' % ap[:, 0].mean()) | ||
ax.set_xlabel('Recall') | ||
ax.set_ylabel('Precision') | ||
ax.set_xlim(0, 1) | ||
ax.set_ylim(0, 1) | ||
plt.legend() | ||
fig.tight_layout() | ||
fig.savefig(fname, dpi=200) | ||
plot_pr_curve(px, py, ap, save_dir, names) | ||
|
||
return p, r, ap, f1, unique_classes.astype('int32') | ||
|
||
|
@@ -108,3 +100,23 @@ def compute_ap(recall, precision): | |
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve | ||
|
||
return ap, mpre, mrec | ||
|
||
|
||
def plot_pr_curve(px, py, ap, save_dir='.', names=()): | ||
fig, ax = plt.subplots(1, 1, figsize=(9, 6)) | ||
py = np.stack(py, axis=1) | ||
|
||
if 0 < len(names) < 21: # show mAP in legend if < 10 classes | ||
for i, y in enumerate(py.T): | ||
ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision) | ||
else: | ||
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) | ||
|
||
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f [email protected]' % ap[:, 0].mean()) | ||
ax.set_xlabel('Recall') | ||
ax.set_ylabel('Precision') | ||
ax.set_xlim(0, 1) | ||
ax.set_ylim(0, 1) | ||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") | ||
fig.tight_layout() | ||
fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,7 +65,7 @@ def plot_one_box(x, img, color=None, label=None, line_thickness=None): | |
cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA) | ||
|
||
|
||
def plot_wh_methods(): # from utils.general import *; plot_wh_methods() | ||
def plot_wh_methods(): # from utils.plots import *; plot_wh_methods() | ||
# Compares the two methods for width-height anchor multiplication | ||
# https://github.com/ultralytics/yolov3/issues/168 | ||
x = np.arange(-4.0, 4.0, .1) | ||
|
@@ -200,7 +200,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): | |
plt.savefig(Path(save_dir) / 'LR.png', dpi=200) | ||
|
||
|
||
def plot_test_txt(): # from utils.general import *; plot_test() | ||
def plot_test_txt(): # from utils.plots import *; plot_test() | ||
# Plot test.txt histograms | ||
x = np.loadtxt('test.txt', dtype=np.float32) | ||
box = xyxy2xywh(x[:, :4]) | ||
|
@@ -217,7 +217,7 @@ def plot_test_txt(): # from utils.general import *; plot_test() | |
plt.savefig('hist1d.png', dpi=200) | ||
|
||
|
||
def plot_targets_txt(): # from utils.general import *; plot_targets_txt() | ||
def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() | ||
# Plot targets.txt histograms | ||
x = np.loadtxt('targets.txt', dtype=np.float32).T | ||
s = ['x targets', 'y targets', 'width targets', 'height targets'] | ||
|
@@ -230,7 +230,7 @@ def plot_targets_txt(): # from utils.general import *; plot_targets_txt() | |
plt.savefig('targets.jpg', dpi=200) | ||
|
||
|
||
def plot_study_txt(f='study.txt', x=None): # from utils.general import *; plot_study_txt() | ||
def plot_study_txt(f='study.txt', x=None): # from utils.plots import *; plot_study_txt() | ||
# Plot study.txt generated by test.py | ||
fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True) | ||
ax = ax.ravel() | ||
|
@@ -294,7 +294,7 @@ def plot_labels(labels, save_dir=''): | |
pass | ||
|
||
|
||
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general import *; plot_evolution() | ||
def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution() | ||
# Plot hyperparameter evolution results in evolve.txt | ||
with open(yaml_file) as f: | ||
hyp = yaml.load(f, Loader=yaml.FullLoader) | ||
|
@@ -318,7 +318,7 @@ def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.general im | |
print('\nPlot saved as evolve.png') | ||
|
||
|
||
def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_results_overlay() | ||
def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay() | ||
# Plot training 'results*.txt', overlaying train and val losses | ||
s = ['train', 'train', 'train', 'Precision', '[email protected]', 'val', 'val', 'val', 'Recall', '[email protected]:0.95'] # legends | ||
t = ['Box', 'Objectness', 'Classification', 'P-R', 'mAP-F1'] # titles | ||
|
@@ -342,20 +342,18 @@ def plot_results_overlay(start=0, stop=0): # from utils.general import *; plot_ | |
|
||
|
||
def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): | ||
# from utils.general import *; plot_results(save_dir='runs/train/exp0') | ||
# Plot training 'results*.txt' as seen in https://github.com/ultralytics/yolov5#reproduce-our-training | ||
# Plot training 'results*.txt'. from utils.plots import *; plot_results(save_dir='runs/train/exp') | ||
fig, ax = plt.subplots(2, 5, figsize=(12, 6)) | ||
ax = ax.ravel() | ||
s = ['Box', 'Objectness', 'Classification', 'Precision', 'Recall', | ||
'val Box', 'val Objectness', 'val Classification', '[email protected]', '[email protected]:0.95'] | ||
if bucket: | ||
# os.system('rm -rf storage.googleapis.com') | ||
# files = ['https://storage.googleapis.com/%s/results%g.txt' % (bucket, x) for x in id] | ||
files = ['results%g.txt' % x for x in id] | ||
c = ('gsutil cp ' + '%s ' * len(files) + '.') % tuple('gs://%s/results%g.txt' % (bucket, x) for x in id) | ||
os.system(c) | ||
else: | ||
files = glob.glob(str(Path(save_dir) / 'results*.txt')) + glob.glob('../../Downloads/results*.txt') | ||
files = list(Path(save_dir).glob('results*.txt')) | ||
assert len(files), 'No results.txt files found in %s, nothing to plot.' % os.path.abspath(save_dir) | ||
for fi, f in enumerate(files): | ||
try: | ||
|
@@ -367,7 +365,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): | |
if i in [0, 1, 2, 5, 6, 7]: | ||
y[y == 0] = np.nan # don't show zero loss values | ||
# y /= y[0] # normalize | ||
label = labels[fi] if len(labels) else Path(f).stem | ||
label = labels[fi] if len(labels) else f.stem | ||
ax[i].plot(x, y, marker='.', label=label, linewidth=1, markersize=6) | ||
ax[i].set_title(s[i]) | ||
# if i in [5, 6, 7]: # share train and val loss y axes | ||
|