根据DCT特征训练CNN

记录一次改代码的挣扎经历:
        看了几篇关于DCT频域的深度模型文献,尤其是21年FcaNet:基于DCT 的attention model,咱就是说想试试将我模型的输入改为分组的DCT系数,然后就开始下面的波折了。

第一次尝试:

        我直接调用了库函数,然后出现问题了:这个库函数是应用在numpy数组上,得在CPU上处理。

from scipy.fftpack import dct, idct
...
dct_block = dct(dct(block, axis=2, norm='ortho'), axis=3, norm='ortho')   # [B,C,k,k]
...
block = idct(idct(dct_block, axis=2, norm='ortho'), axis=3, norm='ortho')    # [B,C,k,k]

第二次尝试:
        好吧,我先把数据调回CPU,处理后,再调回GPU,又有新问题了:这样做(将block从GPU转移至CPU)torch类型张量转换为numpy数组时,torch张量的梯度无法保存。

# 图像分块
...
# 将块转移到 CPU
block_cpu = block.cpu()        # [B,C,k,k]
# 在 CPU 上对块应用 DCT
dct_block_np = dct(dct(block_cpu.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')   # [B,C,k,k]
# 将结果传输回 GPU
dct_block = torch.from_numpy(dct_block_np).to(image.device)     # [B,C,k,k]...# 将块转移到 CPU
dct_block_cpu = dct_block.cpu()
# 在 CPU 上对块应用逆 DCT
block_np = idct(idct(dct_block_cpu.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')
# 将结果传输回 GPU
block = torch.from_numpy(block_np).to(dct_block.device)    # [B,C,k,k]

 第三次尝试:

        根据报错提醒,我进行以下改进,将block_cpu.numpy -> block_cpu.detach.numpy(),即忽略掉torch类型张量带着的梯度信息,哈哈,这样一改,梯度就丢失了,模型就不能反向传播进行更新训练了。

# 图像分块
...
# 将块转移到 CPU
block_cpu = block.cpu()        # [B,C,k,k]
# 在 CPU 上对块应用 DCT
dct_block_np = dct(dct(block_cpu.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')   # [B,C,k,k]
# 将结果传输回 GPU
dct_block = torch.from_numpy(dct_block_np).to(image.device)     # [B,C,k,k]...# 将块转移到 CPU
dct_block_cpu = dct_block.cpu()
# 在 CPU 上对块应用逆 DCT
block_np = idct(idct(dct_block_cpu.detach.numpy(), axis=2, norm='ortho'), axis=3, norm='ortho')
# 将结果传输回 GPU
block = torch.from_numpy(block_np).to(dct_block.device)    # [B,C,k,k]

第四次尝试:
        CPU上库函数不好用,那我自己写(借鉴)DCT变换的函数嘛,DCT就是输入k*k图像关于k*k个余弦基函数的加权和嘛:

 别人写的的8 x 8d的DCT和IDCT的实现:


class DCT8X8(nn.Module):""" Discrete Cosine TransformationInput:image(tensor): batch x height x widthOutput:dcp(tensor): batch x height x width"""def __init__(self):super(DCT8X8, self).__init__()tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)for x, y, u, v in itertools.product(range(8), repeat=4):tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)alpha = np.array([1. / np.sqrt(2)] + [1] * 7)self.tensor = nn.Parameter(torch.from_numpy(tensor).float())self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())def forward(self, image):image = image - 128result = self.scale * torch.tensordot(image, self.tensor, dims=2)result.view(image.shape)return resultclass IDCT8X8(nn.Module):""" Inverse discrete Cosine TransformationInput:dcp(tensor): batch x height x widthOutput:image(tensor): batch x height x width"""def __init__(self):super(IDCT8X8, self).__init__()alpha = np.array([1. / np.sqrt(2)] + [1] * 7)self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)for x, y, u, v in itertools.product(range(8), repeat=4):tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16)self.tensor = nn.Parameter(torch.from_numpy(tensor).float())def forward(self, image):image = image * self.alpharesult = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128result.view(image.shape)return result

我根据上述改的任意block_size的DCT和IDCT:

class DCTCustom(nn.Module):"""Customizable Discrete Cosine TransformationInput:image(tensor): batch x height x widthOutput:dct(tensor): batch x height x width"""def __init__(self, input_size=8):super(DCTCustom, self).__init__()self.input_size = input_sizetensor = np.zeros((input_size, input_size, input_size, input_size), dtype=np.float32)for x, y, u, v in itertools.product(range(input_size), repeat=4):tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / (2 * input_size)) * np.cos((2 * y + 1) * v * np.pi / (2 * input_size))alpha = np.array([1. / np.sqrt(2)] + [1] * (input_size - 1))self.tensor = nn.Parameter(torch.from_numpy(tensor).float())self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())def forward(self, image):image = image - 128result = self.scale * torch.tensordot(image, self.tensor, dims=2)result = result.view(image.shape)  # Corrected linereturn resultclass IDCTCustom(nn.Module):""" Inverse discrete Cosine TransformationInput:dcp(tensor): batch x height x widthOutput:image(tensor): batch x height x width"""def __init__(self, block_size=8):super(IDCTCustom, self).__init__()self.block_size = block_size# Compute alpha coefficientsalpha = np.array([1. / np.sqrt(2)] + [1] * (block_size - 1))self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())# Compute tensor for IDCTtensor = np.zeros((block_size, block_size, block_size, block_size), dtype=np.float32)for x, y, u, v in itertools.product(range(block_size), repeat=4):tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / (2 * block_size)) * np.cos((2 * v + 1) * y * np.pi / (2 * block_size))self.tensor = nn.Parameter(torch.from_numpy(tensor).float())def forward(self, image):if image.shape[-2] % self.block_size != 0 or image.shape[-1] % self.block_size != 0:raise ValueError("Input dimensions must be divisible by the block size.")# Apply IDCTimage = image * self.alpharesult = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128result = result.view(image.shape)return result

        不出意外的话,问题又出现了,我对一个torch.ones((2,3,k,k))的张量进行DCT,再IDCT恢复。当k=8时(即block_size=8x8)时,能够完全恢复,但当k!=8(=16、32)时,经IDCT后无法恢复原始输入,懵。

