PyTorch深度学习模型训练流程的python实现:回归

回归的流程与分类基本一致,只需要把评估指标改动一下就行。回归输出的是损失曲线、R^2曲线、训练集预测值与真实值折线图、测试集预测值散点图与真实值折线图。输出效果如下:

 注意:预测值与真实值图像处理为按真实值排序,图中呈现的升序与数据集趋势无关。

代码如下:

from functools import partial
import numpy as np
import pandas as pd
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, r2_scoreimport torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from visdom import Visdomfrom typing import Union, Optional
from sklearn.base import TransformerMixin
from torch.optim.optimizer import Optimizerdef regress(data: tuple[Union[np.ndarray, Dataset], Union[np.ndarray, Dataset]],model: nn.Module,optimizer: Optimizer,criterion: nn.Module,scaler: Optional[TransformerMixin] = None,batch_size: int = 64,epochs: int = 10,device: Optional[torch.device] = None
) -> nn.Module:"""回归任务的训练函数。:param data: 形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型:param model: 回归模型:param optimizer: 优化器:param criterion: 损失函数:param scaler: 数据标准化器:param batch_size: 批大小:param epochs: 训练轮数:param device: 训练设备:return: 训练好的回归模型"""if isinstance(data[0], np.ndarray):X, y = data# 分离训练集和测试集,指定随机种子以便复现X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化if scaler is not None:X_train = scaler.fit_transform(X_train)X_test = scaler.transform(X_test)# 转换为tensorX_train = torch.from_numpy(X_train.astype(np.float32))X_test = torch.from_numpy(X_test.astype(np.float32))y_train = torch.from_numpy(y_train.astype(np.float32))y_test = torch.from_numpy(y_test.astype(np.float32))# 将X和y封装成TensorDatasettrain_dataset = TensorDataset(X_train, y_train)test_dataset = TensorDataset(X_test, y_test)elif isinstance(data[0], Dataset):train_dataset, test_dataset = dataelse:raise ValueError('Unsupported data type')train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=2,)test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True,num_workers=2,)model.to(device)vis = Visdom()# 训练模型for epoch in range(epochs):for step, (batch_x_train, batch_y_train) in enumerate(train_loader):batch_x_train = batch_x_train.to(device)batch_y_train = batch_y_train.to(device)# 前向传播output = model(batch_x_train)loss = criterion(output, batch_y_train)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()niter = epoch * len(train_loader) + step + 1  # 计算迭代次数if niter % 100 == 0:# 评估模型model.eval()with torch.no_grad():eval_dict = {'test_loss': [],'test_r2': [],'y_test': [],'y_pred': [],}for batch_x_test, batch_y_test in test_loader:batch_x_test = batch_x_test.to(device)batch_y_test = batch_y_test.to(device)test_output = model(batch_x_test)test_predicted_tuple = (batch_y_test.numpy(), test_output.numpy())# 计算并记录损失、R^2、真实值、预测值eval_dict['test_loss'].append(criterion(test_output, batch_y_test))eval_dict['test_r2'].append(r2_score(*test_predicted_tuple))eval_dict['y_test'].append(batch_y_test)eval_dict['y_pred'].append(test_output)# 画出损失曲线vis.line(X=torch.ones((1, 2)) * (niter // 100),Y=torch.stack((loss, torch.mean(torch.tensor(eval_dict['test_loss'])))).unsqueeze(0),win='loss',update='append',opts=dict(title='Loss', legend=['train_loss', 'test_loss']),)# 画出R^2曲线train_r2 = r2_score(batch_y_train.numpy(), output.numpy())vis.line(X=torch.ones((1, 2)) * (niter // 100),Y=torch.tensor((train_r2, np.mean(eval_dict['test_r2']))).unsqueeze(0),win='R^2',update='append',opts=dict(title='R^2', legend=['train_R^2', 'test_R^2'], ytickmin=0, ytickmax=1),)# 画出训练集预测值和真实值折线图sorted_train_idx = torch.argsort(batch_y_train)  # 按真实值排序vis.line(X=torch.arange(batch_size).repeat(2, 1).t(),Y=torch.stack((batch_y_train[sorted_train_idx], output[sorted_train_idx]), dim=1),win='batch_train_line',opts=dict(title='Predicted vs. Actual (Train Set)', legend=['Actual', 'Predicted']),)# 画出测试集预测值散点图和真实值折线图x = list(range(len(y_test)))y_test = torch.cat(eval_dict['y_test'])y_pred = torch.cat(eval_dict['y_pred'])sorted_test_idx = torch.argsort(y_test)vis._send({'data': [{'x': x, 'y': y_test[sorted_test_idx].tolist(), 'type': 'custom', 'mode': 'lines', 'name': 'Actual'},{'x': x, 'y': y_pred[sorted_test_idx].tolist(), 'type': 'custom', 'mode': 'markers', 'name': 'Predicted', 'marker': {'size': 3}}],'win': 'test_line','layout': {'title': 'Predicted vs. Actual (Test Set)'},})return model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/407961.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OCR识别行驶证(阿里云和百度云)

