使用AutoDecoder自动解码器实现简单MNIST特征向量提取

AutoDecoder

自动解码器(AD)是论文"DeepSDF: Learning Continuous Signed Distance Functions for Shape Representation" 中使用的一种方法,与传统编码-解码结构不同,AD无编码器,仅有一个解码器。解码器实现特征向量(隐向量)与图片之间的转换

在训练过程中同时优化特征向量与神经网络参数。如果训练集有N张图片,特征向量长度为n,神经网络参数为m,那么待训练参数共有N*n+m个。

训练完成之后,任给一个特征向量,输入解码器,则可得到一张图片。

DeepSDF原文更为复杂,使用AD生成带符号的最小距离场,进而实现3D形状的生成。DeepSDF原文还有一个针对MNIST手写数据集的简单的案例,但是没有给出源代码,难以上手。笔者在Github上找到了一个具体实现代码https://github.com/alexeybokhovkin/DeepSDF,在此基础上作了一些修改完善,撰写了这篇博客,以作记录,希望对读者有所帮助。
在这里插入图片描述

代码

项目共有四个文件,dataset.py用于定义数据集,evaluate.py用于评估训练后的神经网络,network.py用于定义神经网络结构,train.py用于训练神经网络
在这里插入图片描述

1 dataset.py

用于导入MNIST数据集(或FashionMNIST)数据集,如果本地没有则会从互联网下载

import torchvision
from torch.utils.data.dataset import Dataset# A wrapper dataset over MNIST to return images and indices
class DatasetMNIST(Dataset):def __init__(self, root_dir, latent_size, transform=None):mnist = torchvision.datasets.FashionMNIST(root=root_dir, train=True,download=True)self.data = mnist.train_data.float()/255.0def __len__(self):return len(self.data)def __getitem__(self, index):image = self.data[index]return image.flatten(), index

2 network.py

定义了一个全连接神经网络,输入特征向量,输出图片(展平为向量)。为了便于观察结果,输入的特征向量维数为2。需要注意的是,特征向量也作为被训练参数,共N*n个元素。

import torch
import torch.nn as nn
import torch.nn.init as init# Autodecoder structure
class AD(nn.Module):def __init__(self, image_size=784, z_dim=2, data_shape=60000):super(AD, self).__init__()self.decoder = nn.Sequential(nn.Linear(z_dim, 128),nn.ReLU(True),nn.Linear(128, 256),nn.ReLU(True),nn.Linear(256, 512),nn.ReLU(True), nn.Linear(512, 28 * 28), nn.Tanh())self.latent_vectors = nn.Parameter(torch.FloatTensor(data_shape, z_dim))init.xavier_normal(self.latent_vectors)def forward(self, ind):x = self.latent_vectors[ind]return self.decoder(x)def predict(self, x):return self.decoder(x)

3 train.py

使用Adams训练神经网络,batch_size=128num_epochs=250。训练只需几分钟,完成后保存为model.pth文件,以供调用

