Недавно мне пришлось выступить с несколькими докладами о машинном обучении, глубоком обучении и компьютерном зрении. Я начал работать над набором слайдов и посмотрел, что делают другие в Интернете, чтобы передать такие концепции, как контролируемое обучение, CNN и т. Д.

Я нашла действительно красивые картинки и схемы. Они имели смысл для меня как для специалиста по данным, но по большей части они казались слишком абстрактными, чтобы передать их непосредственно нетехнической аудитории.

Кроме того, люди лучше учатся, делая, чем слушая. Поэтому я всегда стараюсь иметь какое-то взаимодействие с аудиторией, чтобы сделать мои презентации более интересными.

Что, если бы я мог провести весь процесс машинного обучения (сбор данных, обучение, тестирование) вживую на сцене?

Это сделало бы идею более осязаемой и конкретной. Так я получил вдохновение для разработки Deep Hive.

Полный исходный код доступен по адресу:
https://github.com/wouterdewinter/deep-hive

Глубокий улей

Основная концепция довольно проста:

  1. Люди в аудитории маркируют изображения и одновременно;
  2. Модель обучается с этими аннотациями и;
  3. Результаты теста отображаются на главном экране.

Задача - изучить модель машинного обучения для классификации изображений по ряду классов. Набор данных Dogs and Cats по умолчанию состоит из двух классов. С пользовательскими наборами данных возможно больше классов, но для большего количества классов потребуется больше данных для достижения разумной производительности.

Тренировочные и тестовые наборы

Чтобы предотвратить утечку информации, набор данных разделен на обучающий и тестовый набор. Изображения в обучающем наборе показываются пользователям в аудитории для маркировки. Изображения в тестовом наборе используются для оценки точности. Это изображения, отображаемые на приборной панели.

Для каждой метки изображения, отправленной пользователем, модель оценивает одно тестовое изображение. В качестве заявленной точности принимается средний балл последних 64 тестовых изображений.

Архитектура

Давайте погрузимся в мельчайшие подробности приложения.

Рабочий

Рабочий - это буквально рабочая лошадка приложения. Этот скрипт python работает в фоновом режиме и содержит модель Keras. Он получает сообщения из очереди сообщений, такие как новые аннотации. Он взаимодействует с моделью и отправляет статистику точности обратно в очередь сообщений.

Скрипт выполняется в одном потоке. У вас не может быть нескольких рабочих процессов, работающих с балансировкой нагрузки, потому что вам нужно, чтобы модель училась на каждой новой аннотации, а не только на ее части. Есть способы обучать модели распределенным способом, но это более сложный и не необходимый для небольшого количества данных, используемых здесь.

Модель

Для задач классификации изображений сверточные нейронные сети (CNN) являются золотым стандартом. Обычно для хорошей работы CNN требуется много обучающих данных. У нас нет такой роскоши, потому что у нас очень ограниченное время, чтобы позволить аудитории маркировать изображения. К счастью, мы можем значительно сократить объем необходимых данных с помощью «Transfer Leaning .

Используя эту технику, мы можем повторно использовать уже обученные слои из другой сети, поместить наши собственные слои поверх и обучать только эти последние слои.

Модель Deep Hive добавляет три уровня к стандартной модели VGG-16:

Layer (type)                 Output Shape              Param #
=================================================================
vgg16 (Model)                (None, 4, 4, 512)         14714688
_________________________________________________________________
global_average_pooling2d_1 ( (None, 512)               0
_________________________________________________________________
dense_1 (Dense)              (None, 256)               131328
_________________________________________________________________
dropout_1 (Dropout)          (None, 256)               0
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 514
=================================================================
Total params: 14,846,530
Trainable params: 131,842
Non-trainable params: 14,714,688
_________________________________________________________________

Всего 14 846 530 параметров, но нам нужно обучить только 131 842 из них. Милая!

Поскольку модели необходимо учиться на новых ярлыках изображений одну за другой, мы, по сути, проводим «онлайн-обучение». По этой причине я выбрал старомодный оптимизатор SGD (стохастический градиентный спуск) вместо Adam или другого оптимизатора с динамической скоростью обучения.

Очередь сообщений

Для простоты я выбрал Redis в качестве очереди сообщений. Redis - популярное хранилище ключей / значений в памяти, но имеет очень удобный механизм pubsub. Он действует как реле между рабочим и веб-сервером.

Веб сервер

Требования к веб-серверу довольно скромные. Обслуживайте несколько статических файлов, предоставляйте небольшой API-интерфейс для внешнего интерфейса и передавайте сообщения в Redis. Flask легкий и идеально подходит для этой задачи. Поскольку модель разделена через worker и Redis, вы можете без проблем запустить Flask в многопоточном режиме.

Внешний интерфейс

Панель управления довольно динамична и требует обновления нескольких частей при поступлении новых данных несколько раз в секунду. Простое приложение React выполнит эту работу. Приложение опрашивает новые данные каждые 300 мс. Это не очень хорошо масштабируется, но обычно одновременно работает только одна панель управления. Использование веб-сокетов было бы более масштабируемой альтернативой.

Аннотации

Экран аннотаций создан для мобильных устройств и показывает изображение члену аудитории и кнопку для каждого класса. Щелчок по одному из классов отправляет метку и немедленно запрашивает новое изображение для аннотации.

Щиток приборов

Панель приборов состоит из 4 частей:

  1. Сетка изображений показывает полный тестовый набор из 40 изображений. Он отображает изображение, прогнозируемую метку и цвет, указывающий на правильный или неправильный прогноз.
  2. Точность показывает точность набора тестов, а график на заднем плане показывает тенденцию. График визуализируется превосходной библиотекой d3.js.
  3. Счетчик аннотаций, ну… показывает количество аннотированных изображений.
  4. Короткий URL предназначен для входа аудитории со своего мобильного телефона. В своих презентациях я показываю QR-код на слайде перед демонстрацией Deep Hive, чтобы уберечь нескольких пользователей от набора текста.

Под коротким URL-адресом есть две кнопки:

  • Сброс приведет к сбросу модели и повторной оценке тестового набора;
  • Simulate имитирует аудиторию, помещая метки изображений с полным обучающим набором в очередь сообщений.

Наборы данных

Я попробовал приложение с двумя разными наборами данных.

Основным из них является Набор данных Kaggle Dogs vs.Cats. Благодаря этому я получил около 90% точности примерно 250 пользовательских аннотаций. Если в вашей аудитории 25 человек, это обычно занимает меньше минуты!

Для клиента я работаю над распознавателем цветов. Используя изображения из двух классов в их частном наборе данных, я получил более 90% примерно в 160 аннотациях. Вроде бы для модели это более легкая задача.

Заключение

Надеюсь, вам понравилось приложение и возможность заглянуть под капот. Если людям интересно, я могу добавить еще несколько функций в ближайшее время. Например, ползунок скорости обучения или некоторые другие элементы управления гиперпараметрами. Пожалуйста, дайте мне знать в комментариях, что бы вы хотели от следующей версии.

Не стесняйтесь использовать Deep Hive для своих презентаций. Мне любопытно узнать, как это происходит!