RASMUS commited on
Commit
d7c6080
·
verified ·
1 Parent(s): 53abc6e

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +32 -28
handler.py CHANGED
@@ -31,33 +31,37 @@ class EndpointHandler():
31
  Return:
32
  A :obj:`list` | `dict`: will be serialized and returned
33
  """
34
- context = data.pop("context",None)
35
- question = data.pop("question",None)
36
- messages = generate_rag_prompt_message(context, question)
37
-
38
- inputs = self.tokenizer(
39
- [
40
- self.tokenizer.apply_chat_template(messages, tokenize=False)
41
- ]*1, return_tensors = "pt").to("cuda")
42
-
43
-
44
- with torch.no_grad():
45
- generated_ids = self.model.generate(
46
- input_ids=inputs["input_ids"],
47
- attention_mask=inputs["attention_mask"],
48
- generation_config=self.generation_config, **{
49
- "temperature": 0.1,
50
- "penalty_alpha": 0.6,
51
- "min_p": 0.5,
52
- "do_sample": True,
53
- "repetition_penalty": 1.28,
54
- "min_length": 10,
55
- "max_new_tokens": 250
56
- })
57
-
58
- generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True)[0]
59
  try:
60
- generated_answer = generated_text.split('[/INST]')[1].strip()
61
- return json.dumps({"answer": generated_answer})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  except Exception as e:
63
- return json.dumps({"answer": str(e)})
 
31
  Return:
32
  A :obj:`list` | `dict`: will be serialized and returned
33
  """
34
+ print(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  try:
36
+ context = data.pop("context",None)
37
+ question = data.pop("question",None)
38
+ messages = generate_rag_prompt_message(context, question)
39
+
40
+ inputs = self.tokenizer(
41
+ [
42
+ self.tokenizer.apply_chat_template(messages, tokenize=False)
43
+ ]*1, return_tensors = "pt").to("cuda")
44
+
45
+
46
+ with torch.no_grad():
47
+ generated_ids = self.model.generate(
48
+ input_ids=inputs["input_ids"],
49
+ attention_mask=inputs["attention_mask"],
50
+ generation_config=self.generation_config, **{
51
+ "temperature": 0.1,
52
+ "penalty_alpha": 0.6,
53
+ "min_p": 0.5,
54
+ "do_sample": True,
55
+ "repetition_penalty": 1.28,
56
+ "min_length": 10,
57
+ "max_new_tokens": 250
58
+ })
59
+
60
+ generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True)[0]
61
+ try:
62
+ generated_answer = generated_text.split('[/INST]')[1].strip()
63
+ return json.dumps({"answer": generated_answer})
64
+ except Exception as e:
65
+ return json.dumps({"answer": str(e)})
66
  except Exception as e:
67
+ print(e)