Python numpy浮点数组精度
Python numpy floating point array precision
我正在尝试使用 Pegasos 小批量算法(如图 2 所示)解决 SVM 优化问题 link:http://www.cs.huji.ac.il/~shais/papers/ShalevSiSrCo10.pdf
#X: m*n matrix with m examples and n features per example (m=4000 and n=784 in my case), Y: m length vector containing 1 or -1 for each example, l: lambda as given in algorithm (l=1 in my code), itr: number of iterations, k: size of batch (100) in my case
def pegasos(X,Y,l,n,m,itr,k):
w = np.zeros((1,n),dtype=np.float32)
print m, n
diff = 0.0
for t in range(1,itr+1):
A = random.sample(range(1,m),k)
total = np.zeros((1,n),dtype=np.float32)
eta = 1/(l*t)
for i in A:
x = X[i]
y = Y[i]
p = y*(np.dot(w,x.T))
if p < 1:
p1 = y*x
total = np.add(total,p1)
#update rule
w = np.add((w*(1-(1/t))) , (eta*total*(1/k)))
return w
我的数据集是这样的,当我的变量 total 被计算时,我得到的大部分是 0,但是有一些值在 10^(-1) 的顺序到 10^(-5)。一旦在更新规则中将总数乘以 (eta/k),所有值都变为 0。因此在每次迭代中,我获得的 w 都是 0。这不应该是这种情况。我已经尝试过提高浮点数精度的方法,但它们似乎根本不起作用。当我使用基本的 Pegasos 算法时(如上面 link 中的图 1 所示),我没有遇到任何问题,因此我的数据集并不完全奇怪。
非常感谢有关此问题的任何帮助:)
如果需要精度,应该使用np.float64
(普通浮点精度double
)。
如果您正在使用 Python 2,则您在 (1/t)
、(1/k)
和 (1/l)
中使用整数除法。写成1.0/
,强制进行浮点除法
我正在尝试使用 Pegasos 小批量算法(如图 2 所示)解决 SVM 优化问题 link:http://www.cs.huji.ac.il/~shais/papers/ShalevSiSrCo10.pdf
#X: m*n matrix with m examples and n features per example (m=4000 and n=784 in my case), Y: m length vector containing 1 or -1 for each example, l: lambda as given in algorithm (l=1 in my code), itr: number of iterations, k: size of batch (100) in my case
def pegasos(X,Y,l,n,m,itr,k):
w = np.zeros((1,n),dtype=np.float32)
print m, n
diff = 0.0
for t in range(1,itr+1):
A = random.sample(range(1,m),k)
total = np.zeros((1,n),dtype=np.float32)
eta = 1/(l*t)
for i in A:
x = X[i]
y = Y[i]
p = y*(np.dot(w,x.T))
if p < 1:
p1 = y*x
total = np.add(total,p1)
#update rule
w = np.add((w*(1-(1/t))) , (eta*total*(1/k)))
return w
我的数据集是这样的,当我的变量 total 被计算时,我得到的大部分是 0,但是有一些值在 10^(-1) 的顺序到 10^(-5)。一旦在更新规则中将总数乘以 (eta/k),所有值都变为 0。因此在每次迭代中,我获得的 w 都是 0。这不应该是这种情况。我已经尝试过提高浮点数精度的方法,但它们似乎根本不起作用。当我使用基本的 Pegasos 算法时(如上面 link 中的图 1 所示),我没有遇到任何问题,因此我的数据集并不完全奇怪。 非常感谢有关此问题的任何帮助:)
如果需要精度,应该使用np.float64
(普通浮点精度double
)。
如果您正在使用 Python 2,则您在 (1/t)
、(1/k)
和 (1/l)
中使用整数除法。写成1.0/
,强制进行浮点除法