AI学习指南深度学习篇-变分自编码器Python实践

AI学习指南深度学习篇 - 变分自编码器Python实践

引言

变分自编码器(Variational Autoencoder, VAE)是一种生成模型,它结合了变分推断和自编码器的优点。在本文中,我们将详细探讨VAE的基本概念,并通过Python中的深度学习框架(如TensorFlow和PyTorch)实现它。

目录

  1. VAE的基本概念
  2. 使用TensorFlow实现VAE
    • 数据准备
    • 模型构建
    • 训练过程
  3. 使用PyTorch实现VAE
    • 数据准备
    • 模型构建
    • 训练过程
  4. 总结与展望

VAE的基本概念

变分自编码器是一种隐变量模型,通过编码器将输入数据映射到隐变量空间,并使用解码器将隐变量映射回数据空间。VAE的目标是最大化数据的似然性,同时通过KL散度约束隐变量的分布,使其接近一个标准正态分布。

在VAE中,网络有两部分组成:

  1. 编码器(Encoder):将输入数据映射到隐变量分布的参数,通常是均值和方差。
  2. 解码器(Decoder):利用隐变量生成新的数据。

总体来说,VAE的损失函数可以表示为:

Loss = Reconstruction Loss + β ⋅ KL Divergence \text{Loss} = \text{Reconstruction Loss} + \beta \cdot \text{KL Divergence} Loss=Reconstruction Loss+βKL Divergence

其中,重构损失反映了模型生成样本的质量,KL散度则对隐变量分布进行约束。


使用TensorFlow实现VAE

数据准备

我们将使用MNIST数据集作为示例,首先确保你已安装TensorFlow和其他所需库。

pip install tensorflow numpy matplotlib

然后,可以运行以下代码下载和准备数据:

import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt# 加载数据集
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype(np.float32) / 255.0  # 归一化
x_test = x_test.astype(np.float32) / 255.0# 将数据 reshape 为 (num_samples, height*width)
x_train = np.reshape(x_train, (len(x_train), -1))
x_test = np.reshape(x_test, (len(x_test), -1))

模型构建

在这里,我们将定义VAE的编码器和解码器。

class Sampling(tf.keras.layers.Layer):"""使用均值和对数方差进行采样"""def call(self, inputs):mean, log_var = inputsbatch = tf.shape(mean)[0]dim = tf.shape(mean)[1]eps = tf.random.normal(shape=(batch, dim))return mean + tf.exp(0.5 * log_var) * epslatent_dim = 2  # 潜在空间维度# 编码器
encoder_inputs = tf.keras.layers.Input(shape=(784,))
x = tf.keras.layers.Dense(256, activation="relu")(encoder_inputs)
x = tf.keras.layers.Dense(128, activation="relu")(x)
# 输出均值和对数方差
z_mean = tf.keras.layers.Dense(latent_dim)(x)
z_log_var = tf.keras.layers.Dense(latent_dim)(x)
z = Sampling()([z_mean, z_log_var])# 解码器
decoder_inputs = tf.keras.layers.Input(shape=(latent_dim,))
x = tf.keras.layers.Dense(128, activation="relu")(decoder_inputs)
x = tf.keras.layers.Dense(256, activation="relu")(x)
decoder_outputs = tf.keras.layers.Dense(784, activation="sigmoid")(x)# 创建VAE模型
encoder = tf.keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
decoder = tf.keras.Model(decoder_inputs, decoder_outputs, name="decoder")
vae = tf.keras.Model(encoder_inputs, decoder(decoder(encoder(encoder_inputs)[2])), name="vae")

损失函数

定义VAE的损失函数,包括重构损失和KL散度。

def vae_loss(x, x_decoded_mean):reconstruction_loss = tf.keras.losses.binary_crossentropy(x, x_decoded_mean)reconstruction_loss *= 784  # 扩展到784维kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)kl_loss = tf.reduce_sum(kl_loss) * -0.5return tf.reduce_mean(reconstruction_loss + kl_loss)# 编译模型
vae.compile(optimizer="adam", loss=vae_loss)

训练过程

训练VAE模型,使用编译后的模型进行训练。

# 训练模型
vae.fit(x_train, x_train, epochs=50, batch_size=128, validation_data=(x_test, x_test))# 绘制结果
def plot_results(models, data, batch_size=128):encoder, decoder = modelsx_test_encoded = encoder.predict(data)[2]plt.figure(figsize=(6, 6))plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], s=2)plt.xlabel("Latent Dimension 1")plt.ylabel("Latent Dimension 2")plt.title("Latent Space")plt.colorbar()plt.show()plot_results((encoder, decoder), x_test)

使用PyTorch实现VAE

数据准备

确保你已安装PyTorch库。

pip install torch torchvision matplotlib

然后,加载数据:

import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 数据预处理与加载
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

模型构建

定义VAE网络的编码器和解码器。

