-
Notifications
You must be signed in to change notification settings - Fork 73
/
contrastive.py
72 lines (51 loc) · 1.59 KB
/
contrastive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
Time: 2021-10-13 4:01 下午
Author: huayang
Subject:
"""
import os
import sys
import json
import doctest
from typing import *
from collections import defaultdict
from torch.nn import functional as F # noqa
from my.pytorch.backend.distance_fn import euclidean_distance
from my.pytorch.loss.base import BaseLoss
__all__ = [
'ContrastiveLoss'
]
def contrastive_loss(x1, x2, labels, distance_fn=euclidean_distance, margin=2.0):
""" 对比损失 (0 <= label <= 1)
- 当 y=1(即样本相似)时,如果距离较大,则加大损失;
- 当 y=0(即样本不相似)时,如果距离反而小,也会增大损失;
Args:
x1:
x2:
labels:
distance_fn: 默认为欧几里得距离
margin: 需要根据使用距离函数调整
Returns:
"""
labels = labels.float()
distances = distance_fn(x1, x2)
return 0.5 * (labels * distances.pow(2) + (1 - labels) * F.relu(margin - distances).pow(2))
class ContrastiveLoss(BaseLoss):
"""@Pytorch Loss
对比损失(默认距离函数为欧几里得距离)
"""
def __init__(self, distance_fn=euclidean_distance, margin=1.0, **kwargs):
""""""
self.margin = margin
self.distance_fn = distance_fn
super(ContrastiveLoss, self).__init__(**kwargs)
def compute_loss(self, x1, x2, labels):
return contrastive_loss(x1, x2, labels, distance_fn=self.distance_fn, margin=self.margin)
def _test():
""""""
doctest.testmod()
if __name__ == '__main__':
""""""
_test()