Skip to content

Commit

Permalink
Merge pull request #180 from Aiven-Open/joelynch/str-enum
Browse files Browse the repository at this point in the history
azure: allow StrEnum use in AzureTransfer.get_or_create_container
  • Loading branch information
giacomo-alzetta-aiven authored Apr 17, 2024
2 parents 6081a0c + 764d021 commit 34a9cbe
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 31 deletions.
5 changes: 5 additions & 0 deletions rohmu/object_storage/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Any, BinaryIO, Collection, Iterator, Optional, Tuple, Union

import azure.common
import enum
import logging
import time

Expand Down Expand Up @@ -393,6 +394,10 @@ def progress_callback(pipeline_response: Any) -> None:
delattr(fd, "tell")

def get_or_create_container(self, container_name: str) -> str:
if isinstance(container_name, enum.Enum):
# ensure that the enum value is used rather than the enum name
# https://github.com/Azure/azure-sdk-for-python/blob/azure-storage-blob_12.8.1/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py#L667
container_name = container_name.value
start_time = time.monotonic()
try:
self.conn.create_container(container_name)
Expand Down
81 changes: 50 additions & 31 deletions test/object_storage/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,33 @@
# Copyright (c) 2022 Aiven, Helsinki, Finland. https://aiven.io/
from datetime import datetime
from io import BytesIO
from pytest_mock import MockerFixture
from rohmu.common.strenum import StrEnum
from rohmu.errors import InvalidByteRangeError
from rohmu.object_storage.azure import AzureTransfer
from rohmu.object_storage.config import AzureObjectStorageConfig
from tempfile import NamedTemporaryFile
from types import ModuleType
from typing import Any, Optional, Tuple
from typing import Any, Optional
from unittest.mock import MagicMock, patch

import azure.storage.blob
import pytest
import rohmu.object_storage.azure
import sys


@pytest.fixture(scope="module", name="mock_azure_module")
def fixture_mock_azure_module() -> Tuple[ModuleType, MagicMock]:
@pytest.fixture(name="mock_get_blob_client")
def fixture_mock_get_blob_client(mocker: MockerFixture) -> MagicMock:
get_blob_client_mock = MagicMock()
blob_client = MagicMock(get_blob_client=get_blob_client_mock)
service_client = MagicMock(from_connection_string=MagicMock(return_value=blob_client))
module_patches = {
"azure.common": MagicMock(),
"azure.core.exceptions": MagicMock(),
"azure.storage.blob": MagicMock(BlobServiceClient=service_client),
}
with patch.dict(sys.modules, module_patches):
import rohmu.object_storage.azure

return rohmu.object_storage.azure, get_blob_client_mock
mocker.patch.object(rohmu.object_storage.azure, "BlobServiceClient", service_client)
return get_blob_client_mock


@pytest.fixture(name="azure_module")
def fixture_azure_module(mock_azure_module: Tuple[ModuleType, MagicMock]) -> ModuleType:
return mock_azure_module[0]


@pytest.fixture(name="get_blob_client")
def fixture_get_blob_client(mock_azure_module: Tuple[ModuleType, MagicMock]) -> MagicMock:
return mock_azure_module[1]


def test_store_file_from_disk(azure_module: ModuleType, get_blob_client: MagicMock) -> None:
def test_store_file_from_disk(mock_get_blob_client: MagicMock) -> None:
notifier = MagicMock()
transfer = azure_module.AzureTransfer(
transfer = AzureTransfer(
bucket_name="test_bucket",
account_name="test_account",
account_key="test_key1",
Expand All @@ -49,7 +36,7 @@ def test_store_file_from_disk(azure_module: ModuleType, get_blob_client: MagicMo
test_data = b"test-data"
metadata = {"Content-Length": len(test_data), "some-date": datetime(2022, 11, 15, 18, 30, 58, 486644)}
upload_blob = MagicMock()
get_blob_client.return_value = MagicMock(upload_blob=upload_blob)
mock_get_blob_client.return_value = MagicMock(upload_blob=upload_blob)

with NamedTemporaryFile() as tmpfile:
tmpfile.write(test_data)
Expand All @@ -62,9 +49,9 @@ def test_store_file_from_disk(azure_module: ModuleType, get_blob_client: MagicMo
)


def test_store_file_object(azure_module: ModuleType, get_blob_client: MagicMock) -> None:
def test_store_file_object(mock_get_blob_client: MagicMock) -> None:
notifier = MagicMock()
transfer = azure_module.AzureTransfer(
transfer = AzureTransfer(
bucket_name="test_bucket",
account_name="test_account",
account_key="test_key2",
Expand All @@ -80,7 +67,7 @@ def upload_side_effect(*args: Any, **kwargs: Any) -> None:

# Size reporting relies on the progress callback from azure client
upload_blob = MagicMock(wraps=upload_side_effect)
get_blob_client.return_value = MagicMock(upload_blob=upload_blob)
mock_get_blob_client.return_value = MagicMock(upload_blob=upload_blob)

transfer.store_file_object(key="test_key2", fd=file_object, metadata=metadata)

Expand All @@ -90,9 +77,9 @@ def upload_side_effect(*args: Any, **kwargs: Any) -> None:
)


def test_get_contents_to_fileobj_raises_error_on_invalid_byte_range(azure_module: ModuleType) -> None:
def test_get_contents_to_fileobj_raises_error_on_invalid_byte_range(mock_get_blob_client: MagicMock) -> None:
notifier = MagicMock()
transfer = azure_module.AzureTransfer(
transfer = AzureTransfer(
bucket_name="test_bucket",
account_name="test_account",
account_key="test_key2",
Expand Down Expand Up @@ -201,3 +188,35 @@ def test_conn_string(host: Optional[str], port: Optional[int], is_secured: bool,
account_name="test_name", account_key="test_key", azure_cloud=None, host=host, port=port, is_secure=is_secured
)
assert expected == conn_string


class MockBucketName(StrEnum):
bucket_enum_key = "bucket_enum_value"


def test_create_container_enum(mocker: MockerFixture) -> None:
container_client_mock = MagicMock(spec=azure.storage.blob.ContainerClient)
mocker.patch.object(azure.storage.blob._blob_service_client, "ContainerClient", container_client_mock)
notifier = MagicMock()
AzureTransfer(
bucket_name=MockBucketName.bucket_enum_key,
account_name="test_account",
account_key="test_key",
notifier=notifier,
)
container_name = container_client_mock.call_args.kwargs["container_name"]
assert container_name == "bucket_enum_value"


def test_create_container_str(mocker: MockerFixture) -> None:
container_client_mock = MagicMock(spec=azure.storage.blob.ContainerClient)
mocker.patch.object(azure.storage.blob._blob_service_client, "ContainerClient", container_client_mock)
notifier = MagicMock()
AzureTransfer(
bucket_name="bucket_name",
account_name="test_account",
account_key="test_key",
notifier=notifier,
)
container_name = container_client_mock.call_args.kwargs["container_name"]
assert container_name == "bucket_name"

0 comments on commit 34a9cbe

Please sign in to comment.