Contenido principal

Entrenar redes generativas antagónicas (GAN)

En este ejemplo se muestra cómo entrenar una red generativa antagónica para generar imágenes.

Una red generativa antagónica (GAN) es un tipo de red de deep learning que puede generar datos con características similares a las de los datos reales de entrada.

La función trainnet no admite el entrenamiento de GAN, por lo que se debe implementar un bucle de entrenamiento personalizado. Para entrenar la GAN con un bucle de entrenamiento personalizado, puede utilizar objetos dlarray y dlnetwork para la diferenciación automática.

Una GAN consta de dos redes que se entrenan juntas:

  1. Generador: dado un vector de valores aleatorios (entradas latentes) como entrada, esta red genera datos con la misma estructura que los datos de entrenamiento.

  2. Discriminador: dados los lotes de datos que contienen observaciones de los datos de entrenamiento y de los datos generados por el generador, esta red intenta clasificar las observaciones como "real" o "generated".

Este diagrama ilustra la red del generador de una GAN que genera imágenes a partir de vectores de entradas aleatorias.

Este diagrama ilustra la estructura de una GAN.

Para entrenar una GAN, entrene las dos redes simultáneamente para maximizar el rendimiento de ambas:

  • Entrene el generador para generar datos que "engañen" al discriminador.

  • Entrene el discriminador para distinguir entre datos reales y generados.

Para optimizar el rendimiento del generador, maximice la pérdida del discriminador cuando se proporcionen datos generados. Es decir, el objetivo del generador es generar datos que el discriminador clasifique como "real".

Para optimizar el rendimiento del discriminador, minimice la pérdida del discriminador cuando se proporcionen lotes de datos reales y generados. Es decir, el objetivo del discriminador es no ser "engañado" por el generador.

Idealmente, estas estrategias dan como resultado un generador que genera datos convincentemente realistas y un discriminador que ha aprendido representaciones de características fuertes que son representativas de los datos de entrenamiento.

En este ejemplo se entrena una red GAN para generar imágenes utilizando el conjunto de datos Flowers [1], que contiene imágenes de flores.

Cargar los datos de entrenamiento

Descargue y extraiga el conjunto de datos Flowers [1].

url = "http://download.tensorflow.org/example_images/flower_photos.tgz";
downloadFolder = tempdir;
filename = fullfile(downloadFolder,"flower_dataset.tgz");

imageFolder = fullfile(downloadFolder,"flower_photos");
if ~datasetExists(imageFolder)
    disp("Downloading Flowers data set (218 MB)...")
    websave(filename,url);
    untar(filename,downloadFolder)
end

Cree un almacén de datos de imágenes que contenga las fotos de las flores.

imds = imageDatastore(imageFolder,IncludeSubfolders=true);

Aumente los datos para incluir volteo horizontal aleatorio y cambie el tamaño de las imágenes para que tengan un tamaño de 64 por 64.

augmenter = imageDataAugmenter(RandXReflection=true);
augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);

Definir redes generativas antagónicas

Una GAN consta de dos redes que se entrenan juntas:

  1. Generador: dado un vector de valores aleatorios (entradas latentes) como entrada, esta red genera datos con la misma estructura que los datos de entrenamiento.

  2. Discriminador: dados los lotes de datos que contienen observaciones de los datos de entrenamiento y de los datos generados por el generador, esta red intenta clasificar las observaciones como "real" o "generated".

Este diagrama ilustra la estructura de una GAN.

Definir la red del generador

Defina la siguiente arquitectura de la red, que genera imágenes a partir de vectores aleatorios.

Esta red:

  • Convierte los vectores aleatorios de tamaño 100 a arreglos de 4 por 4 por 512 utilizando una operación de proyección y remodelación.

  • Mejora los arreglos resultantes para que sean arreglos de 64 por 64 por 3 usando una serie de capas de convolución traspuesta con normalización de lotes y capas ReLU.

Defina esta arquitectura de la red como una gráfica de capas y especifique las siguientes propiedades de red.

  • Para las capas de convolución traspuesta, especifique filtros de 5 por 5 con un número descendente de filtros para cada capa, un tramo de 2 y recorte de la salida en cada borde.

  • Para la última capa de convolución traspuesta, especifique tres filtros de 5 por 5 que correspondan a los tres canales RGB de las imágenes generadas y el tamaño de salida de la capa anterior.

  • Al final de la red, incluya una capa tanh.

