使用tensorflow 2.x时,从任何类型的对象列表中选择一个项目



给定类A[A() for _ in range(5)]的实例列表,我想随机选择其中一个(请参阅下面的代码以获取示例(

class A:
def __init__(self, a):
self.a = a
def __call__(self):
return self.a
def f():
a_list = [A(i) for i in range(5)]
a = a_list[random.randint(0, 5)]()
return a
f()

有没有一种方法可以用@tf.function装饰f,而不改变f的功能,也不调用a_list中的所有项目?

请注意,在不对上述代码进行任何其他更改的情况下,用@tf.function直接修饰f是不可行的,因为它总是返回相同的结果。此外,我知道这可以通过首先调用a_list中的所有元素,然后使用tf.gather_nd对它们进行索引来实现。但是,如果调用A类型的对象涉及深度神经网络,这将产生大量开销。

我现在正在做同样的事情。以下是我迄今为止的收获。如果有人知道更好的方法,我也很想听听。当我在一个昂贵的调用上运行它时,它比我计算并返回所有值的速度要快得多。

@tf.function
def f2():
a_list = [A(i) for i in range(5)]
idx = tf.cast(tf.random.uniform(shape=[], maxval=4), tf.int32)
return tf.switch_case(idx, a_list)

为了进行速度比较,我使用了昂贵的矩阵代数的调用方法。然后考虑一个调用每个函数的替代函数:

@tf.function
def f3():
a_list = [A(i) for i in range(40)]
results = [a() for a in a_list]
return results

使用40个元素运行f2:0.42643秒

运行具有40个元素的f3:14.9153秒

因此,对于只选择一个分支的预期40倍加速来说,这似乎是正确的。

相关内容

  • 没有找到相关文章

最新更新