Pytorch Advanced(三) Neural Style Transfer

神经风格迁移在之前的博客中已经用keras实现过了,比较复杂,keras版本。

这里用pytorch重新实现一次,原理图如下:


from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as npdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

加载图像

def load_image(image_path, transform=None, max_size=None, shape=None):"""Load an image and convert it to a torch tensor."""image = Image.open(image_path)if max_size:scale = max_size / max(image.size)size = np.array(image.size) * scaleimage = image.resize(size.astype(int), Image.ANTIALIAS)if shape:image = image.resize(shape, Image.LANCZOS)if transform:image = transform(image).unsqueeze(0)return image.to(device)

这里用的模型是 VGG-19,所要用的是网络中的5个卷积层

class VGGNet(nn.Module):def __init__(self):"""Select conv1_1 ~ conv5_1 activation maps."""super(VGGNet, self).__init__()self.select = ['0', '5', '10', '19', '28'] self.vgg = models.vgg19(pretrained=True).featuresdef forward(self, x):"""Extract multiple convolutional feature maps."""features = []for name, layer in self.vgg._modules.items():x = layer(x)if name in self.select:features.append(x)return features

 模型结构如下,可以看到使用序列模型来写的VGG-NET,所以标号即层号,我们要保存的是['0', '5', '10', '19', '28'] 的输出结果。

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace)(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(17): ReLU(inplace)(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace)(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(24): ReLU(inplace)(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(26): ReLU(inplace)(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace)(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(31): ReLU(inplace)(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(33): ReLU(inplace)(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(35): ReLU(inplace)(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace)(2): Dropout(p=0.5)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace)(5): Dropout(p=0.5)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

 训练:

接下来对训练过程进行解释:

1、加载风格图像和内容图像,我们在之前的博客中使用的一幅加噪图进行训练,这里是用的内容图像的拷贝。

2、我们需要优化的就是作为目标的内容图像拷贝,可以看到target需要求导。

3、VGGnet参数是不需要优化的,所以设置为验证状态。

4、将3幅图像输入网络,得到总共15个输出(每个图像有5层的输出)

5、内容损失:这里是遍历5个层的输出来计算损失,而在keras版本中只用了第4层的输出计算损失

6、风格损失:同样计算格拉姆风格矩阵,将每一层的风格损失叠加,得到总的风格损失,计算公式同样和keras版本有所不一样

7、反向传播

def main(config):# Image preprocessing# VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].# We use the same normalization statistics here.transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])# Load content and style images# Make the style image same size as the content imagecontent = load_image(config.content, transform, max_size=config.max_size)style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])# Initialize a target image with the content imagetarget = content.clone().requires_grad_(True)optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])vgg = VGGNet().to(device).eval()for step in range(config.total_step):# Extract multiple(5) conv feature vectorstarget_features = vgg(target)content_features = vgg(content)style_features = vgg(style)style_loss = 0content_loss = 0for f1, f2, f3 in zip(target_features, content_features, style_features):# Compute content loss with target and content imagescontent_loss += torch.mean((f1 - f2)**2)# Reshape convolutional feature maps_, c, h, w = f1.size()f1 = f1.view(c, h * w)f3 = f3.view(c, h * w)# Compute gram matrixf1 = torch.mm(f1, f1.t())f3 = torch.mm(f3, f3.t())# Compute style loss with target and style imagesstyle_loss += torch.mean((f1 - f3)**2) / (c * h * w) # Compute total loss, backprop and optimizeloss = content_loss + config.style_weight * style_loss optimizer.zero_grad()loss.backward()optimizer.step()if (step+1) % config.log_step == 0:print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}' .format(step+1, config.total_step, content_loss.item(), style_loss.item()))if (step+1) % config.sample_step == 0:# Save the generated imagedenorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))img = target.clone().squeeze()img = denorm(img).clamp_(0, 1)torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))

写在if __name__=="__main__"后面的语句只会在本脚本中才能被执行,被调用时是不会被执行的。 

python的命令行工具:argparse,很优雅的添加参数

但是由于jupyter不支持添加外部参数,所以使用了外部博客的方法来支持(记住更改读取图片的位置)

import sys
if __name__ == "__main__":#解决方案来自于博客if '-f' in sys.argv:sys.argv.remove('-f')parser = argparse.ArgumentParser()parser.add_argument('--content', type=str, default='png/content.png')parser.add_argument('--style', type=str, default='png/style.png')parser.add_argument('--max_size', type=int, default=400)parser.add_argument('--total_step', type=int, default=2000)parser.add_argument('--log_step', type=int, default=10)parser.add_argument('--sample_step', type=int, default=500)parser.add_argument('--style_weight', type=float, default=100)parser.add_argument('--lr', type=float, default=0.003)#config = parser.parse_args()config = parser.parse_known_args()[0]   #参考博客 https://blog.csdn.net/ken_for_learning/article/details/89675904print(config)main(config)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/130085.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024年java面试--mysql(2)

系列文章目录 2024年java面试(一)–spring篇2024年java面试(二)–spring篇2024年java面试(三)–spring篇2024年java面试(四)–spring篇2024年java面试–集合篇2024年java面试–redi…

FPGA实战小项目3

基于FPGA的波形发生器 基于FPGA的波形发生器 基于FPGA的beep音乐播放器设计 基于FPGA的beep音乐播放器设计 基于FPGA的cordic算法实现DDS sin和cosine波形的产生 基于FPGA的cordic算法实现DDS sin和cosine波形的产生

iframe 实现跨域,两页面之间的通信

一、 背景 一个项目为vue2,一个项目为vue3,两个不同的项目实现iframe嵌入,并实现通信 二、方案 iframe跨域时,iframe组件之间常用的通信,主要是H5的possmessage方法 三、案例代码 父页面-vue2(端口号为…

JWT 使用教程 授权 认证

JWT 1.什么是JWT JSON Web Token (JWT) is an open standard (RFC 7519) that defines a compact and self-contained way for securely transmitting information between parties as a JSON object. This information can be verified and trusted because it is digitally s…

轻松搭建本地知识库的ChatGLM2-6B

近期发现了一个项目,它的前身是ChatGLM,在我之前的博客中有关于ChatGLM的部署过程,本项目在前者基础上进行了优化,可以基于当前主流的LLM模型和庞大的知识库,实现本地部署自己的ChatGPT,并可结合自己的知识…

Zabbix登录绕过漏洞复现(CVE-2022-23131)

0x00 前言 最近在复现zabbix的漏洞(CVE-2022-23131),偶然间拿到了国外某公司zabbix服务器。Zabbix Sia Zabbix是拉脱维亚Zabbix SIA(Zabbix Sia)公司的一套开源的监控系统。该系统支持网络监控、服务器监控、云监控和…

C# Winform 简单排期实现(DevExpress TreeList)

排期的需求在很多任务安排的系统中都有相应的需求,原生的Winform控件并未提供相应的控件,一般都是利用DataGridViewTreeView组合完成相应的需求,实现起来比较麻烦。用过DevExpress控件集的开发者应该知道,DevExpress WinForm提供了…

每日一博 - CRUD system VS Event sourcing design

文章目录 概念Arch Overview小结 概念 CRUD 系统和事件溯源设计是两种不同的软件架构方法,用于处理数据和应用程序的状态。以下是它们的区别以及各自适用的场景: CRUD 系统: CRUD 代表 Create(创建)、Read&#xff08…

Android Studio实机同WIFI调试

1.点击Pair using Wi-Fi 2.手机扫描跳出来的二维码 小米手机可搜索无线调试进行adb 调试

模态分析的概念。C++减振器设计。

一、说明 模态分析是工程和物理学中用于研究系统或结构动态特性的技术。它涉及分析系统的振动或振荡的自然模式以及相应的频率、阻尼系数和振型。 在模态分析中,所研究的系统通常表示为一组质量、刚度和阻尼元件(在下面的文章中忽略了阻尼)。…

单链表(Single Link Table)——单文件实现

一、单链表前言 上篇文章我们讲述了顺序表,认真学习我们会发现顺序表优缺点。 缺点1:头部和中部的插入删除效率都不行,时间和空间复杂度都为O(N); 缺点2:空间不够了扩容有一定的消耗(尤其是realloc的异地扩容); 缺…

Docker-namespace

Docker-namespace namespace基础命令dd 命令mkfsdfmountunshare pid 隔离试验mount 隔离 namespace namespace 是 Linux 内核用来隔离内核资源的方式。通过 namespace 可以让一些进程只能看到与自己相关的一部分资源,而另外一些进程也只能看到与它们自己相关的资源…

【Unity编辑器扩展】| 顶部菜单栏扩展 MenuItem

前言【Unity编辑器扩展】 | 顶部菜单栏扩展 MenuItem一、创建多级菜单二、创建可使用快捷键的菜单项三、调节菜单显示顺序和可选择性四、创建可被勾选的菜单项五、右键菜单扩展5.1 Hierarchy 右键菜单5.2 Project 右键菜单5.3 Inspector 组件右键菜单六、AddComponentMenu 特性…

MediaBox助力企业一站式获取音视频能力

以一只音视频百宝箱,应对「千行千面」。 洪炳峰、楚佩斯|作者 大家好,今天我分享的主题是MediaBox——行业音视频数字化再加速。 根据权威数据表明,65%的行业数字化信息来自视频,基于此,音视频技术对于行…

群晖NAS教程(二十五)、利用web station安装nextcloud

群晖NAS教程(二十五)、利用web station安装nextcloud 一、下载离线安装包文件 下载地址https://download.nextcloud.com/server/releases/,我们选择zip格式的,下载这个latest-27.zip的最新版的。 把它加压缩到群辉web/hepnextcloud路径下,并…

CSS:隐藏移动端的滚动条的方式

目录 方式一:-webkit-scrollbar方式二:overflow方式三:clip-path方式四:mask 遮罩总结参考 移动端开发中,有一个横向滚动元素,产品告诉我不需要滚动条,我说这个简单,隐藏一下不就行了…

Ubuntu使用命令行界面配置静态IP地址

参考地址:https://www.zhihu.com/tardis/sogou/art/46544606 方法一:配置/etc/network/interfaces文件 首先查看网卡接口名称:ip a 知道网卡接口名称之后,在 /etc/network/interfaces 文件中配置: auto enp0s31f6 …

keep-alive缓存三级及三级以上路由

需求需要缓存这个出入记录,当tab切换时不重新加载,当刷新页面时,或把这个关闭在重新打开时重新加载如图: (我这里用的是芋道源码的前端框架) keep-alive 1、include 包含页面组件name的这些组件页面,会被…

【算法与数据结构】236、LeetCode二叉树的最近公共祖先

文章目录 一、题目二、解法三、完整代码 所有的LeetCode题解索引,可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、解法 思路分析: 根据定义,最近祖先节点需要遍历节点的左右子树,然后才能知道是否为最近祖先节…

Kubernetes 部署发布镜像(cubefile:0.4.0)

目录 实验:部署发布镜像(cubefile:0.4.0) 需求分析: 1、部署Kubenetes环境: 2、撰写 cubefile-deployment.yaml 文件 代码解释: 遇到的问题: 问题解决 : 3、撰写 cubefile-se…