Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use iter based evaluation instead of async thread in sync/async exper… #47

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 12 additions & 22 deletions experiments/asynch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from parameter_servers.server_actor import ParameterServer
from workers.worker_task import compute_gradients
from metrics.metric_exporter import MetricExporter
from models.fashion_mnist import get_data_loader, evaluate
import threading
import copy
from evaluation.evaluator import async_eval
from evaluation.evaluator_state import evaluator_state
# from models.test_model import get_data_loader, evaluate
from models.fashion_mnist import fashion_mnist_get_data_loader
from models.model_common import evaluate

iterations = 200
num_workers = 2
Expand All @@ -15,22 +13,15 @@ def run_async(model, num_workers=1, epochs=5, server_kill_timeout=10, server_rec
metric_exporter = MetricExporter.remote("async control")
ps = ParameterServer.remote(1e-2)

test_loader = get_data_loader()[1]

# Start eval thread
model_copy = copy.deepcopy(model)
timer_runs = threading.Event()
timer_runs.set()
eval_thread = threading.Thread(target=async_eval, args=(timer_runs, model_copy, test_loader, metric_exporter, evaluate))
eval_thread.start()
test_loader = fashion_mnist_get_data_loader[1]

print("Running Asynchronous Parameter Server Training.")
current_weights = ps.get_weights.remote()
gradients = []
for _ in range(num_workers):
gradients.append(compute_gradients.remote(current_weights))

for _ in range(iterations * num_workers * epochs):
for i in range(iterations * num_workers * epochs):
ready_gradient_list, _ = ray.wait(gradients)
ready_gradient_id = ready_gradient_list[0]
gradients.remove(ready_gradient_id)
Expand All @@ -39,12 +30,11 @@ def run_async(model, num_workers=1, epochs=5, server_kill_timeout=10, server_rec
current_weights = ps.apply_gradients.remote([ready_gradient_id])
gradients.append(compute_gradients.remote(current_weights, metric_exporter=metric_exporter))

evaluator_state.weights_lock.acquire()
evaluator_state.CURRENT_WEIGHTS = ray.get(current_weights)
evaluator_state.weights_lock.release()

timer_runs.clear()
eval_thread.join() # Ensure the eval thread has finished
if i % 10 == 0:
# Evaluate the current model after every 10 updates.
model.set_weights(ray.get(current_weights))
accuracy = evaluate(model, test_loader)
print("Iter {}: \taccuracy is {:.1f}".format(i, accuracy))
metric_exporter.set_accuracy.remote(accuracy)

# Clean up Ray resources and processes before the next example.
ray.shutdown()
print("Final accuracy is {:.1f}.".format(accuracy))
49 changes: 21 additions & 28 deletions experiments/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,34 @@
from parameter_servers.server_actor import ParameterServer
from workers.worker_task import compute_gradients
from metrics.metric_exporter import MetricExporter
from models.fashion_mnist import get_data_loader, evaluate
import copy
import threading
from evaluation.evaluator import async_eval
from evaluation.evaluator_state import evaluator_state
# from models.test_model import get_data_loader, evaluate
from models.fashion_mnist import fashion_mnist_get_data_loader
from models.model_common import evaluate

iterations = 200
num_workers = 2

def run_sync(model, num_workers=1, epochs=5, server_kill_timeout=10, server_recovery_timeout=5):
metric_exporter = MetricExporter.remote("sync control")
test_loader = get_data_loader()[1]
ps = ParameterServer.remote(1e-2)
metric_exporter = MetricExporter.remote("sync control")
ps = ParameterServer.remote(1e-2)

# Start eval thread
model_copy = copy.deepcopy(model)
timer_runs = threading.Event()
timer_runs.set()
eval_thread = threading.Thread(target=async_eval, args=(timer_runs, model_copy, test_loader, metric_exporter, evaluate))
eval_thread.start()
test_loader = fashion_mnist_get_data_loader()[1]

print("Running synchronous parameter server training.")
weights_ref = ps.get_weights.remote()
for _ in range(iterations * epochs):
gradients = [compute_gradients.remote(weights_ref, metric_exporter=metric_exporter) for _ in range(num_workers)]
# Calculate update after all gradients are available.
weights_ref = ps.apply_gradients.remote(gradients)
print("Running synchronous parameter server training.")
current_weights = ps.get_weights.remote()
for i in range(iterations * epochs):
gradients = [compute_gradients.remote(current_weights, metric_exporter=metric_exporter) for _ in range(num_workers)]
# Calculate update after all gradients are available.
current_weights = ps.apply_gradients.remote(gradients)

evaluator_state.weights_lock.acquire()
evaluator_state.CURRENT_WEIGHTS = ray.get(weights_ref)
model.set_weights(evaluator_state.CURRENT_WEIGHTS)
evaluator_state.weights_lock.release()
if i % 10 == 0:
# Evaluate the current model.
model.set_weights(ray.get(current_weights))
accuracy = evaluate(model, test_loader)
print("Iter {}: \taccuracy is {:.1f}".format(i, accuracy))
metric_exporter.set_accuracy.remote(accuracy)

timer_runs.clear()
eval_thread.join() # Ensure the eval thread has finished
print("Final accuracy is {:.1f}.".format(accuracy))

# Clean up Ray resources and processes before the next example.
ray.shutdown()
# Clean up Ray resources and processes before the next example.
ray.shutdown()
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from experiments.debug_disk_checkpointing import run_debug_disk_checkpointing
from experiments.debug_object_store_checkpointing import run_debug_object_store_checkpointing
# from models.test_model import ConvNet
from models.fashion_mnist import ConvNet
from models.fashion_mnist import FashionMNISTConvNet

num_workers = 2

Expand All @@ -35,7 +35,7 @@

MODEL_MAP = {
"IMAGENET": None,
"DEBUG": ConvNet()
"DEBUG": FashionMNISTConvNet()
}

# TODO: This doesn't seem to make the randomness consistent
Expand All @@ -56,7 +56,7 @@ def main():
print(ray.init(ignore_reinit_error=True, _metrics_export_port=8081))

# Ensure consistency across experiments when it comes to randomness
init_random_seeds()
# init_random_seeds()

# Use flags for argument parsing
parser = argparse.ArgumentParser()
Expand Down
35 changes: 7 additions & 28 deletions models/fashion_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@
from torchvision.transforms import Normalize, ToTensor
from filelock import FileLock

def get_data_loader():
def fashion_mnist_get_data_loader():
batch_size = 32
# Transform to normalize the input images
transform = transforms.Compose([ToTensor(), Normalize((0.5,), (0.5,))])

with FileLock(os.path.expanduser("~/data.lock")):
# Download training data from open datasets
training_data = datasets.FashionMNIST(
root="~/data",
train=True,
download=True,
transform=transform,
root="~/data",
train=True,
download=True,
transform=transform,
)

# Download test data from open datasets
Expand All @@ -35,30 +35,9 @@ def get_data_loader():

return train_dataloader, test_dataloader


def evaluate(model, test_loader):
"""Evaluates the accuracy of the model on a validation dataset."""
model.eval()
correct = 0
total = 0
test_loss, num_correct, num_total = 0, 0, 0
loss_fn = nn.CrossEntropyLoss()
with torch.no_grad():
for batch, (X, y) in enumerate(test_loader):
pred = model(X)
loss = loss_fn(pred, y)

test_loss += loss.item()
num_total += y.shape[0]
num_correct += (pred.argmax(1) == y).sum().item()

test_loss /= len(test_loader)
accuracy = num_correct / num_total
return accuracy

class ConvNet(nn.Module):
class FashionMNISTConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
super(FashionMNISTConvNet, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
Expand Down
10 changes: 5 additions & 5 deletions parameter_servers/server_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import os
from workers.worker_task import compute_gradients
# from models.test_model import ConvNet, get_data_loader, evaluate
# from models.fashion_mnist import ConvNet, get_data_loader, evaluate
from models.cifar10 import ResNet, get_data_loader, evaluate
from models.fashion_mnist import FashionMNISTConvNet, fashion_mnist_get_data_loader
from models.model_common import evaluate
# from models.cifar10 import ResNet, get_data_loader, evaluate
from zookeeper.zoo import KazooChainNode

# TODO (Change to training epochs)
Expand All @@ -20,7 +21,7 @@
class ParameterServer(object):
def __init__(self, lr, node_id=None, metric_exporter=None):
#self.model = ConvNet()
self.model = ResNet()
self.model = FashionMNISTConvNet()
self.start_iteration = 0
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
self.start_iteration = 0
Expand Down Expand Up @@ -71,8 +72,7 @@ def retrieve_weights_from_zookeeper(self, event):
self.chain_node.zk.exists("/exp3/"+str(node_id), watch=self.chain_node.handle_delete_or_change_event)

def run_synch_chain_node_experiment(self, num_workers):
# test_loader = get_data_loader()[0]
test_loader = get_data_loader()[1]
test_loader = fashion_mnist_get_data_loader[1]

print("Running synchronous parameter server training.")
current_weights = self.get_weights()
Expand Down
8 changes: 4 additions & 4 deletions parameter_servers/server_actor_disk_ckpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from workers.worker_task import compute_gradients
# from models.test_model import ConvNet
# from models.test_model import ConvNet, get_data_loader, evaluate
from models.fashion_mnist import ConvNet
from models.fashion_mnist import get_data_loader, evaluate
from models.fashion_mnist import FashionMNIST, fashion_mnist_get_data_loader
from models.model_common import evaluate

iterations = 200
num_workers = 2
Expand Down Expand Up @@ -67,7 +67,7 @@ def run_training(self, synchronous=True):


def run_synch_training(self):
test_loader = get_data_loader()[1]
test_loader = fashion_mnist_get_data_loader()[1]

print("Running synchronous parameter server training.")
current_weights = self.get_weights()
Expand All @@ -85,7 +85,7 @@ def run_synch_training(self):
print("Final accuracy is {:.1f}.".format(accuracy))

def run_asynch_training(self):
test_loader = get_data_loader()[1]
test_loader = fashion_mnist_get_data_loader()[1]

print("Running Asynchronous Parameter Server Training.")
current_weights = self.get_weights()
Expand Down
5 changes: 3 additions & 2 deletions parameter_servers/server_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
import ray
# from models.test_model import get_data_loader, evaluate
from models.fashion_mnist import get_data_loader, evaluate
from models.fashion_mnist import fashion_mnist_get_data_loader
from models.model_common import evaluate
from kazoo.client import KazooClient
from kazoo.exceptions import NodeExistsError, NoNodeError

Expand Down Expand Up @@ -33,7 +34,7 @@ def _load_weights_for_optimizer(self, zk, model, lr):
def run_parameter_server_task(self, model, num_workers, lr, weight_saver, metric_exporter):
print("Parameter Server is starting")
then = time.time()
test_loader = get_data_loader()[1]
test_loader = fashion_mnist_get_data_loader[1]

zk = self._start_zk()
model, optimizer = self._load_weights_for_optimizer(zk, model, lr)
Expand Down
17 changes: 7 additions & 10 deletions workers/worker_task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import ray.cloudpickle
import torch.nn.functional as F
import torch.nn as nn
from models.fashion_mnist import ConvNet, get_data_loader
from models.fashion_mnist import FashionMNISTConvNet, fashion_mnist_get_data_loader
# from models.test_model import ConvNet, get_data_loader
# from models.cifar10 import ResNet, get_data_loader
from kazoo.client import KazooClient
from kazoo.exceptions import NodeExistsError
import ray
Expand All @@ -12,22 +11,20 @@

@ray.remote
def compute_gradients(weights, metric_exporter=None):
# model = ConvNet()
model = ResNet()
data_iterator = iter(get_data_loader()[0])
model = FashionMNISTConvNet()
data_iterator = iter(fashion_mnist_get_data_loader()[0])

model.train()
model.set_weights(weights)
try:
data, target = next(data_iterator)
except StopIteration: # When the epoch ends, start a new epoch.
data_iterator = iter(get_data_loader()[0])
data_iterator = iter(fashion_mnist_get_data_loader()[0])
data, target = next(data_iterator)
model.zero_grad()
output = model(data)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(output, target)
print("training loss is", loss.item())
if metric_exporter is not None:
metric_exporter.set_loss.remote(loss.item())
loss.backward()
Expand All @@ -37,7 +34,7 @@ def compute_gradients(weights, metric_exporter=None):
def compute_gradients_relaxed_consistency(model, worker_index, epochs=5, metric_exporter=None):
curr_epoch = 0
print(f"Worker {worker_index} is starting at Epoch {curr_epoch}")
data_iterator = iter(get_data_loader()[0])
data_iterator = iter(fashion_mnist_get_data_loader()[0])
zk = KazooClient(hosts='127.0.0.1:2181')
zk.start()

Expand Down Expand Up @@ -87,7 +84,7 @@ def compute_grads(data, target):
try:
data, target = next(data_iterator)
except StopIteration: # When the epoch ends, start a new epoch.
data_iterator = iter(get_data_loader()[0])
data_iterator = iter(fashion_mnist_get_data_loader()[0])
data, target = next(data_iterator)
model.zero_grad()
output = model(data)
Expand All @@ -105,7 +102,7 @@ def has_next_data():
return True, d, t
except StopIteration:
if curr_epoch < epochs:
data_iterator = iter(get_data_loader()[0])
data_iterator = iter(fashion_mnist_get_data_loader()[0])
d, t = next(data_iterator)
curr_epoch += 1
print(f"Starting Epoch {curr_epoch}")
Expand Down