Вышел TorchMetrics v0.10, значительно изменивший весь пакет классификации. В этом сообщении блога будут рассмотрены причины, по которым пакет классификации нуждается в рефакторинге, что это означает для наших конечных пользователей и, наконец, какие преимущества он дает. Руководство по обновлению кода до последних изменений можно найти внизу.

Почему показатели классификации должны измениться

Нам давно известно, что изначально были проблемы с тем, как мы структурировали классификационный пакет. По сути, задачи классификации можно разделить на binary, multiclass или multilabel, и определить, для какой задачи пользователь пытается запустить данную метрику, сложно только на основе входных данных. Причина, по которой такой пакет, как sklearn, может это сделать, заключается в том, что он поддерживает ввод только в очень специфических форматах (без многомерных массивов и без поддержки целых форматов и форматов вероятности/логита).

Это означало, что некоторые метрики, особенно для бинарных задач, могли рассчитывать что-то отличное от ожидаемого, если бы пользователь предоставил другую форму, но не ожидаемую. Это противоречит основной ценности TorchMetrics, состоящей в том, что наши пользователи, конечно же, должны верить в то, что оцениваемая ими метрика дает исключительный результат.

Кроме того, метрикам классификации не хватало согласованности. Для одних показателей num_classes=2 означало двоичное, а для других num_classes=1 означало двоичное. Подробнее о причинах этого рефакторинга вы можете прочитать в этой и этой проблеме.

Решение

Мы решили разделить каждую метрику классификации на три отдельные метрики с префиксами binary_*, multiclass_* и multilabel_*. Это решает многие из вышеперечисленных проблем из коробки, потому что нам становится легче соответствовать ожиданиям наших пользователей для любой заданной входной формы. Кроме того, он имеет некоторые другие преимущества как для нас, как для разработчиков, так и для конечных пользователей.

  • Удобство сопровождения: разделяя код на три отдельные функции, мы (надеемся) снижаем сложность кода, облегчая поддержку кодовой базы в долгосрочной перспективе.
  • Скорость: полностью убрав автоматическое определение задачи во время выполнения, мы можем значительно увеличить скорость вычислений (подробнее об этом позже).
  • Аргументы, зависящие от задачи: путем разделения на три функции мы также уточняем, какие входные аргументы влияют на вычисляемый результат. Возьмите Accuracy в качестве примера: оба num_classes , top_k , average являются аргументами, которые влияют, если вы выполняете multiclass классификацию, но ничего не делаете для binary классификации, и наоборот с аргументом thresholds. Версии для конкретных задач содержат только те аргументы, которые влияют на данную задачу.

В рефакторинге скрыто много небольших улучшений качества жизни, однако вот наши лучшие 3:

Стандартизированные аргументы

Входные аргументы для пакета классификации теперь гораздо более стандартизированы. Вот несколько примеров:

  • Каждая метрика теперь поддерживает только аргументы, влияющие на конечный результат. Это означает, что num_classes удалено из всех binary_* метрик, теперь требуется для всех multiclass_* метрик и переименовано в num_labels для всех multilabel_* метрик.
  • Аргумент ignore_index теперь поддерживается ВСЕМИ метриками классификации и поддерживает любое значение, а не только значения в диапазоне [0,num_classes] (аналогично функциям потери факела). Ниже показан пример:

  • Мы добавили новый validate_args ко всем метрикам классификации, чтобы пользователи могли пропускать проверку входных данных, что полностью ускоряет вычисления. По умолчанию мы по-прежнему будем выполнять проверку ввода, потому что это самый безопасный вариант для пользователя. Тем не менее, если вы уверены, что входные данные для метрики верны, теперь вы можете отключить это, проверив потенциальное ускорение (подробнее об этом позже).

Реализации постоянной памяти

Одними из наиболее полезных показателей для оценки проблем классификации являются такие показатели, как ROC, AUROC, AveragePrecision и т. д., поскольку они не только оценивают вашу модель для одного порогового значения, но и для целого диапазона пороговых значений, по сути, давая вам возможность увидеть торговлю. между Ошибками типа I и типа II. Однако большая проблема со стандартной формулировкой этих показателей (которую мы использовали) заключается в том, что для их расчета требуется доступ ко всем данным. Наша реализация была чрезвычайно требовательна к памяти для таких метрик.

В версии 0.10 TorchMetrics все эти показатели теперь имеют аргумент с именем thresholds. По умолчанию это None, и метрика по-прежнему будет сохранять все цели и прогнозы в памяти, как вы привыкли. Однако, если этот аргумент вместо этого установлен на тензор - torch.linspace(0,1,100), он вместо этого будет использовать приближение с постоянной памятью, оценивая метрику при указанных пороговых значениях.

