Pyspark — мощная библиотека, предлагающая множество возможностей для манипулирования и потоковой передачи данных в больших масштабах. Несмотря на то, что библиотека поддерживает возможности машинного обучения, в библиотеке нет реализации кодирования One Hot.

В этой статье мы создадим простой кодер One Hot, который сделает всю работу за нас.

Набор данных

Для проверки реализации нам понадобится всего несколько строк с синтетическими данными:

from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType


# Define the schema for the DataFrame
schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("country", StringType(), False),
    StructField("religion", StringType(), False)
])

# Sample data for the DataFrame
data = [
    (1, "United States", "Christianity"),
    (2, "Canada", "Atheism"),
    (3, "United Kingdom", "Islam"),
    (4, "Australia", "Hinduism"),
    (5, "Germany", "Buddhism"),
    # ... Add more rows with different country and religion values
]

# Create a DataFrame using the provided schema and data
df = spark.createDataFrame(data, schema=schema)

Создание кодировщика One Hot

Для кодирования One Hot нам нужно сохранить всего несколько точек данных:

  • имена столбцов
  • категории
  • что мы хотим сделать, если во время вывода была обнаружена неизвестная категория

Поэтому наша единственная цель — сохранить эту информацию и применить некоторые преобразования к кадру данных Pyspark на ее основе. Мы храним нашу информацию во вложенных словарях:

self.conditions: Dict[str, Dict[str, int]] = {}
self.categories: Dict[str, Dict[str, int]] = {}

Самое сложное здесь — это вывод. Мы не хотим жестко запрограммировать, какие столбцы добавлять, а добавляем столбцы в цикле. К счастью, Pyspark позволяет нам создавать выражения SQL, которые помогают нам достичь этой цели:

# Build the dynamic expression using F.when and F.otherwise
expression = F.expr(f"CASE WHEN {col} == '{cat}' THEN 1 ELSE 0 END")

# Apply the expression to the DataFrame
df = df.withColumn(str(col) + "_" + str(cat), F.lit(0))
df = df.withColumn(str(col) + "_" + str(cat), expression)

Мы имитируем API sklearn, чтобы сделать использование кодировщика более доступным. Вот последний класс:

from typing import Dict, List
from pyspark.sql import DataFrame
from pyspark.sql import functions as F


class PysparkOnehotEncoder:
    def __init__(self):
        self.conditions: Dict[str, Dict[str, int]] = {}
        self.categories: Dict[str, Dict[str, int]] = {}

    def fit(self, df: DataFrame, cat_cols: List[str]) -> None:
        for col in cat_cols:
            cats = df.select(col).distinct().collect()
            length = len(cats)
            cats: List[str] = [df.select(col).distinct().collect()[cat][0] for cat  in range(length)]

            self.conditions[col] = {}
            self.categories[col] = {}
            for idx, cat in enumerate(cats):
                self.conditions[col][f"{col} == '{cat}'"] = idx
                self.categories[col][cat] = idx # useful for reverse_transform at some point

    def predict(self, df: DataFrame) -> DataFrame:
        for col, cats_dict in self.categories.items():
            for cat, _idx in cats_dict.items():
                # Build the dynamic expression using F.when and F.otherwise
                expression = F.expr(f"CASE WHEN {col} == '{cat}' THEN 1 ELSE 0 END")

                # Apply the expression to the DataFrame
                df = df.withColumn(str(col) + "_" + str(cat), F.lit(0))
                df = df.withColumn(str(col) + "_" + str(cat), expression)
            df = df.drop(col)
        return df

    def fit_predict(self, df: DataFrame, cat_cols: List[str]):
        self.fit(df, cat_cols)
        df = self.predict(df)
        return df

Тестирование кодировщика

Мы создаем экземпляр класса и запускаем метод fit:

onehot_encoder = PysparkOnehotEncoder()

onehot_encoder.fit(df, ["country", "religion"])

Теперь мы создадим некоторые новые данные специально для вывода:

# Sample data for the DataFrame
data = [
    (1, "United States", "Christianity"),
    (2, "Canada", "Atheism"),
    (3, "United Kingdom", "Islam"),
    (4, "Australia", "Hinduism"),
    (5, "Germany", "Buddhism"),
    (6, "Italy", "Unknown"), # add unseen categories
]

# Create a DataFrame using the provided schema and data
pred_df = spark.createDataFrame(data, schema=schema)

Наконец, мы трансформируем наш новый фрейм данных:

pred_df = onehot_encoder.predict(pred_df)

Результат выглядит хорошо:

И вот оно: наш собственный кодер One Hot, похожий на sklearn.

Заключительные слова

Несмотря на то, что за последние десятилетия кодирование One Hot столкнулось с жесткой конкуренцией из-за его потенциальных недостатков, по-прежнему удивительно, что в Pyspark не наблюдается никакой реализации.

Я надеюсь, что эта реализация может стать хорошей отправной точкой для того, чтобы помочь другим, нуждающимся в таком кодировщике.

Если вам понравилась статья, похлопайте в ладоши или даже подпишитесь, чтобы не пропустить будущий контент.

ПИСАТЕЛЬ на MLearning.ai / Code LLM Wonders / AI art Copyright