Tensorflow:奇怪的广播行为



我不太明白广播机制在Tensorflow中是如何工作的。假设我们有以下代码:

W1_shape = [5, 5, 1, 32]
b1_shape = [32]
x = tf.placeholder(tf.float32)
initial_W1 = tf.truncated_normal(shape=W1_shape, stddev=0.1)
W1 = tf.Variable(initial_W1)
initial_b1 = tf.constant(0.1, shape=b1_shape)
b1 = tf.Variable(initial_b1)
conv1 = tf.nn.conv2d(x, W1, strides=[1, 1, 1, 1], padding='SAME')
conv1_sum = conv1 + b1
y = tf.placeholder(tf.float32)
z = conv1 + y
sess = tf.Session()
# Run init ops
init = tf.global_variables_initializer()
sess.run(init)
while True:
samples, labels, indices = dataset.get_next_batch(batch_size=1000)
samples = samples.reshape((1000, MnistDataSet.MNIST_SIZE, MnistDataSet.MNIST_SIZE, 1))
y_data = np.ones(shape=(1000, 32))
conv1_res, conv1_sum_res, b1_res, z_res=
sess.run([conv1, conv1_sum, b1, z], feed_dict={x: samples, y: y_data})
if dataset.isNewEpoch:
break

因此,我加载了 MNIST 数据集,该数据集由 28x28 大小的图像组成。卷积运算符使用 32 个大小为 5x5 的过滤器。我使用的批大小为 1000,因此数据张量x的形状为 (1000,28,28,1(。tf.nn.conv2d运算输出形状为 (1000,28,28,32( 的张量。y是一个占位符,我添加一个变量,通过将其添加到(1000,28,28,32(形状的conv1张量来检查Tensorflow的广播机制。在y_data = np.ones(shape=(1000, 32))行中,我试验了各种张量形状的y。形状 (28,28(、(1000,28( 和 (1000,32( 不会添加到conv1中,但类型错误如下:

无效参数错误(有关回溯,请参见上文(:不兼容的形状:[1000,28,28,32] 与 [28,28]

形状 (28,32( 和 (28,28,32( 工作和广播正确。但是根据 https://www.tensorflow.org/performance/xla/broadcasting 解释的广播语义,前三个形状也必须工作,因为它们通过匹配维度与4Dconv1张量来具有正确的顺序。例如,(28,28( 匹配维度 1 和 2 中的 (1000,28,28,32(,维度 0 和 3 中的 (1000,32( 匹配项,就像链接中提到的。我在这里错过或误解了什么吗?在这种情况下,Tensorflow 的正确广播行为是什么?

确实,文档似乎在暗示您所说的内容。但看起来它遵循numpy broadcsting rules

在两个阵列上操作时,NumPy 会比较它们的形状 元素方面。它从尾随维度开始,并工作 前进的方式。在以下情况下,两个维度兼容:

  1. 它们是相等的,或者
  2. 其中之一是 1

所以根据上面的定义:

(28, 28(
  • 不能广播到 (1000, 28, 28, 32(,但 (28, 28, 1( 可以。
  • (1000,28( 不能但 (1000, 1, 28, 1(
  • 或 (1000, 28, 1, 1( 可以

  • (28, 32( 有效,因为尾随维度匹配。

最新更新