使用numba加速odeint,在尝试传递字典和自定义对象时会发出问题



我正在研究一个个人项目,以在Python中编码四边形模拟(和控制(,作为一个学习项目。我正在使用scipy Integrator odeint,在漫长的计算时间内,我感到非常失望。因此,我希望使用numba加速我的集成。我调用odeint每个时间段,因为我必须在每个模拟时间段之后创建命令。

首先,当我的功能集成(state_dot(是Quadcopter类的方法时,我遇到了问题。因此,我使其成为一个单独的函数,但是现在我用@jit装饰功能时在定义正确的类型方面遇到问题。state_dot函数具有字典(params(作为输入参数(我已经读到Numba支持字典(,但也是自定义类(wind(,因为我的风模型是该类的方法。如果我现在排除wind,则使用numba.typed.Dict似乎无法用来导入字典。

要在函数中导入wind对象,我已经看到了使用NUMBA类型object_,但是Python在NUMBA中找不到object_

我使用的是Numba版本0.45.0和Python 3.7。

import numpy as np
from scipy.integrate import odeint
from numba import jit, void, float_, int_
import numba
class Quadcopter:
    def __init__(self):
        # Quad Params
        # ---------------------------
        mB  = 1.2       # mass (kg)
        params = {}
        params["mB"]   = mB
        self.params = params

        # Initial State
        # ---------------------------
        self.state = np.zeros(3)
    def update(self, t, Ts, cmd, wind):
        self.state = odeint(state_dot, self.state, [t,t+Ts], args = (cmd, self.params, wind))[1]

@jit(void(float_[:], float_, float_[:], numba.typed.Dict )) #(nopython = True)
def state_dot(state, t, cmd, params, wind):
    # Import Params
    # ---------------------------    
    mB   = params["mB"]
    # Import State Vector
    # ---------------------------  
    x      = state[0]
    y      = state[1]
    z      = state[2]
    # Motor Dynamics and Rotor forces (Second Order System: https://apmonitor.com/pdc/index.php/Main/SecondOrderSystems)
    # ---------------------------
    print(cmd)
    # Wind Model
    # ---------------------------
    [velW, qW1, qW2] = wind.randomWind(t)
    print(velW)
    # State Derivative Vector
    # ---------------------------
    sdot     = np.zeros(3)
    sdot[0]  = x*t + 0.1
    sdot[1]  = y*t + 0.1
    sdot[2]  = z*t + 0.1

    return sdot

class Wind:
    def __init__(self):
        # Normally, average wind would be randomly set here
        self.velW_med = 5.0
        self.qW1_med  = 0.2
        self.qW2_med  = 0.1
    def randomWind(self, t):
        # Normally, wind values would be a sine function dependant of current time
        velW = self.velW_med
        qW1  = self.qW1_med
        qW2  = self.qW2_med
        return velW, qW1, qW2
# Set time
Ti = 0
Ts = 0.005
Tf = 10
# Initialize quadcopter and wind
quad = Quadcopter()
wind = Wind()
# Simulation
t = Ti
while round(t,3) < Tf:
    cmd = np.array([1,2,1,3])
    quad.update(t, Ts, cmd, wind)
    print(quad.state)
    t += Ts

收到的错误是

Traceback (most recent call last):
  File "c:/Users/JOHN-Laptop/Documents/Code Dev/Test/question_quad.py", line 29, in <module>
    @jit(void(float_[:], float_, float_[:], numba.typed.Dict )) #(nopython = True)
  File "C:UsersJOHN-LaptopAppDataLocalProgramsPythonPython37libsite-packagesnumbadecorators.py", line 186, in wrapper
    disp.compile(sig)
  File "C:UsersJOHN-LaptopAppDataLocalProgramsPythonPython37libsite-packagesnumbacompiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "C:UsersJOHN-LaptopAppDataLocalProgramsPythonPython37libsite-packagesnumbadispatcher.py", line 676, in compile
    args, return_type = sigutils.normalize_signature(sig)
  File "C:UsersJOHN-LaptopAppDataLocalProgramsPythonPython37libsite-packagesnumbasigutils.py", line 48, in normalize_signature
    check_type(ty)
  File "C:UsersJOHN-LaptopAppDataLocalProgramsPythonPython37libsite-packagesnumbasigutils.py", line 43, in check_type
    "instance, got %r" % (ty,))
TypeError: invalid type in signature: expected a type instance, got <class 'numba.typed.typeddict.Dict'>

可以在此处查看完整的代码:https://github.com/bobzwik/quadcopter_simcon/blob/dev_numba/simulation/quadfiles/quadfiles/quad.py.py

如果我缺少任何信息,请随时提出。

编辑:更改完整代码的链接,链接到另一个分支。

我注意到的第一件事是 - 至少在您在此处显示的代码中 - 您的JIT签名具有四种类型,但是您所装饰的功能有五个参数:

@jit(void(float_[:], float_, float_[:], numba.typed.Dict))
def state_dot(state, t, cmd, params, wind):

很明显,您需要解决这个问题。最简单的尝试就是删除签名,然后让Numba弄清楚:

@jit
def state_dot(state, t, cmd, params, wind):

当然,即使您这样做,Numba仍然抱怨它不知道如何键入所有内容,并指向说mB = params["mB"]的行。它仍然可以进行" loop升起",这意味着它能够编译某些东西,但不会尽可能快。

因此,第二件事要注意的是,虽然Numba表示它支持dict s,但随后引发了许多警告。基本上,使用dict仍然不是一个好主意。我也看不到您使用dict的任何充分理由。为什么不像self.mB = mB一样使mB成为班级的成员呢?我知道您在完整的Quadcopter课程中您会有更多复杂的事情,但是您可以有很多成员。

现在,要注意的第三件事是,自从我写了您在其他地方指出的要点以来,Numba变得更好了,现在可以处理课堂,因此您可能需要研究numba.jitclass。通常,当您将jitclass对象传递给您要jit的函数时,numba将知道如何处理。

,但也许比所有这些都重要的是,您的update方法为每个步骤调用odeint。我想这是代码中最慢的部分。该功能本来可以调用一次,以便可以从头到尾解决您的整个问题,因此它具有许多(相对较慢(的开销,与理解您传递的参数,分配内存,初始化事物等有关。一种更好的方法是构造一个scipy.integrate.ode对象,以使所有步骤之间设置所有设置,并将其保持在周围,以便您可以在步骤之间使用相同的步骤。较新的接口solve_ivpRK45(以及类似(大致相当于odeintode,只是ode具有我的首选求解器dop853。如果您只需要OdeSolver子类之一,则可能更喜欢这些接口。另请注意,如果您实际上在状态中实际上任何内容步骤之间,则可能需要再次致电set_initial_value,否则情况可能会出错而不注意。

更一般而言,如果您担心速度,那么您能做的最好的事情就是介绍您的代码。这里的第一步是仅在ipython中使用%prun

最新更新