从 Tensorflow 中的张量中读取文件名

Reading in file names from a tensor in Tensorflow

上下文:我正在尝试制作一个 GAN 以从大型数据集生成图像,并且在加载训练数据时 运行 遇到了 OOM 问题。为了解决这个问题,我试图传入一个文件目录列表,并仅在需要时将它们作为图像读入。

问题:我不知道如何从张量本身解析出文件名。如果有人对如何将张量转换回列表或以某种方式遍历张量有任何见解。或者,如果这是解决此问题的糟糕方法,请告诉我

相关代码片段:

正在生成数据: 注意:make_file_list() returns 我想读入的所有图像的文件名列表

data = make_file_list(base_dir)
train_dataset = tf.data.Dataset.from_tensor_slices(data).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train(train_dataset, EPOCHS)

训练函数:

def train(dataset, epochs):
    plot_iteration = []
    gen_loss_l = []
    disc_loss_l = []

    for epoch in range(epochs):
        start = time.time()

        for image_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch)

训练步骤:

@tf.function
def train_step(image_files):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    images = [load_img(filepath) for filepath in image_files]

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

错误:

line 250, in train_step  *
        images = [load_img(filepath) for filepath in image_files]

    OperatorNotAllowedInGraphError: Iterating over a symbolic `tf.Tensor` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature

删除 train_step 上的 @tf.function 装饰器。如果你用@tf.function修饰你的train_step,Tensorflow会尝试将train_step里面的Python代码转换成一个执行图,而不是在eager模式下运行。执行图提供了加速,但也对可以执行的运算符施加了一些限制(如错误所述)。

要将 @tf.function 保持在 train_step,您可以先在 train 函数中执行迭代和加载步骤,然后将已加载的图像作为参数传递给 train_step 而不是尝试直接在 train_step

中加载图像
def train(dataset, epochs):
    plot_iteration = []
    gen_loss_l = []
    disc_loss_l = []

    for epoch in range(epochs):
        start = time.time()

    for image_batch in dataset:
        images = [load_img(filepath) for filepath in image_batch ]
        gen_loss, disc_loss = train_step(images)

@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)
        ....