Tensorflow-serving очень удобно использовать для обслуживания модели TF. В этом блоге объясняется, как запустить обслуживание tensorflow на докере и протестировать его с помощью REST/gRPC. Тестирование с помощью REST довольно распространено и легко, однако тестирование с помощью gRPC не так просто, как вы думаете. Также не так много примеров тестирования tensorflow-обслуживания с помощью gRPC. Надеюсь, что это может стать для вас «привет миром» обслуживания tensorflow.

Запустите обслуживание tensorflow на докере

docker pull tensorflow/serving
git clone https://github.com/tensorflow/serving
TESTDATA=”$(pwd)/serving/tensorflow_serving/servables/tensorflow/testdata”
docker run -t — rm -p 8501:8501 -p 8500:8500 \
 -v “$TESTDATA/saved_model_half_plus_two_cpu:/models/half_plus_two” \
 -e MODEL_NAME=half_plus_two \
 tensorflow/serving &

Протестируйте с помощью REST API

curl -d ‘{“instances”: [1.0, 2.0, 5.0]}’ \
 -X POST http://localhost:8501/v1/models/half_plus_two:predict

Протестируйте с помощью gRPC

Сохраните приведенный ниже скрипт Python как predict.py.

from __future__ import print_function
import argparse
import time
import grpc
from tensorflow import make_tensor_proto
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
def run(host, port, model, signature_name):
channel = grpc.insecure_channel(‘{host}:{port}’.format(host=host, port=port))
 stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
start = time.time()
# Call classification model to make prediction
 print(“make request”)
 request = predict_pb2.PredictRequest()
 request.model_spec.name = model
 request.model_spec.signature_name = signature_name
 request.inputs[‘x’].CopyFrom(make_tensor_proto([2.5, 3.0, 4.5], shape=[1,3])) # inputs variable name ‘x’ is defined inside the model file
print(“predict”)
 result = stub.Predict(request, 10.0)
end = time.time()
 time_diff = end — start
# Reference:
 # How to access nested values
 # https://stackoverflow.com/questions/44785847/how-to-retrieve-float-val-from-a-predictresponse-object
 print(result)
 print(‘time elapased: {}’.format(time_diff))
if __name__ == ‘__main__’:
 parser = argparse.ArgumentParser()
 parser.add_argument(‘ — host’, help=’Tensorflow server host name’, default=’localhost’, type=str)
 parser.add_argument(‘ — port’, help=’Tensorflow server port number’, default=8500, type=int)
 parser.add_argument(‘ — model’, help=’model name’, type=str)
 parser.add_argument(‘ — signature_name’, help=’Signature name of saved TF model’,
 default=’serving_default’, type=str)
args = parser.parse_args()
 run(args.host, args.port, args.model, args.signature_name)

Запустите код

python predict.py — model "half_plus_two" — host "localhost"

Ссылка

https://www.tensorflow.org/tfx/serving/docker
https://github.com/yu-iskw/tensorflow-serving-example/blob/68dc20c9626e479a1c925433477a7dbe85b8162f/python/grpc_iris_client.py