快速计算大量输入值的数学表达式(函数)



以下问题

  • 计算字符串
    中的数学表达式
  • Python
    中的公式解析
  • 在 Python
    中解析用户提供的数学公式的安全方法
  • 从 Python 中不安全的用户输入中评估数学方程式

他们各自的答案让我思考如何有效地解析(或多或少受信任的(用户给出的单个数学表达式(大致与本答案 https://stackoverflow.com/a/594294/1672565 相同(,用于来自数据库的 20k 到 30k 输入值。我实施了一个快速而肮脏的基准测试,以便我可以比较不同的解决方案。

# Runs with Python 3(.4)
import pprint
import time
# This is what I have
userinput_function = '5*(1-(x*0.1))' # String - numbers should be handled as floats
demo_len = 20000 # Parameter for benchmark (20k to 30k in real life)
print_results = False
# Some database, represented by an array of dicts (simplified for this example)
database_xy = []
for a in range(1, demo_len, 1):
    database_xy.append({
        'x':float(a),
        'y_eval':0,
        'y_sympya':0,
        'y_sympyb':0,
        'y_sympyc':0,
        'y_aevala':0,
        'y_aevalb':0,
        'y_aevalc':0,
        'y_numexpr': 0,
        'y_simpleeval':0
        })
#

解决方案#1:评估 [是的,完全不安全]

time_start = time.time()
func = eval("lambda x: " + userinput_function)
for item in database_xy:
    item['y_eval'] = func(item['x'])
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('1 eval: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #2a:sympy - evalf (http://www.sympy.org(

import sympy
time_start = time.time()
x = sympy.symbols('x')
sympy_function = sympy.sympify(userinput_function)
for item in database_xy:
    item['y_sympya'] = float(sympy_function.evalf(subs={x:item['x']}))
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('2a sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #2b:sympy - lambdify (http://www.sympy.org(

from sympy.utilities.lambdify import lambdify
import sympy
import numpy
time_start = time.time()
sympy_functionb = sympy.sympify(userinput_function)
func = lambdify(x, sympy_functionb, 'numpy') # returns a numpy-ready function
xx = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
    xx[index] = item['x']
yy = func(xx)
for index, item in enumerate(database_xy):
    item['y_sympyb'] = yy[index]
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('2b sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #2c:sympy - lambdify with numexpr [and numpy] (http://www.sympy.org(

from sympy.utilities.lambdify import lambdify
import sympy
import numpy
import numexpr
time_start = time.time()
sympy_functionb = sympy.sympify(userinput_function)
func = lambdify(x, sympy_functionb, 'numexpr') # returns a numpy-ready function
xx = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
    xx[index] = item['x']
yy = func(xx)
for index, item in enumerate(database_xy):
    item['y_sympyc'] = yy[index]
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('2c sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #3a:星体 [基于 ast] - 带弦魔法 (http://newville.github.io/asteval/index.html(

from asteval import Interpreter
aevala = Interpreter()
time_start = time.time()
aevala('def func(x):ntreturn ' + userinput_function)
for item in database_xy:
    item['y_aevala'] = aevala('func(' + str(item['x']) + ')')
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('3a aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')

# Solution #3b (M Newville(: asteval [based of ast] - parse & run (http://newville.github.io/asteval/index.html(

from asteval import Interpreter
aevalb = Interpreter()
time_start = time.time()
exprb = aevalb.parse(userinput_function)
for item in database_xy:
    aevalb.symtable['x'] = item['x']
    item['y_aevalb'] = aevalb.run(exprb)
time_end = time.time()
print('3b aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')

# Solution #3c (M Newville(: asteval [based of ast] - parse & run with numpy (http://newville.github.io/asteval/index.html(

from asteval import Interpreter
import numpy
aevalc = Interpreter()
time_start = time.time()
exprc = aevalc.parse(userinput_function)
x = numpy.array([item['x'] for item in database_xy])
aevalc.symtable['x'] = x
y = aevalc.run(exprc)
for index, item in enumerate(database_xy):
    item['y_aevalc'] = y[index]
time_end = time.time()
print('3c aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')
#

解决方案 #4:简单 [基于 ast] (https://github.com/danthedeckie/simpleeval(

from simpleeval import simple_eval
time_start = time.time()
for item in database_xy:
    item['y_simpleeval'] = simple_eval(userinput_function, names={'x': item['x']})
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('4 simpleeval: ' + str(round(time_end - time_start, 4)) + ' seconds')
#

解决方案 #5 numexpr [和 numpy] (https://github.com/pydata/numexpr(

import numpy
import numexpr
time_start = time.time()
x = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
    x[index] = item['x']
y = numexpr.evaluate(userinput_function)
for index, item in enumerate(database_xy):
    item['y_numexpr'] = y[index]
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('5 numexpr: ' + str(round(time_end - time_start, 4)) + ' seconds')

在我的旧测试机器(Python 3.4,Linux 3.11 x86_64,两个内核,1.8GHz(上,我得到以下结果:

1 eval: 0.0185 seconds
2a sympy: 10.671 seconds
2b sympy: 0.0315 seconds
2c sympy: 0.0348 seconds
3a aeval: 2.8368 seconds
3b aeval: 0.5827 seconds
3c aeval: 0.0246 seconds
4 simpleeval: 1.2363 seconds
5 numexpr: 0.0312 seconds

突出的是eval令人难以置信的速度,尽管我不想在现实生活中使用它。第二个最好的解决方案似乎是 numexpr,它取决于 numpy - 我想避免的依赖关系,尽管这不是硬性要求。下一个最好的东西是simpleeval,它是围绕ast构建的。aeval 是另一个基于 AST 的解决方案,它遭受了这样一个事实,即我必须先将每个浮点输入值转换为字符串,我找不到方法。Sympy 最初是我最喜欢的,因为它提供了最灵活、显然最安全的解决方案,但它最终与倒数第二个解决方案的距离令人印象深刻。

更新 1:使用 sympy 有一种更快的方法。请参阅解决方案 2b。它几乎和 numexpr 一样好,尽管我不确定 sympy 是否真的在内部使用它。

更新 2sympy 实现现在使用 sympify 而不是简化(正如其首席开发人员 asmeurer 所推荐的那样 - 谢谢(。除非明确要求它这样做,否则它不会使用 numexpr(请参阅解决方案 2c(。我还添加了两个基于asteval的明显更快的解决方案(感谢M Newville(。


我有哪些选择可以进一步加快任何相对安全的解决方案?例如,是否有其他直接使用 ast 的安全(-ish(方法?

我过去使用过C++ ExprTK库,并取得了巨大的成功。这是其他C++解析器(例如Muparser,MathExpr,ATMSP等(的基准速度测试,ExprTK名列前茅。

ExprTK有一个叫做cexprtk的Python包装器,我已经使用过它,发现它非常快。您只能编译一次数学表达式,然后根据需要多次计算此序列化表达式。下面是一个将 cexprtkuserinput_function 一起使用的简单示例代码:

import cexprtk
import time
userinput_function = '5*(1-(x*0.1))' # String - numbers should be handled as floats
demo_len = 20000 # Parameter for benchmark (20k to 30k in real life)
time_start = time.time()
x = 1
st = cexprtk.Symbol_Table({"x":x}, add_constants = True) # Setup the symbol table
Expr = cexprtk.Expression(userinput_function, st) # Apply the symbol table to the userinput_function
for x in range(0,demo_len,1):
    st.variables['x'] = x # Update the symbol table with the new x value
    Expr() # evaluate expression
time_end = time.time()
print('1 cexprtk: ' + str(round(time_end - time_start, 4)) + ' seconds')

在我的机器(Linux,双核,2.5GHz(上,对于20000的演示长度,这在0.0202秒内完成。

对于 2,000,000 的演示长度cexprtk在 1.23 秒内完成。

既然你问了星号,有一种方法可以使用它并获得更快的结果:

aeval = Interpreter()
time_start = time.time()
expr = aeval.parse(userinput_function)
for item in database_xy:
    aeval.symtable['x'] = item['x']
    item['y_aeval'] = aeval.run(expr)
time_end = time.time()

也就是说,您可以先解析("预编译"(用户输入函数,然后将每个新值x插入到符号表中,并使用Interpreter.run()来计算该值的已编译表达式。 在你的规模上,我认为这会让你接近 0.5 秒。

如果您愿意使用 numpy ,混合解决方案:

aeval = Interpreter()
time_start = time.time()
expr = aeval.parse(userinput_function)
x = numpy.array([item['x'] for item in database_xy])
aeval.symtable['x'] = x
y = aeval.run(expr)
time_end = time.time()

应该快得多,并且在运行时与使用 numexpr 相当。

CPython(和pypy(使用非常简单的堆栈语言来执行函数,并且使用ast模块自己编写字节码相当容易。

import sys
PY3 = sys.version_info.major > 2
import ast
from ast import parse
import types
from dis import opmap
ops = {
    ast.Mult: opmap['BINARY_MULTIPLY'],
    ast.Add: opmap['BINARY_ADD'],
    ast.Sub: opmap['BINARY_SUBTRACT'],
    ast.Div: opmap['BINARY_TRUE_DIVIDE'],
    ast.Pow: opmap['BINARY_POWER'],
}
LOAD_CONST = opmap['LOAD_CONST']
RETURN_VALUE = opmap['RETURN_VALUE']
LOAD_FAST = opmap['LOAD_FAST']
def process(consts, bytecode, p, stackSize=0):
    if isinstance(p, ast.Expr):
        return process(consts, bytecode, p.value, stackSize)
    if isinstance(p, ast.BinOp):
        szl = process(consts, bytecode, p.left, stackSize)
        szr = process(consts, bytecode, p.right, stackSize)
        if type(p.op) in ops:
            bytecode.append(ops[type(p.op)])
        else:
            print(p.op)
            raise Exception("unspported opcode")
        return max(szl, szr) + stackSize + 1
    if isinstance(p, ast.Num):
        if p.n not in consts:
            consts.append(p.n)
        idx = consts.index(p.n)
        bytecode.append(LOAD_CONST)
        bytecode.append(idx % 256)
        bytecode.append(idx // 256)
        return stackSize + 1
    if isinstance(p, ast.Name):
        bytecode.append(LOAD_FAST)
        bytecode.append(0)
        bytecode.append(0)
        return stackSize + 1
    raise Exception("unsupported token")
def makefunction(inp):
    def f(x):
        pass
    if PY3:
        oldcode = f.__code__
        kwonly = oldcode.co_kwonlyargcount
    else:
        oldcode = f.func_code
    stack_size = 0
    consts = [None]
    bytecode = []
    p = ast.parse(inp).body[0]
    stack_size = process(consts, bytecode, p, stack_size)
    bytecode.append(RETURN_VALUE)
    bytecode = bytes(bytearray(bytecode))
    consts = tuple(consts)
    if PY3:
        code = types.CodeType(oldcode.co_argcount, oldcode.co_kwonlyargcount, oldcode.co_nlocals, stack_size, oldcode.co_flags, bytecode, consts, oldcode.co_names, oldcode.co_varnames, oldcode.co_filename, 'f', oldcode.co_firstlineno, b'')
        f.__code__ = code
    else:
        code = types.CodeType(oldcode.co_argcount, oldcode.co_nlocals, stack_size, oldcode.co_flags, bytecode, consts, oldcode.co_names, oldcode.co_varnames, oldcode.co_filename, 'f', oldcode.co_firstlineno, '')
        f.func_code = code
    return f

这有一个明显的优势,那就是生成与eval基本相同的函数,并且它的扩展几乎与compile + eval一样好(compile步骤比eval稍慢,并且eval将预先计算任何它可以计算的内容(1+1+x被编译为2+x(。

相比之下,eval在 0.0125 秒内完成 20k 测试,makefunction在 0.014 秒内完成。 将迭代次数增加到 2,000,000,eval 在 1.23 秒内完成,makefunction在 1.32 秒内完成。

有趣的是,pypy 认识到 evalmakefunction 产生基本相同的功能,因此第一个的 JIT 预热加速了第二个。

如果要将字符串传递给sympy.simplify(不建议使用;建议显式使用sympify(,则将使用sympy.sympify将其转换为内部使用eval的 SymPy 表达式。

我不是 Python 程序员,所以我不能提供 Python 代码。 但我认为我可以提供一个简单的方案来模拟您的依赖项并且仍然运行得非常快。

这里的关键是建造一些接近评估的东西,而不是被评估。因此,您要做的是将用户方程式"编译"为可以快速评估的内容。 OP已经展示了许多解决方案。

这是另一个基于将方程评估为反向抛光的方法。

为了便于讨论,假设您可以将方程转换为 RPN(反向抛光表示法(。 这意味着操作数先于运算符,例如,对于用户公式:

        sqrt(x**2 + y**2)

您会从左到右获得 RPN 等效读数:

          x 2 ** y 2 ** + sqrt

事实上,我们可以将"操作数"(例如,变量和常量(视为采用零操作数的运算符。 现在RPN中的每一个都是一个运算符。

如果我们将每个运算符元素视为一个标记(假设每个运算符元素下面写成">RPNelement"的唯一小整数(并将它们存储在数组"RPN"中,我们可以非常快速地使用下推堆栈评估这样的公式:

       stack = {};  // make the stack empty
       do i=1,len(RPN),1
          case RPN[i]:
              "0":  push(stack,0);
              "1": push(stack,1);
              "+":  push(stack,pop(stack)+pop(stack));break;
               "-": push(stack,pop(stack)-pop(stack));break;
               "**": push(stack,power(pop(stack),pop(stack)));break;
               "x": push(stack,x);break;
               "y": push(stack,y);break;
               "K1": push(stack,K1);break;
                ... // as many K1s as you have typical constants in a formula
           endcase
       enddo
       answer=pop(stack);

您可以内联推送和弹出的操作以加快速度。如果提供的 RPN 格式正确,则此代码是完全安全的。

现在,如何获取RPN? 答:构建一个小的递归下降解析器,其操作将 RPN 运算符附加到 RPN 数组中。 请参阅我的 SO 答案,了解如何轻松为典型方程构建递归下降解析器。

您必须组织以将解析中遇到的常量放入 K1、K2、...如果它们不是特殊的,经常出现的值(如我为"0"和"1"显示的那样;如果有帮助,您可以添加更多(。

此解决方案最多应该有几百行,并且对其他包的依赖性为零。

(Python专家:随意编辑代码以使其成为Python式(。

相关内容

  • 没有找到相关文章

最新更新