numpy对角函数速度较慢



我想在python中实现connect 4游戏,作为一个业余项目,我不知道为什么在对角线上搜索匹配项如此缓慢。在用psstats分析我的代码时,我发现这是瓶颈。我想建立一个计算机敌人,分析游戏中数千个未来步骤,因此性能是个问题。

有人知道如何提高以下代码的性能吗?我选择numpy来做这件事,因为我认为这会加快速度。问题是,我找不到避免for循环的方法。

import numpy as np
# Finds all the diagonal and off-diagonal-sequences in a 7x6 numpy array
def findseq(sm,seq=2,redyellow=1):
matches=0
# search in the diagonals
# diags stores all the diagonals and off diagonals as rows of a matrix
diags=np.zeros((1,6),dtype=np.int8)
for k in range(-5,7):   
t=np.zeros(6,dtype=np.int8)
a=np.diag(sm,k=k).copy()
t[:len(a)] += a
s=np.zeros(6,dtype=np.int8)
a=np.diag(np.fliplr(sm),k=k).copy()
s[:len(a)] += a
diags=np.concatenate(( diags,t[None,:],s[None,:]),axis=0)
diags=np.delete(diags,0,0)
# print(diags)
# now, search for sequences
Na=np.size(diags,axis=1)
n=np.arange(Na-seq+1)[:,None]+np.arange(seq)
seqmat=np.all(diags[:,n]==redyellow,axis=2)
matches+=seqmat.sum()
return matches
def randomdebug():
# sm=np.array([[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,2,1,1,0,0]])
sm=np.random.randint(0,3,size=(6,7))
return sm
# in my main program, I need to do this thousands of times
matches=[]
for i in range(1000):
sm=randomdebug()
matches.append(findseq(sm,seq=3,redyellow=1))
matches.append(findseq(sm,seq=3,redyellow=2))
# print(sm)
# print(findseq(sm,seq=3))

这是psstats

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
2000    1.965    0.001    4.887    0.002 Frage zu diag.py:4(findseq)
151002/103002    0.722    0.000    1.979    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
48000    0.264    0.000    0.264    0.000 {method 'diagonal' of 'numpy.ndarray' objects}
48072    0.251    0.000    0.251    0.000 {method 'copy' of 'numpy.ndarray' objects}
48000    0.209    0.000    0.985    0.000 twodim_base.py:240(diag)
48000    0.179    0.000    1.334    0.000 <__array_function__ internals>:177(diag)
50000    0.165    0.000    0.165    0.000 {built-in method numpy.zeros}

我是python的新手,所以请想象一个标签"无望的角落";-(

正如Andrey在评论中所说,该代码调用了许多需要额外内存分配的np函数。我认为这就是瓶颈。

我建议预先计算所有对角线的索引,因为在你的情况下它们不会有太大变化(矩阵形状保持不变,我想序列可能会改变(。然后你可以使用它们快速寻址对角线:

import numpy as np

known_diagonals = dict()
def diagonal_indices(h: int, w: int, length: int = 3) -> np.array:
'''
Returns array (shape diagonal_count x length) of diagonal indices
of a flatten array
'''
# one of many ways to store precomputed function output
# cleaner way would probably be to do this outside this function
diagonal_indices_key = (h, w, length)
if diagonal_indices_key in known_diagonals:
return known_diagonals[diagonal_indices_key]

diagonals_count = (h + 1 - length) * (w + 1 - length) * 2
# default value is meant to ease process with cumsum:
# adding h + 1 selects an index 1 down and 1 right, h - 1 index 1 down 1 left
# firts half dedicated to right down diagonals
diagonals = np.full((diagonals_count, length), w + 1, dtype=np.int32)
# second half dedicated to left down diagonals
diagonals[diagonals_count//2::] = w - 1
# this could have been calculated mathematicaly
flat_indices = np.arange(w * h).reshape((h, w))
# print(flat_indices)
# selects rectangle offseted by l - 1 from right and down edges
diagonal_starts_rd = flat_indices[:h + 1 - length, :w + 1 - length]
# selects rectangle offseted by l - 1 from left and down edges
diagonal_starts_ld = flat_indices[:h + 1 - length, -(w + 1 - length):]

# sets starts
diagonals[:diagonals_count//2, 0] = diagonal_starts_rd.flatten()
diagonals[diagonals_count//2::, 0] = diagonal_starts_ld.flatten()
# sum triplets left to right
# diagonals contains triplets (or vector of other length) of (start, h+-1, h+-1). cumsum makes diagonal indices
diagonals = diagonals.cumsum(axis=1)
# save ouput
known_diagonals[diagonal_indices_key] = diagonals
return diagonals
# Finds all the diagonal and off-diagonal-sequences in a 7x6 numpy array
def findseq(sm: np.array, seq: int = 2, redyellow: int = 1) -> int:
matches = 0
diagonals = diagonal_indices(*sm.shape, seq)
seqmat = np.all(sm.flatten()[diagonals] == redyellow, axis=1)
matches += seqmat.sum()
return matches
def randomdebug():
# sm=np.array([[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,2,1,1,0,0]])
sm=np.random.randint(0,3,size=(6,7))
return sm
# in my main program, I need to do this thousands of times
matches=[]
for i in range(1000):
sm=randomdebug()
matches.append(findseq(sm,seq=3,redyellow=1))
matches.append(findseq(sm,seq=3,redyellow=2))
# print(sm)
# print(findseq(sm,seq=3))

相关内容

  • 没有找到相关文章

最新更新