指标 (Metrics)#

class flax.nnx.metrics.Metric(self)#

指标的基类。任何继承 Metric 的类都应实现 computeresetupdate 方法。

__init__()#
compute()#

计算并返回 Metric 的值。

reset()#

原地重置 Metric

update(**kwargs)#

原地更新 Metric

class flax.nnx.metrics.Average(self, argname='values')#

平均值指标。

用法示例

>>> import jax.numpy as jnp
>>> from flax import nnx

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics = nnx.metrics.Average()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Array(2.5, dtype=float32)
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Array(2., dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)
__init__(argname='values')#

传入一个字符串,表示 update() 将用于获取新值的关键字参数。例如,将指标构造为 avg = Average('test') 将允许您使用 avg.update(test=new_value) 进行更新。

参数

argname – 一个可选字符串,表示 update() 将用于获取新值的关键字参数。默认为 'values'

compute()#

计算并返回平均值。

reset()#

重置此 Metric

update(**kwargs)#

原地更新此 Metric。此方法将使用 kwargs[self.argname] 的值来更新指标,其中 self.argname 在构造时定义。

参数

**kwargs – 关键字参数,包含一个 self.argname 条目,该条目映射到我们想要用来更新此指标的值。

class flax.nnx.metrics.Accuracy(self, threshold=None, *args, **kwargs)#

准确率指标。此指标继承自 Average,因此它们共享相同的 resetcompute 方法实现。与 Average 不同,在构造 Accuracy 时无需传入字符串。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([0, 1, 1, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])

>>> metrics = nnx.metrics.Accuracy()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(logits=logits, labels=labels)
>>> metrics.compute()
Array(0.6, dtype=float32)
>>> metrics.update(logits=logits2, labels=labels2)
>>> metrics.compute()
Array(0.4, dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)

>>> logits3 = jax.random.normal(jax.random.key(2), (5,))
>>> labels3 = jnp.array([0, 1, 0, 1, 1])
>>> accuracy = nnx.metrics.Accuracy(threshold=0.5)
>>> accuracy.update(logits=logits3, labels=labels3)
>>> accuracy.compute()
Array(0.8, dtype=float32)
update(*, logits, labels, **_)#

原地更新此 Metric

参数
  • logits – 输出的预测激活值。对于多类别分类,这些值在与标签比较之前会进行 argmax 操作(在最后一个维度上)。对于二元分类,这些值直接与标签进行比较。

  • labels – 真实的整数标签。

class flax.nnx.metrics.Welford(self, argname='values')#

使用 Welford 算法计算数据流的均值和方差。

用法示例

>>> import jax.numpy as jnp
>>> from flax import nnx

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics = nnx.metrics.Welford()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
>>> metrics.reset()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
__init__(argname='values')#

传入一个字符串,表示 update() 将用于获取新值的关键字参数。例如,将指标构造为 wf = Welford('test') 将允许您使用 wf.update(test=new_value) 进行更新。

参数

argname – 一个可选字符串,表示 update() 将用于获取新值的关键字参数。默认为 'values'

compute()#

计算并返回一个 Statistics 数据类对象中的均值和方差统计数据。

reset()#

重置此 Metric

update(**kwargs)#

原地更新此 Metric。此方法将使用 kwargs[self.argname] 的值来更新指标,其中 self.argname 在构造时定义。

参数

**kwargs – 关键字参数,包含一个 self.argname 条目,该条目映射到我们想要用来更新此指标的值。

class flax.nnx.metrics.MultiMetric(self, **metrics)#

MultiMetric 类用于存储多个指标并通过一次调用更新它们。

用法示例

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> metrics = nnx.MultiMetric(
...   accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()
... )

>>> metrics
MultiMetric( # MetricState: 4 (16 B)
  accuracy=Accuracy( # MetricState: 2 (8 B)
    threshold=None,
    argname='values',
    total=MetricState( # 1 (4 B)
      value=Array(0., dtype=float32)
    ),
    count=MetricState( # 1 (4 B)
      value=Array(0, dtype=int32)
    )
  ),
  loss=Average( # MetricState: 2 (8 B)
    argname='values',
    total=MetricState( # 1 (4 B)
      value=Array(0., dtype=float32)
    ),
    count=MetricState( # 1 (4 B)
      value=Array(0, dtype=int32)
    )
  )
)

>>> metrics.accuracy
Accuracy( # MetricState: 2 (8 B)
  threshold=None,
  argname='values',
  total=MetricState( # 1 (4 B)
    value=Array(0., dtype=float32)
  ),
  count=MetricState( # 1 (4 B)
    value=Array(0, dtype=int32)
  )
)

>>> metrics.loss
Average( # MetricState: 2 (8 B)
  argname='values',
  total=MetricState( # 1 (4 B)
    value=Array(0., dtype=float32)
  ),
  count=MetricState( # 1 (4 B)
    value=Array(0, dtype=int32)
  )
)

>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([0, 1, 1, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])

>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])

>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
>>> metrics.compute()
{'accuracy': Array(0.6, dtype=float32), 'loss': Array(2.5, dtype=float32)}
>>> metrics.update(logits=logits2, labels=labels2, values=batch_loss2)
>>> metrics.compute()
{'accuracy': Array(0.4, dtype=float32), 'loss': Array(2., dtype=float32)}
>>> metrics.reset()
>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
__init__(**metrics)#

向构造函数传入关键字参数,例如 MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)

参数

**metrics – 将用于访问相应 Metric 的关键字参数。

compute()#

计算并返回所有底层 Metric 的值。此方法将返回一个字典,将字符串(由传递给构造函数的关键字参数 **metrics 定义)映射到相应的指标值。

reset()#

重置所有底层的 Metric

update(**updates)#

原地更新此 MultiMetric 中的所有底层 Metric。所有 **updates 将被传递给所有底层 Metricupdate 方法。

参数

**updates – 将传递给底层 Metricupdate 方法的关键字参数。