pytorch实现单层线性回归模型

文章目录

    • 简述
      • 代码重构要点
    • 数学模型、运行结果
    • 数据构建与分批
    • 模型封装
    • 运行测试

简述

python使用 数值微分法 求梯度,实现单层线性回归-CSDN博客
python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客
数值微分求梯度、计算图求梯度,实现单层线性回归 模型速度差异及损失率比对-CSDN博客

上述文章都是使用python来实现求梯度的,是为了学习原理,实际使用上,pytorch实现了自动求导,原理也是(基于计算图的)链式求导,本文还就 “单层线性回归” 问题用pytorch实现。

代码重构要点

1.nn.Moudle

torch.nn.Module的继承、nn.Sequentialnn.Linear
torch.nn — PyTorch 2.4 documentation

对于nn.Sequential的理解可以看python使用 计算图(forward与backward) 求梯度,实现单层线性回归-CSDN博客一文代码的模型初始化与计算部分,如图:

在这里插入图片描述

nn.Sequential可以说是把图中标注的代码封装起来了,并且可以放多层。

2.torch.optim优化器

本例中使用随机梯度下降torch.optim.SGD()
torch.optim — PyTorch 2.4 documentation
SGD — PyTorch 2.4 documentation

3.数据构建与数据加载

data.TensorDatasetdata.DataLoader,之前为了实现数据分批,手动实现了data_iter,现在可以直接调用pytorch的data.DataLoader

对于data.DataLoader的参数num_workers,默认值为0,即在主线程中处理,但设置其它值时存在反而速度变慢的情况,以后再讨论。

数学模型、运行结果

y = X W + b y = XW + b y=XW+b

y为标量,X列数为2. 损失函数使用均方误差。

运行结果:

在这里插入图片描述

在这里插入图片描述

数据构建与分批

def build_data(weights, bias, num_examples):  x = torch.randn(num_examples, len(weights))  y = x.matmul(weights) + bias  # 给y加个噪声  y += torch.randn(1)  return x, y  def load_array(data_arrays, batch_size, num_workers=0, is_train=True):  """构造一个PyTorch数据迭代器"""  dataset = data.TensorDataset(*data_arrays)  return data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=is_train)

模型封装

class TorchLinearNet(torch.nn.Module):  def __init__(self):  super(TorchLinearNet, self).__init__()  model = nn.Sequential(Linear(in_features=2, out_features=1))  self.model = model  self.criterion = nn.MSELoss()  def predict(self, x):  return self.model(x)  def loss(self, y_predict, y):  return self.criterion(y_predict, y)

运行测试

if __name__ == '__main__':  start = time.perf_counter()  true_w1 = torch.rand(2, 1)  true_b1 = torch.rand(1)  x_train, y_train = build_data(true_w1, true_b1, 5000)  net = TorchLinearNet()  print(net)  init_loss = net.loss(net.predict(x_train), y_train)  loss_history = list()  loss_history.append(init_loss.item())  num_epochs = 3  batch_size = 50  learning_rate = 0.01  dataloader_workers = 6  data_loader = load_array((x_train, y_train), batch_size=batch_size, is_train=True)  optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)  for epoch in range(num_epochs):  # running_loss = 0.0  for x, y in data_loader:  y_pred = net.predict(x)  loss = net.loss(y_pred, y)  optimizer.zero_grad()  loss.backward()  optimizer.step()  # running_loss = running_loss + loss.item()  loss_history.append(loss.item())  end = time.perf_counter()  print(f"运行时间(不含绘图时间):{(end - start) * 1000}毫秒\n")  plt.title("pytorch实现单层线性回归模型", fontproperties="STSong")  plt.xlabel("epoch")  plt.ylabel("loss")  plt.plot(loss_history, linestyle='dotted')  plt.show()  print(f'初始损失值:{init_loss}')  print(f'最后一次损失值:{loss_history[-1]}\n')  print(f'正确参数: true_w1={true_w1}, true_b1={true_b1}')  print(f'预测参数:{net.model.state_dict()}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/402697.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

