scipy.optimize.curve_fit - 类型错误("不正确的输入:N=%s 不得超过 M=%s' % (n, m))



我正在尝试优化我编写的函数的参数,该函数旋转点,然后从特定点将它们从像素位置转换为mm。我遇到了TypeError('Improper input: N=%s must not exceed M=%s' % (n, m)),其他关于SO的问题似乎表明我有比数据点更多的参数。我输入的数据有26组x,y坐标。有人能看出我下面的代码哪里出错了吗?

Thanks in advance

from scipy import optimize
import numpy as np
import math
def px_to_mm_v4(coords, cf_x, cf_y, nudge_x, nudge_y, center_x, center_y, rotate_degrees):

## set lower left loc
ll_x = center_x - (127.76/2/cf_x) ## lower left x location in pixels
ll_y = center_y + (85.47/2/cf_y) ## lower left y location in pixels

## unpack coordinates
x,y = coords

## rotate points around center
rotate_radians = math.radians(rotate_degrees)
x_rotated = center_x + math.cos(rotate_radians) * (x - center_x) - math.sin(rotate_radians) * (y - center_y)
y_rotated = center_y + math.cos(rotate_radians) * (x - center_x) + math.cos(rotate_radians) * (y - center_y)

## convert px to mm
x_converted = (x_rotated - ll_x) * cf_x + nudge_x
y_converted = (ll_y - y_rotated) * cf_y + nudge_y
ret_x = x_converted
ret_y = y_converted
ret = (ret_x,ret_y)
return ret

x_px = np.array([1723,1530,1334,1135,943,747,548,2520,2322,2120,1921,1726,1530,1331,1132,937,741,545,346,349,352,355,358,358,361,361,148])
y_px = np.array([596,791,986,1176,1373,1569,1769,1973,1967,1967,1964,1962,1964,1967,1967,1967,1962,1964,1967,1769,1569,1373,1178,986,791,602,2162])
x_mm = np.array([80,70,60,50,40,30,20,120,110,100,90,80,70,60,50,40,30,20,10,10,10,10,10,10,10,10,0])
y_mm = np.array([80,70,60,50,40,30,20,10,10,10,10,10,10,10,10,10,10,10,10,20,30,40,50,60,70,80,0])
test_coords_tup = (x_px,y_px)
points_to_fit_tup = (x_mm,y_mm)
cf_x_test = 0.05072
cf_y_test = 0.05076
nudge_x_test = -2.2
nudge_y_test = 2.1
center_x_test = 1374
center_y_test = 1290
rotate_degrees_test = 1.4
params0 = [cf_x_test,cf_y_test,nudge_x_test,nudge_y_test,center_x_test,center_y_test,rotate_degrees_test]
popt, pcov = optimize.curve_fit(px_to_mm_v4, test_coords_tup, points_to_fit_tup, p0=params0)

编辑:运行代码返回下面的错误

Traceback (most recent call last):
File "SO_example.py", line 48, in <module>
popt, pcov = optimize.curve_fit(px_to_mm_v4, test_coords_tup, points_to_fit_tup, p0=params0)
File "//anaconda3/lib/python3.7/site-packages/scipy/optimize/minpack.py", line 784, in curve_fit
res = leastsq(func, p0, Dfun=jac, full_output=1, **kwargs)
File "//anaconda3/lib/python3.7/site-packages/scipy/optimize/minpack.py", line 414, in leastsq
raise TypeError('Improper input: N=%s must not exceed M=%s' % (n, m))
TypeError: Improper input: N=7 must not exceed M=2

Fromleastsqdocs

func callable
Should take at least one (possibly length N vector) argument and 
returns M floating point numbers. It must not return NaNs or fitting 
might fail. M must be greater than or equal to N.

最新更新