基于深度学习神经网络的AI图片上色DDcolor系统源码

第一步:DDcolor介绍

        DDColor 是最新的 SOTA 图像上色算法,能够对输入的黑白图像生成自然生动的彩色结果,使用 UNet 结构的骨干网络和图像解码器分别实现图像特征提取和特征图上采样,并利用 Transformer 结构的颜色解码器完成基于视觉语义的颜色查询,最终聚合输出彩色通道预测结果。

        它甚至可以对动漫游戏中的风景进行着色/重新着色,将您的动画风景转变为逼真的现实生活风格!(图片来源:原神)

第二步:DDcolor网络结构

        算法整体流程如下图,使用 UNet 结构的骨干网络和图像解码器分别实现图像特征提取和特征图上采样,并利用 Transformer 结构的颜色解码器完成基于视觉语义的颜色查询,最终聚合输出彩色通道预测结果。

第三步:模型代码展示

import os
import torch
from collections import OrderedDict
from os import path as osp
from tqdm import tqdm
import numpy as npfrom basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.img_util import tensor_lab2rgb
from basicsr.utils.dist_util import master_only
from basicsr.utils.registry import MODEL_REGISTRY
from .base_model import BaseModel
from basicsr.metrics.custom_fid import INCEPTION_V3_FID, get_activations, calculate_activation_statistics, calculate_frechet_distance
from basicsr.utils.color_enhance import color_enhacne_blend@MODEL_REGISTRY.register()
class ColorModel(BaseModel):"""Colorization model for single image colorization."""def __init__(self, opt):super(ColorModel, self).__init__(opt)# define network net_gself.net_g = build_network(opt['network_g'])self.net_g = self.model_to_device(self.net_g)self.print_network(self.net_g)# load pretrained model for net_gload_path = self.opt['path'].get('pretrain_network_g', None)if load_path is not None:param_key = self.opt['path'].get('param_key_g', 'params')self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)if self.is_train:self.init_training_settings()def init_training_settings(self):train_opt = self.opt['train']self.ema_decay = train_opt.get('ema_decay', 0)if self.ema_decay > 0:logger = get_root_logger()logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')# define network net_g with Exponential Moving Average (EMA)# net_g_ema is used only for testing on one GPU and saving# There is no need to wrap with DistributedDataParallelself.net_g_ema = build_network(self.opt['network_g']).to(self.device)# load pretrained modelload_path = self.opt['path'].get('pretrain_network_g', None)if load_path is not None:self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')else:self.model_ema(0)  # copy net_g weightself.net_g_ema.eval()# define network net_dself.net_d = build_network(self.opt['network_d'])self.net_d = self.model_to_device(self.net_d)self.print_network(self.net_d)# load pretrained model for net_dload_path = self.opt['path'].get('pretrain_network_d', None)if load_path is not None:param_key = self.opt['path'].get('param_key_d', 'params')self.load_network(self.net_d, load_path, self.opt['path'].get('strict_load_d', True), param_key)self.net_g.train()self.net_d.train()# define lossesif train_opt.get('pixel_opt'):self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)else:self.cri_pix = Noneif train_opt.get('perceptual_opt'):self.cri_perceptual = build_loss(train_opt['perceptual_opt']).to(self.device)else:self.cri_perceptual = Noneif train_opt.get('gan_opt'):self.cri_gan = build_loss(train_opt['gan_opt']).to(self.device)else:self.cri_gan = Noneif self.cri_pix is None and self.cri_perceptual is None:raise ValueError('Both pixel and perceptual losses are None.')if train_opt.get('colorfulness_opt'):self.cri_colorfulness = build_loss(train_opt['colorfulness_opt']).to(self.device)else:self.cri_colorfulness = None# set up optimizers and schedulersself.setup_optimizers()self.setup_schedulers()# set real dataset cache for fid metric computingself.real_mu, self.real_sigma = None, Noneif self.opt['val'].get('metrics') is not None and self.opt['val']['metrics'].get('fid') is not None:self._prepare_inception_model_fid()def setup_optimizers(self):train_opt = self.opt['train']# optim_params_g = []# for k, v in self.net_g.named_parameters():#     if v.requires_grad:#         optim_params_g.append(v)#     else:#         logger = get_root_logger()#         logger.warning(f'Params {k} will not be optimized.')optim_params_g = self.net_g.parameters()# optimizer goptim_type = train_opt['optim_g'].pop('type')self.optimizer_g = self.get_optimizer(optim_type, optim_params_g, **train_opt['optim_g'])self.optimizers.append(self.optimizer_g)# optimizer doptim_type = train_opt['optim_d'].pop('type')self.optimizer_d = self.get_optimizer(optim_type, self.net_d.parameters(), **train_opt['optim_d'])self.optimizers.append(self.optimizer_d)def feed_data(self, data):self.lq = data['lq'].to(self.device)self.lq_rgb = tensor_lab2rgb(torch.cat([self.lq, torch.zeros_like(self.lq), torch.zeros_like(self.lq)], dim=1))if 'gt' in data:self.gt = data['gt'].to(self.device)self.gt_lab = torch.cat([self.lq, self.gt], dim=1)self.gt_rgb = tensor_lab2rgb(self.gt_lab)if self.opt['train'].get('color_enhance', False):for i in range(self.gt_rgb.shape[0]):self.gt_rgb[i] = color_enhacne_blend(self.gt_rgb[i], factor=self.opt['train'].get('color_enhance_factor'))def optimize_parameters(self, current_iter):# optimize net_gfor p in self.net_d.parameters():p.requires_grad = Falseself.optimizer_g.zero_grad()self.output_ab = self.net_g(self.lq_rgb)self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)self.output_rgb = tensor_lab2rgb(self.output_lab)l_g_total = 0loss_dict = OrderedDict()# pixel lossif self.cri_pix:l_g_pix = self.cri_pix(self.output_ab, self.gt)l_g_total += l_g_pixloss_dict['l_g_pix'] = l_g_pix# perceptual lossif self.cri_perceptual:l_g_percep, l_g_style = self.cri_perceptual(self.output_rgb, self.gt_rgb)if l_g_percep is not None:l_g_total += l_g_perceploss_dict['l_g_percep'] = l_g_percepif l_g_style is not None:l_g_total += l_g_styleloss_dict['l_g_style'] = l_g_style# gan lossif self.cri_gan:fake_g_pred = self.net_d(self.output_rgb)l_g_gan = self.cri_gan(fake_g_pred, target_is_real=True, is_disc=False)l_g_total += l_g_ganloss_dict['l_g_gan'] = l_g_gan# colorfulness lossif self.cri_colorfulness:l_g_color = self.cri_colorfulness(self.output_rgb)l_g_total += l_g_colorloss_dict['l_g_color'] = l_g_colorl_g_total.backward()self.optimizer_g.step()# optimize net_dfor p in self.net_d.parameters():p.requires_grad = Trueself.optimizer_d.zero_grad()real_d_pred = self.net_d(self.gt_rgb)fake_d_pred = self.net_d(self.output_rgb.detach())l_d = self.cri_gan(real_d_pred, target_is_real=True, is_disc=True) + self.cri_gan(fake_d_pred, target_is_real=False, is_disc=True)loss_dict['l_d'] = l_dloss_dict['real_score'] = real_d_pred.detach().mean()loss_dict['fake_score'] = fake_d_pred.detach().mean()l_d.backward()self.optimizer_d.step()self.log_dict = self.reduce_loss_dict(loss_dict)if self.ema_decay > 0:self.model_ema(decay=self.ema_decay)def get_current_visuals(self):out_dict = OrderedDict()out_dict['lq'] = self.lq_rgb.detach().cpu()out_dict['result'] = self.output_rgb.detach().cpu()if self.opt['logger'].get('save_snapshot_verbose', False):  # only for verboseself.output_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.output_ab], dim=1)self.output_rgb_chroma = tensor_lab2rgb(self.output_lab_chroma)out_dict['result_chroma'] = self.output_rgb_chroma.detach().cpu()if hasattr(self, 'gt'):out_dict['gt'] = self.gt_rgb.detach().cpu()if self.opt['logger'].get('save_snapshot_verbose', False):  # only for verboseself.gt_lab_chroma = torch.cat([torch.ones_like(self.lq) * 50, self.gt], dim=1)self.gt_rgb_chroma = tensor_lab2rgb(self.gt_lab_chroma)out_dict['gt_chroma'] = self.gt_rgb_chroma.detach().cpu()return out_dictdef test(self):if hasattr(self, 'net_g_ema'):self.net_g_ema.eval()with torch.no_grad():self.output_ab = self.net_g_ema(self.lq_rgb)self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)self.output_rgb = tensor_lab2rgb(self.output_lab)else:self.net_g.eval()with torch.no_grad():self.output_ab = self.net_g(self.lq_rgb)self.output_lab = torch.cat([self.lq, self.output_ab], dim=1)self.output_rgb = tensor_lab2rgb(self.output_lab)self.net_g.train()def dist_validation(self, dataloader, current_iter, tb_logger, save_img):if self.opt['rank'] == 0:self.nondist_validation(dataloader, current_iter, tb_logger, save_img)def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):dataset_name = dataloader.dataset.opt['name']with_metrics = self.opt['val'].get('metrics') is not Noneuse_pbar = self.opt['val'].get('pbar', False)if with_metrics and not hasattr(self, 'metric_results'):  # only execute in the first runself.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}# initialize the best metric results for each dataset_name (supporting multiple validation datasets)if with_metrics:self._initialize_best_metric_results(dataset_name)# zero self.metric_resultsif with_metrics:self.metric_results = {metric: 0 for metric in self.metric_results}metric_data = dict()if use_pbar:pbar = tqdm(total=len(dataloader), unit='image')if self.opt['val']['metrics'].get('fid') is not None:fake_acts_set, acts_set = [], []for idx, val_data in enumerate(dataloader):# if idx == 100:#     breakimg_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]if hasattr(self, 'gt'):del self.gtself.feed_data(val_data)self.test()visuals = self.get_current_visuals()sr_img = tensor2img([visuals['result']])metric_data['img'] = sr_imgif 'gt' in visuals:gt_img = tensor2img([visuals['gt']])metric_data['img2'] = gt_imgtorch.cuda.empty_cache()if save_img:if self.opt['is_train']:save_dir = osp.join(self.opt['path']['visualization'], img_name)for key in visuals:save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))img = tensor2img(visuals[key])imwrite(img, save_path)else:if self.opt['val']['suffix']:save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,f'{img_name}_{self.opt["val"]["suffix"]}.png')else:save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,f'{img_name}_{self.opt["name"]}.png')imwrite(sr_img, save_img_path)if with_metrics:# calculate metricsfor name, opt_ in self.opt['val']['metrics'].items():if name == 'fid':pred, gt = visuals['result'].cuda(), visuals['gt'].cuda()fake_act = get_activations(pred, self.inception_model_fid, 1)fake_acts_set.append(fake_act)if self.real_mu is None:real_act = get_activations(gt, self.inception_model_fid, 1)acts_set.append(real_act)else:self.metric_results[name] += calculate_metric(metric_data, opt_)if use_pbar:pbar.update(1)pbar.set_description(f'Test {img_name}')if use_pbar:pbar.close()if with_metrics:if self.opt['val']['metrics'].get('fid') is not None:if self.real_mu is None:acts_set = np.concatenate(acts_set, 0)self.real_mu, self.real_sigma = calculate_activation_statistics(acts_set)fake_acts_set = np.concatenate(fake_acts_set, 0)fake_mu, fake_sigma = calculate_activation_statistics(fake_acts_set)fid_score = calculate_frechet_distance(self.real_mu, self.real_sigma, fake_mu, fake_sigma)self.metric_results['fid'] = fid_scorefor metric in self.metric_results.keys():if metric != 'fid':self.metric_results[metric] /= (idx + 1)# update the best metric resultself._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)self._log_validation_metric_values(current_iter, dataset_name, tb_logger)def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):log_str = f'Validation {dataset_name}\n'for metric, value in self.metric_results.items():log_str += f'\t # {metric}: {value:.4f}'if hasattr(self, 'best_metric_results'):log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ 'f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')log_str += '\n'logger = get_root_logger()logger.info(log_str)if tb_logger:for metric, value in self.metric_results.items():tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)def _prepare_inception_model_fid(self, path='pretrain/inception_v3_google-1a9a5a14.pth'):incep_state_dict = torch.load(path, map_location='cpu')block_idx = INCEPTION_V3_FID.BLOCK_INDEX_BY_DIM[2048]self.inception_model_fid = INCEPTION_V3_FID(incep_state_dict, [block_idx])self.inception_model_fid.cuda()self.inception_model_fid.eval()@master_onlydef save_training_images(self, current_iter):visuals = self.get_current_visuals()save_dir = osp.join(self.opt['root_path'], 'experiments', self.opt['name'], 'training_images_snapshot')os.makedirs(save_dir, exist_ok=True)for key in visuals:save_path = os.path.join(save_dir, '{}_{}.png'.format(current_iter, key))img = tensor2img(visuals[key])imwrite(img, save_path)def save(self, epoch, current_iter):if hasattr(self, 'net_g_ema'):self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])else:self.save_network(self.net_g, 'net_g', current_iter)self.save_network(self.net_d, 'net_d', current_iter)self.save_training_state(epoch, current_iter)

