flax.training 软件包#

训练状态#

class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[源代码]#

用于具有单个 Optax 优化器的常见情况的简单训练状态。

用法示例

>>> import flax.linen as nn
>>> from flax.training.train_state import TrainState
>>> import jax, jax.numpy as jnp
>>> import optax

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 2))
>>> model = nn.Dense(2)
>>> variables = model.init(jax.random.key(0), x)
>>> tx = optax.adam(1e-3)

>>> state = TrainState.create(
...     apply_fn=model.apply,
...     params=variables['params'],
...     tx=tx)

>>> def loss_fn(params, x, y):
...   predictions = state.apply_fn({'params': params}, x)
...   loss = optax.l2_loss(predictions=predictions, targets=y).mean()
...   return loss
>>> loss_fn(state.params, x, y)
Array(1.8136346, dtype=float32)

>>> grads = jax.grad(loss_fn)(state.params, x, y)
>>> state = state.apply_gradients(grads=grads)
>>> loss_fn(state.params, x, y)
Array(1.8079796, dtype=float32)

请注意,您可以通过子类化这个数据类来轻松扩展它,以存储额外的数据(例如,额外的变量集合)。

对于更特殊的用例(例如,多个优化器),最好是 fork 该类并进行修改。

参数
  • step – 计数器从 0 开始,每次调用 .apply_gradients() 时递增。

  • apply_fn – 通常设置为 model.apply()。为方便起见,将其保留在此数据类中,以便在训练循环中的 train_step() 函数具有更短的参数列表。

  • params – 将由 tx 更新并由 apply_fn 使用的参数。

  • tx – 一个 Optax 梯度转换。

  • opt_statetx 的状态。

apply_gradients(*, grads, **kwargs)[源代码]#

在返回值中更新 stepparamsopt_state**kwargs

请注意,此函数内部会调用 .tx.update(),然后调用 optax.apply_updates() 来更新 paramsopt_state

参数
  • grads – 与 .params 具有相同 pytree 结构的梯度。

  • **kwargs – 应该被 .replace() 的其他数据类属性。

返回

self 的一个更新实例,其中 step 增加 1,paramsopt_state 通过应用 grads 进行更新,并且其他属性按照 kwargs 的指定进行替换。

classmethod create(*, apply_fn, params, tx, **kwargs)[源代码]#

创建一个新实例,其中 step=0 并且 opt_state 已初始化。