深度学习:基于MindSpore实现CycleGAN壁画修复

关于CycleGAN的基础知识可参考:

深度学习:CycleGAN图像风格迁移转换-CSDN博客

以及MindSpore官方的教学视频:

CycleGAN图像风格迁移转换_哔哩哔哩_bilibili

本案例将基于CycleGAN实现破损草图到线稿图的转换

数据集

本案例使用的数据集里面的图片为经图线稿图数据。图像被统一缩放为256×256像素大小,其中用于训练的线稿图片25654张、草图图片25654张,用于测试的线稿图片100张、草图图片116张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理。

DexiNed

 DexiNed(Dense Extreme Inception Network for Edge Detection)是一个为边缘检测任务设计的深度卷积神经网络模型。它由两个主要部分组成:Dexi和上采样网络(USNet)。

Dexi

这是模型的主要部分,包含六个编码块,每个块由多个子块组成,子块中包含卷积层、批量归一化和ReLU激活函数。从第二个块开始,引入了跳跃连接(skip connections),以保留不同层次的边缘特征。这些块的输出特征图被送入上采样网络以生成中间边缘图。

上采样网络(USNet)

这个部分由多个上采样块组成,每个块包括卷积层和反卷积层(也称为转置卷积层)。USNet的作用是将Dexi输出的低分辨率特征图上采样到更高的分辨率,以生成清晰的边缘图。

卷积层用于提取特征,反卷积层(或转置卷积层)用于将特征图的空间尺寸增大。

损失函数

 DexiNed模型使用的损失函数是专门为边缘检测任务设计的,它在一定程度上受到了BDCN(Bi-directional Cascade Network)损失函数的启发,并进行了一些修改和优化。这个损失函数的目的是在训练过程中平衡正面(正样本)和负面(负样本)的边缘样本比例,从而提高边缘检测的准确性。

损失函数定义为:

其中,

Li 是第 i 个输出的损失,λi 是对应的权重,用于平衡正负样本的比例。具体的 Li 计算方式为:

 DexiNed数据集

DexiNed模型的训练数据集主要是为边缘检测任务设计的高质量数据集。在论文中提到了两个主要的数据集:

  1. BIPED (Barcelona Images for Perceptual Edge Detection):这是一个特别为边缘检测设计的大规模数据集,包含详细的边缘标注信息。它由250张真实世界的图像组成,图像分辨率为1280×720像素,主要描绘城市环境场景。这些图像的边缘通过手动标注生成,以确保边缘检测的准确性。

  2. MDBD (Multicue Dataset for Boundary Detection):这是一个用于边界检测的数据集,也适用于边缘检测任务。它由100个高清图像组成,每个图像有多个参与者的标注,适用于训练和评估边缘检测算法。

DexiNed模型需要成对的数据来进行训练,即每张输入图像都需要有一个对应的标注图像(Ground Truth, GT)。这些标注图像详细地标出了图像中边缘的位置,模型通过比较预测边缘和这些标注来学习如何准确地检测边缘。

DexiNed在本例中主要用于将彩色图片转化为线稿图,随后将线稿图输入CycleGAN,得到输出。

基于MindSpore的壁画修复

加载数据集

