一个有意思pytorch的简单应用小实验

        通过一个简单的脚本,来学习pytorch的基本应用,比如:前向传播、反向传播、学习率以及预测、模型的基本原理和套路。

        得到结果。。。保存模型。。。输入参数。。。预测。。。像不像?。。。像多少?。。。

        设计目标:一个包含了两个元素的输入张量,经过一个线性模型的运算后输出预测结果,经过前向传播、反向传播、学习调整后,使预测的结果尽量接近目标结果。

        输入张量:in_tensor=[2.0, 9.0]

        线性模型:model(k0 * in_tensor[0] + k1 * in_tensor[1])

        目标结果:100。

        总结来说就是:设计目标:2.0*k0 + 9.0*k1 = 100,通过Pytorch的惯用框架和套路,经过多次学习和迭代优化之后,求出k0和k1的最优值。

基本代码

import torch
import random# 定义常量
TARGET_VALUE = 100
LR = 0.01  # 学习率# 初始化张量和权重
in_tensor = torch.tensor([2.0, 4.0])
k0 = torch.tensor(random.random(), requires_grad=True)  # 权重需要计算梯度
k1 = torch.tensor(random.random(), requires_grad=True)  # 权重需要计算梯度# 定义模型
def model(in_tensor, k0, k1):return k0 * in_tensor[0] + k1 * in_tensor[1]   # 定义了一个简单的线性模型# 定义损失函数
def loss_fn(y_pred, y_true):return (y_pred - y_true) ** 2   # 均方误差(MSE),计算预测值与真实值之间的平方差。# 训练过程
def train(iterations, in_tensor, k0, k1):for i in range(iterations):# 前向传播y_pred = model(in_tensor, k0, k1)   # 预测结果loss = loss_fn(y_pred, TARGET_VALUE)   # 损失值(可以理解为误差)# 反向传播loss.backward()# 更新权重with torch.no_grad():   # 停止梯度跟踪k0 -= LR * k0.grad  # k0减去它的梯度*学习率,完成一次权重的调整k1 -= LR * k1.grad  # k1减去它的梯度*学习率,完成一次权重的调整# 清零权重的梯度k0.grad.zero_()k1.grad.zero_()print(f"Iteration {i}: y_pred = {y_pred}, loss = {loss.item()}")# 开始训练
train(10, in_tensor, k0, k1)

运行结果:

 可以看到,由于模型很简单,收敛很快,经过10次训练,loss已经降到了0.98。

学习率的实验

上面是学习率LR = 0.01得到的训练结果,现改为LR = 0.015:

同样的10次训练,当学习率增加之后,loss已经降到了0.000625,模型的收敛速度加快。 

继续加大学习率,改为LR = 0.03: 

 loss已经降到了2.85e-9,模型的收敛速度更快了。

继续加大,LR = 0.06:

预测值剧烈震荡,模型无法收敛。

知识点:

        加大学习率,可以加快模型收敛速度,但是也不能过大,学习率过大的后果:

        1. 无法收敛

        • 跳过最优解: 学习率过大时,每次参数更新的步长也会很大,这可能导致模型在优化过程中跳过最优解。

        • 震荡: 模型参数可能会在最佳值附近来回震荡,无法稳定地达到收敛。

        • 梯度爆炸: 在极端情况下,学习率过大可能导致梯度值变得非常大,进而使得参数更新步长过大,甚至导致数值溢出(如NaN)。  

        2. 训练不稳定

        • 损失函数波动: 损失函数的值可能会在每次迭代中剧烈波动,而不是逐渐减小。

        • 泛化能力差: 由于模型参数未能稳定收敛,可能导致模型在测试集上的表现不稳定,泛化能力差。  

        3. 过拟合风险增加

        • 在某些情况下,即使模型最终收敛,也可能因为学习率过大而错过最优解,导致过拟合。

再来,将学习率变小,LR = 0.006:

模型也在持续收敛,但是比起LR = 0.01,收敛变慢了。 

LR = 0.004:

收敛更慢了。

