Реализация «Hello World!» сверточных нейронных сетей

Большая часть прогресса, достигнутого в области глубокого обучения за последние годы, связана с концепцией сверточных нейронных сетей или CNN. Эти сети стали де-факто стандартом во всех, кроме самых тривиальных, задачах обработки изображений. Основные концепции CNN берут начало с 1980-х годов, первое приложение для распознавания изображений было опубликовано в 1989 году. Как и многие другие темы в области глубокого обучения, большие успехи пришли с большей вычислительной мощностью, и одним из основных факторов стало использование графических процессоров вместо процессоров для обучения, начиная с середины 2000-х годов.

Мы рассмотрим одну из самых важных статей в этой области, а именно статью 1998 Градиентное обучение в применении к распознаванию документов. Эту статью стоит прочитать более чем через 20 лет после публикации по разным причинам. Известны авторы, в том числе Ян Лекун и Джошуа Бангио, которые вместе с Джеффри Хинтоном считаются крестными отцами глубокого обучения и вместе получили премию Тьюринга 2018 года за свою работу в этой области.

Другая проблема заключается в том, что эта статья не является чисто исследовательской работой, а разработанные в ней решения были коммерчески применены для распознавания рукописных цифр в корпорации NCR. Поскольку в документе всесторонне описываются усилия при сосредоточении внимания на архитектуре нейронной сети, нам нужно прочитать только разделы с 1 по 3 из всего 10 разделов. Длина всей статьи составляет 46 страниц, что само по себе также заслуживает внимания. выше обычного.

Самая важная проблема здесь заключается в том, что сеть LeNet-5, обсуждаемая в статье, обычно считается одной из самых актуальных в истории CNN. Иногда его называют «Hello World!» CNN. В пользу этого есть несколько аргументов:

  • «Реализация LeNet-5» дает вам более 130 000 посещений в Google, довольно много из них актуальны.
  • Использование текущих библиотек, реализующих базовую архитектуру, выполняется с помощью нескольких строк кода, в отличие от больших исследовательских усилий, описанных в статье.
  • Вы быстро получите неплохие результаты по набору данных MNIST. В настоящее время этого недостаточно для награждения за лучшую бумажную премию, но, вероятно, лучше, чем пытаться с нуля.
  • Обучение сети займет от около 100 секунд на CPU и всего несколько секунд на GPU за эпоху, в зависимости от используемых гиперпараметров. Это дает вам разумное время для экспериментов с сетью.

С другой стороны, повторная реализация сети из бумаги была бы действительно сложной задачей, поскольку во время ее реализации авторам приходилось делать все вручную. В 1998 году не было графических процессоров, Python не был языком искусственного интеллекта, и не было установленных библиотек глубокого обучения. Делая все с нуля, они также сделали множество оптимизаций, которые трудно воспроизвести с использованием современных библиотек. Большинство реализаций, которые вы найдете с помощью Google, просто увеличивают сеть, чтобы быстро получить такие же хорошие результаты по MNIST. Так что это не так просто, как с "Hello World!" чтобы получить сопоставимые результаты, но это по-прежнему хорошая основа для начала работы с CNN.

Начало работы с кодированием

Мы даже не будем пытаться воспроизвести статью точно, но в отличие от большинства других реализаций, найденных в Интернете, мы начнем с версии, следующей за статьей, по крайней мере, в некоторых соответствующих параметрах. Код будет на Python и будет использовать Keras на Tensorflow. Используя высокоуровневую библиотеку, такую ​​как Keras, код для сети почти тривиален, что также облегчит эксперименты с ней.

Базовая структура сети подробно описана в документе в разделе II.B. Всего у него семь слоев.

  • Сверточный слой с 6 ядрами 5x5 с заполнением, поэтому мы фактически добавили изображения 32x32 в качестве входных данных, в отличие от исходных изображений MNIST 28x28.
  • Слой объединения 2x2. Здесь воспроизведение статьи уже становится затруднительным, поскольку вы поймете, что читаете ее. Для простоты мы используем средний пул.
  • Сверточный слой с 16 ядрами 5x5 без заполнения. Опять же, это упрощение структуры статьи.
  • Еще один уровень пула 2x2, использующий средний пул, опять же не настоящий.
  • Полностью связанный слой с 120 нейронами.
  • Полностью связанный слой с 84 нейронами.
  • Слой softmax, чтобы наконец получить 10 возможных выходных классов. Хотя это снова отличается от бумаги, полностью связанные слои - настоящая вещь.

