Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update azureblob.py #3289

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 76 additions & 33 deletions luigi/contrib/azureblob.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import datetime

from azure.storage.blob import blockblobservice
from azure.storage.blob import BlobServiceClient

from luigi.format import get_default_format
from luigi.target import FileAlreadyExists, FileSystem, AtomicLocalFile, FileSystemTarget
Expand Down Expand Up @@ -62,60 +62,101 @@ def __init__(self, account_name=None, account_key=None, sas_token=None, **kwargs
* `custom_domain` - The custom domain to use. This can be set in the Azure Portal. For example, ‘www.mydomain.com’.
* `token_credential` - A token credential used to authenticate HTTPS requests. The token value should be updated before its expiration.
"""
self.options = {"account_name": account_name, "account_key": account_key, "sas_token": sas_token}
if kwargs.get("custom_domain"):
account_url = "{protocol}://{custom_domain}/{account_name}".format(protocol=kwargs.get("protocol", "https"),
custom_domain=kwargs.get("custom_domain"),
account_name=account_name)
else:
account_url = "{protocol}://{account_name}.blob.{endpoint_suffix}".format(protocol=kwargs.get("protocol",
"https"),
account_name=account_name,
endpoint_suffix=kwargs.get(
"endpoint_suffix",
"core.windows.net"))

self.options = {
"account_name": account_name,
"account_key": account_key,
"account_url": account_url,
"sas_token": sas_token}
self.kwargs = kwargs

@property
def connection(self):
return blockblobservice.BlockBlobService(account_name=self.options.get("account_name"),
account_key=self.options.get("account_key"),
sas_token=self.options.get("sas_token"),
protocol=self.kwargs.get("protocol"),
connection_string=self.kwargs.get("connection_string"),
endpoint_suffix=self.kwargs.get("endpoint_suffix"),
custom_domain=self.kwargs.get("custom_domain"),
is_emulated=self.kwargs.get("is_emulated") or False)
if self.kwargs.get("connection_string"):
return BlobServiceClient.from_connection_string(conn_str=self.kwargs.get("connection_string"),
**self.kwargs)
else:
return BlobServiceClient(account_url=self.options.get("account_url"),
credential=self.options.get("account_key") or self.options.get("sas_token"),
**self.kwargs)

def container_client(self, container_name):
return self.connection.get_container_client(container_name)

def blob_client(self, container_name, blob_name):
container_client = self.container_client(container_name)
return container_client.get_blob_client(blob_name)

def upload(self, tmp_path, container, blob, **kwargs):
logging.debug("Uploading file '{tmp_path}' to container '{container}' and blob '{blob}'".format(
tmp_path=tmp_path, container=container, blob=blob))
self.create_container(container)
lease_id = self.connection.acquire_blob_lease(container, blob)\
if self.exists("{container}/{blob}".format(container=container, blob=blob)) else None
lease = None
blob_client = self.blob_client(container, blob)
if blob_client.exists():
lease = blob_client.acquire_lease()
try:
self.connection.create_blob_from_path(container, blob, tmp_path, lease_id=lease_id, progress_callback=kwargs.get("progress_callback"))
with open(tmp_path, 'rb') as data:
blob_client.upload_blob(data,
overwrite=True,
lease=lease,
progress_hook=kwargs.get("progress_callback"))
finally:
if lease_id is not None:
self.connection.release_blob_lease(container, blob, lease_id)
if lease is not None:
lease.release()

def download_as_bytes(self, container, blob, bytes_to_read=None):
start_range, end_range = (0, bytes_to_read-1) if bytes_to_read is not None else (None, None)
logging.debug("Downloading from container '{container}' and blob '{blob}' as bytes".format(
container=container, blob=blob))
return self.connection.get_blob_to_bytes(container, blob, start_range=start_range, end_range=end_range).content
blob_client = self.blob_client(container, blob)
download_stream = blob_client.download_blob(offset=0, length=bytes_to_read) if bytes_to_read \
else blob_client.download_blob()
return download_stream.readall()

def download_as_file(self, container, blob, location):
logging.debug("Downloading from container '{container}' and blob '{blob}' to {location}".format(
container=container, blob=blob, location=location))
return self.connection.get_blob_to_path(container, blob, location)
blob_client = self.blob_client(container, blob)
with open(location, 'wb') as file:
download_stream = blob_client.download_blob()
file.write(download_stream.readall())
return blob_client.get_blob_properties()

def create_container(self, container_name):
return self.connection.create_container(container_name)
if not self.exists(container_name):
return self.connection.create_container(container_name)

def delete_container(self, container_name):
lease_id = self.connection.acquire_container_lease(container_name)
self.connection.delete_container(container_name, lease_id=lease_id)
container_client = self.container_client(container_name)
lease = container_client.acquire_lease()
container_client.delete_container(lease=lease)

def exists(self, path):
container, blob = self.splitfilepath(path)
return self.connection.exists(container, blob)
if blob is None:
return self.container_client(container).exists()
else:
return self.blob_client(container, blob).exists()

def remove(self, path, recursive=True, skip_trash=True):
container, blob = self.splitfilepath(path)
if not self.exists(path):
return False
lease_id = self.connection.acquire_blob_lease(container, blob)
self.connection.delete_blob(container, blob, lease_id=lease_id)

container, blob = self.splitfilepath(path)
blob_client = self.blob_client(container, blob)
lease = blob_client.acquire_lease()
blob_client.delete_blob(lease=lease)
return True

def mkdir(self, path, parents=True, raise_if_exists=False):
Expand Down Expand Up @@ -148,16 +189,18 @@ def copy(self, path, dest):
source_container=source_container, dest_container=dest_container
))

source_lease_id = self.connection.acquire_blob_lease(source_container, source_blob)
destination_lease_id = self.connection.acquire_blob_lease(dest_container, dest_blob) if self.exists(dest) else None
source_blob_client = self.blob_client(source_container, source_blob)
dest_blob_client = self.blob_client(dest_container, dest_blob)
source_lease = source_blob_client.acquire_lease()
destination_lease = dest_blob_client.acquire_lease() if self.exists(dest) else None
try:
return self.connection.copy_blob(source_container, dest_blob, self.connection.make_blob_url(
source_container, source_blob),
destination_lease_id=destination_lease_id, source_lease_id=source_lease_id)
return dest_blob_client.start_copy_from_url(source_url=source_blob_client.url,
source_lease=source_lease,
destination_lease=destination_lease)
finally:
self.connection.release_blob_lease(source_container, source_blob, source_lease_id)
if destination_lease_id is not None:
self.connection.release_blob_lease(dest_container, dest_blob, destination_lease_id)
source_lease.release()
if destination_lease is not None:
destination_lease.release()

def rename_dont_move(self, path, dest):
self.move(path, dest)
Expand Down
9 changes: 5 additions & 4 deletions test/contrib/azureblob_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
account_name = os.environ.get("ACCOUNT_NAME")
account_key = os.environ.get("ACCOUNT_KEY")
sas_token = os.environ.get("SAS_TOKEN")
is_emulated = False if account_name else True
client = AzureBlobClient(account_name, account_key, sas_token, is_emulated=is_emulated)
custom_domain = os.environ.get("CUSTOM_DOMAIN")
protocol = os.environ.get("PROTOCOL")
client = AzureBlobClient(account_name, account_key, sas_token, custom_domain=custom_domain, protocol=protocol)


@pytest.mark.azureblob
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_upload_copy_move_remove_blob(self):
self.assertTrue(self.client.exists(from_path))

# copy
self.assertIn(self.client.copy(from_path, to_path).status, ["success", "pending"])
self.assertIn(self.client.copy(from_path, to_path)["copy_status"], ["success", "pending"])
self.assertTrue(self.client.exists(to_path))

# remove
Expand All @@ -121,7 +122,7 @@ def output(self):
return AzureBlobTarget("luigi-test", "movie-cheesy.txt", client, download_when_reading=False)

def run(self):
client.connection.create_container("luigi-test")
client.create_container("luigi-test")
with self.output().open("w") as op:
op.write("I'm going to make him an offer he can't refuse.\n")
op.write("Toto, I've got a feeling we're not in Kansas anymore.\n")
Expand Down
Loading