from PIL import Image import matplotlib.pyplot as plt import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation from typing import Dict, List, Any from io import BytesIO import base64 class EndpointHandler: def __init__(self): # Initialize the image segmentation pipeline self.pipeline = AutoModelForImageSegmentation.from_pretrained('.', trust_remote_code=True) torch.set_float32_matmul_precision(['high', 'highest'][0]) self.image_size = (1024, 1024) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: self.pipeline.eval() image_size = (1024, 1024) transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) print(1) # Extract the image path from the input data image_b64 = data.get("inputs", "") image_data = base64.b64decode(image_b64) image = Image.open(BytesIO(image_data)).convert("RGB") # Convert to RGB instead of RGBA input_images = transform_image(image).unsqueeze(0) print(2) # Prediction with torch.no_grad(): preds = self.pipeline(input_images)[-1].sigmoid() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) # Resize the mask to the original image size original_image = Image.open(BytesIO(image_data)).convert("RGBA") # Load original RGBA image for alpha mask = pred_pil.resize(original_image.size) original_image.putalpha(mask) print(3) # Convert the image with alpha mask to base64 buffered = BytesIO() original_image.save(buffered, format="PNG") base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") # Return the result as a list of dictionaries return [{"image": base64_image}]