Как воспроизвести значение потерь TensorFlow за пределами Tensorflow, используя окончательный прогноз?

Я пытаюсь изучить Tensorflow. Я собрал игрушечную демонстрацию простой модели. Я заметил, что не могу получить то же значение потерь, что и Tensorflow, если возьму окончательный прогноз и сравню его с фактическими значениями.

Игрушечная модель представляет собой регрессию временных рядов. Попытка предсказать следующее значение во временном ряду. В модели нет ничего необычного, я не ищу советов по моделированию. Я просто хочу понять, что делает Tensorflow, и этот код вычисляет значение потерь, которое я не могу воспроизвести.

Я делаю это в R, используя пакет tensorflow, который является прямой привязкой к библиотеке python. Так что код довольно легко портируется.

Значения временных рядов, которые я использую:

prices = c(104.285380,101.347495,101.357033,102.778283,106.727258,106.841721,104.209071,105.134314,104.733694,101.891194,
       101.099492,103.703526,104.495229,107.213726,107.766964,107.881427,104.104147,109.989455,113.413808,111.754094,
       113.156266,113.175344,114.043355,114.405821,113.886963,114.643464,116.845936,119.584662,121.097665,121.691375,
       122.409572,123.257045,123.003282,124.003971,127.360347,126.565541,123.328865,124.884959,123.012858,123.616144,
       123.874695,123.089466,121.049785,121.231728,121.748831,119.230352,117.056607,119.172896,118.349363,119.651694,
       121.653071,123.022434,122.088777,120.561411,121.815862,121.317912,118.148267,118.971800,118.023780,121.011481,
       119.153744,118.981376,120.006005,121.949926,120.666746,120.274132,121.193425,121.710527,121.471128,120.944449,
       121.404096,120.819962,119.460175,122.189325,121.528583,123.166074,124.171550,124.755684,127.025188,125.023811,
       123.185225,119.843213,123.482080,123.242681,120.465651,119.709150,119.948549,122.715809,121.465765,121.028250,
       121.167678,123.994700,123.821617,125.187049,125.071660,125.062044,126.340935,127.446743,124.638953,126.970765,
       126.715948,125.273590,125.518791,124.965887,125.119739,124.388944,123.706228,122.888892,122.523495,123.927390,
       123.648534,122.283102,122.042709,122.696577,122.408106,122.965818,121.735006,122.706193,122.148482,123.186979,
       122.600420,121.879241,119.744552,120.605159,121.735006,121.581154,121.158062,120.859975,117.859871,115.455941,
       118.542587,120.831128,120.783049,121.946551,123.571608,124.638953,126.994804,125.725529,120.408036,120.350342,
       119.715705,118.052185,118.638744,118.263731,117.667556,116.638674,113.888579,110.234605,110.965400,110.705776,
       111.582500,115.639343,109.621692,111.312044,111.225111,112.007502,113.166600,112.529097,111.089883,108.810324,
       102.155170,99.605154,100.204021,105.951215,109.071121,109.428509,108.916574,104.048363,108.510890,106.608038,
       105.545531,108.481913,106.395536,108.733051,110.317151,111.379658,112.316595,112.442164,110.037036,109.583056,
       111.283066,109.534760,110.423402,111.080224,110.800109,108.607482,105.342689,106.202353,105.844965,106.617697,
       107.004063,107.515998,107.004063,105.767692,108.298389,107.796113,107.979637,106.453491,108.047251,107.255201,
       107.921682,109.892149,109.882489,111.563182,115.021157,111.350680,110.645562,115.204681,116.421734,115.426842,
       117.049579,118.392201,117.841629,116.798441,117.436526,116.961193,113.274931,112.634686,112.256359,108.977526,
       110.757603,110.287119,113.779367,115.224769,115.729205,114.225599,115.321776,114.497218,114.283803,114.759136,
       113.827870,112.799597,111.751923,115.467287,114.739735,114.691232,112.159352,112.692890,109.792384,109.113336,
       107.182899,108.007458,105.718095,102.856393,104.117482,104.020475,105.359170,104.796530,103.622747,105.485279,
       104.107781,102.109440,102.196746,99.635764,97.685926,93.563134,94.057869,95.580877,96.968075,94.474998,96.541245,
       94.222780,93.766848,93.892957,93.417623,98.384375,96.463639,96.997177,90.623825,91.273771,94.426495,93.543732,
       91.652098,93.466127,93.708644,91.696830,92.662368,92.642862,91.940652,91.384737,91.667571,94.252091,95.695522,
       93.881481,93.666917,94.486161,92.350274,93.725434,94.369126,94.515420,94.300856,98.045972,98.260536,98.992004,
       100.464693,99.352862,98.533617,98.621394,98.670158,99.733225,99.986801,101.995899,103.351553,103.185754,103.302789,
       103.293036,104.083021,103.507600,103.058966,102.590827,105.019300,106.852847,106.296931,107.272222,108.374300,
       107.096670,108.218254,105.858050,105.975085,106.326190,107.711103,109.271568,109.330085,107.135681,104.824242,
       104.268327,104.482891,103.351553,103.068719,102.483545,101.771582,95.402934,92.486815,91.423748,91.326219,92.828167,
       91.862629,90.936103,90.981767,91.050455,91.668644,90.775704,88.646385,88.823011,92.120021,91.737332,92.787273,
       92.434021,93.434899,94.622215,96.064657,97.752412,98.527602,98.468727,97.987913,96.614159,95.888032,96.084282,
       96.780972,97.173473,97.085160,97.781850,96.977222,95.515156,95.632906,95.318905,95.721219,93.542837,93.317149,
       94.111964,93.758713,94.298402,91.649019,90.314515,91.835457,92.630272,93.807775,94.092339,93.209211,93.739088,
       94.141401,94.867529,95.161904,95.593656,95.053967,96.937972,96.928160,97.958475,97.997725,98.086038,97.565974,
       96.810409,95.515156,94.857716,101.019984,102.383926,102.256363,104.061868,102.521301,103.806742,103.885243,106.032880,
       106.910897,107.344972,106.545878,106.476821,106.723455,108.005951,107.907298,107.749452,107.611337,107.887567,
       107.049012,107.384434,106.575474,106.121668,105.500150,105.381766,104.572806,104.671460,105.292978,106.279514,
       106.249917,106.901031,104.099269,101.741448,104.020346,106.496551,110.265119,114.013955,113.372707,112.050749,
       112.040883,112.021153,113.076746,111.192462,111.360173,111.567346,112.415767,110.669598,111.527885,111.005021,
       111.478558,111.527885,112.356575,112.524286,114.487492,114.734126,115.760124,115.404971,116.046219,115.967296,
       115.888373,115.543086,115.483894,115.030087,116.065950,116.657871,114.033686,112.938631,112.188864,112.011287,
       109.988889,110.087542,108.351239,107.931825,109.488725,110.133301,109.954803,106.890586,107.525246,104.827942,
       106.216260,109.072229,109.032563,109.141645,110.797711,110.867126,110.301883,110.857210,110.639046,110.529963,
       109.597807,108.576401,108.982980,108.199572,109.032563,110.103551,111.184456,112.999187,112.354610,114.228840,
       114.228840,114.853583,115.002331,115.666741,115.974154,116.083236,115.319661,115.547742,116.281568,115.785740,
       115.755990,114.853583,115.180830,115.051914,115.636991,116.926144,117.997132,118.116131,118.750791,118.254963,
       118.046715,118.998705,118.988788,118.780540,118.998705,119.078037,118.968955,120.863018,120.922517,120.932434,
       120.615104,120.337440,127.675694,127.457529,128.002940,129.202844,130.432497,130.938241,131.315071,131.581537,
       132.746769,134.469718,134.957721,134.793393,135.166865,136.142871,136.551200,135.973564,136.103034,136.371934,
       136.431689,139.220278,138.393660,139.210318,138.772112,138.951378,138.433497,138.114801,138.572927,138.632682,
       138.423538,139.887547,140.116610,139.419462,140.883471,139.270074,140.843634,140.345672,140.066813,140.305835,
       143.213935,143.532630,143.343405,143.074505,143.114342,144.179981,143.433038,143.074505,142.755809,142.586502,
       141.052778,141.222086,140.475142,141.251963,140.624531,140.106650,141.859477,141.690170,143.054587,143.950919,
       143.065343,143.203975,143.064546,146.002523,146.908814,146.460648,145.932808,148.352905,152.376439,153.332527,
       152.635380,153.322568,156.100000,155.700000,155.470000,150.250000,152.540000,152.960000,153.990000,153.800000,
       153.340000,153.870000,153.610000,153.670000,152.760000,153.180000,155.450000,153.930000,154.450000,155.370000,
       154.990000,148.980000,145.320000,146.590000)

