尝试使用 tf.scan() 实现循环网络
Trying to implement recurrent network with tf.scan()
我正在尝试使用 tf.scan
实现循环状态张量。我现在的代码是这样的:
import tensorflow as tf
import math
import numpy as np
INPUTS = 10
HIDDEN_1 = 20
BATCH_SIZE = 3
def iterate_state(prev_state_tuple, input):
with tf.name_scope('h1'):
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
matmuladd = tf.matmul(inputs, weights) + biases
unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
prev_state = unpacked_state
state = 0.9* prev_state + 0.1*matmuladd
output = tf.nn.relu(state)
return tf.concat(0,[state, output])
def data_iter():
while True:
idxs = np.random.rand(BATCH_SIZE, INPUTS)
yield idxs
with tf.Graph().as_default():
inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
with tf.variable_scope('states'):
initial_state = tf.zeros([HIDDEN_1],
name='initial_state')
initial_out = tf.zeros([HIDDEN_1],
name='initial_out')
concat_tensor = tf.concat(0,[initial_state, initial_out])
states, output = tf.scan(iterate_state, inputs,
initializer=concat_tensor, name='states')
sess = tf.Session()
# Run the Op to initialize the variables.
sess.run(tf.initialize_all_variables())
iter_ = data_iter()
for i in xrange(0, 2):
print ("iteration: ",i)
input_data = iter_.next()
out,st = sess.run([output,states], feed_dict={ inputs: input_data})
但是,当 运行 这个时候我得到这个错误:
Traceback (most recent call last):
File "cycles_in_graphs_with_scan.py", line 37, in <module>
initializer=concat_tensor, name='states')
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 442, in __iter__
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
(tensorflow)charlesq@Leviathan ~/projects/stuff $ python cycles_in_graphs_with_scan.py
Traceback (most recent call last):
File "cycles_in_graphs_with_scan.py", line 37, in <module>
initializer=concat_tensor, name='states')
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 442, in __iter__
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
我已经尝试使用 pack/unpack
和 concat/split
,但我遇到了同样的错误。
有什么解决这个问题的想法吗?
你得到一个错误,因为 tf.scan()
returns a single tf.Tensor
,所以行:
states, output = tf.scan(...)
...无法将从 tf.scan()
返回的张量解构(解压缩)为两个值(states
和 outputs
)。实际上,代码试图将 tf.scan()
的结果视为长度为 2 的列表,并将第一个元素分配给 states
,将第二个元素分配给 output
,但是与 Python 列表或元组—tf.Tensor
不支持这个。
相反,您需要手动从 tf.scan()
的结果中提取值。例如,使用 tf.split()
:
scan_result = tf.scan(...)
# Assumes values are packed together along `split_dim`.
states, output = tf.split(split_dim, 2, scan_result)
或者,您可以使用 tf.slice()
or tf.unpack()
提取相关的 states
和 output
值。
我正在尝试使用 tf.scan
实现循环状态张量。我现在的代码是这样的:
import tensorflow as tf
import math
import numpy as np
INPUTS = 10
HIDDEN_1 = 20
BATCH_SIZE = 3
def iterate_state(prev_state_tuple, input):
with tf.name_scope('h1'):
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
matmuladd = tf.matmul(inputs, weights) + biases
unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
prev_state = unpacked_state
state = 0.9* prev_state + 0.1*matmuladd
output = tf.nn.relu(state)
return tf.concat(0,[state, output])
def data_iter():
while True:
idxs = np.random.rand(BATCH_SIZE, INPUTS)
yield idxs
with tf.Graph().as_default():
inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
with tf.variable_scope('states'):
initial_state = tf.zeros([HIDDEN_1],
name='initial_state')
initial_out = tf.zeros([HIDDEN_1],
name='initial_out')
concat_tensor = tf.concat(0,[initial_state, initial_out])
states, output = tf.scan(iterate_state, inputs,
initializer=concat_tensor, name='states')
sess = tf.Session()
# Run the Op to initialize the variables.
sess.run(tf.initialize_all_variables())
iter_ = data_iter()
for i in xrange(0, 2):
print ("iteration: ",i)
input_data = iter_.next()
out,st = sess.run([output,states], feed_dict={ inputs: input_data})
但是,当 运行 这个时候我得到这个错误:
Traceback (most recent call last):
File "cycles_in_graphs_with_scan.py", line 37, in <module>
initializer=concat_tensor, name='states')
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 442, in __iter__
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
(tensorflow)charlesq@Leviathan ~/projects/stuff $ python cycles_in_graphs_with_scan.py
Traceback (most recent call last):
File "cycles_in_graphs_with_scan.py", line 37, in <module>
initializer=concat_tensor, name='states')
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 442, in __iter__
raise TypeError("'Tensor' object is not iterable.")
TypeError: 'Tensor' object is not iterable.
我已经尝试使用 pack/unpack
和 concat/split
,但我遇到了同样的错误。
有什么解决这个问题的想法吗?
你得到一个错误,因为 tf.scan()
returns a single tf.Tensor
,所以行:
states, output = tf.scan(...)
...无法将从 tf.scan()
返回的张量解构(解压缩)为两个值(states
和 outputs
)。实际上,代码试图将 tf.scan()
的结果视为长度为 2 的列表,并将第一个元素分配给 states
,将第二个元素分配给 output
,但是与 Python 列表或元组—tf.Tensor
不支持这个。
相反,您需要手动从 tf.scan()
的结果中提取值。例如,使用 tf.split()
:
scan_result = tf.scan(...)
# Assumes values are packed together along `split_dim`.
states, output = tf.split(split_dim, 2, scan_result)
或者,您可以使用 tf.slice()
or tf.unpack()
提取相关的 states
和 output
值。