tensorflow 自定义循环不会在第一个纪元结束并且进度条运行到无限

tensorflow custom loop does not end in first epoch and progress bar runs to infinite

我正在尝试编写一个 tensorflow 自定义训练循环并包含一些 tensorboard 实用程序。

完整代码如下:

import tensorflow as tf
from pathlib import Path
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import layers
import cv2
from tqdm import tqdm
from os import listdir
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from tqdm import tqdm
from random import shuffle, choice, uniform

from os.path import isdir, dirname, abspath, join
from os import makedirs
from tensorflow.keras.callbacks import (ModelCheckpoint, TensorBoard,
                                        EarlyStopping, LearningRateScheduler)

import io
from natsort import natsorted
from tensorflow.keras import backend as K
from tensorflow.keras import Sequential,Model

from tensorflow.keras.applications import (DenseNet201, InceptionV3, MobileNetV2,
                                           ResNet101, Xception, EfficientNetB7,VGG19, NASNetLarge)
from tensorflow.keras.applications import (densenet, inception_v3, mobilenet_v2,
                                           resnet, xception, efficientnet, vgg19, nasnet)

from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.layers.experimental.preprocessing import Rescaling, Resizing
from tensorflow.keras.utils import Progbar


ROOT = '/content/drive/MyDrive'
data_path = 'cropped/'
train_path = data_path + 'train'
val_path = data_path + 'val'

labels = {v:k for k, v in enumerate(listdir(train_path))}

models = {
    'densenet': DenseNet201,
    'xception': Xception,
    'inceptionv3': InceptionV3,
    'effecientnetb7': EfficientNetB7,
    'vgg19': VGG19,
    'nasnetlarge': NASNetLarge,
    'mobilenetv2': MobileNetV2,
    'resnet': ResNet101
}

# models['densenet']()

preprocess_pipeline = {
    'densenet': densenet.preprocess_input,
    'xception': xception.preprocess_input,
    'inceptionv3': inception_v3.preprocess_input,
    'effecientnetb7': efficientnet.preprocess_input,
    'vgg19': vgg19.preprocess_input,
    'nasnetlarge': nasnet.preprocess_input,
    'mobilenetv2': mobilenet_v2.preprocess_input,
    'resnet': resnet.preprocess_input
}


def configure_for_performance(ds, buffer_size, batch_size):
    """
    Configures caching and prefetching
    """
    ds = ds.cache()
    ds = ds.batch(batch_size)
    ds = ds.prefetch(buffer_size=buffer_size)
    return ds


def generator(tfrecord_file, batch_size, n_data, validation_ratio, reshuffle_each_iteration=False):
    """
    Returns training and validation generators with infinite repeat.
    """
    reader = tf.data.TFRecordDataset(filenames=[tfrecord_file])
    reader.shuffle(n_data, reshuffle_each_iteration=reshuffle_each_iteration)
    AUTOTUNE = tf.data.experimental.AUTOTUNE

    val_size = int(n_data * validation_ratio)
    train_ds = reader.skip(val_size)
    val_ds = reader.take(val_size)

    # Parsing data from tfrecord format.
    train_ds = train_ds.map(_parse_function, num_parallel_calls=AUTOTUNE)
    
    # Some data augmentation.
    train_ds = train_ds.map(_augment_function, num_parallel_calls=AUTOTUNE)
    train_ds = configure_for_performance(train_ds, AUTOTUNE, batch_size).repeat()

    val_ds = val_ds.map(_parse_function, num_parallel_calls=AUTOTUNE)
    val_ds = val_ds.map(_augment_function, num_parallel_calls=AUTOTUNE)
    val_ds = configure_for_performance(val_ds, AUTOTUNE, batch_size).repeat() # Is this repeat function the reason behind the issue 
    return train_ds, val_ds

def create_model(optimizer, name='densenet', include_compile=True):
    base_model = models[name](include_top=False, weights='imagenet')
    x = GlobalAveragePooling2D()(base_model.layers[-1].output)
    x = Dense(1024, activation='relu')(x)
    output = Dense(12, activation='softmax')(x)
    model = Model(base_model.inputs, output)

    if include_compile:
        model.compile(loss='categorical_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])

    return model

现在让我们创建一个模型并初始化:

n_data = len(list(Path(data_path).rglob('*.jpg'))) # Find out how many images are there
validation_ratio = 0.2
val_size = int(n_data * validation_ratio) # Find out validation image size.
train_size = n_data - val_size # And train images size
batch_size = 64
n_epochs = 5

# Tfrecord of images
filename = '/content/drive/MyDrive/cropped_data.tfrecord'

train_ds, val_ds = generator(filename,
                            batch_size=batch_size,
                            n_data=n_data,
                            validation_ratio=validation_ratio,
                            reshuffle_each_iteration=True)

# Tensorboard initialization
model_name = 'xception'

path_to_run = "runs/run_1"
tb_train_path = join(path_to_run, 'logs','train')
tb_test_path = join(path_to_run, 'logs', 'test')

train_writer = tf.summary.create_file_writer(tb_train_path)
test_writer = tf.summary.create_file_writer(tb_test_path)
train_step = test_step = 0

blocks_to_train = []
lr = 1e-4

optimizer = SGD(lr=lr, decay=1e-6,momentum=0.9,nesterov=True)
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
acc_metric = tf.keras.metrics.CategoricalCrossentropy()

# Create the xception model
model = create_model(optimizer, name=model_name, include_compile=False)

metrics = {'acc': 0.0, 'loss': 0.0, 'val_acc': 0.0, 'val_loss': 0.0, 'lr': lr}

这是训练和测试循环:

for epoch in range(n_epochs):
    # Iterate through the training set
    progress_bar = Progbar(train_size, stateful_metrics=list(metrics.keys()))

    for batch_idx, (x, y) in enumerate(train_ds):
        with tf.GradientTape() as tape:
            y_pred = model(x, training=True)
            loss = loss_fn(y, y_pred)

        gradients = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, model.trainable_weights))
        acc_metric.update_state(y, y_pred)
        train_step += 1
        progress_bar.update(batch_idx*batch_size, values=[('acc',acc_metric.result()),
                                       ('loss', loss)])

    with train_writer.as_default():
        tf.summary.scalar("Loss", loss, step=epoch)
        tf.summary.scalar(
            "Accuracy", acc_metric.result(), step=epoch
        )

    # reset accuracy between epochs (and for testing and test)

    acc_metric.reset_states()


    for batch_idx, (x,y) in enumerate(val_ds):
        y_pred = model(x, training=False)
        loss = loss_fn(y, y_pred)
        acc_metric.update_state(y,
                                y_pred)
        confusion += get_confusion_matrix(y, y_pred, class_names=list(labels.keys()))

    with test_writer.as_default():
        tf.summary.scalar("Loss", loss, step=epoch)
        tf.summary.scalar("Accuracy", acc_metric.result(), step=epoch)

    progress_bar.update(train_size, values=[('val_acc', acc_metric.result()), ('val_loss', loss)])

    # reset accuracy between epochs (and for testing and test)
    acc_metric.reset_states()

我修改了代码并删除了一些 tensorboard 实用程序。代码开始训练,但不会在预定义时期结束时停止。我看到进度条一直在不停地显示验证指标。

你们能帮我做一个和 keras.fit 功能完全一样的进度条吗?

谢谢

我发现了长时间训练 epoch 背后的(愚蠢的)原因:

数据由 train_size 训练数据和 val_size 验证数据组成,不考虑批次。例如,训练数据包含 4886 个数据样本,这将是 76 个数据批次(batch_size=64)。

当我使用for batch_idx, (x, y) in enumerate(train_gen):时,我总共有76个批次,但我在循环中错误地循环了4886个批次。

我重写了以下几行:

for epoch in range(n_epochs):
# Iterate through the training set
progress_bar = Progbar(train_size, stateful_metrics=list(metrics.keys()))

train_gen = train_ds.take(train_size//batch_size) # This line

for batch_idx, (x, y) in enumerate(train_gen):

.....


val_gen = val_ds.take(val_size//batch_size)

for batch_idx, (x,y) in enumerate(val_gen):