PyTorch深度学习实战(26)——卷积自编码器(Convolutional Autoencoder)

PyTorch深度学习实战(26)——卷积自编码器

    • 0. 前言
    • 1. 卷积自编码器
    • 2. 使用 t-SNE 对相似图像进行分组
    • 小结
    • 系列链接

0. 前言

我们已经学习了自编码器 (AutoEncoder) 的原理,并使用 PyTorch 搭建了全连接自编码器,但我们使用的数据集较为简单,每张图像只有一个通道(每张图像都为黑白图像)且图像相对较小 (28 x 28)。但在现实场景中,图像数据通常为彩色图像( 3 个通道)且图像尺寸通常较大。在本节中,我们将实现能够处理多维输入图像的卷积自编码器,为了与普通自编码器进行对比,同样使用 MNIST 数据集。

1. 卷积自编码器

与传统的全连接自编码器不同,卷积自编码器 (Convolutional Autoencoder) 利用卷积层和池化层替代了全连接层,以处理具有高维空间结构的图像数据。这样的设计使得卷积自编码器能够在较少的参数量下对输入数据进行降维和压缩,同时保留重要的空间特征。卷积自编码器架构如下所示:

卷积自编码器

从上图中可以看出,输入图像被表示为瓶颈层中的潜空间变量,用于重建图像。图像经过多次卷积(编码器)得到低维潜空间表示,然后在解码器中,将潜空间变量还原为原始尺寸,使解码器的输出能够近似恢复原始输入。
本质上,卷积自编码器在其网络中使用卷积、池化操作来代替原始自编码器的全连接操作,并使用反卷积操作 (Conv2DTranspose) 对特征图进行上采样。了解卷积自编码器的原理后,使用 PyTorch 实现此架构。

(1) 数据集的加载和构建方式与全连接自编码器完全相同:

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import numpy as np
from matplotlib import pyplot as plt
device = 'cuda' if torch.cuda.is_available() else 'cpu'img_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5], [0.5]),transforms.Lambda(lambda x: x.to(device))
])trn_ds = MNIST('MNIST/', transform=img_transform, train=True, download=True)
val_ds = MNIST('MNIST/', transform=img_transform, train=False, download=True)batch_size = 256
trn_dl = DataLoader(trn_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

(2) 定义神经网络类 ConvAutoEncoder

定义 __init__ 方法:

class ConvAutoEncoder(nn.Module):def __init__(self):super().__init__()

定义编码器架构:

        self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3, stride=3, padding=1), nn.ReLU(True),nn.MaxPool2d(2, stride=2),nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(True),nn.MaxPool2d(2, stride=1))

在以上代码中,通道数最初由 1 开始,逐渐增加到 64,同时通过 nn.MaxPool2dnn.Conv2d 操作减小输入图像尺寸。

定义解码器架构:

        self.decoder = nn.Sequential(nn.ConvTranspose2d(64, 32, 3, stride=2), nn.ReLU(True),nn.ConvTranspose2d(32, 16, 5, stride=3, padding=1), nn.ReLU(True),nn.ConvTranspose2d(16, 1, 2, stride=2, padding=1), nn.Tanh())

定义前向传播方法 forward

    def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x

(3) 使用 summary 方法获取模型摘要信息:

