动手学深度学习-3.2 线性回归的从0开始

以下是代码的逐段解析及其实际作用:


1. 环境设置与库导入

%matplotlib inline
import random
import torch
from d2l import torch as d2l
  • 作用
    • %matplotlib inline:在 Jupyter Notebook 中内嵌显示 matplotlib 图形。
    • random:生成随机索引用于数据打乱。
    • torch:PyTorch 深度学习框架。
    • d2l:《动手学深度学习》提供的工具函数库(如绘图工具)。

2. 生成合成数据

假设真实权重向量为 w true ∈ R n \mathbf{w}_{\text{true}} \in \mathbb{R}^n wtrueRn,偏置为 b true b_{\text{true}} btrue,噪声为高斯分布 ϵ ∼ N ( 0 , σ 2 ) \epsilon \sim \mathcal{N}(0, \sigma^2) ϵN(0,σ2),则合成数据生成公式为:
y = X w true + b true + ϵ \mathbf{y} = \mathbf{X} \mathbf{w}_{\text{true}} + b_{\text{true}} + \epsilon y=Xwtrue+btrue+ϵ
其中:

  • X ∈ R m × n \mathbf{X} \in \mathbb{R}^{m \times n} XRm×n:输入特征矩阵( m m m 个样本, n n n 个特征)。
  • w true ∈ R n \mathbf{w}_{\text{true}} \in \mathbb{R}^n wtrueRn:真实权重向量。
  • ϵ ∈ R m \epsilon \in \mathbb{R}^m ϵRm:噪声向量。
def synthetic_data(w, b, num_examples):  #@save"""生成y=Xw+b+噪声"""X = torch.normal(0, 1, (num_examples, len(w)))  # 生成标准正态分布的输入特征 num_examples行,len(w)列y = torch.matmul(X, w) + b                      # 计算线性输出 y = Xw + by += torch.normal(0, 0.01, y.shape)             # 添加高斯噪声return X, y.reshape((-1, 1))                    # y行数不定(值为-1,列数为1)true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

生成的函数是一个二维线性回归模型,其数学表达式为:

y = w 1 x 1 + w 2 x 2 + b + ϵ y = w_1 x_1 + w_2 x_2 + b + \epsilon y=w1x1+w2x2+b+ϵ

其中:

  • 权重 w = [ w 1 , w 2 ] = [ 2 , − 3.4 ] \mathbf{w} = [w_1, w_2] = [2, -3.4] w=[w1,w2]=[2,3.4],由 true_w 定义。
  • 偏置 b = 4.2 b = 4.2 b=4.2,由 true_b 定义。
  • 噪声 ϵ ∼ N ( 0 , 0.0 1 2 ) \epsilon \sim \mathcal{N}(0, 0.01^2) ϵN(0,0.012),即均值为 0、标准差为 0.01 的高斯噪声。

展开为标量形式:
y i = 2 ⋅ x i 1 − 3.4 ⋅ x i 2 + 4.2 + ϵ i ( i = 1 , 2 , … , 1000 ) y_i = 2 \cdot x_{i1} - 3.4 \cdot x_{i2} + 4.2 + \epsilon_i \quad (i = 1, 2, \dots, 1000) yi=2xi13.4xi2+4.2+ϵi(i=1,2,,1000)


3. 数据可视化

d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(), labels.detach().numpy(), 1);
  • 绘制第二个特征(features[:,1] => n行第1列)与标签 labels 的散点图。

4. 定义数据迭代器

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)  # 打乱索引顺序for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]  # 生成小批量数据
  • 作用
    • 将数据集按 batch_size 划分为小批量,并随机打乱顺序。
    • 使用生成器 (yield) 逐批返回数据,避免一次性加载全部数据到内存。

5. 初始化模型参数

w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
  • 初始化w和b的值
    • w:从均值为 0、标准差为 0.01 的正态分布中初始化权重,启用梯度追踪。
    • b:初始化为 0 的偏置,启用梯度追踪。
    • 参数需梯度追踪以支持反向传播。

6. 定义模型、损失函数和优化器

