DP-GAN-生成器代码

首先看一下数据生成:
在这里插入图片描述
在预处理阶段会将label经过ont-hot编码转换为35个通道,即每个通道都是由(0,1)组成。
在这里插入图片描述
在train文件中,对生成器和判别器分别进行更新,根据loss的不同,分别计算对于的损失:

loss_G, losses_G_list = model(image, label, "losses_G", losses_computer)
loss_D, losses_D_list = model(image, label, "losses_D", losses_computer)

在model中:

from models.sync_batchnorm import DataParallelWithCallback
import models.generator as generators
import models.discriminator as discriminators
import os
import copy
import torch
import torch.nn as nn
from torch.nn import init
import models.losses as losses
class DP_GAN_model(nn.Module):def __init__(self, opt):super(DP_GAN_model, self).__init__()self.opt = opt#--- generator and discriminator ---self.netG = generators.DP_GAN_Generator(opt).cuda()if opt.phase == "train" or opt.phase == "eval":self.netD = discriminators.DP_GAN_Discriminator(opt)self.print_parameter_count()self.init_networks()#--- EMA of generator weights ---with torch.no_grad():self.netEMA = copy.deepcopy(self.netG) if not opt.no_EMA else None#--- load previous checkpoints if needed ---self.load_checkpoints()#--- perceptual loss ---#if opt.phase == "train":if opt.add_vgg_loss:self.VGG_loss = losses.VGGLoss(self.opt.gpu_ids)self.GAN_loss = losses.GANLoss()self.MSELoss = nn.MSELoss(reduction='mean')def align_loss(self, feats, feats_ref):loss_align = 0for f, fr in zip(feats, feats_ref):loss_align += self.MSELoss(f, fr)return loss_aligndef forward(self, image, label, mode, losses_computer):# Branching is applied to be compatible with DataParallelif mode == "losses_G":loss_G = 0fake = self.netG(label)output_D, scores, feats = self.netD(fake)_, _, feats_ref = self.netD(image)loss_G_adv = losses_computer.loss(output_D, label, for_real=True)loss_G += loss_G_advloss_ms = self.GAN_loss(scores, True, for_discriminator=False)loss_G += loss_ms.item()loss_align = self.align_loss(feats, feats_ref)loss_G += loss_alignif self.opt.add_vgg_loss:loss_G_vgg = self.opt.lambda_vgg * self.VGG_loss(fake, image)loss_G += loss_G_vggelse:loss_G_vgg = Nonereturn loss_G, [loss_G_adv, loss_G_vgg]if mode == "losses_D":loss_D = 0with torch.no_grad():fake = self.netG(label)output_D_fake, scores_fake, _ = self.netD(fake)loss_D_fake = losses_computer.loss(output_D_fake, label, for_real=False)loss_ms_fake = self.GAN_loss(scores_fake, False, for_discriminator=True)loss_D += loss_D_fake + loss_ms_fake.item()output_D_real, scores_real, _ = self.netD(image)loss_D_real = losses_computer.loss(output_D_real, label, for_real=True)loss_ms_real = self.GAN_loss(scores_real, True, for_discriminator=True)loss_D += loss_D_real + loss_ms_real.item()if not self.opt.no_labelmix:mixed_inp, mask = generate_labelmix(label, fake, image)output_D_mixed, _, _ = self.netD(mixed_inp)loss_D_lm = self.opt.lambda_labelmix * losses_computer.loss_labelmix(mask, output_D_mixed, output_D_fake,output_D_real)loss_D += loss_D_lmelse:loss_D_lm = Nonereturn loss_D, [loss_D_fake, loss_D_real, loss_D_lm]if mode == "generate":with torch.no_grad():if self.opt.no_EMA:fake = self.netG(label)else:fake = self.netEMA(label)return fakeif mode == "eval":with torch.no_grad():pred, _, _ = self.netD(image)return preddef load_checkpoints(self):if self.opt.phase == "test":which_iter = self.opt.ckpt_iterpath = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")if self.opt.no_EMA:self.netG.load_state_dict(torch.load(path + "G.pth"))else:self.netEMA.load_state_dict(torch.load(path + "EMA.pth"))elif self.opt.phase == "eval":which_iter = self.opt.ckpt_iterpath = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")self.netD.load_state_dict(torch.load(path + "D.pth"))elif self.opt.continue_train:which_iter = self.opt.which_iterpath = os.path.join(self.opt.checkpoints_dir, self.opt.name, "models", str(which_iter) + "_")self.netG.load_state_dict(torch.load(path + "G.pth"))self.netD.load_state_dict(torch.load(path + "D.pth"))if not self.opt.no_EMA:self.netEMA.load_state_dict(torch.load(path + "EMA.pth"))def print_parameter_count(self):if self.opt.phase == "train":networks = [self.netG, self.netD]else:networks = [self.netG]for network in networks:param_count = 0for name, module in network.named_modules():if (isinstance(module, nn.Conv2d)or isinstance(module, nn.Linear)or isinstance(module, nn.Embedding)):param_count += sum([p.data.nelement() for p in module.parameters()])print('Created', network.__class__.__name__, "with %d parameters" % param_count)def init_networks(self):def init_weights(m, gain=0.02):classname = m.__class__.__name__if classname.find('BatchNorm2d') != -1:if hasattr(m, 'weight') and m.weight is not None:init.normal_(m.weight.data, 1.0, gain)if hasattr(m, 'bias') and m.bias is not None:init.constant_(m.bias.data, 0.0)elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):init.xavier_normal_(m.weight.data, gain=gain)if hasattr(m, 'bias') and m.bias is not None:init.constant_(m.bias.data, 0.0)if self.opt.phase == "train":networks = [self.netG, self.netD]else:networks = [self.netG]for net in networks:net.apply(init_weights)def put_on_multi_gpus(model, opt):if opt.gpu_ids != "-1":gpus = list(map(int, opt.gpu_ids.split(",")))model = DataParallelWithCallback(model, device_ids=gpus).cuda()else:model.module = modelassert len(opt.gpu_ids.split(",")) == 0 or opt.batch_size % len(opt.gpu_ids.split(",")) == 0return modeldef preprocess_input(opt, data):data['label'] = data['label'].long()if opt.gpu_ids != "-1":data['label'] = data['label'].cuda()data['image'] = data['image'].cuda()label_map = data['label']bs, _, h, w = label_map.size()nc = opt.semantic_ncif opt.gpu_ids != "-1":input_label = torch.cuda.FloatTensor(bs, nc, h, w).zero_()else:input_label = torch.FloatTensor(bs, nc, h, w).zero_()input_semantics = input_label.scatter_(1, label_map, 1.0)return data['image'], input_semanticsdef generate_labelmix(label, fake_image, real_image):target_map = torch.argmax(label, dim = 1, keepdim = True)all_classes = torch.unique(target_map)for c in all_classes:target_map[target_map == c] = torch.randint(0,2,(1,)).cuda()target_map = target_map.float()mixed_image = target_map*real_image+(1-target_map)*fake_imagereturn mixed_image, target_map

