Классификация изображений романтических пар с помощью PyTorch

Если и есть корневая область недавнего бурного развития глубокого обучения, то это, безусловно, компьютерное зрение, анализ изображений и видеоданных. Поэтому неудивительно, что вы испытываете удачу с некоторыми методами компьютерного зрения во время изучения глубокого обучения. Короче говоря, мы с моим партнером (Максимилиан Улих) решили применить эту форму глубокого обучения к изображениям романтических пар, потому что Максимилиан - исследователь отношений и терапевт. В частности, мы хотели выяснить, можем ли мы точно сказать, счастлива ли какая-либо пара, изображенная на изображении или видео, в своих отношениях или нет? Оказывается, можем! С точностью классификации почти 97 процентов наша окончательная модель (которую мы назвали DeepConnection) смогла четко различить несчастные и счастливые пары. Полную историю вы можете прочитать в нашем препринте, а дальше - набросок того, что мы сделали.

Для нашего набора данных мы провели сканирование (используя этот удобный скрипт Python) для изображений счастливых и несчастливых пар. В итоге у нас остался обучающий набор из около 1000 изображений. Это не так уж и много, поэтому мы назвали возможности увеличения данных и передачи обучения нам на помощь. Увеличение данных, незначительные изменения ориентации изображения, цветового оттенка и интенсивности и многое другое не позволяет вашей модели усвоить некоторые несущественные связи. Например, если изображения счастливых пар в среднем ярче, чем изображения несчастливых пар, мы не хотим, чтобы наша модель отображала эту ассоциацию. Мы использовали (отличную) библиотеку ImgAug и занялись довольно большим увеличением объема данных, чтобы убедиться, что наша модель является надежной. По сути, к каждому изображению для каждой партии применяется по крайней мере часть методов увеличения. Ниже вы можете увидеть примерный пакет данных с одним и тем же изображением 48 раз с репрезентативными профилями увеличения данных.

Все идет нормально. Поскольку здесь речь идет об изображениях, мы решили использовать модель типа ResNet в качестве основы для DeepConnection, предварительно обученную на огромном наборе данных ImageNet. Предварительно обученная работе со всеми типами изображений, эта модель уже выучила множество полезных форм и форм и получила преимущество благодаря этому переносному обучению. Между прочим, все наши модели находятся в PyTorch, и мы использовали бесплатные ресурсы графического процессора в Google Colab для обучения и вывода. Эта базовая модель сама по себе уже была хорошим началом для классификации, но мы решили пойти дальше и заменить последний слой адаптивного пула в нашей базовой модели ResNet-34 слоем пространственного пула пирамиды (SPP). Здесь обработанные данные изображения разбиваются на разное количество квадратов, и для дальнейшего анализа передаются только максимальные значения. Это позволяет модели сосредоточиться на важных функциях, делает ее устойчивой к разным размерам изображений и невосприимчивой к искажениям изображения. После этого мы разместили слой преобразования среднего значения (PMT) для преобразования данных с помощью нескольких математических функций, чтобы ввести нелинейности и позволить DeepConnection захватывать более сложные отношения из данных. Оба эти дополнения повысили точность нашей классификации, и в итоге мы получили около 97 процентов в нашем отдельном наборе для проверки. Вы можете проверить код для SPP / PMT и последующих уровней классификации ниже.

class SPP(nn.Module):
  def __init__(self):
    super(SPP, self).__init__()
    
    ## features incoming from ResNet-34 (after SPP/PMT)
    self.lin1 = nn.Linear(2*43520, 100)
    
    self.relu = nn.ReLU()
    self.bn1 = nn.BatchNorm1d(100)
    self.dp1 = nn.Dropout(0.5)
    self.lin2 = nn.Linear(100, 2)
    
  def forward(self, x):
    # SPP
    x = spatial_pyramid_pool(x, x.shape[0], [x.shape[2], x.shape[3]], [8, 4, 2, 1])
    
    # PMT
    x_1 = torch.sign(x)*torch.log(1 + abs(x))
    x_2 = torch.sign(x)*(torch.log(1 + abs(x)))**2
    x = torch.cat((x_1, x_2), dim = 1)
    
    # fully connected classification part
    x = self.lin1(x)
    x = self.bn1(self.relu(x))
    
    #1
    x1 = self.lin2(self.dp1(x))
    #2
    x2 = self.lin2(self.dp1(x))
    #3
    x3 = self.lin2(self.dp1(x))
    #4
    x4 = self.lin2(self.dp1(x))
    #5
    x5 = self.lin2(self.dp1(x))
    #6
    x6 = self.lin2(self.dp1(x))
    #7
    x7 = self.lin2(self.dp1(x))
    #8
    x8 = self.lin2(self.dp1(x))
    
    x = torch.mean(torch.stack([x1, x2, x3, x4, x5, x6, x7, x8]), dim = 0)
    
    return x

Если у вас зоркий глаз, вы могли заметить восемь вариантов на последнем уровне классификации. То, что может показаться расточительным расчетом, на самом деле противоположное. Эта концепция была недавно выдвинута как множественное выпадение выборки и значительно ускоряет сходимость во время обучения. По сути, это компромисс между предотвращением изучения модели ложными отношениями (переобучением) и попыткой не выбросить информацию из маски исключения.

У нас была пара других настроек в проекте, лучше всего проверить препринт или связанный код на GitHub для получения дополнительной информации. Просто кратко упомяну несколько: мы обучили модель со смешанной точностью (используя библиотеку Apex) для значительно меньшего использования памяти, использовали раннюю остановку для предотвращения переобучения и отжиг скорости обучения в соответствии с функцией косинуса.

После достижения удовлетворительной точности классификации (с соответственно большой отзывчивостью и точностью) мы хотели выяснить, можем ли мы чему-то научиться из классификации, выполненной DeepConnection. Поэтому мы занялись интерпретацией модели и использовали технику, известную как Gradient-weighted Class Activation Mapping (Grad-CAM). По сути, Grad-CAM принимает входящие градиенты последнего сверточного слоя для определения заметных областей изображения, которые могут быть визуализированы как тепловая карта с повышенной дискретизацией поверх исходных изображений. Если вы хотите увидеть, как это выглядит, просто взгляните на рисунок ниже с кодом Grad-CAM после него.

Мы обсудим это далее в статье и включим его в существующие психологические исследования, но DeepConnection, похоже, в основном сосредоточен на области лица. С точки зрения исследования, это имеет большой смысл, поскольку общение и эмоции в значительной степени передаются с помощью мимики. Мы также хотели посмотреть, сможем ли мы получить реальные характеристики с помощью интерпретации модели в дополнение к визуальному восприятию, полученному с помощью Grad-CAM. Для этого мы создали графики состояния активации, чтобы визуализировать, какие из нейронов последнего слоя классификации были активированы данным изображением.

В отличие от других моделей, DeepConnection также обнаружил особенно неприятные особенности и не просто использовал отсутствие положительных черт для классификации как несчастные. Но необходимы дальнейшие исследования, чтобы действительно сопоставить эти особенности с интерпретируемыми человеком аспектами. Мы также попробовали DeepConnection на видеокадрах пар, которых модель раньше не видела, и это сработало довольно хорошо.

В целом надежность модели - одна из ее сильных сторон. Точная классификация работала с гомосексуальными парами, цветными парами, парами с дополнительными людьми на изображении, парами, лица которых не были полностью видны, и т. Д. В случае присутствия дополнительных людей на изображении DeepConnection даже распознает, кажутся ли дополнительные люди счастливыми или нет, но все же фокусирует свое предсказание на самой паре.

Следующими шагами, помимо дальнейшей интерпретации модели, будет использование большего набора обучающих данных с потенциалом для обучения более сложных моделей. Было бы также интересно использовать DeepConnection в качестве помощника для парных терапевтов, с обратной связью в реальном времени о текущем состоянии отношений (снятой) пары во время сеанса или после него для анализа. Кроме того, я бы посоветовал вам ввести свой собственный образ с вашим партнером и посмотреть, что DeepConnection может сказать о ваших отношениях! Я рада сообщить, что это, по крайней мере, благословило наши отношения, так что я думаю, это хорошее начало!

Дайте мне знать, если у вас есть идеи, как продвинуться вперед или что бы вы хотели добавить / улучшить!

Вот ссылки на препринт и репозиторий GitHub, если вам интересно.