我有一个脚本代码,其中1x68x8x8
的x1
和x2
大小
tmp_batch, tmp_channel, tmp_height, tmp_width = x1.size()
x1 = x1.view(tmp_batch*tmp_channel, -1)
max_ids = torch.argmax(x1, 1)
max_ids = max_ids.view(-1, 1)
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = torch.gather(x2, 1, max_ids) # size of 68 x 1
至于上面的代码,当我使用旧的onnx
时,我遇到了torch.gather
的问题。因此,我想找到一种替代解决方案,用其他运算符替换toch.gather
,但用上面的代码给出相同的输出。你能给我一些建议吗?
一个解决方法是使用等效的numpy方法。如果在某个地方包含import numpy as np
语句,则可以执行以下操作。
outputs_x_select = torch.Tensor(np.take_along_axis(x2,max_ids,1))
如果这会给你一个与毕业生相关的错误,请尝试
outputs_x_select = torch.Tensor(np.take_along_axis(x2.detach(),max_ids,1))
一种没有numpy的方法:在这种情况下,似乎max_ids
每行只包含一个条目。因此,我相信以下内容将起作用:
max_ids = torch.argmax(x1, 1) # do not reshape
x2 = x2.view(tmp_batch*tmp_channel, -1)
outputs_x_select = x2[torch.arange(tmp_batch*tmp_channel),max_ids]