Skip to content

Commit

Permalink
Allow users to add custom http headers when using hs2-http (#557)
Browse files Browse the repository at this point in the history
In a modern Impala deployment hs2-http protocol is used in a system
where http messages pass through one or more http proxies. Some of
these proxies add their own http message headers to messages as they
are forwarded. It would be useful to test Impala with some of the
message headers that are added by http proxies. In particular the case
where there are multiple http headers with the same name is hard to
simulate with clients such as Impyla or Impala Shell. This is partly
because these clients store http headers in a Python dict which does
not allow duplicate keys.

Extend the Impyla connect() method to add
a 'get_user_custom_headers_func' parameter. This specifies a function
that is called as http message headers are being written. The function
should return a list of tuples, each tuple containing a key-value pair.
This allows duplicate headers to be set on outgoing messages.

TESTING
Add test code which implements a reverse http proxy, which allows test
code to access the outgoing http message headers generated by Impyla.
Add a test using this proxy which validates the new feature.

The new test code requires a new python package 'requests'. I think
there is not away to add this requirement automatically so I added a
note to README.md

All tests pass on Python2 and Python3.

Fix TestHS2FaultInjection to  use setup_method() and teardown_method()
so as to work in Python3
  • Loading branch information
bartash committed Nov 5, 2024
1 parent e4c7616 commit 1141dde
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 25 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Optional:

* `sqlalchemy` for the SQLAlchemy engine

* `pytest` for running tests; `unittest2` for testing on Python 2.6
* `pytest` and `requests` for running tests; `unittest2` for testing on Python 2.6


#### System Kerberos
Expand Down
46 changes: 39 additions & 7 deletions impala/_thrift_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,14 @@ class ImpalaHttpClient(TTransportBase):
MIN_REQUEST_SIZE_FOR_EXPECT = 1024

def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=None,
key_file=None, ssl_context=None, http_cookie_names=None):
key_file=None, ssl_context=None, http_cookie_names=None,
get_user_custom_headers_func=None):
"""ImpalaHttpClient supports two different types of construction:
ImpalaHttpClient(host, port, path) - deprecated
ImpalaHttpClient(uri, [port=<n>, path=<s>, cafile=<filename>, cert_file=<filename>,
key_file=<filename>, ssl_context=<context>, http_cookie_names=<cookienamelist>])
key_file=<filename>, ssl_context=<context>, http_cookie_names=<cookienamelist>],
get_user_custom_headers_func=<function_setting_http_headers>)
Only the second supports https. To properly authenticate against the server,
provide the client's identity by specifying cert_file and key_file. To properly
Expand All @@ -85,6 +87,10 @@ def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=Non
one of these names is returned in an http response by the server or an intermediate
proxy then it will be included in each subsequent request for the same connection. If
it is set as wildcards, all cookies in an http response will be preserved.
The optional get_user_custom_headers_func parameter can be used to add http headers
to outgoing http messages when using hs2-http protocol. The parameter should be a
function returning a list of tuples, each tuple containing a key-value pair
representing the header name and value.
"""
if port is not None:
warnings.warn(
Expand Down Expand Up @@ -158,6 +164,12 @@ def __init__(self, uri_or_host, port=None, path=None, cafile=None, cert_file=Non
# new request.
self.__custom_headers = None
self.__get_custom_headers_func = None
# __user_custom_headers is a list of tuples, each tuple contains a key-value pair.
self.__user_custom_headers = None
if get_user_custom_headers_func:
self.__get_user_custom_headers_func = get_user_custom_headers_func
else:
self.__get_user_custom_headers_func = None
# the default user agent if none is provied
self.__custom_user_agent = 'Python/ImpylaHttpClient'

Expand Down Expand Up @@ -220,12 +232,24 @@ def setCustomUserAgent(self, user_agent):
def setGetCustomHeadersFunc(self, func):
self.__get_custom_headers_func = func

# Update HTTP headers based on the saved cookies and auth mechanism.
# Set the user-defined callback function which adds custom HTTP headers to outgoing
# messages.
def setGetUserDefinedCustomHeadersFunc(self, func):
self.__get_user_custom_headers_func = func

# Update outgoing HTTP headers.
# This is done by two callback functions, if present
# __get_custom_headers_func adds headers based on the saved cookies and auth
# mechanism.
# __get_user_custom_headers_func adds custom user-supplied http headers.
def refreshCustomHeaders(self):
if self.__get_custom_headers_func:
cookie_header, has_auth_cookie = self.getHttpCookieHeaderForRequest()
self.__custom_headers = \
self.__get_custom_headers_func(cookie_header, has_auth_cookie)
if self.__get_user_custom_headers_func:
self.__user_custom_headers = \
self.__get_user_custom_headers_func()

# Return first value as a cookie list for Cookie header. It's a list of name-value
# pairs in the form of <cookie-name>=<cookie-value>. Pairs in the list are separated by
Expand Down Expand Up @@ -329,6 +353,9 @@ def sendRequestRecvResp(data):
if self.__custom_headers:
for key, val in six.iteritems(self.__custom_headers):
self.__http.putheader(key, val)
if self.__user_custom_headers:
for key, val in self.__user_custom_headers:
self.__http.putheader(key, val)

self.__http.endheaders()

Expand Down Expand Up @@ -393,7 +420,8 @@ def isOpen(self):
def get_http_transport(host, port, http_path, timeout=None, use_ssl=False,
ca_cert=None, auth_mechanism='NOSASL', user=None,
password=None, kerberos_host=None, kerberos_service_name=None,
http_cookie_names=None, jwt=None, user_agent=None):
http_cookie_names=None, jwt=None, user_agent=None,
get_user_custom_headers_func=None):
# TODO: support timeout
if timeout is not None:
log.error('get_http_transport does not support a timeout')
Expand All @@ -408,12 +436,16 @@ def get_http_transport(host, port, http_path, timeout=None, use_ssl=False,
url = 'https://%s:%s/%s' % (host, port, http_path)
log.debug('get_http_transport url=%s', url)
# TODO(#362): Add server authentication with thrift 0.12.
transport = ImpalaHttpClient(url, ssl_context=ssl_ctx,
http_cookie_names=http_cookie_names)
transport = ImpalaHttpClient(
url, ssl_context=ssl_ctx,
http_cookie_names=http_cookie_names,
get_user_custom_headers_func=get_user_custom_headers_func)
else:
url = 'http://%s:%s/%s' % (host, port, http_path)
log.debug('get_http_transport url=%s', url)
transport = ImpalaHttpClient(url, http_cookie_names=http_cookie_names)
transport = ImpalaHttpClient(
url, http_cookie_names=http_cookie_names,
get_user_custom_headers_func=get_user_custom_headers_func)

# set custom user agent if provided by user
if user_agent:
Expand Down
10 changes: 8 additions & 2 deletions impala/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def connect(host='localhost', port=21050, database=None, timeout=None,
ldap_user=None, ldap_password=None, use_kerberos=None,
protocol=None, krb_host=None, use_http_transport=False,
http_path='', auth_cookie_names=None, http_cookie_names=None,
retries=3, jwt=None, user_agent=None):
retries=3, jwt=None, user_agent=None,
get_user_custom_headers_func=None):
"""Get a connection to HiveServer2 (HS2).
These options are largely compatible with the impala-shell command line
Expand Down Expand Up @@ -105,6 +106,10 @@ def connect(host='localhost', port=21050, database=None, timeout=None,
'Python/ImpylaHttpClient' is used
use_ldap : bool, optional
Specify `auth_mechanism='LDAP'` instead.
get_user_custom_headers_func : function, optional
Used to add custom headers to the http messages when using hs2-http protocol.
This is a function returning a list of tuples, each tuple contains a key-value
pair. This allows duplicate headers to be set.
.. deprecated:: 0.18.0
auth_cookie_names : list of str or str, optional
Expand Down Expand Up @@ -203,7 +208,8 @@ def connect(host='localhost', port=21050, database=None, timeout=None,
http_path=http_path,
http_cookie_names=http_cookie_names,
retries=retries,
jwt=jwt, user_agent=user_agent)
jwt=jwt, user_agent=user_agent,
get_user_custom_headers_func=get_user_custom_headers_func)
return hs2.HiveServer2Connection(service, default_db=database)


Expand Down
20 changes: 11 additions & 9 deletions impala/hiveserver2.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ def connect(host, port, timeout=None, use_ssl=False, ca_cert=None,
user=None, password=None, kerberos_service_name='impala',
auth_mechanism=None, krb_host=None, use_http_transport=False,
http_path='', http_cookie_names=None, retries=3, jwt=None,
user_agent=None):
user_agent=None, get_user_custom_headers_func=None):
log.debug('Connecting to HiveServer2 %s:%s with %s authentication '
'mechanism', host, port, auth_mechanism)

Expand All @@ -930,14 +930,16 @@ def connect(host, port, timeout=None, use_ssl=False, ca_cert=None,
raise NotSupportedError("Server authentication is not supported " +
"with HTTP endpoints")

transport = get_http_transport(host, port, http_path=http_path,
use_ssl=use_ssl, ca_cert=ca_cert,
auth_mechanism=auth_mechanism,
user=user, password=password,
kerberos_host=kerberos_host,
kerberos_service_name=kerberos_service_name,
http_cookie_names=http_cookie_names,
jwt=jwt, user_agent=user_agent)
transport = get_http_transport(
host, port, http_path=http_path,
use_ssl=use_ssl, ca_cert=ca_cert,
auth_mechanism=auth_mechanism,
user=user, password=password,
kerberos_host=kerberos_host,
kerberos_service_name=kerberos_service_name,
http_cookie_names=http_cookie_names,
jwt=jwt, user_agent=user_agent,
get_user_custom_headers_func=get_user_custom_headers_func)
else:
sock = get_socket(host, port, use_ssl, ca_cert)

Expand Down
4 changes: 2 additions & 2 deletions impala/tests/test_hs2_fault_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ def _read(self, sz):
class TestHS2FaultInjection(object):
"""Class for testing the http fault injection in various rpcs used by Impyla"""

def setup(self):
def setup_method(self):
url = 'http://%s:%s/%s' % (ENV.host, ENV.http_port, "cliservice")
self.transport = FaultInjectingHttpClient(url)
self.configuration = {'idle_session_timeout': '30'}

def teardown(self):
def teardown_method(self):
self.transport.disable_fault()

def connect(self):
Expand Down
133 changes: 129 additions & 4 deletions impala/tests/test_http_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from contextlib import closing

import pytest
import requests
import six
from six.moves import SimpleHTTPServer
from six.moves import http_client
Expand Down Expand Up @@ -66,10 +67,85 @@ def __init__(self):
yield server

# Cleanup after test.
if server.httpd is not None:
server.httpd.shutdown()
if server.http_server_thread is not None:
server.http_server_thread.join()
shutdown_server(server)


@pytest.yield_fixture
def http_proxy_server():
"""A fixture that creates a reverse http proxy."""

class RequestHandlerProxy(SimpleHTTPServer.SimpleHTTPRequestHandler):
"""A custom http handler that acts as a reverse http proxy. This proxy will forward
http messages to Impala, and copy the responses back to the client. In addition, it
will save the outgoing http message headers in a class variable so that they can be
accessed by test code."""

# This class variable is used to store the most recently seen outgoing http
# message headers.
saved_headers=None

def __init__(self, request, client_address, server):
SimpleHTTPServer.SimpleHTTPRequestHandler.__init__(self, request, client_address,
server)

def do_POST(self):
# Read the body of the incoming http post message.
data_string = self.rfile.read(int(self.headers['Content-Length']))
# Save the http headers from the message in a class variable.
RequestHandlerProxy.saved_headers = self.decode_raw_headers()
# Forward the http post message to Impala and get a response message.
response = requests.post(url="http://localhost:28000/cliservice",
headers=self.headers, data=data_string)
# Send the response message back to the client.
self.send_response(code=response.status_code)
# Send the http headers.
# In python3 response.headers is a CaseInsensitiveDict
# In python2 response.headers is a dict
for key, value in response.headers.items():
self.send_header(keyword=key, value=value)
self.end_headers()
# Send the message body.
self.wfile.write(response.content)
self.wfile.close()

def decode_raw_headers(self):
"""Decode a list of header strings into a list of tuples, each tuple containing a
key-value pair. The details of how to get the headers are differs between Python2
and Python3"""
if six.PY2:
header_list = []
# In Python2 self.headers is an instance of mimetools.Message and
# self.headers.headers is a list of raw header strings.
# An example header string: 'Accept-Encoding: identity\\r\\n'
for header in self.headers.headers:
stripped = header.strip()
key, value = stripped.split(':', 1)
header_list.append((key.strip(), value.strip()))
return header_list
if six.PY3:
# In Python 3 self.headers._headers is what we need
return self.headers._headers


class TestHTTPServerProxy(object):
def __init__(self, clazz):
self.clazz = clazz
self.HOST = "localhost"
self.PORT = get_unused_port()
self.httpd = socketserver.TCPServer((self.HOST, self.PORT), clazz)
self.http_server_thread = threading.Thread(target=self.httpd.serve_forever)
self.http_server_thread.start()

def get_headers(self):
"""Return the most recently seen outgoing http message headers."""
return self.clazz.saved_headers

server = TestHTTPServerProxy(RequestHandlerProxy)
yield server

# Cleanup after test.
shutdown_server(server)


from impala.dbapi import connect

Expand All @@ -93,10 +169,59 @@ def test_http_interactions(self, http_503_server):
assert e.code == http_client.SERVICE_UNAVAILABLE
assert e.body.decode("utf-8") == "extra text"

def test_duplicate_headers(self, http_proxy_server):
"""Test that we can use 'connect' with the get_user_custom_headers_func parameter
to add duplicate http message headers to outgoing messages."""
con = connect("localhost", http_proxy_server.PORT, use_http_transport=True,
get_user_custom_headers_func=get_user_custom_headers_func)
cur = con.cursor()
cur.execute('select 1')
rows = cur.fetchall()
assert rows == [(1,)]

# Get the outgoing message headers from the last outgoing http message.
headers = http_proxy_server.get_headers()
# For sanity test the count of a few simple expected headers.
assert count_tuples_with_key(headers, "Host") == 1
assert count_tuples_with_key(headers, "User-Agent") == 1
# Check that the custom headers are present.
assert count_tuples_with_key(headers, "key1") == 2
assert count_tuples_with_key(headers, "key2") == 1
assert count_tuples_with_key(headers, "key3") == 0

def get_user_custom_headers_func():
"""Insert some custom http headers, including a duplicate."""
headers = []
headers.append(('key1', 'value1'))
headers.append(('key1', 'value2'))
headers.append(('key2', 'value3'))
return headers


def get_unused_port():
""" Find an unused port http://stackoverflow.com/questions/1365265 """
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]

def shutdown_server(server):
"""Helper method to shutdown a http server."""
if server.httpd is not None:
server.httpd.shutdown()
if server.http_server_thread is not None:
server.http_server_thread.join()

def count_tuples_with_key(tuple_list, key_to_count):
"""Counts the number of tuples in a list that have a specific key.
Args:
tuple_list: A list of key-value tuples.
key_to_count: The key to count occurrences of.
Returns:
The number of tuples with the specified key.
"""
count = 0
for key, _ in tuple_list:
if key == key_to_count:
count += 1
return count

0 comments on commit 1141dde

Please sign in to comment.