MNIST手写数字辨识-cnn网路 (机器学习中的hello world,加油)

用PyTorch实现MNIST手写数字识别(非常详细) - 知乎 (zhihu.com)

参考来源(这篇文章非常适合入门来看,每个细节都讲解得很到位)

一、模块函数用法-查漏补缺:

1.关于torch.nn.functional.max_pool2d()的用法:

上述示例中,输入张量 input 经过最大池化操作后,使用了 kernel_size=2stride=2,所以输出张量 output 的高度和宽度均为输入的一半(32/2=16)。

2.pytorch中的view函数的用法:

http://t.csdn.cn/AAhdH

这一篇文章写得非常好

3.关于f.log_softmax(x,dim = -1)这个先进行softmax,再取log的函数的讲解:

http://t.csdn.cn/GIJ7g

这篇文章讲解得非常好,补充一点,dim的default值和softmax一样,都是-1,也就是计算最里面那个维度的softmax的结果

4.原来loss和counter计数器数组有这个作用:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]

5.关于F.nll_loss这个损失函数:

http://t.csdn.cn/ZoruZ

总的来说就是一句话“损失函数 nn.CrossEntropyLoss() 与 NLLLoss() 相同, 唯一的不同是它为我们去做 log_softmax.”

这篇文章讲述得非常清楚

6.关于loss.item()的作用:

http://t.csdn.cn/AvrnJ

这篇文章讲得非常清楚:

就是输出loss这个数值,但是呢,是用非常高的精度进行输出的,一般我们进行一各batch的训练后,就会得到这一次的loss单个数值,需要输出的话,最好就用item()

7.with torch.no_grad()的用法:

http://t.csdn.cn/STaKp

这篇文章讲述得非常清楚,就是不会进行gradient_descend操作,极大的节省了运算开销

8.data.max()函数的用法:

http://t.csdn.cn/aBmin

上面那里讲得不太好,还是chatGPT比较优秀

9.data.view_as()的用法:

10.torch.eq的用法:

http://t.csdn.cn/Tb0kY

这篇文章讲述得非常清楚,也就是对张量中的数值逐个进行比较,

返回的是同样形状的数据,每个位置要么True要么False,可以用.sum()求和得到True的总数

顺便提一下torch.sum的用法,

x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(x.sum())

输出的结果是21

二、各个部分的代码和注释:

#设置环境
import torch
import torchvision
from torch.utils.data import DataLoader
#准备数据集
#1.设置必要的参数
n_epochs = 3
batch_size_train = 64 #所以呢,这个64其实就是下面train时候的batch_size大小
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10 #这个就是后面用来输出的间隔
random_seed = 1
torch.manual_seed(random_seed)

 

