-
Notifications
You must be signed in to change notification settings - Fork 1
/
metrics.py
56 lines (45 loc) · 1.98 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import paddle
import numpy as np
def calc_accuracy_score(true_labels, pred_labels):
assert len(true_labels) == len(pred_labels)
num = 0
for i in range(0, len(true_labels)):
if int(true_labels[i]) == int(pred_labels[i]):
num += 1
return float(num / len(true_labels))
def calc_f1_score(true_labels, pred_labels):
return f1_score(true_labels, pred_labels), precision_score(true_labels, pred_labels), recall_score(true_labels, pred_labels)
# this function is from baidu company <https://aistudio.baidu.com/aistudio/projectdetail/1968542>
@paddle.no_grad()
def evaluate(model, criterion, metric, data_loader):
"""
Given a dataset, it evals model and computes the metric.
Args:
model(obj:`paddle.nn.Layer`): A model to classify texts.
criterion(obj:`paddle.nn.Layer`): It can compute the loss.
metric(obj:`paddle.metric.Metric`): The evaluation metric.
data_loader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
"""
model.eval()
metric.reset()
losses = []
result = []
full_target = []
for batch in data_loader:
input_ids, token_type_ids, labels = batch
logits = model(input_ids, token_type_ids)
loss = criterion(logits, labels)
losses.append(loss.numpy())
pred = logits.argmax(1)
result.extend(pred.cpu().tolist())
full_target.extend(labels.cpu().tolist())
correct = metric.compute(logits, labels)
metric.update(correct)
accu = metric.accumulate()
precision = precision_score(full_target,result,average='macro')
recall = recall_score(full_target,result,average='macro')
f1 = f1_score(full_target,result,average='macro')
print("eval loss: %.5f, accu: %.5f, F1: %.4f, Precision: %.4f, Recall: %.4f" % (np.mean(losses), accu, f1, precision, recall))
model.train()
metric.reset()