经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

1 PixelCNN

  • PixelCNN是DeepMind团队在论文Pixel Recurrent Neural Networks (16.01)提出的一种生成模型,实际上这篇论文共提出了两种架构:PixelRNNPixelCNN,两者的主要区别是前者用LSTM来建模,而PixelCNN是基于CNN的,相比RNN,CNN计算更高效,我们这里只讨论PixelCNN。

  • PixelCNN借用了NLP里的方法来生成图像。对于自然图像,每个像素值的取值范围为0~255,共256个离散值。PixelCNN模型会根据前i - 1个像素输出第i个像素的概率分布。

  • 训练时,和多分类任务一样,要根据第i个像素的真值和预测的概率分布求交叉熵损失函数

  • 采样时(图像生成时),会根据前i - 1个像素直接从预测的概率分布(多项分布)里采样出第i个像素。

1.1 单通道PixelCNN

1.1.1 掩码卷积

我们现在知道了PixelCNN的大体思路,就是根据前i - 1个像素输出第i个像素的概率分布。我们现在只考虑单通道图像,每个像素的颜色取值只有256种,那么很容易想到下面的实现方式:

在这里插入图片描述

但是只输出一个像素的概率分布,这样训练效率太低了。

  • 在训练时,我们可以输入一幅图像,同时让模型输出图像每一点像素的概率分布(如下图所示),这样就能通过每个像素的真值和模型预测的概率分布求交叉熵损失函数,进行并行训练。
  • 我们能这么做的原因是:在训练时,整幅训练图像是已知的,因此我们可以在一次前向传播后得到图像每一处的概率分布。
  • 当然,我们需要找到每个像素都忽略后续像素的信息的方法,即论文中提出的掩码卷积机制,我们后面再讲。

在这里插入图片描述

但是在生成图像(采样)时,还是要一个像素一个像素的生成(如下所示)

  • 在采样时,我们会先根据前i - 1个像素输出第i个像素的概率分布。
  • 然后,我们会从第i个像素的概率分布中进行采样(如下面代码所示)
# 假设颜色取值范围为[0, 7],下面为概率分布
prob_dist = torch.tensor([[0.1347, 0.1356, 0.1048, 0.1314, 0.1329, 0.1256, 0.1326, 0.1025]])

# 我们并不是取概率最大的像素,而是从概率分布中采样(例如下面取像素值6)
# torch.multinomial会从input这个概率分布中,取num_samples个值
pixel = torch.multinomial(input=prob_dist, num_samples=1).float() # tensor([[6.]])

在这里插入图片描述

我们现在已经知道了训练及采样的大体过程。但是,我们现在还是有一个疑问,如何保证训练时候,每个像素都忽略后续像素的信息?

PixelCNN论文里提出了一种掩码卷积机制,这种机制可以巧妙地掩盖住每个像素右侧和下侧的信息。

  • 具体来说,PixelCNN使用了两类掩码卷积:
    • 我们把两类掩码卷积分别称为「A类」和「B类」。
    • 二者都是对卷积操作的卷积核做了掩码处理,使得卷积核的右下部分不产生贡献。
    • A类和B类的唯一区别在于:卷积核的中心像素是否产生贡献
    • CNN的第一个的卷积层使用A类掩码卷积,之后每一层的都使用B类掩码卷积

在这里插入图片描述

我们来分析下这样设计的优点:

  • 对于一个7x7的图像,我们先用1次3x3 A类掩码卷积,再用若干次3x3 B类掩码卷积。我们观察图像中心处的像素在每次卷积后的感受野(即输入图像中哪些像素的信息能够传递到中心像素上)
    • 经过了第一个A类掩码卷积后,每个像素就已经看不到自己位置上的输入信息了。
    • 再经过两次B类掩码卷积后,中心像素能够看到左上角大部分像素的信息(如下图所示,我们发现还是会看漏少部分的信息,后面的Gated PixelCNN对此进行了改进)。
    • 这满足PixelCNN的约束。

在这里插入图片描述

  • 如果一直使用A类掩码卷积,每次卷积后中心像素都会看漏一些信息,最终就会导致看漏很多信息

在这里插入图片描述

  • 如果第一层就使用B类卷积,中心像素还是能看到自己位置的输入信息。这打破了PixelCNN的约束。

总结如下:

  • 逐像素预测只依赖于前面的像素,因此在选择卷积核时要进行掩码操作避免看到未来的值,因此,在第一层预测时可采用掩码卷积A
  • 由于CNN的逐像素预测是多层卷积,所以当第一层结束后,图像缺失部分已经有了预测值,因此在进行下一次/层卷积操作时可以利用当前像素的预测值,因此采用下列掩码卷积B
  • 需要注意的是,这里只考虑了单通道,如果扩展到RGB三个通道时,该如何进行mask呢?

