在 R 中的 keras CNN 模型上使用 lime 时为数据框给出的无效 'dimnames'

invalid 'dimnames' given for data frame when using lime on a keras CNN model in R

在 R

中使用石灰包时,我遇到了与 How to fix "Invalid 'dimnames' given for data frame? 相同的错误

Error in dimnames<-.data.frame(*tmp*, value = list(n)) : invalid 'dimnames' given for data frame

我正在尝试将 lime 函数(解释)应用于 keras CNN 网络

Model
_____________________________________________________________________________________________________
Layer (type)                                                       Output Shape                                               Param #                
=====================================================================================================
conv1d_8 (Conv1D)                                                  (None, 1896, 64)                                           384                    
_____________________________________________________________________________________________________________________________________________________
batch_normalization_8 (BatchNormalization)                         (None, 1896, 64)                                           256                    
_____________________________________________________________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)                                          (None, 1896, 64)                                           0                      
_____________________________________________________________________________________________________________________________________________________
dropout_10 (Dropout)                                               (None, 1896, 64)                                           0                      
_____________________________________________________________________________________________________________________________________________________
conv1d_9 (Conv1D)                                                  (None, 1886, 32)                                           22560                  
_____________________________________________________________________________________________________________________________________________________
batch_normalization_9 (BatchNormalization)                         (None, 1886, 32)                                           128                    
_____________________________________________________________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)                                          (None, 1886, 32)                                           0                      
_____________________________________________________________________________________________________________________________________________________
dropout_11 (Dropout)                                               (None, 1886, 32)                                           0                      
_____________________________________________________________________________________________________________________________________________________
conv1d_10 (Conv1D)                                                 (None, 1866, 16)                                           10768                  
_____________________________________________________________________________________________________________________________________________________
batch_normalization_10 (BatchNormalization)                        (None, 1866, 16)                                           64                     
_____________________________________________________________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)                                         (None, 1866, 16)                                           0                      
_____________________________________________________________________________________________________________________________________________________
dropout_12 (Dropout)                                               (None, 1866, 16)                                           0                      
_____________________________________________________________________________________________________________________________________________________
conv1d_11 (Conv1D)                                                 (None, 1826, 8)                                            5256                   
_____________________________________________________________________________________________________________________________________________________
batch_normalization_11 (BatchNormalization)                        (None, 1826, 8)                                            32                     
_____________________________________________________________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)                                         (None, 1826, 8)                                            0                      
_____________________________________________________________________________________________________________________________________________________
dropout_13 (Dropout)                                               (None, 1826, 8)                                            0                      
_____________________________________________________________________________________________________________________________________________________
flatten_2 (Flatten)                                                (None, 14608)                                              0                      
_____________________________________________________________________________________________________________________________________________________
dense_4 (Dense)                                                    (None, 100)                                                1460900                
_____________________________________________________________________________________________________________________________________________________
dropout_14 (Dropout)                                               (None, 100)                                                0                      
_____________________________________________________________________________________________________________________________________________________
dense_5 (Dense)                                                    (None, 5)                                                  505                    
=====================================================================================================================================================
Total params: 1,500,853
Trainable params: 1,500,613
Non-trainable params: 240
_____________________________________________________________________________________________________________________________________________________
my_explainer <- lime(x =train, model = model, bin_continuous = FALSE)

explanation <- explain(test, explainer = my_explainer, n_labels = 1, n_features = 2, kernel_width = 0.5)

我的训练集和测试集包含 1900 个特征,这里我只显示 9 个以简化

str(train[,1:9])
'data.frame':   77 obs. of  9 variables:
 $ X1: num  0.005598 0.009835 0.005365 0.000725 0.000992 ...
 $ X2: num  0 0.00156 0 0.00172 0.00261 ...
 $ X3: num  0 0.00752 0 0 0.00556 ...
 $ X4: num  0 0.00191 0.00479 0.00193 0.005 ...
 $ X5: num  0.0028 0.0033 0 0.00115 0.00503 ...
 $ X6: num  0 0 0 0.000453 0.00258 ...
 $ X7: num  0 0.00121 0 0.00127 0.00185 ...
 $ X8: num  0.00646 0 0.0097 0.00435 0.00278 ...
 $ X9: num  0 0.00301 0.00183 0.0045 0.00241 ...

str(test[,1:9])
'data.frame':   3 obs. of  9 variables:
 $ X1: num  0.00651 0.00286 0.00511
 $ X2: num  0.00229 0.00592 0.0031
 $ X3: num  0.00343 0.00338 0.0094
 $ X4: num  0.00464 0.00532 0.01073
 $ X5: num  0.00163 0.00203 0.01841
 $ X6: num  0.00277 0.0041 0.00865
 $ X7: num  0.00169 0.00257 0.01793
 $ X8: num  0.00669 0.00213 0.01202
 $ X9: num  0.0038 0.01023 0.00843
dimnames(train[,1:9])
[[1]]
 [1] "1"   "2"   "3"   "5"   "6"   "7"   "8"   "9"   "10"  "12"  "15"  "16"  "18"  "19"  "20"  "21"  "25"  "26"  "28"  "29"  "30"  "31"  "33"  "34" 
[25] "35"  "36"  "38"  "39"  "40"  "42"  "43"  "44"  "46"  "48"  "50"  "51"  "52"  "53"  "55"  "59"  "60"  "61"  "64"  "65"  "66"  "67"  "70"  "71" 
[49] "73"  "74"  "76"  "77"  "78"  "79"  "80"  "83"  "84"  "85"  "86"  "87"  "88"  "90"  "92"  "94"  "97"  "102" "103" "104" "105" "106" "108" "109"
[73] "112" "114" "115" "116" "117"

[[2]]
[1] "X1" "X2" "X3" "X4" "X5" "X6" "X7" "X8" "X9"

dimnames(test[,1:9])
[[1]]
[1] "23" "27" "32"

[[2]]
[1] "X1" "X2" "X3" "X4" "X5" "X6" "X7" "X8" "X9"

我认为问题在于 keras 模型需要一个矩阵作为输入。或者至少那是我的分类问题。这是对我有用的(我假设你的模型是顺序的)。我的火车数据也是一个矩阵:

model_type.keras.engine.sequential.Sequential <- function(x, ...) {
  "classification"
}
# Setup lime::predict_model()
predict_model.keras.engine.sequential.Sequential <- function (x, newdata, type, ...) {

  ## here you have to write function, that takes a data.frame
  ## and transform it to shape that keras understands

  ## for example if u flatten your array before training CNN, you just use
  ## as-matrix()

  ## if keras model expect 3d dataset, you write something like
  ## as.array(newdata, dims(n,12,12))

  your_function <- function(data){
  ## will return proper shape of your data
  }

  pred <- predict(object = x, x = your_function(newdata))
  data.frame (pred) }

x <- as.data.frame(train_data)  
x2 <- as.data.frame(test_data)  

explainer <- lime(x = x, model= model)



explanation <- lime::explain (
  x=  x2[1:10,], 
  explainer    = explainer, 
  n_features   = 5,
  n_labels=1) ## change for your problem    

plot_features (explanation) +
  labs (title = "LIME: Feature Importance Visualization CNN")

编辑:考虑到评论编辑了答案。