chinnadhurai sankar
		
	commited on
		
		
					Commit 
							
							·
						
						5028f79
	
1
								Parent(s):
							
							896f67a
								
initial commit
Browse files- elm/infer_elm.py +2 -2
- elm/infer_elm_for_demo_app.py +143 -0
- elm/model.py +1 -1
- elm/positional_embeddings.py +0 -2
- elm/utils.py +1 -6
    	
        elm/infer_elm.py
    CHANGED
    
    | @@ -1,4 +1,4 @@ | |
| 1 | 
            -
            # Copyright (c) 2024, SliceX AI, Inc. | 
| 2 |  | 
| 3 | 
             
            from elm.model import *
         | 
| 4 | 
             
            from elm.utils import batchify
         | 
| @@ -129,4 +129,4 @@ def generate_elm_responses(elm_model_path, | |
| 129 | 
             
                            print(json.dumps({"prompt": prompt, "response": response}, indent=4))
         | 
| 130 | 
             
                            print("\n***\n")
         | 
| 131 | 
             
                return result
         | 
| 132 | 
            -
                
         | 
|  | |
| 1 | 
            +
            # Copyright (c) 2024, SliceX AI, Inc.
         | 
| 2 |  | 
| 3 | 
             
            from elm.model import *
         | 
| 4 | 
             
            from elm.utils import batchify
         | 
|  | |
| 129 | 
             
                            print(json.dumps({"prompt": prompt, "response": response}, indent=4))
         | 
| 130 | 
             
                            print("\n***\n")
         | 
| 131 | 
             
                return result
         | 
| 132 | 
            +
                
         | 
    	
        elm/infer_elm_for_demo_app.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2024, SliceX AI, Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from elm.model import *
         | 
| 4 | 
            +
            from elm.utils import batchify
         | 
| 5 | 
            +
            from transformers import AutoTokenizer
         | 
| 6 | 
            +
            import json
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def load_elm_model_and_tokenizer(local_path, 
         | 
| 10 | 
            +
                                             model_config_dict,
         | 
| 11 | 
            +
                                             device="cuda",
         | 
| 12 | 
            +
                                             load_partial=True,
         | 
| 13 | 
            +
                                             get_num_layers_from_ckpt=True):
         | 
| 14 | 
            +
                """Load ELM model and tokenizer from local checkpoint."""
         | 
| 15 | 
            +
                model_args = ModelArgs(**model_config_dict)
         | 
| 16 | 
            +
                model = load_elm_model_from_ckpt(local_path, device=device, model_args=model_args, load_partial=load_partial, get_num_layers_from_ckpt=get_num_layers_from_ckpt)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                tokenizer = AutoTokenizer.from_pretrained(local_path)
         | 
| 19 | 
            +
                tokenizer.padding_side = "left"
         | 
| 20 | 
            +
                tokenizer.truncation_side = "left"
         | 
| 21 | 
            +
                return model, tokenizer
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def generate_elm_response_given_model(prompts, model, tokenizer, 
         | 
| 25 | 
            +
                                      device="cuda",
         | 
| 26 | 
            +
                                      max_ctx_word_len=1024,
         | 
| 27 | 
            +
                                      max_ctx_token_len=0,
         | 
| 28 | 
            +
                                      max_new_tokens=500,
         | 
| 29 | 
            +
                                      temperature=0.8, # set to 0 for greedy decoding
         | 
| 30 | 
            +
                                      top_k=200,
         | 
| 31 | 
            +
                                      return_tok_cnt=False,
         | 
| 32 | 
            +
                                      return_gen_only=False,
         | 
| 33 | 
            +
                                      early_stop_on_eos=False):
         | 
| 34 | 
            +
                """Generate responses from ELM model given an input list of prompts ([str])."""
         | 
| 35 | 
            +
                if max_ctx_token_len > 0:
         | 
| 36 | 
            +
                    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=max_ctx_token_len).to(device)
         | 
| 37 | 
            +
                else:
         | 
| 38 | 
            +
                    prompts = [" ".join(p.split(" ")[-max_ctx_word_len:]) for p in prompts]
         | 
| 39 | 
            +
                    inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
         | 
| 40 | 
            +
                
         | 
| 41 | 
            +
                results = []
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
                input_tok_cnt = torch.numel(inputs.input_ids)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                model.eval()
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                out_tok_cnt = 0
         | 
| 48 | 
            +
                with torch.no_grad():
         | 
| 49 | 
            +
                    temperature = temperature
         | 
| 50 | 
            +
                    top_k = top_k
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    outputs = model.generate(inputs.input_ids, max_new_tokens, temperature=temperature, top_k=top_k,
         | 
| 53 | 
            +
                                             return_gen_only=return_gen_only)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    if return_tok_cnt:
         | 
