NNX 演示#

import jax
from jax import numpy as jnp
from flax import nnx

[1] NNX 是 Pythonic 的#

class Block(nnx.Module):
  def __init__(self, din, dout, *, rngs):
    self.linear = nnx.Linear(din, dout, rngs=rngs)
    self.bn = nnx.BatchNorm(dout, rngs=rngs)

  def __call__(self, x):
    return nnx.relu(self.bn(self.linear(x)))


class MLP(nnx.Module):
  def __init__(self, nlayers, dim, *, rngs): # explicit RNG threading
    self.blocks = [
      Block(dim, dim, rngs=rngs) for _ in range(nlayers)
    ]
    self.count = Count(0)  # stateful variables are defined as attributes

  def __call__(self, x):
    self.count.value += 1  # in-place stateful updates
    for block in self.blocks:
      x = block(x)
    return x

class Count(nnx.Variable):   # custom Variable types define the "collections"
  pass

model = MLP(5, 4, rngs=nnx.Rngs(0))  # no special `init` method
model.set_attributes(use_running_average=False)  # set flags
y = model(jnp.ones((2, 4)))  # call methods directly

print(f'{model = }'[:500] + '\n...')
model = MLP(
  blocks=[Block(
      linear=Linear(
            in_features=4,
            out_features=4,
            use_bias=True,
            dtype=None,
            param_dtype=<class 'jax.numpy.float32'>,
            precision=None,
            kernel_init=<function variance_scaling.<locals>.init at 0x28ae86dc0>,
            bias_init=<function zeros at 0x122d39f70>,
            dot_general=<function dot_general at 0x1218459d0>
          ),
      bn=BatchNorm(
            num_features=4,
  
...

因为 NNX 模块包含其自身的状态,所以它们非常容易检查

print(f'{model.count = }')
print(f'{model.blocks[0].linear.kernel = }')
# print(f'{model.blocks.sdf.kernel = }') # typesafe inspection
model.count = Count(
  raw_value=1
)
model.blocks[0].linear.kernel = Param(
  raw_value=Array([[-0.80345297, -0.34071913, -0.9408296 ,  0.01005968],
         [ 0.26146442,  1.1247735 ,  0.54563737, -0.374164  ],
         [ 1.0281805 , -0.6798804 , -0.1488401 ,  0.05694951],
         [-0.44308168, -0.60587114,  0.434087  , -0.40541083]],      dtype=float32)
)

[2] 模型修改直观易懂#

# Module sharing
model.blocks[1] = model.blocks[3]
# Weight tying
model.blocks[0].linear.kernel = model.blocks[-1].linear.kernel
# Monkey patching
def my_optimized_layer(x): return x
model.blocks[2] = my_optimized_layer

y = model(jnp.ones((2, 4)))  # still works
print(f'{y.shape = }')
y.shape = (2, 4)

[3] 与 JAX 交互简单轻松#

graphdef, state = model.split()

# state is a dictionary-like JAX pytree
print(f'{state = }'[:500] + '\n...')

# graphdef is also a JAX pytree, but just metadata
print(f'\n{graphdefefefefefef = }'[:300] + '\n...')
state = State({
  'blocks': {
    '0': {
      'linear': {
        'kernel': Param(
          raw_value=Array([[-0.33095378,  0.67149884,  0.33700302,  0.30972847],
                 [ 0.8662822 , -0.11225506, -1.0820619 , -0.9906892 ],
                 [ 0.88298297, -0.2143851 ,  0.48143268,  0.6474548 ],
                 [-0.7710582 ,  0.3372276 ,  0.15487202,  0.6219269 ]],      dtype=float32)
        ),
        'bias': Param(
          raw_value=Array([0., 0., 0., 0.], dtype=float32)
        
...

graphdef = GraphDef(
  type=MLP,
  index=0,
  attributes=('blocks', 'count'),
  subgraphs={
    'blocks': GraphDef(
      type=list,
      index=1,
      attributes=('0', '1', '2', '3', '4'),
      subgraphs={
        '0': GraphDef(
          type=Block,
          index=2,
          attributes=('line
...
graphdef, state = model.split()

@jax.jit
def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array):
  model = graphdef.merge(state)
  y = model(x)
  state, _ = model.split()
  return y, state

x = jnp.ones((2, 4))
y, state = forward(graphdef,state, x)

model.update(state)

print(f'{y.shape = }')
print(f'{model.count.value = }')
y.shape = (2, 4)
model.count.value = Array(3, dtype=int32, weak_type=True)
params, batch_stats, counts, graphdef = model.split(nnx.Param, nnx.BatchStat, Count)

@jax.jit
def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):
  model = graphdef.merge(params, batch_stats, counts)
  y = model(x, train=True)
  params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)
  return y, params, batch_stats, counts

x = jnp.ones((2, 4))
y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x)

model.update(params, batch_stats, counts)

print(f'{y.shape = }')
print(f'{model.count = }')
y.shape = (2, 4)
model.count = Array(4, dtype=int32, weak_type=True)
class Parent(nnx.Module):
    def __init__(self, model: MLP):
        self.model = model

    def __call__(self, x):
        params, batch_stats, counts, graphdef = self.model.split(nnx.Param, nnx.BatchStat, Count)

        @jax.jit
        def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):
            model = graphdef.merge(params, batch_stats, counts)
            y = model(x)
            params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)
            return y, params, batch_stats, counts

        y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x)

        self.model.update(params, batch_stats, counts)
        return y

parent = Parent(model)

y = parent(jnp.ones((2, 4)))

print(f'{y.shape = }')
print(f'{parent.model.count.value = }')
y.shape = (2, 4)
parent.model.count.value = Array(4, dtype=int32, weak_type=True)