File size: 105,353 Bytes
6c5dce9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 |
import torch
import types
import inspect
import importlib
import transformers
import torch.nn as nn
from transformers import Cache, GenerationConfig
from typing import Any, Dict, List, Optional, Tuple, Union
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers import Cache, GenerationConfig
UNSUPPORTED_GENERATION_ARGS = [
"cache_implementation", # cache-related arguments, here we always use SepCache
"cache_config",
"return_legacy_cache",
"num_beams", # beam search (and cousin techniques) are not supported
"compile_config", # SepCache doesn't support torch.compile
"assistant_model", # it also doesn't support speculative decoding
]
##################################################### Functions to Patch #######################################################
def truncate_input_ids_4_autoregression(input_ids, key_states):
if input_ids.shape[-1] != key_states.shape[-2]:
assert input_ids.shape[-1] >= key_states.shape[-2]
truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
return truncated_input_ids
else:
return input_ids
def llama_atten_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
if hasattr(self, "head_dim"):
head_dim = self.head_dim
elif hasattr(self, "head_size"):
head_dim = self.head_size
hidden_shape = (*input_shape, -1, head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
###########################SepCache########################
assert isinstance(past_key_value, SepCache), f"`past_key_value` must be of the type: `SepCache`."
APPLY_PE_SHIFT = past_key_value.APPLY_PE_SHIFT
APPLY_PES_INSIDE = past_key_value.APPLY_PES_INSIDE
###########################################################
########################Monkey Patching####################
module = importlib.import_module(self.__module__)
apply_rotary_pos_emb = module.apply_rotary_pos_emb
rotate_half = module.rotate_half
eager_attention_forward = module.eager_attention_forward
ALL_ATTENTION_FUNCTIONS = module.ALL_ATTENTION_FUNCTIONS
###########################################################
if not APPLY_PE_SHIFT:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# ##################################################Default#########################################################
# sin and cos are specific to RoPE models; cache_position needed for the static cache
# cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# ##################################################################################################################
##################################################SepCache#########################################################
# sin and cos are specific to RoPE models; position_ids needed for the static cache
if APPLY_PE_SHIFT and (not APPLY_PES_INSIDE):
### At least the shifted `sin` and `cos` should be properly provided (not `None`).
cache_kwargs = {"sin": sin, "cos": cos, "cos_q": cos_q, "sin_q": sin_q, "cache_position": cache_position, "partial_rotation_size": None }
else:
cache_kwargs = {}
if "kwargs" in locals():
pass
elif "flash_attn_kwargs" in locals():
kwargs = flash_attn_kwargs
else:
raise NameError("`kwargs` or `flash_attn_kwargs` should be given and they need to contain `sepllm_kwargs` (which contains `input_ids`) and `position_ids`.")
if "input_ids" not in locals():
if "input_ids" in kwargs:
input_ids = kwargs.get("input_ids", None)
else:
sepllm_kwargs = kwargs.get("sepllm_kwargs", None)
assert sepllm_kwargs is not None, f"`sepllm_kwargs` must be provided when `input_ids` is not given."
input_ids = sepllm_kwargs.get("input_ids", None)
assert input_ids is not None, f"`input_ids` must be properly provided directly or through `sepllm_kwargs` when calling `update()` in `SepCache`."
if "position_ids" not in locals():
position_ids = kwargs.get("position_ids")
assert input_ids is not None, f"`input_ids` must be properly provided when calling `update()` in `SepCache`."
bsz, q_len, _ = hidden_states.size()
input_ids = truncate_input_ids_4_autoregression(input_ids = input_ids, key_states = key_states )
if APPLY_PE_SHIFT:
key_states, value_states, query_states = past_key_value.update(
key_states = key_states,
value_states = value_states,
query_states = query_states,
input_ids = input_ids,
layer_idx = self.layer_idx,
position_ids = position_ids,
PREFILLING_FLAG = q_len > 1,
cache_kwargs = cache_kwargs )
else:
key_states, value_states = past_key_value.update(
key_states = key_states,
value_states = value_states,
input_ids = input_ids,
layer_idx = self.layer_idx,
position_ids = position_ids,
PREFILLING_FLAG = q_len > 1,
cache_kwargs = cache_kwargs )
seq_len = past_key_value.get_usable_length(self.layer_idx)
if attention_mask is not None:
attention_mask = attention_mask[..., :seq_len]
##################################################################################################################
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
"""Validates model kwargs for generation. Generate argument typos will also be caught here."""
# If a `Cache` instance is passed, checks whether the model is compatible with it
if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class:
raise ValueError(
f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please "
"check the model documentation for supported cache formats."
)
# Excludes arguments that are handled before calling any model function
if self.config.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
# `kwargs`/`model_kwargs` is often used to handle optional forward pass inputs like `attention_mask`. If
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
# Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
if self.config.is_encoder_decoder:
base_model = getattr(self, self.base_model_prefix, None)
# allow encoder kwargs
encoder = getattr(self, "encoder", None)
# `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
# Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
# TODO: A better way to handle this.
if encoder is None and base_model is not None:
encoder = getattr(base_model, "encoder", None)
if encoder is not None:
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
model_args |= encoder_model_args
# allow decoder kwargs
decoder = getattr(self, "decoder", None)
if decoder is None and base_model is not None:
decoder = getattr(base_model, "decoder", None)
if decoder is not None:
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}
for key, value in model_kwargs.items():
# #############################Default###########################
# if value is not None and key not in model_args:
# unused_model_args.append(key)
# ###############################################################
###############################SepCache###########################
if (value is not None) and (key not in model_args) and ("sep" not in str(key).lower()):
unused_model_args.append(key)
###################################################################
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
" generate arguments will also show up in this list)"
)
#############################################################End################################################################
########################################################## SepCache ############################################################
class SepCache(Cache):
"""
A cache as described in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094). In the training phase,
SepLLM condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the
corresponding SepCache only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.
It stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Frequently-Used Parameters:
`init_cache_size: Union[int, List]`:
The maximum number of KVs to be stored for initial tokens.
In the paper, the hyperparameter `a` is an abbreviated alias for `self.init_cache_size`.
`sep_cache_size: Union[int, List]`:
The maximum number of KVs to be stored for separator tokens.
In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
`local_size: Union[int, List]`:
The maximum number of KVs to be stored for local tokens (i.e., sliding window).
In the paper, the hyperparameter `w` is an abbreviated alias for `self.local_size`.
`cache_size: Union[int, List]`:
The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.
In the paper, the hyperparameter `c` is an abbreviated alias for `self.cache_size`.
Concerning these four parameters above:
When a list is passed (its length must be `layer_num`), it represents different values for each layer.
When an integer is passed, it means the setting is the same for all layers.
`USE_MAX_SEP_CACHE: bool`:
If True, it means we only keep at most `self.sep_cache_size` seperators' KVs.
If the number exceeds this limit, older separator's KVs will be discarded, keeping only the most recent `self.sep_cache_size` KVs.
In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
`separator_token_ids: List[int]`:
The token ids of the separator tokens for the current model's tokenizer.
We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you
to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).
`PADDING_ID: int`:
The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.
Important Note:
When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache.
However, you must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`.
Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime.
To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality:
`init_cache_size` + `sep_cache_size` + `local_size` < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025]
to leave room for `left_padding_offset`.
Please refer to the `__init__` function's comments for more details on the parameters.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, SepCache
>>> import torch
>>> from huggingface_hub import login
>>> login("hf_xxxXXXxxx")
>>> def to_cuda(a_dict: dict) -> dict:
>>> new_dict = {}
>>> for k,v in a_dict.items():
>>> if isinstance(v, torch.Tensor):
>>> new_dict[k] = v.cuda()
>>> else:
>>> new_dict[k] = v
>>> return new_dict
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", attn_implementation="flash_attention_2", device_map="cuda:0")
>>> model.bfloat16().cuda()
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
>>> inputs = tokenizer(text="My name is Llama 3", return_tensors="pt")
>>> inputs = to_cuda(inputs)
>>> # Prepare a cache and pass it to model's forward; `layer_num` is the number of layers for the pretrained model.
>>> past_key_values = SepCache(init_cache_size=4, sep_cache_size=128, local_size=256, cache_size=512, layer_num=32, USE_MAX_SEP_CACHE=True, model_type='llama')
>>> # `separator_token_ids` and `PADDING_ID` must also be provided if you are not using `model_type='llama'` like this demo.
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access SepCache filled with keys/values
SepCache()
```
```python
>>> ## When using the `update` function of SepCache to update the keys/values and the past token ids (necessary in SepCache), the current `input_ids` must also be provided.
>>> key_states, value_states = past_key_values.update(
key_states = key_states,
value_states = value_states,
input_ids = input_ids,
layer_idx = layer_idx,
PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
)
```
For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM
"""
# is_sliding = True
@staticmethod
def slice_on_1d(x, start, end):
return x[:, start:end, ...]
@staticmethod
def slice_on_2d(x, start, end):
return x[:, :, start:end, ...]
@staticmethod
def slice_on_3d(x, start, end):
return x[:, :, :, start:end, ...]
@staticmethod
def sep_1bat_select_on_1d(x, Bid, sep_index, min_sep_num=None, max_sep_num=None, SEP_PADDING_IN_BATCH=True):
"""
For the record with index `Bid` in a batch, extract the K/V states corresponding to the separators on dimension 1.
If `SEP_PADDING_IN_BATCH=True`, pad to the longest length (i.e. `max_sep_num`);
otherwise, truncate to the shortest length (i.e. `min_sep_num`).
"""
sep_index = sep_index.to(x.device)
if SEP_PADDING_IN_BATCH: ## Need padding
assert max_sep_num is not None, f"if `SEP_PADDING_IN_BATCH=True`, `max_sep_num` should not be None"
new_x_sep = x[Bid, sep_index, ...] # # batch x seqlen x head x dim --> sep_num x head x dim
padding_num = max_sep_num - new_x_sep.shape[0]
if padding_num > 0 :
assert padding_num <= x.shape[1], f"`padding_num` should be <= `x.shape[1]`, i.e. x's seqlen"
new_x_pad = x[Bid, -padding_num: , ...] # padding_num x head x dim
return torch.cat([new_x_sep, new_x_pad ] , dim=0) # max_sep_num x head x dim
else:
return new_x_sep # max_sep_num x head x dim
if min_sep_num is None:
return x[Bid, sep_index, ...] # # batch x seqlen x head x dim --> sep_num x head x dim
else: ## `min_sep_num` is provided. Need truncation
new_x = x[Bid, sep_index, ...] # # batch x seqlen x head x dim --> sep_num x head x dim
return new_x[ :min_sep_num, ...] # # min_sep_num x head x dim
@staticmethod
def sep_1bat_select_on_2d(x, Bid, sep_index, min_sep_num=None, max_sep_num=None, SEP_PADDING_IN_BATCH=True):
"""
For the record with index `Bid` in a batch, extract the K/V states corresponding to the separators on dimension 2.
If `SEP_PADDING_IN_BATCH=True`, pad to the longest length (i.e. `max_sep_num`);
otherwise, truncate to the shortest length (i.e. `min_sep_num`).
"""
sep_index = sep_index.to(x.device)
if SEP_PADDING_IN_BATCH: ## Need padding
assert max_sep_num is not None, f"if `SEP_PADDING_IN_BATCH=True`, `max_sep_num` should not be None"
new_x_sep = x[Bid, :, sep_index, ...] # # batch x head x seqlen x dim --> head x sep_num x dim
padding_num = max_sep_num - new_x_sep.shape[-2]
if padding_num > 0 :
assert padding_num<= x.shape[-2], f"`padding_num` should be <= `x.shape[-2]`, i.e. x's seqlen"
new_x_pad = x[Bid, :, -padding_num: , ...] # head x padding_num x dim
return torch.cat([new_x_sep, new_x_pad ] , dim=-2) # head x max_sep_num x dim
else:
return new_x_sep # head x max_sep_num x dim
if min_sep_num is None:
return x[Bid, :, sep_index, ...] # # batch x head x seqlen x dim --> head x sep_num x dim
else: ## `min_sep_num` is provided. Need truncation
new_x = x[Bid, :, sep_index, ...] # # batch x head x seqlen x dim --> head x sep_num x dim
return new_x[:, :min_sep_num, ...] # # head x min_sep_num x dim
@staticmethod
def sep_1bat_select_on_3d(x, Bid, sep_index, min_sep_num=None, max_sep_num=None, SEP_PADDING_IN_BATCH=True):
"""
For the record with index `Bid` in a batch, extract the K/V states corresponding to the separators on dimension 3.
If `SEP_PADDING_IN_BATCH=True`, pad to the longest length (i.e. `max_sep_num`);
otherwise, truncate to the shortest length (i.e. `min_sep_num`).
"""
sep_index = sep_index.to(x.device)
if SEP_PADDING_IN_BATCH: ## Need padding
assert max_sep_num is not None, f"if `SEP_PADDING_IN_BATCH=True`, `max_sep_num` should not be None"
new_x_sep = x[Bid, :, :, sep_index, ...] # # batch x head x dim x seqlen --> head x dim x sep_num
padding_num = max_sep_num - new_x_sep.shape[-1]
if padding_num > 0 :
assert padding_num <= x.shape[-1], f"`padding_num` should be <= `x.shape[-1]`, i.e. x's seqlen"
new_x_pad = x[Bid, :, :, -padding_num:, ...] # head x dim x padding_num
return torch.cat([new_x_sep, new_x_pad] , dim=-1) # head x dim x max_sep_num
else:
return new_x_sep # head x dim x max_sep_num
if min_sep_num is None:
return x[Bid, :, :, sep_index, ...] # # batch x head x dim x seqlen --> head x dim x sep_num
else: ## `min_sep_num` is provided. Need truncation
new_x = x[Bid, :, :, sep_index, ...] # # batch x head x dim x seqlen --> head x dim x sep_num
return new_x[:, :, :min_sep_num, ...] # # head x dim x min_sep_num
DIM_TO_SLICE = {
1: slice_on_1d,
2: slice_on_2d,
3: slice_on_3d,
}
BAT_DIM_TO_SELECT = {
1: sep_1bat_select_on_1d,
2: sep_1bat_select_on_2d,
3: sep_1bat_select_on_3d,
}
def __init__(self,
## For SepLLM
init_cache_size: Union[int, List] = 4,
sep_cache_size: Union[int, List] = 64,
local_size: Union[int, List]=256,
cache_size: Union[int, List]=512,
SEP_ACCUMULATION: bool = True,
USE_MAX_SEP_CACHE: bool = False,
SEP_PADDING_IN_BATCH: bool = False,
separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided.
PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.
## For inheritance & initialization states
past_tok_ids: List[torch.Tensor] = None, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
key_cache: List[torch.Tensor] = None,
value_cache: List[torch.Tensor] = None,
## For debugging
PRINT_KV_RATIO_INSIDE: bool = False,
print_KV_inside_per_steps: int = 1000,
_seen_tokens: int = 0,
_kept_kv_ratio: List[Tuple[int]] = None,
### For positional encoding shifting
APPLY_PE_SHIFT: bool = False,
APPLY_PES_INSIDE: bool = True,
_shifted_position_ids: List[torch.Tensor] = None,
_rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
_rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
pe_scaling_factor:float = 1.0,
pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
max_position_embeddings: int = 8192,
base: int=10000, ## The base for RoPE.
## For basic transformer architecture
k_seq_dim: int=2, ## The dimension for seq_len in key tensors
v_seq_dim: int=2, ## The dimension for seq_len in value tensors
layer_num: int = None, ## required for initialization
model_type: str = None, ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
device = None
) -> None:
"""
`SEP_ACCUMULATION`: If True, it means we will try to accumulate all the KVs for seperators. If False, only the `new_sep_kv` compressed from the `past_win_kv` will be kept (see function `compress_kv_cache_and_tokids_layer_wise`).
`USE_MAX_SEP_CACHE`: If True, it means we only keep at most `self.sep_cache_size` seperators' KVs. If the number exceeds this limit, older separator's KVs will be discarded, keeping only the most recent `self.sep_cache_size` KVs. In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
`SEP_PADDING_IN_BATCH`: If True, it means that SepCache will pad separator tokens in other records to be aligned with the record with the most separators in a batch. If False, it means that SepCache will truncate older separator tokens in other records to be aligned with the record with the fewest separators in a batch.
Note: If `SEP_ACCUMULATION=True` and `USE_MAX_SEP_CACHE=False`, as the number of input tokens increases, the number of separators in the KV cache will also accumulate endlessly
and `self.cache_size` will also be infinitely expanded (no longer fixed).
When `SEP_PADDING_IN_BATCH=True` is used in combination with `USE_MAX_SEP_CACHE=False` and `SEP_ACCUMULATION=True`, the KV cache will accumulate indefinitely,
and since `SEP_PADDING_IN_BATCH=True`, the KVs of all separators will be retained (rather than being truncated).
For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM
"""
super().__init__()
if (key_cache is not None) or (value_cache is not None) or (past_tok_ids is not None):
assert isinstance(key_cache, list)
assert isinstance(value_cache, list)
assert isinstance(past_tok_ids, list), f"For SepCache, if `key_cache` and `value_cache` are given (e.g., provided from legacy `past_key_values`), `past_tok_ids` corresponding to `key_cache` and `value_cache` must also be provided to initialize SepCache."
assert len(key_cache) == len(past_tok_ids), f"The length of `key_cache` ({len(key_cache)}) should be equal to that of `past_tok_ids` ({len(past_tok_ids)})."
assert len(value_cache) == len(past_tok_ids), f"The length of `value_cache` ({len(value_cache)}) should be equal to that of `past_tok_ids` ({len(past_tok_ids)})."
assert layer_num is not None, f"`layer_num` must be provided according to the pretrained model."
## For basic parameters & states
self.key_cache: List[torch.Tensor] = key_cache if key_cache is not None else []
self.value_cache: List[torch.Tensor] = value_cache if value_cache is not None else []
self.k_seq_dim = k_seq_dim ## The dimension for the seq_len in key states. Typically, 2.
self.v_seq_dim = v_seq_dim ## The dimension for the seq_len in value states. Typically, 2.
self.k_slice = self.DIM_TO_SLICE[k_seq_dim]
self.v_slice = self.DIM_TO_SLICE[v_seq_dim]
self.k_bat_dim_select = self.BAT_DIM_TO_SELECT[k_seq_dim]
self.v_bat_dim_select = self.BAT_DIM_TO_SELECT[v_seq_dim]
self._seen_tokens: int = _seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen as well as performing statistics.
self.layer_num = layer_num
self.device = device # Deprecated
## For debugging
self.PRINT_KV_RATIO_INSIDE = PRINT_KV_RATIO_INSIDE
self.print_KV_inside_per_steps = print_KV_inside_per_steps
self._print_kv_ratio_count = 0
self._kept_kv_ratio: List[Tuple[int]] = _kept_kv_ratio if _kept_kv_ratio is not None else []
## For Streaming SepLLM
self.past_tok_ids: List[torch.Tensor] = past_tok_ids if past_tok_ids is not None else [] ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache
self.left_padding_offset = None
self._set_layer_wise_attribute("init_cache_size", init_cache_size, layer_num)
self._set_layer_wise_attribute("local_size", local_size, layer_num)
self._set_layer_wise_attribute("cache_size", cache_size, layer_num)
self._set_layer_wise_attribute("sep_cache_size", sep_cache_size, layer_num)
self._set_layer_wise_attribute("sep_exrange", 0, layer_num) # runtime right boundary for separators, excluded
self._set_layer_wise_attribute("max_sep_exidx", self._list_element_add(self.sep_cache_size, self.init_cache_size), layer_num) # max right boundary for separators, excluded
self.SEP_ACCUMULATION = SEP_ACCUMULATION
self.USE_MAX_SEP_CACHE = USE_MAX_SEP_CACHE
self.SEP_PADDING_IN_BATCH = SEP_PADDING_IN_BATCH
### For positional encoding shifting
self.APPLY_PE_SHIFT = APPLY_PE_SHIFT
self.APPLY_PES_INSIDE = APPLY_PES_INSIDE
self.cos_sin_rerotation_cache = {}
self._cos_cache = None
self._sin_cache = None
self._shifted_position_ids: List[torch.Tensor] = _shifted_position_ids if _shifted_position_ids is not None else []
self._rope_unsqueeze_dim = _rope_unsqueeze_dim
self._rope_seq_dim = _rope_seq_dim
self.pe_dim = pe_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.pe_dim, 2, dtype=torch.int64).float().to(device) / self.pe_dim))
self.inv_freq = inv_freq
self.pe_scaling_factor = pe_scaling_factor
self._sin_cached = None
self._cos_cached = None
if model_type is None:
assert isinstance(separator_token_ids, list), f"`separator_token_ids: List[int]` must be correctly provided for initialization unless `model_type` is properly given, which will auto-fiil `separator_token_ids`."
assert len(separator_token_ids) > 0, f"`separator_token_ids: List[int]` should NOT be empty."
for i in range(len(separator_token_ids)):
assert isinstance(separator_token_ids[i], int), f"The ids in `separator_token_ids` must be of the type `int`."
assert isinstance(PADDING_ID, int), f"`PADDING_ID: int` must be correctly provided for initialization unless `model_type` is properly given, which will auto-fiil `PADDING_ID`."
self.separator_token_ids = separator_token_ids
self.PADDING_ID = PADDING_ID
else:
assert isinstance(model_type, str), f"`model_type` should be a `str` or `None`."
if 'llama' in model_type.lower():
# print("Debug: For Llama's default separators")
self.separator_token_ids = [128000, 13, 11, 30, 0, 26, 25, 198, 220, 662, 1174, 949, 758, 2652, 551, 720, 256,262] # llama3 8b
self.PADDING_ID = 128009
elif ( 'pythia' in model_type.lower() ) or ( 'gpt_neox' in model_type.lower() ):
# print("Debug: For GPTNeox's default separators")
self.separator_token_ids = [15, 13, 32, 2, 28, 27, 209, 186, 187, 964, 1157, 3736, 2195, 3706, 1163, 2490, 50276, 586, 4928, 50275 ] # pythia 14b
self.PADDING_ID = 0
elif 'falcon' in model_type.lower():
# print(f"Debug: For Falcon's default separators")
self.separator_token_ids = [25, 23, 42, 12, 38, 37, 193, 4610, 204, 258, 1212, 23787, 466 ] # falcon-40b
self.PADDING_ID = 11
else:
raise NotImplementedError(f"NOT implemented for the tokenizer of the backbone model type: `{model_type}`. You must provide `separator_token_ids: List[int]` and `PADDING_ID: int` for initialization in this case! ")
if APPLY_PE_SHIFT:
print(">>>>>>>>---------#####################################################################################-----------<<<<<<<<")
print(">>>>>>>>--------- -----------<<<<<<<<")
print(">>>>>>>>--------- Warning: When `APPLY_PE_SHIFT=True`, SepCache must store the key/value states ----------<<<<<<<<")
print(">>>>>>>>--------- before applying positional encoding (specifically RoPE) -----------<<<<<<<<")
print(">>>>>>>>---------#####################################################################################-----------<<<<<<<<")
if APPLY_PES_INSIDE:
print(">>>>>>>>---------#####################################################################################-----------<<<<<<<<")
print(">>>>>>>>--------- -----------<<<<<<<<")
print(">>>>>>>>--------- Warning: When `APPLY_PES_INSIDE=True`, there is no need to apply rotary positional embedding--<<<<<<<<")
print(">>>>>>>>--------- within the self_attention function, as this operation will be handled inside the `update` ---<<<<<<<<")
print(">>>>>>>>--------- function of SepCache. Note that `APPLY_PES_INSIDE=True` is typically used together with ---<<<<<<<<")
print(">>>>>>>>--------- `APPLY_PE_SHIFT=True`. ---<<<<<<<<")
print(">>>>>>>>---------#####################################################################################-----------<<<<<<<<")
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
input_ids: torch.Tensor = None,
PREFILLING_FLAG: bool = True,
query_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor]=None,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor],Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
`key_states` (`torch.Tensor`):
The new key states to cache.
`value_states` (`torch.Tensor`):
The new value states to cache.
`input_ids` (`torch.Tensor`)
The ids of the input tokens (context tokens or autoregressive tokens)
`layer_idx` (`int`):
The index of the layer to cache the states for.
`PREFILLING_FLAG` (`bool`)
It should be `True` at pre-filling phase and `False` when decoding
`query_states` (`Optional[torch.Tensor]`)
The query states that need positional encoding shifting. Only useful when `self.APPLY_PE_SHIFT=True`
`position_ids` (`Optional[torch.Tensor]`)
The original positional ids of the tokens in the input sequence (i.e., indices of positions of each input sequence tokens in the position embeddings)
Only useful when `self.APPLY_PE_SHIFT=True`, i.e., SepCache will utilize `position_ids` to calculate positional shifting.
`cache_kwargs` (`Dict[str, Any]`, optional):
Additional arguments for the cache update. The following arguments can be used in `SepCache`: `sin`,
`cos`, `sin_q`, `cos_q`, `shifted_pos_ids` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
rotation as the tokens are shifted. (These are only useful when `self.APPLY_PE_SHIFT=True`)
Only useful when `self.APPLY_PE_SHIFT=True` and `self.APPLY_PES_INSIDE=False`:
`cos` and `sin` are the shifted rotation matrices for key states
`cos_q` and `sin_q` are the shifted rotation matrices for query states
`shifted_pos_ids` is the shifted positional ids for key states
When `self.APPLY_PE_SHIFT=True` and `self.APPLY_PES_INSIDE=True`:
SepCache will utilize `position_ids` to calculate positional shifting.
`partial_rotation_size` means that `partial_rotation_size` slices along certain dimension need to be shifted (i.e., [0, 1, ..., `partial_rotation_size-1`] slices along certain dimension)
Return:
A tuple containing the updated key, value, and query states (query states are optional: only applicable when `self.APPLY_PE_SHIFT=True`).
For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM
"""
APPLY_PE_SHIFT = self.APPLY_PE_SHIFT
APPLY_PES_INSIDE = self.APPLY_PES_INSIDE
SEP_ACCUMULATION = self.SEP_ACCUMULATION
USE_MAX_SEP_CACHE = self.USE_MAX_SEP_CACHE
SEP_PADDING_IN_BATCH = self.SEP_PADDING_IN_BATCH
if input_ids is None:
input_ids = cache_kwargs.get("input_ids", None)
assert input_ids is not None, f"`input_ids` must be properly provided when calling `update()` in `SepCache`."
assert (self.APPLY_PE_SHIFT and (query_states is not None)) or not APPLY_PE_SHIFT, f"If `APPLY_PE_SHIFT=True`, `query_states` should be provided and it will be updated and returned"
# Update the number of seen tokens
if layer_idx == 0:
assert key_states.shape[-2] == input_ids.shape[-1], f"`key_states.shape[-2]` ({key_states.shape[-2]}) should be equal to `input_ids.shape[-1]` ({input_ids.shape[-1]})."
self._seen_tokens += input_ids.shape[-1]
# [bsz, num_heads, seq_len, head_dim]
new_kv_pair = (key_states, value_states)
if (key_states.shape[self.k_seq_dim] + self.get_usable_length(layer_idx) < self.cache_size[layer_idx]) or PREFILLING_FLAG: ## For prefilling
assert (PREFILLING_FLAG and key_states.shape[self.k_seq_dim] >= 1) or (not PREFILLING_FLAG and key_states.shape[self.k_seq_dim] == 1)
# Update cache and past token ids
self.update_kv_cache_and_past_tok_ids(new_kv_pair, input_ids, layer_idx, COMPRESS_KV=False, SEP_ACCUMULATION=SEP_ACCUMULATION, USE_MAX_SEP_CACHE=USE_MAX_SEP_CACHE, SEP_PADDING_IN_BATCH=SEP_PADDING_IN_BATCH)
if APPLY_PE_SHIFT:
shifted_keys, shifted_queries = self.apply_shifted_pos_emb(layer_idx, APPLY_PES_INSIDE, PREFILLING_FLAG, key_states, query_states, position_ids, cache_kwargs )
query_states = shifted_queries
self.set_kv_cache( (shifted_keys, self.value_cache[layer_idx]), layer_idx)
if PREFILLING_FLAG and layer_idx == 0:
self.left_padding_offset = self.get_initial_pos_offset(layer_idx)
## Count KV usage
kv_len_ori = self.get_seq_length(layer_idx)
kv_len_cmp = self.get_usable_length(layer_idx)
self._update_kv_ratio(kv_len_cmp=kv_len_cmp, kv_len_ori=kv_len_ori, layer_idx=layer_idx)
else:
## Update the KV cache, count KV usage, and compress the KV cache if necessary
kv_len_ori = self.get_seq_length(layer_idx)
offset_init_size_layer = self.update_kv_cache_and_past_tok_ids(new_kv_pair, input_ids, layer_idx, COMPRESS_KV=True, SEP_ACCUMULATION=SEP_ACCUMULATION, USE_MAX_SEP_CACHE=USE_MAX_SEP_CACHE, SEP_PADDING_IN_BATCH=SEP_PADDING_IN_BATCH)
kv_len_cmp = self.get_usable_length(layer_idx)
self._update_kv_ratio(kv_len_cmp=kv_len_cmp, kv_len_ori=kv_len_ori, layer_idx=layer_idx)
if APPLY_PE_SHIFT:
shifted_keys, shifted_queries = self.apply_shifted_pos_emb(layer_idx, APPLY_PES_INSIDE, PREFILLING_FLAG, key_states, query_states, position_ids, cache_kwargs )
query_states = shifted_queries
self.set_kv_cache( (shifted_keys, self.value_cache[layer_idx]), layer_idx)
if self.PRINT_KV_RATIO_INSIDE:
self._print_kv_ratio(layer_idx)
if query_states is not None:
return self.key_cache[layer_idx], self.value_cache[layer_idx], query_states
else:
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def update_kv_cache_and_past_tok_ids(self, new_kv_pair: Tuple[torch.Tensor], input_ids: torch.Tensor, layer_idx: int, COMPRESS_KV=False, SEP_ACCUMULATION:bool=True, USE_MAX_SEP_CACHE:bool=False, SEP_PADDING_IN_BATCH:bool=True) -> None:
"""Update the KV cache and past token ids; compress the KV cache if necessary."""
assert layer_idx is not None, f"`layer_idx` must be given"
assert len(new_kv_pair) == 2, f"The length of `new_kv_pair` must be 2."
assert len(self.key_cache) == len(self.value_cache), f"The layer numbers of stored `self.key_cache` and `self.value_cache` must be the same."
self.append_past_tok_ids(input_ids, layer_idx)
key, value = new_kv_pair
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key)
self.value_cache.append(value)
assert len(self.key_cache) - 1 == layer_idx, f"The key_cache should be updated sequentially according to the layer numbering."
assert len(self.value_cache) - 1 == layer_idx, f"The value_cache should be updated sequentially according to the layer numbering."
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx] , key], dim=self.k_seq_dim)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx] , value], dim=self.v_seq_dim)
assert len(self.key_cache) == len(self.value_cache), f"The layer numbers of stored key_cache and value_cache must be the same."
assert self.key_cache[layer_idx].shape[self.k_seq_dim] == self.value_cache[layer_idx].shape[self.v_seq_dim], "The seq length for key_cache and value_cache must be the same."
if COMPRESS_KV:
cmp_past_kv_pairs, cmp_past_tok_ids, offset_init_size_layer = self.compress_kv_cache_and_tokids_layer_wise((self.key_cache[layer_idx], self.value_cache[layer_idx]), layer_idx ,SEP_ACCUMULATION=SEP_ACCUMULATION, USE_MAX_SEP_CACHE=USE_MAX_SEP_CACHE, SEP_PADDING_IN_BATCH=SEP_PADDING_IN_BATCH )
self.set_kv_cache(cmp_past_kv_pairs, layer_idx)
self.set_past_tok_ids(cmp_past_tok_ids, layer_idx)
return offset_init_size_layer
def append_past_tok_ids(self, input_ids: torch.Tensor, layer_idx: int) -> None:
"""Naively append the new `input_ids` to `self.past_tok_ids[layer_idx]`"""
assert layer_idx is not None, f"`layer_idx` must be given"
if len(self.past_tok_ids) <= layer_idx:
self.past_tok_ids.append(input_ids)
assert len(self.past_tok_ids) - 1 == layer_idx, f"The past_tok_ids should be updated sequentially according to the layer numbering."
else:
self.past_tok_ids[layer_idx] = torch.cat([self.past_tok_ids[layer_idx] , input_ids], dim=-1)
def compress_kv_cache_and_tokids_layer_wise(self, past_kv_pairs, layer_idx:int ,SEP_ACCUMULATION=False, USE_MAX_SEP_CACHE=False, SEP_PADDING_IN_BATCH=True ):
"""
`SEP_ACCUMULATION`: If True, it means we will try to accumulate all the KVs for seperators. If False, only the `new_sep_kv` compressed from the `past_win_kv` will be kept (see function `compress_kv_cache_and_tokids_layer_wise`).
`USE_MAX_SEP_CACHE`: If True, it means we only keep at most `self.sep_cache_size` seperators' KVs. If the number exceeds this limit, older separator's KVs will be discarded, keeping only the most recent `self.sep_cache_size` KVs. In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
`SEP_PADDING_IN_BATCH`: If True, it means that SepCache will pad separator tokens in other records to be aligned with the record with the most separators in a batch. If False, it means that SepCache will truncate older separator tokens in other records to be aligned with the record with the fewest separators in a batch.
Note: If `SEP_ACCUMULATION=True` and `USE_MAX_SEP_CACHE=False`, as the number of input tokens increases, the number of separators in the KV cache will also accumulate endlessly
and `self.cache_size` will also be infinitely expanded (no longer fixed).
When `SEP_PADDING_IN_BATCH=True` is used in combination with `USE_MAX_SEP_CACHE=False` and `SEP_ACCUMULATION=True`, the KV cache will accumulate indefinitely,
and since `SEP_PADDING_IN_BATCH=True`, the KVs of all separators will be retained (rather than being truncated).
For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM
"""
key, value = past_kv_pairs
seq_len = key.size(self.k_seq_dim)
assert seq_len == self.get_usable_length(layer_idx), f"The seq_len of cached past key and value states should be the same as the return of `get_usable_length()`, which is {self.get_usable_length(layer_idx)}"
left_padding_offset = self.left_padding_offset
assert left_padding_offset is not None
offset_init_size_layer = self.init_cache_size[layer_idx] + left_padding_offset
self._set_layer_wise_attribute("max_sep_exidx", self._list_element_add(self.sep_cache_size, self.init_cache_size, bias=left_padding_offset), self.layer_num)
self._CHECK_PARAMS_VALIDITY(layer_idx, left_padding_offset)
if self.sep_exrange[layer_idx] <=0:
self.sep_exrange[layer_idx] = offset_init_size_layer
assert seq_len - self.local_size[layer_idx] > self.sep_exrange[layer_idx]
if offset_init_size_layer > 0:
initial_kv, initial_tokids = self.slice_kv_cache_and_tokids( past_kv_pairs, self.past_tok_ids[layer_idx], 0, offset_init_size_layer, seq_len=seq_len, _CHECK_IDX=True )
Before_First_Time_Compress_Flag = (self.sep_exrange[layer_idx] == offset_init_size_layer) ## If true, it means the present timestamp is before t1: the 1st time to compress the past window, in which only seperators' kv are kept.
if SEP_ACCUMULATION and not Before_First_Time_Compress_Flag: ## To get the old sep kv and sep token ids.
past_sep_kv, past_sep_tokids = self.slice_kv_cache_and_tokids( past_kv_pairs, self.past_tok_ids[layer_idx], offset_init_size_layer, self.sep_exrange[layer_idx], seq_len=seq_len, _CHECK_IDX=True )
past_win_kv, past_win_tokids = self.slice_kv_cache_and_tokids( past_kv_pairs, self.past_tok_ids[layer_idx], self.sep_exrange[layer_idx], seq_len - self.local_size[layer_idx], seq_len=seq_len, _CHECK_IDX=True )
local_kv, local_tokids = self.slice_kv_cache_and_tokids( past_kv_pairs, self.past_tok_ids[layer_idx], seq_len - self.local_size[layer_idx], seq_len, seq_len=seq_len, _CHECK_IDX=True )
new_sep_kv, new_sep_tokids, min_sep_num, max_sep_num = self.compress_past_win_2_seps( past_win_kv, past_win_tokids, SEP_PADDING_IN_BATCH = SEP_PADDING_IN_BATCH ) ## To get the new sep kv and sep token ids that were just compressed from the past window
if SEP_ACCUMULATION and not Before_First_Time_Compress_Flag: ## Try to accumulate all the seen seps
sep_kv, sep_tokids = self.cat_kv_cache_and_tokids( [ past_sep_kv, new_sep_kv ] , [past_sep_tokids, new_sep_tokids ] )
new_sep_len = new_sep_tokids.shape[-1]
sep_len = sep_tokids.shape[-1]
else: ## Only keep the newly obtained kv (those just compressed from the past window)
sep_kv, sep_tokids = new_sep_kv, new_sep_tokids
# new_sep_len = new_sep_tokids.shape[-1]
sep_len = sep_tokids.shape[-1]
assert (SEP_PADDING_IN_BATCH and max_sep_num==sep_len) or ( (not SEP_PADDING_IN_BATCH) and min_sep_num==sep_len)
if USE_MAX_SEP_CACHE: ## Fixed sep cache size, i.e., only keep max_sep_len seps' kv in the cache.
if offset_init_size_layer + sep_len > self.max_sep_exidx[layer_idx]:
max_sep_len = self.max_sep_exidx[layer_idx] - offset_init_size_layer
assert sep_kv[0].shape[-2] == sep_tokids.shape[-1], f"The seq_len for seps' KVs and tok_ids should be the same."
sep_kv, sep_tokids = self.slice_kv_cache_and_tokids( sep_kv, sep_tokids, sep_len-max_sep_len, sep_len, seq_len = sep_tokids.shape[-1] ,_CHECK_IDX=True )
self.sep_exrange[layer_idx] = self.max_sep_exidx[layer_idx]
else:
self.sep_exrange[layer_idx] = offset_init_size_layer + sep_len
else: ## Extend the sep cache and the whole cache if USE_MAX_SEP_CACHE is not set
self.sep_exrange[layer_idx] = offset_init_size_layer + sep_len
if self.sep_exrange[layer_idx] > self.max_sep_exidx[layer_idx]:
cache_incremental_gap = self.sep_exrange[layer_idx] - self.max_sep_exidx[layer_idx]
self.max_sep_exidx[layer_idx] = self.sep_exrange[layer_idx]
self.sep_cache_size[layer_idx] = self.sep_cache_size[layer_idx] + cache_incremental_gap
self.cache_size[layer_idx] = self.cache_size[layer_idx] + cache_incremental_gap
if offset_init_size_layer > 0:
cmp_past_kv_pairs, cmp_past_tok_ids = self.cat_kv_cache_and_tokids( [initial_kv, sep_kv, local_kv ] , [initial_tokids, sep_tokids, local_tokids ] )
else:
cmp_past_kv_pairs, cmp_past_tok_ids = self.cat_kv_cache_and_tokids( [sep_kv, local_kv ] , [sep_tokids, local_tokids ] )
return cmp_past_kv_pairs, cmp_past_tok_ids, offset_init_size_layer
def compress_past_win_2_seps(self, past_win_kv: Tuple[torch.Tensor], past_win_tokids: torch.Tensor, MIN_SEP_ALERT: bool=False, SEP_PADDING_IN_BATCH: bool=True ) -> Tuple[Union[Tuple[torch.Tensor], torch.Tensor, int ]]:
"""Compress the KVs in the past window into the sep cache where only separators' KVs are kept. Padding or Truncating if necessary."""
sep_index_tensor = torch.zeros_like(past_win_tokids).bool() # batch x seq_len
for sp_id in self.separator_token_ids:
sep_index_tensor = sep_index_tensor | ( past_win_tokids == sp_id ) # batch x seq_len
sep_cnt = sep_index_tensor.int().sum(-1)
min_sep_num = sep_cnt.min() # the min sep number for the seqs in a batch
max_sep_num = sep_cnt.max() # the max sep number for the seqs in a batch
if MIN_SEP_ALERT and not SEP_PADDING_IN_BATCH:
assert min_sep_num>0, f"The min sep number for each compressing time in a batch should be at least one if `MIN_SEP_ALERT=True` and `SEP_PADDING_IN_BATCH=False`"
batch1_sep_ids_list = []
batch_size = past_win_tokids.shape[0]
for b_id in range(batch_size):
batch1_sep_ids = past_win_tokids[b_id, sep_index_tensor[b_id]] # # sep_num
if SEP_PADDING_IN_BATCH: ## padding
sep_num = batch1_sep_ids.shape[-1]
padding_num = max_sep_num - sep_num
if padding_num > 0:
assert padding_num <= past_win_tokids.shape[-1], f"padding_num: {padding_num} should be <= past_win_tokids.shape[-1]:{past_win_tokids.shape[-1]}"
batch1_sep_ids = batch1_sep_ids # # sep_num
batch1_pad_ids = past_win_tokids[b_id, -padding_num:] # # padding_num
batch1_sep_ids = torch.cat([batch1_sep_ids, batch1_pad_ids], dim =-1) ## max_sep_num
else: ## truncating
batch1_sep_ids = batch1_sep_ids[..., :min_sep_num ] # # min_sep_num
batch1_sep_ids_list.append(batch1_sep_ids)
new_sep_tokids = torch.stack(batch1_sep_ids_list, dim=0) # # B x min_sep_num
key_cache, value_cache = past_win_kv
assert batch_size==key_cache.shape[0]
batch1_sep_k_list = []
batch1_sep_v_list = []
batch1_sep_ids_list = []
for b_id in range(batch_size):
batch1_sep_k = self.k_bat_dim_select(key_cache, b_id, sep_index_tensor[b_id], min_sep_num, max_sep_num, SEP_PADDING_IN_BATCH)
batch1_sep_k_list.append(batch1_sep_k)
batch1_sep_v = self.v_bat_dim_select(value_cache, b_id, sep_index_tensor[b_id], min_sep_num, max_sep_num, SEP_PADDING_IN_BATCH)
batch1_sep_v_list.append( batch1_sep_v )
sep_k = torch.stack(batch1_sep_k_list, dim=0) ## batch x head x min_sep_num x dim
sep_v = torch.stack(batch1_sep_v_list, dim=0) ## batch x head x min_sep_num x dim
new_sep_kv = (sep_k, sep_v)
return new_sep_kv, new_sep_tokids, min_sep_num, max_sep_num
def apply_shifted_pos_emb(self, layer_idx: int, APPLY_PES_INSIDE: bool, PREFILLING_FLAG: bool, key_states: torch.Tensor, query_states: torch.Tensor, position_ids: torch.Tensor, cache_kwargs: Optional[Dict[str, Any]] = None ) -> torch.Tensor:
"""Perform positional encoding shifting if required"""
seq_len = self.get_usable_length(layer_idx)
keys_to_shift = self.key_cache[layer_idx]
queries_to_shift = query_states
assert keys_to_shift.shape[self.k_seq_dim] == seq_len
if cache_kwargs is None:
cache_kwargs = {}
if APPLY_PES_INSIDE:
if len(self._shifted_position_ids) <= layer_idx:
self._shifted_position_ids.append(None)
if PREFILLING_FLAG: ## for prefilling
assert position_ids.shape[-1] >= seq_len, f"The length of position_ids should be >= the usable length of kv cache when prefilling."
self._shifted_position_ids[layer_idx] = position_ids[:, :seq_len].detach()
shifted_pos_ids = self._shifted_position_ids[layer_idx]
elif self._shifted_position_ids[layer_idx].shape[-1] >= seq_len: ## for generation
assert position_ids.shape[-1] == 1, f"The length of query and position_ids should be 1 during generation."
shifted_pos_ids = self._shifted_position_ids[layer_idx][:, :seq_len].detach()
elif self._shifted_position_ids[layer_idx].shape[-1] < seq_len: ## for generation
assert position_ids.shape[-1] == 1, f"The length of query and position_ids should be 1 during generation."
increased_gap = seq_len - self._shifted_position_ids[layer_idx].shape[-1]
assert increased_gap < self._shifted_position_ids[layer_idx].shape[-1], f"Normally, for auto-regressive model, the input length for each step should be 1 during generation."
new_position_ids = self._shifted_position_ids[layer_idx][:, -increased_gap: ] + increased_gap
self._shifted_position_ids[layer_idx] = torch.cat([self._shifted_position_ids[layer_idx], new_position_ids.detach()], dim=-1)
shifted_pos_ids = self._shifted_position_ids[layer_idx]
else:
raise RuntimeError
cos, sin = self._get_naive_shifted_cos_sin(
key_states, shifted_pos_ids, seq_len
)
q_rope_idx = torch.arange( seq_len - query_states.shape[self.k_seq_dim], seq_len).to(cos.device)
cos_q, sin_q = cos.index_select(self._rope_seq_dim, q_rope_idx), sin.index_select(self._rope_seq_dim, q_rope_idx)
else:
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
sin_q = cache_kwargs.get("sin_q")
cos_q = cache_kwargs.get("cos_q")
shifted_pos_ids = cache_kwargs.get("shifted_pos_ids")
assert (sin is not None) and (cos is not None), f"sin and cos matrices should be be provided"
if sin_q is None:
q_rope_idx = torch.arange( seq_len - query_states.shape[self.k_seq_dim], seq_len).to(sin.device)
sin_q = sin.index_select(self._rope_seq_dim, q_rope_idx)
if cos_q is None:
q_rope_idx = torch.arange( seq_len - query_states.shape[self.k_seq_dim], seq_len).to(cos.device)
cos_q = cos.index_select(self._rope_seq_dim, q_rope_idx)
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if partial_rotation_size is not None:
keys_to_shift, keys_pass = (
keys_to_shift[..., :partial_rotation_size],
keys_to_shift[..., partial_rotation_size:]
)
queries_to_shift, queries_pass = (
queries_to_shift[..., :partial_rotation_size],
queries_to_shift[..., partial_rotation_size:]
)
shifted_keys = self._apply_rotary_pos_emb_single(keys_to_shift, cos, sin, shifted_pos_ids, unsqueeze_dim=self._rope_unsqueeze_dim)
shifted_queries = self._apply_rotary_pos_emb_single(queries_to_shift, cos_q, sin_q, shifted_pos_ids[:, -queries_to_shift.shape[self.k_seq_dim] : ], unsqueeze_dim=self._rope_unsqueeze_dim)
if partial_rotation_size is not None:
shifted_keys = torch.cat( [shifted_keys, keys_pass], dim=-1)
shifted_queries = torch.cat( [shifted_queries, queries_pass], dim=-1)
return shifted_keys, shifted_queries
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the seen tokens. A layer index can be optionally passed."""
return self._seen_tokens
def get_usable_length(self, layer_idx: int = 0) -> int:
"""Returns the sequence length of the actual cached states. A layer index must be passed."""
if len(self.key_cache) <= layer_idx :
return 0
assert self.key_cache[layer_idx].shape[self.k_seq_dim] == self.value_cache[layer_idx].shape[self.v_seq_dim], f"`self.key_cache` and `self.value_cache` should have the same length."
return self.key_cache[layer_idx].shape[self.k_seq_dim]
def get_initial_pos_offset(self, layer_idx:int = 0) -> int:
"""Return the number of padding tokens in the record with the most left padding tokens in a batch."""
assert isinstance(self.PADDING_ID, int), f"`self.PADDING_ID` should be correctly set."
assert len(self.past_tok_ids) > layer_idx, f"`self.past_tok_ids` for layer {layer_idx} must have been properly set."
past_tok_ids = self.past_tok_ids[layer_idx]
assert past_tok_ids is not None, f"`past_tok_ids` for layer {layer_idx} should not be None"
pad_index_tensor = (past_tok_ids == self.PADDING_ID) ## batch x seq_len
pad_toks_cnt = pad_index_tensor.int().sum(-1) ## [batch]
offset = pad_toks_cnt.max().item()
return offset
def get_batch_size(self) -> int:
"""Return the batch size."""
assert self.key_cache is not None, f"`self.key_cache` should not be None."
assert self.value_cache is not None, f"`self.value_cache` should not be None."
assert len(self.key_cache) > 0, f"`self.key_cache` is empty. No batch size is available."
assert len(self.value_cache) > 0, f"self.value_cache is empty. No batch size is available."
assert len(self.value_cache) == len(self.key_cache), f"self.value_cache and self.key_cache should be at the same length."
assert self.value_cache[0].shape[0] == self.key_cache[0].shape[0], f"self.value_cache and self.key_cache should have the same batch size."
return self.value_cache[0].shape[0]
def get_kv_pair(self, layer_idx: int = None) -> Tuple[torch.Tensor]:
assert layer_idx is not None, f"`layer_idx` must be given."
if (len(self.key_cache) <= layer_idx) and (len(self.value_cache) <= layer_idx ):
key = self.key_cache[layer_idx]
value = self.value_cache[layer_idx]
else:
raise RuntimeError(f"The KV for layer:{layer_idx} have not been set.")
return (key, value)
def set_kv_cache(self, kv_pair: Tuple , layer_idx: int ) -> None:
assert len(kv_pair) == 2, f"The length of `kv_pair` must be 2."
self.key_cache[layer_idx] = kv_pair[0]
self.value_cache[layer_idx] = kv_pair[1]
def set_past_tok_ids(self, tok_ids: torch.Tensor, layer_idx:int) -> None:
self.past_tok_ids[layer_idx] = tok_ids
def cat_kv_cache_and_tokids(self, kv_pairs_list: List[Tuple[torch.Tensor]] , tok_ids_list:List[torch.Tensor]) -> Tuple[Union[Tuple[torch.Tensor],torch.Tensor]]:
return self.cat_kv_cache(kv_pairs_list), self.cat_token_ids(tok_ids_list)
def slice_kv_cache_and_tokids(self, kv_pair:Tuple[torch.Tensor], tok_ids_list:torch.Tensor, start:int, end:int, seq_len:int=None, _CHECK_IDX:bool=True, ) -> Tuple[Union[Tuple[torch.Tensor], torch.Tensor]]:
sliced_kv = self._slice_kv(start, end, kv_pair=kv_pair, seq_len=seq_len, _CHECK_IDX=_CHECK_IDX,)
sliced_tids = self._slice_tok_ids(start, end, tok_ids_list = tok_ids_list, seq_len=seq_len, _CHECK_IDX=_CHECK_IDX)
return sliced_kv , sliced_tids
def cat_kv_cache(self, kv_pairs_list: List[Tuple[torch.Tensor]] ) -> Tuple[torch.Tensor]:
assert len(kv_pairs_list) >= 1
if len(kv_pairs_list) == 1 :
return kv_pairs_list[0]
else:
ret = None
for i, kv_pair in enumerate(kv_pairs_list): # enumerate all the KVs needed to be cat
if i == 0:
ret = kv_pair
else:
ret = self._cat_kv(ret, kv_pair)
return ret
def cat_token_ids(self, tok_ids_list:List[torch.Tensor] ) -> torch.Tensor :
assert len(tok_ids_list) >= 1
return torch.cat(tok_ids_list, dim=-1)
def _cat_kv(self, kv_pair_a:Tuple[torch.Tensor], kv_pair_b:Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
k_a, v_a = kv_pair_a
k_b, v_b = kv_pair_b
cat_k = torch.cat([k_a, k_b], dim=self.k_seq_dim)
cat_v = torch.cat([v_a, v_b], dim=self.v_seq_dim)
return (cat_k, cat_v)
def _slice_kv(self, start:int, end:int, kv_pair: Tuple[torch.Tensor], seq_len:int=None, _CHECK_IDX:bool=True) -> Tuple[torch.Tensor] :
assert kv_pair is not None, f"kv_pair must NOT be None when slicing it."
key_cache = kv_pair[0]
value_cache = kv_pair[1]
if _CHECK_IDX:
assert seq_len is not None, f"seq_len must be given for checking the index for slicing"
start, end = self._CHECK_IDX(start, end, seq_len)
sliced_key_cache = self.k_slice(key_cache, start, end)
sliced_value_cache = self.v_slice(value_cache, start, end)
return ( sliced_key_cache, sliced_value_cache)
def _slice_tok_ids(self, start:int, end:int, tok_ids_list:torch.Tensor , seq_len:int=None, _CHECK_IDX:bool=False) -> torch.Tensor:
assert tok_ids_list is not None, f"tok_ids_list must NOT be None when slicing it."
if _CHECK_IDX:
assert seq_len is not None, f"seq_len must be given for checking the index for slicing"
start, end = self._CHECK_IDX(start, end, seq_len)
sliced_tok_ids = tok_ids_list[:, start:end]
return sliced_tok_ids
def _set_layer_wise_attribute(self, name: str, value: Any, layer_num:int ):
"""Set layer-wise attributes"""
if isinstance(value, int):
setattr(self, name, [value] * layer_num)
elif isinstance(value, (list, tuple)):
assert len(value) == layer_num, f"The length of {name}: {len(value)} must be equal to `layer_num`: {layer_num}"
setattr(self, name, list(value))
else:
raise TypeError(f"{name} must be of the type `int` or `list` but got `{type(value)}`")
def _list_element_add(self, list_a: List, list_b: List, bias: int=0, dtype = int, device = 'cpu') -> List:
"""Element-wise addition between two lists."""
assert len(list_a) == len(list_b), f"The length of `list_a` ({len(list_a)}) must be equal to that of `list_b` ({len(list_b)})."
tensor_c = torch.tensor(list_a, dtype=dtype, device=device) + torch.tensor(list_b, dtype=dtype, device=device) + torch.tensor([bias], dtype=dtype, device=device)
return tensor_c.int().tolist()
def _CHECK_IDX(self, start: int = 0, end: int = 100, seq_len: int = 1000):
assert isinstance(start, int) and isinstance(end, int) and isinstance(seq_len, int), f"`start`, `end`, `seq_len` must be `int`."
assert seq_len>0 , f"`seq_len` must > 0"
if start <0 :
start = start % seq_len
if end < 0 :
end = end % seq_len
assert (start >=0) and (start < seq_len) , f"start:{start}, end:{end}, seq_len:{seq_len}"
assert (end >= 0) and (end <= seq_len) , f"start:{start}, end:{end}, seq_len:{seq_len}"
assert start < end, f"start:{start}, end:{end}, seq_len:{seq_len}"
return start,end
def _CHECK_PARAMS_VALIDITY(self, layer_idx:int, left_padding_offset:int):
assert len(self.cache_size) > layer_idx
assert len(self.init_cache_size) > layer_idx
assert len(self.sep_cache_size) > layer_idx
assert len(self.max_sep_exidx) > layer_idx
assert len(self.local_size) > layer_idx
assert self.cache_size[layer_idx] > 0 , f"`self.cache_size` for layer:{layer_idx} must be greater than 0"
assert self.init_cache_size[layer_idx] >= 0 , f"`self.init_cache_size` for layer:{layer_idx} must be greater than (equal to) 0"
assert self.local_size[layer_idx] > 0 , f"`self.local_size` for layer:{layer_idx} must be greater than 0"
assert self.sep_cache_size[layer_idx] > 0 , f"`self.sep_cache_size` for layer:{layer_idx} must be greater than 0"
assert self.max_sep_exidx[layer_idx] > 0 , f"`self.max_sep_exidx` for layer:{layer_idx} must be greater than 0"
assert self.init_cache_size[layer_idx] + self.sep_cache_size[layer_idx] + self.local_size[layer_idx] + left_padding_offset < self.cache_size[layer_idx], f"`init_cache_size` ({self.init_cache_size[layer_idx]}) + `sep_cache_size` ({self.sep_cache_size[layer_idx]}) + `local_size` ({self.local_size[layer_idx]}) + `left_padding_offset` ({left_padding_offset}) for layer {layer_idx} should be less than `cache_size`:({self.cache_size[layer_idx]}) for layer {layer_idx}, i.e., a + s + w + (left_padding_offset) < c. Please increase `cache_size` if applicable."
def _rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def _apply_rotary_pos_emb_single(self, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim) # batch x seq_len x dim --> batch x 1 x seq_len x dim
sin = sin.unsqueeze(unsqueeze_dim)
k_embed = (k * cos) + (self._rotate_half(k) * sin)
return k_embed
def _get_naive_shifted_cos_sin(self, x: torch.Tensor, position_ids: torch.Tensor=None, seq_len=None):
# x: [batch, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(dtype=x.dtype)
sin = emb.sin().to(dtype=x.dtype)
# backwards compatibility
self._cos_cached = cos
self._sin_cached = sin
return cos, sin
def _get_scaled_shifted_cos_sin(self, x, position_ids, seq_len=None):
# difference to the original RoPE: a scaling factor is aplied to the position ids
position_ids = position_ids.float() / self.scaling_factor
cos, sin = self._get_naive_shifted_cos_sin(x, position_ids, seq_len)
return cos, sin
def _get_dynamicNTK_scaling_shifted_cos_sin(self, x, position_ids, seq_len=None):
# difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO: this may break with compilation
cos, sin = self._get_naive_shifted_cos_sin(x, position_ids, seq_len)
return cos, sin
def _update_kv_ratio(self, kv_len_cmp:int, kv_len_ori:int, layer_idx: int=0) -> None:
"""Update the KV ratios which are for statistics and debugging."""
if len(self._kept_kv_ratio) <= layer_idx:
self._kept_kv_ratio.append( (kv_len_cmp, kv_len_ori ) )
else:
old_kv_len_cmp = self._kept_kv_ratio[layer_idx][0]
old_kv_len_ori = self._kept_kv_ratio[layer_idx][1]
self._kept_kv_ratio[layer_idx] = (old_kv_len_cmp + kv_len_cmp, old_kv_len_ori + kv_len_ori )
def _print_kv_ratio(self, layer_idx : int, LAYER_WISE: bool = False):
"""Print the KV ratios."""
self._print_kv_ratio_count += 1
if LAYER_WISE:
if self._print_kv_ratio_count % self.print_KV_inside_per_steps == 0:
print(f"######################## [Kept Tokens, Seen Tokens] : {self._kept_kv_ratio[layer_idx]}, Ratio: { (self._kept_kv_ratio[layer_idx][0]+1e-6) / (self._kept_kv_ratio[layer_idx][1]+1e-6) :.4f} ########################")
elif self._print_kv_ratio_count % (self.print_KV_inside_per_steps * self.layer_num) == 0:
print(f"######################## [Kept Tokens, Seen Tokens] : {self._kept_kv_ratio[layer_idx]}, Ratio: { (self._kept_kv_ratio[layer_idx][0]+1e-6) / (self._kept_kv_ratio[layer_idx][1]+1e-6) :.4f} ########################")
@classmethod ## Deprecated
def from_legacy_cache(cls,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
## For SepLLM
init_cache_size: Union[int, List] = 4,
sep_cache_size: Union[int, List] = 64,
local_size: Union[int, List]=256,
cache_size: Union[int, List]=512,
SEP_ACCUMULATION: bool = True,
USE_MAX_SEP_CACHE: bool = False,
SEP_PADDING_IN_BATCH: bool = False,
separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided. set it to `[-1]` to degrade SepCache to StreamingLLM's SinkCache
PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.
## For inheritance & initialization states
past_tok_ids: List[torch.Tensor] = None, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
key_cache: List[torch.Tensor] = None,
value_cache: List[torch.Tensor] = None,
## For debugging
PRINT_KV_RATIO_INSIDE: bool = False,
print_KV_inside_per_steps: int = 1000,
_seen_tokens: int = 0,
_kept_kv_ratio: List[Tuple[int]] = None,
### For positional encoding shifting
APPLY_PE_SHIFT: bool = False,
APPLY_PES_INSIDE: bool = True,
_shifted_position_ids: List[torch.Tensor] = None,
_rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
_rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
pe_scaling_factor:float = 1.0,
pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
max_position_embeddings: int = 8192,
base: int=10000, ## The base for RoPE.
## For basic transformer architecture
k_seq_dim: int=2, ## The dimension for seq_len in key tensors
v_seq_dim: int=2, ## The dimension for seq_len in value tensors
layer_num: int = None, ## required for initialization
model_type: str = None, ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
device = None
) -> "SepCache":
"""Deprecated: Converts a cache in the legacy cache format into `SepCache`."""
if past_key_values is not None:
assert len(past_key_values)==0, f"`from_legacy_cache` function is deprecated. You can only use it when `past_key_values=None` or `past_key_values` is empty, in which case, `from_legacy_cache` is equivalent to the `__init__` function."
past_key_values = None
assert past_key_values is None, f"`from_legacy_cache` function is deprecated. You can only use it when `past_key_values=None` or `past_key_values` is empty, in which case, `from_legacy_cache` is equivalent to the `__init__` function."
if past_key_values is not None: ## Deprecated
key_cache = []
value_cache = []
for i, kv in enumerate(past_key_values):
if i == 0:
past_tok_ids = [] if len(kv) == 4 else past_tok_ids
if len(kv) == 4:
k, v, p_tok_ids, _seen_tokens = kv
key_cache.append(k)
value_cache.append(v)
past_tok_ids.append(p_tok_ids)
_seen_tokens = _seen_tokens
elif len(kv) == 2:
k, v = kv
key_cache.append(k)
value_cache.append(v)
cache = cls(
## For SepLLM
init_cache_size = init_cache_size,
sep_cache_size = sep_cache_size,
local_size = local_size,
cache_size = cache_size,
SEP_ACCUMULATION = SEP_ACCUMULATION,
USE_MAX_SEP_CACHE = USE_MAX_SEP_CACHE,
SEP_PADDING_IN_BATCH = SEP_PADDING_IN_BATCH,
separator_token_ids = separator_token_ids,
PADDING_ID = PADDING_ID,
## For inheritance & initialization states
past_tok_ids = past_tok_ids, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache
key_cache = key_cache,
value_cache = value_cache,
## For debugging
PRINT_KV_RATIO_INSIDE = PRINT_KV_RATIO_INSIDE,
print_KV_inside_per_steps = print_KV_inside_per_steps,
_seen_tokens = _seen_tokens,
_kept_kv_ratio = _kept_kv_ratio,
### For positional encoding shifting
APPLY_PE_SHIFT = APPLY_PE_SHIFT,
APPLY_PES_INSIDE = APPLY_PES_INSIDE,
_shifted_position_ids = _shifted_position_ids,
_rope_unsqueeze_dim = _rope_unsqueeze_dim,
_rope_seq_dim = _rope_seq_dim,
pe_scaling_factor = pe_scaling_factor,
pe_dim = pe_dim,
max_position_embeddings = max_position_embeddings,
base = base,
## For basic transformer architecture
k_seq_dim = k_seq_dim,
v_seq_dim = v_seq_dim,
layer_num = layer_num,
model_type = model_type,
device = device,
)
return cache
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]]: ## Deprecated
"""Deprecated: Converts the `SepCache` instance into the legacy cache format, i.e., tuple."""
print(">>>>>>>>>>>Warnings: Please try to avoid using this deprecated `to_legacy_cache` function since it will drop many useful parameters or states in SepCache.<<<<<<<<<<<")
legacy_cache = ()
for layer_idx in range(len(self.key_cache)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.past_tok_ids[layer_idx], self._seen_tokens), )
return legacy_cache
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
for layer_idx in range(len(self)):
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
if self.key_cache is not None:
return len(self.key_cache)
else:
return 0
@property
def seen_tokens(self):
if hasattr(self, "_seen_tokens"):
return self._seen_tokens
else:
return None
class KVUsageCounter:
def __init__(self, PRINT_KV_per_ITERs:int = 100):
"""
For detailed usage instructions, please refer to sepllm.github.io
"""
self._total_kept_kv_ratio = (0, 0)
self._printing_counter = 0
self.PRINT_KV_per_ITERs = PRINT_KV_per_ITERs
def accumulate_historical_kv_stats(self, cache: SepCache = None) -> None:
assert cache is not None, f"The KV cache object (of the class SepCache) must be given."
assert hasattr(cache, "_kept_kv_ratio"), f"The cache object must have the attribute _kept_kv_ratio."
assert hasattr(cache, "layer_num"), f"The cache object must have the attribute layer_num."
assert len(cache._kept_kv_ratio) == cache.layer_num, f"The length ({cache._kept_kv_ratio}) of cache object's _kept_kv_ratio attribute must be equal to layer_num ({cache.layer_num})"
for ly in range(cache.layer_num):
self._total_kept_kv_ratio = (self._total_kept_kv_ratio[0] + cache._kept_kv_ratio[ly][0], self._total_kept_kv_ratio[1] + cache._kept_kv_ratio[ly][1] )
self._printing_counter += 1
def WHETHER_2_PRINT(self) -> bool:
return (self._printing_counter % self.PRINT_KV_per_ITERs) == 0
def print_KV_ratio(self) -> None:
print(f"######################## The KVs for ALL layers: (KV number kept for predicting current token)/(Total seen KV number) ########################")
print(f"########################>>>>>>>>>>> [Kept Tokens, Seen Tokens] : {self._total_kept_kv_ratio}, Ratio: { (self._total_kept_kv_ratio[0]+1e-6) / (self._total_kept_kv_ratio[1]+1e-6):.4f} <<<<<<<<<<<<##########################")
print(f"######################## -------------------------------------------------------------------------------------------- ########################")
#############################################################End################################################################
##################################################### Monkey Patch Utils #######################################################
def get_full_class_import_path(obj):
"""Get the complete class import path of an object"""
# Get the class of the object
cls = obj.__class__
# Get the module name where the class is defined
module = cls.__module__
# Get the qualified name of the class (including outer classes)
qualname = cls.__qualname__
# Handle nested classes (e.g., ClassA.ClassB)
if '.' in qualname:
# Replace nested class separators
class_path = f"{module}.{qualname.replace('.', '_')}"
else:
class_path = f"{module}.{qualname}"
return class_path
def get_importable_class_path(obj):
"""Get the directly importable class path (handling special cases and dynamic classes)"""
cls = obj.__class__
module = cls.__module__
qualname = cls.__qualname__
# Handle built-in types
if module == 'builtins':
return qualname
# Handle dynamically generated classes (e.g., functools.partial)
if not hasattr(cls, '__module__') or module is None:
return f"<dynamic class {qualname}>"
# Handle nested classes
if '.' in qualname:
# Try to import the parent module to validate the path
try:
import importlib
parent_module = importlib.import_module(module)
# Follow the qualified name path
parts = qualname.split('.')
current = parent_module
for part in parts:
current = getattr(current, part)
# If successful access, return the original path
return f"{module}.{qualname}"
except (ImportError, AttributeError):
# Fallback: use underscore connection
return f"{module}.{qualname.replace('.', '_')}"
return f"{module}.{qualname}"
def monkey_patch_by_class_path(model, new_forward):
"""Perform monkey patching through class path"""
# Get the complete class path
class_path = get_importable_class_path(model)
# Dynamically import the class
try:
import importlib
module_path, class_name = class_path.rsplit('.', 1)
module = importlib.import_module(module_path)
target_class = getattr(module, class_name)
# Save the original method
if not hasattr(target_class, '_original_forward'):
target_class._original_forward = target_class.forward
# Apply the patch
target_class.forward = new_forward
# Update the method binding of the current instance
model.forward = types.MethodType(target_class.forward, model)
return f"Successful Monkey Patch: {class_path}.forward"
except (ImportError, AttributeError, ValueError) as e:
return f"Patch Failed: {str(e)}"
def find_inner_attribute(obj, attr_name_list: List[str], default_type = PreTrainedModel ):
# try to find the attribute of the name in `attr_name_list`.
for target_attr_name in attr_name_list:
if hasattr(obj, target_attr_name):
return getattr(obj, target_attr_name)
# else: try to find the attribute of the type `default_type`
for attr_name in dir(obj):
attr_value = getattr(obj, attr_name)
if isinstance(attr_value, default_type):
return attr_value
raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any name in {attr_name_list} or whose type is {default_type}.")
def find_attribute_name(obj, name_pattern_list: List[str], exclude_pattern_list: List[str], match_type = nn.Module):
for attr_name in dir(obj):
attr_value = getattr(obj, attr_name)
for pattern in name_pattern_list:
for ex_pattern in exclude_pattern_list:
if isinstance(attr_value, match_type) and (pattern.lower() in attr_value.__class__.__name__.lower()) and ( ex_pattern.lower() not in attr_value.__class__.__name__.lower() ):
return attr_value
elif isinstance(attr_value, match_type) and (pattern.lower() in attr_name.lower()) and (ex_pattern.lower() not in attr_name.lower() ):
return attr_value
raise AttributeError(f"In the {obj} object, there is no attribute whose name matches any pattern in {name_pattern_list} and excludes any pattern in {exclude_pattern_list}, and whose type is {match_type}.")
def monkey_patching(model_obj,
model_atten_forward , ## The `forward` function used to patch.
possible_inner_model_names: List[str] = ["model", "transformer", "gpt_neox"] , # In `XXXForCausalLM` class, the possible name of internal attribute for model. e.g., "model", "transformer", "gpt_neox", etc.
possible_layers_names: List[str] = ["layers", "h" ], # In `XXXModel` class, the possible name of internal attribute for decoder layers, e.g., "layers", "h", etc.
atten_attr_name_pattern_list: List[str] = ["attention", "self_attn"], # In `XXXDecoderLayer` class, the possible name of internal attribute for self-attention, e.g., "attention", "self_attn", etc.
atten_attr_name_pattern_exclude: List[str] = ["norm", "layer"], # In `XXXDecoderLayer` class, the impossible name patterns (i.e., the patterns to be excluded) of internal attribute for self-attention module class, e.g., "norm" , etc. Sometimes, there will be some attributes like "post_attention_norm" and we do not want modify the `forward` function of it - we want to modify the `forward` function of `XXXAttention`. So, we need to exclude attribute name patterns like "norm" to accurately find the correct "forward" function to replace.
verbose = True):
"""
This `monkey_patching` function is to
- find the `forward` function of the `XXXAttention` class.
- replace all the related `forward` functions of the instances of `XXXAttention` class with `model_atten_forward`.
"""
## To avoid the argument check failure, i.e., let "sepllm_kwargs" pass the check.
transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
## Get inner model obj
inner_model_type = PreTrainedModel
inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
## Get the decoder layers (`nn.ModuleList`) obj
layers_type = nn.ModuleList
model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
## Replace all the related `forward` functions of XXXAttention class's instances.
for i, decoder_layer in enumerate(model_layers):
self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
if verbose:
decoder_class_name = get_importable_class_path(decoder_layer)
print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
return model_layers
#############################################################End################################################################
def str2list(s : str, sep = ';'):
if isinstance(s, str):
res = s.split(sep)
return [ int(i) for i in res ]
else:
return s
def generate(model,
## For SepCache
init_cache_size: Union[int, List] = 4,
sep_cache_size: Union[int, List] = 128,
local_size: Union[int, List]=256,
cache_size: Union[int, List]=512,
SEP_ACCUMULATION: bool = True,
USE_MAX_SEP_CACHE: bool = False,
SEP_PADDING_IN_BATCH: bool = False,
separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided.
PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.
## For inheritance & initialization states
past_tok_ids: List[torch.Tensor] = None, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
key_cache: List[torch.Tensor] = None,
value_cache: List[torch.Tensor] = None,
## For debugging
PRINT_KV_RATIO_INSIDE: bool = False,
print_KV_inside_per_steps: int = 1000,
_seen_tokens: int = 0,
_kept_kv_ratio: List[Tuple[int]] = None,
### For positional encoding shifting
APPLY_PE_SHIFT: bool = False,
APPLY_PES_INSIDE: bool = False,
_shifted_position_ids: List[torch.Tensor] = None,
_rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
_rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
pe_scaling_factor:float = 1.0,
pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
max_position_embeddings: int = 8192,
base: int=10000, ## The base for RoPE.
## For basic transformer architecture
k_seq_dim: int=2, ## The dimension for seq_len in key tensors
v_seq_dim: int=2, ## The dimension for seq_len in value tensors
layer_num: int = None, ## required for initialization
model_type: str = 'llama', ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
device = None,
## For verbosity of monkey patching
monkey_patch_verbose: bool = False,
**kwargs
):
"""Custom generate function for SepCache.
A cache as described in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094). In the training phase,
SepLLM condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the
corresponding SepCache only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.
It stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Frequently-Used Parameters:
`init_cache_size: Union[int, List]`:
The maximum number of KVs to be stored for initial tokens.
In the paper, the hyperparameter `a` is an abbreviated alias for `self.init_cache_size`.
`sep_cache_size: Union[int, List]`:
The maximum number of KVs to be stored for separator tokens.
In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
`local_size: Union[int, List]`:
The maximum number of KVs to be stored for local tokens (i.e., sliding window).
In the paper, the hyperparameter `w` is an abbreviated alias for `self.local_size`.
`cache_size: Union[int, List]`:
The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.
In the paper, the hyperparameter `c` is an abbreviated alias for `self.cache_size`.
Concerning these four parameters above:
When a list is passed (its length must be `layer_num`), it represents different values for each layer.
When an integer is passed, it means the setting is the same for all layers.
`USE_MAX_SEP_CACHE: bool`:
If True, it means we only keep at most `self.sep_cache_size` seperators' KVs.
If the number exceeds this limit, older separator's KVs will be discarded, keeping only the most recent `self.sep_cache_size` KVs.
In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`.
`separator_token_ids: List[int]`:
The token ids of the separator tokens for the current model's tokenizer.
We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you
to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).
`PADDING_ID: int`:
The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.
Important Note:
When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache.
However, you must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`.
Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime.
To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality:
`init_cache_size` + `sep_cache_size` + `local_size` < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025]
to leave room for `left_padding_offset`.
Please refer to the `__init__` function's comments for more details on the parameters.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM,
>>> from .custom_generate.generate import SepCache
>>> import torch
>>> from huggingface_hub import login
>>> login("hf_xxxXXXxxx")
>>> def to_cuda(a_dict: dict) -> dict:
>>> new_dict = {}
>>> for k,v in a_dict.items():
>>> if isinstance(v, torch.Tensor):
>>> new_dict[k] = v.cuda()
>>> else:
>>> new_dict[k] = v
>>> return new_dict
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", attn_implementation="flash_attention_2", device_map="cuda:0")
>>> model.bfloat16().cuda()
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
>>> inputs = tokenizer(text="My name is Llama 3", return_tensors="pt")
>>> inputs = to_cuda(inputs)
>>> # Prepare a cache and pass it to model's forward; `layer_num` is the number of layers for the pretrained model.
>>> past_key_values = SepCache(init_cache_size=4, sep_cache_size=128, local_size=256, cache_size=512, layer_num=32, USE_MAX_SEP_CACHE=True, model_type='llama')
>>> # `separator_token_ids` and `PADDING_ID` must also be provided if you are not using `model_type='llama'` like this demo.
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access SepCache filled with keys/values
SepCache()
```
```python
>>> ## When using the `update` function of SepCache to update the keys/values and the past token ids (necessary in SepCache), the current `input_ids` must also be provided.
>>> key_states, value_states = past_key_values.update(
key_states = key_states,
value_states = value_states,
input_ids = input_ids,
layer_idx = layer_idx,
PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
)
```
For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM
"""
# 0. Monkey Patching towards the "forward" function of `XXXAttention` class in order to pass `input_ids` to the `update` function of `SepCache` when calling it.
model_layers = monkey_patching(model, model_atten_forward=llama_atten_forward, verbose=monkey_patch_verbose)
separator_token_ids = str2list(separator_token_ids)
# 1. General sanity checks
# 1.a. A few arguments are not allowed, especially arguments that control caches.
generation_config = kwargs.get("generation_config")
default_global_generation_config = GenerationConfig()
default_model_generation_config = model.generation_config
for arg in UNSUPPORTED_GENERATION_ARGS:
has_custom_gen_config_arg = (
generation_config is not None
# = and not (match global default or match model-specific default)
and not (
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
)
)
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
if kwargs_has_arg or has_custom_gen_config_arg:
raise ValueError(
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
)
# 1.b. The model must be decoder-only
if model.config.is_encoder_decoder:
raise ValueError("This custom generate function only works with decoder-only models")
# 1.c. compatibility with transformers>=4.52: we must pop `custom_generate` from kwargs, otherwise it will result
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
kwargs.pop("custom_generate", None)
sepllm_kwargs = {}
sepllm_kwargs["input_ids"] = kwargs["input_ids"] ## `input_ids` must be passed to the `update` function of `SepCache` when calling it.
kwargs["sepllm_kwargs"] = sepllm_kwargs
# 2. Generate with SepCache
# 2.a. prepare the cache, if it was not passed.
past_key_values = kwargs.pop("past_key_values", None)
if past_key_values is None:
past_key_values = SepCache(
## For SepCache
init_cache_size = init_cache_size,
sep_cache_size = sep_cache_size,
local_size = local_size,
cache_size = cache_size,
SEP_ACCUMULATION = SEP_ACCUMULATION,
USE_MAX_SEP_CACHE = USE_MAX_SEP_CACHE,
SEP_PADDING_IN_BATCH = SEP_PADDING_IN_BATCH,
separator_token_ids = separator_token_ids, ## required for initialization if `model_type` is not provided.
PADDING_ID = PADDING_ID, ## required for initialization if `model_type` is not provided.
## For inheritance & initialization states
past_tok_ids = past_tok_ids, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
key_cache = key_cache,
value_cache = value_cache,
## For debugging
PRINT_KV_RATIO_INSIDE = PRINT_KV_RATIO_INSIDE,
print_KV_inside_per_steps = print_KV_inside_per_steps,
_seen_tokens = _seen_tokens,
_kept_kv_ratio = _kept_kv_ratio,
### For positional encoding shifting
APPLY_PE_SHIFT = APPLY_PE_SHIFT,
APPLY_PES_INSIDE = APPLY_PES_INSIDE,
_shifted_position_ids = _shifted_position_ids,
_rope_unsqueeze_dim = _rope_unsqueeze_dim, ## The unsqueeze_dim when applying RoPE.
_rope_seq_dim =_rope_seq_dim, ## The seq_len dimension for the `cos` or `sin` tensors.
pe_scaling_factor = pe_scaling_factor,
pe_dim = pe_dim, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this, i.e., model.config.hidden_size // model.config.num_attention_heads
max_position_embeddings = max_position_embeddings, # i.e., model.config.max_position_embeddings
base = base, ## The base for RoPE.
## For basic transformer architecture
k_seq_dim = k_seq_dim, ## The dimension for seq_len in key tensors
v_seq_dim = v_seq_dim, ## The dimension for seq_len in value tensors
layer_num = len(model_layers), ## required for initialization. model.config.num_hidden_layers
model_type = model_type, ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
device = device,
)
elif not isinstance(past_key_values, SepCache):
raise ValueError(f"`past_key_values` must be a `SepCache` instance, got a {type(past_key_values)} instance")
# 2.b. generate with the cache
kwargs["use_cache"] = True
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values)
return generation_outputs
|