首先看生成器流程:
标签输入到生成器中得到fake image,fake image 和 real image 共同输入到判别器中得到中间变量输出,接着分别计算四个损失。我们需要明白生成器和辨别器模型的搭建,损失计算过程。
在这里插入图片描述
首先是生成器的组成:
在这里插入图片描述
在这里插入图片描述
输入标签大小是(b,c,h,w),首先z等于一个正态分布的随机数,大小为(b,64),接着view为(b,64,1,1),再扩张到(b,64,h,w)和(b,c,h,w)沿着通道维度拼接起来。将拼接的结果上采样到W和H大小。
在这里插入图片描述
其中在CityscapesDataset指定了:
在这里插入图片描述
则w=512//2^5=16,h=16/2=8.
在这里插入图片描述
令s等于input label,输入到pyrmid中,生成结果添加到列表中。

self.seg_pyrmid = nn.ModuleList([])if not self.opt.no_3dnoise:self.fc = nn.Conv2d(self.opt.semantic_nc + self.opt.z_dim, 16 * ch, 3, padding=1)self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(self.opt.semantic_nc + self.opt.z_dim, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True)))else:self.fc = nn.Conv2d(self.opt.semantic_nc, 16 * ch, 3, padding=1)self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(self.opt.semantic_nc, 32, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(32, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))for i in range(len(self.channels)-2):self.seg_pyrmid.append(nn.Sequential(nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)))         

而pyrmid是一个modulist,便利添加的每一个module,生成一个结果:
首先将标签图和噪声拼接起来经过一个3x3卷积,输出通道变为32,再经过一个1x1卷积,输出通道变为64.再经过经过5个步长为2的3x3卷积,下采样32倍。这样pyrmid列表中就有7个结果。
接着将已经采样的x输入到Fc中,输出通道是1024.这里需要清楚两个变量x,和pyrmid.
1:x是输入下采样到(H,W)大小的label+noise.
2:pyrmid是储存经过七次(五次下采样)卷积之后的label+noise。
接着将pyrmid最后一个值采样到x的大小。然后和pyrmid的第i个值拼接在一起。
在这里插入图片描述
对应于:
在这里插入图片描述
每拼接一次生成的值和经过Fc之后的label+noise共同作为输入:
在这里插入图片描述
输入到SPADE块中:
首先要判断SPAD的两个参数即输入通道是否相等。
在这里插入图片描述
在这里插入图片描述
如果相等就输入到SPADE模块,如果不等令变量等于输入值。
在这里插入图片描述
其中最后一个参数是类别值:在Cityscape数据集设定语义标签是34类。有一类是未知,加上噪声的64个通道。
在这里插入图片描述
SPADE:

class SPADE(nn.Module):def __init__(self, opt, norm_nc, label_nc):super().__init__()self.first_norm = get_norm_layer(opt, norm_nc)ks = opt.spade_ksnhidden = 128pw = ks // 2#self.mlp_shared = nn.Sequential(#    nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),#    nn.ReLU()#)self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)def forward(self, x, segmap):normalized = self.first_norm(x)#segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')#actv = self.mlp_shared(segmap)actv = segmapgamma = self.mlp_gamma(actv)beta = self.mlp_beta(actv)out = normalized * (1 + gamma) + betareturn out

公式:
在这里插入图片描述
首先X经过一个norm层,即为分布式BN。
在这里插入图片描述
在这里插入图片描述
接着使用卷积学习β和γ。
在这里插入图片描述
在这里插入图片描述
卷积核大小都为3,padding为1。
接着经过bn之后的变量和γ相乘在和β相加,再和经过归一化之后的x相加。
在这里插入图片描述
接着:x和seg经过相同的norm操作。再进过一个LeakyReLU,再进行一个卷积层。中间有个midlayer过渡。
在这里插入图片描述
在这里插入图片描述
输出的结果经过一个跳连接得到最后输出。
在这里插入图片描述
经过SPADE之后的输出上采样两倍作为输入输入到下一个SPADE中。
最终输出一个通道为3的RGB图片。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/75447.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

“ARTS挑战:探索技术,分享思考“

文章目录 前言一、学习的内容二、遇到的困难及解决办法三、学习打卡成果展示四、学习技巧的总结五、未来学习打卡计划后记 关于 ARTS 的释义 ● Algorithm: 每周至少做一个 LeetCode 的算法题 ● Review: 阅读并点评至少一篇英文技术文章 ● Tips: 学习至少一个技术技巧 ● Sha…

Linux(三):Linux服务器下日常实操命令 (常年更新)

基础命令 cd命令:切换目录 cd :切换当前目录百至其它目录,比如进入/etc目录,则执行 cd /etccd / :在Linux 系统中斜杠“/”表示的是根目录。cd / ,即进入根目录.cd ~:进入用户在该系统的home目录&#…

【eNSP】静态路由

【eNSP】静态路由 原理网关路由表 实验根据图片连接模块配置路由器设备R1R2R3R4 配置PC的IP地址、掩码、网关PC1PC2PC3 配置静态路由查看路由表R1R2R3R4测试能否通信 原理 网关 网关与路由器地址相同,一般路由地址为.1或.254。 网关是当电脑发送的数据的目标IP不在…

mac录屏怎么打开?很简单,让我来教你!

mac电脑作为一款广受欢迎的电脑系统,提供了多种方式来满足用户录屏的需求。无论您是要录制教学视频、制作演示文稿,还是记录游戏精彩瞬间,mac电脑都能帮助您实现这些目标。本文将为您介绍两种mac录屏的方法。通过本文的指导,您将能…

数据包在网络中传输的过程

ref: 【先把这个视频看完了】:数据包的传输过程【网络常识10】_哔哩哔哩_bilibili 常识都看看 》Ref: 1. 这个写的嘎嘎好,解释了为啥4层7层5层,还有数据包封装的问题:数据包在网络中的传输过程详解_数据包传输_张孟浩_jay的博客…

CSS调色网有哪些

本文章转载于湖南五车教育,仅用于学习和讨论,如有侵权请联系 1、https://webgradients.com/ Wbgradients 是一个在线调整渐变色的网站 ,可以根据你想要的调整效果,同时支持复制 CSS 代码,可以更好的与开发对接。 Wbg…

LeetCode--剑指Offer75(2)

目录 题目描述:剑指 Offer 58 - II. 左旋转字符串(简单)题目接口解题思路1代码解题思路2代码 PS: 题目描述:剑指 Offer 58 - II. 左旋转字符串(简单) 字符串的左旋转操作是把字符串前面的若干个字符转移到…

