-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactoring face recognizer module. API changed, added optional retur…
…n predictions. Preprocessing module. Former-commit-id: 8ddc9db
- Loading branch information
Showing
13 changed files
with
103 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,15 @@ | ||
from facenet_pytorch import MTCNN | ||
from . import Aligner | ||
from torchvision import transforms | ||
from .. import preprocessing | ||
|
||
|
||
class MTCNNAligner(Aligner): | ||
def __init__(self): | ||
self.mtcnn = MTCNN(keep_all=True) | ||
self.preprocess = transforms.Compose([ | ||
preprocessing.ExifOrientationNormalize() | ||
]) | ||
|
||
def align(self, img): | ||
return self.mtcnn.detect(img) | ||
return self.mtcnn.detect(self.preprocess(img)) |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,60 +1,59 @@ | ||
import torch | ||
import pickle | ||
import os | ||
from .aligner.factory import aligner_factory | ||
from .facenet.factory import facenet_factory | ||
from .classifier.factory import classifier_factory | ||
from facenet_pytorch.models.utils.detect_face import extract_face | ||
from facenet_pytorch.models.mtcnn import prewhiten | ||
from collections import namedtuple | ||
|
||
Face = namedtuple('Face', 'bb identity probability') | ||
Prediction = namedtuple('Prediction', 'id name confidence') | ||
Face = namedtuple('Face', 'top_prediction bb all_predictions') | ||
BoundingBox = namedtuple('BoundingBox', 'left top right bottom') | ||
|
||
|
||
class BoundingBox: | ||
def __init__(self, left, top, right, bottom): | ||
self._left = left | ||
self._top = top | ||
self._right = right | ||
self._bottom = bottom | ||
|
||
def left(self): | ||
return self._left | ||
|
||
def top(self): | ||
return self._top | ||
|
||
def right(self): | ||
return self._right | ||
|
||
def bottom(self): | ||
return self._bottom | ||
|
||
|
||
def face_recogniser_factory(): | ||
def face_recogniser_factory(include_predictions=False): | ||
return FaceRecogniser( | ||
aligner=aligner_factory(), | ||
facenet=facenet_factory(), | ||
classifier=classifier_factory() | ||
include_predictions=include_predictions | ||
) | ||
|
||
|
||
def top_prediction(le, probs): | ||
top_label = probs.argmax() | ||
return Prediction(id=top_label, name=le.classes_[top_label], confidence=probs[top_label]) | ||
|
||
|
||
def to_predictions(le, probs): | ||
return [Prediction(id=i, name=le.classes_[i], confidence=prob) for i, prob in enumerate(probs)] | ||
|
||
|
||
class FaceRecogniser: | ||
def __init__(self, aligner, facenet, classifier): | ||
def __init__(self, aligner, facenet, include_predictions): | ||
self.aligner = aligner | ||
self.facenet = facenet | ||
self.classifier = classifier | ||
self.le, self.classifier = pickle.load( | ||
open(os.path.join(os.path.dirname(__file__), '../models/model.pkl'), 'rb')) | ||
self.include_predictions = include_predictions | ||
|
||
def recognise_faces(self, img): | ||
bbs, _ = self.aligner(img) | ||
if bbs is None: | ||
# if no face is detected | ||
return [] | ||
|
||
faces = torch.stack([prewhiten(extract_face(img, bb)) for bb in bbs]) | ||
faces = torch.stack([extract_face(img, bb) for bb in bbs]) | ||
embeddings = self.facenet(faces).detach().numpy() | ||
people = self.classifier(embeddings) | ||
|
||
return [Face(BoundingBox(left=bb[0], top=bb[1], right=bb[2], bottom=bb[3]), person, 100) | ||
for bb, person in zip(bbs, people)] | ||
predictions = self.classifier.predict_proba(embeddings) | ||
|
||
return [ | ||
Face( | ||
top_prediction=top_prediction(self.le, probs), | ||
bb=BoundingBox(left=bb[0], top=bb[1], right=bb[2], bottom=bb[3]), | ||
all_predictions=None if not self.include_predictions else to_predictions(self.le, probs) | ||
) | ||
for bb, probs in zip(bbs, predictions) | ||
] | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.recognise_faces(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,15 @@ | ||
from . import FaceNet | ||
from facenet_pytorch import InceptionResnetV1 | ||
from torchvision import transforms | ||
from .. import preprocessing | ||
|
||
|
||
class FaceNetImpl(FaceNet): | ||
def __init__(self): | ||
self.facenet = InceptionResnetV1(pretrained='vggface2').eval() | ||
self.preprocess = transforms.Compose([ | ||
preprocessing.Whitening() | ||
]) | ||
|
||
def forward(self, img): | ||
return self.facenet(img) | ||
return self.facenet(self.preprocess(img)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from PIL import Image | ||
|
||
# EXIF orientation info http://sylvana.net/jpegcrop/exif_orientation.html | ||
exif_orientation_tag = 0x0112 | ||
exif_transpose_sequences = [ # Val 0th row 0th col | ||
[], # 0 (reserved) | ||
[], # 1 top left | ||
[Image.FLIP_LEFT_RIGHT], # 2 top right | ||
[Image.ROTATE_180], # 3 bottom right | ||
[Image.FLIP_TOP_BOTTOM], # 4 bottom left | ||
[Image.FLIP_LEFT_RIGHT, Image.ROTATE_90], # 5 left top | ||
[Image.ROTATE_270], # 6 right top | ||
[Image.FLIP_TOP_BOTTOM, Image.ROTATE_90], # 7 right bottom | ||
[Image.ROTATE_90], # 8 left bottom | ||
] | ||
|
||
|
||
class ExifOrientationNormalize(object): | ||
""" | ||
Normalizes rotation of the image based on exif orientation info (if exists.) | ||
""" | ||
|
||
def __call__(self, img): | ||
if 'parsed_exif' in img.info: | ||
orientation = img.info['parsed_exif'][exif_orientation_tag] | ||
transposes = exif_transpose_sequences[orientation] | ||
for trans in transposes: | ||
img = img.transpose(trans) | ||
return img | ||
|
||
|
||
class Whitening(object): | ||
""" | ||
Whitens the image. | ||
""" | ||
|
||
def __call__(self, img): | ||
mean = img.mean() | ||
std = img.std() | ||
std_adj = std.clamp(min=1.0 / (float(img.numel()) ** 0.5)) | ||
y = (img - mean) / std_adj | ||
return y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters