为何选择 Flax NNX?#

2020 年,Flax 团队发布了 Flax Linen API,以支持 JAX 上的建模研究,重点关注扩展性和性能。从那时起,我们从用户那里学到了很多。团队引入了一些被证明对用户有益的想法,例如

Flax 团队做出的选择之一是,通过参数的延迟初始化,为神经网络编程使用函数式 (compact) 语义。这使得实现代码更加简洁,并使 Flax Linen API 与 Haiku 保持一致。

然而,这也意味着 Flax 中的模块和变量的语义不符合 Python 风格,并且常常出人意料。它还导致了实现的复杂性,并模糊了对神经网络进行变换 (transforms) 的核心思想。

Flax NNX 简介#

快进到 2024 年,Flax 团队开发了 Flax NNX——它试图保留 Flax Linen 对用户有用的特性,同时引入一些新原则。Flax NNX 背后的核心思想是将引用语义引入 JAX。以下是其主要特性:

  • NNX 符合 Python 风格:模块遵循常规的 Python 语义,包括支持可变性和共享引用。

  • NNX 很简单:Flax Linen 中的许多复杂 API 要么使用 Python 惯用法进行了简化,要么被完全移除。

  • 更好的 JAX 集成:自定义的 NNX 变换采用与 JAX 变换相同的 API。并且使用 NNX 可以更轻松地直接使用 JAX 变换(高阶函数)

下面是一个简单的 Flax NNX 程序示例,它阐明了上述许多要点:

from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # Eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

@nnx.jit  # Automatic state management for JAX transforms.
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(model, grads)  # in-place updates

  return loss

Flax NNX 对 Linen 的改进#

本文档的其余部分使用各种示例来演示 Flax NNX 如何改进 Flax Linen。

检查#

第一个改进是 Flax NNX 模块是常规的 Python 对象。这意味着您可以轻松地构建和检查 Module 对象。

另一方面,Flax Linen 模块不容易检查和调试,因为它们是延迟的,这意味着某些属性在构建时不可用,只能在运行时访问。

class Block(nn.Module):
  def setup(self):
    self.linear = nn.Dense(10)

block = Block()

try:
  block.linear  # AttributeError: "Block" object has no attribute "linear".
except AttributeError as e:
  pass





...
class Block(nnx.Module):
  def __init__(self, rngs):
    self.linear = nnx.Linear(5, 10, rngs=rngs)

block = Block(nnx.Rngs(0))


block.linear
# Linear(
#   kernel=Param(
#     value=Array(shape=(5, 10), dtype=float32)
#   ),
#   bias=Param(
#     value=Array(shape=(10,), dtype=float32)
#   ),
#   ...

请注意,在上面的 Flax NNX 示例中,没有形状推断——输入和输出形状都必须提供给 Linear nnx.Module。这是一个权衡,它允许更明确和可预测的行为。

运行计算#

在 Flax Linen 中,所有顶层计算都必须通过 flax.linen.Module.initflax.linen.Module.apply 方法完成,并且参数或任何其他类型的状态都作为单独的结构处理。这就造成了一种不对称:1) 在 apply 内部运行的代码可以直接运行方法和其他 Module 对象;2) 在 apply 外部运行的代码必须使用 apply 方法。

在 Flax NNX 中,没有特殊的上下文,因为参数作为属性持有,方法可以直接调用。这意味着您的 NNX 模块的 __init____call__ 方法与其他类方法的处理方式没有区别,而 Flax Linen 模块的 setup()__call__ 方法是特殊的。

Encoder = lambda: nn.Dense(10)
Decoder = lambda: nn.Dense(2)

class AutoEncoder(nn.Module):
  def setup(self):
    self.encoder = Encoder()
    self.decoder = Decoder()

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

x = jnp.ones((1, 2))
model = AutoEncoder()
params = model.init(random.key(0), x)['params']

y = model.apply({'params': params}, x)
z = model.apply({'params': params}, x, method='encode')
y = Decoder().apply({'params': params['decoder']}, z)
Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

class AutoEncoder(nnx.Module):
  def __init__(self, rngs):
    self.encoder = Encoder(rngs)
    self.decoder = Decoder(rngs)

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

x = jnp.ones((1, 2))
model = AutoEncoder(nnx.Rngs(0))


y = model(x)
z = model.encode(x)
y = model.decoder(z)