Para proyectar y remodelar la entrada de ruido, use la capa personalizada projectAndReshapeLayer, que se adjunta a este ejemplo como archivo de apoyo. Para acceder a esta capa, abra el ejemplo como un script en vivo.

filterSize = 5;
numFilters = 64;
numLatentInputs = 100;

projectionSize = [4 4 512];

layersGenerator = [
    featureInputLayer(numLatentInputs)
    projectAndReshapeLayer(projectionSize)
    transposedConv2dLayer(filterSize,4*numFilters)
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")
    batchNormalizationLayer
    reluLayer
    transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")
    tanhLayer];

Para entrenar la red con un bucle de entrenamiento personalizado y habilitar la diferenciación automática, convierta la gráfica de capas a un objeto dlnetwork.

netG = dlnetwork(layersGenerator);

Definir la red del discriminador

Defina la siguiente red, que clasifica imágenes de 64 por 64 reales y generadas.

Cree una red que tome imágenes de 64 por 64 por 3 y devuelva una puntuación de predicción escalar usando una serie de capas de convolución traspuesta con normalización de lotes y capas ReLU con fugas. Añada ruido a las imágenes de entrada mediante abandono.

  • Para la capa de abandono, especifique una probabilidad de abandono de 0.5.

  • Para las capas de convolución, especifique filtros de 5 por 5 con un número ascendente de filtros para cada capa. Especifique también un tramo de 2 y el relleno de la salida.

  • Para las capas ReLU con fugas, especifique una escala de 0.2.

  • Para generar las probabilidades en el intervalo [0,1], especifique una capa convolucional con un filtro de 4 por 4 seguido por una capa sigmoide.

dropoutProb = 0.5;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,Normalization="none")
    dropoutLayer(dropoutProb)
    convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")
    batchNormalizationLayer
    leakyReluLayer(scale)
    convolution2dLayer(4,1)
    sigmoidLayer];

Para entrenar la red con un bucle de entrenamiento personalizado y habilitar la diferenciación automática, convierta la gráfica de capas a un objeto dlnetwork.

netD = dlnetwork(layersDiscriminator);

Definir las funciones de pérdida del modelo

Cree la función modelLoss, enumerada en la sección Función de pérdida del modelo del ejemplo, que toma como entrada las redes del generador y el discriminador, un minilote de datos de entrada, un arreglo de valores aleatorios y el factor de volteo, y devuelve los valores de pérdida, los gradientes de los valores de pérdida con respecto a los parámetros que se pueden aprender en las redes, el estado del generador y las puntuaciones de las dos redes.

Especificar las opciones de entrenamiento

Entrene con un tamaño de minilote de 128 durante 500 épocas. Para conjuntos de datos más grandes, puede que no sea necesario entrenar durante tantas épocas.

numEpochs = 500;
miniBatchSize = 128;

Especifique las opciones para la optimización de Adam. Para ambas redes, especifique:

  • Una tasa de aprendizaje de 0.0002

  • Un factor de decaimiento de gradiente de 0.5

  • Un factor de decaimiento de gradiente cuadrado de 0.999

learnRate = 0.0002;
gradientDecayFactor = 0.5;
squaredGradientDecayFactor = 0.999;

Si el discriminador aprende a discriminar entre imágenes reales y generadas demasiado rápidamente, es posible que el generador no sea capaz de entrenarse. Para equilibrar mejor el aprendizaje del discriminador y el generador, añada ruido a los datos reales volteando aleatoriamente las etiquetas asignadas a las imágenes reales.

Especifique que las etiquetas reales se volteen con una probabilidad de 0.35. Observe que esto no afecta al generador, ya que todas las imágenes generadas siguen estando correctamente etiquetadas.

flipProb = 0.35;

Muestre las imágenes de validación generadas cada 100 iteraciones.

validationFrequency = 100;

Entrenar un modelo

Para entrenar una GAN, entrene las dos redes simultáneamente para maximizar el rendimiento de ambas:

  • Entrene el generador para generar datos que "engañen" al discriminador.

  • Entrene el discriminador para distinguir entre datos reales y generados.

Para optimizar el rendimiento del generador, maximice la pérdida del discriminador cuando se proporcionen datos generados. Es decir, el objetivo del generador es generar datos que el discriminador clasifique como "real".

Para optimizar el rendimiento del discriminador, minimice la pérdida del discriminador cuando se proporcionen lotes de datos reales y generados. Es decir, el objetivo del discriminador es no ser "engañado" por el generador.

Idealmente, estas estrategias dan como resultado un generador que genera datos convincentemente realistas y un discriminador que ha aprendido representaciones de características fuertes que son representativas de los datos de entrenamiento.

Utilice minibatchqueue para procesar y gestionar los minilotes de imágenes. Para cada minilote:

  • Utilice la función de preprocesamiento de minilotes personalizada preprocessMiniBatch (definida al final de este ejemplo) para volver a escalar las imágenes en el intervalo [-1,1].

  • Descarte cualquier minilote parcial con menos observaciones que el tamaño de minilote especificado.

  • Dé formato a los datos de imagen con el formato "SSCB" (espacial, espacial, canal, lote). De forma predeterminada, el objeto minibatchqueue convierte los datos en objetos dlarray con el tipo subyacente single.

  • Entrene en una GPU, si se dispone de ella. De forma predeterminada, el objeto minibatchqueue convierte cada salida en gpuArray si hay una GPU disponible. Utilizar una GPU requiere Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox).

augimds.MiniBatchSize = miniBatchSize;

mbq = minibatchqueue(augimds, ...
    MiniBatchSize=miniBatchSize, ...
    PartialMiniBatch="discard", ...
    MiniBatchFcn=@preprocessMiniBatch, ...
    MiniBatchFormat="SSCB");

Entrene el modelo con un bucle de entrenamiento personalizado. Pase en bucle por los datos de entrenamiento y actualice los parámetros de la red en cada iteración. Para monitorizar el progreso del entrenamiento, muestre un lote de imágenes generadas usando un arreglo de retención de valores aleatorios para introducir en el generador, así como una gráfica de las puntuaciones.

Inicialice los parámetros para la optimización de Adam.

trailingAvgG = [];
trailingAvgSqG = [];
trailingAvg = [];
trailingAvgSqD = [];

Para monitorizar el progreso del entrenamiento, muestre un lote de imágenes generadas usando un lote de retención de vectores aleatorios fijos introducidos en el generador y represente las puntuaciones de la red.

Cree un arreglo de valores aleatorios de retención.

numValidationImages = 25;
ZValidation = randn(numLatentInputs,numValidationImages,"single");

Convierta los datos a objetos dlarray y especifique el formato "CB" (canal, lote).

ZValidation = dlarray(ZValidation,"CB");

Para el entrenamiento en GPU, convierta los datos a objetos gpuArray.

if canUseGPU
    ZValidation = gpuArray(ZValidation);
end

Para realizar un seguimiento de las puntuaciones del generador y del discriminador, use un objeto trainingProgressMonitor. Calcule el número total de iteraciones para la monitorización.

numObservationsTrain = numel(imds.Files);
numIterationsPerEpoch = floor(numObservationsTrain/miniBatchSize);
numIterations = numEpochs*numIterationsPerEpoch;

Inicialice el objeto TrainingProgressMonitor. Dado que el cronómetro empieza cuando crea el objeto de monitorización, asegúrese de crear el objeto cerca del bucle de entrenamiento.

monitor = trainingProgressMonitor( ...
    Metrics=["GeneratorScore","DiscriminatorScore"], ...
    Info=["Epoch","Iteration"], ...
    XLabel="Iteration");

groupSubPlot(monitor,Score=["GeneratorScore","DiscriminatorScore"])

Entrene la GAN. Para cada época, cambie el orden del almacén de datos y pase en bucle por minilotes de datos.

