MultiMix: экстремальное многозадачное обучение с умеренным контролем на основе медицинских изображений

В этой статье я расскажу о новом методе полуконтролируемой многозадачной медицинской визуализации под названием MultiMix, разработанном Аяаном Хаком (я), Абдуллой-Аль-Зубайром Имраном, Адамом Вангом и Деметри Терзопулосом. Наша статья была принята на ISBI 2021 в полном составе и представлена ​​на конференции в апреле. Расширение нашей статьи с улучшенными результатами также было опубликовано в MELBA Journal. В этой статье будет рассмотрен обзор методов, результатов и краткий обзор кода. Код доступен здесь.

Обзор:

MultiMix выполняет совместную полууправляемую классификацию и сегментацию, используя стратегию увеличения на основе достоверности и новый модуль моста значимости, который обеспечивает объяснимость для совместных задач. Модели на основе глубокого обучения при полном контроле могут быть эффективными при выполнении сложных задач анализа изображений, но эта производительность в значительной степени зависит от доступности больших наборов размеченных данных. Особенно в области медицинской визуализации метки дороги, требуют много времени и подвержены вариациям наблюдателя. В результате полууправляемое обучение, которое позволяет учиться на ограниченном количестве размеченных данных, было исследовано как альтернатива контролируемым аналогам.

Более того, изучение нескольких задач в рамках одной и той же модели еще больше улучшает обобщаемость модели. Кроме того, многозадачность позволяет изучать совместное представление между задачами, требуя при этом меньше параметров и меньше вычислений, что делает модели более эффективными и менее подверженными переоснащению.

Наши обширные эксперименты с различным количеством размеченных данных и данных из нескольких источников доказывают эффективность нашего метода. Кроме того, мы также представляем как внутридоменные, так и междоменные оценки задач, чтобы продемонстрировать потенциал нашей модели для адаптации к сложным сценариям обобщения, что является сложной, но важной задачей для методов медицинской визуализации.

Предыстория:

Проблема

Медицинская визуализация, основанная на обучении, в последние годы выросла в основном из-за развития глубокого обучения. Однако всегда сохраняется фундаментальная проблема глубокого обучения, заключающаяся в том, что для его эффективности требуется много размеченных данных. К сожалению, это еще более серьезная проблема в области медицинской визуализации, поскольку сбор больших наборов данных и аннотаций может быть трудным, поскольку они требуют экспертных знаний в предметной области, являются дорогостоящими, трудоемкими и трудными для организации в централизованных наборах данных. Более того, обобщение является ключевой проблемой в области медицинской визуализации, поскольку изображения из разных источников могут значительно различаться как качественно, так и количественно, что усложняет процесс построения модели, если мы хотим добиться высокой производительности в нескольких областях. Мы надеемся решить эти фундаментальные проблемы с помощью нескольких ключевых подходов, основанных на полуконтролируемом и многозадачном обучении.

Что такое частично контролируемое обучение?

Чтобы решить проблему ограниченности помеченных данных, полууправляемое обучение (SSL) привлекло большое внимание как многообещающая альтернатива. При полууправляемом обучении немаркированные примеры используются в сочетании с помеченными примерами, чтобы максимизировать получение информации. Было проведено множество исследований в области полуконтролируемого обучения, как общих, так и в области медицины. Я не буду подробно обсуждать эти методы, но вот список известных методов, на которые можно сослаться, если вам интересно [1, 2, 3, 4].

Другим решением для обучения с ограниченной выборкой является использование данных из нескольких источников, поскольку это увеличивает количество выборок в данных, а также разнообразие данных. Тем не менее, это сложно и требует специальных методов обучения, но если все сделано правильно, это может быть очень эффективным.

Что такое многозадачное обучение?