第五次尝试(hh):
        突然!我发现了torch内置的DCT函数!可以再GPU上实现DCT。

torch-dct · PyPI

import torch_dct as dct# 图像分块    # [B,C,H,W]...        # [B,C,k,k]# dctblock = dct.dct_2d(block)     # [B,C,k,k]...# idctblock = dct.idct_2d(block)        # [B,C,k,k]

 然后又有问题了:
        我的模型开始训练后,我发现我的每个epoch的loss都为NAN...

        然后我打印了DCT输出,发现DCT系数长这个样子,CNN不高兴好好训练吧。

        我们再想想办法将输入数据归一化到范围[0, 1]或[-1, 1]之间,再喂给CNN吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/225172.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

k8s面试之——简述网络模型

kubernetes网络模型是kubernetes集群中管理容器网络通信的一种机制,用于实现pod间、pod与外部网络间的通信和互联,并提供了多种网络插件和配置选项来满足不同应用场景下的需求。kubernetes网络模型可以分为一下几个部分: 1. pod网络模型 在…

傻瓜式教学Docker 使用docker compose部署 php nginx mysql

首先你可以准备这个三个服务,也可以在docker compose 文件中 直接拉去指定镜像,这里演示的是镜像服务已经在本地安装好了,提供如下: PHP # 设置基础镜像 FROM php:8.2-fpm# install dependencies RUN apt-get update && apt-get install -y \vim \libzip-dev \libpng…

【docker笔记】docker理论及安装

前言 本笔记来源于尚硅谷docker教学视频 视频地址:https://www.bilibili.com/video/BV1gr4y1U7CY/?spm_id_from333.337.search-card.all.click 纯手打笔记,来之不易,感谢支持~ Docker简介 docker为什么会出现 想象一下:一个应用…

【数据结构初阶】二叉树(2)

二叉树顺序结构 1.二叉树的顺序结构及实现1.1二叉树的顺序结构 1.2 堆的概念及结构1.3 堆的实现1.3.1向上调整1.3.2向下调整1.3.3交换函数1.3.4打印1.3.5初始化1.3.6销毁1.3.7插入1.3.8删除1.3.9获得堆顶元素1.3.10判断是否为空1.3.6 堆的代码实现 1.3.2堆的创建1.3.3 建堆时间…

Tg5032smn:高稳定性105℃高温

TG5032SMN是一款频率范围10MHz ~ 54MHz,具有高稳定的TCXO晶振,可与CMOS或限幅正弦输出。外部尺寸5.0 3.2 1.45mm,超小型,质地轻。该系列晶振的额定工作范围-40℃~﹢105C内可高稳定性工作,使得信号频率的误差很小。TG5032SMN与其他…

在k8s中将gitlab-runner的运行pod调度到指定节点

本篇和前面的 基于helm的方式在k8s集群中部署gitlab 具有很强的关联性,因此如果有不明白的地方可以查看往期分享: 基于helm的方式在k8s集群中部署gitlab - 部署基于helm的方式在k8s集群中部署gitlab - 备份恢复基于helm的方式在k8s集群中部署gitlab - 升…

云原生Kubernetes:K8S集群版本升级(v1.22.14 - v1.23.14)

目录 一、理论 1.K8S集群升级 2.环境 3.升级集群(v1.23.14) 4.验证集群(v1.23.14) 二、实验 1. 环境 2.升级集群(v1.23.14) 2.验证集群(v1.23.14) 一、理论 1.K8S集群升级 …

链表常见题|删除链表、合并链表、环形链表、相交链表、反转链表、回文链表

链表常见题|删除链表、合并链表、环形链表、相交链表、反转链表、回文链表 文章目录 链表常见题|删除链表、合并链表、环形链表、相交链表、反转链表、回文链表2.两数相加19.删除链表的倒数第 N 个结点21.合并两个有序链表141.环形链表142.环形链表 II160.相交链表206.反转链表…

