Chapter5.4 Loading and saving model weights in PyTorch

5 Pretraining on Unlabeled Data

5.4 Loading and saving model weights in PyTorch

  • 训练LLM的计算成本很高,因此能够保存和加载LLM的权重至关重要。

  • 在PyTorch中,推荐的方式是通过将torch.save函数应用于.state_dict()方法来保存模型权重,即所谓的state_dict

    torch.save(model.state_dict(),"model.pth")
    

    我们可以将模型权重加载到新的 GPTModel 模型实例中

    model = GPTModel(GPT_CONFIG_124M)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=True))
    model.eval();
    
  • 自适应优化器(如 AdamW)为每个模型权重存储额外的参数。AdamW 使用历史数据动态调整每个模型参数的学习率。如果没有这些参数,优化器会重置,模型可能会学习效果不佳,甚至无法正确收敛,这意味着模型将失去生成连贯文本的能力。使用 torch.save,我们可以保存模型和优化器的 state_dict 内容,如下所示

    torch.save({"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),}, "model_and_optimizer.pth"
    )
    

    然后,我们可以通过以下方式恢复模型和优化器状态:首先通过 torch.load 加载保存的数据,然后使用 load_state_dict 方法:

    checkpoint = torch.load("model_and_optimizer.pth", weights_only=True)model = GPTModel(GPT_CONFIG_124M)
    model.load_state_dict(checkpoint["model_state_dict"])optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.1)
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    model.train();
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/4188.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

运动相机拍视频过程中摔了,导致录视频打不开怎么办

3-11 在使用运动相机拍摄激烈运动的时候,极大的震动会有一定概率使得保存在存储卡中的视频出现打不开的情况,原因是存储卡和相机在极端情况下,可能会出现接触不良的问题,如果遇到这种问题,就不得不进行视频修复了。 本…

Python制作简易PDF查看工具PDFViewerV1.0

PDFViewer PDF浏览工具,Python自制PDF查看工具,可实现基本翻页浏览功能,其它功能在进一步开发完善当中,如果有想一起开发的朋友,可以留言。本软件完全免费,自由使用。 软件界面简洁,有菜单栏、…

SpringBoot实现定时任务,使用自带的定时任务以及调度框架quartz的配置使用

SpringBoot实现定时任务,使用自带的定时任务以及调度框架quartz的配置使用 文章目录 SpringBoot实现定时任务,使用自带的定时任务以及调度框架quartz的配置使用一. 使用SpringBoot自带的定时任务(适用于小型应用)二. 使用调度框架…

Output

AUTOSAR OS模块详解(三) Alarm 本文主要介绍AUTOSAR OS的Alarm,并对基于英飞凌Aurix TC3XX系列芯片的Vector Microsar代码和配置进行部分讲解。 文章目录 AUTOSAR OS模块详解(三) Alarm1 简介2 功能介绍2.1 触发原理2.2 工作类型2.3 Alarm启动方式2.4 Alarm配置2.5…

openharmony应用开发快速入门

开发准备 本文档适用于OpenHarmony应用开发的初学者。通过构建一个简单的具有页面跳转/返回功能的应用(如下图所示),快速了解工程目录的主要文件,熟悉OpenHarmony应用开发流程。 在开始之前,您需要了解有关OpenHarmon…

使用傅里叶变换进行图像边缘检测

