GAN生成漫画脸

最近对对抗生成网络GAN比较感兴趣,相关知识点文章还在编辑中,以下这个是一个练手的小项目~

 (在原模型上做了,为了减少计算量让其好训练一些。)

一、导入工具包

import tensorflow as tf
from tensorflow.keras import layersimport numpy as np
import os
import time
import glob
import matplotlib.pyplot as plt
from IPython.display import clear_output
from IPython import display

1.1 设置GPU

gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")
gpus 
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

二、导入训练数据

链接: 点这里

fileList = glob.glob('./ani_face/*.jpg')
len(fileList)
41621

2.1 数据可视化 

# 随机显示几张图
for index,i in enumerate(fileList[:3]):display.display(display.Image(fileList[index]))

2.2 数据预处理

# 文件名列表
path_ds = tf.data.Dataset.from_tensor_slices(fileList)# 预处理,归一化,缩放
def load_and_preprocess_image(path):image = tf.io.read_file(path)image = tf.image.decode_jpeg(image, channels=3)image = tf.image.resize(image, [64, 64])image /= 255.0  # normalize to [0,1] rangeimage = tf.reshape(image, [1, 64,64,3])return imageimage_ds = path_ds.map(load_and_preprocess_image)
image_ds
<MapDataset shapes: (1, 64, 64, 3), types: tf.float32>
# 查看一张图片
for x in image_ds:plt.axis("off")plt.imshow((x.numpy() * 255).astype("int32")[0])break

三、网络构建

3.1 D网络

discriminator = keras.Sequential([keras.Input(shape=(64, 64, 3)),layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Flatten(),layers.Dropout(0.2),layers.Dense(1, activation="sigmoid"),],name="discriminator",
)
discriminator.summary()
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 32, 32, 64)        3136      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 128)       131200    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 128)         262272    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 8192)              0         
_________________________________________________________________
dropout (Dropout)            (None, 8192)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 8193      
=================================================================
Total params: 404,801
Trainable params: 404,801
Non-trainable params: 0

3.2 G网络

latent_dim = 128generator = keras.Sequential([keras.Input(shape=(latent_dim,)),layers.Dense(8 * 8 * 128),layers.Reshape((8, 8, 128)),layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),],name="generator",
)
generator.summary()

3.3 重写 train_step

class GAN(keras.Model):def __init__(self, discriminator, generator, latent_dim):super(GAN, self).__init__()self.discriminator = discriminatorself.generator = generatorself.latent_dim = latent_dimdef compile(self, d_optimizer, g_optimizer, loss_fn):super(GAN, self).compile()self.d_optimizer = d_optimizerself.g_optimizer = g_optimizerself.loss_fn = loss_fnself.d_loss_metric = keras.metrics.Mean(name="d_loss")self.g_loss_metric = keras.metrics.Mean(name="g_loss")@propertydef metrics(self):return [self.d_loss_metric, self.g_loss_metric]def train_step(self, real_images):# 生成噪音batch_size = tf.shape(real_images)[0]random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))# 生成的图片generated_images = self.generator(random_latent_vectors)# Combine them with real imagescombined_images = tf.concat([generated_images, real_images], axis=0)# Assemble labels discriminating real from fake imageslabels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)# Add random noise to the labels - important trick!labels += 0.05 * tf.random.uniform(tf.shape(labels))# 训练判别器,生成的当成0,真实的当成1 with tf.GradientTape() as tape:predictions = self.discriminator(combined_images)d_loss = self.loss_fn(labels, predictions)grads = tape.gradient(d_loss, self.discriminator.trainable_weights)self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))# Sample random points in the latent spacerandom_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))# Assemble labels that say "all real images"misleading_labels = tf.zeros((batch_size, 1))# Train the generator (note that we should *not* update the weights# of the discriminator)!with tf.GradientTape() as tape:predictions = self.discriminator(self.generator(random_latent_vectors))g_loss = self.loss_fn(misleading_labels, predictions)grads = tape.gradient(g_loss, self.generator.trainable_weights)self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))# Update metricsself.d_loss_metric.update_state(d_loss)self.g_loss_metric.update_state(g_loss)return {"d_loss": self.d_loss_metric.result(),"g_loss": self.g_loss_metric.result(),}

3.4 设置回调函数

class GANMonitor(keras.callbacks.Callback):def __init__(self, num_img=3, latent_dim=128):self.num_img = num_imgself.latent_dim = latent_dimdef on_epoch_end(self, epoch, logs=None):random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))generated_images = self.model.generator(random_latent_vectors)generated_images *= 255generated_images.numpy()for i in range(self.num_img):img = keras.preprocessing.image.array_to_img(generated_images[i])display.display(img)img.save("gen_ani/generated_img_%03d_%d.png" % (epoch, i))

四、训练模型

epochs = 100  # In practice, use ~100 epochsgan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),loss_fn=keras.losses.BinaryCrossentropy(),
)gan.fit(image_ds, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]
)

五、保存模型

#保存模型
gan.generator.save('./data/ani_G_model')

生成模型文件:点这里

六、生成漫画脸

G_model =  tf.keras.models.load_model('./data/ani_G_model/',compile=False)def randomGenerate():noise_seed = tf.random.normal([16, 128])predictions = G_model(noise_seed, training=False)fig = plt.figure(figsize=(8, 8))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)img = (predictions[i].numpy() * 255 ).astype('int')plt.imshow(img )plt.axis('off')plt.show()
count = 0
while True:randomGenerate()clear_output(wait=True)time.sleep(0.1)if count > 100:breakcount+=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/22880.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

拥抱毒瘤 DDD!

点击关注公众号&#xff0c;Java干货及时送达&#x1f447; 来源&#xff1a;小姐姐味道 牛B的人物&#xff0c;早已经厌倦了中英文混杂&#xff0c;他们更进一步&#xff0c;使用中英文缩写&#xff0c;对普通人进行降维打击。更厉害的&#xff0c;造就新的名词&#xff0c;并…

技术人的618种草书单,这些好书值得收藏

虽然才刚刚进入 6 月&#xff0c;但各大网店的 618 活动都已经悄悄拉开帷幕&#xff0c;最近图灵君收到读者们的催更&#xff0c;希望我们推荐一些值得购买的书&#xff0c;想趁着 618 期间下手&#xff0c;于是火速响应大家的需求&#xff0c;集结了各方数据&#xff0c;整理出…

人工智能的黎明:从信息动力学的角度看ChatGPT| 观点

David S. Soriano, CC BY-SA 4.0 via Wikimedia Commons 导读&#xff1a; 以ChatGPT为代表的新的人工智能语言模型的出现与更迭&#xff0c;引发人们极大的兴奋和关注。 物理学家祁晓亮从信息动力学的角度分析&#xff0c;认为AI革命的标志是信息处理复杂度的临界点。AI还没有…

安卓集成腾讯即时通信IM完成聊天室功能

安卓集成腾讯即时通信IM完成聊天室功能 没有效果图的文章都是扯淡**请将下面的MainActivity的代码复制到源码里面&#xff0c;替换掉源码的MainActivity.class**话不多说&#xff0c;下来上代码&#xff1a;以上就是所有的代码附上demo源码。源码&#xff1a;[源码请点这里](ht…

「GPT虚拟直播」实战篇二|无人直播间如何接入虚拟人

摘要 虚拟人和数字人是人工智能技术在现实生活中的具体应用&#xff0c;它们可以为人们的生活和工作带来便利和创新。在直播间场景里&#xff0c;虚拟人和数字人可用于直播主播、智能客服、营销推广等。接入GPT的虚拟人像是加了超强buff&#xff0c;具备更强大的自然语言处理能…

从腾讯实时音视频发家史,看爆发中的 RTC 将何去何从

作者 | 夕颜 头图 | 下载于视觉中国 出品 | AI 科技大本营&#xff08;ID:rgznai100&#xff09; 早在2015年左右&#xff0c;直播和短视频的兴起渗透进普通人的日常生活&#xff0c;人们信息消费的内容已经开始从文字向语音、视频信息转变。而疫情期间全民“家里蹲”的窘境&am…

腾讯云html5直播开发,腾讯云IM开发 直播 聊天室

2019年6月工作总结 总结时间&#xff1a;2019年6月30日 总结人&#xff1a;韩放 工作内容&#xff1a; 1.哟呵直播开发 2.一乙农场客服商品对接 项目总结&#xff1a; 这个月主要是又做了一个直播类项目&#xff0c;这次主要的不同是根据客户的定制要求完全使用了腾讯IM加腾讯云…

微信团队分享:微信直播聊天室单房间1500万在线的消息架构演进之路

本文由微信开发团队工程师“ kellyliang”原创发表于“微信后台团队”公众号&#xff0c;收录时有修订和改动。 1、引言 随着直播和类直播场景在微信内的增长&#xff0c;这些业务对临时消息&#xff08;在线状态时的实时消息&#xff09;通道的需求日益增长&#xff0c;直播…

有哪些值得关注的AIGC细分方向?

&#xff08;以下内容&#xff0c;来自咱们社群“AI产品经理大本营” 1月12日的团员内部讨论&#xff1b;本文不求大而全&#xff0c;但会有一些大家“没听说过的一手信息input”&#xff09;‍‍‍‍ 【参与团员】 A&#xff1a;某司 负责 AIGC B&#xff1a;某司 负责 医疗AI…

