可在Python模块Sympy中以矩阵形式使用的微分算子



我们需要两个矩阵的微分算子[B][C],如:

B = sympy.Matrix([[ D(x), D(y) ],
                  [ D(y), D(x) ]])
C = sympy.Matrix([[ D(x), D(y) ]])
ans = B * sympy.Matrix([[x*y**2],
                        [x**2*y]])
print ans
[x**2 + y**2]
[      4*x*y]
ans2 = ans * C
print ans2
[2*x, 2*y]
[4*y, 4*x]

这也可以应用于计算向量场的旋度,如:

culr  = sympy.Matrix([[ D(x), D(y), D(z) ]])
field = sympy.Matrix([[ x**2*y, x*y*z, -x**2*y**2 ]])
要使用Sympy解决这个问题,必须创建以下Python类:
import sympy
class D( sympy.Derivative ):
    def __init__( self, var ):
        super( D, self ).__init__()
        self.var = var
    def __mul__(self, other):
        return sympy.diff( other, self.var )

这个类单独解决了微分算子矩阵在左边相乘的问题。这里diff只在求导函数已知的情况下执行。

为了解决微分算子矩阵在右侧相乘的问题,核心类Expr中的__mul__方法必须按照以下方式进行更改:

class Expr(Basic, EvalfMixin):
    # ...
    def __mul__(self, other):
        import sympy
        if other.__class__.__name__ == 'D':
            return sympy.diff( self, other.var )
        else:
            return Mul(self, other)
    #...

它工作得很好,但是Sympy应该有一个更好的本地解决方案来处理这个问题。有人知道是什么吗?

这个解决方案应用了其他答案和这里的提示。D操作符可以定义如下:

  • 仅在从左乘时考虑,因此D(t)*2*t**3 = 6*t**22*t**3*D(t)不做任何事情
  • 所有与D一起使用的表达式和符号必须有is_commutative = False
  • 在给定表达式的上下文中使用evaluateExpr()求值
    • 沿着表达式从右向左查找D运算符并将mydiff() *应用于相应的右部分

*: mydiff用来代替diff,以允许创建一个更高阶的D,如mydiff(D(t), t) = D(t,t)

D__mul__()中的diff仅供参考,因为在当前解决方案中,evaluateExpr()实际上执行微分工作。创建一个python模块并保存为d.py

import sympy
from sympy.core.decorators import call_highest_priority
from sympy import Expr, Matrix, Mul, Add, diff
from sympy.core.numbers import Zero
class D(Expr):
    _op_priority = 11.
    is_commutative = False
    def __init__(self, *variables, **assumptions):
        super(D, self).__init__()
        self.evaluate = False
        self.variables = variables
    def __repr__(self):
        return 'D%s' % str(self.variables)
    def __str__(self):
        return self.__repr__()
    @call_highest_priority('__mul__')
    def __rmul__(self, other):
        return Mul(other, self)
    @call_highest_priority('__rmul__')
    def __mul__(self, other):
        if isinstance(other, D):
            variables = self.variables + other.variables
            return D(*variables)
        if isinstance(other, Matrix):
            other_copy = other.copy()
            for i, elem in enumerate(other):
                other_copy[i] = self * elem
            return other_copy
        if self.evaluate:
            return diff(other, *self.variables)
        else:
            return Mul(self, other)
    def __pow__(self, other):
        variables = self.variables
        for i in range(other-1):
            variables += self.variables
        return D(*variables)
def mydiff(expr, *variables):
    if isinstance(expr, D):
        expr.variables += variables
        return D(*expr.variables)
    if isinstance(expr, Matrix):
        expr_copy = expr.copy()
        for i, elem in enumerate(expr):
            expr_copy[i] = diff(elem, *variables)
        return expr_copy
    return diff(expr, *variables)
def evaluateMul(expr):
    end = 0
    if expr.args:
        if isinstance(expr.args[-1], D):
            if len(expr.args[:-1])==1:
                cte = expr.args[0]
                return Zero()
            end = -1
    for i in range(len(expr.args)-1+end, -1, -1):
        arg = expr.args[i]
        if isinstance(arg, Add):
            arg = evaluateAdd(arg)
        if isinstance(arg, Mul):
            arg = evaluateMul(arg)
        if isinstance(arg, D):
            left = Mul(*expr.args[:i])
            right = Mul(*expr.args[i+1:])
            right = mydiff(right, *arg.variables)
            ans = left * right
            return evaluateMul(ans)
    return expr
def evaluateAdd(expr):
    newargs = []
    for arg in expr.args:
        if isinstance(arg, Mul):
            arg = evaluateMul(arg)
        if isinstance(arg, Add):
            arg = evaluateAdd(arg)
        if isinstance(arg, D):
            arg = Zero()
        newargs.append(arg)
    return Add(*newargs)
#courtesy: https://stackoverflow.com/a/48291478/1429450
def disableNonCommutivity(expr):
    replacements = {s: sympy.Dummy(s.name) for s in expr.free_symbols}
    return expr.xreplace(replacements)