1.1.2 PixelCNN的网络架构

  • 利用两类掩码卷积,PixelCNN满足了每个像素只能接受之前像素的信息这一约束。
  • 我们可以用任意一种CNN架构来实现PixelCNN。
  • 下图红色框所示部分是PixelCNN的网络结构,其中,第一个7x7卷积层用了A类掩码卷积,之后所有3x3卷积都是B类掩码卷积。

在这里插入图片描述

1.1.3 PixelCNN在MNIST数据集上的应用

1.1.3.1 模型

实现PixelCNN,最重要的是实现掩码卷积。

  • 掩码卷积的实现思路就是在卷积核组上设置一个mask。在前向传播的时候,先让卷积核组乘mask,再做普通的卷积。
  • 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import time
import einops
import cv2
import numpy as np
import os


class MaskConv2d(nn.Module):
    """
        掩码卷积的实现思路:
            在卷积核组上设置一个mask,在前向传播的时候,先让卷积核组乘mask,再做普通的卷积
    """
    def __init__(self, conv_type, *args, **kwags):
        super().__init__()
        assert conv_type in ('A', 'B')
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        # 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[0:H // 2] = 1
        mask[H // 2, 0:W // 2] = 1
        if conv_type == 'B':
            mask[H // 2, W // 2] = 1
        # 为了保证掩码能正确广播到4维的卷积核组上,我们做一个reshape操作
        mask = mask.reshape((1, 1, H, W))
        # register_buffer可以把一个变量加入成员变量的同时,记录到PyTorch的Module中
        # 每当执行model.to(device)把模型中所有参数转到某个设备上时,被注册的变量会跟着转。
        # 第三个参数表示被注册的变量是否要加入state_dict中以保存下来
        self.register_buffer(name='mask', tensor=mask, persistent=False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res

有了最核心的掩码卷积,我们来根据论文中的模型结构图把模型搭起来

在这里插入图片描述

  • 我们先实现残差块上图右部分的ResidualBlock,这里添加归一化
class ResidualBlock(nn.Module):
    """
        残差块ResidualBlock
    """
    def __init__(self, h, bn=True):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(2 * h, h, 1)
        self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()
        self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()
        self.conv3 = nn.Conv2d(h, 2 * h, 1)
        self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()

    def forward(self, x):
        # 1、ReLU + 1×1 Conv + bn
        y = self.relu(x)
        y = self.conv1(y)
        y = self.bn1(y)
        # 2、ReLU + 3×3 Conv(mask B) + bn
        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)
        # 3、ReLU + 1×1 Conv + bn
        y = self.relu(y)
        y = self.conv3(y)
        y = self.bn3(y)
        # 4、残差连接
        y = y + x
        return y
  • 有了所有这些基础模块后,我们就可以拼出最终的PixelCNN了。
  • 注意,我们可以自己决定颜色有几个亮度级别。要修改亮度级别的数量,只需要修改softmax输出的通道数color_level。
class PixelCNN(nn.Module):
    def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):
        super().__init__()
        self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)
        self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()
        self.residual_blocks = nn.ModuleList()
        for _ in range(n_blocks):
            self.residual_blocks.append(ResidualBlock(h, bn))
        self.relu = nn.ReLU()
        self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)
        self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
        self.out = nn.Conv2d(linear_dim, color_level, 1)

    def forward(self, x):
        # 1、7 × 7 conv(mask A)
        x = self.conv1(x)
        x = self.bn1(x)
        # 2、Multiple residual blocks
        for block in self.residual_blocks:
            x = block(x)
        x = self.relu(x)
        # 3、1 × 1 conv
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.out(x)
        return x
1.1.3.2 数据集及训练

准备好了模型代码,我们可以编写训练脚本了:

  • PixelCNN有15个残差块,中间特征的通道数为128,输出前线性层的通道数为32
