Skip to content

Commit

Permalink
Refactoring face recognizer module. API changed, added optional retur…
Browse files Browse the repository at this point in the history
…n predictions. Preprocessing module.

Former-commit-id: 8ddc9db
  • Loading branch information
ldulcic committed Aug 1, 2019
1 parent df6662a commit 04c21ba
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 82 deletions.
2 changes: 1 addition & 1 deletion __init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .arsfutura_face_recognition import face_recogniser_factory

face_recognizer = face_recogniser_factory()
face_recognizer = face_recogniser_factory(include_predictions=True)


def recognise_faces(img):
Expand Down
7 changes: 6 additions & 1 deletion arsfutura_face_recognition/aligner/mtcnn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from facenet_pytorch import MTCNN
from . import Aligner
from torchvision import transforms
from .. import preprocessing


class MTCNNAligner(Aligner):
def __init__(self):
self.mtcnn = MTCNN(keep_all=True)
self.preprocess = transforms.Compose([
preprocessing.ExifOrientationNormalize()
])

def align(self, img):
return self.mtcnn.detect(img)
return self.mtcnn.detect(self.preprocess(img))
10 changes: 0 additions & 10 deletions arsfutura_face_recognition/classifier/__init__.py

This file was deleted.

11 changes: 0 additions & 11 deletions arsfutura_face_recognition/classifier/classifier.py

This file was deleted.

5 changes: 0 additions & 5 deletions arsfutura_face_recognition/classifier/factory.py

This file was deleted.

63 changes: 31 additions & 32 deletions arsfutura_face_recognition/face_recogniser.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,59 @@
import torch
import pickle
import os
from .aligner.factory import aligner_factory
from .facenet.factory import facenet_factory
from .classifier.factory import classifier_factory
from facenet_pytorch.models.utils.detect_face import extract_face
from facenet_pytorch.models.mtcnn import prewhiten
from collections import namedtuple

Face = namedtuple('Face', 'bb identity probability')
Prediction = namedtuple('Prediction', 'id name confidence')
Face = namedtuple('Face', 'top_prediction bb all_predictions')
BoundingBox = namedtuple('BoundingBox', 'left top right bottom')


class BoundingBox:
def __init__(self, left, top, right, bottom):
self._left = left
self._top = top
self._right = right
self._bottom = bottom

def left(self):
return self._left

def top(self):
return self._top

def right(self):
return self._right

def bottom(self):
return self._bottom


def face_recogniser_factory():
def face_recogniser_factory(include_predictions=False):
return FaceRecogniser(
aligner=aligner_factory(),
facenet=facenet_factory(),
classifier=classifier_factory()
include_predictions=include_predictions
)


def top_prediction(le, probs):
top_label = probs.argmax()
return Prediction(id=top_label, name=le.classes_[top_label], confidence=probs[top_label])


def to_predictions(le, probs):
return [Prediction(id=i, name=le.classes_[i], confidence=prob) for i, prob in enumerate(probs)]


class FaceRecogniser:
def __init__(self, aligner, facenet, classifier):
def __init__(self, aligner, facenet, include_predictions):
self.aligner = aligner
self.facenet = facenet
self.classifier = classifier
self.le, self.classifier = pickle.load(
open(os.path.join(os.path.dirname(__file__), '../models/model.pkl'), 'rb'))
self.include_predictions = include_predictions

def recognise_faces(self, img):
bbs, _ = self.aligner(img)
if bbs is None:
# if no face is detected
return []

faces = torch.stack([prewhiten(extract_face(img, bb)) for bb in bbs])
faces = torch.stack([extract_face(img, bb) for bb in bbs])
embeddings = self.facenet(faces).detach().numpy()
people = self.classifier(embeddings)

return [Face(BoundingBox(left=bb[0], top=bb[1], right=bb[2], bottom=bb[3]), person, 100)
for bb, person in zip(bbs, people)]
predictions = self.classifier.predict_proba(embeddings)

return [
Face(
top_prediction=top_prediction(self.le, probs),
bb=BoundingBox(left=bb[0], top=bb[1], right=bb[2], bottom=bb[3]),
all_predictions=None if not self.include_predictions else to_predictions(self.le, probs)
)
for bb, probs in zip(bbs, predictions)
]

def __call__(self, *args, **kwargs):
return self.recognise_faces(*args, **kwargs)
7 changes: 6 additions & 1 deletion arsfutura_face_recognition/facenet/facenet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from . import FaceNet
from facenet_pytorch import InceptionResnetV1
from torchvision import transforms
from .. import preprocessing


class FaceNetImpl(FaceNet):
def __init__(self):
self.facenet = InceptionResnetV1(pretrained='vggface2').eval()
self.preprocess = transforms.Compose([
preprocessing.Whitening()
])

def forward(self, img):
return self.facenet(img)
return self.facenet(self.preprocess(img))
42 changes: 42 additions & 0 deletions arsfutura_face_recognition/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from PIL import Image

