diff --git a/src/requests/models.py b/src/requests/models.py index 8f56ca7d23..d90da6d4ae 100644 --- a/src/requests/models.py +++ b/src/requests/models.py @@ -65,6 +65,9 @@ super_len, to_key_val_list, ) +from requests.exceptions import ChunkedEncodingError, ConnectionError, ContentDecodingError, SSLError as RequestsSSLError, StreamConsumedError +from requests.cookies import cookiejar_from_dict +from requests.structures import CaseInsensitiveDict #: The set of HTTP status codes that indicate an automatically #: processable redirect. @@ -812,46 +815,36 @@ def iter_content(self, chunk_size=1, decode_unicode=False): If decode_unicode is True, content will be decoded using the best available encoding based on the response. """ - - def generate(): - # Special case for urllib3. - if hasattr(self.raw, "stream"): + if self._content_consumed and isinstance(self._content, bool): + raise StreamConsumedError() + if chunk_size is not None and not isinstance(chunk_size, int): + raise TypeError( + f"chunk_size must be an int, it is instead a {type(chunk_size)}." + ) + + def generate(raw, chunk_size): + if hasattr(raw, "stream"): try: - yield from self.raw.stream(chunk_size, decode_content=True) - except ProtocolError as e: - raise ChunkedEncodingError(e) - except DecodeError as e: - raise ContentDecodingError(e) - except ReadTimeoutError as e: - raise ConnectionError(e) - except SSLError as e: - raise RequestsSSLError(e) + yield from raw.stream(chunk_size, decode_content=True) + except (ProtocolError, DecodeError, ReadTimeoutError, SSLError) as e: + if isinstance(e, ProtocolError): + raise ChunkedEncodingError(e) + elif isinstance(e, DecodeError): + raise ContentDecodingError(e) + elif isinstance(e, ReadTimeoutError): + raise ConnectionError(e) + elif isinstance(e, SSLError): + raise RequestsSSLError(e) else: - # Standard file-like object. while True: - chunk = self.raw.read(chunk_size) + chunk = raw.read(chunk_size) if not chunk: break yield chunk - self._content_consumed = True - - if self._content_consumed and isinstance(self._content, bool): - raise StreamConsumedError() - elif chunk_size is not None and not isinstance(chunk_size, int): - raise TypeError( - f"chunk_size must be an int, it is instead a {type(chunk_size)}." - ) - # simulate reading small chunks of the content - reused_chunks = iter_slices(self._content, chunk_size) - - stream_chunks = generate() - - chunks = reused_chunks if self._content_consumed else stream_chunks - - if decode_unicode: - chunks = stream_decode_response_unicode(chunks, self) - + + reused_chunks = iter_slices(self._content, chunk_size) if self._content_consumed else generate(self.raw, chunk_size) + chunks = stream_decode_response_unicode(reused_chunks, self) if decode_unicode else reused_chunks return chunks def iter_lines( @@ -1035,3 +1028,7 @@ def close(self): release_conn = getattr(self.raw, "release_conn", None) if release_conn is not None: release_conn() + + def close(self): + if self.raw and hasattr(self.raw, "close"): + self.raw.close() diff --git a/src/requests/utils.py b/src/requests/utils.py index 699683e5d9..8409166cef 100644 --- a/src/requests/utils.py +++ b/src/requests/utils.py @@ -58,6 +58,8 @@ UnrewindableBodyError, ) from .structures import CaseInsensitiveDict +from requests.cookies import cookiejar_from_dict +from requests.structures import CaseInsensitiveDict NETRC_FILES = (".netrc", "_netrc") @@ -566,12 +568,12 @@ def get_encoding_from_headers(headers): def stream_decode_response_unicode(iterator, r): """Stream decodes an iterator.""" - - if r.encoding is None: + encoding = r.encoding + if encoding is None: yield from iterator return - decoder = codecs.getincrementaldecoder(r.encoding)(errors="replace") + decoder = codecs.getincrementaldecoder(encoding)(errors="replace") for chunk in iterator: rv = decoder.decode(chunk) if rv: @@ -583,12 +585,9 @@ def stream_decode_response_unicode(iterator, r): def iter_slices(string, slice_length): """Iterate over slices of a string.""" - pos = 0 if slice_length is None or slice_length <= 0: slice_length = len(string) - while pos < len(string): - yield string[pos : pos + slice_length] - pos += slice_length + return (string[pos:pos + slice_length] for pos in range(0, len(string), slice_length)) def get_unicode_from_response(r):