过滤器#

Flax NNX 广泛使用Filter(过滤器)作为在 API 中创建 nnx.State 组的方式,例如 nnx.splitnnx.state() 以及许多 Flax NNX 变换(transforms)

在本指南中,您将学习如何

在下面的示例中,nnx.Paramnnx.BatchStat 被用作 Filter,将模型分成两组:一组包含参数,另一组包含批处理统计信息。

from flax import nnx

class Foo(nnx.Module):
  def __init__(self):
    self.a = nnx.Param(0)
    self.b = nnx.BatchStat(True)

foo = Foo()

graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)

print(f'{params = }')
print(f'{batch_stats = }')
params = State({
  'a': Param(
    value=0
  )
})
batch_stats = State({
  'b': BatchStat(
    value=True
  )
})

让我们更深入地了解 Filter

Filter 协议#

通常,Flax Filter 是形式如下的谓词函数


(path: tuple[Key, ...], value: Any) -> bool

其中

  • Key 是一个可哈希且可比较的类型;

  • path 是一个由 Key 组成的元组,表示嵌套结构中值的路径;以及

  • value 是该路径上的值。

如果值应该被包含在组中,则函数返回 True,否则返回 False

类型并非这种形式的函数。它们被视为 Filter 是因为,正如您将在下一节中学到的,类型和一些其他字面量会被转换为*谓词*。例如,nnx.Param 大致会转换为像这样的谓词

def is_param(path, value) -> bool:
  return isinstance(value, nnx.Param)

print(f'{is_param((), nnx.Param(0)) = }')
is_param((), nnx.Param(0)) = True

这样的函数会匹配任何 nnx.Param 的实例。在内部,Flax NNX 使用 OfType,它为给定类型定义了这种形式的可调用对象

is_param = nnx.OfType(nnx.Param)

print(f'{is_param((), nnx.Param(0)) = }')
is_param((), nnx.Param(0)) = True

Filter DSL#

Flax NNX 提供了一个小型的领域特定语言(DSL),形式化为 nnx.filterlib.Filter 类型。这意味着用户不必像上一节中那样创建函数。

以下是 Flax NNX 中包含的所有可调用的 Filter,以及它们对应的 DSL 字面量(如果可用)

字面量

可调用对象

描述

...True

Everything()

匹配所有值

NoneFalse

Nothing()

不匹配任何值

type

OfType(type)

匹配 type 实例的值,或其 type 属性是 type 实例的值

PathContains(key)

匹配其关联 path 包含给定 key 的值

'{filter}' str

WithTag('{filter}')

匹配其字符串 tag 属性等于 '{filter}' 的值。由 RngKeyRngCount 使用。

(*filters) tuple[*filters] list

Any(*filters)

匹配与内部任一 filters 匹配的值

All(*filters)

匹配与内部所有 filters 匹配的值

Not(filter)

匹配与内部 filter 不匹配的值

让我们通过一个使用 nnx.vmap 的例子来看看 DSL 的实际应用。考虑以下情况:

  1. 您想要向量化所有参数;

  2. 在第 0 轴上应用 'dropout' Rng(Keys|Counts);以及

  3. 广播其余部分。

为此,您可以使用以下 Filter 来定义一个 nnx.StateAxes 对象,然后将其传递给 nnx.vmapin_axes,以指定 model 的各个子状态应如何向量化:

state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})

@nnx.vmap(in_axes=(state_axes, 0))
def forward(model, x):
  ...

这里,(nnx.Param, 'dropout') 展开为 Any(OfType(nnx.Param), WithTag('dropout')),而 ... 展开为 Everything()

如果您希望手动将字面量转换为谓词,可以使用 nnx.filterlib.to_predicate

is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))

print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.variablelib.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.variablelib.Param'>), WithTag('dropout'))

State 进行分组#

掌握了前几节中关于 Filter 的知识后,让我们来学习如何大致实现 nnx.split。以下是关键思想:

  • 使用 nnx.graph.flatten 获取节点的 GraphDefnnx.State 表示。

  • 将所有 Filter 转换为谓词。

  • 使用 State.flat_state 获取状态的扁平化表示。

  • 遍历扁平化状态中的所有 (path, value) 对,并根据谓词对它们进行分组。

  • 使用 State.from_flat_state 将扁平化状态转换为嵌套的 nnx.State

from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]

def split(node, *filters):
  graphdef, state = nnx.graph.flatten(node)
  predicates = [nnx.filterlib.to_predicate(f) for f in filters]
  flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]

  for path, value in state:
    for i, predicate in enumerate(predicates):
      if predicate(path, value):
        flat_states[i][path] = value
        break
    else:
      raise ValueError(f'No filter matched {path = } {value = }')

  states: tuple[nnx.GraphState, ...] = tuple(
    nnx.State.from_flat_path(flat_state) for flat_state in flat_states
  )
  return graphdef, *states

# Let's test it.
foo = Foo()

graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)

print(f'{params = }')
print(f'{batch_stats = }')
params = State({
  'a': Param(
    value=0
  )
})
batch_stats = State({
  'b': BatchStat(
    value=True
  )
})

注意:* 了解过滤是依赖于顺序的这一点非常重要。第一个匹配到某个值的 Filter 会保留该值,因此您应该将更具体的 Filter 放在更通用的 Filter 之前。

例如,如下所示,如果您:

  1. 创建一个 SpecialParam 类型,它是 nnx.Param 的子类,并创建一个包含两种类型参数的 Bar 对象(子类化自 nnx.Module);并且

  2. 尝试在分割 SpecialParam 之前分割 nnx.Param

那么所有值都将被放入 nnx.Param 组中,而 SpecialParam 组将为空,因为所有 SpecialParam 同时也是 nnx.Param

class SpecialParam(nnx.Param):
  pass

class Bar(nnx.Module):
  def __init__(self):
    self.a = nnx.Param(0)
    self.b = SpecialParam(0)

bar = Bar()

graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
  'a': Param(
    value=0
  ),
  'b': SpecialParam(
    value=0
  )
})
special_params = State({})

而颠倒顺序将确保 SpecialParam 首先被捕获

graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
  'a': Param(
    value=0
  )
})
special_params = State({
  'b': SpecialParam(
    value=0
  )
})