Легко устраняйте дисбаланс классов с помощью Pytorch

Что вы можете сделать, если ваша модель переоценивает ваши данные? Эта проблема часто возникает, когда мы имеем дело с несбалансированным набором данных. Если ваш набор данных представляет несколько классов, один из которых представлен гораздо меньше, чем другие, тогда трудно узнать истинное базовое распределение, представляющее такой второстепенный класс.

Как объясняется в этой статье, которую необходимо прочитать, метод устранения дисбаланса классов, который стал доминирующим почти во всех проанализированных сценариях, - это передискретизация. Передискретизация должна применяться к уровню, который полностью устраняет дисбаланс, тогда как Оптимальный коэффициент недостаточной дискретизации зависит от степени дисбаланса. В отличие от некоторых классических моделей машинного обучения, передискретизация не вызывает переобучения CNN.

На практике при обучении модели машинного обучения нужно выполнить несколько ключевых шагов:

  1. Разделите данные на набор для обучения / тестирования (80%, 20%).
  2. Обучите модель машинного обучения, сопоставив ее с данными обучения.
  3. Оцените производительность на тестовом наборе.

При использовании архитектур глубокого обучения принято разделять обучающие данные на пакеты, которые мы загружаем в нашу нейронную сеть во время обучения. Чтобы создать такие пакеты, мы обычно случайным образом выбираем из обучающего набора, следуя равномерному распределению по набору наблюдений.

Теперь нужна простая статистика. Предположим, у нас есть набор данных с 2 классами class_1 и class_2. Какова вероятность случайной выборки точки, скажем, из class_1?

Следуя равномерному распределению по множеству точек, такую ​​вероятность легко выразить:

На практике дисбаланс классов в бинарной задаче возникает, когда у нас гораздо больше наблюдений из одного класса, чем из другого:

В результате имеем:

Другими словами, более вероятно, что точка будет нарисована из class_1, чем из class_2. Поскольку модель видит намного меньше class_2, неудивительно, что она не способна изучать полезные функции из такого класса…

Теперь, прежде чем погрузиться в кодирование, нам нужно понять ключевую идею искусственного увеличения данных. Мы хотим убедиться, что, искусственно увеличивая второстепенный класс, мы имеем:

Время пришло! Давайте создадим код для решения этой проблемы с помощью WeightedRandomSampler от Pytorch.

Набор данных: мы создаем набор данных из 900 наблюдений из class_major с меткой 0 и 100 наблюдений из class_minor с меткой 1. ( 90%, 10%)

Предположив, что мы построим 10 пакетов по 100 предложений в каждом, мы получим в среднем 10 предложений класса 1 и 90 предложений класса 0.

Как я могу легко сбалансировать вышеуказанное? Давайте напишем несколько строк кода, используя библиотеку Pytorch.

Из приведенного выше видно, что WeightedRandomSampler использует массив example_weights, который соответствует весам, присвоенным каждому классу. Цель состоит в том, чтобы присвоить более высокий вес второстепенному классу. Это повлияет на вероятность получения точки из каждого класса путем перехода от равномерного распределения к полиномиальному распределению.

Теперь мы можем подробно рассмотреть пакеты, содержащиеся в arr_batch, на практике каждый из них должен содержать 100 предложений. Для наглядности мы сосредоточимся только на этикетках.

Как видно из рисунка выше, теперь у нас есть сбалансированные пакеты данных. В результате во время обучения наша модель не будет видеть значительно больше одного класса над другим, и, следовательно, риски переобучения снижаются.

В итоге мы увидели, что:

  1. Избыточная выборка - ключевая стратегия устранения классового дисбаланса и, следовательно, снижения рисков переобучения.
  2. Случайная выборка из вашего набора данных - плохая идея, когда есть дисбаланс классов.
  3. Взвешенная случайная выборка с помощью WeightedRandomSampler перебалансирует наши классы обучающих данных путем передискретизации второстепенного класса.

В следующей статье мы погрузимся в реализацию WeightedRandomSampler, чтобы лучше понять схему взвешивания. Мы также применим передискретизацию в простом сценарии машинного обучения и проанализируем ее влияние на общую производительность.

Спасибо за чтение, пожалуйста, оставьте комментарий ниже, если у вас есть какие-либо отзывы! 🤗