GAN для спешащих людей

«Генеративные рекламные сети - самая интересная идея за последние десять лет в машинном обучении»

Янн ЛеКун, директор по исследованиям искусственного интеллекта Facebook.

GAN - это потрясающая и новая идея в машинном обучении (ML). В этой статье я объясню, что такое GAN, типы GAN и их преимущества, покажу, как некоторые компании делают удивительные вещи с GAN, и я объясню, как создать такую ​​сеть самостоятельно.

Что такое GAN?

Так что же это за волшебные алгоритмы? GAN (генерирующие рекламные сети) обычно представляют собой две CNN (сверточную нейронную сеть) и состоят из генератора и дискриминатора, которые являются противниками. Самый простой способ объяснить это - аналогия, так что поехали! Вы можете представить себе GAN в виде полицейского и фальшивомонетчика. Фальшивомонетчик делает фальшивые деньги, но сначала он ничего не знает, поэтому делает что-то вроде этого:

Но полицейский говорит фальшивомонетчику и самому себе, что с этим не так, в результате чего и фальшивомонетчику, и полиции становится лучше. На практике мы называем фальшивомонетчика генератором, а полицию - дискриминатором.

Типы GAN

Итак, теперь, когда мы знаем, что такое GAN, какие бывают типы?

База

Любой GAN для запуска требуется одна из этих двух GAN:

OG GAN

Это оригинальный GAN, впервые упомянутый в статье Иэна Гудфеллоу в 2014 году.

DCGAN (состязательная сеть, генерирующая глубокую свертку)

DCGAN - это улучшенная версия обычного GAN, которая обычно используется поверх GAN, поскольку она более стабильна и в целом лучше.

Расширения

После того, как у нас будет базовый набор, мы можем улучшить нашу сеть GAN, специализируя ее под определенные нужды. Есть сотни способов изменить нашу GAN, но я остановлюсь только на наиболее распространенных. Чтобы увидеть больше, загляните в Зоопарк ГАН.

WGAN

WGAN или Wasserstein GAN - это GAN, которая меняет дискриминатор на критик и использует потерю Вассерштейна, где потеря больше не привязана между 0 и 1 и вместо этого может быть любым действительным числом. Это решает проблему исчезающего градиента и проблему коллапса режима, что в основном означает, что GAN не застревает в локальных минимумах, и если создание объекта с несколькими группами не застрянет, генерируя только одну группу.

CGAN

CGAN или условный GAN - это GAN, который предотвращает коллапс режима и позволяет вам выбирать определенные изображения из вашего GAN. CGAN делает это, вводя метки наборов данных в GAN.

Pix2Pix GAN

Pix2Pix GAN - это ваш общий GAN для преобразования изображений в изображения, который может выполнять такие действия, как восстановление старого изображения с недостающими частями, добавление цвета к изображениям без цвета и превращение рисунков в изображения. Pix2Pix GAN работает как отправная точка для более сложной модели преобразования изображений, такой как цикл GAN. Pix2Pix GAN использует условную GAN для сравнения хороших результатов с худшими, чтобы получить хорошие результаты.

Цикл GAN

Цикл GAN позволяет передавать образцы без соответствующих наборов данных, например, если вы хотите сделать лошадь зеброй или фотографию, похожую на картину Ван Гога. Цикл GAN работает, но имеет два GAN. Один GAN берет лошадь (вход 1) и пытается превратить эту лошадь в зебру (выход 1), затем дискриминатор оценивает, насколько зебра похожа на эту лошадь, которая была превращена в зебру (потеря 1). После этого лошадь-зебра (вход 2) теперь помещается в другой генератор, пытаясь заставить ее выглядеть как можно ближе к исходной лошади (выход 2). Наконец, второй дискриминатор вычисляет потери для этой лошади-зебры (потеря 2). Я знаю, что это звучит очень сложно, но в основном лошадь превращается в зебру обратно к лошади.

