DL4J-图像变得太亮

DL4J-Image become too bright

目前,我被要求使用 DL4J 和 YOLOv2 架构编写 CNN 代码。但问题是在模型完成后,我做了一个简单的 GUI 来进行验证测试,然后 image shown is too bright and sometimes the image can be displayed。我不确定这个问题是从哪里来的,是在训练的最早阶段还是其他阶段。在这里,我附上了我现在拥有的代码。对于迭代器:

public class faceMaskIterator {
private static final Logger log = org.slf4j.LoggerFactory.getLogger(faceMaskIterator.class);
private static final int seed = 123;
private static Random rng = new Random(seed);
private static String dataDir;
private static Path pathDirectory;
private static InputSplit trainData, testData;
private static final String[] allowedFormats  = NativeImageLoader.ALLOWED_FORMATS;
private static final double splitRatio = 0.8;
private static final int nChannels = 3;
public static final int gridWidth = 13;
public static final int gridHeight = 13;
public static final int yolowidth = 416;
public static final int yoloheight = 416;

private static RecordReaderDataSetIterator makeIterator(InputSplit split, Path dir, int batchSize) throws Exception {

    ObjectDetectionRecordReader recordReader = new ObjectDetectionRecordReader(yoloheight, yolowidth, nChannels,
            gridHeight, gridWidth, new VocLabelProvider(dir.toString()));
    recordReader.initialize(split);
    RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, 1,true);
    iter.setPreProcessor(new ImagePreProcessingScaler(0, 1));

    return iter;
}

public static RecordReaderDataSetIterator trainIterator(int batchSize) throws Exception {
    return makeIterator(trainData, pathDirectory, batchSize);
}

public static RecordReaderDataSetIterator testIterator(int batchSize) throws Exception {
    return makeIterator(testData, pathDirectory, batchSize);
}

public static void setup() throws IOException {
    log.info("Load data...");
    dataDir = Paths.get(
            System.getProperty("user.home"),
            Helper.getPropValues("dl4j_home.data")
    ).toString();
    pathDirectory = Paths.get(dataDir,"face_mask_dataset");
    FileSplit fileSplit = new FileSplit(new File(pathDirectory.toString()),allowedFormats,rng);
    PathFilter pathFilter = new RandomPathFilter(rng,allowedFormats);
    InputSplit[] sample = fileSplit.sample(pathFilter, splitRatio,1-splitRatio);
    trainData = sample[0];
    testData = sample[1];
}}

训练:

