1. Что такое обратный вызов в Python и языках программирования

Обратный вызов — это вызываемый объект, который может быть принят и запущен другой логикой. Поскольку функции являются гражданами первого класса, обратный вызов в Python относится к функции, которая передается в качестве аргумента другой функции. Основное использование обратного вызова — динамическая настройка поведения функции, получившей обратный вызов.

Рассмотрим следующий пример, когда человек всегда завтракает, а затем едет в школу на разных транспортных средствах. Мы могли бы использовать go_to_school(take_a_bus) или go_to_school(drive_a_car) для динамического описания процедуры, повторно используя ту же функцию go_to_school(), которая определяет шаблон. Позже, если есть третий способ передвижения, например, город строит железную дорогу между вашим домом и школой, нам нужно только реализовать новую функцию take_a_train() без каких-либо других изменений.

def take_a_bus():
    print("take a bus")
def drive_a_car():
    print("drive a car")

def go_to_school(transportation):
    print("wake up and have breakfast")
    transportation()
    print("arrive at school")

Можно понять, что обратные вызовы имеют два ограничения:

  • Обратный вызов должен соответствовать строгому формату/сигнатуре, которая должна быть согласована между вызывающей стороной и провайдером, поскольку логика запуска обратного вызова вызывающей стороной предопределена до определения конкретного обратного вызова.
  • Второй также связан с предыдущим. Трудно предоставить сложный контекст или зависимость от провайдера, например, как сделать take_a_bus() динамичным, например, Джек едет на автобусе 501, а Люси едет на автобусе 300, без утомительного определения одного для каждого человека.

Ожидается первое ограничение, поскольку обратные вызовы, по сути, являются интерфейсами между двумя (распределенными) системами, а строгий контракт, как правило, необходим, чтобы ничего не нарушать.

Второе ограничение в целом можно преодолеть, расширив обратные вызовы от функций до классов, что позволяет нам использовать атрибуты класса для определения контекстов/зависимостей. Мы знаем, что функция в Python — это просто класс. Так что ничто не может запретить нам использовать класс в качестве обратного вызова в его общем определении. Кроме того, определение обратных вызовов через классы также позволяет нам унифицированно инкапсулировать несколько функций. Например, мы могли бы расширить приведенный выше обратный вызов функции до следующего примера, где и коммутирующий, и первый класс могут быть настроены с помощью переданного объекта my_pattern. Например, мы могли бы использовать go_to_school(pattern("bus", "math", "Jack")), чтобы указать, что Джек едет в школу на автобусе и изучает математику в качестве первого предмета.

class pattern:
    def __init__(self, transportation, first_class, person, bus):
        self.transportation = transportation
        self.first_class = first_class
        self.person = person
    def commute_to_school(self):
        print(f"{self.person} takes a {self.transportation}")
    def first_class(self):
        print(f"{self.person}'s first class is {self.first_class}")

def go_to_school(my_pattern):
    print("wake up and have breakfast")
    my_pattern.commute_to_school()
    print("arrive at school")
    my_pattern.first_class()

2. Обратные вызовы в PyTorch и TensorFlow

Поскольку процедура обучения модели очень динамична, обратные вызовы широко используются в TF/Keras и Pytorch, чтобы пользователи могли разрабатывать новые алгоритмы обучения, максимально повторно используя существующие API-интерфейсы Python.

2.1 Обратные вызовы/перехватчики в PyTorch

В PyTorch функции обратного вызова называются хуками. (Название перехватчика должно подчеркнуть, что обратный вызов предназначен для запуска в определенной точке, которая перехвачена). Перехватчики PyTorch могут быть зарегистрированы и активированы в проходах forward и backward объектов Tensor и nn.Module. Эти ловушки чрезвычайно полезны, поскольку ученые машинного обучения редко реализуют прямые или обратные тензорные операторы, которые обычно реализуются в CPP / CUDA и требуют разных наборов навыков, но часто нуждаются в применении дополнительных операций в прямом / обратном вычислении в их изобретенных методах обучения модели. . Например, можно зарегистрировать обратный хук для реализации отсечения градиента, не касаясь утомительной процедуры обратного прохода. Другой пример: когда человек видит странные результаты обучения модели или логического вывода, он может зарегистрировать перехватчик вперед или назад (с возможным изменением кода в одну строку), чтобы удобно распечатать сводку задействованных тензоров или градаций для отладки. PyTorch ожидает функции/хуки с заранее заданными сигнатурами (в соответствии с его вариантами использования). См. этот официальный документ ссылка для получения списка поддерживаемых хуков и их регистрационных API для nn.Module.

