变换#

flax.nnx.grad(f=<flax.typing.Missing object>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[源代码]#

能够处理模块/图节点作为参数的 jax.grad 的对象感知版本。

示例

>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.grad(loss_fn)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
  'bias': Param(
    value=(3,)
  ),
  'kernel': Param(
    value=(2, 3)
  )
})

默认情况下,NNX 对象相对于其所有的 nnx.Param 变量进行微分。您可以通过将 DiffState 对象传递给 argnums 参数来指定哪些子状态是可微分的。例如,如果您只想对 Linear 类的 kernel 属性进行微分,您可以使用 PathContains 过滤器。

>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
...
>>> kernel_attribute = nnx.PathContains('kernel')
>>> diff_state = nnx.DiffState(0, kernel_attribute)
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.grad(loss_fn, argnums=diff_state)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
  'kernel': Param(
    value=(2, 3)
  )
})

有关如何创建自定义过滤器的更多信息,请参阅使用过滤器指南。

参数
  • fun – 要被微分的函数。其在 argnums 指定位置上的参数应该是数组、标量、图节点或标准的 Python 容器。argnums 指定位置上的参数数组必须是非精确(即浮点或复数)类型。它应该返回一个标量(包括形状为 () 的数组,但不包括形状为 (1,) 的数组等)。

  • argnums – 可选,整数或整数序列。指定对哪个位置参数进行微分(默认为 0)。

  • has_aux – 可选,布尔值。指示 fun 是否返回一个二元组,其中第一个元素被认为是待微分数学函数的输出,第二个元素是辅助数据。默认为 False。

  • holomorphic – 可选,布尔值。指示 fun 是否保证是全纯的。如果为 True,输入和输出必须是复数。默认为 False。

  • allow_int – 可选,布尔值。是否允许对整数值输入进行微分。整数输入的梯度将具有一个平凡的向量空间 dtype (float0)。默认为 False。

flax.nnx.jit(fun=<class 'flax.typing.Missing'>, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, abstracted_axes=None)[源代码]#

jax.jit 的提升版本,可以处理模块/图节点作为参数。