# EXIF orientation info http://sylvana.net/jpegcrop/exif_orientation.html
exif_orientation_tag = 0x0112
exif_transpose_sequences = [ # Val 0th row 0th col
[], # 0 (reserved)
[], # 1 top left
[Image.FLIP_LEFT_RIGHT], # 2 top right
[Image.ROTATE_180], # 3 bottom right
[Image.FLIP_TOP_BOTTOM], # 4 bottom left
[Image.FLIP_LEFT_RIGHT, Image.ROTATE_90], # 5 left top
[Image.ROTATE_270], # 6 right top
[Image.FLIP_TOP_BOTTOM, Image.ROTATE_90], # 7 right bottom
[Image.ROTATE_90], # 8 left bottom
]


class ExifOrientationNormalize(object):
"""
Normalizes rotation of the image based on exif orientation info (if exists.)
"""

def __call__(self, img):
if 'parsed_exif' in img.info:
orientation = img.info['parsed_exif'][exif_orientation_tag]
transposes = exif_transpose_sequences[orientation]
for trans in transposes:
img = img.transpose(trans)
return img


class Whitening(object):
"""
Whitens the image.
"""

def __call__(self, img):
mean = img.mean()
std = img.std()
std_adj = std.clamp(min=1.0 / (float(img.numel()) ** 0.5))
y = (img - mean) / std_adj
return y
2 changes: 2 additions & 0 deletions bin/align-mtcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import argparse
from torchvision import datasets, transforms
from facenet_pytorch.models.mtcnn import MTCNN
from arsfutura_face_recognition import preprocessing
from PIL import Image


Expand All @@ -27,6 +28,7 @@ def create_dirs(root_dir, classes):
def main():
args = parse_args()
trans = transforms.Compose([
preprocessing.ExifOrientationNormalize(),
transforms.Resize(1024)
])

Expand Down
22 changes: 11 additions & 11 deletions face_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,22 @@ def parse_args(args=None):
parser = argparse.ArgumentParser(
'Script for recognising faces on picture. Output of this script is json with list of people on picture and '
'base64 encoded picture which has bounding boxes of people.')
image_group = parser.add_mutually_exclusive_group(required=True)
image_group.add_argument('--image-path', help='Path to image file.')
parser.add_argument('--classifier-path', required=True, help='Path to serialized classifier.')
parser.add_argument('--image-path', required=True, help='Path to image file.')
return parser.parse_args(args)


def draw_bb_on_img(faces, img):
for face in faces:
cv2.rectangle(img, (int(face.bb.left()), int(face.bb.top())), (int(face.bb.right()), int(face.bb.bottom())),
cv2.rectangle(img, (int(face.bb.left), int(face.bb.top)), (int(face.bb.right), int(face.bb.bottom)),
(0, 255, 0), 2)
cv2.putText(img, "%s %.2f%%" % (face.identity, face.probability),
(int(face.bb.left()), int(face.bb.bottom()) + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, (255, 255, 255), 1)
cv2.putText(img, "%s %.2f%%" % (face.top_prediction.name.upper(), face.top_prediction.confidence * 100),
(int(face.bb.left), int(face.bb.top) - 10), cv2.FONT_HERSHEY_SIMPLEX, 3, (255, 255, 255), 4,
cv2.LINE_AA)


def _recognise_faces(args):
img = Image.open(args.image_path)
faces = face_recogniser_factory()(img)
faces = face_recogniser_factory(include_predictions=True)(img)
img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
draw_bb_on_img(faces, img_cv)
return faces, img_cv
Expand All @@ -36,10 +35,11 @@ def _recognise_faces(args):
def main():
args = parse_args()
faces, img = _recognise_faces(args)
cv2.imshow('image', img)
cv2.waitKey()
cv2.destroyAllWindows()
cv2.waitKey(1)
cv2.imwrite('img.jpg', img)
# cv2.imshow('image', img)
# cv2.waitKey()
# cv2.destroyAllWindows()
# cv2.waitKey(1)


if __name__ == '__main__':
Expand Down
5 changes: 0 additions & 5 deletions generate_embeddings.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
#!/usr/bin/env bash

python -m bin.exif_orientation_normalize --images-path data/images
echo -e "\n===================================="
echo "Exif orientation normalization done."
echo "===================================="

python -m bin.align-mtcnn --input-folder data/images --output-folder data/aligned
echo -e "\n===================================="
echo "Cropping faces (aligning) done."
Expand Down
Binary file modified models/model.pkl
Binary file not shown.
9 changes: 4 additions & 5 deletions real_time_face_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ def main():
faces = face_recogniser(Image.fromarray(frame))
if faces is not None:
for face in faces:
cv2.rectangle(frame, (int(face.bb.left()), int(face.bb.top())),
(int(face.bb.right()), int(face.bb.bottom())),
cv2.rectangle(frame, (int(face.bb.left), int(face.bb.top)), (int(face.bb.right), int(face.bb.bottom)),
(0, 255, 0), 2)
cv2.putText(frame, "%s %.2f%%" % (face.identity, face.probability),
(int(face.bb.left()), int(face.bb.bottom()) + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5,
(255, 255, 255), 1)
cv2.putText(frame, "%s %.2f%%" % (face.top_prediction.name.upper(), face.top_prediction.confidence * 100),
(int(face.bb.left), int(face.bb.top) - 10), cv2.FONT_HERSHEY_SIMPLEX, 3, (255, 255, 255), 4,
cv2.LINE_AA)

# Display the resulting frame
cv2.imshow('frame', frame)
Expand Down

0 comments on commit 04c21ba

Please sign in to comment.