Skip to content

Commit

Permalink
Daemon thread plotting (ultralytics#1561)
Browse files Browse the repository at this point in the history
* Daemon thread plotting

* remove process_batch

* plot after print
  • Loading branch information
glenn-jocher authored Nov 30, 2020
1 parent 68211f7 commit b6ed110
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
23 changes: 12 additions & 11 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
from pathlib import Path
from threading import Thread

import numpy as np
import torch
Expand Down Expand Up @@ -206,10 +207,10 @@ def test(data,

# Plot images
if plots and batch_i < 3:
f = save_dir / f'test_batch{batch_i}_labels.jpg' # filename
plot_images(img, targets, paths, f, names) # labels
f = save_dir / f'test_batch{batch_i}_pred.jpg'
plot_images(img, output_to_target(output), paths, f, names) # predictions
f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels
Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions
Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start()

# Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
Expand All @@ -221,13 +222,6 @@ def test(data,
else:
nt = torch.zeros(1)

# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb and wandb.run:
wandb.log({"Images": wandb_images})
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})

# Print results
pf = '%20s' + '%12.3g' * 6 # print format
print(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
Expand All @@ -242,6 +236,13 @@ def test(data,
if not training:
print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)

# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb and wandb.run:
wandb.log({"Images": wandb_images})
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})

# Save JSON
if save_json and len(jdict):
w = Path(weights[0] if isinstance(weights, list) else weights).stem if weights is not None else '' # weights
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import argparse
import logging
import math
import os
import random
import time
from pathlib import Path
from threading import Thread
from warnings import warn

import math
import numpy as np
import torch.distributed as dist
import torch.nn as nn
Expand Down Expand Up @@ -134,6 +135,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
name=save_dir.stem,
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
loggers = {'wandb': wandb} # loggers dict

# Resume
start_epoch, best_fitness = 0, 0.0
Expand Down Expand Up @@ -201,11 +203,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device))
if plots:
plot_labels(labels, save_dir=save_dir)
Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start()
if tb_writer:
tb_writer.add_histogram('classes', c, 0)
if wandb:
wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})

# Anchors
if not opt.noautoanchor:
Expand Down Expand Up @@ -311,7 +311,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Plot
if plots and ni < 3:
f = save_dir / f'train_batch{ni}.jpg' # filename
plot_images(images=imgs, targets=targets, paths=paths, fname=f)
Thread(target=plot_images, args=(imgs, targets, paths, f), daemon=True).start()
# if tb_writer:
# tb_writer.add_image(f, result, dataformats='HWC', global_step=epoch)
# tb_writer.add_graph(model, imgs) # add model to tensorboard
Expand Down
11 changes: 8 additions & 3 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
plt.savefig('test_study.png', dpi=300)


def plot_labels(labels, save_dir=''):
def plot_labels(labels, save_dir=Path(''), loggers=None):
# plot dataset labels
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes
Expand All @@ -264,7 +264,7 @@ def plot_labels(labels, save_dir=''):
sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
diag_kws=dict(bins=50))
plt.savefig(Path(save_dir) / 'labels_correlogram.png', dpi=200)
plt.savefig(save_dir / 'labels_correlogram.png', dpi=200)
plt.close()
except Exception as e:
pass
Expand Down Expand Up @@ -292,9 +292,14 @@ def plot_labels(labels, save_dir=''):
for a in [0, 1, 2, 3]:
for s in ['top', 'right', 'left', 'bottom']:
ax[a].spines[s].set_visible(False)
plt.savefig(Path(save_dir) / 'labels.png', dpi=200)
plt.savefig(save_dir / 'labels.png', dpi=200)
plt.close()

# loggers
for k, v in loggers.items() or {}:
if k == 'wandb' and v:
v.log({"Labels": [v.Image(str(x), caption=x.name) for x in save_dir.glob('*labels*.png')]})


def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
# Plot hyperparameter evolution results in evolve.txt
Expand Down

0 comments on commit b6ed110

Please sign in to comment.