示例:使用预训练的 Gemma 和 Flax NNX 进行推理

示例:使用预训练的 Gemma 进行 Flax NNX 推理#

本示例展示了如何使用 Flax NNX 加载 Gemma 开放模型文件,并使用它们进行采样/推理以生成文本。您将使用由 Flax 和 JAX 编写的 Flax NNX gemma 模块 进行模型参数配置和推理。

Gemma 是一个基于 Google DeepMind 的 Gemini 的轻量级、最先进的开放模型系列。阅读更多关于 GemmaGemma 2 的信息。

建议您使用具有 A100 GPU 加速访问权限的 Google Colab 来运行代码。

安装#

安装必要的依赖项,包括 kagglehub

! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope

下载模型#

要使用 Gemma 模型,您需要一个 Kaggle 账户和 API 密钥。

  1. 要创建账户,请访问 Kaggle 并点击“注册”。

  2. 如果您已有账户,您需要登录,进入您的“设置”,并在“API”下点击“创建新令牌”以生成并下载您的 Kaggle API 密钥。

  3. Google Colab 中,在“Secrets”下添加您的 Kaggle 用户名和 API 密钥,将用户名存储为 KAGGLE_USERNAME,将密钥存储为 KAGGLE_KEY。如果您正在使用 Kaggle Notebook 以获得免费的 TPU 或其他硬件加速,它在“Add-ons” > “Secrets”下有一个密钥存储功能,并附有访问已存储密钥的说明。

然后运行下面的单元格。

import kagglehub
kagglehub.login()

如果一切顺利,应该会显示 Kaggle credentials set. Kaggle credentials successfully validated.

注意:在 Google Colab 中,您可以在执行完上述可选的第 3 步后,改用以下代码进行 Kaggle 身份验证。

import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

现在,加载您想尝试的 Gemma 模型。下一个单元格中的代码利用 kagglehub.model_download 来下载模型文件。

注意:对于较大的模型,例如 gemma 7bgemma 7b-it (instruct),您可能需要一个具有足够内存的硬件加速器,例如 NVIDIA A100。

注意:为了在下载模型时避免出现 403 错误,您需要在 Kaggle 上同意 Gemma 模型的许可协议。为此,请在浏览器中打开 https://www.kaggle.com/models/google/gemma/flax/,然后点击“Download”按钮,选择任意版本的 Gemma 模型。在下一个窗口中,系统会提示您同意 Gemma 模型的使用许可协议。完成此步骤后,您就可以使用下面的代码下载模型了。

from IPython.display import clear_output

VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'

Python 导入#

from flax import nnx
import sentencepiece as spm

为了与 Gemma 模型交互,您将使用来自 google/flax GitHub 示例的 Flax NNX gemma 代码。由于它没有作为包发布,您需要使用以下变通方法从 GitHub 上的 Flax NNX examples/gemma 进行导入。

import sys
import tempfile
with tempfile.TemporaryDirectory() as tmp:
  # Create a temporary directory and clone the `flax` repo.
  # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.
  ! git clone https://github.com/google/flax.git {tmp}/flax
  sys.path.append(f"{tmp}/flax/examples/gemma")
  import params as params_lib
  import sampler as sampler_lib
  import transformer as transformer_lib
  sys.path.pop();
Cloning into '/tmp/tmp_68d13pv/flax'...
remote: Enumerating objects: 31912, done.
remote: Counting objects: 100% (605/605), done.
remote: Compressing objects: 100% (250/250), done.
remote: Total 31912 (delta 406), reused 503 (delta 352), pack-reused 31307 (from 1)
Receiving objects: 100% (31912/31912), 23.92 MiB | 18.17 MiB/s, done.
Resolving deltas: 100% (23869/23869), done.

加载并准备 Gemma 模型#

首先,加载 Gemma 模型参数以供 Flax 使用。

params = params_lib.load_and_format_params(ckpt_path)

接下来,加载使用 SentencePiece 库构建的分词器文件。

vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
True

然后,使用 Flax NNX gemma.transformer.TransformerConfig.from_params 函数从检查点自动加载正确的配置。

注意:由于此版本中存在未使用的词元,词汇表大小小于输入嵌入的数量。

transformer = transformer_lib.Transformer.from_params(params)
nnx.display(transformer)

执行采样/推理#

在您的模型和分词器之上,使用正确的参数形状构建一个 Flax NNX gemma.Sampler

sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
)

您已准备好开始采样!

注意:这个 Flax NNX gemma.Sampler 使用 JAX 的即时(JIT)编译,因此更改输入形状会触发重新编译,这可能会降低速度。为了获得最快、最高效的结果,请保持您的批处理大小一致。

input_batch 中编写一个提示并执行推理。您可以随意调整 total_generation_steps(生成响应时执行的步数)。

input_batch = [
    "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
  ]

out_data = sampler(
    input_strings=input_batch,
    total_generation_steps=300,  # The number of steps performed when generating a response.
  )

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
  print()
  print(10*'#')
Prompt:

# Python program for implementation of Bubble Sort

def bubbleSort(arr):
Output:

    for i in range(len(arr)):
        for j in range(len(arr) - i - 1):
            if arr[j] > arr[j + 1]:
                swap(arr, j, j + 1)


def swap(arr, i, j):
    temp = arr[i]
    arr[i] = arr[j]
    arr[j] = temp


# Driver code
arr = [5, 2, 8, 3, 1, 9]
print("Unsorted array:")
print(arr)
bubbleSort(arr)
print("Sorted array:")
print(arr)


# Time complexity of Bubble sort O(n^2)
# where n is the length of the array


# Space complexity of Bubble sort O(1)
# as it only requires constant extra space for the swap operation


# This program uses the bubble sort algorithm to sort the given array in ascending order.

```python
# This program uses the bubble sort algorithm to sort the given array in ascending order.

def bubbleSort(arr):
    for i in range(len(arr)):
        for j in range(len(arr) - i - 1):
            if arr[j] > arr[j + 1]:
                swap(arr, j, j + 1)


def swap(

##########

您应该会得到一个冒泡排序算法的 Python 实现。