深度学习之超分辨率算法——FRCNN

– 对之前SRCNN算法的改进

    1. 输出层采用转置卷积层放大尺寸,这样可以直接将低分辨率图片输入模型中,解决了输入尺度问题。
    2. 改变特征维数,使用更小的卷积核和使用更多的映射层。卷积核更小,加入了更多的激活层。
    3. 共享其中的映射层,如果需要训练不同上采样倍率的模型,只需要修改最后的反卷积层大小,就可以训练出不同尺寸的图片。
  • 模型实现
  • 在这里插入图片描述
import math
from torch import nnclass FSRCNN(nn.Module):def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):super(FSRCNN, self).__init__()self.first_part = nn.Sequential(nn.Conv2d(num_channels, d, kernel_size=5, padding=5//2),nn.PReLU(d))# 添加入多个激活层和小卷积核self.mid_part = [nn.Conv2d(d, s, kernel_size=1), nn.PReLU(s)]for _ in range(m):self.mid_part.extend([nn.Conv2d(s, s, kernel_size=3, padding=3//2), nn.PReLU(s)])self.mid_part.extend([nn.Conv2d(s, d, kernel_size=1), nn.PReLU(d)])self.mid_part = nn.Sequential(*self.mid_part)# 最后输出self.last_part = nn.ConvTranspose2d(d, num_channels, kernel_size=9, stride=scale_factor, padding=9//2,output_padding=scale_factor-1)self._initialize_weights()def _initialize_weights(self):# 初始化for m in self.first_part:if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))nn.init.zeros_(m.bias.data)for m in self.mid_part:if isinstance(m, nn.Conv2d):nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))nn.init.zeros_(m.bias.data)nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001)nn.init.zeros_(self.last_part.bias.data)def forward(self, x):x = self.first_part(x)x = self.mid_part(x)x = self.last_part(x)return x

以上代码中,如起初所说,将SRCNN中给的输出修改为转置卷积,并且在中间添加了多个11卷积核和多个线性激活层。且应用了权重初始化,解决协变量偏移问题。
备注:1
1卷积核虽然在通道的像素层面上,针对一个像素进行卷积,貌似没有什么作用,但是卷积神经网络的特性,我们在利用多个卷积核对特征图进行扫描时,单个卷积核扫描后的为sum©,那么就是尽管在像素层面上无用,但是在通道层面上进行了融合,并且进一步加深了层数,使网络层数增加,网络能力增强。

  • 上代码
  • train.py

训练脚本

import argparse
import os
import copyimport torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdmfrom models import FSRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnrif __name__ == '__main__':parser = argparse.ArgumentParser()# 训练文件parser.add_argument('--train-file', type=str,help="the dir of train data",default="./Train/91-image_x4.h5")# 测试集文件parser.add_argument('--eval-file', type=str,help="thr dir of test data ",default="./Test/Set5_x4.h5")# 输出的文件夹parser.add_argument('--outputs-dir',help="the output dir", type=str,default="./outputs")parser.add_argument('--weights-file', type=str)parser.add_argument('--scale', type=int, default=2)parser.add_argument('--lr', type=float, default=1e-3)parser.add_argument('--batch-size', type=int, default=16)parser.add_argument('--num-epochs', type=int, default=20)parser.add_argument('--num-workers', type=int, default=8)parser.add_argument('--seed', type=int, default=123)args = parser.parse_args()args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))if not os.path.exists(args.outputs_dir):os.makedirs(args.outputs_dir)cudnn.benchmark = Truedevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')torch.manual_seed(args.seed)model = FSRCNN(scale_factor=args.scale).to(device)criterion = nn.MSELoss()optimizer = optim.Adam([{'params': model.first_part.parameters()},{'params': model.mid_part.parameters()},{'params': model.last_part.parameters(), 'lr': args.lr * 0.1}], lr=args.lr)train_dataset = TrainDataset(args.train_file)train_dataloader = DataLoader(dataset=train_dataset,batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers,pin_memory=True)eval_dataset = EvalDataset(args.eval_file)eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)best_weights = copy.deepcopy(model.state_dict())best_epoch = 0best_psnr = 0.0for epoch in range(args.num_epochs):model.train()epoch_losses = AverageMeter()with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size), ncols=80) as t:t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))for data in train_dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)preds = model(inputs)loss = criterion(preds, labels)epoch_losses.update(loss.item(), len(inputs))optimizer.zero_grad()loss.backward()optimizer.step()t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))t.update(len(inputs))torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))model.eval()epoch_psnr = AverageMeter()for data in eval_dataloader:inputs, labels = datainputs = inputs.to(device)labels = labels.to(device)with torch.no_grad():preds = model(inputs).clamp(0.0, 1.0)epoch_psnr.update(calc_psnr(preds, labels), len(inputs))print('eval psnr: {:.2f}'.format(epoch_psnr.avg))if epoch_psnr.avg > best_psnr:best_epoch = epochbest_psnr = epoch_psnr.avgbest_weights = copy.deepcopy(model.state_dict())print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

test.py 测试脚本

import argparseimport torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_imagefrom models import FSRCNN
from utils import convert_ycbcr_to_rgb, preprocess, calc_psnrif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--weights-file', type=str, required=True)parser.add_argument('--image-file', type=str, required=True)parser.add_argument('--scale', type=int, default=3)args = parser.parse_args()cudnn.benchmark = Truedevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')model = FSRCNN(scale_factor=args.scale).to(device)state_dict = model.state_dict()for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():if n in state_dict.keys():state_dict[n].copy_(p)else:raise KeyError(n)model.eval()image = pil_image.open(args.image_file).convert('RGB')image_width = (image.width // args.scale) * args.scaleimage_height = (image.height // args.scale) * args.scalehr = image.resize((image_width, image_height), resample=pil_image.BICUBIC)lr = hr.resize((hr.width // args.scale, hr.height // args.scale), resample=pil_image.BICUBIC)bicubic = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)bicubic.save(args.image_file.replace('.', '_bicubic_x{}.'.format(args.scale)))lr, _ = preprocess(lr, device)hr, _ = preprocess(hr, device)_, ycbcr = preprocess(bicubic, device)with torch.no_grad():preds = model(lr).clamp(0.0, 1.0)psnr = calc_psnr(hr, preds)print('PSNR: {:.2f}'.format(psnr))preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)output = pil_image.fromarray(output)# 保存图片output.save(args.image_file.replace('.', '_fsrcnn_x{}.'.format(args.scale)))

datasets.py

数据集的读取

import h5py
import numpy as np
from torch.utils.data import Datasetclass TrainDataset(Dataset):def __init__(self, h5_file):super(TrainDataset, self).__init__()self.h5_file = h5_filedef __getitem__(self, idx):with h5py.File(self.h5_file, 'r') as f:return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)def __len__(self):with h5py.File(self.h5_file, 'r') as f:return len(f['lr'])class EvalDataset(Dataset):def __init__(self, h5_file):super(EvalDataset, self).__init__()self.h5_file = h5_filedef __getitem__(self, idx):with h5py.File(self.h5_file, 'r') as f:return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)def __len__(self):with h5py.File(self.h5_file, 'r') as f:return len(f['lr'])

