Tensorflow 中的 Theano 函数等价物

Theano function equivalent in Tensorflow

我想知道这个

我想用这个懒惰的 tensorflow[=23= 解决 Theano.function 中的 update 问题] 构造:

class TensorFlowTheanoFunction(object):
def __init__(self, inputs, outputs, session):
    self._inputs = inputs
    self._outputs = outputs
    self.session = session

def __call__(self, *args, **kwargs):
    feeds = {}
    for (argpos, arg) in enumerate(args):
        feeds[self._inputs[argpos]] = arg
    return self.session.run(self._outputs, feeds)

如果我想传递一个 update 参数(就像在 Theano 中一样),我该如何修改这个惰性调用? 我只是希望这也可以在 tensorflow 中工作:

self.new = theano.function([], [], updates=zip(old_params, params))

只需从该线程修改 Yaroslav 的代码以使用 tf.assign,并使用控制依赖项来确保在分配发生之前计算输出:

import tensorflow as tf

class TensorFlowTheanoFunction(object):   
  def __init__(self, inputs, outputs, updates=()):
    self._inputs = inputs
    self._outputs = outputs
    self._updates = updates

  def __call__(self, *args, **kwargs):
    feeds = {}
    for (argpos, arg) in enumerate(args):
      feeds[self._inputs[argpos]] = arg
    try:
      outputs_identity = [tf.identity(output) for output in self._outputs]
      output_is_list = True
    except TypeError:
      outputs_identity = [tf.identity(self._outputs)]
      output_is_list = False
    with tf.control_dependencies(outputs_identity):
      assign_ops = [tf.assign(variable, replacement) 
                    for variable, replacement in self._updates]
    outputs_list = tf.get_default_session().run(
        outputs_identity + assign_ops, feeds)[:len(outputs_identity)]
    if output_is_list:
      return outputs_list
    else:
      assert len(outputs_list) == 1
      return outputs_list[0]

a = tf.placeholder(dtype=tf.int32)
b = tf.placeholder(dtype=tf.int32)
variable = tf.get_variable(
    "variable", shape=[], dtype=tf.int32, initializer=tf.zeros_initializer)
c = a + b + variable
d = a - b
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
f = TensorFlowTheanoFunction([a, b], [c, d], updates=[(variable, variable + 1)])
print f(1, 2)
print f(1, 2)
print f(0, 2)
f = TensorFlowTheanoFunction([a, b], c, updates=[(variable, variable + 1)])
print f(1, 2)
print f(1, 2)
print f(0, 2)

这会在每次迭代时更新变量:

[3, -1]
[4, -1]
[4, -2]
6
7
7