激活函数#
- flax.nnx.celu(x, alpha=1.0)[源代码]#
连续可微指数线性单元激活函数。
按元素计算函数
\[\begin{split}\mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases}\end{split}\]更多信息,请参阅连续可微指数线性单元。
- 参数
x – 输入数组
alpha – 数组或标量(默认值:1.0)
- 返回
一个数组。
- flax.nnx.elu(x, alpha=1.0)[源代码]#
指数线性单元激活函数。
按元素计算函数
\[\begin{split}\mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases}\end{split}\]- 参数
x – 输入数组
alpha – alpha 值的标量或数组(默认值:1.0)
- 返回
一个数组。
另请参阅
- flax.nnx.gelu(x, approximate=True)[源代码]#
高斯误差线性单元激活函数。
如果
approximate=False
,按元素计算函数\[\mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( \frac{-x}{\sqrt{2}} \right) \right)\]如果
approximate=True
,则使用 GELU 的近似公式\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]更多信息,请参阅高斯误差线性单元 (GELU),第 2 节。
- 参数
x – 输入数组
approximate – 是否使用近似或精确公式。
- flax.nnx.glu(x, axis=-1)[源代码]#
门控线性单元激活函数。
计算函数
\[\mathrm{glu}(x) = x\left[\ldots, 0:\frac{n}{2}, \ldots\right] \cdot \mathrm{sigmoid} \left( x\left[\ldots, \frac{n}{2}:n, \ldots\right] \right)\]其中数组沿
axis
分为两半。axis
维度的大小必须能被 2 整除。- 参数
x – 输入数组
axis – 进行分割的轴(默认值:-1)
- 返回
一个数组。
另请参阅
- flax.nnx.hard_sigmoid(x)[源代码]#
硬 Sigmoid 激活函数。
按元素计算函数
\[\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}\]- 参数
x – 输入数组
- 返回
一个数组。
另请参阅
relu6()
- flax.nnx.hard_silu(x)[源代码]#
硬 SiLU (swish) 激活函数
按元素计算函数
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]hard_silu()
和hard_swish()
都是同一函数的别名。- 参数
x – 输入数组
- 返回
一个数组。
另请参阅
- flax.nnx.hard_swish(x)#
硬 SiLU (swish) 激活函数
按元素计算函数
\[\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)\]hard_silu()
和hard_swish()
都是同一函数的别名。- 参数
x – 输入数组
- 返回
一个数组。
另请参阅
- flax.nnx.hard_tanh(x)[源代码]#
硬 \(\mathrm{tanh}\) 激活函数。
按元素计算函数
\[\begin{split}\mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases}\end{split}\]- 参数
x – 输入数组
- 返回
一个数组。
- flax.nnx.leaky_relu(x, negative_slope=0.01)[源代码]#
带泄漏修正线性单元激活函数。
按元素计算函数
\[\begin{split}\mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases}\end{split}\]其中 \(\alpha\) =
negative_slope
。- 参数
x – 输入数组
negative_slope – 指定负斜率的数组或标量(默认值:0.01)
- 返回
一个数组。
另请参阅
- flax.nnx.log_sigmoid(x)[源代码]#
对数 Sigmoid 激活函数。
按元素计算函数
\[\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})\]- 参数
x – 输入数组
- 返回
一个数组。
另请参阅
- flax.nnx.log_softmax(x, axis=-1, where=None)[源代码]#
Log-Softmax 函数。
计算
softmax
函数的对数,该函数将元素重新缩放到 \([-\infty, 0)\) 范围内。\[\mathrm{log\_softmax}(x)_i = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right)\]- 参数
x – 输入数组
axis – 计算
log_softmax
的轴。整数或整数元组。where – 要包含在
log_softmax
中的元素。任何被屏蔽的元素的输出都为负无穷大。
- 返回
一个数组。
注意
如果任何输入值为
+inf
,结果将全部为NaN
:这反映了inf / inf
在浮点数学上下文中没有明确定义的事实。另请参阅
- flax.nnx.logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None)[源代码]#
对数-和-指数规约。
scipy.special.logsumexp()
的 JAX 实现。\[\operatorname{logsumexp} a = \log \sum_i b_i \exp a_i\]其中 \(i\) 索引遍历一个或多个要规约的维度。
- 参数
a – 输入数组
axis – int 或 int 序列,默认值为 None。计算和的轴。如果为 None,则沿所有轴计算和。
b – 指数函数的缩放因子。必须可广播到 a 的形状。
keepdims – 如果为
True
,则被规约的轴将作为大小为 1 的维度保留在输出中。return_sign – 如果为
True
,输出将是一个(result, sign)
对,其中sign
是和的符号,result
包含其绝对值的对数。如果为False
,则只返回result
,如果和为负,则它将包含 NaN 值。where – 要包含在规约中的元素。
- 返回
数组
result
或数组对(result, sign)
,具体取决于return_sign
参数的值。
- flax.nnx.one_hot(x, num_classes, *, dtype=<class 'numpy.float64'>, axis=-1)[源代码]#
对给定索引进行独热编码。
输入
x
中的每个索引都编码为一个长度为num_classes
的零向量,并将index
处的元素设置为 1。>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) Array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
超出 [0, num_classes) 范围的索引将被编码为零。
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- 参数
x – 索引张量。
num_classes – 独热维度中的类别数。
dtype – 可选,返回值的浮点数据类型(默认为
jnp.float_
)。axis – 计算函数所沿的轴。
- flax.nnx.relu(x)[源代码]#
修正线性单元激活函数。
按元素计算函数
\[\mathrm{relu}(x) = \max(x, 0)\]但在微分下,我们取
\[\nabla \mathrm{relu}(0) = 0\]更多信息,请参阅ReLU'(0) 对反向传播的数值影响。
- 参数
x – 输入数组
- 返回
一个数组。
示例
>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.])) Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)
另请参阅
relu6()
- flax.nnx.selu(x)[源代码]#
缩放指数线性单元激活函数。
按元素计算函数
\[\begin{split}\mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases}\end{split}\]其中 \(\lambda = 1.0507009873554804934193349852946\) 且 \(\alpha = 1.6732632423543772848170429916717\)。
更多信息,请参阅自归一化神经网络。
- 参数
x – 输入数组
- 返回
一个数组。
另请参阅
- flax.nnx.sigmoid(x)[源代码]#
Sigmoid 激活函数。
按元素计算函数
\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\]- 参数
x – 输入数组
- 返回
一个数组。
另请参阅
- flax.nnx.silu(x)[源代码]#
SiLU (又名 swish) 激活函数。
按元素计算函数
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]- 参数
x – 输入数组
- 返回
一个数组。
另请参阅
- flax.nnx.soft_sign(x)[源代码]#
Soft-sign 激活函数。
按元素计算函数
\[\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}\]- 参数
x – 输入数组
- flax.nnx.softmax(x, axis=-1, where=None)[源代码]#
Softmax 函数。
计算该函数,它将元素重新缩放到 \([0, 1]\) 范围内,使得沿
axis
的元素之和为 \(1\)。\[\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\]- 参数
x – 输入数组
axis – 计算 softmax 所沿的轴。这些维度上的 softmax 输出总和应为 \(1\)。整数或整数元组。
where – 要包含在
softmax
中的元素。任何被屏蔽的元素的输出都为零。
- 返回
一个数组。
注意
如果任何输入值为
+inf
,结果将全部为NaN
:这反映了inf / inf
在浮点数学上下文中没有明确定义的事实。另请参阅
- flax.nnx.softplus(x)[源代码]#
Softplus 激活函数。
按元素计算函数
\[\mathrm{softplus}(x) = \log(1 + e^x)\]- 参数
x – 输入数组
- flax.nnx.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[源代码]#
将输入标准化为零均值和单位方差。
标准化由以下公式给出
\[x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}}\]其中 \(\langle x\rangle\) 表示 \(x\) 的均值,\(\epsilon\) 是为避免除以零而引入的一个小修正因子。
- 参数
x – 要标准化的输入数组。
axis – 表示标准化所沿轴的整数或整数元组。默认为最后一个轴(
-1
)。mean – 可选地指定用于标准化的均值。如果未指定,则将使用
x.mean(axis, where=where)
。variance – 可选地指定用于标准化的方差。如果未指定,则将使用
x.var(axis, where=where)
。epsilon – 添加到方差中以避免除以零的修正因子;默认为
1E-5
。where – 可选的布尔掩码,指定在计算均值和方差时使用哪些元素。
- 返回
一个与
x
形状相同的数组,包含标准化的输入。
- flax.nnx.swish(x)#
SiLU (又名 swish) 激活函数。
按元素计算函数
\[\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}\]- 参数
x – 输入数组
- 返回
一个数组。
另请参阅
- flax.nnx.tanh(x, /)#
按元素计算输入的双曲正切。
numpy.tanh
的 JAX 实现。双曲正切定义为
\[tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}\]- 参数
x – 输入数组或标量。
- 返回
一个包含
x
中每个元素的双曲正切的数组,会提升为不精确的数据类型。
注意
jnp.tanh
等效于计算-1j * jnp.tan(1j * x)
。另请参阅
jax.numpy.sinh()
:按元素计算输入的双曲正弦。jax.numpy.cosh()
:按元素计算输入的双曲余弦。jax.numpy.arctanh()
:按元素计算输入的反双曲正切。
示例
>>> x = jnp.array([[-1, 0, 1], ... [3, -2, 5]]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.tanh(x) Array([[-0.762, 0. , 0.762], [ 0.995, -0.964, 1. ]], dtype=float32) >>> with jnp.printoptions(precision=3, suppress=True): ... -1j * jnp.tan(1j * x) Array([[-0.762+0.j, 0. -0.j, 0.762-0.j], [ 0.995-0.j, -0.964+0.j, 1. -0.j]], dtype=complex64, weak_type=True)
对于复数值输入
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.tanh(2-5j) Array(1.031+0.021j, dtype=complex64, weak_type=True) >>> with jnp.printoptions(precision=3, suppress=True): ... -1j * jnp.tan(1j * (2-5j)) Array(1.031+0.021j, dtype=complex64, weak_type=True)