微型 NNX#

Open In Colab

NNX 核心 API 的教学实现。

核心 API#

import dataclasses
import hashlib
import typing as tp

import jax
import jax.numpy as jnp
from jax import random

A = tp.TypeVar("A")
M = tp.TypeVar("M", bound="Module")
Sharding = tp.Tuple[tp.Optional[str], ...]
Array = jax.Array


class Variable(tp.Generic[A]):

  def __init__(
      self,
      value: A,
      *,
      sharding: tp.Optional[Sharding] = None,
  ):
    self.value = value
    self.sharding = sharding

  def __repr__(self) -> str:
    return (
        f"{type(self).__name__}(value={self.value}, sharding={self.sharding})"
    )

  def __init_subclass__(cls):
    super().__init_subclass__()
    jax.tree_util.register_pytree_node(
        cls,
        lambda x: ((x.value,), (x.sharding,)),
        lambda metadata, value: cls(value[0], sharding=metadata[0]),
    )


class State(dict[str, Variable[tp.Any]]):

  def extract(self, variable_type: tp.Type[Variable]) -> "State":
    return State(
        {
            path: variable
            for path, variable in self.items()
            if isinstance(variable, variable_type)
        }
    )

  def __repr__(self) -> str:
    elems = ",\n  ".join(
        f"'{path}': {variable}".replace("\n", "\n    ")
        for path, variable in self.items()
    )
    return f"State({{\n  {elems}\n}})"


jax.tree_util.register_pytree_node(
    State,
    # in reality, values and paths should be sorted by path
    lambda x: (tuple(x.values()), tuple(x.keys())),
    lambda paths, values: State(dict(zip(paths, values))),
)


@dataclasses.dataclass
class GraphDef(tp.Generic[M]):
  type: tp.Type[M]
  index: int
  submodules: dict[str, tp.Union["GraphDef[Module]", int]]
  static_fields: dict[str, tp.Any]

  def merge(self, state: State) -> M:
    module = GraphDef._build_module_recursive(self, {})
    module.update(state)
    return module

  @staticmethod
  def _build_module_recursive(
      graphdef: tp.Union["GraphDef[M]", int],
      index_to_module: dict[int, "Module"],
  ) -> M:
    if isinstance(graphdef, int):
      return index_to_module[graphdef] # type: ignore

    assert graphdef.index not in index_to_module

    # add a dummy module to the index to avoid infinite recursion
    module = object.__new__(graphdef.type)
    index_to_module[graphdef.index] = module

    submodules = {
        name: GraphDef._build_module_recursive(submodule, index_to_module)
        for name, submodule in graphdef.submodules.items()
    }
    vars(module).update(graphdef.static_fields)
    vars(module).update(submodules)
    return module

  def apply(
      self, state: State
  ) -> tp.Callable[..., tuple[tp.Any, tuple[State, "GraphDef[M]"]]]:
    def _apply(*args, **kwargs):
      module = self.merge(state)
      out = module(*args, **kwargs)  # type: ignore
      return out, module.split()

    return _apply


class Module:

  def split(self: M) -> tp.Tuple[State, GraphDef[M]]:
    state = State()
    graphdef = Module._partition_recursive(
        module=self, module_id_to_index={}, path_parts=(), state=state
    )
    assert isinstance(graphdef, GraphDef)
    return state, graphdef

  @staticmethod
  def _partition_recursive(
      module: M,
      module_id_to_index: dict[int, int],
      path_parts: tp.Tuple[str, ...],
      state: State,
  ) -> tp.Union[GraphDef[M], int]:
    if id(module) in module_id_to_index:
      return module_id_to_index[id(module)]

    index = len(module_id_to_index)
    module_id_to_index[id(module)] = index

    submodules = {}
    static_fields = {}

    # iterate fields sorted by name to ensure deterministic order
    for name, value in sorted(vars(module).items(), key=lambda x: x[0]):
      value_path = (*path_parts, name)
      # if value is a Module, recurse
      if isinstance(value, Module):
        submoduledef = Module._partition_recursive(
            value, module_id_to_index, value_path, state
        )
        submodules[name] = submoduledef
      # if value is a Variable, add to state
      elif isinstance(value, Variable):
        state["/".join(value_path)] = value
      else:  # otherwise, add to graphdef fields
        static_fields[name] = value

    return GraphDef(
        type=type(module),
        index=index,
        submodules=submodules,
        static_fields=static_fields,
    )

  def update(self, state: State) -> None:
    for path, value in state.items():
      path_parts = path.split("/")
      Module._set_value_at_path(self, path_parts, value)

  @staticmethod
  def _set_value_at_path(
      module: "Module", path_parts: tp.Sequence[str], value: Variable[tp.Any]
  ) -> None:
    if len(path_parts) == 1:
      setattr(module, path_parts[0], value)
    else:
      Module._set_value_at_path(
          getattr(module, path_parts[0]), path_parts[1:], value
      )


