-
Notifications
You must be signed in to change notification settings - Fork 15
/
metrics.py
79 lines (63 loc) · 2.46 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import tensorflow as tf
from keras import backend as K
def mean_iou(y_true, y_pred):
prec = []
for t in np.arange(0.5, 1.0, 0.05):
y_pred_ = tf.to_int32(y_pred > t)
score, up_opt = tf.metrics.mean_iou(y_true, y_pred_, 2)
K.get_session().run(tf.local_variables_initializer())
with tf.control_dependencies([up_opt]):
score = tf.identity(score)
prec.append(score)
return K.mean(K.stack(prec), axis=0)
def iou_metric(y_true_in, y_pred_in, print_table=False):
labels = y_true_in
y_pred = y_pred_in
true_objects = 2
pred_objects = 2
intersection = np.histogram2d(labels.flatten(), y_pred.flatten(), bins=(true_objects, pred_objects))[0]
# Compute areas (needed for finding the union between all objects)
area_true = np.histogram(labels, bins=true_objects)[0]
area_pred = np.histogram(y_pred, bins=pred_objects)[0]
area_true = np.expand_dims(area_true, -1)
area_pred = np.expand_dims(area_pred, 0)
# Compute union
union = area_true + area_pred - intersection
# Exclude background from the analysis
intersection = intersection[1:, 1:]
union = union[1:, 1:]
union[union == 0] = 1e-9
# Compute the intersection over union
iou = intersection / union
# Precision helper function
def precision_at(threshold, iou):
matches = iou > threshold
true_positives = np.sum(matches, axis=1) == 1 # Correct objects
false_positives = np.sum(matches, axis=0) == 0 # Missed objects
false_negatives = np.sum(matches, axis=1) == 0 # Extra objects
tp, fp, fn = np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives)
return tp, fp, fn
# Loop over IoU thresholds
prec = []
if print_table:
print("Thresh\tTP\tFP\tFN\tPrec.")
for t in np.arange(0.5, 1.0, 0.05):
tp, fp, fn = precision_at(t, iou)
if (tp + fp + fn) > 0:
p = tp / (tp + fp + fn)
else:
p = 0
if print_table:
print("{:1.3f}\t{}\t{}\t{}\t{:1.3f}".format(t, tp, fp, fn, p))
prec.append(p)
if print_table:
print("AP\t-\t-\t-\t{:1.3f}".format(np.mean(prec)))
return np.mean(prec)
def iou_metric_batch(y_true_in, y_pred_in):
batch_size = y_true_in.shape[0]
metric = []
for batch in range(batch_size):
value = iou_metric(y_true_in[batch], y_pred_in[batch])
metric.append(value)
return np.mean(metric)