Легко устраняйте дисбаланс классов с помощью Pytorch
Что вы можете сделать, если ваша модель переоценивает ваши данные? Эта проблема часто возникает, когда мы имеем дело с несбалансированным набором данных. Если ваш набор данных представляет несколько классов, один из которых представлен гораздо меньше, чем другие, тогда трудно узнать истинное базовое распределение, представляющее такой второстепенный класс.
Как объясняется в этой статье, которую необходимо прочитать, метод устранения дисбаланса классов, который стал доминирующим почти во всех проанализированных сценариях, - это передискретизация. Передискретизация должна применяться к уровню, который полностью устраняет дисбаланс, тогда как Оптимальный коэффициент недостаточной дискретизации зависит от степени дисбаланса. В отличие от некоторых классических моделей машинного обучения, передискретизация не вызывает переобучения CNN.
На практике при обучении модели машинного обучения нужно выполнить несколько ключевых шагов:
- Разделите данные на набор для обучения / тестирования (80%, 20%).
- Обучите модель машинного обучения, сопоставив ее с данными обучения.
- Оцените производительность на тестовом наборе.
При использовании архитектур глубокого обучения принято разделять обучающие данные на пакеты, которые мы загружаем в нашу нейронную сеть во время обучения. Чтобы создать такие пакеты, мы обычно случайным образом выбираем из обучающего набора, следуя равномерному распределению по набору наблюдений.
Теперь нужна простая статистика. Предположим, у нас есть набор данных с 2 классами class_1 и class_2. Какова вероятность случайной выборки точки, скажем, из class_1?
Следуя равномерному распределению по множеству точек, такую вероятность легко выразить:
На практике дисбаланс классов в бинарной задаче возникает, когда у нас гораздо больше наблюдений из одного класса, чем из другого:
В результате имеем:
Другими словами, более вероятно, что точка будет нарисована из class_1, чем из class_2. Поскольку модель видит намного меньше class_2, неудивительно, что она не способна изучать полезные функции из такого класса…
Теперь, прежде чем погрузиться в кодирование, нам нужно понять ключевую идею искусственного увеличения данных. Мы хотим убедиться, что, искусственно увеличивая второстепенный класс, мы имеем:
Время пришло! Давайте создадим код для решения этой проблемы с помощью WeightedRandomSampler от Pytorch.
Набор данных: мы создаем набор данных из 900 наблюдений из class_major с меткой 0 и 100 наблюдений из class_minor с меткой 1. ( 90%, 10%)
Предположив, что мы построим 10 пакетов по 100 предложений в каждом, мы получим в среднем 10 предложений класса 1 и 90 предложений класса 0.
Как я могу легко сбалансировать вышеуказанное? Давайте напишем несколько строк кода, используя библиотеку Pytorch.
Из приведенного выше видно, что WeightedRandomSampler использует массив example_weights, который соответствует весам, присвоенным каждому классу. Цель состоит в том, чтобы присвоить более высокий вес второстепенному классу. Это повлияет на вероятность получения точки из каждого класса путем перехода от равномерного распределения к полиномиальному распределению.
Теперь мы можем подробно рассмотреть пакеты, содержащиеся в arr_batch, на практике каждый из них должен содержать 100 предложений. Для наглядности мы сосредоточимся только на этикетках.
Как видно из рисунка выше, теперь у нас есть сбалансированные пакеты данных. В результате во время обучения наша модель не будет видеть значительно больше одного класса над другим, и, следовательно, риски переобучения снижаются.
В итоге мы увидели, что:
- Избыточная выборка - ключевая стратегия устранения классового дисбаланса и, следовательно, снижения рисков переобучения.
- Случайная выборка из вашего набора данных - плохая идея, когда есть дисбаланс классов.
- Взвешенная случайная выборка с помощью WeightedRandomSampler перебалансирует наши классы обучающих данных путем передискретизации второстепенного класса.
В следующей статье мы погрузимся в реализацию WeightedRandomSampler, чтобы лучше понять схему взвешивания. Мы также применим передискретизацию в простом сценарии машинного обучения и проанализируем ее влияние на общую производительность.
Спасибо за чтение, пожалуйста, оставьте комментарий ниже, если у вас есть какие-либо отзывы! 🤗