微型 NNX#
NNX 核心 API 的教学实现。
核心 API#
import dataclasses
import hashlib
import typing as tp
import jax
import jax.numpy as jnp
from jax import random
A = tp.TypeVar("A")
M = tp.TypeVar("M", bound="Module")
Sharding = tp.Tuple[tp.Optional[str], ...]
Array = jax.Array
class Variable(tp.Generic[A]):
def __init__(
self,
value: A,
*,
sharding: tp.Optional[Sharding] = None,
):
self.value = value
self.sharding = sharding
def __repr__(self) -> str:
return (
f"{type(self).__name__}(value={self.value}, sharding={self.sharding})"
)
def __init_subclass__(cls):
super().__init_subclass__()
jax.tree_util.register_pytree_node(
cls,
lambda x: ((x.value,), (x.sharding,)),
lambda metadata, value: cls(value[0], sharding=metadata[0]),
)
class State(dict[str, Variable[tp.Any]]):
def extract(self, variable_type: tp.Type[Variable]) -> "State":
return State(
{
path: variable
for path, variable in self.items()
if isinstance(variable, variable_type)
}
)
def __repr__(self) -> str:
elems = ",\n ".join(
f"'{path}': {variable}".replace("\n", "\n ")
for path, variable in self.items()
)
return f"State({{\n {elems}\n}})"
jax.tree_util.register_pytree_node(
State,
# in reality, values and paths should be sorted by path
lambda x: (tuple(x.values()), tuple(x.keys())),
lambda paths, values: State(dict(zip(paths, values))),
)
@dataclasses.dataclass
class GraphDef(tp.Generic[M]):
type: tp.Type[M]
index: int
submodules: dict[str, tp.Union["GraphDef[Module]", int]]
static_fields: dict[str, tp.Any]
def merge(self, state: State) -> M:
module = GraphDef._build_module_recursive(self, {})
module.update(state)
return module
@staticmethod
def _build_module_recursive(
graphdef: tp.Union["GraphDef[M]", int],
index_to_module: dict[int, "Module"],
) -> M:
if isinstance(graphdef, int):
return index_to_module[graphdef] # type: ignore
assert graphdef.index not in index_to_module
# add a dummy module to the index to avoid infinite recursion
module = object.__new__(graphdef.type)
index_to_module[graphdef.index] = module
submodules = {
name: GraphDef._build_module_recursive(submodule, index_to_module)
for name, submodule in graphdef.submodules.items()
}
vars(module).update(graphdef.static_fields)
vars(module).update(submodules)
return module
def apply(
self, state: State
) -> tp.Callable[..., tuple[tp.Any, tuple[State, "GraphDef[M]"]]]:
def _apply(*args, **kwargs):
module = self.merge(state)
out = module(*args, **kwargs) # type: ignore
return out, module.split()
return _apply
class Module:
def split(self: M) -> tp.Tuple[State, GraphDef[M]]:
state = State()
graphdef = Module._partition_recursive(
module=self, module_id_to_index={}, path_parts=(), state=state
)
assert isinstance(graphdef, GraphDef)
return state, graphdef
@staticmethod
def _partition_recursive(
module: M,
module_id_to_index: dict[int, int],
path_parts: tp.Tuple[str, ...],
state: State,
) -> tp.Union[GraphDef[M], int]:
if id(module) in module_id_to_index:
return module_id_to_index[id(module)]
index = len(module_id_to_index)
module_id_to_index[id(module)] = index
submodules = {}
static_fields = {}
# iterate fields sorted by name to ensure deterministic order
for name, value in sorted(vars(module).items(), key=lambda x: x[0]):
value_path = (*path_parts, name)
# if value is a Module, recurse
if isinstance(value, Module):
submoduledef = Module._partition_recursive(
value, module_id_to_index, value_path, state
)
submodules[name] = submoduledef
# if value is a Variable, add to state
elif isinstance(value, Variable):
state["/".join(value_path)] = value
else: # otherwise, add to graphdef fields
static_fields[name] = value
return GraphDef(
type=type(module),
index=index,
submodules=submodules,
static_fields=static_fields,
)
def update(self, state: State) -> None:
for path, value in state.items():
path_parts = path.split("/")
Module._set_value_at_path(self, path_parts, value)
@staticmethod
def _set_value_at_path(
module: "Module", path_parts: tp.Sequence[str], value: Variable[tp.Any]
) -> None:
if len(path_parts) == 1:
setattr(module, path_parts[0], value)
else:
Module._set_value_at_path(
getattr(module, path_parts[0]), path_parts[1:], value
)
@dataclasses.dataclass
class Rngs:
key: jax.Array
count: int = 0
count_path: tuple[int, ...] = ()
def fork(self) -> "Rngs":
"""Forks the context, guaranteeing that all the random numbers generated
will be different from the ones generated in the original context. Fork is
used to create a new Rngs that can be passed to a JAX transform"""
count_path = self.count_path + (self.count,)
self.count += 1
return Rngs(self.key, count_path=count_path)
def make_rng(self) -> jax.Array:
fold_data = self._stable_hash(self.count_path + (self.count,))
self.count += 1
return random.fold_in(self.key, fold_data) # type: ignore
@staticmethod
def _stable_hash(data: tuple[int, ...]) -> int:
hash_str = " ".join(str(x) for x in data)
_hash = hashlib.blake2s(hash_str.encode())
hash_bytes = _hash.digest()
# uint32 is represented as 4 bytes in big endian
return int.from_bytes(hash_bytes[:4], byteorder="big")
# in the real NNX Rngs is not a pytree, instead
# it has a split/merge API similar to Module
# but for simplicity we use a pytree here
jax.tree_util.register_pytree_node(
Rngs,
lambda x: ((x.key,), (x.count, x.count_path)),
lambda metadata, value: Rngs(value[0], *metadata),
)
基础层#
class Param(Variable[A]):
pass
class BatchStat(Variable[A]):
pass
class Linear(Module):
def __init__(self, din: int, dout: int, *, rngs: Rngs):
self.din = din
self.dout = dout
key = rngs.make_rng()
self.w = Param(random.uniform(key, (din, dout)))
self.b = Param(jnp.zeros((dout,)))
def __call__(self, x: jax.Array) -> jax.Array:
return x @ self.w.value + self.b.value
class BatchNorm(Module):
def __init__(self, din: int, mu: float = 0.95):
self.mu = mu
self.scale = Param(jax.numpy.ones((din,)))
self.bias = Param(jax.numpy.zeros((din,)))
self.mean = BatchStat(jax.numpy.zeros((din,)))
self.var = BatchStat(jax.numpy.ones((din,)))
def __call__(self, x, train: bool) -> jax.Array:
if train:
axis = tuple(range(x.ndim - 1))
mean = jax.numpy.mean(x, axis=axis)
var = jax.numpy.var(x, axis=axis)
# ema update
self.mean.value = self.mu * self.mean.value + (1 - self.mu) * mean
self.var.value = self.mu * self.var.value + (1 - self.mu) * var
else:
mean, var = self.mean.value, self.var.value
scale, bias = self.scale.value, self.bias.value
x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias
return x
class Dropout(Module):
def __init__(self, rate: float):
self.rate = rate
def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:
if train:
mask = random.bernoulli(rngs.make_rng(), (1 - self.rate), x.shape)
x = x * mask / (1 - self.rate)
return x
层扫描示例#
class Block(Module):
def __init__(self, din: int, dout: int, *, rngs: Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.bn = BatchNorm(dout)
self.dropout = Dropout(0.1)
def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:
x = self.linear(x)
x = self.bn(x, train=train)
x = jax.nn.gelu(x)
x = self.dropout(x, train=train, rngs=rngs)
return x
class ScanMLP(Module):
def __init__(self, hidden_size: int, n_layers: int, *, rngs: Rngs):
self.n_layers = n_layers
# lift init
key = random.split(rngs.make_rng(), n_layers - 1)
graphdef: GraphDef[Block] = None # type: ignore
def init_fn(key):
nonlocal graphdef
state, graphdef = Block(
hidden_size, hidden_size, rngs=Rngs(key)
).split()
return state
state = jax.vmap(init_fn)(key)
self.layers = graphdef.merge(state)
self.linear = Linear(hidden_size, hidden_size, rngs=rngs)
def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:
# lift call
key: jax.Array = random.split(rngs.make_rng(), self.n_layers - 1) # type: ignore
state, graphdef = self.layers.split()
def scan_fn(x, inputs: tuple[jax.Array, State]):
key, state = inputs
x, (state, _) = graphdef.apply(state)(x, train=train, rngs=Rngs(key))
return x, state
x, state = jax.lax.scan(scan_fn, x, (key, state))
self.layers.update(state)
x = self.linear(x)
return x
module = ScanMLP(hidden_size=10, n_layers=5, rngs=Rngs(random.key(0)))
x = jax.random.normal(random.key(0), (2, 10))
y = module(x, train=True, rngs=Rngs(random.key(1)))
state, graphdef = module.split()
print("state =", jax.tree.map(jnp.shape, state))
print("graphdef =", graphdef)
state = State({
'layers/bn/bias': Param(value=(4, 10), sharding=None),
'layers/bn/mean': BatchStat(value=(4, 10), sharding=None),
'layers/bn/scale': Param(value=(4, 10), sharding=None),
'layers/bn/var': BatchStat(value=(4, 10), sharding=None),
'layers/linear/b': Param(value=(4, 10), sharding=None),
'layers/linear/w': Param(value=(4, 10, 10), sharding=None),
'linear/b': Param(value=(10,), sharding=None),
'linear/w': Param(value=(10, 10), sharding=None)
})
graphdef = GraphDef(type=<class '__main__.ScanMLP'>, index=0, submodules={'layers': GraphDef(type=<class '__main__.Block'>, index=1, submodules={'bn': GraphDef(type=<class '__main__.BatchNorm'>, index=2, submodules={}, static_fields={'mu': 0.95}), 'dropout': GraphDef(type=<class '__main__.Dropout'>, index=3, submodules={}, static_fields={'rate': 0.1}), 'linear': GraphDef(type=<class '__main__.Linear'>, index=4, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={}), 'linear': GraphDef(type=<class '__main__.Linear'>, index=5, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={'n_layers': 5})
状态筛选#
# split
params = state.extract(Param)
batch_stats = state.extract(BatchStat)
# merge
state = State({**params, **batch_stats})
print("params =", jax.tree.map(jnp.shape, params))
print("batch_stats =", jax.tree.map(jnp.shape, batch_stats))
params = State({
'layers/bn/bias': Param(value=(4, 10), sharding=None),
'layers/bn/scale': Param(value=(4, 10), sharding=None),
'layers/linear/b': Param(value=(4, 10), sharding=None),
'layers/linear/w': Param(value=(4, 10, 10), sharding=None),
'linear/b': Param(value=(10,), sharding=None),
'linear/w': Param(value=(10, 10), sharding=None)
})
batch_stats = State({
'layers/bn/mean': BatchStat(value=(4, 10), sharding=None),
'layers/bn/var': BatchStat(value=(4, 10), sharding=None)
})