计算编辑距离(feed_dict 错误)
Computing Edit Distance (feed_dict error)
我在 Tensorflow 中编写了一些代码来计算一个字符串和一组字符串之间的编辑距离。我无法找出错误。
import tensorflow as tf
sess = tf.Session()
# Create input data
test_string = ['foo']
ref_strings = ['food', 'bar']
def create_sparse_vec(word_list):
num_words = len(word_list)
indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)]
chars = list(''.join(word_list))
return(tf.SparseTensor(indices, chars, [num_words,1,1]))
test_string_sparse = create_sparse_vec(test_string*len(ref_strings))
ref_string_sparse = create_sparse_vec(ref_strings)
sess.run(tf.edit_distance(test_string_sparse, ref_string_sparse, normalize=True))
此代码有效,当 运行 时,它产生输出:
array([[ 0.25],
[ 1. ]], dtype=float32)
但是当我尝试通过稀疏占位符输入稀疏张量来执行此操作时,出现错误。
test_input = tf.sparse_placeholder(dtype=tf.string)
ref_input = tf.sparse_placeholder(dtype=tf.string)
edit_distances = tf.edit_distance(test_input, ref_input, normalize=True)
feed_dict = {test_input: test_string_sparse,
ref_input: ref_string_sparse}
sess.run(edit_distances, feed_dict=feed_dict)
这是错误回溯:
Traceback (most recent call last):
File "<ipython-input-29-4e06de0b7af3>", line 1, in <module>
sess.run(edit_distances, feed_dict=feed_dict)
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 372, in run
run_metadata_ptr)
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 597, in _run
for subfeed, subfeed_val in _feed_fn(feed, feed_val):
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 558, in _feed_fn
return feed_fn(feed, feed_val)
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 268, in <lambda>
[feed.indices, feed.values, feed.shape], feed_val)),
TypeError: zip argument #2 must support iteration
知道这里发生了什么吗?
TL;DR: 对于 create_sparse_vec()
的 return 类型,使用 tf.SparseTensorValue
instead of tf.SparseTensor
.
这里的问题来自create_sparse_vec()
的return类型,即tf.SparseTensor
, and is not understood as a feed value in the call to sess.run()
。
当您输入(密集)tf.Tensor
时,预期值类型是 NumPy 数组(或某些可以转换为数组的对象)。当您输入 tf.SparseTensor
时,期望值类型是 tf.SparseTensorValue
,它与 tf.SparseTensor
类似,但它的 indices
、values
和 [=23] =] 属性是 NumPy 数组(或某些可以转换为数组的对象,例如示例中的列表。
以下代码应该有效:
def create_sparse_vec(word_list):
num_words = len(word_list)
indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)]
chars = list(''.join(word_list))
return tf.SparseTensorValue(indices, chars, [num_words,1,1])
我在 Tensorflow 中编写了一些代码来计算一个字符串和一组字符串之间的编辑距离。我无法找出错误。
import tensorflow as tf
sess = tf.Session()
# Create input data
test_string = ['foo']
ref_strings = ['food', 'bar']
def create_sparse_vec(word_list):
num_words = len(word_list)
indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)]
chars = list(''.join(word_list))
return(tf.SparseTensor(indices, chars, [num_words,1,1]))
test_string_sparse = create_sparse_vec(test_string*len(ref_strings))
ref_string_sparse = create_sparse_vec(ref_strings)
sess.run(tf.edit_distance(test_string_sparse, ref_string_sparse, normalize=True))
此代码有效,当 运行 时,它产生输出:
array([[ 0.25],
[ 1. ]], dtype=float32)
但是当我尝试通过稀疏占位符输入稀疏张量来执行此操作时,出现错误。
test_input = tf.sparse_placeholder(dtype=tf.string)
ref_input = tf.sparse_placeholder(dtype=tf.string)
edit_distances = tf.edit_distance(test_input, ref_input, normalize=True)
feed_dict = {test_input: test_string_sparse,
ref_input: ref_string_sparse}
sess.run(edit_distances, feed_dict=feed_dict)
这是错误回溯:
Traceback (most recent call last):
File "<ipython-input-29-4e06de0b7af3>", line 1, in <module>
sess.run(edit_distances, feed_dict=feed_dict)
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 372, in run
run_metadata_ptr)
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 597, in _run
for subfeed, subfeed_val in _feed_fn(feed, feed_val):
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 558, in _feed_fn
return feed_fn(feed, feed_val)
File "/usr/local/lib/python3.4/dist-packages/tensorflow/python/client/session.py", line 268, in <lambda>
[feed.indices, feed.values, feed.shape], feed_val)),
TypeError: zip argument #2 must support iteration
知道这里发生了什么吗?
TL;DR: 对于 create_sparse_vec()
的 return 类型,使用 tf.SparseTensorValue
instead of tf.SparseTensor
.
这里的问题来自create_sparse_vec()
的return类型,即tf.SparseTensor
, and is not understood as a feed value in the call to sess.run()
。
当您输入(密集)tf.Tensor
时,预期值类型是 NumPy 数组(或某些可以转换为数组的对象)。当您输入 tf.SparseTensor
时,期望值类型是 tf.SparseTensorValue
,它与 tf.SparseTensor
类似,但它的 indices
、values
和 [=23] =] 属性是 NumPy 数组(或某些可以转换为数组的对象,例如示例中的列表。
以下代码应该有效:
def create_sparse_vec(word_list):
num_words = len(word_list)
indices = [[xi, 0, yi] for xi,x in enumerate(word_list) for yi,y in enumerate(x)]
chars = list(''.join(word_list))
return tf.SparseTensorValue(indices, chars, [num_words,1,1])