-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
147 lines (125 loc) · 5.96 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils.dice_score import multiclass_dice_coeff, dice_coeff
# add
from PIL import Image
import numpy as np
from torch.functional import Tensor
from utils.miou import IOU, MIOU
'''
def miou(mask_pred,mask_true):
total_miou=0.0
return total_miou
def recall(mask_pred,mask_true):
total_recall=0.0
return total_recall
def map(mask_pred,mask_true):
total_map=0.0
return total_map
def warpingError(mask_pred,mask_true):
total_warp=0.0
return total_warp
def randError(mask_pred,mask_true):
total_rand=0.0
return total_rand
'''
def evaluate(net, dataloader, device, is_gpus=False):
net.eval()
num_val_batches = len(dataloader)
dice_onehot_bg = 0
dice_onehot_nobg = 0
dice_softmax_bg = 0
dice_softmax_nobg = 0
dice_score = 0
miou_eva = 0
class_iou_eva = []
# acc compute
num_correct = 0
num_pixels = 0
input_channels = net.module.n_channels if is_gpus else net.n_channels
output_classes = net.module.n_classes if is_gpus else net.n_classes
# iterate over the validation set
for batch in tqdm(dataloader,
total=num_val_batches,
desc='Validation round',
unit='batch',
leave=False):
# print()
# print('-------------------------------------------------------')
# print('Evaluation for batch ')
image, mask_true = batch['image'], batch['mask']
# move images and labels to correct device and type
image = image.to(device=device, dtype=torch.float32)
# mask_true = mask_true.to(device=device, dtype=torch.float32)
mask_true = mask_true.to(device=device, dtype=torch.long)
mask_true = F.one_hot(torch.squeeze(mask_true, dim=1),
output_classes).permute(0, 3, 1, 2).float()
# mask_true = mask_true.float()
# mask_true = F.one_hot(mask_true.argmax(dim=1),
# net.n_classes).permute(0, 3, 1, 2).float()
with torch.no_grad():
# predict the mask
mask_pred = net(image)
# convert to one-hot format
if output_classes == 1:
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
# compute the Dice score
dice_score += dice_coeff(mask_pred,
mask_true,
reduce_batch_first=False)
else:
mask_pred_softmax = F.softmax(mask_pred, dim=1).float()
mask_pred_onehot = F.one_hot(mask_pred_softmax.argmax(dim=1),
output_classes).permute(
0, 3, 1, 2).float()
# save the eva as pic
# mask=Tensor.cpu(mask_pred[0])
# backgrand=Image.fromarray(np.uint8(mask[0]*0))
# class1=Image.fromarray(np.uint8(mask[1]*255))
# class2=Image.fromarray(np.uint8(mask[2]*255))
# result_img=Image.merge('RGB',(class2,class1,backgrand))
# result_img.save('eva.png')
# print()
# print('---------------start evaluate-----------------------------')
# print('mask_pred: ',Tensor.cpu(mask_pred))
# print('mask_true: ',Tensor.cpu(mask_true))
# dice_score += multiclass_dice_coeff(mask_pred, mask_true, reduce_batch_first=True)
# print('----------------------use softmax to compute dice----------------------')
# ignoring background
dice_softmax_nobg += multiclass_dice_coeff(
mask_pred_softmax[:, 1:, ...],
mask_true[:, 1:, ...],
reduce_batch_first=True)
# consider background
dice_softmax_bg += multiclass_dice_coeff(
mask_pred_softmax, mask_true, reduce_batch_first=True)
# print('----------------------use onehot to compute dice----------------------')
# ignoring background
dice_onehot_nobg += multiclass_dice_coeff(
mask_pred_onehot[:, 1:, ...],
mask_true[:, 1:, ...],
reduce_batch_first=True)
# consider background
dice_onehot_bg += multiclass_dice_coeff(
mask_pred_onehot, mask_true, reduce_batch_first=True)
# print('---------------finish evaluate-----------------------------')
# compute the acc score, ignoring background
num_correct += (mask_pred_onehot[:, 1:,
...] == mask_true[:, 1:,
...]).sum()
num_pixels += torch.numel(mask_pred[:, 1:, ...])
# miou no bg
class_iou, miou = MIOU(mask_pred_onehot[:, 1:, ...],
mask_true[:, 1:, ...])
miou_eva += miou
if len(class_iou) != len(class_iou_eva):
class_iou_eva = [0] * len(class_iou)
for index_iou in range(len(class_iou)):
class_iou_eva[index_iou] += class_iou[index_iou]
net.train()
class_iou_eva = np.array(class_iou_eva)
# print('num_val_batches',num_val_batches)
# print('num_pixels',num_pixels)
return class_iou_eva / num_val_batches, miou_eva / num_val_batches, dice_softmax_nobg / num_val_batches, dice_softmax_bg / num_val_batches, dice_onehot_nobg / num_val_batches, dice_onehot_bg / num_val_batches, num_correct / num_pixels
return dice_score / num_val_batches, dice_score_softmax / num_val_batches, num_correct / num_pixels