def linreg(X, w, b):  #@save"""线性回归模型"""return torch.matmul(X, w) + bdef squared_loss(y_hat, y):  #@save"""均方损失"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2  # 除以2便于梯度计算def sgd(params, lr, batch_size):  #@save"""小批量随机梯度下降"""with torch.no_grad():  # 禁用梯度计算for param in params:param -= lr * param.grad / batch_size  # 参数更新param.grad.zero_()                     # 梯度清零
  • linreg:模型预测值 y ^ \hat{\mathbf{y}} y^ 的矩阵形式为:
    y ^ = X w + b \hat{\mathbf{y}} = \mathbf{X} \mathbf{w} + b y^=Xw+b
    其中:

    • w ∈ R n \mathbf{w} \in \mathbb{R}^n wRn:待学习的权重向量。
    • b ∈ R b \in \mathbb{R} bR:待学习的偏置。
  • squared_loss:损失函数的矩阵形式为:
    L = 1 2 ∥ y ^ − y ∥ 2 L = \frac{1}{2} \| \hat{\mathbf{y}} - \mathbf{y} \|^2 L=21y^y2

    L ( w , b ) = 1 2 m ∥ X w + b − y ∥ 2 L(\mathbf{w}, b) = \frac{1}{2m} \| \mathbf{X} \mathbf{w} + b - \mathbf{y} \|^2 L(w,b)=2m1Xw+by2
    展开后:
    L ( w , b ) = 1 2 m ( X w + b 1 − y ) ⊤ ( X w + b 1 − y ) L(\mathbf{w}, b) = \frac{1}{2m} (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y})^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}) L(w,b)=2m1(Xw+b1y)(Xw+b1y)

  • sgd:小批量随机梯度下降优化器,

    • 对权重 w \mathbf{w} w 的梯度
      ∇ w L = 1 m X ⊤ ( X w + b 1 − y ) \nabla_{\mathbf{w}} L = \frac{1}{m} \mathbf{X}^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}) wL=m1X(Xw+b1y)

    • 对偏置 b b b 的梯度
      ∇ b L = 1 m 1 ⊤ ( X w + b 1 − y ) , 1 为单位列向量 \nabla_{b} L = \frac{1}{m} \mathbf{1}^\top (\mathbf{X} \mathbf{w} + b \mathbf{1} - \mathbf{y}),\mathbf{1} 为单位列向量 bL=m11(Xw+b1y)1为单位列向量

    • 使用学习率 η \eta η,参数更新公式为:
      w ← w − η ∇ w L b ← b − η ∇ b L \mathbf{w} \leftarrow \mathbf{w} - \eta \nabla_{\mathbf{w}} L\\ b \leftarrow b - \eta \nabla_{b} L wwηwLbbηbL


7. 训练循环

lr = 0.03
num_epochs = 3
batch_size = 10  # 需补充定义(原代码未显式定义)for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)  # 计算小批量损失l.sum().backward()         # 反向传播计算梯度sgd([w, b], lr, batch_size) # 更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
  • 作用

    • 外层循环:遍历训练轮次 (num_epochs)。
    • 内层循环:按小批量遍历数据,计算损失并反向传播。
    • l.sum().backward():将小批量损失求和后反向传播,计算梯度。
    • sgd:根据梯度更新参数,梯度需除以 batch_size 以保持学习率一致性。
    • 每个 epoch 结束后,计算并打印整体训练损失。
    • mean()函数计算平均值
  • 梯度下降

  l.sum().backward()  # 反向传播计算梯度sgd([w, b], lr, batch_size)  # 更新参数
  • 小批量梯度计算公式:
    ∇ w L batch = 1 batch_size X batch ⊤ ( X batch w + b − y batch ) \nabla_{\mathbf{w}} L_{\text{batch}} = \frac{1}{\text{batch\_size}} \mathbf{X}_{\text{batch}}^\top (\mathbf{X}_{\text{batch}} \mathbf{w} + b - \mathbf{y}_{\text{batch}}) wLbatch=batch_size1Xbatch(Xbatchw+bybatch)
    ∇ b L batch = 1 batch_size 1 ⊤ ( X batch w + b − y batch ) \nabla_{b} L_{\text{batch}} = \frac{1}{\text{batch\_size}} \mathbf{1}^\top (\mathbf{X}_{\text{batch}} \mathbf{w} + b - \mathbf{y}_{\text{batch}}) bLbatch=batch_size11(Xbatchw+bybatch)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/12166.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SAP HCM 回溯分析

最近总有人问回溯问题,今天把12年总结的笔记在这共享下: 12年开这个图的时候总是不明白是什么原理,教程看N次,网上资料找一大堆,就是不明白原理,后来为搞明白逻辑,按照教材的数据一样做&#xf…

gitea - fatal: Authentication failed

文章目录 gitea - fatal: Authentication failed概述run_gitea_on_my_pkm.bat 笔记删除windows凭证管理器中对应的url认证凭证启动gitea服务端的命令行正常用 TortoiseGit 提交代码备注END gitea - fatal: Authentication failed 概述 本地的git归档服务端使用gitea. 原来的用…

X Window System 架构概述

X Window System 架构概述 1. X Server 与 X Client ​ 这里引入一张维基百科的图,在Linux系统中,若用户需要图形化界面,则可以使用X Window System,其使用**Client-Server**架构,并通过网络传输相关信息。 ​ ​ X…

Linux防火墙基础

一、Linux防火墙的状态机制 1.iptables是可以配置有状态的防火墙,其有状态的特点是能够指定并记住发送或者接收信息包所建立的连接状态,其一共有四种状态,分别为established invalid new related。 established:该信息包已建立连接&#x…

[论文学习]Adaptively Perturbed Mirror Descent for Learning in Games

[论文学习]Adaptively Perturbed Mirror Descent for Learning in Games 前言概述前置知识和问题约定单调博弈(monotone game)Nash均衡和Gap函数文章问题定义Mirror Descent 方法评价 前言 文章链接 我们称集合是紧的,则集合满足&#xff1…

Go学习:类型转换需注意的点 以及 类型别名

