如何将低对角元素的Numpy数组转换为全矩阵数组



我有一个由单个矩阵的下对角线元素组成的数组。我可以按照如何在NumPy中将三角形矩阵转换为正方形的方法将其转换为完整矩阵。对于单个矩阵,示例如下:

# Create the lower diagonal elements of a 6x6 matrix.
ld = np.arange(21)
# Create full 6x6 matrix
x = np.zeros((6,6))
# Stuff lower triangular values into it
x[np.tril_indices(6)] = ld
# Populate upper triangular elements
x = x + x.T
# Fix diagonals (they got doubled)
diag_idx = [0, 2, 5, 9, 14, 20]
np.fill_diagonal(x, ld[diag_idx])
print(x)
我们得到了期望的完整矩阵
[[ 0.  1.  3.  6. 10. 15.]
[ 1.  2.  4.  7. 11. 16.]
[ 3.  4.  5.  8. 12. 17.]
[ 6.  7.  8.  9. 13. 18.]
[10. 11. 12. 13. 14. 19.]
[15. 16. 17. 18. 19. 20.]]

现在我想把它扩展到一个包含N个低对角元素集合的数组并且想要得到一个包含N个满矩阵的数组。前者的形状为(N, 21),后者的形状为(N, 6, 6)。I将单矩阵的例子展开为包含2个矩阵的例子

# Two sets of lower diagonal elements
ld = np.arange(2*21).reshape(2, 21)
# Two sets of full 6x6 matrices
x = np.zeros((ld.shape[0], 6, 6))
# Find the lower triangular indices of each row and stuff them with the 
# values from the corresponding row in the lower diagonal array
x[:, np.tril_indices(6)] = ld[:]
# Populate upper triangular elements
x[:] = x[:] + x[:].T
# Fix diagonals (they got doubled)
diag_idx = [0, 2, 5, 9, 14, 20]
np.fill_diagonal(x[:], ld[:][diag_idx])

但是我在x[:, np.tril_indices(6)] = ld[:]

这行得到了形状不匹配
ValueError: shape mismatch: value array of shape (2,21) could not be broadcast to indexing result of shape (2,2,21,6)

我可以在N个对角线值较低的集合上做一个普通的Python循环,但我试图通过Numpy来完成。我的索引哪里出错了,有什么建议吗?

X中的期望值为:

[[[ 0.  1.  3.  6. 10. 15.]
[ 1.  2.  4.  7. 11. 16.]
[ 3.  4.  5.  8. 12. 17.]
[ 6.  7.  8.  9. 13. 18.]
[10. 11. 12. 13. 14. 19.]
[15. 16. 17. 18. 19. 20.]],
[[21., 22., 24., 27., 31., 36.],
[22., 23., 25., 28., 32., 37.],
[24., 25., 26., 29., 33., 38.],
[27., 28., 29., 30., 34., 39.],
[31., 32., 33., 34., 35., 40.],
[36., 37., 38., 39., 40., 41.]]]

您可以这样做,适用于任何N

import numpy as np
N = 3
ld = np.arange(N*21).reshape(N, 21)
x = np.zeros((ld.shape[0], 6, 6))
tril_ind = np.tril_indices(6)
x[:, tril_ind[0], tril_ind[1]] = ld
x += np.transpose(x, (0, 2, 1))
diag_ind = np.diag_indices(6)
x[:, diag_ind[0], diag_ind[1]] /= 2
print(x)

这个打印

[[[ 0.  1.  3.  6. 10. 15.]
[ 1.  2.  4.  7. 11. 16.]
[ 3.  4.  5.  8. 12. 17.]
[ 6.  7.  8.  9. 13. 18.]
[10. 11. 12. 13. 14. 19.]
[15. 16. 17. 18. 19. 20.]]
[[21. 22. 24. 27. 31. 36.]
[22. 23. 25. 28. 32. 37.]
[24. 25. 26. 29. 33. 38.]
[27. 28. 29. 30. 34. 39.]
[31. 32. 33. 34. 35. 40.]
[36. 37. 38. 39. 40. 41.]]
[[42. 43. 45. 48. 52. 57.]
[43. 44. 46. 49. 53. 58.]
[45. 46. 47. 50. 54. 59.]
[48. 49. 50. 51. 55. 60.]
[52. 53. 54. 55. 56. 61.]
[57. 58. 59. 60. 61. 62.]]]

最新更新