Прогрессивная GAN

Прогрессивная GAN - это GAN, которая растет во время обучения, переходя от низкого разрешения к высокому. GAN растет за счет добавления слоев к модели во время обучения. Это сделает GAN действительно стабильным, ускоряет обучение и позволит вам добиться лучших результатов.

Информация GAN

Информация GAN - это GAN, который позволяет вам сортировать вывод по таким параметрам, как ширина, поворот и т. Д. Информация GAN делает это, отсортировав вывод по шуму и скрытым кодам.

Компании

Компаний, основанных на GAN, не так много, при этом есть два основных типа компаний, основанных на GAN.

Создание контента

Фотографировать и платить людям за рекламу - дорогое удовольствие. Так что позвольте GAN сделать эту работу за нас. GAN отлично подходят для творческой работы, поэтому любые задачи, такие как дизайн, редактирование изображений, создание изображений / видео, могут быть автоматизированы с помощью GAN. Некоторыми примерами компаний, использующих GAN для творчества, являются RoseBud.ai, DataGrid и другие. Вот видео с RoseBud, показывающее, что они делают.

Оптимизация продукта

Конечно, я говорю о Святом Граале GAN - генеративном дизайне. Генеративный дизайн - фантастическая технология, способная изменить весь мир. Генеративный дизайн - это приложение GAN, которое позволяет нам оптимизировать строительство, транспортировку, мебель и т. Д. С помощью итеративного процесса. Некоторые компании, использующие генеративный дизайн, - это Space Factory, Fusion 360 и другие.

Как создать GAN

Звучит потрясающе, так как же нам это сделать? Я покажу вам, как создать DCGAN с помощью Keras и TensorFlow. Ссылки на другие GAN я сделал в конце.

Импорт библиотек

import tqdm
import os
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from glob import glob
import keras

Эта часть не требует пояснений, но мы импортируем любые библиотеки, которые собирались использовать в коде. Как правило, это те, которые я бы рекомендовал, но вам может потребоваться импортировать другие библиотеки.

Определить переменные

IMG_WIDTH = 100
IMG_HEIGHT = 100
random_dim = 100 
np.random.seed(512) 
datadir = "" #path goes here

Определение переменных для нормализации изображения (ширина и высота), скрытого пространства (random_dim и seed) и определение пути к набору данных.

Нормализация изображения / настройка конвейера

from tqdm import tqdm
data = []
def create_data():
    path = datadir
    for img in tqdm(os.listdir(path)):
      img_array = cv2.imread(os.path.join(path,img))
      new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE), interpolation = cv2.INTER_NEAREST)
      drawings_data.append(new_array)
    plt.imshow(drawings_data[0])
    plt.show()
    print(new_array)
create_data()

Здесь мы выполняем некоторую базовую нормализацию изображения, такую ​​как изменение размера изображений, а затем преобразование их в массив, чтобы модель могла использовать данные.

def load_data():
  x_train = data
  x_train = (np.asarray(x_train).astype(np.float32) - 127.5)/127.5
  return x_train

Здесь мы меняем форму массива для модели.

Определение модели

def get_optimizer():
  return Adam(lr=0.0002, beta_1=0.9)

Здесь мы определяем скорость обучения и бета-версию для модели, как правило, я рекомендую 0,0002 для скорости обучения для большинства моделей и 0,9 бета для большинства моделей.

def get_generator(optimizer):
  generator = Sequential()
generator.add(Dense(13*13*40, input_dim = random_dim,kernel_initializer=initializers.RandomNormal(stddev=0.02)))
  generator.add(BatchNormalization())
  generator.add(LeakyReLU())
generator.add(Reshape((13, 13, 40)))
generator.add(Conv2DTranspose(256, (3, 3), strides=(1, 1), padding='same'))
  generator.add(BatchNormalization())
  generator.add(LeakyReLU())
