Здравствуйте, я снова здесь.
В последнем посте я написал пример использования spaCy для прогнозирования настроений в данных Twitter.
Можно еще много говорить об основных принципах тренировочного процесса spaCy.
Но сегодня речь пойдет о том, как улучшить конвейер для классификации текста.
Резюмируя последний пост:
- Извлеките данные (тематическое исследование было соревнованием Kaggle, поэтому набор данных был готов к работе)
- Ключевые понятия о классификации текста spaCy
- Предварительно обработайте данные, чтобы они соответствовали шаблону spaCy (мы не делали никакой очистки в тексте)
- Преобразование данных в формат DocBin
- Установите config.cfg
- Тренироваться
- Оценивать
Хорошо…. Начнем с части оптимизации.
. Исходная модель:
Мы обучили нашу модель, используя пространственную модель с пустой английской моделью.
И результаты были не так хороши.
ROC AUC: 0,80
Оценка F1: 0,63
. Первая оптимизация:
Давайте потренируемся, используя маленькую модель Spacy. (Я использовал маленькую модель примерно с мощностью моего компьютера, для обучения на большой модели я буду использовать Google Colab)
nlp = spacy.load("en_core_web_sm")
Чтобы изменить это, нам нужно создать новый config.cfg.
Доступ -› https://spacy.io/usage/training
. Вторая оптимизация:
Давайте почистим текст.
Для этого я использую 3 функции (конечно, вы можете улучшить с помощью множества других функций и очистки конвейеров)
def remove_emoji(text):
emoji_pattern = re.compile("["
u"\\U0001F600-\\U0001F64F" # emoticons
u"\\U0001F300-\\U0001F5FF" # symbols & pictographs
u"\\U0001F680-\\U0001F6FF" # transport & map symbols
u"\\U0001F1E0-\\U0001F1FF" # flags (iOS)
u"\\U00002702-\\U000027B0"
u"\\U000024C2-\\U0001F251"
"]+", flags=re.UNICODE)
return emoji_pattern.sub(r'', text)
def remove_url(text):
url_pattern = re.compile('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')
return url_pattern.sub(r'', text)
def clean_text(text ):
delete_dict = {sp_character: '' for sp_character in string.punctuation}
delete_dict[' '] = ' '
table = str.maketrans(delete_dict)
text1 = text.translate(table)
textArr= text1.split()
return (' '.join([w for w in textArr if (not w.isdigit() and ( not w.isdigit() and len(w)>3))])).lower()
Есть много подобных проектов, и я нашел эту функцию в этом репо:
Результаты намного лучше.
После обучения в 20 эпох и почти 1 часа обучения, посмотрите на страницу обучения:
Ничего себе, модель намного лучше, оценка F1 для модели составляет 85%.
Будущие оптимизации:
. Используйте большую модель: en_core_web_lg или en_core_web_trf.
. Улучшить конвейер очистки тестов
. Протестируйте другие подходы: сбалансируйте данные, оцените выбросы
И еще многое можно сказать об оптимизации.
Большое спасибо за чтение этой статьи.
Если у вас есть предложения по следующим темам, скажите, пожалуйста.
Вот и все, если вам понравилось, поделились этой статьей.
Спасибо.