基于知识蒸馏的两阶段去雨去雪去雾模型学习记录(三)之知识测试阶段与评估模块

去雨去雾去雪算法分为两个阶段,分别是知识收集阶段与知识测试阶段,前面我们已经学习了知识收集阶段,了解到知识阶段的特征迁移模块(CKT)与软损失(SCRLoss),那么在知识收集阶段的主要重点便是HCRLoss(硬损失),事实上,知识测试阶段要比知识收集阶段简单,因为这个模块只需要训练学生网络即可。

模型创新点

在进行知识测试阶段的代码学习之前,我们来回顾一下去雨去雪去雾网络的创新点:
首先是提出两阶段的知识蒸馏网络,即构建三个教师网络与一个学生网络,设置总训练次数为250,其中前125个epoch教师网络与学生网络一同训练,这里的训练是指将图像输入教师网络,随后将教师网络的输出结果与中间特征图保留,将其作为真值指导学生网络进行训练。
其次便是提出知识迁移模块(CKT)该模块的作用是将教师网络的特征迁移到学生网络。
随后便是软损失与硬损失计算了,这个其实是知识蒸馏中的概念。
总体来看去雨去雾去雪网络的设计虽然较为新颖,但事实上就是知识蒸馏网络的架构,本着这一点,程序理解起来也就容易多了。

在这里插入图片描述

接下来开始代码的学习:

小插曲(算力不足)

首先需要指出,前面将batch-size设置为4,但却会报错:

RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED

开始时博主以为是cuDNN与CUDA版本不匹配导致的,但后来一想不对呀,先前已经运行过呀,那么问题很可能便是batch出问题了,果然将batch改为3后就正常了,这是由于算力不足导致的,注意算力不足和显存不足还是有区别的。
将batch-size改为3后重新运行,开始知识测试阶段的探索。

知识测试阶段

事实上,知识测试阶段的实现与知识收集阶段几乎相同,并且要比知识收集阶段简单,其只是训练学生网络,并计算一个硬损失而已。
由于知识测试阶段与知识收集阶段几乎相同,因此有许多地方是重复的,这里博主便会简要介绍。
首先相同的是使用train_loader进行训练集的加载,并使用tqdm进行封装。
随后便是遍历过程,这个过程就要简单很多了,没有使用到教师网络,直接将图像输入学生网络进行预测即可,这里的学生网络与教师网络的构造是完全相同的,将结果分别计算L1损失与HCR_loss即可。不过需要注意的是由于该阶段不需要与教师网络进行特征迁移,因此就不需要返回中间特征图了,即设置return_feat=False

for target_images, input_images in pBar:if target_images is None: continuetarget_images = target_images.cuda()input_images = torch.cat(input_images).cuda()preds = model(input_images, return_feat=False)G_loss = criterion_l1(preds, target_images)HCR_loss = 0.2 * criterion_hcr(preds, target_images, input_images)total_loss = G_loss + HCR_loss

至于其他的基本就相同了,需要注意的是这里的batch设置为3。接下来记录一下数据的变化情况:

input_images:输入图像,torch.Size([3, 3, 224, 224])第一个3是指图像数量,第二个3是指通道维度
target_images:目标图像(真值),torch.Size([3, 3, 224, 224])第一个3是指图像数量,第二个3是指通道维度
preds:预测图像(去噪后的图像),torch.Size([3, 3, 224, 224])第一个3是指图像数量,第二个3是指通道维度

在这里插入图片描述

随后计算L1损失与HCRLoss,由于在学生网络中使用的事实上是混合数据集,即不区分去噪类型,因此输入图像等都是直接使用tesnor格式,而非list格式。

G_loss:tensor(0.5621, device='cuda:0', grad_fn=<L1LossBackward>)

HCRLoss

SCRLoss相同,HCRLoss也是先将图像进行特征转换后再计算损失的

