diff --git a/smart_open/http.py b/smart_open/http.py index 7bbbe6f4..438ae0f4 100644 --- a/smart_open/http.py +++ b/smart_open/http.py @@ -50,7 +50,7 @@ def open_uri(uri, mode, transport_params): def open(uri, mode, kerberos=False, user=None, password=None, cert=None, - headers=None, timeout=None, buffer_size=DEFAULT_BUFFER_SIZE): + headers=None, timeout=None, session=None, buffer_size=DEFAULT_BUFFER_SIZE): """Implement streamed reader from a web site. Supports Kerberos and Basic HTTP authentication. @@ -73,6 +73,9 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None, Any headers to send in the request. If ``None``, the default headers are sent: ``{'Accept-Encoding': 'identity'}``. To use no headers at all, set this variable to an empty dict, ``{}``. + session: object, optional + The requests Session object to use with http get requests. + Can be used for OAuth2 clients. buffer_size: int, optional The buffer size to use when performing I/O. @@ -86,7 +89,7 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None, fobj = SeekableBufferedInputBase( uri, mode, buffer_size=buffer_size, kerberos=kerberos, user=user, password=password, cert=cert, - headers=headers, timeout=timeout, + headers=headers, session=session, timeout=timeout, ) fobj.name = os.path.basename(urllib.parse.urlparse(uri).path) return fobj @@ -97,7 +100,10 @@ def open(uri, mode, kerberos=False, user=None, password=None, cert=None, class BufferedInputBase(io.BufferedIOBase): def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, kerberos=False, user=None, password=None, cert=None, - headers=None, timeout=None): + headers=None, session=None, timeout=None): + + self.session = session or requests + if kerberos: import requests_kerberos auth = requests_kerberos.HTTPKerberosAuth() @@ -116,7 +122,14 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, self.timeout = timeout - self.response = requests.get( + self.response = session.get( + url, + auth=auth, + cert=cert, + stream=True, + headers=self.headers, + timeout=self.timeout, + ) if session is not None else requests.get( url, auth=auth, cert=cert, @@ -217,7 +230,7 @@ class SeekableBufferedInputBase(BufferedInputBase): def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, kerberos=False, user=None, password=None, cert=None, - headers=None, timeout=None): + headers=None, session=None, timeout=None): """ If Kerberos is True, will attempt to use the local Kerberos credentials. If cert is set, will try to use a client certificate @@ -227,6 +240,8 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE, """ self.url = url + self.session = session or requests + if kerberos: import requests_kerberos self.auth = requests_kerberos.HTTPKerberosAuth() @@ -332,7 +347,7 @@ def _partial_request(self, start_pos=None): if start_pos is not None: self.headers.update({"range": smart_open.utils.make_range_string(start_pos)}) - response = requests.get( + response = self.session.get( self.url, auth=self.auth, stream=True, diff --git a/smart_open/tests/test_http.py b/smart_open/tests/test_http.py index 9433043f..a70f86d7 100644 --- a/smart_open/tests/test_http.py +++ b/smart_open/tests/test_http.py @@ -15,7 +15,7 @@ import smart_open.http import smart_open.s3 import smart_open.constants - +import requests BYTES = b'i tried so hard and got so far but in the end it doesn\'t even matter' URL = 'http://localhost' @@ -159,6 +159,15 @@ def test_timeout_attribute(self): assert hasattr(reader, 'timeout') assert reader.timeout == timeout + @responses.activate + def test_session_attribute(self): + session = requests.Session() + responses.add_callback(responses.GET, URL, callback=request_callback) + reader = smart_open.open(URL, "rb", transport_params={'session': session}) + assert hasattr(reader, 'session') + assert reader.session == session + assert reader.read() == BYTES + @responses.activate def test_seek_implicitly_enabled(numbytes=10):