目录 1. 类型转换 2. 类型别名 1. 类型转换 在从前的学习中,知道布尔bool类型变量只有两种值true或false,C/C、Python、JAVA等编程语言中,如果将布尔类型bool变量转换为整型int变量,通常采用 “0为假,非0为真”的方…

使用Pygame制作“吃豆人”游戏

本篇博客展示如何使用 Python Pygame 编写一个简易版的“吃豆人(Pac-Man)” 风格游戏。这里我们暂且命名为 Py-Man。玩家需要控制主角在一个网格地图里移动、吃掉散布在各处的豆子,并躲避在地图中巡逻的幽灵。此示例可帮助你理解网格地图、角…

ubuntu磁盘扩容

ubuntu磁盘扩容 描述先在虚拟机设置里面扩容进入Ubuntu 配置使用命令行工具parted进行分区输出如下完成 描述 执行命令,查看 fs 类型是什么 lsblk -o NAME,FSTYPE,MOUNTPOINT将60G扩容到100G,其中有些操作我也不知道什么意思,反正就是成功了&#xff0…

redis底层数据结构

底层数据结构 了解下这些咱常用的数据其底层实现是啥 在提到使用哪类数据结构之前,先来了解下redis底层到底有多少种数据结构 1,sds动态字符串 概念与由来 redis是一种使用C语言编写的nosql,redis存储的key数据均为string结构&#xff0…

ChatGPT怎么回事?

纯属发现,调侃一下~ 这段时间deepseek不是特别火吗,尤其是它的推理功能,突发奇想,想用deepseek回答一些问题,回答一个问题之后就回复服务器繁忙(估计还在被攻击吧~_~) 然后就转向了GPT&#xf…

趣味Python100例初学者练习01

1. 1 抓交通肇事犯 一辆卡车违反交通规则,撞人后逃跑。现场有三人目击该事件,但都没有记住车号,只记下了车号的一些特征。甲说:牌照的前两位数字是相同的;乙说:牌照的后两位数字是相同的,但与前…

2024-我的学习成长之路

因为热爱,无畏山海

蓝桥杯备考:高精度算法之除法

我们除法的高精度其实也不完全是高精度,而是一个高精度作被除数除以一个低精度 模拟我们的小学除法 由于题目中我们的除数最大是1e9,当它真正是1e9的时候,t是有可能超过1e9的,所以要用long long

Maven jar 包下载失败问题处理

Maven jar 包下载失败问题处理 1.配置好国内的Maven源2.重新下载3. 其他问题 1.配置好国内的Maven源 打开⾃⼰的 Idea 检测 Maven 的配置是否正确,正确的配置如下图所示: 检查项⼀共有两个: 确认右边的两个勾已经选中,如果没有请…

【前端】ES6模块化

文章目录 1. 模块化概述1.1 什么是模块化?1.2 为什么需要模块化? 2. 有哪些模块化规范3. CommonJs3.1 导出数据3.2 导入数据3.3 扩展理解3.4 在浏览器端运行 4.ES6模块化4.1 浏览器运行4.2 在node服务端运行4.3 导出4.3.1 分别导出4.3.2 统一导出4.3.3 默认导出4.3.4 混用 4.…

强化学习笔记(5)——PPO

PPO视频课程来源 首先理解采样期望的转换 变量x在p(x)分布下,函数f(x)的期望 等于f(x)乘以对应出现概率p(x)的累加 经过转换后变成 x在q(x)分布下,f(x)*p(x)/q(x) 的期望。 起因是:求最大化回报的期望,所以对ceta求梯度 具体举例…

20-30 五子棋游戏

20-分析五子棋的实现思路_哔哩哔哩_bilibili20-分析五子棋的实现思路是一次性学会 Canvas 动画绘图(核心精讲50个案例)2023最新教程的第21集视频,该合集共计53集,视频收藏或关注UP主,及时了解更多相关视频内容。https:…

【HTML入门】Sublime Text 4与 Phpstorm

文章目录 前言一、环境基础1.Sublime Text 42.Phpstorm(1)安装(2)启动Phpstorm(3)“启动”码 二、HTML1.HTML简介(1)什么是HTML(2)HTML版本及历史(3)HTML基本结构 2.HTML简单语法(1)HTML标签语法(2)HTML常用标签(3)表格(4)特殊字符 总结 前言 在当今的软件开发领域&#xff0c…

Kubernetes学习之包管理工具(Helm)

一、基础知识 1.如果我们需要开发微服务架构的应用,组成应用的服务可能很多,使用原始的组织和管理方式就会非常臃肿和繁琐以及较难管理,此时我们需要一个更高层次的工具将这些配置组织起来。 2.helm架构: chart:一个应用的信息集合…

Kamailio 不通过 dmq 实现注册复制功能

春节期间找到一篇文章,需要 fg 才能看到: https://medium.com/tumalevich/kamailio-registration-replication-without-dmq-65e225f9a8a7 kamailio1 192.168.56.115 kamailio2 192.168.56.116 kamailio3 192.168.56.117 route[HANDLE_REPLICATION] {i…