经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

1 PixelCNN

  • PixelCNN是DeepMind团队在论文Pixel Recurrent Neural Networks (16.01)提出的一种生成模型,实际上这篇论文共提出了两种架构:PixelRNNPixelCNN,两者的主要区别是前者用LSTM来建模,而PixelCNN是基于CNN的,相比RNN,CNN计算更高效,我们这里只讨论PixelCNN。

  • PixelCNN借用了NLP里的方法来生成图像。对于自然图像,每个像素值的取值范围为0~255,共256个离散值。PixelCNN模型会根据前i - 1个像素输出第i个像素的概率分布。

  • 训练时,和多分类任务一样,要根据第i个像素的真值和预测的概率分布求交叉熵损失函数

  • 采样时(图像生成时),会根据前i - 1个像素直接从预测的概率分布(多项分布)里采样出第i个像素。

1.1 单通道PixelCNN

1.1.1 掩码卷积

我们现在知道了PixelCNN的大体思路,就是根据前i - 1个像素输出第i个像素的概率分布。我们现在只考虑单通道图像,每个像素的颜色取值只有256种,那么很容易想到下面的实现方式:

在这里插入图片描述

但是只输出一个像素的概率分布,这样训练效率太低了。

  • 在训练时,我们可以输入一幅图像,同时让模型输出图像每一点像素的概率分布(如下图所示),这样就能通过每个像素的真值和模型预测的概率分布求交叉熵损失函数,进行并行训练。
  • 我们能这么做的原因是:在训练时,整幅训练图像是已知的,因此我们可以在一次前向传播后得到图像每一处的概率分布。
  • 当然,我们需要找到每个像素都忽略后续像素的信息的方法,即论文中提出的掩码卷积机制,我们后面再讲。

在这里插入图片描述

但是在生成图像(采样)时,还是要一个像素一个像素的生成(如下所示)

  • 在采样时,我们会先根据前i - 1个像素输出第i个像素的概率分布。
  • 然后,我们会从第i个像素的概率分布中进行采样(如下面代码所示)
# 假设颜色取值范围为[0, 7],下面为概率分布
prob_dist = torch.tensor([[0.1347, 0.1356, 0.1048, 0.1314, 0.1329, 0.1256, 0.1326, 0.1025]])# 我们并不是取概率最大的像素,而是从概率分布中采样(例如下面取像素值6)
# torch.multinomial会从input这个概率分布中,取num_samples个值
pixel = torch.multinomial(input=prob_dist, num_samples=1).float() # tensor([[6.]])

在这里插入图片描述

我们现在已经知道了训练及采样的大体过程。但是,我们现在还是有一个疑问,如何保证训练时候,每个像素都忽略后续像素的信息?

PixelCNN论文里提出了一种掩码卷积机制,这种机制可以巧妙地掩盖住每个像素右侧和下侧的信息。

  • 具体来说,PixelCNN使用了两类掩码卷积:
    • 我们把两类掩码卷积分别称为「A类」和「B类」。
    • 二者都是对卷积操作的卷积核做了掩码处理,使得卷积核的右下部分不产生贡献。
    • A类和B类的唯一区别在于:卷积核的中心像素是否产生贡献
    • CNN的第一个的卷积层使用A类掩码卷积,之后每一层的都使用B类掩码卷积

在这里插入图片描述

我们来分析下这样设计的优点:

  • 对于一个7x7的图像,我们先用1次3x3 A类掩码卷积,再用若干次3x3 B类掩码卷积。我们观察图像中心处的像素在每次卷积后的感受野(即输入图像中哪些像素的信息能够传递到中心像素上)
    • 经过了第一个A类掩码卷积后,每个像素就已经看不到自己位置上的输入信息了。
    • 再经过两次B类掩码卷积后,中心像素能够看到左上角大部分像素的信息(如下图所示,我们发现还是会看漏少部分的信息,后面的Gated PixelCNN对此进行了改进)。
    • 这满足PixelCNN的约束。

在这里插入图片描述

  • 如果一直使用A类掩码卷积,每次卷积后中心像素都会看漏一些信息,最终就会导致看漏很多信息

在这里插入图片描述

  • 如果第一层就使用B类卷积,中心像素还是能看到自己位置的输入信息。这打破了PixelCNN的约束。

