使用PyTorch進行圖像風格轉換
譯者:bdqfork
作者: Alexis Jacq
簡介
本教程主要講解如何實現由Leon A. Gatys,Alexander S. Ecker和Matthias Bethge提出的 Neural-Style 演算法。Neural-Style或者叫Neural-Transfer,可以讓你使用一種新的風格將指定的圖片進行重構。<!--more-->這個演算法使用三張圖片,一張輸入圖片,一張內容圖片和一張風格圖片,並將輸入的圖片變得與內容圖片相似,且擁有風格圖片的優美風格。
基本原理
原理很簡單:我們定義兩個間距,一個用於內容D_C
,另一個用於風格D_S
。D_C
測量兩張圖片內容的不同,而D_S
用來測量兩張圖片風格的不同。然後,我們輸入第三張圖片,並改變這張圖片,使其與內容圖片的內容間距和風格圖片的風格間距最小化。現在,我們可以導入必要的包,開始圖像風格轉換。
導包並選擇設備
下面是一張實現圖像風格轉換所需包的清單。
torch
,torch.nn
,numpy
(使用PyTorch進行風格轉換必不可少的包)torch.optim
(高效的梯度下降)PIL
,PIL.Image
,matplotlib.pyplot
(載入和展示圖片)torchvision.transforms
(將PIL圖片轉換成張量)torchvision.models
(訓練或載入預訓練模型)copy
(對模型進行深度拷貝;系統包)
from __future__ import print_function
?
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
?
from PIL import Image
import matplotlib.pyplot as plt
?
import torchvision.transforms as transforms
import torchvision.models as models
?
import copy
?
下一步,我們選擇用哪一個設備來運行神經網路,導入內容和風格圖片。在大量圖片上運行圖像風格演算法需要很長時間,在GPU上運行可以加速。我們可以使用torch.cuda.is_available()
來判斷是否有可用的GPU。下一步,我們在整個教程中使用 torch.device
。 .to(device)
方法也被用來將張量或者模型移動到指定設備。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
載入圖片
現在我們將導入風格和內容圖片。原始的PIL圖片的值介於0到255之間,但是當轉換成torch張量時,它們的值被轉換成0到1之間。圖片也需要被重設成相同的維度。一個重要的細節是,注意torch庫中的神經網路用來訓練的張量的值為0到1之間。如果你嘗試將0到255的張量圖片載入到神經網路,然後激活的特徵映射將不能偵測到目標內容和風格。然而,Caffe庫中的預訓練網路用來訓練的張量值為0到255之間的圖片。
注意
這是一個下載本教程需要用到的圖片的鏈接: picasso.jpg 和 dancing.jpg。下載這兩張圖片並且將它們添加到你當前工作目錄中的 images
文件夾。
# desired size of the output image
imsize = 512 if torch.cuda.is_available() else 128 # use small size if no gpu
?
loader = transforms.Compose([
transforms.Resize(imsize), # scale imported image
transforms.ToTensor()]) # transform it into a torch tensor
?
def image_loader(image_name):
image = Image.open(image_name)
# fake batch dimension required to fit networks input dimensions
image = loader(image).unsqueeze(0)
return image.to(device, torch.float)
?
style_img = image_loader("./data/images/neural-style/picasso.jpg")
content_img = image_loader("./data/images/neural-style/dancing.jpg")
?
assert style_img.size() == content_img.size(),
"we need to import style and content images of the same size"
?
現在,讓我們創建一個方法,通過重新將圖片轉換成PIL格式來展示,並使用plt.imshow
展示它的拷貝。我們將嘗試展示內容和風格圖片來確保它們被正確的導入。
unloader = transforms.ToPILImage() # reconvert into PIL image
?
plt.ion()
?
def imshow(tensor, title=None):
image = tensor.cpu().clone() # we clone the tensor to not do changes on it
image = image.squeeze(0) # remove the fake batch dimension
image = unloader(image)
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
?
plt.figure()
imshow(style_img, title=Style Image)
?
plt.figure()
imshow(content_img, title=Content Image)
?
損失函數
內容損失
內容損失是一個表示一層內容間距的加權版本。這個方法使用網路中的L層的特徵映射F_XL
,該網路處理輸入X並返回在圖片X和內容圖片C之間的加權內容間距W_CL*D_C^L(X,C)
。該方法必須知道內容圖片(F_CL
)的特徵映射來計算內容間距。我們使用一個以F_CL
作為構造參數輸入的torch模型來實現這個方法。間距||F_XL-F_CL||^2
是兩個特徵映射集合之間的平均方差,可以使用nn.MSELoss
來計算。
我們將直接添加這個內容損失模型到被用來計算內容間距的卷積層之後。這樣每一次輸入圖片到網路中時,內容損失都會在目標層被計算。而且因為自動求導的緣故,所有的梯度都會被計算。現在,為了使內容損失層透明化,我們必須定義一個forward
方法來計算內容損失,同時返回該層的輸入。計算的損失作為模型的參數被保存。
class ContentLoss(nn.Module):
?
def __init__(self, target,):
super(ContentLoss, self).__init__()
# we detach the target content from the tree used
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
self.target = target.detach()
?
def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input
?
注意
重要細節:儘管這個模型的名稱被命名為 ContentLoss
, 它不是一個真實的PyTorch損失方法。如果你想要定義你的內容損失為PyTorch Loss方法,你必須創建一個PyTorch自動求導方法來手動的在backward
方法中重計算/實現梯度.
風格損失
風格損失模型與內容損失模型的實現方法類似。它要作為一個網路中的透明層,來計算相應層的風格損失。為了計算風格損失,我們需要計算Gram矩陣G_XL
。Gram矩陣是將給定矩陣和它的轉置矩陣的乘積。在這個應用中,給定的矩陣是L層特徵映射F_XL
的重塑版本。F_XL
被重塑成F?_XL
,一個KxN的矩陣,其中K是L層特徵映射的數量,N是任何向量化特徵映射F_XL^K
的長度。例如,第一行的F?_XL
與第一個向量化的F_XL^1
。
最後,Gram矩陣必須通過將每一個元素除以矩陣中所有元素的數量進行標準化。標準化是為了消除擁有很大的N維度F?_XL
在Gram矩陣中產生的很大的值。這些很大的值將在梯度下降的時候,對第一層(在池化層之前)產生很大的影響。風格特徵往往在網路中更深的層,所以標準化步驟是很重要的。
def gram_matrix(input):
a, b, c, d = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
?
features = input.view(a * b, c * d) # resise F_XL into hat F_XL
?
G = torch.mm(features, features.t()) # compute the gram product
?
# we normalize the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)
?
現在風格損失模型看起來和內容損失模型很像。風格間距也用G_XL
和G_SL
之間的均方差來計算。
class StyleLoss(nn.Module):
?
def __init__(self, target_feature):
super(StyleLoss, self).__init__()
self.target = gram_matrix(target_feature).detach()
?
def forward(self, input):
G = gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input
?
導入模型
現在我們需要導入預訓練的神經網路。我們將使用19層的VGG網路,就像論文中使用的一樣。
PyTorch的VGG模型實現被分為了兩個字Sequential
模型:features
(包含卷積層和池化層)和classifier
(包含全連接層)。我們將使用features
模型,因為我們需要每一層卷積層的輸出來計算內容和風格損失。在訓練的時候有些層會有和評估不一樣的行為,所以我們必須用.eval()
將網路設置成評估模式。
cnn = models.vgg19(pretrained=True).features.to(device).eval()
?
此外,VGG網路通過使用mean=[0.485, 0.456, 0.406]和std=[0.229, 0.224, 0.225]參數來標準化圖片的每一個通道,並在圖片上進行訓練。因此,我們將在把圖片輸入神經網路之前,先使用這些參數對圖片進行標準化。
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
?
# create a module to normalize input image so we can easily put it in a
# nn.Sequential
class Normalization(nn.Module):
def __init__(self, mean, std):
super(Normalization, self).__init__()
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1)
?
def forward(self, img):
# normalize img
return (img - self.mean) / self.std
?
一個Sequential
模型包含一個順序排列的子模型序列。例如,vff19.features
包含一個以正確的深度順序排列的序列(Conv2d, ReLU, MaxPool2d, Conv2d, ReLU…)。我們需要將我們自己的內容損失和風格損失層在感知到卷積層之後立即添加進去。因此,我們必須創建一個新的Sequential
模型,並正確的插入內容損失和風格損失模型。
# desired depth layers to compute style/content losses :
content_layers_default = [conv_4]
style_layers_default = [conv_1, conv_2, conv_3, conv_4, conv_5]
?
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_img, content_img,
content_layers=content_layers_default,
style_layers=style_layers_default):
cnn = copy.deepcopy(cnn)
?
# normalization module
normalization = Normalization(normalization_mean, normalization_std).to(device)
?
# just in order to have an iterable access to or list of content/syle
# losses
content_losses = []
style_losses = []
?
# assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
# to put in modules that are supposed to be activated sequentially
model = nn.Sequential(normalization)
?
i = 0 # increment every time we see a conv
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = conv_{}.format(i)
elif isinstance(layer, nn.ReLU):
name = relu_{}.format(i)
# The in-place version doesnt play very nicely with the ContentLoss
# and StyleLoss we insert below. So we replace with out-of-place
# ones here.
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = pool_{}.format(i)
elif isinstance(layer, nn.BatchNorm2d):
name = bn_{}.format(i)
else:
raise RuntimeError(Unrecognized layer: {}.format(layer.__class__.__name__))
?
model.add_module(name, layer)
?
if name in content_layers:
# add content loss:
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module("content_loss_{}".format(i), content_loss)
content_losses.append(content_loss)
?
if name in style_layers:
# add style loss:
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
model.add_module("style_loss_{}".format(i), style_loss)
style_losses.append(style_loss)
?
# now we trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break
?
model = model[:(i + 1)]
?
return model, style_losses, content_losses
?
下一步,我們選擇輸入圖片。你可以使用內容圖片的副本或者白雜訊。
input_img = content_img.clone()
# if you want to use white noise instead uncomment the below line:
# input_img = torch.randn(content_img.data.size(), device=device)
?
# add the original input image to the figure:
plt.figure()
imshow(input_img, title=Input Image)
?
梯度下降
和演算法的作者Leon Gatys的在 這裡建議的一樣,我們將使用L-BFGS演算法來進行我們的梯度下降。與訓練一般網路不同,我們訓練輸入圖片是為了最小化內容/風格損失。我們要創建一個PyTorch的L-BFGS優化器optim.LBFGS
,並傳入我們的圖片到其中,作為張量去優化。
def get_input_optimizer(input_img):
# this line to show that input is a parameter that requires a gradient
optimizer = optim.LBFGS([input_img.requires_grad_()])
return optimizer
?
最後,我們必須定義一個方法來展示圖像風格轉換。對於每一次的網路迭代,都將更新過的輸入傳入其中並計算損失。我們要運行每一個損失模型的backward
方法來計算它們的梯度。優化器需要一個「關閉」方法,它重新估計模型並且返回損失。
我們還有最後一個問題要解決。神經網路可能會嘗試使張量圖片的值超過0到1之間來優化輸入。我們可以通過在每次網路運行的時候將輸入的值矯正到0到1之間來解決這個問題。
def run_style_transfer(cnn, normalization_mean, normalization_std,
content_img, style_img, input_img, num_steps=300,
style_weight=1000000, content_weight=1):
"""Run the style transfer."""
print(Building the style transfer model..)
model, style_losses, content_losses = get_style_model_and_losses(cnn,
normalization_mean, normalization_std, style_img, content_img)
optimizer = get_input_optimizer(input_img)
?
print(Optimizing..)
run = [0]
while run[0] <= num_steps:
?
def closure():
# correct the values of updated input image
input_img.data.clamp_(0, 1)
?
optimizer.zero_grad()
model(input_img)
style_score = 0
content_score = 0
?
for sl in style_losses:
style_score += sl.loss
for cl in content_losses:
content_score += cl.loss
?
style_score *= style_weight
content_score *= content_weight
?
loss = style_score + content_score
loss.backward()
?
run[0] += 1
if run[0] % 50 == 0:
print("run {}:".format(run))
print(Style Loss : {:4f} Content Loss: {:4f}.format(
style_score.item(), content_score.item()))
print()
?
return style_score + content_score
?
optimizer.step(closure)
?
# a last correction...
input_img.data.clamp_(0, 1)
?
return input_img
?
最後,我們可以運行這個演算法。
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
content_img, style_img, input_img)
?
plt.figure()
imshow(output, title=Output Image)
?
# sphinx_gallery_thumbnail_number = 4
plt.ioff()
plt.show()
?
輸出:
Building the style transfer model..
Optimizing..
run [50]:
Style Loss : 4.169304 Content Loss: 4.235329
?
run [100]:
Style Loss : 1.145476 Content Loss: 3.039176
?
run [150]:
Style Loss : 0.716769 Content Loss: 2.663749
?
run [200]:
Style Loss : 0.476047 Content Loss: 2.500893
?
run [250]:
Style Loss : 0.347092 Content Loss: 2.410895
?
run [300]:
Style Loss : 0.263698 Content Loss: 2.358449
?
推薦閱讀: