-
Notifications
You must be signed in to change notification settings - Fork 0
/
metric.py
49 lines (40 loc) · 1.68 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from torchmetrics import Metric
import torch
from sklearn.metrics import f1_score
class MyAccuracy(Metric):
def __init__(self):
super().__init__(dist_sync_on_step=False)
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds, target):
preds = preds.argmax(dim=1)
if preds.shape != target.shape:
raise ValueError("Predictions and targets must have the same shape")
self.correct += (preds == target).sum()
self.total += target.numel()
def compute(self):
return self.correct.float() / self.total.float()
class MyF1Score(Metric):
def __init__(self, num_classes):
super().__init__(dist_sync_on_step=False)
self.num_classes = num_classes
self.add_state("all_targets", default=[], dist_reduce_fx="cat")
self.add_state("all_preds", default=[], dist_reduce_fx="cat")
def update(self, preds, target):
preds = torch.argmax(preds, dim=1)
self.all_preds.append(preds)
self.all_targets.append(target)
def compute(self):
y_pred = torch.cat(self.all_preds)
y_true = torch.cat(self.all_targets)
f1_scores = {}
for cls in range(self.num_classes):
y_true_cls = (y_true == cls).int()
y_pred_cls = (y_pred == cls).int()
f1 = f1_score(y_true_cls.cpu().numpy(), y_pred_cls.cpu().numpy(), zero_division=1)
f1_scores[cls] = f1
return f1_scores
# 사용 예시
num_classes = 3
accuracy_metric = MyAccuracy()
f1_score_metric = MyF1Score(num_classes=num_classes)