知识点:

        学习率过小的后果:

        1. 收敛速度慢

        • 训练时间长: 由于每次参数更新的步长很小,模型需要更多的迭代次数才能达到最优解,导致训练时间显著增加。

        • 陷入局部最优: 在某些情况下,学习率过小可能导致模型陷入局部最优解,而不是全局最优解。  

        2. 过拟合风险增加

        • 过度训练: 由于训练时间过长,模型可能在训练集上过度拟合,导致在测试集上的表现下降。  

        3. 梯度消失

        • 接近零的梯度: 学习率过小,尤其是在深度神经网络中,可能导致梯度值变得非常小,进而使得参数更新几乎停滞,这种现象称为梯度消失。

早停机制

将局部代码改为:

LR = 0.016
train(100, in_tensor, k0, k1)

在训练了16次之后,loss已经为0。所以,就可以停止训练了。

局部代码修改为:

# 训练过程
def train(iterations, in_tensor, k0, k1):for i in range(iterations):# 前向传播y_pred = model(in_tensor, k0, k1)   # 预测结果loss = loss_fn(y_pred, TARGET_VALUE)   # 损失值(可以理解为误差)if loss <= 0.00001:   # 早停机制print(f"Iteration {i}: y_pred = {y_pred}, loss = {loss.item()}, ki = {k0}, k1 = {k1}")break# 反向传播
。。。。。。。。。。。

 当偏差足够小时,停止训练,并输出训练结果。

保存模型和使用模型预测

import torch
import random# 定义常量
TARGET_VALUE = 100
LR = 0.016  # 学习率# 初始化张量和权重
in_tensor = torch.tensor([2.0, 4.0])
pre_tensor = torch.tensor([2.2, 4.0])
k0 = torch.tensor(random.random(), requires_grad=True)  # 权重需要计算梯度
k1 = torch.tensor(random.random(), requires_grad=True)  # 权重需要计算梯度
model_state = []   # 模型参数# 定义模型
def model(in_tensor, k0, k1):return k0 * in_tensor[0] + k1 * in_tensor[1]# 定义损失函数
def loss_fn(y_pred, y_true):return (y_pred - y_true) ** 2# 训练过程
def train(iterations, in_tensor, k0, k1):for i in range(iterations):# 前向传播y_pred = model(in_tensor, k0, k1)loss = loss_fn(y_pred, TARGET_VALUE)if loss <= 0.00001:print(f"Iteration {i}: y_pred = {y_pred}, loss = {loss.item()}, ki = {k0}, k1 = {k1}")return [k0, k1]   # 返回训练后的模型参数break# 反向传播loss.backward()# 更新权重with torch.no_grad():   # 停止梯度跟踪k0 -= LR * k0.gradk1 -= LR * k1.grad# 清零权重的梯度k0.grad.zero_()k1.grad.zero_()# print(f"Iteration {i}: y_pred = {y_pred}, loss = {loss.item()}")# 开始训练
model_state = train(100, in_tensor, k0, k1) # 训练100次# 保存模型
torch.save(model_state, "model_state.pt")# 加载模型
model_state = torch.load("model_state.pt")
print(model_state)# 预测
y_pred = model(pre_tensor, model_state[0], model_state[1])
print(f"y_pred = {y_pred}, loss = {loss_fn(y_pred, TARGET_VALUE)}")

        在上面的代码中,我们保存了一个模型,并且用它预测了一个张量[2.2, 4.0],与我们训练用的数据[2.0, 4.0]相差不多,所以预测结果也相差不多。如果换成不同的数据,那么预测的结果也将会不同。

        推而广之,如果把输入的张量换成一个图像的像素阵列,预测结果换为判断类别,模型换为多层的卷积神经网络,再加上一些层间池化、输出激活函数,那么就是pytorch最常见的图像识别套路了。所以,无论模型和应用框架多么复杂,也是由最简单的结构迭加、衍生而成,将一个复杂的任务分解成一个个简单任务,它就不再复杂。

        以上为一点点初学者的肤浅心得,与大家交流共勉,望多指教!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/484662.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot 分层解耦

