TLDR;
Retentive Network (RetNet) имеет производительность, сравнимую с Transformer того же размера, может обучаться параллельно, но поддерживает режим повторения, который обеспечивает сложность вывода O (1) на токен.
Неофициальную, но полную реализацию можно найти в моем репозитории ниже:
«Невозможный треугольник» для моделей генеративной последовательности
Для моделей последовательности, особенно. генеративные, у нас есть три вышеупомянутых требования: быстрый вывод, параллельное обучение и высокая производительность. (На мой взгляд, есть еще одно измерение: экстраполяция длины последовательности. Это может поддерживаться RetNet, но явных экспериментов с ним нет.)
RNN имеет быстрый вывод, но медленное обучение, линейные преобразователи имеют более слабые характеристики, а преобразователи имеют O (n) на вывод токена. RetNet удовлетворяет всем трем требованиям: параллельное обучение, логический вывод O(1) и преобразователи битов.
Краткая история
Было несколько подходов к смягчению дорогостоящего вывода генеративных трансформаторов. Известные работы включают Linear Transformers, Attention-Free Transformers (AFT; от Apple) и RWKV (от BlinkDL, на основе AFT). .
Они заслуживают отдельного поста, поэтому я не буду вдаваться в подробности: но, на мой взгляд, все они математически очень элегантны, особенно вывод того, как можно распараллелить RNN. в то время как я нахожу RetNet немного более интересным, так как он также имеет фрагментарное представление и некоторые изящные трюки, такие как xpos.
Так как же это работает?
RetNet — это своего рода plug-and-play замена «внимания» на «retention» в той же архитектуре Transformer.
Я пройдусь по ним сверху вниз.
1. Каждый блок RetNet
На самом высоком уровне RetNet состоит из нескольких стеков одинаковых блоков, каждый из которых содержит MultiScaleRetention (MSR) и FeedForwardNetwork (FFN). У них также есть норма слоя и пропускное соединение, как и у трансформеров. FFN также почти идентичен Transformers, который представляет собой двухслойный MLP, размер скрытого затемнения = 2 x размер встраивания и с активацией gelu.
Если мы заменим MSR на MultiHeadAttention, это будет просто Transformer. Поэтому все различия можно найти в MSR.
2. Закрытое мультимасштабное удержание
Multi-Scale аналогичен Multi-Head. В приведенном выше уравнении γ — это некоторый гиперпараметр, который следует использовать при удержании, и он определяется отдельно для каждой головки. До групповой нормы это обычное многоголовое внимание, но с удержанием.
Gated MSR добавляет групповую норму, поворотный затвор и выходную проекцию на выходе, что можно рассматривать как вспомогательный вариант дизайна. (group-norm допускает масштабирование скалярного произведения, но пока это не так важно.)Самое важное отличие (модуль удержания) еще впереди.
3. Удержание
Наконец, давайте посмотрим, что такое удержание. У удержания есть 3 парадигмы: параллельная, рекуррентная и поблочно-рекуррентная. Давайте посмотрим на них один за другим.
Параллельное хранение
Сосредоточьтесь на последней строке. Игнорируя D, опять же, это точечный продукт внимания без softmax. Итак, важная деталь снова находится в D и Theta.
- Theta (и bar(Theta), комплексное сопряжение) — это комплексное представление “ кодирование xpos — основано на ротационном встраивании, чтобы модель могла лучше экстраполировать длину последовательности. Точно такое же представление есть и в некомплексном пространстве, которое представляет собой xpos, построенный на RoPE.
См. бумагу xpos. Я также нашел эту конспект лекции полезной для понимания этого.
- D — причинная маскировка + матрица распада.
Если вы нарисуете D, D будет выглядеть следующим образом:
gamma = 0.9 exponent = [[0, 0, 0, 0], [1, 0, 0, 0], [2, 1, 0, 0], [3, 2, 1, 0]] D = tril(gamma**exponent) # [[1., 0., 0., 0.], # [0.9000, 1., 0., 0.], # [0.8100, 0.9000, 1., 0.], # [0.7290, 0.8100, 0.9000, 1.]])
- Верхний треугольник — 0 → причинная маскировка.
- Показатель степени = количество раз, когда предыдущее представление токена было разрушено. Это станет более ясным, когда мы увидим рекуррентное представление.
Повторяющееся удержание
Sn аналогичен КВ-кэшу в трансформаторах. Вместо того, чтобы последовательно объединять их все, RetNet объединяет их в единую матрицу с повторением в первой строке. Затем это умножается на запрос текущего шага.
Это то же самое, что и параллельное удержание.
Неофициальный пробный эскиз:
Пусть S_0 = 0.Если мы решим повторение S_n,
Вспомните последнюю строку матрицы показателей для D в параллельном представлении, которая была [3, 2, 1, 0]. Обратите внимание, что n=4. Когда мы вычисляем удержание для 4-го токена по сравнению с 1-м токеном, мы уменьшаем его 3 раза, что эквивалентно n — i = 3 в приведенном выше уравнении! Поскольку все остальное одинаково, параллельное и рекуррентное представления идентичны друг другу.
Удержание по частям
Это выглядит сложно, но на самом деле это параллельное вычисление для каждого фрагмента + рекуррентное соединение фрагментов. Единственная важная вещь — это снова количество примененных затухания.
Ошибка в документе
На самом деле, фрагментарное представление (уравнение выше) для Ri в статье неверно! На самом деле, это должно быть
где оператор X — это векторное произведение, а D_B — последняя строка матрицы D. Интуитивно это следует из убывающего умножения параллельного представления и рекуррентного представления.
Схематическая диаграмма
Вот и все! Выше приведена сводная диаграмма двух представлений.
Почему Распад?
Итак, в основном, самая важная деталь заключается в том, что он использует то, что называется затуханием, и применение затухания нужное количество раз позволяет распараллелить. Но мы должны понять, что стоит за таким распадом. Вывод (на высоком уровне) довольно прост.
- мы определяем повторяющееся состояние (s_n) как kv_cache. Затем рекуррентное отношение находится в первой строке на рисунке выше.
- Затем мы определяем вывод в момент времени n как Q_n * s_n. Вторая строка выше записывает это и решает повторение, чтобы развернуть полную зависимость. Обратите внимание, что матрица A применяется несколько раз.
3. Теперь мы диагонализируем матрицу A следующим образом.
4. Затем символы Λ можно включить в другие обучаемые параметры (Q_n = X * W_k, поэтому Λ можно включить в W_k!). Поэтому у нас осталась только средняя часть.
Средняя часть — это в точностиγ(распад)и тета, которые мы наблюдали ранее.
Интуитивно они работают как «позиционное кодирование в закрытой форме», которое также имеет рекуррентную форму, так что кодирование в момент времени n можно вычислить заранее, что позволяет распараллелить.
Эмпирические данные
- RetNet побеждает Transformer, поскольку становится больше. (Критик: не уверен, что эта тенденция сохранится)
- RetNet превосходит другие линейные преобразователи времени по производительности.
- Ретнет работает быстро. (Критик: это очевидно на основе архитектуры. Показывать 3 рисунка, чтобы подчеркнуть это, бессмысленно. TBH, нет необходимости даже проводить эксперименты, чтобы нарисовать эти графики…)
Критики
- В документе есть несколько недостающих деталей, которые не будут прояснены до тех пор, пока не будет опубликован официальный код.
- RWKV также поддерживает распараллеливание обучения, но в документе это неверно представлено как невозможное.
- Своего рода хвастовство тем, что RetNet работает быстро, а 3 цифры говорят одно и то же. :-)
- Любопытно, распространится ли эта тенденция на более крупные модели.
- Не уверен, что они отпустят предварительно тренированный вес.
- Не уверен, что они превзойдут такие модели, как LLaMA.
Плюсы
- БЫСТРЫЙ! (Я раскритиковал их хвастовство, но это действительно быстро, что хорошо)
- Сопоставимая производительность. Если эта тенденция сохранится и в более крупных моделях не будет падения производительности, это может стать де-факто LLM, поскольку они намного дешевле.
Кому интересно, пожалуйста, взгляните на мою реализацию RetNet: