从 Haiku 迁移到 Flax NNX#
本指南通过并排展示示例代码来演示 Haiku 和 Flax NNX 模型之间的差异,帮助您从 Haiku 迁移到 Flax NNX API。
如果您是 Flax NNX 的新手,请确保您已熟悉 Flax NNX 基础知识,其中涵盖了 nnx.Module
系统、Flax 转换以及带示例的函数式 API。
让我们先导入一些库。
基本模块定义#
Haiku 和 Flax 都使用 Module
类作为表示神经网络库层的默认单元。例如,要创建一个带 dropout 和 ReLU 激活函数的单层网络,您需要:
首先,创建一个
Block
(通过子类化Module
),它由一个带有 dropout 和 ReLU 激活函数的线性层组成。然后,在创建
Model
(同样通过子类化Module
)时,将Block
用作子Module
。Model
由Block
和一个线性层构成。
Haiku 和 Flax 的 Module
对象之间有两个根本区别:
无状态与有状态:
一个
haiku.Module
实例是无状态的。这意味着,变量是从一个纯函数的Module.init()
调用中返回并单独管理的。然而,一个
flax.nnx.Module
将其变量作为该 Python 对象的属性来拥有。
惰性与即时:
一个
haiku.Module
只有在用户调用模型并实际看到输入时才会分配空间来创建变量(惰性)。一个
flax.nnx.Module
实例在实例化时就会创建变量,在看到样本输入之前(即时)。
import haiku as hk
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class Model(hk.Module):
def __init__(self, dmid: int, dout: int, name=None):
super().__init__(name=name)
self.dmid = dmid
self.dout = dout
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = hk.Linear(self.dout)(x)
return x
from flax import nnx
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
self.block = Block(din, dmid, rngs=rngs)
self.linear = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = self.block(x)
x = self.linear(x)
return x
变量创建#
本节介绍如何实例化模型并初始化其参数。
要为 Haiku 模型生成模型参数,您需要将其放入一个前向函数中,并使用
haiku.transform
使其成为纯函数。这将产生一个 JAX 数组(jax.Array
数据类型)的嵌套字典,需要单独携带和维护。在 Flax NNX 中,当您实例化模型时,模型参数会自动初始化,并且变量(
nnx.Variable
对象)作为属性存储在nnx.Module
(或其子模块)内部。您仍然需要为其提供一个伪随机数生成器 (PRNG) 密钥,但该密钥将被包装在nnx.Rngs
类中并存储在内部,在需要时生成更多的 PRNG 密钥。
如果您想以无状态、类似字典的方式访问 Flax 模型参数以进行检查点保存或模型修改,请查阅 Flax NNX 拆分/合并 API (nnx.split
/ nnx.merge
)。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform(forward)
sample_x = jnp.ones((1, 784))
params = model.init(jax.random.key(0), sample_x, training=False)
assert params['model/linear']['b'].shape == (10,)
assert params['model/block/linear']['w'].shape == (784, 256)
...
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
# Parameters were already initialized during model instantiation.
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
训练步骤和编译#
本节介绍如何编写训练步骤并使用 JAX 即时编译对其进行编译。
编译训练步骤时:
Haiku 使用
@jax.jit
——一个 JAX 转换——来编译一个纯函数的训练步骤。Flax NNX 使用
@nnx.jit
——一个 Flax NNX 转换(其行为类似于 JAX 转换,但也能很好地与 Flax 对象协同工作的几个转换 API 之一)。jax.jit
只接受带有纯无状态参数的函数,而flax.nnx.jit
允许参数是有状态的模块。这极大地减少了训练步骤所需的代码行数。
计算梯度时:
同样,Haiku 使用
jax.grad
(一个用于自动微分的 JAX 转换)来返回一个原始的梯度字典。与此同时,Flax NNX 使用
flax.nnx.grad
(一个 Flax NNX 转换)来返回 Flax NNX 模块的梯度,形式为flax.nnx.State
字典。如果您想在 Flax NNX 中使用常规的jax.grad
,则需要使用拆分/合并 API。
对于优化器:
如果您已经在使用 Optax 优化器(如
optax.adamw
)与 Haiku(而不是此处显示的原始jax.tree.map
计算),请查看 Flax 基础知识指南中的flax.nnx.Optimizer
示例,了解一种更简洁的训练和更新模型的方法。
每个训练步骤中的模型更新
Haiku 训练步骤需要返回一个参数的 JAX pytree,作为下一步的输入。
Flax NNX 训练步骤不需要返回任何东西,因为
model
已经在nnx.jit
内部就地更新了。此外,
nnx.Module
对象是有状态的,并且Module
会自动跟踪其内部的几项内容,例如 PRNG 密钥和flax.nnx.BatchNorm
统计信息。这就是为什么您不需要在每一步都显式传入 PRNG 密钥。另请注意,您可以使用flax.nnx.reseed
来重置其底层的 PRNG 状态。
dropout 行为
在 Haiku 中,您需要显式定义并传入
training
参数来切换haiku.dropout
,并确保只有在training=True
时才会发生随机 dropout。在 Flax NNX 中,您可以调用
model.train()
(flax.nnx.Module.train()
) 来自动将flax.nnx.Dropout
切换到训练模式。相反,您可以调用model.eval()
(flax.nnx.Module.eval()
) 来关闭训练模式。您可以在flax.nnx.Module.train
的 API 参考中了解更多关于它的作用。
...
@jax.jit
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params, key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
model.train() # set deterministic=False
@nnx.jit
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(
inputs, # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.merge_state(params, rest))
处理非参数状态#
Haiku 对可训练参数和所有其他模型跟踪的数据(“状态”)进行了区分。例如,批归一化中使用的批次统计信息被视为一种状态。带有状态的模型需要使用 hk.transform_with_state
进行转换,以便它们的 .init()
同时返回参数和状态。
在 Flax 中,没有这样严格的区分——它们都是 nnx.Variable
的子类,并被模块视为其属性。参数是名为 nnx.Param
的子类的实例,而批次统计信息可以是另一个名为 nnx.BatchStat
的子类的实例。您可以使用 nnx.split
快速提取特定变量类型的所有数据。
让我们通过一个例子来看看这一点,我们采用上面的 Block
定义,但用 BatchNorm
替换 dropout。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.BatchNorm(
create_scale=True, create_offset=True, decay_rate=0.99
)(x, is_training=training)
x = jax.nn.relu(x)
return x
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)
sample_x = jnp.ones((1, 784))
params, batch_stats = model.init(jax.random.key(0), sample_x, training=True)
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.batchnorm = nnx.BatchNorm(
num_features=out_features, momentum=0.99, rngs=rngs
)
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
x = jax.nn.relu(x)
return x
model = Block(4, 4, rngs=nnx.Rngs(0))
model.linear.kernel # Param(value=...)
model.batchnorm.mean # BatchStat(value=...)
Flax 考虑了可训练参数和其他数据之间的差异。nnx.grad
将只对 nnx.Param
变量求梯度,从而自动跳过 batchnorm
数组。因此,对于这个模型,Flax NNX 的训练步骤看起来是一样的。
使用多种方法#
在本节中,您将学习如何在 Haiku 和 Flax 中使用多种方法。作为示例,您将实现一个具有三种方法的自动编码器模型:encode
、decode
和 __call__
。
在 Haiku 中,您需要使用 hk.multi_transform
来显式定义模型应如何初始化以及它可以调用哪些方法(这里是 encode
和 decode
)。请注意,您仍然需要定义一个 __call__
,它会激活两个层,以便对所有模型参数进行惰性初始化。
在 Flax 中,这更简单,因为您在 __init__
中初始化参数,并且 nnx.Module
的方法 encode
和 decode
可以直接使用。
class AutoEncoder(hk.Module):
def __init__(self, embed_dim: int, output_dim: int, name=None):
super().__init__(name=name)
self.encoder = hk.Linear(embed_dim, name="encoder")
self.decoder = hk.Linear(output_dim, name="decoder")
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
def forward():
module = AutoEncoder(256, 784)
init = lambda x: module(x)
return init, (module.encode, module.decode)
model = hk.multi_transform(forward)
params = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):
def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
...
参数结构如下:
...
{
'auto_encoder/~/decoder': {
'b': (784,),
'w': (256, 784)
},
'auto_encoder/~/encoder': {
'b': (256,),
'w': (784, 256)
}
}
_, params, _ = nnx.split(model, nnx.Param, ...)
params
{
'decoder': {
'bias': Param(value=(784,)),
'kernel': Param(value=(256, 784))
},
'encoder': {
'bias': Param(value=(256,)),
'kernel': Param(value=(784, 256))
}
}
要调用这些自定义方法:
在 Haiku 中,您需要解耦 .apply 函数以提取您的方法,然后再调用它。
在 Flax 中,您可以直接调用该方法。
encode, decode = model.apply
z = encode(params, None, x=jnp.ones((1, 784)))
...
z = model.encode(jnp.ones((1, 784)))
转换#
Haiku 和 Flax 转换都提供了各自的转换集,它们包装了 JAX 转换,使得它们可以与 Module
对象一起使用。
有关 Flax 转换的更多信息,请查阅转换指南。
让我们从一个例子开始:
首先,定义一个
RNNCell
Module
,它将包含 RNN 单个步骤的逻辑。定义一个
initial_state
方法,它将用于初始化 RNN 的状态(也称为carry
)。与jax.lax.scan
(API 文档)类似,RNNCell.__call__
方法将是一个函数,它接受进位 (carry) 和输入,并返回新的进位和输出。在这种情况下,进位和输出是相同的。
class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = self.linear(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下来,我们将定义一个 RNN
模块,它将包含整个 RNN 的逻辑。在这两种情况下,我们都使用库的 scan
调用来在输入序列上运行 RNNCell
。
唯一的区别是,Flax 的 nnx.scan
允许您在参数 in_axes
和 out_axes
中指定要在哪个轴上重复,这些参数将被转发到底层的 `jax.lax.scan<https://jax.net.cn/en/latest/_autosummary/jax.lax.scan.html>`__,而在 Haiku 中,您需要显式地转置输入和输出。
class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(
cell, carry,
jnp.swapaxes(x, 1, 0)
)
y = jnp.swapaxes(y, 0, 1)
return y
class RNN(nnx.Module):
def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
self.hidden_size = hidden_size
self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)
def __call__(self, x):
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)
return y
扫描层#
大多数 Haiku 转换应该与 Flax 类似,因为它们都包装了它们的 JAX 对应项,但跨层扫描(scan-over-layers)的用例是一个例外。
跨层扫描是一种技术,您将输入通过一个由 N 个重复层组成的序列,将每个层的输出作为下一层的输入。这种模式可以显著减少大型模型的编译时间。在下面的示例中,您将在顶层 MLP
Module
中重复 Block
Module
5 次。
在 Haiku 中,我们像往常一样定义 Block
模块,然后在 MLP
内部使用 hk.experimental.layer_stack
对一个 stack_block
函数进行操作,以创建一个 Block
模块的堆栈。同样的代码将在初始化时创建 5 个层的参数,并在调用时将输入通过它们运行。
在 Flax 中,模型初始化和调用代码是完全解耦的,因此我们使用 nnx.vmap
转换来初始化底层的 Block
参数,并使用 nnx.scan
转换来通过它们运行模型输入。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
super().__init__(name=name)
self.features = features
self.num_layers = num_layers
def __call__(self, x, training: bool):
@hk.experimental.layer_stack(self.num_layers)
def stack_block(x):
return Block(self.features)(x, training)
stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)
def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)
sample_x = jnp.ones((1, 64))
params = model.init(jax.random.key(0), sample_x, training=False)
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array): # No need to require a second input!
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x # No need to return a second output!
class MLP(nnx.Module):
def __init__(self, features, num_layers, rngs):
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return Block(features, features, rngs=rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def __call__(self, x):
@nnx.split_rngs(splits=self.num_layers)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))
在上面的 Flax 示例中还有一些其他细节需要解释:
`@nnx.split_rngs` 装饰器: Flax 转换,就像它们的 JAX 对应项一样,完全不关心 PRNG 状态,而是依赖输入来获取 PRNG 密钥。
nnx.split_rngs
装饰器允许您在将nnx.Rngs
传递给被装饰的函数之前对其进行分割,并在之后将其“降低”,以便它们可以在外部使用。在这里,您分割 PRNG 密钥是因为
jax.vmap
和jax.lax.scan
如果其每个内部操作都需要自己的密钥,则需要一个 PRNG 密钥列表。因此,对于MLP
内部的 5 个层,您在进入 JAX 转换之前,从其参数中分割并提供 5 个不同的 PRNG 密钥。请注意,实际上
create_block()
知道它需要创建 5 个层,*正是因为*它看到了 5 个 PRNG 密钥,因为in_axes=(0,)
表示vmap
将查看第一个参数的第一个维度来了解它将映射的大小。forward()
也是如此,它查看第一个参数(即model
)内部的变量来确定它需要扫描多少次。nnx.split_rngs
在这里实际上分割了model
内部的 PRNG 状态。(如果Block
Module
没有 dropout,您就不需要nnx.split_rngs
这一行,因为它无论如何都不会消耗任何 PRNG 密钥。)
为什么 Flax 中的 Block 模块不需要接收和返回那个额外的虚拟值:
jax.lax.scan
(API 文档) 要求其函数返回两个输入——进位 (carry) 和堆叠的输出。在这种情况下,我们没有使用后者。Flax 简化了这一点,因此如果您将out_axes
设置为nnx.Carry
而不是默认的(nnx.Carry, 0)
,现在就可以忽略第二个输出。这是 Flax NNX 转换偏离 JAX 转换 API 的罕见情况之一。
上面的 Flax 示例中有更多的代码行,但它们更精确地表达了每个时间点发生的事情。由于 Flax 转换变得更接近 JAX 转换 API,建议在使用其 Flax NNX 等价物之前,对底层的 JAX 转换有一个很好的理解。
现在检查两侧的变量 pytree
...
{
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
}
}
...
_, params, _ = nnx.split(model, nnx.Param, ...)
params
{
'blocks': {
'linear': {
'bias': Param(value=(5, 64)),
'kernel': Param(value=(5, 64, 64))
}
}
}
顶层 Haiku 函数与顶层 Flax 模块#
在 Haiku 中,可以通过使用原始的 hk.{get,set}_{parameter,state}
来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶层“模块”写成函数是很常见的做法。
Flax 团队推荐一种更以模块为中心的方法,即使用 __call__
来定义前向函数。在 Flax 模块中,可以使用常规的 Python 类语义来正常设置和访问参数和变量。
...
def forward(x):
counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter(
'multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones
)
output = x + multiplier * counter
hk.set_state("counter", counter + 1)
return output
model = hk.transform_with_state(forward)
params, state = model.init(jax.random.key(0), jnp.ones((1, 64)))
class Counter(nnx.Variable):
pass
class FooModule(nnx.Module):
def __init__(self, rngs):
self.counter = Counter(jnp.ones((), jnp.int32))
self.multiplier = nnx.Param(
nnx.initializers.ones(rngs.params(), [1,], jnp.float32)
)
def __call__(self, x):
output = x + self.multiplier * self.counter.value
self.counter.value += 1
return output
model = FooModule(rngs=nnx.Rngs(0))
_, params, counter = nnx.split(model, nnx.Param, Counter)