#利用pytorch直接加载对应的train_data集 和 test_Data集
train_loader = torch.utils.data.DataLoader( #这里调用的是torch.utils.data.DataLoader的对象,实例化出train_dataloader#限免设置各个参数,比如,第一个就是Dataset参数,这里是引用MNIST作为参数,并且设置MNIST中的各个参数torchvision.datasets.MNIST('./data/', train=True, download=True, #设为train数据+下载transform=torchvision.transforms.Compose([ #对数据进行transform变换torchvision.transforms.ToTensor(), #先变tensor后进行Normlizetorchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_train, shuffle=True) #这个loader的后两个参数batch_size和shuffle
#同样的道理设置test_data_loader
test_loader = torch.utils.data.DataLoader(torchvision.datasets.MNIST('./data/', train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size_test, shuffle=True)
#查看一条数据:
examples = enumerate(test_loader) #enumerate返回一个(index,data)的元组,本身是一个迭代器,可以用于遍历test_loader
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets) #输出测试(这里的test是有answer作为label的)的1000各answer
print(len(example_targets)) #总共1000各target
print(example_data.shape) #一共有1000张28*28的黑白灰度图
#利用matplotlib进行绘制得到某些数据的可视化结果
import matplotlib.pyplot as plt
fig = plt.figure() #创建一个fig对象
for i in range(6):plt.subplot(2,3,i+1) #按照2行3列绘制6张图片plt.tight_layout() #设置紧密相连plt.imshow(example_data[i][0], cmap='gray', interpolation='none') #利用imshow在下方直接输出图像plt.title("Ground Truth: {}".format(example_targets[i]))#设置标题,就是label的数值plt.xticks([])plt.yticks([])
plt.show()
#定义neural network的结构
import torch.nn as nn #引入neural network的库
import torch.nn.functional as F #引入nn总的常用Func
import torch.optim as optim #引入torch中的optimizerclass Net(nn.Module): #继承nn中的moduledef __init__(self):  #定义这个网络结构的构造函数super(Net, self).__init__() #继承nn.Module的初始化构造self.conv1 = nn.Conv2d(1, 10, kernel_size=5)#参数:输入channel、输出channel、卷积核5*5(filters),strdie(default =1),padding(default=0)#所以1*28*28的图像通过后,10*21*21(10是filters的数量)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d() #这个函数http://t.csdn.cn/xK6og这篇文章讲得挺好的,就是让部分filters在某一层不工作,效果是有效防止overfitself.fc1 = nn.Linear(320, 50) #定义一个320 -->50 的Linear层函数self.fc2 = nn.Linear(50, 10)  #定义一个50 -->10  的Linear层函数def forward(self, x): #下面就是直接进行整个network的作用过程定义了 , 输入1*28*28的灰度图x = F.relu(F.max_pool2d(self.conv1(x), 2)) #经过一个conv1卷积层后,经过1次2*2窗口的pooling得到,默认??padding=1,之后再算好了x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) #再通过conv2之后->通过conv2_drop->通过max_pool2dx = x.view(-1, 320) #第二维是320,并自动计算第一维x = F.relu(self.fc1(x)) #通过一个linear层之后,又通过一个relu的激活函数,最后输出的是第二维是50的结果x = F.dropout(x, training=self.training) #只有在training模式下才会调用dropout(让某些神经元“熄火”喵)x = self.fc2(x) #再让x通过一个linear层,输出的结果是2维的数据,第二维(共10列)return F.log_softmax(x) #最后通过对最里面那一层softmax层后,取log对数
#创建model对象+设置optimizer优化器
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,momentum=momentum) #lr和momentum都是上面设置好的

#设置用于存储的数组结构:
train_losses = []
train_counter = [] #估计就是一个计数器的作用
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]
#print(test_counter)输出[0,60000,120000,180000]不知道再干啥,反正上面的n——epoch==3
#定义这个train函数:
def train(epoch): #这里的epoch是传递进来的参数network.train() #开启train模式for batch_idx, (data, target) in enumerate(train_loader):#迭代器:以batch为单位逐个从train_loader中获取 索引、data图像数据、label作为target数据optimizer.zero_grad() #因为torch中的grad是累加的,所以需要在每个batch训练之前利用optimizer.zero_grad()清零output = network(data) #将data图像数据通过network网络得到output输出结果loss = F.nll_loss(output, target) #这个loss_func只是比cross_entropy少一个对输入数据的log_softmax操作loss.backward()optimizer.step() #loss.backward + optimizer.step()常规更新模型参数的操作if batch_idx % log_interval == 0: #下面都是没啥用的间隔输出操作,上面设置的log_interval =10#每经过10各batch处理输出一次:#第几个epoch,第几个图像,总共的train有多少图像,已经完成了百分之几的batch,这个batch的loss值print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item())) #将这个batch的loss值添加到train_losses数组中(注意,这里好像是每隔10个batch记录一次loss)train_losses.append(loss.item())train_counter.append(  #在counter中记录这个batch在考虑epoch情况下的位置(batch_idx*64) + ((epoch-1)*len(train_loader.dataset))) #这个64是train时候的batch_size上面写了#将当前的network的参数状态state_dice存储到对应的路径下, 同时optimizer的状态也要存储?why感觉optim没啥用torch.save(network.state_dict(), './model.pth')torch.save(optimizer.state_dict(), './optimizer.pth')#train(1) #传递参数epoch=1进行train一次
#这里的train有个地方很有意思,它只是输出loss,没有利用argmax计算出对应的one-hot vec,从而没法和label进行比较得到acc
#定义test函数,并且进行test测试 (不用想,大概率和train的内容没有太大的区别,不过是少了backward和step的更新)
def test():network.eval() #开启model的eval模式test_loss = 0 #设置loss和acc初值correct = 0with torch.no_grad(): #不计算SGDfor data, target in test_loader: #非enumerate,非迭代器版本,不会返回索引,获取data图像batch和target的labels数值output = network(data) #调用network获取output结果test_loss += F.nll_loss(output, target, size_average=False).item() #这里计算出这一次的 output和target之间的losspred = output.data.max(1, keepdim=True)[1] #通过data.max函数获取对应的索引,这是一个索引的数组,因为是一个batch一起预测的correct += pred.eq(target.data.view_as(pred)).sum() #如果pred和target数组对应位置比较,计算总共相等的位置的数量test_loss /= len(test_loader.dataset) #计算平均的losstest_losses.append(test_loss) #将这一次的平均loss加入到test_losses数组中#输出:#这一次的平均loss,总数中正确预测的数目,正确率print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))#test()#调用上述定义的test函数
#再调用一次test()
test()
for epoch in range(1, n_epochs + 1): #调用n_epochs个数的train和测试结果train(epoch)test()
#下面对上述获取到的数据进行图像的绘制
#绘制图像一开始出错了,我怀疑是我多进行了一次test(),导致x和y的大小不对应
import matplotlib.pyplot as plt
fig = plt.figure()       #创建figure对象 
plt.plot(train_counter, train_losses, color='blue') #绘制曲线图,x是train计数,y是trainloss
#plt.scatter(test_counter, test_losses, color='red') #绘制散点图,x是test_counter计数,y是test_losses数据
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen') #x轴标题
plt.ylabel('negative log likelihood loss')     #y轴标题
plt.show() #绘制结果
#抽取几个直观的例子进行测试:examples = enumerate(test_loader) #获取test_loader的迭代器
batch_idx, (example_data, example_targets) = next(examples) #获取第一个test_loader中的batch
with torch.no_grad():output = network(example_data) #将example_data数据通过network得到output
fig = plt.figure() #创建figure对象
for i in range(6): #构建2行3列的图像排列plt.subplot(2,3,i+1)plt.tight_layout() #紧密排列plt.imshow(example_data[i][0], cmap='gray', interpolation='none') #利用imshow输出example图像plt.title("Prediction: {}".format(output.data.max(1, keepdim=True)[1][i].item())) #输出预测结果,结果非常美妙plt.xticks([])plt.yticks([])
plt.show() #绘制-这个似乎可以不用
#为了能够持续训练,这里考虑 获取 上一次的 model_dict 和 optim_dict
continued_network = Net()
continued_optimizer = optim.SGD(network.parameters(), lr=learning_rate,momentum=momentum)network_state_dict = torch.load('model.pth')
continued_network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load('optimizer.pth')
continued_optimizer.load_state_dict(optimizer_state_dict)#再接着上面练上6次
for i in range(4, 9):test_counter.append(i*len(train_loader.dataset))train(i)test()
#同样进行图像的绘制
fig = plt.figure()
plt.plot(train_counter, train_losses, color='blue')
plt.scatter(test_counter, test_losses, color='red') #因为之前多test了一次,所以这里应该还是会出错
plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
plt.xlabel('number of training examples seen')
plt.ylabel('negative log likelihood loss')
plt.show()

第一个再vscode完成的神经网络训练, 撒花庆祝!!🎉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/125783.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

