CSV文件即可。不会被转换成浮点数

CSV file can. not be converted into a float

我目前正在尝试在数据 csv 文件上测试 ANN。当我训练 ANN 时,这段代码第一次没有问题,但是当我为测试做它时出现了这个错误。如果我再次 运行 该程序,它会吐出答案,但显然我不希望每次都必须不断重新 运行 它

代码

   # load the mnist test data CSV file into a list 
    test_data_file = open("mnist_test.csv", 'r') 
    test_data_list = test_data_file.readlines() 
    test_data_file.close() 
 # go through all the records in the test data set
 for record in test_data_list:
     # split the record by the ',' commas
all_values = record.split(',')
# correct answer is first value
correct_label = int(all_values[0])
# scale and shift the inputs
inputs = (numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
# query the network
outputs = n.query(inputs)
# the index of the highest value corresponds to the label
label = numpy.argmax(outputs)
# append correct or incorrect to list
if (label == correct_label):
    # network's answer matches correct answer, add 1 to scorecard
    scorecard.append(1)
else:
    # network's answer doesn't match correct answer, add 0 to scorecard
    scorecard.append(0)
  

错误

 ---------------------------------------------------------------------------
 ValueError                                Traceback (most recent call last)
 <ipython-input-7-026bfaa95a53> in <module>
 11     correct_label = int(all_values[0])
 12     # scale and shift the inputs
 ---> 13     inputs = (numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
 14     # query the network
 15     outputs = n.query(inputs)

 <__array_function__ internals> in asfarray(*args, **kwargs)

 ~/opt/anaconda3/lib/python3.8/site-packages/numpy/lib/type_check.py in asfarray(a, dtype)
113     if not _nx.issubdtype(dtype, _nx.inexact):
114         dtype = _nx.float_
  --> 115     return asarray(a, dtype=dtype)
116 
117 

 ~/opt/anaconda3/lib/python3.8/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
 83 
 84     """
   ---> 85     return array(a, dtype, copy=False, order=order)
 86 
 87 

  ValueError: could not convert string to float: ''

您收到错误消息是因为 numpy.asfarray(all_values[1:]) 在列表中发现了空字符串。这些来自 record.split(','),这意味着您的原始 CSV 文件的某些行中有一些空列。这就引出了一个问题,这对您的数据是否合法以及在这种情况下什么是好的默认设置。

由于您的数据是一个整数列 0,其余列是浮点数,您可以跳过自己读取文件并使用 numpy.genfromtxt() 一次调用从整个文件构建一个数组。然后,您可以拉出 correct_label 并将数组作为一个整体进行处理。我选择 np.nan 来填充空单元格 - 这只是一个猜测,您可能需要使用不同的值。

我不太清楚你的脚本结尾处发生了什么。但我认为构建正确的数组然后使用 for 循环来枚举它的行,我正在做同样的事情......不管它是什么!

import numpy as np

# read in entire data set
arr = np.genfromtxt("test.csv", dtype='f8', delimiter=',',
    missing_values=np.nan, filling_values=np.nan)

# column 0 is really integer correct labels, so extract to 1d array
correct_labels = arr[0].astype('i8')

# scale and shift remaining columns
arr = (np.delete(arr, 0, axis=1)*255.0*0.99)+0.01

for inputs in arr:
    outputs = n.query(inputs)
    # the index of the highest value corresponds to the label
    label = numpy.argmax(outputs)
    # append correct or incorrect to list
    if (label == correct_label):
        # network's answer matches correct answer, add 1 to scorecard
        scorecard.append(1)
    else:
        # network's answer doesn't match correct answer, add 0 to scorecard
        scorecard.append(0)