如何使用 tf.data.Dataset.from_generator 进行批处理?我需要修改生成器吗

how to batch with tf.data.Dataset.from_generator? Do i needto modify generator

我正在使用 batch(8) 函数,它修改形状并添加批次维度,但每批次只获取一张图像。下面是我的代码:-

import cv2
import numpy as np
import os
import tensorflow as tf
import random

folder_path = "./real/"
files = os.listdir(folder_path)

def get_image():
    index = random.randint(0,len(files)-1)
    img = cv2.imread(folder_path+files[index])
    img = cv2.resize(img,(128,128))
    img = img/255.
    #More complex transformation
    yield img

dset = tf.data.Dataset.from_generator(get_image,(tf.float32)).batch(8)

for img in dset:
    print(img.shape)
    break

即使使用 batch(8),输出仍然是 (1, 128, 128, 3)。我是否需要修改生成器以手动创建批处理?还有,如何在tensorflow中包裹在生成器中,使其运行得更快?

因为你的 yield 只需要一张你应该循环 yield 的图片,下面是一个例子

def get_image():
   for file in files:
      img = cv2.imread(folder_path + file)
      img = cv2.resize(img, (128, 128))
      img = img / 255.

      yield img # Your supposed to yield in a loop

dataset = tf.data.Dataset.from_generator(get_image, output_shapes=(128, 128), output_types=(tf.float32))

next(iter(dataset.batch(8))).shape

# TensorShape([8, 128, 128])