是否由其他操作员更换火炬收集



我有一个脚本代码,其中1x68x8x8x1x2大小

tmp_batch, tmp_channel, tmp_height, tmp_width = x1.size()
x1 = x1.view(tmp_batch*tmp_channel, -1)        
max_ids = torch.argmax(x1, 1)            
max_ids = max_ids.view(-1, 1)

x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = torch.gather(x2, 1, max_ids) # size of 68 x 1

至于上面的代码,当我使用旧的onnx时,我遇到了torch.gather的问题。因此,我想找到一种替代解决方案,用其他运算符替换toch.gather,但用上面的代码给出相同的输出。你能给我一些建议吗?

一个解决方法是使用等效的numpy方法。如果在某个地方包含import numpy as np语句,则可以执行以下操作。

outputs_x_select = torch.Tensor(np.take_along_axis(x2,max_ids,1))

如果这会给你一个与毕业生相关的错误,请尝试

outputs_x_select = torch.Tensor(np.take_along_axis(x2.detach(),max_ids,1))

一种没有numpy的方法:在这种情况下,似乎max_ids每行只包含一个条目。因此,我相信以下内容将起作用:

max_ids = torch.argmax(x1, 1) # do not reshape

x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = x2[torch.arange(tmp_batch*tmp_channel),max_ids]

最新更新