Как реализовать сеть с использованием Bert в качестве кодировщика абзацев в классификации длинного текста в keras?

Я выполняю задачу классификации длинного текста, в которой содержится более 10000 слов в документе, я планирую использовать Bert в качестве кодировщика абзацев, а затем пошагово подавать вложения абзаца в BiLSTM. Сеть выглядит следующим образом:

Входные данные: (размер_пакета, макс_параграф_лен, макс_токен_пер_пара, размер_встраивания)

слой берта: (max_paragraph_len, paragraph_embedding_size)

слой lstm: ???

выходной слой: (размер_пакции, размер_классификации)

Как это реализовать с помощью кераса? Я использую keras load_trained_model_from_checkpoint для загрузки модели берта

bert_model = load_trained_model_from_checkpoint(
        config_path,
        model_path,
        training=False,
        use_adapter=True,
        trainable=['Encoder-{}-MultiHeadSelfAttention-Adapter'.format(i + 1) for i in range(layer_num)] +
            ['Encoder-{}-FeedForward-Adapter'.format(i + 1) for i in range(layer_num)] +
            ['Encoder-{}-MultiHeadSelfAttention-Norm'.format(i + 1) for i in range(layer_num)] +
            ['Encoder-{}-FeedForward-Norm'.format(i + 1) for i in range(layer_num)],
        )

person user1337896    schedule 05.11.2019    source источник
comment
Возможно, вас заинтересует Bert-as-service, пакет, специально предназначенный для использования Bert для встраивания документов. Если по какой-то причине вы хотите реализовать его самостоятельно, прочтите, как работает пакет (объединение предпоследнего уровня)   -  person Sam H.    schedule 06.11.2019


Ответы (1)


Я считаю, что вы можете проверить следующее статья. Автор показывает, как загрузить предварительно обученную модель BERT, встроить ее в слой Keras и использовать в настроенной глубокой нейронной сети. Сначала установите TensorFlow 2.0 Keras реализацию google-research / bert:

pip install bert-for-tf2

Затем запустите:

import bert
import os

def createBertLayer():
    global bert_layer

    bertDir = os.path.join(modelBertDir, "multi_cased_L-12_H-768_A-12")

    bert_params = bert.params_from_pretrained_ckpt(bertDir)

    bert_layer = bert.BertModelLayer.from_params(bert_params, name="bert")

    bert_layer.apply_adapter_freeze()

def loadBertCheckpoint():

    modelsFolder = os.path.join(modelBertDir, "multi_cased_L-12_H-768_A-12")
    checkpointName = os.path.join(modelsFolder, "bert_model.ckpt")

    bert.load_stock_weights(bert_layer, checkpointName)
person SvGA    schedule 01.05.2020