fix: assert is None for other kwargs too
Browse files- modeling_bert.py +0 -4
- modeling_for_glue.py +5 -5
modeling_bert.py
CHANGED
|
@@ -379,16 +379,12 @@ class BertModel(BertPreTrainedModel):
|
|
| 379 |
task_type_ids=None,
|
| 380 |
attention_mask=None,
|
| 381 |
masked_tokens_mask=None,
|
| 382 |
-
head_mask=None,
|
| 383 |
):
|
| 384 |
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
| 385 |
we only want the output for the masked tokens. This means that we only compute the last
|
| 386 |
layer output for these tokens.
|
| 387 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 388 |
"""
|
| 389 |
-
if head_mask is not None:
|
| 390 |
-
raise NotImplementedError('Masking heads is not supported')
|
| 391 |
-
|
| 392 |
hidden_states = self.embeddings(
|
| 393 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 394 |
)
|
|
|
|
| 379 |
task_type_ids=None,
|
| 380 |
attention_mask=None,
|
| 381 |
masked_tokens_mask=None,
|
|
|
|
| 382 |
):
|
| 383 |
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
|
| 384 |
we only want the output for the masked tokens. This means that we only compute the last
|
| 385 |
layer output for these tokens.
|
| 386 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
| 387 |
"""
|
|
|
|
|
|
|
|
|
|
| 388 |
hidden_states = self.embeddings(
|
| 389 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 390 |
)
|
modeling_for_glue.py
CHANGED
|
@@ -51,16 +51,16 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 51 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 52 |
)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
outputs = self.bert(
|
| 55 |
input_ids,
|
| 56 |
attention_mask=attention_mask,
|
| 57 |
token_type_ids=token_type_ids,
|
| 58 |
position_ids=position_ids,
|
| 59 |
-
head_mask=head_mask,
|
| 60 |
-
inputs_embeds=inputs_embeds,
|
| 61 |
-
output_attentions=output_attentions,
|
| 62 |
-
output_hidden_states=output_hidden_states,
|
| 63 |
-
return_dict=return_dict,
|
| 64 |
)
|
| 65 |
|
| 66 |
pooled_output = outputs[1]
|
|
|
|
| 51 |
return_dict if return_dict is not None else self.config.use_return_dict
|
| 52 |
)
|
| 53 |
|
| 54 |
+
assert head_mask is None
|
| 55 |
+
assert inputs_embeds is None
|
| 56 |
+
assert output_attentions is None
|
| 57 |
+
assert output_hidden_states is None
|
| 58 |
+
assert return_dict is None
|
| 59 |
outputs = self.bert(
|
| 60 |
input_ids,
|
| 61 |
attention_mask=attention_mask,
|
| 62 |
token_type_ids=token_type_ids,
|
| 63 |
position_ids=position_ids,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
)
|
| 65 |
|
| 66 |
pooled_output = outputs[1]
|