Tensorflow 过滤掉不为零的张量
Tensorflow filter out tensors without zero
我有 X
和 Y
的批量张量,像这样
X = tf.constant([[[1,-2], [2,0], [-2,2], [4,-1]],
[[3,1], [4,1], [**0**,1], [-5,3]],
[[5,-4], [6,-2], [-2,1], [-2,2]]], dtype=tf.float16)
Y = tf.constant([[1], [43], [2]])
X
实际上有一个维度 TensorShape([512, 30, 57])
.
我想过滤掉维度 0 处的元素,这些元素在维度 2 的任何第一个元素处都为零(检查上面突出显示的零)。
X = tf.constant([[[1,-2], [2,0], [-2,2], [4,-1]],
[[5,-4], [6,-2], [-2,1], [-2,2]]], dtype=tf.float16)
Y = tf.constant([[1], [2]])
目前,我有以下代码
idx = [k for k, v in enumerate(X) if 0 not in v[:, 0]]
X_clean = [X[x, :, :] for x in idx]
X_clean = tensorflow.stack(X_clean)
Y_clean = tensorflow.stack([Y[x] for x in idx])
这太慢了,每次迭代大约需要 2 秒。我怎样才能使这项工作更快?
您可以使用 tf.where
、tf.reduce_all
和 tf.gather
获得更有效的解决方案:
# getting the index of the valid elements batch wise
# X[...,0]!=0 checks that the first element in the last dimension is not 0
# reduce_all cheks that this is true for every element along dimension 1
# where gives the index of those valid elements
valid_element_idxs = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_element_idxs)
Y_clean = tf.gather(Y, valid_element_idxs)
比较你的方法,这个方法与你给出的 2 个小张量上的 %timeit 作为例子:
>>> %timeit list_comp(X,Y)
2.82 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit tf_native(X,Y)
263 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
您可以使用 tf.function
压缩一点性能:
>>> %timeit tf_native_decorated(X,Y)
206 µs ± 6.31 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
供参考的函数定义:
def list_comp(X,Y):
idx = [k for k, v in enumerate(X) if 0 not in v[:, 0]]
X_clean = [X[x, :, :] for x in idx]
X_clean = tf.stack(X_clean)
Y_clean = tf.stack([Y[x] for x in idx])
return X_clean, Y_clean
def tf_native(X,Y):
valid_elements_idx = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_elements_idx)
Y_clean = tf.gather(Y, valid_elements_idx)
return X_clean, Y_clean
@tf.function
def tf_native_decorated(X,Y):
valid_elements_idx = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_elements_idx)
Y_clean = tf.gather(Y, valid_elements_idx)
return X_clean, Y_clean
我有 X
和 Y
的批量张量,像这样
X = tf.constant([[[1,-2], [2,0], [-2,2], [4,-1]],
[[3,1], [4,1], [**0**,1], [-5,3]],
[[5,-4], [6,-2], [-2,1], [-2,2]]], dtype=tf.float16)
Y = tf.constant([[1], [43], [2]])
X
实际上有一个维度 TensorShape([512, 30, 57])
.
我想过滤掉维度 0 处的元素,这些元素在维度 2 的任何第一个元素处都为零(检查上面突出显示的零)。
X = tf.constant([[[1,-2], [2,0], [-2,2], [4,-1]],
[[5,-4], [6,-2], [-2,1], [-2,2]]], dtype=tf.float16)
Y = tf.constant([[1], [2]])
目前,我有以下代码
idx = [k for k, v in enumerate(X) if 0 not in v[:, 0]]
X_clean = [X[x, :, :] for x in idx]
X_clean = tensorflow.stack(X_clean)
Y_clean = tensorflow.stack([Y[x] for x in idx])
这太慢了,每次迭代大约需要 2 秒。我怎样才能使这项工作更快?
您可以使用 tf.where
、tf.reduce_all
和 tf.gather
获得更有效的解决方案:
# getting the index of the valid elements batch wise
# X[...,0]!=0 checks that the first element in the last dimension is not 0
# reduce_all cheks that this is true for every element along dimension 1
# where gives the index of those valid elements
valid_element_idxs = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_element_idxs)
Y_clean = tf.gather(Y, valid_element_idxs)
比较你的方法,这个方法与你给出的 2 个小张量上的 %timeit 作为例子:
>>> %timeit list_comp(X,Y)
2.82 ms ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> %timeit tf_native(X,Y)
263 µs ± 2.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
您可以使用 tf.function
压缩一点性能:
>>> %timeit tf_native_decorated(X,Y)
206 µs ± 6.31 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
供参考的函数定义:
def list_comp(X,Y):
idx = [k for k, v in enumerate(X) if 0 not in v[:, 0]]
X_clean = [X[x, :, :] for x in idx]
X_clean = tf.stack(X_clean)
Y_clean = tf.stack([Y[x] for x in idx])
return X_clean, Y_clean
def tf_native(X,Y):
valid_elements_idx = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_elements_idx)
Y_clean = tf.gather(Y, valid_elements_idx)
return X_clean, Y_clean
@tf.function
def tf_native_decorated(X,Y):
valid_elements_idx = tf.squeeze(tf.where(tf.reduce_all(X[...,0]!=0,axis=-1)))
X_clean = tf.gather(X, valid_elements_idx)
Y_clean = tf.gather(Y, valid_elements_idx)
return X_clean, Y_clean