From 92ec3e7e5cc21b816540d41dae592b75c4de4568 Mon Sep 17 00:00:00 2001 From: Ross Date: Thu, 23 Nov 2023 15:57:01 +0000 Subject: [PATCH] Added SD Inpainting --- ...etrained_sd_inpainting_512_inference.ipynb | 406 ++++++++++++++++++ 1 file changed, 406 insertions(+) create mode 100644 torch-neuronx/inference/hf_pretrained_sd_inpainting_512_inference.ipynb diff --git a/torch-neuronx/inference/hf_pretrained_sd_inpainting_512_inference.ipynb b/torch-neuronx/inference/hf_pretrained_sd_inpainting_512_inference.ipynb new file mode 100644 index 0000000..e413b41 --- /dev/null +++ b/torch-neuronx/inference/hf_pretrained_sd_inpainting_512_inference.ipynb @@ -0,0 +1,406 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## HuggingFace Stable Diffusion Inpainting (512x512) Inference on Inf2" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Introduction**\n", + "\n", + "This notebook demonstrates how to compile and run the HuggingFace Stable Diffusion Inpainting (512x512) model for accelerated inference on Neuron.\n", + "\n", + "This Jupyter notebook should be run on an Inf2 instance (`inf2.8xlarge` or larger)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the [PyTorch Installation Guide](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/setup/torch-neuronx.html#setup-torch-neuronx). You can select the kernel from the 'Kernel -> Change Kernel' option on the top of this Jupyter notebook page." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Install Dependencies**\n", + "\n", + "This tutorial requires the following pip packages to be installed:\n", + "- `torch-neuronx`\n", + "- `neuronx-cc`\n", + "- `diffusers==0.20.0`\n", + "- `transformers==4.26.1`\n", + "- `accelerate==0.16.0`\n", + "- `matplotlib`\n", + "\n", + "`torch-neuronx` and `neuronx-cc` will be installed when you configure your environment following the Inf2 setup guide. The remaining dependencies can be installed below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%env TOKENIZERS_PARALLELISM=True #Supresses tokenizer warnings making errors easier to detect\n", + "!pip install diffusers==0.20.0 transformers==4.26.1 accelerate==0.16.0 matplotlib" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**imports**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + " \n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch_neuronx\n", + "from diffusers import StableDiffusionInpaintPipeline, DPMSolverMultistepScheduler\n", + "from diffusers.models.unet_2d_condition import UNet2DConditionOutput\n", + "from diffusers.models.attention_processor import Attention\n", + " \n", + "from matplotlib import pyplot as plt\n", + "from matplotlib import image as mpimg\n", + "import time\n", + "import copy\n", + "from IPython.display import clear_output\n", + "from PIL import Image\n", + "\n", + "clear_output(wait=False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Define utility classes and functions**\n", + "\n", + "The following section defines some utility classes and functions. In particular, we define a double-wrapper for the UNet. These wrappers enable `torch_neuronx.trace` to trace the wrapped models for compilation with the Neuron compiler. In addition, the `get_attention_scores_neuron` utility function performs optimized attention score calculation and is used to replace the origianl `get_attention_scores` function in the `diffusers` package via a monkey patch (see the next code block under \"Compile UNet and save\" for usage)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_attention_scores_neuron(self, query, key, attn_mask): \n", + " if(query.size() == key.size()):\n", + " attention_scores = custom_badbmm(\n", + " key,\n", + " query.transpose(-1, -2),\n", + " self.scale\n", + " )\n", + " attention_probs = attention_scores.softmax(dim=1).permute(0,2,1)\n", + "\n", + " else:\n", + " attention_scores = custom_badbmm(\n", + " query,\n", + " key.transpose(-1, -2),\n", + " self.scale\n", + " )\n", + " attention_probs = attention_scores.softmax(dim=-1)\n", + " \n", + " return attention_probs\n", + " \n", + "\n", + "def custom_badbmm(a, b, scale):\n", + " bmm = torch.bmm(a, b)\n", + " scaled = bmm * scale\n", + " return scaled\n", + " \n", + "\n", + "class UNetWrap(nn.Module):\n", + " def __init__(self, unet):\n", + " super().__init__()\n", + " self.unet = unet\n", + " \n", + " def forward(self, sample, timestep, encoder_hidden_states, text_embeds=None, time_ids=None):\n", + " out_tuple = self.unet(sample,\n", + " timestep,\n", + " encoder_hidden_states,\n", + " return_dict=False)\n", + " return out_tuple\n", + " \n", + " \n", + "class NeuronUNet(nn.Module):\n", + " def __init__(self, unetwrap):\n", + " super().__init__()\n", + " self.unetwrap = unetwrap\n", + " self.config = unetwrap.unet.config\n", + " self.in_channels = unetwrap.unet.in_channels\n", + " self.device = unetwrap.unet.device\n", + " \n", + " def forward(self, sample, timestep, encoder_hidden_states, added_cond_kwargs=None, return_dict=False, cross_attention_kwargs=None):\n", + " sample = self.unetwrap(sample,\n", + " timestep.float().expand((sample.shape[0],)),\n", + " encoder_hidden_states)[0]\n", + " return UNet2DConditionOutput(sample=sample)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "COMPILER_WORKDIR_ROOT = 'sd_inpainting_compile_dir_512'\n", + "\n", + "# Model ID for SD XL version pipeline\n", + "model_id = \"runwayml/stable-diffusion-inpainting\"\n", + "\n", + "# --- Compile VAE decoder and save ---\n", + "\n", + "# Only keep the model being compiled in RAM to minimze memory pressure\n", + "pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch.float32)\n", + "decoder = copy.deepcopy(pipe.vae.decoder)\n", + "del pipe\n", + "\n", + "# # Compile vae decoder\n", + "decoder_in = torch.randn([1, 4, 64, 64])\n", + "decoder_neuron = torch_neuronx.trace(\n", + " decoder, \n", + " decoder_in, \n", + " compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder'),\n", + ")\n", + "\n", + "\n", + "# Save the compiled vae decoder\n", + "decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder/model.pt')\n", + "torch.jit.save(decoder_neuron, decoder_filename)\n", + "\n", + "# delete unused objects\n", + "del decoder\n", + "\n", + "# --- Compile UNet and save ---\n", + "\n", + "pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch.float32)\n", + "\n", + "# Replace original cross-attention module with custom cross-attention module for better performance\n", + "Attention.get_attention_scores = get_attention_scores_neuron\n", + "\n", + "# Apply double wrapper to deal with custom return type\n", + "pipe.unet = NeuronUNet(UNetWrap(pipe.unet))\n", + "\n", + "# Only keep the model being compiled in RAM to minimze memory pressure\n", + "unet = copy.deepcopy(pipe.unet.unetwrap)\n", + "del pipe\n", + "\n", + "# Compile unet - FP32\n", + "sample_1b = torch.randn([1, 9, 64, 64])\n", + "timestep_1b = torch.tensor(999).float().expand((1,))\n", + "encoder_hidden_states_1b = torch.randn([1, 77, 768])\n", + "example_inputs = (sample_1b, timestep_1b, encoder_hidden_states_1b)\n", + "\n", + "unet_neuron = torch_neuronx.trace(\n", + " unet,\n", + " example_inputs,\n", + " compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'unet'),\n", + " compiler_args=[\"--model-type=unet-inference\"]\n", + ")\n", + "\n", + "# save compiled unet\n", + "unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet/model.pt')\n", + "torch.jit.save(unet_neuron, unet_filename)\n", + "\n", + "# delete unused objects\n", + "del unet\n", + "\n", + "\n", + "# --- Compile VAE post_quant_conv and save ---\n", + "\n", + "# Only keep the model being compiled in RAM to minimze memory pressure\n", + "pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch.float32)\n", + "post_quant_conv = copy.deepcopy(pipe.vae.post_quant_conv)\n", + "del pipe\n", + "\n", + "# Compile vae post_quant_conv\n", + "post_quant_conv_in = torch.randn([1, 4, 64, 64])\n", + "post_quant_conv_neuron = torch_neuronx.trace(\n", + " post_quant_conv, \n", + " post_quant_conv_in,\n", + " compiler_workdir=os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv'),\n", + ")\n", + "\n", + "# Save the compiled vae post_quant_conv\n", + "post_quant_conv_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv/model.pt')\n", + "torch.jit.save(post_quant_conv_neuron, post_quant_conv_filename)\n", + "\n", + "# delete unused objects\n", + "del post_quant_conv" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Load the saved model and run it**\n", + "\n", + "Now that the model is compiled, you can reload it with any number of prompts. Note the use of the `torch_neuronx.DataParallel` API to load the UNet onto two neuron cores for data-parallel inference. Currently the UNet is the only part of the pipeline that runs data-parallel on two cores. All other parts of the pipeline runs on a single Neuron core.\n", + "\n", + "Edit the Prompts below to see what you can create." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# --- Load all compiled models ---\n", + "COMPILER_WORKDIR_ROOT = 'sd_inpainting_compile_dir_512'\n", + "model_id = \"runwayml/stable-diffusion-inpainting\"\n", + "decoder_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_decoder/model.pt')\n", + "unet_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'unet/model.pt')\n", + "post_quant_conv_filename = os.path.join(COMPILER_WORKDIR_ROOT, 'vae_post_quant_conv/model.pt')\n", + "\n", + "pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, torch_dtype=torch.float32)\n", + "pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n", + "\n", + "# Load the compiled UNet onto two neuron cores.\n", + "pipe.unet = NeuronUNet(UNetWrap(pipe.unet))\n", + "device_ids = [0,1]\n", + "pipe.unet.unetwrap = torch_neuronx.DataParallel(torch.jit.load(unet_filename), device_ids, set_dynamic_batching=False)\n", + "\n", + "# Load other compiled models onto a single neuron core.\n", + "pipe.vae.decoder = torch.jit.load(decoder_filename)\n", + "pipe.vae.post_quant_conv = torch.jit.load(post_quant_conv_filename)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Image and Load Image Mask\n", + "Image and mask_image should be PIL images\n", + "The mask structure is white for inpainting and black for keeping it as is. \n", + "\n", + "See https://huggingface.co/runwayml/stable-diffusion-inpainting for sample images. \n", + "\n", + "Save the image and the image mask and modify the filenames below as needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# File names to save\n", + "image_filename = 'image.png'\n", + "mask_image_filename = 'image_mask.png'\n", + "\n", + "# Load images\n", + "image = Image.open(image_filename)\n", + "mask_image = Image.open(mask_image_filename)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Prompt\n", + "prompt = \"Face of a yellow cat, high resolution, sitting on a park bench\"\n", + "\n", + "# Run pipeline\n", + "total_time = 0\n", + "start_time = time.time()\n", + "image_output = pipe(prompt=prompt, image=image, mask_image=mask_image).images[0]\n", + "total_time = total_time + (time.time() - start_time)\n", + "image_output.save('image_output.png')\n", + "plt.title(\"Image\")\n", + "plt.xlabel(\"X pixels scaling\")\n", + "plt.ylabel(\"Y pixels scaling\")\n", + "plt.imshow(image_output)\n", + "plt.show()\n", + "print(\"Total time: \", np.round((total_time/len(prompt)), 2), \"seconds\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Now have Fun**\n", + "\n", + "Uncomment the cell below for interactive experiment with different prompts." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "# user_input = \"\"\n", + "# print(\"Enter Prompt, type exit to quit\")\n", + "# while user_input != \"exit\":\n", + "# total_time = 0\n", + "# user_input = input(\"What prompt would you like to give? \")\n", + "# if user_input == \"exit\":\n", + "# break\n", + "# start_time = time.time()\n", + "# image_output = pipe(prompt=user_input, image=image, mask_image=mask_image).images[0]\n", + "# total_time = total_time + (time.time()-start_time)\n", + "# image_output.save(\"image_output.png\")\n", + "\n", + "# plt.title(\"Image\")\n", + "# plt.xlabel(\"X pixel scaling\")\n", + "# plt.ylabel(\"Y pixels scaling\")\n", + "\n", + "# image_output = mpimg.imread(\"image_output.png\")\n", + "# plt.imshow(image_output)\n", + "# plt.show()\n", + "# print(\"time: \", np.round(total_time, 2), \"seconds\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}