Pytorch笔记之回归

文章目录

  • 前言
  • 一、导入库
  • 二、数据处理
  • 三、构建模型
  • 四、迭代训练
  • 五、结果预测
  • 总结


前言

以线性回归为例,记录Pytorch的基本使用方法。


一、导入库

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable # 定义求导变量
from torch import nn, optim # 定义网络模型和优化器

二、数据处理

将数据类型转为tensor,第一维度变为batch_size

# 构建数据
x = np.random.rand(100)
noise = np.random.normal(0, 0.01, x.shape)
y = 0.1 * x + 0.2 + noise
# 数据处理
x_data = torch.FloatTensor(x.reshape(-1, 1))
y_data = torch.FloatTensor(y.reshape(-1, 1))
inputs = Variable(x_data)
target = Variable(y_data)

三、构建模型

1、继承nn.Module,定义一个线性回归模型。在__init__中定义连接层,定义前向传播的方法
2、实例化模型,定义损失函数与优化器

# 继承模型
class LinearRegression(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(1, 1)def forward(self, x):out = self.fc(x)return out
# 定义模型
print('模型参数')
model = LinearRegression()
mse_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for name, param in model.named_parameters():print('{}:{}'.format(name, param))

四、迭代训练

1、梯度清零:optimizer.zero_grad()
2、反向传播计算梯度值:loss.backward()
3、执行参数更新:optimizer.step()
循环迭代,定期输出损失值

print('损失值')
for i in range(1001):out = model.forward(inputs)loss = mse_loss(out, target)optimizer.zero_grad()loss.backward()optimizer.step()if i % 200 == 0:print(i, loss.item())

五、结果预测

绘制样本的散点图与预测值的折线图

print('结果预测')
y_pred = model(x_data)
plt.plot(x, y, 'b.')
plt.plot(x, y_pred.data.numpy(), 'r-')
plt.show()


总结

使用Pytorch进行训练主要的三步:
(1)数据处理:将数据维度转换为(batch, *),数据类型转换为可训练的tensor;
(2)构建模型:继承nn.Module,定义连接层与运算方法,实例化,定义损失函数与优化器;
(3)迭代训练:循环迭代,依次执行梯度清零、梯度计算、参数更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/152668.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

「专题速递」AR协作、智能NPC、数字人的应用与未来

元宇宙是一个融合了虚拟现实、增强现实、人工智能和云计算等技术的综合概念。它旨在创造一个高度沉浸式的虚拟环境,允许用户在其中交互、创造和共享内容。在元宇宙中,人们可以建立虚拟身份、参与虚拟社交,并享受无限的虚拟体验。 作为互联网大…

Prompt-Tuning(一)

一、预训练语言模型的发展过程 第一阶段的模型主要是基于自监督学习的训练目标,其中常见的目标包括掩码语言模型(MLM)和下一句预测(NSP)。这些模型采用了Transformer架构,并遵循了Pre-training和Fine-tuni…

上班第一天同事让我下载个小乌龟,我就去百度小乌龟。。。。

记得那会儿是刚毕业,去上班第一天,管我的那个上级说让我下载个小乌龟,等下把代码拉一下,我那是一脸懵逼啊,我在学校只学过git啊,然后开始磨磨蹭蹭吭吭哧哧的不知所措,之后我想也许百度能救我&am…

C语言内存函数

目录 memcpy(Copy block of memory)使用和模拟实现memcpy的模拟实现 memmove(Move block of memory)使用和模拟实现memmove的模拟实现: memset(Fill block of memory)函数的使用扩展 memcmp(Compare two blocks of memory)函数的使用 感谢各位大佬对我的支持,如果我的文章对你有…

PSN 两步验证解除2023.10.9经验贴

背景 本人10月1号收到Sony邮件,说是不规律登录,需修改密码后登录,然后我10月8日登录PS4的时候,提示两步验证。当时就想坏了,然后找B站相关经验贴,10月9号电话香港客服,解除了两步验证&#xff0…

Windows10打开应用总是会弹出提示窗口的解决方法

用户们在Windows10电脑中打开应用程序,遇到了总是会弹出提示窗口的烦人问题。这样的情况会干扰到用户的正常操作,给用户带来不好的操作体验,接下来小编给大家详细介绍关闭这个提示窗口的方法,让大家可以在Windows10电脑中舒心操作…

计算机网络八股

1、请你说说TCP和UDP的区别 TCP提供面向连接的可靠传输,UDP提供面向无连接的不可靠传输。UDP在很多实时性要求高的场景有很好的表现,而TCP在要求数据准确、对速度没有硬件要求的场景有很好的表现。TCP和UDP都是传输层协议,都是为应用层程序服…

大数据——Spark Streaming

是什么 Spark Streaming是一个可扩展、高吞吐、具有容错性的流式计算框架。 之前我们接触的spark-core和spark-sql都是离线批处理任务,每天定时处理数据,对于数据的实时性要求不高,一般都是T1的。但在企业任务中存在很多的实时性的任务需求&…

超大视频如何优雅切片

背景 有一次录屏产生了一个大小为33G的文件, 我想把他上传到B站, 但是B站最大只支持4G. 无法上传, 因此做了一个简单的探索. 质疑与思考 a. 有没有一个工具或一个程序协助我做分片呢? 尝试 a. 必剪 > 有大小限制, 添加素材加不进去(而且报错信息也提示的不对) b. PR &…

C++设计模式_07_Bridge 桥模式

文章目录 1. 动机(Motivation)2. 代码演示Bridge 桥模式2.1 基于继承的常规思维处理2.2 基于组合关系的重构优化2.3 采用Bridge 桥模式的实现 3. 模式定义4. 结构(Structure)5. 要点总结 与上篇介绍的Decorator 装饰模式一样&…

从零开始读懂相对论:探索爱因斯坦的科学奇迹

💂 个人网站:【工具大全】【游戏大全】【神级源码资源网】🤟 前端学习课程:👉【28个案例趣学前端】【400个JS面试题】💅 寻找学习交流、摸鱼划水的小伙伴,请点击【摸鱼学习交流群】 引言 阿尔伯特爱因斯坦…

竞赛 机器视觉 opencv 深度学习 驾驶人脸疲劳检测系统 -python

文章目录 0 前言1 课题背景2 Dlib人脸识别2.1 简介2.2 Dlib优点2.3 相关代码2.4 人脸数据库2.5 人脸录入加识别效果 3 疲劳检测算法3.1 眼睛检测算法3.2 打哈欠检测算法3.3 点头检测算法 4 PyQt54.1 简介4.2相关界面代码 5 最后 0 前言 🔥 优质竞赛项目系列&#x…

Transformer预测 | Pytorch实现基于Transformer 的锂电池寿命预测(CALCE数据集)

文章目录 效果一览文章概述模型描述程序设计参考资料效果一览 文章概述 Pytorch实现基于Transformer 的锂电池寿命预测,环境为pytorch 1.8.0,pandas 0.24.2 随着充放电次数的增加,锂电池的性能逐渐下降。电池的性能可以用容量来表示,故寿命预测 (RUL) 可以定义如下: SOH(t…

flutter开发实战-video_player插件播放抖音直播实现(仅限Android端)

flutter开发实战-video_player插件播放抖音直播实现(仅限Android端) 在之前的开发过程中,遇到video_player播放视频,通过查看video_player插件描述,可以看到video_player在Android端使用exoplayer,在iOS端…

workerman的基本用法(示例详解)

workerman是什么? Workerman是一个异步事件驱动的PHP框架,具有高性能,可轻松构建快速,可扩展的网络应用程序。支持HTTP,Websocket,SSL和其他自定义协议。支持libevent,HHVM,ReactPH…

el-table 设置最大高度且能刚好撑满

max-height"calc(90vh - 120px)"90vh视口高度的90%自行调整即可

解决: 使用html2canvas和print-js打印组件时, 超出高度出现空白页

如果所示:当我利用html2canvas转换成图片后, 然后使用print-js打印多张图片, 第一张会出现空白页 打印组件可参考这个: Vue-使用html2canvas和print-js打印组件 解决: 因为是使用html2canvas转换成图片后才打印的, 而图片是行内块级元素, 会有间隙, 所以被挤下去了…

真香!Jenkins 主从模式解决问题So Easy~

01.Jenkins 能干什么 Jenkins 是一个开源软件项目,是基于 Java 开发的一种持续集成工具,用于监控持续重复的工作,旨在提供一个开放易用的软件平台,使软件项目可以进行持续集成。 中文官网:https://jenkins.io/zh/ 0…

Docker基础(CentOS 7)

参考资料 hub.docker.com 查看docker官方仓库,需要梯子 Docker命令大全 黑马程序员docker实操教程 (黑马讲的真的不错 容器与虚拟机 安装 yum install -y docker Docker服务命令 启动服务 systemctl start docker停止服务 systemctl stop docker重启…

Redis AOF重写原原理

重写aof之前 appendonly.aof.1.base.aof appendonly.aof.1.incr.aof appendonly.aof.manifest 重写aof 一次 appendonly.aof.2.base.aof 大小变化 appendonly.aof.2.incr.aof 大小o appendonly.aof.manifest 大小不变 AOF文件重写并不是对原文件进行重新整理,而是直…