import os
import numpy as np
import torch
import torch.nn as nn
from torchvision.utils import save_image
from dataset import DatasetMNIST
from network import AD
from tqdm import tqdm# Hyper-parameters
image_width = 28
image_size = image_width*image_width
h_dim = 512
num_epochs = 250
batch_size = 128
learning_rate = 1e-3
latent_size = 2if __name__ == "__main__":os.environ['CUDA_VISIBLE_DEVICES'] = '0'# Create a directory if not existssample_dir = 'samples'os.makedirs(sample_dir, exist_ok=True)dataset = DatasetMNIST(root_dir='./data', latent_size=latent_size)# Data loaderdata_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)model = AD(image_size=image_size, z_dim=latent_size, data_shape=60000).cuda()# recusntruction losscriterion = nn.MSELoss()optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)model.train()for epoch in range(num_epochs):tq = tqdm(total=len(data_loader))tq.set_description('Epoch {}'.format(epoch))for i, (x, ind) in enumerate(data_loader):# Forward passx = x.cuda()x_reconst = model(ind)loss = criterion(x_reconst, x)optimizer.zero_grad()loss.backward()optimizer.step()tq.update()tq.set_postfix(loss='{:.3f}'.format(loss.item()))if epoch%5 == 0:with torch.no_grad():# Visualize 2D latent spacesteps = 50bound = 0.8size = image_widthout = torch.zeros(size=(steps * size, steps * size))for i, l1 in enumerate(np.linspace(-bound, bound, steps)):for j, l2 in enumerate(np.linspace(-bound, bound, steps)):vector = torch.tensor([l1, l2]).to(dtype=torch.float32).cuda()out_ = model.predict(vector)out[i * size:(i + 1) * size, j * size:(j + 1)* size] = out_.view(size, size)save_image(out, os.path.join(sample_dir, 'latent_space-{}.png'.format(epoch + 1)))# save modeltorch.save(model, 'model.pth')

4 evaluate.py

加载训练好的模型,遍历特征向量[l1,l2],使用解码器生成对应的图像,保存为latent_space-eval.png

将所有训练集的特征向量[l1,l2]绘制在一张图上,保存为latent_space-distribution.png

最后将两张图片拼接得到latent_space-merged.png

import os
import torch
import numpy as np
from torchvision.utils import save_image
from train import image_widthif __name__ == "__main__":# Create a directory if not existssample_dir = 'samples'os.makedirs(sample_dir, exist_ok=True)# 选择GPU或CPUdevice = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")# 从文件加载已经训练完成的模型model = torch.load('model.pth', map_location=device)model.eval()  # 设置模型为evaluation状态print(model)# Visualize 2D latent spacesteps = 50bound = 0.8size = image_widthout_grid = torch.zeros(size=(steps * size, steps * size))for i, l1 in enumerate(np.linspace(-bound, bound, steps)):for j, l2 in enumerate(np.linspace(-bound, bound, steps)):vector = torch.tensor([l1, l2]).to(dtype=torch.float32).cuda()out_ = model.predict(vector)out_grid[i * size:(i + 1) * size, j * size:(j + 1)* size] = out_.view(size, size)save_image(out_grid, os.path.join(sample_dir, 'latent_space-eval.png'))out_dist = torch.zeros(size=(steps * size, steps * size))latent_vectors_scaled=model.latent_vectors.cpu().detach().numpy()latent_vectors_scaled=np.clip(latent_vectors_scaled, -bound+0.005, bound-0.005)latent_vectors_scaled = ((latent_vectors_scaled+bound)/(2.0*bound)*steps*size*1.0)for i in range(len(latent_vectors_scaled)):l1=round(latent_vectors_scaled[i][0])l2=round(latent_vectors_scaled[i][1])out_dist[l1, l2]=1.0out_dist[l1-1, l2]=1.0out_dist[l1+1, l2]=1.0out_dist[l1, l2-1]=1.0out_dist[l1-1, l2-1]=1.0out_dist[l1+1, l2-1]=1.0out_dist[l1, l2+1]=1.0out_dist[l1-1, l2+1]=1.0out_dist[l1+1, l2+1]=1.0save_image(out_dist, os.path.join(sample_dir, 'latent_space-distribution.png'))out_merged = torch.cat((out_grid, out_dist), dim=1)save_image(out_merged, os.path.join(sample_dir, 'latent_space-merged.png'))print(model.latent_vectors.max(), model.latent_vectors.min())

结果

训练前遍历特征向量绘制得到解码后的图片:
在这里插入图片描述
epoch=16, 遍历特征向量绘制得到解码后的图片:

在这里插入图片描述
epoch=101, 遍历特征向量绘制得到解码后的图片:
在这里插入图片描述epoch=250, 遍历特征向量绘制得到解码后的图片:
在这里插入图片描述训练样本分布如下:
在这里插入图片描述