Например, прямой хук для module, который будет вызываться каждый раз после того, как модуль forward() вычислит вывод, может быть зарегистрирован register_forward_hook(hook, *, prepend=False, with_kwargs=False) и должен иметь следующую подпись:

hook(module, args, output) -> None or modified output

когда with_kwargs=False или

hook(module, args, kwargs, output) -> None or modified output

когда with_kwargs=True, где args — входные тензоры, а output — выходной тензор. Обратный хук будет вызываться каждый раз, когда вычисляются градиенты по отношению к модулю, и может быть зарегистрирован register_full_backward_hook(hook, prepend=False), и должен иметь другую сигнатуру, ориентированную на градиенты:

hook(module, grad_input, grad_output) -> tuple(Tensor) or None

потому что обратное, в отличие от прямого вычисления выходных данных NN, заключается в вычислении градиентов. Согласно онлайн-документу от 20.06.2023, grad_input и grad_output — это кортежи, содержащие градиенты по отношению к входам и выходам соответственно. (Однако мой эксперимент с использованием PyTorch 2.0.1 с nn.Linear показал, что grad_input содержит больше, чем градиент по отношению к входным данным. На самом деле, если nn.Linear имеет bias=Ture, то grad_input представляет собой тройку, первый элемент которой является градиентом относительно смещения линейный модуль, 2-й элемент — это градация по отношению к входу, а 3-й элемент — это градация по отношению к весам линейного модуля; если nn.Linear имеет bias=Fale, grad_input представляет собой двойку с двумя последними элементами случая bias=Ture. эту ссылку на форум PyTorch для связанного обсуждения.)

2.2 Обратные вызовы в TensorFlow/Keras

TensorFlow/Keras также в значительной степени полагаются на обратные вызовы для настройки поведения в методах Keras fit, evaluate и predict. В отличие от Pytorch, где обычно есть явный контроль над процедурой обучения, Keras, по сути, делает процедуру обучения более похожей на черный ящик, предоставляя fit() API, хотя в некотором смысле все еще можно настраивать поведение fit() путем перезаписи некоторых фиксированных функций, таких как train_step(). (См., например, эту ссылку о том, как настроить обучение с помощью fit() в Keras. На мой взгляд, настройка с помощью перезаписи train_step() в fit() на самом деле разделяет ту же идею настройки с помощью обратных вызовов, но имеет более тяжелый вес.)

Поскольку такие методы, как fit() в Keras, являются монолитными, и неудобно обеспечивать точный контроль на низких уровнях forward/backward, как в PyTorch nn.Module, Keras/TF полагается на классы обратных вызовов, которые могут гибко вводить контексты/зависимости, как мы объясняли ранее. Базовый класс tf.keras.callbacks.Callback определен в Keras и реализует ожидаемые интерфейсы функций обратного вызова, такие как on_epoch_begin(self, epoch, logs=None), on_epoch_end(self, epoch, logs=None), on_train_batch_begin(self, batch, logs=None), on_train_batch_end(self, batch, logs=None) и т. д. Можно легко сделать вывод, где эти функции обратного вызова будут запускаться по их именам. Метод fit() всегда будет запускать функциональные интерфейсы, соответствующие соответствующим местам, например, он всегда запускает on_train_batch_begin в начале каждого пакета в цикле обучения. На самом деле метод fit() ожидает список объектов обратного вызова и запускает функции-члены с одинаковым именем (соответствующим месту выполнения) последовательно от всех объектов обратного вызова в этом списке. Таким образом, мы можем легко расширить поведение обратного вызова до необходимого количества.

Когда я впервые изучал классы обратного вызова в Keras, меня больше всего смутило то, как можно заставить поведение обратного вызова зависеть от запускающего модуля NN, учитывая, что вышеупомянутые интерфейсы функций включают только простые аргументы, такие как epoch или пакет. Напомним, что хуки forward/backward в PyTorch должны следовать фиксированным сигнатурам функций, чтобы мы знали, что эти определенные хуки будут работать с задействованными тензорами или их градиентами. Как класс обратного вызова, определенный вне модели Keras, получает информацию о том, с каким весом модели его on_batch_end() должен выполнять операции? Оказывается, Keras полагается на неявное правило: базовый класс tf.keras.callbacks.Callback имеет атрибут self.model и определяет метод set_model(self, model) для изменения self.model. По умолчанию .fit(), получающий список объектов обратного вызова, запускает метод set_model() и указывает self.model внутри этих экземпляров обратного вызова на модель Keras, которая получает обратный вызов. В соответствии с этим неявным правилом можно определить ее класс обратного вызова, предполагая, что ему известна точная модель, которая должна получить класс обратного вызова. Хотя это разумная стратегия, я бы хотел, чтобы официальный документ Кераса где-то ясно объяснил ее. Понимание этого неявного правила должно помочь вам решить некоторые сложные сценарии, когда вам нужно, чтобы ваши обратные вызовы изменяли или зависели от низких уровней моделей Keras.

2.3 Обратные вызовы в PyTorch Lightning

Pytorch Lightening — это относительно новая среда глубокого обучения, которая предоставляет высокоуровневый интерфейс для Pytorch. Если у вас есть опыт работы как с PyTorch, так и с TF/Keras, вы можете грубо рассматривать PyTorch Lightning как более похожую на Keras версию PyTorch. Коды PyTorch должны явно иметь дело с логикой эпох/итераций, которые являются необходимыми деталями для исследователей/ученых для разработки/отладки новых методов обучения, но, вероятно, могут напугать инженеров-программистов, которые не так сильно разбираются в такой сути и скорее рассматривают общее обучение как черный ящик. На самом деле, некоторые инженеры-программисты без опыта машинного обучения, например, те, кто в прошлом писал только Java, с которыми я работал в прошлом, на самом деле категорически против процессно-ориентированного стиля кодирования PyTorch и изо всех сил пытались преобразовать коды PyTorch в быть более объектно-ориентированным. (Я представлю некоторые из своих личных мыслей о процессно-ориентированных и объектно-ориентированных кодах машинного обучения позже в этом документе). Я согласен с их комментариями о том, что коды обучения PyTorch часто кажутся беспорядочными и могут быть подвержены ошибкам при их изменении. Однако они, вероятно, не видели, насколько это необходимо или, по крайней мере, удобно исследователям/ученым для разработки новой процедуры обучения. PyTorch Lightening или Keras — идеальные фреймворки для таких инженеров-программистов. Или, как утверждается в этом коротком видео о PyTorch Lightning на его официальном веб-сайте, наиболее идеальный рабочий процесс, вероятно, заключается в том, что исследователи / ученые / MLE разрабатывают методы обучения в PyTorch, а затем конвертируют коды в PyTorch Lightning перед передачей кодов в производство. Еще одним большим преимуществом PyTorch Lightning является то, что он абстрагирует логику распределенного обучения низкого уровня и скрывает ее от логики высокого уровня. Это на самом деле желательно даже исследователям / ученым, которые обычно не имеют опыта работы с распределенными системами, поскольку, к сожалению, глубокое обучение в эпоху больших моделей все больше и больше полагается на распределенные системы.

Обратный вызов в PyTorch Lightening почти такой же, как и в Keras, например, с использованием класса вместо функции, поскольку они предназначены для использования в хорошо инкапсулированной функции, такой как .fit().

3. Обратные вызовы vs. Подклассы

Поскольку в основном обратные вызовы используются для настройки шаблонной логики, может возникнуть вопрос, чем они отличаются от подклассов. Теоретически то, что может быть обеспечено обратными вызовами, должно быть достигнуто путем создания подклассов. Например, в примере go_to_school() в начале этого документа мы могли бы определить один подкласс, который использует автобус, и другой подкласс, который использует поезд. Однако для каждого подкласса потребуется дублировать функцию go_to_school(), которая в действительности может быть очень длинной и в конечном итоге заменить только один код строки, указывающий, как человек добирается до школы. То есть создание подклассов — это излишество, когда нужно настроить только второстепенное место в длинном логическом модуле. На самом деле создание подклассов в этом случае может фактически отклонить будущую эволюцию, например, когда мы позже введем дополнительный шаг take_a_bath() перед поездкой в ​​школу, нам нужно изменить как базовый класс, так и подкласс. (Чтобы избежать этого, может потребоваться разбить логику go_to_school() на подмодули и определить общие подмодули в базовом классе.) Кроме того, определение слишком большого количества подклассов с незначительными отклонениями может в конечном итоге сделать кодовую базу слишком сложной для управления, особенно когда настройка предназначена исключительно для отслеживание и отладка. Учитывая, что глубокое обучение по-прежнему является быстро развивающейся областью, схема обучения которой никогда не стабилизируется, обратные вызовы предпочтительнее. Фактически, обратные вызовы совместимы с подклассами, можно инкапсулировать настраиваемую логику, полученную с помощью обратных вызовов, в подкласс, чтобы скрыть детали настройки.