Почему tf.keras BatchNormalization заставляет GAN приводить к бессмысленным потерям и точности?

Фон:

Я получаю необычные потери и точность при обучении GAN со слоями пакетной нормализации в дискриминаторе с использованием tf.keras. GAN имеют оптимальное значение целевой функции log (4), которое возникает, когда дискриминатор полностью не может отличить реальные образцы от подделок и, следовательно, прогнозирует 0,5 для всех образцов. Когда я включаю слои BatchNormalization в свой дискриминатор, и генератор, и дискриминатор достигают почти идеальных результатов (высокая точность, низкие потери), что невозможно в условиях состязания.

Без BatchNorm:

На этом рисунке показаны потери (y) за эпоху (x), когда BN не используется . Обратите внимание, что случайные значения ниже теоретического минимума связаны с повторяющимся процессом обучения. На этом рисунке показана точность, когда BN не используется, которая составляет около 50% каждая. . Обе эти цифры показывают разумные значения.

С BatchNorm:

На этом рисунке показаны потери (y) за эпоху (x) при использовании BN. Посмотрите, как цель GAN, которая не должна опускаться ниже log (4), приближается к нулю. Эта цифра показывает точность при использовании BN, причем оба показателя близки к 100%. GAN состязательны; и генератор, и дискриминатор не могут иметь 100% точность одновременно.

Вопрос:

Код для создания и обучения GAN можно найти здесь. Я что-то упустил, допустил ли я ошибку в своей реализации, или есть ошибка в tf.keras? Я почти уверен, что это техническая проблема, а не теоретическая проблема, которую могут решить «взломы GAN». Обратите внимание, что это включает только использование слоев BatchNormalization в дискриминаторе; использование их в генераторе не вызывает этой проблемы.




Ответы (2)


Существует проблема с уровнем BatchNormalization Tensorflow в TF 2.0 и 2.1; переход на TF 1.15 решает проблему. Причина проблемы пока не установлена.

Вот соответствующая проблема GitHub: https://github.com/tensorflow/tensorflow/issues/37673

person Conor    schedule 19.03.2020

Причина проблемы очевидна. Дискриминатор учится различать этапы обучения и тестирования слоя BatchNormalization вместо обучения различать данные.

На этапе подготовки фактическое среднее значение партии и дисперсия используются в BN, в отличие от фазы тестирования, где используются скользящее среднее и скользящая дисперсия, хранящиеся в BN.

person Radek    schedule 19.07.2021