OCR识别行驶证(阿里云和百度云) 一、使用场景 1、通过识别行驶证,获取相关汽车信息,替代手输 二、效果图 三、代码部分: 1、阿里云OCR 1.1、控制层 PostMapping("/ocrCard") public JSONObject ocrCard(RequestPart("fi…

快速入门:使用Python构建学生成绩管理应用

前言 诸位观众,本学期我有幸学习了Python编程课程。随着课程的结束,授课教师布置了一项任务,要求我们开发一个学生信息管理系统。基于老师的要求,我个人独立完成了这项任务。今天,我希望将这个简易的程序分享给大家&a…

【数字三角形】

题目 代码 #include <bits/stdc.h> using namespace std;const int N 510; int f[N][N]; int a[N][N]; int main() {int n;cin >> n;for(int i 1; i < n; i){for(int j 1; j < i; j){cin >> a[i][j];if(i 1 && j 1) f[i][j] a[i][j];el…

ORCAD Capture CIS 打开原理图总是卡住

原因&#xff1a;ORCAD自动进行了DRC检查。要打开的原理图中footprint未指定footprint路径。 修改&#xff1a;1、第一种方法&#xff1a;指定footprint路径 2、第二种方法&#xff1a;关闭在线DRC检查

钢包智慧管理平台

钢包智慧管理平台基于海康、大华视频监控&#xff0c;实现对钢包的全动态管理&#xff0c;实时检测钢包的温度数据变化&#xff0c;也可以随时查询时间区间内的钢包温度数据变化。 平台基于springboot vue前后台分离技术开发&#xff0c;视频基于zlmedia的转码拉流。实现了视频…

STM32————SPI硬件外设实现读写

首先是理论知识&#xff1a; 常用8位数据帧、高位先行 SPI的时钟由PCLK内部时钟分频得来&#xff0c;最大可到36MHz 精简为半双工就是去掉一根数据线后&#xff0c;用剩下的一根作为发送/接收数据&#xff1b;单工就是去掉接收线&#xff0c;只用发送线进行发送数据&#xf…

揭秘CAAC、AOPA、ALPA、ASFC和UTC无人机执照的差别及实用价值

CAAC、AOPA、ALPA、ASFC和UTC无人机执照各有其独特的差别及实用价值&#xff0c;以下是针对这些执照的详细解析&#xff1a; 一、CAAC无人机执照 颁发机构&#xff1a;中国民用航空局&#xff08;CAAC&#xff09; 差别&#xff1a; - 权威性&#xff1a;CAAC无人机执照是目…

Java面试题--JVM大厂篇之JVM 大厂面试题及答案解析(2)

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到我的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而我的博客&…

Leetcode 1108. IP地址无效化

Leetcode 1108. IP 地址无效化 问题&#xff1a;给你一个有效的 IPv4 地址address&#xff0c;返回这个 IP 地址的无效化版本。 所谓无效化 IP 地址&#xff0c;其实就是用 "[.]" 代替了每个 "."。 方法1&#xff1a;对字符串挨个进行判断&#xff0c;如…

http连接未释放导致生产故障

凌晨4点运维老大收到NAT网关连接数打满报警&#xff08;官网页面接口超时&#xff09;&#xff0c;运维自己先看了看服务器相关配置&#xff0c;先后还联系了阿里云的客服&#xff0c;客服建议升级NAT网络连接阈值&#xff0c;之前是1w升级到了5w&#xff0c;但后来还是给研发打…

安装torchvision==0.5.0

安装pytorch 1.4 但是在当前配置的镜像源中找不到 torchvision0.5.0 这个版本的包。 直接找资源下载 网址添加链接描述 直接运行该命令&#xff0c;成功。 然后重复运行上面的命令就可以了 # CUDA 9.2 conda install pytorch1.4.0 torchvision0.5.0 cudatoolkit9.2 -c pyto…

Python编码系列—Python单元测试的艺术:深入探索unittest与pytest

&#x1f31f;&#x1f31f; 欢迎来到我的技术小筑&#xff0c;一个专为技术探索者打造的交流空间。在这里&#xff0c;我们不仅分享代码的智慧&#xff0c;还探讨技术的深度与广度。无论您是资深开发者还是技术新手&#xff0c;这里都有一片属于您的天空。让我们在知识的海洋中…

CS1.5快捷键

《黑神话悟空》玩不起&#xff0c;玩起了23年前的cs1.5 B11&#xff1a;USP(警察自带手枪&#xff09; B12&#xff1a;Glock18(匪徒自带手枪) B13&#xff1a;Desert Eagle&#xff08;沙漠之鹰&#xff09; B14&#xff1a;P-228 B15&#xff1a;Dual Berettas&#xff08;匪…

linux中对.jar文件的配置文件进行修改

linux中对.jar文件的配置文件进行修改 第一步&#xff0c;进入你的.jar的当前文件夹 第二步 &#xff0c;编辑你指定的 .jar 文件 编辑之前请先备份 cp xxx.jar xxx-1.2.jar 输入编辑命令 vim xxx.jar第三步&#xff0c;找到你要编辑的文件 输入命令进入vi模式&#xff08;…

金蝶云星空开发简单账表《物料年采购入库报表》

文章目录 业务背景业务需求方案设计详细设计测试业务背景 系统现有功能不支持查询过去一年内所有物料的入库数,需要人工导出,然后再汇总。 业务需求 可以查询所有物料的入库数,多个物料,单个物料,多个组织,单个组织的入库数,以及支持查询入库数大于某个阈值。 方案设…

Unity教程(十一)使用Cinemachine添加并调整相机

Unity开发2D类银河恶魔城游戏学习笔记 Unity教程&#xff08;零&#xff09;Unity和VS的使用相关内容 Unity教程&#xff08;一&#xff09;开始学习状态机 Unity教程&#xff08;二&#xff09;角色移动的实现 Unity教程&#xff08;三&#xff09;角色跳跃的实现 Unity教程&…

一文彻底搞懂CNN - 模型架构(Model Architecture)

CNN Model Architecture CNN&#xff08;卷积神经网络&#xff09;的模型架构由输入层、卷积层、池化层以及全连接层组成&#xff0c;通过卷积操作提取图像特征&#xff0c;并通过池化减少参数数量&#xff0c;最终通过全连接层进行分类或回归。 输入层&#xff1a;接收原始图…

【奇某信-注册/登录安全分析报告】

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造成亏损无底洞…

Android Jitpack制作远程仓库aar流程

开发高效提速系列目录 软件多语言文案脚本自动化方案Android Jitpack制作远程仓库aar流程 Android Jitpack制作远程仓库aar流程 背景aar制作与使用1. aar制作2. aar使用 异常解决总结 博客创建时间&#xff1a;2023.08.24 博客更新时间&#xff1a;2023.08.24 以Android stud…

【闪送-注册安全分析报告】

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造成亏损无底洞…