и код, который я использую,

#split into train and test
v.train = prices[1:500]
v.test = prices[-(1:500)]

sampleSize = 30

getJoinedLaggedData = function(dataVector){
  data.len = length(dataVector)
  result = NULL
  for (i in seq(1, data.len - sampleSize + 1, by = 1)) {
    result = rbind(result, dataVector[seq(i, i + sampleSize - 1, by = 1)])
  }
  list(input=result[,-sampleSize], output = result[,sampleSize])
}

m.train = getJoinedLaggedData(v.train)

#build the model
indata = tf$placeholder(tf$float32, shape(471, sampleSize - 1))
outData = tf$placeholder(tf$float32, shape(471))

w1 <- tf$Variable(tf$truncated_normal(shape(sampleSize - 1, 300),stddev = 1.0 / sqrt(sampleSize)))
b1 = tf$Variable(tf$zeros(shape(300)))
l1 <- tf$matmul(indata, w1) + b1

w2 <- tf$Variable(tf$truncated_normal(shape(300, 50),stddev = 1.0 / sqrt(300)))
b2 <- tf$Variable(tf$zeros(shape(50)))
l2 <- tf$matmul(l1, w2) + b2

w3 <- tf$Variable(tf$truncated_normal(shape(50, 1),stddev = 1.0 / sqrt(50)))
b3 <- tf$Variable(tf$zeros(shape(1)))
pred <- tf$matmul(l2, w3) + b3

