在指定时间后无法收敛时停止 Keras

Stop Keras when fails to converge after a specified time

我在 R 中使用 Keras API 的直接应用程序。根据 set.seed(value) 值,有时它会收敛,有时不会。我假设是因为种子设置了最初随机化的权重。如果它一开始不收敛,我通常可以通过更改种子值让它收敛到不同的 运行,但我必须手动 monitor/stop 它。如果模型在指定时间后仍未收敛,我该如何停止 Keras(例如,在 600 秒后停止它并使用不同的种子值重新启动它)。

  set.seed(42)
  x <- as.matrix(train_data)
  y <- as.matrix(train_targets)
  
  model = keras_model_sequential() %>%
    layer_dense(units=64, kernel_regularizer=regularizer_l2(0.001), activation="relu", input_shape=dim(train_data)[[2]]) %>%
    layer_dense(units=32, kernel_regularizer=regularizer_l2(0.001), activation = "relu") %>%
    layer_dense(units=1, activation="linear")
  
  model %>% compile(
    loss = "mse", 
    optimizer = "rmsprop",
    metrics = list("mae")
  )
  
  model %>% fit(x, y, epochs = 50,verbose = 0)

一个选项是定义一个调用自身的函数,也许在调用之前执行设置种子等操作。 基于 and borrowing an example from the keras guides 需要几秒钟才能 运行。

library(keras)
d <- dataset_mnist()
x_train <- d$train$x
y_train <- d$train$y
x_test <- d$test$x
y_test <- d$test$y

x_train <- array_reshape(x_train, c(nrow(x_train), 784))
x_test <- array_reshape(x_test, c(nrow(x_test), 784))
x_train <- x_train / 255
x_test <- x_test / 255
y_train <- to_categorical(y_train, 10)
y_test <- to_categorical(y_test, 10)

model <- keras_model_sequential() 
model %>% 
  layer_dense(units = 256, activation = 'relu', input_shape = c(784)) %>% 
  layer_dropout(rate = 0.4) %>% 
  layer_dense(units = 128, activation = 'relu') %>%
  layer_dropout(rate = 0.3) %>%
  layer_dense(units = 10, activation = 'softmax')

model %>% compile(
  loss = 'categorical_crossentropy',
  optimizer = optimizer_rmsprop(),
  metrics = c('accuracy')
)

我们可以做一个超时后调用自身的递归函数。

timed_fit <- function(t = 5) {
  Sys.sleep(1)
  set.seed(t)
  message("seed set to ", t)
  setTimeLimit(cpu = t, elapsed = t, transient = TRUE)
  on.exit({setTimeLimit(cpu = Inf, elapsed = Inf, transient = FALSE)})
  tryCatch({
    model %>% fit(
      x_train, y_train, 
      epochs = 4, batch_size = 128, 
      validation_split = 0.2
    )
  }, error = function(e) {
    if (grepl("reached elapsed time limit|reached CPU time limit", e$message)) {
      message("\n timed out!\n") # or set another seed, continue
      timed_fit(t + 10)
    } else {
      # error not related to timeout
      stop(e)
    }
  })
}

timed_fit()

插入 Sys.sleep(1) 是为了避免出现错误但未正确中断进程。