工具文件utils.py

  • 主要用来测试psnr指数,图片的格式转换(悄悄说一句,opencv有直接实现~~~)
import torch
import numpy as npdef calc_patch_size(func):def wrapper(args):if args.scale == 2:args.patch_size = 10elif args.scale == 3:args.patch_size = 7elif args.scale == 4:args.patch_size = 6else:raise Exception('Scale Error', args.scale)return func(args)return wrapperdef convert_rgb_to_y(img, dim_order='hwc'):if dim_order == 'hwc':return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.else:return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.def convert_rgb_to_ycbcr(img, dim_order='hwc'):if dim_order == 'hwc':y = 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.cb = 128. + (-37.945 * img[..., 0] - 74.494 * img[..., 1] + 112.439 * img[..., 2]) / 256.cr = 128. + (112.439 * img[..., 0] - 94.154 * img[..., 1] - 18.285 * img[..., 2]) / 256.else:y = 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.cb = 128. + (-37.945 * img[0] - 74.494 * img[1] + 112.439 * img[2]) / 256.cr = 128. + (112.439 * img[0] - 94.154 * img[1] - 18.285 * img[2]) / 256.return np.array([y, cb, cr]).transpose([1, 2, 0])def convert_ycbcr_to_rgb(img, dim_order='hwc'):if dim_order == 'hwc':r = 298.082 * img[..., 0] / 256. + 408.583 * img[..., 2] / 256. - 222.921g = 298.082 * img[..., 0] / 256. - 100.291 * img[..., 1] / 256. - 208.120 * img[..., 2] / 256. + 135.576b = 298.082 * img[..., 0] / 256. + 516.412 * img[..., 1] / 256. - 276.836else:r = 298.082 * img[0] / 256. + 408.583 * img[2] / 256. - 222.921g = 298.082 * img[0] / 256. - 100.291 * img[1] / 256. - 208.120 * img[2] / 256. + 135.576b = 298.082 * img[0] / 256. + 516.412 * img[1] / 256. - 276.836return np.array([r, g, b]).transpose([1, 2, 0])def preprocess(img, device):img = np.array(img).astype(np.float32)ycbcr = convert_rgb_to_ycbcr(img)x = ycbcr[..., 0]x /= 255.x = torch.from_numpy(x).to(device)x = x.unsqueeze(0).unsqueeze(0)return x, ycbcrdef calc_psnr(img1, img2):return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))class AverageMeter(object):def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count

