Глубокая адаптация домена с использованием PyTorch Adapt

Рассмотрим сценарий, в котором вы обучали модель цифрам MNIST:

И вам сказали, что ваша модель также должна работать с набором данных случайно окрашенных цифр:

Оказывается, у вас нет меток для этого набора данных, поэтому вы не можете использовать обучение с учителем для переобучения своей модели. Но вы можете использовать адаптацию предметной области, которая представляет собой тип алгоритма перепрофилирования существующих моделей для работы в других предметных областях.

Введите PyTorch Adapt, новую модульную библиотеку для адаптации к предметной области. Вы можете использовать его с ванильным PyTorch или с предоставленной оболочкой фреймворка. Давайте посмотрим, как это работает на задаче MNIST → MNIST-M.

Использование PyTorch Adapt для задачи MNIST → MNIST-M

Следующие фрагменты взяты из этой записной книжки Jupyter.

  1. Загрузите наборы данных и инициализируйте создателя загрузчика данных.

2. Настройте модели (G и C), алгоритм адаптации домена (DANN) и валидатор (IMValidator). Мы перемещаем большую часть модели классификатора (C) в ствол (G), потому что это лучше работает для DANN. Чтобы упростить наш код, мы будем использовать оболочку PyTorch Ignite.

3. Настройте хук визуализации. Для этой демонстрации я написал простую функцию для визуализации функций во время обучения. (Определение функции см. в блокноте.) Поскольку мы используем оболочку PyTorch Ignite, мы можем использовать систему обработки событий Ignite, чтобы добавить хук визуализатора.

4. Обучите модель. Здесь мы тренируем только 4 эпохи, хотя обычно для лучшей производительности требуется гораздо больше.

5. Сравните визуализации функций. Перед обучением функции MNIST (синие) хорошо сгруппированы, но мало перекрываются с функциями MNIST-M (оранжевые). После 4 периодов обучения перекрытие между двумя доменами увеличилось, что указывает на то, что модель адаптируется к новому домену.

6. Точность вычислений на МНИСТ-М… или нет? В реальном приложении вы не сможете вычислить точность, потому что целевые данные не имеют меток. Но для целей этой демонстрации мы все равно сжульничаем и проверим точность.

Наилучшая точность после 4 эпох составляет 65,6% по сравнению с начальной точностью 57,4%. Так что тренировки идут в правильном направлении.

Продолжение следует

В этом посте представлен краткий обзор PyTorch Adapt. В следующем посте я объясню, как легко настраивать алгоритмы с помощью модуля pytorch_adapt.hooks .

Для получения дополнительной информации, проверьте эти ссылки: