«Самая интересная идея за последние 10 лет в ML»

[GANs], и предлагаемые сейчас варианты - это, на мой взгляд, самая интересная идея в области машинного обучения за последние 10 лет. -Янн ЛеКун

Генеративные состязательные сети, или сокращенно GAN, являются одними из самых мощных инструментов, которые могут использовать практики машинного обучения.

GAN способны делать все: от изображений сверхвысокого разрешения до перевода изображений, манипуляций с лицом, создания лекарств и многого другого.

Содержание

В этой статье мы рассмотрим как интуицию, лежащую в основе работы GAN, так и то, как мы можем реализовать их в Keras. В частности, мы сначала реализуем полностью подключенную GAN (FCGAN) для MNIST, а затем превратим ее в глубокую сверточную GAN (DCGAN) для класса CIFAR- 10.

Не стесняйтесь переходить к коду, если у вас уже есть понимание интуиции, лежащей в основе GAN.

Завершенный код, который мы будем создавать в этом руководстве, доступен на моем GitHub здесь.

Как работают GAN

GAN работает, сражаясь между двумя нейронными сетями, генератором и дискриминатором, в попытке узнать распределение вероятностей набора данных.

Генератор, часто обозначаемый буквой G, пытается создавать реалистичные изображения. Для этого он берет некоторый вектор шума z и применяет к нему ряд вычислений; эти вычисления обычно выполняются в форме нейронной сети. В результате получается изображение G (z), которое является попыткой генератора обмануть дискриминатор.

С другой стороны, дискриминатор D пытается классифицировать реальные и поддельные изображения. Изображения считаются «поддельными», если их создает генератор, и «настоящими», если они были выбраны из набора данных. x представляет входное изображение, а D (x) представляет вероятность того, что дискриминатор считает, что x реально.

По мере того, как дискриминатор улучшает классификацию реальных и поддельных изображений, он заставляет генератор улучшаться при создании изображений.

Цель состоит в том, чтобы генератор стал настолько искусным в создании изображений, чтобы сгенерированные изображения были неотличимы от реальности, а это означает, что даже идеальный дискриминатор никогда не будет уверен в достоверности изображений, то есть D (x) = 0,5.

Если вам нужно гораздо более подробное объяснение математики, лежащей в основе GAN, я бы рекомендовал прочитать исходную статью Гудфеллоу и др..

Применение на практике

Обучить GAN намного сложнее, чем понять, как это работает. Хотя я буду проходить через код Keras для создания простой GAN, я рекомендую примерно следовать тому, что я делаю, вместо того, чтобы копировать его дословно. Поиск правильной архитектуры для сетей GAN - сложная задача, и лучший способ получить представление о построении моделей - это метод проб и ошибок.

Настройка всего

Начнем с определения нескольких гиперпараметров:

  • Мы используем np.random.seed для получения стабильных результатов.
  • noise_dim определяет длину вектора случайного шума, z.
  • Мы выполняем 3750 шагов за эпоху как 60000, количество примеров в MNIST, разделенное на 16, размер пакета составляет 3750.
  • save_path указывает, где мы будем тренировать снимки изображений, сгенерированных во время обучения. Мы не будем использовать это позже.

После установки гиперпараметров мы можем загрузить набор данных по нашему выбору.

Выше мы также создаем каталог для сохранения изображений, сгенерированных во время обучения, если он еще не существует.

Когда мы реализуем FCGAN, мы будем загружать mnist в плоскую форму; однако мы заменим это при внедрении DCGAN.

Если бы мы создавали, скажем, лица, мы бы заменили этот раздел кодом для загрузки этих лиц в тот же формат.

Создание генератора

Теперь пора определить функцию для создания генератора. Это самая творческая часть модели, и я настоятельно рекомендую вам изобретать свои собственные архитектуры.

Этот генератор принимает на входе размер noise_dim и через несколько наборов Dense слоев с последующими LeakyReLU активациями выводит плоскую версию изображения.

Мы должны использовать tanh активацию для последнего слоя, чтобы оставаться согласованным с нашей предыдущей нормализацией изображений и для лучших результатов.

Оптимизатор, который мы используем, также очень важен. Если бы скорость обучения была слишком высокой, модель перешла бы в режим свертывания, в котором она больше не могла бы улучшаться, и создала бы изображения мусора .

Создание дискриминатора

Далее дискриминатор. Опять же, я рекомендую поиграть с архитектурой и посмотреть, какие результаты вы получите.

В этом дискриминаторе нет ничего особенного. Он принимает сплющенное изображение, пропускает его через несколько наборов слоев Dense и LeakyReLU, затем выдает единственную вероятность D (x), сжатую до значений от 0 до 1 с активацией сигмоида.

Создание GAN

Имея дискриминатор и генератор, мы можем создать комбинированную модель GAN.

Во-первых, мы можем инициализировать дискриминатор и генератор с помощью:

discriminator = create_discriminator()
generator = create_generator()

Затем мы можем использовать небольшой трюк; мы можем установить discriminator.trainable в False.

Зачем нам это нужно?
Что ж, мы не собираемся обучать модель генератора напрямую - мы собираемся объединить генератор и дискриминатор в одну модель, а затем обучить ее. Это позволяет генератору понять дискриминатор, чтобы он мог обновляться более эффективно.

Установка discriminator.trainable на False повлияет только на копию дискриминатора в комбинированной модели. Это хорошо! Если бы копию дискриминатора в комбинированной модели можно было обучить, она обновилась бы, чтобы быть хуже при классификации изображений. Мы рассмотрим это подробнее при обучении модели.

Чтобы объединить генератор и дискриминатор, мы будем вызывать дискриминатор на выходе генератора.

Это дает нам модель, которая принимает в качестве входных данных некоторый случайный шум z и возвращает, насколько убежден дискриминатор в том, что изображения генератора реальны, D (G (z) ).

В частности, он имеет входную форму (None, 100) и выходную форму (None, 1). 100 в форме ввода происходит от noise_dim.

Напомним, вот полный код для настройки GAN для обучения:

Обучение GAN

Пришло время приступить к обучению GAN.

Поскольку мы обучаем сразу две модели, дискриминатор и генератор, мы не можем полагаться на .fit функцию Кераса. Вместо этого мы должны вручную перебирать каждую эпоху и подбирать модели по партиям.

for epoch in range(epochs):
    for batch in range(steps_per_epoch):
        ...

Внутри вложенного цикла мы можем создать наши fake_x данные:

Мы делаем это, сначала отбирая некоторый случайный шум из нормального распределения, а затем получая предсказания генератора относительно шума. Переменная noise является кодовым эквивалентом переменной z, которую мы обсуждали ранее.

Мы можем создать наши real_x данные, выбирая случайные элементы из нашего x_train, которое мы загрузили ранее.

После создания как реальных, так и поддельных входных данных для нашего дискриминатора мы можем объединить их в одну переменную:

x = np.concatenate((real_x, fake_x))

Затем мы можем настроить наш y_data для дискриминатора:

Мы устанавливаем половину disc_y на 0,9, чтобы коррелировать с данными real_x.

Почему 0,9 вместо 1?
Сглаживание меток, процесс замены жестких значений (например, 1 или 0) на мягкие значения (например, 0,9 или 0,1) для меток, часто помогает обучению дискриминатора, уменьшая разреженные градиенты. Этот метод был предложен для GAN в Salimans et al. 2016 . Сглаживание меток обычно наиболее эффективно, когда применяется только к единицам y-данных, что в таком случае называется односторонним сглаживанием меток.

Наконец, мы можем обучить дискриминатор на партии:

d_loss = discriminator.train_on_batch(x, disc_y)

Обязательно сохраните потерю, полученную при обучении партии, в переменной, чтобы мы могли распечатать ее позже.

Затем мы можем обучить генератор всего двумя строками:

По сути, эти две строки говорят генератору сделать так, чтобы дискриминатор решил, что создаваемое вами изображение является реальным, при наличии некоторого случайного шума.

Подумайте на секунду, что здесь происходит под капотом.

Поскольку мы соединили две модели в одну, генератор понимает, как дискриминатор классифицирует свои изображения, и знает, как соответствующим образом обновить свою технику.
По умолчанию это также обновит веса дискриминатора, чтобы помочь генератору, но поскольку мы устанавливаем discriminator.trainable на False, это не так - заставляя генератор создавать более реалистичные изображения.

Мне эта идея кажется невероятной.

Внутри цикла эпохи, но вне цикла steps_per_epoch, вы можете распечатать потери генератора и дискриминатора. Вы можете сделать это примерно так:

Подведем итоги: вот полный цикл обучения, который мы только что реализовали:

Полученные результаты

