Tensorflow 计算二进制矩阵乘法
Tensorflow compute multiplication by binary matrix
我有一个形状为 [batch_size,512]
的数据张量,我有一个只有 0 和 1 的常量矩阵,其形状为 [256,512]
.
我想为每个批次有效地计算我的向量(数据张量的第二维)的乘积之和,仅针对 1 而不是 0 的条目。
解释示例:
假设我有 1 个大小的批处理:数据张量的值为 [5,4,3,7,8,2]
,我的常量矩阵的值为:
[0,1,1,0,0,0]
[1,0,0,0,0,0]
[1,1,1,0,0,1]
这意味着我想计算第一行 4*3
、第二行 5
和第三行 5*4*3*2
。
对于这批,我得到 4*3+5+5*4*3*2
等于 137。
目前,我通过遍历所有行,逐元素计算我的数据和常量矩阵行的乘积,然后求和,这运行得非常慢。
这样的事情怎么样:
import tensorflow as tf
# Two-element batch
data = [[5, 4, 3, 7, 8, 2],
[4, 2, 6, 1, 6, 8]]
mask = [[0, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 1]]
with tf.Graph().as_default(), tf.Session() as sess:
# Data as tensors
d = tf.constant(data, dtype=tf.int32)
m = tf.constant(mask, dtype=tf.int32)
# Tile data as needed
dd = tf.tile(d[:, tf.newaxis], (1, tf.shape(m)[0], 1))
mm = tf.tile(m[tf.newaxis, :], (tf.shape(d)[0], 1, 1))
# Replace values with 1 wherever the mask is 0
w = tf.where(tf.cast(mm, tf.bool), dd, tf.ones_like(dd))
# Multiply row-wise and sum
result = tf.reduce_sum(tf.reduce_prod(w, axis=-1), axis=-1)
print(sess.run(result))
输出:
[137 400]
编辑:
如果您输入的数据是单个向量,那么您只需:
import tensorflow as tf
# Two-element batch
data = [5, 4, 3, 7, 8, 2]
mask = [[0, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 1]]
with tf.Graph().as_default(), tf.Session() as sess:
# Data as tensors
d = tf.constant(data, dtype=tf.int32)
m = tf.constant(mask, dtype=tf.int32)
# Tile data as needed
dd = tf.tile(d[tf.newaxis], (tf.shape(m)[0], 1))
# Replace values with 1 wherever the mask is 0
w = tf.where(tf.cast(m, tf.bool), dd, tf.ones_like(dd))
# Multiply row-wise and sum
result = tf.reduce_sum(tf.reduce_prod(w, axis=-1), axis=-1)
print(sess.run(result))
输出:
137
我有一个形状为 [batch_size,512]
的数据张量,我有一个只有 0 和 1 的常量矩阵,其形状为 [256,512]
.
我想为每个批次有效地计算我的向量(数据张量的第二维)的乘积之和,仅针对 1 而不是 0 的条目。
解释示例:
假设我有 1 个大小的批处理:数据张量的值为 [5,4,3,7,8,2]
,我的常量矩阵的值为:
[0,1,1,0,0,0]
[1,0,0,0,0,0]
[1,1,1,0,0,1]
这意味着我想计算第一行 4*3
、第二行 5
和第三行 5*4*3*2
。
对于这批,我得到 4*3+5+5*4*3*2
等于 137。
目前,我通过遍历所有行,逐元素计算我的数据和常量矩阵行的乘积,然后求和,这运行得非常慢。
这样的事情怎么样:
import tensorflow as tf
# Two-element batch
data = [[5, 4, 3, 7, 8, 2],
[4, 2, 6, 1, 6, 8]]
mask = [[0, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 1]]
with tf.Graph().as_default(), tf.Session() as sess:
# Data as tensors
d = tf.constant(data, dtype=tf.int32)
m = tf.constant(mask, dtype=tf.int32)
# Tile data as needed
dd = tf.tile(d[:, tf.newaxis], (1, tf.shape(m)[0], 1))
mm = tf.tile(m[tf.newaxis, :], (tf.shape(d)[0], 1, 1))
# Replace values with 1 wherever the mask is 0
w = tf.where(tf.cast(mm, tf.bool), dd, tf.ones_like(dd))
# Multiply row-wise and sum
result = tf.reduce_sum(tf.reduce_prod(w, axis=-1), axis=-1)
print(sess.run(result))
输出:
[137 400]
编辑:
如果您输入的数据是单个向量,那么您只需:
import tensorflow as tf
# Two-element batch
data = [5, 4, 3, 7, 8, 2]
mask = [[0, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 1]]
with tf.Graph().as_default(), tf.Session() as sess:
# Data as tensors
d = tf.constant(data, dtype=tf.int32)
m = tf.constant(mask, dtype=tf.int32)
# Tile data as needed
dd = tf.tile(d[tf.newaxis], (tf.shape(m)[0], 1))
# Replace values with 1 wherever the mask is 0
w = tf.where(tf.cast(m, tf.bool), dd, tf.ones_like(dd))
# Multiply row-wise and sum
result = tf.reduce_sum(tf.reduce_prod(w, axis=-1), axis=-1)
print(sess.run(result))
输出:
137