Как это уменьшается в LSTM

В части 1 этой серии мы рассмотрели обратное распространение в модели RNN и объяснили как формулами, так и численно показали проблему исчезающего градиента в RNN. В этой статье мы собираемся объяснить, как мы можем частично решить проблему исчезающего градиента с помощью LSTM, даже если он не исчезает полностью и с очень длинными последовательностями проблема все еще сохраняется.

Мотивация

Как мы видели в части 1 этой серии, ванильная RNN хранит временную информацию в скрытом состоянии, которое обновляется на каждом временном шаге, когда добавляется новая информация, т. Е. Обрабатывается новый токен в последовательности. Поскольку скрытое состояние обновляется на каждом этапе, старая информация перезаписывается, и сеть забывает, что она видела в прошлом. Чтобы этого избежать, нужна отдельная память и механизм, который решает, что в нее записывать, учитывая новую информацию, что удалять из прошлого, что не пригодится в будущем, а что передавать в следующее состояние.
LSTM делает именно это — добавляет ячейку памяти, в которой хранится долговременная информация, и имеет механизм селектора, который используется для принятия решения о том, что следует забыть из прошлого, добавить из текущего ввода и передать вперед.

Прямое распространение

Давайте посмотрим, как прямое распространение во времени выполняется в модели LSTM. Учитывая последовательность из N токенов и предполагая, что мы получили ячейку памяти c(t-1) и скрытое состояние h(t-1) из предыдущей ячейки, на временном шаге t мы вычисляем вентили, чтобы решить, что делать с новая поступающая информация. Во-первых, давайте посчитаем активации:

Помните, что все веса распределяются по временным шагам. Матрица активации затем разбивается на 4 матрицы, каждая из которых имеет размерность H, и применяя сигмовидную функцию активации к первым трем и th к последним, мы вычисляем вентили:

Обратите внимание, что все вентили являются функциями входа и предыдущего скрытого состояния.

Наконец, мы вычисляем текущее состояние ячейки памяти c(t) и скрытое состояние h(t), которые будут переданы на следующий шаг.

Вычисленные значения вентилей имеют следующие функциональные возможности:

  • ворота f:какую информацию из предыдущей ячейки памяти c(t-1) следует забыть. Обратите внимание, что поскольку мы выполняем поэлементное умножение (помните, что c(t-1) и h(t-1) являются векторами), а f содержит значения от 0 до 1 из-за сигмовидной функции активации, оно отменит или уменьшит информацию в c(t-1), когда значения f равны или близки к 0, и сохранит всю или почти всю информацию, когда значения f равны или близки к 1.
  • gate g: можно интерпретировать как вектор обновления ячейки памяти, который объединяется с предыдущей ячейкой памяти c(t-1) для вычисления новой ячейки памяти c(t). В отличие от других ворот, к активации a(g) применяется тангенсная функция, которая выводит значение от -1 до 1. Это позволяет состоянию памяти ячейки как увеличиваться, так и уменьшаться, как если бы у нас была сигмовидная активация, элементы ячейки памяти никогда не могли уменьшаться.
  • ворота i:какая информация записывается из вектора обновления ячейки памяти (ворота g) в предыдущую ячейку памяти c(t-1).
  • gate o: какую информацию включить в новое скрытое состояние h(t)

Затем эти вентили объединяются, как показано на рис. 4, для вычисления новой ячейки памяти c(t) и скрытого состояния h(t). Эти новые ячейки и скрытое состояние затем передаются в следующую ячейку LSTM, которая снова повторяет тот же процесс. Весь этот процесс можно проиллюстрировать на схеме ниже:

После этого для каждого скрытого состояния вычисляем выход и потери:

В коде:

def softmax(x, axis=2):
    p = np.exp(x - np.max(x, axis=axis,keepdims=True))
    return p / np.sum(p, axis=axis, keepdims=True)