24/8/17算法笔记 策略梯度reinforce算法

import gym from matplotlib import pyplot as plt %matplotlib inline#创建环境 env gym.make(CartPole-v0) env.reset()#打印游戏 def show():plt.imshow(env.render(mode rgb_array))plt.show() show()定义网络模型 import torch #定义模型 model torch.nn.Sequential(t…

希亦、洁盟、苏泊尔眼镜清洗机哪款好用?热门眼镜清洗机测评总结

随着科学技术的发展,电子设备的升级,越来越多的人开始戴眼镜,而眼镜由于长时间的佩戴,镜框以及镜面都积累了一些灰尘以及人们肉眼所看不见的细菌,但是如果你使用普通的清洁方式去清洗的话肯定是清洗不干净的&#xff0…

【protobuf】ProtoBuf——proto3语法详解、字段规则、消息类型的定义与使用、通讯录的写入和读取功能实现

文章目录 ProtoBuf5. proto3语法详解5.1 字段规则5.2 消息类型的定义与使用 ProtoBuf 5. proto3语法详解 在语法详解部分,依旧通过项目推进的方式开展教学。此部分会对通讯录多次升级,用 2.x 表示升级的版本,最终将完成以下内容的升级&#x…

海康VisionMaster使用学习笔记4-快速匹配模块

快速匹配模块 快速匹配包括基本参数,特征模板,运行参数,结果显示 基本参数 可以修改图像源和模块的ROI区域. 特征模版 可以配置管理所有的模版,点击创建可以新增模版,也可以通过载入加载本地的模型 建立新模版 点击创建,可以选择当前图像或本地图像进行建模 模版存图按…

使用docker compose一键部署 Portainer

使用docker compose一键部署 Portainer Portainer 是一款轻量级的应用,它提供了图形化界面,用于方便地管理Docker环境,包括单机环境和集群环境。 1、创建安装目录 mkdir /data/partainer/ -p && cd /data/partainer2、创建docker…

uni-app 使用九宫格(uni-grid)布局组件

1、运行环境 开发工具为 HBuilder X 4.23, 操作系统为 Windows 11。Vue.js 版本为 3. 2、操作步骤 首先,登录 HBuilder X。然后用桌面浏览器,访问官网组件网址。 https://ext.dcloud.net.cn/plugin?nameuni-grid 在组件网址右上角、点击“下载插…

每日一题-贪心算法

122. 买卖股票的最佳时机 II - 力扣(LeetCode) 55. 跳跃游戏 - 力扣(LeetCode) 这个题目一开始肯定是会懵,就比如说一开始先跳几步,之后再怎么跳,其实我们就可以用最大范围来算就行了&#xff0…

开发笔记:uniapp+vue+微信小程序 picker +后端 省市区三级联动

写在前面 未采用: 前端放置js 或者 json文件进行 省市区三级联动 采用: 前端组件 后端接口实现三级联动 原因:首先微信小程序有大小限制,能省则省,其次:方便后台维护省市区数据,完整省市区每年更新好像…

SQL基础教程(八)SQL高级处理

※食用指南:文章内容为《SQL基础教程》系列学习笔记,该书对新手入门非常友好,循序渐进,浅显易懂,本人主要用来补全学习MySQL中未涉及的部分,便于刷题和做项目。 官方电子书:《SQL基础教程》第2…

Web安全:SqlMap工具

一、简介 sqlmap 是一款开源的渗透测试工具,可以自动化进行SQL注入的检测、利用,并能接管数据库服务器。它具有功能强大的检测引擎,为渗透测试人员提供了许多专业的功能并且可以进行组合,其中包括数据库指纹识别、数据读取和访问底层文件系统…

柔性超级电容器咋储能?生物聚合物在其中起啥作用?有啥挑战?