def get_dataloader(batch_size: int):
    dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',
                                         train=True,
                                         transform=ToTensor()
                                         )
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):
    """训练过程"""
    dataloader = get_dataloader(batch_size)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    loss_fn = nn.CrossEntropyLoss()

    tic = time.time()
    for e in range(n_epochs):
        total_loss = 0
        for x, _ in dataloader:
            current_batch_size = x.shape[0]
            x = x.to(device)
            # 把训练集的浮点颜色值转换成[0, color_level-1]之间的整型标签
            y = torch.ceil(x * (color_level - 1)).long()
            y = y.squeeze(1)
            predict_y = model(x)
            loss = loss_fn(predict_y, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * current_batch_size
        total_loss /= len(dataloader.dataset)
        toc = time.time()
        torch.save(model.state_dict(), model_path)
        print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')

if __name__ == '__main__':
    os.makedirs('work_dirs', exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # 需要注意的是:MNIST数据集的大部分像素都是0和255
    color_level = 8  # or 256
    # 1、创建PixelCNN模型
    model = PixelCNN(n_blocks=15, h=128, linear_dim=32, bn=True, color_level=color_level)
    # 2、模型训练
    model_path = f'work_dirs/model_pixelcnn_{color_level}.pth'
    train(model, device, model_path)
    # 3、采样
    sample(model, device, model_path, f'work_dirs/pixelcnn_{color_level}.jpg')        
1.1.3.3 采样
  • 在采样时,我们把x初始化成一个0张量。
  • 之后,循环遍历每一个像素,输入x,把预测出的下一个像素填入x.
def sample(model, device, model_path, output_path, n_sample=1):
    """
        把x初始化成一个0张量。
        循环遍历每一个像素,输入x,把预测出的下一个像素填入x
    """
    model.eval()
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    C, H, W = get_img_shape()  # (1, 28, 28)
    x = torch.zeros((n_sample, C, H, W)).to(device)
    with torch.no_grad():
        for i in range(H):
            for j in range(W):
                # 我们先获取模型的输出,再用softmax转换成概率分布
                output = model(x)
                prob_dist = F.softmax(output[:, :, i, j], -1)
                # 再用torch.multinomial从概率分布里采样出【1】个[0, color_level-1]的离散颜色值
                # 再除以(color_level - 1)把离散颜色转换成浮点[0, 1]
                pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)
                # 最后把新像素填入到生成图像中
                x[:, :, i, j] = pixel
    # 乘255变成一个用8位字节表示的图像
    imgs = x * 255
    imgs = imgs.clamp(0, 255)
    imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))

    imgs = imgs.detach().cpu().numpy().astype(np.uint8)
    cv2.imwrite(output_path, imgs)

1.2 多通道PixelCNN

如下图所示,作者假设RGB三个通道之间存在相互影响

  • 其中红色预测不受蓝色和绿色通道的影响,只受上下文影响
  • 绿色红色通道和上下文影响,但不受蓝色通道影响;
  • 蓝色通道受上下文、红色通道、绿色通道影响

在这里插入图片描述

更具体地,我们规定一个子像素只由它之前的子像素决定,生成图像时,我们一个子像素一个子像素地生成

  • 如下图所示,对于RGB图像,R子像素由它之前所有像素决定
  • G子像素由它的R子像素和之前所有像素决定,
  • B子像素由它的R、G子像素和它之前所有像素决定。

在这里插入图片描述

如下图所示,由于现在要预测三个颜色通道,网络的输出应该是一个[256x3, H, W]形状的张量

  • 即每个像素输出三个概率分布,分别表示R、G、B取某种颜色的概率。
  • 同时,本质上来讲,网络是在并行地为每个像素计算3组结果。因此,为了达到同样的性能,网络所有的特征图的通道数也要乘3。

在这里插入图片描述

图像变为多通道后,A类卷积和B类卷积的定义也需要做出一些调整。我们不仅要考虑像素在空间上的约束,还要考虑一个像素内子像素间的约束。为此,我们要用不同的策略实现约束。为了方便描述,我们设卷积核组的形状为[o, i, h, w],其中o为输出通道数,i为输入通道数,h, w为卷积核的高和宽。

  • 对于通道间的约束,我们要在o, i两个维度上设置掩码,如下图左边所示。
    • 设输出通道可以被拆成三组o1, o2, o3,输入通道可以被拆成三组i1, i2, i3
      • o1 = 0:o/3, o2 = o/3:o*2/3, o3 = o*2/3:o
      • i1 = 0:i/3, i2 = i/3:i*2/3, i3 = i*2/3:i
      • 序号1, 2, 3分别表示这组通道是在维护R, G, B的计算。
    • 我们对输入通道组和输出通道组之间进行约束。
    • 对于A类卷积,我们令o1看不到i1, i2, i3o2看不到i2, i3o3看不到i3
    • 对于B类卷积,我们取消每个通道看不到自己的限制,即在A类卷积的基础上令o1看到i1o2看到i2o3看到i3
  • 如下图右边所示,对于空间上的约束,我们还是和之前一样,在h, w两个维度上设置掩码。由于「是否看到自己」的处理已经在o, i两个维度里做好了,我们直接在空间上用原来的B类卷积就行。

