Оригинальная статья: Фотореалистичное сверхвысокое разрешение одиночного изображения с использованием генеративно-состязательной сети

Введение

В приведенной выше статье предлагается нейронная сеть на основе остаточных блоков для суперразрешения изображений, потеря VGG для улучшения потери MSE, которая часто не обеспечивает точное создание изображений SR. Методы SRGAN из статьи также включают в себя обучение модели с потерей противника вместе с потерей контекста для дальнейшего улучшения качества реконструкции изображения.

Мы резюмировали концепции и методы статьи в предыдущем посте[2]. В этом посте мы реализуем сетевую архитектуру, потери и процедуру обучения методов, предложенных в этой статье. Полный код, использованный в этом посте, можно посмотреть здесь.

Загрузка данных

В статье их сети обучались с помощью культур из известного набора данных распознавания изображений ImageNet. Хотя обучать модели на больших объемах данных полезно, набор данных оказался слишком тяжелым, и я решил использовать набор данных tf_flowers, состоящий из 3670 изображений, которые могут показаться слишком маленькими, но их вполне достаточно для игрушечного набора данных для оценки. и сравните производительность каждого метода обучения из статьи.

data=tfds.load('tf_flowers')
train_data=data['train'].skip(600)
test_data=data['train'].take(600)

Мы используем модуль tensorflow_datasets для загрузки набора данных tf_flowers и берем первые 600 изображений в качестве набора данных проверки.

@tf.function
def build_data(data):
  cropped=tf.dtypes.cast(tf.image.random_crop(data['image'] / 255,(128,128,3)),tf.float32)
  lr=tf.image.resize(cropped,(32,32))
  return (lr,cropped * 2 - 1)
train_dataset_mapped = train_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE)
for x in train_dataset_mapped.take(1):
  plt.imshow(x[0].numpy())
 plt.show()
  plt.imshow(bicubic_interpolate(x[0].numpy(),(128,128)))
  plt.show()
  plt.imshow(x[1].numpy())
  plt.show()

Затем мы определяем функцию для сопоставления каждого изображения из набора данных с культурами (128, 128) и его копией с низким разрешением (32, 32). Мы можем применить эту функцию к нашему набору данных с помощью train_data.map(build_data, …) . Это будет выполняться перед каждой эпохой обучения.

Определение модели

def residual_block_gen(ch=64,k_s=3,st=1):
  model=tf.keras.Sequential([
  tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding='same'),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.LeakyReLU(),
  tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding='same'),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.LeakyReLU(),
  ])
  return model
def Upsample_block(x, ch=256, k_s=3, st=1):
  x = tf.keras.layers.Conv2D(ch,k_s, strides=(st,st),padding='same')(x)
  x = tf.nn.depth_to_space(x, 2) # Subpixel pixelshuffler
  x = tf.keras.layers.LeakyReLU()(x)
 return x
input_lr=tf.keras.layers.Input(shape=(None,None,3))
input_conv=tf.keras.layers.Conv2D(64,9,padding='same')(input_lr)
input_conv=tf.keras.layers.LeakyReLU()(input_conv)
SRRes=input_conv
for x in range(5):
  res_output=residual_block_gen()(SRRes)
  SRRes=tf.keras.layers.Add()([SRRes,res_output])
SRRes=tf.keras.layers.Conv2D(64,9,padding='same')(SRRes)
SRRes=tf.keras.layers.BatchNormalization()(SRRes)
SRRes=tf.keras.layers.Add()([SRRes,input_conv])
SRRes=Upsample_block(SRRes)
SRRes=Upsample_block(SRRes)
output_sr=tf.keras.layers.Conv2D(3,9,activation='tanh',padding='same')(SRRes)
SRResnet=tf.keras.models.Model(input_lr,output_sr)

Мы определяем архитектуру генератора остатка, используя Tensorflow. Были определены функции для построения всего остаточного блока, а также реализованы поэлементные соединения с пропуском суммы. Соединяются 5 остаточных блоков, и финальное изображение подвергается апсэмплингу с помощью метода перемешивания пикселей, реализованного в функции Upsample_block. Модель построена, как показано ниже. Поскольку эта сеть полностью состоит из свертки, нам не нужно определять входную форму, и поэтому модель также может обрабатывать изображения любого размера.

def residual_block_disc(ch=64,k_s=3,st=1):
  model=tf.keras.Sequential([
  tf.keras.layers.Conv2D(ch,k_s,strides=(st,st),padding='same'),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.LeakyReLU(),
  ])
  return model
