一、InfoGAN 解析
InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets
https://arxiv.org/abs/1606.03657
1、GAN存在的问题
GAN通过生成器与判别器的对抗学习,最终可以得到一个与real data分布一致的fake data,但是由于生成器的输入z是一个连续的噪声信号,并且没有任何约束,导致GAN无法利用这个z,并且无法将z的具体维度与数据的语义特征对应起来,并不是一个可解释的表示。
2、InfoGAN:通过最大化生成对抗网络的信息进行可解释的表示学习
InfoGAN是对GAN的一种改进,曾被OPENAI称为2016年的五大突破之一。
InfoGAN以此为出发点,试图利用z寻找一个可解释的表达,于是将z进行拆解,可分为如下两部分:
- z:不可压缩的噪声
- c:可解释的隐变量(latent code)
作者希望通过约束隐变量c与生成数据之间的关系,可以使得c里面包括有对数据的可解释信息。如文中提到的MNIST数据集,c可以分为categorial latent code指向数字种类信息(0-9),continous latent code指向倾斜度、粗细。
3、模型的设计
整个框架由3部分组成:生成器G,判别器D和潜变量判别器Q,其中D和Q共享一套参数,除了最后一层全连接层不一样之外。
infoGAN相较于GAN就是加了一个loss,loss(c, c')
为什么通过添加一个c就可以进行解纠缠?
InfoGAN是通过无监督学习来得到一些潜在的特征表示,这些潜在的特征就包括数据的类别。
对于一个生成的一个数据集,虽然里面的数据没有标签信息,但仍存在潜在的类别差异,此时InfoGAN就可以提供一种无监督的方法,来辨别出数据中潜在的类别差异,并且可以通过控制潜在编码 latent code c来生成指定类别的数据。
这就很容易理解了,c是很多随机数值(连续或离散)与z一起输入进G,得到了很多生成图像,这些生成的图像其实是有差别的,例如1就代表是minist 1、5就代表是minist 5。如何用无监督的方式来特定的生成这些种类的图像呢?那我们就在生成过程中随机输入标签c,然后将生成的图像通过Q网络预测这个标签并与真实c求loss,loss很低就可以保证生成的图像含有这个信息c,也就是c与以c为条件的生成的图像互信息很强。那么我们在推理的时候就可以直接通过输入这个伪标签c+z就可以得到某个类别的图像了。
但是需要注意的是,你这里的c假如是1,但是你以1为c得到的真实图像可能是6,所以需要人为的后期观察不同c与生成图像之间的关系。
后面很多的文章都用了infoGAN的思想进行设计解纠缠模型:
- Unsupervised discovery of interpretable directions in the GAN latent space
- Navigating the GAN parameter space for semantic image editing
- Closed-Loop Unsupervised Representation Disentanglement with β-VAE Distillation and Diffusion Probabilistic Feedback
代码实现示例:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST # 使用MNIST作为示例数据集
from torchvision.transforms import ToTensor
import torch.nn.functional as F
# 定义超参数
latent_dim = 64 # 随机噪声维度
code_dim = 10 # 可解释隐变量维度(例如对于MNIST,可解释为数字类别)
batch_size = 64
epochs = 10
lr = 0.0002
lambda_info = 1.0 # 控制互信息最大化的权重
# 加载数据集
train_dataset = MNIST(root='./data', train=True, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 定义生成器G
class Generator(nn.Module):
def __init__(self, latent_dim, code_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(latent_dim + code_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Tanh() # 输出范围(-1, 1)
)
def forward(self, z, c):
input_code = torch.cat([z, c], dim=1)
img = self.fc(input_code)
return img.view(-1, 1, 28, 28) # 重塑为图像尺寸
# 定义判别器D
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2)
)
self.fc = nn.Sequential(
nn.Linear(256 * 3 * 3, 1),
nn.Sigmoid() # 输出范围(0, 1)
)
def forward(self, img):
features = self.conv(img)
features = features.view(features.size(0), -1)
validity = self.fc(features)
return validity
# 定义辅助网络Q(c|x)
class QNet(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2)
)
self.fc = nn.Linear(256 * 3 * 3, code_dim)
def forward(self, img):
features = self.conv(img)
features = features.view(features.size(0), -1)
c_pred = self.fc(features)
return c_pred
# 初始化模型
G = Generator(latent_dim, code_dim)
D = Discriminator()
Q = QNet()
# 定义优化器
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
Q_optimizer = optim.Adam(Q.parameters(), lr=lr, betas=(0.5, 0.999))
device = 'cpu'
# 训练循环
for epoch in range(epochs):
for real_images, _ in train_loader:
real_images = real_images.to(device)
## 生成器更新
# random noise
z = torch.randn(batch_size, latent_dim).to(device)
# 对于MNIST,一部分是0~9的离散数据编码,控制G生成的数字。另一部分是2个连续性编码,可以认为是控制字体。这里只写0-9
c = torch.randint(0, 10, (batch_size, code_dim)).to(device)
# 1)给定条件c+z,G预测出fake image
fake_images = G(z, c)
# 2)D鉴别生成的fake image
D_fake_pred = D(fake_images)
# 3)Q预测fake image的c
Q_fake_pred = Q(fake_images.detach()) # detach避免反向传播到G
G_loss = -torch.mean(D_fake_pred) - lambda_info * torch.mean(torch.sum(F.log_softmax(Q_fake_pred, dim=1) * c, dim=1))
G.zero_grad()
G_loss.backward()
G_optimizer.step()
# 判别器和辅助网络Q更新
D_real_pred = D(real_images)
D_real_loss = -torch.mean(D_real_pred)
z = torch.randn(batch_size, latent_dim).to(device)
c = torch.randint(0, 10, (batch_size, code_dim)).to(device)
fake_images = G(z, c)
D_fake_pred = D(fake_images.detach())
Q_fake_pred = Q(fake_images)
D_fake_loss = torch.mean(D_fake_pred)
Q_loss = -torch.mean(torch.sum(F.log_softmax(Q_fake_pred, dim=1) * c, dim=1))
D_loss = D_real_loss + D_fake_loss
Q_loss = Q_loss
D.zero_grad()
Q.zero_grad()
D_loss.backward()
Q_loss.backward()
D_optimizer.step()
Q_optimizer.step()
print(f"Epoch {epoch + 1}: G loss={G_loss.item():.4f}, D loss={D_loss.item():.4f}, Q loss={Q_loss.item():.4f}")
二、InfoGAN-CR 解析
https://arxiv.org/abs/1906.06034#
理解了infoGAN,infoGAN-CR就很好理解了:
infoGAN也是通过无监督的方式来进行的解纠缠,除了求infoGAN的这些loss,还需要再额外求一个通过固定c的某一个维度,变化剩余所有维度ci生成的一对图像(x',x'')的loss【注意变化的是c而不是z】
本站资源均来自互联网,仅供研究学习,禁止违法使用和商用,产生法律纠纷本站概不负责!如果侵犯了您的权益请与我们联系!
转载请注明出处: 免费源码网-免费的源码资源网站 » 【NeurIPS 2016】InfoGAN + InfoGAN-CR 解析
发表评论 取消回复