对于 Tensorflow 中的张量对象,具有多个参数和 map_fn 中的一个 return 值的自定义函数

Custom function with multiple argument and one return value in map_fn for tensor object in Tensorflow

我有两个张量 t1 和 t2 (shape=(64,64,3), dtype=tf.float64)。我想执行一个自定义函数“func”,它接受两个张量作为输入和 returns 一个新张量。

@tf.function
def func(a):
  t1 = a[0]
  t2 = a[1]
  
  return tf.add(t1, t2)

我正在使用 map_fn 的 tensorflow 为输入的每个元素执行函数。

t = tf.map_fn(fn=func, elems=(t1, t2), dtype=(tf.float64, tf.float64))
tf.print(t)

用于测试目的的样本输入张量是,

t1 = tf.constant([[1.1, 2.2, 3.3],
                  [4.4, 5.5, 6.6]])
t2 = tf.constant([[7.7, 8.8, 9.9],
                  [10.1, 11.11, 12.12]])

我不能使用带有两个参数的 map_fn。 [尝试使用 tf.stack,也取消堆叠,但这也不起作用。]知道如何做到这一点吗?

“map_fn”的“elems”参数解压沿轴 0 传递给它的参数。因此,为了在自定义函数中传递多个张量,

  1. 我们必须将它们堆叠在一起。
  2. 沿轴 0 添加额外维度。
# t1 and t2 has shape [2, 3]
val = tf.stack([t1, t2]) # shape is now [2, 2, 3]
val = tf.expand_dims(val, axis=0) # shape is now [1, 2, 2, 3]
t = tf.map_fn(fn=func, elems=val, dtype=tf.float64)

另外,“map_fn”的“dtype”应该是函数的 return 类型。例如,在本例中它应该是 tf.float64。如果该函数 return 一个元组,则 dtype 也将是一个元组。

@tf.function
def func(a): # a has shape [2, 2, 3]
  t1 = a[0] # shape [2, 3]
  t2 = a[1] # shape [2, 3]

  return tf.add(t1, t2)