李沐深度学习记录4:12.权重衰减/L2正则化

权重衰减从零开始实现

#高维线性回归
%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l#整个流程是,1.生成标准数据集,包括训练数据和测试数据
#          2.定义线性模型训练
#           模型初始化(函数)、包含惩罚项的损失(函数)
#           定义epochs进行训练,每训练5轮评估一次模型在训练集和测试集的损失,画图显示
#           训练结束后分别查看并比较是否添加范数惩罚项损失对应的训练结果w的L2范数
#生成数据集
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5  #训练数据样本数20,测试样本数100,数据维度200,批量大小5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05  #生成w矩阵(200,1),w值0.01,偏置b为0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train) #生成训练数据集X(20,200),y(20,1),y=Xw+b+噪声,train_data接收返回的X,y
train_iter = d2l.load_array(train_data, batch_size)  #传入数据集和批量大小,构造训练数据迭代器
test_data = d2l.synthetic_data(true_w, true_b, n_test) #生成测试数据集
test_iter = d2l.load_array(test_data, batch_size, is_train=False)  #构造测试数据迭代器#初始化模型参数
def init_params():w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)b = torch.zeros(1, requires_grad=True)return [w, b]#定义L2范数惩罚项
def l2_penalty(w):return torch.sum(w.pow(2)) / 2  #L2范数公式需要开平方根,但这里L2范数惩罚项是L2范数的平方,所以不需要开平方根了#训练代码
def train(lambd):  #输入λ超参数w, b = init_params()  #初始化模型参数net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss  #net线性模型torch.matmul(X, w) + b;loss是均方误差num_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):  #进行多次迭代训练for X, y in train_iter:  #每个epoch,取训练数据# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)  #loss计算加上了λ×范数惩罚项l.sum().backward()  #这里计算损失和,下面参数更新时会对梯度求平均再更新参数d2l.sgd([w, b], lr, batch_size)  #进行参数更新操作if (epoch + 1) % 5 == 0:  #每5次epoch训练,评估一次模型的训练损失和测试损失animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())  #训练结束后,计算w的L2范数(没有平方)
#λ为0,无正则化项,训练
train(lambd=0)
d2l.plt.show()

在这里插入图片描述

#λ为10,有正则化项,训练
train(lambd=5)
d2l.plt.show()

在这里插入图片描述

权重衰减的简洁实现

#权重衰减的简洁实现
def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))   #定义模型for param in net.parameters():   #初始化参数param.data.normal_()loss = nn.MSELoss(reduction='none')  #计算loss,这里不包含正则项num_epochs, lr = 100, 0.003# 偏置参数没有衰减#在参数优化部分,计算梯度时加入了权重衰减#所以是计算loss时没计算正则项,只是在计算梯度时加入了权重衰减吗?trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):   #训练100轮for X, y in train_iter:  #对于每轮,取数据训练trainer.zero_grad()   #梯度清零l = loss(net(X), y)  #计算lossl.mean().backward() #反向传播trainer.step()  #更新梯度if (epoch + 1) % 5 == 0:   #每5轮评估一次模型在测试集和训练集的损失animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
#没有进行权重衰减
train_concise(0)

在这里插入图片描述

#进行权重衰减
train_concise(5)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/148717.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Solidity 合约漏洞,价值 38BNB 漏洞分析

Solidity 合约漏洞,价值 38BNB 漏洞分析 1. 漏洞简介 https://twitter.com/NumenAlert/status/1626447469361102850 https://twitter.com/bbbb/status/1626392605264351235 2. 相关地址或交易 攻击交易: https://bscscan.com/tx/0x146586f05a451313…

洛谷题目题解详细解答

洛谷是一个很不错的刷题软件,可是找不到合适的题解是个大麻烦,大家有啥可以私信问我,以下是我已经通过的题目。 你如果有哪一题不会(最好是我通过过的,我没过的也没关系),可以私信我&#xff0…

Ubuntu1804 安装后无法使用root登录解决方法

1. 给root用户设置密码 sudo passwd root2. 确认是否安装ssh服务 (在安装Ubuntu 的时候可以勾选安装ssh 远程服务),没有安装的话执行以下命令(Ubuntu可以连接互联网) sudo apt-get instll openssh-server3. 设置允许root 用户进行远程连接 sudo vim /etc/ssh/sshd_config 在…

1803_ChibiOS网络书籍阅读_嵌入式RTOS介绍

全部学习汇总: GreyZhang/g_ChibiOS: I found a new RTOS called ChibiOS and it seems interesting! (github.com) 1. RTOS指的是实时性操作系统,但是并不是只有嵌入式领域使用RTOS。然而,嵌入式是RTOS的主要使用领域。 2. 一般的RTOS有一组…

stm32 - GPIO

stm32 - GPIO GPIO结构图GPIO原理图输入上拉/下拉/浮空施密特触发器片上外设 输出推挽/开漏/关闭输出方式 GPIO88种模式复用输出 GPIO寄存器端口配置寄存器_CRL端口输入数据寄存器_IDR端口输出数据寄存器_ODR端口位设置/清除寄存器_BSRR端口位清除寄存器_BRR端口配置锁定寄存器…

不容易解的题10.5

31.下一个排列 31. 下一个排列 - 力扣(LeetCode)https://leetcode.cn/problems/next-permutation/?envTypelist&envIdZCa7r67M会做就不算难题,如果没做过不知道思路,这道题将会变得很难。 这道题相当于模拟cpp的next_permu…

