File size: 6,493 Bytes
5123cdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import Qwen2ForCausalLM, Cache
from transformers.processing_utils import Unpack
from transformers.utils import ModelOutput, LossKwargs
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
import torch.nn as nn
@dataclass
class CausalLMOutputWithPastAndRegression(ModelOutput):
"""
Class for causal language model (or autoregressive) outputs together with regression ouputs.
Args:
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
regr_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Regression loss (for score prediction).
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Combined loss from language modelling loss and regression loss.
regr_output (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Regression output.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: torch.FloatTensor = None
lm_loss: Optional[torch.FloatTensor] = None
regr_loss: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
regr_output: Optional[torch.FloatTensor] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Qwen2WithRegressionHead(Qwen2ForCausalLM):
def __init__(self, config):
super().__init__(config)
self.regression_head = nn.Linear(config.hidden_size, 1)
self.post_init()
def r_loss_function(self, outputs, labels, regression_labels):
lm_loss = outputs.loss
regression_loss = None
if regression_labels is not None:
regression_output = outputs.regr_output
regression_loss_fct = MSELoss()
regression_loss = regression_loss_fct(regression_output, regression_labels)
total_loss = None
if lm_loss is not None and regression_loss is not None:
total_loss = lm_loss + regression_loss
return {
"loss": total_loss,
"lm_loss": lm_loss,
"regr_loss": regression_loss,
}
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
regression_labels: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPastAndRegression:
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
output_hidden_states=True,
**kwargs,
)
hidden_states = outputs.hidden_states[-1] # last layer hidden states (B x S x D)
pooled_output = hidden_states[:, -1, :] # last token"s hidden state (B x D)
regression_output = self.regression_head(pooled_output)
regression_output = regression_output.squeeze(-1)
outputs.regr_output = regression_output
loss_dict = self.r_loss_function(outputs, labels, regression_labels)
loss_dict["logits"] = outputs.logits
return CausalLMOutputWithPastAndRegression(
loss=loss_dict["loss"],
lm_loss=loss_dict["lm_loss"],
regr_loss=loss_dict["regr_loss"],
logits=outputs.logits,
regr_output=outputs.regr_output,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|