|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Big vision sharding utilities."""
|
|
|
|
from absl import logging
|
|
|
|
from big_vision.pp.registry import Registry
|
|
import big_vision.utils as u
|
|
import flax.linen as nn
|
|
import jax
|
|
import numpy as np
|
|
|
|
|
|
NamedSharding = jax.sharding.NamedSharding
|
|
P = jax.sharding.PartitionSpec
|
|
|
|
|
|
def _replicated(mesh):
|
|
return NamedSharding(mesh, P())
|
|
|
|
|
|
def _shard_along_axis(mesh, i, axis_name):
|
|
return NamedSharding(mesh, P(*((None,) * i + (axis_name,))))
|
|
|
|
|
|
def infer_sharding(params, strategy, mesh):
|
|
"""Infers `params` sharding based on strategy.
|
|
|
|
Args:
|
|
params: a pytree of arrays.
|
|
strategy: sharding strategy.
|
|
mesh: jax device mesh.
|
|
|
|
Returns:
|
|
A pytree with shardings, that has the same shape as the `tree` argument.
|
|
"""
|
|
patterns, tactics = zip(*strategy)
|
|
|
|
x_with_names, tree_def = u.tree_flatten_with_names(params)
|
|
names = tree_def.unflatten(list(zip(*x_with_names))[0])
|
|
|
|
|
|
|
|
mask_trees = u.make_mask_trees(params, patterns)
|
|
|
|
specs = jax.tree.map(lambda x: (None,) * x.ndim, params)
|
|
|
|
for mask_tree, tactic in zip(mask_trees, tactics):
|
|
for op_str in tactic.split("|"):
|
|
op = Registry.lookup(f"shardings.{op_str}")()
|
|
specs = jax.tree.map(
|
|
lambda x, n, match, spec, op=op: op(spec, mesh, n, x)
|
|
if match else spec,
|
|
params, names, mask_tree, specs,
|
|
is_leaf=lambda v: isinstance(v, nn.Partitioned))
|
|
|
|
|
|
specs = jax.tree.map(lambda _, spec: P(*spec), nn.unbox(params), specs)
|
|
return jax.tree.map(lambda spec: NamedSharding(mesh, spec), specs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@Registry.register("shardings.replicate")
|
|
def replicate():
|
|
"""Full replication sharding rule.
|
|
|
|
Note full replication is deafult, so this can be skipped and useful to
|
|
explicitly state in the config that certrain parameters are replicated.
|
|
TODO: can be generalized to support replication over a sub-mesh.
|
|
|
|
Returns:
|
|
A function that updates the sharding spec.
|
|
"""
|
|
def _update_spec(cur_spec, mesh, name, x):
|
|
del x, mesh
|
|
if not all(axis is None for axis in cur_spec):
|
|
raise ValueError(f"Inconsistent sharding instructions: "
|
|
f"parameter {name} has spec {cur_spec}, "
|
|
f"so it can't be fully replicated.")
|
|
return cur_spec
|
|
return _update_spec
|
|
|
|
|
|
@Registry.register("shardings.fsdp")
|
|
def fsdp(axis, min_size_to_shard_mb=4):
|
|
"""FSDP sharding rule.
|
|
|
|
Shards the largest dimension that is not sharded already and is divisible
|
|
by the total device count.
|
|
|
|
Args:
|
|
axis: mesh axis name for FSDP, or a collection of names.
|
|
min_size_to_shard_mb: minimal tensor size to bother with sharding.
|
|
|
|
Returns:
|
|
A function that updates the sharding spec.
|
|
"""
|
|
axis = axis if isinstance(axis, str) else tuple(axis)
|
|
axis_tuple = axis if isinstance(axis, tuple) else (axis,)
|
|
def _update_spec(cur_spec, mesh, name, x):
|
|
shape = x.shape
|
|
axis_size = np.prod([mesh.shape[a] for a in axis_tuple])
|
|
|
|
if np.prod(shape) * x.dtype.itemsize <= min_size_to_shard_mb * (2 ** 20):
|
|
return cur_spec
|
|
|
|
|
|
idx = np.argsort(shape)[::-1]
|
|
for i in idx:
|
|
if shape[i] % axis_size == 0:
|
|
if cur_spec[i] is None:
|
|
return cur_spec[:i] + (axis,) + cur_spec[i+1:]
|
|
|
|
logging.info("Failed to apply `fsdp` rule to the parameter %s:%s, as all "
|
|
"its dimensions are not divisible by the requested axis: "
|
|
"%s:%i, or already occupied by other sharding rules: %s",
|
|
name, shape, axis, axis_size, cur_spec)
|
|
return cur_spec
|
|
return _update_spec
|
|
|
|
|
|
@Registry.register("shardings.logical_partitioning")
|
|
def logical_partitioning():
|
|
"""Manual sharding based on Flax's logical partitioning annotations.
|
|
|
|
Uses logical sharding annotations added in model code with
|
|
`nn.with_logical_partitioning`. Respects logical to mesh name mapping rules
|
|
(typically defined in the dynamic context using
|
|
`with nn.logical_axis_rules(rules): ...`).
|
|
|
|
Returns:
|
|
A function that outputs the sharding spec of `nn.LogicallyPartitioned` boxed
|
|
specs.
|
|
"""
|
|
def _update_spec(cur_spec, mesh, name, x):
|
|
del x, name, mesh
|
|
if isinstance(cur_spec, nn.LogicallyPartitioned):
|
|
return nn.logical_to_mesh_axes(cur_spec.names)
|
|
return cur_spec
|
|
return _update_spec
|
|
|
|
|
|
@Registry.register("shardings.shard_dim")
|
|
def shard_dim(axis, dim, ignore_ndim_error=False):
|
|
"""Shards the given dimension along the given axis.
|
|
|
|
Args:
|
|
axis: mesh axis name for sharding.
|
|
dim: dimension to shard (can be negative).
|
|
ignore_ndim_error: if True, a warning error is logged instead of raising an
|
|
exception when the given dimension is not compatible with the number of
|
|
dimensions of the array.
|
|
|
|
Returns:
|
|
A function that updates the sharding spec.
|
|
"""
|
|
def _update_spec(cur_spec, mesh, name, x):
|
|
del mesh, x
|
|
if np.abs(dim) >= len(cur_spec):
|
|
msg = f"Cannot shard_dim({axis}, {dim}): name={name} cur_spec={cur_spec}"
|
|
if ignore_ndim_error:
|
|
logging.warning(msg)
|
|
return cur_spec
|
|
else:
|
|
raise ValueError(msg)
|
|
pos_dim = dim
|
|
if pos_dim < 0:
|
|
pos_dim += len(cur_spec)
|
|
if cur_spec[pos_dim] is not None:
|
|
raise ValueError(
|
|
f"Already sharded: shard_dim({axis}, {dim}):"
|
|
f" name={name} cur_spec={cur_spec}"
|
|
)
|
|
new_spec = cur_spec[:pos_dim] + (axis,) + cur_spec[pos_dim + 1 :]
|
|
return new_spec
|
|
|
|
return _update_spec
|
|
|