Генеративно-состязательные сети (GAN) — захватывающая область исследований, которая произвела революцию в машинном обучении. Благодаря своей способности создавать новые данные, изучая существующие данные, GAN добились впечатляющих результатов в различных областях, таких как синтез изображений и речи, создание музыки и даже открытие лекарств. В этом среднем посте я покажу вам, как создать GAN с нуля, используя Python и PyTorch.

Для начала мы импортируем необходимые модули и создадим гиперпараметры.

# Import the required modules
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
from torchvision import transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

# setup of the main parameters and hyperparameters
epochs         = 500
cur_step       = 0
info_step      = 300
mean_gen_loss  = 0
mean_disc_loss = 0
z_dimension    = 128
lr             = 0.00001
loss_func      = nn.BCEWithLogitsLoss()
batch_size     = 128
device         = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataloader = DataLoader(datasets.MNIST('.', download=True, transform=transforms.ToTensor()),shuffle=True, batch_size=batch_size)

Далее мы создадим сети генератора и дискриминатора. Сеть генератора принимает случайный шум в качестве входных данных и генерирует новую выборку, которая имитирует распределение исходного набора данных. Сеть дискриминатора принимает как сгенерированные данные, так и реальные данные и пытается различать их. Начнем с генератора.

# Generator network layers
def generator_Layer(inp, out):
    return nn.Sequential(
      nn.Linear(inp, out),
      nn.BatchNorm1d(out),
      nn.ReLU(inplace=True)
    )
# Generator network
class Generator(nn.Module):
    def __init__(self, z_dim=64, output_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        self.generator = nn.Sequential(
            generator_Layer(z_dim,        hidden_dim),
            generator_Layer(hidden_dim,   hidden_dim*2),
            generator_Layer(hidden_dim*2, hidden_dim*4), 
            generator_Layer(hidden_dim*4, hidden_dim*8), 
            nn.Linear(hidden_dim*8, output_dim), 
            nn.Sigmoid()
        )
        
    def forward(self, noise):
        return self.generator(noise)

Сеть генератора отвечает за создание новых данных, которые имитируют распределение исходного набора данных. Функция generator_Layer создает последовательный слой, состоящий из линейного преобразования, пакетной нормализации и функции активации ReLU. Затем сеть генератора определяется как последовательность этих слоев, где вход представляет собой случайный шум размерности z_dim, а выход представляет собой сгенерированные данные размерности output_dim.

Функция forward принимает случайный шум и пропускает его через сеть генератора для создания сгенерированных данных. Наконец, к выходным данным применяется сигмовидная функция активации для нормализации значений от 0 до 1, что является общим для данных изображения.

Теперь давайте посмотрим на код дискриминатора.

def discriminator_layer(inp, out):
    return nn.Sequential(
      nn.Linear(inp, out),
      nn.LeakyReLU(0.2)
    )

# Discriminator network
class Discriminator(nn.Module):
    
    def __init__(self, input_dim, hidden_dim=256):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            discriminator_layer(input_dim,    hidden_dim*4),
            discriminator_layer(hidden_dim*4, hidden_dim*2), 
            discriminator_layer(hidden_dim*2, hidden_dim), 
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.discriminator(img)

Сеть дискриминатора отвечает за различение сгенерированных данных и реальных данных. Функция discriminator_layer создает последовательный слой, состоящий из линейного преобразования и функции активации LeakyReLU. Затем сеть дискриминатора определяется как последовательность этих слоев, где входные данные представляют собой данные размерности input_dim, а выходные данные представляют собой скалярное значение, указывающее, являются ли входные данные реальными или сгенерированными.

Функция forward принимает данные и передает их через сеть дискриминатора для получения выходного скалярного значения. Наконец, к выходным данным применяется сигмовидная функция активации для нормализации значений от 0 до 1, что представляет вероятность того, что входные данные реальны.

Чтобы начать обучение нашей GAN, нам нужно создать экземпляры сетей генератора и дискриминатора. Мы также определим функцию потерь, оптимизатор и количество эпох для обучения. Для нашей функции потерь мы будем использовать двоичную кросс-энтропию с потерей логитов (BCEWithLogitsLoss), которая представляет собой единый класс, сочетающий сигмовидный слой и BCELoss.

Мы установим скорость обучения на 0,0002 для оптимизатора Адама. Мы будем обучать GAN в течение 200 эпох, чтобы генератор научился создавать образцы, очень похожие на исходный набор данных.

# Instantiate generator and discriminator
G = G.to(device)
D = D.to(device)

# Define loss function, optimizer, and epochs
optimizer_G = optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = optim.Adam(D.parameters(), lr=0.0002)
criterion = nn.BCEWithLogitsLoss()
num_epochs = 200

Обучение GAN — сложная задача, которая включает в себя обучение двух нейронных сетей в минимаксной игре. Сеть генератора пытается создать поддельные образцы, которые могут обмануть сеть дискриминатора, в то время как сеть дискриминатора пытается отличить настоящие образцы от поддельных, созданных генератором.

# Train GAN
for epoch in range(num_epochs):
    for batch_idx, (real_images, _) in enumerate(dataloader):
        """
        Discriminator
        """
        optimizer_D.zero_grad()
        
        noise = torch.randn((real_images.shape[0], z_dimension))
        
        real_images = real_images.view(len(real_images), -1)
        fake_images = G(noise)
        
        disc_real = D(real_images)
        disc_fake = D(fake_images.detach())
        
        
        disc_real_targets = torch.ones((real_images.shape[0], 1))
        disc_fake_targets = torch.zeros((real_images.shape[0], 1))
        
        
        loss_D_real = loss_func(disc_real, disc_real_targets)
        loss_D_fake = loss_func(disc_fake, disc_fake_targets)
        
        loss_D = (loss_D_fake+loss_D_real)/2
        loss_D.backward(retain_graph=True)
        optimizer_D.step()
        
        """
        Generator
        """ 
        optimizer_G.zero_grad()
        noise = torch.randn((real_images.shape[0], z_dimension))
        fake = G(noise)
        pred = D(fake)
        targets = torch.ones_like(pred)
#         print(targets.shape)
        loss_G = loss_func(pred, targets)
        
        loss_G.backward(retain_graph=True)
        optimizer_G.step()
        
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} Loss D: {loss_D:.4f}, Loss G: {loss_G:.4f}")
          

