Обучение с учителем: методы классификации и регрессии

Это новый день в Werkroom, и мы здесь, чтобы наконец применить на практике некоторые алгоритмы машинного обучения! Добро пожаловать обратно в гонку машинного обучения RuPaul, где мы должны научить компьютеры уметь отличать хороший наряд для подиума от плохого. По сути, мы пытаемся выяснить, сможем ли мы уволить организаторов Fashion Photo RuView. В предыдущей статье мы говорили о настройке наших данных для обучения с учителем, а также о том, как проверить, что наша модель действительно работает. Теперь мы собираемся использовать эти стратегии для настройки и оценки различных типов алгоритмов классификации и регрессии.

Теперь, прежде чем мы начнем: я никогда не смогу объяснить математику, лежащую в основе какой-либо статистической техники, лучше, чем видео StatQuest. В мире машинного обучения и статистики Джош Стармер — мама, и зачем мне тратить время на попытки получить эту корону? Вместо этого я дам обзор каждой техники и того, как она работает на практике, но если вы хотите потрясающее, упрощенное, визуально понятное статистическое объяснение для каждого метода, просто перейдите по ссылкам в заголовках техник. Я также ссылаюсь на блокнот Colab, где я обрабатываю изображения и запускаю каждый из этих алгоритмов, если вы хотите продолжить.

Классификация

Цель алгоритмов классификации — использовать определенную стратегию для поиска наилучшего способа распределения данных по разным категориям. В этом случае мы пытаемся классифицировать изображение взлетно-посадочной полосы как «Toot» или «Boot».

Машины опорных векторов

Первый тип алгоритма классификации, который мы собираемся использовать, — это Машина опорных векторов (SVM). По сути, эта королева математически находит наилучшую «границу» между Toot и Boot. Представьте что-то вроде рисунка 1.

Таким образом, если бы у нас был новый наряд, который находился по ту же сторону от этой пунктирной линии (границамежду Тутс и Ботинками), что и красные наряды, мы бы решили, что этот новый наряд, который мы знаем, не о чем — это Boot.

Но как мы узнаем, куда в первую очередь должен идти наряд? Ну, во-первых, эта маленькая штучка, которую я нарисовал здесь, немного неправда. Вы знаете, как этот график имеет ось X и ось Y? Чтобы проиллюстрировать, что происходит за кулисами в SVM, нам понадобится ось для каждого пикселя изображения. К сожалению, я не очень хорошо рисую графики за пределами третьего измерения, так как я всего лишь человек. Но у наших компьютеров нет проблем с перемещением за пределы сотого измерения, поэтому они могут это понять. Только представьте, что это происходит так: ей удается провести некую 100-мерную разделительную «линию» между двумя категориями.

Опять же, если вы хотите разобраться в математике всего этого или посмотреть, как мы на самом деле вычисляем эту граничную линию, посмотрите видео об этом от Mother Starmer, ссылка на которое приведена в начале этого раздела. Но мы здесь не для этого, мы здесь для того, чтобы проникнуть в код!

Я использую только данные 14-го сезона, чтобы мы могли попрактиковаться на меньшем подмножестве всех доступных данных, которые почти бесконечны. Давайте посмотрим на общее распределение рейтингов, которые Раджа и Готтмик дали этим взлетно-посадочным полосам, на рисунке 2. В качестве быстрого напоминания я создал эти графики в этом блокноте Colab, если вы все еще хотите следить за ними.

В целом, мы получаем гораздо больше Тутов, чем Ботинков. Это может нарушить наш алгоритм: она работает намного лучше, когда в каждой категории примерно равное количество данных. Посмотрим, что из этого выйдет через секунду, но сначала нам нужно решить, что делать с категориями «Лучшие кадры» и «Лучшие кадры». Они лучше, чем обычные Toots, но для простоты я просто сгруппирую их с Toots (рис. 3). В этом случае намного проще работать всего с двумя категориями.

Я проверил нашу модель на рейтингах Раджи и Готтмика и вот наши результаты:

Raja:
Accuracy:  0.7391304347826086
Precision:  0.7391304347826086
Recall:  1.0

Gottmik:
Accuracy:  0.6956521739130435
Precision:  0.6956521739130435
Recall:  1.0

Теперь 73 % и 67 % точности и достоверности кажутся довольно хорошими, но 100 % отзыва — это смехотворно хорошо. Что здесь происходит — как он идеально идентифицировал всех Тутов? Давайте проверим матрицы путаницы на рисунке 4a-b.

Итак, похоже, что наши модели просто… догадались, что все это Тут. И можно ли их винить? Если бы я был машиной, ничего не понимающей, и я видел бы, что у меня будет более 50% шансов быть правильным, если бы я угадал, что наряд был Тутом, я бы догадался, что каждый раз . В этом суть машинного обучения: цель состоит только в том, чтобы создать модель, которая предсказывает лучше, чем случайный случай. Но хотя она может быть довольно точной в техническом смысле, она не так полезна для нас. Нам нужно найти еще несколько полезных алгоритмов.

"Логистическая регрессия"

Теперь мы попробуем кое-что под названием Логистическая регрессия. Технически это метод регрессии, о котором мы поговорим в следующем разделе, но на самом деле мы используем его как технику классификации. Это еще один метод, который создает «линию», чтобы показать разницу между Toots и Boots, но она отличается от SVM. Здесь происходит следующее: основываясь на пикселях, из которых состоят наши изображения одежды, мы вычисляем вероятность того, что что-то является Тутом или Ботинком. Сначала говорим ей несколько Тутов и Ботинков (с вероятностью оказаться Тутом как 1 и 0 соответственно). Затем она рисует кривую линию, как на рис. 5, чтобы определить, в какой момент мы переходим от Boot-worthy к Toot-worthy.

