Skip to content

Commit

Permalink
Merge pull request #150 from whdalsrnt/master
Browse files Browse the repository at this point in the history
refactor: refactor code for grpc compatability
  • Loading branch information
whdalsrnt authored Jul 14, 2024
2 parents 7dc1980 + 8cfa68e commit 239631e
Showing 1 changed file with 169 additions and 80 deletions.
249 changes: 169 additions & 80 deletions src/spaceone/core/pygrpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import types
import grpc
from google.protobuf.json_format import ParseDict
from google.protobuf.message_factory import MessageFactory
from google.protobuf.message_factory import MessageFactory, GetMessageClass
from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.descriptor import ServiceDescriptor, MethodDescriptor
from grpc_reflection.v1alpha.proto_reflection_descriptor_database import ProtoReflectionDescriptorDatabase
from grpc_reflection.v1alpha.proto_reflection_descriptor_database import (
ProtoReflectionDescriptorDatabase,
)
from spaceone.core.error import *

_MAX_RETRIES = 2
Expand All @@ -14,24 +16,32 @@


class _ClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):

grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
):
def __init__(self, options: dict, channel_key: str, request_map: dict):
self._request_map = request_map
self._channel_key = channel_key
self.metadata = options.get('metadata', {})
self.metadata = options.get("metadata", {})

def _check_message(self, client_call_details, request_or_iterator, is_stream):
if client_call_details.method in self._request_map:
if is_stream:
if not isinstance(request_or_iterator, types.GeneratorType):
raise Exception("Stream method must be specified as a generator type.")
raise Exception(
"Stream method must be specified as a generator type."
)

return self._generate_message(request_or_iterator, client_call_details.method)
return self._generate_message(
request_or_iterator, client_call_details.method
)

else:
return self._make_message(request_or_iterator, client_call_details.method)
return self._make_message(
request_or_iterator, client_call_details.method
)

return request_or_iterator

Expand All @@ -50,17 +60,17 @@ def _check_error(self, response):
if isinstance(response, Exception):
details = response.details()
status_code = response.code().name
if details.startswith('ERROR_'):
details_split = details.split(':', 1)
if details.startswith("ERROR_"):
details_split = details.split(":", 1)
if len(details_split) == 2:
error_code, error_message = details_split
else:
error_code = details_split[0]
error_message = details

if status_code == 'PERMISSION_DENIED':
if status_code == "PERMISSION_DENIED":
raise ERROR_PERMISSION_DENIED()
elif status_code == 'UNAUTHENTICATED':
elif status_code == "UNAUTHENTICATED":
raise ERROR_AUTHENTICATE_FAILURE(message=error_message)
else:
e = ERROR_INTERNAL_API(message=error_message)
Expand All @@ -70,13 +80,15 @@ def _check_error(self, response):

else:
error_message = response.details()
if status_code == 'PERMISSION_DENIED':
if status_code == "PERMISSION_DENIED":
raise ERROR_PERMISSION_DENIED()
elif status_code == 'PERMISSION_DENIED':
elif status_code == "PERMISSION_DENIED":
raise ERROR_AUTHENTICATE_FAILURE(message=error_message)
elif status_code == 'UNAVAILABLE':
e = ERROR_GRPC_CONNECTION(channel=self._channel_key, message=error_message)
e.meta['channel'] = self._channel_key
elif status_code == "UNAVAILABLE":
e = ERROR_GRPC_CONNECTION(
channel=self._channel_key, message=error_message
)
e.meta["channel"] = self._channel_key
raise e
else:
e = ERROR_INTERNAL_API(message=error_message)
Expand All @@ -92,12 +104,16 @@ def _generate_response(self, response_iterator):
except Exception as e:
self._check_error(e)

def _retry_call(self, continuation, client_call_details, request_or_iterator, is_stream):
def _retry_call(
self, continuation, client_call_details, request_or_iterator, is_stream
):
retries = 0

while True:
try:
response_or_iterator = continuation(client_call_details, request_or_iterator)
response_or_iterator = continuation(
client_call_details, request_or_iterator
)

