为何选择 Flax NNX?#
2020 年,Flax 团队发布了 Flax Linen API,以支持 JAX 上的建模研究,重点关注扩展性和性能。从那时起,我们从用户那里学到了很多。团队引入了一些被证明对用户有益的想法,例如
将变量组织成集合。
自动且高效的伪随机数生成器 (PRNG) 管理。
用于单程序多数据 (SPMD) 标注、优化器元数据和其他用例的变量元数据。
Flax 团队做出的选择之一是,通过参数的延迟初始化,为神经网络编程使用函数式 (compact
) 语义。这使得实现代码更加简洁,并使 Flax Linen API 与 Haiku 保持一致。
然而,这也意味着 Flax 中的模块和变量的语义不符合 Python 风格,并且常常出人意料。它还导致了实现的复杂性,并模糊了对神经网络进行变换 (transforms) 的核心思想。
Flax NNX 简介#
快进到 2024 年,Flax 团队开发了 Flax NNX——它试图保留 Flax Linen 对用户有用的特性,同时引入一些新原则。Flax NNX 背后的核心思想是将引用语义引入 JAX。以下是其主要特性:
NNX 符合 Python 风格:模块遵循常规的 Python 语义,包括支持可变性和共享引用。
NNX 很简单:Flax Linen 中的许多复杂 API 要么使用 Python 惯用法进行了简化,要么被完全移除。
更好的 JAX 集成:自定义的 NNX 变换采用与 JAX 变换相同的 API。并且使用 NNX 可以更轻松地直接使用 JAX 变换(高阶函数)。
下面是一个简单的 Flax NNX 程序示例,它阐明了上述许多要点:
from flax import nnx
import optax
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.2, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # Eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
@nnx.jit # Automatic state management for JAX transforms.
def train_step(model, optimizer, x, y):
def loss_fn(model):
y_pred = model(x) # call methods directly
return ((y_pred - y) ** 2).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads) # in-place updates
return loss
Flax NNX 对 Linen 的改进#
本文档的其余部分使用各种示例来演示 Flax NNX 如何改进 Flax Linen。
检查#
第一个改进是 Flax NNX 模块是常规的 Python 对象。这意味着您可以轻松地构建和检查 Module
对象。
另一方面,Flax Linen 模块不容易检查和调试,因为它们是延迟的,这意味着某些属性在构建时不可用,只能在运行时访问。
class Block(nn.Module):
def setup(self):
self.linear = nn.Dense(10)
block = Block()
try:
block.linear # AttributeError: "Block" object has no attribute "linear".
except AttributeError as e:
pass
...
class Block(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(5, 10, rngs=rngs)
block = Block(nnx.Rngs(0))
block.linear
# Linear(
# kernel=Param(
# value=Array(shape=(5, 10), dtype=float32)
# ),
# bias=Param(
# value=Array(shape=(10,), dtype=float32)
# ),
# ...
请注意,在上面的 Flax NNX 示例中,没有形状推断——输入和输出形状都必须提供给 Linear
nnx.Module
。这是一个权衡,它允许更明确和可预测的行为。
运行计算#
在 Flax Linen 中,所有顶层计算都必须通过 flax.linen.Module.init
或 flax.linen.Module.apply
方法完成,并且参数或任何其他类型的状态都作为单独的结构处理。这就造成了一种不对称:1) 在 apply
内部运行的代码可以直接运行方法和其他 Module
对象;2) 在 apply
外部运行的代码必须使用 apply
方法。
在 Flax NNX 中,没有特殊的上下文,因为参数作为属性持有,方法可以直接调用。这意味着您的 NNX 模块的 __init__
和 __call__
方法与其他类方法的处理方式没有区别,而 Flax Linen 模块的 setup()
和 __call__
方法是特殊的。
Encoder = lambda: nn.Dense(10)
Decoder = lambda: nn.Dense(2)
class AutoEncoder(nn.Module):
def setup(self):
self.encoder = Encoder()
self.decoder = Decoder()
def __call__(self, x) -> jax.Array:
return self.decoder(self.encoder(x))
def encode(self, x) -> jax.Array:
return self.encoder(x)
x = jnp.ones((1, 2))
model = AutoEncoder()
params = model.init(random.key(0), x)['params']
y = model.apply({'params': params}, x)
z = model.apply({'params': params}, x, method='encode')
y = Decoder().apply({'params': params['decoder']}, z)
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)
class AutoEncoder(nnx.Module):
def __init__(self, rngs):
self.encoder = Encoder(rngs)
self.decoder = Decoder(rngs)
def __call__(self, x) -> jax.Array:
return self.decoder(self.encoder(x))
def encode(self, x) -> jax.Array:
return self.encoder(x)
x = jnp.ones((1, 2))
model = AutoEncoder(nnx.Rngs(0))
y = model(x)
z = model.encode(x)
y = model.decoder(z)
在 Flax Linen 中,直接调用子模块是不可能的,因为它们没有被初始化。因此,您必须做的是构建一个新实例,然后提供一个合适的参数结构。
但在 Flax NNX 中,您可以直接调用子模块而没有任何问题。
状态处理#
Flax Linen 在状态处理方面是出了名的复杂。当您使用 Dropout 层、BatchNorm 层或两者都使用时,您突然必须处理新的状态,并使用它来配置 flax.linen.Module.apply
方法。
在 Flax NNX 中,状态保存在一个 nnx.Module
内部并且是可变的,这意味着它可以被直接调用。
class Block(nn.Module):
train: bool
def setup(self):
self.linear = nn.Dense(10)
self.bn = nn.BatchNorm(use_running_average=not self.train)
self.dropout = nn.Dropout(0.1, deterministic=not self.train)
def __call__(self, x):
return nn.relu(self.dropout(self.bn(self.linear(x))))
x = jnp.ones((1, 5))
model = Block(train=True)
vs = model.init(random.key(0), x)
params, batch_stats = vs['params'], vs['batch_stats']
y, updates = model.apply(
{'params': params, 'batch_stats': batch_stats},
x,
rngs={'dropout': random.key(1)},
mutable=['batch_stats'],
)
batch_stats = updates['batch_stats']
class Block(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(5, 10, rngs=rngs)
self.bn = nnx.BatchNorm(10, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.dropout(self.bn(self.linear(x))))
x = jnp.ones((1, 5))
model = Block(nnx.Rngs(0))
y = model(x)
...
Flax NNX 状态处理的主要好处是,当您添加新的有状态层时,不必更改训练代码。
此外,在 Flax NNX 中,处理状态的层也非常容易实现。下面是一个 BatchNorm
层的简化版本,它在每次被调用时都会更新均值和方差。
class BatchNorm(nnx.Module):
def __init__(self, features: int, mu: float = 0.95):
# Variables
self.scale = nnx.Param(jax.numpy.ones((features,)))
self.bias = nnx.Param(jax.numpy.zeros((features,)))
self.mean = nnx.BatchStat(jax.numpy.zeros((features,)))
self.var = nnx.BatchStat(jax.numpy.ones((features,)))
self.mu = mu # Static
def __call__(self, x):
mean = jax.numpy.mean(x, axis=-1)
var = jax.numpy.var(x, axis=-1)
# ema updates
self.mean.value = self.mu * self.mean + (1 - self.mu) * mean
self.var.value = self.mu * self.var + (1 - self.mu) * var
# normalize and scale
x = (x - mean) / jax.numpy.sqrt(var + 1e-5)
return x * self.scale + self.bias
模型修改#
在 Flax Linen 中,模型修改 (model surgery) 历来都具有挑战性,原因有二:
由于延迟初始化,不能保证您可以用一个新的子
Module
替换现有的子Module
。参数结构与
flax.linen.Module
结构是分开的,这意味着您必须手动保持它们同步。
在 Flax NNX 中,您可以根据 Python 语义直接替换子模块。由于参数是 nnx.Module
结构的一部分,它们永远不会失同步。下面是一个示例,展示了如何实现一个 LoRA 层,然后用它来替换现有模型中的 Linear
层。
class LoraLinear(nn.Module):
linear: nn.Dense
rank: int
@nn.compact
def __call__(self, x: jax.Array):
A = self.param(random.normal, (x.shape[-1], self.rank))
B = self.param(random.normal, (self.rank, self.linear.features))
return self.linear(x) + x @ A @ B
try:
model = Block(train=True)
model.linear = LoraLinear(model.linear, rank=5) # <-- ERROR
lora_params = model.linear.init(random.key(1), x)
lora_params['linear'] = params['linear']
params['linear'] = lora_params
except AttributeError as e:
pass
class LoraParam(nnx.Param): pass
class LoraLinear(nnx.Module):
def __init__(self, linear, rank, rngs):
self.linear = linear
self.A = LoraParam(random.normal(rngs(), (linear.in_features, rank)))
self.B = LoraParam(random.normal(rngs(), (rank, linear.out_features)))
def __call__(self, x: jax.Array):
return self.linear(x) + x @ self.A @ self.B
rngs = nnx.Rngs(0)
model = Block(rngs)
model.linear = LoraLinear(model.linear, rank=5, rngs=rngs)
...
如上所示,在 Flax Linen 中,这在这种情况下并不真正起作用,因为 linear
子Module
是不可用的。然而,代码的其余部分提供了一个关于必须如何手动更新 params
结构的想法。
在 Flax Linen 中执行任意的模型修改并不容易,目前 intercept_methods API 是进行通用方法修补 (patching) 的唯一方法。但这个 API 不太符合人体工程学。
在 Flax NNX 中,要进行通用的模型修改,您只需使用 nnx.iter_graph
,这比在 Linen 中更简单、更容易。下面是一个示例,展示了如何将模型中所有的 nnx.Linear
层替换为自定义的 LoraLinear
NNX 层。
rngs = nnx.Rngs(0)
model = Block(rngs)
for path, module in nnx.iter_graph(model):
if isinstance(module, nnx.Module):
for name, value in vars(module).items():
if isinstance(value, nnx.Linear):
setattr(module, name, LoraLinear(value, rank=5, rngs=rngs))
变换#
Flax Linen 变换非常强大,因为它们能够对模型状态进行细粒度控制。然而,Flax Linen 变换也有缺点,例如:
它们暴露了不属于 JAX 的额外 API,使其行为令人困惑,有时甚至与它们的 JAX 对应物有所不同。这也限制了您与 JAX 变换 交互以及跟上 JAX API 变化的方式。
它们作用于具有非常特定签名的函数,即:
一个
flax.linen.Module
必须是第一个参数。它们接受其他
Module
对象作为参数,但不接受作为返回值。
它们只能在
flax.linen.Module.apply
内部使用。
另一方面,Flax NNX 变换 旨在与其对应的 JAX 变换 等效,但有一个例外——它们可以用于 Flax NNX 模块。这意味着 Flax 变换:
具有与 JAX 变换相同的 API。
可以在任何参数上接受 Flax NNX 模块,并且
nnx.Module
对象可以从它们返回。可以在任何地方使用,包括训练循环。
下面是一个使用 vmap
和 Flax NNX 的示例,它通过变换返回一些 Weights
的 create_weights
函数来创建权重堆栈,并通过变换以 Weights
为第一个参数、一批输入为第二个参数的 vector_dot
函数,将该权重堆栈分别应用于一批输入。
class Weights(nnx.Module):
def __init__(self, kernel: jax.Array, bias: jax.Array):
self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)
def create_weights(seed: jax.Array):
return Weights(
kernel=random.uniform(random.key(seed), (2, 3)),
bias=jnp.zeros((3,)),
)
def vector_dot(weights: Weights, x: jax.Array):
assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ weights.kernel + weights.bias
seeds = jnp.arange(10)
weights = nnx.vmap(create_weights, in_axes=0, out_axes=0)(seeds)
x = jax.random.normal(random.key(1), (10, 2))
y = nnx.vmap(vector_dot, in_axes=(0, 0), out_axes=1)(weights, x)
与 Flax Linen 变换相反,in_axes
参数和其他 API 会影响 nnx.Module
状态的变换方式。
此外,Flax NNX 变换可以用作方法装饰器,因为 nnx.Module
方法只是将 Module
作为第一个参数的函数。这意味着前面的示例可以重写如下:
class WeightStack(nnx.Module):
@nnx.vmap(in_axes=(0, 0), out_axes=0)
def __init__(self, seed: jax.Array):
self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
self.bias = nnx.Param(jnp.zeros((3,)))
@nnx.vmap(in_axes=(0, 0), out_axes=1)
def __call__(self, x: jax.Array):
assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
assert x.ndim == 1, 'Batch dimensions not allowed'
return x @ self.kernel + self.bias
weights = WeightStack(jnp.arange(10))
x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)