pollux-judge-7b-r / qwen2_regression.py
ai-forever's picture
Upload Qwen2WithRegressionHead
5123cdd verified
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import Qwen2ForCausalLM, Cache
from transformers.processing_utils import Unpack
from transformers.utils import ModelOutput, LossKwargs
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
import torch.nn as nn
@dataclass
class CausalLMOutputWithPastAndRegression(ModelOutput):
"""
Class for causal language model (or autoregressive) outputs together with regression ouputs.
Args:
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
regr_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Regression loss (for score prediction).
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Combined loss from language modelling loss and regression loss.
regr_output (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Regression output.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
logits: torch.FloatTensor = None
lm_loss: Optional[torch.FloatTensor] = None
regr_loss: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
regr_output: Optional[torch.FloatTensor] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Qwen2WithRegressionHead(Qwen2ForCausalLM):
def __init__(self, config):
super().__init__(config)
self.regression_head = nn.Linear(config.hidden_size, 1)
self.post_init()
def r_loss_function(self, outputs, labels, regression_labels):
lm_loss = outputs.loss
regression_loss = None
if regression_labels is not None:
regression_output = outputs.regr_output
regression_loss_fct = MSELoss()
regression_loss = regression_loss_fct(regression_output, regression_labels)
total_loss = None
if lm_loss is not None and regression_loss is not None:
total_loss = lm_loss + regression_loss
return {
"loss": total_loss,
"lm_loss": lm_loss,
"regr_loss": regression_loss,
}
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
regression_labels: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPastAndRegression:
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
output_hidden_states=True,
**kwargs,
)
hidden_states = outputs.hidden_states[-1] # last layer hidden states (B x S x D)
pooled_output = hidden_states[:, -1, :] # last token"s hidden state (B x D)
regression_output = self.regression_head(pooled_output)
regression_output = regression_output.squeeze(-1)
outputs.regr_output = regression_output
loss_dict = self.r_loss_function(outputs, labels, regression_labels)
loss_dict["logits"] = outputs.logits
return CausalLMOutputWithPastAndRegression(
loss=loss_dict["loss"],
lm_loss=loss_dict["lm_loss"],
regr_loss=loss_dict["regr_loss"],
logits=outputs.logits,
regr_output=outputs.regr_output,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)