Для мини-проекта я решил реализовать в коде некоторые алгоритмы из статьи Развитие алгоритмов обучения с подкреплением. Основная идея статьи состоит в том, чтобы разработать новые алгоритмы обучения с подкреплением (RL), представляя алгоритм в виде графа, допуская различные эволюции и выбирая наиболее эффективные из них.

Некоторые моменты из статьи для меня:

  • Разработанные алгоритмы могут загружаться из известного алгоритма (например, Deep Q Network (DQN)) или с нуля.
  • Алгоритмы могут обобщаться на новые среды, в которых они не обучались.
  • Авторы связывают наиболее эффективные новые алгоритмы, DQNClipped и DQNReg, с аналогичными методами, используемыми в Conservative Q Learning и Munchausen-DQN.
  • В документе используются некоторые уловки, чтобы заставить эволюцию работать лучше, такие как регуляризованная эволюция (удаление самых старых членов из популяции) и использование быстрых испытаний CartPole в качестве метода быстрой оценки (оценка популяции часто является узким местом в эволюционных подходах).

DQNClipped и DQNReg

Остальная часть этого поста посвящена внедрению DQNClipped и DQNReg в существующую библиотеку RL Stable Baselines 3 (SB). Два алгоритма являются простыми модификациями DQN. Выбор существующей библиотеки RL может быть сложнее, чем кодирование для DQNReg. Есть десятки библиотек RL на выбор. В этой статье есть таблица, в которой упоминаются 24 из них, и эта таблица даже близко не исчерпывающая. Я выбрал SB, так как он хорошо зарекомендовал себя, в нем есть активные разработчики и в планах сделать запрос на включение в будущем.

Вот основы реализации алгоритмов в SB. Если я в конечном итоге выполню запрос на включение, я напишу об этом еще один пост.

DQNClipped и DQNReg используют следующие функции потерь:

В основном это означает замену этой строки кода DQN другими функциями потерь:

loss = F.smooth_l1_loss(current_q_values, target_q_values)

Функция потери DQNReg:

import torch as th
def dqn_reg_loss(current_q, target_q, weight=0.1):
    """
    In DQN, replaces Huber/MSE loss between train and target network
    :param current_q: Q(st, at) of training network
    :param target_q: Max Q value from the target network, including the reward and gamma. r + gamma * Q_target(st+1,a)
    :param weight: scalar. weighted term that regularizes Q value. Paper defaults to 0.1 but theorizes that tuning this
    per env to some positive value may be beneficial.
    """
    # weight * Q(st, at) + delta^2
    delta = current_q - target_q
    loss = th.mean(weight * current_q + th.pow(delta, 2))

    return loss

Функция потерь DQNClipped:

def dqn_clipped_loss(current_q, target_q, target_q_a_max, gamma):
    """
    :param current_q: Q(st, at) of training network
    :param target_q: Max Q value from the target network, including the reward and gamma. r + gamma * Q_target(st+1,a)
    :param target_q_a_max: Q_target(st+1, a) where a is the action of the max Q value
    :param gamma: the discount factor
    """
    # max[current_q, delta^2 + target_q] + max[delta,gamma*(target_q_a_max)^2]
    delta = current_q - target_q
    left_max = th.max(current_q, th.pow(delta, 2) + target_q)
    right_max = th.max(delta, gamma * th.pow(target_q_a_max, 2))
    loss = th.mean(left_max + right_max)

    return loss

Вот репозиторий моей реализации. Файлы utils.py и dqn.py — это модификации файлов SB, которые позволяют указать два алгоритма. Пример использования и гиперпараметры находятся в этом файле.

В документе сообщаются следующие результаты для четырех классических сред управления: CartPole, LunarLandar, MountainCar и Acrobot.

Чтобы проверить, работает ли моя реализация, я запустил 5 случайных начальных значений в этих средах для трех функций потерь (стандартный DQN, DQNClipped и DQNReg) и получил аналогичные результаты (см. графики ниже). Я скопировал гиперпараметры, упомянутые в статье. Что касается гиперпараметров, не упомянутых в документе, я сделал все возможное, чтобы предположить. Эти гиперпараметры были успешными для всех сред, кроме MountainCar. DQN SB не работал с моими настройками гиперпараметров или настройками по умолчанию, поэтому я использовал настройки гиперпараметров из RL Zoo (репозиторий, который содержит настроенные и обученные сети от агентов SB).

Чтобы внести свой вклад в SB, мне еще нужно сделать несколько вещей. В SB contrib repo есть свои шаги, которые мне нужно будет выполнить. Я постараюсь сделать еще один пост об этом. Я, вероятно, также захочу провести еще несколько испытаний, чтобы убедиться, что мои результаты соответствуют результатам статьи в более сложных условиях. В следующем посте я, скорее всего, расскажу о процессе вклада и использовании облачного сервиса для запуска пробных версий.

Вещи, в которых я не был уверен

Как обычно, были некоторые вещи, в которых я не был уверен, когда пытался поместить статью в код. Такие как:

  • Полные гиперпараметры для DQN
  • Использует ли дельта-формула текущее состояние или следующее состояние для целевой сети Q. В формуле четко указано текущее состояние (которое будет отличаться от следующего состояния, используемого DQN), но формулировка в нескольких разных частях документа заставляет меня думать, что это следующее состояние (например, называть delta² «нормальной потерей DQN»).