YiYiXu's picture
Update block.py
829f347 verified
from diffusers.modular_pipelines import (
PipelineBlock,
InputParam,
OutputParam,
ConfigSpec,
)
from diffusers.utils import load_image
from PIL import Image
from typing import Union, Tuple
import numpy as np
class Wan14BImageProcessor(PipelineBlock):
@property
def description(self):
return "default Image Processor for wan14B i2v (for both Wan2.1 and Wan2.2, it resizes image"
@property
def inputs(self):
return [
InputParam(name="image", type_hint=Union[Image.Image, str], description= "the Image to process"),
InputParam(name="max_area", type_hint=int, description= "the maximum area of the Image to process")
]
@property
def intermediate_outputs(self):
return [
OutputParam(name="processed_image", type_hint=Image.Image, description= "the processed Image"),
]
@property
def expected_configs(self):
return [
ConfigSpec(name="patch_size", default=(1, 2, 2)),
ConfigSpec(name="vae_stride", default=(4, 8, 8)),
]
def __call__(self, components, state):
block_state = self.get_block_state(state)
if isinstance(block_state.image, str):
image = load_image(block_state.image).convert("RGB")
elif isinstance(block_state.image, Image.Image):
image = block_state.image
else:
raise ValueError(f"Invalid image type: {type(block_state.image)}; only support PIL Image or url string")
if block_state.max_area is None:
max_area = 480 * 832
else:
max_area = block_state.max_area
aspect_ratio = image.height / image.width
mod_value_height = components.vae_stride[1] * components.patch_size[1]
mod_value_width = components.vae_stride[2] * components.patch_size[2]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value_height * mod_value_height
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value_width * mod_value_width
resized_image = image.resize((width, height))
block_state.processed_image = resized_image
print(f" initial image size: {image.size}")
print(f" processed image size: {resized_image.size}")
self.set_block_state(state, block_state)
return components, state