使用 StratifiedKFold 在 Auto-Keras 中训练时出现的 NumPy 数组值错误
NumPy array value error from training in Auto-Keras with StratifiedKFold
背景
我的情绪分析研究涉及各种数据集。最近我遇到了一个数据集,不知何故我无法成功训练。我主要使用 .CSV
文件格式的开放数据,因此 Pandas
和 NumPy
被大量使用。
在我的研究过程中,其中一种方法是尝试集成自动化机器学习 (AutoML
),我选择使用的库是 Auto-Keras
,主要使用其 TextClassifier()
包装函数实现AutoML
.
主要问题
我已经通过官方文档验证,TextClassifier()
以 NumPy 数组的格式获取数据。但是,当我将数据加载到 Pandas DataFrame
并在我需要训练的列上使用 .to_numpy()
时,以下错误不断显示:
ValueError Traceback (most recent call last)
<ipython-input-13-1444bf2a605c> in <module>()
16 clf = ak.TextClassifier(overwrite=True, max_trials=2)
17
---> 18 clf.fit(x_train, y_train, epochs=3, callbacks=cbs)
19
20
ValueError: Failed to convert a NumPy array to a Tensor (Unsupported object type float).
错误相关代码扇区
我使用 .drop()
删除不需要的 Pandas DataFrame
列的扇区,并使用 Pandas
的 to_numpy()
函数将需要的列转换为 NumPy
数组] 已提供。
df_src = pd.read_csv(get_data)
df_src = df_src.drop(columns=["Name", "Cast", "Plot", "Direction",
"Soundtrack", "Acting", "Cinematography"])
df_src = df_src.reset_index(drop=True)
X = df_src["Review"].to_numpy()
Y = df_src["Overall Sentiment"].to_numpy()
print(X, "\n")
print("\n", Y)
main错误代码部分,我在这里执行StratifedKFold()
,同时使用TextClassifier()
训练和测试模型。
fold = 0
for train, test in skf.split(X, Y):
fold += 1
print(f"Fold #{fold}\n")
x_train = X[train]
y_train = Y[train]
x_test = X[test]
y_test = Y[test]
cbs = [tf.keras.callbacks.EarlyStopping(patience=3)]
clf = ak.TextClassifier(overwrite=True, max_trials=2)
# The line where it indicated the error.
clf.fit(x_train, y_train, epochs=3, callbacks=cbs)
pred = clf.predict(x_test) # result data type is in lists of `string`
ceval = clf.evaluate(x_test, y_test)
metrics_test = metrics.classification_report(y_test, np.array(list(pred), dtype=int))
print(metrics_test, "\n")
print(f"Fold #{fold} finished\n")
补充
我正在通过 Google Colab
分享与错误相关的完整代码,你可以帮助我 diagnose here。
编辑笔记
我已经尝试过可能的解决方案,例如:
x_train = np.asarray(x_train).astype(np.float32)
y_train = np.asarray(y_train).astype(np.float32)
或
x_train = tf.data.Dataset.from_tensor_slices((x_train,))
y_train = tf.data.Dataset.from_tensor_slices((y_train,))
然而,问题依然存在。
其中一个字符串等于 nan
。只需删除此条目和相应的标签即可。
背景
我的情绪分析研究涉及各种数据集。最近我遇到了一个数据集,不知何故我无法成功训练。我主要使用 .CSV
文件格式的开放数据,因此 Pandas
和 NumPy
被大量使用。
在我的研究过程中,其中一种方法是尝试集成自动化机器学习 (AutoML
),我选择使用的库是 Auto-Keras
,主要使用其 TextClassifier()
包装函数实现AutoML
.
主要问题
我已经通过官方文档验证,TextClassifier()
以 NumPy 数组的格式获取数据。但是,当我将数据加载到 Pandas DataFrame
并在我需要训练的列上使用 .to_numpy()
时,以下错误不断显示:
ValueError Traceback (most recent call last)
<ipython-input-13-1444bf2a605c> in <module>()
16 clf = ak.TextClassifier(overwrite=True, max_trials=2)
17
---> 18 clf.fit(x_train, y_train, epochs=3, callbacks=cbs)
19
20
ValueError: Failed to convert a NumPy array to a Tensor (Unsupported object type float).
错误相关代码扇区
我使用 .drop()
删除不需要的 Pandas DataFrame
列的扇区,并使用 Pandas
的 to_numpy()
函数将需要的列转换为 NumPy
数组] 已提供。
df_src = pd.read_csv(get_data)
df_src = df_src.drop(columns=["Name", "Cast", "Plot", "Direction",
"Soundtrack", "Acting", "Cinematography"])
df_src = df_src.reset_index(drop=True)
X = df_src["Review"].to_numpy()
Y = df_src["Overall Sentiment"].to_numpy()
print(X, "\n")
print("\n", Y)
main错误代码部分,我在这里执行StratifedKFold()
,同时使用TextClassifier()
训练和测试模型。
fold = 0
for train, test in skf.split(X, Y):
fold += 1
print(f"Fold #{fold}\n")
x_train = X[train]
y_train = Y[train]
x_test = X[test]
y_test = Y[test]
cbs = [tf.keras.callbacks.EarlyStopping(patience=3)]
clf = ak.TextClassifier(overwrite=True, max_trials=2)
# The line where it indicated the error.
clf.fit(x_train, y_train, epochs=3, callbacks=cbs)
pred = clf.predict(x_test) # result data type is in lists of `string`
ceval = clf.evaluate(x_test, y_test)
metrics_test = metrics.classification_report(y_test, np.array(list(pred), dtype=int))
print(metrics_test, "\n")
print(f"Fold #{fold} finished\n")
补充
我正在通过 Google Colab
分享与错误相关的完整代码,你可以帮助我 diagnose here。
编辑笔记
我已经尝试过可能的解决方案,例如:
x_train = np.asarray(x_train).astype(np.float32)
y_train = np.asarray(y_train).astype(np.float32)
或
x_train = tf.data.Dataset.from_tensor_slices((x_train,))
y_train = tf.data.Dataset.from_tensor_slices((y_train,))
然而,问题依然存在。
其中一个字符串等于 nan
。只需删除此条目和相应的标签即可。