graph#

flax.nnx.split(node, *filters)[源代码]#

将图节点拆分为一个 GraphDef 和一个或多个 StateState 是一个从字符串或整数到 Variables、数组或嵌套 State 的映射。GraphDef 包含重构 Module 图所需的所有静态信息,它类似于 JAX 的 PyTreeDefsplit()merge() 结合使用,以在图的有状态和无状态表示之间无缝切换。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...
>>> node = Foo(nnx.Rngs(0))
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
...
>>> jax.tree.map(jnp.shape, params)
State({
  'batch_norm': {
    'bias': Param(
      value=(2,)
    ),
    'scale': Param(
      value=(2,)
    )
  },
  'linear': {
    'bias': Param(
      value=(3,)
    ),
    'kernel': Param(
      value=(2, 3)
    )
  }
})
>>> jax.tree.map(jnp.shape, batch_stats)
State({
  'batch_norm': {
    'mean': BatchStat(
      value=(2,)
    ),
    'var': BatchStat(
      value=(2,)
    )
  }
})

split()merge() 主要用于直接与 JAX 转换交互,更多信息请参见函数式 API

参数
  • node – 要拆分的图节点。

  • *filters – 一些可选的过滤器,用于将状态分组到互斥的子状态中。

返回

一个 GraphDef 和一个或多个 States,数量等于传递的过滤器数量。如果未传递过滤器,则返回单个 State

flax.nnx.merge(graphdef, state, /, *states)[源代码]#

flax.nnx.split() 的逆操作。

nnx.merge 接收一个 flax.nnx.GraphDef 和一个或多个 flax.nnx.State,并创建一个与原始节点结构相同的新节点。

回顾:flax.nnx.split() 用于通过以下方式表示 flax.nnx.Module:1) 一个捕获其 Python 静态信息的静态 nnx.GraphDef;以及 2) 一个或多个 flax.nnx.Variable nnx.State,以 JAX PyTree 的形式捕获其 jax.Array

nnx.mergennx.split 结合使用,以在图的有状态和无状态表示之间无缝切换。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...
>>> node = Foo(nnx.Rngs(0))
>>> graphdef, params, batch_stats = nnx.split(node, nnx.Param, nnx.BatchStat)
...
>>> new_node = nnx.merge(graphdef, params, batch_stats)
>>> assert isinstance(new_node, Foo)
>>> assert isinstance(new_node.batch_norm, nnx.BatchNorm)
>>> assert isinstance(new_node.linear, nnx.Linear)

nnx.splitnnx.merge 主要用于直接与 JAX 转换交互(更多信息请参阅函数式 API)。

参数
返回

合并后的 flax.nnx.Module

flax.nnx.update(node, state, /, *states)[源代码]#

使用一个新的或多个状态就地更新给定的图节点。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))

>>> def loss_fn(model, x, y):
...   return jnp.mean((y - model(x))**2)
>>> prev_loss = loss_fn(model, x, y)

>>> grads = nnx.grad(loss_fn)(model, x, y)
>>> new_state = jax.tree.map(lambda p, g: p - 0.1*g, nnx.state(model), grads)
>>> nnx.update(model, new_state)
>>> assert loss_fn(model, x, y) < prev_loss
参数
  • node – 要更新的图节点。

  • state – 一个 State 对象。

  • *states – 额外的 State 对象。

flax.nnx.pop(node, *filters)[源代码]#

从图节点中弹出一个或多个 Variable 类型。

用法示例

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')
>>> y = model(x)
>>> assert hasattr(model, 'i')

>>> intermediates = nnx.pop(model, nnx.Intermediate)
>>> assert intermediates['i'].value[0].shape == (1, 3)
>>> assert not hasattr(model, 'i')
参数
  • node – 一个图节点对象。

  • *filters – 用于过滤的一个或多个 Variable 对象。

返回

弹出的 State,包含被过滤出的 Variable 对象。

flax.nnx.state(node, *filters)[源代码]#

split() 类似,但只返回由过滤器指定的 State

用法示例

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batch_norm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> # get the learnable parameters from the batch norm and linear layer
>>> params = nnx.state(model, nnx.Param)
>>> # get the batch statistics from the batch norm layer
>>> batch_stats = nnx.state(model, nnx.BatchStat)
>>> # get them separately
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> # get them together
>>> state = nnx.state(model)
参数
  • node – 一个图节点对象。

  • *filters – 用于过滤的一个或多个 Variable 对象。

返回

一个或多个 State 映射。

flax.nnx.variables(node, *filters)#

split() 类似,但只返回由过滤器指定的 State

用法示例

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batch_norm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batch_norm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> # get the learnable parameters from the batch norm and linear layer
>>> params = nnx.state(model, nnx.Param)
>>> # get the batch statistics from the batch norm layer
>>> batch_stats = nnx.state(model, nnx.BatchStat)
>>> # get them separately
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> # get them together
>>> state = nnx.state(model)
参数
  • node – 一个图节点对象。

  • *filters – 用于过滤的一个或多个 Variable 对象。

返回

一个或多个 State 映射。

