Optimizer#
- class flax.nnx.optimizer.Optimizer(self, model, tx, *, wrt)#
适用于单个 Optax 优化器的常见情况的简单训练状态。
用法示例
>>> import jax, jax.numpy as jnp >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) Array(2.3359997, dtype=float32) >>> grads = nnx.grad(loss_fn)(model) >>> optimizer.update(model, grads) >>> loss_fn(model) Array(2.310461, dtype=float32)
- step#
一个用于跟踪步数计数的
OptState
Variable
。
- tx#
一个 Optax 梯度转换。
- opt_state#
Optax 优化器状态。
- __init__(model, tx, *, wrt)#
实例化该类并包装
Module
和 Optax 梯度转换。实例化优化器状态以跟踪wrt
中指定的Variable
类型。将步数计数设置为 0。- 参数
model – 一个 NNX 模块。
tx – 一个 Optax 梯度转换。
wrt – 可选参数,用于筛选在优化器状态中跟踪哪些
Variable
。这些应该是您计划更新的Variable
;也就是说,此参数值应与传递给nnx.grad
调用的wrt
参数匹配,该调用将生成梯度,这些梯度将被传递到update()
方法的grads
参数中。
- update(model, grads, /, **kwargs)#
根据给定的梯度更新优化器状态和模型参数。
示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.count = nnx.Variable(jnp.array(0)) ... ... def __call__(self, x): ... self.count[...] += 1 ... return self.linear(x) ... >>> model = Model(rngs=nnx.Rngs(0)) ... >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() >>> optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) >>> grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, nnx.Param))( ... model, jnp.ones((1, 2)), jnp.ones((1, 3)) ... ) >>> optimizer.update(model, grads)
请注意,此函数内部会调用
.tx.update()
,然后调用optax.apply_updates()
来更新params
和opt_state
。- 参数
grads – 从
nnx.grad
派生的梯度。**kwargs – 传递给 tx.update 的额外关键字参数,以支持
GradientTransformationExtraArgs –
(例如)optax.scale_by_backtracking_linesearch。