remiai3's picture
Upload 4 files
ce307e6 verified
import argparse, torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, required=True)
parser.add_argument("--max_length", type=int, default=20)
args = parser.parse_args()
model_id = "nlpconnect/vit-gpt2-image-captioning"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisionEncoderDecoderModel.from_pretrained(model_id).to(device)
feature_extractor = ViTImageProcessor.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
img = Image.open(args.image).convert("RGB")
pixel_values = feature_extractor(images=[img], return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
output_ids = model.generate(pixel_values, max_length=args.max_length)[0]
caption = tokenizer.decode(output_ids, skip_special_tokens=True)
print(caption)
if __name__ == "__main__":
main()