TomBombadyl commited on
Commit
81bed52
·
verified ·
1 Parent(s): 6a0c90b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +7 -7
handler.py CHANGED
@@ -108,14 +108,14 @@ class EndpointHandler:
108
  raise e3
109
 
110
  def _load_model(self, model_path: str, base_path: str):
111
- """Load model with fallback methods"""
112
 
113
  try:
114
- # Try direct loading from model path
115
- logger.info(f"Trying to load model from {model_path}")
116
  model = AutoModelForCausalLM.from_pretrained(
117
  model_path,
118
- torch_dtype=torch.float16,
119
  device_map="auto",
120
  trust_remote_code=True,
121
  local_files_only=True,
@@ -125,11 +125,11 @@ class EndpointHandler:
125
  logger.warning(f"Failed to load from {model_path}: {e1}")
126
 
127
  try:
128
- # Try loading from base path
129
- logger.info(f"Trying to load model from {base_path}")
130
  model = AutoModelForCausalLM.from_pretrained(
131
  base_path,
132
- torch_dtype=torch.float16,
133
  device_map="auto",
134
  trust_remote_code=True,
135
  local_files_only=True,
 
108
  raise e3
109
 
110
  def _load_model(self, model_path: str, base_path: str):
111
+ """Load model with 8-bit quantization to fit memory limits"""
112
 
113
  try:
114
+ # Try direct loading from model path with 8-bit quantization
115
+ logger.info(f"Trying to load model from {model_path} with 8-bit quantization")
116
  model = AutoModelForCausalLM.from_pretrained(
117
  model_path,
118
+ load_in_8bit=True, # Use 8-bit quantization
119
  device_map="auto",
120
  trust_remote_code=True,
121
  local_files_only=True,
 
125
  logger.warning(f"Failed to load from {model_path}: {e1}")
126
 
127
  try:
128
+ # Try loading from base path with 8-bit quantization
129
+ logger.info(f"Trying to load model from {base_path} with 8-bit quantization")
130
  model = AutoModelForCausalLM.from_pretrained(
131
  base_path,
132
+ load_in_8bit=True, # Use 8-bit quantization
133
  device_map="auto",
134
  trust_remote_code=True,
135
  local_files_only=True,