Main Content

Train Fast Style Transfer Network

This example shows how to train a network to transfer the style of an image to a second image. It is based on the architecture defined in [1].

This example is similar to Neural Style Transfer Using Deep Learning, but it works faster once you have trained the network on a style image S. This is because, to obtain the stylized image Y you only need to do a forward pass of the input image X to the network.

Find a high-level diagram of the training algorithm below. This uses three images to calculate the loss: the input image X, the transformed image Y and the style image S.

Note that the loss function uses the pretrained network VGG-16 to extract features from the images. You can find its implementation and mathematical definition in the Style Transfer Loss section of this example.

Load Training Data

Download and extract the COCO 2014 train images and captions from http://cocodataset.org/#download by clicking the "2014 Train images". Save the data in the folder specified by imageFolder. Extract the images into imageFolder. The COCO 2014 was collected by the Coco Consortium.

Create directories to store the COCO data set.

imageFolder = fullfile(tempdir,"coco");
if ~exist(imageFolder,'dir')
    mkdir(imageFolder);
end

Create an image datastore containing the COCO images.

imds = imageDatastore(imageFolder,'IncludeSubfolders',true);

Training can take a long time to run. If you want to decrease the training time at the cost of accuracy of the resulting network, then select a subset of the image datastore by setting fraction to a smaller value.

fraction = 1;
numObservations = numel(imds.Files);
imds = subset(imds,1:floor(numObservations*fraction));

To resize the images and convert them all to RGB, create an augmented image datastore.

augimds = augmentedImageDatastore([256 256],imds,'ColorPreprocessing',"gray2rgb");

Read the style image.

styleImage = imread('starryNight.jpg');
styleImage = imresize(styleImage,[256 256]);

Display the chosen style image.

figure
imshow(styleImage)
title("Style Image")

Define Image Transformer Network

Define the image transformer network. This is an image-to-image network. The network consists of 3 parts:

  1. The first part of the network takes as input an RGB image of size [256x256x3] and downsamples it to a feature map of size [64x64x128].

  2. The second part of the network consists of five identical residual blocks defined in the supporting function residualBlock.

  3. The third and final part of the network upsamples the feature map to the original size of the image and returns the transformed image. This last part uses the upsampleLayer, which is a custom layer attached to this example as a supporting file.

layers = [
    
    % First part.
    imageInputLayer([256 256 3], 'Name', 'input', 'Normalization','none')
    
    convolution2dLayer([9 9], 32, 'Padding','same','Name', 'conv1')
    groupNormalizationLayer('channel-wise','Name','norm1')
    reluLayer('Name', 'relu1')
    
    convolution2dLayer([3 3], 64, 'Stride', 2,'Padding','same','Name', 'conv2')
    groupNormalizationLayer('channel-wise' ,'Name','norm2')
    reluLayer('Name', 'relu2')
    
    convolution2dLayer([3 3], 128, 'Stride', 2, 'Padding','same','Name', 'conv3')
    groupNormalizationLayer('channel-wise' ,'Name','norm3')
    reluLayer('Name', 'relu3')
    
    % Second part. 
    residualBlock("1")
    residualBlock("2")
    residualBlock("3")
    residualBlock("4")
    residualBlock("5")
    
    % Third part.
    upsampleLayer('up1')
    convolution2dLayer([3 3], 64, 'Stride', 1, 'Padding','same','Name', 'upconv1')
    groupNormalizationLayer('channel-wise' ,'Name','norm6')
    reluLayer('Name', 'relu5')
    
    upsampleLayer('up2')
    convolution2dLayer([3 3], 32, 'Stride', 1, 'Padding','same','Name', 'upconv2')
    groupNormalizationLayer('channel-wise' ,'Name','norm7')
    reluLayer('Name', 'relu6')
    
    convolution2dLayer(9,3,'Padding','same','Name','conv_out')];

lgraph = layerGraph(layers);

Add missing connections in residual blocks.

