我的数据生成器没有获取新数据。每次迭代只获取相同的数据
My data generator is not getting new data. Just gets same data each interation
我在 python/Keras 中创建了一个数据生成器,用于以 batchesize=5 提取文件名和标签。每次迭代都会获得相同的文件名和标签。我希望它每次迭代都能获得新的(成功的)文件名和标签。
我查看了很多示例并阅读了文档,但无法弄明白。
def datagenerator(imgfns, imglabels, batchsize, mode="train"):
while True:
images = []
labels = []
cnt=0
while len(images) < batchsize:
images.append(imgfns[cnt])
labels.append(imglabels[cnt])
cnt=cnt+1
#for ii in range(batchsize):
# #img = np.load(imgfns[ii])
# #images.append(img)
# images.append(imgfns[ii])
# labels.append(imglabels[ii])
#for image, label in zip(imgfns, imglabels):
# #img = np.load(image)
# #images.append(img)
# images.append(image)
# labels.append(label)
print(images)
print(labels)
print('********** cnt = ', cnt)
yield images, labels
train_gen = datagenerator(train_uxo_scrap, train_uxo_scrap_labels, batchsize=BS)
valid_gen = datagenerator(test_uxo_scrap, test_uxo_scrap_labels, batchsize=BS)
# train the network
H = model.fit_generator(
train_gen,
steps_per_epoch=NUM_TRAIN_IMAGES // BS,
validation_data=valid_gen,
validation_steps=NUM_TEST_IMAGES // BS,
epochs=NUM_EPOCHS)
这是我得到的输出示例。您可以看到,每次它通过生成器时,它都会获取相同的数据。 "Epoch 1/10"后的第一行有5个文件名。下一行有 5 个标签(对应 batchsize=5)。例如,您可以在每个输出中看到第一个文件名是“... 508.npy”等。并且每次迭代的标签都相同。
Epoch 1/10
['C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#508.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#1218.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#71.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#551.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#843.npy']
[1, 0, 0, 0, 1]
********** cnt = 5
['C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#508.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#1218.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#71.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#551.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#843.npy']
[1, 0, 0, 0, 1]
********** cnt = 5
['C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#508.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#1218.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#71.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#551.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#843.npy']
[1, 0, 0, 0, 1]
********** cnt = 5
问题是您每次迭代都设置 cnt=0
。你抓取 5 个文件名,产生它们,然后重复确切的事情,所以你总是抓取前 5 个。你想要更改
def datagenerator(imgfns, imglabels, batchsize, mode="train"):
while True:
images = []
labels = []
cnt=0
到
def datagenerator(imgfns, imglabels, batchsize, mode="train"):
cnt=0
while True:
images = []
labels = []
您还需要确保 cnt
保持在列表的范围内。所以像
while len(images) < batchsize and cnt < len(imgfns):
# blah
我在 python/Keras 中创建了一个数据生成器,用于以 batchesize=5 提取文件名和标签。每次迭代都会获得相同的文件名和标签。我希望它每次迭代都能获得新的(成功的)文件名和标签。
我查看了很多示例并阅读了文档,但无法弄明白。
def datagenerator(imgfns, imglabels, batchsize, mode="train"):
while True:
images = []
labels = []
cnt=0
while len(images) < batchsize:
images.append(imgfns[cnt])
labels.append(imglabels[cnt])
cnt=cnt+1
#for ii in range(batchsize):
# #img = np.load(imgfns[ii])
# #images.append(img)
# images.append(imgfns[ii])
# labels.append(imglabels[ii])
#for image, label in zip(imgfns, imglabels):
# #img = np.load(image)
# #images.append(img)
# images.append(image)
# labels.append(label)
print(images)
print(labels)
print('********** cnt = ', cnt)
yield images, labels
train_gen = datagenerator(train_uxo_scrap, train_uxo_scrap_labels, batchsize=BS)
valid_gen = datagenerator(test_uxo_scrap, test_uxo_scrap_labels, batchsize=BS)
# train the network
H = model.fit_generator(
train_gen,
steps_per_epoch=NUM_TRAIN_IMAGES // BS,
validation_data=valid_gen,
validation_steps=NUM_TEST_IMAGES // BS,
epochs=NUM_EPOCHS)
这是我得到的输出示例。您可以看到,每次它通过生成器时,它都会获取相同的数据。 "Epoch 1/10"后的第一行有5个文件名。下一行有 5 个标签(对应 batchsize=5)。例如,您可以在每个输出中看到第一个文件名是“... 508.npy”等。并且每次迭代的标签都相同。
Epoch 1/10
['C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#508.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#1218.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#71.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#551.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#843.npy']
[1, 0, 0, 0, 1]
********** cnt = 5
['C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#508.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#1218.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#71.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#551.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#843.npy']
[1, 0, 0, 0, 1]
********** cnt = 5
['C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#508.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#1218.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#71.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\scrap_48-81\JBCC_Norm_Formatted_48-81_#551.npy', 'C:\Users\jfhauris\Documents\xtemp\ML GEO\MLGeoCode\FormattedDataStore\uxo_48-81\JBCC_Norm_Formatted_48-81_#843.npy']
[1, 0, 0, 0, 1]
********** cnt = 5
问题是您每次迭代都设置 cnt=0
。你抓取 5 个文件名,产生它们,然后重复确切的事情,所以你总是抓取前 5 个。你想要更改
def datagenerator(imgfns, imglabels, batchsize, mode="train"):
while True:
images = []
labels = []
cnt=0
到
def datagenerator(imgfns, imglabels, batchsize, mode="train"):
cnt=0
while True:
images = []
labels = []
您还需要确保 cnt
保持在列表的范围内。所以像
while len(images) < batchsize and cnt < len(imgfns):
# blah