Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unicastILP, compress operator, encrypt operator #912

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
data/
dist/
.pytype/
.history/
docs/_build

# Byte-compiled / optimized / DLL files
Expand Down
14 changes: 13 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ cvxpy = { version = ">=1.1.0", optional = true }
graphviz = { version = ">=0.15", optional = true }
matplotlib = { version = ">=3.0.0", optional = true }
numpy = { version = ">=1.19.0", optional = true }
networkx = { version = ">=2.5", optional = true }

# gateway dependencies
flask = { version = "^2.1.2", optional = true }
Expand All @@ -70,7 +71,7 @@ gcp = ["google-api-python-client", "google-auth", "google-cloud-compute", "googl
ibm = ["ibm-cloud-sdk-core", "ibm-cos-sdk", "ibm-vpc"]
all = ["boto3", "azure-identity", "azure-mgmt-authorization", "azure-mgmt-compute", "azure-mgmt-network", "azure-mgmt-resource", "azure-mgmt-storage", "azure-mgmt-subscription", "azure-storage-blob", "google-api-python-client", "google-auth", "google-cloud-compute", "google-cloud-storage", "ibm-cloud-sdk-core", "ibm-cos-sdk", "ibm-vpc"]
gateway = ["flask", "lz4", "pynacl", "pyopenssl", "werkzeug"]
solver = ["cvxpy", "graphviz", "matplotlib", "numpy"]
solver = ["networkx", "cvxpy", "graphviz", "matplotlib", "numpy"]

[tool.poetry.dev-dependencies]
pytest = ">=6.0.0"
Expand Down
1 change: 1 addition & 0 deletions scripts/requirements-gateway.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ numpy
pandas
pyarrow
typer
networkx
11 changes: 7 additions & 4 deletions skyplane/api/dataplane.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import nacl.secret
import nacl.utils
import typer
import urllib3
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional
Expand Down Expand Up @@ -89,7 +90,6 @@ def _start_gateway(
gateway_server: compute.Server,
gateway_log_dir: Optional[PathLike],
authorize_ssh_pub_key: Optional[str] = None,
e2ee_key_bytes: Optional[str] = None,
):
# map outgoing ports
setup_args = {}
Expand Down Expand Up @@ -119,9 +119,7 @@ def _start_gateway(
gateway_docker_image=gateway_docker_image,
gateway_program_path=str(gateway_program_filename),
gateway_info_path=f"{gateway_log_dir}/gateway_info.json",
e2ee_key_bytes=e2ee_key_bytes, # TODO: remove
use_bbr=self.transfer_config.use_bbr, # TODO: remove
use_compression=self.transfer_config.use_compression,
use_socket_tls=self.transfer_config.use_socket_tls,
)

Expand Down Expand Up @@ -202,6 +200,11 @@ def provision(
# todo: move server.py:start_gateway here
logger.fs.info(f"Using docker image {gateway_docker_image}")
e2ee_key_bytes = nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE)
# save E2EE keys
e2ee_key_file = "e2ee_key"
if not os.path.exists(f"/tmp/{e2ee_key_file}"):
with open(f"/tmp/{e2ee_key_file}", 'wb') as f:
f.write(e2ee_key_bytes)

# create gateway logging dir
gateway_program_dir = f"{self.log_dir}/programs"
Expand All @@ -218,7 +221,7 @@ def provision(
jobs = []
for node, server in gateway_bound_nodes.items():
jobs.append(
partial(self._start_gateway, gateway_docker_image, node, server, gateway_program_dir, authorize_ssh_pub_key, e2ee_key_bytes)
partial(self._start_gateway, gateway_docker_image, node, server, gateway_program_dir, authorize_ssh_pub_key)
)
logger.fs.debug(f"[Dataplane.provision] Starting gateways on {len(jobs)} servers")
try:
Expand Down
16 changes: 12 additions & 4 deletions skyplane/api/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob
from skyplane.api.config import TransferConfig

from skyplane.planner.planner import MulticastDirectPlanner, DirectPlannerSourceOneSided, DirectPlannerDestOneSided
from skyplane.planner.planner import (
MulticastDirectPlanner,
DirectPlannerSourceOneSided,
DirectPlannerDestOneSided,
UnicastILPPlanner,
)
from skyplane.planner.topology import TopologyPlanGateway
from skyplane.utils import logger
from skyplane.utils.definitions import tmp_log_dir
Expand Down Expand Up @@ -62,11 +67,13 @@ def __init__(
# planner
self.planning_algorithm = planning_algorithm
if self.planning_algorithm == "direct":
self.planner = MulticastDirectPlanner(self.max_instances, self.n_connections, self.transfer_config)
self.planner = MulticastDirectPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "src_one_sided":
self.planner = DirectPlannerSourceOneSided(self.max_instances, self.n_connections, self.transfer_config)
self.planner = DirectPlannerSourceOneSided(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "dst_one_sided":
self.planner = DirectPlannerDestOneSided(self.max_instances, self.n_connections, self.transfer_config)
self.planner = DirectPlannerDestOneSided(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "uni_ilp":
self.planning_algorithm = UnicastILPPlanner(self.transfer_config, self.max_instances, self.n_connections)
else:
raise ValueError(f"No such planning algorithm {planning_algorithm}")

Expand Down Expand Up @@ -185,3 +192,4 @@ def estimate_total_cost(self):

# return size
return total_size * topo.cost_per_gb

12 changes: 8 additions & 4 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,10 @@ def transfer_pair_generator(
dest_provider, dest_region = dst_iface.region_tag().split(":")
try:
dest_key = self.map_object_key_prefix(src_prefix, obj.key, dst_prefix, recursive=recursive)
assert (
dest_key[: len(dst_prefix)] == dst_prefix
), f"Destination key {dest_key} does not start with destination prefix {dst_prefix}"
dest_keys.append(dest_key[len(dst_prefix) :])
# TODO: why is it changed here?
# dest_keys.append(dest_key[len(dst_prefix) :])

dest_keys.append(dest_key)
except exceptions.MissingObjectException as e:
logger.fs.exception(e)
raise e from None
Expand Down Expand Up @@ -508,8 +508,12 @@ def dst_prefixes(self) -> List[str]:
if not hasattr(self, "_dst_prefix"):
if self.transfer_type == "unicast":
self._dst_prefix = [str(parse_path(self.dst_paths[0])[2])]
print("return dst_prefixes for unicast", self._dst_prefix)
else:
for path in self.dst_paths:
print("Parsing result for multicast", parse_path(path))
self._dst_prefix = [str(parse_path(path)[2]) for path in self.dst_paths]
print("return dst_prefixes for multicast", self._dst_prefix)
return self._dst_prefix

@property
Expand Down
11 changes: 3 additions & 8 deletions skyplane/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ class Chunk:
part_number: Optional[int] = None
upload_id: Optional[str] = None # TODO: for broadcast, this is not used

def to_wire_header(self, n_chunks_left_on_socket: int, wire_length: int, raw_wire_length: int, is_compressed: bool = False):
def to_wire_header(self, n_chunks_left_on_socket: int, wire_length: int, raw_wire_length: int):
return WireProtocolHeader(
chunk_id=self.chunk_id,
data_len=wire_length,
raw_data_len=raw_wire_length,
is_compressed=is_compressed,
n_chunks_left_on_socket=n_chunks_left_on_socket,
)

Expand Down Expand Up @@ -99,7 +98,6 @@ class WireProtocolHeader:
chunk_id: str # 128bit UUID
data_len: int # long
raw_data_len: int # long (uncompressed, unecrypted)
is_compressed: bool # char
n_chunks_left_on_socket: int # long

@staticmethod
Expand All @@ -115,8 +113,8 @@ def protocol_version():

@staticmethod
def length_bytes():
# magic (8) + protocol_version (4) + chunk_id (16) + data_len (8) + raw_data_len(8) + is_compressed (1) + n_chunks_left_on_socket (8)
return 8 + 4 + 16 + 8 + 8 + 1 + 8
# magic (8) + protocol_version (4) + chunk_id (16) + data_len (8) + raw_data_len(8) + n_chunks_left_on_socket (8)
return 8 + 4 + 16 + 8 + 8 + 8

@staticmethod
def from_bytes(data: bytes):
Expand All @@ -130,13 +128,11 @@ def from_bytes(data: bytes):
chunk_id = data[12:28].hex()
chunk_len = int.from_bytes(data[28:36], byteorder="big")
raw_chunk_len = int.from_bytes(data[36:44], byteorder="big")
is_compressed = bool(int.from_bytes(data[44:45], byteorder="big"))
n_chunks_left_on_socket = int.from_bytes(data[45:53], byteorder="big")
return WireProtocolHeader(
chunk_id=chunk_id,
data_len=chunk_len,
raw_data_len=raw_chunk_len,
is_compressed=is_compressed,
n_chunks_left_on_socket=n_chunks_left_on_socket,
)

Expand All @@ -149,7 +145,6 @@ def to_bytes(self):
out_bytes += chunk_id_bytes
out_bytes += self.data_len.to_bytes(8, byteorder="big")
out_bytes += self.raw_data_len.to_bytes(8, byteorder="big")
out_bytes += self.is_compressed.to_bytes(1, byteorder="big")
out_bytes += self.n_chunks_left_on_socket.to_bytes(8, byteorder="big")
assert len(out_bytes) == WireProtocolHeader.length_bytes(), f"{len(out_bytes)} != {WireProtocolHeader.length_bytes()}"
return out_bytes
Expand Down
14 changes: 2 additions & 12 deletions skyplane/compute/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,6 @@ def start_gateway(
gateway_info_path: str,
log_viewer_port=8888,
use_bbr=False,
use_compression=False,
e2ee_key_bytes=None,
use_socket_tls=False,
):
def check_stderr(tup):
Expand Down Expand Up @@ -338,13 +336,6 @@ def check_stderr(tup):
if self.provider == "aws":
docker_envs["AWS_DEFAULT_REGION"] = self.region_tag.split(":")[1]

# copy E2EE keys
if e2ee_key_bytes is not None:
e2ee_key_file = "e2ee_key"
self.write_file(e2ee_key_bytes, f"/tmp/{e2ee_key_file}")
docker_envs["E2EE_KEY_FILE"] = f"/pkg/data/{e2ee_key_file}"
docker_run_flags += f" -v /tmp/{e2ee_key_file}:/pkg/data/{e2ee_key_file}"

# upload gateway programs and gateway info
gateway_program_file = os.path.basename(gateway_program_path).replace(":", "_")
gateway_info_file = os.path.basename(gateway_info_path).replace(":", "_")
Expand All @@ -359,8 +350,7 @@ def check_stderr(tup):
# update docker flags
docker_run_flags += " " + " ".join(f"--env {k}={v}" for k, v in docker_envs.items())

gateway_daemon_cmd += f" --region {self.region_tag} {'--use-compression' if use_compression else ''}"
gateway_daemon_cmd += f" {'--disable-e2ee' if e2ee_key_bytes is None else ''}"
gateway_daemon_cmd += f" --region {self.region_tag}"
gateway_daemon_cmd += f" {'--disable-tls' if not use_socket_tls else ''}"
escaped_gateway_daemon_cmd = gateway_daemon_cmd.replace('"', '\\"')
docker_launch_cmd = (
Expand All @@ -378,7 +368,7 @@ def check_stderr(tup):
logger.fs.debug(f"{self.uuid()} gateway_api_url = {self.gateway_api_url}")

# wait for gateways to start (check status API)
http_pool = urllib3.PoolManager()
http_pool = urllib3.PoolManager(retries=urllib3.Retry(total=10))

def is_api_ready():
try:
Expand Down
79 changes: 64 additions & 15 deletions skyplane/gateway/gateway_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@
GatewaySender,
GatewayRandomDataGen,
GatewayWriteLocal,
GatewayDeleteLocal,
GatewayObjStoreReadOperator,
GatewayObjStoreWriteOperator,
GatewayWaitReceiver,
GatewayCompressor,
GatewayDecompressor,
GatewayEncrypter,
GatewayDecrypter,
)
from skyplane.gateway.operators.gateway_receiver import GatewayReceiver
from skyplane.utils import logger
Expand All @@ -38,8 +43,6 @@ def __init__(
chunk_dir: PathLike,
max_incoming_ports=64,
use_tls=True,
use_e2ee=True, # TODO: read from operator field
use_compression=True, # TODO: read from operator field
):
# read gateway program
gateway_program_path = Path(os.environ["GATEWAY_PROGRAM_FILE"]).expanduser()
Expand Down Expand Up @@ -68,13 +71,6 @@ def __init__(

self.error_event = Event()
self.error_queue = Queue()
if use_e2ee:
e2ee_key_path = Path(os.environ["E2EE_KEY_FILE"]).expanduser()
with open(e2ee_key_path, "rb") as f:
self.e2ee_key_bytes = f.read()
print("Server side E2EE key loaded: ", self.e2ee_key_bytes)
else:
self.e2ee_key_bytes = None

# create gateway operators
self.terminal_operators = defaultdict(list) # track terminal operators per partition
Expand All @@ -90,8 +86,6 @@ def __init__(
error_queue=self.error_queue,
max_pending_chunks=max_incoming_ports,
use_tls=self.use_tls,
use_compression=use_compression,
e2ee_key_bytes=self.e2ee_key_bytes,
)

# API server
Expand Down Expand Up @@ -232,8 +226,6 @@ def create_gateway_operators_helper(input_queue, program: List[Dict], partition_
error_queue=self.error_queue,
chunk_store=self.chunk_store,
use_tls=self.use_tls,
use_compression=op["compress"],
e2ee_key_bytes=self.e2ee_key_bytes,
n_processes=op["num_connections"],
)
total_p += op["num_connections"]
Expand Down Expand Up @@ -264,6 +256,65 @@ def create_gateway_operators_helper(input_queue, program: List[Dict], partition_
chunk_store=self.chunk_store,
)
total_p += 1
elif op["op_type"] == "delete_local":
operators[handle] = GatewayDeleteLocal(
handle=handle,
region=self.region,
input_queue=input_queue,
output_queue=output_queue,
error_queue=self.error_queue,
error_event=self.error_event,
chunk_store=self.chunk_store,
)
total_p += 1
elif op["op_type"] == "compress":
operators[handle] = GatewayCompressor(
handle=handle,
region=self.region,
input_queue=input_queue,
output_queue=output_queue,
error_event=self.error_event,
error_queue=self.error_queue,
chunk_store=self.chunk_store,
use_compression=op["compress"],
)
total_p += 1
elif op["op_type"] == "decompress":
operators[handle] = GatewayDecompressor(
handle=handle,
region=self.region,
input_queue=input_queue,
output_queue=output_queue,
error_event=self.error_event,
error_queue=self.error_queue,
chunk_store=self.chunk_store,
use_compression=op["compress"],
)
total_p += 1
elif op["op_type"] == "encrypt":
operators[handle] = GatewayEncrypter(
handle=handle,
region=self.region,
input_queue=input_queue,
output_queue=output_queue,
error_event=self.error_event,
error_queue=self.error_queue,
chunk_store=self.chunk_store,
e2ee_key_bytes=op["e2ee_key_bytes"],
)
total_p += 1
elif op["op_type"] == "decrypt":
operators[handle] = GatewayDecrypter(
handle=handle,
region=self.region,
input_queue=input_queue,
output_queue=output_queue,
error_event=self.error_event,
error_queue=self.error_queue,
chunk_store=self.chunk_store,
e2ee_key_bytes=op["e2ee_key_bytes"],
)
total_p += 1
else:
raise ValueError(f"Unsupported op_type {op['op_type']}")
# recursively create for child operators
Expand Down Expand Up @@ -346,8 +397,6 @@ def exit_handler(signum, frame):
parser.add_argument("--region", type=str, required=True, help="Region tag (provider:region")
parser.add_argument("--chunk-dir", type=Path, default="/tmp/skyplane/chunks", help="Directory to store chunks")
parser.add_argument("--disable-tls", action="store_true")
parser.add_argument("--use-compression", action="store_true") # TODO: remove
parser.add_argument("--disable-e2ee", action="store_true") # TODO: remove
args = parser.parse_args()

os.makedirs(args.chunk_dir)
Expand Down
Loading