先跑他个几十轮~
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/495329.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3项目history路由模式部署上线405、刷新404问题(包括部分页面刷新404问题)

一、找不到js模块 解决方法:配置Nginx配置文件: // root /your/program/path/dist root /www/wwwroot/my_manage_backend_v1/dist;二、刷新页面导致404问题(Not found) 经过一系列配置后发现进入页面一切正常,包括路由前进和回退&#xff0…

微服务篇-深入了解 XXL-JOB 分布式任务调度的具体使用(XXL-JOB 的工作流程、框架搭建)

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 XXL-JOB 调度中心概述 1.2 XXL-JOB 工作流程 1.3 Cron 表达式调度 2.0 XXL-JOB 框架搭建 2.1 XXL-JOB 调度中心的搭建 2.2 XXL-JOB 执行器的搭建 2.3 使用调度中心…

JS中若干相似特性的区别

Object.is与的区别? 其他时候都相等 字符串concat()和号有什么区别? 数组at和直接索引区别 at里是负值,计算方法是:数组的长度加这个负值,得到的数作为索引值 substring与slice的区别 substring是负值,则视为0,等于全部复制 slice是负值,则从后往前复制,-2就是复制最后2个字…

Fuel库实战:下载失败时的异常处理策略

Fuel库作为一个轻量级的Kotlin HTTP客户端库,因其简洁的API和强大的功能而受到开发者的青睐。然而,网络请求总是伴随着失败的风险,比如网络不稳定、服务器错误、资源不存在等。因此,合理地处理这些异常情况对于提升用户体验和应用…

vscode插件更新特别慢的问题

点击插件标题去网页查看 命令行安装 D:\Software\VSCode\Code.exe --extensions-dir "D:\Software\VSCode\extendions" --install-extension Vue.volar-2.2.0.vsix安装完成之后重启vs code即可 参考 https://www.cnblogs.com/yiquanfeng/p/18218722

2.利用docker进行gitlab服务器迁移

一、Docker安装 安装Ubuntu 22.04.3 LTS \n \l 1、旧版本安装包清理 sudo apt-get remove docker docker-engine docker.io containerd runc当你卸载Docker时,存储在/var/lib/docker/中的图像、容器、卷和网络不会自动删除。如果你想从一个干净的安装开始&#x…

大型语言模型(LLMs)演化树 Large Language Models

大型语言模型(LLMs)演化树 Large Language Models flyfish 下面的图来自论文地址 Transformer 模型(如 BERT 和 GPT-3)已经给自然语言处理(NLP)领域带来了革命性的变化。这得益于它们具备并行化能力&…

springboot477基于vue技术的农业设备租赁系统(论文+源码)_kaic

摘 要 使用旧方法对农业设备租赁系统的信息进行系统化管理已经不再让人们信赖了,把现在的网络信息技术运用在农业设备租赁系统的管理上面可以解决许多信息管理上面的难题,比如处理数据时间很长,数据存在错误不能及时纠正等问题。这次开发的农…

如何在 Ubuntu 22.04 上安装和使用 Composer

简介 如果你是一名 PHP 开发者,想要简化你的项目依赖管理,那么 Composer 是一个必不可少的工具。Composer 可以简化包管理,并允许你轻松地将外部库集成到你的项目中。 本教程将向你展示如何在 Ubuntu 22.04 操作系统上安装 Composer&#x…

