Линейная регрессия за 2 минуты (с использованием PyTorch)
Это вторая часть серии Primer Series PyTorch.
Линейная регрессия - это линейный подход к моделированию взаимосвязи между входными данными и прогнозами.
Мы находим "линейное соответствие" данным.
Подгонка: мы пытаемся предсказать переменную y, подгоняя кривую (здесь линия) к данным. Кривая линейной регрессии следует линейной зависимости между скалярной (x) и зависимой переменной.
Создание моделей в PyTorch
- Создать класс
- Объявите свой форвардный пас
- Настройте гиперпараметры
class LinearRegressionModel(nn.Module): def __init__(self, input_dim, output_dim): super(LinearRegressionModel, self).__init__() # Calling Super Class's constructor self.linear = nn.Linear(input_dim, output_dim) # nn.linear is defined in nn.Module def forward(self, x): # Here the forward pass is simply a linear function out = self.linear(x) return out input_dim = 1 output_dim = 1
Шаги
- Создать экземпляр модели
- Выберите критерий убытка
- Выберите гиперпараметры
model = LinearRegressionModel(input_dim,output_dim) criterion = nn.MSELoss()# Mean Squared Loss l_rate = 0.01 optimiser = torch.optim.SGD(model.parameters(), lr = l_rate) #Stochastic Gradient Descent epochs = 2000
Обучение модели
for epoch in range(epochs): epoch +=1 #increase the number of epochs by 1 every time inputs = Variable(torch.from_numpy(x_train)) labels = Variable(torch.from_numpy(y_correct)) #clear grads as discussed in prev post optimiser.zero_grad() #forward to get predicted values outputs = model.forward(inputs) loss = criterion(outputs, labels) loss.backward()# back props optimiser.step()# update the parameters print('epoch {}, loss {}'.format(epoch,loss.data[0]))
Наконец, распечатайте прогнозируемые значения
predicted =model.forward(Variable(torch.from_numpy(x_train))).data.numpy() plt.plot(x_train, y_correct, 'go', label = 'from data', alpha = .5) plt.plot(x_train, predicted, label = 'prediction', alpha = 0.5) plt.legend() plt.show() print(model.state_dict())
В следующей части серии будет обсуждаться линейная регрессия.
Вы можете найти меня в Twitter @ bhutanisanyam1, свяжитесь со мной в Linkedin здесь
Помогите мне работать в позднюю ночную смену, купив мне чашку кофе