文章目录
- 一. MNIST数据集
- 1.1 什么是MNIST数据集
- 1.2MNIST数据集文件格式
- 1.3使用python访问MNIST数据集文件内容
- 附录
- 程序源码
一. MNIST数据集
1.1 什么是MNIST数据集
MNIST数据集是入门机器学习/识别模式的最经典数据集之一。最早于1998年Yan Lecun在论文:[Gradient-based learning applied to document recognition]中提出。该数据集包含了0-9共10类手写数字图片,每张图片都做了尺寸归一化,都是28x28大小的灰度图。每张图片中的像素大小在0-255之间,其中0是黑色,255是白色。如下图所示:
MNIST共包含70000张手写数字图片,其中有60000张用作训练集,10000张用作测试集。元数据集可以在MNIST官网下载。下载之后得到4个压缩文件:
train-images-idx3-ubyte.gz #60000张训练集图片
train-labels-idx1-ubyte.gz #60000张训练集图片对应的标签
t10k-images-idx3-ubyte.gz #10000张测试集图片
t10k-labels-idx1-ubyte.gz #10000张测试集图片对应的标签
将上面的4个压缩文件分别解压,得到:
train-images-idx3-ubyte #60000张训练集图片的idx3-ubyte格式文件
train-labels-idx1-ubyte #60000张训练集图片对应的标签的idx3-ubyte格式文件
t10k-images-idx3-ubyte #10000张测试集图片的idx3-ubyte格式文件
t10k-labels-idx1-ubyte #10000张测试集图片对应的标签的idx3-ubyte格式文件
1.2MNIST数据集文件格式
解压得到的4个文件都是二进制格式的文件,为了获取其中的信息,需要先了解MNIST二进制文件的存储格式。格式描述如下:
- 第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051;
- 第5-8个byte存的是number of images,即图像数量60000;
- 第9-12个byte存的是每张图片行数/高度,即28;
- 第13-16个byte存的是每张图片的列数/宽度,即28。
- 从第17个byte开始,每个byte存储一张图片中的一个像素点的值。
1.3使用python访问MNIST数据集文件内容
知道了MNIST二进制文件的存储方式,下面介绍如何使用python访问文件内容。同样以训练集标签文件train-labels-idx1-ubyte
和训练集图像文件train-images-idx3-ubyte
为例:
import numpy as np
from PIL import ImageMNIST_labels_path = 'G:\\mnist_dataset\\train-labels-idx1-ubyte\\train-labels.idx1-ubyte' # 下载的MNIST数据集文件地址
MNIST_images_path = 'G:\\mnist_dataset\\train-images-idx3-ubyte\\train-images.idx3-ubyte' # 下载的MNIST数据集文件地址with open(MNIST_labels_path, 'rb') as f:file_labels = f.read() # 读入标签二进制文件
with open(MNIST_images_path, 'rb') as f:file_images = f.read() # 读入照片二进制文件magic_number_labels = int.from_bytes(file_labels[0:4], 'big') # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_items = int.from_bytes(file_labels[4:8], 'big') # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
print('labels: magic number = ', magic_number_labels)
print('labels: number of items = ', number_items)magic_number = int.from_bytes(file_images[0:4], 'big') # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_images = int.from_bytes(file_images[4:8], 'big') # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
number_rows = int.from_bytes(file_images[8:12], 'big') # 读取二进制文件的第9-12个byte( 1byte = 8bit ),即number of rows,并转换成10进制
number_columns = int.from_bytes(file_images[12:16], 'big') # 读取二进制文件的第13-16个byte( 1byte = 8bit ),即number of columns,并转换成10进制
print('images: magic number = ', magic_number)
print('images: number of images = ', number_images)
print('images: number of rows = ', number_rows)
print('images: number of columns = ', number_columns)
使用with open() as 函数读取文件,并使用int.from_bytes()方法将文件的magic number, number of items, number of images, number of rows, number of columns,
等数据读入,将字节数据转换成整数数据,从而查看图像数量、图像高度和图像宽度信息。
运行结果:
通过以下程序,可以将MNIST数据集二进制文件中的照片提取出来并以.png格式保存在文件夹中:
# 将二进制的图像文件中的图像提取出来并保存在文件夹中
for i in range(1, 60001):image = [item for item in file_images[16 + 28 * 28 * (i - 1):16 + 28 * 28 * i]]image_np = np.array(image, dtype=np.uint8).reshape(28, 28)im = Image.fromarray(image_np)im.save("G:\\mnist_dataset\\train-images" + "\\" + str(i) + ".png")
输出的部分照片如下所示:
通过以下程序,将二进制标签文件中的部分标签信息打印出来,可以发现,标签中的数据正对应于图像中的手写数字信息。
# 将二进制的标签文件中的部分标签信息打印出来
for i in range(40, 53):labels = int.from_bytes(file_labels[8 + i - 1:8 + i], 'big')print('labels' + str(i) + '=' + str(labels))
附录
程序源码
import numpy as np
from PIL import ImageMNIST_labels_path = 'G:\\mnist_dataset\\train-labels-idx1-ubyte\\train-labels.idx1-ubyte' # 下载的MNIST数据集文件地址
MNIST_images_path = 'G:\\mnist_dataset\\train-images-idx3-ubyte\\train-images.idx3-ubyte' # 下载的MNIST数据集文件地址with open(MNIST_labels_path, 'rb') as f:file_labels = f.read() # 读入标签二进制文件
with open(MNIST_images_path, 'rb') as f:file_images = f.read() # 读入照片二进制文件magic_number_labels = int.from_bytes(file_labels[0:4], 'big') # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_items = int.from_bytes(file_labels[4:8], 'big') # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
print('labels: magic number = ', magic_number_labels)
print('labels: number of items = ', number_items)magic_number = int.from_bytes(file_images[0:4], 'big') # 读取二进制文件的第1-4个byte( 1byte = 8bit )即magic number,并转换成10进制
number_images = int.from_bytes(file_images[4:8], 'big') # 读取二进制文件的第5-8个byte( 1byte = 8bit ),即number of images,并转换成10进制
number_rows = int.from_bytes(file_images[8:12], 'big') # 读取二进制文件的第9-12个byte( 1byte = 8bit ),即number of rows,并转换成10进制
number_columns = int.from_bytes(file_images[12:16], 'big') # 读取二进制文件的第13-16个byte( 1byte = 8bit ),即number of columns,并转换成10进制
print('images: magic number = ', magic_number)
print('images: number of images = ', number_images)
print('images: number of rows = ', number_rows)
print('images: number of columns = ', number_columns)# 将二进制的图像文件中的图像提取出来并保存在文件夹中
for i in range(1, 60001):image = [item for item in file_images[16 + 28 * 28 * (i - 1):16 + 28 * 28 * i]]image_np = np.array(image, dtype=np.uint8).reshape(28, 28)im = Image.fromarray(image_np)im.save("G:\\mnist_dataset\\train-images" + "\\" + str(i) + ".png")# 将二进制的标签文件中的部分标签信息打印出来
for i in range(40, 53):labels = int.from_bytes(file_labels[8 + i - 1:8 + i], 'big')print('labels' + str(i) + '=' + str(labels))