使用傅里叶变换进行图像边缘检测 今天我们介绍通过傅里叶变换求得图像的边缘 什么是傅立叶变换? 简单来说,傅里叶变换是将输入的信号分解成指定样式的构造块。例如,首先通过叠加具有不同频率的两个或更多个正弦函数而生成信号f(x…

用户中心项目教程(四)---Vue脚手架完成前端初始化

目录 1.项目的创建 2.使用开发工具打开 3.项目运行方法 4.使用按钮组件 5.全局注册 6.如何进行组件的测试 7.使用组件的效果展示 8.关于这个vue项目内容的说明 1.项目的创建 这个前提你是你完成了我的教程(三)里面的相关配置,不然你可…

《自动驾驶与机器人中的SLAM技术》ch4:基于预积分和图优化的 GINS

前言:预积分图优化的结构 1 预积分的图优化顶点 这里使用 《自动驾驶与机器人中的SLAM技术》ch4:预积分学 中提到的散装的形式来实现预积分的顶点部分,所以每个状态被分为位姿()、速度、陀螺零偏、加计零偏四种顶点&am…

二叉搜索树(TreeMapTreeSet)

文章目录 1.概念2.二叉搜索树的底层代码实现(1)首先构建二叉树(2)实现插入功能;(3)实现查找(4)删除(重点) 3.TreeMap 1.概念 TreeMap&TreeSet都是有序的集合都是基于二叉搜索树来实现的 二叉搜索树:是一种特殊的二叉树 若左子…

【QT用户登录与界面跳转】

【QT用户登录与界面跳转】 1.前言2. 项目设置3.设计登录界面3.1 login.pro参数3.2 界面设置3.2.1 登录界面3.2.2 串口主界面 4. 实现登录逻辑5.串口界面6.测试功能7.总结 1.前言 在Qt应用程序开发中,实现用户登录及界面跳转功能是构建交互式应用的重要步骤之一。下…

基于springboot的口腔管理平台

作者:学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等 文末获取“源码数据库万字文档PPT”,支持远程部署调试、运行安装。 项目包含: 完整源码数据库功能演示视频万字文档PPT 项目编码&#xff1…

4 AXI USER IP

前言 使用AXI Interface封装IP,并使用AXI Interface实现对IP内部寄存器进行读写实现控制LED的demo,这个demo是非常必要的,因为在前面的笔记中基本都需哟PS端与PL端就行通信互相交互,在PL端可以通过中断的形式来告知PS端一些事情&…

实力认证 | 海云安入选《信创安全产品及服务购买决策参考》

近日,国内知名安全调研机构GoUpSec发布了2024年中国网络安全行业《信创安全产品及服务购买决策参考》,报告从产品特点、产品优势、成功案例、安全策略等维度对各厂商信创安全产品及服务进行调研了解。 海云安凭借AI大模型技术在信创安全领域中的创新应用…

二、点灯基础实验

嵌入式基础实验第一个就是点灯,地位相当于编程界的hello world。 如下为LED原理图,要让相应LED发光,需要给I/O口设置输出引脚,低电平,二极管才会导通 2.1 打开初始工程,编写代码 以下会实现BLINKY常亮&…

Amazon MSK 开启 Public 访问 SASL 配置的方法

1. 开启 MSK Public 1.1 配置 MSK 参数 进入 MSK 控制台页面,点击左侧菜单 Cluster configuration。选择已有配置,或者创建新配置。在配置中添加参数 allow.everyone.if.no.acl.foundfalse修改集群配置,选择到新添加的配置。 1.2 开启 Pu…

大模型UI:Gradio全解11——Chatbot:融合大模型的聊天机器人(4)

大模型UI:Gradio全解11——Chatbot:融合大模型的聊天机器人(4) 前言本篇摘要11. Chatbot:融合大模型的多模态聊天机器人11.4 使用Blocks创建自定义聊天机器人11.4.1 简单聊天机器人演示11.4.2 立即响应和流式传输11.4.…

流量分析复现(第十八届信息安全大赛 第二届长城杯 )

zeroshell_1 题目:从数据包中找出攻击者利用漏洞开展攻击的会话(攻击者执行了一条命令),写出该会话中设置的flag, 结果提交形式:flag{xxxxxxxxx} 这里大致的思路还是先看看,流量协议的分级 主要还是以TCP流…

ImportError: /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.32‘ not found

问题描述:安装MMYOLO或者MMROTATE时,出现的问题: (base) rootautodl-container-78fc438fda-4132d99a:~/autodl-tmp/MMROTATE_PROJECT/mmrotate-1.x# python demo/image_demo.py demo/demo.jpg oriented-rcnn-le90_r50_fpn_1x_dota.py orient…

2024年博客之星年度评选—创作影响力评审入围名单公布

2024年博客之星活动地址https://www.csdn.net/blogstar2024 TOP 300 榜单排名 用户昵称博客主页 身份 认证 评分 原创 博文 评分 平均 质量分评分 互动数据评分 总分排名三掌柜666三掌柜666-CSDN博客1001002001005001wkd_007wkd_007-CSDN博客1001002001005002栗筝ihttps:/…

25/1/15 嵌入式笔记 初学STM32F108

GPIO初始化函数 GPIO_Ini:初始化GPIO引脚的模式,速度和引脚号 GPIO_Init(GPIOA, &GPIO_InitStruct); // 初始化GPIOA的引脚0 GPIO输出控制函数 GPIO_SetBits:将指定的GPIO引脚设置为高电平 GPIO_SetBits(GPIOA, GPIO_Pin_0); // 将GPIO…