LoRA#
NNX LoRA 类。
- class flax.nnx.LoRA(self, in_features, lora_rank, out_features, *, base_module=None, dtype=None, param_dtype=<class 'jax.numpy.float32'>, a_initializer=<function variance_scaling.<locals>.init>, b_initializer=<function zeros>, lora_param_type=<class 'flax.nnx.nn.lora.LoRAParam'>, rngs)[源代码]#
一个独立的 LoRA 层。
用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.LoRA(3, 2, 4, rngs=nnx.Rngs(0)) >>> layer.lora_a.value.shape (3, 2) >>> layer.lora_b.value.shape (2, 4) >>> # Wrap around existing layer >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> wrapper = nnx.LoRA(3, 2, 4, base_module=linear, rngs=nnx.Rngs(1)) >>> assert wrapper.base_module == linear >>> wrapper.lora_a.value.shape (3, 2) >>> layer.lora_b.value.shape (2, 4) >>> y = layer(jnp.ones((16, 3))) >>> y.shape (16, 4)
- 参数
in_features – 输入特征的数量。
lora_rank – LoRA 维度的秩。
out_features – 输出特征的数量。
base_module – 一个基础模块,如果可能的话,用于调用和替换。
dtype – 计算的数据类型(默认:从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
precision – 计算的数值精度,详情请参阅 jax.lax.Precision。
a_initializer – 扇入矩阵的初始化函数。默认为 he_uniform。
b_initializer – 扇出矩阵的初始化函数。默认为零初始化器。
lora_param_type – LoRA 参数的类型。
方法
- class flax.nnx.LoRALinear(self, in_features, out_features, *, lora_rank, lora_dtype=None, lora_param_dtype=<class 'jax.numpy.float32'>, a_initializer=<function variance_scaling.<locals>.init>, b_initializer=<function zeros>, lora_param_type=<class 'flax.nnx.nn.lora.LoRAParam'>, rngs, **kwargs)[源代码]#
一个 nnx.Linear 层,其输出将被 LoRA 化。
模型状态结构将与 Linear 的状态结构兼容。
用法示例
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> linear = nnx.Linear(3, 4, rngs=nnx.Rngs(0)) >>> lora_linear = nnx.LoRALinear(3, 4, lora_rank=2, rngs=nnx.Rngs(0)) >>> linear.kernel.value.shape (3, 4) >>> lora_linear.kernel.value.shape (3, 4) >>> lora_linear.lora.lora_a.value.shape (3, 2) >>> jnp.allclose(linear.kernel.value, lora_linear.kernel.value) Array(True, dtype=bool) >>> y = lora_linear(jnp.ones((16, 3))) >>> y.shape (16, 4)
- 参数
in_features – 输入特征的数量。
out_features – 输出特征的数量。
lora_rank – LoRA 维度的秩。
base_module – 一个基础模块,如果可能的话,用于调用和替换。
dtype – 计算的数据类型(默认:从输入和参数推断)。
param_dtype – 传递给参数初始化器的数据类型(默认:float32)。
precision – 计算的数值精度,详情请参阅 jax.lax.Precision。
a_initializer – 扇入矩阵的初始化函数。默认为 he_uniform。
b_initializer – 扇出矩阵的初始化函数。默认为零初始化器。
lora_param_type – LoRA 参数的类型。
方法