Pytorch Tutorial【Chapter 3. Simple Neural Network】

Pytorch Tutorial【Chapter 3. Simple Neural Network】

文章目录

  • Pytorch Tutorial【Chapter 3. Simple Neural Network】
    • Chapter 3. Simple Neural Network
      • 3.1 Train Neural Network Procedure训练神经网络流程
      • 3.2 Build Neural Network Procedure 搭建神经网络
      • 3.3 Use Loss Function to Backward 利用损失函数进行反向传播
      • 3.4 Update Parameter of NN 更新神经网络参数
        • 3.4.1 Update Manually手动更新参数
        • 3.4.2 Update Automatically自动更新参数
    • Reference

Chapter 3. Simple Neural Network

3.1 Train Neural Network Procedure训练神经网络流程

一个典型的神经网络训练过程包括以下几点:

  • 定义一个包含可训练参数的神经网络

  • 迭代整个输入

  • 通过神经网络处理输入

  • 计算损失(loss)

  • 反向传播梯度到神经网络的参数

  • 更新网络的参数,典型的用一个简单的更新方法:weight = weight - learning_rate *gradient

3.2 Build Neural Network Procedure 搭建神经网络

  • 定义一个类并继承torch.nn.Moulde
  • 使用类torch.nn中的组件和torch.nn.functional中的组件来搭建网络结构
  • 改写该类的forward方法,在此方法中,进一步完善网络结构(如激活函数,池化层等),并且获得返回值的输出

先简要介绍一下需要使用到的torch.nn组件

  • torch.nn.Conv2d(in_channels, out_channels, kernel_size)进行卷积操作,输入是 ( N , C i n , H , W ) (N, C_{in},H,W) (N,Cin,H,W),输出是 ( N , C o u t , H , W ) (N,C_{out},H,W) (N,Cout,H,W)(详见Conv2d),卷积对图片尺寸的影响如下
Image
  • torch.nn.Linear(in_features, out_features) 进行放射变换操作(affine mapping),即 y = x A T + b y=xA^T+b y=xAT+b,(详见Linear)
  • torch.nn.Flatten(start_dim=1, end_dim=- 1),将连续的范围展平为张量(详见Flatten)

再介绍一下需要使用到的torch.nn.functional组件

  • torch.nn.functional.relu(),对输入使用Relu激活函数,具体操作是对每个元素都进行 ReLU ( x ) = max ⁡ ( 0 , x ) \text{ReLU}(x)=\max(0,x) ReLU(x)=max(0,x)的计算,(详见ReLu)

  • torch.nn.functional.softmax(),对输入使用Softmax激活函数,具体操作是对每个元素都进行 Softmax ( x i ) = exp ⁡ ( x i ) ∑ j exp ⁡ ( x j ) \text{Softmax}(x_i)=\frac{\exp(x_i)}{\sum_{j}\exp(x_j)} Softmax(xi)=jexp(xj)exp(xi)(详见Softmax)

  • torch.nn.functional.max_pool2d(input, kernel_size, stride) , 进行最大池化操作,输入是KaTeX parse error: Expected 'EOF', got '_' at position 28: …atch}, \text{in_̲channels}, iH,i…,详见Max_Pool2D)

例如下述代码,我们要搭建一个下图的神经网络结构

Image

​ 输入是 [ batch , channels , height , weight ] [\text{batch},\text{channels},\text{height},\text{weight}] [batch,channels,height,weight]的图片,分别代表批量大小、通道数、图像的高度、图像的宽度。在我们取批量大小为 1 1 1,然后这个例子变成 [ 1 , 1 , 32 , 32 ] [1,1,32,32] [1,1,32,32],张量的变化过程如下所示
KaTeX parse error: Expected 'EOF', got '_' at position 74: …arrow{\text{max_̲pool}}[1,6,14,1…

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net,self).__init__()# 1 is input_channel 6 is output_channelself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# affine function y = Wx + bself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self,x):# Max pooling over a (2,2) windowx = F.max_pool2d(F.relu(self.conv1(x)), (2,2))# If the size is a square you can only specify a single numberx = F.max_pool2d(F.relu(self.conv2(x)), (2,2))# flat the feature as a vectorx = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.softmax(self.fc3(x), dim=1)return xdef num_flat_features(self,x):size = x.size()[1:] # all dimensions except the batch dimensionnum_feature = 1for s in size:num_feature *= sreturn num_featurenet = Net()
print(net)

class Net(nn.Module):def __init__(self):super(Net,self).__init__()# 1 is input_channel 6 is output_channelself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# affine function y = Wx + bself.flat = nn.Flatten(1,-1)    # flat the featureself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self,x):# Max pooling over a (2,2) windowx = F.max_pool2d(F.relu(self.conv1(x)), (2,2))# If the size is a square you can only specify a single numberx = F.max_pool2d(F.relu(self.conv2(x)), (2,2))# flat the feature as a vectorx = self.flat(x)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.softmax(self.fc3(x), dim=1)return xnet = Net()
print(net)

结果如下,可以查看网络结构

Net((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
)

一个模型可训练的参数可以通过调用 net.parameters() 返回,

params = list(net.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight

结果如下

10
torch.Size([6, 1, 5, 5])

然后我们可以自定义一些输入,来获得网络的输出

input = torch.randn(3, 1, 32, 32)
out = net(input)
print(out.data)
sum = 0
for i in out.data[0]:sum += i.item()
print(sum)

输出如下

tensor([[0.0996, 0.1000, 0.0907, 0.0946, 0.0930, 0.1067, 0.1044, 0.1119, 0.1073,0.0918],[0.0975, 0.1005, 0.0913, 0.0935, 0.0939, 0.1070, 0.1045, 0.1121, 0.1065,0.0933],[0.0968, 0.1009, 0.0896, 0.0978, 0.0903, 0.1107, 0.1055, 0.1130, 0.1043,0.0913]])
0.9999999925494194

torch.nn.Module.zero_grad()把所有参数梯度缓存器置零,

net.zero_grad() #把所有参数梯度缓存器置零,用随机的梯度来反向传播
out.backward(torch.randn_like(out))
net.conv1.bias.grad  #查看Conv卷积层的偏置和权重一些参数的梯度
net.conv1.weight.grad

结果如下

tensor([-0.0064,  0.0136, -0.0046,  0.0008, -0.0044,  0.0021])
...

3.3 Use Loss Function to Backward 利用损失函数进行反向传播

现在我们已经完成了

  • 定义一个神经网络

  • 处理输入以及调用反向传播

还剩下:

  • 计算损失值

  • 更新网络中的权重

损失函数

一个损失函数需要一对输入:模型输出和目标,然后计算一个值来评估输出距离目标有多远。

有一些不同的损失函数在 nn 包(详细请查看loss-functions)。一个简单的损失函数就是 nn.MSELoss ,这计算了均方误差

input = torch.randn(3, 1, 32, 32)
target = torch.randn(3, 10)
predict = net(input)
criterion = nn.MSELoss()
loss = criterion(predict, target)print(loss)
print(loss.grad_fn.next_functions[0][0])

我们计算了MSE损失函数,并且能够跟踪其计算图

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear -> sfotmax-> MSELoss-> loss

当我们使用loss.backward()时,整个计算图都会进行微分,为了实现反向传播损失,我们所有需要做的事情仅仅是使用 loss.backward()你需要清空现存的梯度,要不然现在计算的梯度都将会和历史保存的梯度累计到一起

net.zero_grad()     # zeroes the gradient buffers of all parametersprint('conv1.bias.grad before backward')
print(net.conv1.bias.grad)loss.backward()print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

结果如下

conv1.bias.grad before backward
None
conv1.bias.grad after backward
tensor([-3.4418e-04, -1.0766e-04,  1.0913e-04, -5.5018e-05,  1.7342e-04,-5.3316e-04])

3.4 Update Parameter of NN 更新神经网络参数

3.4.1 Update Manually手动更新参数

我们可以手动实现随机梯度下降(stochastic gradient descent)来更新参数(详细请参考Chapter 6. Stochastic Approximation)

w k + 1 = w k − a k ∇ w k f ( w k , x k ) \textcolor{red}{w_{k+1} = w_k - a_k \nabla_{w_k} f(w_k,x_k)} wk+1=wkakwkf(wk,xk)

learning_rate = 0.01
for f in net.parameters():f.data.sub_(f.grad.data * learning_rate)    #sub_() is in-place minus

3.4.2 Update Automatically自动更新参数

尽管如此,如果你是用神经网络,你想使用不同的更新规则,类似于 SGD, Nesterov-SGD, Adam, RMSProp, 等。为了让这可行,我们建立了一个小包:torch.optim 实现了所有的方法。我们可以使用optim.step()来替代上述手动实现的代码,使用它非常的简单。

import torch.optim as optim# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)# in your training loop:
optimizer.zero_grad()   # zero the gradient buffers
output = net(input)
loss = criterion(output, target)	
loss.backward()		# 计算梯度
optimizer.step()    # Does the update

Reference

参考教程1
参考教程2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/76772.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

海外应用商店优化实用指南之关键词

和SEO一样,关键词是ASO中的一个重要因素。就像应用程序标题一样,在Apple App Store和Google Play中处理应用程序关键字的方式也有所不同。 关键词研究。 对于Apple,我们的所有关键词只能获得100个字符,Google Play没有特定的关键…

【新版系统架构补充】-传输介质、子网划分

传输介质 双绞线:无屏蔽双绞线UTP和屏蔽双绞线STP,传输距离在100m内 网线安装标准: 光纤:由纤芯和包层组成,分多模光纤MMF、单模光纤SMF 无线信道:分为无线电波和红外光波 通信方式和交换方式 单工…

做测试8年,33岁前只想追求大厂高薪,今年只求稳定收入

疫情3年,每一个行业的危机,每一个企业的倒下,背后都是无数人的降薪、降职和失业。这也暴露了人生的残酷真相:人活一辈子,总有“丰年”和“荒年” 优秀的测试既过得了丰年,也受得住荒年 一个测试宝妈&…

数据结构: 线性表(带头双向循环链表实现)

之前一章学习了单链表的相关操作, 但是单链表的限制却很多, 比如不能倒序扫描链表, 解决方法是在数据结构上附加一个域, 使它包含指向前一个单元的指针即可. 那么怎么定义数据结构呢? 首先我们先了解以下链表的分类 1. 链表的分类 链表的结构非常多样, 以下情况组合起来就有…

爬虫---练习源码

选取的是网上对一些球员的评价,来评选谁更加伟大一点 import csv import requests import re import timedef main(page):url fhttps://tieba.baidu.com/p/7882177660?pn{page}headers {User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/53…

AMEYA360:瑞萨电子MCU和MPU产品线将支持Microsoft Visual Studio Code

全球半导体解决方案供应商瑞萨电子宣布其客户现可以使用Microsoft Visual Studio Code(VS Code)开发瑞萨全系列微控制器(MCU)和微处理器(MPU)。瑞萨已为其所有嵌入式处理器开发了工具扩展,并将其…

分布式开源监控Zabbix实战

Zabbix作为一个分布式开源监控软件,在传统的监控领域有着先天的优势,具备灵活的数据采集、自定义的告警策略、丰富的图表展示以及高可用性和扩展性。本文简要介绍Zabbix的特性、整体架构和工作流程,以及安装部署的过程,并结合实战…

分布式异步任务处理组件(七)

分布式异步任务处理组件底层网络通信模型的设计--如图: 使用Java原生NIO来实现TCP通信模型普通节点维护一个网络IO线程,负责和主节点的网络数据通信连接--这里的网络数据是指组件通信协议之下的直接面对字节流的数据读写,上层会有另一个线程负…

Linux下安装VMware虚拟机

目录 1. 简介 2. 工具/原料 2.1. 下载VMware 2.2. 安装 1. 简介 ​ VMware Workstation(中文名“威睿工作站”)是一款功能强大的桌面虚拟计算机软件,提供用户可在单一的桌面上同时运行不同的操作系统,和进行开发、测试 …

为什么list.sort()比Stream().sorted()更快?

真的更好吗&#xff1f; 先简单写个demo List<Integer> userList new ArrayList<>();Random rand new Random();for (int i 0; i < 10000 ; i) {userList.add(rand.nextInt(1000));}List<Integer> userList2 new ArrayList<>();userList2.add…

从零开始:手把手搭建 RocketMQ 单节点、集群节点实例

&#x1f52d; 嗨&#xff0c;您好 &#x1f44b; 我是 vnjohn&#xff0c;在互联网企业担任 Java 开发&#xff0c;CSDN 优质创作者 &#x1f4d6; 推荐专栏&#xff1a;Spring、MySQL、Nacos、Java&#xff0c;后续其他专栏会持续优化更新迭代 &#x1f332;文章所在专栏&…

240. 搜索二维矩阵 II

240. 搜索二维矩阵 II 原题链接&#xff1a;完成情况&#xff1a;解题思路&#xff1a;参考代码&#xff1a; 原题链接&#xff1a; 240. 搜索二维矩阵 II https://leetcode.cn/problems/search-a-2d-matrix-ii/description/ 完成情况&#xff1a; 解题思路&#xff1a; 从…

配置root账户ssh免密登录并使用docker-machine构建docker服务

简介 Docker Machine是一种可以在多种平台上快速安装和维护docker运行环境&#xff0c;并支持多种平台&#xff0c;让用户可以在很短时间内在本地或云环境中搭建一套docker主机集群的工具。 使用docker-machine命令&#xff0c;可以启动、审查、停止、重启托管的docker 也可以…

《向量数据库指南》——腾讯云向量数据库Tencent Cloud Vector DB正式上线公测!提供10亿级向量检索能力

8月1日,腾讯云向量数据库(Tencent Cloud Vector DB)已正式上线公测。在腾讯云官网上搜索“向量数据库”,就可以正式体验该产品。 腾讯云向量数据库不仅能为大模型提供外部知识库,提高大模型回答的准确性,还可广泛应用于推荐系统、文本图像检索、自然语言处理等 AI 领域。…

ChatGPT已打破图灵测试,新的测试方法在路上

生信麻瓜的 ChatGPT 4.0 初体验 偷个懒&#xff0c;用ChatGPT 帮我写段生物信息代码 代码看不懂&#xff1f;ChatGPT 帮你解释&#xff0c;详细到爆&#xff01; 如果 ChatGPT 给出的的代码不太完善&#xff0c;如何请他一步步改好&#xff1f; 全球最佳的人工智能系统可以通过…

K8S kubeadm搭建

kubeadm搭建整体步骤 1&#xff09;所有节点进行初始化&#xff0c;安装docker引擎和kubeadm kubelet kubectl 2&#xff09;生成集群初始化配置文件并进行修改 3&#xff09;使用kubeadm init根据初始化配置文件生成K8S的master控制管理节点 4&#xff09;安装CNI网络插件&am…

【Ubuntu 18.04 搭建 DHCP 服务】

参考Ubuntu官方文档&#xff1a;https://ubuntu.com/server/docs/how-to-install-and-configure-isc-dhcp-server dhcpd.conf 手册页 配置&#xff1a;https://maas.io/docs/about-dhcp 实验环境规划 Ubuntu 18.04&#xff08;172.16.65.128/24&#xff09;dhcp服务端Ubuntu…

GD32F103VE点灯

GD32F103VE点灯主要用来学习端口引脚的输出配置。它由LED.c&#xff0c;LED.h&#xff0c;SoftDelay.c和main.c组成。 #include "gd32f10x.h" //使能uint8_t,uint16_t,uint32_t,uint64_t,int8_t,int16_t,int32_t,int64_t #include "SoftDelay.h"#include …

后端整理(集合框架、IO流、多线程)

1. 集合框架 Java集合类主要有两个根接口Collection和Map派生出来 Collection派生两个子接口 List List代表了有序可重复集合&#xff0c;可以直接根据元素的索引进行访问Set Set代表无序不可重复集合&#xff0c;只能根据元素本身进行访问 Map接口派生 Map代表的是存储key…

数据结构 二叉树(C语言实现)

绪论 雄关漫道真如铁&#xff0c;而今迈步从头越。 本章将开始学习二叉树&#xff08;全文共一万两千字&#xff09;&#xff0c;二叉树相较于前面的数据结构来说难度会有许多的攀升&#xff0c;但只要跟着本篇博客深入的学习也可以基本的掌握基础二叉树。 话不多说安全带系好&…