@dataclasses.dataclass
class Rngs:
  key: jax.Array
  count: int = 0
  count_path: tuple[int, ...] = ()

  def fork(self) -> "Rngs":
    """Forks the context, guaranteeing that all the random numbers generated
    will be different from the ones generated in the original context. Fork is
    used to create a new Rngs that can be passed to a JAX transform"""
    count_path = self.count_path + (self.count,)
    self.count += 1
    return Rngs(self.key, count_path=count_path)

  def make_rng(self) -> jax.Array:
    fold_data = self._stable_hash(self.count_path + (self.count,))
    self.count += 1
    return random.fold_in(self.key, fold_data)  # type: ignore

  @staticmethod
  def _stable_hash(data: tuple[int, ...]) -> int:
    hash_str = " ".join(str(x) for x in data)
    _hash = hashlib.blake2s(hash_str.encode())
    hash_bytes = _hash.digest()
    # uint32 is represented as 4 bytes in big endian
    return int.from_bytes(hash_bytes[:4], byteorder="big")


# in the real NNX Rngs is not a pytree, instead
# it has a split/merge API similar to Module
# but for simplicity we use a pytree here
jax.tree_util.register_pytree_node(
    Rngs,
    lambda x: ((x.key,), (x.count, x.count_path)),
    lambda metadata, value: Rngs(value[0], *metadata),
)

基础层#

class Param(Variable[A]):
  pass


class BatchStat(Variable[A]):
  pass


class Linear(Module):

  def __init__(self, din: int, dout: int, *, rngs: Rngs):
    self.din = din
    self.dout = dout
    key = rngs.make_rng()
    self.w = Param(random.uniform(key, (din, dout)))
    self.b = Param(jnp.zeros((dout,)))

  def __call__(self, x: jax.Array) -> jax.Array:
    return x @ self.w.value + self.b.value


class BatchNorm(Module):

  def __init__(self, din: int, mu: float = 0.95):
    self.mu = mu
    self.scale = Param(jax.numpy.ones((din,)))
    self.bias = Param(jax.numpy.zeros((din,)))
    self.mean = BatchStat(jax.numpy.zeros((din,)))
    self.var = BatchStat(jax.numpy.ones((din,)))

  def __call__(self, x, train: bool) -> jax.Array:
    if train:
      axis = tuple(range(x.ndim - 1))
      mean = jax.numpy.mean(x, axis=axis)
      var = jax.numpy.var(x, axis=axis)
      # ema update
      self.mean.value = self.mu * self.mean.value + (1 - self.mu) * mean
      self.var.value = self.mu * self.var.value + (1 - self.mu) * var
    else:
      mean, var = self.mean.value, self.var.value

    scale, bias = self.scale.value, self.bias.value
    x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias
    return x


class Dropout(Module):

  def __init__(self, rate: float):
    self.rate = rate

  def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:
    if train:
      mask = random.bernoulli(rngs.make_rng(), (1 - self.rate), x.shape)
      x = x * mask / (1 - self.rate)
    return x

层扫描示例#

