昇思25天学习打卡营第5天|GAN图像生成

文章目录

      • 昇思MindSpore应用实践
        • 基于MindSpore的生成对抗网络图像生成
          • 1、生成对抗网络简介
            • 零和博弈 vs 极大极小博弈
            • GAN的生成对抗损失:
          • 2、基于MindSpore的 Vanilla GAN
          • 3、基于MindSpore的手写数字图像生成
            • 导入数据
            • 数据可视化
            • 模型训练
      • Reference

昇思MindSpore应用实践

本系列文章主要用于记录昇思25天学习打卡营的学习心得。

基于MindSpore的生成对抗网络图像生成
1、生成对抗网络简介
零和博弈 vs 极大极小博弈

生成对抗网络Generative adversarial networks (GANs)主要包括生成器网络(Generator)和判别器网络(Discriminator)
这两个网络在GAN的训练过程中相互竞争,形成了一种博弈论中的极大极小博弈(MinMax game)

零和博弈(Zero-sum game)是博弈论中的一个重要概念,指的是参与者的利益完全相反,即一方的利益的增加意味着另一方的利益的减少,总利益为零。在零和博弈中,参与者之间的利益是完全对立的,因此一个参与者的利益的增加必然导致其他参与者的利益减少。在非合作博弈中,纳什均衡是一种重要的解,纳什均衡代表每个玩家选择的策略都是其在对方策略给定的情况下的最优策略。在零和博弈中,寻找纳什均衡通常涉及找到使每个玩家的预期收益最大化的策略组合。

极大极小博弈(MinMax game)是一种博弈论中的解决方法,用于确定参与者的最佳决策策略,此外为人所熟知用于决策的方法还有强化学习。在极大极小博弈中,每个参与者都试图最大化自己的最小收益。也就是说,每个参与者都采取行动,以确保在对手选择其最优策略时自己的收益最大化。

假设GAN网络训练达到了纳什平衡状态,那么判别器无法准确地判断出输入样本是真样本还是假样本,此时判别器失效,生成器达到了巅峰状态,我们就无需使用判别器并终止训练了,得到的生成器就是我们用来生成数据的预训练模型。

在这里插入图片描述
从理论上讲,此博弈游戏的平衡点是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG(x;θ)=pdata(x),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布。
  2. 判别器通过求取梯度和损失函数对网络进行优化,将靠近真实数据分布的数据判定为1,将靠近生成器生成出来数据分布的数据判定为0。
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据。
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2,如上图中的(d)所示。
GAN的生成对抗损失:

min ⁡ G max ⁡ D V ( G , D ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \underset{G}{\min} \underset{D}{\max}V(G, D) = \mathbb{E}_{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(G,D)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

GAN网络本身就是在训练一个能达到平衡状态的损失函数,生成对抗损失是GANs中最基本的损失函数。

近十年来著名的GAN网络结构:
在这里插入图片描述

2、基于MindSpore的 Vanilla GAN

生成器部分:生成器 Generator 的功能是将隐码映射到数据空间。通过五层 Dense 全连接层来完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,使其返回 [-1,1] 的数据范围内,并返回一张28x28的图像作为生成结果。

from mindspore import nn
import mindspore.ops as opsimg_size = 28  # 训练图像长(宽)28x28class Generator(nn.Cell):def __init__(self, latent_size, auto_prefix=True):super(Generator, self).__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 100] -> [N, 128]# 输入一个100维的0~1之间的高斯分布,通过第一层线性变换将其映射到128维self.model.append(nn.Dense(latent_size, 128))self.model.append(nn.ReLU())# 通过第二层线性变换将其映射到256维# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))self.model.append(nn.BatchNorm1d(256))self.model.append(nn.ReLU())# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))self.model.append(nn.BatchNorm1d(512))self.model.append(nn.ReLU())# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))self.model.append(nn.BatchNorm1d(1024))self.model.append(nn.ReLU())# [N, 1024] -> [N, 784]# 经过线性变换将其变成784维self.model.append(nn.Dense(1024, img_size * img_size))# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间self.model.append(nn.Tanh())def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 28, 28))net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

判别器部分:判别器 Discriminator 是一个二分类网络模型,在训练时,判别器接收生成器的生成图像与对应的真实数据相对比,输出判定该图像为真实图的概率。主要通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。

 # 判别器
class Discriminator(nn.Cell):def __init__(self, auto_prefix=True):super().__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 784] -> [N, 512]self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数# [N, 512] -> [N, 256]self.model.append(nn.Dense(512, 256))  # 进行一个线性映射self.model.append(nn.LeakyReLU())# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)net_d = Discriminator()
net_d.update_parameters_name('discriminator')
3、基于MindSpore的手写数字图像生成
导入数据
import numpy as np
import mindspore.dataset as dsbatch_size = 128
latent_size = 100  # 潜在编码的长度train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')def data_load(dataset):dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False)# 数据增强mnist_ds = dataset1.map(  # 通过map方法给每张图像映射一个潜在编码# 将图像数据转换为 float32 类型# 生成一个长度为 latent_size 的服从正态分布的随机向量,并将其转换为 float32 类型operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),output_columns=["image", "latent_code"])mnist_ds = mnist_ds.project(["image", "latent_code"])# 批量操作mnist_ds = mnist_ds.batch(batch_size, True)return mnist_dsmnist_ds = data_load(train_dataset)iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)
数据可视化
import matplotlib.pyplot as pltdata_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):image = data_iter['image'][idx]figure.add_subplot(rows, cols, idx)plt.axis("off")plt.imshow(image.squeeze(), cmap="gray")
plt.show()

在这里插入图片描述
潜在编码(latent code)的构造:
为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise输入到生成器中,通过这组固定的潜在编码(也叫隐码)所生成的图像效果来评估生成器的生成质量。

import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)
模型训练

定义损失函数和优化器:

lr = 0.0002  # 学习率# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

训练分为两个主要部分,也就是要训练两个网络:生成与对抗网络。

第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。按照原论文的方法,通过提高其随机梯度来更新判别器,最大化 l o g D ( x ) + l o g ( 1 − D ( G ( z ) ) log D(x) + log(1 - D(G(z)) logD(x)+log(1D(G(z)) 的值。

第二部分是训练生成器。如论文所述,最小化 l o g ( 1 − D ( G ( z ) ) ) log(1 - D(G(z))) log(1D(G(z))) 来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将固定隐码批量推送到生成器中,以直观地跟踪生成器 Generator 的训练效果。

import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpointtotal_epoch = 24  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)real_out = net_d(real_data)real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))loss_d = real_loss + fake_lossreturn loss_d# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())def train_step(real_data, latent_code):# 计算判别器损失和梯度loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)return loss_d, loss_g# 保存生成的test图像
def save_imgs(gen_imgs1, idx):for i3 in range(gen_imgs1.shape[0]):plt.subplot(5, 5, i3 + 1)plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")plt.axis("off")plt.savefig(image_path + "/test_{}.png".format(idx))# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)net_g.set_train()
net_d.set_train()# 储存生成器和判别器loss
losses_g, losses_d = [], []for epoch in range(total_epoch):start = time.time()for (iter, data) in enumerate(mnist_ds):start1 = time.time()image, latent_code = dataimage = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])d_loss, g_loss = train_step(image, latent_code)end1 = time.time()if iter % 10 == 0:print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "f"loss_d:{d_loss.asnumpy():>4f} , "f"loss_g:{g_loss.asnumpy():>4f} , "f"time:{(end1 - start1):>3f}s, "f"lr:{lr:>6f}")end = time.time()print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))losses_d.append(d_loss.asnumpy())losses_g.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片gen_imgs = net_g(test_noise)save_imgs(gen_imgs.asnumpy(), epoch)# 根据epoch保存模型权重文件if epoch % 1 == 0:save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))import time
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'Wayn_Fan-sail')

使用cpu进行12个epoch的生成效果如下:
在这里插入图片描述

Reference

昇思官方文档-GAN图像生成
昇思大模型平台

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/364739.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring Cloud LoadBalancer基础入门与应用实践

官网地址:https://docs.spring.io/spring-cloud-commons/reference/spring-cloud-commons/loadbalancer.html 【1】概述 Spring Cloud LoadBalancer是由SpringCloud官方提供的一个开源的、简单易用的客户端负载均衡器,它包含在SpringCloud-commons中用…

json文件 增删查改

默认收藏夹 qt操作json格式文件... 这个人的 写的很好 我的demo全是抄他的 抄了就能用 —————————— 下次有空把我的demo 传上来 在E盘的demo文件夹 json什么名字

追觅科技25届校招校招24年社招科技北森题库商业推理综合测评答题攻略、通关技巧

一、追觅科技这家公司怎么样? 追觅科技是一家在智能清洁家电领域表现出色的企业。 二、追觅科技待遇怎么样 追觅科技的待遇在业内具有竞争力,具体信息如下: 1. **薪酬结构**:根据对外经济贸易大学招生就业处发布的2023届校园招…

Ubuntu挂载window的网络共享文件夹爱

1.进入win10创建一个用户smb密码也是smb 2.右键进入文件夹共享 3.进入Ubuntu安装支持cifs-utils sudo apt update sudo apt install cifs-utils 4.sudo mkdir /mnt/shared 5.挂载: sudo mount -t cifs -o usernamesm bpasswordsmb //172.16.11.37(windowsIP)/s…

杨幂跨界学术圈:内容营销专家刘鑫炜带你了解核心期刊的学术奥秘

近日&#xff0c;知名艺人杨幂在权威期刊《中国广播电视学刊》上发表了一篇名为《浅谈影视剧中演员创作习惯——以电视剧<哈尔滨一九四四>为例》的学术论文&#xff0c;此举在学术界和娱乐圈均引起了广泛关注。该期刊不仅享有极高的声誉&#xff0c;还同时被北大中文核心…

sheng的学习笔记-AI-密度聚类

AI目录&#xff1a;sheng的学习笔记-AI目录-CSDN博客 需要学习的前置知识&#xff1a;聚类&#xff0c;可参考&#xff1a;sheng的学习笔记-AI-聚类(Clustering)-CSDN博客 什么是密度聚类 密度聚类亦称“基于密度的聚类”(density-based clustering)&#xff0c;此类算法假设…

Linux-笔记 OverlayFS文件系统小应用 恢复功能

前言 通过另一章节 OverlayFS文件系统入门 中已经大致了解了原理&#xff0c;这里来实现一个小应用。通过前面介绍我们已经知道lowerdir是只读层&#xff0c;upperdir是可读写层&#xff0c;merged是合并层&#xff08;挂载点&#xff09;&#xff0c;那么我们可以利用这个机…

【网络】计算机网络-基本知识

目录 概念计算机网络功能计算机网络的组成计算机网络的分类 网络地址网络地址的分类 计算机网络相关性能指标速率带宽吞吐量时延时延的种类&#xff1a; 时延带宽积往返时延RTT利用率 概念 计算机网络是指将多台计算机通过通信设备连接起来&#xff0c;实现数据和资源的共享。…

pandas数据分析(1)

pandas&#xff0c;即Python数据分析库&#xff08;Python data analysis library&#xff09; DataFrame和Series DataFrame&#xff08;数据帧&#xff09;和Series&#xff08;序列&#xff09;是pandas的核心数据结构。DataFrame的主要组件包含索引、列、数据。DataFrame和…

架构设计上中的master三种架构,单节点,主从节点,多节点分析

文章目录 背景单节点优点缺点 主从节点优点缺点 多节点优点缺点 多节点&#xff0c;多backup设计优点缺点 总结 背景 在很多分布式系统里会有master,work这种结构。 master 节点负责管理资源&#xff0c;分发任务。下面着重讨论下master 数量不同带来的影响 单节点 优点 1.设…

二叉搜索数的最小绝对差-二叉树

需要用到中序遍历 中序遍历 94. 二叉树的中序遍历 - 力扣&#xff08;LeetCode&#xff09; 递归 class Solution { public:vector<int> inorderTraversal(TreeNode* root) {vector<int> res;inoder(root,res);return res;}void inoder(TreeNode* root , vector…

代码随想录-二叉搜索树(1)

目录 二叉搜索树的定义 700. 二叉搜索树中的搜索 题目描述&#xff1a; 输入输出示例&#xff1a; 思路和想法&#xff1a; 98. 验证二叉搜索树 题目描述&#xff1a; 输入输出示例&#xff1a; 思路和想法&#xff1a; 530. 二叉搜索树的最小绝对差 题目描述&#x…

Python和MATLAB粘性力接触力动态模型半隐式欧拉算法

&#x1f3af;要点 &#x1f3af;运动力模型计算制作过程&#xff1a;&#x1f58a;相机捕捉网球运动图&#xff0c;制定运动数学模型&#xff0c;数值微分运动方程 | &#x1f58a;计算运动&#xff0c;欧拉算法离散积分运动&#xff0c;欧拉-克罗默算法微分运动方程 &#…

linux的CP指令

实现 CP 指令 src 源文件 des 目标文件 执行流程&#xff1a; 打开源文件&#xff08; src &#xff09; open 打开目标文件&#xff08; des &#xff09; open 写入目标文件 write 读取 src 文件到缓存数组 read 关闭目标文件和源文件 close ./a.out src.c de…

【Linux】进程 | 控制块pcb | task_struct | 创建子进程fork

目录 Ⅰ. 进程的概念&#xff08;Process&#xff09; 1. 什么是进程&#xff1f; 2. 多进程管理 3. 进程控制块&#xff08;PCB&#xff09; task_struct 的结构 Ⅱ. 进程查看与管理 1. 使用指令查看进程 ​编辑 2. /proc 查看进程信息 ​编辑 3. 获取进程 ID 4. …

ONLYOFFICE 8.1 版本桌面编辑器测评

在现代办公环境中&#xff0c;办公软件的重要性不言而喻。从文档处理到电子表格分析&#xff0c;再到演示文稿制作&#xff0c;强大且高效的办公软件工具能够极大提升工作效率。ONLYOFFICE 作为一个功能全面且开源的办公软件套件&#xff0c;一直以来都受到广大用户的关注与喜爱…

第三届人工智能、物联网与云计算技术国际会议(AIoTC 2024)

第三届人工智能、物联网与云计算技术国际会议(AIoTC 2024)将于2024年9月13日-15日在中国武汉举行。本次会议由华中师范大学伍伦贡联合研究院与南京大学联合主办、江苏省大数据区块链与智能信息专委会承办、江苏省概率统计学会、江苏省应用统计学会、Sir Forum、南京理工大学、南…

K8S集群进行分布式负载测试

使用K8S集群执行分布式负载测试 本教程介绍如何使用Kubernetes部署分布式负载测试框架&#xff0c;该框架使用分布式部署的locust 产生压测流量&#xff0c;对一个部署到 K8S集群的 Web 应用执行负载测试&#xff0c;该 Web 应用公开了 REST 格式的端点&#xff0c;以响应传入…

固定翼无人机入门(二)

这里讲讲无人机的路径跟踪控制相关知识&#xff0c;路径跟踪需要制导率&#xff08;平面&#xff09;和控制器&#xff0c;在无人机中较为常用的是L1制导率&#xff0c;不过L1制导率是控制无人机在二维平面上的转向&#xff0c;此处还引入总能量控制&#xff0c;控制无人机的高…

uniapp加载打点点效果

uniapp加载打点点效果 背景实现思路代码实现尾巴 背景 为了增加系统的交互性&#xff0c;我们在加载数据时通常会增加一些loading动效&#xff0c;但是在某些场景下只需要一些简单文字提醒。比如说使用【加载中】或者【loading】等字段&#xff0c;但是写静态的字符又显得交互…