TomBombadyl commited on
Commit
1eff8e4
·
verified ·
1 Parent(s): 28f4095

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -7
handler.py CHANGED
@@ -174,14 +174,45 @@ class EndpointHandler:
174
  if not safetensors_files:
175
  raise FileNotFoundError("No safetensors files found")
176
 
177
- # Load weights manually
178
  from safetensors.torch import load_file
179
 
 
 
 
 
 
180
  state_dict = {}
181
- for file in sorted(safetensors_files):
182
- logger.info(f"Loading weights from: {file}")
183
- partial_state_dict = load_file(file)
184
- state_dict.update(partial_state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  logger.info(f"Total state dict keys: {len(state_dict)}")
187
 
@@ -196,10 +227,16 @@ class EndpointHandler:
196
  logger.warning(f"Unexpected keys: {len(unexpected_keys)} unexpected keys")
197
  logger.warning(f"First few unexpected: {unexpected_keys[:5]}")
198
 
199
- # Convert to half precision and move to GPU
200
- model = model.half()
 
 
 
 
 
201
  if torch.cuda.is_available():
202
  model = model.cuda()
 
203
 
204
  model.eval()
205
  return model
 
174
  if not safetensors_files:
175
  raise FileNotFoundError("No safetensors files found")
176
 
177
+ # Load weights manually with memory optimization
178
  from safetensors.torch import load_file
179
 
180
+ # Convert to half precision before loading weights to save memory
181
+ model = model.half()
182
+ logger.info("Converted model to half precision")
183
+
184
+ # Load weights in chunks to avoid memory spikes
185
  state_dict = {}
186
+ total_files = len(safetensors_files)
187
+
188
+ for i, file in enumerate(sorted(safetensors_files)):
189
+ logger.info(f"Loading weights from file {i+1}/{total_files}: {os.path.basename(file)}")
190
+
191
+ try:
192
+ # Load partial weights
193
+ partial_state_dict = load_file(file)
194
+
195
+ # Convert to half precision immediately
196
+ partial_state_dict = {k: v.half() for k, v in partial_state_dict.items()}
197
+
198
+ # Update state dict
199
+ state_dict.update(partial_state_dict)
200
+
201
+ # Clear partial dict to free memory
202
+ del partial_state_dict
203
+
204
+ # Force garbage collection
205
+ import gc
206
+ gc.collect()
207
+
208
+ if torch.cuda.is_available():
209
+ torch.cuda.empty_cache()
210
+
211
+ logger.info(f"Loaded file {i+1}/{total_files}, current memory usage: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
212
+
213
+ except Exception as e:
214
+ logger.error(f"Failed to load file {file}: {e}")
215
+ raise e
216
 
217
  logger.info(f"Total state dict keys: {len(state_dict)}")
218
 
 
227
  logger.warning(f"Unexpected keys: {len(unexpected_keys)} unexpected keys")
228
  logger.warning(f"First few unexpected: {unexpected_keys[:5]}")
229
 
230
+ # Clear state dict to free memory
231
+ del state_dict
232
+ gc.collect()
233
+ if torch.cuda.is_available():
234
+ torch.cuda.empty_cache()
235
+
236
+ # Move to GPU if available
237
  if torch.cuda.is_available():
238
  model = model.cuda()
239
+ logger.info(f"Model moved to GPU, final memory usage: {torch.cuda.memory_allocated() / 1024**3:.2f}GB")
240
 
241
  model.eval()
242
  return model