public class faceMaskPreTrained {
private static final Logger log = LoggerFactory.getLogger(ai.certifai.groupProjek.faceMaskPreTrained.class);
private static int seed = 420;
private static double detectionThreshold = 0.9;
private static int nBoxes = 3;
private static double lambdaNoObj = 0.7;
private static double lambdaCoord = 1.0;
private static double[][] priorBoxes = {{1, 1}, {2, 1}, {1, 2}};

private static int batchSize = 3;
private static int nEpochs = 1;
private static double learningRate = 1e-4;
private static int nClasses = 3;
private static List<String> labels;

private static File modelFilename = new File(System.getProperty("user.dir"), "generated-models/facemask_detector.zip");
private static ComputationGraph model;
private static Frame frame = null;
private static final Scalar GREEN = RGB(0, 255.0, 0);
private static final Scalar YELLOW = RGB(255, 255, 0);
private static final Scalar RED = RGB(255, 0, 0);
private static Scalar[] colormap = {GREEN, YELLOW, RED};
private static String labeltext = null;

public static void main(String[] args) throws Exception {
    faceMaskIterator.setup();
    RecordReaderDataSetIterator trainIter = faceMaskIterator.trainIterator(batchSize);
    RecordReaderDataSetIterator testIter = faceMaskIterator.testIterator(1);
    labels = trainIter.getLabels();

    if (modelFilename.exists()) {
        Nd4j.getRandom().setSeed(seed);
        log.info("Load model...");
        model = ModelSerializer.restoreComputationGraph(modelFilename);
    } else {
        Nd4j.getRandom().setSeed(seed);
        INDArray priors = Nd4j.create(priorBoxes);

        log.info("Build model...");
        ComputationGraph pretrained = (ComputationGraph) YOLO2.builder().build().initPretrained();

        
        FineTuneConfiguration fineTuneConf = getFineTuneConfiguration();
        model = getComputationGraph(pretrained, priors, fineTuneConf);
        System.out.println(model.summary(InputType.convolutional(
                faceMaskIterator.yoloheight,
                faceMaskIterator.yolowidth,
                nClasses)));

        log.info("Train model...");
        UIServer server = UIServer.getInstance();
        StatsStorage storage = new InMemoryStatsStorage();
        server.attach(storage);
        model.setListeners(new ScoreIterationListener(5), new StatsListener(storage,5));

        for (int i = 1; i < nEpochs + 1; i++) {
            trainIter.reset();
            while (trainIter.hasNext()) {
                model.fit(trainIter.next());
            }
            log.info("*** Completed epoch {} ***", i);
        }
        ModelSerializer.writeModel(model, modelFilename, true);
        System.out.println("Model saved.");
    }
    //   Evaluate the model's accuracy by using the test iterator.
    OfflineValidationWithTestDataset(testIter);
    //   Inference the model and process the webcam stream and make predictions.
    doInference();
}

private static ComputationGraph getComputationGraph(ComputationGraph pretrained, INDArray priors, FineTuneConfiguration fineTuneConf) {

    return new TransferLearning.GraphBuilder(pretrained)
            .fineTuneConfiguration(fineTuneConf)
            .removeVertexKeepConnections("conv2d_23")
            .removeVertexKeepConnections("outputs")
            .addLayer("conv2d_23",
                    new ConvolutionLayer.Builder(1, 1)
                            .nIn(1024)
                            .nOut(nBoxes * (5 + nClasses))
                            .stride(1, 1)
                            .convolutionMode(ConvolutionMode.Same)
                            .weightInit(WeightInit.XAVIER)
                            .activation(Activation.IDENTITY)
                            .build(),
                    "leaky_re_lu_22")
            .addLayer("outputs",
                    new Yolo2OutputLayer.Builder()
                            .lambdaNoObj(lambdaNoObj)
                            .lambdaCoord(lambdaCoord)
                            .boundingBoxPriors(priors.castTo(DataType.FLOAT))
                            .build(),
                    "conv2d_23")
            .setOutputs("outputs")
            .build();
}

private static FineTuneConfiguration getFineTuneConfiguration() {

    return new FineTuneConfiguration.Builder()
            .seed(seed)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer)
            .gradientNormalizationThreshold(1.0)
            .updater(new Adam.Builder().learningRate(learningRate).build())
            .l2(0.00001)
            .activation(Activation.IDENTITY)
            .trainingWorkspaceMode(WorkspaceMode.ENABLED)
            .inferenceWorkspaceMode(WorkspaceMode.ENABLED)
            .build();
}

//    Evaluate visually the performance of the trained object detection model
private static void OfflineValidationWithTestDataset(RecordReaderDataSetIterator test) throws InterruptedException {
    NativeImageLoader imageLoader = new NativeImageLoader();
    CanvasFrame canvas = new CanvasFrame("Validate Test Dataset");
    OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
    org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
    Mat convertedMat = new Mat();
    Mat convertedMat_big = new Mat();

    while (test.hasNext() && canvas.isVisible()) {
        org.nd4j.linalg.dataset.DataSet ds = test.next();
        INDArray features = ds.getFeatures();
        INDArray results = model.outputSingle(features);
        List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
        YoloUtils.nms(objs, 0.4);
        Mat mat = imageLoader.asMat(features);
        mat.convertTo(convertedMat, CV_8U, 255, 0);
        int w = mat.cols() * 2;
        int h = mat.rows() * 2;
        resize(convertedMat, convertedMat_big, new Size(w, h));
        convertedMat_big = drawResults(objs, convertedMat_big, w, h);
        canvas.showImage(converter.convert(convertedMat_big));
        canvas.waitKey();
    }
    canvas.dispose();
}