结论

(1) AD事实上是一种特征提取方法,本文从数据集中提取了一个2D特征,在2D平面内重构出了原始数据集。在实际使用中,特征向量是一个高维向量,效果会更好。
(2) 特征向量的分布近似于一个正态分布,但是不同类别之间存在鸿沟
(3)AD的本质是对图片数据进行压缩,图片公有信息蕴含于神经网络参数中,个体信息蕴含于特征向量。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/234264.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

K-means 聚类算法分析

算法简述 K-means 算法原理 我们假定给定数据样本 X ,包含了 n 个对象 ,其中每一个对象都具有 m 个维度的属性。而 K-means 算法的目标就是将 n 个对象依据对象间的相似性聚集到指定的 k 个类簇中,每个对象属于且仅属于一个其到类簇中心距离…

Python中调用matplotlib库三维可视化图像像素曲面分布

为了更直观的从3D视角观察一副图像的像素分布,且拖动观察没一个像素细节,可以使用下面代码实现。 目录 一、代码二、效果展示 一、代码 使用代码修改修改的地方如下: 具体实现代码如下: import numpy as np import matplotlib.pyplot as …

windows系统如何查看扇区?

windows系统如何查看扇区? 首先,我们按WindowsR 弹出"运行"对话框,打开文本框输入"MSINFO32.EXE"命令 展开左侧"组件"节点 接下来,我们选择"组件|存储|磁盘"文件夹 在其里面即可查看硬盘…

欧科云链研究院:奔赴2024,Web3与AI共振引爆数字时代潘多拉魔盒

出品|欧科云链研究院 2024年,Web3与AI两个数字科技的巅峰碰撞,欧科云链研究院探索AI与Web3的技术融合,与澎湃科技联合发布2024年展望,原标题为《2024年展望:Web3与AI共振引爆可信数字社会》,共…

Linux系统下gitee使用git提交代码

Linux系统下gitee使用git提交代码 一、安装配置git1.1 在 Linux 中安装 git,并生成授信证书1.2 将SSH key 添加到 ssh-agent1.2 将SSH key 添加到你的gitee账户 二、gitee 的使用2.1 下载项目到本地 三、上传gitee三步走3.1 三板斧第一招:git add3.2 三板…

编译原理期末大题步骤——例题

一、预测分析方法步骤 提取左公因子,消除左递归判断文法是否为LL(1)文法若是,构造预测分析表;否则,不能进行分析。根据预测分析表对输入串进行分析 例子: 文法G[E]: …

【Python】不一样的Ansible(一)

不一样的Ansible——进阶学习 前言正文概念Ansible CorePlugins和Modules 插件插件类型编写自定义插件基本要求插件选项文档标准编写插件 添加一个本地插件注册为内置插件指定插件目录 其他一些技巧更改Strategy 结语 前言 Ansible 是一个极其简单的 IT 自动化引擎&#xff0c…

ros gazebo机械臂仿真,手动控制与MoveIt自动控制

本文总结归纳古月居胡春旭ros机械臂教程,给出了一些error的解决方法,补充了通过python运行moveit。十分建议去看github huchunxu源代码的repository。 创建机械臂的xacro模型 首先创建一个工作空间,在工作空间中创建arm_description功能包。…

GitHub 一周热点汇总 第4期 (2024/01/01-01/06)

GitHub一周热点汇总第四期 (2023/12/24-12/30),梳理每周热门的GitHub项目,了解热点技术趋势,掌握前沿科技方向,发掘更多商机。2024年到了,希望所有的朋友们都能万事顺遂。 说明一下,有时候本周的热点项目会…

【HarmonyOS4.0】第三篇-类web开发模式

