DL4J-Image become too bright
目前,我被要求使用 DL4J 和 YOLOv2 架构编写 CNN 代码。但问题是在模型完成后,我做了一个简单的 GUI 来进行验证测试,然后 image shown is too bright and sometimes the image can be displayed。我不确定这个问题是从哪里来的,是在训练的最早阶段还是其他阶段。在这里,我附上了我现在拥有的代码。对于迭代器:
public class faceMaskIterator {
private static final Logger log = org.slf4j.LoggerFactory.getLogger(faceMaskIterator.class);
private static final int seed = 123;
private static Random rng = new Random(seed);
private static String dataDir;
private static Path pathDirectory;
private static InputSplit trainData, testData;
private static final String[] allowedFormats = NativeImageLoader.ALLOWED_FORMATS;
private static final double splitRatio = 0.8;
private static final int nChannels = 3;
public static final int gridWidth = 13;
public static final int gridHeight = 13;
public static final int yolowidth = 416;
public static final int yoloheight = 416;
private static RecordReaderDataSetIterator makeIterator(InputSplit split, Path dir, int batchSize) throws Exception {
ObjectDetectionRecordReader recordReader = new ObjectDetectionRecordReader(yoloheight, yolowidth, nChannels,
gridHeight, gridWidth, new VocLabelProvider(dir.toString()));
RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, 1,true);
iter.setPreProcessor(new ImagePreProcessingScaler(0, 1));
return iter;
public static RecordReaderDataSetIterator trainIterator(int batchSize) throws Exception {
return makeIterator(trainData, pathDirectory, batchSize);
public static RecordReaderDataSetIterator testIterator(int batchSize) throws Exception {
return makeIterator(testData, pathDirectory, batchSize);
public static void setup() throws IOException {
log.info("Load data...");
dataDir = Paths.get(
pathDirectory = Paths.get(dataDir,"face_mask_dataset");
FileSplit fileSplit = new FileSplit(new File(pathDirectory.toString()),allowedFormats,rng);
PathFilter pathFilter = new RandomPathFilter(rng,allowedFormats);
InputSplit[] sample = fileSplit.sample(pathFilter, splitRatio,1-splitRatio);
trainData = sample[0];
testData = sample[1];
public class faceMaskPreTrained {
private static final Logger log = LoggerFactory.getLogger(ai.certifai.groupProjek.faceMaskPreTrained.class);
private static int seed = 420;
private static double detectionThreshold = 0.9;
private static int nBoxes = 3;
private static double lambdaNoObj = 0.7;
private static double lambdaCoord = 1.0;
private static double[][] priorBoxes = {{1, 1}, {2, 1}, {1, 2}};
private static int batchSize = 3;
private static int nEpochs = 1;
private static double learningRate = 1e-4;
private static int nClasses = 3;
private static List<String> labels;
private static File modelFilename = new File(System.getProperty("user.dir"), "generated-models/facemask_detector.zip");
private static ComputationGraph model;
private static Frame frame = null;
private static final Scalar GREEN = RGB(0, 255.0, 0);
private static final Scalar YELLOW = RGB(255, 255, 0);
private static final Scalar RED = RGB(255, 0, 0);
private static Scalar[] colormap = {GREEN, YELLOW, RED};
private static String labeltext = null;
public static void main(String[] args) throws Exception {
RecordReaderDataSetIterator trainIter = faceMaskIterator.trainIterator(batchSize);
RecordReaderDataSetIterator testIter = faceMaskIterator.testIterator(1);
labels = trainIter.getLabels();
if (modelFilename.exists()) {
log.info("Load model...");
model = ModelSerializer.restoreComputationGraph(modelFilename);
} else {
INDArray priors = Nd4j.create(priorBoxes);
log.info("Build model...");
ComputationGraph pretrained = (ComputationGraph) YOLO2.builder().build().initPretrained();
FineTuneConfiguration fineTuneConf = getFineTuneConfiguration();
model = getComputationGraph(pretrained, priors, fineTuneConf);
log.info("Train model...");
UIServer server = UIServer.getInstance();
StatsStorage storage = new InMemoryStatsStorage();
model.setListeners(new ScoreIterationListener(5), new StatsListener(storage,5));
for (int i = 1; i < nEpochs + 1; i++) {
while (trainIter.hasNext()) {
log.info("*** Completed epoch {} ***", i);
ModelSerializer.writeModel(model, modelFilename, true);
System.out.println("Model saved.");
// Evaluate the model's accuracy by using the test iterator.
// Inference the model and process the webcam stream and make predictions.
private static ComputationGraph getComputationGraph(ComputationGraph pretrained, INDArray priors, FineTuneConfiguration fineTuneConf) {
return new TransferLearning.GraphBuilder(pretrained)
new ConvolutionLayer.Builder(1, 1)
.nOut(nBoxes * (5 + nClasses))
.stride(1, 1)
new Yolo2OutputLayer.Builder()
private static FineTuneConfiguration getFineTuneConfiguration() {
return new FineTuneConfiguration.Builder()
.updater(new Adam.Builder().learningRate(learningRate).build())
// Evaluate visually the performance of the trained object detection model
private static void OfflineValidationWithTestDataset(RecordReaderDataSetIterator test) throws InterruptedException {
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame canvas = new CanvasFrame("Validate Test Dataset");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
Mat convertedMat = new Mat();
Mat convertedMat_big = new Mat();
while (test.hasNext() && canvas.isVisible()) {
org.nd4j.linalg.dataset.DataSet ds = test.next();
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
YoloUtils.nms(objs, 0.4);
Mat mat = imageLoader.asMat(features);
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = mat.cols() * 2;
int h = mat.rows() * 2;
resize(convertedMat, convertedMat_big, new Size(w, h));
convertedMat_big = drawResults(objs, convertedMat_big, w, h);
// Stream video frames from Webcam and run them through YOLOv2 model and get predictions
private static void doInference() {
String cameraPos = "front";
int cameraNum = 0;
Thread thread = null;
NativeImageLoader loader = new NativeImageLoader(
new ColorConversionTransform(COLOR_BGR2RGB));
ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);
if (!cameraPos.equals("front") && !cameraPos.equals("back")) {
try {
throw new Exception("Unknown argument for camera position. Choose between front and back");
} catch (Exception e) {
FrameGrabber grabber = null;
try {
grabber = FrameGrabber.createDefault(cameraNum);
} catch (FrameGrabber.Exception e) {
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
try {
} catch (FrameGrabber.Exception e) {
CanvasFrame canvas = new CanvasFrame("Object Detection");
int w = grabber.getImageWidth();
int h = grabber.getImageHeight();
canvas.setCanvasSize(w, h);
while (true) {
try {
frame = grabber.grab();
} catch (FrameGrabber.Exception e) {
//if a thread is null, create new thread
if (thread == null) {
thread = new Thread(() ->
while (frame != null) {
try {
Mat rawImage = new Mat();
//Flip the camera if opening front camera
if (cameraPos.equals("front")) {
Mat inputImage = converter.convert(frame);
flip(inputImage, rawImage, 1);
} else {
rawImage = converter.convert(frame);
Mat resizeImage = new Mat();
resize(rawImage, resizeImage, new Size(faceMaskIterator.yolowidth, faceMaskIterator.yoloheight));
INDArray inputImage = loader.asMatrix(resizeImage);
INDArray outputs = model.outputSingle(inputImage);
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
List<DetectedObject> objs = yout.getPredictedObjects(outputs, detectionThreshold);
YoloUtils.nms(objs, 0.4);
rawImage = drawResults(objs, rawImage, w, h);
} catch (Exception e) {
throw new RuntimeException(e);
KeyEvent t = null;
try {
t = canvas.waitKey(33);
} catch (InterruptedException e) {
if ((t != null) && (t.getKeyCode() == KeyEvent.VK_Q)) {
private static Mat drawResults(List<DetectedObject> objects, Mat mat, int w, int h) {
for (DetectedObject obj : objects) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / faceMaskIterator.gridWidth);
int y1 = (int) Math.round(h * xy1[1] / faceMaskIterator.gridHeight);
int x2 = (int) Math.round(w * xy2[0] / faceMaskIterator.gridWidth);
int y2 = (int) Math.round(h * xy2[1] / faceMaskIterator.gridHeight);
//Draw bounding box
rectangle(mat, new Point(x1, y1), new Point(x2, y2), colormap[obj.getPredictedClass()], 2, 0, 0);
//Display label text
labeltext = label + " " + String.format("%.2f", obj.getConfidence() * 100) + "%";
int[] baseline = {0};
Size textSize = getTextSize(labeltext, FONT_HERSHEY_DUPLEX, 1, 1, baseline);
rectangle(mat, new Point(x1 + 2, y2 - 2), new Point(x1 + 2 + textSize.get(0), y2 - 2 - textSize.get(1)), colormap[obj.getPredictedClass()], FILLED, 0, 0);
putText(mat, labeltext, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, RGB(0, 0, 0));
return mat;
默认尝试进行伽玛校正,因为用于 CV 的相机通常需要它,但廉价的网络摄像头通常会输出伽玛校正后的图像,因此请务必让 CanvasFrame
// We should also specify the relative monitor/camera response for proper gamma correction.
CanvasFrame frame = new CanvasFrame("Some Title", CanvasFrame.getDefaultGamma()/grabber.getGamma());
目前,我被要求使用 DL4J 和 YOLOv2 架构编写 CNN 代码。但问题是在模型完成后,我做了一个简单的 GUI 来进行验证测试,然后 image shown is too bright and sometimes the image can be displayed。我不确定这个问题是从哪里来的,是在训练的最早阶段还是其他阶段。在这里,我附上了我现在拥有的代码。对于迭代器:
public class faceMaskIterator {
private static final Logger log = org.slf4j.LoggerFactory.getLogger(faceMaskIterator.class);
private static final int seed = 123;
private static Random rng = new Random(seed);
private static String dataDir;
private static Path pathDirectory;
private static InputSplit trainData, testData;
private static final String[] allowedFormats = NativeImageLoader.ALLOWED_FORMATS;
private static final double splitRatio = 0.8;
private static final int nChannels = 3;
public static final int gridWidth = 13;
public static final int gridHeight = 13;
public static final int yolowidth = 416;
public static final int yoloheight = 416;
private static RecordReaderDataSetIterator makeIterator(InputSplit split, Path dir, int batchSize) throws Exception {
ObjectDetectionRecordReader recordReader = new ObjectDetectionRecordReader(yoloheight, yolowidth, nChannels,
gridHeight, gridWidth, new VocLabelProvider(dir.toString()));
RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(recordReader, batchSize, 1, 1,true);
iter.setPreProcessor(new ImagePreProcessingScaler(0, 1));
return iter;
public static RecordReaderDataSetIterator trainIterator(int batchSize) throws Exception {
return makeIterator(trainData, pathDirectory, batchSize);
public static RecordReaderDataSetIterator testIterator(int batchSize) throws Exception {
return makeIterator(testData, pathDirectory, batchSize);
public static void setup() throws IOException {
log.info("Load data...");
dataDir = Paths.get(
pathDirectory = Paths.get(dataDir,"face_mask_dataset");
FileSplit fileSplit = new FileSplit(new File(pathDirectory.toString()),allowedFormats,rng);
PathFilter pathFilter = new RandomPathFilter(rng,allowedFormats);
InputSplit[] sample = fileSplit.sample(pathFilter, splitRatio,1-splitRatio);
trainData = sample[0];
testData = sample[1];
public class faceMaskPreTrained {
private static final Logger log = LoggerFactory.getLogger(ai.certifai.groupProjek.faceMaskPreTrained.class);
private static int seed = 420;
private static double detectionThreshold = 0.9;
private static int nBoxes = 3;
private static double lambdaNoObj = 0.7;
private static double lambdaCoord = 1.0;
private static double[][] priorBoxes = {{1, 1}, {2, 1}, {1, 2}};
private static int batchSize = 3;
private static int nEpochs = 1;
private static double learningRate = 1e-4;
private static int nClasses = 3;
private static List<String> labels;
private static File modelFilename = new File(System.getProperty("user.dir"), "generated-models/facemask_detector.zip");
private static ComputationGraph model;
private static Frame frame = null;
private static final Scalar GREEN = RGB(0, 255.0, 0);
private static final Scalar YELLOW = RGB(255, 255, 0);
private static final Scalar RED = RGB(255, 0, 0);
private static Scalar[] colormap = {GREEN, YELLOW, RED};
private static String labeltext = null;
public static void main(String[] args) throws Exception {
RecordReaderDataSetIterator trainIter = faceMaskIterator.trainIterator(batchSize);
RecordReaderDataSetIterator testIter = faceMaskIterator.testIterator(1);
labels = trainIter.getLabels();
if (modelFilename.exists()) {
log.info("Load model...");
model = ModelSerializer.restoreComputationGraph(modelFilename);
} else {
INDArray priors = Nd4j.create(priorBoxes);
log.info("Build model...");
ComputationGraph pretrained = (ComputationGraph) YOLO2.builder().build().initPretrained();
FineTuneConfiguration fineTuneConf = getFineTuneConfiguration();
model = getComputationGraph(pretrained, priors, fineTuneConf);
log.info("Train model...");
UIServer server = UIServer.getInstance();
StatsStorage storage = new InMemoryStatsStorage();
model.setListeners(new ScoreIterationListener(5), new StatsListener(storage,5));
for (int i = 1; i < nEpochs + 1; i++) {
while (trainIter.hasNext()) {
log.info("*** Completed epoch {} ***", i);
ModelSerializer.writeModel(model, modelFilename, true);
System.out.println("Model saved.");
// Evaluate the model's accuracy by using the test iterator.
// Inference the model and process the webcam stream and make predictions.
private static ComputationGraph getComputationGraph(ComputationGraph pretrained, INDArray priors, FineTuneConfiguration fineTuneConf) {
return new TransferLearning.GraphBuilder(pretrained)
new ConvolutionLayer.Builder(1, 1)
.nOut(nBoxes * (5 + nClasses))
.stride(1, 1)
new Yolo2OutputLayer.Builder()
private static FineTuneConfiguration getFineTuneConfiguration() {
return new FineTuneConfiguration.Builder()
.updater(new Adam.Builder().learningRate(learningRate).build())
// Evaluate visually the performance of the trained object detection model
private static void OfflineValidationWithTestDataset(RecordReaderDataSetIterator test) throws InterruptedException {
NativeImageLoader imageLoader = new NativeImageLoader();
CanvasFrame canvas = new CanvasFrame("Validate Test Dataset");
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
Mat convertedMat = new Mat();
Mat convertedMat_big = new Mat();
while (test.hasNext() && canvas.isVisible()) {
org.nd4j.linalg.dataset.DataSet ds = test.next();
INDArray features = ds.getFeatures();
INDArray results = model.outputSingle(features);
List<DetectedObject> objs = yout.getPredictedObjects(results, detectionThreshold);
YoloUtils.nms(objs, 0.4);
Mat mat = imageLoader.asMat(features);
mat.convertTo(convertedMat, CV_8U, 255, 0);
int w = mat.cols() * 2;
int h = mat.rows() * 2;
resize(convertedMat, convertedMat_big, new Size(w, h));
convertedMat_big = drawResults(objs, convertedMat_big, w, h);
// Stream video frames from Webcam and run them through YOLOv2 model and get predictions
private static void doInference() {
String cameraPos = "front";
int cameraNum = 0;
Thread thread = null;
NativeImageLoader loader = new NativeImageLoader(
new ColorConversionTransform(COLOR_BGR2RGB));
ImagePreProcessingScaler scaler = new ImagePreProcessingScaler(0, 1);
if (!cameraPos.equals("front") && !cameraPos.equals("back")) {
try {
throw new Exception("Unknown argument for camera position. Choose between front and back");
} catch (Exception e) {
FrameGrabber grabber = null;
try {
grabber = FrameGrabber.createDefault(cameraNum);
} catch (FrameGrabber.Exception e) {
OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat();
try {
} catch (FrameGrabber.Exception e) {
CanvasFrame canvas = new CanvasFrame("Object Detection");
int w = grabber.getImageWidth();
int h = grabber.getImageHeight();
canvas.setCanvasSize(w, h);
while (true) {
try {
frame = grabber.grab();
} catch (FrameGrabber.Exception e) {
//if a thread is null, create new thread
if (thread == null) {
thread = new Thread(() ->
while (frame != null) {
try {
Mat rawImage = new Mat();
//Flip the camera if opening front camera
if (cameraPos.equals("front")) {
Mat inputImage = converter.convert(frame);
flip(inputImage, rawImage, 1);
} else {
rawImage = converter.convert(frame);
Mat resizeImage = new Mat();
resize(rawImage, resizeImage, new Size(faceMaskIterator.yolowidth, faceMaskIterator.yoloheight));
INDArray inputImage = loader.asMatrix(resizeImage);
INDArray outputs = model.outputSingle(inputImage);
org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) model.getOutputLayer(0);
List<DetectedObject> objs = yout.getPredictedObjects(outputs, detectionThreshold);
YoloUtils.nms(objs, 0.4);
rawImage = drawResults(objs, rawImage, w, h);
} catch (Exception e) {
throw new RuntimeException(e);
KeyEvent t = null;
try {
t = canvas.waitKey(33);
} catch (InterruptedException e) {
if ((t != null) && (t.getKeyCode() == KeyEvent.VK_Q)) {
private static Mat drawResults(List<DetectedObject> objects, Mat mat, int w, int h) {
for (DetectedObject obj : objects) {
double[] xy1 = obj.getTopLeftXY();
double[] xy2 = obj.getBottomRightXY();
String label = labels.get(obj.getPredictedClass());
int x1 = (int) Math.round(w * xy1[0] / faceMaskIterator.gridWidth);
int y1 = (int) Math.round(h * xy1[1] / faceMaskIterator.gridHeight);
int x2 = (int) Math.round(w * xy2[0] / faceMaskIterator.gridWidth);
int y2 = (int) Math.round(h * xy2[1] / faceMaskIterator.gridHeight);
//Draw bounding box
rectangle(mat, new Point(x1, y1), new Point(x2, y2), colormap[obj.getPredictedClass()], 2, 0, 0);
//Display label text
labeltext = label + " " + String.format("%.2f", obj.getConfidence() * 100) + "%";
int[] baseline = {0};
Size textSize = getTextSize(labeltext, FONT_HERSHEY_DUPLEX, 1, 1, baseline);
rectangle(mat, new Point(x1 + 2, y2 - 2), new Point(x1 + 2 + textSize.get(0), y2 - 2 - textSize.get(1)), colormap[obj.getPredictedClass()], FILLED, 0, 0);
putText(mat, labeltext, new Point(x1 + 2, y2 - 2), FONT_HERSHEY_DUPLEX, 1, RGB(0, 0, 0));
return mat;
默认尝试进行伽玛校正,因为用于 CV 的相机通常需要它,但廉价的网络摄像头通常会输出伽玛校正后的图像,因此请务必让 CanvasFrame
// We should also specify the relative monitor/camera response for proper gamma correction.
CanvasFrame frame = new CanvasFrame("Some Title", CanvasFrame.getDefaultGamma()/grabber.getGamma());