第四步:运行

第五步:整个工程的内容

代码的下载路径(新窗口打开链接)基于深度学习神经网络的AI图片上色DDcolor系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/315093.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

以生命健康为中心的物联网旅居养老运营平台

随着科技的飞速发展和人口老龄化的日益加剧,养老问题逐渐成为社会关注的焦点。传统的养老模式已经难以满足现代老年人的多元化需求,因此,构建一个以生命健康为中心的物联网旅居养老运营平台显得尤为重要。 以生命健康为中心的物联网旅居养老运…

未来已来:解锁AGI的无限潜能与挑战

未来已来:解锁AGI的无限潜能与挑战 引言 假设你有一天醒来,发现你的智能手机不仅提醒你今天的日程,还把你昨晚做的那个奇怪的梦解释了一番,并建议你可能需要减少咖啡摄入量——这不是科幻电影的情节,而是人工通用智能…

本地认证的密码去哪了?怎么保证安全的?

1. windows登录的明文密码,存储过程是怎么样的?密文存在哪个文件下?该文件是否可以打开,并且查看到密文? 系统将输入的明文密码通过hash算法转为哈希值,且输入的值会在内存中立即删除无法查看。 然后将密文存放在C:…

Vue3+Vite开发的项目进行加密打包

本文主要介绍Vue3+Vite开发的项目如何进行加密打包。 目录 一、vite简介二、混淆工具三、使用方法1. 安装插件:2. 配置插件:3. 运行构建:4. 自定义混淆选项:5. 排除文件:下面是Vue 3+Vite开发的项目进行加密打包的方法。 一、vite简介 Vite 是一个由 Evan You 创造的现代…

MultiHeadAttention在Tensorflow中的实现原理

前言 通过这篇文章,你可以学习到Tensorflow实现MultiHeadAttention的底层原理。 一、MultiHeadAttention的本质内涵 1.Self_Atention机制 MultiHeadAttention是Self_Atention的多头堆嵌,有必要对Self_Atention机制进行一次深入浅出的理解,这…

AJAX——案例

1.商品分类 需求&#xff1a;尽可能同时展示所有商品分类到页面上 步骤&#xff1a; 获取所有的一级分类数据遍历id&#xff0c;创建获取二级分类请求合并所有二级分类Promise对象等待同时成功后&#xff0c;渲染页面 index.html代码 <!DOCTYPE html> <html lang&qu…

ssh 文件传输:你应该掌握的几种命令行工具

这篇文章主要分享一下我使用过的 ssh 传输文件的进阶路程&#xff0c;从 scp -> lrzsz -> trzsz&#xff0c;希望能给你带来一些帮助&#xff5e; scp scp 命令可以用于在 linux 系统之间复制文件&#xff0c;具体的语法可以参考下图 其实使用起来也还比较方便&#x…

【Docker】Docker 实践(三):使用 Dockerfile 文件构建镜像

Docker 实践&#xff08;三&#xff09;&#xff1a;使用 Dockerfile 文件构建镜像 1.使用 Dockerfile 文件构建镜像2.Dockerfile 文件详解 1.使用 Dockerfile 文件构建镜像 Dockerfile 是一个文本文件&#xff0c;其中包含了一条条的指令&#xff0c;每一条指令都用于构建镜像…

