issue in lora loading with nunchaku fp4 fluxkontext

#3
by tradetheunicorn - opened

I tried the lora loading with nunchaku fp4 fluxkontext transformer (https://huggingface.co/mit-han-lab/nunchaku-flux.1-kontext-dev
) for these loras:

style_type_lora_dict = {
"3D_Chibi": "3D_Chibi_lora_weights.safetensors",
"American_Cartoon": "American_Cartoon_lora_weights.safetensors",
"Chinese_Ink": "Chinese_Ink_lora_weights.safetensors",
"Clay_Toy": "Clay_Toy_lora_weights.safetensors",
"Fabric": "Fabric_lora_weights.safetensors",
"Ghibli": "Ghibli_lora_weights.safetensors",
"Irasutoya": "Irasutoya_lora_weights.safetensors",
"Jojo": "Jojo_lora_weights.safetensors",
"Oil_Painting": "Oil_Painting_lora_weights.safetensors",
"Pixel": "Pixel_lora_weights.safetensors",
"Snoopy": "Snoopy_lora_weights.safetensors",
"Poly": "Poly_lora_weights.safetensors",
"LEGO": "LEGO_lora_weights.safetensors",
"Origami" : "Origami_lora_weights.safetensors",
"Pop_Art" : "Pop_Art_lora_weights.safetensors",
"Van_Gogh" : "Van_Gogh_lora_weights.safetensors",
"Paper_Cutting" : "Paper_Cutting_lora_weights.safetensors",
"Line" : "Line_lora_weights.safetensors",
"Vector" : "Vector_lora_weights.safetensors",
"Picasso" : "Picasso_lora_weights.safetensors",
"Macaron" : "Macaron_lora_weights.safetensors",
"Rick_Morty" : "Rick_Morty_lora_weights.safetensors"
}

And I got same error for each one:

Target modules {'attn.to_k', 'ff.net.0.proj', 'attn.add_q_proj', 'attn.add_k_proj', 'ff_context.net.0.proj', 'ff.net.2', 'ff_context.net.2', 'attn.to_v', 'attn.add_v_proj', 'attn.to_add_out', 'attn.to_q', 'attn.to_out.0'} not found in the base model. Please check the target modules and try again.

@Owen718 please let me know how can I resolve this?

I think the error you're encountering is due to using the pipeline.load_lora_weights method:

pipeline.load_lora_weights("Kontext-Style-Loras/Ghibli_lora_weights.safetensors", adapter_name="lora")
pipeline.set_adapters(["lora"], adapter_weights=[1])

Instead, we used the following code, which avoids this issue:

transformer.update_lora_params(
    "Kontext-Style-Loras/Ghibli_lora_weights.safetensors"
)  # Path to your LoRA safetensors, can also be a remote HuggingFace path
transformer.set_lora_strength(1)  # Your LoRA strength here

Here’s the complete code:

import torch
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel, NunchakuT5EncoderModel
from nunchaku.utils import get_precision

transformer = NunchakuFluxTransformer2dModel.from_pretrained(
    f"nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
)
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors")

pipeline = FluxKontextPipeline.from_pretrained(
    "FLUX.1-Kontext-dev", transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=torch.bfloat16
).to("cuda")

### LoRA Related Code ###
transformer.update_lora_params(
    "Kontext-Style-Loras/Ghibli_lora_weights.safetensors"
)  # Path to your LoRA safetensors, can also be a remote HuggingFace path
transformer.set_lora_strength(1)  # Your LoRA strength here
### End of LoRA Related Code ###

image = load_image(
    "ba8c75293f0f60353f6afb4b76e7eda0_input_image.png"
).convert("RGB")

# Prepare the prompt
# The style_name is used in the prompt and for the output filename.
style_name = "Ghibli"
prompt = f"convert this image into the Ghibli style."

# Run inference
result_image = pipeline(
    image=image, 
    prompt=prompt, 
    height=1024, 
    width=1024, 
    num_inference_steps=24,
    guidance_scale=2.5
).images[0]

# Save the result
output_filename = f"{style_name.replace(' ', '_')}.png"
result_image.save(output_filename)

print(f"Image saved as {output_filename}")

Sign up or log in to comment