Skip to content

Commit

Permalink
Add import fallbacks to play nice if Tensorflow already loaded
Browse files Browse the repository at this point in the history
Protobufs register themselves to a global namespace. Because of this, we
need to be defensive in loading the protobufs. Change adds try/excepts
for the importing of the protobufs, allowing the library to fall back on
the main tensorflow protobufs. Downside is that importing tensorflow
after this library will still throw the protobuf registration error
  • Loading branch information
quantumfusion committed Feb 24, 2021
1 parent 3b10879 commit 4847f16
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 13 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ tensor_serving_client/tensorflow_serving
tensor_serving_client/min_tfs_client/tensorflow
tensor_serving_client/min_tfs_client/tensorflow_serving

# VIM
*~
*.swp
*.swo
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from contextlib import contextmanager
from distutils.cmd import Command
from itertools import chain
from pathlib import Path
from shutil import copy2, rmtree
from subprocess import check_output
Expand Down Expand Up @@ -94,7 +95,7 @@ def run(self):
if os.path.isdir('min_tfs_client/tensorflow_serving'):
rmtree('min_tfs_client/tensorflow_serving')
os.rename('tensorflow_serving', 'min_tfs_client/tensorflow_serving')
for file_path in OUTPUT_PATH.rglob('*.py'):
for file_path in chain((OUTPUT_PATH / 'min_tfs_client/tensorflow').rglob('*.py'), (OUTPUT_PATH / 'min_tfs_client/tensorflow_serving').rglob('*.py')):
filename = str(file_path)
with open(filename, 'r') as f:
new_text = f.read().replace('from tensorflow', 'from min_tfs_client.tensorflow')
Expand Down
5 changes: 4 additions & 1 deletion tensor_serving_client/min_tfs_client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

import numpy as np

from tensorflow.core.framework import types_pb2
try:
from min_tfs_client.tensorflow.core.framework import types_pb2
except TypeError: # protobuf registration errors
from tensorflow.core.framework import types_pb2


class TFType(NamedTuple):
Expand Down
29 changes: 20 additions & 9 deletions tensor_serving_client/min_tfs_client/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,26 @@
import grpc
import numpy as np

from tensorflow_serving.apis.classification_pb2 import ClassificationRequest, ClassificationResponse
from tensorflow_serving.apis.predict_pb2 import PredictRequest, PredictResponse
from tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub
from tensorflow_serving.apis.regression_pb2 import RegressionRequest, RegressionResponse
from tensorflow_serving.apis.get_model_status_pb2 import (
GetModelStatusRequest,
GetModelStatusResponse,
)
from tensorflow_serving.apis.model_service_pb2_grpc import ModelServiceStub
try:
from min_tfs_client.tensorflow_serving.apis.classification_pb2 import ClassificationRequest, ClassificationResponse
from min_tfs_client.tensorflow_serving.apis.predict_pb2 import PredictRequest, PredictResponse
from min_tfs_client.tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub
from min_tfs_client.tensorflow_serving.apis.regression_pb2 import RegressionRequest, RegressionResponse
from min_tfs_client.tensorflow_serving.apis.get_model_status_pb2 import (
GetModelStatusRequest,
GetModelStatusResponse,
)
from min_tfs_client.tensorflow_serving.apis.model_service_pb2_grpc import ModelServiceStub
except TypeError: # protobuf registration errors
from tensorflow_serving.apis.classification_pb2 import ClassificationRequest, ClassificationResponse
from tensorflow_serving.apis.predict_pb2 import PredictRequest, PredictResponse
from tensorflow_serving.apis.prediction_service_pb2_grpc import PredictionServiceStub
from tensorflow_serving.apis.regression_pb2 import RegressionRequest, RegressionResponse
from tensorflow_serving.apis.get_model_status_pb2 import (
GetModelStatusRequest,
GetModelStatusResponse,
)
from tensorflow_serving.apis.model_service_pb2_grpc import ModelServiceStub

from .tensors import ndarray_to_tensor_proto

Expand Down
8 changes: 6 additions & 2 deletions tensor_serving_client/min_tfs_client/tensors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .types import DataType
from tensorflow.core.framework.tensor_pb2 import TensorProto
from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
try:
from min_tfs_client.tensorflow.core.framework.tensor_pb2 import TensorProto
from min_tfs_client.tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
except TypeError: # protobuf registration errors
from tensorflow.core.framework.tensor_pb2 import TensorProto
from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
import numpy as np


Expand Down

0 comments on commit 4847f16

Please sign in to comment.