input_lr=tf.keras.layers.Input(shape=(128,128,3))
input_conv=tf.keras.layers.Conv2D(64,3,padding='same')(input_lr)
input_conv=tf.keras.layers.LeakyReLU()(input_conv)
channel_nums=[64,128,128,256,256,512,512]
stride_sizes=[2,1,2,1,2,1,2]
disc=input_conv
for x in range(7):
  disc=residual_block_disc(ch=channel_nums[x],st=stride_sizes[x])(disc)
disc=tf.keras.layers.Flatten()(disc)
disc=tf.keras.layers.Dense(1024)(disc)
disc=tf.keras.layers.LeakyReLU()(disc)disc_output=tf.keras.layers.Dense(1,activation='sigmoid')(disc)
discriminator=tf.keras.models.Model(input_lr,disc_output)

Архитектура дискриминатора также реализована на основе спецификаций бумаг. Сеть представляет собой обычную CNN, которая вводит изображение и определяет его подлинность.

Реализация потерь

def PSNR(y_true,y_pred):
  mse=tf.reduce_mean( (y_true - y_pred) ** 2 )
  return 20 * log10(1 / (mse ** 0.5))
def log10(x):
  numerator = tf.math.log(x)
  denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype))
  return numerator / denominator
def pixel_MSE(y_true,y_pred):
  return tf.reduce_mean( (y_true - y_pred) ** 2 )

Мы определяем потерю MSE по пикселям и показатель PSNR для обучения и оценки. Эти формулировки потерь более подробно объясняются в предыдущем посте о концепциях этой статьи.

VGG19=tf.keras.applications.VGG19(weights='imagenet',include_top=False,input_shape=(128,128,3))

VGG_i,VGG_j=2,2
def VGG_loss(y_hr,y_sr,i_m=2,j_m=2):
  i,j=0,0
  accumulated_loss=0.0
  for l in VGG19.layers:\
    cl_name=l.__class__.__name__
    if cl_name=='Conv2D':
      j+=1
    if cl_name=='MaxPooling2D':
      i+=1
      j=0
    if i==i_m and j==j_m:
      break
    y_hr=l(y_hr)
    y_sr=l(y_sr)
    if cl_name=='Conv2D':
      accumulated_loss+=tf.reduce_mean((y_hr-y_sr)**2) * 0.006
  return accumulated_loss
def VGG_loss_old(y_true,y_pred):
  accumulated_loss=0.0
  for l in VGG19.layers:
    y_true=l(y_true)
    y_pred=l(y_pred)
    accumulated_loss+=tf.reduce_mean((y_true-y_pred)**2) * 0.006
  return accumulated_loss

Потеря VGG, предложенная в статье, сравнивает промежуточную активацию предварительно обученной сети VGG-19 при прогнозировании изображений. Мы последовательно проходим через каждый уровень модели VGG и сравниваем каждый промежуточный результат. Мы определяем интуитивную потерю VGG как VGG_loss_old, а точную потерю как VGG_loss.

cross_entropy = tf.keras.losses.BinaryCrossentropy()
def discriminator_loss(real_output, fake_output):
  real_loss = cross_entropy(tf.ones_like(real_output), real_output)
  fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
  total_loss = real_loss + fake_loss
  return total_loss
def generator_loss(fake_output):
  return cross_entropy(tf.ones_like(fake_output), fake_output)

Потери противника определяются, как указано выше. Код, связанный с процедурой состязательного обучения, в основном упоминается в руководстве по Tensorflow DCGAN[3].

Обучение

generator_optimizer=tf.keras.optimizers.Adam(0.001)
discriminator_optimizer=tf.keras.optimizers.Adam(0.001)
adv_ratio=0.001
evaluate=['PSNR']
# MSE
loss_func,adv_learning = pixel_MSE,False
# VGG2.2
loss_func,adv_learning = lambda y_hr,h_sr:VGG_loss(y_hr,y_sr,i_m=2,j_m=2),False
# VGG 5.4
loss_func,adv_learning = lambda y_hr,h_sr:VGG_loss(y_hr,y_sr,i_m=5,j_m=4),False
# SRGAN-MSE
loss_func,adv_learning = pixel_MSE,True
# SRGAN-VGG 2.2
loss_func,adv_learning = lambda y_hr,h_sr:VGG_loss(y_hr,y_sr,i_m=2,j_m=2),True
# SRGAN-VGG 5.4
loss_func,adv_learning = lambda y_hr,h_sr:VGG_loss(y_hr,y_sr,i_m=5,j_m=4),True
loss_func,adv_learning = lambda y_hr,h_sr:VGG_loss(y_hr,y_sr,i_m=5,j_m=4),True
#Real loss
loss_func,adv_learning = pixel_MSE,False

