|
|
|
import numpy as np |
|
import torch |
|
from axengine import InferenceSession |
|
from ml_dtypes import bfloat16 |
|
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM |
|
from tqdm import tqdm |
|
from einops import rearrange |
|
from janus.models import MultiModalityCausalLM, VLChatProcessor |
|
from janus.models.modeling_vlm import MultiModalityConfig |
|
from janus.utils.io import load_pil_images |
|
import os |
|
import PIL.Image |
|
from loguru import logger |
|
import onnxruntime |
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Model configuration parameters") |
|
parser.add_argument("--tokenizer_dir", type=str, default="Janus-Pro-1B", |
|
help="Path to HuggingFace model") |
|
parser.add_argument("--axmodel_path", type=str, default="janus_pro_1B_axmodel", |
|
help="Path to save compiled axmodel of llama model") |
|
args = parser.parse_args() |
|
|
|
|
|
|
|
tokenizer_dir = args.tokenizer_dir |
|
axmodel_path = args.axmodel_path |
|
|
|
"""ONNX MODEL""" |
|
gen_vision_model_decode = onnxruntime.InferenceSession("./img_gen_onnx/gen_vision_model_decode_sim.onnx", providers=["CPUExecutionProvider"]) |
|
gen_aligner = onnxruntime.InferenceSession("./img_gen_onnx/gen_aligner.onnx", providers=["CPUExecutionProvider"]) |
|
gen_head = onnxruntime.InferenceSession("./img_gen_onnx/post_head.onnx", providers=["CPUExecutionProvider"]) |
|
post_norm = onnxruntime.InferenceSession("./img_gen_onnx/post_norm.onnx", providers=["CPUExecutionProvider"]) |
|
"""ONNX MODEL""" |
|
|
|
"""EMBEDINGs""" |
|
embeds = np.load(f"{axmodel_path}/model.embed_tokens.weight.npy") |
|
gen_embed = np.load("./embeds/gen_embed.npy") |
|
codebook_entry_embedding = torch.load('./embeds/codebook_entry_embedding.pt', map_location=torch.device('cpu')) |
|
"""EMBEDINGs""" |
|
|
|
|
|
def prefill( |
|
cfg, |
|
prefill_decoder_sessins, |
|
vl_chat_processor: VLChatProcessor, |
|
prompt: str, |
|
temperature: float = 1, |
|
parallel_size: int = 1, |
|
cfg_weight: float = 5, |
|
image_token_num_per_image: int = 576, |
|
): |
|
input_ids = vl_chat_processor.tokenizer.encode(prompt) |
|
input_ids = torch.LongTensor(input_ids) |
|
|
|
tokens = torch.zeros((parallel_size*2, len(input_ids)), dtype=torch.int) |
|
for i in range(parallel_size*2): |
|
tokens[i, :] = input_ids |
|
if i % 2 != 0: |
|
tokens[i, 1: -1] = vl_chat_processor.pad_id |
|
|
|
inputs_embeds = embeds[tokens.numpy()] |
|
batch, token_len, seq_dim = inputs_embeds.shape |
|
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int) |
|
prefill_len = 640 |
|
token_ids = tokens |
|
|
|
|
|
lastN = 1023 |
|
kv_dim = cfg.hidden_size // cfg.num_attention_heads * cfg.num_key_value_heads |
|
batch_k_caches = {} |
|
batch_v_caches = {} |
|
|
|
for bid in range(batch): |
|
batch_k_caches[bid] = [ |
|
np.zeros((1, lastN, kv_dim), dtype=bfloat16) |
|
for _ in range(cfg.num_hidden_layers) |
|
] |
|
batch_v_caches[bid] = [ |
|
np.zeros((1, lastN, kv_dim), dtype=bfloat16) |
|
for _ in range(cfg.num_hidden_layers) |
|
] |
|
|
|
mask = np.zeros((1, prefill_len, prefill_len)) - 65536 |
|
for j in range(token_len): |
|
mask[:, j, :j + 1] = 0 |
|
mask = mask.astype(bfloat16) |
|
|
|
indices = np.array(list(range(prefill_len)), np.uint32).reshape( |
|
(1, prefill_len) |
|
) |
|
indices[:, token_len:] = 0 |
|
hidden_states = np.zeros((batch, token_len, cfg.hidden_size)).astype(bfloat16) |
|
|
|
for bid in range(batch): |
|
data = np.zeros((1, prefill_len, cfg.hidden_size)).astype(bfloat16) |
|
data[:, 0:token_len] = inputs_embeds[bid].astype(bfloat16) |
|
k_caches = batch_k_caches[bid] |
|
v_caches = batch_v_caches[bid] |
|
|
|
for i in range(cfg.num_hidden_layers): |
|
input_feed = { |
|
"K_cache": np.zeros((1, 1, cfg.hidden_size), dtype=bfloat16), |
|
"V_cache": np.zeros((1, 1, cfg.hidden_size), dtype=bfloat16), |
|
"indices": indices, |
|
"input": data, |
|
"mask": mask, |
|
} |
|
outputs = prefill_decoder_sessins[i].run(None, input_feed, shape_group=1) |
|
k_caches[i][:, :token_len, :] = outputs[0][:, :token_len, :] |
|
v_caches[i][:, :token_len, :] = outputs[1][:, :token_len, :] |
|
data[:, :token_len] = outputs[2][:, :token_len, :] |
|
|
|
|
|
hidden_states[bid] = data[:, :token_len] |
|
batch_k_caches[bid] = k_caches |
|
batch_v_caches[bid] = v_caches |
|
|
|
|
|
hidden_states = post_norm.run(["output"], {"input": hidden_states[:, -1:, :].astype(np.float32)})[0] |
|
logits = gen_head.run(["output"], {"input": hidden_states[:, -1, :]})[0] |
|
|
|
logits = torch.from_numpy(logits) |
|
logit_cond = logits[0::2, :] |
|
logit_uncond = logits[1::2, :] |
|
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond) |
|
probs = torch.softmax(logits / temperature, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
generated_tokens[:, 0] = next_token.squeeze(dim=-1) |
|
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) |
|
|
|
gen_embed_res = np.take(gen_embed, next_token.numpy().tolist(), axis=0) |
|
img_embeds = gen_aligner.run(["output"], {"input": gen_embed_res})[0] |
|
inputs_embeds = np.expand_dims(img_embeds, axis=1) |
|
return inputs_embeds, token_ids, generated_tokens, batch_k_caches, batch_v_caches |
|
|
|
|
|
@torch.inference_mode() |
|
def generate( |
|
cfg, |
|
prefill_decoder_sessins, |
|
vl_chat_processor: VLChatProcessor, |
|
prompt: str, |
|
temperature: float = 1, |
|
parallel_size: int = 1, |
|
cfg_weight: float = 5, |
|
image_token_num_per_image: int = 576, |
|
img_size: int = 384, |
|
patch_size: int = 16, |
|
): |
|
inputs_embeds, token_ids, generated_tokens, batch_k_caches, batch_v_caches = prefill( |
|
cfg, prefill_decoder_sessins, vl_chat_processor, |
|
prompt, temperature, parallel_size, cfg_weight, image_token_num_per_image |
|
) |
|
|
|
logger.debug("prefill completed!") |
|
token_len = token_ids.shape[1] |
|
|
|
lastN = 1023 |
|
|
|
batch = parallel_size * 2 |
|
|
|
mask = np.zeros((1, 1, lastN + 1), dtype=np.float32).astype(bfloat16) |
|
mask[:, :, :lastN] -= 65536 |
|
mask[:, :, :token_len] = 0 |
|
|
|
for image_token_i in tqdm(range(1, image_token_num_per_image), desc="ImageToken"): |
|
|
|
|
|
start_indice = image_token_i + token_len - 1 |
|
indices = np.array([start_indice], np.uint32).reshape((1, 1)) |
|
hidden_states = np.zeros((batch, 1, cfg.hidden_size)).astype(bfloat16) |
|
assert (inputs_embeds[0] == inputs_embeds[1]).all() |
|
|
|
for bid in range(batch): |
|
k_caches = batch_k_caches[bid] |
|
v_caches = batch_v_caches[bid] |
|
data = inputs_embeds[:1, ...].astype(bfloat16) |
|
|
|
for i in range(cfg.num_hidden_layers): |
|
input_feed = { |
|
"K_cache": k_caches[i], |
|
"V_cache": v_caches[i], |
|
"indices": indices, |
|
"input": data, |
|
"mask": mask, |
|
} |
|
|
|
outputs = prefill_decoder_sessins[i].run(None, input_feed, shape_group=0) |
|
k_caches[i][:, start_indice, :] = outputs[0][:, :, :] |
|
v_caches[i][:, start_indice, :] = outputs[1][:, :, :] |
|
data = outputs[2] |
|
|
|
hidden_states[bid] = data |
|
batch_k_caches[bid] = k_caches |
|
batch_v_caches[bid] = v_caches |
|
|
|
mask[..., start_indice] = 0 |
|
|
|
|
|
hidden_states = post_norm.run(["output"], {"input": hidden_states.astype(np.float32)})[0] |
|
logits = gen_head.run(["output"], {"input": hidden_states[:, -1, :]})[0] |
|
|
|
logits = torch.from_numpy(logits) |
|
logit_cond = logits[0::2, :] |
|
logit_uncond = logits[1::2, :] |
|
logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond) |
|
probs = torch.softmax(logits / temperature, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
generated_tokens[:, image_token_i] = next_token.squeeze(dim=-1) |
|
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) |
|
|
|
gen_embed_res = np.take(gen_embed, next_token.numpy().tolist(), axis=0) |
|
img_embeds = gen_aligner.run(["output"], {"input": gen_embed_res})[0] |
|
inputs_embeds = np.expand_dims(img_embeds, axis=1) |
|
|
|
|
|
indices = generated_tokens.to(dtype=torch.int) |
|
shape = [parallel_size, 8, img_size//patch_size, img_size//patch_size] |
|
z_q = codebook_entry_embedding[indices] |
|
z_q = z_q.reshape(shape[0], shape[2], shape[3], shape[1]) |
|
|
|
z_q = z_q.permute(0, 3, 1, 2) |
|
dec = gen_vision_model_decode.run(['image'], {'quant': z_q.to(dtype=torch.float32).numpy()})[0] |
|
dec = dec.transpose(0, 2, 3, 1) |
|
dec = np.clip((dec + 1) / 2 * 255, 0, 255) |
|
visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8) |
|
visual_img[:, :, :] = dec |
|
|
|
os.makedirs('generated_samples', exist_ok=True) |
|
for i in range(parallel_size): |
|
save_path = os.path.join('generated_samples', "img_{}.jpg".format(i)) |
|
PIL.Image.fromarray(visual_img[i]).save(save_path) |
|
|
|
|
|
config: MultiModalityConfig = AutoConfig.from_pretrained(tokenizer_dir, trust_remote_code=True) |
|
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(tokenizer_dir) |
|
tokenizer = vl_chat_processor.tokenizer |
|
|
|
description = "A close-up high-contrast photo of Sydney Opera House sitting next to Eiffel tower, under a blue night sky of roiling energy, exploding yellow stars, and radiating swirls of blue." |
|
|
|
conversation = [ |
|
{ |
|
"role": "User", |
|
"content": description, |
|
}, |
|
{"role": "Assistant", "content": ""}, |
|
] |
|
|
|
sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( |
|
conversations=conversation, |
|
sft_format=vl_chat_processor.sft_format, |
|
system_prompt="", |
|
) |
|
prompt = sft_format + vl_chat_processor.image_start_tag |
|
|
|
|
|
cfg = config.language_config |
|
|
|
prefill_decoder_sessins = [] |
|
for i in tqdm(range(cfg.num_hidden_layers), desc="Init InferenceSession"): |
|
session = InferenceSession( |
|
f"{axmodel_path}/llama_p640_l{i}_together.axmodel" |
|
) |
|
prefill_decoder_sessins.append(session) |
|
|
|
logger.info("model load done!") |
|
|
|
generate( |
|
cfg, |
|
prefill_decoder_sessins, |
|
vl_chat_processor, |
|
prompt |
|
) |
|
|