智慧码头港口:施工作业安全生产AI视频监管与风险预警平台方案

一、建设思路 随着全球贸易的快速发展&#xff0c;港口作为连接海洋与内陆的关键节点&#xff0c;其运营效率和安全性越来越受到人们的关注。为了提升港口的运营效率和安全性&#xff0c;智慧港口视频智能监控系统的建设显得尤为重要。 1&#xff09;系统架构设计 系统应该采…

针对icon报错

针对上篇文章生成图标链接中图标报错 C# winfrom应用程序添加图标-CSDN博客 问题&#xff1a;参数“picture”必须是可用作Icon的参数 原因&#xff1a;生成的ico图标类型不匹配 解决方法&#xff1a; 更改导出的ico类型

下载学浪视频,小浪助手一键搞定

小浪助手可以一键获取课程&#xff0c;一键根据课程获取视频列表&#xff0c;而且内置了2大下载器&#xff0c;N_m3u8和逍遥一仙下载器 小浪助手我已经打包好了&#xff0c;有需要的自己取一下 学浪下载工具链接&#xff1a;https://pan.baidu.com/s/1_Sg-EGGXKc4bMW-NPqUqvg…

第55篇:创建Nios II工程之Hello_World<一>

Q&#xff1a;本期我们开始介绍创建Platform Designer系统&#xff0c;并设计基于Nios II Processor的Hello_world工程。 A&#xff1a;设计流程和实验原理&#xff1a;需要用到的IP组件有Clock Source、Nios II Processor、On-Chip Memory、JTAG UART和System ID外设。Nios I…

