Я думаю, что линейная регрессия — лучший алгоритм для знакомства с Jax. Потому что это просто, и обычно все начинают свое путешествие по машинному обучению именно с этого алгоритма.
Во-первых, что такое JAX?
Jax — это библиотека Python, основанная на Numpy. Это позволяет нам выполнять числовые вычисления на массивах Numpy. В этой статье мы будем использовать Jax для получения производной функции стоимости (для градиентного спуска). Также у Jax есть функция Just-In-Time (JIT), которая обеспечивает более быстрые вычисления.
Давайте приступим к написанию кода!
Во-первых, мы реализуем необходимые библиотеки.
from jax import grad, jit import jax.numpy as jnp import matplotlib.pyplot as plt from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split
Затем давайте создадим наш собственный линейный набор данных.
X,y= make_regression(n_samples = 150, n_features= 2, noise = 5) y=y.reshape((y.shape[0],1)) X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.15) #Splitting data into train and test fig = plt.figure(figsize=(8,6)) plt.scatter(X[:,0], y, c='r') plt.scatter(X[:,1], y, c='g') plt.show()
Вот наша функция среднеквадратичной ошибки (MSE) для расчета потерь.
def loss(w,b,X,y): pred = X.dot(w)+ b return ((pred-y)**2).mean()
Инициализация веса и другие переменные.
Weights = jnp.zeros((X_train.shape[1],1)) # jnp array with zero initialization bias = 0. l_rate = 0.001 n_iter = 3000
Вот наши функции градиента для весов и смещения.
gradW = jit(grad(loss, argnums=0)) #wrapped with jit function for faster processing gradb = jit(grad(loss,argnums=1))
Параметр «argnums» позволяет нумеровать аргументы
И, наконец, тренировочный цикл!
for _ in range(n_iter): dW = gradW(Weights,bias,X_train,y_train) db = gradb(Weights,bias,X_train,y_train) print(loss(Weights,bias,X_train,y_train)) Weights -= dW*l_rate bias-= db*l_rate
Последняя MSE (среднеквадратичная ошибка) — «17,813536». (Начало с «7545,73»)
loss(Weights, bias, X_test, y_test) #Model's Loss on test set
››› DeviceArray(18.755125, dtype=float32)
Посмотрим результаты в таблице.
fig, ax = plt.subplots() ax.set_title("Regression Line for feature 1") ax.set_xlabel("X1 value") ax.set_ylabel("Y value") plt.scatter(X[:,0], y, c='r') plt.plot(X[:,0], X[:,0]*Weights[0]+bias) plt.show()
fig, ax = plt.subplots() ax.set_title("Regression Line for feature 2") ax.set_xlabel("X2 value") ax.set_ylabel("Y value") plt.scatter(X[:,1], y, c='g') plt.plot(X[:,1], X[:,1]*Weights[1]+bias) plt.show()
В заключение
Мы узнали, что такое Jax и как его использовать простым способом. Кроме того, мы получили 17,813536 потерь MSE на тренировочном наборе и 18,755125 потерь MSE на тестовом наборе, что является хорошим результатом для такой простой модели. Возможно, вы сможете использовать JAX для реализации более сложных алгоритмов машинного обучения.
Если у вас есть какие-либо вопросы или предложения, пожалуйста, оставьте свои комментарии.