// Stream video frames from Webcam and run them through YOLOv2 model and get predictions
private static void doInference() {

    String cameraPos = "front";
    int cameraNum = 0;
    Thread thread = null;
    NativeImageLoader loader = new NativeImageLoader(
            faceMaskIterator.yolowidth,
            faceMaskIterator.yoloheight,
            3,
            new ColorConversionTransform(COLOR_BGR2RGB));
    ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);

    if (!cameraPos.equals("front") && !cameraPos.equals("back")) {
        try {
            throw new Exception("Unknown argument for camera position. Choose between front and back");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    FrameGrabber grabber = null;
    try {
        grabber = FrameGrabber.createDefault(cameraNum);
    } catch (FrameGrabber.Exception e) {
        e.printStackTrace();
    }
    OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();

    try {
        grabber.start();
    } catch (FrameGrabber.Exception e) {
        e.printStackTrace();
    }

    CanvasFrame canvas = new CanvasFrame("Object Detection");
    int w = grabber.getImageWidth();
    int h = grabber.getImageHeight();
    canvas.setCanvasSize(w, h);

    while (true) {
        try {
            frame = grabber.grab();
        } catch (FrameGrabber.Exception e) {
            e.printStackTrace();
        }

        //if a thread is null, create new thread
        if (thread == null) {
            thread = new Thread(() ->
            {
                while (frame != null) {
                    try {
                        Mat rawImage = new Mat();

                        //Flip the camera if opening front camera
                        if (cameraPos.equals("front")) {
                            Mat inputImage = converter.convert(frame);
                            flip(inputImage, rawImage, 1);
                        } else {
                            rawImage = converter.convert(frame);
                        }

                        Mat resizeImage = new Mat();
                        resize(rawImage, resizeImage, new Size(faceMaskIterator.yolowidth, faceMaskIterator.yoloheight));
                        INDArray inputImage = loader.asMatrix(resizeImage);
                        scaler.transform(inputImage);
                        INDArray outputs = model.outputSingle(inputImage);
                        org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
                        List<DetectedObject> objs = yout.getPredictedObjects(outputs, detectionThreshold);
                        YoloUtils.nms(objs, 0.4);
                        rawImage = drawResults(objs, rawImage, w, h);
                        canvas.showImage(converter.convert(rawImage));
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            });
            thread.start();
        }

        KeyEvent t = null;
        try {
            t = canvas.waitKey(33);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

        if ((t != null) && (t.getKeyCode() == KeyEvent.VK_Q)) {
            break;
        }
    }
}

private static Mat drawResults(List<DetectedObject> objects, Mat mat, int w, int h) {
    for (DetectedObject obj : objects) {
        double[] xy1 = obj.getTopLeftXY();
        double[] xy2 = obj.getBottomRightXY();
        String label = labels.get(obj.getPredictedClass());
        int x1 = (int) Math.round(w * xy1[0] / faceMaskIterator.gridWidth);
        int y1 = (int) Math.round(h * xy1[1] / faceMaskIterator.gridHeight);
        int x2 = (int) Math.round(w * xy2[0] / faceMaskIterator.gridWidth);
        int y2 = (int) Math.round(h * xy2[1] / faceMaskIterator.gridHeight);
        //Draw bounding box
        rectangle(mat, new Point(x1, y1), new Point(x2, y2), colormap[obj.getPredictedClass()], 2, 0, 0);
        //Display label text
        labeltext = label + " " + String.format("%.2f", obj.getConfidence() * 100) + "%";
        int[] baseline = {0};
        Size textSize = getTextSize(labeltext, FONT_HERSHEY_DUPLEX, 1, 1, baseline);
        rectangle(mat, new Point(x1 + 2, y2 - 2), new Point(x1 + 2 + textSize.get(0), y2 - 2 - textSize.get(1)), colormap[obj.getPredictedClass()], FILLED, 0, 0);
        putText(mat, labeltext, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, RGB(0, 0, 0));
    }
    return mat;
}

CanvasFrame 默认尝试进行伽玛校正,因为用于 CV 的相机通常需要它,但廉价的网络摄像头通常会输出伽玛校正后的图像,因此请务必让 CanvasFrame 知道这一点方式:

// We should also specify the relative monitor/camera response for proper gamma correction.
CanvasFrame frame = new CanvasFrame("Some Title", CanvasFrame.getDefaultGamma()/grabber.getGamma());