lgraph = connectLayers(lgraph,"relu3","add1/in2");
lgraph = connectLayers(lgraph,"add1","add2/in2");
lgraph = connectLayers(lgraph,"add2","add3/in2");
lgraph = connectLayers(lgraph,"add3","add4/in2");
lgraph = connectLayers(lgraph,"add4","add5/in2");

Visualize the image transformer network in a plot.

figure
plot(lgraph)
title("Transform Network")

Create a dlnetwork object from the layer graph.

dlnetTransform = dlnetwork(lgraph);

Style Loss Network

This example uses a pretrained VGG-16 deep neural network to extract the features of the content and style images at different layers. These multilayer features are used to compute respective content and style losses.

To get a pretrained VGG-16 network, use the vgg16 function. If you do not have the required support packages installed, then the software provides a download link.

netLoss = vgg16;

To extract the feature necessary to calculate the loss you need the first 24 layers only. Extract and convert to a layer graph.

lossLayers = netLoss.Layers(1:24);
lgraph = layerGraph(lossLayers);

Convert to a dlnetwork.

dlnetLoss = dlnetwork(lgraph);

Define the Loss Function and Gram Matrix

Create the styleTransferLoss function defined in the Style Transfer Loss section of this example.

The function styleTransferLoss takes as input the loss network dlnetLoss, a mini-batch of input transformed images dlX, a mini-batch of transformed images dlY, an array containing the Gram matrices of the style image dlSGram, the weight associated with the content loss contentWeight and the weight associated with the style loss styleWeight. The function returns the total loss loss and the individual components: the content loss lossContent and the style loss lossStyle.

The styleTransferLoss function uses the supporting function createGramMatrix in the computation of the style loss.

The createGramMatrix function takes as an input the features extracted by the loss network and returns a stylistic representation for each image in a mini-batch. You can find the implementation and mathematical definition of the Gram matrix in the section Gram Matrix.

Define the Model Gradients Function

Create the function modelGradients, listed in the Model Gradients Function section of the example. This function takes as input the loss network dlnetLoss, the image transformer network dlnetTransform, a mini-batch of input images dlX, an array containing the Gram matrices of the style image dlSGram, the weight associated with the content loss contentWeight and the weight associated with the style loss styleWeight. The function returns the gradients of the loss with respect to the learnable parameters of the image transformer, the state of the image transformer network, the transformed images dlY, the total loss loss, the loss associated with the content lossContent and the loss associated with the style lossStyle.

Specify Training Options

Train with a mini-batch size of 4 for 2 epochs as in [1].

numEpochs = 2;
miniBatchSize = 4;

Set the read size of the augmented image datastore to the mini-batch size.

augimds.MiniBatchSize = miniBatchSize;

Specify the options for ADAM optimization. Specify a learn rate of 0.001 with a gradient decay factor of 0.01, and a squared gradient decay factor of 0.999.

learnRate = 0.001;
gradientDecayFactor = 0.9;
squaredGradientDecayFactor = 0.999;

Train on a GPU if one is available. Using a GPU requires Parallel Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU with compute capability 3.0 or higher.

executionEnvironment = "auto";

Specify the weight given to the style loss and the one given to the content loss in the calculation of the total loss.

Note that, in order to find a good balance between content and style loss, you might need to experiment with different combinations of weights.

weightContent = 1e-4;
weightStyle = 3e-8; 

Choose the plot frequency of the training progress. This specifies how many iterations there are between each plot update.

plotFrequency = 10;

Train Model

In order to be able to compute the loss during training, calculate the Gram matrices for the style image.

Convert the style image to dlarray.

dlS = dlarray(single(styleImage),'SSC');

In order to calculate the Gram matrix, feed the style image to the VGG-16 network and extract the activations at four different layers.