Сначала мы определяем гиперпараметры и функцию потерь для оптимизации модели. Во фрагменте приведены некоторые конфигурации потерь, предложенные в статье.

Этап обучения основан на учебном пособии Tensorflow DCGAN, цикл обучения может быть обобщен для всех возможных потерь. Противоборствующее обучение проводится только в том случае, если adv_learning=True . Мы суперразрешаем изображение, используя модель генератора, измеряем потери с заданной метрикой и фиксируем градиенты. Если следующий код кажется слишком сложным, я настоятельно рекомендую ознакомиться с руководством по DCGAN.

@tf.function()
def train_step(data,loss_func=pixel_MSE,adv_learning=True,evaluate=['PSNR'],adv_ratio=0.001):
  logs={}
  gen_loss,disc_loss=0,0
  low_resolution,high_resolution=data
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    super_resolution = SRResnet(low_resolution, training=True)
    gen_loss=loss_func(high_resolution,super_resolution)
    logs['reconstruction']=gen_loss
    if adv_learning:
      real_output = discriminator(high_resolution, training=True)
      fake_output = discriminator(super_resolution, training=True)
      
      adv_loss_g = generator_loss(fake_output) * adv_ratio
      gen_loss += adv_loss_g
      
      disc_loss = discriminator_loss(real_output, fake_output)
      logs['adv_g']=adv_loss_g
      logs['adv_d']=disc_loss
  gradients_of_generator = gen_tape.gradient(gen_loss, SRResnet.trainable_variables)
  generator_optimizer.apply_gradients(zip(gradients_of_generator, SRResnet.trainable_variables))
  
  if adv_learning:
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
  
  for x in evaluate:
    if x=='PSNR':
      logs[x]=PSNR(high_resolution,super_resolution)
  return logs

Обучение выполняется с помощью следующего кода, который зацикливает набор данных и вызывает предопределенную функцию train_step для каждого пакета. Как упоминалось выше, изображения снова обрезаются перед каждой эпохой.

for x in range(50):
  train_dataset_mapped = train_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE).batch(128)
  val_dataset_mapped = test_data.map(build_data,num_parallel_calls=tf.data.AUTOTUNE).batch(128)
  for image_batch in tqdm.tqdm(train_dataset_mapped, position=0, leave=True):
    logs=train_step(image_batch,loss_func,adv_learning,evaluate,adv_ratio)
  for k in logs.keys():
    print(k,':',logs[k],end='  ')
    print()

Оценка

Мы визуализируем несколько примеров изображений с суперразрешением с помощью обученных моделей. Первое изображение представляет собой исходное изображение HR, второе изображение представляет собой изображение SR, а третье и четвертое изображения представляют собой изображения с низким разрешением и бикубической интерполяцией. Хотя каждая модель не обучалась в течение достаточного времени, мы могли сравнить производительность каждой модели. Изображения, созданные с помощью моделей, обученных с помощью VGG и противоборствующих потерь, по-видимому, имеют лучшее качество. Посмотрите внимательно на реконструированную текстуру дерева на первом снимке.

Я не проверял все предложенные потери. Было бы здорово, если бы вы могли поделиться результатами после обучения других методов и оценить производительность с помощью кода, предоставленного в моей ссылке на COLAB, и попробовать обучить модель на больших наборах данных, таких как набор данных ImageNet. Кроме того, я уверен, что модель будет работать лучше с большим количеством периодов обучения. на текущем этапе обучения мы можем видеть искусственные фильтры в реконструированном изображении из-за незрелых слоев реконструкции ESPCN. Эту проблему можно решить путем большего количества итераций обучения, хотя модель по-прежнему превосходит модель, основанную на MSE, в восприятии.

  • SRResNet + MSE

  • SRResNet + ВГГ 2.2

  • SRResNet + ВГГ 5.4
  • СРГАН 0,001 + СКО

  • СРГАН 0,001 + ВГГ 2,2
  • СРГАН 0,001 + ВГГ 5,4

использованная литература

Моя реализация SRResnet/SRGAN в COLAB: https://colab.research.google.com/drive/15MGvc5h_zkB9i97JJRoy_-qLtoPEU2sp?usp=sharing

[1] Ледиг, Кристиан и др. «Фотореалистичное сверхвысокое разрешение одиночного изображения с использованием генеративной состязательной сети». Материалы конференции IEEE по компьютерному зрению и распознаванию образов. 2017.

[2] Суперразрешение с SRResnet, SRGAN. https://medium.com/analytics-vidhya/super-resolution-with-srresnet-srgan-2859b87c9c7f

[3] Учебное пособие по Tensorflow DCGAN: https://www.tensorflow.org/tutorials/generative/dcgan