Проект Data Scientist Capstone
Цель этого поста — поделиться своим опытом работы над проектом Sparkify и показать вам его этапы. В качестве завершающего проекта моей степени Data Scientist в области нанотехнологий в Udacity было несколько интересных тем для заключительного проекта, и я совершенно уверен, что вернусь к ним позже. Однако я решил работать над проектом, который обрабатывает журналы взаимодействия с пользователем воображаемой платформы потоковой передачи музыки под названием Sparkify. Итак, приступим…
Цель проекта
… Sparkify предназначен для анализа заданных журналов взаимодействия с пользователем платформы потоковой передачи песен, аналогичной Spotify или Pandora, и прогнозирования будущих действий пользователей (обновление, понижение версии, отмена их службы). Сделанные прогнозы могут быть полезны для предложения более персонализированных услуг и обеспечения удовлетворенности пользователей платформой. Для этого мы используем pyspark и такие библиотеки, как sql и ml.
Основными этапами проекта являются загрузка, очистка, исследование данных, создание дополнительных функций, а также создание и прогнозирование модели.
Записная книжка начинается с импорта необходимых библиотек. Затем мы создаем сеанс искры и загружаем данные из файла JSON. Мы обращаемся к схеме данных, их описанию и нескольким строкам, чтобы получить представление об обработанных данных.
Уборка
Как обычно, эта часть занимает много времени, чтобы просмотреть все столбцы, очистить, преобразовать данные и заполнить все недостающие значения или удалить связанные строки/столбцы.
Итак, я начал с проверки всех строк с пустыми userId и sessionId. Я также отфильтровал строки с userId, имеющим в качестве значения пустую строку, поскольку нас интересует поведение зарегистрированных пользователей.
Разработка функций
Чтобы предсказать будущие действия пользователя, нам нужно рассмотреть функции, которые было бы интересно добавить в наш фрейм данных, и другие, которые нам на самом деле не нужны, и поэтому мы можем их удалить.
Исходный набор данных содержит столбцы:
root |-- artist: string (nullable = true) |-- auth: string (nullable = true) |-- firstName: string (nullable = true) |-- gender: string (nullable = true) |-- itemInSession: long (nullable = true) |-- lastName: string (nullable = true) |-- length: double (nullable = true) |-- level: string (nullable = true) |-- location: string (nullable = true) |-- method: string (nullable = true) |-- page: string (nullable = true) |-- registration: long (nullable = true) |-- sessionId: long (nullable = true) |-- song: string (nullable = true) |-- status: long (nullable = true) |-- ts: long (nullable = true) |-- userAgent: string (nullable = true) |-- userId: string (nullable = true)
Новые возможности
Мы добавляем следующие функции для получения дополнительной информации.
- Churn, чтобы пометить запись журнала с пользователем, который выполнил отмену.
user_log = flag_row(user_log, “page”,‘Cancellation Confirmation’, “Churn”)
- Понижен, чтобы пометить записи журнала с пользователем, понизившим версию служб.
user_log = flag_row(user_log, “page”, ‘Submit Downgrade’, “Downgraded”)
- hour час записи журнала, поэтому нам не нужен столбец ts типа timestamp.
- registration_year год регистрации пользователя на платформе.
- registration_month месяц регистрации пользователя на платформе.
- фаза, котораяидентифицирует запись журнала при использовании для перехода на более раннюю версию.
Технические столбцы интересны для нашего прогноза. Потому что пользователи, испытывающие технические проблемы, более склонны не быть довольными платформой и уходить. Поэтому мы сохраняем их в нашем последнем кадре данных.
В итоге получаем:
numerical_columns = [ ‘itemInSession’, ‘length’]
и
categorical_columns =[ “auth”, “artist”, “gender”, “level”, “page”, “status”, “userId”, “method” “Churn”, “Downgraded”, “phase”, “registration_year”, “registration_month”, “hour”]
Мы меняем тип numerical_columns на IntegerType и собираем их вместе как разреженный вектор. Обработка числовых признаков завершается преобразованием StandardScaler:
assembler = VectorAssembler(inputCols=numerical_columns, outputCol=”num_features”) registreted_user_log = assembler.transform(registreted_user_log) standard_scaler = StandardScaler(inputCol=”num_features”, outputCol=”SScaler_num_features”, withMean=True, withStd=True) standard_scaler_model = standard_scaler.fit(registreted_user_log) registreted_user_log = standard_scaler_model.transform(registreted_user_log)
Обработку categorical_columns выполняет StringIndexer, поэтому мы передаем в нашу модель индексы, а не строковые значения. Вот функции, которые мы используем для обучения нашей модели:
['Idx_auth', 'Idx_artist', 'Idx_gender', 'Idx_level', 'Idx_page', 'Idx_status', 'Idx_userId', 'Idx_phase', 'Idx_registration_year', 'Idx_registration_month', 'Idx_hour', 'SScaler_num_features']
Обучение модели
Чтобы предсказать будущие действия пользователей на основе журналов, сгенерированных их взаимодействием с платформой, нам необходимо использовать контролируемые алгоритмы машинного обучения. Потому что мы разработали функции Churn и Downgrade, которые мы хотели бы предсказать с помощью нашей модели. Итак, у нас есть проблема с классификацией. Я использовал классификатор LogisticRegression для этой проблемы.
Вот churn_pipelineслогистической регрессией, используемой для прогнозирования ярлыкаChurn.
assembler = VectorAssembler(inputCols = features, outputCol=’features’) pca = PCA(k=10, inputCol=’features’, outputCol=’pcaFeature’) churn_indexer = StringIndexer(inputCol=”Idx_Churn”, outputCol=”label”) lr = LogisticRegression(maxIter=10, regParam=0.3) churn_pipeline = Pipeline (stages=[assembler, pca, churn_indexer, lr])
Обучайте, тестируйте, проверяйте
Мы продолжаем делить набор данных на две части:
rest, validation = processed_user_log.randomSplit([0.7, 0.3], seed = 42)
Я использовал CrossValidator, чтобы получить лучшую модель для нашего набора данных:
CrossValidator
— это Оценщик для настройки модели, т. е. поиска лучшей модели для заданных параметров и набора данных.CrossValidator
разбивает набор данных на набор непересекающихся произвольно разделенных numFolds пары наборов данных для обучения и проверки.CrossValidator
генерируетCrossValidatorModel
для хранения лучшей модели и средних показателей перекрестной проверки.
Итак, реализация:
churn_crossval = CrossValidator(estimator = churn_pipeline, estimatorParamMaps = paramGrid, evaluator = MulticlassClassificationEvaluator(), numFolds=3)
После обучения модели мы проверяем средние показатели перекрестной проверки с помощью avgMetrics:
[0.9997074462177573, 0.9997074462177573, 0.9997074462177573, 0.9997074462177573]
Кажется, все в порядке с прогнозом.
Прогноз
Результатом good_predictions/number_of_total_predictions, выполненного моделью в проверочном наборе, является:
# good_predictions/number_of_total_predictions 83277/83291
Вывод
Как видно из результата, прогнозы достаточно хорошие. Использование LogisticRegression, CrossValidator позволяет нам с высокой точностью прогнозировать будущие действия пользователей. Основными проблемами этого анализа были части разработки функций и создания конвейера, поскольку они немного различаются между pyspark&ml и python& Научись учиться.
Будущие улучшения
Есть несколько улучшений, которые следует рассмотреть на будущее:
- Было бы интересно протестировать несколько алгоритмов классификации (вроде RandomForest, SVM, Naive Bayes) и сравнить результаты/метрики.
- Улучшите модель, протестировав модель на большом наборе данных и адаптировав ее на основе результатов/показателей.
- Создайте сценарий, чтобы регулярно обрабатывать вновь созданные пользовательские журналы, выполнять прогнозирование и продолжать обучать модель на основе вновь созданных пользовательских журналов.
- Автоматизируйте процесс обнаружения пользователей, «желающих» отказаться от услуг платформы или понизить их рейтинг, и предложить им какие-либо акции.
- И, конечно же, для этих точек в реальной среде нам нужно отслеживать эксплуатационные расходы и проверять результат с помощью тестов (например, A/B-тестирования…)
Этот анализ был выполнен в рабочей области, предоставленной Udacity, с использованием набора данных «mini_sparkify_event_data.json» небольшого размера, вдохновленного материалами, изученными на уроках. Следующим шагом будет продолжение работы над точками улучшения (здесь ниже), используя полный набор данных (размером 12 ГБ) в облаке…
Полную тетрадь для этого проекта вы можете найти здесь. Мне, как начинающему специалисту по данным, будет приятно вас услышать!
Бьентот :)