def lstm_step_forward(x, prev_h, prev_c, Wx, Wh, b):

    next_h, next_c, cache = None, None, None
    
    h = x @ Wx + prev_h @ Wh + b
    assert h.shape[-1] % 4 == 0
    ai, af, ao, ag = np.array_split(h, 4, axis=-1)
    i = sigmoid(ai)
    f = sigmoid(af)
    o = sigmoid(ao)
    g = np.tanh(ag)

    next_c = f * prev_c + i * g 
    next_h = o * np.tanh(next_c)
    
    cache = (x, next_h, prev_h, prev_c, Wx, Wh, h, np.tanh(next_c), i, f, o ,g)

    return next_h, next_c, cache


np.random.seed(232)

# N - Batch size
# D - Embeddding dimension
# V - Vocabulary size
# H - Hidden dimension
# T - timesteps
N, D, T, H, V = 2, 5, 3, 4, 4

x  = np.random.randn(N, T, D)
h0 = np.random.randn(N, H)
Wx = np.random.randn(D, H)
Wh = np.random.randn(H, H)
Wy = np.random.randn(H, V)
b  = np.random.randn(H)

y = np.random.randint(V, size=(N, T))
mask = np.ones((N, T))


all_cache = []
h = np.zeros((N, T, H))    
next_c = np.zeros((N, H))
    
for t in range(T):
    xt = x[:, t , :]
    if t == 0:
        next_h, next_c, cache_s = lstm_step_forward(xt, h0, next_c, Wx, Wh, b)
        all_cache.append(cache_s)
    else:
        next_h, next_c, cache_s = lstm_step_forward(xt, next_h, next_c, Wx, Wh, b)   
        all_cache.append(cache_s)

    h[:, t, :] = next_h 

ft = h @ Wy
out = softmax(ft)

Обратное распространение

Формулы для обратного распространения немного сложнее, чем в обычной RNN. В этом уроке мы собираемся получить градиенты относительно Wx, чтобы затем показать, как LSTM обрабатывает исчезающие градиенты. Производные по другим параметрам могут быть получены аналогичным образом, и это предоставляется читателю в качестве упражнения. Однако код содержит производные по всем градиентам, и вы можете проверить свои результаты на основе кода.
Производная Loss по скрытому состоянию остается такой же, как и для RNN, поскольку там ничего не меняется, поскольку Loss принимает только скрытое состояние в качестве входных данных:

Теперь найдем производные по другим одиночным компонентам:

Обратите внимание, что для удобства мы разделили dct/dat и dht/dat, и везде, где у нас есть dht/dct dct/dat, мы пишем его непосредственно как dht/dat. Кроме того, поскольку мы будем выполнять обратное распространение в матричной форме, мы объединяем производные вентилей следующим образом:

Сумма в dht/dat возникает из-за того, что у нас есть два направления (см. рис. 7) — одно ведет в предыдущую ячейку, а другое — в скрытое состояние. С той же логикой градиентного потока производная dct/dc(t-1) выглядит следующим образом:

Теперь давайте получим общий градиент относительно Wx. Это определяется суммой отдельных потерь по отношению к Wx, как описано в части 1 этой серии:

Сосредоточившись на отдельных потерях, например, dL3/dWx,когда мы распространяемся от L3 к Wx, Wx появляется во всех компонентах временных шагов, таким образом , нам нужно будет суммировать все эти компоненты, чтобы получить полный градиент L3 w.r.t. Wx. Слегка злоупотребляя математическими обозначениями, делаем примерно так (помните, что Wx3 = Wx2 = Wx1):

Первый компонент будет таким, как показано ниже. Кроме того, мы заменяем dht/dct dct/dat на dht/dat, чтобы затем напрямую использовать эту производную

Для краткости я пропущу dL3/dWx2 и сразу перейду к третьему компоненту. У нас есть:

Как и ранее, давайте заменим везде, где у нас есть dht/dct dct/dat, на dht/dat, чтобы затем напрямую использовать эту производную:

Суммируя их, мы получаем производную от dL3/dWx. Чтобы получить производную dWx по w.r.t. общие потери, нам нужно добавить к dL3/dWx, dL2/dWx и dL1/dWx.

В коде:

def lstm_forward(x, h0, Wx, Wh, b, next_c=None):
    h, cache = None, None
   
    cache = []
    N, T, _ = x.shape
    H = h0.shape[-1]
    h = np.zeros((N, T, H))
    if next_c is None:
        next_c = np.zeros((N, H))
    for t in range(x.shape[1]):
        xt = x[:, t , :]
        if t == 0:
            next_h, next_c, cache_s = lstm_step_forward(xt, h0, next_c, Wx, Wh, b)
            cache.append(cache_s)
        else:
            next_h, next_c, cache_s = lstm_step_forward(xt, next_h, next_c, Wx, Wh, b)   
            cache.append(cache_s)
    
        h[:, t, :] = next_h 

    return h, cache

def dc_da(h, prev_c, next_c_t, i, f, o, g):
    dgrad_c = np.zeros((h.shape[0], 4 * h.shape[1]))
    dgrad_h = np.zeros((h.shape[0], 4 * h.shape[1]))
    # assert dgrad.shape[1] % 4 == 0
    H = dgrad.shape[1] // 4
    
    # compute gradients wrt ai, af, ao and ag from two flows - next_h and next_c
    dnextc_dai = (i * (1-i)) * g
    dnextc_daf = (f * (1-f)) * prev_c
    dnextc_dao = 0
    dnextc_dag = (1 - g**2) * i
    
    dh_dc = o * (1 - next_c_t**2)

    dnexth_dai = dh_dc * dnextc_dai
    dnexth_daf = dh_dc * dnextc_daf
    dnexth_dao = (o * (1-o) * next_c_t)
    dnexth_dag = dh_dc * dnextc_dag

    # join them together in a matrix at this point to conveniently compute
    # downstream gradients 
    dgrad_c[:, 0:H] = dnextc_dai 
    dgrad_c[:, H:2*H] = dnextc_daf 
    dgrad_c[:, 2*H:3*H] = dnextc_dao 
    dgrad_c[:, 3*H:4*H] = dnextc_dag 
    
    dgrad_h[:, 0:H] =  dnexth_dai
    dgrad_h[:, H:2*H] = dnexth_daf
    dgrad_h[:, 2*H:3*H] = dnexth_dao
    dgrad_h[:, 3*H:4*H] = dnexth_dag
    return dgrad_c, dgrad_h

np.random.seed(1)

N, D, T, H = 1, 3, 3, 1

x = np.random.randn(N, T, D)
h0 = np.random.randn(N, H)
Wx = np.random.randn(D, 4 * H)
Wh = np.random.randn(H, 4 * H)
b = np.random.randn(4 * H)

out, cache = lstm_forward(x, h0, Wx, Wh, b)

# let's define the dout instead of deriving them for simplicity
dout = np.random.randn(*out.shape)
    
# dL3/dWvx
dnext_c2 = np.zeros((h0.shape))
dnext_h2 = dout[:, -1, :]
(x2, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t2, i2, f2, o2 ,g2) = cache[2]
dgrad_c2, dgrad_h2 = dc_da(h0, cache[2][3], cache[2][-5], cache[2][-4],  cache[2][-3], cache[2][-2], cache[2][-1]) 

dL3_dWx2 = x2.T @ (dgrad_h2 * dnext_h2 + dgrad_c2 * dnext_c2)
print(dL3_dWx2)

dnext_c1 = dnext_c2 * f2 + dnext_h2 * o2 * (1 - next_c_t2**2) * f2
dnext_h1 = (dnext_h2 * dgrad_h2 +  dnext_c2 * dgrad_c2) @ Wh.T

(x1, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t1, i1, f1, o1 ,g1) = cache[1]
dgrad_c1, dgrad_h1 = dc_da(h0, cache[1][3], cache[1][-5], cache[1][-4],  cache[1][-3], cache[1][-2], cache[1][-1])    

dL3_dWx1 = x1.T @ (dnext_c1 * dgrad_c1 + dnext_h1 * dgrad_h1)

print(dL3_dWx1)

dnext_c0 = dnext_c1 * f1 + dnext_h1 * o1 * (1 - next_c_t1**2) * f1
dnext_h0 = (dnext_h1 * dgrad_h1 + dnext_c1 * dgrad_c1) @ Wh.T

(x0, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t0, i0, f0, o0 ,g0) = cache[0]
dgrad_c0, dgrad_h0 = dc_da(h0, cache[0][3], cache[0][-5], cache[0][-4],  cache[0][-3], cache[0][-2], cache[0][-1])    

dL3_dWx0 = x0.T @ (dnext_c0 * dgrad_c0 + dnext_h0 * dgrad_h0)
print(dL3_dWx0)

Выходы:

[[-0.02349287  0.00135057 -0.11156069 -0.05284914]
 [ 0.01024921 -0.00058921  0.04867045  0.02305643]
 [-0.00429567  0.00024695 -0.02039889 -0.00966347]]
[[-9.83990139e-03  6.78775168e-05 -1.10660923e-03  4.20773125e-04]
 [ 7.93641636e-03 -5.47469140e-05  8.92540613e-04 -3.39376441e-04]
 [-2.11067811e-02  1.45598602e-04 -2.37369846e-03  9.02566589e-04]]
[[-1.95768961e-05  0.00000000e+00  2.77411349e-05 -9.76467796e-03]
 [ 7.37299593e-06  0.00000000e+00 -1.04477887e-05  3.67754574e-03]
 [ 6.36561888e-06  0.00000000e+00 -9.02030083e-06  3.17508036e-03]]
losses_dWx = {i : {x_comp : 0 for x_comp in range(i)} for i in range(T)}
dWx = np.zeros((D, 4 * H))
dWh = np.zeros((H, 4 * H))
db = np.zeros((4 * H, ))
for idx in range(T-1, -1, -1):
    print(f"Loss {idx + 1}")
    dnext_c = np.zeros((h0.shape))
    dnext_h =  dout[:, idx, :]
    for j in range(idx, -1, -1):
        (x, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t, i, f, o ,g) = cache[j]
        
        dgrad_c, dgrad_h = dc_da(h0, prev_c, next_c_t, i, f, o, g) 
        
        dgrad = dnext_c * dgrad_c + dnext_h * dgrad_h

        losses_dWx[idx][j] = x.T @ dgrad
        
        dnext_c = dnext_c * f + dnext_h * o * (1 - next_c_t**2) * f
        dnext_h = (dnext_h * dgrad_h +  dnext_c * dgrad_c) @ Wh.T
        dnext_h = dgrad @ Wh.T  

        # accumulate gradient of dWx and other params for each loss
        dWx += x.T @ dgrad
        dWh += prev_h.T @ dgrad
        db += dgrad.sum(0)

        print(f"component {j} - ", np.linalg.norm(losses_dWx[idx][j]))

Исчезающий градиент в LSTM

Как и в части 1 для RNN, давайте посмотрим на градиенты Loss L3 для каждого компонента:

Loss 3
component 0 - 0.010906688399113558
component 1 - 0.02478099846737857
component 2 - 0.13901933055672275

Из приведенного выше видно, что X3 (компонент 2), ближайший к L3, по-прежнему имеет самое большое обновление, в то время как X1 и X2вносит меньший вклад в Wx1обновление. Однако для RNN эта разница намного больше.
Действительно, градиент, который проходит через скрытое состояние, будет страдать от исчезающего градиента по той же причине, что и RNN — Whterms (dat/dh(t-1)) по-прежнему появляются в обратном распространении, например здесь в dL3/dWx1 (рис. 15):

Однако градиент, протекающий через ячейку, которая по-прежнему является функцией входных данных и скрытого состояния, не имеет членов Wh, а вместо этого имеет сигмовидные члены (см. формулу для забыть ворота ft на рисунке 3):

Напомним, что dct/dc(t-1) = ft.
Таким образом, если забыть ворот высокое, т. е. близкое к 1, происходит исчезающий градиент гораздо медленнее, чем в ванильном RNN, но это все равно произойдет, если только все вентили забывания не равны ровно 1, чего на практике не происходит.

Выводы

Основная цель этой статьи заключалась в том, чтобы понять, получив обратное распространение, что LSTM все еще страдает от исчезающего градиента на практике, однако с гораздо меньшей скоростью, чем ванильный RNN, благодаря состоянию ячейки, которое заставляет градиент затухать при забудьте скорость прохода, а не скорость Wx.
Если вы обнаружите какие-либо ошибки, пожалуйста, сообщите мне об этом в комментариях.

Рекомендации