总结如下:

  • 逐像素预测只依赖于前面的像素,因此在选择卷积核时要进行掩码操作避免看到未来的值,因此,在第一层预测时可采用掩码卷积A
  • 由于CNN的逐像素预测是多层卷积,所以当第一层结束后,图像缺失部分已经有了预测值,因此在进行下一次/层卷积操作时可以利用当前像素的预测值,因此采用下列掩码卷积B
  • 需要注意的是,这里只考虑了单通道,如果扩展到RGB三个通道时,该如何进行mask呢?

1.1.2 PixelCNN的网络架构

  • 利用两类掩码卷积,PixelCNN满足了每个像素只能接受之前像素的信息这一约束。
  • 我们可以用任意一种CNN架构来实现PixelCNN。
  • 下图红色框所示部分是PixelCNN的网络结构,其中,第一个7x7卷积层用了A类掩码卷积,之后所有3x3卷积都是B类掩码卷积。

在这里插入图片描述

1.1.3 PixelCNN在MNIST数据集上的应用

1.1.3.1 模型

实现PixelCNN,最重要的是实现掩码卷积。

  • 掩码卷积的实现思路就是在卷积核组上设置一个mask。在前向传播的时候,先让卷积核组乘mask,再做普通的卷积。
  • 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import time
import einops
import cv2
import numpy as np
import osclass MaskConv2d(nn.Module):"""掩码卷积的实现思路:在卷积核组上设置一个mask,在前向传播的时候,先让卷积核组乘mask,再做普通的卷积"""def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]# 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2] = 1mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1# 为了保证掩码能正确广播到4维的卷积核组上,我们做一个reshape操作mask = mask.reshape((1, 1, H, W))# register_buffer可以把一个变量加入成员变量的同时,记录到PyTorch的Module中# 每当执行model.to(device)把模型中所有参数转到某个设备上时,被注册的变量会跟着转。# 第三个参数表示被注册的变量是否要加入state_dict中以保存下来self.register_buffer(name='mask', tensor=mask, persistent=False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_res

有了最核心的掩码卷积,我们来根据论文中的模型结构图把模型搭起来

在这里插入图片描述

  • 我们先实现残差块上图右部分的ResidualBlock,这里添加归一化
class ResidualBlock(nn.Module):"""残差块ResidualBlock"""def __init__(self, h, bn=True):super().__init__()self.relu = nn.ReLU()self.conv1 = nn.Conv2d(2 * h, h, 1)self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()self.conv3 = nn.Conv2d(h, 2 * h, 1)self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()def forward(self, x):# 1、ReLU + 1×1 Conv + bny = self.relu(x)y = self.conv1(y)y = self.bn1(y)# 2、ReLU + 3×3 Conv(mask B) + bny = self.relu(y)y = self.conv2(y)y = self.bn2(y)# 3、ReLU + 1×1 Conv + bny = self.relu(y)y = self.conv3(y)y = self.bn3(y)# 4、残差连接y = y + xreturn y
  • 有了所有这些基础模块后,我们就可以拼出最终的PixelCNN了。
  • 注意,我们可以自己决定颜色有几个亮度级别。要修改亮度级别的数量,只需要修改softmax输出的通道数color_level。
