利用clip模型实现text2draw

参考论文

实践

有数据增强的代码

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)self.text_features = self.model.encode_text(self.text)# self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log_augs/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(im_batch)# logit_scale = self.model.logit_scale.exp()for n in range(NUM_AUGS):loss -= torch.cosine_similarity(self.text_features, image_features[n:n + 1], dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述

没有图像增强

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as Fclass GeometrymatchLoss(torch.nn.Module):def __init__(self, device, reference_images_path):super(GeometrymatchLoss, self).__init__()self.device = deviceself.model, clip_preprocess = clip.load('ViT-B/32', self.device, jit=False)self.model.eval()self.preprocess = transforms.Compose([clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisation# self.preprocess = transforms.Compose([clip_preprocess.transforms[-1]])  # clip normalisationself.reference_images_feature = self.reference_images_feature(reference_images_path)self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)self.text = clip.tokenize([ "A picture of triangle"]).to(device)# self.text = clip.tokenize(["A picture of rectangle", "A picture of triangle", "A picture of circle", "A picture of pentagon","A picture of five-pointed star"]).to(device)self.text_features = self.model.encode_text(self.text)self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)print("text_features.requires_grad:",self.text_features.requires_grad)self.text_features=self.text_features.detach()self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]# Image Augmentation Transformationself.augment_trans = transforms.Compose([transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),])def forward(self, t,canvas_width, canvas_height,shapes):scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)# 渲染图像render = pydiffvg.RenderFunction.applytarget = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)if target.shape[-1] == 4:target = self.compose_image_with_white_background(target)if t%100==0:pydiffvg.imwrite(target.cpu(), f'learn/log/output_{t}.png', gamma=2.2)# targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)img = target.unsqueeze(0)img = img.permute(0, 3, 1, 2)loss = 0NUM_AUGS = 4img_augs = []for n in range(NUM_AUGS):img_augs.append(self.augment_trans(img))im_batch = torch.cat(img_augs)image_features = self.model.encode_image(img)self.targets_features: torch.tensor=image_features[0]self.targets_features = self.targets_features / self.targets_features.norm(dim=-1, keepdim=True)loss -= torch.cosine_similarity(self.text_features, self.targets_features, dim=1)return lossdef compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:if img.shape[-1] == 3:  # return img if it is already rgbreturn img# Compose img with white backgroundalpha = img[:, :, 3:4]img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(img.shape[0], img.shape[1], 3, device=self.device)return imgdef read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imagedef reference_images_feature(self, reference_images_path):reference_images_num = len(os.listdir(reference_images_path))reference_images_feature = []for i in range(reference_images_num):i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))if i_reference_image.shape[-1] == 4:i_reference_image = self.compose_image_with_white_background(i_reference_image)# targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()reference_images_feature.append(i_reference_image_features)return torch.cat(reference_images_feature)def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:if path_to_png_image.endswith('.webp'):numpy_image = np.array(webp.load_image(path_to_png_image))else:numpy_image = skimage.io.imread(path_to_png_image)normalized_tensor_image = torch.from_numpy(numpy_image).to(torch.float32) / 255.0resizer = torchvision.transforms.Resize((224, 224))resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)).permute(1, 2, 0)return resized_imageif __name__ == '__main__':torch.autograd.set_detect_anomaly(True)from tqdm import tqdmdef get_bezier_circle(radius: float = 80,segments: int = 4,bias: np.array = np.asarray([100., 100.])):deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)points = torch.stack((torch.cos(deg), torch.sin(deg))).Tpoints = points * radius + torch.tensor(bias).unsqueeze(dim=0)points = points.type(torch.FloatTensor).contiguous()return pointsdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")matchLoss = GeometrymatchLoss(device, "reference_images/")# print(matchLoss.reference_images_feature.shape)# img1 = read_png_image_from_path('learn/output.png')canvas_width, canvas_height = 224, 224num_segments=4points1 = get_bezier_circle()path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),is_closed=True)shapes=[path]path.points.requires_grad = Trueprint(id(path.points))print(id(points1))points_vars = []points_vars.append(path.points)points_optim = torch.optim.Adam(points_vars, lr=1)pbar = tqdm(range(100000))print(points1)for t in pbar:# print(t)points_optim.zero_grad()# print("match_loss:", match_loss)match_loss = matchLoss(t,224, 224, shapes)match_loss.backward()# print(path.points.grad)points_optim.step()pbar.set_postfix({"match_loss": f"{match_loss.item()}"})# print(points_vars[0])pass

迭代1000轮次后生成的结果
在这里插入图片描述
迭代2000轮次后生成的结果
在这里插入图片描述
迭代4000轮次后生成的结果
在这里插入图片描述
迭代8000轮次后生成的结果
在这里插入图片描述

无图像增强效果不好的原因分析

论文CLIPDraw: Exploring Text-to-Drawing Synthesisthrough Language-Image Encoders解释

在这里插入图片描述

论文StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Translation解释

在这里插入图片描述

个人理解

因为有很多图片可以和一个文本相匹配,对于我们人来说这些图片有一个根本和文本不相关,如果进行图像增强大概率会得到局部最优值。在计算损失函数之前对图片先进行增强,透过透视等变换,相关的图片不论如何变换和文本的相似度基本不会降低,而不相关的图像变换完之后一般会让相似度降低,这样就可以防止不相关图片对实验结果的影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/413254.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024年起重信号司索工(建筑特殊工种)证模拟考试题库及起重信号司索工(建筑特殊工种)理论考试试题

题库来源:安全生产模拟考试一点通公众号小程序 2024年起重信号司索工(建筑特殊工种)证模拟考试题库及起重信号司索工(建筑特殊工种)理论考试试题是由安全生产模拟考试一点通提供,起重信号司索工(建筑特殊工种)证模拟考试题库是根据起重信号司索工(建筑特…

能大致讲一下Chat GPT的原理吗?

AI视频生成:小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频百万播放量https://aitools.jurilu.com/ 话题群精选了三位网友的回答,从不同的角度阐释了Chat GPT的原理。 第一位网友的回答: 不给你扯长篇大论&#…

江协科技stm32————10-4 I2C通信协议

目录 I2C外设简介 I2C功能框图 基本结构图(一主多从) GPIO复用输入输出图 主机发送​编辑 START stop ​ EV5 (标志位) BTF 主机接收 ACK 软件/硬件波形对比 I2C外设简介 可变多主机模型11110作为10位地址模式的标志位…

disk manager操作教程 如何使用Disk Manager组件 Mac如何打开ntfs格式文件

macOS系统有一个特别明显的弱点,即不能对NTFS格式磁盘写入数据。想要适合Mac系统使用来回转换磁盘格式又十分麻烦,这该怎么办呢?Tuxera ntfs for mac作为一款Mac完全读写软件,大家在安装该软件后,能充分使用它的磁盘管…

MyBatis中的#{}和${}区别、ResultMap使用、MyBatis常用注解方式、MyBatis动态SQL

#{}和${}区别: #{}:是占位符,采用预编译的方式sql中传值,防止sql注入,如果我们往sql中列值传递一般使用 #{}。 ${}:采用字符串拼接的方式直接拼接到sql语句中,一般不用于sql列值传递&#xf…

macbook怎么换自定义壁纸?Mac怎么设置壁纸 macOS中如何轻松删除不需要的壁纸?

自定义壁纸,不仅是为了优化桌面外观,还能在很大程度上影响用户情绪。一张好看的壁纸,可以显著提升用户的使用体验。因此,掌握更换和删除壁纸的操作技巧,就显得十分重要。下面详细解读如何在Mac上设置壁纸,以…

SkyWalking部署(监控系统)

简介 SkyWalking 是一个开源的应用性能监控 (APM) 和可观测性平台,旨在帮助开发者、运维人员和架构师监控、诊断和优化微服务架构中的应用。SkyWalking 提供了一套完整的工具链,用于收集、分析和可视化应用的性能指标、追踪和日志数据。 SkyWalking 的…

排序算法(冒泡、插入、选择、快排、归并)原理动画及Python、Java实现

排序算法(冒泡、插入、选择、快排、归并)原理动画及Python、Java实现 1 冒泡排序1.1 原理1.2 Python、Java实现 2 插入排序2.1 原理2.2 Python、Java实现 3 选择排序3.1 原理3.2 Python、Java实现 4 快速排序4.1 原理4.2 Python 5 归并排序5.1 原理5.2 P…

AI绘图提示词/咒语/词缀/关键词使用指南(Stable Diffusion Prompt 最强提示词手册)

一、为什么学习AI绘画关键词 在人工智能技术飞速发展的今天,AI绘画已成为艺术领域的一大热点。学习AI绘画关键词,不仅有助于我们掌握这一新兴技术,还能拓宽我们的创作思路,实现艺术与技术的完美融合。以下是学习AI绘画关键词的几…

STM32外设SPI(串行通信),W25Q64(8Mb)

1 非易失存储器:E2PROM,FLASH(断电不丢失) 2 易失存储器:SRAM,DRAM 3 W25Q64 1 从00 00 00 到 7F FF FF 2 block(块),sector(扇区) ,page(页区) 写数据到FLASH(256字节) 读数据很快&#…

优化学习管理:Moodle和ONLYOFFICE文档编辑器的完美结合

目录 前言 一、什么是 Moodle 1、简单快速插入表单字段 3、免费表单模板库 4、开启无缝协作 三、在Moodle中集成ONLYOFFICE文档 四、在Moodle安装使用ONLYOFFICE 1、下载安装 2、配置服务器 3、在Moodle中使用ONLYOFFICE 文档活动 五、未来展望 写在最后 前言 在当今教育科技飞…

Apache Druid日志实时分析

业务分析 ​ 秒杀业务中,通常会有很多用户同时蜂拥而上去抢购热卖商品,经常会出现抢购人数远大于商品库存。其实在秒杀过程中,热卖商品并不多,几乎只占1%,而99%的流量都源自热卖商品,很有可能因为这1%的热…

C--四种排序方法的补充

上一篇文章因为时间原因只写了三种,这一篇来补充第四种,第四种的代码更多,所需要理解的也是更多的。 堆排序 想要学会堆排序,你必须了解二叉树的内容。堆排序的排序速度也是非常的快。 这里都已大堆为例 1.向上调整算法&#…

xampp安装federated插件,实现mysql数据同步

需求:a服务器上的mysql数据库data表插入新数据时,需要同步到b服务器上的data表中。 解决:a服务器上开启federated引擎插件,创建data1对应b服务器上的data表。 在a服务器上的data表创建触发器,data表插入数据后执行触发…

Vue的状态管理——Vuex34Pinia

Vue3中Vuex的使用_vue3 vuex-CSDN博客 VueX详解_组合式vuex-CSDN博客 15分钟学会Pinia Vuex 3和4详解 Vuex 3 Vuex 3是Vue.js 2.x版本的状态管理库,它提供了一种集中式存储和管理组件状态的方式。以下是Vuex 3的一些关键特性: 状态集中管理&#x…

建模杂谈系列250 Hello2Pymc

说明 pymc算是多年的老朋友了,中间失联了好几年。 内容 1 安装 安装更加麻烦了,不能很好的和其他的环境兼容。在官网上,也是建议用conda的方式安装。 conda create -c conda-forge -n pymc_env "pymc>5" conda activate p…

自闭症儿童托管学校

在自闭症儿童的成长道路上,寻找一个既能够提供专业康复又充满关爱的托管学校,是许多家庭的重要课题。星启帆自闭症儿童康复机构,作为国内规模较大的自闭症儿童托管学校,以其专业的师资力量、科学的康复方法、严格的管理制度以及温…

Milvus向量数据库-数据备份与恢复

前言 随着Milvus版本的持续迭代,越来越多的用户将其作为构建生产环境的向量数据服务使用。作为数据服务使用,其中的运维、数据安全、容灾备份自然是用户最关心且不容有失的需求。为解决这一需求,Milvus-backup项目工具应运而生。 Milvus-ba…

【并集查找 图论】2421. 好路径的数目

本文涉及知识点 C图论 LeetCode2421. 好路径的数目 给你一棵 n 个节点的树(连通无向无环的图),节点编号从 0 到 n - 1 且恰好有 n - 1 条边。 给你一个长度为 n 下标从 0 开始的整数数组 vals ,分别表示每个节点的值。同时给你…