Многозадачное обучение (MTL) было исследовано для улучшения обобщаемости многих моделей. Многозадачное обучение определяется как оптимизация более чем одной потери в одной модели, так что несколько связанных задач выполняются посредством обучения с общим представлением. Совместное обучение нескольких задач в рамках модели улучшает обобщаемость модели, поскольку каждая из задач упорядочивает друг друга. Кроме того, если предположить, что обучающие данные поступают из разных дистрибутивов для разных задач с ограниченными аннотациями, многозадачность может быть полезна в таких сценариях для обучения практически без контроля. Сочетание многозадачности с полуконтролируемым обучением может повысить производительность и добиться успеха в этих двух задачах. Выполнение этих двух задач одновременно может быть чрезвычайно полезным, поскольку вместо специалистов с медицинским образованием одна модель глубокого обучения может выполнять обе задачи с удивительной точностью.

Что касается смежной работы в медицинской сфере, то я не буду слишком подробно останавливаться на методах, но вот список: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]. Однако основные ограничения этих работ заключаются в том, что они не используют данные из нескольких источников, что ограничивает их обобщаемость, а также в том, что большинство методов являются только однозадачными.

Таким образом, мы предлагаем новую, более обобщенную многозадачную модель под названием MultiMix, включающую в себя достоверную аугментацию и модуль моста значимости, чтобы совместно изучать диагностическую классификацию и сегментацию анатомических структур на основе данных из нескольких источников. Карта заметности позволяет анализировать предсказания модели посредством визуализации значимых визуальных признаков. Карта заметности может быть создана несколькими способами, в первую очередь путем вычисления градиента оценки класса на основе входного изображения. Хотя любую модель глубокого обучения можно исследовать на предмет лучшей объяснимости с помощью карты заметности, насколько нам известно, мост значимости между двумя общими задачами в одной модели еще не исследовался.

Алгоритм:

Начнем с определения нашей проблемы. Мы используем два набора данных для обучения, один для сегментации и один для классификации. Для данных сегментации мы можем использовать обозначения Xs и Y, которые представляют собой изображения и маски сегментации соответственно. Для данных классификации мы можем использовать обозначения Xc и C, которые являются изображениями и метками классов.

Что касается архитектуры нашей модели, мы используем базовую архитектуру U-Net, которая является широко используемой архитектурой сегментации с использованием структуры кодировщик-декодер. Кодер работает аналогично стандартной CNN. Чтобы выполнить многозадачность с U-Net, мы отходим от кодировщика с пулом и полностью подключенными слоями, чтобы получить окончательный результат классификации.

Классификация:

Для предлагаемого нами метода классификации мы используем увеличение данных и псевдомаркировку. Вдохновленные [1], мы берем неразмеченное изображение и выполняем две отдельные аугментации. Во-первых, неразмеченное изображение слабо дополнено, и из этой слабо дополненной версии изображения предполагается псевдометка на основе предсказания текущего состояния модели. Вот почему этот метод является полуконтролируемым, но о процессе псевдомаркировки мы поговорим чуть позже.

Во-вторых, то же самое точное неразмеченное изображение сильно дополнено, а потеря вычисляется с помощью псевдометки из слабо дополненного изображения и самого сильно дополненного изображения. По сути, мы учим модель сопоставлять слабо дополненное изображение с сильно дополненным изображением, и это заставляет модель изучать фундаментальные базовые функции, необходимые для диагностической классификации. Двойное увеличение изображений также максимизирует потенциальные знания, получаемые от одного изображения. Это также помогает улучшить обобщение, поскольку, если модель вынуждена изучать только самые важные части изображения, она сможет преодолеть различия, которые появляются на изображениях из-за разных доменов.

Что касается аугментаций, мы используем обычные аугментации для слабо дополненных изображений, таких как горизонтальное переворачивание и небольшие повороты. Стратегия сильного увеличения гораздо интереснее. Мы создаем пул нетрадиционных, сильных аугментаций и применяем случайное количество аугментаций к любому заданному изображению. Эти дополнения довольно искажают, например, включают обрезку, автоконтраст, яркость, контрастность, выравнивание, идентичность, вращение, резкость, сдвиг и многое другое. Применяя любое их количество, мы создаем чрезвычайно широкий спектр изображений, что особенно важно при работе с наборами данных с низкой выборкой. В конечном итоге мы обнаружили, что эта стратегия увеличения очень важна для высокой производительности.

