JAX grad函数:为什么我得到的是一个零列表而不是梯度

我试图找到一个具有多个变量(500+(的Python函数的全局最大值。为此,我尝试使用JAX grad((来计算这个MyFunction的梯度函数。



from jax import grad
import jax.numpy as jnp
import numpy as np
import json
# Example data - in real case I have 150,000+ rows
data = jnp.array([[ 1.  ,  1.06,  9.77,  5.  ,  3.  ,  2.  ,  6.  , 12.  ,  4.  ,
10.  ,  1.  ,  7.  ,  1.  , 12.  , 10.  , 12.  ,  4.  , 10.  ,
8.  ,  7.  , 11.  ,  5.  ,  9.  ,  3.  ,  6.  , 12.  ,  6.  ,
5.  ,  3.  ,  5.  ,  9.  ,  8.  ,  9.  , 10.  , 11.  , 12.  ,
1.  ,  2.  ,  3.  ,  4.  ,  5.  ,  6.  ,  7.  ,  0.  ,  0.  ,
0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
0.  ,  0.  ,  0.  ], 
[ 1.  ,  1.33,  3.33,  5.  ,  3.  ,  2.  ,  6.  , 12.  ,  4.  ,
10.  ,  1.  ,  7.  ,  1.  , 12.  , 10.  , 12.  ,  4.  , 10.  ,
8.  ,  7.  , 11.  ,  5.  ,  9.  ,  3.  ,  6.  , 12.  ,  6.  ,
5.  ,  3.  ,  5.  ,  9.  ,  8.  ,  9.  , 10.  , 11.  , 12.  ,
1.  ,  2.  ,  3.  ,  4.  ,  5.  ,  6.  ,  7.  ,  0.  ,  0.  ,
0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
0.  ,  0.  ,  0.  ], 
[ 2.  ,  1.65,  2.07,  5.  ,  3.  ,  2.  ,  6.  , 12.  ,  4.  ,
10.  ,  1.  ,  7.  ,  1.  , 12.  , 10.  , 12.  ,  4.  ,  8.  ,
6.  ,  5.  ,  9.  ,  3.  ,  7.  ,  1.  ,  4.  , 10.  ,  4.  ,
3.  ,  1.  ,  3.  ,  7.  , 10.  , 11.  , 12.  ,  1.  ,  2.  ,
3.  ,  4.  ,  5.  ,  6.  ,  7.  ,  8.  ,  9.  ,  0.  ,  0.  ,
0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,
0.  ,  0.  ,  0.  ]])

def MyFunction(coefs, data):

balance = float(len(data)*-1000)

for row in data:
result = row[0]
fOdds = row[1]
dOdds = row[2]
h1P = 0.0
h2P = 0.0
h3P = 0.0
h4P = 0.0
h5P = 0.0
h6P = 0.0
h7P = 0.0
h8P = 0.0
h9P = 0.0
h10P = 0.0
h11P = 0.0
h12P = 0.0

for p in range (0, 14):
s = int(row[3+p]-1)
h = int(row[17+p])
r = int(row[43+p])

bCoef = coefs[p]
sCoef = coefs[14 + (p * 12) + s]
hCoef = coefs[182 + (p * 12) + h]
if r == 1:
rCoef = coefs[350 + p]
rCoef = 1.0

pStrength = bCoef * sCoef * hCoef * rCoef

if h == 0:
h1P += pStrength
if h == 1:
h2P += pStrength
if h == 2:
h3P += pStrength
if h == 3:
h4P += pStrength
if h == 4:
h5P += pStrength
if h == 5:
h6P += pStrength
if h == 6:
h7P += pStrength
if h == 7:
h8P += pStrength
if h == 8:
h9P += pStrength
if h == 9:
h10P += pStrength
if h == 10:
h11P += pStrength
if h == 11:
h12P += pStrength

for h in range (0, 12):
hSign = int(row[31+h]-1)
if h == 0:
h1P *= coefs [364 + (h*12) + hSign]
if h == 1:
h2P *= coefs [364 + (h*12) + hSign]
if h == 2:
h3P *= coefs [364 + (h*12) + hSign]
if h == 3:
h4P *= coefs [364 + (h*12) + hSign]
if h == 4:
h5P *= coefs [364 + (h*12) + hSign]
if h == 5:
h6P *= coefs [364 + (h*12) + hSign]
if h == 6:
h7P *= coefs [364 + (h*12) + hSign]
if h == 7:
h8P *= coefs [364 + (h*12) + hSign]
if h == 8:
h9P *= coefs [364 + (h*12) + hSign]
if h == 9:
h10P *= coefs [364 + (h*12) + hSign]
if h == 10:
h11P *= coefs [364 + (h*12) + hSign]
if h == 11:
h12P *= coefs [364 + (h*12) + hSign]

fPoints = 0.0
dPoints = 0.0

fPoints += h1P * coefs[508]
fPoints += h2P * coefs[509]
fPoints += h3P * coefs[510]
fPoints += h4P * coefs[511]
fPoints += h5P * coefs[512]
fPoints += h6P * coefs[513]
fPoints += h7P * coefs[514]
fPoints += h8P * coefs[515]
fPoints += h9P * coefs[516]
fPoints += h10P * coefs[517]
fPoints += h11P * coefs[518]
fPoints += h12P * coefs[519]

dPoints += h1P * coefs[520]
dPoints += h2P * coefs[521]
dPoints += h3P * coefs[522]             
dPoints += h4P * coefs[523]     
dPoints += h5P * coefs[524]
dPoints += h6P * coefs[525]
dPoints += h7P * coefs[526]     
dPoints += h8P * coefs[527]
dPoints += h9P * coefs[528]
dPoints += h10P * coefs[529]
dPoints += h11P * coefs[530]
dPoints += h12P * coefs[531]

if result == 1:
if fPoints >= dPoints:
balance += fOdds*1000

elif result == 2:
if dPoints > fPoints:
balance += dOdds*1000

return balance

derivFunction = grad (MyFunction)
coefs = np.random.sample(532)
# here I just get a list of 532 zeros instead of the derivatives...
print (derivFunction(coefs, data))
coefs = np.random.sample(532)
print (derivFunction(coefs, data))


print(MyFunction(coefs, data))
# -3000.0
print(MyFunction(coefs + 0.1, data))
# -3000.0
print(MyFunction(coefs - 0.1, data))
# -3000.0


  • 没有找到相关文章
