Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

continue advanced planner #854

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions skyplane/api/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
import json
import time
import os
import threading
from collections import defaultdict, Counter
from datetime import datetime
from functools import partial
from datetime import datetime

import nacl.secret
import nacl.utils
import urllib3
from typing import TYPE_CHECKING, Dict, List, Optional

from skyplane import compute
from skyplane.api.tracker import TransferProgressTracker, TransferHook
from skyplane.api.tracker import TransferProgressTracker
from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob
from skyplane.api.config import TransferConfig

from skyplane.planner.planner import MulticastDirectPlanner
from skyplane.planner.planner import (
MulticastDirectPlanner,
UnicastDirectPlanner,
UnicastILPPlanner,
MulticastILPPlanner,
MulticastMDSTPlanner,
)
from skyplane.planner.topology import TopologyPlanGateway
from skyplane.utils import logger
from skyplane.utils.definitions import gateway_docker_image, tmp_log_dir
from skyplane.utils.fn import PathLike, do_parallel
from skyplane.utils.definitions import tmp_log_dir

from skyplane.api.dataplane import Dataplane

Expand All @@ -39,6 +36,7 @@ def __init__(
transfer_config: TransferConfig,
# cloud_regions: dict,
max_instances: Optional[int] = 1,
num_connections: Optional[int] = 32,
planning_algorithm: Optional[str] = "direct",
debug: Optional[bool] = False,
):
Expand Down Expand Up @@ -67,8 +65,18 @@ def __init__(

# planner
self.planning_algorithm = planning_algorithm

if self.planning_algorithm == "direct":
self.planner = MulticastDirectPlanner(self.max_instances, 64)
# TODO: should find some ways to merge direct / Ndirect
self.planner = UnicastDirectPlanner(self.max_instances, num_connections)
elif self.planning_algorithm == "Ndirect":
self.planner = MulticastDirectPlanner(self.max_instances, num_connections)
elif self.planning_algorithm == "MDST":
self.planner = MulticastMDSTPlanner(self.max_instances, num_connections)
elif self.planning_algorithm == "ILP":
self.planning_algorithm = MulticastILPPlanner(self.max_instances, num_connections)
elif self.planning_algorithm == "UnicastILP":
self.planning_algorithm = UnicastILPPlanner(self.max_instances, num_connections)
else:
raise ValueError(f"No such planning algorithm {planning_algorithm}")

Expand Down Expand Up @@ -112,7 +120,7 @@ def start(self, debug=False, progress=False):
# copy gateway logs
if debug:
dp.copy_gateway_logs()
except Exception as e:
except Exception:
dp.copy_gateway_logs()
dp.deprovision(spinner=True)
return dp
Expand Down
20 changes: 12 additions & 8 deletions skyplane/api/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def log_exception(
exception: Exception,
args: Optional[Dict] = None,
src_region_tag: Optional[str] = None,
dest_region_tag: Optional[str] = None, # TODO: fix this for mult-dest
dest_region_tags: Optional[List[str]] = None, # TODO: fix this for mult-dest
session_start_timestamp_ms: Optional[int] = None,
):
if cls.enabled():
Expand All @@ -146,7 +146,7 @@ def log_exception(
error_dict=error_dict,
arguments_dict=args,
src_region_tag=src_region_tag,
dest_region_tag=dest_region_tag,
dest_region_tags=dest_region_tags,
session_start_timestamp_ms=session_start_timestamp_ms,
)
destination = client.write_usage_data(stats)
Expand All @@ -158,7 +158,7 @@ def log_transfer(
transfer_stats: Optional[Dict],
args: Optional[Dict] = None,
src_region_tag: Optional[str] = None,
dest_region_tags: Optional[str] = None,
dest_region_tags: Optional[List[str]] = None,
session_start_timestamp_ms: Optional[int] = None,
):
if cls.enabled():
Expand Down Expand Up @@ -250,7 +250,7 @@ def make_stat(
arguments_dict: Optional[Dict] = None,
transfer_stats: Optional[Dict] = None,
src_region_tag: Optional[str] = None,
dest_region_tags: Optional[str] = None,
dest_region_tags: Optional[List[str]] = None,
session_start_timestamp_ms: Optional[int] = None,
):
if src_region_tag is None:
Expand All @@ -261,7 +261,9 @@ def make_stat(
dest_provider, dest_region = None, None
else:
# TODO: have usage stats view for multiple destinations
dest_provider, dest_region = dest_region_tags[0].split(":")
dest_region_tag = [dest_region_tag.split(":") for dest_region_tag in dest_region_tags]
dest_provider, dest_region = list(zip(*dest_region_tag))
dest_provider, dest_region = ','.join(dest_provider), ','.join(dest_region)

return UsageStatsToReport(
skyplane_version=skyplane.__version__,
Expand All @@ -284,18 +286,20 @@ def make_error(
error_dict: Dict,
arguments_dict: Optional[Dict] = None,
src_region_tag: Optional[str] = None,
dest_region_tag: Optional[str] = None,
dest_region_tags: Optional[List[str]] = None,
session_start_timestamp_ms: Optional[int] = None,
):
if src_region_tag is None:
src_provider, src_region = None, None
else:
src_provider, src_region = src_region_tag.split(":")

if dest_region_tag is None:
if dest_region_tags is None:
dest_provider, dest_region = None, None
else:
dest_provider, dest_region = dest_region_tag.split(":")
dest_region_tag = [dest_region_tag.split(":") for dest_region_tag in dest_region_tags]
dest_provider, dest_region = list(zip(*dest_region_tag))
dest_provider, dest_region = ','.join(dest_provider), ','.join(dest_region)

return UsageStatsToReport(
skyplane_version=skyplane.__version__,
Expand Down
39 changes: 20 additions & 19 deletions skyplane/cli/cli_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
from skyplane.api.usage import UsageClient
from skyplane.config import SkyplaneConfig
from skyplane.config_paths import cloud_config, config_path
from skyplane.obj_store.object_store_interface import ObjectStoreInterface, StorageInterface
from skyplane.obj_store.file_system_interface import FileSystemInterface
from skyplane.obj_store.object_store_interface import StorageInterface
from skyplane.cli.impl.progress_bar import ProgressBarTransferHook
from skyplane.utils import logger
from skyplane.utils.definitions import GB, format_bytes
from skyplane.utils.path import parse_path
from skyplane.utils.path import parse_path, parse_multi_paths


@dataclass
Expand All @@ -50,8 +49,8 @@ def to_dict(self) -> Dict[str, Optional[Any]]:


class SkyplaneCLI:
def __init__(self, src_region_tag: str, dst_region_tag: str, args: Dict[str, Any], skyplane_config: Optional[SkyplaneConfig] = None):
self.src_region_tag, self.dst_region_tag = src_region_tag, dst_region_tag
def __init__(self, src_region_tag: str, dst_region_tags: List[str], args: Dict[str, Any], skyplane_config: Optional[SkyplaneConfig] = None):
self.src_region_tag, self.dst_region_tags = src_region_tag, dst_region_tags
self.args = args
self.aws_config, self.azure_config, self.gcp_config, self.ibmcloud_config = self.to_api_config(skyplane_config or cloud_config)

Expand Down Expand Up @@ -103,7 +102,8 @@ def to_api_config(self, config: SkyplaneConfig):
return aws_config, azure_config, gcp_config, ibmcloud_config

def make_transfer_config(self, config: SkyplaneConfig) -> TransferConfig:
intraregion = self.src_region_tag == self.dst_region_tag
# intraregion = self.src_region_tag == self.dst_region_tag
intraregion = self.src_region_tag
return TransferConfig(
autoterminate_minutes=config.get_flag("autoshutdown_minutes"),
requester_pays=config.get_flag("requester_pays"),
Expand Down Expand Up @@ -131,7 +131,7 @@ def check_config(self) -> bool:
return True
except skyplane.exceptions.BadConfigException as e:
logger.exception(e)
UsageClient.log_exception("cli_check_config", e, self.args, self.src_region_tag, self.dst_region_tag)
UsageClient.log_exception("cli_check_config", e, self.args, self.src_region_tag, self.dst_region_tags)
return False

def transfer_cp_onprem(self, src: str, dst: str, recursive: bool) -> bool:
Expand All @@ -144,7 +144,7 @@ def transfer_cp_onprem(self, src: str, dst: str, recursive: bool) -> bool:
if rc == 0:
print_stats_completed(request_time, None)
transfer_stats = TransferStats(monitor_status="completed", total_runtime_s=request_time, throughput_gbits=0)
UsageClient.log_transfer(transfer_stats.to_dict(), self.args, self.src_region_tag, self.dst_region_tag)
UsageClient.log_transfer(transfer_stats.to_dict(), self.args, self.src_region_tag, self.dst_region_tags)
return True
else:
typer.secho("Transfer not supported", fg="red")
Expand All @@ -160,7 +160,7 @@ def transfer_sync_onprem(self, src: str, dst: str) -> bool:
if rc == 0:
print_stats_completed(request_time, None)
transfer_stats = TransferStats(monitor_status="completed", total_runtime_s=request_time, throughput_gbits=0)
UsageClient.log_transfer(transfer_stats.to_dict(), self.args, self.src_region_tag, self.dst_region_tag)
UsageClient.log_transfer(transfer_stats.to_dict(), self.args, self.src_region_tag, self.dst_region_tags)
return True
else:
typer.secho("Transfer not supported", fg="red")
Expand Down Expand Up @@ -299,7 +299,7 @@ def force_deprovision(dp: skyplane.Dataplane):

def run_transfer(
src: str,
dst: str,
dst: List[str],
recursive: bool,
debug: bool,
multipart: bool,
Expand All @@ -315,9 +315,10 @@ def run_transfer(
print_header()

provider_src, bucket_src, path_src = parse_path(src)
provider_dst, bucket_dst, path_dst = parse_path(dst)
provider_dsts, bucket_dsts, path_dsts = parse_multi_paths(dst)
src_region_tag = StorageInterface.create(f"{provider_src}:infer", bucket_src).region_tag()
dst_region_tag = StorageInterface.create(f"{provider_dst}:infer", bucket_dst).region_tag()
dst_region_tags = StorageInterface.create_region_tags(provider_dsts, bucket_dsts)

args = {
"cmd": cmd,
"recursive": True,
Expand All @@ -330,7 +331,7 @@ def run_transfer(
}

# create CLI object
cli = SkyplaneCLI(src_region_tag=src_region_tag, dst_region_tag=dst_region_tag, args=args)
cli = SkyplaneCLI(src_region_tag=src_region_tag, dst_region_tags=dst_region_tags, args=args)
if not cli.check_config():
typer.secho(
f"Skyplane configuration file is not valid. Please reset your config by running `rm {config_path}` and then rerunning `skyplane init` to fix.",
Expand All @@ -346,15 +347,15 @@ def run_transfer(
pipeline.queue_sync(src, dst)

# confirm transfer
if not cli.confirm_transfer(pipeline, src_region_tag, [dst_region_tag], 5, ask_to_confirm_transfer=not confirm):
if not cli.confirm_transfer(pipeline, src_region_tag, dst_region_tags, 5, ask_to_confirm_transfer=not confirm):
return 1

# local->local transfers not supported (yet)
if provider_src == "local" and provider_dst == "local":
if provider_src == "local" and dst_region_tags[0] == "local":
raise NotImplementedError("Local->local transfers not supported (yet)")

# fall back options: local->cloud, cloud->local, small cloud->cloud transfers
if provider_src == "local" or provider_dst == "local":
if provider_src == "local" or provider_dsts[0] == "local":
if cli.args["cmd"] == "cp":
return 0 if cli.transfer_cp_onprem(src, dst, recursive) else 1
else:
Expand Down Expand Up @@ -388,20 +389,20 @@ def run_transfer(
logger.fs.exception(e)
console.print(f"[bright_black]{traceback.format_exc()}[/bright_black]")
console.print(e)
UsageClient.log_exception("cli_cp", e, args, cli.src_region_tag, cli.dst_region_tag)
UsageClient.log_exception("cli_cp", e, args, cli.src_region_tag, cli.dst_region_tags)
console.print("[bold red]Deprovisioning was interrupted! VMs may still be running which will incur charges.[/bold red]")
console.print("[bold red]Please manually deprovision the VMs by running `skyplane deprovision`.[/bold red]")
return 1
except skyplane.exceptions.SkyplaneException as e:
console.print(f"[bright_black]{traceback.format_exc()}[/bright_black]")
console.print(e.pretty_print_str())
UsageClient.log_exception("cli_query_objstore", e, args, cli.src_region_tag, cli.dst_region_tag)
UsageClient.log_exception("cli_query_objstore", e, args, cli.src_region_tag, cli.dst_region_tags)
force_deprovision(dp)
except Exception as e:
logger.fs.exception(e)
console.print(f"[bright_black]{traceback.format_exc()}[/bright_black]")
console.print(e)
UsageClient.log_exception("cli_query_objstore", e, args, cli.src_region_tag, cli.dst_region_tag)
UsageClient.log_exception("cli_query_objstore", e, args, cli.src_region_tag, cli.dst_region_tags)
force_deprovision(dp)


Expand Down
11 changes: 11 additions & 0 deletions skyplane/obj_store/storage_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,14 @@ def create(region_tag: str, bucket: str):
return POSIXInterface(bucket)
else:
raise ValueError(f"Invalid region_tag {region_tag} - could not create interface")

@staticmethod
def create_region_tags(provider_dsts, bucket_dsts):
if isinstance(provider_dsts, str):
provider_dsts = [provider_dsts]

dst_region_tags = []
for provider_dst, bucket_dst in zip(provider_dsts, bucket_dsts):
tag = StorageInterface.create(f"{provider_dst}:infer", bucket_dst).region_tag()
dst_region_tags.append(tag)
return dst_region_tags
Loading