Update handler.py
Browse files- handler.py +23 -1
handler.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
import json
|
4 |
import os
|
5 |
import glob
|
|
|
6 |
from transformers import PreTrainedTokenizerFast, PreTrainedModel
|
7 |
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
8 |
import logging
|
@@ -19,6 +20,14 @@ class EndpointHandler:
|
|
19 |
logger.info(f"Loading model from {path}")
|
20 |
|
21 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# Find the actual model files
|
23 |
model_path = self._discover_model_files(path)
|
24 |
logger.info(f"Model files found at: {model_path}")
|
@@ -39,6 +48,18 @@ class EndpointHandler:
|
|
39 |
logger.error(f"Failed to initialize: {str(e)}")
|
40 |
raise e
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def _discover_model_files(self, base_path: str) -> str:
|
43 |
"""Find where the actual model files are located"""
|
44 |
|
@@ -126,7 +147,8 @@ class EndpointHandler:
|
|
126 |
tokenizer = AutoTokenizer.from_pretrained(
|
127 |
model_path,
|
128 |
trust_remote_code=True,
|
129 |
-
local_files_only=True
|
|
|
130 |
)
|
131 |
except Exception as e:
|
132 |
logger.error(f"Failed to load tokenizer: {e}")
|
|
|
3 |
import json
|
4 |
import os
|
5 |
import glob
|
6 |
+
import tempfile
|
7 |
from transformers import PreTrainedTokenizerFast, PreTrainedModel
|
8 |
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
9 |
import logging
|
|
|
20 |
logger.info(f"Loading model from {path}")
|
21 |
|
22 |
try:
|
23 |
+
# Set cache directories to temp to avoid memory issues
|
24 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
|
25 |
+
os.environ['HF_HOME'] = '/tmp/hf_home'
|
26 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
27 |
+
|
28 |
+
# Clear any existing cache
|
29 |
+
self._clear_cache()
|
30 |
+
|
31 |
# Find the actual model files
|
32 |
model_path = self._discover_model_files(path)
|
33 |
logger.info(f"Model files found at: {model_path}")
|
|
|
48 |
logger.error(f"Failed to initialize: {str(e)}")
|
49 |
raise e
|
50 |
|
51 |
+
def _clear_cache(self):
|
52 |
+
"""Clear any cached model data to free memory"""
|
53 |
+
try:
|
54 |
+
import shutil
|
55 |
+
cache_dirs = ['/tmp/transformers_cache', '/tmp/hf_home']
|
56 |
+
for cache_dir in cache_dirs:
|
57 |
+
if os.path.exists(cache_dir):
|
58 |
+
shutil.rmtree(cache_dir)
|
59 |
+
logger.info(f"Cleared cache: {cache_dir}")
|
60 |
+
except Exception as e:
|
61 |
+
logger.warning(f"Could not clear cache: {e}")
|
62 |
+
|
63 |
def _discover_model_files(self, base_path: str) -> str:
|
64 |
"""Find where the actual model files are located"""
|
65 |
|
|
|
147 |
tokenizer = AutoTokenizer.from_pretrained(
|
148 |
model_path,
|
149 |
trust_remote_code=True,
|
150 |
+
local_files_only=True,
|
151 |
+
cache_dir='/tmp/tokenizer_cache' # Use temp cache
|
152 |
)
|
153 |
except Exception as e:
|
154 |
logger.error(f"Failed to load tokenizer: {e}")
|