File size: 8,299 Bytes
f4acc5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
# REF: https://github.com/deepseek-ai/Janus
import numpy as np
import torch
from axengine import InferenceSession
from ml_dtypes import bfloat16
from transformers import AutoModel, AutoTokenizer, AutoConfig, AutoModelForCausalLM
from tqdm import tqdm
from einops import rearrange
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.models.modeling_vlm import MultiModalityConfig
from janus.utils.io import load_pil_images
import argparse
import os


parser = argparse.ArgumentParser(description="Model configuration parameters")
parser.add_argument("--tokenizer_dir", type=str, default="Janus-Pro-1B",
                    help="Path to HuggingFace model")
parser.add_argument("--axmodel_path", type=str, default="janus_pro_1B_axmodel",
                    help="Path to save compiled axmodel of llama model")
parser.add_argument("-i", "--test_img_path", type=str, default="./imgs/image.png",
                    help="Test image path (supports png/jpg formats)")
parser.add_argument("--vit_axmodel_path", type=str, default="vit_axmodel/janus_warp_vit.axmodel",
                    help="Path to ViT model's axmodel")

args = parser.parse_args()

# base info
tokenizer_dir = args.tokenizer_dir
axmodel_path = args.axmodel_path
test_img_path = args.test_img_path
vit_axmodel_path = args.vit_axmodel_path
embeds = np.load(os.path.join(args.axmodel_path, "model.embed_tokens.weight.npy"))


def prepare_inputs_embeds(
    input_ids: torch.LongTensor,
    pixel_values: torch.FloatTensor,
    images_seq_mask: torch.LongTensor,
    images_emb_mask: torch.LongTensor,
    **kwargs,
):
    """

    Args:
        input_ids (torch.LongTensor): [b, T]
        pixel_values (torch.FloatTensor):   [b, n_images, 3, h, w]
        images_seq_mask (torch.BoolTensor): [b, T]
        images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]

        assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)

    Returns:
        input_embeds (torch.Tensor): [b, T, D]
    """

    bs, n = pixel_values.shape[0:2]
    images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
    # [b x n, T2, D]
    vit_session = InferenceSession(vit_axmodel_path)
    images_embeds = vit_session.run(None, {"image": pixel_values[0].numpy()})[0] # pixel_values: [1, 1, 3, 384, 384]
    print(f"vit_output.shape is {images_embeds.shape}, vit feature extract done!")

    # [b x n, T2, D] -> [b, n x T2, D]
    images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
    # [b, n, T2] -> [b, n x T2]
    images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")

    # [b, T, D]
    input_ids[input_ids < 0] = 0  # ignore the image embeddings
    inputs_embeds = np.take(embeds, input_ids[0].cpu().numpy().tolist(), axis=0)[None, ...]
    inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]

    return inputs_embeds

def post_process(data, topk=1, topp=0.9, temperature=0.6):
    def top_p(l: np.ndarray, p: float) -> np.ndarray:
        index = np.argsort(l)
        res = l.copy()
        sum_p = 0
        for i in index[::-1]:
            if sum_p >= p:
                res[i] = 0
            sum_p += res[i]
        return res / sum_p

    def softmax(l: np.ndarray) -> np.ndarray:
        l_max = l - l.max()
        l_exp = np.exp(l_max)
        res = l_exp / np.sum(l_exp)
        return res.astype(np.float64)

    r = data.astype(np.float32)
    r = r.flatten()
    candidate_index = np.argpartition(r, -topk)[-topk:]
    candidate_value = r[candidate_index]
    candidate_value /= temperature
    candidate_soft = softmax(candidate_value)
    candidate_soft = top_p(candidate_soft, topp)
    candidate_soft = candidate_soft.astype(np.float64) / candidate_soft.sum()
    pos = np.random.multinomial(1, candidate_soft).argmax()
    next_token = candidate_index[pos]
    return next_token, candidate_index, candidate_soft

config: MultiModalityConfig = AutoConfig.from_pretrained(tokenizer_dir, trust_remote_code=True)
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(tokenizer_dir)
tokenizer = vl_chat_processor.tokenizer

# question = "请尝试理解这幅图中的内容."
question = "Please describe the picture."
conversation = [
    {
        "role": "User",
        "content": f"<image_placeholder>\n{question}",
        "images": [test_img_path],
    },
    {"role": "Assistant", "content": ""},
]

