В оценщике Tensorflow, как это работает, когда model_fn вызывается несколько раз?

def model_fn(features, labels, mode, params):
  """Model function for Estimator."""

  # Connect the first hidden layer to input layer
  # (features["x"]) with relu activation
  first_hidden_layer = tf.layers.dense(features["x"], 10, activation=tf.nn.relu)

  # Connect the second hidden layer to first hidden layer with relu
  second_hidden_layer = tf.layers.dense(
      first_hidden_layer, 10, activation=tf.nn.relu)

  # Connect the output layer to second hidden layer (no activation fn)
  output_layer = tf.layers.dense(second_hidden_layer, 1)

  # Reshape output layer to 1-dim Tensor to return predictions
  predictions = tf.reshape(output_layer, [-1])

  # Provide an estimator spec for `ModeKeys.PREDICT`.
  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions={"ages": predictions})

  # Calculate loss using mean squared error
  loss = tf.losses.mean_squared_error(labels, predictions)

  # Calculate root mean squared error as additional eval metric
  eval_metric_ops = {
      "rmse": tf.metrics.root_mean_squared_error(
          tf.cast(labels, tf.float64), predictions)
  }

  optimizer = tf.train.GradientDescentOptimizer(
      learning_rate=params["learning_rate"])
  train_op = optimizer.minimize(
      loss=loss, global_step=tf.train.get_global_step())

  # Provide an estimator spec for `ModeKeys.EVAL` and `ModeKeys.TRAIN` modes.
  return tf.estimator.EstimatorSpec(
      mode=mode,
      loss=loss,
      train_op=train_op,
      eval_metric_ops=eval_metric_ops)

Выше приведен пример model_fn, используемый оценщиком Tensorflow.

Как упоминалось в руководстве, этот model_fn можно вызывать в другом контексте (обучение, прогнозирование, оценка). Однако я немного запутался, потому что каждый раз, когда вызывается model_fn, вместо повторного использования существующего графа создается новый граф (или создается новый узел в графе)

Например, сначала я вызвал model_fn в режиме TRAIN, затем я вызвал model_fn в режиме PREDICT. Как я могу убедиться, что PREDICT повторно использует вес обученных значений?


person Hanfei Sun    schedule 22.10.2017    source источник


Ответы (1)


См. эту тему: https://github.com/tensorflow/tensorflow/issues/13895

График каждый раз перестраивается и данные загружаются из контрольной точки.

person Hanfei Sun    schedule 24.10.2017