我正在研究一个个人项目,以在Python中编码四边形模拟(和控制(,作为一个学习项目。我正在使用scipy
Integrator odeint
,在漫长的计算时间内,我感到非常失望。因此,我希望使用numba
加速我的集成。我调用odeint
每个时间段,因为我必须在每个模拟时间段之后创建命令。
首先,当我的功能集成(state_dot
(是Quadcopter
类的方法时,我遇到了问题。因此,我使其成为一个单独的函数,但是现在我用@jit
装饰功能时在定义正确的类型方面遇到问题。state_dot
函数具有字典(params
(作为输入参数(我已经读到Numba支持字典(,但也是自定义类(wind
(,因为我的风模型是该类的方法。如果我现在排除wind
,则使用numba.typed.Dict
似乎无法用来导入字典。
要在函数中导入wind
对象,我已经看到了使用NUMBA类型object_
,但是Python在NUMBA中找不到object_
。
我使用的是Numba版本0.45.0和Python 3.7。
import numpy as np
from scipy.integrate import odeint
from numba import jit, void, float_, int_
import numba
class Quadcopter:
def __init__(self):
# Quad Params
# ---------------------------
mB = 1.2 # mass (kg)
params = {}
params["mB"] = mB
self.params = params
# Initial State
# ---------------------------
self.state = np.zeros(3)
def update(self, t, Ts, cmd, wind):
self.state = odeint(state_dot, self.state, [t,t+Ts], args = (cmd, self.params, wind))[1]
@jit(void(float_[:], float_, float_[:], numba.typed.Dict )) #(nopython = True)
def state_dot(state, t, cmd, params, wind):
# Import Params
# ---------------------------
mB = params["mB"]
# Import State Vector
# ---------------------------
x = state[0]
y = state[1]
z = state[2]
# Motor Dynamics and Rotor forces (Second Order System: https://apmonitor.com/pdc/index.php/Main/SecondOrderSystems)
# ---------------------------
print(cmd)
# Wind Model
# ---------------------------
[velW, qW1, qW2] = wind.randomWind(t)
print(velW)
# State Derivative Vector
# ---------------------------
sdot = np.zeros(3)
sdot[0] = x*t + 0.1
sdot[1] = y*t + 0.1
sdot[2] = z*t + 0.1
return sdot
class Wind:
def __init__(self):
# Normally, average wind would be randomly set here
self.velW_med = 5.0
self.qW1_med = 0.2
self.qW2_med = 0.1
def randomWind(self, t):
# Normally, wind values would be a sine function dependant of current time
velW = self.velW_med
qW1 = self.qW1_med
qW2 = self.qW2_med
return velW, qW1, qW2
# Set time
Ti = 0
Ts = 0.005
Tf = 10
# Initialize quadcopter and wind
quad = Quadcopter()
wind = Wind()
# Simulation
t = Ti
while round(t,3) < Tf:
cmd = np.array([1,2,1,3])
quad.update(t, Ts, cmd, wind)
print(quad.state)
t += Ts
收到的错误是
Traceback (most recent call last):
File "c:/Users/JOHN-Laptop/Documents/Code Dev/Test/question_quad.py", line 29, in <module>
@jit(void(float_[:], float_, float_[:], numba.typed.Dict )) #(nopython = True)
File "C:UsersJOHN-LaptopAppDataLocalProgramsPythonPython37libsite-packagesnumbadecorators.py", line 186, in wrapper
disp.compile(sig)
File "C:UsersJOHN-LaptopAppDataLocalProgramsPythonPython37libsite-packagesnumbacompiler_lock.py", line 32, in _acquire_compile_lock
return func(*args, **kwargs)
File "C:UsersJOHN-LaptopAppDataLocalProgramsPythonPython37libsite-packagesnumbadispatcher.py", line 676, in compile
args, return_type = sigutils.normalize_signature(sig)
File "C:UsersJOHN-LaptopAppDataLocalProgramsPythonPython37libsite-packagesnumbasigutils.py", line 48, in normalize_signature
check_type(ty)
File "C:UsersJOHN-LaptopAppDataLocalProgramsPythonPython37libsite-packagesnumbasigutils.py", line 43, in check_type
"instance, got %r" % (ty,))
TypeError: invalid type in signature: expected a type instance, got <class 'numba.typed.typeddict.Dict'>
可以在此处查看完整的代码:https://github.com/bobzwik/quadcopter_simcon/blob/dev_numba/simulation/quadfiles/quadfiles/quad.py.py
如果我缺少任何信息,请随时提出。
编辑:更改完整代码的链接,链接到另一个分支。
我注意到的第一件事是 - 至少在您在此处显示的代码中 - 您的JIT签名具有四种类型,但是您所装饰的功能有五个参数:
@jit(void(float_[:], float_, float_[:], numba.typed.Dict))
def state_dot(state, t, cmd, params, wind):
很明显,您需要解决这个问题。最简单的尝试就是删除签名,然后让Numba弄清楚:
@jit
def state_dot(state, t, cmd, params, wind):
当然,即使您这样做,Numba仍然抱怨它不知道如何键入所有内容,并指向说mB = params["mB"]
的行。它仍然可以进行" loop升起",这意味着它能够编译某些东西,但不会尽可能快。
因此,第二件事要注意的是,虽然Numba表示它支持dict
s,但随后引发了许多警告。基本上,使用dict仍然不是一个好主意。我也看不到您使用dict的任何充分理由。为什么不像self.mB = mB
一样使mB
成为班级的成员呢?我知道您在完整的Quadcopter
课程中您会有更多复杂的事情,但是您可以有很多成员。
现在,要注意的第三件事是,自从我写了您在其他地方指出的要点以来,Numba变得更好了,现在可以处理课堂,因此您可能需要研究numba.jitclass
。通常,当您将jitclass对象传递给您要jit的函数时,numba将知道如何处理。
,但也许比所有这些都重要的是,您的update
方法为每个步骤调用odeint
。我想这是代码中最慢的部分。该功能本来可以调用一次,以便可以从头到尾解决您的整个问题,因此它具有许多(相对较慢(的开销,与理解您传递的参数,分配内存,初始化事物等有关。一种更好的方法是构造一个scipy.integrate.ode
对象,以使所有步骤之间设置所有设置,并将其保持在周围,以便您可以在步骤之间使用相同的步骤。较新的接口solve_ivp
和RK45
(以及类似(大致相当于odeint
和ode
,只是ode
具有我的首选求解器dop853
。如果您只需要OdeSolver
子类之一,则可能更喜欢这些接口。另请注意,如果您实际上在状态中实际上任何内容步骤之间,则可能需要再次致电set_initial_value
,否则情况可能会出错而不注意。
更一般而言,如果您担心速度,那么您能做的最好的事情就是介绍您的代码。这里的第一步是仅在ipython中使用%prun
。