在这里插入图片描述

  • 下面给出三维掩码示意图方便理解:

在这里插入图片描述

2 Gated PixelCNN

2.1 Gated PixelCNN简述

  • 可以参考大神讲解:Gated PixelCNN (sergeiturukin.com)

  • PixelCNN的掩码卷积其实有一个重大漏洞:像素存在视野盲区。如下图所示,中心像素看不到右上角三个本应该能看到的像素。

在这里插入图片描述

在这里插入图片描述

  • 如下图所示,Gated PixelCNN使用了两种卷积,即垂直卷积和水平卷积,来分别维护一个像素上侧的信息和左侧的信息
    • 垂直卷积的结果只是一些临时量
    • 而水平卷积的结果最终会被网络输出
    • 使用这种新的掩码卷积机制后,每个像素能正确地收到之前所有像素的信息了。

在这里插入图片描述

  • Gated PixelCNN用下图的模块代替了原PixelCNN的普通残差模块。
  • 模块的输入输出都是两个量,左边的量是垂直卷积中间结果,右边的量是最后用来计算输出的量。
  • 垂直卷积的结果会经过偏移和一个1x1卷积,再加到水平卷积的结果上。
  • 两条计算路线在输出前都会经过门激活单元。所谓门激活单元,就是输入两个形状相同的量,一个做tanh,一个做sigmoid,两个结果相乘再输出。
  • 此外,模块右侧还有一个残差连接。

在这里插入图片描述

2.2 Gated PixelCNN在MNIST数据集上的应用

2.2.1 创建模型

  • 首先,实现垂直卷积和水平卷积
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import time
import einops
import cv2
import numpy as np
import os


class VerticalMaskConv2d(nn.Module):
    """
        垂直卷积
    """
    def __init__(self, *args, **kwags):
        super().__init__()
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[0:H // 2 + 1] = 1
        mask = mask.reshape((1, 1, H, W))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res


class HorizontalMaskConv2d(nn.Module):
    """
        水平卷积
    """
    def __init__(self, conv_type, *args, **kwags):
        super().__init__()
        assert conv_type in ('A', 'B')
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[H // 2, 0:W // 2] = 1
        if conv_type == 'B':
            mask[H // 2, W // 2] = 1
        mask = mask.reshape((1, 1, H, W))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res
# 垂直卷积
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [0., 0., 0.]]]])
# A类水平卷积
tensor([[[[0., 0., 0.],
          [1., 0., 0.],
          [0., 0., 0.]]]])
# B类水平卷积
tensor([[[[0., 0., 0.],
          [1., 1., 0.],
          [0., 0., 0.]]]])
  • 我们现在搭建Gated Block模块,这也是最难理解的一部分。
  • 可以参考的解释:https://segmentfault.com/a/1190000041189859?utm_source=sf-similar-article

在这里插入图片描述

  • # 这里比较难理解,通过对图像进行零填充并裁剪图像底部,可以确保垂直和水平堆栈之间的因果关系
    v_to_h = v[:, :, 0:-1]
    v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
    # 注意到,v和i相加的位置只差了一个单位。
    # 为了把相加的位置对齐,我们要把v往下移一个单位,把原来在i-1处的信息移到i上。
    # 这样,移动过后的v_to_h就能和h直接用向量加法并行地加到一起了。
    

在这里插入图片描述

  • 维护两个v, h两个变量,分别表示垂直卷积部分的结果和水平卷积部分的结果。
    • v会经过一个垂直掩码卷积和一个门激活函数。
    • h会经过一个类似于残差块的结构,只不过第一个卷积是水平掩码卷积、激活函数是门激活函数、进入激活函数之前会和垂直卷积的信息融合。