import torch.nn as nn
import torch.nn.functional as Fclass VAE(nn.Module):def __init__(self):super(VAE, self).__init__()self.fc1 = nn.Linear(784, 256)self.fc21 = nn.Linear(256, 20)  # 均值self.fc22 = nn.Linear(256, 20)  # 对数方差self.fc3 = nn.Linear(20, 256)self.fc4 = nn.Linear(256, 784)def encode(self, x):h1 = F.relu(self.fc1(x))return self.fc21(h1), self.fc22(h1)def reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h3 = F.relu(self.fc3(z))return torch.sigmoid(self.fc4(h3))def forward(self, x):mu, logvar = self.encode(x.view(-1, 784))z = self.reparameterize(mu, logvar)return self.decode(z), mu, logvar

损失函数

定义VAE的损失函数。

def loss_function(recon_x, x, mu, logvar):BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum")KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return BCE + KLD

训练过程

接下来,我们将训练VAE模型并展示结果。

# 初始化模型和优化器
model = VAE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)# 训练模型
model.train()
for epoch in range(10):train_loss = 0for batch_idx, (data, _) in enumerate(train_loader):data = data.to(torch.float32).view(-1, 784)optimizer.zero_grad()recon_batch, mu, logvar = model(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()print(f"Epoch {epoch+1}, Loss: {train_loss / len(train_loader.dataset)}")# 绘制潜在空间
def plot_latent_space(model, data_loader):model.eval()with torch.no_grad():mu_list = []for data, _ in data_loader:data = data.view(-1, 784).to(torch.float32)mu, _ = model.encode(data)mu_list.append(mu)mu_tensor = torch.cat(mu_list).cpu().numpy()plt.figure(figsize=(6, 6))plt.scatter(mu_tensor[:, 0], mu_tensor[:, 1], s=2)plt.xlabel("Latent Dimension 1")plt.ylabel("Latent Dimension 2")plt.title("Latent Space")plt.colorbar()plt.show()plot_latent_space(model, test_loader)

总结与展望

在本文中,我们通过TensorFlow和PyTorch两个深度学习库实现了变分自编码器(VAE),并介绍了其基本概念、模型构建、损失函数及训练过程。

变分自编码器不仅在生成模型中具有重要应用,还能够用于数据降维、特征学习等领域。在未来的研究中,读者可以尝试改进此模型,比如添加更多的层、使用更复杂的损失函数,或将其应用于其他类型的数据集(如图像、文本等)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/450687.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

顺序表的查找

. GetElem(L,i):按位查找。获取L中的第i个位置元素的值。 静态查找: #define MaxSzie 10 typedef struct{ElemType data[MaxSize];int length; }Sqlist;ElemType GetElem(Sqlist L,int i) {return L.data[i-1]; }动态分配: #define InitSzie 10 type…

公司新来一个同事,把枚举运用得炉火纯青...

1.概览 在本文中,我们将看到什么是 Java 枚举,它们解决了哪些问题以及如何在实践中使用 Java 枚举实现一些设计模式。 enum关键字在 java5 中引入,表示一种特殊类型的类,其总是继承java.lang.Enum类,更多内容可以自行…

SpringBoot驱动的车辆信息管理平台

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统,它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等,非常…

如何使用C#实现Padim算法的训练和推理

目录 说明 项目背景 算法实现 预处理模块——图像预处理 主要模块——训练:Resnet层信息提取 主要模块——信息处理,计算Anomaly Map 主要模块——评估 主要模块——评估:门限值的确定 主要模块——推理 写在最后 项目下载链接 说…

【即见未来,为何不拜】聊聊分布式系统中的故障监测机制——Phi Accrual failure detector

前言 昨天在看tcp拥塞控制中的BBR(Bottleneck Bandwidth and Round-trip propagation time)算法时,发现了这一特点: 在BBR以前的拥塞控制算法中(如Reno、Cubic、Vegas),都依赖于丢包事件的发生,在高并发时则会看到网络波动的现象…

【含开题报告+文档+PPT+源码】基于SSM的景行天下旅游网站的设计与实现

开题报告 随着互联网的快速发展,旅游业也逐渐进入了数字化时代。作为一个旅游目的地,云浮市意识到了互联网在促进旅游业发展方面的巨大潜力。为了更好地推广云浮的旅游资源,提高旅游服务质量,云浮市决定开发一个专门的旅游网站。…

深入理解计算机系统--计算机系统漫游

对于一段最基础代码的文件hello.c&#xff0c;解释程序的运行 #include <stdio.h>int main() {printf ( "Hello, world\n") ;return 0; }1.1、信息就是位上下文 源程序是由值 0 和 1 组成的位&#xff08;比特&#xff09;序列&#xff0c;8 个位被组织成一组…

梯度下降算法优化—随机梯度下降、小批次、动量、Adagrad等方法pytorch实现

现有不足 现有调整网络的方法是借助成本函数的梯度下降方法&#xff0c;也就是给函数作切线&#xff0c;不断逼近最优点&#xff0c;即成本函数为零的点。 梯度下降的一般公式为&#xff1a; 即根据每个节点成本函数的梯度进行更新&#xff0c;使用该方法有一些问题&#xff…

探索OpenCV的人脸检测:用Haar特征分类器识别图片中的人脸

目录 简介 OpenCV和Haar特征分类器 实现人脸检测 1. 导入所需库 2. 加载图片和Haar特征分类器 3. 检测人脸 4. 标注人脸 5. 显示 6、结果展示 结论 简介 在计算机视觉和图像处理领域&#xff0c;人脸识别是一项重要的技术。它不仅应用于安全监控、人机交互&#xff0…

10秒钟用Midjourney画出国风味的变形金刚

上魔咒 Optimus Prime comes from the movie Transformers, Chinese style, Wu ShanMing, Ink Painting Halo Dyeing, Conceptual of the Digita Art, MasterComposition, Romantic Ancient Style, Inspired by traditional patterns and symbols, Minimalism, do not con…

day01 -- MybatisPlus

1. MybatisPlus简介 有基础的同学可结合资源中的代码一起看 MyBatis 的增强工具&#xff0c;在 MyBatis 的基础上只做增强不做改变&#xff0c;为简化开发、提高效率而生 特性 通用的 CRUD 操作&#xff1a;内置通用 Mapper、通用 Service&#xff0c;仅仅通过少量配置即可实…

私有化部署大模型最佳解决方案 Ollama (8B)模型

私有化部署大模型Ollama 为什么需要私有化部署大模型一、Ollama本地部署Llama3大模型二、Langchain4j调用Ollama本地部署模型API三、Ollama本地部署nomic向量模型四、Spring AI调用Ollama本地部署模型API 为什么需要私有化部署大模型 企业考虑成本和数据隐私问题&#xff0c;会…

021_Thermal_Transient_in_Matlab统一偏微分框架之热传导问题

Matlab求解有限元专题系列 固体热传导方程 固体热传导的方程为&#xff1a; ρ C p ( ∂ T ∂ t u t r a n s ⋅ ∇ T ) ∇ ⋅ ( q q r ) − α T d S d t Q \rho C_p \left( \frac{\partial T}{\partial t} \mathbf{u}_{\mathtt{trans}} \cdot \nabla T \right) \nab…

BM算法(手算版)

BM 算法 BM 算法是一种字符串匹配的算法。 与 KMP 相比&#xff0c;BM 算法不扫描全部输入字符&#xff0c;平均匹配时间 c・n, 常量 c <1 (随机或真实文本), 但最坏情况是 O (n・m). 可以将 BM 算法的最坏情况改进到 O (n)&#xff1a;通过记录文本后缀中最…

计算机系统简介

一、计算机的软硬件概念 1.硬件&#xff1a;计算机的实体&#xff0c;如主机、外设、硬盘、显卡等。 2.软件&#xff1a;由具有各类特殊功能的信息&#xff08;程序&#xff09;组成。 系统软件&#xff1a;用来管理整个计算机系统&#xff0c;如语言处理程序、操作系统、服…

群晖前面加了雷池社区版,安装失败,然后无法识别出用户真实访问IP

有nas的相信对公网都不模式&#xff0c;在现在基础上传带宽能有100兆的时代&#xff0c;有公网代表着家里有一个小服务器&#xff0c;像百度网盘&#xff0c;优酷这种在线服务都能部署为私有化服务。但现在运营商几乎不可能提供公网ip&#xff0c;要么自己买个云服务器做内网穿…

通过github创建自己网页链接的方法

文章目录 要使用GitHub创建静态网页链接&#xff0c;可以按照以下详细步骤进行操作&#xff1a;一、准备阶段二、创建仓库并配置三、准备并上传静态网站文件四、配置GitHub Pages五、访问和更新你的静态网页 要使用GitHub创建静态网页链接&#xff0c;可以按照以下详细步骤进行…

uniapp微信小程序调用百度OCR

uniapp编写微信小程序调用百度OCR 公司有一个识别行驶证需求&#xff0c;调用百度ocr识别 使用了image-tools这个插件&#xff0c;因为百度ocr接口用图片的base64 这里只是简单演示&#xff0c;accesstoken获取接口还是要放在服务器端&#xff0c;不然就暴露了自己的百度项目k…

Xshell使用密钥远程登录Ubuntu 22.04报错:所选的用户密钥未在远程主机上注册。请再试一次

报错截图如下&#xff1a; 问题原因&#xff1a; Ubuntu 22.04 不支持 Xshell使用的私钥。 查看系统支持的私钥&#xff1a;sudo sshd -T | egrep "pubkey" ~$ sudo sshd -T | egrep "pubkey" pubkeyauthentication yes pubkeyacceptedalgorithms ssh-ed…

基于SpringBoot+Vue的旅游服务平台【提供源码+答辩PPT+参考文档+项目部署】

&#x1f4a5; ① 前言&#xff1a;这两年毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的JavaWeb项目缺少创新和亮点&#xff0c;往往达不到毕业答辩的要求&#xff01; ❗② 如何解决这类问题&#xff1f; 让我们能够顺利通过毕业&#xff0c;我也一直在不断思考、…