Abhaykoul commited on
Commit
0674972
·
verified ·
1 Parent(s): dd12bab

Update modeling_helpingai.py

Browse files
Files changed (1) hide show
  1. modeling_helpingai.py +1020 -969
modeling_helpingai.py CHANGED
@@ -1,969 +1,1020 @@
1
- from typing import Callable, Optional, Union
2
-
3
- import torch
4
- from torch import nn
5
-
6
- from transformers.activations import ACT2FN
7
- from transformers.cache_utils import Cache, DynamicCache
8
- from transformers.generation import GenerationMixin
9
- from transformers.integrations import use_kernel_forward_from_hub
10
- from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
11
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
12
- from transformers.modeling_layers import (
13
- GenericForQuestionAnswering,
14
- GenericForSequenceClassification,
15
- GenericForTokenClassification,
16
- GradientCheckpointingLayer,
17
- )
18
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
20
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
- from transformers.processing_utils import Unpack
22
- from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
23
- from transformers.utils.deprecation import deprecate_kwarg
24
- from transformers.utils.generic import check_model_inputs
25
- from .configuration_helpingai import HelpingAIConfig
26
-
27
-
28
- @use_kernel_forward_from_hub("RMSNorm")
29
- class HelpingAIRMSNorm(nn.Module):
30
- def __init__(self, hidden_size, eps=1e-6):
31
- """
32
- HelpingAIRMSNorm is equivalent to T5LayerNorm
33
- """
34
- super().__init__()
35
- self.weight = nn.Parameter(torch.ones(hidden_size))
36
- self.variance_epsilon = eps
37
-
38
- def forward(self, hidden_states):
39
- input_dtype = hidden_states.dtype
40
- hidden_states = hidden_states.to(torch.float32)
41
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
42
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
43
- return self.weight * hidden_states.to(input_dtype)
44
-
45
- def extra_repr(self):
46
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
47
-
48
-
49
- class HelpingAISemanticEmotionReasoning(nn.Module):
50
- """
51
- Structured Emotional Reasoning (SER) layer for emotional understanding and processing.
52
- Maps emotions to semantic representations and provides contextual emotion analysis.
53
- """
54
- def __init__(self, config: HelpingAIConfig):
55
- super().__init__()
56
- self.config = config
57
- self.emotion_hidden_size = config.emotion_hidden_size
58
- self.hidden_size = config.hidden_size
59
-
60
- # Emotion detection and mapping
61
- self.emotion_detector = nn.Linear(self.hidden_size, self.emotion_hidden_size)
62
- self.emotion_mapper = nn.Linear(self.emotion_hidden_size, self.emotion_hidden_size)
63
-
64
- # Contextual emotion analysis
65
- self.emotion_context = nn.MultiheadAttention(
66
- embed_dim=self.emotion_hidden_size,
67
- num_heads=min(8, self.emotion_hidden_size // 64),
68
- batch_first=True
69
- )
70
-
71
- # Emotion classification heads
72
- self.primary_emotion = nn.Linear(self.emotion_hidden_size, 32) # Primary emotions
73
- self.emotion_intensity = nn.Linear(self.emotion_hidden_size, 1) # Intensity score
74
- self.emotion_valence = nn.Linear(self.emotion_hidden_size, 1) # Positive/negative
75
-
76
- # Output projection
77
- self.emotion_output = nn.Linear(self.emotion_hidden_size, self.hidden_size)
78
- self.emotion_norm = HelpingAIRMSNorm(self.emotion_hidden_size, eps=config.rms_norm_eps)
79
-
80
- # Activation
81
- self.act_fn = ACT2FN[config.hidden_act]
82
-
83
- def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict]:
84
- # Detect emotional content
85
- emotion_features = self.act_fn(self.emotion_detector(hidden_states))
86
- emotion_mapped = self.emotion_mapper(emotion_features)
87
- emotion_mapped = self.emotion_norm(emotion_mapped)
88
-
89
- # Contextual emotion analysis
90
- emotion_context, attention_weights = self.emotion_context(
91
- emotion_mapped, emotion_mapped, emotion_mapped
92
- )
93
-
94
- # Emotion analysis outputs
95
- primary_emotions = self.primary_emotion(emotion_context)
96
- emotion_intensity = torch.sigmoid(self.emotion_intensity(emotion_context))
97
- emotion_valence = torch.tanh(self.emotion_valence(emotion_context))
98
-
99
- # Project back to hidden size
100
- emotion_output = self.emotion_output(emotion_context)
101
-
102
- # Emotion metadata
103
- emotion_metadata = {
104
- "primary_emotions": primary_emotions,
105
- "intensity": emotion_intensity,
106
- "valence": emotion_valence,
107
- "attention_weights": attention_weights
108
- }
109
-
110
- return emotion_output, emotion_metadata
111
-
112
-
113
- class HelpingAIPerspectiveEmotionThreading(nn.Module):
114
- """
115
- Parallel Empathic Threads (PET) layer for multi-threaded emotional reasoning.
116
- Processes multiple perspective threads: relatable, supportive, motivational, analytical.
117
- """
118
- def __init__(self, config: HelpingAIConfig):
119
- super().__init__()
120
- self.config = config
121
- self.hidden_size = config.hidden_size
122
- self.perspective_threads = config.perspective_threads
123
- self.thread_hidden_size = config.emotion_hidden_size
124
-
125
- # Thread-specific processors
126
- self.thread_projections = nn.ModuleList([
127
- nn.Linear(self.hidden_size, self.thread_hidden_size)
128
- for _ in range(self.perspective_threads)
129
- ])
130
-
131
- # Thread names for interpretability
132
- self.thread_names = ["relatable", "supportive", "motivational", "analytical"][:self.perspective_threads]
133
-
134
- # Cross-thread attention for perspective integration
135
- self.cross_thread_attention = nn.MultiheadAttention(
136
- embed_dim=self.thread_hidden_size,
137
- num_heads=min(4, self.thread_hidden_size // 64),
138
- batch_first=True
139
- )
140
-
141
- # Thread-specific processing layers
142
- self.thread_processors = nn.ModuleList([
143
- nn.Sequential(
144
- nn.Linear(self.thread_hidden_size, self.thread_hidden_size * 2),
145
- nn.GELU(),
146
- nn.Linear(self.thread_hidden_size * 2, self.thread_hidden_size),
147
- HelpingAIRMSNorm(self.thread_hidden_size, eps=config.rms_norm_eps)
148
- )
149
- for _ in range(self.perspective_threads)
150
- ])
151
-
152
- # Output integration
153
- self.thread_combiner = nn.Linear(
154
- self.thread_hidden_size * self.perspective_threads,
155
- self.hidden_size
156
- )
157
-
158
- # Thread importance weighting
159
- self.thread_weights = nn.Parameter(torch.ones(self.perspective_threads))
160
-
161
- def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict]:
162
- batch_size, seq_len, _ = hidden_states.shape
163
-
164
- # Process each perspective thread
165
- thread_outputs = []
166
- thread_metadata = {}
167
-
168
- for i, (projection, processor, thread_name) in enumerate(
169
- zip(self.thread_projections, self.thread_processors, self.thread_names)
170
- ):
171
- # Project to thread space
172
- thread_input = projection(hidden_states)
173
-
174
- # Process thread-specific perspective
175
- thread_output = processor(thread_input)
176
- thread_outputs.append(thread_output)
177
-
178
- # Store thread metadata
179
- thread_metadata[f"{thread_name}_activation"] = torch.mean(torch.abs(thread_output))
180
-
181
- # Stack threads for cross-thread attention
182
- stacked_threads = torch.stack(thread_outputs, dim=2) # [batch, seq_len, num_threads, hidden]
183
- stacked_threads = stacked_threads.reshape(batch_size * seq_len, self.perspective_threads, self.thread_hidden_size)
184
-
185
- # Cross-thread attention for perspective integration
186
- integrated_threads, cross_attention = self.cross_thread_attention(
187
- stacked_threads, stacked_threads, stacked_threads
188
- )
189
-
190
- # Apply thread importance weighting
191
- thread_weights_normalized = torch.softmax(self.thread_weights, dim=0)
192
- weighted_threads = integrated_threads * thread_weights_normalized.unsqueeze(0).unsqueeze(-1)
193
-
194
- # Combine threads - use reshape instead of view for memory layout compatibility
195
- combined_threads = weighted_threads.reshape(batch_size, seq_len, -1)
196
- final_output = self.thread_combiner(combined_threads)
197
-
198
- # Thread metadata
199
- thread_metadata.update({
200
- "thread_weights": thread_weights_normalized,
201
- "cross_attention": cross_attention,
202
- "thread_activations": {
203
- name: torch.mean(output) for name, output in zip(self.thread_names, thread_outputs)
204
- }
205
- })
206
-
207
- return final_output, thread_metadata
208
-
209
-
210
- class HelpingAIMultiStageThinking(nn.Module):
211
- """
212
- Multi-stage thinking module for internal reasoning and reflection processes.
213
- Implements cascaded thinking stages with simplified feedback loops.
214
- """
215
- def __init__(self, config: HelpingAIConfig):
216
- super().__init__()
217
- self.config = config
218
- self.hidden_size = config.hidden_size
219
- self.thinking_stages = config.num_thinking_stages
220
- self.thinking_depth = config.thinking_depth
221
-
222
- # Thinking stage processors
223
- self.thinking_layers = nn.ModuleList([
224
- nn.Sequential(
225
- nn.Linear(self.hidden_size, self.hidden_size),
226
- nn.GELU(),
227
- nn.Linear(self.hidden_size, self.hidden_size),
228
- HelpingAIRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
229
- )
230
- for _ in range(self.thinking_stages)
231
- ])
232
-
233
- # Simple reflection mechanism without complex attention
234
- self.reflection_layers = nn.ModuleList([
235
- nn.Linear(self.hidden_size, self.hidden_size)
236
- for _ in range(self.thinking_stages - 1)
237
- ])
238
-
239
- # Stage transition gates
240
- self.stage_gates = nn.ModuleList([
241
- nn.Linear(self.hidden_size, 1) for _ in range(self.thinking_stages - 1)
242
- ])
243
-
244
- # Thinking combination weights
245
- self.stage_combiner = nn.Linear(self.thinking_stages * self.hidden_size, self.hidden_size)
246
-
247
- def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict]:
248
- batch_size, seq_len, _ = hidden_states.shape
249
- thinking_outputs = []
250
- thinking_metadata = {}
251
-
252
- current_thought = hidden_states
253
-
254
- # Multi-stage thinking process
255
- for stage_idx, stage_processor in enumerate(self.thinking_layers):
256
- # Process current thinking stage
257
- current_thought = stage_processor(current_thought)
258
-
259
- # Store stage output
260
- thinking_outputs.append(current_thought)
261
- thinking_metadata[f"stage_{stage_idx}_activation"] = torch.mean(torch.abs(current_thought)).item()
262
-
263
- # Apply reflection if not the last stage
264
- if stage_idx < self.thinking_stages - 1:
265
- # Simple reflection mechanism
266
- reflection = self.reflection_layers[stage_idx](current_thought)
267
- current_thought = current_thought + 0.1 * reflection # Small reflection influence
268
-
269
- # Stage transition gating
270
- gate_weight = torch.sigmoid(self.stage_gates[stage_idx](current_thought))
271
- current_thought = gate_weight * current_thought + (1 - gate_weight) * hidden_states
272
-
273
- # Combine all thinking stages
274
- all_thoughts = torch.cat(thinking_outputs, dim=-1) # Concatenate along hidden dimension
275
- final_thought = self.stage_combiner(all_thoughts)
276
-
277
- thinking_metadata["stage_contributions"] = [
278
- torch.mean(torch.abs(output)).item() for output in thinking_outputs
279
- ]
280
-
281
- return final_thought, thinking_metadata
282
-
283
-
284
- class HelpingAIMLP(nn.Module):
285
- def __init__(self, config):
286
- super().__init__()
287
- self.config = config
288
- self.hidden_size = config.hidden_size
289
- self.intermediate_size = config.intermediate_size
290
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
291
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
292
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
293
- self.act_fn = ACT2FN[config.hidden_act]
294
-
295
- # Enhanced MLP with thinking modules
296
- if hasattr(config, 'use_emotional_reasoning') and config.use_emotional_reasoning:
297
- self.thinking_module = HelpingAIMultiStageThinking(config)
298
- self.use_thinking = True
299
- else:
300
- self.use_thinking = False
301
-
302
- # Reasoning temperature for controlled generation
303
- self.reasoning_temperature = getattr(config, 'reasoning_temperature', 1.0)
304
-
305
- def forward(self, x):
306
- # Standard MLP forward pass
307
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
308
-
309
- # Apply multi-stage thinking if enabled
310
- if self.use_thinking:
311
- thinking_output, thinking_metadata = self.thinking_module(down_proj)
312
- # Apply reasoning temperature
313
- down_proj = down_proj + (thinking_output * self.reasoning_temperature)
314
-
315
- return down_proj
316
-
317
-
318
- def rotate_half(x):
319
- """Rotates half the hidden dims of the input."""
320
- x1 = x[..., : x.shape[-1] // 2]
321
- x2 = x[..., x.shape[-1] // 2 :]
322
- return torch.cat((-x2, x1), dim=-1)
323
-
324
-
325
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
326
- """Applies Rotary Position Embedding to the query and key tensors.
327
-
328
- Args:
329
- q (`torch.Tensor`): The query tensor.
330
- k (`torch.Tensor`): The key tensor.
331
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
332
- sin (`torch.Tensor`): The sine part of the rotary embedding.
333
- position_ids (`torch.Tensor`, *optional*):
334
- Deprecated and unused.
335
- unsqueeze_dim (`int`, *optional*, defaults to 1):
336
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
337
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
338
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
339
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
340
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
341
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
342
- Returns:
343
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
344
- """
345
- cos = cos.unsqueeze(unsqueeze_dim)
346
- sin = sin.unsqueeze(unsqueeze_dim)
347
- q_embed = (q * cos) + (rotate_half(q) * sin)
348
- k_embed = (k * cos) + (rotate_half(k) * sin)
349
- return q_embed, k_embed
350
-
351
-
352
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
353
- """
354
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
355
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
356
- """
357
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
358
- if n_rep == 1:
359
- return hidden_states
360
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
361
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
362
-
363
-
364
- def eager_attention_forward(
365
- module: nn.Module,
366
- query: torch.Tensor,
367
- key: torch.Tensor,
368
- value: torch.Tensor,
369
- attention_mask: Optional[torch.Tensor],
370
- scaling: float,
371
- dropout: float = 0.0,
372
- **kwargs: Unpack[TransformersKwargs],
373
- ):
374
- key_states = repeat_kv(key, module.num_key_value_groups)
375
- value_states = repeat_kv(value, module.num_key_value_groups)
376
-
377
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
378
- if attention_mask is not None:
379
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
380
- attn_weights = attn_weights + causal_mask
381
-
382
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
383
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
384
- attn_output = torch.matmul(attn_weights, value_states)
385
- attn_output = attn_output.transpose(1, 2).contiguous()
386
-
387
- return attn_output, attn_weights
388
-
389
-
390
- class HelpingAIAttention(nn.Module):
391
- """Multi-headed attention with specialized emotional and empathetic reasoning capabilities"""
392
-
393
- def __init__(self, config: HelpingAIConfig, layer_idx: int):
394
- super().__init__()
395
- self.config = config
396
- self.layer_idx = layer_idx
397
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
398
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
399
- self.scaling = self.head_dim**-0.5
400
- self.attention_dropout = config.attention_dropout
401
- self.is_causal = True
402
-
403
- self.q_proj = nn.Linear(
404
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
405
- )
406
- self.k_proj = nn.Linear(
407
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
408
- )
409
- self.v_proj = nn.Linear(
410
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
411
- )
412
- self.o_proj = nn.Linear(
413
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
414
- )
415
- self.q_norm = HelpingAIRMSNorm(self.head_dim, eps=config.rms_norm_eps)
416
- self.k_norm = HelpingAIRMSNorm(self.head_dim, eps=config.rms_norm_eps)
417
- self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
418
-
419
- # Enhanced emotional and empathetic attention
420
- if hasattr(config, 'use_emotional_reasoning') and config.use_emotional_reasoning:
421
- self.num_emotion_heads = getattr(config, 'num_emotion_heads', 4)
422
- self.empathy_scaling_factor = getattr(config, 'empathy_scaling_factor', 1.2)
423
-
424
- # Specialized emotion attention projections
425
- self.emotion_q_proj = nn.Linear(config.hidden_size, self.num_emotion_heads * self.head_dim, bias=False)
426
- self.emotion_k_proj = nn.Linear(config.hidden_size, self.num_emotion_heads * self.head_dim, bias=False)
427
- self.emotion_v_proj = nn.Linear(config.hidden_size, self.num_emotion_heads * self.head_dim, bias=False)
428
-
429
- # Empathy enhancement layer
430
- self.empathy_enhancer = nn.Sequential(
431
- nn.Linear(config.hidden_size, config.hidden_size // 2),
432
- nn.GELU(),
433
- nn.Linear(config.hidden_size // 2, config.num_attention_heads),
434
- nn.Softmax(dim=-1)
435
- )
436
-
437
- self.use_emotional_attention = True
438
- else:
439
- self.use_emotional_attention = False
440
-
441
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
442
- def forward(
443
- self,
444
- hidden_states: torch.Tensor,
445
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
446
- attention_mask: Optional[torch.Tensor],
447
- past_key_values: Optional[Cache] = None,
448
- cache_position: Optional[torch.LongTensor] = None,
449
- **kwargs: Unpack[FlashAttentionKwargs],
450
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
451
- input_shape = hidden_states.shape[:-1]
452
- hidden_shape = (*input_shape, -1, self.head_dim)
453
-
454
- # Standard attention processing
455
- query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
456
- key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
457
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
458
-
459
- cos, sin = position_embeddings
460
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
461
-
462
- if past_key_values is not None:
463
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
464
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
465
-
466
- # Enhanced emotional attention processing
467
- if self.use_emotional_attention:
468
- # Compute empathy weights
469
- empathy_weights = self.empathy_enhancer(hidden_states.mean(dim=1)) # [batch, num_heads]
470
-
471
- # Emotional query, key, value computation
472
- emotion_query = self.emotion_q_proj(hidden_states).view(*input_shape, self.num_emotion_heads, self.head_dim).transpose(1, 2)
473
- emotion_key = self.emotion_k_proj(hidden_states).view(*input_shape, self.num_emotion_heads, self.head_dim).transpose(1, 2)
474
- emotion_value = self.emotion_v_proj(hidden_states).view(*input_shape, self.num_emotion_heads, self.head_dim).transpose(1, 2)
475
-
476
- # Apply rotary embeddings to emotional attention
477
- emotion_query, emotion_key = apply_rotary_pos_emb(emotion_query, emotion_key, cos, sin)
478
-
479
- # Emotional attention computation
480
- emotion_scaling = (self.head_dim ** -0.5) * self.empathy_scaling_factor
481
- emotion_attn_weights = torch.matmul(emotion_query, emotion_key.transpose(2, 3)) * emotion_scaling
482
-
483
- if attention_mask is not None:
484
- emotion_causal_mask = attention_mask[:, :, :, :emotion_key.shape[-2]]
485
- emotion_attn_weights = emotion_attn_weights + emotion_causal_mask
486
-
487
- emotion_attn_weights = nn.functional.softmax(emotion_attn_weights, dim=-1, dtype=torch.float32).to(emotion_query.dtype)
488
- emotion_output = torch.matmul(emotion_attn_weights, emotion_value)
489
-
490
- # Integrate emotional attention with standard attention
491
- # Pad or truncate emotional attention to match standard attention heads
492
- if self.num_emotion_heads < self.config.num_attention_heads:
493
- padding_heads = self.config.num_attention_heads - self.num_emotion_heads
494
- emotion_padding = torch.zeros(
495
- *emotion_output.shape[:-3], padding_heads, *emotion_output.shape[-2:],
496
- device=emotion_output.device, dtype=emotion_output.dtype
497
- )
498
- emotion_output = torch.cat([emotion_output, emotion_padding], dim=1)
499
-
500
- # Standard attention computation
501
- attention_interface: Callable = eager_attention_forward
502
- if self.config._attn_implementation != "eager":
503
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
504
-
505
- attn_output, attn_weights = attention_interface(
506
- self,
507
- query_states,
508
- key_states,
509
- value_states,
510
- attention_mask,
511
- dropout=0.0 if not self.training else self.attention_dropout,
512
- scaling=self.scaling,
513
- sliding_window=self.sliding_window,
514
- **kwargs,
515
- )
516
-
517
- # Blend standard and emotional attention if emotional reasoning is enabled
518
- if self.use_emotional_attention:
519
- # For now, use a simplified approach - just apply empathy scaling
520
- # This avoids the complex tensor dimension matching issues
521
- batch_size, num_heads, seq_len, head_dim = attn_output.shape
522
-
523
- # Get average empathy weight per batch
524
- empathy_scale = torch.mean(empathy_weights, dim=1, keepdim=True) # [batch, 1]
525
- empathy_scale = empathy_scale.view(batch_size, 1, 1, 1) # [batch, 1, 1, 1]
526
- empathy_scale = empathy_scale.expand(batch_size, num_heads, seq_len, head_dim)
527
-
528
- # Apply empathy scaling to attention output
529
- attn_output = attn_output * (1.0 + empathy_scale * 0.1) # Small empathy influence
530
-
531
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
532
- attn_output = self.o_proj(attn_output)
533
- return attn_output, attn_weights
534
-
535
-
536
- class HelpingAIDecoderLayer(GradientCheckpointingLayer):
537
- def __init__(self, config: HelpingAIConfig, layer_idx: int):
538
- super().__init__()
539
- self.hidden_size = config.hidden_size
540
- self.layer_idx = layer_idx
541
-
542
- self.self_attn = HelpingAIAttention(config=config, layer_idx=layer_idx)
543
- self.mlp = HelpingAIMLP(config)
544
- self.input_layernorm = HelpingAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
545
- self.post_attention_layernorm = HelpingAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
546
- self.attention_type = config.layer_types[layer_idx]
547
-
548
- # Enhanced reasoning layers
549
- if hasattr(config, 'use_emotional_reasoning') and config.use_emotional_reasoning:
550
- self.ser_layer = HelpingAISemanticEmotionReasoning(config)
551
- self.use_ser = True
552
- else:
553
- self.use_ser = False
554
-
555
- if hasattr(config, 'use_perspective_threading') and config.use_perspective_threading:
556
- self.pet_layer = HelpingAIPerspectiveEmotionThreading(config)
557
- self.use_pet = True
558
- else:
559
- self.use_pet = False
560
-
561
- # Reasoning integration layers
562
- if self.use_ser or self.use_pet:
563
- self.reasoning_norm = HelpingAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
564
- self.reasoning_gate = nn.Linear(config.hidden_size, 1)
565
-
566
- @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
567
- def forward(
568
- self,
569
- hidden_states: torch.Tensor,
570
- attention_mask: Optional[torch.Tensor] = None,
571
- position_ids: Optional[torch.LongTensor] = None,
572
- past_key_values: Optional[Cache] = None,
573
- use_cache: Optional[bool] = False,
574
- cache_position: Optional[torch.LongTensor] = None,
575
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
576
- **kwargs: Unpack[TransformersKwargs],
577
- ) -> torch.Tensor:
578
- residual = hidden_states
579
- hidden_states = self.input_layernorm(hidden_states)
580
-
581
- # Self Attention
582
- hidden_states, attention_weights = self.self_attn(
583
- hidden_states=hidden_states,
584
- attention_mask=attention_mask,
585
- position_ids=position_ids,
586
- past_key_values=past_key_values,
587
- use_cache=use_cache,
588
- cache_position=cache_position,
589
- position_embeddings=position_embeddings,
590
- **kwargs,
591
- )
592
- hidden_states = residual + hidden_states
593
-
594
- # Enhanced reasoning processing
595
- reasoning_outputs = []
596
- reasoning_metadata = {}
597
-
598
- if self.use_ser:
599
- # Semantic Emotion Reasoning
600
- ser_output, ser_meta = self.ser_layer(hidden_states)
601
- reasoning_outputs.append(ser_output)
602
- reasoning_metadata['ser'] = ser_meta
603
-
604
- if self.use_pet:
605
- # Perspective Emotion Threading
606
- pet_output, pet_meta = self.pet_layer(hidden_states)
607
- reasoning_outputs.append(pet_output)
608
- reasoning_metadata['pet'] = pet_meta
609
-
610
- # Integrate reasoning outputs if any
611
- if reasoning_outputs:
612
- # Combine reasoning outputs
613
- combined_reasoning = torch.stack(reasoning_outputs, dim=0).mean(dim=0)
614
- combined_reasoning = self.reasoning_norm(combined_reasoning)
615
-
616
- # Apply gating to control reasoning influence
617
- reasoning_gate = torch.sigmoid(self.reasoning_gate(hidden_states))
618
- hidden_states = hidden_states + (reasoning_gate * combined_reasoning)
619
-
620
- # Fully Connected (MLP)
621
- residual = hidden_states
622
- hidden_states = self.post_attention_layernorm(hidden_states)
623
- hidden_states = self.mlp(hidden_states)
624
- hidden_states = residual + hidden_states
625
-
626
- # Store reasoning metadata for analysis (optional)
627
- if hasattr(hidden_states, '_reasoning_metadata'):
628
- hidden_states._reasoning_metadata = reasoning_metadata
629
-
630
- return hidden_states
631
-
632
-
633
- @auto_docstring
634
- class HelpingAIPreTrainedModel(PreTrainedModel):
635
- config: HelpingAIConfig
636
- base_model_prefix = "model"
637
- supports_gradient_checkpointing = True
638
- _no_split_modules = ["HelpingAIDecoderLayer"]
639
- _skip_keys_device_placement = ["past_key_values"]
640
- _supports_flash_attn = True
641
- _supports_sdpa = True
642
- _supports_flex_attn = True
643
-
644
- _can_compile_fullgraph = True
645
- _supports_attention_backend = True
646
- _can_record_outputs = {
647
- "hidden_states": HelpingAIDecoderLayer,
648
- "attentions": HelpingAIAttention,
649
- }
650
-
651
-
652
- class HelpingAIRotaryEmbedding(nn.Module):
653
- inv_freq: torch.Tensor # fix linting for `register_buffer`
654
-
655
- def __init__(self, config: HelpingAIConfig, device=None):
656
- super().__init__()
657
- # BC: "rope_type" was originally "type"
658
- if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
659
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
660
- else:
661
- self.rope_type = "default"
662
- self.max_seq_len_cached = config.max_position_embeddings
663
- self.original_max_seq_len = config.max_position_embeddings
664
-
665
- self.config = config
666
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
667
-
668
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
669
- self.register_buffer("inv_freq", inv_freq, persistent=False)
670
- self.original_inv_freq = self.inv_freq
671
-
672
- @torch.no_grad()
673
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
674
- def forward(self, x, position_ids):
675
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
676
- position_ids_expanded = position_ids[:, None, :].float()
677
-
678
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
679
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
680
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
681
- emb = torch.cat((freqs, freqs), dim=-1)
682
- cos = emb.cos() * self.attention_scaling
683
- sin = emb.sin() * self.attention_scaling
684
-
685
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
686
-
687
-
688
- @auto_docstring
689
- class HelpingAIModel(HelpingAIPreTrainedModel):
690
- def __init__(self, config: HelpingAIConfig):
691
- super().__init__(config)
692
- self.padding_idx = config.pad_token_id
693
- self.vocab_size = config.vocab_size
694
-
695
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
696
- self.layers = nn.ModuleList(
697
- [HelpingAIDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
698
- )
699
- self.norm = HelpingAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
700
- self.rotary_emb = HelpingAIRotaryEmbedding(config=config)
701
- self.gradient_checkpointing = False
702
- self.has_sliding_layers = "sliding_attention" in self.config.layer_types
703
-
704
- # Initialize weights and apply final processing
705
- self.post_init()
706
-
707
- @check_model_inputs
708
- @auto_docstring
709
- def forward(
710
- self,
711
- input_ids: Optional[torch.LongTensor] = None,
712
- attention_mask: Optional[torch.Tensor] = None,
713
- position_ids: Optional[torch.LongTensor] = None,
714
- past_key_values: Optional[Cache] = None,
715
- inputs_embeds: Optional[torch.FloatTensor] = None,
716
- use_cache: Optional[bool] = None,
717
- cache_position: Optional[torch.LongTensor] = None,
718
- **kwargs: Unpack[TransformersKwargs],
719
- ) -> BaseModelOutputWithPast:
720
- if (input_ids is None) ^ (inputs_embeds is not None):
721
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
722
-
723
- if inputs_embeds is None:
724
- inputs_embeds = self.embed_tokens(input_ids)
725
-
726
- if use_cache and past_key_values is None:
727
- past_key_values = DynamicCache()
728
-
729
- if cache_position is None:
730
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
731
- cache_position = torch.arange(
732
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
733
- )
734
-
735
- if position_ids is None:
736
- position_ids = cache_position.unsqueeze(0)
737
-
738
- # It may already have been prepared by e.g. `generate`
739
- if not isinstance(causal_mask_mapping := attention_mask, dict):
740
- # Prepare mask arguments
741
- mask_kwargs = {
742
- "config": self.config,
743
- "input_embeds": inputs_embeds,
744
- "attention_mask": attention_mask,
745
- "cache_position": cache_position,
746
- "past_key_values": past_key_values,
747
- "position_ids": position_ids,
748
- }
749
- # Create the masks
750
- causal_mask_mapping = {
751
- "full_attention": create_causal_mask(**mask_kwargs),
752
- }
753
- # The sliding window alternating layers are not always activated depending on the config
754
- if self.has_sliding_layers:
755
- causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
756
-
757
- hidden_states = inputs_embeds
758
-
759
- # create position embeddings to be shared across the decoder layers
760
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
761
-
762
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
763
- hidden_states = decoder_layer(
764
- hidden_states,
765
- attention_mask=causal_mask_mapping[decoder_layer.attention_type],
766
- position_ids=position_ids,
767
- past_key_values=past_key_values,
768
- use_cache=use_cache,
769
- cache_position=cache_position,
770
- position_embeddings=position_embeddings,
771
- **kwargs,
772
- )
773
-
774
- hidden_states = self.norm(hidden_states)
775
- return BaseModelOutputWithPast(
776
- last_hidden_state=hidden_states,
777
- past_key_values=past_key_values if use_cache else None,
778
- )
779
-
780
-
781
- @auto_docstring
782
- class HelpingAIForCausalLM(HelpingAIPreTrainedModel, GenerationMixin):
783
- _tied_weights_keys = ["lm_head.weight"]
784
- _tp_plan = {"lm_head": "colwise_rep"}
785
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
786
-
787
- def __init__(self, config):
788
- super().__init__(config)
789
- self.model = HelpingAIModel(config)
790
- self.vocab_size = config.vocab_size
791
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
792
-
793
- # Enhanced structured output support
794
- if hasattr(config, 'structured_output_vocab_size') and config.structured_output_vocab_size > 0:
795
- self.structured_vocab_size = config.structured_output_vocab_size
796
- self.structured_lm_head = nn.Linear(config.hidden_size, self.structured_vocab_size, bias=False)
797
- self.use_structured_output = True
798
-
799
- # Special token embeddings for structured reasoning
800
- self.structured_token_embeddings = nn.Embedding(self.structured_vocab_size, config.hidden_size)
801
-
802
- # Reasoning mode classifier
803
- self.reasoning_mode_classifier = nn.Sequential(
804
- nn.Linear(config.hidden_size, config.hidden_size // 2),
805
- nn.GELU(),
806
- nn.Linear(config.hidden_size // 2, 4), # think, ser, pet, normal
807
- nn.Softmax(dim=-1)
808
- )
809
- else:
810
- self.use_structured_output = False
811
-
812
- # Initialize weights and apply final processing
813
- self.post_init()
814
-
815
- def set_decoder(self, decoder):
816
- self.model = decoder
817
-
818
- def get_decoder(self):
819
- return self.model
820
-
821
- def get_reasoning_mode_probabilities(self, hidden_states: torch.Tensor) -> torch.Tensor:
822
- """Get probabilities for different reasoning modes: think, ser, pet, normal"""
823
- if self.use_structured_output:
824
- # Use the last token's hidden state for mode classification
825
- last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
826
- mode_probs = self.reasoning_mode_classifier(last_hidden)
827
- return mode_probs
828
- return None
829
-
830
- @can_return_tuple
831
- @auto_docstring
832
- def forward(
833
- self,
834
- input_ids: Optional[torch.LongTensor] = None,
835
- attention_mask: Optional[torch.Tensor] = None,
836
- position_ids: Optional[torch.LongTensor] = None,
837
- past_key_values: Optional[Cache] = None,
838
- inputs_embeds: Optional[torch.FloatTensor] = None,
839
- labels: Optional[torch.LongTensor] = None,
840
- use_cache: Optional[bool] = None,
841
- cache_position: Optional[torch.LongTensor] = None,
842
- logits_to_keep: Union[int, torch.Tensor] = 0,
843
- return_reasoning_metadata: Optional[bool] = False,
844
- **kwargs: Unpack[TransformersKwargs],
845
- ) -> CausalLMOutputWithPast:
846
- r"""
847
- Enhanced HelpingAI forward pass with structured reasoning support.
848
-
849
- Args:
850
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
851
- Indices of input sequence tokens in the vocabulary.
852
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
853
- Mask to avoid performing attention on padding token indices.
854
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
855
- Indices of positions of each input sequence tokens in the position embeddings.
856
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
857
- Pre-computed hidden-states that can be used to speed up autoregressive decoding.
858
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
859
- Embedded representation of the input tokens. Can be used instead of `input_ids`.
860
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
861
- Labels for computing the masked language modeling loss.
862
- use_cache (`bool`, *optional*):
863
- If set to `True`, past key values are returned and can be used to speed up decoding.
864
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
865
- Indices depicting the position of the input tokens in the sequence.
866
- logits_to_keep (`Union[int, torch.Tensor]`, *optional*, defaults to 0):
867
- Number of logits to keep from the end of the sequence.
868
- return_reasoning_metadata (`bool`, *optional*, defaults to `False`):
869
- Whether to return reasoning metadata including SER and PET analysis for structured reasoning.
870
-
871
- Returns:
872
- `CausalLMOutputWithPast`: Model output containing logits, past key values, and optional reasoning metadata.
873
-
874
- Example:
875
-
876
- ```python
877
- >>> from transformers import AutoTokenizer, HelpingAIForCausalLM
878
-
879
- >>> model = HelpingAIForCausalLM.from_pretrained("HelpingAI/HelpingAI-8B")
880
- >>> tokenizer = AutoTokenizer.from_pretrained("HelpingAI/HelpingAI-8B")
881
-
882
- >>> # Standard generation
883
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
884
- >>> inputs = tokenizer(prompt, return_tensors="pt")
885
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
886
- >>> response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0]
887
-
888
- >>> # Structured reasoning generation
889
- >>> outputs = model(inputs.input_ids, return_reasoning_metadata=True)
890
- >>> reasoning_modes = model.get_reasoning_mode_probabilities(outputs.hidden_states)
891
- ```"""
892
- outputs: BaseModelOutputWithPast = self.model(
893
- input_ids=input_ids,
894
- attention_mask=attention_mask,
895
- position_ids=position_ids,
896
- past_key_values=past_key_values,
897
- inputs_embeds=inputs_embeds,
898
- use_cache=use_cache,
899
- cache_position=cache_position,
900
- **kwargs,
901
- )
902
-
903
- hidden_states = outputs.last_hidden_state
904
-
905
- # Standard language modeling head
906
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
907
- logits = self.lm_head(hidden_states[:, slice_indices, :])
908
-
909
- # Enhanced structured output logits
910
- structured_logits = None
911
- reasoning_mode_probs = None
912
- if self.use_structured_output:
913
- structured_logits = self.structured_lm_head(hidden_states[:, slice_indices, :])
914
- reasoning_mode_probs = self.get_reasoning_mode_probabilities(hidden_states)
915
-
916
- loss = None
917
- if labels is not None:
918
- # Standard loss computation
919
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
920
-
921
- # Add structured output loss if applicable
922
- if self.use_structured_output and structured_logits is not None:
923
- # Additional loss term for structured reasoning (if labels include structured tokens)
924
- structured_loss_weight = 0.1 # Weight for structured output loss
925
- structured_loss = self.loss_function(
926
- logits=structured_logits,
927
- labels=labels,
928
- vocab_size=self.structured_vocab_size,
929
- **kwargs
930
- )
931
- loss = loss + (structured_loss_weight * structured_loss)
932
-
933
- # Prepare output with enhanced reasoning metadata
934
- output = CausalLMOutputWithPast(
935
- loss=loss,
936
- logits=logits,
937
- past_key_values=outputs.past_key_values,
938
- hidden_states=outputs.hidden_states,
939
- attentions=outputs.attentions,
940
- )
941
-
942
- # Add custom attributes for reasoning
943
- if return_reasoning_metadata and self.use_structured_output:
944
- output.structured_logits = structured_logits
945
- output.reasoning_mode_probabilities = reasoning_mode_probs
946
-
947
- return output
948
-
949
-
950
- class HelpingAIForSequenceClassification(GenericForSequenceClassification, HelpingAIPreTrainedModel):
951
- pass
952
-
953
-
954
- class HelpingAIForTokenClassification(GenericForTokenClassification, HelpingAIPreTrainedModel):
955
- pass
956
-
957
-
958
- class HelpingAIForQuestionAnswering(GenericForQuestionAnswering, HelpingAIPreTrainedModel):
959
- base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
960
-
961
-
962
- __all__ = [
963
- "HelpingAIForCausalLM",
964
- "HelpingAIForQuestionAnswering",
965
- "HelpingAIPreTrainedModel",
966
- "HelpingAIModel",
967
- "HelpingAIForSequenceClassification",
968
- "HelpingAIForTokenClassification",
969
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from transformers.activations import ACT2FN
7
+ from transformers.cache_utils import Cache, DynamicCache
8
+ from transformers.generation import GenerationMixin
9
+ from transformers.integrations import use_kernel_forward_from_hub
10
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
11
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
12
+ from transformers.modeling_layers import (
13
+ GenericForQuestionAnswering,
14
+ GenericForSequenceClassification,
15
+ GenericForTokenClassification,
16
+ GradientCheckpointingLayer,
17
+ )
18
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
19
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
20
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
21
+ from transformers.processing_utils import Unpack
22
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
23
+ from transformers.utils.deprecation import deprecate_kwarg
24
+ from transformers.utils.generic import check_model_inputs
25
+ from .configuration_helpingai import HelpingAIConfig
26
+
27
+
28
+ @use_kernel_forward_from_hub("RMSNorm")
29
+ class HelpingAIRMSNorm(nn.Module):
30
+ def __init__(self, hidden_size, eps=1e-6):
31
+ """
32
+ HelpingAIRMSNorm is equivalent to T5LayerNorm
33
+ """
34
+ super().__init__()
35
+ self.weight = nn.Parameter(torch.ones(hidden_size))
36
+ self.variance_epsilon = eps
37
+
38
+ def forward(self, hidden_states):
39
+ input_dtype = hidden_states.dtype
40
+ hidden_states = hidden_states.to(torch.float32)
41
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
42
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
43
+ return self.weight * hidden_states.to(input_dtype)
44
+
45
+ def extra_repr(self):
46
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
47
+
48
+
49
+ class HelpingAISemanticEmotionReasoning(nn.Module):
50
+ """
51
+ Structured Emotional Reasoning (SER) layer for emotional understanding and processing.
52
+ Maps emotions to semantic representations and provides contextual emotion analysis.
53
+ """
54
+ def __init__(self, config: HelpingAIConfig):
55
+ super().__init__()
56
+ self.config = config
57
+ self.emotion_hidden_size = config.emotion_hidden_size
58
+ self.hidden_size = config.hidden_size
59
+
60
+ # Emotion detection and mapping
61
+ self.emotion_detector = nn.Linear(self.hidden_size, self.emotion_hidden_size)
62
+ self.emotion_mapper = nn.Linear(self.emotion_hidden_size, self.emotion_hidden_size)
63
+
64
+ # Contextual emotion analysis
65
+ self.emotion_context = nn.MultiheadAttention(
66
+ embed_dim=self.emotion_hidden_size,
67
+ num_heads=min(8, self.emotion_hidden_size // 64),
68
+ batch_first=True
69
+ )
70
+
71
+ # Emotion classification heads
72
+ self.primary_emotion = nn.Linear(self.emotion_hidden_size, 32) # Primary emotions
73
+ self.emotion_intensity = nn.Linear(self.emotion_hidden_size, 1) # Intensity score
74
+ self.emotion_valence = nn.Linear(self.emotion_hidden_size, 1) # Positive/negative
75
+
76
+ # Output projection
77
+ self.emotion_output = nn.Linear(self.emotion_hidden_size, self.hidden_size)
78
+ self.emotion_norm = HelpingAIRMSNorm(self.emotion_hidden_size, eps=config.rms_norm_eps)
79
+
80
+ # Activation
81
+ self.act_fn = ACT2FN[config.hidden_act]
82
+
83
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict]:
84
+ # Detect emotional content
85
+ emotion_features = self.act_fn(self.emotion_detector(hidden_states))
86
+ emotion_mapped = self.emotion_mapper(emotion_features)
87
+ emotion_mapped = self.emotion_norm(emotion_mapped)
88
+
89
+ # Contextual emotion analysis
90
+ emotion_context, attention_weights = self.emotion_context(
91
+ emotion_mapped, emotion_mapped, emotion_mapped
92
+ )
93
+
94
+ # Emotion analysis outputs
95
+ primary_emotions = self.primary_emotion(emotion_context)
96
+ emotion_intensity = torch.sigmoid(self.emotion_intensity(emotion_context))
97
+ emotion_valence = torch.tanh(self.emotion_valence(emotion_context))
98
+
99
+ # Project back to hidden size
100
+ emotion_output = self.emotion_output(emotion_context)
101
+
102
+ # Emotion metadata
103
+ emotion_metadata = {
104
+ "primary_emotions": primary_emotions,
105
+ "intensity": emotion_intensity,
106
+ "valence": emotion_valence,
107
+ "attention_weights": attention_weights
108
+ }
109
+
110
+ return emotion_output, emotion_metadata
111
+
112
+
113
+ class HelpingAIPerspectiveEmotionThreading(nn.Module):
114
+ """
115
+ Parallel Empathic Threads (PET) layer for multi-threaded emotional reasoning.
116
+ Processes multiple perspective threads: relatable, supportive, motivational, analytical.
117
+ """
118
+ def __init__(self, config: HelpingAIConfig):
119
+ super().__init__()
120
+ self.config = config
121
+ self.hidden_size = config.hidden_size
122
+ self.perspective_threads = config.perspective_threads
123
+ self.thread_hidden_size = config.emotion_hidden_size
124
+
125
+ # Thread-specific processors
126
+ self.thread_projections = nn.ModuleList([
127
+ nn.Linear(self.hidden_size, self.thread_hidden_size)
128
+ for _ in range(self.perspective_threads)
129
+ ])
130
+
131
+ # Thread names for interpretability
132
+ self.thread_names = ["relatable", "supportive", "motivational", "analytical"][:self.perspective_threads]
133
+
134
+ # Cross-thread attention for perspective integration
135
+ self.cross_thread_attention = nn.MultiheadAttention(
136
+ embed_dim=self.thread_hidden_size,
137
+ num_heads=min(4, self.thread_hidden_size // 64),
138
+ batch_first=True
139
+ )
140
+
141
+ # Thread-specific processing layers
142
+ self.thread_processors = nn.ModuleList([
143
+ nn.Sequential(
144
+ nn.Linear(self.thread_hidden_size, self.thread_hidden_size * 2),
145
+ nn.GELU(),
146
+ nn.Linear(self.thread_hidden_size * 2, self.thread_hidden_size),
147
+ HelpingAIRMSNorm(self.thread_hidden_size, eps=config.rms_norm_eps)
148
+ )
149
+ for _ in range(self.perspective_threads)
150
+ ])
151
+
152
+ # Output integration
153
+ self.thread_combiner = nn.Linear(
154
+ self.thread_hidden_size * self.perspective_threads,
155
+ self.hidden_size
156
+ )
157
+
158
+ # Thread importance weighting
159
+ self.thread_weights = nn.Parameter(torch.ones(self.perspective_threads))
160
+
161
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict]:
162
+ batch_size, seq_len, _ = hidden_states.shape
163
+
164
+ # Process each perspective thread
165
+ thread_outputs = []
166
+ thread_metadata = {}
167
+
168
+ for i, (projection, processor, thread_name) in enumerate(
169
+ zip(self.thread_projections, self.thread_processors, self.thread_names)
170
+ ):
171
+ # Project to thread space
172
+ thread_input = projection(hidden_states)
173
+
174
+ # Process thread-specific perspective
175
+ thread_output = processor(thread_input)
176
+ thread_outputs.append(thread_output)
177
+
178
+ # Store thread metadata
179
+ thread_metadata[f"{thread_name}_activation"] = torch.mean(torch.abs(thread_output))
180
+
181
+ # Stack threads for cross-thread attention
182
+ stacked_threads = torch.stack(thread_outputs, dim=2) # [batch, seq_len, num_threads, hidden]
183
+ stacked_threads = stacked_threads.reshape(batch_size * seq_len, self.perspective_threads, self.thread_hidden_size)
184
+
185
+ # Cross-thread attention for perspective integration
186
+ integrated_threads, cross_attention = self.cross_thread_attention(
187
+ stacked_threads, stacked_threads, stacked_threads
188
+ )
189
+
190
+ # Apply thread importance weighting
191
+ thread_weights_normalized = torch.softmax(self.thread_weights, dim=0)
192
+ weighted_threads = integrated_threads * thread_weights_normalized.unsqueeze(0).unsqueeze(-1)
193
+
194
+ # Combine threads - use reshape instead of view for memory layout compatibility
195
+ combined_threads = weighted_threads.reshape(batch_size, seq_len, -1)
196
+ final_output = self.thread_combiner(combined_threads)
197
+
198
+ # Thread metadata
199
+ thread_metadata.update({
200
+ "thread_weights": thread_weights_normalized,
201
+ "cross_attention": cross_attention,
202
+ "thread_activations": {
203
+ name: torch.mean(output) for name, output in zip(self.thread_names, thread_outputs)
204
+ }
205
+ })
206
+
207
+ return final_output, thread_metadata
208
+
209
+
210
+ class HelpingAIMultiStageThinking(nn.Module):
211
+ """
212
+ Multi-stage thinking module for internal reasoning and reflection processes.
213
+ Implements cascaded thinking stages with simplified feedback loops.
214
+ """
215
+ def __init__(self, config: HelpingAIConfig):
216
+ super().__init__()
217
+ self.config = config
218
+ self.hidden_size = config.hidden_size
219
+ self.thinking_stages = config.num_thinking_stages
220
+ self.thinking_depth = config.thinking_depth
221
+
222
+ # Thinking stage processors
223
+ self.thinking_layers = nn.ModuleList([
224
+ nn.Sequential(
225
+ nn.Linear(self.hidden_size, self.hidden_size),
226
+ nn.GELU(),
227
+ nn.Linear(self.hidden_size, self.hidden_size),
228
+ HelpingAIRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
229
+ )
230
+ for _ in range(self.thinking_stages)
231
+ ])
232
+
233
+ # Simple reflection mechanism without complex attention
234
+ self.reflection_layers = nn.ModuleList([
235
+ nn.Linear(self.hidden_size, self.hidden_size)
236
+ for _ in range(self.thinking_stages - 1)
237
+ ])
238
+
239
+ # Stage transition gates
240
+ self.stage_gates = nn.ModuleList([
241
+ nn.Linear(self.hidden_size, 1) for _ in range(self.thinking_stages - 1)
242
+ ])
243
+
244
+ # Thinking combination weights
245
+ self.stage_combiner = nn.Linear(self.thinking_stages * self.hidden_size, self.hidden_size)
246
+
247
+ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, dict]:
248
+ batch_size, seq_len, _ = hidden_states.shape
249
+ thinking_outputs = []
250
+ thinking_metadata = {}
251
+
252
+ current_thought = hidden_states
253
+
254
+ # Multi-stage thinking process
255
+ for stage_idx, stage_processor in enumerate(self.thinking_layers):
256
+ # Process current thinking stage
257
+ current_thought = stage_processor(current_thought)
258
+
259
+ # Store stage output
260
+ thinking_outputs.append(current_thought)
261
+ thinking_metadata[f"stage_{stage_idx}_activation"] = torch.mean(torch.abs(current_thought)).item()
262
+
263
+ # Apply reflection if not the last stage
264
+ if stage_idx < self.thinking_stages - 1:
265
+ # Simple reflection mechanism
266
+ reflection = self.reflection_layers[stage_idx](current_thought)
267
+ current_thought = current_thought + 0.1 * reflection # Small reflection influence
268
+
269
+ # Stage transition gating
270
+ gate_weight = torch.sigmoid(self.stage_gates[stage_idx](current_thought))
271
+ current_thought = gate_weight * current_thought + (1 - gate_weight) * hidden_states
272
+
273
+ # Combine all thinking stages
274
+ all_thoughts = torch.cat(thinking_outputs, dim=-1) # Concatenate along hidden dimension
275
+ final_thought = self.stage_combiner(all_thoughts)
276
+
277
+ thinking_metadata["stage_contributions"] = [
278
+ torch.mean(torch.abs(output)).item() for output in thinking_outputs
279
+ ]
280
+
281
+ return final_thought, thinking_metadata
282
+
283
+
284
+ class HelpingAIMLP(nn.Module):
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ self.config = config
288
+ self.hidden_size = config.hidden_size
289
+ self.intermediate_size = config.intermediate_size
290
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
291
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
292
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
293
+ self.act_fn = ACT2FN[config.hidden_act]
294
+
295
+ # Enhanced MLP with thinking modules
296
+ if hasattr(config, 'use_emotional_reasoning') and config.use_emotional_reasoning:
297
+ self.thinking_module = HelpingAIMultiStageThinking(config)
298
+ self.use_thinking = True
299
+ else:
300
+ self.use_thinking = False
301
+
302
+ # Reasoning temperature for controlled generation
303
+ self.reasoning_temperature = getattr(config, 'reasoning_temperature', 1.0)
304
+
305
+ def forward(self, x):
306
+ # Standard MLP forward pass
307
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
308
+
309
+ # Apply multi-stage thinking if enabled
310
+ if self.use_thinking:
311
+ thinking_output, thinking_metadata = self.thinking_module(down_proj)
312
+ # Apply reasoning temperature
313
+ down_proj = down_proj + (thinking_output * self.reasoning_temperature)
314
+
315
+ return down_proj
316
+
317
+
318
+ def rotate_half(x):
319
+ """Rotates half the hidden dims of the input."""
320
+ x1 = x[..., : x.shape[-1] // 2]
321
+ x2 = x[..., x.shape[-1] // 2 :]
322
+ return torch.cat((-x2, x1), dim=-1)
323
+
324
+
325
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
326
+ """Applies Rotary Position Embedding to the query and key tensors.
327
+
328
+ Args:
329
+ q (`torch.Tensor`): The query tensor.
330
+ k (`torch.Tensor`): The key tensor.
331
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
332
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
333
+ position_ids (`torch.Tensor`, *optional*):
334
+ Deprecated and unused.
335
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
336
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
337
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
338
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
339
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
340
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
341
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
342
+ Returns:
343
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
344
+ """
345
+ cos = cos.unsqueeze(unsqueeze_dim)
346
+ sin = sin.unsqueeze(unsqueeze_dim)
347
+ q_embed = (q * cos) + (rotate_half(q) * sin)
348
+ k_embed = (k * cos) + (rotate_half(k) * sin)
349
+ return q_embed, k_embed
350
+
351
+
352
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
353
+ """
354
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
355
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
356
+ """
357
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
358
+ if n_rep == 1:
359
+ return hidden_states
360
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
361
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
362
+
363
+
364
+ def eager_attention_forward(
365
+ module: nn.Module,
366
+ query: torch.Tensor,
367
+ key: torch.Tensor,
368
+ value: torch.Tensor,
369
+ attention_mask: Optional[torch.Tensor],
370
+ scaling: float,
371
+ dropout: float = 0.0,
372
+ **kwargs: Unpack[TransformersKwargs],
373
+ ):
374
+ key_states = repeat_kv(key, module.num_key_value_groups)
375
+ value_states = repeat_kv(value, module.num_key_value_groups)
376
+
377
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
378
+ if attention_mask is not None:
379
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
380
+ attn_weights = attn_weights + causal_mask
381
+
382
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
383
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
384
+ attn_output = torch.matmul(attn_weights, value_states)
385
+ attn_output = attn_output.transpose(1, 2).contiguous()
386
+
387
+ return attn_output, attn_weights
388
+
389
+
390
+ class HelpingAIAttention(nn.Module):
391
+ """Multi-headed attention with specialized emotional and empathetic reasoning capabilities"""
392
+
393
+ def __init__(self, config: HelpingAIConfig, layer_idx: int):
394
+ super().__init__()
395
+ self.config = config
396
+ self.layer_idx = layer_idx
397
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
398
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
399
+ self.scaling = self.head_dim**-0.5
400
+ self.attention_dropout = config.attention_dropout
401
+ self.is_causal = True
402
+
403
+ self.q_proj = nn.Linear(
404
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
405
+ )
406
+ self.k_proj = nn.Linear(
407
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
408
+ )
409
+ self.v_proj = nn.Linear(
410
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
411
+ )
412
+ self.o_proj = nn.Linear(
413
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
414
+ )
415
+ self.q_norm = HelpingAIRMSNorm(self.head_dim, eps=config.rms_norm_eps)
416
+ self.k_norm = HelpingAIRMSNorm(self.head_dim, eps=config.rms_norm_eps)
417
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
418
+
419
+ # Enhanced emotional and empathetic attention
420
+ if hasattr(config, 'use_emotional_reasoning') and config.use_emotional_reasoning:
421
+ self.num_emotion_heads = getattr(config, 'num_emotion_heads', 4)
422
+ self.empathy_scaling_factor = getattr(config, 'empathy_scaling_factor', 1.2)
423
+
424
+ # Specialized emotion attention projections
425
+ self.emotion_q_proj = nn.Linear(config.hidden_size, self.num_emotion_heads * self.head_dim, bias=False)
426
+ self.emotion_k_proj = nn.Linear(config.hidden_size, self.num_emotion_heads * self.head_dim, bias=False)
427
+ self.emotion_v_proj = nn.Linear(config.hidden_size, self.num_emotion_heads * self.head_dim, bias=False)
428
+
429
+ # Empathy enhancement layer
430
+ self.empathy_enhancer = nn.Sequential(
431
+ nn.Linear(config.hidden_size, config.hidden_size // 2),
432
+ nn.GELU(),
433
+ nn.Linear(config.hidden_size // 2, config.num_attention_heads),
434
+ nn.Softmax(dim=-1)
435
+ )
436
+
437
+ self.use_emotional_attention = True
438
+ else:
439
+ self.use_emotional_attention = False
440
+
441
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
442
+ def forward(
443
+ self,
444
+ hidden_states: torch.Tensor,
445
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
446
+ attention_mask: Optional[torch.Tensor],
447
+ past_key_values: Optional[Cache] = None,
448
+ cache_position: Optional[torch.LongTensor] = None,
449
+ **kwargs: Unpack[FlashAttentionKwargs],
450
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
451
+ input_shape = hidden_states.shape[:-1]
452
+ hidden_shape = (*input_shape, -1, self.head_dim)
453
+
454
+ # Standard attention processing
455
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
456
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
457
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
458
+
459
+ cos, sin = position_embeddings
460
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
461
+
462
+ if past_key_values is not None:
463
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
464
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
465
+
466
+ # Enhanced emotional attention processing
467
+ if self.use_emotional_attention:
468
+ # Compute empathy weights
469
+ empathy_weights = self.empathy_enhancer(hidden_states.mean(dim=1)) # [batch, num_heads]
470
+
471
+ # Emotional query, key, value computation
472
+ emotion_query = self.emotion_q_proj(hidden_states).view(*input_shape, self.num_emotion_heads, self.head_dim).transpose(1, 2)
473
+ emotion_key = self.emotion_k_proj(hidden_states).view(*input_shape, self.num_emotion_heads, self.head_dim).transpose(1, 2)
474
+ emotion_value = self.emotion_v_proj(hidden_states).view(*input_shape, self.num_emotion_heads, self.head_dim).transpose(1, 2)
475
+
476
+ # Apply rotary embeddings to emotional attention
477
+ emotion_query, emotion_key = apply_rotary_pos_emb(emotion_query, emotion_key, cos, sin)
478
+
479
+ # Emotional attention computation
480
+ emotion_scaling = (self.head_dim ** -0.5) * self.empathy_scaling_factor
481
+ emotion_attn_weights = torch.matmul(emotion_query, emotion_key.transpose(2, 3)) * emotion_scaling
482
+
483
+ if attention_mask is not None:
484
+ emotion_causal_mask = attention_mask[:, :, :, :emotion_key.shape[-2]]
485
+ emotion_attn_weights = emotion_attn_weights + emotion_causal_mask
486
+
487
+ emotion_attn_weights = nn.functional.softmax(emotion_attn_weights, dim=-1, dtype=torch.float32).to(emotion_query.dtype)
488
+ emotion_output = torch.matmul(emotion_attn_weights, emotion_value)
489
+
490
+ # Integrate emotional attention with standard attention
491
+ # Pad or truncate emotional attention to match standard attention heads
492
+ if self.num_emotion_heads < self.config.num_attention_heads:
493
+ padding_heads = self.config.num_attention_heads - self.num_emotion_heads
494
+ emotion_padding = torch.zeros(
495
+ *emotion_output.shape[:-3], padding_heads, *emotion_output.shape[-2:],
496
+ device=emotion_output.device, dtype=emotion_output.dtype
497
+ )
498
+ emotion_output = torch.cat([emotion_output, emotion_padding], dim=1)
499
+
500
+ # Standard attention computation
501
+ attention_interface: Callable = eager_attention_forward
502
+ if self.config._attn_implementation != "eager":
503
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
504
+
505
+ attn_output, attn_weights = attention_interface(
506
+ self,
507
+ query_states,
508
+ key_states,
509
+ value_states,
510
+ attention_mask,
511
+ dropout=0.0 if not self.training else self.attention_dropout,
512
+ scaling=self.scaling,
513
+ sliding_window=self.sliding_window,
514
+ **kwargs,
515
+ )
516
+
517
+ # Blend standard and emotional attention if emotional reasoning is enabled
518
+ if self.use_emotional_attention:
519
+ # For now, use a simplified approach - just apply empathy scaling
520
+ # This avoids the complex tensor dimension matching issues
521
+ batch_size, num_heads, seq_len, head_dim = attn_output.shape
522
+
523
+ # Get average empathy weight per batch
524
+ empathy_scale = torch.mean(empathy_weights, dim=1, keepdim=True) # [batch, 1]
525
+ empathy_scale = empathy_scale.view(batch_size, 1, 1, 1) # [batch, 1, 1, 1]
526
+ empathy_scale = empathy_scale.expand(batch_size, num_heads, seq_len, head_dim)
527
+
528
+ # Apply empathy scaling to attention output
529
+ attn_output = attn_output * (1.0 + empathy_scale * 0.1) # Small empathy influence
530
+
531
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
532
+ attn_output = self.o_proj(attn_output)
533
+ return attn_output, attn_weights
534
+
535
+
536
+ class HelpingAIDecoderLayer(GradientCheckpointingLayer):
537
+ def __init__(self, config: HelpingAIConfig, layer_idx: int):
538
+ super().__init__()
539
+ self.hidden_size = config.hidden_size
540
+ self.layer_idx = layer_idx
541
+
542
+ self.self_attn = HelpingAIAttention(config=config, layer_idx=layer_idx)
543
+ self.mlp = HelpingAIMLP(config)
544
+ self.input_layernorm = HelpingAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
545
+ self.post_attention_layernorm = HelpingAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
546
+ self.attention_type = config.layer_types[layer_idx]
547
+
548
+ # Enhanced reasoning layers
549
+ if hasattr(config, 'use_emotional_reasoning') and config.use_emotional_reasoning:
550
+ self.ser_layer = HelpingAISemanticEmotionReasoning(config)
551
+ self.use_ser = True
552
+ else:
553
+ self.use_ser = False
554
+
555
+ if hasattr(config, 'use_perspective_threading') and config.use_perspective_threading:
556
+ self.pet_layer = HelpingAIPerspectiveEmotionThreading(config)
557
+ self.use_pet = True
558
+ else:
559
+ self.use_pet = False
560
+
561
+ # Reasoning integration layers
562
+ if self.use_ser or self.use_pet:
563
+ self.reasoning_norm = HelpingAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
564
+ self.reasoning_gate = nn.Linear(config.hidden_size, 1)
565
+
566
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
567
+ def forward(
568
+ self,
569
+ hidden_states: torch.Tensor,
570
+ attention_mask: Optional[torch.Tensor] = None,
571
+ position_ids: Optional[torch.LongTensor] = None,
572
+ past_key_values: Optional[Cache] = None,
573
+ use_cache: Optional[bool] = False,
574
+ cache_position: Optional[torch.LongTensor] = None,
575
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
576
+ **kwargs: Unpack[TransformersKwargs],
577
+ ) -> torch.Tensor:
578
+ residual = hidden_states
579
+ hidden_states = self.input_layernorm(hidden_states)
580
+
581
+ # Self Attention
582
+ hidden_states, attention_weights = self.self_attn(
583
+ hidden_states=hidden_states,
584
+ attention_mask=attention_mask,
585
+ position_ids=position_ids,
586
+ past_key_values=past_key_values,
587
+ use_cache=use_cache,
588
+ cache_position=cache_position,
589
+ position_embeddings=position_embeddings,
590
+ **kwargs,
591
+ )
592
+ hidden_states = residual + hidden_states
593
+
594
+ # Enhanced reasoning processing
595
+ reasoning_outputs = []
596
+ reasoning_metadata = {}
597
+
598
+ if self.use_ser:
599
+ # Semantic Emotion Reasoning
600
+ ser_output, ser_meta = self.ser_layer(hidden_states)
601
+ reasoning_outputs.append(ser_output)
602
+ reasoning_metadata['ser'] = ser_meta
603
+
604
+ if self.use_pet:
605
+ # Perspective Emotion Threading
606
+ pet_output, pet_meta = self.pet_layer(hidden_states)
607
+ reasoning_outputs.append(pet_output)
608
+ reasoning_metadata['pet'] = pet_meta
609
+
610
+ # Integrate reasoning outputs if any
611
+ if reasoning_outputs:
612
+ # Combine reasoning outputs
613
+ combined_reasoning = torch.stack(reasoning_outputs, dim=0).mean(dim=0)
614
+ combined_reasoning = self.reasoning_norm(combined_reasoning)
615
+
616
+ # Apply gating to control reasoning influence
617
+ reasoning_gate = torch.sigmoid(self.reasoning_gate(hidden_states))
618
+ hidden_states = hidden_states + (reasoning_gate * combined_reasoning)
619
+
620
+ # Fully Connected (MLP)
621
+ residual = hidden_states
622
+ hidden_states = self.post_attention_layernorm(hidden_states)
623
+ hidden_states = self.mlp(hidden_states)
624
+ hidden_states = residual + hidden_states
625
+
626
+ # Store reasoning metadata for analysis (optional)
627
+ if hasattr(hidden_states, '_reasoning_metadata'):
628
+ hidden_states._reasoning_metadata = reasoning_metadata
629
+
630
+ return hidden_states
631
+
632
+
633
+ @auto_docstring
634
+ class HelpingAIPreTrainedModel(PreTrainedModel):
635
+ config: HelpingAIConfig
636
+ base_model_prefix = "model"
637
+ supports_gradient_checkpointing = True
638
+ _no_split_modules = ["HelpingAIDecoderLayer"]
639
+ _skip_keys_device_placement = ["past_key_values"]
640
+ _supports_flash_attn = True
641
+ _supports_sdpa = True
642
+ _supports_flex_attn = True
643
+
644
+ _can_compile_fullgraph = True
645
+ _supports_attention_backend = True
646
+ _can_record_outputs = {
647
+ "hidden_states": HelpingAIDecoderLayer,
648
+ "attentions": HelpingAIAttention,
649
+ }
650
+
651
+
652
+ class HelpingAIRotaryEmbedding(nn.Module):
653
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
654
+
655
+ def __init__(self, config: HelpingAIConfig, device=None):
656
+ super().__init__()
657
+ # BC: "rope_type" was originally "type"
658
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
659
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
660
+ else:
661
+ self.rope_type = "default"
662
+ self.max_seq_len_cached = config.max_position_embeddings
663
+ self.original_max_seq_len = config.max_position_embeddings
664
+
665
+ self.config = config
666
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
667
+
668
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
669
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
670
+ self.original_inv_freq = self.inv_freq
671
+
672
+ @torch.no_grad()
673
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
674
+ def forward(self, x, position_ids):
675
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
676
+ position_ids_expanded = position_ids[:, None, :].float()
677
+
678
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
679
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
680
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
681
+ emb = torch.cat((freqs, freqs), dim=-1)
682
+ cos = emb.cos() * self.attention_scaling
683
+ sin = emb.sin() * self.attention_scaling
684
+
685
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
686
+
687
+
688
+ @auto_docstring
689
+ class HelpingAIModel(HelpingAIPreTrainedModel):
690
+ def __init__(self, config: HelpingAIConfig):
691
+ super().__init__(config)
692
+ self.padding_idx = config.pad_token_id
693
+ self.vocab_size = config.vocab_size
694
+
695
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
696
+ self.layers = nn.ModuleList(
697
+ [HelpingAIDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
698
+ )
699
+ self.norm = HelpingAIRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
700
+ self.rotary_emb = HelpingAIRotaryEmbedding(config=config)
701
+ self.gradient_checkpointing = False
702
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
703
+
704
+ # Initialize weights and apply final processing
705
+ self.post_init()
706
+
707
+ @check_model_inputs
708
+ @auto_docstring
709
+ def forward(
710
+ self,
711
+ input_ids: Optional[torch.LongTensor] = None,
712
+ attention_mask: Optional[torch.Tensor] = None,
713
+ position_ids: Optional[torch.LongTensor] = None,
714
+ past_key_values: Optional[Cache] = None,
715
+ inputs_embeds: Optional[torch.FloatTensor] = None,
716
+ use_cache: Optional[bool] = None,
717
+ cache_position: Optional[torch.LongTensor] = None,
718
+ **kwargs: Unpack[TransformersKwargs],
719
+ ) -> BaseModelOutputWithPast:
720
+ if (input_ids is None) ^ (inputs_embeds is not None):
721
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
722
+
723
+ if inputs_embeds is None:
724
+ inputs_embeds = self.embed_tokens(input_ids)
725
+
726
+ if use_cache and past_key_values is None:
727
+ past_key_values = DynamicCache()
728
+
729
+ if cache_position is None:
730
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
731
+ cache_position = torch.arange(
732
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
733
+ )
734
+
735
+ if position_ids is None:
736
+ position_ids = cache_position.unsqueeze(0)
737
+
738
+ # It may already have been prepared by e.g. `generate`
739
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
740
+ # Prepare mask arguments
741
+ mask_kwargs = {
742
+ "config": self.config,
743
+ "input_embeds": inputs_embeds,
744
+ "attention_mask": attention_mask,
745
+ "cache_position": cache_position,
746
+ "past_key_values": past_key_values,
747
+ "position_ids": position_ids,
748
+ }
749
+ # Create the masks
750
+ causal_mask_mapping = {
751
+ "full_attention": create_causal_mask(**mask_kwargs),
752
+ }
753
+ # The sliding window alternating layers are not always activated depending on the config
754
+ if self.has_sliding_layers:
755
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
756
+
757
+ hidden_states = inputs_embeds
758
+
759
+ # create position embeddings to be shared across the decoder layers
760
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
761
+
762
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
763
+ hidden_states = decoder_layer(
764
+ hidden_states,
765
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
766
+ position_ids=position_ids,
767
+ past_key_values=past_key_values,
768
+ use_cache=use_cache,
769
+ cache_position=cache_position,
770
+ position_embeddings=position_embeddings,
771
+ **kwargs,
772
+ )
773
+
774
+ hidden_states = self.norm(hidden_states)
775
+ return BaseModelOutputWithPast(
776
+ last_hidden_state=hidden_states,
777
+ past_key_values=past_key_values if use_cache else None,
778
+ )
779
+
780
+
781
+ @auto_docstring
782
+ class HelpingAIForCausalLM(HelpingAIPreTrainedModel, GenerationMixin):
783
+ _tied_weights_keys = ["lm_head.weight"]
784
+ _tp_plan = {"lm_head": "colwise_rep"}
785
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
786
+
787
+ def __init__(self, config):
788
+ super().__init__(config)
789
+ self.model = HelpingAIModel(config)
790
+ self.vocab_size = config.vocab_size
791
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
792
+
793
+ # Enhanced structured output support
794
+ if hasattr(config, 'structured_output_vocab_size') and config.structured_output_vocab_size > 0:
795
+ self.structured_vocab_size = config.structured_output_vocab_size
796
+ self.structured_lm_head = nn.Linear(config.hidden_size, self.structured_vocab_size, bias=False)
797
+ self.use_structured_output = True
798
+
799
+ # Special token embeddings for structured reasoning
800
+ self.structured_token_embeddings = nn.Embedding(self.structured_vocab_size, config.hidden_size)
801
+
802
+ # Reasoning mode classifier
803
+ self.reasoning_mode_classifier = nn.Sequential(
804
+ nn.Linear(config.hidden_size, config.hidden_size // 2),
805
+ nn.GELU(),
806
+ nn.Linear(config.hidden_size // 2, 4), # think, ser, pet, normal
807
+ nn.Softmax(dim=-1)
808
+ )
809
+ else:
810
+ self.use_structured_output = False
811
+
812
+ # Optional speech output head (predict mel-spectrogram frames)
813
+ self.use_speech_output = getattr(config, "use_speech_output", False)
814
+ if self.use_speech_output:
815
+ self.speech_num_mels = getattr(config, "speech_num_mels", 80)
816
+ self.speech_upsample_factor = getattr(config, "speech_upsample_factor", 1)
817
+ hidden_dim = getattr(config, "speech_head_hidden_dim", None)
818
+ if hidden_dim is None:
819
+ hidden_dim = config.hidden_size // 2
820
+ # Projector from hidden_size -> hidden_dim -> mel bins
821
+ self.speech_proj = nn.Sequential(
822
+ nn.Linear(config.hidden_size, hidden_dim),
823
+ nn.GELU(),
824
+ nn.Linear(hidden_dim, self.speech_num_mels),
825
+ )
826
+ self.speech_loss_type = getattr(config, "speech_loss_type", "l1")
827
+
828
+ # Initialize weights and apply final processing
829
+ self.post_init()
830
+
831
+ def set_decoder(self, decoder):
832
+ self.model = decoder
833
+
834
+ def get_decoder(self):
835
+ return self.model
836
+
837
+ def get_reasoning_mode_probabilities(self, hidden_states: torch.Tensor) -> torch.Tensor:
838
+ """Get probabilities for different reasoning modes: think, ser, pet, normal"""
839
+ if self.use_structured_output:
840
+ # Use the last token's hidden state for mode classification
841
+ last_hidden = hidden_states[:, -1, :] # [batch_size, hidden_size]
842
+ mode_probs = self.reasoning_mode_classifier(last_hidden)
843
+ return mode_probs
844
+ return None
845
+
846
+ @can_return_tuple
847
+ @auto_docstring
848
+ def forward(
849
+ self,
850
+ input_ids: Optional[torch.LongTensor] = None,
851
+ attention_mask: Optional[torch.Tensor] = None,
852
+ position_ids: Optional[torch.LongTensor] = None,
853
+ past_key_values: Optional[Cache] = None,
854
+ inputs_embeds: Optional[torch.FloatTensor] = None,
855
+ labels: Optional[torch.LongTensor] = None,
856
+ # Optional supervision for speech frames: float tensor [B, T_frames, n_mels]
857
+ speech_targets: Optional[torch.FloatTensor] = None,
858
+ use_cache: Optional[bool] = None,
859
+ cache_position: Optional[torch.LongTensor] = None,
860
+ logits_to_keep: Union[int, torch.Tensor] = 0,
861
+ return_reasoning_metadata: Optional[bool] = False,
862
+ **kwargs: Unpack[TransformersKwargs],
863
+ ) -> CausalLMOutputWithPast:
864
+ r"""
865
+ Enhanced HelpingAI forward pass with structured reasoning support.
866
+
867
+ Args:
868
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
869
+ Indices of input sequence tokens in the vocabulary.
870
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
871
+ Mask to avoid performing attention on padding token indices.
872
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
873
+ Indices of positions of each input sequence tokens in the position embeddings.
874
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
875
+ Pre-computed hidden-states that can be used to speed up autoregressive decoding.
876
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
877
+ Embedded representation of the input tokens. Can be used instead of `input_ids`.
878
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
879
+ Labels for computing the masked language modeling loss.
880
+ use_cache (`bool`, *optional*):
881
+ If set to `True`, past key values are returned and can be used to speed up decoding.
882
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
883
+ Indices depicting the position of the input tokens in the sequence.
884
+ logits_to_keep (`Union[int, torch.Tensor]`, *optional*, defaults to 0):
885
+ Number of logits to keep from the end of the sequence.
886
+ return_reasoning_metadata (`bool`, *optional*, defaults to `False`):
887
+ Whether to return reasoning metadata including SER and PET analysis for structured reasoning.
888
+
889
+ Returns:
890
+ `CausalLMOutputWithPast`: Model output containing logits, past key values, and optional reasoning metadata.
891
+
892
+ Example:
893
+
894
+ ```python
895
+ >>> from transformers import AutoTokenizer, HelpingAIForCausalLM
896
+
897
+ >>> model = HelpingAIForCausalLM.from_pretrained("HelpingAI/HelpingAI-8B")
898
+ >>> tokenizer = AutoTokenizer.from_pretrained("HelpingAI/HelpingAI-8B")
899
+
900
+ >>> # Standard generation
901
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
902
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
903
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
904
+ >>> response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0]
905
+
906
+ >>> # Structured reasoning generation
907
+ >>> outputs = model(inputs.input_ids, return_reasoning_metadata=True)
908
+ >>> reasoning_modes = model.get_reasoning_mode_probabilities(outputs.hidden_states)
909
+ ```"""
910
+ outputs: BaseModelOutputWithPast = self.model(
911
+ input_ids=input_ids,
912
+ attention_mask=attention_mask,
913
+ position_ids=position_ids,
914
+ past_key_values=past_key_values,
915
+ inputs_embeds=inputs_embeds,
916
+ use_cache=use_cache,
917
+ cache_position=cache_position,
918
+ **kwargs,
919
+ )
920
+
921
+ hidden_states = outputs.last_hidden_state
922
+
923
+ # Standard language modeling head
924
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
925
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
926
+
927
+ # Enhanced structured output logits
928
+ structured_logits = None
929
+ reasoning_mode_probs = None
930
+ if self.use_structured_output:
931
+ structured_logits = self.structured_lm_head(hidden_states[:, slice_indices, :])
932
+ reasoning_mode_probs = self.get_reasoning_mode_probabilities(hidden_states)
933
+
934
+ # Speech output prediction
935
+ speech_mels = None
936
+ if self.use_speech_output:
937
+ token_level = hidden_states # [B, T_tok, H]
938
+ # Simple temporal upsampling by repetition to approximate frame rate
939
+ if getattr(self, "speech_upsample_factor", 1) > 1:
940
+ token_level = token_level.repeat_interleave(self.speech_upsample_factor, dim=1)
941
+ # Project to mel bins per (upsampled) time-step
942
+ speech_mels = self.speech_proj(token_level) # [B, T_frames, n_mels]
943
+
944
+ loss = None
945
+ if labels is not None:
946
+ # Standard loss computation
947
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
948
+
949
+ # Add structured output loss if applicable
950
+ if self.use_structured_output and structured_logits is not None:
951
+ # Additional loss term for structured reasoning (if labels include structured tokens)
952
+ structured_loss_weight = 0.1 # Weight for structured output loss
953
+ structured_loss = self.loss_function(
954
+ logits=structured_logits,
955
+ labels=labels,
956
+ vocab_size=self.structured_vocab_size,
957
+ **kwargs
958
+ )
959
+ loss = loss + (structured_loss_weight * structured_loss)
960
+
961
+ # Add speech supervision if provided
962
+ if self.use_speech_output and speech_targets is not None:
963
+ # Ensure time dimension alignment by trimming or padding speech_mels to targets
964
+ B, T_pred, M = speech_mels.shape
965
+ B2, T_tgt, M2 = speech_targets.shape
966
+ if B != B2 or M != M2:
967
+ raise ValueError("speech_targets shape mismatch. Expected [B, T, n_mels] with same B and n_mels as model output.")
968
+ if T_pred > T_tgt:
969
+ speech_mels_aligned = speech_mels[:, :T_tgt, :]
970
+ elif T_pred < T_tgt:
971
+ pad = torch.zeros(B, T_tgt - T_pred, M, device=speech_mels.device, dtype=speech_mels.dtype)
972
+ speech_mels_aligned = torch.cat([speech_mels, pad], dim=1)
973
+ else:
974
+ speech_mels_aligned = speech_mels
975
+
976
+ if self.speech_loss_type == "mse":
977
+ speech_loss = nn.functional.mse_loss(speech_mels_aligned, speech_targets)
978
+ else:
979
+ speech_loss = nn.functional.l1_loss(speech_mels_aligned, speech_targets)
980
+ loss = speech_loss if loss is None else (loss + speech_loss)
981
+
982
+ # Prepare output with enhanced reasoning metadata
983
+ output = CausalLMOutputWithPast(
984
+ loss=loss,
985
+ logits=logits,
986
+ past_key_values=outputs.past_key_values,
987
+ hidden_states=outputs.hidden_states,
988
+ attentions=outputs.attentions,
989
+ )
990
+
991
+ # Add custom attributes for reasoning
992
+ if return_reasoning_metadata and self.use_structured_output:
993
+ output.structured_logits = structured_logits
994
+ output.reasoning_mode_probabilities = reasoning_mode_probs
995
+ if self.use_speech_output:
996
+ output.speech_mels = speech_mels
997
+
998
+ return output
999
+
1000
+
1001
+ class HelpingAIForSequenceClassification(GenericForSequenceClassification, HelpingAIPreTrainedModel):
1002
+ pass
1003
+
1004
+
1005
+ class HelpingAIForTokenClassification(GenericForTokenClassification, HelpingAIPreTrainedModel):
1006
+ pass
1007
+
1008
+
1009
+ class HelpingAIForQuestionAnswering(GenericForQuestionAnswering, HelpingAIPreTrainedModel):
1010
+ base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
1011
+
1012
+
1013
+ __all__ = [
1014
+ "HelpingAIForCausalLM",
1015
+ "HelpingAIForQuestionAnswering",
1016
+ "HelpingAIPreTrainedModel",
1017
+ "HelpingAIModel",
1018
+ "HelpingAIForSequenceClassification",
1019
+ "HelpingAIForTokenClassification",
1020
+ ]