#下载数据集
from download import downloadurl = "https://6169fb4615b14dbcb6b2cb1c4eb78bb2.obs.cn-north-4.myhuaweicloud.com/Cyc_line.zip"download(url, "./localdata", kind="zip", replace=True)
from __future__ import division
import math
import numpy as npimport os
import multiprocessingimport mindspore.dataset as de
import mindspore.dataset.vision as vision"""数据集分布式采样器"""
class DistributedSampler:"""Distributed sampler."""def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):if num_replicas is None:print("***********Setting world_size to 1 since it is not passed in ******************")num_replicas = 1if rank is None:print("***********Setting rank to 0 since it is not passed in ******************")rank = 0self.dataset_size = dataset_sizeself.num_replicas = num_replicasself.rank = rankself.epoch = 0self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))self.total_size = self.num_samples * self.num_replicasself.shuffle = shuffledef __iter__(self):# deterministically shuffle based on epochif self.shuffle:indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)# np.array type. number from 0 to len(dataset_size)-1, used as index of datasetindices = indices.tolist()self.epoch += 1# change to list typeelse:indices = list(range(self.dataset_size))# add extra samples to make it evenly divisibleindices += indices[:(self.total_size - len(indices))]assert len(indices) == self.total_size# subsampleindices = indices[self.rank:self.total_size:self.num_replicas]assert len(indices) == self.num_samplesreturn iter(indices)def __len__(self):return self.num_samples
# 加载CycleGAN数据集
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']# 判断当前文件是否为图片
def is_image_file(filename):return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)# 定义一个函数用于从指定目录中创建数据集列表
def make_dataset(dir_path, max_dataset_size=float("inf")):# 初始化一个空列表用来存储图片路径images = []# 确保提供的dir_path是一个有效的目录assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path# 遍历目录下的所有文件,将图片的文件路径存入images列表for root, _, fnames in sorted(os.walk(dir_path)):for fname in fnames:if is_image_file(fname):path = os.path.join(root, fname)images.append(path)# 返回指定长度的图片路径列表return images[:min(max_dataset_size, len(images))]# CycleGAN中没有成对出现,但是分属两个领域的图片数据
class UnalignedDataset:'''此数据集类能够加载未对齐或未配对的数据集。需要两个目录来存放来自领域A和B的训练图片。可以使用'--dataroot /path/to/data'这样的标志来训练模型。同样,在测试时也需要准备两个目录。返回:两个领域的图片路径列表。'''def __init__(self, dataroot, max_dataset_size=float("inf"), use_random=True):# 根据指定根路径生成A\B领域数据的文件夹路径self.dir_A = os.path.join(dataroot, 'trainA')self.dir_B = os.path.join(dataroot, 'trainB')# 领域A图片数据的路径self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size))# 领域B图片数据的路径self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size))# 领域A的数据长度self.A_size = len(self.A_paths)# 领域B的数据长度self.B_size = len(self.B_paths)# 根据参数决定是否随机化self.use_random = use_random# 从数据集中根据给定的索引 index 获取一个样本对(分别来自领域A和领域B的一张图片)def __getitem__(self, index):# 数据A的索引index_A = index % self.A_size# 数据B的索引index_B = index % self.B_size# 每遍历完所有图片后会重新随机排序领域A中的图片路径列表,并且从领域B中随机选取图片。if index % max(self.A_size, self.B_size) == 0 and self.use_random:random.shuffle(self.A_paths)index_B = random.randint(0, self.B_size - 1)# 获取指定下标的图片路径A_path = self.A_paths[index_A]B_path = self.B_paths[index_B]# 获取图片对象A_img = np.array(Image.open(A_path).convert('RGB'))B_img = np.array(Image.open(B_path).convert('RGB'))# 返回领域A和B的图片return A_img, B_imgdef __len__(self):return max(self.A_size, self.B_size)
def create_dataset(dataroot, batch_size=1, use_random=True, device_num=1, rank=0, max_dataset_size=float("inf"), image_size=256):"""创建数据集该数据集类可以加载用于训练或测试的图像。参数:dataroot (str): 图像根目录。batch_size (int): 批处理大小,默认为1。use_random (bool): 是否使用随机化,默认为True。device_num (int): 设备数量,默认为1。rank (int): 当前设备的排名,默认为0。max_dataset_size (float): 数据集的最大大小,默认为无穷大。image_size (int): 图像的尺寸,默认为256x256。返回:RGB图像列表。"""shuffle = use_random # 是否打乱数据集# 获取系统可用的CPU核心数cores = multiprocessing.cpu_count()# 计算并行工作的线程数,根据设备数量分配num_parallel_workers = min(1, int(cores / device_num))# 定义归一化时使用的均值和标准差# 三个通道的均值和房擦汗都是127.5mean = [0.5 * 255] * 3std = [0.5 * 255] * 3# 创建数据集(未对齐)dataset = UnalignedDataset(dataroot, max_dataset_size=max_dataset_size, use_random=use_random)# 使用DistributedSampler来实现分布式采样distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)# 创建GeneratorDataset,指定列名,并使用之前创建的sampler和并行工作线程数ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)# 指定数据增强操作if use_random:trans = [# 图片随机裁剪变比例vision.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),# 水平翻转,概率为0.5vision.RandomHorizontalFlip(prob=0.5),# 图片数据归一化vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]else:  # 如果不启用随机化,则只进行简单的缩放和归一化trans = [C.Resize((image_size, image_size)),  # 固定大小缩放C.Normalize(mean=mean, std=std),  # 归一化C.HWC2CHW()  # 将HWC格式转换为CHW格式]# 将数据增强操作映射到数据中ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)# 设置批处理大小,并且丢弃不足一批的数据ds = ds.batch(batch_size, drop_remainder=True)return ds
#根据设备情况调整训练参数dataroot = "./localdata"
batch_size = 12
device_num = 1
rank = 0
use_random = True
max_dataset_size = 24000
image_size = 256cyclegan_ds = create_dataset(dataroot=dataroot,max_dataset_size=max_dataset_size,batch_size=batch_size,device_num=device_num,rank = rank,use_random=use_random,image_size=image_size)
datasize = cyclegan_ds.get_dataset_size()
print("Datasize: ", datasize)