if is_stream:
response_or_iterator = self._generate_response(response_or_iterator)
Expand All @@ -107,84 +123,142 @@ def _retry_call(self, continuation, client_call_details, request_or_iterator, is
return response_or_iterator

except Exception as e:
if e.error_code == 'ERROR_GRPC_CONNECTION':
if e.error_code == "ERROR_GRPC_CONNECTION":
if retries >= _MAX_RETRIES:
channel = e.meta.get('channel')
channel = e.meta.get("channel")
if channel in _GRPC_CHANNEL:
_LOGGER.error(f'Disconnect gRPC Endpoint. (channel = {channel})')
_LOGGER.error(
f"Disconnect gRPC Endpoint. (channel = {channel})"
)
del _GRPC_CHANNEL[channel]
raise e
else:
_LOGGER.debug(f'Retry gRPC Call: reason = {e.message}, retry = {retries + 1}')
_LOGGER.debug(
f"Retry gRPC Call: reason = {e.message}, retry = {retries + 1}"
)
else:
raise e

retries += 1

def _intercept_call(self, continuation, client_call_details,
request_or_iterator, is_request_stream, is_response_stream):
new_request_or_iterator = self. _check_message(
client_call_details, request_or_iterator, is_request_stream)

return self._retry_call(continuation, client_call_details,
new_request_or_iterator, is_response_stream)
def _intercept_call(
self,
continuation,
client_call_details,
request_or_iterator,
is_request_stream,
is_response_stream,
):
new_request_or_iterator = self._check_message(
client_call_details, request_or_iterator, is_request_stream
)

return self._retry_call(
continuation,
client_call_details,
new_request_or_iterator,
is_response_stream,
)

def intercept_unary_unary(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request, False, False)
return self._intercept_call(
continuation, client_call_details, request, False, False
)

def intercept_unary_stream(self, continuation, client_call_details, request):
return self._intercept_call(continuation, client_call_details, request, False, True)
return self._intercept_call(
continuation, client_call_details, request, False, True
)

def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
return self._intercept_call(continuation, client_call_details, request_iterator, True, False)
def intercept_stream_unary(
self, continuation, client_call_details, request_iterator
):
return self._intercept_call(
continuation, client_call_details, request_iterator, True, False
)

def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
return self._intercept_call(continuation, client_call_details, request_iterator, True, True)
def intercept_stream_stream(
self, continuation, client_call_details, request_iterator
):
return self._intercept_call(
continuation, client_call_details, request_iterator, True, True
)


class _GRPCStub(object):

def __init__(self, desc_pool: DescriptorPool, service_desc: ServiceDescriptor, channel: grpc.Channel):
def __init__(
self,
desc_pool: DescriptorPool,
service_desc: ServiceDescriptor,
channel: grpc.Channel,
):
self._desc_pool = desc_pool
for method_desc in service_desc.methods:
self._bind_grpc_method(service_desc, method_desc, channel)

def _bind_grpc_method(self, service_desc: ServiceDescriptor, method_desc: MethodDescriptor, channel: grpc.Channel):
def _bind_grpc_method(
self,
service_desc: ServiceDescriptor,
method_desc: MethodDescriptor,
channel: grpc.Channel,
):
method_name = method_desc.name
method_key = f'/{service_desc.full_name}/{method_name}'
request_desc = self._desc_pool.FindMessageTypeByName(method_desc.input_type.full_name)
request_message_desc = MessageFactory(self._desc_pool).GetPrototype(request_desc)
response_desc = self._desc_pool.FindMessageTypeByName(method_desc.output_type.full_name)
response_message_desc = MessageFactory(self._desc_pool).GetPrototype(response_desc)
method_key = f"/{service_desc.full_name}/{method_name}"
request_desc = self._desc_pool.FindMessageTypeByName(
method_desc.input_type.full_name
)
# request_message_desc = MessageFactory(self._desc_pool).GetPrototype(request_desc)
request_message_desc = GetMessageClass(request_desc)

response_desc = self._desc_pool.FindMessageTypeByName(
method_desc.output_type.full_name
)
# response_message_desc = MessageFactory(self._desc_pool).GetPrototype(response_desc)
response_message_desc = GetMessageClass(response_desc)

if method_desc.client_streaming and method_desc.server_streaming:
setattr(self, method_name, channel.stream_stream(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString
))
setattr(
self,
method_name,
channel.stream_stream(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString,
),
)
elif method_desc.client_streaming and not method_desc.server_streaming:
setattr(self, method_name, channel.stream_unary(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString
))
setattr(
self,
method_name,
channel.stream_unary(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString,
),
)
elif not method_desc.client_streaming and method_desc.server_streaming:
setattr(self, method_name, channel.unary_stream(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString
))
setattr(
self,
method_name,
channel.unary_stream(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString,
),
)
else:
setattr(self, method_name, channel.unary_unary(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString
))
setattr(
self,
method_name,
channel.unary_unary(
method_key,
request_serializer=request_message_desc.SerializeToString,
response_deserializer=response_message_desc.FromString,
),
)


class GRPCClient(object):

def __init__(self, channel, options, channel_key):
self._request_map = {}
self._api_resources = {}
Expand All @@ -193,7 +267,9 @@ def __init__(self, channel, options, channel_key):
self._desc_pool = DescriptorPool(self._reflection_db)
self._init_grpc_reflection()

_client_interceptor = _ClientInterceptor(options, channel_key, self._request_map)
_client_interceptor = _ClientInterceptor(
options, channel_key, self._request_map
)
_intercept_channel = grpc.intercept_channel(channel, _client_interceptor)
self._bind_grpc_stub(_intercept_channel)

Expand All @@ -206,9 +282,12 @@ def _init_grpc_reflection(self):
service_desc: ServiceDescriptor = self._desc_pool.FindServiceByName(service)
service_name = service_desc.name
for method_desc in service_desc.methods:
method_key = f'/{service}/{method_desc.name}'
request_desc = self._desc_pool.FindMessageTypeByName(method_desc.input_type.full_name)
self._request_map[method_key] = MessageFactory(self._desc_pool).GetPrototype(request_desc)
method_key = f"/{service}/{method_desc.name}"
request_desc = self._desc_pool.FindMessageTypeByName(
method_desc.input_type.full_name
)
# self._request_map[method_key] = MessageFactory(self._desc_pool).GetPrototype(request_desc)
self._request_map[method_key] = GetMessageClass(request_desc)

if service_desc.name not in self._api_resources:
self._api_resources[service_name] = []
Expand All @@ -219,7 +298,11 @@ def _bind_grpc_stub(self, intercept_channel: grpc.Channel):
for service in self._reflection_db.get_services():
service_desc: ServiceDescriptor = self._desc_pool.FindServiceByName(service)

setattr(self, service_desc.name, _GRPCStub(self._desc_pool, service_desc, intercept_channel))
setattr(
self,
service_desc.name,
_GRPCStub(self._desc_pool, service_desc, intercept_channel),
)


def _create_secure_channel(endpoint, options):
Expand All @@ -245,8 +328,8 @@ def client(endpoint=None, ssl_enabled=False, max_message_length=None, **client_o
options = []

if max_message_length:
options.append(('grpc.max_send_message_length', max_message_length))
options.append(('grpc.max_receive_message_length', max_message_length))
options.append(("grpc.max_send_message_length", max_message_length))
options.append(("grpc.max_receive_message_length", max_message_length))

if ssl_enabled:
channel = _create_secure_channel(endpoint, options)
Expand All @@ -256,12 +339,14 @@ def client(endpoint=None, ssl_enabled=False, max_message_length=None, **client_o
try:
grpc.channel_ready_future(channel).result(timeout=3)
except Exception as e:
raise ERROR_GRPC_CONNECTION(channel=endpoint, message='Channel is not ready.')
raise ERROR_GRPC_CONNECTION(
channel=endpoint, message="Channel is not ready."
)

try:
_GRPC_CHANNEL[endpoint] = GRPCClient(channel, client_opts, endpoint)
except Exception as e:
if hasattr(e, 'details'):
if hasattr(e, "details"):
raise ERROR_GRPC_CONNECTION(channel=endpoint, message=e.details())
else:
raise ERROR_GRPC_CONNECTION(channel=endpoint, message=str(e))
Expand All @@ -271,12 +356,16 @@ def client(endpoint=None, ssl_enabled=False, max_message_length=None, **client_o

def get_grpc_method(uri_info):
try:
conn = client(endpoint=uri_info['endpoint'], ssl_enabled=uri_info['ssl_enabled'])
return getattr(getattr(conn, uri_info['service']), uri_info['method'])
conn = client(
endpoint=uri_info["endpoint"], ssl_enabled=uri_info["ssl_enabled"]
)
return getattr(getattr(conn, uri_info["service"]), uri_info["method"])

except ERROR_BASE as e:
raise e
except Exception as e:
raise ERROR_GRPC_CONFIGURATION(endpoint=uri_info.get('endpoint'),
service=uri_info.get('service'),
method=uri_info.get('method'))
raise ERROR_GRPC_CONFIGURATION(
endpoint=uri_info.get("endpoint"),
service=uri_info.get("service"),
method=uri_info.get("method"),
)

0 comments on commit 239631e

Please sign in to comment.