Para cada minilote:

  • Se detiene si la propiedad Stop del objeto TrainingProgressMonitor es true. La propiedad Stop cambia a true cuando hace clic en el botón Stop.

  • Evalúe los gradientes de la pérdida con respecto a los parámetros que se pueden aprender, el estado del generador y las puntuaciones de la red usando dlfeval y la función modelLoss.

  • Actualice los parámetros de red con la función adamupdate.

  • Represente las puntuaciones de las dos redes.

  • Después de cada validationFrequency iteraciones, muestre un lote de imágenes generadas para una entrada de generador de retención fija.

La ejecución del entrenamiento puede tardar algún tiempo.

epoch = 0;
iteration = 0;

% Loop over epochs.
while epoch < numEpochs && ~monitor.Stop
    epoch = epoch + 1;

    % Reset and shuffle datastore.
    shuffle(mbq);

    % Loop over mini-batches.
    while hasdata(mbq) && ~monitor.Stop
        iteration = iteration + 1;

        % Read mini-batch of data.
        X = next(mbq);

        % Generate latent inputs for the generator network. Convert to
        % dlarray and specify the format "CB" (channel, batch). If a GPU is
        % available, then convert latent inputs to gpuArray.
        Z = randn(numLatentInputs,miniBatchSize,"single");
        Z = dlarray(Z,"CB");

        if canUseGPU
            Z = gpuArray(Z);
        end

        % Evaluate the gradients of the loss with respect to the learnable
        % parameters, the generator state, and the network scores using
        % dlfeval and the modelLoss function.
        [~,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
            dlfeval(@modelLoss,netG,netD,X,Z,flipProb);
        netG.State = stateG;

        % Update the discriminator network parameters.
        [netD,trailingAvg,trailingAvgSqD] = adamupdate(netD, gradientsD, ...
            trailingAvg, trailingAvgSqD, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);

        % Update the generator network parameters.
        [netG,trailingAvgG,trailingAvgSqG] = adamupdate(netG, gradientsG, ...
            trailingAvgG, trailingAvgSqG, iteration, ...
            learnRate, gradientDecayFactor, squaredGradientDecayFactor);

        % Every validationFrequency iterations, display batch of generated
        % images using the held-out generator input.
        if mod(iteration,validationFrequency) == 0 || iteration == 1
            % Generate images using the held-out generator input.
            XGeneratedValidation = predict(netG,ZValidation);

            % Tile and rescale the images in the range [0 1].
            I = imtile(extractdata(XGeneratedValidation));
            I = rescale(I);

            % Display the images.
            image(I)
            xticklabels([]);
            yticklabels([]);
            title("Generated Images");
        end

        % Update the training progress monitor.
        recordMetrics(monitor,iteration, ...
            GeneratorScore=scoreG, ...
            DiscriminatorScore=scoreD);

        updateInfo(monitor,Epoch=epoch,Iteration=iteration);
        monitor.Progress = 100*iteration/numIterations;
    end
end

En este caso, el discriminador ha aprendido una representación fuerte que identifica imágenes reales entre las imágenes generadas. A su vez, el generador ha aprendido una representación de características de similar fuerza que permite generar imágenes parecidas a los datos de entrenamiento.

La gráfica de entrenamiento muestra las puntuaciones de las redes del generador y el discriminador. Para obtener más información sobre cómo interpretar las puntuaciones de las redes, consulte Monitor GAN Training Progress and Identify Common Failure Modes.

Generar imágenes nuevas

Para generar imágenes nuevas, use la función predict en el generador con un objeto dlarray que contenga un lote de vectores aleatorios. Para mostrar las imágenes juntas, use la función imtile y vuelva a escalar las imágenes con la función rescale.

Cree un objeto dlarray que contenga un lote de 25 vectores aleatorios para introducir a la red del generador.

numObservations = 25;
ZNew = randn(numLatentInputs,numObservations,"single");
ZNew = dlarray(ZNew,"CB");

Si hay una GPU disponible, convierta los vectores latentes a gpuArray.

if canUseGPU
    ZNew = gpuArray(ZNew);
end

Genere imágenes nuevas usando la función predict con el generador y los datos de entrada.

XGeneratedNew = predict(netG,ZNew);

Muestre las imágenes.

I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
figure
image(I)
axis off
title("Generated Images")

Función de pérdida del modelo

La función modelLoss toma como entrada los objetos dlnetwork del generador y el discriminador (netG y netD), un minilote de datos de entrada X, un arreglo de valores aleatorios Z y la probabilidad de voltear las etiquetas reales flipProb, y devuelve los valores de pérdida, los gradientes de los valores de pérdida con respecto a los parámetros que se pueden aprender en las redes, el estado del generador y las puntuaciones de las dos redes.

function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
    modelLoss(netG,netD,X,Z,flipProb)

% Calculate the predictions for real data with the discriminator network.
YReal = forward(netD,X);

% Calculate the predictions for generated data with the discriminator
% network.
[XGenerated,stateG] = forward(netG,Z);
YGenerated = forward(netD,XGenerated);

% Calculate the score of the discriminator.
scoreD = (mean(YReal) + mean(1-YGenerated)) / 2;

% Calculate the score of the generator.
scoreG = mean(YGenerated);

% Randomly flip the labels of the real images.
numObservations = size(YReal,4);
idx = rand(1,numObservations) < flipProb;
YReal(:,:,:,idx) = 1 - YReal(:,:,:,idx);

% Calculate the GAN loss.
[lossG, lossD] = ganLoss(YReal,YGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true);
gradientsD = dlgradient(lossD,netD.Learnables);

end

Función de pérdida GAN y puntuaciones

El objetivo del generador es generar datos que el discriminador clasifique como "real". Para maximizar la probabilidad de que las imágenes del generador sean clasificadas como reales por el discriminador, minimice la función de verosimilitud logarítmica negativa.

Dada la salida Y del discriminador:

  • Y es la probabilidad de que la imagen de entrada pertenezca a la clase "real".

  • 1-Y es la probabilidad de que la imagen de entrada pertenezca a la clase "generated".

La función de pérdida para el generador viene dada por

lossGenerator=-mean(log(YGenerated)),

donde YGenerated contiene las probabilidades de salida del discriminador para las imágenes generadas.

El objetivo del discriminador es no ser "engañado" por el generador. Para maximizar la probabilidad de que el discriminador discrimine correctamente entre las imágenes reales y generadas, minimice la suma de las correspondientes funciones de verosimilitud logarítmica negativa.

La función de pérdida para el discriminador viene dada por

lossDiscriminator=-mean(log(YReal))-mean(log(1-YGenerated)),

donde YReal contiene las probabilidades de salida del discriminador para las imágenes reales.

Para medir, en una escala de 0 a 1, el grado de éxito de los objetivos del generador y el discriminador, puede utilizar el concepto de puntuación.

La puntuación del generador es el promedio de probabilidades correspondiente a la salida del discriminador para las imágenes generadas:

scoreGenerator=mean(YGenerated).

La puntuación del discriminador es el promedio de probabilidades correspondiente a la salida del discriminador para las imágenes reales y generadas:

scoreDiscriminator=12mean(YReal)+12mean(1-YGenerated).

La puntuación es inversamente proporcional a la pérdida, pero contiene, de forma efectiva, la misma información.

function [lossG,lossD] = ganLoss(YReal,YGenerated)

% Calculate the loss for the discriminator network.
lossD = -mean(log(YReal)) - mean(log(1-YGenerated));

% Calculate the loss for the generator network.
lossG = -mean(log(YGenerated));

end

Función de preprocesamiento de minilotes

La función preprocessMiniBatch preprocesa los datos dando los siguientes pasos:

  1. Extraer los datos de imagen del arreglo de celdas de entrada y concatenarlos en un arreglo numérico.

  2. Volver a escalar las imágenes para que estén en el intervalo [-1,1].

function X = preprocessMiniBatch(data)

% Concatenate mini-batch
X = cat(4,data{:});

% Rescale the images in the range [-1 1].
X = rescale(X,-1,1,InputMin=0,InputMax=255);

end

Referencias

  1. The TensorFlow Team. Flowers http://download.tensorflow.org/example_images/flower_photos.tgz

  2. Radford, Alec, Luke Metz y Soumith Chintala. "Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks". Preimpresión, presentado el 19 de noviembre de 2015. http://arxiv.org/abs/1511.06434.

Consulte también

| | | | | | |

Temas