使用 keras.layers.Normalization 进行预处理,adapt 调用冻结

Using keras.layers.Normalization for preprocessing, the adapt call freezes

我正在使用 keras.layers.Normalization 预处理从 make_csv_dataset 返回的 csv 数据集。执行在 adapt(ds) 调用时冻结。没有错误输出,它只是永远执行 adapt。我试过使用 pandas 进行标准化,它在几秒钟内完成。

系统信息:

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
features = ["sepal-length", "sepal-width", "pedal-length", "pedal-width"]
label = ["class"]
class_names = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]

def get_data():
    columns = features+label
    fpath = keras.utils.get_file("iris.csv", origin=url)
    ds = tf.data.experimental.make_csv_dataset(fpath, header=False, label_name=label[0],column_names=features+label,  batch_size=10, shuffle=True, ignore_errors=True)
    return ds


ds = get_data()
ds_features = ds.map(lambda x, y: tf.stack([x.pop(feature) for feature in features], axis=-1))

norm = keras.layers.Normalization(axis=-1)
norm.adapt(ds_features)

print("adapt completed")

您必须在 make_csv_dataset 中将参数 num_epochs 设置为 1,因为默认值为 None 并且它会导致无限循环,如 [=16= 中所述]:

An int specifying the number of times this dataset is repeated. If None, cycles through the dataset forever.

工作示例:

import tensorflow as tf

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
features = ["sepal-length", "sepal-width", "pedal-length", "pedal-width"]
label = ["class"]
class_names = ["Iris-setosa", "Iris-versicolor", "Iris-virginica"]

def get_data():
    columns = features+label
    fpath = tf.keras.utils.get_file("iris.csv", origin=url)
    ds = tf.data.experimental.make_csv_dataset(fpath, header=False, label_name=label[0],column_names=features+label,  num_epochs=1, batch_size=10, shuffle=True, ignore_errors=True)
    return ds


ds = get_data()
ds_feature = ds.map(lambda x, y: tf.stack([x.pop(feature) for feature in features], axis=-1))

norm = tf.keras.layers.Normalization(axis=-1)
norm.adapt(ds_feature)

print("adapt completed")
adapt completed