HCRLoss((vgg): Vgg19((slice1): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True))(slice2): Sequential((2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True))(slice3): Sequential((7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True))(slice4): Sequential((12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(17): ReLU(inplace=True)(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True))(slice5): Sequential((21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(24): ReLU(inplace=True)(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(26): ReLU(inplace=True)(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)))(l1): L1Loss()
)
HCRLoss:tensor(0.3274, device='cuda:0', grad_fn=<MulBackward0>)

评估模块

至此,知识测试阶段便完成了,随后便是模型评估了。这里默认设置评估时的batch-size为1,即每次输入一张图像。
所谓的评估指的是对学生网络的评估,该模块其实与知识测试阶段类似,不同之处在于这里是需要计算SSIMPSNR的。至于其他则是完全相同,核心代码如下:

for target, image in pBar:if torch.cuda.is_available():image = image.cuda()target = target.cuda()pred = model(image)   		psnr_list.append(torchPSNR(pred, target).item())ssim_list.append(pytorch_ssim.ssim(pred, target).item())

由于batch-size设置为1,因此targettorch.Size([1, 3, 480, 640])image也为torch.Size([1, 3, 480, 640]),这里需要注意的是,在训练阶段(包含知识收集与知识测试阶段),数据集中的图像都要转换为224x224的大小,而在评估阶段则不需要进行转换了,即使用的是原图像的大小。
直接将输入图输入模型,获的去噪后的图像pred大小为torch.Size([1, 3, 480, 640])

pred = model(image)

在这里插入图片描述

随后将预测图像与真值图像进行计算PSNR与SSIM

psnr_list.append(torchPSNR(pred, target).item())
ssim_list.append(pytorch_ssim.ssim(pred, target).item())

PSNR计算

@torch.no_grad()
def torchPSNR(prd_img, tar_img):if not isinstance(prd_img, torch.Tensor):prd_img = torch.from_numpy(prd_img)tar_img = torch.from_numpy(tar_img)imdff = torch.clamp(prd_img, 0, 1) - torch.clamp(tar_img, 0, 1)rmse = (imdff**2).mean().sqrt()ps = 20 * torch.log10(1/rmse)return ps

SSIM计算

class SSIM(torch.nn.Module):def __init__(self, window_size = 11, size_average = True):super(SSIM, self).__init__()self.window_size = window_sizeself.size_average = size_averageself.channel = 1self.window = create_window(window_size, self.channel)def forward(self, img1, img2):(_, channel, _, _) = img1.size()if channel == self.channel and self.window.data.type() == img1.data.type():window = self.windowelse:window = create_window(self.window_size, channel)            if img1.is_cuda:window = window.cuda(img1.get_device())window = window.type_as(img1)        self.window = windowself.channel = channelreturn _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):(_, channel, _, _) = img1.size()window = create_window(window_size, channel)  if img1.is_cuda:window = window.cuda(img1.get_device())window = window.type_as(img1)return _ssim(img1, img2, window, window_size, channel, size_average)

将每个循环得到的psnrssim加入列表

在这里插入图片描述
最后的PSNRSSIM是对list中的所有值求平均:

print("PSNR: {:.3f}".format(np.mean(psnr_list)))
print("SSIM: {:.3f}".format(np.mean(ssim_list)))

至此,知识测试阶段与评估模块就讲解完成了,接下来博主将对该模型进行改进。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/150536.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity可视化Shader工具ASE介绍——2、ASE的Shader创建和输入输出

大家好&#xff0c;我是阿赵&#xff0c;这里继续介绍Unity可视化写Shader的ASE插件的用法。上一篇介绍了ASE的安装和编辑器界面分布&#xff0c;这一篇主要是通过一个简单的例子介绍shader的创建和输入输出。 一、ASE的Shader创建 这里先选择Surface类型的Shader&#xff0c;…

Java高级之反射

关于反射的举例&#xff1a; 示例代码&#xff1a;Fan.java package testFanShe;/*** author: Arbicoral* Description: 测试反射&#xff1a;* 成员变量&#xff1a;2个public&#xff0c;2个private* 构造器&#xff1a;4个public&#x…

Glide源码分析

一&#xff0c;Glide一次完整的加载流程 下面的流程图是一次完整的使用Glide加载图片流程,时序图 二&#xff0c;Glide重要的类图 三&#xff0c;Glide加载图片 流程图

【软件工程_UML—StartUML作图工具】startUML怎么画interface接口

StartUML作图工具怎么画interface接口 初试为圆形 &#xff0c;点击该接口在右下角的设置中->Format->Stereotype Display->Label&#xff0c;即可切换到想要的样式 其他方式 在class diagram下&#xff0c;左侧有interface图标&#xff0c;先鼠标左键选择&#xff0…

氟化钙光学窗口保护镜片 光学元件红外测温窗口保护片

氟化钙光学窗口保护镜片 光学元件红外测温窗口保护片 常见镜片材料 特此记录 anlog 2023年10月7日

大语言模型之十四-PEFT的LoRA

在《大语言模型之七- Llama-2单GPU微调SFT》和《大语言模型之十三 LLama2中文推理》中我们都提到了LoRA&#xff08;低秩分解&#xff09;方法&#xff0c;之所以用低秩分解进行参数的优化的原因是为了减少计算资源。 我们以《大语言模型之四-LlaMA-2从模型到应用》一文中的图…

windows内核编程(2021年出版)笔记

