flax.training 软件包#
训练状态#
- class flax.training.train_state.TrainState(step, apply_fn, params, tx, opt_state)[源代码]#
用于具有单个 Optax 优化器的常见情况的简单训练状态。
用法示例
>>> import flax.linen as nn >>> from flax.training.train_state import TrainState >>> import jax, jax.numpy as jnp >>> import optax >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 2)) >>> model = nn.Dense(2) >>> variables = model.init(jax.random.key(0), x) >>> tx = optax.adam(1e-3) >>> state = TrainState.create( ... apply_fn=model.apply, ... params=variables['params'], ... tx=tx) >>> def loss_fn(params, x, y): ... predictions = state.apply_fn({'params': params}, x) ... loss = optax.l2_loss(predictions=predictions, targets=y).mean() ... return loss >>> loss_fn(state.params, x, y) Array(1.8136346, dtype=float32) >>> grads = jax.grad(loss_fn)(state.params, x, y) >>> state = state.apply_gradients(grads=grads) >>> loss_fn(state.params, x, y) Array(1.8079796, dtype=float32)
请注意,您可以通过子类化这个数据类来轻松扩展它,以存储额外的数据(例如,额外的变量集合)。
对于更特殊的用例(例如,多个优化器),最好是 fork 该类并进行修改。
- 参数
step – 计数器从 0 开始,每次调用
.apply_gradients()
时递增。apply_fn – 通常设置为
model.apply()
。为方便起见,将其保留在此数据类中,以便在训练循环中的train_step()
函数具有更短的参数列表。params – 将由
tx
更新并由apply_fn
使用的参数。tx – 一个 Optax 梯度转换。
opt_state –
tx
的状态。
- apply_gradients(*, grads, **kwargs)[源代码]#
在返回值中更新
step
、params
、opt_state
和**kwargs
。请注意,此函数内部会调用
.tx.update()
,然后调用optax.apply_updates()
来更新params
和opt_state
。- 参数
grads – 与
.params
具有相同 pytree 结构的梯度。**kwargs – 应该被
.replace()
的其他数据类属性。
- 返回
self
的一个更新实例,其中step
增加 1,params
和opt_state
通过应用grads
进行更新,并且其他属性按照kwargs
的指定进行替换。