one_hot 批处理的编码将是不完整的 tensorflow

one_hot encoding for batches will be incomplete tensorflow

如您所知,tf.one_hot 可以进行一次性编码。但是,当我的数据集非常大时,我需要进行批量训练。这样,当我使用 for 循环遍历所有批次时,在每次迭代中,当我执行 tf.one_hot 时,一个热矩阵的维度将小于我的预期。

例如,对于列 'a',我们有 47 个类别,但在一个批次中它们可能只显示 20 个,当我对这批次执行 one_hot 时,它将创建一个矩阵行的维度 * 20 而不是行的维度 * 47。

如何在每个批次中得到一个行数 * 47 的一个热矩阵?

谢谢!

tf.one_hot() 将参数 depth 作为其第二个参数,它决定了 one-hot 向量应该有多长。如果你运行你的操作是这样的:

b = tf.one_hot( a, 47 )

它应该给你最后一个维度 47。

没有代码很难说,但有些人不会硬编码 one_hot 大小,而是尝试从标签张量中获取它,比如

max_class = tf.reduce_max( a )
b = tf.one_hot( a, max_class )

如果您的代码是这种情况,那么可能一批只增加到 class 20。

不然需要看你的代码才能说点什么

如果 TensorFlow 运行内存不足,它会因错误而停止,不会默默地吞掉一半的数据。 :)