Функция активации, используемая в статье в tanh, и функция потерь аналогична MSE или среднеквадратической ошибке. Оба были оптимизированы для работы с бумагами, поскольку их нужно было кодировать вручную, а не просто использовать их из библиотеки высокого уровня, как это делаем мы. В качестве оптимизатора мы выбрали простой SGD или стохастический градиентный спуск. Сеть обучалась в течение 20 эпох со скоростью обучения «0,0005 для первых двух проходов, 0,0002 для следующих трех, 0,00005 для следующих 4 и 0,00001 после», как явно указано в документе.

Для создания модели в Керасе достаточно нескольких строк кода:

model = keras.Sequential(
  [
    keras.Input(shape=input_shape),
    layers.Conv2D(6, kernel_size=(5, 5), padding=’same’, activation=’tanh’),
    layers.AveragePooling2D(pool_size=(2, 2)),
    layers.Conv2D(16, kernel_size=(5, 5), padding=’valid’, activation=”tanh”),
    layers.AveragePooling2D(pool_size=(2, 2)),
    layers.Flatten(),
    layers.Dense(120, activation=”tanh”),
    layers.Dense(84, activation=”tanh”),
    layers.Dense(num_classes, activation=”softmax”)
  ]
)

Если вы посмотрите на сводку модели и сравните ее с подробным описанием сети, вы увидите, что только слои C1, C5 и F6 соответствуют бумаге в отношении количества обучаемых параметров. Это потому, что мы не пытались воспроизвести оптимизацию, сделанную в документе.

