如何使用 JAX 进行打印
How to print with JAX
我有一个 JAX 布尔数组,想打印一个语句结合 Trues:
import jax
import jax.numpy as jnp
from jax.experimental.host_callback import id_print
@jax.jit
def overlaps_jax():
mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
id_print(jnp.sum(mask_cp))
overlaps_jax()
mask_cp
中有5个True;我想打印为:
With jax accelerator
There are 5 true bools
因为这个函数是jitted, I tried to print this by using id_print
,但是我做不到。 id_print(jnp.sum(mask_cp))
将打印 5
,但我无法将其用于字符串。我尝试了以下方法:
id_print(jnp.sum(mask_cp))
# print:
# 5
id_print("\nWith jax accelerator\nThere are " + jnp.sum(mask_cp) + " true bools\n")
# error:
# TypeError: can only concatenate str (not "DynamicJaxprTracer") to str
print("\nWith jax accelerator\nThere are {} true bools\n".format(jnp.sum(mask_cp)))
# print:
# With jax accelerator
# There are Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> true bools
如何在这段代码中打印这样的语句?
请注意 id_print
是实验性的,其 API 和功能可能会发生变化。也就是说,我不相信 id_print
有能力添加这样的文本,但你可以通过更通用的 host_callback.call
:
import jax
import jax.numpy as jnp
from jax.experimental.host_callback import call
@jax.jit
def overlaps_jax():
mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
call(lambda x: print(f"There are {x} true bools"), jnp.sum(mask_cp))
overlaps_jax()
输出为
There are 5 true bools
我有一个 JAX 布尔数组,想打印一个语句结合 Trues:
import jax
import jax.numpy as jnp
from jax.experimental.host_callback import id_print
@jax.jit
def overlaps_jax():
mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
id_print(jnp.sum(mask_cp))
overlaps_jax()
mask_cp
中有5个True;我想打印为:
With jax accelerator
There are 5 true bools
因为这个函数是jitted, I tried to print this by using id_print
,但是我做不到。 id_print(jnp.sum(mask_cp))
将打印 5
,但我无法将其用于字符串。我尝试了以下方法:
id_print(jnp.sum(mask_cp))
# print:
# 5
id_print("\nWith jax accelerator\nThere are " + jnp.sum(mask_cp) + " true bools\n")
# error:
# TypeError: can only concatenate str (not "DynamicJaxprTracer") to str
print("\nWith jax accelerator\nThere are {} true bools\n".format(jnp.sum(mask_cp)))
# print:
# With jax accelerator
# There are Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> true bools
如何在这段代码中打印这样的语句?
请注意 id_print
是实验性的,其 API 和功能可能会发生变化。也就是说,我不相信 id_print
有能力添加这样的文本,但你可以通过更通用的 host_callback.call
:
import jax
import jax.numpy as jnp
from jax.experimental.host_callback import call
@jax.jit
def overlaps_jax():
mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
call(lambda x: print(f"There are {x} true bools"), jnp.sum(mask_cp))
overlaps_jax()
输出为
There are 5 true bools