从 Flax Linen 迁移到 Flax NNX#
本指南通过并排展示示例代码来演示 Flax Linen 和 Flax NNX 模型之间的差异,以帮助您从 Flax Linen 迁移到 Flax NNX API。
本文档主要讲解如何将任意 Flax Linen 代码转换为 Flax NNX。如果您希望“安全”地迭代转换您的代码库,请查阅通过 nnx.bridge 结合使用 Flax NNX 和 Linen 指南。
为了充分利用本指南,强烈建议您先阅读Flax NNX 基础文档,其中涵盖了 nnx.Module
系统、Flax 变换以及带示例的函数式 API。
基础 Module
定义#
Flax Linen 和 Flax NNX 都使用 Module
类作为表达神经网络库层的默认单元。在下面的示例中,您首先创建一个 Block
(通过子类化 Module
),它由一个带 dropout 和 ReLU 激活函数的线性层组成;然后,在创建 Model
(也通过子类化 Module
)时,您将其用作子 Module
,该 Model
由 Block
和一个线性层构成。
Flax Linen 和 Flax NNX 的 Module
对象之间有两个根本区别
无状态 vs. 有状态:一个
flax.linen.Module
(nn.Module
) 实例是无状态的——变量由一个纯函数式的Module.init()
调用返回,并被分开管理。而一个flax.nnx.Module
则将其变量作为该 Python 对象的属性来拥有。惰性 vs. 即时:一个
flax.linen.Module
仅在实际看到其输入时才分配空间来创建变量(惰性)。而一个flax.nnx.Module
实例在实例化时就会创建变量,而无需看到样本输入(即时)。Flax Linen 可以使用
@nn.compact
装饰器在单个方法中定义模型,并利用输入样本进行形状推断。而一个 Flax NNXModule
通常需要额外的形状信息以在__init__
期间创建所有参数,并在__call__
方法中单独定义计算。
import flax.linen as nn
class Block(nn.Module):
features: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5, deterministic=not training)(x)
x = jax.nn.relu(x)
return x
class Model(nn.Module):
dmid: int
dout: int
@nn.compact
def __call__(self, x, training: bool):
x = Block(self.dmid)(x, training)
x = nn.Dense(self.dout)(x)
return x
from flax import nnx
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
self.block = Block(din, dmid, rngs=rngs)
self.linear = nnx.Linear(dmid, dout, rngs=rngs)
def __call__(self, x):
x = self.block(x)
x = self.linear(x)
return x
变量创建#
接下来,让我们讨论实例化模型和初始化其参数
要为一个 Flax Linen 模型生成模型参数,您需要使用一个
jax.random.key
(文档)以及一些模型将接受的样本输入来调用flax.linen.Module.init
(nn.Module.init
) 方法。这将产生一个需要单独携带和维护的 JAX 数组(jax.Array
数据类型)的嵌套字典。在 Flax NNX 中,当您实例化模型时,模型参数会自动初始化,并且变量(
nnx.Variable
对象)作为属性存储在nnx.Module
(或其子Module
)内部。您仍然需要为其提供一个伪随机数生成器 (PRNG) 密钥,但该密钥将被包装在nnx.Rngs
类中并存储在内部,以便在需要时生成更多 PRNG 密钥。
如果您希望以无状态、类似字典的方式访问 Flax NNX 模型参数以进行检查点保存或模型修改,请查阅 Flax NNX 拆分/合并 API (nnx.split
/ nnx.merge
)。
model = Model(256, 10)
sample_x = jnp.ones((1, 784))
variables = model.init(jax.random.key(0), sample_x, training=False)
params = variables["params"]
assert params['Dense_0']['bias'].shape == (10,)
assert params['Block_0']['Dense_0']['kernel'].shape == (784, 256)
model = Model(784, 256, 10, rngs=nnx.Rngs(0))
# Parameters were already initialized during model instantiation.
assert model.linear.bias.value.shape == (10,)
assert model.block.linear.kernel.value.shape == (784, 256)
训练步骤和编译#
现在,让我们继续编写一个训练步骤,并使用 JAX 即时编译对其进行编译。下面是 Flax Linen 和 Flax NNX 方法之间的一些差异。
编译训练步骤
Flax Linen 使用
@jax.jit
——一个 JAX 变换——来编译训练步骤。Flax NNX 使用
@nnx.jit
——一个 Flax NNX 变换(是多个行为类似于 JAX 变换但能很好地与 Flax NNX 对象配合使用的变换 API 之一)。因此,虽然jax.jit
只接受纯无状态参数的函数,但nnx.jit
允许参数是有状态的 NNX Module。这大大减少了训练步骤所需的代码行数。
计算梯度
同样,Flax Linen 使用
jax.grad
(一个用于自动微分的 JAX 变换)来返回一个原始的梯度字典。Flax NNX 使用
nnx.grad
(一个 Flax NNX 变换)将 NNX Module 的梯度作为nnx.State
字典返回。如果您想在 Flax NNX 中使用常规的jax.grad
,则需要使用 Flax NNX 拆分/合并 API。
优化器
如果您已经在使用Optax优化器,例如
optax.adamw
(而不是此处显示的原始jax.tree.map
计算)与 Flax Linen 一起使用,请查阅 Flax NNX 基础指南中的nnx.Optimizer
示例,了解一种更简洁的训练和更新模型的方法。
每个训练步骤中的模型更新
Flax Linen 训练步骤需要返回一个参数的 pytree 作为下一步的输入。
Flax NNX 训练步骤不需要返回任何东西,因为
model
已经在nnx.jit
内部就地更新了。此外,
nnx.Module
对象是有状态的,并且Module
会自动跟踪其内部的几件事物,例如 PRNG 密钥和BatchNorm
统计信息。这就是为什么您不需要在每一步都显式传递 PRNG 密钥。另请注意,您可以使用nnx.reseed
来重置其底层的 PRNG 状态。
Dropout 行为
在 Flax Linen 中,您需要显式定义并传入
training
参数来控制flax.linen.Dropout
(nn.Dropout
) 的行为,即其deterministic
标志,这意味着只有当training=True
时才会发生随机 dropout。在 Flax NNX 中,您可以调用
model.train()
(flax.nnx.Module.train()
) 来自动将nnx.Dropout
切换到训练模式。反之,您可以调用model.eval()
(flax.nnx.Module.eval()
) 来关闭训练模式。您可以在其API 参考中了解更多关于nnx.Module.train
的功能。
...
@jax.jit
def train_step(key, params, inputs, labels):
def loss_fn(params):
logits = model.apply(
{'params': params},
inputs, training=True, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(params)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
return params
model.train() # Sets ``deterministic=False` under the hood for nnx.Dropout
@nnx.jit
def train_step(model, inputs, labels):
def loss_fn(model):
logits = model(inputs)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.merge_state(params, rest))
集合和变量类型#
Flax Linen 和 NNX API 之间的一个关键区别在于它们如何将变量分组。Flax Linen 使用不同的集合,而 Flax NNX 中,由于所有变量都应为顶层 Python 属性,您可以使用不同的变量类型。
在 Flax NNX 中,您可以自由地创建自己的变量类型,作为 nnx.Variable
的子类。
对于所有内置的 Flax Linen 层和集合,Flax NNX 已经创建了相应的层和变量类型。例如:
flax.linen.Dense
(nn.Dense
) 创建params
->nnx.Linear
创建nnx.Param
。flax.linen.BatchNorm
(nn.BatchNorm
) 创建batch_stats
->nnx.BatchNorm
创建nnx.BatchStats
。flax.linen.Module.sow()
创建intermediates
->nnx.Module.sow()
创建nnx.Intermediates
。在 Flax NNX 中,您也可以简单地通过将其分配给一个
nnx.Module
属性来获取中间变量——例如,self.sowed = nnx.Intermediates(x)
。这类似于 Flax Linen 的self.variable('intermediates', 'sowed', lambda: x)
。
class Block(nn.Module):
features: int
def setup(self):
self.dense = nn.Dense(self.features)
self.batchnorm = nn.BatchNorm(momentum=0.99)
self.count = self.variable('counter', 'count',
lambda: jnp.zeros((), jnp.int32))
@nn.compact
def __call__(self, x, training: bool):
x = self.dense(x)
x = self.batchnorm(x, use_running_average=not training)
self.count.value += 1
x = jax.nn.relu(x)
return x
x = jax.random.normal(jax.random.key(0), (2, 4))
model = Block(4)
variables = model.init(jax.random.key(0), x, training=True)
variables['params']['dense']['kernel'].shape # (4, 4)
variables['batch_stats']['batchnorm']['mean'].shape # (4, )
variables['counter']['count'] # 1
class Counter(nnx.Variable): pass
class Block(nnx.Module):
def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
self.batchnorm = nnx.BatchNorm(
num_features=out_features, momentum=0.99, rngs=rngs
)
self.count = Counter(jnp.array(0))
def __call__(self, x):
x = self.linear(x)
x = self.batchnorm(x)
self.count += 1
x = jax.nn.relu(x)
return x
model = Block(4, 4, rngs=nnx.Rngs(0))
model.linear.kernel # Param(value=...)
model.batchnorm.mean # BatchStat(value=...)
model.count # Counter(value=...)
如果您想从变量的 pytree 中提取某些数组:
在 Flax Linen 中,您可以访问特定的字典路径。
在 Flax NNX 中,您可以使用
nnx.split
来区分 Flax NNX 中的类型。以下代码是一个简单的示例,它按变量类型拆分变量——请查阅 Flax NNX 过滤器指南以了解更复杂的过滤表达式。
params, batch_stats, counter = (
variables['params'], variables['batch_stats'], variables['counter'])
params.keys() # ['dense', 'batchnorm']
batch_stats.keys() # ['batchnorm']
counter.keys() # ['count']
# ... make arbitrary modifications ...
# Merge back with raw dict to carry on:
variables = {'params': params, 'batch_stats': batch_stats, 'counter': counter}
graphdef, params, batch_stats, count = nnx.split(
model, nnx.Param, nnx.BatchStat, Counter)
params.keys() # ['batchnorm', 'linear']
batch_stats.keys() # ['batchnorm']
count.keys() # ['count']
# ... make arbitrary modifications ...
# Merge back with ``nnx.merge`` to carry on:
model = nnx.merge(graphdef, params, batch_stats, count)
使用多个方法#
在本节中,您将学习如何在 Flax Linen 和 Flax NNX 中使用多个方法。作为示例,您将实现一个具有三个方法的自编码器模型:encode
、decode
和 __call__
。
定义编码器和解码器层
在 Flax Linen 中,和之前一样,定义层时无需传入输入形状,因为
flax.linen.Module
的参数将使用形状推断进行惰性初始化。在 Flax NNX 中,您必须传入输入形状,因为
nnx.Module
的参数将在没有形状推断的情况下即时初始化。
class AutoEncoder(nn.Module):
embed_dim: int
output_dim: int
def setup(self):
self.encoder = nn.Dense(self.embed_dim)
self.decoder = nn.Dense(self.output_dim)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
model = AutoEncoder(256, 784)
variables = model.init(jax.random.key(0), x=jnp.ones((1, 784)))
class AutoEncoder(nnx.Module):
def __init__(self, in_dim: int, embed_dim: int, output_dim: int, rngs):
self.encoder = nnx.Linear(in_dim, embed_dim, rngs=rngs)
self.decoder = nnx.Linear(embed_dim, output_dim, rngs=rngs)
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def __call__(self, x):
x = self.encode(x)
x = self.decode(x)
return x
model = AutoEncoder(784, 256, 784, rngs=nnx.Rngs(0))
变量结构如下
# variables['params']
{
decoder: {
bias: (784,),
kernel: (256, 784),
},
encoder: {
bias: (256,),
kernel: (784, 256),
},
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
{
'decoder': {
'bias': Param(value=(784,)),
'kernel': Param(value=(256, 784))
},
'encoder': {
'bias': Param(value=(256,)),
'kernel': Param(value=(784, 256))
}
}
调用除 __call__
之外的方法
在 Flax Linen 中,您仍然需要使用
apply
API。在 Flax NNX 中,您可以直接调用该方法。
z = model.apply(variables, x=jnp.ones((1, 784)), method="encode")
z = model.encode(jnp.ones((1, 784)))
转换#
Flax Linen 和Flax NNX 变换都提供了各自的一套变换,它们包装了JAX 变换,使其可以与 Module
对象一起使用。
Flax Linen 中的大多数变换,例如 grad
或 jit
,在 Flax NNX 中变化不大。但是,例如,如果您尝试对层进行 scan
,如下一节所述,代码会有很大不同。
让我们从一个例子开始
首先,定义一个
RNNCell
Module
,它将包含 RNN 单个步骤的逻辑。定义一个
initial_state
方法,它将用于初始化 RNN 的状态(即carry
)。与jax.lax.scan
(API 文档)类似,RNNCell.__call__
方法将是一个函数,它接受 carry 和输入,并返回新的 carry 和输出。在这种情况下,carry 和输出是相同的。
class RNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
class RNNCell(nnx.Module):
def __init__(self, input_size, hidden_size, rngs):
self.linear = nnx.Linear(hidden_size + input_size, hidden_size, rngs=rngs)
self.hidden_size = hidden_size
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = self.linear(x)
x = jax.nn.relu(x)
return x, x
def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))
接下来,定义一个 RNN
Module
,它将包含整个 RNN 的逻辑。
在 Flax Linen 中
您将使用
flax.linen.scan
(nn.scan
) 定义一个包装RNNCell
的新临时类型。在此过程中,您还将:1) 指示nn.scan
广播params
集合(所有步骤共享相同的参数)并且不分割params
PRNG 流(以便所有步骤都使用相同的参数进行初始化);最后,2) 指定您希望 scan 在输入的第二个轴上运行,并沿第二个轴堆叠输出。然后,您将立即使用此临时类型创建一个“提升的”
RNNCell
实例,并用它来创建carry
,然后运行__call__
方法,该方法将对序列进行scan
。
在 Flax NNX 中
您将创建一个
scan
函数 (scan_fn
),它将使用在__init__
中定义的RNNCell
来扫描序列,并显式设置in_axes=(nnx.Carry, None, 1)
。nnx.Carry
意味着carry
参数将是 carry,None
意味着cell
将被广播到所有步骤,而1
意味着x
将在轴 1 上进行扫描。
class RNN(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, x):
rnn = nn.scan(
RNNCell, variable_broadcast='params',
split_rngs={'params': False}, in_axes=1, out_axes=1
)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y
x = jnp.ones((3, 12, 32))
model = RNN(64)
variables = model.init(jax.random.key(0), x=jnp.ones((3, 12, 32)))
y = model.apply(variables, x=jnp.ones((3, 12, 32)))
class RNN(nnx.Module):
def __init__(self, input_size: int, hidden_size: int, rngs: nnx.Rngs):
self.hidden_size = hidden_size
self.cell = RNNCell(input_size, self.hidden_size, rngs=rngs)
def __call__(self, x):
scan_fn = lambda carry, cell, x: cell(carry, x)
carry = self.cell.initial_state(x.shape[0])
carry, y = nnx.scan(
scan_fn, in_axes=(nnx.Carry, None, 1), out_axes=(nnx.Carry, 1)
)(carry, self.cell, x)
return y
x = jnp.ones((3, 12, 32))
model = RNN(x.shape[2], 64, rngs=nnx.Rngs(0))
y = model(x)
扫描层#
总的来说,Flax Linen 和 Flax NNX 的变换应该看起来是相同的。然而,Flax NNX 变换被设计得更接近其底层的JAX 对应物,因此我们在某些 Linen 提升变换中摒弃了一些假设。这个跨层扫描的用例将是一个很好的例子来展示这一点。
跨层扫描是一种技术,您将一个输入通过一个包含 N 个重复层的序列,将每层的输出作为下一层的输入。这种模式可以显著减少大型模型的编译时间。在下面的示例中,您将在顶层 MLP
Module
中重复 Block
Module
5 次。
在 Flax Linen 中,您将
flax.linen.scan
(nn.scan
) 变换应用于Block
nn.Module
,以创建一个更大的ScanBlock
nn.Module
,其中包含 5 个Block
nn.Module
对象。它将在初始化时自动创建一个形状为(5, 64, 64)
的大参数,并在调用时对每个(64, 64)
切片进行迭代,总共 5 次,就像jax.lax.scan
(API 文档) 那样。仔细看,在这个模型的逻辑中,初始化时实际上并不需要
jax.lax.scan
操作。那里发生的事情更像是一个jax.vmap
操作——您有一个接受(in_dim, out_dim)
的Block
子Module
,然后您对其“vmap”num_layers
次以创建一个更大的数组。在 Flax NNX 中,您利用了模型初始化和运行代码完全解耦的优势,转而使用
nnx.vmap
变换来初始化底层的Block
参数,并使用nnx.scan
变换来将模型输入通过它们。
有关 Flax NNX 变换的更多信息,请查阅变换指南。
class Block(nn.Module):
features: int
training: bool
@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None
class MLP(nn.Module):
features: int
num_layers: int
@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)
y, _ = ScanBlock(self.features, training)(x, None)
return y
model = MLP(64, num_layers=5)
class Block(nnx.Module):
def __init__(self, input_dim, features, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x: jax.Array): # No need to require a second input!
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x # No need to return a second output!
class MLP(nnx.Module):
def __init__(self, features, num_layers, rngs):
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return Block(features, features, rngs=rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def __call__(self, x):
@nnx.split_rngs(splits=self.num_layers)
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
model = MLP(64, num_layers=5, rngs=nnx.Rngs(0))
上述 Flax NNX 示例中还有一些其他细节需要解释
`@nnx.split_rngs` 装饰器:Flax NNX 变换完全不感知 PRNG 状态,这使得它们的行为更像 JAX 变换,但与处理 PRNG 状态的 Flax Linen 变换有所不同。为了恢复此功能,
nnx.split_rngs
装饰器允许您在将nnx.Rngs
传递给被装饰的函数之前对其进行分割,并在之后将其“降级”,以便在外部使用。在这里,您分割了 PRNG 密钥,因为
jax.vmap
和jax.lax.scan
要求,如果其每个内部操作都需要自己的密钥,则需要一个 PRNG 密钥列表。因此,对于MLP
内部的 5 个层,您在进入 JAX 变换之前,从其参数中分割并提供了 5 个不同的 PRNG 密钥。请注意,实际上
create_block()
知道它需要创建 5 个层,*正是因为*它看到了 5 个 PRNG 密钥,因为in_axes=(0,)
表明vmap
将查看第一个参数的第一个维度来知道它将映射的大小。forward()
也是如此,它会查看第一个参数(即model
)内部的变量来确定需要扫描多少次。nnx.split_rngs
在这里实际上是分割了model
内部的 PRNG 状态。(如果Block
Module
没有 dropout,您就不需要nnx.split_rngs
这一行,因为它无论如何都不会消耗任何 PRNG 密钥。)
为什么 Flax NNX 中的 Block Module 不需要接受和返回那个额外的虚拟值:这是
jax.lax.scan
的一个要求(API 文档)。Flax NNX 简化了这一点,因此如果您将out_axes
设置为nnx.Carry
而不是默认的(nnx.Carry, 0)
,现在可以选择忽略第二个输出。这是 Flax NNX 变换与JAX 变换 API 偏离的罕见情况之一。
上面的 Flax NNX 示例中有更多的代码行,但它们更精确地表达了每次发生的情况。由于 Flax NNX 变换比 JAX 变换 API 更接近,建议在使用其Flax NNX 等价物之前,对底层的JAX 变换有很好的理解。
现在检查双方的变量 pytree
# variables = model.init(key, x=jnp.ones((1, 64)), training=True)
# variables['params']
{
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
}
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
{
'blocks': {
'linear': {
'bias': Param(value=(5, 64)),
'kernel': Param(value=(5, 64, 64))
}
}
}
在 Flax NNX 中使用 TrainState
#
Flax Linen 有一个方便的 TrainState
数据类,用于捆绑模型、参数和优化器。在 Flax NNX 中,这并非真正必要。在本节中,您将学习如何围绕 TrainState
构建您的 Flax NNX 代码,以满足任何向后兼容性的需求。
在 Flax NNX 中
您必须首先在模型上调用
nnx.split
以获取独立的nnx.GraphDef
和nnx.State
对象。您还需要子类化
TrainState
以添加一个用于其他变量的字段。然后,您可以将
nnx.GraphDef.apply
作为apply
函数,将nnx.State
作为参数和其他变量,以及一个优化器作为参数传递给TrainState
构造函数。
请注意,nnx.GraphDef.apply
将接受 nnx.State
对象作为参数并返回一个可调用函数。该函数可以在输入上调用,以输出模型的 logits,以及更新后的 nnx.GraphDef
和 nnx.State
对象。请注意下面使用 @jax.jit
,因为您没有将 Flax NNX Module 传递到 train_step
中。
from flax.training import train_state
sample_x = jnp.ones((1, 784))
model = nn.Dense(features=10)
params = model.init(jax.random.key(0), sample_x)['params']
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(key, state, inputs, labels):
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
inputs, # <== inputs
rngs={'dropout': key}
)
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state
from flax.training import train_state
model = nnx.Linear(784, 10, rngs=nnx.Rngs(0))
model.train() # set deterministic=False
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)
class TrainState(train_state.TrainState):
other_variables: nnx.State
state = TrainState.create(
apply_fn=graphdef.apply,
params=params,
other_variables=other_variables,
tx=optax.adam(1e-3)
)
@jax.jit
def train_step(state, inputs, labels):
def loss_fn(params, other_variables):
logits, (graphdef, new_state) = state.apply_fn(
params,
other_variables
)(inputs) # <== inputs
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
grads = jax.grad(loss_fn)(state.params, state.other_variables)
state = state.apply_gradients(grads=grads)
return state