RNN Flip#
开始日期:2022-08-18
FLIP PR:#2604
FLIP Issue:#2396
作者:Jasmijn Bastings (@bastings) 和 Cristian Garcia (@cgarciae)
摘要#
此 FLIP 增加了对更高级别循环层(RNN、GRU、LSTM)的支持,这些层可以帮助用户使用 Flax 中已有的循环单元来处理输入序列。
动机#
实现众所周知的循环架构很棘手且容易出错,即使是一个简单的 LSTM 层也涉及手动创建和处理 carry/memory,并正确设置 nn.scan
。
@nn.compact
def __call__(self, x):
LSTM = nn.scan(
nn.LSTMCell, variable_broadcast="params", split_rngs={"params": False}
)
carry = LSTM.initialize_carry(
jax.random.key(0), batch_dims=x.shape[:1], size=self.hidden_size
)
carry, x = LSTM()(carry, x)
return x
涉及填充(padding)的稍微复杂的情况,例如在 seq2seq 示例中,需要更多的工作,但通过正确的抽象,有可能简化为几行代码。我们提议为用户提供清晰、正确且高效的抽象来使用循环单元。
要求#
掩码(Masking):我们需要支持一批序列,其中每个序列的末尾都包含填充。
出于性能原因,我们不打算支持非连续填充,即填充不在序列末尾的情况,除非是在打包(见下文)的情况下。
双向性(Bidirectionality):能够以正向和反向两种方式处理序列,并尊重填充(即,反向处理应从实际输入开始,而不是从填充值开始)。
性能(Performance):应对提议的类进行基准测试,以在步长时间和/或内存使用方面提供最佳性能。
循环 Dropout(Recurrent Dropout):支持单元内的循环 dropout(例如,对单元状态应用 dropout)。
实现#
高层结构#
我们提议采用以下三个层次的抽象:
单元(Cells,不变):所有 RNNCellBase 的子类,如 LSTMCell 和 GRUCell,它们实现单步逻辑。这些在 Flax 中已经存在。
层(Layers,新增):一个类(RNN),它接受一个单元并扫描一个序列,同时尊重可能的填充值,并可选地允许打包序列。
双向(Bidirectional,新增):一个单一的类,它接受一个前向和一个后向 RNN 实例,并正确地双向处理输入序列,然后合并结果。
提议的 API 示例#
我们首先提供一个代码示例,展示使用提议的 API 可以实现什么功能,然后在下面详细讨论 API。
cell = nn.LSTMCell()
# Encodes a batch of input sequences.
carry, outputs = nn.RNN(cell, cell_size)(inputs, seq_lengths)
一个双向层,其前向和后向分别使用 LSTM RNNs,将如下所示:
forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32)
backward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32)
# Bidirectional combinator.
bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn)
# Encodes a batch of input sequences in both directions.
carry, outputs = bi_rnn(inputs, seq_lengths)
接下来,我们将讨论 RNN
、Bidirectional
以及对 RNNCellBase
的提议更改。
RNNBase#
RNNBase
类作为 RNN
类的基类,它指定了所有 RNN 层为与 Bidirectional
兼容而应实现的 API。RNNBase
包含 __call__
和 flip_sequences
方法。
class RNNBase(Protocol):
def __call__(
self,
inputs: jax.Array,
*,
initial_carry: Optional[Carry] = None,
init_key: Optional[random.KeyArray] = None,
seq_lengths: Optional[Array] = None,
return_carry: Optional[bool] = None,
time_major: Optional[bool] = None,
reverse: Optional[bool] = None,
keep_order: Optional[bool] = None,
) -> Union[Output, Tuple[Carry, Output]]:
...
其中
inputs
:输入序列。initial_carry
:初始 carry,如果未提供,将使用单元的RNNCellBase.initialize_carry
方法进行初始化。init_key
:用于初始化 carry 的 PRNG 密钥,如果未提供,将使用jax.random.key(0)
。大多数单元会忽略此参数。seq_lengths
:一个可选的整数数组,形状为(*batch)
,指示每个序列的长度,时间维度上索引大于相应长度的元素将被视为填充并被忽略。return_carry
:如果return_carry=False
(默认),则仅返回输出序列,否则将返回一个包含最终 carry 和输出序列的元组。time_major
:如果time_major=False
(默认),则期望输入的形状为(*batch, time, *features)
,否则期望输入的形状为(time, *batch, *features)
。reverse
:如果reverse=False
(默认),则序列从左到右处理并按原始顺序返回,否则将从右到左处理并按相反顺序返回。如果传递了seq_lengths
,填充将始终保留在序列的末尾。keep_order
:如果keep_order=True
,当reverse=True
时,输出将在处理后被反转回原始顺序,这对于在双向 RNN 中对齐序列很有用。如果keep_order=False
(默认),输出将保持由reverse
指定的顺序。Returns
:如果return_carry=False
(默认),则仅返回输出序列,否则将返回一个包含最终 carry 和输出序列的元组。
RNN#
RNN
模块继承自 RNNBase
,其主要功能是在一批输入序列上应用一个 RNNCellBase
实例,它可以与任何类型的单元(例如 GRUCell
、LSTMCell
等)一起使用。它接受以下参数:
class RNN(RNNBase):
cell: RNNCellBase,
cell_size: int | Tuple[int, ...]
time_axis: int = -2,
variable_axes = FrozenDict(),
variable_broadcast: CollectionFilter = 'params'
variable_carry: CollectionFilter = False
split_rngs = FrozenDict({'params': False})
# implement RNNBase
...
像 variable_axes
、variable_broadcast
、variable_carry
和 split_rngs
这样的属性会直接传递给 nn.scan
,它们的默认值设置得使常见的单元如 LSTMCell
和 GRUCell
可以开箱即用。
掩码#
seq_lengths
被定义为一个形状为 (*batch,)
的整数数组,表示每个序列的长度。
讨论
在其他框架中有多种掩码格式,以下是一些最常见的:
二进制掩码:为每个样本和时间步指定该数据点是否应包含在计算中,它可以是非连续的(例如 [1, 1, 0, 1])。Keras 使用这种方式。
序列长度掩码:为每个样本指定序列中包含的非填充样本的数量,序列中包含的任何填充都应堆叠在末尾。FlaxFormer 使用这种方式。
分段掩码:指定数据点属于哪个样本的行和时间步,这种格式允许每行有多个样本,从而可能减少所需的总填充量(例如 [1, 1, 1, 2, 2, 0, 0])。PyTorch 使用这种表示(参见 pack_padded_sequence)。
虽然序列打包(参见 LM1B 示例)功能更强大,但其实现更复杂,是否值得投入精力尚不清楚。最简单的格式是序列长度掩码,我们提议使用这种格式。
双向#
双向处理可以通过一个模块来实现,该模块接受一个 forward_rnn
模块和一个 backward_rnn
模块,两者都应是 RNN
实例,以便双向处理输入序列。下面是一些实现的伪代码:
def __call__(self, inputs, seq_lengths):
# Encode in the forward direction.
carry_forward, outputs_forward = self.forward_rnn(
inputs, seq_lengths=seq_lengths,
return_carry=True, reverse=False,
)
# Encode in the reverse order.
carry_backward, outputs_backward = self.backward_rnn(
inputs, seq_lengths=seq_lengths,
return_carry=True, reverse=True, # process in reverse order
keep_order=True, # but return the sequence in the original order
)
# Merge both sequences.
outputs = jax.tree.map(self.merge_fn, outputs_forward, outputs_backward)
return (carry_forward, carry_backward), outputs
这里的 merge_fn
是一个函数,它接受两个输出并将它们融合(默认为 concat
)。如本文开头所示,用法如下:
forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32)
backward_rnn = nn.RNN(nn.GRUCell(), cell_size=32)
# Bidirectional combinator.
bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn)
# Encodes a batch of input sequences in both directions.
carry, outputs = bi_rnn(inputs, seq_lengths)
循环 Dropout#
在 RNN 中,dropout 有两种主要用法:
输入 dropout:应用于输入的常规 dropout,每一步都不同。
循环 dropout:将 dropout 应用于循环输入/输出,每一步都相同。
Flax 的 nn.scan
可以通过 split_rns
轻松表达这两种 dropout,输入 dropout 会分割 rngs,而循环 dropout 则不会。#2540 的引入使得 nn.Dropout
中的 rng_name
现在可以由用户定义,这样 Cells 就可以定义两种类型的 dropout,例如:
self.dropout = nn.Dropout(...) # input dropout
self.recurrent_dropout = nn.Dropout(..., rng_collection='recurrent_dropout')
基于此,nn.scan
/ nn.RNN
现在可以相应地指定 split_rngs
,例如:
nn.scan(scan_fn, ..., split_rngs={'dropout': True, 'recurrent_dropout': False})
未来构想#
显示
序列打包#
允许打包多个序列以有效利用空间/内存。这可能会导致一种权衡,即步长时间更长(因为在每一步都需要检查是否开始一个新序列并重置 carry/初始状态),但使用的填充更少,从而整体上提高了效率。
RNNCell 重新设计#
将 initialize_state 设为实例方法#
第一个替代方案是使 initialize_carry
成为一个实例方法。通过这一更改,超参数可以直接传递给单元,其签名将如下所示:
def initialize_carry(self, sample_input) -> Carry:
...
用法如下:
LSTM = nn.scan(
nn.LSTMCell, variable_broadcast='params',
split_rngs={'dropout': True})
lstm = LSTM(features=32)
carry = lstm.initialize_carry(x[:, 0])
carry, y = lstm(carry, x)
移除 initialize_carry#
一个替代方案是完全移除 initialize_carry
,并将 carry 状态作为 carry 集合来处理。这将大大简化用法:
LSTM = nn.scan(
nn.LSTMCell, variable_broadcast='params',
split_rngs={'dropout': True})
y = LSTM(features=32)(carry, x)
然而,这需要 nn.scan
支持 carry 集合的初始化,这目前是无法做到的。此外,用户将必须指定一个集合是可变的,例如 mutable=['carry']
,即使他们对输出的 carry 状态不感兴趣。