模块#

class flax.nnx.Module(self, /, *args, **kwargs)[源代码]#

所有神经网络模块的基类。

层和模型应从此类派生。

Module可以包含子模块,并以此方式嵌套成树形结构。子模块可以在 __init__ 方法内作为常规属性分配。

您可以在您的 Module 子类上定义任意的“前向传播”方法。虽然没有特殊处理的方法,但 __call__ 是一个热门选择,因为您可以直接调用 Module

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
eval(**attributes)[源代码]#

将模块设置为评估模式。

eval 使用 set_attributes 递归地为所有具有这些属性的嵌套模块设置 deterministic=Trueuse_running_average=True 属性。它主要用于控制 DropoutBatchNorm 模块的运行时行为。

示例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
参数

**attributes – 传递给 set_attributes 的额外属性。

iter_children()[源代码]#

遍历当前模块的所有子 Module。此方法类似于 iter_modules(),但它只遍历直接子模块,不进行更深层次的递归。

iter_children 创建一个生成器,它会产生键和模块实例,其中键是表示用于访问相应子模块的模块属性名称的字符串。

示例

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_children():
...  print(path, type(module).__name__)
...
batch_norm BatchNorm
dropout Dropout
linear Linear
submodule SubModule
iter_modules()[源代码]#

递归地遍历当前模块的所有嵌套 Module,包括当前模块本身。

iter_modules 创建一个生成器,它会产生路径和模块实例,其中路径是由字符串或整数组成的元组,表示从根模块到该模块的路径。

示例

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in model.iter_modules():
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
perturb(name, value, variable_type=<class 'flax.nnx.variablelib.Perturbation'>)[源代码]#

向中间值添加一个零值变量(“扰动”)。

value 的梯度将与此扰动变量的梯度相同。因此,如果您将损失函数定义为同时接受参数和扰动作为独立参数,您可以通过对扰动变量运行 jax.grad 来获得 value 的中间梯度。

由于扰动值的形状取决于输入的形状,扰动变量只有在您通过模型运行一次样本输入后才会创建。

注意

这将创建与 value 大小相同的额外虚拟变量,因此会占用更多内存。仅在训练中调试梯度时使用它。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = self.perturb('xgrad', x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 4))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'xgrad')  # perturbation requires a sample input run
>>> _ = model(x)
>>> assert model.xgrad.value.shape == (1, 3)   # same as the intermediate value

>>> # Take gradients on the Param and Perturbation variables
>>> @nnx.grad(argnums=nnx.DiffState(argnum=0, filter=nnx.Any(nnx.Param, nnx.Perturbation)))
... def grad_loss(model, inputs, targets):
...   preds = model(inputs)
...   return jnp.square(preds - targets).mean()

>>> intm_grads = grad_loss(model, x, y)
>>> # `intm_grads.xgrad.value` is the intermediate gradient
>>> assert not jnp.array_equal(intm_grads.xgrad.value, jnp.zeros((1, 3)))
参数
  • name – 一个字符串,表示扰动值的 Module 属性名称。

  • value – 需要获取中间梯度的值。

  • variable_type – 用于存储扰动的 Variable 类型。默认为 nnx.Perturbation

set_attributes(*filters, raise_if_not_found=True, **attributes)[源代码]#

设置嵌套模块(包括当前模块)的属性。如果模块中未找到该属性,则会忽略它。

示例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

可以使用Filter来设置特定模块的属性。

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
参数
  • *filters – 用于选择要设置属性的模块的过滤器。

  • raise_if_not_found – 如果为 True(默认值),则当在所选模块中至少有一个属性实例未找到时,会引发 ValueError。

  • **attributes – 要设置的属性。

sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[源代码]#

sow()可用于收集中间值,而无需显式地在每次模块调用中传递容器。 sow() 将值存储在由 name 表示的新 Module 属性中。该值将被类型为 variable_typeVariable 包装,这在 split()state()pop() 中进行筛选时非常有用。

默认情况下,值存储在元组中,每个存储的值都追加在末尾。这样,当同一模块被多次调用时,可以跟踪所有中间值。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x, add=0):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x+add)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')

>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i.value) == 1 # tuple of length 1
>>> assert model.i.value[0].shape == (1, 3)

>>> y = model(x, add=1)
>>> assert len(model.i.value) == 2 # tuple of length 2
>>> assert (model.i.value[0] + 1 == model.i.value[1]).all()

或者,可以传递自定义的初始化/归约函数。

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     self.sow(nnx.Intermediate, 'product', x,
...              init_fn=lambda: 1,
...              reduce_fn=lambda prev, curr: prev*curr)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))

>>> y = model(x)
>>> assert (model.sum.value == model.product.value).all()
>>> intermediate = model.sum.value

>>> y = model(x)
>>> assert (model.sum.value == intermediate*2).all()
>>> assert (model.product.value == intermediate**2).all()
参数
  • variable_type – 用于存储值的 Variable 类型。通常使用 Intermediate 来表示中间值。

  • name – 一个字符串,表示存储播种值的 Module 属性名称。

  • value – 要存储的值。

  • reduce_fn – 用于将现有值与新值组合的函数。默认是将值追加到元组中。

  • init_fn – 对于存储的第一个值,reduce_fn 将接收 init_fn 的结果以及要存储的值。默认为空元组。

train(**attributes)[源代码]#

将模块设置为训练模式。

train 使用 set_attributes 递归地为所有具有这些属性的嵌套模块设置 deterministic=Falseuse_running_average=False 属性。它主要用于控制 DropoutBatchNorm 模块的运行时行为。

示例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
参数

**attributes – 传递给 set_attributes 的额外属性。