diff --git "a/create_handler.ipynb" "b/create_handler.ipynb" new file mode 100644--- /dev/null +++ "b/create_handler.ipynb" @@ -0,0 +1,292 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Custom Handler for Inference Endpoints\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: diffusers in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (0.9.0)\n", + "Collecting diffusers\n", + " Using cached diffusers-0.10.2-py3-none-any.whl (503 kB)\n", + "Requirement already satisfied: importlib-metadata in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (4.11.4)\n", + "Requirement already satisfied: filelock in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (3.8.0)\n", + "Requirement already satisfied: numpy in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (1.22.4)\n", + "Requirement already satisfied: Pillow in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (9.2.0)\n", + "Requirement already satisfied: huggingface-hub>=0.10.0 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (0.11.1)\n", + "Requirement already satisfied: requests in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (2.28.1)\n", + "Requirement already satisfied: regex!=2019.12.17 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (2022.7.25)\n", + "Requirement already satisfied: packaging>=20.9 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from huggingface-hub>=0.10.0->diffusers) (21.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from huggingface-hub>=0.10.0->diffusers) (4.3.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from huggingface-hub>=0.10.0->diffusers) (6.0)\n", + "Requirement already satisfied: tqdm in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from huggingface-hub>=0.10.0->diffusers) (4.64.0)\n", + "Requirement already satisfied: zipp>=0.5 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from importlib-metadata->diffusers) (3.8.1)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from requests->diffusers) (3.3)\n", + "Requirement already satisfied: charset-normalizer<3,>=2 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from requests->diffusers) (2.1.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from requests->diffusers) (2022.6.15)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from requests->diffusers) (1.26.11)\n", + "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from packaging>=20.9->huggingface-hub>=0.10.0->diffusers) (3.0.9)\n", + "Installing collected packages: diffusers\n", + " Attempting uninstall: diffusers\n", + " Found existing installation: diffusers 0.9.0\n", + " Uninstalling diffusers-0.9.0:\n", + " Successfully uninstalled diffusers-0.9.0\n", + "Successfully installed diffusers-0.10.2\n" + ] + } + ], + "source": [ + "!pip install diffusers --upgrade" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "\n", + "if device.type != 'cuda':\n", + " raise ValueError(\"need to run on GPU\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting handler.py\n" + ] + } + ], + "source": [ + "%%writefile handler.py\n", + "from typing import Dict, List, Any\n", + "import torch\n", + "from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline\n", + "from PIL import Image\n", + "import base64\n", + "from io import BytesIO\n", + "\n", + "\n", + "# set device\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "if device.type != 'cuda':\n", + " raise ValueError(\"need to run on GPU\")\n", + "\n", + "class EndpointHandler():\n", + " def __init__(self, path=\"\"):\n", + " # load StableDiffusionInpaintPipeline pipeline\n", + " self.pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16)\n", + " # use DPMSolverMultistepScheduler\n", + " self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)\n", + " # move to device \n", + " self.pipe = self.pipe.to(device)\n", + "\n", + "\n", + " def __call__(self, data: Any) -> List[List[Dict[str, float]]]:\n", + " \"\"\"\n", + " :param data: A dictionary contains `inputs` and optional `image` field.\n", + " :return: A dictionary with `image` field contains image in base64.\n", + " \"\"\"\n", + " inputs = data.pop(\"inputs\", data)\n", + " encoded_image = data.pop(\"image\", None)\n", + " encoded_mask_image = data.pop(\"mask_image\", None)\n", + " \n", + " # hyperparamters\n", + " num_inference_steps = data.pop(\"num_inference_steps\", 25)\n", + " guidance_scale = data.pop(\"guidance_scale\", 7.5)\n", + " negative_prompt = data.pop(\"negative_prompt\", None)\n", + " height = data.pop(\"height\", None)\n", + " width = data.pop(\"width\", None)\n", + " \n", + " # process image\n", + " if encoded_image is not None and encoded_mask_image is not None:\n", + " image = self.decode_base64_image(encoded_image)\n", + " mask_image = self.decode_base64_image(encoded_mask_image)\n", + " else:\n", + " image = None\n", + " mask_image = None \n", + " \n", + " # run inference pipeline\n", + " out = self.pipe(inputs, \n", + " image=image, \n", + " mask_image=mask_image, \n", + " num_inference_steps=num_inference_steps,\n", + " guidance_scale=guidance_scale,\n", + " num_images_per_prompt=1,\n", + " negative_prompt=negative_prompt,\n", + " height=height,\n", + " width=width\n", + " )\n", + " \n", + " # return first generate PIL image\n", + " return out.images[0]\n", + " \n", + " # helper to decode input image\n", + " def decode_base64_image(self, image_string):\n", + " base64_image = base64.b64decode(image_string)\n", + " buffer = BytesIO(base64_image)\n", + " image = Image.open(buffer)\n", + " return image" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from handler import EndpointHandler\n", + "\n", + "# init handler\n", + "my_handler = EndpointHandler(path=\".\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8523b8998b74472ead35a11270dff3a5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/25 [00:00" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "pred.save(\"result.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.13 ('dev': conda)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}