from typing import Any, Dict from transformers import AutoTokenizer, AutoModel import torch class EndpointHandler: def __init__(self, model_dir: str, **kwargs: Any) -> None: self.model = AutoModel.from_pretrained( model_dir, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, use_flash_attn=False, trust_remote_code=True, device_map="auto", ).eval() self.tokenizer = AutoTokenizer.from_pretrained( model_dir, trust_remote_code=True, use_fast=False ) def __call__(self, data: Dict[str, Any]) -> Any: logger.info(f"Received incoming request with {data=}") if __name__ == "__main__": handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct") print(handler)