Scipy最小平方:將正方形網格擬合到2D的實驗點上。

[英]Scipy leastsq: fitting a square grid to experimental points in 2D


I'm trying to use Scipy leastsq to find the best fit of a "square" grid for a set of measured points coordinates in 2-D (the experimental points are approximately on a square grid).

我試着用Scipy最小平方來找到一個“正方形”網格的最適合的坐標,在二維(實驗點大約在一個正方形網格上)。

The parameters of the grid are pitch (equal for x and y), the center position (center_x and center_y) and rotation (in degree).

網格的參數為間距(x和y相等)、中心位置(center_x和center_y)和旋轉(度)。

I defined an error function calculating the euclidean distance for each pairs of points (experimental vs ideal grid) and taking the mean. I want to minimize this function thorugh leastsq but I get an error.

我定義了一個誤差函數,計算每一對點的歐幾里得距離(實驗vs理想網格)並取平均值。我想把這個函數最小化,但我有一個錯誤。

Here are the function definitions:

下面是函數定義:

import numpy as np
from scipy.optimize import leastsq

def get_spot_grid(shape, pitch, center_x, center_y, rotation=0):
    x_spots, y_spots = np.meshgrid(
             (np.arange(shape[1]) - (shape[1]-1)/2.)*pitch, 
             (np.arange(shape[0]) - (shape[0]-1)/2.)*pitch)
    theta = rotation/180.*np.pi
    x_spots = x_spots*np.cos(theta) - y_spots*np.sin(theta) + center_x
    y_spads = x_spots*np.sin(theta) + y_spots*np.cos(theta) + center_y
    return x_spots, y_spots

def get_mean_distance(x1, y1, x2, y2):
    return np.sqrt((x1 - x2)**2 + (y1 - y2)**2).mean()

def err_func(params, xe, ye):
    pitch, center_x, center_y, rotation = params
    x_grid, y_grid = get_spot_grid(xe.shape, pitch, center_x, center_y, rotation)
    return get_mean_distance(x_grid, y_grid, xe, ye)

This are the experimental coordinates:

這是實驗坐標

xe = np.array([ -23.31,  -4.01,  15.44,  34.71, -23.39,  -4.10,  15.28,  34.60, -23.75,  -4.38,  15.07,  34.34, -23.91,  -4.53,  14.82,  34.15]).reshape(4, 4)
ye = np.array([-16.00, -15.81, -15.72, -15.49,   3.29,   3.51,   3.90,   4.02,  22.75,  22.93,  23.18,  23.43,  42.19,  42.35,  42.69,  42.87]).reshape(4, 4)

I try to use leastsq in this way:

我試着用這種方式使用最小平方數:

leastsq(err_func, x0=(19, 12, 5, 0), args=(xe, ye))

but I get the following error:

但我有以下錯誤:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-19-ee91cf6ce7d6> in <module>()
----> 1 leastsq(err_func, x0=(19, 12, 5, 0), args=(xe, ye))

C:\Anaconda\lib\site-packages\scipy\optimize\minpack.pyc in leastsq(func, x0, args, Dfun, full_output, col_deriv, ftol, xtol, gtol, maxfev, epsfcn, factor, diag)
    369     m = shape[0]
    370     if n > m:
--> 371         raise TypeError('Improper input: N=%s must not exceed M=%s' % (n, m))
    372     if epsfcn is None:
    373         epsfcn = finfo(dtype).eps

TypeError: Improper input: N=4 must not exceed M=1