在 Flax Linen 中,直接调用子模块是不可能的,因为它们没有被初始化。因此,您必须做的是构建一个新实例,然后提供一个合适的参数结构。

但在 Flax NNX 中,您可以直接调用子模块而没有任何问题。

状态处理#

Flax Linen 在状态处理方面是出了名的复杂。当您使用 Dropout 层、BatchNorm 层或两者都使用时,您突然必须处理新的状态,并使用它来配置 flax.linen.Module.apply 方法。

在 Flax NNX 中,状态保存在一个 nnx.Module 内部并且是可变的,这意味着它可以被直接调用。

class Block(nn.Module):
  train: bool

  def setup(self):
    self.linear = nn.Dense(10)
    self.bn = nn.BatchNorm(use_running_average=not self.train)
    self.dropout = nn.Dropout(0.1, deterministic=not self.train)

  def __call__(self, x):
    return nn.relu(self.dropout(self.bn(self.linear(x))))

x = jnp.ones((1, 5))
model = Block(train=True)
vs = model.init(random.key(0), x)
params, batch_stats = vs['params'], vs['batch_stats']

y, updates = model.apply(
  {'params': params, 'batch_stats': batch_stats},
  x,
  rngs={'dropout': random.key(1)},
  mutable=['batch_stats'],
)
batch_stats = updates['batch_stats']
class Block(nnx.Module):


  def __init__(self, rngs):
    self.linear = nnx.Linear(5, 10, rngs=rngs)
    self.bn = nnx.BatchNorm(10, rngs=rngs)
    self.dropout = nnx.Dropout(0.1, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.dropout(self.bn(self.linear(x))))

x = jnp.ones((1, 5))
model = Block(nnx.Rngs(0))



y = model(x)





...

Flax NNX 状态处理的主要好处是,当您添加新的有状态层时,不必更改训练代码。

此外,在 Flax NNX 中,处理状态的层也非常容易实现。下面是一个 BatchNorm 层的简化版本,它在每次被调用时都会更新均值和方差。

class BatchNorm(nnx.Module):
  def __init__(self, features: int, mu: float = 0.95):
    # Variables
    self.scale = nnx.Param(jax.numpy.ones((features,)))
    self.bias = nnx.Param(jax.numpy.zeros((features,)))
    self.mean = nnx.BatchStat(jax.numpy.zeros((features,)))
    self.var = nnx.BatchStat(jax.numpy.ones((features,)))
    self.mu = mu  # Static

def __call__(self, x):
  mean = jax.numpy.mean(x, axis=-1)
  var = jax.numpy.var(x, axis=-1)
  # ema updates
  self.mean.value = self.mu * self.mean + (1 - self.mu) * mean
  self.var.value = self.mu * self.var + (1 - self.mu) * var
  # normalize and scale
  x = (x - mean) / jax.numpy.sqrt(var + 1e-5)
  return x * self.scale + self.bias

模型修改#

在 Flax Linen 中,模型修改 (model surgery) 历来都具有挑战性,原因有二:

  1. 由于延迟初始化,不能保证您可以用一个新的子Module替换现有的子Module

  2. 参数结构与 flax.linen.Module 结构是分开的,这意味着您必须手动保持它们同步。

在 Flax NNX 中,您可以根据 Python 语义直接替换子模块。由于参数是 nnx.Module 结构的一部分,它们永远不会失同步。下面是一个示例,展示了如何实现一个 LoRA 层,然后用它来替换现有模型中的 Linear 层。

class LoraLinear(nn.Module):
  linear: nn.Dense
  rank: int

  @nn.compact
  def __call__(self, x: jax.Array):
    A = self.param(random.normal, (x.shape[-1], self.rank))
    B = self.param(random.normal, (self.rank, self.linear.features))

    return self.linear(x) + x @ A @ B

try:
  model = Block(train=True)
  model.linear = LoraLinear(model.linear, rank=5) # <-- ERROR

  lora_params = model.linear.init(random.key(1), x)
  lora_params['linear'] = params['linear']
  params['linear'] = lora_params

except AttributeError as e:
  pass
class LoraParam(nnx.Param): pass

class LoraLinear(nnx.Module):
  def __init__(self, linear, rank, rngs):
    self.linear = linear
    self.A = LoraParam(random.normal(rngs(), (linear.in_features, rank)))
    self.B = LoraParam(random.normal(rngs(), (rank, linear.out_features)))

  def __call__(self, x: jax.Array):
    return self.linear(x) + x @ self.A @ self.B

rngs = nnx.Rngs(0)
model = Block(rngs)
model.linear = LoraLinear(model.linear, rank=5, rngs=rngs)






...

如上所示,在 Flax Linen 中,这在这种情况下并不真正起作用,因为 linearModule 是不可用的。然而,代码的其余部分提供了一个关于必须如何手动更新 params 结构的想法。

在 Flax Linen 中执行任意的模型修改并不容易,目前 intercept_methods API 是进行通用方法修补 (patching) 的唯一方法。但这个 API 不太符合人体工程学。

在 Flax NNX 中,要进行通用的模型修改,您只需使用 nnx.iter_graph,这比在 Linen 中更简单、更容易。下面是一个示例,展示了如何将模型中所有的 nnx.Linear 层替换为自定义的 LoraLinear NNX 层。

rngs = nnx.Rngs(0)
model = Block(rngs)

for path, module in nnx.iter_graph(model):
  if isinstance(module, nnx.Module):
    for name, value in vars(module).items():
      if isinstance(value, nnx.Linear):
        setattr(module, name, LoraLinear(value, rank=5, rngs=rngs))

变换#

Flax Linen 变换非常强大,因为它们能够对模型状态进行细粒度控制。然而,Flax Linen 变换也有缺点,例如:

  1. 它们暴露了不属于 JAX 的额外 API,使其行为令人困惑,有时甚至与它们的 JAX 对应物有所不同。这也限制了您与 JAX 变换 交互以及跟上 JAX API 变化的方式。

  2. 它们作用于具有非常特定签名的函数,即:

  • 一个 flax.linen.Module 必须是第一个参数。

  • 它们接受其他 Module 对象作为参数,但不接受作为返回值。

  1. 它们只能在 flax.linen.Module.apply 内部使用。

另一方面,Flax NNX 变换 旨在与其对应的 JAX 变换 等效,但有一个例外——它们可以用于 Flax NNX 模块。这意味着 Flax 变换:

  1. 具有与 JAX 变换相同的 API。

  2. 可以在任何参数上接受 Flax NNX 模块,并且 nnx.Module 对象可以从它们返回。

  3. 可以在任何地方使用,包括训练循环。

下面是一个使用 vmap 和 Flax NNX 的示例,它通过变换返回一些 Weightscreate_weights 函数来创建权重堆栈,并通过变换以 Weights 为第一个参数、一批输入为第二个参数的 vector_dot 函数,将该权重堆栈分别应用于一批输入。

class Weights(nnx.Module):
  def __init__(self, kernel: jax.Array, bias: jax.Array):
    self.kernel, self.bias = nnx.Param(kernel), nnx.Param(bias)

def create_weights(seed: jax.Array):
  return Weights(
    kernel=random.uniform(random.key(seed), (2, 3)),
    bias=jnp.zeros((3,)),
  )

def vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  return x @ weights.kernel + weights.bias

seeds = jnp.arange(10)
weights = nnx.vmap(create_weights, in_axes=0, out_axes=0)(seeds)

x = jax.random.normal(random.key(1), (10, 2))
y = nnx.vmap(vector_dot, in_axes=(0, 0), out_axes=1)(weights, x)

与 Flax Linen 变换相反,in_axes 参数和其他 API 会影响 nnx.Module 状态的变换方式。

此外,Flax NNX 变换可以用作方法装饰器,因为 nnx.Module 方法只是将 Module 作为第一个参数的函数。这意味着前面的示例可以重写如下:

class WeightStack(nnx.Module):
  @nnx.vmap(in_axes=(0, 0), out_axes=0)
  def __init__(self, seed: jax.Array):
    self.kernel = nnx.Param(random.uniform(random.key(seed), (2, 3)))
    self.bias = nnx.Param(jnp.zeros((3,)))

  @nnx.vmap(in_axes=(0, 0), out_axes=1)
  def __call__(self, x: jax.Array):
    assert self.kernel.ndim == 2, 'Batch dimensions not allowed'
    assert x.ndim == 1, 'Batch dimensions not allowed'
    return x @ self.kernel + self.bias

weights = WeightStack(jnp.arange(10))

x = jax.random.normal(random.key(1), (10, 2))
y = weights(x)