TLDR;

Retentive Network (RetNet) имеет производительность, сравнимую с Transformer того же размера, может обучаться параллельно, но поддерживает режим повторения, который обеспечивает сложность вывода O (1) на токен.

Неофициальную, но полную реализацию можно найти в моем репозитории ниже:



«Невозможный треугольник» для моделей генеративной последовательности

Для моделей последовательности, особенно. генеративные, у нас есть три вышеупомянутых требования: быстрый вывод, параллельное обучение и высокая производительность. (На мой взгляд, есть еще одно измерение: экстраполяция длины последовательности. Это может поддерживаться RetNet, но явных экспериментов с ним нет.)

RNN имеет быстрый вывод, но медленное обучение, линейные преобразователи имеют более слабые характеристики, а преобразователи имеют O (n) на вывод токена. RetNet удовлетворяет всем трем требованиям: параллельное обучение, логический вывод O(1) и преобразователи битов.

Краткая история

Было несколько подходов к смягчению дорогостоящего вывода генеративных трансформаторов. Известные работы включают Linear Transformers, Attention-Free Transformers (AFT; от Apple) и RWKV (от BlinkDL, на основе AFT). .

Они заслуживают отдельного поста, поэтому я не буду вдаваться в подробности: но, на мой взгляд, все они математически очень элегантны, особенно вывод того, как можно распараллелить RNN. в то время как я нахожу RetNet немного более интересным, так как он также имеет фрагментарное представление и некоторые изящные трюки, такие как xpos.

Так как же это работает?

RetNet — это своего рода plug-and-play замена «внимания» на «retention» в той же архитектуре Transformer.

Я пройдусь по ним сверху вниз.

1. Каждый блок RetNet

На самом высоком уровне RetNet состоит из нескольких стеков одинаковых блоков, каждый из которых содержит MultiScaleRetention (MSR) и FeedForwardNetwork (FFN). У них также есть норма слоя и пропускное соединение, как и у трансформеров. FFN также почти идентичен Transformers, который представляет собой двухслойный MLP, размер скрытого затемнения = 2 x размер встраивания и с активацией gelu.

Если мы заменим MSR на MultiHeadAttention, это будет просто Transformer. Поэтому все различия можно найти в MSR.

2. Закрытое мультимасштабное удержание

Multi-Scale аналогичен Multi-Head. В приведенном выше уравнении γ — это некоторый гиперпараметр, который следует использовать при удержании, и он определяется отдельно для каждой головки. До групповой нормы это обычное многоголовое внимание, но с удержанием.

Gated MSR добавляет групповую норму, поворотный затвор и выходную проекцию на выходе, что можно рассматривать как вспомогательный вариант дизайна. (group-norm допускает масштабирование скалярного произведения, но пока это не так важно.)Самое важное отличие (модуль удержания) еще впереди.

3. Удержание

Наконец, давайте посмотрим, что такое удержание. У удержания есть 3 парадигмы: параллельная, рекуррентная и поблочно-рекуррентная. Давайте посмотрим на них один за другим.

Параллельное хранение

Сосредоточьтесь на последней строке. Игнорируя D, опять же, это точечный продукт внимания без softmax. Итак, важная деталь снова находится в D и Theta.

  • Thetabar(Theta), комплексное сопряжение) — это комплексное представление кодирование xpos основано на ротационном встраивании, чтобы модель могла лучше экстраполировать длину последовательности. Точно такое же представление есть и в некомплексном пространстве, которое представляет собой xpos, построенный на RoPE.

См. бумагу xpos. Я также нашел эту конспект лекции полезной для понимания этого.

  • D — причинная маскировка + матрица распада.

Если вы нарисуете D, D будет выглядеть следующим образом:

gamma = 0.9
exponent = [[0, 0, 0, 0],
            [1, 0, 0, 0],
            [2, 1, 0, 0],
            [3, 2, 1, 0]]

D = tril(gamma**exponent)
# [[1.,     0.,     0.,     0.],
#  [0.9000, 1.,     0.,     0.],
#  [0.8100, 0.9000, 1.,     0.],
#  [0.7290, 0.8100, 0.9000, 1.]])
  • Верхний треугольник — 0 → причинная маскировка.
  • Показатель степени = количество раз, когда предыдущее представление токена было разрушено. Это станет более ясным, когда мы увидим рекуррентное представление.

Повторяющееся удержание

Sn аналогичен КВ-кэшу в трансформаторах. Вместо того, чтобы последовательно объединять их все, RetNet объединяет их в единую матрицу с повторением в первой строке. Затем это умножается на запрос текущего шага.

Это то же самое, что и параллельное удержание.

Неофициальный пробный эскиз:

Пусть S_0 = 0.Если мы решим повторение S_n,

Вспомните последнюю строку матрицы показателей для D в параллельном представлении, которая была [3, 2, 1, 0]. Обратите внимание, что n=4. Когда мы вычисляем удержание для 4-го токена по сравнению с 1-м токеном, мы уменьшаем его 3 раза, что эквивалентно n — i = 3 в приведенном выше уравнении! Поскольку все остальное одинаково, параллельное и рекуррентное представления идентичны друг другу.

Удержание по частям

Это выглядит сложно, но на самом деле это параллельное вычисление для каждого фрагмента + рекуррентное соединение фрагментов. Единственная важная вещь — это снова количество примененных затухания.

Ошибка в документе

На самом деле, фрагментарное представление (уравнение выше) для Ri в статье неверно! На самом деле, это должно быть

где оператор X — это векторное произведение, а D_B — последняя строка матрицы D. Интуитивно это следует из убывающего умножения параллельного представления и рекуррентного представления.

Схематическая диаграмма

Вот и все! Выше приведена сводная диаграмма двух представлений.

Почему Распад?

Итак, в основном, самая важная деталь заключается в том, что он использует то, что называется затуханием, и применение затухания нужное количество раз позволяет распараллелить. Но мы должны понять, что стоит за таким распадом. Вывод (на высоком уровне) довольно прост.

  1. мы определяем повторяющееся состояние (s_n) как kv_cache. Затем рекуррентное отношение находится в первой строке на рисунке выше.
  2. Затем мы определяем вывод в момент времени n как Q_n * s_n. Вторая строка выше записывает это и решает повторение, чтобы развернуть полную зависимость. Обратите внимание, что матрица A применяется несколько раз.

3. Теперь мы диагонализируем матрицу A следующим образом.

4. Затем символы Λ можно включить в другие обучаемые параметры (Q_n = X * W_k, поэтому Λ можно включить в W_k!). Поэтому у нас осталась только средняя часть.

Средняя часть — это в точностиγ(распад)и тета, которые мы наблюдали ранее.

Интуитивно они работают как «позиционное кодирование в закрытой форме», которое также имеет рекуррентную форму, так что кодирование в момент времени n можно вычислить заранее, что позволяет распараллелить.

Эмпирические данные

  • RetNet побеждает Transformer, поскольку становится больше. (Критик: не уверен, что эта тенденция сохранится)

  • RetNet превосходит другие линейные преобразователи времени по производительности.

  • Ретнет работает быстро. (Критик: это очевидно на основе архитектуры. Показывать 3 рисунка, чтобы подчеркнуть это, бессмысленно. TBH, нет необходимости даже проводить эксперименты, чтобы нарисовать эти графики…)

Критики

  • В документе есть несколько недостающих деталей, которые не будут прояснены до тех пор, пока не будет опубликован официальный код.
  • RWKV также поддерживает распараллеливание обучения, но в документе это неверно представлено как невозможное.
  • Своего рода хвастовство тем, что RetNet работает быстро, а 3 цифры говорят одно и то же. :-)
  • Любопытно, распространится ли эта тенденция на более крупные модели.
  • Не уверен, что они отпустят предварительно тренированный вес.
  • Не уверен, что они превзойдут такие модели, как LLaMA.

Плюсы

  • БЫСТРЫЙ! (Я раскритиковал их хвастовство, но это действительно быстро, что хорошо)
  • Сопоставимая производительность. Если эта тенденция сохранится и в более крупных моделях не будет падения производительности, это может стать де-факто LLM, поскольку они намного дешевле.

Кому интересно, пожалуйста, взгляните на мою реализацию RetNet: