GAN原理 代码解读

模型架构

在这里插入图片描述

代码

数据准备

import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch# 创建文件夹存放图片
os.makedirs("data", exist_ok=True)
transform = transforms.Compose([transforms.ToTensor(), #它会进行0-1归一化,h方向/h,w方向/w。 然后将图片格式转换为 (channel,h,w)transforms.Normalize(0.5,0.5),#把数据归一化为均值为0.5,方差为0.5,图像的数值范围变成-1到1
])
# 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
train_dataset = datasets.MNIST('data',train=True,transform=transform,download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)

定义生成器

'''输入:正态分布随机数噪声(长度为100)输出:生成的图片,(1,28,28)中间过程:linear1: 100 -> 256linear2: 256 -> 512linear3: 512 -> 28*28reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):def __init__(self):super(Generator,self).__init__() # super().__init__() 是调用父类的__init__函数self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),nn.Linear(256,512),nn.ReLU(),# 最后一层用tanh激活,将数据压缩到-1到1nn.Linear(512,28*28),nn.Tanh())def forward(self,x):img = self.model(x)img = img.view(-1,28,28,1) # 得到的是28*28=784,把它reshape为 (批量,h,w,channel)return img

定义判别器

'''判别器输入:(1,28,28)的图片输出:二分类的概率值 用sigmoid压缩到0-1之间内容:判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):def __init__(self):super(Discriminator,self).__init__()self.model = nn.Sequential(nn.Linear(28*28,512),nn.LeakyReLU(),nn.Linear(512,256),nn.LeakyReLU(),nn.Linear(256,1),nn.Sigmoid(),)def forward(self,x):x = x.view(-1,28*28)x = self.model(x)return x

初始化模型,优化器及损失计算函数

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device) # 初始化并放到了相应的设备上
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()

画生成器生成的图的绘图函数

def gen_img_plot(model,epoch,test_input):prediction = model(test_input).detach().cpu().numpy() # 放在内存上 并转换为Numpyprediction = np.squeeze(prediction) # np.squeeze是一个numpy函数,删除数组中形状为1的维度fig = plt.figure(figsize=(4,4))for i in range(16): # 迭代这n张图片plt.subplot(4,4,i+1)plt.imshow((prediction[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]plt.axis('off')plt.show()

显示图片的函数

def img_plot(img):img = np.squeeze(img) # np.squeeze是一个numpy函数,删除数组中形状为1的维度fig = plt.figure(figsize=(4,4))for i in range(16): # 迭代这n张图片plt.subplot(4,4,i+1)plt.imshow((img[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]plt.axis('off')plt.show()

定义训练函数


def train(num_epoch,test_input):D_loss = []G_loss = []# 训练循环for epoch in range(num_epoch):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader) # 返回批次数for step,(img,_) in enumerate(dataloader): # _是标签数据,img是(批次,h,w),每次取的img形状为(64,1,28,28)# print(f'step={step},img.shape={img.shape}')# img_plot(img)img = img.to(device)size = img.size(0) # 得到一个批次的图片random_noise = torch.randn(size,100,device=device) # 生成器的输入'''一. 训练判别器''''''用真实图片训练判别器'''dis_optim.zero_grad()real_output = dis(img) # 对判别取输入真实的图片,输出对真实图片的预测结果# 判别器在真实图像上的损失d_real_loss = bce_loss(real_output,# torch.ones_like(real_output) 创建一个根real_loss一样形状的全1数组,作为标签。torch.ones_like(real_output))d_real_loss.backward()'''用生成的图片训练判别器'''gen_img = gen(random_noise)# 因为此时是为了训练判别器,所以不能让生成器的梯度参与进来。所以用detach()取出无梯度的tensorfake_output = dis(gen_img.detach())d_fake_loss = bce_loss(fake_output,torch.zeros_like(fake_output))d_fake_loss.backward()d_loss = d_real_loss+d_fake_lossdis_optim.step() # 对参数进行优化'''二.训练生成器'''gen_optim.zero_grad()# 刚才是去掉生成器生成的图片的梯度,来训练判别器。此处不需要去掉梯度。让判别器进行判别fake_output = dis(gen_img)# 思想:目的是生成越来越逼真的图片瞒过判别器,让判别器判定生成的图片是真实的图片。# 实现方法:把判别器的结果输入到bce_loss,用1作为标签,看判别器把生成的图片判别为真的损失。g_loss = bce_loss(fake_output,torch.ones_like(fake_output))g_loss.backward()gen_optim.step()# 计算一个epoch的损失with torch.no_grad(): #  禁止梯度计算和参数更新d_epoch_loss +=d_lossg_epoch_loss +=g_loss# 计算整体loss每个epoch的平均Losswith torch.no_grad(): #  禁止梯度计算和参数更新d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)print('Epoch:', epoch+1)print(f'd_epoch_loss={d_epoch_loss}')print(f'g_epoch_loss={g_epoch_loss}')# 将16个长度为100的噪音输入到生成器并画图gen_img_plot(gen,test_input)