#loss function
loss <- tf$reduce_mean(tf$abs(tf$subtract(x = pred, y = outData)))

#trainer
optimizer <- tf$train$GradientDescentOptimizer(0.00003)
train.op <- optimizer$minimize(loss)

#run the model
init = tf$global_variables_initializer()

sess = tf$Session()
sess$run(init)

for (i in 1:1000){
  values = sess$run(list(train.op, loss, pred), feed_dict = dict(indata = m.train$input, outData = m.train$output))
  loss_value = values[[2]]
  pred_value = values[[3]]
  print(loss_value)
}

myLoss = mean(abs(pred_value - m.train$output))
print(sprintf("Tensorflow Loss: %f, My Loss: %f", loss_value, myLoss))

когда я запускаю это, Tensorflow вычисляет потерю около 11,71, а я вычисляю потерю около 5,03.

Это также относится и к другим функциям потерь, которые я пробовал.

Любой совет будет принят с благодарностью!


person Chechy Levas    schedule 24.06.2017    source источник


Ответы (1)


Разница возникла из-за вывода m.train$, tensorflow считывает его как вектор, а не как матрицу.

Я изменил код по мере необходимости. пожалуйста, проверьте

prices = c(104.285380,101.347495,101.357033,102.778283,106.727258,106.841721,104.209071,105.134314,104.733694,101.891194,
       101.099492,103.703526,104.495229,107.213726,107.766964,107.881427,104.104147,109.989455,113.413808,111.754094,
       113.156266,113.175344,114.043355,114.405821,113.886963,114.643464,116.845936,119.584662,121.097665,121.691375,
       122.409572,123.257045,123.003282,124.003971,127.360347,126.565541,123.328865,124.884959,123.012858,123.616144,
       123.874695,123.089466,121.049785,121.231728,121.748831,119.230352,117.056607,119.172896,118.349363,119.651694,
       121.653071,123.022434,122.088777,120.561411,121.815862,121.317912,118.148267,118.971800,118.023780,121.011481,
       119.153744,118.981376,120.006005,121.949926,120.666746,120.274132,121.193425,121.710527,121.471128,120.944449,
       121.404096,120.819962,119.460175,122.189325,121.528583,123.166074,124.171550,124.755684,127.025188,125.023811,
       123.185225,119.843213,123.482080,123.242681,120.465651,119.709150,119.948549,122.715809,121.465765,121.028250,
       121.167678,123.994700,123.821617,125.187049,125.071660,125.062044,126.340935,127.446743,124.638953,126.970765,
       126.715948,125.273590,125.518791,124.965887,125.119739,124.388944,123.706228,122.888892,122.523495,123.927390,
       123.648534,122.283102,122.042709,122.696577,122.408106,122.965818,121.735006,122.706193,122.148482,123.186979,
       122.600420,121.879241,119.744552,120.605159,121.735006,121.581154,121.158062,120.859975,117.859871,115.455941,
       118.542587,120.831128,120.783049,121.946551,123.571608,124.638953,126.994804,125.725529,120.408036,120.350342,
       119.715705,118.052185,118.638744,118.263731,117.667556,116.638674,113.888579,110.234605,110.965400,110.705776,
       111.582500,115.639343,109.621692,111.312044,111.225111,112.007502,113.166600,112.529097,111.089883,108.810324,
       102.155170,99.605154,100.204021,105.951215,109.071121,109.428509,108.916574,104.048363,108.510890,106.608038,
       105.545531,108.481913,106.395536,108.733051,110.317151,111.379658,112.316595,112.442164,110.037036,109.583056,
       111.283066,109.534760,110.423402,111.080224,110.800109,108.607482,105.342689,106.202353,105.844965,106.617697,
       107.004063,107.515998,107.004063,105.767692,108.298389,107.796113,107.979637,106.453491,108.047251,107.255201,
       107.921682,109.892149,109.882489,111.563182,115.021157,111.350680,110.645562,115.204681,116.421734,115.426842,
       117.049579,118.392201,117.841629,116.798441,117.436526,116.961193,113.274931,112.634686,112.256359,108.977526,
       110.757603,110.287119,113.779367,115.224769,115.729205,114.225599,115.321776,114.497218,114.283803,114.759136,
       113.827870,112.799597,111.751923,115.467287,114.739735,114.691232,112.159352,112.692890,109.792384,109.113336,
       107.182899,108.007458,105.718095,102.856393,104.117482,104.020475,105.359170,104.796530,103.622747,105.485279,
       104.107781,102.109440,102.196746,99.635764,97.685926,93.563134,94.057869,95.580877,96.968075,94.474998,96.541245,
       94.222780,93.766848,93.892957,93.417623,98.384375,96.463639,96.997177,90.623825,91.273771,94.426495,93.543732,
       91.652098,93.466127,93.708644,91.696830,92.662368,92.642862,91.940652,91.384737,91.667571,94.252091,95.695522,
       93.881481,93.666917,94.486161,92.350274,93.725434,94.369126,94.515420,94.300856,98.045972,98.260536,98.992004,
       100.464693,99.352862,98.533617,98.621394,98.670158,99.733225,99.986801,101.995899,103.351553,103.185754,103.302789,
       103.293036,104.083021,103.507600,103.058966,102.590827,105.019300,106.852847,106.296931,107.272222,108.374300,
       107.096670,108.218254,105.858050,105.975085,106.326190,107.711103,109.271568,109.330085,107.135681,104.824242,
       104.268327,104.482891,103.351553,103.068719,102.483545,101.771582,95.402934,92.486815,91.423748,91.326219,92.828167,
       91.862629,90.936103,90.981767,91.050455,91.668644,90.775704,88.646385,88.823011,92.120021,91.737332,92.787273,
       92.434021,93.434899,94.622215,96.064657,97.752412,98.527602,98.468727,97.987913,96.614159,95.888032,96.084282,
       96.780972,97.173473,97.085160,97.781850,96.977222,95.515156,95.632906,95.318905,95.721219,93.542837,93.317149,
       94.111964,93.758713,94.298402,91.649019,90.314515,91.835457,92.630272,93.807775,94.092339,93.209211,93.739088,
       94.141401,94.867529,95.161904,95.593656,95.053967,96.937972,96.928160,97.958475,97.997725,98.086038,97.565974,
       96.810409,95.515156,94.857716,101.019984,102.383926,102.256363,104.061868,102.521301,103.806742,103.885243,106.032880,
       106.910897,107.344972,106.545878,106.476821,106.723455,108.005951,107.907298,107.749452,107.611337,107.887567,
       107.049012,107.384434,106.575474,106.121668,105.500150,105.381766,104.572806,104.671460,105.292978,106.279514,
       106.249917,106.901031,104.099269,101.741448,104.020346,106.496551,110.265119,114.013955,113.372707,112.050749,
       112.040883,112.021153,113.076746,111.192462,111.360173,111.567346,112.415767,110.669598,111.527885,111.005021,
       111.478558,111.527885,112.356575,112.524286,114.487492,114.734126,115.760124,115.404971,116.046219,115.967296,
       115.888373,115.543086,115.483894,115.030087,116.065950,116.657871,114.033686,112.938631,112.188864,112.011287,
       109.988889,110.087542,108.351239,107.931825,109.488725,110.133301,109.954803,106.890586,107.525246,104.827942,
       106.216260,109.072229,109.032563,109.141645,110.797711,110.867126,110.301883,110.857210,110.639046,110.529963,
       109.597807,108.576401,108.982980,108.199572,109.032563,110.103551,111.184456,112.999187,112.354610,114.228840,
       114.228840,114.853583,115.002331,115.666741,115.974154,116.083236,115.319661,115.547742,116.281568,115.785740,
       115.755990,114.853583,115.180830,115.051914,115.636991,116.926144,117.997132,118.116131,118.750791,118.254963,
       118.046715,118.998705,118.988788,118.780540,118.998705,119.078037,118.968955,120.863018,120.922517,120.932434,
       120.615104,120.337440,127.675694,127.457529,128.002940,129.202844,130.432497,130.938241,131.315071,131.581537,
       132.746769,134.469718,134.957721,134.793393,135.166865,136.142871,136.551200,135.973564,136.103034,136.371934,
       136.431689,139.220278,138.393660,139.210318,138.772112,138.951378,138.433497,138.114801,138.572927,138.632682,
       138.423538,139.887547,140.116610,139.419462,140.883471,139.270074,140.843634,140.345672,140.066813,140.305835,
       143.213935,143.532630,143.343405,143.074505,143.114342,144.179981,143.433038,143.074505,142.755809,142.586502,
       141.052778,141.222086,140.475142,141.251963,140.624531,140.106650,141.859477,141.690170,143.054587,143.950919,
       143.065343,143.203975,143.064546,146.002523,146.908814,146.460648,145.932808,148.352905,152.376439,153.332527,
       152.635380,153.322568,156.100000,155.700000,155.470000,150.250000,152.540000,152.960000,153.990000,153.800000,
       153.340000,153.870000,153.610000,153.670000,152.760000,153.180000,155.450000,153.930000,154.450000,155.370000,
       154.990000,148.980000,145.320000,146.590000)

