将 tf.data.Dataset 转换为 jax.numpy 迭代器
Turn a tf.data.Dataset to a jax.numpy iterator
我对使用 JAX 训练神经网络很感兴趣。我查看了 tf.data.Dataset
,但它只提供 tf 张量。我寻找了一种将数据集更改为 JAX numpy 数组的方法,并且发现了很多使用 Dataset.as_numpy_generator()
将 tf 张量转换为 numpy 数组的实现。但是我想知道这是否是一个好习惯,因为 numpy 数组存储在 CPU 内存中,这不是我想要的训练(我使用 GPU)。所以我发现的最后一个想法是通过调用 jnp.array
手动重铸数组,但这并不是很优雅(我担心 GPU 内存中的副本)。有人对此有更好的主意吗?
快速代码说明:
import os
import jax.numpy as jnp
import tensorflow as tf
def generator():
for _ in range(2):
yield tf.random.uniform((1, ))
ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
output_shapes=tf.TensorShape([1]))
ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)
for i, batch in enumerate(ds1):
print(type(batch))
for i, batch in enumerate(ds2):
print(type(jnp.array(batch)))
# returns:
<class 'numpy.ndarray'> # not good
<class 'jaxlib.xla_extension.DeviceArray'> # good but not elegant
tensorflow 和 JAX 都可以在不复制内存的情况下将数组转换为 dlpack 张量,因此可以从张量流数组创建 JAX 数组而不复制底层数据缓冲区的一种方法是通过dlpack:
import numpy as np
import tensorflow as tf
import jax.dlpack
tf_arr = tf.random.uniform((10,))
dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)
jax_arr = jax.dlpack.from_dlpack(dl_arr)
np.testing.assert_array_equal(tf_arr, jax_arr)
通过往返 JAX,您可以比较 unsafe_buffer_pointer()
以确保数组指向同一缓冲区,而不是沿途复制缓冲区:
def tf_to_jax(arr):
return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_arr))
def jax_to_tf(arr):
return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))
jax_arr = jnp.arange(20.)
tf_arr = jax_to_tf(jax_arr)
jax_arr2 = tf_to_jax(tf_arr)
print(jnp.all(jax_arr == jax_arr2))
# True
print(jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer())
# True
我对使用 JAX 训练神经网络很感兴趣。我查看了 tf.data.Dataset
,但它只提供 tf 张量。我寻找了一种将数据集更改为 JAX numpy 数组的方法,并且发现了很多使用 Dataset.as_numpy_generator()
将 tf 张量转换为 numpy 数组的实现。但是我想知道这是否是一个好习惯,因为 numpy 数组存储在 CPU 内存中,这不是我想要的训练(我使用 GPU)。所以我发现的最后一个想法是通过调用 jnp.array
手动重铸数组,但这并不是很优雅(我担心 GPU 内存中的副本)。有人对此有更好的主意吗?
快速代码说明:
import os
import jax.numpy as jnp
import tensorflow as tf
def generator():
for _ in range(2):
yield tf.random.uniform((1, ))
ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
output_shapes=tf.TensorShape([1]))
ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)
for i, batch in enumerate(ds1):
print(type(batch))
for i, batch in enumerate(ds2):
print(type(jnp.array(batch)))
# returns:
<class 'numpy.ndarray'> # not good
<class 'jaxlib.xla_extension.DeviceArray'> # good but not elegant
tensorflow 和 JAX 都可以在不复制内存的情况下将数组转换为 dlpack 张量,因此可以从张量流数组创建 JAX 数组而不复制底层数据缓冲区的一种方法是通过dlpack:
import numpy as np
import tensorflow as tf
import jax.dlpack
tf_arr = tf.random.uniform((10,))
dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)
jax_arr = jax.dlpack.from_dlpack(dl_arr)
np.testing.assert_array_equal(tf_arr, jax_arr)
通过往返 JAX,您可以比较 unsafe_buffer_pointer()
以确保数组指向同一缓冲区,而不是沿途复制缓冲区:
def tf_to_jax(arr):
return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_arr))
def jax_to_tf(arr):
return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))
jax_arr = jnp.arange(20.)
tf_arr = jax_to_tf(jax_arr)
jax_arr2 = tf_to_jax(tf_arr)
print(jnp.all(jax_arr == jax_arr2))
# True
print(jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer())
# True