Tensorflow相当于火炬.Tensor.index_copy



我正在实现这里最初使用pytorch实现的模型的tensorflow等价物。一切都很顺利,直到我遇到了这一行代码。

batch_current = Variable(torch.zeros(size, self.embedding_dim))
# self.embedding and self.W_c are pytorch network layers I have created
batch_current = self.W_c(batch_current.index_copy(0, Variable(torch.LongTensor(index)),
self.embedding(Variable(self.th.LongTensor(current_node)))))

如果搜索index_copy的文档,它所做的似乎只是在某个索引和公共轴上复制一组元素,并将其分配给另一个张量。但我真的不想写一些有缺陷的代码,所以在尝试任何自我实现之前,我想知道你们是否知道我该如何实现它

该模型来自本文,是的,我已经搜索了其他tensorflow实现,但它们对我来说似乎没有太大意义

您需要的是tensorflow中的tf.tensor_scatter_nd_update,以获得类似pytorch的tensor.index_copy的等效操作。下面是一个演示。

在pytorch中,你有

import torch 
tensor = torch.zeros(5, 3)
indices = torch.tensor([0, 4, 2])
updates= torch.tensor([[1, 2, 3], 
[4, 5, 6], 
[7, 8, 9]], dtype=torch.float)
tensor.index_copy_(0, indices, updates)
tensor([[1., 2., 3.],
[0., 0., 0.],
[7., 8., 9.],
[0., 0., 0.],
[4., 5., 6.]])

在tensorflow中,你可以进行

import tensorflow as tf
tensor = tf.zeros([5,3])
indices = tf.constant([[0], [4], [2]])
updates  = tf.constant([[1, 2, 3], 
[4, 5, 6], 
[7, 8, 9]], dtype=tf.float32)
tensor = tf.tensor_scatter_nd_update(tensor, indices, updates)
tensor.numpy()
array([[1., 2., 3.],
[0., 0., 0.],
[7., 8., 9.],
[0., 0., 0.],
[4., 5., 6.]], dtype=float32)

最新更新