File size: 2,019 Bytes
a3504d5 3a96b74 a3504d5 628532f 460210d d36c36e 12f2dcd 460210d 12f2dcd 460210d a3504d5 3a96b74 a3504d5 3a96b74 a3504d5 3a96b74 a3504d5 3a96b74 d36c36e 3576cce 3a96b74 a3504d5 3576cce a3504d5 d36c36e 3a96b74 a3504d5 d36c36e 3a96b74 a3504d5 3a96b74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Dict, List, Any
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self):
# Initialize the image segmentation pipeline
self.pipeline = AutoModelForImageSegmentation.from_pretrained('.', trust_remote_code=True)
torch.set_float32_matmul_precision(['high', 'highest'][0])
self.image_size = (1024, 1024)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
self.pipeline.eval()
image_size = (1024, 1024)
transform_image = transforms.Compose([
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
print(1)
# Extract the image path from the input data
image_b64 = data.get("inputs", "")
image_data = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_data)).convert("RGB") # Convert to RGB instead of RGBA
input_images = transform_image(image).unsqueeze(0)
print(2)
# Prediction
with torch.no_grad():
preds = self.pipeline(input_images)[-1].sigmoid()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
# Resize the mask to the original image size
original_image = Image.open(BytesIO(image_data)).convert("RGBA") # Load original RGBA image for alpha
mask = pred_pil.resize(original_image.size)
original_image.putalpha(mask)
print(3)
# Convert the image with alpha mask to base64
buffered = BytesIO()
original_image.save(buffered, format="PNG")
base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
# Return the result as a list of dictionaries
return [{"image": base64_image}] |