Adapters
GhostNetworkUser commited on
Commit
6b02503
·
verified ·
1 Parent(s): 28365d2

Upload 73 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. __init__.cpython-312.pyc +0 -0
  3. __init__.py +7 -9
  4. __init__.pyi +591 -0
  5. activations.py +239 -0
  6. activations_tf.py +147 -0
  7. audio_utils.py +1123 -0
  8. base_tokenizer.cpython-312.pyc +0 -0
  9. base_tokenizer.py +418 -0
  10. bert_wordpiece.cpython-312.pyc +0 -0
  11. bert_wordpiece.py +151 -0
  12. byte_level_bpe.cpython-312.pyc +0 -0
  13. byte_level_bpe.py +122 -0
  14. cache_utils.py +0 -0
  15. char_level_bpe.cpython-312.pyc +0 -0
  16. char_level_bpe.py +150 -0
  17. configuration_utils.py +1187 -0
  18. convert_graph_to_onnx.py +551 -0
  19. convert_pytorch_checkpoint_to_tf2.py +446 -0
  20. convert_slow_tokenizer.py +1642 -0
  21. convert_slow_tokenizers_checkpoints_to_fast.py +130 -0
  22. convert_tf_hub_seq_to_seq_bert_to_pytorch.py +87 -0
  23. debug_utils.py +346 -0
  24. dependency_versions_check.py +63 -0
  25. dependency_versions_table.py +102 -0
  26. dynamic_module_utils.py +685 -0
  27. feature_extraction_sequence_utils.py +372 -0
  28. feature_extraction_utils.py +702 -0
  29. file_utils.py +133 -0
  30. hf_argparser.py +437 -0
  31. hyperparameter_search.py +141 -0
  32. image_processing_base.py +559 -0
  33. image_processing_utils.py +287 -0
  34. image_processing_utils_fast.py +133 -0
  35. image_transforms.py +860 -0
  36. image_utils.py +871 -0
  37. keras_callbacks.py +413 -0
  38. modelcard.py +908 -0
  39. modeling_attn_mask_utils.py +481 -0
  40. modeling_flash_attention_utils.py +389 -0
  41. modeling_flax_outputs.py +700 -0
  42. modeling_flax_pytorch_utils.py +492 -0
  43. modeling_flax_utils.py +1290 -0
  44. modeling_gguf_pytorch_utils.py +471 -0
  45. modeling_outputs.py +0 -0
  46. modeling_rope_utils.py +568 -0
  47. modeling_tf_outputs.py +991 -0
  48. modeling_tf_pytorch_utils.py +673 -0
  49. modeling_tf_utils.py +0 -0
  50. modeling_utils.py +0 -0
.gitattributes CHANGED
@@ -59,3 +59,4 @@ torchrun.exe filter=lfs diff=lfs merge=lfs -text
59
  tqdm.exe filter=lfs diff=lfs merge=lfs -text
60
  transformers-cli.exe filter=lfs diff=lfs merge=lfs -text
61
  wheel.exe filter=lfs diff=lfs merge=lfs -text
 
 
59
  tqdm.exe filter=lfs diff=lfs merge=lfs -text
60
  transformers-cli.exe filter=lfs diff=lfs merge=lfs -text
61
  wheel.exe filter=lfs diff=lfs merge=lfs -text
62
+ tokenizers.pyd filter=lfs diff=lfs merge=lfs -text
__init__.cpython-312.pyc ADDED
Binary file (458 Bytes). View file
 
__init__.py CHANGED
@@ -1,10 +1,8 @@
1
- """torchgen
 
2
 
3
- This module contains codegeneration utilities for PyTorch. It is used to
4
- build PyTorch from source, but may also be used for out-of-tree projects
5
- that extend PyTorch.
6
-
7
- Note well that we provide no BC guarantees for torchgen. If you're interested
8
- in using torchgen and want the PyTorch team to be aware, please reach out
9
- on GitHub.
10
- """
 
1
+ # Generated content DO NOT EDIT
2
+ from .. import models
3
 
4
+ Model = models.Model
5
+ BPE = models.BPE
6
+ Unigram = models.Unigram
7
+ WordLevel = models.WordLevel
8
+ WordPiece = models.WordPiece
 
 
 
__init__.pyi ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated content DO NOT EDIT
2
+ class Model:
3
+ """
4
+ Base class for all models
5
+
6
+ The model represents the actual tokenization algorithm. This is the part that
7
+ will contain and manage the learned vocabulary.
8
+
9
+ This class cannot be constructed directly. Please use one of the concrete models.
10
+ """
11
+ def get_trainer(self):
12
+ """
13
+ Get the associated :class:`~tokenizers.trainers.Trainer`
14
+
15
+ Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this
16
+ :class:`~tokenizers.models.Model`.
17
+
18
+ Returns:
19
+ :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model
20
+ """
21
+ pass
22
+
23
+ def id_to_token(self, id):
24
+ """
25
+ Get the token associated to an ID
26
+
27
+ Args:
28
+ id (:obj:`int`):
29
+ An ID to convert to a token
30
+
31
+ Returns:
32
+ :obj:`str`: The token associated to the ID
33
+ """
34
+ pass
35
+
36
+ def save(self, folder, prefix):
37
+ """
38
+ Save the current model
39
+
40
+ Save the current model in the given folder, using the given prefix for the various
41
+ files that will get created.
42
+ Any file with the same name that already exists in this folder will be overwritten.
43
+
44
+ Args:
45
+ folder (:obj:`str`):
46
+ The path to the target folder in which to save the various files
47
+
48
+ prefix (:obj:`str`, `optional`):
49
+ An optional prefix, used to prefix each file name
50
+
51
+ Returns:
52
+ :obj:`List[str]`: The list of saved files
53
+ """
54
+ pass
55
+
56
+ def token_to_id(self, tokens):
57
+ """
58
+ Get the ID associated to a token
59
+
60
+ Args:
61
+ token (:obj:`str`):
62
+ A token to convert to an ID
63
+
64
+ Returns:
65
+ :obj:`int`: The ID associated to the token
66
+ """
67
+ pass
68
+
69
+ def tokenize(self, sequence):
70
+ """
71
+ Tokenize a sequence
72
+
73
+ Args:
74
+ sequence (:obj:`str`):
75
+ A sequence to tokenize
76
+
77
+ Returns:
78
+ A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens
79
+ """
80
+ pass
81
+
82
+ class BPE(Model):
83
+ """
84
+ An implementation of the BPE (Byte-Pair Encoding) algorithm
85
+
86
+ Args:
87
+ vocab (:obj:`Dict[str, int]`, `optional`):
88
+ A dictionary of string keys and their ids :obj:`{"am": 0,...}`
89
+
90
+ merges (:obj:`List[Tuple[str, str]]`, `optional`):
91
+ A list of pairs of tokens (:obj:`Tuple[str, str]`) :obj:`[("a", "b"),...]`
92
+
93
+ cache_capacity (:obj:`int`, `optional`):
94
+ The number of words that the BPE cache can contain. The cache allows
95
+ to speed-up the process by keeping the result of the merge operations
96
+ for a number of words.
97
+
98
+ dropout (:obj:`float`, `optional`):
99
+ A float between 0 and 1 that represents the BPE dropout to use.
100
+
101
+ unk_token (:obj:`str`, `optional`):
102
+ The unknown token to be used by the model.
103
+
104
+ continuing_subword_prefix (:obj:`str`, `optional`):
105
+ The prefix to attach to subword units that don't represent a beginning of word.
106
+
107
+ end_of_word_suffix (:obj:`str`, `optional`):
108
+ The suffix to attach to subword units that represent an end of word.
109
+
110
+ fuse_unk (:obj:`bool`, `optional`):
111
+ Whether to fuse any subsequent unknown tokens into a single one
112
+
113
+ byte_fallback (:obj:`bool`, `optional`):
114
+ Whether to use spm byte-fallback trick (defaults to False)
115
+
116
+ ignore_merges (:obj:`bool`, `optional`):
117
+ Whether or not to match tokens with the vocab before using merges.
118
+ """
119
+ def __init__(
120
+ self,
121
+ vocab=None,
122
+ merges=None,
123
+ cache_capacity=None,
124
+ dropout=None,
125
+ unk_token=None,
126
+ continuing_subword_prefix=None,
127
+ end_of_word_suffix=None,
128
+ fuse_unk=None,
129
+ byte_fallback=False,
130
+ ignore_merges=False,
131
+ ):
132
+ pass
133
+
134
+ @staticmethod
135
+ def from_file(cls, vocab, merge, **kwargs):
136
+ """
137
+ Instantiate a BPE model from the given files.
138
+
139
+ This method is roughly equivalent to doing::
140
+
141
+ vocab, merges = BPE.read_file(vocab_filename, merges_filename)
142
+ bpe = BPE(vocab, merges)
143
+
144
+ If you don't need to keep the :obj:`vocab, merges` values lying around,
145
+ this method is more optimized than manually calling
146
+ :meth:`~tokenizers.models.BPE.read_file` to initialize a :class:`~tokenizers.models.BPE`
147
+
148
+ Args:
149
+ vocab (:obj:`str`):
150
+ The path to a :obj:`vocab.json` file
151
+
152
+ merges (:obj:`str`):
153
+ The path to a :obj:`merges.txt` file
154
+
155
+ Returns:
156
+ :class:`~tokenizers.models.BPE`: An instance of BPE loaded from these files
157
+ """
158
+ pass
159
+
160
+ def get_trainer(self):
161
+ """
162
+ Get the associated :class:`~tokenizers.trainers.Trainer`
163
+
164
+ Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this
165
+ :class:`~tokenizers.models.Model`.
166
+
167
+ Returns:
168
+ :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model
169
+ """
170
+ pass
171
+
172
+ def id_to_token(self, id):
173
+ """
174
+ Get the token associated to an ID
175
+
176
+ Args:
177
+ id (:obj:`int`):
178
+ An ID to convert to a token
179
+
180
+ Returns:
181
+ :obj:`str`: The token associated to the ID
182
+ """
183
+ pass
184
+
185
+ @staticmethod
186
+ def read_file(self, vocab, merges):
187
+ """
188
+ Read a :obj:`vocab.json` and a :obj:`merges.txt` files
189
+
190
+ This method provides a way to read and parse the content of these files,
191
+ returning the relevant data structures. If you want to instantiate some BPE models
192
+ from memory, this method gives you the expected input from the standard files.
193
+
194
+ Args:
195
+ vocab (:obj:`str`):
196
+ The path to a :obj:`vocab.json` file
197
+
198
+ merges (:obj:`str`):
199
+ The path to a :obj:`merges.txt` file
200
+
201
+ Returns:
202
+ A :obj:`Tuple` with the vocab and the merges:
203
+ The vocabulary and merges loaded into memory
204
+ """
205
+ pass
206
+
207
+ def save(self, folder, prefix):
208
+ """
209
+ Save the current model
210
+
211
+ Save the current model in the given folder, using the given prefix for the various
212
+ files that will get created.
213
+ Any file with the same name that already exists in this folder will be overwritten.
214
+
215
+ Args:
216
+ folder (:obj:`str`):
217
+ The path to the target folder in which to save the various files
218
+
219
+ prefix (:obj:`str`, `optional`):
220
+ An optional prefix, used to prefix each file name
221
+
222
+ Returns:
223
+ :obj:`List[str]`: The list of saved files
224
+ """
225
+ pass
226
+
227
+ def token_to_id(self, tokens):
228
+ """
229
+ Get the ID associated to a token
230
+
231
+ Args:
232
+ token (:obj:`str`):
233
+ A token to convert to an ID
234
+
235
+ Returns:
236
+ :obj:`int`: The ID associated to the token
237
+ """
238
+ pass
239
+
240
+ def tokenize(self, sequence):
241
+ """
242
+ Tokenize a sequence
243
+
244
+ Args:
245
+ sequence (:obj:`str`):
246
+ A sequence to tokenize
247
+
248
+ Returns:
249
+ A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens
250
+ """
251
+ pass
252
+
253
+ class Unigram(Model):
254
+ """
255
+ An implementation of the Unigram algorithm
256
+
257
+ Args:
258
+ vocab (:obj:`List[Tuple[str, float]]`, `optional`, `optional`):
259
+ A list of vocabulary items and their relative score [("am", -0.2442),...]
260
+ """
261
+ def __init__(self, vocab, unk_id, byte_fallback):
262
+ pass
263
+
264
+ def get_trainer(self):
265
+ """
266
+ Get the associated :class:`~tokenizers.trainers.Trainer`
267
+
268
+ Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this
269
+ :class:`~tokenizers.models.Model`.
270
+
271
+ Returns:
272
+ :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model
273
+ """
274
+ pass
275
+
276
+ def id_to_token(self, id):
277
+ """
278
+ Get the token associated to an ID
279
+
280
+ Args:
281
+ id (:obj:`int`):
282
+ An ID to convert to a token
283
+
284
+ Returns:
285
+ :obj:`str`: The token associated to the ID
286
+ """
287
+ pass
288
+
289
+ def save(self, folder, prefix):
290
+ """
291
+ Save the current model
292
+
293
+ Save the current model in the given folder, using the given prefix for the various
294
+ files that will get created.
295
+ Any file with the same name that already exists in this folder will be overwritten.
296
+
297
+ Args:
298
+ folder (:obj:`str`):
299
+ The path to the target folder in which to save the various files
300
+
301
+ prefix (:obj:`str`, `optional`):
302
+ An optional prefix, used to prefix each file name
303
+
304
+ Returns:
305
+ :obj:`List[str]`: The list of saved files
306
+ """
307
+ pass
308
+
309
+ def token_to_id(self, tokens):
310
+ """
311
+ Get the ID associated to a token
312
+
313
+ Args:
314
+ token (:obj:`str`):
315
+ A token to convert to an ID
316
+
317
+ Returns:
318
+ :obj:`int`: The ID associated to the token
319
+ """
320
+ pass
321
+
322
+ def tokenize(self, sequence):
323
+ """
324
+ Tokenize a sequence
325
+
326
+ Args:
327
+ sequence (:obj:`str`):
328
+ A sequence to tokenize
329
+
330
+ Returns:
331
+ A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens
332
+ """
333
+ pass
334
+
335
+ class WordLevel(Model):
336
+ """
337
+ An implementation of the WordLevel algorithm
338
+
339
+ Most simple tokenizer model based on mapping tokens to their corresponding id.
340
+
341
+ Args:
342
+ vocab (:obj:`str`, `optional`):
343
+ A dictionary of string keys and their ids :obj:`{"am": 0,...}`
344
+
345
+ unk_token (:obj:`str`, `optional`):
346
+ The unknown token to be used by the model.
347
+ """
348
+ def __init__(self, vocab, unk_token):
349
+ pass
350
+
351
+ @staticmethod
352
+ def from_file(vocab, unk_token):
353
+ """
354
+ Instantiate a WordLevel model from the given file
355
+
356
+ This method is roughly equivalent to doing::
357
+
358
+ vocab = WordLevel.read_file(vocab_filename)
359
+ wordlevel = WordLevel(vocab)
360
+
361
+ If you don't need to keep the :obj:`vocab` values lying around, this method is
362
+ more optimized than manually calling :meth:`~tokenizers.models.WordLevel.read_file` to
363
+ initialize a :class:`~tokenizers.models.WordLevel`
364
+
365
+ Args:
366
+ vocab (:obj:`str`):
367
+ The path to a :obj:`vocab.json` file
368
+
369
+ Returns:
370
+ :class:`~tokenizers.models.WordLevel`: An instance of WordLevel loaded from file
371
+ """
372
+ pass
373
+
374
+ def get_trainer(self):
375
+ """
376
+ Get the associated :class:`~tokenizers.trainers.Trainer`
377
+
378
+ Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this
379
+ :class:`~tokenizers.models.Model`.
380
+
381
+ Returns:
382
+ :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model
383
+ """
384
+ pass
385
+
386
+ def id_to_token(self, id):
387
+ """
388
+ Get the token associated to an ID
389
+
390
+ Args:
391
+ id (:obj:`int`):
392
+ An ID to convert to a token
393
+
394
+ Returns:
395
+ :obj:`str`: The token associated to the ID
396
+ """
397
+ pass
398
+
399
+ @staticmethod
400
+ def read_file(vocab):
401
+ """
402
+ Read a :obj:`vocab.json`
403
+
404
+ This method provides a way to read and parse the content of a vocabulary file,
405
+ returning the relevant data structures. If you want to instantiate some WordLevel models
406
+ from memory, this method gives you the expected input from the standard files.
407
+
408
+ Args:
409
+ vocab (:obj:`str`):
410
+ The path to a :obj:`vocab.json` file
411
+
412
+ Returns:
413
+ :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict`
414
+ """
415
+ pass
416
+
417
+ def save(self, folder, prefix):
418
+ """
419
+ Save the current model
420
+
421
+ Save the current model in the given folder, using the given prefix for the various
422
+ files that will get created.
423
+ Any file with the same name that already exists in this folder will be overwritten.
424
+
425
+ Args:
426
+ folder (:obj:`str`):
427
+ The path to the target folder in which to save the various files
428
+
429
+ prefix (:obj:`str`, `optional`):
430
+ An optional prefix, used to prefix each file name
431
+
432
+ Returns:
433
+ :obj:`List[str]`: The list of saved files
434
+ """
435
+ pass
436
+
437
+ def token_to_id(self, tokens):
438
+ """
439
+ Get the ID associated to a token
440
+
441
+ Args:
442
+ token (:obj:`str`):
443
+ A token to convert to an ID
444
+
445
+ Returns:
446
+ :obj:`int`: The ID associated to the token
447
+ """
448
+ pass
449
+
450
+ def tokenize(self, sequence):
451
+ """
452
+ Tokenize a sequence
453
+
454
+ Args:
455
+ sequence (:obj:`str`):
456
+ A sequence to tokenize
457
+
458
+ Returns:
459
+ A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens
460
+ """
461
+ pass
462
+
463
+ class WordPiece(Model):
464
+ """
465
+ An implementation of the WordPiece algorithm
466
+
467
+ Args:
468
+ vocab (:obj:`Dict[str, int]`, `optional`):
469
+ A dictionary of string keys and their ids :obj:`{"am": 0,...}`
470
+
471
+ unk_token (:obj:`str`, `optional`):
472
+ The unknown token to be used by the model.
473
+
474
+ max_input_chars_per_word (:obj:`int`, `optional`):
475
+ The maximum number of characters to authorize in a single word.
476
+ """
477
+ def __init__(self, vocab, unk_token, max_input_chars_per_word):
478
+ pass
479
+
480
+ @staticmethod
481
+ def from_file(vocab, **kwargs):
482
+ """
483
+ Instantiate a WordPiece model from the given file
484
+
485
+ This method is roughly equivalent to doing::
486
+
487
+ vocab = WordPiece.read_file(vocab_filename)
488
+ wordpiece = WordPiece(vocab)
489
+
490
+ If you don't need to keep the :obj:`vocab` values lying around, this method is
491
+ more optimized than manually calling :meth:`~tokenizers.models.WordPiece.read_file` to
492
+ initialize a :class:`~tokenizers.models.WordPiece`
493
+
494
+ Args:
495
+ vocab (:obj:`str`):
496
+ The path to a :obj:`vocab.txt` file
497
+
498
+ Returns:
499
+ :class:`~tokenizers.models.WordPiece`: An instance of WordPiece loaded from file
500
+ """
501
+ pass
502
+
503
+ def get_trainer(self):
504
+ """
505
+ Get the associated :class:`~tokenizers.trainers.Trainer`
506
+
507
+ Retrieve the :class:`~tokenizers.trainers.Trainer` associated to this
508
+ :class:`~tokenizers.models.Model`.
509
+
510
+ Returns:
511
+ :class:`~tokenizers.trainers.Trainer`: The Trainer used to train this model
512
+ """
513
+ pass
514
+
515
+ def id_to_token(self, id):
516
+ """
517
+ Get the token associated to an ID
518
+
519
+ Args:
520
+ id (:obj:`int`):
521
+ An ID to convert to a token
522
+
523
+ Returns:
524
+ :obj:`str`: The token associated to the ID
525
+ """
526
+ pass
527
+
528
+ @staticmethod
529
+ def read_file(vocab):
530
+ """
531
+ Read a :obj:`vocab.txt` file
532
+
533
+ This method provides a way to read and parse the content of a standard `vocab.txt`
534
+ file as used by the WordPiece Model, returning the relevant data structures. If you
535
+ want to instantiate some WordPiece models from memory, this method gives you the
536
+ expected input from the standard files.
537
+
538
+ Args:
539
+ vocab (:obj:`str`):
540
+ The path to a :obj:`vocab.txt` file
541
+
542
+ Returns:
543
+ :obj:`Dict[str, int]`: The vocabulary as a :obj:`dict`
544
+ """
545
+ pass
546
+
547
+ def save(self, folder, prefix):
548
+ """
549
+ Save the current model
550
+
551
+ Save the current model in the given folder, using the given prefix for the various
552
+ files that will get created.
553
+ Any file with the same name that already exists in this folder will be overwritten.
554
+
555
+ Args:
556
+ folder (:obj:`str`):
557
+ The path to the target folder in which to save the various files
558
+
559
+ prefix (:obj:`str`, `optional`):
560
+ An optional prefix, used to prefix each file name
561
+
562
+ Returns:
563
+ :obj:`List[str]`: The list of saved files
564
+ """
565
+ pass
566
+
567
+ def token_to_id(self, tokens):
568
+ """
569
+ Get the ID associated to a token
570
+
571
+ Args:
572
+ token (:obj:`str`):
573
+ A token to convert to an ID
574
+
575
+ Returns:
576
+ :obj:`int`: The ID associated to the token
577
+ """
578
+ pass
579
+
580
+ def tokenize(self, sequence):
581
+ """
582
+ Tokenize a sequence
583
+
584
+ Args:
585
+ sequence (:obj:`str`):
586
+ A sequence to tokenize
587
+
588
+ Returns:
589
+ A :obj:`List` of :class:`~tokenizers.Token`: The generated tokens
590
+ """
591
+ pass
activations.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from collections import OrderedDict
17
+
18
+ import torch
19
+ from packaging import version
20
+ from torch import Tensor, nn
21
+
22
+ from .utils import logging
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ class PytorchGELUTanh(nn.Module):
29
+ """
30
+ A fast C implementation of the tanh approximation of the GeLU activation function. See
31
+ https://arxiv.org/abs/1606.08415.
32
+
33
+ This implementation is equivalent to NewGELU and FastGELU but much faster. However, it is not an exact numerical
34
+ match due to rounding errors.
35
+ """
36
+
37
+ def __init__(self):
38
+ super().__init__()
39
+ if version.parse(torch.__version__) < version.parse("1.12.0"):
40
+ raise ImportError(
41
+ f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
42
+ "PytorchGELUTanh. Please upgrade torch."
43
+ )
44
+
45
+ def forward(self, input: Tensor) -> Tensor:
46
+ return nn.functional.gelu(input, approximate="tanh")
47
+
48
+
49
+ class NewGELUActivation(nn.Module):
50
+ """
51
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
52
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
53
+ """
54
+
55
+ def forward(self, input: Tensor) -> Tensor:
56
+ return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
57
+
58
+
59
+ class GELUActivation(nn.Module):
60
+ """
61
+ Original Implementation of the GELU activation function in Google BERT repo when initially created. For
62
+ information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
63
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
64
+ Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
65
+ """
66
+
67
+ def __init__(self, use_gelu_python: bool = False):
68
+ super().__init__()
69
+ if use_gelu_python:
70
+ self.act = self._gelu_python
71
+ else:
72
+ self.act = nn.functional.gelu
73
+
74
+ def _gelu_python(self, input: Tensor) -> Tensor:
75
+ return input * 0.5 * (1.0 + torch.erf(input / math.sqrt(2.0)))
76
+
77
+ def forward(self, input: Tensor) -> Tensor:
78
+ return self.act(input)
79
+
80
+
81
+ class FastGELUActivation(nn.Module):
82
+ """
83
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
84
+ """
85
+
86
+ def forward(self, input: Tensor) -> Tensor:
87
+ return 0.5 * input * (1.0 + torch.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
88
+
89
+
90
+ class QuickGELUActivation(nn.Module):
91
+ """
92
+ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
93
+ """
94
+
95
+ def forward(self, input: Tensor) -> Tensor:
96
+ return input * torch.sigmoid(1.702 * input)
97
+
98
+
99
+ class ClippedGELUActivation(nn.Module):
100
+ """
101
+ Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
102
+ it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
103
+ https://arxiv.org/abs/2004.09602.
104
+
105
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
106
+ initially created.
107
+
108
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
109
+ torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
110
+ """
111
+
112
+ def __init__(self, min: float, max: float):
113
+ if min > max:
114
+ raise ValueError(f"min should be < max (got min: {min}, max: {max})")
115
+
116
+ super().__init__()
117
+ self.min = min
118
+ self.max = max
119
+
120
+ def forward(self, x: Tensor) -> Tensor:
121
+ return torch.clip(gelu(x), self.min, self.max)
122
+
123
+
124
+ class AccurateGELUActivation(nn.Module):
125
+ """
126
+ Applies GELU approximation that is faster than default and more accurate than QuickGELU. See:
127
+ https://github.com/hendrycks/GELUs
128
+
129
+ Implemented along with MEGA (Moving Average Equipped Gated Attention)
130
+ """
131
+
132
+ def __init__(self):
133
+ super().__init__()
134
+ self.precomputed_constant = math.sqrt(2 / math.pi)
135
+
136
+ def forward(self, input: Tensor) -> Tensor:
137
+ return 0.5 * input * (1 + torch.tanh(self.precomputed_constant * (input + 0.044715 * torch.pow(input, 3))))
138
+
139
+
140
+ class MishActivation(nn.Module):
141
+ """
142
+ See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
143
+ visit the official repository for the paper: https://github.com/digantamisra98/Mish
144
+ """
145
+
146
+ def __init__(self):
147
+ super().__init__()
148
+ if version.parse(torch.__version__) < version.parse("1.9.0"):
149
+ self.act = self._mish_python
150
+ else:
151
+ self.act = nn.functional.mish
152
+
153
+ def _mish_python(self, input: Tensor) -> Tensor:
154
+ return input * torch.tanh(nn.functional.softplus(input))
155
+
156
+ def forward(self, input: Tensor) -> Tensor:
157
+ return self.act(input)
158
+
159
+
160
+ class LinearActivation(nn.Module):
161
+ """
162
+ Applies the linear activation function, i.e. forwarding input directly to output.
163
+ """
164
+
165
+ def forward(self, input: Tensor) -> Tensor:
166
+ return input
167
+
168
+
169
+ class LaplaceActivation(nn.Module):
170
+ """
171
+ Applies elementwise activation based on Laplace function, introduced in MEGA as an attention activation. See
172
+ https://arxiv.org/abs/2209.10655
173
+
174
+ Inspired by squared relu, but with bounded range and gradient for better stability
175
+ """
176
+
177
+ def forward(self, input, mu=0.707107, sigma=0.282095):
178
+ input = (input - mu).div(sigma * math.sqrt(2.0))
179
+ return 0.5 * (1.0 + torch.erf(input))
180
+
181
+
182
+ class ReLUSquaredActivation(nn.Module):
183
+ """
184
+ Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
185
+ """
186
+
187
+ def forward(self, input):
188
+ relu_applied = nn.functional.relu(input)
189
+ squared = torch.square(relu_applied)
190
+ return squared
191
+
192
+
193
+ class ClassInstantier(OrderedDict):
194
+ def __getitem__(self, key):
195
+ content = super().__getitem__(key)
196
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
197
+ return cls(**kwargs)
198
+
199
+
200
+ ACT2CLS = {
201
+ "gelu": GELUActivation,
202
+ "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
203
+ "gelu_fast": FastGELUActivation,
204
+ "gelu_new": NewGELUActivation,
205
+ "gelu_python": (GELUActivation, {"use_gelu_python": True}),
206
+ "gelu_pytorch_tanh": PytorchGELUTanh,
207
+ "gelu_accurate": AccurateGELUActivation,
208
+ "laplace": LaplaceActivation,
209
+ "leaky_relu": nn.LeakyReLU,
210
+ "linear": LinearActivation,
211
+ "mish": MishActivation,
212
+ "quick_gelu": QuickGELUActivation,
213
+ "relu": nn.ReLU,
214
+ "relu2": ReLUSquaredActivation,
215
+ "relu6": nn.ReLU6,
216
+ "sigmoid": nn.Sigmoid,
217
+ "silu": nn.SiLU,
218
+ "swish": nn.SiLU,
219
+ "tanh": nn.Tanh,
220
+ }
221
+ ACT2FN = ClassInstantier(ACT2CLS)
222
+
223
+
224
+ def get_activation(activation_string):
225
+ if activation_string in ACT2FN:
226
+ return ACT2FN[activation_string]
227
+ else:
228
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
229
+
230
+
231
+ # For backwards compatibility with: from activations import gelu_python
232
+ gelu_python = get_activation("gelu_python")
233
+ gelu_new = get_activation("gelu_new")
234
+ gelu = get_activation("gelu")
235
+ gelu_fast = get_activation("gelu_fast")
236
+ quick_gelu = get_activation("quick_gelu")
237
+ silu = get_activation("silu")
238
+ mish = get_activation("mish")
239
+ linear_act = get_activation("linear")
activations_tf.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+
17
+ import tensorflow as tf
18
+ from packaging.version import parse
19
+
20
+
21
+ try:
22
+ import tf_keras as keras
23
+ except (ModuleNotFoundError, ImportError):
24
+ import keras
25
+
26
+ if parse(keras.__version__).major > 2:
27
+ raise ValueError(
28
+ "Your currently installed version of Keras is Keras 3, but this is not yet supported in "
29
+ "Transformers. Please install the backwards-compatible tf-keras package with "
30
+ "`pip install tf-keras`."
31
+ )
32
+
33
+
34
+ def _gelu(x):
35
+ """
36
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
37
+ initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
38
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
39
+ https://arxiv.org/abs/1606.08415
40
+ """
41
+ x = tf.convert_to_tensor(x)
42
+ cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
43
+
44
+ return x * cdf
45
+
46
+
47
+ def _gelu_new(x):
48
+ """
49
+ Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
50
+
51
+ Args:
52
+ x: float Tensor to perform activation
53
+
54
+ Returns:
55
+ `x` with the GELU activation applied.
56
+ """
57
+ x = tf.convert_to_tensor(x)
58
+ pi = tf.cast(math.pi, x.dtype)
59
+ coeff = tf.cast(0.044715, x.dtype)
60
+ cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3))))
61
+
62
+ return x * cdf
63
+
64
+
65
+ def mish(x):
66
+ x = tf.convert_to_tensor(x)
67
+
68
+ return x * tf.tanh(tf.math.softplus(x))
69
+
70
+
71
+ def gelu_fast(x):
72
+ x = tf.convert_to_tensor(x)
73
+ coeff1 = tf.cast(0.044715, x.dtype)
74
+ coeff2 = tf.cast(0.7978845608, x.dtype)
75
+
76
+ return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
77
+
78
+
79
+ def quick_gelu(x):
80
+ x = tf.convert_to_tensor(x)
81
+ coeff = tf.cast(1.702, x.dtype)
82
+ return x * tf.math.sigmoid(coeff * x)
83
+
84
+
85
+ def gelu_10(x):
86
+ """
87
+ Clip the range of possible GeLU outputs between [-10, 10]. This is especially useful for quantization purpose, as
88
+ it allows mapping 2 negatives values in the GeLU spectrum. For more information on this trick, please refer to
89
+ https://arxiv.org/abs/2004.09602
90
+
91
+ Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
92
+ initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
93
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see
94
+ https://arxiv.org/abs/1606.08415 :param x: :return:
95
+ """
96
+ return tf.clip_by_value(_gelu(x), -10, 10)
97
+
98
+
99
+ def glu(x, axis=-1):
100
+ """
101
+ Gated Linear Unit. Implementation as defined in the original paper (see https://arxiv.org/abs/1612.08083), where
102
+ the input `x` is split in two halves across a dimension (`axis`), A and B, returning A * sigmoid(B).
103
+
104
+ Args:
105
+ `x`: float Tensor to perform activation
106
+ `axis`: dimension across which `x` be split in half
107
+
108
+ Returns:
109
+ `x` with the GLU activation applied (with its size halved across the dimension `axis`).
110
+ """
111
+ a, b = tf.split(x, 2, axis=axis)
112
+ return a * tf.math.sigmoid(b)
113
+
114
+
115
+ if parse(tf.version.VERSION) >= parse("2.4"):
116
+
117
+ def approximate_gelu_wrap(x):
118
+ return keras.activations.gelu(x, approximate=True)
119
+
120
+ gelu = keras.activations.gelu
121
+ gelu_new = approximate_gelu_wrap
122
+ else:
123
+ gelu = _gelu
124
+ gelu_new = _gelu_new
125
+
126
+
127
+ ACT2FN = {
128
+ "gelu": gelu,
129
+ "gelu_10": gelu_10,
130
+ "gelu_fast": gelu_fast,
131
+ "gelu_new": gelu_new,
132
+ "glu": glu,
133
+ "mish": mish,
134
+ "quick_gelu": quick_gelu,
135
+ "relu": keras.activations.relu,
136
+ "sigmoid": keras.activations.sigmoid,
137
+ "silu": keras.activations.swish,
138
+ "swish": keras.activations.swish,
139
+ "tanh": keras.activations.tanh,
140
+ }
141
+
142
+
143
+ def get_tf_activation(activation_string):
144
+ if activation_string in ACT2FN:
145
+ return ACT2FN[activation_string]
146
+ else:
147
+ raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
audio_utils.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team and the librosa & torchaudio authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Audio processing functions to extract features from audio waveforms. This code is pure numpy to support all frameworks
17
+ and remove unnecessary dependencies.
18
+ """
19
+
20
+ import warnings
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+
25
+
26
+ def hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
27
+ """
28
+ Convert frequency from hertz to mels.
29
+
30
+ Args:
31
+ freq (`float` or `np.ndarray`):
32
+ The frequency, or multiple frequencies, in hertz (Hz).
33
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
34
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
35
+
36
+ Returns:
37
+ `float` or `np.ndarray`: The frequencies on the mel scale.
38
+ """
39
+
40
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
41
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
42
+
43
+ if mel_scale == "htk":
44
+ return 2595.0 * np.log10(1.0 + (freq / 700.0))
45
+ elif mel_scale == "kaldi":
46
+ return 1127.0 * np.log(1.0 + (freq / 700.0))
47
+
48
+ min_log_hertz = 1000.0
49
+ min_log_mel = 15.0
50
+ logstep = 27.0 / np.log(6.4)
51
+ mels = 3.0 * freq / 200.0
52
+
53
+ if isinstance(freq, np.ndarray):
54
+ log_region = freq >= min_log_hertz
55
+ mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep
56
+ elif freq >= min_log_hertz:
57
+ mels = min_log_mel + np.log(freq / min_log_hertz) * logstep
58
+
59
+ return mels
60
+
61
+
62
+ def mel_to_hertz(mels: Union[float, np.ndarray], mel_scale: str = "htk") -> Union[float, np.ndarray]:
63
+ """
64
+ Convert frequency from mels to hertz.
65
+
66
+ Args:
67
+ mels (`float` or `np.ndarray`):
68
+ The frequency, or multiple frequencies, in mels.
69
+ mel_scale (`str`, *optional*, `"htk"`):
70
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
71
+
72
+ Returns:
73
+ `float` or `np.ndarray`: The frequencies in hertz.
74
+ """
75
+
76
+ if mel_scale not in ["slaney", "htk", "kaldi"]:
77
+ raise ValueError('mel_scale should be one of "htk", "slaney" or "kaldi".')
78
+
79
+ if mel_scale == "htk":
80
+ return 700.0 * (np.power(10, mels / 2595.0) - 1.0)
81
+ elif mel_scale == "kaldi":
82
+ return 700.0 * (np.exp(mels / 1127.0) - 1.0)
83
+
84
+ min_log_hertz = 1000.0
85
+ min_log_mel = 15.0
86
+ logstep = np.log(6.4) / 27.0
87
+ freq = 200.0 * mels / 3.0
88
+
89
+ if isinstance(mels, np.ndarray):
90
+ log_region = mels >= min_log_mel
91
+ freq[log_region] = min_log_hertz * np.exp(logstep * (mels[log_region] - min_log_mel))
92
+ elif mels >= min_log_mel:
93
+ freq = min_log_hertz * np.exp(logstep * (mels - min_log_mel))
94
+
95
+ return freq
96
+
97
+
98
+ def hertz_to_octave(
99
+ freq: Union[float, np.ndarray], tuning: Optional[float] = 0.0, bins_per_octave: Optional[int] = 12
100
+ ):
101
+ """
102
+ Convert frequency from hertz to fractional octave numbers.
103
+ Adapted from *librosa*.
104
+
105
+ Args:
106
+ freq (`float` or `np.ndarray`):
107
+ The frequency, or multiple frequencies, in hertz (Hz).
108
+ tuning (`float`, defaults to `0.`):
109
+ Tuning deviation from the Stuttgart pitch (A440) in (fractional) bins per octave.
110
+ bins_per_octave (`int`, defaults to `12`):
111
+ Number of bins per octave.
112
+
113
+ Returns:
114
+ `float` or `np.ndarray`: The frequencies on the octave scale.
115
+ """
116
+ stuttgart_pitch = 440.0 * 2.0 ** (tuning / bins_per_octave)
117
+ octave = np.log2(freq / (float(stuttgart_pitch) / 16))
118
+ return octave
119
+
120
+
121
+ def _create_triangular_filter_bank(fft_freqs: np.ndarray, filter_freqs: np.ndarray) -> np.ndarray:
122
+ """
123
+ Creates a triangular filter bank.
124
+
125
+ Adapted from *torchaudio* and *librosa*.
126
+
127
+ Args:
128
+ fft_freqs (`np.ndarray` of shape `(num_frequency_bins,)`):
129
+ Discrete frequencies of the FFT bins in Hz.
130
+ filter_freqs (`np.ndarray` of shape `(num_mel_filters,)`):
131
+ Center frequencies of the triangular filters to create, in Hz.
132
+
133
+ Returns:
134
+ `np.ndarray` of shape `(num_frequency_bins, num_mel_filters)`
135
+ """
136
+ filter_diff = np.diff(filter_freqs)
137
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
138
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
139
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
140
+ return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
141
+
142
+
143
+ def chroma_filter_bank(
144
+ num_frequency_bins: int,
145
+ num_chroma: int,
146
+ sampling_rate: int,
147
+ tuning: float = 0.0,
148
+ power: Optional[float] = 2.0,
149
+ weighting_parameters: Optional[Tuple[float]] = (5.0, 2),
150
+ start_at_c_chroma: Optional[bool] = True,
151
+ ):
152
+ """
153
+ Creates a chroma filter bank, i.e a linear transformation to project spectrogram bins onto chroma bins.
154
+
155
+ Adapted from *librosa*.
156
+
157
+ Args:
158
+ num_frequency_bins (`int`):
159
+ Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
160
+ num_chroma (`int`):
161
+ Number of chroma bins (i.e pitch classes).
162
+ sampling_rate (`float`):
163
+ Sample rate of the audio waveform.
164
+ tuning (`float`):
165
+ Tuning deviation from A440 in fractions of a chroma bin.
166
+ power (`float`, *optional*, defaults to 2.0):
167
+ If 12.0, normalizes each column with their L2 norm. If 1.0, normalizes each column with their L1 norm.
168
+ weighting_parameters (`Tuple[float]`, *optional*, defaults to `(5., 2.)`):
169
+ If specified, apply a Gaussian weighting parameterized by the first element of the tuple being the center and
170
+ the second element being the Gaussian half-width.
171
+ start_at_c_chroma (`float`, *optional*, defaults to `True`):
172
+ If True, the filter bank will start at the 'C' pitch class. Otherwise, it will start at 'A'.
173
+ Returns:
174
+ `np.ndarray` of shape `(num_frequency_bins, num_chroma)`
175
+ """
176
+ # Get the FFT bins, not counting the DC component
177
+ frequencies = np.linspace(0, sampling_rate, num_frequency_bins, endpoint=False)[1:]
178
+
179
+ freq_bins = num_chroma * hertz_to_octave(frequencies, tuning=tuning, bins_per_octave=num_chroma)
180
+
181
+ # make up a value for the 0 Hz bin = 1.5 octaves below bin 1
182
+ # (so chroma is 50% rotated from bin 1, and bin width is broad)
183
+ freq_bins = np.concatenate(([freq_bins[0] - 1.5 * num_chroma], freq_bins))
184
+
185
+ bins_width = np.concatenate((np.maximum(freq_bins[1:] - freq_bins[:-1], 1.0), [1]))
186
+
187
+ chroma_filters = np.subtract.outer(freq_bins, np.arange(0, num_chroma, dtype="d")).T
188
+
189
+ num_chroma2 = np.round(float(num_chroma) / 2)
190
+
191
+ # Project into range -num_chroma/2 .. num_chroma/2
192
+ # add on fixed offset of 10*num_chroma to ensure all values passed to
193
+ # rem are positive
194
+ chroma_filters = np.remainder(chroma_filters + num_chroma2 + 10 * num_chroma, num_chroma) - num_chroma2
195
+
196
+ # Gaussian bumps - 2*D to make them narrower
197
+ chroma_filters = np.exp(-0.5 * (2 * chroma_filters / np.tile(bins_width, (num_chroma, 1))) ** 2)
198
+
199
+ # normalize each column
200
+ if power is not None:
201
+ chroma_filters = chroma_filters / np.sum(chroma_filters**power, axis=0, keepdims=True) ** (1.0 / power)
202
+
203
+ # Maybe apply scaling for fft bins
204
+ if weighting_parameters is not None:
205
+ center, half_width = weighting_parameters
206
+ chroma_filters *= np.tile(
207
+ np.exp(-0.5 * (((freq_bins / num_chroma - center) / half_width) ** 2)),
208
+ (num_chroma, 1),
209
+ )
210
+
211
+ if start_at_c_chroma:
212
+ chroma_filters = np.roll(chroma_filters, -3 * (num_chroma // 12), axis=0)
213
+
214
+ # remove aliasing columns, copy to ensure row-contiguity
215
+ return np.ascontiguousarray(chroma_filters[:, : int(1 + num_frequency_bins / 2)])
216
+
217
+
218
+ def mel_filter_bank(
219
+ num_frequency_bins: int,
220
+ num_mel_filters: int,
221
+ min_frequency: float,
222
+ max_frequency: float,
223
+ sampling_rate: int,
224
+ norm: Optional[str] = None,
225
+ mel_scale: str = "htk",
226
+ triangularize_in_mel_space: bool = False,
227
+ ) -> np.ndarray:
228
+ """
229
+ Creates a frequency bin conversion matrix used to obtain a mel spectrogram. This is called a *mel filter bank*, and
230
+ various implementation exist, which differ in the number of filters, the shape of the filters, the way the filters
231
+ are spaced, the bandwidth of the filters, and the manner in which the spectrum is warped. The goal of these
232
+ features is to approximate the non-linear human perception of the variation in pitch with respect to the frequency.
233
+
234
+ Different banks of mel filters were introduced in the literature. The following variations are supported:
235
+
236
+ - MFCC FB-20: introduced in 1980 by Davis and Mermelstein, it assumes a sampling frequency of 10 kHz and a speech
237
+ bandwidth of `[0, 4600]` Hz.
238
+ - MFCC FB-24 HTK: from the Cambridge HMM Toolkit (HTK) (1995) uses a filter bank of 24 filters for a speech
239
+ bandwidth of `[0, 8000]` Hz. This assumes sampling rate ≥ 16 kHz.
240
+ - MFCC FB-40: from the Auditory Toolbox for MATLAB written by Slaney in 1998, assumes a sampling rate of 16 kHz and
241
+ speech bandwidth of `[133, 6854]` Hz. This version also includes area normalization.
242
+ - HFCC-E FB-29 (Human Factor Cepstral Coefficients) of Skowronski and Harris (2004), assumes a sampling rate of
243
+ 12.5 kHz and speech bandwidth of `[0, 6250]` Hz.
244
+
245
+ This code is adapted from *torchaudio* and *librosa*. Note that the default parameters of torchaudio's
246
+ `melscale_fbanks` implement the `"htk"` filters while librosa uses the `"slaney"` implementation.
247
+
248
+ Args:
249
+ num_frequency_bins (`int`):
250
+ Number of frequencies used to compute the spectrogram (should be the same as in `stft`).
251
+ num_mel_filters (`int`):
252
+ Number of mel filters to generate.
253
+ min_frequency (`float`):
254
+ Lowest frequency of interest in Hz.
255
+ max_frequency (`float`):
256
+ Highest frequency of interest in Hz. This should not exceed `sampling_rate / 2`.
257
+ sampling_rate (`int`):
258
+ Sample rate of the audio waveform.
259
+ norm (`str`, *optional*):
260
+ If `"slaney"`, divide the triangular mel weights by the width of the mel band (area normalization).
261
+ mel_scale (`str`, *optional*, defaults to `"htk"`):
262
+ The mel frequency scale to use, `"htk"`, `"kaldi"` or `"slaney"`.
263
+ triangularize_in_mel_space (`bool`, *optional*, defaults to `False`):
264
+ If this option is enabled, the triangular filter is applied in mel space rather than frequency space. This
265
+ should be set to `true` in order to get the same results as `torchaudio` when computing mel filters.
266
+
267
+ Returns:
268
+ `np.ndarray` of shape (`num_frequency_bins`, `num_mel_filters`): Triangular filter bank matrix. This is a
269
+ projection matrix to go from a spectrogram to a mel spectrogram.
270
+ """
271
+ if norm is not None and norm != "slaney":
272
+ raise ValueError('norm must be one of None or "slaney"')
273
+
274
+ # center points of the triangular mel filters
275
+ mel_min = hertz_to_mel(min_frequency, mel_scale=mel_scale)
276
+ mel_max = hertz_to_mel(max_frequency, mel_scale=mel_scale)
277
+ mel_freqs = np.linspace(mel_min, mel_max, num_mel_filters + 2)
278
+ filter_freqs = mel_to_hertz(mel_freqs, mel_scale=mel_scale)
279
+
280
+ if triangularize_in_mel_space:
281
+ # frequencies of FFT bins in Hz, but filters triangularized in mel space
282
+ fft_bin_width = sampling_rate / (num_frequency_bins * 2)
283
+ fft_freqs = hertz_to_mel(fft_bin_width * np.arange(num_frequency_bins), mel_scale=mel_scale)
284
+ filter_freqs = mel_freqs
285
+ else:
286
+ # frequencies of FFT bins in Hz
287
+ fft_freqs = np.linspace(0, sampling_rate // 2, num_frequency_bins)
288
+
289
+ mel_filters = _create_triangular_filter_bank(fft_freqs, filter_freqs)
290
+
291
+ if norm is not None and norm == "slaney":
292
+ # Slaney-style mel is scaled to be approx constant energy per channel
293
+ enorm = 2.0 / (filter_freqs[2 : num_mel_filters + 2] - filter_freqs[:num_mel_filters])
294
+ mel_filters *= np.expand_dims(enorm, 0)
295
+
296
+ if (mel_filters.max(axis=0) == 0.0).any():
297
+ warnings.warn(
298
+ "At least one mel filter has all zero values. "
299
+ f"The value for `num_mel_filters` ({num_mel_filters}) may be set too high. "
300
+ f"Or, the value for `num_frequency_bins` ({num_frequency_bins}) may be set too low."
301
+ )
302
+
303
+ return mel_filters
304
+
305
+
306
+ def optimal_fft_length(window_length: int) -> int:
307
+ """
308
+ Finds the best FFT input size for a given `window_length`. This function takes a given window length and, if not
309
+ already a power of two, rounds it up to the next power or two.
310
+
311
+ The FFT algorithm works fastest when the length of the input is a power of two, which may be larger than the size
312
+ of the window or analysis frame. For example, if the window is 400 samples, using an FFT input size of 512 samples
313
+ is more optimal than an FFT size of 400 samples. Using a larger FFT size does not affect the detected frequencies,
314
+ it simply gives a higher frequency resolution (i.e. the frequency bins are smaller).
315
+ """
316
+ return 2 ** int(np.ceil(np.log2(window_length)))
317
+
318
+
319
+ def window_function(
320
+ window_length: int,
321
+ name: str = "hann",
322
+ periodic: bool = True,
323
+ frame_length: Optional[int] = None,
324
+ center: bool = True,
325
+ ) -> np.ndarray:
326
+ """
327
+ Returns an array containing the specified window. This window is intended to be used with `stft`.
328
+
329
+ The following window types are supported:
330
+
331
+ - `"boxcar"`: a rectangular window
332
+ - `"hamming"`: the Hamming window
333
+ - `"hann"`: the Hann window
334
+ - `"povey"`: the Povey window
335
+
336
+ Args:
337
+ window_length (`int`):
338
+ The length of the window in samples.
339
+ name (`str`, *optional*, defaults to `"hann"`):
340
+ The name of the window function.
341
+ periodic (`bool`, *optional*, defaults to `True`):
342
+ Whether the window is periodic or symmetric.
343
+ frame_length (`int`, *optional*):
344
+ The length of the analysis frames in samples. Provide a value for `frame_length` if the window is smaller
345
+ than the frame length, so that it will be zero-padded.
346
+ center (`bool`, *optional*, defaults to `True`):
347
+ Whether to center the window inside the FFT buffer. Only used when `frame_length` is provided.
348
+
349
+ Returns:
350
+ `np.ndarray` of shape `(window_length,)` or `(frame_length,)` containing the window.
351
+ """
352
+ length = window_length + 1 if periodic else window_length
353
+
354
+ if name == "boxcar":
355
+ window = np.ones(length)
356
+ elif name in ["hamming", "hamming_window"]:
357
+ window = np.hamming(length)
358
+ elif name in ["hann", "hann_window"]:
359
+ window = np.hanning(length)
360
+ elif name in ["povey"]:
361
+ window = np.power(np.hanning(length), 0.85)
362
+ else:
363
+ raise ValueError(f"Unknown window function '{name}'")
364
+
365
+ if periodic:
366
+ window = window[:-1]
367
+
368
+ if frame_length is None:
369
+ return window
370
+
371
+ if window_length > frame_length:
372
+ raise ValueError(
373
+ f"Length of the window ({window_length}) may not be larger than frame_length ({frame_length})"
374
+ )
375
+
376
+ padded_window = np.zeros(frame_length)
377
+ offset = (frame_length - window_length) // 2 if center else 0
378
+ padded_window[offset : offset + window_length] = window
379
+ return padded_window
380
+
381
+
382
+ # TODO This method does not support batching yet as we are mainly focused on inference.
383
+ def spectrogram(
384
+ waveform: np.ndarray,
385
+ window: np.ndarray,
386
+ frame_length: int,
387
+ hop_length: int,
388
+ fft_length: Optional[int] = None,
389
+ power: Optional[float] = 1.0,
390
+ center: bool = True,
391
+ pad_mode: str = "reflect",
392
+ onesided: bool = True,
393
+ preemphasis: Optional[float] = None,
394
+ mel_filters: Optional[np.ndarray] = None,
395
+ mel_floor: float = 1e-10,
396
+ log_mel: Optional[str] = None,
397
+ reference: float = 1.0,
398
+ min_value: float = 1e-10,
399
+ db_range: Optional[float] = None,
400
+ remove_dc_offset: Optional[bool] = None,
401
+ dtype: np.dtype = np.float32,
402
+ ) -> np.ndarray:
403
+ """
404
+ Calculates a spectrogram over one waveform using the Short-Time Fourier Transform.
405
+
406
+ This function can create the following kinds of spectrograms:
407
+
408
+ - amplitude spectrogram (`power = 1.0`)
409
+ - power spectrogram (`power = 2.0`)
410
+ - complex-valued spectrogram (`power = None`)
411
+ - log spectrogram (use `log_mel` argument)
412
+ - mel spectrogram (provide `mel_filters`)
413
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
414
+
415
+ How this works:
416
+
417
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
418
+ - hop_length` samples.
419
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
420
+ 3. The DFT is taken of each windowed frame.
421
+ 4. The results are stacked into a spectrogram.
422
+
423
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
424
+
425
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
426
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
427
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
428
+
429
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
430
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
431
+ typically the next power of two.
432
+
433
+ Note: This function is not optimized for speed yet. It should be mostly compatible with `librosa.stft` and
434
+ `torchaudio.functional.transforms.Spectrogram`, although it is more flexible due to the different ways spectrograms
435
+ can be constructed.
436
+
437
+ Args:
438
+ waveform (`np.ndarray` of shape `(length,)`):
439
+ The input waveform. This must be a single real-valued, mono waveform.
440
+ window (`np.ndarray` of shape `(frame_length,)`):
441
+ The windowing function to apply, including zero-padding if necessary. The actual window length may be
442
+ shorter than `frame_length`, but we're assuming the array has already been zero-padded.
443
+ frame_length (`int`):
444
+ The length of the analysis frames in samples. With librosa this is always equal to `fft_length` but we also
445
+ allow smaller sizes.
446
+ hop_length (`int`):
447
+ The stride between successive analysis frames in samples.
448
+ fft_length (`int`, *optional*):
449
+ The size of the FFT buffer in samples. This determines how many frequency bins the spectrogram will have.
450
+ For optimal speed, this should be a power of two. If `None`, uses `frame_length`.
451
+ power (`float`, *optional*, defaults to 1.0):
452
+ If 1.0, returns the amplitude spectrogram. If 2.0, returns the power spectrogram. If `None`, returns
453
+ complex numbers.
454
+ center (`bool`, *optional*, defaults to `True`):
455
+ Whether to pad the waveform so that frame `t` is centered around time `t * hop_length`. If `False`, frame
456
+ `t` will start at time `t * hop_length`.
457
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
458
+ Padding mode used when `center` is `True`. Possible values are: `"constant"` (pad with zeros), `"edge"`
459
+ (pad with edge values), `"reflect"` (pads with mirrored values).
460
+ onesided (`bool`, *optional*, defaults to `True`):
461
+ If True, only computes the positive frequencies and returns a spectrogram containing `fft_length // 2 + 1`
462
+ frequency bins. If False, also computes the negative frequencies and returns `fft_length` frequency bins.
463
+ preemphasis (`float`, *optional*)
464
+ Coefficient for a low-pass filter that applies pre-emphasis before the DFT.
465
+ mel_filters (`np.ndarray` of shape `(num_freq_bins, num_mel_filters)`, *optional*):
466
+ The mel filter bank. If supplied, applies a this filter bank to create a mel spectrogram.
467
+ mel_floor (`float`, *optional*, defaults to 1e-10):
468
+ Minimum value of mel frequency banks.
469
+ log_mel (`str`, *optional*):
470
+ How to convert the spectrogram to log scale. Possible options are: `None` (don't convert), `"log"` (take
471
+ the natural logarithm) `"log10"` (take the base-10 logarithm), `"dB"` (convert to decibels). Can only be
472
+ used when `power` is not `None`.
473
+ reference (`float`, *optional*, defaults to 1.0):
474
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
475
+ the loudest part to 0 dB. Must be greater than zero.
476
+ min_value (`float`, *optional*, defaults to `1e-10`):
477
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
478
+ `log(0)`. For a power spectrogram, the default of `1e-10` corresponds to a minimum of -100 dB. For an
479
+ amplitude spectrogram, the value `1e-5` corresponds to -100 dB. Must be greater than zero.
480
+ db_range (`float`, *optional*):
481
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
482
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
483
+ remove_dc_offset (`bool`, *optional*):
484
+ Subtract mean from waveform on each frame, applied before pre-emphasis. This should be set to `true` in
485
+ order to get the same results as `torchaudio.compliance.kaldi.fbank` when computing mel filters.
486
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
487
+ Data type of the spectrogram tensor. If `power` is None, this argument is ignored and the dtype will be
488
+ `np.complex64`.
489
+
490
+ Returns:
491
+ `nd.array` containing a spectrogram of shape `(num_frequency_bins, length)` for a regular spectrogram or shape
492
+ `(num_mel_filters, length)` for a mel spectrogram.
493
+ """
494
+ window_length = len(window)
495
+
496
+ if fft_length is None:
497
+ fft_length = frame_length
498
+
499
+ if frame_length > fft_length:
500
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
501
+
502
+ if window_length != frame_length:
503
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
504
+
505
+ if hop_length <= 0:
506
+ raise ValueError("hop_length must be greater than zero")
507
+
508
+ if waveform.ndim != 1:
509
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
510
+
511
+ if np.iscomplexobj(waveform):
512
+ raise ValueError("Complex-valued input waveforms are not currently supported")
513
+
514
+ if power is None and mel_filters is not None:
515
+ raise ValueError(
516
+ "You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram."
517
+ "Specify `power` to fix this issue."
518
+ )
519
+
520
+ # center pad the waveform
521
+ if center:
522
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
523
+ waveform = np.pad(waveform, padding, mode=pad_mode)
524
+
525
+ # promote to float64, since np.fft uses float64 internally
526
+ waveform = waveform.astype(np.float64)
527
+ window = window.astype(np.float64)
528
+
529
+ # split waveform into frames of frame_length size
530
+ num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))
531
+
532
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
533
+ spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)
534
+
535
+ # rfft is faster than fft
536
+ fft_func = np.fft.rfft if onesided else np.fft.fft
537
+ buffer = np.zeros(fft_length)
538
+
539
+ timestep = 0
540
+ for frame_idx in range(num_frames):
541
+ buffer[:frame_length] = waveform[timestep : timestep + frame_length]
542
+
543
+ if remove_dc_offset:
544
+ buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()
545
+
546
+ if preemphasis is not None:
547
+ buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]
548
+ buffer[0] *= 1 - preemphasis
549
+
550
+ buffer[:frame_length] *= window
551
+
552
+ spectrogram[frame_idx] = fft_func(buffer)
553
+ timestep += hop_length
554
+
555
+ # note: ** is much faster than np.power
556
+ if power is not None:
557
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
558
+
559
+ spectrogram = spectrogram.T
560
+
561
+ if mel_filters is not None:
562
+ spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))
563
+
564
+ if power is not None and log_mel is not None:
565
+ if log_mel == "log":
566
+ spectrogram = np.log(spectrogram)
567
+ elif log_mel == "log10":
568
+ spectrogram = np.log10(spectrogram)
569
+ elif log_mel == "dB":
570
+ if power == 1.0:
571
+ spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
572
+ elif power == 2.0:
573
+ spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
574
+ else:
575
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
576
+ else:
577
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
578
+
579
+ spectrogram = np.asarray(spectrogram, dtype)
580
+
581
+ return spectrogram
582
+
583
+
584
+ def spectrogram_batch(
585
+ waveform_list: List[np.ndarray],
586
+ window: np.ndarray,
587
+ frame_length: int,
588
+ hop_length: int,
589
+ fft_length: Optional[int] = None,
590
+ power: Optional[float] = 1.0,
591
+ center: bool = True,
592
+ pad_mode: str = "reflect",
593
+ onesided: bool = True,
594
+ preemphasis: Optional[float] = None,
595
+ mel_filters: Optional[np.ndarray] = None,
596
+ mel_floor: float = 1e-10,
597
+ log_mel: Optional[str] = None,
598
+ reference: float = 1.0,
599
+ min_value: float = 1e-10,
600
+ db_range: Optional[float] = None,
601
+ remove_dc_offset: Optional[bool] = None,
602
+ dtype: np.dtype = np.float32,
603
+ ) -> List[np.ndarray]:
604
+ """
605
+ Calculates spectrograms for a list of waveforms using the Short-Time Fourier Transform, optimized for batch processing.
606
+ This function extends the capabilities of the `spectrogram` function to handle multiple waveforms efficiently by leveraging broadcasting.
607
+
608
+ It supports generating various types of spectrograms:
609
+
610
+ - amplitude spectrogram (`power = 1.0`)
611
+ - power spectrogram (`power = 2.0`)
612
+ - complex-valued spectrogram (`power = None`)
613
+ - log spectrogram (use `log_mel` argument)
614
+ - mel spectrogram (provide `mel_filters`)
615
+ - log-mel spectrogram (provide `mel_filters` and `log_mel`)
616
+
617
+ How this works:
618
+
619
+ 1. The input waveform is split into frames of size `frame_length` that are partially overlapping by `frame_length
620
+ - hop_length` samples.
621
+ 2. Each frame is multiplied by the window and placed into a buffer of size `fft_length`.
622
+ 3. The DFT is taken of each windowed frame.
623
+ 4. The results are stacked into a spectrogram.
624
+
625
+ We make a distinction between the following "blocks" of sample data, each of which may have a different lengths:
626
+
627
+ - The analysis frame. This is the size of the time slices that the input waveform is split into.
628
+ - The window. Each analysis frame is multiplied by the window to avoid spectral leakage.
629
+ - The FFT input buffer. The length of this determines how many frequency bins are in the spectrogram.
630
+
631
+ In this implementation, the window is assumed to be zero-padded to have the same size as the analysis frame. A
632
+ padded window can be obtained from `window_function()`. The FFT input buffer may be larger than the analysis frame,
633
+ typically the next power of two.
634
+
635
+ Note: This function is designed for efficient batch processing of multiple waveforms but retains compatibility with individual waveform processing methods like `librosa.stft`.
636
+
637
+ Args:
638
+ waveform_list (`List[np.ndarray]` with arrays of shape `(length,)`):
639
+ The list of input waveforms, each a single-channel (mono) signal.
640
+ window (`np.ndarray` of shape `(frame_length,)`):
641
+ The windowing function to apply, including zero-padding if necessary.
642
+ frame_length (`int`):
643
+ The length of each frame for analysis.
644
+ hop_length (`int`):
645
+ The step size between successive frames.
646
+ fft_length (`int`, *optional*):
647
+ The size of the FFT buffer, defining frequency bin resolution.
648
+ power (`float`, *optional*, defaults to 1.0):
649
+ Determines the type of spectrogram: 1.0 for amplitude, 2.0 for power, None for complex.
650
+ center (`bool`, *optional*, defaults to `True`):
651
+ Whether to center-pad the waveform frames.
652
+ pad_mode (`str`, *optional*, defaults to `"reflect"`):
653
+ The padding strategy when `center` is `True`.
654
+ onesided (`bool`, *optional*, defaults to `True`):
655
+ If True, returns a one-sided spectrogram for real input signals.
656
+ preemphasis (`float`, *optional*):
657
+ Applies a pre-emphasis filter to each frame.
658
+ mel_filters (`np.ndarray`, *optional*):
659
+ Mel filter bank for converting to mel spectrogram.
660
+ mel_floor (`float`, *optional*, defaults to 1e-10):
661
+ Floor value for mel spectrogram to avoid log(0).
662
+ log_mel (`str`, *optional*):
663
+ Specifies log scaling strategy; options are None, "log", "log10", "dB".
664
+ reference (`float`, *optional*, defaults to 1.0):
665
+ Reference value for dB conversion in log_mel.
666
+ min_value (`float`, *optional*, defaults to 1e-10):
667
+ Minimum floor value for log scale conversions.
668
+ db_range (`float`, *optional*):
669
+ Dynamic range for dB scale spectrograms.
670
+ remove_dc_offset (`bool`, *optional*):
671
+ Whether to remove the DC offset from each frame.
672
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
673
+ Data type of the output spectrogram.
674
+
675
+ Returns:
676
+ List[`np.ndarray`]: A list of spectrogram arrays, one for each input waveform.
677
+ """
678
+ window_length = len(window)
679
+
680
+ if fft_length is None:
681
+ fft_length = frame_length
682
+
683
+ if frame_length > fft_length:
684
+ raise ValueError(f"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})")
685
+
686
+ if window_length != frame_length:
687
+ raise ValueError(f"Length of the window ({window_length}) must equal frame_length ({frame_length})")
688
+
689
+ if hop_length <= 0:
690
+ raise ValueError("hop_length must be greater than zero")
691
+
692
+ # Check the dimensions of the waveform , and if waveform is complex
693
+ for waveform in waveform_list:
694
+ if waveform.ndim != 1:
695
+ raise ValueError(f"Input waveform must have only one dimension, shape is {waveform.shape}")
696
+ if np.iscomplexobj(waveform):
697
+ raise ValueError("Complex-valued input waveforms are not currently supported")
698
+ # Center pad the waveform
699
+ if center:
700
+ padding = [(int(frame_length // 2), int(frame_length // 2))]
701
+ waveform_list = [
702
+ np.pad(
703
+ waveform,
704
+ padding,
705
+ mode=pad_mode,
706
+ )
707
+ for waveform in waveform_list
708
+ ]
709
+ original_waveform_lengths = [
710
+ len(waveform) for waveform in waveform_list
711
+ ] # these lengths will be used to remove padding later
712
+
713
+ # Batch pad the waveform
714
+ max_length = max(original_waveform_lengths)
715
+ padded_waveform_batch = np.array(
716
+ [
717
+ np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0)
718
+ for waveform in waveform_list
719
+ ],
720
+ dtype=dtype,
721
+ )
722
+
723
+ # Promote to float64, since np.fft uses float64 internally
724
+ padded_waveform_batch = padded_waveform_batch.astype(np.float64)
725
+ window = window.astype(np.float64)
726
+
727
+ # Split waveform into frames of frame_length size
728
+ num_frames = int(1 + np.floor((padded_waveform_batch.shape[1] - frame_length) / hop_length))
729
+ # these lengths will be used to remove padding later
730
+ true_num_frames = [int(1 + np.floor((length - frame_length) / hop_length)) for length in original_waveform_lengths]
731
+ num_batches = padded_waveform_batch.shape[0]
732
+
733
+ num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length
734
+ spectrogram = np.empty((num_batches, num_frames, num_frequency_bins), dtype=np.complex64)
735
+
736
+ # rfft is faster than fft
737
+ fft_func = np.fft.rfft if onesided else np.fft.fft
738
+ buffer = np.zeros((num_batches, fft_length))
739
+
740
+ for frame_idx in range(num_frames):
741
+ timestep = frame_idx * hop_length
742
+ buffer[:, :frame_length] = padded_waveform_batch[:, timestep : timestep + frame_length]
743
+
744
+ if remove_dc_offset:
745
+ buffer[:, :frame_length] -= buffer[:, :frame_length].mean(axis=1, keepdims=True)
746
+
747
+ if preemphasis is not None:
748
+ buffer[:, 1:frame_length] -= preemphasis * buffer[:, : frame_length - 1]
749
+ buffer[:, 0] *= 1 - preemphasis
750
+
751
+ buffer[:, :frame_length] *= window
752
+
753
+ spectrogram[:, frame_idx] = fft_func(buffer)
754
+
755
+ # Note: ** is much faster than np.power
756
+ if power is not None:
757
+ spectrogram = np.abs(spectrogram, dtype=np.float64) ** power
758
+
759
+ # Apply mel filters if provided
760
+ if mel_filters is not None:
761
+ result = np.tensordot(spectrogram, mel_filters.T, axes=([2], [1]))
762
+ spectrogram = np.maximum(mel_floor, result)
763
+
764
+ # Convert to log scale if specified
765
+ if power is not None and log_mel is not None:
766
+ if log_mel == "log":
767
+ spectrogram = np.log(spectrogram)
768
+ elif log_mel == "log10":
769
+ spectrogram = np.log10(spectrogram)
770
+ elif log_mel == "dB":
771
+ if power == 1.0:
772
+ spectrogram = amplitude_to_db_batch(spectrogram, reference, min_value, db_range)
773
+ elif power == 2.0:
774
+ spectrogram = power_to_db_batch(spectrogram, reference, min_value, db_range)
775
+ else:
776
+ raise ValueError(f"Cannot use log_mel option '{log_mel}' with power {power}")
777
+ else:
778
+ raise ValueError(f"Unknown log_mel option: {log_mel}")
779
+
780
+ spectrogram = np.asarray(spectrogram, dtype)
781
+
782
+ spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))]
783
+
784
+ return spectrogram_list
785
+
786
+
787
+ def power_to_db(
788
+ spectrogram: np.ndarray,
789
+ reference: float = 1.0,
790
+ min_value: float = 1e-10,
791
+ db_range: Optional[float] = None,
792
+ ) -> np.ndarray:
793
+ """
794
+ Converts a power spectrogram to the decibel scale. This computes `10 * log10(spectrogram / reference)`, using basic
795
+ logarithm properties for numerical stability.
796
+
797
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
798
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
799
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
800
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
801
+
802
+ Based on the implementation of `librosa.power_to_db`.
803
+
804
+ Args:
805
+ spectrogram (`np.ndarray`):
806
+ The input power (mel) spectrogram. Note that a power spectrogram has the amplitudes squared!
807
+ reference (`float`, *optional*, defaults to 1.0):
808
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
809
+ the loudest part to 0 dB. Must be greater than zero.
810
+ min_value (`float`, *optional*, defaults to `1e-10`):
811
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
812
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
813
+ db_range (`float`, *optional*):
814
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
815
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
816
+
817
+ Returns:
818
+ `np.ndarray`: the spectrogram in decibels
819
+ """
820
+ if reference <= 0.0:
821
+ raise ValueError("reference must be greater than zero")
822
+ if min_value <= 0.0:
823
+ raise ValueError("min_value must be greater than zero")
824
+
825
+ reference = max(min_value, reference)
826
+
827
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
828
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
829
+
830
+ if db_range is not None:
831
+ if db_range <= 0.0:
832
+ raise ValueError("db_range must be greater than zero")
833
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
834
+
835
+ return spectrogram
836
+
837
+
838
+ def power_to_db_batch(
839
+ spectrogram: np.ndarray,
840
+ reference: float = 1.0,
841
+ min_value: float = 1e-10,
842
+ db_range: Optional[float] = None,
843
+ ) -> np.ndarray:
844
+ """
845
+ Converts a batch of power spectrograms to the decibel scale. This computes `10 * log10(spectrogram / reference)`,
846
+ using basic logarithm properties for numerical stability.
847
+
848
+ This function supports batch processing, where each item in the batch is an individual power (mel) spectrogram.
849
+
850
+ Args:
851
+ spectrogram (`np.ndarray`):
852
+ The input batch of power (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
853
+ Note that a power spectrogram has the amplitudes squared!
854
+ reference (`float`, *optional*, defaults to 1.0):
855
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
856
+ the loudest part to 0 dB. Must be greater than zero.
857
+ min_value (`float`, *optional*, defaults to `1e-10`):
858
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
859
+ `log(0)`. The default of `1e-10` corresponds to a minimum of -100 dB. Must be greater than zero.
860
+ db_range (`float`, *optional*):
861
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
862
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
863
+
864
+ Returns:
865
+ `np.ndarray`: the batch of spectrograms in decibels
866
+ """
867
+ if reference <= 0.0:
868
+ raise ValueError("reference must be greater than zero")
869
+ if min_value <= 0.0:
870
+ raise ValueError("min_value must be greater than zero")
871
+
872
+ reference = max(min_value, reference)
873
+
874
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
875
+ spectrogram = 10.0 * (np.log10(spectrogram) - np.log10(reference))
876
+
877
+ if db_range is not None:
878
+ if db_range <= 0.0:
879
+ raise ValueError("db_range must be greater than zero")
880
+ # Apply db_range clipping per batch item
881
+ max_values = spectrogram.max(axis=(1, 2), keepdims=True)
882
+ spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
883
+
884
+ return spectrogram
885
+
886
+
887
+ def amplitude_to_db(
888
+ spectrogram: np.ndarray,
889
+ reference: float = 1.0,
890
+ min_value: float = 1e-5,
891
+ db_range: Optional[float] = None,
892
+ ) -> np.ndarray:
893
+ """
894
+ Converts an amplitude spectrogram to the decibel scale. This computes `20 * log10(spectrogram / reference)`, using
895
+ basic logarithm properties for numerical stability.
896
+
897
+ The motivation behind applying the log function on the (mel) spectrogram is that humans do not hear loudness on a
898
+ linear scale. Generally to double the perceived volume of a sound we need to put 8 times as much energy into it.
899
+ This means that large variations in energy may not sound all that different if the sound is loud to begin with.
900
+ This compression operation makes the (mel) spectrogram features match more closely what humans actually hear.
901
+
902
+ Args:
903
+ spectrogram (`np.ndarray`):
904
+ The input amplitude (mel) spectrogram.
905
+ reference (`float`, *optional*, defaults to 1.0):
906
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
907
+ the loudest part to 0 dB. Must be greater than zero.
908
+ min_value (`float`, *optional*, defaults to `1e-5`):
909
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
910
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
911
+ db_range (`float`, *optional*):
912
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
913
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
914
+
915
+ Returns:
916
+ `np.ndarray`: the spectrogram in decibels
917
+ """
918
+ if reference <= 0.0:
919
+ raise ValueError("reference must be greater than zero")
920
+ if min_value <= 0.0:
921
+ raise ValueError("min_value must be greater than zero")
922
+
923
+ reference = max(min_value, reference)
924
+
925
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
926
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
927
+
928
+ if db_range is not None:
929
+ if db_range <= 0.0:
930
+ raise ValueError("db_range must be greater than zero")
931
+ spectrogram = np.clip(spectrogram, a_min=spectrogram.max() - db_range, a_max=None)
932
+
933
+ return spectrogram
934
+
935
+
936
+ def amplitude_to_db_batch(
937
+ spectrogram: np.ndarray, reference: float = 1.0, min_value: float = 1e-5, db_range: Optional[float] = None
938
+ ) -> np.ndarray:
939
+ """
940
+ Converts a batch of amplitude spectrograms to the decibel scale. This computes `20 * log10(spectrogram / reference)`,
941
+ using basic logarithm properties for numerical stability.
942
+
943
+ The function supports batch processing, where each item in the batch is an individual amplitude (mel) spectrogram.
944
+
945
+ Args:
946
+ spectrogram (`np.ndarray`):
947
+ The input batch of amplitude (mel) spectrograms. Expected shape is (batch_size, *spectrogram_shape).
948
+ reference (`float`, *optional*, defaults to 1.0):
949
+ Sets the input spectrogram value that corresponds to 0 dB. For example, use `np.max(spectrogram)` to set
950
+ the loudest part to 0 dB. Must be greater than zero.
951
+ min_value (`float`, *optional*, defaults to `1e-5`):
952
+ The spectrogram will be clipped to this minimum value before conversion to decibels, to avoid taking
953
+ `log(0)`. The default of `1e-5` corresponds to a minimum of -100 dB. Must be greater than zero.
954
+ db_range (`float`, *optional*):
955
+ Sets the maximum dynamic range in decibels. For example, if `db_range = 80`, the difference between the
956
+ peak value and the smallest value will never be more than 80 dB. Must be greater than zero.
957
+
958
+ Returns:
959
+ `np.ndarray`: the batch of spectrograms in decibels
960
+ """
961
+ if reference <= 0.0:
962
+ raise ValueError("reference must be greater than zero")
963
+ if min_value <= 0.0:
964
+ raise ValueError("min_value must be greater than zero")
965
+
966
+ reference = max(min_value, reference)
967
+
968
+ spectrogram = np.clip(spectrogram, a_min=min_value, a_max=None)
969
+ spectrogram = 20.0 * (np.log10(spectrogram) - np.log10(reference))
970
+
971
+ if db_range is not None:
972
+ if db_range <= 0.0:
973
+ raise ValueError("db_range must be greater than zero")
974
+ # Apply db_range clipping per batch item
975
+ max_values = spectrogram.max(axis=(1, 2), keepdims=True)
976
+ spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
977
+
978
+ return spectrogram
979
+
980
+
981
+ ### deprecated functions below this line ###
982
+
983
+
984
+ def get_mel_filter_banks(
985
+ nb_frequency_bins: int,
986
+ nb_mel_filters: int,
987
+ frequency_min: float,
988
+ frequency_max: float,
989
+ sample_rate: int,
990
+ norm: Optional[str] = None,
991
+ mel_scale: str = "htk",
992
+ ) -> np.array:
993
+ warnings.warn(
994
+ "The function `get_mel_filter_banks` is deprecated and will be removed in version 4.31.0 of Transformers",
995
+ FutureWarning,
996
+ )
997
+ return mel_filter_bank(
998
+ num_frequency_bins=nb_frequency_bins,
999
+ num_mel_filters=nb_mel_filters,
1000
+ min_frequency=frequency_min,
1001
+ max_frequency=frequency_max,
1002
+ sampling_rate=sample_rate,
1003
+ norm=norm,
1004
+ mel_scale=mel_scale,
1005
+ )
1006
+
1007
+
1008
+ def fram_wave(waveform: np.array, hop_length: int = 160, fft_window_size: int = 400, center: bool = True):
1009
+ """
1010
+ In order to compute the short time fourier transform, the waveform needs to be split in overlapping windowed
1011
+ segments called `frames`.
1012
+
1013
+ The window length (window_length) defines how much of the signal is contained in each frame, while the hop length
1014
+ defines the step between the beginning of each new frame.
1015
+
1016
+
1017
+ Args:
1018
+ waveform (`np.array` of shape `(sample_length,)`):
1019
+ The raw waveform which will be split into smaller chunks.
1020
+ hop_length (`int`, *optional*, defaults to 160):
1021
+ Step between each window of the waveform.
1022
+ fft_window_size (`int`, *optional*, defaults to 400):
1023
+ Defines the size of the window.
1024
+ center (`bool`, defaults to `True`):
1025
+ Whether or not to center each frame around the middle of the frame. Centering is done by reflecting the
1026
+ waveform on the left and on the right.
1027
+
1028
+ Return:
1029
+ framed_waveform (`np.array` of shape `(waveform.shape // hop_length , fft_window_size)`):
1030
+ The framed waveforms that can be fed to `np.fft`.
1031
+ """
1032
+ warnings.warn(
1033
+ "The function `fram_wave` is deprecated and will be removed in version 4.31.0 of Transformers",
1034
+ FutureWarning,
1035
+ )
1036
+ frames = []
1037
+ for i in range(0, waveform.shape[0] + 1, hop_length):
1038
+ if center:
1039
+ half_window = (fft_window_size - 1) // 2 + 1
1040
+ start = i - half_window if i > half_window else 0
1041
+ end = i + half_window if i < waveform.shape[0] - half_window else waveform.shape[0]
1042
+ frame = waveform[start:end]
1043
+ if start == 0:
1044
+ padd_width = (-i + half_window, 0)
1045
+ frame = np.pad(frame, pad_width=padd_width, mode="reflect")
1046
+
1047
+ elif end == waveform.shape[0]:
1048
+ padd_width = (0, (i - waveform.shape[0] + half_window))
1049
+ frame = np.pad(frame, pad_width=padd_width, mode="reflect")
1050
+
1051
+ else:
1052
+ frame = waveform[i : i + fft_window_size]
1053
+ frame_width = frame.shape[0]
1054
+ if frame_width < waveform.shape[0]:
1055
+ frame = np.lib.pad(
1056
+ frame, pad_width=(0, fft_window_size - frame_width), mode="constant", constant_values=0
1057
+ )
1058
+ frames.append(frame)
1059
+
1060
+ frames = np.stack(frames, 0)
1061
+ return frames
1062
+
1063
+
1064
+ def stft(frames: np.array, windowing_function: np.array, fft_window_size: int = None):
1065
+ """
1066
+ Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. Should give the same results
1067
+ as `torch.stft`.
1068
+
1069
+ Args:
1070
+ frames (`np.array` of dimension `(num_frames, fft_window_size)`):
1071
+ A framed audio signal obtained using `audio_utils.fram_wav`.
1072
+ windowing_function (`np.array` of dimension `(nb_frequency_bins, nb_mel_filters)`:
1073
+ A array representing the function that will be used to reduces the amplitude of the discontinuities at the
1074
+ boundaries of each frame when computing the STFT. Each frame will be multiplied by the windowing_function.
1075
+ For more information on the discontinuities, called *Spectral leakage*, refer to [this
1076
+ tutorial]https://download.ni.com/evaluation/pxi/Understanding%20FFTs%20and%20Windowing.pdf
1077
+ fft_window_size (`int`, *optional*):
1078
+ Size of the window om which the Fourier transform is applied. This controls the frequency resolution of the
1079
+ spectrogram. 400 means that the fourrier transform is computed on windows of 400 samples. The number of
1080
+ frequency bins (`nb_frequency_bins`) used to divide the window into equal strips is equal to
1081
+ `(1+fft_window_size)//2`. An increase of the fft_window_size slows the calculus time proportionnally.
1082
+
1083
+ Example:
1084
+
1085
+ ```python
1086
+ >>> from transformers.audio_utils import stft, fram_wave
1087
+ >>> import numpy as np
1088
+
1089
+ >>> audio = np.random.rand(50)
1090
+ >>> fft_window_size = 10
1091
+ >>> hop_length = 2
1092
+ >>> framed_audio = fram_wave(audio, hop_length, fft_window_size)
1093
+ >>> spectrogram = stft(framed_audio, np.hanning(fft_window_size + 1))
1094
+ ```
1095
+
1096
+ Returns:
1097
+ spectrogram (`np.ndarray`):
1098
+ A spectrogram of shape `(num_frames, nb_frequency_bins)` obtained using the STFT algorithm
1099
+ """
1100
+ warnings.warn(
1101
+ "The function `stft` is deprecated and will be removed in version 4.31.0 of Transformers",
1102
+ FutureWarning,
1103
+ )
1104
+ frame_size = frames.shape[1]
1105
+
1106
+ if fft_window_size is None:
1107
+ fft_window_size = frame_size
1108
+
1109
+ if fft_window_size < frame_size:
1110
+ raise ValueError("FFT size must greater or equal the frame size")
1111
+ # number of FFT bins to store
1112
+ nb_frequency_bins = (fft_window_size >> 1) + 1
1113
+
1114
+ spectrogram = np.empty((len(frames), nb_frequency_bins), dtype=np.complex64)
1115
+ fft_signal = np.zeros(fft_window_size)
1116
+
1117
+ for f, frame in enumerate(frames):
1118
+ if windowing_function is not None:
1119
+ np.multiply(frame, windowing_function, out=fft_signal[:frame_size])
1120
+ else:
1121
+ fft_signal[:frame_size] = frame
1122
+ spectrogram[f] = np.fft.fft(fft_signal, axis=0)[:nb_frequency_bins]
1123
+ return spectrogram.T
base_tokenizer.cpython-312.pyc ADDED
Binary file (19.4 kB). View file
 
base_tokenizer.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple, Union
2
+
3
+ from tokenizers import AddedToken, EncodeInput, Encoding, InputSequence, Tokenizer
4
+ from tokenizers.decoders import Decoder
5
+ from tokenizers.models import Model
6
+ from tokenizers.normalizers import Normalizer
7
+ from tokenizers.pre_tokenizers import PreTokenizer
8
+ from tokenizers.processors import PostProcessor
9
+
10
+
11
+ Offsets = Tuple[int, int]
12
+
13
+
14
+ class BaseTokenizer:
15
+ def __init__(self, tokenizer: Tokenizer, parameters=None):
16
+ self._tokenizer = tokenizer
17
+ self._parameters = parameters if parameters is not None else {}
18
+
19
+ def __repr__(self):
20
+ return "Tokenizer(vocabulary_size={}, {})".format(
21
+ self._tokenizer.get_vocab_size(),
22
+ ", ".join(k + "=" + str(v) for k, v in self._parameters.items()),
23
+ )
24
+
25
+ def num_special_tokens_to_add(self, is_pair: bool) -> int:
26
+ """
27
+ Return the number of special tokens that would be added for single/pair sentences.
28
+ :param is_pair: Boolean indicating if the input would be a single sentence or a pair
29
+ :return:
30
+ """
31
+ return self._tokenizer.num_special_tokens_to_add(is_pair)
32
+
33
+ def get_vocab(self, with_added_tokens: bool = True) -> Dict[str, int]:
34
+ """Returns the vocabulary
35
+
36
+ Args:
37
+ with_added_tokens: boolean:
38
+ Whether to include the added tokens in the vocabulary
39
+
40
+ Returns:
41
+ The vocabulary
42
+ """
43
+ return self._tokenizer.get_vocab(with_added_tokens=with_added_tokens)
44
+
45
+ def get_added_tokens_decoder(self) -> Dict[int, AddedToken]:
46
+ """Returns the added reverse vocabulary
47
+
48
+ Returns:
49
+ The added vocabulary mapping ints to AddedTokens
50
+ """
51
+ return self._tokenizer.get_added_tokens_decoder()
52
+
53
+ def get_vocab_size(self, with_added_tokens: bool = True) -> int:
54
+ """Return the size of vocabulary, with or without added tokens.
55
+
56
+ Args:
57
+ with_added_tokens: (`optional`) bool:
58
+ Whether to count in added special tokens or not
59
+
60
+ Returns:
61
+ Size of vocabulary
62
+ """
63
+ return self._tokenizer.get_vocab_size(with_added_tokens=with_added_tokens)
64
+
65
+ def enable_padding(
66
+ self,
67
+ direction: Optional[str] = "right",
68
+ pad_to_multiple_of: Optional[int] = None,
69
+ pad_id: Optional[int] = 0,
70
+ pad_type_id: Optional[int] = 0,
71
+ pad_token: Optional[str] = "[PAD]",
72
+ length: Optional[int] = None,
73
+ ):
74
+ """Change the padding strategy
75
+
76
+ Args:
77
+ direction: (`optional`) str:
78
+ Can be one of: `right` or `left`
79
+
80
+ pad_to_multiple_of: (`optional`) unsigned int:
81
+ If specified, the padding length should always snap to the next multiple of
82
+ the given value. For example if we were going to pad with a length of 250 but
83
+ `pad_to_multiple_of=8` then we will pad to 256.
84
+
85
+ pad_id: (`optional`) unsigned int:
86
+ The indice to be used when padding
87
+
88
+ pad_type_id: (`optional`) unsigned int:
89
+ The type indice to be used when padding
90
+
91
+ pad_token: (`optional`) str:
92
+ The pad token to be used when padding
93
+
94
+ length: (`optional`) unsigned int:
95
+ If specified, the length at which to pad. If not specified
96
+ we pad using the size of the longest sequence in a batch
97
+ """
98
+ return self._tokenizer.enable_padding(
99
+ direction=direction,
100
+ pad_to_multiple_of=pad_to_multiple_of,
101
+ pad_id=pad_id,
102
+ pad_type_id=pad_type_id,
103
+ pad_token=pad_token,
104
+ length=length,
105
+ )
106
+
107
+ def no_padding(self):
108
+ """Disable padding"""
109
+ return self._tokenizer.no_padding()
110
+
111
+ @property
112
+ def padding(self) -> Optional[dict]:
113
+ """Get the current padding parameters
114
+
115
+ Returns:
116
+ None if padding is disabled, a dict with the currently set parameters
117
+ if the padding is enabled.
118
+ """
119
+ return self._tokenizer.padding
120
+
121
+ def enable_truncation(self, max_length: int, stride: Optional[int] = 0, strategy: Optional[str] = "longest_first"):
122
+ """Change the truncation options
123
+
124
+ Args:
125
+ max_length: unsigned int:
126
+ The maximum length at which to truncate
127
+
128
+ stride: (`optional`) unsigned int:
129
+ The length of the previous first sequence to be included
130
+ in the overflowing sequence
131
+
132
+ strategy: (`optional`) str:
133
+ Can be one of `longest_first`, `only_first` or `only_second`
134
+ """
135
+ return self._tokenizer.enable_truncation(max_length, stride=stride, strategy=strategy)
136
+
137
+ def no_truncation(self):
138
+ """Disable truncation"""
139
+ return self._tokenizer.no_truncation()
140
+
141
+ @property
142
+ def truncation(self) -> Optional[dict]:
143
+ """Get the current truncation parameters
144
+
145
+ Returns:
146
+ None if truncation is disabled, a dict with the current truncation parameters if
147
+ truncation is enabled
148
+ """
149
+ return self._tokenizer.truncation
150
+
151
+ def add_tokens(self, tokens: List[Union[str, AddedToken]]) -> int:
152
+ """Add the given tokens to the vocabulary
153
+
154
+ Args:
155
+ tokens: List[Union[str, AddedToken]]:
156
+ A list of tokens to add to the vocabulary. Each token can either be
157
+ a string, or an instance of AddedToken
158
+
159
+ Returns:
160
+ The number of tokens that were added to the vocabulary
161
+ """
162
+ return self._tokenizer.add_tokens(tokens)
163
+
164
+ def add_special_tokens(self, special_tokens: List[Union[str, AddedToken]]) -> int:
165
+ """Add the given special tokens to the vocabulary, and treat them as special tokens.
166
+
167
+ The special tokens will never be processed by the model, and will be
168
+ removed while decoding.
169
+
170
+ Args:
171
+ tokens: List[Union[str, AddedToken]]:
172
+ A list of special tokens to add to the vocabulary. Each token can either be
173
+ a string, or an instance of AddedToken
174
+
175
+ Returns:
176
+ The number of tokens that were added to the vocabulary
177
+ """
178
+ return self._tokenizer.add_special_tokens(special_tokens)
179
+
180
+ def normalize(self, sequence: str) -> str:
181
+ """Normalize the given sequence
182
+
183
+ Args:
184
+ sequence: str:
185
+ The sequence to normalize
186
+
187
+ Returns:
188
+ The normalized string
189
+ """
190
+ return self._tokenizer.normalize(sequence)
191
+
192
+ def encode(
193
+ self,
194
+ sequence: InputSequence,
195
+ pair: Optional[InputSequence] = None,
196
+ is_pretokenized: bool = False,
197
+ add_special_tokens: bool = True,
198
+ ) -> Encoding:
199
+ """Encode the given sequence and pair. This method can process raw text sequences as well
200
+ as already pre-tokenized sequences.
201
+
202
+ Args:
203
+ sequence: InputSequence:
204
+ The sequence we want to encode. This sequence can be either raw text or
205
+ pre-tokenized, according to the `is_pretokenized` argument:
206
+
207
+ - If `is_pretokenized=False`: `InputSequence` is expected to be `str`
208
+ - If `is_pretokenized=True`: `InputSequence` is expected to be
209
+ `Union[List[str], Tuple[str]]`
210
+
211
+ is_pretokenized: bool:
212
+ Whether the input is already pre-tokenized.
213
+
214
+ add_special_tokens: bool:
215
+ Whether to add the special tokens while encoding.
216
+
217
+ Returns:
218
+ An Encoding
219
+ """
220
+ if sequence is None:
221
+ raise ValueError("encode: `sequence` can't be `None`")
222
+
223
+ return self._tokenizer.encode(sequence, pair, is_pretokenized, add_special_tokens)
224
+
225
+ def encode_batch(
226
+ self,
227
+ inputs: List[EncodeInput],
228
+ is_pretokenized: bool = False,
229
+ add_special_tokens: bool = True,
230
+ ) -> List[Encoding]:
231
+ """Encode the given inputs. This method accept both raw text sequences as well as already
232
+ pre-tokenized sequences.
233
+
234
+ Args:
235
+ inputs: List[EncodeInput]:
236
+ A list of single sequences or pair sequences to encode. Each `EncodeInput` is
237
+ expected to be of the following form:
238
+ `Union[InputSequence, Tuple[InputSequence, InputSequence]]`
239
+
240
+ Each `InputSequence` can either be raw text or pre-tokenized,
241
+ according to the `is_pretokenized` argument:
242
+
243
+ - If `is_pretokenized=False`: `InputSequence` is expected to be `str`
244
+ - If `is_pretokenized=True`: `InputSequence` is expected to be
245
+ `Union[List[str], Tuple[str]]`
246
+
247
+ is_pretokenized: bool:
248
+ Whether the input is already pre-tokenized.
249
+
250
+ add_special_tokens: bool:
251
+ Whether to add the special tokens while encoding.
252
+
253
+ Returns:
254
+ A list of Encoding
255
+ """
256
+
257
+ if inputs is None:
258
+ raise ValueError("encode_batch: `inputs` can't be `None`")
259
+
260
+ return self._tokenizer.encode_batch(inputs, is_pretokenized, add_special_tokens)
261
+
262
+ def decode(self, ids: List[int], skip_special_tokens: Optional[bool] = True) -> str:
263
+ """Decode the given list of ids to a string sequence
264
+
265
+ Args:
266
+ ids: List[unsigned int]:
267
+ A list of ids to be decoded
268
+
269
+ skip_special_tokens: (`optional`) boolean:
270
+ Whether to remove all the special tokens from the output string
271
+
272
+ Returns:
273
+ The decoded string
274
+ """
275
+ if ids is None:
276
+ raise ValueError("None input is not valid. Should be a list of integers.")
277
+
278
+ return self._tokenizer.decode(ids, skip_special_tokens=skip_special_tokens)
279
+
280
+ def decode_batch(self, sequences: List[List[int]], skip_special_tokens: Optional[bool] = True) -> str:
281
+ """Decode the list of sequences to a list of string sequences
282
+
283
+ Args:
284
+ sequences: List[List[unsigned int]]:
285
+ A list of sequence of ids to be decoded
286
+
287
+ skip_special_tokens: (`optional`) boolean:
288
+ Whether to remove all the special tokens from the output strings
289
+
290
+ Returns:
291
+ A list of decoded strings
292
+ """
293
+ if sequences is None:
294
+ raise ValueError("None input is not valid. Should be list of list of integers.")
295
+
296
+ return self._tokenizer.decode_batch(sequences, skip_special_tokens=skip_special_tokens)
297
+
298
+ def token_to_id(self, token: str) -> Optional[int]:
299
+ """Convert the given token to its corresponding id
300
+
301
+ Args:
302
+ token: str:
303
+ The token to convert
304
+
305
+ Returns:
306
+ The corresponding id if it exists, None otherwise
307
+ """
308
+ return self._tokenizer.token_to_id(token)
309
+
310
+ def id_to_token(self, id: int) -> Optional[str]:
311
+ """Convert the given token id to its corresponding string
312
+
313
+ Args:
314
+ token: id:
315
+ The token id to convert
316
+
317
+ Returns:
318
+ The corresponding string if it exists, None otherwise
319
+ """
320
+ return self._tokenizer.id_to_token(id)
321
+
322
+ def save_model(self, directory: str, prefix: Optional[str] = None):
323
+ """Save the current model to the given directory
324
+
325
+ Args:
326
+ directory: str:
327
+ A path to the destination directory
328
+
329
+ prefix: (Optional) str:
330
+ An optional prefix, used to prefix each file name
331
+ """
332
+ return self._tokenizer.model.save(directory, prefix=prefix)
333
+
334
+ def save(self, path: str, pretty: bool = True):
335
+ """Save the current Tokenizer at the given path
336
+
337
+ Args:
338
+ path: str:
339
+ A path to the destination Tokenizer file
340
+ """
341
+ return self._tokenizer.save(path, pretty)
342
+
343
+ def to_str(self, pretty: bool = False):
344
+ """Get a serialized JSON version of the Tokenizer as a str
345
+
346
+ Args:
347
+ pretty: bool:
348
+ Whether the JSON string should be prettified
349
+
350
+ Returns:
351
+ str
352
+ """
353
+ return self._tokenizer.to_str(pretty)
354
+
355
+ def post_process(
356
+ self, encoding: Encoding, pair: Optional[Encoding] = None, add_special_tokens: bool = True
357
+ ) -> Encoding:
358
+ """Apply all the post-processing steps to the given encodings.
359
+
360
+ The various steps are:
361
+ 1. Truncate according to global params (provided to `enable_truncation`)
362
+ 2. Apply the PostProcessor
363
+ 3. Pad according to global params. (provided to `enable_padding`)
364
+
365
+ Args:
366
+ encoding: Encoding:
367
+ The main Encoding to post process
368
+
369
+ pair: Optional[Encoding]:
370
+ An optional pair Encoding
371
+
372
+ add_special_tokens: bool:
373
+ Whether to add special tokens
374
+
375
+ Returns:
376
+ The resulting Encoding
377
+ """
378
+ return self._tokenizer.post_process(encoding, pair, add_special_tokens)
379
+
380
+ @property
381
+ def model(self) -> Model:
382
+ return self._tokenizer.model
383
+
384
+ @model.setter
385
+ def model(self, model: Model):
386
+ self._tokenizer.model = model
387
+
388
+ @property
389
+ def normalizer(self) -> Normalizer:
390
+ return self._tokenizer.normalizer
391
+
392
+ @normalizer.setter
393
+ def normalizer(self, normalizer: Normalizer):
394
+ self._tokenizer.normalizer = normalizer
395
+
396
+ @property
397
+ def pre_tokenizer(self) -> PreTokenizer:
398
+ return self._tokenizer.pre_tokenizer
399
+
400
+ @pre_tokenizer.setter
401
+ def pre_tokenizer(self, pre_tokenizer: PreTokenizer):
402
+ self._tokenizer.pre_tokenizer = pre_tokenizer
403
+
404
+ @property
405
+ def post_processor(self) -> PostProcessor:
406
+ return self._tokenizer.post_processor
407
+
408
+ @post_processor.setter
409
+ def post_processor(self, post_processor: PostProcessor):
410
+ self._tokenizer.post_processor = post_processor
411
+
412
+ @property
413
+ def decoder(self) -> Decoder:
414
+ return self._tokenizer.decoder
415
+
416
+ @decoder.setter
417
+ def decoder(self, decoder: Decoder):
418
+ self._tokenizer.decoder = decoder
bert_wordpiece.cpython-312.pyc ADDED
Binary file (5.83 kB). View file
 
bert_wordpiece.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Iterator, List, Optional, Union
2
+
3
+ from tokenizers import AddedToken, Tokenizer, decoders, trainers
4
+ from tokenizers.models import WordPiece
5
+ from tokenizers.normalizers import BertNormalizer
6
+ from tokenizers.pre_tokenizers import BertPreTokenizer
7
+ from tokenizers.processors import BertProcessing
8
+
9
+ from .base_tokenizer import BaseTokenizer
10
+
11
+
12
+ class BertWordPieceTokenizer(BaseTokenizer):
13
+ """Bert WordPiece Tokenizer"""
14
+
15
+ def __init__(
16
+ self,
17
+ vocab: Optional[Union[str, Dict[str, int]]] = None,
18
+ unk_token: Union[str, AddedToken] = "[UNK]",
19
+ sep_token: Union[str, AddedToken] = "[SEP]",
20
+ cls_token: Union[str, AddedToken] = "[CLS]",
21
+ pad_token: Union[str, AddedToken] = "[PAD]",
22
+ mask_token: Union[str, AddedToken] = "[MASK]",
23
+ clean_text: bool = True,
24
+ handle_chinese_chars: bool = True,
25
+ strip_accents: Optional[bool] = None,
26
+ lowercase: bool = True,
27
+ wordpieces_prefix: str = "##",
28
+ ):
29
+ if vocab is not None:
30
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token)))
31
+ else:
32
+ tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token)))
33
+
34
+ # Let the tokenizer know about special tokens if they are part of the vocab
35
+ if tokenizer.token_to_id(str(unk_token)) is not None:
36
+ tokenizer.add_special_tokens([str(unk_token)])
37
+ if tokenizer.token_to_id(str(sep_token)) is not None:
38
+ tokenizer.add_special_tokens([str(sep_token)])
39
+ if tokenizer.token_to_id(str(cls_token)) is not None:
40
+ tokenizer.add_special_tokens([str(cls_token)])
41
+ if tokenizer.token_to_id(str(pad_token)) is not None:
42
+ tokenizer.add_special_tokens([str(pad_token)])
43
+ if tokenizer.token_to_id(str(mask_token)) is not None:
44
+ tokenizer.add_special_tokens([str(mask_token)])
45
+
46
+ tokenizer.normalizer = BertNormalizer(
47
+ clean_text=clean_text,
48
+ handle_chinese_chars=handle_chinese_chars,
49
+ strip_accents=strip_accents,
50
+ lowercase=lowercase,
51
+ )
52
+ tokenizer.pre_tokenizer = BertPreTokenizer()
53
+
54
+ if vocab is not None:
55
+ sep_token_id = tokenizer.token_to_id(str(sep_token))
56
+ if sep_token_id is None:
57
+ raise TypeError("sep_token not found in the vocabulary")
58
+ cls_token_id = tokenizer.token_to_id(str(cls_token))
59
+ if cls_token_id is None:
60
+ raise TypeError("cls_token not found in the vocabulary")
61
+
62
+ tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id))
63
+ tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
64
+
65
+ parameters = {
66
+ "model": "BertWordPiece",
67
+ "unk_token": unk_token,
68
+ "sep_token": sep_token,
69
+ "cls_token": cls_token,
70
+ "pad_token": pad_token,
71
+ "mask_token": mask_token,
72
+ "clean_text": clean_text,
73
+ "handle_chinese_chars": handle_chinese_chars,
74
+ "strip_accents": strip_accents,
75
+ "lowercase": lowercase,
76
+ "wordpieces_prefix": wordpieces_prefix,
77
+ }
78
+
79
+ super().__init__(tokenizer, parameters)
80
+
81
+ @staticmethod
82
+ def from_file(vocab: str, **kwargs):
83
+ vocab = WordPiece.read_file(vocab)
84
+ return BertWordPieceTokenizer(vocab, **kwargs)
85
+
86
+ def train(
87
+ self,
88
+ files: Union[str, List[str]],
89
+ vocab_size: int = 30000,
90
+ min_frequency: int = 2,
91
+ limit_alphabet: int = 1000,
92
+ initial_alphabet: List[str] = [],
93
+ special_tokens: List[Union[str, AddedToken]] = [
94
+ "[PAD]",
95
+ "[UNK]",
96
+ "[CLS]",
97
+ "[SEP]",
98
+ "[MASK]",
99
+ ],
100
+ show_progress: bool = True,
101
+ wordpieces_prefix: str = "##",
102
+ ):
103
+ """Train the model using the given files"""
104
+
105
+ trainer = trainers.WordPieceTrainer(
106
+ vocab_size=vocab_size,
107
+ min_frequency=min_frequency,
108
+ limit_alphabet=limit_alphabet,
109
+ initial_alphabet=initial_alphabet,
110
+ special_tokens=special_tokens,
111
+ show_progress=show_progress,
112
+ continuing_subword_prefix=wordpieces_prefix,
113
+ )
114
+ if isinstance(files, str):
115
+ files = [files]
116
+ self._tokenizer.train(files, trainer=trainer)
117
+
118
+ def train_from_iterator(
119
+ self,
120
+ iterator: Union[Iterator[str], Iterator[Iterator[str]]],
121
+ vocab_size: int = 30000,
122
+ min_frequency: int = 2,
123
+ limit_alphabet: int = 1000,
124
+ initial_alphabet: List[str] = [],
125
+ special_tokens: List[Union[str, AddedToken]] = [
126
+ "[PAD]",
127
+ "[UNK]",
128
+ "[CLS]",
129
+ "[SEP]",
130
+ "[MASK]",
131
+ ],
132
+ show_progress: bool = True,
133
+ wordpieces_prefix: str = "##",
134
+ length: Optional[int] = None,
135
+ ):
136
+ """Train the model using the given iterator"""
137
+
138
+ trainer = trainers.WordPieceTrainer(
139
+ vocab_size=vocab_size,
140
+ min_frequency=min_frequency,
141
+ limit_alphabet=limit_alphabet,
142
+ initial_alphabet=initial_alphabet,
143
+ special_tokens=special_tokens,
144
+ show_progress=show_progress,
145
+ continuing_subword_prefix=wordpieces_prefix,
146
+ )
147
+ self._tokenizer.train_from_iterator(
148
+ iterator,
149
+ trainer=trainer,
150
+ length=length,
151
+ )
byte_level_bpe.cpython-312.pyc ADDED
Binary file (4.85 kB). View file
 
byte_level_bpe.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
2
+
3
+ from tokenizers import AddedToken, Tokenizer, decoders, pre_tokenizers, processors, trainers
4
+ from tokenizers.models import BPE
5
+ from tokenizers.normalizers import Lowercase, Sequence, unicode_normalizer_from_str
6
+
7
+ from .base_tokenizer import BaseTokenizer
8
+
9
+
10
+ class ByteLevelBPETokenizer(BaseTokenizer):
11
+ """ByteLevelBPETokenizer
12
+
13
+ Represents a Byte-level BPE as introduced by OpenAI with their GPT-2 model
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ vocab: Optional[Union[str, Dict[str, int]]] = None,
19
+ merges: Optional[Union[str, Dict[Tuple[int, int], Tuple[int, int]]]] = None,
20
+ add_prefix_space: bool = False,
21
+ lowercase: bool = False,
22
+ dropout: Optional[float] = None,
23
+ unicode_normalizer: Optional[str] = None,
24
+ continuing_subword_prefix: Optional[str] = None,
25
+ end_of_word_suffix: Optional[str] = None,
26
+ trim_offsets: bool = False,
27
+ ):
28
+ if vocab is not None and merges is not None:
29
+ tokenizer = Tokenizer(
30
+ BPE(
31
+ vocab,
32
+ merges,
33
+ dropout=dropout,
34
+ continuing_subword_prefix=continuing_subword_prefix or "",
35
+ end_of_word_suffix=end_of_word_suffix or "",
36
+ )
37
+ )
38
+ else:
39
+ tokenizer = Tokenizer(BPE())
40
+
41
+ # Check for Unicode normalization first (before everything else)
42
+ normalizers = []
43
+
44
+ if unicode_normalizer:
45
+ normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
46
+
47
+ if lowercase:
48
+ normalizers += [Lowercase()]
49
+
50
+ # Create the normalizer structure
51
+ if len(normalizers) > 0:
52
+ if len(normalizers) > 1:
53
+ tokenizer.normalizer = Sequence(normalizers)
54
+ else:
55
+ tokenizer.normalizer = normalizers[0]
56
+
57
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
58
+ tokenizer.decoder = decoders.ByteLevel()
59
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=trim_offsets)
60
+
61
+ parameters = {
62
+ "model": "ByteLevelBPE",
63
+ "add_prefix_space": add_prefix_space,
64
+ "lowercase": lowercase,
65
+ "dropout": dropout,
66
+ "unicode_normalizer": unicode_normalizer,
67
+ "continuing_subword_prefix": continuing_subword_prefix,
68
+ "end_of_word_suffix": end_of_word_suffix,
69
+ "trim_offsets": trim_offsets,
70
+ }
71
+
72
+ super().__init__(tokenizer, parameters)
73
+
74
+ @staticmethod
75
+ def from_file(vocab_filename: str, merges_filename: str, **kwargs):
76
+ vocab, merges = BPE.read_file(vocab_filename, merges_filename)
77
+ return ByteLevelBPETokenizer(vocab, merges, **kwargs)
78
+
79
+ def train(
80
+ self,
81
+ files: Union[str, List[str]],
82
+ vocab_size: int = 30000,
83
+ min_frequency: int = 2,
84
+ show_progress: bool = True,
85
+ special_tokens: List[Union[str, AddedToken]] = [],
86
+ ):
87
+ """Train the model using the given files"""
88
+
89
+ trainer = trainers.BpeTrainer(
90
+ vocab_size=vocab_size,
91
+ min_frequency=min_frequency,
92
+ show_progress=show_progress,
93
+ special_tokens=special_tokens,
94
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
95
+ )
96
+ if isinstance(files, str):
97
+ files = [files]
98
+ self._tokenizer.train(files, trainer=trainer)
99
+
100
+ def train_from_iterator(
101
+ self,
102
+ iterator: Union[Iterator[str], Iterator[Iterator[str]]],
103
+ vocab_size: int = 30000,
104
+ min_frequency: int = 2,
105
+ show_progress: bool = True,
106
+ special_tokens: List[Union[str, AddedToken]] = [],
107
+ length: Optional[int] = None,
108
+ ):
109
+ """Train the model using the given iterator"""
110
+
111
+ trainer = trainers.BpeTrainer(
112
+ vocab_size=vocab_size,
113
+ min_frequency=min_frequency,
114
+ show_progress=show_progress,
115
+ special_tokens=special_tokens,
116
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
117
+ )
118
+ self._tokenizer.train_from_iterator(
119
+ iterator,
120
+ trainer=trainer,
121
+ length=length,
122
+ )
cache_utils.py ADDED
The diff for this file is too large to render. See raw diff
 
char_level_bpe.cpython-312.pyc ADDED
Binary file (5.84 kB). View file
 
char_level_bpe.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
2
+
3
+ from .. import AddedToken, Tokenizer, decoders, pre_tokenizers, trainers
4
+ from ..models import BPE
5
+ from ..normalizers import BertNormalizer, Lowercase, Sequence, unicode_normalizer_from_str
6
+ from .base_tokenizer import BaseTokenizer
7
+
8
+
9
+ class CharBPETokenizer(BaseTokenizer):
10
+ """Original BPE Tokenizer
11
+
12
+ Represents the BPE algorithm, as introduced by Rico Sennrich
13
+ (https://arxiv.org/abs/1508.07909)
14
+
15
+ The defaults settings corresponds to OpenAI GPT BPE tokenizers and differs from the original
16
+ Sennrich subword-nmt implementation by the following options that you can deactivate:
17
+ - adding a normalizer to clean up the text (deactivate with `bert_normalizer=False`) by:
18
+ * removing any control characters and replacing all whitespaces by the classic one.
19
+ * handle chinese chars by putting spaces around them.
20
+ * strip all accents.
21
+ - spitting on punctuation in addition to whitespaces (deactivate it with
22
+ `split_on_whitespace_only=True`)
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ vocab: Optional[Union[str, Dict[str, int]]] = None,
28
+ merges: Optional[Union[str, Dict[Tuple[int, int], Tuple[int, int]]]] = None,
29
+ unk_token: Union[str, AddedToken] = "<unk>",
30
+ suffix: str = "</w>",
31
+ dropout: Optional[float] = None,
32
+ lowercase: bool = False,
33
+ unicode_normalizer: Optional[str] = None,
34
+ bert_normalizer: bool = True,
35
+ split_on_whitespace_only: bool = False,
36
+ ):
37
+ if vocab is not None and merges is not None:
38
+ tokenizer = Tokenizer(
39
+ BPE(
40
+ vocab,
41
+ merges,
42
+ dropout=dropout,
43
+ unk_token=str(unk_token),
44
+ end_of_word_suffix=suffix,
45
+ )
46
+ )
47
+ else:
48
+ tokenizer = Tokenizer(BPE(unk_token=str(unk_token), dropout=dropout, end_of_word_suffix=suffix))
49
+
50
+ if tokenizer.token_to_id(str(unk_token)) is not None:
51
+ tokenizer.add_special_tokens([str(unk_token)])
52
+
53
+ # Check for Unicode normalization first (before everything else)
54
+ normalizers = []
55
+
56
+ if unicode_normalizer:
57
+ normalizers += [unicode_normalizer_from_str(unicode_normalizer)]
58
+
59
+ if bert_normalizer:
60
+ normalizers += [BertNormalizer(lowercase=False)]
61
+
62
+ if lowercase:
63
+ normalizers += [Lowercase()]
64
+
65
+ # Create the normalizer structure
66
+ if len(normalizers) > 0:
67
+ if len(normalizers) > 1:
68
+ tokenizer.normalizer = Sequence(normalizers)
69
+ else:
70
+ tokenizer.normalizer = normalizers[0]
71
+
72
+ if split_on_whitespace_only:
73
+ tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
74
+ else:
75
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
76
+
77
+ tokenizer.decoder = decoders.BPEDecoder(suffix=suffix)
78
+
79
+ parameters = {
80
+ "model": "BPE",
81
+ "unk_token": unk_token,
82
+ "suffix": suffix,
83
+ "dropout": dropout,
84
+ "lowercase": lowercase,
85
+ "unicode_normalizer": unicode_normalizer,
86
+ "bert_normalizer": bert_normalizer,
87
+ "split_on_whitespace_only": split_on_whitespace_only,
88
+ }
89
+
90
+ super().__init__(tokenizer, parameters)
91
+
92
+ @staticmethod
93
+ def from_file(vocab_filename: str, merges_filename: str, **kwargs):
94
+ vocab, merges = BPE.read_file(vocab_filename, merges_filename)
95
+ return CharBPETokenizer(vocab, merges, **kwargs)
96
+
97
+ def train(
98
+ self,
99
+ files: Union[str, List[str]],
100
+ vocab_size: int = 30000,
101
+ min_frequency: int = 2,
102
+ special_tokens: List[Union[str, AddedToken]] = ["<unk>"],
103
+ limit_alphabet: int = 1000,
104
+ initial_alphabet: List[str] = [],
105
+ suffix: Optional[str] = "</w>",
106
+ show_progress: bool = True,
107
+ ):
108
+ """Train the model using the given files"""
109
+
110
+ trainer = trainers.BpeTrainer(
111
+ vocab_size=vocab_size,
112
+ min_frequency=min_frequency,
113
+ special_tokens=special_tokens,
114
+ limit_alphabet=limit_alphabet,
115
+ initial_alphabet=initial_alphabet,
116
+ end_of_word_suffix=suffix,
117
+ show_progress=show_progress,
118
+ )
119
+ if isinstance(files, str):
120
+ files = [files]
121
+ self._tokenizer.train(files, trainer=trainer)
122
+
123
+ def train_from_iterator(
124
+ self,
125
+ iterator: Union[Iterator[str], Iterator[Iterator[str]]],
126
+ vocab_size: int = 30000,
127
+ min_frequency: int = 2,
128
+ special_tokens: List[Union[str, AddedToken]] = ["<unk>"],
129
+ limit_alphabet: int = 1000,
130
+ initial_alphabet: List[str] = [],
131
+ suffix: Optional[str] = "</w>",
132
+ show_progress: bool = True,
133
+ length: Optional[int] = None,
134
+ ):
135
+ """Train the model using the given iterator"""
136
+
137
+ trainer = trainers.BpeTrainer(
138
+ vocab_size=vocab_size,
139
+ min_frequency=min_frequency,
140
+ special_tokens=special_tokens,
141
+ limit_alphabet=limit_alphabet,
142
+ initial_alphabet=initial_alphabet,
143
+ end_of_word_suffix=suffix,
144
+ show_progress=show_progress,
145
+ )
146
+ self._tokenizer.train_from_iterator(
147
+ iterator,
148
+ trainer=trainer,
149
+ length=length,
150
+ )
configuration_utils.py ADDED
@@ -0,0 +1,1187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Configuration base class and utilities."""
17
+
18
+ import copy
19
+ import json
20
+ import os
21
+ import re
22
+ import warnings
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
+
25
+ from packaging import version
26
+
27
+ from . import __version__
28
+ from .dynamic_module_utils import custom_object_save
29
+ from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
30
+ from .utils import (
31
+ CONFIG_NAME,
32
+ PushToHubMixin,
33
+ add_model_info_to_auto_map,
34
+ add_model_info_to_custom_pipelines,
35
+ cached_file,
36
+ copy_func,
37
+ download_url,
38
+ extract_commit_hash,
39
+ is_remote_url,
40
+ is_torch_available,
41
+ logging,
42
+ )
43
+ from .utils.generic import is_timm_config_dict
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
49
+
50
+
51
+ class PretrainedConfig(PushToHubMixin):
52
+ # no-format
53
+ r"""
54
+ Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
55
+ methods for loading/downloading/saving configurations.
56
+
57
+ <Tip>
58
+
59
+ A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to
60
+ initialize a model does **not** load the model weights. It only affects the model's configuration.
61
+
62
+ </Tip>
63
+
64
+ Class attributes (overridden by derived classes):
65
+
66
+ - **model_type** (`str`) -- An identifier for the model type, serialized into the JSON file, and used to recreate
67
+ the correct object in [`~transformers.AutoConfig`].
68
+ - **is_composition** (`bool`) -- Whether the config class is composed of multiple sub-configs. In this case the
69
+ config has to be initialized from two or more configs of type [`~transformers.PretrainedConfig`] like:
70
+ [`~transformers.EncoderDecoderConfig`] or [`~RagConfig`].
71
+ - **keys_to_ignore_at_inference** (`List[str]`) -- A list of keys to ignore by default when looking at dictionary
72
+ outputs of the model during inference.
73
+ - **attribute_map** (`Dict[str, str]`) -- A dict that maps model specific attribute names to the standardized
74
+ naming of attributes.
75
+ - **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
76
+ parallel plan applied to the sub-module when `model.tensor_parallel` is called.
77
+
78
+ Common attributes (present in all subclasses):
79
+
80
+ - **vocab_size** (`int`) -- The number of tokens in the vocabulary, which is also the first dimension of the
81
+ embeddings matrix (this attribute may be missing for models that don't have a text modality like ViT).
82
+ - **hidden_size** (`int`) -- The hidden size of the model.
83
+ - **num_attention_heads** (`int`) -- The number of attention heads used in the multi-head attention layers of the
84
+ model.
85
+ - **num_hidden_layers** (`int`) -- The number of blocks in the model.
86
+
87
+ <Tip warning={true}>
88
+
89
+ Setting parameters for sequence generation in the model config is deprecated. For backward compatibility, loading
90
+ some of them will still be possible, but attempting to overwrite them will throw an exception -- you should set
91
+ them in a [~transformers.GenerationConfig]. Check the documentation of [~transformers.GenerationConfig] for more
92
+ information about the individual parameters.
93
+
94
+ </Tip>
95
+
96
+ Arg:
97
+ name_or_path (`str`, *optional*, defaults to `""`):
98
+ Store the string that was passed to [`PreTrainedModel.from_pretrained`] or
99
+ [`TFPreTrainedModel.from_pretrained`] as `pretrained_model_name_or_path` if the configuration was created
100
+ with such a method.
101
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
102
+ Whether or not the model should return all hidden-states.
103
+ output_attentions (`bool`, *optional*, defaults to `False`):
104
+ Whether or not the model should returns all attentions.
105
+ return_dict (`bool`, *optional*, defaults to `True`):
106
+ Whether or not the model should return a [`~transformers.utils.ModelOutput`] instead of a plain tuple.
107
+ is_encoder_decoder (`bool`, *optional*, defaults to `False`):
108
+ Whether the model is used as an encoder/decoder or not.
109
+ is_decoder (`bool`, *optional*, defaults to `False`):
110
+ Whether the model is used as decoder or not (in which case it's used as an encoder).
111
+ cross_attention_hidden_size** (`bool`, *optional*):
112
+ The hidden size of the cross-attention layer in case the model is used as a decoder in an encoder-decoder
113
+ setting and the cross-attention hidden dimension differs from `self.config.hidden_size`.
114
+ add_cross_attention (`bool`, *optional*, defaults to `False`):
115
+ Whether cross-attention layers should be added to the model. Note, this option is only relevant for models
116
+ that can be used as decoder models within the [`EncoderDecoderModel`] class, which consists of all models
117
+ in `AUTO_MODELS_FOR_CAUSAL_LM`.
118
+ tie_encoder_decoder (`bool`, *optional*, defaults to `False`):
119
+ Whether all encoder weights should be tied to their equivalent decoder weights. This requires the encoder
120
+ and decoder model to have the exact same parameter names.
121
+ prune_heads (`Dict[int, List[int]]`, *optional*, defaults to `{}`):
122
+ Pruned heads of the model. The keys are the selected layer indices and the associated values, the list of
123
+ heads to prune in said layer.
124
+
125
+ For instance `{1: [0, 2], 2: [2, 3]}` will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
126
+ chunk_size_feed_forward (`int`, *optional*, defaults to `0`):
127
+ The chunk size of all feed forward layers in the residual attention blocks. A chunk size of `0` means that
128
+ the feed forward layer is not chunked. A chunk size of n means that the feed forward layer processes `n` <
129
+ sequence_length embeddings at a time. For more information on feed forward chunking, see [How does Feed
130
+ Forward Chunking work?](../glossary.html#feed-forward-chunking).
131
+
132
+ > Parameters for fine-tuning tasks
133
+
134
+ architectures (`List[str]`, *optional*):
135
+ Model architectures that can be used with the model pretrained weights.
136
+ finetuning_task (`str`, *optional*):
137
+ Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow
138
+ or PyTorch) checkpoint.
139
+ id2label (`Dict[int, str]`, *optional*):
140
+ A map from index (for instance prediction index, or target index) to label.
141
+ label2id (`Dict[str, int]`, *optional*): A map from label to index for the model.
142
+ num_labels (`int`, *optional*):
143
+ Number of labels to use in the last layer added to the model, typically for a classification task.
144
+ task_specific_params (`Dict[str, Any]`, *optional*):
145
+ Additional keyword arguments to store for the current task.
146
+ problem_type (`str`, *optional*):
147
+ Problem type for `XxxForSequenceClassification` models. Can be one of `"regression"`,
148
+ `"single_label_classification"` or `"multi_label_classification"`.
149
+
150
+ > Parameters linked to the tokenizer
151
+
152
+ tokenizer_class (`str`, *optional*):
153
+ The name of the associated tokenizer class to use (if none is set, will use the tokenizer associated to the
154
+ model by default).
155
+ prefix (`str`, *optional*):
156
+ A specific prompt that should be added at the beginning of each text before calling the model.
157
+ bos_token_id (`int`, *optional*): The id of the _beginning-of-stream_ token.
158
+ pad_token_id (`int`, *optional*): The id of the _padding_ token.
159
+ eos_token_id (`int`, *optional*): The id of the _end-of-stream_ token.
160
+ decoder_start_token_id (`int`, *optional*):
161
+ If an encoder-decoder model starts decoding with a different token than _bos_, the id of that token.
162
+ sep_token_id (`int`, *optional*): The id of the _separation_ token.
163
+
164
+ > PyTorch specific parameters
165
+
166
+ torchscript (`bool`, *optional*, defaults to `False`):
167
+ Whether or not the model should be used with Torchscript.
168
+ tie_word_embeddings (`bool`, *optional*, defaults to `True`):
169
+ Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
170
+ model has a output word embedding layer.
171
+ torch_dtype (`str`, *optional*):
172
+ The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype`
173
+ (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved
174
+ model is `float16`, ideally we want to load it back using the minimal amount of memory needed to load
175
+ `float16` weights. Since the config object is stored in plain text, this attribute contains just the
176
+ floating type string without the `torch.` prefix. For example, for `torch.float16` ``torch_dtype` is the
177
+ `"float16"` string.
178
+
179
+ This attribute is currently not being used during model loading time, but this may change in the future
180
+ versions. But we can already start preparing for the future by saving the dtype with save_pretrained.
181
+
182
+ > TensorFlow specific parameters
183
+
184
+ use_bfloat16 (`bool`, *optional*, defaults to `False`):
185
+ Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
186
+ tf_legacy_loss (`bool`, *optional*, defaults to `False`):
187
+ Whether the model should use legacy TensorFlow losses. Legacy losses have variable output shapes and may
188
+ not be XLA-compatible. This option is here for backward compatibility and will be removed in Transformers
189
+ v5.
190
+ loss_type (`str`, *optional*):
191
+ The type of loss that the model should use. It should be in `LOSS_MAPPING`'s keys, otherwise the loss will
192
+ be automatically infered from the model architecture.
193
+ """
194
+
195
+ model_type: str = ""
196
+ base_config_key: str = ""
197
+ sub_configs: Dict[str, "PretrainedConfig"] = {}
198
+ is_composition: bool = False
199
+ attribute_map: Dict[str, str] = {}
200
+ base_model_tp_plan: Optional[Dict[str, Any]] = None
201
+ _auto_class: Optional[str] = None
202
+
203
+ def __setattr__(self, key, value):
204
+ if key in super().__getattribute__("attribute_map"):
205
+ key = super().__getattribute__("attribute_map")[key]
206
+ super().__setattr__(key, value)
207
+
208
+ def __getattribute__(self, key):
209
+ if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
210
+ key = super().__getattribute__("attribute_map")[key]
211
+ return super().__getattribute__(key)
212
+
213
+ def __init__(self, **kwargs):
214
+ # Attributes with defaults
215
+ self.return_dict = kwargs.pop("return_dict", True)
216
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
217
+ self.output_attentions = kwargs.pop("output_attentions", False)
218
+ self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
219
+ self.torch_dtype = kwargs.pop("torch_dtype", None) # Only used by PyTorch models
220
+ self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
221
+ self.tf_legacy_loss = kwargs.pop("tf_legacy_loss", False) # Only used by TensorFlow models
222
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
223
+ self.tie_word_embeddings = kwargs.pop(
224
+ "tie_word_embeddings", True
225
+ ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
226
+ self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
227
+
228
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
229
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
230
+ self.is_decoder = kwargs.pop("is_decoder", False)
231
+ self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
232
+ self.add_cross_attention = kwargs.pop("add_cross_attention", False)
233
+ self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
234
+
235
+ # Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
236
+ # parameters, saving them will be deprecated. In a distant future, we won't need to load them.
237
+ for parameter_name, default_value in self._get_global_generation_defaults().items():
238
+ setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))
239
+
240
+ # Fine-tuning task arguments
241
+ self.architectures = kwargs.pop("architectures", None)
242
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
243
+ self.id2label = kwargs.pop("id2label", None)
244
+ self.label2id = kwargs.pop("label2id", None)
245
+ if self.label2id is not None and not isinstance(self.label2id, dict):
246
+ raise ValueError("Argument label2id should be a dictionary.")
247
+ if self.id2label is not None:
248
+ if not isinstance(self.id2label, dict):
249
+ raise ValueError("Argument id2label should be a dictionary.")
250
+ num_labels = kwargs.pop("num_labels", None)
251
+ if num_labels is not None and len(self.id2label) != num_labels:
252
+ logger.warning(
253
+ f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
254
+ f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
255
+ )
256
+ self.id2label = {int(key): value for key, value in self.id2label.items()}
257
+ # Keys are always strings in JSON so convert ids to int here.
258
+ else:
259
+ self.num_labels = kwargs.pop("num_labels", 2)
260
+
261
+ if self.torch_dtype is not None and isinstance(self.torch_dtype, str):
262
+ # we will start using self.torch_dtype in v5, but to be consistent with
263
+ # from_pretrained's torch_dtype arg convert it to an actual torch.dtype object
264
+ if is_torch_available():
265
+ import torch
266
+
267
+ self.torch_dtype = getattr(torch, self.torch_dtype)
268
+
269
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
270
+ self.tokenizer_class = kwargs.pop("tokenizer_class", None)
271
+ self.prefix = kwargs.pop("prefix", None)
272
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
273
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
274
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
275
+ self.sep_token_id = kwargs.pop("sep_token_id", None)
276
+
277
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
278
+
279
+ # task specific arguments
280
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
281
+
282
+ # regression / multi-label classification
283
+ self.problem_type = kwargs.pop("problem_type", None)
284
+ allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
285
+ if self.problem_type is not None and self.problem_type not in allowed_problem_types:
286
+ raise ValueError(
287
+ f"The config parameter `problem_type` was not understood: received {self.problem_type} "
288
+ "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
289
+ )
290
+
291
+ # TPU arguments
292
+ if kwargs.pop("xla_device", None) is not None:
293
+ logger.warning(
294
+ "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can "
295
+ "safely remove it from your `config.json` file."
296
+ )
297
+
298
+ # Name or path to the pretrained checkpoint
299
+ self._name_or_path = str(kwargs.pop("name_or_path", ""))
300
+ # Config hash
301
+ self._commit_hash = kwargs.pop("_commit_hash", None)
302
+
303
+ # Attention implementation to use, if relevant.
304
+ self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
305
+ self._attn_implementation_autoset = False
306
+
307
+ # Drop the transformers version info
308
+ self.transformers_version = kwargs.pop("transformers_version", None)
309
+
310
+ # Deal with gradient checkpointing
311
+ if kwargs.get("gradient_checkpointing", False):
312
+ warnings.warn(
313
+ "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
314
+ "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
315
+ "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
316
+ )
317
+
318
+ # Additional attributes without default values
319
+ for key, value in kwargs.items():
320
+ try:
321
+ setattr(self, key, value)
322
+ except AttributeError as err:
323
+ logger.error(f"Can't set {key} with value {value} for {self}")
324
+ raise err
325
+
326
+ @property
327
+ def name_or_path(self) -> str:
328
+ return getattr(self, "_name_or_path", None)
329
+
330
+ @name_or_path.setter
331
+ def name_or_path(self, value):
332
+ self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
333
+
334
+ @property
335
+ def use_return_dict(self) -> bool:
336
+ """
337
+ `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples.
338
+ """
339
+ # If torchscript is set, force `return_dict=False` to avoid jit errors
340
+ return self.return_dict and not self.torchscript
341
+
342
+ @property
343
+ def num_labels(self) -> int:
344
+ """
345
+ `int`: The number of labels for classification models.
346
+ """
347
+ return len(self.id2label)
348
+
349
+ @num_labels.setter
350
+ def num_labels(self, num_labels: int):
351
+ if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
352
+ self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
353
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
354
+
355
+ @property
356
+ def _attn_implementation(self):
357
+ # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
358
+ if hasattr(self, "_attn_implementation_internal"):
359
+ if self._attn_implementation_internal is None:
360
+ # `config.attn_implementation` should never be None, for backward compatibility.
361
+ return "eager"
362
+ else:
363
+ return self._attn_implementation_internal
364
+ else:
365
+ return "eager"
366
+
367
+ @_attn_implementation.setter
368
+ def _attn_implementation(self, value):
369
+ self._attn_implementation_internal = value
370
+
371
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
372
+ """
373
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
374
+ [`~PretrainedConfig.from_pretrained`] class method.
375
+
376
+ Args:
377
+ save_directory (`str` or `os.PathLike`):
378
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
379
+ push_to_hub (`bool`, *optional*, defaults to `False`):
380
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
381
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
382
+ namespace).
383
+ kwargs (`Dict[str, Any]`, *optional*):
384
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
385
+ """
386
+ self._set_token_in_kwargs(kwargs)
387
+
388
+ if os.path.isfile(save_directory):
389
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
390
+
391
+ non_default_generation_parameters = self._get_non_default_generation_parameters()
392
+ if len(non_default_generation_parameters) > 0:
393
+ # TODO (joao): this should be an exception if the user has modified the loaded config. See #33886
394
+ warnings.warn(
395
+ "Some non-default generation parameters are set in the model config. These should go into either a) "
396
+ "`model.generation_config` (as opposed to `model.config`); OR b) a GenerationConfig file "
397
+ "(https://huggingface.co/docs/transformers/generation_strategies#save-a-custom-decoding-strategy-with-your-model)."
398
+ "This warning will become an exception in the future."
399
+ f"\nNon-default generation parameters: {str(non_default_generation_parameters)}",
400
+ UserWarning,
401
+ )
402
+
403
+ os.makedirs(save_directory, exist_ok=True)
404
+
405
+ if push_to_hub:
406
+ commit_message = kwargs.pop("commit_message", None)
407
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
408
+ repo_id = self._create_repo(repo_id, **kwargs)
409
+ files_timestamps = self._get_files_timestamps(save_directory)
410
+
411
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
412
+ # loaded from the Hub.
413
+ if self._auto_class is not None:
414
+ custom_object_save(self, save_directory, config=self)
415
+
416
+ # If we save using the predefined names, we can load using `from_pretrained`
417
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
418
+
419
+ self.to_json_file(output_config_file, use_diff=True)
420
+ logger.info(f"Configuration saved in {output_config_file}")
421
+
422
+ if push_to_hub:
423
+ self._upload_modified_files(
424
+ save_directory,
425
+ repo_id,
426
+ files_timestamps,
427
+ commit_message=commit_message,
428
+ token=kwargs.get("token"),
429
+ )
430
+
431
+ @staticmethod
432
+ def _set_token_in_kwargs(kwargs, token=None):
433
+ """Temporary method to deal with `token` and `use_auth_token`.
434
+
435
+ This method is to avoid apply the same changes in all model config classes that overwrite `from_pretrained`.
436
+
437
+ Need to clean up `use_auth_token` in a follow PR.
438
+ """
439
+ # Some model config classes like CLIP define their own `from_pretrained` without the new argument `token` yet.
440
+ if token is None:
441
+ token = kwargs.pop("token", None)
442
+ use_auth_token = kwargs.pop("use_auth_token", None)
443
+
444
+ if use_auth_token is not None:
445
+ warnings.warn(
446
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
447
+ FutureWarning,
448
+ )
449
+ if token is not None:
450
+ raise ValueError(
451
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
452
+ )
453
+ token = use_auth_token
454
+
455
+ if token is not None:
456
+ kwargs["token"] = token
457
+
458
+ @classmethod
459
+ def from_pretrained(
460
+ cls,
461
+ pretrained_model_name_or_path: Union[str, os.PathLike],
462
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
463
+ force_download: bool = False,
464
+ local_files_only: bool = False,
465
+ token: Optional[Union[str, bool]] = None,
466
+ revision: str = "main",
467
+ **kwargs,
468
+ ) -> "PretrainedConfig":
469
+ r"""
470
+ Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
471
+
472
+ Args:
473
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
474
+ This can be either:
475
+
476
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
477
+ huggingface.co.
478
+ - a path to a *directory* containing a configuration file saved using the
479
+ [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
480
+ - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
481
+ cache_dir (`str` or `os.PathLike`, *optional*):
482
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
483
+ standard cache should not be used.
484
+ force_download (`bool`, *optional*, defaults to `False`):
485
+ Whether or not to force to (re-)download the configuration files and override the cached versions if
486
+ they exist.
487
+ resume_download:
488
+ Deprecated and ignored. All downloads are now resumed by default when possible.
489
+ Will be removed in v5 of Transformers.
490
+ proxies (`Dict[str, str]`, *optional*):
491
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
492
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
493
+ token (`str` or `bool`, *optional*):
494
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
495
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
496
+ revision (`str`, *optional*, defaults to `"main"`):
497
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
498
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
499
+ identifier allowed by git.
500
+
501
+ <Tip>
502
+
503
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
504
+
505
+ </Tip>
506
+
507
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
508
+ If `False`, then this function returns just the final configuration object.
509
+
510
+ If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
511
+ dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
512
+ part of `kwargs` which has not been used to update `config` and is otherwise ignored.
513
+ subfolder (`str`, *optional*, defaults to `""`):
514
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
515
+ specify the folder name here.
516
+ kwargs (`Dict[str, Any]`, *optional*):
517
+ The values in kwargs of any keys which are configuration attributes will be used to override the loaded
518
+ values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
519
+ by the `return_unused_kwargs` keyword parameter.
520
+
521
+ Returns:
522
+ [`PretrainedConfig`]: The configuration object instantiated from this pretrained model.
523
+
524
+ Examples:
525
+
526
+ ```python
527
+ # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
528
+ # derived class: BertConfig
529
+ config = BertConfig.from_pretrained(
530
+ "google-bert/bert-base-uncased"
531
+ ) # Download configuration from huggingface.co and cache.
532
+ config = BertConfig.from_pretrained(
533
+ "./test/saved_model/"
534
+ ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
535
+ config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
536
+ config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
537
+ assert config.output_attentions == True
538
+ config, unused_kwargs = BertConfig.from_pretrained(
539
+ "google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
540
+ )
541
+ assert config.output_attentions == True
542
+ assert unused_kwargs == {"foo": False}
543
+ ```"""
544
+ kwargs["cache_dir"] = cache_dir
545
+ kwargs["force_download"] = force_download
546
+ kwargs["local_files_only"] = local_files_only
547
+ kwargs["revision"] = revision
548
+
549
+ cls._set_token_in_kwargs(kwargs, token)
550
+
551
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
552
+ if cls.base_config_key and cls.base_config_key in config_dict:
553
+ config_dict = config_dict[cls.base_config_key]
554
+
555
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
556
+ # sometimes the config has no `base_config_key` if the config is used in several composite models
557
+ # e.g. LlamaConfig. In that case we try to see if there is match in `model_type` before raising a warning
558
+ for k, v in config_dict.items():
559
+ if isinstance(v, dict) and v.get("model_type") == cls.model_type:
560
+ config_dict = v
561
+
562
+ # raise warning only if we still can't see a match in `model_type`
563
+ if config_dict["model_type"] != cls.model_type:
564
+ logger.warning(
565
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
566
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
567
+ )
568
+
569
+ return cls.from_dict(config_dict, **kwargs)
570
+
571
+ @classmethod
572
+ def get_config_dict(
573
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
574
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
575
+ """
576
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
577
+ [`PretrainedConfig`] using `from_dict`.
578
+
579
+ Parameters:
580
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
581
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
582
+
583
+ Returns:
584
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
585
+
586
+ """
587
+ cls._set_token_in_kwargs(kwargs)
588
+
589
+ original_kwargs = copy.deepcopy(kwargs)
590
+ # Get config dict associated with the base config file
591
+ config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
592
+ if config_dict is None:
593
+ return {}, kwargs
594
+ if "_commit_hash" in config_dict:
595
+ original_kwargs["_commit_hash"] = config_dict["_commit_hash"]
596
+
597
+ # That config file may point us toward another config file to use.
598
+ if "configuration_files" in config_dict:
599
+ configuration_file = get_configuration_file(config_dict["configuration_files"])
600
+ config_dict, kwargs = cls._get_config_dict(
601
+ pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
602
+ )
603
+
604
+ return config_dict, kwargs
605
+
606
+ @classmethod
607
+ def _get_config_dict(
608
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
609
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
610
+ cache_dir = kwargs.pop("cache_dir", None)
611
+ force_download = kwargs.pop("force_download", False)
612
+ resume_download = kwargs.pop("resume_download", None)
613
+ proxies = kwargs.pop("proxies", None)
614
+ token = kwargs.pop("token", None)
615
+ local_files_only = kwargs.pop("local_files_only", False)
616
+ revision = kwargs.pop("revision", None)
617
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
618
+ subfolder = kwargs.pop("subfolder", "")
619
+ from_pipeline = kwargs.pop("_from_pipeline", None)
620
+ from_auto_class = kwargs.pop("_from_auto", False)
621
+ commit_hash = kwargs.pop("_commit_hash", None)
622
+
623
+ gguf_file = kwargs.get("gguf_file", None)
624
+
625
+ if trust_remote_code is True:
626
+ logger.warning(
627
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
628
+ " ignored."
629
+ )
630
+
631
+ user_agent = {"file_type": "config", "from_auto_class": from_auto_class}
632
+ if from_pipeline is not None:
633
+ user_agent["using_pipeline"] = from_pipeline
634
+
635
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
636
+
637
+ is_local = os.path.isdir(pretrained_model_name_or_path)
638
+ if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
639
+ # Special case when pretrained_model_name_or_path is a local file
640
+ resolved_config_file = pretrained_model_name_or_path
641
+ is_local = True
642
+ elif is_remote_url(pretrained_model_name_or_path):
643
+ configuration_file = pretrained_model_name_or_path if gguf_file is None else gguf_file
644
+ resolved_config_file = download_url(pretrained_model_name_or_path)
645
+ else:
646
+ configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) if gguf_file is None else gguf_file
647
+
648
+ try:
649
+ # Load from local folder or from cache or download from model Hub and cache
650
+ resolved_config_file = cached_file(
651
+ pretrained_model_name_or_path,
652
+ configuration_file,
653
+ cache_dir=cache_dir,
654
+ force_download=force_download,
655
+ proxies=proxies,
656
+ resume_download=resume_download,
657
+ local_files_only=local_files_only,
658
+ token=token,
659
+ user_agent=user_agent,
660
+ revision=revision,
661
+ subfolder=subfolder,
662
+ _commit_hash=commit_hash,
663
+ )
664
+ if resolved_config_file is None:
665
+ return None, kwargs
666
+ commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
667
+ except EnvironmentError:
668
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
669
+ # the original exception.
670
+ raise
671
+ except Exception:
672
+ # For any other exception, we throw a generic error.
673
+ raise EnvironmentError(
674
+ f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
675
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the same"
676
+ f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
677
+ f" containing a {configuration_file} file"
678
+ )
679
+
680
+ try:
681
+ if gguf_file:
682
+ config_dict = load_gguf_checkpoint(resolved_config_file, return_tensors=False)["config"]
683
+ else:
684
+ # Load config dict
685
+ config_dict = cls._dict_from_json_file(resolved_config_file)
686
+
687
+ config_dict["_commit_hash"] = commit_hash
688
+ except (json.JSONDecodeError, UnicodeDecodeError):
689
+ raise EnvironmentError(
690
+ f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
691
+ )
692
+
693
+ if is_local:
694
+ logger.info(f"loading configuration file {resolved_config_file}")
695
+ else:
696
+ logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
697
+
698
+ if "auto_map" in config_dict and not is_local:
699
+ config_dict["auto_map"] = add_model_info_to_auto_map(
700
+ config_dict["auto_map"], pretrained_model_name_or_path
701
+ )
702
+ if "custom_pipelines" in config_dict and not is_local:
703
+ config_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
704
+ config_dict["custom_pipelines"], pretrained_model_name_or_path
705
+ )
706
+
707
+ # timm models are not saved with the model_type in the config file
708
+ if "model_type" not in config_dict and is_timm_config_dict(config_dict):
709
+ config_dict["model_type"] = "timm_wrapper"
710
+
711
+ return config_dict, kwargs
712
+
713
+ @classmethod
714
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
715
+ """
716
+ Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
717
+
718
+ Args:
719
+ config_dict (`Dict[str, Any]`):
720
+ Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
721
+ retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
722
+ kwargs (`Dict[str, Any]`):
723
+ Additional parameters from which to initialize the configuration object.
724
+
725
+ Returns:
726
+ [`PretrainedConfig`]: The configuration object instantiated from those parameters.
727
+ """
728
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
729
+ # Those arguments may be passed along for our internal telemetry.
730
+ # We remove them so they don't appear in `return_unused_kwargs`.
731
+ kwargs.pop("_from_auto", None)
732
+ kwargs.pop("_from_pipeline", None)
733
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
734
+ if "_commit_hash" in kwargs and "_commit_hash" in config_dict:
735
+ kwargs["_commit_hash"] = config_dict["_commit_hash"]
736
+
737
+ # We remove it from kwargs so that it does not appear in `return_unused_kwargs`.
738
+ config_dict["attn_implementation"] = kwargs.pop("attn_implementation", None)
739
+
740
+ config = cls(**config_dict)
741
+
742
+ if hasattr(config, "pruned_heads"):
743
+ config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
744
+
745
+ # Update config with kwargs if needed
746
+ if "num_labels" in kwargs and "id2label" in kwargs:
747
+ num_labels = kwargs["num_labels"]
748
+ id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
749
+ if len(id2label) != num_labels:
750
+ raise ValueError(
751
+ f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
752
+ f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
753
+ "one of them."
754
+ )
755
+ to_remove = []
756
+ for key, value in kwargs.items():
757
+ if hasattr(config, key):
758
+ current_attr = getattr(config, key)
759
+ # To authorize passing a custom subconfig as kwarg in models that have nested configs.
760
+ if isinstance(current_attr, PretrainedConfig) and isinstance(value, dict):
761
+ value = current_attr.__class__(**value)
762
+ setattr(config, key, value)
763
+ if key != "torch_dtype":
764
+ to_remove.append(key)
765
+ for key in to_remove:
766
+ kwargs.pop(key, None)
767
+
768
+ logger.info(f"Model config {config}")
769
+ if return_unused_kwargs:
770
+ return config, kwargs
771
+ else:
772
+ return config
773
+
774
+ @classmethod
775
+ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
776
+ """
777
+ Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
778
+
779
+ Args:
780
+ json_file (`str` or `os.PathLike`):
781
+ Path to the JSON file containing the parameters.
782
+
783
+ Returns:
784
+ [`PretrainedConfig`]: The configuration object instantiated from that JSON file.
785
+
786
+ """
787
+ config_dict = cls._dict_from_json_file(json_file)
788
+ return cls(**config_dict)
789
+
790
+ @classmethod
791
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
792
+ with open(json_file, "r", encoding="utf-8") as reader:
793
+ text = reader.read()
794
+ return json.loads(text)
795
+
796
+ def __eq__(self, other):
797
+ return isinstance(other, PretrainedConfig) and (self.__dict__ == other.__dict__)
798
+
799
+ def __repr__(self):
800
+ return f"{self.__class__.__name__} {self.to_json_string()}"
801
+
802
+ def __iter__(self):
803
+ for attr in self.__dict__:
804
+ yield attr
805
+
806
+ def to_diff_dict(self) -> Dict[str, Any]:
807
+ """
808
+ Removes all attributes from config which correspond to the default config attributes for better readability and
809
+ serializes to a Python dictionary.
810
+
811
+ Returns:
812
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
813
+ """
814
+ config_dict = self.to_dict()
815
+
816
+ # get the default config dict
817
+ default_config_dict = PretrainedConfig().to_dict()
818
+
819
+ # get class specific config dict
820
+ class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
821
+
822
+ serializable_config_dict = {}
823
+
824
+ # only serialize values that differ from the default config
825
+ for key, value in config_dict.items():
826
+ if (
827
+ isinstance(getattr(self, key, None), PretrainedConfig)
828
+ and key in class_config_dict
829
+ and isinstance(class_config_dict[key], dict)
830
+ ):
831
+ # For nested configs we need to clean the diff recursively
832
+ diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
833
+ if "model_type" in value:
834
+ # Needs to be set even if it's not in the diff
835
+ diff["model_type"] = value["model_type"]
836
+ if len(diff) > 0:
837
+ serializable_config_dict[key] = diff
838
+ elif (
839
+ key not in default_config_dict
840
+ or key == "transformers_version"
841
+ or value != default_config_dict[key]
842
+ or (key in class_config_dict and value != class_config_dict[key])
843
+ ):
844
+ serializable_config_dict[key] = value
845
+
846
+ if hasattr(self, "quantization_config"):
847
+ serializable_config_dict["quantization_config"] = (
848
+ self.quantization_config.to_dict()
849
+ if not isinstance(self.quantization_config, dict)
850
+ else self.quantization_config
851
+ )
852
+
853
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
854
+ _ = serializable_config_dict.pop("_pre_quantization_dtype", None)
855
+
856
+ self.dict_torch_dtype_to_str(serializable_config_dict)
857
+
858
+ if "_attn_implementation_internal" in serializable_config_dict:
859
+ del serializable_config_dict["_attn_implementation_internal"]
860
+ # Do not serialize `base_model_tp_plan` for now
861
+ if "base_model_tp_plan" in serializable_config_dict:
862
+ del serializable_config_dict["base_model_tp_plan"]
863
+
864
+ return serializable_config_dict
865
+
866
+ def to_dict(self) -> Dict[str, Any]:
867
+ """
868
+ Serializes this instance to a Python dictionary.
869
+
870
+ Returns:
871
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
872
+ """
873
+ output = copy.deepcopy(self.__dict__)
874
+ if hasattr(self.__class__, "model_type"):
875
+ output["model_type"] = self.__class__.model_type
876
+ if "_auto_class" in output:
877
+ del output["_auto_class"]
878
+ if "_commit_hash" in output:
879
+ del output["_commit_hash"]
880
+ if "_attn_implementation_internal" in output:
881
+ del output["_attn_implementation_internal"]
882
+ # Do not serialize `base_model_tp_plan` for now
883
+ if "base_model_tp_plan" in output:
884
+ del output["base_model_tp_plan"]
885
+
886
+ # Transformers version when serializing the model
887
+ output["transformers_version"] = __version__
888
+
889
+ for key, value in output.items():
890
+ # Deal with nested configs like CLIP
891
+ if isinstance(value, PretrainedConfig):
892
+ value = value.to_dict()
893
+ del value["transformers_version"]
894
+
895
+ output[key] = value
896
+
897
+ if hasattr(self, "quantization_config"):
898
+ output["quantization_config"] = (
899
+ self.quantization_config.to_dict()
900
+ if not isinstance(self.quantization_config, dict)
901
+ else self.quantization_config
902
+ )
903
+
904
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
905
+ _ = output.pop("_pre_quantization_dtype", None)
906
+
907
+ self.dict_torch_dtype_to_str(output)
908
+
909
+ return output
910
+
911
+ def to_json_string(self, use_diff: bool = True) -> str:
912
+ """
913
+ Serializes this instance to a JSON string.
914
+
915
+ Args:
916
+ use_diff (`bool`, *optional*, defaults to `True`):
917
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
918
+ is serialized to JSON string.
919
+
920
+ Returns:
921
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
922
+ """
923
+ if use_diff is True:
924
+ config_dict = self.to_diff_dict()
925
+ else:
926
+ config_dict = self.to_dict()
927
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
928
+
929
+ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
930
+ """
931
+ Save this instance to a JSON file.
932
+
933
+ Args:
934
+ json_file_path (`str` or `os.PathLike`):
935
+ Path to the JSON file in which this configuration instance's parameters will be saved.
936
+ use_diff (`bool`, *optional*, defaults to `True`):
937
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
938
+ is serialized to JSON file.
939
+ """
940
+ with open(json_file_path, "w", encoding="utf-8") as writer:
941
+ writer.write(self.to_json_string(use_diff=use_diff))
942
+
943
+ def update(self, config_dict: Dict[str, Any]):
944
+ """
945
+ Updates attributes of this class with attributes from `config_dict`.
946
+
947
+ Args:
948
+ config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
949
+ """
950
+ for key, value in config_dict.items():
951
+ setattr(self, key, value)
952
+
953
+ def update_from_string(self, update_str: str):
954
+ """
955
+ Updates attributes of this class with attributes from `update_str`.
956
+
957
+ The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
958
+ "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
959
+
960
+ The keys to change have to already exist in the config object.
961
+
962
+ Args:
963
+ update_str (`str`): String with attributes that should be updated for this class.
964
+
965
+ """
966
+
967
+ d = dict(x.split("=") for x in update_str.split(","))
968
+ for k, v in d.items():
969
+ if not hasattr(self, k):
970
+ raise ValueError(f"key {k} isn't in the original config dict")
971
+
972
+ old_v = getattr(self, k)
973
+ if isinstance(old_v, bool):
974
+ if v.lower() in ["true", "1", "y", "yes"]:
975
+ v = True
976
+ elif v.lower() in ["false", "0", "n", "no"]:
977
+ v = False
978
+ else:
979
+ raise ValueError(f"can't derive true or false from {v} (key {k})")
980
+ elif isinstance(old_v, int):
981
+ v = int(v)
982
+ elif isinstance(old_v, float):
983
+ v = float(v)
984
+ elif not isinstance(old_v, str):
985
+ raise TypeError(
986
+ f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
987
+ )
988
+
989
+ setattr(self, k, v)
990
+
991
+ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
992
+ """
993
+ Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
994
+ converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
995
+ string, which can then be stored in the json format.
996
+ """
997
+ if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
998
+ d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
999
+ for value in d.values():
1000
+ if isinstance(value, dict):
1001
+ self.dict_torch_dtype_to_str(value)
1002
+
1003
+ @classmethod
1004
+ def register_for_auto_class(cls, auto_class="AutoConfig"):
1005
+ """
1006
+ Register this class with a given auto class. This should only be used for custom configurations as the ones in
1007
+ the library are already mapped with `AutoConfig`.
1008
+
1009
+ <Tip warning={true}>
1010
+
1011
+ This API is experimental and may have some slight breaking changes in the next releases.
1012
+
1013
+ </Tip>
1014
+
1015
+ Args:
1016
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoConfig"`):
1017
+ The auto class to register this new configuration with.
1018
+ """
1019
+ if not isinstance(auto_class, str):
1020
+ auto_class = auto_class.__name__
1021
+
1022
+ import transformers.models.auto as auto_module
1023
+
1024
+ if not hasattr(auto_module, auto_class):
1025
+ raise ValueError(f"{auto_class} is not a valid auto class.")
1026
+
1027
+ cls._auto_class = auto_class
1028
+
1029
+ @staticmethod
1030
+ def _get_global_generation_defaults() -> Dict[str, Any]:
1031
+ return {
1032
+ "max_length": 20,
1033
+ "min_length": 0,
1034
+ "do_sample": False,
1035
+ "early_stopping": False,
1036
+ "num_beams": 1,
1037
+ "num_beam_groups": 1,
1038
+ "diversity_penalty": 0.0,
1039
+ "temperature": 1.0,
1040
+ "top_k": 50,
1041
+ "top_p": 1.0,
1042
+ "typical_p": 1.0,
1043
+ "repetition_penalty": 1.0,
1044
+ "length_penalty": 1.0,
1045
+ "no_repeat_ngram_size": 0,
1046
+ "encoder_no_repeat_ngram_size": 0,
1047
+ "bad_words_ids": None,
1048
+ "num_return_sequences": 1,
1049
+ "output_scores": False,
1050
+ "return_dict_in_generate": False,
1051
+ "forced_bos_token_id": None,
1052
+ "forced_eos_token_id": None,
1053
+ "remove_invalid_values": False,
1054
+ "exponential_decay_length_penalty": None,
1055
+ "suppress_tokens": None,
1056
+ "begin_suppress_tokens": None,
1057
+ }
1058
+
1059
+ def _get_non_default_generation_parameters(self) -> Dict[str, Any]:
1060
+ """
1061
+ Gets the non-default generation parameters on the PretrainedConfig instance
1062
+ """
1063
+ non_default_generation_parameters = {}
1064
+ decoder_attribute_name = None
1065
+
1066
+ # Composite models don't have a default config, use their decoder config as a fallback for default values
1067
+ # If no known pattern is matched, then `default_config = None` -> check against the global generation defaults
1068
+ try:
1069
+ default_config = self.__class__()
1070
+ except ValueError:
1071
+ decoder_config = self.get_text_config(decoder=True)
1072
+ if decoder_config is not self:
1073
+ default_config = decoder_config.__class__()
1074
+ else:
1075
+ default_config = None
1076
+
1077
+ # If it is a composite model, we want to check the subconfig that will be used for generation
1078
+ self_decoder_config = self if decoder_attribute_name is None else getattr(self, decoder_attribute_name)
1079
+
1080
+ for parameter_name, default_global_value in self._get_global_generation_defaults().items():
1081
+ if hasattr(self_decoder_config, parameter_name):
1082
+ is_default_in_config = is_default_generation_value = None
1083
+ parameter_value = getattr(self_decoder_config, parameter_name)
1084
+ # Three cases in which is okay for the model config to hold generation config parameters:
1085
+ # 1. The parameter is set to `None`, effectivelly delegating its value to the generation config
1086
+ if parameter_value is None:
1087
+ continue
1088
+ # 2. If we have a default config, then the instance should hold the same generation defaults
1089
+ if default_config is not None:
1090
+ is_default_in_config = parameter_value == getattr(default_config, parameter_name)
1091
+ # 3. if we don't have a default config, then the instance should hold the global generation defaults
1092
+ else:
1093
+ is_default_generation_value = parameter_value == default_global_value
1094
+
1095
+ is_non_default = (is_default_in_config is False) or (
1096
+ is_default_in_config is None and is_default_generation_value is False
1097
+ )
1098
+ if is_non_default:
1099
+ non_default_generation_parameters[parameter_name] = getattr(self_decoder_config, parameter_name)
1100
+
1101
+ return non_default_generation_parameters
1102
+
1103
+ def get_text_config(self, decoder=False) -> "PretrainedConfig":
1104
+ """
1105
+ Returns the config that is meant to be used with text IO. On most models, it is the original config instance
1106
+ itself. On specific composite models, it is under a set of valid names.
1107
+
1108
+ If `decoder` is set to `True`, then only search for decoder config names.
1109
+ """
1110
+ decoder_possible_text_config_names = ("decoder", "generator", "text_config")
1111
+ encoder_possible_text_config_names = ("text_encoder",)
1112
+ if decoder:
1113
+ possible_text_config_names = decoder_possible_text_config_names
1114
+ else:
1115
+ possible_text_config_names = encoder_possible_text_config_names + decoder_possible_text_config_names
1116
+
1117
+ valid_text_config_names = []
1118
+ for text_config_name in possible_text_config_names:
1119
+ if hasattr(self, text_config_name):
1120
+ text_config = getattr(self, text_config_name, None)
1121
+ if text_config is not None:
1122
+ valid_text_config_names += [text_config_name]
1123
+
1124
+ if len(valid_text_config_names) > 1:
1125
+ raise ValueError(
1126
+ f"Multiple valid text configs were found in the model config: {valid_text_config_names}. In this "
1127
+ "case, using `get_text_config()` would be ambiguous. Please specify the desied text config directly."
1128
+ )
1129
+ elif len(valid_text_config_names) == 1:
1130
+ return getattr(self, valid_text_config_names[0])
1131
+ return self
1132
+
1133
+
1134
+ def get_configuration_file(configuration_files: List[str]) -> str:
1135
+ """
1136
+ Get the configuration file to use for this version of transformers.
1137
+
1138
+ Args:
1139
+ configuration_files (`List[str]`): The list of available configuration files.
1140
+
1141
+ Returns:
1142
+ `str`: The configuration file to use.
1143
+ """
1144
+ configuration_files_map = {}
1145
+ for file_name in configuration_files:
1146
+ search = _re_configuration_file.search(file_name)
1147
+ if search is not None:
1148
+ v = search.groups()[0]
1149
+ configuration_files_map[v] = file_name
1150
+ available_versions = sorted(configuration_files_map.keys())
1151
+
1152
+ # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
1153
+ configuration_file = CONFIG_NAME
1154
+ transformers_version = version.parse(__version__)
1155
+ for v in available_versions:
1156
+ if version.parse(v) <= transformers_version:
1157
+ configuration_file = configuration_files_map[v]
1158
+ else:
1159
+ # No point going further since the versions are sorted.
1160
+ break
1161
+
1162
+ return configuration_file
1163
+
1164
+
1165
+ def recursive_diff_dict(dict_a, dict_b, config_obj=None):
1166
+ """
1167
+ Helper function to recursively take the diff between two nested dictionaries. The resulting diff only contains the
1168
+ values from `dict_a` that are different from values in `dict_b`.
1169
+ """
1170
+ diff = {}
1171
+ default = config_obj.__class__().to_dict() if config_obj is not None else {}
1172
+ for key, value in dict_a.items():
1173
+ obj_value = getattr(config_obj, str(key), None)
1174
+ if isinstance(obj_value, PretrainedConfig) and key in dict_b and isinstance(dict_b[key], dict):
1175
+ diff_value = recursive_diff_dict(value, dict_b[key], config_obj=obj_value)
1176
+ if len(diff_value) > 0:
1177
+ diff[key] = diff_value
1178
+ elif key not in dict_b or value != dict_b[key] or key not in default or value != default[key]:
1179
+ diff[key] = value
1180
+ return diff
1181
+
1182
+
1183
+ PretrainedConfig.push_to_hub = copy_func(PretrainedConfig.push_to_hub)
1184
+ if PretrainedConfig.push_to_hub.__doc__ is not None:
1185
+ PretrainedConfig.push_to_hub.__doc__ = PretrainedConfig.push_to_hub.__doc__.format(
1186
+ object="config", object_class="AutoConfig", object_files="configuration file"
1187
+ )
convert_graph_to_onnx.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import warnings
16
+ from argparse import ArgumentParser
17
+ from os import listdir, makedirs
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional, Tuple
20
+
21
+ from packaging.version import Version, parse
22
+
23
+ from transformers.pipelines import Pipeline, pipeline
24
+ from transformers.tokenization_utils import BatchEncoding
25
+ from transformers.utils import ModelOutput, is_tf_available, is_torch_available
26
+
27
+
28
+ # This is the minimal required version to
29
+ # support some ONNX Runtime features
30
+ ORT_QUANTIZE_MINIMUM_VERSION = parse("1.4.0")
31
+
32
+
33
+ SUPPORTED_PIPELINES = [
34
+ "feature-extraction",
35
+ "ner",
36
+ "sentiment-analysis",
37
+ "fill-mask",
38
+ "question-answering",
39
+ "text-generation",
40
+ "translation_en_to_fr",
41
+ "translation_en_to_de",
42
+ "translation_en_to_ro",
43
+ ]
44
+
45
+
46
+ class OnnxConverterArgumentParser(ArgumentParser):
47
+ """
48
+ Wraps all the script arguments supported to export transformers models to ONNX IR
49
+ """
50
+
51
+ def __init__(self):
52
+ super().__init__("ONNX Converter")
53
+
54
+ self.add_argument(
55
+ "--pipeline",
56
+ type=str,
57
+ choices=SUPPORTED_PIPELINES,
58
+ default="feature-extraction",
59
+ )
60
+ self.add_argument(
61
+ "--model",
62
+ type=str,
63
+ required=True,
64
+ help="Model's id or path (ex: google-bert/bert-base-cased)",
65
+ )
66
+ self.add_argument("--tokenizer", type=str, help="Tokenizer's id or path (ex: google-bert/bert-base-cased)")
67
+ self.add_argument(
68
+ "--framework",
69
+ type=str,
70
+ choices=["pt", "tf"],
71
+ help="Framework for loading the model",
72
+ )
73
+ self.add_argument("--opset", type=int, default=11, help="ONNX opset to use")
74
+ self.add_argument(
75
+ "--check-loading",
76
+ action="store_true",
77
+ help="Check ONNX is able to load the model",
78
+ )
79
+ self.add_argument(
80
+ "--use-external-format",
81
+ action="store_true",
82
+ help="Allow exporting model >= than 2Gb",
83
+ )
84
+ self.add_argument(
85
+ "--quantize",
86
+ action="store_true",
87
+ help="Quantize the neural network to be run with int8",
88
+ )
89
+ self.add_argument("output")
90
+
91
+
92
+ def generate_identified_filename(filename: Path, identifier: str) -> Path:
93
+ """
94
+ Append a string-identifier at the end (before the extension, if any) to the provided filepath
95
+
96
+ Args:
97
+ filename: pathlib.Path The actual path object we would like to add an identifier suffix
98
+ identifier: The suffix to add
99
+
100
+ Returns: String with concatenated identifier at the end of the filename
101
+ """
102
+ return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix)
103
+
104
+
105
+ def check_onnxruntime_requirements(minimum_version: Version):
106
+ """
107
+ Check onnxruntime is installed and if the installed version match is recent enough
108
+
109
+ Raises:
110
+ ImportError: If onnxruntime is not installed or too old version is found
111
+ """
112
+ try:
113
+ import onnxruntime
114
+
115
+ # Parse the version of the installed onnxruntime
116
+ ort_version = parse(onnxruntime.__version__)
117
+
118
+ # We require 1.4.0 minimum
119
+ if ort_version < ORT_QUANTIZE_MINIMUM_VERSION:
120
+ raise ImportError(
121
+ f"We found an older version of onnxruntime ({onnxruntime.__version__}) "
122
+ f"but we require onnxruntime to be >= {minimum_version} to enable all the conversions options.\n"
123
+ "Please update onnxruntime by running `pip install --upgrade onnxruntime`"
124
+ )
125
+
126
+ except ImportError:
127
+ raise ImportError(
128
+ "onnxruntime doesn't seem to be currently installed. "
129
+ "Please install the onnxruntime by running `pip install onnxruntime`"
130
+ " and relaunch the conversion."
131
+ )
132
+
133
+
134
+ def ensure_valid_input(model, tokens, input_names):
135
+ """
136
+ Ensure inputs are presented in the correct order, without any Non
137
+
138
+ Args:
139
+ model: The model used to forward the input data
140
+ tokens: BatchEncoding holding the input data
141
+ input_names: The name of the inputs
142
+
143
+ Returns: Tuple
144
+
145
+ """
146
+ print("Ensuring inputs are in correct order")
147
+
148
+ model_args_name = model.forward.__code__.co_varnames
149
+ model_args, ordered_input_names = [], []
150
+ for arg_name in model_args_name[1:]: # start at index 1 to skip "self" argument
151
+ if arg_name in input_names:
152
+ ordered_input_names.append(arg_name)
153
+ model_args.append(tokens[arg_name])
154
+ else:
155
+ print(f"{arg_name} is not present in the generated input list.")
156
+ break
157
+
158
+ print(f"Generated inputs order: {ordered_input_names}")
159
+ return ordered_input_names, tuple(model_args)
160
+
161
+
162
+ def infer_shapes(nlp: Pipeline, framework: str) -> Tuple[List[str], List[str], Dict, BatchEncoding]:
163
+ """
164
+ Attempt to infer the static vs dynamic axes for each input and output tensors for a specific model
165
+
166
+ Args:
167
+ nlp: The pipeline object holding the model to be exported
168
+ framework: The framework identifier to dispatch to the correct inference scheme (pt/tf)
169
+
170
+ Returns:
171
+
172
+ - List of the inferred input variable names
173
+ - List of the inferred output variable names
174
+ - Dictionary with input/output variables names as key and shape tensor as value
175
+ - a BatchEncoding reference which was used to infer all the above information
176
+ """
177
+
178
+ def build_shape_dict(name: str, tensor, is_input: bool, seq_len: int):
179
+ if isinstance(tensor, (tuple, list)):
180
+ return [build_shape_dict(name, t, is_input, seq_len) for t in tensor]
181
+
182
+ else:
183
+ # Let's assume batch is the first axis with only 1 element (~~ might not be always true ...)
184
+ axes = {[axis for axis, numel in enumerate(tensor.shape) if numel == 1][0]: "batch"}
185
+ if is_input:
186
+ if len(tensor.shape) == 2:
187
+ axes[1] = "sequence"
188
+ else:
189
+ raise ValueError(f"Unable to infer tensor axes ({len(tensor.shape)})")
190
+ else:
191
+ seq_axes = [dim for dim, shape in enumerate(tensor.shape) if shape == seq_len]
192
+ axes.update({dim: "sequence" for dim in seq_axes})
193
+
194
+ print(f"Found {'input' if is_input else 'output'} {name} with shape: {axes}")
195
+ return axes
196
+
197
+ tokens = nlp.tokenizer("This is a sample output", return_tensors=framework)
198
+ seq_len = tokens.input_ids.shape[-1]
199
+ outputs = nlp.model(**tokens) if framework == "pt" else nlp.model(tokens)
200
+ if isinstance(outputs, ModelOutput):
201
+ outputs = outputs.to_tuple()
202
+ if not isinstance(outputs, (list, tuple)):
203
+ outputs = (outputs,)
204
+
205
+ # Generate input names & axes
206
+ input_vars = list(tokens.keys())
207
+ input_dynamic_axes = {k: build_shape_dict(k, v, True, seq_len) for k, v in tokens.items()}
208
+
209
+ # flatten potentially grouped outputs (past for gpt2, attentions)
210
+ outputs_flat = []
211
+ for output in outputs:
212
+ if isinstance(output, (tuple, list)):
213
+ outputs_flat.extend(output)
214
+ else:
215
+ outputs_flat.append(output)
216
+
217
+ # Generate output names & axes
218
+ output_names = [f"output_{i}" for i in range(len(outputs_flat))]
219
+ output_dynamic_axes = {k: build_shape_dict(k, v, False, seq_len) for k, v in zip(output_names, outputs_flat)}
220
+
221
+ # Create the aggregated axes representation
222
+ dynamic_axes = dict(input_dynamic_axes, **output_dynamic_axes)
223
+ return input_vars, output_names, dynamic_axes, tokens
224
+
225
+
226
+ def load_graph_from_args(
227
+ pipeline_name: str, framework: str, model: str, tokenizer: Optional[str] = None, **models_kwargs
228
+ ) -> Pipeline:
229
+ """
230
+ Convert the set of arguments provided through the CLI to an actual pipeline reference (tokenizer + model
231
+
232
+ Args:
233
+ pipeline_name: The kind of pipeline to use (ner, question-answering, etc.)
234
+ framework: The actual model to convert the pipeline from ("pt" or "tf")
235
+ model: The model name which will be loaded by the pipeline
236
+ tokenizer: The tokenizer name which will be loaded by the pipeline, default to the model's value
237
+
238
+ Returns: Pipeline object
239
+
240
+ """
241
+ # If no tokenizer provided
242
+ if tokenizer is None:
243
+ tokenizer = model
244
+
245
+ # Check the wanted framework is available
246
+ if framework == "pt" and not is_torch_available():
247
+ raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
248
+ if framework == "tf" and not is_tf_available():
249
+ raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
250
+
251
+ print(f"Loading pipeline (model: {model}, tokenizer: {tokenizer})")
252
+
253
+ # Allocate tokenizer and model
254
+ return pipeline(pipeline_name, model=model, tokenizer=tokenizer, framework=framework, model_kwargs=models_kwargs)
255
+
256
+
257
+ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format: bool):
258
+ """
259
+ Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
260
+
261
+ Args:
262
+ nlp: The pipeline to be exported
263
+ opset: The actual version of the ONNX operator set to use
264
+ output: Path where will be stored the generated ONNX model
265
+ use_external_format: Split the model definition from its parameters to allow model bigger than 2GB
266
+
267
+ Returns:
268
+
269
+ """
270
+ if not is_torch_available():
271
+ raise Exception("Cannot convert because PyTorch is not installed. Please install torch first.")
272
+
273
+ import torch
274
+ from torch.onnx import export
275
+
276
+ print(f"Using framework PyTorch: {torch.__version__}")
277
+
278
+ with torch.no_grad():
279
+ input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "pt")
280
+ ordered_input_names, model_args = ensure_valid_input(nlp.model, tokens, input_names)
281
+
282
+ export(
283
+ nlp.model,
284
+ model_args,
285
+ f=output.as_posix(),
286
+ input_names=ordered_input_names,
287
+ output_names=output_names,
288
+ dynamic_axes=dynamic_axes,
289
+ do_constant_folding=True,
290
+ opset_version=opset,
291
+ )
292
+
293
+
294
+ def convert_tensorflow(nlp: Pipeline, opset: int, output: Path):
295
+ """
296
+ Export a TensorFlow backed pipeline to ONNX Intermediate Representation (IR)
297
+
298
+ Args:
299
+ nlp: The pipeline to be exported
300
+ opset: The actual version of the ONNX operator set to use
301
+ output: Path where will be stored the generated ONNX model
302
+
303
+ Notes: TensorFlow cannot export model bigger than 2GB due to internal constraint from TensorFlow
304
+
305
+ """
306
+ if not is_tf_available():
307
+ raise Exception("Cannot convert because TF is not installed. Please install tensorflow first.")
308
+
309
+ print("/!\\ Please note TensorFlow doesn't support exporting model > 2Gb /!\\")
310
+
311
+ try:
312
+ import tensorflow as tf
313
+ import tf2onnx
314
+ from tf2onnx import __version__ as t2ov
315
+
316
+ print(f"Using framework TensorFlow: {tf.version.VERSION}, tf2onnx: {t2ov}")
317
+
318
+ # Build
319
+ input_names, output_names, dynamic_axes, tokens = infer_shapes(nlp, "tf")
320
+
321
+ # Forward
322
+ nlp.model.predict(tokens.data)
323
+ input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in tokens.items()]
324
+ model_proto, _ = tf2onnx.convert.from_keras(
325
+ nlp.model, input_signature, opset=opset, output_path=output.as_posix()
326
+ )
327
+
328
+ except ImportError as e:
329
+ raise Exception(
330
+ f"Cannot import {e.name} required to convert TF model to ONNX. Please install {e.name} first. {e}"
331
+ )
332
+
333
+
334
+ def convert(
335
+ framework: str,
336
+ model: str,
337
+ output: Path,
338
+ opset: int,
339
+ tokenizer: Optional[str] = None,
340
+ use_external_format: bool = False,
341
+ pipeline_name: str = "feature-extraction",
342
+ **model_kwargs,
343
+ ):
344
+ """
345
+ Convert the pipeline object to the ONNX Intermediate Representation (IR) format
346
+
347
+ Args:
348
+ framework: The framework the pipeline is backed by ("pt" or "tf")
349
+ model: The name of the model to load for the pipeline
350
+ output: The path where the ONNX graph will be stored
351
+ opset: The actual version of the ONNX operator set to use
352
+ tokenizer: The name of the model to load for the pipeline, default to the model's name if not provided
353
+ use_external_format:
354
+ Split the model definition from its parameters to allow model bigger than 2GB (PyTorch only)
355
+ pipeline_name: The kind of pipeline to instantiate (ner, question-answering, etc.)
356
+ model_kwargs: Keyword arguments to be forwarded to the model constructor
357
+
358
+ Returns:
359
+
360
+ """
361
+ warnings.warn(
362
+ "The `transformers.convert_graph_to_onnx` package is deprecated and will be removed in version 5 of"
363
+ " Transformers",
364
+ FutureWarning,
365
+ )
366
+ print(f"ONNX opset version set to: {opset}")
367
+
368
+ # Load the pipeline
369
+ nlp = load_graph_from_args(pipeline_name, framework, model, tokenizer, **model_kwargs)
370
+
371
+ if not output.parent.exists():
372
+ print(f"Creating folder {output.parent}")
373
+ makedirs(output.parent.as_posix())
374
+ elif len(listdir(output.parent.as_posix())) > 0:
375
+ raise Exception(f"Folder {output.parent.as_posix()} is not empty, aborting conversion")
376
+
377
+ # Export the graph
378
+ if framework == "pt":
379
+ convert_pytorch(nlp, opset, output, use_external_format)
380
+ else:
381
+ convert_tensorflow(nlp, opset, output)
382
+
383
+
384
+ def optimize(onnx_model_path: Path) -> Path:
385
+ """
386
+ Load the model at the specified path and let onnxruntime look at transformations on the graph to enable all the
387
+ optimizations possible
388
+
389
+ Args:
390
+ onnx_model_path: filepath where the model binary description is stored
391
+
392
+ Returns: Path where the optimized model binary description has been saved
393
+
394
+ """
395
+ from onnxruntime import InferenceSession, SessionOptions
396
+
397
+ # Generate model name with suffix "optimized"
398
+ opt_model_path = generate_identified_filename(onnx_model_path, "-optimized")
399
+ sess_option = SessionOptions()
400
+ sess_option.optimized_model_filepath = opt_model_path.as_posix()
401
+ _ = InferenceSession(onnx_model_path.as_posix(), sess_option)
402
+
403
+ print(f"Optimized model has been written at {opt_model_path}: \N{HEAVY CHECK MARK}")
404
+ print("/!\\ Optimized model contains hardware specific operators which might not be portable. /!\\")
405
+
406
+ return opt_model_path
407
+
408
+
409
+ def quantize(onnx_model_path: Path) -> Path:
410
+ """
411
+ Quantize the weights of the model from float32 to in8 to allow very efficient inference on modern CPU
412
+
413
+ Args:
414
+ onnx_model_path: Path to location the exported ONNX model is stored
415
+
416
+ Returns: The Path generated for the quantized
417
+ """
418
+ import onnx
419
+ import onnxruntime
420
+ from onnx.onnx_pb import ModelProto
421
+ from onnxruntime.quantization import QuantizationMode
422
+ from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
423
+ from onnxruntime.quantization.registry import IntegerOpsRegistry
424
+
425
+ # Load the ONNX model
426
+ onnx_model = onnx.load(onnx_model_path.as_posix())
427
+
428
+ if parse(onnx.__version__) < parse("1.5.0"):
429
+ print(
430
+ "Models larger than 2GB will fail to quantize due to protobuf constraint.\n"
431
+ "Please upgrade to onnxruntime >= 1.5.0."
432
+ )
433
+
434
+ # Copy it
435
+ copy_model = ModelProto()
436
+ copy_model.CopyFrom(onnx_model)
437
+
438
+ # Construct quantizer
439
+ # onnxruntime renamed input_qType to activation_qType in v1.13.1, so we
440
+ # check the onnxruntime version to ensure backward compatibility.
441
+ # See also: https://github.com/microsoft/onnxruntime/pull/12873
442
+ if parse(onnxruntime.__version__) < parse("1.13.1"):
443
+ quantizer = ONNXQuantizer(
444
+ model=copy_model,
445
+ per_channel=False,
446
+ reduce_range=False,
447
+ mode=QuantizationMode.IntegerOps,
448
+ static=False,
449
+ weight_qType=True,
450
+ input_qType=False,
451
+ tensors_range=None,
452
+ nodes_to_quantize=None,
453
+ nodes_to_exclude=None,
454
+ op_types_to_quantize=list(IntegerOpsRegistry),
455
+ )
456
+ else:
457
+ quantizer = ONNXQuantizer(
458
+ model=copy_model,
459
+ per_channel=False,
460
+ reduce_range=False,
461
+ mode=QuantizationMode.IntegerOps,
462
+ static=False,
463
+ weight_qType=True,
464
+ activation_qType=False,
465
+ tensors_range=None,
466
+ nodes_to_quantize=None,
467
+ nodes_to_exclude=None,
468
+ op_types_to_quantize=list(IntegerOpsRegistry),
469
+ )
470
+
471
+ # Quantize and export
472
+ quantizer.quantize_model()
473
+
474
+ # Append "-quantized" at the end of the model's name
475
+ quantized_model_path = generate_identified_filename(onnx_model_path, "-quantized")
476
+
477
+ # Save model
478
+ print(f"Quantized model has been written at {quantized_model_path}: \N{HEAVY CHECK MARK}")
479
+ onnx.save_model(quantizer.model.model, quantized_model_path.as_posix())
480
+
481
+ return quantized_model_path
482
+
483
+
484
+ def verify(path: Path):
485
+ from onnxruntime import InferenceSession, SessionOptions
486
+ from onnxruntime.capi.onnxruntime_pybind11_state import RuntimeException
487
+
488
+ print(f"Checking ONNX model loading from: {path} ...")
489
+ try:
490
+ onnx_options = SessionOptions()
491
+ _ = InferenceSession(path.as_posix(), onnx_options, providers=["CPUExecutionProvider"])
492
+ print(f"Model {path} correctly loaded: \N{HEAVY CHECK MARK}")
493
+ except RuntimeException as re:
494
+ print(f"Error while loading the model {re}: \N{HEAVY BALLOT X}")
495
+
496
+
497
+ if __name__ == "__main__":
498
+ parser = OnnxConverterArgumentParser()
499
+ args = parser.parse_args()
500
+
501
+ # Make sure output is absolute path
502
+ args.output = Path(args.output).absolute()
503
+
504
+ try:
505
+ print("\n====== Converting model to ONNX ======")
506
+ # Convert
507
+ convert(
508
+ args.framework,
509
+ args.model,
510
+ args.output,
511
+ args.opset,
512
+ args.tokenizer,
513
+ args.use_external_format,
514
+ args.pipeline,
515
+ )
516
+
517
+ if args.quantize:
518
+ # Ensure requirements for quantization on onnxruntime is met
519
+ check_onnxruntime_requirements(ORT_QUANTIZE_MINIMUM_VERSION)
520
+
521
+ # onnxruntime optimizations doesn't provide the same level of performances on TensorFlow than PyTorch
522
+ if args.framework == "tf":
523
+ print(
524
+ "\t Using TensorFlow might not provide the same optimization level compared to PyTorch.\n"
525
+ "\t For TensorFlow users you can try optimizing the model directly through onnxruntime_tools.\n"
526
+ "\t For more information, please refer to the onnxruntime documentation:\n"
527
+ "\t\thttps://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers\n"
528
+ )
529
+
530
+ print("\n====== Optimizing ONNX model ======")
531
+
532
+ # Quantization works best when using the optimized version of the model
533
+ args.optimized_output = optimize(args.output)
534
+
535
+ # Do the quantization on the right graph
536
+ args.quantized_output = quantize(args.optimized_output)
537
+
538
+ # And verify
539
+ if args.check_loading:
540
+ print("\n====== Check exported ONNX model(s) ======")
541
+ verify(args.output)
542
+
543
+ if hasattr(args, "optimized_output"):
544
+ verify(args.optimized_output)
545
+
546
+ if hasattr(args, "quantized_output"):
547
+ verify(args.quantized_output)
548
+
549
+ except Exception as e:
550
+ print(f"Error while converting the model: {e}")
551
+ exit(1)
convert_pytorch_checkpoint_to_tf2.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert pytorch checkpoints to TensorFlow"""
16
+
17
+ import argparse
18
+ import os
19
+
20
+ from . import (
21
+ AlbertConfig,
22
+ BartConfig,
23
+ BertConfig,
24
+ CamembertConfig,
25
+ CTRLConfig,
26
+ DistilBertConfig,
27
+ DPRConfig,
28
+ ElectraConfig,
29
+ FlaubertConfig,
30
+ GPT2Config,
31
+ LayoutLMConfig,
32
+ LxmertConfig,
33
+ OpenAIGPTConfig,
34
+ RobertaConfig,
35
+ T5Config,
36
+ TFAlbertForPreTraining,
37
+ TFBartForConditionalGeneration,
38
+ TFBartForSequenceClassification,
39
+ TFBertForPreTraining,
40
+ TFBertForQuestionAnswering,
41
+ TFBertForSequenceClassification,
42
+ TFCamembertForMaskedLM,
43
+ TFCTRLLMHeadModel,
44
+ TFDistilBertForMaskedLM,
45
+ TFDistilBertForQuestionAnswering,
46
+ TFDPRContextEncoder,
47
+ TFDPRQuestionEncoder,
48
+ TFDPRReader,
49
+ TFElectraForPreTraining,
50
+ TFFlaubertWithLMHeadModel,
51
+ TFGPT2LMHeadModel,
52
+ TFLayoutLMForMaskedLM,
53
+ TFLxmertForPreTraining,
54
+ TFLxmertVisualFeatureEncoder,
55
+ TFOpenAIGPTLMHeadModel,
56
+ TFRobertaForCausalLM,
57
+ TFRobertaForMaskedLM,
58
+ TFRobertaForSequenceClassification,
59
+ TFT5ForConditionalGeneration,
60
+ TFTransfoXLLMHeadModel,
61
+ TFWav2Vec2Model,
62
+ TFXLMRobertaForMaskedLM,
63
+ TFXLMWithLMHeadModel,
64
+ TFXLNetLMHeadModel,
65
+ TransfoXLConfig,
66
+ Wav2Vec2Config,
67
+ Wav2Vec2Model,
68
+ XLMConfig,
69
+ XLMRobertaConfig,
70
+ XLNetConfig,
71
+ is_torch_available,
72
+ load_pytorch_checkpoint_in_tf2_model,
73
+ )
74
+ from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
75
+
76
+
77
+ if is_torch_available():
78
+ import numpy as np
79
+ import torch
80
+
81
+ from . import (
82
+ AlbertForPreTraining,
83
+ BartForConditionalGeneration,
84
+ BertForPreTraining,
85
+ BertForQuestionAnswering,
86
+ BertForSequenceClassification,
87
+ CamembertForMaskedLM,
88
+ CTRLLMHeadModel,
89
+ DistilBertForMaskedLM,
90
+ DistilBertForQuestionAnswering,
91
+ DPRContextEncoder,
92
+ DPRQuestionEncoder,
93
+ DPRReader,
94
+ ElectraForPreTraining,
95
+ FlaubertWithLMHeadModel,
96
+ GPT2LMHeadModel,
97
+ LayoutLMForMaskedLM,
98
+ LxmertForPreTraining,
99
+ LxmertVisualFeatureEncoder,
100
+ OpenAIGPTLMHeadModel,
101
+ RobertaForMaskedLM,
102
+ RobertaForSequenceClassification,
103
+ T5ForConditionalGeneration,
104
+ TransfoXLLMHeadModel,
105
+ XLMRobertaForMaskedLM,
106
+ XLMWithLMHeadModel,
107
+ XLNetLMHeadModel,
108
+ )
109
+
110
+
111
+ logging.set_verbosity_info()
112
+
113
+ MODEL_CLASSES = {
114
+ "bart": (
115
+ BartConfig,
116
+ TFBartForConditionalGeneration,
117
+ TFBartForSequenceClassification,
118
+ BartForConditionalGeneration,
119
+ ),
120
+ "bert": (
121
+ BertConfig,
122
+ TFBertForPreTraining,
123
+ BertForPreTraining,
124
+ ),
125
+ "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad": (
126
+ BertConfig,
127
+ TFBertForQuestionAnswering,
128
+ BertForQuestionAnswering,
129
+ ),
130
+ "google-bert/bert-large-cased-whole-word-masking-finetuned-squad": (
131
+ BertConfig,
132
+ TFBertForQuestionAnswering,
133
+ BertForQuestionAnswering,
134
+ ),
135
+ "google-bert/bert-base-cased-finetuned-mrpc": (
136
+ BertConfig,
137
+ TFBertForSequenceClassification,
138
+ BertForSequenceClassification,
139
+ ),
140
+ "dpr": (
141
+ DPRConfig,
142
+ TFDPRQuestionEncoder,
143
+ TFDPRContextEncoder,
144
+ TFDPRReader,
145
+ DPRQuestionEncoder,
146
+ DPRContextEncoder,
147
+ DPRReader,
148
+ ),
149
+ "openai-community/gpt2": (
150
+ GPT2Config,
151
+ TFGPT2LMHeadModel,
152
+ GPT2LMHeadModel,
153
+ ),
154
+ "xlnet": (
155
+ XLNetConfig,
156
+ TFXLNetLMHeadModel,
157
+ XLNetLMHeadModel,
158
+ ),
159
+ "xlm": (
160
+ XLMConfig,
161
+ TFXLMWithLMHeadModel,
162
+ XLMWithLMHeadModel,
163
+ ),
164
+ "xlm-roberta": (
165
+ XLMRobertaConfig,
166
+ TFXLMRobertaForMaskedLM,
167
+ XLMRobertaForMaskedLM,
168
+ ),
169
+ "transfo-xl": (
170
+ TransfoXLConfig,
171
+ TFTransfoXLLMHeadModel,
172
+ TransfoXLLMHeadModel,
173
+ ),
174
+ "openai-community/openai-gpt": (
175
+ OpenAIGPTConfig,
176
+ TFOpenAIGPTLMHeadModel,
177
+ OpenAIGPTLMHeadModel,
178
+ ),
179
+ "roberta": (
180
+ RobertaConfig,
181
+ TFRobertaForCausalLM,
182
+ TFRobertaForMaskedLM,
183
+ RobertaForMaskedLM,
184
+ ),
185
+ "layoutlm": (
186
+ LayoutLMConfig,
187
+ TFLayoutLMForMaskedLM,
188
+ LayoutLMForMaskedLM,
189
+ ),
190
+ "FacebookAI/roberta-large-mnli": (
191
+ RobertaConfig,
192
+ TFRobertaForSequenceClassification,
193
+ RobertaForSequenceClassification,
194
+ ),
195
+ "camembert": (
196
+ CamembertConfig,
197
+ TFCamembertForMaskedLM,
198
+ CamembertForMaskedLM,
199
+ ),
200
+ "flaubert": (
201
+ FlaubertConfig,
202
+ TFFlaubertWithLMHeadModel,
203
+ FlaubertWithLMHeadModel,
204
+ ),
205
+ "distilbert": (
206
+ DistilBertConfig,
207
+ TFDistilBertForMaskedLM,
208
+ DistilBertForMaskedLM,
209
+ ),
210
+ "distilbert-base-distilled-squad": (
211
+ DistilBertConfig,
212
+ TFDistilBertForQuestionAnswering,
213
+ DistilBertForQuestionAnswering,
214
+ ),
215
+ "lxmert": (
216
+ LxmertConfig,
217
+ TFLxmertForPreTraining,
218
+ LxmertForPreTraining,
219
+ ),
220
+ "lxmert-visual-feature-encoder": (
221
+ LxmertConfig,
222
+ TFLxmertVisualFeatureEncoder,
223
+ LxmertVisualFeatureEncoder,
224
+ ),
225
+ "Salesforce/ctrl": (
226
+ CTRLConfig,
227
+ TFCTRLLMHeadModel,
228
+ CTRLLMHeadModel,
229
+ ),
230
+ "albert": (
231
+ AlbertConfig,
232
+ TFAlbertForPreTraining,
233
+ AlbertForPreTraining,
234
+ ),
235
+ "t5": (
236
+ T5Config,
237
+ TFT5ForConditionalGeneration,
238
+ T5ForConditionalGeneration,
239
+ ),
240
+ "electra": (
241
+ ElectraConfig,
242
+ TFElectraForPreTraining,
243
+ ElectraForPreTraining,
244
+ ),
245
+ "wav2vec2": (
246
+ Wav2Vec2Config,
247
+ TFWav2Vec2Model,
248
+ Wav2Vec2Model,
249
+ ),
250
+ }
251
+
252
+
253
+ def convert_pt_checkpoint_to_tf(
254
+ model_type, pytorch_checkpoint_path, config_file, tf_dump_path, compare_with_pt_model=False, use_cached_models=True
255
+ ):
256
+ if model_type not in MODEL_CLASSES:
257
+ raise ValueError(f"Unrecognized model type, should be one of {list(MODEL_CLASSES.keys())}.")
258
+
259
+ config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
260
+
261
+ # Initialise TF model
262
+ if config_file in aws_config_map:
263
+ config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
264
+ config = config_class.from_json_file(config_file)
265
+ config.output_hidden_states = True
266
+ config.output_attentions = True
267
+ print(f"Building TensorFlow model from configuration: {config}")
268
+ tf_model = model_class(config)
269
+
270
+ # Load weights from tf checkpoint
271
+ if pytorch_checkpoint_path in aws_config_map.keys():
272
+ pytorch_checkpoint_path = cached_file(
273
+ pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
274
+ )
275
+ # Load PyTorch checkpoint in tf2 model:
276
+ tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
277
+
278
+ if compare_with_pt_model:
279
+ tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network
280
+
281
+ weights_only_kwarg = {"weights_only": True}
282
+ state_dict = torch.load(
283
+ pytorch_checkpoint_path,
284
+ map_location="cpu",
285
+ **weights_only_kwarg,
286
+ )
287
+ pt_model = pt_model_class.from_pretrained(
288
+ pretrained_model_name_or_path=None, config=config, state_dict=state_dict
289
+ )
290
+
291
+ with torch.no_grad():
292
+ pto = pt_model(**pt_model.dummy_inputs)
293
+
294
+ np_pt = pto[0].numpy()
295
+ np_tf = tfo[0].numpy()
296
+ diff = np.amax(np.abs(np_pt - np_tf))
297
+ print(f"Max absolute difference between models outputs {diff}")
298
+ assert diff <= 2e-2, f"Error, model absolute difference is >2e-2: {diff}"
299
+
300
+ # Save pytorch-model
301
+ print(f"Save TensorFlow model to {tf_dump_path}")
302
+ tf_model.save_weights(tf_dump_path, save_format="h5")
303
+
304
+
305
+ def convert_all_pt_checkpoints_to_tf(
306
+ args_model_type,
307
+ tf_dump_path,
308
+ model_shortcut_names_or_path=None,
309
+ config_shortcut_names_or_path=None,
310
+ compare_with_pt_model=False,
311
+ use_cached_models=False,
312
+ remove_cached_files=False,
313
+ only_convert_finetuned_models=False,
314
+ ):
315
+ if args_model_type is None:
316
+ model_types = list(MODEL_CLASSES.keys())
317
+ else:
318
+ model_types = [args_model_type]
319
+
320
+ for j, model_type in enumerate(model_types, start=1):
321
+ print("=" * 100)
322
+ print(f" Converting model type {j}/{len(model_types)}: {model_type}")
323
+ print("=" * 100)
324
+ if model_type not in MODEL_CLASSES:
325
+ raise ValueError(f"Unrecognized model type {model_type}, should be one of {list(MODEL_CLASSES.keys())}.")
326
+
327
+ config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
328
+
329
+ if model_shortcut_names_or_path is None:
330
+ model_shortcut_names_or_path = list(aws_model_maps.keys())
331
+ if config_shortcut_names_or_path is None:
332
+ config_shortcut_names_or_path = model_shortcut_names_or_path
333
+
334
+ for i, (model_shortcut_name, config_shortcut_name) in enumerate(
335
+ zip(model_shortcut_names_or_path, config_shortcut_names_or_path), start=1
336
+ ):
337
+ print("-" * 100)
338
+ if "-squad" in model_shortcut_name or "-mrpc" in model_shortcut_name or "-mnli" in model_shortcut_name:
339
+ if not only_convert_finetuned_models:
340
+ print(f" Skipping finetuned checkpoint {model_shortcut_name}")
341
+ continue
342
+ model_type = model_shortcut_name
343
+ elif only_convert_finetuned_models:
344
+ print(f" Skipping not finetuned checkpoint {model_shortcut_name}")
345
+ continue
346
+ print(
347
+ f" Converting checkpoint {i}/{len(aws_config_map)}: {model_shortcut_name} - model_type {model_type}"
348
+ )
349
+ print("-" * 100)
350
+
351
+ if config_shortcut_name in aws_config_map:
352
+ config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
353
+ else:
354
+ config_file = config_shortcut_name
355
+
356
+ if model_shortcut_name in aws_model_maps:
357
+ model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
358
+ else:
359
+ model_file = model_shortcut_name
360
+
361
+ if os.path.isfile(model_shortcut_name):
362
+ model_shortcut_name = "converted_model"
363
+
364
+ convert_pt_checkpoint_to_tf(
365
+ model_type=model_type,
366
+ pytorch_checkpoint_path=model_file,
367
+ config_file=config_file,
368
+ tf_dump_path=os.path.join(tf_dump_path, model_shortcut_name + "-tf_model.h5"),
369
+ compare_with_pt_model=compare_with_pt_model,
370
+ )
371
+ if remove_cached_files:
372
+ os.remove(config_file)
373
+ os.remove(model_file)
374
+
375
+
376
+ if __name__ == "__main__":
377
+ parser = argparse.ArgumentParser()
378
+ # Required parameters
379
+ parser.add_argument(
380
+ "--tf_dump_path", default=None, type=str, required=True, help="Path to the output Tensorflow dump file."
381
+ )
382
+ parser.add_argument(
383
+ "--model_type",
384
+ default=None,
385
+ type=str,
386
+ help=(
387
+ f"Model type selected in the list of {list(MODEL_CLASSES.keys())}. If not given, will download and "
388
+ "convert all the models from AWS."
389
+ ),
390
+ )
391
+ parser.add_argument(
392
+ "--pytorch_checkpoint_path",
393
+ default=None,
394
+ type=str,
395
+ help=(
396
+ "Path to the PyTorch checkpoint path or shortcut name to download from AWS. "
397
+ "If not given, will download and convert all the checkpoints from AWS."
398
+ ),
399
+ )
400
+ parser.add_argument(
401
+ "--config_file",
402
+ default=None,
403
+ type=str,
404
+ help=(
405
+ "The config json file corresponding to the pre-trained model. \n"
406
+ "This specifies the model architecture. If not given and "
407
+ "--pytorch_checkpoint_path is not given or is a shortcut name "
408
+ "use the configuration associated to the shortcut name on the AWS"
409
+ ),
410
+ )
411
+ parser.add_argument(
412
+ "--compare_with_pt_model", action="store_true", help="Compare Tensorflow and PyTorch model predictions."
413
+ )
414
+ parser.add_argument(
415
+ "--use_cached_models",
416
+ action="store_true",
417
+ help="Use cached models if possible instead of updating to latest checkpoint versions.",
418
+ )
419
+ parser.add_argument(
420
+ "--remove_cached_files",
421
+ action="store_true",
422
+ help="Remove pytorch models after conversion (save memory when converting in batches).",
423
+ )
424
+ parser.add_argument("--only_convert_finetuned_models", action="store_true", help="Only convert finetuned models.")
425
+ args = parser.parse_args()
426
+
427
+ # if args.pytorch_checkpoint_path is not None:
428
+ # convert_pt_checkpoint_to_tf(args.model_type.lower(),
429
+ # args.pytorch_checkpoint_path,
430
+ # args.config_file if args.config_file is not None else args.pytorch_checkpoint_path,
431
+ # args.tf_dump_path,
432
+ # compare_with_pt_model=args.compare_with_pt_model,
433
+ # use_cached_models=args.use_cached_models)
434
+ # else:
435
+ convert_all_pt_checkpoints_to_tf(
436
+ args.model_type.lower() if args.model_type is not None else None,
437
+ args.tf_dump_path,
438
+ model_shortcut_names_or_path=[args.pytorch_checkpoint_path]
439
+ if args.pytorch_checkpoint_path is not None
440
+ else None,
441
+ config_shortcut_names_or_path=[args.config_file] if args.config_file is not None else None,
442
+ compare_with_pt_model=args.compare_with_pt_model,
443
+ use_cached_models=args.use_cached_models,
444
+ remove_cached_files=args.remove_cached_files,
445
+ only_convert_finetuned_models=args.only_convert_finetuned_models,
446
+ )
convert_slow_tokenizer.py ADDED
@@ -0,0 +1,1642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Utilities to convert slow tokenizers in their fast tokenizers counterparts.
17
+
18
+ All the conversions are grouped here to gather SentencePiece dependencies outside of the fast tokenizers files and
19
+ allow to make our dependency on SentencePiece optional.
20
+ """
21
+
22
+ import warnings
23
+ from typing import Dict, List, Tuple
24
+
25
+ from packaging import version
26
+ from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
27
+ from tokenizers.models import BPE, Unigram, WordPiece
28
+
29
+ from .utils import is_protobuf_available, is_sentencepiece_available, logging, requires_backends
30
+ from .utils.import_utils import PROTOBUF_IMPORT_ERROR
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ def import_protobuf(error_message=""):
37
+ if is_sentencepiece_available():
38
+ from sentencepiece import sentencepiece_model_pb2
39
+
40
+ return sentencepiece_model_pb2
41
+ if is_protobuf_available():
42
+ import google.protobuf
43
+
44
+ if version.parse(google.protobuf.__version__) < version.parse("4.0.0"):
45
+ from transformers.utils import sentencepiece_model_pb2
46
+ else:
47
+ from transformers.utils import sentencepiece_model_pb2_new as sentencepiece_model_pb2
48
+ return sentencepiece_model_pb2
49
+ else:
50
+ raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
51
+
52
+
53
+ def _get_prepend_scheme(add_prefix_space: bool, original_tokenizer) -> str:
54
+ if add_prefix_space:
55
+ prepend_scheme = "always"
56
+ if not getattr(original_tokenizer, "legacy", True):
57
+ prepend_scheme = "first"
58
+ else:
59
+ prepend_scheme = "never"
60
+ return prepend_scheme
61
+
62
+
63
+ def generate_merges(vocab, vocab_scores):
64
+ reverse = vocab_scores is not None
65
+ vocab_scores = dict(vocab_scores) if reverse else vocab
66
+
67
+ merges = []
68
+ for merge, piece_score in vocab_scores.items():
69
+ local = []
70
+ for index in range(1, len(merge)):
71
+ piece_l, piece_r = merge[:index], merge[index:]
72
+ if piece_l in vocab and piece_r in vocab:
73
+ local.append((piece_l, piece_r, piece_score))
74
+ local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]]))
75
+ merges.extend(local)
76
+
77
+ merges = sorted(merges, key=lambda val: (val[2], len(val[0]), len(val[1])), reverse=reverse)
78
+ merges = [(val[0], val[1]) for val in merges]
79
+ return merges
80
+
81
+
82
+ class SentencePieceExtractor:
83
+ """
84
+ Extractor implementation for SentencePiece trained models. https://github.com/google/sentencepiece
85
+ """
86
+
87
+ def __init__(self, model: str):
88
+ requires_backends(self, "sentencepiece")
89
+ from sentencepiece import SentencePieceProcessor
90
+
91
+ self.sp = SentencePieceProcessor()
92
+ self.sp.Load(model)
93
+
94
+ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
95
+ """
96
+ By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
97
+ order the merges with respect to the piece scores instead.
98
+ """
99
+ sp = self.sp
100
+ vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
101
+
102
+ merges = generate_merges(vocab, vocab_scores)
103
+
104
+ return vocab, merges
105
+
106
+
107
+ class GemmaSentencePieceExtractor(SentencePieceExtractor):
108
+ def extract(self, vocab_scores=None) -> Tuple[Dict[str, int], List[Tuple]]:
109
+ """
110
+ By default will return vocab and merges with respect to their order, by sending `vocab_scores` we're going to
111
+ order the merges with respect to the piece scores instead.
112
+ """
113
+ sp = self.sp
114
+ vocab = {sp.id_to_piece(index): index for index in range(sp.GetPieceSize())}
115
+
116
+ # there is a missing token in the vocab. We have to do this to support merges
117
+ # "<0x09>" is the bytefallback for `\t`
118
+ vocab["\t"] = vocab.get("<0x09>")
119
+
120
+ merges = generate_merges(vocab, vocab_scores)
121
+ return vocab, merges
122
+
123
+
124
+ def check_number_comma(piece: str) -> bool:
125
+ return len(piece) < 2 or piece[-1] != "," or not piece[-2].isdigit()
126
+
127
+
128
+ class Converter:
129
+ def __init__(self, original_tokenizer):
130
+ self.original_tokenizer = original_tokenizer
131
+
132
+ def converted(self) -> Tokenizer:
133
+ raise NotImplementedError()
134
+
135
+
136
+ class BertConverter(Converter):
137
+ def converted(self) -> Tokenizer:
138
+ vocab = self.original_tokenizer.vocab
139
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
140
+
141
+ tokenize_chinese_chars = False
142
+ strip_accents = False
143
+ do_lower_case = False
144
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
145
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
146
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
147
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
148
+
149
+ tokenizer.normalizer = normalizers.BertNormalizer(
150
+ clean_text=True,
151
+ handle_chinese_chars=tokenize_chinese_chars,
152
+ strip_accents=strip_accents,
153
+ lowercase=do_lower_case,
154
+ )
155
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
156
+
157
+ cls = str(self.original_tokenizer.cls_token)
158
+ sep = str(self.original_tokenizer.sep_token)
159
+ cls_token_id = self.original_tokenizer.cls_token_id
160
+ sep_token_id = self.original_tokenizer.sep_token_id
161
+
162
+ tokenizer.post_processor = processors.TemplateProcessing(
163
+ single=f"{cls}:0 $A:0 {sep}:0",
164
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
165
+ special_tokens=[
166
+ (cls, cls_token_id),
167
+ (sep, sep_token_id),
168
+ ],
169
+ )
170
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
171
+
172
+ return tokenizer
173
+
174
+
175
+ class SplinterConverter(Converter):
176
+ def converted(self) -> Tokenizer:
177
+ vocab = self.original_tokenizer.vocab
178
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
179
+
180
+ tokenize_chinese_chars = False
181
+ strip_accents = False
182
+ do_lower_case = False
183
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
184
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
185
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
186
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
187
+
188
+ tokenizer.normalizer = normalizers.BertNormalizer(
189
+ clean_text=True,
190
+ handle_chinese_chars=tokenize_chinese_chars,
191
+ strip_accents=strip_accents,
192
+ lowercase=do_lower_case,
193
+ )
194
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
195
+
196
+ cls = str(self.original_tokenizer.cls_token)
197
+ sep = str(self.original_tokenizer.sep_token)
198
+ question = str(self.original_tokenizer.question_token)
199
+ dot = "."
200
+ cls_token_id = self.original_tokenizer.cls_token_id
201
+ sep_token_id = self.original_tokenizer.sep_token_id
202
+ question_token_id = self.original_tokenizer.question_token_id
203
+ dot_token_id = self.original_tokenizer.convert_tokens_to_ids(".")
204
+
205
+ if self.original_tokenizer.padding_side == "right":
206
+ pair = f"{cls}:0 $A:0 {question} {dot} {sep}:0 $B:1 {sep}:1"
207
+ else:
208
+ pair = f"{cls}:0 $A:0 {sep}:0 $B:1 {question} {dot} {sep}:1"
209
+
210
+ tokenizer.post_processor = processors.TemplateProcessing(
211
+ single=f"{cls}:0 $A:0 {sep}:0",
212
+ pair=pair,
213
+ special_tokens=[
214
+ (cls, cls_token_id),
215
+ (sep, sep_token_id),
216
+ (question, question_token_id),
217
+ (dot, dot_token_id),
218
+ ],
219
+ )
220
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
221
+
222
+ return tokenizer
223
+
224
+
225
+ class FunnelConverter(Converter):
226
+ def converted(self) -> Tokenizer:
227
+ vocab = self.original_tokenizer.vocab
228
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
229
+
230
+ tokenize_chinese_chars = False
231
+ strip_accents = False
232
+ do_lower_case = False
233
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
234
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
235
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
236
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
237
+
238
+ tokenizer.normalizer = normalizers.BertNormalizer(
239
+ clean_text=True,
240
+ handle_chinese_chars=tokenize_chinese_chars,
241
+ strip_accents=strip_accents,
242
+ lowercase=do_lower_case,
243
+ )
244
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
245
+
246
+ cls = str(self.original_tokenizer.cls_token)
247
+ sep = str(self.original_tokenizer.sep_token)
248
+ cls_token_id = self.original_tokenizer.cls_token_id
249
+ sep_token_id = self.original_tokenizer.sep_token_id
250
+
251
+ tokenizer.post_processor = processors.TemplateProcessing(
252
+ single=f"{cls}:2 $A:0 {sep}:0", # token_type_id is 2 for Funnel transformer
253
+ pair=f"{cls}:2 $A:0 {sep}:0 $B:1 {sep}:1",
254
+ special_tokens=[
255
+ (cls, cls_token_id),
256
+ (sep, sep_token_id),
257
+ ],
258
+ )
259
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
260
+
261
+ return tokenizer
262
+
263
+
264
+ class MPNetConverter(Converter):
265
+ def converted(self) -> Tokenizer:
266
+ vocab = self.original_tokenizer.vocab
267
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
268
+
269
+ tokenize_chinese_chars = False
270
+ strip_accents = False
271
+ do_lower_case = False
272
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
273
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
274
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
275
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
276
+
277
+ tokenizer.normalizer = normalizers.BertNormalizer(
278
+ clean_text=True,
279
+ handle_chinese_chars=tokenize_chinese_chars,
280
+ strip_accents=strip_accents,
281
+ lowercase=do_lower_case,
282
+ )
283
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
284
+
285
+ cls = str(self.original_tokenizer.cls_token)
286
+ sep = str(self.original_tokenizer.sep_token)
287
+ cls_token_id = self.original_tokenizer.cls_token_id
288
+ sep_token_id = self.original_tokenizer.sep_token_id
289
+
290
+ tokenizer.post_processor = processors.TemplateProcessing(
291
+ single=f"{cls}:0 $A:0 {sep}:0",
292
+ pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
293
+ special_tokens=[
294
+ (cls, cls_token_id),
295
+ (sep, sep_token_id),
296
+ ],
297
+ )
298
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
299
+
300
+ return tokenizer
301
+
302
+
303
+ class OpenAIGPTConverter(Converter):
304
+ def converted(self) -> Tokenizer:
305
+ vocab = self.original_tokenizer.encoder
306
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
307
+ unk_token = self.original_tokenizer.unk_token
308
+
309
+ tokenizer = Tokenizer(
310
+ BPE(
311
+ vocab=vocab,
312
+ merges=merges,
313
+ dropout=None,
314
+ unk_token=str(unk_token),
315
+ end_of_word_suffix="</w>",
316
+ fuse_unk=False,
317
+ )
318
+ )
319
+
320
+ if tokenizer.token_to_id(str(unk_token)) is not None:
321
+ tokenizer.add_special_tokens([str(unk_token)])
322
+
323
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=True)
324
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
325
+ tokenizer.decoder = decoders.BPEDecoder(suffix="</w>")
326
+
327
+ return tokenizer
328
+
329
+
330
+ class GPT2Converter(Converter):
331
+ def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
332
+ if not vocab:
333
+ vocab = self.original_tokenizer.encoder
334
+ if not merges:
335
+ merges = list(self.original_tokenizer.bpe_ranks)
336
+
337
+ tokenizer = Tokenizer(
338
+ BPE(
339
+ vocab=vocab,
340
+ merges=merges,
341
+ dropout=None,
342
+ continuing_subword_prefix="",
343
+ end_of_word_suffix="",
344
+ fuse_unk=False,
345
+ )
346
+ )
347
+
348
+ add_prefix_space = getattr(self.original_tokenizer, "add_prefix_space", False)
349
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=add_prefix_space)
350
+ tokenizer.decoder = decoders.ByteLevel()
351
+ if getattr(self.original_tokenizer, "add_bos_token", False):
352
+ bos = self.original_tokenizer.bos_token
353
+ bos_token_id = self.original_tokenizer.bos_token_id
354
+ tokenizer.post_processor = processors.TemplateProcessing(
355
+ single=f"{bos}:0 $A:0",
356
+ pair=f"{bos}:0 $A:0 $B:1",
357
+ special_tokens=[
358
+ (bos, bos_token_id),
359
+ ],
360
+ )
361
+ else:
362
+ # XXX trim_offsets=False actually means this post_processor doesn't
363
+ # really do anything.
364
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
365
+ return tokenizer
366
+
367
+
368
+ class HerbertConverter(Converter):
369
+ def converted(self) -> Tokenizer:
370
+ tokenizer_info_str = "#version:"
371
+ token_suffix = "</w>"
372
+
373
+ vocab = self.original_tokenizer.encoder
374
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
375
+ if tokenizer_info_str in merges[0][0]:
376
+ merges = merges[1:]
377
+
378
+ tokenizer = Tokenizer(
379
+ BPE(
380
+ vocab,
381
+ merges,
382
+ dropout=None,
383
+ unk_token=self.original_tokenizer.unk_token,
384
+ end_of_word_suffix=token_suffix,
385
+ )
386
+ )
387
+
388
+ tokenizer.normalizer = normalizers.BertNormalizer(lowercase=False, strip_accents=False)
389
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
390
+ tokenizer.decoder = decoders.BPEDecoder(suffix=token_suffix)
391
+ tokenizer.post_processor = processors.BertProcessing(
392
+ sep=(self.original_tokenizer.sep_token, self.original_tokenizer.sep_token_id),
393
+ cls=(self.original_tokenizer.cls_token, self.original_tokenizer.cls_token_id),
394
+ )
395
+
396
+ return tokenizer
397
+
398
+
399
+ class Qwen2Converter(Converter):
400
+ def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
401
+ if not vocab:
402
+ vocab = self.original_tokenizer.encoder
403
+ if not merges:
404
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
405
+
406
+ tokenizer = Tokenizer(
407
+ BPE(
408
+ vocab=vocab,
409
+ merges=merges,
410
+ dropout=None,
411
+ unk_token=None,
412
+ continuing_subword_prefix="",
413
+ end_of_word_suffix="",
414
+ fuse_unk=False,
415
+ byte_fallback=False,
416
+ )
417
+ )
418
+
419
+ tokenizer.normalizer = normalizers.NFC()
420
+
421
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
422
+ [
423
+ pre_tokenizers.Split(
424
+ Regex(
425
+ r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
426
+ ),
427
+ behavior="isolated",
428
+ invert=False,
429
+ ),
430
+ pre_tokenizers.ByteLevel(
431
+ add_prefix_space=getattr(self.original_tokenizer, "add_prefix_space", False),
432
+ use_regex=False,
433
+ ),
434
+ ]
435
+ )
436
+
437
+ tokenizer.decoder = decoders.ByteLevel()
438
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
439
+
440
+ return tokenizer
441
+
442
+
443
+ class RobertaConverter(Converter):
444
+ def converted(self) -> Tokenizer:
445
+ ot = self.original_tokenizer
446
+ vocab = ot.encoder
447
+ merges = list(ot.bpe_ranks.keys())
448
+
449
+ tokenizer = Tokenizer(
450
+ BPE(
451
+ vocab=vocab,
452
+ merges=merges,
453
+ dropout=None,
454
+ continuing_subword_prefix="",
455
+ end_of_word_suffix="",
456
+ fuse_unk=False,
457
+ )
458
+ )
459
+
460
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
461
+ tokenizer.decoder = decoders.ByteLevel()
462
+ tokenizer.post_processor = processors.RobertaProcessing(
463
+ sep=(ot.sep_token, ot.sep_token_id),
464
+ cls=(ot.cls_token, ot.cls_token_id),
465
+ add_prefix_space=ot.add_prefix_space,
466
+ trim_offsets=True, # True by default on Roberta (historical)
467
+ )
468
+
469
+ return tokenizer
470
+
471
+
472
+ class RoFormerConverter(Converter):
473
+ def converted(self) -> Tokenizer:
474
+ from .models.roformer.tokenization_utils import JiebaPreTokenizer
475
+
476
+ vocab = self.original_tokenizer.vocab
477
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
478
+
479
+ strip_accents = False
480
+ do_lower_case = False
481
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
482
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
483
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
484
+
485
+ tokenizer.normalizer = normalizers.BertNormalizer(
486
+ clean_text=True,
487
+ handle_chinese_chars=False,
488
+ strip_accents=strip_accents,
489
+ lowercase=do_lower_case,
490
+ )
491
+ tokenizer.pre_tokenizer = pre_tokenizers.PreTokenizer.custom(JiebaPreTokenizer(vocab))
492
+
493
+ cls = str(self.original_tokenizer.cls_token)
494
+ sep = str(self.original_tokenizer.sep_token)
495
+ cls_token_id = self.original_tokenizer.cls_token_id
496
+ sep_token_id = self.original_tokenizer.sep_token_id
497
+
498
+ tokenizer.post_processor = processors.TemplateProcessing(
499
+ single=f"{cls}:0 $A:0 {sep}:0",
500
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
501
+ special_tokens=[
502
+ (cls, cls_token_id),
503
+ (sep, sep_token_id),
504
+ ],
505
+ )
506
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
507
+
508
+ return tokenizer
509
+
510
+
511
+ class DebertaConverter(Converter):
512
+ def converted(self) -> Tokenizer:
513
+ ot = self.original_tokenizer
514
+ vocab = ot.encoder
515
+ merges = list(ot.bpe_ranks.keys())
516
+
517
+ tokenizer = Tokenizer(
518
+ BPE(
519
+ vocab=vocab,
520
+ merges=merges,
521
+ dropout=None,
522
+ continuing_subword_prefix="",
523
+ end_of_word_suffix="",
524
+ fuse_unk=False,
525
+ )
526
+ )
527
+
528
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
529
+ tokenizer.decoder = decoders.ByteLevel()
530
+ tokenizer.post_processor = processors.TemplateProcessing(
531
+ single="[CLS]:0 $A:0 [SEP]:0",
532
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
533
+ special_tokens=[
534
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
535
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
536
+ ],
537
+ )
538
+
539
+ return tokenizer
540
+
541
+
542
+ class SpmConverter(Converter):
543
+ handle_byte_fallback = False
544
+ SpmExtractor = SentencePieceExtractor
545
+ special_tokens = {}
546
+
547
+ def __init__(self, *args):
548
+ requires_backends(self, "protobuf")
549
+
550
+ super().__init__(*args)
551
+
552
+ # from .utils import sentencepiece_model_pb2 as model_pb2
553
+ model_pb2 = import_protobuf()
554
+
555
+ m = model_pb2.ModelProto()
556
+ with open(self.original_tokenizer.vocab_file, "rb") as f:
557
+ m.ParseFromString(f.read())
558
+ self.proto = m
559
+
560
+ if self.proto.trainer_spec.byte_fallback and not self.handle_byte_fallback:
561
+ warnings.warn(
562
+ "The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option"
563
+ " which is not implemented in the fast tokenizers. In practice this means that the fast version of the"
564
+ " tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these "
565
+ "unknown tokens into a sequence of byte tokens matching the original piece of text."
566
+ )
567
+
568
+ def vocab(self, proto):
569
+ return [(piece.piece, piece.score) for piece in proto.pieces]
570
+
571
+ def unk_id(self, proto):
572
+ return proto.trainer_spec.unk_id
573
+
574
+ def tokenizer(self, proto):
575
+ model_type = proto.trainer_spec.model_type
576
+ vocab_scores = self.vocab(proto)
577
+
578
+ if model_type == 1:
579
+ tokenizer = Tokenizer(
580
+ Unigram(
581
+ vocab_scores,
582
+ unk_id=self.unk_id(proto),
583
+ byte_fallback=self.handle_byte_fallback,
584
+ )
585
+ )
586
+
587
+ elif model_type == 2:
588
+ _, merges = self.SpmExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
589
+ bpe_vocab = {word: i for i, (word, score) in enumerate(vocab_scores)}
590
+ tokenizer = Tokenizer(
591
+ BPE(
592
+ bpe_vocab,
593
+ merges,
594
+ unk_token=proto.trainer_spec.unk_piece,
595
+ fuse_unk=True,
596
+ byte_fallback=self.handle_byte_fallback,
597
+ dropout=None,
598
+ )
599
+ )
600
+
601
+ else:
602
+ raise Exception(
603
+ "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
604
+ )
605
+
606
+ # control tokens are special
607
+ # user defined symbols are not
608
+ # both user and control tokens are AddedTokens
609
+ # Add user defined symbols (type == 4) from sentencepiece (https://github.com/google/sentencepiece/blob/6225e08edb2577757163b3f5dbba4c0b670ef445/src/sentencepiece_model.proto#L299C29-L299C33)
610
+ spm_added_tokens = [
611
+ (id, p.piece, p.type == 3 or p.piece in self.special_tokens)
612
+ for id, p in enumerate(proto.pieces)
613
+ if p.type in [3, 4]
614
+ ]
615
+ tokenizer.add_tokens(
616
+ [
617
+ AddedToken(token, normalized=False, special=special)
618
+ for id, token, special in sorted(spm_added_tokens, key=lambda x: x[0])
619
+ ]
620
+ )
621
+
622
+ return tokenizer
623
+
624
+ def normalizer(self, proto):
625
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
626
+ _normalizers = [
627
+ normalizers.Strip(left=False, right=True), # stripping is important
628
+ normalizers.Replace(Regex(" {2,}"), "▁"),
629
+ ]
630
+ if not precompiled_charsmap:
631
+ return normalizers.Sequence(_normalizers)
632
+ else:
633
+ return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
634
+
635
+ def pre_tokenizer(self, replacement, add_prefix_space):
636
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
637
+ return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
638
+
639
+ def post_processor(self):
640
+ return None
641
+
642
+ def decoder(self, replacement, add_prefix_space):
643
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
644
+ return decoders.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme)
645
+
646
+ def converted(self) -> Tokenizer:
647
+ tokenizer = self.tokenizer(self.proto)
648
+
649
+ # Tokenizer assemble
650
+ normalizer = self.normalizer(self.proto)
651
+ if normalizer is not None:
652
+ tokenizer.normalizer = normalizer
653
+
654
+ replacement = "▁"
655
+ add_prefix_space = True
656
+ if hasattr(self.original_tokenizer, "add_prefix_space"):
657
+ add_prefix_space = self.original_tokenizer.add_prefix_space
658
+
659
+ pre_tokenizer = self.pre_tokenizer(replacement, add_prefix_space)
660
+ if pre_tokenizer is not None:
661
+ tokenizer.pre_tokenizer = pre_tokenizer
662
+
663
+ tokenizer.decoder = self.decoder(replacement, add_prefix_space)
664
+ post_processor = self.post_processor()
665
+ if post_processor:
666
+ tokenizer.post_processor = post_processor
667
+
668
+ return tokenizer
669
+
670
+
671
+ class AlbertConverter(SpmConverter):
672
+ def vocab(self, proto):
673
+ return [
674
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
675
+ for piece in proto.pieces
676
+ ]
677
+
678
+ def normalizer(self, proto):
679
+ list_normalizers = [
680
+ normalizers.Replace("``", '"'),
681
+ normalizers.Replace("''", '"'),
682
+ ]
683
+ if not self.original_tokenizer.keep_accents:
684
+ list_normalizers.append(normalizers.NFKD())
685
+ list_normalizers.append(normalizers.StripAccents())
686
+ if self.original_tokenizer.do_lower_case:
687
+ list_normalizers.append(normalizers.Lowercase())
688
+
689
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
690
+
691
+ if precompiled_charsmap:
692
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
693
+
694
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
695
+ return normalizers.Sequence(list_normalizers)
696
+
697
+ def post_processor(self):
698
+ return processors.TemplateProcessing(
699
+ single="[CLS]:0 $A:0 [SEP]:0",
700
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
701
+ special_tokens=[
702
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
703
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
704
+ ],
705
+ )
706
+
707
+
708
+ class BarthezConverter(SpmConverter):
709
+ def unk_id(self, proto):
710
+ unk_id = 3
711
+ return unk_id
712
+
713
+ def post_processor(self):
714
+ return processors.TemplateProcessing(
715
+ single="<s> $A </s>",
716
+ pair="<s> $A </s> </s> $B </s>",
717
+ special_tokens=[
718
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
719
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
720
+ ],
721
+ )
722
+
723
+
724
+ class CamembertConverter(SpmConverter):
725
+ def vocab(self, proto):
726
+ vocab = [
727
+ ("<s>NOTUSED", 0.0),
728
+ ("<pad>", 0.0),
729
+ ("</s>NOTUSED", 0.0),
730
+ ("<unk>", 0.0),
731
+ ("<unk>NOTUSED", -100),
732
+ ]
733
+ # We down-grade the original SentencePiece by -100 to avoid using it and use our added token instead
734
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[1:]]
735
+ vocab += [("<mask>", 0.0)]
736
+ return vocab
737
+
738
+ def unk_id(self, proto):
739
+ # See vocab unk position
740
+ return 3
741
+
742
+ def post_processor(self):
743
+ return processors.TemplateProcessing(
744
+ single="<s> $A </s>",
745
+ pair="<s> $A </s> </s> $B </s>",
746
+ special_tokens=[
747
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
748
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
749
+ ],
750
+ )
751
+
752
+
753
+ class DebertaV2Converter(SpmConverter):
754
+ def pre_tokenizer(self, replacement, add_prefix_space):
755
+ list_pretokenizers = []
756
+ if self.original_tokenizer.split_by_punct:
757
+ list_pretokenizers.append(pre_tokenizers.Punctuation(behavior="isolated"))
758
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
759
+ list_pretokenizers.append(pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme))
760
+ return pre_tokenizers.Sequence(list_pretokenizers)
761
+
762
+ def normalizer(self, proto):
763
+ list_normalizers = []
764
+ if self.original_tokenizer.do_lower_case:
765
+ list_normalizers.append(normalizers.Lowercase())
766
+ list_normalizers.append(normalizers.Strip())
767
+
768
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
769
+ if precompiled_charsmap:
770
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
771
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
772
+
773
+ return normalizers.Sequence(list_normalizers)
774
+
775
+ def post_processor(self):
776
+ return processors.TemplateProcessing(
777
+ single="[CLS]:0 $A:0 [SEP]:0",
778
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
779
+ special_tokens=[
780
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
781
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
782
+ ],
783
+ )
784
+
785
+
786
+ class MBartConverter(SpmConverter):
787
+ def vocab(self, proto):
788
+ vocab = [
789
+ ("<s>", 0.0),
790
+ ("<pad>", 0.0),
791
+ ("</s>", 0.0),
792
+ ("<unk>", 0.0),
793
+ ]
794
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
795
+ vocab += [
796
+ ("ar_AR", 0.0),
797
+ ("cs_CZ", 0.0),
798
+ ("de_DE", 0.0),
799
+ ("en_XX", 0.0),
800
+ ("es_XX", 0.0),
801
+ ("et_EE", 0.0),
802
+ ("fi_FI", 0.0),
803
+ ("fr_XX", 0.0),
804
+ ("gu_IN", 0.0),
805
+ ("hi_IN", 0.0),
806
+ ("it_IT", 0.0),
807
+ ("ja_XX", 0.0),
808
+ ("kk_KZ", 0.0),
809
+ ("ko_KR", 0.0),
810
+ ("lt_LT", 0.0),
811
+ ("lv_LV", 0.0),
812
+ ("my_MM", 0.0),
813
+ ("ne_NP", 0.0),
814
+ ("nl_XX", 0.0),
815
+ ("ro_RO", 0.0),
816
+ ("ru_RU", 0.0),
817
+ ("si_LK", 0.0),
818
+ ("tr_TR", 0.0),
819
+ ("vi_VN", 0.0),
820
+ ("zh_CN", 0.0),
821
+ ]
822
+ vocab += [("<mask>", 0.0)]
823
+ return vocab
824
+
825
+ def unk_id(self, proto):
826
+ return 3
827
+
828
+ def post_processor(self):
829
+ return processors.TemplateProcessing(
830
+ single="$A </s> en_XX",
831
+ pair="$A $B </s> en_XX",
832
+ special_tokens=[
833
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
834
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
835
+ ],
836
+ )
837
+
838
+
839
+ class MBart50Converter(SpmConverter):
840
+ def vocab(self, proto):
841
+ vocab = [
842
+ ("<s>", 0.0),
843
+ ("<pad>", 0.0),
844
+ ("</s>", 0.0),
845
+ ("<unk>", 0.0),
846
+ ]
847
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
848
+ vocab += [("ar_AR", 0.0), ("cs_CZ", 0.0), ("de_DE", 0.0), ("en_XX", 0.0), ("es_XX", 0.0), ("et_EE", 0.0), ("fi_FI", 0.0), ("fr_XX", 0.0), ("gu_IN", 0.0), ("hi_IN", 0.0), ("it_IT", 0.0), ("ja_XX", 0.0), ("kk_KZ", 0.0), ("ko_KR", 0.0), ("lt_LT", 0.0), ("lv_LV", 0.0), ("my_MM", 0.0), ("ne_NP", 0.0), ("nl_XX", 0.0), ("ro_RO", 0.0), ("ru_RU", 0.0), ("si_LK", 0.0), ("tr_TR", 0.0), ("vi_VN", 0.0), ("zh_CN", 0.0), ("af_ZA", 0.0), ("az_AZ", 0.0), ("bn_IN", 0.0), ("fa_IR", 0.0), ("he_IL", 0.0), ("hr_HR", 0.0), ("id_ID", 0.0), ("ka_GE", 0.0), ("km_KH", 0.0), ("mk_MK", 0.0), ("ml_IN", 0.0), ("mn_MN", 0.0), ("mr_IN", 0.0), ("pl_PL", 0.0), ("ps_AF", 0.0), ("pt_XX", 0.0), ("sv_SE", 0.0), ("sw_KE", 0.0), ("ta_IN", 0.0), ("te_IN", 0.0), ("th_TH", 0.0), ("tl_XX", 0.0), ("uk_UA", 0.0), ("ur_PK", 0.0), ("xh_ZA", 0.0), ("gl_ES", 0.0), ("sl_SI", 0.0)] # fmt: skip
849
+ vocab += [("<mask>", 0.0)]
850
+ return vocab
851
+
852
+ def unk_id(self, proto):
853
+ return 3
854
+
855
+ def post_processor(self):
856
+ return processors.TemplateProcessing(
857
+ single="en_XX $A </s>",
858
+ pair="en_XX $A $B </s>",
859
+ special_tokens=[
860
+ ("en_XX", self.original_tokenizer.convert_tokens_to_ids("en_XX")),
861
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
862
+ ],
863
+ )
864
+
865
+
866
+ class NllbConverter(SpmConverter):
867
+ def vocab(self, proto):
868
+ vocab = [
869
+ ("<s>", 0.0),
870
+ ("<pad>", 0.0),
871
+ ("</s>", 0.0),
872
+ ("<unk>", 0.0),
873
+ ]
874
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
875
+ return vocab
876
+
877
+ def unk_id(self, proto):
878
+ return 3
879
+
880
+ def post_processor(self):
881
+ return processors.TemplateProcessing(
882
+ single="eng_Latn $A </s>",
883
+ pair="eng_Latn $A $B </s>",
884
+ special_tokens=[
885
+ ("eng_Latn", self.original_tokenizer.convert_tokens_to_ids("eng_Latn")),
886
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
887
+ ],
888
+ )
889
+
890
+
891
+ class SeamlessM4TConverter(SpmConverter):
892
+ def vocab(self, proto):
893
+ vocab = [
894
+ ("<pad>", 0.0),
895
+ ("<unk>", 0.0),
896
+ ("<s>", 0.0),
897
+ ("</s>", 0.0),
898
+ ]
899
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
900
+ return vocab
901
+
902
+ def unk_id(self, proto):
903
+ return self.original_tokenizer.unk_token_id
904
+
905
+ def post_processor(self):
906
+ return processors.TemplateProcessing(
907
+ single="__eng__ $A </s>",
908
+ pair="__eng__ $A $B </s>",
909
+ special_tokens=[
910
+ ("__eng__", self.original_tokenizer.convert_tokens_to_ids("__eng__")),
911
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
912
+ ],
913
+ )
914
+
915
+
916
+ class XLMRobertaConverter(SpmConverter):
917
+ def vocab(self, proto):
918
+ vocab = [
919
+ ("<s>", 0.0),
920
+ ("<pad>", 0.0),
921
+ ("</s>", 0.0),
922
+ ("<unk>", 0.0),
923
+ ]
924
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
925
+ vocab += [("<mask>", 0.0)]
926
+ return vocab
927
+
928
+ def unk_id(self, proto):
929
+ unk_id = 3
930
+ return unk_id
931
+
932
+ def post_processor(self):
933
+ return processors.TemplateProcessing(
934
+ single="<s> $A </s>",
935
+ pair="<s> $A </s> </s> $B </s>",
936
+ special_tokens=[
937
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
938
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
939
+ ],
940
+ )
941
+
942
+
943
+ class XLNetConverter(SpmConverter):
944
+ def vocab(self, proto):
945
+ return [
946
+ (piece.piece, piece.score) if check_number_comma(piece.piece) else (piece.piece, piece.score - 100)
947
+ for piece in proto.pieces
948
+ ]
949
+
950
+ def normalizer(self, proto):
951
+ list_normalizers = [
952
+ normalizers.Replace("``", '"'),
953
+ normalizers.Replace("''", '"'),
954
+ ]
955
+ if not self.original_tokenizer.keep_accents:
956
+ list_normalizers.append(normalizers.NFKD())
957
+ list_normalizers.append(normalizers.StripAccents())
958
+ if self.original_tokenizer.do_lower_case:
959
+ list_normalizers.append(normalizers.Lowercase())
960
+
961
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
962
+
963
+ if precompiled_charsmap:
964
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
965
+
966
+ list_normalizers.append(normalizers.Replace(Regex(" {2,}"), " "))
967
+ return normalizers.Sequence(list_normalizers)
968
+
969
+ def post_processor(self):
970
+ return processors.TemplateProcessing(
971
+ single="$A:0 <sep>:0 <cls>:2",
972
+ pair="$A:0 <sep>:0 $B:1 <sep>:1 <cls>:2",
973
+ special_tokens=[
974
+ ("<sep>", self.original_tokenizer.convert_tokens_to_ids("<sep>")),
975
+ ("<cls>", self.original_tokenizer.convert_tokens_to_ids("<cls>")),
976
+ ],
977
+ )
978
+
979
+
980
+ class ReformerConverter(SpmConverter):
981
+ pass
982
+
983
+
984
+ class RemBertConverter(SpmConverter):
985
+ # Inspired from AlbertConverter
986
+ def normalizer(self, proto):
987
+ list_normalizers = [
988
+ normalizers.Replace("``", '"'),
989
+ normalizers.Replace("''", '"'),
990
+ normalizers.Replace(Regex(" {2,}"), " "),
991
+ ]
992
+ if not self.original_tokenizer.keep_accents:
993
+ list_normalizers.append(normalizers.NFKD())
994
+ list_normalizers.append(normalizers.StripAccents())
995
+ if self.original_tokenizer.do_lower_case:
996
+ list_normalizers.append(normalizers.Lowercase())
997
+
998
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
999
+
1000
+ if precompiled_charsmap:
1001
+ list_normalizers.append(normalizers.Precompiled(precompiled_charsmap))
1002
+
1003
+ return normalizers.Sequence(list_normalizers)
1004
+
1005
+ def post_processor(self):
1006
+ return processors.TemplateProcessing(
1007
+ single="[CLS]:0 $A:0 [SEP]:0",
1008
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
1009
+ special_tokens=[
1010
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
1011
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
1012
+ ],
1013
+ )
1014
+
1015
+
1016
+ class BertGenerationConverter(SpmConverter):
1017
+ pass
1018
+
1019
+
1020
+ class PegasusConverter(SpmConverter):
1021
+ def vocab(self, proto):
1022
+ vocab = [
1023
+ (self.original_tokenizer.pad_token, 0.0),
1024
+ (self.original_tokenizer.eos_token, 0.0),
1025
+ ]
1026
+
1027
+ if self.original_tokenizer.mask_token_sent is not None:
1028
+ vocab += [(self.original_tokenizer.mask_token_sent, 0.0)]
1029
+
1030
+ if (
1031
+ self.original_tokenizer.mask_token is not None
1032
+ and self.original_tokenizer.mask_token_id < self.original_tokenizer.offset
1033
+ ):
1034
+ vocab += [(self.original_tokenizer.mask_token, 0.0)]
1035
+
1036
+ vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
1037
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
1038
+ return vocab
1039
+
1040
+ def unk_id(self, proto):
1041
+ return proto.trainer_spec.unk_id + self.original_tokenizer.offset
1042
+
1043
+ def pre_tokenizer(self, replacement, add_prefix_space):
1044
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
1045
+ return pre_tokenizers.Sequence(
1046
+ [
1047
+ pre_tokenizers.WhitespaceSplit(),
1048
+ pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme),
1049
+ ]
1050
+ )
1051
+
1052
+ def post_processor(self):
1053
+ eos = self.original_tokenizer.eos_token
1054
+ special_tokens = [
1055
+ (eos, self.original_tokenizer.eos_token_id),
1056
+ ]
1057
+ return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)
1058
+
1059
+
1060
+ class T5Converter(SpmConverter):
1061
+ def vocab(self, proto):
1062
+ num_extra_ids = self.original_tokenizer._extra_ids
1063
+ vocab = [(piece.piece, piece.score) for piece in proto.pieces]
1064
+ vocab += [(f"<extra_id_{i}>", 0.0) for i in range(num_extra_ids - 1, -1, -1)]
1065
+ return vocab
1066
+
1067
+ def post_processor(self):
1068
+ return processors.TemplateProcessing(
1069
+ single=["$A", "</s>"],
1070
+ pair=["$A", "</s>", "$B", "</s>"],
1071
+ special_tokens=[
1072
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
1073
+ ],
1074
+ )
1075
+
1076
+
1077
+ class UdopConverter(SpmConverter):
1078
+ def post_processor(self):
1079
+ return processors.TemplateProcessing(
1080
+ single=["$A", "</s>"],
1081
+ pair=["$A", "</s>", "$B", "</s>"],
1082
+ special_tokens=[
1083
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
1084
+ ],
1085
+ )
1086
+
1087
+
1088
+ class WhisperConverter(Converter):
1089
+ def converted(self) -> Tokenizer:
1090
+ vocab = self.original_tokenizer.encoder
1091
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
1092
+
1093
+ tokenizer = Tokenizer(
1094
+ BPE(
1095
+ vocab=vocab,
1096
+ merges=merges,
1097
+ dropout=None,
1098
+ continuing_subword_prefix="",
1099
+ end_of_word_suffix="",
1100
+ fuse_unk=False,
1101
+ )
1102
+ )
1103
+
1104
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=self.original_tokenizer.add_prefix_space)
1105
+ tokenizer.decoder = decoders.ByteLevel()
1106
+
1107
+ prefix_token_ids = self.original_tokenizer.prefix_tokens
1108
+ prefixes = self.original_tokenizer.convert_ids_to_tokens(prefix_token_ids)
1109
+ eos = self.original_tokenizer.eos_token
1110
+ eos_token_id = self.original_tokenizer.eos_token_id
1111
+ prefix_template = " ".join([f"{token}:0" for token in prefixes])
1112
+ tokenizer.post_processor = processors.TemplateProcessing(
1113
+ single=f"{prefix_template} $A:0 {eos}:0",
1114
+ pair=f"{prefix_template} $A:0 $B:1 {eos}:1",
1115
+ special_tokens=[
1116
+ (eos, eos_token_id),
1117
+ *zip(prefixes, prefix_token_ids),
1118
+ ],
1119
+ )
1120
+
1121
+ return tokenizer
1122
+
1123
+
1124
+ class BigBirdConverter(SpmConverter):
1125
+ def post_processor(self):
1126
+ return processors.TemplateProcessing(
1127
+ single="[CLS]:0 $A:0 [SEP]:0",
1128
+ pair="[CLS]:0 $A:0 [SEP]:0 $B:1 [SEP]:1",
1129
+ special_tokens=[
1130
+ ("[CLS]", self.original_tokenizer.convert_tokens_to_ids("[CLS]")),
1131
+ ("[SEP]", self.original_tokenizer.convert_tokens_to_ids("[SEP]")),
1132
+ ],
1133
+ )
1134
+
1135
+
1136
+ class CLIPConverter(Converter):
1137
+ def converted(self) -> Tokenizer:
1138
+ vocab = self.original_tokenizer.encoder
1139
+ merges = list(self.original_tokenizer.bpe_ranks.keys())
1140
+ unk_token = self.original_tokenizer.unk_token
1141
+
1142
+ tokenizer = Tokenizer(
1143
+ BPE(
1144
+ vocab=vocab,
1145
+ merges=merges,
1146
+ dropout=None,
1147
+ continuing_subword_prefix="",
1148
+ end_of_word_suffix="</w>",
1149
+ fuse_unk=False,
1150
+ unk_token=str(unk_token),
1151
+ )
1152
+ )
1153
+
1154
+ tokenizer.normalizer = normalizers.Sequence(
1155
+ [normalizers.NFC(), normalizers.Replace(Regex(r"\s+"), " "), normalizers.Lowercase()]
1156
+ )
1157
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
1158
+ [
1159
+ pre_tokenizers.Split(
1160
+ Regex(r"""'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""),
1161
+ behavior="removed",
1162
+ invert=True,
1163
+ ),
1164
+ pre_tokenizers.ByteLevel(add_prefix_space=False),
1165
+ ]
1166
+ )
1167
+ tokenizer.decoder = decoders.ByteLevel()
1168
+
1169
+ # Hack to have a ByteLevel and TemplaceProcessor
1170
+ tokenizer.post_processor = processors.RobertaProcessing(
1171
+ sep=(self.original_tokenizer.eos_token, self.original_tokenizer.eos_token_id),
1172
+ cls=(self.original_tokenizer.bos_token, self.original_tokenizer.bos_token_id),
1173
+ add_prefix_space=False,
1174
+ trim_offsets=False,
1175
+ )
1176
+ return tokenizer
1177
+
1178
+
1179
+ class LayoutLMv2Converter(Converter):
1180
+ def converted(self) -> Tokenizer:
1181
+ vocab = self.original_tokenizer.vocab
1182
+ tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
1183
+
1184
+ tokenize_chinese_chars = False
1185
+ strip_accents = False
1186
+ do_lower_case = True
1187
+ if hasattr(self.original_tokenizer, "basic_tokenizer"):
1188
+ tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
1189
+ strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
1190
+ do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
1191
+
1192
+ tokenizer.normalizer = normalizers.BertNormalizer(
1193
+ clean_text=True,
1194
+ handle_chinese_chars=tokenize_chinese_chars,
1195
+ strip_accents=strip_accents,
1196
+ lowercase=do_lower_case,
1197
+ )
1198
+ tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
1199
+
1200
+ cls = str(self.original_tokenizer.cls_token)
1201
+ sep = str(self.original_tokenizer.sep_token)
1202
+ cls_token_id = self.original_tokenizer.cls_token_id
1203
+ sep_token_id = self.original_tokenizer.sep_token_id
1204
+
1205
+ tokenizer.post_processor = processors.TemplateProcessing(
1206
+ single=f"{cls}:0 $A:0 {sep}:0",
1207
+ pair=f"{cls}:0 $A:0 {sep}:0 $B:1 {sep}:1",
1208
+ special_tokens=[
1209
+ (cls, cls_token_id),
1210
+ (sep, sep_token_id),
1211
+ ],
1212
+ )
1213
+ tokenizer.decoder = decoders.WordPiece(prefix="##")
1214
+
1215
+ return tokenizer
1216
+
1217
+
1218
+ class BlenderbotConverter(Converter):
1219
+ def converted(self) -> Tokenizer:
1220
+ ot = self.original_tokenizer
1221
+ vocab = ot.encoder
1222
+ merges = list(ot.bpe_ranks.keys())
1223
+
1224
+ tokenizer = Tokenizer(
1225
+ BPE(
1226
+ vocab=vocab,
1227
+ merges=merges,
1228
+ dropout=None,
1229
+ continuing_subword_prefix="",
1230
+ end_of_word_suffix="",
1231
+ fuse_unk=False,
1232
+ )
1233
+ )
1234
+
1235
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
1236
+ tokenizer.decoder = decoders.ByteLevel()
1237
+ tokenizer.post_processor = processors.TemplateProcessing(
1238
+ single=f"$A:0 {ot.eos_token}:0",
1239
+ special_tokens=[
1240
+ (ot.eos_token, ot.eos_token_id),
1241
+ ],
1242
+ )
1243
+
1244
+ return tokenizer
1245
+
1246
+
1247
+ class XGLMConverter(SpmConverter):
1248
+ def vocab(self, proto):
1249
+ vocab = [
1250
+ ("<s>", 0.0),
1251
+ ("<pad>", 0.0),
1252
+ ("</s>", 0.0),
1253
+ ("<unk>", 0.0),
1254
+ ]
1255
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1256
+ vocab += [("<madeupword0>", 0.0), ("<madeupword1>", 0.0), ("<madeupword2>", 0.0), ("<madeupword3>", 0.0), ("<madeupword4>", 0.0), ("<madeupword5>", 0.0), ("<madeupword6>", 0.0)] # fmt: skip
1257
+ return vocab
1258
+
1259
+ def unk_id(self, proto):
1260
+ unk_id = 3
1261
+ return unk_id
1262
+
1263
+ def post_processor(self):
1264
+ return processors.TemplateProcessing(
1265
+ single="</s> $A",
1266
+ pair="</s> $A </s> </s> $B",
1267
+ special_tokens=[
1268
+ ("<s>", self.original_tokenizer.convert_tokens_to_ids("<s>")),
1269
+ ("</s>", self.original_tokenizer.convert_tokens_to_ids("</s>")),
1270
+ ],
1271
+ )
1272
+
1273
+
1274
+ class GemmaConverter(SpmConverter):
1275
+ handle_byte_fallback = True
1276
+ SpmExtractor = GemmaSentencePieceExtractor
1277
+ # start and end of turn tokens must be marked as special
1278
+ special_tokens = {"<start_of_turn>", "<end_of_turn>"}
1279
+
1280
+ """"
1281
+ split_by_unicode_script: true
1282
+ split_by_number: true
1283
+ split_by_whitespace: true
1284
+ treat_whitespace_as_suffix: false
1285
+ allow_whitespace_only_pieces: true
1286
+ split_digits: true
1287
+ byte_fallback: true
1288
+ """
1289
+
1290
+ def normalizer(self, proto):
1291
+ return normalizers.Replace(" ", "▁")
1292
+
1293
+ def vocab(self, proto):
1294
+ vocab = [
1295
+ (self.original_tokenizer.pad_token, 0.0),
1296
+ (self.original_tokenizer.eos_token, 0.0),
1297
+ (self.original_tokenizer.bos_token, 0.0),
1298
+ ]
1299
+ for piece in proto.pieces[3:]:
1300
+ if piece.piece == "<0x09>":
1301
+ vocab += [("\t", piece.score)]
1302
+ else:
1303
+ vocab += [(piece.piece, piece.score)]
1304
+ # vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1305
+ return vocab
1306
+
1307
+ def pre_tokenizer(self, replacement, add_prefix_space):
1308
+ return pre_tokenizers.Split(" ", "merged_with_previous")
1309
+
1310
+ def unk_id(self, proto):
1311
+ unk_id = 3
1312
+ return unk_id
1313
+
1314
+ def decoder(self, replacement, add_prefix_space):
1315
+ return decoders.Sequence(
1316
+ [
1317
+ decoders.Replace("▁", " "),
1318
+ decoders.ByteFallback(),
1319
+ decoders.Fuse(),
1320
+ ]
1321
+ )
1322
+
1323
+
1324
+ class LlamaConverter(SpmConverter):
1325
+ handle_byte_fallback = True
1326
+
1327
+ def vocab(self, proto):
1328
+ vocab = [
1329
+ (self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
1330
+ (self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
1331
+ (self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
1332
+ ]
1333
+ vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
1334
+ return vocab
1335
+
1336
+ def unk_id(self, proto):
1337
+ unk_id = 0
1338
+ return unk_id
1339
+
1340
+ def decoder(self, replacement, add_prefix_space):
1341
+ sequence = [
1342
+ decoders.Replace("▁", " "),
1343
+ decoders.ByteFallback(),
1344
+ decoders.Fuse(),
1345
+ ]
1346
+ if add_prefix_space:
1347
+ sequence += [decoders.Strip(content=" ", left=1)]
1348
+ return decoders.Sequence(sequence)
1349
+
1350
+ def normalizer(self, proto):
1351
+ if getattr(self.original_tokenizer, "legacy", True):
1352
+ sequence = []
1353
+ if getattr(self.original_tokenizer, "add_prefix_space", True):
1354
+ sequence += [normalizers.Prepend(prepend="▁")]
1355
+ sequence += [normalizers.Replace(pattern=" ", content="▁")]
1356
+ return normalizers.Sequence(sequence)
1357
+ return None # non-legacy, no normalizer
1358
+
1359
+ def pre_tokenizer(self, replacement, add_prefix_space):
1360
+ if not getattr(self.original_tokenizer, "legacy", True): # non-legacy, we need a replace
1361
+ prepend_scheme = _get_prepend_scheme(add_prefix_space, self.original_tokenizer)
1362
+ return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
1363
+ return None
1364
+
1365
+ def post_processor(self):
1366
+ # the processor is defined in the LlamaTokenizerFast class.
1367
+ return None
1368
+
1369
+
1370
+ class MarkupLMConverter(Converter):
1371
+ def converted(self) -> Tokenizer:
1372
+ ot = self.original_tokenizer
1373
+ vocab = ot.encoder
1374
+ merges = list(ot.bpe_ranks.keys())
1375
+
1376
+ tokenizer = Tokenizer(
1377
+ BPE(
1378
+ vocab=vocab,
1379
+ merges=merges,
1380
+ dropout=None,
1381
+ continuing_subword_prefix="",
1382
+ end_of_word_suffix="",
1383
+ fuse_unk=False,
1384
+ unk_token=self.original_tokenizer.unk_token,
1385
+ )
1386
+ )
1387
+
1388
+ tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=ot.add_prefix_space)
1389
+ tokenizer.decoder = decoders.ByteLevel()
1390
+
1391
+ cls = str(self.original_tokenizer.cls_token)
1392
+ sep = str(self.original_tokenizer.sep_token)
1393
+ cls_token_id = self.original_tokenizer.cls_token_id
1394
+ sep_token_id = self.original_tokenizer.sep_token_id
1395
+
1396
+ tokenizer.post_processor = processors.TemplateProcessing(
1397
+ single=f"{cls} $A {sep}",
1398
+ pair=f"{cls} $A {sep} $B {sep}",
1399
+ special_tokens=[
1400
+ (cls, cls_token_id),
1401
+ (sep, sep_token_id),
1402
+ ],
1403
+ )
1404
+
1405
+ return tokenizer
1406
+
1407
+
1408
+ class MoshiConverter(SpmConverter):
1409
+ handle_byte_fallback = True
1410
+
1411
+ def __init__(self, vocab_file, model_max_length=None, **kwargs):
1412
+ requires_backends(self, "protobuf")
1413
+
1414
+ Converter.__init__(self, vocab_file)
1415
+
1416
+ # from .utils import sentencepiece_model_pb2 as model_pb2
1417
+ model_pb2 = import_protobuf()
1418
+
1419
+ m = model_pb2.ModelProto()
1420
+ with open(vocab_file, "rb") as f:
1421
+ m.ParseFromString(f.read())
1422
+ self.proto = m
1423
+
1424
+ def normalizer(self, proto):
1425
+ precompiled_charsmap = proto.normalizer_spec.precompiled_charsmap
1426
+ _normalizers = [
1427
+ normalizers.Replace(" ", "▁"),
1428
+ ]
1429
+ if not precompiled_charsmap:
1430
+ return normalizers.Sequence(_normalizers)
1431
+ else:
1432
+ return normalizers.Sequence([normalizers.Precompiled(precompiled_charsmap)] + _normalizers)
1433
+
1434
+ def decoder(self, replacement, add_prefix_space):
1435
+ sequence = [
1436
+ decoders.Replace("▁", " "),
1437
+ decoders.ByteFallback(),
1438
+ decoders.Fuse(),
1439
+ ]
1440
+ if add_prefix_space:
1441
+ sequence += [decoders.Strip(content=" ", left=1)]
1442
+ return decoders.Sequence(sequence)
1443
+
1444
+ def pre_tokenizer(self, replacement, add_prefix_space):
1445
+ prepend_scheme = "first"
1446
+ return pre_tokenizers.Metaspace(replacement=replacement, prepend_scheme=prepend_scheme, split=False)
1447
+
1448
+
1449
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
1450
+ def bytes_to_unicode():
1451
+ """
1452
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
1453
+ characters the bpe code barfs on.
1454
+
1455
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
1456
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
1457
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
1458
+ tables between utf-8 bytes and unicode strings.
1459
+ """
1460
+ bs = (
1461
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
1462
+ )
1463
+ cs = bs[:]
1464
+ n = 0
1465
+ for b in range(2**8):
1466
+ if b not in bs:
1467
+ bs.append(b)
1468
+ cs.append(2**8 + n)
1469
+ n += 1
1470
+ cs = [chr(n) for n in cs]
1471
+ return dict(zip(bs, cs))
1472
+
1473
+
1474
+ class TikTokenConverter:
1475
+ """
1476
+ A general tiktoken converter.
1477
+ """
1478
+
1479
+ def __init__(
1480
+ self,
1481
+ vocab_file=None,
1482
+ pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
1483
+ add_prefix_space=False,
1484
+ additional_special_tokens=None,
1485
+ *args,
1486
+ **kwargs,
1487
+ ):
1488
+ super().__init__(*args)
1489
+ self.vocab_file = vocab_file
1490
+ self.pattern = pattern
1491
+ self.add_prefix_space = add_prefix_space
1492
+ self.additional_special_tokens = additional_special_tokens
1493
+
1494
+ def extract_vocab_merges_from_model(self, tiktoken_url: str):
1495
+ try:
1496
+ from tiktoken.load import load_tiktoken_bpe
1497
+ except Exception:
1498
+ raise ValueError(
1499
+ "`tiktoken` is required to read a `tiktoken` file. Install it with " "`pip install tiktoken`."
1500
+ )
1501
+
1502
+ bpe_ranks = load_tiktoken_bpe(tiktoken_url)
1503
+ byte_encoder = bytes_to_unicode()
1504
+
1505
+ def token_bytes_to_string(b):
1506
+ return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
1507
+
1508
+ merges = []
1509
+ vocab = {}
1510
+ for token, rank in bpe_ranks.items():
1511
+ vocab[token_bytes_to_string(token)] = rank
1512
+ if len(token) == 1:
1513
+ continue
1514
+ local = []
1515
+ for index in range(1, len(token)):
1516
+ piece_l, piece_r = token[:index], token[index:]
1517
+ if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
1518
+ local.append((piece_l, piece_r, rank))
1519
+ local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
1520
+ merges.extend(local)
1521
+ merges = sorted(merges, key=lambda val: val[2], reverse=False)
1522
+ merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
1523
+ return vocab, merges
1524
+
1525
+ def tokenizer(self):
1526
+ vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab_file)
1527
+ tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
1528
+ if hasattr(tokenizer.model, "ignore_merges"):
1529
+ tokenizer.model.ignore_merges = True
1530
+ return tokenizer
1531
+
1532
+ def converted(self) -> Tokenizer:
1533
+ tokenizer = self.tokenizer()
1534
+ tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
1535
+ [
1536
+ pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
1537
+ pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
1538
+ ]
1539
+ )
1540
+ tokenizer.decoder = decoders.ByteLevel()
1541
+ tokenizer.add_special_tokens(self.additional_special_tokens)
1542
+
1543
+ tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
1544
+
1545
+ return tokenizer
1546
+
1547
+
1548
+ SLOW_TO_FAST_CONVERTERS = {
1549
+ "AlbertTokenizer": AlbertConverter,
1550
+ "BartTokenizer": RobertaConverter,
1551
+ "BarthezTokenizer": BarthezConverter,
1552
+ "BertTokenizer": BertConverter,
1553
+ "BigBirdTokenizer": BigBirdConverter,
1554
+ "BlenderbotTokenizer": BlenderbotConverter,
1555
+ "CamembertTokenizer": CamembertConverter,
1556
+ "CLIPTokenizer": CLIPConverter,
1557
+ "CodeGenTokenizer": GPT2Converter,
1558
+ "ConvBertTokenizer": BertConverter,
1559
+ "DebertaTokenizer": DebertaConverter,
1560
+ "DebertaV2Tokenizer": DebertaV2Converter,
1561
+ "DistilBertTokenizer": BertConverter,
1562
+ "DPRReaderTokenizer": BertConverter,
1563
+ "DPRQuestionEncoderTokenizer": BertConverter,
1564
+ "DPRContextEncoderTokenizer": BertConverter,
1565
+ "ElectraTokenizer": BertConverter,
1566
+ "FNetTokenizer": AlbertConverter,
1567
+ "FunnelTokenizer": FunnelConverter,
1568
+ "GPT2Tokenizer": GPT2Converter,
1569
+ "HerbertTokenizer": HerbertConverter,
1570
+ "LayoutLMTokenizer": BertConverter,
1571
+ "LayoutLMv2Tokenizer": BertConverter,
1572
+ "LayoutLMv3Tokenizer": RobertaConverter,
1573
+ "LayoutXLMTokenizer": XLMRobertaConverter,
1574
+ "LongformerTokenizer": RobertaConverter,
1575
+ "LEDTokenizer": RobertaConverter,
1576
+ "LxmertTokenizer": BertConverter,
1577
+ "MarkupLMTokenizer": MarkupLMConverter,
1578
+ "MBartTokenizer": MBartConverter,
1579
+ "MBart50Tokenizer": MBart50Converter,
1580
+ "MPNetTokenizer": MPNetConverter,
1581
+ "MobileBertTokenizer": BertConverter,
1582
+ "MvpTokenizer": RobertaConverter,
1583
+ "NllbTokenizer": NllbConverter,
1584
+ "OpenAIGPTTokenizer": OpenAIGPTConverter,
1585
+ "PegasusTokenizer": PegasusConverter,
1586
+ "Qwen2Tokenizer": Qwen2Converter,
1587
+ "RealmTokenizer": BertConverter,
1588
+ "ReformerTokenizer": ReformerConverter,
1589
+ "RemBertTokenizer": RemBertConverter,
1590
+ "RetriBertTokenizer": BertConverter,
1591
+ "RobertaTokenizer": RobertaConverter,
1592
+ "RoFormerTokenizer": RoFormerConverter,
1593
+ "SeamlessM4TTokenizer": SeamlessM4TConverter,
1594
+ "SqueezeBertTokenizer": BertConverter,
1595
+ "T5Tokenizer": T5Converter,
1596
+ "UdopTokenizer": UdopConverter,
1597
+ "WhisperTokenizer": WhisperConverter,
1598
+ "XLMRobertaTokenizer": XLMRobertaConverter,
1599
+ "XLNetTokenizer": XLNetConverter,
1600
+ "SplinterTokenizer": SplinterConverter,
1601
+ "XGLMTokenizer": XGLMConverter,
1602
+ "LlamaTokenizer": LlamaConverter,
1603
+ "CodeLlamaTokenizer": LlamaConverter,
1604
+ "GemmaTokenizer": GemmaConverter,
1605
+ "Phi3Tokenizer": LlamaConverter,
1606
+ }
1607
+
1608
+
1609
+ def convert_slow_tokenizer(transformer_tokenizer, from_tiktoken=False) -> Tokenizer:
1610
+ """
1611
+ Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
1612
+
1613
+ Args:
1614
+ transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
1615
+ Instance of a slow tokenizer to convert in the backend tokenizer for
1616
+ [`~tokenization_utils_base.PreTrainedTokenizerFast`].
1617
+ from_tiktoken (bool, optional): Whether to use the `tiktoken` library to convert the tokenizer instead of sentencepiece.
1618
+ Defaults to False.
1619
+
1620
+ Return:
1621
+ A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
1622
+ [`~tokenization_utils_base.PreTrainedTokenizerFast`]
1623
+ """
1624
+
1625
+ tokenizer_class_name = transformer_tokenizer.__class__.__name__
1626
+ if tokenizer_class_name in SLOW_TO_FAST_CONVERTERS and not from_tiktoken:
1627
+ converter_class = SLOW_TO_FAST_CONVERTERS[tokenizer_class_name]
1628
+ return converter_class(transformer_tokenizer).converted()
1629
+
1630
+ else:
1631
+ try:
1632
+ logger.info("Converting from Tiktoken")
1633
+ return TikTokenConverter(
1634
+ vocab_file=transformer_tokenizer.vocab_file,
1635
+ additional_special_tokens=transformer_tokenizer.additional_special_tokens,
1636
+ ).converted()
1637
+ except Exception:
1638
+ raise ValueError(
1639
+ f"Converting from Tiktoken failed, if a converter for SentencePiece is available, provide a model path "
1640
+ f"with a SentencePiece tokenizer.model file."
1641
+ f"Currently available slow->fast convertors: {list(SLOW_TO_FAST_CONVERTERS.keys())}"
1642
+ )
convert_slow_tokenizers_checkpoints_to_fast.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert slow tokenizers checkpoints in fast (serialization format of the `tokenizers` library)"""
16
+
17
+ import argparse
18
+ import os
19
+
20
+ import transformers
21
+
22
+ from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
23
+ from .utils import logging
24
+
25
+
26
+ logging.set_verbosity_info()
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ TOKENIZER_CLASSES = {
32
+ # Phi3 uses Llama tokenizer
33
+ name: getattr(transformers, "LlamaTokenizerFast" if name == "Phi3Tokenizer" else name + "Fast")
34
+ for name in SLOW_TO_FAST_CONVERTERS
35
+ }
36
+
37
+
38
+ def convert_slow_checkpoint_to_fast(tokenizer_name, checkpoint_name, dump_path, force_download):
39
+ if tokenizer_name is not None and tokenizer_name not in TOKENIZER_CLASSES:
40
+ raise ValueError(f"Unrecognized tokenizer name, should be one of {list(TOKENIZER_CLASSES.keys())}.")
41
+
42
+ if tokenizer_name is None:
43
+ tokenizer_names = TOKENIZER_CLASSES
44
+ else:
45
+ tokenizer_names = {tokenizer_name: getattr(transformers, tokenizer_name + "Fast")}
46
+
47
+ logger.info(f"Loading tokenizer classes: {tokenizer_names}")
48
+
49
+ for tokenizer_name in tokenizer_names:
50
+ tokenizer_class = TOKENIZER_CLASSES[tokenizer_name]
51
+
52
+ add_prefix = True
53
+ if checkpoint_name is None:
54
+ checkpoint_names = list(tokenizer_class.max_model_input_sizes.keys())
55
+ else:
56
+ checkpoint_names = [checkpoint_name]
57
+
58
+ logger.info(f"For tokenizer {tokenizer_class.__class__.__name__} loading checkpoints: {checkpoint_names}")
59
+
60
+ for checkpoint in checkpoint_names:
61
+ logger.info(f"Loading {tokenizer_class.__class__.__name__} {checkpoint}")
62
+
63
+ # Load tokenizer
64
+ tokenizer = tokenizer_class.from_pretrained(checkpoint, force_download=force_download)
65
+
66
+ # Save fast tokenizer
67
+ logger.info(f"Save fast tokenizer to {dump_path} with prefix {checkpoint} add_prefix {add_prefix}")
68
+
69
+ # For organization names we create sub-directories
70
+ if "/" in checkpoint:
71
+ checkpoint_directory, checkpoint_prefix_name = checkpoint.split("/")
72
+ dump_path_full = os.path.join(dump_path, checkpoint_directory)
73
+ elif add_prefix:
74
+ checkpoint_prefix_name = checkpoint
75
+ dump_path_full = dump_path
76
+ else:
77
+ checkpoint_prefix_name = None
78
+ dump_path_full = dump_path
79
+
80
+ logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
81
+
82
+ if checkpoint in list(tokenizer.pretrained_vocab_files_map.values())[0]:
83
+ file_path = list(tokenizer.pretrained_vocab_files_map.values())[0][checkpoint]
84
+ next_char = file_path.split(checkpoint)[-1][0]
85
+ if next_char == "/":
86
+ dump_path_full = os.path.join(dump_path_full, checkpoint_prefix_name)
87
+ checkpoint_prefix_name = None
88
+
89
+ logger.info(f"=> {dump_path_full} with prefix {checkpoint_prefix_name}, add_prefix {add_prefix}")
90
+
91
+ file_names = tokenizer.save_pretrained(
92
+ dump_path_full, legacy_format=False, filename_prefix=checkpoint_prefix_name
93
+ )
94
+ logger.info(f"=> File names {file_names}")
95
+
96
+ for file_name in file_names:
97
+ if not file_name.endswith("tokenizer.json"):
98
+ os.remove(file_name)
99
+ logger.info(f"=> removing {file_name}")
100
+
101
+
102
+ if __name__ == "__main__":
103
+ parser = argparse.ArgumentParser()
104
+ # Required parameters
105
+ parser.add_argument(
106
+ "--dump_path", default=None, type=str, required=True, help="Path to output generated fast tokenizer files."
107
+ )
108
+ parser.add_argument(
109
+ "--tokenizer_name",
110
+ default=None,
111
+ type=str,
112
+ help=(
113
+ f"Optional tokenizer type selected in the list of {list(TOKENIZER_CLASSES.keys())}. If not given, will "
114
+ "download and convert all the checkpoints from AWS."
115
+ ),
116
+ )
117
+ parser.add_argument(
118
+ "--checkpoint_name",
119
+ default=None,
120
+ type=str,
121
+ help="Optional checkpoint name. If not given, will download and convert the canonical checkpoints from AWS.",
122
+ )
123
+ parser.add_argument(
124
+ "--force_download",
125
+ action="store_true",
126
+ help="Re-download checkpoints.",
127
+ )
128
+ args = parser.parse_args()
129
+
130
+ convert_slow_checkpoint_to_fast(args.tokenizer_name, args.checkpoint_name, args.dump_path, args.force_download)
convert_tf_hub_seq_to_seq_bert_to_pytorch.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Convert Seq2Seq TF Hub checkpoint."""
16
+
17
+ import argparse
18
+
19
+ from . import (
20
+ BertConfig,
21
+ BertGenerationConfig,
22
+ BertGenerationDecoder,
23
+ BertGenerationEncoder,
24
+ load_tf_weights_in_bert_generation,
25
+ logging,
26
+ )
27
+
28
+
29
+ logging.set_verbosity_info()
30
+
31
+
32
+ def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder):
33
+ # Initialise PyTorch model
34
+ bert_config = BertConfig.from_pretrained(
35
+ "google-bert/bert-large-cased",
36
+ vocab_size=vocab_size,
37
+ max_position_embeddings=512,
38
+ is_decoder=True,
39
+ add_cross_attention=True,
40
+ )
41
+ bert_config_dict = bert_config.to_dict()
42
+ del bert_config_dict["type_vocab_size"]
43
+ config = BertGenerationConfig(**bert_config_dict)
44
+ if is_encoder:
45
+ model = BertGenerationEncoder(config)
46
+ else:
47
+ model = BertGenerationDecoder(config)
48
+ print(f"Building PyTorch model from configuration: {config}")
49
+
50
+ # Load weights from tf checkpoint
51
+ load_tf_weights_in_bert_generation(
52
+ model,
53
+ tf_hub_path,
54
+ model_class="bert",
55
+ is_encoder_named_decoder=is_encoder_named_decoder,
56
+ is_encoder=is_encoder,
57
+ )
58
+
59
+ # Save pytorch-model
60
+ print(f"Save PyTorch model and config to {pytorch_dump_path}")
61
+ model.save_pretrained(pytorch_dump_path)
62
+
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser()
66
+ # Required parameters
67
+ parser.add_argument(
68
+ "--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
69
+ )
70
+ parser.add_argument(
71
+ "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
72
+ )
73
+ parser.add_argument(
74
+ "--is_encoder_named_decoder",
75
+ action="store_true",
76
+ help="If decoder has to be renamed to encoder in PyTorch model.",
77
+ )
78
+ parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.")
79
+ parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model")
80
+ args = parser.parse_args()
81
+ convert_tf_checkpoint_to_pytorch(
82
+ args.tf_hub_path,
83
+ args.pytorch_dump_path,
84
+ args.is_encoder_named_decoder,
85
+ args.vocab_size,
86
+ is_encoder=args.is_encoder,
87
+ )
debug_utils.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import collections
16
+
17
+ from .utils import ExplicitEnum, is_torch_available, logging
18
+
19
+
20
+ if is_torch_available():
21
+ import torch
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class DebugUnderflowOverflow:
28
+ """
29
+ This debug class helps detect and understand where the model starts getting very large or very small, and more
30
+ importantly `nan` or `inf` weight and activation elements.
31
+
32
+ There are 2 working modes:
33
+
34
+ 1. Underflow/overflow detection (default)
35
+ 2. Specific batch absolute min/max tracing without detection
36
+
37
+ Mode 1: Underflow/overflow detection
38
+
39
+ To activate the underflow/overflow detection, initialize the object with the model :
40
+
41
+ ```python
42
+ debug_overflow = DebugUnderflowOverflow(model)
43
+ ```
44
+
45
+ then run the training as normal and if `nan` or `inf` gets detected in at least one of the weight, input or output
46
+ elements this module will throw an exception and will print `max_frames_to_save` frames that lead to this event,
47
+ each frame reporting
48
+
49
+ 1. the fully qualified module name plus the class name whose `forward` was run
50
+ 2. the absolute min and max value of all elements for each module weights, and the inputs and output
51
+
52
+ For example, here is the header and the last few frames in detection report for `google/mt5-small` run in fp16
53
+ mixed precision :
54
+
55
+ ```
56
+ Detected inf/nan during batch_number=0
57
+ Last 21 forward frames:
58
+ abs min abs max metadata
59
+ [...]
60
+ encoder.block.2.layer.1.DenseReluDense.wi_0 Linear
61
+ 2.17e-07 4.50e+00 weight
62
+ 1.79e-06 4.65e+00 input[0]
63
+ 2.68e-06 3.70e+01 output
64
+ encoder.block.2.layer.1.DenseReluDense.wi_1 Linear
65
+ 8.08e-07 2.66e+01 weight
66
+ 1.79e-06 4.65e+00 input[0]
67
+ 1.27e-04 2.37e+02 output
68
+ encoder.block.2.layer.1.DenseReluDense.wo Linear
69
+ 1.01e-06 6.44e+00 weight
70
+ 0.00e+00 9.74e+03 input[0]
71
+ 3.18e-04 6.27e+04 output
72
+ encoder.block.2.layer.1.DenseReluDense T5DenseGatedGeluDense
73
+ 1.79e-06 4.65e+00 input[0]
74
+ 3.18e-04 6.27e+04 output
75
+ encoder.block.2.layer.1.dropout Dropout
76
+ 3.18e-04 6.27e+04 input[0]
77
+ 0.00e+00 inf output
78
+ ```
79
+
80
+ You can see here, that `T5DenseGatedGeluDense.forward` resulted in output activations, whose absolute max value was
81
+ around 62.7K, which is very close to fp16's top limit of 64K. In the next frame we have `Dropout` which
82
+ renormalizes the weights, after it zeroed some of the elements, which pushes the absolute max value to more than
83
+ 64K, and we get an overlow.
84
+
85
+ As you can see it's the previous frames that we need to look into when the numbers start going into very large for
86
+ fp16 numbers.
87
+
88
+ The tracking is done in a forward hook, which gets invoked immediately after `forward` has completed.
89
+
90
+ By default the last 21 frames are printed. You can change the default to adjust for your needs. For example :
91
+
92
+ ```python
93
+ debug_overflow = DebugUnderflowOverflow(model, max_frames_to_save=100)
94
+ ```
95
+
96
+ To validate that you have set up this debugging feature correctly, and you intend to use it in a training that
97
+ may take hours to complete, first run it with normal tracing enabled for one of a few batches as explained in
98
+ the next section.
99
+
100
+
101
+ Mode 2. Specific batch absolute min/max tracing without detection
102
+
103
+ The second work mode is per-batch tracing with the underflow/overflow detection feature turned off.
104
+
105
+ Let's say you want to watch the absolute min and max values for all the ingredients of each `forward` call of a
106
+ given batch, and only do that for batches 1 and 3. Then you instantiate this class as :
107
+
108
+ ```python
109
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3])
110
+ ```
111
+
112
+ And now full batches 1 and 3 will be traced using the same format as explained above. Batches are 0-indexed.
113
+
114
+ This is helpful if you know that the program starts misbehaving after a certain batch number, so you can
115
+ fast-forward right to that area.
116
+
117
+
118
+ Early stopping:
119
+
120
+ You can also specify the batch number after which to stop the training, with :
121
+
122
+ ```python
123
+ debug_overflow = DebugUnderflowOverflow(model, trace_batch_nums=[1, 3], abort_after_batch_num=3)
124
+ ```
125
+
126
+ This feature is mainly useful in the tracing mode, but you can use it for any mode.
127
+
128
+
129
+ **Performance**:
130
+
131
+ As this module measures absolute `min`/``max` of each weight of the model on every forward it'll slow the training
132
+ down. Therefore remember to turn it off once the debugging needs have been met.
133
+
134
+ Args:
135
+ model (`nn.Module`):
136
+ The model to debug.
137
+ max_frames_to_save (`int`, *optional*, defaults to 21):
138
+ How many frames back to record
139
+ trace_batch_nums(`List[int]`, *optional*, defaults to `[]`):
140
+ Which batch numbers to trace (turns detection off)
141
+ abort_after_batch_num (`int``, *optional*):
142
+ Whether to abort after a certain batch number has finished
143
+ """
144
+
145
+ def __init__(self, model, max_frames_to_save=21, trace_batch_nums=[], abort_after_batch_num=None):
146
+ self.model = model
147
+ self.trace_batch_nums = trace_batch_nums
148
+ self.abort_after_batch_num = abort_after_batch_num
149
+
150
+ # keep a LIFO buffer of frames to dump as soon as inf/nan is encountered to give context to the problem emergence
151
+ self.frames = collections.deque([], max_frames_to_save)
152
+ self.frame = []
153
+ self.batch_number = 0
154
+ self.total_calls = 0
155
+ self.detected_overflow = False
156
+ self.prefix = " "
157
+
158
+ self.analyse_model()
159
+
160
+ self.register_forward_hook()
161
+
162
+ def save_frame(self, frame=None):
163
+ if frame is not None:
164
+ self.expand_frame(frame)
165
+ self.frames.append("\n".join(self.frame))
166
+ self.frame = [] # start a new frame
167
+
168
+ def expand_frame(self, line):
169
+ self.frame.append(line)
170
+
171
+ def trace_frames(self):
172
+ print("\n".join(self.frames))
173
+ self.frames = []
174
+
175
+ def reset_saved_frames(self):
176
+ self.frames = []
177
+
178
+ def dump_saved_frames(self):
179
+ print(f"\nDetected inf/nan during batch_number={self.batch_number}")
180
+ print(f"Last {len(self.frames)} forward frames:")
181
+ print(f"{'abs min':8} {'abs max':8} metadata")
182
+ print("\n".join(self.frames))
183
+ print("\n\n")
184
+ self.frames = []
185
+
186
+ def analyse_model(self):
187
+ # extract the fully qualified module names, to be able to report at run time. e.g.:
188
+ # encoder.block.2.layer.0.SelfAttention.o
189
+ #
190
+ # for shared weights only the first shared module name will be registered
191
+ self.module_names = {m: name for name, m in self.model.named_modules()}
192
+ # self.longest_module_name = max(len(v) for v in self.module_names.values())
193
+
194
+ def analyse_variable(self, var, ctx):
195
+ if torch.is_tensor(var):
196
+ self.expand_frame(get_abs_min_max(var, ctx))
197
+ if detect_overflow(var, ctx):
198
+ self.detected_overflow = True
199
+ elif var is None:
200
+ self.expand_frame(f"{'None':>17} {ctx}")
201
+ else:
202
+ self.expand_frame(f"{'not a tensor':>17} {ctx}")
203
+
204
+ def batch_start_frame(self):
205
+ self.expand_frame(f"\n\n{self.prefix} *** Starting batch number={self.batch_number} ***")
206
+ self.expand_frame(f"{'abs min':8} {'abs max':8} metadata")
207
+
208
+ def batch_end_frame(self):
209
+ self.expand_frame(f"{self.prefix} *** Finished batch number={self.batch_number-1} ***\n\n")
210
+
211
+ def create_frame(self, module, input, output):
212
+ self.expand_frame(f"{self.prefix} {self.module_names[module]} {module.__class__.__name__}")
213
+
214
+ # params
215
+ for name, p in module.named_parameters(recurse=False):
216
+ self.analyse_variable(p, name)
217
+
218
+ # inputs
219
+ if isinstance(input, tuple):
220
+ for i, x in enumerate(input):
221
+ self.analyse_variable(x, f"input[{i}]")
222
+ else:
223
+ self.analyse_variable(input, "input")
224
+
225
+ # outputs
226
+ if isinstance(output, tuple):
227
+ for i, x in enumerate(output):
228
+ # possibly a tuple of tuples
229
+ if isinstance(x, tuple):
230
+ for j, y in enumerate(x):
231
+ self.analyse_variable(y, f"output[{i}][{j}]")
232
+ else:
233
+ self.analyse_variable(x, f"output[{i}]")
234
+ else:
235
+ self.analyse_variable(output, "output")
236
+
237
+ self.save_frame()
238
+
239
+ def register_forward_hook(self):
240
+ self.model.apply(self._register_forward_hook)
241
+
242
+ def _register_forward_hook(self, module):
243
+ module.register_forward_hook(self.forward_hook)
244
+
245
+ def forward_hook(self, module, input, output):
246
+ # - input is a tuple of packed inputs (could be non-Tensors)
247
+ # - output could be a Tensor or a tuple of Tensors and non-Tensors
248
+
249
+ last_frame_of_batch = False
250
+
251
+ trace_mode = True if self.batch_number in self.trace_batch_nums else False
252
+ if trace_mode:
253
+ self.reset_saved_frames()
254
+
255
+ if self.total_calls == 0:
256
+ self.batch_start_frame()
257
+ self.total_calls += 1
258
+
259
+ # count batch numbers - the very first forward hook of the batch will be called when the
260
+ # batch completes - i.e. it gets called very last - we know this batch has finished
261
+ if module == self.model:
262
+ self.batch_number += 1
263
+ last_frame_of_batch = True
264
+
265
+ self.create_frame(module, input, output)
266
+
267
+ # if last_frame_of_batch:
268
+ # self.batch_end_frame()
269
+
270
+ if trace_mode:
271
+ self.trace_frames()
272
+
273
+ if last_frame_of_batch:
274
+ self.batch_start_frame()
275
+
276
+ if self.detected_overflow and not trace_mode:
277
+ self.dump_saved_frames()
278
+
279
+ # now we can abort, as it's pointless to continue running
280
+ raise ValueError(
281
+ "DebugUnderflowOverflow: inf/nan detected, aborting as there is no point running further. "
282
+ "Please scroll up above this traceback to see the activation values prior to this event."
283
+ )
284
+
285
+ # abort after certain batch if requested to do so
286
+ if self.abort_after_batch_num is not None and self.batch_number > self.abort_after_batch_num:
287
+ raise ValueError(
288
+ f"DebugUnderflowOverflow: aborting after {self.batch_number} batches due to"
289
+ f" `abort_after_batch_num={self.abort_after_batch_num}` arg"
290
+ )
291
+
292
+
293
+ def get_abs_min_max(var, ctx):
294
+ abs_var = var.abs()
295
+ return f"{abs_var.min():8.2e} {abs_var.max():8.2e} {ctx}"
296
+
297
+
298
+ def detect_overflow(var, ctx):
299
+ """
300
+ Report whether the tensor contains any `nan` or `inf` entries.
301
+
302
+ This is useful for detecting overflows/underflows and best to call right after the function that did some math that
303
+ modified the tensor in question.
304
+
305
+ This function contains a few other helper features that you can enable and tweak directly if you want to track
306
+ various other things.
307
+
308
+ Args:
309
+ var: the tensor variable to check
310
+ ctx: the message to print as a context
311
+
312
+ Return:
313
+ `True` if `inf` or `nan` was detected, `False` otherwise
314
+ """
315
+ detected = False
316
+ if torch.isnan(var).any().item():
317
+ detected = True
318
+ print(f"{ctx} has nans")
319
+ if torch.isinf(var).any().item():
320
+ detected = True
321
+ print(f"{ctx} has infs")
322
+
323
+ # if needed to monitor large elements can enable the following
324
+ if 0: # and detected:
325
+ n100 = var[torch.ge(var.abs(), 100)]
326
+ if n100.numel() > 0:
327
+ print(f"{ctx}: n100={n100.numel()}")
328
+ n1000 = var[torch.ge(var.abs(), 1000)]
329
+ if n1000.numel() > 0:
330
+ print(f"{ctx}: n1000={n1000.numel()}")
331
+ n10000 = var[torch.ge(var.abs(), 10000)]
332
+ if n10000.numel() > 0:
333
+ print(f"{ctx}: n10000={n10000.numel()}")
334
+
335
+ if 0:
336
+ print(f"min={var.min():9.2e} max={var.max():9.2e}")
337
+
338
+ if 0:
339
+ print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
340
+
341
+ return detected
342
+
343
+
344
+ class DebugOption(ExplicitEnum):
345
+ UNDERFLOW_OVERFLOW = "underflow_overflow"
346
+ TPU_METRICS_DEBUG = "tpu_metrics_debug"
dependency_versions_check.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .dependency_versions_table import deps
16
+ from .utils.versions import require_version, require_version_core
17
+
18
+
19
+ # define which module versions we always want to check at run time
20
+ # (usually the ones defined in `install_requires` in setup.py)
21
+ #
22
+ # order specific notes:
23
+ # - tqdm must be checked before tokenizers
24
+
25
+ pkgs_to_check_at_runtime = [
26
+ "python",
27
+ "tqdm",
28
+ "regex",
29
+ "requests",
30
+ "packaging",
31
+ "filelock",
32
+ "numpy",
33
+ "tokenizers",
34
+ "huggingface-hub",
35
+ "safetensors",
36
+ "accelerate",
37
+ "pyyaml",
38
+ ]
39
+
40
+ for pkg in pkgs_to_check_at_runtime:
41
+ if pkg in deps:
42
+ if pkg == "tokenizers":
43
+ # must be loaded here, or else tqdm check may fail
44
+ from .utils import is_tokenizers_available
45
+
46
+ if not is_tokenizers_available():
47
+ continue # not required, check version only if installed
48
+ elif pkg == "accelerate":
49
+ # must be loaded here, or else tqdm check may fail
50
+ from .utils import is_accelerate_available
51
+
52
+ # Maybe switch to is_torch_available in the future here so that Accelerate is hard dep of
53
+ # Transformers with PyTorch
54
+ if not is_accelerate_available():
55
+ continue # not required, check version only if installed
56
+
57
+ require_version_core(deps[pkg])
58
+ else:
59
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
60
+
61
+
62
+ def dep_version_check(pkg, hint=None):
63
+ require_version(deps[pkg], hint)
dependency_versions_table.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow>=10.0.1,<=15.0",
6
+ "accelerate": "accelerate>=0.26.0",
7
+ "av": "av==9.2.0",
8
+ "beautifulsoup4": "beautifulsoup4",
9
+ "blobfile": "blobfile",
10
+ "codecarbon": "codecarbon>=2.8.1",
11
+ "cookiecutter": "cookiecutter==1.7.3",
12
+ "dataclasses": "dataclasses",
13
+ "datasets": "datasets!=2.5.0",
14
+ "deepspeed": "deepspeed>=0.9.3",
15
+ "diffusers": "diffusers",
16
+ "dill": "dill<0.3.5",
17
+ "evaluate": "evaluate>=0.2.0",
18
+ "faiss-cpu": "faiss-cpu",
19
+ "fastapi": "fastapi",
20
+ "filelock": "filelock",
21
+ "flax": "flax>=0.4.1,<=0.7.0",
22
+ "fsspec": "fsspec<2023.10.0",
23
+ "ftfy": "ftfy",
24
+ "fugashi": "fugashi>=1.0",
25
+ "GitPython": "GitPython<3.1.19",
26
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
27
+ "huggingface-hub": "huggingface-hub>=0.24.0,<1.0",
28
+ "importlib_metadata": "importlib_metadata",
29
+ "ipadic": "ipadic>=1.0.0,<2.0",
30
+ "isort": "isort>=5.5.4",
31
+ "jax": "jax>=0.4.1,<=0.4.13",
32
+ "jaxlib": "jaxlib>=0.4.1,<=0.4.13",
33
+ "jieba": "jieba",
34
+ "jinja2": "jinja2>=3.1.0",
35
+ "kenlm": "kenlm",
36
+ "keras": "keras>2.9,<2.16",
37
+ "keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
38
+ "librosa": "librosa",
39
+ "nltk": "nltk<=3.8.1",
40
+ "natten": "natten>=0.14.6,<0.15.0",
41
+ "numpy": "numpy>=1.17",
42
+ "onnxconverter-common": "onnxconverter-common",
43
+ "onnxruntime-tools": "onnxruntime-tools>=1.4.2",
44
+ "onnxruntime": "onnxruntime>=1.4.0",
45
+ "opencv-python": "opencv-python",
46
+ "optimum-benchmark": "optimum-benchmark>=0.3.0",
47
+ "optuna": "optuna",
48
+ "optax": "optax>=0.0.8,<=0.1.4",
49
+ "packaging": "packaging>=20.0",
50
+ "parameterized": "parameterized",
51
+ "phonemizer": "phonemizer",
52
+ "protobuf": "protobuf",
53
+ "psutil": "psutil",
54
+ "pyyaml": "pyyaml>=5.1",
55
+ "pydantic": "pydantic",
56
+ "pytest": "pytest>=7.2.0,<8.0.0",
57
+ "pytest-asyncio": "pytest-asyncio",
58
+ "pytest-timeout": "pytest-timeout",
59
+ "pytest-xdist": "pytest-xdist",
60
+ "python": "python>=3.9.0",
61
+ "ray[tune]": "ray[tune]>=2.7.0",
62
+ "regex": "regex!=2019.12.17",
63
+ "requests": "requests",
64
+ "rhoknp": "rhoknp>=1.1.0,<1.3.1",
65
+ "rjieba": "rjieba",
66
+ "rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
67
+ "ruff": "ruff==0.5.1",
68
+ "sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
69
+ "sacremoses": "sacremoses",
70
+ "safetensors": "safetensors>=0.4.1",
71
+ "sagemaker": "sagemaker>=2.31.0",
72
+ "schedulefree": "schedulefree>=1.2.6",
73
+ "scikit-learn": "scikit-learn",
74
+ "scipy": "scipy<1.13.0",
75
+ "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
76
+ "sigopt": "sigopt",
77
+ "starlette": "starlette",
78
+ "sudachipy": "sudachipy>=0.6.6",
79
+ "sudachidict_core": "sudachidict_core>=20220729",
80
+ "tensorboard": "tensorboard",
81
+ "tensorflow-cpu": "tensorflow-cpu>2.9,<2.16",
82
+ "tensorflow": "tensorflow>2.9,<2.16",
83
+ "tensorflow-text": "tensorflow-text<2.16",
84
+ "tensorflow-probability": "tensorflow-probability<0.24",
85
+ "tf2onnx": "tf2onnx",
86
+ "timeout-decorator": "timeout-decorator",
87
+ "tiktoken": "tiktoken",
88
+ "timm": "timm<=1.0.11",
89
+ "tokenizers": "tokenizers>=0.21,<0.22",
90
+ "torch": "torch>=2.0",
91
+ "torchaudio": "torchaudio",
92
+ "torchvision": "torchvision",
93
+ "pyctcdecode": "pyctcdecode>=0.4.0",
94
+ "tqdm": "tqdm>=4.27",
95
+ "unidic": "unidic>=1.0.2",
96
+ "unidic_lite": "unidic_lite>=1.0.7",
97
+ "urllib3": "urllib3<2.0.0",
98
+ "uvicorn": "uvicorn",
99
+ "pytest-rich": "pytest-rich",
100
+ "libcst": "libcst",
101
+ "rich": "rich",
102
+ }
dynamic_module_utils.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities to dynamically load objects from the Hub."""
16
+
17
+ import filecmp
18
+ import hashlib
19
+ import importlib
20
+ import importlib.util
21
+ import os
22
+ import re
23
+ import shutil
24
+ import signal
25
+ import sys
26
+ import threading
27
+ import typing
28
+ import warnings
29
+ from pathlib import Path
30
+ from types import ModuleType
31
+ from typing import Any, Dict, List, Optional, Union
32
+
33
+ from huggingface_hub import try_to_load_from_cache
34
+
35
+ from .utils import (
36
+ HF_MODULES_CACHE,
37
+ TRANSFORMERS_DYNAMIC_MODULE_NAME,
38
+ cached_file,
39
+ extract_commit_hash,
40
+ is_offline_mode,
41
+ logging,
42
+ )
43
+
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+ _HF_REMOTE_CODE_LOCK = threading.Lock()
47
+
48
+
49
+ def init_hf_modules():
50
+ """
51
+ Creates the cache directory for modules with an init, and adds it to the Python path.
52
+ """
53
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
54
+ if HF_MODULES_CACHE in sys.path:
55
+ return
56
+
57
+ sys.path.append(HF_MODULES_CACHE)
58
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
59
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
60
+ if not init_path.exists():
61
+ init_path.touch()
62
+ importlib.invalidate_caches()
63
+
64
+
65
+ def create_dynamic_module(name: Union[str, os.PathLike]) -> None:
66
+ """
67
+ Creates a dynamic module in the cache directory for modules.
68
+
69
+ Args:
70
+ name (`str` or `os.PathLike`):
71
+ The name of the dynamic module to create.
72
+ """
73
+ init_hf_modules()
74
+ dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
75
+ # If the parent module does not exist yet, recursively create it.
76
+ if not dynamic_module_path.parent.exists():
77
+ create_dynamic_module(dynamic_module_path.parent)
78
+ os.makedirs(dynamic_module_path, exist_ok=True)
79
+ init_path = dynamic_module_path / "__init__.py"
80
+ if not init_path.exists():
81
+ init_path.touch()
82
+ # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
83
+ # with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
84
+ importlib.invalidate_caches()
85
+
86
+
87
+ def get_relative_imports(module_file: Union[str, os.PathLike]) -> List[str]:
88
+ """
89
+ Get the list of modules that are relatively imported in a module file.
90
+
91
+ Args:
92
+ module_file (`str` or `os.PathLike`): The module file to inspect.
93
+
94
+ Returns:
95
+ `List[str]`: The list of relative imports in the module.
96
+ """
97
+ with open(module_file, "r", encoding="utf-8") as f:
98
+ content = f.read()
99
+
100
+ # Imports of the form `import .xxx`
101
+ relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
102
+ # Imports of the form `from .xxx import yyy`
103
+ relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
104
+ # Unique-ify
105
+ return list(set(relative_imports))
106
+
107
+
108
+ def get_relative_import_files(module_file: Union[str, os.PathLike]) -> List[str]:
109
+ """
110
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
111
+ imports (if a imports b and b imports c, it will return module files for b and c).
112
+
113
+ Args:
114
+ module_file (`str` or `os.PathLike`): The module file to inspect.
115
+
116
+ Returns:
117
+ `List[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
118
+ of module files a given module needs.
119
+ """
120
+ no_change = False
121
+ files_to_check = [module_file]
122
+ all_relative_imports = []
123
+
124
+ # Let's recurse through all relative imports
125
+ while not no_change:
126
+ new_imports = []
127
+ for f in files_to_check:
128
+ new_imports.extend(get_relative_imports(f))
129
+
130
+ module_path = Path(module_file).parent
131
+ new_import_files = [str(module_path / m) for m in new_imports]
132
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
133
+ files_to_check = [f"{f}.py" for f in new_import_files]
134
+
135
+ no_change = len(new_import_files) == 0
136
+ all_relative_imports.extend(files_to_check)
137
+
138
+ return all_relative_imports
139
+
140
+
141
+ def get_imports(filename: Union[str, os.PathLike]) -> List[str]:
142
+ """
143
+ Extracts all the libraries (not relative imports this time) that are imported in a file.
144
+
145
+ Args:
146
+ filename (`str` or `os.PathLike`): The module file to inspect.
147
+
148
+ Returns:
149
+ `List[str]`: The list of all packages required to use the input module.
150
+ """
151
+ with open(filename, "r", encoding="utf-8") as f:
152
+ content = f.read()
153
+
154
+ # filter out try/except block so in custom code we can have try/except imports
155
+ content = re.sub(r"\s*try\s*:.*?except.*?:", "", content, flags=re.DOTALL)
156
+
157
+ # filter out imports under is_flash_attn_2_available block for avoid import issues in cpu only environment
158
+ content = re.sub(
159
+ r"if is_flash_attn[a-zA-Z0-9_]+available\(\):\s*(from flash_attn\s*.*\s*)+", "", content, flags=re.MULTILINE
160
+ )
161
+
162
+ # Imports of the form `import xxx`
163
+ imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
164
+ # Imports of the form `from xxx import yyy`
165
+ imports += re.findall(r"^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
166
+ # Only keep the top-level module
167
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
168
+ return list(set(imports))
169
+
170
+
171
+ def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
172
+ """
173
+ Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
174
+ library is missing.
175
+
176
+ Args:
177
+ filename (`str` or `os.PathLike`): The module file to check.
178
+
179
+ Returns:
180
+ `List[str]`: The list of relative imports in the file.
181
+ """
182
+ imports = get_imports(filename)
183
+ missing_packages = []
184
+ for imp in imports:
185
+ try:
186
+ importlib.import_module(imp)
187
+ except ImportError as exception:
188
+ logger.warning(f"Encountered exception while importing {imp}: {exception}")
189
+ # Some packages can fail with an ImportError because of a dependency issue.
190
+ # This check avoids hiding such errors.
191
+ # See https://github.com/huggingface/transformers/issues/33604
192
+ if "No module named" in str(exception):
193
+ missing_packages.append(imp)
194
+ else:
195
+ raise
196
+
197
+ if len(missing_packages) > 0:
198
+ raise ImportError(
199
+ "This modeling file requires the following packages that were not found in your environment: "
200
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
201
+ )
202
+
203
+ return get_relative_imports(filename)
204
+
205
+
206
+ def get_class_in_module(
207
+ class_name: str,
208
+ module_path: Union[str, os.PathLike],
209
+ *,
210
+ force_reload: bool = False,
211
+ ) -> typing.Type:
212
+ """
213
+ Import a module on the cache directory for modules and extract a class from it.
214
+
215
+ Args:
216
+ class_name (`str`): The name of the class to import.
217
+ module_path (`str` or `os.PathLike`): The path to the module to import.
218
+ force_reload (`bool`, *optional*, defaults to `False`):
219
+ Whether to reload the dynamic module from file if it already exists in `sys.modules`.
220
+ Otherwise, the module is only reloaded if the file has changed.
221
+
222
+ Returns:
223
+ `typing.Type`: The class looked for.
224
+ """
225
+ name = os.path.normpath(module_path)
226
+ if name.endswith(".py"):
227
+ name = name[:-3]
228
+ name = name.replace(os.path.sep, ".")
229
+ module_file: Path = Path(HF_MODULES_CACHE) / module_path
230
+ with _HF_REMOTE_CODE_LOCK:
231
+ if force_reload:
232
+ sys.modules.pop(name, None)
233
+ importlib.invalidate_caches()
234
+ cached_module: Optional[ModuleType] = sys.modules.get(name)
235
+ module_spec = importlib.util.spec_from_file_location(name, location=module_file)
236
+
237
+ # Hash the module file and all its relative imports to check if we need to reload it
238
+ module_files: List[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
239
+ module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
240
+
241
+ module: ModuleType
242
+ if cached_module is None:
243
+ module = importlib.util.module_from_spec(module_spec)
244
+ # insert it into sys.modules before any loading begins
245
+ sys.modules[name] = module
246
+ else:
247
+ module = cached_module
248
+ # reload in both cases, unless the module is already imported and the hash hits
249
+ if getattr(module, "__transformers_module_hash__", "") != module_hash:
250
+ module_spec.loader.exec_module(module)
251
+ module.__transformers_module_hash__ = module_hash
252
+ return getattr(module, class_name)
253
+
254
+
255
+ def get_cached_module_file(
256
+ pretrained_model_name_or_path: Union[str, os.PathLike],
257
+ module_file: str,
258
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
259
+ force_download: bool = False,
260
+ resume_download: Optional[bool] = None,
261
+ proxies: Optional[Dict[str, str]] = None,
262
+ token: Optional[Union[bool, str]] = None,
263
+ revision: Optional[str] = None,
264
+ local_files_only: bool = False,
265
+ repo_type: Optional[str] = None,
266
+ _commit_hash: Optional[str] = None,
267
+ **deprecated_kwargs,
268
+ ) -> str:
269
+ """
270
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
271
+ Transformers module.
272
+
273
+ Args:
274
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
275
+ This can be either:
276
+
277
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
278
+ huggingface.co.
279
+ - a path to a *directory* containing a configuration file saved using the
280
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
281
+
282
+ module_file (`str`):
283
+ The name of the module file containing the class to look for.
284
+ cache_dir (`str` or `os.PathLike`, *optional*):
285
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
286
+ cache should not be used.
287
+ force_download (`bool`, *optional*, defaults to `False`):
288
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
289
+ exist.
290
+ resume_download:
291
+ Deprecated and ignored. All downloads are now resumed by default when possible.
292
+ Will be removed in v5 of Transformers.
293
+ proxies (`Dict[str, str]`, *optional*):
294
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
295
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
296
+ token (`str` or *bool*, *optional*):
297
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
298
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
299
+ revision (`str`, *optional*, defaults to `"main"`):
300
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
301
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
302
+ identifier allowed by git.
303
+ local_files_only (`bool`, *optional*, defaults to `False`):
304
+ If `True`, will only try to load the tokenizer configuration from local files.
305
+ repo_type (`str`, *optional*):
306
+ Specify the repo type (useful when downloading from a space for instance).
307
+
308
+ <Tip>
309
+
310
+ Passing `token=True` is required when you want to use a private model.
311
+
312
+ </Tip>
313
+
314
+ Returns:
315
+ `str`: The path to the module inside the cache.
316
+ """
317
+ use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
318
+ if use_auth_token is not None:
319
+ warnings.warn(
320
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
321
+ FutureWarning,
322
+ )
323
+ if token is not None:
324
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
325
+ token = use_auth_token
326
+
327
+ if is_offline_mode() and not local_files_only:
328
+ logger.info("Offline mode: forcing local_files_only=True")
329
+ local_files_only = True
330
+
331
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
332
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
333
+ is_local = os.path.isdir(pretrained_model_name_or_path)
334
+ if is_local:
335
+ submodule = os.path.basename(pretrained_model_name_or_path)
336
+ else:
337
+ submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
338
+ cached_module = try_to_load_from_cache(
339
+ pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
340
+ )
341
+
342
+ new_files = []
343
+ try:
344
+ # Load from URL or cache if already cached
345
+ resolved_module_file = cached_file(
346
+ pretrained_model_name_or_path,
347
+ module_file,
348
+ cache_dir=cache_dir,
349
+ force_download=force_download,
350
+ proxies=proxies,
351
+ resume_download=resume_download,
352
+ local_files_only=local_files_only,
353
+ token=token,
354
+ revision=revision,
355
+ repo_type=repo_type,
356
+ _commit_hash=_commit_hash,
357
+ )
358
+ if not is_local and cached_module != resolved_module_file:
359
+ new_files.append(module_file)
360
+
361
+ except EnvironmentError:
362
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
363
+ raise
364
+
365
+ # Check we have all the requirements in our environment
366
+ modules_needed = check_imports(resolved_module_file)
367
+
368
+ # Now we move the module inside our cached dynamic modules.
369
+ full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
370
+ create_dynamic_module(full_submodule)
371
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
372
+ if submodule == os.path.basename(pretrained_model_name_or_path):
373
+ # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
374
+ # has changed since last copy.
375
+ if not (submodule_path / module_file).exists() or not filecmp.cmp(
376
+ resolved_module_file, str(submodule_path / module_file)
377
+ ):
378
+ shutil.copy(resolved_module_file, submodule_path / module_file)
379
+ importlib.invalidate_caches()
380
+ for module_needed in modules_needed:
381
+ module_needed = f"{module_needed}.py"
382
+ module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
383
+ if not (submodule_path / module_needed).exists() or not filecmp.cmp(
384
+ module_needed_file, str(submodule_path / module_needed)
385
+ ):
386
+ shutil.copy(module_needed_file, submodule_path / module_needed)
387
+ importlib.invalidate_caches()
388
+ else:
389
+ # Get the commit hash
390
+ commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
391
+
392
+ # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
393
+ # benefit of versioning.
394
+ submodule_path = submodule_path / commit_hash
395
+ full_submodule = full_submodule + os.path.sep + commit_hash
396
+ create_dynamic_module(full_submodule)
397
+
398
+ if not (submodule_path / module_file).exists():
399
+ shutil.copy(resolved_module_file, submodule_path / module_file)
400
+ importlib.invalidate_caches()
401
+ # Make sure we also have every file with relative
402
+ for module_needed in modules_needed:
403
+ if not (submodule_path / f"{module_needed}.py").exists():
404
+ get_cached_module_file(
405
+ pretrained_model_name_or_path,
406
+ f"{module_needed}.py",
407
+ cache_dir=cache_dir,
408
+ force_download=force_download,
409
+ resume_download=resume_download,
410
+ proxies=proxies,
411
+ token=token,
412
+ revision=revision,
413
+ local_files_only=local_files_only,
414
+ _commit_hash=commit_hash,
415
+ )
416
+ new_files.append(f"{module_needed}.py")
417
+
418
+ if len(new_files) > 0 and revision is None:
419
+ new_files = "\n".join([f"- {f}" for f in new_files])
420
+ repo_type_str = "" if repo_type is None else f"{repo_type}s/"
421
+ url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
422
+ logger.warning(
423
+ f"A new version of the following files was downloaded from {url}:\n{new_files}"
424
+ "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
425
+ "versions of the code file, you can pin a revision."
426
+ )
427
+
428
+ return os.path.join(full_submodule, module_file)
429
+
430
+
431
+ def get_class_from_dynamic_module(
432
+ class_reference: str,
433
+ pretrained_model_name_or_path: Union[str, os.PathLike],
434
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
435
+ force_download: bool = False,
436
+ resume_download: Optional[bool] = None,
437
+ proxies: Optional[Dict[str, str]] = None,
438
+ token: Optional[Union[bool, str]] = None,
439
+ revision: Optional[str] = None,
440
+ local_files_only: bool = False,
441
+ repo_type: Optional[str] = None,
442
+ code_revision: Optional[str] = None,
443
+ **kwargs,
444
+ ) -> typing.Type:
445
+ """
446
+ Extracts a class from a module file, present in the local folder or repository of a model.
447
+
448
+ <Tip warning={true}>
449
+
450
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
451
+ therefore only be called on trusted repos.
452
+
453
+ </Tip>
454
+
455
+
456
+
457
+ Args:
458
+ class_reference (`str`):
459
+ The full name of the class to load, including its module and optionally its repo.
460
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
461
+ This can be either:
462
+
463
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
464
+ huggingface.co.
465
+ - a path to a *directory* containing a configuration file saved using the
466
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
467
+
468
+ This is used when `class_reference` does not specify another repo.
469
+ module_file (`str`):
470
+ The name of the module file containing the class to look for.
471
+ class_name (`str`):
472
+ The name of the class to import in the module.
473
+ cache_dir (`str` or `os.PathLike`, *optional*):
474
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
475
+ cache should not be used.
476
+ force_download (`bool`, *optional*, defaults to `False`):
477
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
478
+ exist.
479
+ resume_download:
480
+ Deprecated and ignored. All downloads are now resumed by default when possible.
481
+ Will be removed in v5 of Transformers.
482
+ proxies (`Dict[str, str]`, *optional*):
483
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
484
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
485
+ token (`str` or `bool`, *optional*):
486
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
487
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
488
+ revision (`str`, *optional*, defaults to `"main"`):
489
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
490
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
491
+ identifier allowed by git.
492
+ local_files_only (`bool`, *optional*, defaults to `False`):
493
+ If `True`, will only try to load the tokenizer configuration from local files.
494
+ repo_type (`str`, *optional*):
495
+ Specify the repo type (useful when downloading from a space for instance).
496
+ code_revision (`str`, *optional*, defaults to `"main"`):
497
+ The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
498
+ rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
499
+ storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
500
+
501
+ <Tip>
502
+
503
+ Passing `token=True` is required when you want to use a private model.
504
+
505
+ </Tip>
506
+
507
+ Returns:
508
+ `typing.Type`: The class, dynamically imported from the module.
509
+
510
+ Examples:
511
+
512
+ ```python
513
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
514
+ # module.
515
+ cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
516
+
517
+ # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
518
+ # module.
519
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
520
+ ```"""
521
+ use_auth_token = kwargs.pop("use_auth_token", None)
522
+ if use_auth_token is not None:
523
+ warnings.warn(
524
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
525
+ FutureWarning,
526
+ )
527
+ if token is not None:
528
+ raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
529
+ token = use_auth_token
530
+
531
+ # Catch the name of the repo if it's specified in `class_reference`
532
+ if "--" in class_reference:
533
+ repo_id, class_reference = class_reference.split("--")
534
+ else:
535
+ repo_id = pretrained_model_name_or_path
536
+ module_file, class_name = class_reference.split(".")
537
+
538
+ if code_revision is None and pretrained_model_name_or_path == repo_id:
539
+ code_revision = revision
540
+ # And lastly we get the class inside our newly created module
541
+ final_module = get_cached_module_file(
542
+ repo_id,
543
+ module_file + ".py",
544
+ cache_dir=cache_dir,
545
+ force_download=force_download,
546
+ resume_download=resume_download,
547
+ proxies=proxies,
548
+ token=token,
549
+ revision=code_revision,
550
+ local_files_only=local_files_only,
551
+ repo_type=repo_type,
552
+ )
553
+ return get_class_in_module(class_name, final_module, force_reload=force_download)
554
+
555
+
556
+ def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[Dict] = None) -> List[str]:
557
+ """
558
+ Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
559
+ adds the proper fields in a config.
560
+
561
+ Args:
562
+ obj (`Any`): The object for which to save the module files.
563
+ folder (`str` or `os.PathLike`): The folder where to save.
564
+ config (`PretrainedConfig` or dictionary, `optional`):
565
+ A config in which to register the auto_map corresponding to this custom object.
566
+
567
+ Returns:
568
+ `List[str]`: The list of files saved.
569
+ """
570
+ if obj.__module__ == "__main__":
571
+ logger.warning(
572
+ f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
573
+ "this code in a separate module so we can include it in the saved folder and make it easier to share via "
574
+ "the Hub."
575
+ )
576
+ return
577
+
578
+ def _set_auto_map_in_config(_config):
579
+ module_name = obj.__class__.__module__
580
+ last_module = module_name.split(".")[-1]
581
+ full_name = f"{last_module}.{obj.__class__.__name__}"
582
+ # Special handling for tokenizers
583
+ if "Tokenizer" in full_name:
584
+ slow_tokenizer_class = None
585
+ fast_tokenizer_class = None
586
+ if obj.__class__.__name__.endswith("Fast"):
587
+ # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
588
+ fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
589
+ if getattr(obj, "slow_tokenizer_class", None) is not None:
590
+ slow_tokenizer = getattr(obj, "slow_tokenizer_class")
591
+ slow_tok_module_name = slow_tokenizer.__module__
592
+ last_slow_tok_module = slow_tok_module_name.split(".")[-1]
593
+ slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
594
+ else:
595
+ # Slow tokenizer: no way to have the fast class
596
+ slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
597
+
598
+ full_name = (slow_tokenizer_class, fast_tokenizer_class)
599
+
600
+ if isinstance(_config, dict):
601
+ auto_map = _config.get("auto_map", {})
602
+ auto_map[obj._auto_class] = full_name
603
+ _config["auto_map"] = auto_map
604
+ elif getattr(_config, "auto_map", None) is not None:
605
+ _config.auto_map[obj._auto_class] = full_name
606
+ else:
607
+ _config.auto_map = {obj._auto_class: full_name}
608
+
609
+ # Add object class to the config auto_map
610
+ if isinstance(config, (list, tuple)):
611
+ for cfg in config:
612
+ _set_auto_map_in_config(cfg)
613
+ elif config is not None:
614
+ _set_auto_map_in_config(config)
615
+
616
+ result = []
617
+ # Copy module file to the output folder.
618
+ object_file = sys.modules[obj.__module__].__file__
619
+ dest_file = Path(folder) / (Path(object_file).name)
620
+ shutil.copy(object_file, dest_file)
621
+ result.append(dest_file)
622
+
623
+ # Gather all relative imports recursively and make sure they are copied as well.
624
+ for needed_file in get_relative_import_files(object_file):
625
+ dest_file = Path(folder) / (Path(needed_file).name)
626
+ shutil.copy(needed_file, dest_file)
627
+ result.append(dest_file)
628
+
629
+ return result
630
+
631
+
632
+ def _raise_timeout_error(signum, frame):
633
+ raise ValueError(
634
+ "Loading this model requires you to execute custom code contained in the model repository on your local "
635
+ "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
636
+ )
637
+
638
+
639
+ TIME_OUT_REMOTE_CODE = 15
640
+
641
+
642
+ def resolve_trust_remote_code(trust_remote_code, model_name, has_local_code, has_remote_code):
643
+ if trust_remote_code is None:
644
+ if has_local_code:
645
+ trust_remote_code = False
646
+ elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
647
+ prev_sig_handler = None
648
+ try:
649
+ prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
650
+ signal.alarm(TIME_OUT_REMOTE_CODE)
651
+ while trust_remote_code is None:
652
+ answer = input(
653
+ f"The repository for {model_name} contains custom code which must be executed to correctly "
654
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
655
+ f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
656
+ f"Do you wish to run the custom code? [y/N] "
657
+ )
658
+ if answer.lower() in ["yes", "y", "1"]:
659
+ trust_remote_code = True
660
+ elif answer.lower() in ["no", "n", "0", ""]:
661
+ trust_remote_code = False
662
+ signal.alarm(0)
663
+ except Exception:
664
+ # OS which does not support signal.SIGALRM
665
+ raise ValueError(
666
+ f"The repository for {model_name} contains custom code which must be executed to correctly "
667
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
668
+ f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
669
+ )
670
+ finally:
671
+ if prev_sig_handler is not None:
672
+ signal.signal(signal.SIGALRM, prev_sig_handler)
673
+ signal.alarm(0)
674
+ elif has_remote_code:
675
+ # For the CI which puts the timeout at 0
676
+ _raise_timeout_error(None, None)
677
+
678
+ if has_remote_code and not has_local_code and not trust_remote_code:
679
+ raise ValueError(
680
+ f"Loading {model_name} requires you to execute the configuration file in that"
681
+ " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
682
+ " set the option `trust_remote_code=True` to remove this error."
683
+ )
684
+
685
+ return trust_remote_code
feature_extraction_sequence_utils.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Sequence feature extraction class for common feature extractors to preprocess sequences.
17
+ """
18
+
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import numpy as np
22
+
23
+ from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
24
+ from .utils import PaddingStrategy, TensorType, is_tf_tensor, is_torch_tensor, logging, to_numpy
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class SequenceFeatureExtractor(FeatureExtractionMixin):
31
+ """
32
+ This is a general feature extraction class for speech recognition.
33
+
34
+ Args:
35
+ feature_size (`int`):
36
+ The feature dimension of the extracted features.
37
+ sampling_rate (`int`):
38
+ The sampling rate at which the audio files should be digitalized expressed in hertz (Hz).
39
+ padding_value (`float`):
40
+ The value that is used to fill the padding values / vectors.
41
+ """
42
+
43
+ def __init__(self, feature_size: int, sampling_rate: int, padding_value: float, **kwargs):
44
+ self.feature_size = feature_size
45
+ self.sampling_rate = sampling_rate
46
+ self.padding_value = padding_value
47
+
48
+ self.padding_side = kwargs.pop("padding_side", "right")
49
+ self.return_attention_mask = kwargs.pop("return_attention_mask", True)
50
+
51
+ super().__init__(**kwargs)
52
+
53
+ def pad(
54
+ self,
55
+ processed_features: Union[
56
+ BatchFeature,
57
+ List[BatchFeature],
58
+ Dict[str, BatchFeature],
59
+ Dict[str, List[BatchFeature]],
60
+ List[Dict[str, BatchFeature]],
61
+ ],
62
+ padding: Union[bool, str, PaddingStrategy] = True,
63
+ max_length: Optional[int] = None,
64
+ truncation: bool = False,
65
+ pad_to_multiple_of: Optional[int] = None,
66
+ return_attention_mask: Optional[bool] = None,
67
+ return_tensors: Optional[Union[str, TensorType]] = None,
68
+ ) -> BatchFeature:
69
+ """
70
+ Pad input values / input vectors or a batch of input values / input vectors up to predefined length or to the
71
+ max sequence length in the batch.
72
+
73
+ Padding side (left/right) padding values are defined at the feature extractor level (with `self.padding_side`,
74
+ `self.padding_value`)
75
+
76
+ <Tip>
77
+
78
+ If the `processed_features` passed are dictionary of numpy arrays, PyTorch tensors or TensorFlow tensors, the
79
+ result will use the same type unless you provide a different tensor type with `return_tensors`. In the case of
80
+ PyTorch tensors, you will lose the specific device of your tensors however.
81
+
82
+ </Tip>
83
+
84
+ Args:
85
+ processed_features ([`BatchFeature`], list of [`BatchFeature`], `Dict[str, List[float]]`, `Dict[str, List[List[float]]` or `List[Dict[str, List[float]]]`):
86
+ Processed inputs. Can represent one input ([`BatchFeature`] or `Dict[str, List[float]]`) or a batch of
87
+ input values / vectors (list of [`BatchFeature`], *Dict[str, List[List[float]]]* or *List[Dict[str,
88
+ List[float]]]*) so you can use this method during preprocessing as well as in a PyTorch Dataloader
89
+ collate function.
90
+
91
+ Instead of `List[float]` you can have tensors (numpy arrays, PyTorch tensors or TensorFlow tensors),
92
+ see the note above for the return type.
93
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
94
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
95
+ index) among:
96
+
97
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
98
+ sequence if provided).
99
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
100
+ acceptable input length for the model if that argument is not provided.
101
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
102
+ lengths).
103
+ max_length (`int`, *optional*):
104
+ Maximum length of the returned list and optionally padding length (see above).
105
+ truncation (`bool`):
106
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
107
+ pad_to_multiple_of (`int`, *optional*):
108
+ If set will pad the sequence to a multiple of the provided value.
109
+
110
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
111
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
112
+ return_attention_mask (`bool`, *optional*):
113
+ Whether to return the attention mask. If left to the default, will return the attention mask according
114
+ to the specific feature_extractor's default.
115
+
116
+ [What are attention masks?](../glossary#attention-mask)
117
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
118
+ If set, will return tensors instead of list of python integers. Acceptable values are:
119
+
120
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
121
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
122
+ - `'np'`: Return Numpy `np.ndarray` objects.
123
+ """
124
+ # If we have a list of dicts, let's convert it in a dict of lists
125
+ # We do this to allow using this method as a collate_fn function in PyTorch Dataloader
126
+ if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)):
127
+ processed_features = {
128
+ key: [example[key] for example in processed_features] for key in processed_features[0].keys()
129
+ }
130
+
131
+ # The model's main input name, usually `input_values`, has be passed for padding
132
+ if self.model_input_names[0] not in processed_features:
133
+ raise ValueError(
134
+ "You should supply an instance of `transformers.BatchFeature` or list of `transformers.BatchFeature`"
135
+ f" to this method that includes {self.model_input_names[0]}, but you provided"
136
+ f" {list(processed_features.keys())}"
137
+ )
138
+
139
+ required_input = processed_features[self.model_input_names[0]]
140
+ return_attention_mask = (
141
+ return_attention_mask if return_attention_mask is not None else self.return_attention_mask
142
+ )
143
+
144
+ if len(required_input) == 0:
145
+ if return_attention_mask:
146
+ processed_features["attention_mask"] = []
147
+ return processed_features
148
+
149
+ # If we have PyTorch/TF tensors or lists as inputs, we cast them as Numpy arrays
150
+ # and rebuild them afterwards if no return_tensors is specified
151
+ # Note that we lose the specific device the tensor may be on for PyTorch
152
+
153
+ first_element = required_input[0]
154
+ if isinstance(first_element, (list, tuple)):
155
+ # first_element might be an empty list/tuple in some edge cases so we grab the first non empty element.
156
+ index = 0
157
+ while len(required_input[index]) == 0:
158
+ index += 1
159
+ if index < len(required_input):
160
+ first_element = required_input[index][0]
161
+
162
+ if return_tensors is None:
163
+ if is_tf_tensor(first_element):
164
+ return_tensors = "tf"
165
+ elif is_torch_tensor(first_element):
166
+ return_tensors = "pt"
167
+ elif isinstance(first_element, (int, float, list, tuple, np.ndarray)):
168
+ return_tensors = "np"
169
+ else:
170
+ raise ValueError(
171
+ f"type of {first_element} unknown: {type(first_element)}. "
172
+ "Should be one of a python, numpy, pytorch or tensorflow object."
173
+ )
174
+
175
+ for key, value in processed_features.items():
176
+ if isinstance(value[0], (int, float)):
177
+ processed_features[key] = to_numpy(value)
178
+ else:
179
+ processed_features[key] = [to_numpy(v) for v in value]
180
+
181
+ # Convert padding_strategy in PaddingStrategy
182
+ padding_strategy = self._get_padding_strategies(padding=padding, max_length=max_length)
183
+
184
+ required_input = processed_features[self.model_input_names[0]]
185
+
186
+ batch_size = len(required_input)
187
+ if not all(len(v) == batch_size for v in processed_features.values()):
188
+ raise ValueError("Some items in the output dictionary have a different batch size than others.")
189
+
190
+ truncated_inputs = []
191
+ for i in range(batch_size):
192
+ inputs = {k: v[i] for k, v in processed_features.items()}
193
+ # truncation
194
+ inputs_slice = self._truncate(
195
+ inputs,
196
+ max_length=max_length,
197
+ pad_to_multiple_of=pad_to_multiple_of,
198
+ truncation=truncation,
199
+ )
200
+ truncated_inputs.append(inputs_slice)
201
+
202
+ if padding_strategy == PaddingStrategy.LONGEST:
203
+ # make sure that `max_length` cannot be longer than the longest truncated length
204
+ max_length = max(len(input_slice[self.model_input_names[0]]) for input_slice in truncated_inputs)
205
+ padding_strategy = PaddingStrategy.MAX_LENGTH
206
+
207
+ batch_outputs = {}
208
+ for i in range(batch_size):
209
+ # padding
210
+ outputs = self._pad(
211
+ truncated_inputs[i],
212
+ max_length=max_length,
213
+ padding_strategy=padding_strategy,
214
+ pad_to_multiple_of=pad_to_multiple_of,
215
+ return_attention_mask=return_attention_mask,
216
+ )
217
+
218
+ for key, value in outputs.items():
219
+ if key not in batch_outputs:
220
+ batch_outputs[key] = []
221
+ if value.dtype is np.dtype(np.float64):
222
+ value = value.astype(np.float32)
223
+ batch_outputs[key].append(value)
224
+
225
+ return BatchFeature(batch_outputs, tensor_type=return_tensors)
226
+
227
+ def _pad(
228
+ self,
229
+ processed_features: Union[Dict[str, np.ndarray], BatchFeature],
230
+ max_length: Optional[int] = None,
231
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
232
+ pad_to_multiple_of: Optional[int] = None,
233
+ return_attention_mask: Optional[bool] = None,
234
+ ) -> dict:
235
+ """
236
+ Pad inputs (on left/right and up to predefined length or max length in the batch)
237
+
238
+ Args:
239
+ processed_features (`Union[Dict[str, np.ndarray], BatchFeature]`):
240
+ Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
241
+ of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
242
+ max_length (`int`, *optional*):
243
+ Maximum length of the returned list and optionally padding length (see below)
244
+ padding_strategy (`PaddingStrategy`, *optional*, default to `PaddingStrategy.DO_NOT_PAD`):
245
+ PaddingStrategy to use for padding.
246
+
247
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
248
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
249
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
250
+ The feature_extractor padding sides are defined in self.padding_side:
251
+
252
+ - 'left': pads on the left of the sequences
253
+ - 'right': pads on the right of the sequences
254
+ pad_to_multiple_of (`int`, *optional*):
255
+ Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
256
+ enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
257
+ which benefit from having sequence lengths be a multiple of 128.
258
+ return_attention_mask (`bool`, *optional*):
259
+ Set to False to avoid returning attention mask (default: set to model specifics)
260
+ """
261
+ required_input = processed_features[self.model_input_names[0]]
262
+
263
+ if padding_strategy == PaddingStrategy.LONGEST:
264
+ max_length = len(required_input)
265
+
266
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
267
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
268
+
269
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) < max_length
270
+
271
+ if return_attention_mask and "attention_mask" not in processed_features:
272
+ processed_features["attention_mask"] = np.ones(len(required_input), dtype=np.int32)
273
+
274
+ if needs_to_be_padded:
275
+ difference = max_length - len(required_input)
276
+ if self.padding_side == "right":
277
+ if return_attention_mask:
278
+ processed_features["attention_mask"] = np.pad(
279
+ processed_features["attention_mask"], (0, difference)
280
+ )
281
+ padding_shape = ((0, difference), (0, 0)) if self.feature_size > 1 else (0, difference)
282
+ processed_features[self.model_input_names[0]] = np.pad(
283
+ required_input, padding_shape, "constant", constant_values=self.padding_value
284
+ )
285
+ elif self.padding_side == "left":
286
+ if return_attention_mask:
287
+ processed_features["attention_mask"] = np.pad(
288
+ processed_features["attention_mask"], (difference, 0)
289
+ )
290
+ padding_shape = ((difference, 0), (0, 0)) if self.feature_size > 1 else (difference, 0)
291
+ processed_features[self.model_input_names[0]] = np.pad(
292
+ required_input, padding_shape, "constant", constant_values=self.padding_value
293
+ )
294
+ else:
295
+ raise ValueError("Invalid padding strategy:" + str(self.padding_side))
296
+
297
+ return processed_features
298
+
299
+ def _truncate(
300
+ self,
301
+ processed_features: Union[Dict[str, np.ndarray], BatchFeature],
302
+ max_length: Optional[int] = None,
303
+ pad_to_multiple_of: Optional[int] = None,
304
+ truncation: Optional[bool] = None,
305
+ ):
306
+ """
307
+ Truncate inputs to predefined length or max length in the batch
308
+
309
+ Args:
310
+ processed_features(`Union[Dict[str, np.ndarray], BatchFeature]`):
311
+ Dictionary of input values (`np.ndarray[float]`) / input vectors (`List[np.ndarray[float]]`) or batch
312
+ of inputs values (`List[np.ndarray[int]]`) / input vectors (`List[np.ndarray[int]]`)
313
+ max_length (`int`, *optional*):
314
+ maximum length of the returned list and optionally padding length (see below)
315
+ pad_to_multiple_of (`int`, *optional*) :
316
+ Integer if set will pad the sequence to a multiple of the provided value. This is especially useful to
317
+ enable the use of Tensor Core on NVIDIA hardware with compute capability `>= 7.5` (Volta), or on TPUs
318
+ which benefit from having sequence lengths be a multiple of 128.
319
+ truncation (`bool`, *optional*):
320
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
321
+ """
322
+ if not truncation:
323
+ return processed_features
324
+ elif truncation and max_length is None:
325
+ raise ValueError("When setting ``truncation=True``, make sure that ``max_length`` is defined.")
326
+
327
+ required_input = processed_features[self.model_input_names[0]]
328
+
329
+ # find `max_length` that fits `pad_to_multiple_of`
330
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
331
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
332
+
333
+ needs_to_be_truncated = len(required_input) > max_length
334
+
335
+ if needs_to_be_truncated:
336
+ processed_features[self.model_input_names[0]] = processed_features[self.model_input_names[0]][:max_length]
337
+ if "attention_mask" in processed_features:
338
+ processed_features["attention_mask"] = processed_features["attention_mask"][:max_length]
339
+
340
+ return processed_features
341
+
342
+ def _get_padding_strategies(self, padding=False, max_length=None):
343
+ """
344
+ Find the correct padding strategy
345
+ """
346
+
347
+ # Get padding strategy
348
+ if padding is not False:
349
+ if padding is True:
350
+ padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
351
+ elif not isinstance(padding, PaddingStrategy):
352
+ padding_strategy = PaddingStrategy(padding)
353
+ elif isinstance(padding, PaddingStrategy):
354
+ padding_strategy = padding
355
+ else:
356
+ padding_strategy = PaddingStrategy.DO_NOT_PAD
357
+
358
+ # Set max length if needed
359
+ if max_length is None:
360
+ if padding_strategy == PaddingStrategy.MAX_LENGTH:
361
+ raise ValueError(
362
+ f"When setting ``padding={PaddingStrategy.MAX_LENGTH}``, make sure that max_length is defined"
363
+ )
364
+
365
+ # Test if we have a padding value
366
+ if padding_strategy != PaddingStrategy.DO_NOT_PAD and (self.padding_value is None):
367
+ raise ValueError(
368
+ "Asking to pad but the feature_extractor does not have a padding value. Please select a value to use"
369
+ " as `padding_value`. For example: `feature_extractor.padding_value = 0.0`."
370
+ )
371
+
372
+ return padding_strategy
feature_extraction_utils.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extraction saving/loading class for common feature extractors.
17
+ """
18
+
19
+ import copy
20
+ import json
21
+ import os
22
+ import warnings
23
+ from collections import UserDict
24
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
25
+
26
+ import numpy as np
27
+
28
+ from .dynamic_module_utils import custom_object_save
29
+ from .utils import (
30
+ FEATURE_EXTRACTOR_NAME,
31
+ PushToHubMixin,
32
+ TensorType,
33
+ add_model_info_to_auto_map,
34
+ add_model_info_to_custom_pipelines,
35
+ cached_file,
36
+ copy_func,
37
+ download_url,
38
+ is_flax_available,
39
+ is_jax_tensor,
40
+ is_numpy_array,
41
+ is_offline_mode,
42
+ is_remote_url,
43
+ is_tf_available,
44
+ is_torch_available,
45
+ is_torch_device,
46
+ is_torch_dtype,
47
+ logging,
48
+ requires_backends,
49
+ )
50
+
51
+
52
+ if TYPE_CHECKING:
53
+ if is_torch_available():
54
+ import torch # noqa
55
+
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ PreTrainedFeatureExtractor = Union["SequenceFeatureExtractor"] # noqa: F821
60
+
61
+
62
+ class BatchFeature(UserDict):
63
+ r"""
64
+ Holds the output of the [`~SequenceFeatureExtractor.pad`] and feature extractor specific `__call__` methods.
65
+
66
+ This class is derived from a python dictionary and can be used as a dictionary.
67
+
68
+ Args:
69
+ data (`dict`, *optional*):
70
+ Dictionary of lists/arrays/tensors returned by the __call__/pad methods ('input_values', 'attention_mask',
71
+ etc.).
72
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
73
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
74
+ initialization.
75
+ """
76
+
77
+ def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None):
78
+ super().__init__(data)
79
+ self.convert_to_tensors(tensor_type=tensor_type)
80
+
81
+ def __getitem__(self, item: str) -> Union[Any]:
82
+ """
83
+ If the key is a string, returns the value of the dict associated to `key` ('input_values', 'attention_mask',
84
+ etc.).
85
+ """
86
+ if isinstance(item, str):
87
+ return self.data[item]
88
+ else:
89
+ raise KeyError("Indexing with integers is not available when using Python based feature extractors")
90
+
91
+ def __getattr__(self, item: str):
92
+ try:
93
+ return self.data[item]
94
+ except KeyError:
95
+ raise AttributeError
96
+
97
+ def __getstate__(self):
98
+ return {"data": self.data}
99
+
100
+ def __setstate__(self, state):
101
+ if "data" in state:
102
+ self.data = state["data"]
103
+
104
+ # Copied from transformers.tokenization_utils_base.BatchEncoding.keys
105
+ def keys(self):
106
+ return self.data.keys()
107
+
108
+ # Copied from transformers.tokenization_utils_base.BatchEncoding.values
109
+ def values(self):
110
+ return self.data.values()
111
+
112
+ # Copied from transformers.tokenization_utils_base.BatchEncoding.items
113
+ def items(self):
114
+ return self.data.items()
115
+
116
+ def _get_is_as_tensor_fns(self, tensor_type: Optional[Union[str, TensorType]] = None):
117
+ if tensor_type is None:
118
+ return None, None
119
+
120
+ # Convert to TensorType
121
+ if not isinstance(tensor_type, TensorType):
122
+ tensor_type = TensorType(tensor_type)
123
+
124
+ # Get a function reference for the correct framework
125
+ if tensor_type == TensorType.TENSORFLOW:
126
+ if not is_tf_available():
127
+ raise ImportError(
128
+ "Unable to convert output to TensorFlow tensors format, TensorFlow is not installed."
129
+ )
130
+ import tensorflow as tf
131
+
132
+ as_tensor = tf.constant
133
+ is_tensor = tf.is_tensor
134
+ elif tensor_type == TensorType.PYTORCH:
135
+ if not is_torch_available():
136
+ raise ImportError("Unable to convert output to PyTorch tensors format, PyTorch is not installed.")
137
+ import torch # noqa
138
+
139
+ def as_tensor(value):
140
+ if isinstance(value, (list, tuple)) and len(value) > 0:
141
+ if isinstance(value[0], np.ndarray):
142
+ value = np.array(value)
143
+ elif (
144
+ isinstance(value[0], (list, tuple))
145
+ and len(value[0]) > 0
146
+ and isinstance(value[0][0], np.ndarray)
147
+ ):
148
+ value = np.array(value)
149
+ if isinstance(value, np.ndarray):
150
+ return torch.from_numpy(value)
151
+ else:
152
+ return torch.tensor(value)
153
+
154
+ is_tensor = torch.is_tensor
155
+ elif tensor_type == TensorType.JAX:
156
+ if not is_flax_available():
157
+ raise ImportError("Unable to convert output to JAX tensors format, JAX is not installed.")
158
+ import jax.numpy as jnp # noqa: F811
159
+
160
+ as_tensor = jnp.array
161
+ is_tensor = is_jax_tensor
162
+ else:
163
+
164
+ def as_tensor(value, dtype=None):
165
+ if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
166
+ value_lens = [len(val) for val in value]
167
+ if len(set(value_lens)) > 1 and dtype is None:
168
+ # we have a ragged list so handle explicitly
169
+ value = as_tensor([np.asarray(val) for val in value], dtype=object)
170
+ return np.asarray(value, dtype=dtype)
171
+
172
+ is_tensor = is_numpy_array
173
+ return is_tensor, as_tensor
174
+
175
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
176
+ """
177
+ Convert the inner content to tensors.
178
+
179
+ Args:
180
+ tensor_type (`str` or [`~utils.TensorType`], *optional*):
181
+ The type of tensors to use. If `str`, should be one of the values of the enum [`~utils.TensorType`]. If
182
+ `None`, no modification is done.
183
+ """
184
+ if tensor_type is None:
185
+ return self
186
+
187
+ is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type)
188
+
189
+ # Do the tensor conversion in batch
190
+ for key, value in self.items():
191
+ try:
192
+ if not is_tensor(value):
193
+ tensor = as_tensor(value)
194
+
195
+ self[key] = tensor
196
+ except: # noqa E722
197
+ if key == "overflowing_values":
198
+ raise ValueError("Unable to create tensor returning overflowing values of different lengths. ")
199
+ raise ValueError(
200
+ "Unable to create tensor, you should probably activate padding "
201
+ "with 'padding=True' to have batched tensors with the same length."
202
+ )
203
+
204
+ return self
205
+
206
+ def to(self, *args, **kwargs) -> "BatchFeature":
207
+ """
208
+ Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
209
+ different `dtypes` and sending the `BatchFeature` to a different `device`.
210
+
211
+ Args:
212
+ args (`Tuple`):
213
+ Will be passed to the `to(...)` function of the tensors.
214
+ kwargs (`Dict`, *optional*):
215
+ Will be passed to the `to(...)` function of the tensors.
216
+ To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`).
217
+
218
+ Returns:
219
+ [`BatchFeature`]: The same instance after modification.
220
+ """
221
+ requires_backends(self, ["torch"])
222
+ import torch # noqa
223
+
224
+ new_data = {}
225
+ device = kwargs.get("device")
226
+ non_blocking = kwargs.get("non_blocking", False)
227
+ # Check if the args are a device or a dtype
228
+ if device is None and len(args) > 0:
229
+ # device should be always the first argument
230
+ arg = args[0]
231
+ if is_torch_dtype(arg):
232
+ # The first argument is a dtype
233
+ pass
234
+ elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int):
235
+ device = arg
236
+ else:
237
+ # it's something else
238
+ raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.")
239
+ # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor`
240
+ for k, v in self.items():
241
+ # check if v is a floating point
242
+ if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
243
+ # cast and send to device
244
+ new_data[k] = v.to(*args, **kwargs)
245
+ elif isinstance(v, torch.Tensor) and device is not None:
246
+ new_data[k] = v.to(device=device, non_blocking=non_blocking)
247
+ else:
248
+ new_data[k] = v
249
+ self.data = new_data
250
+ return self
251
+
252
+
253
+ class FeatureExtractionMixin(PushToHubMixin):
254
+ """
255
+ This is a feature extraction mixin used to provide saving/loading functionality for sequential and image feature
256
+ extractors.
257
+ """
258
+
259
+ _auto_class = None
260
+
261
+ def __init__(self, **kwargs):
262
+ """Set elements of `kwargs` as attributes."""
263
+ # Pop "processor_class" as it should be saved as private attribute
264
+ self._processor_class = kwargs.pop("processor_class", None)
265
+ # Additional attributes without default values
266
+ for key, value in kwargs.items():
267
+ try:
268
+ setattr(self, key, value)
269
+ except AttributeError as err:
270
+ logger.error(f"Can't set {key} with value {value} for {self}")
271
+ raise err
272
+
273
+ def _set_processor_class(self, processor_class: str):
274
+ """Sets processor class as an attribute."""
275
+ self._processor_class = processor_class
276
+
277
+ @classmethod
278
+ def from_pretrained(
279
+ cls,
280
+ pretrained_model_name_or_path: Union[str, os.PathLike],
281
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
282
+ force_download: bool = False,
283
+ local_files_only: bool = False,
284
+ token: Optional[Union[str, bool]] = None,
285
+ revision: str = "main",
286
+ **kwargs,
287
+ ):
288
+ r"""
289
+ Instantiate a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a feature extractor, *e.g.* a
290
+ derived class of [`SequenceFeatureExtractor`].
291
+
292
+ Args:
293
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
294
+ This can be either:
295
+
296
+ - a string, the *model id* of a pretrained feature_extractor hosted inside a model repo on
297
+ huggingface.co.
298
+ - a path to a *directory* containing a feature extractor file saved using the
299
+ [`~feature_extraction_utils.FeatureExtractionMixin.save_pretrained`] method, e.g.,
300
+ `./my_model_directory/`.
301
+ - a path or url to a saved feature extractor JSON *file*, e.g.,
302
+ `./my_model_directory/preprocessor_config.json`.
303
+ cache_dir (`str` or `os.PathLike`, *optional*):
304
+ Path to a directory in which a downloaded pretrained model feature extractor should be cached if the
305
+ standard cache should not be used.
306
+ force_download (`bool`, *optional*, defaults to `False`):
307
+ Whether or not to force to (re-)download the feature extractor files and override the cached versions
308
+ if they exist.
309
+ resume_download:
310
+ Deprecated and ignored. All downloads are now resumed by default when possible.
311
+ Will be removed in v5 of Transformers.
312
+ proxies (`Dict[str, str]`, *optional*):
313
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
314
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
315
+ token (`str` or `bool`, *optional*):
316
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
317
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
318
+ revision (`str`, *optional*, defaults to `"main"`):
319
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
320
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
321
+ identifier allowed by git.
322
+
323
+
324
+ <Tip>
325
+
326
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
327
+
328
+ </Tip>
329
+
330
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
331
+ If `False`, then this function returns just the final feature extractor object. If `True`, then this
332
+ functions returns a `Tuple(feature_extractor, unused_kwargs)` where *unused_kwargs* is a dictionary
333
+ consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part of
334
+ `kwargs` which has not been used to update `feature_extractor` and is otherwise ignored.
335
+ kwargs (`Dict[str, Any]`, *optional*):
336
+ The values in kwargs of any keys which are feature extractor attributes will be used to override the
337
+ loaded values. Behavior concerning key/value pairs whose keys are *not* feature extractor attributes is
338
+ controlled by the `return_unused_kwargs` keyword parameter.
339
+
340
+ Returns:
341
+ A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`].
342
+
343
+ Examples:
344
+
345
+ ```python
346
+ # We can't instantiate directly the base class *FeatureExtractionMixin* nor *SequenceFeatureExtractor* so let's show the examples on a
347
+ # derived class: *Wav2Vec2FeatureExtractor*
348
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
349
+ "facebook/wav2vec2-base-960h"
350
+ ) # Download feature_extraction_config from huggingface.co and cache.
351
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
352
+ "./test/saved_model/"
353
+ ) # E.g. feature_extractor (or model) was saved using *save_pretrained('./test/saved_model/')*
354
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("./test/saved_model/preprocessor_config.json")
355
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
356
+ "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False
357
+ )
358
+ assert feature_extractor.return_attention_mask is False
359
+ feature_extractor, unused_kwargs = Wav2Vec2FeatureExtractor.from_pretrained(
360
+ "facebook/wav2vec2-base-960h", return_attention_mask=False, foo=False, return_unused_kwargs=True
361
+ )
362
+ assert feature_extractor.return_attention_mask is False
363
+ assert unused_kwargs == {"foo": False}
364
+ ```"""
365
+ kwargs["cache_dir"] = cache_dir
366
+ kwargs["force_download"] = force_download
367
+ kwargs["local_files_only"] = local_files_only
368
+ kwargs["revision"] = revision
369
+
370
+ use_auth_token = kwargs.pop("use_auth_token", None)
371
+ if use_auth_token is not None:
372
+ warnings.warn(
373
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
374
+ FutureWarning,
375
+ )
376
+ if token is not None:
377
+ raise ValueError(
378
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
379
+ )
380
+ token = use_auth_token
381
+
382
+ if token is not None:
383
+ kwargs["token"] = token
384
+
385
+ feature_extractor_dict, kwargs = cls.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
386
+
387
+ return cls.from_dict(feature_extractor_dict, **kwargs)
388
+
389
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
390
+ """
391
+ Save a feature_extractor object to the directory `save_directory`, so that it can be re-loaded using the
392
+ [`~feature_extraction_utils.FeatureExtractionMixin.from_pretrained`] class method.
393
+
394
+ Args:
395
+ save_directory (`str` or `os.PathLike`):
396
+ Directory where the feature extractor JSON file will be saved (will be created if it does not exist).
397
+ push_to_hub (`bool`, *optional*, defaults to `False`):
398
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
399
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
400
+ namespace).
401
+ kwargs (`Dict[str, Any]`, *optional*):
402
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
403
+ """
404
+ use_auth_token = kwargs.pop("use_auth_token", None)
405
+
406
+ if use_auth_token is not None:
407
+ warnings.warn(
408
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
409
+ FutureWarning,
410
+ )
411
+ if kwargs.get("token", None) is not None:
412
+ raise ValueError(
413
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
414
+ )
415
+ kwargs["token"] = use_auth_token
416
+
417
+ if os.path.isfile(save_directory):
418
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
419
+
420
+ os.makedirs(save_directory, exist_ok=True)
421
+
422
+ if push_to_hub:
423
+ commit_message = kwargs.pop("commit_message", None)
424
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
425
+ repo_id = self._create_repo(repo_id, **kwargs)
426
+ files_timestamps = self._get_files_timestamps(save_directory)
427
+
428
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
429
+ # loaded from the Hub.
430
+ if self._auto_class is not None:
431
+ custom_object_save(self, save_directory, config=self)
432
+
433
+ # If we save using the predefined names, we can load using `from_pretrained`
434
+ output_feature_extractor_file = os.path.join(save_directory, FEATURE_EXTRACTOR_NAME)
435
+
436
+ self.to_json_file(output_feature_extractor_file)
437
+ logger.info(f"Feature extractor saved in {output_feature_extractor_file}")
438
+
439
+ if push_to_hub:
440
+ self._upload_modified_files(
441
+ save_directory,
442
+ repo_id,
443
+ files_timestamps,
444
+ commit_message=commit_message,
445
+ token=kwargs.get("token"),
446
+ )
447
+
448
+ return [output_feature_extractor_file]
449
+
450
+ @classmethod
451
+ def get_feature_extractor_dict(
452
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
453
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
454
+ """
455
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
456
+ feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] using `from_dict`.
457
+
458
+ Parameters:
459
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
460
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
461
+
462
+ Returns:
463
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the feature extractor object.
464
+ """
465
+ cache_dir = kwargs.pop("cache_dir", None)
466
+ force_download = kwargs.pop("force_download", False)
467
+ resume_download = kwargs.pop("resume_download", None)
468
+ proxies = kwargs.pop("proxies", None)
469
+ subfolder = kwargs.pop("subfolder", None)
470
+ token = kwargs.pop("token", None)
471
+ use_auth_token = kwargs.pop("use_auth_token", None)
472
+ local_files_only = kwargs.pop("local_files_only", False)
473
+ revision = kwargs.pop("revision", None)
474
+
475
+ if use_auth_token is not None:
476
+ warnings.warn(
477
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
478
+ FutureWarning,
479
+ )
480
+ if token is not None:
481
+ raise ValueError(
482
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
483
+ )
484
+ token = use_auth_token
485
+
486
+ from_pipeline = kwargs.pop("_from_pipeline", None)
487
+ from_auto_class = kwargs.pop("_from_auto", False)
488
+
489
+ user_agent = {"file_type": "feature extractor", "from_auto_class": from_auto_class}
490
+ if from_pipeline is not None:
491
+ user_agent["using_pipeline"] = from_pipeline
492
+
493
+ if is_offline_mode() and not local_files_only:
494
+ logger.info("Offline mode: forcing local_files_only=True")
495
+ local_files_only = True
496
+
497
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
498
+ is_local = os.path.isdir(pretrained_model_name_or_path)
499
+ if os.path.isdir(pretrained_model_name_or_path):
500
+ feature_extractor_file = os.path.join(pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME)
501
+ if os.path.isfile(pretrained_model_name_or_path):
502
+ resolved_feature_extractor_file = pretrained_model_name_or_path
503
+ is_local = True
504
+ elif is_remote_url(pretrained_model_name_or_path):
505
+ feature_extractor_file = pretrained_model_name_or_path
506
+ resolved_feature_extractor_file = download_url(pretrained_model_name_or_path)
507
+ else:
508
+ feature_extractor_file = FEATURE_EXTRACTOR_NAME
509
+ try:
510
+ # Load from local folder or from cache or download from model Hub and cache
511
+ resolved_feature_extractor_file = cached_file(
512
+ pretrained_model_name_or_path,
513
+ feature_extractor_file,
514
+ cache_dir=cache_dir,
515
+ force_download=force_download,
516
+ proxies=proxies,
517
+ resume_download=resume_download,
518
+ local_files_only=local_files_only,
519
+ subfolder=subfolder,
520
+ token=token,
521
+ user_agent=user_agent,
522
+ revision=revision,
523
+ )
524
+ except EnvironmentError:
525
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
526
+ # the original exception.
527
+ raise
528
+ except Exception:
529
+ # For any other exception, we throw a generic error.
530
+ raise EnvironmentError(
531
+ f"Can't load feature extractor for '{pretrained_model_name_or_path}'. If you were trying to load"
532
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
533
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
534
+ f" directory containing a {FEATURE_EXTRACTOR_NAME} file"
535
+ )
536
+
537
+ try:
538
+ # Load feature_extractor dict
539
+ with open(resolved_feature_extractor_file, "r", encoding="utf-8") as reader:
540
+ text = reader.read()
541
+ feature_extractor_dict = json.loads(text)
542
+
543
+ except json.JSONDecodeError:
544
+ raise EnvironmentError(
545
+ f"It looks like the config file at '{resolved_feature_extractor_file}' is not a valid JSON file."
546
+ )
547
+
548
+ if is_local:
549
+ logger.info(f"loading configuration file {resolved_feature_extractor_file}")
550
+ else:
551
+ logger.info(
552
+ f"loading configuration file {feature_extractor_file} from cache at {resolved_feature_extractor_file}"
553
+ )
554
+
555
+ if not is_local:
556
+ if "auto_map" in feature_extractor_dict:
557
+ feature_extractor_dict["auto_map"] = add_model_info_to_auto_map(
558
+ feature_extractor_dict["auto_map"], pretrained_model_name_or_path
559
+ )
560
+ if "custom_pipelines" in feature_extractor_dict:
561
+ feature_extractor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
562
+ feature_extractor_dict["custom_pipelines"], pretrained_model_name_or_path
563
+ )
564
+
565
+ return feature_extractor_dict, kwargs
566
+
567
+ @classmethod
568
+ def from_dict(cls, feature_extractor_dict: Dict[str, Any], **kwargs) -> PreTrainedFeatureExtractor:
569
+ """
570
+ Instantiates a type of [`~feature_extraction_utils.FeatureExtractionMixin`] from a Python dictionary of
571
+ parameters.
572
+
573
+ Args:
574
+ feature_extractor_dict (`Dict[str, Any]`):
575
+ Dictionary that will be used to instantiate the feature extractor object. Such a dictionary can be
576
+ retrieved from a pretrained checkpoint by leveraging the
577
+ [`~feature_extraction_utils.FeatureExtractionMixin.to_dict`] method.
578
+ kwargs (`Dict[str, Any]`):
579
+ Additional parameters from which to initialize the feature extractor object.
580
+
581
+ Returns:
582
+ [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature extractor object instantiated from those
583
+ parameters.
584
+ """
585
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
586
+
587
+ # Update feature_extractor with kwargs if needed
588
+ to_remove = []
589
+ for key, value in kwargs.items():
590
+ if key in feature_extractor_dict:
591
+ feature_extractor_dict[key] = value
592
+ to_remove.append(key)
593
+ for key in to_remove:
594
+ kwargs.pop(key, None)
595
+
596
+ feature_extractor = cls(**feature_extractor_dict)
597
+
598
+ logger.info(f"Feature extractor {feature_extractor}")
599
+ if return_unused_kwargs:
600
+ return feature_extractor, kwargs
601
+ else:
602
+ return feature_extractor
603
+
604
+ def to_dict(self) -> Dict[str, Any]:
605
+ """
606
+ Serializes this instance to a Python dictionary. Returns:
607
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
608
+ """
609
+ output = copy.deepcopy(self.__dict__)
610
+ output["feature_extractor_type"] = self.__class__.__name__
611
+ if "mel_filters" in output:
612
+ del output["mel_filters"]
613
+ if "window" in output:
614
+ del output["window"]
615
+ return output
616
+
617
+ @classmethod
618
+ def from_json_file(cls, json_file: Union[str, os.PathLike]) -> PreTrainedFeatureExtractor:
619
+ """
620
+ Instantiates a feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`] from the path to
621
+ a JSON file of parameters.
622
+
623
+ Args:
624
+ json_file (`str` or `os.PathLike`):
625
+ Path to the JSON file containing the parameters.
626
+
627
+ Returns:
628
+ A feature extractor of type [`~feature_extraction_utils.FeatureExtractionMixin`]: The feature_extractor
629
+ object instantiated from that JSON file.
630
+ """
631
+ with open(json_file, "r", encoding="utf-8") as reader:
632
+ text = reader.read()
633
+ feature_extractor_dict = json.loads(text)
634
+ return cls(**feature_extractor_dict)
635
+
636
+ def to_json_string(self) -> str:
637
+ """
638
+ Serializes this instance to a JSON string.
639
+
640
+ Returns:
641
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
642
+ """
643
+ dictionary = self.to_dict()
644
+
645
+ for key, value in dictionary.items():
646
+ if isinstance(value, np.ndarray):
647
+ dictionary[key] = value.tolist()
648
+
649
+ # make sure private name "_processor_class" is correctly
650
+ # saved as "processor_class"
651
+ _processor_class = dictionary.pop("_processor_class", None)
652
+ if _processor_class is not None:
653
+ dictionary["processor_class"] = _processor_class
654
+
655
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
656
+
657
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
658
+ """
659
+ Save this instance to a JSON file.
660
+
661
+ Args:
662
+ json_file_path (`str` or `os.PathLike`):
663
+ Path to the JSON file in which this feature_extractor instance's parameters will be saved.
664
+ """
665
+ with open(json_file_path, "w", encoding="utf-8") as writer:
666
+ writer.write(self.to_json_string())
667
+
668
+ def __repr__(self):
669
+ return f"{self.__class__.__name__} {self.to_json_string()}"
670
+
671
+ @classmethod
672
+ def register_for_auto_class(cls, auto_class="AutoFeatureExtractor"):
673
+ """
674
+ Register this class with a given auto class. This should only be used for custom feature extractors as the ones
675
+ in the library are already mapped with `AutoFeatureExtractor`.
676
+
677
+ <Tip warning={true}>
678
+
679
+ This API is experimental and may have some slight breaking changes in the next releases.
680
+
681
+ </Tip>
682
+
683
+ Args:
684
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoFeatureExtractor"`):
685
+ The auto class to register this new feature extractor with.
686
+ """
687
+ if not isinstance(auto_class, str):
688
+ auto_class = auto_class.__name__
689
+
690
+ import transformers.models.auto as auto_module
691
+
692
+ if not hasattr(auto_module, auto_class):
693
+ raise ValueError(f"{auto_class} is not a valid auto class.")
694
+
695
+ cls._auto_class = auto_class
696
+
697
+
698
+ FeatureExtractionMixin.push_to_hub = copy_func(FeatureExtractionMixin.push_to_hub)
699
+ if FeatureExtractionMixin.push_to_hub.__doc__ is not None:
700
+ FeatureExtractionMixin.push_to_hub.__doc__ = FeatureExtractionMixin.push_to_hub.__doc__.format(
701
+ object="feature extractor", object_class="AutoFeatureExtractor", object_files="feature extractor file"
702
+ )
file_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ File utilities: utilities related to download and cache models
16
+
17
+ This module should not be update anymore and is only left for backward compatibility.
18
+ """
19
+
20
+ from huggingface_hub import get_full_repo_name # for backward compatibility
21
+ from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
22
+
23
+ from . import __version__
24
+
25
+ # Backward compatibility imports, to make sure all those objects can be found in file_utils
26
+ from .utils import (
27
+ CLOUDFRONT_DISTRIB_PREFIX,
28
+ CONFIG_NAME,
29
+ DUMMY_INPUTS,
30
+ DUMMY_MASK,
31
+ ENV_VARS_TRUE_AND_AUTO_VALUES,
32
+ ENV_VARS_TRUE_VALUES,
33
+ FEATURE_EXTRACTOR_NAME,
34
+ FLAX_WEIGHTS_NAME,
35
+ HF_MODULES_CACHE,
36
+ HUGGINGFACE_CO_PREFIX,
37
+ HUGGINGFACE_CO_RESOLVE_ENDPOINT,
38
+ MODEL_CARD_NAME,
39
+ MULTIPLE_CHOICE_DUMMY_INPUTS,
40
+ PYTORCH_PRETRAINED_BERT_CACHE,
41
+ PYTORCH_TRANSFORMERS_CACHE,
42
+ S3_BUCKET_PREFIX,
43
+ SENTENCEPIECE_UNDERLINE,
44
+ SPIECE_UNDERLINE,
45
+ TF2_WEIGHTS_NAME,
46
+ TF_WEIGHTS_NAME,
47
+ TORCH_FX_REQUIRED_VERSION,
48
+ TRANSFORMERS_CACHE,
49
+ TRANSFORMERS_DYNAMIC_MODULE_NAME,
50
+ USE_JAX,
51
+ USE_TF,
52
+ USE_TORCH,
53
+ WEIGHTS_INDEX_NAME,
54
+ WEIGHTS_NAME,
55
+ ContextManagers,
56
+ DummyObject,
57
+ EntryNotFoundError,
58
+ ExplicitEnum,
59
+ ModelOutput,
60
+ PaddingStrategy,
61
+ PushToHubMixin,
62
+ RepositoryNotFoundError,
63
+ RevisionNotFoundError,
64
+ TensorType,
65
+ _LazyModule,
66
+ add_code_sample_docstrings,
67
+ add_end_docstrings,
68
+ add_start_docstrings,
69
+ add_start_docstrings_to_model_forward,
70
+ cached_property,
71
+ copy_func,
72
+ default_cache_path,
73
+ define_sagemaker_information,
74
+ get_cached_models,
75
+ get_file_from_repo,
76
+ get_torch_version,
77
+ has_file,
78
+ http_user_agent,
79
+ is_apex_available,
80
+ is_bs4_available,
81
+ is_coloredlogs_available,
82
+ is_datasets_available,
83
+ is_detectron2_available,
84
+ is_faiss_available,
85
+ is_flax_available,
86
+ is_ftfy_available,
87
+ is_g2p_en_available,
88
+ is_in_notebook,
89
+ is_ipex_available,
90
+ is_librosa_available,
91
+ is_offline_mode,
92
+ is_onnx_available,
93
+ is_pandas_available,
94
+ is_phonemizer_available,
95
+ is_protobuf_available,
96
+ is_psutil_available,
97
+ is_py3nvml_available,
98
+ is_pyctcdecode_available,
99
+ is_pytesseract_available,
100
+ is_pytorch_quantization_available,
101
+ is_rjieba_available,
102
+ is_sagemaker_dp_enabled,
103
+ is_sagemaker_mp_enabled,
104
+ is_scipy_available,
105
+ is_sentencepiece_available,
106
+ is_seqio_available,
107
+ is_sklearn_available,
108
+ is_soundfile_available,
109
+ is_spacy_available,
110
+ is_speech_available,
111
+ is_tensor,
112
+ is_tensorflow_probability_available,
113
+ is_tf2onnx_available,
114
+ is_tf_available,
115
+ is_timm_available,
116
+ is_tokenizers_available,
117
+ is_torch_available,
118
+ is_torch_bf16_available,
119
+ is_torch_cuda_available,
120
+ is_torch_fx_available,
121
+ is_torch_fx_proxy,
122
+ is_torch_mps_available,
123
+ is_torch_tf32_available,
124
+ is_torch_xla_available,
125
+ is_torchaudio_available,
126
+ is_training_run_on_sagemaker,
127
+ is_vision_available,
128
+ replace_return_docstrings,
129
+ requires_backends,
130
+ to_numpy,
131
+ to_py_obj,
132
+ torch_only_method,
133
+ )
hf_argparser.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import dataclasses
16
+ import json
17
+ import os
18
+ import sys
19
+ import types
20
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError
21
+ from copy import copy
22
+ from enum import Enum
23
+ from inspect import isclass
24
+ from pathlib import Path
25
+ from typing import Any, Callable, Dict, Iterable, List, Literal, NewType, Optional, Tuple, Union, get_type_hints
26
+
27
+ import yaml
28
+
29
+
30
+ DataClass = NewType("DataClass", Any)
31
+ DataClassType = NewType("DataClassType", Any)
32
+
33
+
34
+ # From https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
35
+ def string_to_bool(v):
36
+ if isinstance(v, bool):
37
+ return v
38
+ if v.lower() in ("yes", "true", "t", "y", "1"):
39
+ return True
40
+ elif v.lower() in ("no", "false", "f", "n", "0"):
41
+ return False
42
+ else:
43
+ raise ArgumentTypeError(
44
+ f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
45
+ )
46
+
47
+
48
+ def make_choice_type_function(choices: list) -> Callable[[str], Any]:
49
+ """
50
+ Creates a mapping function from each choices string representation to the actual value. Used to support multiple
51
+ value types for a single argument.
52
+
53
+ Args:
54
+ choices (list): List of choices.
55
+
56
+ Returns:
57
+ Callable[[str], Any]: Mapping function from string representation to actual value for each choice.
58
+ """
59
+ str_to_choice = {str(choice): choice for choice in choices}
60
+ return lambda arg: str_to_choice.get(arg, arg)
61
+
62
+
63
+ def HfArg(
64
+ *,
65
+ aliases: Union[str, List[str]] = None,
66
+ help: str = None,
67
+ default: Any = dataclasses.MISSING,
68
+ default_factory: Callable[[], Any] = dataclasses.MISSING,
69
+ metadata: dict = None,
70
+ **kwargs,
71
+ ) -> dataclasses.Field:
72
+ """Argument helper enabling a concise syntax to create dataclass fields for parsing with `HfArgumentParser`.
73
+
74
+ Example comparing the use of `HfArg` and `dataclasses.field`:
75
+ ```
76
+ @dataclass
77
+ class Args:
78
+ regular_arg: str = dataclasses.field(default="Huggingface", metadata={"aliases": ["--example", "-e"], "help": "This syntax could be better!"})
79
+ hf_arg: str = HfArg(default="Huggingface", aliases=["--example", "-e"], help="What a nice syntax!")
80
+ ```
81
+
82
+ Args:
83
+ aliases (Union[str, List[str]], optional):
84
+ Single string or list of strings of aliases to pass on to argparse, e.g. `aliases=["--example", "-e"]`.
85
+ Defaults to None.
86
+ help (str, optional): Help string to pass on to argparse that can be displayed with --help. Defaults to None.
87
+ default (Any, optional):
88
+ Default value for the argument. If not default or default_factory is specified, the argument is required.
89
+ Defaults to dataclasses.MISSING.
90
+ default_factory (Callable[[], Any], optional):
91
+ The default_factory is a 0-argument function called to initialize a field's value. It is useful to provide
92
+ default values for mutable types, e.g. lists: `default_factory=list`. Mutually exclusive with `default=`.
93
+ Defaults to dataclasses.MISSING.
94
+ metadata (dict, optional): Further metadata to pass on to `dataclasses.field`. Defaults to None.
95
+
96
+ Returns:
97
+ Field: A `dataclasses.Field` with the desired properties.
98
+ """
99
+ if metadata is None:
100
+ # Important, don't use as default param in function signature because dict is mutable and shared across function calls
101
+ metadata = {}
102
+ if aliases is not None:
103
+ metadata["aliases"] = aliases
104
+ if help is not None:
105
+ metadata["help"] = help
106
+
107
+ return dataclasses.field(metadata=metadata, default=default, default_factory=default_factory, **kwargs)
108
+
109
+
110
+ class HfArgumentParser(ArgumentParser):
111
+ """
112
+ This subclass of `argparse.ArgumentParser` uses type hints on dataclasses to generate arguments.
113
+
114
+ The class is designed to play well with the native argparse. In particular, you can add more (non-dataclass backed)
115
+ arguments to the parser after initialization and you'll get the output back after parsing as an additional
116
+ namespace. Optional: To create sub argument groups use the `_argument_group_name` attribute in the dataclass.
117
+ """
118
+
119
+ dataclass_types: Iterable[DataClassType]
120
+
121
+ def __init__(self, dataclass_types: Union[DataClassType, Iterable[DataClassType]], **kwargs):
122
+ """
123
+ Args:
124
+ dataclass_types:
125
+ Dataclass type, or list of dataclass types for which we will "fill" instances with the parsed args.
126
+ kwargs (`Dict[str, Any]`, *optional*):
127
+ Passed to `argparse.ArgumentParser()` in the regular way.
128
+ """
129
+ # To make the default appear when using --help
130
+ if "formatter_class" not in kwargs:
131
+ kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
132
+ super().__init__(**kwargs)
133
+ if dataclasses.is_dataclass(dataclass_types):
134
+ dataclass_types = [dataclass_types]
135
+ self.dataclass_types = list(dataclass_types)
136
+ for dtype in self.dataclass_types:
137
+ self._add_dataclass_arguments(dtype)
138
+
139
+ @staticmethod
140
+ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
141
+ # Long-option strings are conventionlly separated by hyphens rather
142
+ # than underscores, e.g., "--long-format" rather than "--long_format".
143
+ # Argparse converts hyphens to underscores so that the destination
144
+ # string is a valid attribute name. Hf_argparser should do the same.
145
+ long_options = [f"--{field.name}"]
146
+ if "_" in field.name:
147
+ long_options.append(f"--{field.name.replace('_', '-')}")
148
+
149
+ kwargs = field.metadata.copy()
150
+ # field.metadata is not used at all by Data Classes,
151
+ # it is provided as a third-party extension mechanism.
152
+ if isinstance(field.type, str):
153
+ raise RuntimeError(
154
+ "Unresolved type detected, which should have been done with the help of "
155
+ "`typing.get_type_hints` method by default"
156
+ )
157
+
158
+ aliases = kwargs.pop("aliases", [])
159
+ if isinstance(aliases, str):
160
+ aliases = [aliases]
161
+
162
+ origin_type = getattr(field.type, "__origin__", field.type)
163
+ if origin_type is Union or (hasattr(types, "UnionType") and isinstance(origin_type, types.UnionType)):
164
+ if str not in field.type.__args__ and (
165
+ len(field.type.__args__) != 2 or type(None) not in field.type.__args__
166
+ ):
167
+ raise ValueError(
168
+ "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because"
169
+ " the argument parser only supports one type per argument."
170
+ f" Problem encountered in field '{field.name}'."
171
+ )
172
+ if type(None) not in field.type.__args__:
173
+ # filter `str` in Union
174
+ field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1]
175
+ origin_type = getattr(field.type, "__origin__", field.type)
176
+ elif bool not in field.type.__args__:
177
+ # filter `NoneType` in Union (except for `Union[bool, NoneType]`)
178
+ field.type = (
179
+ field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]
180
+ )
181
+ origin_type = getattr(field.type, "__origin__", field.type)
182
+
183
+ # A variable to store kwargs for a boolean field, if needed
184
+ # so that we can init a `no_*` complement argument (see below)
185
+ bool_kwargs = {}
186
+ if origin_type is Literal or (isinstance(field.type, type) and issubclass(field.type, Enum)):
187
+ if origin_type is Literal:
188
+ kwargs["choices"] = field.type.__args__
189
+ else:
190
+ kwargs["choices"] = [x.value for x in field.type]
191
+
192
+ kwargs["type"] = make_choice_type_function(kwargs["choices"])
193
+
194
+ if field.default is not dataclasses.MISSING:
195
+ kwargs["default"] = field.default
196
+ else:
197
+ kwargs["required"] = True
198
+ elif field.type is bool or field.type == Optional[bool]:
199
+ # Copy the currect kwargs to use to instantiate a `no_*` complement argument below.
200
+ # We do not initialize it here because the `no_*` alternative must be instantiated after the real argument
201
+ bool_kwargs = copy(kwargs)
202
+
203
+ # Hack because type=bool in argparse does not behave as we want.
204
+ kwargs["type"] = string_to_bool
205
+ if field.type is bool or (field.default is not None and field.default is not dataclasses.MISSING):
206
+ # Default value is False if we have no default when of type bool.
207
+ default = False if field.default is dataclasses.MISSING else field.default
208
+ # This is the value that will get picked if we don't include --{field.name} in any way
209
+ kwargs["default"] = default
210
+ # This tells argparse we accept 0 or 1 value after --{field.name}
211
+ kwargs["nargs"] = "?"
212
+ # This is the value that will get picked if we do --{field.name} (without value)
213
+ kwargs["const"] = True
214
+ elif isclass(origin_type) and issubclass(origin_type, list):
215
+ kwargs["type"] = field.type.__args__[0]
216
+ kwargs["nargs"] = "+"
217
+ if field.default_factory is not dataclasses.MISSING:
218
+ kwargs["default"] = field.default_factory()
219
+ elif field.default is dataclasses.MISSING:
220
+ kwargs["required"] = True
221
+ else:
222
+ kwargs["type"] = field.type
223
+ if field.default is not dataclasses.MISSING:
224
+ kwargs["default"] = field.default
225
+ elif field.default_factory is not dataclasses.MISSING:
226
+ kwargs["default"] = field.default_factory()
227
+ else:
228
+ kwargs["required"] = True
229
+ parser.add_argument(*long_options, *aliases, **kwargs)
230
+
231
+ # Add a complement `no_*` argument for a boolean field AFTER the initial field has already been added.
232
+ # Order is important for arguments with the same destination!
233
+ # We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down
234
+ # here and we do not need those changes/additional keys.
235
+ if field.default is True and (field.type is bool or field.type == Optional[bool]):
236
+ bool_kwargs["default"] = False
237
+ parser.add_argument(
238
+ f"--no_{field.name}",
239
+ f"--no-{field.name.replace('_', '-')}",
240
+ action="store_false",
241
+ dest=field.name,
242
+ **bool_kwargs,
243
+ )
244
+
245
+ def _add_dataclass_arguments(self, dtype: DataClassType):
246
+ if hasattr(dtype, "_argument_group_name"):
247
+ parser = self.add_argument_group(dtype._argument_group_name)
248
+ else:
249
+ parser = self
250
+
251
+ try:
252
+ type_hints: Dict[str, type] = get_type_hints(dtype)
253
+ except NameError:
254
+ raise RuntimeError(
255
+ f"Type resolution failed for {dtype}. Try declaring the class in global scope or "
256
+ "removing line of `from __future__ import annotations` which opts in Postponed "
257
+ "Evaluation of Annotations (PEP 563)"
258
+ )
259
+ except TypeError as ex:
260
+ # Remove this block when we drop Python 3.9 support
261
+ if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex):
262
+ python_version = ".".join(map(str, sys.version_info[:3]))
263
+ raise RuntimeError(
264
+ f"Type resolution failed for {dtype} on Python {python_version}. Try removing "
265
+ "line of `from __future__ import annotations` which opts in union types as "
266
+ "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To "
267
+ "support Python versions that lower than 3.10, you need to use "
268
+ "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of "
269
+ "`X | None`."
270
+ ) from ex
271
+ raise
272
+
273
+ for field in dataclasses.fields(dtype):
274
+ if not field.init:
275
+ continue
276
+ field.type = type_hints[field.name]
277
+ self._parse_dataclass_field(parser, field)
278
+
279
+ def parse_args_into_dataclasses(
280
+ self,
281
+ args=None,
282
+ return_remaining_strings=False,
283
+ look_for_args_file=True,
284
+ args_filename=None,
285
+ args_file_flag=None,
286
+ ) -> Tuple[DataClass, ...]:
287
+ """
288
+ Parse command-line args into instances of the specified dataclass types.
289
+
290
+ This relies on argparse's `ArgumentParser.parse_known_args`. See the doc at:
291
+ docs.python.org/3.7/library/argparse.html#argparse.ArgumentParser.parse_args
292
+
293
+ Args:
294
+ args:
295
+ List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
296
+ return_remaining_strings:
297
+ If true, also return a list of remaining argument strings.
298
+ look_for_args_file:
299
+ If true, will look for a ".args" file with the same base name as the entry point script for this
300
+ process, and will append its potential content to the command line args.
301
+ args_filename:
302
+ If not None, will uses this file instead of the ".args" file specified in the previous argument.
303
+ args_file_flag:
304
+ If not None, will look for a file in the command-line args specified with this flag. The flag can be
305
+ specified multiple times and precedence is determined by the order (last one wins).
306
+
307
+ Returns:
308
+ Tuple consisting of:
309
+
310
+ - the dataclass instances in the same order as they were passed to the initializer.abspath
311
+ - if applicable, an additional namespace for more (non-dataclass backed) arguments added to the parser
312
+ after initialization.
313
+ - The potential list of remaining argument strings. (same as argparse.ArgumentParser.parse_known_args)
314
+ """
315
+
316
+ if args_file_flag or args_filename or (look_for_args_file and len(sys.argv)):
317
+ args_files = []
318
+
319
+ if args_filename:
320
+ args_files.append(Path(args_filename))
321
+ elif look_for_args_file and len(sys.argv):
322
+ args_files.append(Path(sys.argv[0]).with_suffix(".args"))
323
+
324
+ # args files specified via command line flag should overwrite default args files so we add them last
325
+ if args_file_flag:
326
+ # Create special parser just to extract the args_file_flag values
327
+ args_file_parser = ArgumentParser()
328
+ args_file_parser.add_argument(args_file_flag, type=str, action="append")
329
+
330
+ # Use only remaining args for further parsing (remove the args_file_flag)
331
+ cfg, args = args_file_parser.parse_known_args(args=args)
332
+ cmd_args_file_paths = vars(cfg).get(args_file_flag.lstrip("-"), None)
333
+
334
+ if cmd_args_file_paths:
335
+ args_files.extend([Path(p) for p in cmd_args_file_paths])
336
+
337
+ file_args = []
338
+ for args_file in args_files:
339
+ if args_file.exists():
340
+ file_args += args_file.read_text().split()
341
+
342
+ # in case of duplicate arguments the last one has precedence
343
+ # args specified via the command line should overwrite args from files, so we add them last
344
+ args = file_args + args if args is not None else file_args + sys.argv[1:]
345
+ namespace, remaining_args = self.parse_known_args(args=args)
346
+ outputs = []
347
+ for dtype in self.dataclass_types:
348
+ keys = {f.name for f in dataclasses.fields(dtype) if f.init}
349
+ inputs = {k: v for k, v in vars(namespace).items() if k in keys}
350
+ for k in keys:
351
+ delattr(namespace, k)
352
+ obj = dtype(**inputs)
353
+ outputs.append(obj)
354
+ if len(namespace.__dict__) > 0:
355
+ # additional namespace.
356
+ outputs.append(namespace)
357
+ if return_remaining_strings:
358
+ return (*outputs, remaining_args)
359
+ else:
360
+ if remaining_args:
361
+ raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {remaining_args}")
362
+
363
+ return (*outputs,)
364
+
365
+ def parse_dict(self, args: Dict[str, Any], allow_extra_keys: bool = False) -> Tuple[DataClass, ...]:
366
+ """
367
+ Alternative helper method that does not use `argparse` at all, instead uses a dict and populating the dataclass
368
+ types.
369
+
370
+ Args:
371
+ args (`dict`):
372
+ dict containing config values
373
+ allow_extra_keys (`bool`, *optional*, defaults to `False`):
374
+ Defaults to False. If False, will raise an exception if the dict contains keys that are not parsed.
375
+
376
+ Returns:
377
+ Tuple consisting of:
378
+
379
+ - the dataclass instances in the same order as they were passed to the initializer.
380
+ """
381
+ unused_keys = set(args.keys())
382
+ outputs = []
383
+ for dtype in self.dataclass_types:
384
+ keys = {f.name for f in dataclasses.fields(dtype) if f.init}
385
+ inputs = {k: v for k, v in args.items() if k in keys}
386
+ unused_keys.difference_update(inputs.keys())
387
+ obj = dtype(**inputs)
388
+ outputs.append(obj)
389
+ if not allow_extra_keys and unused_keys:
390
+ raise ValueError(f"Some keys are not used by the HfArgumentParser: {sorted(unused_keys)}")
391
+ return tuple(outputs)
392
+
393
+ def parse_json_file(
394
+ self, json_file: Union[str, os.PathLike], allow_extra_keys: bool = False
395
+ ) -> Tuple[DataClass, ...]:
396
+ """
397
+ Alternative helper method that does not use `argparse` at all, instead loading a json file and populating the
398
+ dataclass types.
399
+
400
+ Args:
401
+ json_file (`str` or `os.PathLike`):
402
+ File name of the json file to parse
403
+ allow_extra_keys (`bool`, *optional*, defaults to `False`):
404
+ Defaults to False. If False, will raise an exception if the json file contains keys that are not
405
+ parsed.
406
+
407
+ Returns:
408
+ Tuple consisting of:
409
+
410
+ - the dataclass instances in the same order as they were passed to the initializer.
411
+ """
412
+ with open(Path(json_file), encoding="utf-8") as open_json_file:
413
+ data = json.loads(open_json_file.read())
414
+ outputs = self.parse_dict(data, allow_extra_keys=allow_extra_keys)
415
+ return tuple(outputs)
416
+
417
+ def parse_yaml_file(
418
+ self, yaml_file: Union[str, os.PathLike], allow_extra_keys: bool = False
419
+ ) -> Tuple[DataClass, ...]:
420
+ """
421
+ Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the
422
+ dataclass types.
423
+
424
+ Args:
425
+ yaml_file (`str` or `os.PathLike`):
426
+ File name of the yaml file to parse
427
+ allow_extra_keys (`bool`, *optional*, defaults to `False`):
428
+ Defaults to False. If False, will raise an exception if the json file contains keys that are not
429
+ parsed.
430
+
431
+ Returns:
432
+ Tuple consisting of:
433
+
434
+ - the dataclass instances in the same order as they were passed to the initializer.
435
+ """
436
+ outputs = self.parse_dict(yaml.safe_load(Path(yaml_file).read_text()), allow_extra_keys=allow_extra_keys)
437
+ return tuple(outputs)
hyperparameter_search.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023-present the HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from .integrations import (
17
+ is_optuna_available,
18
+ is_ray_tune_available,
19
+ is_sigopt_available,
20
+ is_wandb_available,
21
+ run_hp_search_optuna,
22
+ run_hp_search_ray,
23
+ run_hp_search_sigopt,
24
+ run_hp_search_wandb,
25
+ )
26
+ from .trainer_utils import (
27
+ HPSearchBackend,
28
+ default_hp_space_optuna,
29
+ default_hp_space_ray,
30
+ default_hp_space_sigopt,
31
+ default_hp_space_wandb,
32
+ )
33
+ from .utils import logging
34
+
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ class HyperParamSearchBackendBase:
40
+ name: str
41
+ pip_package: str = None
42
+
43
+ @staticmethod
44
+ def is_available():
45
+ raise NotImplementedError
46
+
47
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
48
+ raise NotImplementedError
49
+
50
+ def default_hp_space(self, trial):
51
+ raise NotImplementedError
52
+
53
+ def ensure_available(self):
54
+ if not self.is_available():
55
+ raise RuntimeError(
56
+ f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
57
+ )
58
+
59
+ @classmethod
60
+ def pip_install(cls):
61
+ return f"`pip install {cls.pip_package or cls.name}`"
62
+
63
+
64
+ class OptunaBackend(HyperParamSearchBackendBase):
65
+ name = "optuna"
66
+
67
+ @staticmethod
68
+ def is_available():
69
+ return is_optuna_available()
70
+
71
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
72
+ return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
73
+
74
+ def default_hp_space(self, trial):
75
+ return default_hp_space_optuna(trial)
76
+
77
+
78
+ class RayTuneBackend(HyperParamSearchBackendBase):
79
+ name = "ray"
80
+ pip_package = "'ray[tune]'"
81
+
82
+ @staticmethod
83
+ def is_available():
84
+ return is_ray_tune_available()
85
+
86
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
87
+ return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
88
+
89
+ def default_hp_space(self, trial):
90
+ return default_hp_space_ray(trial)
91
+
92
+
93
+ class SigOptBackend(HyperParamSearchBackendBase):
94
+ name = "sigopt"
95
+
96
+ @staticmethod
97
+ def is_available():
98
+ return is_sigopt_available()
99
+
100
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
101
+ return run_hp_search_sigopt(trainer, n_trials, direction, **kwargs)
102
+
103
+ def default_hp_space(self, trial):
104
+ return default_hp_space_sigopt(trial)
105
+
106
+
107
+ class WandbBackend(HyperParamSearchBackendBase):
108
+ name = "wandb"
109
+
110
+ @staticmethod
111
+ def is_available():
112
+ return is_wandb_available()
113
+
114
+ def run(self, trainer, n_trials: int, direction: str, **kwargs):
115
+ return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
116
+
117
+ def default_hp_space(self, trial):
118
+ return default_hp_space_wandb(trial)
119
+
120
+
121
+ ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
122
+ HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, SigOptBackend, WandbBackend]
123
+ }
124
+
125
+
126
+ def default_hp_search_backend() -> str:
127
+ available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
128
+ if len(available_backends) > 0:
129
+ name = available_backends[0].name
130
+ if len(available_backends) > 1:
131
+ logger.info(
132
+ f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
133
+ )
134
+ return name
135
+ raise RuntimeError(
136
+ "No hyperparameter search backend available.\n"
137
+ + "\n".join(
138
+ f" - To install {backend.name} run {backend.pip_install()}"
139
+ for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
140
+ )
141
+ )
image_processing_base.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import copy
18
+ import json
19
+ import os
20
+ import warnings
21
+ from io import BytesIO
22
+ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
23
+
24
+ import numpy as np
25
+ import requests
26
+
27
+ from .dynamic_module_utils import custom_object_save
28
+ from .feature_extraction_utils import BatchFeature as BaseBatchFeature
29
+ from .utils import (
30
+ IMAGE_PROCESSOR_NAME,
31
+ PushToHubMixin,
32
+ add_model_info_to_auto_map,
33
+ add_model_info_to_custom_pipelines,
34
+ cached_file,
35
+ copy_func,
36
+ download_url,
37
+ is_offline_mode,
38
+ is_remote_url,
39
+ is_vision_available,
40
+ logging,
41
+ )
42
+
43
+
44
+ if is_vision_available():
45
+ from PIL import Image
46
+
47
+
48
+ ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin")
49
+
50
+
51
+ logger = logging.get_logger(__name__)
52
+
53
+
54
+ # TODO: Move BatchFeature to be imported by both image_processing_utils and image_processing_utils
55
+ # We override the class string here, but logic is the same.
56
+ class BatchFeature(BaseBatchFeature):
57
+ r"""
58
+ Holds the output of the image processor specific `__call__` methods.
59
+
60
+ This class is derived from a python dictionary and can be used as a dictionary.
61
+
62
+ Args:
63
+ data (`dict`):
64
+ Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
65
+ tensor_type (`Union[None, str, TensorType]`, *optional*):
66
+ You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
67
+ initialization.
68
+ """
69
+
70
+
71
+ # TODO: (Amy) - factor out the common parts of this and the feature extractor
72
+ class ImageProcessingMixin(PushToHubMixin):
73
+ """
74
+ This is an image processor mixin used to provide saving/loading functionality for sequential and image feature
75
+ extractors.
76
+ """
77
+
78
+ _auto_class = None
79
+
80
+ def __init__(self, **kwargs):
81
+ """Set elements of `kwargs` as attributes."""
82
+ # This key was saved while we still used `XXXFeatureExtractor` for image processing. Now we use
83
+ # `XXXImageProcessor`, this attribute and its value are misleading.
84
+ kwargs.pop("feature_extractor_type", None)
85
+ # Pop "processor_class" as it should be saved as private attribute
86
+ self._processor_class = kwargs.pop("processor_class", None)
87
+ # Additional attributes without default values
88
+ for key, value in kwargs.items():
89
+ try:
90
+ setattr(self, key, value)
91
+ except AttributeError as err:
92
+ logger.error(f"Can't set {key} with value {value} for {self}")
93
+ raise err
94
+
95
+ def _set_processor_class(self, processor_class: str):
96
+ """Sets processor class as an attribute."""
97
+ self._processor_class = processor_class
98
+
99
+ @classmethod
100
+ def from_pretrained(
101
+ cls: Type[ImageProcessorType],
102
+ pretrained_model_name_or_path: Union[str, os.PathLike],
103
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
104
+ force_download: bool = False,
105
+ local_files_only: bool = False,
106
+ token: Optional[Union[str, bool]] = None,
107
+ revision: str = "main",
108
+ **kwargs,
109
+ ) -> ImageProcessorType:
110
+ r"""
111
+ Instantiate a type of [`~image_processing_utils.ImageProcessingMixin`] from an image processor.
112
+
113
+ Args:
114
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
115
+ This can be either:
116
+
117
+ - a string, the *model id* of a pretrained image_processor hosted inside a model repo on
118
+ huggingface.co.
119
+ - a path to a *directory* containing a image processor file saved using the
120
+ [`~image_processing_utils.ImageProcessingMixin.save_pretrained`] method, e.g.,
121
+ `./my_model_directory/`.
122
+ - a path or url to a saved image processor JSON *file*, e.g.,
123
+ `./my_model_directory/preprocessor_config.json`.
124
+ cache_dir (`str` or `os.PathLike`, *optional*):
125
+ Path to a directory in which a downloaded pretrained model image processor should be cached if the
126
+ standard cache should not be used.
127
+ force_download (`bool`, *optional*, defaults to `False`):
128
+ Whether or not to force to (re-)download the image processor files and override the cached versions if
129
+ they exist.
130
+ resume_download:
131
+ Deprecated and ignored. All downloads are now resumed by default when possible.
132
+ Will be removed in v5 of Transformers.
133
+ proxies (`Dict[str, str]`, *optional*):
134
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
135
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
136
+ token (`str` or `bool`, *optional*):
137
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
138
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
139
+ revision (`str`, *optional*, defaults to `"main"`):
140
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
141
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
142
+ identifier allowed by git.
143
+
144
+
145
+ <Tip>
146
+
147
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
148
+
149
+ </Tip>
150
+
151
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
152
+ If `False`, then this function returns just the final image processor object. If `True`, then this
153
+ functions returns a `Tuple(image_processor, unused_kwargs)` where *unused_kwargs* is a dictionary
154
+ consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part of
155
+ `kwargs` which has not been used to update `image_processor` and is otherwise ignored.
156
+ subfolder (`str`, *optional*, defaults to `""`):
157
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
158
+ specify the folder name here.
159
+ kwargs (`Dict[str, Any]`, *optional*):
160
+ The values in kwargs of any keys which are image processor attributes will be used to override the
161
+ loaded values. Behavior concerning key/value pairs whose keys are *not* image processor attributes is
162
+ controlled by the `return_unused_kwargs` keyword parameter.
163
+
164
+ Returns:
165
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`].
166
+
167
+ Examples:
168
+
169
+ ```python
170
+ # We can't instantiate directly the base class *ImageProcessingMixin* so let's show the examples on a
171
+ # derived class: *CLIPImageProcessor*
172
+ image_processor = CLIPImageProcessor.from_pretrained(
173
+ "openai/clip-vit-base-patch32"
174
+ ) # Download image_processing_config from huggingface.co and cache.
175
+ image_processor = CLIPImageProcessor.from_pretrained(
176
+ "./test/saved_model/"
177
+ ) # E.g. image processor (or model) was saved using *save_pretrained('./test/saved_model/')*
178
+ image_processor = CLIPImageProcessor.from_pretrained("./test/saved_model/preprocessor_config.json")
179
+ image_processor = CLIPImageProcessor.from_pretrained(
180
+ "openai/clip-vit-base-patch32", do_normalize=False, foo=False
181
+ )
182
+ assert image_processor.do_normalize is False
183
+ image_processor, unused_kwargs = CLIPImageProcessor.from_pretrained(
184
+ "openai/clip-vit-base-patch32", do_normalize=False, foo=False, return_unused_kwargs=True
185
+ )
186
+ assert image_processor.do_normalize is False
187
+ assert unused_kwargs == {"foo": False}
188
+ ```"""
189
+ kwargs["cache_dir"] = cache_dir
190
+ kwargs["force_download"] = force_download
191
+ kwargs["local_files_only"] = local_files_only
192
+ kwargs["revision"] = revision
193
+
194
+ use_auth_token = kwargs.pop("use_auth_token", None)
195
+ if use_auth_token is not None:
196
+ warnings.warn(
197
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
198
+ FutureWarning,
199
+ )
200
+ if token is not None:
201
+ raise ValueError(
202
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
203
+ )
204
+ token = use_auth_token
205
+
206
+ if token is not None:
207
+ kwargs["token"] = token
208
+
209
+ image_processor_dict, kwargs = cls.get_image_processor_dict(pretrained_model_name_or_path, **kwargs)
210
+
211
+ return cls.from_dict(image_processor_dict, **kwargs)
212
+
213
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
214
+ """
215
+ Save an image processor object to the directory `save_directory`, so that it can be re-loaded using the
216
+ [`~image_processing_utils.ImageProcessingMixin.from_pretrained`] class method.
217
+
218
+ Args:
219
+ save_directory (`str` or `os.PathLike`):
220
+ Directory where the image processor JSON file will be saved (will be created if it does not exist).
221
+ push_to_hub (`bool`, *optional*, defaults to `False`):
222
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
223
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
224
+ namespace).
225
+ kwargs (`Dict[str, Any]`, *optional*):
226
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
227
+ """
228
+ use_auth_token = kwargs.pop("use_auth_token", None)
229
+
230
+ if use_auth_token is not None:
231
+ warnings.warn(
232
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
233
+ FutureWarning,
234
+ )
235
+ if kwargs.get("token", None) is not None:
236
+ raise ValueError(
237
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
238
+ )
239
+ kwargs["token"] = use_auth_token
240
+
241
+ if os.path.isfile(save_directory):
242
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
243
+
244
+ os.makedirs(save_directory, exist_ok=True)
245
+
246
+ if push_to_hub:
247
+ commit_message = kwargs.pop("commit_message", None)
248
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
249
+ repo_id = self._create_repo(repo_id, **kwargs)
250
+ files_timestamps = self._get_files_timestamps(save_directory)
251
+
252
+ # If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
253
+ # loaded from the Hub.
254
+ if self._auto_class is not None:
255
+ custom_object_save(self, save_directory, config=self)
256
+
257
+ # If we save using the predefined names, we can load using `from_pretrained`
258
+ output_image_processor_file = os.path.join(save_directory, IMAGE_PROCESSOR_NAME)
259
+
260
+ self.to_json_file(output_image_processor_file)
261
+ logger.info(f"Image processor saved in {output_image_processor_file}")
262
+
263
+ if push_to_hub:
264
+ self._upload_modified_files(
265
+ save_directory,
266
+ repo_id,
267
+ files_timestamps,
268
+ commit_message=commit_message,
269
+ token=kwargs.get("token"),
270
+ )
271
+
272
+ return [output_image_processor_file]
273
+
274
+ @classmethod
275
+ def get_image_processor_dict(
276
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
277
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
278
+ """
279
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
280
+ image processor of type [`~image_processor_utils.ImageProcessingMixin`] using `from_dict`.
281
+
282
+ Parameters:
283
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
284
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
285
+ subfolder (`str`, *optional*, defaults to `""`):
286
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
287
+ specify the folder name here.
288
+ image_processor_filename (`str`, *optional*, defaults to `"config.json"`):
289
+ The name of the file in the model directory to use for the image processor config.
290
+
291
+ Returns:
292
+ `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the image processor object.
293
+ """
294
+ cache_dir = kwargs.pop("cache_dir", None)
295
+ force_download = kwargs.pop("force_download", False)
296
+ resume_download = kwargs.pop("resume_download", None)
297
+ proxies = kwargs.pop("proxies", None)
298
+ token = kwargs.pop("token", None)
299
+ use_auth_token = kwargs.pop("use_auth_token", None)
300
+ local_files_only = kwargs.pop("local_files_only", False)
301
+ revision = kwargs.pop("revision", None)
302
+ subfolder = kwargs.pop("subfolder", "")
303
+ image_processor_filename = kwargs.pop("image_processor_filename", IMAGE_PROCESSOR_NAME)
304
+
305
+ from_pipeline = kwargs.pop("_from_pipeline", None)
306
+ from_auto_class = kwargs.pop("_from_auto", False)
307
+
308
+ if use_auth_token is not None:
309
+ warnings.warn(
310
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
311
+ FutureWarning,
312
+ )
313
+ if token is not None:
314
+ raise ValueError(
315
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
316
+ )
317
+ token = use_auth_token
318
+
319
+ user_agent = {"file_type": "image processor", "from_auto_class": from_auto_class}
320
+ if from_pipeline is not None:
321
+ user_agent["using_pipeline"] = from_pipeline
322
+
323
+ if is_offline_mode() and not local_files_only:
324
+ logger.info("Offline mode: forcing local_files_only=True")
325
+ local_files_only = True
326
+
327
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
328
+ is_local = os.path.isdir(pretrained_model_name_or_path)
329
+ if os.path.isdir(pretrained_model_name_or_path):
330
+ image_processor_file = os.path.join(pretrained_model_name_or_path, image_processor_filename)
331
+ if os.path.isfile(pretrained_model_name_or_path):
332
+ resolved_image_processor_file = pretrained_model_name_or_path
333
+ is_local = True
334
+ elif is_remote_url(pretrained_model_name_or_path):
335
+ image_processor_file = pretrained_model_name_or_path
336
+ resolved_image_processor_file = download_url(pretrained_model_name_or_path)
337
+ else:
338
+ image_processor_file = image_processor_filename
339
+ try:
340
+ # Load from local folder or from cache or download from model Hub and cache
341
+ resolved_image_processor_file = cached_file(
342
+ pretrained_model_name_or_path,
343
+ image_processor_file,
344
+ cache_dir=cache_dir,
345
+ force_download=force_download,
346
+ proxies=proxies,
347
+ resume_download=resume_download,
348
+ local_files_only=local_files_only,
349
+ token=token,
350
+ user_agent=user_agent,
351
+ revision=revision,
352
+ subfolder=subfolder,
353
+ )
354
+ except EnvironmentError:
355
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
356
+ # the original exception.
357
+ raise
358
+ except Exception:
359
+ # For any other exception, we throw a generic error.
360
+ raise EnvironmentError(
361
+ f"Can't load image processor for '{pretrained_model_name_or_path}'. If you were trying to load"
362
+ " it from 'https://huggingface.co/models', make sure you don't have a local directory with the"
363
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
364
+ f" directory containing a {image_processor_filename} file"
365
+ )
366
+
367
+ try:
368
+ # Load image_processor dict
369
+ with open(resolved_image_processor_file, "r", encoding="utf-8") as reader:
370
+ text = reader.read()
371
+ image_processor_dict = json.loads(text)
372
+
373
+ except json.JSONDecodeError:
374
+ raise EnvironmentError(
375
+ f"It looks like the config file at '{resolved_image_processor_file}' is not a valid JSON file."
376
+ )
377
+
378
+ if is_local:
379
+ logger.info(f"loading configuration file {resolved_image_processor_file}")
380
+ else:
381
+ logger.info(
382
+ f"loading configuration file {image_processor_file} from cache at {resolved_image_processor_file}"
383
+ )
384
+ if "auto_map" in image_processor_dict:
385
+ image_processor_dict["auto_map"] = add_model_info_to_auto_map(
386
+ image_processor_dict["auto_map"], pretrained_model_name_or_path
387
+ )
388
+ if "custom_pipelines" in image_processor_dict:
389
+ image_processor_dict["custom_pipelines"] = add_model_info_to_custom_pipelines(
390
+ image_processor_dict["custom_pipelines"], pretrained_model_name_or_path
391
+ )
392
+
393
+ return image_processor_dict, kwargs
394
+
395
+ @classmethod
396
+ def from_dict(cls, image_processor_dict: Dict[str, Any], **kwargs):
397
+ """
398
+ Instantiates a type of [`~image_processing_utils.ImageProcessingMixin`] from a Python dictionary of parameters.
399
+
400
+ Args:
401
+ image_processor_dict (`Dict[str, Any]`):
402
+ Dictionary that will be used to instantiate the image processor object. Such a dictionary can be
403
+ retrieved from a pretrained checkpoint by leveraging the
404
+ [`~image_processing_utils.ImageProcessingMixin.to_dict`] method.
405
+ kwargs (`Dict[str, Any]`):
406
+ Additional parameters from which to initialize the image processor object.
407
+
408
+ Returns:
409
+ [`~image_processing_utils.ImageProcessingMixin`]: The image processor object instantiated from those
410
+ parameters.
411
+ """
412
+ image_processor_dict = image_processor_dict.copy()
413
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
414
+
415
+ # The `size` parameter is a dict and was previously an int or tuple in feature extractors.
416
+ # We set `size` here directly to the `image_processor_dict` so that it is converted to the appropriate
417
+ # dict within the image processor and isn't overwritten if `size` is passed in as a kwarg.
418
+ if "size" in kwargs and "size" in image_processor_dict:
419
+ image_processor_dict["size"] = kwargs.pop("size")
420
+ if "crop_size" in kwargs and "crop_size" in image_processor_dict:
421
+ image_processor_dict["crop_size"] = kwargs.pop("crop_size")
422
+
423
+ image_processor = cls(**image_processor_dict)
424
+
425
+ # Update image_processor with kwargs if needed
426
+ to_remove = []
427
+ for key, value in kwargs.items():
428
+ if hasattr(image_processor, key):
429
+ setattr(image_processor, key, value)
430
+ to_remove.append(key)
431
+ for key in to_remove:
432
+ kwargs.pop(key, None)
433
+
434
+ logger.info(f"Image processor {image_processor}")
435
+ if return_unused_kwargs:
436
+ return image_processor, kwargs
437
+ else:
438
+ return image_processor
439
+
440
+ def to_dict(self) -> Dict[str, Any]:
441
+ """
442
+ Serializes this instance to a Python dictionary.
443
+
444
+ Returns:
445
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this image processor instance.
446
+ """
447
+ output = copy.deepcopy(self.__dict__)
448
+ output["image_processor_type"] = self.__class__.__name__
449
+
450
+ return output
451
+
452
+ @classmethod
453
+ def from_json_file(cls, json_file: Union[str, os.PathLike]):
454
+ """
455
+ Instantiates a image processor of type [`~image_processing_utils.ImageProcessingMixin`] from the path to a JSON
456
+ file of parameters.
457
+
458
+ Args:
459
+ json_file (`str` or `os.PathLike`):
460
+ Path to the JSON file containing the parameters.
461
+
462
+ Returns:
463
+ A image processor of type [`~image_processing_utils.ImageProcessingMixin`]: The image_processor object
464
+ instantiated from that JSON file.
465
+ """
466
+ with open(json_file, "r", encoding="utf-8") as reader:
467
+ text = reader.read()
468
+ image_processor_dict = json.loads(text)
469
+ return cls(**image_processor_dict)
470
+
471
+ def to_json_string(self) -> str:
472
+ """
473
+ Serializes this instance to a JSON string.
474
+
475
+ Returns:
476
+ `str`: String containing all the attributes that make up this feature_extractor instance in JSON format.
477
+ """
478
+ dictionary = self.to_dict()
479
+
480
+ for key, value in dictionary.items():
481
+ if isinstance(value, np.ndarray):
482
+ dictionary[key] = value.tolist()
483
+
484
+ # make sure private name "_processor_class" is correctly
485
+ # saved as "processor_class"
486
+ _processor_class = dictionary.pop("_processor_class", None)
487
+ if _processor_class is not None:
488
+ dictionary["processor_class"] = _processor_class
489
+
490
+ return json.dumps(dictionary, indent=2, sort_keys=True) + "\n"
491
+
492
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
493
+ """
494
+ Save this instance to a JSON file.
495
+
496
+ Args:
497
+ json_file_path (`str` or `os.PathLike`):
498
+ Path to the JSON file in which this image_processor instance's parameters will be saved.
499
+ """
500
+ with open(json_file_path, "w", encoding="utf-8") as writer:
501
+ writer.write(self.to_json_string())
502
+
503
+ def __repr__(self):
504
+ return f"{self.__class__.__name__} {self.to_json_string()}"
505
+
506
+ @classmethod
507
+ def register_for_auto_class(cls, auto_class="AutoImageProcessor"):
508
+ """
509
+ Register this class with a given auto class. This should only be used for custom image processors as the ones
510
+ in the library are already mapped with `AutoImageProcessor `.
511
+
512
+ <Tip warning={true}>
513
+
514
+ This API is experimental and may have some slight breaking changes in the next releases.
515
+
516
+ </Tip>
517
+
518
+ Args:
519
+ auto_class (`str` or `type`, *optional*, defaults to `"AutoImageProcessor "`):
520
+ The auto class to register this new image processor with.
521
+ """
522
+ if not isinstance(auto_class, str):
523
+ auto_class = auto_class.__name__
524
+
525
+ import transformers.models.auto as auto_module
526
+
527
+ if not hasattr(auto_module, auto_class):
528
+ raise ValueError(f"{auto_class} is not a valid auto class.")
529
+
530
+ cls._auto_class = auto_class
531
+
532
+ def fetch_images(self, image_url_or_urls: Union[str, List[str]]):
533
+ """
534
+ Convert a single or a list of urls into the corresponding `PIL.Image` objects.
535
+
536
+ If a single url is passed, the return value will be a single object. If a list is passed a list of objects is
537
+ returned.
538
+ """
539
+ headers = {
540
+ "User-Agent": (
541
+ "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0"
542
+ " Safari/537.36"
543
+ )
544
+ }
545
+ if isinstance(image_url_or_urls, list):
546
+ return [self.fetch_images(x) for x in image_url_or_urls]
547
+ elif isinstance(image_url_or_urls, str):
548
+ response = requests.get(image_url_or_urls, stream=True, headers=headers)
549
+ response.raise_for_status()
550
+ return Image.open(BytesIO(response.content))
551
+ else:
552
+ raise TypeError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
553
+
554
+
555
+ ImageProcessingMixin.push_to_hub = copy_func(ImageProcessingMixin.push_to_hub)
556
+ if ImageProcessingMixin.push_to_hub.__doc__ is not None:
557
+ ImageProcessingMixin.push_to_hub.__doc__ = ImageProcessingMixin.push_to_hub.__doc__.format(
558
+ object="image processor", object_class="AutoImageProcessor", object_files="image processor file"
559
+ )
image_processing_utils.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Iterable, Optional, Union
17
+
18
+ import numpy as np
19
+
20
+ from .image_processing_base import BatchFeature, ImageProcessingMixin
21
+ from .image_transforms import center_crop, normalize, rescale
22
+ from .image_utils import ChannelDimension
23
+ from .utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ INIT_SERVICE_KWARGS = [
30
+ "processor_class",
31
+ "image_processor_type",
32
+ ]
33
+
34
+
35
+ class BaseImageProcessor(ImageProcessingMixin):
36
+ def __init__(self, **kwargs):
37
+ super().__init__(**kwargs)
38
+
39
+ def __call__(self, images, **kwargs) -> BatchFeature:
40
+ """Preprocess an image or a batch of images."""
41
+ return self.preprocess(images, **kwargs)
42
+
43
+ def preprocess(self, images, **kwargs) -> BatchFeature:
44
+ raise NotImplementedError("Each image processor must implement its own preprocess method")
45
+
46
+ def rescale(
47
+ self,
48
+ image: np.ndarray,
49
+ scale: float,
50
+ data_format: Optional[Union[str, ChannelDimension]] = None,
51
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
52
+ **kwargs,
53
+ ) -> np.ndarray:
54
+ """
55
+ Rescale an image by a scale factor. image = image * scale.
56
+
57
+ Args:
58
+ image (`np.ndarray`):
59
+ Image to rescale.
60
+ scale (`float`):
61
+ The scaling factor to rescale pixel values by.
62
+ data_format (`str` or `ChannelDimension`, *optional*):
63
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
64
+ image is used. Can be one of:
65
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
66
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
67
+ input_data_format (`ChannelDimension` or `str`, *optional*):
68
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
69
+ from the input image. Can be one of:
70
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
71
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
72
+
73
+ Returns:
74
+ `np.ndarray`: The rescaled image.
75
+ """
76
+ return rescale(image, scale=scale, data_format=data_format, input_data_format=input_data_format, **kwargs)
77
+
78
+ def normalize(
79
+ self,
80
+ image: np.ndarray,
81
+ mean: Union[float, Iterable[float]],
82
+ std: Union[float, Iterable[float]],
83
+ data_format: Optional[Union[str, ChannelDimension]] = None,
84
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
85
+ **kwargs,
86
+ ) -> np.ndarray:
87
+ """
88
+ Normalize an image. image = (image - image_mean) / image_std.
89
+
90
+ Args:
91
+ image (`np.ndarray`):
92
+ Image to normalize.
93
+ mean (`float` or `Iterable[float]`):
94
+ Image mean to use for normalization.
95
+ std (`float` or `Iterable[float]`):
96
+ Image standard deviation to use for normalization.
97
+ data_format (`str` or `ChannelDimension`, *optional*):
98
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
99
+ image is used. Can be one of:
100
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
101
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
102
+ input_data_format (`ChannelDimension` or `str`, *optional*):
103
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
104
+ from the input image. Can be one of:
105
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
106
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
107
+
108
+ Returns:
109
+ `np.ndarray`: The normalized image.
110
+ """
111
+ return normalize(
112
+ image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs
113
+ )
114
+
115
+ def center_crop(
116
+ self,
117
+ image: np.ndarray,
118
+ size: Dict[str, int],
119
+ data_format: Optional[Union[str, ChannelDimension]] = None,
120
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
121
+ **kwargs,
122
+ ) -> np.ndarray:
123
+ """
124
+ Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along
125
+ any edge, the image is padded with 0's and then center cropped.
126
+
127
+ Args:
128
+ image (`np.ndarray`):
129
+ Image to center crop.
130
+ size (`Dict[str, int]`):
131
+ Size of the output image.
132
+ data_format (`str` or `ChannelDimension`, *optional*):
133
+ The channel dimension format for the output image. If unset, the channel dimension format of the input
134
+ image is used. Can be one of:
135
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
136
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
137
+ input_data_format (`ChannelDimension` or `str`, *optional*):
138
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
139
+ from the input image. Can be one of:
140
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
141
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
142
+ """
143
+ size = get_size_dict(size)
144
+ if "height" not in size or "width" not in size:
145
+ raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}")
146
+ return center_crop(
147
+ image,
148
+ size=(size["height"], size["width"]),
149
+ data_format=data_format,
150
+ input_data_format=input_data_format,
151
+ **kwargs,
152
+ )
153
+
154
+ def to_dict(self):
155
+ encoder_dict = super().to_dict()
156
+ encoder_dict.pop("_valid_processor_keys", None)
157
+ return encoder_dict
158
+
159
+
160
+ VALID_SIZE_DICT_KEYS = (
161
+ {"height", "width"},
162
+ {"shortest_edge"},
163
+ {"shortest_edge", "longest_edge"},
164
+ {"longest_edge"},
165
+ {"max_height", "max_width"},
166
+ )
167
+
168
+
169
+ def is_valid_size_dict(size_dict):
170
+ if not isinstance(size_dict, dict):
171
+ return False
172
+
173
+ size_dict_keys = set(size_dict.keys())
174
+ for allowed_keys in VALID_SIZE_DICT_KEYS:
175
+ if size_dict_keys == allowed_keys:
176
+ return True
177
+ return False
178
+
179
+
180
+ def convert_to_size_dict(
181
+ size, max_size: Optional[int] = None, default_to_square: bool = True, height_width_order: bool = True
182
+ ):
183
+ # By default, if size is an int we assume it represents a tuple of (size, size).
184
+ if isinstance(size, int) and default_to_square:
185
+ if max_size is not None:
186
+ raise ValueError("Cannot specify both size as an int, with default_to_square=True and max_size")
187
+ return {"height": size, "width": size}
188
+ # In other configs, if size is an int and default_to_square is False, size represents the length of
189
+ # the shortest edge after resizing.
190
+ elif isinstance(size, int) and not default_to_square:
191
+ size_dict = {"shortest_edge": size}
192
+ if max_size is not None:
193
+ size_dict["longest_edge"] = max_size
194
+ return size_dict
195
+ # Otherwise, if size is a tuple it's either (height, width) or (width, height)
196
+ elif isinstance(size, (tuple, list)) and height_width_order:
197
+ return {"height": size[0], "width": size[1]}
198
+ elif isinstance(size, (tuple, list)) and not height_width_order:
199
+ return {"height": size[1], "width": size[0]}
200
+ elif size is None and max_size is not None:
201
+ if default_to_square:
202
+ raise ValueError("Cannot specify both default_to_square=True and max_size")
203
+ return {"longest_edge": max_size}
204
+
205
+ raise ValueError(f"Could not convert size input to size dict: {size}")
206
+
207
+
208
+ def get_size_dict(
209
+ size: Union[int, Iterable[int], Dict[str, int]] = None,
210
+ max_size: Optional[int] = None,
211
+ height_width_order: bool = True,
212
+ default_to_square: bool = True,
213
+ param_name="size",
214
+ ) -> dict:
215
+ """
216
+ Converts the old size parameter in the config into the new dict expected in the config. This is to ensure backwards
217
+ compatibility with the old image processor configs and removes ambiguity over whether the tuple is in (height,
218
+ width) or (width, height) format.
219
+
220
+ - If `size` is tuple, it is converted to `{"height": size[0], "width": size[1]}` or `{"height": size[1], "width":
221
+ size[0]}` if `height_width_order` is `False`.
222
+ - If `size` is an int, and `default_to_square` is `True`, it is converted to `{"height": size, "width": size}`.
223
+ - If `size` is an int and `default_to_square` is False, it is converted to `{"shortest_edge": size}`. If `max_size`
224
+ is set, it is added to the dict as `{"longest_edge": max_size}`.
225
+
226
+ Args:
227
+ size (`Union[int, Iterable[int], Dict[str, int]]`, *optional*):
228
+ The `size` parameter to be cast into a size dictionary.
229
+ max_size (`Optional[int]`, *optional*):
230
+ The `max_size` parameter to be cast into a size dictionary.
231
+ height_width_order (`bool`, *optional*, defaults to `True`):
232
+ If `size` is a tuple, whether it's in (height, width) or (width, height) order.
233
+ default_to_square (`bool`, *optional*, defaults to `True`):
234
+ If `size` is an int, whether to default to a square image or not.
235
+ """
236
+ if not isinstance(size, dict):
237
+ size_dict = convert_to_size_dict(size, max_size, default_to_square, height_width_order)
238
+ logger.info(
239
+ f"{param_name} should be a dictionary on of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size}."
240
+ f" Converted to {size_dict}.",
241
+ )
242
+ else:
243
+ size_dict = size
244
+
245
+ if not is_valid_size_dict(size_dict):
246
+ raise ValueError(
247
+ f"{param_name} must have one of the following set of keys: {VALID_SIZE_DICT_KEYS}, got {size_dict.keys()}"
248
+ )
249
+ return size_dict
250
+
251
+
252
+ def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple:
253
+ """
254
+ Selects the best resolution from a list of possible resolutions based on the original size.
255
+
256
+ This is done by calculating the effective and wasted resolution for each possible resolution.
257
+
258
+ The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution.
259
+
260
+ Args:
261
+ original_size (tuple):
262
+ The original size of the image in the format (height, width).
263
+ possible_resolutions (list):
264
+ A list of possible resolutions in the format [(height1, width1), (height2, width2), ...].
265
+
266
+ Returns:
267
+ tuple: The best fit resolution in the format (height, width).
268
+ """
269
+ original_height, original_width = original_size
270
+ best_fit = None
271
+ max_effective_resolution = 0
272
+ min_wasted_resolution = float("inf")
273
+
274
+ for height, width in possible_resolutions:
275
+ scale = min(width / original_width, height / original_height)
276
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
277
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
278
+ wasted_resolution = (width * height) - effective_resolution
279
+
280
+ if effective_resolution > max_effective_resolution or (
281
+ effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution
282
+ ):
283
+ max_effective_resolution = effective_resolution
284
+ min_wasted_resolution = wasted_resolution
285
+ best_fit = (height, width)
286
+
287
+ return best_fit
image_processing_utils_fast.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import functools
17
+ from dataclasses import dataclass
18
+ from typing import Any, Iterable, List, Optional, Tuple
19
+
20
+ from .image_processing_utils import BaseImageProcessor
21
+ from .utils.import_utils import is_torch_available, is_torchvision_available
22
+
23
+
24
+ if is_torchvision_available():
25
+ from torchvision.transforms import Compose
26
+
27
+ if is_torch_available():
28
+ import torch
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class SizeDict:
33
+ """
34
+ Hashable dictionary to store image size information.
35
+ """
36
+
37
+ height: int = None
38
+ width: int = None
39
+ longest_edge: int = None
40
+ shortest_edge: int = None
41
+ max_height: int = None
42
+ max_width: int = None
43
+
44
+ def __getitem__(self, key):
45
+ if hasattr(self, key):
46
+ return getattr(self, key)
47
+ raise KeyError(f"Key {key} not found in SizeDict.")
48
+
49
+
50
+ class BaseImageProcessorFast(BaseImageProcessor):
51
+ _transform_params = None
52
+
53
+ def _build_transforms(self, **kwargs) -> "Compose":
54
+ """
55
+ Given the input settings e.g. do_resize, build the image transforms.
56
+ """
57
+ raise NotImplementedError
58
+
59
+ def _validate_params(self, **kwargs) -> None:
60
+ for k, v in kwargs.items():
61
+ if k not in self._transform_params:
62
+ raise ValueError(f"Invalid transform parameter {k}={v}.")
63
+
64
+ @functools.lru_cache(maxsize=1)
65
+ def get_transforms(self, **kwargs) -> "Compose":
66
+ self._validate_params(**kwargs)
67
+ return self._build_transforms(**kwargs)
68
+
69
+ def to_dict(self):
70
+ encoder_dict = super().to_dict()
71
+ encoder_dict.pop("_transform_params", None)
72
+ return encoder_dict
73
+
74
+
75
+ def get_image_size_for_max_height_width(
76
+ image_size: Tuple[int, int],
77
+ max_height: int,
78
+ max_width: int,
79
+ ) -> Tuple[int, int]:
80
+ """
81
+ Computes the output image size given the input image and the maximum allowed height and width. Keep aspect ratio.
82
+ Important, even if image_height < max_height and image_width < max_width, the image will be resized
83
+ to at least one of the edges be equal to max_height or max_width.
84
+
85
+ For example:
86
+ - input_size: (100, 200), max_height: 50, max_width: 50 -> output_size: (25, 50)
87
+ - input_size: (100, 200), max_height: 200, max_width: 500 -> output_size: (200, 400)
88
+
89
+ Args:
90
+ image_size (`Tuple[int, int]`):
91
+ The image to resize.
92
+ max_height (`int`):
93
+ The maximum allowed height.
94
+ max_width (`int`):
95
+ The maximum allowed width.
96
+ """
97
+ height, width = image_size
98
+ height_scale = max_height / height
99
+ width_scale = max_width / width
100
+ min_scale = min(height_scale, width_scale)
101
+ new_height = int(height * min_scale)
102
+ new_width = int(width * min_scale)
103
+ return new_height, new_width
104
+
105
+
106
+ def safe_squeeze(tensor: "torch.Tensor", axis: Optional[int] = None) -> "torch.Tensor":
107
+ """
108
+ Squeezes a tensor, but only if the axis specified has dim 1.
109
+ """
110
+ if axis is None:
111
+ return tensor.squeeze()
112
+
113
+ try:
114
+ return tensor.squeeze(axis=axis)
115
+ except ValueError:
116
+ return tensor
117
+
118
+
119
+ def max_across_indices(values: Iterable[Any]) -> List[Any]:
120
+ """
121
+ Return the maximum value across all indices of an iterable of values.
122
+ """
123
+ return [max(values_i) for values_i in zip(*values)]
124
+
125
+
126
+ def get_max_height_width(images: List["torch.Tensor"]) -> Tuple[int]:
127
+ """
128
+ Get the maximum height and width across all images in a batch.
129
+ """
130
+
131
+ _, max_height, max_width = max_across_indices([img.shape for img in images])
132
+
133
+ return (max_height, max_width)
image_transforms.py ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warnings
17
+ from math import ceil
18
+ from typing import Iterable, List, Optional, Tuple, Union
19
+
20
+ import numpy as np
21
+
22
+ from .image_utils import (
23
+ ChannelDimension,
24
+ ImageInput,
25
+ get_channel_dimension_axis,
26
+ get_image_size,
27
+ infer_channel_dimension_format,
28
+ )
29
+ from .utils import ExplicitEnum, TensorType, is_jax_tensor, is_tf_tensor, is_torch_tensor
30
+ from .utils.import_utils import (
31
+ is_flax_available,
32
+ is_tf_available,
33
+ is_torch_available,
34
+ is_torchvision_available,
35
+ is_torchvision_v2_available,
36
+ is_vision_available,
37
+ requires_backends,
38
+ )
39
+
40
+
41
+ if is_vision_available():
42
+ import PIL
43
+
44
+ from .image_utils import PILImageResampling
45
+
46
+ if is_torch_available():
47
+ import torch
48
+
49
+ if is_tf_available():
50
+ import tensorflow as tf
51
+
52
+ if is_flax_available():
53
+ import jax.numpy as jnp
54
+
55
+ if is_torchvision_v2_available():
56
+ from torchvision.transforms.v2 import functional as F
57
+ elif is_torchvision_available():
58
+ from torchvision.transforms import functional as F
59
+
60
+
61
+ def to_channel_dimension_format(
62
+ image: np.ndarray,
63
+ channel_dim: Union[ChannelDimension, str],
64
+ input_channel_dim: Optional[Union[ChannelDimension, str]] = None,
65
+ ) -> np.ndarray:
66
+ """
67
+ Converts `image` to the channel dimension format specified by `channel_dim`.
68
+
69
+ Args:
70
+ image (`numpy.ndarray`):
71
+ The image to have its channel dimension set.
72
+ channel_dim (`ChannelDimension`):
73
+ The channel dimension format to use.
74
+ input_channel_dim (`ChannelDimension`, *optional*):
75
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
76
+
77
+ Returns:
78
+ `np.ndarray`: The image with the channel dimension set to `channel_dim`.
79
+ """
80
+ if not isinstance(image, np.ndarray):
81
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
82
+
83
+ if input_channel_dim is None:
84
+ input_channel_dim = infer_channel_dimension_format(image)
85
+
86
+ target_channel_dim = ChannelDimension(channel_dim)
87
+ if input_channel_dim == target_channel_dim:
88
+ return image
89
+
90
+ if target_channel_dim == ChannelDimension.FIRST:
91
+ image = image.transpose((2, 0, 1))
92
+ elif target_channel_dim == ChannelDimension.LAST:
93
+ image = image.transpose((1, 2, 0))
94
+ else:
95
+ raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))
96
+
97
+ return image
98
+
99
+
100
+ def rescale(
101
+ image: np.ndarray,
102
+ scale: float,
103
+ data_format: Optional[ChannelDimension] = None,
104
+ dtype: np.dtype = np.float32,
105
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
106
+ ) -> np.ndarray:
107
+ """
108
+ Rescales `image` by `scale`.
109
+
110
+ Args:
111
+ image (`np.ndarray`):
112
+ The image to rescale.
113
+ scale (`float`):
114
+ The scale to use for rescaling the image.
115
+ data_format (`ChannelDimension`, *optional*):
116
+ The channel dimension format of the image. If not provided, it will be the same as the input image.
117
+ dtype (`np.dtype`, *optional*, defaults to `np.float32`):
118
+ The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
119
+ extractors.
120
+ input_data_format (`ChannelDimension`, *optional*):
121
+ The channel dimension format of the input image. If not provided, it will be inferred from the input image.
122
+
123
+ Returns:
124
+ `np.ndarray`: The rescaled image.
125
+ """
126
+ if not isinstance(image, np.ndarray):
127
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
128
+
129
+ rescaled_image = image.astype(np.float64) * scale # Numpy type promotion has changed, so always upcast first
130
+ if data_format is not None:
131
+ rescaled_image = to_channel_dimension_format(rescaled_image, data_format, input_data_format)
132
+
133
+ rescaled_image = rescaled_image.astype(dtype) # Finally downcast to the desired dtype at the end
134
+
135
+ return rescaled_image
136
+
137
+
138
+ def _rescale_for_pil_conversion(image):
139
+ """
140
+ Detects whether or not the image needs to be rescaled before being converted to a PIL image.
141
+
142
+ The assumption is that if the image is of type `np.float` and all values are between 0 and 1, it needs to be
143
+ rescaled.
144
+ """
145
+ if image.dtype == np.uint8:
146
+ do_rescale = False
147
+ elif np.allclose(image, image.astype(int)):
148
+ if np.all(0 <= image) and np.all(image <= 255):
149
+ do_rescale = False
150
+ else:
151
+ raise ValueError(
152
+ "The image to be converted to a PIL image contains values outside the range [0, 255], "
153
+ f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
154
+ )
155
+ elif np.all(0 <= image) and np.all(image <= 1):
156
+ do_rescale = True
157
+ else:
158
+ raise ValueError(
159
+ "The image to be converted to a PIL image contains values outside the range [0, 1], "
160
+ f"got [{image.min()}, {image.max()}] which cannot be converted to uint8."
161
+ )
162
+ return do_rescale
163
+
164
+
165
+ def to_pil_image(
166
+ image: Union[np.ndarray, "PIL.Image.Image", "torch.Tensor", "tf.Tensor", "jnp.ndarray"],
167
+ do_rescale: Optional[bool] = None,
168
+ image_mode: Optional[str] = None,
169
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
170
+ ) -> "PIL.Image.Image":
171
+ """
172
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
173
+ needed.
174
+
175
+ Args:
176
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`):
177
+ The image to convert to the `PIL.Image` format.
178
+ do_rescale (`bool`, *optional*):
179
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
180
+ to `True` if the image type is a floating type and casting to `int` would result in a loss of precision,
181
+ and `False` otherwise.
182
+ image_mode (`str`, *optional*):
183
+ The mode to use for the PIL image. If unset, will use the default mode for the input image type.
184
+ input_data_format (`ChannelDimension`, *optional*):
185
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
186
+
187
+ Returns:
188
+ `PIL.Image.Image`: The converted image.
189
+ """
190
+ requires_backends(to_pil_image, ["vision"])
191
+
192
+ if isinstance(image, PIL.Image.Image):
193
+ return image
194
+
195
+ # Convert all tensors to numpy arrays before converting to PIL image
196
+ if is_torch_tensor(image) or is_tf_tensor(image):
197
+ image = image.numpy()
198
+ elif is_jax_tensor(image):
199
+ image = np.array(image)
200
+ elif not isinstance(image, np.ndarray):
201
+ raise ValueError("Input image type not supported: {}".format(type(image)))
202
+
203
+ # If the channel has been moved to first dim, we put it back at the end.
204
+ image = to_channel_dimension_format(image, ChannelDimension.LAST, input_data_format)
205
+
206
+ # If there is a single channel, we squeeze it, as otherwise PIL can't handle it.
207
+ image = np.squeeze(image, axis=-1) if image.shape[-1] == 1 else image
208
+
209
+ # PIL.Image can only store uint8 values so we rescale the image to be between 0 and 255 if needed.
210
+ do_rescale = _rescale_for_pil_conversion(image) if do_rescale is None else do_rescale
211
+
212
+ if do_rescale:
213
+ image = rescale(image, 255)
214
+
215
+ image = image.astype(np.uint8)
216
+ return PIL.Image.fromarray(image, mode=image_mode)
217
+
218
+
219
+ # Logic adapted from torchvision resizing logic: https://github.com/pytorch/vision/blob/511924c1ced4ce0461197e5caa64ce5b9e558aab/torchvision/transforms/functional.py#L366
220
+ def get_resize_output_image_size(
221
+ input_image: np.ndarray,
222
+ size: Union[int, Tuple[int, int], List[int], Tuple[int]],
223
+ default_to_square: bool = True,
224
+ max_size: Optional[int] = None,
225
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
226
+ ) -> tuple:
227
+ """
228
+ Find the target (height, width) dimension of the output image after resizing given the input image and the desired
229
+ size.
230
+
231
+ Args:
232
+ input_image (`np.ndarray`):
233
+ The image to resize.
234
+ size (`int` or `Tuple[int, int]` or List[int] or `Tuple[int]`):
235
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to
236
+ this.
237
+
238
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
239
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this
240
+ number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
241
+ default_to_square (`bool`, *optional*, defaults to `True`):
242
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square
243
+ (`size`,`size`). If set to `False`, will replicate
244
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
245
+ with support for resizing only the smallest edge and providing an optional `max_size`.
246
+ max_size (`int`, *optional*):
247
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater
248
+ than `max_size` after being resized according to `size`, then the image is resized again so that the longer
249
+ edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
250
+ than `size`. Only used if `default_to_square` is `False`.
251
+ input_data_format (`ChannelDimension`, *optional*):
252
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
253
+
254
+ Returns:
255
+ `tuple`: The target (height, width) dimension of the output image after resizing.
256
+ """
257
+ if isinstance(size, (tuple, list)):
258
+ if len(size) == 2:
259
+ return tuple(size)
260
+ elif len(size) == 1:
261
+ # Perform same logic as if size was an int
262
+ size = size[0]
263
+ else:
264
+ raise ValueError("size must have 1 or 2 elements if it is a list or tuple")
265
+
266
+ if default_to_square:
267
+ return (size, size)
268
+
269
+ height, width = get_image_size(input_image, input_data_format)
270
+ short, long = (width, height) if width <= height else (height, width)
271
+ requested_new_short = size
272
+
273
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
274
+
275
+ if max_size is not None:
276
+ if max_size <= requested_new_short:
277
+ raise ValueError(
278
+ f"max_size = {max_size} must be strictly greater than the requested "
279
+ f"size for the smaller edge size = {size}"
280
+ )
281
+ if new_long > max_size:
282
+ new_short, new_long = int(max_size * new_short / new_long), max_size
283
+
284
+ return (new_long, new_short) if width <= height else (new_short, new_long)
285
+
286
+
287
+ def resize(
288
+ image: np.ndarray,
289
+ size: Tuple[int, int],
290
+ resample: "PILImageResampling" = None,
291
+ reducing_gap: Optional[int] = None,
292
+ data_format: Optional[ChannelDimension] = None,
293
+ return_numpy: bool = True,
294
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
295
+ ) -> np.ndarray:
296
+ """
297
+ Resizes `image` to `(height, width)` specified by `size` using the PIL library.
298
+
299
+ Args:
300
+ image (`np.ndarray`):
301
+ The image to resize.
302
+ size (`Tuple[int, int]`):
303
+ The size to use for resizing the image.
304
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
305
+ The filter to user for resampling.
306
+ reducing_gap (`int`, *optional*):
307
+ Apply optimization by resizing the image in two steps. The bigger `reducing_gap`, the closer the result to
308
+ the fair resampling. See corresponding Pillow documentation for more details.
309
+ data_format (`ChannelDimension`, *optional*):
310
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
311
+ return_numpy (`bool`, *optional*, defaults to `True`):
312
+ Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
313
+ returned.
314
+ input_data_format (`ChannelDimension`, *optional*):
315
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
316
+
317
+ Returns:
318
+ `np.ndarray`: The resized image.
319
+ """
320
+ requires_backends(resize, ["vision"])
321
+
322
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
323
+
324
+ if not len(size) == 2:
325
+ raise ValueError("size must have 2 elements")
326
+
327
+ # For all transformations, we want to keep the same data format as the input image unless otherwise specified.
328
+ # The resized image from PIL will always have channels last, so find the input format first.
329
+ if input_data_format is None:
330
+ input_data_format = infer_channel_dimension_format(image)
331
+ data_format = input_data_format if data_format is None else data_format
332
+
333
+ # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
334
+ # the pillow library to resize the image and then convert back to numpy
335
+ do_rescale = False
336
+ if not isinstance(image, PIL.Image.Image):
337
+ do_rescale = _rescale_for_pil_conversion(image)
338
+ image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
339
+ height, width = size
340
+ # PIL images are in the format (width, height)
341
+ resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap)
342
+
343
+ if return_numpy:
344
+ resized_image = np.array(resized_image)
345
+ # If the input image channel dimension was of size 1, then it is dropped when converting to a PIL image
346
+ # so we need to add it back if necessary.
347
+ resized_image = np.expand_dims(resized_image, axis=-1) if resized_image.ndim == 2 else resized_image
348
+ # The image is always in channels last format after converting from a PIL image
349
+ resized_image = to_channel_dimension_format(
350
+ resized_image, data_format, input_channel_dim=ChannelDimension.LAST
351
+ )
352
+ # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
353
+ # rescale it back to the original range.
354
+ resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
355
+ return resized_image
356
+
357
+
358
+ def normalize(
359
+ image: np.ndarray,
360
+ mean: Union[float, Iterable[float]],
361
+ std: Union[float, Iterable[float]],
362
+ data_format: Optional[ChannelDimension] = None,
363
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
364
+ ) -> np.ndarray:
365
+ """
366
+ Normalizes `image` using the mean and standard deviation specified by `mean` and `std`.
367
+
368
+ image = (image - mean) / std
369
+
370
+ Args:
371
+ image (`np.ndarray`):
372
+ The image to normalize.
373
+ mean (`float` or `Iterable[float]`):
374
+ The mean to use for normalization.
375
+ std (`float` or `Iterable[float]`):
376
+ The standard deviation to use for normalization.
377
+ data_format (`ChannelDimension`, *optional*):
378
+ The channel dimension format of the output image. If unset, will use the inferred format from the input.
379
+ input_data_format (`ChannelDimension`, *optional*):
380
+ The channel dimension format of the input image. If unset, will use the inferred format from the input.
381
+ """
382
+ if not isinstance(image, np.ndarray):
383
+ raise ValueError("image must be a numpy array")
384
+
385
+ if input_data_format is None:
386
+ input_data_format = infer_channel_dimension_format(image)
387
+
388
+ channel_axis = get_channel_dimension_axis(image, input_data_format=input_data_format)
389
+ num_channels = image.shape[channel_axis]
390
+
391
+ # We cast to float32 to avoid errors that can occur when subtracting uint8 values.
392
+ # We preserve the original dtype if it is a float type to prevent upcasting float16.
393
+ if not np.issubdtype(image.dtype, np.floating):
394
+ image = image.astype(np.float32)
395
+
396
+ if isinstance(mean, Iterable):
397
+ if len(mean) != num_channels:
398
+ raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}")
399
+ else:
400
+ mean = [mean] * num_channels
401
+ mean = np.array(mean, dtype=image.dtype)
402
+
403
+ if isinstance(std, Iterable):
404
+ if len(std) != num_channels:
405
+ raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}")
406
+ else:
407
+ std = [std] * num_channels
408
+ std = np.array(std, dtype=image.dtype)
409
+
410
+ if input_data_format == ChannelDimension.LAST:
411
+ image = (image - mean) / std
412
+ else:
413
+ image = ((image.T - mean) / std).T
414
+
415
+ image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
416
+ return image
417
+
418
+
419
+ def center_crop(
420
+ image: np.ndarray,
421
+ size: Tuple[int, int],
422
+ data_format: Optional[Union[str, ChannelDimension]] = None,
423
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
424
+ return_numpy: Optional[bool] = None,
425
+ ) -> np.ndarray:
426
+ """
427
+ Crops the `image` to the specified `size` using a center crop. Note that if the image is too small to be cropped to
428
+ the size given, it will be padded (so the returned result will always be of size `size`).
429
+
430
+ Args:
431
+ image (`np.ndarray`):
432
+ The image to crop.
433
+ size (`Tuple[int, int]`):
434
+ The target size for the cropped image.
435
+ data_format (`str` or `ChannelDimension`, *optional*):
436
+ The channel dimension format for the output image. Can be one of:
437
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
438
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
439
+ If unset, will use the inferred format of the input image.
440
+ input_data_format (`str` or `ChannelDimension`, *optional*):
441
+ The channel dimension format for the input image. Can be one of:
442
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
443
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
444
+ If unset, will use the inferred format of the input image.
445
+ return_numpy (`bool`, *optional*):
446
+ Whether or not to return the cropped image as a numpy array. Used for backwards compatibility with the
447
+ previous ImageFeatureExtractionMixin method.
448
+ - Unset: will return the same type as the input image.
449
+ - `True`: will return a numpy array.
450
+ - `False`: will return a `PIL.Image.Image` object.
451
+ Returns:
452
+ `np.ndarray`: The cropped image.
453
+ """
454
+ requires_backends(center_crop, ["vision"])
455
+
456
+ if return_numpy is not None:
457
+ warnings.warn("return_numpy is deprecated and will be removed in v.4.33", FutureWarning)
458
+
459
+ return_numpy = True if return_numpy is None else return_numpy
460
+
461
+ if not isinstance(image, np.ndarray):
462
+ raise TypeError(f"Input image must be of type np.ndarray, got {type(image)}")
463
+
464
+ if not isinstance(size, Iterable) or len(size) != 2:
465
+ raise ValueError("size must have 2 elements representing the height and width of the output image")
466
+
467
+ if input_data_format is None:
468
+ input_data_format = infer_channel_dimension_format(image)
469
+ output_data_format = data_format if data_format is not None else input_data_format
470
+
471
+ # We perform the crop in (C, H, W) format and then convert to the output format
472
+ image = to_channel_dimension_format(image, ChannelDimension.FIRST, input_data_format)
473
+
474
+ orig_height, orig_width = get_image_size(image, ChannelDimension.FIRST)
475
+ crop_height, crop_width = size
476
+ crop_height, crop_width = int(crop_height), int(crop_width)
477
+
478
+ # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
479
+ top = (orig_height - crop_height) // 2
480
+ bottom = top + crop_height
481
+ # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
482
+ left = (orig_width - crop_width) // 2
483
+ right = left + crop_width
484
+
485
+ # Check if cropped area is within image boundaries
486
+ if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
487
+ image = image[..., top:bottom, left:right]
488
+ image = to_channel_dimension_format(image, output_data_format, ChannelDimension.FIRST)
489
+ return image
490
+
491
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
492
+ new_height = max(crop_height, orig_height)
493
+ new_width = max(crop_width, orig_width)
494
+ new_shape = image.shape[:-2] + (new_height, new_width)
495
+ new_image = np.zeros_like(image, shape=new_shape)
496
+
497
+ # If the image is too small, pad it with zeros
498
+ top_pad = ceil((new_height - orig_height) / 2)
499
+ bottom_pad = top_pad + orig_height
500
+ left_pad = ceil((new_width - orig_width) / 2)
501
+ right_pad = left_pad + orig_width
502
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
503
+
504
+ top += top_pad
505
+ bottom += top_pad
506
+ left += left_pad
507
+ right += left_pad
508
+
509
+ new_image = new_image[..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)]
510
+ new_image = to_channel_dimension_format(new_image, output_data_format, ChannelDimension.FIRST)
511
+
512
+ if not return_numpy:
513
+ new_image = to_pil_image(new_image)
514
+
515
+ return new_image
516
+
517
+
518
+ def _center_to_corners_format_torch(bboxes_center: "torch.Tensor") -> "torch.Tensor":
519
+ center_x, center_y, width, height = bboxes_center.unbind(-1)
520
+ bbox_corners = torch.stack(
521
+ # top left x, top left y, bottom right x, bottom right y
522
+ [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)],
523
+ dim=-1,
524
+ )
525
+ return bbox_corners
526
+
527
+
528
+ def _center_to_corners_format_numpy(bboxes_center: np.ndarray) -> np.ndarray:
529
+ center_x, center_y, width, height = bboxes_center.T
530
+ bboxes_corners = np.stack(
531
+ # top left x, top left y, bottom right x, bottom right y
532
+ [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
533
+ axis=-1,
534
+ )
535
+ return bboxes_corners
536
+
537
+
538
+ def _center_to_corners_format_tf(bboxes_center: "tf.Tensor") -> "tf.Tensor":
539
+ center_x, center_y, width, height = tf.unstack(bboxes_center, axis=-1)
540
+ bboxes_corners = tf.stack(
541
+ # top left x, top left y, bottom right x, bottom right y
542
+ [center_x - 0.5 * width, center_y - 0.5 * height, center_x + 0.5 * width, center_y + 0.5 * height],
543
+ axis=-1,
544
+ )
545
+ return bboxes_corners
546
+
547
+
548
+ # 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
549
+ def center_to_corners_format(bboxes_center: TensorType) -> TensorType:
550
+ """
551
+ Converts bounding boxes from center format to corners format.
552
+
553
+ center format: contains the coordinate for the center of the box and its width, height dimensions
554
+ (center_x, center_y, width, height)
555
+ corners format: contains the coodinates for the top-left and bottom-right corners of the box
556
+ (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
557
+ """
558
+ # Function is used during model forward pass, so we use the input framework if possible, without
559
+ # converting to numpy
560
+ if is_torch_tensor(bboxes_center):
561
+ return _center_to_corners_format_torch(bboxes_center)
562
+ elif isinstance(bboxes_center, np.ndarray):
563
+ return _center_to_corners_format_numpy(bboxes_center)
564
+ elif is_tf_tensor(bboxes_center):
565
+ return _center_to_corners_format_tf(bboxes_center)
566
+
567
+ raise ValueError(f"Unsupported input type {type(bboxes_center)}")
568
+
569
+
570
+ def _corners_to_center_format_torch(bboxes_corners: "torch.Tensor") -> "torch.Tensor":
571
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.unbind(-1)
572
+ b = [
573
+ (top_left_x + bottom_right_x) / 2, # center x
574
+ (top_left_y + bottom_right_y) / 2, # center y
575
+ (bottom_right_x - top_left_x), # width
576
+ (bottom_right_y - top_left_y), # height
577
+ ]
578
+ return torch.stack(b, dim=-1)
579
+
580
+
581
+ def _corners_to_center_format_numpy(bboxes_corners: np.ndarray) -> np.ndarray:
582
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = bboxes_corners.T
583
+ bboxes_center = np.stack(
584
+ [
585
+ (top_left_x + bottom_right_x) / 2, # center x
586
+ (top_left_y + bottom_right_y) / 2, # center y
587
+ (bottom_right_x - top_left_x), # width
588
+ (bottom_right_y - top_left_y), # height
589
+ ],
590
+ axis=-1,
591
+ )
592
+ return bboxes_center
593
+
594
+
595
+ def _corners_to_center_format_tf(bboxes_corners: "tf.Tensor") -> "tf.Tensor":
596
+ top_left_x, top_left_y, bottom_right_x, bottom_right_y = tf.unstack(bboxes_corners, axis=-1)
597
+ bboxes_center = tf.stack(
598
+ [
599
+ (top_left_x + bottom_right_x) / 2, # center x
600
+ (top_left_y + bottom_right_y) / 2, # center y
601
+ (bottom_right_x - top_left_x), # width
602
+ (bottom_right_y - top_left_y), # height
603
+ ],
604
+ axis=-1,
605
+ )
606
+ return bboxes_center
607
+
608
+
609
+ def corners_to_center_format(bboxes_corners: TensorType) -> TensorType:
610
+ """
611
+ Converts bounding boxes from corners format to center format.
612
+
613
+ corners format: contains the coordinates for the top-left and bottom-right corners of the box
614
+ (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
615
+ center format: contains the coordinate for the center of the box and its the width, height dimensions
616
+ (center_x, center_y, width, height)
617
+ """
618
+ # Inverse function accepts different input types so implemented here too
619
+ if is_torch_tensor(bboxes_corners):
620
+ return _corners_to_center_format_torch(bboxes_corners)
621
+ elif isinstance(bboxes_corners, np.ndarray):
622
+ return _corners_to_center_format_numpy(bboxes_corners)
623
+ elif is_tf_tensor(bboxes_corners):
624
+ return _corners_to_center_format_tf(bboxes_corners)
625
+
626
+ raise ValueError(f"Unsupported input type {type(bboxes_corners)}")
627
+
628
+
629
+ # 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
630
+ # Copyright (c) 2018, Alexander Kirillov
631
+ # All rights reserved.
632
+ def rgb_to_id(color):
633
+ """
634
+ Converts RGB color to unique ID.
635
+ """
636
+ if isinstance(color, np.ndarray) and len(color.shape) == 3:
637
+ if color.dtype == np.uint8:
638
+ color = color.astype(np.int32)
639
+ return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
640
+ return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
641
+
642
+
643
+ def id_to_rgb(id_map):
644
+ """
645
+ Converts unique ID to RGB color.
646
+ """
647
+ if isinstance(id_map, np.ndarray):
648
+ id_map_copy = id_map.copy()
649
+ rgb_shape = tuple(list(id_map.shape) + [3])
650
+ rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
651
+ for i in range(3):
652
+ rgb_map[..., i] = id_map_copy % 256
653
+ id_map_copy //= 256
654
+ return rgb_map
655
+ color = []
656
+ for _ in range(3):
657
+ color.append(id_map % 256)
658
+ id_map //= 256
659
+ return color
660
+
661
+
662
+ class PaddingMode(ExplicitEnum):
663
+ """
664
+ Enum class for the different padding modes to use when padding images.
665
+ """
666
+
667
+ CONSTANT = "constant"
668
+ REFLECT = "reflect"
669
+ REPLICATE = "replicate"
670
+ SYMMETRIC = "symmetric"
671
+
672
+
673
+ def pad(
674
+ image: np.ndarray,
675
+ padding: Union[int, Tuple[int, int], Iterable[Tuple[int, int]]],
676
+ mode: PaddingMode = PaddingMode.CONSTANT,
677
+ constant_values: Union[float, Iterable[float]] = 0.0,
678
+ data_format: Optional[Union[str, ChannelDimension]] = None,
679
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
680
+ ) -> np.ndarray:
681
+ """
682
+ Pads the `image` with the specified (height, width) `padding` and `mode`.
683
+
684
+ Args:
685
+ image (`np.ndarray`):
686
+ The image to pad.
687
+ padding (`int` or `Tuple[int, int]` or `Iterable[Tuple[int, int]]`):
688
+ Padding to apply to the edges of the height, width axes. Can be one of three formats:
689
+ - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
690
+ - `((before, after),)` yields same before and after pad for height and width.
691
+ - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
692
+ mode (`PaddingMode`):
693
+ The padding mode to use. Can be one of:
694
+ - `"constant"`: pads with a constant value.
695
+ - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
696
+ vector along each axis.
697
+ - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
698
+ - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
699
+ constant_values (`float` or `Iterable[float]`, *optional*):
700
+ The value to use for the padding if `mode` is `"constant"`.
701
+ data_format (`str` or `ChannelDimension`, *optional*):
702
+ The channel dimension format for the output image. Can be one of:
703
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
704
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
705
+ If unset, will use same as the input image.
706
+ input_data_format (`str` or `ChannelDimension`, *optional*):
707
+ The channel dimension format for the input image. Can be one of:
708
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
709
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
710
+ If unset, will use the inferred format of the input image.
711
+
712
+ Returns:
713
+ `np.ndarray`: The padded image.
714
+
715
+ """
716
+ if input_data_format is None:
717
+ input_data_format = infer_channel_dimension_format(image)
718
+
719
+ def _expand_for_data_format(values):
720
+ """
721
+ Convert values to be in the format expected by np.pad based on the data format.
722
+ """
723
+ if isinstance(values, (int, float)):
724
+ values = ((values, values), (values, values))
725
+ elif isinstance(values, tuple) and len(values) == 1:
726
+ values = ((values[0], values[0]), (values[0], values[0]))
727
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
728
+ values = (values, values)
729
+ elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
730
+ values = values
731
+ else:
732
+ raise ValueError(f"Unsupported format: {values}")
733
+
734
+ # add 0 for channel dimension
735
+ values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))
736
+
737
+ # Add additional padding if there's a batch dimension
738
+ values = (0, *values) if image.ndim == 4 else values
739
+ return values
740
+
741
+ padding = _expand_for_data_format(padding)
742
+
743
+ if mode == PaddingMode.CONSTANT:
744
+ constant_values = _expand_for_data_format(constant_values)
745
+ image = np.pad(image, padding, mode="constant", constant_values=constant_values)
746
+ elif mode == PaddingMode.REFLECT:
747
+ image = np.pad(image, padding, mode="reflect")
748
+ elif mode == PaddingMode.REPLICATE:
749
+ image = np.pad(image, padding, mode="edge")
750
+ elif mode == PaddingMode.SYMMETRIC:
751
+ image = np.pad(image, padding, mode="symmetric")
752
+ else:
753
+ raise ValueError(f"Invalid padding mode: {mode}")
754
+
755
+ image = to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
756
+ return image
757
+
758
+
759
+ # TODO (Amy): Accept 1/3/4 channel numpy array as input and return np.array as default
760
+ def convert_to_rgb(image: ImageInput) -> ImageInput:
761
+ """
762
+ Converts an image to RGB format. Only converts if the image is of type PIL.Image.Image, otherwise returns the image
763
+ as is.
764
+ Args:
765
+ image (Image):
766
+ The image to convert.
767
+ """
768
+ requires_backends(convert_to_rgb, ["vision"])
769
+
770
+ if not isinstance(image, PIL.Image.Image):
771
+ return image
772
+
773
+ if image.mode == "RGB":
774
+ return image
775
+
776
+ image = image.convert("RGB")
777
+ return image
778
+
779
+
780
+ def flip_channel_order(
781
+ image: np.ndarray,
782
+ data_format: Optional[ChannelDimension] = None,
783
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
784
+ ) -> np.ndarray:
785
+ """
786
+ Flips the channel order of the image.
787
+
788
+ If the image is in RGB format, it will be converted to BGR and vice versa.
789
+
790
+ Args:
791
+ image (`np.ndarray`):
792
+ The image to flip.
793
+ data_format (`ChannelDimension`, *optional*):
794
+ The channel dimension format for the output image. Can be one of:
795
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
796
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
797
+ If unset, will use same as the input image.
798
+ input_data_format (`ChannelDimension`, *optional*):
799
+ The channel dimension format for the input image. Can be one of:
800
+ - `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
801
+ - `ChannelDimension.LAST`: image in (height, width, num_channels) format.
802
+ If unset, will use the inferred format of the input image.
803
+ """
804
+ input_data_format = infer_channel_dimension_format(image) if input_data_format is None else input_data_format
805
+
806
+ if input_data_format == ChannelDimension.LAST:
807
+ image = image[..., ::-1]
808
+ elif input_data_format == ChannelDimension.FIRST:
809
+ image = image[::-1, ...]
810
+ else:
811
+ raise ValueError(f"Unsupported channel dimension: {input_data_format}")
812
+
813
+ if data_format is not None:
814
+ image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
815
+ return image
816
+
817
+
818
+ def _cast_tensor_to_float(x):
819
+ if x.is_floating_point():
820
+ return x
821
+ return x.float()
822
+
823
+
824
+ class FusedRescaleNormalize:
825
+ """
826
+ Rescale and normalize the input image in one step.
827
+ """
828
+
829
+ def __init__(self, mean, std, rescale_factor: float = 1.0, inplace: bool = False):
830
+ self.mean = torch.tensor(mean) * (1.0 / rescale_factor)
831
+ self.std = torch.tensor(std) * (1.0 / rescale_factor)
832
+ self.inplace = inplace
833
+
834
+ def __call__(self, image: "torch.Tensor"):
835
+ image = _cast_tensor_to_float(image)
836
+ return F.normalize(image, self.mean, self.std, inplace=self.inplace)
837
+
838
+
839
+ class Rescale:
840
+ """
841
+ Rescale the input image by rescale factor: image *= rescale_factor.
842
+ """
843
+
844
+ def __init__(self, rescale_factor: float = 1.0):
845
+ self.rescale_factor = rescale_factor
846
+
847
+ def __call__(self, image: "torch.Tensor"):
848
+ image = image * self.rescale_factor
849
+ return image
850
+
851
+
852
+ class NumpyToTensor:
853
+ """
854
+ Convert a numpy array to a PyTorch tensor.
855
+ """
856
+
857
+ def __call__(self, image: np.ndarray):
858
+ # Same as in PyTorch, we assume incoming numpy images are in HWC format
859
+ # c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
860
+ return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()
image_utils.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import base64
17
+ import os
18
+ from io import BytesIO
19
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import requests
23
+ from packaging import version
24
+
25
+ from .utils import (
26
+ ExplicitEnum,
27
+ TensorType,
28
+ is_jax_tensor,
29
+ is_numpy_array,
30
+ is_tf_tensor,
31
+ is_torch_available,
32
+ is_torch_tensor,
33
+ is_torchvision_available,
34
+ is_vision_available,
35
+ logging,
36
+ requires_backends,
37
+ to_numpy,
38
+ )
39
+ from .utils.constants import ( # noqa: F401
40
+ IMAGENET_DEFAULT_MEAN,
41
+ IMAGENET_DEFAULT_STD,
42
+ IMAGENET_STANDARD_MEAN,
43
+ IMAGENET_STANDARD_STD,
44
+ OPENAI_CLIP_MEAN,
45
+ OPENAI_CLIP_STD,
46
+ )
47
+
48
+
49
+ if is_vision_available():
50
+ import PIL.Image
51
+ import PIL.ImageOps
52
+
53
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
54
+ PILImageResampling = PIL.Image.Resampling
55
+ else:
56
+ PILImageResampling = PIL.Image
57
+
58
+ if is_torchvision_available():
59
+ from torchvision.transforms import InterpolationMode
60
+
61
+ pil_torch_interpolation_mapping = {
62
+ PILImageResampling.NEAREST: InterpolationMode.NEAREST,
63
+ PILImageResampling.BOX: InterpolationMode.BOX,
64
+ PILImageResampling.BILINEAR: InterpolationMode.BILINEAR,
65
+ PILImageResampling.HAMMING: InterpolationMode.HAMMING,
66
+ PILImageResampling.BICUBIC: InterpolationMode.BICUBIC,
67
+ PILImageResampling.LANCZOS: InterpolationMode.LANCZOS,
68
+ }
69
+
70
+
71
+ if TYPE_CHECKING:
72
+ if is_torch_available():
73
+ import torch
74
+
75
+
76
+ logger = logging.get_logger(__name__)
77
+
78
+
79
+ ImageInput = Union[
80
+ "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"]
81
+ ] # noqa
82
+
83
+
84
+ VideoInput = Union[
85
+ List["PIL.Image.Image"],
86
+ "np.ndarray",
87
+ "torch.Tensor",
88
+ List["np.ndarray"],
89
+ List["torch.Tensor"],
90
+ List[List["PIL.Image.Image"]],
91
+ List[List["np.ndarrray"]],
92
+ List[List["torch.Tensor"]],
93
+ ] # noqa
94
+
95
+
96
+ class ChannelDimension(ExplicitEnum):
97
+ FIRST = "channels_first"
98
+ LAST = "channels_last"
99
+
100
+
101
+ class AnnotationFormat(ExplicitEnum):
102
+ COCO_DETECTION = "coco_detection"
103
+ COCO_PANOPTIC = "coco_panoptic"
104
+
105
+
106
+ class AnnotionFormat(ExplicitEnum):
107
+ COCO_DETECTION = AnnotationFormat.COCO_DETECTION.value
108
+ COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value
109
+
110
+
111
+ AnnotationType = Dict[str, Union[int, str, List[Dict]]]
112
+
113
+
114
+ def is_pil_image(img):
115
+ return is_vision_available() and isinstance(img, PIL.Image.Image)
116
+
117
+
118
+ class ImageType(ExplicitEnum):
119
+ PIL = "pillow"
120
+ TORCH = "torch"
121
+ NUMPY = "numpy"
122
+ TENSORFLOW = "tensorflow"
123
+ JAX = "jax"
124
+
125
+
126
+ def get_image_type(image):
127
+ if is_pil_image(image):
128
+ return ImageType.PIL
129
+ if is_torch_tensor(image):
130
+ return ImageType.TORCH
131
+ if is_numpy_array(image):
132
+ return ImageType.NUMPY
133
+ if is_tf_tensor(image):
134
+ return ImageType.TENSORFLOW
135
+ if is_jax_tensor(image):
136
+ return ImageType.JAX
137
+ raise ValueError(f"Unrecognised image type {type(image)}")
138
+
139
+
140
+ def is_valid_image(img):
141
+ return is_pil_image(img) or is_numpy_array(img) or is_torch_tensor(img) or is_tf_tensor(img) or is_jax_tensor(img)
142
+
143
+
144
+ def valid_images(imgs):
145
+ # If we have an list of images, make sure every image is valid
146
+ if isinstance(imgs, (list, tuple)):
147
+ for img in imgs:
148
+ if not valid_images(img):
149
+ return False
150
+ # If not a list of tuple, we have been given a single image or batched tensor of images
151
+ elif not is_valid_image(imgs):
152
+ return False
153
+ return True
154
+
155
+
156
+ def is_batched(img):
157
+ if isinstance(img, (list, tuple)):
158
+ return is_valid_image(img[0])
159
+ return False
160
+
161
+
162
+ def is_scaled_image(image: np.ndarray) -> bool:
163
+ """
164
+ Checks to see whether the pixel values have already been rescaled to [0, 1].
165
+ """
166
+ if image.dtype == np.uint8:
167
+ return False
168
+
169
+ # It's possible the image has pixel values in [0, 255] but is of floating type
170
+ return np.min(image) >= 0 and np.max(image) <= 1
171
+
172
+
173
+ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
174
+ """
175
+ Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1.
176
+ If the input is a batch of images, it is converted to a list of images.
177
+
178
+ Args:
179
+ images (`ImageInput`):
180
+ Image of images to turn into a list of images.
181
+ expected_ndims (`int`, *optional*, defaults to 3):
182
+ Expected number of dimensions for a single input image. If the input image has a different number of
183
+ dimensions, an error is raised.
184
+ """
185
+ if is_batched(images):
186
+ return images
187
+
188
+ # Either the input is a single image, in which case we create a list of length 1
189
+ if isinstance(images, PIL.Image.Image):
190
+ # PIL images are never batched
191
+ return [images]
192
+
193
+ if is_valid_image(images):
194
+ if images.ndim == expected_ndims + 1:
195
+ # Batch of images
196
+ images = list(images)
197
+ elif images.ndim == expected_ndims:
198
+ # Single image
199
+ images = [images]
200
+ else:
201
+ raise ValueError(
202
+ f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got"
203
+ f" {images.ndim} dimensions."
204
+ )
205
+ return images
206
+ raise ValueError(
207
+ "Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or "
208
+ f"jax.ndarray, but got {type(images)}."
209
+ )
210
+
211
+
212
+ def to_numpy_array(img) -> np.ndarray:
213
+ if not is_valid_image(img):
214
+ raise ValueError(f"Invalid image type: {type(img)}")
215
+
216
+ if is_vision_available() and isinstance(img, PIL.Image.Image):
217
+ return np.array(img)
218
+ return to_numpy(img)
219
+
220
+
221
+ def infer_channel_dimension_format(
222
+ image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None
223
+ ) -> ChannelDimension:
224
+ """
225
+ Infers the channel dimension format of `image`.
226
+
227
+ Args:
228
+ image (`np.ndarray`):
229
+ The image to infer the channel dimension of.
230
+ num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`):
231
+ The number of channels of the image.
232
+
233
+ Returns:
234
+ The channel dimension of the image.
235
+ """
236
+ num_channels = num_channels if num_channels is not None else (1, 3)
237
+ num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
238
+
239
+ if image.ndim == 3:
240
+ first_dim, last_dim = 0, 2
241
+ elif image.ndim == 4:
242
+ first_dim, last_dim = 1, 3
243
+ else:
244
+ raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
245
+
246
+ if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
247
+ logger.warning(
248
+ f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension."
249
+ )
250
+ return ChannelDimension.FIRST
251
+ elif image.shape[first_dim] in num_channels:
252
+ return ChannelDimension.FIRST
253
+ elif image.shape[last_dim] in num_channels:
254
+ return ChannelDimension.LAST
255
+ raise ValueError("Unable to infer channel dimension format")
256
+
257
+
258
+ def get_channel_dimension_axis(
259
+ image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None
260
+ ) -> int:
261
+ """
262
+ Returns the channel dimension axis of the image.
263
+
264
+ Args:
265
+ image (`np.ndarray`):
266
+ The image to get the channel dimension axis of.
267
+ input_data_format (`ChannelDimension` or `str`, *optional*):
268
+ The channel dimension format of the image. If `None`, will infer the channel dimension from the image.
269
+
270
+ Returns:
271
+ The channel dimension axis of the image.
272
+ """
273
+ if input_data_format is None:
274
+ input_data_format = infer_channel_dimension_format(image)
275
+ if input_data_format == ChannelDimension.FIRST:
276
+ return image.ndim - 3
277
+ elif input_data_format == ChannelDimension.LAST:
278
+ return image.ndim - 1
279
+ raise ValueError(f"Unsupported data format: {input_data_format}")
280
+
281
+
282
+ def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]:
283
+ """
284
+ Returns the (height, width) dimensions of the image.
285
+
286
+ Args:
287
+ image (`np.ndarray`):
288
+ The image to get the dimensions of.
289
+ channel_dim (`ChannelDimension`, *optional*):
290
+ Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image.
291
+
292
+ Returns:
293
+ A tuple of the image's height and width.
294
+ """
295
+ if channel_dim is None:
296
+ channel_dim = infer_channel_dimension_format(image)
297
+
298
+ if channel_dim == ChannelDimension.FIRST:
299
+ return image.shape[-2], image.shape[-1]
300
+ elif channel_dim == ChannelDimension.LAST:
301
+ return image.shape[-3], image.shape[-2]
302
+ else:
303
+ raise ValueError(f"Unsupported data format: {channel_dim}")
304
+
305
+
306
+ def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool:
307
+ if (
308
+ isinstance(annotation, dict)
309
+ and "image_id" in annotation
310
+ and "annotations" in annotation
311
+ and isinstance(annotation["annotations"], (list, tuple))
312
+ and (
313
+ # an image can have no annotations
314
+ len(annotation["annotations"]) == 0 or isinstance(annotation["annotations"][0], dict)
315
+ )
316
+ ):
317
+ return True
318
+ return False
319
+
320
+
321
+ def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool:
322
+ if (
323
+ isinstance(annotation, dict)
324
+ and "image_id" in annotation
325
+ and "segments_info" in annotation
326
+ and "file_name" in annotation
327
+ and isinstance(annotation["segments_info"], (list, tuple))
328
+ and (
329
+ # an image can have no segments
330
+ len(annotation["segments_info"]) == 0 or isinstance(annotation["segments_info"][0], dict)
331
+ )
332
+ ):
333
+ return True
334
+ return False
335
+
336
+
337
+ def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
338
+ return all(is_valid_annotation_coco_detection(ann) for ann in annotations)
339
+
340
+
341
+ def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool:
342
+ return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations)
343
+
344
+
345
+ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image":
346
+ """
347
+ Loads `image` to a PIL Image.
348
+
349
+ Args:
350
+ image (`str` or `PIL.Image.Image`):
351
+ The image to convert to the PIL Image format.
352
+ timeout (`float`, *optional*):
353
+ The timeout value in seconds for the URL request.
354
+
355
+ Returns:
356
+ `PIL.Image.Image`: A PIL Image.
357
+ """
358
+ requires_backends(load_image, ["vision"])
359
+ if isinstance(image, str):
360
+ if image.startswith("http://") or image.startswith("https://"):
361
+ # We need to actually check for a real protocol, otherwise it's impossible to use a local file
362
+ # like http_huggingface_co.png
363
+ image = PIL.Image.open(BytesIO(requests.get(image, timeout=timeout).content))
364
+ elif os.path.isfile(image):
365
+ image = PIL.Image.open(image)
366
+ else:
367
+ if image.startswith("data:image/"):
368
+ image = image.split(",")[1]
369
+
370
+ # Try to load as base64
371
+ try:
372
+ b64 = base64.decodebytes(image.encode())
373
+ image = PIL.Image.open(BytesIO(b64))
374
+ except Exception as e:
375
+ raise ValueError(
376
+ f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}"
377
+ )
378
+ elif isinstance(image, PIL.Image.Image):
379
+ image = image
380
+ else:
381
+ raise TypeError(
382
+ "Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image."
383
+ )
384
+ image = PIL.ImageOps.exif_transpose(image)
385
+ image = image.convert("RGB")
386
+ return image
387
+
388
+
389
+ def load_images(
390
+ images: Union[List, Tuple, str, "PIL.Image.Image"], timeout: Optional[float] = None
391
+ ) -> Union["PIL.Image.Image", List["PIL.Image.Image"], List[List["PIL.Image.Image"]]]:
392
+ """Loads images, handling different levels of nesting.
393
+
394
+ Args:
395
+ images: A single image, a list of images, or a list of lists of images to load.
396
+ timeout: Timeout for loading images.
397
+
398
+ Returns:
399
+ A single image, a list of images, a list of lists of images.
400
+ """
401
+ if isinstance(images, (list, tuple)):
402
+ if len(images) and isinstance(images[0], (list, tuple)):
403
+ return [[load_image(image, timeout=timeout) for image in image_group] for image_group in images]
404
+ else:
405
+ return [load_image(image, timeout=timeout) for image in images]
406
+ else:
407
+ return load_image(images, timeout=timeout)
408
+
409
+
410
+ def validate_preprocess_arguments(
411
+ do_rescale: Optional[bool] = None,
412
+ rescale_factor: Optional[float] = None,
413
+ do_normalize: Optional[bool] = None,
414
+ image_mean: Optional[Union[float, List[float]]] = None,
415
+ image_std: Optional[Union[float, List[float]]] = None,
416
+ do_pad: Optional[bool] = None,
417
+ size_divisibility: Optional[int] = None,
418
+ do_center_crop: Optional[bool] = None,
419
+ crop_size: Optional[Dict[str, int]] = None,
420
+ do_resize: Optional[bool] = None,
421
+ size: Optional[Dict[str, int]] = None,
422
+ resample: Optional["PILImageResampling"] = None,
423
+ ):
424
+ """
425
+ Checks validity of typically used arguments in an `ImageProcessor` `preprocess` method.
426
+ Raises `ValueError` if arguments incompatibility is caught.
427
+ Many incompatibilities are model-specific. `do_pad` sometimes needs `size_divisor`,
428
+ sometimes `size_divisibility`, and sometimes `size`. New models and processors added should follow
429
+ existing arguments when possible.
430
+
431
+ """
432
+ if do_rescale and rescale_factor is None:
433
+ raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")
434
+
435
+ if do_pad and size_divisibility is None:
436
+ # Here, size_divisor might be passed as the value of size
437
+ raise ValueError(
438
+ "Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
439
+ )
440
+
441
+ if do_normalize and (image_mean is None or image_std is None):
442
+ raise ValueError("`image_mean` and `image_std` must both be specified if `do_normalize` is `True`.")
443
+
444
+ if do_center_crop and crop_size is None:
445
+ raise ValueError("`crop_size` must be specified if `do_center_crop` is `True`.")
446
+
447
+ if do_resize and (size is None or resample is None):
448
+ raise ValueError("`size` and `resample` must be specified if `do_resize` is `True`.")
449
+
450
+
451
+ def validate_fast_preprocess_arguments(
452
+ do_rescale: Optional[bool] = None,
453
+ rescale_factor: Optional[float] = None,
454
+ do_normalize: Optional[bool] = None,
455
+ image_mean: Optional[Union[float, List[float]]] = None,
456
+ image_std: Optional[Union[float, List[float]]] = None,
457
+ do_pad: Optional[bool] = None,
458
+ size_divisibility: Optional[int] = None,
459
+ do_center_crop: Optional[bool] = None,
460
+ crop_size: Optional[Dict[str, int]] = None,
461
+ do_resize: Optional[bool] = None,
462
+ size: Optional[Dict[str, int]] = None,
463
+ resample: Optional["PILImageResampling"] = None,
464
+ return_tensors: Optional[Union[str, TensorType]] = None,
465
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
466
+ ):
467
+ """
468
+ Checks validity of typically used arguments in an `ImageProcessorFast` `preprocess` method.
469
+ Raises `ValueError` if arguments incompatibility is caught.
470
+ """
471
+ validate_preprocess_arguments(
472
+ do_rescale=do_rescale,
473
+ rescale_factor=rescale_factor,
474
+ do_normalize=do_normalize,
475
+ image_mean=image_mean,
476
+ image_std=image_std,
477
+ do_resize=do_resize,
478
+ size=size,
479
+ resample=resample,
480
+ )
481
+ # Extra checks for ImageProcessorFast
482
+ if return_tensors != "pt":
483
+ raise ValueError("Only returning PyTorch tensors is currently supported.")
484
+
485
+ if data_format != ChannelDimension.FIRST:
486
+ raise ValueError("Only channel first data format is currently supported.")
487
+
488
+
489
+ # In the future we can add a TF implementation here when we have TF models.
490
+ class ImageFeatureExtractionMixin:
491
+ """
492
+ Mixin that contain utilities for preparing image features.
493
+ """
494
+
495
+ def _ensure_format_supported(self, image):
496
+ if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
497
+ raise ValueError(
498
+ f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
499
+ "`torch.Tensor` are."
500
+ )
501
+
502
+ def to_pil_image(self, image, rescale=None):
503
+ """
504
+ Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
505
+ needed.
506
+
507
+ Args:
508
+ image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`):
509
+ The image to convert to the PIL Image format.
510
+ rescale (`bool`, *optional*):
511
+ Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
512
+ default to `True` if the image type is a floating type, `False` otherwise.
513
+ """
514
+ self._ensure_format_supported(image)
515
+
516
+ if is_torch_tensor(image):
517
+ image = image.numpy()
518
+
519
+ if isinstance(image, np.ndarray):
520
+ if rescale is None:
521
+ # rescale default to the array being of floating type.
522
+ rescale = isinstance(image.flat[0], np.floating)
523
+ # If the channel as been moved to first dim, we put it back at the end.
524
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
525
+ image = image.transpose(1, 2, 0)
526
+ if rescale:
527
+ image = image * 255
528
+ image = image.astype(np.uint8)
529
+ return PIL.Image.fromarray(image)
530
+ return image
531
+
532
+ def convert_rgb(self, image):
533
+ """
534
+ Converts `PIL.Image.Image` to RGB format.
535
+
536
+ Args:
537
+ image (`PIL.Image.Image`):
538
+ The image to convert.
539
+ """
540
+ self._ensure_format_supported(image)
541
+ if not isinstance(image, PIL.Image.Image):
542
+ return image
543
+
544
+ return image.convert("RGB")
545
+
546
+ def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray:
547
+ """
548
+ Rescale a numpy image by scale amount
549
+ """
550
+ self._ensure_format_supported(image)
551
+ return image * scale
552
+
553
+ def to_numpy_array(self, image, rescale=None, channel_first=True):
554
+ """
555
+ Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
556
+ dimension.
557
+
558
+ Args:
559
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
560
+ The image to convert to a NumPy array.
561
+ rescale (`bool`, *optional*):
562
+ Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
563
+ default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise.
564
+ channel_first (`bool`, *optional*, defaults to `True`):
565
+ Whether or not to permute the dimensions of the image to put the channel dimension first.
566
+ """
567
+ self._ensure_format_supported(image)
568
+
569
+ if isinstance(image, PIL.Image.Image):
570
+ image = np.array(image)
571
+
572
+ if is_torch_tensor(image):
573
+ image = image.numpy()
574
+
575
+ rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale
576
+
577
+ if rescale:
578
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
579
+
580
+ if channel_first and image.ndim == 3:
581
+ image = image.transpose(2, 0, 1)
582
+
583
+ return image
584
+
585
+ def expand_dims(self, image):
586
+ """
587
+ Expands 2-dimensional `image` to 3 dimensions.
588
+
589
+ Args:
590
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
591
+ The image to expand.
592
+ """
593
+ self._ensure_format_supported(image)
594
+
595
+ # Do nothing if PIL image
596
+ if isinstance(image, PIL.Image.Image):
597
+ return image
598
+
599
+ if is_torch_tensor(image):
600
+ image = image.unsqueeze(0)
601
+ else:
602
+ image = np.expand_dims(image, axis=0)
603
+ return image
604
+
605
+ def normalize(self, image, mean, std, rescale=False):
606
+ """
607
+ Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array
608
+ if it's a PIL Image.
609
+
610
+ Args:
611
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
612
+ The image to normalize.
613
+ mean (`List[float]` or `np.ndarray` or `torch.Tensor`):
614
+ The mean (per channel) to use for normalization.
615
+ std (`List[float]` or `np.ndarray` or `torch.Tensor`):
616
+ The standard deviation (per channel) to use for normalization.
617
+ rescale (`bool`, *optional*, defaults to `False`):
618
+ Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will
619
+ happen automatically.
620
+ """
621
+ self._ensure_format_supported(image)
622
+
623
+ if isinstance(image, PIL.Image.Image):
624
+ image = self.to_numpy_array(image, rescale=True)
625
+ # If the input image is a PIL image, it automatically gets rescaled. If it's another
626
+ # type it may need rescaling.
627
+ elif rescale:
628
+ if isinstance(image, np.ndarray):
629
+ image = self.rescale(image.astype(np.float32), 1 / 255.0)
630
+ elif is_torch_tensor(image):
631
+ image = self.rescale(image.float(), 1 / 255.0)
632
+
633
+ if isinstance(image, np.ndarray):
634
+ if not isinstance(mean, np.ndarray):
635
+ mean = np.array(mean).astype(image.dtype)
636
+ if not isinstance(std, np.ndarray):
637
+ std = np.array(std).astype(image.dtype)
638
+ elif is_torch_tensor(image):
639
+ import torch
640
+
641
+ if not isinstance(mean, torch.Tensor):
642
+ if isinstance(mean, np.ndarray):
643
+ mean = torch.from_numpy(mean)
644
+ else:
645
+ mean = torch.tensor(mean)
646
+ if not isinstance(std, torch.Tensor):
647
+ if isinstance(std, np.ndarray):
648
+ std = torch.from_numpy(std)
649
+ else:
650
+ std = torch.tensor(std)
651
+
652
+ if image.ndim == 3 and image.shape[0] in [1, 3]:
653
+ return (image - mean[:, None, None]) / std[:, None, None]
654
+ else:
655
+ return (image - mean) / std
656
+
657
+ def resize(self, image, size, resample=None, default_to_square=True, max_size=None):
658
+ """
659
+ Resizes `image`. Enforces conversion of input to PIL.Image.
660
+
661
+ Args:
662
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
663
+ The image to resize.
664
+ size (`int` or `Tuple[int, int]`):
665
+ The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be
666
+ matched to this.
667
+
668
+ If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
669
+ `size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to
670
+ this number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
671
+ resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`):
672
+ The filter to user for resampling.
673
+ default_to_square (`bool`, *optional*, defaults to `True`):
674
+ How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a
675
+ square (`size`,`size`). If set to `False`, will replicate
676
+ [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
677
+ with support for resizing only the smallest edge and providing an optional `max_size`.
678
+ max_size (`int`, *optional*, defaults to `None`):
679
+ The maximum allowed for the longer edge of the resized image: if the longer edge of the image is
680
+ greater than `max_size` after being resized according to `size`, then the image is resized again so
681
+ that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller
682
+ edge may be shorter than `size`. Only used if `default_to_square` is `False`.
683
+
684
+ Returns:
685
+ image: A resized `PIL.Image.Image`.
686
+ """
687
+ resample = resample if resample is not None else PILImageResampling.BILINEAR
688
+
689
+ self._ensure_format_supported(image)
690
+
691
+ if not isinstance(image, PIL.Image.Image):
692
+ image = self.to_pil_image(image)
693
+
694
+ if isinstance(size, list):
695
+ size = tuple(size)
696
+
697
+ if isinstance(size, int) or len(size) == 1:
698
+ if default_to_square:
699
+ size = (size, size) if isinstance(size, int) else (size[0], size[0])
700
+ else:
701
+ width, height = image.size
702
+ # specified size only for the smallest edge
703
+ short, long = (width, height) if width <= height else (height, width)
704
+ requested_new_short = size if isinstance(size, int) else size[0]
705
+
706
+ if short == requested_new_short:
707
+ return image
708
+
709
+ new_short, new_long = requested_new_short, int(requested_new_short * long / short)
710
+
711
+ if max_size is not None:
712
+ if max_size <= requested_new_short:
713
+ raise ValueError(
714
+ f"max_size = {max_size} must be strictly greater than the requested "
715
+ f"size for the smaller edge size = {size}"
716
+ )
717
+ if new_long > max_size:
718
+ new_short, new_long = int(max_size * new_short / new_long), max_size
719
+
720
+ size = (new_short, new_long) if width <= height else (new_long, new_short)
721
+
722
+ return image.resize(size, resample=resample)
723
+
724
+ def center_crop(self, image, size):
725
+ """
726
+ Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the
727
+ size given, it will be padded (so the returned result has the size asked).
728
+
729
+ Args:
730
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)):
731
+ The image to resize.
732
+ size (`int` or `Tuple[int, int]`):
733
+ The size to which crop the image.
734
+
735
+ Returns:
736
+ new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels,
737
+ height, width).
738
+ """
739
+ self._ensure_format_supported(image)
740
+
741
+ if not isinstance(size, tuple):
742
+ size = (size, size)
743
+
744
+ # PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
745
+ if is_torch_tensor(image) or isinstance(image, np.ndarray):
746
+ if image.ndim == 2:
747
+ image = self.expand_dims(image)
748
+ image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2]
749
+ else:
750
+ image_shape = (image.size[1], image.size[0])
751
+
752
+ top = (image_shape[0] - size[0]) // 2
753
+ bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
754
+ left = (image_shape[1] - size[1]) // 2
755
+ right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.
756
+
757
+ # For PIL Images we have a method to crop directly.
758
+ if isinstance(image, PIL.Image.Image):
759
+ return image.crop((left, top, right, bottom))
760
+
761
+ # Check if image is in (n_channels, height, width) or (height, width, n_channels) format
762
+ channel_first = True if image.shape[0] in [1, 3] else False
763
+
764
+ # Transpose (height, width, n_channels) format images
765
+ if not channel_first:
766
+ if isinstance(image, np.ndarray):
767
+ image = image.transpose(2, 0, 1)
768
+ if is_torch_tensor(image):
769
+ image = image.permute(2, 0, 1)
770
+
771
+ # Check if cropped area is within image boundaries
772
+ if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
773
+ return image[..., top:bottom, left:right]
774
+
775
+ # Otherwise, we may need to pad if the image is too small. Oh joy...
776
+ new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
777
+ if isinstance(image, np.ndarray):
778
+ new_image = np.zeros_like(image, shape=new_shape)
779
+ elif is_torch_tensor(image):
780
+ new_image = image.new_zeros(new_shape)
781
+
782
+ top_pad = (new_shape[-2] - image_shape[0]) // 2
783
+ bottom_pad = top_pad + image_shape[0]
784
+ left_pad = (new_shape[-1] - image_shape[1]) // 2
785
+ right_pad = left_pad + image_shape[1]
786
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
787
+
788
+ top += top_pad
789
+ bottom += top_pad
790
+ left += left_pad
791
+ right += left_pad
792
+
793
+ new_image = new_image[
794
+ ..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
795
+ ]
796
+
797
+ return new_image
798
+
799
+ def flip_channel_order(self, image):
800
+ """
801
+ Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of
802
+ `image` to a NumPy array if it's a PIL Image.
803
+
804
+ Args:
805
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
806
+ The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should
807
+ be first.
808
+ """
809
+ self._ensure_format_supported(image)
810
+
811
+ if isinstance(image, PIL.Image.Image):
812
+ image = self.to_numpy_array(image)
813
+
814
+ return image[::-1, :, :]
815
+
816
+ def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None):
817
+ """
818
+ Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees
819
+ counter clockwise around its centre.
820
+
821
+ Args:
822
+ image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
823
+ The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before
824
+ rotating.
825
+
826
+ Returns:
827
+ image: A rotated `PIL.Image.Image`.
828
+ """
829
+ resample = resample if resample is not None else PIL.Image.NEAREST
830
+
831
+ self._ensure_format_supported(image)
832
+
833
+ if not isinstance(image, PIL.Image.Image):
834
+ image = self.to_pil_image(image)
835
+
836
+ return image.rotate(
837
+ angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor
838
+ )
839
+
840
+
841
+ def validate_annotations(
842
+ annotation_format: AnnotationFormat,
843
+ supported_annotation_formats: Tuple[AnnotationFormat, ...],
844
+ annotations: List[Dict],
845
+ ) -> None:
846
+ if annotation_format not in supported_annotation_formats:
847
+ raise ValueError(f"Unsupported annotation format: {format} must be one of {supported_annotation_formats}")
848
+
849
+ if annotation_format is AnnotationFormat.COCO_DETECTION:
850
+ if not valid_coco_detection_annotations(annotations):
851
+ raise ValueError(
852
+ "Invalid COCO detection annotations. Annotations must a dict (single image) or list of dicts "
853
+ "(batch of images) with the following keys: `image_id` and `annotations`, with the latter "
854
+ "being a list of annotations in the COCO format."
855
+ )
856
+
857
+ if annotation_format is AnnotationFormat.COCO_PANOPTIC:
858
+ if not valid_coco_panoptic_annotations(annotations):
859
+ raise ValueError(
860
+ "Invalid COCO panoptic annotations. Annotations must a dict (single image) or list of dicts "
861
+ "(batch of images) with the following keys: `image_id`, `file_name` and `segments_info`, with "
862
+ "the latter being a list of annotations in the COCO format."
863
+ )
864
+
865
+
866
+ def validate_kwargs(valid_processor_keys: List[str], captured_kwargs: List[str]):
867
+ unused_keys = set(captured_kwargs).difference(set(valid_processor_keys))
868
+ if unused_keys:
869
+ unused_key_str = ", ".join(unused_keys)
870
+ # TODO raise a warning here instead of simply logging?
871
+ logger.warning(f"Unused or unrecognized kwargs: {unused_key_str}.")
keras_callbacks.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from time import sleep
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import tensorflow as tf
9
+ from huggingface_hub import Repository, create_repo
10
+ from packaging.version import parse
11
+
12
+ from . import IntervalStrategy, PreTrainedTokenizerBase
13
+ from .modelcard import TrainingSummary
14
+ from .modeling_tf_utils import keras
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class KerasMetricCallback(keras.callbacks.Callback):
21
+ """
22
+ Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be
23
+ compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string
24
+ operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the
25
+ `eval_dataset` before being passed to the `metric_fn` in `np.ndarray` format. The `metric_fn` should compute
26
+ metrics and return a dict mapping metric names to metric values.
27
+
28
+ We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that
29
+ this example skips some post-processing for readability and simplicity, and should probably not be used as-is!
30
+
31
+ ```py
32
+ from datasets import load_metric
33
+
34
+ rouge_metric = load_metric("rouge")
35
+
36
+
37
+ def rouge_fn(predictions, labels):
38
+ decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
39
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
40
+ result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)
41
+ return {key: value.mid.fmeasure * 100 for key, value in result.items()}
42
+ ```
43
+
44
+ The above function will return a dict containing values which will be logged like any other Keras metric:
45
+
46
+ ```
47
+ {'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781
48
+ ```
49
+
50
+ Args:
51
+ metric_fn (`Callable`):
52
+ Metric function provided by the user. It will be called with two arguments - `predictions` and `labels`.
53
+ These contain the model's outputs and matching labels from the dataset. It should return a dict mapping
54
+ metric names to numerical values.
55
+ eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
56
+ Validation data to be used to generate predictions for the `metric_fn`.
57
+ output_cols (`List[str], *optional*):
58
+ A list of columns to be retained from the model output as the predictions. Defaults to all.
59
+ label_cols ('`List[str]`, *optional*'):
60
+ A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not
61
+ supplied.
62
+ batch_size (`int`, *optional*):
63
+ Batch size. Only used when the data is not a pre-batched `tf.data.Dataset`.
64
+ predict_with_generate (`bool`, *optional*, defaults to `False`):
65
+ Whether we should use `model.generate()` to get outputs for the model.
66
+ use_xla_generation (`bool`, *optional*, defaults to `False`):
67
+ If we're generating, whether to compile model generation with XLA. This can massively increase the speed of
68
+ generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA
69
+ generation, it's a good idea to pad your inputs to the same size, or to use the `pad_to_multiple_of`
70
+ argument in your `tokenizer` or `DataCollator`, which will reduce the number of unique input shapes and
71
+ save a lot of compilation time. This option has no effect is `predict_with_generate` is `False`.
72
+ generate_kwargs (`dict`, *optional*):
73
+ Keyword arguments to pass to `model.generate()` when generating. Has no effect if `predict_with_generate`
74
+ is `False`.
75
+
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ metric_fn: Callable,
81
+ eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
82
+ output_cols: Optional[List[str]] = None,
83
+ label_cols: Optional[List[str]] = None,
84
+ batch_size: Optional[int] = None,
85
+ predict_with_generate: bool = False,
86
+ use_xla_generation: bool = False,
87
+ generate_kwargs: Optional[dict] = None,
88
+ ):
89
+ super().__init__()
90
+ self.metric_fn = metric_fn
91
+ self.batch_size = batch_size
92
+ if not isinstance(eval_dataset, tf.data.Dataset):
93
+ if batch_size is None:
94
+ raise ValueError(
95
+ "When passing data to KerasMetricCallback that is not a pre-batched tf.data.Dataset "
96
+ "the batch_size argument must be set."
97
+ )
98
+ # Wrap a tf.data.Dataset around it
99
+ eval_dataset = tf.data.Dataset.from_tensor_slices(eval_dataset).batch(batch_size, drop_remainder=False)
100
+ self.eval_dataset = eval_dataset
101
+ self.predict_with_generate = predict_with_generate
102
+ self.output_cols = output_cols
103
+
104
+ # This next block attempts to parse out which elements of the dataset should be appended to the labels list
105
+ # that is passed to the metric_fn
106
+ if isinstance(eval_dataset.element_spec, tuple) and len(eval_dataset.element_spec) == 2:
107
+ input_spec, label_spec = eval_dataset.element_spec
108
+ else:
109
+ input_spec = eval_dataset.element_spec
110
+ label_spec = None
111
+ if label_cols is not None:
112
+ for label in label_cols:
113
+ if label not in input_spec:
114
+ raise ValueError(f"Label {label} is in label_cols but could not be found in the dataset inputs!")
115
+ self.label_cols = label_cols
116
+ self.use_keras_label = False
117
+ elif label_spec is not None:
118
+ # If the dataset inputs are split into a 2-tuple of inputs and labels,
119
+ # assume the second element is the labels
120
+ self.label_cols = None
121
+ self.use_keras_label = True
122
+ elif "labels" in input_spec:
123
+ self.label_cols = ["labels"]
124
+ self.use_keras_label = False
125
+ logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
126
+ elif "start_positions" in input_spec and "end_positions" in input_spec:
127
+ self.label_cols = ["start_positions", "end_positions"]
128
+ self.use_keras_label = False
129
+ logging.warning(
130
+ "No label_cols specified for KerasMetricCallback, assuming you want the "
131
+ "start_positions and end_positions keys."
132
+ )
133
+ else:
134
+ raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
135
+ if parse(tf.__version__) < parse("2.7"):
136
+ logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
137
+
138
+ self.use_xla_generation = use_xla_generation
139
+ self.generate_kwargs = {} if generate_kwargs is None else generate_kwargs
140
+
141
+ self.generation_function = None
142
+
143
+ @staticmethod
144
+ def _concatenate_batches(batches, padding_index=-100):
145
+ # If all batches are unidimensional or same length, do a simple concatenation
146
+ if batches[0].ndim == 1 or all(batch.shape[1] == batches[0].shape[1] for batch in batches):
147
+ return np.concatenate(batches, axis=0)
148
+
149
+ # Welp, they're not the same length. Let's do some padding
150
+ max_len = max([batch.shape[1] for batch in batches])
151
+ num_samples = sum([batch.shape[0] for batch in batches])
152
+ output = np.full_like(
153
+ batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
154
+ )
155
+ # i keeps track of which part of the concatenated array we're writing the next batch to
156
+ i = 0
157
+ for batch in batches:
158
+ output[i : i + len(batch), : batch.shape[1]] = batch
159
+ i += len(batch)
160
+ return output
161
+
162
+ def _postprocess_predictions_or_labels(self, inputs):
163
+ if isinstance(inputs[0], dict):
164
+ outputs = {}
165
+ for key in inputs[0].keys():
166
+ outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
167
+ # If it's a dict with only one key, just return the array
168
+ if len(outputs) == 1:
169
+ outputs = list(outputs.values())[0]
170
+ elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
171
+ outputs = []
172
+ for input_list in zip(*inputs):
173
+ outputs.append(self._concatenate_batches(input_list))
174
+ if len(outputs) == 1:
175
+ outputs = outputs[0] # If it's a list with only one element, just return the array
176
+ elif isinstance(inputs[0], np.ndarray):
177
+ outputs = self._concatenate_batches(inputs)
178
+ elif isinstance(inputs[0], tf.Tensor):
179
+ outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
180
+ else:
181
+ raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
182
+ return outputs
183
+
184
+ def on_epoch_end(self, epoch, logs=None):
185
+ if hasattr(self.model, "config"):
186
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
187
+ else:
188
+ ignore_keys = []
189
+
190
+ main_input_name = None
191
+ if self.predict_with_generate:
192
+ # This dense conditional recognizes the case where we have an encoder-decoder model, but
193
+ # avoids getting tangled up when we just have a model with a layer called 'encoder'
194
+ if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
195
+ main_input_name = self.model.encoder.main_input_name
196
+ else:
197
+ main_input_name = getattr(self.model, "main_input_name", "input_ids")
198
+
199
+ if self.use_xla_generation and self.generation_function is None:
200
+
201
+ def generation_function(inputs, attention_mask):
202
+ return self.model.generate(inputs, attention_mask=attention_mask, **self.generate_kwargs)
203
+
204
+ self.generation_function = tf.function(generation_function, jit_compile=True)
205
+
206
+ prediction_list = []
207
+ label_list = []
208
+
209
+ # The whole predict/generate loop is handled inside this method
210
+ for batch in self.eval_dataset:
211
+ if isinstance(batch, tuple):
212
+ batch, labels = batch
213
+ else:
214
+ labels = None
215
+ if self.predict_with_generate:
216
+ if isinstance(batch, dict):
217
+ generation_inputs = batch[main_input_name]
218
+ attention_mask = batch.get("attention_mask", None)
219
+ else:
220
+ generation_inputs = batch
221
+ attention_mask = None
222
+ if self.use_xla_generation:
223
+ predictions = self.generation_function(generation_inputs, attention_mask=attention_mask)
224
+ else:
225
+ predictions = self.model.generate(
226
+ generation_inputs, attention_mask=attention_mask, **self.generate_kwargs
227
+ )
228
+ else:
229
+ predictions = self.model.predict_on_batch(batch)
230
+ if isinstance(predictions, dict):
231
+ # This converts any dict-subclass to a regular dict
232
+ # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
233
+ predictions = dict(predictions)
234
+ if self.output_cols is not None:
235
+ predictions = {key: predictions[key] for key in self.output_cols}
236
+ else:
237
+ predictions = {
238
+ key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]
239
+ }
240
+ prediction_list.append(predictions)
241
+ if not self.use_keras_label:
242
+ labels = {key: batch[key].numpy() for key in self.label_cols}
243
+ elif isinstance(labels, dict):
244
+ labels = {key: array.numpy() for key, array in labels.items()}
245
+ elif isinstance(labels, list) or isinstance(labels, tuple):
246
+ labels = [array.numpy() for array in labels]
247
+ elif isinstance(labels, tf.Tensor):
248
+ labels = labels.numpy()
249
+ else:
250
+ raise TypeError(f"Confused by labels of type {type(labels)}")
251
+ label_list.append(labels)
252
+
253
+ all_preds = self._postprocess_predictions_or_labels(prediction_list)
254
+ all_labels = self._postprocess_predictions_or_labels(label_list)
255
+
256
+ metric_output = self.metric_fn((all_preds, all_labels))
257
+ if not isinstance(metric_output, dict):
258
+ raise TypeError(
259
+ f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"
260
+ )
261
+ # This is the critical bit - Keras passes a dict containing the loss and standard metric values for this epoch
262
+ # in the logs argument. Ordinarily, this is so the callback can read them, but in this case we write a bunch of
263
+ # new keys in there, which will then get read by the History callback and treated like any other metric value.
264
+ # I promise that I have it in writing from Chollet that this is okay.
265
+ logs.update(metric_output)
266
+
267
+
268
+ class PushToHubCallback(keras.callbacks.Callback):
269
+ """
270
+ Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can
271
+ be changed with the `save_strategy` argument. Pushed models can be accessed like any other model on the hub, such
272
+ as with the `from_pretrained` method.
273
+
274
+ ```py
275
+ from transformers.keras_callbacks import PushToHubCallback
276
+
277
+ push_to_hub_callback = PushToHubCallback(
278
+ output_dir="./model_save",
279
+ tokenizer=tokenizer,
280
+ hub_model_id="gpt5-7xlarge",
281
+ )
282
+
283
+ model.fit(train_dataset, callbacks=[push_to_hub_callback])
284
+ ```
285
+
286
+ Args:
287
+ output_dir (`str`):
288
+ The output directory where the model predictions and checkpoints will be written and synced with the
289
+ repository on the Hub.
290
+ save_strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"epoch"`):
291
+ The checkpoint save strategy to adopt during training. Possible values are:
292
+
293
+ - `"no"`: Save is done at the end of training.
294
+ - `"epoch"`: Save is done at the end of each epoch.
295
+ - `"steps"`: Save is done every `save_steps`
296
+ save_steps (`int`, *optional*):
297
+ The number of steps between saves when using the "steps" `save_strategy`.
298
+ tokenizer (`PreTrainedTokenizerBase`, *optional*):
299
+ The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
300
+ hub_model_id (`str`, *optional*):
301
+ The name of the repository to keep in sync with the local `output_dir`. It can be a simple model ID in
302
+ which case the model will be pushed in your namespace. Otherwise it should be the whole repository name,
303
+ for instance `"user_name/model"`, which allows you to push to an organization you are a member of with
304
+ `"organization_name/model"`.
305
+
306
+ Will default to the name of `output_dir`.
307
+ hub_token (`str`, *optional*):
308
+ The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
309
+ `huggingface-cli login`.
310
+ checkpoint (`bool`, *optional*, defaults to `False`):
311
+ Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be
312
+ resumed. Only usable when `save_strategy` is `"epoch"`.
313
+ """
314
+
315
+ def __init__(
316
+ self,
317
+ output_dir: Union[str, Path],
318
+ save_strategy: Union[str, IntervalStrategy] = "epoch",
319
+ save_steps: Optional[int] = None,
320
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
321
+ hub_model_id: Optional[str] = None,
322
+ hub_token: Optional[str] = None,
323
+ checkpoint: bool = False,
324
+ **model_card_args,
325
+ ):
326
+ super().__init__()
327
+ if checkpoint and save_strategy != "epoch":
328
+ raise ValueError("Cannot save checkpoints when save_strategy is not 'epoch'!")
329
+ if isinstance(save_strategy, str):
330
+ save_strategy = IntervalStrategy(save_strategy.lower())
331
+ self.save_strategy = save_strategy
332
+ if self.save_strategy == IntervalStrategy.STEPS and (not isinstance(save_steps, int) or save_steps <= 0):
333
+ raise ValueError("Please supply a positive integer argument for save_steps when save_strategy == 'steps'!")
334
+ self.save_steps = save_steps
335
+ output_dir = Path(output_dir)
336
+
337
+ # Create repo and retrieve repo_id
338
+ if hub_model_id is None:
339
+ hub_model_id = output_dir.absolute().name
340
+ self.hub_model_id = create_repo(repo_id=hub_model_id, exist_ok=True, token=hub_token).repo_id
341
+
342
+ self.output_dir = output_dir
343
+ self.repo = Repository(str(self.output_dir), clone_from=self.hub_model_id, token=hub_token)
344
+
345
+ self.tokenizer = tokenizer
346
+ self.last_job = None
347
+ self.checkpoint = checkpoint
348
+ self.training_history = None
349
+ self.model_card_args = model_card_args
350
+
351
+ def on_train_begin(self, logs=None):
352
+ # Although we can access model.history, we have no guarantees that the History callback will fire before this
353
+ # one, so we keep track of it here too
354
+ self.training_history = []
355
+
356
+ def on_train_batch_end(self, batch, logs=None):
357
+ if self.save_strategy == IntervalStrategy.STEPS and (batch + 1) % self.save_steps == 0:
358
+ if self.last_job is not None and not self.last_job.is_done:
359
+ return # The last upload is still running, don't start another
360
+ self.model.save_pretrained(self.output_dir)
361
+ if self.tokenizer is not None:
362
+ self.tokenizer.save_pretrained(self.output_dir)
363
+ _, self.last_job = self.repo.push_to_hub(
364
+ commit_message=f"Training in progress steps {batch}", blocking=False
365
+ )
366
+
367
+ def on_epoch_end(self, epoch, logs=None):
368
+ logs = logs.copy() # Don't accidentally write things that Keras will read later
369
+ if "epoch" not in logs:
370
+ logs["epoch"] = epoch
371
+ self.training_history.append(logs)
372
+ if self.save_strategy == IntervalStrategy.EPOCH:
373
+ if self.last_job is not None and not self.last_job.is_done:
374
+ return # The last upload is still running, don't start another
375
+ self.model.save_pretrained(self.output_dir)
376
+ if self.tokenizer is not None:
377
+ self.tokenizer.save_pretrained(self.output_dir)
378
+ if self.checkpoint:
379
+ checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
380
+ self.model._save_checkpoint(checkpoint_dir, epoch)
381
+ train_summary = TrainingSummary.from_keras(
382
+ model=self.model,
383
+ model_name=self.hub_model_id,
384
+ keras_history=self.training_history,
385
+ **self.model_card_args,
386
+ )
387
+ model_card = train_summary.to_model_card()
388
+ with (self.output_dir / "README.md").open("w") as f:
389
+ f.write(model_card)
390
+ _, self.last_job = self.repo.push_to_hub(
391
+ commit_message=f"Training in progress epoch {epoch}", blocking=False
392
+ )
393
+
394
+ def on_train_end(self, logs=None):
395
+ # Makes sure the latest version of the model is uploaded
396
+ if self.last_job is not None and not self.last_job.is_done:
397
+ logging.info("Pushing the last epoch to the Hub, this may take a while...")
398
+ while not self.last_job.is_done:
399
+ sleep(1)
400
+ else:
401
+ self.model.save_pretrained(self.output_dir)
402
+ if self.tokenizer is not None:
403
+ self.tokenizer.save_pretrained(self.output_dir)
404
+ train_summary = TrainingSummary.from_keras(
405
+ model=self.model,
406
+ model_name=self.hub_model_id,
407
+ keras_history=self.training_history,
408
+ **self.model_card_args,
409
+ )
410
+ model_card = train_summary.to_model_card()
411
+ with (self.output_dir / "README.md").open("w") as f:
412
+ f.write(model_card)
413
+ self.repo.push_to_hub(commit_message="End of training", blocking=True)
modelcard.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Configuration base class and utilities."""
16
+
17
+ import copy
18
+ import json
19
+ import os
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Optional, Union
24
+
25
+ import requests
26
+ import yaml
27
+ from huggingface_hub import model_info
28
+ from huggingface_hub.utils import HFValidationError
29
+
30
+ from . import __version__
31
+ from .models.auto.modeling_auto import (
32
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
33
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
34
+ MODEL_FOR_CTC_MAPPING_NAMES,
35
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
36
+ MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
37
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES,
38
+ MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
39
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
40
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
41
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
42
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
43
+ MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
44
+ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
45
+ MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
46
+ )
47
+ from .training_args import ParallelMode
48
+ from .utils import (
49
+ MODEL_CARD_NAME,
50
+ cached_file,
51
+ is_datasets_available,
52
+ is_offline_mode,
53
+ is_tf_available,
54
+ is_tokenizers_available,
55
+ is_torch_available,
56
+ logging,
57
+ )
58
+
59
+
60
+ TASK_MAPPING = {
61
+ "text-generation": MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
62
+ "image-classification": MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
63
+ "image-segmentation": MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES,
64
+ "fill-mask": MODEL_FOR_MASKED_LM_MAPPING_NAMES,
65
+ "object-detection": MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES,
66
+ "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
67
+ "text2text-generation": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
68
+ "text-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
69
+ "table-question-answering": MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
70
+ "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
71
+ "audio-classification": MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
72
+ "automatic-speech-recognition": {**MODEL_FOR_CTC_MAPPING_NAMES, **MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES},
73
+ "zero-shot-image-classification": MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES,
74
+ }
75
+
76
+ logger = logging.get_logger(__name__)
77
+
78
+
79
+ class ModelCard:
80
+ r"""
81
+ Structured Model Card class. Store model card as well as methods for loading/downloading/saving model cards.
82
+
83
+ Please read the following paper for details and explanation on the sections: "Model Cards for Model Reporting" by
84
+ Margaret Mitchell, Simone Wu, Andrew Zaldivar, Parker Barnes, Lucy Vasserman, Ben Hutchinson, Elena Spitzer,
85
+ Inioluwa Deborah Raji and Timnit Gebru for the proposal behind model cards. Link: https://arxiv.org/abs/1810.03993
86
+
87
+ Note: A model card can be loaded and saved to disk.
88
+ """
89
+
90
+ def __init__(self, **kwargs):
91
+ warnings.warn(
92
+ "The class `ModelCard` is deprecated and will be removed in version 5 of Transformers", FutureWarning
93
+ )
94
+ # Recommended attributes from https://arxiv.org/abs/1810.03993 (see papers)
95
+ self.model_details = kwargs.pop("model_details", {})
96
+ self.intended_use = kwargs.pop("intended_use", {})
97
+ self.factors = kwargs.pop("factors", {})
98
+ self.metrics = kwargs.pop("metrics", {})
99
+ self.evaluation_data = kwargs.pop("evaluation_data", {})
100
+ self.training_data = kwargs.pop("training_data", {})
101
+ self.quantitative_analyses = kwargs.pop("quantitative_analyses", {})
102
+ self.ethical_considerations = kwargs.pop("ethical_considerations", {})
103
+ self.caveats_and_recommendations = kwargs.pop("caveats_and_recommendations", {})
104
+
105
+ # Open additional attributes
106
+ for key, value in kwargs.items():
107
+ try:
108
+ setattr(self, key, value)
109
+ except AttributeError as err:
110
+ logger.error(f"Can't set {key} with value {value} for {self}")
111
+ raise err
112
+
113
+ def save_pretrained(self, save_directory_or_file):
114
+ """Save a model card object to the directory or file `save_directory_or_file`."""
115
+ if os.path.isdir(save_directory_or_file):
116
+ # If we save using the predefined names, we can load using `from_pretrained`
117
+ output_model_card_file = os.path.join(save_directory_or_file, MODEL_CARD_NAME)
118
+ else:
119
+ output_model_card_file = save_directory_or_file
120
+
121
+ self.to_json_file(output_model_card_file)
122
+ logger.info(f"Model card saved in {output_model_card_file}")
123
+
124
+ @classmethod
125
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
126
+ r"""
127
+ Instantiate a [`ModelCard`] from a pre-trained model model card.
128
+
129
+ Parameters:
130
+ pretrained_model_name_or_path: either:
131
+
132
+ - a string, the *model id* of a pretrained model card hosted inside a model repo on huggingface.co.
133
+ - a path to a *directory* containing a model card file saved using the [`~ModelCard.save_pretrained`]
134
+ method, e.g.: `./my_model_directory/`.
135
+ - a path or url to a saved model card JSON *file*, e.g.: `./my_model_directory/modelcard.json`.
136
+
137
+ cache_dir: (*optional*) string:
138
+ Path to a directory in which a downloaded pre-trained model card should be cached if the standard cache
139
+ should not be used.
140
+
141
+ kwargs: (*optional*) dict: key/value pairs with which to update the ModelCard object after loading.
142
+
143
+ - The values in kwargs of any keys which are model card attributes will be used to override the loaded
144
+ values.
145
+ - Behavior concerning key/value pairs whose keys are *not* model card attributes is controlled by the
146
+ *return_unused_kwargs* keyword parameter.
147
+
148
+ proxies: (*optional*) dict, default None:
149
+ A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
150
+ 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
151
+
152
+ return_unused_kwargs: (*optional*) bool:
153
+
154
+ - If False, then this function returns just the final model card object.
155
+ - If True, then this functions returns a tuple *(model card, unused_kwargs)* where *unused_kwargs* is a
156
+ dictionary consisting of the key/value pairs whose keys are not model card attributes: ie the part of
157
+ kwargs which has not been used to update *ModelCard* and is otherwise ignored.
158
+
159
+ Examples:
160
+
161
+ ```python
162
+ # Download model card from huggingface.co and cache.
163
+ modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased")
164
+ # Model card was saved using *save_pretrained('./test/saved_model/')*
165
+ modelcard = ModelCard.from_pretrained("./test/saved_model/")
166
+ modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
167
+ modelcard = ModelCard.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
168
+ ```"""
169
+ cache_dir = kwargs.pop("cache_dir", None)
170
+ proxies = kwargs.pop("proxies", None)
171
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
172
+ from_pipeline = kwargs.pop("_from_pipeline", None)
173
+
174
+ user_agent = {"file_type": "model_card"}
175
+ if from_pipeline is not None:
176
+ user_agent["using_pipeline"] = from_pipeline
177
+
178
+ is_local = os.path.isdir(pretrained_model_name_or_path)
179
+ if os.path.isfile(pretrained_model_name_or_path):
180
+ resolved_model_card_file = pretrained_model_name_or_path
181
+ is_local = True
182
+ else:
183
+ try:
184
+ # Load from URL or cache if already cached
185
+ resolved_model_card_file = cached_file(
186
+ pretrained_model_name_or_path,
187
+ filename=MODEL_CARD_NAME,
188
+ cache_dir=cache_dir,
189
+ proxies=proxies,
190
+ user_agent=user_agent,
191
+ )
192
+ if is_local:
193
+ logger.info(f"loading model card file {resolved_model_card_file}")
194
+ else:
195
+ logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}")
196
+ # Load model card
197
+ modelcard = cls.from_json_file(resolved_model_card_file)
198
+
199
+ except (EnvironmentError, json.JSONDecodeError):
200
+ # We fall back on creating an empty model card
201
+ modelcard = cls()
202
+
203
+ # Update model card with kwargs if needed
204
+ to_remove = []
205
+ for key, value in kwargs.items():
206
+ if hasattr(modelcard, key):
207
+ setattr(modelcard, key, value)
208
+ to_remove.append(key)
209
+ for key in to_remove:
210
+ kwargs.pop(key, None)
211
+
212
+ logger.info(f"Model card: {modelcard}")
213
+ if return_unused_kwargs:
214
+ return modelcard, kwargs
215
+ else:
216
+ return modelcard
217
+
218
+ @classmethod
219
+ def from_dict(cls, json_object):
220
+ """Constructs a `ModelCard` from a Python dictionary of parameters."""
221
+ return cls(**json_object)
222
+
223
+ @classmethod
224
+ def from_json_file(cls, json_file):
225
+ """Constructs a `ModelCard` from a json file of parameters."""
226
+ with open(json_file, "r", encoding="utf-8") as reader:
227
+ text = reader.read()
228
+ dict_obj = json.loads(text)
229
+ return cls(**dict_obj)
230
+
231
+ def __eq__(self, other):
232
+ return self.__dict__ == other.__dict__
233
+
234
+ def __repr__(self):
235
+ return str(self.to_json_string())
236
+
237
+ def to_dict(self):
238
+ """Serializes this instance to a Python dictionary."""
239
+ output = copy.deepcopy(self.__dict__)
240
+ return output
241
+
242
+ def to_json_string(self):
243
+ """Serializes this instance to a JSON string."""
244
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
245
+
246
+ def to_json_file(self, json_file_path):
247
+ """Save this instance to a json file."""
248
+ with open(json_file_path, "w", encoding="utf-8") as writer:
249
+ writer.write(self.to_json_string())
250
+
251
+
252
+ AUTOGENERATED_TRAINER_COMMENT = """
253
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
254
+ should probably proofread and complete it, then remove this comment. -->
255
+ """
256
+
257
+ AUTOGENERATED_KERAS_COMMENT = """
258
+ <!-- This model card has been generated automatically according to the information Keras had access to. You should
259
+ probably proofread and complete it, then remove this comment. -->
260
+ """
261
+
262
+
263
+ TASK_TAG_TO_NAME_MAPPING = {
264
+ "fill-mask": "Masked Language Modeling",
265
+ "image-classification": "Image Classification",
266
+ "image-segmentation": "Image Segmentation",
267
+ "multiple-choice": "Multiple Choice",
268
+ "object-detection": "Object Detection",
269
+ "question-answering": "Question Answering",
270
+ "summarization": "Summarization",
271
+ "table-question-answering": "Table Question Answering",
272
+ "text-classification": "Text Classification",
273
+ "text-generation": "Causal Language Modeling",
274
+ "text2text-generation": "Sequence-to-sequence Language Modeling",
275
+ "token-classification": "Token Classification",
276
+ "translation": "Translation",
277
+ "zero-shot-classification": "Zero Shot Classification",
278
+ "automatic-speech-recognition": "Automatic Speech Recognition",
279
+ "audio-classification": "Audio Classification",
280
+ }
281
+
282
+
283
+ METRIC_TAGS = [
284
+ "accuracy",
285
+ "bleu",
286
+ "f1",
287
+ "matthews_correlation",
288
+ "pearsonr",
289
+ "precision",
290
+ "recall",
291
+ "rouge",
292
+ "sacrebleu",
293
+ "spearmanr",
294
+ "wer",
295
+ ]
296
+
297
+
298
+ def _listify(obj):
299
+ if obj is None:
300
+ return []
301
+ elif isinstance(obj, str):
302
+ return [obj]
303
+ else:
304
+ return obj
305
+
306
+
307
+ def _insert_values_as_list(metadata, name, values):
308
+ if values is None:
309
+ return metadata
310
+ if isinstance(values, str):
311
+ values = [values]
312
+ values = [v for v in values if v is not None]
313
+ if len(values) == 0:
314
+ return metadata
315
+ metadata[name] = values
316
+ return metadata
317
+
318
+
319
+ def infer_metric_tags_from_eval_results(eval_results):
320
+ if eval_results is None:
321
+ return {}
322
+ result = {}
323
+ for key in eval_results.keys():
324
+ if key.lower().replace(" ", "_") in METRIC_TAGS:
325
+ result[key.lower().replace(" ", "_")] = key
326
+ elif key.lower() == "rouge1":
327
+ result["rouge"] = key
328
+ return result
329
+
330
+
331
+ def _insert_value(metadata, name, value):
332
+ if value is None:
333
+ return metadata
334
+ metadata[name] = value
335
+ return metadata
336
+
337
+
338
+ def is_hf_dataset(dataset):
339
+ if not is_datasets_available():
340
+ return False
341
+
342
+ from datasets import Dataset, IterableDataset
343
+
344
+ return isinstance(dataset, (Dataset, IterableDataset))
345
+
346
+
347
+ def _get_mapping_values(mapping):
348
+ result = []
349
+ for v in mapping.values():
350
+ if isinstance(v, (tuple, list)):
351
+ result += list(v)
352
+ else:
353
+ result.append(v)
354
+ return result
355
+
356
+
357
+ @dataclass
358
+ class TrainingSummary:
359
+ model_name: str
360
+ language: Optional[Union[str, List[str]]] = None
361
+ license: Optional[str] = None
362
+ tags: Optional[Union[str, List[str]]] = None
363
+ finetuned_from: Optional[str] = None
364
+ tasks: Optional[Union[str, List[str]]] = None
365
+ dataset: Optional[Union[str, List[str]]] = None
366
+ dataset_tags: Optional[Union[str, List[str]]] = None
367
+ dataset_args: Optional[Union[str, List[str]]] = None
368
+ dataset_metadata: Optional[Dict[str, Any]] = None
369
+ eval_results: Optional[Dict[str, float]] = None
370
+ eval_lines: Optional[List[str]] = None
371
+ hyperparameters: Optional[Dict[str, Any]] = None
372
+ source: Optional[str] = "trainer"
373
+
374
+ def __post_init__(self):
375
+ # Infer default license from the checkpoint used, if possible.
376
+ if (
377
+ self.license is None
378
+ and not is_offline_mode()
379
+ and self.finetuned_from is not None
380
+ and len(self.finetuned_from) > 0
381
+ ):
382
+ try:
383
+ info = model_info(self.finetuned_from)
384
+ for tag in info.tags:
385
+ if tag.startswith("license:"):
386
+ self.license = tag[8:]
387
+ except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError, HFValidationError):
388
+ pass
389
+
390
+ def create_model_index(self, metric_mapping):
391
+ model_index = {"name": self.model_name}
392
+
393
+ # Dataset mapping tag -> name
394
+ dataset_names = _listify(self.dataset)
395
+ dataset_tags = _listify(self.dataset_tags)
396
+ dataset_args = _listify(self.dataset_args)
397
+ dataset_metadata = _listify(self.dataset_metadata)
398
+ if len(dataset_args) < len(dataset_tags):
399
+ dataset_args = dataset_args + [None] * (len(dataset_tags) - len(dataset_args))
400
+ dataset_mapping = dict(zip(dataset_tags, dataset_names))
401
+ dataset_arg_mapping = dict(zip(dataset_tags, dataset_args))
402
+ dataset_metadata_mapping = dict(zip(dataset_tags, dataset_metadata))
403
+
404
+ task_mapping = {
405
+ task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
406
+ }
407
+
408
+ model_index["results"] = []
409
+
410
+ if len(task_mapping) == 0 and len(dataset_mapping) == 0:
411
+ return [model_index]
412
+ if len(task_mapping) == 0:
413
+ task_mapping = {None: None}
414
+ if len(dataset_mapping) == 0:
415
+ dataset_mapping = {None: None}
416
+
417
+ # One entry per dataset and per task
418
+ all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
419
+ for task_tag, ds_tag in all_possibilities:
420
+ result = {}
421
+ if task_tag is not None:
422
+ result["task"] = {"name": task_mapping[task_tag], "type": task_tag}
423
+
424
+ if ds_tag is not None:
425
+ metadata = dataset_metadata_mapping.get(ds_tag, {})
426
+ result["dataset"] = {
427
+ "name": dataset_mapping[ds_tag],
428
+ "type": ds_tag,
429
+ **metadata,
430
+ }
431
+ if dataset_arg_mapping[ds_tag] is not None:
432
+ result["dataset"]["args"] = dataset_arg_mapping[ds_tag]
433
+
434
+ if len(metric_mapping) > 0:
435
+ result["metrics"] = []
436
+ for metric_tag, metric_name in metric_mapping.items():
437
+ result["metrics"].append(
438
+ {
439
+ "name": metric_name,
440
+ "type": metric_tag,
441
+ "value": self.eval_results[metric_name],
442
+ }
443
+ )
444
+
445
+ # Remove partial results to avoid the model card being rejected.
446
+ if "task" in result and "dataset" in result and "metrics" in result:
447
+ model_index["results"].append(result)
448
+ else:
449
+ logger.info(f"Dropping the following result as it does not have all the necessary fields:\n{result}")
450
+
451
+ return [model_index]
452
+
453
+ def create_metadata(self):
454
+ metric_mapping = infer_metric_tags_from_eval_results(self.eval_results)
455
+
456
+ metadata = {}
457
+ metadata = _insert_value(metadata, "library_name", "transformers")
458
+ metadata = _insert_values_as_list(metadata, "language", self.language)
459
+ metadata = _insert_value(metadata, "license", self.license)
460
+ if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0:
461
+ metadata = _insert_value(metadata, "base_model", self.finetuned_from)
462
+ metadata = _insert_values_as_list(metadata, "tags", self.tags)
463
+ metadata = _insert_values_as_list(metadata, "datasets", self.dataset_tags)
464
+ metadata = _insert_values_as_list(metadata, "metrics", list(metric_mapping.keys()))
465
+ metadata["model-index"] = self.create_model_index(metric_mapping)
466
+
467
+ return metadata
468
+
469
+ def to_model_card(self):
470
+ model_card = ""
471
+
472
+ metadata = yaml.dump(self.create_metadata(), sort_keys=False)
473
+ if len(metadata) > 0:
474
+ model_card = f"---\n{metadata}---\n"
475
+
476
+ # Now the model card for realsies.
477
+ if self.source == "trainer":
478
+ model_card += AUTOGENERATED_TRAINER_COMMENT
479
+ else:
480
+ model_card += AUTOGENERATED_KERAS_COMMENT
481
+
482
+ model_card += f"\n# {self.model_name}\n\n"
483
+
484
+ if self.finetuned_from is None:
485
+ model_card += "This model was trained from scratch on "
486
+ else:
487
+ model_card += (
488
+ "This model is a fine-tuned version of"
489
+ f" [{self.finetuned_from}](https://huggingface.co/{self.finetuned_from}) on "
490
+ )
491
+
492
+ if self.dataset is None:
493
+ model_card += "an unknown dataset."
494
+ else:
495
+ if isinstance(self.dataset, str):
496
+ model_card += f"the {self.dataset} dataset."
497
+ elif isinstance(self.dataset, (tuple, list)) and len(self.dataset) == 1:
498
+ model_card += f"the {self.dataset[0]} dataset."
499
+ else:
500
+ model_card += (
501
+ ", ".join([f"the {ds}" for ds in self.dataset[:-1]]) + f" and the {self.dataset[-1]} datasets."
502
+ )
503
+
504
+ if self.eval_results is not None:
505
+ model_card += "\nIt achieves the following results on the evaluation set:\n"
506
+ model_card += "\n".join([f"- {name}: {_maybe_round(value)}" for name, value in self.eval_results.items()])
507
+ model_card += "\n"
508
+
509
+ model_card += "\n## Model description\n\nMore information needed\n"
510
+ model_card += "\n## Intended uses & limitations\n\nMore information needed\n"
511
+ model_card += "\n## Training and evaluation data\n\nMore information needed\n"
512
+
513
+ model_card += "\n## Training procedure\n"
514
+ model_card += "\n### Training hyperparameters\n"
515
+ if self.hyperparameters is not None:
516
+ model_card += "\nThe following hyperparameters were used during training:\n"
517
+ model_card += "\n".join([f"- {name}: {value}" for name, value in self.hyperparameters.items()])
518
+ model_card += "\n"
519
+ else:
520
+ model_card += "\nMore information needed\n"
521
+
522
+ if self.eval_lines is not None:
523
+ model_card += "\n### Training results\n\n"
524
+ model_card += make_markdown_table(self.eval_lines)
525
+ model_card += "\n"
526
+
527
+ model_card += "\n### Framework versions\n\n"
528
+ model_card += f"- Transformers {__version__}\n"
529
+
530
+ if self.source == "trainer" and is_torch_available():
531
+ import torch
532
+
533
+ model_card += f"- Pytorch {torch.__version__}\n"
534
+ elif self.source == "keras" and is_tf_available():
535
+ import tensorflow as tf
536
+
537
+ model_card += f"- TensorFlow {tf.__version__}\n"
538
+ if is_datasets_available():
539
+ import datasets
540
+
541
+ model_card += f"- Datasets {datasets.__version__}\n"
542
+ if is_tokenizers_available():
543
+ import tokenizers
544
+
545
+ model_card += f"- Tokenizers {tokenizers.__version__}\n"
546
+
547
+ return model_card
548
+
549
+ @classmethod
550
+ def from_trainer(
551
+ cls,
552
+ trainer,
553
+ language=None,
554
+ license=None,
555
+ tags=None,
556
+ model_name=None,
557
+ finetuned_from=None,
558
+ tasks=None,
559
+ dataset_tags=None,
560
+ dataset_metadata=None,
561
+ dataset=None,
562
+ dataset_args=None,
563
+ ):
564
+ # Infer default from dataset
565
+ one_dataset = trainer.eval_dataset if trainer.eval_dataset is not None else trainer.train_dataset
566
+ if is_hf_dataset(one_dataset) and (dataset_tags is None or dataset_args is None or dataset_metadata is None):
567
+ default_tag = one_dataset.builder_name
568
+ # Those are not real datasets from the Hub so we exclude them.
569
+ if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
570
+ if dataset_metadata is None:
571
+ dataset_metadata = [{"config": one_dataset.config_name, "split": str(one_dataset.split)}]
572
+ if dataset_tags is None:
573
+ dataset_tags = [default_tag]
574
+ if dataset_args is None:
575
+ dataset_args = [one_dataset.config_name]
576
+
577
+ if dataset is None and dataset_tags is not None:
578
+ dataset = dataset_tags
579
+
580
+ # Infer default finetuned_from
581
+ if (
582
+ finetuned_from is None
583
+ and hasattr(trainer.model.config, "_name_or_path")
584
+ and not os.path.isdir(trainer.model.config._name_or_path)
585
+ ):
586
+ finetuned_from = trainer.model.config._name_or_path
587
+
588
+ # Infer default task tag:
589
+ if tasks is None:
590
+ model_class_name = trainer.model.__class__.__name__
591
+ for task, mapping in TASK_MAPPING.items():
592
+ if model_class_name in _get_mapping_values(mapping):
593
+ tasks = task
594
+
595
+ if model_name is None:
596
+ model_name = Path(trainer.args.output_dir).name
597
+ if len(model_name) == 0:
598
+ model_name = finetuned_from
599
+
600
+ # Add `generated_from_trainer` to the tags
601
+ if tags is None:
602
+ tags = ["generated_from_trainer"]
603
+ elif isinstance(tags, str) and tags != "generated_from_trainer":
604
+ tags = [tags, "generated_from_trainer"]
605
+ elif "generated_from_trainer" not in tags:
606
+ tags.append("generated_from_trainer")
607
+
608
+ _, eval_lines, eval_results = parse_log_history(trainer.state.log_history)
609
+ hyperparameters = extract_hyperparameters_from_trainer(trainer)
610
+
611
+ return cls(
612
+ language=language,
613
+ license=license,
614
+ tags=tags,
615
+ model_name=model_name,
616
+ finetuned_from=finetuned_from,
617
+ tasks=tasks,
618
+ dataset=dataset,
619
+ dataset_tags=dataset_tags,
620
+ dataset_args=dataset_args,
621
+ dataset_metadata=dataset_metadata,
622
+ eval_results=eval_results,
623
+ eval_lines=eval_lines,
624
+ hyperparameters=hyperparameters,
625
+ )
626
+
627
+ @classmethod
628
+ def from_keras(
629
+ cls,
630
+ model,
631
+ model_name,
632
+ keras_history=None,
633
+ language=None,
634
+ license=None,
635
+ tags=None,
636
+ finetuned_from=None,
637
+ tasks=None,
638
+ dataset_tags=None,
639
+ dataset=None,
640
+ dataset_args=None,
641
+ ):
642
+ # Infer default from dataset
643
+ if dataset is not None:
644
+ if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
645
+ default_tag = dataset.builder_name
646
+ # Those are not real datasets from the Hub so we exclude them.
647
+ if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
648
+ if dataset_tags is None:
649
+ dataset_tags = [default_tag]
650
+ if dataset_args is None:
651
+ dataset_args = [dataset.config_name]
652
+
653
+ if dataset is None and dataset_tags is not None:
654
+ dataset = dataset_tags
655
+
656
+ # Infer default finetuned_from
657
+ if (
658
+ finetuned_from is None
659
+ and hasattr(model.config, "_name_or_path")
660
+ and not os.path.isdir(model.config._name_or_path)
661
+ ):
662
+ finetuned_from = model.config._name_or_path
663
+
664
+ # Infer default task tag:
665
+ if tasks is None:
666
+ model_class_name = model.__class__.__name__
667
+ for task, mapping in TASK_MAPPING.items():
668
+ if model_class_name in _get_mapping_values(mapping):
669
+ tasks = task
670
+
671
+ # Add `generated_from_keras_callback` to the tags
672
+ if tags is None:
673
+ tags = ["generated_from_keras_callback"]
674
+ elif isinstance(tags, str) and tags != "generated_from_keras_callback":
675
+ tags = [tags, "generated_from_keras_callback"]
676
+ elif "generated_from_keras_callback" not in tags:
677
+ tags.append("generated_from_keras_callback")
678
+
679
+ if keras_history is not None:
680
+ _, eval_lines, eval_results = parse_keras_history(keras_history)
681
+ else:
682
+ eval_lines = []
683
+ eval_results = {}
684
+ hyperparameters = extract_hyperparameters_from_keras(model)
685
+
686
+ return cls(
687
+ language=language,
688
+ license=license,
689
+ tags=tags,
690
+ model_name=model_name,
691
+ finetuned_from=finetuned_from,
692
+ tasks=tasks,
693
+ dataset_tags=dataset_tags,
694
+ dataset=dataset,
695
+ dataset_args=dataset_args,
696
+ eval_results=eval_results,
697
+ eval_lines=eval_lines,
698
+ hyperparameters=hyperparameters,
699
+ source="keras",
700
+ )
701
+
702
+
703
+ def parse_keras_history(logs):
704
+ """
705
+ Parse the `logs` of either a `keras.History` object returned by `model.fit()` or an accumulated logs `dict`
706
+ passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`.
707
+ """
708
+ if hasattr(logs, "history"):
709
+ # This looks like a `History` object
710
+ if not hasattr(logs, "epoch"):
711
+ # This history looks empty, return empty results
712
+ return None, [], {}
713
+ logs.history["epoch"] = logs.epoch
714
+ logs = logs.history
715
+ else:
716
+ # Training logs is a list of dicts, let's invert it to a dict of lists to match a History object
717
+ logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}
718
+
719
+ lines = []
720
+ for i in range(len(logs["epoch"])):
721
+ epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
722
+ values = {}
723
+ for k, v in epoch_dict.items():
724
+ if k.startswith("val_"):
725
+ k = "validation_" + k[4:]
726
+ elif k != "epoch":
727
+ k = "train_" + k
728
+ splits = k.split("_")
729
+ name = " ".join([part.capitalize() for part in splits])
730
+ values[name] = v
731
+ lines.append(values)
732
+
733
+ eval_results = lines[-1]
734
+
735
+ return logs, lines, eval_results
736
+
737
+
738
+ def parse_log_history(log_history):
739
+ """
740
+ Parse the `log_history` of a Trainer to get the intermediate and final evaluation results.
741
+ """
742
+ idx = 0
743
+ while idx < len(log_history) and "train_runtime" not in log_history[idx]:
744
+ idx += 1
745
+
746
+ # If there are no training logs
747
+ if idx == len(log_history):
748
+ idx -= 1
749
+ while idx >= 0 and "eval_loss" not in log_history[idx]:
750
+ idx -= 1
751
+
752
+ if idx >= 0:
753
+ return None, None, log_history[idx]
754
+ else:
755
+ return None, None, None
756
+
757
+ # From now one we can assume we have training logs:
758
+ train_log = log_history[idx]
759
+ lines = []
760
+ training_loss = "No log"
761
+ for i in range(idx):
762
+ if "loss" in log_history[i]:
763
+ training_loss = log_history[i]["loss"]
764
+ if "eval_loss" in log_history[i]:
765
+ metrics = log_history[i].copy()
766
+ _ = metrics.pop("total_flos", None)
767
+ epoch = metrics.pop("epoch", None)
768
+ step = metrics.pop("step", None)
769
+ _ = metrics.pop("eval_runtime", None)
770
+ _ = metrics.pop("eval_samples_per_second", None)
771
+ _ = metrics.pop("eval_steps_per_second", None)
772
+ _ = metrics.pop("eval_jit_compilation_time", None)
773
+ values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
774
+ for k, v in metrics.items():
775
+ if k == "eval_loss":
776
+ values["Validation Loss"] = v
777
+ else:
778
+ splits = k.split("_")
779
+ name = " ".join([part.capitalize() for part in splits[1:]])
780
+ values[name] = v
781
+ lines.append(values)
782
+
783
+ idx = len(log_history) - 1
784
+ while idx >= 0 and "eval_loss" not in log_history[idx]:
785
+ idx -= 1
786
+
787
+ if idx > 0:
788
+ eval_results = {}
789
+ for key, value in log_history[idx].items():
790
+ if key.startswith("eval_"):
791
+ key = key[5:]
792
+ if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
793
+ camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
794
+ eval_results[camel_cased_key] = value
795
+ return train_log, lines, eval_results
796
+ else:
797
+ return train_log, lines, None
798
+
799
+
800
+ def extract_hyperparameters_from_keras(model):
801
+ from .modeling_tf_utils import keras
802
+
803
+ hyperparameters = {}
804
+ if hasattr(model, "optimizer") and model.optimizer is not None:
805
+ hyperparameters["optimizer"] = model.optimizer.get_config()
806
+ else:
807
+ hyperparameters["optimizer"] = None
808
+ hyperparameters["training_precision"] = keras.mixed_precision.global_policy().name
809
+
810
+ return hyperparameters
811
+
812
+
813
+ def _maybe_round(v, decimals=4):
814
+ if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
815
+ return f"{v:.{decimals}f}"
816
+ return str(v)
817
+
818
+
819
+ def _regular_table_line(values, col_widths):
820
+ values_with_space = [f"| {v}" + " " * (w - len(v) + 1) for v, w in zip(values, col_widths)]
821
+ return "".join(values_with_space) + "|\n"
822
+
823
+
824
+ def _second_table_line(col_widths):
825
+ values = ["|:" + "-" * w + ":" for w in col_widths]
826
+ return "".join(values) + "|\n"
827
+
828
+
829
+ def make_markdown_table(lines):
830
+ """
831
+ Create a nice Markdown table from the results in `lines`.
832
+ """
833
+ if lines is None or len(lines) == 0:
834
+ return ""
835
+ col_widths = {key: len(str(key)) for key in lines[0].keys()}
836
+ for line in lines:
837
+ for key, value in line.items():
838
+ if col_widths[key] < len(_maybe_round(value)):
839
+ col_widths[key] = len(_maybe_round(value))
840
+
841
+ table = _regular_table_line(list(lines[0].keys()), list(col_widths.values()))
842
+ table += _second_table_line(list(col_widths.values()))
843
+ for line in lines:
844
+ table += _regular_table_line([_maybe_round(v) for v in line.values()], list(col_widths.values()))
845
+ return table
846
+
847
+
848
+ _TRAINING_ARGS_KEYS = [
849
+ "learning_rate",
850
+ "train_batch_size",
851
+ "eval_batch_size",
852
+ "seed",
853
+ ]
854
+
855
+
856
+ def extract_hyperparameters_from_trainer(trainer):
857
+ hyperparameters = {k: getattr(trainer.args, k) for k in _TRAINING_ARGS_KEYS}
858
+
859
+ if trainer.args.parallel_mode not in [ParallelMode.NOT_PARALLEL, ParallelMode.NOT_DISTRIBUTED]:
860
+ hyperparameters["distributed_type"] = (
861
+ "multi-GPU" if trainer.args.parallel_mode == ParallelMode.DISTRIBUTED else trainer.args.parallel_mode.value
862
+ )
863
+ if trainer.args.world_size > 1:
864
+ hyperparameters["num_devices"] = trainer.args.world_size
865
+ if trainer.args.gradient_accumulation_steps > 1:
866
+ hyperparameters["gradient_accumulation_steps"] = trainer.args.gradient_accumulation_steps
867
+
868
+ total_train_batch_size = (
869
+ trainer.args.train_batch_size * trainer.args.world_size * trainer.args.gradient_accumulation_steps
870
+ )
871
+ if total_train_batch_size != hyperparameters["train_batch_size"]:
872
+ hyperparameters["total_train_batch_size"] = total_train_batch_size
873
+ total_eval_batch_size = trainer.args.eval_batch_size * trainer.args.world_size
874
+ if total_eval_batch_size != hyperparameters["eval_batch_size"]:
875
+ hyperparameters["total_eval_batch_size"] = total_eval_batch_size
876
+
877
+ if trainer.args.optim:
878
+ optimizer_name = trainer.args.optim
879
+ optimizer_args = trainer.args.optim_args if trainer.args.optim_args else "No additional optimizer arguments"
880
+
881
+ if "adam" in optimizer_name.lower():
882
+ hyperparameters["optimizer"] = (
883
+ f"Use {optimizer_name} with betas=({trainer.args.adam_beta1},{trainer.args.adam_beta2}) and"
884
+ f" epsilon={trainer.args.adam_epsilon} and optimizer_args={optimizer_args}"
885
+ )
886
+ else:
887
+ hyperparameters["optimizer"] = f"Use {optimizer_name} and the args are:\n{optimizer_args}"
888
+
889
+ hyperparameters["lr_scheduler_type"] = trainer.args.lr_scheduler_type.value
890
+ if trainer.args.warmup_ratio != 0.0:
891
+ hyperparameters["lr_scheduler_warmup_ratio"] = trainer.args.warmup_ratio
892
+ if trainer.args.warmup_steps != 0.0:
893
+ hyperparameters["lr_scheduler_warmup_steps"] = trainer.args.warmup_steps
894
+ if trainer.args.max_steps != -1:
895
+ hyperparameters["training_steps"] = trainer.args.max_steps
896
+ else:
897
+ hyperparameters["num_epochs"] = trainer.args.num_train_epochs
898
+
899
+ if trainer.args.fp16:
900
+ if trainer.use_apex:
901
+ hyperparameters["mixed_precision_training"] = f"Apex, opt level {trainer.args.fp16_opt_level}"
902
+ else:
903
+ hyperparameters["mixed_precision_training"] = "Native AMP"
904
+
905
+ if trainer.args.label_smoothing_factor != 0.0:
906
+ hyperparameters["label_smoothing_factor"] = trainer.args.label_smoothing_factor
907
+
908
+ return hyperparameters
modeling_attn_mask_utils.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+
19
+ from .utils.import_utils import is_torchdynamo_compiling
20
+
21
+
22
+ @dataclass
23
+ class AttentionMaskConverter:
24
+ """
25
+ A utility attention mask class that allows one to:
26
+ - Create a causal 4d mask
27
+ - Create a causal 4d mask with slided window
28
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
29
+ key_value_length) that can be multiplied with attention scores
30
+
31
+ Examples:
32
+
33
+ ```python
34
+ >>> import torch
35
+ >>> from transformers.modeling_attn_mask_utils import AttentionMaskConverter
36
+
37
+ >>> converter = AttentionMaskConverter(True)
38
+ >>> converter.to_4d(torch.tensor([[0, 0, 0, 1, 1]]), 5, key_value_length=5, dtype=torch.float32)
39
+ tensor([[[[-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
40
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
41
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38],
42
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, -3.4028e+38],
43
+ [-3.4028e+38, -3.4028e+38, -3.4028e+38, 0.0000e+00, 0.0000e+00]]]])
44
+ ```
45
+
46
+ Parameters:
47
+ is_causal (`bool`):
48
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
49
+
50
+ sliding_window (`int`, *optional*):
51
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
52
+ """
53
+
54
+ is_causal: bool
55
+ sliding_window: int
56
+
57
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
58
+ self.is_causal = is_causal
59
+ self.sliding_window = sliding_window
60
+
61
+ if self.sliding_window is not None and self.sliding_window <= 0:
62
+ raise ValueError(
63
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
64
+ )
65
+
66
+ def to_causal_4d(
67
+ self,
68
+ batch_size: int,
69
+ query_length: int,
70
+ key_value_length: int,
71
+ dtype: torch.dtype,
72
+ device: Union[torch.device, "str"] = "cpu",
73
+ ) -> Optional[torch.Tensor]:
74
+ """
75
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
76
+ bias to upper right hand triangular matrix (causal mask).
77
+ """
78
+ if not self.is_causal:
79
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
80
+
81
+ # If shape is not cached, create a new causal mask and cache it
82
+ input_shape = (batch_size, query_length)
83
+ past_key_values_length = key_value_length - query_length
84
+
85
+ # create causal mask
86
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
87
+ causal_4d_mask = None
88
+ if input_shape[-1] > 1 or self.sliding_window is not None:
89
+ causal_4d_mask = self._make_causal_mask(
90
+ input_shape,
91
+ dtype,
92
+ device=device,
93
+ past_key_values_length=past_key_values_length,
94
+ sliding_window=self.sliding_window,
95
+ )
96
+
97
+ return causal_4d_mask
98
+
99
+ def to_4d(
100
+ self,
101
+ attention_mask_2d: torch.Tensor,
102
+ query_length: int,
103
+ dtype: torch.dtype,
104
+ key_value_length: Optional[int] = None,
105
+ ) -> torch.Tensor:
106
+ """
107
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
108
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
109
+ causal, a causal mask will be added.
110
+ """
111
+ input_shape = (attention_mask_2d.shape[0], query_length)
112
+
113
+ # create causal mask
114
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
115
+ causal_4d_mask = None
116
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
117
+ if key_value_length is None:
118
+ raise ValueError(
119
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
120
+ )
121
+
122
+ past_key_values_length = key_value_length - query_length
123
+ causal_4d_mask = self._make_causal_mask(
124
+ input_shape,
125
+ dtype,
126
+ device=attention_mask_2d.device,
127
+ past_key_values_length=past_key_values_length,
128
+ sliding_window=self.sliding_window,
129
+ )
130
+ elif self.sliding_window is not None:
131
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
132
+
133
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
134
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
135
+ attention_mask_2d.device
136
+ )
137
+
138
+ if causal_4d_mask is not None:
139
+ expanded_attn_mask = causal_4d_mask.masked_fill(expanded_attn_mask.bool(), torch.finfo(dtype).min)
140
+
141
+ # expanded_attn_mask + causal_4d_mask can cause some overflow
142
+ expanded_4d_mask = expanded_attn_mask
143
+
144
+ return expanded_4d_mask
145
+
146
+ @staticmethod
147
+ def _make_causal_mask(
148
+ input_ids_shape: torch.Size,
149
+ dtype: torch.dtype,
150
+ device: torch.device,
151
+ past_key_values_length: int = 0,
152
+ sliding_window: Optional[int] = None,
153
+ ):
154
+ """
155
+ Make causal mask used for bi-directional self-attention.
156
+ """
157
+ bsz, tgt_len = input_ids_shape
158
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
159
+ mask_cond = torch.arange(mask.size(-1), device=device)
160
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
161
+
162
+ mask = mask.to(dtype)
163
+
164
+ if past_key_values_length > 0:
165
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
166
+
167
+ # add lower triangular sliding window mask if necessary
168
+ if sliding_window is not None:
169
+ diagonal = past_key_values_length - sliding_window - 1
170
+
171
+ context_mask = torch.tril(torch.ones_like(mask, dtype=torch.bool), diagonal=diagonal)
172
+ # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
173
+ # See https://github.com/pytorch/pytorch/issues/127571
174
+ if is_torchdynamo_compiling():
175
+ mask = mask.clone()
176
+ mask.masked_fill_(context_mask, torch.finfo(dtype).min)
177
+
178
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
179
+
180
+ @staticmethod
181
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
182
+ """
183
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
184
+ """
185
+ bsz, src_len = mask.size()
186
+ tgt_len = tgt_len if tgt_len is not None else src_len
187
+
188
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
189
+
190
+ inverted_mask = 1.0 - expanded_mask
191
+
192
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
193
+
194
+ @staticmethod
195
+ def _unmask_unattended(
196
+ expanded_mask: torch.FloatTensor,
197
+ min_dtype: float,
198
+ ):
199
+ # fmt: off
200
+ """
201
+ Attend to all tokens in masked rows from the expanded attention mask, for example the relevant first rows when
202
+ using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
203
+ Details: https://github.com/pytorch/pytorch/issues/110213
204
+
205
+ `expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
206
+ `attention_mask` is [bsz, src_seq_len].
207
+
208
+ The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
209
+
210
+ For example, if `expanded_mask` is (e.g. here left-padding case)
211
+ ```
212
+ [[[[0, 0, 0],
213
+ [0, 0, 0],
214
+ [0, 0, 1]]],
215
+ [[[1, 0, 0],
216
+ [1, 1, 0],
217
+ [1, 1, 1]]],
218
+ [[[0, 0, 0],
219
+ [0, 1, 0],
220
+ [0, 1, 1]]]]
221
+ ```
222
+ then the modified `expanded_mask` will be
223
+ ```
224
+ [[[[1, 1, 1], <-- modified
225
+ [1, 1, 1], <-- modified
226
+ [0, 0, 1]]],
227
+ [[[1, 0, 0],
228
+ [1, 1, 0],
229
+ [1, 1, 1]]],
230
+ [[[1, 1, 1], <-- modified
231
+ [0, 1, 0],
232
+ [0, 1, 1]]]]
233
+ ```
234
+ """
235
+ # fmt: on
236
+ if expanded_mask.dtype == torch.bool:
237
+ raise ValueError(
238
+ "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
239
+ )
240
+
241
+ return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
242
+
243
+ @staticmethod
244
+ def _ignore_causal_mask_sdpa(
245
+ attention_mask: Optional[torch.Tensor],
246
+ inputs_embeds: torch.Tensor,
247
+ past_key_values_length: int,
248
+ sliding_window: Optional[int] = None,
249
+ is_training: bool = False,
250
+ ) -> bool:
251
+ """
252
+ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be
253
+ ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument.
254
+
255
+ In case no token is masked in the `attention_mask` argument, if `query_length == 1` or
256
+ `key_value_length == query_length`, we rather rely on SDPA `is_causal` argument to use causal/non-causal masks,
257
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
258
+ passed).
259
+ """
260
+
261
+ _, query_length = inputs_embeds.shape[0], inputs_embeds.shape[1]
262
+ key_value_length = query_length + past_key_values_length
263
+
264
+ is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
265
+
266
+ ignore_causal_mask = False
267
+
268
+ if attention_mask is None:
269
+ # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input
270
+ # shape, thus SDPA's `is_causal` argument is rightfully updated
271
+ # (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using
272
+ # `torch.export` or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is
273
+ # hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True`
274
+ # which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108).
275
+ # Thus, we only set `ignore_causal_mask = True` if the model is set to training.
276
+ #
277
+ # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal`
278
+ # ("TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor").
279
+ if (
280
+ (is_training or not is_tracing)
281
+ and (query_length == 1 or key_value_length == query_length)
282
+ and (sliding_window is None or key_value_length < sliding_window)
283
+ ):
284
+ ignore_causal_mask = True
285
+ elif sliding_window is None or key_value_length < sliding_window:
286
+ if len(attention_mask.shape) == 4:
287
+ return False
288
+ elif not is_tracing and torch.all(attention_mask == 1):
289
+ if query_length == 1 or key_value_length == query_length:
290
+ # For query_length == 1, causal attention and bi-directional attention are the same.
291
+ ignore_causal_mask = True
292
+
293
+ # Unfortunately, for query_length > 1 and key_value_length != query_length, we cannot generally ignore
294
+ # the attention mask, as SDPA causal mask generation may be wrong. We will set `is_causal=False` in
295
+ # SDPA and rely on Transformers attention_mask instead, hence not setting it to None here.
296
+ # Reference: https://github.com/pytorch/pytorch/issues/108108
297
+ # TODO: maybe revisit this with https://github.com/pytorch/pytorch/pull/114823 in PyTorch 2.3.
298
+
299
+ return ignore_causal_mask
300
+
301
+
302
+ def _prepare_4d_causal_attention_mask(
303
+ attention_mask: Optional[torch.Tensor],
304
+ input_shape: Union[torch.Size, Tuple, List],
305
+ inputs_embeds: torch.Tensor,
306
+ past_key_values_length: int,
307
+ sliding_window: Optional[int] = None,
308
+ ):
309
+ """
310
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
311
+ `(batch_size, key_value_length)`
312
+
313
+ Args:
314
+ attention_mask (`torch.Tensor` or `None`):
315
+ A 2D attention mask of shape `(batch_size, key_value_length)`
316
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
317
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
318
+ inputs_embeds (`torch.Tensor`):
319
+ The embedded inputs as a torch Tensor.
320
+ past_key_values_length (`int`):
321
+ The length of the key value cache.
322
+ sliding_window (`int`, *optional*):
323
+ If the model uses windowed attention, a sliding window should be passed.
324
+ """
325
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
326
+
327
+ key_value_length = input_shape[-1] + past_key_values_length
328
+
329
+ # 4d mask is passed through the layers
330
+ if attention_mask is not None and len(attention_mask.shape) == 2:
331
+ attention_mask = attn_mask_converter.to_4d(
332
+ attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
333
+ )
334
+ elif attention_mask is not None and len(attention_mask.shape) == 4:
335
+ expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
336
+ if tuple(attention_mask.shape) != expected_shape:
337
+ raise ValueError(
338
+ f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
339
+ )
340
+ else:
341
+ # if the 4D mask has correct shape - invert it and fill with negative infinity
342
+ inverted_mask = 1.0 - attention_mask
343
+ attention_mask = inverted_mask.masked_fill(
344
+ inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
345
+ )
346
+ else:
347
+ attention_mask = attn_mask_converter.to_causal_4d(
348
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
349
+ )
350
+
351
+ return attention_mask
352
+
353
+
354
+ # Adapted from _prepare_4d_causal_attention_mask
355
+ def _prepare_4d_causal_attention_mask_for_sdpa(
356
+ attention_mask: Optional[torch.Tensor],
357
+ input_shape: Union[torch.Size, Tuple, List],
358
+ inputs_embeds: torch.Tensor,
359
+ past_key_values_length: int,
360
+ sliding_window: Optional[int] = None,
361
+ ):
362
+ """
363
+ Prepares the correct `attn_mask` argument to be used by `torch.nn.functional.scaled_dot_product_attention`.
364
+
365
+ In case no token is masked in the `attention_mask` argument, we simply set it to `None` for the cases `query_length == 1` and
366
+ `key_value_length == query_length`, and rely instead on SDPA `is_causal` argument to use causal/non-causal masks,
367
+ allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is passed).
368
+ """
369
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
370
+
371
+ key_value_length = input_shape[-1] + past_key_values_length
372
+
373
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1`
374
+ # used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing.
375
+ # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
376
+ is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) or is_torchdynamo_compiling()
377
+
378
+ ignore_causal_mask = AttentionMaskConverter._ignore_causal_mask_sdpa(
379
+ attention_mask=attention_mask,
380
+ inputs_embeds=inputs_embeds,
381
+ past_key_values_length=past_key_values_length,
382
+ sliding_window=sliding_window,
383
+ )
384
+
385
+ if ignore_causal_mask:
386
+ expanded_4d_mask = None
387
+ elif attention_mask is None:
388
+ expanded_4d_mask = attn_mask_converter.to_causal_4d(
389
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
390
+ )
391
+ else:
392
+ if attention_mask.dim() == 4:
393
+ expanded_4d_mask = attention_mask
394
+ else:
395
+ expanded_4d_mask = attn_mask_converter.to_4d(
396
+ attention_mask,
397
+ input_shape[-1],
398
+ dtype=inputs_embeds.dtype,
399
+ key_value_length=key_value_length,
400
+ )
401
+
402
+ # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
403
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
404
+ # Details: https://github.com/pytorch/pytorch/issues/110213
405
+ if not is_tracing and expanded_4d_mask.device.type == "cuda":
406
+ expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
407
+ expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
408
+ )
409
+
410
+ return expanded_4d_mask
411
+
412
+
413
+ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
414
+ """
415
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
416
+ `(batch_size, key_value_length)`
417
+
418
+ Args:
419
+ mask (`torch.Tensor`):
420
+ A 2D attention mask of shape `(batch_size, key_value_length)`
421
+ dtype (`torch.dtype`):
422
+ The torch dtype the created mask shall have.
423
+ tgt_len (`int`):
424
+ The target length or query length the created mask shall have.
425
+ """
426
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
427
+
428
+
429
+ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
430
+ """
431
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
432
+ `(batch_size, key_value_length)`
433
+
434
+ Args:
435
+ mask (`torch.Tensor`):
436
+ A 2D attention mask of shape `(batch_size, key_value_length)`
437
+ dtype (`torch.dtype`):
438
+ The torch dtype the created mask shall have.
439
+ tgt_len (`int`):
440
+ The target length or query length the created mask shall have.
441
+ """
442
+ _, key_value_length = mask.shape
443
+ tgt_len = tgt_len if tgt_len is not None else key_value_length
444
+
445
+ is_tracing = torch.jit.is_tracing() or isinstance(mask, torch.fx.Proxy) or is_torchdynamo_compiling()
446
+
447
+ # torch.jit.trace, symbolic_trace and torchdynamo with fullgraph=True are unable to capture data-dependent controlflows.
448
+ if not is_tracing and torch.all(mask == 1):
449
+ return None
450
+ else:
451
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
452
+
453
+
454
+ def _create_4d_causal_attention_mask(
455
+ input_shape: Union[torch.Size, Tuple, List],
456
+ dtype: torch.dtype,
457
+ device: torch.device,
458
+ past_key_values_length: int = 0,
459
+ sliding_window: Optional[int] = None,
460
+ ) -> Optional[torch.Tensor]:
461
+ """
462
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
463
+
464
+ Args:
465
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
466
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
467
+ dtype (`torch.dtype`):
468
+ The torch dtype the created mask shall have.
469
+ device (`int`):
470
+ The torch device the created mask shall have.
471
+ sliding_window (`int`, *optional*):
472
+ If the model uses windowed attention, a sliding window should be passed.
473
+ """
474
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
475
+
476
+ key_value_length = past_key_values_length + input_shape[-1]
477
+ attention_mask = attn_mask_converter.to_causal_4d(
478
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
479
+ )
480
+
481
+ return attention_mask
modeling_flash_attention_utils.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import os
18
+ from typing import Optional, Tuple, TypedDict
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+
23
+ from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal, logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ if is_flash_attn_2_available():
30
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
32
+
33
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
34
+
35
+
36
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
37
+ """
38
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
39
+
40
+ Arguments:
41
+ attention_mask (`torch.Tensor`):
42
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
43
+
44
+ Return:
45
+ indices (`torch.Tensor`):
46
+ The indices of non-masked tokens from the flattened input sequence.
47
+ cu_seqlens (`torch.Tensor`):
48
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
49
+ max_seqlen_in_batch (`int`):
50
+ Maximum sequence length in batch.
51
+ """
52
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
53
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
54
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
55
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
56
+ return (
57
+ indices,
58
+ cu_seqlens,
59
+ max_seqlen_in_batch,
60
+ )
61
+
62
+
63
+ def _upad_input(
64
+ query_layer: torch.Tensor,
65
+ key_layer: torch.Tensor,
66
+ value_layer: torch.Tensor,
67
+ attention_mask: torch.Tensor,
68
+ query_length: int,
69
+ ):
70
+ """
71
+ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
72
+
73
+ This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
74
+ tensors for query, key, value tensors.
75
+
76
+ Arguments:
77
+ query_layer (`torch.Tensor`):
78
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
79
+ key_layer (`torch.Tensor`):
80
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
81
+ value_layer (`torch.Tensor`):
82
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
83
+ attention_mask (`torch.Tensor`):
84
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
85
+ query_length (`int`):
86
+ Target length.
87
+
88
+ Return:
89
+ query_layer (`torch.Tensor`):
90
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
91
+ key_layer (`torch.Tensor`):
92
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
93
+ value_layer (`torch.Tensor`):
94
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
95
+ indices_q (`torch.Tensor`):
96
+ The indices of non-masked tokens from the flattened input target sequence.
97
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
98
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
99
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
100
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
101
+ """
102
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
103
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
104
+
105
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
106
+ value_layer = index_first_axis(
107
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
108
+ )
109
+ if query_length == kv_seq_len:
110
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
111
+ cu_seqlens_q = cu_seqlens_k
112
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
113
+ indices_q = indices_k
114
+ elif query_length == 1:
115
+ max_seqlen_in_batch_q = 1
116
+ cu_seqlens_q = torch.arange(
117
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
118
+ ) # There is a memcpy here, that is very bad.
119
+ indices_q = cu_seqlens_q[:-1]
120
+ query_layer = query_layer.squeeze(1)
121
+ else:
122
+ # The -q_len: slice assumes left padding.
123
+ attention_mask = attention_mask[:, -query_length:]
124
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
125
+
126
+ return (
127
+ query_layer,
128
+ key_layer,
129
+ value_layer,
130
+ indices_q,
131
+ (cu_seqlens_q, cu_seqlens_k),
132
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
133
+ )
134
+
135
+
136
+ def prepare_fa2_from_position_ids(query, key, value, position_ids):
137
+ """
138
+ This function returns necessary arguments to call `flash_attn_varlen_func`.
139
+ All three query, key, value states will be flattened.
140
+ Cummulative lengths of each examples in the batch will be extracted from position_ids.
141
+
142
+ NOTE: ideally cummulative lengths should be prepared at the data collator stage
143
+
144
+ Arguments:
145
+ query (`torch.Tensor`):
146
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
147
+ key (`torch.Tensor`):
148
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
149
+ value (`torch.Tensor`):
150
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
151
+ position_ids (`torch.Tensor`):
152
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
153
+
154
+ Return:
155
+ query (`torch.Tensor`):
156
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
157
+ key (`torch.Tensor`):
158
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
159
+ value (`torch.Tensor`):
160
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
161
+ indices_q (`torch.Tensor`):
162
+ The indices of non-masked tokens from the flattened input target sequence.
163
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
164
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
165
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
166
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
167
+ """
168
+ query = query.view(-1, query.size(-2), query.size(-1))
169
+ key = key.contiguous().view(-1, key.size(-2), key.size(-1))
170
+ value = value.contiguous().view(-1, value.size(-2), value.size(-1))
171
+ position_ids = position_ids.flatten()
172
+ indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
173
+
174
+ cu_seq_lens = torch.cat(
175
+ (
176
+ indices_q[position_ids == 0],
177
+ torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
178
+ )
179
+ )
180
+
181
+ max_length = position_ids.max() + 1
182
+
183
+ return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
184
+
185
+
186
+ def fa_peft_integration_check(
187
+ query: torch.Tensor,
188
+ key: torch.Tensor,
189
+ value: torch.Tensor,
190
+ target_dtype: Optional[torch.dtype] = None,
191
+ ):
192
+ """
193
+ PEFT usually casts the layer norms in float32 for training stability reasons
194
+ therefore the input hidden states gets silently casted in float32. Hence, we need
195
+ cast them back in float16 / bfloat16 just to be sure everything works as expected.
196
+ This might slowdown training & inference so it is recommended to not cast the LayerNorms!
197
+
198
+ Args:
199
+ query (`torch.Tensor`):
200
+ Input query states to be passed to Flash Attention API
201
+ key (`torch.Tensor`):
202
+ Input key states to be passed to Flash Attention API
203
+ value (`torch.Tensor`):
204
+ Input value states to be passed to Flash Attention API
205
+ target_dtype (`torch.dtype`, *optional*):
206
+ The dtype to convert the attention tensors to. Conversion can be ignored by
207
+ not providing the target dtype.
208
+ """
209
+ if target_dtype is None:
210
+ return query, key, value
211
+
212
+ input_dtype = value.dtype
213
+ if input_dtype == torch.float32:
214
+ logger.warning_once(
215
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
216
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
217
+ f" {target_dtype}."
218
+ )
219
+
220
+ query = query.to(target_dtype)
221
+ key = key.to(target_dtype)
222
+ value = value.to(target_dtype)
223
+
224
+ return query, key, value
225
+
226
+
227
+ flash_241 = is_flash_attn_greater_or_equal("2.4.1")
228
+ deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
229
+
230
+
231
+ def _flash_attention_forward(
232
+ query_states: torch.Tensor,
233
+ key_states: torch.Tensor,
234
+ value_states: torch.Tensor,
235
+ attention_mask: torch.Tensor,
236
+ query_length: int,
237
+ is_causal: bool,
238
+ dropout: float = 0.0,
239
+ position_ids: Optional[torch.Tensor] = None,
240
+ softmax_scale: Optional[float] = None,
241
+ sliding_window: Optional[int] = None,
242
+ use_top_left_mask: bool = False,
243
+ softcap: Optional[float] = None,
244
+ deterministic: bool = None,
245
+ cu_seq_lens_q: Optional[torch.LongTensor] = None,
246
+ cu_seq_lens_k: Optional[torch.LongTensor] = None,
247
+ max_length_q: Optional[int] = None,
248
+ max_length_k: Optional[int] = None,
249
+ target_dtype: Optional[torch.dtype] = None,
250
+ **kwargs,
251
+ ):
252
+ """
253
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
254
+ first unpad the input, then computes the attention scores and pad the final attention scores.
255
+
256
+ Args:
257
+ query_states (`torch.Tensor`):
258
+ Input query states to be passed to Flash Attention API
259
+ key_states (`torch.Tensor`):
260
+ Input key states to be passed to Flash Attention API
261
+ value_states (`torch.Tensor`):
262
+ Input value states to be passed to Flash Attention API
263
+ attention_mask (`torch.Tensor`):
264
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
265
+ position of padding tokens and 1 for the position of non-padding tokens.
266
+ dropout (`float`):
267
+ Attention dropout
268
+ softmax_scale (`float`, *optional*):
269
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
270
+ use_top_left_mask (`bool`, defaults to `False`):
271
+ flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
272
+ softcap (`float`, *optional*):
273
+ Softcap for the attention logits, used e.g. in gemma2.
274
+ deterministic (`bool`, *optional*):
275
+ Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
276
+ """
277
+ if not use_top_left_mask:
278
+ causal = is_causal
279
+ else:
280
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
281
+ causal = is_causal and query_length != 1
282
+
283
+ # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
284
+ use_sliding_windows = (
285
+ _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
286
+ )
287
+ flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
288
+
289
+ if flash_241:
290
+ if deterministic is None:
291
+ deterministic = deterministic_g
292
+ flash_kwargs["deterministic"] = deterministic
293
+
294
+ if softcap is not None:
295
+ flash_kwargs["softcap"] = softcap
296
+
297
+ # PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
298
+ query_states, key_states, value_states = fa_peft_integration_check(
299
+ query_states, key_states, value_states, target_dtype
300
+ )
301
+
302
+ # Contains at least one padding token in the sequence
303
+ if attention_mask is not None:
304
+ batch_size = query_states.shape[0]
305
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
306
+ query_states, key_states, value_states, attention_mask, query_length
307
+ )
308
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
309
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
310
+
311
+ attn_output_unpad = flash_attn_varlen_func(
312
+ query_states,
313
+ key_states,
314
+ value_states,
315
+ cu_seqlens_q=cu_seqlens_q,
316
+ cu_seqlens_k=cu_seqlens_k,
317
+ max_seqlen_q=max_seqlen_in_batch_q,
318
+ max_seqlen_k=max_seqlen_in_batch_k,
319
+ dropout_p=dropout,
320
+ softmax_scale=softmax_scale,
321
+ causal=causal,
322
+ **flash_kwargs,
323
+ )
324
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
325
+
326
+ # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
327
+ # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
328
+ # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
329
+ elif position_ids is not None and (
330
+ max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
331
+ ):
332
+ batch_size = query_states.size(0)
333
+
334
+ if cu_seq_lens_q is None or cu_seq_lens_k is None:
335
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
336
+ prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
337
+ )
338
+
339
+ cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
340
+ max_length_q, max_length_k = max_seq_lens
341
+
342
+ else:
343
+ query_states = query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
344
+ key_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
345
+ value_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
346
+
347
+ attn_output = flash_attn_varlen_func(
348
+ query_states,
349
+ key_states,
350
+ value_states,
351
+ cu_seqlens_q=cu_seq_lens_q,
352
+ cu_seqlens_k=cu_seq_lens_k,
353
+ max_seqlen_q=max_length_q,
354
+ max_seqlen_k=max_length_k,
355
+ dropout_p=dropout,
356
+ softmax_scale=softmax_scale,
357
+ causal=causal,
358
+ **flash_kwargs,
359
+ )
360
+
361
+ attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
362
+
363
+ else:
364
+ attn_output = flash_attn_func(
365
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
366
+ )
367
+
368
+ return attn_output
369
+
370
+
371
+ class FlashAttentionKwargs(TypedDict, total=False):
372
+ """
373
+ Keyword arguments for Flash Attention with Compile.
374
+
375
+ Attributes:
376
+ cu_seq_lens_q (`torch.LongTensor`, *optional*)
377
+ Gets cumlative sequence length for query state.
378
+ cu_seq_lens_k (`torch.LongTensor`, *optional*)
379
+ Gets cumlative sequence length for key state.
380
+ max_length_q (`int`, *optional*):
381
+ Maximum sequence length for query state.
382
+ max_length_k (`int`, *optional*):
383
+ Maximum sequence length for key state.
384
+ """
385
+
386
+ cu_seq_lens_q: Optional[torch.LongTensor]
387
+ cu_seq_lens_k: Optional[torch.LongTensor]
388
+ max_length_q: Optional[int]
389
+ max_length_k: Optional[int]
modeling_flax_outputs.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple
15
+
16
+ import flax
17
+ import jax.numpy as jnp
18
+
19
+ from .utils import ModelOutput
20
+
21
+
22
+ @flax.struct.dataclass
23
+ class FlaxBaseModelOutput(ModelOutput):
24
+ """
25
+ Base class for model's outputs, with potential hidden states and attentions.
26
+
27
+ Args:
28
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
29
+ Sequence of hidden-states at the output of the last layer of the model.
30
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
31
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
32
+ `(batch_size, sequence_length, hidden_size)`.
33
+
34
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
35
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
36
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
37
+ sequence_length)`.
38
+
39
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
40
+ heads.
41
+ """
42
+
43
+ last_hidden_state: jnp.ndarray = None
44
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
45
+ attentions: Optional[Tuple[jnp.ndarray]] = None
46
+
47
+
48
+ @flax.struct.dataclass
49
+ class FlaxBaseModelOutputWithNoAttention(ModelOutput):
50
+ """
51
+ Base class for model's outputs, with potential hidden states.
52
+
53
+ Args:
54
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
55
+ Sequence of hidden-states at the output of the last layer of the model.
56
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
57
+ Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
58
+ for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
59
+ model at the output of each layer plus the optional initial embedding outputs.
60
+ """
61
+
62
+ last_hidden_state: jnp.ndarray = None
63
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
64
+
65
+
66
+ @flax.struct.dataclass
67
+ class FlaxBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
68
+ """
69
+ Base class for model's outputs that also contains a pooling of the last hidden states.
70
+
71
+ Args:
72
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):
73
+ Sequence of hidden-states at the output of the last layer of the model.
74
+ pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
75
+ Last layer hidden-state after a pooling operation on the spatial dimensions.
76
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
77
+ Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
78
+ for the output of each layer) of shape `(batch_size, num_channels, height, width)`. Hidden-states of the
79
+ model at the output of each layer plus the optional initial embedding outputs.
80
+ """
81
+
82
+ last_hidden_state: jnp.ndarray = None
83
+ pooler_output: jnp.ndarray = None
84
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
85
+
86
+
87
+ @flax.struct.dataclass
88
+ class FlaxImageClassifierOutputWithNoAttention(ModelOutput):
89
+ """
90
+ Base class for outputs of image classification models.
91
+
92
+ Args:
93
+ logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
94
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
95
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when
96
+ `config.output_hidden_states=True`):
97
+ Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
98
+ for the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
99
+ called feature maps) of the model at the output of each stage.
100
+ """
101
+
102
+ logits: jnp.ndarray = None
103
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
104
+
105
+
106
+ @flax.struct.dataclass
107
+ class FlaxBaseModelOutputWithPast(ModelOutput):
108
+ """
109
+ Base class for model's outputs, with potential hidden states and attentions.
110
+
111
+ Args:
112
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
113
+ Sequence of hidden-states at the output of the last layer of the model.
114
+ past_key_values (`Dict[str, jnp.ndarray]`):
115
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
116
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
117
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
118
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
119
+ `(batch_size, sequence_length, hidden_size)`.
120
+
121
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
122
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
123
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
124
+ sequence_length)`.
125
+
126
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
127
+ heads.
128
+ """
129
+
130
+ last_hidden_state: jnp.ndarray = None
131
+ past_key_values: Optional[Dict[str, jnp.ndarray]] = None
132
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
133
+ attentions: Optional[Tuple[jnp.ndarray]] = None
134
+
135
+
136
+ @flax.struct.dataclass
137
+ class FlaxBaseModelOutputWithPooling(ModelOutput):
138
+ """
139
+ Base class for model's outputs that also contains a pooling of the last hidden states.
140
+
141
+ Args:
142
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
143
+ Sequence of hidden-states at the output of the last layer of the model.
144
+ pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
145
+ Last layer hidden-state of the first token of the sequence (classification token) further processed by a
146
+ Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
147
+ prediction (classification) objective during pretraining.
148
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
149
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
150
+ `(batch_size, sequence_length, hidden_size)`.
151
+
152
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
153
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
154
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
155
+ sequence_length)`.
156
+
157
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
158
+ heads.
159
+ """
160
+
161
+ last_hidden_state: jnp.ndarray = None
162
+ pooler_output: jnp.ndarray = None
163
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
164
+ attentions: Optional[Tuple[jnp.ndarray]] = None
165
+
166
+
167
+ @flax.struct.dataclass
168
+ class FlaxBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
169
+ """
170
+ Base class for model's outputs that also contains a pooling of the last hidden states.
171
+
172
+ Args:
173
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
174
+ Sequence of hidden-states at the output of the last layer of the model.
175
+ pooler_output (`jnp.ndarray` of shape `(batch_size, hidden_size)`):
176
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
177
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
178
+ the classification token after processing through a linear layer and a tanh activation function. The linear
179
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
180
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
181
+ Tuple of `jnp.ndarray` (one for the output of the embeddings, if the model has an embedding layer, + one
182
+ for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
183
+
184
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
185
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
186
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
187
+ sequence_length)`.
188
+
189
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
190
+ heads.
191
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
192
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
193
+ sequence_length)`.
194
+
195
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
196
+ weighted average in the cross-attention heads.
197
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
198
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
199
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
200
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
201
+ encoder_sequence_length, embed_size_per_head)`.
202
+
203
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
204
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
205
+ input) to speed up sequential decoding.
206
+ """
207
+
208
+ last_hidden_state: jnp.ndarray = None
209
+ pooler_output: jnp.ndarray = None
210
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
211
+ past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
212
+ attentions: Optional[Tuple[jnp.ndarray]] = None
213
+ cross_attentions: Optional[Tuple[jnp.ndarray]] = None
214
+
215
+
216
+ @flax.struct.dataclass
217
+ class FlaxBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
218
+ """
219
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
220
+
221
+ Args:
222
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
223
+ Sequence of hidden-states at the output of the last layer of the model.
224
+
225
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
226
+ hidden_size)` is output.
227
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
228
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
229
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
230
+ `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
231
+ encoder_sequence_length, embed_size_per_head)`.
232
+
233
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
234
+ `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
235
+ input) to speed up sequential decoding.
236
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
237
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
238
+ `(batch_size, sequence_length, hidden_size)`.
239
+
240
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
241
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
242
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
243
+ sequence_length)`.
244
+
245
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
246
+ heads.
247
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
248
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
249
+ sequence_length)`.
250
+
251
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
252
+ weighted average in the cross-attention heads.
253
+ """
254
+
255
+ last_hidden_state: jnp.ndarray = None
256
+ past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
257
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
258
+ attentions: Optional[Tuple[jnp.ndarray]] = None
259
+ cross_attentions: Optional[Tuple[jnp.ndarray]] = None
260
+
261
+
262
+ @flax.struct.dataclass
263
+ class FlaxSeq2SeqModelOutput(ModelOutput):
264
+ """
265
+ Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
266
+ decoding.
267
+
268
+ Args:
269
+ last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`):
270
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
271
+
272
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
273
+ hidden_size)` is output.
274
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
275
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
276
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
277
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
278
+
279
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
280
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
281
+ decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
282
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
283
+ `(batch_size, sequence_length, hidden_size)`.
284
+
285
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
286
+ decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
287
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
288
+ sequence_length)`.
289
+
290
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
291
+ self-attention heads.
292
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
293
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
294
+ sequence_length)`.
295
+
296
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
297
+ weighted average in the cross-attention heads.
298
+ encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
299
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
300
+ encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
301
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
302
+ `(batch_size, sequence_length, hidden_size)`.
303
+
304
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
305
+ encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
306
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
307
+ sequence_length)`.
308
+
309
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
310
+ self-attention heads.
311
+ """
312
+
313
+ last_hidden_state: jnp.ndarray = None
314
+ past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
315
+ decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
316
+ decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
317
+ cross_attentions: Optional[Tuple[jnp.ndarray]] = None
318
+ encoder_last_hidden_state: Optional[jnp.ndarray] = None
319
+ encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
320
+ encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
321
+
322
+
323
+ @flax.struct.dataclass
324
+ class FlaxCausalLMOutputWithCrossAttentions(ModelOutput):
325
+ """
326
+ Base class for causal language model (or autoregressive) outputs.
327
+
328
+ Args:
329
+ logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
330
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
331
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
332
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
333
+ `(batch_size, sequence_length, hidden_size)`.
334
+
335
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
336
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
337
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
338
+ sequence_length)`.
339
+
340
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
341
+ heads.
342
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
343
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
344
+ sequence_length)`.
345
+
346
+ Cross attentions weights after the attention softmax, used to compute the weighted average in the
347
+ cross-attention heads.
348
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
349
+ Tuple of `jnp.ndarray` tuples of length `config.n_layers`, with each tuple containing the cached key, value
350
+ states of the self-attention and the cross-attention layers if model is used in encoder-decoder setting.
351
+ Only relevant if `config.is_decoder = True`.
352
+
353
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
354
+ `past_key_values` input) to speed up sequential decoding.
355
+ """
356
+
357
+ logits: jnp.ndarray = None
358
+ past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
359
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
360
+ attentions: Optional[Tuple[jnp.ndarray]] = None
361
+ cross_attentions: Optional[Tuple[jnp.ndarray]] = None
362
+
363
+
364
+ @flax.struct.dataclass
365
+ class FlaxMaskedLMOutput(ModelOutput):
366
+ """
367
+ Base class for masked language models outputs.
368
+
369
+ Args:
370
+ logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
371
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
372
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
373
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
374
+ `(batch_size, sequence_length, hidden_size)`.
375
+
376
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
377
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
378
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
379
+ sequence_length)`.
380
+
381
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
382
+ heads.
383
+ """
384
+
385
+ logits: jnp.ndarray = None
386
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
387
+ attentions: Optional[Tuple[jnp.ndarray]] = None
388
+
389
+
390
+ FlaxCausalLMOutput = FlaxMaskedLMOutput
391
+
392
+
393
+ @flax.struct.dataclass
394
+ class FlaxSeq2SeqLMOutput(ModelOutput):
395
+ """
396
+ Base class for sequence-to-sequence language models outputs.
397
+
398
+ Args:
399
+ logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
400
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
401
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
402
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
403
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
404
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
405
+
406
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
407
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
408
+ decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
409
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
410
+ `(batch_size, sequence_length, hidden_size)`.
411
+
412
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
413
+ decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
414
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
415
+ sequence_length)`.
416
+
417
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
418
+ self-attention heads.
419
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
420
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
421
+ sequence_length)`.
422
+
423
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
424
+ weighted average in the cross-attention heads.
425
+ encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
426
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
427
+ encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
428
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
429
+ `(batch_size, sequence_length, hidden_size)`.
430
+
431
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
432
+ encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
433
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
434
+ sequence_length)`.
435
+
436
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
437
+ self-attention heads.
438
+ """
439
+
440
+ logits: jnp.ndarray = None
441
+ past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
442
+ decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
443
+ decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
444
+ cross_attentions: Optional[Tuple[jnp.ndarray]] = None
445
+ encoder_last_hidden_state: Optional[jnp.ndarray] = None
446
+ encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
447
+ encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
448
+
449
+
450
+ @flax.struct.dataclass
451
+ class FlaxNextSentencePredictorOutput(ModelOutput):
452
+ """
453
+ Base class for outputs of models predicting if two sentences are consecutive or not.
454
+
455
+ Args:
456
+ logits (`jnp.ndarray` of shape `(batch_size, 2)`):
457
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
458
+ before SoftMax).
459
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
460
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
461
+ `(batch_size, sequence_length, hidden_size)`.
462
+
463
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
464
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
465
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
466
+ sequence_length)`.
467
+
468
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
469
+ heads.
470
+ """
471
+
472
+ logits: jnp.ndarray = None
473
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
474
+ attentions: Optional[Tuple[jnp.ndarray]] = None
475
+
476
+
477
+ @flax.struct.dataclass
478
+ class FlaxSequenceClassifierOutput(ModelOutput):
479
+ """
480
+ Base class for outputs of sentence classification models.
481
+
482
+ Args:
483
+ logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
484
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
485
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
486
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
487
+ `(batch_size, sequence_length, hidden_size)`.
488
+
489
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
490
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
491
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
492
+ sequence_length)`.
493
+
494
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
495
+ heads.
496
+ """
497
+
498
+ logits: jnp.ndarray = None
499
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
500
+ attentions: Optional[Tuple[jnp.ndarray]] = None
501
+
502
+
503
+ @flax.struct.dataclass
504
+ class FlaxSeq2SeqSequenceClassifierOutput(ModelOutput):
505
+ """
506
+ Base class for outputs of sequence-to-sequence sentence classification models.
507
+
508
+ Args:
509
+ logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):
510
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
511
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
512
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
513
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
514
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
515
+
516
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
517
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
518
+ decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
519
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
520
+ `(batch_size, sequence_length, hidden_size)`.
521
+
522
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
523
+ decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
524
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
525
+ sequence_length)`.
526
+
527
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
528
+ self-attention heads.
529
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
530
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
531
+ sequence_length)`.
532
+
533
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
534
+ weighted average in the cross-attention heads.
535
+ encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
536
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
537
+ encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
538
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
539
+ `(batch_size, sequence_length, hidden_size)`.
540
+
541
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
542
+ encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
543
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
544
+ sequence_length)`.
545
+
546
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
547
+ self-attention heads.
548
+ """
549
+
550
+ logits: jnp.ndarray = None
551
+ past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
552
+ decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
553
+ decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
554
+ cross_attentions: Optional[Tuple[jnp.ndarray]] = None
555
+ encoder_last_hidden_state: Optional[jnp.ndarray] = None
556
+ encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
557
+ encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
558
+
559
+
560
+ @flax.struct.dataclass
561
+ class FlaxMultipleChoiceModelOutput(ModelOutput):
562
+ """
563
+ Base class for outputs of multiple choice models.
564
+
565
+ Args:
566
+ logits (`jnp.ndarray` of shape `(batch_size, num_choices)`):
567
+ *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
568
+
569
+ Classification scores (before SoftMax).
570
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
571
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
572
+ `(batch_size, sequence_length, hidden_size)`.
573
+
574
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
575
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
576
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
577
+ sequence_length)`.
578
+
579
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
580
+ heads.
581
+ """
582
+
583
+ logits: jnp.ndarray = None
584
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
585
+ attentions: Optional[Tuple[jnp.ndarray]] = None
586
+
587
+
588
+ @flax.struct.dataclass
589
+ class FlaxTokenClassifierOutput(ModelOutput):
590
+ """
591
+ Base class for outputs of token classification models.
592
+
593
+ Args:
594
+ logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.num_labels)`):
595
+ Classification scores (before SoftMax).
596
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
597
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
598
+ `(batch_size, sequence_length, hidden_size)`.
599
+
600
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
601
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
602
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
603
+ sequence_length)`.
604
+
605
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
606
+ heads.
607
+ """
608
+
609
+ logits: jnp.ndarray = None
610
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
611
+ attentions: Optional[Tuple[jnp.ndarray]] = None
612
+
613
+
614
+ @flax.struct.dataclass
615
+ class FlaxQuestionAnsweringModelOutput(ModelOutput):
616
+ """
617
+ Base class for outputs of question answering models.
618
+
619
+ Args:
620
+ start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
621
+ Span-start scores (before SoftMax).
622
+ end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
623
+ Span-end scores (before SoftMax).
624
+ hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
625
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
626
+ `(batch_size, sequence_length, hidden_size)`.
627
+
628
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
629
+ attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
630
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
631
+ sequence_length)`.
632
+
633
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
634
+ heads.
635
+ """
636
+
637
+ start_logits: jnp.ndarray = None
638
+ end_logits: jnp.ndarray = None
639
+ hidden_states: Optional[Tuple[jnp.ndarray]] = None
640
+ attentions: Optional[Tuple[jnp.ndarray]] = None
641
+
642
+
643
+ @flax.struct.dataclass
644
+ class FlaxSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
645
+ """
646
+ Base class for outputs of sequence-to-sequence question answering models.
647
+
648
+ Args:
649
+ start_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
650
+ Span-start scores (before SoftMax).
651
+ end_logits (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
652
+ Span-end scores (before SoftMax).
653
+ past_key_values (`tuple(tuple(jnp.ndarray))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
654
+ Tuple of `tuple(jnp.ndarray)` of length `config.n_layers`, with each tuple having 2 tensors of shape
655
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
656
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
657
+
658
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
659
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
660
+ decoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
661
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
662
+ `(batch_size, sequence_length, hidden_size)`.
663
+
664
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
665
+ decoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
666
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
667
+ sequence_length)`.
668
+
669
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
670
+ self-attention heads.
671
+ cross_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
672
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
673
+ sequence_length)`.
674
+
675
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
676
+ weighted average in the cross-attention heads.
677
+ encoder_last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
678
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
679
+ encoder_hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
680
+ Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
681
+ `(batch_size, sequence_length, hidden_size)`.
682
+
683
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
684
+ encoder_attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
685
+ Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
686
+ sequence_length)`.
687
+
688
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
689
+ self-attention heads.
690
+ """
691
+
692
+ start_logits: jnp.ndarray = None
693
+ end_logits: jnp.ndarray = None
694
+ past_key_values: Optional[Tuple[Tuple[jnp.ndarray]]] = None
695
+ decoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
696
+ decoder_attentions: Optional[Tuple[jnp.ndarray]] = None
697
+ cross_attentions: Optional[Tuple[jnp.ndarray]] = None
698
+ encoder_last_hidden_state: Optional[jnp.ndarray] = None
699
+ encoder_hidden_states: Optional[Tuple[jnp.ndarray]] = None
700
+ encoder_attentions: Optional[Tuple[jnp.ndarray]] = None
modeling_flax_pytorch_utils.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch - Flax general utilities."""
16
+
17
+ import os
18
+ from pickle import UnpicklingError
19
+ from typing import Dict, Tuple
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+ from flax.serialization import from_bytes
25
+ from flax.traverse_util import flatten_dict, unflatten_dict
26
+
27
+ import transformers
28
+
29
+ from . import is_safetensors_available, is_torch_available
30
+ from .utils import logging
31
+
32
+
33
+ if is_torch_available():
34
+ import torch
35
+
36
+ if is_safetensors_available():
37
+ from safetensors import safe_open
38
+ from safetensors.flax import load_file as safe_load_file
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ #####################
45
+ # PyTorch => Flax #
46
+ #####################
47
+
48
+
49
+ def load_pytorch_checkpoint_in_flax_state_dict(
50
+ flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
51
+ ):
52
+ """Load pytorch checkpoints in a flax model"""
53
+
54
+ if not is_sharded:
55
+ pt_path = os.path.abspath(pytorch_checkpoint_path)
56
+ logger.info(f"Loading PyTorch weights from {pt_path}")
57
+
58
+ if pt_path.endswith(".safetensors"):
59
+ pt_state_dict = {}
60
+ with safe_open(pt_path, framework="flax") as f:
61
+ for k in f.keys():
62
+ pt_state_dict[k] = f.get_tensor(k)
63
+ else:
64
+ try:
65
+ import torch # noqa: F401
66
+ except (ImportError, ModuleNotFoundError):
67
+ logger.error(
68
+ "Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
69
+ " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
70
+ " instructions."
71
+ )
72
+ raise
73
+
74
+ weights_only_kwarg = {"weights_only": True}
75
+ pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
76
+ logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
77
+
78
+ flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
79
+ else:
80
+ # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
81
+ flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)
82
+ return flax_state_dict
83
+
84
+
85
+ def rename_key_and_reshape_tensor(
86
+ pt_tuple_key: Tuple[str],
87
+ pt_tensor: np.ndarray,
88
+ random_flax_state_dict: Dict[str, jnp.ndarray],
89
+ model_prefix: str,
90
+ ) -> (Tuple[str], np.ndarray):
91
+ """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
92
+
93
+ def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
94
+ """Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
95
+ return len(set(random_flax_state_dict) & {key, (model_prefix,) + key}) > 0
96
+
97
+ # layer norm
98
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
99
+ if pt_tuple_key[-1] in ["weight", "gamma"] and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
100
+ return renamed_pt_tuple_key, pt_tensor
101
+
102
+ # batch norm layer mean
103
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("mean",)
104
+ if pt_tuple_key[-1] == "running_mean" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
105
+ return renamed_pt_tuple_key, pt_tensor
106
+
107
+ # batch norm layer var
108
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("var",)
109
+ if pt_tuple_key[-1] == "running_var" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
110
+ return renamed_pt_tuple_key, pt_tensor
111
+
112
+ # embedding
113
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
114
+ if pt_tuple_key[-1] == "weight" and is_key_or_prefix_key_in_dict(renamed_pt_tuple_key):
115
+ return renamed_pt_tuple_key, pt_tensor
116
+
117
+ # conv layer
118
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
119
+ if pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and not is_key_or_prefix_key_in_dict(pt_tuple_key):
120
+ pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
121
+ return renamed_pt_tuple_key, pt_tensor
122
+
123
+ # linear layer
124
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
125
+ if pt_tuple_key[-1] == "weight" and not is_key_or_prefix_key_in_dict(pt_tuple_key):
126
+ pt_tensor = pt_tensor.T
127
+ return renamed_pt_tuple_key, pt_tensor
128
+
129
+ # old PyTorch layer norm weight
130
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
131
+ if pt_tuple_key[-1] == "gamma":
132
+ return renamed_pt_tuple_key, pt_tensor
133
+
134
+ # old PyTorch layer norm bias
135
+ renamed_pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
136
+ if pt_tuple_key[-1] == "beta":
137
+ return renamed_pt_tuple_key, pt_tensor
138
+
139
+ # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
140
+ name = None
141
+ if pt_tuple_key[-3::2] == ("parametrizations", "original0"):
142
+ name = pt_tuple_key[-2] + "_g"
143
+ elif pt_tuple_key[-3::2] == ("parametrizations", "original1"):
144
+ name = pt_tuple_key[-2] + "_v"
145
+ if name is not None:
146
+ renamed_pt_tuple_key = pt_tuple_key[:-3] + (name,)
147
+ return renamed_pt_tuple_key, pt_tensor
148
+
149
+ return pt_tuple_key, pt_tensor
150
+
151
+
152
+ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
153
+ # convert pytorch tensor to numpy
154
+ from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor)
155
+ bfloat16 = torch.bfloat16 if from_bin else "bfloat16"
156
+
157
+ weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
158
+
159
+ if from_bin:
160
+ for k, v in pt_state_dict.items():
161
+ # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
162
+ if v.dtype == bfloat16:
163
+ v = v.float()
164
+ pt_state_dict[k] = v.cpu().numpy()
165
+
166
+ model_prefix = flax_model.base_model_prefix
167
+
168
+ # use params dict if the model contains batch norm layers
169
+ if "params" in flax_model.params:
170
+ flax_model_params = flax_model.params["params"]
171
+ else:
172
+ flax_model_params = flax_model.params
173
+ random_flax_state_dict = flatten_dict(flax_model_params)
174
+
175
+ # add batch_stats keys,values to dict
176
+ if "batch_stats" in flax_model.params:
177
+ flax_batch_stats = flatten_dict(flax_model.params["batch_stats"])
178
+ random_flax_state_dict.update(flax_batch_stats)
179
+
180
+ flax_state_dict = {}
181
+
182
+ load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
183
+ model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
184
+ )
185
+ load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
186
+ model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
187
+ )
188
+
189
+ # Need to change some parameters name to match Flax names
190
+ for pt_key, pt_tensor in pt_state_dict.items():
191
+ pt_tuple_key = tuple(pt_key.split("."))
192
+ is_bfloat_16 = weight_dtypes[pt_key] == bfloat16
193
+
194
+ # remove base model prefix if necessary
195
+ has_base_model_prefix = pt_tuple_key[0] == model_prefix
196
+ if load_model_with_head_into_base_model and has_base_model_prefix:
197
+ pt_tuple_key = pt_tuple_key[1:]
198
+
199
+ # Correctly rename weight parameters
200
+ flax_key, flax_tensor = rename_key_and_reshape_tensor(
201
+ pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
202
+ )
203
+
204
+ # add model prefix if necessary
205
+ require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
206
+ if load_base_model_into_model_with_head and require_base_model_prefix:
207
+ flax_key = (model_prefix,) + flax_key
208
+
209
+ if flax_key in random_flax_state_dict:
210
+ if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
211
+ raise ValueError(
212
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
213
+ f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
214
+ )
215
+
216
+ # add batch stats if the model contains batchnorm layers
217
+ if "batch_stats" in flax_model.params:
218
+ if "mean" in flax_key[-1] or "var" in flax_key[-1]:
219
+ flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
220
+ continue
221
+ # remove num_batches_tracked key
222
+ if "num_batches_tracked" in flax_key[-1]:
223
+ flax_state_dict.pop(flax_key, None)
224
+ continue
225
+
226
+ # also add unexpected weight so that warning is thrown
227
+ flax_state_dict[("params",) + flax_key] = (
228
+ jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
229
+ )
230
+ else:
231
+ # also add unexpected weight so that warning is thrown
232
+ flax_state_dict[flax_key] = (
233
+ jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
234
+ )
235
+
236
+ return unflatten_dict(flax_state_dict)
237
+
238
+
239
+ ############################
240
+ # Sharded Pytorch => Flax #
241
+ ############################
242
+
243
+
244
+ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
245
+ import torch
246
+
247
+ # Load the index
248
+ flax_state_dict = {}
249
+ for shard_file in shard_filenames:
250
+ # load using msgpack utils
251
+ weights_only_kwarg = {"weights_only": True}
252
+ pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
253
+ weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
254
+ pt_state_dict = {
255
+ k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
256
+ }
257
+
258
+ model_prefix = flax_model.base_model_prefix
259
+
260
+ # use params dict if the model contains batch norm layers and then add batch_stats keys,values to dict
261
+ if "batch_stats" in flax_model.params:
262
+ flax_model_params = flax_model.params["params"]
263
+
264
+ random_flax_state_dict = flatten_dict(flax_model_params)
265
+ random_flax_state_dict.update(flatten_dict(flax_model.params["batch_stats"]))
266
+ else:
267
+ flax_model_params = flax_model.params
268
+ random_flax_state_dict = flatten_dict(flax_model_params)
269
+
270
+ load_model_with_head_into_base_model = (model_prefix not in flax_model_params) and (
271
+ model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
272
+ )
273
+ load_base_model_into_model_with_head = (model_prefix in flax_model_params) and (
274
+ model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
275
+ )
276
+ # Need to change some parameters name to match Flax names
277
+ for pt_key, pt_tensor in pt_state_dict.items():
278
+ pt_tuple_key = tuple(pt_key.split("."))
279
+ is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16
280
+
281
+ # remove base model prefix if necessary
282
+ has_base_model_prefix = pt_tuple_key[0] == model_prefix
283
+ if load_model_with_head_into_base_model and has_base_model_prefix:
284
+ pt_tuple_key = pt_tuple_key[1:]
285
+
286
+ # Correctly rename weight parameters
287
+ flax_key, flax_tensor = rename_key_and_reshape_tensor(
288
+ pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix
289
+ )
290
+ # add model prefix if necessary
291
+ require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict
292
+ if load_base_model_into_model_with_head and require_base_model_prefix:
293
+ flax_key = (model_prefix,) + flax_key
294
+
295
+ if flax_key in random_flax_state_dict:
296
+ if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
297
+ raise ValueError(
298
+ f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
299
+ f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
300
+ )
301
+
302
+ # add batch stats if the model contains batchnorm layers
303
+ if "batch_stats" in flax_model.params:
304
+ if "mean" in flax_key[-1]:
305
+ flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
306
+ continue
307
+ if "var" in flax_key[-1]:
308
+ flax_state_dict[("batch_stats",) + flax_key] = jnp.asarray(flax_tensor)
309
+ continue
310
+ # remove num_batches_tracked key
311
+ if "num_batches_tracked" in flax_key[-1]:
312
+ flax_state_dict.pop(flax_key, None)
313
+ continue
314
+
315
+ # also add unexpected weight so that warning is thrown
316
+ flax_state_dict[("params",) + flax_key] = (
317
+ jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
318
+ )
319
+
320
+ else:
321
+ # also add unexpected weight so that warning is thrown
322
+ flax_state_dict[flax_key] = (
323
+ jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
324
+ )
325
+ return unflatten_dict(flax_state_dict)
326
+
327
+
328
+ #####################
329
+ # Flax => PyTorch #
330
+ #####################
331
+
332
+
333
+ def load_flax_checkpoint_in_pytorch_model(model, flax_checkpoint_path):
334
+ """Load flax checkpoints in a PyTorch model"""
335
+ flax_checkpoint_path = os.path.abspath(flax_checkpoint_path)
336
+ logger.info(f"Loading Flax weights from {flax_checkpoint_path}")
337
+
338
+ # import correct flax class
339
+ flax_cls = getattr(transformers, "Flax" + model.__class__.__name__)
340
+
341
+ # load flax weight dict
342
+ if flax_checkpoint_path.endswith(".safetensors"):
343
+ flax_state_dict = safe_load_file(flax_checkpoint_path)
344
+ flax_state_dict = unflatten_dict(flax_state_dict, sep=".")
345
+ else:
346
+ with open(flax_checkpoint_path, "rb") as state_f:
347
+ try:
348
+ flax_state_dict = from_bytes(flax_cls, state_f.read())
349
+ except UnpicklingError:
350
+ raise EnvironmentError(f"Unable to convert {flax_checkpoint_path} to Flax deserializable object. ")
351
+
352
+ return load_flax_weights_in_pytorch_model(model, flax_state_dict)
353
+
354
+
355
+ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
356
+ """Load flax checkpoints in a PyTorch model"""
357
+
358
+ try:
359
+ import torch # noqa: F401
360
+ except (ImportError, ModuleNotFoundError):
361
+ logger.error(
362
+ "Loading a Flax weights in PyTorch, requires both PyTorch and Flax to be installed. Please see"
363
+ " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
364
+ " instructions."
365
+ )
366
+ raise
367
+
368
+ # check if we have bf16 weights
369
+ is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
370
+ if any(is_type_bf16):
371
+ # convert all weights to fp32 if the are bf16 since torch.from_numpy can-not handle bf16
372
+ # and bf16 is not fully supported in PT yet.
373
+ logger.warning(
374
+ "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
375
+ "before loading those in PyTorch model."
376
+ )
377
+ flax_state = jax.tree_util.tree_map(
378
+ lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
379
+ )
380
+
381
+ flax_state_dict = flatten_dict(flax_state)
382
+ pt_model_dict = pt_model.state_dict()
383
+
384
+ load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
385
+ pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict.keys()}
386
+ )
387
+ load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
388
+ pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict.keys()}
389
+ )
390
+
391
+ # keep track of unexpected & missing keys
392
+ unexpected_keys = []
393
+ missing_keys = set(pt_model_dict.keys())
394
+
395
+ for flax_key_tuple, flax_tensor in flax_state_dict.items():
396
+ has_base_model_prefix = flax_key_tuple[0] == pt_model.base_model_prefix
397
+ require_base_model_prefix = ".".join((pt_model.base_model_prefix,) + flax_key_tuple) in pt_model_dict
398
+
399
+ # adapt flax_key to prepare for loading from/to base model only
400
+ if load_model_with_head_into_base_model and has_base_model_prefix:
401
+ flax_key_tuple = flax_key_tuple[1:]
402
+ elif load_base_model_into_model_with_head and require_base_model_prefix:
403
+ flax_key_tuple = (pt_model.base_model_prefix,) + flax_key_tuple
404
+
405
+ # rename flax weights to PyTorch format
406
+ if flax_key_tuple[-1] == "kernel" and flax_tensor.ndim == 4 and ".".join(flax_key_tuple) not in pt_model_dict:
407
+ # conv layer
408
+ flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
409
+ flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
410
+ elif flax_key_tuple[-1] == "kernel" and ".".join(flax_key_tuple) not in pt_model_dict:
411
+ # linear layer
412
+ flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
413
+ flax_tensor = flax_tensor.T
414
+ elif flax_key_tuple[-1] in ["scale", "embedding"]:
415
+ flax_key_tuple = flax_key_tuple[:-1] + ("weight",)
416
+
417
+ # adding batch stats from flax batch norm to pt
418
+ elif "mean" in flax_key_tuple[-1]:
419
+ flax_key_tuple = flax_key_tuple[:-1] + ("running_mean",)
420
+ elif "var" in flax_key_tuple[-1]:
421
+ flax_key_tuple = flax_key_tuple[:-1] + ("running_var",)
422
+
423
+ if "batch_stats" in flax_state:
424
+ flax_key = ".".join(flax_key_tuple[1:]) # Remove the params/batch_stats header
425
+ else:
426
+ flax_key = ".".join(flax_key_tuple)
427
+
428
+ # We also need to look at `pt_model_dict` and see if there are keys requiring further transformation.
429
+ special_pt_names = {}
430
+ # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
431
+ for key in pt_model_dict:
432
+ key_components = key.split(".")
433
+ name = None
434
+ if key_components[-3::2] == ["parametrizations", "original0"]:
435
+ name = key_components[-2] + "_g"
436
+ elif key_components[-3::2] == ["parametrizations", "original1"]:
437
+ name = key_components[-2] + "_v"
438
+ if name is not None:
439
+ key_components = key_components[:-3] + [name]
440
+ key_to_check = ".".join(key_components)
441
+ special_pt_names[key_to_check] = key
442
+
443
+ if flax_key in special_pt_names:
444
+ flax_key = special_pt_names[flax_key]
445
+
446
+ if flax_key in pt_model_dict:
447
+ if flax_tensor.shape != pt_model_dict[flax_key].shape:
448
+ raise ValueError(
449
+ f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
450
+ f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
451
+ )
452
+ else:
453
+ # add weight to pytorch dict
454
+ flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
455
+ pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
456
+ # remove from missing keys
457
+ missing_keys.remove(flax_key)
458
+ else:
459
+ # weight is not expected by PyTorch model
460
+ unexpected_keys.append(flax_key)
461
+
462
+ pt_model.load_state_dict(pt_model_dict)
463
+
464
+ # re-transform missing_keys to list
465
+ missing_keys = list(missing_keys)
466
+
467
+ if len(unexpected_keys) > 0:
468
+ logger.warning(
469
+ "Some weights of the Flax model were not used when initializing the PyTorch model"
470
+ f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
471
+ f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
472
+ " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
473
+ f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
474
+ " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
475
+ " FlaxBertForSequenceClassification model)."
476
+ )
477
+ else:
478
+ logger.warning(f"All Flax model weights were used when initializing {pt_model.__class__.__name__}.\n")
479
+ if len(missing_keys) > 0:
480
+ logger.warning(
481
+ f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
482
+ f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
483
+ " use it for predictions and inference."
484
+ )
485
+ else:
486
+ logger.warning(
487
+ f"All the weights of {pt_model.__class__.__name__} were initialized from the Flax model.\n"
488
+ "If your task is similar to the task the model of the checkpoint was trained on, "
489
+ f"you can already use {pt_model.__class__.__name__} for predictions without further training."
490
+ )
491
+
492
+ return pt_model
modeling_flax_utils.py ADDED
@@ -0,0 +1,1290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import gc
18
+ import json
19
+ import os
20
+ import re
21
+ import warnings
22
+ from functools import partial
23
+ from pickle import UnpicklingError
24
+ from typing import Any, Dict, Optional, Set, Tuple, Union
25
+
26
+ import flax.linen as nn
27
+ import jax
28
+ import jax.numpy as jnp
29
+ import msgpack.exceptions
30
+ from flax.core.frozen_dict import FrozenDict, unfreeze
31
+ from flax.serialization import from_bytes, to_bytes
32
+ from flax.traverse_util import flatten_dict, unflatten_dict
33
+ from jax.random import PRNGKey
34
+
35
+ from .configuration_utils import PretrainedConfig
36
+ from .dynamic_module_utils import custom_object_save
37
+ from .generation import FlaxGenerationMixin, GenerationConfig
38
+ from .modeling_flax_pytorch_utils import load_pytorch_checkpoint_in_flax_state_dict
39
+ from .utils import (
40
+ FLAX_WEIGHTS_INDEX_NAME,
41
+ FLAX_WEIGHTS_NAME,
42
+ SAFE_WEIGHTS_INDEX_NAME,
43
+ SAFE_WEIGHTS_NAME,
44
+ WEIGHTS_INDEX_NAME,
45
+ WEIGHTS_NAME,
46
+ PushToHubMixin,
47
+ add_code_sample_docstrings,
48
+ add_start_docstrings_to_model_forward,
49
+ cached_file,
50
+ copy_func,
51
+ download_url,
52
+ has_file,
53
+ is_offline_mode,
54
+ is_remote_url,
55
+ logging,
56
+ replace_return_docstrings,
57
+ )
58
+ from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
59
+ from .utils.import_utils import is_safetensors_available
60
+
61
+
62
+ if is_safetensors_available():
63
+ from safetensors import safe_open
64
+ from safetensors.flax import load_file as safe_load_file
65
+ from safetensors.flax import save_file as safe_save_file
66
+
67
+ logger = logging.get_logger(__name__)
68
+
69
+
70
+ def quick_gelu(x):
71
+ return x * jax.nn.sigmoid(1.702 * x)
72
+
73
+
74
+ ACT2FN = {
75
+ "gelu": partial(nn.gelu, approximate=False),
76
+ "relu": nn.relu,
77
+ "silu": nn.swish,
78
+ "swish": nn.swish,
79
+ "gelu_new": partial(nn.gelu, approximate=True),
80
+ "quick_gelu": quick_gelu,
81
+ "gelu_pytorch_tanh": partial(nn.gelu, approximate=True),
82
+ }
83
+
84
+
85
+ def dtype_byte_size(dtype):
86
+ """
87
+ Returns the size (in bytes) occupied by one parameter of type `dtype`. Example:
88
+ ```py
89
+ >>> dtype_byte_size(np.float32)
90
+ 4
91
+ ```
92
+ """
93
+ if dtype is bool:
94
+ return 1 / 8
95
+ bit_search = re.search(r"[^\d](\d+)$", dtype.name)
96
+ if bit_search is None:
97
+ raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
98
+ bit_size = int(bit_search.groups()[0])
99
+ return bit_size // 8
100
+
101
+
102
+ def flax_shard_checkpoint(params, max_shard_size="10GB"):
103
+ """
104
+ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
105
+ given size. The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so
106
+ there is no optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For
107
+ example, if the limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as
108
+ [6GB], [6+2GB], [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
109
+
110
+ <Tip warning={true}>
111
+
112
+ If one of the model's weight is bigger that `max_shard_size`, it will end up in its own sub-checkpoint which will
113
+ have a size greater than `max_shard_size`.
114
+
115
+ </Tip>
116
+
117
+ Args:
118
+ params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters.
119
+ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
120
+ The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
121
+ (like `"5MB"`).
122
+ """
123
+ max_shard_size = convert_file_size_to_int(max_shard_size)
124
+
125
+ sharded_state_dicts = []
126
+ current_block = {}
127
+ current_block_size = 0
128
+ total_size = 0
129
+
130
+ # flatten the weights to chunk
131
+ weights = flatten_dict(params, sep="/")
132
+ for item in weights:
133
+ weight_size = weights[item].size * dtype_byte_size(weights[item].dtype)
134
+
135
+ # If this weight is going to tip up over the maximal size, we split.
136
+ if current_block_size + weight_size > max_shard_size:
137
+ sharded_state_dicts.append(current_block)
138
+ current_block = {}
139
+ current_block_size = 0
140
+
141
+ current_block[item] = weights[item]
142
+ current_block_size += weight_size
143
+ total_size += weight_size
144
+
145
+ # Add the last block
146
+ sharded_state_dicts.append(current_block)
147
+
148
+ # If we only have one shard, we return it
149
+ if len(sharded_state_dicts) == 1:
150
+ return {FLAX_WEIGHTS_NAME: sharded_state_dicts[0]}, None
151
+
152
+ # Otherwise, let's build the index
153
+ weight_map = {}
154
+ shards = {}
155
+ for idx, shard in enumerate(sharded_state_dicts):
156
+ shard_file = FLAX_WEIGHTS_NAME.replace(".msgpack", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.msgpack")
157
+ shards[shard_file] = shard
158
+ for weight_name in shard.keys():
159
+ weight_map[weight_name] = shard_file
160
+
161
+ # Add the metadata
162
+ metadata = {"total_size": total_size}
163
+ index = {"metadata": metadata, "weight_map": weight_map}
164
+ return shards, index
165
+
166
+
167
+ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
168
+ r"""
169
+ Base class for all models.
170
+
171
+ [`FlaxPreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
172
+ downloading and saving models.
173
+
174
+ Class attributes (overridden by derived classes):
175
+
176
+ - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
177
+ for this model architecture.
178
+ - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
179
+ classes of the same architecture adding modules on top of the base model.
180
+ - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
181
+ models, `pixel_values` for vision models and `input_values` for speech models).
182
+ """
183
+
184
+ config_class = None
185
+ base_model_prefix = ""
186
+ main_input_name = "input_ids"
187
+ _auto_class = None
188
+ _missing_keys = set()
189
+
190
+ def __init__(
191
+ self,
192
+ config: PretrainedConfig,
193
+ module: nn.Module,
194
+ input_shape: Tuple = (1, 1),
195
+ seed: int = 0,
196
+ dtype: jnp.dtype = jnp.float32,
197
+ _do_init: bool = True,
198
+ ):
199
+ if config is None:
200
+ raise ValueError("config cannot be None")
201
+
202
+ if module is None:
203
+ raise ValueError("module cannot be None")
204
+
205
+ # Those are private to be exposed as typed property on derived classes.
206
+ self._config = config
207
+ self._module = module
208
+
209
+ # Those are public as their type is generic to every derived classes.
210
+ self.key = PRNGKey(seed)
211
+ self.dtype = dtype
212
+ self.input_shape = input_shape
213
+ self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
214
+
215
+ # To check if the model was initialized automatically.
216
+ self._is_initialized = _do_init
217
+
218
+ if _do_init:
219
+ # randomly initialized parameters
220
+ random_params = self.init_weights(self.key, input_shape)
221
+ params_shape_tree = jax.eval_shape(lambda params: params, random_params)
222
+ else:
223
+ init_fn = partial(self.init_weights, input_shape=input_shape)
224
+ params_shape_tree = jax.eval_shape(init_fn, self.key)
225
+
226
+ logger.info(
227
+ "Model weights are not initialized as `_do_init` is set to `False`. "
228
+ f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
229
+ )
230
+
231
+ # get the shape of the parameters
232
+ self._params_shape_tree = params_shape_tree
233
+
234
+ # save required_params as set
235
+ self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
236
+
237
+ # initialize the parameters
238
+ if _do_init:
239
+ self.params = random_params
240
+
241
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
242
+ raise NotImplementedError(f"init method has to be implemented for {self}")
243
+
244
+ def enable_gradient_checkpointing(self):
245
+ raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}")
246
+
247
+ @classmethod
248
+ def _from_config(cls, config, **kwargs):
249
+ """
250
+ All context managers that the model should be initialized under go here.
251
+ """
252
+ return cls(config, **kwargs)
253
+
254
+ @property
255
+ def framework(self) -> str:
256
+ """
257
+ :str: Identifies that this is a Flax model.
258
+ """
259
+ return "flax"
260
+
261
+ @property
262
+ def config(self) -> PretrainedConfig:
263
+ return self._config
264
+
265
+ @property
266
+ def module(self) -> nn.Module:
267
+ return self._module
268
+
269
+ @property
270
+ def params(self) -> Union[Dict, FrozenDict]:
271
+ if not self._is_initialized:
272
+ raise ValueError(
273
+ "`params` cannot be accessed from model when the model is created with `_do_init=False`. "
274
+ "You must call `init_weights` manually and store the params outside of the model and "
275
+ "pass it explicitly where needed."
276
+ )
277
+ return self._params
278
+
279
+ @property
280
+ def required_params(self) -> Set:
281
+ return self._required_params
282
+
283
+ @property
284
+ def params_shape_tree(self) -> Dict:
285
+ return self._params_shape_tree
286
+
287
+ @params.setter
288
+ def params(self, params: Union[Dict, FrozenDict]):
289
+ # don't set params if the model is not initialized
290
+ if not self._is_initialized:
291
+ raise ValueError(
292
+ "`params` cannot be set from model when the model is created with `_do_init=False`. "
293
+ "You store the params outside of the model."
294
+ )
295
+
296
+ if isinstance(params, FrozenDict):
297
+ params = unfreeze(params)
298
+ param_keys = set(flatten_dict(params).keys())
299
+ if len(self.required_params - param_keys) > 0:
300
+ raise ValueError(
301
+ "Some parameters are missing. Make sure that `params` include the following "
302
+ f"parameters {self.required_params - param_keys}"
303
+ )
304
+ self._params = params
305
+
306
+ def _cast_floating_to(self, params: Union[Dict, FrozenDict], dtype: jnp.dtype, mask: Any = None) -> Any:
307
+ """
308
+ Helper method to cast floating-point values of given parameter `PyTree` to given `dtype`.
309
+ """
310
+
311
+ # taken from https://github.com/deepmind/jmp/blob/3a8318abc3292be38582794dbf7b094e6583b192/jmp/_src/policy.py#L27
312
+ def conditional_cast(param):
313
+ if isinstance(param, jnp.ndarray) and jnp.issubdtype(param.dtype, jnp.floating):
314
+ param = param.astype(dtype)
315
+ return param
316
+
317
+ if mask is None:
318
+ return jax.tree_util.tree_map(conditional_cast, params)
319
+
320
+ flat_params = flatten_dict(params)
321
+ flat_mask, _ = jax.tree_util.tree_flatten(mask)
322
+
323
+ for masked, key in zip(flat_mask, sorted(flat_params.keys())):
324
+ if masked:
325
+ flat_params[key] = conditional_cast(flat_params[key])
326
+
327
+ return unflatten_dict(flat_params)
328
+
329
+ def to_bf16(self, params: Union[Dict, FrozenDict], mask: Any = None):
330
+ r"""
331
+ Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast
332
+ the `params` in place.
333
+
334
+ This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full
335
+ half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed.
336
+
337
+ Arguments:
338
+ params (`Union[Dict, FrozenDict]`):
339
+ A `PyTree` of model parameters.
340
+ mask (`Union[Dict, FrozenDict]`):
341
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
342
+ you want to cast, and should be `False` for those you want to skip.
343
+
344
+ Examples:
345
+
346
+ ```python
347
+ >>> from transformers import FlaxBertModel
348
+
349
+ >>> # load model
350
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
351
+ >>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
352
+ >>> model.params = model.to_bf16(model.params)
353
+ >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
354
+ >>> # then pass the mask as follows
355
+ >>> from flax import traverse_util
356
+
357
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
358
+ >>> flat_params = traverse_util.flatten_dict(model.params)
359
+ >>> mask = {
360
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
361
+ ... for path in flat_params
362
+ ... }
363
+ >>> mask = traverse_util.unflatten_dict(mask)
364
+ >>> model.params = model.to_bf16(model.params, mask)
365
+ ```"""
366
+ return self._cast_floating_to(params, jnp.bfloat16, mask)
367
+
368
+ def to_fp32(self, params: Union[Dict, FrozenDict], mask: Any = None):
369
+ r"""
370
+ Cast the floating-point `parmas` to `jax.numpy.float32`. This method can be used to explicitly convert the
371
+ model parameters to fp32 precision. This returns a new `params` tree and does not cast the `params` in place.
372
+
373
+ Arguments:
374
+ params (`Union[Dict, FrozenDict]`):
375
+ A `PyTree` of model parameters.
376
+ mask (`Union[Dict, FrozenDict]`):
377
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
378
+ you want to cast, and should be `False` for those you want to skip
379
+
380
+ Examples:
381
+
382
+ ```python
383
+ >>> from transformers import FlaxBertModel
384
+
385
+ >>> # Download model and configuration from huggingface.co
386
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
387
+ >>> # By default, the model params will be in fp32, to illustrate the use of this method,
388
+ >>> # we'll first cast to fp16 and back to fp32
389
+ >>> model.params = model.to_f16(model.params)
390
+ >>> # now cast back to fp32
391
+ >>> model.params = model.to_fp32(model.params)
392
+ ```"""
393
+ return self._cast_floating_to(params, jnp.float32, mask)
394
+
395
+ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
396
+ r"""
397
+ Cast the floating-point `parmas` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the
398
+ `params` in place.
399
+
400
+ This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full
401
+ half-precision training or to save weights in float16 for inference in order to save memory and improve speed.
402
+
403
+ Arguments:
404
+ params (`Union[Dict, FrozenDict]`):
405
+ A `PyTree` of model parameters.
406
+ mask (`Union[Dict, FrozenDict]`):
407
+ A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params
408
+ you want to cast, and should be `False` for those you want to skip
409
+
410
+ Examples:
411
+
412
+ ```python
413
+ >>> from transformers import FlaxBertModel
414
+
415
+ >>> # load model
416
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
417
+ >>> # By default, the model params will be in fp32, to cast these to float16
418
+ >>> model.params = model.to_fp16(model.params)
419
+ >>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
420
+ >>> # then pass the mask as follows
421
+ >>> from flax import traverse_util
422
+
423
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
424
+ >>> flat_params = traverse_util.flatten_dict(model.params)
425
+ >>> mask = {
426
+ ... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
427
+ ... for path in flat_params
428
+ ... }
429
+ >>> mask = traverse_util.unflatten_dict(mask)
430
+ >>> model.params = model.to_fp16(model.params, mask)
431
+ ```"""
432
+ return self._cast_floating_to(params, jnp.float16, mask)
433
+
434
+ @classmethod
435
+ def load_flax_weights(cls, resolved_archive_file):
436
+ try:
437
+ if resolved_archive_file.endswith(".safetensors"):
438
+ state = safe_load_file(resolved_archive_file)
439
+ state = unflatten_dict(state, sep=".")
440
+ else:
441
+ with open(resolved_archive_file, "rb") as state_f:
442
+ state = from_bytes(cls, state_f.read())
443
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
444
+ try:
445
+ with open(resolved_archive_file) as f:
446
+ if f.read().startswith("version"):
447
+ raise OSError(
448
+ "You seem to have cloned a repository without having git-lfs installed. Please"
449
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
450
+ " folder you cloned."
451
+ )
452
+ else:
453
+ raise ValueError from e
454
+ except (UnicodeDecodeError, ValueError):
455
+ raise EnvironmentError(f"Unable to convert {resolved_archive_file} to Flax deserializable object. ")
456
+
457
+ return state
458
+
459
+ @classmethod
460
+ def load_flax_sharded_weights(cls, shard_files):
461
+ """
462
+ This is the same as [`flax.serialization.from_bytes`]
463
+ (https:lax.readthedocs.io/en/latest/_modules/flax/serialization.html#from_bytes) but for a sharded checkpoint.
464
+
465
+ This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
466
+ loaded in the model.
467
+
468
+ Args:
469
+ shard_files (`List[str]`:
470
+ The list of shard files to load.
471
+
472
+ Returns:
473
+ `Dict`: A nested dictionary of the model parameters, in the expected format for flax models : `{'model':
474
+ {'params': {'...'}}}`.
475
+ """
476
+
477
+ # Load the index
478
+ state_sharded_dict = {}
479
+
480
+ for shard_file in shard_files:
481
+ # load using msgpack utils
482
+ try:
483
+ with open(shard_file, "rb") as state_f:
484
+ state = from_bytes(cls, state_f.read())
485
+ except (UnpicklingError, msgpack.exceptions.ExtraData) as e:
486
+ with open(shard_file) as f:
487
+ if f.read().startswith("version"):
488
+ raise OSError(
489
+ "You seem to have cloned a repository without having git-lfs installed. Please"
490
+ " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
491
+ " folder you cloned."
492
+ )
493
+ else:
494
+ raise ValueError from e
495
+ except (UnicodeDecodeError, ValueError):
496
+ raise EnvironmentError(f"Unable to convert {shard_file} to Flax deserializable object. ")
497
+
498
+ state = flatten_dict(state, sep="/")
499
+ state_sharded_dict.update(state)
500
+ del state
501
+ gc.collect()
502
+
503
+ # the state dict is unflattened to the match the format of model.params
504
+ return unflatten_dict(state_sharded_dict, sep="/")
505
+
506
+ @classmethod
507
+ def can_generate(cls) -> bool:
508
+ """
509
+ Returns whether this model can generate sequences with `.generate()`. Returns:
510
+ `bool`: Whether this model can generate sequences with `.generate()`.
511
+ """
512
+ # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
513
+ # Alternativelly, the model can also have a custom `generate` function.
514
+ if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
515
+ return False
516
+ return True
517
+
518
+ @classmethod
519
+ def from_pretrained(
520
+ cls,
521
+ pretrained_model_name_or_path: Union[str, os.PathLike],
522
+ dtype: jnp.dtype = jnp.float32,
523
+ *model_args,
524
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
525
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
526
+ ignore_mismatched_sizes: bool = False,
527
+ force_download: bool = False,
528
+ local_files_only: bool = False,
529
+ token: Optional[Union[str, bool]] = None,
530
+ revision: str = "main",
531
+ **kwargs,
532
+ ):
533
+ r"""
534
+ Instantiate a pretrained flax model from a pre-trained model configuration.
535
+
536
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
537
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
538
+ task.
539
+
540
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
541
+ weights are discarded.
542
+
543
+ Parameters:
544
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
545
+ Can be either:
546
+
547
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
548
+ - A path to a *directory* containing model weights saved using
549
+ [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
550
+ - A path or url to a *pt index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In this case,
551
+ `from_pt` should be set to `True`.
552
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
553
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
554
+ `jax.numpy.bfloat16` (on TPUs).
555
+
556
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
557
+ specified all the computation will be performed with the given `dtype`.
558
+
559
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
560
+ parameters.**
561
+
562
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
563
+ [`~FlaxPreTrainedModel.to_bf16`].
564
+ model_args (sequence of positional arguments, *optional*):
565
+ All remaining positional arguments will be passed to the underlying model's `__init__` method.
566
+ config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
567
+ Can be either:
568
+
569
+ - an instance of a class derived from [`PretrainedConfig`],
570
+ - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
571
+
572
+ Configuration for the model to use instead of an automatically loaded configuration. Configuration can
573
+ be automatically loaded when:
574
+
575
+ - The model is a model provided by the library (loaded with the *model id* string of a pretrained
576
+ model).
577
+ - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
578
+ save directory.
579
+ - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
580
+ configuration JSON file named *config.json* is found in the directory.
581
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
582
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
583
+ standard cache should not be used.
584
+ from_pt (`bool`, *optional*, defaults to `False`):
585
+ Load the model weights from a PyTorch checkpoint save file (see docstring of
586
+ `pretrained_model_name_or_path` argument).
587
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
588
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
589
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
590
+ checkpoint with 3 labels).
591
+ force_download (`bool`, *optional*, defaults to `False`):
592
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
593
+ cached versions if they exist.
594
+ resume_download:
595
+ Deprecated and ignored. All downloads are now resumed by default when possible.
596
+ Will be removed in v5 of Transformers.
597
+ proxies (`Dict[str, str]`, *optional*):
598
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
599
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
600
+ local_files_only(`bool`, *optional*, defaults to `False`):
601
+ Whether or not to only look at local files (i.e., do not try to download the model).
602
+ token (`str` or `bool`, *optional*):
603
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
604
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
605
+ revision (`str`, *optional*, defaults to `"main"`):
606
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
607
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
608
+ identifier allowed by git.
609
+
610
+
611
+ <Tip>
612
+
613
+ To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
614
+
615
+ </Tip>
616
+
617
+ subfolder (`str`, *optional*, defaults to `""`):
618
+ In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
619
+ specify the folder name here.
620
+ kwargs (remaining dictionary of keyword arguments, *optional*):
621
+ Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
622
+ `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
623
+ automatically loaded:
624
+
625
+ - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
626
+ underlying model's `__init__` method (we assume all relevant updates to the configuration have
627
+ already been done)
628
+ - If a configuration is not provided, `kwargs` will be first passed to the configuration class
629
+ initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
630
+ corresponds to a configuration attribute will be used to override said attribute with the
631
+ supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
632
+ will be passed to the underlying model's `__init__` function.
633
+
634
+ Examples:
635
+
636
+ ```python
637
+ >>> from transformers import BertConfig, FlaxBertModel
638
+
639
+ >>> # Download model and configuration from huggingface.co and cache.
640
+ >>> model = FlaxBertModel.from_pretrained("google-bert/bert-base-cased")
641
+ >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
642
+ >>> model = FlaxBertModel.from_pretrained("./test/saved_model/")
643
+ >>> # Loading from a PyTorch checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
644
+ >>> config = BertConfig.from_json_file("./pt_model/config.json")
645
+ >>> model = FlaxBertModel.from_pretrained("./pt_model/pytorch_model.bin", from_pt=True, config=config)
646
+ ```"""
647
+ from_pt = kwargs.pop("from_pt", False)
648
+ resume_download = kwargs.pop("resume_download", None)
649
+ proxies = kwargs.pop("proxies", None)
650
+ use_auth_token = kwargs.pop("use_auth_token", None)
651
+ trust_remote_code = kwargs.pop("trust_remote_code", None)
652
+ from_pipeline = kwargs.pop("_from_pipeline", None)
653
+ from_auto_class = kwargs.pop("_from_auto", False)
654
+ _do_init = kwargs.pop("_do_init", True)
655
+ subfolder = kwargs.pop("subfolder", "")
656
+ commit_hash = kwargs.pop("_commit_hash", None)
657
+
658
+ # Not relevant for Flax Models
659
+ _ = kwargs.pop("adapter_kwargs", None)
660
+
661
+ if use_auth_token is not None:
662
+ warnings.warn(
663
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
664
+ FutureWarning,
665
+ )
666
+ if token is not None:
667
+ raise ValueError(
668
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
669
+ )
670
+ token = use_auth_token
671
+
672
+ if trust_remote_code is True:
673
+ logger.warning(
674
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
675
+ " ignored."
676
+ )
677
+
678
+ user_agent = {"file_type": "model", "framework": "flax", "from_auto_class": from_auto_class}
679
+ if from_pipeline is not None:
680
+ user_agent["using_pipeline"] = from_pipeline
681
+
682
+ if is_offline_mode() and not local_files_only:
683
+ logger.info("Offline mode: forcing local_files_only=True")
684
+ local_files_only = True
685
+
686
+ # Load config if we don't provide a configuration
687
+ if not isinstance(config, PretrainedConfig):
688
+ config_path = config if config is not None else pretrained_model_name_or_path
689
+ config, model_kwargs = cls.config_class.from_pretrained(
690
+ config_path,
691
+ cache_dir=cache_dir,
692
+ return_unused_kwargs=True,
693
+ force_download=force_download,
694
+ resume_download=resume_download,
695
+ proxies=proxies,
696
+ local_files_only=local_files_only,
697
+ token=token,
698
+ revision=revision,
699
+ subfolder=subfolder,
700
+ _from_auto=from_auto_class,
701
+ _from_pipeline=from_pipeline,
702
+ _commit_hash=commit_hash,
703
+ **kwargs,
704
+ )
705
+ else:
706
+ model_kwargs = kwargs.copy()
707
+
708
+ if commit_hash is None:
709
+ commit_hash = getattr(config, "_commit_hash", None)
710
+
711
+ # Add the dtype to model_kwargs
712
+ model_kwargs["dtype"] = dtype
713
+
714
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
715
+ # index of the files.
716
+ is_sharded = False
717
+
718
+ # Load model
719
+ if pretrained_model_name_or_path is not None:
720
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
721
+ is_local = os.path.isdir(pretrained_model_name_or_path)
722
+ if os.path.isdir(pretrained_model_name_or_path):
723
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
724
+ # Load from a Flax checkpoint
725
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
726
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)):
727
+ # Load from a sharded Flax checkpoint
728
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_INDEX_NAME)
729
+ is_sharded = True
730
+ elif is_safetensors_available() and os.path.isfile(
731
+ os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
732
+ ):
733
+ # Load from a safetensors checkpoint
734
+ archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_NAME)
735
+ elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
736
+ # Load from a PyTorch checkpoint
737
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
738
+ elif from_pt and os.path.isfile(
739
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
740
+ ):
741
+ # Load from a sharded pytorch checkpoint
742
+ archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
743
+ is_sharded = True
744
+ # At this stage we don't have a weight file so we will raise an error.
745
+ elif is_safetensors_available() and os.path.isfile(
746
+ os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
747
+ ):
748
+ # Load from a sharded safetensors checkpoint
749
+ archive_file = os.path.join(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
750
+ is_sharded = True
751
+ raise NotImplementedError("Support for sharded checkpoints using safetensors is coming soon!")
752
+ elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
753
+ raise EnvironmentError(
754
+ f"Error no file named {FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
755
+ "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
756
+ "weights."
757
+ )
758
+ else:
759
+ raise EnvironmentError(
760
+ f"Error no file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME} found in directory "
761
+ f"{pretrained_model_name_or_path}."
762
+ )
763
+ elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
764
+ archive_file = pretrained_model_name_or_path
765
+ is_local = True
766
+ elif is_remote_url(pretrained_model_name_or_path):
767
+ filename = pretrained_model_name_or_path
768
+ resolved_archive_file = download_url(pretrained_model_name_or_path)
769
+ else:
770
+ if from_pt:
771
+ filename = WEIGHTS_NAME
772
+ else:
773
+ filename = FLAX_WEIGHTS_NAME
774
+
775
+ try:
776
+ # Load from URL or cache if already cached
777
+ cached_file_kwargs = {
778
+ "cache_dir": cache_dir,
779
+ "force_download": force_download,
780
+ "proxies": proxies,
781
+ "resume_download": resume_download,
782
+ "local_files_only": local_files_only,
783
+ "token": token,
784
+ "user_agent": user_agent,
785
+ "revision": revision,
786
+ "subfolder": subfolder,
787
+ "_raise_exceptions_for_gated_repo": False,
788
+ "_raise_exceptions_for_missing_entries": False,
789
+ "_commit_hash": commit_hash,
790
+ }
791
+ resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
792
+
793
+ # Maybe the checkpoint is sharded, we try to grab the index name in this case.
794
+ if resolved_archive_file is None and filename == FLAX_WEIGHTS_NAME:
795
+ resolved_archive_file = cached_file(
796
+ pretrained_model_name_or_path, FLAX_WEIGHTS_INDEX_NAME, **cached_file_kwargs
797
+ )
798
+ if resolved_archive_file is not None:
799
+ is_sharded = True
800
+
801
+ # Maybe the checkpoint is pytorch sharded, we try to grab the pytorch index name in this case.
802
+ if resolved_archive_file is None and from_pt:
803
+ resolved_archive_file = cached_file(
804
+ pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
805
+ )
806
+ if resolved_archive_file is not None:
807
+ is_sharded = True
808
+
809
+ # If we still haven't found anything, look for `safetensors`.
810
+ if resolved_archive_file is None:
811
+ # No support for sharded safetensors yet, so we'll raise an error if that's all we find.
812
+ filename = SAFE_WEIGHTS_NAME
813
+ resolved_archive_file = cached_file(
814
+ pretrained_model_name_or_path, SAFE_WEIGHTS_NAME, **cached_file_kwargs
815
+ )
816
+
817
+ # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
818
+ # result when internet is up, the repo and revision exist, but the file does not.
819
+ if resolved_archive_file is None:
820
+ # Otherwise, maybe there is a TF or Torch model file. We try those to give a helpful error
821
+ # message.
822
+ has_file_kwargs = {
823
+ "revision": revision,
824
+ "proxies": proxies,
825
+ "token": token,
826
+ "cache_dir": cache_dir,
827
+ "local_files_only": local_files_only,
828
+ }
829
+ if has_file(pretrained_model_name_or_path, SAFE_WEIGHTS_INDEX_NAME, **has_file_kwargs):
830
+ is_sharded = True
831
+ raise NotImplementedError(
832
+ "Support for sharded checkpoints using safetensors is coming soon!"
833
+ )
834
+ elif has_file(pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs):
835
+ raise EnvironmentError(
836
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
837
+ f" {FLAX_WEIGHTS_NAME} but there is a file for PyTorch weights. Use `from_pt=True` to"
838
+ " load this model from those weights."
839
+ )
840
+ elif has_file(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **has_file_kwargs):
841
+ raise EnvironmentError(
842
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
843
+ f" {FLAX_WEIGHTS_INDEX_NAME} but there is a sharded file for PyTorch weights. Use"
844
+ " `from_pt=True` to load this model from those weights."
845
+ )
846
+ else:
847
+ raise EnvironmentError(
848
+ f"{pretrained_model_name_or_path} does not appear to have a file named"
849
+ f" {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
850
+ )
851
+ except EnvironmentError:
852
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
853
+ # to the original exception.
854
+ raise
855
+ except Exception:
856
+ # For any other exception, we throw a generic error.
857
+ raise EnvironmentError(
858
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
859
+ " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
860
+ f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
861
+ f" directory containing a file named {FLAX_WEIGHTS_NAME} or {WEIGHTS_NAME}."
862
+ )
863
+
864
+ if is_local:
865
+ logger.info(f"loading weights file {archive_file}")
866
+ resolved_archive_file = archive_file
867
+ filename = resolved_archive_file.split(os.path.sep)[-1]
868
+ else:
869
+ logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
870
+ else:
871
+ resolved_archive_file = None
872
+
873
+ # We'll need to download and cache each checkpoint shard if the checkpoint is sharded.
874
+ if is_sharded:
875
+ # resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.
876
+ resolved_archive_file, _ = get_checkpoint_shard_files(
877
+ pretrained_model_name_or_path,
878
+ resolved_archive_file,
879
+ cache_dir=cache_dir,
880
+ force_download=force_download,
881
+ proxies=proxies,
882
+ resume_download=resume_download,
883
+ local_files_only=local_files_only,
884
+ token=token,
885
+ user_agent=user_agent,
886
+ revision=revision,
887
+ subfolder=subfolder,
888
+ _commit_hash=commit_hash,
889
+ )
890
+
891
+ safetensors_from_pt = False
892
+ if filename == SAFE_WEIGHTS_NAME:
893
+ with safe_open(resolved_archive_file, framework="flax") as f:
894
+ safetensors_metadata = f.metadata()
895
+ if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
896
+ raise OSError(
897
+ f"The safetensors archive passed at {resolved_archive_file} does not contain the valid metadata."
898
+ " Make sure you save your model with the `save_pretrained` method."
899
+ )
900
+ safetensors_from_pt = safetensors_metadata.get("format") == "pt"
901
+
902
+ # init random models
903
+ model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
904
+
905
+ if from_pt or safetensors_from_pt:
906
+ state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
907
+ else:
908
+ if is_sharded:
909
+ state = cls.load_flax_sharded_weights(resolved_archive_file)
910
+ else:
911
+ state = cls.load_flax_weights(resolved_archive_file)
912
+ # make sure all arrays are stored as jnp.arrays
913
+ # NOTE: This is to prevent a bug this will be fixed in Flax >= v0.3.4:
914
+ # https://github.com/google/flax/issues/1261
915
+ if _do_init:
916
+ state = jax.tree_util.tree_map(jnp.array, state)
917
+ else:
918
+ # keep the params on CPU if we don't want to initialize
919
+ state = jax.tree_util.tree_map(lambda x: jax.device_put(x, jax.local_devices(backend="cpu")[0]), state)
920
+
921
+ if "batch_stats" in state: # if flax model contains batch norm layers
922
+ # if model is base model only use model_prefix key
923
+ if (
924
+ cls.base_model_prefix not in dict(model.params_shape_tree["params"])
925
+ and cls.base_model_prefix in state["params"]
926
+ ):
927
+ state["params"] = state["params"][cls.base_model_prefix]
928
+ state["batch_stats"] = state["batch_stats"][cls.base_model_prefix]
929
+
930
+ # if model is head model and we are loading weights from base model
931
+ # we initialize new params dict with base_model_prefix
932
+ if (
933
+ cls.base_model_prefix in dict(model.params_shape_tree["params"])
934
+ and cls.base_model_prefix not in state["params"]
935
+ ):
936
+ state = {
937
+ "params": {cls.base_model_prefix: state["params"]},
938
+ "batch_stats": {cls.base_model_prefix: state["batch_stats"]},
939
+ }
940
+
941
+ else:
942
+ # if model is base model only use model_prefix key
943
+ if cls.base_model_prefix not in dict(model.params_shape_tree) and cls.base_model_prefix in state:
944
+ state = state[cls.base_model_prefix]
945
+
946
+ # if model is head model and we are loading weights from base model
947
+ # we initialize new params dict with base_model_prefix
948
+ if cls.base_model_prefix in dict(model.params_shape_tree) and cls.base_model_prefix not in state:
949
+ state = {cls.base_model_prefix: state}
950
+
951
+ # flatten dicts
952
+ state = flatten_dict(state)
953
+
954
+ random_state = flatten_dict(unfreeze(model.params if _do_init else model.params_shape_tree))
955
+
956
+ missing_keys = model.required_params - set(state.keys())
957
+ unexpected_keys = set(state.keys()) - model.required_params
958
+
959
+ # Disabling warning when porting pytorch weights to flax, flax does not uses num_batches_tracked
960
+ for unexpected_key in unexpected_keys.copy():
961
+ if "num_batches_tracked" in unexpected_key[-1]:
962
+ unexpected_keys.remove(unexpected_key)
963
+
964
+ if missing_keys and not _do_init:
965
+ logger.warning(
966
+ f"The checkpoint {pretrained_model_name_or_path} is missing required keys: {missing_keys}. "
967
+ "Make sure to call model.init_weights to initialize the missing weights."
968
+ )
969
+ cls._missing_keys = missing_keys
970
+
971
+ # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
972
+ # matching the weights in the model.
973
+ mismatched_keys = []
974
+ for key in state.keys():
975
+ if key in random_state and state[key].shape != random_state[key].shape:
976
+ if ignore_mismatched_sizes:
977
+ mismatched_keys.append((key, state[key].shape, random_state[key].shape))
978
+ state[key] = random_state[key]
979
+ else:
980
+ raise ValueError(
981
+ f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
982
+ f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
983
+ "Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
984
+ "model."
985
+ )
986
+
987
+ # add missing keys as random parameters if we are initializing
988
+ if missing_keys and _do_init:
989
+ for missing_key in missing_keys:
990
+ state[missing_key] = random_state[missing_key]
991
+
992
+ # remove unexpected keys to not be saved again
993
+ for unexpected_key in unexpected_keys:
994
+ del state[unexpected_key]
995
+
996
+ if len(unexpected_keys) > 0:
997
+ logger.warning(
998
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
999
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1000
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
1001
+ " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1002
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1003
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
1004
+ " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
1005
+ )
1006
+ else:
1007
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1008
+
1009
+ if len(missing_keys) > 0:
1010
+ logger.warning(
1011
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1012
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1013
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1014
+ )
1015
+ elif len(mismatched_keys) == 0:
1016
+ logger.info(
1017
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1018
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
1019
+ f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
1020
+ " training."
1021
+ )
1022
+ if len(mismatched_keys) > 0:
1023
+ mismatched_warning = "\n".join(
1024
+ [
1025
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1026
+ for key, shape1, shape2 in mismatched_keys
1027
+ ]
1028
+ )
1029
+ logger.warning(
1030
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1031
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1032
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
1033
+ " to use it for predictions and inference."
1034
+ )
1035
+
1036
+ # dictionary of key: dtypes for the model params
1037
+ param_dtypes = jax.tree_util.tree_map(lambda x: x.dtype, state)
1038
+ # extract keys of parameters not in jnp.float32
1039
+ fp16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.float16]
1040
+ bf16_params = [k for k in param_dtypes if param_dtypes[k] == jnp.bfloat16]
1041
+
1042
+ # raise a warning if any of the parameters are not in jnp.float32
1043
+ if len(fp16_params) > 0:
1044
+ logger.warning(
1045
+ f"Some of the weights of {model.__class__.__name__} were initialized in float16 precision from "
1046
+ f"the model checkpoint at {pretrained_model_name_or_path}:\n{fp16_params}\n"
1047
+ "You should probably UPCAST the model weights to float32 if this was not intended. "
1048
+ "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
1049
+ )
1050
+
1051
+ if len(bf16_params) > 0:
1052
+ logger.warning(
1053
+ f"Some of the weights of {model.__class__.__name__} were initialized in bfloat16 precision from "
1054
+ f"the model checkpoint at {pretrained_model_name_or_path}:\n{bf16_params}\n"
1055
+ "You should probably UPCAST the model weights to float32 if this was not intended. "
1056
+ "See [`~FlaxPreTrainedModel.to_fp32`] for further information on how to do this."
1057
+ )
1058
+
1059
+ # If it is a model with generation capabilities, attempt to load the generation config
1060
+ if model.can_generate():
1061
+ try:
1062
+ model.generation_config = GenerationConfig.from_pretrained(
1063
+ pretrained_model_name_or_path,
1064
+ cache_dir=cache_dir,
1065
+ force_download=force_download,
1066
+ resume_download=resume_download,
1067
+ proxies=proxies,
1068
+ local_files_only=local_files_only,
1069
+ token=token,
1070
+ revision=revision,
1071
+ subfolder=subfolder,
1072
+ _from_auto=from_auto_class,
1073
+ _from_pipeline=from_pipeline,
1074
+ **kwargs,
1075
+ )
1076
+ except OSError:
1077
+ logger.info(
1078
+ "Generation config file not found, using a generation config created from the model config."
1079
+ )
1080
+ pass
1081
+
1082
+ if _do_init:
1083
+ # set correct parameters
1084
+ model.params = unflatten_dict(state)
1085
+ return model
1086
+ else:
1087
+ return model, unflatten_dict(state)
1088
+
1089
+ def save_pretrained(
1090
+ self,
1091
+ save_directory: Union[str, os.PathLike],
1092
+ params=None,
1093
+ push_to_hub=False,
1094
+ max_shard_size="10GB",
1095
+ token: Optional[Union[str, bool]] = None,
1096
+ safe_serialization: bool = False,
1097
+ **kwargs,
1098
+ ):
1099
+ """
1100
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
1101
+ `[`~FlaxPreTrainedModel.from_pretrained`]` class method
1102
+
1103
+ Arguments:
1104
+ save_directory (`str` or `os.PathLike`):
1105
+ Directory to which to save. Will be created if it doesn't exist.
1106
+ push_to_hub (`bool`, *optional*, defaults to `False`):
1107
+ Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
1108
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
1109
+ namespace).
1110
+ max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
1111
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
1112
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
1113
+
1114
+ <Tip warning={true}>
1115
+
1116
+ If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
1117
+ which will be bigger than `max_shard_size`.
1118
+
1119
+ </Tip>
1120
+
1121
+ token (`str` or `bool`, *optional*):
1122
+ The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
1123
+ the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
1124
+ kwargs (`Dict[str, Any]`, *optional*):
1125
+ Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
1126
+ safe_serialization (`bool`, *optional*, defaults to `False`):
1127
+ Whether to save the model using `safetensors` or through msgpack.
1128
+ """
1129
+ use_auth_token = kwargs.pop("use_auth_token", None)
1130
+
1131
+ if use_auth_token is not None:
1132
+ warnings.warn(
1133
+ "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
1134
+ FutureWarning,
1135
+ )
1136
+ if token is not None:
1137
+ raise ValueError(
1138
+ "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
1139
+ )
1140
+ token = use_auth_token
1141
+
1142
+ if token is not None:
1143
+ kwargs["token"] = token
1144
+
1145
+ if os.path.isfile(save_directory):
1146
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
1147
+ return
1148
+
1149
+ os.makedirs(save_directory, exist_ok=True)
1150
+
1151
+ if push_to_hub:
1152
+ commit_message = kwargs.pop("commit_message", None)
1153
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
1154
+ repo_id = self._create_repo(repo_id, **kwargs)
1155
+ files_timestamps = self._get_files_timestamps(save_directory)
1156
+
1157
+ # get abs dir
1158
+ save_directory = os.path.abspath(save_directory)
1159
+ # save config as well
1160
+ self.config.architectures = [self.__class__.__name__[4:]]
1161
+
1162
+ # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
1163
+ # loaded from the Hub.
1164
+ if self._auto_class is not None:
1165
+ custom_object_save(self, save_directory, config=self.config)
1166
+
1167
+ self.config.save_pretrained(save_directory)
1168
+ if self.can_generate():
1169
+ self.generation_config.save_pretrained(save_directory)
1170
+
1171
+ # save model
1172
+ weights_name = SAFE_WEIGHTS_NAME if safe_serialization else FLAX_WEIGHTS_NAME
1173
+ output_model_file = os.path.join(save_directory, weights_name)
1174
+
1175
+ shards, index = flax_shard_checkpoint(params if params is not None else self.params, max_shard_size)
1176
+ # Clean the folder from a previous save
1177
+ for filename in os.listdir(save_directory):
1178
+ full_filename = os.path.join(save_directory, filename)
1179
+ weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
1180
+ if (
1181
+ filename.startswith(weights_no_suffix)
1182
+ and os.path.isfile(full_filename)
1183
+ and filename not in shards.keys()
1184
+ ):
1185
+ os.remove(full_filename)
1186
+
1187
+ if index is None:
1188
+ if safe_serialization:
1189
+ params = params if params is not None else self.params
1190
+ flat_dict = flatten_dict(params, sep=".")
1191
+ safe_save_file(flat_dict, output_model_file, metadata={"format": "flax"})
1192
+ else:
1193
+ with open(output_model_file, "wb") as f:
1194
+ params = params if params is not None else self.params
1195
+ model_bytes = to_bytes(params)
1196
+ f.write(model_bytes)
1197
+
1198
+ else:
1199
+ save_index_file = os.path.join(save_directory, FLAX_WEIGHTS_INDEX_NAME)
1200
+ # Save the index as well
1201
+ with open(save_index_file, "w", encoding="utf-8") as f:
1202
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
1203
+ f.write(content)
1204
+ logger.info(
1205
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
1206
+ f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
1207
+ f"index located at {save_index_file}."
1208
+ )
1209
+ for shard_file, shard in shards.items():
1210
+ # the shard item are unflattened, to save them we need to flatten them again
1211
+ with open(os.path.join(save_directory, shard_file), mode="wb") as f:
1212
+ params = unflatten_dict(shard, sep="/")
1213
+ shard_bytes = to_bytes(params)
1214
+ f.write(shard_bytes)
1215
+
1216
+ logger.info(f"Model weights saved in {output_model_file}")
1217
+
1218
+ if push_to_hub:
1219
+ self._upload_modified_files(
1220
+ save_directory,
1221
+ repo_id,
1222
+ files_timestamps,
1223
+ commit_message=commit_message,
1224
+ token=token,
1225
+ )
1226
+
1227
+ @classmethod
1228
+ def register_for_auto_class(cls, auto_class="FlaxAutoModel"):
1229
+ """
1230
+ Register this class with a given auto class. This should only be used for custom models as the ones in the
1231
+ library are already mapped with an auto class.
1232
+
1233
+ <Tip warning={true}>
1234
+
1235
+ This API is experimental and may have some slight breaking changes in the next releases.
1236
+
1237
+ </Tip>
1238
+
1239
+ Args:
1240
+ auto_class (`str` or `type`, *optional*, defaults to `"FlaxAutoModel"`):
1241
+ The auto class to register this new model with.
1242
+ """
1243
+ if not isinstance(auto_class, str):
1244
+ auto_class = auto_class.__name__
1245
+
1246
+ import transformers.models.auto as auto_module
1247
+
1248
+ if not hasattr(auto_module, auto_class):
1249
+ raise ValueError(f"{auto_class} is not a valid auto class.")
1250
+
1251
+ cls._auto_class = auto_class
1252
+
1253
+
1254
+ # To update the docstring, we need to copy the method, otherwise we change the original docstring.
1255
+ FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
1256
+ if FlaxPreTrainedModel.push_to_hub.__doc__ is not None:
1257
+ FlaxPreTrainedModel.push_to_hub.__doc__ = FlaxPreTrainedModel.push_to_hub.__doc__.format(
1258
+ object="model", object_class="FlaxAutoModel", object_files="model checkpoint"
1259
+ )
1260
+
1261
+
1262
+ def overwrite_call_docstring(model_class, docstring):
1263
+ # copy __call__ function to be sure docstring is changed only for this function
1264
+ model_class.__call__ = copy_func(model_class.__call__)
1265
+ # delete existing docstring
1266
+ model_class.__call__.__doc__ = None
1267
+ # set correct docstring
1268
+ model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
1269
+
1270
+
1271
+ def append_call_sample_docstring(
1272
+ model_class, checkpoint, output_type, config_class, mask=None, revision=None, real_checkpoint=None
1273
+ ):
1274
+ model_class.__call__ = copy_func(model_class.__call__)
1275
+ model_class.__call__ = add_code_sample_docstrings(
1276
+ checkpoint=checkpoint,
1277
+ output_type=output_type,
1278
+ config_class=config_class,
1279
+ model_cls=model_class.__name__,
1280
+ revision=revision,
1281
+ real_checkpoint=real_checkpoint,
1282
+ )(model_class.__call__)
1283
+
1284
+
1285
+ def append_replace_return_docstrings(model_class, output_type, config_class):
1286
+ model_class.__call__ = copy_func(model_class.__call__)
1287
+ model_class.__call__ = replace_return_docstrings(
1288
+ output_type=output_type,
1289
+ config_class=config_class,
1290
+ )(model_class.__call__)
modeling_gguf_pytorch_utils.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The ggml.ai team and The HuggingFace Inc. team. and pygguf author (github.com/99991)
3
+ # https://github.com/99991/pygguf
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import re
18
+ from typing import Dict, NamedTuple, Optional
19
+
20
+ import numpy as np
21
+ from tqdm import tqdm
22
+
23
+ from .integrations import (
24
+ GGUF_CONFIG_MAPPING,
25
+ GGUF_TOKENIZER_MAPPING,
26
+ _gguf_parse_value,
27
+ )
28
+ from .utils import is_torch_available
29
+ from .utils.import_utils import is_gguf_available
30
+ from .utils.logging import get_logger
31
+
32
+
33
+ if is_torch_available():
34
+ import torch
35
+
36
+ logger = get_logger(__name__)
37
+
38
+
39
+ GGUF_TO_TRANSFORMERS_MAPPING = {
40
+ "ignore": {
41
+ "GGUF": {
42
+ "version": "version",
43
+ "tensor_count": "tensor_count",
44
+ "kv_count": "kv_count",
45
+ },
46
+ "general": {"file_type": "file_type", "quantization_version": "quantization_version"},
47
+ },
48
+ "config": GGUF_CONFIG_MAPPING,
49
+ "tokenizer": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer"]},
50
+ "tokenizer_config": {"tokenizer": GGUF_TOKENIZER_MAPPING["tokenizer_config"]},
51
+ }
52
+
53
+ GGUF_SUPPORTED_ARCHITECTURES = list(GGUF_TO_TRANSFORMERS_MAPPING["config"].keys())
54
+
55
+
56
+ class GGUFTensor(NamedTuple):
57
+ weights: np.ndarray
58
+ name: str
59
+ metadata: dict
60
+
61
+
62
+ class TensorProcessor:
63
+ def __init__(self, config=None):
64
+ self.config = config or {}
65
+
66
+ def process(self, weights, name, **kwargs):
67
+ return GGUFTensor(weights, name, {})
68
+
69
+
70
+ class LlamaTensorProcessor(TensorProcessor):
71
+ def __init__(self, config=None):
72
+ super().__init__(config=config)
73
+
74
+ def process(self, weights, name, **kwargs):
75
+ if ".attn_k." in name or ".attn_q." in name:
76
+ num_heads = self.config.get("num_attention_heads")
77
+ num_kv_heads = self.config.get("num_key_value_heads")
78
+
79
+ if None in (num_heads, num_kv_heads):
80
+ return GGUFTensor(weights, name, {})
81
+ if ".attn_q." in name:
82
+ weights = self._reverse_permute_weights(weights, num_heads, num_heads)
83
+ elif ".attn_k." in name:
84
+ weights = self._reverse_permute_weights(weights, num_heads, num_kv_heads)
85
+ return GGUFTensor(weights, name, {})
86
+
87
+ def _reverse_permute_weights(
88
+ self, weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None
89
+ ) -> np.ndarray:
90
+ # Original permutation implementation
91
+ # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
92
+ if num_kv_heads is not None and n_head != num_kv_heads:
93
+ n_head = num_kv_heads
94
+
95
+ dim = weights.shape[0] // n_head // 2
96
+ w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
97
+ return w.swapaxes(2, 1).reshape(weights.shape)
98
+
99
+
100
+ class Qwen2MoeTensorProcessor(TensorProcessor):
101
+ def __init__(self, config=None):
102
+ super().__init__(config=config)
103
+
104
+ def process(self, weights, name, **kwargs):
105
+ if "_exp" in name:
106
+ tensor_key_mapping = kwargs.get("tensor_key_mapping")
107
+ parsed_parameters = kwargs.get("parsed_parameters")
108
+ if tensor_key_mapping:
109
+ self._split_moe_expert_tensor(weights, parsed_parameters, name, tensor_key_mapping)
110
+ return GGUFTensor(weights, None, {})
111
+ if "ffn_gate_inp_shexp" in name:
112
+ # for compatibility tensor shared_expert_gate must be (1, 2048) dim,
113
+ # quantized one is (2048)
114
+ weights = np.expand_dims(weights, axis=0)
115
+ return GGUFTensor(weights, name, {})
116
+
117
+ def _split_moe_expert_tensor(
118
+ self, weights: np.ndarray, parsed_parameters: Dict[str, Dict], name: str, tensor_key_mapping: dict
119
+ ):
120
+ # Original merge implementation
121
+ # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L1994-L2022
122
+ name = tensor_key_mapping[name]
123
+ w_counter = self.config.get("num_experts", 60)
124
+ for i in range(0, w_counter):
125
+ temp_name = name.replace("mlp.experts.", f"mlp.experts.{i}.")
126
+ exp_weight = weights[i]
127
+ parsed_parameters["tensors"][temp_name] = torch.from_numpy(np.copy(exp_weight))
128
+
129
+
130
+ class BloomTensorProcessor(TensorProcessor):
131
+ def __init__(self, config=None):
132
+ super().__init__(config=config)
133
+
134
+ def process(self, weights, name, **kwargs):
135
+ if "attn_qkv" in name:
136
+ num_heads = self.config["n_head"]
137
+ n_embed = self.config["hidden_size"]
138
+ if "weight" in name:
139
+ weights = self._reverse_reshape_weights(weights, num_heads, n_embed)
140
+ else:
141
+ weights = self._reverse_reshape_bias(weights, num_heads, n_embed)
142
+ return GGUFTensor(weights, name, {})
143
+
144
+ def _reverse_reshape_weights(self, weights: np.ndarray, n_head: int, n_embed: int):
145
+ # Original reshape implementation
146
+ # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L972-L985
147
+ q, k, v = np.array_split(weights, 3, axis=0)
148
+
149
+ q = q.reshape(n_head, n_embed // n_head, n_embed)
150
+ k = k.reshape(n_head, n_embed // n_head, n_embed)
151
+ v = v.reshape(n_head, n_embed // n_head, n_embed)
152
+ qkv_weights = np.stack([q, k, v], axis=1)
153
+
154
+ return qkv_weights.reshape(n_head * 3 * (n_embed // n_head), n_embed)
155
+
156
+ def _reverse_reshape_bias(self, weights: np.ndarray, n_head: int, n_embed: int):
157
+ # Original reshape implementation
158
+ # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L986-L998
159
+ q_bias, k_bias, v_bias = np.array_split(weights, 3)
160
+
161
+ q_bias = q_bias.reshape(n_head, n_embed // n_head)
162
+ k_bias = k_bias.reshape(n_head, n_embed // n_head)
163
+ v_bias = v_bias.reshape(n_head, n_embed // n_head)
164
+
165
+ qkv_bias = np.stack([q_bias, k_bias, v_bias], axis=1).flatten()
166
+ return qkv_bias
167
+
168
+
169
+ class T5TensorProcessor(TensorProcessor):
170
+ def __init__(self, config=None):
171
+ super().__init__(config=config)
172
+
173
+ def process(self, weights, name, **kwargs):
174
+ bid = None
175
+ for chunk in name.split("."):
176
+ if chunk.isdigit():
177
+ bid = int(chunk)
178
+ break
179
+ return GGUFTensor(weights, name, {"bid": bid})
180
+
181
+
182
+ class GPT2TensorProcessor(TensorProcessor):
183
+ def __init__(self, config=None):
184
+ super().__init__(config=config)
185
+
186
+ def process(self, weights, name, **kwargs):
187
+ # Original transpose implementation
188
+ # https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L2060-L2061
189
+ if (
190
+ "attn_qkv.weight" in name
191
+ or "ffn_down.weight" in name
192
+ or "ffn_up.weight" in name
193
+ or "attn_output.weight" in name
194
+ ):
195
+ weights = weights.T
196
+
197
+ # Handle special case for output.weight
198
+ if name == "output.weight":
199
+ # output.weight has conflicts with attn_output.weight in name checking
200
+ # Store the tensor directly and signal to skip further processing
201
+ name = "lm_head.weight"
202
+ parsed_parameters = kwargs.get("parsed_parameters", {})
203
+ parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
204
+ name = None # Signal to skip further processing
205
+ return GGUFTensor(weights, name, {})
206
+
207
+
208
+ class MambaTensorProcessor(TensorProcessor):
209
+ def __init__(self, config=None):
210
+ super().__init__(config=config)
211
+
212
+ def process(self, weights, name, **kwargs):
213
+ if "ssm_conv1d.weight" in name:
214
+ # for compatibility tensor ssm_conv1d must be (5120, 1, 4]) dim,
215
+ # quantized one is (5120, 4)
216
+ weights = np.expand_dims(weights, axis=1)
217
+ if "ssm_a" in name:
218
+ # Original exponential implementation
219
+ # https://github.com/ggerganov/llama.cpp/blob/master/convert_hf_to_gguf.py#L2975-L2977
220
+ weights = np.log(-weights)
221
+ return GGUFTensor(weights, name, {})
222
+
223
+
224
+ class Gemma2TensorProcessor(TensorProcessor):
225
+ def __init__(self, config=None):
226
+ super().__init__(config=config)
227
+
228
+ # ref: https://github.com/ggerganov/llama.cpp/blob/d79d8f39b4da6deca4aea8bf130c6034c482b320/convert_hf_to_gguf.py#L3191
229
+ # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
230
+ def process(self, weights, name, **kwargs):
231
+ if "norm.weight" in name:
232
+ weights = weights - 1
233
+ return GGUFTensor(weights, name, {})
234
+
235
+
236
+ TENSOR_PROCESSORS = {
237
+ "llama": LlamaTensorProcessor,
238
+ "qwen2moe": Qwen2MoeTensorProcessor,
239
+ "bloom": BloomTensorProcessor,
240
+ "t5": T5TensorProcessor,
241
+ "t5encoder": T5TensorProcessor,
242
+ "gpt2": GPT2TensorProcessor,
243
+ "mamba": MambaTensorProcessor,
244
+ "gemma2": Gemma2TensorProcessor,
245
+ }
246
+
247
+
248
+ def read_field(reader, field):
249
+ value = reader.fields[field]
250
+ return [_gguf_parse_value(value.parts[_data_index], value.types) for _data_index in value.data]
251
+
252
+
253
+ # modified from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/loader.py#L1115-L1147
254
+ def get_gguf_hf_weights_map(
255
+ hf_model,
256
+ model_type: Optional[str] = None,
257
+ num_layers: Optional[int] = None,
258
+ qual_name: str = "",
259
+ ):
260
+ """
261
+ GGUF uses this naming convention for their tensors from HF checkpoint:
262
+ `blk.N.BB.weight` and `blk.N.BB.bias`
263
+ where N signifies the block number of a layer, and BB signifies the
264
+ attention/mlp layer components.
265
+ See "Standardized tensor names" in
266
+ https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
267
+ """
268
+ if is_gguf_available() and is_torch_available():
269
+ from gguf import MODEL_ARCH_NAMES, get_tensor_name_map
270
+ else:
271
+ logger.error(
272
+ "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
273
+ "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
274
+ )
275
+ raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
276
+
277
+ model_type = hf_model.config.model_type if model_type is None else model_type
278
+ num_layers = hf_model.config.num_hidden_layers if num_layers is None else num_layers
279
+ # hack: ggufs have a different name for cohere
280
+ if model_type == "cohere":
281
+ model_type = "command-r"
282
+ if model_type == "qwen2_moe":
283
+ model_type = "qwen2moe"
284
+ arch = None
285
+ for key, value in MODEL_ARCH_NAMES.items():
286
+ if value == model_type:
287
+ arch = key
288
+ break
289
+ if arch is None:
290
+ raise NotImplementedError(
291
+ f"Unknown gguf model_type: {model_type} in gguf-py. "
292
+ "This might because you're using an outdated version of gguf-py package, "
293
+ "you can install `gguf` package from source refer to "
294
+ "https://github.com/ggerganov/llama.cpp/tree/master/gguf-py#development"
295
+ )
296
+ name_map = get_tensor_name_map(arch, num_layers)
297
+
298
+ # Use a dummy conversion to get the mapping, because
299
+ # hf => gguf and gguf => hf mappings are reversed
300
+ gguf_to_hf_name_map = {}
301
+ state_dict = hf_model.state_dict()
302
+ for hf_name in state_dict.keys():
303
+ # An exception for qwen2moe model, where the expert layers are packed
304
+ if model_type == "qwen2moe" and "mlp.experts." in hf_name:
305
+ hf_name = re.sub(r"mlp.experts.\d+.", "mlp.experts.", hf_name)
306
+
307
+ name, suffix = hf_name, ""
308
+ if hf_name.endswith(".weight") or hf_name.endswith(".bias"):
309
+ name, suffix = hf_name.rsplit(".", 1)
310
+ suffix = "." + suffix
311
+
312
+ gguf_name = name_map.get_name(name)
313
+ if gguf_name is None:
314
+ continue
315
+
316
+ gguf_to_hf_name_map[gguf_name + suffix] = qual_name + hf_name
317
+
318
+ # Some model like Bloom converted from BloomModel instead of BloomForCausalLM
319
+ # Therefore, we need to check submodule as well to get a correct mapping
320
+ if named_children := hf_model.named_children():
321
+ for name, child in named_children:
322
+ sub_map = get_gguf_hf_weights_map(child, model_type, num_layers, qual_name=f"{qual_name}{name}.")
323
+ # Ignore the keys that are already in the main map to avoid overwriting
324
+ sub_map = {k: v for k, v in sub_map.items() if k not in gguf_to_hf_name_map}
325
+ gguf_to_hf_name_map.update(sub_map)
326
+
327
+ return gguf_to_hf_name_map
328
+
329
+
330
+ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_load=None):
331
+ """
332
+ Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed
333
+ tokenizer and config attributes.
334
+
335
+ Args:
336
+ gguf_checkpoint_path (`str`):
337
+ The path the to GGUF file to load
338
+ return_tensors (`bool`, defaults to `True`):
339
+ Whether to read the tensors from the file and return them. Not doing so is faster
340
+ and only loads the metadata in memory.
341
+ """
342
+ if is_gguf_available() and is_torch_available():
343
+ from gguf import GGUFReader, dequantize
344
+ else:
345
+ logger.error(
346
+ "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
347
+ "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
348
+ )
349
+ raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
350
+
351
+ reader = GGUFReader(gguf_checkpoint_path)
352
+ fields = reader.fields
353
+ reader_keys = list(fields.keys())
354
+
355
+ parsed_parameters = {k: {} for k in GGUF_TO_TRANSFORMERS_MAPPING}
356
+
357
+ architecture = read_field(reader, "general.architecture")[0]
358
+ model_name = read_field(reader, "general.name")
359
+
360
+ # in llama.cpp mistral models use the same architecture as llama. We need
361
+ # to add this patch to ensure things work correctly on our side.
362
+ if "llama" in architecture and "mistral" in model_name:
363
+ updated_architecture = "mistral"
364
+ # FIXME: Currnetly this implementation is only for flan-t5 architecture.
365
+ # It needs to be developed for supporting legacy t5.
366
+ elif "t5" in architecture or "t5encoder" in architecture:
367
+ parsed_parameters["config"]["is_gated_act"] = True
368
+ updated_architecture = "t5"
369
+ else:
370
+ updated_architecture = architecture
371
+
372
+ if "qwen2moe" in architecture:
373
+ updated_architecture = "qwen2_moe"
374
+
375
+ # For stablelm architecture, we need to set qkv_bias and use_parallel_residual from tensors
376
+ # If `qkv_bias=True`, qkv_proj with bias will be present in the tensors
377
+ # If `use_parallel_residual=False`, ffn_norm will be present in the tensors
378
+ if "stablelm" in architecture:
379
+ attn_bias_name = {"attn_q.bias", "attn_k.bias", "attn_v.bias"}
380
+ ffn_norm_name = "ffn_norm"
381
+ qkv_bias = any(bias_name in tensor.name for tensor in reader.tensors for bias_name in attn_bias_name)
382
+ use_parallel_residual = any(ffn_norm_name in tensor.name for tensor in reader.tensors)
383
+ parsed_parameters["config"]["use_qkv_bias"] = qkv_bias
384
+ parsed_parameters["config"]["use_parallel_residual"] = not use_parallel_residual
385
+
386
+ if architecture not in GGUF_SUPPORTED_ARCHITECTURES:
387
+ raise ValueError(f"GGUF model with architecture {architecture} is not supported yet.")
388
+
389
+ # Handle tie_word_embeddings, if lm_head.weight is not present in tensors,
390
+ # tie_word_embeddings is true otherwise false
391
+ parsed_parameters["config"]["tie_word_embeddings"] = all(
392
+ "output.weight" != tensor.name for tensor in reader.tensors
393
+ )
394
+
395
+ # List all key-value pairs in a columnized format
396
+ for gguf_key, field in reader.fields.items():
397
+ gguf_key = gguf_key.replace(architecture, updated_architecture)
398
+ split = gguf_key.split(".")
399
+ prefix = split[0]
400
+ config_key = ".".join(split[1:])
401
+
402
+ value = [_gguf_parse_value(field.parts[_data_index], field.types) for _data_index in field.data]
403
+
404
+ if len(value) == 1:
405
+ value = value[0]
406
+
407
+ if isinstance(value, str) and architecture in value:
408
+ value = value.replace(architecture, updated_architecture)
409
+
410
+ for parameter in GGUF_TO_TRANSFORMERS_MAPPING:
411
+ parameter_renames = GGUF_TO_TRANSFORMERS_MAPPING[parameter]
412
+ if prefix in parameter_renames and config_key in parameter_renames[prefix]:
413
+ renamed_config_key = parameter_renames[prefix][config_key]
414
+ if renamed_config_key == -1:
415
+ continue
416
+
417
+ if renamed_config_key is not None:
418
+ parsed_parameters[parameter][renamed_config_key] = value
419
+
420
+ if gguf_key in reader_keys:
421
+ reader_keys.remove(gguf_key)
422
+
423
+ if gguf_key in reader_keys:
424
+ logger.info(f"Some keys were not parsed and added into account {gguf_key} | {value}")
425
+
426
+ # retrieve config vocab_size from tokenizer
427
+ # Pleas refer to https://github.com/huggingface/transformers/issues/32526 for more details
428
+ if "vocab_size" not in parsed_parameters["config"]:
429
+ tokenizer_parameters = parsed_parameters["tokenizer"]
430
+ if "tokens" in tokenizer_parameters:
431
+ parsed_parameters["config"]["vocab_size"] = len(tokenizer_parameters["tokens"])
432
+ else:
433
+ logger.warning(
434
+ "Can't find a way to retrieve missing config vocab_size from tokenizer parameters. "
435
+ "This will use default value from model config class and cause unexpected behavior."
436
+ )
437
+
438
+ if return_tensors:
439
+ parsed_parameters["tensors"] = {}
440
+
441
+ tensor_key_mapping = get_gguf_hf_weights_map(model_to_load)
442
+ config = parsed_parameters.get("config", {})
443
+
444
+ ProcessorClass = TENSOR_PROCESSORS.get(architecture, TensorProcessor)
445
+ processor = ProcessorClass(config=config)
446
+
447
+ for tensor in tqdm(reader.tensors, desc="Converting and de-quantizing GGUF tensors..."):
448
+ name = tensor.name
449
+ weights = dequantize(tensor.data, tensor.tensor_type)
450
+
451
+ result = processor.process(
452
+ weights=weights,
453
+ name=name,
454
+ tensor_key_mapping=tensor_key_mapping,
455
+ parsed_parameters=parsed_parameters,
456
+ )
457
+
458
+ weights = result.weights
459
+ name = result.name
460
+
461
+ if name not in tensor_key_mapping:
462
+ continue
463
+
464
+ name = tensor_key_mapping[name]
465
+
466
+ parsed_parameters["tensors"][name] = torch.from_numpy(np.copy(weights))
467
+
468
+ if len(reader_keys) > 0:
469
+ logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
470
+
471
+ return parsed_parameters
modeling_outputs.py ADDED
The diff for this file is too large to render. See raw diff
 
modeling_rope_utils.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Optional, Tuple
17
+
18
+ from .configuration_utils import PretrainedConfig
19
+ from .utils import is_torch_available, logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ if is_torch_available():
26
+ import torch
27
+
28
+
29
+ def _compute_default_rope_parameters(
30
+ config: Optional[PretrainedConfig] = None,
31
+ device: Optional["torch.device"] = None,
32
+ seq_len: Optional[int] = None,
33
+ **rope_kwargs,
34
+ ) -> Tuple["torch.Tensor", float]:
35
+ """
36
+ Computes the inverse frequencies according to the original RoPE implementation
37
+ Args:
38
+ config ([`~transformers.PretrainedConfig`]):
39
+ The model configuration.
40
+ device (`torch.device`):
41
+ The device to use for initialization of the inverse frequencies.
42
+ seq_len (`int`, *optional*):
43
+ The current sequence length. Unused for this type of RoPE.
44
+ rope_kwargs (`Dict`, *optional*):
45
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
46
+ Returns:
47
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
48
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
49
+ """
50
+ if config is not None and len(rope_kwargs) > 0:
51
+ raise ValueError(
52
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
53
+ f"`_compute_default_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
54
+ )
55
+ if len(rope_kwargs) > 0:
56
+ base = rope_kwargs["base"]
57
+ dim = rope_kwargs["dim"]
58
+ elif config is not None:
59
+ base = config.rope_theta
60
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
61
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
62
+ dim = int(head_dim * partial_rotary_factor)
63
+
64
+ attention_factor = 1.0 # Unused in this type of RoPE
65
+
66
+ # Compute the inverse frequencies
67
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
68
+ return inv_freq, attention_factor
69
+
70
+
71
+ def _compute_linear_scaling_rope_parameters(
72
+ config: Optional[PretrainedConfig] = None,
73
+ device: Optional["torch.device"] = None,
74
+ seq_len: Optional[int] = None,
75
+ **rope_kwargs,
76
+ ) -> Tuple["torch.Tensor", float]:
77
+ """
78
+ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev
79
+ Args:
80
+ config ([`~transformers.PretrainedConfig`]):
81
+ The model configuration.
82
+ device (`torch.device`):
83
+ The device to use for initialization of the inverse frequencies.
84
+ seq_len (`int`, *optional*):
85
+ The current sequence length. Unused for this type of RoPE.
86
+ rope_kwargs (`Dict`, *optional*):
87
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
88
+ Returns:
89
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
90
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
91
+ """
92
+ if config is not None and len(rope_kwargs) > 0:
93
+ raise ValueError(
94
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
95
+ f"`_compute_linear_scaling_rope_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
96
+ )
97
+ if len(rope_kwargs) > 0:
98
+ factor = rope_kwargs["factor"]
99
+ elif config is not None:
100
+ factor = config.rope_scaling["factor"]
101
+
102
+ # Gets the default RoPE parameters
103
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
104
+
105
+ # Then applies linear scaling to the frequencies.
106
+ # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so
107
+ # applying scaling to the inverse frequencies is equivalent.
108
+ inv_freq /= factor
109
+ return inv_freq, attention_factor
110
+
111
+
112
+ def _compute_dynamic_ntk_parameters(
113
+ config: Optional[PretrainedConfig] = None,
114
+ device: Optional["torch.device"] = None,
115
+ seq_len: Optional[int] = None,
116
+ **rope_kwargs,
117
+ ) -> Tuple["torch.Tensor", float]:
118
+ """
119
+ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla
120
+ Args:
121
+ config ([`~transformers.PretrainedConfig`]):
122
+ The model configuration.
123
+ device (`torch.device`):
124
+ The device to use for initialization of the inverse frequencies.
125
+ seq_len (`int`, *optional*):
126
+ The current sequence length, used to update the dynamic RoPE at inference time.
127
+ rope_kwargs (`Dict`, *optional*):
128
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
129
+ Returns:
130
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
131
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
132
+ """
133
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
134
+ if config is not None and len(rope_kwargs) > 0:
135
+ raise ValueError(
136
+ "Unexpected arguments: `**rope_kwargs` and `config` are mutually exclusive in "
137
+ f"`_compute_dynamic_ntk_parameters`, got `rope_kwargs`={rope_kwargs} and `config`={config}"
138
+ )
139
+ if len(rope_kwargs) > 0:
140
+ base = rope_kwargs["base"]
141
+ dim = rope_kwargs["dim"]
142
+ max_position_embeddings = rope_kwargs["max_position_embeddings"]
143
+ factor = rope_kwargs["factor"]
144
+ elif config is not None:
145
+ base = config.rope_theta
146
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
147
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
148
+ dim = int(head_dim * partial_rotary_factor)
149
+ max_position_embeddings = config.max_position_embeddings
150
+ factor = config.rope_scaling["factor"]
151
+
152
+ attention_factor = 1.0 # Unused in this type of RoPE
153
+
154
+ # seq_len: default to max_position_embeddings, e.g. at init time
155
+ seq_len = seq_len if seq_len is not None and seq_len > max_position_embeddings else max_position_embeddings
156
+
157
+ # Compute the inverse frequencies
158
+ base = base * ((factor * seq_len / max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2))
159
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
160
+ return inv_freq, attention_factor
161
+
162
+
163
+ def _compute_yarn_parameters(
164
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
165
+ ) -> Tuple["torch.Tensor", float]:
166
+ """
167
+ Computes the inverse frequencies with NTK scaling. Please refer to the
168
+ [original paper](https://arxiv.org/abs/2309.00071)
169
+ Args:
170
+ config ([`~transformers.PretrainedConfig`]):
171
+ The model configuration.
172
+ device (`torch.device`):
173
+ The device to use for initialization of the inverse frequencies.
174
+ seq_len (`int`, *optional*):
175
+ The current sequence length. Unused for this type of RoPE.
176
+ rope_kwargs (`Dict`, *optional*):
177
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
178
+ Returns:
179
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
180
+ post-processing scaling factor applied to the computed cos/sin.
181
+ """
182
+ # No need to keep BC with yarn, unreleased when this new pattern was created.
183
+ if len(rope_kwargs) > 0:
184
+ raise ValueError(
185
+ f"Unexpected arguments: `**rope_kwargs` should be unset in `_compute_yarn_parameters`, got {rope_kwargs}"
186
+ )
187
+
188
+ base = config.rope_theta
189
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
190
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
191
+ dim = int(head_dim * partial_rotary_factor)
192
+ max_position_embeddings = config.max_position_embeddings
193
+ factor = config.rope_scaling["factor"]
194
+
195
+ # Sets the attention factor as suggested in the paper
196
+ attention_factor = config.rope_scaling.get("attention_factor")
197
+ if attention_factor is None:
198
+ attention_factor = 0.1 * math.log(factor) + 1.0
199
+
200
+ # Optional config options
201
+ # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly)
202
+ beta_fast = config.rope_scaling.get("beta_fast") or 32
203
+ beta_slow = config.rope_scaling.get("beta_slow") or 1
204
+
205
+ # Compute the inverse frequencies
206
+ def find_correction_dim(num_rotations, dim, base, max_position_embeddings):
207
+ """Inverse dimension formula to find the dimension based on the number of rotations"""
208
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
209
+
210
+ def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):
211
+ """Find dimension range bounds based on rotations"""
212
+ low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
213
+ high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
214
+ return max(low, 0), min(high, dim - 1)
215
+
216
+ def linear_ramp_factor(min, max, dim):
217
+ if min == max:
218
+ max += 0.001 # Prevent singularity
219
+
220
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
221
+ ramp_func = torch.clamp(linear_func, 0, 1)
222
+ return ramp_func
223
+
224
+ # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs
225
+ # to expand the possible context length. In other words, interpolation = apply scaling factor.
226
+ pos_freqs = base ** (torch.arange(0, dim, 2).float().to(device) / dim)
227
+ inv_freq_extrapolation = 1.0 / pos_freqs
228
+ inv_freq_interpolation = 1.0 / (factor * pos_freqs)
229
+
230
+ low, high = find_correction_range(beta_fast, beta_slow, dim, base, max_position_embeddings)
231
+
232
+ # Get n-dimensional rotational scaling corrected for extrapolation
233
+ inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).float().to(device)
234
+ inv_freq = (
235
+ inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)
236
+ + inv_freq_extrapolation * inv_freq_extrapolation_factor
237
+ )
238
+
239
+ return inv_freq, attention_factor
240
+
241
+
242
+ def _compute_longrope_parameters(
243
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
244
+ ) -> Tuple["torch.Tensor", float]:
245
+ """
246
+ Computes the inverse frequencies with LongRoPE scaling. Please refer to the
247
+ [original implementation](https://github.com/microsoft/LongRoPE)
248
+ Args:
249
+ config ([`~transformers.PretrainedConfig`]):
250
+ The model configuration.
251
+ device (`torch.device`):
252
+ The device to use for initialization of the inverse frequencies.
253
+ seq_len (`int`, *optional*):
254
+ The current sequence length.
255
+ rope_kwargs (`Dict`, *optional*):
256
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
257
+ Returns:
258
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
259
+ post-processing scaling factor applied to the computed cos/sin.
260
+ """
261
+ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling
262
+ # No need to keep BC with longrope, unreleased when this new pattern was created.
263
+ if len(rope_kwargs) > 0:
264
+ raise ValueError(
265
+ "Unexpected arguments: `**rope_kwargs` should be unset in `_compute_longrope_parameters`, got "
266
+ f"{rope_kwargs}"
267
+ )
268
+
269
+ base = config.rope_theta
270
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
271
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
272
+ dim = int(head_dim * partial_rotary_factor)
273
+ long_factor = config.rope_scaling["long_factor"]
274
+ short_factor = config.rope_scaling["short_factor"]
275
+ factor = config.rope_scaling.get("factor")
276
+ attention_factor = config.rope_scaling.get("attention_factor")
277
+
278
+ # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a
279
+ # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two
280
+ # values to compute the default attention scaling factor, instead of using `factor`.
281
+ if hasattr(config, "original_max_position_embeddings"):
282
+ original_max_position_embeddings = config.original_max_position_embeddings
283
+ factor = config.max_position_embeddings / config.original_max_position_embeddings
284
+ else:
285
+ original_max_position_embeddings = config.max_position_embeddings
286
+
287
+ # Sets the attention factor as suggested in the paper
288
+ if attention_factor is None:
289
+ if factor <= 1.0:
290
+ attention_factor = 1.0
291
+ else:
292
+ attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings))
293
+
294
+ # Compute the inverse frequencies -- scaled based on the target sequence length
295
+ if seq_len and seq_len > original_max_position_embeddings:
296
+ ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device)
297
+ else:
298
+ ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device)
299
+ inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim
300
+ inv_freq = 1.0 / (ext_factors * base**inv_freq_shape)
301
+
302
+ return inv_freq, attention_factor
303
+
304
+
305
+ def _compute_llama3_parameters(
306
+ config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None, **rope_kwargs
307
+ ) -> Tuple["torch.Tensor", float]:
308
+ """
309
+ Computes the inverse frequencies for llama 3.1.
310
+
311
+ Args:
312
+ config ([`~transformers.PretrainedConfig`]):
313
+ The model configuration.
314
+ device (`torch.device`):
315
+ The device to use for initialization of the inverse frequencies.
316
+ seq_len (`int`, *optional*):
317
+ The current sequence length. Unused for this type of RoPE.
318
+ rope_kwargs (`Dict`, *optional*):
319
+ BC compatibility with the previous RoPE class instantiation, will be removed in v4.45.
320
+ Returns:
321
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
322
+ post-processing scaling factor applied to the computed cos/sin.
323
+ """
324
+ # Gets the default RoPE parameters
325
+ inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, **rope_kwargs)
326
+
327
+ factor = config.rope_scaling["factor"] # `8` in the original implementation
328
+ low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation
329
+ high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation
330
+ old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation
331
+
332
+ low_freq_wavelen = old_context_len / low_freq_factor
333
+ high_freq_wavelen = old_context_len / high_freq_factor
334
+
335
+ wavelen = 2 * math.pi / inv_freq
336
+ # wavelen < high_freq_wavelen: do nothing
337
+ # wavelen > low_freq_wavelen: divide by factor
338
+ inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
339
+ # otherwise: interpolate between the two, using a smooth factor
340
+ smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
341
+ smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
342
+ is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
343
+ inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
344
+
345
+ return inv_freq_llama, attention_factor
346
+
347
+
348
+ # This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
349
+ # from the model config. You can append new {'rope_type': callable} pairs to this dictionary to enable custom RoPE
350
+ # parameterizations, as long as the callable has the same signature.
351
+ ROPE_INIT_FUNCTIONS = {
352
+ "default": _compute_default_rope_parameters,
353
+ "linear": _compute_linear_scaling_rope_parameters,
354
+ "dynamic": _compute_dynamic_ntk_parameters,
355
+ "yarn": _compute_yarn_parameters,
356
+ "longrope": _compute_longrope_parameters,
357
+ "llama3": _compute_llama3_parameters,
358
+ }
359
+
360
+
361
+ def _check_received_keys(
362
+ rope_type: str,
363
+ received_keys: set,
364
+ required_keys: set,
365
+ optional_keys: Optional[set] = None,
366
+ ignore_keys: Optional[set] = None,
367
+ ):
368
+ """Compare the received keys in `config.rope_scaling` against the expected and optional keys"""
369
+ # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present
370
+ if "type" in received_keys:
371
+ received_keys -= {"type"}
372
+ required_keys.add("rope_type")
373
+
374
+ # Some models need to store model-specific keys, and we don't want to throw warning at them
375
+ if ignore_keys is not None:
376
+ received_keys -= ignore_keys
377
+
378
+ missing_keys = required_keys - received_keys
379
+ if missing_keys:
380
+ raise KeyError(f"Missing required keys in `rope_scaling` for 'rope_type'='{rope_type}': {missing_keys}")
381
+
382
+ if optional_keys is not None:
383
+ unused_keys = received_keys - required_keys - optional_keys
384
+ else:
385
+ unused_keys = received_keys - required_keys
386
+ if unused_keys:
387
+ logger.warning(f"Unrecognized keys in `rope_scaling` for 'rope_type'='{rope_type}': {unused_keys}")
388
+
389
+
390
+ def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
391
+ rope_scaling = config.rope_scaling
392
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
393
+ required_keys = {"rope_type"}
394
+ received_keys = set(rope_scaling.keys())
395
+ _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
396
+
397
+
398
+ def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
399
+ rope_scaling = config.rope_scaling
400
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
401
+ required_keys = {"rope_type", "factor"}
402
+ received_keys = set(rope_scaling.keys())
403
+ _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
404
+
405
+ factor = rope_scaling["factor"]
406
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
407
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
408
+
409
+
410
+ def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
411
+ rope_scaling = config.rope_scaling
412
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
413
+ required_keys = {"rope_type", "factor"}
414
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
415
+ optional_keys = {"original_max_position_embeddings"}
416
+ received_keys = set(rope_scaling.keys())
417
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
418
+
419
+ factor = rope_scaling["factor"]
420
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
421
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
422
+
423
+
424
+ def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
425
+ rope_scaling = config.rope_scaling
426
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
427
+ required_keys = {"rope_type", "factor"}
428
+ optional_keys = {"attention_factor", "beta_fast", "beta_slow"}
429
+ received_keys = set(rope_scaling.keys())
430
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
431
+
432
+ factor = rope_scaling["factor"]
433
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
434
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
435
+
436
+ attention_factor = rope_scaling.get("attention_factor")
437
+ if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0):
438
+ logger.warning(
439
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
440
+ )
441
+ beta_fast = rope_scaling.get("beta_fast")
442
+ if beta_fast is not None and not isinstance(beta_fast, float):
443
+ logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}")
444
+ beta_slow = rope_scaling.get("beta_slow")
445
+ if beta_slow is not None and not isinstance(beta_slow, float):
446
+ logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}")
447
+
448
+ if (beta_fast or 32) < (beta_slow or 1):
449
+ logger.warning(
450
+ f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} "
451
+ f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)"
452
+ )
453
+
454
+
455
+ def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
456
+ rope_scaling = config.rope_scaling
457
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
458
+ required_keys = {"rope_type", "short_factor", "long_factor"}
459
+ # TODO (joao): update logic for the inclusion of `original_max_position_embeddings`
460
+ optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"}
461
+ received_keys = set(rope_scaling.keys())
462
+ _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
463
+
464
+ partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
465
+ head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
466
+ dim = int(head_dim * partial_rotary_factor)
467
+
468
+ short_factor = rope_scaling.get("short_factor")
469
+ if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor):
470
+ logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}")
471
+ if not len(short_factor) == dim // 2:
472
+ logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}")
473
+
474
+ long_factor = rope_scaling.get("long_factor")
475
+ if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor):
476
+ logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}")
477
+ if not len(long_factor) == dim // 2:
478
+ logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}")
479
+
480
+ # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over
481
+ # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is
482
+ # unique to longrope (= undesirable)
483
+ if hasattr(config, "original_max_position_embeddings"):
484
+ logger.warning_once(
485
+ "This model has set a `original_max_position_embeddings` field, to be used together with "
486
+ "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`"
487
+ "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, "
488
+ "as it is compatible with most model architectures."
489
+ )
490
+ else:
491
+ factor = rope_scaling.get("factor")
492
+ if factor is None:
493
+ logger.warning("Missing required keys in `rope_scaling`: 'factor'")
494
+ elif not isinstance(factor, float) or factor < 1.0:
495
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
496
+
497
+ attention_factor = rope_scaling.get("attention_factor")
498
+ if attention_factor is not None:
499
+ if not isinstance(attention_factor, float) or attention_factor < 0.0:
500
+ logger.warning(
501
+ f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}"
502
+ )
503
+
504
+
505
+ def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None):
506
+ rope_scaling = config.rope_scaling
507
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type"
508
+ required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"}
509
+ received_keys = set(rope_scaling.keys())
510
+ _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
511
+
512
+ factor = rope_scaling["factor"]
513
+ if factor is None or not isinstance(factor, float) or factor < 1.0:
514
+ logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}")
515
+
516
+ low_freq_factor = rope_scaling["low_freq_factor"]
517
+ high_freq_factor = rope_scaling["high_freq_factor"]
518
+ if low_freq_factor is None or not isinstance(low_freq_factor, float):
519
+ logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}")
520
+ if high_freq_factor is None or not isinstance(high_freq_factor, float):
521
+ logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}")
522
+ if high_freq_factor <= low_freq_factor:
523
+ logger.warning(
524
+ "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor="
525
+ f"{high_freq_factor} and low_freq_factor={low_freq_factor}"
526
+ )
527
+
528
+ original_max_position_embeddings = rope_scaling["original_max_position_embeddings"]
529
+ if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int):
530
+ logger.warning(
531
+ "`rope_scaling`'s original_max_position_embeddings field must be an integer, got "
532
+ f"{original_max_position_embeddings}"
533
+ )
534
+ if original_max_position_embeddings >= config.max_position_embeddings:
535
+ logger.warning(
536
+ "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got "
537
+ f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}"
538
+ )
539
+
540
+
541
+ # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types.
542
+ ROPE_VALIDATION_FUNCTIONS = {
543
+ "default": _validate_default_rope_parameters,
544
+ "linear": _validate_linear_scaling_rope_parameters,
545
+ "dynamic": _validate_dynamic_scaling_rope_parameters,
546
+ "yarn": _validate_yarn_parameters,
547
+ "longrope": _validate_longrope_parameters,
548
+ "llama3": _validate_llama3_parameters,
549
+ }
550
+
551
+
552
+ def rope_config_validation(config: PretrainedConfig, ignore_keys: Optional[set] = None):
553
+ """
554
+ Validate the RoPE config arguments, given a `PretrainedConfig` object
555
+ """
556
+ rope_scaling = getattr(config, "rope_scaling", None) # not a default parameter in `PretrainedConfig`
557
+ if rope_scaling is None:
558
+ return
559
+
560
+ # BC: "rope_type" was originally "type"
561
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", "default"))
562
+ validation_fn = ROPE_VALIDATION_FUNCTIONS.get(rope_type)
563
+ if validation_fn is not None:
564
+ validation_fn(config, ignore_keys=ignore_keys)
565
+ else:
566
+ logger.warning(
567
+ f"Missing validation function mapping in `ROPE_VALIDATION_FUNCTIONS` for 'rope_type'='{rope_type}'"
568
+ )
modeling_tf_outputs.py ADDED
@@ -0,0 +1,991 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import warnings
18
+ from dataclasses import dataclass
19
+ from typing import List, Optional, Tuple
20
+
21
+ import tensorflow as tf
22
+
23
+ from .utils import ModelOutput
24
+
25
+
26
+ @dataclass
27
+ class TFBaseModelOutput(ModelOutput):
28
+ """
29
+ Base class for model's outputs, with potential hidden states and attentions.
30
+
31
+ Args:
32
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
33
+ Sequence of hidden-states at the output of the last layer of the model.
34
+ hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
35
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
36
+ `(batch_size, sequence_length, hidden_size)`.
37
+
38
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
39
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
40
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
41
+ sequence_length)`.
42
+
43
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
44
+ heads.
45
+ """
46
+
47
+ last_hidden_state: tf.Tensor = None
48
+ hidden_states: Tuple[tf.Tensor] | None = None
49
+ attentions: Tuple[tf.Tensor] | None = None
50
+
51
+
52
+ @dataclass
53
+ class TFBaseModelOutputWithNoAttention(ModelOutput):
54
+ """
55
+ Base class for model's outputs, with potential hidden states.
56
+
57
+ Args:
58
+ last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
59
+ Sequence of hidden-states at the output of the last layer of the model.
60
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
61
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
62
+ the output of each layer) of shape `(batch_size, num_channels, height, width)`.
63
+
64
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
65
+ """
66
+
67
+ last_hidden_state: tf.Tensor = None
68
+ hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
69
+
70
+
71
+ @dataclass
72
+ class TFBaseModelOutputWithPooling(ModelOutput):
73
+ """
74
+ Base class for model's outputs that also contains a pooling of the last hidden states.
75
+
76
+ Args:
77
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
78
+ Sequence of hidden-states at the output of the last layer of the model.
79
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
80
+ Last layer hidden-state of the first token of the sequence (classification token) further processed by a
81
+ Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
82
+ prediction (classification) objective during pretraining.
83
+
84
+ This output is usually *not* a good summary of the semantic content of the input, you're often better with
85
+ averaging or pooling the sequence of hidden-states for the whole input sequence.
86
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
87
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
88
+ `(batch_size, sequence_length, hidden_size)`.
89
+
90
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
91
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
92
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
93
+ sequence_length)`.
94
+
95
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
96
+ heads.
97
+ """
98
+
99
+ last_hidden_state: tf.Tensor = None
100
+ pooler_output: tf.Tensor = None
101
+ hidden_states: Tuple[tf.Tensor] | None = None
102
+ attentions: Tuple[tf.Tensor] | None = None
103
+
104
+
105
+ @dataclass
106
+ class TFBaseModelOutputWithPoolingAndNoAttention(ModelOutput):
107
+ """
108
+ Base class for model's outputs that also contains a pooling of the last hidden states.
109
+
110
+ Args:
111
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
112
+ Sequence of hidden-states at the output of the last layer of the model.
113
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
114
+ Last layer hidden-state after a pooling operation on the spatial dimensions.
115
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
116
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
117
+ the output of each layer) of shape `(batch_size, num_channels, height, width)`.
118
+
119
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
120
+ """
121
+
122
+ last_hidden_state: tf.Tensor = None
123
+ pooler_output: tf.Tensor = None
124
+ hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
125
+
126
+
127
+ @dataclass
128
+ class TFBaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
129
+ """
130
+ Base class for model's outputs that also contains a pooling of the last hidden states.
131
+
132
+ Args:
133
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
134
+ Sequence of hidden-states at the output of the last layer of the model.
135
+ pooler_output (`tf.Tensor` of shape `(batch_size, hidden_size)`):
136
+ Last layer hidden-state of the first token of the sequence (classification token) further processed by a
137
+ Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
138
+ prediction (classification) objective during pretraining.
139
+
140
+ This output is usually *not* a good summary of the semantic content of the input, you're often better with
141
+ averaging or pooling the sequence of hidden-states for the whole input sequence.
142
+ past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
143
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
144
+ sequence_length, embed_size_per_head)`).
145
+
146
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
147
+ `past_key_values` input) to speed up sequential decoding.
148
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
149
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
150
+ `(batch_size, sequence_length, hidden_size)`.
151
+
152
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
153
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
154
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
155
+ sequence_length)`.
156
+
157
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
158
+ heads.
159
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
160
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
161
+ sequence_length)`.
162
+
163
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
164
+ weighted average in the cross-attention heads.
165
+ """
166
+
167
+ last_hidden_state: tf.Tensor = None
168
+ pooler_output: tf.Tensor = None
169
+ past_key_values: List[tf.Tensor] | None = None
170
+ hidden_states: Tuple[tf.Tensor] | None = None
171
+ attentions: Tuple[tf.Tensor] | None = None
172
+ cross_attentions: Tuple[tf.Tensor] | None = None
173
+
174
+
175
+ @dataclass
176
+ class TFBaseModelOutputWithPast(ModelOutput):
177
+ """
178
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
179
+
180
+ Args:
181
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
182
+ Sequence of hidden-states at the output of the last layer of the model.
183
+
184
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
185
+ hidden_size)` is output.
186
+ past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
187
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
188
+ sequence_length, embed_size_per_head)`).
189
+
190
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
191
+ `past_key_values` input) to speed up sequential decoding.
192
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
193
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
194
+ `(batch_size, sequence_length, hidden_size)`.
195
+
196
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
197
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
198
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
199
+ sequence_length)`.
200
+
201
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
202
+ heads.
203
+ """
204
+
205
+ last_hidden_state: tf.Tensor = None
206
+ past_key_values: List[tf.Tensor] | None = None
207
+ hidden_states: Tuple[tf.Tensor] | None = None
208
+ attentions: Tuple[tf.Tensor] | None = None
209
+
210
+
211
+ @dataclass
212
+ class TFBaseModelOutputWithCrossAttentions(ModelOutput):
213
+ """
214
+ Base class for model's outputs, with potential hidden states and attentions.
215
+
216
+ Args:
217
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
218
+ Sequence of hidden-states at the output of the last layer of the model.
219
+ hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
220
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
221
+ `(batch_size, sequence_length, hidden_size)`.
222
+
223
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
224
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
225
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
226
+ sequence_length)`.
227
+
228
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
229
+ heads.
230
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
231
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
232
+ sequence_length)`.
233
+
234
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
235
+ weighted average in the cross-attention heads.
236
+ """
237
+
238
+ last_hidden_state: tf.Tensor = None
239
+ hidden_states: Tuple[tf.Tensor] | None = None
240
+ attentions: Tuple[tf.Tensor] | None = None
241
+ cross_attentions: Tuple[tf.Tensor] | None = None
242
+
243
+
244
+ @dataclass
245
+ class TFBaseModelOutputWithPastAndCrossAttentions(ModelOutput):
246
+ """
247
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
248
+
249
+ Args:
250
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
251
+ Sequence of hidden-states at the output of the last layer of the model.
252
+
253
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
254
+ hidden_size)` is output.
255
+ past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
256
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
257
+ sequence_length, embed_size_per_head)`).
258
+
259
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
260
+ `past_key_values` input) to speed up sequential decoding.
261
+ hidden_states (`tuple(tf.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
262
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
263
+ `(batch_size, sequence_length, hidden_size)`.
264
+
265
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
266
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
267
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
268
+ sequence_length)`.
269
+
270
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
271
+ heads.
272
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
273
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
274
+ sequence_length)`.
275
+
276
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
277
+ weighted average in the cross-attention heads.
278
+ """
279
+
280
+ last_hidden_state: tf.Tensor = None
281
+ past_key_values: List[tf.Tensor] | None = None
282
+ hidden_states: Tuple[tf.Tensor] | None = None
283
+ attentions: Tuple[tf.Tensor] | None = None
284
+ cross_attentions: Tuple[tf.Tensor] | None = None
285
+
286
+
287
+ @dataclass
288
+ class TFSeq2SeqModelOutput(ModelOutput):
289
+ """
290
+ Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
291
+ decoding.
292
+
293
+ Args:
294
+ last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
295
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
296
+
297
+ If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
298
+ hidden_size)` is output.
299
+ past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
300
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
301
+ sequence_length, embed_size_per_head)`).
302
+
303
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
304
+ used (see `past_key_values` input) to speed up sequential decoding.
305
+ decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
306
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
307
+ `(batch_size, sequence_length, hidden_size)`.
308
+
309
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
310
+ decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
311
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
312
+ sequence_length)`.
313
+
314
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
315
+ self-attention heads.
316
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
317
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
318
+ sequence_length)`.
319
+
320
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
321
+ weighted average in the cross-attention heads.
322
+ encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
323
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
324
+ encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
325
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
326
+ `(batch_size, sequence_length, hidden_size)`.
327
+
328
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
329
+ encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
330
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
331
+ sequence_length)`.
332
+
333
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
334
+ self-attention heads.
335
+ """
336
+
337
+ last_hidden_state: tf.Tensor = None
338
+ past_key_values: List[tf.Tensor] | None = None
339
+ decoder_hidden_states: Tuple[tf.Tensor] | None = None
340
+ decoder_attentions: Tuple[tf.Tensor] | None = None
341
+ cross_attentions: Tuple[tf.Tensor] | None = None
342
+ encoder_last_hidden_state: tf.Tensor | None = None
343
+ encoder_hidden_states: Tuple[tf.Tensor] | None = None
344
+ encoder_attentions: Tuple[tf.Tensor] | None = None
345
+
346
+
347
+ @dataclass
348
+ class TFCausalLMOutput(ModelOutput):
349
+ """
350
+ Base class for causal language model (or autoregressive) outputs.
351
+
352
+ Args:
353
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
354
+ Language modeling loss (for next-token prediction).
355
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
356
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
357
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
358
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
359
+ `(batch_size, sequence_length, hidden_size)`.
360
+
361
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
362
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
363
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
364
+ sequence_length)`.
365
+
366
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
367
+ heads.
368
+ """
369
+
370
+ loss: tf.Tensor | None = None
371
+ logits: tf.Tensor = None
372
+ hidden_states: Tuple[tf.Tensor] | None = None
373
+ attentions: Tuple[tf.Tensor] | None = None
374
+
375
+
376
+ @dataclass
377
+ class TFCausalLMOutputWithPast(ModelOutput):
378
+ """
379
+ Base class for causal language model (or autoregressive) outputs.
380
+
381
+ Args:
382
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
383
+ Language modeling loss (for next-token prediction).
384
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
385
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
386
+ past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
387
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
388
+ sequence_length, embed_size_per_head)`).
389
+
390
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
391
+ `past_key_values` input) to speed up sequential decoding.
392
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
393
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
394
+ `(batch_size, sequence_length, hidden_size)`.
395
+
396
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
397
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
398
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
399
+ sequence_length)`.
400
+
401
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
402
+ heads.
403
+ """
404
+
405
+ loss: tf.Tensor | None = None
406
+ logits: tf.Tensor = None
407
+ past_key_values: List[tf.Tensor] | None = None
408
+ hidden_states: Tuple[tf.Tensor] | None = None
409
+ attentions: Tuple[tf.Tensor] | None = None
410
+
411
+
412
+ @dataclass
413
+ class TFCausalLMOutputWithCrossAttentions(ModelOutput):
414
+ """
415
+ Base class for causal language model (or autoregressive) outputs.
416
+
417
+ Args:
418
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
419
+ Language modeling loss (for next-token prediction).
420
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
421
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
422
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
423
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
424
+ `(batch_size, sequence_length, hidden_size)`.
425
+
426
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
427
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
428
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
429
+ sequence_length)`.
430
+
431
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
432
+ heads.
433
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
434
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
435
+ sequence_length)`.
436
+
437
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
438
+ weighted average in the cross-attention heads.
439
+ past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
440
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
441
+ sequence_length, embed_size_per_head)`).
442
+
443
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
444
+ `past_key_values` input) to speed up sequential decoding.
445
+ """
446
+
447
+ loss: tf.Tensor | None = None
448
+ logits: tf.Tensor = None
449
+ past_key_values: List[tf.Tensor] | None = None
450
+ hidden_states: Tuple[tf.Tensor] | None = None
451
+ attentions: Tuple[tf.Tensor] | None = None
452
+ cross_attentions: Tuple[tf.Tensor] | None = None
453
+
454
+
455
+ @dataclass
456
+ class TFMaskedLMOutput(ModelOutput):
457
+ """
458
+ Base class for masked language models outputs.
459
+
460
+ Args:
461
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
462
+ Masked language modeling (MLM) loss.
463
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
464
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
465
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
466
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
467
+ `(batch_size, sequence_length, hidden_size)`.
468
+
469
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
470
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
471
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
472
+ sequence_length)`.
473
+
474
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
475
+ heads.
476
+ """
477
+
478
+ loss: tf.Tensor | None = None
479
+ logits: tf.Tensor = None
480
+ hidden_states: Tuple[tf.Tensor] | None = None
481
+ attentions: Tuple[tf.Tensor] | None = None
482
+
483
+
484
+ @dataclass
485
+ class TFSeq2SeqLMOutput(ModelOutput):
486
+ """
487
+ Base class for sequence-to-sequence language models outputs.
488
+
489
+ Args:
490
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `labels` is provided):
491
+ Language modeling loss.
492
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
493
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
494
+ past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
495
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
496
+ sequence_length, embed_size_per_head)`).
497
+
498
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
499
+ used (see `past_key_values` input) to speed up sequential decoding.
500
+ decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
501
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
502
+ `(batch_size, sequence_length, hidden_size)`.
503
+
504
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
505
+ decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
506
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
507
+ sequence_length)`.
508
+
509
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
510
+ self-attention heads.
511
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
512
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
513
+ sequence_length)`.
514
+
515
+ Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
516
+ weighted average in the cross-attention heads.
517
+ encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
518
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
519
+ encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
520
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
521
+ `(batch_size, sequence_length, hidden_size)`.
522
+
523
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
524
+ encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
525
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
526
+ sequence_length)`.
527
+
528
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
529
+ self-attention heads.
530
+ """
531
+
532
+ loss: tf.Tensor | None = None
533
+ logits: tf.Tensor = None
534
+ past_key_values: List[tf.Tensor] | None = None
535
+ decoder_hidden_states: Tuple[tf.Tensor] | None = None
536
+ decoder_attentions: Tuple[tf.Tensor] | None = None
537
+ cross_attentions: Tuple[tf.Tensor] | None = None
538
+ encoder_last_hidden_state: tf.Tensor | None = None
539
+ encoder_hidden_states: Tuple[tf.Tensor] | None = None
540
+ encoder_attentions: Tuple[tf.Tensor] | None = None
541
+
542
+
543
+ @dataclass
544
+ class TFNextSentencePredictorOutput(ModelOutput):
545
+ """
546
+ Base class for outputs of models predicting if two sentences are consecutive or not.
547
+
548
+ Args:
549
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of non-masked labels, returned when `next_sentence_label` is provided):
550
+ Next sentence prediction loss.
551
+ logits (`tf.Tensor` of shape `(batch_size, 2)`):
552
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
553
+ before SoftMax).
554
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
555
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
556
+ `(batch_size, sequence_length, hidden_size)`.
557
+
558
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
559
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
560
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
561
+ sequence_length)`.
562
+
563
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
564
+ heads.
565
+ """
566
+
567
+ loss: tf.Tensor | None = None
568
+ logits: tf.Tensor = None
569
+ hidden_states: Tuple[tf.Tensor] | None = None
570
+ attentions: Tuple[tf.Tensor] | None = None
571
+
572
+
573
+ @dataclass
574
+ class TFSequenceClassifierOutput(ModelOutput):
575
+ """
576
+ Base class for outputs of sentence classification models.
577
+
578
+ Args:
579
+ loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided):
580
+ Classification (or regression if config.num_labels==1) loss.
581
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
582
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
583
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
584
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
585
+ `(batch_size, sequence_length, hidden_size)`.
586
+
587
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
588
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
589
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
590
+ sequence_length)`.
591
+
592
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
593
+ heads.
594
+ """
595
+
596
+ loss: tf.Tensor | None = None
597
+ logits: tf.Tensor = None
598
+ hidden_states: Tuple[tf.Tensor] | None = None
599
+ attentions: Tuple[tf.Tensor] | None = None
600
+
601
+
602
+ @dataclass
603
+ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
604
+ """
605
+ Base class for outputs of sequence-to-sequence sentence classification models.
606
+
607
+ Args:
608
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `label` is provided):
609
+ Classification (or regression if config.num_labels==1) loss.
610
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
611
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
612
+ past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
613
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
614
+ sequence_length, embed_size_per_head)`).
615
+
616
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
617
+ used (see `past_key_values` input) to speed up sequential decoding.
618
+ decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
619
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
620
+ `(batch_size, sequence_length, hidden_size)`.
621
+
622
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
623
+ decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
624
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
625
+ sequence_length)`.
626
+
627
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
628
+ self-attention heads.
629
+ cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
630
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
631
+ sequence_length)`
632
+ encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
633
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
634
+ encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
635
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
636
+ `(batch_size, sequence_length, hidden_size)`.
637
+
638
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
639
+ encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
640
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
641
+ sequence_length)`.
642
+
643
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
644
+ self-attention heads.
645
+ """
646
+
647
+ loss: tf.Tensor | None = None
648
+ logits: tf.Tensor = None
649
+ past_key_values: List[tf.Tensor] | None = None
650
+ decoder_hidden_states: Tuple[tf.Tensor] | None = None
651
+ decoder_attentions: Tuple[tf.Tensor] | None = None
652
+ cross_attentions: Tuple[tf.Tensor] | None = None
653
+ encoder_last_hidden_state: tf.Tensor | None = None
654
+ encoder_hidden_states: Tuple[tf.Tensor] | None = None
655
+ encoder_attentions: Tuple[tf.Tensor] | None = None
656
+
657
+
658
+ @dataclass
659
+ class TFSemanticSegmenterOutput(ModelOutput):
660
+ """
661
+ Base class for outputs of semantic segmentation models.
662
+
663
+ Args:
664
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
665
+ Classification (or regression if config.num_labels==1) loss.
666
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
667
+ Classification scores for each pixel.
668
+
669
+ <Tip warning={true}>
670
+
671
+ The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
672
+ to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
673
+ original image size as post-processing. You should always check your logits shape and resize as needed.
674
+
675
+ </Tip>
676
+
677
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
678
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
679
+ the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
680
+
681
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
682
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
683
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
684
+
685
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
686
+ heads.
687
+ """
688
+
689
+ loss: tf.Tensor | None = None
690
+ logits: tf.Tensor = None
691
+ hidden_states: Tuple[tf.Tensor] | None = None
692
+ attentions: Tuple[tf.Tensor] | None = None
693
+
694
+
695
+ @dataclass
696
+ class TFSemanticSegmenterOutputWithNoAttention(ModelOutput):
697
+ """
698
+ Base class for outputs of semantic segmentation models that do not output attention scores.
699
+
700
+ Args:
701
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
702
+ Classification (or regression if config.num_labels==1) loss.
703
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
704
+ Classification scores for each pixel.
705
+
706
+ <Tip warning={true}>
707
+
708
+ The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
709
+ to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
710
+ original image size as post-processing. You should always check your logits shape and resize as needed.
711
+
712
+ </Tip>
713
+
714
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
715
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
716
+ the output of each layer) of shape `(batch_size, patch_size, hidden_size)`.
717
+
718
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
719
+ """
720
+
721
+ loss: tf.Tensor | None = None
722
+ logits: tf.Tensor = None
723
+ hidden_states: Tuple[tf.Tensor] | None = None
724
+
725
+
726
+ @dataclass
727
+ class TFImageClassifierOutput(ModelOutput):
728
+ """
729
+ Base class for outputs of image classification models.
730
+
731
+ Args:
732
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
733
+ Classification (or regression if config.num_labels==1) loss.
734
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
735
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
736
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
737
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
738
+ the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
739
+ feature maps) of the model at the output of each stage.
740
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
741
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
742
+
743
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
744
+ heads.
745
+ """
746
+
747
+ loss: tf.Tensor | None = None
748
+ logits: tf.Tensor = None
749
+ hidden_states: Tuple[tf.Tensor] | None = None
750
+ attentions: Tuple[tf.Tensor] | None = None
751
+
752
+
753
+ @dataclass
754
+ class TFMultipleChoiceModelOutput(ModelOutput):
755
+ """
756
+ Base class for outputs of multiple choice models.
757
+
758
+ Args:
759
+ loss (`tf.Tensor` of shape *(batch_size, )*, *optional*, returned when `labels` is provided):
760
+ Classification loss.
761
+ logits (`tf.Tensor` of shape `(batch_size, num_choices)`):
762
+ *num_choices* is the second dimension of the input tensors. (see *input_ids* above).
763
+
764
+ Classification scores (before SoftMax).
765
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
766
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
767
+ `(batch_size, sequence_length, hidden_size)`.
768
+
769
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
770
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
771
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
772
+ sequence_length)`.
773
+
774
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
775
+ heads.
776
+ """
777
+
778
+ loss: tf.Tensor | None = None
779
+ logits: tf.Tensor = None
780
+ hidden_states: Tuple[tf.Tensor] | None = None
781
+ attentions: Tuple[tf.Tensor] | None = None
782
+
783
+
784
+ @dataclass
785
+ class TFTokenClassifierOutput(ModelOutput):
786
+ """
787
+ Base class for outputs of token classification models.
788
+
789
+ Args:
790
+ loss (`tf.Tensor` of shape `(n,)`, *optional*, where n is the number of unmasked labels, returned when `labels` is provided) :
791
+ Classification loss.
792
+ logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.num_labels)`):
793
+ Classification scores (before SoftMax).
794
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
795
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
796
+ `(batch_size, sequence_length, hidden_size)`.
797
+
798
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
799
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
800
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
801
+ sequence_length)`.
802
+
803
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
804
+ heads.
805
+ """
806
+
807
+ loss: tf.Tensor | None = None
808
+ logits: tf.Tensor = None
809
+ hidden_states: Tuple[tf.Tensor] | None = None
810
+ attentions: Tuple[tf.Tensor] | None = None
811
+
812
+
813
+ @dataclass
814
+ class TFQuestionAnsweringModelOutput(ModelOutput):
815
+ """
816
+ Base class for outputs of question answering models.
817
+
818
+ Args:
819
+ loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `start_positions` and `end_positions` are provided):
820
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
821
+ start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
822
+ Span-start scores (before SoftMax).
823
+ end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
824
+ Span-end scores (before SoftMax).
825
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
826
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
827
+ `(batch_size, sequence_length, hidden_size)`.
828
+
829
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
830
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
831
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
832
+ sequence_length)`.
833
+
834
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
835
+ heads.
836
+ """
837
+
838
+ loss: tf.Tensor | None = None
839
+ start_logits: tf.Tensor = None
840
+ end_logits: tf.Tensor = None
841
+ hidden_states: Tuple[tf.Tensor] | None = None
842
+ attentions: Tuple[tf.Tensor] | None = None
843
+
844
+
845
+ @dataclass
846
+ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
847
+ """
848
+ Base class for outputs of sequence-to-sequence question answering models.
849
+
850
+ Args:
851
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
852
+ Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
853
+ start_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
854
+ Span-start scores (before SoftMax).
855
+ end_logits (`tf.Tensor` of shape `(batch_size, sequence_length)`):
856
+ Span-end scores (before SoftMax).
857
+ past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
858
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
859
+ sequence_length, embed_size_per_head)`).
860
+
861
+ Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
862
+ used (see `past_key_values` input) to speed up sequential decoding.
863
+ decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
864
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
865
+ `(batch_size, sequence_length, hidden_size)`.
866
+
867
+ Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
868
+ decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
869
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
870
+ sequence_length)`.
871
+
872
+ Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
873
+ self-attention heads.
874
+ encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
875
+ Sequence of hidden-states at the output of the last layer of the encoder of the model.
876
+ encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
877
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
878
+ `(batch_size, sequence_length, hidden_size)`.
879
+
880
+ Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
881
+ encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
882
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
883
+ sequence_length)`.
884
+
885
+ Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
886
+ self-attention heads.
887
+ """
888
+
889
+ loss: tf.Tensor | None = None
890
+ start_logits: tf.Tensor = None
891
+ end_logits: tf.Tensor = None
892
+ past_key_values: List[tf.Tensor] | None = None
893
+ decoder_hidden_states: Tuple[tf.Tensor] | None = None
894
+ decoder_attentions: Tuple[tf.Tensor] | None = None
895
+ encoder_last_hidden_state: tf.Tensor | None = None
896
+ encoder_hidden_states: Tuple[tf.Tensor] | None = None
897
+ encoder_attentions: Tuple[tf.Tensor] | None = None
898
+
899
+
900
+ @dataclass
901
+ class TFSequenceClassifierOutputWithPast(ModelOutput):
902
+ """
903
+ Base class for outputs of sentence classification models.
904
+
905
+ Args:
906
+ loss (`tf.Tensor` of shape `(batch_size, )`, *optional*, returned when `labels` is provided):
907
+ Classification (or regression if config.num_labels==1) loss.
908
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
909
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
910
+ past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
911
+ List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
912
+ sequence_length, embed_size_per_head)`).
913
+
914
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
915
+ `past_key_values` input) to speed up sequential decoding.
916
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
917
+ Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
918
+ `(batch_size, sequence_length, hidden_size)`.
919
+
920
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
921
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
922
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
923
+ sequence_length)`.
924
+
925
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
926
+ heads.
927
+ """
928
+
929
+ loss: tf.Tensor | None = None
930
+ logits: tf.Tensor = None
931
+ past_key_values: List[tf.Tensor] | None = None
932
+ hidden_states: Tuple[tf.Tensor] | None = None
933
+ attentions: Tuple[tf.Tensor] | None = None
934
+
935
+
936
+ @dataclass
937
+ class TFImageClassifierOutputWithNoAttention(ModelOutput):
938
+ """
939
+ Base class for outputs of image classification models.
940
+
941
+ Args:
942
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
943
+ Classification (or regression if config.num_labels==1) loss.
944
+ logits (`tf.Tensor` of shape `(batch_size, config.num_labels)`):
945
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
946
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
947
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
948
+ the output of each stage) of shape `(batch_size, num_channels, height, width)`. Hidden-states (also called
949
+ feature maps) of the model at the output of each stage.
950
+ """
951
+
952
+ loss: tf.Tensor | None = None
953
+ logits: tf.Tensor = None
954
+ hidden_states: Optional[Tuple[tf.Tensor, ...]] = None
955
+
956
+
957
+ @dataclass
958
+ class TFMaskedImageModelingOutput(ModelOutput):
959
+ """
960
+ Base class for outputs of masked image completion / in-painting models.
961
+
962
+ Args:
963
+ loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
964
+ Reconstruction loss.
965
+ reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
966
+ Reconstructed / completed images.
967
+ hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when
968
+ `config.output_hidden_states=True`):
969
+ Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
970
+ the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
971
+ feature maps) of the model at the output of each stage.
972
+ attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when
973
+ `config.output_attentions=True`):
974
+ Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
975
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
976
+ heads.
977
+ """
978
+
979
+ loss: tf.Tensor | None = None
980
+ reconstruction: tf.Tensor = None
981
+ hidden_states: Tuple[tf.Tensor] | None = None
982
+ attentions: Tuple[tf.Tensor] | None = None
983
+
984
+ @property
985
+ def logits(self):
986
+ warnings.warn(
987
+ "logits attribute is deprecated and will be removed in version 5 of Transformers."
988
+ " Please use the reconstruction attribute to retrieve the final output instead.",
989
+ FutureWarning,
990
+ )
991
+ return self.reconstruction
modeling_tf_pytorch_utils.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch - TF 2.0 general utilities."""
17
+
18
+ import os
19
+ import re
20
+
21
+ import numpy
22
+
23
+ from .utils import (
24
+ ExplicitEnum,
25
+ expand_dims,
26
+ is_numpy_array,
27
+ is_safetensors_available,
28
+ is_torch_tensor,
29
+ logging,
30
+ reshape,
31
+ squeeze,
32
+ tensor_size,
33
+ )
34
+ from .utils import transpose as transpose_func
35
+
36
+
37
+ if is_safetensors_available():
38
+ from safetensors import safe_open
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+
44
+ class TransposeType(ExplicitEnum):
45
+ """
46
+ Possible ...
47
+ """
48
+
49
+ NO = "no"
50
+ SIMPLE = "simple"
51
+ CONV1D = "conv1d"
52
+ CONV2D = "conv2d"
53
+
54
+
55
+ def convert_tf_weight_name_to_pt_weight_name(
56
+ tf_name, start_prefix_to_remove="", tf_weight_shape=None, name_scope=None
57
+ ):
58
+ """
59
+ Convert a TF 2.0 model variable name in a pytorch model weight name.
60
+
61
+ Conventions for TF2.0 scopes -> PyTorch attribute names conversions:
62
+
63
+ - '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
64
+ - '_._' is replaced by a new level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
65
+
66
+ return tuple with:
67
+
68
+ - pytorch model weight name
69
+ - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
70
+ transposed with regards to each other
71
+ """
72
+ if name_scope is not None:
73
+ if not tf_name.startswith(name_scope) and "final_logits_bias" not in tf_name:
74
+ raise ValueError(
75
+ f"Weight name {tf_name} does not start with name_scope {name_scope}. This is an internal error "
76
+ "in Transformers, so (unless you were doing something really evil) please open an issue to report it!"
77
+ )
78
+ tf_name = tf_name[len(name_scope) :]
79
+ tf_name = tf_name.lstrip("/")
80
+ tf_name = tf_name.replace(":0", "") # device ids
81
+ tf_name = re.sub(
82
+ r"/[^/]*___([^/]*)/", r"/\1/", tf_name
83
+ ) # '$1___$2' is replaced by $2 (can be used to duplicate or remove layers in TF2.0 vs PyTorch)
84
+ tf_name = tf_name.replace(
85
+ "_._", "/"
86
+ ) # '_._' is replaced by a level separation (can be used to convert TF2.0 lists in PyTorch nn.ModulesList)
87
+ tf_name = re.sub(r"//+", "/", tf_name) # Remove empty levels at the end
88
+ tf_name = tf_name.split("/") # Convert from TF2.0 '/' separators to PyTorch '.' separators
89
+ # Some weights have a single name without "/" such as final_logits_bias in BART
90
+ if len(tf_name) > 1:
91
+ tf_name = tf_name[1:] # Remove level zero
92
+
93
+ tf_weight_shape = list(tf_weight_shape)
94
+
95
+ # When should we transpose the weights
96
+ if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4:
97
+ transpose = TransposeType.CONV2D
98
+ elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3:
99
+ transpose = TransposeType.CONV1D
100
+ elif bool(
101
+ tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
102
+ or "emb_projs" in tf_name
103
+ or "out_projs" in tf_name
104
+ ):
105
+ transpose = TransposeType.SIMPLE
106
+ else:
107
+ transpose = TransposeType.NO
108
+
109
+ # Convert standard TF2.0 names in PyTorch names
110
+ if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":
111
+ tf_name[-1] = "weight"
112
+ if tf_name[-1] == "beta":
113
+ tf_name[-1] = "bias"
114
+
115
+ # The SeparableConv1D TF layer contains two weights that are translated to PyTorch Conv1D here
116
+ if tf_name[-1] == "pointwise_kernel" or tf_name[-1] == "depthwise_kernel":
117
+ tf_name[-1] = tf_name[-1].replace("_kernel", ".weight")
118
+
119
+ # Remove prefix if needed
120
+ tf_name = ".".join(tf_name)
121
+ if start_prefix_to_remove:
122
+ tf_name = tf_name.replace(start_prefix_to_remove, "", 1)
123
+
124
+ return tf_name, transpose
125
+
126
+
127
+ def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True):
128
+ """
129
+ Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a
130
+ framework agnostic way.
131
+ """
132
+ if transpose is TransposeType.CONV2D:
133
+ # Conv2D weight:
134
+ # PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
135
+ # -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
136
+ axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1)
137
+ weight = transpose_func(weight, axes=axes)
138
+ elif transpose is TransposeType.CONV1D:
139
+ # Conv1D weight:
140
+ # PT: (num_out_channel, num_in_channel, kernel)
141
+ # -> TF: (kernel, num_in_channel, num_out_channel)
142
+ weight = transpose_func(weight, axes=(2, 1, 0))
143
+ elif transpose is TransposeType.SIMPLE:
144
+ weight = transpose_func(weight)
145
+
146
+ if match_shape is None:
147
+ return weight
148
+
149
+ if len(match_shape) < len(weight.shape):
150
+ weight = squeeze(weight)
151
+ elif len(match_shape) > len(weight.shape):
152
+ weight = expand_dims(weight, axis=0)
153
+
154
+ if list(match_shape) != list(weight.shape):
155
+ try:
156
+ weight = reshape(weight, match_shape)
157
+ except AssertionError as e:
158
+ e.args += (match_shape, match_shape)
159
+ raise e
160
+
161
+ return weight
162
+
163
+
164
+ #####################
165
+ # PyTorch => TF 2.0 #
166
+ #####################
167
+
168
+
169
+ def load_pytorch_checkpoint_in_tf2_model(
170
+ tf_model,
171
+ pytorch_checkpoint_path,
172
+ tf_inputs=None,
173
+ allow_missing_keys=False,
174
+ output_loading_info=False,
175
+ _prefix=None,
176
+ tf_to_pt_weight_rename=None,
177
+ ):
178
+ """Load pytorch checkpoints in a TF 2.0 model"""
179
+ try:
180
+ import tensorflow as tf # noqa: F401
181
+ import torch # noqa: F401
182
+ from safetensors.torch import load_file as safe_load_file # noqa: F401
183
+ except ImportError:
184
+ logger.error(
185
+ "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
186
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
187
+ )
188
+ raise
189
+
190
+ # Treats a single file as a collection of shards with 1 shard.
191
+ if isinstance(pytorch_checkpoint_path, str):
192
+ pytorch_checkpoint_path = [pytorch_checkpoint_path]
193
+
194
+ # Loads all shards into a single state dictionary
195
+ pt_state_dict = {}
196
+ for path in pytorch_checkpoint_path:
197
+ pt_path = os.path.abspath(path)
198
+ logger.info(f"Loading PyTorch weights from {pt_path}")
199
+ if pt_path.endswith(".safetensors"):
200
+ state_dict = safe_load_file(pt_path)
201
+ else:
202
+ weights_only_kwarg = {"weights_only": True}
203
+ state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
204
+
205
+ pt_state_dict.update(state_dict)
206
+
207
+ logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters")
208
+
209
+ return load_pytorch_weights_in_tf2_model(
210
+ tf_model,
211
+ pt_state_dict,
212
+ tf_inputs=tf_inputs,
213
+ allow_missing_keys=allow_missing_keys,
214
+ output_loading_info=output_loading_info,
215
+ _prefix=_prefix,
216
+ tf_to_pt_weight_rename=tf_to_pt_weight_rename,
217
+ )
218
+
219
+
220
+ def load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=None, allow_missing_keys=False):
221
+ """Load pytorch checkpoints in a TF 2.0 model"""
222
+ pt_state_dict = pt_model.state_dict()
223
+
224
+ return load_pytorch_weights_in_tf2_model(
225
+ tf_model, pt_state_dict, tf_inputs=tf_inputs, allow_missing_keys=allow_missing_keys
226
+ )
227
+
228
+
229
+ def load_pytorch_weights_in_tf2_model(
230
+ tf_model,
231
+ pt_state_dict,
232
+ tf_inputs=None,
233
+ allow_missing_keys=False,
234
+ output_loading_info=False,
235
+ _prefix=None,
236
+ tf_to_pt_weight_rename=None,
237
+ ):
238
+ """Load pytorch state_dict in a TF 2.0 model."""
239
+ try:
240
+ import tensorflow as tf # noqa: F401
241
+ import torch # noqa: F401
242
+ except ImportError:
243
+ logger.error(
244
+ "Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
245
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
246
+ )
247
+ raise
248
+
249
+ # Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision
250
+ pt_state_dict = {
251
+ k: v.numpy() if v.dtype != torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
252
+ }
253
+ return load_pytorch_state_dict_in_tf2_model(
254
+ tf_model,
255
+ pt_state_dict,
256
+ tf_inputs=tf_inputs,
257
+ allow_missing_keys=allow_missing_keys,
258
+ output_loading_info=output_loading_info,
259
+ _prefix=_prefix,
260
+ tf_to_pt_weight_rename=tf_to_pt_weight_rename,
261
+ )
262
+
263
+
264
+ def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name):
265
+ if len(unexpected_keys) > 0:
266
+ logger.warning(
267
+ "Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
268
+ f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing"
269
+ f" {class_name} from a PyTorch model trained on another task or with another architecture"
270
+ " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
271
+ f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect"
272
+ " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
273
+ " BertForSequenceClassification model)."
274
+ )
275
+ else:
276
+ logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n")
277
+ if len(missing_keys) > 0:
278
+ logger.warning(
279
+ f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the"
280
+ f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
281
+ " down-stream task to be able to use it for predictions and inference."
282
+ )
283
+ else:
284
+ logger.warning(
285
+ f"All the weights of {class_name} were initialized from the PyTorch model.\n"
286
+ "If your task is similar to the task the model of the checkpoint was trained on, "
287
+ f"you can already use {class_name} for predictions without further training."
288
+ )
289
+
290
+ if len(mismatched_keys) > 0:
291
+ mismatched_warning = "\n".join(
292
+ [
293
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
294
+ for key, shape1, shape2 in mismatched_keys
295
+ ]
296
+ )
297
+ logger.warning(
298
+ f"Some weights of {class_name} were not initialized from the model checkpoint"
299
+ f" are newly initialized because the shapes did not"
300
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
301
+ " to use it for predictions and inference."
302
+ )
303
+
304
+
305
+ def load_pytorch_state_dict_in_tf2_model(
306
+ tf_model,
307
+ pt_state_dict,
308
+ tf_inputs=None,
309
+ allow_missing_keys=False,
310
+ output_loading_info=False,
311
+ _prefix=None,
312
+ tf_to_pt_weight_rename=None,
313
+ ignore_mismatched_sizes=False,
314
+ skip_logger_warnings=False,
315
+ ):
316
+ """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
317
+ safetensors archive created with the safe_open() function."""
318
+ import tensorflow as tf
319
+
320
+ if tf_inputs is None:
321
+ tf_inputs = tf_model.dummy_inputs
322
+
323
+ if _prefix is None:
324
+ _prefix = ""
325
+ if tf_inputs:
326
+ with tf.name_scope(_prefix):
327
+ tf_model(tf_inputs, training=False) # Make sure model is built
328
+ # Convert old format to new format if needed from a PyTorch state_dict
329
+ tf_keys_to_pt_keys = {}
330
+ for key in pt_state_dict.keys():
331
+ new_key = None
332
+ if "gamma" in key:
333
+ new_key = key.replace("gamma", "weight")
334
+ if "beta" in key:
335
+ new_key = key.replace("beta", "bias")
336
+ if "running_var" in key:
337
+ new_key = key.replace("running_var", "moving_variance")
338
+ if "running_mean" in key:
339
+ new_key = key.replace("running_mean", "moving_mean")
340
+
341
+ # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
342
+ key_components = key.split(".")
343
+ name = None
344
+ if key_components[-3::2] == ["parametrizations", "original0"]:
345
+ name = key_components[-2] + "_g"
346
+ elif key_components[-3::2] == ["parametrizations", "original1"]:
347
+ name = key_components[-2] + "_v"
348
+ if name is not None:
349
+ key_components = key_components[:-3] + [name]
350
+ new_key = ".".join(key_components)
351
+
352
+ if new_key is None:
353
+ new_key = key
354
+ tf_keys_to_pt_keys[new_key] = key
355
+
356
+ # Matt: All TF models store the actual model stem in a MainLayer class, including the base model.
357
+ # In PT, the derived models (with heads) use the base model class as the stem instead,
358
+ # and there is no MainLayer class. This means that TF base classes have one
359
+ # extra layer in their weight names, corresponding to the MainLayer class. This code block compensates for that.
360
+ start_prefix_to_remove = ""
361
+ if not any(s.startswith(tf_model.base_model_prefix) for s in tf_keys_to_pt_keys.keys()):
362
+ start_prefix_to_remove = tf_model.base_model_prefix + "."
363
+
364
+ symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
365
+ tf_loaded_numel = 0
366
+ all_pytorch_weights = set(tf_keys_to_pt_keys.keys())
367
+ missing_keys = []
368
+ mismatched_keys = []
369
+ is_safetensor_archive = hasattr(pt_state_dict, "get_tensor")
370
+ for symbolic_weight in symbolic_weights:
371
+ sw_name = symbolic_weight.name
372
+ name, transpose = convert_tf_weight_name_to_pt_weight_name(
373
+ sw_name,
374
+ start_prefix_to_remove=start_prefix_to_remove,
375
+ tf_weight_shape=symbolic_weight.shape,
376
+ name_scope=_prefix,
377
+ )
378
+ if tf_to_pt_weight_rename is not None:
379
+ aliases = tf_to_pt_weight_rename(name) # Is a tuple to account for possible name aliasing
380
+ for alias in aliases: # The aliases are in priority order, take the first one that matches
381
+ if alias in tf_keys_to_pt_keys:
382
+ name = alias
383
+ break
384
+ else:
385
+ # If none of the aliases match, just use the first one (it'll be reported as missing)
386
+ name = aliases[0]
387
+
388
+ # Find associated numpy array in pytorch model state dict
389
+ if name not in tf_keys_to_pt_keys:
390
+ if allow_missing_keys:
391
+ missing_keys.append(name)
392
+ continue
393
+ elif tf_model._keys_to_ignore_on_load_missing is not None:
394
+ # authorized missing keys don't have to be loaded
395
+ if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
396
+ continue
397
+ raise AttributeError(f"{name} not found in PyTorch model")
398
+ state_dict_name = tf_keys_to_pt_keys[name]
399
+ if is_safetensor_archive:
400
+ array = pt_state_dict.get_tensor(state_dict_name)
401
+ else:
402
+ array = pt_state_dict[state_dict_name]
403
+ try:
404
+ array = apply_transpose(transpose, array, symbolic_weight.shape)
405
+ except tf.errors.InvalidArgumentError as e:
406
+ if not ignore_mismatched_sizes:
407
+ error_msg = str(e)
408
+ error_msg += (
409
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
410
+ )
411
+ raise tf.errors.InvalidArgumentError(error_msg)
412
+ else:
413
+ mismatched_keys.append((name, array.shape, symbolic_weight.shape))
414
+ continue
415
+
416
+ tf_loaded_numel += tensor_size(array)
417
+
418
+ symbolic_weight.assign(tf.cast(array, symbolic_weight.dtype))
419
+ del array # Immediately free memory to keep peak usage as low as possible
420
+ all_pytorch_weights.discard(name)
421
+
422
+ logger.info(f"Loaded {tf_loaded_numel:,} parameters in the TF 2.0 model.")
423
+
424
+ unexpected_keys = list(all_pytorch_weights)
425
+
426
+ if tf_model._keys_to_ignore_on_load_missing is not None:
427
+ for pat in tf_model._keys_to_ignore_on_load_missing:
428
+ missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
429
+ if tf_model._keys_to_ignore_on_load_unexpected is not None:
430
+ for pat in tf_model._keys_to_ignore_on_load_unexpected:
431
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
432
+ if not skip_logger_warnings:
433
+ _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
434
+
435
+ if output_loading_info:
436
+ loading_info = {
437
+ "missing_keys": missing_keys,
438
+ "unexpected_keys": unexpected_keys,
439
+ "mismatched_keys": mismatched_keys,
440
+ }
441
+ return tf_model, loading_info
442
+
443
+ return tf_model
444
+
445
+
446
+ def load_sharded_pytorch_safetensors_in_tf2_model(
447
+ tf_model,
448
+ safetensors_shards,
449
+ tf_inputs=None,
450
+ allow_missing_keys=False,
451
+ output_loading_info=False,
452
+ _prefix=None,
453
+ tf_to_pt_weight_rename=None,
454
+ ignore_mismatched_sizes=False,
455
+ ):
456
+ all_loading_infos = []
457
+ for shard in safetensors_shards:
458
+ with safe_open(shard, framework="tf") as safetensors_archive:
459
+ tf_model, loading_info = load_pytorch_state_dict_in_tf2_model(
460
+ tf_model,
461
+ safetensors_archive,
462
+ tf_inputs=tf_inputs,
463
+ allow_missing_keys=allow_missing_keys,
464
+ output_loading_info=True,
465
+ _prefix=_prefix,
466
+ tf_to_pt_weight_rename=tf_to_pt_weight_rename,
467
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
468
+ skip_logger_warnings=True, # We will emit merged warnings at the end
469
+ )
470
+ all_loading_infos.append(loading_info)
471
+ # Now we just need to merge the loading info
472
+ # Keys are missing only if they're missing in *every* shard
473
+ missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos]))
474
+ # Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard
475
+ unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], [])
476
+ mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])
477
+
478
+ _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)
479
+
480
+ if output_loading_info:
481
+ loading_info = {
482
+ "missing_keys": missing_keys,
483
+ "unexpected_keys": unexpected_keys,
484
+ "mismatched_keys": mismatched_keys,
485
+ }
486
+ return tf_model, loading_info
487
+
488
+ return tf_model
489
+
490
+
491
+ #####################
492
+ # TF 2.0 => PyTorch #
493
+ #####################
494
+
495
+
496
+ def load_tf2_checkpoint_in_pytorch_model(
497
+ pt_model, tf_checkpoint_path, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
498
+ ):
499
+ """
500
+ Load TF 2.0 HDF5 checkpoint in a PyTorch model We use HDF5 to easily do transfer learning (see
501
+ https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
502
+ """
503
+ try:
504
+ import tensorflow as tf # noqa: F401
505
+ import torch # noqa: F401
506
+ except ImportError:
507
+ logger.error(
508
+ "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
509
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
510
+ )
511
+ raise
512
+
513
+ import transformers
514
+
515
+ from .modeling_tf_utils import load_tf_weights
516
+
517
+ logger.info(f"Loading TensorFlow weights from {tf_checkpoint_path}")
518
+
519
+ # Instantiate and load the associated TF 2.0 model
520
+ tf_model_class_name = "TF" + pt_model.__class__.__name__ # Add "TF" at the beginning
521
+ tf_model_class = getattr(transformers, tf_model_class_name)
522
+ tf_model = tf_model_class(pt_model.config)
523
+
524
+ if tf_inputs is None:
525
+ tf_inputs = tf_model.dummy_inputs
526
+
527
+ if tf_inputs is not None:
528
+ tf_model(tf_inputs, training=False) # Make sure model is built
529
+
530
+ load_tf_weights(tf_model, tf_checkpoint_path)
531
+
532
+ return load_tf2_model_in_pytorch_model(
533
+ pt_model, tf_model, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
534
+ )
535
+
536
+
537
+ def load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=False, output_loading_info=False):
538
+ """Load TF 2.0 model in a pytorch model"""
539
+ weights = tf_model.weights
540
+
541
+ return load_tf2_weights_in_pytorch_model(
542
+ pt_model, weights, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
543
+ )
544
+
545
+
546
+ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=False, output_loading_info=False):
547
+ """Load TF2.0 symbolic weights in a PyTorch model"""
548
+ try:
549
+ import tensorflow as tf # noqa: F401
550
+ import torch # noqa: F401
551
+ except ImportError:
552
+ logger.error(
553
+ "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed. Please see "
554
+ "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions."
555
+ )
556
+ raise
557
+
558
+ tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights}
559
+ return load_tf2_state_dict_in_pytorch_model(
560
+ pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
561
+ )
562
+
563
+
564
+ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False):
565
+ import torch
566
+
567
+ new_pt_params_dict = {}
568
+ current_pt_params_dict = dict(pt_model.named_parameters())
569
+
570
+ # Make sure we are able to load PyTorch base models as well as derived models (with heads)
571
+ # TF models always have a prefix, some of PyTorch models (base ones) don't
572
+ start_prefix_to_remove = ""
573
+ if not any(s.startswith(pt_model.base_model_prefix) for s in current_pt_params_dict.keys()):
574
+ start_prefix_to_remove = pt_model.base_model_prefix + "."
575
+
576
+ # Build a map from potential PyTorch weight names to TF 2.0 Variables
577
+ tf_weights_map = {}
578
+ for name, tf_weight in tf_state_dict.items():
579
+ pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
580
+ name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
581
+ )
582
+ tf_weights_map[pt_name] = (tf_weight, transpose)
583
+
584
+ all_tf_weights = set(tf_weights_map.keys())
585
+ loaded_pt_weights_data_ptr = {}
586
+ missing_keys_pt = []
587
+ for pt_weight_name, pt_weight in current_pt_params_dict.items():
588
+ # Handle PyTorch shared weight ()not duplicated in TF 2.0
589
+ if pt_weight.data_ptr() in loaded_pt_weights_data_ptr:
590
+ new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
591
+ continue
592
+
593
+ pt_weight_name_to_check = pt_weight_name
594
+ # New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
595
+ key_components = pt_weight_name.split(".")
596
+ name = None
597
+ if key_components[-3::2] == ["parametrizations", "original0"]:
598
+ name = key_components[-2] + "_g"
599
+ elif key_components[-3::2] == ["parametrizations", "original1"]:
600
+ name = key_components[-2] + "_v"
601
+ if name is not None:
602
+ key_components = key_components[:-3] + [name]
603
+ pt_weight_name_to_check = ".".join(key_components)
604
+
605
+ # Find associated numpy array in pytorch model state dict
606
+ if pt_weight_name_to_check not in tf_weights_map:
607
+ if allow_missing_keys:
608
+ missing_keys_pt.append(pt_weight_name)
609
+ continue
610
+
611
+ raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model")
612
+
613
+ array, transpose = tf_weights_map[pt_weight_name_to_check]
614
+
615
+ array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)
616
+
617
+ if numpy.isscalar(array):
618
+ array = numpy.array(array)
619
+ if not is_torch_tensor(array) and not is_numpy_array(array):
620
+ array = array.numpy()
621
+ if is_numpy_array(array):
622
+ # Convert to torch tensor
623
+ array = torch.from_numpy(array)
624
+
625
+ new_pt_params_dict[pt_weight_name] = array
626
+ loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array
627
+ all_tf_weights.discard(pt_weight_name)
628
+
629
+ missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
630
+ missing_keys += missing_keys_pt
631
+
632
+ # Some models may have keys that are not in the state by design, removing them before needlessly warning
633
+ # the user.
634
+ if pt_model._keys_to_ignore_on_load_missing is not None:
635
+ for pat in pt_model._keys_to_ignore_on_load_missing:
636
+ missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
637
+
638
+ if pt_model._keys_to_ignore_on_load_unexpected is not None:
639
+ for pat in pt_model._keys_to_ignore_on_load_unexpected:
640
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
641
+
642
+ if len(unexpected_keys) > 0:
643
+ logger.warning(
644
+ "Some weights of the TF 2.0 model were not used when initializing the PyTorch model"
645
+ f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
646
+ f" {pt_model.__class__.__name__} from a TF 2.0 model trained on another task or with another architecture"
647
+ " (e.g. initializing a BertForSequenceClassification model from a TFBertForPreTraining model).\n- This IS"
648
+ f" NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect"
649
+ " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
650
+ " TFBertForSequenceClassification model)."
651
+ )
652
+ else:
653
+ logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n")
654
+ if len(missing_keys) > 0:
655
+ logger.warning(
656
+ f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model and are newly"
657
+ f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
658
+ " use it for predictions and inference."
659
+ )
660
+ else:
661
+ logger.warning(
662
+ f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
663
+ "If your task is similar to the task the model of the checkpoint was trained on, "
664
+ f"you can already use {pt_model.__class__.__name__} for predictions without further training."
665
+ )
666
+
667
+ logger.info(f"Weights or buffers not loaded from TF 2.0 model: {all_tf_weights}")
668
+
669
+ if output_loading_info:
670
+ loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
671
+ return pt_model, loading_info
672
+
673
+ return pt_model
modeling_tf_utils.py ADDED
The diff for this file is too large to render. See raw diff
 
modeling_utils.py ADDED
The diff for this file is too large to render. See raw diff