import comfy.sd import comfy.utils import comfy.model_base import comfy.model_management import comfy.model_sampling import torch import folder_paths # Assuming this is available from the original context import json import os # It's good practice to check if args is available, especially if running outside full ComfyUI try: from comfy.cli_args import args except ImportError: class ArgsMock: disable_metadata = False args = ArgsMock() class ModelMergeSimple: # Keeping original for context, not strictly needed for the new node @classmethod def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",), "model2": ("MODEL",), "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, model1, model2, ratio): m = model1.clone() kp = model2.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, ratio, 1.0 - ratio) # Original was 1.0 - ratio, ratio. Swapped to match typical 'ratio applies to model2' return (m, ) class ModelMergeMultiSimple: @classmethod def INPUT_TYPES(s): inputs = {"required": {}} for i in range(1, 6): # 5 models inputs["required"][f"model{i}"] = ("MODEL",) inputs["required"][f"ratio{i}"] = ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}) return inputs RETURN_TYPES = ("MODEL",) FUNCTION = "merge_five" # Changed function name to avoid conflict if in same file CATEGORY = "advanced/model_merging" def merge_five(self, **kwargs): models = [] ratios = [] for i in range(1, 6): model = kwargs.get(f"model{i}") ratio = kwargs.get(f"ratio{i}") if model is not None: # Basic check models.append(model) ratios.append(ratio) elif ratio > 0: # If a ratio is provided for a non-existent model slot (shouldn't happen with type system) print(f"Warning: Ratio {ratio} provided for model{i} but model is missing. Ignoring.") ratios.append(0.0) # Effectively ignore if not models: raise ValueError("No models provided for merging.") # Filter out models with a zero ratio to avoid unnecessary processing # and issues with zero total_ratio if all are zero. active_models_data = [] for model, ratio in zip(models, ratios): if ratio > 0: active_models_data.append({"model": model, "original_ratio": ratio}) if not active_models_data: print("Warning: All model ratios are 0. Returning the first provided model without changes.") return (models[0].clone(), ) # Calculate the sum of original ratios for normalization total_original_ratio = sum(item["original_ratio"] for item in active_models_data) if total_original_ratio == 0: # Should be caught by previous check, but as a safeguard print("Warning: Sum of active model ratios is 0. Returning the first provided model.") return (models[0].clone(), ) # Normalize ratios (these are the w_i in the explanation) normalized_ratios = [item["original_ratio"] / total_original_ratio for item in active_models_data] # Start with the first active model merged_model = active_models_data[0]["model"].clone() if len(active_models_data) == 1: # Only one model has a non-zero ratio, just return its clone return (merged_model,) current_cumulative_normalized_weight = normalized_ratios[0] # Iteratively merge subsequent models for i in range(1, len(active_models_data)): next_model_data = active_models_data[i] next_model_normalized_weight = normalized_ratios[i] # If current_cumulative_normalized_weight is zero (e.g. first model had ratio 0 but others non-zero) # and this is the first *actual* model to process in the loop. if current_cumulative_normalized_weight == 0 and i==0 : #This condition might need adjustment based on active_models_data start merged_model = next_model_data["model"].clone() current_cumulative_normalized_weight = next_model_normalized_weight continue # skip the add_patches for this first assignment # The denominator for scaling factors when adding the next_model # This is (w_accumulated + w_next) denominator = current_cumulative_normalized_weight + next_model_normalized_weight if denominator == 0: # Should not happen if ratios are positive and sum > 0 continue # Strength for the patches from the next_model (w_next / (w_accumulated + w_next)) strength_for_next_model = next_model_normalized_weight / denominator # Strength for the patches already in merged_model (w_accumulated / (w_accumulated + w_next)) strength_for_merged_model_self = current_cumulative_normalized_weight / denominator key_patches = next_model_data["model"].get_key_patches("diffusion_model.") # ComfyUI's add_patches: m.add_patches({k: kp[k]}, strength_for_incoming_patch, strength_for_self_patch) # This means: new_value = self_patch * strength_for_self_patch + incoming_patch * strength_for_incoming_patch for k in key_patches: # Iterate over keys if add_patches doesn't take the whole dict as first arg directly for all keys merged_model.add_patches({k: key_patches[k]}, strength_for_next_model, strength_for_merged_model_self) current_cumulative_normalized_weight += next_model_normalized_weight # Due to potential floating point inaccuracies, it's good to ensure the sum doesn't exceed 1.0 # However, the logic ensures the sum of *normalized original weights* is what we track. # This cumulative weight is the sum of normalized_ratios of models incorporated so far. return (merged_model,) # --- Other classes from your provided code for completeness --- class ModelSubtract: @classmethod def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",), "model2": ("MODEL",), "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, model1, model2, multiplier): m = model1.clone() kp = model2.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, multiplier, -multiplier) # Note: ComfyUI add_patches is (patches, mult, mult_self) # For subtract A - B*mult: m.add_patches({k: kp[k]}, -multiplier, 1.0) # The example in comfy has it as (A,B,mult) -> A - B*mult + B*mult which is just A # (A,B,mult) with add_patches(kp, B, A_self) means A_self*m + B*kp # The ModelSubtract node in Comfy's model_toolרובֿ.py is: # m.add_patches({k: kp[k]}, -multiplier, 1.0) # This results in: final_patch = patch_self * 1.0 + patch_other * (-multiplier) # So, model1_patch - model2_patch * multiplier # Correcting ModelSubtract based on typical understanding (model1 - model2*multiplier) # The provided code had: m.add_patches({k: kp[k]}, - multiplier, multiplier) # This would result in: final = self * mult + other * (-mult) # Let's assume the user wants model1 - model2*multiplier m = model1.clone() kp = model2.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, -multiplier, 1.0) return (m, ) class ModelAdd: @classmethod def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",), "model2": ("MODEL",), }} RETURN_TYPES = ("MODEL",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, model1, model2): m = model1.clone() kp = model2.get_key_patches("diffusion_model.") for k in kp: m.add_patches({k: kp[k]}, 1.0, 1.0) # model1*1.0 + model2*1.0 return (m, ) # ... (rest of your provided classes: CLIPMergeSimple, CLIPSubtract, CLIPAdd, ModelMergeBlocks, save_checkpoint, CheckpointSave, etc.) # It's important that these classes are also defined if you are running this as a standalone script for testing, # or if they are in the same file. For ComfyUI, it will pick them up. # Make sure to add the new class to NODE_CLASS_MAPPINGS # Original mappings: # NODE_CLASS_MAPPINGS = { # "ModelMergeSimple": ModelMergeSimple, # "ModelMergeBlocks": ModelMergeBlocks, # "ModelMergeSubtract": ModelSubtract, # "ModelMergeAdd": ModelAdd, # # ... other mappings # } # Add the new node: # Assuming NODE_CLASS_MAPPINGS and NODE_DISPLAY_NAME_MAPPINGS are defined at the end of the file. # Placeholder for the rest of the code structure class CLIPMergeSimple: @classmethod def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("CLIP",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, clip1, clip2, ratio): return (clip1, ) # Placeholder class CLIPSubtract: @classmethod def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),}} RETURN_TYPES = ("CLIP",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, clip1, clip2, multiplier): return (clip1,) # Placeholder class CLIPAdd: @classmethod def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",),"clip2": ("CLIP",),}} RETURN_TYPES = ("CLIP",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, clip1, clip2): return (clip1,) # Placeholder class ModelMergeBlocks: @classmethod def INPUT_TYPES(s): return {"required": { "model1": ("MODEL",),"model2": ("MODEL",),"input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),"middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),"out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})}} RETURN_TYPES = ("MODEL",) FUNCTION = "merge" CATEGORY = "advanced/model_merging" def merge(self, model1, model2, **kwargs): return (model1,) # Placeholder def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefix=None, output_dir=None, prompt=None, extra_pnginfo=None): pass # Placeholder class CheckpointSave: def __init__(self): self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",),"clip": ("CLIP",),"vae": ("VAE",),"filename_prefix": ("STRING", {"default": "checkpoints/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder class CLIPSave: def __init__(self): self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "clip": ("CLIP",),"filename_prefix": ("STRING", {"default": "clip/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder class VAESave: def __init__(self): self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "vae": ("VAE",),"filename_prefix": ("STRING", {"default": "vae/ComfyUI_vae"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder class ModelSave: def __init__(self): self.output_dir = folder_paths.get_output_directory() @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",),"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},} RETURN_TYPES = () FUNCTION = "save" OUTPUT_NODE = True CATEGORY = "advanced/model_merging" def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None): return {} # Placeholder NODE_CLASS_MAPPINGS = { "ModelMergeSimple": ModelMergeSimple, "ModelMergeMultiSimple": ModelMergeMultiSimple, # Added new class "ModelMergeBlocks": ModelMergeBlocks, "ModelMergeSubtract": ModelSubtract, "ModelMergeAdd": ModelAdd, "CheckpointSave": CheckpointSave, "CLIPMergeSimple": CLIPMergeSimple, "CLIPMergeSubtract": CLIPSubtract, "CLIPMergeAdd": CLIPAdd, "CLIPSave": CLIPSave, "VAESave": VAESave, "ModelSave": ModelSave, } NODE_DISPLAY_NAME_MAPPINGS = { "ModelMergeSimple": "Model Merge Simple (2 Models)", # Clarified original "ModelMergeMultiSimple": "Model Merge Multi Simple (5 Models)", # Added new display name "ModelMergeBlocks": "Model Merge Blocks", "ModelMergeSubtract": "Model Subtract", "ModelMergeAdd": "Model Add", "CheckpointSave": "Save Checkpoint", "CLIPMergeSimple": "CLIP Merge Simple", "CLIPMergeSubtract": "CLIP Subtract", "CLIPMergeAdd": "CLIP Add", "CLIPSave": "CLIP Save", "VAESave": "VAE Save", "ModelSave": "Model Save", } print("Custom model merging nodes loaded.")