Tensorflow:如何在 Estimator 中使用来自生成器的数据集

Tensorflow: How to use dataset from generator in Estimator

试图建立简单的模型只是为了弄清楚如何处理 tf.data.Dataset.from_generator。我不明白如何设置 output_shapes 参数。我尝试了几种组合,包括未指定它,但由于张量的形状不匹配,仍然收到一些错误。这个想法只是产生两个带有 SIZE = 10 和 运行 线性回归的 numpy 数组。这是代码:

SIZE = 10


def _generator():
    feats = np.random.normal(0, 1, SIZE)
    labels = np.random.normal(0, 1, SIZE)
    yield feats, labels


def input_func_gen():
    shapes = (SIZE, SIZE)
    dataset = tf.data.Dataset.from_generator(generator=_generator,
                                             output_types=(tf.float32, tf.float32),
                                             output_shapes=shapes)
    dataset = dataset.batch(10)
    dataset = dataset.repeat(20)
    iterator = dataset.make_one_shot_iterator()
    features_tensors, labels = iterator.get_next()
    features = {'x': features_tensors}
    return features, labels


def train():
    x_col = tf.feature_column.numeric_column(key='x', )
    es = tf.estimator.LinearRegressor(feature_columns=[x_col])
    es = es.train(input_fn=input_func_gen)

另一个问题是是否可以使用此功能为 tf.feature_column.crossed_column 的特征列提供数据?总体目标是在批量训练中使用 Dataset.from_generator 功能,在数据不适合内存的情况下,数据从数据库加载到块上。高度赞赏所有意见和示例。

谢谢!

tf.data.Dataset.from_generator() 的可选 output_shapes 参数允许您指定生成器生成的值的形状。其类型有两个约束,定义了如何指定它:

  • output_shapes参数是一个"nested structure"(例如元组、元组的元组、元组的字典等),必须匹配值的结构(s) 由您的发电机产生。

    在您的程序中,_generator() 包含语句 yield feats, labels。因此 "nested structure" 是两个元素的元组(每个数组一个)。

  • output_shapes 结构的每个组件都应与相应张量的形状相匹配。数组的形状始终是 元组 维度。 (a tf.Tensor 的形状更一般:见 的讨论。)让我们看看 feats 的实际形状:

    >>> SIZE = 10
    >>> feats = np.random.normal(0, 1, SIZE)
    >>> print feats.shape
    (10,)
    

因此 output_shapes 参数应该是一个 2 元素元组,其中每个元素是 (SIZE,):

shapes = ((SIZE,), (SIZE,))
dataset = tf.data.Dataset.from_generator(generator=_generator,
                                         output_types=(tf.float32, tf.float32),
                                         output_shapes=shapes)

最后,您需要向 tf.feature_column.numeric_column() and tf.estimator.LinearRegressor() API 提供更多关于形状的信息:

x_col = tf.feature_column.numeric_column(key='x', shape=(SIZE,))
es = tf.estimator.LinearRegressor(feature_columns=[x_col],
                                  label_dimension=10)