File size: 25,804 Bytes
d1c8b5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 |
# yapf: disable
# ruff: noqa: E501
# coding=utf-8
# Copied from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/instella/configuration_instella.py
"""OLMo 2 configuration."""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class InstellaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InstellaModel`]. It is used to instantiate an OLMo2
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the [allenai/Instella-7B-1124-hf](https://huggingface.co/allenai/Instella-7B-1124-hf).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50304):
Vocabulary size of the Instella model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`InstellaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 1):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 50279):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
```python
>>> from transformers import InstellaModel, InstellaConfig
>>> # Initializing a Instella 7B style configuration
>>> configuration = InstellaConfig()
>>> # Initializing a model from the Instella 7B style configuration
>>> model = InstellaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""
model_type = "instella"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=50304,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
use_cache=True,
pad_token_id=1,
bos_token_id=None,
eos_token_id=50279,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
rms_norm_eps=1e-5,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.rms_norm_eps = rms_norm_eps
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
from functools import partial
# from typing import Iterable, List, Optional, Tuple, Union
from typing import Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
# from vllm.attention import Attention, AttentionMetadata
from vllm.attention import Attention
from vllm.config import VllmConfig
# from vllm.config import CacheConfig
# from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.distributed.utils import split_tensor_along_last_dim
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.models.utils import (
is_pp_missing_parameter, make_empty_intermediate_tensors_factory,
make_layers)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
class InstellaAttention(nn.Module):
"""
This is the attention block where the output is computed as
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(self, *,
vllm_config: VllmConfig,
prefix: str = ""
):
super().__init__()
self.config = vllm_config.model_config.hf_config
# assert isinstance(self.config, InstellaConfig)
hidden_size = self.config.hidden_size
self.tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = self.config.num_attention_heads
assert hidden_size % self.total_num_heads == 0
assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = (self.config.num_key_value_heads
or self.total_num_heads)
if self.total_num_kv_heads >= self.tp_size:
assert self.total_num_kv_heads % self.tp_size == 0
else:
assert self.tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.max_position_embeddings = self.config.max_position_embeddings
self.rope_theta = self.config.rope_theta
# Attention input projection. Projects x -> (q, k, v)
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.tp_rank = get_tensor_model_parallel_rank()
self.k_norm = RMSNorm(
self.total_num_kv_heads * self.head_dim,
eps=self.config.rms_norm_eps,
)
self.q_norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
# Rotary embeddings.
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta, # type: ignore
)
self.scaling = self.head_dim**-0.5
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
prefix=prefix,
)
# Attention output projection.
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.o_proj",
)
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
# kv_cache: torch.Tensor,
# attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
# attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class InstellaMLP(nn.Module):
"""
This is the MLP block where the output is computed as
``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))``
(plus another skip connection).
"""
def __init__(self, *,
vllm_config: VllmConfig,
prefix: str = ""
):
super().__init__()
config=vllm_config.model_config.hf_config
# assert isinstance(config, InstellaConfig)
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
# Feed-forward input projection.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.gate_up_proj",
)
# Activation function.
self.act_fn = SiluAndMul()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class InstellaDecoderLayer(nn.Module):
"""
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(self, *,
vllm_config: VllmConfig,
prefix: str = ""
):
super().__init__()
config=vllm_config.model_config.hf_config
# assert isinstance(config, InstellaConfig)
# Attention block.
self.self_attn = InstellaAttention(vllm_config=vllm_config, prefix=f"{prefix}.self_attn")
# MLP block.
self.mlp = InstellaMLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp")
# LayerNorm
self.pre_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
# kv_cache: torch.Tensor,
# attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# Attention block.
residual = hidden_states
hidden_states = self.pre_attention_layernorm(hidden_states)
# hidden_states = self.self_attn(positions, hidden_states, kv_cache,
# attn_metadata)
hidden_states = self.self_attn(positions, hidden_states)
hidden_states = hidden_states + residual
# MLP block.
residual = hidden_states
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class InstellaModel(nn.Module):
def __init__(self, *,
vllm_config: VllmConfig, prefix: str = ""
):
super().__init__()
self.config = vllm_config.model_config.hf_config
# assert isinstance(self.config, InstellaConfig)
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
self.config.num_hidden_layers,
lambda prefix: InstellaDecoderLayer(vllm_config=vllm_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
self.config.hidden_size))
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
# kv_caches: List[torch.Tensor],
# attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
"""
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
else:
hidden_states = self.embed_tokens(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
assert isinstance(hidden_states, torch.Tensor)
# Apply blocks one-by-one.
# for i in range(self.start_layer, self.end_layer):
for layer in self.layers[self.start_layer:self.end_layer]:
# shape: (batch_size, seq_len, d_model)
# hidden_states = self.layers[i](
# positions,
# hidden_states,
# kv_caches[i - self.start_layer],
# attn_metadata,
# )
hidden_states = layer(positions, hidden_states)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
hidden_states = self.norm(hidden_states)
return hidden_states
class InstellaForCausalLM(nn.Module, SupportsPP):
"""
Extremely barebones HF model wrapper.
"""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config=vllm_config.model_config.hf_config
# print(config)
# print(type(config))
# assert isinstance(config, InstellaConfig)
self.config = vllm_config.model_config.hf_config
self.model = InstellaModel(vllm_config=vllm_config, prefix=f"{prefix}.model")
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=vllm_config.quant_config,
prefix=f"{prefix}.lm_head" # maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
# kv_caches: List[torch.Tensor],
# attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
# kv_caches=kv_caches,
# attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if is_pp_missing_parameter(name, self):
continue
# With tie_word_embeddings, we can skip lm_head.weight
# The weight might appear unnecessarily in the files if the model is
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader # type: ignore
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# from modeling_instella import *
# from modeling_instella_vllm import *
# from vllm import ModelRegistry
# ModelRegistry.register_model( "InstellaForCausalLM", InstellaForCausalLM)
# from vllm import LLM
# model = LLM("/localmount/suranjan/OLMo-3B-4T-rmsnorm-QKnorm-dolmino-50B-instella-ultrachat-averaged-10k-sft-smoltalk-openmathinstruct400k-lr1e-5-0108/step30000-unsharded-hf-instella/")
# prompts = [
# "Hello, my name is",
# "The president of the United States is",
# "The capital of France is",
# "The future of AI is",
# ]
# sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# from vllm import LLM, SamplingParams
# prompts = [
# "Hello, my name is",
# "The president of the United States is",
# "The capital of France is",
# "The future of AI is",
# ]
# sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# outputs = llm.generate(prompts, sampling_params)
# # Print the outputs.
# for output in outputs:
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# outputs = model.generate(prompts, sampling_params)
# # Print the outputs.
# for output in outputs:
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|