Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix read from multiple s3 regions #1453

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
22 changes: 13 additions & 9 deletions pyiceberg/io/pyarrow.py
jiakai-li marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def parse_location(location: str) -> Tuple[str, str, str]:

def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSystem:
if scheme in {"s3", "s3a", "s3n", "oss"}:
from pyarrow.fs import S3FileSystem
from pyarrow.fs import S3FileSystem, resolve_s3_region

client_kwargs: Dict[str, Any] = {
"endpoint_override": self.properties.get(S3_ENDPOINT),
Expand All @@ -377,6 +377,12 @@ def _initialize_fs(self, scheme: str, netloc: Optional[str] = None) -> FileSyste
if force_virtual_addressing := self.properties.get(S3_FORCE_VIRTUAL_ADDRESSING):
client_kwargs["force_virtual_addressing"] = property_as_bool(self.properties, force_virtual_addressing, False)

# Override the default s3.region if netloc(bucket) resolves to a different region
jiakai-li marked this conversation as resolved.
Show resolved Hide resolved
try:
client_kwargs["region"] = resolve_s3_region(netloc)
except (OSError, TypeError):
pass

return S3FileSystem(**client_kwargs)
elif scheme in ("hdfs", "viewfs"):
from pyarrow.fs import HadoopFileSystem
Expand Down Expand Up @@ -1326,13 +1332,14 @@ def _task_to_table(
return None


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
def _read_all_delete_files(io: FileIO, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
deletes_per_file: Dict[str, List[ChunkedArray]] = {}
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
if len(unique_deletes) > 0:
executor = ExecutorFactory.get_or_create()
deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map(
lambda args: _read_deletes(*args), [(fs, delete) for delete in unique_deletes]
lambda args: _read_deletes(*args),
[(_fs_from_file_path(delete_file.file_path, io), delete_file) for delete_file in unique_deletes],
)
for delete in deletes_per_files:
for file, arr in delete.items():
Expand Down Expand Up @@ -1366,7 +1373,6 @@ def _fs_from_file_path(file_path: str, io: FileIO) -> FileSystem:
class ArrowScan:
_table_metadata: TableMetadata
_io: FileIO
_fs: FileSystem
_projected_schema: Schema
_bound_row_filter: BooleanExpression
_case_sensitive: bool
Expand All @@ -1376,7 +1382,6 @@ class ArrowScan:
Attributes:
_table_metadata: Current table metadata of the Iceberg table
_io: PyIceberg FileIO implementation from which to fetch the io properties
_fs: PyArrow FileSystem to use to read the files
_projected_schema: Iceberg Schema to project onto the data files
_bound_row_filter: Schema bound row expression to filter the data with
_case_sensitive: Case sensitivity when looking up column names
Expand All @@ -1394,7 +1399,6 @@ def __init__(
) -> None:
self._table_metadata = table_metadata
self._io = io
self._fs = _fs_from_file_path(table_metadata.location, io) # TODO: use different FileSystem per file
jiakai-li marked this conversation as resolved.
Show resolved Hide resolved
self._projected_schema = projected_schema
self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
self._case_sensitive = case_sensitive
Expand Down Expand Up @@ -1434,7 +1438,7 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table:
ResolveError: When a required field cannot be found in the file
ValueError: When a field type in the file cannot be projected to the schema type
"""
deletes_per_file = _read_all_delete_files(self._fs, tasks)
deletes_per_file = _read_all_delete_files(self._io, tasks)
executor = ExecutorFactory.get_or_create()

def _table_from_scan_task(task: FileScanTask) -> pa.Table:
Expand Down Expand Up @@ -1497,7 +1501,7 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record
ResolveError: When a required field cannot be found in the file
ValueError: When a field type in the file cannot be projected to the schema type
"""
deletes_per_file = _read_all_delete_files(self._fs, tasks)
deletes_per_file = _read_all_delete_files(self._io, tasks)
return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file)

def _record_batches_from_scan_tasks_and_deletes(
Expand All @@ -1508,7 +1512,7 @@ def _record_batches_from_scan_tasks_and_deletes(
if self._limit is not None and total_row_count >= self._limit:
break
batches = _task_to_record_batches(
self._fs,
_fs_from_file_path(task.file.file_path, self._io),
jiakai-li marked this conversation as resolved.
Show resolved Hide resolved
task,
self._bound_row_filter,
self._projected_schema,
Expand Down
37 changes: 35 additions & 2 deletions tests/io/test_pyarrow.py
jiakai-li marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,11 @@ def test_pyarrow_s3_session_properties() -> None:
**UNIFIED_AWS_SESSION_PROPERTIES,
}

with patch("pyarrow.fs.S3FileSystem") as mock_s3fs:
with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver:
jiakai-li marked this conversation as resolved.
Show resolved Hide resolved
s3_fileio = PyArrowFileIO(properties=session_properties)
filename = str(uuid.uuid4())

mock_s3_region_resolver.return_value = "us-east-1"
s3_fileio.new_input(location=f"s3://warehouse/{filename}")

mock_s3fs.assert_called_with(
Expand All @@ -381,10 +382,11 @@ def test_pyarrow_unified_session_properties() -> None:
**UNIFIED_AWS_SESSION_PROPERTIES,
}

with patch("pyarrow.fs.S3FileSystem") as mock_s3fs:
with patch("pyarrow.fs.S3FileSystem") as mock_s3fs, patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver:
jiakai-li marked this conversation as resolved.
Show resolved Hide resolved
s3_fileio = PyArrowFileIO(properties=session_properties)
filename = str(uuid.uuid4())

mock_s3_region_resolver.return_value = "client.region"
s3_fileio.new_input(location=f"s3://warehouse/{filename}")

mock_s3fs.assert_called_with(
Expand Down Expand Up @@ -2074,3 +2076,34 @@ def test__to_requested_schema_timestamps_without_downcast_raises_exception(
_to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False)

assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" in str(exc_info.value)


def test_pyarrow_file_io_fs_by_scheme_cache() -> None:
pyarrow_file_io = PyArrowFileIO()
us_east_1_region = "us-eas1-1"
jiakai-li marked this conversation as resolved.
Show resolved Hide resolved
ap_southeast_2_region = "ap-southeast-2"

with patch("pyarrow.fs.resolve_s3_region") as mock_s3_region_resolver:
# Call with new argument resolves region automatically
mock_s3_region_resolver.return_value = us_east_1_region
filesystem_us = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket")
assert filesystem_us.region == us_east_1_region
assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 1 # type: ignore
assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 1 # type: ignore

# Call with different argument also resolves region automatically
mock_s3_region_resolver.return_value = ap_southeast_2_region
filesystem_ap_southeast_2 = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket")
assert filesystem_ap_southeast_2.region == ap_southeast_2_region
assert pyarrow_file_io.fs_by_scheme.cache_info().misses == 2 # type: ignore
assert pyarrow_file_io.fs_by_scheme.cache_info().currsize == 2 # type: ignore

# Call with same argument hits cache
filesystem_us_cached = pyarrow_file_io.fs_by_scheme("s3", "us-east-1-bucket")
assert filesystem_us_cached.region == us_east_1_region
assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 1 # type: ignore

# Call with same argument hits cache
filesystem_ap_southeast_2_cached = pyarrow_file_io.fs_by_scheme("s3", "ap-southeast-2-bucket")
assert filesystem_ap_southeast_2_cached.region == ap_southeast_2_region
assert pyarrow_file_io.fs_by_scheme.cache_info().hits == 2 # type: ignore