构建和训练CycleGAN的代码不再重复贴出,可参考上述提到的博文

构建DexiNed

import os
import cv2
import numpy as np
import timeimport mindspore as ms
from mindspore import nn, ops
from mindspore import dataset as ds
from mindspore.amp import auto_mixed_precision
from mindspore.common import initializer as init
# DexiNed边缘检测数据集
class Test_Dataset():def __init__(self, data_root, mean_bgr, image_size):self.data = []# 列出根路径下的所有文件名imgs_ = os.listdir(data_root)# 初始化两个空列表,分别用于存储图像路径和对应的文件名self.names = []self.filenames = []for img in imgs_:if img.endswith(".png") or img.endswith(".jpg"):# 构建文件的完整路径dir = os.path.join(data_root, img)self.names.append(dir)self.filenames.append(img)self.mean_bgr = mean_bgrself.image_size = image_sizedef __len__(self):return len(self.names)def __getitem__(self, idx):# 使用OpenCV读取指定索引位置的图像,读取模式为彩色image = cv2.imread(self.names[idx], cv2.IMREAD_COLOR)# 读取图片的长宽im_shape = (image.shape[0], image.shape[1])# 对图像进行变换处理image = self.transform(img=image)# 返回图像、图像名、图像形状return image, self.filenames[idx], im_shapedef transform(self, img):# 裁切img = cv2.resize(img, (self.image_size, self.image_size))# 将图像转换为浮点数类型img = np.array(img, dtype=np.float32)# 归一化img -= self.mean_bgr# 将图像从(H, W, C)格式转换为(C, H, W)格式img = img.transpose((2, 0, 1))return img
# DexiNed网络结构# 初始化权重函数
def weight_init(net):for name, param in net.parameters_and_names():# 使用Xavier分布初始化权重if 'weight' in name:param.set_data(init.initializer(init.XavierNormal(),param.shape,param.dtype))# 偏置初始化为0if 'bias' in name:param.set_data(init.initializer('zeros', param.shape, param.dtype))
# 表示DexiNed中的一个基础的密集连接层,实现具有批量归一化和ReLU激活的卷积层
class _DenseLayer(nn.Cell):def __init__(self, input_features, out_features):super(_DenseLayer, self).__init__()# 两个ConvNormReLU块self.conv1 = nn.Conv2d(input_features, out_features, kernel_size=3,stride=1, padding=2, pad_mode="pad",has_bias=True, weight_init=init.XavierNormal())self.norm1 = nn.BatchNorm2d(out_features)self.relu1 = nn.ReLU()self.conv2 = nn.Conv2d(out_features, out_features, kernel_size=3,stride=1, pad_mode="pad", has_bias=True,weight_init=init.XavierNormal())self.norm2 = nn.BatchNorm2d(out_features)self.relu = ops.ReLU()def construct(self, x):x1, x2 = xx1 = self.conv1(self.relu(x1))x1 = self.norm1(x1)x1 = self.relu1(x1)x1 = self.conv2(x1)new_features = self.norm2(x1)# 每一层的输出都会被传递给所有后续层作为输入return 0.5 * (new_features + x2), x2
# 基于DenseLayer定义DenseBlock
class _DenseBlock(nn.Cell):def __init__(self, num_layers, input_features, out_features):super(_DenseBlock, self).__init__()self.denselayer1 = _DenseLayer(input_features, out_features)input_features = out_featuresself.denselayer2 = _DenseLayer(input_features, out_features)if num_layers == 3:self.denselayer3 = _DenseLayer(input_features, out_features)self.layers = nn.SequentialCell([self.denselayer1, self.denselayer2, self.denselayer3])else:self.layers = nn.SequentialCell([self.denselayer1, self.denselayer2])def construct(self, x):x = self.layers(x)return x
# 表示上采样块,这是USNet的一部分,用于将特征图的尺寸增大。
class UpConvBlock(nn.Cell):def __init__(self, in_features, up_scale):super(UpConvBlock, self).__init__()# 定义上采样的因子,默认为2self.up_factor = 2# 定义一个常量特征数,通常用于控制输出通道的数量self.constant_features = 16layers = self.make_deconv_layers(in_features, up_scale)assert layers is not None, layersself.features = nn.SequentialCell(*layers)# # 构建上采样层def make_deconv_layers(self, in_features, up_scale):layers = []all_pads = [0, 0, 1, 3, 7]# 根据up_scale循环创建相应的层# 逐步放大for i in range(up_scale):# 定义卷积核大小kernel_size = 2 ** up_scale# 获取填充大小pad = all_pads[up_scale]# 计算输出维度out_features = self.compute_out_features(i, up_scale)# 创建反卷积层layers.append(nn.Conv2d(in_features, out_features, 1, has_bias=True))layers.append(nn.ReLU())layers.append(nn.Conv2dTranspose(out_features, out_features, kernel_size,stride=2, padding=pad, pad_mode="pad",has_bias=True, weight_init=init.XavierNormal()))# 更新通道数in_features = out_featuresreturn layers# 计算当前输出通道数def compute_out_features(self, idx, up_scale):return 1 if idx == up_scale - 1 else self.constant_featuresdef construct(self, x):return self.features(x)
# 单个卷积块,包含Conv和BatchNorm
class SingleConvBlock(nn.Cell):def __init__(self, in_features, out_features, stride, use_bs=True):super().__init__()self.use_batch_norm = use_bsself.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, pad_mode='pad', has_bias=True, weight_init=init.XavierNormal())self.bn = nn.BatchNorm2d(out_features)def construct(self, x):x = self.conv(x)if self.use_batch_norm:x = self.bn(x)return x
# 双卷积块
class DoubleConvBlock(nn.Cell):def __init__(self, in_features, mid_features,out_features=None,stride=1,use_act=True):super(DoubleConvBlock, self).__init__()self.use_act = use_actif out_features is None:out_features = mid_featuresself.conv1 = nn.Conv2d(in_features,mid_features,3,padding=1,stride=stride,pad_mode="pad",has_bias=True,weight_init=init.XavierNormal())self.bn1 = nn.BatchNorm2d(mid_features)self.conv2 = nn.Conv2d(mid_features,out_features,3,padding=1,pad_mode="pad",has_bias=True,weight_init=init.XavierNormal())self.bn2 = nn.BatchNorm2d(out_features)self.relu = nn.ReLU()def construct(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)if self.use_act:x = self.relu(x)return x
# 自定义最大汇聚层
class maxpooling(nn.Cell):def __init__(self):super(maxpooling, self).__init__()self.pad = nn.Pad(((0,0),(0,0),(1,1),(1,1)), mode="SYMMETRIC")self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')def construct(self, x):x = self.pad(x)x = self.maxpool(x)return x
# 组件DexiNed网络
class DexiNed(nn.Cell):def __init__(self):super(DexiNed, self).__init__()#  DoubleConvBlock(双卷积块)实例,用于构建DexiNed的编码部分。# 处理输入图像和第一层的输出。self.block_1 = DoubleConvBlock(3, 32, 64, stride=2,)self.block_2 = DoubleConvBlock(64, 128, use_act=False)# 于实现DexiNed中的密集连接层# 用于提取多尺度特征self.dblock_3 = _DenseBlock(2, 128, 256)  # [128,256,100,100]self.dblock_4 = _DenseBlock(3, 256, 512)self.dblock_5 = _DenseBlock(3, 512, 512)self.dblock_6 = _DenseBlock(3, 512, 256)# 最大池化层,用于下采样操作。self.maxpool = maxpooling()# 用于从不同层次提取特征。self.side_1 = SingleConvBlock(64, 128, 2)self.side_2 = SingleConvBlock(128, 256, 2)self.side_3 = SingleConvBlock(256, 512, 2)self.side_4 = SingleConvBlock(512, 512, 1)self.side_5 = SingleConvBlock(512, 256, 1)  # 用于准备输入到密集块的数据。# right skip connections, figure in Journal paperself.pre_dense_2 = SingleConvBlock(128, 256, 2)self.pre_dense_3 = SingleConvBlock(128, 256, 1)self.pre_dense_4 = SingleConvBlock(256, 512, 1)self.pre_dense_5 = SingleConvBlock(512, 512, 1)self.pre_dense_6 = SingleConvBlock(512, 256, 1)# 上采样块,用于恢复特征图的尺寸。self.up_block_1 = UpConvBlock(64, 1)self.up_block_2 = UpConvBlock(128, 1)self.up_block_3 = UpConvBlock(256, 2)self.up_block_4 = UpConvBlock(512, 3)self.up_block_5 = UpConvBlock(512, 4)self.up_block_6 = UpConvBlock(256, 4)# 单卷积块,用于将多个尺度的特征图融合成一个单一的输出。self.block_cat = SingleConvBlock(6, 1, stride=1, use_bs=False)# 如果张量的形状与目标形状不匹配,则使用双线性插值进行调整;否则直接返回张量。def slice(self, tensor, slice_shape):t_shape = tensor.shapeheight, width = slice_shapeif t_shape[-1] != slice_shape[-1]:new_tensor = ops.interpolate(tensor,sizes=(height, width),mode='bilinear',coordinate_transformation_mode="half_pixel")else:new_tensor = tensorreturn new_tensordef construct(self, x):# 确保输入张量 x 是四维的。assert x.ndim == 4, x.shape# 通过 block_1 处理输入 x,并通过 side_1 提取特征。# Block 1block_1 = self.block_1(x)block_1_side = self.side_1(block_1)# 通过 block_2 处理 block_1 的输出,然后通过 maxpool 降采样,并与 block_1_side 相加。再通过 side_2 提取特征。# Block 2block_2 = self.block_2(block_1)block_2_down = self.maxpool(block_2)block_2_add = block_2_down + block_1_sideblock_2_side = self.side_2(block_2_add)# 通过 pre_dense_3 处理 block_2_down,并将其与 block_2_add 一起传递给 dblock_3。# 然后通过 maxpool 降采样,并与 block_2_side 相加。再通过 side_3 提取特征。# Block 3block_3_pre_dense = self.pre_dense_3(block_2_down)block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])block_3_down = self.maxpool(block_3)  # [128,256,50,50]block_3_add = block_3_down + block_2_sideblock_3_side = self.side_3(block_3_add)# 通过 pre_dense_2 和 pre_dense_4 处理 block_2_down 和 block_3_down,并将它们相加后传递给 dblock_4。# 然后通过 maxpool 降采样,并与 block_3_side 相加。再通过 side_4 提取特征。# Block 4block_2_resize_half = self.pre_dense_2(block_2_down)block_4_pre_dense = self.pre_dense_4(block_3_down + block_2_resize_half)block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])block_4_down = self.maxpool(block_4)block_4_add = block_4_down + block_3_sideblock_4_side = self.side_4(block_4_add)# 通过 pre_dense_5 处理 block_4_down,并将其与 block_4_add 一起传递给 dblock_5。然后与 block_4_side 相加。# Block 5block_5_pre_dense = self.pre_dense_5(block_4_down)  # block_5_pre_dense_512 +block_4_downblock_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])block_5_add = block_5 + block_4_side# 通过 pre_dense_6 处理 block_5,并将其与 block_5_add 一起传递给 dblock_6。# Block 6block_6_pre_dense = self.pre_dense_6(block_5)block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])# upsampling blocks# 对每个块的输出进行上采样,恢复特征图的尺寸。out_1 = self.up_block_1(block_1)out_2 = self.up_block_2(block_2)out_3 = self.up_block_3(block_3)out_4 = self.up_block_4(block_4)out_5 = self.up_block_5(block_5)out_6 = self.up_block_6(block_6)results = [out_1, out_2, out_3, out_4, out_5, out_6]# 将所有上采样的输出拼接在一起,并通过 block_cat 进行最后的融合,生成最终的边缘图。# concatenate multiscale outputsop = ops.Concat(1)block_cat = op(results)block_cat = self.block_cat(block_cat)  # Bx1xHxWresults.append(block_cat)# 返回包含多个尺度的边缘图和最终融合后的边缘图的结果列表。return results

使用DexiNed进行推理

'''将输入图像规格化到指定范围'''
def image_normalization(img, img_min=0, img_max=255, epsilon=1e-12):img = np.float32(img)img = (img - np.min(img)) * (img_max - img_min) / \((np.max(img) - np.min(img)) + epsilon) + img_minreturn img
'''对DexiNed模型的输出数据进行后处理'''
# DexiNed 模型会输出多个尺度的边缘图(results列表中包含了多个上采样的结果),
# 这个函数将这些边缘图进行融合,并应用一些图像处理技术来生成最终的边缘检测结果。
def fuse_DNoutput(img):edge_maps = []tensor = imgfor i in tensor:sigmoid = ops.Sigmoid()output = sigmoid(i).numpy()edge_maps.append(output)tensor = np.array(edge_maps)idx = 0tmp = tensor[:, idx, ...]tmp = np.squeeze(tmp)preds = []for i in range(tmp.shape[0]):tmp_img = tmp[i]tmp_img = np.uint8(image_normalization(tmp_img))tmp_img = cv2.bitwise_not(tmp_img)preds.append(tmp_img)if i == 6:fuse = tmp_imgfuse = fuse.astype(np.uint8)idx += 1return fuse
"""DexiNed 检测."""def test(imgs,dexined_ckpt):if not os.path.isfile(dexined_ckpt):raise FileNotFoundError(f"Checkpoint file not found: {dexined_ckpt}")print(f"DexiNed ckpt path : {dexined_ckpt}")# os.makedirs(dexined_output_dir, exist_ok=True)model = DexiNed()# model = auto_mixed_precision(model, 'O2')ms.load_checkpoint(dexined_ckpt, model)model.set_train(False)preds = []origin = []total_duration = []print('Start dexined testing....')for img in imgs.create_dict_iterator():filename = str(img["names"])[2:-2]# print(filename)# output_dir_f = os.path.join(dexined_output_dir, filename)image = img["data"]origin.append(filename)end = time.perf_counter()# 使用DexiNed进行预测pred = model(image)# 获取图片宽高img_h = img["img_shape"][0, 0]img_w = img["img_shape"][0, 1]# 调用 fuse_DNoutput 函数对模型的输出进行后处理。pred = fuse_DNoutput(pred)# 将处理后的边缘图调整为原始图像的尺寸。dexi_img = cv2.resize(pred, (int(img_w.asnumpy()), int(img_h.asnumpy())))# cv2.imwrite("output.jpg", dexi_img)tmp_duration = time.perf_counter() - endtotal_duration.append(tmp_duration)preds.append(pred)total_duration_f = np.sum(np.array(total_duration))print("FPS: %f.4" % (len(total_duration) / total_duration_f))# 返回处理后的边缘图列表 preds 和原始图像文件名列表 origin。return preds,origin

