HowWebWorks commited on
Commit
c9feb62
·
1 Parent(s): 5c35d44

update tokenizer

Browse files
Files changed (1) hide show
  1. handler.py +24 -10
handler.py CHANGED
@@ -1,16 +1,30 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
2
  from typing import Dict
 
 
3
 
4
  class EndpointHandler:
5
- def __init__(self, path=""):
6
- self.tokenizer = AutoTokenizer.from_pretrained(path)
7
- self.model = AutoModelForCausalLM.from_pretrained(path)
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def __call__(self, data: Dict[str, str]) -> Dict[str, str]:
10
- inputs = data.get("inputs", "")
11
- if not inputs:
12
  return {"error": "No input provided."}
13
- inputs = self.tokenizer(inputs, return_tensors="pt")
14
- outputs = self.model.generate(**inputs, max_new_tokens=100)
15
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
16
- return {"generated_text": response}
 
 
 
1
  from typing import Dict
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
  class EndpointHandler:
6
+ """
7
+ Minimal custom handler for InternLM2 / NuExtract-2-8B
8
+ """
9
+
10
+ def __init__(self, path: str = "./model"):
11
+ # allow execution of custom model code
12
+ self.tokenizer = AutoTokenizer.from_pretrained(
13
+ path, trust_remote_code=True
14
+ )
15
+ self.model = AutoModelForCausalLM.from_pretrained(
16
+ path,
17
+ trust_remote_code=True, # ← key line
18
+ torch_dtype=torch.float16, # load in fp16 to fit on one A10/T4
19
+ device_map="auto" # send to GPU if available
20
+ ).eval() # put in inference mode
21
 
22
  def __call__(self, data: Dict[str, str]) -> Dict[str, str]:
23
+ prompt = data.get("inputs", "")
24
+ if not prompt:
25
  return {"error": "No input provided."}
26
+
27
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
28
+ outputs = self.model.generate(**inputs, max_new_tokens=128)
29
+ answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+ return {"generated_text": answer}