给定类A
、[A() for _ in range(5)]
的实例列表,我想随机选择其中一个(请参阅下面的代码以获取示例(
class A:
def __init__(self, a):
self.a = a
def __call__(self):
return self.a
def f():
a_list = [A(i) for i in range(5)]
a = a_list[random.randint(0, 5)]()
return a
f()
有没有一种方法可以用@tf.function
装饰f
,而不改变f
的功能,也不调用a_list
中的所有项目?
请注意,在不对上述代码进行任何其他更改的情况下,用@tf.function
直接修饰f
是不可行的,因为它总是返回相同的结果。此外,我知道这可以通过首先调用a_list
中的所有元素,然后使用tf.gather_nd
对它们进行索引来实现。但是,如果调用A
类型的对象涉及深度神经网络,这将产生大量开销。
我现在正在做同样的事情。以下是我迄今为止的收获。如果有人知道更好的方法,我也很想听听。当我在一个昂贵的调用上运行它时,它比我计算并返回所有值的速度要快得多。
@tf.function
def f2():
a_list = [A(i) for i in range(5)]
idx = tf.cast(tf.random.uniform(shape=[], maxval=4), tf.int32)
return tf.switch_case(idx, a_list)
为了进行速度比较,我使用了昂贵的矩阵代数的调用方法。然后考虑一个调用每个函数的替代函数:
@tf.function
def f3():
a_list = [A(i) for i in range(40)]
results = [a() for a in a_list]
return results
使用40个元素运行f2:0.42643秒
运行具有40个元素的f3:14.9153秒
因此,对于只选择一个分支的预期40倍加速来说,这似乎是正确的。