如何使用Tensorflow功能API沿着批次维度进行广播



在一些应用程序中,如slot attention(在Pytorch中实现(,有必要沿批维度进行广播。然而,我不知道如何使用功能性的API来实现这一点。例如,

import tensorflow as tf
const = tf.ones((1,4))
input = tf.keras.layers.Input((4))
const = tf.broadcast_to(const, input.shape)

引发以下错误:

ValueError: Cannot convert a partially known TensorShape to a Tensor: (None, 4)

因此,我求助于对tf.keras.Model进行子类化,但我希望将代码保留在功能性的API中。有人知道如何做到这一点吗?

使用tf.keras.backend.shape:终于找到了答案

const = tf.ones((1,4))
input = tf.keras.layers.Input((4))
const = tf.broadcast_to(const, [tf.keras.backend.shape(input)[0], 4] )
# Shape of const is now (None, 4)

最新更新