# load images and prepare for inputs
pil_images = load_pil_images(conversation)
prepare_inputs = vl_chat_processor(
    conversations=conversation, images=pil_images, force_batchify=True
)

input_embedding = prepare_inputs_embeds(**prepare_inputs)
token_ids = prepare_inputs['input_ids'].squeeze().numpy().tolist()
prefill_data = input_embedding
prefill_data = prefill_data.astype(bfloat16)
token_len = len(token_ids)

lastN = 1023
cfg = config.language_config

kv_dim = cfg.hidden_size // cfg.num_attention_heads * cfg.num_key_value_heads
k_caches = [
    np.zeros((1, lastN, kv_dim), dtype=bfloat16)
    for _ in range(cfg.num_hidden_layers)
]
v_caches = [
    np.zeros((1, lastN, kv_dim), dtype=bfloat16)
    for _ in range(cfg.num_hidden_layers)
]

prefill_decoder_sessins = []
for i in tqdm(range(cfg.num_hidden_layers), desc="Init InferenceSession"):
    session = InferenceSession(
        f"{axmodel_path}/llama_p640_l{i}_together.axmodel"
    )
    prefill_decoder_sessins.append(session)
post_process_session = InferenceSession(
    f"{axmodel_path}/llama_post.axmodel"
)
print("model load done!")

"""
    prefill
"""
prefill_len = 640

if prefill_len > 0:
    indices = np.array(list(range(prefill_len)), np.uint32).reshape(
        (1, prefill_len)
    )
    indices[:, token_len:] = 0
    mask = np.zeros((1, prefill_len, prefill_len)) - 65536
    data = np.zeros((1, prefill_len, cfg.hidden_size)).astype(bfloat16)
    data[:, 0:token_len] = prefill_data
    for i, t in enumerate(token_ids):
        mask[:, i, : i + 1] = 0
    mask = mask.astype(bfloat16)
    for i in range(cfg.num_hidden_layers):
        input_feed = {
            "K_cache": np.zeros((1, 1, cfg.hidden_size), dtype=bfloat16),
            "V_cache": np.zeros((1, 1, cfg.hidden_size), dtype=bfloat16),
            "indices": indices,
            "input": data,
            "mask": mask,
        }
        outputs = prefill_decoder_sessins[i].run(None, input_feed, shape_group=1)
        k_caches[i][:, :token_len, :] = outputs[0][:, :token_len, :]
        v_caches[i][:, :token_len, :] = outputs[1][:, :token_len, :]
        data[:, :token_len] = outputs[2][:, :token_len, :]

post_out = post_process_session.run(None, {"input": data[:, token_len - 1, :][None, ...]})[0]
next_token, posssible_tokens, possible_soft = post_process(post_out, topk=1)
posibles = [tokenizer.decode([t]) for t in posssible_tokens]
posible_soft = [str((t, s)) for t, s in zip(posibles, possible_soft)]
token_ids.append(next_token)
print("prefill done!")

"""
    decode
"""
mask = np.zeros((1, 1, lastN + 1), dtype=np.float32).astype(bfloat16)
mask[:, :, :lastN] -= 65536
mask[:, :, :token_len] = 0
for start_indice in tqdm(range(lastN + 1), desc="Decoder"): # lastN + 1
    if prefill_len > 0 and start_indice < token_len:
        continue
    next_token = token_ids[start_indice]
    indices = np.array([start_indice], np.uint32).reshape((1, 1))
    data = embeds[next_token, :].reshape((1, 1, cfg.hidden_size)).astype(bfloat16)

    for i in range(cfg.num_hidden_layers):
        input_feed = {
            "K_cache": k_caches[i],
            "V_cache": v_caches[i],
            "indices": indices,
            "input": data,
            "mask": mask,
        }
        outputs = prefill_decoder_sessins[i].run(None, input_feed, shape_group=0)
        k_caches[i][:, start_indice, :] = outputs[0][:, :, :]
        v_caches[i][:, start_indice, :] = outputs[1][:, :, :]
        data = outputs[2]

    mask[..., start_indice] = 0
    if start_indice < token_len - 1:
        pass
    else:
        post_out = post_process_session.run(None, {"input": data})[0]
        next_token, posssible_tokens, possible_soft = post_process(post_out)
        token_ids.append(next_token)
    if next_token == tokenizer.eos_token_id:
        print("hit eos!")
        break
print("Janus Answers: ", tokenizer.decode(token_ids[token_len:], skip_special_tokens=True))