工作和生活中,如何用项目管理思维解决复杂的事情?

在工作和生活中,许多事情都可以采用项目思维方式来解决。当我们逐渐将工作和生活中的各种事务以项目的方式来处理和推进时,我们可能并没有意识到,实际上我们正在运用项目管理思维。 项目管理思维能帮助我们在面对繁杂事务时,理清…

OpenCV 06(图像的基本变换)

一、图像的基本变换 1.1 图像的放大与缩小 - resize(src, dsize, dst, fx, fy, interpolation) - src: 要缩放的图片 - dsize: 缩放之后的图片大小, 元组和列表表示均可. - dst: 可选参数, 缩放之后的输出图片 - fx, fy: x轴和y轴的缩放比, 即宽度和高度的缩放比. - …

flask bootstrap页面json格式化

html <!DOCTYPE html> <html lang"en"> <head><!-- 新 Bootstrap5 核心 CSS 文件 --> <link rel"stylesheet" href"static/bootstrap-5.0.0-beta1-dist/css/bootstrap.min.css"><!-- 最新的 Bootstrap5 核心 …

【设计模式】组合模式实现部门树实践

1.前言 几乎在每一个系统的开发过程中&#xff0c;都会遇到一些树状结构的开发需求&#xff0c;例如&#xff1a;组织机构树&#xff0c;部门树&#xff0c;菜单树等。只要是需要开发这种树状结构的需求&#xff0c;我们都可以使用组合模式来完成。 本篇将结合组合模式与Mysq…

PaddleOCR学习笔记2-初步识别服务

今天初步实现了网页&#xff0c;上传图片&#xff0c;识别显示结果到页面的服务。后续再完善。 采用flask paddleocr bootstrap快速搭建OCR识别服务。 代码结构如下&#xff1a; 模板页面代码文件如下&#xff1a; upload.html : <!DOCTYPE html> <html> <…

合宙Air724UG LuatOS-Air LVGL API控件-窗口 (Window)

