数据类型#

flax.nnx.nn.dtypes.canonicalize_dtype(*args, dtype=None, inexact=True)[源代码]#

将可选的数据类型(dtype)规范化为最终的数据类型。

如果 dtype 为 None,此函数将推断数据类型。如果它不为 None,将原样返回,或者在数据类型无效时引发异常。它会使用 jnp.result_type 从输入参数中推断。

参数
  • *args – JAX 数组兼容值。None 值将被忽略。

  • dtype – 可选的数据类型覆盖。如果指定,参数将被转换为指定的数据类型,并且禁用数据类型推断。

  • inexact – 当为 True 时,输出数据类型必须是 jnp.inexact 的子类型。

  • 不精确的数据类型是实数或复数浮点数。)–

  • 当您想应用不直接适用于整数的操作时,这很有用,)–

  • 例如求平均值。

返回

*args 应被转换成的数据类型。

flax.nnx.nn.dtypes.promote_dtype(args, /, *, dtype=None, inexact=True)[源代码]#

“将输入参数提升为指定或推断的数据类型。

所有参数都被转换为相同的数据类型。关于如何确定此数据类型,请参见 canonicalize_dtype

promote_dtype 的行为主要是一个围绕 jax.numpy.promote_types 的便捷包装器。不同之处在于它会自动将所有输入转换为推断的数据类型,允许通过强制指定的数据类型来覆盖推断,并有一个可选的检查以保证结果数据类型是不精确的。

参数
  • *args – JAX 数组兼容值。None 值将原样返回。

  • dtype – 可选的数据类型覆盖。如果指定,参数将被转换为指定的数据类型,并且禁用数据类型推断。

  • inexact – 当为 True 时,输出数据类型必须是 jnp.inexact 的子类型。

  • 不精确的数据类型是实数或复数浮点数。)–

  • 当您想应用不直接适用于整数的操作时,这很有用,)–

  • 例如求平均值。

返回

被转换为相同数据类型数组的参数。