[Machine Learning]pytorch手搓一个神经网络模型

因为之前虽然写过一点点关于pytorch的东西,但是用的还是他太少了。

这次从头开始,尝试着搓出一个神经网络模型

(因为没有什么训练数据,所以最后的训练部分使用可能不太好跑起来的代码作为演示,如果有需要自己连上数据集合进行修改捏)

1.先阐述一下什么是神经网络块(block)

一般来说,我们之前遇到的一些神经网络,网络中是这样子的结构

net----> layer ----> neuron

而块的存在,就是给这样一个神经网络的整体做了一个封装操作,让神经网络能复合实现一些功能。

结构就变成了这个样子(图片来自D2l)

这样子,神经网络结构就变成了四层

  block ----》 net ---》layer ---》neuron

这样子自然是可以使用诸如一些奇怪的方法,通过三层索引去进行调用什么的,不过这个我们到后面再说。先看一下如何构建一个块。

我们这里构建了一个类,这个类的计算方法实际就是实现了几个层的输入和输出,相当与封装了一个神经网络。

class MLP(nn.Module):# 用模型参数声明层。这里,我们声明两个全连接的层def __init__(self):# 调用MLP的父类Module的构造函数来执行必要的初始化。# 这样,在类实例化时也可以指定其他函数参数,例如模型参数params(稍后将介绍)super().__init__()self.hidden = nn.Linear(20, 256)  # 隐藏层self.out = nn.Linear(256, 10)  # 输出层# 定义模型的前向传播,即如何根据输入X返回所需的模型输出def forward(self, X):# 注意,这里我们使用ReLU的函数版本,其在nn.functional模块中定义。return self.out(F.relu(self.hidden(X)))# 这个东西就相当与先隐藏层,然后relu,然后最后进行一次输出#创建这个神经网路块,然后开始输出
net = MLP()
net(X)

这段代码没有使用squential容器进行封装,但是可以很清楚地看到我们定义了两层(隐藏层256个神经元,输出层10个神经元,不知道为什么没用softmax函数),并且在返回函数计算的时候,中间还经过了一步‘relu’激活函数的操作

(注意和tf不同,pytorch框架下面是不能把激活函数存入层中的,需要单独作为一个‘层’来进行一个输入和输出的控制)

注意下(在后面自定义层的时候也是这样子)由于继承了nn.Module这个类 , 所以我们必须要实现两个函数,首先是_init_,这个在python中是最终要的构造函数。其次就是forward,我对py不是很了解,不过这应该是通过面向对象实现的集成。forward这个方法就是向前传播,也就是接受参数,内部计算,然后返回值传递下去。

我们直接给net对象传递我们随机生成的两条数据的时候,底层时就调用了这个函数。

其他的一些比如sequential的实现方法,在这里我们就不加以赘述了。

为了更好的解释forward这个函数的作用,在这里我们自己创建一个单层,通过类创建,仍然是获取一个集成nn.Module的类,然后内部设置好初始化(为了创建对象),设置好向前传播(为了用来调用)

# 自定义一个不需要参数的层
class CenteredLayer(nn.Module):def __init__(self):super().__init__()def forward(self, X):    #该层向前传播的方法return X - X.mean()
# 这一层最终也是返回一个张量
# sequential是一个简单的线性封装容器,所以只要是符合输入张量,输出张量
# 并且在内部会调用他们的forward方法layer = CenteredLayer()
print('自定义层,每个元素都 - 平均值2',layer(torch.FloatTensor([1, 2, 3, 4, 5])))

这个单层的效果就是对每个元素,都减去平均值。

并且如果想的话,我们也可以创建一些拥有自己属性的层


#现在创建一个带有权重和偏好
class MyLinear(nn.Module):def __init__(self, in_units, units):super().__init__()self.weight = nn.Parameter(torch.randn(in_units, units)) #这个需要手动输入一下输入特征数目还有神经元数目self.bias = nn.Parameter(torch.randn(units,))            def forward(self, X):linear = torch.matmul(X, self.weight.data) + self.bias.data  #向前传播其实就是接受输入return F.relu(linear)
#创建一个层,这个层可以直接用在sequential之中
linear = MyLinear(5, 3) #五个输入三个神经元