Параметр thresholds=None имеет приблизительный объем памяти, равный O(num_samples), тогда как использование thresholds=torch.linspace(0,1,100) имеет приблизительный объем памяти, равный O(num_thresholds). В этом конкретном случае пользователи будут экономить память, когда метрика вычисляется для более чем 100 выборок. Эта функция может сэкономить память, сравнивая ее с современным машинным обучением, где оценка часто выполняется на тысячах или миллионах точек данных.

Это также означает, что метрики Binned*, которые в настоящее время существуют в TorchMetrics, устарели, поскольку их функциональность теперь определяется этим аргументом.

Все метрики быстрее (да)

Разбивая каждую метрику на 3 отдельные метрики, мы уменьшаем количество необходимых вычислений. Поэтому мы изначально ожидали, что наши новые реализации будут быстрее. В таблице ниже показано время различных метрик в старой и новой реализациях (с проверкой ввода и без нее). Цифры в скобках обозначают ускорение по сравнению со старыми реализациями.

Можно сделать следующие наблюдения:

  • Некоторые метрики стали немного быстрее (в 1,3 раза), а другие намного быстрее (в 4,6 раза) после рефакторинга!
  • Отключение проверки ввода может действительно ускорить работу. Например, multiclass_confusion_matrix ускоряется с 3,36x до 4,81, когда проверка ввода отключена. Явное преимущество для пользователей, знакомых с метриками и не нуждающихся в проверке своих данных при каждом обновлении.
  • Если сравнить binaryс multiclass, то наибольшее ускорение можно увидеть для мультиклассовых задач.
  • Все показатели быстрее, за исключением кривой точности-отзыва, даже новый метод аппроксимативного биннинга. Это немного странно, так как неаппроксимация должна быть такой же быстрой (это тот же код). Мы активно изучаем это.

Как обновить код?

Ни одно из вышеперечисленных изменений не приведет к каким-либо критическим изменениям в версии 0.10, и пользователи должны ожидать, что их код будет работать нормально. Единственным изменением является следующее предупреждение

Этим предупреждением мы хотим сообщить пользователям, что начиная с версии 0.11 будут внесены существенные критические изменения. Начиная с версии 0.11, все текущие метрики будут обертками вокруг своих binary_*, multiclass_* и multilabel_* аналогов, где task аргумент должен быть установлен. Обновление вашего кода должно быть довольно простым. Давайте рассмотрим конкретный пример:

from torchmetrics import Accuracy
metric = Accuracy(num_classes=10)

Вот как будет инициализирована метрика мультиклассовой точности до версии 0.10. Начиная с версии 0.10 и выше рекомендуемый способ инициализации одного и того же примера — импорт метрики MulticlassAccuracy.

from torchmetrics.classification import MulticlassAccuracy
metric = MulticlassAccuracy(num_classes=10)

Важно отметить, что все специализированные метрики классификации необходимо импортировать на один уровень выше, чем их неспециализированные, например, вместо from torchmetrics import use from torchmetrics.classification import . Равным способом обновления кода является инициализация Accuracy класса с новым аргументом task.

from torchmetrics import Accuracy
metric = Accuracy(task='multiclass', num_classes=10)

это вернет равную копию MulticlassAccuracy . Начиная с версии 0.11, все метрики базовой классификации будут просто оболочками для своих специализированных аналогов.

Мы надеемся, что эти критические изменения не станут слишком большими проблемами для сообщества.

Наконец, мы хотим прояснить, что этот рефакторинг имеет значительный размер:

Если есть какая-то функция, которую мы упустили в рефакторинге, или у вас есть ошибки, о которых нужно сообщить, не стесняйтесь создавать задачу в репозитории.

Спасибо!

Большое спасибо всем членам нашего сообщества за их вклад и отзывы. Если у вас есть какие-либо рекомендации для следующей метрики, которую мы должны попытаться решить, пожалуйста, откройте вопрос в репозитории.

Мы рады видеть продолжающуюся адаптацию TorchMetrics в более чем 2400+ проектах, и этот выпуск также отмечает точку, в которой мы пересекли 1000+ звезд GitHub.

Наконец, если вы хотите попробовать открытый исходный код, у нас есть каналы #new_contributors и #metrics на общем слабом канале PyTorch-lightning, где вы можете задать вопрос и получить рекомендации.

🔥 Смотри! 🚀