def evaluateExpr(expr):
    if isinstance(expr, Matrix):
        for i, elem in enumerate(expr):
            elem = elem.expand()
            expr[i] = evaluateExpr(elem)
        return disableNonCommutivity(expr)
    expr = expr.expand()
    if isinstance(expr, Mul):
        expr = evaluateMul(expr)
    elif isinstance(expr, Add):
        expr = evaluateAdd(expr)
    elif isinstance(expr, D):
        expr = Zero()
    return disableNonCommutivity(expr)

例1:向量场的旋度。请注意,使用commutative=False定义变量很重要,因为它们在Mul().args中的顺序将影响结果,参见另一个问题。

from d import D, evaluateExpr
from sympy import Matrix
sympy.var('x', commutative=False)
sympy.var('y', commutative=False)
sympy.var('z', commutative=False)
curl  = Matrix( [[ D(x), D(y), D(z) ]] )
field = Matrix( [[ x**2*y, x*y*z, -x**2*y**2 ]] )       
evaluateExpr( curl.cross( field ) )
# [-x*y - 2*x**2*y, 2*x*y**2, -x**2 + y*z]

例2:结构分析中典型的里兹近似。

from d import D, evaluateExpr
from sympy import sin, cos, Matrix
sin.is_commutative = False
cos.is_commutative = False
g1 = []
g2 = []
g3 = []
sympy.var('x', commutative=False)
sympy.var('t', commutative=False)
sympy.var('r', commutative=False)
sympy.var('A', commutative=False)
m=5
n=5
for j in xrange(1,n+1):
    for i in xrange(1,m+1):
        g1 += [sin(i*x)*sin(j*t),                 0,                 0]
        g2 += [                0, cos(i*x)*sin(j*t),                 0]
        g3 += [                0,                 0, sin(i*x)*cos(j*t)]
g = Matrix( [g1, g2, g3] )
B = Matrix(
    [[     D(x),        0,        0],
     [    1/r*A,        0,        0],
     [ 1/r*D(t),        0,        0],
     [        0,     D(x),        0],
     [        0,    1/r*A, 1/r*D(t)],
     [        0, 1/r*D(t), D(x)-1/x],
     [        0,        0,        1],
     [        0,        1,        0]])
ans = evaluateExpr(B*g)

创建了一个print_to_file()函数来快速检查大表达式。

import sympy
import subprocess
def print_to_file( guy, append=False ):
    flag = 'w'
    if append: flag = 'a'
    outfile = open(r'print.txt', flag)
    outfile.write('n')
    outfile.write( sympy.pretty(guy, wrap_line=False) )
    outfile.write('n')
    outfile.close()
    subprocess.Popen( [r'notepad.exe', r'print.txt'] )
print_to_file( B*g )
print_to_file( ans, append=True )

微分运算符在SymPy核心中不存在,即使存在"运算符乘法"而不是"运算符的应用"也是对SymPy不支持的符号的滥用。

[1]另一个问题是SymPy表达式只能从sympy.Basic的子类构建,因此很可能您的class D只是在输入sympy_expr+D(z)时引发错误。这就是(expression*D(z)) * (another_expr)失败的原因。(expression*D(z))不能建立

另外,如果D的参数不是单个Symbol,则不清楚您期望从该操作符中得到什么。

最后,diff(f(x), x)(其中f是一个符号未知函数)返回一个未求值的表达式,因为当f是未知函数时,没有其他东西可以合理地返回。稍后,当您替换expr.subs(f(x), sin(x))时,将计算导数(最坏的情况下您可能需要调用expr.doit())。

[2]没有优雅的简短的解决方案来解决你的问题。我建议解决问题的一种方法是覆盖Expr__mul__方法:而不是仅仅乘以表达式树,它将检查左表达式树是否包含D的实例,并将应用它们。显然,如果你想要添加新对象,这将无法扩展。这是sympy设计中一个长期存在的已知问题。

EDIT:[1]只是为了允许创建包含D的表达式。[2]对于包含不止一个D的表达式是必须的。

如果您想要正确的乘法工作,您需要从object子类化。这将导致x*D回落到D.__rmul__。但我无法想象这是高优先级,因为操作符从不从右侧应用。

使操作符始终自动工作目前是不可能的。要真正完全工作,您需要http://code.google.com/p/sympy/issues/detail?id=1941。参见https://github.com/sympy/sympy/wiki/Canonicalization(可以随意编辑该页)。

但是,您可以使用stackoverflow问题中的想法创建一个大多数时候都能工作的类,对于它不能处理的情况,编写一个简单的函数,遍历表达式并在尚未应用操作符的地方应用操作符。

顺便说一下,作为"乘法"的微分运算符需要考虑的一件事是它是非结合律的。即(D*f)*g = g*Df,而D*(f*g) = g*Df + f*Dg。所以当你做一些事情的时候,你需要小心,不要"吃掉"一个表达的一部分,而不是整个表达。例如,D*2*x会因此给出0。SymPy到处都假定乘法是关联的,所以在某些时候很可能做得不正确。

如果这成为一个问题,我建议转储自动应用程序,并且只使用一个遍历并应用它的函数(正如我上面提到的,无论如何您都需要它)。

相关内容

  • 没有找到相关文章

最新更新