窗口 (Window) 分 享导出pdf 示例代码 win lvgl.win_create(lvgl.scr_act(), nil) lvgl.win_set_title(win, "Window title") -- close_btn lvgl.win_add_btn_right(win, "\xef\x80\x8d") -- --lvgl.obj_set_event_cb(cl…

【运维 Pro】时序场景实践与原理 - 1. 分布与分区

【运维 Pro】: 是由 YMatrix 售前和售后团队负责的栏目。除了介绍日常的数据库运维和使用知识&#xff0c;我们更希望能够通过介绍这些知识背后的原理&#xff0c;让大家和我们一起感知数据库的美妙。 摘要 有别于其它场景&#xff0c;时序场景中的数据、查询都有着更为明显的…

echarts饼图点击区块事件

效果图&#xff1a; 代码&#xff1a; let option {color: pieColors,series: [{name: Access From,type: pie,radius: [36%, 56%],avoidLabelOverlap: false,label: {formatter: params > {// console.log(params)return {color${params.dataIndex}|${params.name}(${par…

EXCEL 中find,if and,if or

接上一篇sql中find函数的作用&#xff0c;由于工作需求是用帆软做报表&#xff0c;他的一些代码不仅有js&#xff0c;sql中的还有一些excel的相关知识&#xff0c;故作整理。 FIND() excel中的find原理和sql中相似&#xff0c;具体可查看 SQL函数 $FIND_Yangshiwei....的博客…

2023开学礼《乡村振兴战略下传统村落文化旅游设计》许少辉八一新书对外经济贸易大学图书馆

2023开学礼《乡村振兴战略下传统村落文化旅游设计》许少辉八一新书对外经济贸易大学图书馆

Splunk Enterprise for Mac:卓越的数据分析与管理工具

在当今的数字化时代&#xff0c;数据已经成为企业成功的核心驱动力。然而&#xff0c;如何有效地管理和分析这些数据&#xff0c;却常常让企业感到困惑。Splunk Enterprise for Mac 是一款领先的数据分析和管理工具&#xff0c;可以帮助你解决这一难题。 Splunk Enterprise fo…

火山引擎边缘云助力智能科技赋予生活更多新意

当下&#xff0c;先进的科学技术使得我们的日常生活变得快捷、舒适。大到上百层智能大厦、高端公共场所、社会智能基础设施&#xff0c;小到智能家居监控、指纹密码锁等&#xff0c;在这个充满想象力的时代&#xff0c;科技以更加智能化的方式改变和守护我们的生活。 引入智能…

[小尾巴 UI 组件库] 全屏响应式轮播背景图(基于 Vue 3 与 Element Plus)

文章归档于&#xff1a;https://www.yuque.com/u27599042/row3c6 组件库地址 npm&#xff1a;https://www.npmjs.com/package/xwb-ui?activeTabreadme小尾巴 UI 组件库源码 gitee&#xff1a;https://gitee.com/tongchaowei/xwb-ui小尾巴 UI 组件库测试代码 gitee&#xff1a…

【已更新代码图表】2023数学建模国赛E题python代码--黄河水沙监测数据分析

E 题 黄河水沙监测数据分析 黄河是中华民族的母亲河。研究黄河水沙通量的变化规律对沿黄流域的环境治理、气候变 化和人民生活的影响&#xff0c;以及对优化黄河流域水资源分配、协调人地关系、调水调沙、防洪减灾 等方面都具有重要的理论指导意义。 附件 1 给出了位于小浪底水…

docker安装opengauss数据库

opengauss官网&#xff1a;https://opengauss.org/ opengauss镜像&#xff1a;https://hub.docker.com/r/enmotech/opengauss 一&#xff1a;镜像拉取并运行 docker run --name opengauss --privilegedtrue -d -e GS_USERNAMEgaussdb -e GS_PASSWORDopenGauss123 -p 5432:54…

thinkPhp5返回某些指定字段

//去除掉密码$db new UserModel();$result $db->field(password,true)->where("username{$params[username]} AND password{$params[password]}")->find(); 或者指定要的字段的数组 $db new UserModel();$result $db->field([username,create_time…

【进阶篇】Redis内存淘汰详解

文章目录 Redis内存淘汰详解0. 前言大纲Redis内存淘汰策略 1. 什么是Redis内存淘汰策略&#xff1f;1.1.Redis 内存不足的迹象 2. Redis内存淘汰策略3. 原理4. 主动和被动1. 主动淘汰1.1 键的生存周期1.2 过期键删除策略 2. 被动淘汰2.2 被动淘汰策略的实现 5. 项目实践优化策略…

阿里云ubuntu服务器搭建ftp服务器

阿里云ubuntu服务器搭建ftp服务器 服务器环境安装步骤一.创建用户二.安装 vsftp三 配置vsftp四.配置阿里云安全组 服务器环境 阿里云上的云服务器&#xff0c;操作系统为 ubuntu20.04。 安装步骤 一.创建用户 为什么需要创建用户&#xff1f; 这里的用户&#xff0c;指的是…

数据结构与算法-树论基础二叉树

大家来看以下几个结构&#xff1a;下图中的结构除了一颗不是树其余的都是&#xff0c;我们可以发现这个跟我们现实生活的树是不是非常相似. 在树形结构里面有几个重要的术语&#xff1a; 1.结点&#xff1a;树里面的元素。 2.父子关系&#xff1a;结点之间相连的边 3.子树&…

云原生Kubernetes:Kubeadm部署K8S单Master架构

目录 一、理论 1.kubeadm 2.Kubeadm部署K8S单Master架构 3.环境部署 4.所有节点安装docker 5.所有节点安装kubeadm&#xff0c;kubelet和kubectl 6.部署K8S集群 7.安装dashboard 8.安装Harbor私有仓库 9.内核参数优化方案 二、实验 1.Kubeadm部署K8S单Master架构 …