使用tf.滚动时增加移位值和堆叠



我有以下情况。我有一个i坐标(x, y, z)的列表,并且必须计算一个截止球内的所有三元组,使得r_ijr_ik小于一个截止值。因此,我正在计算包含所有距离的矩阵r_ij。为了计算三元组,我的想法是构造一个r_ijk矩阵。

我通过循环遍历元素i作为

的数量来做到这一点
import tensorflow as tf
r_ij = tf.reshape(tf.range(4*4), (4, 4))
r_ijk = []
for i in range(len(x)):
r_ijk.append(tf.roll(r_ij, shift=-i, axis=1))
tf.stack(r_ijk)

我想改进这段代码,因为有两个问题。主要是因为我假设,它可以完全矢量化。但是为了在我的模型中使用它,我需要修改它:

@tf.function
def get_triplets(full_r_ij, r_cut):
r_ij = tf.norm(full_r_ij, axis=-1)  # Shape of full_r_ij is (n_timesteps, n_atoms, n_atoms, 3)
n_atoms = tf.shape(r_ij)[1]
r_ijk = r_ij[None]
for atom in range(1, n_atoms):
tf.autograph.experimental.set_loop_options(
shape_invariants=[(r_ijk, tf.TensorShape([None, None, None, None]))]
)
tmp = tf.roll(r_ij, shift=-atom, axis=2)
r_ijk = tf.concat([r_ijk, tmp[None]], axis=0)  # shape is (n_atoms, n_timesteps, n_atoms, n_atoms)
r_ijk = tf.transpose(r_ijk, perm=(1, 0, 2, 3))
r_ijk = tf.where(r_ijk == 0, tf.ones_like(r_ijk) * r_cut, r_ijk)
intermediate_indices = tf.where(
tf.math.logical_and(r_ijk[:, 0, None] == 3.0, r_ijk[:, 1:] == 3.0)
)
n_atoms = tf.cast(n_atoms, dtype=tf.int64)
t, n, i, j = tf.unstack(intermediate_indices, axis=1)
k = j + n + 1
k = tf.where(k >= n_atoms, k - n_atoms, k)
triples = tf.stack([t, i, j, k], axis=1)
return triples

并使用tf.autograph.experimental.set_loop_options因为我要在r_ij张量上循环。是否有一种方法可以改进第一个代码示例(或第二个代码示例)?

我使用tf.vectorized_madtf.map_fn测试了两个进一步的实现,它们的性能都比我编写的初始函数差。所有试验均采用r_ij = tf.random.normal((32, 150, 150))

@tf.function
def roll_loop(r_ij, n_atoms):
r_ijk = r_ij[None]
for atom in range(1, n_atoms):
tf.autograph.experimental.set_loop_options(
shape_invariants=[(r_ijk, tf.TensorShape([None, None, None, None]))]
)
tmp = tf.roll(r_ij, shift=-atom, axis=2)
r_ijk = tf.concat([r_ijk, tmp[None]], axis=0)  # shape is (n_atoms, n_timesteps, n_atoms, n_atoms)
return r_ijk

129 ms ± 1.98 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

@tf.function
def roll_vect(r_ij, n_atoms):
r_ijk = tf.repeat(r_ij[None], repeats=n_atoms, axis=0)
def roll(args):
x, shift = args
return tf.roll(x, shift=shift, axis=2)
return tf.vectorized_map(roll, [r_ijk, tf.range(n_atoms)])

225 ms ± 15.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

@tf.function
def roll_map(r_ij, n_atoms):
r_ijk = tf.repeat(r_ij[None], repeats=n_atoms, axis=0)
def roll(args):
x, shift = args
return tf.roll(x, shift=shift, axis=2)
return tf.map_fn(roll, (r_ijk, tf.range(n_atoms)), fn_output_signature=tf.float32)

327 ms ± 18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

因此,使用python的tf.functionfor循环似乎是最快的(到目前为止)。所有函数在测试前都已编译。

编辑:使用tf.TensorArray似乎是完成此任务的最佳方法。我用几个不同的输入测试了它,它的表现与tf.autograph.experimental.set_loop_options

一样好,甚至更好一点。
@tf.function
def roll_loop(r_ij, n_atoms):
r_ijk = tf.TensorArray(tf.float32, size=n_atoms)
for atom in range(0, n_atoms):
tmp = tf.roll(r_ij, shift=-atom, axis=2)
r_ijk = r_ijk.write(atom, tmp)
return r_ijk.stack()

128 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

相关内容

  • 没有找到相关文章

最新更新