TensorFlow 中的函数是什么,相当于 PyTorch 中的 expand()


假设我有一个 2 x 3 的矩阵,我想创建一个 6 x 2 x 3 的矩阵,

其中第一维中的每个元素都是原始的 2 x 3 矩阵。

在 PyTorch 中,我可以这样做:

import torch
from torch.autograd import Variable
import numpy as np
x = np.array([[1, 2, 3], [4, 5, 6]])
x = Variable(torch.from_numpy(x))
# y is the desired result
y = x.unsqueeze(0).expand(6, 2, 3)

在TensorFlow中执行此操作的等效方法是什么?我知道unsqueeze()等同于tf.expand_dims()但我没有TensorFlow等同于expand()。我正在考虑在 1 x 2 x 3 张量列表中使用 tf.concat,但不确定这是否是最好的方法。

pytorch expand 的等效函数是 tensorflow tf.broadcast_to

文档:https://www.tensorflow.org/api_docs/python/tf/broadcast_to

Tensorflow 会自动广播,所以一般来说你不需要做任何这些。假设您有一个形状为 6x2x3 的y',并且您的x形状为 2x3 ,那么您已经可以执行y'*xy'+x已经表现得好像您已经扩展了它一样。但是,如果出于其他原因您确实需要这样做,那么 tensorflow 中的命令tile

y = tf.tile(tf.reshape(x, (1,2,3)), multiples=(6,1,1))

文档:https://www.tensorflow.org/api_docs/python/tf/tile

最新更新