音视频技术开发周刊 | 282

每周一期&#xff0c;纵览音视频技术领域的干货。 新闻投稿&#xff1a;contributelivevideostack.com。 畅谈音视频未来无限可能&#xff0c;2022音视频技术大会北京站 最新日程公布 2023年3月31日-4月1日&#xff0c;LiveVideoStackCon 2022音视频技术大会北京站&#xff0c;…

chatgpt赋能Python-python_cdo

Python-CDO: 数据处理的超棒工具 Python-CDO是一种极为实用的工具&#xff0c;用于在Python中使用CDO&#xff08;Climate Data Operators&#xff09;命令。CDO是一个功能强大的工具&#xff0c;用于处理气候和气象大型数据集&#xff0c;如Satellite and Reanalysis数据。而…

时间序列分析——基于R | 第2章 时间序列的预处理习题代码

时间序列分析——基于R | 第2章 时间序列的预处理习题 1.考虑序列{1,2,3,4,5,…,20} 1.1判断该序列是否平稳 x <- seq(1,20);x ## [1] 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 1.2样本自相关系数 max_lag <- 6 acf_x <- acf(x, lag.max max_l…

Google 人机验证(reCaptcha)无法显示解决方案

Google 人机验证无法显示解决方案 第一步 安装插件Chrome/Edge 电脑版Firefox 电脑版 第二步 配置插件原理参考文章 前言&#xff1a;为了防止机器人攻击&#xff0c;国外很多网站都使用了 Google reCaptcha 验证码。reCaptcha 对于国外用户非常的友好&#xff0c;但是… 对于国…

解决谷歌人机验证(Captcha)显示问题

文章目录 前言一、Header Editor 下载、安装与配置1. 插件下载2. 插件安装3. 插件配置 前言 由于谷歌服务在国内不可用&#xff0c;所以正常访问时某些网址时&#xff0c;经常会出现需要人机验证的问题&#xff0c;影响正常使用。在不使用科学上网的情况下&#xff0c;我们可以…

利用Python解决掉谷歌人机验证,全自动识别真的牛啊

一、接触前感受 第一次带我领略yolov5风骚的是这个视频&#xff1a;【亦】警惕AI外挂&#xff01;我写了一个枪枪爆头的视觉AI&#xff0c;又亲手“杀死”了它。 这样一来&#xff0c;我对人工智能打游戏产生了浓厚的兴趣&#xff0c;于是在B站查找人工智能基础&#xff0c;随便…

【开源项目】ChatGPT智能聊天系统后台管理解析

ChatGPT是likeshop近期新研发出来的一款AI智能聊天对话的产品&#xff0c;此系统是基于likeadmin-PHP开发的智能对话系统&#xff0c;ChatGPT是一种基于人工智能技术的聊天机器人&#xff0c;它可以与用户进行自然语言对话&#xff0c;提供各种服务和答案。ChatGPT的核心技术是…

checkra1n越狱工具下载地址

https://checkra.in/releases/ 虚拟机 checkra1n -26和-31错误 不支持虚拟机,需要在黑苹果&#xff0c;Ra1nUSB&#xff0c;Linux下越狱 AMD的黑苹果&#xff0c;错误&#xff0d;31&#xff0c; AMD的CPU使用checkra1n越狱黑苹果会报错-31 适用系统iOS13-13.3.1基本都是…

RabbitMQ快速实战以及核心概念详解

RabbitMQ快速实战以及核心概念详解 一、MQ介绍 1、什么是MQ&#xff1f;为什么要用MQ&#xff1f; ChatGPT中对于消息队列的介绍是这样的&#xff1a; 什么是消息队列 消息队列是一种在应用程序之间传递消息的技术。它提供了一种异步通信模式&#xff0c;允许应用程序在不同…

闰秒终于要取消了!一文详解其来源及影响

导读 | 第27届国际计量大会宣布最迟不晚于2035年取消引入闰秒&#xff0c;这一消息引起轰动。上一次闰秒产生&#xff0c;对Reddit、Mozilla、FourSquare等都产生了一定的问题&#xff0c;其中Reddit宕机时间超过1个半小时&#xff01;本栏目特邀腾讯后台开发工程师陶松桥&…

GPT-5暂时来不了 OpenAI悄然布局移动端

OpenAI彻底用GPT-4带火自然语言大模型后&#xff0c;互联网科技行业的大头、小头都在推出自家的大模型或产品。一时间&#xff0c;生成式AI竞速赛上演&#xff0c;“吃瓜群众”也等着看谁能跑赢OpenAI。 坊间预测&#xff0c;干掉GPT-4的还得是GPT-5。结果&#xff0c;OpenAI的…