*本文只作阅读笔记分享* 一、引言 随着对化石燃料影响的日益关注,开发用于先进电化学能量存储设备的绿色和可再生材料变得至关重要。超级电容器因其出色的寿命、安全性和宽温度操作范围等优势而成为有前途的储能候选者。柔性超级电容器特别适合为轻质可穿戴电子设…

我常用的几个傻瓜式爬虫工具,收藏!

爬虫类工具主要两种,一种是编程语言第三方库,比如Python的scrapy、selenium等,需要有一定的代码基础,一种是图形化的web或桌面应用,比如Web Scraper、后羿采集器、八爪鱼采集器、WebHarvy等,接近于傻瓜式操…

qt生成一幅纯马赛克图像

由于项目需要&#xff0c;需生成一幅纯马赛克的图像作为背景&#xff0c;经过多次测试成功&#xff0c;记录下来。 方法一&#xff1a;未优化方法 1、代码&#xff1a; #include <QImage> #include <QDebug> #include <QElapsedTimer>QImage generateMosa…

MyBatis全解

目录 一&#xff0c; MyBatis 概述 1.1-介绍 MyBatis 的历史和发展 1.2-MyBatis 的特点和优势 1.3-MyBatis 与 JDBC 的对比 1.4-MyBatis 与其他 ORM 框架的对比 二&#xff0c; 快速入门 2.1-环境搭建 2.2-第一个 MyBatis 应用程序 2.3-配置文件详解 (mybatis-config.…

Pikachu-XSS漏洞之cookie值获取、钓鱼结果和键盘记录实战记录

目录 Pikachu-XSS漏洞之cookie值获取、钓鱼结果和键盘记录实战记录 一、XSS&#xff08;get型&#xff09;之cookie值获取&#xff1a; 二、xss&#xff08;post型&#xff09;之cookie值获取 三、Xss之钓鱼攻击 四、XSS获取键盘记 Pikachu-XSS漏洞之cookie值获取、钓鱼结果…

坐牢第二十七天(聊天室)

基于UDP的网络聊天室 一.项目需求&#xff1a; 1.如果有用户登录&#xff0c;其他用户可以收到这个人的登录信息 2.如果有人发送信息&#xff0c;其他用户可以收到这个人的群聊信息 3.如果有人下线&#xff0c;其他用户可以收到这个人的下线信息 4.服务器可以发送系统信息…

算法工程师第四十天(647. 回文子串 516.最长回文子序列 动态规划总结篇 )

参考文献 代码随想录 一、回文子串 给你一个字符串 s &#xff0c;请你统计并返回这个字符串中 回文子串 的数目。 回文字符串 是正着读和倒过来读一样的字符串。 子字符串 是字符串中的由连续字符组成的一个序列。 示例 1&#xff1a; 输入&#xff1a;s "abc"…

【stm32项目】多功能智能家居室内灯光控制系统设计与实现(完整工程资料源码)

多功能智能家居室内灯光控制系统设计与实现 目录&#xff1a; 目录&#xff1a; 前言&#xff1a; 一、项目背景与目标 二、国内外研究现状&#xff1a; 2.1 国内研究现状&#xff1a; 2.2 国外研究现状&#xff1a; 2.3 发展趋势 三、硬件电路设计 3.1 总体概述 3.2 硬件连接总…

图像压缩算法

8.1 JPEG压缩 (JPEG Compression) 介绍 JPEG&#xff08;Joint Photographic Experts Group&#xff09;压缩是最常用的有损图像压缩算法之一。它通过减少图像中的冗余数据来实现高效压缩&#xff0c;特别适用于自然图像。 原理 JPEG压缩的基本步骤包括颜色空间转换、离散余…

WPF篇(18)-DataGrid数据表格控件+ComboBox下拉框控件

DataGrid数据表格控件 DataGrid是一个可以多选的数据表格控件。所以&#xff0c;它继承一个支持多选的父类——MultiSelector。 public abstract class MultiSelector : Selector {protected MultiSelector();public IList SelectedItems { get; }protected bool CanSelectMu…