开始训练

'''开始计时'''
start_time = time.time()'''开始训练'''
test_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
print(test_input)
num_epoch = 50
train(num_epoch,test_input)
# 保存训练50次的参数
torch.save(gen.state_dict(),'gen_weights.pth')
torch.save(dis.state_dict(),'dis_weights.pth')'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:print(f'{round(run_time,2)}s')
else:print(f'{round(run_time/60,2)}minutes')

结果可视化

在这里插入图片描述

加载训练好的参数

gen.load_state_dict(torch.load('/opt/software/computer_vision/codes/My_codes/paper_codes/GAN/weights/gen_weights.pth'))

用训练好的生成器生成图片并画图

test_new_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
gen_img_plot(gen,test_new_input)

在这里插入图片描述
GAN的生成是随机的,不同的噪声,生成不同的数字

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/115333.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

I IntelliJ IDEA 2023.2 最新解锁方式,支持java20

在 IntelliJ IDEA 2023.1 中&#xff0c;我们根据用户的宝贵反馈对新 UI 做出了大量改进。 我们还实现了性能增强&#xff0c;从而更快导入 Maven&#xff0c;以及在打开项目时更早提供 IDE 功能。 新版本通过后台提交检查提供了简化的提交流程。 IntelliJ IDEA Ultimate 现在支…

C语言每日一练------------Day(7)

本专栏为c语言练习专栏&#xff0c;适合刚刚学完c语言的初学者。本专栏每天会不定时更新&#xff0c;通过每天练习&#xff0c;进一步对c语言的重难点知识进行更深入的学习。 今日练习题关键字&#xff1a;两个数组的交集     双指针 &#x1f493;博主csdn个人主页&#xf…

【MySQL】存储引擎

1.MySQL体系结构 1). 连接层 最上层是一些客户端和链接服务&#xff0c;包含本地 sock 通信和大多数基于客户端 / 服务端工具实现的类似 于TCP/IP的通信。主要完成一些类似于连接处理、授权认证、及相关的安全方案。在该层上引入 了线程池的概念&#xff0c;为通过认证安全接…

Ubuntu 启动出现grub rescue

​ 一&#xff0c;原因 原因&#xff1a;出现 “grub rescue” 错误通常表示您的计算机无法正常引导到操作系统&#xff0c;而是进入了 GRUB&#xff08;Grand Unified Bootloader&#xff09;紧急模式。这可能是由于引导加载程序配置错误、硬盘驱动器损坏或其他引导问题引起…

js定位到元素底部

文字的一行一行添加的&#xff0c;每次添加要滚动条自动定位到元素底部 <div class"An">//要父元素包裹&#xff0c;父元素设置max-height&#xff0c;overflow啥的<div class"friendly_pW"></div></div>//添加文字时找子元素的高…

Windows右键添加用 IDEA 打开

1.安装IDEA时 安装时会有个选项来添加&#xff0c;如下&#xff1a; 勾选即可 2.修改注册表 安装时未勾选&#xff0c;可以把下面代码中程序路径改为自己的&#xff0c;保存为对应的 idea.reg文件&#xff0c;双击即可 Windows Registry Editor Version 5.00[HKEY_CLASSES…

Screaming Frog SEO Spider,为您的网站提供全方位的优化解决方案

Screaming Frog SEO Spider是一款适用于Mac的软件&#xff0c;它可以帮助用户分析网站的优化信息。该软件可以模拟蜘蛛爬行的方式&#xff0c;抓取网站的各种信息&#xff0c;并将这些信息整理成易于理解的报告。这些报告可以帮助用户评估网站的优化情况&#xff0c;发现链接的…

pyqt5-快捷键QShortcut

import sys from PyQt5.QtWidgets import * from PyQt5.QtCore import * from PyQt5.QtGui import *""" 下面示例揭示了&#xff0c;当关键字绑定的控件出现的时候&#xff0c;快捷键才管用&#xff0c; 绑定的控件没有出现的时候快捷键无效 """…

微前沿 | 第1期:强可控视频生成;定制化样本检索器;用脑电重建视觉感知;大模型鲁棒性评测

欢迎阅读我们的新栏目——“微前沿”&#xff01; “微前沿”汇聚了微软亚洲研究院最新的创新成果与科研动态。在这里&#xff0c;你可以快速浏览研究院的亮点资讯&#xff0c;保持对前沿领域的敏锐嗅觉&#xff0c;同时也能找到先进实用的开源工具。 本期内容速览 01. 强可…

Vue2里监听localstorage里值的变化

有的时候,我们需要根据本地缓存在localstorage里值的变化做出相应的操作,这就需要我们监听localstorage: 首先,我们在src下的libs文件夹下新建一个stroage.js用于重写setItem事件,当使用setItem的时候,触发,window.dispatchEvent派发事件 const Stroage = {// 重写set…

python基础之miniConda管理器

一、介绍 MiniConda 是一个轻量级的 Conda 版本&#xff0c;它是 Conda 的精简版&#xff0c;专注于提供基本的环境管理功能。Conda 是一个流行的开源包管理系统和环境管理器&#xff0c;用于在不同的操作系统上安装、管理和运行软件包。 与完整版的 Anaconda 相比&#xff0c…

安卓版yolo-fastest

安卓版本yolofastest效果测试 安卓配置OPENCV4ANDROID&#xff0c;见我的博客一篇文章opencv4dandroid配置 这个不需要使用JNI&#xff0c;十分简单的配置 说真的&#xff0c;其实只调用OPENCV的函数&#xff0c;自己写的代码不多&#xff0c;使用OPENCV4ANDROID和JNI的时间差…

SPSS教程:如何绘制带误差的折线图

SPSS教程&#xff1a;如何绘制带误差的折线图 1、问题与数据 研究者想研究45-65岁健康男性中&#xff0c;静坐时长和血胆固醇水平的关系&#xff0c;故招募100名研究对象询问其每天静坐时长&#xff08;time&#xff09;&#xff0c;并检测其血液中胆固醇水平&#xff08;cho…

2023开学礼《乡村振兴战略下传统村落文化旅游设计》北农馆藏许少辉八一新书

2023开学礼《乡村振兴战略下传统村落文化旅游设计》北京农学院图书馆许少辉八一新书

ModaHub魔搭社区——决胜大模型时代,算力、网络、向量数据库缺一不可

大模型应用场景日趋多样,需求也随着增加,进而倒逼着多元算力方面的创新,为满足AI工作负载的需求,采用GPU、FPGA、ASIC等加速卡的服务器越来越多。 根据IDC数据统计,2022年,中国加速服务器市场相比2019年增长44.0亿美元,服务器市场增量的一半更是来自加速服务器。 这意味…

已解决 Python FileNotFoundError 的报错问题

本文摘要&#xff1a;本文已解决 Python FileNotFoundError 的相关报错问题&#xff0c;并总结提出了几种可用解决方案。同时结合人工智能GPT排除可能得隐患及错误。 &#x1f60e; 作者介绍&#xff1a;我是程序员洲洲&#xff0c;一个热爱写作的非著名程序员。CSDN全栈优质领…

SQL Server开启变更数据捕获(CDC)

一、CDC简介 变更数据捕获&#xff08;Change Data Capture &#xff0c;简称 CDC&#xff09;&#xff1a;记录 SQL Server 表的插入、更新和删除操作。开启cdc的源表在插入、更新和删除操作时会插入数据到日志表中。cdc通过捕获进程将变更数据捕获到变更表中&#xff0c;通过…

即时物流进入盈利期,为什么说顺丰同城才是“头雁”?

从餐饮店、便利店老板们扮演跑腿角色给顾客送商品算起&#xff0c;即时配送&#xff08;简称“即配”&#xff09;行业跌跌撞撞好几年&#xff0c;规模壮大、秩序提升&#xff0c;但盈亏平衡的及格线&#xff0c;始终让人望洋兴叹。直到这个夏天&#xff0c;平均分终于被拉上去…

服务器部署前后端项目-SQL Father为例

hello~大家好哇&#xff0c;好久没更新博客了。现在来更新一波hhh 现在更新一下部署上的一些东西&#xff0c;因为其实有很多小伙伴跟我之前一样&#xff0c;很多时候只是开发了&#xff0c;本地前后端都能调通&#xff0c;也能用&#xff0c;但是没有部署到服务器试过&#x…

GraphQL渗透测试案例及防御办法

什么是GraphQL GraphQL 是一种 API 查询语言&#xff0c;旨在促进客户端和服务器之间的高效通信。它使用户能够准确指定他们在响应中所需的数据&#xff0c;从而有助于避免有时使用 REST API 看到的大型响应对象和多个调用。 GraphQL 服务定义了一个合约&#xff0c;客户端可…