Flax#

适用于 JAX络 (Neural Networks)


Flax 为使用 JAX 进行神经网络研究的研究人员和开发人员提供灵活的端到端用户体验。Flax 使您能够充分利用 JAX 的强大功能。

Flax 的核心是 NNX——一个简化的 API,使在 JAX 中创建、检查、调试和分析神经网络变得更加容易。Flax NNX 对 Python 引用语义提供了一流的支持,使用户能够使用常规 Python 对象来表达他们的模型。Flax NNX 是对先前 Flax Linen API 的演进,我们凭借多年的经验,才带来了这个更简单、更友好的 API。

注意

Flax Linen API 在短期内不会被弃用,因为大多数 Flax 用户仍然依赖此 API。但是,我们鼓励新用户使用 Flax NNX。请查看 为什么选择 Flax NNX,了解 Flax NNX 和 Linen 之间的比较,以及我们开发新 API 的原因。

要将您的 Flax Linen 代码库迁移到 Flax NNX,请先在 NNX 基础中熟悉该 API,然后按照演进指南开始迁移。

特性#

Pythonic 风格

Flax NNX 支持使用常规 Python 对象,提供直观且可预测的开发体验。

简单

Flax NNX 依赖于 Python 的对象模型,这为用户带来了简单性并提高了开发速度。

表现力强

Flax NNX 允许通过其过滤器系统对模型状态进行细粒度控制。

熟悉

Flax NNX 通过函数式 API,可以非常轻松地将对象与常规 JAX 代码集成。

基本用法#

from flax import nnx
import optax


class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

@nnx.jit  # automatic state management for JAX transforms
def train_step(model, optimizer, x, y):
  def loss_fn(model):
    y_pred = model(x)  # call methods directly
    return ((y_pred - y) ** 2).mean()

  loss, grads = nnx.value_and_grad(loss_fn)(model)
  optimizer.update(model, grads)  # in-place updates

  return loss

安装#

通过 pip 安装

pip install flax

或从代码仓库安装最新版本

pip install git+https://github.com/google/flax.git

了解更多#

Flax NNX 基础
nnx_basics.html
MNIST 教程
mnist_tutorial.html
从 Flax Linen 迁移到 Flax NNX
guides/linen_to_nnx.html