从没有分层思想到传统 Web 分层&#xff0c;再到 Spring Boot 分层架构 1. 没有分层思想 在最初的项目开发中&#xff0c;很多开发者并没有明确的分层思想&#xff0c;所有逻辑都堆砌在一个类或一个方法中。这样的开发方式通常会导致以下问题&#xff1a; 代码混乱&#xff1…

2024 数学建模国一经验分享

2024 数学建模国一经验分享 背景&#xff1a;武汉某211&#xff0c;专业&#xff1a;计算机科学 心血来潮&#xff0c;就从学习和组队两个方面指点下后来者&#xff0c;帮新人避坑吧 2024年我在数学建模比赛中获得了国一&#xff08;教练说论文的分数是湖北省B组第一&#xff0…

Linux 35.6 + JetPack v5.1.4之RTP实时视频Python框架

Linux 35.6 JetPack v5.1.4之RTP实时视频Python框架 1. 源由2. 思路3. 方法论3.1 扩展思考 - 慎谋而后定3.2 扩展思考 - 拒绝拖延或犹豫3.3 扩展思考 - 哲学思考3.4 逻辑实操 - 方法论 4 准备5. 分析5.1 gst-launch-1.05.1.1 xvimagesink5.1.2 nv3dsink5.1.3 nv3dsink sync05…

渤海证券基于互联网环境的漏洞主动防护方案探索与实践

来源&#xff1a;中国金融电脑 作者&#xff1a;渤海证券股份有限公司信息技术总部 刘洋 伴随互联网业务的蓬勃发展&#xff0c;证券行业成为黑客进行网络攻击的重要目标之一&#xff0c;网络攻击的形式也变得愈发多样且复杂。网络攻击如同悬于行业之上的达摩克利斯之剑&…

隐私安全大考,Facebook 如何应对?

随着数字时代的到来和全球互联网用户的快速增长&#xff0c;隐私安全问题已上升为网络世界的重要议题。社交媒体巨头Facebook因其庞大的用户群体和大量的数据处理活动&#xff0c;成为隐私问题的聚焦点。面对隐私安全的大考&#xff0c;Facebook采取了一系列策略来应对这些挑战…

04 创建一个属于爬虫的主虚拟环境

文章目录 回顾conda常用指令创建一个爬虫虚拟主环境Win R 调出终端查看当前conda的虚拟环境创建 spider_base 的虚拟环境安装完成查看环境是否存在 为 pycharm 配置创建的爬虫主虚拟环境选一个盘符来存储之后学习所写的爬虫文件用 pycharm 打开创建的文件夹pycharm 配置解释器…

旅游管理系统的设计与实现

文末获取源码和万字论文&#xff0c;制作不易&#xff0c;感谢点赞支持。 毕 业 设 计&#xff08;论 文&#xff09; 题目&#xff1a;旅游管理系统的设计与实现 摘 要 如今社会上各行各业&#xff0c;都喜欢用自己行业的专属软件工作&#xff0c;互联网发展到这个时候&#…

QT 中 sqlite 数据库使用

一、前提 --pro文件添加sql模块QT core gui sql二、使用 说明 --用于与数据库建立连接QSqlDatabase--执行各种sql语句QSqlQuery--提供数据库特定的错误信息QSqlError查看qt支持的驱动 QStringList list QSqlDatabase::drivers();qDebug()<<list;连接 sqlite3 数据库 …

HENU祖传课堂测试第三弹:Java的文件输入输出

题目&#xff1a;设定文件file1内容&#xff1a;年级,班级&#xff0c;学号&#xff0c;姓名分为四行。 读取文件file1中的内容&#xff0c;若其字符<3个将其转入file2,如若是字符&#xff1e;3个转入file3 代码如下 import java.io.*; import java.nio.file.*; import j…

React Native 速度提升 550%

React Native 爱好者们!🌟 您准备好听一些激动人心的消息了吗?React Native 刚刚发布了其最大的更新之一:一种全新的架构,彻底改变了我们构建移动应用程序的方式。如果您想知道这对您的项目和开发体验意味着什么,请继续关注!我们正在深入探讨这个改变游戏规则的事物;您…

