Идея GAN заключается в том, что у вас есть две сети, генератор G и дискриминатор D, конкурирующие друг с другом. Генератор передает «фальшивые» данные на дискриминатор. Дискриминатор также видит реальные обучающие данные и предсказывает, являются ли полученные данные настоящими или фальшивыми.

Генератор обучен обмануть дискриминатор, он хочет выводить данные, которые выглядят как можно ближе к реальным обучающим данным.

Дискриминатор - это классификатор, который обучен определять, какие данные настоящие, а какие - поддельные.

В конечном итоге происходит то, что генератор учится создавать данные, неотличимые от реальных данных для дискриминатора.

Общая структура GAN показана на диаграмме выше с использованием изображений MNIST в качестве данных. Скрытый образец - это случайный вектор, который генератор использует для создания своих поддельных изображений. Его часто называют скрытым вектором, а это векторное пространство - скрытым пространством. По мере обучения генератор выясняет, как сопоставить скрытые векторы с распознаваемыми изображениями, которые могут обмануть дискриминатор.

Если вас интересует генерация только новых изображений, вы можете выбросить дискриминатор после обучения. В этой записной книжке я покажу вам, как определять и обучать эти противоборствующие сети в PyTorch и генерировать новые изображения!

визуализировать данные:

Определите модель

GAN состоит из двух противоборствующих сетей, дискриминатора и генератора.

Дискриминатор

Дискриминаторная сеть будет довольно типичным линейным классификатором. Чтобы сделать эту сеть универсальным аппроксиматором функций, нам понадобится хотя бы один скрытый слой, и эти скрытые слои должны иметь один ключевой атрибут:

Все скрытые слои будут иметь функцию активации Leaky ReLu, примененную к их выходам.

Дырявый ReLu

Мы должны использовать дырявый ReLU, чтобы градиенты беспрепятственно текли обратно через слой. Негерметичный ReLU похож на обычный ReLU, за исключением того, что для отрицательных входных значений есть небольшой ненулевой выход.

Сигмовидный выход

Мы также воспользуемся подходом к использованию более устойчивой численно функции потерь на выходах. Напомним, что мы хотим, чтобы дискриминатор выдавал значение 0–1, указывающее, является ли изображение реальным или поддельным.

В конечном итоге мы будем использовать BCEWithLogitsLoss, который объединяет sigmoid функцию активации и и двоичную перекрестную потерю энтропии в одной функции.

Итак, к нашему окончательному выходному слою не должно быть применено никаких функций активации.

Генератор

Сеть генератора будет почти такой же, как и сеть дискриминатора, за исключением того, что мы применяем функцию активации tanh к нашему выходному слою.

tanh Выход

Было обнаружено, что генератор лучше всего работает с tanh для выхода генератора, который масштабирует выход в диапазоне от -1 до 1 вместо 0 и 1.

Напомним, что мы также хотим, чтобы эти выходные данные были сопоставимы с реальными входными значениями пикселей, которые считываются как нормализованные значения от 0 до 1.

Таким образом, при обучении дискриминатора нам также потребуется масштабировать наши реальные входные изображения, чтобы иметь значения пикселей от -1 до 1.

Я сделаю это позже в цикле обучения.

Потери дискриминатора и генератора

Теперь нам нужно посчитать убытки.

Потери дискриминатора

Для дискриминатора общие потери - это сумма потерь для реальных и поддельных изображений d_loss = d_real_loss + d_fake_loss.

Помните, что мы хотим, чтобы дискриминатор выдавал 1 для реальных изображений и 0 для поддельных изображений, поэтому нам нужно настроить потери, чтобы отразить это.

Потери будут за счет потери двоичной кросс-энтропии с логитами, которую мы можем получить с помощью BCEWithLogitsLoss. Это объединяет sigmoid функцию активации и и двоичную перекрестную потерю энтропии в одной функции.

Для реальных изображений нам нужно D(real_images) = 1. То есть мы хотим, чтобы дискриминатор классифицировал реальные изображения с меткой = 1, указывающей, что они настоящие. Чтобы дискриминатор лучше обобщал, метки немного уменьшены с 1,0 до 0,9. Для этого воспользуемся параметром smooth; если True, то мы должны сгладить наши метки. В PyTorch это выглядит как labels = torch.ones(size) * 0.9

Потери дискриминатора для поддельных данных аналогичны. Нам нужен D(fake_images) = 0, где поддельные изображения являются выводом генератора, fake_images = G(z).

Потери в генераторе

Потери генератора будут похожи только с перевернутыми этикетками. Цель генератора - получить D(fake_images) = 1. В этом случае метки перевернуты, чтобы показать, что генератор пытается обмануть дискриминатор, заставив его думать, что генерируемые им изображения (фальшивки) настоящие!

Обучение

Обучение будет включать в себя попеременное обучение дискриминатора и генератора. Мы будем использовать наши функции real_loss и fake_loss, чтобы помочь нам рассчитать потери дискриминатора во всех следующих случаях.

Обучение дискриминатору

  1. Вычислите потерю дискриминатора на реальных обучающих изображениях
  2. Создавайте поддельные изображения
  3. Вычислить потерю дискриминатора на поддельных, сгенерированных изображениях
  4. Сложите реальную и фальшивую потерю
  5. Выполнение обратного распространения ошибки + шаг оптимизации для обновления весов дискриминатора

Генератор обучение

  1. Создавайте поддельные изображения
  2. Вычислите потерю дискриминатора на поддельных изображениях, используя перевернутые метки!
  3. Выполнение обратного распространения ошибки + шаг оптимизации для обновления весов генератора

Сохранение образцов

Во время обучения мы также распечатаем некоторую статистику потерь и сохраним сгенерированные «поддельные» образцы.

Образцы генератора из обучения

Здесь мы можем просмотреть образцы изображений из генератора. Сначала посмотрим на изображения, которые мы сохранили во время обучения.

Это образцы из последней тренировочной эпохи. Вы можете видеть, что генератор может воспроизводить числа, такие как 1, 7, 3, 2. Поскольку это всего лишь образец, он не является репрезентативным для всего диапазона изображений, которые может создать этот генератор.

Надеюсь, вы сочтете эту статью полезной, чтобы запачкать руки с gan.