Flax 基础#

Flax NNX 是一个新的简化 API,旨在更轻松地在 JAX 中创建、检查、调试和分析神经网络。它通过添加对 Python 引用语义的头等支持来实现这一点。这允许用户使用常规 Python 对象来表示他们的模型,这些对象被建模为 PyGraph(而不是 pytree),从而实现了引用共享和可变性。这样的 API 设计应该会让 PyTorch 或 Keras 用户感到宾至如归。

首先,使用 pip 安装 Flax 并导入必要的依赖项

# ! pip install -U flax
from flax import nnx
import jax
import jax.numpy as jnp

Flax NNX 模块系统#

nnx.ModuleFlax LinenHaiku 中其他 Module 系统的主要区别在于,在 NNX 中,一切都是**显式**的。这意味着,除其他事项外,nnx.Module 本身直接持有状态(如参数),PRNG 状态由用户传递,并且所有形状信息必须在初始化时提供(没有形状推断)。

让我们从创建一个 Linear nnx.Module 开始。如下所示,动态状态通常存储在 nnx.Param 中,而静态状态(NNX 不处理的所有类型)如整数或字符串则直接存储。类型为 jax.Arraynumpy.ndarray 的属性也被视为动态状态,尽管更推荐将它们存储在 nnx.Variable 中,例如 Param。此外,nnx.Rngs 对象可用于根据传递给构造函数的根 PRNG 密钥生成新的唯一密钥。

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return x @ self.w + self.b

另请注意,nnx.Variable 的内部值可以使用 value 属性访问,但为了方便,它们实现了所有数值运算符,可以直接用于算术表达式(如上述代码所示)。

要初始化一个 Flax nnx.Module,你只需调用其构造函数,Module 的所有参数通常都会被即时创建。由于 nnx.Module 拥有自己的状态方法,你可以直接调用它们,而无需一个单独的 apply 方法。这对于调试非常方便,可以让你直接检查模型的整个结构。

model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))

print(y)
nnx.display(model)
[[1.5643291  0.94782424 0.37971854 1.0724319  0.22112393]]
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py:251: UserWarning: Ignoring error inside wrapper hook <function use_autovisualizer_if_present at 0x75a9605514e0>:
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py", line 225, in _render_subtree
    postprocessed_result = hook(
                           ^^^^^
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/handlers/autovisualizer_hook.py", line 47, in use_autovisualizer_if_present
    result = autoviz(node, path)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/api/array_autovisualizer.py", line 306, in __call__
    jax.sharding.PositionalSharding
  File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr
    raise AttributeError(message)
AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0

  warnings.warn(

上述由 nnx.display 生成的可视化是使用出色的 Treescope 库创建的。

有状态计算#

实现像 nnx.BatchNorm 这样的层需要在前向传播期间执行状态更新。在 Flax NNX 中,你只需创建一个 nnx.Variable 并在前向传播期间更新其 .value

class Count(nnx.Variable): pass

class Counter(nnx.Module):
  def __init__(self):
    self.count = Count(jnp.array(0))

  def __call__(self):
    self.count += 1

counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')
counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)

在 JAX 中通常会避免使用可变引用。但 Flax NNX 提供了可靠的机制来处理它们,正如本指南后面的部分所演示的。

嵌套模块#

Flax nnx.Module 可以用于以嵌套结构组合其他 Module。这些模块可以直接作为属性赋值,或者放在任何(嵌套的)pytree 类型属性内部,例如 listdicttuple 等。

下面的示例展示了如何通过子类化 nnx.Module 来定义一个简单的 MLP。该模型由两个 Linear 层、一个 nnx.Dropout 层和一个 nnx.BatchNorm 层组成。

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

model = MLP(2, 16, 5, rngs=nnx.Rngs(0))

y = model(x=jnp.ones((3, 2)))

nnx.display(model)

在 Flax 中,nnx.Dropout 是一个有状态的模块,它存储一个 nnx.Rngs 对象,这样它就可以在前向传播期间生成新的掩码,而无需用户每次都传递一个新的密钥。

模型修改#

Flax nnx.Module 默认是可变的。这意味着它们的结构可以随时更改,这使得模型修改(model surgery)相当容易,因为任何子Module属性都可以被替换为任何其他东西,例如新的 Module、现有的共享 Module、不同类型的 Module 等等。此外,nnx.Variable 也可以被修改或替换/共享。

以下示例展示了如何将前一个示例中 MLP 模型中的 Linear 层替换为 LoraLinear 层。

class LoraParam(nnx.Param): pass

class LoraLinear(nnx.Module):
  def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
    self.linear = linear
    self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
    self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))

  def __call__(self, x: jax.Array):
    return self.linear(x) + x @ self.A @ self.B

rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)

# Model surgery.
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)

y = model(x=jnp.ones((3, 2)))

nnx.display(model)

Flax 变换#

Flax NNX 变换 (transforms) 扩展了 JAX 变换,以支持 nnx.Module 和其他对象。它们是其等效 JAX 对应项的超集,增加了感知对象状态并提供额外 API 来转换它的功能。

Flax 变换的主要特性之一是保留引用语义,这意味着在变换内部发生的任何对对象图的修改,只要在变换规则内是合法的,都会被传播到外部。在实践中,这意味着 Flax 程序可以用命令式代码来表达,从而极大地简化了用户体验。

在下面的示例中,您定义了一个 train_step 函数,它接受一个 MLP 模型、一个 nnx.Optimizer 和一个数据批次,并返回该步骤的损失。损失和梯度是使用 nnx.value_and_grad 变换对 loss_fn 计算的。梯度被传递给优化器的 nnx.Optimizer.update 方法,以更新 model 的参数。

import optax

# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

@nnx.jit  # Automatic state management
def train_step(model, optimizer, x, y):
  def loss_fn(model: MLP):
    y_pred = model(x)
    return jnp.mean((y_pred - y) ** 2)

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(model, grads)  # In place updates.

  return loss

x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)

print(f'{loss = }')
print(f'{optimizer.step.value = }')
loss = Array(1.0000602, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)

这个例子中有两点值得一提

  1. 对每个 nnx.BatchNormnnx.Dropout 层状态的更新会自动从 loss_fn 内部传播到 train_step,一直到外部的 model 引用。

  2. optimizer 持有对 model 的可变引用——这种关系在 train_step 函数内部得以保留,使得仅使用优化器就能更新模型的参数成为可能。

注意
对于小型模型,nnx.jit 存在性能开销,请查阅性能考量指南以获取更多信息。

扫描层#

下一个示例使用 Flax nnx.vmap 创建一个由多个 MLP 层组成的堆栈,并使用 nnx.scan 迭代地将堆栈中的每一层应用于输入。

在下面的代码中,请注意以下几点

  1. 自定义的 create_model 函数接受一个密钥并返回一个 MLP 对象。由于您创建了五个密钥并对 create_model 使用 nnx.vmap,因此创建了一个包含 5 个 MLP 对象的堆栈。

  2. nnx.scan 用于迭代地将堆栈中的每个 MLP 应用于输入 x

  3. nnx.scan(有意地)偏离了 jax.lax.scan,而是模仿了更具表现力的 nnx.vmapnnx.scan 允许指定多个输入、每个输入/输出的扫描轴以及 carry 的位置。

  4. nnx.BatchNormnnx.Dropout 层的 State 更新由 nnx.scan 自动传播。

@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
  return MLP(10, 32, 10, rngs=nnx.Rngs(key))

keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)

@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
  x = model(x)
  return x

x = jnp.ones((3, 10))
y = forward(model, x)

print(f'{y.shape = }')
nnx.display(model)
y.shape = (3, 10)

Flax NNX 变换是如何实现这一点的?为了理解 Flax NNX 对象如何与 JAX 变换交互,下一节将解释 Flax NNX 函数式 API。