参数
  • fun

    要进行 JIT 编译的函数。fun 应该是一个纯函数,因为副作用可能只执行一次。

    fun 的参数和返回值应该是数组、标量或它们的(嵌套)标准 Python 容器(元组/列表/字典)。由 static_argnums 指定的位置参数可以是任何东西,只要它们是可哈希的并且定义了相等操作。静态参数作为编译缓存键的一部分,这就是为什么必须定义哈希和相等操作符的原因。

    JAX 对 fun 保持弱引用,用作编译缓存键,因此对象 fun 必须是可弱引用的。大多数 Callable 对象已经满足此要求。

  • in_shardings

    一个 Pytree,其结构与 fun 的参数结构相匹配,所有实际参数都替换为资源分配规范。也可以指定一个 pytree 前缀(例如,用一个值代替整个子树),在这种情况下,叶子节点会广播到该子树中的所有值。

    in_shardings 参数是可选的。JAX 将从输入的 jax.Array 推断分片方式,如果无法推断分片,则默认为复制输入。

    有效的资源分配规范是:
    • Sharding,它将决定值如何被分区。

      使用它,网格上下文管理器不是必需的。

    • None,将给 JAX 自由选择它想要的任何分片方式。对于 in_shardings,JAX 会将其标记为已复制,但此行为将来可能会改变。对于 out_shardings,我们将依赖 XLA GSPMD 分区器来确定输出分片。

    每个维度的大小必须是分配给它的总资源数量的倍数。这类似于 pjit 的 in_shardings。

  • out_shardings

    in_shardings 类似,但指定函数输出的资源分配。这类似于 pjit 的 out_shardings。

    out_shardings 参数是可选的。如果未指定,jax.jit() 将使用 GSPMD 的分片传播来确定输出的分片应该是什么。

  • static_argnums

    一个可选的整数或整数集合,用于指定哪些位置参数被视为静态(编译时常量)。仅依赖于静态参数的操作将在 Python 中(在跟踪期间)进行常量折叠,因此相应的参数值可以是任何 Python 对象。

    静态参数应该是可哈希的,意味着 __hash____eq__ 都已实现,并且是不可变的。使用不同的常量值调用 JIT 编译的函数将触发重新编译。非数组或其容器的参数必须标记为静态。

    如果既未提供 static_argnums 也未提供 static_argnames,则没有参数被视为静态。如果未提供 static_argnums 但提供了 static_argnames,反之亦然,JAX 会使用 inspect.signature(fun) 来查找与 static_argnames 对应的任何位置参数(反之亦然)。如果同时提供了 static_argnumsstatic_argnames,则不使用 inspect.signature,只有在 static_argnumsstatic_argnames 中列出的实际参数才会被视为静态。

  • static_argnames – 一个可选的字符串或字符串集合,用于指定哪些命名参数被视为静态(编译时常量)。有关详细信息,请参见关于 static_argnums 的注释。如果未提供但设置了 static_argnums,则默认值基于调用 inspect.signature(fun) 来查找相应的命名参数。

  • donate_argnums

    指定哪些位置参数缓冲区被“捐赠”给计算。如果您在计算完成后不再需要参数缓冲区,那么捐赠它们是安全的。在某些情况下,XLA 可以利用捐赠的缓冲区来减少执行计算所需的内存量,例如回收您的一个输入缓冲区来存储结果。您不应重用您捐赠给计算的缓冲区,如果您尝试这样做,JAX 将会引发错误。默认情况下,不捐赠任何参数缓冲区。

    如果既未提供 donate_argnums 也未提供 donate_argnames,则不捐赠任何参数。如果未提供 donate_argnums 但提供了 donate_argnames,反之亦然,JAX 会使用 inspect.signature(fun) 来查找与 donate_argnames 对应的任何位置参数(反之亦然)。如果同时提供了 donate_argnumsdonate_argnames,则不使用 inspect.signature,只有在 donate_argnumsdonate_argnames 中列出的实际参数才会被捐赠。

    有关缓冲区捐赠的更多详细信息,请参阅 常见问题解答

  • donate_argnames – 一个可选的字符串或字符串集合,用于指定哪些命名参数被捐赠给计算。有关详细信息,请参见关于 donate_argnums 的注释。如果未提供但设置了 donate_argnums,则默认值基于调用 inspect.signature(fun) 来查找相应的命名参数。

  • keep_unused – 如果为 False(默认值),JAX 确定未被 fun 使用的参数可能会从生成的已编译 XLA 可执行文件中删除。此类参数将不会传输到设备,也不会提供给底层可执行文件。如果为 True,则不会修剪未使用的参数。

  • device – 这是一个实验性功能,API 可能会发生变化。可选,JIT 编译的函数将在其上运行的设备。(可用设备可通过 jax.devices() 检索。)默认值继承自 XLA 的 DeviceAssignment 逻辑,通常是使用 jax.devices()[0]

  • backend – 这是一个实验性功能,API 可能会发生变化。可选,一个表示 XLA 后端的字符串:'cpu''gpu''tpu'

  • inline – 指定此函数是否应内联到封闭的 jaxprs 中(而不是表示为带有其自己的 subjaxpr 的 xla_call 原语的应用)。默认为 False。

返回

一个 fun 的包装版本,已设置为即时编译。

flax.nnx.shard_map(f=<class 'flax.typing.Missing'>, *, mesh, in_specs, out_specs, check_rep=True, auto=frozenset({}))[源代码]#

jax.experimental.shard_map.shard_map 的提升版本,可以处理模块/图节点作为参数。

简单数据并行示例

import jax
import jax.numpy as jnp
from flax import nnx
from jax.sharding import PartitionSpec as P

mesh = jax.sharding.Mesh(jax.local_devices(), ('data',))

m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
x = jnp.ones((32, 2))

@nnx.shard_map(
  mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data')
)
def f(m, x):
  return m(x)

y = f(m, x)

jax.debug.visualize_array_sharding(y)

请注意,这里我们只是用一些 PartitionSpec 来定义整个模型和数据的规范。这适用于简单情况,但如果我们需要为模型的不同部分分配不同的 PartitionSpec,我们需要使用 StateSharding 并创建一些过滤器,使我们能够针对模型的特定部分。以下是如何使用 StateSharding 和过滤器为简单的 MLP 块实现张量并行的示例。

mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))

class MLP(nnx.Module):
  def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
    self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)

  def __call__(self, x):
    return self.linear2(jax.nn.relu(self.linear1(x)))

m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
x = jnp.ones((32, 2))

def path_ends_with(*path_suffix): # custom filter
  return lambda path, value: path[-len(path_suffix):] == path_suffix

model_spec = nnx.StateSharding({
  path_ends_with('linear1', 'kernel'): P(None, 'model'),
  path_ends_with('linear2', 'kernel'): P('model', None),
})

