Генеративно-состязательные сети (GAN) — захватывающая область исследований, которая произвела революцию в машинном обучении. Благодаря своей способности создавать новые данные, изучая существующие данные, GAN добились впечатляющих результатов в различных областях, таких как синтез изображений и речи, создание музыки и даже открытие лекарств. В этом среднем посте я покажу вам, как создать GAN с нуля, используя Python и PyTorch.
Для начала мы импортируем необходимые модули и создадим гиперпараметры.
# Import the required modules import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import torchvision.datasets as datasets from torchvision import transforms from torchvision.utils import make_grid import matplotlib.pyplot as plt # setup of the main parameters and hyperparameters epochs = 500 cur_step = 0 info_step = 300 mean_gen_loss = 0 mean_disc_loss = 0 z_dimension = 128 lr = 0.00001 loss_func = nn.BCEWithLogitsLoss() batch_size = 128 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataloader = DataLoader(datasets.MNIST('.', download=True, transform=transforms.ToTensor()),shuffle=True, batch_size=batch_size)
Далее мы создадим сети генератора и дискриминатора. Сеть генератора принимает случайный шум в качестве входных данных и генерирует новую выборку, которая имитирует распределение исходного набора данных. Сеть дискриминатора принимает как сгенерированные данные, так и реальные данные и пытается различать их. Начнем с генератора.
# Generator network layers def generator_Layer(inp, out): return nn.Sequential( nn.Linear(inp, out), nn.BatchNorm1d(out), nn.ReLU(inplace=True) ) # Generator network class Generator(nn.Module): def __init__(self, z_dim=64, output_dim=784, hidden_dim=128): super(Generator, self).__init__() self.generator = nn.Sequential( generator_Layer(z_dim, hidden_dim), generator_Layer(hidden_dim, hidden_dim*2), generator_Layer(hidden_dim*2, hidden_dim*4), generator_Layer(hidden_dim*4, hidden_dim*8), nn.Linear(hidden_dim*8, output_dim), nn.Sigmoid() ) def forward(self, noise): return self.generator(noise)
Сеть генератора отвечает за создание новых данных, которые имитируют распределение исходного набора данных. Функция generator_Layer
создает последовательный слой, состоящий из линейного преобразования, пакетной нормализации и функции активации ReLU. Затем сеть генератора определяется как последовательность этих слоев, где вход представляет собой случайный шум размерности z_dim
, а выход представляет собой сгенерированные данные размерности output_dim
.
Функция forward
принимает случайный шум и пропускает его через сеть генератора для создания сгенерированных данных. Наконец, к выходным данным применяется сигмовидная функция активации для нормализации значений от 0 до 1, что является общим для данных изображения.
Теперь давайте посмотрим на код дискриминатора.
def discriminator_layer(inp, out): return nn.Sequential( nn.Linear(inp, out), nn.LeakyReLU(0.2) ) # Discriminator network class Discriminator(nn.Module): def __init__(self, input_dim, hidden_dim=256): super(Discriminator, self).__init__() self.discriminator = nn.Sequential( discriminator_layer(input_dim, hidden_dim*4), discriminator_layer(hidden_dim*4, hidden_dim*2), discriminator_layer(hidden_dim*2, hidden_dim), nn.Linear(hidden_dim, 1), nn.Sigmoid() ) def forward(self, img): return self.discriminator(img)
Сеть дискриминатора отвечает за различение сгенерированных данных и реальных данных. Функция discriminator_layer
создает последовательный слой, состоящий из линейного преобразования и функции активации LeakyReLU. Затем сеть дискриминатора определяется как последовательность этих слоев, где входные данные представляют собой данные размерности input_dim
, а выходные данные представляют собой скалярное значение, указывающее, являются ли входные данные реальными или сгенерированными.
Функция forward
принимает данные и передает их через сеть дискриминатора для получения выходного скалярного значения. Наконец, к выходным данным применяется сигмовидная функция активации для нормализации значений от 0 до 1, что представляет вероятность того, что входные данные реальны.
Чтобы начать обучение нашей GAN, нам нужно создать экземпляры сетей генератора и дискриминатора. Мы также определим функцию потерь, оптимизатор и количество эпох для обучения. Для нашей функции потерь мы будем использовать двоичную кросс-энтропию с потерей логитов (BCEWithLogitsLoss), которая представляет собой единый класс, сочетающий сигмовидный слой и BCELoss.
Мы установим скорость обучения на 0,0002 для оптимизатора Адама. Мы будем обучать GAN в течение 200 эпох, чтобы генератор научился создавать образцы, очень похожие на исходный набор данных.
# Instantiate generator and discriminator G = G.to(device) D = D.to(device) # Define loss function, optimizer, and epochs optimizer_G = optim.Adam(G.parameters(), lr=0.0002) optimizer_D = optim.Adam(D.parameters(), lr=0.0002) criterion = nn.BCEWithLogitsLoss() num_epochs = 200
Обучение GAN — сложная задача, которая включает в себя обучение двух нейронных сетей в минимаксной игре. Сеть генератора пытается создать поддельные образцы, которые могут обмануть сеть дискриминатора, в то время как сеть дискриминатора пытается отличить настоящие образцы от поддельных, созданных генератором.
# Train GAN for epoch in range(num_epochs): for batch_idx, (real_images, _) in enumerate(dataloader): """ Discriminator """ optimizer_D.zero_grad() noise = torch.randn((real_images.shape[0], z_dimension)) real_images = real_images.view(len(real_images), -1) fake_images = G(noise) disc_real = D(real_images) disc_fake = D(fake_images.detach()) disc_real_targets = torch.ones((real_images.shape[0], 1)) disc_fake_targets = torch.zeros((real_images.shape[0], 1)) loss_D_real = loss_func(disc_real, disc_real_targets) loss_D_fake = loss_func(disc_fake, disc_fake_targets) loss_D = (loss_D_fake+loss_D_real)/2 loss_D.backward(retain_graph=True) optimizer_D.step() """ Generator """ optimizer_G.zero_grad() noise = torch.randn((real_images.shape[0], z_dimension)) fake = G(noise) pred = D(fake) targets = torch.ones_like(pred) # print(targets.shape) loss_G = loss_func(pred, targets) loss_G.backward(retain_graph=True) optimizer_G.step() if batch_idx % 100 == 0: print(f"Epoch [{epoch+1}/{num_epochs}] Batch {batch_idx}/{len(dataloader)} Loss D: {loss_D:.4f}, Loss G: {loss_G:.4f}")
В приведенном выше коде мы начинаем с перебора количества эпох, определенных num_epochs
, и внутри каждой эпохи мы перебираем пакеты в загрузчике данных. Сначала мы обучаем сеть дискриминатора, передавая ей настоящие изображения из загрузчика данных и поддельные изображения, сгенерированные сетью-генератором. Мы рассчитываем потери для сети дискриминатора, используя бинарную кросс-энтропию с логитами потерь BCEWithLogisticLoss()
. Затем мы оптимизируем параметры сети дискриминатора с помощью оптимизатора Адама со скоростью обучения 0,0002.
Затем мы обучаем сеть генератора, генерируя поддельные изображения, используя случайный шум в качестве входных данных, и передавая их в сеть дискриминатора. Сеть генератора пытается создать поддельные изображения, которые могут обмануть сеть дискриминатора, и мы вычисляем потери для сети генератора, используя ту же функцию BCEWithLogisticLoss()
. Затем мы оптимизируем параметры сети генератора с помощью оптимизатора Адама с той же скоростью обучения.
После каждой партии мы распечатываем потери для обеих сетей и отображаем часть поддельных и реальных изображений с помощью функции show()
. Делая это, мы можем наблюдать за прогрессом сети генератора с течением времени и видеть, насколько хорошо она может генерировать изображения, напоминающие исходный набор данных.
По результатам нашей модели GAN мы можем наблюдать, что производительность генератора улучшается по мере обучения. В начале процесса обучения генератор производит в основном шум, что приводит к некачественным изображениям. Однако по мере прохождения эпох мы можем наблюдать постепенное улучшение качества генерируемых цифр. На рисунке 1 левое изображение представляет выходные данные генератора во время первой итерации модели, а правое изображение показывает выходные данные генератора в эпоху 139. Здесь мы можем видеть хорошо нарисованные цифры, которые очень похожи на предполагаемые числа. .
Еще одно интересное наблюдение из результатов заключается в том, что производительность генератора также улучшается по мере того, как генератор создает более качественные изображения. Об этом свидетельствует возрастающая потеря дискриминатора, указывающая на то, что дискриминатору становится все труднее различать реальные и синтетические изображения.
Обучение GAN требует большого терпения и экспериментов с различными гиперпараметрами. Эта захватывающая область имеет практическое применение в создании реалистичных изображений, аудио и текста. Я надеюсь, что это руководство помогло вам разобраться с основами GAN, и теперь вы готовы экспериментировать и создавать свои собственные модели GAN для исследования и развлечения.