from safetensors.torch import safe_open, save_file
import torch 
import glob 
from tqdm import tqdm 


def convertCheckpoint_toTensor(checkpoint):
    """
    Args:
        checkpoint:

    Returns:
        tensor:
        parameter_sizes
    """
    parameters = []
    parameter_sizes = []
    for parameter_name, parameter_value in checkpoint.items():
        parameters.append(parameter_value.flatten().contiguous())
        torch.cuda.empty_cache()
        parameter_sizes.append((parameter_name, parameter_value.shape))
    tensor = torch.cat(parameters, dim=0).contiguous()
    return tensor, parameter_sizes


def convertTensor_toCheckpoint(
    tensor, parameter_sizes
):
    """

    Args:
        tensor:
        parameter_sizes:

    Returns:
        checkpoint
    """
    checkpoint = {}
    start_idx = 0
    for parameter_name, parameter_shape in parameter_sizes:
        parameter_size = parameter_shape.numel()
        end_idx = start_idx + parameter_size

        # It was causing memory issues without cloning. Probably because the memory cannot be freed
        # otherwise
        checkpoint[parameter_name] = torch.clone(
            tensor[start_idx:end_idx].reshape(parameter_shape).contiguous()
        )
        start_idx = end_idx
    return checkpoint


def topk_values_mask(M, K=20, return_mask=False):
    if K > 1:
        K /= 100

    original_shape = M.shape
    if M.dim() == 1:
        M = M.unsqueeze(0)

    n, d = M.shape
    k = int(d * K)
    k = d - k  # Keep top k elements instead of bottom k elements

    # Find the k-th smallest element by magnitude for each row
    kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
    # Create a mask tensor with True for the top k elements in each row
    mask = M.abs() >= kth_values
    final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask

    if return_mask:
        return M * final_mask, final_mask.float().mean(dim=1), final_mask

    return M * final_mask


if __name__ == "__main__":
    model1 = "clip.safetensors"
    model2 = "info.safetensors"

    m1 = {}
    with safe_open(model1, framework="pt", device=0) as f:
        for k in f.keys():
            m1[k] = f.get_tensor(k)

    m2 = {}
    with safe_open(model2, framework="pt", device=0) as f:
        for k in f.keys():
            m2[k] = f.get_tensor(k)

    an = {}
    with safe_open("merged_AB_ties.safetensors", framework="pt", device=0) as f:
        for k in f.keys():
            an[k] = f.get_tensor(k)

    save_file(an, "an.safetensors")

    breakpoint()


    m1t, m1s = convertCheckpoint_toTensor(m1)
    m2t, m2s = convertCheckpoint_toTensor(m2)

    # trim top 20% of each tensor
    m1t = topk_values_mask(m1t, 20).squeeze()
    m2t = topk_values_mask(m2t, 20).squeeze()

    # get back stats
    m1_stats = convertTensor_toCheckpoint(m1t, m1s)
    m2_stats = convertTensor_toCheckpoint(m2t, m2s)

    # save files
    save_file(m1_stats, "clip_trimmed.safetensors")
    save_file(m2_stats, "info_trimmed.safetensors")

    # save as pt files
    torch.save(m1t, "clip_trimmed.pt")
    torch.save(m2t, "info_trimmed.pt")
    
"""
- 
evalmerge git-theta-merge-cli \                       
  --models clip.safetensors \
           info.safetensors \
  --aux_data clip_trimmed.safetensors \
             info_trimmed.safetensors \
  --ancestor an.safetensors \
  --output merged_AB_ties.safetensors \
  --x:merge_lambda 0.5 \
  --merge lora-ties-gpu

infograph-building
clipart-cloth

"""