如何使用sympy找到快速sigmoid函数的逆函数?



我想为变量x求解以下方程(灵感来自:快速sigmoid算法(:

0 = lower + (upper - lower) * (0.5 + 0.5 * x / (1 + abs(x))) - y

如果我用于此符号,则会出现错误:

from sympy.solvers import solve
from sympy import Symbol
x = Symbol('x', real=True)
y = Symbol('y', real=True)
lower = Symbol('lower', real=True)
upper = Symbol('upper', real=True)
solve(lower + (upper - lower) * (0.5 + 0.5 * x / (1 + abs(x))) -y, x)

错误:

File "/home/user/venv/numba/lib/python3.6/site-packages/sympy/core/function.py", line 3082, in nfloat
return type(expr)([nfloat(a, n, exponent) for a in expr])
File "/home/user/venv/numba/lib/python3.6/site-packages/sympy/core/function.py", line 3082, in <listcomp>
return type(expr)([nfloat(a, n, exponent) for a in expr])
File "/home/user/venv/numba/lib/python3.6/site-packages/sympy/core/function.py", line 3082, in nfloat
return type(expr)([nfloat(a, n, exponent) for a in expr])
TypeError: __new__() missing 1 required positional argument: 'cond'

我怎样才能用符号求解这个方程?

(或者如果有人能够手动求解x方程:无论如何,函数的反演会是什么样子?

这似乎是 SymPy 版本 1.4 中的一个错误。在主人身上,我没有收到异常,而是得到:

In [2]: solve(lower + (upper - lower) * (0.5 + 0.5 * x / (1 + abs(x))) -y, x)                                                                                                     
Out[2]: 
⎡⎧0.5⋅lower + 0.5⋅upper - y      0.5⋅(lower + upper - 2.0⋅y)      ⎧-0.5⋅lower - 0.5⋅upper + y      0.5⋅(-lower - upper + 2.0⋅y)    ⎤
⎢⎪─────────────────────────  for ─────────────────────────── < 0  ⎪──────────────────────────  for ──────────────────────────── ≥ 0⎥
⎢⎨        lower - y                       lower - y             , ⎨        upper - y                        upper - y              ⎥
⎢⎪                                                                ⎪                                                                ⎥
⎣⎩           nan                          otherwise               ⎩           nan                           otherwise              ⎦

这将返回两个分段解,对应于负 x 和正 x 的情况(我认为(。

不过,我对上面的结果并不满意。我认为正确的结果应该是这样的:

In [46]: eqn = lower + (upper - lower) * (0.5 + 0.5 * x / (1 + abs(x))) - y                                                                                                       
In [47]: eqn = piecewise_fold(eqn.rewrite(Piecewise))                                                                                                                             
In [48]: eqn                                                                                                                                                                      
Out[48]: 
⎧                             ⎛0.5⋅x      ⎞           
⎪lower - y + (-lower + upper)⋅⎜───── + 0.5⎟  for x ≥ 0
⎪                             ⎝x + 1      ⎠           
⎨                                                     
⎪                             ⎛0.5⋅x      ⎞           
⎪lower - y + (-lower + upper)⋅⎜───── + 0.5⎟  otherwise
⎩                             ⎝1 - x      ⎠           
In [49]: sx1, = solve(eqn.args[0][0], x)                                                                                                                                          
In [50]: sx2, = solve(eqn.args[1][0], x)                                                                                                                                          
In [51]: cx1 = eqn.args[0][1].subs(x, sx1)                                                                                                                                        
In [52]: sol = Piecewise((sx1, cx1), (sx2, True))                                                                                                                                 
In [53]: sol                                                                                                                                                                      
Out[53]: 
⎧-0.5⋅lower - 0.5⋅upper + y      -0.5⋅lower - 0.5⋅upper + y    
⎪──────────────────────────  for ────────────────────────── ≥ 0
⎪        upper - y                       upper - y             
⎨                                                              
⎪0.5⋅lower + 0.5⋅upper - y                                     
⎪─────────────────────────               otherwise             
⎩        lower - y 

顺便说一句,我为普通的快速sigmoid手工导出了一个python函数:

from math import copysign
def inverse_fast_sigmoid(x):
assert -1.0 < x < 1.0
return copysign(
1 / (
1 - abs(x)
) - 1, 
x
)

也许您可以根据您的版本调整它。

相关内容

  • 没有找到相关文章

最新更新