【HarmonyOS4.0】第三篇-类web开发模式 一、鸿蒙介绍 课程核心 为什么我们需要学习鸿蒙? 哪些人适合直接转鸿蒙? 鸿蒙系统优势是什么? 课程内容 (1)为什么要学习鸿蒙 从行情出发: 美国商务部长访问中国,2023年…

【Java并发】深入浅出 synchronized关键词原理-下

上一篇文章,简要介绍了syn的基本用法和monter对象的结构,本篇主要深入理解,偏向锁、轻量级锁、重量级锁的本质。 对象内存布局 Hotspot虚拟机中,对象在内存中存储的布局可以分为三块区域:对象头(Header)、实例数据 (Instance Da…

【Sublime Text】| 02——常用插件安装及配置

系列文章目录 【Sublime Text】| 01——下载软件安装并注册 【Sublime Text】| 02——常用插件安装及配置 失败了也挺可爱,成功了就超帅。 文章目录 1. 汉化2. 更换颜色主题3. 更改编码插件—ConvertToUTF84. 对齐插件—Alignment5. 括号高亮插件—BracketHighligh…

win11修改本地hosts,自定义域名

目录 🧈1.打开指定目录 🥚2.粘贴至桌面 🍳3.添加自己的域名和对应的ip地址 🍿4.替换原来的hosts文件 1.打开指定目录🧂🧂 在C盘下打开 --------C:\Windows\System32\drivers\etc,找到hos…

众和策略:沪指跌0.91%险守2900点,半导体、金融等板块走低

8日早盘,两市股指低开低走,沪指一度失守2900点,深成指、创业板指跌约1%,科创50指数创前史新低。 到午间收盘,沪指跌0.91%报2902.4点,深成指跌1.17%,创业板指跌0.99%,科创50指数跌超…

vue3中使用elementplus中的el-tree-select,自定义显示名称label

<el-tree-select v-model"addPval" node-key"id" :data"menulists" :render-after-expand"false" :props"menuProps" /> <el-divider />let menuProps {//自定义labellabel: (data: { name: any; }) > {ret…

程序媛的mac修炼手册-- 终端(terminal)常用命令

「终端&#xff08;terminal&#xff09;」相当于macOS的一个 App &#xff0c;它的特殊之处是&#xff0c;它是管理其它App的App&#xff0c;操作主要通过命令行界面 (CLI) 。 相比于我们日常熟悉的用户界面&#xff08;User Interface&#xff0c;UI&#xff09;&#xff0c…

vue3 封裝一个常用固定按钮组件(添加、上传、下载、删除)

效果图 这个组件只有四个按钮&#xff0c;添加&#xff0c;上传、下载、删除&#xff0c;其中删除按钮的颜色默认是灰色&#xff0c;当表格有数据选中时再变成红色 实现 组件代码 <script lang"ts" setup> import { Icon } from /components/Icon/index im…

Qt应用-实现图像截取功能类似QQ上传头像截取功能

本文演示利用Qt实现图像截取功能类似QQ上传头像截取功能。 效果如下,通过移动中间的裁剪区域可以获得一张裁剪后的图片。 目录

OpenAI ChatGPT-4开发笔记2024-02:Chat之text generation之completions

API而已 大模型封装在库里&#xff0c;库放在服务器上&#xff0c;服务器放在微软的云上。我们能做的&#xff0c;仅仅是通过API这个小小的缝隙&#xff0c;窥探ai的奥妙。从程序员的角度而言&#xff0c;水平的高低&#xff0c;就体现在对openai的这几个api的理解程度上。 申…

【unity】基于Obi的绳长动态修改(ObiRopeCursor)

文章目录 一、在运行时改变绳子长度:ObiRopeCursor1.1 Cursor Mu&#xff08;光标μ&#xff09;1.2 Source Mu&#xff08;源μ&#xff09;1.3 Direction&#xff08;方向&#xff09; 一、在运行时改变绳子长度:ObiRopeCursor Obi提供了一个非常通用的组件来在运行时修改绳…