forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Utils reorganization (ultralytics#1392)
* Utils reorganization * Add new utils files * cleanup * simplify * reduce datasets.py * remove evolve.sh * loadWebcam cleanup
- Loading branch information
1 parent
379396e
commit fe341fa
Showing
14 changed files
with
890 additions
and
988 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ | |
import math | ||
import os | ||
import random | ||
import shutil | ||
import time | ||
from pathlib import Path | ||
from warnings import warn | ||
|
@@ -23,13 +22,15 @@ | |
|
||
import test # import test.py to get mAP after each epoch | ||
from models.yolo import Model | ||
from utils.autoanchor import check_anchors | ||
from utils.datasets import create_dataloader | ||
from utils.general import ( | ||
torch_distributed_zero_first, labels_to_class_weights, plot_labels, check_anchors, labels_to_image_weights, | ||
compute_loss, plot_images, fitness, strip_optimizer, plot_results, get_latest_run, check_dataset, check_file, | ||
check_git_status, check_img_size, increment_path, print_mutation, plot_evolution, set_logging, init_seeds) | ||
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ | ||
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ | ||
print_mutation, set_logging | ||
from utils.google_utils import attempt_download | ||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts | ||
from utils.loss import compute_loss | ||
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution | ||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -209,7 +210,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
|
||
# Start training | ||
t0 = time.time() | ||
nw = max(round(hyp['warmup_epochs'] * nb), 1e3) # number of warmup iterations, max(3 epochs, 1k iterations) | ||
nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations) | ||
# nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training | ||
maps = np.zeros(nc) # mAP per class | ||
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls) | ||
|
@@ -334,9 +335,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) | ||
|
||
# Log | ||
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', # train loss | ||
tags = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss | ||
'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', | ||
'val/giou_loss', 'val/obj_loss', 'val/cls_loss', # val loss | ||
'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss | ||
'x/lr0', 'x/lr1', 'x/lr2'] # params | ||
for x, tag in zip(list(mloss[:-1]) + list(results) + lr, tags): | ||
if tb_writer: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# Activation functions | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Auto-anchor utils | ||
|
||
import numpy as np | ||
import torch | ||
import yaml | ||
from scipy.cluster.vq import kmeans | ||
from tqdm import tqdm | ||
|
||
|
||
def check_anchor_order(m): | ||
# Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary | ||
a = m.anchor_grid.prod(-1).view(-1) # anchor area | ||
da = a[-1] - a[0] # delta a | ||
ds = m.stride[-1] - m.stride[0] # delta s | ||
if da.sign() != ds.sign(): # same order | ||
print('Reversing anchor order') | ||
m.anchors[:] = m.anchors.flip(0) | ||
m.anchor_grid[:] = m.anchor_grid.flip(0) | ||
|
||
|
||
def check_anchors(dataset, model, thr=4.0, imgsz=640): | ||
# Check anchor fit to data, recompute if necessary | ||
print('\nAnalyzing anchors... ', end='') | ||
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() | ||
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) | ||
scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale | ||
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh | ||
|
||
def metric(k): # compute metric | ||
r = wh[:, None] / k[None] | ||
x = torch.min(r, 1. / r).min(2)[0] # ratio metric | ||
best = x.max(1)[0] # best_x | ||
aat = (x > 1. / thr).float().sum(1).mean() # anchors above threshold | ||
bpr = (best > 1. / thr).float().mean() # best possible recall | ||
return bpr, aat | ||
|
||
bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2)) | ||
print('anchors/target = %.2f, Best Possible Recall (BPR) = %.4f' % (aat, bpr), end='') | ||
if bpr < 0.98: # threshold to recompute | ||
print('. Attempting to improve anchors, please wait...') | ||
na = m.anchor_grid.numel() // 2 # number of anchors | ||
new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) | ||
new_bpr = metric(new_anchors.reshape(-1, 2))[0] | ||
if new_bpr > bpr: # replace anchors | ||
new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors) | ||
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference | ||
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss | ||
check_anchor_order(m) | ||
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.') | ||
else: | ||
print('Original anchors better than new anchors. Proceeding with original anchors.') | ||
print('') # newline | ||
|
||
|
||
def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True): | ||
""" Creates kmeans-evolved anchors from training dataset | ||
Arguments: | ||
path: path to dataset *.yaml, or a loaded dataset | ||
n: number of anchors | ||
img_size: image size used for training | ||
thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0 | ||
gen: generations to evolve anchors using genetic algorithm | ||
verbose: print all results | ||
Return: | ||
k: kmeans evolved anchors | ||
Usage: | ||
from utils.general import *; _ = kmean_anchors() | ||
""" | ||
thr = 1. / thr | ||
|
||
def metric(k, wh): # compute metrics | ||
r = wh[:, None] / k[None] | ||
x = torch.min(r, 1. / r).min(2)[0] # ratio metric | ||
# x = wh_iou(wh, torch.tensor(k)) # iou metric | ||
return x, x.max(1)[0] # x, best_x | ||
|
||
def anchor_fitness(k): # mutation fitness | ||
_, best = metric(torch.tensor(k, dtype=torch.float32), wh) | ||
return (best * (best > thr).float()).mean() # fitness | ||
|
||
def print_results(k): | ||
k = k[np.argsort(k.prod(1))] # sort small to large | ||
x, best = metric(k, wh0) | ||
bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr | ||
print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat)) | ||
print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' % | ||
(n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='') | ||
for i, x in enumerate(k): | ||
print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg | ||
return k | ||
|
||
if isinstance(path, str): # *.yaml file | ||
with open(path) as f: | ||
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict | ||
from utils.datasets import LoadImagesAndLabels | ||
dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True) | ||
else: | ||
dataset = path # dataset | ||
|
||
# Get label wh | ||
shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True) | ||
wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh | ||
|
||
# Filter | ||
i = (wh0 < 3.0).any(1).sum() | ||
if i: | ||
print('WARNING: Extremely small objects found. ' | ||
'%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0))) | ||
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels | ||
|
||
# Kmeans calculation | ||
print('Running kmeans for %g anchors on %g points...' % (n, len(wh))) | ||
s = wh.std(0) # sigmas for whitening | ||
k, dist = kmeans(wh / s, n, iter=30) # points, mean distance | ||
k *= s | ||
wh = torch.tensor(wh, dtype=torch.float32) # filtered | ||
wh0 = torch.tensor(wh0, dtype=torch.float32) # unfiltered | ||
k = print_results(k) | ||
|
||
# Plot | ||
# k, d = [None] * 20, [None] * 20 | ||
# for i in tqdm(range(1, 21)): | ||
# k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance | ||
# fig, ax = plt.subplots(1, 2, figsize=(14, 7)) | ||
# ax = ax.ravel() | ||
# ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') | ||
# fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh | ||
# ax[0].hist(wh[wh[:, 0]<100, 0],400) | ||
# ax[1].hist(wh[wh[:, 1]<100, 1],400) | ||
# fig.tight_layout() | ||
# fig.savefig('wh.png', dpi=200) | ||
|
||
# Evolve | ||
npr = np.random | ||
f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma | ||
pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar | ||
for _ in pbar: | ||
v = np.ones(sh) | ||
while (v == 1).all(): # mutate until a change occurs (prevent duplicates) | ||
v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) | ||
kg = (k.copy() * v).clip(min=2.0) | ||
fg = anchor_fitness(kg) | ||
if fg > f: | ||
f, k = fg, kg.copy() | ||
pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f | ||
if verbose: | ||
print_results(k) | ||
|
||
return print_results(k) |
Oops, something went wrong.