我有一个相当大的numpy数组…
power = ...
print power.shape
>>> (3, 10, 10, 19, 75, 10, 10)
是对称的,即10x10部分,即以下二维矩阵是对称的
power[i, :, :, j, k, l, m]
power[i, j, k, l, m, :, :]
对于i, j, k, l, m的所有值
我可以利用这个因子4增益吗?例如,当保存矩阵文件(50 mb与savez_compressed)
我的尝试:
size = 10
row_idx, col_idx = np.tril_indices(size)
zip_idx = zip(row_idx, col_idx)
print len(zip_idx), zip_idx[:5]
>>> 55 [(0, 0), (1, 0), (1, 1), (2, 0), (2, 1)]
all_idx = [(r0, c0, r1, c1) for (r0, c0) in zip_idx for (r1, c1) in zip_idx]
print len(all_idx), all_idx[:5]
>>> 3025 [(0, 0, 0, 0), (0, 0, 1, 0), (0, 0, 1, 1), (0, 0, 2, 0), (0, 0, 2, 1)]
a, b, c, d = zip(*all_idx)
tril_part = np.transpose(s.power, (0, 3, 4, 1, 2, 5, 6))[:,:,:, a, b, c, d]
print tril_part.shape
>>> (3, 19, 75, 3025)
这看起来很丑,但"有效"…一旦我也可以从tril_part中恢复电源…
我想这产生了两个问题:
- 从power到tril_part的更好方法?
- 如何从tril_part到power?
编辑:"大小"评论显然是有效的,但请忽略它:-)恕我直言,问题的索引部分是独立的。我一直发现自己想为较小的矩阵做类似的索引。
你在正确的道路上。使用np.tril_indices
,你确实可以巧妙地索引这些较低的三角形。需要改进的是数据的实际索引/切片。
请试试这个(复制粘贴):
import numpy as np
shape = (3, 10, 10, 19, 75, 10, 10)
p = np.arange(np.prod(shape)).reshape(shape) # this is not symmetric, but not important
ix, iy = np.tril_indices(10)
# In order to index properly, we need to add axes. This can be done by hand or with this
ix1, ix2 = np.ix_(ix, ix)
iy1, iy2 = np.ix_(iy, iy)
p_ltriag = p[:, ix1, iy1, :, :, ix2, iy2]
print p_ltriag.shape # yields (55, 55, 3, 19, 75), axis order can be changed if needed
q = np.zeros_like(p)
q[:, ix1, iy1, :, :, ix2, iy2] = p_ltriag # fills the lower triangles on both sides
q[:, ix1, iy1, :, :, iy2, ix2] = p_ltriag # fills the lower on left, upper on right
q[:, iy1, ix1, :, :, ix2, iy2] = p_ltriag # fills the upper on left, lower on right
q[:, iy1, ix1, :, :, iy2, ix2] = p_ltriag # fills the upper triangles on both sides
数组q
现在包含p
的对称版本(其中上部三角形被替换为下部三角形的内容)。请注意,最后一行包含iy
和ix
索引,顺序相反,实际上创建了下三角矩阵的转置。
下三角形的比较为了进行对比,我们将上面所有的三角形都设置为0
ux, uy = np.triu_indices(10)
p[:, ux, uy] = 0
q[:, ux, uy] = 0
p[:, :, :, :, :, ux, uy] = 0
q[:, :, :, :, :, ux, uy] = 0
print ((p - q) ** 2).sum() # euclidean distance is 0, so p and q are equal
print ((p ** 2).sum(), (q ** 2).sum()) # prove that not all entries are 0 ;) - This has a negative result due to an overflow