Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cli compatible with remote assets paths (gs and s3) #165

Merged
merged 1 commit into from
Jun 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 68 additions & 46 deletions modelkit/assets/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from rich.table import Table
from rich.tree import Tree

from modelkit.assets.drivers.gcs import GCSStorageDriver
from modelkit.assets.drivers.s3 import S3StorageDriver
from modelkit.assets.errors import ObjectDoesNotExistError
from modelkit.assets.manager import AssetsManager
from modelkit.assets.remote import StorageProvider
Expand All @@ -24,41 +26,34 @@ def assets_cli():
pass


gcs_fn_re = r"gs://(?P<bucket_name>[\w\-]+)/(?P<object_name>.+)"
storage_url_re = (
r"(?P<storage_prefix>[\w]*)://(?P<bucket_name>[\w\-]+)/(?P<object_name>.+)"
)


def parse_gcs(path):
match = re.match(gcs_fn_re, path)
def parse_remote_url(path):
match = re.match(storage_url_re, path)
if not match:
raise ValueError(f"Could not parse GCS path `{path}`")
raise ValueError(f"Could not parse path `{path}`")
return match.groupdict()


def _download_object_or_prefix(manager, asset_path, destination_dir):
parsed_path = parse_gcs(asset_path)
def _download_object_or_prefix(driver, object_name, destination_dir):
asset_path = os.path.join(destination_dir, "myasset")
ldeflandre marked this conversation as resolved.
Show resolved Hide resolved
try:
manager.storage_provider.driver.download_object(
object_name=parsed_path["object_name"],
destination_path=asset_path,
)
driver.download_object(object_name=object_name, destination_path=asset_path)
except ObjectDoesNotExistError:
# maybe prefix containing objects
paths = [
path
for path in manager.storage_provider.driver.iterate_objects(
prefix=parsed_path["object_name"]
)
]
paths = [path for path in driver.iterate_objects(prefix=object_name)]
if not paths:
raise

