• 开始日期:2021-02-08

  • FLIP PR:#1011

  • FLIP Issue:#1009

目录

摘要#

本 FLIP 提议用 DeepMind 的优化器库 Optax 替换我们当前的 flax.optim API(本文档中称为旧版 API)。

动机#

我们当前的 API(本文档中称为旧版 API)使用一种模式,即从 target 变量的 pytree 和定义如何更新优化器状态、超参数和目标变量的 OptimizerDef 创建一个 Optimizer 数据类。这种模式对于实现简单的优化器来说相对复杂,同时在典型的 Linen 训练步骤中(尤其是在使用可变状态集合时)相当冗长。

这个包 flax.optim 包含一些优化器,但这个列表远非详尽,理想情况下我们应该从一个专门的 PyPi 包中使用 JAX 优化器。

DeepMind 已经有一个专门的库 — Optax — 它实现了各种有趣的优化器,并提供了一个框架,可以从可重用的梯度变换中组合新的优化器。

使用 Optax#

梯度变换#

虽然 Optax 确实提供了预定义的优化器(例如 optax.adam 或带有动量的 optax.sgd),但它实际上是一个*梯度变换*库,实例化优化器的惯用方法是提供这些梯度变换的组合。为了在使用旧版 API时模拟示例中的动量优化器,我们将编写

import optax

tx = optax.chain(
    optax.trace(decay=0.9, nesterov=False),
    optax.scale_by_schedule(lambda step: -get_learning_rate(step)),
)

备注

  • 上述梯度变换将等同于Optimizer 和 OptimizerDef下的示例,其中我们定义了一个没有 Nesterov 动量的 Momentum 优化器(请注意,beta 参数对应于 optax.trace() 变换的 decay 参数,学习率在第二个链式变换中应用)。

  • 请注意,像 decaynesterov 这样的超参数仅存在于返回 GradientTransformation 的高阶函数的内部作用域中。这样的梯度变换当前被定义为 init()update() 函数的 NamedTuple。原则上,这种模式也可以扩展到存储超参数,这可能是在 Optax 仓库上讨论的一个点。

  • 我们可以在定义 Optax 梯度更新变换时,使用一个根据步数返回学习率的 get_learning_rate() 函数。上述代码展示了这如何作为我们旧版训练步骤中也使用的函数的直接替代品,该更新函数已经存在(请注意,我们需要反转符号,因为我们将梯度更新添加到参数中)。此外,您可以使用 inject_hyperparams() 来使用 Optax 调度任意超参数。

Optax 训练步骤#

@functools.partial(jax.jit, static_argnums=(4, 5))
def train_step(opt_state, variables, inputs, labels, apply_fn, tx_update_fn):

  def loss_fn(params):
    logits, new_model_state = apply_fn(
        {**variables, 'params': params}, inputs, mutable=['batch_stats'])
    loss = xent_loss(logits, labels)
    return loss, new_model_state

  variables, params = variables.pop('params')
  (loss, new_model_state), grads = jax.value_and_grad(loss_fn, has_aux=True)(
      params)
  updates, new_opt_state = tx_update_fn(grads, opt_state, params)
  new_params = optax.apply_updates(params, updates)
  new_variables = {**variables, **new_model_state, 'params': new_params}
  return new_opt_state, new_variables, loss


opt_state = tx.init(variables['params'])
for batch in ds.as_numpy_iterator():
  opt_state, variables, loss = train_step(
      opt_state, variables, batch['image'], batch['label'], model.apply,
      tx.update)
  print(loss)

备注

  • 由于 tx.update() 仅转换梯度,我们仍然需要调用 optax.apply_updates() 来将这些转换后的梯度应用于参数。

  • 旧版 API相比,我们现在可以将包括 params 在内的整个 variables 作为 train_step() 的输入和输出。

  • 在训练步骤中,仍然需要将 paramsvariables 中分离出来,因为我们只想计算相对于 params 而不是整个 variables 的梯度。

  • 只要 Optax 变换在其各自的状态中暴露这些信息,我们仍然可以记录内部优化器状态,例如学习率。例如,optax.scale_by_schedule() 当前仅暴露 opt_state.count,但可以轻松扩展以暴露 step_size。对于随时间变化的内部优化器状态也是如此。

多优化器#

旧版 API 定义了 flax.optim.MultiOptimizer 用于使用不同的优化器处理参数树的不同部分

biases_traversal = flax.optim.ModelParamTraversal(
    lambda path, _: path.endswith('/bias'))
not_biases_traversal = flax.optim.ModelParamTraversal(
    lambda path, _: not path.endswith('/bias'))

optimizer_def = flax.optim.MultiOptimizer(
    (biases_traversal, flax.optim.GradientDescent(learning_rate=0.1)),
    (not_biases_traversal, flax.optim.GradientDescent(learning_rate=0.05)),
)

