线形回归与小批量梯度下降实例

1、准备数据集

import numpy as np
import matplotlib.pyplot as pltfrom torch.utils.data import DataLoader
from torch.utils.data import TensorDataset#########################################################################
#################准备若干个随机的x和y#####################################
#########################################################################
np.random.seed(100)     #使用random.seed,设置一个固定的随机种子,
data_size = 150         # 数据集大小
x_range = 5             # x的范围
iteration_count = 100   # 迭代次数# np.random.rand 是 NumPy 库中的一个函数,用于生成一个给定形状的数组,
# 数组中的元素是从一个均匀分布的样本中抽取的,这个均匀分布是在半开区间 [0, 1) 上。
# 这意味着产生的随机数将大于等于0且小于1。#随机生成data_size个横坐标x,范围在0到x_range之间
x=x_range * np.random.rand(data_size,1)#生成带有噪音y数据,基本分布在y=2x+6的附近
y=2*x + 6 + np.random.randn(data_size,1)*0.3plt.scatter(x,y,marker='x',color='green')#########################################################################
#################将训练数据转为张量#######################################
#########################################################################
#将训练数据转为张量
tensorX = torch.from_numpy(x).float()
tensorY = torch.from_numpy(y).float()#使用TensorDataset,将tensorX和tensorY组成训练集
dataset = TensorDataset(tensorX,tensorY)#使用DataLoader,构造随机的小批量数据
dataloader=DataLoader(dataset,batch_size = 20, #每一个小批量的数据规模是20shuffle =True )  #随机打乱数据的顺序
print("dataloader len =%d" %(len(dataloader)))for index,(data,label) in enumerate(dataloader):print("index=%d num = %d"%(index,len(data)))

2、线性回归模型的训练思路

2.1 初始化参数

设置是模型参数:权重w和偏置b

初始化为随机值

设置 requires_grad=True,PyTorch 将记录这些张量的操作历史,用于后续的自动求导

2.2 循环训练(epoch

epoch 变量用于控制整个训练过程的迭代轮数

在机器学习和深度学习中,“epoch” 是一个常用的术语,指的是在整个数据集上完整地运行一次(即正向传播和反向传播)训练算法的过程。

定义:一个 epoch 是指训练过程中,训练集中每个样本都被使用过一次来更新模型的权重。
训练过程:在训练一个模型时,通常会将数据集分成多个批次(batches)。每个批次包含一定数量的样本。一个 epoch 完成意味着所有批次都已经过模型处理。
迭代与epoch:在一个 epoch 内,模型可能会多次迭代,每次迭代处理一个批次的数据。因此,一个 epoch 包含多个迭代(iterations)。
目的:通过多个 epochs 的训练,模型可以逐渐学习数据集中的模式,从而提高其性能。
数量:训练一个模型所需的 epochs 数量取决于多种因素,包括数据集的大小、模型的复杂度以及问题的难度。有时可能只需要几个 epochs,而有时可能需要数百甚至数千个 epochs。
监控:在训练过程中,通常会监控每个 epoch 的性能指标(如损失函数的值或准确率),以评估模型的学习进度。
过拟合与欠拟合:如果训练过多的 epochs,模型可能会过拟合(即模型学习到了数据中的噪声而非潜在的模式),而训练不足的 epochs 则可能导致欠拟合(即模型未能捕捉到数据中的关键模式)。

2.3 数据加载

内层循环通过 dataloader 遍历训练数据集的小批量数据。dataloader 是一个数据加载器,通常由 DataLoader 类创建,用于批量加载数据。

2.4 前向传播

假设 tensorX 是当前批次的数据

tensorY 是对应的真实标签

使用当前参数 w 和 b 计算预测值 h=w*tensorX +b。

2.5 计算损失

计算预测值 h 和真实值 tensorY 之间的均方误差(MSE),并保存到 loss

loss = torch.mean((h - tensorY) ** 2)

2.6 反向传播

调用 loss.backward() 进行反向传播,计算损失关于参数 w 和 b 的梯度

设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导

2.7 更新参数

使用梯度下降算法更新参数 w 和 b。学习率设置为0.01

w.data -= 0.01 * w.grad.data

b.data -= 0.01 * b.grad.data

沿着当前小批量计算的得到的梯度(导数)更新w和b

如果导数为0,则w、b保存不变

2.8 梯度清零

在每次迭代后,需要清空参数的梯度信息,以便下一次迭代计算

3、线性回归模型的实现

# 待送代的参数为w和b
w = torch.randn(1,requires_grad=True)
b = torch.randn(1,requires_grad=True)#进入模型的循环迭代
for epoch in range(1,iteration_count):#代表了整个训练数据集的迭代轮数# 在一个迭代轮次中,以小批量的方式,使用dataloader对数据# batch_index表示当前遍历的批次# data和label表示这个批次的训练数据和标记for batch_index,(data, label)in enumerate(dataloader):h = tensorX * w + b #计算当前直线的预测值,保存到h#计算预测值h和真实值y之间的均方误差,保存到loss中loss=torch.mean((h-tensorY)**2)#计算代价1oss关于参数w和b的偏导数,设置了 requires_grad=True,PyTorch 将记录这些张量的操作历史并自动求导loss.backward()#进行梯度下降,沿着梯度的反方向,更新w和b的值#沿着当前小批量计算的得到的梯度(导数)更新w和b#如果导数为0(Δw,Δb为0),则w、b保存不变w.data -=0.01 * w.grad.datab.data -=0.01 * b.grad.dataprint("epoch(%d) batch(%d) lossΔw,Δb,w,b, = %.3lf,%.3lf,%.3lf,%.3lf,%.3lf," %(epoch,batch_index,loss.item(),w.grad.data,b.grad.data,w.data,b.data))#清空张量w和b中的梯度信息,为下一次迭代做准备w.grad.zero_()b.grad.zero_()#每次迭代,都打印当前迭代的轮数epoch#数据的批次batch idx和loss损失值

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/936.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Unity3D日常开发】Unity3D中打开Window文件对话框打开文件(PC版)

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享QQ群:398291828小红书小破站 大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 这篇文章继续讲如何使用Unity3D打开Window文…

iOS 逆向学习 - Inter-Process Communication:进程间通信

iOS 逆向学习 - Inter-Process Communication:进程间通信 一、进程间通信概要二、iOS 进程间通信机制详解1. URL Schemes2. Pasteboard3. App Groups 和 Shared Containers4. XPC Services 三、不同进程间通信机制的差异四、总结 一、进程间通信概要 进程间通信&am…

零基础 监控数据可视化 Spring Boot 2.x(Actuator + Prometheus + Grafana手把手) (上)

一、安装Prometheus Releases prometheus/prometheus GitHubhttps://github.com/prometheus/prometheus/releases 或 https://prometheus.io/download/https://prometheus.io/download/ 1. 下载适用于 Windows 的二进制文件: 找到最新版本的发布页面&#xf…

【API】免费调用Qwen-vl2对图像打标

首次调用通义千问API_大模型服务平台百炼(Model Studio)-阿里云帮助中心https://help.aliyun.com/zh/model-studio/getting-started/first-api-call-to-qwen?spma2c4g.11186623.help-menu-2400256.d_0_1_0.8c693048HxtUzZ&scm20140722.H_2840915._.OR_help-T_cn~zh-V_1 一…

CF 371A.K-Periodic Array(Java实现)

题目分析 这里的意思是一共n个值每k个一组循环,最少改变多少个值就能让循环相同 思路分析 我在这里首先想的是二维数组方便观察循环,依据题目即为每一竖列比较,哪一个值出现的最少那么那就是需要更改的次数,(此题在这儿不考虑需要…

信息科技伦理与道德3:智能决策

1 概述 1.1 发展历史 1950s-1980s:人工智能的诞生与早期发展热潮 1950年:图灵发表了一篇划时代的论文,并提出了著名的“图灵测试”;1956年:达特茅斯会议首次提出“人工智能”概念;1956年-20世纪70年代&a…

一路相伴,非凸科技助力第49届ICPC亚洲区决赛

2024年12月27日-29日,第49届国际大学生程序设计竞赛亚洲区决赛在西北工业大学圆满举行。非凸科技再次作为EC Final的主要赞助方,鼎力支持这群心怀梦想的青年才俊,激励他们勇攀科技高峰,实现创新突破。 EC Final参赛名额主要由当…

MPLS原理及配置

赶时间可以只看实验部分 由来:90年代中期,互联网流量的快速增长。传统IP报文依赖路由器查询路由表转发,但由于硬件技术存在限制导致转发性能低,查表转发成为了网络数据转发的瓶颈。 因此,旨在提高路由器转发速度的MPL…

《机器学习》——TF-IDF(关键词提取)

文章目录 TF-IDF简介TF-IDF应用场景TF-IDF模型模型参数主要参数 TF-IDF实例实例步骤导入数据和模块处理数据处理文章开头和分卷处理将各卷内容存储到数据帧jieba分词和去停用词处理 计算 TF-IDF 并找出核心关键词 TF-IDF简介 TF - IDF(Term Frequency - Inverse Do…

【计算机网络】窥探计网全貌:说说计算机网络体系结构?

标签难度考察频率综合题⭐⭐⭐60% 这个问题在计算机网络知识体系中是一个比较重要的问题,只有完整地了解计算机网络的体系结构才能清晰地认识网络的运行原理。 在回答这个问题时,笔者认为有几个比较重要的点: 首先一定要分清楚前置条件&am…

【前端】【CSS3】基础入门知识

目录 如何学习CSS 1.1什么是CSS​编辑 1.2发展史 1.三种导入方式 1.1、行内样式 1.2、外部样式 1.3、嵌入方式 2.选择器 2.1、基本选择器 (1)元素选择器 (2)类选择器 (3)id选择器:必…

【解决】okhttp的java.lang.IllegalStateException: closed错误

问题 Android 使用OKHttp进行后端通信,后端处理结果,反馈给前端的responseBody中其实有值,但是一直报异常,后来才发现主要是OkHttp请求回调中response.body().string()只能有效调用一次,而我使用了两次: 解…

从硬件设备看Linux

一、介绍 DM3730通过各种连接方式连接了各种设备,输入输出设备根据不同的类型大体可 以分为电源管理、用户输人、显示输出、图像采集、存储以及无线设备等。我们可以将DM 3730与这些设备的数据接口分为总线和单一的数据接口总线。总线的显著特点是单个总线上可以连…

【优选算法】DC-Quicksort-Mysteries:分治-快排的算法之迷

文章目录 1.概念解析2.颜色分类3.排序数组4.数组中的第k个最大元素5.库存管理Ⅲ希望读者们多多三连支持小编会继续更新你们的鼓励就是我前进的动力! 本篇是优选算法之分治-快排,快排可以在更短的时间内完成相同规模数据的排序任务,大大提升了…

浅谈云计算09 | 服务器虚拟化

服务器虚拟化基础 一、虚拟化的定义二、系统虚拟化三、服务器虚拟化的核心要义四、典型实现:探索不同路径五、全虚拟化与半虚拟化六、主流服务器虚拟化技术 一、虚拟化的定义 虚拟化是一种将物理资源抽象为逻辑资源的技术,通过在物理硬件与操作系统、应…

解析OVN架构及其在OpenStack中的集成

引言 随着云计算技术的发展,虚拟化网络成为云平台不可或缺的一部分。为了更好地管理和控制虚拟网络,Open Virtual Network (OVN) 应运而生。作为Open vSwitch (OVS) 的扩展,OVN 提供了对虚拟网络抽象的支持,使得大规模部署和管理…

第三十六章 Spring之假如让你来写MVC——拦截器篇

Spring源码阅读目录 第一部分——IOC篇 第一章 Spring之最熟悉的陌生人——IOC 第二章 Spring之假如让你来写IOC容器——加载资源篇 第三章 Spring之假如让你来写IOC容器——解析配置文件篇 第四章 Spring之假如让你来写IOC容器——XML配置文件篇 第五章 Spring之假如让你来写…

CSS 盒模型

盒模型 CSS盒模型是网页布局的核心概念之一,它描述了网页元素的物理结构和元素内容与周围元素之间的关系。根据W3C规范,每个HTML元素都被视为一个矩形盒子,这个盒子由以下四个部分组成: 内容区(Content area&#xff…

JVM:ZGC详解(染色指针,内存管理,算法流程,分代ZGC)

1,ZGC(JDK21之前) ZGC 的核心是一个并发垃圾收集器,所有繁重的工作都在Java 线程继续执行的同时完成。这极大地降低了垃圾收集对应用程序响应时间的影响。 ZGC为了支持太字节(TB)级内存,设计了基…

OpenCV基于均值漂移算法(pyrMeanShiftFiltering)的水彩画特效

1、均值漂移算法原理 pyrMeanShiftFiltering算法结合了均值迁移(Mean Shift)算法和图像金字塔(Image Pyramid)的概念,用于图像分割和平滑处理。以下是该算法的详细原理: 1.1 、均值迁移(Mean …