Spark - перехват линейной регрессии MLlib и вес NaN

Я пытаюсь построить модель регрессии в Spark, используя некоторые пользовательские данные, а перехват и вес всегда равны nan. Это мои данные:

data = [LabeledPoint(0.0, [27022.0]), LabeledPoint(1.0, [27077.0]), LabeledPoint(2.0, [27327.0]), LabeledPoint(3.0, [27127.0])]

Выход:

(weights=[nan], intercept=nan)  

Однако, если я использую этот набор данных (взятый из примеров Spark), он возвращает не nan вес и перехват.

data = [LabeledPoint(0.0, [0.0]), LabeledPoint(1.0, [1.0]), LabeledPoint(3.0, [2.0]),LabeledPoint(2.0, [3.0])]

Выход:

(weights=[0.798729902914], intercept=0.3027117101297481) 

Это мой текущий код

model = LinearRegressionWithSGD.train(sc.parallelize(data), intercept=True)

Я что-то упускаю? Это потому, что цифры в моих данных такие большие? Я впервые использую MLlib, поэтому могу упустить некоторые детали.

Спасибо


person user3276768    schedule 20.04.2015    source источник


Ответы (1)


Линейная регрессия MLlib основана на SGD, поэтому вам необходимо настроить итерации и размер шага, см. https://spark.apache.org/docs/latest/mllib-optimization.html.

Я попробовал ваши пользовательские данные, подобные этому, и получил некоторые результаты (в scala):

val numIterations = 20
val model = LinearRegressionWithSGD.train(sc.parallelize(data), numIterations)
person selvinsource    schedule 21.04.2015
comment
на Питоне model = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=20) - person selvinsource; 21.04.2015