Skip to content

Commit

Permalink
use initial response for encoding, size and ranges support check (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Jun 23, 2021
1 parent 55bf1a2 commit 9514d4a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 23 deletions.
39 changes: 21 additions & 18 deletions src/webdav4/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ def __init__(
self.url = url
self._loc: int = 0
self._cm = iter_url(client, self.url, chunk_size=chunk_size)
self.size: Optional[int] = None
self._iterator: Optional[Iterator[bytes]] = None
self._response: Optional["HTTPResponse"] = None
self._initial_response: Optional["HTTPResponse"] = None

@property
def supports_ranges(self) -> bool:
Expand All @@ -128,13 +127,22 @@ def supports_ranges(self) -> bool:
to see if the server supports Range header or not. And, we want to
avoid checking that as much as possible.
"""
response = self._response
response = self._initial_response
if response and response.headers.get("Accept-Ranges") == "bytes":
return True
# consider if checking Accept-Ranges from OPTIONS request on self.url
# would be a better solution than using base url.
return self.client.detected_features.supports_ranges

@property
def size(self) -> Optional[int]:
"""Size of the file object."""
assert self._initial_response
content_length: str = self._initial_response.headers.get(
"Content-Length", ""
)
return int(content_length) if content_length.isdigit() else None

@property
def loc(self) -> int:
"""Keep track of location of the stream/file for callbacks."""
Expand All @@ -148,11 +156,7 @@ def loc(self, value: int) -> None:
def __enter__(self) -> "IterStream":
"""Send a streaming response."""
# pylint: disable=no-member
response, self._iterator = self._cm.__enter__()
# we don't want to get this on Ranged requests or retried ones
content_length: str = response.headers.get("Content-Length", "")
self._response = response
self.size = int(content_length) if content_length.isdigit() else None
self._initial_response, self._iterator = self._cm.__enter__()
return self

def __exit__(self, *args: Any) -> None:
Expand All @@ -162,16 +166,15 @@ def __exit__(self, *args: Any) -> None:
@property
def encoding(self) -> Optional[str]:
"""Encoding of the response."""
assert self._response
return self._response.encoding
assert self._initial_response
return self._initial_response.encoding

def close(self) -> None:
"""Close response if not already."""
if self._iterator or self._response:
if self._iterator:
self._cm.__exit__(None, None, None) # pylint: disable=no-member

self._iterator = None
self._response = None
self.buffer = b""

def seek(self, offset: int, whence: int = 0) -> int: # noqa: C901
Expand Down Expand Up @@ -199,9 +202,9 @@ def seek(self, offset: int, whence: int = 0) -> int: # noqa: C901
self.client, self.url, pos=loc, chunk_size=self.chunk_size
)
# pylint: disable=no-member
self._response, self._iterator = self._cm.__enter__()
_, self._iterator = self._cm.__enter__()
self.loc = loc
return self.loc
return loc

def tell(self) -> int:
"""Return current position of the fileobj."""
Expand All @@ -210,7 +213,7 @@ def tell(self) -> int:
@property
def closed(self) -> bool: # pylint: disable=invalid-overridden-method
"""Check whether the stream was closed or not."""
return not any([self._response, self._iterator])
return self._iterator is None

def readable(self) -> bool:
"""Stream is readable."""
Expand Down Expand Up @@ -250,10 +253,10 @@ def read1(self, num: int = -1) -> bytes:
return b""

if num <= 0:
self.buffer = b""
return chunk
output, self.buffer = chunk, b""
else:
output, self.buffer = chunk[:num], chunk[num:]

output, self.buffer = chunk[:num], chunk[num:]
self.loc += len(output)
return output

Expand Down
10 changes: 7 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from io import DEFAULT_BUFFER_SIZE, BytesIO
from pathlib import Path
from typing import Any, Callable, Dict
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -557,9 +558,12 @@ def test_open_binary(storage_dir: TmpDir, client: Client):
f.seek(-10)
with pytest.raises(ValueError):
f.seek(10, 3)
f.size = None # type: ignore
with pytest.raises(ValueError):
f.seek(-10, 2)
with mock.patch.object(
type(f), "size", new_callable=mock.PropertyMock
) as m:
m.return_value = None
with pytest.raises(ValueError):
f.seek(-10, 2)

assert f.closed

Expand Down
4 changes: 2 additions & 2 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def bad_iter_content(
# when we cannot detect support for ranges, we should just raise error
client.detected_features.supports_ranges = False
with client.open("sample.txt", mode="rb") as fd:
fd._response.headers.clear() # type: ignore
fd._initial_response.headers.clear() # type: ignore
with pytest.raises(HTTPNetworkError):
fd.read()

Expand All @@ -116,7 +116,7 @@ def bad_iter_content(
assert str(exc_info.value) == "server does not support ranges"

with fs.open("sample.txt", mode="rb") as fd:
fd.reader._response.headers.clear() # type: ignore
fd.reader._initial_response.headers.clear() # type: ignore
with pytest.raises(HTTPNetworkError):
fd.read()

Expand Down

0 comments on commit 9514d4a

Please sign in to comment.