Qt中的 tableView 设置 二进制 十六进制 序号表头

二 进制序号 因为QTableView的垂直表头并不支持使用委托来自定义。 相反&#xff0c;可以通过将自定义的QWidget作为QHeaderView的标签来实现这一目标。 代码&#xff1a; #include <QApplication> #include <QMainWindow> #include <QVBoxLayout> #include …

中国移动量子云平台:算力并网590量子比特!

在技术革新的浪潮中&#xff0c;量子计算以其独特的并行处理能力和指数级增长的计算潜力&#xff0c;有望成为未来技术范式变革和颠覆式创新应用的新源泉。中国移动作为通信行业的领军企业&#xff0c;致力于量子计算技术研究&#xff0c;推动量子计算产业的跨越式发展。 量子云…

D614 PHP+MYSQL +失物招领系统网站的设计与现 源代码 配置 文档

失物招领系统 1.摘要2. 系统开发的背景和意义3.功能结构图4.界面展示5.源码获取 1.摘要 随着互联网的迅速发展&#xff0c;人们的生产生活方式逐渐发生改变&#xff0c;传统的失物招领也可以通过网络处理。本网站是基PHP技术的一款综合性较强的西南民族大学PHP失物招领系统。 …

YOLOv8实战道路裂缝缺陷识别

本文采用YOLOv8作为核心算法框架&#xff0c;结合PyQt5构建用户界面&#xff0c;使用Python3进行开发。YOLOv8以其高效的实时检测能力&#xff0c;在多个目标检测任务中展现出卓越性能。本研究针对道路裂缝数据集进行训练和优化&#xff0c;该数据集包含丰富的道路裂缝图像样本…

并发编程(15)——基于同步方式的线程安全的栈和队列

文章目录 十四、day141. 线程安全的栈1.1 存在隐患的栈容器1.2 优化后的栈容器 2. 线程安全的队列2.1 基于智能指针的线程安全的队列2.2 不同互斥量管理队首、队尾的队列 十四、day14 在并发编程&#xff08;1&#xff09;并发编程&#xff08;5&#xff09;中&#xff0c;我们…

容器第五天(day042)

1.安装 yum install -y docker-compose 2.配置 配置文件名字&#xff1a;docker-compose.yaml或docker-compose.yml 3.启动 docker-compose up -d

离散数学重点复习

第一章.集合论 概念 1.集合是不能精确定义的基本数学概念.通常是由指定范围内的满足给定条件的所有对象聚集在一起构成的 2.制定范围内的每一个对象称为这个集合的元素 3.固定符号如下: N:自然数集合 Z:整数集合 Q:有理数集合 R:实数集合 C:复数集合 4.集合中的元素是…

docker学习笔记(四)--DockerFile

文章目录 一、什么是Dockerfile二、docker build命令三、dockerfile指令3.1 FROM3.2 ENV3.3 WORKDIR3.4 RUN3.5 CMD3.6 ENTRYPOINT3.7 EXPOSE3.8 ARG3.9 ADD3.10 COPY3.11 VOLUME 四、dockerfile示例 一、什么是Dockerfile Dockerfile 是用于构建 Docker 镜像的脚本文件&#…

动手学深度学习-线性神经网络-1线性回归

目录 线性回归的基本元素 线性模型 损失函数 解析解 随机梯度下降 用模型进行预测 矢量化加速 正态分布与平方损失 从线性回归到深度网络 神经网络图 生物学 小结 回归&#xff08;regression&#xff09;是能为一个或多个自变量与因变量之间关系建模的一类方法。…

BERT模型的输出格式探究以及提取出BERT 模型的CLS表示,last_hidden_state[:, 0, :]用于提取每个句子的CLS向量表示

说在前面 最近使用自己的数据集对bert-base-uncased进行了二次预训练&#xff0c;只使用了MLM任务&#xff0c;发现在加载训练好的模型进行输出CLS表示用于下游任务时&#xff0c;同一个句子的输出CLS表示都不一样&#xff0c;并且控制台输出以下警告信息。说是没有这些权重。…