从 NNX 0.10 到 NNX 0.11#
在本指南中,我们将介绍将 Flax NNX 代码从 Flax 版本 0.10.x
更新到 0.11.x
时所需的代码更改。
在 NNX 变换中使用 Rng#
现在,使用 RNG 的 NNX 层(如 Dropout 或 MultiHeadAttention)会持有一个在构造时提供的 Rngs
对象的 fork
副本,而不是对原始 Rngs
对象的共享引用。这有两个后果: * 它改变了检查点的结构,因为每个层都将拥有唯一的 RNG 状态。 * 它改变了 nnx.split_rngs
与 nnx.vmap
和 nnx.scan
等变换的交互方式,
因为生成的 RNG 状态现在将不会以标量形式存储。
以下是新版本中“对层进行扫描 (scan over layers)”的示例:
import flax.nnx as nnx
class MLP(nnx.Module):
@nnx.split_rngs(splits=5)
@nnx.vmap(in_axes=(0, 0))
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(3, 3, rngs=rngs)
self.bn = nnx.BatchNorm(3, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
self.node = nnx.Param(jnp.ones((2,)))
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def __call__(self, x: jax.Array):
return nnx.gelu(self.dropout(self.bn(self.linear(x))))
import flax.nnx as nnx
class MLP(nnx.Module):
@nnx.split_rngs(splits=5)
@nnx.vmap(in_axes=(0, 0))
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(3, 3, rngs=rngs)
self.bn = nnx.BatchNorm(3, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
self.node = nnx.Param(jnp.ones((2,)))
@nnx.split_rngs(splits=5)
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def __call__(self, x: jax.Array):
return nnx.gelu(self.dropout(self.bn(self.linear(x))))
需要注意的主要事项是,不再需要在 scan
上使用 nnx.split_rngs
,因为 __init__
生成的 RNG 不再是标量形式(它们保留了额外的维度),因此可以直接在 scan
中使用,无需再次分割它们。另外,甚至可以从 __init__
方法中移除 nnx.split_rngs
装饰器,并在将 RNG 传递给模块之前直接使用 Rngs.fork
。
class MLP(nnx.Module):
@nnx.vmap(in_axes=(0, 0))
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(3, 3, rngs=rngs)
self.bn = nnx.BatchNorm(3, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
self.node = nnx.Param(jnp.ones((2,)))
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def __call__(self, x: jax.Array):
return nnx.gelu(self.dropout(self.bn(self.linear(x))))
rngs = nnx.Rngs(0)
mlp = MLP(rngs=rngs.fork(splits=5))
加载带 RNG 的检查点#
在新版本中加载检查点时,您需要丢弃旧的 RNG 结构,并使用新的 RNG 对模型进行部分重新初始化。为此,您可以使用 nnx.jit
来:
从检查点中移除 RNG。
使用新的 RNG 对模型进行部分初始化。
# load checkpoint
checkpointer = ocp.StandardCheckpointer()
checkpoint = checkpointer.restore(path / "state")
@jax.jit
def fix_checkpoint(checkpoint, rngs: nnx.Rngs):
# drop rngs keys
flat_paths = nnx.traversals.flatten_mapping(checkpoint)
flat_paths = {
path[:-1] if path[-1] == "value" else path: value # remove "value" suffix
for path, value in flat_paths.items()
if "rngs" not in path # remove rngs paths
}
checkpoint = nnx.traversals.unflatten_mapping(flat_paths)
# initialize new model with given rngs
model = MyModel(rngs=rngs)
# overwrite model parameters with checkpoint
nnx.update(model, checkpoint)
# get full checkpoint with new rngs
new_checkpoint = nnx.state(model)
return new_checkpoint
checkpoint = fix_checkpoint(checkpoint, rngs=nnx.Rngs(params=0, dropout=1))
checkpointer.save(path.with_name(path.name + "_new"), checkpoint)
之前的代码是高效的,因为 jit
会执行死代码消除 (DCE),所以它实际上不会在内存中初始化现有的模型参数。
优化器更新#
优化器已更新,不再持有对模型的引用。相反,它现在在 update
方法中接收模型和梯度作为参数。具体来说,以下是新的更改:
wrt
构造函数参数现在是必需的。model
属性已被移除。update
方法现在接收(model, grads)
而不是仅接收(grads)
。
from flax import nnx
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
from flax import nnx
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x)
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads)
return loss
包含 NNX 对象的 Pytree#
在新版本中,NNX 模块现在是 Pytree。这意味着您可以直接将它们与 jax.vmap
和 jax.jit
等 JAX 变换一起使用(关于此的更多文档将很快提供)。然而,这也意味着在包含 NNX 模块的结构上使用 jax.tree.*
函数的代码需要考虑到这一点,以保持当前的行为。在这些情况下,解决方案是使用 is_leaf
参数来指定应将 NNX 模块和其他 NNX 对象视为叶节点。
modules = [nnx.Linear(3, 3, rngs=nnx.Rngs(0)), nnx.BatchNorm(3, rngs=nnx.Rngs(1))]
type_names = jax.tree.map(
lambda x: type(x).__name__,
modules,
is_leaf=lambda x: isinstance(x, nnx.Object) # <-- specify that NNX objects are leaves
)