|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Load and run the PaliGemma model."""
|
|
import functools
|
|
import sys
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
from absl import logging
|
|
|
|
|
|
import jax
|
|
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
import ml_collections
|
|
import numpy as np
|
|
|
|
import big_vision.models.proj.paligemma.gemma_bv
|
|
import big_vision.models.proj.paligemma.paligemma as model_mod
|
|
import big_vision.models.vit
|
|
import big_vision.pp.builder
|
|
import big_vision.pp.tokenizer
|
|
import big_vision.pp.ops_image
|
|
import big_vision.pp.ops_general
|
|
import big_vision.pp.ops_text
|
|
import big_vision.pp.proj.paligemma.ops
|
|
import big_vision.sharding
|
|
import big_vision.trainers.proj.paligemma.predict_fns
|
|
import big_vision.utils as u
|
|
|
|
|
|
|
|
jax.config.update("jax_transfer_guard", "disallow")
|
|
|
|
CKPT = flags.DEFINE_string(
|
|
"ckpt", default=None, help="Path to checkpoint.")
|
|
IMAGE = flags.DEFINE_string(
|
|
"image", default=None, help="Path to input image.")
|
|
|
|
SAMPLER = flags.DEFINE_string(
|
|
"sampler", default="greedy", help="Decoding strategy. Try `nucleus(0.1)`")
|
|
RES = flags.DEFINE_integer(
|
|
"res", default=224, help="Image resolution (224, 448, 896).")
|
|
MAX_DECODE_LEN = flags.DEFINE_integer(
|
|
"max_decode_len", default=128, help="Max total generation steps.")
|
|
PREFILL_LEN = flags.DEFINE_integer(
|
|
"prefill_len", default=32, help="Size of prefill (prompt). "
|
|
"Shorter is faster, but too short will cut off your prompt.")
|
|
|
|
TOKENIZER = "gemma(tokensets=['loc', 'seg'])"
|
|
|
|
|
|
def load_model(ckpt):
|
|
model_cfg = ml_collections.FrozenConfigDict(dict(
|
|
img=dict(variant="So400m/14", pool_type="none", scan=True),
|
|
llm=dict(vocab_size=256_000 + 1024 + 128),
|
|
))
|
|
model = model_mod.Model(**model_cfg)
|
|
params = model_mod.load(None, ckpt, model_cfg)
|
|
return model, params
|
|
|
|
|
|
def info(s, *a):
|
|
logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
|
|
logging.flush()
|
|
|
|
|
|
def main(argv):
|
|
info(f"{argv=}")
|
|
info("Loading model...")
|
|
model, params = load_model(CKPT.value)
|
|
|
|
predict_fns = big_vision.trainers.proj.paligemma.predict_fns.get_all(model)
|
|
|
|
info("Loading tokenizer...")
|
|
tokzr = big_vision.pp.tokenizer.get_tokenizer(TOKENIZER)
|
|
|
|
info("Creating mesh and sharding params...")
|
|
mesh = Mesh(jax.devices(), ("data"))
|
|
repl_sharding = NamedSharding(mesh, PartitionSpec())
|
|
data_sharding = NamedSharding(mesh, PartitionSpec("data"))
|
|
params_sharding = big_vision.sharding.infer_sharding(
|
|
params, strategy=[(".*", "fsdp(axis='data')")], mesh=mesh)
|
|
|
|
|
|
params = jax.tree.map(lambda x, sh: u.reshard(x, sh), params, params_sharding)
|
|
|
|
|
|
pp_fn = big_vision.pp.builder.get_preprocess_fn("|".join([
|
|
f"decode|resize({RES.value})|value_range(-1, 1)",
|
|
f"tok(key='prefix', bos='yes', model={repr(TOKENIZER)})",
|
|
f"tok(key='septok', text='\\n', model={repr(TOKENIZER)})",
|
|
'masked_concat(["prefix", "septok"], mask_ar=[0, 0], mask_input=[1, 1])',
|
|
f'tolen({PREFILL_LEN.value}, pad_value=0, key="text")',
|
|
f'tolen({PREFILL_LEN.value}, pad_value=1, key="mask_ar")',
|
|
f'tolen({PREFILL_LEN.value}, pad_value=0, key="mask_input")',
|
|
'keep("image", "text", "mask_ar", "mask_input")',
|
|
]), log_data=False)
|
|
|
|
decode = functools.partial(
|
|
predict_fns["decode"], devices=jax.devices(),
|
|
eos_token=tokzr.eos_token, max_decode_len=MAX_DECODE_LEN.value,
|
|
sampler=SAMPLER.value)
|
|
|
|
def make_batch(fname, prompt):
|
|
image = open(fname, "rb").read()
|
|
|
|
|
|
example = pp_fn({"image": image, "prefix": np.array(prompt)})
|
|
example["_mask"] = np.array(True)
|
|
|
|
batch = jax.tree.map(lambda x: x[None], example)
|
|
return u.reshard(batch, repl_sharding)
|
|
|
|
info("Precompiling inference function...")
|
|
decode({"params": params}, batch=make_batch(IMAGE.value, "caption en"))
|
|
|
|
info("Type a prompt and press enter, for example 'caption en': ")
|
|
for line in map(str.strip, sys.stdin):
|
|
tokens = decode({"params": params}, batch=make_batch(IMAGE.value, line))
|
|
tokens = jax.device_get(tokens)[0]
|
|
|
|
|
|
print(tokzr.to_str(tokens), file=sys.stderr, flush=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
flags.mark_flag_as_required("ckpt")
|
|
flags.mark_flag_as_required("image")
|
|
app.run(main)
|
|
|