Обучайте плотные слои с помощью tensorflow GradientTape вместо Keras model.fit()

Интересно посмотреть, можно ли заменить встроенный метод Keras model.fit() во время обучения на tensorflow GradientTape. Я покажу, как tf.GradientTape может заменить model.fit() практически без изменений, точно воспроизводя обучение Keras и вычисления градиента. Есть много ситуаций, когда это может пригодиться, особенно при оценке пользовательских функций потерь. Это может стать хорошей отправной точкой для разработки специально обученных моделей нейронных сетей и лучшего понимания того, как работает поезд Keras.

В этом тесте я буду использовать небольшой набор данных из примера Iris и сравню результаты обучения обоих методов — model.fit() и tf.GradientTape.

import numpy as np
import datetime
from tensorflow.keras.models import Sequential 
from tensorflow.keras.layers import Dense 
from tensorflow.keras import optimizers
import tensorflow as tf
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt

Набор данных поезда ириса (ограниченный 150 образцами) был взят из здесь, данные предварительно обработаны с помощью X_rand.dat, содержащего геометрические характеристики четырех цветков ириса и Y_rand. dat, содержащий горячие метки типа, обе случайным образом зашифрованные. Полное описание набора данных Iris см. см..

X = np.loadtxt('X_rand.dat', dtype=float)
Y = np.loadtxt('Y_rand.dat', dtype=float)
print('X.shape = , Y.shape =', X.shape, Y.shape)

Теперь давайте создадим две модели Keras — одну для обучения с помощью model.fit(), а другую с помощью tf.GradientTape(). Для простоты я буду использовать оптимизатор SGD с альфа=0,01 и нулевым импульсом. Что касается стандартной проблемы с несколькими классами, я буду использовать активацию softmax и расчет потерь с помощью категориальной кроссэнтропии Keras.

Также для начальных смещений и весов слоев Dense я использую инициализацию ones — мне нужно, чтобы иметь возможность сравнивать точные результаты вычислений двух методов обучения, начиная с одного и того же начального условия и удаляя неопределенность, возникающая из-за инициализации весов по умолчанию в Keras, которая использует случайные значения. Однако обратите внимание, что инициализации с весами единиц, даже если она работает нормально для этого примера, следует избегать, поскольку она вызывает числовую нестабильность.

opt_fit = optimizers.SGD( learning_rate=0.01, momentum=0.0, nesterov=False, name='SGD')
model_fit = Sequential()
model_fit.add(Dense(8, input_dim=4, activation='relu', kernel_initializer='ones', bias_initializer='ones'))
model_fit.add(Dense(3, activation='softmax', kernel_initializer='ones', bias_initializer='ones'))
model_fit.compile(loss='categorical_crossentropy', optimizer=opt_fit, metrics=tf.keras.metrics.CategoricalAccuracy())

opt_tape = optimizers.SGD( learning_rate=0.01, momentum=0.0, nesterov=False, name='SGD')
model_tape = Sequential()
model_tape.add(Dense(8, input_dim=4, activation='relu', kernel_initializer='ones', bias_initializer='ones'))
model_tape.add(Dense(3, activation='softmax', kernel_initializer='ones', bias_initializer='ones'))
model_tape.compile(loss='categorical_crossentropy', optimizer=opt_tape, metrics=tf.keras.metrics.CategoricalAccuracy())

Обучение с помощью model.fit()

Давайте сначала посмотрим на результаты обычного обучения Keras с использованием model.fit() с заданным размером пакета:

EPOCHS = 200
SAMPLES = 150
BATCH_SIZE = 10

history = model_fit.fit(X, Y, epochs=EPOCHS, batch_size=BATCH_SIZE, verbose=2, shuffle = False)

loss_fit = history.history['loss']
acc_fit  = history.history['categorical_accuracy']

В каждую эпоху Керас печатает потери и точность. За 200 эпох потери уменьшились, а точность обучения достигла 0,9733, что весьма неплохо для такой простой оптимизации и инициализации весов. Значения потерь и точности метода fit() во время обучения сохраняются в массивах с именами loss_fit и acc_fit.

Обучение с tf.GradientTape

Теперь давайте посмотрим, насколько хорошо tf.GradientTape будет воспроизводить предыдущий результат обучения. Я разделю данные на пакеты размером batch_size. Нам нужно рассчитать потери и точность самостоятельно. Процесс включает в себя расчет средних потерь и точности для каждой партии с последующим определением градиентов ошибок и, наконец, обновлением весов модели с использованием метода, описанного ниже:

loss_tape = []
acc_tape = []
batch_num = (int)(SAMPLES / BATCH_SIZE)

for i in range(EPOCHS):  
    avg_loss_tape = 0
    avg_acc_tape = 0
    start = datetime.datetime.now()
    for j in range(batch_num):
        start_idx=BATCH_SIZE*j
        end_idx=BATCH_SIZE*(j+1)    
        X_batch = X[start_idx:end_idx]
        y_batch = Y[start_idx:end_idx]    

        with tf.GradientTape() as tape:
            pred = model_tape(X_batch)
            loss = tf.reduce_mean(tf.keras.losses.categorical_crossentropy(y_batch, pred))  
            grads = tape.gradient(loss, model_tape.trainable_weights)
            opt_tape.apply_gradients(zip(grads, model_tape.trainable_variables)) 
            avg_loss_tape += loss.numpy()
            avg_acc_tape += accuracy_score(np.argmax(y_batch,axis=1),np.argmax(pred,axis=1))

    elapsed = datetime.datetime.now() - start   
    avg_loss_tape = avg_loss_tape / batch_num
    avg_acc_tape = avg_acc_tape / batch_num
    print('Epoch:%d'%i, '\nloss: %.4f' % avg_loss_tape, " - categorical_accuracy: %.4f" % avg_acc_tape, " - %d [ms]" % (elapsed.total_seconds() * 1000))
    loss_tape.append(avg_loss_tape)
    acc_tape.append(avg_acc_tape)

Потери и точность из последовательности tf.GradientTape() записываются для каждой эпохи в массивы loss_tape и acc_tape. Теперь посмотрим, насколько хорошо эти результаты перекрываются с рассчитанными ранее значениями loss_fit и acc_fit:

Кривые почти точно совпадают друг с другом. Небольшие различия связаны с точностью вычислений с плавающей запятой. Это демонстрирует, что в этом тесте градиентная лента точно соответствует обучению Keras model.fit()!

Аналогичные результаты могут быть получены и для разных batch_size. Также вы можете удалить инициализацию ones и увидите, что потери и точность начинают различаться, но обе модели в конечном итоге снова достигают одинаковых результатов обучения.

Заключение.

С помощью этого теста я продемонстрировал, что TensorFlow GradientTape может служить точной заменой метода Keras fit(). Аналогичные результаты можно получить и для более сложных моделей, которые я покажу в дополнительном исследовании.

Однако из приведенных выше результатов следует отметить одну вещь: обучение с вычислением градиентов с использованием GradientTape может быть намного медленнее, чем метод fit(). Для этого есть несколько причин, model.fit() имеет различные оптимизации и является рекомендуемым методом обучения в Keras.

Исходный код этого поста доступен на Kaggle.