flax.nnx.graph()#
flax.nnx.graphdef(node, /)[源代码]#

获取给定图节点的 GraphDef

用法示例

>>> from flax import nnx

>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> graphdef, _ = nnx.split(model)
>>> assert graphdef == nnx.graphdef(model)
参数

node – 一个图节点对象。

返回

Module 对象的 GraphDef

flax.nnx.iter_graph(node, /)[源代码]#

遍历给定图节点的所有嵌套节点和叶节点,包括当前节点本身。

iter_graph 创建一个生成器,该生成器产生路径和值的对,其中路径是表示从根到该值的路径的字符串或整数元组。重复的节点只访问一次。叶节点包括静态值。

示例:
>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Linear(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.din, self.dout = din, dout
...     self.w = nnx.Param(jax.random.uniform(rngs.next(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...
>>> module = Linear(3, 4, rngs=nnx.Rngs(0))
>>> graph = [module, module]
...
>>> for path, value in nnx.iter_graph(graph):
...   print(path, type(value).__name__)
...
(0, '_object__nodes') frozenset
(0, '_object__state') ObjectState
(0, 'b') Param
(0, 'din') int
(0, 'dout') int
(0, 'w') Param
(0,) Linear
() list
flax.nnx.clone(node)[源代码]#

创建给定图节点的深层副本。

用法示例

>>> from flax import nnx

>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> cloned_model = nnx.clone(model)
>>> model.bias.value += 1
>>> assert (model.bias.value != cloned_model.bias.value).all()
参数

node – 一个图节点对象。

返回

Module 对象的深层副本。

flax.nnx.call(graphdef_state, /)[源代码]#

调用由 (GraphDef, State) 对定义的底层图节点的方法。

call 接收一个 (GraphDef, State) 对,并创建一个代理对象,该对象可用于调用底层图节点上的方法。当调用一个方法时,其输出将与代表图节点更新后状态的新的 (GraphDef, State) 对一起返回。call 等效于 merge() > method > split`(),但在纯 JAX 函数中使用起来更方便。

示例

>>> from flax import nnx
>>> import jax
>>> import jax.numpy as jnp
...
>>> class StatefulLinear(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...     self.count = Variable(jnp.array(0, dtype=jnp.uint32))
...
...   def increment(self):
...     self.count += 1
...
...   def __call__(self, x):
...     self.increment()
...     return x @ self.w + self.b
...
>>> linear = StatefulLinear(3, 2, nnx.Rngs(0))
>>> linear_state = nnx.split(linear)
...
>>> @jax.jit
... def forward(x, linear_state):
...   y, linear_state = nnx.call(linear_state)(x)
...   return y, linear_state
...
>>> x = jnp.ones((1, 3))
>>> y, linear_state = forward(x, linear_state)
>>> y, linear_state = forward(x, linear_state)
...
>>> linear = nnx.merge(*linear_state)
>>> linear.count.value
Array(2, dtype=uint32)

call 返回的代理对象支持索引和属性访问,以访问嵌套的方法。在下面的示例中,increment 方法的索引被用来调用 nodes 字典中 b 键对应的 StatefulLinear 模块的 increment 方法。

>>> class StatefulLinear(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.w = nnx.Param(jax.random.uniform(rngs(), (din, dout)))
...     self.b = nnx.Param(jnp.zeros((dout,)))
...     self.count = nnx.Variable(jnp.array(0, dtype=jnp.uint32))
...
...   def increment(self):
...     self.count += 1
...
...   def __call__(self, x):
...     self.increment()
...     return x @ self.w + self.b
...
>>> rngs = nnx.Rngs(0)
>>> nodes = dict(
...   a=StatefulLinear(3, 2, rngs),
...   b=StatefulLinear(2, 1, rngs),
... )
...
>>> node_state = nnx.split(nodes)
>>> # use attribute access
>>> _, node_state = nnx.call(node_state)['b'].increment()
...
>>> nodes = nnx.merge(*node_state)
>>> nodes['a'].count.value
Array(0, dtype=uint32)
>>> nodes['b'].count.value
Array(1, dtype=uint32)
flax.nnx.cached_partial(f, *cached_args)#

从一个经过 NNX 转换的函数以及一些缓存的输入参数创建一个偏函数,并通过缓存 NNX 图节点的遍历来减少 Python 开销。这对于加速那些使用相同输入子集重复调用的函数非常有用,例如带有 modeloptimizertrain_step

>>> from flax import nnx
>>> import jax.numpy as jnp
>>> import optax
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param)
...
>>> @nnx.jit
... def train_step(model, optimizer, x, y):
...   def loss_fn(model):
...     return jnp.mean((model(x) - y) ** 2)
...
...   loss, grads = nnx.value_and_grad(loss_fn)(model)
...   optimizer.update(model, grads)
...   return loss
...
>>> cached_train_step = nnx.cached_partial(train_step, model, optimizer)
...
>>> for step in range(total_steps:=2):
...   x, y = jnp.ones((10, 2)), jnp.ones((10, 3))
...   # loss = train_step(model, optimizer, x, y)
...   loss = cached_train_step(x, y)
...   print(f'Step {step}: loss={loss:.3f}')
Step 0: loss=2.669
Step 1: loss=2.660

请注意,cached_partial 将克隆所有缓存的图节点以保证缓存的有效性,并且这些克隆将包含对相同 Variable 对象的引用,这保证了状态能正确传播回原始图节点。因此,每次调用缓存函数后,所有图节点的最终结构必须相同,否则将引发错误。允许临时性的修改(例如使用 Module.sow),只要在函数返回前清理掉即可(例如通过 nnx.pop)。

参数
  • f – 要缓存的函数。

  • *cached_args – 包含要缓存的图节点的输入参数子集。

返回

一个偏函数,它期望接收原始函数的其余参数。

class flax.nnx.GraphDef(nodes: 'list[NodeDefType[tp.Any]]', attributes: 'list[tuple[Key, AttrType]]', num_leaves: 'int')[源代码]#
class flax.nnx.UpdateContext(tag, outer_ref_outer_index, outer_index_inner_ref, outer_index_outer_ref, inner_ref_outer_index, static_cache)[源代码]#

用于处理复杂状态更新的上下文管理器。

flax.nnx.update_context(tag)[源代码]#

创建一个 UpdateContext 上下文管理器,它可以用来处理比 nnx.update 更复杂的状态更新,包括对静态属性和图结构的更新。

UpdateContext 暴露了一个 splitmerge API,其签名与 nnx.split / nnx.merge 相同,但它会执行一些簿记工作,以获取必要信息,从而能够根据转换内部所做的更改完美地更新输入对象。UpdateContext 必须总共调用 split 和 merge 4 次,第一次和最后一次调用发生在转换外部,第二次和第三次调用发生在转换内部,如下图所示:

                      idxmap
(2) merge ─────────────────────────────► split (3)
      ▲                                    │
      │               inside               │
      │. . . . . . . . . . . . . . . . . . │ index_mapping
      │               outside              │
      │                                    ▼
(1) split──────────────────────────────► merge (4)
                      refmap

对 split 的第一次调用 (1) 创建了一个 refmap,它跟踪外部引用;对 merge 的第一次调用 (2) 创建了一个 idxmap,它跟踪内部引用。对 split 的第二次调用 (3) 结合了 refmap 和 idxmap 来生成 index_mapping,该映射指示了外部引用如何映射到内部引用。最后,对 merge 的最后一次调用 (4) 使用 index_mapping 和 refmap 来重构转换的输出,同时重用/更新内部引用。为避免内存泄漏,idxmap 在 (3) 之后被清除,refmap 在 (4) 之后被清除,并且两者都在上下文管理器退出后被清除。

这里是一个简单的示例,展示了 update_context 的用法:

>>> from flax import nnx
...
>>> class Foo(nnx.Module): pass
...
>>> m1 = Foo()
>>> with nnx.update_context('example'):
...   with nnx.split_context('example') as ctx:
...     graphdef, state = ctx.split(m1)
...   @jax.jit
...   def f(graphdef, state):
...     with nnx.merge_context('example', inner=True) as ctx:
...       m2 = ctx.merge(graphdef, state)
...     m2.a = 1
...     m2.ref = m2  # create a reference cycle
...     with nnx.split_context('example') as ctx:
...       return ctx.split(m2)
...   graphdef_out, state_out = f(graphdef, state)
...   with nnx.merge_context('example', inner=False) as ctx:
...     m3 = ctx.merge(graphdef_out, state_out)
...
>>> assert m1 is m3
>>> assert m1.a == 1
>>> assert m1.ref is m1

请注意,update_context 接受一个 tag 参数,主要用作安全机制,以减少在使用 current_update_context() 访问当前活动上下文时意外使用错误 UpdateContext 的风险。update_context 也可以用作装饰器,在函数执行期间创建/激活一个 UpdateContext 上下文。

>>> from flax import nnx
...
>>> class Foo(nnx.Module): pass
...
>>> m1 = Foo()
>>> @jax.jit
... def f(graphdef, state):
...   with nnx.merge_context('example', inner=True) as ctx:
...     m2 = ctx.merge(graphdef, state)
...   m2.a = 1     # insert static attribute
...   m2.ref = m2  # create a reference cycle
...   with nnx.split_context('example') as ctx:
...     return ctx.split(m2)
...
>>> @nnx.update_context('example')
... def g(m1):
...   with nnx.split_context('example') as ctx:
...     graphdef, state = ctx.split(m1)
...   graphdef_out, state_out = f(graphdef, state)
...   with nnx.merge_context('example', inner=False) as ctx:
...     return ctx.merge(graphdef_out, state_out)
...
>>> m3 = g(m1)
>>> assert m1 is m3
>>> assert m1.a == 1
>>> assert m1.ref is m1

可以使用 current_update_context() 访问该上下文。

参数

tag – 用于标识上下文的字符串标签。

flax.nnx.current_update_context(tag)[源代码]#

返回给定标签的当前活动的 UpdateContext