从 Tensorflow 中的张量中读取文件名

Reading in file names from a tensor in Tensorflow

上下文:我正在尝试制作一个 GAN 以从大型数据集生成图像,并且在加载训练数据时 运行 遇到了 OOM 问题。为了解决这个问题,我试图传入一个文件目录列表,并仅在需要时将它们作为图像读入。



正在生成数据: 注意:make_file_list() returns 我想读入的所有图像的文件名列表

data = make_file_list(base_dir)
train_dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train(train_dataset, EPOCHS)


def train(dataset, epochs):
    plot_iteration = []
    gen_loss_l = []
    disc_loss_l = []

    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch)


def train_step(image_files):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    images = [load_img(filepath) for filepath in image_files]

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)


line 250, in train_step  *
        images = [load_img(filepath) for filepath in image_files]

    OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature

删除 train_step 上的 @tf.function 装饰器。如果你用@tf.function修饰你的train_step,Tensorflow会尝试将train_step里面的Python代码转换成一个执行图,而不是在eager模式下运行。执行图提供了加速,但也对可以执行的运算符施加了一些限制(如错误所述)。

要将 @tf.function 保持在 train_step,您可以先在 train 函数中执行迭代和加载步骤,然后将已加载的图像作为参数传递给 train_step 而不是尝试直接在 train_step

def train(dataset, epochs):
    plot_iteration = []
    gen_loss_l = []
    disc_loss_l = []

    for epoch in range(epochs):
        start = time.time()

    for image_batch in dataset:
        images = [load_img(filepath) for filepath in image_batch ]
        gen_loss, disc_loss = train_step(images)

def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)