File size: 1,044 Bytes
ce307e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import argparse, torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--image", type=str, required=True)
    parser.add_argument("--max_length", type=int, default=20)
    args = parser.parse_args()

    model_id = "nlpconnect/vit-gpt2-image-captioning"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = VisionEncoderDecoderModel.from_pretrained(model_id).to(device)
    feature_extractor = ViTImageProcessor.from_pretrained(model_id)
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    img = Image.open(args.image).convert("RGB")
    pixel_values = feature_extractor(images=[img], return_tensors="pt").pixel_values.to(device)

    with torch.no_grad():
        output_ids = model.generate(pixel_values, max_length=args.max_length)[0]
    caption = tokenizer.decode(output_ids, skip_special_tokens=True)
    print(caption)

if __name__ == "__main__":
    main()