分类任务实现模型集成代码模版

分类任务实现模型(投票式)集成代码模版

简介

本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化,对每个基学习器的性能进行对比,直观的看出模型集成的作用。

代码

# -*- coding:utf-8 -*-
import os
import torch
import torchvision
import torchmetrics
import torch.nn as nn
import my_utils as utils
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchensemble.utils import set_module
from torchensemble.voting import VotingClassifierclasses = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']def get_args_parser(add_help=True):import argparseparser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)parser.add_argument("--data-path", default=r"E:\Pytorch-Tutorial-2nd\data\datasets\cifar10-office", type=str,help="dataset path")parser.add_argument("--model", default="resnet8", type=str, help="model name")parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")parser.add_argument("-b", "--batch-size", default=128, type=int, help="images per gpu, the total batch size is $NGPU x batch_size")parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)")parser.add_argument("--opt", default="SGD", type=str, help="optimizer")parser.add_argument("--random-seed", default=42, type=int, help="random seed")parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")parser.add_argument("--wd","--weight-decay",default=1e-4,type=float,metavar="W",help="weight decay (default: 1e-4)",dest="weight_decay",)parser.add_argument("--lr-step-size", default=80, type=int, help="decrease lr every step-size epochs")parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")parser.add_argument("--print-freq", default=80, type=int, help="print frequency")parser.add_argument("--output-dir", default="./Result", type=str, help="path to save outputs")parser.add_argument("--resume", default="", type=str, help="path of checkpoint")parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")return parserdef main():args = get_args_parser().parse_args()utils.setup_seed(args.random_seed)args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = args.devicedata_dir = args.data_pathresult_dir = args.output_dir# ------------------------------------  log ------------------------------------logger, log_dir = utils.make_logger(result_dir)writer = SummaryWriter(log_dir=log_dir)# ------------------------------------ step1: dataset ------------------------------------normMean = [0.4948052, 0.48568845, 0.44682974]normStd = [0.24580306, 0.24236229, 0.2603115]normTransform = transforms.Normalize(normMean, normStd)train_transform = transforms.Compose([transforms.Resize(32),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),normTransform])valid_transform = transforms.Compose([transforms.ToTensor(),normTransform])# root变量下需要存放cifar-10-python.tar.gz 文件# cifar-10-python.tar.gz可从 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下载train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, transform=train_transform, download=True)test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, transform=valid_transform, download=True)# 构建DataLodertrain_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)valid_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, num_workers=args.workers)# ------------------------------------ tep2: model ------------------------------------model_base = utils.resnet20()# model_base = utils.LeNet5()model = MyEnsemble(estimator=model_base, n_estimators=3, logger=logger, device=device, args=args,classes=classes, writer=writer, save_dir=log_dir)model.set_optimizer(args.opt, lr=args.lr, weight_decay=args.weight_decay)model.fit(train_loader, test_loader=valid_loader, epochs=args.epochs)class MyEnsemble(VotingClassifier):def __init__(self, **kwargs):# logger, device, args, classes, writersuper(VotingClassifier, self).__init__(kwargs["estimator"], kwargs["n_estimators"])self.logger = kwargs["logger"]self.writer = kwargs["writer"]self.device = kwargs["device"]self.args = kwargs["args"]self.classes = kwargs["classes"]self.save_dir = kwargs["save_dir"]@staticmethoddef save(model, save_dir, logger):"""Implement model serialization to the specified directory."""if save_dir is None:save_dir = "./"if not os.path.isdir(save_dir):os.mkdir(save_dir)# Decide the base estimator nameif isinstance(model.base_estimator_, type):base_estimator_name = model.base_estimator_.__name__else:base_estimator_name = model.base_estimator_.__class__.__name__# {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators}filename = "{}_{}_{}_ckpt.pth".format(type(model).__name__,base_estimator_name,model.n_estimators,)# The real number of base estimators in some ensembles is not same as# `n_estimators`.state = {"n_estimators": len(model.estimators_),"model": model.state_dict(),"_criterion": model._criterion,}save_dir = os.path.join(save_dir, filename)logger.info("Saving the model to `{}`".format(save_dir))# Savetorch.save(state, save_dir)returndef fit(self, train_loader, epochs=100, log_interval=100, test_loader=None, save_model=True, save_dir=None, ):# 模型、优化器、学习率调整器、评估器 列表创建estimators = []for _ in range(self.n_estimators):estimators.append(self._make_estimator())optimizers = []schedulers = []for i in range(self.n_estimators):optimizers.append(set_module.set_optimizer(estimators[i],self.optimizer_name, **self.optimizer_args))scheduler_ = torch.optim.lr_scheduler.MultiStepLR(optimizers[i], milestones=[100, 150],gamma=self.args.lr_gamma)  # 设置学习率下降策略# scheduler_ = torch.optim.lr_scheduler.StepLR(optimizers[i], step_size=self.args.lr_step_size,#                                             gamma=self.args.lr_gamma)  # 设置学习率下降策略schedulers.append(scheduler_)acc_metrics = []for i in range(self.n_estimators):# task类型与任务一致# num_classes与分类任务的类别数一致acc_metrics.append(torchmetrics.Accuracy(task="multiclass", num_classes=len(self.classes)))self._criterion = nn.CrossEntropyLoss()# epoch循环迭代best_acc = 0.for epoch in range(epochs):# trainingfor model_idx, (estimator, optimizer, scheduler) in enumerate(zip(estimators, optimizers, schedulers)):loss_m_train, acc_m_train, mat_train = \utils.ModelTrainerEnsemble.train_one_epoch(train_loader, estimator, self._criterion, optimizer, scheduler, epoch,self.device, self.args, self.logger, self.classes)# 学习率更新scheduler.step()# 记录self.writer.add_scalars('Loss_group', {'train_loss_{}'.format(model_idx):loss_m_train.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'train_acc_{}'.format(model_idx):acc_m_train.avg}, epoch)self.writer.add_scalar('learning rate', scheduler.get_last_lr()[0], epoch)# 训练混淆矩阵图conf_mat_figure_train = utils.show_conf_mat(mat_train, classes, "train", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_train', conf_mat_figure_train, global_step=epoch)# validateloss_valid_meter, acc_valid, top1_group, mat_valid = \utils.ModelTrainerEnsemble.evaluate(test_loader, estimators, self._criterion, self.device, self.classes)# 日志self.writer.add_scalars('Loss_group', {'valid_loss':loss_valid_meter.avg}, epoch)self.writer.add_scalars('Accuracy_group', {'valid_acc':acc_valid * 100}, epoch)# 验证混淆矩阵图conf_mat_figure_valid = utils.show_conf_mat(mat_valid, classes, "valid", save_dir, epoch=epoch,verbose=epoch == epochs - 1, save=False)self.writer.add_figure('confusion_matrix_valid', conf_mat_figure_valid, global_step=epoch)self.logger.info('Epoch: [{:0>3}/{:0>3}]  ''Train Loss avg: {loss_train:>6.4f}  ''Valid Loss avg: {loss_valid:>6.4f}  ''Train Acc@1 avg:  {top1_train:>7.2f}%   ''Valid Acc@1 avg: {top1_valid:>7.2%}    ''LR: {lr}'.format(epoch, self.args.epochs, loss_train=loss_m_train.avg, loss_valid=loss_valid_meter.avg,top1_train=acc_m_train.avg, top1_valid=acc_valid, lr=schedulers[0].get_last_lr()[0]))for model_idx, top1_meter in enumerate(top1_group):self.writer.add_scalars('Accuracy_group',{'valid_acc_{}'.format(model_idx): top1_meter.compute() * 100}, epoch)if acc_valid > best_acc:best_acc = acc_validself.estimators_ = nn.ModuleList()self.estimators_.extend(estimators)if save_model:self.save(self, self.save_dir, self.logger)if __name__ == "__main__":main()

