PyTorch基础05_模型的保存和加载

目录

一、模型定义组件——重构线性回归

二、模型的加载和保存

2、序列化保存对象和加载

 3、保存模型参数


一、模型定义组件——重构线性回归

回顾之前的手动构建线性回归案例:

1.构建数据集2.加载数据集(数据集转换为迭代器);3.参数初始化4.线性回归(模型函数,前向回归);5.损失函数(均分误差,反向传播对象);6.优化器(梯度更新);7.训练数据集8.预测数据

import math
import torch
import random
from sklearn.datasets import make_regressiondef build_data():"""构建数据集"""# 噪声noise = random.randint(1,5)# 样本数量sample = 1000# 目标y的真实偏置 biasbias = 0.5# coef:真实系数 coef=True 表示希望函数返回生成数据的真实系数x,y,coef = make_regression(n_samples=sample,n_features=4,bias=bias,noise=noise,coef=True,random_state=666)# 数据转换成张量x = torch.tensor(x,dtype=torch.float32)y = torch.tensor(y,dtype=torch.float32)coef = torch.tensor(coef,dtype=torch.float32)return x,y,coef,biasdef load_data(x,y):"""加载数据集将数据集转换为迭代器,以便在训练过程中进行批量处理。"""# 单批次数量batch_size = 16# 样本总数量n_samples = x.shape[0]# 一轮训练的次数n_batches = math.ceil(n_samples/batch_size)# 构建数据索引indices = list(range(n_samples)# 打乱索引random.shuffle(indices)# 从每批次中取出的数据for i in range(0,n_batches):start = i*batch_sizeend = min((i+1)*batch_size,n_samples)# 数据下标切片index = indices[start,end]# 返回数据return x[index],y[index]def initialize(n_feature):"""参数初始化随机初始化权重w, 并将偏置b初始化为1"""torch.manual_seed(66)# 权重 正态分布w = torch.randn(n_feature,required_grad=True,dtype=torch.float32)# 偏置b = torch.tensor(0.0,required_grad=True,dtype=torch.float32)return w,bdef regressor(x,w,b):"""线性回归模型函数 "前向传播""""return x@w + bdef MSE(y_pred,y_true):"""损失函数均分误差 反向传播的对象"""return torch.mean((y_pred-y_true)**2)def optim_step(w,b,dw,db,lr):"""优化器梯度更新 向梯度下降的方向更新"""# 修改的不是原tenser而是tensor的dataw.data -= lr*dw.datab.data -= lr*db.datadef train():"""训练数据集"""# 创建数据x,y,coef,bias = build_data()# 初始化参数w,b = initialize(x.shape[0])# 设置训练参数lr = 0.1 # 学习率epoch = 500 # 迭代次数# 训练数据# 迭代循环for i in range(epoch):total_loss = 0 # 误差总和count = 0 # 训练次数# 批次循环for batch_x,batch_y_true in load_data(x,y):count += 1# 代入线性回归得出预测值batch_y_pred = regressor(x,w,b)# 计算损失函数loss = MSE(batch_y_pred,btach_y_true)tatol_loss += loss# 梯度清零if w.grad is not None:w.data.zero_()if b.grad is not None:b.data.zero_()# 反向传播 计算梯度loss.backward()# 梯度更新 得出预测w和bw,b = optim_step(w,b,w.grad,b.grad,lr)# 打印数据print(f'epoch:{i},loss:{total_loss/count}')return w.data,b.data,coef,biasdef detect(x,w,b):"""预测数据"""return torch.matmul(x.type(torch.float32),w) + bif __name__ == "__main__":w,b,coef,bias = train()print(f'真实系数:{coef},真实偏置:{bias}')print(f'预测系数:{w},预测偏置:{b}')y_pred = detect(torch.tensor([[4,5,6,6],[7,8,8,9]]),w,b)print(f'y_pred:{y_pred}')

 这个手动实现的过程对深度学习的思维很有帮助,现在结合上一篇的官方数据加载器,我们将它重构:

import torch
from sklearn.datasets import make_regression
from torch.utils.data import DataLoader,TensorDatasetdef build_dataset():"""构建数据集"""noise = random.randint(1,5)bias = 14.5X,y,coef = make_regression(n_samples=1000,n_features=4,coef=True,bias=bias,noise=noise,random_state=66)X = torch.tensor(X,dtype=torch.float32)y = torch.tensor(y,dtype=torch.float32)return X,y,coef,biasdef train():"""训练数据集"""# 01 加载数据X,y,coef,bias = build_dataset()# 02 构建模型"""torch.nn.Linear(in_features,out_features)in_features 输入的特征数量——w数量out_features 输出的数量——y数量"""model = torch.nn.Linear(X.shape[1],1)# 03 初始化参数# 若不手动初始化则会自动初始化 这里选择自动初始化# 04 构建损失函数loss_fn = torch.nn.MSELoss() # 均方误差# 05 构建优化器sgd = torch.optim.SGD(model.parameter(),lr) # 传入模型参数和学习率# 06 训练epoch = 500# 06.1 循环次数for i in range(epoch):# 06.2 计算损失data_loader = DataLoader(data,batch_size=16,shuffle=True) # 按小批次划分并随机打乱total_loss = 0count = 0for x,y in data_loader:count += 1y_pred = model(x) # 模型预测的输出值loss = loss_fn(y_pred,y)total_loss += loss# 06.3 梯度清零sgd.zero_grad()# 06.4 反向传播loss.backward()# 06.5 更新参数sgd.step()print(f'epoch:{epoch},loss:{total_loss/count}') # 打印每一批次的结果# 07 保存模型参数print(f'weight:{model.weight},bias:{model.bias}')print(f'true_weight:{coef},true_bias:{bias}')if __name__ == '__main__':train()