I can't figure out what's the problem here :(

我不明白這里的問題是什么

2 个解决方案

#1


1  

Since the leastsq function assumes that the err_function return an array of residuals docs and it is a little difficult to write the err_function in this manner why not use another scipy's function - minimize. Then you add your metric - the error function you already have and it works. However, I think there is one more typo in get_spot_grid function (y_spots vs y_spads). The complete code:

因為最小平方函數假定err_function返回一個剩余的文檔數組,因此用這種方式編寫err_function是有點困難的,為什么不使用另一個scipy函數——最小化。然后你再加上你的度規——你已經擁有的誤差函數,它是有效的。但是,我認為get_spot_grid函數中還有一個類型的錯誤(y_spot vs y_spads)。完整的代碼:

import numpy as np
from scipy.optimize import leastsq, minimize

def get_spot_grid(shape, pitch, center_x, center_y, rotation=0):
    x_spots, y_spots = np.meshgrid(
             (np.arange(shape[1]) - (shape[1]-1)/2.)*pitch, 
             (np.arange(shape[0]) - (shape[0]-1)/2.)*pitch)
    theta = rotation/180.*np.pi
    x_spots = x_spots*np.cos(theta) - y_spots*np.sin(theta) + center_x
    y_spots = x_spots*np.sin(theta) + y_spots*np.cos(theta) + center_y
    return x_spots, y_spots


def get_mean_distance(x1, y1, x2, y2):
    return np.sqrt((x1 - x2)**2 + (y1 - y2)**2).mean()


def err_func(params, xe, ye):
    pitch, center_x, center_y, rotation = params
    x_grid, y_grid = get_spot_grid(xe.shape, pitch, center_x, center_y, rotation)
    return get_mean_distance(x_grid, y_grid, xe, ye)

xe = np.array([-23.31,  -4.01,  15.44,  34.71, -23.39,  -4.10,  15.28,  34.60, -23.75,  -4.38,  15.07,  34.34, -23.91,  -4.53,  14.82,  34.15]).reshape(4, 4)
ye = np.array([-16.00, -15.81, -15.72, -15.49,   3.29,   3.51,   3.90,   4.02,  22.75,  22.93,  23.18,  23.43,  42.19,  42.35,  42.69,  42.87]).reshape(4, 4)

# leastsq(err_func, x0=(19, 12, 5, 0), args=(xe, ye))
minimize(err_func, x0=(19, 12, 5, 0), args=(xe, ye))

#2


0  

The function passed to leastsq (e.g. err_func) should return an array of values of the same shape as xeand ye -- that is, one residual for each value of xe and ye.

傳遞給最小sq(例如err_func)的函數應該返回與xeand ye相同形狀的數組,即xe和ye的每個值的一個剩余值。

def err_func(params, xe, ye):
    pitch, center_x, center_y, rotation = params
    x_grid, y_grid = get_spot_grid(xe.shape, pitch, center_x, center_y, rotation)
    return get_mean_distance(x_grid, y_grid, xe, ye)

The call to mean() in get_mean_distance is reducing the return value to a single scalar. Keep in mind that the xe and ye passed to err_func are arrays not scalars.

get_mean_distance中的mean()調用將返回值降低到單個標量。請記住,xe和ye傳遞到err_func是數組而不是標量。

The error message

錯誤消息

TypeError: Improper input: N=4 must not exceed M=1

is saying the number of the parameters, 4, should not exceed the number of residuals returned by err_func, 1.

表示參數的個數,4,不應該超過err_func返回的剩余的剩余數。


The program can be made runnable by changing the call to mean() to mean(axis=0) (i.e. take the mean of each column) or mean(axis=1) (i.e. take the mean of each row):

這個程序可以通過改變調用mean()來實現runnable()來表示(axis=0)(即取每一列的平均值)或均值(即取每行的平均值):

def get_mean_distance(x1, y1, x2, y2):
    return np.sqrt((x1 - x2)**2 + (y1 - y2)**2).mean(axis=1)

I don't really understand your code well enough to know which it should be. But the idea is that there should be one value for each "point" in xe and ye.

我真的不太理解你的代碼,知道它應該是什么。但是,在xe和ye中,每個“點”應該有一個值。


注意!

本站翻译的文章,版权归属于本站,未经许可禁止转摘,转摘请注明本文地址:https://www.itdaan.com/blog/2014/02/06/729ff571967c5071b132851dfc9cca70.html



 
粤ICP备14056181号  © 2014-2020 ITdaan.com