enalis commited on
Commit
bf999fb
·
verified ·
1 Parent(s): fbaffd0

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +33 -33
inference.py CHANGED
@@ -1,33 +1,33 @@
1
- import torch
2
- from model import LVL
3
- from transformers import RobertaTokenizer
4
- from PIL import Image
5
- from torchvision import transforms
6
-
7
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
-
9
- # Load model
10
- model = LVL()
11
- model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
12
- model.to(device)
13
- model.eval()
14
-
15
- # Load tokenizer
16
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
17
-
18
- # Image transform
19
- transform = transforms.Compose([
20
- transforms.Resize((224, 224)),
21
- transforms.ToTensor()
22
- ])
23
-
24
-
25
- def predict(image_path, text):
26
- image = transform(Image.open(image_path).convert("RGB")).unsqueeze(0).to(device)
27
- tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
28
-
29
- with torch.no_grad():
30
- img_feat, txt_feat = model(image, tokens["input_ids"], tokens["attention_mask"])
31
- similarity = torch.matmul(img_feat, txt_feat.T).squeeze()
32
-
33
- return similarity.item()
 
1
+ import torch
2
+ from model import LVL
3
+ from transformers import RobertaTokenizer
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+
9
+ # Load model
10
+ model = LVL()
11
+ model.load_state_dict(torch.load("scold.pth", map_location=device))
12
+ model.to(device)
13
+ model.eval()
14
+
15
+ # Load tokenizer
16
+ tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
17
+
18
+ # Image transform
19
+ transform = transforms.Compose([
20
+ transforms.Resize((224, 224)),
21
+ transforms.ToTensor()
22
+ ])
23
+
24
+
25
+ def predict(image_path, text):
26
+ image = transform(Image.open(image_path).convert("RGB")).unsqueeze(0).to(device)
27
+ tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
28
+
29
+ with torch.no_grad():
30
+ img_feat, txt_feat = model(image, tokens["input_ids"], tokens["attention_mask"])
31
+ similarity = torch.matmul(img_feat, txt_feat.T).squeeze()
32
+
33
+ return similarity.item()