摘要#
- flax.nnx.tabulate(obj, *input_args, depth=None, method='__call__', row_filter=<function filter_rng_streams>, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), console_kwargs=mappingproxy({}), **input_kwargs)[源代码]#
以表格形式创建图对象的摘要。
该表格总结了对象的状态和元数据。表格结构如下:
第一列表示对象在图中的路径。
第二列表示对象的类型。
第三列表示传递给对象方法的输入参数。
第四列表示对象方法的输出。
接下来的列提供了有关对象状态的信息,按变量类型分组。
示例
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.bn = nnx.BatchNorm(dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.2, rngs=rngs) ... ... def __call__(self, x): ... return nnx.relu(self.dropout(self.bn(self.linear(x)))) ... >>> class Foo(nnx.Module): ... def __init__(self, rngs: nnx.Rngs): ... self.block1 = Block(32, 128, rngs=rngs) ... self.block2 = Block(128, 10, rngs=rngs) ... ... def __call__(self, x): ... return self.block2(self.block1(x)) ... >>> foo = Foo(nnx.Rngs(0)) >>> # print(nnx.tabulate(foo, jnp.ones((1, 32)))) Foo Summary ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ ┃ path ┃ type ┃ inputs ┃ outputs ┃ BatchStat ┃ Param ┃ RngState ┃ ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ │ │ Foo │ float32[1,32] │ float32[1,10] │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1 │ Block │ float32[1,32] │ float32[1,128] │ 256 (1.0 KB) │ 4,480 (17.9 KB) │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/linear │ Linear │ float32[1,32] │ float32[1,128] │ │ bias: float32[128] │ │ │ │ │ │ │ │ kernel: float32[32,128] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 4,224 (16.9 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/bn │ BatchNorm │ float32[1,128] │ float32[1,128] │ mean: float32[128] │ bias: float32[128] │ │ │ │ │ │ │ var: float32[128] │ scale: float32[128] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 256 (1.0 KB) │ 256 (1.0 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/dropout │ Dropout │ float32[1,128] │ float32[1,128] │ │ │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2 │ Block │ float32[1,128] │ float32[1,10] │ 20 (80 B) │ 1,310 (5.2 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/linear │ Linear │ float32[1,128] │ float32[1,10] │ │ bias: float32[10] │ │ │ │ │ │ │ │ kernel: float32[128,10] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 1,290 (5.2 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/bn │ BatchNorm │ float32[1,10] │ float32[1,10] │ mean: float32[10] │ bias: float32[10] │ │ │ │ │ │ │ var: float32[10] │ scale: float32[10] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 20 (80 B) │ 20 (80 B) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/dropout │ Dropout │ float32[1,10] │ float32[1,10] │ │ │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ │ │ │ Total │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ └────────────────┴───────────┴────────────────┴────────────────┴────────────────────┴─────────────────────────┴──────────┘ Total Parameters: 6,068 (24.3 KB)
请注意,
block2/dropout
未显示在表格中,因为它与block1/dropout
共享相同的RngState
。- 参数
obj – 要进行摘要的对象。它可以是 pytree 或图对象,例如 nnx.Module 或 nnx.Optimizer。
*input_args – 传递给对象方法的位置参数。
**input_kwargs – 传递给对象方法的关键字参数。
depth – 表格的深度。
method – 要在对象上调用的方法。默认为
'__call__'
。row_filter – 一个可调用对象,用于筛选要在表格中显示的行。默认情况下,它会筛选掉类型为
nnx.RngStream
的行。table_kwargs – 一个可选字典,包含传递给
rich.table.Table
构造函数的附加关键字参数。column_kwargs – 一个可选字典,包含在向表格添加列时传递给
rich.table.Table.add_column
的附加关键字参数。console_kwargs – 一个可选字典,包含在渲染表格时传递给 rich.console.Console 的附加关键字参数。默认参数为
'force_terminal': True
,如果代码在 Jupyter notebook 中运行,则'force_jupyter'
设置为True
,否则设置为False
。
- 返回
一个总结该对象的字符串。