|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""BERT encoder, optionally loading pre-trained checkpoints."""
|
|
|
|
import dataclasses
|
|
from typing import Optional
|
|
|
|
from absl import logging
|
|
from big_vision import utils
|
|
from big_vision.models import common
|
|
import flax
|
|
import flax.linen as nn
|
|
import jax.numpy as jnp
|
|
from tensorflow.io import gfile
|
|
|
|
from flaxformer.architectures.bert import bert
|
|
from flaxformer.architectures.bert import bert_checkpoint_converter
|
|
from flaxformer.architectures.bert import configs
|
|
|
|
|
|
class Model(nn.Module):
|
|
"""BERT encoder with linear projection on last layer CLS token."""
|
|
|
|
config: str
|
|
num_classes: Optional[int] = None
|
|
head_zeroinit: bool = True
|
|
|
|
@nn.compact
|
|
def __call__(self, text, *, train=False):
|
|
out = {}
|
|
|
|
batch_size, max_len = text.shape
|
|
bert_model = bert.BertEncoder(**dataclasses.asdict({
|
|
"base": configs.BertBaseConfig(),
|
|
"large": configs.BertLargeConfig(),
|
|
}[self.config]))
|
|
x = out["transformed"] = bert_model(
|
|
token_ids=text,
|
|
position_ids=jnp.tile(
|
|
jnp.arange(0, max_len, dtype=jnp.int32), [batch_size, 1]),
|
|
segment_ids=jnp.zeros([batch_size, max_len], dtype=jnp.int32),
|
|
input_mask=text.astype(jnp.bool_).astype(jnp.int32),
|
|
enable_dropout=train,
|
|
)
|
|
|
|
x = out["pre_logits"] = x[:, 0]
|
|
|
|
if self.num_classes:
|
|
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
|
|
x = out["logits"] = nn.Dense(self.num_classes, name="head", **kw)(x)
|
|
|
|
return x, out
|
|
|
|
|
|
def load(params, path, model_cfg=None, dont_load=()):
|
|
"""Returns `params` with BERT weights replaced from checkpoint at `path`."""
|
|
del model_cfg
|
|
|
|
checkpoint_path = f"{path}/bert_model.ckpt"
|
|
if gfile.exists(f"{checkpoint_path}.index"):
|
|
logging.info("Loading original BERT checkpoint from '%s'", checkpoint_path)
|
|
params = flax.core.FrozenDict(params).unfreeze()
|
|
max_len = (
|
|
params["BertEncoder_0"]["embedder"]["embedders_position_ids"]
|
|
["embedding"].shape[0])
|
|
bert_params, pooler_params = (
|
|
bert_checkpoint_converter.load_params_from_tf_checkpoint(
|
|
checkpoint_path=f"{path}/bert_model.ckpt"))
|
|
del pooler_params
|
|
if isinstance(bert_params, flax.core.FrozenDict):
|
|
bert_params = bert_params.unfreeze()
|
|
bert_params["embedder"]["embedders_position_ids"]["embedding"] = (
|
|
bert_params["embedder"]["embedders_position_ids"]["embedding"][:max_len]
|
|
)
|
|
return common.merge_params(
|
|
{"BertEncoder_0": bert_params}, params, dont_load)
|
|
|
|
logging.info(
|
|
"Could not find original BERT checkpoint path '%s', "
|
|
"loading big_vision checkpoint '%s'", checkpoint_path, path)
|
|
restored_params = utils.load_params(path)
|
|
return common.merge_params(restored_params, params, dont_load)
|
|
|