Как это уменьшается в LSTM
В части 1 этой серии мы рассмотрели обратное распространение в модели RNN и объяснили как формулами, так и численно показали проблему исчезающего градиента в RNN. В этой статье мы собираемся объяснить, как мы можем частично решить проблему исчезающего градиента с помощью LSTM, даже если он не исчезает полностью и с очень длинными последовательностями проблема все еще сохраняется.
Мотивация
Как мы видели в части 1 этой серии, ванильная RNN хранит временную информацию в скрытом состоянии, которое обновляется на каждом временном шаге, когда добавляется новая информация, т. Е. Обрабатывается новый токен в последовательности. Поскольку скрытое состояние обновляется на каждом этапе, старая информация перезаписывается, и сеть забывает, что она видела в прошлом. Чтобы этого избежать, нужна отдельная память и механизм, который решает, что в нее записывать, учитывая новую информацию, что удалять из прошлого, что не пригодится в будущем, а что передавать в следующее состояние.
LSTM делает именно это — добавляет ячейку памяти, в которой хранится долговременная информация, и имеет механизм селектора, который используется для принятия решения о том, что следует забыть из прошлого, добавить из текущего ввода и передать вперед.
Прямое распространение
Давайте посмотрим, как прямое распространение во времени выполняется в модели LSTM. Учитывая последовательность из N токенов и предполагая, что мы получили ячейку памяти c(t-1) и скрытое состояние h(t-1) из предыдущей ячейки, на временном шаге t мы вычисляем вентили, чтобы решить, что делать с новая поступающая информация. Во-первых, давайте посчитаем активации:
Помните, что все веса распределяются по временным шагам. Матрица активации затем разбивается на 4 матрицы, каждая из которых имеет размерность H, и применяя сигмовидную функцию активации к первым трем и th к последним, мы вычисляем вентили:
Обратите внимание, что все вентили являются функциями входа и предыдущего скрытого состояния.
Наконец, мы вычисляем текущее состояние ячейки памяти c(t) и скрытое состояние h(t), которые будут переданы на следующий шаг.
Вычисленные значения вентилей имеют следующие функциональные возможности:
- ворота f:какую информацию из предыдущей ячейки памяти c(t-1) следует забыть. Обратите внимание, что поскольку мы выполняем поэлементное умножение (помните, что c(t-1) и h(t-1) являются векторами), а f содержит значения от 0 до 1 из-за сигмовидной функции активации, оно отменит или уменьшит информацию в c(t-1), когда значения f равны или близки к 0, и сохранит всю или почти всю информацию, когда значения f равны или близки к 1.
- gate g: можно интерпретировать как вектор обновления ячейки памяти, который объединяется с предыдущей ячейкой памяти c(t-1) для вычисления новой ячейки памяти c(t). В отличие от других ворот, к активации a(g) применяется тангенсная функция, которая выводит значение от -1 до 1. Это позволяет состоянию памяти ячейки как увеличиваться, так и уменьшаться, как если бы у нас была сигмовидная активация, элементы ячейки памяти никогда не могли уменьшаться.
- ворота i:какая информация записывается из вектора обновления ячейки памяти (ворота g) в предыдущую ячейку памяти c(t-1).
- gate o: какую информацию включить в новое скрытое состояние h(t)
Затем эти вентили объединяются, как показано на рис. 4, для вычисления новой ячейки памяти c(t) и скрытого состояния h(t). Эти новые ячейки и скрытое состояние затем передаются в следующую ячейку LSTM, которая снова повторяет тот же процесс. Весь этот процесс можно проиллюстрировать на схеме ниже:
После этого для каждого скрытого состояния вычисляем выход и потери:
В коде:
def softmax(x, axis=2): p = np.exp(x - np.max(x, axis=axis,keepdims=True)) return p / np.sum(p, axis=axis, keepdims=True) def lstm_step_forward(x, prev_h, prev_c, Wx, Wh, b): next_h, next_c, cache = None, None, None h = x @ Wx + prev_h @ Wh + b assert h.shape[-1] % 4 == 0 ai, af, ao, ag = np.array_split(h, 4, axis=-1) i = sigmoid(ai) f = sigmoid(af) o = sigmoid(ao) g = np.tanh(ag) next_c = f * prev_c + i * g next_h = o * np.tanh(next_c) cache = (x, next_h, prev_h, prev_c, Wx, Wh, h, np.tanh(next_c), i, f, o ,g) return next_h, next_c, cache np.random.seed(232) # N - Batch size # D - Embeddding dimension # V - Vocabulary size # H - Hidden dimension # T - timesteps N, D, T, H, V = 2, 5, 3, 4, 4 x = np.random.randn(N, T, D) h0 = np.random.randn(N, H) Wx = np.random.randn(D, H) Wh = np.random.randn(H, H) Wy = np.random.randn(H, V) b = np.random.randn(H) y = np.random.randint(V, size=(N, T)) mask = np.ones((N, T)) all_cache = [] h = np.zeros((N, T, H)) next_c = np.zeros((N, H)) for t in range(T): xt = x[:, t , :] if t == 0: next_h, next_c, cache_s = lstm_step_forward(xt, h0, next_c, Wx, Wh, b) all_cache.append(cache_s) else: next_h, next_c, cache_s = lstm_step_forward(xt, next_h, next_c, Wx, Wh, b) all_cache.append(cache_s) h[:, t, :] = next_h ft = h @ Wy out = softmax(ft)
Обратное распространение
Формулы для обратного распространения немного сложнее, чем в обычной RNN. В этом уроке мы собираемся получить градиенты относительно Wx, чтобы затем показать, как LSTM обрабатывает исчезающие градиенты. Производные по другим параметрам могут быть получены аналогичным образом, и это предоставляется читателю в качестве упражнения. Однако код содержит производные по всем градиентам, и вы можете проверить свои результаты на основе кода.
Производная Loss по скрытому состоянию остается такой же, как и для RNN, поскольку там ничего не меняется, поскольку Loss принимает только скрытое состояние в качестве входных данных:
Теперь найдем производные по другим одиночным компонентам:
Обратите внимание, что для удобства мы разделили dct/dat и dht/dat, и везде, где у нас есть dht/dct dct/dat, мы пишем его непосредственно как dht/dat. Кроме того, поскольку мы будем выполнять обратное распространение в матричной форме, мы объединяем производные вентилей следующим образом:
Сумма в dht/dat возникает из-за того, что у нас есть два направления (см. рис. 7) — одно ведет в предыдущую ячейку, а другое — в скрытое состояние. С той же логикой градиентного потока производная dct/dc(t-1) выглядит следующим образом:
Теперь давайте получим общий градиент относительно Wx. Это определяется суммой отдельных потерь по отношению к Wx, как описано в части 1 этой серии:
Сосредоточившись на отдельных потерях, например, dL3/dWx,когда мы распространяемся от L3 к Wx, Wx появляется во всех компонентах временных шагов, таким образом , нам нужно будет суммировать все эти компоненты, чтобы получить полный градиент L3 w.r.t. Wx. Слегка злоупотребляя математическими обозначениями, делаем примерно так (помните, что Wx3 = Wx2 = Wx1):
Первый компонент будет таким, как показано ниже. Кроме того, мы заменяем dht/dct dct/dat на dht/dat, чтобы затем напрямую использовать эту производную
Для краткости я пропущу dL3/dWx2 и сразу перейду к третьему компоненту. У нас есть:
Как и ранее, давайте заменим везде, где у нас есть dht/dct dct/dat, на dht/dat, чтобы затем напрямую использовать эту производную:
Суммируя их, мы получаем производную от dL3/dWx. Чтобы получить производную dWx по w.r.t. общие потери, нам нужно добавить к dL3/dWx, dL2/dWx и dL1/dWx.
В коде:
def lstm_forward(x, h0, Wx, Wh, b, next_c=None): h, cache = None, None cache = [] N, T, _ = x.shape H = h0.shape[-1] h = np.zeros((N, T, H)) if next_c is None: next_c = np.zeros((N, H)) for t in range(x.shape[1]): xt = x[:, t , :] if t == 0: next_h, next_c, cache_s = lstm_step_forward(xt, h0, next_c, Wx, Wh, b) cache.append(cache_s) else: next_h, next_c, cache_s = lstm_step_forward(xt, next_h, next_c, Wx, Wh, b) cache.append(cache_s) h[:, t, :] = next_h return h, cache def dc_da(h, prev_c, next_c_t, i, f, o, g): dgrad_c = np.zeros((h.shape[0], 4 * h.shape[1])) dgrad_h = np.zeros((h.shape[0], 4 * h.shape[1])) # assert dgrad.shape[1] % 4 == 0 H = dgrad.shape[1] // 4 # compute gradients wrt ai, af, ao and ag from two flows - next_h and next_c dnextc_dai = (i * (1-i)) * g dnextc_daf = (f * (1-f)) * prev_c dnextc_dao = 0 dnextc_dag = (1 - g**2) * i dh_dc = o * (1 - next_c_t**2) dnexth_dai = dh_dc * dnextc_dai dnexth_daf = dh_dc * dnextc_daf dnexth_dao = (o * (1-o) * next_c_t) dnexth_dag = dh_dc * dnextc_dag # join them together in a matrix at this point to conveniently compute # downstream gradients dgrad_c[:, 0:H] = dnextc_dai dgrad_c[:, H:2*H] = dnextc_daf dgrad_c[:, 2*H:3*H] = dnextc_dao dgrad_c[:, 3*H:4*H] = dnextc_dag dgrad_h[:, 0:H] = dnexth_dai dgrad_h[:, H:2*H] = dnexth_daf dgrad_h[:, 2*H:3*H] = dnexth_dao dgrad_h[:, 3*H:4*H] = dnexth_dag return dgrad_c, dgrad_h np.random.seed(1) N, D, T, H = 1, 3, 3, 1 x = np.random.randn(N, T, D) h0 = np.random.randn(N, H) Wx = np.random.randn(D, 4 * H) Wh = np.random.randn(H, 4 * H) b = np.random.randn(4 * H) out, cache = lstm_forward(x, h0, Wx, Wh, b) # let's define the dout instead of deriving them for simplicity dout = np.random.randn(*out.shape) # dL3/dWvx dnext_c2 = np.zeros((h0.shape)) dnext_h2 = dout[:, -1, :] (x2, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t2, i2, f2, o2 ,g2) = cache[2] dgrad_c2, dgrad_h2 = dc_da(h0, cache[2][3], cache[2][-5], cache[2][-4], cache[2][-3], cache[2][-2], cache[2][-1]) dL3_dWx2 = x2.T @ (dgrad_h2 * dnext_h2 + dgrad_c2 * dnext_c2) print(dL3_dWx2) dnext_c1 = dnext_c2 * f2 + dnext_h2 * o2 * (1 - next_c_t2**2) * f2 dnext_h1 = (dnext_h2 * dgrad_h2 + dnext_c2 * dgrad_c2) @ Wh.T (x1, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t1, i1, f1, o1 ,g1) = cache[1] dgrad_c1, dgrad_h1 = dc_da(h0, cache[1][3], cache[1][-5], cache[1][-4], cache[1][-3], cache[1][-2], cache[1][-1]) dL3_dWx1 = x1.T @ (dnext_c1 * dgrad_c1 + dnext_h1 * dgrad_h1) print(dL3_dWx1) dnext_c0 = dnext_c1 * f1 + dnext_h1 * o1 * (1 - next_c_t1**2) * f1 dnext_h0 = (dnext_h1 * dgrad_h1 + dnext_c1 * dgrad_c1) @ Wh.T (x0, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t0, i0, f0, o0 ,g0) = cache[0] dgrad_c0, dgrad_h0 = dc_da(h0, cache[0][3], cache[0][-5], cache[0][-4], cache[0][-3], cache[0][-2], cache[0][-1]) dL3_dWx0 = x0.T @ (dnext_c0 * dgrad_c0 + dnext_h0 * dgrad_h0) print(dL3_dWx0)
Выходы:
[[-0.02349287 0.00135057 -0.11156069 -0.05284914] [ 0.01024921 -0.00058921 0.04867045 0.02305643] [-0.00429567 0.00024695 -0.02039889 -0.00966347]] [[-9.83990139e-03 6.78775168e-05 -1.10660923e-03 4.20773125e-04] [ 7.93641636e-03 -5.47469140e-05 8.92540613e-04 -3.39376441e-04] [-2.11067811e-02 1.45598602e-04 -2.37369846e-03 9.02566589e-04]] [[-1.95768961e-05 0.00000000e+00 2.77411349e-05 -9.76467796e-03] [ 7.37299593e-06 0.00000000e+00 -1.04477887e-05 3.67754574e-03] [ 6.36561888e-06 0.00000000e+00 -9.02030083e-06 3.17508036e-03]] losses_dWx = {i : {x_comp : 0 for x_comp in range(i)} for i in range(T)} dWx = np.zeros((D, 4 * H)) dWh = np.zeros((H, 4 * H)) db = np.zeros((4 * H, )) for idx in range(T-1, -1, -1): print(f"Loss {idx + 1}") dnext_c = np.zeros((h0.shape)) dnext_h = dout[:, idx, :] for j in range(idx, -1, -1): (x, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t, i, f, o ,g) = cache[j] dgrad_c, dgrad_h = dc_da(h0, prev_c, next_c_t, i, f, o, g) dgrad = dnext_c * dgrad_c + dnext_h * dgrad_h losses_dWx[idx][j] = x.T @ dgrad dnext_c = dnext_c * f + dnext_h * o * (1 - next_c_t**2) * f dnext_h = (dnext_h * dgrad_h + dnext_c * dgrad_c) @ Wh.T dnext_h = dgrad @ Wh.T # accumulate gradient of dWx and other params for each loss dWx += x.T @ dgrad dWh += prev_h.T @ dgrad db += dgrad.sum(0) print(f"component {j} - ", np.linalg.norm(losses_dWx[idx][j]))
Исчезающий градиент в LSTM
Как и в части 1 для RNN, давайте посмотрим на градиенты Loss L3 для каждого компонента:
Loss 3 component 0 - 0.010906688399113558 component 1 - 0.02478099846737857 component 2 - 0.13901933055672275
Из приведенного выше видно, что X3 (компонент 2), ближайший к L3, по-прежнему имеет самое большое обновление, в то время как X1 и X2вносит меньший вклад в Wx1обновление. Однако для RNN эта разница намного больше.
Действительно, градиент, который проходит через скрытое состояние, будет страдать от исчезающего градиента по той же причине, что и RNN — Whterms (dat/dh(t-1)) по-прежнему появляются в обратном распространении, например здесь в dL3/dWx1 (рис. 15):
Однако градиент, протекающий через ячейку, которая по-прежнему является функцией входных данных и скрытого состояния, не имеет членов Wh, а вместо этого имеет сигмовидные члены (см. формулу для забыть ворота ft на рисунке 3):
Напомним, что dct/dc(t-1) = ft.
Таким образом, если забыть ворот высокое, т. е. близкое к 1, происходит исчезающий градиент гораздо медленнее, чем в ванильном RNN, но это все равно произойдет, если только все вентили забывания не равны ровно 1, чего на практике не происходит.
Выводы
Основная цель этой статьи заключалась в том, чтобы понять, получив обратное распространение, что LSTM все еще страдает от исчезающего градиента на практике, однако с гораздо меньшей скоростью, чем ванильный RNN, благодаря состоянию ячейки, которое заставляет градиент затухать при забудьте скорость прохода, а не скорость Wx.
Если вы обнаружите какие-либо ошибки, пожалуйста, сообщите мне об этом в комментариях.