[dlSActivations1,dlSActivations2,dlSActivations3,dlSActivations4] = forward(dlnetLoss,dlS, ...
    'Outputs',["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

Calculate the Gram matrix for each set of activations using the supporting function createGramMatrix.

dlSGram{1} = createGramMatrix(dlSActivations1);
dlSGram{2} = createGramMatrix(dlSActivations2);
dlSGram{3} = createGramMatrix(dlSActivations3);
dlSGram{4} = createGramMatrix(dlSActivations4);

The training plots consists of two figures:

  1. A figure showing a plot of the losses during training

  2. A figure containing an input and an output image of the image transformer network

Initialize the training plots. You can check the details of the initialization in the supporting function initializeFigures. This function returns: the axis ax1 where you plot the loss, the axis ax2 where you plot the validation images, the animated line lineLossContent which contains the content loss, the animated line lineLossStyle which contains the style loss and the animated line lineLossTotal which contains the total loss.

[ax1,ax2,lineLossContent,lineLossStyle,lineLossTotal]=initializeStyleTransferPlots();

Initialize the average gradient and average squared gradient hyperparameters for the ADAM optimizer.

averageGrad = [];
averageSqGrad = [];

Calculate total number of training iterations.

numIterations = floor(augimds.NumObservations*numEpochs/miniBatchSize);

Initialize iteration number and timer before training.

iteration = 0;
start = tic;

Train the model. This could take a long time to run.

% Loop over epochs.
for i = 1:numEpochs
    
    % Reset and shuffle datastore.
    reset(augimds);
    augimds = shuffle(augimds);
    
    % Loop over mini-batches.
    while hasdata(augimds)
        iteration = iteration + 1;
        
        % Read mini-batch of data.
        data = read(augimds);
        
        % Ignore last partial mini-batch of epoch.
        if size(data,1) < miniBatchSize
            continue
        end
        
        % Extract the images from data store into a cell array.
        images = data{:,1};
        
        % Concatenate the images along the 4th dimension.
        X = cat(4,images{:});
        X = single(X);
        
        % Convert mini-batch of data to dlarray and specify the dimension labels
        % 'SSCB' (spatial, spatial, channel, batch).
        dlX = dlarray(X, 'SSCB');
        
        % If training on a GPU, then convert data to gpuArray.
        if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
            dlX = gpuArray(dlX);
        end
        
        % Evaluate the model gradients and the network state using
        % dlfeval and the modelGradients function listed at the end of the
        % example.
        [gradients,state,dlY,loss,lossContent,lossStyle] = dlfeval(@modelGradients, ...
            dlnetLoss,dlnetTransform,dlX,dlSGram,weightContent,weightStyle);
        
        dlnetTransform.State = state;
        
        % Update the network parameters.
        [dlnetTransform,averageGrad,averageSqGrad] = ...
            adamupdate(dlnetTransform,gradients,averageGrad,averageSqGrad,iteration,...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);
              
        
        % Every plotFequency iterations, plot the training progress.
        if mod(iteration,plotFrequency) == 0
            addpoints(lineLossTotal,iteration,double(gather(extractdata(loss))))
            addpoints(lineLossContent,iteration,double(gather(extractdata(lossContent))))
            addpoints(lineLossStyle,iteration,double(gather(extractdata(lossStyle))))
            
            % Use the first image of the mini-batch as a validation image.
            dlV = dlX(:,:,:,1);
            % Use the transformed validation image computed previously.
            dlVY = dlY(:,:,:,1);
            
            % To use the function imshow, convert to uint8.
            validationImage = uint8(gather(extractdata(dlV)));
            transformedValidationImage = uint8(gather(extractdata(dlVY)));
            
            % Plot the input image and the output image and increase size
            imshow(imtile({validationImage,transformedValidationImage}),'Parent',ax2);
        end
        
        % Display time elapsed since start of training and training completion percentage.
        D = duration(0,0,toc(start),'Format','hh:mm:ss');
        completionPercentage = round(iteration/numIterations*100,2);
        title(ax1,"Epoch: " + i + ", Iteration: " + iteration +" of "+ numIterations + "(" + completionPercentage + "%)" +", Elapsed: " + string(D))
        drawnow
        
    end
end

Stylize an Image

Once training has finished, you can use the image transformer on any image of your choice.

Load the image you would like to transform.

imFilename = 'peppers.png';
im = imread(imFilename);

Resize the input image to the input dimensions of the image transformer.

im = imresize(im,[256,256]);

Convert it to dlarray.

dlX = dlarray(single(im),'SSCB');

To use the GPU convert to gpuArray if one is available.

if canUseGPU
    dlX = gpuArray(dlX);
end

To apply the style to the image, forward pass it to the image transformer using the function predict.

dlY = predict(dlnetTransform,dlX);

Rescale the image into the range [0 255]. First, use the function tanh to rescale dlY to the range [-1 1]. Then, shift and scale the output to rescale into the [0 255] range.

Y = 255*(tanh(dlY)+1)/2;

Prepare Y for plotting. Use the function extraxtdata to extract the data from dlarray.Use the function gather to transfer Y from the GPU to the local workspace.

Y = uint8(gather(extractdata(Y)));

Show the input image (left) next to the stylized image (right).

figure
m = imtile({im,Y});
imshow(m)

Model Gradients Function

The function modelGradients takes as input the loss network dlnetLoss, the image transformer network dlnetTransform, a mini-batch of input images dlX, an array containing the Gram matrices of the style image dlSGram, the weight associated with the content loss contentWeight and the weight associated with the style loss styleWeight. It returns the gradients of the loss with respect to the learnable parameters of the image transformer, the state of the image transformer network, the transformed images dlY, the total loss loss, the loss associated with the content lossContent and the loss associated with the style lossStyle.

function [gradients,state,dlY,loss,lossContent,lossStyle] = ...
    modelGradients(dlnetLoss,dlnetTransform,dlX,dlSGram,contentWeight,styleWeight)

[dlY,state] = forward(dlnetTransform,dlX);

dlY = 255*(tanh(dlY)+1)/2;

[loss,lossContent,lossStyle] = styleTransferLoss(dlnetLoss,dlY,dlX,dlSGram,contentWeight,styleWeight);

gradients = dlgradient(loss,dlnetTransform.Learnables);

end

Style Transfer Loss

The function styleTransferLoss takes as input the loss network dlnetLoss, a mini-batch of input images dlX, a mini-batch of transformed images dlY, an array containing the Gram matrices of the style image dlSGram, the weights associated with the content and style contentWeight and styleWeight, respectively. It returns the total loss loss and the individual components: the content loss lossContent and the style loss lossStyle.

The content loss is a measure of how much difference in spatial structure there is between the input image X and the output images Y.

On the other hand, the style loss tells you how much difference in the stylistic appearance there is between the style image S and the output image Y.

The graph below explains the algorithm that styleTransferLoss implements to calculate the total loss.

First, the function passes the input images X, the transformed images Y and the style image S to the pretrained network VGG-16. This pretrained network extracts several features from these images. The algorithm then calculates the content loss by using the spatial features of the input image X and of the output image Y. Moreover, it calculates the style loss by using the stylistic features of the output image Y and of the style image S. Finally, it obtains the total loss by adding the content and style losses.

Content Loss

For each image in the mini-batch, the content loss function compares the features of the original image and of the transformed image output by the layer relu_3_3. In particular, it calculates the mean square error between the activations and returns the average loss for the mini-batch:

lossContent=1Nn=1Nmean([ϕ(Xn)-ϕ(Yn)]2),

where X contains the input images, Y contains the transformed images, N is the mini-batch size, and ϕ() represents the activations extracted at layer relu_3_3.

Style Loss

To calculate the style loss, for each single image in the mini-batch:

  1. Extract the activations at the layers relu1_2, relu2_2, relu3_3 and relu4_3.

  2. For each of the four activations ϕj compute the Gram matrix G(ϕj).

  3. Calculate the squared difference between the corresponding Gram matrices.

  4. Add up the four outputs for each layer j from the previous step.

To obtain the style loss for the whole mini-batch, compute the average of the style loss for each image n in the mini-batch:

lossStyle=1Nn=1Nj=14[G(ϕj(Xn))-G(ϕj(S))]2,

where j is the index of the layer, and G() is the Gram Matrix.

Total Loss

function [loss,lossContent,lossStyle] = styleTransferLoss(dlnetLoss,dlY,dlX, ...
    dlSGram,weightContent,weightStyle)

% Extract activations.
dlYActivations = cell(1,4);
[dlYActivations{1},dlYActivations{2},dlYActivations{3},dlYActivations{4}] = ...
    forward(dlnetLoss,dlY,'Outputs',["relu1_2" "relu2_2" "relu3_3" "relu4_3"]);

dlXActivations = forward(dlnetLoss,dlX,'Outputs','relu3_3');

% Calculate the mean square error between activations.
lossContent = mean((dlYActivations{3} - dlXActivations).^2,'all');

% Add up the losses for all the four activations.
lossStyle = 0;
for j = 1:4
    G = createGramMatrix(dlYActivations{j});
    lossStyle = lossStyle + sum((G - dlSGram{j}).^2,'all');
end

% Average the loss over the mini-batch.
miniBatchSize = size(dlX,4);
lossStyle = lossStyle/miniBatchSize;

% Apply weights.
lossContent = weightContent * lossContent;
lossStyle = weightStyle * lossStyle;

% Calculate the total loss.
loss = lossContent + lossStyle;

end

Residual Block

The residualBlock function returns an array of six layers. It consists of convolution layers, instance normalization layers, a ReLu layer and an addition layer. Note that groupNormalizationLayer('channel-wise') is simply an instance normalization layer.

function layers = residualBlock(name)

layers = [    
    convolution2dLayer([3 3], 128, 'Stride', 1,'Padding','same','Name', "convRes"+name+"_1")
    groupNormalizationLayer('channel-wise','Name',"normRes"+name+"_1")
    reluLayer('Name', "reluRes"+name+"_1")
    convolution2dLayer([3 3], 128, 'Stride', 1,'Padding','same', 'Name', "convRes"+name+"_2")
    groupNormalizationLayer('channel-wise','Name',"normRes"+name+"_2")
    additionLayer(2,'Name',"add"+name)];

end

Gram Matrix

The function createGramMatrix takes as an input the activations of a single layer and returns a stylistic representation for each image in a mini-batch. The input is a feature map of size [H, W, C, N], where H is the height, W is the width, C is the number of channels and N is the mini-batch size. The function outputs an array G of size [C,C,N]. Each subarray G(:,:,k) is the Gram matrix corresponding to the kth image in the mini-batch. Each entry G(i,j,k) of the Gram matrix represents the correlation between channels ci and cj, because each entry in channel ci multiplies the entry in the corresponding position in channel cj:

G(i,j,k)=1C×H×Wh=1Hw=1Wϕk(h,w,ci)ϕk(h,w,cj),

where ϕk are the activations for the kth image in the mini-batch.

The Gram matrix contains information about which features activate together but has no information about where the features occur in the image. This is because the summation over height and width loses the information about the spatial structure. The loss function uses this matrix as a stylistic representation of the image.

function G = createGramMatrix(activations)

[h,w,numChannels] = size(activations,1:3);

features = reshape(activations,h*w,numChannels,[]);
featuresT = permute(features,[2 1 3]);

G = dlmtimes(featuresT,features) / (h*w*numChannels);

end

References

  1. Johnson, Justin, Alexandre Alahi, and Li Fei-Fei. "Perceptual losses for real-time style transfer and super-resolution." European conference on computer vision. Springer, Cham, 2016.

See Also

| | | | | |

Related Topics