File size: 8,299 Bytes
f4acc5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
# REF: https://github.com/deepseek-ai/Janus
import numpy as np
import torch
from axengine import InferenceSession
from ml_dtypes import bfloat16
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM
from tqdm import tqdm
from einops import rearrange
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.models.modeling_vlm import MultiModalityConfig
from janus.utils.io import load_pil_images
import argparse
import os
parser = argparse.ArgumentParser(description="Model configuration parameters")
parser.add_argument("--tokenizer_dir", type=str, default="Janus-Pro-1B",
help="Path to HuggingFace model")
parser.add_argument("--axmodel_path", type=str, default="janus_pro_1B_axmodel",
help="Path to save compiled axmodel of llama model")
parser.add_argument("-i", "--test_img_path", type=str, default="./imgs/image.png",
help="Test image path (supports png/jpg formats)")
parser.add_argument("--vit_axmodel_path", type=str, default="vit_axmodel/janus_warp_vit.axmodel",
help="Path to ViT model's axmodel")
args = parser.parse_args()
# base info
tokenizer_dir = args.tokenizer_dir
axmodel_path = args.axmodel_path
test_img_path = args.test_img_path
vit_axmodel_path = args.vit_axmodel_path
embeds = np.load(os.path.join(args.axmodel_path, "model.embed_tokens.weight.npy"))
def prepare_inputs_embeds(
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
images_seq_mask: torch.LongTensor,
images_emb_mask: torch.LongTensor,
**kwargs,
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
bs, n = pixel_values.shape[0:2]
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
vit_session = InferenceSession(vit_axmodel_path)
images_embeds = vit_session.run(None, {"image": pixel_values[0].numpy()})[0] # pixel_values: [1, 1, 3, 384, 384]
print(f"vit_output.shape is {images_embeds.shape}, vit feature extract done!")
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# [b, T, D]
input_ids[input_ids < 0] = 0 # ignore the image embeddings
inputs_embeds = np.take(embeds, input_ids[0].cpu().numpy().tolist(), axis=0)[None, ...]
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
return inputs_embeds
def post_process(data, topk=1, topp=0.9, temperature=0.6):
def top_p(l: np.ndarray, p: float) -> np.ndarray:
index = np.argsort(l)
res = l.copy()
sum_p = 0
for i in index[::-1]:
if sum_p >= p:
res[i] = 0
sum_p += res[i]
return res / sum_p
def softmax(l: np.ndarray) -> np.ndarray:
l_max = l - l.max()
l_exp = np.exp(l_max)
res = l_exp / np.sum(l_exp)
return res.astype(np.float64)
r = data.astype(np.float32)
r = r.flatten()
candidate_index = np.argpartition(r, -topk)[-topk:]
candidate_value = r[candidate_index]
candidate_value /= temperature
candidate_soft = softmax(candidate_value)
candidate_soft = top_p(candidate_soft, topp)
candidate_soft = candidate_soft.astype(np.float64) / candidate_soft.sum()
pos = np.random.multinomial(1, candidate_soft).argmax()
next_token = candidate_index[pos]
return next_token, candidate_index, candidate_soft
config: MultiModalityConfig = AutoConfig.from_pretrained(tokenizer_dir, trust_remote_code=True)
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(tokenizer_dir)
tokenizer = vl_chat_processor.tokenizer
# question = "请尝试理解这幅图中的内容."
question = "Please describe the picture."
conversation = [
{
"role": "User",
"content": f"<image_placeholder>\n{question}",
"images": [test_img_path],
},
{"role": "Assistant", "content": ""},
]
# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
conversations=conversation, images=pil_images, force_batchify=True
)
input_embedding = prepare_inputs_embeds(**prepare_inputs)
token_ids = prepare_inputs['input_ids'].squeeze().numpy().tolist()
prefill_data = input_embedding
prefill_data = prefill_data.astype(bfloat16)
token_len = len(token_ids)
lastN = 1023
cfg = config.language_config
kv_dim = cfg.hidden_size // cfg.num_attention_heads * cfg.num_key_value_heads
k_caches = [
np.zeros((1, lastN, kv_dim), dtype=bfloat16)
for _ in range(cfg.num_hidden_layers)
]
v_caches = [
np.zeros((1, lastN, kv_dim), dtype=bfloat16)
for _ in range(cfg.num_hidden_layers)
]
prefill_decoder_sessins = []
for i in tqdm(range(cfg.num_hidden_layers), desc="Init InferenceSession"):
session = InferenceSession(
f"{axmodel_path}/llama_p640_l{i}_together.axmodel"
)
prefill_decoder_sessins.append(session)
post_process_session = InferenceSession(
f"{axmodel_path}/llama_post.axmodel"
)
print("model load done!")
"""
prefill
"""
prefill_len = 640
if prefill_len > 0:
indices = np.array(list(range(prefill_len)), np.uint32).reshape(
(1, prefill_len)
)
indices[:, token_len:] = 0
mask = np.zeros((1, prefill_len, prefill_len)) - 65536
data = np.zeros((1, prefill_len, cfg.hidden_size)).astype(bfloat16)
data[:, 0:token_len] = prefill_data
for i, t in enumerate(token_ids):
mask[:, i, : i + 1] = 0
mask = mask.astype(bfloat16)
for i in range(cfg.num_hidden_layers):
input_feed = {
"K_cache": np.zeros((1, 1, cfg.hidden_size), dtype=bfloat16),
"V_cache": np.zeros((1, 1, cfg.hidden_size), dtype=bfloat16),
"indices": indices,
"input": data,
"mask": mask,
}
outputs = prefill_decoder_sessins[i].run(None, input_feed, shape_group=1)
k_caches[i][:, :token_len, :] = outputs[0][:, :token_len, :]
v_caches[i][:, :token_len, :] = outputs[1][:, :token_len, :]
data[:, :token_len] = outputs[2][:, :token_len, :]
post_out = post_process_session.run(None, {"input": data[:, token_len - 1, :][None, ...]})[0]
next_token, posssible_tokens, possible_soft = post_process(post_out, topk=1)
posibles = [tokenizer.decode([t]) for t in posssible_tokens]
posible_soft = [str((t, s)) for t, s in zip(posibles, possible_soft)]
token_ids.append(next_token)
print("prefill done!")
"""
decode
"""
mask = np.zeros((1, 1, lastN + 1), dtype=np.float32).astype(bfloat16)
mask[:, :, :lastN] -= 65536
mask[:, :, :token_len] = 0
for start_indice in tqdm(range(lastN + 1), desc="Decoder"): # lastN + 1
if prefill_len > 0 and start_indice < token_len:
continue
next_token = token_ids[start_indice]
indices = np.array([start_indice], np.uint32).reshape((1, 1))
data = embeds[next_token, :].reshape((1, 1, cfg.hidden_size)).astype(bfloat16)
for i in range(cfg.num_hidden_layers):
input_feed = {
"K_cache": k_caches[i],
"V_cache": v_caches[i],
"indices": indices,
"input": data,
"mask": mask,
}
outputs = prefill_decoder_sessins[i].run(None, input_feed, shape_group=0)
k_caches[i][:, start_indice, :] = outputs[0][:, :, :]
v_caches[i][:, start_indice, :] = outputs[1][:, :, :]
data = outputs[2]
mask[..., start_indice] = 0
if start_indice < token_len - 1:
pass
else:
post_out = post_process_session.run(None, {"input": data})[0]
next_token, posssible_tokens, possible_soft = post_process(post_out)
token_ids.append(next_token)
if next_token == tokenizer.eos_token_id:
print("hit eos!")
break
print("Janus Answers: ", tokenizer.decode(token_ids[token_len:], skip_special_tokens=True))
|