一、InfoGAN 解析 

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets

https://arxiv.org/abs/1606.03657


1、GAN存在的问题

GAN通过生成器与判别器的对抗学习,最终可以得到一个与real data分布一致的fake data,但是由于生成器的输入z是一个连续的噪声信号,并且没有任何约束,导致GAN无法利用这个z,并且无法将z的具体维度与数据的语义特征对应起来,并不是一个可解释的表示。

vanilla GAN

2、InfoGAN:通过最大化生成对抗网络的信息进行可解释的表示学习

InfoGAN是对GAN的一种改进,曾被OPENAI称为2016年的五大突破之一

InfoGAN以此为出发点,试图利用z寻找一个可解释的表达,于是将z进行拆解,可分为如下两部分:

  1. z:不可压缩的噪声
  2. c:可解释的隐变量(latent code)

作者希望通过约束隐变量c与生成数据之间的关系,可以使得c里面包括有对数据的可解释信息。如文中提到的MNIST数据集,c可以分为categorial latent code指向数字种类信息(0-9),continous latent code指向倾斜度、粗细。

3、模型的设计

整个框架由3部分组成:生成器G,判别器D和潜变量判别器Q,其中D和Q共享一套参数,除了最后一层全连接层不一样之外。

infoGAN相较于GAN就是加了一个loss,loss(c, c')

为什么通过添加一个c就可以进行解纠缠?

InfoGAN是通过无监督学习来得到一些潜在的特征表示,这些潜在的特征就包括数据的类别。

对于一个生成的一个数据集,虽然里面的数据没有标签信息,但仍存在潜在的类别差异,此时InfoGAN就可以提供一种无监督的方法,来辨别出数据中潜在的类别差异,并且可以通过控制潜在编码 latent code c来生成指定类别的数据。

这就很容易理解了,c是很多随机数值(连续或离散)与z一起输入进G,得到了很多生成图像,这些生成的图像其实是有差别的,例如1就代表是minist 1、5就代表是minist 5。如何用无监督的方式来特定的生成这些种类的图像呢?那我们就在生成过程中随机输入标签c,然后将生成的图像通过Q网络预测这个标签并与真实c求loss,loss很低就可以保证生成的图像含有这个信息c,也就是c与以c为条件的生成的图像互信息很强。那么我们在推理的时候就可以直接通过输入这个伪标签c+z就可以得到某个类别的图像了。

但是需要注意的是,你这里的c假如是1,但是你以1为c得到的真实图像可能是6,所以需要人为的后期观察不同c与生成图像之间的关系。

后面很多的文章都用了infoGAN的思想进行设计解纠缠模型:

  1. Unsupervised discovery of interpretable directions in the GAN latent space
  2. Navigating the GAN parameter space for semantic image editing
  3. Closed-Loop Unsupervised Representation Disentanglement with β-VAE Distillation and Diffusion Probabilistic Feedback


代码实现示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST  # 使用MNIST作为示例数据集
from torchvision.transforms import ToTensor
import torch.nn.functional as F

# 定义超参数
latent_dim = 64  # 随机噪声维度
code_dim = 10  # 可解释隐变量维度(例如对于MNIST,可解释为数字类别)
batch_size = 64
epochs = 10
lr = 0.0002
lambda_info = 1.0  # 控制互信息最大化的权重

# 加载数据集
train_dataset = MNIST(root='./data', train=True, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


# 定义生成器G
class Generator(nn.Module):
    def __init__(self, latent_dim, code_dim):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + code_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()  # 输出范围(-1, 1)
        )

    def forward(self, z, c):
        input_code = torch.cat([z, c], dim=1)
        img = self.fc(input_code)
        return img.view(-1, 1, 28, 28)  # 重塑为图像尺寸


# 定义判别器D
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Sequential(
            nn.Linear(256 * 3 * 3, 1),
            nn.Sigmoid()  # 输出范围(0, 1)
        )

    def forward(self, img):
        features = self.conv(img)
        features = features.view(features.size(0), -1)
        validity = self.fc(features)
        return validity


# 定义辅助网络Q(c|x)
class QNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Linear(256 * 3 * 3, code_dim)

    def forward(self, img):
        features = self.conv(img)
        features = features.view(features.size(0), -1)
        c_pred = self.fc(features)
        return c_pred


# 初始化模型
G = Generator(latent_dim, code_dim)
D = Discriminator()
Q = QNet()

# 定义优化器
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
Q_optimizer = optim.Adam(Q.parameters(), lr=lr, betas=(0.5, 0.999))


device = 'cpu'
# 训练循环
for epoch in range(epochs):
    for real_images, _ in train_loader:
        real_images = real_images.to(device)

        ## 生成器更新
        # random noise
        z = torch.randn(batch_size, latent_dim).to(device)
        # 对于MNIST,一部分是0~9的离散数据编码,控制G生成的数字。另一部分是2个连续性编码,可以认为是控制字体。这里只写0-9
        c = torch.randint(0, 10, (batch_size, code_dim)).to(device)
        # 1)给定条件c+z,G预测出fake image
        fake_images = G(z, c)
        # 2)D鉴别生成的fake image
        D_fake_pred = D(fake_images)
        # 3)Q预测fake image的c
        Q_fake_pred = Q(fake_images.detach())  # detach避免反向传播到G
        G_loss = -torch.mean(D_fake_pred) - lambda_info * torch.mean(torch.sum(F.log_softmax(Q_fake_pred, dim=1) * c, dim=1))

        G.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        # 判别器和辅助网络Q更新
        D_real_pred = D(real_images)
        D_real_loss = -torch.mean(D_real_pred)

        z = torch.randn(batch_size, latent_dim).to(device)
        c = torch.randint(0, 10, (batch_size, code_dim)).to(device)
        fake_images = G(z, c)
        D_fake_pred = D(fake_images.detach())
        Q_fake_pred = Q(fake_images)
        D_fake_loss = torch.mean(D_fake_pred)
        Q_loss = -torch.mean(torch.sum(F.log_softmax(Q_fake_pred, dim=1) * c, dim=1))

        D_loss = D_real_loss + D_fake_loss
        Q_loss = Q_loss

        D.zero_grad()
        Q.zero_grad()
        D_loss.backward()
        Q_loss.backward()
        D_optimizer.step()
        Q_optimizer.step()

    print(f"Epoch {epoch + 1}: G loss={G_loss.item():.4f}, D loss={D_loss.item():.4f}, Q loss={Q_loss.item():.4f}")

二、InfoGAN-CR 解析

https://arxiv.org/abs/1906.06034#

理解了infoGAN,infoGAN-CR就很好理解了:

infoGAN也是通过无监督的方式来进行的解纠缠,除了求infoGAN的这些loss,还需要再额外求一个通过固定c的某一个维度,变化剩余所有维度ci生成的一对图像(x',x'')的loss【注意变化的是c而不是z】

InfoGAN - 简书

深度探索:机器学习中的信息最大化GAN(InfoGAN)原理及其应用-CSDN博客

infoGAN解析 - 简书

点赞(0) 打赏

评论列表 共有 0 条评论

暂无评论

微信公众账号

微信扫一扫加关注

发表
评论
返回
顶部