r-g2-2024's picture
Update README.md
d05184f verified
|
raw
history blame
6.36 kB
metadata
license: other
language:
  - ja
base_model:
  - tokyotech-llm/Llama-3.1-Swallow-70B-Instruct-v0.3
  - Qwen/Qwen2.5-VL-7B-Instruct
pipeline_tag: visual-question-answering

Llama-3.1-70B-Instruct-multimodal-JP-Graph - Built with Llama

Llama-3.1-70B-Instruct-multimodal-JP-Graph is a Japanese Large Vision Language Model. This model is based on Llama-3.1-Swallow-70B and Image Encoder of Qwen2-VL-7B.

How to use

1. Install LLaVA-NeXT

  • First, please install LLaVA-NeXT by following the instructions at the URL.
git clone https://github.com/LLaVA-VL/LLaVA-NeXT
cd LLaVA-NeXT
conda create -n llava python=3.10 -y
conda activate llava
pip install --upgrade pip  # Enable PEP 660 support.
pip install -e ".[train]"

2. Install dependencies

pip install flash-attn==2.6.3
pip install transformers==4.45.2

3. Modify LLaVA-NeXT

  • Modify the LLaVA-NeXT code as follows.
    • Create the LLaVA-NeXT/llava/model/multimodal_encoder/qwen2_vl directory and copy the contents of the attached qwen2_vl directory into it.
    • Overwrite LLaVA-NeXT/llava/model/multimodal_encoder/builder.py with the attached "builder.py".
    • Copy the attached "qwen2vl_encoder.py" into LLaVA-NeXT/llava/model/multimodal_encoder/.
    • Overwrite LLaVA-NeXT/llava/model/language_model/llava_llama.py with the attached "llava_llama.py".
    • Overwrite LLaVA-NeXT/llava/model/llava_arch.py with the attached "llava_arch.py".
    • Overwrite LLaVA-NeXT/llava/conversation.py with the attached "conversation.py".

4. Inference

The following script loads the model and allows inference.

from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
from llava.conversation import conv_templates, SeparatorStyle

from PIL import Image
import copy
import torch

import warnings
warnings.filterwarnings("ignore")


pretrained = 'r-g2-2024/Llama-3.1-70B-Instruct-multimodal-JP-Graph-v0.1'
model_name = "llava_llama"
device = "cuda"
device_map = "auto"
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map)

model.eval()

image = Image.open("./画像14.png")
image

inputs = image_processor(image)
pixel_values = torch.tensor(inputs['pixel_values']).to(dtype=torch.float16, device=device)
pixel_values = [pixel_values]
_image_grid_thw = torch.tensor(inputs['image_grid_thw'], dtype=torch.long)
_image_grid_thw = [_image_grid_thw]

conv_template = "llava_llama_3"
question = DEFAULT_IMAGE_TOKEN + "\nFY22からFY23にかけて単体の値はどれくらい増加したか?"
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()

input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]

cont = model.generate(
    input_ids,
    images=pixel_values,
    image_sizes=image_sizes,
    image_grid_thws=_image_grid_thw,
    do_sample=False,
    temperature=0,
    max_new_tokens=4096,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)



question = DEFAULT_IMAGE_TOKEN + "\nFY2021の連結の値はいくつか?"
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()

input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]

cont = model.generate(
    input_ids,
    images=pixel_values,
    image_sizes=image_sizes,
    image_grid_thws=_image_grid_thw,
    do_sample=False,
    temperature=0,
    max_new_tokens=4096,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)

question = DEFAULT_IMAGE_TOKEN + "\nこの図は何を表しているか?"
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()

input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]

cont = model.generate(
    input_ids,
    images=pixel_values,
    image_sizes=image_sizes,
    image_grid_thws=_image_grid_thw,
    do_sample=False,
    temperature=0,
    max_new_tokens=4096,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)


question = DEFAULT_IMAGE_TOKEN + "\nFY2020の純利益はマイナスか?プラスか?"
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()

input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]

cont = model.generate(
    input_ids,
    images=pixel_values,
    image_sizes=image_sizes,
    image_grid_thws=_image_grid_thw,
    do_sample=False,
    temperature=0,
    max_new_tokens=4096,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)


question = DEFAULT_IMAGE_TOKEN + "\n単体が連結の利益を上回るのはいつからか?"
conv = copy.deepcopy(conv_templates[conv_template])
conv.append_message(conv.roles[0], question)
conv.append_message(conv.roles[1], None)
prompt_question = conv.get_prompt()

input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
image_sizes = [image.size]

cont = model.generate(
    input_ids,
    images=pixel_values,
    image_sizes=image_sizes,
    image_grid_thws=_image_grid_thw,
    do_sample=False,
    temperature=0,
    max_new_tokens=4096,
)
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
print(text_outputs)