Optimizer#

class flax.nnx.optimizer.Optimizer(self, model, tx, *, wrt)#

适用于单个 Optax 优化器的常见情况的简单训练状态。

用法示例

>>> import jax, jax.numpy as jnp
>>> from flax import nnx
>>> import optax
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     return self.linear2(self.linear1(x))
...
>>> x = jax.random.normal(jax.random.key(0), (1, 2))
>>> y = jnp.ones((1, 4))
...
>>> model = Model(nnx.Rngs(0))
>>> tx = optax.adam(1e-3)
>>> optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
...
>>> loss_fn = lambda model: ((model(x) - y) ** 2).mean()
>>> loss_fn(model)
Array(2.3359997, dtype=float32)
>>> grads = nnx.grad(loss_fn)(model)
>>> optimizer.update(model, grads)
>>> loss_fn(model)
Array(2.310461, dtype=float32)
step#

一个用于跟踪步数计数的 OptState Variable

tx#

一个 Optax 梯度转换。

opt_state#

Optax 优化器状态。

__init__(model, tx, *, wrt)#

实例化该类并包装 Module 和 Optax 梯度转换。实例化优化器状态以跟踪 wrt 中指定的 Variable 类型。将步数计数设置为 0。

参数
  • model – 一个 NNX 模块。

  • tx – 一个 Optax 梯度转换。

  • wrt – 可选参数,用于筛选在优化器状态中跟踪哪些 Variable。这些应该是您计划更新的 Variable;也就是说,此参数值应与传递给 nnx.grad 调用的 wrt 参数匹配,该调用将生成梯度,这些梯度将被传递到 update() 方法的 grads 参数中。

update(model, grads, /, **kwargs)#

根据给定的梯度更新优化器状态和模型参数。

示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> import optax
...
>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.count = nnx.Variable(jnp.array(0))
...
...   def __call__(self, x):
...     self.count[...] += 1
...     return self.linear(x)
...
>>> model = Model(rngs=nnx.Rngs(0))
...
>>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()
>>> optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
>>> grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, nnx.Param))(
...   model, jnp.ones((1, 2)), jnp.ones((1, 3))
... )
>>> optimizer.update(model, grads)

请注意,此函数内部会调用 .tx.update(),然后调用 optax.apply_updates() 来更新 paramsopt_state

参数
  • grads – 从 nnx.grad 派生的梯度。

  • **kwargs – 传递给 tx.update 的额外关键字参数,以支持

  • GradientTransformationExtraArgs

  • 例如optax.scale_by_backtracking_linesearch。