YuLu0713 commited on
Commit
9accbc6
·
verified ·
1 Parent(s): 06f2e79

Create RM_demo.py

Browse files
Files changed (1) hide show
  1. RM_demo.py +54 -0
RM_demo.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoTokenizer, AutoConfig, MistralForCausalLM
7
+ from safetensors.torch import load_file
8
+
9
+ class RewardModel:
10
+ def __init__(self, model_dir) -> None:
11
+ config = AutoConfig.from_pretrained(model_dir)
12
+ # config._attn_implementation = "flash_attention_2"
13
+ self.device = torch.device('cuda')
14
+ self.model = MistralForCausalLM(config)
15
+ self.model.lm_head = nn.Linear(config.hidden_size, 1, bias=False)
16
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
17
+ state_dict = load_file(f"{model_dir}/model.safetensors")
18
+ self.model.load_state_dict(state_dict, strict=False)
19
+ self.model.to(dtype=torch.bfloat16)
20
+ self.model.to(device=self.device)
21
+ self.model.eval()
22
+ logging.info("Load model completed.")
23
+
24
+ @torch.no_grad()
25
+ def score(self, prompts, chosens) -> List[float]:
26
+ # Concat prompt and chosen, append eos_id
27
+ input_ids_list = [self.tokenizer.encode(prompt) + self.tokenizer.encode(chosen) + [self.tokenizer.eos_token_id] for prompt, chosen in zip(prompts, chosens)]
28
+
29
+ # Pad sequences to the maximum length
30
+ max_length = max(len(ids) for ids in input_ids_list)
31
+ padded_input_ids = [ids + [self.tokenizer.pad_token_id or self.tokenizer.eos_token_id] * (max_length - len(ids)) for ids in input_ids_list]
32
+
33
+ # Forward pass
34
+ input_ids = torch.tensor(padded_input_ids).to(device=self.device)
35
+ logits = self.model(input_ids).logits
36
+
37
+ # Extract logits corresponding to eos_token_id positions
38
+ scores = []
39
+ for i, input_ids in enumerate(input_ids_list):
40
+ eos_position = input_ids.index(self.tokenizer.eos_token_id)
41
+ eos_logit = logits[i, eos_position, :].squeeze().item()
42
+ scores.append(eos_logit)
43
+
44
+ return scores
45
+
46
+ if __name__ == '__main__':
47
+
48
+ local_model_dir = "Your local model dir"
49
+ model_dir = f"{local_model_dir}/Seed-X-RM-7B"
50
+ prompt = ["Translate the following English sentence into Chinese:\nMay the force be with you <zh>"]
51
+ candidate = ["愿原力与你同在"]
52
+ model = RewardModel(model_dir)
53
+ scores = model.score(prompt, candidate) # output [score]
54
+ print(scores)