文章目录
- 1 代码分析
- 1.1 加载数据集
- 1.2 定义模型
- 1.3 定义损失函数和优化器
- 1.4 定义训练函数
- 1.4.1 定义累加器Accumulator
- 1.4.2 计算准确率accuracy
- 1.4.3 评估函数evaluate_accuracy
- 1.4.4 单轮训练函数train_epoch
- 1.4.5 训练函数train
- 1.2 执行训练
- 2 整体代码
- 3 参考附录
1 代码分析
1.1 加载数据集
Fashion-MNIST 的目的是要成为 MNIST 数据集的一个直接替代品。作为算法作者,你不需要修改任何的代码,就可以直接使用这个数据集。Fashion-MNIST 的图片大小,训练、测试样本数及类别数与经典 MNIST 完全相同。
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from torch import nndef load_data_fashion_mnist(batch_size, loader_num,