os.mkdir(asset_path)
for path in paths:
object_name = path.split("/")[-1]
manager.storage_provider.driver.download_object(
object_name=parsed_path["object_name"] + "/" + object_name,
destination_path=os.path.join(asset_path, object_name),
sub_object_name = path.split("/")[-1]
driver.download_object(
object_name=object_name + "/" + sub_object_name,
destination_path=os.path.join(asset_path, sub_object_name),
)
return asset_path

Expand Down Expand Up @@ -101,12 +96,16 @@ def new(asset_path, asset_spec, storage_prefix, dry_run):

NB: [asset_name] can contain `/` too.
"""
new_(asset_path, asset_spec, storage_prefix, dry_run)


def new_(asset_path, asset_spec, storage_prefix, dry_run):
_check_asset_file_number(asset_path)
manager = StorageProvider(
prefix=storage_prefix,
)
print("Current assets manager:")
print(f" - storage provider = `{manager.driver}`")
destination_provider = StorageProvider(prefix=storage_prefix)

print("Destination assets provider:")
print(f" - storage driver = `{destination_provider.driver}`")
print(f" - driver bucket = `{destination_provider.driver.bucket}`")
print(f" - prefix = `{storage_prefix}`")

print(f"Current asset: `{asset_spec}`")
Expand All @@ -119,13 +118,24 @@ def new(asset_path, asset_spec, storage_prefix, dry_run):
response = click.prompt("[y/N]")
if response == "y":
with tempfile.TemporaryDirectory() as tmp_dir:
if asset_path.startswith("gs://"):
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
ldeflandre marked this conversation as resolved.
Show resolved Hide resolved
if parsed_path["storage_prefix"] == "gs":
driver = GCSStorageDriver(bucket=parsed_path["bucket_name"])
elif parsed_path["storage_prefix"] == "s3":
driver = S3StorageDriver(bucket=parsed_path["bucket_name"])
else:
raise ValueError(
f"Unmanaged storage prefix `{parsed_path['storage_prefix']}`"
)
asset_path = _download_object_or_prefix(
manager, asset_path=asset_path, destination_dir=tmp_dir
driver,
object_name=parsed_path["object_name"],
destination_dir=tmp_dir,
)
manager.new(asset_path, spec.name, version, dry_run)
else:
print("Aborting.")
destination_provider.new(asset_path, spec.name, version, dry_run)
return version
print("Aborting.")


@assets_cli.command("update")
Expand Down Expand Up @@ -156,14 +166,20 @@ def update(asset_path, asset_spec, storage_prefix, bump_major, dry_run):

Specific documentation depends on the choosen model
"""
try:
update_(asset_path, asset_spec, storage_prefix, bump_major, dry_run)
except ObjectDoesNotExistError:
print("Remote asset not found. Create it first using `new`")
sys.exit(1)


def update_(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
_check_asset_file_number(asset_path)
manager = StorageProvider(
prefix=storage_prefix,
)
destination_provider = StorageProvider(prefix=storage_prefix)

print("Current assets manager:")
print(f" - storage provider = `{manager.driver}`")
print("Destination assets provider:")
print(f" - storage driver = `{destination_provider.driver}`")
print(f" - driver bucket = `{destination_provider.driver.bucket}`")
print(f" - prefix = `{storage_prefix}`")

print(f"Current asset: `{asset_spec}`")
Expand All @@ -175,11 +191,7 @@ def update(asset_path, asset_spec, storage_prefix, bump_major, dry_run):
print(f" - name = `{spec.name}`")
print(f" - version = `{spec.version}`")

try:
version_list = manager.get_versions_info(spec.name)
except ObjectDoesNotExistError:
print("Remote asset not found. Create it first using `new`")
sys.exit(1)
version_list = destination_provider.get_versions_info(spec.name)

update_params = spec.versioning.get_update_cli_params(
version=spec.version,
Expand All @@ -196,21 +208,31 @@ def update(asset_path, asset_spec, storage_prefix, bump_major, dry_run):

response = click.prompt("[y/N]")
if response == "y":

with tempfile.TemporaryDirectory() as tmp_dir:
if asset_path.startswith("gs://"):
if not os.path.exists(asset_path):
parsed_path = parse_remote_url(asset_path)
if parsed_path["storage_prefix"] == "gs":
driver = GCSStorageDriver(bucket=parsed_path["bucket_name"])
elif parsed_path["storage_prefix"] == "s3":
driver = S3StorageDriver(bucket=parsed_path["bucket_name"])
else:
raise ValueError(
f"Unmanaged storage prefix `{parsed_path['storage_prefix']}`"
)
asset_path = _download_object_or_prefix(
manager, asset_path=asset_path, destination_dir=tmp_dir
driver,
object_name=parsed_path["object_name"],
destination_dir=tmp_dir,
)

manager.update(
destination_provider.update(
asset_path,
name=spec.name,
version=new_version,
dry_run=dry_run,
)
else:
print("Aborting.")
return new_version
print("Aborting.")


@assets_cli.command("list")
Expand Down
31 changes: 12 additions & 19 deletions tests/assets/test_assetsmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,46 +101,39 @@ def test_az_assetsmanager(az_assetsmanager):
@skip_unless("ENABLE_GCS_TEST", "True")
def test_download_object_or_prefix_cli(gcs_assetsmanager):
original_asset_path = os.path.join(test_path, "testdata", "some_data.json")
gcs_asset_dir = (
f"gs://{gcs_assetsmanager.storage_provider.driver.bucket}/"
f"{gcs_assetsmanager.storage_provider.prefix}"
"/category-test/some-data.ext"
)
gcs_asset_path = gcs_asset_dir + "/1.0"

gcs_assetsmanager.storage_provider.push(
original_asset_path, "category-test/some-data.ext", "1.0"
)
provider = gcs_assetsmanager.storage_provider

object_dir = f"{provider.prefix}/category-test/some-data.ext"
object_name = object_dir + "/1.0"

provider.push(original_asset_path, "category-test/some-data.ext", "1.0")

with tempfile.TemporaryDirectory() as tmp_dir:
asset_path = modelkit.assets.cli._download_object_or_prefix(
gcs_assetsmanager, asset_path=gcs_asset_path, destination_dir=tmp_dir
provider.driver, object_name=object_name, destination_dir=tmp_dir
)
assert filecmp.cmp(original_asset_path, asset_path)

with tempfile.TemporaryDirectory() as tmp_dir:
asset_dir = modelkit.assets.cli._download_object_or_prefix(
gcs_assetsmanager, asset_path=gcs_asset_dir, destination_dir=tmp_dir
provider.driver, object_name=object_dir, destination_dir=tmp_dir
)
assert filecmp.cmp(original_asset_path, os.path.join(asset_dir, "1.0"))

with tempfile.TemporaryDirectory() as tmp_dir:
with pytest.raises(modelkit.assets.errors.ObjectDoesNotExistError):
modelkit.assets.cli._download_object_or_prefix(
gcs_assetsmanager,
asset_path=gcs_asset_dir + "file-not-found",
provider.driver,
object_name=object_name + "file-not-found",
destination_dir=tmp_dir,
)

with pytest.raises(modelkit.assets.errors.ObjectDoesNotExistError):
# fail because dir contains subdir
modelkit.assets.cli._download_object_or_prefix(
gcs_assetsmanager,
asset_path=(
f"gs://{gcs_assetsmanager.storage_provider.driver.bucket}/"
f"{gcs_assetsmanager.storage_provider.prefix}/"
"category-test"
),
provider.driver,
object_name=f"{provider.prefix}/category-test",
destination_dir=tmp_dir,
)

Expand Down