Skip to content

Commit

Permalink
Add unit test for PyArrorFileIO.fs_by_scheme cache behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
jiakai-li committed Dec 21, 2024
1 parent eb5e491 commit 0c61ac8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
15 changes: 7 additions & 8 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,19 +354,12 @@ def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSyste
if scheme in {"s3", "s3a", "s3n", "oss"}:
from pyarrow.fs import S3FileSystem, resolve_s3_region

bucket_region = None
if netloc:
try:
bucket_region = resolve_s3_region(netloc)
except OSError:
pass

client_kwargs: Dict[str, Any] = {
"endpoint_override": self.properties.get(S3_ENDPOINT),
"access_key": get_first_property_value(self.properties, S3_ACCESS_KEY_ID, AWS_ACCESS_KEY_ID),
"secret_key": get_first_property_value(self.properties, S3_SECRET_ACCESS_KEY, AWS_SECRET_ACCESS_KEY),
"session_token": get_first_property_value(self.properties, S3_SESSION_TOKEN, AWS_SESSION_TOKEN),
"region": bucket_region or get_first_property_value(self.properties, S3_REGION, AWS_REGION),
"region": get_first_property_value(self.properties, S3_REGION, AWS_REGION),
}

if proxy_uri := self.properties.get(S3_PROXY_URI):
Expand All @@ -384,6 +377,12 @@ def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSyste
if force_virtual_addressing := self.properties.get(S3_FORCE_VIRTUAL_ADDRESSING):
client_kwargs["force_virtual_addressing"] = property_as_bool(self.properties, force_virtual_addressing, False)

# Override the default s3.region if netloc(bucket) resolves to a different region
try:
client_kwargs["region"] = resolve_s3_region(netloc)
except OSError:
pass

return S3FileSystem(**client_kwargs)
elif scheme in ("hdfs", "viewfs"):
from pyarrow.fs import HadoopFileSystem
Expand Down
33 changes: 32 additions & 1 deletion tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def test_pyarrow_unified_session_properties() -> None:
s3_fileio = PyArrowFileIO(properties=session_properties)
filename = str(uuid.uuid4())

mock_s3_region_resolver.return_value = None
mock_s3_region_resolver.return_value = "client.region"
s3_fileio.new_input(location=f"s3://warehouse/{filename}")

mock_s3fs.assert_called_with(
Expand Down Expand Up @@ -2076,3 +2076,34 @@ def test__to_requested_schema_timestamps_without_downcast_raises_exception(
_to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False)

assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" in str(exc_info.value)


def test_pyarrow_file_io_fs_by_scheme_cache() -> None:
pyarrow_file_io = PyArrowFileIO()
us_east_1_region = "us-eas1-1"
ap_southeast_2_region = "ap-southeast-2"

with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver:
# Call with new argument resolves region automatically
mock_s3_region_resolver.return_value = us_east_1_region
filesystem_us = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket")
assert filesystem_us.region == us_east_1_region
assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 1 # type: ignore
assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 1 # type: ignore

# Call with different argument also resolves region automatically
mock_s3_region_resolver.return_value = ap_southeast_2_region
filesystem_ap_southeast_2 = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket")
assert filesystem_ap_southeast_2.region == ap_southeast_2_region
assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 2 # type: ignore
assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 2 # type: ignore

# Call with same argument hits cache
filesystem_us_cached = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket")
assert filesystem_us_cached.region == us_east_1_region
assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 1 # type: ignore

# Call with same argument hits cache
filesystem_ap_southeast_2_cached = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket")
assert filesystem_ap_southeast_2_cached.region == ap_southeast_2_region
assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 2 # type: ignore

0 comments on commit 0c61ac8

Please sign in to comment.