从 Haiku 迁移到 Flax#
本指南将逐步介绍将 Haiku 模型迁移到 Flax 的过程,并重点介绍这两个库之间的差异。
基本示例#
要创建自定义模块,您需要在 Haiku 和 Flax 中都从 Module 基类进行子类化。但是,Haiku 类使用常规的 __init__ 方法,而 Flax 类是 dataclasses,这意味着您定义了一些用于自动生成构造函数的类属性。此外,所有 Flax 模块都接受一个 name 参数,无需定义它,而在 Haiku 中,name 必须在构造函数签名中显式定义并传递给超类构造函数。
import haiku as hk
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class Model(hk.Module):
def __init__(self, dmid: int, dout: int, name=None):
super().__init__(name=name)
self.dmid = dmid
self.dout = dout
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = hk.Linear(self.dout)(x)
return x
import flax.linen as nn
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
@nn.compact
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
__call__ 方法在这两个库中看起来非常相似,但是,在 Flax 中,您必须使用 @nn.compact 装饰器才能在内联定义子模块。在 Haiku 中,这是默认行为。
现在,Haiku 和 Flax 在构建模型方面存在很大差异。在 Haiku 中,您使用 hk.transform 对调用模块的函数进行转换,transform 将返回一个具有 init 和 apply 方法的对象。在 Flax 中,您只需实例化您的模块即可。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform(forward)
...
model = Model(256, 10)
要获取这两个库中的模型参数,请使用 init 方法,并使用 random.key 加上一些输入来运行模型。这里的主要区别是 Flax 返回一个从集合名称到嵌套数组字典的映射,params 只是这些可能的集合之一。在 Haiku 中,您可以直接获得 params 结构。
sample_x = jax.numpy.ones((1, 784))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables["params"]
需要注意的一件非常重要的事情是,在 Flax 中,参数结构是分层的,每个嵌套模块有一层,最后一层用于参数名称。在 Haiku 中,参数结构是一个 Python 字典,具有两级层次结构:完全限定的模块名称映射到参数名称。模块名称由 / 分隔的字符串路径组成,该路径包含所有嵌套模块。
...
{
'model/block/linear': {
'b': (256,),
'w': (784, 256),
},
'model/linear': {
'b': (10,),
'w': (256, 10),
}
}
...
FrozenDict({
Block_0: {
Dense_0: {
bias: (256,),
kernel: (784, 256),
},
},
Dense_0: {
bias: (10,),
kernel: (256, 10),
},
})
在两个框架中的训练过程中,您将参数结构传递给 apply 方法以运行正向传递。由于我们使用的是丢弃,因此在这两种情况下,我们都必须向 apply 提供一个 key 以便生成随机丢弃掩码。
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
params,
key,
inputs, training=True # <== inputs
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params
最显著的差异是,在 Flax 中,您必须将参数放在带有 params 键的字典中,并将键放在带有 dropout 键的字典中。这是因为在 Flax 中,您可以拥有多种类型的模型状态和随机状态。在 Haiku 中,您只需直接传递参数和键。
处理状态#
现在让我们看看这两个库是如何处理可变状态的。我们将使用与之前相同的模型,但现在我们将用批量归一化替换丢弃。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.BatchNorm(
create_scale=True, create_offset=True, decay_rate=0.99
)(x, is_training=training)
x = jax.nn.relu(x)
return x
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.BatchNorm(
momentum=0.99
)(x, use_running_average=not training)
x = jax.nn.relu(x)
return x
在这种情况下,代码非常相似,因为两个库都提供了一个批量归一化层。最显著的差异是 Haiku 使用 is_training 来控制是否更新运行统计信息,而 Flax 使用 use_running_average 来实现相同目的。
要在 Haiku 中实例化有状态模型,您需要使用 hk.transform_with_state,它会更改 init 和 apply 的签名以接受和返回状态。与之前一样,在 Flax 中,您直接构建模块。
def forward(x, training: bool):
return Model(256, 10)(x, training)
model = hk.transform_with_state(forward)
...
model = Model(256, 10)
要初始化参数和状态,您只需像以前一样调用 init 方法即可。但是,在 Haiku 中,您现在会获得 state 作为第二个返回值,而在 Flax 中,您会在 variables 字典中获得一个新的 batch_stats 集合。请注意,由于 hk.BatchNorm 仅在 is_training=True 时初始化批量统计信息,因此我们在初始化具有 hk.BatchNorm 层的 Haiku 模型的参数时必须设置 training=True。在 Flax 中,我们可以像往常一样设置 training=False。
sample_x = jax.numpy.ones((1, 784))
params, state = model.init(
random.key(0),
sample_x, training=True # <== inputs
)
...
sample_x = jax.numpy.ones((1, 784))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]
通常,在 Flax 中,您可能会在 variables 字典中找到其他状态集合,例如 cache(用于自回归 Transformer 模型)、intermediates(用于使用 Module.sow 添加的中间值)或其他由自定义层定义的集合名称。Haiku 仅区分 params(在运行 apply 时不会改变的变量)和 state(在运行 apply 时可能会改变的变量),并使用 hk.transform 或 hk.transform_with_state。
现在,训练在这两个框架中看起来非常相似,因为您使用相同的 apply 方法来运行正向传递。在 Haiku 中,现在将 state 作为第二个参数传递给 apply,并将新状态作为第二个返回值获得。在 Flax 中,您改为将 batch_stats 作为新键添加到输入字典中,并将 updates 变量字典作为第二个返回值获得。
def train_step(params, state, inputs, labels):
def loss_fn(params):
logits, new_state = model.apply(
params, state,
None, # <== rng
inputs, training=True # <== inputs
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, new_state
grads, new_state = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, new_state
def train_step(params, batch_stats, inputs, labels):
def loss_fn(params):
logits, updates = model.apply(
{'params': params, 'batch_stats': batch_stats},
inputs, training=True, # <== inputs
mutable='batch_stats',
)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
return loss, updates["batch_stats"]
grads, batch_stats = jax.grad(loss_fn, has_aux=True)(params)
params = jax.tree_util.tree_map(lambda p, g: p - 0.1 * g, params, grads)
return params, batch_stats
一个主要区别是,在 Flax 中,状态集合可以是可变的,也可以是不可变的。在 init 期间,默认情况下所有集合都是可变的,但是,在 apply 期间,您必须显式指定哪些集合是可变的。在本例中,我们指定 batch_stats 是可变的。这里传递的是单个字符串,但如果存在更多可变集合,则也可以传递列表。如果未执行此操作,则在尝试更改 batch_stats 时,将在运行时引发错误。此外,当 mutable 不是 False 时,updates 字典将作为 apply 的第二个返回值返回,否则仅返回模型输出。Haiku 通过使用 params(不可变)和 state(可变)以及使用 hk.transform 或 hk.transform_with_state 来区分可变/不可变。
使用多个方法#
在本节中,我们将了解如何在 Haiku 和 Flax 中使用多个方法。例如,我们将实现一个具有三个方法的自编码器模型:encode、decode 和 __call__。
在 Haiku 中,我们只需直接在 __init__ 中定义 encode 和 decode 所需的子模块,在本例中,每个模块都将使用 Linear 层。在 Flax 中,我们将在 setup 中提前定义 encoder 和 decoder 模块,并在 encode 和 decode 中分别使用它们。
class AutoEncoder(hk.Module):
def __init__(self, embed_dim: int, output_dim: int, name=None):
super().__init__(name=name)
self.encoder = hk.Linear(embed_dim, name="encoder")
self.decoder = hk.Linear(output_dim, name="decoder")
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
请注意,在 Flax 中,setup 不会在 __init__ 之后运行,而是在调用 init 或 apply 时运行。
现在,我们希望能够从我们的 AutoEncoder 模型中调用任何方法。在 Haiku 中,我们可以通过 hk.multi_transform 为模块定义多个 apply 方法。传递给 multi_transform 的函数定义了如何初始化模块以及要生成哪些不同的应用方法。
def forward():
module = AutoEncoder(256, 784)
init = lambda x: module(x)
return init, (module.encode, module.decode)
model = hk.multi_transform(forward)
...
model = AutoEncoder(256, 784)
为了初始化模型的参数,可以使用 init 触发 __call__ 方法,该方法同时使用 encode 和 decode 方法。这将创建模型所需的所有参数。
params = model.init(
random.key(0),
x=jax.numpy.ones((1, 784)),
)
...
variables = model.init(
random.key(0),
x=jax.numpy.ones((1, 784)),
)
params = variables["params"]
这将生成以下参数结构。
{
'auto_encoder/~/decoder': {
'b': (784,),
'w': (256, 784)
},
'auto_encoder/~/encoder': {
'b': (256,),
'w': (784, 256)
}
}
FrozenDict({
decoder: {
bias: (784,),
kernel: (256, 784),
},
encoder: {
bias: (256,),
kernel: (784, 256),
},
})
最后,让我们探索如何使用 apply 函数调用 encode 方法。
encode, decode = model.apply
z = encode(
params,
None, # <== rng
x=jax.numpy.ones((1, 784)),
)
...
z = model.apply(
{"params": params},
x=jax.numpy.ones((1, 784)),
method="encode",
)
由于 Haiku apply 函数是通过 hk.multi_transform 生成的,它是一个包含两个函数的元组,我们可以将其解包为一个 encode 函数和一个 decode 函数,它们对应于 AutoEncoder 模块上的方法。在 Flax 中,我们通过将方法名称作为字符串传递来调用 encode 方法。这里另一个值得注意的区别是,在 Haiku 中,即使模块在 apply 期间没有使用任何随机操作,也需要显式地传递 rng。在 Flax 中,这不是必需的(请查看 Flax 中的随机性和 PRNG)。这里的 Haiku rng 设置为 None,但你也可以在 apply 函数上使用 hk.without_apply_rng 来删除 rng 参数。
提升的转换#
Flax 和 Haiku 都提供了一组转换,我们将它们称为提升的转换,这些转换以一种可以与模块一起使用的方式包装 JAX 转换,并且有时会提供额外的功能。在本节中,我们将了解如何在 Flax 和 Haiku 中使用 scan 的提升版本来实现一个简单的 RNN 层。
首先,我们将定义一个 RNNCell 模块,该模块将包含 RNN 单步的逻辑。我们还将定义一个 initial_state 方法,该方法将用于初始化 RNN 的状态(也称为 carry)。与 jax.lax.scan 一样,RNNCell.__call__ 方法将是一个函数,它接收传递和输入,并返回新的传递和输出。在这种情况下,传递和输出是相同的。
class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下来,我们将定义一个 RNN 模块,该模块将包含整个 RNN 的逻辑。在 Haiku 中,我们将首先初始化 RNNCell,然后使用它来构建 carry,最后使用 hk.scan 在输入序列上运行 RNNCell。在 Flax 中,它的做法略有不同,我们将使用 nn.scan 来定义一个新的临时类型,该类型包装 RNNCell。在此过程中,我们还将指定指示 nn.scan 广播 params 集合(所有步骤共享相同的参数)并且不拆分 params rng 流(以便所有步骤使用相同的参数进行初始化),最后,我们将指定我们希望 scan 在输入的第二个轴上运行,并将输出也沿着第二个轴堆叠起来。然后,我们将立即使用此临时类型来创建提升的 RNNCell 的实例,并使用它来创建 carry 并运行 __call__ 方法,该方法将在序列上进行 scan。
class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size
def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0))
y = jnp.swapaxes(y, 0, 1)
return y
class RNN(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, x):
rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False},
in_axes=1, out_axes=1)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
总的来说,Flax 和 Haiku 之间提升的转换的主要区别在于,在 Haiku 中,提升的转换不会对状态进行操作,也就是说,Haiku 将以一种在转换内外保持相同形状的方式处理 params 和 state。在 Flax 中,提升的转换可以对变量集合和 rng 流进行操作,用户必须根据转换的语义定义每个转换如何处理不同的集合。
最后,让我们快速查看如何在 Haiku 和 Flax 中使用 RNN 模块。
def forward(x):
return RNN(64)(x)
model = hk.without_apply_rng(hk.transform(forward))
params = model.init(
random.key(0),
x=jax.numpy.ones((3, 12, 32)),
)
y = model.apply(
params,
x=jax.numpy.ones((3, 12, 32)),
)
...
model = RNN(64)
variables = model.init(
random.key(0),
x=jax.numpy.ones((3, 12, 32)),
)
params = variables['params']
y = model.apply(
{'params': params},
x=jax.numpy.ones((3, 12, 32)),
)
与前面部分中的示例相比,唯一值得注意的变化是这次我们在 Haiku 中使用了 hk.without_apply_rng,因此我们不必将 rng 参数作为 None 传递给 apply 方法。
在层上进行扫描#
scan 的一个非常重要的应用是,迭代地在输入上应用一系列层,将每个层的输出作为下一个层的输入传递。这对于减少大型模型的编译时间非常有用。例如,我们将创建一个简单的 Block 模块,然后将其用在 MLP 模块中,该模块将应用 num_layers 次的 Block 模块。
在 Haiku 中,我们像往常一样定义 Block 模块,然后在 MLP 中,我们将在 stack_block 函数上使用 hk.experimental.layer_stack 来创建一个 Block 模块堆栈。在 Flax 中,Block 的定义略有不同,__call__ 将接收和返回一个第二个虚拟输入/输出,它们在两种情况下都将为 None。在 MLP 中,我们将像在前面的示例中一样使用 nn.scan,但通过设置 split_rngs={'params': True} 和 variable_axes={'params': 0},我们告诉 nn.scan 为每个步骤创建不同的参数并沿着第一个轴切片 params 集合,从而有效地实现一个 Block 模块堆栈,就像在 Haiku 中一样。
class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features
def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x
class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
super().__init__(name=name)
self.features = features
self.num_layers = num_layers
def __call__(self, x, training: bool):
@hk.experimental.layer_stack(self.num_layers)
def stack_block(x):
return Block(self.features)(x, training)
stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)
y, _ = ScanBlock(self.features, training)(x, None)
return y
请注意,在 Flax 中,我们如何将 None 作为第二个参数传递给 ScanBlock 并忽略其第二个输出。这些表示每个步骤的输入/输出,但它们是 None,因为在这种情况下,我们没有任何输入/输出。
初始化每个模型与前面的示例相同。在这种情况下,我们将指定我们希望使用 5 个层,每个层具有 64 个特征。
def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)
model = hk.transform(forward)
sample_x = jax.numpy.ones((1, 64))
params = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
...
...
model = MLP(64, num_layers=5)
sample_x = jax.numpy.ones((1, 64))
variables = model.init(
random.key(0),
sample_x, training=False # <== inputs
)
params = variables['params']
当在层上使用 scan 时,你应该注意的是,所有层都融合成一个单独的层,其参数在第一个轴上有一个额外的“层”维度。在这种情况下,所有参数的形状都将以 (5, ...) 开头,因为我们使用的是 5 个层。
...
{
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
}
}
...
FrozenDict({
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
})
顶级 Haiku 函数与顶级 Flax 模块#
在 Haiku 中,可以通过使用原始的 hk.{get,set}_{parameter,state} 来定义/访问模型参数和状态,从而将整个模型编写为单个函数。将顶级“模块”编写为函数非常常见。
Flax 团队建议使用更以模块为中心的 approach,该 approach 使用 __call__ 来定义前向函数。相应的访问器将是 nn.module.param 和 nn.module.variable(有关集合的解释,请转到 处理状态)。
def forward(x):
counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
output = x + multiplier * counter
hk.set_state("counter", counter + 1)
return output
model = hk.transform_with_state(forward)
params, state = model.init(random.key(0), jax.numpy.ones((1, 64)))
class FooModule(nn.Module):
@nn.compact
def __call__(self, x):
counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype)
output = x + multiplier * counter.value
if not self.is_initializing(): # otherwise model.init() also increases it
counter.value += 1
return output
model = FooModule()
variables = model.init(random.key(0), jax.numpy.ones((1, 64)))
params, counter = variables['params'], variables['counter']