可见得方便了许多。

 

二、模型的加载和保存

训练一个模型通常需要大量的数据、时间和计算资源。通过保存训练好的模型,可以满足后续的模型部署、模型更新、迁移学习、训练恢复等各种业务需要求。

1、标准网络模型构建

class MyModle(nn.Module):"""标准网络模型构建"""def __init__(self, input_size, output_size):super(MyModle, self).__init__()self.fc1 = nn.Linear(input_size, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, output_size)def forward(self, x):x = self.fc1(x)x = self.fc2(x)output = self.fc3(x)return output

2、序列化保存对象和加载

import torch
import torch.nn as nnclass MyModle(nn.Module):"""标准网络模型构建"""def __init__(self, input_size, output_size):super(MyModle, self).__init__()self.fc1 = nn.Linear(input_size, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, output_size)def forward(self, x):x = self.fc1(x)x = self.fc2(x)output = self.fc3(x)return outputdef train01():"""保存"""model = MyModle(10,5)# 序列化方式保存模型对象torch.save(model, "./data/model.pkl")def detect01():"""加载"""# 注意设备问题model = torch.load("./data/model.pkl", map_location="cpu")print(model)if __name__ == "__main__":test01()test02()

 3、保存模型参数

更常用的保存和加载方式,只需要保存权重、偏执、准确率等相关参数,都可以在加载后打印观察。

import torch
import torch.nn as nn
import torch.optim as optimclass MyModle(nn.Module):"""标准网络模型构建"""def __init__(self, input_size, output_size):super(MyModle, self).__init__()self.fc1 = nn.Linear(input_size, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, output_size)def forward(self, x):x = self.fc1(x)x = self.fc2(x)output = self.fc3(x)return outputdef train02():model = MyModle(input_size=128, output_size=32)optimizer = optim.SGD(model.parameters(), lr=0.01)# 自己构建要存储的模型参数save_dict = {"init_params": {"input_size": 128,  # 输入特征数"output_size": 32,  # 输出特征数},"accuracy": 0.99,  # 模型准确率"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),}torch.save(save_dict, "model_dict.pth")def detect02():save_dict = torch.load("model_dict.pth")model = MyModle(input_size=save_dict["init_params"]["input_size"],output_size=save_dict["init_params"]["output_size"],)# 初始化模型参数model.load_state_dict(save_dict["model_state_dict"])optimizer = optim.SGD(model.parameters(), lr=0.01)# 初始化优化器参数optimizer.load_state_dict(save_dict["optimizer_state_dict"])# 打印模型信息print(save_dict["accuracy"])print(model)if __name__ == "__main__":train02()detect02()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/479349.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaScript核心语法(3)

前两篇文章大概把JavaScript的基础语法讲了一下,这篇文章主要讲讲ES6的核心语法。ES6的核心语法说实话其实有点多,我重点挑一些经常在项目中用到的来讲,其他一些我没怎么见过的就不讲了。 目录 1.变量和常量 变量(let 和 var&a…

爬虫开发(5)如何写一个CSDN热门榜爬虫小程序

笔者 綦枫Maple 的其他作品,欢迎点击查阅哦~: 📚Jmeter性能测试大全:Jmeter性能测试大全系列教程!持续更新中! 📚UI自动化测试系列: SeleniumJava自动化测试系列教程❤ &#x1f4da…

NIO三大组件

现在互联网环境下,分布式系统大相径庭,而分布式系统的根基在于网络编程,而netty恰恰是java领域的网络编程的王者,如果要致力于并发高性能的服务器程序、高性能的客户端程序,必须掌握netty网络编程。 NIO基础 NIO是从ja…

34 基于单片机的指纹打卡系统

目录 一、主要功能 二、硬件资源 三、程序编程 四、实现现象 一、主要功能 基于STC89C52RC,采用两个按键替代指纹,一个按键按下,LCD12864显示比对成功,则 采用ULN2003驱动步进电机转动,表示开门,另一个…

李宏毅机器学习课程知识点摘要(14-18集)

线性回归,逻辑回归(线性回归sigmoid),神经网络 linear regression , logistic regression , neutral network 里面的偏导的相量有几百万维,这就是neutral network的不同,他是…

文件上传upload-labs-docker通关

(图片加载不出,说明被和谐了) 项目一: sqlsec/ggctf-upload - Docker Image | Docker Hub 学习过程中,可以对照源码进行白盒分析. 补充:环境搭建在Linux虚拟机上的同时,以另一台Windows虚拟机进行测试最…

【Android】静态广播接收不到问题分析思路

参考资料: Android 静态广播注册流程(广播2)-CSDN博客 Android广播发送流程(广播3)_android 发送广播-CSDN博客 https://zhuanlan.zhihu.com/p/347227068 在Android中,静态广播如果静态广播不能接收,我们可以从整个流程中去分析&#xff…

2024 APMCM亚太数学建模C题 - 宠物行业及相关产业的发展分析和策略(详细解题思路)

在当下, 日益发展的时代,宠物的数量应该均为稳步上升,在美国出现了下降的趋势, 中国 2019-2020 年也下降,这部分变化可能与疫情相关。需要对该部分进行必要的解释说明。 问题 1: 基于附件 1 中的数据及您的团队收集的额…

Git简单介绍

一、 Git介绍与安装 1.1 Git简介 Git是一个开源的分布式版本控制系统,可以有效、高速地处理从很小到非常大的项目版本管理。 1.2集中式(SVN) VS 分布式(git) 集中式版本控制系统,版本库是集中存放在中央服务器的,工作时要先从中央…

CSS之3D转换

三维坐标系 三维坐标系其实就是指立体空间,立体空间是由3个轴共同组成的。 x轴:水平向右注意:x右边是正值,左边是负值 y轴:垂直向下注意:y下面是正值,上面是负值 z轴:垂直屏幕注意:往外面是正值,往里面是负值 3D移动 translat…

kafka生产者和消费者命令的使用

kafka-console-producer.sh 生产数据 # 发送信息 指定topic即可 kafka-console-producer.sh \ --bootstrap-server bigdata01:9092 \ --topic topicA # 主题# 进程 29124 ConsoleProducer kafka-console-consumer.sh 消费数据 # 消费数据 kafka-console-consumer.sh \ --boo…

基于Springboot的心灵治愈交流平台系统的设计与实现

基于Springboot的心灵治愈交流平台系统 介绍 基于Springboot的心灵治愈交流平台系统,后端框架使用Springboot和mybatis,前端框架使用Vuehrml,数据库使用mysql,使用B/S架构实现前台用户系统和后台管理员系统,和不同级别…

【人工智能】Python常用库-Scikit-learn常用方法教程

Scikit-learn 是一个功能强大的机器学习库,支持数据预处理、分类、回归、聚类、降维等功能,广泛用于模型开发与评估。以下是 Scikit-learn 的常用方法及详细说明。 1. 安装与导入 安装 Scikit-learn: pip install scikit-learn导入基本模块…

Tcon技术和Tconless技术介绍

文章目录 TCON技术(传统时序控制器)定义:主要功能:优点:缺点: TCONless技术(无独立时序控制器)定义:工作原理:优点:缺点: TCON与TCONl…

计算机基础(下)

内存管理 内存管理主要做了什么? 操作系统的内存管理非常重要,主要负责下面这些事情: 内存的分配与回收:对进程所需的内存进行分配和释放,malloc 函数:申请内存,free 函数:释放内存…

【青牛科技】TS223 单触摸键检测IC

概 述 : TS223是 触 摸 键 检 测 IC, 提 供 1个 触 摸 键 。 触 摸 检 测 IC是 为 了用 可 变 面 积 的 键 取 代 传 统 的 按 钮 键 而 设 计 的 。低 功 耗 和 宽 工 作 电压是 触 摸 键 的 DC和 AC特 点 。TS223采 用 SSOP16、 SOT23-6的 封 装 形 式…

CUDA补充笔记

文章目录 一、不同核函数前缀二、指定kernel要执行的线程数量三、线程需要两个内置坐标变量来唯一标识线程四、不是blocksize越大越好,上限一般是1024个blocksize 一、不同核函数前缀 二、指定kernel要执行的线程数量 总共需要线程数是: 1 * N N个线程…

“华为杯”研究生数学建模比赛历年赛题汇总(2004-2024)

文章目录 赛题链接历年赛题2004年赛题2005年赛题2006年赛题2007年赛题2008年赛题2009年赛题2010年赛题2011年赛题2012年赛题2013年赛题2014年赛题2015年赛题2016年赛题2017年赛题2018年赛题2019年赛题2020年赛题2020年赛题2021年赛题2022年赛题2023年赛题2024年赛题 赛题链接 部…

Python学习指南 + 谷歌浏览器如何安装插件

找往期文章包括但不限于本期文章中不懂的知识点: 个人主页:我要学编程(ಥ_ಥ)-CSDN博客 所属专栏: Python 目录 前言 Python 官方文档的使用 谷歌浏览器中如何安装插件 前言 在学习Python时,我们可能会出现这样的困惑&#x…

java写一个石头剪刀布小游戏

石头剪刀布是一款经典的手势游戏,通常由两人参与,玩法简单且充满趣味。玩家通过出示手势代表“石头”、“剪刀”或“布”,并根据规则比较手势决定胜负。它广泛用于休闲娱乐、决策或解压活动。 一、功能简介 用户与计算机对战。 用户输入选择:石头、剪刀或布。 计算机随机生…