在张量上使用 while_loop 在张量流中创建掩码
using while_loop over the tensor for creating a mask in tensorflow
我想创建一个迭代张量的掩码。
我有这个代码:
import tensorflow as tf
out = tf.Variable(tf.zeros_like(alp, dtype=tf.int32))
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
我想遍历 rows_tf
并相应地 columns_tf
在 out
上创建一个掩码。
例如,它将屏蔽 out
张量中 [1,1] [2,1] and [5,1]
处的索引等于 1
.
对于 rows_tf
的第二行,out 张量中 [1,2] [2,2] [5,2]
处的索引将被设置为 1
等等,总共 8 行
到目前为止我已经这样做了,虽然它没有 运行 成功:
body = lambda k, i: (tf.add(out[rows_tf[i][k]][columns_tf[i][i]], 1)) # find the corresponding element in out tensor and add 1 to it (0+1=1)
k = 0
n2, m2 = rows_tf.shape
for i in tf.range(0,n2): # loop through rows in rows_tf
cond = lambda k, _: tf.less(k, m2) #this check to go over the columns in rows_tf
tf.while_loop(cond, body, (k, i))
它引发了这个错误:
TypeError: Cannot iterate over a scalar tensor.
in this while cond(*loop_vars):
我浏览了几个链接,即 以确保我按照说明进行操作,但无法修复这个链接。
感谢帮助
你可以不用循环使用 tf.scatter_nd
像这样:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
out = tf.zeros([10, 4], dtype=tf.int32)
rows_tf = tf.constant(
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]], dtype=tf.int32)
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]], dtype=tf.int32)
# Broadcast columns
columns_bc = tf.broadcast_to(columns_tf, tf.shape(rows_tf))
# Scatter values to indices
scatter_idx = tf.stack([rows_tf, columns_bc], axis=-1)
mask = tf.scatter_nd(scatter_idx, tf.ones_like(rows_tf, dtype=tf.bool), tf.shape(out))
print(sess.run(mask))
输出:
[[False False False False]
[False True True True]
[False True True True]
[False False True True]
[False False True True]
[False True True True]
[False False True True]
[False False True False]
[False False False False]
[False False False False]]
或者,您也可以仅使用布尔运算来执行此操作:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
out = tf.zeros([10, 4], dtype=tf.int32)
rows_tf = tf.constant(
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]], dtype=tf.int32)
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]], dtype=tf.int32)
# Compare indices
row_eq = tf.equal(tf.range(out.shape[0])[:, tf.newaxis],
rows_tf[..., np.newaxis, np.newaxis])
col_eq = tf.equal(tf.range(out.shape[1])[tf.newaxis, :],
columns_tf[..., np.newaxis, np.newaxis])
# Aggregate
mask = tf.reduce_any(row_eq & col_eq, axis=[0, 1])
print(sess.run(mask))
# Same as before
然而,这原则上会占用更多内存。
我想创建一个迭代张量的掩码。 我有这个代码:
import tensorflow as tf
out = tf.Variable(tf.zeros_like(alp, dtype=tf.int32))
rows_tf = tf.constant (
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]])
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]])
我想遍历 rows_tf
并相应地 columns_tf
在 out
上创建一个掩码。
例如,它将屏蔽 out
张量中 [1,1] [2,1] and [5,1]
处的索引等于 1
.
对于 rows_tf
的第二行,out 张量中 [1,2] [2,2] [5,2]
处的索引将被设置为 1
等等,总共 8 行
到目前为止我已经这样做了,虽然它没有 运行 成功:
body = lambda k, i: (tf.add(out[rows_tf[i][k]][columns_tf[i][i]], 1)) # find the corresponding element in out tensor and add 1 to it (0+1=1)
k = 0
n2, m2 = rows_tf.shape
for i in tf.range(0,n2): # loop through rows in rows_tf
cond = lambda k, _: tf.less(k, m2) #this check to go over the columns in rows_tf
tf.while_loop(cond, body, (k, i))
它引发了这个错误:
TypeError: Cannot iterate over a scalar tensor.
in this while cond(*loop_vars):
我浏览了几个链接,即
感谢帮助
你可以不用循环使用 tf.scatter_nd
像这样:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
out = tf.zeros([10, 4], dtype=tf.int32)
rows_tf = tf.constant(
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]], dtype=tf.int32)
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]], dtype=tf.int32)
# Broadcast columns
columns_bc = tf.broadcast_to(columns_tf, tf.shape(rows_tf))
# Scatter values to indices
scatter_idx = tf.stack([rows_tf, columns_bc], axis=-1)
mask = tf.scatter_nd(scatter_idx, tf.ones_like(rows_tf, dtype=tf.bool), tf.shape(out))
print(sess.run(mask))
输出:
[[False False False False]
[False True True True]
[False True True True]
[False False True True]
[False False True True]
[False True True True]
[False False True True]
[False False True False]
[False False False False]
[False False False False]]
或者,您也可以仅使用布尔运算来执行此操作:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
out = tf.zeros([10, 4], dtype=tf.int32)
rows_tf = tf.constant(
[[1, 2, 5],
[1, 2, 5],
[1, 2, 5],
[1, 4, 6],
[1, 4, 6],
[2, 3, 6],
[2, 3, 6],
[2, 4, 7]], dtype=tf.int32)
columns_tf = tf.constant(
[[1],
[2],
[3],
[2],
[3],
[2],
[3],
[2]], dtype=tf.int32)
# Compare indices
row_eq = tf.equal(tf.range(out.shape[0])[:, tf.newaxis],
rows_tf[..., np.newaxis, np.newaxis])
col_eq = tf.equal(tf.range(out.shape[1])[tf.newaxis, :],
columns_tf[..., np.newaxis, np.newaxis])
# Aggregate
mask = tf.reduce_any(row_eq & col_eq, axis=[0, 1])
print(sess.run(mask))
# Same as before
然而,这原则上会占用更多内存。