|
from diffusers.modular_pipelines import ( |
|
PipelineBlock, |
|
InputParam, |
|
OutputParam, |
|
ConfigSpec, |
|
) |
|
|
|
from diffusers.utils import load_image |
|
from PIL import Image |
|
from typing import Union, Tuple |
|
import numpy as np |
|
|
|
|
|
|
|
class Wan14BImageProcessor(PipelineBlock): |
|
|
|
@property |
|
def description(self): |
|
return "default Image Processor for wan14B i2v (for both Wan2.1 and Wan2.2, it resizes image" |
|
|
|
@property |
|
def inputs(self): |
|
return [ |
|
InputParam(name="image", type_hint=Union[Image.Image, str], description= "the Image to process"), |
|
InputParam(name="max_area", type_hint=int, description= "the maximum area of the Image to process") |
|
] |
|
|
|
@property |
|
def intermediate_outputs(self): |
|
return [ |
|
OutputParam(name="processed_image", type_hint=Image.Image, description= "the processed Image"), |
|
] |
|
|
|
@property |
|
def expected_configs(self): |
|
return [ |
|
ConfigSpec(name="patch_size", default=(1, 2, 2)), |
|
ConfigSpec(name="vae_stride", default=(4, 8, 8)), |
|
] |
|
|
|
def __call__(self, components, state): |
|
|
|
block_state = self.get_block_state(state) |
|
|
|
if isinstance(block_state.image, str): |
|
image = load_image(block_state.image).convert("RGB") |
|
elif isinstance(block_state.image, Image.Image): |
|
image = block_state.image |
|
else: |
|
raise ValueError(f"Invalid image type: {type(block_state.image)}; only support PIL Image or url string") |
|
|
|
if block_state.max_area is None: |
|
max_area = 480 * 832 |
|
else: |
|
max_area = block_state.max_area |
|
|
|
aspect_ratio = image.height / image.width |
|
mod_value_height = components.vae_stride[1] * components.patch_size[1] |
|
mod_value_width = components.vae_stride[2] * components.patch_size[2] |
|
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value_height * mod_value_height |
|
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value_width * mod_value_width |
|
resized_image = image.resize((width, height)) |
|
|
|
|
|
block_state.processed_image = resized_image |
|
|
|
print(f" initial image size: {image.size}") |
|
print(f" processed image size: {resized_image.size}") |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
return components, state |