В этом блоге мы создадим генеративно-состязательную сеть (GAN), обученную на наборе данных MNIST (рукописные цифры). Из этого мы сможем генерировать новые рукописные цифры. GAN были представлены Яном Гудфеллоу и др. в 2014 году. С тех пор популярность GAN резко возросла. Этот метод может генерировать фотографии, которые, по крайней мере, внешне выглядят аутентичными для человека-наблюдателя, обладая многими реалистичными характеристиками.
В машинном обучении GAN — это класс алгоритмов искусственного интеллекта, используемых в неконтролируемом машинном обучении. Идея GAN заключается в том, что у вас есть две сети, generator G
и discriminator D
, которые соревнуются друг с другом в рамках игры с нулевой суммой.
Генератор передает поддельные данные дискриминатору, технически генерирующая сеть учится сопоставлять latent space z
с конкретным _4. Дискриминатор также видит реальные данные и предсказывает, являются ли полученные данные real
или fake
, различая экземпляры из real data distribution
и candidates produced by the generator
. .
Генератор обучен обманывать дискриминатор, и он хочет выводить данные, максимально приближенные к реальным данным, путем создания новых синтезированных экземпляров, которые, по-видимому, получены из реального распределения данных. Более того, дискриминатор обучен определять, какие данные настоящие, а какие поддельные. В конечном итоге генератор учится создавать данные, которые неотличимы от реальных данных для дискриминатора.Backpropagation
применяется в обеих сетях, чтобы генератор создавал более качественные изображения, в то время как дискриминатор становится более опытным в пометке синтетических изображений. Генератор обычно представляет собой деконволюционную нейронную сеть, а дискриминатор — сверточную нейронную сеть.
Импорт данных и библиотек
Входные данные модели
Во-первых, нам нужно создать входные данные для нашего графика TensorFlow. Нам нужны два входа: один для генератора и один для дискриминатора. Здесь мы назовем вход генератора inputs_z
и вход дискриминатора inputs_real
.
Генератор
Входными данными для генератора является ряд случайно сгенерированных чисел, называемых latent space
(Скрытые переменные — это переменные, которые не наблюдаются напрямую, а скорее выводятся с помощью математической модели из других переменных, которые наблюдаются).
Генератор представляет собой нейронную сеть, содержащую скрытый слой с Leaky ReLU
активацией и tanh
выходом. Он пытается сопоставить скрытое пространство с реальными изображениями набора данных, используя алгоритм обратного распространения. После обучения генератор может создавать цифровые изображения из скрытых образцов.
Дискриминатор
Дискриминатор — это классификатор, обученный с помощью обучения с учителем. Он классифицирует, является ли изображение real (1)
или Fake (0)
. Мы обучаем дискриминатор, используя как реальные изображения набора данных, так и изображения, сгенерированные генератором.
Если входное изображение взято из набора данных MNIST, дискриминатор должен классифицировать его как real
. Если входное изображение исходит от генератора, дискриминатор должен классифицировать его как fake
. Сеть дискриминатора почти такая же, как сеть генератора, за исключением того, что мы используем выходной слой sigmoid
.
Гиперпараметры
Определить сеть
Чтобы построить сеть из функций, определенных выше, мы подключаем генератор и дискриминатор для создания GAN.
- Во-первых, нужно получить наши входные данные,
input_real, input_z
изmodel_inputs
, используя размеры входных данных и z. - Затем мы создадим генератор. Это создает генератор с соответствующими размерами
input
иoutput
. - Потом дискриминаторы. Мы построим два из них, один для
real data
и один дляfake data
.
Мы установим reuse=True
, так как мы хотим, чтобы веса были одинаковы для данных real
и fake
, нам нужно повторно использовать переменные.
Определить потери
Теперь нам нужно рассчитать потери как для генератора, так и для дискриминатора, что немного сложно.
В дискриминаторе общие потери представляют собой сумму потерь для настоящих и поддельных изображений.
d_loss = d_loss_real + d_loss_fake
- Для логитов реального изображения,
d_logits_real
иlabels
, мы хотим, чтобы они были все, поскольку все это реальные изображения. Чтобы помочь дискриминатору лучше обобщать, метки немного уменьшены с 1,0 до 0,9. - Для поддельных данных аналогичны логиты
d_logits_fake
, эти поддельные логиты используются сlabels
всегоzeros
.
Наконец, генератор потерь использует d_logits_fake
логиты поддельных изображений. Но теперь они labels
все ones
. Генератор пытается обмануть дискриминатор, поэтому он хочет, чтобы дискриминатор выдавал единицы для поддельных изображений.
Оптимизаторы
Мы собираемся создать два оптимизатора, один для генератора и один для дискриминатора. Чтобы обновлять переменные генератора и дискриминатора по отдельности, нам нужен список переменных, характерных для оптимизатора. Чтобы получить все обучаемые переменные, мы используем tf.trainable_variables()
.
Мы использовали область видимости переменных, чтобы все имена переменных нашего генератора начинались с generator
, а все переменные в дискриминаторе начинались с discriminator
. Теперь нам просто нужно пройтись по списку от tf.trainable_variables()
и сохранить переменные, чтобы они начинались с generator
в g_vars
и discriminator
в d_vars
.
Сетевое обучение
Визуализация потерь при обучении
На тренировке мы составили список потерь для generator
и discriminatort t
, чтобы проверить, насколько хорошо обучен наш ГАН. Теперь мы можем визуализировать с помощью train_losses.pkl
.
Создать новый образец
Чтобы сгенерировать новый образец, мы идем load
к сохраненному модулю, initialize
к нашему сеансу и pass random noise
к генератору.
Вывод
Я надеюсь, что в этом блоге вы поняли базовую архитектуру нового метода, называемого генеративно-состязательными сетями. GAN — один из немногих успешных и эффективных методов машинного обучения без учителя. GAN успешно применяются во многих областях, включая генератор музыки, интерактивное редактирование изображений, оценку трехмерных форм, поиск лекарств, создание книг, которые кажутся написанными подлинными авторами, и многое другое. Некоторые из историй успеха включают в себя.
- Adobe Research использует GAN для разработки продуктов, создавая новые изображения с нуля на основе пользовательских каракулей.
- Facebook создал модель передачи в реальном времени, работающую на мобильных устройствах.
Исходный код: https://github.com/llabhishekll/GAN-implementation