|
from dataclasses import dataclass |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from torch.nn import CrossEntropyLoss, MSELoss |
|
from transformers import Qwen2ForCausalLM, Cache |
|
from transformers.processing_utils import Unpack |
|
from transformers.utils import ModelOutput, LossKwargs |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
import torch.nn as nn |
|
|
|
|
|
@dataclass |
|
class CausalLMOutputWithPastAndRegression(ModelOutput): |
|
""" |
|
Class for causal language model (or autoregressive) outputs together with regression ouputs. |
|
|
|
Args: |
|
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Language modeling loss (for next-token prediction). |
|
regr_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Regression loss (for score prediction). |
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Combined loss from language modelling loss and regression loss. |
|
regr_output (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
|
Regression output. |
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
|
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) |
|
|
|
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
|
`past_key_values` input) to speed up sequential decoding. |
|
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
|
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
|
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
|
|
|
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
|
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
|
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
|
sequence_length)`. |
|
|
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
|
heads. |
|
""" |
|
|
|
logits: torch.FloatTensor = None |
|
lm_loss: Optional[torch.FloatTensor] = None |
|
regr_loss: Optional[torch.FloatTensor] = None |
|
loss: Optional[torch.FloatTensor] = None |
|
regr_output: Optional[torch.FloatTensor] = None |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
|
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... |
|
|
|
|
|
class Qwen2WithRegressionHead(Qwen2ForCausalLM): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.regression_head = nn.Linear(config.hidden_size, 1) |
|
self.post_init() |
|
|
|
def r_loss_function(self, outputs, labels, regression_labels): |
|
lm_loss = outputs.loss |
|
|
|
regression_loss = None |
|
if regression_labels is not None: |
|
regression_output = outputs.regr_output |
|
regression_loss_fct = MSELoss() |
|
regression_loss = regression_loss_fct(regression_output, regression_labels) |
|
|
|
total_loss = None |
|
if lm_loss is not None and regression_loss is not None: |
|
total_loss = lm_loss + regression_loss |
|
|
|
return { |
|
"loss": total_loss, |
|
"lm_loss": lm_loss, |
|
"regr_loss": regression_loss, |
|
} |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Cache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
regression_labels: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
**kwargs: Unpack[KwargsForCausalLM], |
|
) -> CausalLMOutputWithPastAndRegression: |
|
outputs = super().forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
cache_position=cache_position, |
|
logits_to_keep=logits_to_keep, |
|
output_hidden_states=True, |
|
**kwargs, |
|
) |
|
hidden_states = outputs.hidden_states[-1] |
|
pooled_output = hidden_states[:, -1, :] |
|
regression_output = self.regression_head(pooled_output) |
|
regression_output = regression_output.squeeze(-1) |
|
outputs.regr_output = regression_output |
|
|
|
loss_dict = self.r_loss_function(outputs, labels, regression_labels) |
|
loss_dict["logits"] = outputs.logits |
|
|
|
return CausalLMOutputWithPastAndRegression( |
|
loss=loss_dict["loss"], |
|
lm_loss=loss_dict["lm_loss"], |
|
regr_loss=loss_dict["regr_loss"], |
|
logits=outputs.logits, |
|
regr_output=outputs.regr_output, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|