Maven多模块快速升级超好用Idea插件-MPVP

功能&#xff1a;多模块maven项目快速升级指定版本插件&#xff0c;并提供预览和相关升级模块日志能力。 可快速进行版本升级&#xff0c;进行部署到Maven仓库。 安装&#xff1a; 可在idea插件中心进行安装 / 下载资源拖动安装 MPVP(Maven) - IntelliJ IDEs Plugin | Marke…

构建安全高效的前端权限控制系统

✨✨谢谢大家捧场&#xff0c;祝屏幕前的小伙伴们每天都有好运相伴左右&#xff0c;一定要天天开心哦&#xff01;✨✨ &#x1f388;&#x1f388;作者主页&#xff1a; 喔的嘛呀&#x1f388;&#x1f388; ✨✨ 帅哥美女们&#xff0c;我们共同加油&#xff01;一起进步&am…

数据库变更时,OceanBase如何自动生成回滚 SQL

背景 在开发中&#xff0c;数据的变更与维护工作一般较频繁。当我们执行数据库的DML操作时&#xff0c;必须谨慎考虑变更对数据可能产生的后果&#xff0c;以及变更是否能够顺利执行。若出现意外数据丢失、操作失误或语法错误等情况&#xff0c;我们必须迅速将数据库恢复到变更…

Bayes判别示例数据:鸢尾花数据集

使用Bayes判别的R语言实例通常涉及使用朴素贝叶斯分类器。朴素贝叶斯分类器是一种简单的概率分类器&#xff0c;基于贝叶斯定理和特征之间的独立性假设。在R中&#xff0c;我们可以使用e1071包中的naiveBayes函数来实现这一算法。下面&#xff0c;我将通过一个简单的示例展示如…

npm、yarn与pnpm详解

&#x1f525; npm、yarn与pnpm详解 &#x1f516; 一、npm &#x1f50d; 简介&#xff1a; npm是随Node.js一起安装的官方包管理工具&#xff0c;它为开发者搭建了一个庞大的资源库&#xff0c;允许他们在这个平台上搜索、安装和管理项目所必需的各种代码库或模块。 &#…

Intelij Idea Push失败,出现git Authentication failed(验证失败)

目录 1、出现问题的原因 2、解决之法 1、出现问题的原因 能出现这种问题&#xff0c;最主要的原因是链接对上了&#xff0c;但用户验证失败了&#xff0c;即登录失败。 因为服务器转移或者换了git项目链接&#xff0c;导致你忘记了用户名密码&#xff0c;随意输入之后&…

求三个字符数组最大者(C语言)

一、N-S流程图&#xff1b; 二、运行结果&#xff1b; 三、源代码&#xff1b; # define _CRT_SECURE_NO_WARNINGS # include <stdio.h> # include <string.h>int main() {//初始化变量值&#xff1b;int i 0;char str[3][20];char string[20];//循环输入3个字符…

AIGC - SD(中英文本生成图片) + PaddleHub/HuggingFace + stable-diffusion-webui

功能 stable-diffusion(文本生成图片)webui-win搭建&#xff08;开启api界面汉化&#xff09;PaddleHubHuggingFace: SD2&#xff0c;中文-alibaba/EasyNLP stable-diffusion-webui 下载与安装 环境相关下载 python&#xff08;文档推荐&#xff1a;Install Python 3.10.6 …