Datawhale X 李宏毅苹果书 AI夏令营 Task2笔记

       Datawhale X 李宏毅苹果书 向李宏毅学深度学习(进阶) 是 Datawhale 2024 年 AI 夏令营第五期的学习活动(“深度学习 进阶”方向

       往期task1链接:深度学习进阶-Task1

       我做的task1的笔记博客:传送门

       Datawhale官方的task2链接:深度学习进阶-Task2

       Github-《深度学习详解》开源地址:传送门

《深度学习详解》主要内容源于《机器学习》(2021年春),选取了《机器学习》(2017年春) 的部分内容,在这些基础上进行了一定的原创,补充了不少除这门公开课之外的深度学习相关知识。为了尽可能地降低阅读门槛,笔者对这门公开课的精华内容进行选取并优化,对所涉及的公式都给出详细的推导过程,对较难理解的知识点进行了重点讲解和强化,以方便读者较为轻松地入门。

       在理论严谨的基础上,本书保留了公开课中大量生动有趣的例子,帮助读者从生活化的角度理解深度学习的概念、建模过程和核心算法细节,包括——

  • 卷积神经网络、Transformer、生成模型、自监督学习(包括 BERT 和 GPT)等深度学习常见算法,

  • 对抗攻击、领域自适应、强化学习、元学习、终身学习、网络压缩等深度学习相关进阶算法。


目录

1. 深度学习优化器的演变

2. AdaGrad

AdaGrad的优点:

AdaGrad的缺点:

3. RMSProp

RMSProp的优点:

RMSProp的缺点:

4. Adam

Adam的优点:

Adam的缺点:

5. 学习率调度

6. 分类问题的损失函数

7. 拓展-  RAdam

RAdam的优点:

RAdam的缺点:

8. 知识点总结

9.(实践任务):HW3(CNN)卷积神经网络-图像分类

       获取数据集和代码文件命令:

       训练模型代码

       运行结果

       十分钟跑通baseline视频(跑baseline过程中出现问题的可以对照着看看):


       在第五期的进阶方向的学习内容中,Task2在Task1的基础上继续叙述,主要学习自适应学习率、学习率调度、优化和分类问题的知识点,对应《深度学习详解》一书中的3.3&4&5及3.6的内容。

       在深度学习模型训练中,优化算法起着至关重要的作用。它决定了模型参数更新的方向和速度,进而影响模型的性能和泛化能力。本笔记将结合图文,深入浅出地解释深度学习中的优化概念和算法,包括局部极小值、鞍点、批量梯度下降、随机梯度下降、动量法以及自适应学习率。

1. 深度学习优化器的演变

       深度学习模型的训练过程本质上是一个优化问题,目标是最小化损失函数。传统的梯度下降法存在着一些局限性,例如:

  • 学习率固定: 无法适应不同参数和不同阶段的训练需求。
  • 梯度消失/爆炸: 对于深层网络,梯度在反向传播过程中会逐渐减小或增大,导致训练困难。
  • 陷入局部最优解: 梯度下降法容易陷入局部最优解,无法找到全局最优解。

 不同学习率对训练的影响

       最原始的梯度下降连简单的误差表面都做不好,因此需要更好的梯度下降的版本。在梯 度下降里面,所有的参数都是设同样的学习率,这显然是不够的,应该要为每一个参数定制 化学习率,即引入自适应学习率(adaptive learning rate)的方法,给每一个参数不同的学习率。

       为了克服这些局限性,研究者们提出了许多改进的优化器,其中最常用的包括AdaGrad、RMSProp和Adam。

2. AdaGrad

        AdaGrad是最早提出的自适应学习率优化器,其核心思想是根据参数的历史梯度信息动态调整学习率。具体来说,AdaGrad会对每个参数维护一个累加的平方梯度,并将其用于更新学习率。梯度较大的参数对应的学习率会逐渐减小,梯度较小的参数对应的学习率会逐渐增大。

def sgd_adagrad(parameters, sqrs, lr):eps = 1e-10for param, sqr in zip(parameters, sqrs):sqr[:] = sqr + param.grad.data ** 2div = lr / torch.sqrt(sqr + eps) * param.grad.dataparam.data = param.data - div
AdaGrad的优点
  • 避免了梯度消失/爆炸问题,更适合处理稀疏数据。
  • 无需手动调整学习率,可以自动适应不同参数和不同阶段的训练需求。
AdaGrad的缺点
  • 学习率逐渐减小,可能导致训练速度变慢,甚至停止。
  • 会导致参数更新步长越来越小,难以跳出局部最优解。

3. RMSProp

        RMSProp是AdaGrad的改进版本,它引入了超参数α来控制历史梯度的权重,使学习率更具动态性。具体来说,RMSProp会对每个参数维护一个指数衰减的平均平方梯度,并将其用于更新学习率。α值越小,历史梯度的影响越大;α值越大,历史梯度的影响越小。

RMSProp的优点
  • 解决了AdaGrad学习率过快衰减的问题,提高了训练速度。
  • 可以更好地处理非平稳目标函数。
RMSProp的缺点
  • 学习率调整不够平滑,可能导致训练过程不稳定。
  • 需要手动设置超参数α,选择不当会影响训练效果。

4. Adam

        Adam是近年来最常用的优化器之一,它结合了AdaGrad和RMSProp的优点,并引入了动量项,使参数更新更加平滑。具体来说,Adam会对每个参数维护两个状态:一个是指数衰减的平均梯度,用于更新学习率;另一个是指数衰减的平均梯度平方,用于更新动量。

Adam的优点
  • 具有自适应学习率和动量的特性,训练速度快,效果稳定。
  • 无需手动调整学习率,可以自动适应不同参数和不同阶段的训练需求。
  • 避免了梯度消失/爆炸问题,更适合处理深层网络。
Adam的缺点
  • 需要设置多个超参数,选择不当会影响训练效果。
  • 对于某些问题,Adam的效果可能不如专门的优化器。

5. 学习率调度

        学习率调度是指在训练过程中动态调整学习率,以提高训练速度和效果。常见的学习率调度方法包括:

  • 学习率退火: 随着训练次数的增加,逐渐减小学习率,使模型更加精细地调整参数。
  • 学习率预热: 训练初期先增大学习率,快速探索误差空间,然后逐渐减小学习率,进行精细调整。
  • 周期性调整: 将学习率设置为周期性变化的函数,例如余弦退火。

6. 分类问题的损失函数

分类问题常用的损失函数包括:

  • 均方误差: 计算预测值与真实值之间的平方差,适用于回归问题。
  • 交叉熵: 计算预测概率分布与真实概率分布之间的距离,更适合分类问题。
  • Hinge Loss: 计算预测值与真实标签的夹角,适用于支持向量机。
  • Log Loss: 计算预测概率与真实概率的对数差,适用于概率预测问题。

均方误差与交叉熵在分类问题上有什么不同呢?

均方误差与交叉熵的区别

  • 均方误差: 适用于回归问题,但不适合分类问题,因为它没有考虑到类别之间的差异。
  • 交叉熵: 适用于分类问题,因为它可以有效地衡量预测概率分布与真实概率分布之间的差异。

7. 拓展-  RAdam

       RAdam是Adam的改进版本,它引入了阶跃下降的概念,使学习率调整更加平滑。具体来说,RAdam会根据梯度变化情况,动态调整学习率的更新步长。当梯度变化较大时,增加学习率的更新步长;当梯度变化较小时,减小学习率的更新步长。

RAdam对不同的学习率具有鲁棒性,同时仍能快速收敛并获得更高的精度(CIFAR数据集)

       正如你所看到的,RAdam提供了一个动态启发式方法来提供自动化的方差衰减,从而消除了在训练期间热身所涉及手动调优的需要。此外,RAdam对学习速率变化(最重要的超参数)具有更强的鲁棒性,并在各种数据集和各种AI体系结构中提供更好的训练精度和泛化。

       PyTorch的官方github提供了RAdam的实现:https://github.com/LiyuanLucasLiu/RAdam。

RAdam的优点
  • 解决了Adam在训练初期学习率过小的问题,提高了训练速度。
  • 可以更好地处理非平稳目标函数。
RAdam的缺点
  • 需要设置额外的超参数,选择不当会影响训练效果。
  • 对于某些问题,RAdam的效果可能不如Adam。

8. 知识点总结

        选择合适的优化器和学习率调度方法对深度学习模型的训练至关重要。AdaGrad、RMSProp和Adam各有优缺点,需要根据具体问题选择。学习率退火和预热可以有效提高训练速度和效果。RAdam作为Adam的改进版本,也值得尝试。

9.(实践任务):HW3(CNN)卷积神经网络-图像分类

       Homework3的内容是通过利用卷积神经网络架构,通过一个较小的10种食物的图像的数据集训练一个模型完成图像分类的任务。       

       获取数据集和代码文件命令:
git clone https://www.modelscope.cn/datasets/Datawhale/LeeDL-HW3-CNN.git
       训练模型代码
  1. 初始化追踪器:stale 和 best_acc 用于追踪训练过程中的损失和准确率。stale 表示连续没有改进的轮数,当 stale 大于设定的阈值 patience 时,提前停止训练。

  2. 训练阶段:在训练阶段,首先确保模型处于训练模式,然后遍历训练数据加载器 train_loader 中的每个批次。对于每个批次,将图像数据 imgs 和对应的标签 labels 传递给模型,计算输出 logits。然后计算交叉熵损失 loss,并清除上一步中参数中存储的梯度。计算参数的梯度,并进行梯度裁剪以稳定训练。最后更新模型参数。

  3. 验证阶段:在验证阶段,首先确保模型处于评估模式,然后遍历验证数据加载器 valid_loader 中的每个批次。对于每个批次,将图像数据 imgs 和对应的标签 labels 传递给模型,计算输出 logits。计算损失 loss 和准确率 acc

  4. 打印训练和验证信息:在训练和验证阶段,打印当前轮次的损失和准确率。在验证阶段,如果当前轮次的准确率高于最佳准确率 best_acc,则更新 best_acc 和保存模型。

  5. 保存模型:在训练过程中,如果找到更好的模型,则保存模型参数。

 初始化追踪器,这些不是参数,不应该被更改
stale = 0
best_acc = 0for epoch in range(n_epochs):# ---------- 训练阶段 ----------# 确保模型处于训练模式model.train()# 这些用于记录训练过程中的信息train_loss = []train_accs = []for batch in tqdm(train_loader):# 每个批次包含图像数据及其对应的标签imgs, labels = batch# imgs = imgs.half()# print(imgs.shape,labels.shape)# 前向传播数据。(确保数据和模型位于同一设备上)logits = model(imgs.to(device))# 计算交叉熵损失。# 在计算交叉熵之前不需要应用softmax,因为它会自动完成。loss = criterion(logits, labels.to(device))# 清除上一步中参数中存储的梯度optimizer.zero_grad()# 计算参数的梯度loss.backward()# 为了稳定训练,限制梯度范数grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)# 使用计算出的梯度更新参数optimizer.step()# 计算当前批次的准确率acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()# 记录损失和准确率train_loss.append(loss.item())train_accs.append(acc)train_loss = sum(train_loss) / len(train_loss)train_acc = sum(train_accs) / len(train_accs)# 打印信息print(f"[ 训练 | {epoch + 1:03d}/{n_epochs:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")# ---------- 验证阶段 ----------# 确保模型处于评估模式,以便某些模块如dropout能够正常工作model.eval()# 这些用于记录验证过程中的信息valid_loss = []valid_accs = []# 按批次迭代验证集for batch in tqdm(valid_loader):# 每个批次包含图像数据及其对应的标签imgs, labels = batch# imgs = imgs.half()# 我们在验证阶段不需要梯度。# 使用 torch.no_grad() 加速前向传播过程。with torch.no_grad():logits = model(imgs.to(device))# 我们仍然可以计算损失(但不计算梯度)。loss = criterion(logits, labels.to(device))# 计算当前批次的准确率acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()# 记录损失和准确率valid_loss.append(loss.item())valid_accs.append(acc)# break# 整个验证集的平均损失和准确率是所记录值的平均valid_loss = sum(valid_loss) / len(valid_loss)valid_acc = sum(valid_accs) / len(valid_accs)# 打印信息print(f"[ 验证 | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")# 更新日志if valid_acc > best_acc:with open(f"./{_exp_name}_log.txt", "a"):print(f"[ 验证 | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f} -> 最佳")else:with open(f"./{_exp_name}_log.txt", "a"):print(f"[ 验证 | {epoch + 1:03d}/{n_epochs:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")# 保存模型if valid_acc > best_acc:print(f"在第 {epoch} 轮找到最佳模型,正在保存模型")torch.save(model.state_dict(), f"{_exp_name}_best.ckpt")  # 只保存最佳模型以防止输出内存超出错误best_acc = valid_accstale = 0else:stale += 1if stale > patience:print(f"连续 {patience} 轮没有改进,提前停止")break
       运行结果

       简单的 baseline 不过多赘述,以下是运行结果:

         十分钟跑通baseline视频(跑baseline过程中出现问题的可以对照着看看):

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/411599.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C语言】宏定义详解

目录 C语言宏定义详解1. 宏定义关键词总览2. #define3. #undef4. #ifdef5. #ifndef6. #if7. #else8. #elif9. #endif10. #include11. #error12. #pragma12.1 #pragma once12.2 #pragma pack12.3 #pragma warning12.4 #pragma GCC 13. #line14. 字符串化和标识符连接14.1 字符串…

C# 对桌面快捷方式的操作设置开机启动项

首先在项目中引入Windows Script Host Object Model,引入方式如下图。 对于桌面快捷方式的修改无非就是将现有的快捷方式修改和添加新的快捷方式。 1、遍历桌面快捷方式,代码如下。 string desktopPath Environment.GetFolderPath(Environment.Special…

LLM 应用开发入门 - 实现 langchain.js ChatModel 接入火山引擎大模型和实现一个 CLI 聊天机器人(上)

前言 Langchain 是一个大语言模型(LLM)应用开发的框架,提供了 LLM 开发中各个阶段很多非常强大的辅助工具支持。对于进行 LLM 开发是必不可少的工具库。 本文将通过一个实际的开发例子来入门 LLM 开发基础工具链,并实现 langchain.js ChatModel 接入火山引擎大模型和基于…

【亲测有效】linux抓包http协议分析,分析header和body

linux抓包http协议分析,分析header和body 安装: 执行抓包命令,这里ip要换成你想抓包的目标ip: ngrep -q -W byline -d any "^Host:|^GET|^POST|^HTTP/" tcp and host 183.2.172.42 and port 80 触发抓包,…

FPGA实现多功能SDI视频采集卡,基于GTX+RIFFA架构,提供2套工程源码和技术支持

目录 1、前言工程概述免责声明 2、相关方案推荐我已有的PCIE方案本博已有的 SDI 编解码方案 3、详细设计方案设计框图SDI 输入设备Gv8601a 均衡器GTX 解串与串化SMPTE SD/HD/3G SDI IP核BT1120转RGBFDMA图像缓存RIFFA用户数据控制RIFFA架构详解Xilinx 7 Series Integrated Bloc…

94522

springboot 广州应用科技学院的教室管理系统 摘 要 科技进步的飞速发展引起人们日常生活的巨大变化,电子信息技术的飞速发展使得电子信息技术的各个领域的应用水平得到普及和应用。信息时代的到来已成为不可阻挡的时尚潮流,人类发展的历史正进入一个新时…

详细分析Ubuntu中的ufw基本知识

目录 前言1. 基本知识2. 基本使用 前言 由于命令行比较简单,此处主要以表格的形式呈现,还有实战中遇到的一个注意点 1. 基本知识 Ubuntu 中一种用户友好的防火墙配置工具,简化 iptables 的使用,适合那些不熟悉复杂防火墙配置的…

页面内容---复制粘贴【收藏版】【H5 web端亲测有效】

js中的复制粘贴 . 页面内容—复制粘贴【收藏版】【H5 web端亲测有效】 navigator.clipboard.writeText(copyText) 是 Web API 中的一个方法,用于将指定的文本内容复制到用户的剪贴板。这个方法属于 Clipboard API,它使得网页能够读取和写入剪贴板的内容…

开放式耳机哪种好用又实用?优质开放式耳机种草测评

在开放式耳机领域,目前有几款表现尤为突出的产品。作为一名专业的音乐制作人和评测专家,我深知一款出色的耳机对于音乐创作和鉴赏的重要性。 最近,我亲自评测了市面上一些颇受欢迎的开放式耳机,发现它们不仅在音质上有着令人满意…

winXP下构建python开发环境

近期车间有个动平衡检测仪数采的需求,工控机是xp系统,原理也很简单,监控文件变化,发现有新的检测数据就调用远程接口传输到服务器上去。 通常python监控文件变化会用watchdog这个库, 可是xp太老了,测试了一…

Jenkins服务安装配置

目录 Jenkins 配置环境 配置 中文插件 配置 Maven 插件 配置 JDK 配置 Git 配置 SSH 远程服务器 Jenkins 配置项目 构建 maven 项目 构建 pipeline 流水线项目 什么是 Jenkins Jenkins 是一个开源的自动化服务器,主要用于持续集成(CI&#xff…

代码随想录刷题day15丨110.平衡二叉树,257. 二叉树的所有路径, 404.左叶子之和 ,222.完全二叉树的节点个数

代码随想录刷题day15丨110.平衡二叉树,257. 二叉树的所有路径, 404.左叶子之和 ,222.完全二叉树的节点个数 1.题目 1.1平衡二叉树(优先掌握递归) 题目链接:110. 平衡二叉树 - 力扣(LeetCode&a…

以小搏大:Salesforce 十亿参数模型表现超过ChatGPT

小模型的强势崛:轻量化AI如何以高效表现撼动大型模型的统治! ©作者|DWT 来源|神州问学 导读 近年来,人工智能领域的迅猛发展使得大型语言模型(LLM)成为了焦点。这些模型,如OpenAI的GPT-4和Google的…

讲透一个强大的算法模型,Transformer

Transformer 模型是一种基于注意力机制的深度学习模型,广泛应用于自然语言处理(NLP)任务,如机器翻译、文本生成和语义理解。 它最初由 Vaswani 等人在2017年的论文《Attention is All You Need》中提出。它突破了传统序列模型&am…

CSRF 概念及防护机制

概述 CSRF(Cross-Site Request Forgery),即跨站请求伪造,是一种网络攻击方式。在这种攻击中,恶意用户诱导受害者在不知情的情况下执行某些操作,通常是利用受害者已经登录的身份,向受害者信任的…

微纳芯:如何利用CRM实现渠道分销管理的数字化转型

MINCHIP由联想控股投资,是一家专注于快速体外诊断产品的研发、生产、销售、服务的高科技企业,拥有多项自主知识产权及技术专利。致力于用专业的微流控临床检验产品,为全球大众提供触手可及、负担得起的健康服务。其系列全自动生化分析仪持续为医师、兽医师的机构运营提供解决方…

C++对C的扩充(8.28)

1.使用C手动封装一个顺序表&#xff0c;包括成员数组1个&#xff0c;成员变量n个 代码&#xff1a; #include <iostream>using namespace std;//类型重命名 using datatype int; #define MAX 30struct seqList { private: //私有权限datatype *data; //相当于 …

Java中的java.lang.ArithmeticException: null问题详解与解决方案

个人名片 &#x1f393;作者简介&#xff1a;java领域优质创作者 &#x1f310;个人主页&#xff1a;码农阿豪 &#x1f4de;工作室&#xff1a;新空间代码工作室&#xff08;提供各种软件服务&#xff09; &#x1f48c;个人邮箱&#xff1a;[2435024119qq.com] &#x1f4f1…

fixed、absolute 和 relative 布局

https://andi.cn/page/621716.html

0.0 C语言被我遗忘的知识点

文章目录 位移运算(>>和<<)函数指针函数指针的应用场景 strcmp的返回值合法的c语言实数表示sizeof 数组字符串的储存 —— 字符数组与字符指针字符串可能缺少 \0 的情况 用二维数组储存字符串数组其他储存字符串数组的方法 位移运算(>>和<<) 右移(>…