Model_Blend_Node_w_Workflow / MultiMergeSimple.py
ND911's picture
Upload 3 files
4432cc6 verified
import comfy.sd
import comfy.utils
import comfy.model_base
import comfy.model_management
import comfy.model_sampling
import torch
import folder_paths # Assuming this is available from the original context
import json
import os
# It's good practice to check if args is available, especially if running outside full ComfyUI
try:
from comfy.cli_args import args
except ImportError:
class ArgsMock:
disable_metadata = False
args = ArgsMock()
class ModelMergeSimple: # Keeping original for context, not strictly needed for the new node
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, ratio):
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, ratio, 1.0 - ratio) # Original was 1.0 - ratio, ratio. Swapped to match typical 'ratio applies to model2'
return (m, )
class ModelMergeMultiSimple:
@classmethod
def INPUT_TYPES(s):
inputs = {"required": {}}
for i in range(1, 6): # 5 models
inputs["required"][f"model{i}"] = ("MODEL",)
inputs["required"][f"ratio{i}"] = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01})
return inputs
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge_five" # Changed function name to avoid conflict if in same file
CATEGORY = "advanced/model_merging"
def merge_five(self, **kwargs):
models = []
ratios = []
for i in range(1, 6):
model = kwargs.get(f"model{i}")
ratio = kwargs.get(f"ratio{i}")
if model is not None: # Basic check
models.append(model)
ratios.append(ratio)
elif ratio > 0:
# If a ratio is provided for a non-existent model slot (shouldn't happen with type system)
print(f"Warning: Ratio {ratio} provided for model{i} but model is missing. Ignoring.")
ratios.append(0.0) # Effectively ignore
if not models:
raise ValueError("No models provided for merging.")
# Filter out models with a zero ratio to avoid unnecessary processing
# and issues with zero total_ratio if all are zero.
active_models_data = []
for model, ratio in zip(models, ratios):
if ratio > 0:
active_models_data.append({"model": model, "original_ratio": ratio})
if not active_models_data:
print("Warning: All model ratios are 0. Returning the first provided model without changes.")
return (models[0].clone(), )
# Calculate the sum of original ratios for normalization
total_original_ratio = sum(item["original_ratio"] for item in active_models_data)
if total_original_ratio == 0: # Should be caught by previous check, but as a safeguard
print("Warning: Sum of active model ratios is 0. Returning the first provided model.")
return (models[0].clone(), )
# Normalize ratios (these are the w_i in the explanation)
normalized_ratios = [item["original_ratio"] / total_original_ratio for item in active_models_data]
# Start with the first active model
merged_model = active_models_data[0]["model"].clone()
if len(active_models_data) == 1:
# Only one model has a non-zero ratio, just return its clone
return (merged_model,)
current_cumulative_normalized_weight = normalized_ratios[0]
# Iteratively merge subsequent models
for i in range(1, len(active_models_data)):
next_model_data = active_models_data[i]
next_model_normalized_weight = normalized_ratios[i]
# If current_cumulative_normalized_weight is zero (e.g. first model had ratio 0 but others non-zero)
# and this is the first *actual* model to process in the loop.
if current_cumulative_normalized_weight == 0 and i==0 : #This condition might need adjustment based on active_models_data start
merged_model = next_model_data["model"].clone()
current_cumulative_normalized_weight = next_model_normalized_weight
continue # skip the add_patches for this first assignment
# The denominator for scaling factors when adding the next_model
# This is (w_accumulated + w_next)
denominator = current_cumulative_normalized_weight + next_model_normalized_weight
if denominator == 0: # Should not happen if ratios are positive and sum > 0
continue
# Strength for the patches from the next_model (w_next / (w_accumulated + w_next))
strength_for_next_model = next_model_normalized_weight / denominator
# Strength for the patches already in merged_model (w_accumulated / (w_accumulated + w_next))
strength_for_merged_model_self = current_cumulative_normalized_weight / denominator
key_patches = next_model_data["model"].get_key_patches("diffusion_model.")
# ComfyUI's add_patches: m.add_patches({k: kp[k]}, strength_for_incoming_patch, strength_for_self_patch)
# This means: new_value = self_patch * strength_for_self_patch + incoming_patch * strength_for_incoming_patch
for k in key_patches: # Iterate over keys if add_patches doesn't take the whole dict as first arg directly for all keys
merged_model.add_patches({k: key_patches[k]}, strength_for_next_model, strength_for_merged_model_self)
current_cumulative_normalized_weight += next_model_normalized_weight
# Due to potential floating point inaccuracies, it's good to ensure the sum doesn't exceed 1.0
# However, the logic ensures the sum of *normalized original weights* is what we track.
# This cumulative weight is the sum of normalized_ratios of models incorporated so far.
return (merged_model,)
# --- Other classes from your provided code for completeness ---
class ModelSubtract:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, multiplier):
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, multiplier, -multiplier) # Note: ComfyUI add_patches is (patches, mult, mult_self)
# For subtract A - B*mult: m.add_patches({k: kp[k]}, -multiplier, 1.0)
# The example in comfy has it as (A,B,mult) -> A - B*mult + B*mult which is just A
# (A,B,mult) with add_patches(kp, B, A_self) means A_self*m + B*kp
# The ModelSubtract node in Comfy's model_toolרובֿ.py is:
# m.add_patches({k: kp[k]}, -multiplier, 1.0)
# This results in: final_patch = patch_self * 1.0 + patch_other * (-multiplier)
# So, model1_patch - model2_patch * multiplier
# Correcting ModelSubtract based on typical understanding (model1 - model2*multiplier)
# The provided code had: m.add_patches({k: kp[k]}, - multiplier, multiplier)
# This would result in: final = self * mult + other * (-mult)
# Let's assume the user wants model1 - model2*multiplier
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, -multiplier, 1.0)
return (m, )
class ModelAdd:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model1": ("MODEL",),
"model2": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2):
m = model1.clone()
kp = model2.get_key_patches("diffusion_model.")
for k in kp:
m.add_patches({k: kp[k]}, 1.0, 1.0) # model1*1.0 + model2*1.0
return (m, )
# ... (rest of your provided classes: CLIPMergeSimple, CLIPSubtract, CLIPAdd, ModelMergeBlocks, save_checkpoint, CheckpointSave, etc.)
# It's important that these classes are also defined if you are running this as a standalone script for testing,
# or if they are in the same file. For ComfyUI, it will pick them up.
# Make sure to add the new class to NODE_CLASS_MAPPINGS
# Original mappings:
# NODE_CLASS_MAPPINGS = {
# "ModelMergeSimple": ModelMergeSimple,
# "ModelMergeBlocks": ModelMergeBlocks,
# "ModelMergeSubtract": ModelSubtract,
# "ModelMergeAdd": ModelAdd,
# # ... other mappings
# }
# Add the new node:
# Assuming NODE_CLASS_MAPPINGS and NODE_DISPLAY_NAME_MAPPINGS are defined at the end of the file.
# Placeholder for the rest of the code structure
class CLIPMergeSimple:
@classmethod
def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, ratio): return (clip1, ) # Placeholder
class CLIPSubtract:
@classmethod
def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2, multiplier): return (clip1,) # Placeholder
class CLIPAdd:
@classmethod
def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, clip1, clip2): return (clip1,) # Placeholder
class ModelMergeBlocks:
@classmethod
def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",),"model2": ("MODEL",),"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "merge"
CATEGORY = "advanced/model_merging"
def merge(self, model1, model2, **kwargs): return (model1,) # Placeholder
def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): pass # Placeholder
class CheckpointSave:
def __init__(self): self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s): return {"required": { "model": ("MODEL",),"clip": ("CLIP",),"vae": ("VAE",),"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder
class CLIPSave:
def __init__(self): self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s): return {"required": { "clip": ("CLIP",),"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder
class VAESave:
def __init__(self): self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s): return {"required": { "vae": ("VAE",),"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder
class ModelSave:
def __init__(self): self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s): return {"required": { "model": ("MODEL",),"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeMultiSimple": ModelMergeMultiSimple, # Added new class
"ModelMergeBlocks": ModelMergeBlocks,
"ModelMergeSubtract": ModelSubtract,
"ModelMergeAdd": ModelAdd,
"CheckpointSave": CheckpointSave,
"CLIPMergeSimple": CLIPMergeSimple,
"CLIPMergeSubtract": CLIPSubtract,
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
"ModelSave": ModelSave,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"ModelMergeSimple": "Model Merge Simple (2 Models)", # Clarified original
"ModelMergeMultiSimple": "Model Merge Multi Simple (5 Models)", # Added new display name
"ModelMergeBlocks": "Model Merge Blocks",
"ModelMergeSubtract": "Model Subtract",
"ModelMergeAdd": "Model Add",
"CheckpointSave": "Save Checkpoint",
"CLIPMergeSimple": "CLIP Merge Simple",
"CLIPMergeSubtract": "CLIP Subtract",
"CLIPMergeAdd": "CLIP Add",
"CLIPSave": "CLIP Save",
"VAESave": "VAE Save",
"ModelSave": "Model Save",
}
print("Custom model merging nodes loaded.")