create_lmdb.py 文件的错误
Bug with the create_lmdb.py file
我正在使用 this github, following this 教程中的代码。
我所做的更改很少,因为我正在用我的数据训练 CNN。但是,我在 'create_lmdb.py'
文件中执行的更改可能存在问题。两个数据库的区别是:
首先:我正在使用 32x32 图像训练我的网络。
第二:我的数据库仅包含 灰度 图像。
但是 - 我也训练我的网络进行二元分类。
修改后,这是我的文件:
import os
import glob
import random
import numpy as np
import cv2
import caffe
from caffe.proto import caffe_pb2
import lmdb
#Size of images
IMAGE_WIDTH = 32
IMAGE_HEIGHT = 32
def transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT):
#Histogram Equalization
img = cv2.equalizeHist(img)
#img[:, :, 1] = cv2.equalizeHist(img[:, :, 1]) not a RGB
#img[:, :, 2] = cv2.equalizeHist(img[:, :, 2])
#Image Resizing
img = cv2.resize(img, (img_width, img_height), interpolation = cv2.INTER_CUBIC) # make sure all the images are at the same size
return img
def make_datum(img, label):
#image is numpy.ndarray format. BGR instead of RGB
return caffe_pb2.Datum(
channels=1, #not an RGB image
width=IMAGE_WIDTH,
height=IMAGE_HEIGHT,
label=label,
data=img.tostring())
train_lmdb = '/home/roishik/Desktop/Thesis/Code/cafe_cnn/first/input/train_lmdb'
validation_lmdb = '/home/roishik/Desktop/Thesis/Code/cafe_cnn/first/input/validation_lmdb'
os.system('rm -rf ' + train_lmdb)
os.system('rm -rf ' + validation_lmdb)
train_data = [img for img in glob.glob("../input/train/*png")]
test_data = [img for img in glob.glob("../input/test1/*png")]
#Shuffle train_data
random.shuffle(train_data)
print 'Creating train_lmdb'
in_db = lmdb.open(train_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
for in_idx, img_path in enumerate(train_data):
if in_idx % 6 == 0:
continue
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
if 'cat' in img_path:
label = 0
else:
label = 1
datum = make_datum(img, label)
in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
print '{:0>5d}'.format(in_idx) + ':' + img_path
in_db.close()
print '\nCreating validation_lmdb'
in_db = lmdb.open(validation_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
for in_idx, img_path in enumerate(train_data):
if in_idx % 6 != 0:
continue
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
prec=int(img_path[(img_path.index('prec_')+5):(img_path.index('prec_')+8)])
if prec>50:
label = 1
else:
label = 0
datum = make_datum(img, label)
in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
print '{:0>5d}'.format(in_idx) + ':' + img_path
in_db.close()
print '\nFinished processing all images'
但我认为根据训练结果:.mdb 输出文件已损坏(可能为空或其他内容 - 甚至其权重为 47MB)。
有人能看出这个文件有什么问题吗?或者,给我 link 一个关于构建 lmdb 文件的好教程?
非常感谢您的帮助!
谢谢
如果您想创建一个 'lmdb'
图像数据集来训练分类网络,别着急! Caffe 已经有专门用于此目的的工具!
您正在寻找 $CAFFE_ROOT/build/tools/convert_imageset
工具,您可以找到非常详细的(如果我可以这么说的话;)教程 。
好的,我解决了!
在深入研究代码后,我注意到我只更新了验证数据集的标签(并跳过了训练数据):P
在这段代码中可以看到:
img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
if 'cat' in img_path:
label = 0
else:
label = 1
属于原教程。
结论:如果您无法访问您的 lmdb 文件,可能是因为创建它的函数被破坏了。
我正在使用 this github, following this 教程中的代码。
我所做的更改很少,因为我正在用我的数据训练 CNN。但是,我在 'create_lmdb.py'
文件中执行的更改可能存在问题。两个数据库的区别是:
首先:我正在使用 32x32 图像训练我的网络。 第二:我的数据库仅包含 灰度 图像。 但是 - 我也训练我的网络进行二元分类。
修改后,这是我的文件:
import os
import glob
import random
import numpy as np
import cv2
import caffe
from caffe.proto import caffe_pb2
import lmdb
#Size of images
IMAGE_WIDTH = 32
IMAGE_HEIGHT = 32
def transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT):
#Histogram Equalization
img = cv2.equalizeHist(img)
#img[:, :, 1] = cv2.equalizeHist(img[:, :, 1]) not a RGB
#img[:, :, 2] = cv2.equalizeHist(img[:, :, 2])
#Image Resizing
img = cv2.resize(img, (img_width, img_height), interpolation = cv2.INTER_CUBIC) # make sure all the images are at the same size
return img
def make_datum(img, label):
#image is numpy.ndarray format. BGR instead of RGB
return caffe_pb2.Datum(
channels=1, #not an RGB image
width=IMAGE_WIDTH,
height=IMAGE_HEIGHT,
label=label,
data=img.tostring())
train_lmdb = '/home/roishik/Desktop/Thesis/Code/cafe_cnn/first/input/train_lmdb'
validation_lmdb = '/home/roishik/Desktop/Thesis/Code/cafe_cnn/first/input/validation_lmdb'
os.system('rm -rf ' + train_lmdb)
os.system('rm -rf ' + validation_lmdb)
train_data = [img for img in glob.glob("../input/train/*png")]
test_data = [img for img in glob.glob("../input/test1/*png")]
#Shuffle train_data
random.shuffle(train_data)
print 'Creating train_lmdb'
in_db = lmdb.open(train_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
for in_idx, img_path in enumerate(train_data):
if in_idx % 6 == 0:
continue
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
if 'cat' in img_path:
label = 0
else:
label = 1
datum = make_datum(img, label)
in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
print '{:0>5d}'.format(in_idx) + ':' + img_path
in_db.close()
print '\nCreating validation_lmdb'
in_db = lmdb.open(validation_lmdb, map_size=int(1e12))
with in_db.begin(write=True) as in_txn:
for in_idx, img_path in enumerate(train_data):
if in_idx % 6 != 0:
continue
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
prec=int(img_path[(img_path.index('prec_')+5):(img_path.index('prec_')+8)])
if prec>50:
label = 1
else:
label = 0
datum = make_datum(img, label)
in_txn.put('{:0>5d}'.format(in_idx), datum.SerializeToString())
print '{:0>5d}'.format(in_idx) + ':' + img_path
in_db.close()
print '\nFinished processing all images'
但我认为根据训练结果:.mdb 输出文件已损坏(可能为空或其他内容 - 甚至其权重为 47MB)。
有人能看出这个文件有什么问题吗?或者,给我 link 一个关于构建 lmdb 文件的好教程?
非常感谢您的帮助! 谢谢
如果您想创建一个 'lmdb'
图像数据集来训练分类网络,别着急! Caffe 已经有专门用于此目的的工具!
您正在寻找 $CAFFE_ROOT/build/tools/convert_imageset
工具,您可以找到非常详细的(如果我可以这么说的话;)教程
好的,我解决了! 在深入研究代码后,我注意到我只更新了验证数据集的标签(并跳过了训练数据):P 在这段代码中可以看到:
img = transform_img(img, img_width=IMAGE_WIDTH, img_height=IMAGE_HEIGHT)
if 'cat' in img_path:
label = 0
else:
label = 1
属于原教程。
结论:如果您无法访问您的 lmdb 文件,可能是因为创建它的函数被破坏了。