Skip to content

Commit cd34e5c

Browse files
sovrasovAlexanderDokuchaev
authored andcommitted
Update face recognition scripts (openvinotoolkit#134)
* Face recognition: make all models accessible via config * Update face recognition readme * Face recognition: add new congig, fixes in scripts * Add link to mobilenetv2_2x
1 parent df1a4db commit cd34e5c

14 files changed

+142
-106
lines changed

pytorch_toolkit/face_recognition/README.md

+20-3
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,16 @@ cd $FR_ROOT/
3636
```
3737

3838
2. To start training FR model:
39+
3940
```bash
4041
python train.py --train_data_root $VGGFace2_ROOT/train/ --train_list $VGGFace2_ROOT/meta/train_list.txt
4142
--train_landmarks $VGGFace2_ROOT/bb_landmark/ --val_data_root $LFW_ROOT/lfw/ --val_list $LFW_ROOT/pairs.txt
4243
--val_landmarks $LFW_ROOT/lfw_landmark.txt --train_batch_size 200 --snap_prefix mobilenet_256 --lr 0.35
43-
--embed_size 256 --model mobilenet --device 1
44+
--embed_size 256 --model mobilenetv2 --device 1
4445
```
4546

4647
3. To evaluate FR snapshot (let's say we have MobileNet with 256 embedding size trained for 300k):
48+
4749
```bash
4850
python evaluate_lfw.py --val_data_root $LFW_ROOT/lfw/ --val_list $LFW_ROOT/pairs.txt
4951
--val_landmarks $LFW_ROOT/lfw_landmark.txt --snap /path/to/snapshot/mobilenet_256_300000.pt --model mobilenet --embed_size 256
@@ -62,7 +64,7 @@ margin_type: cos
6264
s: 30
6365
m: 0.35
6466
#model parameters
65-
model: mobilenet
67+
model: mobilenetv2
6668
embed_size: 256
6769
#misc
6870
snap_prefix: MobileFaceNet
@@ -81,14 +83,17 @@ python train.py -m 0.35 @./my_config.yml #here m can be overwritten with the val
8183

8284
## Models
8385

84-
1. You can download pretrained model from fileshare as well - https://download.01.org/opencv/openvino_training_extensions/models/face_recognition/Mobilenet_se_focal_121000.pt
86+
1. You can download pretrained model from fileshare as well - [mobilenetv2](https://download.01.org/opencv/openvino_training_extensions/models/face_recognition/Mobilenet_se_focal_121000.pt),
87+
[mobilenetv2_2x](https://download.01.org/opencv/openvino_training_extensions/models/face_recognition/Mobilenet_2x_se_121000.pt).
88+
8589
```bash
8690
cd $FR_ROOT
8791
python evaluate_lfw.py --val_data_root $LFW_ROOT/lfw/ --val_list $LFW_ROOT/pairs.txt --val_landmarks $LFW_ROOT/lfw_landmark.txt
8892
--snap /path/to/snapshot/Mobilenet_se_focal_121000.pt --model mobilenet --embed_size 256
8993
```
9094

9195
2. You should get the following output:
96+
- for `mobilenetv2`:
9297
```
9398
I1114 09:33:37.846870 10544 evaluate_lfw.py:242] Accuracy/Val_same_accuracy mean: 0.9923
9499
I1114 09:33:37.847019 10544 evaluate_lfw.py:243] Accuracy/Val_diff_accuracy mean: 0.9970
@@ -97,6 +102,18 @@ I1114 09:33:37.847179 10544 evaluate_lfw.py:245] Accuracy/Val_accuracy std dev:
97102
I1114 09:33:37.847229 10544 evaluate_lfw.py:246] AUC: 0.9995
98103
I1114 09:33:37.847305 10544 evaluate_lfw.py:247] Estimated threshold: 0.7241
99104
```
105+
- for `mobilenetv2_2x`:
106+
```
107+
I0820 15:48:06.307454 23328 evaluate_lfw.py:262] Accuracy/Val_same_accuracy mean: 0.9893
108+
I0820 15:48:06.307612 23328 evaluate_lfw.py:263] Accuracy/Val_diff_accuracy mean: 0.9990
109+
I0820 15:48:06.307647 23328 evaluate_lfw.py:264] Accuracy/Val_accuracy mean: 0.9942
110+
I0820 15:48:06.307732 23328 evaluate_lfw.py:265] Accuracy/Val_accuracy std dev: 0.0061
111+
I0820 15:48:06.307766 23328 evaluate_lfw.py:266] AUC: 0.9992
112+
I0820 15:48:06.307812 23328 evaluate_lfw.py:267] Estimated threshold: 0.6721
113+
```
114+
115+
`mobilenetv2_2x` is slightly worse on the LFW benchmark than `mobilenetv2`, but it's heavier and achieves higher score in the
116+
uncleaned version of the [MegaFace](http://megaface.cs.washington.edu/participate/challenge.html) benchmark: 73.77% rank-1 at 1M distractors in reidentification protocol vs 70.2%.
100117

101118
## Face Recognition Demo
102119

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#optimizer parameters
2+
lr: 0.4
3+
train_batch_size: 256
4+
#loss options
5+
margin_type: cos
6+
s: 30
7+
m: 0.35
8+
mining_type: sv
9+
t: 1.1
10+
#model parameters
11+
model: mobilenetv2_2x
12+
embed_size: 256
13+
14+
train_dataset: vgg
15+
snap_prefix: MobileFaceNet
16+
devices: [0, 1]

pytorch_toolkit/face_recognition/configs/mobilefacenet_vgg2.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ m: 0.35
88
mining_type: sv
99
t: 1.1
1010
#model parameters
11-
model: mobilenet
11+
model: mobilenetv2
1212
embed_size: 256
1313

1414
train_dataset: vgg

pytorch_toolkit/face_recognition/dump_features.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def main(args):
120120

121121
emb_array = np.zeros((nrof_images, args.embedding_size), dtype=np.float32)
122122

123-
dataset.transform = t.Compose([ResizeNumpy(models_backbones[args.model].get_input_res()),
123+
dataset.transform = t.Compose([ResizeNumpy(models_backbones[args.model]().get_input_res()),
124124
NumpyToTensor(switch_rb=True)])
125125
val_loader = DataLoader(dataset, batch_size=args.batch_size, num_workers=5, shuffle=False)
126126

pytorch_toolkit/face_recognition/evaluate_lfw.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def evaluate(args, dataset, model, compute_embeddings_fun, val_batch_size=16,
271271

272272
def load_test_dataset(arguments):
273273
"""Loads and configures the LFW dataset"""
274-
input_size = models_backbones[arguments.model].get_input_res()
274+
input_size = models_backbones[arguments.model]().get_input_res()
275275
lfw = LFW(arguments.val, arguments.v_list, arguments.v_land)
276276
assert lfw.use_landmarks
277277
log.info('Using landmarks for the LFW images.')

pytorch_toolkit/face_recognition/model/backbones/resnet.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919

2020
class ResNet(nn.Module):
21-
def __init__(self, block, layers, num_classes=1000, activation=nn.ReLU):
21+
def __init__(self, block, layers, num_classes=1000, activation=nn.ReLU, head=False):
2222
self.inplanes = 64
2323
super(ResNet, self).__init__()
2424
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
@@ -32,7 +32,12 @@ def __init__(self, block, layers, num_classes=1000, activation=nn.ReLU):
3232
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, activation=activation)
3333
self.avgpool = nn.Conv2d(512 * block.expansion, 512 * block.expansion, 7,
3434
groups=512 * block.expansion, bias=False)
35-
self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1, stride=1, padding=0, bias=False)
35+
self.head = head
36+
if not self.head:
37+
self.output_channels = 512 * block.expansion
38+
else:
39+
self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1, stride=1, padding=0, bias=False)
40+
self.output_channels = num_classes
3641

3742
for m in self.modules():
3843
if isinstance(m, nn.Conv2d):
@@ -70,10 +75,14 @@ def forward(self, x):
7075
x = self.layer4(x)
7176

7277
x = self.avgpool(x)
73-
x = self.fc(x)
78+
if self.head:
79+
x = self.fc(x)
7480

7581
return x
7682

83+
def get_output_channels(self):
84+
return self.output_channels
85+
7786

7887
def resnet50(**kwargs):
7988
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)

pytorch_toolkit/face_recognition/model/backbones/se_resnet.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@
1212
"""
1313

1414
import math
15-
1615
import torch.nn as nn
1716

1817
from model.blocks.se_resnet_blocks import SEBottleneck
1918

2019

2120
class SEResNet(nn.Module):
22-
def __init__(self, block, layers, num_classes=1000, activation=nn.ReLU):
23-
self.inplanes = 64
21+
def __init__(self, block, layers, num_classes=1000, activation=nn.ReLU, head=False):
2422
super(SEResNet, self).__init__()
23+
self.inplanes = 64
2524
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
2625
bias=False)
2726
self.bn1 = nn.BatchNorm2d(64)
@@ -33,7 +32,12 @@ def __init__(self, block, layers, num_classes=1000, activation=nn.ReLU):
3332
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, activation=activation)
3433
self.avgpool = nn.Conv2d(512 * block.expansion, 512 * block.expansion, 7,
3534
groups=512 * block.expansion, bias=False)
36-
self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1, stride=1, padding=0, bias=False)
35+
self.head = head
36+
if not self.head:
37+
self.output_channels = 512 * block.expansion
38+
else:
39+
self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1, stride=1, padding=0, bias=False)
40+
self.output_channels = num_classes
3741

3842
for m in self.modules():
3943
if isinstance(m, nn.Conv2d):
@@ -72,10 +76,14 @@ def forward(self, x):
7276
x = self.layer4(x)
7377

7478
x = self.avgpool(x)
75-
x = self.fc(x)
79+
if self.head:
80+
x = self.fc(x)
7681

7782
return x
7883

84+
def get_output_channels(self):
85+
return self.output_channels
86+
7987

8088
def se_resnet50(**kwargs):
8189
model = SEResNet(SEBottleneck, [3, 4, 6, 3], **kwargs)

pytorch_toolkit/face_recognition/model/backbones/se_resnext.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,29 @@
1818

1919

2020
class SEResNeXt(nn.Module):
21-
22-
def __init__(self, block, layers, cardinality=32, num_classes=1000):
21+
def __init__(self, block, layers, cardinality=32, num_classes=1000, activation=nn.ReLU, head=False):
2322
super(SEResNeXt, self).__init__()
2423
self.cardinality = cardinality
2524
self.inplanes = 64
2625

27-
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
26+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3,
2827
bias=False)
2928
self.bn1 = nn.BatchNorm2d(64)
3029
self.relu = nn.ReLU(inplace=True)
3130
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
3231

33-
self.layer1 = self._make_layer(block, 64, layers[0])
34-
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
35-
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
36-
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
37-
38-
self.avgpool = nn.AdaptiveAvgPool2d(1)
39-
self.fc = nn.Linear(512 * block.expansion, num_classes)
32+
self.layer1 = self._make_layer(block, 64, layers[0], activation=activation)
33+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, activation=activation)
34+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, activation=activation)
35+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, activation=activation)
36+
self.avgpool = nn.Conv2d(512 * block.expansion, 512 * block.expansion, 7,
37+
groups=512 * block.expansion, bias=False)
38+
self.head = head
39+
if not self.head:
40+
self.output_channels = 512 * block.expansion
41+
else:
42+
self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1, stride=1, padding=0, bias=False)
43+
self.output_channels = num_classes
4044

4145
for m in self.modules():
4246
if isinstance(m, nn.Conv2d):
@@ -48,7 +52,7 @@ def __init__(self, block, layers, cardinality=32, num_classes=1000):
4852
m.weight.data.fill_(1)
4953
m.bias.data.zero_()
5054

51-
def _make_layer(self, block, planes, blocks, stride=1):
55+
def _make_layer(self, block, planes, blocks, stride=1, activation=nn.ReLU):
5256
downsample = None
5357
if stride != 1 or self.inplanes != planes * block.expansion:
5458
downsample = nn.Sequential(
@@ -58,10 +62,10 @@ def _make_layer(self, block, planes, blocks, stride=1):
5862
)
5963

6064
layers = []
61-
layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
65+
layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample, activation=activation))
6266
self.inplanes = planes * block.expansion
6367
for _ in range(1, blocks):
64-
layers.append(block(self.inplanes, planes, self.cardinality))
68+
layers.append(block(self.inplanes, planes, self.cardinality, activation=activation))
6569

6670
return nn.Sequential(*layers)
6771

@@ -77,12 +81,14 @@ def forward(self, x):
7781
x = self.layer4(x)
7882

7983
x = self.avgpool(x)
80-
x = x.view(x.size(0), -1)
81-
82-
x = self.fc(x)
84+
if self.head:
85+
x = self.fc(x)
8386

8487
return x
8588

89+
def get_output_channels(self):
90+
return self.output_channels
91+
8692

8793
def se_resnext50(**kwargs):
8894
model = SEResNeXt(SEBottleneckX, [3, 4, 6, 3], **kwargs)

pytorch_toolkit/face_recognition/model/blocks/se_resnet_blocks.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, activation=nn.Re
3030
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
3131
self.bn3 = nn.BatchNorm2d(planes * 4)
3232

33-
self.relu = make_activation(activation)
33+
self.relu1 = make_activation(activation)
34+
self.relu2 = make_activation(activation)
35+
self.relu3 = make_activation(activation)
36+
self.relu4 = make_activation(activation)
3437

3538
# SE
3639
self.global_pool = nn.AdaptiveAvgPool2d(1)
@@ -47,25 +50,25 @@ def forward(self, x):
4750

4851
out = self.conv1(x)
4952
out = self.bn1(out)
50-
out = self.relu(out)
53+
out = self.relu1(out)
5154

5255
out = self.conv2(out)
5356
out = self.bn2(out)
54-
out = self.relu(out)
57+
out = self.relu2(out)
5558

5659
out = self.conv3(out)
5760
out = self.bn3(out)
5861

5962
out1 = self.global_pool(out)
6063
out1 = self.conv_down(out1)
61-
out1 = self.relu(out1)
64+
out1 = self.relu3(out1)
6265
out1 = self.conv_up(out1)
6366
out1 = self.sig(out1)
6467

6568
if self.downsample is not None:
6669
residual = self.downsample(x)
6770

6871
res = out1 * out + residual
69-
res = self.relu(res)
72+
res = self.relu4(res)
7073

7174
return res

pytorch_toolkit/face_recognition/model/blocks/se_resnext_blocks.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
import torch.nn as nn
1515

1616
from model.blocks.shared_blocks import SELayer
17+
from model.blocks.shared_blocks import make_activation
1718

1819

1920
class SEBottleneckX(nn.Module):
2021
expansion = 4
2122

22-
def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
23+
def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None, activation=nn.ReLU):
2324
super(SEBottleneckX, self).__init__()
2425
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
2526
self.bn1 = nn.BatchNorm2d(planes * 2)
@@ -31,9 +32,12 @@ def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
3132
self.conv3 = nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
3233
self.bn3 = nn.BatchNorm2d(planes * 4)
3334

34-
self.selayer = SELayer(planes * 4, 16, nn.ReLU)
35+
self.selayer = SELayer(planes * 4, 16, activation)
36+
37+
self.relu1 = make_activation(activation)
38+
self.relu2 = make_activation(activation)
39+
self.relu3 = make_activation(activation)
3540

36-
self.relu = nn.ReLU(inplace=True)
3741
self.downsample = downsample
3842
self.stride = stride
3943

@@ -42,11 +46,11 @@ def forward(self, x):
4246

4347
out = self.conv1(x)
4448
out = self.bn1(out)
45-
out = self.relu(out)
49+
out = self.relu1(out)
4650

4751
out = self.conv2(out)
4852
out = self.bn2(out)
49-
out = self.relu(out)
53+
out = self.relu2(out)
5054

5155
out = self.conv3(out)
5256
out = self.bn3(out)
@@ -57,6 +61,6 @@ def forward(self, x):
5761
residual = self.downsample(x)
5862

5963
out += residual
60-
out = self.relu(out)
64+
out = self.relu3(out)
6165

6266
return out

0 commit comments

Comments
 (0)