В приведенном выше коде мы начинаем с перебора количества эпох, определенных num_epochs, и внутри каждой эпохи мы перебираем пакеты в загрузчике данных. Сначала мы обучаем сеть дискриминатора, передавая ей настоящие изображения из загрузчика данных и поддельные изображения, сгенерированные сетью-генератором. Мы рассчитываем потери для сети дискриминатора, используя бинарную кросс-энтропию с логитами потерь BCEWithLogisticLoss(). Затем мы оптимизируем параметры сети дискриминатора с помощью оптимизатора Адама со скоростью обучения 0,0002.

Затем мы обучаем сеть генератора, генерируя поддельные изображения, используя случайный шум в качестве входных данных, и передавая их в сеть дискриминатора. Сеть генератора пытается создать поддельные изображения, которые могут обмануть сеть дискриминатора, и мы вычисляем потери для сети генератора, используя ту же функцию BCEWithLogisticLoss(). Затем мы оптимизируем параметры сети генератора с помощью оптимизатора Адама с той же скоростью обучения.

После каждой партии мы распечатываем потери для обеих сетей и отображаем часть поддельных и реальных изображений с помощью функции show(). Делая это, мы можем наблюдать за прогрессом сети генератора с течением времени и видеть, насколько хорошо она может генерировать изображения, напоминающие исходный набор данных.

По результатам нашей модели GAN мы можем наблюдать, что производительность генератора улучшается по мере обучения. В начале процесса обучения генератор производит в основном шум, что приводит к некачественным изображениям. Однако по мере прохождения эпох мы можем наблюдать постепенное улучшение качества генерируемых цифр. На рисунке 1 левое изображение представляет выходные данные генератора во время первой итерации модели, а правое изображение показывает выходные данные генератора в эпоху 139. Здесь мы можем видеть хорошо нарисованные цифры, которые очень похожи на предполагаемые числа. .

Еще одно интересное наблюдение из результатов заключается в том, что производительность генератора также улучшается по мере того, как генератор создает более качественные изображения. Об этом свидетельствует возрастающая потеря дискриминатора, указывающая на то, что дискриминатору становится все труднее различать реальные и синтетические изображения.

Обучение GAN требует большого терпения и экспериментов с различными гиперпараметрами. Эта захватывающая область имеет практическое применение в создании реалистичных изображений, аудио и текста. Я надеюсь, что это руководство помогло вам разобраться с основами GAN, и теперь вы готовы экспериментировать и создавать свои собственные модели GAN для исследования и развлечения.