如何获取tf.data.dataset的形状?

How to acquire tf.data.dataset's shape?

我知道数据集有 output_shapes,但它显示如下:

data_set: DatasetV1Adapter shapes: {item_id_hist: (?, ?), tags: (?, ?), client_platform: (?,), entrance: (?,), item_id: (?,), lable: (?,), mode: (?,), time: (?,), user_id: (?,)}, types: {item_id_hist: tf.int64, tags: tf.int64, client_platform: tf.string, entrance: tf.string, item_id: tf.int64, lable: tf.int64, mode: tf.int64, time: tf.int64, user_id: tf.int64}

如何获取我的数据总数?

知道长度的地方可以调用:

tf.data.experimental.cardinality(dataset)

但是如果这失败了,重要的是要知道 TensorFlow Dataset 是(通常)惰性评估的,所以这意味着在一般情况下我们可能需要遍历每条记录才能找到数据集的长度。

例如,假设您启用了急切执行并且它是一个适合内存的小型 'toy' 数据集,您可以 enumerate 将它放入一个新列表并获取最后一个索引(然后添加1 因为列表是 zero-indexed):

dataset_length = [i for i,_ in enumerate(dataset)][-1] + 1

当然,这充其量是低效的,对于大型数据集,将完全失败,因为所有内容都需要适合列表的内存。在这种情况下,除了遍历保持手动计数的记录外,我看不到任何其他选择。

代码如下:

dataset_to_numpy = list(dataset.as_numpy_iterator())
shape = tf.shape(dataset_to_numpy)
print(shape)

它产生这样的输出:

tf.Tensor([1080   64   64    3], shape=(4,), dtype=int32)

写代码很简单,但是迭代数据集还是很费时间的。 有关 tf.data.Dataset 的更多信息,请查看此 link

要查看元素形状和类型,直接打印数据集元素而不是使用 as_numpy_iterator。 - https://www.tensorflow.org/api_docs/python/tf/data/Dataset

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
  print(element)

打破 for 循环以查看任何张量的形状

dataset = tf.data.Dataset.from_tensor_slices((X_s, y_s))
for element in dataset:
  print(element)
  break

此处输出两个 numpy 数组并打印每个数组的形状

(<tf.Tensor: shape=(13,), dtype=float32, numpy=
array([ 0.9521966 ,  0.68100524,  1.973123  ,  0.7639558 , -0.2563337 ,
        2.394438  , -1.0058318 ,  0.01544279, -0.69663054,  1.0873381 ,
       -2.2745786 , -0.71442884, -2.1488726 ], dtype=float32)>, <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0., 1.], dtype=float32)>)

自 2022 年 4 月 15 日起使用 TF v2.8,您可以使用

获取结果

dataset.cardinality().numpy()

参考:https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cardinality