因为图片识别很多代码都包装在d2l库里了,直接调用就行了
完整代码:
import torch
from torch import nn
from d2l import torch as d2l"获取训练集&获取检测集"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) # nn.Flatten()将28*28展平成784"初始化w,b后者不操作默认初始化"
def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std = 0.01)net.apply(init_weights) # 给到所有模型loss = nn.CrossEntropyLoss()trainer = torch.optim.SGD(net.parameters(), lr=0.1) # net.parameters()将net中数据整合w,b给SGDif __name__ == '__main__':num_epochs = 10cnt = 1for i in range(num_epochs):X, Y = d2l.train_epoch_ch3(net, train_iter, loss, trainer)print("训练次数: " + str(cnt))cnt += 1print("训练损失: {:.4f}".format(X))print("训练精度: {:.4f}".format(Y))print(".................................")
画图功能不兼容pycharm,所以还是朴素的用输出函数吧