Мотивация

Огромное количество исследований, проведенных в области неконтролируемого машинного обучения и других важных областей, таких как теория информации, предполагает, что в немаркированных данных есть неиспользованный потенциал. Данные такого типа не только многочисленны, но и легко доступны. Меня особенно интересовала область полуконтролируемого обучения, потому что оно мало чем отличается от затруднительного положения, в котором человек оказывается, когда он подвергается воздействию мира [1]. Согласно теории когнитивного развития Пиаже [2], младенцы учатся и адаптируются к миру посредством сочетания как контролируемого, так и неконтролируемого опыта. Младенцы получают некоторые непосредственные указания и инструкции от своих опекунов (обучение под наблюдением), но они также самостоятельно исследуют окружающую среду и взаимодействуют с ней, осмысливая новую информацию и формируя свое понимание (обучение без присмотра).

Точно так же при полуконтролируемом обучении модель получает ограниченный объем помеченных данных, аналогично контролируемому руководству, которое дают младенцам. Эти помеченные данные предоставляют модели явные примеры для обучения и служат формой контроля. С другой стороны, модель также имеет доступ к более значительному количеству неразмеченных данных, представляющих фазу обучения без учителя. На этом этапе модель исследует и учится на неразмеченных данных, извлекая полезные модели и представления из окружающей среды, точно так же, как это делают младенцы посредством своих исследований и взаимодействий.

Поэтому я представляю новую стратегию для парадигмы классификации изображений, которая также использует такие концепции, как дистилляция знаний учитель-ученик [3], ансамблевое обучение [4], увеличение данных [5] и метаобучение [6]! Этот проект был выполнен в рамках моей диссертации бакалавра в Индийском институте научного образования и исследований в Бхопале (IISERB) под руководством доктора Акшая Агарвала. Более глубокий анализ вы найдете в тезисе, обязательно ознакомьтесь!

Пробелы в исследованиях

Несмотря на прогресс, достигнутый в SSL, все еще существуют заметные пробелы в исследованиях, которые необходимо устранить. Во-первых, выбор регуляризатора может сильно повлиять на производительность алгоритмов SSL. Хотя некоторые регуляризаторы, такие как минимизация энтропии[7] или регуляризация согласованности[8], показали многообещающие результаты, исследование новых регуляризаторов, которые могут еще больше улучшить производительность SSL, остается открытым вопросом. Включая изученную интерполяцию меток и метаобучение параметра лямбда, моя модель представляет новый регуляризатор, который побуждает модель ученика эффективно уравновешивать влияние нескольких моделей учителей. Этот подход обеспечивает более гибкую и адаптивную стратегию регуляризации по сравнению с традиционными фиксированными регуляризаторами. Это позволяет модели изучить оптимальный компромисс между разными учителями и повышает общую производительность алгоритма SSL.

Во-вторых, устойчивость к атакам злоумышленников — еще одна область, требующая внимания в SSL[9]. Атаки со стороны злоумышленников представляют собой серьезную угрозу для надежности и обобщения моделей, поэтому разработка методов SSL, которые могут эффективно обрабатывать примеры со стороны злоумышленников и поддерживать надежность, имеет решающее значение. Благодаря использованию метода смешивания псевдометок и включению прогнозов нескольких моделей учителей моя модель повышает надежность алгоритмов SSL. Комбинация различных прогнозов учителей помогает снизить уязвимость модели к враждебным возмущениям и повышает ее общую надежность.

Кроме того, изучение новых способов объединения размеченных и неразмеченных данных является постоянным направлением исследований. В то время как существующие методы, такие как совместное обучение [10] и самообучение [11], показали успех, существует потребность в более сложных методах, которые могут более эффективно использовать дополнительную информацию, представленную в размеченных и неразмеченных данных. Благодаря интерполяции изученных меток моя модель эффективно объединяет информацию из нескольких моделей учителей и создает более надежные псевдометки для немаркированных данных. Эта улучшенная комбинация размеченных и неразмеченных данных помогает модели ученика лучше использовать неразмеченные данные для обучения, что приводит к повышению эффективности классификации.

Цели

(1) Предложить новую структуру SSL, которая использует консенсус двух моделей CNN учителя для генерации псевдометки и смешивает предсказанную метку в определенных пропорциях (смягченные метки), когда между моделями учителя есть расхождение.

(2) Отчет и анализ производительности модели ученика, которая была обучена в соответствии с процедурой обучения, которая объединяет размеченные и псевдоразмеченные данные. Проведенные эксперименты будут проводиться на эталонных наборах данных, таких как CIFAR-10, CIFAR-100 [12], SVHN [13] и наборы данных MNIST [14].

