TomBombadyl commited on
Commit
0711c19
·
verified ·
1 Parent(s): ab1c2e4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +29 -20
handler.py CHANGED
@@ -121,10 +121,18 @@ class EndpointHandler:
121
  return tokenizer
122
 
123
  def _load_model_manual(self, model_path: str):
124
- """Load model completely manually"""
125
 
126
  logger.info("Loading model manually...")
127
 
 
 
 
 
 
 
 
 
128
  # Load config manually
129
  config_path = os.path.join(model_path, "config.json")
130
  with open(config_path, 'r') as f:
@@ -145,8 +153,12 @@ class EndpointHandler:
145
 
146
  # Create model
147
  model = Qwen2ForCausalLM(config)
 
148
 
149
- # Load weights manually from safetensors
 
 
 
150
  import glob
151
  safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
152
  logger.info(f"Found {len(safetensors_files)} safetensors files")
@@ -154,15 +166,20 @@ class EndpointHandler:
154
  if safetensors_files:
155
  from safetensors.torch import load_file
156
 
157
- # Load weights in chunks
158
- state_dict = {}
159
  for i, file in enumerate(sorted(safetensors_files)):
160
  logger.info(f"Loading weights from file {i+1}/{len(safetensors_files)}: {os.path.basename(file)}")
161
 
 
162
  partial_state_dict = load_file(file)
163
- state_dict.update(partial_state_dict)
164
 
165
- # Clear partial dict to free memory
 
 
 
 
 
 
166
  del partial_state_dict
167
 
168
  # Force garbage collection
@@ -171,27 +188,19 @@ class EndpointHandler:
171
 
172
  if torch.cuda.is_available():
173
  torch.cuda.empty_cache()
174
-
175
- # Load weights into model
176
- missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
177
-
178
- if missing_keys:
179
- logger.warning(f"Missing keys: {len(missing_keys)} keys missing")
180
- if unexpected_keys:
181
- logger.warning(f"Unexpected keys: {len(unexpected_keys)} unexpected keys")
182
-
183
- # Clear state dict to free memory
184
- del state_dict
185
- gc.collect()
186
- if torch.cuda.is_available():
187
- torch.cuda.empty_cache()
188
 
189
  # Convert to half precision and move to GPU
 
190
  model = model.half()
 
191
  if torch.cuda.is_available():
 
192
  model = model.cuda()
 
193
 
194
  model.eval()
 
195
  return model
196
 
197
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
 
121
  return tokenizer
122
 
123
  def _load_model_manual(self, model_path: str):
124
+ """Load model completely manually with memory optimization"""
125
 
126
  logger.info("Loading model manually...")
127
 
128
+ # Check GPU availability and memory
129
+ if torch.cuda.is_available():
130
+ logger.info(f"CUDA available: {torch.cuda.get_device_name()}")
131
+ logger.info(f"GPU memory before loading: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
132
+ logger.info(f"GPU memory total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f}GB")
133
+ else:
134
+ logger.warning("CUDA not available, using CPU")
135
+
136
  # Load config manually
137
  config_path = os.path.join(model_path, "config.json")
138
  with open(config_path, 'r') as f:
 
153
 
154
  # Create model
155
  model = Qwen2ForCausalLM(config)
156
+ logger.info("Model architecture created")
157
 
158
+ if torch.cuda.is_available():
159
+ logger.info(f"GPU memory after model creation: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
160
+
161
+ # Load weights manually from safetensors with memory optimization
162
  import glob
163
  safetensors_files = glob.glob(os.path.join(model_path, "*.safetensors"))
164
  logger.info(f"Found {len(safetensors_files)} safetensors files")
 
166
  if safetensors_files:
167
  from safetensors.torch import load_file
168
 
169
+ # Load weights directly into model without accumulating in state_dict
 
170
  for i, file in enumerate(sorted(safetensors_files)):
171
  logger.info(f"Loading weights from file {i+1}/{len(safetensors_files)}: {os.path.basename(file)}")
172
 
173
+ # Load partial weights
174
  partial_state_dict = load_file(file)
 
175
 
176
+ if torch.cuda.is_available():
177
+ logger.info(f"GPU memory after loading file {i+1}: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
178
+
179
+ # Load this partial state dict directly into the model
180
+ missing_keys, unexpected_keys = model.load_state_dict(partial_state_dict, strict=False)
181
+
182
+ # Clear partial dict immediately to free memory
183
  del partial_state_dict
184
 
185
  # Force garbage collection
 
188
 
189
  if torch.cuda.is_available():
190
  torch.cuda.empty_cache()
191
+ logger.info(f"GPU memory after cleanup: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  # Convert to half precision and move to GPU
194
+ logger.info("Converting model to half precision...")
195
  model = model.half()
196
+
197
  if torch.cuda.is_available():
198
+ logger.info(f"GPU memory after half precision: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
199
  model = model.cuda()
200
+ logger.info(f"GPU memory after moving to GPU: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
201
 
202
  model.eval()
203
+ logger.info("Model loaded successfully and set to eval mode")
204
  return model
205
 
206
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: