Mean_iou 在 tensorflow 中不是 updating/resulting 的正确值

Mean_iou in tensorflow not updating/resulting in correct value

我在 tensorflow 中实现了一个 U-NET 版本,试图从卫星图像中识别建筑物。该实施正在发挥作用,并在分类方面取得了可喜的成果。除 mean_iou 外,所有指标似乎都正常工作。不管超参数和从数据集中选择的图像如何,mean_iou 始终相同。该值类似于每个纪元后的 15 个小数点。

与 mean_iou 相比,准确率和召回率值要高得多,这是预期的,所以似乎有些地方没有按预期工作。

由于我对 tensorflow 比较陌生,所以错误可能完全不同,但我是来学习的。所有反馈将不胜感激。

这是模型训练的相关代码和打印输出。

import numpy as np
import tensorflow as tf
from unet_model import build_unet
from data import load_dataset, tf_dataset
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger, EarlyStopping

model_types = ['segnet-master', 'unet-master', 'simpler', 'even-simpler']

if __name__ == "__main__":
    """ Hyperparamaters """
    dataset_path = "building-segmentation"
    input_shape = (64, 64, 3)
    batch_size = 20
    model = 3
    epochs = 5
    res = 64
    lr = 1e-3
    model_path = f"unet_models/unet_{epochs}_epochs_{res}.h5"
    csv_path = f"csv/data_unet_{epochs}_{res}.csv"

    """ Load the dataset """
    (train_images, train_masks), (val_images, val_masks) = load_dataset(dataset_path)

    train_dataset = tf_dataset(train_images, train_masks, batch=batch_size)
    val_dataset = tf_dataset(val_images, val_masks, batch=batch_size)


    model = build_unet(input_shape)

    model.compile(
        loss="binary_crossentropy",
        optimizer=tf.keras.optimizers.Adam(lr),
        metrics=[
            tf.keras.metrics.MeanIoU(num_classes=2),
            tf.keras.metrics.IoU(num_classes=2, target_class_ids=[0]),
            tf.keras.metrics.Recall(),
            tf.keras.metrics.Precision()
        ]
    )


    callbacks = [
        ModelCheckpoint(model_path, monitor="val_loss", verbose=1),
        ReduceLROnPlateau(monitor="val_loss", patience=10, factor=0.1, verbose=1),
        CSVLogger(csv_path),
        EarlyStopping(monitor="val_loss", patience=10)
    ]

    train_steps = len(train_images)//batch_size
    if len(train_images) % batch_size != 0:
        train_steps += 1

    test_steps = len(val_images)//batch_size
    if len(val_images) % batch_size != 0:
        test_steps += 1

    model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        steps_per_epoch=train_steps,
        validation_steps=test_steps,
        callbacks=callbacks
    )
epoch loss lr mean_io_u precision recall val_loss val_mean_io_u val_precision val_recall
0 0.41137945652008057 0.001 0.37184661626815796 0.695444643497467 0.5243006944656372 0.87176513671875 0.37157535552978516 0.38247567415237427 0.9118495583534241
1 0.3461640477180481 0.001 0.37182655930519104 0.7579150795936584 0.6075601577758789 0.3907579183578491 0.37157535552978516 0.8406943082809448 0.5024654865264893
2 0.3203786611557007 0.001 0.37182655930519104 0.7694798707962036 0.6599727272987366 0.3412915766239166 0.37157535552978516 0.6986522674560547 0.7543279528617859
3 0.2999393939971924 0.001 0.37182655930519104 0.7859976887702942 0.6890525221824646 0.40518054366111755 0.37157535552978516 0.6738141775131226 0.6654454469680786
4 0.28737708926200867 0.001 0.37182655930519104 0.793653130531311 0.7092126607894897 0.37544798851013184 0.37157535552978516 0.621263325214386 0.768422544002533
5 0.27629318833351135 0.001 0.37182655930519104 0.8028419613838196 0.72260981798172 0.4055494964122772 0.37157535552978516 0.8477562665939331 0.5473824143409729
6 0.2665417492389679 0.001 0.37182655930519104 0.809609055519104 0.7353982329368591 0.33294594287872314 0.37157535552978516 0.7307689785957336 0.6933897733688354
7 0.25887876749038696 0.001 0.37182655930519104 0.8132126927375793 0.744954526424408 0.28797024488449097 0.37157535552978516 0.7534120082855225 0.7735632061958313
8 0.25271594524383545 0.001 0.37182655930519104 0.8179733753204346 0.7538670897483826 0.30249008536338806 0.37157535552978516 0.8644329905509949 0.6237345337867737
9 0.24556593596935272 0.001 0.37182655930519104 0.8207928538322449 0.7622584104537964 0.3576349914073944 0.37157535552978516 0.6576451063156128 0.8346141576766968
10 0.23954670131206512 0.001 0.37182655930519104 0.8256030082702637 0.769091010093689 0.2541409134864807 0.37157535552978516 0.8100516200065613 0.7633218765258789
11 0.2349284589290619 0.001 0.37182655930519104 0.8274455070495605 0.7762861847877502 0.24383187294006348 0.37157535552978516 0.795067310333252 0.8124401569366455
12 0.22480393946170807 0.001 0.37182655930519104 0.8354562520980835 0.787416398525238 0.3778316378593445 0.37157535552978516 0.6533672213554382 0.8588836789131165
13 0.22573505342006683 0.001 0.37182655930519104 0.8342418670654297 0.7852107882499695 0.3342073857784271 0.37157535552978516 0.6768029928207397 0.7917631268501282
14 0.21639415621757507 0.001 0.37182655930519104 0.8411555886268616 0.7972605228424072 0.2792396545410156 0.37157535552978516 0.7611830234527588 0.7955203652381897
15 0.21154287457466125 0.001 0.37182655930519104 0.8441442251205444 0.8019176125526428 0.27426305413246155 0.37157535552978516 0.8764772415161133 0.6708933115005493
16 0.20740143954753876 0.001 0.37182655930519104 0.8469985127449036 0.8068550825119019 0.367437481880188 0.37157535552978516 0.646026611328125 0.8527452945709229
17 0.2005360722541809 0.001 0.37182655930519104 0.8522992134094238 0.8129924535751343 0.22591133415699005 0.37157535552978516 0.8203750252723694 0.8089460730552673
18 0.1976771354675293 0.001 0.37182655930519104 0.853760302066803 0.8163849115371704 0.2331937551498413 0.37157535552978516 0.807687520980835 0.8157453536987305
19 0.19583451747894287 0.001 0.37182655930519104 0.8560215830802917 0.8190248012542725 0.2519392669200897 0.37157535552978516 0.7935053110122681 0.8000433444976807
20 0.1872621327638626 0.001 0.37182655930519104 0.8615736365318298 0.8263705372810364 0.22855037450790405 0.37157535552978516 0.7948822975158691 0.8500961065292358
21 0.1852150857448578 0.001 0.37182655930519104 0.8620718717575073 0.8289932012557983 0.2352440059185028 0.37157535552978516 0.7972174286842346 0.8323403000831604
22 0.17845036089420319 0.001 0.37182655930519104 0.8677510023117065 0.8351714611053467 0.21090157330036163 0.37157535552978516 0.8470866084098816 0.8098670244216919
23 0.1732502579689026 0.001 0.37182655930519104 0.8711428046226501 0.8414102792739868 0.32612740993499756 0.37157535552978516 0.8412857055664062 0.695543646812439
24 0.17396509647369385 0.001 0.37182655930519104 0.8704758882522583 0.840953528881073 0.2149643898010254 0.37157535552978516 0.8315027952194214 0.8180400729179382
25 0.1740695685148239 0.001 0.37182655930519104 0.8702647089958191 0.8410759568214417 0.2138184905052185 0.37157535552978516 0.8604387044906616 0.7878146171569824
26 0.16104143857955933 0.001 0.37182655930519104 0.8794053196907043 0.8530260324478149 0.23256370425224304 0.37157535552978516 0.8179659843444824 0.8145195841789246
27 0.15866029262542725 0.001 0.37182655930519104 0.8813797831535339 0.8556373119354248 0.21111807227134705 0.37157535552978516 0.8566364049911499 0.805817723274231
28 0.15867507457733154 0.001 0.37182655930519104 0.8811318874359131 0.8551875352859497 0.2091868668794632 0.37157535552978516 0.8498891592025757 0.8088852763175964
29 0.15372247993946075 0.001 0.37182655930519104 0.884833574295044 0.8602938055992126 0.2100905030965805 0.37157535552978516 0.8543928265571594 0.8121073246002197
30 0.1550114005804062 0.001 0.37182655930519104 0.8840479850769043 0.85946124792099 0.21207265555858612 0.37157535552978516 0.8512551784515381 0.814805269241333
31 0.14192143082618713 0.001 0.37182655930519104 0.8927850127220154 0.8717316389083862 0.21726688742637634 0.37157535552978516 0.8147332072257996 0.8602878451347351
32 0.1401694267988205 0.001 0.37182655930519104 0.8940809965133667 0.8732201457023621 0.21714988350868225 0.37157535552978516 0.8370103240013123 0.8307888507843018
33 0.13880570232868195 0.001 0.37182655930519104 0.8950505256652832 0.8743049502372742 0.23316830396652222 0.37157535552978516 0.8291308283805847 0.8264546990394592
34 0.14308543503284454 0.001 0.37182655930519104 0.892676830291748 0.8704872131347656 0.2735193967819214 0.37157535552978516 0.7545790076255798 0.8698106408119202
35 0.14015090465545654 0.001 0.37182655930519104 0.8939213752746582 0.8743175864219666 0.20235474407672882 0.37157535552978516 0.8535885810852051 0.8286886215209961
36 0.1288939267396927 0.001 0.37182655930519104 0.9015076756477356 0.8844809532165527 0.22387968003749847 0.37157535552978516 0.8760555982589722 0.7937673926353455
37 0.12568938732147217 0.001 0.37182655930519104 0.9041174054145813 0.8872519731521606 0.21494744718074799 0.37157535552978516 0.8468613028526306 0.8249993324279785
38 0.12176792323589325 0.001 0.37182655930519104 0.9065613746643066 0.8911336064338684 0.23827765882015228 0.37157535552978516 0.8391880989074707 0.8176671862602234
39 0.11993639171123505 0.001 0.37182655930519104 0.9084023237228394 0.8925207257270813 0.22297391295433044 0.37157535552978516 0.8404833674430847 0.8346469402313232
40 0.11878598481416702 0.001 0.37182655930519104 0.9090615510940552 0.8941413164138794 0.22415445744991302 0.37157535552978516 0.8580552339553833 0.8152300715446472
41 0.1256236732006073 0.001 0.37182655930519104 0.9046309590339661 0.8880045413970947 0.20100584626197815 0.37157535552978516 0.8520526885986328 0.8423823714256287
42 0.10843898355960846 0.001 0.37182655930519104 0.9163806438446045 0.903978168964386 0.21887923777103424 0.37157535552978516 0.86836838722229 0.8237167596817017
43 0.10670299828052521 0.001 0.37182655930519104 0.9178842902183533 0.9054436683654785 0.21005834639072418 0.37157535552978516 0.8679876327514648 0.8253417611122131
44 0.10276217758655548 0.001 0.37182655930519104 0.9207708239555359 0.909300684928894 0.2151617556810379 0.37157535552978516 0.8735089302062988 0.8225894570350647
45 0.10141195356845856 0.001 0.3718271255493164 0.9218501448631287 0.9108821749687195 0.22106514871120453 0.37157535552978516 0.8555923700332642 0.8328163623809814
46 0.09918847680091858 0.001 0.37182655930519104 0.9235833883285522 0.9129346609115601 0.23230132460594177 0.37157535552978516 0.8555824756622314 0.8224022388458252
47 0.10588783025741577 0.001 0.37182655930519104 0.9191931486129761 0.9068878293037415 0.22423967719078064 0.37157535552978516 0.8427634239196777 0.825032114982605
48 0.103585384786129 0.001 0.37182655930519104 0.9209527969360352 0.9087461233139038 0.2110774666070938 0.37157535552978516 0.8639764785766602 0.8252225518226624
49 0.09157560020685196 0.001 0.37182655930519104 0.9292182922363281 0.9203035831451416 0.22161123156547546 0.37157535552978516 0.8649827837944031 0.8406093120574951
50 0.08616402745246887 0.001 0.37182655930519104 0.9334553480148315 0.9252204298973083 0.2387685328722 0.37157535552978516 0.8806527256965637 0.811405599117279
51 0.0846954956650734 0.001 0.37182655930519104 0.9345796704292297 0.9265674352645874 0.22581790387630463 0.37157535552978516 0.8756505846977234 0.8313769698143005

对于二进制问题,还有一个名为 tf.keras.metrics.BinaryIoU(name='IoU') 的 IOU。这可能会解决问题。

我在从 tf.keras.metrics.MeanIoU 移动后解决了多 class 分段的相同问题 到 tf.keras.metrics.OneHotMeanIoU 因为我正在使用一个热编码标签。