File size: 10,416 Bytes
99e9899
4ce7387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e9899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ce7387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e9899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ce7387
 
 
 
 
 
99e9899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
from collections import defaultdict
import logging
import re
from typing import Optional

import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

# Role tokens
AI = "AI: "
HUMAN = "Human: "
_AI = "\n" + AI
_HUMAN = "\n" + HUMAN

# special media tokens
IMAGE = "<image>"
IMAGE_ROW_SEPARATOR = "\n"
IMAGE_GLOBAL_LOCAL_SEPARATOR = "\n"
MEDIA_TOKENS = {
    "image": [IMAGE],
}

_INFINITE = int(1e12)  # infinite token length for no-truncation

logger = logging.getLogger("kanana-1.5-v")


class AttrDict(dict):
    __slots__ = ()

    def __getattr__(self, name):
        try:
            val = self[name]
        except KeyError:
            raise AttributeError(name) from None

        if isinstance(val, dict) and not isinstance(val, AttrDict):
            val = AttrDict(val)
            self[name] = val
        return val

    def __setattr__(self, name, value):
        if name.startswith('_'):
            return super().__setattr__(name, value)
        if isinstance(value, dict) and not isinstance(value, AttrDict):
            value = AttrDict(value)
        self[name] = value

    def __delattr__(self, name):
        try:
            del self[name]
        except KeyError:
            raise AttributeError(name) from None


def to_attrdict(obj):
    if isinstance(obj, dict) and not isinstance(obj, AttrDict):
        return AttrDict({k: to_attrdict(v) for k, v in obj.items()})
    if isinstance(obj, list):
        return [to_attrdict(x) for x in obj]
    if isinstance(obj, tuple):
        return tuple(to_attrdict(x) for x in obj)
    return obj


def _pad_trunc(
    x: list[list[int]],
    padding: str,
    padding_side: str,
    pad_value: int,
    max_length: int,
) -> torch.LongTensor:
    """Pad and truncate sequences to the same length

    Args:
        x (list[list[int]])
        padding ("longest" or "max_length")
        padding_side ("left" or "right")
        pad_value (int)
        max_length (int or None): if padding == "max_length", max_length should be given.
    """
    assert padding in ["longest", "max_length"]
    assert padding_side in ["left", "right"]

    lengths = [len(sample) for sample in x]
    if padding == "longest":
        max_length = max(lengths)

    new_x = []
    for sample, length in zip(x, lengths):
        if torch.is_tensor(sample):
            sample = sample.tolist()

        if length >= max_length:
            new_x.append(sample[:max_length])
            continue

        padding_size = max_length - length
        pads = [pad_value] * padding_size
        if padding_side == "right":
            new_x.append(sample + pads)
        else:
            new_x.append(pads + sample)

    return torch.as_tensor(new_x, dtype=torch.long)


