diff --git a/impala/_thrift_api.py b/impala/_thrift_api.py index 38cafdaf8..697f71f4c 100644 --- a/impala/_thrift_api.py +++ b/impala/_thrift_api.py @@ -68,12 +68,14 @@ class ImpalaHttpClient(TTransportBase): MIN_REQUEST_SIZE_FOR_EXPECT = 1024 def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=None, - key_file=None, ssl_context=None, http_cookie_names=None): + key_file=None, ssl_context=None, http_cookie_names=None, + get_user_custom_headers_func=None): """ImpalaHttpClient supports two different types of construction: ImpalaHttpClient(host, port, path) - deprecated ImpalaHttpClient(uri, [port=, path=, cafile=, cert_file=, - key_file=, ssl_context=, http_cookie_names=]) + key_file=, ssl_context=, http_cookie_names=], + get_user_custom_headers_func=) Only the second supports https. To properly authenticate against the server, provide the client's identity by specifying cert_file and key_file. To properly @@ -85,6 +87,10 @@ def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=Non one of these names is returned in an http response by the server or an intermediate proxy then it will be included in each subsequent request for the same connection. If it is set as wildcards, all cookies in an http response will be preserved. + The optional get_user_custom_headers_func parameter can be used to add http headers + to outgoing http messages when using hs2-http protocol. The parameter should be a + function returning a list of tuples, each tuple containing a key-value pair + representing the header name and value. """ if port is not None: warnings.warn( @@ -158,6 +164,12 @@ def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=Non # new request. self.__custom_headers = None self.__get_custom_headers_func = None + # __user_custom_headers is a list of tuples, each tuple contains a key-value pair. + self.__user_custom_headers = None + if get_user_custom_headers_func: + self.__get_user_custom_headers_func = get_user_custom_headers_func + else: + self.__get_user_custom_headers_func = None # the default user agent if none is provied self.__custom_user_agent = 'Python/ImpylaHttpClient' @@ -220,12 +232,24 @@ def setCustomUserAgent(self, user_agent): def setGetCustomHeadersFunc(self, func): self.__get_custom_headers_func = func - # Update HTTP headers based on the saved cookies and auth mechanism. + # Set the user-defined callback function which adds custom HTTP headers to outgoing + # messages. + def setGetUserDefinedCustomHeadersFunc(self, func): + self.__get_user_custom_headers_func = func + + # Update outgoing HTTP headers. + # This is done by two callback functions, if present + # __get_custom_headers_func adds headers based on the saved cookies and auth + # mechanism. + # __get_user_custom_headers_func adds custom user-supplied http headers. def refreshCustomHeaders(self): if self.__get_custom_headers_func: cookie_header, has_auth_cookie = self.getHttpCookieHeaderForRequest() self.__custom_headers = \ self.__get_custom_headers_func(cookie_header, has_auth_cookie) + if self.__get_user_custom_headers_func: + self.__user_custom_headers = \ + self.__get_user_custom_headers_func() # Return first value as a cookie list for Cookie header. It's a list of name-value # pairs in the form of =. Pairs in the list are separated by @@ -329,6 +353,9 @@ def sendRequestRecvResp(data): if self.__custom_headers: for key, val in six.iteritems(self.__custom_headers): self.__http.putheader(key, val) + if self.__user_custom_headers: + for key, val in self.__user_custom_headers: + self.__http.putheader(key, val) self.__http.endheaders() @@ -393,7 +420,8 @@ def isOpen(self): def get_http_transport(host, port, http_path, timeout=None, use_ssl=False, ca_cert=None, auth_mechanism='NOSASL', user=None, password=None, kerberos_host=None, kerberos_service_name=None, - http_cookie_names=None, jwt=None, user_agent=None): + http_cookie_names=None, jwt=None, user_agent=None, + get_user_custom_headers_func=None): # TODO: support timeout if timeout is not None: log.error('get_http_transport does not support a timeout') @@ -408,12 +436,16 @@ def get_http_transport(host, port, http_path, timeout=None, use_ssl=False, url = 'https://%s:%s/%s' % (host, port, http_path) log.debug('get_http_transport url=%s', url) # TODO(#362): Add server authentication with thrift 0.12. - transport = ImpalaHttpClient(url, ssl_context=ssl_ctx, - http_cookie_names=http_cookie_names) + transport = ImpalaHttpClient( + url, ssl_context=ssl_ctx, + http_cookie_names=http_cookie_names, + get_user_custom_headers_func=get_user_custom_headers_func) else: url = 'http://%s:%s/%s' % (host, port, http_path) log.debug('get_http_transport url=%s', url) - transport = ImpalaHttpClient(url, http_cookie_names=http_cookie_names) + transport = ImpalaHttpClient( + url, http_cookie_names=http_cookie_names, + get_user_custom_headers_func=get_user_custom_headers_func) # set custom user agent if provided by user if user_agent: diff --git a/impala/dbapi.py b/impala/dbapi.py index 8bfa596b9..c39ecaecd 100644 --- a/impala/dbapi.py +++ b/impala/dbapi.py @@ -44,7 +44,8 @@ def connect(host='localhost', port=21050, database=None, timeout=None, ldap_user=None, ldap_password=None, use_kerberos=None, protocol=None, krb_host=None, use_http_transport=False, http_path='', auth_cookie_names=None, http_cookie_names=None, - retries=3, jwt=None, user_agent=None): + retries=3, jwt=None, user_agent=None, + get_user_custom_headers_func=None): """Get a connection to HiveServer2 (HS2). These options are largely compatible with the impala-shell command line @@ -105,6 +106,10 @@ def connect(host='localhost', port=21050, database=None, timeout=None, 'Python/ImpylaHttpClient' is used use_ldap : bool, optional Specify `auth_mechanism='LDAP'` instead. + get_user_custom_headers_func : function, optional + Used to add custom headers to the http messages when using hs2-http protocol. + This is a function returning a list of tuples, each tuple contains a key-value + pair. This allows duplicate headers to be set. .. deprecated:: 0.18.0 auth_cookie_names : list of str or str, optional @@ -203,7 +208,8 @@ def connect(host='localhost', port=21050, database=None, timeout=None, http_path=http_path, http_cookie_names=http_cookie_names, retries=retries, - jwt=jwt, user_agent=user_agent) + jwt=jwt, user_agent=user_agent, + get_user_custom_headers_func=get_user_custom_headers_func) return hs2.HiveServer2Connection(service, default_db=database) diff --git a/impala/hiveserver2.py b/impala/hiveserver2.py index 591691683..7b4fcb021 100644 --- a/impala/hiveserver2.py +++ b/impala/hiveserver2.py @@ -915,7 +915,7 @@ def connect(host, port, timeout=None, use_ssl=False, ca_cert=None, user=None, password=None, kerberos_service_name='impala', auth_mechanism=None, krb_host=None, use_http_transport=False, http_path='', http_cookie_names=None, retries=3, jwt=None, - user_agent=None): + user_agent=None, get_user_custom_headers_func=None): log.debug('Connecting to HiveServer2 %s:%s with %s authentication ' 'mechanism', host, port, auth_mechanism) @@ -930,14 +930,16 @@ def connect(host, port, timeout=None, use_ssl=False, ca_cert=None, raise NotSupportedError("Server authentication is not supported " + "with HTTP endpoints") - transport = get_http_transport(host, port, http_path=http_path, - use_ssl=use_ssl, ca_cert=ca_cert, - auth_mechanism=auth_mechanism, - user=user, password=password, - kerberos_host=kerberos_host, - kerberos_service_name=kerberos_service_name, - http_cookie_names=http_cookie_names, - jwt=jwt, user_agent=user_agent) + transport = get_http_transport( + host, port, http_path=http_path, + use_ssl=use_ssl, ca_cert=ca_cert, + auth_mechanism=auth_mechanism, + user=user, password=password, + kerberos_host=kerberos_host, + kerberos_service_name=kerberos_service_name, + http_cookie_names=http_cookie_names, + jwt=jwt, user_agent=user_agent, + get_user_custom_headers_func=get_user_custom_headers_func) else: sock = get_socket(host, port, use_ssl, ca_cert) diff --git a/impala/tests/test_http_connect.py b/impala/tests/test_http_connect.py index 496b1ac0d..0947b758e 100644 --- a/impala/tests/test_http_connect.py +++ b/impala/tests/test_http_connect.py @@ -17,6 +17,7 @@ from contextlib import closing import pytest +import requests import six from six.moves import SimpleHTTPServer from six.moves import http_client @@ -66,10 +67,83 @@ def __init__(self): yield server # Cleanup after test. - if server.httpd is not None: - server.httpd.shutdown() - if server.http_server_thread is not None: - server.http_server_thread.join() + shutdown_server(server) + + +@pytest.yield_fixture +def http_proxy_server(): + """A fixture that creates a reverse http proxy.""" + server = TestHTTPServerProxy(RequestHandlerProxy) + yield server + + # Cleanup after test. + shutdown_server(server) + +class RequestHandlerProxy(SimpleHTTPServer.SimpleHTTPRequestHandler): + """A custom http handler that acts as a reverse http proxy. This proxy will forward http + messages to Impala, and copy the responses back to the client. In addition, it will save + the outgoing http message headers in a class variable so tha they can be accessed by + test code.""" + + # This class variable is used to store the most recently seen outgoing http + # message headers. + saved_headers=None + + def __init__(self, request, client_address, server): + SimpleHTTPServer.SimpleHTTPRequestHandler.__init__(self, request, client_address, + server) + + def do_POST(self): + # Read the body of the incoming http post message. + data_string = self.rfile.read(int(self.headers['Content-Length'])) + # Save the http headers from the message in a class variable. + RequestHandlerProxy.saved_headers = self.decode_raw_headers() + # Forward the http post message to Impala and get a response message. + response = requests.post(url="http://localhost:28000/cliservice", + headers=self.headers, data=data_string) + # Send the response message back to the client. + self.send_response(code=response.status_code) + # Send the http headers. + # In python3 response.headers is a CaseInsensitiveDict + # In python2 response.headers is a dict + for key, value in response.headers.items(): + self.send_header(keyword=key, value=value) + self.end_headers() + # Send the message body. + self.wfile.write(response.content) + self.wfile.close() + + def decode_raw_headers(self): + """Decode a list of header strings into a list of tuples, each tuple containing a + key-value pair. The details of how to get the headers are differs between Python2 and + Python3""" + if six.PY2: + header_list = [] + # In Python2 self.headers is an instance of mimetools.Message and + # self.headers.headers is a list of raw header strings. + # An example header string: 'Accept-Encoding: identity\\r\\n' + for header in self.headers.headers: + stripped = header.strip() + key, value = stripped.split(':', 1) + header_list.append((key.strip(), value.strip())) + return header_list + if six.PY3: + # In Python 3 self.headers._headers is what we need + return self.headers._headers + + +class TestHTTPServerProxy(object): + def __init__(self, clazz): + self.clazz = clazz + self.HOST = "localhost" + self.PORT = get_unused_port() + self.httpd = socketserver.TCPServer((self.HOST, self.PORT), clazz) + self.http_server_thread = threading.Thread(target=self.httpd.serve_forever) + self.http_server_thread.start() + + def get_headers(self): + """Return the most recently seen outgoing http message headers.""" + return self.clazz.saved_headers from impala.dbapi import connect @@ -93,6 +167,34 @@ def test_http_interactions(self, http_503_server): assert e.code == http_client.SERVICE_UNAVAILABLE assert e.body.decode("utf-8") == "extra text" + def test_duplicate_headers(self, http_proxy_server): + """Test that we can use 'connect' with the get_user_custom_headers_func parameter + to add duplicate http message headers to outgoing messages.""" + con = connect("localhost", http_proxy_server.PORT, use_http_transport=True, + get_user_custom_headers_func=get_user_custom_headers_func) + cur = con.cursor() + cur.execute('select 1') + rows = cur.fetchall() + assert rows == [(1,)] + + # Get the outgoing message headers from the last outgoing http message. + headers = http_proxy_server.get_headers() + # For sanity test the count of a few simple expected headers. + assert count_tuples_with_key(headers, "Host") == 1 + assert count_tuples_with_key(headers, "User-Agent") == 1 + # Check that the custom headers are present. + assert count_tuples_with_key(headers, "key1") == 2 + assert count_tuples_with_key(headers, "key2") == 1 + assert count_tuples_with_key(headers, "key3") == 0 + +def get_user_custom_headers_func(): + """Insert some custom http headers, including a duplicate.""" + headers = [] + headers.append(('key1', 'value1')) + headers.append(('key1', 'value2')) + headers.append(('key2', 'value3')) + return headers + def get_unused_port(): """ Find an unused port http://stackoverflow.com/questions/1365265 """ @@ -100,3 +202,24 @@ def get_unused_port(): s.bind(('', 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] + +def shutdown_server(server): + """Helper method to shutdown a http server.""" + if server.httpd is not None: + server.httpd.shutdown() + if server.http_server_thread is not None: + server.http_server_thread.join() + +def count_tuples_with_key(tuple_list, key_to_count): + """Counts the number of tuples in a list that have a specific key. + Args: + tuple_list: A list of key-value tuples. + key_to_count: The key to count occurrences of. + Returns: + The number of tuples with the specified key. + """ + count = 0 + for key, _ in tuple_list: + if key == key_to_count: + count += 1 + return count