model = ConvAutoEncoder().to(device)
from torchsummary import summary
summary(model, (1,28,28))
输出结果如下所示:
```shell
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1           [-1, 32, 10, 10]             320ReLU-2           [-1, 32, 10, 10]               0MaxPool2d-3             [-1, 32, 5, 5]               0Conv2d-4             [-1, 64, 3, 3]          18,496ReLU-5             [-1, 64, 3, 3]               0MaxPool2d-6             [-1, 64, 2, 2]               0ConvTranspose2d-7             [-1, 32, 5, 5]          18,464ReLU-8             [-1, 32, 5, 5]               0ConvTranspose2d-9           [-1, 16, 15, 15]          12,816ReLU-10           [-1, 16, 15, 15]               0ConvTranspose2d-11            [-1, 1, 28, 28]              65Tanh-12            [-1, 1, 28, 28]               0
================================================================
Total params: 50,161
Trainable params: 50,161
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.14
Params size (MB): 0.19
Estimated Total Size (MB): 0.34
----------------------------------------------------------------

从以上模型架构信息可以看出,使用尺寸为 batch size x 64 x 2 x 2MaxPool2d-6 层作为瓶颈层。

模型训练过程,训练和验证损失随时间的变化以及对输入图像的重建结果如下:

def train_batch(input, model, criterion, optimizer):model.train()optimizer.zero_grad()output = model(input)loss = criterion(output, input)loss.backward()optimizer.step()return loss@torch.no_grad()
def validate_batch(input, model, criterion):model.eval()output = model(input)loss = criterion(output, input)return lossmodel = ConvAutoEncoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)num_epochs = 20
train_loss_epochs = []
val_loss_epochs = []
for epoch in range(num_epochs):N = len(trn_dl)trn_loss = []val_loss = []for ix, (data, _) in enumerate(trn_dl):loss = train_batch(data, model, criterion, optimizer)pos = (epoch + (ix+1)/N)trn_loss.append(loss.item())train_loss_epochs.append(np.average(trn_loss))N = len(val_dl)for ix, (data, _) in enumerate(val_dl):loss = validate_batch(data, model, criterion)pos = epoch + (1+ix)/Nval_loss.append(loss.item())val_loss_epochs.append(np.average(val_loss))epochs = np.arange(num_epochs)+1
plt.plot(epochs, train_loss_epochs, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs, 'r-', label='Test loss')
plt.title('Training and Test loss over increasing epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.show()for _ in range(5):ix = np.random.randint(len(val_ds))im, _ = val_ds[ix]_im = model(im[None])[0]plt.subplot(121)# fig, ax = plt.subplots(1,2,figsize=(3,3)) plt.imshow(im[0].detach().cpu(), cmap='gray')plt.title('input')plt.subplot(122)plt.imshow(_im[0].detach().cpu(), cmap='gray')plt.title('prediction')plt.show()

模型性能监测图像重建结果图像重建结果

从上图中,我们可以看到卷积自编码器重建后的图像比全连接自编码器更清晰,可以通过改变编码器和解码器中的通道数,观察模型训练结果。在下一节中,我们将根据瓶颈层潜变量对相似图像进行分组

2. 使用 t-SNE 对相似图像进行分组

假设相似的图像具有相似的潜变量(也称嵌入),而不相似的图像具有不同的潜变量,使用自编码器,可以在低维空间中表示图像。接下来,我们继续学习图像的相似度度量,在二维空间中绘制潜变量,使用 t-SNE 技术将卷积自编码器的 64 维向量缩减至到 2 维空间。
2 维空间中,我们可以方便的可视化潜变量,以观察相似图像是否具有相似的潜变量,相似图像在二维平面中应该聚集在一起。接下里,我们在二维平面中表示所有测试图像的潜变量。

(1) 初始化列表,以便存储潜变量 (latent_vectors) 和相应的图像类别(存储每个图像的类别只是为了验证同一类别的图像是否具有较高的相似性,并不会在训练过程使用):

latent_vectors = []
classes = []

(2) 遍历验证数据加载器 (val_dl) 中的图像,并存储编码器的输出 (model.encoder(im).view(len(im),-1)) 和每个图像 (im) 对应的类别 (clss):

for im,clss in val_dl:latent_vectors.append(model.encoder(im).view(len(im),-1))classes.extend(clss)

(3) 连接潜变量 (latent_vectors) NumPy 数组:

latent_vectors = torch.cat(latent_vectors).cpu().detach().numpy()

(4) 导入 t-SNE 库 (TSNE),并将潜变量转换为二维向量 (TSNE(2)) ,以便进行绘制:

from sklearn.manifold import TSNE
tsne = TSNE(2)

(5) 通过在图像潜变量 (latent_vectors) 上运行 fit_transform 方法来拟合 t-SNE

clustered = tsne.fit_transform(latent_vectors)

(6) 拟合 t-SNE 后绘制数据点:

fig = plt.figure(figsize=(12,10))
cmap = plt.get_cmap('Spectral', 10)
plt.scatter(*zip(*clustered), c=classes, cmap=cmap)
plt.colorbar(drawedges=True)
plt.show()

聚类结果

可以看到同一类别的图像能够聚集在一起,即相似的图像将具有相似的潜变量值。

小结

卷积自编码器是一种基于卷积神经网络结构的自编码器,适用于处理图像数据。卷积自编码器在图像处理领域有广泛的应用,包括图像去噪、图像压缩、图像生成等任务。通过训练卷积自编码器,可以提取出输入图像的关键特征,并实现对图像数据的降维和压缩,同时保留重要的空间信息。在本节中,我们介绍了卷积自编码器的模型架构,使用 PyTorch 从零开始实现在 MNIST 数据集上训练了一个简单的卷积自编码器,并使用 t-SNE 技术在二维平面中表示了所有测试图像的潜变量。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——使用U-Net架构进行图像分割
PyTorch深度学习实战(24)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(25)——自编码器(Autoencoder)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/221638.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

node.js mongoose middleware

目录 官方文档 简介 定义模型 注册中间件 创建doc实例,并进行增删改查 方法名和注册的中间件名相匹配 执行结果 分析 错误处理中间件 手动抛出错误 注意点 官方文档 Mongoose v8.0.3: Middleware 简介 在mongoose中,中间件是一种允许在执…

vue的自定义指令注册使用

目录 分类 局部注册 全局注册 使用例子 分类 自定义指令的注册分为局部注册和全局注册 局部注册是在某个组件内注册的指令,只能在这个组件内使用 全局注册是在main.js中注册的指令在任何组件内都可以使用,指令在使用时不论是全局还是局部注册的&am…

机器学习 | 贝叶斯方法

不同于KNN最近邻算法的空间思维,线性算法的线性思维,决策树算法的树状思维,神经网络的网状思维,SVM的升维思维。 贝叶斯方法强调的是 先后的因果思维。 监督式模型分为判别式模型和生成式模型。 判别模型和生成模型的区别&#xf…

Spring MVC框架支持RESTful,设计URL时可以使用{自定义名称}的占位符@Get(“/{id:[0-9]+}/delete“)

背景:在开发实践中,如果没有明确的规定URL,可以参考: 传统接口 获取数据列表,固定接口路径:/数据类型的复数 例如:/albums/select RESTful接口 - 根据ID获取某条数据:/数据类型的复数/{id} - 例…

智能化物联网(IoT):发展、问题与未来前景

导言 智能化物联网(IoT)作为信息技术领域的一项核心技术,正在深刻改变人们的生活和工作方式。本文将深入研究IoT的发展过程、遇到的问题及解决过程、未来的可用范围,以及在各国的应用和未来的研究趋势。探讨在哪些方面能够取得胜利…

将Abp默认事件总线改造为分布式事件总线

文章目录 原理创建分布式事件总线实现自动订阅和事件转发 使用启动Redis服务配置传递Abp默认事件传递自定义事件 项目地址 原理 本地事件总线是通过Ioc容器来实现的。 IEventBus接口定义了事件总线的基本功能,如注册事件、取消注册事件、触发事件等。 Abp.Events…

Keil编译STM32工程,提示__align(4)处语法错误

好久没有用Keil编程,因为别人的代码是用Keil写的,所以又得安装起来,编译时遇到__align(4)的错误提示。 这个问题主要是编译器版本的问题,默认使用的是v6.19版本的编译器,而工程原来使用的是v5版本的,两个编…

B039-SpringMVC基础

目录 SpringMVC简介复习servletSpringMVC入门导包配置前端控制器编写处理器实现Contoller接口普通类加注解(常用) 路径问题获取参数的方式过滤器简介自定义过滤器配置框架提供的过滤器 springMVC向页面传值的三种方式视图解析器springMVC的转发和重定向 SpringMVC简介 1.Sprin…

Windows如何安装使用TortoiseSVN客户端并实现公网访问本地SVN Server

文章目录 前言1. TortoiseSVN 客户端下载安装2. 创建检出文件夹3. 创建与提交文件4. 公网访问测试 前言 TortoiseSVN是一个开源的版本控制系统,它与Apache Subversion(SVN)集成在一起,提供了一个用户友好的界面,方便用…

接口测试和测试用例分析

只要有软件产品的公司百分之九十以上都会做接口测试,要做接口测试的公司那是少不了接口测试工程师的,接口测试工程师相对于其他的职位又比较轻松并且容易胜任。如果你想从事接口测试的工作那就少不了对接口进行分析,同时也会对测试用例进行研…

解决 Hive 外部表分隔符问题的实用指南

简介: 在使用 Hive 外部表时,分隔符设置不当可能导致数据导入和查询过程中的问题。本文将详细介绍如何解决在 Hive 外部表中正确设置分隔符的步骤。 问题描述: 在使用Hive外部表时,可能会遇到分隔符问题。这主要是因为Hive在读…

同一个数组中对象去重

封装方法 fn1 (tempArr) {this.echartList.map(item > {for (let i 0; i < item.data.length; i) {for (let j i 1; j < item.data.length; j) {if (item.data[i].deviceId item.data[j].deviceId && item.data[i].time item.data[j].time && it…

解决win10下强制设置web浏览器为microsoft edge的方法

目录 问题场景实现方法禁止edge默认选项设置默认浏览器 反思 问题场景 因为一些特殊的原因&#xff0c;我需要第二个浏览器&#xff0c;我的第一个浏览器是google的chrome浏览器&#xff0c;所以我选择的是windows的默认浏览器&#xff0c;就是microsoft edge浏览器&#xff0…

easyExcel生成excel并导出自定义样式------添加复杂表头

easyExcel生成excel并导出自定义样式------添加复杂表头 设置合并竖行单元格&#xff0c;表头设置 OutputStream outputStream ExcelUtils.getResponseOutputStream(response, fileName);//根据数据组装需要合并的单元格Map<String, List<String>> strategyMap …

React学习计划-React16--React基础(二)组件与组件的3大核心属性state、props、ref和事件处理

1. 组件 函数式组件&#xff08;适用于【简单组件】的定义&#xff09; 示例&#xff1a; 执行了ReactDOM.render(<MyComponent/>, ...)之后执行了什么&#xff1f; React解析组件标签&#xff0c;找到了MyComponent组件发现组件是使用函数定义的&#xff0c;随后调用该…

kafka文件存储机制

Topic分为好几个partition分区&#xff0c;每个分区对应于一个log文件&#xff0c;log文件其实是虚的&#xff0c;Kafka采取了分片和索引机制&#xff0c; 将每个partition分为多个segment&#xff08;大小为1G&#xff09;。每个segment包括&#xff1a;“.index”文件、“.lo…

java SSM健身跑步爱好者社区系统myeclipse开发mysql数据库springMVC模式java编程计算机网页设计

一、源码特点 java SSM健身跑步爱好者社区系统是一套完善的web设计系统&#xff08;系统采用SSM框架进行设计开发&#xff0c;springspringMVCmybatis&#xff09;&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整 的源代码和数据库&#xff0c;系统…

Linux 操作系统(Vim)

vim 编译器&#xff08;相当于windows中记事本&#xff09; 当在终端窗口直接运行vim命令&#xff0c;会出现以下截图&#xff08;类似手册对vim编译器简单的介绍&#xff09;&#xff1a; vim提供三种基本工作模式&#xff1a; 命令模式(默认模式) 插入模式 末行模式 创建文本…

Qt前端技术:2.QSS

border-style&#xff1a;后边是两个参数的话第一个参数改变上下的style 第二个参数改变左右的style 如果后边是三个参数的话第一个参数改变上边的style第二个参数改变左右的style&#xff0c;第三个参数改变的下边的style 如果后边是四个参数的话对应的顺序为上&#xff0c;右…

C# WPF上位机开发(进度条操作)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 软件上面如果一个操作比较缓慢&#xff0c;或者说需要很长的时间&#xff0c;那么这个时候最好添加一个进度条&#xff0c;提示一下当前任务的进展…