深度学习中降维的几种方法

在这里插入图片描述

笔者在搞网络的时候碰到个问题,就是将特征维度从1024降维到268,那么可以通过哪些深度学习方法来实现呢?

文章目录

  • 1. 卷积层降维
  • 2. 全连接层降维
  • 3. 使用注意力机制
  • 4. 使用自编码器

1. 卷积层降维

可以使用1x1卷积层(也叫pointwise卷积)来减少通道数。这种方法保留了特征图的空间维度(宽度和高度),同时减少了通道数。

import torch
import torch.nn as nnclass ReduceDim(nn.Module):def __init__(self, in_channels, out_channels):super(ReduceDim, self).__init__()self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)def forward(self, x):return self.conv1x1(x)# 假设输入的特征图为 (bs, 1024, 28, 28)
x = torch.randn(56, 1024, 28, 28)
model = ReduceDim(1024, 268)
output = model(x)
print(output.shape)  # 输出形状应为 (56, 268, 28, 28)

2. 全连接层降维

可以将特征图展平为一个向量,然后使用全连接层(线性层)来降维。这种方法适用于特征图的全局降维。

class ReduceDimFC(nn.Module):def __init__(self, in_channels, out_channels, width, height):super(ReduceDimFC, self).__init__()self.fc = nn.Linear(in_channels * width * height, out_channels * width * height)self.width = widthself.height = heightdef forward(self, x):bs, c, w, h = x.shapex = x.view(bs, -1)x = self.fc(x)x = x.view(bs, out_channels, self.width, self.height)return x# 假设输入的特征图为 (bs, 1024, 28, 28)
x = torch.randn(56, 1024, 28, 28)
model = ReduceDimFC(1024, 268, 28, 28)
output = model(x)
print(output.shape)  # 输出形状应为 (56, 268, 28, 28)

3. 使用注意力机制

可以使用基于注意力机制的方法来降维。例如,可以使用Transformer编码器或自注意力机制来实现降维。

import torch
import torch.nn as nnclass ReduceDimAttention(nn.Module):def __init__(self, in_channels, out_channels):super(ReduceDimAttention, self).__init__()self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=8)self.fc = nn.Linear(in_channels, out_channels)def forward(self, x):bs, c, w, h = x.shapex = x.view(bs, c, -1).permute(2, 0, 1)  # (w*h, bs, c)x, _ = self.attention(x, x, x)x = x.permute(1, 2, 0).view(bs, c, w, h)x = self.fc(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)return x# 假设输入的特征图为 (bs, 1024, 28, 28)
x = torch.randn(56, 1024, 28, 28)
model = ReduceDimAttention(1024, 268)
output = model(x)
print(output.shape)  # 输出形状应为 (56, 268, 28, 28)

4. 使用自编码器

可以训练一个自编码器网络来学习降维。自编码器由编码器和解码器组成,通过最小化重建误差来学习紧凑的表示。


class Encoder(nn.Module):def __init__(self, in_channels, out_channels):super(Encoder, self).__init__()self.conv1 = nn.Conv2d(in_channels, 512, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(512, out_channels, kernel_size=3, padding=1)def forward(self, x):x = torch.relu(self.conv1(x))x = torch.relu(self.conv2(x))return xclass Decoder(nn.Module):def __init__(self, in_channels, out_channels):super(Decoder, self).__init__()self.conv1 = nn.Conv2d(in_channels, 512, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(512, out_channels, kernel_size=3, padding=1)def forward(self, x):x = torch.relu(self.conv1(x))x = torch.relu(self.conv2(x))return xclass Autoencoder(nn.Module):def __init__(self, in_channels, bottleneck_channels, out_channels):super(Autoencoder, self).__init__()self.encoder = Encoder(in_channels, bottleneck_channels)self.decoder = Decoder(bottleneck_channels, out_channels)def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x# 假设输入的特征图为 (bs, 1024, 28, 28)
x = torch.randn(56, 1024, 28, 28)
model = Autoencoder(1024, 268, 1024)
encoded = model.encoder(x)
print(encoded.shape)  # 输出形状应为 (56, 268, 28, 28)

以上方法都是有效的深度学习降维技术,可以根据具体的需求和应用场景选择合适的方法。Enjoy~

∼ O n e p e r s o n g o f a s t e r , a g r o u p o f p e o p l e c a n g o f u r t h e r ∼ \sim_{One\ person\ go\ faster,\ a\ group\ of\ people\ can\ go\ further}\sim One person go faster, a group of people can go further

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/393167.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《大道平渊》· 拾柒 —— 个人的心理定位决定市场

《大道平渊》 拾柒 个人的心理定位决定市场。 对于个人定位来说,个人的心理定位影响你的行为。 比如我的心理定位是经营者,那我的行为则是满足市场需求和解决问题。 因为心理定位的不同,会影响你思考问题的角度。 . 以上皆为个人思考&am…

【为什么不要买运营商的机顶盒?解锁智能电视新体验,从一台刷机机顶盒开始】

【置顶:机顶盒刷机步骤请跳转此链接】 在这个数字化飞速发展的时代,电视早已不再是单一的播放工具,它正逐步演变成为家庭娱乐与信息获取的综合中心。然而,许多家庭在选择机顶盒时,往往会因为惯性或便利而直接选择运营商提供的机顶…

常见中间件漏洞(三、Jboss合集)

目录 三、Jboss Jboss介绍 3.1 CVE-2015-7501 漏洞介绍 影响范围 环境搭建 漏洞复现 3.2 CVE-2017-7504 漏洞介绍 影响范围 环境搭建 漏洞复现 3.3 CVE-2017-12149 漏洞简述 漏洞范围 漏洞复现 3.4 Administration Console弱囗令 漏洞描述 影响版本 环境搭建…

【多线程-从零开始-伍】volatile关键字和内存可见性问题

volatile 关键字 import java.util.Scanner; public class Demo2 { private static int n 0; public static void main(String[] args) { Thread t1 new Thread(() -> { while(n 0){ //啥都不写 } System.out.println("t1 线程结束循环"); }, "…

C++类和对象——中

1. 类的默认成员函数 默认成员函数就是⽤⼾没有显式实现,编译器会⾃动⽣成的成员函数称为默认成员函数。⼀个类,我们不写的情况下编译器会默认⽣成以下6个默认成员函数,需要注意的是这6个中最重要的是前4个,最后两个取地址重载不…

差分专题的练习

神经&#xff0c;树状数组做多了一开始还想着用树状数组来查询差分数组&#xff0c;但是我们要进行所有元素的查询&#xff0c;直接过一遍就好啦 class Solution { public:int numberOfPoints(vector<vector<int>>& nums) {vector<int> c(105, 0);for (i…

Leetcode—233. 数字 1 的个数【困难】

2024每日刷题&#xff08;152&#xff09; Leetcode—233. 数字 1 的个数 算法思想 参考自k神 实现代码 class Solution { public:int countDigitOne(int n) {long digit 1;long high n / 10;long low 0;long cur n % 10;long ans 0;while(high ! 0 || cur ! 0) {if(cu…

多线程用不用ArrayList?

​ 博客主页: 南来_北往 系列专栏&#xff1a;Spring Boot实战 引言 多线程使用是指在单个程序中同时运行多个线程来完成不同的工作。 多线程是计算机领域的一个重要概念&#xff0c;它允许一个程序中的多个代码片段&#xff08;称为线程&#xff09;同时运行&#xff0…

Cache结构

Cache cache的一般设计 超标量处理器每周期需要从Cache中同时读取多条指令&#xff0c;同时每周期也可能有多条load/store指令会访问Cache&#xff0c;因此需要多端口的Cache L1 Cache&#xff1a;最靠近处理器&#xff0c;是流水线的一部分&#xff0c;包含两个物理存在 指…

鲜花销售小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;商家管理&#xff0c;鲜花信息管理&#xff0c;鲜花分类管理&#xff0c;管理员管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统首页&#xff0c;购物车&#xff0…

Linux 操作系统速通

一、安装虚拟机 1. VmWare 安装下载 vmware workstation pro 16 下载 win R 输入 ncpa.cpl 确保网卡正常 2. CentOS 系统下载 CentOS 系统下载 将 CentOS 系统安装到虚拟机 3. 查看虚拟机 IP 命令 ifconfig 4. finalShell 安装下载 finalShell 下载 输入用户名一般是 ro…

Html实现全国省市区三级联动

目录 前言 1.全国省市区的Json数据 2.找到Json数据文件(在此博文绑定资源)之后&#xff0c;放到resource目录下。 3.通过类加载器加载资源文件&#xff0c;读取Json文件 3.1 创建JsonLoader类 3.2 注入JsonLoader实体&#xff0c;解析Json文件 4.构建前端Html页面 5.通过…

如何在国外市场推广中国游戏

在国外市场推广中国游戏需要一种考虑文化差异、市场偏好和有效营销渠道的战略方法。以下是成功向国际观众介绍和推广中国游戏的关键步骤和策略&#xff1a; 进行市场调研 了解目标市场&#xff1a;首先确定哪些外国市场对你的游戏最具潜力。考虑类似游戏类型的受欢迎程度、玩…

通过数组中元素或者key将数组拆分归类成新的二维数组

处理前的数组: 处理后的数组: 你希望根据 riqi 字段将这个数组拆分成多个二维数组,每个二维数组包含相同日期的项。在ThinkPHP中,你可以使用PHP的数组操作来实现这一拆分操作。以下是如何按照 riqi 字段拆分成新的二维数组的示例代码: $splitArrays = [];foreach ($list…

YOLOv6训练自己的数据集

文章目录 前言一、YOLOv6简介二、环境搭建三、构建数据集四、修改配置文件①数据集文件配置②权重下载③模型文件配置 五、模型训练和测试模型训练模型测试 总结 前言 提示&#xff1a;本文是YOLOv6训练自己数据集的记录教程&#xff0c;需要大家在本地已配置好CUDA,cuDNN等环…

opencascade TopoDS、TopoDS_Vertex、TopoDS_Edge、TopoDS_Wire、源码学习

前言 opencascade TopoDS转TopoDS_Vertex opencascade TopoDS转TopoDS_Edge opencascade TopoDS转TopoDS_Wire opencascade TopoDS转TopoDS_Face opencascade TopoDS转TopoDS_Shell opencascade TopoDS转TopoDS_Solid opencascade TopoDS转TopoDS_Compound 提供方法将 TopoDS_…

Spring快速学习

目录 IOC控制反转 引言 IOC案例 Bean的作用范围 Bean的实例化 bean生命周期 DI 依赖注入 setter注入 构造器注入 自动装配 自动装配的方式 注意事项; 集合注入 核心容器 容器的创建方式 Bean的三种获取方式 Bean和依赖注入相关总结 IOC/DI注解开发 注解开发…

抽象代数精解【8】

文章目录 希尔密码矩阵矩阵基本概念行列式基本概念特殊矩阵关于乘法运算构成群 加解密原理密钥加密函数解密函数 Z 26 上的运算&#xff08; Z 256 与此类似&#xff09; Z_{26}上的运算&#xff08;Z_{256}与此类似&#xff09; Z26​上的运算&#xff08;Z256​与此类似&…

sql注入知识整理

sql注入知识整理 一、SQL注入概念 SQL注入就是用户输入的一些语句没有被过滤&#xff0c;输入后诸如这得到了数据库的信息SQL 注入是一种攻击方式&#xff0c;在这种攻击方式中&#xff0c;在字符串中插入恶意代码&#xff0c;然后将该字符串传递到 SQL Server 数据库引擎的实…

递归.python

目录 一、认识递归 二、阶乘问题 三、经典例题&#xff1a;汉诺塔问题 一、认识递归 递归&#xff1a;即方法&#xff08;函数&#xff09;自己调用自己的一种特殊编程写法。 函数调用自己&#xff0c;即称之为递归调用。 def func(): If ....: func() return ..... 递归…