class GatedBlock(nn.Module):

    def __init__(self, conv_type, in_channels, p, bn=True):
        super().__init__()
        self.conv_type = conv_type
        self.p = p
        self.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,
                                           1)
        self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.h_output_conv = nn.Conv2d(p, p, 1)
        self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()

    def forward(self, v_input, h_input):
        # v代表垂直卷积部分的结果
        v = self.v_conv(v_input)
        v = self.bn1(v)
        # Note: 重点代码
        # 为了把v的信息贴到h上,我们并不是像前面的示意图所写的令v上移一个单位
        # 而是用下面的代码令v下移了一个单位(下移即去掉最下面一行,往最上面一行填0)
        v_to_h = v[:, :, 0:-1]
        v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
        # 和h相加前,先经过 1×1 conv
        v_to_h = self.v_to_h_conv(v_to_h)
        v_to_h = self.bn2(v_to_h)
        # 分为两份,经过tanh 和 sigmoid
        v1, v2 = v[:, :self.p], v[:, self.p:]
        v1 = torch.tanh(v1)
        v2 = torch.sigmoid(v2)
        v = v1 * v2

        # h代表水平卷积部分的结果
        h = self.h_conv(h_input)
        h = self.bn3(h)
        h = h + v_to_h
        # 分为两份,经过tanh 和 sigmoid
        h1, h2 = h[:, :self.p], h[:, self.p:]
        h1 = torch.tanh(h1)
        h2 = torch.sigmoid(h2)
        h = h1 * h2
        h = self.h_output_conv(h)
        h = self.bn4(h)
        # 在网络的第一层,每个数据是不能看到自己的。
        # 所以,当GatedBlock发现卷积类型为A类时,不应该对h做残差连接。
        if self.conv_type == 'B':
            h = h + h_input
        return v, h
  • 最后,我们来用GatedBlock搭出Gated PixelCNN
  • Gated PixelCNN和PixelCNN的结构非常相似,只是把ResidualBlock替换成了GatedBlock而已。
class GatedPixelCNN(nn.Module):

    def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
        super().__init__()
        self.block1 = GatedBlock('A', 1, p, bn)
        self.blocks = nn.ModuleList()
        for _ in range(n_blocks):
            self.blocks.append(GatedBlock('B', p, p, bn))
        self.relu = nn.ReLU()
        self.linear1 = nn.Conv2d(p, linear_dim, 1)
        self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
        self.out = nn.Conv2d(linear_dim, color_level, 1)

    def forward(self, x):
        v, h = self.block1(x, x)
        for block in self.blocks:
            v, h = block(v, h)
        x = self.relu(h)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.out(x)
        return x

2.2.2 数据集、训练及采样

  • 数据集、训练及采样和PixelCNN一模一样,不再赘述。
def get_dataloader(batch_size: int):
    dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',
                                         train=True,
                                         transform=ToTensor()
                                         )
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):
    """训练过程"""
    dataloader = get_dataloader(batch_size)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    loss_fn = nn.CrossEntropyLoss()

    tic = time.time()
    for e in range(n_epochs):
        total_loss = 0
        for x, _ in dataloader:
            current_batch_size = x.shape[0]
            x = x.to(device)
            # 把训练集的浮点颜色值转换成0~color_level-1之间的整型标签的
            y = torch.ceil(x * (color_level - 1)).long()
            y = y.squeeze(1)
            predict_y = model(x)
            loss = loss_fn(predict_y, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * current_batch_size
        total_loss /= len(dataloader.dataset)
        toc = time.time()
        torch.save(model.state_dict(), model_path)
        print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')


def get_img_shape():
    return (1, 28, 28)


def sample(model, device, model_path, output_path, n_sample=1):
    """
        把x初始化成一个0张量。
        循环遍历每一个像素,输入x,把预测出的下一个像素填入x
    """
    model.eval()
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    C, H, W = get_img_shape()  # (1, 28, 28)
    x = torch.zeros((n_sample, C, H, W)).to(device)
    with torch.no_grad():
        for i in range(H):
            for j in range(W):
                # 我们先获取模型的输出,再用softmax转换成概率分布
                output = model(x)
                prob_dist = F.softmax(output[:, :, i, j], -1)
                # 再用torch.multinomial从概率分布里采样出【1个】0~(color_level-1)的离散颜色值
                # 再除以(color_level - 1)把离散颜色转换成浮点颜色(因为网络是输入是浮点颜色)
                pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)
                # 最后把新像素填入生成图像
                x[:, :, i, j] = pixel

    imgs = x * 255
    imgs = imgs.clamp(0, 255)
    imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))

    imgs = imgs.detach().cpu().numpy().astype(np.uint8)
    cv2.imwrite(output_path, imgs)


if __name__ == '__main__':
    os.makedirs('work_dirs', exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    color_level = 8  # or 256
    # 1、创建GatedPixelCNN模型
    model = GatedPixelCNN(n_blocks=15, p=128, linear_dim=32, bn=True, color_level=color_level)
    # 2、模型训练
    model_path = f'work_dirs/model_gatedpixelcnn_{color_level}.pth'
    train(model, device, model_path, batch_size=1)
    # 3、采样
    sample(model, device, model_path, f'work_dirs/gatedpixelcnn_{color_level}.jpg')

点赞(0) 打赏

评论列表 共有 0 条评论

暂无评论

微信公众账号

微信扫一扫加关注

发表
评论
返回
顶部