DexiNed结合CycleGAN对壁画进行修复

import os
import numpy as np
from PIL import Image
import mindspore.dataset as ds
import matplotlib.pyplot as plt
import mindspore.dataset.vision as vision
from mindspore.dataset import transforms
from mindspore import load_checkpoint, load_param_into_net# 加载权重文件
def load_ckpt(net, ckpt_dir):param_GA = load_checkpoint(ckpt_dir)load_param_into_net(net, param_GA)#模型参数地址
g_a_ckpt = './ckpt/G_A_120.ckpt'
dexined_ckpt = "./ckpt/dexined.ckpt"#图片输入地址
img_path='./ckpt/jt'
#输出地址
save_path='./result'load_ckpt(net_rg_a, g_a_ckpt)os.makedirs(save_path, exist_ok=True)
# 图片推理
fig = plt.figure(figsize=(16, 4), dpi=64)
def eval_data(dir_path, net, a):my_dataset = Test_Dataset(dir_path, mean_bgr=[167.15, 146.07, 124.62], image_size=512)dataset = ds.GeneratorDataset(my_dataset, column_names=["data", "names", "img_shape"])dataset = dataset.batch(1, drop_remainder=True)# 使用DexiNed将原图转为线稿图preds ,origin= test(dataset,dexined_ckpt)for i, data in enumerate(preds):# 读取线稿图img =ms.Tensor((np.array([data,data,data])/255-0.5)*2).unsqueeze(0)# 将线稿图放入生成器,生成假图fake = net(img.to(ms.float32))fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))# 找到线稿图对应的原图,用作后续输出img = (Image.open(os.path.join(img_path,origin[i])).convert('RGB'))# 保存CycleGAN生成的结果fake_pil=Image.fromarray(fake.asnumpy())fake_pil.save(f"{save_path}/{i}.jpg")# 展示推理结果if i<8:fig.add_subplot(2, 8, min(i+1+a, 16))plt.axis("off")plt.imshow(np.array(img))fig.add_subplot(2, 8, min(i+9+a, 16))plt.axis("off")plt.imshow(fake.asnumpy())eval_data(img_path,net_rg_a, 0)plt.show()

