Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add custom Azure host/port to support custom blob endpoint #164

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 36 additions & 21 deletions rohmu/object_storage/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SourceStorageModelT,
)
from rohmu.object_storage.config import ( # pylint: disable=unused-import
AZURE_ENDPOINT_SUFFIXES,
AZURE_MAX_BLOCK_SIZE as MAX_BLOCK_SIZE,
AzureObjectStorageConfig as Config,
calculate_azure_max_block_size as calculate_max_block_size,
Expand All @@ -42,14 +43,6 @@
from azure.storage.blob._models import BlobPrefix, BlobType # type: ignore


ENDPOINT_SUFFIXES = {
None: "core.windows.net",
"germany": "core.cloudapi.de", # Azure Germany is a completely separate cloud from the regular Azure Public cloud
"china": "core.chinacloudapi.cn",
"public": "core.windows.net",
}


# Reduce Azure logging verbocity of http requests and responses
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)

Expand All @@ -64,6 +57,9 @@ def __init__(
account_key: Optional[str] = None,
sas_token: Optional[str] = None,
prefix: Optional[str] = None,
is_secure: bool = True,
host: Optional[str] = None,
port: Optional[int] = None,
azure_cloud: Optional[str] = None,
proxy_info: Optional[dict[str, Union[str, int]]] = None,
notifier: Optional[Notifier] = None,
Expand All @@ -78,16 +74,13 @@ def __init__(
self.account_key = account_key
self.container_name = bucket_name
self.sas_token = sas_token
try:
endpoint_suffix = ENDPOINT_SUFFIXES[azure_cloud]
except KeyError:
raise InvalidConfigurationError(f"Unknown azure cloud {repr(azure_cloud)}")

conn_str = (
"DefaultEndpointsProtocol=https;"
f"AccountName={self.account_name};"
f"AccountKey={self.account_key};"
f"EndpointSuffix={endpoint_suffix}"
conn_str = self.conn_string(
account_name=account_name,
account_key=account_key,
azure_cloud=azure_cloud,
host=host,
port=port,
is_secure=is_secure,
)
config: dict[str, Any] = {"max_block_size": MAX_BLOCK_SIZE}
if proxy_info:
Expand All @@ -97,13 +90,13 @@ def __init__(
auth = f"{username}:{password}@"
else:
auth = ""
host = proxy_info["host"]
port = proxy_info["port"]
proxy_host = proxy_info["host"]
proxy_port = proxy_info["port"]
if proxy_info.get("type") == "socks5":
schema = "socks5"
else:
schema = "http"
config["proxies"] = {"https": f"{schema}://{auth}{host}:{port}"}
config["proxies"] = {"https": f"{schema}://{auth}{proxy_host}:{proxy_port}"}

self.conn: BlobServiceClient = BlobServiceClient.from_connection_string(
conn_str=conn_str,
Expand All @@ -113,6 +106,28 @@ def __init__(
self.container = self.get_or_create_container(self.container_name)
self.log.debug("AzureTransfer initialized, %r", self.container_name)

@staticmethod
def conn_string(
account_name: str,
account_key: Optional[str],
azure_cloud: Optional[str],
host: Optional[str],
port: Optional[int],
is_secure: bool,
) -> str:
protocol = "https" if is_secure else "http"
conn = [
f"DefaultEndpointsProtocol={protocol}",
f"AccountName={account_name}",
f"AccountKey={account_key}",
]
if not host and not port:
endpoint_suffix = AZURE_ENDPOINT_SUFFIXES[azure_cloud]
conn.append(f"EndpointSuffix={endpoint_suffix}")
else:
conn.append(f"BlobEndpoint={protocol}://{host}:{port}/{account_name}")
Comment on lines +124 to +128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I guess we could do something like:

Suggested change
if not host and not port:
endpoint_suffix = AZURE_ENDPOINT_SUFFIXES[azure_cloud]
conn.append(f"EndpointSuffix={endpoint_suffix}")
else:
conn.append(f"BlobEndpoint={protocol}://{host}:{port}/{account_name}")
if not host and not port:
endpoint_suffix = AZURE_ENDPOINT_SUFFIXES[azure_cloud]
conn.append(f"EndpointSuffix={endpoint_suffix}")
elif host and port:
conn.append(f"BlobEndpoint={protocol}://{host}:{port}/{account_name}")
else:
raise ValueError("You must either specify both host and port or neither of them")

In most cases the AzureTransfer will be build using the get_transfer facade where we have the PyDantic validation, but I believe people can still manually create the transfer, so maybe validating this again here is better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it is possible to create transfers bypassing the pydantic config models?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just saying that we are not indicating that AzureTransfer is private, so users could instantiate it directly instead of going through the factory methods that use from_model to create the instance from the pydantic configuration.

I don't this is a big deal, hence why I class this as nitpicking.

return ";".join(conn)

def copy_file(
self, *, source_key: str, destination_key: str, metadata: Optional[Metadata] = None, **kwargs: Any
) -> None:
Expand Down
25 changes: 24 additions & 1 deletion rohmu/object_storage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from enum import Enum, unique
from pathlib import Path
from pydantic import Field, root_validator
from pydantic import Field, root_validator, validator
from rohmu.common.models import ProxyInfo, StorageDriver, StorageModel
from typing import Any, Dict, Final, Literal, Optional, TypeVar

Expand Down Expand Up @@ -42,6 +42,12 @@ def calculate_azure_max_block_size() -> int:
return max(min(int(total_mem_mib / 1000), 100), 4) * 1024 * 1024


AZURE_ENDPOINT_SUFFIXES = {
None: "core.windows.net",
"germany": "core.cloudapi.de", # Azure Germany is a completely separate cloud from the regular Azure Public cloud
"china": "core.chinacloudapi.cn",
"public": "core.windows.net",
}
# Increase block size based on host memory. Azure supports up to 50k blocks and up to 5 TiB individual
# files. Default block size is set to 4 MiB so only ~200 GB files can be uploaded. In order to get close
# to that 5 TiB increase the block size based on host memory; we don't want to use the max 100 for all
Expand Down Expand Up @@ -83,10 +89,27 @@ class AzureObjectStorageConfig(StorageModel):
account_key: Optional[str] = Field(None, repr=False)
sas_token: Optional[str] = Field(None, repr=False)
prefix: Optional[str] = None
is_secure: bool = True
host: Optional[str] = None
port: Optional[int] = None
azure_cloud: Optional[str] = None
proxy_info: Optional[ProxyInfo] = None
storage_type: Literal[StorageDriver.azure] = StorageDriver.azure

@root_validator
@classmethod
def host_and_port_must_be_set_together(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if (values["host"] is None) != (values["port"] is None):
raise ValueError("host and port must be set together")
return values

@validator("azure_cloud")
@classmethod
def valid_azure_cloud_endpoint(cls, v: str) -> str:
if v not in AZURE_ENDPOINT_SUFFIXES:
raise ValueError(f"azure_cloud must be one of {AZURE_ENDPOINT_SUFFIXES.keys()}")
return v


class GoogleObjectStorageConfig(StorageModel):
project_id: str
Expand Down
100 changes: 99 additions & 1 deletion test/object_storage/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from datetime import datetime
from io import BytesIO
from rohmu.errors import InvalidByteRangeError
from rohmu.object_storage.config import AzureObjectStorageConfig
from tempfile import NamedTemporaryFile
from types import ModuleType
from typing import Any, Tuple
from typing import Any, Optional, Tuple
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -103,3 +104,100 @@ def test_get_contents_to_fileobj_raises_error_on_invalid_byte_range(azure_module
fileobj_to_store_to=BytesIO(),
byte_range=(100, 10),
)


def test_minimal_config() -> None:
config = AzureObjectStorageConfig(account_name="test")
assert config.account_name == "test"


def test_azure_config_host_port_set_together() -> None:
with pytest.raises(ValueError):
AzureObjectStorageConfig(account_name="test", host="localhost")
with pytest.raises(ValueError):
AzureObjectStorageConfig(account_name="test", port=10000)
config = AzureObjectStorageConfig(account_name="test", host="localhost", port=10000)
assert config.host == "localhost"
assert config.port == 10000


def test_valid_azure_cloud_endpoint() -> None:
with pytest.raises(ValueError):
AzureObjectStorageConfig(account_name="test", azure_cloud="invalid")
config = AzureObjectStorageConfig(account_name="test", azure_cloud="public")
assert config.azure_cloud == "public"


@pytest.mark.parametrize(
"host,port,is_secured,expected",
[
(
None,
None,
True,
";".join(
[
"DefaultEndpointsProtocol=https",
"AccountName=test_name",
"AccountKey=test_key",
"EndpointSuffix=core.windows.net",
]
),
),
(
None,
None,
False,
";".join(
[
"DefaultEndpointsProtocol=http",
"AccountName=test_name",
"AccountKey=test_key",
"EndpointSuffix=core.windows.net",
]
),
),
(
"localhost",
10000,
True,
";".join(
[
"DefaultEndpointsProtocol=https",
"AccountName=test_name",
"AccountKey=test_key",
"BlobEndpoint=https://localhost:10000/test_name",
]
),
),
(
"localhost",
10000,
False,
";".join(
[
"DefaultEndpointsProtocol=http",
"AccountName=test_name",
"AccountKey=test_key",
"BlobEndpoint=http://localhost:10000/test_name",
]
),
),
],
)
def test_conn_string(host: Optional[str], port: Optional[int], is_secured: bool, expected: str) -> None:
get_blob_client_mock = MagicMock()
blob_client = MagicMock(get_blob_client=get_blob_client_mock)
service_client = MagicMock(from_connection_string=MagicMock(return_value=blob_client))
module_patches = {
"azure.common": MagicMock(),
"azure.core.exceptions": MagicMock(),
"azure.storage.blob": MagicMock(BlobServiceClient=service_client),
}
with patch.dict(sys.modules, module_patches):
from rohmu.object_storage.azure import AzureTransfer

conn_string = AzureTransfer.conn_string(
account_name="test_name", account_key="test_key", azure_cloud=None, host=host, port=port, is_secure=is_secured
)
assert expected == conn_string
Loading