Теперь вернемся назад и обсудим процесс псевдомаркировки. Поэтому, как только слабые дополнения преобразуются в псевдометки, мы используем только их. Обратите внимание, что эта метка изображения, если достоверность, с которой модель генерирует псевдометку, выше настроенного порога, что предотвращает обучение модели от неправильных и плохих этикетки. Это приводит к бесплатному эффекту учебной программы, поскольку, когда прогнозы вначале менее надежны, модель учится в основном на размеченных данных. Модель становится более уверенной с созданием меток для немаркированных изображений, и в результате модель становится более эффективной. Это также очень важная функция с точки зрения повышения производительности.

Теперь давайте посмотрим на функцию потерь. Потеря классификации может быть смоделирована следующим уравнением:

где L-sub-l — контролируемая потеря, c-hat-l — прогноз классификации, c-l — метка, лямбда — неконтролируемый вес классификации, L-sub-u — неконтролируемая потеря, c-hat-s — это предсказания на сильно расширенных изображениях, argmax(c-hat-w) — псевдометки слабо дополненных изображений, t — порог псевдомаркировки.

Это, по сути, резюмирует метод классификации, поэтому теперь давайте перейдем к методу сегментации.

Сегментация:

Для сегментации прогнозы выполняются через архитектуру кодер-декодер с пропуском соединений, что довольно просто. Нашим основным вкладом в сегментацию является включение модуля моста значимости для соединения двух задач, как показано на рисунке выше. Мы генерируем карты значимости на основе классов, предсказанных моделью, используя градиенты от кодировщика, расширенного до ветви классификации. Весь процесс показан выше, но, по сути, карта заметности показывает, какие части изображения модель использует для классификации изображения пневмонии. При визуализации они становятся похожими на карту сегментации, что делает ее идеальным дополнением для моста сегментации.

Хотя мы не знаем, представляют ли изображения сегментации пневмонию, сгенерированные карты выделяют легкие, создавая изображения с окончательным разрешением сегментации. Таким образом, когда предсказание класса изображения создается и визуализируется с помощью карты заметности, оно чем-то напоминает маску легких. Мы предполагаем, что эти карты заметности можно использовать для управления сегментацией на этапе декодера, обеспечивая улучшенную сегментацию при обучении на ограниченных размеченных данных.

В MultiMix сгенерированные карты значимости объединяются с входными изображениями, подвергаются субдискретизации и добавляются к картам объектов, входным на первом этапе декодера. Объединение с входным изображением обеспечивает более сильную связь между двумя задачами и повышает эффективность модуля моста из-за предоставляемого им контекста. Добавление как входного изображения, так и карты заметности предоставляет декодеру больше контекста и информации, которые могут быть действительно важны при работе с данными с малой выборкой.

Теперь поговорим о тренировках и потерях. Для меченых образцов мы обычно рассчитываем потери при сегментации, используя потери в кости между эталонной маской легкого и прогнозируемой сегментацией.

Поскольку у нас нет масок сегментации для немаркированных образцов сегментации, мы не можем напрямую рассчитать для них потери сегментации. Итак, для этого мы вычисляем расхождение KL между предсказаниями сегментации для помеченных и немеченых примеров. Это наказывает модель за то, что она делает прогнозы, которые все больше отличаются от прогнозов помеченных данных, что помогает модели лучше соответствовать неразмеченным данным. Хотя это косвенный метод вычисления потерь, он все же позволяет модели многому научиться на немаркированных данных сегментации.

Что касается потерь, наши потери сегментации можно записать как:

Где альфа — это вес потерь сегментации по сравнению с классификацией, y-hat-l — это предсказания помеченной сегментации, y-l — соответствующие маски, бета — это вес неконтролируемой сегментации, а y-hat-u — предсказания неразмеченной сегментации.

Наша модель обучена объединенной цели классификации и потери сегментации. Теперь, когда мы обсудили потери, это завершает как методы сегментации, так и весь раздел методов.