Model: “sequential”
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 28, 28, 6) 156
_________________________________________________________________
average_pooling2d (AveragePo (None, 14, 14, 6) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 10, 10, 16) 2416
_________________________________________________________________
average_pooling2d_1 (Average (None, 5, 5, 16) 0
_________________________________________________________________
flatten (Flatten) (None, 400) 0
_________________________________________________________________
dense (Dense) (None, 120) 48120
_________________________________________________________________
dense_1 (Dense) (None, 84) 10164
_________________________________________________________________
dense_2 (Dense) (None, 10) 850
=================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0

Обучение и подсчет очков

Существенная проблема воспроизведения статьи не ограничивается сетевой архитектурой. Следующий вопрос, на который нужно ответить, - как структурировать тренировочный процесс. С положительной стороны, набор данных MNIST, используемый в документе, по-прежнему является набором данных MNIST, используемым сегодня. В разделе III.A статьи есть интересная и краткая история MNIST, которую стоит прочитать.

Помимо функции потерь, оптимизатора и скорости обучения мы должны принять решение о размере пакета, а также о том, следует ли и как использовать перекрестную проверку во время обучения. Низкие скорости обучения, приведенные в документе, в сочетании с обучением только 20 эпохам, имеют смысл только в том случае, если пакеты не использовались, то есть размер пакета равен 1. Обратите внимание, что использование графического процессора дает соответствующее ускорение только при использовании микропакетов, поэтому обучение займет какое-то время с этими параметрами.

В документе также исследуется влияние количества обучающих данных на производительность сети и прямо указывается, что было использовано 60 000 обучающих изображений из MNIST. Все эти обучающие изображения доступны в стандартном MNIST. Таким образом, мы не используем перекрестную проверку, то есть разделение проверки на 0.

В результате получается следующий код для обучения сети:

lr_list = [0.0005, 0.0005, 0.0002, 0.0002, 0.0002, 0.00005,
           0.00005, 0.00005, 0.00005, 0.00001, 0.00001, 0.00001,
           0.00001, 0.00001, 0.00001, 0.00001, 0.00001, 0.00001,
           0.00001, 0.00001]
def calc_lr():
  elem = lr_list[0]
  del lr_list[0]
  return(elem)
optimizer = keras.optimizers.SGD(learning_rate=calc_lr)
model.compile(loss=”mean_squared_error”, optimizer=optimizer,
              metrics=[“accuracy”])
model.fit(x_train, y_train, batch_size=1, epochs=20,
          validation_split=0.0)

Обучение сети выводит точность на обучающем наборе, и, если используется перекрестная проверка, точность на проверочном наборе. Единственная важная метрика для сети - это производительность на тестовом наборе, поскольку она дает разумное представление о том, как сеть будет работать с реальными данными. В моем случае обучение сети с этими параметрами дало точность 0,9257 на данных обучения и 0,9292 на данных тестирования. Это хороший признак того, что точность на тестовом наборе выше, и намекает на хорошее обобщение сети.

Но является ли 93% хорошей точностью в целом? На самом деле нет. В статье сравнивается несколько методов и моделей для классификации MNIST в разделе C, и худший из упомянутых там методов имеет точность около 95%. Важным результатом статьи является то, что LeNet-5, как предлагается там, имеет точность 99% или более по MNIST. Так что даже в 1998 году у нас есть еще немало возможностей для получения награды за лучшую работу.

Получение лучшего результата

Как упоминалось ранее, в статье упоминается множество оптимизаций, и, вероятно, стоит потрудиться, по крайней мере, на докторскую диссертацию, чтобы попытаться воспроизвести их. Так что мы не пойдем по этому пути. Более простой способ добиться лучших результатов - это взять базовую структуру LeNet-5 и использовать современные передовые практики в CNN, в отличие от 20 с лишним лет назад. Еще один путь, который стоит рассмотреть, - это увеличение размера сети, поскольку обучение в любом случае занимает очень мало времени с использованием современного оборудования.

Чтобы увидеть, что возможно, мы воспользуемся сокращением, используя одну из многих реализаций, найденных в Интернете, например, это первое попадание в мой поиск Google:

Https://hackmd.io/@bouteille/S1WvJyqmI

Код немного сложнее нашего, вероятно, потому, что использовалась более старая версия Keras. В нашем выборе не так уж много различий, на самом деле сеть выглядит точно так же. Потенциально значимые различия:

  • Использование категориальной перекрестной энтропии вместо среднеквадратичной ошибки в качестве функции потерь.
  • Использование скорости обучения по умолчанию 0,01 в отличие от скорости обучения из статьи.
  • Использование размера пакета по умолчанию, равного 32, вместо отсутствия пакетов, то есть размера пакета 1.
  • Использование перекрестной проверки.

Эти простые адаптации дают значительно улучшенный результат: точность 0,9890 для данных обучения и 0,9870 для данных теста в моем случае. Возвращаясь к первому размеру пакета и отсутствию перекрестной проверки, я получил точность обучающих данных 1,0000, что предполагает переобучение. Но поскольку на тестовых данных я также получил 0,9888, мне было все равно. Таким образом, просто изменив функцию потерь и увеличив скорость обучения, мы немного приблизимся к награде за лучшую работу, если бы мы уже вручали нашу работу только в 1998 году.

Резюме

Реализовать базовую структуру LeNet-5 легко, но вы не сможете воспроизвести работу, описанную в статье. Однако производительность набора данных MNIST этой сети, которой более 20 лет, по-прежнему впечатляет. Очень важно использовать хорошие гиперпараметры! Хотя использование кросс-энтропии для классификации может считаться лучшей практикой, а не MSE для регрессии, есть и другие варианты, по которым следует принять решение. Использование CNN на MNIST было сделано так часто, что максимальная точность хорошо известна.

Альтернативой является Fashion MNIST, который может заменить классический MNIST без каких-либо изменений в сети, поскольку разрешение изображения такое же. Он также имеет 10 классов, но цель состоит в том, чтобы классифицировать статьи о моде, а не рукописные цифры. Для дальнейших экспериментов с «Hello World!» из CNN это более интересный набор данных.