class PixelCNN(nn.Module):def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):super().__init__()self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()self.residual_blocks = nn.ModuleList()for _ in range(n_blocks):self.residual_blocks.append(ResidualBlock(h, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):# 1、7 × 7 conv(mask A)x = self.conv1(x)x = self.bn1(x)# 2、Multiple residual blocksfor block in self.residual_blocks:x = block(x)x = self.relu(x)# 3、1 × 1 convx = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x
1.1.3.2 数据集及训练

准备好了模型代码,我们可以编写训练脚本了:

  • PixelCNN有15个残差块,中间特征的通道数为128,输出前线性层的通道数为32
def get_dataloader(batch_size: int):dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',train=True,transform=ToTensor())return DataLoader(dataset, batch_size=batch_size, shuffle=True)def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):"""训练过程"""dataloader = get_dataloader(batch_size)model = model.to(device)optimizer = torch.optim.Adam(model.parameters(), 1e-3)loss_fn = nn.CrossEntropyLoss()tic = time.time()for e in range(n_epochs):total_loss = 0for x, _ in dataloader:current_batch_size = x.shape[0]x = x.to(device)# 把训练集的浮点颜色值转换成[0, color_level-1]之间的整型标签y = torch.ceil(x * (color_level - 1)).long()y = y.squeeze(1)predict_y = model(x)loss = loss_fn(predict_y, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item() * current_batch_sizetotal_loss /= len(dataloader.dataset)toc = time.time()torch.save(model.state_dict(), model_path)print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')if __name__ == '__main__':os.makedirs('work_dirs', exist_ok=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'# 需要注意的是:MNIST数据集的大部分像素都是0和255color_level = 8  # or 256# 1、创建PixelCNN模型model = PixelCNN(n_blocks=15, h=128, linear_dim=32, bn=True, color_level=color_level)# 2、模型训练model_path = f'work_dirs/model_pixelcnn_{color_level}.pth'train(model, device, model_path)# 3、采样sample(model, device, model_path, f'work_dirs/pixelcnn_{color_level}.jpg')        
1.1.3.3 采样
  • 在采样时,我们把x初始化成一个0张量。
  • 之后,循环遍历每一个像素,输入x,把预测出的下一个像素填入x.
def sample(model, device, model_path, output_path, n_sample=1):"""把x初始化成一个0张量。循环遍历每一个像素,输入x,把预测出的下一个像素填入x"""model.eval()model.load_state_dict(torch.load(model_path))model = model.to(device)C, H, W = get_img_shape()  # (1, 28, 28)x = torch.zeros((n_sample, C, H, W)).to(device)with torch.no_grad():for i in range(H):for j in range(W):# 我们先获取模型的输出,再用softmax转换成概率分布output = model(x)prob_dist = F.softmax(output[:, :, i, j], -1)# 再用torch.multinomial从概率分布里采样出【1】个[0, color_level-1]的离散颜色值# 再除以(color_level - 1)把离散颜色转换成浮点[0, 1]pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)# 最后把新像素填入到生成图像中x[:, :, i, j] = pixel# 乘255变成一个用8位字节表示的图像imgs = x * 255imgs = imgs.clamp(0, 255)imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))imgs = imgs.detach().cpu().numpy().astype(np.uint8)cv2.imwrite(output_path, imgs)

1.2 多通道PixelCNN

如下图所示,作者假设RGB三个通道之间存在相互影响

  • 其中红色预测不受蓝色和绿色通道的影响,只受上下文影响
  • 绿色红色通道和上下文影响,但不受蓝色通道影响;
  • 蓝色通道受上下文、红色通道、绿色通道影响

在这里插入图片描述

更具体地,我们规定一个子像素只由它之前的子像素决定,生成图像时,我们一个子像素一个子像素地生成

  • 如下图所示,对于RGB图像,R子像素由它之前所有像素决定
  • G子像素由它的R子像素和之前所有像素决定,
  • B子像素由它的R、G子像素和它之前所有像素决定。

在这里插入图片描述

如下图所示,由于现在要预测三个颜色通道,网络的输出应该是一个[256x3, H, W]形状的张量

  • 即每个像素输出三个概率分布,分别表示R、G、B取某种颜色的概率。
  • 同时,本质上来讲,网络是在并行地为每个像素计算3组结果。因此,为了达到同样的性能,网络所有的特征图的通道数也要乘3。

在这里插入图片描述

图像变为多通道后,A类卷积和B类卷积的定义也需要做出一些调整。我们不仅要考虑像素在空间上的约束,还要考虑一个像素内子像素间的约束。为此,我们要用不同的策略实现约束。为了方便描述,我们设卷积核组的形状为[o, i, h, w],其中o为输出通道数,i为输入通道数,h, w为卷积核的高和宽。

  • 对于通道间的约束,我们要在o, i两个维度上设置掩码,如下图左边所示。
    • 设输出通道可以被拆成三组o1, o2, o3,输入通道可以被拆成三组i1, i2, i3
      • o1 = 0:o/3, o2 = o/3:o*2/3, o3 = o*2/3:o
      • i1 = 0:i/3, i2 = i/3:i*2/3, i3 = i*2/3:i
      • 序号1, 2, 3分别表示这组通道是在维护R, G, B的计算。
    • 我们对输入通道组和输出通道组之间进行约束。
    • 对于A类卷积,我们令o1看不到i1, i2, i3o2看不到i2, i3o3看不到i3
    • 对于B类卷积,我们取消每个通道看不到自己的限制,即在A类卷积的基础上令o1看到i1o2看到i2o3看到i3
  • 如下图右边所示,对于空间上的约束,我们还是和之前一样,在h, w两个维度上设置掩码。由于「是否看到自己」的处理已经在o, i两个维度里做好了,我们直接在空间上用原来的B类卷积就行。

在这里插入图片描述

  • 下面给出三维掩码示意图方便理解:

在这里插入图片描述

2 Gated PixelCNN

2.1 Gated PixelCNN简述

  • 可以参考大神讲解:Gated PixelCNN (sergeiturukin.com)

  • PixelCNN的掩码卷积其实有一个重大漏洞:像素存在视野盲区。如下图所示,中心像素看不到右上角三个本应该能看到的像素。

在这里插入图片描述

  • 为此,PixelCNN论文的作者又发表了Conditional Image Generation with PixelCNN Decoders(16.06)。这篇论文提出了一种叫做Gated PixelCNN的改进架构。Gated PixelCNN使用了一种更好的掩码卷积机制,消除了原PixelCNN里的视野盲区。

在这里插入图片描述

  • 如下图所示,Gated PixelCNN使用了两种卷积,即垂直卷积和水平卷积,来分别维护一个像素上侧的信息和左侧的信息
    • 垂直卷积的结果只是一些临时量
    • 而水平卷积的结果最终会被网络输出
    • 使用这种新的掩码卷积机制后,每个像素能正确地收到之前所有像素的信息了。

在这里插入图片描述

  • Gated PixelCNN用下图的模块代替了原PixelCNN的普通残差模块。
  • 模块的输入输出都是两个量,左边的量是垂直卷积中间结果,右边的量是最后用来计算输出的量。
  • 垂直卷积的结果会经过偏移和一个1x1卷积,再加到水平卷积的结果上。
  • 两条计算路线在输出前都会经过门激活单元。所谓门激活单元,就是输入两个形状相同的量,一个做tanh,一个做sigmoid,两个结果相乘再输出。
  • 此外,模块右侧还有一个残差连接。

在这里插入图片描述

2.2 Gated PixelCNN在MNIST数据集上的应用

2.2.1 创建模型

  • 首先,实现垂直卷积和水平卷积
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import time
import einops
import cv2
import numpy as np
import osclass VerticalMaskConv2d(nn.Module):"""垂直卷积"""def __init__(self, *args, **kwags):super().__init__()self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[0:H // 2 + 1] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_resclass HorizontalMaskConv2d(nn.Module):"""水平卷积"""def __init__(self, conv_type, *args, **kwags):super().__init__()assert conv_type in ('A', 'B')self.conv = nn.Conv2d(*args, **kwags)H, W = self.conv.weight.shape[-2:]mask = torch.zeros((H, W), dtype=torch.float32)mask[H // 2, 0:W // 2] = 1if conv_type == 'B':mask[H // 2, W // 2] = 1mask = mask.reshape((1, 1, H, W))self.register_buffer('mask', mask, False)def forward(self, x):self.conv.weight.data *= self.maskconv_res = self.conv(x)return conv_res
# 垂直卷积
tensor([[[[1., 1., 1.],[1., 1., 1.],[0., 0., 0.]]]])
# A类水平卷积
tensor([[[[0., 0., 0.],[1., 0., 0.],[0., 0., 0.]]]])
# B类水平卷积
tensor([[[[0., 0., 0.],[1., 1., 0.],[0., 0., 0.]]]])
  • 我们现在搭建Gated Block模块,这也是最难理解的一部分。
  • 可以参考的解释:https://segmentfault.com/a/1190000041189859?utm_source=sf-similar-article

在这里插入图片描述

  • # 这里比较难理解,通过对图像进行零填充并裁剪图像底部,可以确保垂直和水平堆栈之间的因果关系
    v_to_h = v[:, :, 0:-1]
    v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
    # 注意到,v和i相加的位置只差了一个单位。
    # 为了把相加的位置对齐,我们要把v往下移一个单位,把原来在i-1处的信息移到i上。
    # 这样,移动过后的v_to_h就能和h直接用向量加法并行地加到一起了。
    

在这里插入图片描述

  • 维护两个v, h两个变量,分别表示垂直卷积部分的结果和水平卷积部分的结果。
    • v会经过一个垂直掩码卷积和一个门激活函数。
    • h会经过一个类似于残差块的结构,只不过第一个卷积是水平掩码卷积、激活函数是门激活函数、进入激活函数之前会和垂直卷积的信息融合。
class GatedBlock(nn.Module):def __init__(self, conv_type, in_channels, p, bn=True):super().__init__()self.conv_type = conv_typeself.p = pself.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, kernel_size=1)self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,1)self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()self.h_output_conv = nn.Conv2d(p, p, 1)self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()def forward(self, v_input, h_input):# v代表垂直卷积部分的结果v = self.v_conv(v_input)v = self.bn1(v)# Note: 重点代码# 为了把v的信息贴到h上,我们并不是像前面的示意图所写的令v上移一个单位# 而是用下面的代码令v下移了一个单位(下移即去掉最下面一行,往最上面一行填0)v_to_h = v[:, :, 0:-1]v_to_h = F.pad(v_to_h, (0, 0, 1, 0))# 和h相加前,先经过 1×1 convv_to_h = self.v_to_h_conv(v_to_h)v_to_h = self.bn2(v_to_h)# 分为两份,经过tanh 和 sigmoidv1, v2 = v[:, :self.p], v[:, self.p:]v1 = torch.tanh(v1)v2 = torch.sigmoid(v2)v = v1 * v2# h代表水平卷积部分的结果h = self.h_conv(h_input)h = self.bn3(h)h = h + v_to_h# 分为两份,经过tanh 和 sigmoidh1, h2 = h[:, :self.p], h[:, self.p:]h1 = torch.tanh(h1)h2 = torch.sigmoid(h2)h = h1 * h2h = self.h_output_conv(h)h = self.bn4(h)# 在网络的第一层,每个数据是不能看到自己的。# 所以,当GatedBlock发现卷积类型为A类时,不应该对h做残差连接。if self.conv_type == 'B':h = h + h_inputreturn v, h
  • 最后,我们来用GatedBlock搭出Gated PixelCNN
  • Gated PixelCNN和PixelCNN的结构非常相似,只是把ResidualBlock替换成了GatedBlock而已。