Наборы данных:

Модели были обучены и протестированы для задач классификации и сегментации, и данные для каждой задачи получены из двух разных источников: набора данных для обнаружения пневмонии, который мы назовем CheX [11], и Японского общества радиологических технологий, или JSRT [ 12] для классификации и сегментации соответственно. Когда мы говорим о наборах данных в домене, это два набора данных.

Важно проверить модели на двух внешних наборах данных, по одному для каждой задачи. Мы использовали рентгенографию грудной клетки округа Монтгомери или MCU [13] и подмножество набора данных рентгенографии грудной клетки NIH, которое мы будем называть NIHX [14]. Разнообразие источников представляет собой серьезную проблему для нашей модели, поскольку качество изображения, размер, соотношение нормальных и аномальных изображений и несоответствие в распределении интенсивности четырех наборов данных весьма различаются. На рисунках ниже показаны различия в распределении интенсивности вместе с примерами изображений из каждого набора данных. Все 4 набора данных имеют лицензию CC BY 4.0.

Результаты:

Мы провели множество экспериментов с различным количеством размеченных данных на нескольких наборах данных, как внутри, так и между доменами.

Чтобы предварить результаты, мы использовали несколько базовых показателей в наших тестах, поскольку у нас есть базовый уровень для каждого дополнения к нашей модели. Мы начнем с базовой U-Net и стандартного классификатора (enc), который представляет собой экстрактор функций кодировщика с плотными слоями. Затем мы объединили их для нашей базовой многозадачной модели (UMTL). Мы также использовали кодировщик с полууправляемым методом (EncSSL), многозадачную модель с мостом значимости (UMTLS) и многозадачную модель с мостом значимости и предлагаемым полууправляемым методом (UMTLS-SSL). , который в основном представляет собой MultiMix без расхождения KL для полуконтролируемой сегментации. Тогда у нас, конечно, есть MultiMix.

Что касается обучения, мы тренировались на нескольких уровнях помеченных наборов данных. Для классификации мы использовали 100, 1000 и все метки, а для сегментации использовали 10, 50 и все метки. Для наших результатов мы будем использовать обозначение: модель-разделители-метки классов (например, MultiMix-10–100). Для оценки мы использовали показатели точности (Acc) и F1 (F1-N и F1-P) для классификации, а для сегментации мы использовали сходство в кости (DS), показатель сходства Жаккара (JS), индекс структурного сходства (SSIM). , среднее расстояние Хаусдорфа (HD), точность (P) и отзыв (R).

На следующем рисунке представлена ​​таблица производительности MultiMix по сравнению с несколькими базовыми уровнями. Наилучшие результаты, полученные при полном контроле, подчеркнуты, а лучшие результаты, полученные при частичном контроле, выделены жирным шрифтом.

В таблице показано, как производительность модели улучшается при последующем включении каждого из новых компонентов. Для задачи классификации наш основанный на достоверности подход к обучению с полуучителем значительно улучшил производительность по сравнению с базовыми моделями. Даже при минимальном количестве размеченных данных для каждой задачи наш MultiMix-10–100 превосходит полностью контролируемый базовый кодировщик с точки зрения точности. Что касается сегментации, включение модуля моста значимости дает значительные улучшения по сравнению с базовыми моделями U-Net и UMTL. Даже с минимальными метками сегментации мы видим прирост производительности на 30% по сравнению с аналогами, что доказывает эффективность предложенной нами модели MultiMix.

Мы уделяем большое внимание важности обобщения, и наши результаты показывают, что наша модель способна достаточно хорошо обобщать. MultiMix стабильно работает в доменах с улучшенной обобщаемостью в обеих задачах. Как видно из таблицы, производительность MultiMix столь же многообещающая, как и у внутридоменных. MultiMix добился лучших результатов в задаче классификации по сравнению со всеми базовыми моделями. Из-за существенных различий в наборах данных NIHX и CheX, как обсуждалось ранее, оценки не так хороши, как результаты внутри предметной области. Однако он действительно работает лучше, чем другие модели.

