如何使用tf.cond进行批处理

How to use tf.cond for batch processing

我想使用 tf.cond(pred, fn1, fn2, name=None) 进行条件分支。假设我有两个张量:x, y。每个张量都是一批 0/1,我想使用这个张量压缩 x < y 作为源 tf.cond pred 参数:

pred: A scalar determining whether to return the result of fn1 or fn2.

但是如果我正在处理批处理,那么看起来我需要迭代图中的源张量并为批处理中的每个项目制作切片并为每个项目应用 tf.cond。对我来说看起来很可疑。为什么 tf.cond 不接受批处理而只接受标量?你能告诉我批处理使用它的正确方法是什么吗?

tf.where 听起来像你想要的:张量之间的矢量化选择。

tf.cond 是一个控制流修饰符:它决定了执行了哪些操作,因此很难想到有用的批处理语义。

我们还可以将这些操作混合在一起:根据条件进行切片并将这些切片传递给两个分支的操作​​。

import tensorflow as tf
from tensorflow.python.util import nest

def slicing_where(condition, full_input, true_branch, false_branch):
  """Split `full_input` between `true_branch` and `false_branch` on `condition`.

  Args:
    condition: A boolean Tensor with shape [B_1, ..., B_N].
    full_input: A Tensor or nested tuple of Tensors of any dtype, each with
      shape [B_1, ..., B_N, ...], to be split between `true_branch` and
      `false_branch` based on `condition`.
    true_branch: A function taking a single argument, that argument having the
      same structure and number of batch dimensions as `full_input`. Receives
      slices of `full_input` corresponding to the True entries of
      `condition`. Returns a Tensor or nested tuple of Tensors, each with batch
      dimensions matching its inputs.
    false_branch: Like `true_branch`, but receives inputs corresponding to the
      false elements of `condition`. Returns a Tensor or nested tuple of Tensors
      (with the same structure as the return value of `true_branch`), but with
      batch dimensions matching its inputs.
  Returns:
    Interleaved outputs from `true_branch` and `false_branch`, each Tensor
    having shape [B_1, ..., B_N, ...].
  """
  full_input_flat = nest.flatten(full_input)
  true_indices = tf.where(condition)
  false_indices = tf.where(tf.logical_not(condition))
  true_branch_inputs = nest.pack_sequence_as(
      structure=full_input,
      flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
                     for input_tensor in full_input_flat])
  false_branch_inputs = nest.pack_sequence_as(
      structure=full_input,
      flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
                     for input_tensor in full_input_flat])
  true_outputs = true_branch(true_branch_inputs)
  false_outputs = false_branch(false_branch_inputs)
  nest.assert_same_structure(true_outputs, false_outputs)
  def scatter_outputs(true_output, false_output):
    batch_shape = tf.shape(condition)
    scattered_shape = tf.concat(
        [batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
        0)
    true_scatter = tf.scatter_nd(
        indices=tf.cast(true_indices, tf.int32),
        updates=true_output,
        shape=scattered_shape)
    false_scatter = tf.scatter_nd(
        indices=tf.cast(false_indices, tf.int32),
        updates=false_output,
        shape=scattered_shape)
    return true_scatter + false_scatter
  result = nest.pack_sequence_as(
      structure=true_outputs,
      flat_sequence=[
          scatter_outputs(true_single_output, false_single_output)
          for true_single_output, false_single_output
          in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
  return result

一些示例:

vector_test = slicing_where(
    condition=tf.equal(tf.range(10) % 2, 0),
    full_input=tf.range(10, dtype=tf.float32),
    true_branch=lambda x: 0.2 + x,
    false_branch=lambda x: 0.1 + x)

cross_range = (tf.range(10, dtype=tf.float32)[:, None]
               * tf.range(10, dtype=tf.float32)[None, :])
matrix_test = slicing_where(
    condition=tf.equal(tf.range(10) % 3, 0),
    full_input=cross_range,
    true_branch=lambda x: -x,
    false_branch=lambda x: x + 0.1)

with tf.Session():
  print(vector_test.eval())
  print(matrix_test.eval())

打印:

[ 0.2         1.10000002  2.20000005  3.0999999   4.19999981  5.0999999
  6.19999981  7.0999999   8.19999981  9.10000038]
[[  0.           0.           0.           0.           0.           0.
    0.           0.           0.           0.        ]
 [  0.1          1.10000002   2.0999999    3.0999999    4.0999999
    5.0999999    6.0999999    7.0999999    8.10000038   9.10000038]
 [  0.1          2.0999999    4.0999999    6.0999999    8.10000038
   10.10000038  12.10000038  14.10000038  16.10000038  18.10000038]
 [  0.          -3.          -6.          -9.         -12.         -15.
  -18.         -21.         -24.         -27.        ]
 [  0.1          4.0999999    8.10000038  12.10000038  16.10000038
   20.10000038  24.10000038  28.10000038  32.09999847  36.09999847]
 [  0.1          5.0999999   10.10000038  15.10000038  20.10000038
   25.10000038  30.10000038  35.09999847  40.09999847  45.09999847]
 [  0.          -6.         -12.         -18.         -24.         -30.
  -36.         -42.         -48.         -54.        ]
 [  0.1          7.0999999   14.10000038  21.10000038  28.10000038
   35.09999847  42.09999847  49.09999847  56.09999847  63.09999847]
 [  0.1          8.10000038  16.10000038  24.10000038  32.09999847
   40.09999847  48.09999847  56.09999847  64.09999847  72.09999847]
 [  0.          -9.         -18.         -27.         -36.         -45.
  -54.         -63.         -72.         -81.        ]]