使用 TF 数据集和 Eager 创建有状态计数器
Create stateful counter with TF Dataset and Eager
我正在尝试在我的 Tensorflow 数据集管道中添加一个累加器。基本上,我有这个:
def _filter_bcc_labels(self, labels, labels_table, bcc_count):
bg_counter = tf.zeros(shape=(), dtype=tf.int32)
def _add_to_counter():
tf.add(bg_counter, 1)
# Here the bg_counter is always equal to 0
tf.Print(bg_counter, [bg_counter])
return tf.constant(True)
return tf.cond(tf.greater_equal(bg_counter, tf.constant(bcc_count, dtype=tf.int32)),
true_fn=lambda: tf.constant(False),
false_fn=_add_to_counter)
ds = ds.filter(lambda file, position, img, lbls: self._filter_bcc_labels(lbls, {"BCC": 0, "BACKGROUND": 1}, 10))
我的目标是在达到 tf.cond
false_fn
时递增 bg_counter
但我的变量始终具有值 0,它实际上从未递增。
有人可以向我解释发生了什么吗?
请记住,我正在使用 TF eager,我不能使用 ds.make_initializable_iterator()
然后输入我的 bg_counter
初始值。
谢谢
我认为您要执行的操作需要 assign_add() 方法而不是添加方法。请注意,参数必须是一个变量。
还请注意 tf.cond 在 eager 之外的一般用途。 是关于相同的讨论。
您可能希望将计数器包装在 class 中,因为 Eager 中的变量在 运行 超出范围时会被删除。
代码:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
tf.enable_eager_execution()
import tensorflow.contrib.eager as tfe
dataset = tf.data.Dataset.from_tensor_slices(([1,2,3,4,5], [-1,-2,-3,-4,-5]))
class My(object):
def __init__(self):
self.x = tf.get_variable("mycounter", initializer=lambda: tf.zeros(shape=[], dtype=tf.float32), dtype=tf.float32
, trainable=False)
v = My()
print(v.x)
tf.assign(v.x,tf.add(v.x,1.0))
print(v.x)
def map_fn(x,v):
tf.cond(tf.greater_equal(v.x, tf.constant(5.0))
,lambda: tf.constant(0.0)
,lambda: tf.assign(v.x,tf.add(v.x,1.0))
)
return x
dataset = dataset.map(lambda x,y: map_fn(x,v)).batch(1)
for batch in tfe.Iterator(dataset):
print("{} | {}".format(batch, v.x))
日志:
<tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=0.0>
<tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=1.0>
[1] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=2.0>
[2] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=3.0>
[3] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=4.0>
[4] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=5.0>
[5] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=5.0>
工作示例:
https://www.kaggle.com/mpekalski/tfe-conditional-stateful-counter
感谢@MPękalski 为我指明了正确的方向,我实际上找到了问题的答案。
代码现在看起来像这样:
def _filter_bcc_labels(self, bg_counter, labels, labels_table, bcc_count):
bg_counter = tf.zeros(shape=(), dtype=tf.int32)
def _add_to_counter():
nonlocal bg_counter
bg_counter.assign_add(1)
# Prints the counter value
tf.Print(bg_counter, [bg_counter])
return tf.constant(True)
return tf.cond(tf.greater_equal(bg_counter, tf.constant(bcc_count, dtype=tf.int32)),
true_fn=lambda: tf.constant(False),
false_fn=_add_to_counter)
bg_counter = tf.get_variable("bg_counter_" + step, initializer=lambda: tf.zeros(shape=[], dtype=tf.int32), dtype=tf.int32, trainable=False)
ds = ds.filter(lambda file, position, img, lbls: self._filter_bcc_labels(bg_counter, lbls, {"BCC": 0, "BACKGROUND": 1}, 10))
请记住,如果您在数据集上迭代两次,此解决方案将不起作用,因为在这种情况下计数器不会重新初始化。如果你将 bg_counter = tf.get_variable("bg_counter_" + step, initializer=lambda: tf.zeros(shape=[], dtype=tf.int32), dtype=tf.int32, trainable=False)
移到 ds.filter
中,那么你会得到一个 'Tensor' object has no attribute 'assign_add'
因为急切模式。
如果您真的想以正确的方式做到这一点,那么您必须在迭代数据集管道之外的批次时创建一个计数器。
我正在尝试在我的 Tensorflow 数据集管道中添加一个累加器。基本上,我有这个:
def _filter_bcc_labels(self, labels, labels_table, bcc_count):
bg_counter = tf.zeros(shape=(), dtype=tf.int32)
def _add_to_counter():
tf.add(bg_counter, 1)
# Here the bg_counter is always equal to 0
tf.Print(bg_counter, [bg_counter])
return tf.constant(True)
return tf.cond(tf.greater_equal(bg_counter, tf.constant(bcc_count, dtype=tf.int32)),
true_fn=lambda: tf.constant(False),
false_fn=_add_to_counter)
ds = ds.filter(lambda file, position, img, lbls: self._filter_bcc_labels(lbls, {"BCC": 0, "BACKGROUND": 1}, 10))
我的目标是在达到 tf.cond
false_fn
时递增 bg_counter
但我的变量始终具有值 0,它实际上从未递增。
有人可以向我解释发生了什么吗?
请记住,我正在使用 TF eager,我不能使用 ds.make_initializable_iterator()
然后输入我的 bg_counter
初始值。
谢谢
我认为您要执行的操作需要 assign_add() 方法而不是添加方法。请注意,参数必须是一个变量。
还请注意 tf.cond 在 eager 之外的一般用途。
您可能希望将计数器包装在 class 中,因为 Eager 中的变量在 运行 超出范围时会被删除。
代码:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
tf.enable_eager_execution()
import tensorflow.contrib.eager as tfe
dataset = tf.data.Dataset.from_tensor_slices(([1,2,3,4,5], [-1,-2,-3,-4,-5]))
class My(object):
def __init__(self):
self.x = tf.get_variable("mycounter", initializer=lambda: tf.zeros(shape=[], dtype=tf.float32), dtype=tf.float32
, trainable=False)
v = My()
print(v.x)
tf.assign(v.x,tf.add(v.x,1.0))
print(v.x)
def map_fn(x,v):
tf.cond(tf.greater_equal(v.x, tf.constant(5.0))
,lambda: tf.constant(0.0)
,lambda: tf.assign(v.x,tf.add(v.x,1.0))
)
return x
dataset = dataset.map(lambda x,y: map_fn(x,v)).batch(1)
for batch in tfe.Iterator(dataset):
print("{} | {}".format(batch, v.x))
日志:
<tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=0.0>
<tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=1.0>
[1] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=2.0>
[2] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=3.0>
[3] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=4.0>
[4] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=5.0>
[5] | <tf.Variable 'mycounter:0' shape=() dtype=float32, numpy=5.0>
工作示例: https://www.kaggle.com/mpekalski/tfe-conditional-stateful-counter
感谢@MPękalski 为我指明了正确的方向,我实际上找到了问题的答案。 代码现在看起来像这样:
def _filter_bcc_labels(self, bg_counter, labels, labels_table, bcc_count):
bg_counter = tf.zeros(shape=(), dtype=tf.int32)
def _add_to_counter():
nonlocal bg_counter
bg_counter.assign_add(1)
# Prints the counter value
tf.Print(bg_counter, [bg_counter])
return tf.constant(True)
return tf.cond(tf.greater_equal(bg_counter, tf.constant(bcc_count, dtype=tf.int32)),
true_fn=lambda: tf.constant(False),
false_fn=_add_to_counter)
bg_counter = tf.get_variable("bg_counter_" + step, initializer=lambda: tf.zeros(shape=[], dtype=tf.int32), dtype=tf.int32, trainable=False)
ds = ds.filter(lambda file, position, img, lbls: self._filter_bcc_labels(bg_counter, lbls, {"BCC": 0, "BACKGROUND": 1}, 10))
请记住,如果您在数据集上迭代两次,此解决方案将不起作用,因为在这种情况下计数器不会重新初始化。如果你将 bg_counter = tf.get_variable("bg_counter_" + step, initializer=lambda: tf.zeros(shape=[], dtype=tf.int32), dtype=tf.int32, trainable=False)
移到 ds.filter
中,那么你会得到一个 'Tensor' object has no attribute 'assign_add'
因为急切模式。
如果您真的想以正确的方式做到这一点,那么您必须在迭代数据集管道之外的批次时创建一个计数器。