class GatedPixelCNN(nn.Module):def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):super().__init__()self.block1 = GatedBlock('A', 1, p, bn)self.blocks = nn.ModuleList()for _ in range(n_blocks):self.blocks.append(GatedBlock('B', p, p, bn))self.relu = nn.ReLU()self.linear1 = nn.Conv2d(p, linear_dim, 1)self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)self.out = nn.Conv2d(linear_dim, color_level, 1)def forward(self, x):v, h = self.block1(x, x)for block in self.blocks:v, h = block(v, h)x = self.relu(h)x = self.linear1(x)x = self.relu(x)x = self.linear2(x)x = self.out(x)return x

2.2.2 数据集、训练及采样

  • 数据集、训练及采样和PixelCNN一模一样,不再赘述。
def get_dataloader(batch_size: int):dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',train=True,transform=ToTensor())return DataLoader(dataset, batch_size=batch_size, shuffle=True)def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):"""训练过程"""dataloader = get_dataloader(batch_size)model = model.to(device)optimizer = torch.optim.Adam(model.parameters(), 1e-3)loss_fn = nn.CrossEntropyLoss()tic = time.time()for e in range(n_epochs):total_loss = 0for x, _ in dataloader:current_batch_size = x.shape[0]x = x.to(device)# 把训练集的浮点颜色值转换成0~color_level-1之间的整型标签的y = torch.ceil(x * (color_level - 1)).long()y = y.squeeze(1)predict_y = model(x)loss = loss_fn(predict_y, y)optimizer.zero_grad()loss.backward()optimizer.step()total_loss += loss.item() * current_batch_sizetotal_loss /= len(dataloader.dataset)toc = time.time()torch.save(model.state_dict(), model_path)print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')def get_img_shape():return (1, 28, 28)def sample(model, device, model_path, output_path, n_sample=1):"""把x初始化成一个0张量。循环遍历每一个像素,输入x,把预测出的下一个像素填入x"""model.eval()model.load_state_dict(torch.load(model_path))model = model.to(device)C, H, W = get_img_shape()  # (1, 28, 28)x = torch.zeros((n_sample, C, H, W)).to(device)with torch.no_grad():for i in range(H):for j in range(W):# 我们先获取模型的输出,再用softmax转换成概率分布output = model(x)prob_dist = F.softmax(output[:, :, i, j], -1)# 再用torch.multinomial从概率分布里采样出【1个】0~(color_level-1)的离散颜色值# 再除以(color_level - 1)把离散颜色转换成浮点颜色(因为网络是输入是浮点颜色)pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)# 最后把新像素填入生成图像x[:, :, i, j] = pixelimgs = x * 255imgs = imgs.clamp(0, 255)imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))imgs = imgs.detach().cpu().numpy().astype(np.uint8)cv2.imwrite(output_path, imgs)if __name__ == '__main__':os.makedirs('work_dirs', exist_ok=True)device = 'cuda' if torch.cuda.is_available() else 'cpu'color_level = 8  # or 256# 1、创建GatedPixelCNN模型model = GatedPixelCNN(n_blocks=15, p=128, linear_dim=32, bn=True, color_level=color_level)# 2、模型训练model_path = f'work_dirs/model_gatedpixelcnn_{color_level}.pth'train(model, device, model_path, batch_size=1)# 3、采样sample(model, device, model_path, f'work_dirs/gatedpixelcnn_{color_level}.jpg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/343211.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【排序算法】快速排序

