Генеративные состязательные сети в Керасе работают не так, как ожидалось

Я новичок в машинном обучении Keras. Я пытаюсь разобраться в генеративных состязательных сетях (GAN). Для этого я пытаюсь запрограммировать простой пример. Я генерирую данные с помощью следующей функции:

def genReal(l):
    realX = []
    for i in range(l):
        x = []
        y = []
        for i in np.arange(0.0, 1.0, 0.02):
            x.append(i + np.random.normal(0,0.01))
            y.append(-abs(i-0.5)+0.5+ np.random.normal(0,0.01))

        data = np.array(list(zip(x, y)))
        data = np.reshape(data, (100))
        data.clip(0,1)
        realX.append(data)

    realX = np.array(realX)
    return realX

Данные, созданные с помощью этой функции, похожи на эти примеры:

введите здесь описание изображения Теперь цель должна состоять в том, чтобы обучить нейронную сеть генерировать похожие данные. Для GAN нам нужна сеть Генератор, которую я смоделировал следующим образом:

generator = Sequential()
generator.add(Dense(128, input_shape=(100,), activation='relu'))
generator.add(Dropout(rate=0.2))
generator.add(Dense(128, activation='relu'))
generator.add(Dropout(rate=0.2))
generator.add(Dense(100, activation='sigmoid'))
generator.compile(loss='mean_squared_error', optimizer='adam')

дискриминатор, который выглядит так:

discriminator = Sequential()
discriminator.add(Dense(128, input_shape=(100,), activation='relu'))
discriminator.add(Dropout(rate=0.2))
discriminator.add(Dense(128, activation='relu'))
discriminator.add(Dropout(rate=0.2))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='mean_squared_error', optimizer='adam')

комбинированная модель:

ganInput = Input(shape=(100,))
x = generator(ganInput)
ganOutput = discriminator(x)

GAN = Model(inputs=ganInput, outputs=ganOutput)
GAN.compile(loss='binary_crossentropy', optimizer='adam')

У меня есть функция, которая генерирует шум (случайный массив)

def noise(l):
   noise = np.array([np.random.uniform(0, 1, size=[l, ])])
   return noise

А потом тренирую модель:

for i in range(1000000):
    fake = generator.predict(noise(100))
    print(i, "==>", discriminator.predict(fake))
    discriminator.train_on_batch(genReal(1), np.array([1]))
    discriminator.train_on_batch(fake, np.array([0]))

    discriminator.trainable = False
    GAN.train_on_batch(noise(100), np.array([1]))
    discriminator.trainable = True

Как видите, я уже пробовал обучить модель для 1. Итераций Mio. Но генератор выводит данные, которые потом выглядят следующим образом (несмотря на разные входы):

введите здесь описание изображения

Определенно не то, что я хотел. Итак, мой вопрос: 1. Mio Iterations недостаточно, или что-то не так в концепции моей программы?

редактировать:

Это функция, с помощью которой я рисую свои данные:

def plotData(data):
    x = np.reshape(data,(50,2))
    x = x.tolist()
    plt.scatter(list(zip(*x))[0],list(zip(*x))[1], c=col)

person man zet    schedule 29.06.2018    source источник


Ответы (1)


Проблема с вашей реализацией заключается в том, что discriminator.trainable = False не имеет никакого эффекта после компиляции discriminator. Следовательно, все веса (как от дискриминатора, так и от генераторной сети) можно обучить, когда вы выполняете GAN.train_on_batch.

Решение этой проблемы - установить discriminator.trainable = False сразу после компиляции discriminator и перед компиляцией GAN:

discriminator.compile(loss='mean_squared_error', optimizer='adam')    
discriminator.trainable = False

ganInput = Input(shape=(100,))
x = generator(ganInput)
ganOutput = discriminator(x)

GAN = Model(inputs=ganInput, outputs=ganOutput)
GAN.compile(loss='binary_crossentropy', optimizer='adam')

ПРИМЕЧАНИЕ. Я построил ваши данные, и они выглядят примерно так:  Созданные данные

person rvinas    schedule 30.06.2018
comment
спасибо, теперь это работает намного лучше. Я добавил свою сюжетную функцию для лучшего понимания ... - person man zet; 01.07.2018