为什么我在 R 中使用 keras 训练自动编码器时收到此错误?

Why do I receive this error training an autoencoder with keras in R?

我有一个包含不同大小的 8 位 rgb(3 通道)图像的目录。我正在尝试使用它们在 R 3.6.3 中使用 keras 2.2.5.0tensorflow 2.0 训练自动编码器.0linux mint 19 机器上。数据集在这里(压缩):https://github.com/hrj21/processing-imagestream-images/blob/master/ciliated_cells.zip

图像被分成两个标记为 class 的图像,但我不关心这种 class 结构。

当我 运行 fit_generator() 函数时,我得到错误:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
IndexError: list index out of range

我确定这是我做错了什么,但我对 keras 的经验不足,无法理解那是什么。您能提供的任何帮助将不胜感激。这是代码:

# Load package ------------------------------------------------------------

library(keras)

# Defining the file paths -------------------------------------------------

base_dir <- "ciliated_cells"
train_dir <- file.path(base_dir, "train")
validation_dir <- file.path(base_dir, "validation")
test_dir <- file.path(base_dir, "test")

# Define data generators --------------------------------------------------
# To scale and resize images 

datagen <- image_data_generator(rescale = 1/255)

train_generator <- flow_images_from_directory(
  train_dir,
  datagen,
  target_size = c(150, 150),
  batch_size = 88,
  class_mode = NULL
)

validation_generator <- flow_images_from_directory(
  validation_dir,
  datagen,
  target_size = c(150, 150),
  batch_size = 36,
  class_mode = NULL
)

test_generator = flow_images_from_directory(
  test_dir,
  datagen,
  target_size = c(150, 150),
  batch_size = 30,
  class_mode = NULL,  
  shuffle = FALSE)  # keep data in same order as labels

# Defining the model architecture from scratch ----------------------------

input <- layer_input(shape = c(150, 150, 3))

output <- input %>%
  layer_flatten(input_shape = c(150, 150, 3)) %>%
  layer_flatten() %>%
  layer_dense(units = 32, activation = "relu") %>%
  layer_dense(units = 16, name = "code") %>%
  layer_dense(units = 32, activation = "relu") %>%
  layer_dense(units = 150 * 150 * 3) %>%
  layer_reshape(c(150, 150, 3))

model <- keras_model(input, output)

# Compiling and fitting the model -----------------------------------------

model %>% compile(
  loss = "mse",
  optimizer = optimizer_rmsprop(lr = 2e-5)
)

history <- model %>% fit_generator(
  train_generator,
  steps_per_epoch = 1,
  epochs = 100,
  validation_data = validation_generator,
  validation_steps = 1
)

这是 sessionInfo() 的输出:

R version 3.6.3 (2020-02-29)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Linux Mint 19

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1

locale:
 [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C               LC_TIME=en_GB.UTF-8       
 [4] LC_COLLATE=en_GB.UTF-8     LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
 [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                  LC_ADDRESS=C              
[10] LC_TELEPHONE=C             LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] forcats_0.5.0   stringr_1.4.0   dplyr_0.8.5     purrr_0.3.3     readr_1.3.1     tidyr_1.0.2    
 [7] tibble_3.0.0    ggplot2_3.3.0   tidyverse_1.3.0 keras_2.2.5.0  

loaded via a namespace (and not attached):
 [1] reticulate_1.15-9000 tidyselect_1.0.0     haven_2.2.0          lattice_0.20-41     
 [5] colorspace_1.4-1     vctrs_0.2.4          generics_0.0.2       base64enc_0.1-3     
 [9] rlang_0.4.5          pillar_1.4.3         withr_2.1.2          glue_1.4.0          
[13] DBI_1.1.0            rappdirs_0.3.1       dbplyr_1.4.2         modelr_0.1.6        
[17] readxl_1.3.1         lifecycle_0.2.0      tensorflow_2.0.0     munsell_0.5.0       
[21] gtable_0.3.0         cellranger_1.1.0     rvest_0.3.5          tfruns_1.4          
[25] fansi_0.4.1          broom_0.5.5          Rcpp_1.0.4.6         backports_1.1.6     
[29] scales_1.1.0         jsonlite_1.6.1       fs_1.4.1             hms_0.5.3           
[33] packrat_0.5.0        stringi_1.4.6        grid_3.6.3           cli_2.0.2           
[37] tools_3.6.3          magrittr_1.5         crayon_1.3.4         whisker_0.4         
[41] pkgconfig_2.0.3      zeallot_0.1.0        ellipsis_0.3.0       Matrix_1.2-18       
[45] xml2_1.3.1           reprex_0.3.0         lubridate_1.7.4      assertthat_0.2.1    
[49] httr_1.4.1           rstudioapi_0.11      R6_2.4.1             nlme_3.1-145        
[53] compiler_3.6.3 

所以我意识到我的错误。我的数据生成器正在生成输入图像,而不是输出图像(应该是相同的)供自动编码器学习。所以解决方案是将每个 flow_images_from_directory() 函数中的 class_mode 参数更改为 "input"。然后,“fit_generator()”函数运行没有问题。没有这个,自动编码器就不会 "know" 它会尝试在输出层中重现输入图像。