输出结果如下:

详细可以参考MindSpore官方教学视频:

基于MindSpore实现CycleGAN壁画修复_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/438786.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【含文档】基于Springboot+Vue的护肤品推荐系统(含源码+数据库+lw)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 系统定…

企望制造ERP系统存在RCE漏洞

漏洞描述 企望制造纸箱业erp系统由深知纸箱行业特点和业务流程的多位IT专家打造&#xff0c;具有国际先进的管理方式&#xff0c;将现代化的管理方式融入erp软件中&#xff0c;让企业分分钟就拥有科学的管理经验。erp的功能包括成本核算、报价定价、订单下达、生产下单、现场管…

五子棋双人对战项目(3)——匹配模块

目录 一、分析需求 二、约定前后端交互接口 匹配请求&#xff1a; 匹配响应&#xff1a; 三、实现游戏大厅页面&#xff08;前端代码&#xff09; game_hall.html&#xff1a; common.css&#xff1a; game_hall.css&#xff1a; 四、实现后端代码 WebSocketConfig …

vue3 环境配置vue-i8n国际化

一.依赖和插件的安装 主要是vue-i18n和 vscode的自动化插件i18n Ally https://vue-i18n.intlify.dev/ npm install vue-i18n10 pnpm add vue-i18n10 yarn add vue-i18n10 vscode在应用商城中搜索i18n Ally&#xff1a;如图 二.实操 安装完以后在对应项目中的跟package.jso…

