是否可以将 sklearn.neighbors.KNeighborsClassifier 用于 TensorFlow 会话,即使用 Tensor?

Is possible to use sklearn.neighbors.KNeighborsClassifier into a tensorflow Session i.e with Tensor?

我正在尝试在 Tensorflow 会话中使用 KNN 分类器。

但我收到以下错误:

NotImplementedError: Cannot convert a symbolic Tensor (Const:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported

在会话之外,代码工作正常:

import tensorflow as tf
from sklearn.neighbors import KNeighborsClassifier
features= tf.constant([[1., 1.], [2., 2.],[2., 2.],[2., 2.],[2., 2.],[2., 2.]])
label= tf.constant([[1], [2], [2], [2], [2], [2]])
model = KNeighborsClassifier(n_neighbors=3)

# Train the model using the training sets
model.fit(features,label)

teste = tf.constant([[1., 1.], [2., 2.]])
#Predict Output
predicted= model.predict(teste) # 0:Overcast, 2:Mild
print(predicted)

但我在会话中需要它,这里是一个错误示例代码:

import tensorflow as tf
from sklearn.neighbors import KNeighborsClassifier
@tf.function
def add():
    model = KNeighborsClassifier(n_neighbors=3)


    features= tf.constant([[1., 1.], [2., 2.],[2., 2.],[2., 2.],[2., 2.],[2., 2.]])
    label= tf.constant([[1], [2], [2], [2], [2], [2]])
    # Train the model using the training sets
    model.fit(features,label)
    
    return model

add()

版本:

tf.version.VERSION
'2.6.0'
sklearn.__version__
1.0.1

此代码可能会帮助您解决问题。

import tensorflow as tf
from sklearn.neighbors import KNeighborsClassifier
tf.config.run_functions_eagerly(True)
@tf.function
def add():
    
    model = KNeighborsClassifier(n_neighbors=3)


    features= tf.constant([[1., 1.], [2., 2.],[2., 2.],[2., 2.],[2., 2.],[2., 2.]])
    label= tf.constant([[1], [2], [2], [2], [2], [2]])
    features = features.numpy()
    label = label.numpy()
    # Train the model using the training sets
    model.fit(features,label)
    
    return model

add()

我有 运行 Google Colab 中的代码。将 NumPy 降级为 1.19.5

注:

  • .numpy() 将张量更改为 numpy 数组。
  • Tensorflow 2 有一个配置选项 运行 功能“急切”,这将允许通过 .numpy() 方法获取 Tensor 值。 3rd line of the code [如果没有该行,.numpy() 将无法工作,因为 @tf.function 装饰器出于性能原因禁止执行 tensor.numpy() 等函数。