r-g2-2024 commited on
Commit
d05184f
·
verified ·
1 Parent(s): cc16e9c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +144 -0
README.md CHANGED
@@ -45,3 +45,147 @@ pip install transformers==4.45.2
45
 
46
  ### 4. Inference
47
  The following script loads the model and allows inference.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  ### 4. Inference
47
  The following script loads the model and allows inference.
48
+
49
+ ```python
50
+ from llava.model.builder import load_pretrained_model
51
+ from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
52
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX
53
+ from llava.conversation import conv_templates, SeparatorStyle
54
+
55
+ from PIL import Image
56
+ import copy
57
+ import torch
58
+
59
+ import warnings
60
+ warnings.filterwarnings("ignore")
61
+
62
+
63
+ pretrained = 'r-g2-2024/Llama-3.1-70B-Instruct-multimodal-JP-Graph-v0.1'
64
+ model_name = "llava_llama"
65
+ device = "cuda"
66
+ device_map = "auto"
67
+ tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map)
68
+
69
+ model.eval()
70
+
71
+ image = Image.open("./画像14.png")
72
+ image
73
+
74
+ inputs = image_processor(image)
75
+ pixel_values = torch.tensor(inputs['pixel_values']).to(dtype=torch.float16, device=device)
76
+ pixel_values = [pixel_values]
77
+ _image_grid_thw = torch.tensor(inputs['image_grid_thw'], dtype=torch.long)
78
+ _image_grid_thw = [_image_grid_thw]
79
+
80
+ conv_template = "llava_llama_3"
81
+ question = DEFAULT_IMAGE_TOKEN + "\nFY22からFY23にかけて単体の値はどれくらい増加したか?"
82
+ conv = copy.deepcopy(conv_templates[conv_template])
83
+ conv.append_message(conv.roles[0], question)
84
+ conv.append_message(conv.roles[1], None)
85
+ prompt_question = conv.get_prompt()
86
+
87
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
88
+ image_sizes = [image.size]
89
+
90
+ cont = model.generate(
91
+ input_ids,
92
+ images=pixel_values,
93
+ image_sizes=image_sizes,
94
+ image_grid_thws=_image_grid_thw,
95
+ do_sample=False,
96
+ temperature=0,
97
+ max_new_tokens=4096,
98
+ )
99
+ text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
100
+ print(text_outputs)
101
+
102
+
103
+
104
+ question = DEFAULT_IMAGE_TOKEN + "\nFY2021の連結の値はいくつか?"
105
+ conv = copy.deepcopy(conv_templates[conv_template])
106
+ conv.append_message(conv.roles[0], question)
107
+ conv.append_message(conv.roles[1], None)
108
+ prompt_question = conv.get_prompt()
109
+
110
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
111
+ image_sizes = [image.size]
112
+
113
+ cont = model.generate(
114
+ input_ids,
115
+ images=pixel_values,
116
+ image_sizes=image_sizes,
117
+ image_grid_thws=_image_grid_thw,
118
+ do_sample=False,
119
+ temperature=0,
120
+ max_new_tokens=4096,
121
+ )
122
+ text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
123
+ print(text_outputs)
124
+
125
+ question = DEFAULT_IMAGE_TOKEN + "\nこの図は何を表しているか?"
126
+ conv = copy.deepcopy(conv_templates[conv_template])
127
+ conv.append_message(conv.roles[0], question)
128
+ conv.append_message(conv.roles[1], None)
129
+ prompt_question = conv.get_prompt()
130
+
131
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
132
+ image_sizes = [image.size]
133
+
134
+ cont = model.generate(
135
+ input_ids,
136
+ images=pixel_values,
137
+ image_sizes=image_sizes,
138
+ image_grid_thws=_image_grid_thw,
139
+ do_sample=False,
140
+ temperature=0,
141
+ max_new_tokens=4096,
142
+ )
143
+ text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
144
+ print(text_outputs)
145
+
146
+
147
+ question = DEFAULT_IMAGE_TOKEN + "\nFY2020の純利益はマイナスか?プラスか?"
148
+ conv = copy.deepcopy(conv_templates[conv_template])
149
+ conv.append_message(conv.roles[0], question)
150
+ conv.append_message(conv.roles[1], None)
151
+ prompt_question = conv.get_prompt()
152
+
153
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
154
+ image_sizes = [image.size]
155
+
156
+ cont = model.generate(
157
+ input_ids,
158
+ images=pixel_values,
159
+ image_sizes=image_sizes,
160
+ image_grid_thws=_image_grid_thw,
161
+ do_sample=False,
162
+ temperature=0,
163
+ max_new_tokens=4096,
164
+ )
165
+ text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
166
+ print(text_outputs)
167
+
168
+
169
+ question = DEFAULT_IMAGE_TOKEN + "\n単体が連結の利益を上回るのはいつからか?"
170
+ conv = copy.deepcopy(conv_templates[conv_template])
171
+ conv.append_message(conv.roles[0], question)
172
+ conv.append_message(conv.roles[1], None)
173
+ prompt_question = conv.get_prompt()
174
+
175
+ input_ids = tokenizer_image_token(prompt_question, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(device)
176
+ image_sizes = [image.size]
177
+
178
+ cont = model.generate(
179
+ input_ids,
180
+ images=pixel_values,
181
+ image_sizes=image_sizes,
182
+ image_grid_thws=_image_grid_thw,
183
+ do_sample=False,
184
+ temperature=0,
185
+ max_new_tokens=4096,
186
+ )
187
+ text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
188
+ print(text_outputs)
189
+
190
+
191
+ ```