Глубокие CNN, используемые для сегментации, часто страдают от исчезающих градиентов. Можем ли мы бороться с этим, вычисляя потери при различных уровнях выпуска?

Сегментация изображения влечет за собой разделение пикселей изображения на разные классы. Некоторые приложения включают идентификацию областей опухоли на медицинских изображениях, разделение участков земли и воды на изображениях с беспилотников и т. д. В отличие от классификации, где CNN выводят вектор оценки вероятности класса, для сегментации требуется, чтобы CNN выдавали изображение.

Соответственно, традиционные архитектуры CNN настраиваются для получения желаемого результата. Для сегментации изображений доступен ряд архитектур, включая преобразователи. Но помимо улучшения дизайна сети, исследователи постоянно экспериментируют с другими хаками, чтобы улучшить производительность сегментации. Недавно я наткнулся на работу, в которой кратко описывалась идея расчета потерь на множественных выходных уровнях (потеря глубокого наблюдения). В этом посте я хочу поделиться тем же.

Мы реализуем модель, аналогичную UNet, широко используемой архитектуре сегментации, и обучаем ее с потерей контроля, используя подкласс модели Keras. Вы можете обратиться к прикрепленным ссылкам Github/Kaggle для получения кода. Я предполагаю, что вы знакомы с основами Keras. Поскольку нам нужно многое охватить, я свяжу все ресурсы и пропущу некоторые вещи, такие как проигрыш в кости, обучение keras с использованием model.fit, генераторы изображений и т. д.

Давайте сначала начнем с понимания сегментации изображения.

1. Сегментация изображения

Проще говоря, сегментация — это классификация пикселей. Если на изображении есть кошка и собака, мы хотим, чтобы машина определяла пиксели кошки и собаки и помечала их как 1 (кошка) или 2 (собака) на выходе. Каждый второй пиксель (фон, шум и т. д.) равен 0. Для обучения таких моделей мы используем пары изображений и масок.

Допустим, вы хотите идентифицировать опухоли головного мозга на МРТ. Сначала вы создадите обучающий набор положительных (опухоли) и отрицательных (неопухолевых) изображений. Для каждого вы затем создадите соответствующую маску. Как это делается? Сделайте МРТ-сканирование, найдите область опухоли, преобразуйте все значения пикселей в этой области в 0 и установите все остальные пиксели в 1. Естественно, маски без опухолей будут абсолютно черными. Модель, обученная на этих парах (Ввод = МРТ, Выход = Маска), будет работать для выявления опухолей при МРТ-сканировании. Полезно, не так ли?

Теперь давайте углубимся в архитектуру нейронной сети, необходимую для сегментации изображений.

2. Унет

Обычно CNN ловко определяют, что присутствует на изображении. Для сегментации CNN также необходимо научиться точно позиционировать составляющие изображения. UNet оборудован именно для этого. В оригинальной статье UNet она описывается как сеть, разделенная на две части — сжимающуюся (кодировщик) и расширяющуюся (декодер). Начнем с части кодировщика (обратите внимание, я внес небольшие изменения в архитектуру, представленную в документе UNet).

# Important Libraries to import
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
import tensorflow as tf
from tensorflow.keras import Input
from tensorflow.keras.models import Model, load_model, save_model
from tensorflow.keras.layers import Input, Activation, BatchNormalization, Lambda, Conv2D, Conv2DTranspose,MaxPooling2D, concatenate,UpSampling2D,Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint,ReduceLROnPlateau
from tensorflow.keras import backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import KFold
from tensorflow.keras.losses import BinaryCrossentropy
import random

2.1 Энкодер

Кодер работает как обычный CNN. Он постоянно разбивает входные данные, чтобы различать функции, связанные с объектами на изображении. Этот процесс повторяется в нескольких блоках (блоки кодировщика). Каждый блок состоит из следующего:

  1. Два свёрточных слоя с отступами и ядрами (3,3) последовательно (мы будем называть это свёрточным блоком). Можно также включать слои нормализации/выпадения партий, где это необходимо [в исходной статье использовались свертки без подкладок]. Мы будем использовать relu в качестве функции активации.
  2. Слой максимального объединения с шагом 2, чтобы сжать изображение.
# Functions to build the encoder path
def conv_block(inp, filters, padding='same', activation='relu'):
    """
    Convolution block of a UNet encoder
    """
    x = Conv2D(filters, (3, 3), padding=padding, activation=activation)(inp)
    x = Conv2D(filters, (3, 3), padding=padding)(x)
    x = BatchNormalization(axis=3)(x)
    x = Activation(activation)(x)
    return x


def encoder_block(inp, filters, padding='same', pool_stride=2,
                  activation='relu'):
    """
    Encoder block of a UNet passes the result from the convolution block
    above to a max pooling layer
    """
    x = conv_block(inp, filters, padding, activation)
    p = MaxPooling2D(pool_size=(2, 2), strides=pool_stride)(x)
    return x, p

Если вы заметили, encoder_block возвращает два значения — изображение до и после максимального объединения. UNet передает последний (p) в следующий блок и сохраняет первый (x) в памяти. Почему? Мы объясним это позже.

Выше приведена копия первого блока энкодера, описанного в статье UNet. Он состоит из двух сверточных слоев с последовательно примененными 64 фильтрами, за которыми следует слой максимального объединения (обозначен зеленой стрелкой вниз). Мы повторим это с обсуждаемыми выше модификациями. Спецификация входной формы не является обязательной для таких сетей, как UNet, которые не включают плотный (плоский) слой, тем не менее, мы определим входную форму как (256,256,1).

# Building the first block
inputs = Input((256,256,1))
d1,p1 = encoder_block(inputs,64)

За этим первым блоком следуют еще три подобных блока с фильтрами = 128 256 512. Вместе четыре блока образуют путь сокращения/кодер.

#Building the other four blocks
d2,p2 = encoder_block(p1,128)
d3,p3 = encoder_block(p2,256)
d4,p4 = encoder_block(p3,512)

1.2 Средняя часть

Да, мы обсуждали, что сеть состоит из двух частей, но я предпочитаю рассматривать средний компонент отдельно.

Он берет максимально объединенный вывод из предыдущего блока кодировщика и пропускает его через две последовательные (3,3) свертки 1024 фильтров. Так же, как блок свертки, спросите вы? Да, только в этот раз вывод не проходит за max pooling. Следовательно, мы будем использовать conv_block вместо encoder_block для создания среднего раздела.

# Middle convolution block (no max pooling)
mid = conv_block(p4,1024) #Midsection

Этот окончательный результат теперь будет повышенной дискретизации.

1.3 Декодер

Превратив изображение в многочисленные карты объектов, UNet имеет четкое представление о том, что находится на входном изображении. Он знает, какие классы (объекты) содержит изображение. Теперь ему нужно предсказать правильное расположение всех этих классов и соответствующим образом пометить их пиксели в конечном результате. Для этого UNet использует две ключевые идеи — пропуск соединений и повышение частоты дискретизации.

1.3.1 Транспонированная свертка

После обхода энкодера и прохождения среднего блока входные данные преобразуются в форму (16,16,1024) [Вы можете проверить это, используя API model.summary() keras]. Затем UNet применяет транспонированную свертку для повышения дискретизации вывода. Так что же такое транспонированная свертка? Посмотрите, отвечает ли изображение ниже на ваш вопрос.

По сути, мы умножаем веса ядра на каждую запись во входных данных и объединяем все (2,2) выходные данные, чтобы получить окончательный результат. При взаимных индексах числа складываются. Как и ядра свертки, веса в транспонированных ядрах свертки также поддаются обучению. Сеть изучает их во время обратного распространения, чтобы точно повышать дискретизацию карт объектов. Обратитесь к этой статье, чтобы узнать больше о транспонированных извилинах.

1.3.2 Пропустить соединения

Для каждого блока кодировщика UNet также имеет сопряженный блок декодера. Мы уже обсуждали, что блоки декодера учатся повышать дискретизацию изображений. Чтобы улучшить свое обучение и обеспечить правильное расположение пикселей в конечном выводе, декодеры обращаются за помощью к соответствующим кодировщикам. Они используют это в виде пропуска подключений.

Пропустить соединения — суть UNet. Если вы ранее работали с реснетами, вы должны быть знакомы с этой концепцией. В 1.1 мы обсуждали, что UNet хранит вывод (x) блока свертки в памяти. Эти выходные данные объединяются с изображениями с повышенной дискретизацией из каждого блока декодера. Горизонтальные стрелки на изображении ниже обозначают пропущенные соединения.

