在 FLIP 中重构 RNNCellBase#
作者:Cristian Garcia、Marcus Chiam、Jasmijn Bastings
开始日期:2023 年 5 月 1 日
FLIP 问题:[待定]
FLIP PR:#3053
状态:已实施
摘要#
本提案旨在通过重构 initialize_carry
方法及其他相关组件,来提升 RNNCellBase
类的可用性。
动机#
目前,initialize_carry
不仅用于初始化 carry,还用于传递关键元数据,例如特征数量。这个 API 可能不直观,因为它要求用户手动计算一些通常可以由模块自行推断的信息,例如批次维度的形状和特征维度的形状。
示例:ConvLSTM#
在诸如 ConvLSTM
的情况下,当前的 API 可能不直观,因为其 size
参数同时包含了输入图像形状和输出特征维度。
x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels)
# image shape: vvvvvvv
carry = nn.ConvLSTMCell.initialize_carry(key1, (16,), (64, 64, 16))
# batch size: ^^ ^^ :output features
lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3))
(carry, y), initial_params = lstm.init_with_output(key2, carry, x)
此 FLIP 将提议对 initialize_carry
进行一些更改,以便将前面的示例简化为
x = jnp.ones((2, 4, 4, 3)) # (batch, *image_shape, channels)
lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3))
carry = lstm.initialize_carry(key1, input_shape=x.shape)
(carry, y), initial_params = lstm.init_with_output(key2, carry, x)
实现#
本提案建议进行以下更改
initialize_carry#
initialize_carry
应重构为一个实例方法,其签名如下
def initialize_carry(self, key, sample_input):
sample_input
应该是一个数组,其形状与单元将要处理的输入形状相同,但不包括时间轴。
重构 RNNCellBase 子类#
RNNCellBase
应该进行重构,以包含初始化单元和执行其前向传播所需的元数据。对于 LSTMCell
和 GRUCell
,这意味着添加一个 features
属性,该属性应由用户在构造时提供。这一更改与其他大多数 Module
的结构保持一致,让用户感觉更熟悉。
x = jnp.ones((2, 100, 10)) # (batch, time, features)
cell = nn.LSTMCell(features=32)
carry = cell.initialize_carry(PRNGKey(0), x[:, 0]) # sample input
(carry, y), variables = cell.init_with_output(PRNGKey(1), carry, x)
num_feature_dims#
为了简化在 RNN
等抽象中处理 RNNCellBase
实例的方式,每个单元都应实现 num_feature_dims
属性。对于大多数单元,如 LSTMCell
和 GRUCell
,该值始终为 1。对于像 ConvLSTM
这样的单元,该值取决于其 kernel_size
。
讨论#
替代方法#
为了消除对
num_feature_dims
的需求,RNN
可以只支持单个批次维度,即输入形式为(batch, time, *features)
。目前,它同时支持多个批次维度和多个特征维度。另一种方法可能是完全重新设计 Flax 处理循环状态的方式。例如,可以将一个
memory
集合作为变量的一部分来处理。然而,这会带来一些挑战,例如在训练期间处理无状态单元、将状态从一层传递到另一层,以及在scan
内部执行初始化。
重构成本#
最初的 TGP 结果显示有 761 个测试中断和 110 个测试失败。然而,在修复一个测试后,TGP 结果显示有 231 个测试中断和 13 个测试失败,因此中断的测试之间似乎存在大量重叠。
为了最大限度地降低重构成本,当前的实现将以一个已弃用的名称为 Google 内部用户保留。这将允许用户按照自己的节奏迁移到新的 API。对于开源用户,我们应将 Flax 版本提升到 0.7.0
,以便现有用户可以继续依赖 0.6.x
版本。