目录
动机#
我们当前的 API(本文档中称为旧版 API)使用一种模式,即从 target
变量的 pytree 和定义如何更新优化器状态、超参数和目标变量的 OptimizerDef
创建一个 Optimizer
数据类。这种模式对于实现简单的优化器来说相对复杂,同时在典型的 Linen 训练步骤中(尤其是在使用可变状态集合时)相当冗长。
这个包 flax.optim
包含一些优化器,但这个列表远非详尽,理想情况下我们应该从一个专门的 PyPi 包中使用 JAX 优化器。
DeepMind 已经有一个专门的库 — Optax — 它实现了各种有趣的优化器,并提供了一个框架,可以从可重用的梯度变换中组合新的优化器。
使用 Optax#
梯度变换#
虽然 Optax 确实提供了预定义的优化器(例如 optax.adam
或带有动量的 optax.sgd
),但它实际上是一个*梯度变换*库,实例化优化器的惯用方法是提供这些梯度变换的组合。为了在使用旧版 API时模拟示例中的动量优化器,我们将编写
import optax
tx = optax.chain(
optax.trace(decay=0.9, nesterov=False),
optax.scale_by_schedule(lambda step: -get_learning_rate(step)),
)
备注
上述梯度变换将等同于Optimizer 和 OptimizerDef下的示例,其中我们定义了一个没有 Nesterov 动量的 Momentum 优化器(请注意,
beta
参数对应于optax.trace()
变换的decay
参数,学习率在第二个链式变换中应用)。请注意,像
decay
或nesterov
这样的超参数仅存在于返回GradientTransformation
的高阶函数的内部作用域中。这样的梯度变换当前被定义为init()
和update()
函数的NamedTuple
。原则上,这种模式也可以扩展到存储超参数,这可能是在 Optax 仓库上讨论的一个点。我们可以在定义 Optax 梯度更新变换时,使用一个根据步数返回学习率的
get_learning_rate()
函数。上述代码展示了这如何作为我们旧版训练步骤中也使用的函数的直接替代品,该更新函数已经存在(请注意,我们需要反转符号,因为我们将梯度更新添加到参数中)。此外,您可以使用inject_hyperparams()
来使用 Optax 调度任意超参数。
Optax 训练步骤#
@functools.partial(jax.jit, static_argnums=(4, 5))
def train_step(opt_state, variables, inputs, labels, apply_fn, tx_update_fn):
def loss_fn(params):
logits, new_model_state = apply_fn(
{**variables, 'params': params}, inputs, mutable=['batch_stats'])
loss = xent_loss(logits, labels)
return loss, new_model_state
variables, params = variables.pop('params')
(loss, new_model_state), grads = jax.value_and_grad(loss_fn, has_aux=True)(
params)
updates, new_opt_state = tx_update_fn(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
new_variables = {**variables, **new_model_state, 'params': new_params}
return new_opt_state, new_variables, loss
opt_state = tx.init(variables['params'])
for batch in ds.as_numpy_iterator():
opt_state, variables, loss = train_step(
opt_state, variables, batch['image'], batch['label'], model.apply,
tx.update)
print(loss)
备注
由于
tx.update()
仅转换梯度,我们仍然需要调用optax.apply_updates()
来将这些转换后的梯度应用于参数。与旧版 API相比,我们现在可以将包括
params
在内的整个variables
作为train_step()
的输入和输出。在训练步骤中,仍然需要将
params
从variables
中分离出来,因为我们只想计算相对于params
而不是整个variables
的梯度。只要 Optax 变换在其各自的状态中暴露这些信息,我们仍然可以记录内部优化器状态,例如学习率。例如,
optax.scale_by_schedule()
当前仅暴露opt_state.count
,但可以轻松扩展以暴露step_size
。对于随时间变化的内部优化器状态也是如此。
多优化器#
旧版 API 定义了 flax.optim.MultiOptimizer
用于使用不同的优化器处理参数树的不同部分
biases_traversal = flax.optim.ModelParamTraversal(
lambda path, _: path.endswith('/bias'))
not_biases_traversal = flax.optim.ModelParamTraversal(
lambda path, _: not path.endswith('/bias'))
optimizer_def = flax.optim.MultiOptimizer(
(biases_traversal, flax.optim.GradientDescent(learning_rate=0.1)),
(not_biases_traversal, flax.optim.GradientDescent(learning_rate=0.05)),
)
请注意,我们如何首先定义一个遍历,根据参数的路径(模块作用域和变量名称的串联)选择参数,然后创建一个 MultiOptimizer
,为每个独立的遍历绑定一个不同的优化器。
Optax 最近实现了 optax.masked()
,可用于指定仅应用于梯度子集的梯度变换
def flattened_traversal(fn):
def mask(data):
flat = traverse_util.flatten_dict(data)
return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
return mask
tx = optax.chain(
optax.masked(optax.sgd(learning_rate=0.1),
mask=flattened_traversal(lambda path, _: path[-1] == 'bias')),
optax.masked(optax.sgd(learning_rate=0.05),
mask=flattened_traversal(lambda path, _: path[-1] != 'bias')),
)
训练状态#
在 Flax 中,通常会传递一个 TrainState
对象,该对象可以用于检查点。这通过减少参数数量和去除 static_argnums
,简化了上述Optax 训练步骤。
我们可以定义一个 TrainState
数据类,它封装了通过应用梯度来更新优化器状态和参数的常见模式。
# Small helper class in flax.training
class TrainState(flax.struct.PyTreeNode):
step: int
apply_fn: Callable = flax.struct.field(pytree_node=False)
params: flax.core.FrozenDict[str, Any]
tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
opt_state: optax.OptState
def apply_gradients(self, *, grads, **kwargs):
updates, new_opt_state = self.tx.update(
grads, self.opt_state, self.params)
new_params = optax.apply_updates(self.params, updates)
return self.replace(
step=self.step + 1,
params=new_params,
opt_state=new_opt_state,
**kwargs,
)
@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
opt_state = tx.init(params)
return cls(
step=0,
apply_fn=apply_fn,
params=params,
tx=tx,
opt_state=opt_state,
**kwargs,
)
用户可以从这个数据类派生并添加更多字段,例如可变模型状态
from flax.training import train_state
class TrainState(train_state.TrainState):
batch_stats: flax.core.FrozenDict[str, Any]
这样,Optax 训练步骤就变成了
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params):
outputs, new_model_state = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
inputs,
mutable=['batch_stats'])
loss = xent_loss(outputs, labels)
return loss, new_model_state
(loss, new_model_state), grads = jax.value_and_grad(
loss_fn, has_aux=True)(state.params)
new_state = state.apply_gradients(
grads=grads,
batch_stats=new_model_state['batch_stats'],
)
return new_state, loss
state = TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx,
batch_stats=variables['batch_stats'],
)
for batch in ds.as_numpy_iterator():
state, loss = train_step(state, batch['image'], batch['label'])
不带可变状态的训练步骤简化为
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params):
outputs = state.apply_fn({'params': params}, inputs)
loss = xent_loss(outputs, labels)
return loss
loss, grads = jax.value_and_grad(loss_fn)(state.params)
new_state = state.update(grads=grads)
return new_state, loss
state = flax.training.TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx,
)
for batch in ds.as_numpy_iterator():
state, loss = train_step(state, batch['image'], batch['label'])
备注
在 Flax 训练循环中,一个常见的模式是拥有一个
TrainState
数据类,它在每个步骤后都会用新状态进行更新。flax.training.train_state
中提出的简单解决方案可以扩展以包含额外数据,但不支持高级用例(例如,多个不同的模型和/或优化器)。用户应分叉数据类并根据自己的需求重新实现。与旧版 API中的
Optimizer
抽象不同,现在TrainState
直接包含.params
,无需通过.optimizer
旧版 API#
Optimizer 和 OptimizerDef#
优化器本身将通过创建从 OpimizerDef
派生新类来实现
# flax/optim/momentum.py
@flax.struct.dataclass
class _MomentumHyperParams:
learning_rate: jnp.ndarray
beta: jnp.ndarray
@flax.struct.dataclass
class _MomentumParamState:
momentum: np.ndarray
class Momentum(flax.optim.OptimizerDef):
def __init__(self, learning_rate=None, beta=0.9):
super().__init__(
_MomentumHyperParams(learning_rate, beta)
)
def init_param_state(self, param):
return _MomentumParamState(jnp.zeros_like(param))
def apply_param_gradient(self, step, hyper_params, param, state, grad):
del step
assert hyper_params.learning_rate is not None
new_momentum = state.momentum * hyper_params.beta + grad
new_params = param - hyper_params.learning_rate * new_momentum
return new_params, _MomentumParamState(new_momentum)
备注
注意
OptimizerDef
和Optimizer
之间的关系:当用户代码调用Optimizer.apply_gradient()
函数时,它会(除其他外)调用OptimizerDef.apply_gradient()
,后者又会调用OptimizerDef.apply_param_gradient()
(由OptimizerDef
的子类实现)。函数
init_param_state()
和apply_param_gradient()
对 params/grads pytree 中的每个叶子节点都被调用。这使得可以直接编写计算而无需jax.tree_util.tree_map()
。该接口在 Linen 之前定义,没有考虑
params
与variables
中其他集合的区别。最初的 API 很优雅,因为只需要传递优化器,其中包含参数、优化器状态、优化器超参数以及对OptimizerDef
的引用,以执行参数/状态更新。
旧版训练步骤#
优化器首先由其定义和目标参数的 pytree 构建
optimizer_def = flax.optim.Momentum(learning_rate=0.1, beta=0.9)
optimizer = optimizer_def.create(variables['params'])
然后,目标变量将在训练步骤中进行优化(假设只有一个非参数集合“batch_stats”)
def make_train_step(apply_fn):
@jax.jit
def train_step(optimizer, batch_stats, inputs, labels):
def loss_fn(params):
variables = {'params': params, 'batch_stats': batch_stats}
logits, new_model_state = apply_fn(
variables, inputs, mutable=['batch_stats'])
loss = xent_loss(logits, labels)
return loss, new_model_state['batch_stats']
(loss, new_batch_stats), grad = jax.value_and_grad(loss_fn, has_aux=True)(
optimizer.target)
lr = get_learning_rate(step)
new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
return new_optimizer, new_batch_stats, loss
return train_step
batch_stats = variables['batch_stats']
train_step = make_train_step(model.apply)
for step, batch in enumerate(ds)
optimizer, batch_stats, loss = train_step(
optimizer, batch_stats, batch['image'], batch['label'])
备注
请注意
optimizer.apply_gradient()
如何接受额外参数以更新超参数,例如本例中来自独立函数get_learning_rate()
的学习率。
更新计划#
完成关于此 FLIP 的讨论
向 Optax 添加 等价性测试,以保证现有的
flax.optim
优化器与相应的optax
优化器返回相同的值。更新示例以使用 Optax 并验证它们以相同的计算成本达到相同的最终性能。
将缺失的优化器移植到 Optax(例如 Adafactor)——并验证上述几点。
更新所有文档(包括 README、Flax 引导教程、HOWTOs 等),使其只讨论 Optax 优化器。
为用户提供从
flax.optim
更新到使用 Optax 的过渡指南。此过渡指南还应指向 Optax 的 等价性测试 以及更新示例的拉取请求。将
flax.optim
中的优化器标记为已弃用。
请注意,所有当前的 Flax 示例都使用 Optax 中已有的优化器
示例 |
亚麻 |
光学 |
评论 |
---|---|---|---|
图像网 |
optim.Momentum |
optax.sgd |
DynamicScale 可保持不变使用。 |
mnist |
optim.Momentum |
optax.sgd |
|
nlp_seq |
optim.Adam |
optax.adamw |
|
pixelcnn |
optim.Adam |
optax.adam |
|
ppo |
optim.Adam |
optax.adam |
|
seq2seq |
optim.Adam |
optax.adam |
|
vae |
optim.Adam |
optax.adam |
|
wmt |
optim.Adam |
optax.adamw |
(Flax 的 Adam 实现有一个可选的权重衰减参数,但在 Optax 中,带和不带权重衰减的 Adam 是两个不同的别名。)
附录#
设置代码#
以下设置代码可用于运行此 FLIP 中的代码片段
import functools
from typing import Callable, Sequence
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import tensorflow as tf
import tensorflow_datasets as tfds
def pp(features):
return {
'image': tf.cast(features['image'], tf.float32) / 255 - 0.5,
'label': features['label'],
}
class Model(nn.Module):
@nn.compact
def __call__(self, inputs):
x = inputs.reshape([inputs.shape[0], -1])
x = nn.normalization.BatchNorm(True)(x)
x = nn.Dense(10)(x)
x = nn.log_softmax(x)
return x
def onehot(labels, num_classes, on_value=1.0, off_value=0.0):
x = (labels[..., None] == jnp.arange(num_classes)[None])
x = jax.lax.select(
x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value))
return x.astype(jnp.float32)
def xent_loss(logits, labels):
return -jnp.sum(
onehot(labels, num_classes=10) * logits) / labels.size
def get_learning_rate(step):
return 0.1
model = Model()
rng = jax.random.key(0)
ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16)
batch = next(iter(ds))
variables = model.init(rng, jnp.array(batch['image'][:1]))
jax.tree_util.tree_map(jnp.shape, variables)