请注意,我们如何首先定义一个遍历,根据参数的路径(模块作用域和变量名称的串联)选择参数,然后创建一个 MultiOptimizer,为每个独立的遍历绑定一个不同的优化器。

Optax 最近实现了 optax.masked(),可用于指定仅应用于梯度子集的梯度变换

def flattened_traversal(fn):
  def mask(data):
    flat = traverse_util.flatten_dict(data)
    return traverse_util.unflatten_dict({k: fn(k, v) for k, v in flat.items()})
  return mask

tx = optax.chain(
    optax.masked(optax.sgd(learning_rate=0.1),
                 mask=flattened_traversal(lambda path, _: path[-1] == 'bias')),
    optax.masked(optax.sgd(learning_rate=0.05),
                 mask=flattened_traversal(lambda path, _: path[-1] != 'bias')),
)

训练状态#

在 Flax 中,通常会传递一个 TrainState 对象,该对象可以用于检查点。这通过减少参数数量和去除 static_argnums,简化了上述Optax 训练步骤

我们可以定义一个 TrainState 数据类,它封装了通过应用梯度来更新优化器状态和参数的常见模式。

# Small helper class in flax.training
class TrainState(flax.struct.PyTreeNode):
  step: int
  apply_fn: Callable = flax.struct.field(pytree_node=False)
  params: flax.core.FrozenDict[str, Any]
  tx: optax.GradientTransformation = flax.struct.field(pytree_node=False)
  opt_state: optax.OptState

  def apply_gradients(self, *, grads, **kwargs):
    updates, new_opt_state = self.tx.update(
        grads, self.opt_state, self.params)
    new_params = optax.apply_updates(self.params, updates)
    return self.replace(
        step=self.step + 1,
        params=new_params,
        opt_state=new_opt_state,
        **kwargs,
    )

  @classmethod
  def create(cls, *, apply_fn, params, tx, **kwargs):
    opt_state = tx.init(params)
    return cls(
        step=0,
        apply_fn=apply_fn,
        params=params,
        tx=tx,
        opt_state=opt_state,
        **kwargs,
    )

用户可以从这个数据类派生并添加更多字段,例如可变模型状态

from flax.training import train_state

class TrainState(train_state.TrainState):
  batch_stats: flax.core.FrozenDict[str, Any]

这样,Optax 训练步骤就变成了

@jax.jit
def train_step(state, inputs, labels):

  def loss_fn(params):
    outputs, new_model_state = state.apply_fn(
        {'params': params, 'batch_stats': state.batch_stats},
        inputs,
        mutable=['batch_stats'])
    loss = xent_loss(outputs, labels)
    return loss, new_model_state

  (loss, new_model_state), grads = jax.value_and_grad(
      loss_fn, has_aux=True)(state.params)
  new_state = state.apply_gradients(
      grads=grads,
      batch_stats=new_model_state['batch_stats'],
  )

  return new_state, loss


state = TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx,
    batch_stats=variables['batch_stats'],
)
for batch in ds.as_numpy_iterator():
  state, loss = train_step(state, batch['image'], batch['label'])

不带可变状态的训练步骤简化为

@jax.jit
def train_step(state, inputs, labels):

  def loss_fn(params):
    outputs = state.apply_fn({'params': params}, inputs)
    loss = xent_loss(outputs, labels)
    return loss

  loss, grads = jax.value_and_grad(loss_fn)(state.params)
  new_state = state.update(grads=grads)

  return new_state, loss


state = flax.training.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx,
)
for batch in ds.as_numpy_iterator():
  state, loss = train_step(state, batch['image'], batch['label'])

备注

  • 在 Flax 训练循环中,一个常见的模式是拥有一个 TrainState 数据类,它在每个步骤后都会用新状态进行更新。

  • flax.training.train_state 中提出的简单解决方案可以扩展以包含额外数据,但不支持高级用例(例如,多个不同的模型和/或优化器)。用户应分叉数据类并根据自己的需求重新实现。

  • 旧版 API中的 Optimizer 抽象不同,现在 TrainState 直接包含 .params,无需通过 .optimizer

旧版 API#

Optimizer 和 OptimizerDef#

优化器本身将通过创建从 OpimizerDef 派生新类来实现

# flax/optim/momentum.py

@flax.struct.dataclass
class _MomentumHyperParams:
  learning_rate: jnp.ndarray
  beta: jnp.ndarray


@flax.struct.dataclass
class _MomentumParamState:
  momentum: np.ndarray


class Momentum(flax.optim.OptimizerDef):

  def __init__(self, learning_rate=None, beta=0.9):
    super().__init__(
      _MomentumHyperParams(learning_rate, beta)
    )

  def init_param_state(self, param):
    return _MomentumParamState(jnp.zeros_like(param))

  def apply_param_gradient(self, step, hyper_params, param, state, grad):
    del step
    assert hyper_params.learning_rate is not None
    new_momentum = state.momentum * hyper_params.beta + grad
    new_params = param - hyper_params.learning_rate * new_momentum
    return new_params, _MomentumParamState(new_momentum)

