我想用numba制作RK4以加速。我是一个使用麻木的初学者。为什么麻木者不能理解我的代码?
简单的代码如下
在swing.py 中
@numba.jit(nopython=True)
def RK4(func, t_end, X0, dt):
t = np.arange(0,t_end, dt, dtype=np.float64)
X = np.zeros((t.shape[0], X0.shape[0]))
X[0] = X0
hdt = dt*.5
for i in range(t.shape[0]-1):
t1 = t[i]
x1 = X[i]
k1 = func(t[i], X[i])
t2 = t[i] + hdt
x2 = X[i] + hdt * k1
k2 = func(t2, x2)
t3 = t[i] + hdt
x3 = X[i] + hdt * k2
k3 = func(t3, x3)
t4 = t[i] + dt
x4 = X[i] + dt * k3
k4 = func(t4, x4)
X[i+1] = X[i] + dt / 6. * (k1 + 2. * k2 + 2. * k3 + k4)
return X
# dyummy function for test
@numba.jit(nopython=True)
def fff(t, X):
t = 1
X = 3
res = [0]
res.append(t*X)
return res
运行的主要代码。
import numpy as np
import numba
swing.RK4(swing.fff, 10, np.array([0,1]), 0.1)
错误消息如下:但我不明白这个简单的代码中有什么不正确的地方。
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
Input In [2], in <cell line: 1>()
----> 1 swing.RK4(swing.fff, 10, np.array([0,1]), 0.1)
File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464 msg = (f"{str(e).rstrip()} nnThis error may have been caused "
465 f"by the following argument(s):n{args_str}n")
466 e.patch_message(msg)
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
471 error_rewrite(e, 'unsupported_error')
File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function mul>) found for signature:
>>> mul(float64, list(int64)<iv=[0]>)
There are 14 candidate implementations:
- Of which 12 did not match due to:
Overload of function 'mul': File: <numerous>: Line N/A.
With argument(s): '(float64, list(int64)<iv=None>)':
No match.
- Of which 2 did not match due to:
Operator Overload in function 'mul': File: unknown: Line unknown.
With argument(s): '(float64, list(int64)<iv=None>)':
No match for registered cases:
* (int64, int64) -> int64
* (int64, uint64) -> int64
* (uint64, int64) -> int64
* (uint64, uint64) -> uint64
* (float32, float32) -> float32
* (float64, float64) -> float64
* (complex64, complex64) -> complex64
* (complex128, complex128) -> complex128
During: typing of intrinsic-call at /disk/disk2/youngjin/workspace/workspace/DS/Inference/MCMC/Swing/swing.py (36)
File "swing.py", line 36:
def RK4(func, t_end, X0, dt):
<source elided>
t2 = t[i] + hdt
x2 = X[i] + hdt * k1
^
你找到的原因和解决方案了吗
解决方案
分枝杆菌中的。py
import numpy as np
from scipy import integrate
from typing import Union, List
import numba
def AdjMtoAdjL(adjM: np.ndarray) -> list:
return [np.argwhere(adjM[:,i] > 0).flatten() for i in range(len(adjM))]
def AdjMtoEdgL(adjM: np.ndarray) -> np.ndarray:
return np.argwhere(adjM > 0)
@numba.jit(nopython=True)
# def swing(t, y, model_param, model):
def swing(t, y, phi, m, gamma, P, K, model):
if model == "swing":
T, O = y
T = np.array([T])
O = np.array([O])
else:
T = y
# Get Interaction
Interaction = K*np.sin(T-phi)
"""
dot{theta} &= omega \
dot{omega} &= frac{1}{m}(P-gammaomega+Sigma Ksin(theta-phi))
"""
if model == "swing":
dT = O
dO = 1/m*(P - gamma*O - Interaction)
dydt = np.concatenate((dT, dO))#, dtype=np.float64)
else:
dydt = P + Interaction
return dydt
@numba.jit(nopython=True)
def RK4(func, t_end, X0, dt, phi, m, gamma, P, K, model):
t = np.arange(0,t_end, dt, dtype=np.float64)
X = np.zeros((t.shape[0], X0.shape[0]))
X[0] = X0
hdt = dt*.5
for i in range(t.shape[0]-1):
t1 = t[i]
x1 = X[i]
k1 = func(t[i], X[i], phi, m, gamma, P, K, model)
t2 = t[i] + hdt
x2 = X[i] + hdt * k1
k2 = func(t2, x2, phi, m, gamma, P, K, model)
t3 = t[i] + hdt
x3 = X[i] + hdt * k2
k3 = func(t3, x3, phi, m, gamma, P, K, model)
t4 = t[i] + dt
x4 = X[i] + dt * k3
k4 = func(t4, x4, phi, m, gamma, P, K, model)
X[i+1] = X[i] + dt / 6. * (k1 + 2. * k2 + 2. * k3 + k4)
return X
maincode.ipynb
import networkx as nx
import os
import multiprocessing as mp
from multiprocessing import Pool
import time
import numpy as np
import swing
def multiprocess(Ngrid=101, t_end=30., omega_lim=30, dt=.001, n_cpu=19):
start = int(time.time())
T_range = np.linspace(0, 2*np.pi, Ngrid)
O_range = np.linspace(-omega_lim, omega_lim, Ngrid)
paramss = []
for theta in T_range:
for omega in O_range:
y0 = np.hstack((
theta, # Theta
omega, # Omega
))
params = {}
params['sparam'] = Swing_Parameters
params['t_end'] = t_end
params['init'] = y0
params['dt'] = dt
paramss.append(params)
del([[params]])
p = Pool(processes=n_cpu)
result = p.map(solve_func, paramss)
end = int(time.time())
print("***run time(sec) : ", end-start)
print("Number of Core : " + str(n_cpu))
return result
def solve_func(params):
Swing_Parameters = params['sparam']
t_end = params['t_end']
y0 = params['init']
dt = params['dt']
# model = swing.SwingSingle(**Swing_Parameters)
t_eval = np.arange(0,t_end, dt)
# solution = integrate.solve_ivp(model, [0,t_end], y0, dense_output=False,
# t_eval=t_eval, vectorized=True, method="LSODA")
phi = Swing_Parameters["phi"]
m = Swing_Parameters["m"]
gamma = Swing_Parameters["gamma"]
P = Swing_Parameters["P"]
K = Swing_Parameters["K"]
_model = Swing_Parameters["model"]
solution = swing.RK4(swing.swing, t_end, y0, dt, phi, m, gamma, P, K, _model)
return solution
Ngrid = 301
t_end = 24.
omega_lim = 30
dt = .05
Ngrid = 301
t_end = 24.
omega_lim = 30
dt = .05
Swing_Parameters = {
"phi": np.pi,
"m": 1.,
"gamma": 0.3,
"P": 2.,
"K": 8.,
"model": "swing"
}
model = swing.SwingSingle(**Swing_Parameters)
res = multiprocess(Ngrid=Ngrid, t_end=t_end, omega_lim=omega_lim, dt=dt, n_cpu=19)