handler.py
Browse files- handler.py +50 -0
    	
        handler.py
    ADDED
    
    | @@ -0,0 +1,50 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import  Dict, List, Any
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, AutoPipelineForInpainting, AutoPipelineForImage2Image, StableDiffusionXLImg2ImgPipeline
         | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
            import base64
         | 
| 6 | 
            +
            from io import BytesIO
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            # set device
         | 
| 10 | 
            +
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            if device.type != 'cuda':
         | 
| 13 | 
            +
                raise ValueError("need to run on GPU")
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            class EndpointHandler():
         | 
| 16 | 
            +
                def __init__(self, path=""):
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                    #self.fast_pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to("cuda")
         | 
| 19 | 
            +
                    #self.generator = torch.Generator(device="cuda").manual_seed(0)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
                    self.smooth_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
         | 
| 23 | 
            +
                      "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
         | 
| 24 | 
            +
                    )
         | 
| 25 | 
            +
                    self.smooth_pipe.to("cuda")
         | 
| 26 | 
            +
                    
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
                def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
         | 
| 30 | 
            +
                    """
         | 
| 31 | 
            +
                    :param data: A dictionary contains `inputs` and optional `image` field.
         | 
| 32 | 
            +
                    :return: A dictionary with `image` field contains image in base64.
         | 
| 33 | 
            +
                    """
         | 
| 34 | 
            +
                    encoded_image = data.pop("image", None)
         | 
| 35 | 
            +
                    
         | 
| 36 | 
            +
                    prompt = data.pop("prompt", "")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
                    if encoded_image is not None:
         | 
| 40 | 
            +
                        image = self.decode_base64_image(encoded_image)
         | 
| 41 | 
            +
                        out = self.smooth_pipe(prompt, image=image).images[0]
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                        return out
         | 
| 44 | 
            +
                
         | 
| 45 | 
            +
                # helper to decode input image
         | 
| 46 | 
            +
                def decode_base64_image(self, image_string):
         | 
| 47 | 
            +
                    base64_image = base64.b64decode(image_string)
         | 
| 48 | 
            +
                    buffer = BytesIO(base64_image)
         | 
| 49 | 
            +
                    image = Image.open(buffer)
         | 
| 50 | 
            +
                    return image
         |