Я думаю, что линейная регрессия — лучший алгоритм для знакомства с Jax. Потому что это просто, и обычно все начинают свое путешествие по машинному обучению именно с этого алгоритма.

Во-первых, что такое JAX?

Jax — это библиотека Python, основанная на Numpy. Это позволяет нам выполнять числовые вычисления на массивах Numpy. В этой статье мы будем использовать Jax для получения производной функции стоимости (для градиентного спуска). Также у Jax есть функция Just-In-Time (JIT), которая обеспечивает более быстрые вычисления.

Давайте приступим к написанию кода!

Во-первых, мы реализуем необходимые библиотеки.

from jax import grad, jit
import jax.numpy as jnp
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

Затем давайте создадим наш собственный линейный набор данных.

X,y= make_regression(n_samples = 150, n_features=  2, noise = 5)
y=y.reshape((y.shape[0],1))
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.15) #Splitting data into train and test
fig = plt.figure(figsize=(8,6))
plt.scatter(X[:,0], y, c='r')
plt.scatter(X[:,1], y, c='g')
plt.show()

Вот наша функция среднеквадратичной ошибки (MSE) для расчета потерь.

def loss(w,b,X,y):
  pred = X.dot(w)+ b
  return ((pred-y)**2).mean()

Инициализация веса и другие переменные.

Weights = jnp.zeros((X_train.shape[1],1)) # jnp array with zero initialization
bias = 0.
l_rate = 0.001
n_iter = 3000

Вот наши функции градиента для весов и смещения.

gradW = jit(grad(loss, argnums=0)) #wrapped with jit function for faster processing
gradb = jit(grad(loss,argnums=1))

Параметр «argnums» позволяет нумеровать аргументы

И, наконец, тренировочный цикл!

for _ in range(n_iter):
  dW = gradW(Weights,bias,X_train,y_train)
  db = gradb(Weights,bias,X_train,y_train)
  print(loss(Weights,bias,X_train,y_train))
  Weights -= dW*l_rate
  bias-= db*l_rate

Последняя MSE (среднеквадратичная ошибка) — «17,813536». (Начало с «7545,73»)

loss(Weights, bias, X_test, y_test) #Model's Loss on test set

››› DeviceArray(18.755125, dtype=float32)

Посмотрим результаты в таблице.

fig, ax = plt.subplots()
ax.set_title("Regression Line for feature 1")
ax.set_xlabel("X1 value")
ax.set_ylabel("Y value")
plt.scatter(X[:,0], y, c='r')
plt.plot(X[:,0], X[:,0]*Weights[0]+bias)
plt.show()

fig, ax = plt.subplots()
ax.set_title("Regression Line for feature 2")
ax.set_xlabel("X2 value")
ax.set_ylabel("Y value")
plt.scatter(X[:,1], y, c='g')
plt.plot(X[:,1], X[:,1]*Weights[1]+bias)
plt.show()

В заключение

Мы узнали, что такое Jax и как его использовать простым способом. Кроме того, мы получили 17,813536 потерь MSE на тренировочном наборе и 18,755125 потерь MSE на тестовом наборе, что является хорошим результатом для такой простой модели. Возможно, вы сможете использовать JAX для реализации более сложных алгоритмов машинного обучения.

Если у вас есть какие-либо вопросы или предложения, пожалуйста, оставьте свои комментарии.