@nnx.shard_map(mesh=mesh, in_specs=(model_spec, P(None)), out_specs=P(None))
def f(m, x):
  y = m(x)
  return jax.lax.psum(y, 'model')

y = f(m, x)

jax.debug.visualize_array_sharding(m.linear1.kernel.value)
jax.debug.visualize_array_sharding(m.linear2.kernel.value)

或者,可以将一个为每个状态具有精确 PartitionSpec 的 State 对象传递给 StateSharding

mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))

class MLP(nnx.Module):
  def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
    self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)

  def __call__(self, x):
    return self.linear2(jax.nn.relu(self.linear1(x)))

m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
x = jnp.ones((32, 2))

model_spec = nnx.State(
  {
    'linear1': {'kernel': P(None, 'model')},
    'linear2': {'kernel': P('model', None)},
  }
)

@nnx.shard_map(
  mesh=mesh,
  in_specs=(nnx.StateSharding(model_spec), P(None)),
  out_specs=P(None),
)
def f(m, x):
  y = m(x)
  return jax.lax.psum(y, 'model')

y = f(m, x)

jax.debug.visualize_array_sharding(m.linear1.kernel.value)
jax.debug.visualize_array_sharding(m.linear2.kernel.value)

这里的 model_spec 是手动创建的,但您也可以通过使用 nnx.get_partition_spec 来自动创建它,从而自动化此过程(请参阅在多个设备上扩展)。

参数
  • f – 要映射的可调用对象。 f 的每次应用,或 f 的“实例”,都将映射参数的一个分片作为输入,并产生输出的一个分片。

  • mesh – 一个 jax.sharding.Mesh,表示要在其上分片数据和执行 f 实例的设备阵列。 Mesh 的名称可用于 f 中的集合通信操作。这通常由 jax.experimental.mesh_utils.create_device_mesh() 等实用函数创建。

  • in_specs – 一个 pytree,叶子节点为 jax.sharding.PartitionSpecnnx.StateSharding(将子状态映射到 PartitionSpec)实例,其树结构是要映射的 args 元组的树前缀。类似于 jax.sharding.NamedSharding,每个 PartitionSpec 表示相应的参数(或参数子树)应如何沿 mesh 的命名轴进行分片。在每个 PartitionSpec 中,在某个位置提及 mesh 轴名称表示沿该位置轴对相应的参数数组轴进行分片;不提及轴名称表示复制。如果一个参数或参数子树具有相应的 None 规范,则该参数不被分片。

  • out_specs – 一个 pytree,叶子节点为 jax.sharding.PartitionSpecnnx.StateSharding(将子状态映射到 PartitionSpec)实例,其树结构是 f 输出的树前缀。每个 PartitionSpec 表示相应的输出分片应如何连接。在每个 PartitionSpec 中,在某个位置提及 mesh 轴名称表示沿相应的位置轴连接该网格轴的分片。不提及 mesh 轴名称表示承诺输出值在该网格轴上是相等的,并且不应连接,只应产生单个值。

  • check_rep – 如果为 True(默认),则启用额外的有效性检查和自动微分优化。有效性检查涉及 out_specs 中未提及的任何网格轴名称是否与 f 的输出复制方式一致。如果在 f 中使用 Pallas 内核,则必须设置为 False。

  • auto – (实验性)一个可选的来自 mesh 的轴名称集合,我们不在此轴上分片数据或映射函数,而是允许编译器控制分片。这些名称不能在 in_specsout_specsf 中的通信集合中使用。

返回

一个可调用对象,它根据 meshin_specs 在分片的数据上应用输入函数 f

flax.nnx.remat(f=<flax.typing.Missing object>, *, prevent_cse=True, static_argnums=(), policy=None)[源代码]#

jax.checkpoint(也称为 jax.remat)的“提升”版本。

flax.nnx.rematjax.checkpoint 类似,可以提供控制,例如,

控制 flax.nnx.grad 值在正向传播期间如何计算和保存,以及在反向传播期间如何重新计算,从而在内存和 FLOPs 之间进行权衡。

Flax NNX 与 JAX 变换中了解更多信息。

要了解 jax.remat,请参阅 JAX 的

jax.checkpoint 基础知识实践说明

flax.nnx.scan(f=<class 'flax.typing.Missing'>, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), out_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), transform_metadata=FrozenDict({}))[源代码]#
flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[源代码]#
flax.nnx.vmap(f=<class 'flax.typing.Missing'>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, transform_metadata=FrozenDict({}))[源代码]#

