-
Notifications
You must be signed in to change notification settings - Fork 10
/
losses.py
173 lines (147 loc) · 5.3 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from keras import backend as K
from keras.layers import Layer
from keras import regularizers
import tensorflow as tf
"""
=========================
SphereFace
=========================
"""
class SphereFace(Layer):
def __init__(self, n_classes=10, s=30.0, m=1.35, regularizer=None, **kwargs):
super(SphereFace, self).__init__(**kwargs)
self.n_classes = n_classes
self.s = s
self.m = m
self.regularizer = regularizers.get(regularizer)
def build(self, input_shape):
super(SphereFace, self).build(input_shape[0])
self.W = self.add_weight(name='W',
shape=(input_shape[0][-1], self.n_classes),
initializer='glorot_uniform',
trainable=True,
regularizer=self.regularizer)
def call(self, inputs):
x, y = inputs
c = K.shape(x)[-1]
# normalize feature
x = tf.nn.l2_normalize(x, axis=1)
# normalize weights
W = tf.nn.l2_normalize(self.W, axis=0)
# dot product
logits = x @ W
# add margin
# clip logits to prevent zero division when backward
theta = tf.acos(K.clip(logits, -1.0 + K.epsilon(), 1.0 - K.epsilon()))
target_logits = tf.cos(self.m * theta)
#
logits = logits * (1 - y) + target_logits * y
# feature re-scale
logits *= self.s
out = tf.nn.softmax(logits)
return out
def compute_output_shape(self, input_shape):
return (None, self.n_classes)
"""
=========================
CosFace
=========================
"""
class CosFace(Layer):
def __init__(self, n_classes=10, s=30.0, m=0.35, regularizer=None, **kwargs):
super(CosFace, self).__init__(**kwargs)
self.n_classes = n_classes
self.s = s
self.m = m
self.regularizer = regularizers.get(regularizer)
def build(self, input_shape):
super(CosFace, self).build(input_shape[0])
self.W = self.add_weight(name='W',
shape=(input_shape[0][-1], self.n_classes),
initializer='glorot_uniform',
trainable=True,
regularizer=self.regularizer)
def call(self, inputs):
x, y = inputs
c = K.shape(x)[-1]
# normalize feature
x = tf.nn.l2_normalize(x, axis=1)
# normalize weights
W = tf.nn.l2_normalize(self.W, axis=0)
# dot product
logits = x @ W
# add margin
target_logits = logits - self.m
#
logits = logits * (1 - y) + target_logits * y
# feature re-scale
logits *= self.s
out = tf.nn.softmax(logits)
return out
def compute_output_shape(self, input_shape):
return (None, self.n_classes)
"""
=========================
ArcFace
=========================
"""
class ArcFace(Layer):
def __init__(self, n_classes=10, s=30.0, m=0.50, regularizer=None, **kwargs):
super(ArcFace, self).__init__(**kwargs)
self.n_classes = n_classes
self.s = s
self.m = m
self.regularizer = regularizers.get(regularizer)
def build(self, input_shape):
super(ArcFace, self).build(input_shape[0])
self.W = self.add_weight(name='W',
shape=(input_shape[0][-1], self.n_classes),
initializer='glorot_uniform',
trainable=True,
regularizer=self.regularizer)
def call(self, inputs):
x, y = inputs
c = K.shape(x)[-1]
# normalize feature
x = tf.nn.l2_normalize(x, axis=1)
# normalize weights
W = tf.nn.l2_normalize(self.W, axis=0)
# dot product
logits = x @ W
# add margin
# clip logits to prevent zero division when backward
theta = tf.acos(K.clip(logits, -1.0 + K.epsilon(), 1.0 - K.epsilon()))
target_logits = tf.cos(theta + self.m)
# sin = tf.sqrt(1 - logits**2)
# cos_m = tf.cos(logits)
# sin_m = tf.sin(logits)
# target_logits = logits * cos_m - sin * sin_m
#
logits = logits * (1 - y) + target_logits * y
# feature re-scale
logits *= self.s
out = tf.nn.softmax(logits)
return out
def compute_output_shape(self, input_shape):
return (None, self.n_classes)
"""
=========================
Circle loss
=========================
"""
def circle_loss(y_true,
y_pred,
gamma: int = 256,
margin: float = 0.25,):
O_p = 1 + margin
O_n = -margin
alpha_p = tf.nn.relu(O_p - tf.stop_gradient(y_pred))
alpha_n = tf.nn.relu(tf.stop_gradient(y_pred) - O_n)
Delta_p = 1 - margin
Delta_n = margin
# yapf: disable
y_true = tf.cast(y_true, tf.float32)
y_pred = (y_true * (alpha_p * (y_pred - Delta_p)) +
(1 - y_true) * (alpha_n * (y_pred - Delta_n))) * gamma
# yapf: enable
return tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)