用numpy优化双三次方程求解器



我有一些图像处理软件,工作得很好,但速度很慢。我现在正在使用numpy做一些事情,但我忍不住想,我可以利用这个库的更多功能来获得更好的收益。

另外,如果"双立方"在这里不是合适的术语,请原谅我——我懂数学,但不擅长词汇;)

我的工具求解系数(a, b,…), j)的双三次方程,形式为:

f(x,y) = ax^3 + by^3 + cx^2*y + dy^2*x + ex^2 + fy^2 + gxy + hx + iy + j

我这样做的方法是首先使用源数据集和目标数据集的较小点集生成求解器。这是通过首先为每个形状为(n, 10)的集合设置"矩阵行"来完成的,然后使用最小二乘法求解。从那里,我们遍历源数据中的每个其他项,生成"矩阵行"并执行np。点对"已解"系数矩阵。

import numpy as np
def matrix_row(x, y):
    row = [0] * 10
    row[0] = 1.0
    row[1] = x
    row[2] = y
    row[3] = x * y
    row[4] = y * y
    row[5] = x * x
    row[6] = (y * y) * x
    row[7] = (x * x) * y
    row[8] = y**3
    row[9] = x**3
    return row
def gen_matrix(items, num_items=24):
    mat = []
    for i in xrange(0, num_items):
        mat.append(_matrix_row(items[i, 0], items[i, 1]))
    return np.array(mat)
# Generate source data
n = 24
srcdata = np.random.rand(100, 2)
dstdata = np.random.rand(n, 2)
# Determine the coefficients for the solver for the first n components
# The resultant 'solved' matrix will be of shape (10, 2)
srcmat = gen_matrix(srcdata[:n, :], num_items=n)
solved, residuals, rank, s = np.linalg.lstsq(srcmat, dstdata)
# Apply the solution to all the src data
for item in srcdata:
    mrow = matrix_row(item[0], item[1])
    # Obviously, the print statements aren't important
    print 'orig   ', item
    print 'solved ', np.dot(mrow, solved)
    print 'n'

我有很多数据,所以Python中的for循环真的会降低性能。有没有更愚蠢的方法来优化这个?

这是一个矢量化的方法-

# Form an array of shape (10,100) that represents pairwise implementation
# of all those functions using all rows off srcdata
x,y = srcdata.T
a = np.ones((10,100))
a[1] = x
a[2] = y
a[3] = x * y
a[4] = y * y
a[5] = x * x
a[6] = (y * y) * x
a[7] = (x * x) * y
a[8] = y**3
a[9] = x**3
# Finally use dot product in one go to get solutions across all iterations
out = solved.T.dot(a).T

最新更新