为sympy创建自定义打印.导数



假设我有一个函数f(x,y),我希望fw.r.t 的偏导数x显示为partial_{x}^{n} f(x,y)所以我创建了以下类

class D(sp.Derivative):
def _latex(self,printer=None):
func = printer.doprint(self.args[0])
b = self.args[1]
if b[1] == 1 :
return r"partial_{%s}%s"%(printer.doprint(b[0]),func)
else :
return r"partial_{%s}^{%s}%s"%(printer.doprint(b[0]),printer.doprint(b[1]),func)

这工作正常,但是当我使用方法评估导数时doit()回到默认行为。说我有

x,y = sp.symbols('x,y')
f = sp.Function('f')(x,y)

然后sp.print_latex(D(f,x))给出partial_{x}f{left(x,y right)}是正确的,但sp.print_latex(D(x*f,x).doit())给出x frac{partial}{partial x} f{left(x,y right)} + f{left(x,y right)},这是旧的行为。如何解决此问题?

问题是你没有覆盖父类的doit,它返回的是纯Derivative对象而不是子类。与其创建新的Derivative类,我建议创建一个新的打印机类:

from sympy import *
from sympy.printing.latex import LatexPrinter
class MyLatexPrinter(LatexPrinter):
def _print_Derivative(self, expr):
differand, *(wrt_counts) = expr.args
if len(wrt_counts) > 1 or wrt_counts[0][1] != 1:
raise NotImplementedError('More code needed...')
((wrt, count),) = wrt_counts
return 'partial_{%s} %s)' % (self._print(wrt), self._print(differand))
x, y = symbols('x, y')
f = Function('f')
expr = (x*f(x, y)).diff(x)
printer = MyLatexPrinter()
print(printer.doprint(expr))

这给了x partial_{x} f{left(x,y right)}) + f{left(x,y right)}

您可以使用init_printing(latex_printer=printer.doprint)将其设置为默认输出。

最新更新