图像风格迁移(Neural Style Transfer, NST)是深度学习中一个令人着迷的应用,它能够将一张图像的风格应用到另一张图像上。例如,能够将梵高的画风应用到一张普通照片上。本文将详细解释如何使用PyTorch进行风格迁移,逐步分析代码,并讲解其中的关键技术。

1. 环境准备

在开始之前,确保安装了必要的库:

pip install torch torchvision pillow

2. 模型缓存目录设置

为了加速模型的加载,我们可以通过设置环境变量TORCH_HOME来指定模型缓存目录,避免每次运行代码时重新下载模型:

os.environ['TORCH_HOME'] = './model_directory'  # 你可以根据需要自定义目录

3. 加载图像

加载图像并进行预处理是风格迁移中的重要步骤。我们需要将图像转换为张量并进行归一化处理,以便与预训练的VGG19模型匹配:

def load_image(image_path, max_size=400):
    image = Image.open(image_path).convert('RGB')
    size = min(max_size, max(image.size))

    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    return transform(image).unsqueeze(0)

在这里,我们将图像调整为不大于400像素的正方形,并将其转换为适合VGG19模型输入的格式。

4. VGG19模型的特征提取

风格迁移的核心思想是将内容图像的高层次特征与风格图像的低层次特征结合。我们使用VGG19模型的前21层来提取图像的特征:

class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.features = vgg19(pretrained=True).features[:21].eval()

    def forward(self, x):
        features = []
        for i, layer in enumerate(self.features):
            x = layer(x)
            if i in {0, 5, 10, 19, 21}:
                features.append(x)
        return features

5. 内容与风格损失

内容损失衡量生成图像与内容图像的特征差异,而风格损失则是基于Gram矩阵来衡量生成图像与风格图像的差异。

  • 内容损失:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()

    def forward(self, input):
        return nn.functional.mse_loss(input, self.target)
  • 风格损失:
class StyleLoss(nn.Module):
    def __init__(self, target):
        super(StyleLoss, self).__init__()
        self.target = self.gram_matrix(target).detach()

    def gram_matrix(self, input):
        batch_size, channels, height, width = input.size()
        features = input.view(batch_size * channels, height * width)
        G = torch.mm(features, features.t())
        return G.div(batch_size * channels * height * width)

    def forward(self, input):
        G = self.gram_matrix(input)
        return nn.functional.mse_loss(G, self.target)

6. 图像风格迁移算法

核心算法将内容图像初始化为输入图像,并通过多次迭代优化,使其逐步接近目标风格图像,同时保持内容的完整性。我们使用LBFGS优化器来实现这一过程:

def style_transfer(content_img, style_img, num_steps=1000, style_weight=1e9, content_weight=1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    content_img = content_img.to(device)
    style_img = style_img.to(device)

    model = VGG().to(device)

    style_features = model(style_img)
    content_features = model(content_img)

    input_img = content_img.clone().requires_grad_(True).to(device)

    optimizer = optim.LBFGS([input_img])

    style_losses = []
    content_losses = []

    for sf, cf in zip(style_features, content_features):
        content_losses.append(ContentLoss(cf))
        style_losses.append(StyleLoss(sf))

    run = [0]
    while run[0] <= num_steps:

        def closure():
            optimizer.zero_grad()

            input_features = model(input_img)
            content_loss = 0
            style_loss = 0

            for cl, input_f in zip(content_losses, input_features):
                content_loss += content_weight * cl(input_f)

            for sl, input_f in zip(style_losses, input_features):
                style_loss += style_weight * sl(input_f)

            loss = content_loss + style_loss
            loss.backward()

            run[0] += 1
            if run[0] % 50 == 0:
                print(f'Step {run[0]}, Content Loss: {content_loss.item():4f}, Style Loss: {style_loss.item():4f}')

            return loss

        optimizer.step(closure)

    return input_img

7. 结果保存

生成的图像需要去除归一化并保存为常规图片格式:

def save_image(tensor, path):
    image = tensor.clone().detach()
    image = image.squeeze(0)
    image = transforms.ToPILImage()(image)
    image.save(path)

8. 主函数执行

整个过程可以通过主函数来执行,加载图像、进行风格迁移并保存结果:

if __name__ == '__main__':
    content_image_path = 'content_image.png'
    style_image_path = 'style_image.png'
    output_image_path = 'output_image.jpg'

    content_img = load_image(content_image_path)
    style_img = load_image(style_image_path)

    result = style_transfer(content_img, style_img)

    save_image(result, output_image_path)
    print(f"风格迁移完成,图像已保存为 {output_image_path}")

总结

本文展示了如何使用PyTorch和VGG19模型实现图像风格迁移。通过合理设置内容和风格损失的权重,我们可以生成既保留内容图像结构又具有风格图像艺术风格的全新图像。

完整代码

github:https://github.com/Yolumia/Image_style_transfer_base_vgg19/

点赞(0) 打赏

评论列表 共有 0 条评论

暂无评论

微信公众账号

微信扫一扫加关注

发表
评论
返回
顶部