计算机毕业设计 基于协同过滤算法的个性化音乐推荐系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

Charles+socksdroid手机抓包配置

证书配置 保存一个证书 使用abd将证书推送到手机 找手机的加密与凭据 点击从存储设备安装 选择刚刚导入手机的证书 证书按照成功 手机安装socksdroid 端口对应 ip对应 开启 点击allow 成功手机抓包 将用户证书移动到系统证书 系统证书路径&#xff1a;/etc/security/cacerts…

【springboot】整合LoadBalancer

目录 问题产生背景解决方案&#xff1a;实现LoadBalancer1. 添加依赖2. 配置文件3. 使用LoadBalancer4. 使用 RestTemplate 进行服务调用5. 测试 问题产生背景 以下是一个购物车项目&#xff0c;通过调用外部接口获取商品信息&#xff0c;并添加到购物车中&#xff0c;这段代码…

【Android 14源码分析】WMS-窗口显示-第二步:relayoutWindow -1

忽然有一天&#xff0c;我想要做一件事&#xff1a;去代码中去验证那些曾经被“灌输”的理论。                                                                                  – 服装…

【JAVA开源】基于Vue和SpringBoot的宠物咖啡馆平台

本文项目编号 T 064 &#xff0c;文末自助获取源码 \color{red}{T064&#xff0c;文末自助获取源码} T064&#xff0c;文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析 六、核心代码6.1 查…

使用CSS实现酷炫加载