Теперь, если бы мы взяли новый наряд, наша модель увидела бы, где она ложится на изгиб. Если вероятность того, что она тут, больше 0,5 (50%), то мы предполагаем, что она тут!

Я запустил модели для каждого из наших хостов RuView, и вот результаты:

Raja:
Accuracy:  0.6521739130434783
Precision:  0.6666666666666666
Recall:  0.9333333333333333

Gottmik:
Accuracy:  0.6739130434782609
Precision:  0.6976744186046512
Recall:  0.9375

На этот раз мы не получаем 100% отзыва, так что у нас хорошее начало! Я также хотел бы отметить точность 66,6% и 69%, которую мы здесь получаем — не по какой-то конкретной математической причине, а просто потому, что это забавно.

Давайте проверим матрицы путаницы на рисунке 6a-b.

Итак... она предсказала, что некоторые из них были Ботинками! Она даже была права насчет некоторых из них! Мы можем сказать, что здесь происходит какое-то размышление — она просто немного глупа, так как ей все еще не хватало кучи сапог и даже несколько Тутов. Если бы у нас было больше данных о тренировках, возможно, она могла бы выступать немного лучше.

Регрессия

Мы пытались определить, является ли наряд Тутом или Сапогом, но в этих категориях есть немало вариаций. Пятерка сильно отличается от девятки, и технически я бы загудел каждую из них. Давайте посмотрим, можем ли мы использовать регрессию для прогнозирования числового рейтинга нарядов и посмотрим, работает ли она лучше, чем классификация.

"Линейная регрессия"

Линейная регрессия — это самый простой тип регрессии: поиск линии, которая показывает взаимосвязь между фотографиями одежды и их рейтингом.

Прежде чем мы начнем, мы посмотрим, как выглядит моя средняя система ранжирования. Я построил гистограмму числа каждого значения ранжирования на рисунке 7:

В том, как я ранжировал их в числовом выражении, гораздо больше разнообразия, чем в системе Toot/Boot, которую мы видели ранее. В любом случае, большинство из них попадают в диапазон от 5 до 8, и все они в конечном итоге являются Toots, хотя, как я уже говорил ранее, 5 НАМНОГО отличается от 8. Возможно, наша модель будет работать лучше с некоторой более подробной информацией.

Давайте посмотрим, насколько хорошо работает наша модель! Вот наши результаты:

Mean absolute error: 1.1609901663678173
Mean squared error: 2.311689216514142
Root mean squared error: 1.5204240252357701
R^2 score: -0.12132137731400894

Давайте снова разберем их:

  • MAE: в среднем разница между фактическими значениями и предсказанными моделью составляет 1,16 балла. Не супер, но могло быть намного хуже. Было бы неплохо, если бы у нас была разница менее 1 очка, так что это может стать целью на будущее.
  • MSE/RMSE: существует значительная разница между ошибками. Таким образом, хотя некоторые пункты могут быть более точными (разница в пределах 1 балла), есть и такие, которые намного превосходят это значение.
  • R^2: Отрицательное значение R^2 на самом деле довольно ужасно — по сути, это означает, что мы могли бы получать более точные прогнозы, если бы угадывали случайным образом.

Мы можем построить остатки, чтобы увидеть, как это выглядит. На рисунке 8 горизонтальная линия в центре представляет нулевую ошибку, а все, что выше и ниже ее, представляет величину ошибки.

Мы видим, что есть довольно много точек, которые находятся близко к линии нулевой ошибки или находятся на ней. Это хорошо! Однако есть некоторые прогнозы, отличающиеся почти на 4 балла, что не является фантастикой. Есть над чем поработать, но мы добились определенного прогресса!

Логистическая регрессия, опять же

Давайте попробуем логистическую регрессию на этих числовых рейтингах! Предполагается, что эта стратегия обычно используется для категорийных данных, но мы можем рассматривать каждый рейтинг как «категорию», и она отлично работает.

С точки зрения метрик мы видим, что это немного хуже, чем линейная регрессия:

Mean absolute error: 1.2222222222222223
Mean squared error: 2.8518518518518516
Root mean squared error: 1.6887426837300736
R^2 score: -0.15828896867648878

Давайте посмотрим на график остатков на рисунке 9:

Разброс довольно похож на график линейной регрессии, но мы видим, что прогнозы всегда представляют собой целые числа, поскольку они размещают их в фиксированном категориальном ранжировании каждого числа.

Контролируемое обучение: Toot или Boot?

В целом, ни одна из рассмотренных нами моделей не добилась особого успеха. Решение проблемы с ошибкой модели не всегда состоит в том, чтобы всегда предоставить ей больше данных, но в данном случае это может помочь. Теперь, когда вы освоились, попробуйте собрать больше данных и добавить их к тому, что мы уже скомпилировали здесь. Улучшаются ли модели? Или нам нужно изучить что-то совершенно другое?

В следующей статье мы отойдем от обучения нашей модели пониманию моды и вместо этого позволим ей делать все, что она хочет. Пришло время обучения без присмотра!

Нажмите здесь для Часть 4!