Update code/inference.py
Browse files- code/inference.py +10 -0
code/inference.py
CHANGED
|
@@ -3,6 +3,8 @@ from PIL import Image
|
|
| 3 |
from io import BytesIO
|
| 4 |
import torch
|
| 5 |
from transformers import AutoTokenizer
|
|
|
|
|
|
|
| 6 |
|
| 7 |
from llava.model import LlavaLlamaForCausalLM
|
| 8 |
from llava.utils import disable_torch_init
|
|
@@ -68,6 +70,14 @@ def predict_fn(data, model_and_tokenizer):
|
|
| 68 |
if image_file.startswith("http") or image_file.startswith("https"):
|
| 69 |
response = requests.get(image_file)
|
| 70 |
image = Image.open(BytesIO(response.content)).convert("RGB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
else:
|
| 72 |
image = Image.open(image_file).convert("RGB")
|
| 73 |
|
|
|
|
| 3 |
from io import BytesIO
|
| 4 |
import torch
|
| 5 |
from transformers import AutoTokenizer
|
| 6 |
+
import boto3
|
| 7 |
+
import tempfile
|
| 8 |
|
| 9 |
from llava.model import LlavaLlamaForCausalLM
|
| 10 |
from llava.utils import disable_torch_init
|
|
|
|
| 70 |
if image_file.startswith("http") or image_file.startswith("https"):
|
| 71 |
response = requests.get(image_file)
|
| 72 |
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 73 |
+
elif image_file.startswith("s3://"):
|
| 74 |
+
s3 = boto3.client("s3")
|
| 75 |
+
s3_path = s3_path[5:]
|
| 76 |
+
bucket = s3_path.split('/')[0]
|
| 77 |
+
s3_key = '/'.join(s3_path.split('/')[1:])
|
| 78 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
| 79 |
+
s3.download_file(bucket, s3_key, temp_file.name)
|
| 80 |
+
image = Image.open(temp_file).convert("RGB")
|
| 81 |
else:
|
| 82 |
image = Image.open(image_file).convert("RGB")
|
| 83 |
|