Flax 函数式 API#

Flax NNX 函数式 API 在引用/对象语义和值/pytree 语义之间建立了一个清晰的界限。它也提供了与 Flax Linen 和 Haiku 用户习惯的相同程度的对状态的细粒度控制。Flax NNX 函数式 API 由三个基本方法组成:nnx.splitnnx.mergennx.update

下面是一个使用函数式 API 的 StatefulLinear nnx.Module 的示例。它包含

class Count(nnx.Variable): pass

class StatefulLinear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.count = Count(jnp.array(0, dtype=jnp.uint32))

  def __call__(self, x: jax.Array):
    self.count += 1
    return x @ self.w + self.b

model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))

nnx.display(model)

状态与 GraphDef#

一个 Flax nnx.Module 可以使用 nnx.split 函数分解为 nnx.Statennx.GraphDef

graphdef, state = nnx.split(model)

nnx.display(graphdef, state)

拆分、合并和更新#

Flax 的 nnx.mergennx.split 的逆操作。它接收 nnx.GraphDef + nnx.State 并重构 nnx.Module。下面的示例演示了这一点:

  • 通过依次使用 nnx.splitnnx.merge,任何 Module 都可以被提升以在任何 JAX 变换中使用。

  • nnx.update 可以用给定的 nnx.State 的内容来就地更新一个对象。

  • 此模式用于将状态从变换传播回外部的源对象。

print(f'{model.count.value = }')

# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.
graphdef, state = nnx.split(model)

@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:
  # 2. Use `nnx.merge` to create a new model inside the JAX transformation.
  model = nnx.merge(graphdef, state)
  # 3. Call the `nnx.Module`
  y = model(x)
  # 4. Use `nnx.split` to propagate `nnx.State` updates.
  _, state = nnx.split(model)
  return y, state

y, state = forward(graphdef, state, x=jnp.ones((1, 3)))
# 5. Update the state of the original `nnx.Module`.
nnx.update(model, state)

print(f'{model.count.value = }')
model.count.value = Array(1, dtype=uint32)
model.count.value = Array(2, dtype=uint32)

这种模式的关键洞见在于,在变换上下文中使用可变引用是可以的(包括基础的即时解释器),但在跨越边界时必须使用函数式 API。

为什么模块不直接是 pytree? 主要原因是,这样很容易意外地丢失共享引用。例如,如果你将两个拥有一个共享 Modulennx.Module 传递通过一个 JAX 边界,你会悄无声息地失去这种共享关系。Flax 的函数式 API 使这种行为变得明确,因此更容易推理。

细粒度状态控制#

经验丰富的 Flax LinenHaiku API 用户可能会认识到,将所有状态放在一个单一结构中并不总是最佳选择,因为在某些情况下,您可能希望以不同方式处理状态的不同子集。这在与 JAX 变换交互时是很常见的。

例如

  • 在与 jax.grad 交互时,并非每个模型状态都可以或应该被微分。

  • 或者,有时在使用 jax.lax.scan 时,需要指定模型状态的哪一部分是 carry,哪一部分不是。

为了解决这个问题,Flax NNX API 提供了 nnx.split,它允许您传递一个或多个 nnx.filterlib.Filter 来将 nnx.Variable 分割成互斥的 nnx.State。Flax NNX 在 API 中使用 Filter 创建 State 组(例如 nnx.splitnnx.state() 以及许多 NNX 变换)。

下面的示例展示了最常见的 Filter

# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.
graphdef, params, counts = nnx.split(model, nnx.Param, Count)

nnx.display(params, counts)

注意: nnx.filterlib.Filter 必须是详尽的,如果一个值没有被匹配,将会引发错误。

不出所料,nnx.mergennx.update 方法自然会消费多个 State

# Merge multiple `State`s
model = nnx.merge(graphdef, params, counts)
# Update with multiple `State`s
nnx.update(model, params, counts)