智慧防汛,数字科技的力量

随着夏日的脚步临近,台风季节即将降临。对于那些居住在沿海地区的人们来说,台风是一种常见的自然灾害,其带来的风雨可能对生命和财产造成严重威胁。然而,随着数字科技的飞速发展,可视化技术为防汛抗台工作带来了全新的…

基于canvas画布的实用类Fabric.js的使用

目录 前言 一、Fabric.js简介 二、开始 1、引入Fabric.js 2、在main.js中使用 3、初始化画布 三、方法 四、事件 1、常用事件 2、事件绑定 3、事件解绑 五、canvas常用属性 六、对象属性 1、基本属性 2、扩展属性 七、图层层级操作 八、复制和粘贴 1、复制 2…

Collections工具类(java)

文章目录 7.1 常用方法 参考操作数组的工具类:Arrays,Collections 是一个操作 Set、List 和 Map 等集合的工具类。 7.1 常用方法 Collections 中提供了一系列静态的方法对集合元素进行排序、查询和修改等操作,还提供了对集合对象设置不可变、…

C#控制台程序+Window增加右键菜单

有时候我们可能会想定制一些自己的右键菜单功能,帮我们减少重复的操作。那么使用控制台程序加自定义右键菜单,就可以很好地满足我们的需求。 1 编写控制台程序 因为我只用到了在文件夹中空白处的右键菜单,所以这里提供了一个对应的模板&…

计算机视觉实验:图像处理综合-路沿检测

目录 实验步骤与过程 1. 路沿检测方法设计 2. 路沿检测方法实现 2.1 视频图像提取 2.2 图像预处理 2.3 兴趣区域提取 2.4 边缘检测 ​​​​​​​2.5 Hough变换 ​​​​​​​2.6 线条过滤与图像输出 3. 路沿检测结果展示 4. 其他路沿检测方法 实验结论或体会 实…

推荐一款非常简单实用的数据库连接工具Navicat Premium

Navicat Premium是一款非常实用的数据库连接工具,别再用HeidiSQL和idea自带的数据库连接了,看完这篇文章,赶紧把Navicat Premium用起来吧。 首先,需要获取Navicat Premium的安装包,可以通过以下网盘链接下载&#xff0…

计算机二级Python基本操作题-序号45

1. 键盘输入一组水果名称并以空格分隔,共一行。 示例格式如下: 苹果 芒果 草莓 芒果 苹果 草莓 芒果 香蕉 芒果 草莓 统计各类型的数量,从数量多到少的顺序输出类型及对应数量,以英文冒号分隔,每个类型行。输出结果保存…

复亚智能打造全新云平台:让无人机任务管理更智能、更简单

复亚智能全新升级的MindView云平台,对航线规划、任务管理、自动飞行、数据管理等各个环节开展可视化、数字化、智能化监管,从任务到结果的“看得清”、“管得住”、“查得准”,带来更轻松的操作,改善作业效率、安全保障和用户体验…

使用Gunicorn+Nginx部署Flask项目

部署-开发机上的准备工作 确认项目没有bug。用pip freeze > requirements.txt将当前环境的包导出到requirements.txt文件中,方便部署的时候安装。将项目上传到服务器上的/srv目录下。这里以git为例。使用git比其他上传方式(比如使用pycharm&#xff…

网络安全进阶学习第九课——SQL注入介绍

文章目录 一、什么是注入二、什么是SQL注入三、SQL注入产生的原因四、SQL注入的危害五、SQL注入在渗透中的利用1、绕过登录验证:使用万能密码登录网站后台等。2、获取敏感数据3、文件系统操作4、注册表操作5、执行系统命令 六、如何挖掘SQL注入1、SQL注入漏洞分类按…

微信小程序wx.getlocation接口权限申请总结

先附上申请通过截图 插播内容:可代开通,保证通过。wx.getLocation接口(获取当前的地址位置) qq: 308205428 如何申请 当申请微信小程序的wx.getLocation接口权限时,你可以…

8.4 day05软件学习

文章目录 微服务的概念微服务的原则微服务的特征:集群介绍 spring aop 在家学习效率真不高,下午好兄弟喊出去玩,一直到晚上才回来,赶紧总结一下早上学习的内容。 继续看java基础进阶的思想,之前学的很多都忘了。 微服…

SQL基础复习与进阶

SQL进阶 文章目录 SQL进阶关键字复习ALLANYEXISTS 内置函数ROUND(四舍五入)TRUNCATE(截断函数)SEILING(向上取整)FLOOR(向下取整)ABS(获取绝对值)RAND&#x…