Neural Style Tranfer

Neural Style Tranfer

What is Neural Style Transfer?

It is a process of applying a new artistic style on an existing original image by using a pre-trained neural network, usually a convolutional neural network. e.g, AlexNet, VGG19, SqueezeNet etc.

How does it works?

  1. Take two input images(one is style image) and resize them to equal sizes.
  2. Load a pre-trained CNN and freeze all its parameters.
  3. We need to extract features related to the content or the objects present from the first image and features related to styles and textures from the second.
  4. The original neural style transfer paper suggests correlations between different features in different layers to obtain style information.
  5. Feature correlations are given by the Gram matrix G, where every cell (i, j) in G is the inner product between the vectorised feature maps i and j in a layer.
  6. Content Loss = difference between input and output image.
  7. Style Loss = difference between style and output image.

While selecting layers for features extraction different combinations can be used, here I'm using a combination which proved to be working.

Style transfer using SqueezeNet in PyTorch

Importing Libraries

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
from torchvision import models
from torchvision import transforms as tf
import torch.nn.functional as F

Loading pre-trained SqueezeNet model

squeezenet = models.squeezenet1_0(pretrained=True).features

for param in squeezenet.parameters():
    param.requires_grad_(False) # freezing all parameters or weights
device = torch.device("cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")

squeezenet.to(device)

Loading and transforming images

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225) #The mean and std values are of the image dataset on which the model was trained on(Imagenet).

def transformation(img):

    tasks = tf.Compose([tf.Resize(256), 
                        tf.ToTensor(),
                        tf.Normalize(mean, std)])

    img = tasks(img)  
    img = img.unsqueeze(0)    

    return img
content_img = Image.open("image.jpg").convert('RGB')
style_img   = Image.open("style.jpg").convert('RGB')
content_img = transformation(content_img).to(device)
style_img   = transformation(style_img).to(device)

Converting tensors to images and plotting

def tensor_to_image(tensor):

    image = tensor.clone().detach()
    image = image.cpu().numpy().squeeze()

    image = image.transpose(1, 2, 0)

    image *= np.array(std) + np.array(mean)
    image = image.clip(0, 1)

    return image
img = tensor_to_image(content_img)
fig = plt.figure()
fig.suptitle('Content Image')
plt.imshow(img)

img = tensor_to_image(style_img)
fig = plt.figure()
fig.suptitle('Style Image')
plt.imshow(img)

NOTE: Images may look bit different while plotting in notebooks

image.jpg:

im3.jpg

style_img.jpg:

pexels-steve-johnson-1570779.jpg

Extracting features from images

We use all the initial convolutional layers to extract style information and layer conv7_1 for content information

LAYERS_OF_INTEREST = {'0': 'conv1_1', 
                      '3': 'conv2_1',  
                      '4': 'conv3_1',
                      '5': 'conv4_1',
                      '6': 'conv5_1',
                      '7': 'conv6_1',  
                      '8': 'conv7_1',
                      '9': 'conv8_1'}

def apply_model_and_extract_features(image, model):
    x = image

    features = {}

    for name, layer in model._modules.items():
        x = layer(x)

        if name in LAYERS_OF_INTEREST:
            features[LAYERS_OF_INTEREST[name]] = x   

    return features
content_img_features = apply_model_and_extract_features(content_img, squeezenet)
style_img_features   = apply_model_and_extract_features(style_img, squeezenet)

Using Gram matrix for obtaining correlation between layer features.

def calculate_gram_matrix(tensor):

    _, channels, height, width = tensor.size()

    tensor = tensor.view(channels, height * width)    

    gram_matrix = torch.mm(tensor, tensor.t())

    gram_matrix = gram_matrix.div(channels * height * width) 

    return gram_matrix
style_features_gram_matrix = {layer: calculate_gram_matrix(style_img_features[layer]) for layer in 
                                                    style_img_features}

style_features_gram_matrix

style transfer

  • To transfer the style from one image to the other, we set the weight of every layer used to obtain style features i.e the initial layer of every convolutional block.
  • Define the optimizer function and the target image which is a copy of the content image.
weights = {'conv1_1': 1.0, 'conv2_1': 0.8, 'conv3_1': 0.65,
           'conv4_1': 0.5, 'conv5_1': 0.45, 'conv6_1': 0.3, 
           'conv7_1': 0.1, 'conv8_1': 0.15}

target = content_img.clone().requires_grad_(True).to(device)

optimizer = optim.Adam([target], lr=0.003)

Content and Style Loss minimization

for i in range(1, 5000):

    target_features = apply_model_and_extract_features(target, squeezenet)
    content_loss = F.mse_loss (target_features['conv7_1'], content_img_features['conv7_1'])

    style_loss = 0

    for layer in weights:

        target_feature = target_features[layer]

        target_gram_matrix = calculate_gram_matrix(target_feature)
        style_gram_matrix = style_features_gram_matrix[layer]

        layer_loss = F.mse_loss (target_gram_matrix, style_gram_matrix)
        layer_loss *= weights[layer]

        _, channels, height, width = target_feature.shape

        style_loss += layer_loss  

    total_loss = 1000000 * style_loss + content_loss

    if i % 100 == 0:
        print (f"Epoch {i}:, Style Loss : {style_loss:4f}, Content Loss : {content_loss:4f}")

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

plotting results

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
ax1.imshow(tensor_to_image(content_img))
ax2.imshow(tensor_to_image(target))

Output:

im1.png

Remarks

This shows one of the uses of CNN to do style transfers. You can play with the hyper-parameters and can also use other networks like Vgg or Alxenet to produce similar effects or even better ones.

References: