object#

class flax.nnx.Object(*args, **kwargs)[源代码]#

所有 NNX 对象的基类。

flax.nnx.data(value, /)[源代码]#

将一个属性标注为 pytree 数据。

data 的返回值必须直接赋给一个 Object 属性,该属性将被注册为 pytree 数据属性。

示例

from flax import nnx
import jax

class Foo(nnx.Object):
  def __init__(self):
    self.data_attr = nnx.data(42)  # pytree data
    self.static_attr = "hello"     # static attribute

foo = Foo()

assert jax.tree.leaves(foo) == [42]
参数

value – 要标注为数据的值。

返回

一个在赋值时将属性注册为数据的值。

flax.nnx.Data#

Data 使用类型注解将类的属性标记为 pytree 数据。

Data 注解必须在类级别使用,并将应用于所有实例。当类型注解已经存在或被要求(例如,对于 dataclass)时,推荐使用 Data。

示例

from flax import nnx
import jax
import dataclasses

@dataclasses.dataclass
class Foo(nnx.Object):
  a: nnx.Data[int]  # Annotates `a` as pytree data
  b: str            # `b` is not pytree data

foo = Foo(a=42, b='hello')

assert jax.tree.leaves(foo) == [42]

A[A] 的别名

flax.nnx.is_data_type(value, /)[源代码]#

检查一个值是否为已注册的数据类型。

此函数检查该值是否为已注册的数据类型,这意味着当它被赋给一个 Object 属性时,它会被自动识别为 pytree 数据。

数据类型包括: - jax.Arrays - np.ndarrays - MutableArrays - Variables (Param, BatchStat, RngState, 等) - 所有图节点 (Object, Module, Rngs, 等) - 任何通过 nnx.register_data_type 注册的类型

示例

from flax import nnx
import jax.numpy as jnp

module = nnx.Linear(1, 1, rngs=nnx.Rngs(0))
blocks = [module, module, module]

assert nnx.is_data_type(jnp.array(42))  # Arrays are data
assert nnx.is_data_type(nnx.Param(1))   # Variables are data
assert nnx.is_data_type(nnx.Rngs(0))    # Objects are data
assert nnx.is_data_type(module)         # Objects are data

assert not nnx.is_data_type(0.)         # float is not data
assert not nnx.is_data_type(1)          # int is not data
assert not nnx.is_data_type("hello")    # str is not data
assert not nnx.is_data_type(blocks)     # list is not data
参数

value – 要检查的值。

返回

如果该值是已注册的数据类型,则为 True,否则为 False。

flax.nnx.register_data_type(type_, /)[源代码]#

将一个类型注册为 Object 识别的 pytree 数据类型。

注册为数据的自定义类型在赋给 Object 属性时,将被自动识别为数据属性。这意味着该类型的值不需要用 nnx.data(…) 包装,Object 就能将其被赋值的属性标记为数据。

示例

from flax import nnx
from dataclasses import dataclass

@dataclass(frozen=True)
class MyType:
  value: int

nnx.register_data_type(MyType)

class Foo(nnx.Object):
  def __init__(self, a):
    self.a = MyType(a)  # Automatically registered as data
    self.b = "hello"     # str not registered as data

foo = Foo(42)

assert nnx.is_data_type(foo.a)  # True
assert jax.tree.leaves(foo) == [MyType(value=42)]