设计加速!11个Adobe XD插件推荐!

你是否一直在寻找可以提升 Adobe XD 工作流程和体验的方法?如果是,一定要试试这些 Adobe XD 插件!本文将介绍 11 款好用的 Adobe XD 插件,这些插件可以为 UI/UX 设计添加很酷的新功能,极大提升你的工作效率和产出。让我…

基于STM32 ZigBee无线远程火灾报警监控系统物联网温度烟雾

实践制作DIY- GC00168---ZigBee无线远程监控系统 一、功能说明: 基于STM32单片机设计---ZigBee无线远程监控系统 二、功能说明: 1个主机:STM32F103C系列单片机LCD1602显示器蜂鸣器 ZigBee无线模块3个按键(设置、加、减&#xff0…

行与走,放慢自己,思考回顾。

国庆一定要出去走走!!! 为什么要出去行与走? 1、出去行与走看到祖国的大美风景,可以更深刻的认识到我们祖国的美好。 2、可以放空心情,排除掉积攒在写字楼内的方格子里面的郁闷和烦恼。 3、可以为自己的…

阿里云服务器地域和可用区查询表_地域可用区大全

阿里云服务器地域和可用区有哪些?阿里云服务器地域节点遍布全球29个地域、88个可用区,包括中国大陆、中国香港、日本、美国、新加坡、孟买、泰国、首尔、迪拜等地域,同一个地域下有多个可用区可以选择,阿里云服务器网分享2023新版…

Vscode爆红Delete `␍`eslintprettier/prettier

一、先看报错 文件中爆红,提示 Delete ␍eslintprettier/prettier 二、解决方案 项目根目录下,.prettierrc.js 文件中: endOfLine: auto,三、重启VsCode 此时不在爆红,问题完美解决

windows11 安装Nodejs

一、介绍 NPM 全称 Node Package Manager,它是 JavaScript 的包管理工具, 并且是 Node.js 平台的默认包管理工具。通 过 NPM 可以安装、共享、分发代码,管理项目依赖关系。 可从NPM服务器下载别人编写的第三方包到本地使用。可从NPM服务器下载并安装别人编写的命令…

走进Spring的世界 —— Spring底层核心原理解析(一)

文章目录 前言一、Spring中是如何创建一个对象二、Bean的创建过程三、推断构造方法四、AOP大致流程五、Spring事务 前言 ClassPathXmlApplicationContext context new ClassPathXmlApplicationContext("config.xml"); UserService userService (UserService) cont…

Cannot resolve MVC view ‘xxx‘

这是在springboot下通过controller访问templates目录下的静态文件&#xff08;Hello.html)报的错误 原因&#xff1a;缺少thymeleaf依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-thymeleaf</ar…

SSM - Springboot - MyBatis-Plus 全栈体系(十八)

第四章 SpringMVC SpringMVC 实战&#xff1a;构建高效表述层框架 一、SpringMVC 简介和体验 1. 介绍 Spring Web MVC 是基于 Servlet API 构建的原始 Web 框架&#xff0c;从一开始就包含在 Spring Framework 中。正式名称“Spring Web MVC”来自其源模块的名称&#xff08…

【计算机组成原理】考研真题攻克与重点知识点剖析 - 第 1 篇:计算机系统概述

前言 本文基础知识部分来自于b站&#xff1a;分享笔记的好人儿的思维导图与王道考研课程&#xff0c;感谢大佬的开源精神&#xff0c;习题来自老师划的重点以及考研真题。此前我尝试了完全使用Python或是结合大语言模型对考研真题进行数据清洗与可视化分析&#xff0c;本人技术…

剑指offer——JZ35 复杂链表的复制 解题思路与具体代码【C++】

一、题目描述与要求 复杂链表的复制_牛客题霸_牛客网 (nowcoder.com) 题目描述 输入一个复杂链表&#xff08;每个节点中有节点值&#xff0c;以及两个指针&#xff0c;一个指向下一个节点&#xff0c;另一个特殊指针random指向一个随机节点&#xff09;&#xff0c;请对此链…

QT商业播放器

QT商业播放器 总体架构图 架构优点&#xff1a;解耦&#xff0c;采用生产者消费者设计模式&#xff0c;各个线程各司其职&#xff0c;通过消息队列高效协作 这个项目是一个基于ijkplayer和ffplayer.c的QT商业播放器, 项目有5部分构成&#xff1a; 前端QT用户界面 后端是集成了…

成都建筑模板批发市场在哪?

成都作为中国西南地区的重要城市&#xff0c;建筑业蓬勃发展&#xff0c;建筑模板作为建筑施工的重要材料之一&#xff0c;在成都也有着广泛的需求。如果您正在寻找成都的建筑模板批发市场&#xff0c;广西贵港市能强优品木业有限公司是一家值得关注的供应商。广西贵港市能强优…

数组(数据结构)

优质博文&#xff1a;IT-BLOG-CN 一、简介 数组Array是一种线性表数据结构&#xff0c;它用一组连续的内存空间&#xff0c;存储一组具有相同类型的数据。 数组因具有连续的内存空间的特点&#xff0c;数据拥有非常高效率的“随机访问”&#xff0c;时间复杂度为O(1)。但因要保…