使用CSS实现酷炫加载 效果展示 整体页面布局 <div class"container"></div>使用JavaScript添加loading加载动画的元素 document.addEventListener("DOMContentLoaded", () > {let container document.querySelector(".container&q…

初识Linux · 自主Shell编写

目录 前言&#xff1a; 1 命令行解释器部分 2 获取用户命令行参数 3 命令行参数进行分割 4 执行命令 5 判断命令是否为内建命令 前言&#xff1a; 本文介绍是自主Shell编写&#xff0c;对于shell&#xff0c;即外壳解释程序&#xff0c;我们目前接触到的命令行解释器&am…

数据提取之JSON与JsonPATH

第一章 json 一、json简介 json简单说就是javascript中的对象和数组&#xff0c;所以这两种结构就是对象和数组两种结构&#xff0c;通过这两种结构可以表示各种复杂的结构 > 1. 对象&#xff1a;对象在js中表示为{ }括起来的内容&#xff0c;数据结构为 { key&#xff1…

区块链+Web3学习笔记(METAMASHK、密码学知识)

学习资料来源于B站&#xff1a; 17小时最全Web3教程&#xff1a;ERC20&#xff0c;NFT&#xff0c;Hardhat&#xff0c;CCIP跨链_哔哩哔哩_bilibili 该课程提供的Github代码地址&#xff0c;相关资料详见README.md&#xff1a; Web3_tutorial_Chinese/README.md at main sm…

银河麒麟系统内存清理

银河麒麟系统内存清理 1、操作步骤2、注意事项 &#x1f490;The Begin&#x1f490;点点关注&#xff0c;收藏不迷路&#x1f490; 当银河麒麟系统运行较长时间&#xff0c;内存中的缓存可能会积累过多&#xff0c;影响系统性能。此时&#xff0c;你可以通过简单的命令来清理这…

JS | 如何解决ajax无法后退的问题?

Ajax请求通常不支持浏览器的后退按钮&#xff0c;因为它们是异步的&#xff0c;不会导致页面重新加载(刷新)。但如果你想要用户能够通过浏览器的后退按钮回到之前的页面状态&#xff0c;你可以通过几种方法来解决这个问题&#xff1a; 1、使用pushState和replaceState方法 hi…

【Android】数据存储

本章介绍Android五种主要存储方式的用法&#xff0c;包括共享参数SharedPreferences、数据库SQLite、SD卡文件、App的全局内存&#xff0c;另外介绍重要组件之一的应用Application的基本概念与常见用法&#xff0c;以及四大组件之一的内容提供器ContentProvider的基本概念与常见…

五.海量数据实时分析-FlinkCDC+DorisConnector实现数据的全量增量同步

前言 前面四篇文字都在学习Doris的理论知识&#xff0c;也是比较枯燥&#xff0c;当然Doris的理论知识还很多&#xff0c;我们后面慢慢学&#xff0c;本篇文章我们尝试使用SpringBoot来整合Doris完成基本的CRUD。 由于 Doris 高度兼容 Mysql 协议&#xff0c;两者在 SQL 语法…

Redis数据库与GO(二):list,set

一、list&#xff08;列表&#xff09; list&#xff08;列表&#xff09;是简单的字符串列表&#xff0c;按照插入顺序排序。你可以添加一个元素到列表的头部(左边)或者尾部(右边)。List本质是个链表&#xff0c; list是一个双向链表&#xff0c;其元素是有序的&#xff0c;元…

GS-SLAM论文阅读笔记-CaRtGS

前言 这篇文章看起来有点像Photo-slam的续作&#xff0c;行文格式和图片类型很接近&#xff0c;而且貌似是出自同一所学校的&#xff0c;所以推测可能是Photo-slam的优化与改进方法&#xff0c;接下来具体看看改进了哪些地方。 文章目录 前言1.背景介绍GS-SLAM方法总结 2.关键…

uniapp+Android面向网络学习的时间管理工具软件 微信小程序

目录 项目介绍支持以下技术栈&#xff1a;具体实现截图HBuilderXuniappmysql数据库与主流编程语言java类核心代码部分展示登录的业务流程的顺序是&#xff1a;数据库设计性能分析操作可行性技术可行性系统安全性数据完整性软件测试详细视频演示源码获取方式 项目介绍 用户功能…