object#
- flax.nnx.data(value, /)[源代码]#
将一个属性标注为 pytree 数据。
data 的返回值必须直接赋给一个 Object 属性,该属性将被注册为 pytree 数据属性。
示例
from flax import nnx import jax class Foo(nnx.Object): def __init__(self): self.data_attr = nnx.data(42) # pytree data self.static_attr = "hello" # static attribute foo = Foo() assert jax.tree.leaves(foo) == [42]
- 参数
value – 要标注为数据的值。
- 返回
一个在赋值时将属性注册为数据的值。
- flax.nnx.Data#
Data 使用类型注解将类的属性标记为 pytree 数据。
Data 注解必须在类级别使用,并将应用于所有实例。当类型注解已经存在或被要求(例如,对于 dataclass)时,推荐使用 Data。
示例
from flax import nnx import jax import dataclasses @dataclasses.dataclass class Foo(nnx.Object): a: nnx.Data[int] # Annotates `a` as pytree data b: str # `b` is not pytree data foo = Foo(a=42, b='hello') assert jax.tree.leaves(foo) == [42]
A
[A
] 的别名
- flax.nnx.is_data_type(value, /)[源代码]#
检查一个值是否为已注册的数据类型。
此函数检查该值是否为已注册的数据类型,这意味着当它被赋给一个 Object 属性时,它会被自动识别为 pytree 数据。
数据类型包括: - jax.Arrays - np.ndarrays - MutableArrays - Variables (Param, BatchStat, RngState, 等) - 所有图节点 (Object, Module, Rngs, 等) - 任何通过 nnx.register_data_type 注册的类型
示例
from flax import nnx import jax.numpy as jnp module = nnx.Linear(1, 1, rngs=nnx.Rngs(0)) blocks = [module, module, module] assert nnx.is_data_type(jnp.array(42)) # Arrays are data assert nnx.is_data_type(nnx.Param(1)) # Variables are data assert nnx.is_data_type(nnx.Rngs(0)) # Objects are data assert nnx.is_data_type(module) # Objects are data assert not nnx.is_data_type(0.) # float is not data assert not nnx.is_data_type(1) # int is not data assert not nnx.is_data_type("hello") # str is not data assert not nnx.is_data_type(blocks) # list is not data
- 参数
value – 要检查的值。
- 返回
如果该值是已注册的数据类型,则为 True,否则为 False。
- flax.nnx.register_data_type(type_, /)[源代码]#
将一个类型注册为 Object 识别的 pytree 数据类型。
注册为数据的自定义类型在赋给 Object 属性时,将被自动识别为数据属性。这意味着该类型的值不需要用 nnx.data(…) 包装,Object 就能将其被赋值的属性标记为数据。
示例
from flax import nnx from dataclasses import dataclass @dataclass(frozen=True) class MyType: value: int nnx.register_data_type(MyType) class Foo(nnx.Object): def __init__(self, a): self.a = MyType(a) # Automatically registered as data self.b = "hello" # str not registered as data foo = Foo(42) assert nnx.is_data_type(foo.a) # True assert jax.tree.leaves(foo) == [MyType(value=42)]