| 56 | 
            +
                        out_tok_cnt += torch.numel(outputs)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    if early_stop_on_eos:
         | 
| 59 | 
            +
                        mod_outputs = []
         | 
| 60 | 
            +
                        for i in range(len(outputs)):
         | 
| 61 | 
            +
                            curr_out = outputs[i]
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                            eos_loc_id = -1
         | 
| 64 | 
            +
                            for j in range(len(outputs[i])):
         | 
| 65 | 
            +
                                tok_id = outputs[i][j]
         | 
| 66 | 
            +
                                if tok_id == tokenizer.eos_token_id:
         | 
| 67 | 
            +
                                    eos_loc_id = j
         | 
| 68 | 
            +
                                    break
         | 
| 69 | 
            +
                            if eos_loc_id >= 0:
         | 
| 70 | 
            +
                                curr_out = outputs[i][:eos_loc_id]
         | 
| 71 | 
            +
                            mod_outputs.append(curr_out)
         | 
| 72 | 
            +
                        outputs = mod_outputs
         | 
| 73 | 
            +
                    detokenized_output = tokenizer.batch_decode(outputs, skip_special_tokens=False)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    results = detokenized_output
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                if return_tok_cnt:
         | 
| 78 | 
            +
                    return results, (input_tok_cnt, out_tok_cnt)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                return results
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            def load_elm_model_given_path(elm_model_path, elm_model_config={}, device=None):
         | 
| 83 | 
            +
                if not device:
         | 
| 84 | 
            +
                    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 85 | 
            +
                print(f"Setting device to {device}")
         | 
| 86 | 
            +
                model_config_dict = {
         | 
| 87 | 
            +
                        "hidden_size": elm_model_config.get("hidden_size", 2048),
         | 
| 88 | 
            +
                        "max_inp_len": elm_model_config.get("max_inp_len", 2048),
         | 
| 89 | 
            +
                        "num_attention_heads": elm_model_config.get("num_attention_heads", 32),
         | 
| 90 | 
            +
                        "num_layers": elm_model_config.get("num_layers", 48),
         | 
| 91 | 
            +
                        "bits": elm_model_config.get("bits", 256),
         | 
| 92 | 
            +
                        "vocab_size": elm_model_config.get("vocab_size", 50304),
         | 
| 93 | 
            +
                        "dropout": elm_model_config.get("dropout", 0.1),
         | 
| 94 | 
            +
                        "use_rotary_embeddings": elm_model_config.get("use_rotary_embeddings", True)
         | 
| 95 | 
            +
                    }
         | 
| 96 | 
            +
                    
         | 
| 97 | 
            +
                model, tokenizer = load_elm_model_and_tokenizer(local_path=elm_model_path, model_config_dict=model_config_dict, device=device, load_partial=True)
         | 
| 98 | 
            +
                return {"model": model, "tokenizer": tokenizer}
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            def generate_elm_responses(elm_model_path, 
         | 
| 101 | 
            +
                                       prompts,
         | 
| 102 | 
            +
                                       device=None, 
         | 
| 103 | 
            +
                                       elm_model_config={},
         | 
| 104 | 
            +
                                       eval_batch_size=1,
         | 
| 105 | 
            +
                                       verbose=True,
         | 
| 106 | 
            +
                                       model_info=None):
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
                if not device:
         | 
| 110 | 
            +
                    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
         | 
| 111 | 
            +
                print(f"Setting device to {device}")
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                if not model_info:
         | 
| 114 | 
            +
                    model_info = load_elm_model_given_path(elm_model_path, elm_model_config=elm_model_config, device=device)
         | 
| 115 | 
            +
                
         | 
| 116 | 
            +
                model, tokenizer = model_info["model"], model_info["tokenizer"]
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                #prompts = [prompt if "[INST]" in prompt else f"[INST]{prompt}[/INST]" for prompt in prompts]
         | 
| 119 | 
            +
                max_new_tokens = 128
         | 
| 120 | 
            +
                if "classification" in elm_model_path or "detection" in elm_model_path:
         | 
| 121 | 
            +
                    max_new_tokens = 12
         | 
| 122 | 
            +
                result = []
         | 
| 123 | 
            +
                for prompt_batch in batchify(prompts, eval_batch_size):
         | 
| 124 | 
            +
                    responses, _ = generate_elm_response_given_model(prompt_batch,
         | 
| 125 | 
            +
                                                                        model, 
         | 
| 126 | 
            +
                                                                        tokenizer, 
         | 
| 127 | 
            +
                                                                        device=device,
         | 
| 128 | 
            +
                                                                        max_ctx_word_len=1024,
         | 
| 129 | 
            +
                                                                        max_ctx_token_len=512,
         | 
| 130 | 
            +
                                                                        max_new_tokens=max_new_tokens,
         | 
| 131 | 
            +
                                                                        return_tok_cnt=True, 
         | 
| 132 | 
            +
                                                                        return_gen_only=False, 
         | 
| 133 | 
            +
                                                                        temperature=0.0, 
         | 
| 134 | 
            +
                                                                        early_stop_on_eos=True)
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
                    for prompt, response in zip(prompt_batch, responses):
         | 
