feat: implemented task_type_ids
Browse files- modeling_bert.py +10 -0
modeling_bert.py
CHANGED
|
@@ -340,14 +340,21 @@ class BertModel(BertPreTrainedModel):
|
|
| 340 |
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 341 |
self.encoder = BertEncoder(config)
|
| 342 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
|
|
| 343 |
|
| 344 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
def forward(
|
| 347 |
self,
|
| 348 |
input_ids,
|
| 349 |
position_ids=None,
|
| 350 |
token_type_ids=None,
|
|
|
|
| 351 |
attention_mask=None,
|
| 352 |
masked_tokens_mask=None,
|
| 353 |
):
|
|
@@ -359,6 +366,9 @@ class BertModel(BertPreTrainedModel):
|
|
| 359 |
hidden_states = self.embeddings(
|
| 360 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 361 |
)
|
|
|
|
|
|
|
|
|
|
| 362 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 363 |
# BERT puts embedding LayerNorm before embedding dropout.
|
| 364 |
if not self.fused_dropout_add_ln:
|
|
|
|
| 340 |
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 341 |
self.encoder = BertEncoder(config)
|
| 342 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 343 |
+
self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
|
| 344 |
|
| 345 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 346 |
+
# We now initialize the task embeddings to 0; We do not use task types during
|
| 347 |
+
# pretraining. When we start using task types during embedding training,
|
| 348 |
+
# we want the model to behave exactly as in pretraining (i.e. task types
|
| 349 |
+
# have no effect).
|
| 350 |
+
self.task_type_embeddings.fill_(0)
|
| 351 |
|
| 352 |
def forward(
|
| 353 |
self,
|
| 354 |
input_ids,
|
| 355 |
position_ids=None,
|
| 356 |
token_type_ids=None,
|
| 357 |
+
task_type_ids=None,
|
| 358 |
attention_mask=None,
|
| 359 |
masked_tokens_mask=None,
|
| 360 |
):
|
|
|
|
| 366 |
hidden_states = self.embeddings(
|
| 367 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 368 |
)
|
| 369 |
+
if task_type_ids is not None:
|
| 370 |
+
hidden_states = hidden_states + self.task_type_embeddings(task_type_ids)
|
| 371 |
+
|
| 372 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 373 |
# BERT puts embedding LayerNorm before embedding dropout.
|
| 374 |
if not self.fused_dropout_add_ln:
|