На следующем рисунке показана диаграмма, показывающая согласованность результатов нашей сегментации как для внутридоменных, так и для междоменных оценок. Мы отображаем результаты игральных костей наших моделей для каждого изображения в наборе данных друг против друга. Из графика видно, что MultiMix является самой сильной моделью по сравнению с базовыми показателями.

Последний рисунок, который мы обсудим, — это визуализация прогнозов сегментации нашей модели. Мы показываем прогнозируемую границу для каждого из предложенных дополнений к задаче сегментации в сравнении с основной истиной для различных помеченных данных как для внутри-, так и для междоменных областей. На рисунке показано сильное совпадение с предсказаниями границ MultiMix относительно границы правды, особенно по сравнению с базовой линией. Особенно для междоменных задач MultiMix является лучшим со значительным отрывом, демонстрируя нашу сильную способность к обобщению.

Просмотр кода:

Теперь, когда мы рассмотрели методы и результаты, давайте приступим к коду. В основном я расскажу об архитектуре модели и цикле обучения, так как это основные области моего вклада. Обратите внимание, что код написан на PyTorch и Python.

Давайте начнем с проверки наших сверточных блоков.

Каждый блок является блоком двойной свертки. Мы начинаем с 2-мерного сверточного слоя с размером ядра 3, а затем используем слой нормализации экземпляра и функцию активации LeakyReLU с отрицательным наклоном 0,2. Затем мы повторяем эту последовательность снова, чтобы закончить блок свертки.

Теперь давайте взглянем на мост значимости.

Этот код используется только для создания карты значимости. Сначала мы передаем входные данные, кодировщик и оптимизатор. Затем мы создаем копию изображений, чтобы убедиться, что градиенты изображения не изменены. Затем мы устанавливаем для ввода require_grad значение true и устанавливаем кодировщик в режим eval. Затем мы получаем карты функций и выходные данные кодировщика, чтобы мы могли генерировать карты значимости. Сначала мы получаем максимальный индекс результатов классификации, а затем используем функцию .backward() для сбора градиентов. Затем мы получаем карту значимости, собирая градиенты с помощью функции .abs(). Важно отметить, что мы должны обнулить градиенты оптимизатора, потому что при обратном вычислении градиенты могут быть проблематичными при обновлении параметров модели.

Теперь, когда мы рассмотрели компоненты архитектуры, давайте соберем все вместе и проверим всю архитектуру.

Мы разделяем модель на отдельные модули кодировщика и декодера и объединяем их в классе MultiMix. Для кодировщика мы используем масштабирование блоков double_conv каждый раз в 2 раза. Глядя на прямую функцию, мы сохраняем карты объектов после каждого блока свертки, которые используются для пропуска соединений между кодировщиком и декодером, и мы используем слои с максимальным объединением для деконструкции изображения. Затем мы добавляем ветвь классификации для многозадачности, используя средний объединяющий слой и плотный слой, чтобы получить окончательный результат классификации (outC). Мы возвращаем все карты объектов, а также прогноз классификации для использования декодером.

Затем в декодере мы используем слои свертки, которые уменьшают карты объектов, и мы используем слои повышения дискретизации для восстановления изображения. Функция переадресации — это место, где происходит все волшебство. Мы начинаем с объединения и сложения карты заметности с исходным изображением. Затем мы уменьшаем входные данные, чтобы их можно было объединить в первом сверточном блоке вместе с пропускным соединением. Для следующих сверточных блоков мы просто выполняем стандартную деконволюцию и пропускаем соединения, чтобы получить окончательный результат (out).

После того, как мы построили модель, мы можем построить наш цикл обучения. Это довольно длинный и пугающий блок кода, так что не волнуйтесь, мы его разберем.

Прежде чем мы обсудим цикл, обратите внимание, что мы пропустили многие методы и цикл обучения, чтобы упростить его.

Если мы начнем со строки 52, мы начнем с объединения всех обучающих наборов данных, включая обучающий набор контролируемой сегментации, обучающий набор немаркированной сегментации, обучающий набор контролируемой классификации, слабо расширенный набор классификации и сильно расширенный набор классификации. Последние два имеют одинаковые точные изображения, но просто дополнены на разных уровнях. Следующий набор строк — это просто базовое объединение и объединение данных, чтобы все данные проходили через модель единообразно.

Как только мы передаем все входные данные в модель, мы передаем их все функции calc_loss. В функции calc_loss мы начинаем с получения базовой контролируемой классификации и потери сегментации (dice и lossClassifier). Мы используем потери в костях для сегментации и перекрестную энтропию для классификации.

Для полуконтролируемой классификации мы начинаем с передачи слабо дополненных предсказаний изображения через функцию softmax, чтобы получить вероятности, и мы используем функцию torch.max, чтобы получить метку. Затем мы используем функцию .ge, чтобы сохранить только те прогнозы, которые превышают доверительный порог, что является важным фактором, обсуждаемым в методах. Затем мы вычисляем неконтролируемую потерю классификации (lossUnlabeled).

Наконец, мы вычисляем расхождение KL, используя прогнозы сегментации с пометкой и без пометки (kl_seg). После завершения всех вычислений мы объединяем их в один расчет потерь, суммируя все потери после того, как они умножаются на соответствующие веса (лямбда, альфа, бета). Как только это передается обратно в основной цикл обучения, мы просто вычисляем градиенты с помощью loss.backward() и обновляем параметры модели с помощью optimizer.step().

На этом раздел обзора кода заканчивается. Мы не стали рассматривать часть аугментации и обработки данных, так как это довольно утомительно. Если вам интересно, ознакомьтесь с полным кодом в следующем репозитории: https://github.com/ayaanzhaque/MultiMix.

Заключение и мысли:

В этом сообщении блога мы объясняем MultiMix, новую модель многозадачного обучения с умеренным наблюдением для совместного изучения задач классификации и сегментации. Благодаря включению расширения согласованности и нового модуля моста значимости для лучшей объяснимости MultiMix выполняет улучшенное и последовательное обнаружение пневмонии и сегментацию легких даже при обучении на данных с ограниченным числом меток и данных из нескольких источников. Наши обширные эксперименты с использованием четырех различных наборов данных рентгенографии грудной клетки действительно демонстрируют эффективность MultiMix как при внутридоменных, так и при междоменных оценках для любой задачи. Наша будущая работа будет сосредоточена на дальнейшем улучшении междоменной производительности MultiMix, особенно для классификации. В настоящее время мы готовим полную запись в журнал с дополнительными результатами и расширениями работы.

Выполнение этой работы было для меня очень увлекательным. Будучи старшеклассником, я благодарен за возможность работать с квалифицированными и опытными исследователями для проведения передовых исследований. Весь процесс был для меня довольно сложным, так как у меня было мало опыта в том, как писать формальные исследовательские работы и проводить правильные и убедительные эксперименты. Даже кодирование и создание реальных дополнений заняло довольно много времени. Я все еще знакомился с PyTorch, но работать над этим проектом было очень весело и увлекательно, и я так много узнал о глубоком обучении и медицинской визуализации. Я в восторге от конференции, поскольку у меня есть возможность встретиться с коллегами-исследователями и узнать о новых исследованиях в этой области, и я уверен, что наша будущая работа будет иметь такой же успех, как и этот проект. Спасибо за чтение.

Если какая-либо часть этого блога или документа показалась вам интересной, рассмотрите возможность ссылки на:

@article{melba:2021:011:haque,
    title = "Generalized Multi-Task Learning from Substantially Unlabeled Multi-Source Medical Image Data",
    authors = "Haque, Ayaan and Imran, Abdullah-Al-Zubaer and Wang, Adam and Terzopoulos, Demetri",
    journal = "Machine Learning for Biomedical Imaging",
    volume = "1",
    issue = "October 2021 issue",
    year = "2021"
}