可变数组(实验性)#

from flax import nnx
import jax
import jax.numpy as jnp
import jax.experimental
import optax

基础知识#

可变数组入门#

m_array = jax.experimental.mutable_array(jnp.array([1, 2, 3]))

@jax.jit
def increment(m_array: jax.experimental.MutableArray):  # no return!
  array: jax.Array = m_array[...]  # access
  m_array[...] = array + 1         # update

print("[1] =", m_array); increment(m_array); print("[2] =", m_array)
[1] = MutableArray([1, 2, 3], dtype=int32)
[2] = MutableArray([2, 3, 4], dtype=int32)
@jax.jit
def inc(x):
  x[...] += 1

print(increment.lower(m_array).as_text())
module @jit_increment attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32> {tf.aliasing_output = 0 : i32}) -> (tensor<3xi32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<1> : tensor<i32>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i32>) -> tensor<3xi32>
    %1 = stablehlo.add %arg0, %0 : tensor<3xi32>
    return %1 : tensor<3xi32>
  }
}

可变变量#

variable = nnx.Variable(jnp.array([1, 2, 3]), mutable=True)
print(f"{variable.mutable = }\n")

print("[1] =", variable); increment(variable); print("[2] =", variable)
variable.mutable = True

[1] = Variable( # 3 (12 B)
  value=MutableArray([1, 2, 3], dtype=int32)
)
[2] = Variable( # 3 (12 B)
  value=MutableArray([2, 3, 4], dtype=int32)
)
with nnx.use_mutable_arrays(True):
  variable = nnx.Variable(jnp.array([1, 2, 3]))

print(f"{variable.mutable = }")
variable.mutable = True

更改状态#

class Linear(nnx.Module):
  def __init__(self, in_features, out_features, rngs: nnx.Rngs):
    self.kernel = nnx.Param(jax.random.normal(rngs(), (in_features, out_features)))
    self.bias = nnx.Param(jnp.zeros(out_features))

  def __call__(self, x):
    return x @ self.kernel + self.bias[None]

model = Linear(1, 3, rngs=nnx.Rngs(0)) # without mutable arrays
mutable_model = nnx.mutable(model) # convert to mutable arrays
frozen_model = nnx.freeze(mutable_model) # freeze mutable arrays again

print("nnx.mutable(model) =", mutable_model)
print("nnx.freeze(mutable_model) =", frozen_model)
nnx.mutable(model) = Linear( # Param: 6 (24 B)
  bias=Param( # 3 (12 B)
    value=MutableArray(shape=(3,), dtype=dtype('float32'))
  ),
  kernel=Param( # 3 (12 B)
    value=MutableArray(shape=(1, 3), dtype=dtype('float32'))
  )
)
nnx.freeze(mutable_model) = Linear( # Param: 6 (24 B)
  bias=Param( # 3 (12 B)
    value=Array(shape=(3,), dtype=dtype('float32'))
  ),
  kernel=Param( # 3 (12 B)
    value=Array(shape=(1, 3), dtype=dtype('float32'))
  )
)

示例#

class Block(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.1, rngs=rngs)
    self.linear_out = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.gelu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)

训练循环#

with nnx.use_mutable_arrays(True):
  model = Block(2, 64, 3, rngs=nnx.Rngs(0))
  optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)

@jax.jit
def train_step(model, optimizer, x, y):
  graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)
  def loss_fn(params):
    model =  nnx.merge(graphdef, params, nondiff)
    return ((model(x) - y) ** 2).mean()

  loss, grads = jax.value_and_grad(loss_fn)(nnx.freeze(params))  # freeze MutableArrays for jax.grad
  optimizer.update(model, grads)

  return loss

train_step(model, optimizer, x=jnp.ones((10, 2)), y=jnp.ones((10, 3)))
Array(1.0002378, dtype=float32)

扫描层#

@nnx.vmap
def create_stack(rngs):
  return Block(2, 64, 2, rngs=rngs)

with nnx.use_mutable_arrays(True):
  block_stack = create_stack(nnx.Rngs(0).fork(split=8))

def scan_fn(x, block):
  x = block(x)
  return x, None

x = jax.random.uniform(jax.random.key(0), (3, 2))
y, _ = jax.lax.scan(scan_fn, x, block_stack)

print("y = ", y)
y =  [[ 0.82836264 -0.25364825]
 [ 4.955331    4.9364624 ]
 [-7.672193   -3.4669733 ]]

限制#

MutableArray 输出#

@jax.jit
def create_model(rngs):
  return Block(2, 64, 3, rngs=rngs)

try:
  with nnx.use_mutable_arrays(True):
    model = create_model(nnx.Rngs(0))
except Exception as e:
  print(f"Error:", e)
with nnx.use_mutable_arrays(False): # <-- disable mutable arrays
  model = create_model(nnx.Rngs(0))

model = nnx.mutable(model) # convert to mutable after creation

print("model.linear =", model.linear)
model.linear = Linear( # Param: 192 (768 B)
  bias=Param( # 64 (256 B)
    value=MutableArray(shape=(64,), dtype=dtype('float32'))
  ),
  kernel=Param( # 128 (512 B)
    value=MutableArray(shape=(2, 64), dtype=dtype('float32'))
  )
)
@nnx.jit
def create_model(rngs):
  return Block(2, 64, 3, rngs=rngs)

with nnx.use_mutable_arrays(True):
  model = create_model(nnx.Rngs(0))

print("model.linear =", model.linear)
model.linear = Linear( # Param: 192 (768 B)
  bias=Param( # 64 (256 B)
    value=MutableArray(shape=(64,), dtype=dtype('float32'))
  ),
  kernel=Param( # 128 (512 B)
    value=MutableArray(shape=(2, 64), dtype=dtype('float32'))
  )
)

引用共享(别名)#

def get_error(f, *args):
  try:
    return f(*args)
  except Exception as e:
    return f"{type(e).__name__}: {e}"
  
x = jax.experimental.mutable_array(jnp.array(0))

@jax.jit
def f(a, b):
  ...

print(get_error(f, x, x))
None
class SharedVariables(nnx.Object):
  def __init__(self):
    self.a = nnx.Variable(jnp.array(0))
    self.b = self.a

class SharedModules(nnx.Object):
  def __init__(self):
    self.a = Linear(1, 1, rngs=nnx.Rngs(0))
    self.b = self.a

@jax.jit
def g(pytree):
  ...

with nnx.use_mutable_arrays(True):
  shared_variables = SharedVariables()
  shared_modules = SharedModules()

print("SharedVariables", get_error(g, shared_variables))
print("SharedModules", get_error(g, shared_modules))
SharedVariables None
SharedModules None
@jax.jit
def h(graphdef, state):
  obj = nnx.merge(graphdef, state)
  obj.a[...] += 10

graphdef, state = nnx.split(shared_variables)
print(state) # split deduplicates the state

h(graphdef, state)

print("updated", shared_variables)
State({
  'a': Variable( # 1 (4 B)
    value=MutableArray(0, dtype=int32, weak_type=True)
  )
})
updated SharedVariables( # Variable: 1 (4 B)
  a=Variable( # 1 (4 B)
    value=MutableArray(10, dtype=int32)
  ),
  b=Variable( # 1 (4 B)
    value=MutableArray(10, dtype=int32)
  )
)