class KananaVTokenizerMixin:
    def mllm_setup(self, num_visual_tokens: int):
        self.num_visual_tokens = num_visual_tokens

        # Currently we only support the image modality for media modality.
        self.media_tokens = {k: -int(i + 1) for i, k in enumerate(MEDIA_TOKENS["image"])}
        self.media_lengths = {MEDIA_TOKENS["image"][0]: num_visual_tokens}

    def repeat_image_tokens(
        self, hw_tokens, with_row_separator=True, add_global_local_separator=False
    ):
        if len(hw_tokens) == 3:
            T, H, W = hw_tokens
        else:
            H, W = hw_tokens

        repeated_tokens = []

        if add_global_local_separator:
            global_local_separator = self(IMAGE_GLOBAL_LOCAL_SEPARATOR, add_special_tokens=False)[
                "input_ids"
            ]

            repeated_tokens += global_local_separator

        if with_row_separator:
            row_sep = self(IMAGE_ROW_SEPARATOR, add_special_tokens=False)["input_ids"]

        for h_idx in range(H):
            repeated_tokens += [self.media_tokens[IMAGE]] * W
            if with_row_separator and h_idx != H - 1:
                repeated_tokens += row_sep

        return repeated_tokens

    def encode_prompt(
        self, prompt: str, max_length: int | None = None, image_meta: dict | None = None
    ) -> dict:
        """Tokenize prompt which consists of image-text or text only, with role tokens.
        Role pattern is "AI: " or "Human: ".

        Args:
            prompt
            max_length (int or None): here, max_length is used for truncation.
                If max_length is None, no truncation is applied.
        """
        max_length = max_length or _INFINITE  # if None, set to infinite for no-truncation

        # output enc_chunk
        enc_chunk = []

        # Text-only or Image-Text Data
        # split prompt into chunks by media and role tokens
        tokens_to_split = list(self.media_tokens.keys()) + [_AI, _HUMAN]
        pattern = "|".join(map(re.escape, tokens_to_split))
        chunk_strs = re.split(f"({pattern})", prompt)
        chunk_strs = [x for x in chunk_strs if len(x) > 0]
        # tokenize chunks
        img_idx = 0  # for sync with image_meta
        for idx, chunk_str in enumerate(chunk_strs):
            if chunk_str in self.media_tokens:
                if chunk_str == IMAGE:
                    image_token_thw = (
                        image_meta["image_token_thw"][img_idx]
                        if image_meta.get("image_token_thw")
                        else None
                    )

                    media_tokens = self.repeat_image_tokens(
                        image_token_thw,
                        with_row_separator=True,
                        add_global_local_separator=True,
                    )
                    # increment image index
                    img_idx += 1

                else:
                    raise ValueError("Unknown chunk str", chunk_str)

                enc_chunk += media_tokens

            else:
                curr_chunk = self(chunk_str, add_special_tokens=False)["input_ids"]
                enc_chunk += curr_chunk

        L = len(enc_chunk)

        input_ids = torch.as_tensor(enc_chunk, dtype=torch.long)
        attention_mask = torch.ones_like(input_ids)

        assert L <= max_length, (
            f"[Length exceeded] Input sequence length ({L}) is greater than "
            f"the allowed max_length ({max_length}). "
            "Please truncate the sequence or increase max_length."
        )

        return {
            "input_ids": input_ids,  # [L]
            "seq_length": L,  # int
            "attention_mask": attention_mask,  # [L]
        }

    def batch_collate_pad(
        self,
        batch: list,
        padding: str,
        padding_side: str,
        max_length: int | None,
    ) -> dict[str, torch.LongTensor]:
        """Collate batch and pad/truncate to the same length

        Args:
            batch
            padding ("longest" or "max_length")
            padding_side ("left" or "right")
            pad_value (int)
            max_length (int or None): if padding == "max_length", max_length should be given
        """
        if padding == "max_length":
            assert max_length is not None, "max_length should be given if padding == 'max_length'"
        else:
            # if padding == 'longest' and max_length is None, set to infinite for no-truncation
            max_length = max_length or _INFINITE

        input_ids = [sample["input_ids"] for sample in batch]
        attention_mask = [sample["attention_mask"] for sample in batch]
        seq_length = [sample["seq_length"] for sample in batch]

        input_ids = _pad_trunc(input_ids, padding, padding_side, self.pad_token_id, max_length)
        attention_mask = _pad_trunc(attention_mask, padding, padding_side, 0, max_length)
        seq_length = torch.as_tensor(seq_length, dtype=torch.long)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "seq_length": seq_length,
        }

    def get_chat_template(self) -> str:
        """Method for bw-compat: old HF transformers (e.g., 4.41.0) does not have get_chat_template
        """
        return self.chat_template


class KananaVTokenizer(PreTrainedTokenizer, KananaVTokenizerMixin):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __call__(self, text, *args, **kwargs):
        assert isinstance(text, str), "Only str is supported for tokenization."

        # split prompt into chunks by role tokens: text (str) -> chunk_strs (list)
        tokens_to_split = [_AI, _HUMAN]
        pattern = "|".join(map(re.escape, tokens_to_split))
        if re.search(pattern, text):
            chunk_strs = re.split(f"({pattern})", text)
            chunk_strs = [x for x in chunk_strs if len(x) > 0]

            # encode chunk strs
            kwargs["add_special_tokens"] = False
            encodings = defaultdict(list)
            for chunk_str in chunk_strs:
                encoding = super().__call__(chunk_str, *args, **kwargs)
                for k, v in encoding.items():
                    encodings[k].extend(v)
            encodings = to_attrdict(encodings)
            return encodings
        else:
            return super().__call__(text, *args, **kwargs)

    def encode(self, *args, **kwargs):
        return self.__call__(*args, **kwargs)["input_ids"]


class KananaVTokenizerFast(PreTrainedTokenizerFast, KananaVTokenizerMixin):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __call__(self, text, *args, **kwargs):
        assert isinstance(text, str), "Only str is supported for fast tokenization."

        # split prompt into chunks by role tokens: text (str) -> chunk_strs (list)
        tokens_to_split = [_AI, _HUMAN]
        pattern = "|".join(map(re.escape, tokens_to_split))
        if re.search(pattern, text):
            chunk_strs = re.split(f"({pattern})", text)
            chunk_strs = [x for x in chunk_strs if len(x) > 0]

            # encode chunk strs
            kwargs["add_special_tokens"] = False
            encodings = defaultdict(list)
            for chunk_str in chunk_strs:
                encoding = super().__call__(chunk_str, *args, **kwargs)
                for k, v in encoding.items():
                    encodings[k].extend(v)
            encodings = to_attrdict(encodings)
            return encodings
        else:
            return super().__call__(text, *args, **kwargs)

    def encode(self, *args, **kwargs):
        return self.__call__(*args, **kwargs)["input_ids"]