图像的深度学习异常检测

DeepLearning Anomaly Detection for images

我对深度学习的世界还比较陌生。我想创建一个用于图像异常检测的深度学习模型(最好使用 Tensorflow/Keras)。通过异常检测,我的意思是,本质上是 OneClassSVM

我已经使用图像中的 HOG 特征尝试了 sklearn 的 OneClassSVM。我想知道是否有一些例子可以说明我如何在深度学习中做到这一点。我查了一下,但找不到一个代码片段来处理这种情况。

在 Keras 中执行此操作的方法是使用该模块的 KerasRegressor wrapper module (they wrap sci-kit learn's regressor interface). Useful information can also be found in the source 代码。基本上你首先必须定义你的网络模型,例如:

def simple_model():
    #Input layer
    data_in = Input(shape=(13,)) 
    #First layer, fully connected, ReLU activation
    layer_1 = Dense(13,activation='relu',kernel_initializer='normal')(data_in)   
    #second layer...etc
    layer_2 = Dense(6,activation='relu',kernel_initializer='normal')(layer_1)  
    #Output, single node without activation
    data_out = Dense(1, kernel_initializer='normal')(layer_2)     
    #Save and Compile model
    model = Model(inputs=data_in, outputs=data_out)   
    #you may choose any loss or optimizer function, be careful which you chose 
    model.compile(loss='mean_squared_error', optimizer='adam')
    return model

然后,将其传递给 KerasRegressor 生成器,并 fit 连同您的数据:

from keras.wrappers.scikit_learn import KerasRegressor
#chose your epochs and batches 
regressor = KerasRegressor(build_fn=simple_model, nb_epoch=100, batch_size=64)
#fit with your data
regressor.fit(data, labels, epochs=100)

您现在可以对其进行预测或获取其分数:

p = regressor.predict(data_test) #obtain predicted value
score = regressor.score(data_test, labels_test) #obtain test score

在您的情况下,由于您需要从正常图像中检测异常图像,您可以采取的一种方法是通过传递标记为 1 和标记为 0 的图像。

这将使您的模型在输入异常图像时 return 接近 1 的值,使您能够 达到阈值 所需的结果。您可以将此输出视为您训练为 1(完美匹配)的 "Anomalous Model" 的 R^2 系数。

此外,正如您提到的,自动编码器是另一种进行异常检测的方法。为此,我建议您查看 Keras 博客 post Building Autoencoders in Keras,他们在其中详细解释了如何使用 Keras 库实现它们。


值得注意的是,Single-class classification 是 Regression 的另一种说法。

分类试图在 N 种可能的 class 中找到 概率分布 ,您通常会选择最有可能的 class 作为输出(这就是为什么大多数分类网络在其输出标签上使用 Sigmoid 激活,因为它的范围是 [0, 1])。它的 输出是 discrete/categorical.

类似地,回归试图通过最小化误差或一些其他指标(如众所周知的 R^2 指标,或 Coefficient of Determination).它的 输出是真实的 number/continuous (这也是大多数回归网络不在其输出上使用激活的原因)。希望对您有所帮助,祝您编码顺利。