摘要

目录

摘要#

flax.nnx.tabulate(obj, *input_args, depth=None, method='__call__', row_filter=<function filter_rng_streams>, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), console_kwargs=mappingproxy({}), **input_kwargs)[源代码]#

以表格形式创建图对象的摘要。

该表格总结了对象的状态和元数据。表格结构如下:

  • 第一列表示对象在图中的路径。

  • 第二列表示对象的类型。

  • 第三列表示传递给对象方法的输入参数。

  • 第四列表示对象方法的输出。

  • 接下来的列提供了有关对象状态的信息,按变量类型分组。

示例

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.bn = nnx.BatchNorm(dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.2, rngs=rngs)
...
...   def __call__(self, x):
...     return nnx.relu(self.dropout(self.bn(self.linear(x))))
...
>>> class Foo(nnx.Module):
...   def __init__(self, rngs: nnx.Rngs):
...     self.block1 = Block(32, 128, rngs=rngs)
...     self.block2 = Block(128, 10, rngs=rngs)
...
...   def __call__(self, x):
...     return self.block2(self.block1(x))
...
>>> foo = Foo(nnx.Rngs(0))
>>> # print(nnx.tabulate(foo, jnp.ones((1, 32))))

                                                      Foo Summary
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ path           ┃ type      ┃ inputs         ┃ outputs        ┃ BatchStat          ┃ Param                   ┃ RngState ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│                │ Foo       │ float32[1,32]  │ float32[1,10]  │ 276 (1.1 KB)       │ 5,790 (23.2 KB)         │ 2 (12 B) │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block1         │ Block     │ float32[1,32]  │ float32[1,128] │ 256 (1.0 KB)       │ 4,480 (17.9 KB)         │ 2 (12 B) │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block1/linear  │ Linear    │ float32[1,32]  │ float32[1,128] │                    │ bias: float32[128]      │          │
│                │           │                │                │                    │ kernel: float32[32,128] │          │
│                │           │                │                │                    │                         │          │
│                │           │                │                │                    │ 4,224 (16.9 KB)         │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block1/bn      │ BatchNorm │ float32[1,128] │ float32[1,128] │ mean: float32[128] │ bias: float32[128]      │          │
│                │           │                │                │ var: float32[128]  │ scale: float32[128]     │          │
│                │           │                │                │                    │                         │          │
│                │           │                │                │ 256 (1.0 KB)       │ 256 (1.0 KB)            │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block1/dropout │ Dropout   │ float32[1,128] │ float32[1,128] │                    │                         │ 2 (12 B) │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block2         │ Block     │ float32[1,128] │ float32[1,10]  │ 20 (80 B)          │ 1,310 (5.2 KB)          │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block2/linear  │ Linear    │ float32[1,128] │ float32[1,10]  │                    │ bias: float32[10]       │          │
│                │           │                │                │                    │ kernel: float32[128,10] │          │
│                │           │                │                │                    │                         │          │
│                │           │                │                │                    │ 1,290 (5.2 KB)          │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block2/bn      │ BatchNorm │ float32[1,10]  │ float32[1,10]  │ mean: float32[10]  │ bias: float32[10]       │          │
│                │           │                │                │ var: float32[10]   │ scale: float32[10]      │          │
│                │           │                │                │                    │                         │          │
│                │           │                │                │ 20 (80 B)          │ 20 (80 B)               │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│ block2/dropout │ Dropout   │ float32[1,10]  │ float32[1,10]  │                    │                         │          │
├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤
│                │           │                │          Total │ 276 (1.1 KB)       │ 5,790 (23.2 KB)         │ 2 (12 B) │
└────────────────┴───────────┴────────────────┴────────────────┴────────────────────┴─────────────────────────┴──────────┘

                                            Total Parameters: 6,068 (24.3 KB)

请注意,block2/dropout 未显示在表格中,因为它与 block1/dropout 共享相同的 RngState

参数
  • obj – 要进行摘要的对象。它可以是 pytree 或图对象,例如 nnx.Module 或 nnx.Optimizer。

  • *input_args – 传递给对象方法的位置参数。

  • **input_kwargs – 传递给对象方法的关键字参数。

  • depth – 表格的深度。

  • method – 要在对象上调用的方法。默认为 '__call__'

  • row_filter – 一个可调用对象,用于筛选要在表格中显示的行。默认情况下,它会筛选掉类型为 nnx.RngStream 的行。

  • table_kwargs – 一个可选字典,包含传递给 rich.table.Table 构造函数的附加关键字参数。

  • column_kwargs – 一个可选字典,包含在向表格添加列时传递给 rich.table.Table.add_column 的附加关键字参数。

  • console_kwargs – 一个可选字典,包含在渲染表格时传递给 rich.console.Console 的附加关键字参数。默认参数为 'force_terminal': True,如果代码在 Jupyter notebook 中运行,则 'force_jupyter' 设置为 True,否则设置为 False

返回

一个总结该对象的字符串。