jax.vmap 的引用感知版本。

参数
  • f – 要在额外轴上映射的函数。

  • in_axes – 一个整数、None 或值序列,指定要映射的输入数组轴(参见 jax.vmap)。除了整数和 None,还可以使用 StateAxes 来控制图节点(如模块)如何向量化,方法是给定一个过滤器,为图节点的子状态指定要应用的轴。

  • out_axes – 一个整数、None 或 pytree,指示映射轴应出现在输出中的位置(参见 jax.vmap)。

  • axis_name – 可选,一个可哈希的 Python 对象,用于标识映射的轴,以便可以应用并行集合操作。

  • axis_size – 可选,一个整数,指示要映射的轴的大小。如果未提供,则从参数推断映射轴的大小。

返回

f 的批处理/向量化版本,其参数与 f 的参数相对应,但在 in_axes 指示的位置有额外的数组轴,并且返回值与 f 的返回值相对应,但在 out_axes 指示的位置有额外的数组轴。

示例

>>> from flax import nnx
>>> from jax import random, numpy as jnp
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((5, 2))
...
>>> @nnx.vmap(in_axes=(None, 0), out_axes=0)
... def forward(model, x):
...   return model(x)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)
>>> class LinearEnsemble(nnx.Module):
...   def __init__(self, num, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3)))
...
>>> model = LinearEnsemble(5, rngs=nnx.Rngs(0))
>>> x = jnp.ones((2,))
...
>>> @nnx.vmap(in_axes=(0, None), out_axes=0)
... def forward(model, x):
...   return jnp.dot(x, model.w.value)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)

要控制图节点子状态如何向量化,可以将 StateAxes 传递给 in_axesout_axes,以指定给定过滤器要应用于每个子状态的轴。以下示例展示了如何在保持不同批次统计数据和 dropout 随机状态的同时,在集成成员之间共享参数。

>>> class Foo(nnx.Module):
...   def __init__(self):
...     self.a = nnx.Param(jnp.arange(4))
...     self.b = nnx.BatchStat(jnp.arange(4))
...
>>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None})
>>> @nnx.vmap(in_axes=(state_axes,), out_axes=0)
... def mul(foo):
...   return foo.a * foo.b
...
>>> foo = Foo()
>>> y = mul(foo)
>>> y
Array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]], dtype=int32)
flax.nnx.eval_shape(f, *args, **kwargs)[源代码]#
jax.eval_shape 的“提升”版本,

可以处理 flax.nnx.Module / 图节点作为参数。

jax.eval_shape 类似,它计算函数 f 的形状/数据类型,而

不执行任何浮点运算(FLOPs),这可能很昂贵。这对于执行形状推断等操作很有用。

flax.nnx.custom_vjp(fun=<flax.typing.Missing object>, *, nondiff_argnums=())[源代码]#

jax.custom_vjp 的引用感知版本。

nnx.custom_vjp 接受模块和其他 Flax NNX 对象作为参数。与 JAX 版本的主要区别在于,由于模块遵循引用语义,它们将输入的状态更新作为辅助输出传播。这意味着 bwd 函数中传入的梯度将具有 (input_updates_g, out_g) 的形式,其中 input_updates_g 是输入相对于输入的梯度更新状态。输入中的所有模块项在 input_updates_g 中都有一个关联的 State 项,而所有非模块项将显示为 None。切线的形状应与输入的形状相同,其中模块项的位置是 State 项。

示例

>>> import jax
>>> import jax.numpy as jnp
>>> from flax import nnx
...
>>> class Foo(nnx.Module):
...   def __init__(self, x, y):
...     self.x = nnx.Param(x)
...     self.y = nnx.Param(y)
...
>>> @nnx.custom_vjp
... def f(m: Foo):
...   return jnp.sin(m.x) * m.y
...
>>> def f_fwd(m: Foo):
...   return f(m), (jnp.cos(m.x), jnp.sin(m.x), m)
...
>>> def f_bwd(res, g):
...   input_updates_g, out_g = g
...   cos_x, sin_x, m = res
...   (m_updates_g,) = input_updates_g
...   m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
...   m_g['x'].value = cos_x * out_g * m.y
...   m_g['y'].value = sin_x * out_g
...   return (m_g,)
...
>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grads = nnx.grad(f)(m)
...
>>> jax.tree.map(jnp.shape, grads)
State({
  'x': Param(
    value=()
  ),
  'y': Param(
    value=()
  )
})

请注意,代表 input_updates_g 上模块项的 State 对象与输出切线中预期的 State 对象具有相同的形状。这意味着您通常可以直接从 input_updates_g 复制它们,并用它们相应的梯度值进行更新。

您可以通过向 nondiff_argnums 传递一个 DiffState 来选择模块和其他图节点的哪些子状态是可微分的(具有切线)。例如,如果您只想对 Foo 类的 x 属性进行微分,您可以执行以下操作:

>>> x_attribute = nnx.PathContains('x')
>>> diff_state = nnx.DiffState(0, x_attribute)
...
>>> @nnx.custom_vjp(nondiff_argnums=(diff_state,))
... def f(m: Foo):
...   return jnp.sin(m.x) * m.y  # type: ignore

>>> def f_fwd(m: Foo):
...   y = f(m)
...   res = (jnp.cos(m.x), m)  # type: ignore
...   return y, res
...
>>> def f_bwd(res, g):
...   input_updates_g, out_g = g
...   cos_x, m = res
...   (m_updates_g,) = input_updates_g
...   m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
...   m_g.x.value = cos_x * out_g * m.y
...   del m_g['y'] # y is not differentiable
...   return (m_g,)

>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m)
...
>>> jax.tree.map(jnp.shape, grad)
State({
  'x': Param(
    value=()
  )
})

请注意,grad 无法为没有由 custom_vjp 定义切线的状态计算梯度,在上面的例子中,我们重用了相同的 x_attribute 过滤器来保持 custom_vjpgrad 的同步。

参数
  • fun – 可调用的基础函数。

  • nondiff_argnums – 整数元组或 DiffState 对象,指定不被微分的参数索引。默认情况下,所有参数都被微分。整数不能用于将模块等图节点标记为不可微分,在这种情况下,请使用 DiffState 对象。DiffState 对象定义可微分子状态的集合,与此参数的名称相反,这样做是为了与 grad 兼容。

flax.nnx.cond(pred, true_fun, false_fun, *operands, **kwargs)[源代码]#
flax.nnx.switch(index, branches, *operands)[源代码]#
flax.nnx.while_loop(cond_fun, body_fun, init_val)[源代码]#

jax.lax.while_loop 的 Flax NNX 变换。

注意:为了使 NNX 内部引用跟踪机制正常工作,您不能在 body_fun 中更改 init_val 的变量引用结构。

示例

>>> import jax
>>> from flax import nnx
>>> def fwd_fn(input):
...   module, x, count = input
...   return module, module(x), count - 1.0

>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> # `module` will be called three times
>>> _, y, _ = nnx.while_loop(
...   lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
参数
  • cond_fun – 用于 while 循环继续条件的函数,接受一个 T 类型的输入并输出一个布尔值。

  • body_fun – 一个函数,接受一个 T 类型的输入并输出一个 T。请注意,T 的数据和模块在输入和输出之间必须具有相同的引用结构。

  • init_valcond_funbody_fun 的初始输入。必须是 T 类型。

flax.nnx.fori_loop(lower, upper, body_fun, init_val, *, unroll=None)[源代码]#

jax.lax.fori_loop 的 Flax NNX 变换。

注意:为了使 NNX 内部引用跟踪机制正常工作,您不能在 body_fun 中更改 init_val 的变量引用结构。

示例

>>> import jax
>>> from flax import nnx

>>> def fwd_fn(i, input):
...   m, x = input
...   m.kernel.value = jnp.identity(10) * i
...   return m, m(x)

>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
>>> np.testing.assert_array_equal(y, x * 2 * 3)
参数
  • lower – 一个整数,表示循环索引的下界(包含)。

  • upper – 一个整数,表示循环索引的上界(不包含)。

  • body_fun – 一个函数,接受一个 T 类型的输入并输出一个 T。请注意,T 的数据和模块在输入和输出之间必须具有相同的引用结构。

  • init_val – body_fun 的初始输入。必须是 T 类型。

  • unroll – 一个可选的整数或布尔值,用于确定循环展开的程度。如果提供一个整数,它确定在循环的单个滚动迭代中运行多少个展开的循环迭代。如果提供一个布尔值,它将确定循环是完全展开(即 unroll=True)还是完全不展开(即 unroll=False)。此参数仅在循环边界是静态已知时适用。

返回

来自最后一次迭代的循环值,类型为 T