class Block(Module):

  def __init__(self, din: int, dout: int, *, rngs: Rngs):
    self.linear = Linear(din, dout, rngs=rngs)
    self.bn = BatchNorm(dout)
    self.dropout = Dropout(0.1)

  def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:
    x = self.linear(x)
    x = self.bn(x, train=train)
    x = jax.nn.gelu(x)
    x = self.dropout(x, train=train, rngs=rngs)
    return x


class ScanMLP(Module):

  def __init__(self, hidden_size: int, n_layers: int, *, rngs: Rngs):
    self.n_layers = n_layers

    # lift init
    key = random.split(rngs.make_rng(), n_layers - 1)
    graphdef: GraphDef[Block] = None  # type: ignore

    def init_fn(key):
      nonlocal graphdef
      state, graphdef = Block(
          hidden_size, hidden_size, rngs=Rngs(key)
      ).split()
      return state

    state = jax.vmap(init_fn)(key)
    self.layers = graphdef.merge(state)
    self.linear = Linear(hidden_size, hidden_size, rngs=rngs)

  def __call__(self, x: jax.Array, *, train: bool, rngs: Rngs) -> jax.Array:
    # lift call
    key: jax.Array = random.split(rngs.make_rng(), self.n_layers - 1)  # type: ignore
    state, graphdef = self.layers.split()

    def scan_fn(x, inputs: tuple[jax.Array, State]):
      key, state = inputs
      x, (state, _) = graphdef.apply(state)(x, train=train, rngs=Rngs(key))
      return x, state

    x, state = jax.lax.scan(scan_fn, x, (key, state))
    self.layers.update(state)
    x = self.linear(x)
    return x
module = ScanMLP(hidden_size=10, n_layers=5, rngs=Rngs(random.key(0)))
x = jax.random.normal(random.key(0), (2, 10))
y = module(x, train=True, rngs=Rngs(random.key(1)))

state, graphdef = module.split()
print("state =", jax.tree.map(jnp.shape, state))
print("graphdef =", graphdef)
state = State({
  'layers/bn/bias': Param(value=(4, 10), sharding=None),
  'layers/bn/mean': BatchStat(value=(4, 10), sharding=None),
  'layers/bn/scale': Param(value=(4, 10), sharding=None),
  'layers/bn/var': BatchStat(value=(4, 10), sharding=None),
  'layers/linear/b': Param(value=(4, 10), sharding=None),
  'layers/linear/w': Param(value=(4, 10, 10), sharding=None),
  'linear/b': Param(value=(10,), sharding=None),
  'linear/w': Param(value=(10, 10), sharding=None)
})
graphdef = GraphDef(type=<class '__main__.ScanMLP'>, index=0, submodules={'layers': GraphDef(type=<class '__main__.Block'>, index=1, submodules={'bn': GraphDef(type=<class '__main__.BatchNorm'>, index=2, submodules={}, static_fields={'mu': 0.95}), 'dropout': GraphDef(type=<class '__main__.Dropout'>, index=3, submodules={}, static_fields={'rate': 0.1}), 'linear': GraphDef(type=<class '__main__.Linear'>, index=4, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={}), 'linear': GraphDef(type=<class '__main__.Linear'>, index=5, submodules={}, static_fields={'din': 10, 'dout': 10})}, static_fields={'n_layers': 5})

状态筛选#

# split
params = state.extract(Param)
batch_stats = state.extract(BatchStat)
# merge
state = State({**params, **batch_stats})

print("params =", jax.tree.map(jnp.shape, params))
print("batch_stats =", jax.tree.map(jnp.shape, batch_stats))
params = State({
  'layers/bn/bias': Param(value=(4, 10), sharding=None),
  'layers/bn/scale': Param(value=(4, 10), sharding=None),
  'layers/linear/b': Param(value=(4, 10), sharding=None),
  'layers/linear/w': Param(value=(4, 10, 10), sharding=None),
  'linear/b': Param(value=(10,), sharding=None),
  'linear/w': Param(value=(10, 10), sharding=None)
})
batch_stats = State({
  'layers/bn/mean': BatchStat(value=(4, 10), sharding=None),
  'layers/bn/var': BatchStat(value=(4, 10), sharding=None)
})