(3) Реализовать стратегию «обучения своего собственного учителя», при которой многоголовой сети удается не только минимизировать ожидаемые потери на псевдометках и помеченных данных, но и изучать параметр смешения меток «лямбда», который выполняется вторым руководителем той же модели, где сеть мета-узнает, к какому учителю следует стремиться.

(4) Проведите анализ модели, сравнив ее с различными базовыми уровнями, используя различные комбинации нейронных сетей, и визуализируйте пояснительные графики, такие как t-SNE[15], Grad-CAM[16] и т. д.

Методология

Дистилляция ансамбля учителей с помощью псевдомаркировки и смешивания меток

Предлагаемая структура выглядит следующим образом: -
(1) Первоначально инициализируются две модели учителя (CNN), которые обучаются на размеченных данных. Эти учителя являются слабыми учениками, которые перегоняют свои знания в модель ученика.

(2) Обученные веса хранятся в файлах путей и доступны во время обучения студенческой модели CNN.

(3) Пакет данных выбирается из немаркированного набора данных, и выходные логиты оцениваются из обеих моделей учителя.

(4) Жесткие псевдометки назначаются для прогнозов, и если между учителями существует соглашение о метке входа, пара вход-псевдометка отправляется для обучения модели ученика с использованием перекрестной потери энтропии.

(5) Для экземпляров данных в пакете, где есть разногласия по поводу предсказанной метки (скажем, учитель 1 предсказывает кошку, а учитель 2 — собаку), метки перепутаны в определенных пропорциях. Интерполированная псевдометка и входная пара отправляются для обучения модели ученика с использованием перекрестной потери энтропии.

Извлекая знания из ансамбля моделей учителей, модель ученика выигрывает от более полного и точного понимания данных, даже если отдельные учителя могут быть слабыми по отдельности. Кроме того, смешивание пространства меток может улучшить производительность обобщения нейронной сети. Путем объединения двух выборок и соответствующих им меток модель поощряется к изучению линейных отношений между входными объектами и выходными метками, что может помочь ей лучше обобщать невидимые данные. Это особенно полезно в ситуациях, когда размеченные данные ограничены или зашумлены, а модель также должна учиться на неразмеченных данных. Коэффициент смешивания можно выбрать из бета-распределения.

Учитесь у своего собственного учителя: схема интерполяции меток на основе данных

Как мы видели, сводя к минимуму перекрестную энтропийную потерю перепутанных меток, модель поощряется к изучению большего количества отличительных признаков, которые могут лучше различать разные классы в наборе данных. Однако стоит отметить, что смешивание меток следует использовать с осторожностью, так как это не всегда может улучшить производительность, а в некоторых случаях может даже снизить ее, особенно если набор данных мал или если стратегия смешивания не продумана. Важно тщательно оценить производительность модели на проверочных или тестовых данных и поэкспериментировать с различными стратегиями смешивания, чтобы найти то, что лучше всего подходит для данной задачи и набора данных.

В отличие от исходной стратегии, в которой используется предопределенное распределение лямбда-коэффициента интерполяции и уникальное значение в каждом мини-пакете, я рассматриваю возможность использования разных коэффициентов интерполяции в каждом мини-пакете, чтобы улучшить разнообразие и сделать их все обучаемыми. Процедура, которую я использовал, выглядит следующим образом:

(1) Полностью подключенная линейная головка выдает одно значение от 0 до 1, которое является лямбдой, используемой для смешивания меток для стратегии, описанной ранее.

(2) Производительность предыдущей политики интерполяции меток проверяется на протянутом наборе валидации с метками. Я инициализирую две головки, давайте назовем одну головкой классификации, которая возьмет на себя потерю перекрестной энтропии в вышеупомянутой процедуре псевдомаркировки.

(3) Другая головка («лямбда-головка») принимает в качестве входных данных два решения учителя по немаркированному экземпляру и производит дробь от 0 до 1. По сути, я изучаю целевое распределение, для которого ожидаемая перекрестная энтропия ученика должна быть сведен к минимуму.

Более того, путем объединения логитов softmax учителя модель может мета-узнать, к каким прогнозам модели учителя должна стремиться модель.

Детали реализации

Ниже приведен класс, который определяет некоторые функции полезности, такие как этап обучения, когда учителя соглашаются и когда они не согласны, шаг метаобучения лямбда и некоторые другие необходимые функции. Их можно довольно легко вызвать в тренировочном цикле благодаря аккуратному ООП!

class ImageClassificationBaseMixup(nn.Module):
    
    # the loss for the lambda meta-learning head
    def training_step_NeuralNetwork(self,model,val_loader,teacher_inputs, device):
      model.train()
      val_loss = []
      for images,labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)
        (out,_) = model(images,teacher_inputs)
        loss = F.cross_entropy(out, labels)
        val_loss.append(loss) 
      
      loss_alpha = torch.stack(val_loss).mean()
      
      return loss_alpha

    # The loss when both the teachers agree on the same label for an instance
    def training_step_normal(self,images,labels,teacher_inputs,device):
        
        (out,_) = self(images,torch.zeros(teacher_inputs.shape).to(device))                  # Generate predictions
        loss_normal = F.cross_entropy(out, labels) # Calculate loss
        return loss_normal
    
    # The loss when teachers don't agree on the same label( labels are mixed up)
    def training_step_label_mixup(self,images,label_1,label_2,teacher_inputs,device):
        
        (preds,lam) = self(images,teacher_inputs)
        lam = lam.float()
        lam = Beta(lam,lam +0.001).sample()
        loss_mixup = lam*F.cross_entropy(preds,label_1)+(1-lam)*F.cross_entropy(preds,label_2)
        return (loss_mixup,lam)


    def validation_step(self,images,labels,teacher_inputs,device):
        
        (out,_) = self(images,torch.zeros(teacher_inputs.shape).to(device))                    # Generate predictions
        loss = F.cross_entropy(out, labels)   # Calculate loss
        acc = accuracy(out, labels)           # Calculate accuracy
        return {'val_loss': loss.detach(), 'val_acc': acc}
        
    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()   # Combine losses
        batch_accs = [x['val_acc'] for x in outputs]
        epoch_acc = torch.stack(batch_accs).mean()      # Combine accuracies
        return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}
    
    def epoch_end(self, epoch, result):
        print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
            epoch, result['lrs'][-1], result['train_loss'], result['val_loss'], result['val_acc']))

Ниже приведен код для пользовательского класса PyTorch, который определяет базу ResNet-9 и имеет 2 выхода: вывод классификации и лямбда. Он наследуется от класса ImageClassificationBaseMixup, определенного выше. В моем репозитории GitHub такие пользовательские классы созданы для разных типов CNN, это только один пример.

class ResNet9(ImageClassificationBaseMixup):
    def __init__(self, in_channels, num_classes,meta_input_size,meta_tensor):
        super().__init__()
        
        self.conv1 = conv_block(in_channels, 64)
        self.conv2 = conv_block(64, 128, pool=True)
        self.res1 = nn.Sequential(conv_block(128, 128), conv_block(128, 128))
        
        self.conv3 = conv_block(128, 256, pool=True)
        self.conv4 = conv_block(256, 512, pool=True)
        self.res2 = nn.Sequential(conv_block(512, 512), conv_block(512, 512))
        self.classifier = nn.Sequential(nn.MaxPool2d(3), 
                                        nn.Flatten(), 
                                        nn.Dropout(0.2),
                                        nn.Linear(512, num_classes))
        self.lam_head = LambdaHead(meta_input_size)

        
    def forward(self, xb, meta_tensor):
        out = self.conv1(xb)
        out = self.conv2(out)
        out = self.res1(out) + out
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.res2(out) + out
        out = self.classifier(out)
        lam = self.lam_head(meta_tensor)
        return (out,lam)

Для CIFAR-10 4 000 помеченных примеров хранятся как помеченные данные, а 41 000 примеров используются как немаркированные данные. Набор тестов для CIFAR-10 стандартный и состоит из 10 000 примеров. Для SVHN 4000 примеров используются как размеченные данные, тогда как около 73 257 примеров используются как неразмеченные данные. Набор тестов для SVHN также стандартный и содержит 26 032 примера. Размер пакета устанавливается равным 64 при обучении учителей и 128 при обучении ученика. Я использовал задержку проверки, разделенную на эти 4000 помеченных примеров из 1000 изображений, и обучение для моделей учителя выполняется с помощью оптимизатора Адама [17], а для ученика — с оптимизатором стохастического градиентного спуска. Я использовал оптимизаторы Adam и SGD с планировщиками скорости обучения с косинусным отжигом, со скоростью обучения 0,001 для MNIST и 0,01 для наборов данных CIFAR и снижением веса 0,0001, и обучил модель на 30–50 эпох.

Как упоминалось ранее, потери для мета-обучаемого (лямбда-голова) рассчитываются на наборе с проверочной маркировкой. Следовательно, в сети будет два обратных прохода и нам придется сохранить вычислительный граф для первого прохода. Второй обратный проход, по общему признанию, требует значительных вычислительных ресурсов, но, тем не менее, является полезной стратегией, потому что при регулярном смешивании меток много времени и вычислений должно быть посвящено выбору правильного гиперпараметра альфа, который, если его выбрать неразумно, может привести к недооценке. -подгонка и ухудшение производительности классификатора.

Результаты и обсуждения

Мои результаты показывают, что предложенный подход показал лучшую производительность, чем базовая модель CNN, обученная только на псевдоразмеченных данных одним учителем. В наборе данных CIFAR-10 исходная модель достигла точности 71,4% с 4000 помеченными примерами в тестовом наборе и 50000 немаркированными примерами по сравнению с точностью 63,2% базовой модели. Это огромная разница, учитывая, насколько маленьким был контролируемый набор. В наборе данных CIFAR-100 модель достигла точности 44,8 % с 4000 помеченными примерами и 50 000 немаркированных примеров по сравнению с точностью 44,2 % базовой модели. В наборе данных MNIST модель достигла точности 91,1 процента с 400 помеченными примерами и 59 900 неразмеченными примерами по сравнению с точностью 84,1 базовой модели с обучением псевдометкам с одним учителем.

Я также выполнил визуализацию Grad-CAM, чтобы визуализировать части изображения, на которых фокусируется модель ученика для каждого класса. Результаты показывают, что модель ученика может сосредоточиться в основном на правильных областях изображения для каждого класса, что указывает на то, что алгоритм изучает значимые представления. Для дальнейшего анализа изученных представлений я выполнил визуализацию t-SNE пространства признаков моделей учеников и моделей учителей. Результаты показывают, что кластерное пространство лучше разделяет модели учеников по сравнению с моделями учителей.

Единственная проблема, с которой я столкнулся, — это выравнивание значения лямбда до 1. Это означало бы, что модель переобучает лямбда одному из учителей. Это может означать, что термин регуляризации неэффективен в контроле компромисса между двумя учителями. В этом случае мы можем попробовать добавить регуляризацию к самой лямбда-головке.

Последние мысли!

Мои результаты показывают, что предложенный подход обеспечивает лучшую производительность, чем базовая модель CNN, обученная только на размеченных данных. Тем не менее, мои результаты все еще далеки от современных результатов, о которых сообщалось в недавних исследованиях. Одной из возможных причин этого являются вычислительные ограничения, с которыми я столкнулся. Наши эксперименты были ограничены одним графическим процессором (стандартным графическим процессором Colab) с ограниченной памятью и вычислительной мощностью. Поэтому я не мог исследовать более крупные модели и более обширную настройку гиперпараметров, что, возможно, привело к лучшим результатам.

Мои результаты показывают, что модель достигла лучшей производительности с большим количеством помеченных примеров, особенно в наборах данных с более высокой сложностью, таких как CIFAR-100 и SVHN. Следовательно, можно изучить более выгодные способы более разумного использования контролируемых данных, например, предложенные активным обучением.[17]

Наконец, я обнаружил, что лямбда-головка играет решающую роль в работе предложенного подхода. Прогнозируя качество псевдометок, лямбда-головка помогла модели различать надежные и ненадежные псевдометки и сосредоточиться на более важных. Но я не смог получить результаты для этой стратегии метаобучения из-за тяжелой вычислительной природы алгоритма. Выполнение двух обратных проходов в алгоритме метаобучения может быть дорогостоящим в вычислительном отношении, поскольку оно включает в себя повторное вычисление прямого прохода и обратное распространение градиентов по модели дважды. Это может быть особенно дорогостоящим, если модель большая или если набор данных, используемый для метаобучения, большой.

В целом, это была особенно интригующая попытка теоретизировать и реализовать оригинальную стратегию. Кривая обучения была крутой, и моя настойчивость принесла мне высокую оценку совета, когда я защищал диссертацию. Теперь я предлагаю сообществу Medium критически оценить мой подход, чтобы мы могли продолжить исследования в этой, несомненно, важной области. И последнее, но не менее важное: я хотел бы отослать всех вас к работе, проделанной Mai, Z. et al. «MetaMixUp: изучение политики адаптивной интерполяции MixUp с помощью метаобучения», о которой я не знал, но которая расширяет исследование в том же направлении, но с другой политикой интерполяции.[18]

Мерси и слава Богу!

Полную реализацию проекта вы можете найти в этом репозитории GitHub.

Рекомендации

[1] Оуали Ю., Худелот К. и Тами М. (2020) Обзор глубокого полуконтролируемого обучения, arXiv [cs.LG]. Доступно по адресу: http://arxiv.org/abs/2006.05278 (дата обращения: 2 июня 2023 г.).

[2]Рабиндран и Маданагопал, Д. (2020) «Теория Пиаже и этапы когнитивного развития — обзор», Журнал ученых прикладных медицинских наук, 8(9), стр. 2152–2157. дои: 10.36347/sjams.2020.v08i09.034.

[3] Аббаси, С. и др. (2019) Моделирование методов учитель-ученик в глубоких нейронных сетях для извлечения знаний, arXiv [cs.CV]. Доступно по адресу: http://arxiv.org/abs/1912.13179 (дата обращения: 2 июня 2023 г.).

[4] Вэнь, Л. и др. (2022) «Новая ансамблевая сверточная нейронная сеть с регуляризацией разнообразия для диагностики неисправностей», Журнал производственных систем, 62, стр. 964–971. doi: 10.1016/j.jmsy.2020.12.002.

[5] Шортен, К. и Хошгофтаар, Т. М. (2019) «Обзор увеличения данных изображений для глубокого обучения», Journal of big data, 6(1). doi: 10.1186/s40537–019–0197–0.

[6] Hospedales, T. et al. (2020) Метаобучение в нейронных сетях: обзор, arXiv [cs.LG]. Доступно по адресу: http://arxiv.org/abs/2004.05439 (дата обращения: 2 июня 2023 г.).

[7]Grandvalet, Y. and Heudiasyc, Y.B. (без даты) Semi-supervised learning by antropy minim>, Neurips.cc. Доступно по адресу: https://proceedings.neurips.cc/paper_files/paper/2004/file/96f2b50b5d3613adf9c27049b2a888c7-Paper.pdf (дата обращения: 2 июня 2023 г.).

[8]Энглессон, Э. и Азизпур, Х. (2021) Регуляризация согласованности может повысить устойчивость к маркировке шума, arXiv [cs.LG]. Доступно по адресу: http://arxiv.org/abs/2110.01242 (дата обращения: 2 июня 2023 г.).

[9]Чакраборти, А. и др. (2018) Состязательные атаки и защита: обзор, arXiv [cs.LG]. Доступно по адресу: http://arxiv.org/abs/1810.00069 (дата обращения: 2 июня 2023 г.).

[10] Ду, Дж. и др. (без даты) Когда совместное обучение работает на реальных данных?, Uwo.ca. Доступно по адресу: https://www.csd.uwo.ca/~xling/papers/TKDE2010_co-training.pdf (дата обращения: 2 июня 2023 г.).

[11] Амини, М.-Р. et al. (2022) Самообучение: обзор, arXiv [cs.LG]. Доступно по адресу: http://arxiv.org/abs/2202.12040 (дата обращения: 2 июня 2023 г.).

[12]Наборы данных CIFAR-10 и CIFAR-100 (без даты) Toronto.edu. Режим доступа: https://www.cs.toronto.edu/~kriz/cifar.html (дата обращения: 2 июня 2023 г.).

[13]Набор данных о номерах домов для просмотра улиц (SVHN) (без даты) Stanford.edu. Доступно по адресу: http://ufldl.stanford.edu/housenumbers/ (дата обращения: 2 июня 2023 г.).

[14]Берджес, CJC (без даты) База данных рукописных цифр MNIST, Янн ЛеКун, Коринна Кортес и Крис Берджес, Lecun.com. Доступно по адресу: http://yann.lecun.com/exdb/mnist/ (дата обращения: 2 июня 2023 г.).

[15]Gmail, Л. и Хинтон, Г. (2008) Визуализация данных с помощью t-SNE, Jmlr.org. Доступно по адресу: https://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf (дата обращения: 2 июня 2023 г.).

[16] Selvaraju, R. R. et al. (2016) Grad-CAM: визуальные объяснения из глубоких сетей с помощью локализации на основе градиента, arXiv [cs.CV]. Доступно по адресу: http://arxiv.org/abs/1610.02391 (дата обращения: 2 июня 2023 г.).

[17] Рен, П. и др. (2022) «Обзор глубокого активного обучения», Исследования вычислений ACM, 54(9), стр. 1–40. . дои: 10.1145/3472291.

[18]Mai, Z. et al. (2019) MetaMixUp: Изучение политики адаптивной интерполяции MixUp с метаобучением, arXiv [cs.CV]. Доступно по адресу: http://arxiv.org/abs/1908.10059 (дата обращения: 2 июня 2023 г.).