Create RM_demo.py
Browse files- RM_demo.py +54 -0
RM_demo.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from transformers import AutoTokenizer, AutoConfig, MistralForCausalLM
|
| 7 |
+
from safetensors.torch import load_file
|
| 8 |
+
|
| 9 |
+
class RewardModel:
|
| 10 |
+
def __init__(self, model_dir) -> None:
|
| 11 |
+
config = AutoConfig.from_pretrained(model_dir)
|
| 12 |
+
# config._attn_implementation = "flash_attention_2"
|
| 13 |
+
self.device = torch.device('cuda')
|
| 14 |
+
self.model = MistralForCausalLM(config)
|
| 15 |
+
self.model.lm_head = nn.Linear(config.hidden_size, 1, bias=False)
|
| 16 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
| 17 |
+
state_dict = load_file(f"{model_dir}/model.safetensors")
|
| 18 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 19 |
+
self.model.to(dtype=torch.bfloat16)
|
| 20 |
+
self.model.to(device=self.device)
|
| 21 |
+
self.model.eval()
|
| 22 |
+
logging.info("Load model completed.")
|
| 23 |
+
|
| 24 |
+
@torch.no_grad()
|
| 25 |
+
def score(self, prompts, chosens) -> List[float]:
|
| 26 |
+
# Concat prompt and chosen, append eos_id
|
| 27 |
+
input_ids_list = [self.tokenizer.encode(prompt) + self.tokenizer.encode(chosen) + [self.tokenizer.eos_token_id] for prompt, chosen in zip(prompts, chosens)]
|
| 28 |
+
|
| 29 |
+
# Pad sequences to the maximum length
|
| 30 |
+
max_length = max(len(ids) for ids in input_ids_list)
|
| 31 |
+
padded_input_ids = [ids + [self.tokenizer.pad_token_id or self.tokenizer.eos_token_id] * (max_length - len(ids)) for ids in input_ids_list]
|
| 32 |
+
|
| 33 |
+
# Forward pass
|
| 34 |
+
input_ids = torch.tensor(padded_input_ids).to(device=self.device)
|
| 35 |
+
logits = self.model(input_ids).logits
|
| 36 |
+
|
| 37 |
+
# Extract logits corresponding to eos_token_id positions
|
| 38 |
+
scores = []
|
| 39 |
+
for i, input_ids in enumerate(input_ids_list):
|
| 40 |
+
eos_position = input_ids.index(self.tokenizer.eos_token_id)
|
| 41 |
+
eos_logit = logits[i, eos_position, :].squeeze().item()
|
| 42 |
+
scores.append(eos_logit)
|
| 43 |
+
|
| 44 |
+
return scores
|
| 45 |
+
|
| 46 |
+
if __name__ == '__main__':
|
| 47 |
+
|
| 48 |
+
local_model_dir = "Your local model dir"
|
| 49 |
+
model_dir = f"{local_model_dir}/Seed-X-RM-7B"
|
| 50 |
+
prompt = ["Translate the following English sentence into Chinese:\nMay the force be with you <zh>"]
|
| 51 |
+
candidate = ["愿原力与你同在"]
|
| 52 |
+
model = RewardModel(model_dir)
|
| 53 |
+
scores = model.score(prompt, candidate) # output [score]
|
| 54 |
+
print(scores)
|