Scipy 的 leastsq 与复数
Scipy's leastsq with complex numbers
我正在尝试对复数使用 scipy.optimize.leastsq。我知道已经有一些关于此的问题,但我仍然无法让我的简单示例正常工作,它抱怨从复数转换为实数。
如果我做对了下面的解决方案应该是 x=[1+1j,2j]
:
import numpy as np
from scipy.optimize import leastsq
def cost_cpl(x,A,b):
return np.abs(np.dot(A,x)-b)
A=np.array([[1,1],[2j,0]],dtype=np.complex128)
b=np.array([1+3j,-2+2j],dtype=np.complex128)
x,r=leastsq(cost_cpl,np.array([0+0j,0+0j]),args=(A,b))
print x
print r
但我得到了
TypeError: Cannot cast array data from dtype('complex128') to dtype('float64') according to the rule 'safe'
编辑:如果我将第一个猜测从 np.array([0+0j,0+0j])
更改为 np.array([0,0])
,该函数运行但我得到错误的答案(真实答案)。
由于leastsq()
只能接受实数,需要用.view()
方法进行实数组和复数组的转换
import numpy as np
from scipy.optimize import leastsq
def cost_cpl(x, A, b):
return (np.dot(A, x.view(np.complex128)) - b).view(np.double)
A = np.array([[1,1],[2j,0]],dtype=np.complex128)
b = np.array([1+3j,-2+2j],dtype=np.complex128)
init = np.array([0.0, 0.0, 0.0, 0.0])
x, r = leastsq(cost_cpl, init, args=(A, b))
print(x.view(np.complex128))
输出:
array([ 1.00000000e+00+1.j, 4.96506831e-16+2.j])
我正在尝试对复数使用 scipy.optimize.leastsq。我知道已经有一些关于此的问题,但我仍然无法让我的简单示例正常工作,它抱怨从复数转换为实数。
如果我做对了下面的解决方案应该是 x=[1+1j,2j]
:
import numpy as np
from scipy.optimize import leastsq
def cost_cpl(x,A,b):
return np.abs(np.dot(A,x)-b)
A=np.array([[1,1],[2j,0]],dtype=np.complex128)
b=np.array([1+3j,-2+2j],dtype=np.complex128)
x,r=leastsq(cost_cpl,np.array([0+0j,0+0j]),args=(A,b))
print x
print r
但我得到了
TypeError: Cannot cast array data from dtype('complex128') to dtype('float64') according to the rule 'safe'
编辑:如果我将第一个猜测从 np.array([0+0j,0+0j])
更改为 np.array([0,0])
,该函数运行但我得到错误的答案(真实答案)。
由于leastsq()
只能接受实数,需要用.view()
方法进行实数组和复数组的转换
import numpy as np
from scipy.optimize import leastsq
def cost_cpl(x, A, b):
return (np.dot(A, x.view(np.complex128)) - b).view(np.double)
A = np.array([[1,1],[2j,0]],dtype=np.complex128)
b = np.array([1+3j,-2+2j],dtype=np.complex128)
init = np.array([0.0, 0.0, 0.0, 0.0])
x, r = leastsq(cost_cpl, init, args=(A, b))
print(x.view(np.complex128))
输出:
array([ 1.00000000e+00+1.j, 4.96506831e-16+2.j])