16_HTML5 语义元素 --[HTML5 API 学习之旅]

HTML5 引入了许多新的语义元素,这些元素有助于创建结构更清晰、更具描述性的网页。语义化 HTML 不仅改善了代码的可读性,还增强了搜索引擎优化(SEO),提高了无障碍访问性,并使得开发者更容易理解和维护代码。…

国标GB28181视频监控平台与Liveweb视频监控汇聚平台对接方案

应急管理部门以“以信息化推动应急管理能力现代化”为总体目标,加快现代信息技术与应急管理业务深度融合,全面支持现代应急管理体系建设,这不仅是国家加强和改进应急管理工作的关键举措,也是应对日益严峻的应急管理形势和满足公众…

内部知识库的未来展望:技术融合与用户体验的双重升级

在当今数字化飞速发展的时代,企业内部知识库作为知识管理的关键载体,正站在变革的十字路口,即将迎来技术融合与用户体验双重升级的崭新时代,这一系列变化将深度重塑企业知识管理的格局。 一、技术融合:开启知识管理新…

EasyGBS国标GB28181公网平台P2P远程访问故障诊断:云端服务端排查指南

随着信息技术的飞速发展,视频监控领域正经历从传统安防向智能化、网络化安防的深刻转变。EasyGBS平台,作为基于国标GB28181协议的视频流媒体平台,为用户提供了强大的视频监控直播功能。然而,在实际应用中,P2P远程访问可…

HW护网分析研判思路,流量告警分析技巧

《网络安全自学教程》 这篇文章,写给每一个「护网黑奴」,为初次护网的小伙伴普及一下护网工作内容,提供一些简单的分析思路。 护网分析研判思路 1、护网组织架构和责任划分1.1、安全监控1.2、分析研判1.3、应急处置 2、分析研判2.1、判断告警…

springBoot发布https服务及调用

一、服务端发布https服务 1、准备SSL证书 (1)自签名证书:如果你只是用于开发或测试环境,可以生成一个自签名证书。 (2)CA 签名证书:对于生产环境,应该使用由受信任的证书颁发机构 …

Web 第一次作业 初探html 使用VSCode工具开发

目录 初探html? 代码展示&#xff1a; 初探html 大多数代码都比较冗长 不是很简洁 还有许多标签功能不会使用 记录一下成长过程 哈哈哈哈哈&#xff01;<–_–> 代码展示&#xff1a; 12.10首次确定书写对象 牢9门 <!DOCTYPE html> <html lang"en&quo…

分别查询 user 表中 avatar 和 nickname 列为空的用户数量

文章目录 1、要查询 user 表中 avatar 列为空的用户数量2、要查询 user 表中 nickname 列为空的用户数量 1、要查询 user 表中 avatar 列为空的用户数量 好的&#xff0c;要查询 user 表中 avatar 列为空的用户数量&#xff0c;你可以使用以下 SQL 查询语句&#xff1a; SELE…

【批量生成WORD和PDF文件】根据表格内容和模板文件批量创建word文件,一次性生成多个word文档和批量创建PDF文件

如何按照Word模板和表格的数据快速制作5000个word文档 &#xff1f; 在与客户的合作的中需要创建大量的合同&#xff0c;这些合同的模板大概都是一致的&#xff0c;是不是每次我们都需要填充不一样的数据来完成&#xff1f; 今天用表格数据完成合同模板的填充&#xff0c;批量…

DX12 快速教程(2) —— 渲染天蓝色窗口

快速导航 新建项目 "002-DrawSkyblueWindow"DirectX 12 入门1. COM 技术&#xff1a;DirectX 的中流砥柱什么是 COM 技术COM 智能指针 2.创建 D3D12 调试层设备&#xff1a;CreateDebugDevice什么是调试层如何创建并使用调试层 3.创建 D3D12 设备&#xff1a;CreateD…

【MySQL】7.0 入门学习(七)——MySQL基本指令:帮助、清除输入、查询等

1.0 help &#xff1f; 帮助指令&#xff0c;查询某个指令的解释、用法、说明等。详情参考博文&#xff1a; 【数据库】6.0 MySQL入门学习&#xff08;六&#xff09;——MySQL启动与停止、官方手册、文档查询 https://www.cnblogs.com/xiaofu007/p/10301005.html 2.0 在cmd命…