HowWebWorks commited on
Commit
ed658dd
·
1 Parent(s): 5777436

point handler to model

Browse files
Files changed (1) hide show
  1. handler.py +4 -2
handler.py CHANGED
@@ -1,18 +1,20 @@
1
  from typing import Dict
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
  class EndpointHandler:
6
  """Custom handler for NuExtract-2-8B (InternLM2 based)."""
7
 
8
  def __init__(self, path: str = "") -> None:
 
9
  # ↓↓↓ allow the repo’s custom configuration & modelling code
10
  self.tokenizer = AutoTokenizer.from_pretrained(
11
- path,
12
  trust_remote_code=True # ← mandatory
13
  )
14
  self.model = AutoModelForCausalLM.from_pretrained(
15
- path,
16
  trust_remote_code=True, # ← mandatory
17
  torch_dtype=torch.float16, # fits on a 16 GB GPU
18
  device_map="auto" # put tensors on the GPU
 
1
  from typing import Dict
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import os
5
 
6
  class EndpointHandler:
7
  """Custom handler for NuExtract-2-8B (InternLM2 based)."""
8
 
9
  def __init__(self, path: str = "") -> None:
10
+ model_path = os.path.join(path, "model")
11
  # ↓↓↓ allow the repo’s custom configuration & modelling code
12
  self.tokenizer = AutoTokenizer.from_pretrained(
13
+ model_path,
14
  trust_remote_code=True # ← mandatory
15
  )
16
  self.model = AutoModelForCausalLM.from_pretrained(
17
+ model_path,
18
  trust_remote_code=True, # ← mandatory
19
  torch_dtype=torch.float16, # fits on a 16 GB GPU
20
  device_map="auto" # put tensors on the GPU