备注

  • 注意 OptimizerDefOptimizer 之间的关系:当用户代码调用 Optimizer.apply_gradient() 函数时,它会(除其他外)调用 OptimizerDef.apply_gradient(),后者又会调用 OptimizerDef.apply_param_gradient()(由 OptimizerDef 的子类实现)。

  • 函数 init_param_state()apply_param_gradient() 对 params/grads pytree 中的每个叶子节点都被调用。这使得可以直接编写计算而无需 jax.tree_util.tree_map()

  • 该接口在 Linen 之前定义,没有考虑 paramsvariables 中其他集合的区别。最初的 API 很优雅,因为只需要传递优化器,其中包含参数、优化器状态、优化器超参数以及对 OptimizerDef 的引用,以执行参数/状态更新。

旧版训练步骤#

优化器首先由其定义和目标参数的 pytree 构建

optimizer_def = flax.optim.Momentum(learning_rate=0.1, beta=0.9)
optimizer = optimizer_def.create(variables['params'])

然后,目标变量将在训练步骤中进行优化(假设只有一个非参数集合“batch_stats”)

def make_train_step(apply_fn):
  @jax.jit
  def train_step(optimizer, batch_stats, inputs, labels):

    def loss_fn(params):
      variables = {'params': params, 'batch_stats': batch_stats}
      logits, new_model_state = apply_fn(
          variables, inputs, mutable=['batch_stats'])
      loss = xent_loss(logits, labels)
      return loss, new_model_state['batch_stats']

    (loss, new_batch_stats), grad = jax.value_and_grad(loss_fn, has_aux=True)(
        optimizer.target)
    lr = get_learning_rate(step)
    new_optimizer = optimizer.apply_gradient(grad, learning_rate=lr)
    return new_optimizer, new_batch_stats, loss

  return train_step


batch_stats = variables['batch_stats']
train_step = make_train_step(model.apply)
for step, batch in enumerate(ds)
  optimizer, batch_stats, loss = train_step(
      optimizer, batch_stats, batch['image'], batch['label'])

备注

  • 请注意 optimizer.apply_gradient() 如何接受额外参数以更新超参数,例如本例中来自独立函数 get_learning_rate() 的学习率。

更新计划#

  1. 完成关于此 FLIP 的讨论

  2. 向 Optax 添加 等价性测试,以保证现有的 flax.optim 优化器与相应的 optax 优化器返回相同的值。

  3. 更新示例以使用 Optax 并验证它们以相同的计算成本达到相同的最终性能。

  4. 将缺失的优化器移植到 Optax(例如 Adafactor)——并验证上述几点。

  5. 更新所有文档(包括 README、Flax 引导教程、HOWTOs 等),使其只讨论 Optax 优化器。

  6. 为用户提供从 flax.optim 更新到使用 Optax 的过渡指南。此过渡指南还应指向 Optax 的 等价性测试 以及更新示例的拉取请求。

  7. flax.optim 中的优化器标记为已弃用。

请注意,所有当前的 Flax 示例都使用 Optax 中已有的优化器

示例

亚麻

光学

评论

图像网

optim.Momentum

optax.sgd

DynamicScale 可保持不变使用。

mnist

optim.Momentum

optax.sgd

nlp_seq

optim.Adam

optax.adamw

pixelcnn

optim.Adam

optax.adam

ppo

optim.Adam

optax.adam

seq2seq

optim.Adam

optax.adam

vae

optim.Adam

optax.adam

wmt

optim.Adam

optax.adamw

(Flax 的 Adam 实现有一个可选的权重衰减参数,但在 Optax 中,带和不带权重衰减的 Adam 是两个不同的别名。)

附录#

设置代码#

以下设置代码可用于运行此 FLIP 中的代码片段

import functools
from typing import Callable, Sequence

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import tensorflow as tf
import tensorflow_datasets as tfds


def pp(features):
  return {
      'image': tf.cast(features['image'], tf.float32) / 255 - 0.5,
      'label': features['label'],
  }


class Model(nn.Module):

  @nn.compact
  def __call__(self, inputs):
    x = inputs.reshape([inputs.shape[0], -1])
    x = nn.normalization.BatchNorm(True)(x)
    x = nn.Dense(10)(x)
    x = nn.log_softmax(x)
    return x


def onehot(labels, num_classes, on_value=1.0, off_value=0.0):
  x = (labels[..., None] == jnp.arange(num_classes)[None])
  x = jax.lax.select(
      x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value))
  return x.astype(jnp.float32)


def xent_loss(logits, labels):
  return -jnp.sum(
      onehot(labels, num_classes=10) * logits) / labels.size


def get_learning_rate(step):
  return 0.1


model = Model()
rng = jax.random.key(0)
ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16)
batch = next(iter(ds))
variables = model.init(rng, jnp.array(batch['image'][:1]))
jax.tree_util.tree_map(jnp.shape, variables)