从 NNX 0.10 到 NNX 0.11#

在本指南中,我们将介绍将 Flax NNX 代码从 Flax 版本 0.10.x 更新到 0.11.x 时所需的代码更改。

在 NNX 变换中使用 Rng#

现在,使用 RNG 的 NNX 层(如 Dropout 或 MultiHeadAttention)会持有一个在构造时提供的 Rngs 对象的 fork 副本,而不是对原始 Rngs 对象的共享引用。这有两个后果: * 它改变了检查点的结构,因为每个层都将拥有唯一的 RNG 状态。 * 它改变了 nnx.split_rngsnnx.vmapnnx.scan 等变换的交互方式,

因为生成的 RNG 状态现在将不会以标量形式存储。

以下是新版本中“对层进行扫描 (scan over layers)”的示例:

import flax.nnx as nnx

class MLP(nnx.Module):
  @nnx.split_rngs(splits=5)
  @nnx.vmap(in_axes=(0, 0))
  def __init__(self, rngs: nnx.Rngs):
    self.linear = nnx.Linear(3, 3, rngs=rngs)
    self.bn = nnx.BatchNorm(3, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)
    self.node = nnx.Param(jnp.ones((2,)))


  @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
  def __call__(self, x: jax.Array):
    return nnx.gelu(self.dropout(self.bn(self.linear(x))))
import flax.nnx as nnx

class MLP(nnx.Module):
  @nnx.split_rngs(splits=5)
  @nnx.vmap(in_axes=(0, 0))
  def __init__(self, rngs: nnx.Rngs):
    self.linear = nnx.Linear(3, 3, rngs=rngs)
    self.bn = nnx.BatchNorm(3, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)
    self.node = nnx.Param(jnp.ones((2,)))

  @nnx.split_rngs(splits=5)
  @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
  def __call__(self, x: jax.Array):
    return nnx.gelu(self.dropout(self.bn(self.linear(x))))

需要注意的主要事项是,不再需要在 scan 上使用 nnx.split_rngs,因为 __init__ 生成的 RNG 不再是标量形式(它们保留了额外的维度),因此可以直接在 scan 中使用,无需再次分割它们。另外,甚至可以从 __init__ 方法中移除 nnx.split_rngs 装饰器,并在将 RNG 传递给模块之前直接使用 Rngs.fork

class MLP(nnx.Module):
  @nnx.vmap(in_axes=(0, 0))
  def __init__(self, rngs: nnx.Rngs):
    self.linear = nnx.Linear(3, 3, rngs=rngs)
    self.bn = nnx.BatchNorm(3, rngs=rngs)
    self.dropout = nnx.Dropout(0.5, rngs=rngs)
    self.node = nnx.Param(jnp.ones((2,)))

  @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
  def __call__(self, x: jax.Array):
    return nnx.gelu(self.dropout(self.bn(self.linear(x))))

rngs = nnx.Rngs(0)
mlp = MLP(rngs=rngs.fork(splits=5))

加载带 RNG 的检查点#

在新版本中加载检查点时,您需要丢弃旧的 RNG 结构,并使用新的 RNG 对模型进行部分重新初始化。为此,您可以使用 nnx.jit 来:

  1. 从检查点中移除 RNG。

  2. 使用新的 RNG 对模型进行部分初始化。

# load checkpoint
checkpointer = ocp.StandardCheckpointer()
checkpoint = checkpointer.restore(path / "state")

@jax.jit
def fix_checkpoint(checkpoint, rngs: nnx.Rngs):
  # drop rngs keys
  flat_paths = nnx.traversals.flatten_mapping(checkpoint)
  flat_paths = {
      path[:-1] if path[-1] == "value" else path: value  # remove "value" suffix
      for path, value in flat_paths.items()
      if "rngs" not in path  # remove rngs paths
  }
  checkpoint = nnx.traversals.unflatten_mapping(flat_paths)

  # initialize new model with given rngs
  model = MyModel(rngs=rngs)
  # overwrite model parameters with checkpoint
  nnx.update(model, checkpoint)
  # get full checkpoint with new rngs
  new_checkpoint = nnx.state(model)

  return new_checkpoint

checkpoint = fix_checkpoint(checkpoint, rngs=nnx.Rngs(params=0, dropout=1))
checkpointer.save(path.with_name(path.name + "_new"), checkpoint)

之前的代码是高效的,因为 jit 会执行死代码消除 (DCE),所以它实际上不会在内存中初始化现有的模型参数。

优化器更新#

优化器已更新,不再持有对模型的引用。相反,它现在在 update 方法中接收模型和梯度作为参数。具体来说,以下是新的更改:

  1. wrt 构造函数参数现在是必需的。

  2. model 属性已被移除。

  3. update 方法现在接收 (model, grads) 而不是仅接收 (grads)

from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

@nnx.jit
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(model, grads)

  return loss
from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))

@nnx.jit
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(grads)

  return loss

包含 NNX 对象的 Pytree#

在新版本中,NNX 模块现在是 Pytree。这意味着您可以直接将它们与 jax.vmapjax.jit 等 JAX 变换一起使用(关于此的更多文档将很快提供)。然而,这也意味着在包含 NNX 模块的结构上使用 jax.tree.* 函数的代码需要考虑到这一点,以保持当前的行为。在这些情况下,解决方案是使用 is_leaf 参数来指定应将 NNX 模块和其他 NNX 对象视为叶节点。

modules = [nnx.Linear(3, 3, rngs=nnx.Rngs(0)), nnx.BatchNorm(3, rngs=nnx.Rngs(1))]

type_names = jax.tree.map(
    lambda x: type(x).__name__,
    modules,
    is_leaf=lambda x: isinstance(x, nnx.Object)  # <-- specify that NNX objects are leaves
)