Commit 
							
							·
						
						5f12b41
	
1
								Parent(s):
							
							0127b45
								
up
Browse files- controlnet_img2img.py +77 -0
    	
        controlnet_img2img.py
    ADDED
    
    | @@ -0,0 +1,77 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python3
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            from huggingface_hub import HfApi
         | 
| 5 | 
            +
            from pathlib import Path
         | 
| 6 | 
            +
            from diffusers.utils import load_image
         | 
| 7 | 
            +
            import cv2
         | 
| 8 | 
            +
            from PIL import Image
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from diffusers import (
         | 
| 12 | 
            +
                ControlNetModel,
         | 
| 13 | 
            +
                StableDiffusionControlNetImg2ImgPipeline,
         | 
| 14 | 
            +
                StableDiffusionControlNetInpaintPipeline,
         | 
| 15 | 
            +
                DiffusionPipeline,
         | 
| 16 | 
            +
                UniPCMultistepScheduler,
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
            import sys
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            checkpoint = sys.argv[1]
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            # image = load_image(
         | 
| 23 | 
            +
            #     "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png"
         | 
| 24 | 
            +
            # )
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
         | 
| 27 | 
            +
            mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
         | 
| 28 | 
            +
            image = load_image(img_url).resize((512, 512))
         | 
| 29 | 
            +
            mask_image = load_image(mask_url).resize((512, 512))
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            np_image = np.array(image)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            low_threshold = 100
         | 
| 34 | 
            +
            high_threshold = 200
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            np_image = cv2.Canny(np_image, low_threshold, high_threshold)
         | 
| 37 | 
            +
            np_image = np_image[:, :, None]
         | 
| 38 | 
            +
            np_image = np.concatenate([np_image, np_image, np_image], axis=2)
         | 
| 39 | 
            +
            canny_image = Image.fromarray(np_image)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            controlnet = ControlNetModel.from_pretrained(checkpoint, torch_dtype=torch.float16)
         | 
| 42 | 
            +
            # pipe = DiffusionPipeline.from_pretrained(
         | 
| 43 | 
            +
            #    "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, custom_pipeline="stable_diffusion_controlnet_inpaint"
         | 
| 44 | 
            +
            # )
         | 
| 45 | 
            +
            pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
         | 
| 46 | 
            +
                "runwayml/stable-diffusion-inpainting",
         | 
| 47 | 
            +
                controlnet=controlnet,
         | 
| 48 | 
            +
                torch_dtype=torch.float16,
         | 
| 49 | 
            +
            )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
         | 
| 52 | 
            +
            pipe.enable_model_cpu_offload()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            generator = torch.manual_seed(0)
         | 
| 55 | 
            +
            text_prompt="a blue dog"
         | 
| 56 | 
            +
            # out_image = pipe("A blue dog", num_inference_steps=50, generator=generator, image=image, mask_image=mask_image, controlnet_conditioning_image=canny_image).images[0]
         | 
| 57 | 
            +
            out_image = pipe(
         | 
| 58 | 
            +
                text_prompt,
         | 
| 59 | 
            +
                num_inference_steps=20,
         | 
| 60 | 
            +
                generator=generator,
         | 
| 61 | 
            +
                image=image,
         | 
| 62 | 
            +
                mask_image=mask_image,
         | 
| 63 | 
            +
                control_image=canny_image,
         | 
| 64 | 
            +
            ).images[0]
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            path = os.path.join(Path.home(), "images", "aa.png")
         | 
| 67 | 
            +
            out_image.save(path)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            api = HfApi()
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            api.upload_file(
         | 
| 72 | 
            +
                path_or_fileobj=path,
         | 
| 73 | 
            +
                path_in_repo=path.split("/")[-1],
         | 
| 74 | 
            +
                repo_id="patrickvonplaten/images",
         | 
| 75 | 
            +
                repo_type="dataset",
         | 
| 76 | 
            +
            )
         | 
| 77 | 
            +
            print("https://huggingface.co/datasets/patrickvonplaten/images/blob/main/aa.png")
         | 