这里可以看到只要重写了forward方法,那么这个类就能变成一个能用来计算的类,甚至是一个层可以单独计算。并且这样子写好以后是可以放在sequential容器中,作为一个统一训练的。

因此,如果我们有多个块的话,也是可以自己去写一个容器,进行组合。

#创建一个新类
class MySequential(nn.Module): #()就是py中的继承语法def __init__(self, *args):super().__init__()for idx, module in enumerate(args):# 这里,module是Module子类的一个实例。我们把它保存在'Module'类的成员# 变量_modules中。_module的类型是OrderedDictself._modules[str(idx)] = moduledef forward(self, X):# OrderedDict保证了按照成员添加的顺序遍历它们for block in self._modules.values():X = block(X)return X#这个类实现的效果就类似原声的sequential
net = MySequential(net1,net2,net3)
print(net(X))

这个就大概是在拼接块,层的时候,内部所做的底层原理。

当然直接用sequential容器是更加省力气的方法,对吧

2.关于参数如何进行检查

假设现在有一个单独的神经网络

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))

众所周知,这个神经网络是两层(中间的一层是激活函数我们不做讨论)

我们可以通过索引来调用和获取某个层的属性

#返回结果是这个全链接层的weight和bias,正好对应八个神经元
print(net[0].state_dict())
#检查参数
print(net[2].bias) #还会返回一些具体的属性
print(net[2].bias.data) #单纯的数据

对于block组成的神经网路社区中(我也不知道很多块组在一起应该叫什么了),仍然是一个嵌套的结构,我们可以创建这样一个社区

#这段代码其实也能看出来,sequential也是一个能容纳block的东西
#     容器 --》 block --》 layer --》 神经元   这三层架构(或者说四层)
def block1():return nn.Sequential(nn.Linear(4, 8), nn.ReLU(),nn.Linear(8, 4), nn.ReLU())
def block2():net = nn.Sequential()for i in range(4):# 在这里嵌套net.add_module(f'block {i}', block1())return netrgnet = nn.Sequential(block2(), nn.Linear(4, 1))

然后我们对这个rgnet进行打印,可以直接看到工作状态

#这样子打印会展示整个网络的状态
print('检查工作状态',rgnet)

可以很清晰地看到,这样一个嵌套结构

所以比如说我们想要访问第一个社区中,第2个块,中的第一个层中的参数,我们可以直接这样子读取

rgnet[0][1][0].bias.data

另外如果想要对已经形成的模型做初始化,这里还有一个例子

def init_normal(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, mean=0, std=0.01)   #平均值0,标准差为0.01nn.init.zeros_(m.bias)                        #偏移直接设置为0
net.apply(init_normal)
print('手动初始化的效果为',net[0].weight.data[0],'手动初始化bias:', net[0].bias.data[0])

函数实现的功能是先检测传进来的是不是正常的线性层,然后分别初始化。

补充一下,apply函数和js里的用法差不多,对内部的每个单元进行遍历,然后做一些操作。

(当然这不是唯一一种方法,自然还有别的)。

3.关于张量的保存和获取

在pytorch中,张量的保存主要有两种形式,第一种是保存数据,用于其他模型的训练

#===========读写张量===========#
x = torch.arange(4)       #[0,1,2,3],创建了一个张量
torch.save(x, 'x-file')   #这是保存在x-file这个文件下面的
loaded_x = torch.load('x-file')  #反过来加载
print(loaded_x)                  #输出
#这样子读取列表和读出,也可以使用字典{x:x,Y:y}或者列表[x,y],反正是变成文件形式了

另一种是保存模型的参数,可以直接套在其他模型上

#=====读写参数并且保存在内存=====#class MLP(nn.Module):   #手动创建多层感知机def __init__(self):super().__init__()self.hidden = nn.Linear(20, 256)self.output = nn.Linear(256, 10)def forward(self, x):return self.output(F.relu(self.hidden(x)))net = MLP()           #构建对象
print('MLP的参数',net.state_dict()) #这里输出一下参数#保存这个模型的参数
torch.save(net.state_dict(), 'mlp.params')#然后对一个新模型使用这个参数
clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))  #内置函数加载参数clone.eval()#设置为评估模式,禁止训练什么的,这应该是module中附带的功能print('clone的参数',clone.state_dict())#可以看到参数被完全复制了

