Научитесь кодировать алгоритм быстрого обучения на наборе данных Omniglot

Вступление

Мы, люди, можем распознать класс по нескольким примерам этого класса. Например, ребенку нужно всего два или три изображения кролика, чтобы различить это животное среди других видов. Эта способность учиться на ограниченных данных превосходит любой классический алгоритм машинного обучения. Многие люди думают, что человечество низвергается искусственным интеллектом, но вот правда: чтобы иметь возможность хорошо различать классы, классификатор часто загружается несколькими тысячами изображений на класс… в то время как нам нужно всего два или три!

Прототипные сети - это алгоритм, представленный Snell et al. в 2017 году (в «Прототипах сетей для обучения с использованием нескольких кадров) », в котором рассматривается парадигма обучения с использованием нескольких кадров. Давайте разберемся с этим шаг за шагом на примере. В этой статье наша цель - классифицировать изображения персонажей. Предоставленный код находится в PyTorch, доступен здесь.

Набор данных Omniglot

В обучении с использованием нескольких кадров нам предоставляется набор данных с несколькими изображениями на класс (обычно от 1 до 10). В этой статье мы будем работать с набором данных Omniglot, который содержит 1623 различных рукописных символа, собранных из 50 алфавитов. Этот набор данных можно найти в этом репозитории GitHub. Я использовал файлы images_background.zip и images_evaluation.zip.

Как предлагается в официальном документе, увеличение данных выполняется для увеличения количества классов. На практике все изображения поворачиваются на 90 °, 180 ° и 270 °, каждый поворот приводит к дополнительному классу. Как только это увеличение данных выполнено, у нас есть 1623 * 4 = 6492 класса. Я разделил весь набор данных на обучающий набор (изображения 4200 классов) и тестовый набор (изображения 2292 классов).

Выберите образец

Для создания выборки классы Nc выбираются случайным образом среди всех классов. Для каждого класса у нас есть два набора изображений: набор поддержки размера Ns и набор запросов размера Nq.

Вставить изображения

«Наш подход основан на идее, что существует вложение, в котором точки группируются вокруг одного представления прототипа для каждого класса». утверждают авторы оригинальной статьи.

Другими словами, существует математическое представление изображений, в котором изображения одного класса собираются в группы, называемые кластерами. Основное преимущество работы в этом пространстве для встраивания состоит в том, что два одинаковых изображения будут расположены близко друг к другу, а два совершенно разных изображения будут далеко.

В нашем случае с набором данных Omniglot блок внедрения принимает изображения (28x28x3) в качестве входных данных и возвращает столбцы с 64-мерными точками. Функция image2vector состоит из 4 модулей. Каждый модуль состоит из сверточного слоя, пакетной нормализации, функции активации ReLu и максимального уровня объединения 2x2.

Вычислить прототипы классов

На этом этапе мы вычисляем прототип для каждого кластера. После встраивания опорных изображений векторы усредняются для формирования прототипа класса, своего рода «делегата» для этого класса.

где v (k) - прототип класса k, f_phi - функция встраивания, а xi - опорные изображения.

Вычислить расстояния между запросами и прототипами

Этот шаг состоит в классификации изображений запроса. Для этого мы вычисляем расстояние между каждым изображением без надписи и прототипами. Выбор метрики здесь имеет решающее значение, и изобретатели прототипических сетей должны быть признаны их выбором расстояния: евклидова расстояния.

После вычисления расстояний над ними выполняется softmax, чтобы получить вероятности принадлежности к каждому классу. Чем короче расстояние, тем выше вероятность.

Вычислите потерю и сделайте обратное распространение

Фаза изучения прототипных сетей происходит за счет минимизации отрицательной логарифмической вероятности, также называемой логарифмически-softmax потерями. Основное преимущество использования логарифма - резко увеличить потери, когда модель не может предсказать правильный класс.

Обратное распространение осуществляется через стохастический градиентный спуск (SGD).

Начать обучение

Вся описанная выше последовательность образует эпизод. А тренировочный этап состоит из нескольких эпизодов. Я попытался воспроизвести результаты оригинальной статьи. Вот настройки тренировки:

  • NC: 60 классов
  • Ns: 1 или 5 точек поддержки / класс
  • Nq: 5 точек запроса / класс
  • 5 эпох
  • 2000 серий / эпоха
  • Скорость обучения изначально составляет 0,001 и делится на 2 для каждой эпохи.

Тренировка длилась 30 минут.

Полученные результаты

После обучения ProtoNet мы можем протестировать его с новыми данными. Аналогичным образом отбираем образцы в тестовой выборке. Набор поддержки используется для вычисления прототипов, а затем каждая точка набора запроса помечается в соответствии с меньшим расстоянием до прототипов.

Для тестирования я пробовал 5- и 20-ходовые сценарии. Я набрал такое же количество точек поддержки и запросов, что и на этапе обучения. Испытания проводились на 1000 эпизодах.

Результаты представлены в таблице ниже. «5-ходовой 1 выстрел» означает Nc = 5 и Ns = 1.

Я получил те же результаты, что и в исходной статье, в некоторых случаях немного лучше. Это может быть связано со стратегией выборки, которая не указана в документе. Я использовал случайную выборку для каждого эпизода.

Надеюсь, Прототипные Сети больше не являются для вас секретом! Мы смогли создать классификатор изображений, используя всего несколько примеров для каждого класса.

Полный код доступен на моем GitHub: https://github.com/cnielly/prototypical-networks-omniglot