Исследователи считают, что по мере того, как входные изображения проникают глубже в сеть, более мелкие детали, такие как местонахождение различных объектов/классов в изображении, теряются. Пропускные соединения передают неуместную информацию из начальных слоев, что позволяет UNet создавать более качественные карты сегментации.

Объединенные изображения с повышенной дискретизацией затем передаются в блоки свертки (2 последовательных слоя свертки). Следовательно, мы используем следующую функцию для создания блоков декодера.

# Functions to build the decoder block
def decoder_block(inp,filters,concat_layer,padding='same'):
    #Upsample the feature maps
    x=Conv2DTranspose(filters,(2,2),strides=(2,2),padding=padding)(inp)
    x=concatenate([x,concat_layer])#Concatenation/Skip conncetion with conjuagte encoder
    x=conv_block(x,filters)#Passed into the convolution block above
    return x

1.4 Окончательная сеть UNet

Ниже показана наша последняя сеть UNet. Выход после e5 имеет вид (256,256,64). Чтобы сопоставить его с входом (256,256,1), мы будем использовать сверточный слой (1,1) с 1 фильтром. Посмотрите это видео Эндрю Нг, если вам интересны свертки 1,1.

# Bulding the Unet model using the above functions
inputs=Input((256,256,1))
d1,p1=encoder_block(inputs,64)
d2,p2=encoder_block(p1,128)
d3,p3=encoder_block(p2,256)
d4,p4=encoder_block(p3,512)
mid=conv_block(p4,1024) #Midsection
e2=decoder_block(mid,512,d4) #Conjugate of encoder 4
e3=decoder_block(e2,256,d3) #Conjugate of encoder 3
e4=decoder_block(e3,128,d2) #Conjugate of encoder 2 
e5=decoder_block(e4,64,d1) #Conjugate of encoder 1
outputs = Conv2D(1, (1,1),activation=None)(e5) #Final Output
ml=Model(inputs=[inputs],outputs=[outputs,o1],name='Unet')

2. Сегментация изображения и глубокий контроль

Ладно, пора применить то, что мы узнали.

Мы выполним сегментацию изображения в этой базе данных рентген грудной клетки covid-19 (основной набор данных). Он включает в себя четыре класса изображений — Covid, Normal, Light Opacity и Viral Pneumonia. В этом посте я просто хочу поделиться тем, как можно использовать потерю контроля и подкласс модели Keras для сегментации изображений. Производительность здесь не при чем.

Поэтому я сформулирую простую задачу. Мы возьмем изображения из класса Covid и разделим их пиксели на легкие и не легкие. (Обратите внимание, что мы не пытаемся научить модель выявлять пораженные коронавирусом области, а отображать пространство, занимаемое легкими). Однако, как правило, медицинская визуализация связана с чрезвычайно сложными случаями, такими как поиск органов, пораженных опухолью, и т. д. Хотя мы не будем их здесь рассматривать, вы можете использовать/модифицировать прилагаемый код для таких привлекательных приложений.

Мы используем следующий блок кода для получения путей изображения/маски из каталога

# Block to read image paths, will be used in image data generator
df = pd.DataFrame(columns=['img_path','msk_path','img_shape','msk_shape','class'])
for cat in ['COVID']:
    dir_ = f"../input/covid19-radiography-database/COVID-19_Radiography_Dataset/{cat}"
    for f in os.listdir(f"{dir_}/images"):
        s1 = cv2.imread(f"{dir_}/images/{f}",config.img_type_num).shape
        s2 = cv2.imread(f"{dir_}/masks/{f}",config.msk_type_num).shape
        dic={'img_path':f"{dir_}/images/{f}",'msk_path':f"{dir_}/masks/{f}",'img_shape':s1,
            'msk_shape':s2}
        df = df.append(dic,ignore_index=True)

Ниже приведены несколькоизображений и соответствующие им маски из набора данных. Маски, как обсуждалось, имеют два класса:

0: легкие

1: не легкие

2.1 Функция потери и глубокая потеря контроля

Обучающие маски имеют только два значения: 0 и 1. Следовательно, мы можем использовать бинарную кросс-энтропию для расчета потерь между ними и нашими конечными результатами. Теперь обратимся к слону в комнате — потеря надзора.

Проблема с глубокими нейронными архитектурами — потеря градиента. Некоторые архитектуры настолько глубоки, что градиенты исчезают при обратном распространении к начальным уровням, что приводит к минимальному смещению веса в начальных слоях, что приводит к недостаточному обучению. Кроме того, это делает модель непропорционально зависимой от более глубоких уровней производительности.

Для ускорения градиентного потока в этой статье предлагается вычислять потери на разных уровнях декодера. Как именно, спросите вы? Во-первых, как показано ниже, мы извлечем из сети дополнительный вывод o1. Мы возьмем результат от предпоследнего декодера (e4), который имеет форму (128,128,128), и сократим его до (128,128,1) с помощью сверточного фильтра (1,1).

# Adding output from 2nd last decoder block
inputs=Input((256,256,1))
d1,p1=encoder_block(inputs,64)
d2,p2=encoder_block(p1,128)
d3,p3=encoder_block(p2,256)
d4,p4=encoder_block(p3,512)
mid=conv_block(p4,1024) #Midsection
e2=decoder_block(mid,512,d4) #Conjugate of encoder 4
e3=decoder_block(e2,256,d3) #Conjugate of encoder 3
e4=decoder_block(e3,128,d2) #Conjugate of encoder 2 
o1 = Conv2D(1,(1,1),activation=None)(e4) # Output from 2nd last decoder
e5=decoder_block(e4,64,d1) #Conjugate of encoder 1
outputs = Conv2D(1, (1,1),activation=None)(e5) #Final Output

Затем мы добавим o1 в качестве вывода модели, добавив его в список вывода в API модели Keras.

# Adding output to output list in keras model API
ml=Model(inputs=[inputs],outputs=[outputs,o1],name='Unet')

Далее, чтобы рассчитать потери от этого уровня, нам также нужно изменить размер копии ввода до (128,128,1). Теперь окончательный проигрыш составит:

Мы также можем взять взвешенную комбинацию двух потерь. Вычисление потерь на разных уровнях также позволяет им получить лучшее приближение к конечному результату.

2.2 Обучение с использованием подкласса модели Keras

Ладно, осталось только тренироваться. Хотя технически вы можете передать список ввода и ввода с измененным размером в известном «ml.fit», я предпочитаю использовать подкласс модели keras. Это позволяет нам больше играть с функцией потерь.

Создаем сеть классов с наследованием от tf.keras.Model. Мы будем передавать используемую модель (ml), функцию потерь (бинарную кросс-энтропию), метрику (потери в кости) и веса потерь (к потерям веса с двух уровней декодера) при инициализации объекта класса network.

# Defining network class which inherits keras model class
class network(tf.keras.Model):
    
    def __init__(self,model,loss,metric,loss_weights):
        super().__init__()
        self.loss = loss
        self.metric = metric
        self.model = model
        self.loss_weights = loss_weights

Затем мы переопределим то, что происходит в ml.fit, используя функцию train_step. Весь поток входного изображения через сеть и вычисление потерь выполняются в рамках tf.GradientTape, который вычисляет градиенты всего в двух строках.

# Overriding model.fit using def train_step
def call(self,inputs,training):
        out = self.model(inputs)
        if training==True:
            return out
        else:
            if type(out) == list:
                return out[0]
            else:
                return out
    
    def calc_supervision_loss(self,y_true,y_preds):
        loss = 0
        for i,pred in enumerate(y_preds):
            y_resized = tf.image.resize(y_true,[*pred.shape[1:3]])
            loss+= self.loss_weights[i+1] * self.loss(y_resized,pred)
            return loss
    
    def train_step(self,data):
        x,y = data
        with tf.GradientTape() as tape:
            y_preds = self(x,training=True)
            if type(y_preds) == list:
                loss = self.loss_weights[0] * self.loss(y,y_preds[0])
                acc = self.metric(y,y_preds[0])
                loss += self.calc_supervision_loss(y,y_preds[1:])
            else:
                loss = self.loss(y,y_preds)
                acc = self.metric(y,y_preds)
        trainable_vars = self.trainable_variables #Network trainable parameters
        gradients = tape.gradient(loss, trainable_vars) #Calculating gradients 
        #Applying gradients to optimizer
        self.optimizer.apply_gradients(zip(gradients, trainable_vars)) 
        return loss,acc

Когда мы работаем с потерями контроля, сеть возвращает выходные данные в виде списка, и мы вызываем функцию calc_supervision_loss для вычисления окончательных потерь.

Точно так же мы можем переопределить шаг проверки

# Overriding validation step
def test_step(self,data):
        x,y=data
        y_pred = self(x,training=False)
        loss = self.loss(y,y_pred)
        acc = self.metric(y,y_pred)
        return loss,acc

С этого момента все обыденно. Мы будем использовать Keras ImageDataGenerator для передачи пар изображение-маска для обучения.

# Keras Image data generator
def img_dataset(df_inp,path_img,path_mask,batch):
    img_gen=ImageDataGenerator(rescale=1./255.)
    df_img = img_gen.flow_from_dataframe(dataframe=df_inp,
                                     x_col=path_img,
                                     class_mode=None,
                                     batch_size=batch,
                                    color_mode=config.img_mode,
                                         seed=config.seed,
                                     target_size=config.img_size)
    df_mask=img_gen.flow_from_dataframe(dataframe=df_inp,
                                     x_col=path_mask,
                                     class_mode=None,
                                     batch_size=batch,
                                    color_mode=config.msk_mode,
                                        seed=config.seed,
                                     target_size=config.img_size)
    data_gen = zip(df_img,df_mask)
    return data_gen

Затем мы создаем наборы для обучения и проверки, устанавливаем оптимизатор, создаем экземпляр класса сети, который мы создали выше, и компилируем его. (Поскольку мы наследуем сеть классов от keras, мы можем напрямую использовать функциональность .compile)

train=img_dataset(train_ds,'img_path','msk_path',config.batch_size)
        val=img_dataset(val_ds,'img_path','msk_path',config.batch_size)
        opt = Adam(learning_rate=config.lr, epsilon=None, amsgrad=False,beta_1=0.9,beta_2=0.99)
        
        model = network(ml,BinaryCrossentropy(),dice_coef,[1,0.5])
        model.compile(optimizer=opt,loss=BinaryCrossentropy(),metrics=[dice_coef])

Переходим к тренировочному циклу

# Custom training loop
best_val = np.inf
        for epoch in range(config.train_epochs):
            epoch_train_loss = 0.0
            epoch_train_acc=0.0
            epoch_val_acc=0.0
            epoch_val_loss=0.0
            num_batches = 0
            for x in train:
                if num_batches > (len(train_ds)//config.batch_size):
                    break
                a,b = model.train_step(x)
                epoch_train_loss+=a
                epoch_train_acc+=b
                num_batches+=1
            epoch_train_loss = epoch_train_loss/num_batches
            epoch_train_acc = epoch_train_acc/num_batches
            num_batches_v=0
            for x in val:
                if num_batches_v > (len(val_ds)//config.batch_size):
                    break
                a,b = model.test_step(x)
                epoch_val_loss+=a
                epoch_val_acc+=b
                num_batches_v+=1
            epoch_val_loss=epoch_val_loss/num_batches_v
            if epoch_val_loss < best_val:
                best_val = epoch_val_loss
                print('---Validation Loss improved,saving model---')
                model.model.save('./weights',save_format='tf')
            epoch_val_acc=epoch_val_acc/num_batches_v
            template = ("Epoch: {}, TrainLoss: {}, TainAcc: {}, ValLoss: {}, ValAcc {}")
            print(template.format(epoch,epoch_train_loss,epoch_train_acc,
                                  epoch_val_loss,epoch_val_acc))

2.3 Результаты

Предсказанные маски довольно точны. Модель имеет оценку кости проверки 0,96 и потерю проверки 0,55. Однако, как уже говорилось, нам не следует слишком много вникать в эти значения, поскольку решаемая задача была грубой. Цель состояла в том, чтобы показать, как можно использовать потерю контроля. В упомянутой выше бумаге авторы использовали выходные данные трех декодеров для расчета окончательных потерь.

Спасибо, что дочитали до конца. Я надеюсь, что это позволит вам использовать потерю контроля в будущем. Проверьте мой Github/Kaggle, если вам понравилась работа.

Использованная литература:

https://arxiv.org/abs/1505.04597









https://linkinghub.elsevier.com/retrieve/pii/S001048252100113X