一、前言
2017年,NVIDIA Research 网站发布了一篇颇为震撼的GAN论文:Progressive Growing of GANs for Improved Quality, Stability, and Variation(简称PGGAN),通过使用渐增型GAN网络和精心处理的CelebA-HQ数据集,实现了效果令人惊叹的生成图像,分辨率达到1024x1024。
论文地址
Progressive Growing of GANs for Improved Quality, Stability, and Variation
代码地址
我的代码:https://github.com/coolEphemeroptera/CELEBA_PGGAN
官方代码: https://github.com/tkarras/progressive_growing_of_gans
参考代码:https://github.com/zhangqianhui/progressive_growing_of_gans_tensorflow
二、生成样本展示
通俗的阅读完PGGAN论文后,大致摸清了套路后,决定说GAN就GAN。考虑到设备条件有限,所以并不选用Celeb-HQ高清数据集生产超高清样本 ,决定采用Celeb-A数据集生产128x128分辨率人脸小试牛刀。果然功夫不负有心人,经过多日试验修改,最终获取了较为可观的生成模型并附上训练过程。
128x128 生成器
64x64 训练过程
128x128 训练过程
三、关键方法解读
3.1.基于 ‘批标准差’ 增加多样性(INCREASING VARIATION USING MINIBATCH STANDARD DEVIATION)
由于GAN网络倾向于学习数据集的子分部,由此2016年Salimans提出‘minibatch discrimination’即‘批判别’作为解决方案。通过计算训练批数据的特征图的统计特性来驱动生成样本的特征图满足相似的统计特性。做法是在判别器尾端加入minibatch层,该层处理特征图的统计特性。PGGAN在此基础上做出简化操作来提升样本的多样性。
原文
实现
# 添加多样性特征
def MinibatchstateConcat(nhwf, averaging='all'):# input:[N,H,W,fmaps]s = nhwf.shape# 获取批大小group_size = get_N(nhwf)"""计算方法:(1)先计算N个特征图的标准差得到特征图fmap1:[1,H,W,fmaps](2)对fmap1求均值 得到值M1:[1,1,1,1](3)复制扩张M2得到N个特征图fmap2:[N,H,W,1](4)将fmap2添加至每个样本的特征图中"""adjusted_std = lambda x, **kwargs: tf.sqrt(tf.reduce_mean((x - tf.reduce_mean(x, **kwargs)) **2, **kwargs) + 1e-8)vals = adjusted_std(nhwf, axis=0, keep_dims=True)# 求均值vals = tf.reduce_mean(vals, keep_dims=True)# 复制扩张vals = tf.tile(vals, multiples=(group_size, s[1].value, s[2].value, 1))# 将统计特征拼接到每个样本特征图中return tf.concat([nhwf, vals], axis=3)
3.2 生成器和判别器的归一化
PGGAN使用两种不同的方式来限制梯度和不健康博弈,而且方法均采用非训练的处理方式
3.2.1 平衡学习率(EQUALIZED LEARNING RATE)
原文
He的初始化方法能够确保网络初始化的时候,随机初始化的参数不会大幅度地改变输入信号的强度。然而PGGAN中不仅限初始状态scale而是实时scale,其中公式如下:
实现
# 获取归一化权值(equalized learning rate)
def get_weight(shape, gain=np.sqrt(2), use_wscale=False, fan_in=None):"""HE公式:0.5*n*var(w)=1 , so:std(w)=sqrt(2)/sqrt(n)=gain/sqrt(fan_in)"""# 某卷积核参数个数(h*w*fmaps1)或dense层输入节点数目fmaps1# conv_w:[H,W,fmaps1,fmaps2] or mlp_w:[fmaps1,fmaps2]if fan_in is None: fan_in = np.prod(shape[:-1])# He initstd = gain / np.sqrt(fan_in)# 归一化if use_wscale:wscale = tf.constant(np.float32(std), name='wscale')return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal())*wscaleelse:return tf.get_variable('weight', shape=shape, initializer=tf.initializers.random_normal(0,std))
3.2.1 生成器像素归一化(pixel nromalization)
原文
为了避免生成器梯度爆炸,引入像素归一化,介绍如下
Pixel norm,它是local response normalization的变种。Pixel norm沿着channel维度做归一化,这样归一化的一个好处在于,feature map的每个位置都具有单位长度。这个归一化策略与作者设计的Generator输出有较大关系,注意到Generator的输出层并没有Tanh或者Sigmoid激活函数
实现
# 定义像素归一化操作(pixel normalization)
def PN(nd):if len(nd.shape) > 2:axis_ = 3else:axis_ = 1epsilon = 1e-8with tf.variable_scope('PixelNorm'):return nd * tf.rsqrt(tf.reduce_mean(tf.square(nd), axis=axis_, keep_dims=True) + epsilon)
四、构造渐增型网络(progressive network)
在递增的训练阶段,生成器和判别器的型号也是在逐步拓展的,比如训练128x128图像,我们从4x4开始训练,训练阶段有:
stage 1 4x4 稳定 level2-net
stage 2 8x8 过度 level3-net
stage 3 8x8 稳定 level3-net
stage 4 16x16 过渡 level4-net
stage 5 16x16 稳定 level4-net
stage 6 32x32 过渡 level5-net
stage 7 32x32 稳定 level5-net
stage 8 64x64 过渡 level6-net
stage 9 64x64 稳定 level6-net
stage 10 128x128 过渡 level7-net
stage 11 128x128 稳定 level7-net
在代码中体现为:
PGGAN(0,latents_size,batch_size, lowest, highest, level=2, isTransit=False,epochs=epochs,data_size=data_size)
PGGAN(1,latents_size, batch_size, lowest, highest, level=3, isTransit=True, epochs=epochs, data_size=data_size)
PGGAN(2,latents_size, batch_size, lowest, highest, level=3, isTransit=False, epochs=epochs, data_size=data_size)
PGGAN(3,latents_size, batch_size, lowest, highest, level=4, isTransit=True, epochs=epochs, data_size=data_size)
PGGAN(4,latents_size, batch_size, lowest, highest, level=4, isTransit=False, epochs=epochs, data_size=data_size)
PGGAN(5,latents_size, batch_size, lowest, highest, level=5, isTransit=True, epochs=epochs, data_size=data_size)
PGGAN(6,latents_size, batch_size, lowest, highest, level=5, isTransit=False, epochs=epochs, data_size=data_size)
PGGAN(7,latents_size, batch_size, lowest, highest, level=6, isTransit=True, epochs=epochs, data_size=data_size)
PGGAN(8,latents_size, batch_size, lowest, highest, level=6, isTransit=False, epochs=epochs, data_size=data_size)
PGGAN(9,latents_size, batch_size, lowest, highest, level=7, isTransit=True, epochs=epochs, data_size=data_size)
PGGAN(10,latents_size, batch_size, lowest, highest, level=7, isTransit=False, epochs=epochs, data_size=data_size)
4.1 上采样和下采样
论文中上采样由近邻插值方法,下采样由平均池化方法实现。
同时在卷积过程中,PG-GAN移除了deconv 方式,改用了conv + upsample。checkboard提到了deconv会让生成模型遭受checkerboard效应,关于什么时候是checkerboard,可以参考链接的介绍。
以下论文给出的生成器和判别器中的卷积块:
生成器卷积块:
判别器卷积块:
有点类似于高斯金字塔的上下采样过程(高斯金字塔和拉普拉斯金字塔 https://blog.csdn.net/poem_qianmo/article/details/26157633))
实现
# 上采样
def upsampling2d(nhwf):_, h, w, _ = int_shape(nhwf)return tf.image.resize_nearest_neighbor(nhwf, (2 * h, 2 * w))# 下采样
def downsampling2d(nhwf):return tf.nn.avg_pool(nhwf, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
4.2 设计不同level的生成器和判别器 ( level = log2(res), res:当前分辨率)
以生成一张5级(32x32)图片为例,GAN网络从最低分辨率4x4慢慢向最高分辨率32x32学习,其中G/D网络也是逐阶段递增的。 接下来将以生成器为例,解释生成器的不同阶段的搭建方式
(1)建立初级(level=2)卷积层
原文
实现
with tf.variable_scope('generator',reuse=reuse):# ******** 构造二级初始架构 ******************with tf.variable_scope('scale_%d'%(2)):nf = PN(latents)# 论文:CONV4x4+CONV3x3,这里CONV4x4采用FC替代(参考论文源码)with tf.variable_scope('Dense0' ):nf = dense(nf,fmaps=fn(2)*4*4,gain=np.sqrt(2)/4,use_wscale=True)# Dense0:[N,512] to [N,4*4*512}nhwf = tf.reshape(nf,[-1, 4, 4,fn(2)])# reshape:[N,4*4*512} to [N,4,4,512]nhwf = PN(lrelu(add_bias(nhwf)))with tf.variable_scope('CONV1'):nhwf = PN(lrelu(add_bias(conv2d(nhwf,fmaps=fn(2), kernel=3, use_wscale=True))))
(2)建立拓扑卷积层
在4.1介绍了卷积块,这里我们通过这些卷积块来拼接成更高级网络,当然每个卷积块的特征图数量是指定的,PGGAN里指定为:
feats_map_num = [512,512,512,512,256,128,64,32,16]
拓扑结构如下:
实现:这里我们只要生成128x128即可(硬件资源有限)
首先定义生成器卷积块
def G_CONV_BLOCK(nhwf, level, use_wscale=False):"""上采样+CONV0 = pyrUp"""# 上采样with tf.variable_scope('upscale2d'):nhwf = upsampling2d(nhwf)# CONV0with tf.variable_scope('CONV0'):nhwf = PN(lrelu(add_bias(conv2d(nhwf, fmaps=fn(level), kernel=3, use_wscale=use_wscale))))# CONV1with tf.variable_scope('CONV1'):nhwf = PN(lrelu(add_bias(conv2d(nhwf, fmaps=fn(level), kernel=3, use_wscale=use_wscale))))return nhwf
再拓展生成器就可以了,注意到如果是训练阶段,还需保存上一阶段输出(toRGB)
for scale in range(3,level+1):if scale == level and isTransit: # 在最后卷积层新建之前,获取当前输出图片并上采样RGB0 = upsampling2d(nhwf) # 上采样RGB0 = toRGB(RGB0,scale-1,use_wscale=True)# toRGBwith tf.variable_scope('scale_%d'%scale):nhwf = G_CONV_BLOCK(nhwf,scale,use_wscale=True)# 卷积层拓展
(3)生成器输出(整合特征图:toRGB)
经过多层卷积之后,我们获得了特征图,输出端我们需要将这些特征图整合为3通道的RGB图像
首先定义toRGB函数
实现
# 定义toRGB
def toRGB(nhwf, level, use_wscale=False): with tf.variable_scope('level_%d_toRGB' % level):return add_bias(conv2d(nhwf, fmaps=3, kernel=1, gain=1, use_wscale=use_wscale))
然后需要考虑该阶段是否为过渡阶段,如果是过渡阶段还需将上一阶段输出过渡至本阶段
实现
RGB1 = toRGB(nhwf, level,use_wscale=True) # 获取最后卷积层输出图像
# 判断是否为过渡阶段
if isTransit:nhw3 = trans_alpha * RGB1 + (1 - trans_alpha) * RGB0 # 由RGB0平滑过渡到RGB1
else:nhw3 = RGB1return nhw3
其中过渡系数 0<= trans_alpha <=1,且随着训练进度线性递增
(4) 完整生成器定义
下面给出完整的生成器定义,判别器原理和生成器类似,相关代码请参考我的github
def Generator_PG(latents,level,reuse = False,isTransit = False,trans_alpha = 0.0):""":param latents: 输入分布:param level: 网络等级(阶段):param reuse: 变量复用:param isTransit: 是否fade_in:param trans_alpha: 过度系数:return: 生成图片""""""说明:(1)Generator构成:scale_2 + scale_3~level + toRGB , 其中toRGB层将全部特征图合成RGB(2) 过渡阶段: ① 本阶段RGB将融合上一阶段RGB输出。对于上一阶段RGB处理层而言,通过特征图上采样匹配大小,再toRGB再融合。② 上一阶段toRGB的卷积核参数对于上采样后的特征图依然有效"""# ******************************* 构造PG生成器 ************************************with tf.variable_scope('generator',reuse=reuse):# ******** 构造二级初始架构 ******************with tf.variable_scope('scale_%d'%(2)):nf = PN(latents)# 论文:CONV4x4+CONV3x3,这里CONV4x4采用FC替代(参考论文源码)with tf.variable_scope('Dense0' ):nf = dense(nf,fmaps=fn(2)*4*4,gain=np.sqrt(2)/4,use_wscale=True)# Dense0:[N,512] to [N,4*4*512}nhwf = tf.reshape(nf,[-1, 4, 4,fn(2)])# reshape:[N,4*4*512} to [N,4,4,512]nhwf = PN(lrelu(add_bias(nhwf)))with tf.variable_scope('CONV1'):nhwf = PN(lrelu(add_bias(conv2d(nhwf,fmaps=fn(2), kernel=3, use_wscale=True))))# ********* 构造拓扑架构(3~level) *********************for scale in range(3,level+1):if scale == level and isTransit: # 在最后卷积层新建之前,获取当前输出图片并上采样RGB0 = upsampling2d(nhwf) # 上采样RGB0 = toRGB(RGB0,scale-1,use_wscale=True)# toRGBwith tf.variable_scope('scale_%d'%scale):nhwf = G_CONV_BLOCK(nhwf,scale,use_wscale=True)# 卷积层拓展# ******************* toRGB *****************************RGB1 = toRGB(nhwf, level,use_wscale=True) # 获取最后卷积层输出图像# 判断是否为过渡阶段if isTransit:nhw3 = trans_alpha * RGB1 + (1 - trans_alpha) * RGB0 # 由RGB0平滑过渡到RGB1else:nhw3 = RGB1return nhw3
五、生成图片质量评价—— sliced wasserstein distance
原文
六、tensorflow上实现多阶段训练
PGGAN论文同时也给出了训练Celeb-HQ的一些trick(在论文的A.1节),这里我们参考其trick在tensorflow上实现
由于tf的计算图为静态图,因此需要训练完一个阶段,再保存其参数,再重新编写计算图再读取上一阶段参数。这里注意的是模型读取参数需要匹配正确,下面给出训练过程代码,其中结束每一阶段注意清除图(tf.reset_default_graph())
import time
import os
from ops import *
import utils as us
import tfr_tools as tfr
import sliced_wasserstein_distance as swd
os.environ['CUDA_VISIBLE_DEVICES']='0'def PGGAN( id , # PG模型序号latents_size, # 噪声型号batch_size, # 批型号lowest,# 最低网络级数highest,#最高网络级数level,# 目标网络等级isTransit, # 是否过渡epochs, # 训练循环次数data_size, # 数据集大小):#-------------------- 超参 --------------------------#learning_rate = 0.001lam_gp = 10lam_eps = 0.001beta1 = 0.0beta2 = 0.99max_iters = int(epochs * data_size / batch_size)n_critic = 1 # 判别器训练次数#---------- (1)创建目录和指定模型路径 -------------## 当前模型路径model_path = './ckpt/PG%d_level%d_%s' % (id,level, isTransit)us.MKDIR(model_path)# 上一级网络模型路径if isTransit:old_model_path = r'./ckpt/PG%d_level%d_%s/' % (id-1,level - 1, not isTransit) # 上一阶段稳定模型else:old_model_path = r'./ckpt/PG%d_level%d_%s/' % (id-1,level, not isTransit) # 该阶段过度模型#--------------------- (2)定义输入输出 --------------## 图像分辨率res = int(2 ** level)# 定义噪声输入latents = tf.placeholder(name='latents', shape=[None, latents_size], dtype=tf.float32)# 定义数据输入real_images = tf.placeholder(name='real_images', shape=[None, res, res, 3], dtype=tf.float32)# 训练步数train_steps = tf.Variable(0, trainable=False, name='train_steps', dtype=tf.float32) # 等于生成器训练次数# 生成器和判别器输出fake_images = Generator_PG(latents=latents, level=level, reuse=False, isTransit=isTransit,trans_alpha=train_steps / max_iters)d_real_logits = Discriminator_PG(RGB=real_images, level=level, reuse=False, isTransit=isTransit,trans_alpha=train_steps / max_iters)d_fake_logits = Discriminator_PG(RGB=fake_images, level=level, reuse=True, isTransit=isTransit,trans_alpha=train_steps / max_iters)#------------ (3)Wasserstein距离和损失函数 --------------## 定义wasserstein距离wass_dist = tf.reduce_mean(d_real_logits-d_fake_logits)# 定义G,D损失函数d_loss = -wass_dist # 判别器损失函数g_loss = -tf.reduce_mean(d_fake_logits) # 生成器损失函数# 基于‘WGAN-GP’的梯度惩罚alpha_dist = tf.contrib.distributions.Uniform(low=0., high=1.) # 获取[0,1]之间正态分布alpha = alpha_dist.sample((batch_size, 1, 1, 1))interpolated = real_images + alpha * (fake_images - real_images) # 对真实样本和生成样本之间插值inte_logit = Discriminator_PG(RGB=interpolated, level=level, reuse=True, isTransit=isTransit,trans_alpha=train_steps / max_iters) # 求得对应判别器输出# 求得判别器梯度gradients = tf.gradients(inte_logit, [interpolated, ])[0]slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))slopes_m = tf.reduce_mean(slopes)# 定义惩罚项gradient_penalty = tf.reduce_mean((slopes - 1) ** 2)# d_loss加入惩罚项d_loss += gradient_penalty * lam_gp# 零点偏移修正d_loss += tf.reduce_mean(tf.square(d_real_logits)) * lam_eps# ------------ (4)模型可训练参数提取 --------------## 获取G,D 所有可训练参数train_vars = tf.trainable_variables()g_vars = [var for var in train_vars if var.name.startswith("generator")]d_vars = [var for var in train_vars if var.name.startswith("discriminator")]ShowParasList(d_vars, g_vars, level, isTransit)# 记录参数# 提取本阶段各级网络层参数(不含RGB处理层)d_vars_c = [var for var in d_vars if 'fromRGB' not in var.name] # discriminator/scale_(0~level)/g_vars_c = [var for var in g_vars if 'toRGB' not in var.name] # generator/scale_(0~level)/# 提取上一阶段各级网络层参数(不含RGB处理层)d_vars_old = [var for var in d_vars_c if 'scale_%d' % level not in var.name] # discriminator/scale_(0~level-1)/g_vars_old = [var for var in g_vars_c if 'scale_%d' % level not in var.name] # generator/scale_(0~level-1)/# 提取本次和上次阶段RGB处理层参数d_vars_rgb = [var for var in d_vars if 'fromRGB' in var.name] # discriminator/level_*_fromRGB/g_vars_rgb = [var for var in g_vars if 'toRGB' in var.name] # generator/level_*_toRGB/# 提取上一阶段RGB处理层参数d_vars_rgb_old = [var for var in d_vars_rgb if'level_%d_fromRGB' % level not in var.name] # discriminator/level_level-1_fromRGB/g_vars_rgb_old = [var for var in g_vars_rgb if'level_%d_toRGB' % level not in var.name] # generator/level_level-1_fromRGB/# 提取上一阶段全部变量old_vars = d_vars_old + g_vars_old + d_vars_rgb_old + g_vars_rgb_old# ------------ (5)梯度下降 --------------## G,D梯度下降方式d_train_opt = tf.train.AdamOptimizer(learning_rate=learning_rate,beta1=beta1,beta2=beta2).minimize(d_loss, var_list=d_vars)g_train_opt = tf.train.AdamOptimizer(learning_rate=learning_rate,beta1=beta1,beta2=beta2).minimize(g_loss, var_list=g_vars, global_step=train_steps)# 为保持全局平稳学习,我们将保存adam参数的更新状态all_vars = tf.all_variables()adam_vars = [var for var in all_vars if 'Adam' in var.name]adam_vars_old = [var for var in adam_vars if 'level_%d' % level not in var.name and 'scale_%d' % level not in var.name]# ------------ (6)模型保存与恢复 ------------------## 保存本阶段所有变量saver = tf.train.Saver(d_vars + g_vars + adam_vars,max_to_keep=3)# 提取上一阶段所有变量if level > lowest:VARS_MATCH(old_model_path, old_vars) # 核对old_saver = tf.train.Saver(old_vars + adam_vars_old)# ------------ (7)数据集读取(TFR) --------------## read TFR[num, data, label] = tfr.Reading_TFR(sameName=r'./TFR/celeba_%dx%d-*'%(res,res) ,isShuffle=False, datatype=tf.float32, labeltype=tf.int8)# # get batch[num_batch, data_batch, label_batch] = tfr.Reading_Batch_TFR(num, data, label, data_size=res*res*3,label_size=1, isShuffle=False, batchSize=batch_size)# ------------------ (8)迭代 ---------------------## GPU配置config = tf.ConfigProto()config.gpu_options.allow_growth = True# 保存记录losses = []Genlog = []Wass = []SWD = []# 加载数据集的descriptors集合if res>=16:# 加载训练数据的特征集DESC = us.PICKLE_LOADING(r'./DESC.desc')# 开启会话with tf.Session(config=config) as sess:# 全局和局部变量初始化init = (tf.global_variables_initializer(), tf.local_variables_initializer())sess.run(init)# 开启协调器coord = tf.train.Coordinator()# 启动线程threads = tf.train.start_queue_runners(sess=sess, coord=coord)# 加载上一阶段参数if level>lowest:if isTransit: # 如果处于过渡阶段old_saver.restore(sess, tf.train.latest_checkpoint(old_model_path)) # 恢复历史模型print('成功读取上一阶段参数...')else: # 如果处于稳定阶段saver.restore(sess, tf.train.latest_checkpoint(old_model_path)) # 继续训练该架构# 迭代time_start = time.time() # 开始计时for steps in range(1,max_iters+1):# 获取trans_alphatrans_alpha = steps / max_iters# 输入标准正态分布z = np.random.normal(size=(batch_size, latents_size))# 获取数据集minibatch = sess.run(data_batch)# 格式修正minibatch = np.reshape(minibatch,[-1,res,res,3]).astype(np.float32)# 数据集显示# us.CV2_IMSHOW_NHWC_RAMDOM(minibatch, 1, 9, 3, 3, 'minibatch', 0)# 数据集过度处理if isTransit:# minibatch_low = us.lpf_nhwc(minibatch)# minibatch_input = trans_alpha * minibatch + (1 - trans_alpha) * minibatch_low # 数据集过渡处理trans_res = int(0.5*res+0.5*trans_alpha*res)minibatch_input = us.upsize_nhwc(us.downsize_nhwc(minibatch,(trans_res,trans_res)),(res,res))else:minibatch_input = minibatch# 规格化【-1,1】minibatch_input = minibatch_input*2-1# 训练判别器for i in range(n_critic):sess.run(d_train_opt, feed_dict={real_images: minibatch_input, latents: z})# 训练生成器sess.run(g_train_opt, feed_dict={latents: z})# recording training info[d_loss2,g_loss2,wass_dist2,slopes2] = sess.run([d_loss,g_loss,wass_dist,slopes_m], feed_dict={real_images: minibatch_input, latents: z})# recording training_productsz = np.random.normal(size=[9, latents_size])gen_samples = sess.run(fake_images, feed_dict={latents: z})us.CV2_IMSHOW_NHWC_RAMDOM((gen_samples+1)/2, 1, 9, 3, 3, 'GEN', 10)# 打印print('level:%d(%dx%d)..' % (level, res, res),'isTrans:%s..' % isTransit,'step:%d/%d..' % (sess.run(train_steps), max_iters),'Discriminator Loss: %.4f..' % (d_loss2),'Generator Loss: %.4f..' % (g_loss2),'Wasserstein:%.3f..'% wass_dist2,'Slopes:%.3f..'%slopes2)# 记录训练信息if steps % 10 == 0:# (1)记录损失函数losses.append([steps, d_loss2, g_loss2])Wass.append([steps,wass_dist2])# if steps % 50 == 0:# (2)记录生成样本# GenLog.append(gen_samples[0:9])# 计算swd模块if steps % 1000 == 0 and res>=16:# 获取2^13个fake 样本FAKES = []for i in range(64):z = np.random.normal(size=[128, latents_size])fakes = sess.run(fake_images, feed_dict={latents: z})FAKES.append(fakes)FAKES = np.concatenate(FAKES, axis=0)FAKES = (FAKES + 1) / 2# 计算与数据集拉式金字塔指定层的swdif res >16:FAKES = us.hpf_nhwc(FAKES) # 获取高频信号d_desc = swd.get_descriptors_for_minibatch(FAKES, 7, 64)# 提取特征del FAKESd_desc = swd.finalize_descriptors(d_desc)swd2 = swd.sliced_wasserstein_distance(d_desc, DESC[str(res)], 4, 128) * 1e3 # 计算swd*1e3SWD.append([steps,swd2])print('当前生成样本swd(x1e3):', swd2, '...')del d_desc# 保存生成模型if steps % 1000 == 0:saver.save(sess, model_path + '/network.ckpt', global_step=steps) # 保存模型# 关闭线程coord.request_stop()coord.join(threads)# 计时结束:us.CV2_ALL_CLOSE()time_end = time.time()print('迭代结束,耗时:%.2f秒' % (time_end - time_start))# 保存信息us.PICKLE_SAVING(np.array(losses),'./trainlog/losses_%dx%d_trans_%s'%(res,res,isTransit))us.PICKLE_SAVING(np.array(Wass), './trainlog/Wass_%dx%d_trans_%s' % (res, res, isTransit))# us.PICKLE_SAVING(Genlog, './trainlog/Genlog_%dx%d_trans_%s' % (res, res, isTransit))if res>=16:us.PICKLE_SAVING(np.array(SWD),'./trainlog/SWD_%dx%d_trans_%s'%(res,res,isTransit))# 清理图tf.reset_default_graph()#********************************************************* main *******************************************************#
if __name__ == '__main__':# 超参latents_size = 512batch_size = 16lowest = 2highest = 7epochs = 10data_size = 30000us.MKDIR('ckpt')us.MKDIR('structure')us.MKDIR('trainlog')# progressive growingtime0 = time.time() # 开始计时PGGAN(0,latents_size,batch_size, lowest, highest, level=2, isTransit=False,epochs=epochs,data_size=data_size)PGGAN(1,latents_size, batch_size, lowest, highest, level=3, isTransit=True, epochs=epochs, data_size=data_size)PGGAN(2,latents_size, batch_size, lowest, highest, level=3, isTransit=False, epochs=epochs, data_size=data_size)PGGAN(3,latents_size, batch_size, lowest, highest, level=4, isTransit=True, epochs=epochs, data_size=data_size)PGGAN(4,latents_size, batch_size, lowest, highest, level=4, isTransit=False, epochs=epochs, data_size=data_size)PGGAN(5,latents_size, batch_size, lowest, highest, level=5, isTransit=True, epochs=epochs, data_size=data_size)PGGAN(6,latents_size, batch_size, lowest, highest, level=5, isTransit=False, epochs=epochs, data_size=data_size)PGGAN(7,latents_size, batch_size, lowest, highest, level=6, isTransit=True, epochs=epochs, data_size=data_size)PGGAN(8,latents_size, batch_size, lowest, highest, level=6, isTransit=False, epochs=epochs, data_size=data_size)PGGAN(9,latents_size, batch_size, lowest, highest, level=7, isTransit=True, epochs=epochs, data_size=data_size)PGGAN(10,latents_size, batch_size, lowest, highest, level=7, isTransit=False, epochs=epochs, data_size=data_size)time1 = time.time() # 开始计时print('全部训练耗费时间:%.2f..'%(time1-time0))