v.train = prices[1:500]
v.test = prices[-(1:500)]

sampleSize = 30

getJoinedLaggedData = function(dataVector){
  data.len = length(dataVector)
  result = NULL
  for (i in seq(1, data.len - sampleSize + 1, by = 1)) {
    result = rbind(result, dataVector[seq(i, i + sampleSize - 1, by = 1)])
  }
  list(input=matrix(result[,-sampleSize], ncol = sampleSize - 1), output = matrix(result[,sampleSize]))
}

m.train = getJoinedLaggedData(v.train)

#build the model
indata = tf$placeholder(tf$float32, shape(471, sampleSize - 1))
outData = tf$placeholder(tf$float32, shape(471, 1))

w1 <- tf$Variable(tf$truncated_normal(shape(sampleSize - 1, 300),stddev = 1.0 / sqrt(sampleSize)))
b1 = tf$Variable(tf$zeros(shape(300)))
l1 <- tf$matmul(indata, w1) + b1

w2 <- tf$Variable(tf$truncated_normal(shape(300, 50),stddev = 1.0 / sqrt(300)))
b2 <- tf$Variable(tf$zeros(shape(50)))
l2 <- tf$matmul(l1, w2) + b2

w3 <- tf$Variable(tf$truncated_normal(shape(50, 1),stddev = 1.0 / sqrt(50)))
b3 <- tf$Variable(tf$zeros(shape(1)))
pred <- tf$matmul(l2, w3) + b3

