Flax 基础#
Flax NNX 是一个新的简化 API,旨在更轻松地在 JAX 中创建、检查、调试和分析神经网络。它通过添加对 Python 引用语义的头等支持来实现这一点。这允许用户使用常规 Python 对象来表示他们的模型,这些对象被建模为 PyGraph(而不是 pytree),从而实现了引用共享和可变性。这样的 API 设计应该会让 PyTorch 或 Keras 用户感到宾至如归。
首先,使用 pip
安装 Flax 并导入必要的依赖项
# ! pip install -U flax
from flax import nnx
import jax
import jax.numpy as jnp
Flax NNX 模块系统#
nnx.Module
与 Flax Linen 或 Haiku 中其他 Module
系统的主要区别在于,在 NNX 中,一切都是**显式**的。这意味着,除其他事项外,nnx.Module
本身直接持有状态(如参数),PRNG 状态由用户传递,并且所有形状信息必须在初始化时提供(没有形状推断)。
让我们从创建一个 Linear
nnx.Module
开始。如下所示,动态状态通常存储在 nnx.Param
中,而静态状态(NNX 不处理的所有类型)如整数或字符串则直接存储。类型为 jax.Array
和 numpy.ndarray
的属性也被视为动态状态,尽管更推荐将它们存储在 nnx.Variable
中,例如 Param
。此外,nnx.Rngs
对象可用于根据传递给构造函数的根 PRNG 密钥生成新的唯一密钥。
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs.params()
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return x @ self.w + self.b
另请注意,nnx.Variable
的内部值可以使用 value
属性访问,但为了方便,它们实现了所有数值运算符,可以直接用于算术表达式(如上述代码所示)。
要初始化一个 Flax nnx.Module
,你只需调用其构造函数,Module
的所有参数通常都会被即时创建。由于 nnx.Module
拥有自己的状态方法,你可以直接调用它们,而无需一个单独的 apply
方法。这对于调试非常方便,可以让你直接检查模型的整个结构。
model = Linear(2, 5, rngs=nnx.Rngs(params=0))
y = model(x=jnp.ones((1, 2)))
print(y)
nnx.display(model)
[[1.5643291 0.94782424 0.37971854 1.0724319 0.22112393]]
/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py:251: UserWarning: Ignoring error inside wrapper hook <function use_autovisualizer_if_present at 0x75a9605514e0>:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/renderers.py", line 225, in _render_subtree
postprocessed_result = hook(
^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/handlers/autovisualizer_hook.py", line 47, in use_autovisualizer_if_present
result = autoviz(node, path)
^^^^^^^^^^^^^^^^^^^
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/treescope/_internal/api/array_autovisualizer.py", line 306, in __call__
jax.sharding.PositionalSharding
File "/home/docs/checkouts/readthedocs.org/user_builds/flax/envs/latest/lib/python3.11/site-packages/jax/_src/deprecations.py", line 54, in getattr
raise AttributeError(message)
AttributeError: jax.sharding.PositionalSharding was deprecated in JAX v0.6.0 and removed in JAX v0.7.0
warnings.warn(
上述由 nnx.display
生成的可视化是使用出色的 Treescope 库创建的。
有状态计算#
实现像 nnx.BatchNorm
这样的层需要在前向传播期间执行状态更新。在 Flax NNX 中,你只需创建一个 nnx.Variable
并在前向传播期间更新其 .value
。
class Count(nnx.Variable): pass
class Counter(nnx.Module):
def __init__(self):
self.count = Count(jnp.array(0))
def __call__(self):
self.count += 1
counter = Counter()
print(f'{counter.count.value = }')
counter()
print(f'{counter.count.value = }')
counter.count.value = Array(0, dtype=int32, weak_type=True)
counter.count.value = Array(1, dtype=int32, weak_type=True)
在 JAX 中通常会避免使用可变引用。但 Flax NNX 提供了可靠的机制来处理它们,正如本指南后面的部分所演示的。
嵌套模块#
Flax nnx.Module
可以用于以嵌套结构组合其他 Module
。这些模块可以直接作为属性赋值,或者放在任何(嵌套的)pytree 类型属性内部,例如 list
、dict
、tuple
等。
下面的示例展示了如何通过子类化 nnx.Module
来定义一个简单的 MLP
。该模型由两个 Linear
层、一个 nnx.Dropout
层和一个 nnx.BatchNorm
层组成。
class MLP(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
self.linear1 = Linear(din, dmid, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.linear2 = Linear(dmid, dout, rngs=rngs)
def __call__(self, x: jax.Array):
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
return self.linear2(x)
model = MLP(2, 16, 5, rngs=nnx.Rngs(0))
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
在 Flax 中,nnx.Dropout
是一个有状态的模块,它存储一个 nnx.Rngs
对象,这样它就可以在前向传播期间生成新的掩码,而无需用户每次都传递一个新的密钥。
模型修改#
Flax nnx.Module
默认是可变的。这意味着它们的结构可以随时更改,这使得模型修改(model surgery)相当容易,因为任何子Module
属性都可以被替换为任何其他东西,例如新的 Module
、现有的共享 Module
、不同类型的 Module
等等。此外,nnx.Variable
也可以被修改或替换/共享。
以下示例展示了如何将前一个示例中 MLP
模型中的 Linear
层替换为 LoraLinear
层。
class LoraParam(nnx.Param): pass
class LoraLinear(nnx.Module):
def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):
self.linear = linear
self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))
self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))
def __call__(self, x: jax.Array):
return self.linear(x) + x @ self.A @ self.B
rngs = nnx.Rngs(0)
model = MLP(2, 32, 5, rngs=rngs)
# Model surgery.
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)
y = model(x=jnp.ones((3, 2)))
nnx.display(model)
Flax 变换#
Flax NNX 变换 (transforms) 扩展了 JAX 变换,以支持 nnx.Module
和其他对象。它们是其等效 JAX 对应项的超集,增加了感知对象状态并提供额外 API 来转换它的功能。
Flax 变换的主要特性之一是保留引用语义,这意味着在变换内部发生的任何对对象图的修改,只要在变换规则内是合法的,都会被传播到外部。在实践中,这意味着 Flax 程序可以用命令式代码来表达,从而极大地简化了用户体验。
在下面的示例中,您定义了一个 train_step
函数,它接受一个 MLP
模型、一个 nnx.Optimizer
和一个数据批次,并返回该步骤的损失。损失和梯度是使用 nnx.value_and_grad
变换对 loss_fn
计算的。梯度被传递给优化器的 nnx.Optimizer.update
方法,以更新 model
的参数。
import optax
# An MLP containing 2 custom `Linear` layers, 1 `nnx.Dropout` layer, 1 `nnx.BatchNorm` layer.
model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
@nnx.jit # Automatic state management
def train_step(model, optimizer, x, y):
def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y_pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads) # In place updates.
return loss
x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
loss = train_step(model, optimizer, x, y)
print(f'{loss = }')
print(f'{optimizer.step.value = }')
loss = Array(1.0000602, dtype=float32)
optimizer.step.value = Array(1, dtype=uint32)
这个例子中有两点值得一提
对每个
nnx.BatchNorm
和nnx.Dropout
层状态的更新会自动从loss_fn
内部传播到train_step
,一直到外部的model
引用。optimizer
持有对model
的可变引用——这种关系在train_step
函数内部得以保留,使得仅使用优化器就能更新模型的参数成为可能。
注意
对于小型模型,nnx.jit
存在性能开销,请查阅性能考量指南以获取更多信息。
扫描层#
下一个示例使用 Flax nnx.vmap
创建一个由多个 MLP 层组成的堆栈,并使用 nnx.scan
迭代地将堆栈中的每一层应用于输入。
在下面的代码中,请注意以下几点
自定义的
create_model
函数接受一个密钥并返回一个MLP
对象。由于您创建了五个密钥并对create_model
使用nnx.vmap
,因此创建了一个包含 5 个MLP
对象的堆栈。nnx.scan
用于迭代地将堆栈中的每个MLP
应用于输入x
。nnx.scan
(有意地)偏离了jax.lax.scan
,而是模仿了更具表现力的nnx.vmap
。nnx.scan
允许指定多个输入、每个输入/输出的扫描轴以及 carry 的位置。nnx.BatchNorm
和nnx.Dropout
层的State
更新由nnx.scan
自动传播。
@nnx.vmap(in_axes=0, out_axes=0)
def create_model(key: jax.Array):
return MLP(10, 32, 10, rngs=nnx.Rngs(key))
keys = jax.random.split(jax.random.key(0), 5)
model = create_model(keys)
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
def forward(model: MLP, x):
x = model(x)
return x
x = jnp.ones((3, 10))
y = forward(model, x)
print(f'{y.shape = }')
nnx.display(model)
y.shape = (3, 10)
Flax NNX 变换是如何实现这一点的?为了理解 Flax NNX 对象如何与 JAX 变换交互,下一节将解释 Flax NNX 函数式 API。
Flax 函数式 API#
Flax NNX 函数式 API 在引用/对象语义和值/pytree 语义之间建立了一个清晰的界限。它也提供了与 Flax Linen 和 Haiku 用户习惯的相同程度的对状态的细粒度控制。Flax NNX 函数式 API 由三个基本方法组成:nnx.split
、nnx.merge
和 nnx.update
。
下面是一个使用函数式 API 的 StatefulLinear
nnx.Module
的示例。它包含
一些
nnx.Param
nnx.Variable
;以及一个自定义的
Count()
nnx.Variable
类型,用于跟踪每次前向传播时都会增加的整数标量状态。
class Count(nnx.Variable): pass
class StatefulLinear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.count = Count(jnp.array(0, dtype=jnp.uint32))
def __call__(self, x: jax.Array):
self.count += 1
return x @ self.w + self.b
model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))
y = model(jnp.ones((1, 3)))
nnx.display(model)
状态与 GraphDef#
一个 Flax nnx.Module
可以使用 nnx.split
函数分解为 nnx.State
和 nnx.GraphDef
。
nnx.State
是一个从字符串到nnx.Variable
或嵌套State
的Mapping
(映射)。nnx.GraphDef
包含了重建一个nnx.Module
图所需的所有静态信息,它类似于 JAX 的PyTreeDef
。
graphdef, state = nnx.split(model)
nnx.display(graphdef, state)
拆分、合并和更新#
Flax 的 nnx.merge
是 nnx.split
的逆操作。它接收 nnx.GraphDef
+ nnx.State
并重构 nnx.Module
。下面的示例演示了这一点:
通过依次使用
nnx.split
和nnx.merge
,任何Module
都可以被提升以在任何 JAX 变换中使用。nnx.update
可以用给定的nnx.State
的内容来就地更新一个对象。此模式用于将状态从变换传播回外部的源对象。
print(f'{model.count.value = }')
# 1. Use `nnx.split` to create a pytree representation of the `nnx.Module`.
graphdef, state = nnx.split(model)
@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:
# 2. Use `nnx.merge` to create a new model inside the JAX transformation.
model = nnx.merge(graphdef, state)
# 3. Call the `nnx.Module`
y = model(x)
# 4. Use `nnx.split` to propagate `nnx.State` updates.
_, state = nnx.split(model)
return y, state
y, state = forward(graphdef, state, x=jnp.ones((1, 3)))
# 5. Update the state of the original `nnx.Module`.
nnx.update(model, state)
print(f'{model.count.value = }')
model.count.value = Array(1, dtype=uint32)
model.count.value = Array(2, dtype=uint32)
这种模式的关键洞见在于,在变换上下文中使用可变引用是可以的(包括基础的即时解释器),但在跨越边界时必须使用函数式 API。
为什么模块不直接是 pytree? 主要原因是,这样很容易意外地丢失共享引用。例如,如果你将两个拥有一个共享 Module
的 nnx.Module
传递通过一个 JAX 边界,你会悄无声息地失去这种共享关系。Flax 的函数式 API 使这种行为变得明确,因此更容易推理。
细粒度状态控制#
经验丰富的 Flax Linen 或 Haiku API 用户可能会认识到,将所有状态放在一个单一结构中并不总是最佳选择,因为在某些情况下,您可能希望以不同方式处理状态的不同子集。这在与 JAX 变换交互时是很常见的。
例如
在与
jax.grad
交互时,并非每个模型状态都可以或应该被微分。或者,有时在使用
jax.lax.scan
时,需要指定模型状态的哪一部分是 carry,哪一部分不是。
为了解决这个问题,Flax NNX API 提供了 nnx.split
,它允许您传递一个或多个 nnx.filterlib.Filter
来将 nnx.Variable
分割成互斥的 nnx.State
。Flax NNX 在 API 中使用 Filter
创建 State
组(例如 nnx.split
、nnx.state()
以及许多 NNX 变换)。
下面的示例展示了最常见的 Filter
# Use `nnx.Variable` type `Filter`s to split into multiple `nnx.State`s.
graphdef, params, counts = nnx.split(model, nnx.Param, Count)
nnx.display(params, counts)
注意: nnx.filterlib.Filter
必须是详尽的,如果一个值没有被匹配,将会引发错误。
不出所料,nnx.merge
和 nnx.update
方法自然会消费多个 State
。
# Merge multiple `State`s
model = nnx.merge(graphdef, params, counts)
# Update with multiple `State`s
nnx.update(model, params, counts)