首页 > 其他 > 详细

自己动手读取MNIST数据集并存入四个np.array

时间:2020-01-10 21:21:37      阅读:71      评论:0      收藏:0      [点我收藏+]

从下载http://yann.lecun.com/exdb/mnist/四个.gz压缩包:
技术分享图片
他们分别是训练用数据、训练用标签、测试用数据、测试用标签。

然后将他们放入一个名为dataPath的文件夹中,我放入的是/home/zzz/intern/data:
技术分享图片

然后是读取数据的代码,readData()函数返回的就是四个np.array

import gzip
import numpy as np
def read_idx3(filename):
    with gzip.open(filename, 'rb') as fo:
        buf = fo.read()
        index = 0
        header = np.frombuffer(buf, '>i', 4, index)
        index += header.size * header.itemsize
        data = np.frombuffer(buf, '>B', header[1]*header[2]*header[3], index).reshape(header[1],-1)
        return data

def read_idx1(filename):
    with gzip.open(filename, 'rb') as fo:
        buf = fo.read()
        index = 0
        header = np.frombuffer(buf, '>i', 2, index)
        index += header.size * header.itemsize
        data = np.frombuffer(buf, '>B', header[1], index)
        return data

def readData(dataPath):
    X_train = read_idx3(dataPath + '/train-images-idx3-ubyte.gz')  # 训练数据集的样本特征
    y_train = read_idx1(dataPath + '/train-labels-idx1-ubyte.gz')  # 训练数据集的标签
    X_test = read_idx3(dataPath + '/t10k-images-idx3-ubyte.gz')  # 测试数据集的样本特征
    y_test = read_idx1(dataPath + '/t10k-labels-idx1-ubyte.gz')  # 测试数据集的标签
    return X_train, y_train, X_test, y_test

可以输出一下他们的维度:

if __name__=="__main__":
    dataPath = "/home/zzz/intern/data"
    X_train, y_train, X_test, y_test = readData(dataPath)
    print(X_train.shape, y_train.shape)
    print(X_test.shape, y_test.shape)

如果结果如下图所示即为正确:
技术分享图片

自己动手读取MNIST数据集并存入四个np.array

原文:https://www.cnblogs.com/huangming-zzz/p/12177905.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!