【PyTorch】神经风格迁移项目

神经风格迁移中,取一个内容图像和一个风格图像,综合内容图像的内容和风格图像的艺术风格生成新的图像。

 

目录

准备数据

处理数据 

神经风格迁移模型

加载预训练模型 

定义损失函数

定义优化器

运行模型 


准备数据

创建data文件夹,放入一张内容图片(左),一张风格图片(右),分别命名为content和style

from PIL import Image
path2content= "./data/content.jpg"
path2style= "./data/style.jpg"
content_img = Image.open(path2content)
style_img = Image.open(path2style)

 

 

 

处理数据 

调用torchvision.transforms包中Resize、ToTensor和Normalize对图像进行预处理

import torchvision.transforms as transformsh, w = 256, 384 
mean_rgb = (0.485, 0.456, 0.406)
std_rgb = (0.229, 0.224, 0.225)
transformer = transforms.Compose([# 将图像缩放到指定大小transforms.Resize((h,w)),  # 将图像转换为张量transforms.ToTensor(),# 对图像进行标准化处理transforms.Normalize(mean_rgb, std_rgb)])  content_tensor = transformer(content_img)
print(content_tensor.shape, content_tensor.requires_grad)style_tensor = transformer(style_img)
print(style_tensor.shape, style_tensor.requires_grad)

 

# 克隆content_tensor作为输入图像,并设置requires_grad为True,表示需要计算梯度
input_tensor = content_tensor.clone().requires_grad_(True)
print(input_tensor.shape, input_tensor.requires_grad)

import torch
from torchvision.transforms.functional import to_pil_image
# 将图像张量转换为所需PIL图像
def imgtensor2pil(img_tensor):# 克隆并分离图像张量img_tensor_c = img_tensor.clone().detach()# 将图像张量乘以标准RGB值img_tensor_c*=torch.tensor(std_rgb).view(3,1,1)# 将图像张量加上均值RGB值img_tensor_c+=torch.tensor(mean_rgb).view(3,1,1)# 将图像张量限制在0到1之间img_tensor_c = img_tensor_c.clamp(0,1)# 将图像张量转换为PIL图像img_pil=to_pil_image(img_tensor_c)# 返回PIL图像return img_pilimport matplotlib.pylab as plt
%matplotlib inlineplt.imshow(imgtensor2pil(content_tensor))
plt.title("content image");
plt.imshow(imgtensor2pil(style_tensor))
plt.title("style image");

 

 

神经风格迁移模型

 保持模型参数不变,更新模型的输入

加载预训练模型 

import torchvision.models as models
# 检查是否有可用的GPU,如果没有则使用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载预训练的VGG19模型,并将其特征提取部分移动到指定的设备上,并将其设置为评估模式
model_vgg = models.vgg19(pretrained=True).features.to(device).eval()
# 将模型的所有参数设置为不需要梯度,即不进行反向传播
for param in model_vgg.parameters():param.requires_grad_(False)   
print(model_vgg)

 

定义损失函数

# 定义函数,获取模型中指定层的特征
def get_features(x, model, layers):# 创建一个空字典,用于存储特征features = {}# 遍历模型的所有子层for name, layer in enumerate(model.children()):# 将输入数据传入子层,得到输出数据x = layer(x)# 如果子层的名称在指定的层列表中if str(name) in layers:# 将输出数据存储到字典中,键为子层的名称features[layers[str(name)]] = x# 返回字典return features# 定义函数,于计算gram矩阵
def gram_matrix(x):# 获取输入张量的维度n, c, h, w = x.size()# 将输入张量展平x = x.view(n*c, h * w)# 计算gram矩阵gram = torch.mm(x, x.t())return gramimport torch.nn.functional as F# 定义函数,获取内容损失
def get_content_loss(pred_features, target_features, layer):# 获取目标特征target= target_features[layer]# 获取预测特征pred = pred_features [layer]# 计算均方误差损失loss = F.mse_loss(pred, target)return loss# 定义函数,获取风格损失
def get_style_loss(pred_features, target_features, style_layers_dict):  # 初始化损失为0loss = 0# 遍历style_layers_dict中的每一层for layer in style_layers_dict:# 获取预测特征pred_fea = pred_features[layer]# 计算预测特征的gram矩阵pred_gram = gram_matrix(pred_fea)# 获取预测特征的shapen, c, h, w = pred_fea.shape# 获取目标特征的gram矩阵target_gram = gram_matrix (target_features[layer])# 计算当前层的损失layer_loss = style_layers_dict[layer] *  F.mse_loss(pred_gram, target_gram)# 将当前层的损失加到总损失中loss += layer_loss/ (n* c * h * w)# 返回总损失return loss
# 定义特征层字典,用于存储不同层的特征
feature_layers = {'0': 'conv1_1','5': 'conv2_1','10': 'conv3_1','19': 'conv4_1','21': 'conv4_2',  '28': 'conv5_1'}# 将内容张量增加一个维度,并将其移动到指定设备上
con_tensor = content_tensor.unsqueeze(0).to(device)sty_tensor = style_tensor.unsqueeze(0).to(device)# 获取内容张量的特征
content_features = get_features(con_tensor, model_vgg, feature_layers)style_features = get_features(sty_tensor, model_vgg, feature_layers)
# 遍历content_features字典中的所有key
for key in content_features.keys():# 打印每个key对应的值的形状print(content_features[key].shape)

 

定义优化器

from torch import optim# 克隆con_tensor,并设置requires_grad_为True,表示需要计算梯度
input_tensor = con_tensor.clone().requires_grad_(True)
# 使用Adam优化器,优化input_tensor,学习率为0.01
optimizer = optim.Adam([input_tensor], lr=0.01)

运行模型 

# 定义训练的轮数
num_epochs = 300
# 定义内容损失的权重
content_weight = 1e1
# 定义风格损失的权重
style_weight = 1e4
# 定义内容层
content_layer = "conv5_1"
# 定义风格层及其权重
style_layers_dict = { 'conv1_1': 0.75,'conv2_1': 0.5,'conv3_1': 0.25,'conv4_1': 0.25,'conv5_1': 0.25}# 遍历每一轮
for epoch in range(num_epochs+1):# 梯度清零optimizer.zero_grad()# 获取输入特征input_features = get_features(input_tensor, model_vgg, feature_layers)# 获取内容损失content_loss = get_content_loss (input_features, content_features, content_layer)# 获取风格损失style_loss = get_style_loss(input_features, style_features, style_layers_dict)# 计算神经损失neural_loss = content_weight * content_loss + style_weight * style_loss# 反向传播neural_loss.backward(retain_graph=True)# 更新参数optimizer.step()# 每隔100轮打印一次损失if epoch % 100 == 0:print('epoch {}, content loss: {:.2}, style loss {:.2}'.format(epoch,content_loss, style_loss))

打印输出图片(左),对比原始内容图片(右)

plt.imshow(imgtensor2pil(input_tensor[0].cpu()));

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/390916.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据恢复软件:电脑丢失文件,及时使用数据恢复软件恢复!

数据恢复软件什么时候会用到? 答:如果真的不小心删除文件,清空回收站,电脑重装系统等情况发生,我们要懂的及时停止使用电子设备,使用可靠的数据恢复软件,帮助我们恢复这些电子设备的数据&#…

二进制搭建 Kubernetes v1.20(上)

目录 一、操作系统初始化配置 二、升级Liunx内核 三、部署docker引擎 四、部署etcd集群 五、部署Master组件 六、部署Worker Node组件 hostnameip需要部署k8s集群master0120.0.0.100kube-apiserver kube-controller-manager kube-scheduler etcdk8s集群master0220.0.0.1…

CookieMaker工作室合作开发C++项目十一:拟态病毒

(注:本文章使用了“无标题技术”) 一天,我和几个同事,平台出了点BUG,居然给我刷出了千年杀,同事看得瑕疵欲裂,发誓要将我挫骨扬灰—— (游戏入口:和平精英31.…

【数据脱敏】数据交换平台数据脱敏建设方案

1 概述 1.1 数据脱敏定义 1.2 数据脱敏原则 1.2.1基本原则 1.2.2技术原则 1.2.3管理原则 1.3 数据脱敏常用方法 3.1.1泛化技术 3.1.2抑制技术 3.1.3扰乱技术 3.1.4有损技术 1.4 数据脱敏全生命周期 2 制定数据脱敏规程 3 发现敏感数据 4 定义脱敏规则 5 执…

[Unity] ShaderGraph实现DeBuff污染 溶解叠加效果

本篇是在之前的基础上,继续做的功能衍生。 [Unity] ShaderGraph实现Sprite消散及受击变色 完整连连看如下所示:

TypeError: ‘float’ object is not iterable 深度解析

TypeError: ‘float’ object is not iterable 深度解析与实战指南 在Python编程中,TypeError: float object is not iterable是一个常见的错误,通常发生在尝试对浮点数(float)进行迭代操作时。这个错误表明代码中存在类型使用不…

C基础项目(学生成绩管理系统)

目录 一、项目要求 二、完整代码实例 三、分文件编写代码实例 一、项目要求 1.系统运行,打开如下界面。列出系统帮助菜单(即命令菜单),提示输入命令 2.开始时还没有录入成绩,所以输入命令 L 也无法列出成绩。应提…

嵌入式Linux系统中pinictrl框架基本实现

1. 回顾Pinctrl的三大作用 记住pinctrl的三大作用,有助于理解所涉及的数据结构: * 引脚枚举与命名(Enumerating and naming) * 单个引脚 * 各组引脚 * 引脚复用(Multiplexing):比如用作GPIO、I2C或其他功能 * 引脚配置(Configuration):比如上拉、下拉、open drain、驱…

从零入门 AI for Science(AI+药物) 笔记 #Datawhale AI 夏令营

💖使用平台 我的Notebook 魔搭社区 https://modelscope.cn/my/mynotebook/preset . 魔搭高峰期打不开Task3又换回飞桨了 吧torch 架构换成了 飞桨的paddle 飞桨AI Studio星河社区-人工智能学习与实训社区 https://aistudio.baidu.com/projectdetail/8191835?cont…

Python数据分析案例58——热门游戏数据分析及其可视化

案例背景 有哪个男生不喜欢玩游戏呢?就算上了班儿也要研究一下游戏以及热门的游戏。正好这里有个热门的游戏数据集,全球热门游戏数据集来做一下一些可视化的分析。 数据介绍 该文件包含一个数据集,详细说明了多个平台上的各种流行游戏。每个…

基于ThinkPHP开发的校园跑腿社区小程序系统源码,包含前后端代码

基于ThinkPHP开发的校园跑腿社区小程序系统源码,包含前后端代码 最新独立版校园跑腿校园社区小程序系统源码 | 附教程 测试环境:NginxPHP7.2MySQL5.6 多校版本,多模块,适合跑腿,外卖,表白,二…

Java中的5种线程池类型

Java中的5种线程池类型 1. CachedThreadPool (有缓冲的线程池)2. FixedThreadPool (固定大小的线程池)3. ScheduledThreadPool(计划线程池)4. SingleThreadExecutor (单线程线程池)…

使用 Streamlit 和 Python 构建 Web 应用程序

一.介绍 在本文中,我们将探讨如何使用 Streamlit 构建一个简单的 Web 应用程序。Streamlit 是一个功能强大的 Python 库,允许开发人员快速轻松地创建交互式 Web 应用程序。Streamlit 旨在让 Python 开发人员尽可能轻松地创建 Web 应用程序。以下是一些主…

萱仔大模型学习记录5-langchain实战

前面我的bertlora微调已经跑出了不错的结果,我也学会了如何在bert上使用Lora进行微调,我后续会补充一个医疗意图识别的项目于这个系列,现在这个医疗意图识别代码还暂时不准备公开。我就继续按照我的计划学习一番LangChain。 LangChain是一个用…

【软件测试】--接口测试

1. 接口用例设计 接口测试的测试点 功能测试 单接口功能: 手工测试中的单个业务模块,一般对应一个接口 登陆业务 --> 登陆接口加入购物车业务 --> 加入购物车接口订单业务 --> 订单接口支付业务 --> 支付接口 借助工具、代码。绕开前端界面…

AI大模型技术的四大核心架构分析

AI大模型技术的四大核心架构演进之路 随着人工智能技术的飞速发展,大模型技术已经成为AI领域的重要分支。 深度剖析四大大模型技术架构:纯粹的Prompt提示词法、Agent Function Calling机制,RAG(检索增强生成)及Fine-…

NSSCTF-Web题目27(Nginx漏洞、php伪协议、php解析绕过)

目录 [HNCTF 2022 WEEK2]easy_include 1、题目 2、知识点 3、思路 [NSSRound#8 Basic]MyDoor 4、题目 5、知识点 6、思路 [HNCTF 2022 WEEK2]easy_include 1、题目 2、知识点 nginx日志漏洞执行系统命令 3、思路 打开题目,出现源码 题目要我们上传一个fi…

web浏览器播放rtsp视频流,海康监控API

概述 这里记录一下如何让前端播放rtsp协议的视频流 ​ 项目中调用海康API,生成的视频流(hls、ws、rtmp等)通过PotPlayer播放器都无法播放,说明视频流有问题,唯独rtsp视频流可以播放。 但是浏览器本身是无法播放rtsp视频的,即使…

C++——异常

前言:本篇文章我们来分享C的一个全新内容——异常。 目录 一.异常概念 二.异常的使用 1.异常的抛出和匹配原则 2.在函数调用链中异常栈展开匹配原则 3.异常的重新抛出 三.异常的优缺点 1.优点 2.缺点 结语 一.异常概念 异常是一种处理错误的方式&#xff…

完成QT上位机(八)

一. 正式开始设计界面 这一章节我们将完成QT上位机的设计,如果有同学对QtCreater的使用不太熟悉的,可以参考下面的链接 Qt 快速入门系列教程 Qt 快速入门系列教程 (gitbooks.io)https://wizardforcel.gitbooks.io/qt-beginning/content/ 二. 数据库处…