sympy lambdify with numexpr and sqrt



我正在尝试加速lambdify使用numexpr生成的一些数字代码。遗憾的是,基于 numexpr 的函数在使用 sqrt 函数时会中断,即使它是受支持的函数之一。

这为我重现了这个问题:

import sympy
import numpy as np
import numexpr
from sympy.utilities.lambdify import lambdify
expr = sympy.S('b*sqrt(a) - a**2')
a, b = sorted(expr.free_symbols, key=lambda s: s.name)
func_numpy = lambdify((a,b), expr, modules=[np], dummify=False)
func_numexpr = lambdify((a,b), expr, modules=[numexpr], dummify=False)
foo, bar = np.random.random((2, 4))
print sympy.__version__
print func_numpy(foo, bar)
print func_numexpr(foo, bar)

当我运行它时,输出是:

0.7.6
[-0.02062061  0.08648306 -0.57868128  0.27598245]
Traceback (most recent call last):
  File "sympy_test.py", line 17, in <module>
    print func_numexpr(foo, bar)
  File "<string>", line 1, in <lambda>
NameError: global name 'sqrt' is not defined

作为健全性检查,我还尝试直接致电numexpr

numexpr.evaluate('b*sqrt(a) - a**2', local_dict=dict(a=foo, b=bar))

按预期工作,产生与func_numpy相同的结果。


编辑:当我使用该行时它有效:

func_numexpr = lambdify((a,b), expr, modules=['numexpr'], dummify=False)

这是一个符号错误吗?

您可以将np.sqrt(9)更改为numexpr.evaluate('9**0.5')

相关内容

  • 没有找到相关文章

最新更新