文章目录 1.什么是快速排序2.快速排序的步骤3.时间复杂度 1.什么是快速排序 快速排序算法是一种高效的排序方法,它的基本思想是“分而治之”,通过一趟排序将待排记录分隔成独立的两部分,其中一部分记录的关键字均比另一部分的关键字小&#x…

从零开始手把手Vue3+TypeScript+ElementPlus管理后台项目实战五(引入vue-router,并给注册功能加上美丽的外衣el-form)

安装vue-router pnpm install vue-router创建router src下新增router目录,ruoter目录中新增index.ts import { createRouter, createWebHashHistory } from "vue-router"; const routes [{path: "/",name: "Home",component: () …

个人笔记-python生成gif

使用文件的修改时间戳进行排序 import os import re import imageio# 设置图片所在的文件夹路径 folder_path /home/czy/ACode/AMAW_20240219/9.3.x(Discrete_time_marching)/9.3.17.11.1(Disc_concessive_CH_ZJ)/current_figures # 文件夹路径;linux…

网络编程: 高级IO与多路转接select,poll,epoll的使用与介绍

网络编程: 高级IO与多路转接select,poll,epoll的使用与介绍 前言一.五种IO模型1.IO的本质2.五种IO模型1.五种IO模型2.同步IO与异步IO3.IO效率 二.非阻塞IO1.系统调用介绍2.验证代码 三.select多路转接1.系统调用接口2.写代码 : 基于select的TCP服务器1.封装的Socket接口2.开始写…

前端 CSS 经典:水波进度样式

前言&#xff1a;简单实现水波进度样式&#xff0c;简单好看。 效果图&#xff1a; 代码实现&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"utf-8" /><meta http-equiv"X-UA-Compatible" cont…

《数学学习与研究》杂志是什么级别?知网收录吗?评职认可吗?

《数学学习与研究》杂志是什么级别&#xff1f;知网收录吗&#xff1f;评职认可吗&#xff1f; 《数学学习与研究》是由东北师范大学主管&#xff0c;吉林省数学会与东北师范大学出版社联合主办的省级优秀数学类期刊杂志。该杂志为半月刊&#xff0c;国际标准刊号为 ISSN1007-…

SkyWalking之P0业务场景输出调用链路应用

延伸扩展&#xff1a;XX业务场景 路由标签打标、传播、检索 链路标签染色与传播 SW: SkyWalking的简写 用户请求携带HTTP头信息X-sw8-correlation “X-sw8-correlation: key1value1,key2value2,key3value3” 网关侧读取解析HTTP头信息X-sw8-correlation&#xff0c;然后通过SW…

【Linux】系统优化:一键切换软件源与安装Docker

引言 在Linux系统安装完成后&#xff0c;进行一些必要的初始化设置是提升系统性能和用户体验的关键。本文将重点介绍两个实用的一键脚本&#xff1a;LinuxMirrors提供的软件源切换脚本和Docker安装脚本。这两个脚本将帮助我们简化配置安装过程。 一键切换软件源脚本 在Linux…

AI绘画如何打造高质量数据集?

遇到难题不要怕&#xff01;厚德提问大佬答&#xff01; 厚德提问大佬答11 你是否对AI绘画感兴趣却无从下手&#xff1f;是否有很多疑问却苦于没有大佬解答带你飞&#xff1f;从此刻开始这些问题都将迎刃而解&#xff01;你感兴趣的话题&#xff0c;厚德云替你问&#xff0c;你…

数据动态变化时实现多选及回显

<template><el-dialog title"设置权限" :visible.sync"showDialog" :close-on-click-modal"false" :append-to-body"true" width"800px"><div v-loading"loading"><el-radio-group v-model&…

【论文阅读】MODELING AND SOLVING THE TRAVELING SALESMAN PROBLEM WITH PRIORITY PRIZES

文章目录 论文基本信息摘要1.引言2. INTEGER QUADRATIC PROGRAM FOR TSPPP3. MIXED INTEGER LINEAR PROGRAMS FOR TSPPP4. TABU SEARCH ALGORITHM FOR TSPPP5. COMPUTATIONAL RESULTS6. CONCLUDING REMARKS补充 论文基本信息 《MODELING AND SOLVING THE TRAVELING SALESMAN P…

SQL语句练习每日5题(四)

题目1——查找GPA最高值 想要知道复旦大学学生gpa最高值是多少&#xff0c;请你取出相应数据 题解&#xff1a; 1、使用MAX select MAX(gpa) FROM user_profile WHERE university 复旦大学 2、使用降序排序组合limit select gpa FROM user_profile WHERE university 复…

【vuex小试牛刀】

了解vuex核心概念请移步 https://vuex.vuejs.org/zh/ # 一、初始vuex # 1.1 vuex是什么 就是把需要共享的变量全部存储在一个对象里面&#xff0c;然后将这个对象放在顶层组件中供其他组件使用 父子组件通信时&#xff0c;我们通常会采用 props emit 这种方式。但当通信双方不…

如何搭建一台永久运行的个人服务器?

一、前言 由于本人在这段时候&#xff0c;看到了一个叫做树莓派的东东&#xff0c;初步了解之后觉得很有意思&#xff0c;于是想把整个过程记录下来。 二、树莓派是什么&#xff1f; Raspberry Pi(中文名为树莓派,简写为RPi&#xff0c;(或者RasPi / RPI) 是为学习计算机编程…

38页 | 工商银行大数据平台助力全行数字化转型之路(免费下载)

【1】关注本公众号&#xff0c;转发当前文章到微信朋友圈 【2】私信发送 工商银行大数据平台 【3】获取本方案PDF下载链接&#xff0c;直接下载即可。 如需下载本方案PPT/WORD原格式&#xff0c;请加入微信扫描以下方案驿站知识星球&#xff0c;获取上万份PPT/WORD解决方案&a…

LlamIndex二 RAG应用开发

在AutoGen)系列后&#xff0c;我又开始了LlamIndex 系列。欢迎查询LlamaIndex 一 简单文档查询 - 掘金 (juejin.cn)了解LlamIndex&#xff0c;今天我们来看看LlamIndex的拿手戏&#xff0c;RAG应用开发。 何为RAG&#xff1f; RAG全称"Retrieval-Augmented Generation&q…

Linux C语言:指针和指针变量

一、指针的作用 使程序简洁、紧凑、高效有效地表示复杂的数据结构动态分配内存能直接访问硬件能够方便的处理字符串得到多于一个的函数返回值 二、内存、地址和变量 1、内存地址 2、变量和地址 1&#xff09;变量用来在程序中保存数据 比如: int k 58; //声明一个int变…

jupyter notebook默认工作目录修改

jupyter notebook默认工作目录修改 1、问题2、如何修改jupyter notebook默认工作目录 1、问题 anaconda安装好之后&#xff0c;我们启动jupyter notebook会发现其默认工作目录是在C盘&#xff0c;将工作目录放在C盘会让C盘很快被撑爆&#xff0c;我们应该将jupyter notebook默…

QT系列教程(8) QT 布局学习

简介 Qt 中的布局有三种方式&#xff0c;水平布局&#xff0c;垂直布局&#xff0c;栅格布局。 通过ui设置布局 我们先创建一个窗口应用程序&#xff0c;程序名叫layout&#xff0c;基类选择QMainWindow。但我们不使用这个mainwindow&#xff0c;我们创建一个Qt应用程序类Log…

MYTED | TED100篇打卡总结 辅助学习网站使用说明

文章目录 &#x1f4da;背景&#x1f407;timeline&#x1f407;版本记录&#x1f407;产出小结 &#x1f4da;功能说明&#x1f407;左侧&#x1f407;中间&#x1f407;右侧 &#x1f4da;背景 &#x1f407;timeline 在一个平常的下午&#xff0c;一次平常的桌面整理&#…