1. Windows内部概览 1.1 进程 进程包含以下内容&#xff1a; 可执行程序&#xff0c;代码和数据私有的虚拟地址空间&#xff0c;分配内存时从这里分配主令牌&#xff0c;保存进程默认安全上下文&#xff0c;进程中的线程执行代码时会用到它私有句柄表&#xff0c;保存进程运…

QT实现tcp服务器客户端

服务器.cpp #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this);//实例化一个服务器server new QTcpServer(this);// 此时&#xff0c;服务器已经成功进入监听状态…

DirectX12_Windows_GameDevelop_3:Direct3D的初始化

引言 查看龙书时发现&#xff0c;第四章介绍预备知识的代码不太利于学习。因为它不像是LearnOpenGL那样从头开始一步一步教你敲代码&#xff0c;导致你没有一种整体感。如果你把它当作某一块的代码进行学习&#xff0c;你跟着敲会发现&#xff0c;总有几个变量是没有定义的。这…

Linux系统及Docker安装RabbitMq

目录 一、linux系统安装 1、上传文件 2、在线安装依赖环境 3、安装Erlang 4、安装RabbitMQ 5、开启管理界面及配置 6、启动 7、删除mq 二、docker安装 1、上传mq.tar包或使用命令拉取镜像 2、启动并运行 3、访问mq 一、linux系统安装 1、上传文件 2、在线安装依赖环…

3. 无重复字符的最长子串(枚举+滑动窗口)

目录 一、题目 二、代码 一、题目 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 二、代码 class Solution { public:int lengthOfLongestSubstring(string s) {int _MaxLength 0;int left 0, right 0;vector<int>hash(128, 0);//ASCII…

Qt扫盲-QTreeView 理论总结

QTreeView 理论使用总结 一、概述二、快捷键绑定三、提高性能四、简单实例1. 设计与概念2. TreeItem类定义3. TreeItem类的实现4. TreeModel类定义5. TreeModel类实现6. 在模型中设置数据 一、概述 QTreeView实现了 model 中item的树形表示。这个类用于提供标准的层次列表&…

C#上位机——根据命令发送

C#上位机——根据命令发送 第一步&#xff1a;设置窗口的布局 第二步&#xff1a;设置各个属性 第三步&#xff1a;编写各个模块之间的关系

第九课 排序

文章目录 第九课 排序排序算法lc912.排序数组--中等题目描述代码展示 lc1122.数组的相对排序--简单题目描述代码展示 lc56.合并区间--中等题目描述代码展示 lc215.数组中的第k个最大元素--中等题目描述代码展示 acwing104.货仓选址--简单题目描述代码展示 lc493.翻转树--困难题…

OMV6 安装Extras 插件失败的解决方法

# Time: 2023/10/07 #Author: Xiaohong # 运行环境: OS: OMV6 # 功能: 安装Extras 插件失败的解决方法 问题描述&#xff1a;OMV6 安装插件omv-extras&#xff0c;只能按如下提示的命令行&#xff0c;但安装过程中&#xff0c;会提示raw.githubusercontent.com 无法访问插…

抖音账号矩阵系统开发源码----技术研发

一、技术自研框架开发背景&#xff1a; 抖音账号矩阵系统是一种基于数据分析和管理的全新平台&#xff0c;能够帮助用户更好地管理、扩展和营销抖音账号。 抖音账号矩阵系统开发源码 部分源码分享&#xff1a; ic function indexAction() { //面包屑 $breadc…

【QT5-程序控制电源-RS232-SCPI协议-上位机-基础样例【1】】

【QT5-程序控制电源-RS232-SCPI协议-上位机-基础样例【1】】 1、前言2、实验环境3、自我总结1、基础了解仪器控制-熟悉仪器2、连接SCPI协议3、选择控制方式-程控方式-RS2324、代码编写 4、熟悉协议-SCPI协议5、测试实验-测试指令&#xff08;1&#xff09;硬件连接&#xff08;…

再来介绍另一个binlog文件解析的第三方工具my2sql

看腻了文字就来听听视频演示吧&#xff1a;https://www.bilibili.com/video/BV1rp4y1w74B/ github项目&#xff1a;https://github.com/liuhr/my2sql gitee链接&#xff1a;https://gitee.com/mirrors/my2sql my2sql go版MySQL binlog解析工具&#xff0c;通过解析MySQL bin…

8.2 JUC - 4.Semaphore

目录 一、是什么&#xff1f;二、简单使用三、semaphore应用四、Semaphore原理 一、是什么&#xff1f; Semaphore&#xff1a;信号量&#xff0c;用来限制能同时访问共享资源的线程上限 二、简单使用 public class TestSemaphore {public static void main(String[] args) …

H桥级联型五电平三相逆变器Simulink仿真模型

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…