数据集中张量的 ref() 不相等。为什么?

ref() of tensor not equal in dataset. Why?

我对以下行为感到很困惑。参加这个节目:

import tensorflow_datasets as tfds

# %% Train dataset
(ds_train_original, ds_test_original), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

iterator = iter(ds_train_original)
el = iterator.get_next()[0]
el[0].ref() == el[0].ref()   # <- this should be True

最后一行 IMO 应该 return True。然而,这是False。 我不明白为什么。

根据 ref 文档:

Returns a hashable reference object to this Tensor. The primary use case for this API is to put tensors in a set/dictionary.

我的理解是,您应该能够使用 ref() 来检查 Tensor 之间的相等性。 一旦我提取了 ref,问题就不会再发生了。 例如,这是真的:

a_ref = el[0].ref()
a_deref = a_ref.deref()
another_ref = a_deref.ref()
a_ref == another_ref

所以“问题”似乎仅限于从 iterator.

中提取 ref()

任何人都可以向我解释发生了什么以及为什么 el[0].ref() == el[0].ref()False 吗?

发布后 issue on Github, it seems like the only viable solution is to compare the samples values, since only weakrefs 被创建。

因此解决方案是:

import tensorflow_datasets as tfds

# %% Train dataset
(ds_train_original, ds_test_original), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

iterator = iter(ds_train_original)
el = iterator.get_next()[0]
(el[0].numpy() == el[0].numpy()).all()