Select 使用 tensorflow 时任意类型对象列表中的一项 2.x
Select an item from a list of object of any type when using tensorflow 2.x
给定 class A
、[A() for _ in range(5)]
的实例列表,我想随机 select 其中一个(示例见以下代码)
class A:
def __init__(self, a):
self.a = a
def __call__(self):
return self.a
def f():
a_list = [A(i) for i in range(5)]
a = a_list[random.randint(0, 5)]()
return a
f()
有没有一种方法可以用 @tf.function
装饰 f
而无需更改 f
的功能并且无需调用 a_list
中的所有项目?
请注意,直接用 @tf.function
修饰 f
而不对上述代码进行任何其他更改是不可行的,因为它总是 return 相同的结果。另外,我知道这可以通过首先调用 a_list
中的所有元素然后使用 tf.gather_nd
对它们进行索引来实现。但是如果调用A
类型的对象涉及深度神经网络,这将产生大量开销。
我目前正在做同样的事情。这是我到目前为止所得到的。如果有人知道更好的方法,我也有兴趣听听。当我 运行 它进行昂贵的调用时,它比我计算和 return 所有值要快得多。
@tf.function
def f2():
a_list = [A(i) for i in range(5)]
idx = tf.cast(tf.random.uniform(shape=[], maxval=4), tf.int32)
return tf.switch_case(idx, a_list)
为了速度比较我做了A expensive matrix algebra的调用方法。然后考虑调用每个函数的备用函数:
@tf.function
def f3():
a_list = [A(i) for i in range(40)]
results = [a() for a in a_list]
return results
运行 f2 有 40 个元素:0.42643 秒
运行 f3 有 40 个元素:14.9153 秒
因此,对于仅选择一个分支的预期 40 倍加速,这看起来是正确的。
给定 class A
、[A() for _ in range(5)]
的实例列表,我想随机 select 其中一个(示例见以下代码)
class A:
def __init__(self, a):
self.a = a
def __call__(self):
return self.a
def f():
a_list = [A(i) for i in range(5)]
a = a_list[random.randint(0, 5)]()
return a
f()
有没有一种方法可以用 @tf.function
装饰 f
而无需更改 f
的功能并且无需调用 a_list
中的所有项目?
请注意,直接用 @tf.function
修饰 f
而不对上述代码进行任何其他更改是不可行的,因为它总是 return 相同的结果。另外,我知道这可以通过首先调用 a_list
中的所有元素然后使用 tf.gather_nd
对它们进行索引来实现。但是如果调用A
类型的对象涉及深度神经网络,这将产生大量开销。
我目前正在做同样的事情。这是我到目前为止所得到的。如果有人知道更好的方法,我也有兴趣听听。当我 运行 它进行昂贵的调用时,它比我计算和 return 所有值要快得多。
@tf.function
def f2():
a_list = [A(i) for i in range(5)]
idx = tf.cast(tf.random.uniform(shape=[], maxval=4), tf.int32)
return tf.switch_case(idx, a_list)
为了速度比较我做了A expensive matrix algebra的调用方法。然后考虑调用每个函数的备用函数:
@tf.function
def f3():
a_list = [A(i) for i in range(40)]
results = [a() for a in a_list]
return results
运行 f2 有 40 个元素:0.42643 秒
运行 f3 有 40 个元素:14.9153 秒
因此,对于仅选择一个分支的预期 40 倍加速,这看起来是正确的。