Skip to content

Commit

Permalink
Update cli compatible with remote assets paths (gs and s3)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgenin committed Jun 28, 2022
1 parent 77723ac commit 6e6e321
Showing 1 changed file with 68 additions and 46 deletions.
114 changes: 68 additions & 46 deletions modelkit/assets/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from rich.table import Table
from rich.tree import Tree

from modelkit.assets.drivers.gcs import GCSStorageDriver
from modelkit.assets.drivers.s3 import S3StorageDriver
from modelkit.assets.errors import ObjectDoesNotExistError
from modelkit.assets.manager import AssetsManager
from modelkit.assets.remote import StorageProvider
Expand All @@ -24,41 +26,34 @@ def assets_cli():
pass


gcs_fn_re = r"gs://(?P<bucket_name>[\w\-]+)/(?P<object_name>.+)"
storage_url_re = (
r"(?P<storage_prefix>[\w]*)://(?P<bucket_name>[\w\-]+)/(?P<object_name>.+)"
)


def parse_gcs(path):
match = re.match(gcs_fn_re, path)
def parse_remote_url(path):
match = re.match(storage_url_re, path)
if not match:
raise ValueError(f"Could not parse GCS path `{path}`")
raise ValueError(f"Could not parse path `{path}`")
return match.groupdict()


def _download_object_or_prefix(manager, asset_path, destination_dir):
parsed_path = parse_gcs(asset_path)
def _download_object_or_prefix(driver, object_name, destination_dir):
asset_path = os.path.join(destination_dir, "myasset")
try:
manager.storage_provider.driver.download_object(
object_name=parsed_path["object_name"],
destination_path=asset_path,
)
driver.download_object(object_name=object_name, destination_path=asset_path)
except ObjectDoesNotExistError:
# maybe prefix containing objects
paths = [
path
for path in manager.storage_provider.driver.iterate_objects(
prefix=parsed_path["object_name"]
)
]
paths = [path for path in driver.iterate_objects(prefix=object_name)]
if not paths:
raise

os.mkdir(asset_path)
for path in paths:
object_name = path.split("/")[-1]
manager.storage_provider.driver.download_object(
object_name=parsed_path["object_name"] + "/" + object_name,
destination_path=os.path.join(asset_path, object_name),
sub_object_name = path.split("/")[-1]
driver.download_object(
object_name=object_name + "/" + sub_object_name,
destination_path=os.path.join(asset_path, sub_object_name),
)
return asset_path

Expand Down Expand Up @@ -101,12 +96,16 @@ def new(asset_path, asset_spec, storage_prefix, dry_run):
NB: [asset_name] can contain `/` too.
"""
new_(asset_path, asset_spec, storage_prefix, dry_run)


def new_(asset_path, asset_spec, storage_prefix, dry_run):
_check_asset_file_number(asset_path)
manager = StorageProvider(
prefix=storage_prefix,
)
print("Current assets manager:")
print(f" - storage provider = `{manager.driver}`")
destination_provider = StorageProvider(prefix=storage_prefix)

print("Destination assets provider:")
print(f" - storage driver = `{destination_provider.driver}`")
print(f" - driver bucket = `{destination_provider.driver.bucket}`")
print(f" - prefix = `{storage_prefix}`")

print(f"Current asset: `{asset_spec}`")
Expand All @@ -119,13 +118,24 @@ def new(asset_path, asset_spec, storage_prefix, dry_run):
response = click.prompt("[y/N]")
if response == "y":
with tempfile.TemporaryDirectory() as tmp_dir:
if asset_path.startswith("gs://"):
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
if parsed_path["storage_prefix"] == "gs":
driver = GCSStorageDriver(bucket=parsed_path["bucket_name"])
elif parsed_path["storage_prefix"] == "s3":
driver = S3StorageDriver(bucket=parsed_path["bucket_name"])
else:
raise ValueError(
f"Unmanaged storage prefix `{parsed_path['storage_prefix']}`"
)
asset_path = _download_object_or_prefix(
manager, asset_path=asset_path, destination_dir=tmp_dir
driver,
object_name=parsed_path["object_name"],
destination_dir=tmp_dir,
)
manager.new(asset_path, spec.name, version, dry_run)
else:
print("Aborting.")
destination_provider.new(asset_path, spec.name, version, dry_run)
return version
print("Aborting.")


@assets_cli.command("update")
Expand Down Expand Up @@ -156,14 +166,20 @@ def update(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
Specific documentation depends on the choosen model
"""
try:
update_(asset_path, asset_spec, storage_prefix, bump_major, dry_run)
except ObjectDoesNotExistError:
print("Remote asset not found. Create it first using `new`")
sys.exit(1)


def update_(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
_check_asset_file_number(asset_path)
manager = StorageProvider(
prefix=storage_prefix,
)
destination_provider = StorageProvider(prefix=storage_prefix)

print("Current assets manager:")
print(f" - storage provider = `{manager.driver}`")
print("Destination assets provider:")
print(f" - storage driver = `{destination_provider.driver}`")
print(f" - driver bucket = `{destination_provider.driver.bucket}`")
print(f" - prefix = `{storage_prefix}`")

print(f"Current asset: `{asset_spec}`")
Expand All @@ -175,11 +191,7 @@ def update(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
print(f" - name = `{spec.name}`")
print(f" - version = `{spec.version}`")

try:
version_list = manager.get_versions_info(spec.name)
except ObjectDoesNotExistError:
print("Remote asset not found. Create it first using `new`")
sys.exit(1)
version_list = destination_provider.get_versions_info(spec.name)

update_params = spec.versioning.get_update_cli_params(
version=spec.version,
Expand All @@ -196,21 +208,31 @@ def update(asset_path, asset_spec, storage_prefix, bump_major, dry_run):

response = click.prompt("[y/N]")
if response == "y":

with tempfile.TemporaryDirectory() as tmp_dir:
if asset_path.startswith("gs://"):
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
if parsed_path["storage_prefix"] == "gs":
driver = GCSStorageDriver(bucket=parsed_path["bucket_name"])
elif parsed_path["storage_prefix"] == "s3":
driver = S3StorageDriver(bucket=parsed_path["bucket_name"])
else:
raise ValueError(
f"Unmanaged storage prefix `{parsed_path['storage_prefix']}`"
)
asset_path = _download_object_or_prefix(
manager, asset_path=asset_path, destination_dir=tmp_dir
driver,
object_name=parsed_path["object_name"],
destination_dir=tmp_dir,
)

manager.update(
destination_provider.update(
asset_path,
name=spec.name,
version=new_version,
dry_run=dry_run,
)
else:
print("Aborting.")
return new_version
print("Aborting.")


@assets_cli.command("list")
Expand Down

0 comments on commit 6e6e321

Please sign in to comment.