| 137 | 
            +
                        response = response.split("[/INST]")[-1].strip()
         | 
| 138 | 
            +
                        result.append(response)
         | 
| 139 | 
            +
                        if verbose:
         | 
| 140 | 
            +
                            print(json.dumps({"prompt": prompt, "response": response}, indent=4))
         | 
| 141 | 
            +
                            print("\n***\n")
         | 
| 142 | 
            +
                return result
         | 
| 143 | 
            +
                
         | 
    	
        elm/model.py
    CHANGED
    
    | @@ -413,4 +413,4 @@ def sample_top_p(probs, threshold): | |
| 413 | 
             
                next_token = torch.multinomial(probs_sort, num_samples=1)
         | 
| 414 | 
             
                next_token = torch.gather(probs_idx, -1, next_token)
         | 
| 415 |  | 
| 416 | 
            -
                return next_token
         | 
|  | |
| 413 | 
             
                next_token = torch.multinomial(probs_sort, num_samples=1)
         | 
| 414 | 
             
                next_token = torch.gather(probs_idx, -1, next_token)
         | 
| 415 |  | 
| 416 | 
            +
                return next_token
         | 
    	
        elm/positional_embeddings.py
    CHANGED
    
    | @@ -9,8 +9,6 @@ def rotate_half(x): | |
| 9 |  | 
| 10 | 
             
            @torch.jit.script
         | 
| 11 | 
             
            def apply_rotary_pos_emb(x, cos, sin):
         | 
| 12 | 
            -
                # NOTE: This could probably be moved to Triton
         | 
| 13 | 
            -
             | 
| 14 | 
             
                # Handle a possible sequence length mismatch in between q and k
         | 
| 15 | 
             
                cos = cos[:, :, : x.shape[-2], :]
         | 
| 16 | 
             
                sin = sin[:, :, : x.shape[-2], :]
         | 
|  | |
| 9 |  | 
| 10 | 
             
            @torch.jit.script
         | 
| 11 | 
             
            def apply_rotary_pos_emb(x, cos, sin):
         | 
|  | |
|  | |
| 12 | 
             
                # Handle a possible sequence length mismatch in between q and k
         | 
| 13 | 
             
                cos = cos[:, :, : x.shape[-2], :]
         | 
| 14 | 
             
                sin = sin[:, :, : x.shape[-2], :]
         | 
    	
        elm/utils.py
    CHANGED
    
    | @@ -1,21 +1,16 @@ | |
| 1 | 
            -
            # Copyright (c) 2024, SliceX AI, Inc. | 
| 2 |  | 
| 3 | 
            -
            from prettytable import PrettyTable
         | 
| 4 |  | 
| 5 | 
             
            def count_parameters(model):
         | 
| 6 | 
             
                """Count the number of parameters in the model."""
         | 
| 7 | 
            -
                table = PrettyTable(["Modules", "Parameters"])
         | 
| 8 | 
             
                total_params = 0
         | 
| 9 |  | 
| 10 | 
             
                for name, parameter in model.named_parameters():
         | 
| 11 | 
             
                    if not parameter.requires_grad: continue
         | 
| 12 | 
             
                    params = parameter.numel()
         | 
| 13 | 
            -
                    table.add_row([name, params])
         | 
| 14 | 
             
                    total_params+=params
         | 
| 15 |  | 
| 16 | 
            -
                print(table)
         | 
| 17 | 
             
                print(f"Total Trainable Params: {total_params}")
         | 
| 18 | 
            -
                
         | 
| 19 | 
             
                return total_params
         | 
| 20 |  | 
| 21 |  | 
|  | |
| 1 | 
            +
            # Copyright (c) 2024, SliceX AI, Inc.
         | 
| 2 |  | 
|  | |
| 3 |  | 
| 4 | 
             
            def count_parameters(model):
         | 
| 5 | 
             
                """Count the number of parameters in the model."""
         | 
|  | |
| 6 | 
             
                total_params = 0
         | 
| 7 |  | 
| 8 | 
             
                for name, parameter in model.named_parameters():
         | 
| 9 | 
             
                    if not parameter.requires_grad: continue
         | 
| 10 | 
             
                    params = parameter.numel()
         | 
|  | |
| 11 | 
             
                    total_params+=params
         | 
| 12 |  | 
|  | |
| 13 | 
             
                print(f"Total Trainable Params: {total_params}")
         | 
|  | |
| 14 | 
             
                return total_params
         | 
| 15 |  | 
| 16 |  | 