generator.add(Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same'))
  generator.add(BatchNormalization())
  generator.add(LeakyReLU())
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same'))
  generator.add(BatchNormalization())
  generator.add(LeakyReLU())
generator.add(Conv2DTranspose(3, (3, 3), strides=(2, 2), padding='same', activation='tanh'))
generator.add(Cropping2D(cropping=((2,2),(2,2))))
return generator

Здесь мы определяем модель генератора.

def get_discriminator(optimizer):
discriminator = Sequential()
discriminator.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=(100,100, 3)))
  discriminator.add(LeakyReLU())
  discriminator.add(Dropout(0.02))
  
  discriminator.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
  discriminator.add(LeakyReLU())
  discriminator.add(Dropout(0.02))
  
  discriminator.add(Flatten())
  discriminator.add(Dense(1, activation= "sigmoid",))
  discriminator.compile(loss="BinaryCrossentropy", optimizer=optimizer)
return discriminator

Здесь мы определяем модель дискриминатора.

def get_gan_network(discriminator, random_dim, generator, optimizer):
  discriminator.trainable = False
  gan_input = Input(shape=(random_dim,))
  x = generator(gan_input)
  gan_output = discriminator(x)
  gan = Model(inputs=gan_input, outputs=gan_output)
  gan.compile(loss="binary_crossentropy", optimizer=optimizer)
  return gan

Наконец, здесь мы определяем GAN.

Обучение модели

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(5, 5)):
  noise = np.random.normal(0, 1, size=[examples, random_dim])
  generated_images = generator.predict(noise)
  generated_images = generated_images.reshape(examples, 100, 100, 3)
  generated_images += 1
  generated_images /= 2
plt.figure(figsize=figsize)
  for i in range(generated_images.shape[0]):
    plt.subplot(dim[0], dim[1], i+1)
    plt.imshow(generated_images[i], interpolation="nearest")
    plt.axis("off")
  plt.tight_layout()
  plt.savefig("gan_generated_image_epoch_%d.png" % epoch)

Здесь мы настраиваем функцию для построения изображений для нас.

def train(epochs=1, batch_size=50):
  x_train = load_data()
  batch_count = x_train.shape[0] / batch_size
  adam = get_optimizer()
  generator = get_generator(adam)
  discriminator = get_discriminator(adam)
  gan = get_gan_network(discriminator, random_dim, generator, adam)
  #random noise and images
  for e in range(1, epochs+1):
    print("-"*15,"Epoch %d" % e, "-"*15)
    for _ in tqdm(range(int(batch_count))):
      noise = np.random.normal(0, 1, size=[batch_size, random_dim])
      image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
#generate fake images
      generated_images = generator.predict(noise)
      X = np.concatenate([image_batch, generated_images])
#labels for fake + real
      y_dis = np.zeros(2*batch_size)
      y_dis[:batch_size] = 0.9
discriminator.trainable = True 
      discriminator.train_on_batch(X, y_dis)
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
      y_gen = np.ones(batch_size)
      discriminator.trainable = False
      gan.train_on_batch(noise, y_gen)
if e == 1 or e % 5 == 0:
      plot_generated_images(e, generator)

Теперь мы собираемся определить функцию обучения, а затем, наконец, мы можем обучить нашу модель.

train(100,50)

Здесь он настроен на работу в течение 100 эпох с размером пакета 50, но это можно легко изменить.

Теперь вы знаете, как создать GAN и различные типы GAN и чем они полезны.

Вот некоторые из результатов, которые я получил от своих GAN.

Вот ссылка на Github, чтобы увидеть код для всех моих GAN.

Если вам понравилась эта статья, вам, вероятно, понравятся и другие мои, так что подумайте о том, чтобы подписаться на меня в Medium, и хорошо, что вы делаете это, подписывайтесь на меня в Twitter, Linkedin и подпишитесь на мою рассылку новостей.