from safetensors.torch import safe_open, save_file
import torch 
import glob 
import os
from tqdm import tqdm 



def convertCheckpoint_toTensor(checkpoint):
    """
    Args:
        checkpoint:

    Returns:
        tensor:
        parameter_sizes
    """
    parameters = []
    parameter_sizes = []
    for parameter_name, parameter_value in checkpoint.items():
        parameters.append(parameter_value.flatten().contiguous())
        torch.cuda.empty_cache()
        parameter_sizes.append((parameter_name, parameter_value.shape))
    tensor = torch.cat(parameters, dim=0).contiguous()
    return tensor, parameter_sizes


def convertTensor_toCheckpoint(
    tensor, parameter_sizes
):
    """

    Args:
        tensor:
        parameter_sizes:

    Returns:
        checkpoint
    """
    checkpoint = {}
    start_idx = 0
    for parameter_name, parameter_shape in parameter_sizes:
        parameter_size = parameter_shape.numel()
        end_idx = start_idx + parameter_size

        # It was causing memory issues without cloning. Probably because the memory cannot be freed
        # otherwise
        checkpoint[parameter_name] = torch.clone(
            tensor[start_idx:end_idx].reshape(parameter_shape).contiguous()
        )
        start_idx = end_idx
    return checkpoint


def topk_values_mask(M, K=20, return_mask=False):
    if K > 1:
        K /= 100

    original_shape = M.shape
    if M.dim() == 1:
        M = M.unsqueeze(0)

    n, d = M.shape
    k = int(d * K)
    k = d - k  # Keep top k elements instead of bottom k elements

    # Find the k-th smallest element by magnitude for each row
    kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
    # Create a mask tensor with True for the top k elements in each row
    mask = M.abs() >= kth_values
    final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask

    if return_mask:
        return M * final_mask, final_mask.float().mean(dim=1), final_mask

    return M * final_mask


if __name__ == "__main__":
    ckpts_dir = "/scratch/ssd004/scratch/calvinyu/git_merge/stats/"
    save_dir = "/h/calvinyu/evalmerge/stats/"
    ab_ckpts = glob.glob(ckpts_dir + "*AB.safetensors")

    # save ancestor with all zeros
    os.makedirs(save_dir, exist_ok=True)
    an = {}
    with safe_open(ab_ckpts[0], framework="pt", device=0) as f:
        for k in f.keys():
            an[k] = f.get_tensor(k) * 0.0
    save_file(an, os.path.join(save_dir, "ancestor.safetensors"))

    # save trimmed ABs
    for ab_ckpt in tqdm(ab_ckpts):
        model = {}
        with safe_open(ab_ckpt, framework="pt", device=0) as f:
            for k in f.keys():
                model[k] = f.get_tensor(k)
        model_tensor, model_sizes = convertCheckpoint_toTensor(model)
        model_topk = topk_values_mask(model_tensor, 20).squeeze()
        model_trimmed = convertTensor_toCheckpoint(model_topk, model_sizes)
        file_name = os.path.basename(ab_ckpt).replace("AB.safetensors", "AB_trimmed.safetensors")
        save_file(model_trimmed, os.path.join(save_dir, file_name))