Чтобы визуализировать наши результаты, мы можем реализовать быструю функцию, которая будет визуализировать график 10x10 сгенерированных изображений:

Затем мы можем показать изображения с помощью:

После 10 эпох обучения с размером пакета 16 и шагом на эпоху 3750, вот результаты, которые я получил:

Ничего страшного, но есть что улучшить.
Кроме того, если вы попытаетесь запустить это на CIFAR-10, вы, скорее всего, получите непригодные для использования результаты. FCGAN не справится.

Чтобы исправить это, мы можем использовать DCGAN.

DCGAN

Единственное различие между FCGAN и DCGAN объясняется их названиями. Полностью подключенная GAN использует полносвязные слои, также известные как плотные слои, для генерации и классификации изображений, в то время как глубокая сверточная GAN использует сверточные слои для выполнения той же задачи.

Предварительная обработка данных

Чтобы учесть новую архитектуру DCGAN, мы также должны изменить способ перестройки x_train.
Итак, мы можем изменить:
x_train = x_train.reshape(-1, img_rows*img_cols*channels)
на:
x_train = x_train.reshape((-1, img_rows, img_cols, channels))

Я также меняю:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
на:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

И, конечно, img_rows, img_cols, channels = 28, 28, 1 кому: img_rows, img_cols, channels = 32, 32, 3

Ради экономии времени я бы рекомендовал использовать только один класс CIFAR-10. Давайте выберем кошек, третий класс, так как это то, чего нам нужно больше в Интернете.

Мы можем сделать это с помощью:

x_train = x_train[np.where(y_train == 3)[0]]

np.where возвращает индексы массива, в котором оценка, например y_train == 3, верно.

Генератор

Для DCGAN у нас гораздо более продвинутая архитектура генератора.

Эта модель начинается с небольшого изображения, которое постепенно увеличивается до Conv2DTranspose слоев. Увеличение изображений GAN - это невероятная технология, разработанная NVIDIA, и архитектура генератора, представленная выше, имитирует значительно упрощенную ее версию.

Вы можете прочитать полную статью о прогрессивном развитии GAN здесь.

В этой модели мы также инициализируем веса с помощью инициализатора RandomNormal. Это дает лучшие результаты и никоим образом не ограничивается DCGAN.

Дискриминатор

Лучшие архитектуры классификаторов изображений также часто являются лучшими архитектурами дискриминаторов.

Из-за этого этот дискриминатор выглядит почти как любой стандартный классификатор изображений Keras, который вы видели.

Полученные результаты

Запуск этого класса CIFAR-10 дает гораздо лучшие результаты, чем когда-либо мог бы простой FCGAN.

Улучшенная архитектура позволит получать изображения более высокого качества. Популярным усовершенствованием стандартного DCGAN является включение остаточных блоков из ResNet paper. Это видно в генераторе U-Net, который мы реализуем в моей статье о продвинутой технике CycleGAN.

Вы также можете визуализировать процесс обучения, что означает, что вы можете вызывать функцию визуализации, которую мы реализовали ранее внутри цикла обучения, для каждых 10 эпох или около того. Если вы сделаете это, я бы порекомендовал сначала создать некоторый постоянный шум, чтобы вы могли видеть, как GAN изменяется с течением времени.

Ваш новый цикл обучения будет выглядеть примерно так:

Кроме того, внутри show_images вы можете сохранить фигуру с помощью plt.figsave(file), превратив функцию show_images в:

Если вы хотите, вы можете превратить каталог, в котором были сохранены эти изображения, в GIF с помощью:

Мы сортируем изображения по номеру эпохи, используя:

Это удаляет все нечисловые символы из пути к файлу, преобразует его в целое число и после этого сортирует его в порядке возрастания.

Если вам нужно подробное объяснение создания GIF-файлов на Python, я бы порекомендовал эту статью.

Заключение

В этой статье мы рассмотрели основы реализации GAN в Керасе. Я надеюсь создать целую серию статей об огромных возможностях GAN и о том, как мы можем реализовать их в простом коде Keras (и, возможно, немного в TensorFlow).

Также имейте в виду, что для обучения GAN требуется много вычислительной мощности. Не бойтесь запускать модель часами или даже днями. Если у вас недостаточно вычислительной мощности для обучения GAN, я бы рекомендовал использовать сеанс Kaggle Kernel или Google Colab.

Полный исходный код доступен на моем GitHub здесь.

Удачного кодирования!

Дальнейшее чтение