Optuna — это мощная и гибкая библиотека Python для оптимизации гиперпараметров. Он предоставляет интуитивно понятную и эффективную основу для автоматического поиска оптимальных гиперпараметров моделей машинного обучения. В этом сообщении блога мы рассмотрим практическое применение Optuna и продемонстрируем, как он может упростить процесс настройки гиперпараметров.

Монтаж

Прежде чем мы углубимся в примеры, давайте установим Optuna с помощью pip, установщика пакетов Python.

pip install optuna

Основное использование Optuna

Optuna следует трехэтапному процессу: определение пространства поиска, указание целевой функции и запуск оптимизации. Рассмотрим пример использования Optuna для оптимизации гиперпараметров классификатора машины опорных векторов (SVM).

import optuna
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

# Load the dataset
X, y = datasets.load_iris(return_X_y=True)

# Split the dataset into training and validation sets
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=42)

# Define the objective function to be optimized
def objective(trial):
    # Define the hyperparameters to search
    C = trial.suggest_loguniform('C', 1e-4, 1e4)
    kernel = trial.suggest_categorical('kernel', ['linear', 'rbf', 'poly'])
    
    # Create the SVM classifier with the suggested hyperparameters
    clf = SVC(C=C, kernel=kernel)
    
    # Train the classifier and evaluate on the validation set
    clf.fit(X_train, y_train)
    accuracy = clf.score(X_valid, y_valid)
    
    return accuracy

# Run the optimization
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)

# Print the best hyperparameters and the corresponding accuracy
best_params = study.best_params
best_accuracy = study.best_value
print('Best Hyperparameters:', best_params)
print('Best Accuracy:', best_accuracy)

Расширенные возможности Оптуны

Optuna предоставляет расширенные функции, такие как обрезка, распараллеливание и интеграция со средами машинного обучения. Эти функции повышают эффективность и масштабируемость процесса оптимизации гиперпараметров. Рассмотрим пример использования Optuna с отсечением для оптимизации модели нейронной сети.

import optuna
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

# Generate a synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)

# Split the dataset into training and validation sets
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=42)

# Define the PyTorch model
class NeuralNetwork(nn.Module):
    def __init__(self, input_dim):
        super(NeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 10)
        self.fc2 = nn.Linear(10, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Define the objective function to be optimized
def objective(trial):
    # Define the hyperparameters to search
    lr = trial.suggest_loguniform('lr', 1e-5, 1e-1)
    momentum = trial.suggest_uniform('momentum', 0.0, 1.0)
    
    # Create the PyTorch model with the suggested hyperparameters
    model = NeuralNetwork(input_dim=X_train.shape[1])
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    
    # Convert the data into PyTorch tensors and create data loaders
    train_dataset = TensorDataset(torch.tensor(X_train).float(), torch.tensor(y_train).long())
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    
    # Training loop
    for epoch in range(10):
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        
        # Perform early stopping based on intermediate validation accuracy
        accuracy = evaluate_model(model, X_valid, y_valid)
        trial.report(accuracy, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
    
    # Return the final validation accuracy
    final_accuracy = evaluate_model(model, X_valid, y_valid)
    return final_accuracy

# Evaluate the model on the validation set
def evaluate_model(model, X, y):
    inputs = torch.tensor(X).float()
    targets = torch.tensor(y).long()
    outputs = model(inputs)
    _, predicted = torch.max(outputs, 1)
    accuracy = (predicted == targets).sum().item() / len(targets)
    return accuracy

# Run the optimization with pruning
study = optuna.create_study(direction='maximize', pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=100)

# Print the best hyperparameters and the corresponding accuracy
best_params = study.best_params
best_accuracy = study.best_value
print('Best Hyperparameters:', best_params)
print('Best Accuracy:', best_accuracy)

Optuna — мощная библиотека для оптимизации гиперпараметров в Python. В этом сообщении блога мы рассмотрели практические применения Optuna, в том числе определение пространства поиска, указание целевой функции и запуск процесса оптимизации. Мы также продемонстрировали расширенные функции, такие как обрезка и интеграция с платформами машинного обучения.

С Optuna вы можете упростить процесс настройки гиперпараметров и найти оптимальный набор гиперпараметров для ваших моделей машинного обучения. Автоматизировав этот процесс, вы сможете сэкономить время и усилия и добиться более высокой производительности и точности своих моделей.

Начните использовать Optuna, чтобы оптимизировать свои модели машинного обучения и полностью раскрыть их потенциал, найдя лучшие гиперпараметры. Экспериментируйте с различными алгоритмами, пространствами поиска и расширенными функциями для достижения оптимальной производительности в ваших проектах.

Связаться с автором: https://linktr.ee/harshita_aswani

Ссылка: