在 jupyter notebook 中加载数据集 python

loading dataset in jupyter notebook python

我运行在jupyter notebook中使用下面的代码python:

# Run some setup code for this notebook.

import random
import numpy as np
from cs231n.data_utils import load_CIFAR10
import matplotlib.pyplot as plt

# This is a bit of magic to make matplotlib figures appear inline in the notebook
# rather than in a new window.
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# Some more magic so that the notebook will reload external python modules;
# see 
%load_ext autoreload
%autoreload 2

然后是以下说明:

# Load the raw CIFAR-10 data.
cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

# As a sanity check, we print out the size of the training and test data.
print ('Training data shape: ', X_train.shape)
print ('Training labels shape: ', y_train.shape)
print ('Test data shape: ', X_test.shape)
print ('Test labels shape: ', y_test.shape)

通过 运行 宁第二部分,我得到以下错误:

---------------------------------------------------------------------------
UnicodeDecodeError                        Traceback (most recent call last)
<ipython-input-5-9506c06e646a> in <module>()
      1 # Load the raw CIFAR-10 data.
      2 cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
----> 3 X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
      4 
      5 # As a sanity check, we print out the size of the training and test data.

C:\Users\lenovo\assignment1\cs231n\data_utils.py in load_CIFAR10(ROOT)
     20   for b in range(1,6):
     21     f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
---> 22     X, Y = load_CIFAR_batch(f)
     23     xs.append(X)
     24     ys.append(Y)

C:\Users\lenovo\assignment1\cs231n\data_utils.py in load_CIFAR_batch(filename)
      7   """ load single batch of cifar """
      8   with open(filename, 'rb') as f:
----> 9     datadict = pickle.load(f)
     10     X = datadict['data']
     11     Y = datadict['labels']

UnicodeDecodeError: 'ascii' codec can't decode byte 0x8b in position 6: ordinal not in range(128)

我该如何解决这个错误?我正在使用 Annaconda3 来 运行 这段代码。上面的代码似乎是在 Annaonda2 版本中编写的。有任何解决这些错误的建议吗?

更多详情:

我正在尝试解决来自 link 的作业:http://cs231n.github.io/assignments2016/assignment1/

编辑:

添加 data_utils.py 包含 load_CIFAR

的定义
import _pickle as pickle
import numpy as np
import os
from scipy.misc import imread

def load_CIFAR_batch(filename):
  """ load single batch of cifar """
  with open(filename, 'rb') as f:
    datadict = pickle.load(f)
    X = datadict['data']
    Y = datadict['labels']
    X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
    Y = np.array(Y)
    return X, Y

def load_CIFAR10(ROOT):
  """ load all of cifar """
  xs = []
  ys = []
  for b in range(1,6):
    f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
    X, Y = load_CIFAR_batch(f)
    xs.append(X)
    ys.append(Y)    
  Xtr = np.concatenate(xs)
  Ytr = np.concatenate(ys)
  del X, Y
  Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
  return Xtr, Ytr, Xte, Yte

您正在加载的 pickle 文件很可能是使用 python 2.

生成的

由于 pickle 在 Python2 和 Python3 中的工作方式存在根本差异,您可以尝试使用 latin-1 编码加载文件,假设 0-255 直接映射到字符。

此方法需要进行一些完整性检查,因为它不能保证生成连贯的数据。