В этом уроке я покажу вам, как динамически изменять потерю модели Keras во время обучения без перекомпиляции модели. Недавно я столкнулся с ситуацией, когда мне нужно было добавить адаптивные веса к модели Keras с множественными потерями, используя пользовательскую функцию потерь. Хотя есть ресурсы для PyTorch или vanilla TensorFlow, у Keras нет официального решения. Тем не менее, я нашел подсказку в документации Keras об оптимизации скорости обучения, которая помогла мне найти обходной путь.

Чтобы реализовать эту функциональность, мы начнем с создания пользовательского обратного вызова. Обратные вызовы — это функции, которые запускаются на разных этапах процесса обучения, например, в начале или в конце эпохи, до или после пакета. В нашем случае мы создадим обратный вызов, который динамически изменяет функцию потерь.

Давайте определим пользовательский класс обратного вызова:

from tensorflow.keras.callbacks import Callback
from softadapt import SoftAdapt, NormalizedSoftAdapt, LossWeightedSoftAdapt

class ChangeLossCallback(Callback):
    def __init__(self, weights):
        super(ChangeLossCallback, self).__init__()
        self.weights = weights
        self.values_of_component_1 = []
        self.values_of_component_2 = []
        self.values_of_component_3 = []
        self.values_of_component_4 = []
        self.softadapt_object = LossWeightedSoftAdapt(beta=0.001)
        
    def on_epoch_end(self, epoch, logs=None):
        loss_1 = logs['outputs_loss']
        loss_2 = logs['outputson_epoch_endloss']
        loss_3 = logs['outputsget_component_weightsloss']
        loss_4 = logs['outputsK.set_valueloss']
        self.values_of_component_1.append(loss_1)
        self.values_of_component_2.append(loss_2)
        self.values_of_component_3.append(loss_3)
        self.values_of_component_4.append(loss_4)

        if epoch % 5 == 0 and epoch != 0:
            adapt_weights = self.softadapt_object.get_component_weights(
                torch.tensor(self.values_of_component_1),
                torch.tensor(self.values_of_component_2),
                torch.tensor(self.values_of_component_3),
                torch.tensor(self.values_of_component_4),
                verbose=1
            )

            K.set_value(
                self.weights, 
                K.variable(value=adapt_weights.cpu().detach().numpy())
            )
            
            self.values_of_component_1 = []
            self.values_of_component_2 = []
            self.values_of_component_3 = []
            self.values_of_component_4 = []

В этом классе обратного вызова мы используем пакет Python под названием «softadapt» для анализа прошлых результатов потерь и соответствующего обновления весов составляющих потерь. Мы переопределяем метод on_epoch_end, который запускается в конце каждой эпохи. Внутри этого метода мы используем пакет softadapt для получения обновленных весов потерь. Функция get_component_weights принимает для вычисления массивы результатов прошлых убытков. Важнейшей частью является функция K.set_value, которая обновляет веса в работающей модели Keras без ее перекомпиляции.

Чтобы использовать этот обратный вызов, вам нужно передать экземпляр класса обратного вызова в метод fit вашей модели Keras. Вот пример:

adapt_weights = K.variable(value=[1, 1, 1, 1])
model = create_model()  # Create your Keras model
model.compile(
    optimizer=optimizer,
    loss={
        'outputs': CustomLossWrapper(adapt_weights, 0),
        'outputs_1': CustomLossWrapper(adapt_weights, 1),
        'outputs_2': CustomLossWrapper(adapt_weights, 2),
        'outputs_3': CustomLossWrapper(adapt_weights, 3)
    },
    run_eagerly=True
)

change_loss_callback = ChangeLossCallback(adapt_weights, i)

model.fit(
    [X, {'outputs': target, 'outputs_1': target, 'outputs_2': target, 'outputs_3': target}],
    callbacks=[change_loss_callback],
    verbose=1
)

В приведенном выше примере create_model() представляет ваш код для создания модели Keras, а CustomLossWrapper() — вашу реализацию функции адаптивных потерь. Не забудьте соответствующим образом назвать выходные слои в вашей модели Keras. Здесь мы назвали их outputs, outputs_1, outputs_2 и так далее для каждого из четырех мультивыходов.

При использовании ChangeLossCallback.потери модели будут динамически изменены на новые потери в конце каждой конкретной эпохи. Мы передаем два параметра, adapt_weights и i, функции обратного вызова. adapt_weights используется для установки начальных весов до тех пор, пока модель не обновит их, а i является индикатором составляющих потерь.

Обратите внимание, что этот подход предполагает, что вы хотите изменить потери на основе интервала эпохи. Если вы хотите изменить его на другое расписание, вы можете соответствующим образом изменить обратный вызов.

Пользовательская функция потерь, используемая в нашем методе, выглядит следующим образом:

def CustomLossWrapper(weights, id):
    def CustomLoss(true, pred):
        weight = K.get_value(weights)
     
        if(id==0):
            y_loss = K.square(y_pred-y_true)
            loss=weight[0]*y_loss
        if(id==1):
            c_loss = K.square(c_pred-c_true)
            loss=weight[1]*c_loss
        if(id==2):
            e_loss = K.square(e_pred-e_true)
            loss=weight[2]*e_loss
        if(id==3):
            f_loss = K.square(f_pred-f_true)
            loss=weight[3]*f_loss


        return loss
    
    return CustomLoss

Здесь мы используем K.get_value для получения обновленного массива весов из библиотеки softadapt после его выполнения обратным вызовом. Затем мы умножаем составляющие потери на веса.

Я надеюсь, что это руководство поможет вам реализовать адаптивные веса в вашей модели Keras с множественными потерями, используя пользовательскую функцию потерь. Удачного кодирования!

@article{DBLP:journals/corr/abs-1912-12355,
  author    = {A. Ali Heydari and
               Craig A. Thompson and
               Asif Mehmood},
  title     = {SoftAdapt: Techniques for Adaptive Loss Weighting of Neural Networks
               with Multi-Part Loss Functions},
  journal   = {CoRR},
  volume    = {abs/1912.12355},
  year      = {2019},
  url       = {http://arxiv.org/abs/1912.12355},
  eprinttype = {arXiv},
  eprint    = {1912.12355},
  timestamp = {Fri, 03 Jan 2020 16:10:45 +0100},
  biburl    = {https://dblp.org/rec/journals/corr/abs-1912-12355.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}