是否有Numpy或pyTorch函数用于此代码?



基本上有一个Numpy或PyTorch函数做到这一点:

vp_sa_s=mdp_data['sa_s'].detach().clone()
dims = vp_sa_s.size()
for i in range(dims[0]):
for j in range(dims[1]):
for k in range(dims[2]):
# to mimic matlab functionality: vp(mdp_data.sa_s)
try:
vp_sa_s[i,j,k] = vp[mdp_data['sa_s'][i,j,k]]
except:
pass

给定vp_sa_s的大小为(10,5,5),并且每个值都是有效的索引vp,即在0-9的范围内。vp是大小为(10,1)的一堆随机值。

Matlab用vp(mdp_data.sa_s)优雅而快速地完成它,这将形成一个新的(10,5,5)矩阵。如果mdp_data.sa_s中的所有值都是1,则结果将是(10,5,5)张量,每个值都是vp中的第一个值。

是否存在一个函数或方法可以在少于O(N^3)的时间内实现此目标,因为上面的代码非常低效。

谢谢!

怎么了

result = vp[vp_sa_s, 0]

注意,因为你的vp的形状是(10, 1)(它有一个末尾的单例维度),你需要在赋值中添加, 0]索引来去掉这个额外的维度。

相关内容

  • 没有找到相关文章

最新更新