Tensorflow-serving очень удобно использовать для обслуживания модели TF. В этом блоге объясняется, как запустить обслуживание tensorflow на докере и протестировать его с помощью REST/gRPC. Тестирование с помощью REST довольно распространено и легко, однако тестирование с помощью gRPC не так просто, как вы думаете. Также не так много примеров тестирования tensorflow-обслуживания с помощью gRPC. Надеюсь, что это может стать для вас «привет миром» обслуживания tensorflow.
Запустите обслуживание tensorflow на докере
docker pull tensorflow/serving git clone https://github.com/tensorflow/serving TESTDATA=”$(pwd)/serving/tensorflow_serving/servables/tensorflow/testdata” docker run -t — rm -p 8501:8501 -p 8500:8500 \ -v “$TESTDATA/saved_model_half_plus_two_cpu:/models/half_plus_two” \ -e MODEL_NAME=half_plus_two \ tensorflow/serving &
Протестируйте с помощью REST API
curl -d ‘{“instances”: [1.0, 2.0, 5.0]}’ \ -X POST http://localhost:8501/v1/models/half_plus_two:predict
Протестируйте с помощью gRPC
Сохраните приведенный ниже скрипт Python как predict.py
.
from __future__ import print_function import argparse import time import grpc from tensorflow import make_tensor_proto from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_service_pb2_grpc def run(host, port, model, signature_name): channel = grpc.insecure_channel(‘{host}:{port}’.format(host=host, port=port)) stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) start = time.time() # Call classification model to make prediction print(“make request”) request = predict_pb2.PredictRequest() request.model_spec.name = model request.model_spec.signature_name = signature_name request.inputs[‘x’].CopyFrom(make_tensor_proto([2.5, 3.0, 4.5], shape=[1,3])) # inputs variable name ‘x’ is defined inside the model file print(“predict”) result = stub.Predict(request, 10.0) end = time.time() time_diff = end — start # Reference: # How to access nested values # https://stackoverflow.com/questions/44785847/how-to-retrieve-float-val-from-a-predictresponse-object print(result) print(‘time elapased: {}’.format(time_diff)) if __name__ == ‘__main__’: parser = argparse.ArgumentParser() parser.add_argument(‘ — host’, help=’Tensorflow server host name’, default=’localhost’, type=str) parser.add_argument(‘ — port’, help=’Tensorflow server port number’, default=8500, type=int) parser.add_argument(‘ — model’, help=’model name’, type=str) parser.add_argument(‘ — signature_name’, help=’Signature name of saved TF model’, default=’serving_default’, type=str) args = parser.parse_args() run(args.host, args.port, args.model, args.signature_name)
Запустите код
python predict.py — model "half_plus_two" — host "localhost"
Ссылка
https://www.tensorflow.org/tfx/serving/docker
https://github.com/yu-iskw/tensorflow-serving-example/blob/68dc20c9626e479a1c925433477a7dbe85b8162f/python/grpc_iris_client.py