Это репост оригинала на сигмовидном штрихе. Соответствующий код доступен на GitHub здесь.

Представьте себе химика, ищущего новое лекарство, которое будет связываться с определенным белком в организме человека. Химик сталкивается с двумя проблемами: 1) количество возможных лекарств огромно и 2) стоимость синтеза большого количества лекарств-кандидатов и оценки их эффективности непомерно высока и требует много времени.

Чтобы помочь ей автоматизировать процесс поиска небольшого набора лекарств-кандидатов, она обучает модель оракула выводить оценку для любой произвольной молекулы, указывающую, насколько хорошо она будет связываться с целевым белком. Это делается контролируемым образом путем обучения модели на парах (x,y), где xпредставляет известные наркотики, а y — эффективность, с которой лекарство связывается с белком. Эта модель, однако, не является генеративной и не поможет ей найти новое соединение, которое она ищет, сама по себе. Именно здесь на сцену выходят генеративные потоковые сети (GFlowNet).

Обзор

Сети GFlowNet были предложены в 2021 году Эммануэлем и Йошуа Бенжио (и сотрудниками) из Mila и пытаются одновременно решить две проблемы: первая проблема — это выборка дискретных и составных объектов x (т. может быть построен путем сложения элементов в последовательности дискретных шагов). Например, графы и молекулы являются примерами таких объектов, поскольку их можно выбирать, добавляя узлы/атомы и ребра/связи к объекту шаг за шагом.

Вторая проблема заключается в выборке таких объектов пропорционально заданному неотрицательному вознаграждению или «энергетической» функции R(x). Ключевая фраза здесь — «пропорционально». Обычно модели обучаются максимизировать заданную функцию вознаграждения, таким образом сходясь вокруг одного или нескольких образцов с высоким вознаграждением. Напротив, GFlowNet обучается таким образом, что вероятность p(x) выборки объекта x ∈ X соответствует нормализованному вознаграждению, т.е.

Это свойство побуждает к исследованию пространства выборки X и, как правило, приводит к более широкому разнообразию выборок, чем те, которые получены из модели, обученной с использованием максимизации вознаграждения. Хотя это свойство не всегда желательно, оно может быть очень полезным для нашего гипотетического химика. Это связано с тем, что оракул является лишь приблизительным и, таким образом, иногда будет присваивать более низкие баллы, чем это оправдано для новых соединений, которые на самом деле превосходно связываются с целевым белком. Тогда возникает вопрос, как GFlowNets достигают этого?

Выборка

Новизна GFlowNet заключается в том, как они обучаются, а не в деталях их архитектуры. Фактически, более или менее любая нейронная сеть может выступать в качестве архитектуры для GFlowNet. Однако, чтобы понять, как обучается архитектура, нам нужно поближе взглянуть на то, как авторы формулируют проблему выборки из модели.

Рассмотрим проблему выборки DAG с тремя узлами. Для трехузлового графа с узлами X, Y и Z существует 25 возможных DAG, многие из которых могут быть получены за несколько способами, начиная с пустого графа (т. е. графа с тремя узлами, но без ребер). Пусть s0​ обозначает пустой граф и начальную точку процесса выборки. Отсюда у нас есть набор разрешенных действий A(s0), которые добавят ребро к пустому графу (или завершат работу, поскольку s0​ – допустимый DAG). В частности, мы можем выбирать между следующими действиями:

каждый из которых переходит либо в s1​,…,s6​, либо завершается в s0​. Поскольку существует 25 возможных DAG, пространство состояний равно s0​,…,s24​. (Вообще, не все состояния должны соответствовать допустимым объектам. Когда si​ представляет недопустимый объект, A(si​) просто не будет содержать действие завершения. Это полезен в тех случаях, когда модели необходимо «пройти» через недопустимый объект, чтобы достичь допустимого.) Мы можем визуализировать небольшую часть пространства состояния-действия следующим образом:

Здесь я ограничил пространство, чтобы включить только состояния и действия, возникающие в результате первоначального выбора X→Y или Y→Z и приводящие к DAG с ребрами (X →Y, Y→Z, X→Z). Из диаграммы видно, что может быть несколько последовательностей действий, ведущих к одному и тому же состоянию. Последовательность действий также можно рассматривать как траекторию τ, такую, что τ=(s0​→ ⋯ → sT ​→ sf​) содержит историю пошагового создания конечного объекта sT​. Вы также заметите особое состояние sf​. Это состояние называется sink и достигается всякий раз, когда выборка завершается. Раковина как таковая не является частью модели, но, как мы увидим ниже, помогает формализовать цель обучения.

Поток

Основная идея GFlowNets состоит в том, чтобы представить Z = Σ_x R(x) единиц воды, текущей из источника s0 по различным траекториям и выходящей из системы через сток. Важно отметить, что для состояния s, представляющего объект x, объем потока на пути s → sf должен быть равен R(x ), а любой избыточный поток должен переходить в другие состояния. Если мы сможем добиться этого, относительная частота завершения в с для любой единицы потока будет R(x) / Z, и, таким образом, мы будем производить выборку x пропорционально вознаграждению. Трудность состоит в том, чтобы формализовать потерю, при которой оптимальное решение подчиняется этому ограничению. Для этого нам нужно ввести условия согласования потоков.

Согласование потока и баланс траектории

Пусть F(s → s') обозначает поток по пути s s', а Pa(s) и Ch(s) обозначают родительские и дочерние состояния s соответственно. Чтобы поток был действительным, для всех состояний, кроме s0 и sf, должно выполняться следующее условие:

Для удобства мы предполагаем, что sf ∉ Ch(s) для любого s. Это означает, что равенство требует, чтобы конечный поток F(s → sf) был равен R(s), а любой избыточный поток распределялся между дочерними состояниями в Канал(ы). Есть несколько способов сделать это. Например, вы можете ввести оценщик потока F_Φ с параметрами Φ и минимизировать потери согласования потока

Альтернативный (но эквивалентный) подход состоит в том, чтобы сформулировать проблему в терминах прямой политики P_F(s{t+1} | st), выводящей распределение по состояниям, достижимым из состояния в момент времени t (вы также можете просмотреть это как распределение по A(st)), а также оценку Z_θ общего потока, исходящего из s0 . Кроме того, если мы введем обратную политику P_B(st | s{t+1}), мы можем сформулировать то, что известно как потеря баланса траектории:

Чтобы понять потери баланса траектории, рассмотрим условие детального баланса, которое просто утверждает, что F(s → s') может быть выражено как доля общего потока F(s) через s и как часть общего потока F(s') через s':

Если это не интуитивно понятно, представьте, что у Алисы и Боба по 10 литров воды. Алиса отдает 8 литров Чарли, а Боб отдает ему 5 литров. Вода, «текущая» между Алисой и Чарли, равна 0,8*10, но также может быть выражена как 8/13 * 13, поскольку всего Чарли получает 13 литров. Когда P_F и P_B совпадают таким образом, они удовлетворяют условию подробного баланса.

Вдохновленный этим, потеря баланса траектории заменяет поток F(s’) в правой части уравнения выше вознаграждением R(s’). Обратите внимание, что это делается только в отношении последнего состояния траектории (т. е. sf), так как это единственное состояние, в котором поток из sT всегда должен точно равняться награда R(sT). Интересно, что авторы утверждают, что существует уникальный поток, удовлетворяющий приведенному выше условию согласования потоков, независимо от выбора P_B. Таким образом, мы можем безопасно установить для P_B равномерное распределение по родительским состояниям и сосредоточить наше внимание исключительно на обучении Z_θ и P_F.

Реализация GFlowNet

Теперь мы пойдем по стопам Бенжио и др. и внедрить GFlowNet для моделирования среды двумерной сетки, в которой каждая координата имеет соответствующее вознаграждение. s0 будет представлять верхнюю левую координату, и каждое действие будет перемещаться вниз или вправо от текущей координаты. Для размера сетки N=16 вознаграждение среда выглядит следующим образом:

Как видно из изображения, функция вознаграждения имеет четыре режима, разделенных «мертвыми зонами» с очень низким вознаграждением. Чем ниже вознаграждение, тем сложнее модели будет исследовать окружающую среду. В крайнем случае, когда вознаграждение равно 0 вне режимов, модель не сможет исследовать окружающую среду, так как никакой поток не сможет пройти через мертвые зоны. Однако до тех пор, пока вознаграждение положительное, модель сможет исследовать всю окружающую среду, если у нее будет достаточно времени.

Мы собираемся реализовать P_F как маленькую MLP, принимающую в качестве входных данных N²-мерный однократный вектор, указывающий текущую позицию (состояние), и выводящий распределение возможных действий: Down, Верно или Завершить. Когда текущая позиция находится на нижнем или правом краю среды, недопустимое действие Вниз или Вправо будет замаскировано установкой его вероятности на 0. Наконец, мы делаем выборку действие в соответствии с этим распределением, соответствующим образом обновить состояние, ввести новое состояние в P_F и т. д. Это продолжается до тех пор, пока модель не выберет действие Завершить, после чего мы вычисляем потеря баланса траектории и обновление Z_θ и P_F. Промыть и повторить.

Начнем с реализации прямой и обратной политики.

class ForwardPolicy(nn.Module):
    def __init__(self, state_dim, hidden_dim, num_actions):
        super().__init__()
        self.dense1 = nn.Linear(state_dim, hidden_dim)
        self.dense2 = nn.Linear(hidden_dim, num_actions)
    
    def forward(self, s):
        x = self.dense1(s)
        x = relu(x)
        x = self.dense2(x)
        return softmax(x, dim=1)
        
class BackwardPolicy:
    def __init__(self, state_dim, num_actions):
        super().__init__()
        self.num_actions = num_actions
        self.size = int(state_dim**0.5)
    
    def __call__(self, s):
        idx = s.argmax(-1)
        at_top_edge = idx < self.size
        at_left_edge = (idx > 0) & (idx % self.size == 0)
        
        probs = 0.5 * torch.ones(len(s), self.num_actions)
        probs[at_left_edge] = torch.Tensor([1, 0, 0]) # previous action was "down"
        probs[at_top_edge] = torch.Tensor([0, 1, 0]) # previous action was "right"
        probs[:, -1] = 0 # disregard termination action
        
        return probs

Самая интересная часть здесь — BackwardPolicy. Поскольку мы используем фиксированную обратную политику, классу не нужно наследовать от nn.Module. Когда задано состояние, метод __call__ просто по умолчанию назначает вероятность 0,5 для достижения состояния из любого из двух родительских состояний (сверху или слева). Однако в случае, когда имеется только одно родительское состояние (на левом и верхнем краю среды), мы устанавливаем вероятность 1 для единственного родительского состояния.

Мы также определяем класс Grid, отвечающий за маскировку недопустимых действий, получение пар «состояние-действие» и вывод обновленных состояний, а также вычисление вознаграждений за состояние.

class Grid:
    def __init__(self, size):
        self.size = size
        self.state_dim = size**2
        self.num_actions = 3 # down, right, terminate
        
    def update(self, s, actions):
        idx = s.argmax(1)
        down, right = actions == 0, actions == 1
        idx[down] = idx[down] + self.size
        idx[right] = idx[right] + 1
        return one_hot(idx, self.state_dim).float()
    
    def mask(self, s):
        mask = torch.ones(len(s), self.num_actions)
        idx = s.argmax(1) + 1
        at_right_edge = (idx > 0) & (idx % (self.size) == 0)
        at_bottom_edge = idx > self.size*(self.size-1)
        mask[at_right_edge, 1] = 0
        mask[at_bottom_edge, 0] = 0
        return mask
        
    def reward(self, s):
        grid = s.view(len(s), self.size, self.size)
        coord = (grid == 1).nonzero()[:, 1:].view(len(s), 2)
        R0, R1, R2 = 1e-2, 0.5, 2
        norm = torch.abs(coord / (self.size-1) - 0.5)
        R1_term = torch.prod(0.25 < norm, dim=1)
        R2_term = torch.prod((0.3 < norm) & (norm < 0.4), dim=1)
        return (R0 + R1*R1_term + R2*R2_term)

Наконец, мы реализуем настоящий GFlowNet.

class GFlowNet(nn.Module):
    def __init__(self, forward_policy, backward_policy, env):
        super().__init__()
        self.total_flow = Parameter(torch.ones(1))
        self.forward_policy = forward_policy
        self.backward_policy = backward_policy
        self.env = env
    
    def mask_and_normalize(self, s, probs):
        probs = self.env.mask(s) * probs
        return probs / probs.sum(1).unsqueeze(1)
    
    def forward_probs(self, s):
        probs = self.forward_policy(s)
        return self.mask_and_normalize(s, probs)
    
    def sample_states(self, s0, return_log=False):
        s = s0.clone()
        done = torch.BoolTensor([False] * len(s))
        log = Log(s0, self.backward_policy, self.total_flow, self.env) if return_log else None

        while not done.all():
            probs = self.forward_probs(s[~done])
            actions = Categorical(probs).sample()
            s[~done] = self.env.update(s[~done], actions)
            
            if return_log:
                log.log(s, probs, actions, done)
                
            terminated = actions == probs.shape[-1] - 1
            done[~done] = terminated
        
        return (s, log) if return_log else s

Метод sample_states принимает в качестве входных данных начальные состояния s0 и выполняет цикл до тех пор, пока все состояния не завершатся. Класс Log отвечает за регистрацию текущей оценки общего потока, вознаграждения за каждую выборку, а также прямых и обратных вероятностей, встречающихся на траектории для каждой выборки. Они регистрируются, так как они необходимы для расчета потери баланса траектории. Я не буду вдаваться в детали реализации Log здесь, но не стесняйтесь проверить это на GitHub.

Теперь мы готовы обучить GFlowNet.

def train(batch_size, num_epochs):
    env = Grid(size=size)
    forward_policy = ForwardPolicy(env.state_dim, hidden_dim=32, num_actions=env.num_actions)
    backward_policy = BackwardPolicy(env.state_dim, num_actions=env.num_actions)
    model = GFlowNet(forward_policy, backward_policy, env)
    opt = Adam(model.parameters(), lr=5e-3)
    
    for i in (p := tqdm(range(num_epochs))):
        s0 = one_hot(torch.zeros(batch_size).long(), env.state_dim).float()
        s, log = model.sample_states(s0, return_log=True)
        loss = trajectory_balance_loss(log.total_flow,
                                       log.rewards,
                                       log.fwd_probs,
                                       log.back_probs)
        loss.backward()
        opt.step()
        opt.zero_grad()
        if i % 10 == 0: p.set_description(f"{loss.item():.3f}")

После того, как модель будет обучена, мы можем выбрать последнюю партию из 10 000 образцов из метода sample_states, на этот раз без использования зарегистрированных данных.

s0 = one_hot(torch.zeros(10**4).long(), env.state_dim).float()
s = model.sample_states(s0, return_log=False)
plot(s, env)

Как видите, относительная частота каждого образца из обученной модели (слева) примерно пропорциональна соответствующему вознаграждению (справа). Конечно, это простая игрушечная задача, которая служит только доказательством концепции. В качестве примера того, как GFlowNets можно применять к более реальным проблемам, см., например, недавнюю работу Deleu et al. по использованию GFlowNets для изучения байесовской структуры.

Заключение

GFlowNets — это новая область исследований, и время покажет, насколько полезными они окажутся. Мне они кажутся многообещающим инструментом для таких задач, как открытие лекарств и вывод о причинно-следственных связях, когда интересующие объекты дискретны (т. е. молекулы и графы причинно-следственных связей) и где предсказание моделей оракулов или оценка вероятностей могут служить функциями вознаграждения. . Если вам интересно узнать больше о GFlowNets, я настоятельно рекомендую Интервью MLST с Йошуа Бенжио.

Рекомендации

[1] Бенжио, Йошуа и Лахлоу, Салем и Делеу, Тристан и Ху, Эдвард Дж. и Тивари, Мо и Бенжио, Эммануэль (2021). «Основы GFlowNet»

[2] Бенжио, Эммануэль и Джейн, Мокш и Кораблев, Максим и Прекап, Дойна и Бенжио, Йошуа (2021). «Генеративные модели на основе потоковой сети для неитеративной генерации разнообразных кандидатов»

[3] Малкин, Николай и Джайн, Мокш и Бенжио, Эммануэль и Сунь, Чен и Бенжио, Йошуа. «Баланс траектории: улучшенное распределение кредитов в GFlowNets»

[4] Делеу, Тристан и Гойс, Антонио и Эмезуэ, Крис и Ранкават, Манси и Лакост-Жюльен, Саймон и Бауэр, Стефан и Бенхио, Йошуа (2022). “Обучение байесовской структуры с генеративными потоковыми сетями”