但是注意一个问题,如果使用另一个模型初始化自身的时候,要保证两个模型的结构一致

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/151757.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java常见设计模式

单例模式:程序自始至终只创建一个对象。 应用场景:1.整个程序运行中只允许一个类的实例时 2.需要频繁实例化然后销毁的对象 3.创建对象时耗时过多但又经常用到的对象 4.方便资源相互通信的环境 懒汉式线程不安全问题解决方案: 双重检查加锁机…

手机切换ip地址的几种方法详解

在某些情况下,我们可能需要切换手机的IP地址来实现一些特定的需求,如解决某些应用程序的限制、绕过IP封禁等。本文将为大家分享几种切换手机IP地址的方法,让您能够轻松应对各种需求。 一、使用动态服务器 使用动态服务器是一种常见的切换手机…

STM32 CubeMX ADC采集(HAL库)

STM32 CubeMX ADC采集(HAL库) STM32 CubeMX STM32 CubeMX ADC采集(HAL库)ADC介绍ADC主要特征最小识别电压值:2.4/4096≈0.6mv(不考虑误差)一、STM32 CubeMX设置二、代码部分三,单通道…

【Leetcode】 51. N 皇后

按照国际象棋的规则,皇后可以攻击与之处在同一行或同一列或同一斜线上的棋子。 n 皇后问题 研究的是如何将 n 个皇后放置在 nn 的棋盘上,并且使皇后彼此之间不能相互攻击。 给你一个整数 n ,返回所有不同的 n 皇后问题 的解决方案。 每一种…

webserver项目

利用无锁工作队列的Web服务器设计 项目地址https://github.com/whitehat32/webserver_no_lock 基本流程与牛客版的一致,下面放一个牛客版的流程框图 引言 在Web服务器的设计与实现中,性能优化是永远不会过时的话题。一般来说,Web服务器需…

iview 的table表格组件使单元格可编辑和输入

表格的列定义中&#xff0c;在需要编辑的字段下使用render函数 template表格组件 <Table border :data"data" :columns"tableColumns" :loading"loading"></Table>data中定义table对象 table: {tableColumns: [{title: 商品序号,k…

EdgeView 4 for Mac:重新定义您的图像查看体验

您是否厌倦了那些功能繁杂、操作复杂的图像查看器&#xff1f;您是否渴望一款简单、快速且高效的工具&#xff0c;以便更轻松地浏览和管理您的图像库&#xff1f;如果答案是肯定的&#xff0c;那么EdgeView 4 for Mac将是您的理想之选&#xff01; EdgeView 4是一款专为Mac用户…

最短路径专题8 交通枢纽 (Floyd求最短路 )

题目&#xff1a; 样例&#xff1a; 输入 4 5 2 0 1 1 0 2 5 0 3 3 1 2 2 2 3 4 0 2 输出 0 7 思路&#xff1a; 由题意&#xff0c;绘制了该城市的地图之后&#xff0c;由给出的 k 个编号作为起点&#xff0c;求该点到各个点之间的最短距离之和最小的点是哪个&#xff0c;并…

Linux目录和文件查看命令

一、Linux 的目录结构 Linux 的目录结构是一个树状结构&#xff0c;以根目录&#xff08;/&#xff09;为起点&#xff0c;以下是常见的 Linux 目录结构的主要内容&#xff1a; / 根路径 ├── bin: 存放系统指令&#xff08;命令&#xff09;&#xff0c;如ls、cp、mv等&…

如何部署一个高可用高并发的电商平台

假设我们已经有了一个特别大的电商平台&#xff0c;这个平台应该部署在哪里呢&#xff1f;假设我们用公有云&#xff0c;一般公有云会有多个位置&#xff0c;比如在华东、华北、华南都有。毕竟咱们的电商是要服务全国的&#xff0c;当然到处都要部署了。我们把主站点放在华东。…

重启Oracle数据库命令列表逐步操作

&#x1f495;欢迎来到 Oracle 数据库重启教程&#x1f495; &#x1f3af;第一步&#xff1a;以 oracle 身份登录数据库&#x1f3af; su - oracle &#xff08;如果是WINdows系统的CMD窗口&#xff09;直接从第二步开始&#xff01; &#x1f3af;第二步&#xff1a;进入…

【进阶C语言】数组笔试题解析

本节内容以刷题为主&#xff0c;大致目录&#xff1a; 1.一维数组 2.字符数组 3.二维数组 学完后&#xff0c;你将对数组有了更全面的认识 在刷关于数组的题目前&#xff0c;我们先认识一下数组名&#xff1a; 数组名的意义&#xff1a;表示数组首元素的地址 但是有两个例外…

Js基础——事件流

引入 当浏览器发展到第四代时&#xff08; IE4 及 Netscape Communicator 4 &#xff09;&#xff0c;浏览器开发团队遇到了一个很有意思 的问题&#xff1a;页面的哪一部分会拥有某个特定的事件&#xff1f;要明白这个问题问的是什么&#xff0c;可以想象画在一张纸上的一组…

vue3简易文字验证码

大神勿喷&#xff0c;简易版本&#xff0c;demo中可以用一下。 需要几个文字自己codelen 赋值 灵活点直接父组件传过去&#xff0c;可以自己改造 首先创建一个生成数字的js **mathcode.js**function MathCode(num){let str "寻寻觅觅冷冷清清凄凄惨惨戚戚乍暖还寒时候…

千兆以太网传输层 UDP 协议原理与 FPGA 实现(UDP接收)

文章目录 前言心得体会一、 UDP 协议简单回顾二、UDP接收实现三、完整代码展示四、仿真测试(1)模拟电脑数据发送,(2)测试顶层文件编写(3)仿真文件(4)仿真波形前言 在前面我们对以太网 UDP 帧格式做了讲解,UDP 帧格式包括前导码+帧界定符、以太网头部数据、IP 头部数…

光伏发电预测(LSTM、CNN_LSTM和XGBoost回归模型,Python代码)

运行效果&#xff1a;光伏发电预测&#xff08;LSTM、CNN_LSTM和XGBoost回归模型&#xff0c;Python代码&#xff09;_哔哩哔哩_bilibili 运行环境库的版本 光伏太阳能电池通过互连形成光伏模块&#xff0c;以捕捉太阳光并将太阳能转化为电能。因此&#xff0c;当光伏模块暴露…

windows 任务计划自动提交 笔记到github 、gitee

一、必须有个git仓库托管到git上。 这个就不用说了&#xff0c;自己在github或者码云上新建一个仓库就行了。 二、创建自动提交脚本 这个bat脚本是在windows环境下使用的。 注意&#xff1a;windows定时任务下 调用自动提交git前&#xff0c;必须先进入该git仓库目录&#x…

【Linux】线程控制

&#x1f525;&#x1f525; 欢迎来到小林的博客&#xff01;&#xff01;       &#x1f6f0;️博客主页&#xff1a;✈️林 子       &#x1f6f0;️博客专栏&#xff1a;✈️ Linux       &#x1f6f0;️社区 :✈️ 进步学堂       &#x1f6f0…

DependsOn注解失效问题排查

文章目录 前言一、现象描述1.1.背景描述1.2.第一次修改&#xff0c;使用DependsOn注解1.3.第二次修改&#xff0c;设置方法入参 二、看看源码2.1.Spring实例化的源码2.2.调试2.3.验证 总结 前言 最近几天遇到一个比较有意思的问题&#xff0c;发现Spring的DependsOn注解失效&a…

CSS3与HTML5

box-sizing content-box&#xff1a;默认&#xff0c;宽高包不含边框和内边距 border-box&#xff1a;也叫怪异盒子&#xff0c;宽高包含边框和内边距 动画&#xff1a;移动translate&#xff0c;旋转、transform等等 走马灯&#xff1a;利用动画实现animation&#xff1a;from…