Использование предварительно обученной модели Inception_v4

https://github.com/tensorflow/models/blob/master/tutorials/image/imagenet/classify_image.py

Это дает ссылку для загрузки контрольных точек для предварительно обученных моделей Inception v1-4. Однако tar.gz содержит только файл .ckpt.

В руководстве по использованию Inception v3 2012 [Эта ссылка] tar.gz содержит файлы .pb и .pbtxt, которые используются для классификации.

Как я могу использовать только файл .ckpt для создания соответствующих файлов .pb и .pbtxt? ИЛИ Есть ли альтернативный способ использования файла .ckpt для классификации?


person megan adams    schedule 22.03.2017    source источник


Ответы (1)


Даже я тоже пробую модель inception_v4. Во время поиска я смог найти файлы контрольных точек, содержащие веса. Итак, чтобы использовать это, необходимо было загрузить граф inception_v4 из inception_v4.py и восстановить сеанс из файла контрольной точки. Следующий код прочитает файл контрольной точки и создаст файл protobuf.

import tensorflow as tf
slim = tf.contrib.slim
import tf_slim.models.slim.nets as net
# inception_v3_arg_scope
import tf_slim
import inception_v4 as net
import cv2


# checkpoint file
checkpoint_file = '/home/.../inception_v4.ckpt' 

# Load Session
sess = tf.Session()
arg_scope = net.inception_v4_arg_scope()
input_tensor = tf.placeholder(tf.float32, [None, 299, 299, 3])
with slim.arg_scope(arg_scope):
    logits, end_points = net.inception_v4(inputs=input_tensor)

saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
f = tf.gfile.FastGFile('./mynet.pb', "w")
f.write(sess.graph_def.SerializeToString())
f.close()

# reading the graph
#
with tf.gfile.FastGFile('./mynet.pb', 'rb') as fp:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(fp.read())

    with tf.Session(graph=tf.import_graph_def(graph_def, name='')) as sess:
    # op = sess.graph.get_operations()
    # with open('./tensors.txt', mode='w') as fp:
    #     for m in op:
    #     #     print m.values()
    #         fp.write('%s \n' % str(m.values()))
    cell_patch = cv2.imread('./car.jpg')
    softmax_tensor = sess.graph.get_tensor_by_name('InceptionV4/Logits/Predictions:0')
    predictions = sess.run(softmax_tensor, {'Placeholder:0': cell_patch})

Но приведенный выше код не дает никаких прогнозов. Потому что я столкнулся с проблемой ввода данных в график. Но это может быть хорошей отправной точкой для работы с файлами контрольных точек.

Контрольная точка загружается по следующей ссылке контрольные точки

person Nirmal Jith    schedule 23.03.2017