scikit-learn 的 BallTree 的类型错误
TypeError with scikit-learn's BallTree
我有这个玩具示例 scipy 的 cKDTree which works very well and I want to do a similar piece of code with scikit-learn'sBallTree
import numpy as np
from scipy import spatial
min_neighbors = 3
x,y = np.mgrid[0:5,0:5]
grid_x,grid_y = np.mgrid[1:6,1:6]
points = np.c_[x.ravel(),y.ravel()]
grid_points = np.c_[grid_x.ravel(),grid_y.ravel()]
tree = spatial.cKDTree(points)
indices = tree.query_ball_point(grid_points,r=1)
for idx,(matches,grid) in enumerate(zip(indices,grid_points)):
if len(matches) >= min_neighbors:
x1,y1 = tree.data[matches].T
当我用 BallTree 做类似的玩具示例时,如下所示
import numpy as np
from sklearn.neighbors.ball_tree import BallTree
from sklearn.neighbors import NearestNeighbors
import sys
def main():
min_neighbors = 3
x,y = np.mgrid[0:5,0:5]
grid_x,grid_y = np.mgrid[1:6,1:6]
points = np.c_[x.ravel(),y.ravel()]
grid_points = np.c_[grid_x.ravel(),grid_y.ravel()]
bt = BallTree(points,leaf_size=1, metric='haversine')
indices = bt.query_radius(grid_points,1)
for idx,(matches,grid) in enumerate(zip(indices,grid_points)):
#print(matches)
if len(matches) >= min_neighbors:
x1,y1 = bt.data[matches].T
main()
我收到以下错误 -
Traceback (most recent call last):
File "testballtree.py", line 25, in <module>
main()
File "testballtree.py", line 23, in main
x1,y1 = bt.data[matches].T
File "stringsource", line 406, in View.MemoryView.memoryview.__getitem__
File "stringsource", line 746, in View.MemoryView.memview_slice
TypeError: only integer scalar arrays can be converted to a scalar index
以与 scipy 相同的方式访问 scikit-learn 的 BallTree 中的数据属性的确切方法是什么?
scikit-learn version is 0.19.2
看起来文档有误,bt.data
是内存视图而不是 numpy 数组。它可能应该是一个 numpy 数组或者是私有的。不过,您可以改用 points
来修复您的代码段。已打开 https://github.com/scikit-learn/scikit-learn/issues/11728
BallTree.data
是 a view or a copy of the training data,因此在上面的示例中,您可以直接使用 points
数组。
data
属性文档字符串确实不正确:它是内存视图而不是数组。您可以使用 numpy.asarray(bt.data)
.
将其转换回 numpy 数组
我有这个玩具示例 scipy 的 cKDTree which works very well and I want to do a similar piece of code with scikit-learn'sBallTree
import numpy as np
from scipy import spatial
min_neighbors = 3
x,y = np.mgrid[0:5,0:5]
grid_x,grid_y = np.mgrid[1:6,1:6]
points = np.c_[x.ravel(),y.ravel()]
grid_points = np.c_[grid_x.ravel(),grid_y.ravel()]
tree = spatial.cKDTree(points)
indices = tree.query_ball_point(grid_points,r=1)
for idx,(matches,grid) in enumerate(zip(indices,grid_points)):
if len(matches) >= min_neighbors:
x1,y1 = tree.data[matches].T
当我用 BallTree 做类似的玩具示例时,如下所示
import numpy as np
from sklearn.neighbors.ball_tree import BallTree
from sklearn.neighbors import NearestNeighbors
import sys
def main():
min_neighbors = 3
x,y = np.mgrid[0:5,0:5]
grid_x,grid_y = np.mgrid[1:6,1:6]
points = np.c_[x.ravel(),y.ravel()]
grid_points = np.c_[grid_x.ravel(),grid_y.ravel()]
bt = BallTree(points,leaf_size=1, metric='haversine')
indices = bt.query_radius(grid_points,1)
for idx,(matches,grid) in enumerate(zip(indices,grid_points)):
#print(matches)
if len(matches) >= min_neighbors:
x1,y1 = bt.data[matches].T
main()
我收到以下错误 -
Traceback (most recent call last):
File "testballtree.py", line 25, in <module>
main()
File "testballtree.py", line 23, in main
x1,y1 = bt.data[matches].T
File "stringsource", line 406, in View.MemoryView.memoryview.__getitem__
File "stringsource", line 746, in View.MemoryView.memview_slice
TypeError: only integer scalar arrays can be converted to a scalar index
以与 scipy 相同的方式访问 scikit-learn 的 BallTree 中的数据属性的确切方法是什么?
scikit-learn version is 0.19.2
看起来文档有误,bt.data
是内存视图而不是 numpy 数组。它可能应该是一个 numpy 数组或者是私有的。不过,您可以改用 points
来修复您的代码段。已打开 https://github.com/scikit-learn/scikit-learn/issues/11728
BallTree.data
是 a view or a copy of the training data,因此在上面的示例中,您可以直接使用 points
数组。
data
属性文档字符串确实不正确:它是内存视图而不是数组。您可以使用 numpy.asarray(bt.data)
.