#loss function
loss <- tf$reduce_mean(tf$abs(tf$sub(x = pred, y = outData)))

#trainer
optimizer <- tf$train$GradientDescentOptimizer(0.00003)
train.op <- optimizer$minimize(loss)

#run the model
init = tf$initialize_all_variables()

sess = tf$Session()
sess$run(init)

for (i in 1:1000){
  values = sess$run(list(train.op, loss, pred), feed_dict = dict(indata = m.train$input, outData = m.train$output))
  loss_value = values[[2]]
  pred_value = values[[3]]
  print(loss_value)
}

myLoss = mean(abs(pred_value - m.train$output))
print(sprintf("Tensorflow Loss: %f, My Loss: %f", loss_value, myLoss))
person user3256363    schedule 24.06.2017
comment
Превосходно! Для бонусных баллов есть ли способ изменить функцию потерь, чтобы перевести ее из вектора в матрицу? Очевидно, что здесь это на самом деле не нужно, но полезно знать, каковы все углы. - person Chechy Levas; 24.06.2017
comment
Функция потерь всегда должна возвращать одно значение. Извините, я не понимаю, что вы имели в виду, изменяя функцию потерь, чтобы преобразовать ее из вектора в матрицу. - person user3256363; 24.06.2017
comment
Я имею в виду изменить его так, чтобы outData выводился в матрицу, а не в матрицу в файле feed_dict. Первый использует Tensorflow для кастинга. Последний использует R для выполнения приведения. - person Chechy Levas; 24.06.2017
comment
Я не знаю, можете ли вы внести некоторые изменения в функцию потерь. На данный момент вы можете обновить функцию getJoinedLaggedData, чтобы преобразовать вывод m.train$ в матрицу. Я обновил код. - person user3256363; 24.06.2017