Введение, краткое объяснение и подробная реализация Octave Convolution в PyTorch.

Сверточные нейронные сети (CNN) доминируют в области компьютерного зрения. В этом посте мы рассмотрим недавно предложенную свертку Octave из этой статьи: Drop an Octave: Уменьшение пространственной избыточности в сверточных нейронных сетях с помощью октавной свертки.

Октавная свертка может использоваться как замена ванильной свертки. Авторы продемонстрировали, что аналогичная (иногда лучшая) точность может быть достигнута с помощью октавной свертки при сохранении огромного количества требуемых флопов. Размер модели в случае октавной и ванильной сверток одинаков.

Ванильная свертка выполняет высокочастотную свертку по всем входным каналам. С другой стороны, октавная свертка разделяет все каналы на две части: высокочастотную и низкочастотную. Низкочастотные каналы на одну октаву меньше (высота и ширина) по сравнению с высокочастотными свертками. Более того, высокочастотные и низкочастотные каналы объединяются друг с другом перед отправкой выходных сигналов.

Как видно на рисунке, каждый модуль октавной свертки может иметь внутри до 4 ветвей, каждая из которых выполняет ванильную свертку. Пути зеленого цвета не изменяют пространственные размеры от входа к выходу. Однако пути с красным цветом либо увеличивают (от низкого к высокому), либо уменьшают (от высокого к низкому) пространственные размеры, идущие от входа к выходу.

При переходе от высокочастотного входа к низкочастотному выходу (путь HtoL) выполняется операция объединения 2x2, чтобы получить уменьшенный вход для свертки. Итак, путь HtoL - conv_vanilla (pool (in_high))

Точно так же при переходе от низкочастотного входа к высокочастотному выходу (путь LtoH) обычная свертка увенчивается билинейной интерполяцией для повышения дискретизации вывода свертки низкого разрешения. Итак, путь LtoH - это bilenear_interpolation (vanilla_convolution (in_low)).

В основе свертки Octave лежит концепция α (отношение общего числа каналов, которые используются низкочастотными свертками). Для первого сверточного слоя нет низкочастотного входного канала, поэтому α_ {in} = 0. Аналогично для последнего сверточного слоя нет низкочастотного выходного канала, α_ {out} = 0. Для всех других слоев, авторы приняли α_ {in} = α_ {out} = 0,5.

В статье авторы показали много результатов. Тот, который я считаю наиболее интересным, показан ниже. Как видите, с небольшой долей низкочастотной составляющей (0,125 или 0,25) сети работают лучше, чем базовые модели со всеми высокочастотными каналами.

Ниже приводится реализация Pytorch.

Полную реализацию можно найти в моем репозитории git. Для тестирования этой реализации я обучил двухслойную ванильную CNN на CIFAR10 примерно на 20 эпох. Затем я заменил все свертки на октавную свертку. сеть работает немного лучше (2–3%). Я чувствую, что для более крупных сетей разница может быть даже лучше.