File size: 6,493 Bytes
5123cdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import Qwen2ForCausalLM, Cache
from transformers.processing_utils import Unpack
from transformers.utils import ModelOutput, LossKwargs
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
import torch.nn as nn


@dataclass
class CausalLMOutputWithPastAndRegression(ModelOutput):
    """
    Class for causal language model (or autoregressive) outputs together with regression ouputs.

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        regr_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Regression loss (for score prediction).
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Combined loss from language modelling loss and regression loss.
        regr_output (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Regression output.
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    logits: torch.FloatTensor = None
    lm_loss: Optional[torch.FloatTensor] = None
    regr_loss: Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None
    regr_output: Optional[torch.FloatTensor] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None


class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...


class Qwen2WithRegressionHead(Qwen2ForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.regression_head = nn.Linear(config.hidden_size, 1)
        self.post_init()

    def r_loss_function(self, outputs, labels, regression_labels):
        lm_loss = outputs.loss

        regression_loss = None
        if regression_labels is not None:
            regression_output = outputs.regr_output
            regression_loss_fct = MSELoss()
            regression_loss = regression_loss_fct(regression_output, regression_labels)

        total_loss = None
        if lm_loss is not None and regression_loss is not None:
            total_loss = lm_loss + regression_loss

        return {
            "loss": total_loss,
            "lm_loss": lm_loss,
            "regr_loss": regression_loss,
        }

    def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[Cache] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            regression_labels: Optional[torch.FloatTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            cache_position: Optional[torch.LongTensor] = None,
            logits_to_keep: Union[int, torch.Tensor] = 0,
            **kwargs: Unpack[KwargsForCausalLM],
    ) -> CausalLMOutputWithPastAndRegression:
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            cache_position=cache_position,
            logits_to_keep=logits_to_keep,
            output_hidden_states=True,
            **kwargs,
        )
        hidden_states = outputs.hidden_states[-1]  # last layer hidden states (B x S x D)
        pooled_output = hidden_states[:, -1, :]  # last token"s hidden state (B x D)
        regression_output = self.regression_head(pooled_output)
        regression_output = regression_output.squeeze(-1)
        outputs.regr_output = regression_output

        loss_dict = self.r_loss_function(outputs, labels, regression_labels)
        loss_dict["logits"] = outputs.logits

        return CausalLMOutputWithPastAndRegression(
            loss=loss_dict["loss"],
            lm_loss=loss_dict["lm_loss"],
            regr_loss=loss_dict["regr_loss"],
            logits=outputs.logits,
            regr_output=outputs.regr_output,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )