Линейная регрессия за 2 минуты (с использованием PyTorch)

Это вторая часть серии Primer Series PyTorch.

Линейная регрессия - это линейный подход к моделированию взаимосвязи между входными данными и прогнозами.

Мы находим "линейное соответствие" данным.

Подгонка: мы пытаемся предсказать переменную y, подгоняя кривую (здесь линия) к данным. Кривая линейной регрессии следует линейной зависимости между скалярной (x) и зависимой переменной.

Создание моделей в PyTorch

  1. Создать класс
  2. Объявите свой форвардный пас
  3. Настройте гиперпараметры
class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__() 
        # Calling Super Class's constructor
        self.linear = nn.Linear(input_dim, output_dim)
        # nn.linear is defined in nn.Module
    def forward(self, x):
        # Here the forward pass is simply a linear function
        out = self.linear(x)
        return out
input_dim = 1
output_dim = 1

Шаги

  1. Создать экземпляр модели
  2. Выберите критерий убытка
  3. Выберите гиперпараметры
model = LinearRegressionModel(input_dim,output_dim)
criterion = nn.MSELoss()# Mean Squared Loss
l_rate = 0.01
optimiser = torch.optim.SGD(model.parameters(), lr = l_rate) #Stochastic Gradient Descent
epochs = 2000

Обучение модели

for epoch in range(epochs):
    epoch +=1
    #increase the number of epochs by 1 every time
inputs = Variable(torch.from_numpy(x_train))
    labels = Variable(torch.from_numpy(y_correct))
    #clear grads as discussed in prev post
optimiser.zero_grad()
#forward to get predicted values
outputs = model.forward(inputs)
    loss = criterion(outputs, labels)
    loss.backward()# back props
    optimiser.step()# update the parameters
    print('epoch {}, loss {}'.format(epoch,loss.data[0]))

Наконец, распечатайте прогнозируемые значения

predicted =model.forward(Variable(torch.from_numpy(x_train))).data.numpy()
plt.plot(x_train, y_correct, 'go', label = 'from data', alpha = .5)
plt.plot(x_train, predicted, label = 'prediction', alpha = 0.5)
plt.legend()
plt.show()
print(model.state_dict())

Если вы хотите прочитать о неделе 2 в моем путешествии по самостоятельному вождению, вот запись в блоге

В следующей части серии будет обсуждаться линейная регрессия.

Вы можете найти меня в Twitter @ bhutanisanyam1, свяжитесь со мной в Linkedin здесь

Подпишитесь на мою рассылку, чтобы получать еженедельный список статей, посвященных глубокому обучению и компьютерному зрению

Помогите мне работать в позднюю ночную смену, купив мне чашку кофе