效果图

本实验采用3个学习器进行投票式集成,因此绘制了7条曲线,其中各学习器在训练和验证各有2条曲线,集成模型的结果通过 valid_acc输出(蓝色),通过下图可发现,集成模型与三个基学习器相比,分类准确率都能提高3-4百分点左右,是非常高的提升了。

image-20240830103703565

image-20240830154555390

image-20240830154619630

参考

7.7 TorchEnsemble 模型集成库 · PyTorch实用教程(第二版) (tingsongyu.github.io)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/417701.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Datawhale x李宏毅苹果书AI夏令营深度学习详解进阶Task03

在深度学习中,批量归一化(Batch Normalization,BN)技术是一种重要的优化方法,它可以有效地改善模型的训练效果。本文将详细讨论批量归一化的原理、实现方式、在神经网络中的应用,以及如何选择合适的损失函数…

Python-面向对象编程(超详细易懂)

面向对象编程(oop) 面向对象是Python最重要的特性,在Python中一切数据类型都是面向对象的。 面向对象的编程思想:按照真实世界客观事物的自然规律进行分析,客观世界中存在什么样的实体,构建的软件系统就存在…

视频监控管理平台LntonAIServer视频智能分析噪声检测应用场景

在视频监控系统中,噪声问题常常影响到视频画面的清晰度和可用性。噪声可能由多种因素引起,包括但不限于低光环境、摄像机传感器灵敏度过高、编码压缩失真等。LntonAIServer通过引入噪声检测功能,旨在帮助用户及时发现并解决视频流中的噪声问题…

linux 内核代码学习(八)

总体目标:由于fedora10 linux发行版中自带的linux2.6.xx内核源码规模太庞大了,对于想通读内核源码的爱好者来说太困难了,因此选择了linux2.4.20内核来进行测试(最终是希望能够实现linux1.0内核的源码完全编译和测试)。…

了解一下HTTP 与 HTTPS 的区别

介绍: HTTP是超文本传输协议。规定了客户端(通常是浏览器)和服务器之间如何传输超文本,也就是包含链接的文本。通常使用TCP【1】/IP协议来传输数据,默认端口为80。 HTTPS是超文本传输安全协议,具有CA证书。…

【RLHF】浅谈ChatGPT 等大模型中的RLHF算法

本文收录于《深入浅出讲解自然语言处理》专栏,此专栏聚焦于自然语言处理领域的各大经典算法,将持续更新,欢迎大家订阅!​个人主页:有梦想的程序星空​个人介绍:小编是人工智能领域硕士,全栈工程…

TCP的流量控制深入理解

在理解流量控制之前我们先需要理解TCP的发送缓冲区和接收缓冲区,也称为套接字缓冲区。首先我们先知道缓冲区存在于哪个位置? 其中缓冲区存在于Socket Library层。 而我们的发送窗口和接收窗口就存在于缓冲区当中。在实现滑动窗口时则将两个指针指向缓冲区…

STM32F103调试DMA+PWM 实现占空比逐渐增加的软启效果

实现效果:DMAPWM 实现PWM输出时,从低电平到输出占空比逐渐增加再到保持高电平的效果,达到控制 MOS 功率开关软启的效果。 1.配置时钟 2.TIM 的 PWM 功能配置 选择、配置 TIM 注意:选择 TIM 支持 DMA 控制输出 PWM 功能的通道&a…

使用Unity的准备

下载Unity 下载Unity Hub Unity - 实时内容开发平台 | 3D、2D、VR & AR可视化https://unity.cn/ 创建账号或者登入账号 Unity安装 路径尽量为英文路径 登入账号 点击头像登入账号 这里已经登入 打开偏好 设置中文 添加许可证 获取免费版的即可 安装编辑器 新建项目…

mysql-PXC实现高可用

mysql8.0使用PXC实现高可用 什么是 PXC PXC 是一套 MySQL 高可用集群解决方案,与传统的基于主从复制模式的集群架构相比 PXC 最突出特点就是解决了诟病已久的数据复制延迟问题,基本上可以达到实时同步。而且节点与节点之间,他们相互的关系是…

PHP一站式解决方案高级房产系统小程序源码

一站式解决方案,高级房产系统让房产管理更轻松 🏠【开篇:告别繁琐,迎接高效房产管理新时代】🏠 你是否还在为房产管理的繁琐流程而头疼?从房源录入、客户咨询到合同签订、售后服务,每一个环节…

【CSS】如何写渐变色文字并且有打光效果

效果如上,其实核心除了渐变色文字的设置 background: linear-gradient(270deg, #d2a742 94%, #f6e2a7 25%, #d5ab4a 48%, #f6e2a7 82%, #d1a641 4%);color: #e8bb2c;background-clip: text;color: transparent;还有就是打光效果,原理其实就是两块遮罩&am…

7、关于LoFTR

7、关于LoFTR LoFTR论文链接:LoFTR LoFTR的提出,是将Transformer模型的注意力机制在特征匹配方向的应用,Transformer的提取特征的机制,在自身进行,本文提出可以的两张图像之间进行特征计算,非常适合进行特…

“弹性盒子”一维布局系统(补充)——WEB开发系列31

弹性盒子是一种一维布局方法,用于根据行或列排列元素。元素可以扩展以填补多余的空间,或者缩小以适应较小的空间,为容器中的子元素提供灵活的且一致的布局方式。 一、什么是弹性盒子? CSS 弹性盒子(Flexible Box Layo…

提高开发效率的实用工具库VueUse

VueUse中文网:https://vueuse.nodejs.cn/ 使用方法 安装依赖包 npm i vueuse/core单页面使用(useThrottleFn举例) import { useThrottleFn } from "vueuse/core"; // 表单提交 const handleSubmit useThrottleFn(() > {// 具…

策略模式的小记

策略模式 策略模式支付系统【场景再现】硬编码完成不同的支付策略使用策略模式,对比不同(1)支付策略接口(2)具体的支付策略类(3)上下文(4)客户端(5&#xff0…

python 交互模式怎么切换目录

假如要用交互界面调用一个.py文件: (1)用cmd界面定位到文件位置,如cd Desktop/data/ #进入desktop下data目录。 (2)接着打开python(输入python) 调用os (1&#xff0…

Linux df命令详解,Linux查看磁盘使用情况

《网络安全自学教程》 df 一、字段解释二、显示单位三、汇总显示四、指定目录五、指定显示字段六、du和df结果不一样 df(disk free)命令用来查看系统磁盘空间使用情况。 参数: -h:(可读性)显示单位&#…

Mobile-Agent赛题分析和代码解读笔记(DataWhale AI夏令营)

前言 你好,我是GISer Liu,一名热爱AI技术的GIS开发者,本文是DataWhale 2024 AI夏令营的最后一期——Mobile-Agent赛道,关于赛题分析和代码解读的学习文档总结;这边作者也会分享自己的思路; 本文是对原视频的…

万象奥科参展“2024 STM32全国巡回研讨会”—深圳站、广州站

9月3日-9月5日,万象奥科参展“2024 STM32全国巡回研讨会”— 深圳站、广州站。此次STM32研讨会将会走进全国11个城市,展示STM32在智能工业、无线连接、边缘人工智能、安全、图形用户界面等领域的产品解决方案及多样化应用实例,深入解读最新的…