File size: 2,019 Bytes
a3504d5
3a96b74
a3504d5
 
 
 
628532f
 
460210d
d36c36e
12f2dcd
460210d
12f2dcd
460210d
a3504d5
3a96b74
a3504d5
3a96b74
 
a3504d5
3a96b74
a3504d5
 
 
3a96b74
 
 
 
d36c36e
3576cce
3a96b74
a3504d5
 
3576cce
a3504d5
 
d36c36e
 
 
 
 
3a96b74
 
a3504d5
 
d36c36e
3a96b74
a3504d5
 
3a96b74
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
from typing import Dict, List, Any
from io import BytesIO
import base64

class EndpointHandler:
    def __init__(self):
        # Initialize the image segmentation pipeline
        self.pipeline = AutoModelForImageSegmentation.from_pretrained('.', trust_remote_code=True)
        torch.set_float32_matmul_precision(['high', 'highest'][0])
        self.image_size = (1024, 1024)
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        self.pipeline.eval()
        image_size = (1024, 1024)
        transform_image = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        print(1)
        # Extract the image path from the input data
        image_b64 = data.get("inputs", "")
        image_data = base64.b64decode(image_b64)
        image = Image.open(BytesIO(image_data)).convert("RGB")  # Convert to RGB instead of RGBA
        input_images = transform_image(image).unsqueeze(0)
        print(2)
        # Prediction
        with torch.no_grad():
            preds = self.pipeline(input_images)[-1].sigmoid()
        pred = preds[0].squeeze()
        pred_pil = transforms.ToPILImage()(pred)
        
        # Resize the mask to the original image size
        original_image = Image.open(BytesIO(image_data)).convert("RGBA")  # Load original RGBA image for alpha
        mask = pred_pil.resize(original_image.size)
        original_image.putalpha(mask)
        print(3)
    
        # Convert the image with alpha mask to base64
        buffered = BytesIO()
        original_image.save(buffered, format="PNG")
        base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
        
        # Return the result as a list of dictionaries
        return [{"image": base64_image}]