From 9ee8a8f372eb6a2b9c53428503f4696c09befc80 Mon Sep 17 00:00:00 2001 From: stijn Date: Thu, 11 Dec 2025 16:41:27 +0100 Subject: [PATCH 1/2] feat: add first version of fakeinversion model --- .../data/datasets/genimagedataset.py | 2 - .../models/detection/fakeinversion.py | 154 ++++++++++++++ pdm.lock | 196 +++++++++++++++++- pyproject.toml | 2 + 4 files changed, 351 insertions(+), 3 deletions(-) create mode 100644 deepfake_detection/models/detection/fakeinversion.py diff --git a/deepfake_detection/data/datasets/genimagedataset.py b/deepfake_detection/data/datasets/genimagedataset.py index 7a7ac0ba..5b45cc84 100644 --- a/deepfake_detection/data/datasets/genimagedataset.py +++ b/deepfake_detection/data/datasets/genimagedataset.py @@ -82,10 +82,8 @@ def _format_label(self, label: str) -> str: label = label.lower() # For stable diffusion labels, correctly format version number - print(label) if label.startswith('stable diffusion'): label = re.sub(r'v_(\d)_(\d)', r'v$1.$2', label) - print(label) # Replace any underscores with spaces label = label.replace('_', ' ') diff --git a/deepfake_detection/models/detection/fakeinversion.py b/deepfake_detection/models/detection/fakeinversion.py new file mode 100644 index 00000000..c95170d6 --- /dev/null +++ b/deepfake_detection/models/detection/fakeinversion.py @@ -0,0 +1,154 @@ +from typing import Union, List + +import torch +from torchvision import transforms, models +from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration +from diffusers import StableDiffusionPipeline, DDIMScheduler +from torch.nn import functional as F +import numpy as np + +from deepfake_detection.data import Instance, Dataset, FileImageInstance, ImageInstance +from deepfake_detection.models.model import Model +from deepfake_detection.models.prediction import Prediction + + +def process_instance(instance: Union[FileImageInstance, ImageInstance]) -> torch.Tensor: + preprocess = transforms.Compose([ + transforms.Resize((512, 512)), + transforms.ToTensor() + ]) + return preprocess(instance.data) + + +class FakeInversion(Model): + """ + Implementation of the FakeInversion model by Cazenavette et al. (2024). + + More info about the model can be found here: https://fake-inversion.github.io. + """ + + def __init__(self, device: str = 'cuda'): + super().__init__("FakeInversion") + self.classifier = None + self.captioning = None + self.embedding = None + self.feature_extractor = None + self.device = device + + + def load_model(self): + # Define captioning model + self.captioning = ImageCaptioning() + + # Define embedding model + self.embedding = TextEmbedding() + + # Define feature extractor model + self.feature_extractor = FeatureExtractor() + + # Define classifier + self.classifier = models.resnet50(pretrained=True) + self.classifier.fc = torch.nn.Linear(self.classifier.fc.in_features, 2) # Binary classification + self.classifier.to(self.device).eval() + + + def predict(self, instance: Union[ImageInstance, FileImageInstance]) -> Prediction: + + if not self.classifier: + self.load_model() + + # Create img tensor + img_tensor = process_instance(instance).to(self.device) + + # Generate caption + caption = self.captioning.get_caption(instance.data) + + # Get text embedding + text_embedding = self.embedding.get_embedding(caption) + + # Extract features + latent, noise, reconstructed_image = self.feature_extractor.extract_features(img_tensor, text_embedding) + reconstructed_image = reconstructed_image.squeeze(0) # Remove batch dimension if present + + # Pass reconstructed_image directly to the classifier + output = self.classifier(reconstructed_image.unsqueeze(0)) + + # Transform to prediction object + prediction = torch.argmax(F.softmax(output, dim=1), dim=1) + + return Prediction(classification={'fake': float(prediction[0]), 'real': 1 - float(prediction[0])}, + embedding=latent.cpu().detach().numpy(), + text=caption, + image=(reconstructed_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + ) + + + def predict_batch(self, instances: Union[List[Instance], Dataset]) -> List[Prediction]: + pass + + + +# BLIP: Image Captioning +class ImageCaptioning: + def __init__(self, device: str = 'cuda'): + self.device = device + self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") + self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", + use_safetensors=True).to(self.device) + + def get_caption(self, image): + inputs = self.processor(image, return_tensors="pt").to(self.device) + caption_ids = self.model.generate(**inputs) + caption = self.processor.decode(caption_ids[0], skip_special_tokens=True) + return caption + + +# CLIP: Text Embedding +class TextEmbedding: + def __init__(self, device: str = 'cuda'): + self.device = device + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", + use_safetensors=True) + self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", + use_safetensors=True).to(self.device) + + def get_embedding(self, caption): + inputs = self.processor(text=[caption], + return_tensors="pt", + padding=True).to(self.device) + text_embedding = self.model.get_text_features(**inputs) + return text_embedding + + +# Stable Diffusion Feature Extraction +class FeatureExtractor: + + def __init__(self, model_name="runwayml/stable-diffusion-v1-5", device: str = 'cuda'): + self.device = device + self.pipe = StableDiffusionPipeline.from_pretrained(model_name, use_safetensors=True).to(self.device) + # The scheduler configuration is likely located under the model's directory. + # Specify the 'scheduler' subfolder: + self.scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + # Call set_timesteps to initialize num_inference_steps + self.scheduler.set_timesteps(50) # You can adjust the number of steps here + + + def extract_features(self, image, text_embedding): + # Encode the image to latent space + latents = self.pipe.vae.encode(image.unsqueeze(0).to(self.device)).latent_dist.sample() + latents = latents * self.pipe.vae.config.scaling_factor + + # Invert using DDIM + noise = torch.randn_like(latents).to(self.device) + inverted_latents = self.scheduler.add_noise(latents, + noise, + torch.tensor([49], device=self.device, dtype=torch.long) + ) + + # Reconstruct image from inverted latent + with torch.no_grad(): + # Passing noise and inverted_latents as arguments and removing text_embedding as it is not the timestep + reconstructed_latents = self.scheduler.step(noise, 49, inverted_latents).prev_sample + reconstructed_image = self.pipe.vae.decode(reconstructed_latents / self.pipe.vae.config.scaling_factor).sample + + return latents, noise, reconstructed_image \ No newline at end of file diff --git a/pdm.lock b/pdm.lock index 4f9a663a..ad3b28f2 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,11 +5,31 @@ groups = ["default", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:700a0a4b7fda345f81916e3eb213fcc48ca7bdbf1797d18ec343a781adf94579" +content_hash = "sha256:fe5af1de4d27693ef3dda4faf175ac049893c0a9d87bb15a48352c2e5465279d" [[metadata.targets]] requires_python = ">=3.10" +[[package]] +name = "accelerate" +version = "1.12.0" +requires_python = ">=3.10.0" +summary = "Accelerate" +groups = ["default"] +dependencies = [ + "huggingface-hub>=0.21.0", + "numpy>=1.17", + "packaging>=20.0", + "psutil", + "pyyaml", + "safetensors>=0.4.3", + "torch>=2.0.0", +] +files = [ + {file = "accelerate-1.12.0-py3-none-any.whl", hash = "sha256:3e2091cd341423207e2f084a6654b1efcd250dc326f2a37d6dde446e07cabb11"}, + {file = "accelerate-1.12.0.tar.gz", hash = "sha256:70988c352feb481887077d2ab845125024b2a137a5090d6d7a32b57d03a45df6"}, +] + [[package]] name = "aiofiles" version = "25.1.0" @@ -475,6 +495,20 @@ files = [ {file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"}, ] +[[package]] +name = "click" +version = "8.3.1" +requires_python = ">=3.10" +summary = "Composable command line interface toolkit" +groups = ["default"] +dependencies = [ + "colorama; platform_system == \"Windows\"", +] +files = [ + {file = "click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6"}, + {file = "click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a"}, +] + [[package]] name = "colorama" version = "0.4.6" @@ -676,6 +710,28 @@ files = [ {file = "deprecated-1.3.1.tar.gz", hash = "sha256:b1b50e0ff0c1fddaa5708a2c6b0a6588bb09b892825ab2b214ac9ea9d92a5223"}, ] +[[package]] +name = "diffusers" +version = "0.36.0" +requires_python = ">=3.8.0" +summary = "State-of-the-art diffusion in PyTorch and JAX." +groups = ["default"] +dependencies = [ + "Pillow", + "filelock", + "httpx<1.0.0", + "huggingface-hub<2.0,>=0.34.0", + "importlib-metadata", + "numpy", + "regex!=2019.12.17", + "requests", + "safetensors>=0.3.1", +] +files = [ + {file = "diffusers-0.36.0-py3-none-any.whl", hash = "sha256:525d42abc74bfc3b2db594999961295c054b48ef40a11724dacf50e6abd1af98"}, + {file = "diffusers-0.36.0.tar.gz", hash = "sha256:a9cde8721b415bde6a678f2d02abb85396487e1b0e0d2b4abb462d14a9825ab0"}, +] + [[package]] name = "dill" version = "0.4.0" @@ -947,6 +1003,41 @@ files = [ {file = "h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1"}, ] +[[package]] +name = "hf-xet" +version = "1.2.1" +requires_python = ">=3.8" +summary = "Fast transfer of large files with the Hugging Face Hub." +groups = ["default"] +marker = "platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"arm64\" or platform_machine == \"aarch64\"" +files = [ + {file = "hf_xet-1.2.1-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:d9b8118b8b171f0482a61d40a473335857d2b85fcceca499fe16db887e5bf0fb"}, + {file = "hf_xet-1.2.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:29c0603f7b27d58dc35b4f859cf62a8d905782db4e3cd0253cb45b3225e6c6f1"}, + {file = "hf_xet-1.2.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ad56e0f6cfcdde30c436aee3a2adfd66d7432b0e1598353321c644576b97623a"}, + {file = "hf_xet-1.2.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29cc367533e338f2a0c65d186882c05e6e6841f5c591562c629de652d5ec219e"}, + {file = "hf_xet-1.2.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:6df5382f854cedd0cf78bffb2d93ffeb877717fe4ab0871735dbeb77a02c3c3c"}, + {file = "hf_xet-1.2.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:efd65ead9913c199031250d154c242a6ce3876e2ac725155bb66483b610e799e"}, + {file = "hf_xet-1.2.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:72fc277b3655861121dbbdd3ac0315263fe5de63f6a844dd395245dcb66928e6"}, + {file = "hf_xet-1.2.1-cp313-cp313t-win_amd64.whl", hash = "sha256:cc3902fa877c15f36ee149897f54cd97340aed602e4544bc1e80151b47635edc"}, + {file = "hf_xet-1.2.1-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:5d66cf732d3d46f3b12f1d14d6f5c639433dd332af8e851d148a80bcfcb56f25"}, + {file = "hf_xet-1.2.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:0429f6de40d2d6d4e04f412e3f9074fd5a208f9145b5247a2ea66c689b1a5768"}, + {file = "hf_xet-1.2.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2b898f9d106f0ef83bd2f73e973e3e702d1368a0404315fc8ad64156434a443b"}, + {file = "hf_xet-1.2.1-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b45cfc180275f69ecc570ef98946752fbd9424ddc9bca895942fef3f0b85f3a3"}, + {file = "hf_xet-1.2.1-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:bf957da3fe3571e2f161b98193065baa54c9eb8670b914daccb60aa1c9ee94a4"}, + {file = "hf_xet-1.2.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:cc0fde31706c0626beb5430665512e53da23004c6c020c80256be057b6248255"}, + {file = "hf_xet-1.2.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7ea0eb9dab2276fc2e502dc8c095c0c2b00cf89ae3d7899b6878dfa5ea1c1464"}, + {file = "hf_xet-1.2.1-cp314-cp314t-win_amd64.whl", hash = "sha256:ac0b51005cf12e7f88654d4e127f1b8d2378a2db575782fc3ea25d4c0e7e5c49"}, + {file = "hf_xet-1.2.1-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:0b0742e1258335686c82cbda57d6587081c25f9e004b135de485b76166c3f172"}, + {file = "hf_xet-1.2.1-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:1cae780544f0d8849174e82047bd7a8c1d9e14b7c13dce99b1462574bedf3358"}, + {file = "hf_xet-1.2.1-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ce3a09c77f1ba29fedb2b1f397948d8a237a8062afa946821811b39abb9903b4"}, + {file = "hf_xet-1.2.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e722352e5e9414030b44eacce598f840aab8b5d5ea36ea66a6c9121e52a9455b"}, + {file = "hf_xet-1.2.1-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:3214bae661c600232ce71bff88d22b6cb81f669be4e98ce4b57114c135fd208d"}, + {file = "hf_xet-1.2.1-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:397c6dadecf1303d084b675efa4afd920229ccc13b84bfc095f48dfd508f464e"}, + {file = "hf_xet-1.2.1-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:885ac1975b064cf3919650f5e5efdddf72a4ebe3e1f6fe027a4a9a319e43a17b"}, + {file = "hf_xet-1.2.1-cp37-abi3-win_amd64.whl", hash = "sha256:c76704cdda11cac957519dc8cb868eb103e561e9bfea284c8678e7d975bb4369"}, + {file = "hf_xet-1.2.1.tar.gz", hash = "sha256:6c7a48d40b25f06f7f1b0fdd96c6f9d222dd4d0c833990684723eb77163b7372"}, +] + [[package]] name = "hpack" version = "4.1.0" @@ -990,6 +1081,29 @@ files = [ {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, ] +[[package]] +name = "huggingface-hub" +version = "1.2.2" +requires_python = ">=3.9.0" +summary = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" +groups = ["default"] +dependencies = [ + "filelock", + "fsspec>=2023.5.0", + "hf-xet<2.0.0,>=1.2.0; platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"arm64\" or platform_machine == \"aarch64\"", + "httpx<1,>=0.23.0", + "packaging>=20.9", + "pyyaml>=5.1", + "shellingham", + "tqdm>=4.42.1", + "typer-slim", + "typing-extensions>=3.7.4.3", +] +files = [ + {file = "huggingface_hub-1.2.2-py3-none-any.whl", hash = "sha256:0f55d7d22058fbf8b29d8095aeee80a7b695aa764f906a21e886c1f87223718f"}, + {file = "huggingface_hub-1.2.2.tar.gz", hash = "sha256:b5b97bd37f4fe5b898a467373044649c94ee32006c032ce8fb835abe9d92ea28"}, +] + [[package]] name = "humanize" version = "4.14.0" @@ -1075,6 +1189,21 @@ files = [ {file = "imageio-2.37.2.tar.gz", hash = "sha256:0212ef2727ac9caa5ca4b2c75ae89454312f440a756fcfc8ef1993e718f50f8a"}, ] +[[package]] +name = "importlib-metadata" +version = "8.7.0" +requires_python = ">=3.9" +summary = "Read metadata from Python packages" +groups = ["default"] +dependencies = [ + "typing-extensions>=3.6.4; python_version < \"3.8\"", + "zipp>=3.20", +] +files = [ + {file = "importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd"}, + {file = "importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000"}, +] + [[package]] name = "inflate64" version = "1.0.4" @@ -2911,6 +3040,34 @@ files = [ {file = "s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920"}, ] +[[package]] +name = "safetensors" +version = "0.7.0" +requires_python = ">=3.9" +summary = "" +groups = ["default"] +files = [ + {file = "safetensors-0.7.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c82f4d474cf725255d9e6acf17252991c3c8aac038d6ef363a4bf8be2f6db517"}, + {file = "safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:94fd4858284736bb67a897a41608b5b0c2496c9bdb3bf2af1fa3409127f20d57"}, + {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e07d91d0c92a31200f25351f4acb2bc6aff7f48094e13ebb1d0fb995b54b6542"}, + {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8469155f4cb518bafb4acf4865e8bb9d6804110d2d9bdcaa78564b9fd841e104"}, + {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54bef08bf00a2bff599982f6b08e8770e09cc012d7bba00783fc7ea38f1fb37d"}, + {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:42cb091236206bb2016d245c377ed383aa7f78691748f3bb6ee1bfa51ae2ce6a"}, + {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac7252938f0696ddea46f5e855dd3138444e82236e3be475f54929f0c510d48"}, + {file = "safetensors-0.7.0-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1d060c70284127fa805085d8f10fbd0962792aed71879d00864acda69dbab981"}, + {file = "safetensors-0.7.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:cdab83a366799fa730f90a4ebb563e494f28e9e92c4819e556152ad55e43591b"}, + {file = "safetensors-0.7.0-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:672132907fcad9f2aedcb705b2d7b3b93354a2aec1b2f706c4db852abe338f85"}, + {file = "safetensors-0.7.0-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:5d72abdb8a4d56d4020713724ba81dac065fedb7f3667151c4a637f1d3fb26c0"}, + {file = "safetensors-0.7.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b0f6d66c1c538d5a94a73aa9ddca8ccc4227e6c9ff555322ea40bdd142391dd4"}, + {file = "safetensors-0.7.0-cp38-abi3-win32.whl", hash = "sha256:c74af94bf3ac15ac4d0f2a7c7b4663a15f8c2ab15ed0fc7531ca61d0835eccba"}, + {file = "safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755"}, + {file = "safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4729811a6640d019a4b7ba8638ee2fd21fa5ca8c7e7bdf0fed62068fcaac737"}, + {file = "safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:12f49080303fa6bb424b362149a12949dfbbf1e06811a88f2307276b0c131afd"}, + {file = "safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0071bffba4150c2f46cae1432d31995d77acfd9f8db598b5d1a2ce67e8440ad2"}, + {file = "safetensors-0.7.0-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:473b32699f4200e69801bf5abf93f1a4ecd432a70984df164fc22ccf39c4a6f3"}, + {file = "safetensors-0.7.0.tar.gz", hash = "sha256:07663963b67e8bd9f0b8ad15bb9163606cd27cc5a1b96235a50d8369803b96b0"}, +] + [[package]] name = "scikit-image" version = "0.25.1" @@ -3076,6 +3233,17 @@ files = [ {file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"}, ] +[[package]] +name = "shellingham" +version = "1.5.4" +requires_python = ">=3.7" +summary = "Tool to Detect Surrounding Shell" +groups = ["default"] +files = [ + {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"}, + {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, +] + [[package]] name = "six" version = "1.17.0" @@ -3410,6 +3578,21 @@ files = [ {file = "triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42"}, ] +[[package]] +name = "typer-slim" +version = "0.20.0" +requires_python = ">=3.8" +summary = "Typer, build great CLIs. Easy to code. Based on Python type hints." +groups = ["default"] +dependencies = [ + "click>=8.0.0", + "typing-extensions>=3.7.4.3", +] +files = [ + {file = "typer_slim-0.20.0-py3-none-any.whl", hash = "sha256:f42a9b7571a12b97dddf364745d29f12221865acef7a2680065f9bb29c7dc89d"}, + {file = "typer_slim-0.20.0.tar.gz", hash = "sha256:9fc6607b3c6c20f5c33ea9590cbeb17848667c51feee27d9e314a579ab07d1a3"}, +] + [[package]] name = "typing-extensions" version = "4.15.0" @@ -3633,3 +3816,14 @@ files = [ {file = "xmltodict-1.0.2-py3-none-any.whl", hash = "sha256:62d0fddb0dcbc9f642745d8bbf4d81fd17d6dfaec5a15b5c1876300aad92af0d"}, {file = "xmltodict-1.0.2.tar.gz", hash = "sha256:54306780b7c2175a3967cad1db92f218207e5bc1aba697d887807c0fb68b7649"}, ] + +[[package]] +name = "zipp" +version = "3.23.0" +requires_python = ">=3.9" +summary = "Backport of pathlib-compatible object wrapper for zip files" +groups = ["default"] +files = [ + {file = "zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e"}, + {file = "zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166"}, +] diff --git a/pyproject.toml b/pyproject.toml index d73cb7bb..f4d916ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ dependencies = [ "torch>=2.7.1", "fiftyone>=1.11.0", "torchvision>=0.22.1", + "diffusers>=0.36.0", + "accelerate>=1.12.0", ] [dependency-groups] From ae913ff044d7eea049a0d3ce461b04c844a3e9a7 Mon Sep 17 00:00:00 2001 From: stijn Date: Thu, 18 Dec 2025 16:01:10 +0100 Subject: [PATCH 2/2] fix: made batch processing possible with fakeinversion model --- .../models/detection/fakeinversion.py | 181 ++++++++++-------- 1 file changed, 105 insertions(+), 76 deletions(-) diff --git a/deepfake_detection/models/detection/fakeinversion.py b/deepfake_detection/models/detection/fakeinversion.py index c95170d6..8cc24f53 100644 --- a/deepfake_detection/models/detection/fakeinversion.py +++ b/deepfake_detection/models/detection/fakeinversion.py @@ -1,23 +1,24 @@ -from typing import Union, List +from typing import Union, List, Sequence import torch +from PIL.Image import Image from torchvision import transforms, models -from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration +from transformers import BlipProcessor, BlipForConditionalGeneration from diffusers import StableDiffusionPipeline, DDIMScheduler from torch.nn import functional as F import numpy as np -from deepfake_detection.data import Instance, Dataset, FileImageInstance, ImageInstance +from deepfake_detection.data import Dataset, FileImageInstance, ImageInstance from deepfake_detection.models.model import Model from deepfake_detection.models.prediction import Prediction -def process_instance(instance: Union[FileImageInstance, ImageInstance]) -> torch.Tensor: +def process_images(images: Sequence[Image]) -> torch.Tensor: preprocess = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor() ]) - return preprocess(instance.data) + return torch.stack([preprocess(i) for i in images]) class FakeInversion(Model): @@ -27,97 +28,85 @@ class FakeInversion(Model): More info about the model can be found here: https://fake-inversion.github.io. """ - def __init__(self, device: str = 'cuda'): + def __init__(self, ckpt: str, device: str = 'cuda'): super().__init__("FakeInversion") self.classifier = None self.captioning = None self.embedding = None self.feature_extractor = None self.device = device + self.ckpt = ckpt def load_model(self): # Define captioning model self.captioning = ImageCaptioning() - # Define embedding model - self.embedding = TextEmbedding() - # Define feature extractor model self.feature_extractor = FeatureExtractor() # Define classifier self.classifier = models.resnet50(pretrained=True) - self.classifier.fc = torch.nn.Linear(self.classifier.fc.in_features, 2) # Binary classification + self.classifier.fc = torch.nn.Linear(self.classifier.fc.in_features, 2) + state_dict = torch.load(self.ckpt, weights_only=True, map_location='cpu') + self.classifier.load_state_dict(state_dict['model']) self.classifier.to(self.device).eval() - def predict(self, instance: Union[ImageInstance, FileImageInstance]) -> Prediction: - + def predict_batch(self, instances: Union[List[Union[ImageInstance, FileImageInstance]], Dataset]) -> List[Prediction]: if not self.classifier: self.load_model() - # Create img tensor - img_tensor = process_instance(instance).to(self.device) - - # Generate caption - caption = self.captioning.get_caption(instance.data) + # Preprocess images and convert to a single 4D Tensor [B, C, H, W] + imgs = [i.data for i in instances] + img_tensor = process_images(imgs).to(self.device) - # Get text embedding - text_embedding = self.embedding.get_embedding(caption) + # Generate captions + captions = self.captioning.get_captions(imgs) # Extract features - latent, noise, reconstructed_image = self.feature_extractor.extract_features(img_tensor, text_embedding) - reconstructed_image = reconstructed_image.squeeze(0) # Remove batch dimension if present + latents, noises, reconstructed_images = self.feature_extractor.extract_features(img_tensor, captions) - # Pass reconstructed_image directly to the classifier - output = self.classifier(reconstructed_image.unsqueeze(0)) + # Pass reconstructed images directly to the classifier + outputs = self.classifier(reconstructed_images.float()) - # Transform to prediction object - prediction = torch.argmax(F.softmax(output, dim=1), dim=1) + # Apply softmax function + class_predictions = torch.argmax(F.softmax(outputs, dim=1), dim=1) - return Prediction(classification={'fake': float(prediction[0]), 'real': 1 - float(prediction[0])}, - embedding=latent.cpu().detach().numpy(), - text=caption, - image=(reconstructed_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) - ) + # Transfer to cpu + class_predictions = class_predictions.cpu().detach().numpy() + latents = latents.cpu().detach().numpy() + reconstructed_images = (reconstructed_images.permute(0, 2, 3, 1).cpu().detach().numpy() * 255).astype(np.uint8) + # Transform to prediction objects + predictions = [] + for i in range(len(instances)): + pred = Prediction(classification={'fake': float(class_predictions[i]), + 'real': 1 - float(class_predictions[i]) + }, + embedding=latents[i], + text=captions[i], + image=reconstructed_images[i] + ) + predictions.append(pred) - def predict_batch(self, instances: Union[List[Instance], Dataset]) -> List[Prediction]: - pass - + return predictions # BLIP: Image Captioning class ImageCaptioning: + def __init__(self, device: str = 'cuda'): self.device = device self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", use_safetensors=True).to(self.device) - def get_caption(self, image): - inputs = self.processor(image, return_tensors="pt").to(self.device) - caption_ids = self.model.generate(**inputs) - caption = self.processor.decode(caption_ids[0], skip_special_tokens=True) - return caption - - -# CLIP: Text Embedding -class TextEmbedding: - def __init__(self, device: str = 'cuda'): - self.device = device - self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", - use_safetensors=True) - self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", - use_safetensors=True).to(self.device) - - def get_embedding(self, caption): - inputs = self.processor(text=[caption], - return_tensors="pt", - padding=True).to(self.device) - text_embedding = self.model.get_text_features(**inputs) - return text_embedding + def get_captions(self, images: List[Image]): + inputs = self.processor(images, return_tensors="pt").to(self.device) + caption_ids = self.model.generate(**inputs, do_sample=False) + captions = self.processor.batch_decode(caption_ids, skip_special_tokens=True) + return captions # Stable Diffusion Feature Extraction @@ -125,30 +114,70 @@ class FeatureExtractor: def __init__(self, model_name="runwayml/stable-diffusion-v1-5", device: str = 'cuda'): self.device = device - self.pipe = StableDiffusionPipeline.from_pretrained(model_name, use_safetensors=True).to(self.device) - # The scheduler configuration is likely located under the model's directory. - # Specify the 'scheduler' subfolder: + self.pipe = StableDiffusionPipeline.from_pretrained(model_name, + use_safetensors=True, + torch_dtype=torch.float16).to(self.device) self.scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") - # Call set_timesteps to initialize num_inference_steps - self.scheduler.set_timesteps(50) # You can adjust the number of steps here + self.scheduler.set_timesteps(50) - def extract_features(self, image, text_embedding): - # Encode the image to latent space - latents = self.pipe.vae.encode(image.unsqueeze(0).to(self.device)).latent_dist.sample() - latents = latents * self.pipe.vae.config.scaling_factor + def extract_features(self, image_tensor: torch.Tensor, captions: List[str], seed: int = 42): + """ + Performs the 'FakeInversion' logic: + 1. Encode image to latent. + 2. Noise latent to t=49. + 3. Predict noise using U-Net conditioned on text embedding. + 4. Reconstruct images from predicted noise. - # Invert using DDIM - noise = torch.randn_like(latents).to(self.device) - inverted_latents = self.scheduler.add_noise(latents, - noise, - torch.tensor([49], device=self.device, dtype=torch.long) - ) + :param image_tensor: Tensor of shape [B, C, H, W] representing the input image. + :param captions: List of length B containing the text prompts. + :param seed: Seed to use for sampling. + """ - # Reconstruct image from inverted latent - with torch.no_grad(): - # Passing noise and inverted_latents as arguments and removing text_embedding as it is not the timestep - reconstructed_latents = self.scheduler.step(noise, 49, inverted_latents).prev_sample - reconstructed_image = self.pipe.vae.decode(reconstructed_latents / self.pipe.vae.config.scaling_factor).sample + # This object manages the random state locally without affecting global PyTorch state + generator = torch.Generator(device=self.device).manual_seed(seed) + + # Ensure input is float16 for the pipeline + image_tensor = image_tensor.to(dtype=torch.float16) - return latents, noise, reconstructed_image \ No newline at end of file + with torch.no_grad(): + # Encode Images (VAE) + latents = self.pipe.vae.encode(image_tensor).latent_dist.sample(generator=generator) + latents = latents * self.pipe.vae.config.scaling_factor + + # Add Noise (Diffusion Process) + # Simulate timestep 49 (near the end of the diffusion process) + t_idx = 49 + # Create a tensor of shape [Batch] filled with 49 + timesteps = torch.full((latents.shape[0],), t_idx, device=self.device, dtype=torch.long) + noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents.dtype) + noisy_latents = self.scheduler.add_noise(latents, noise, timesteps) + + # Get text embeddings of captions + text_embeddings, _ = self.pipe.encode_prompt( + prompt=captions, + device=self.device, + num_images_per_prompt=1, + do_classifier_free_guidance=False + ) + + # Predict Noise (U-Net) + # This detects the artifacts. Real images + Caption != Predicted Noise. + # Fake images + Caption == Predicted Noise (roughly). + noise_pred = self.pipe.unet( + noisy_latents, + timesteps, + encoder_hidden_states=text_embeddings + ).sample + + # Denoise / Reconstruct + # Step back from t=49 to previous step using the model's prediction + reconstructed_latents = self.scheduler.step(noise_pred, t_idx, noisy_latents).prev_sample + + # Decode (VAE) + reconstructed_images = self.pipe.vae.decode( + reconstructed_latents / self.pipe.vae.config.scaling_factor + ).sample + + # Return outputs (keep on GPU for now, cast if needed later) + return latents, noise, reconstructed_images