C++-类和对象(1)

1.面向过程和面向对象初步认识 C语言是面向过程的,关注的是过程,分析出求解问题的步骤,通过函数调用逐步解决问题。 C是基于面向对象的,关注的是对象,将一件事情拆分成不同的对象,靠对象之间的交互完 成。…

SpringBoot 3.2.0 基于SpringDoc接入OpenAPI实现接口文档

依赖版本 JDK 17 Spring Boot 3.2.0 SpringDoc 2.3.0 工程源码&#xff1a;Gitee 导入依赖 <properties><maven.compiler.source>17</maven.compiler.source><maven.compiler.target>17</maven.compiler.target><project.build.sourceEnco…

深入剖析LinkedList:揭秘底层原理

文章目录 一、 概述LinkedList1.1 LinkedList简介1.2 LinkedList的优点和缺点 二、 LinkedList数据结构分析2.1 Node节点结构体解析2.2 LinkedList实现了双向链表的原因2.3 LinkedList如何实现了链表的基本操作&#xff08;增删改查&#xff09;2.4 LinkedList的遍历方式 三、 …

SQL server 数据库练习题及答案(练习2)

使用你的名字创建一个数据库 创建表&#xff1a; 数据库中有三张表&#xff0c;分别为student,course,SC&#xff08;即学生表&#xff0c;课程表&#xff0c;选课表&#xff09; 问题&#xff1a; --1.分别查询学生表和学生修课表中的全部数据。--2.查询成绩在70到80分之间…

dhcp的配置

原理 就是服务器&#xff08;路由器或者交换机分配网段&#xff09; 动态分配 接口资源池 全局资源池 静态分配 实验 ar1 ip地址 r1-r3 dhcp en 打开dhcp en 第三步 配置地址池 接口 进入端口 dhcp select interface 设置接口资源池 dhcp server dns-…

5G NR无线蜂窝系统的信道估计器设计

文章目录 DMRS简介DMRS类型DMRS频域密度 信道估计实验仿真实验参数实验实验结论 DMRS简介 DMRS类型 类型A&#xff1a;DMRS位于时隙的第二个或第三个OFDM符号&#xff0c;由14个OFDM符号组成&#xff0c;当数据占据大部分时隙时使用A型映射。 类型B&#xff1a;用在URLLC中&a…

【Mybatis】深入学习MyBatis:概述、主要特性以及配置与映射

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a; Mybatis ⛳️ 功不唐捐&#xff0c;玉汝于成 目录 前言 正文 一、概述 MyBatis简介 主要特性 1. 动态SQL 2.结果映射 3 .插件机制 二、MyBatis配置文件 1.配置文件结构 数据库连…

15 款Python编辑器的优缺点,别再问我“选什么编辑器”

本文介绍了多个 Python IDE&#xff0c;并评价其优缺点。读者可以参考此文列举的 Python IDE 列表&#xff0c;选择适合自己的编辑器。 写 Python 代码最好的方式莫过于使用集成开发环境&#xff08;IDE&#xff09;了。它们不仅能使你的工作更加简单、更具逻辑性&#xff0c;…

Spring Boot整合MyBatis-Plus框架快速上手

最开始&#xff0c;我们要在Java中使用数据库时&#xff0c;需要使用JDBC&#xff0c;创建Connection、ResultSet等&#xff0c;然后我们又对JDBC的操作进行了封装&#xff0c;创建了许多类似于DBUtil等工具类。再慢慢的&#xff0c;出现了一系列持久层的框架&#xff1a;Hiber…

GBase南大通用-GBase 8a资源管理功能试用

环境&#xff1a;centos7.9&#xff1b;GBase 8a V9.5.3.27 资源管理功能简介 GBase南大通用的GBase 8a MPP Cluster 资源管理功能可以对 SELECT 和 DML 等受控 SQL 在运 行过程中使用的 CPU、内存、I/O 和磁盘空间等资源进行合理管控&#xff0c;以达到资 源合理利用&#x…

【toolschain algorithm cpp ros】cpp工厂模式实现--后续填充具体规划算法,控制器版的已填充了算法接入了仿真器

写在前面 现在局势危机&#xff0c;于是想复习一下之前写的设计模式&#xff0c;之前提到&#xff0c;做过一个闭环仿真器&#xff08;借用ros&#xff09;&#xff0c;见https://blog.csdn.net/weixin_46479223/article/details/134864123我的控制器的建立遵循了工厂模式&…

【低照度图像增强系列(2)】Retinex(SSR/MSR/MSRCR)算法详解与代码实现

前言 ☀️ 在低照度场景下进行目标检测任务&#xff0c;常存在图像RGB特征信息少、提取特征困难、目标识别和定位精度低等问题&#xff0c;给检测带来一定的难度。 &#x1f33b;使用图像增强模块对原始图像进行画质提升&#xff0c;恢复各类图像信息&#xff0c;再使用目标检…