diff --git a/benchmark/configs/auxo/auxo.yml b/benchmark/configs/auxo/auxo.yml new file mode 100644 index 00000000..aca16936 --- /dev/null +++ b/benchmark/configs/auxo/auxo.yml @@ -0,0 +1,52 @@ +# Configuration file of fed_hetero experiment + +# ========== Cluster configuration ========== +# ip address of the parameter server (need 1 GPU process) +ps_ip: localhost +ps_port: 12345 + +# ip address of each worker:# of available gpus process on each gpu in this node +# Note that if we collocate ps and worker on same GPU, then we need to decrease this number of available processes on that GPU by 1 +# E.g., master node has 4 available processes, then 1 for the ps, and worker should be set to: worker:3 +worker_ips: + - localhost:[7,7,0,0] # worker_ip: [(# processes on gpu) for gpu in available_gpus] eg. 10.0.0.2:[4,4,4,4] This node has 4 gpus, each gpu has 4 processes. + +exp_path: $FEDSCALE_HOME/examples/auxo + +# Entry function of executor and aggregator under $exp_path +executor_entry: executor.py + +aggregator_entry: aggregator.py + +auth: + ssh_user: "" + ssh_private_key: ~/.ssh/id_rsa + +# cmd to run before we can indeed run FAR (in order) +setup_commands: + - source $HOME/anaconda3/bin/activate fedscale + +# ========== Additional job configuration ========== +# Default parameters are specified in config_parser.py, wherein more description of the parameter can be found + +job_conf: + - job_name: auxo_femnist # Generate logs under this folder: log_path/job_name/time_stamp + - log_path: $FEDSCALE_HOME/benchmark # Path of log files + - num_participants: 200 # Number of participants per round, we use K=100 in our paper, large K will be much slower + - data_set: femnist # Dataset: openImg, google_speech, stackoverflow + - data_dir: $FEDSCALE_HOME/benchmark/dataset/data/ # Path of the dataset + - data_map_file: $FEDSCALE_HOME/benchmark/dataset/data/femnist/client_data_mapping/train.csv # Allocation of data to each client, turn to iid setting if not provided + - device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace + - device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace + - model: resnet18 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs +# - model_zoo: fedscale-torch-zoo + - eval_interval: 20 # How many rounds to run a testing on the testing set + - rounds: 1000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds + - filter_less: 0 # Remove clients w/ less than 21 samples + - num_loaders: 2 + - local_steps: 10 + - learning_rate: 0.05 + - batch_size: 20 + - test_bsz: 20 + - use_cuda: True + - save_checkpoint: False diff --git a/examples/auxo/Dockerfile b/examples/auxo/Dockerfile new file mode 100644 index 00000000..8e5bb3e3 --- /dev/null +++ b/examples/auxo/Dockerfile @@ -0,0 +1,29 @@ +# Use an official CUDA image as a parent image +FROM nvidia/cuda:11.0-base-ubuntu20.04 + +# Set the working directory inside the container +WORKDIR /app + +# Install necessary system packages +RUN apt-get update && apt-get install -y python3.7 python3-pip + +# Create a virtual environment and activate it +RUN python3.7 -m pip install virtualenv +RUN python3.7 -m virtualenv venv +RUN /bin/bash -c "source venv/bin/activate" + +# Copy the requirements file into the container +COPY requirements.txt . + +# Install the Python dependencies +RUN pip install --upgrade pip && pip install -r requirements.txt + +# Copy the project files into the container (assuming your project is in the current directory) +COPY . . + +# Install your project using pip +RUN pip install -e . + +# Command to run when the container starts +CMD ["bash"] + diff --git a/examples/auxo/README.md b/examples/auxo/README.md new file mode 100644 index 00000000..4ac86fdf --- /dev/null +++ b/examples/auxo/README.md @@ -0,0 +1,72 @@ + + +
+ + + + + + +
+ + \ No newline at end of file diff --git a/examples/auxo/aggregator.py b/examples/auxo/aggregator.py new file mode 100644 index 00000000..829bf837 --- /dev/null +++ b/examples/auxo/aggregator.py @@ -0,0 +1,600 @@ +# -*- coding: utf-8 -*- +import logging + +from fedscale.cloud.aggregation.aggregator import * +from client_manager import HeterClientManager +from resource_manager import AuxoResourceManager +from utils.helper import * + +class AuxoAggregator(Aggregator): + def __init__(self, args): + super().__init__(args) + + self.sampled_participants = [[]] + self.sampled_executors = [[]] + self.round_stragglers = [[]] + self.stats_util_accumulator = [[]] + self.loss_accumulator = [[]] + self.client_training_results = [[]] + self.test_result_accumulator = [[]] + self.virtual_client_clock = [[]] + self.testing_history = [{'data_set': args.data_set, 'model': args.model, 'sample_mode': args.sample_mode, + 'gradient_policy': args.gradient_policy, 'task': args.task, + 'perf': collections.OrderedDict()}] + + self.model_in_update = [0] + # self.last_saved_round = [0] + self.tasks_round = [0] + self.global_virtual_clock = [0.] + self.round_duration = [0.] + self.model_update_size = [0.] + self.round = [0] + + self.stop_cluster = 0 + self.split_cluster = 1 + self.num_split = 2 + self.resource_manager = AuxoResourceManager(self.experiment_mode) + + def init_model(self): + """Initialize the model""" + if self.args.engine == commons.TENSORFLOW: + self.model_wrapper = [TensorflowModelAdapter(init_model())] + elif self.args.engine == commons.PYTORCH: + self.model_wrapper = [TorchModelAdapter( + init_model(), + optimizer=TorchServerOptimizer( + self.args.gradient_policy, self.args, self.device))] + else: + raise ValueError(f"{self.args.engine} is not a supported engine.") + self.model_weights = [self.model_wrapper[0].get_weights()] + + def init_client_manager(self, args): + """ Initialize client sampler + + Args: + args (dictionary): Variable arguments for fedscale runtime config. defaults to the setup in arg_parser.py + + Returns: + ClientManager: The client manager class + + Currently we implement two client managers: + + 1. Random client sampler - it selects participants randomly in each round + [Ref]: https://arxiv.org/abs/1902.01046 + + 2. Oort sampler + Oort prioritizes the use of those clients who have both data that offers the greatest utility + in improving model accuracy and the capability to run training quickly. + [Ref]: https://www.usenix.org/conference/osdi21/presentation/lai + + 3. Auxo: Client Heterogeneity Manager + [Ref]: https://arxiv.org/abs/2210.16656 + """ + + # sample_mode: random or oort + client_manager = HeterClientManager(args.sample_mode, args=args) + + return client_manager + + def event_monitor(self): + """Activate event handler according to the received new message + """ + logging.info("Start monitoring events ...") + + while True: + # Broadcast events to clients + if len(self.broadcast_events_queue) > 0: + current_event = self.broadcast_events_queue.popleft() + logging.info(f"Event {current_event} is broadcasted to clients") + event_type, cohort_id = decode_msg(current_event) + + if event_type in (commons.UPDATE_MODEL, commons.MODEL_TEST): + self.dispatch_client_events(current_event) + + elif event_type == commons.START_ROUND: + self.dispatch_client_events(generate_msg(commons.CLIENT_TRAIN, cohort_id)) + + elif event_type == 'split': + self.dispatch_client_events(current_event) + + elif event_type == commons.SHUT_DOWN: + self.dispatch_client_events(current_event) + break + + # Handle events queued on the aggregator + elif len(self.sever_events_queue) > 0: + + client_id, current_event, meta, data = self.sever_events_queue.popleft() + logging.info(f"Event {current_event} is received from client {client_id}") + event_type, cohort_id = decode_msg(current_event) + + if event_type == commons.UPLOAD_MODEL: + self.client_completion_handler( + self.deserialize_response(data), cohort_id) + logging.info(f"[Cohort {cohort_id}] Client {client_id} has completed the task. {len(self.stats_util_accumulator[cohort_id])} v.s. {self.tasks_round[cohort_id]}") + if len(self.stats_util_accumulator[cohort_id]) == self.tasks_round[cohort_id]: + self.round_completion_handler(cohort_id) + + elif event_type == commons.MODEL_TEST: + self.testing_completion_handler( + client_id, self.deserialize_response(data), cohort_id) + else: + logging.error(f"Event {current_event} is not defined") + + else: + # execute every 100 ms + time.sleep(0.1) + + def CLIENT_REGISTER(self, request, context): + """FL TorchClient register to the aggregator + + Args: + request (RegisterRequest): Registeration request info from executor. + + Returns: + ServerResponse: Server response to registeration request + + """ + + # NOTE: client_id = executor_id in deployment, + # while multiple client_id uses the same executor_id (VMs) in simulations + executor_id = request.executor_id + executor_info = self.deserialize_response(request.executor_info) + if executor_id not in self.individual_client_events: + # logging.info(f"Detect new client: {executor_id}, executor info: {executor_info}") + self.individual_client_events[executor_id] = collections.deque() + else: + logging.info(f"Previous client: {executor_id} resumes connecting") + + # We can customize whether to admit the clients here + self.executor_info_handler(executor_id, executor_info) + dummy_data = self.serialize_response(generate_msg(commons.DUMMY_RESPONSE, 0)) + + return job_api_pb2.ServerResponse(event=generate_msg(commons.DUMMY_EVENT, 0), + meta=dummy_data, data=dummy_data) + + def get_test_config(self, client_id, cohort_id=0): + """FL model testing on clients, developers can further define personalized client config here. + + Args: + client_id (int): The client id. + + Returns: + dictionary: The testing config for new task. + + """ + num_client = self.client_manager.schedule_plan() + client_list = self.select_participants(num_client, overcommitment = 1, cohort_id = cohort_id, test=True) + return {'client_id': client_list} + + + def CLIENT_PING(self, request, context): + """Handle client ping requests + + Args: + request (PingRequest): Ping request info from executor. + + Returns: + ServerResponse: Server response to ping request + + """ + # NOTE: client_id = executor_id in deployment, + # while multiple client_id may use the same executor_id (VMs) in simulations + executor_id, client_id = request.executor_id, request.client_id + response_data = response_msg = generate_msg(commons.DUMMY_RESPONSE, 0) + + if len(self.individual_client_events[executor_id]) == 0: + # send dummy response + current_event = generate_msg(commons.DUMMY_EVENT, 0) + response_data = response_msg = current_event + else: + current_event = self.individual_client_events[executor_id].popleft() + event_type, cohort_id = decode_msg(current_event) + if event_type == commons.CLIENT_TRAIN: + response_msg, response_data = self.create_client_task( + executor_id, cohort_id) + if response_msg is None: + current_event = generate_msg(commons.DUMMY_EVENT, 0) + if self.experiment_mode != commons.SIMULATION_MODE: + self.individual_client_events[executor_id].append( + commons.CLIENT_TRAIN) + elif event_type == commons.MODEL_TEST: + # TODO: remove fedscale test and add individual client testing + response_msg = self.get_test_config(client_id, cohort_id) + elif event_type == commons.UPDATE_MODEL: + response_data = self.model_wrapper[cohort_id].get_weights() + elif event_type == commons.SHUT_DOWN: + response_msg = self.get_shutdown_config(executor_id) + + response_msg, response_data = self.serialize_response( + response_msg), self.serialize_response(response_data) + # NOTE: in simulation mode, response data is pickle for faster (de)serialization + response = job_api_pb2.ServerResponse(event=current_event, + meta=response_msg, data=response_data) + if decode_msg(current_event)[0] != commons.DUMMY_EVENT: + logging.info(f"Issue EVENT ({current_event}) to EXECUTOR ({executor_id})") + + return response + + def CLIENT_EXECUTE_COMPLETION(self, request, context): + """FL clients complete the execution task. + + Args: + request (CompleteRequest): Complete request info from executor. + + Returns: + ServerResponse: Server response to job completion request + + """ + executor_id, client_id, event = request.executor_id, request.client_id, request.event + execution_status, execution_msg = request.status, request.msg + meta_result, data_result = request.meta_result, request.data_result + event_type, cohort_id = decode_msg(event) + if event_type == commons.CLIENT_TRAIN: + # Training results may be uploaded in CLIENT_EXECUTE_RESULT request later, + # so we need to specify whether to ask client to do so (in case of straggler/timeout in real FL). + if execution_status is False: + logging.error(f"Executor {executor_id} fails to run client {client_id}, due to {execution_msg}") + + if self.resource_manager.has_next_task(executor_id, cohort_id): + # NOTE: we do not pop the train immediately in simulation mode, + # since the executor may run multiple clients + if commons.CLIENT_TRAIN not in self.individual_client_events[executor_id]: + self.individual_client_events[executor_id].append(generate_msg(commons.CLIENT_TRAIN, cohort_id)) + + elif event_type in (commons.MODEL_TEST, commons.UPLOAD_MODEL): + self.add_event_handler( + executor_id, event, meta_result, data_result) + else: + logging.error(f"Received undefined event {event} from client {client_id}") + + return self.CLIENT_PING(request, context) + + + def create_client_task(self, executor_id, cohort_id): + """Issue a new client training task to specific executor + + Args: + executorId (int): Executor Id. + cohort_id (int): Cohort Id. + + Returns: + tuple: Training config for new task. (dictionary, PyTorch or TensorFlow module) + + """ + next_client_id = self.resource_manager.get_next_task(executor_id, cohort_id) + train_config = None + # NOTE: model = None then the executor will load the global model broadcasted in UPDATE_MODEL + if next_client_id is not None: + config = self.get_client_conf(next_client_id) + train_config = {'client_id': next_client_id, 'task_config': config, 'cohort_id': cohort_id} + + return train_config, self.model_wrapper[cohort_id].get_weights() + + + def client_completion_handler(self, results, cohort_id): + """We may need to keep all updates from clients, + if so, we need to append results to the cache + + Args: + results (dictionary): client's training result + + """ + # Format: + # -results = {'client_id':client_id, 'update_weight': model_param, 'moving_loss': round_train_loss, + # 'trained_size': count, 'wall_duration': time_cost, 'success': is_success 'utility': utility} + + if self.args.gradient_policy in ['q-fedavg']: + self.client_training_results[cohort_id].append(results) + # Feed metrics to client sampler + self.stats_util_accumulator[cohort_id].append(results['utility']) + self.loss_accumulator[cohort_id].append(results['moving_loss']) + + self.client_manager.register_feedback(results['client_id'], results['utility'], + auxi=math.sqrt( + results['moving_loss']), + time_stamp=self.round[cohort_id], + duration=self.virtual_client_clock[cohort_id][results['client_id']]['computation'] + + self.virtual_client_clock[cohort_id][results['client_id']]['communication'], + w_new = results['update_weight'], + w_old = self.model_wrapper[cohort_id].get_weights(), + cohort_id=cohort_id + ) + + # ================== Aggregate weights ====================== + self.update_lock.acquire() + + self.model_in_update[cohort_id] += 1 + self.update_weight_aggregation(results, cohort_id) + + self.update_lock.release() + + def tictak_client_tasks(self, sampled_clients, num_clients_to_collect, cohort_id): + """Record sampled client execution information in last round. In the SIMULATION_MODE, + further filter the sampled_client and pick the top num_clients_to_collect clients. + + Args: + sampled_clients (list of int): Sampled clients from client manager + num_clients_to_collect (int): The number of clients actually needed for next round. + + Returns: + Tuple: (the List of clients to run, the List of stragglers in the round, a Dict of the virtual clock of each + client, the duration of the aggregation round, and the durations of each client's task). + + """ + if self.experiment_mode == commons.SIMULATION_MODE: + # NOTE: We try to remove dummy events as much as possible in simulations, + # by removing the stragglers/offline clients in overcommitment""" + sampledClientsReal = [] + completionTimes = [] + completed_client_clock = {} + # 1. remove dummy clients that are not available to the end of training + for client_to_run in sampled_clients: + client_cfg = self.client_conf.get(client_to_run, self.args) + + exe_cost = self.client_manager.get_completion_time(client_to_run, + batch_size=client_cfg.batch_size, + local_steps=client_cfg.local_steps, + upload_size=self.model_update_size, + download_size=self.model_update_size) + + roundDuration = exe_cost['computation'] + \ + exe_cost['communication'] + # if the client is not active by the time of collection, we consider it is lost in this round + if self.client_manager.isClientActive(client_to_run, roundDuration + self.global_virtual_clock[cohort_id]): + sampledClientsReal.append(client_to_run) + completionTimes.append(roundDuration) + completed_client_clock[client_to_run] = exe_cost + + num_clients_to_collect = min( + num_clients_to_collect, len(completionTimes)) + # 2. get the top-k completions to remove stragglers + workers_sorted_by_completion_time = sorted( + range(len(completionTimes)), key=lambda k: completionTimes[k]) + top_k_index = workers_sorted_by_completion_time[:num_clients_to_collect] + clients_to_run = [sampledClientsReal[k] for k in top_k_index] + + stragglers = [sampledClientsReal[k] + for k in workers_sorted_by_completion_time[num_clients_to_collect:]] + round_duration = completionTimes[top_k_index[-1]] + completionTimes.sort() + + return (clients_to_run, stragglers, + completed_client_clock, round_duration, + completionTimes[:num_clients_to_collect]) + else: + completed_client_clock = { + client: {'computation': 1, 'communication': 1} for client in sampled_clients} + completionTimes = [1 for c in sampled_clients] + return (sampled_clients, sampled_clients, completed_client_clock, + 1, completionTimes) + + def update_default_task_config(self, cohort_id): + """Update the default task configuration after each round + """ + # TODO: fix the lr update + if self.round[cohort_id] % self.args.decay_round == 0: + self.args.learning_rate = max( + self.args.learning_rate * self.args.decay_factor, self.args.min_learning_rate) + + def select_participants(self, select_num_participants, overcommitment=1.3, cohort_id=0, test=False): + """Select clients for next round. + + Args: + select_num_participants (int): Number of clients to select. + overcommitment (float): Overcommit ration for next round. + + Returns: + list of int: The list of sampled clients id. + + """ + return sorted(self.client_manager.select_participants( + int(select_num_participants * overcommitment), + cur_time=self.global_virtual_clock[cohort_id], + cohort_id=cohort_id, + test=test) + ) + + + + def _init_split_config(self, cohort_id): + def increment_config( in_list ): + in_list.append([]) + in_list[cohort_id ] = [] + + self.model_wrapper.append(copy.deepcopy(self.model_wrapper[cohort_id])) + self.global_virtual_clock.append(copy.deepcopy(self.global_virtual_clock[cohort_id])) + self.model_in_update.append(0) + increment_config(self.round_stragglers) + increment_config(self.virtual_client_clock ) + increment_config(self.round_duration) + increment_config(self.test_result_accumulator ) + increment_config(self.stats_util_accumulator) + increment_config(self.client_training_results) + increment_config(self.tasks_round) + increment_config(self.loss_accumulator) + self.model_weights.append(copy.deepcopy(self.model_weights[cohort_id])) + self.testing_history.append(copy.deepcopy(self.testing_history[cohort_id])) + self.round.append(copy.deepcopy(self.round[cohort_id])) + increment_config(self.sampled_participants) + + def _split_participant_list(self, cohort_id, num_split = 2): + + for s in range(num_split-1): + self._init_split_config(cohort_id) + cohort_id_list = [cohort_id, self.split_cluster - 1] if num_split == 2 else [*range(num_split)] + + for cid in range(num_split): + + # num_client_per_round = self.args.num_participants * self.client_manager.get_cohort_size(cohort_id_list[cid]) // self.total_clients + num_client_per_round = self.client_manager.schedule_plan(self.round[cohort_id_list[cid]] , cid ) + num_client_per_round = max(num_client_per_round,1 ) + + self.sampled_participants[cohort_id_list[cid]] = self.select_participants(select_num_participants = num_client_per_round, \ + overcommitment=self.args.overcommitment, cohort_id = cohort_id_list[cid]) + clients_to_run, round_stragglers, virtual_client_clock, round_duration, _ = \ + self.tictak_client_tasks(self.sampled_participants[cohort_id_list[cid]], num_client_per_round, cohort_id_list[cid]) + self.round_stragglers[cohort_id_list[cid]] = round_stragglers + self.resource_manager.register_tasks(clients_to_run, cohort_id_list[cid]) + self.tasks_round[cohort_id_list[cid]] = len(clients_to_run) + self.virtual_client_clock[cohort_id_list[cid]] = virtual_client_clock + self.round_duration[cohort_id_list[cid]] = round_duration + self.model_in_update[cohort_id_list[cid]] = 0 + self.test_result_accumulator[cohort_id_list[cid]] = [] + self.stats_util_accumulator[cohort_id_list[cid]] = [] + self.client_training_results[cohort_id_list[cid]] = [] + self.loss_accumulator[cohort_id_list[cid]] = [] + + return cohort_id_list + + + def _is_first_result_in_round(self, cohort_id): + return self.model_in_update[cohort_id] == 1 + + def _is_last_result_in_round(self, cohort_id): + return self.model_in_update[cohort_id] == self.tasks_round[cohort_id] + + + def update_weight_aggregation(self, results, cohort_id = 0): + """Updates the aggregation with the new results. + + :param results: the results collected from the client. + """ + update_weights = results['update_weight'] + if type(update_weights) is dict: + update_weights = [x for x in update_weights.values()] + if self._is_first_result_in_round(cohort_id): + self.model_weights[cohort_id] = update_weights + else: + self.model_weights[cohort_id] = [weight + update_weights[i] for i, weight in enumerate(self.model_weights[cohort_id])] + if self._is_last_result_in_round(cohort_id): + self.model_weights[cohort_id] = [np.divide(weight, self.tasks_round[cohort_id]) for weight in self.model_weights[cohort_id]] + self.model_wrapper[cohort_id].set_weights(copy.deepcopy(self.model_weights[cohort_id])) + + + def round_completion_handler(self, cohort_id = 0): + """Triggered upon the round completion, it registers the last round execution info, + broadcast new tasks for executors and select clients for next round. + """ + self.global_virtual_clock[cohort_id] += self.round_duration[cohort_id] + self.round[cohort_id] += 1 + last_round_avg_util = sum(self.stats_util_accumulator[cohort_id]) / max(1, len(self.stats_util_accumulator[cohort_id])) + # assign avg reward to explored, but not ran workers + for client_id in self.round_stragglers[cohort_id]: + self.client_manager.register_feedback(client_id, last_round_avg_util, + time_stamp=self.round[cohort_id], + duration=self.virtual_client_clock[cohort_id][client_id]['computation'] + + self.virtual_client_clock[cohort_id][client_id]['communication'], + success=False) + + avg_loss = sum(self.loss_accumulator[cohort_id]) / max(1, len(self.loss_accumulator[cohort_id])) + logging.info(f"[Cohort {cohort_id}] Wall clock: {round(self.global_virtual_clock[cohort_id])} s, round: {self.round[cohort_id]}, Planned participants: " + + f"{len(self.sampled_participants[cohort_id])}, Succeed participants: {len(self.stats_util_accumulator[cohort_id])}, Training loss: {avg_loss}") + + at_split = False + if self.round[cohort_id] > 1: # TODO: replace with clustering start round + at_split = self.client_manager.cohort_clustering(self.round[cohort_id], cohort_id ) + + # TODO: add split and non-split logic: update the stats and select participants + if at_split: + self.split_cluster = len(self.client_manager.feasibleClients) + self.resource_manager.split(cohort_id) + new_cohort_id_list = self._split_participant_list(cohort_id, self.num_split) + + else: + num_client_per_round = self.client_manager.schedule_plan(self.round[cohort_id], cohort_id) + # update select participants + self.sampled_participants[cohort_id] = self.select_participants( + select_num_participants=num_client_per_round, overcommitment=self.args.overcommitment, cohort_id=cohort_id) + (clients_to_run, round_stragglers, virtual_client_clock, round_duration, + flatten_client_duration) = self.tictak_client_tasks( + self.sampled_participants[cohort_id], num_client_per_round, cohort_id) + + logging.info(f"Selected participants to run: {clients_to_run}") + + # Issue requests to the resource manager; Tasks ordered by the completion time + self.resource_manager.register_tasks(clients_to_run, cohort_id) + self.tasks_round[cohort_id] = len(clients_to_run) + + # Update executors and participants + if self.experiment_mode == commons.SIMULATION_MODE: + self.sampled_executors = list( + self.individual_client_events.keys()) + else: + self.sampled_executors = [str(c_id) + for c_id in self.sampled_participants] + self.round_stragglers[cohort_id] = round_stragglers + self.virtual_client_clock[cohort_id] = virtual_client_clock + self.flatten_client_duration = np.array(flatten_client_duration) + self.round_duration[cohort_id] = round_duration + self.model_in_update[cohort_id] = 0 + self.test_result_accumulator[cohort_id] = [] + self.stats_util_accumulator[cohort_id] = [] + self.client_training_results[cohort_id] = [] + self.loss_accumulator[cohort_id] = [] + self.update_default_task_config(cohort_id) + + if self.round[cohort_id] >= self.args.rounds: + self.broadcast_aggregator_events(generate_msg(commons.SHUT_DOWN)) + elif at_split: + self.broadcast_aggregator_events(generate_msg('split', cohort_id=cohort_id)) + for cid in new_cohort_id_list: + self.broadcast_aggregator_events(generate_msg(commons.UPDATE_MODEL, cohort_id=cid)) + self.broadcast_aggregator_events(generate_msg(commons.START_ROUND, cohort_id=cid)) + elif self.round[cohort_id] % self.args.eval_interval == 0 or self.round[cohort_id] == 1: + self.broadcast_aggregator_events(generate_msg(commons.UPDATE_MODEL, cohort_id=cohort_id)) + self.broadcast_aggregator_events(generate_msg(commons.MODEL_TEST, cohort_id=cohort_id)) + else: + self.broadcast_aggregator_events(generate_msg(commons.UPDATE_MODEL, cohort_id=cohort_id)) + self.broadcast_aggregator_events(generate_msg(commons.START_ROUND, cohort_id=cohort_id)) + + def aggregate_test_result(self, cohort_id): + accumulator = self.test_result_accumulator[cohort_id][0] + for i in range(1, len(self.test_result_accumulator[cohort_id])): + if self.args.task == "detection": + for key in accumulator: + if key == "boxes": + for j in range(596): + accumulator[key][j] = accumulator[key][j] + \ + self.test_result_accumulator[cohort_id][i][key][j] + else: + accumulator[key] += self.test_result_accumulator[cohort_id][i][key] + else: + for key in accumulator: + accumulator[key] += self.test_result_accumulator[cohort_id][i][key] + self.testing_history[cohort_id]['perf'][self.round[cohort_id]] = {'round': self.round[cohort_id], 'clock': self.global_virtual_clock[cohort_id]} + for metric_name in accumulator.keys(): + if metric_name == 'test_loss': + self.testing_history[cohort_id]['perf'][self.round[cohort_id]]['loss'] = accumulator['test_loss'] \ + if self.args.task == "detection" else accumulator['test_loss'] / accumulator['test_len'] + elif metric_name not in ['test_len']: + self.testing_history[cohort_id]['perf'][self.round[cohort_id]][metric_name] \ + = accumulator[metric_name] / accumulator['test_len'] + + round_perf = self.testing_history[cohort_id]['perf'][self.round[cohort_id]] + logging.info( + "FL Testing in round: {}, virtual_clock: {}, results: {}" + .format(self.round[cohort_id], self.global_virtual_clock[cohort_id], round_perf)) + + def testing_completion_handler(self, client_id, results, cohort_id): + """Each executor will handle a subset of testing dataset + + Args: + client_id (int): The client id. + results (dictionary): The client test results. + + """ + + results = results['results'] + # List append is thread-safe + self.test_result_accumulator[cohort_id].append(results) + + # Have collected all testing results + if len(self.test_result_accumulator[cohort_id]) == len(self.executors): + self.aggregate_test_result(cohort_id) + self.broadcast_aggregator_events(generate_msg(commons.START_ROUND, cohort_id=cohort_id)) + + +if __name__ == "__main__": + aggregator = AuxoAggregator(parser.args) + aggregator.run() diff --git a/examples/auxo/client_manager.py b/examples/auxo/client_manager.py new file mode 100644 index 00000000..86409efd --- /dev/null +++ b/examples/auxo/client_manager.py @@ -0,0 +1,268 @@ + +from fedscale.cloud.client_manager import * +from collections import defaultdict +from client_metadata import AuxoClientMetadata +import logging +import numpy as np +from sklearn import preprocessing +from collections import defaultdict +import copy +from clustering import QTable +from sklearn.manifold import TSNE +from config import auxo_config + + +class HeterClientManager(ClientManager): + def __init__(self, mode, args, sample_seed=233): + ''' + Manage cohort membership; + Manane client selection; + Manage cohort training resources usage + ''' + super().__init__(mode, args, sample_seed) + + self.round_clt = defaultdict(list) + self.grad_div_dict = defaultdict(list) + # self.split_round = defaultdict(bool) + self.stop_cluster = False + self.gradient_list = defaultdict(list) + self.feasibleClients = [[]] + + self.total_res = args.num_participants + self.round_acc = defaultdict(dict) + self.num_cluster = 1 + self.latest_acc_list = {0:1} + + logging.info(f'Client manager initialized with auxo config: {auxo_config}') + + + def register_client(self, host_id: int, client_id: int, size: int, speed: Dict[str, float], + duration: float = 1) -> None: + """Register client information to the client manager. + + Args: + host_id (int): executor Id. + client_id (int): client Id. + size (int): number of samples on this client. + speed (Dict[str, float]): device speed (e.g., compuutation and communication). + duration (float): execution latency. + + """ + uniqueId = self.getUniqueId(host_id, client_id) + user_trace = None if self.user_trace is None else self.user_trace[self.user_trace_keys[int( + client_id) % len(self.user_trace)]] + + self.client_metadata[uniqueId] = AuxoClientMetadata(host_id, client_id, speed, user_trace) + # remove clients + # if size >= self.filter_less and size <= self.filter_more: + self.feasibleClients[0].append(client_id) + self.feasible_samples += size + + if self.mode == "oort": + feedbacks = {'reward': min(size, self.args.local_steps * self.args.batch_size), + 'duration': duration, + } + self.ucb_sampler.register_client(client_id, feedbacks=feedbacks) + # else: + # del self.client_metadata[uniqueId] + + + def getUniqueId(self, host_id, client_id): + return int(client_id) + # return (str(host_id) + '_' + str(client_id)) + + def getFeasibleClients(self, cur_time: float, cohort_id: int = 0): + if self.user_trace is None: + clients_online = self.feasibleClients[cohort_id] + else: + clients_online = [client_id for client_id in self.feasibleClients[cohort_id] if self.client_metadata[self.getUniqueId( + 0, client_id)].is_active(cur_time)] + + logging.info(f"Wall clock time: {cur_time}, {len(clients_online)} clients online, " + + f"{len(self.feasibleClients[cohort_id]) - len(clients_online)} clients offline") + + return clients_online + + + def select_participants(self, num_of_clients: int, cur_time: float = 0, cohort_id: int = 0, test: bool = False) -> List[int]: + """Select participating clients for current execution task. + + Args: + num_of_clients (int): number of participants to select. + cur_time (float): current wall clock time. + + Returns: + List[int]: indices of selected clients. + + """ + + clients_online = self.getFeasibleClients(cur_time, cohort_id) + + if len(clients_online) <= num_of_clients: + return clients_online + + self.gradient_list[cohort_id] = [] + pickled_clients = None + clients_online_set = set(clients_online) + + if test: + pivot_client = self.reward_qtable.return_pivot_client(cohort_id) + pivot_client = list(set(pivot_client) & set(self.feasibleClients[cohort_id])) + self.rng.shuffle(pivot_client) + extra_clt = list(set(clients_online) - set(pivot_client)) + self.rng.shuffle(extra_clt) + + pickled_clients = pivot_client + extra_clt + client_len = min(num_of_clients, len(pickled_clients) - 1) + pickled_clients = pickled_clients[:client_len] + + elif self.mode == "oort": + pickled_clients = self.ucb_sampler.select_participant( + num_of_clients, feasible_clients=clients_online_set) + else: + self.rng.shuffle(clients_online) + client_len = min(num_of_clients, len(clients_online) - 1) + pickled_clients = clients_online[:client_len] + return pickled_clients + + def getDataInfo(self): + train_ratio = self.args.num_participants / len(self.feasibleClients[0]) + avg_train_times = self.args.rounds * self.args.num_participants / len( + self.feasibleClients[0]) + known_clt = len(self.feasibleClients[0]) // 20 + + split_round = auxo_config['split_round'] + exploredecay = auxo_config['exploredecay'] + explorerate = auxo_config['explorerate'] + self.reduction = auxo_config['reduction'] + metric = auxo_config['metric'] + + self.reward_qtable = QTable(1 + len(self.feasibleClients[0]), known_clt = known_clt, + split_round = split_round,\ + elbow_constant=0.97, train_ratio=train_ratio, avg_train_times=avg_train_times, \ + epsilon=explorerate, epsilon_decay = exploredecay,\ + metric=metric) + + return {'total_feasible_clients': len(self.feasibleClients[0]), 'total_num_samples': self.feasible_samples} + + def get_cohort_size(self, cohort_id): + return len(self.feasibleClients[cohort_id]) + + def register_feedback(self, client_id: int, reward: float, auxi: float = 1.0, time_stamp: float = 0, + duration: float = 1., success: bool = True, w_new = None, w_old = None, cohort_id = 0) -> None: + """Collect client execution feedbacks of last round. + + Args: + client_id (int): client Id. + reward (float): execution utilities (processed feedbacks). + auxi (float): unprocessed feedbacks. + time_stamp (float): current wall clock time. + duration (float): system execution duration. + success (bool): whether this client runs successfully. + + """ + # currently, we only use distance as reward + if self.mode == "oort": + feedbacks = { + 'reward': reward, + 'duration': duration, + 'status': True, + 'time_stamp': time_stamp + } + + self.ucb_sampler.update_client_util(client_id, feedbacks=feedbacks) + + if w_new is not None: + grad_norm = self._register_client_grad(client_id, w_new, w_old, cohort_id) + return + + def _register_client_grad(self, client_id, w_new, w_old, cohort_id): + if not self.stop_cluster: + grad_norm = self.client_metadata[client_id].register_gradient(w_new, w_old) + self.round_clt[cohort_id].append(client_id) + self.gradient_list[cohort_id].append(grad_norm) + return grad_norm + return 0 + + def cohort_clustering(self, round, cohort_id=0): + '''Clustering cohort results for current rounds; Update reward table + Args: + round: current round + cohort_id: current cohort + Returns: + whether to split the cohort + ''' + if round < auxo_config['start_round']: + return False + global_index = [int(gid) for gid in self.round_clt[cohort_id]] # [1,2,3,5,7] + logging.info(f'Clustering clients global index {global_index}') + if len(global_index) < 5: + return False + gradient_list = [self.client_metadata[clt].gradient for clt in self.round_clt[cohort_id]] + + # if self.reduction: + # norm_grad = TSNE(n_components=3, init='random').fit_transform(np.array(gradient_list)) + # elif self.distance == 'kl': + # norm_grad = centered_grad = np.array(gradient_list) + # else: + avg_grad = np.mean(gradient_list, axis=0) + centered_grad = gradient_list - avg_grad + norm_grad = preprocessing.normalize(centered_grad) + + # update sub reward + logging.info(f'Update intra-cluster relation ') + self.reward_qtable.knn_update_subR(cohort_id, norm_grad, global_index, False if round > 500 else True) + logging.info(f'Update inter-cluster relation ') + split = self.reward_qtable.update_mainR(cohort_id, centered_grad, global_index, False if round > 500 else True) + + if split: + self._split(len(self.feasibleClients)) + logging.info(f'SPLIT at round {round} ') + self.feasibleClients.append([]) + for i in range(len(self.feasibleClients)): + self.feasibleClients[i] = list(np.argwhere(self.reward_qtable.y_kmeans == i).reshape(-1)) + + # update feasible clients + if len(self.feasibleClients) > 1 and split == False: + for clt in self.round_clt[cohort_id]: # str + new_label = int(self.reward_qtable.y_kmeans[int(clt)]) + if int(clt) in self.feasibleClients[cohort_id] and new_label != cohort_id: + self.feasibleClients[cohort_id].remove(int(clt)) + self.feasibleClients[new_label].append(int(clt)) + elif int(clt) not in self.feasibleClients[cohort_id] and new_label == cohort_id: + for c in range(len(self.feasibleClients)): + if int(clt) in self.feasibleClients[c]: + self.feasibleClients[c].remove(int(clt)) + self.feasibleClients[cohort_id].append(int(clt)) + + self.round_clt[cohort_id] = [] + + self._print_cohort(round) + return split + + def _print_cohort(self, round): + size_ratio = [len(cluster) for cluster in self.feasibleClients] + logging.info(f'Round {round} FeasibleClients client number : {size_ratio}') + + def schedule_plan(self, round=0, cohort_id=0) -> int: + """ Schedule the training resources for each cohort + + Args: + round: + cohort_id: + + Returns: + int: number of training resources for each cohort + """ + # TODO: schedule based on accuracy + return self.total_res // self.num_cluster + + def _split(self, cohort_id): + self.num_cluster += 1 + if cohort_id not in self.latest_acc_list: + self.latest_acc_list[cohort_id] = 1 + + def update_eval(self, r, acc, clusterID): + self.round_acc[r][clusterID] = acc + if len(self.round_acc[r]) == self.num_cluster: + self.latest_acc_list = self.round_acc[r] \ No newline at end of file diff --git a/examples/auxo/client_metadata.py b/examples/auxo/client_metadata.py new file mode 100644 index 00000000..87c961e5 --- /dev/null +++ b/examples/auxo/client_metadata.py @@ -0,0 +1,17 @@ +import logging + +from fedscale.cloud.internal.client_metadata import * + +class AuxoClientMetadata(ClientMetadata): + def __init__(self, host_id, client_id, speed, traces=None): + super().__init__(host_id, client_id, speed, traces) + self.grad_ratio = 10 + + def register_gradient(self, W, W_old): + W_old = [dt.cpu().numpy() for dt in W_old] + W = [W[k] for k in W] + gradient = [pb - pa for pa, pb in zip(W, W_old)] + gradient = np.concatenate([v.flatten() for v in gradient]) + val_len = len(gradient) // self.grad_ratio # change grad size + self.gradient = np.float16(gradient[-val_len:]) + return self.gradient diff --git a/examples/auxo/clustering.py b/examples/auxo/clustering.py new file mode 100644 index 00000000..007124be --- /dev/null +++ b/examples/auxo/clustering.py @@ -0,0 +1,359 @@ +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import silhouette_score +from random import Random +from collections import defaultdict +from sklearn.cluster import KMeans +from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier +from sklearn.cluster import MiniBatchKMeans +import logging, time +from utils.klkmeans import KLKmeans + + +class QEntry(): + def __init__(self, init_R=1, r0=1, train_times_R=0): + self.R = init_R + self.sub_R = [init_R / 2, init_R / 2] + self.lr = 0.9 + self.r0 = r0 # lambda : np.random.normal(r0, r0/20) + self.train_times_R = train_times_R + + def update_R(self, new_reward): + old_R = self.R + # self.R = self.lr * new_reward + self.R * (1 - self.lr) + self.R = new_reward + self.R + delta_R = self.R - old_R + self.sub_R[0] += delta_R / 2 + self.sub_R[1] += delta_R / 2 + + def update_sub(self, sub_id): + self.sub_R[sub_id] += self.r0 # () + self.sub_R[1 - sub_id] -= self.r0 # () + + def reset_sub(self): + self.sub_R = [self.R / 2, self.R / 2] + self.train_times_R = 0 + + +class QTable(): + def __init__(self, num_client, train_ratio=0.1, base_reward=5, sample_seed=233, epsilon=0.01, epsilon_decay=0.99, \ + known_clt=50, elbow_constant=0.45, avg_train_times=4, split_round=None, merge=False, metric='cosine'): + # list dict + self.Qtable = [{0: QEntry()} for row in range(num_client)] # init with one model + self.num_client = num_client + self.y_kmeans = np.zeros(num_client) + self.base_reward = base_reward + self.known_clt = known_clt + self.num_model = 1 + self.init_round = {0: False} + self.epsilon = epsilon + self.epsilon_decay = epsilon_decay + self.rng = Random() + self.rng.seed(sample_seed) + self.sub_num = {} + self.split_counter = defaultdict(int) + self.epoch = defaultdict(int) + self.min_round = 1 + self.elbow_constant = elbow_constant + self.train_ratio = train_ratio + self.min_clt = 100 / train_ratio # minimal # clt per round --> 50 + self.avg_train_times = avg_train_times + self.pivot_clients = defaultdict(list) + self.pivot_clients[0] = [*range(self.num_client)] + self.min_cluster_size = num_client // 10 + self.split_round = split_round + self.merge_action = merge + self.metric = metric + if metric == 'kl': + self._initialize_kl() + + def _initialize_kl(self): + """Initialize KL divergence metric if applicable.""" + def KL(a, b): + epsilon = 0.00001 + a += epsilon + b += epsilon + return np.sum(np.where(a != 0, a * np.log(a / b), 0)) + self.kl = KL + + def update_R(self, cid, mid, new_reward): + self.Qtable[cid][mid].update_R(new_reward) + + if new_reward < 0: + self.Qtable[cid][mid].reset_sub() + + def update_R_batch(self, cid_list, mid, reward_list, remain_round): + for cid, reward in zip(cid_list, reward_list): + self.update_R(cid, mid, reward) + self.Qtable[cid][mid].train_times_R += 1 + + def update_subR(self, cid, mid, sub_id): + self.Qtable[cid][mid].update_sub(sub_id) + self.Qtable[cid][mid].train_times_R += 1 + + def update_subR_batch(self, cid_list, mid, sub_id): + for cid in cid_list: + self.update_subR(cid, mid, sub_id) + + def get_subid(self, cid, mid): + # ( cid ) prefer which subcluster in mid + if self.Qtable[cid][mid].sub_R[0] == self.Qtable[cid][mid].sub_R[1] or self.Qtable[cid][mid].R < 0: + return 2 + else: + return int(self.Qtable[cid][mid].sub_R[0] < self.Qtable[cid][mid].sub_R[1]) + + def subcluster_policy(self, mid, clt_list): + sub_num = [0, 0, 0] + sub_label_list = [] + for cid in clt_list: + sub_id = self.get_subid(cid, mid) + sub_num[sub_id] += 1 + sub_label_list.append(sub_id) + return sub_num, sub_label_list + + def grow_table(self, mid): + self.split_counter[mid] = 0 + new_mid = self.num_model + self.init_round[new_mid] = False + self.init_round[mid] = False + self.epoch[new_mid] = self.epoch[mid] + for cid in range(self.num_client): + sub_a = self.Qtable[cid][mid].sub_R[0] + sub_b = self.Qtable[cid][mid].sub_R[1] + a_R = 1 if sub_a > sub_b else 0 + b_R = 1 - a_R if sub_a != sub_b else 0 + + reward = self.Qtable[cid][mid].R + times = self.Qtable[cid][mid].train_times_R + self.Qtable[cid][mid] = QEntry(init_R=reward + a_R, train_times_R=times * a_R) + self.Qtable[cid][new_mid] = QEntry(init_R=reward + b_R, train_times_R=times * b_R) + self.num_model += 1 + + def shrink_table(self, mid): + self.num_model -= 1 + for cid in range(self.num_client): + self.Qtable[cid].pop(mid, None) + tmp_dict = self.Qtable[cid] + for m in range(self.num_model): + self.Qtable[cid][m] = list(tmp_dict.values())[m] + + def model_policy(self, cid): + + train_times = [self.Qtable[cid][m].train_times_R for m in self.Qtable[cid]] + new_client = False if np.sum(train_times) > 0 else True + + self.epsilon *= self.epsilon_decay + if new_client or self.rng.random() < self.epsilon: + return self.rng.randint(0, self.num_model - 1) + else: + reward_list = self.get_model_reward(cid) + return np.argmax(reward_list) + + def model_policy_batch(self, client_list): + for cid in client_list: + self.y_kmeans[cid] = self.model_policy(cid) + + self.y_kmeans[0] = -1 + + def get_model_reward(self, cid): + return [self.Qtable[cid][m].R for m in self.Qtable[cid]] + + def dist_to_reward(self, dist_list, known_ratio): + + avg_dist = np.mean(dist_list) + std_dist = np.std(dist_list) + + R0 = self.base_reward + slope = self.base_reward / (avg_dist + std_dist) + return [(R0 - slope * d) * known_ratio for d in dist_list] + + def count_known_main(self, global_index, mid): + known_index_list = [i for i, cid in enumerate(global_index) \ + if self.Qtable[cid][mid].train_times_R > 0 and self.Qtable[cid][mid].R > 1] + # local clt index + return known_index_list + + def split(self, mid): + # whether to split: enough client subreward info + minimal size satisfy + elbow test + if self.split_round is not None and self.epoch[mid] in self.split_round: + return True + elif self.split_round is None: + # whether to split: enough client subreward info + minimal size satisfy + elbow test + mid_client = np.argwhere(self.y_kmeans == mid) # global idx + mid_client = [i[0] for i in mid_client] # [1,3,5,7,9] + if len(mid_client) < self.min_clt: # ensure around 50 participate every round + return False + + sub_num = [0, 0, 0] + for cid in mid_client: + sub_num[self.get_subid(cid, mid)] += 1 + logging.info(f'sub_num : {sub_num} ') + min_size = self.known_clt * 0.9 ** (self.num_model - 1) + ratio = sub_num[0] / max(sub_num[1], 1) + + within_ratio = True if (ratio > 0.5 and ratio < 2) or self.num_model == 1 else False + return sub_num[0] > min_size and sub_num[1] > min_size and within_ratio + else: + return False + + def knn_update_subR(self, mid, X_sub, global_index, keep_split=True): + # update cluster membership for each clusters + self.epoch[mid] += 1 + if self.split_round is not None: + if self.epoch[mid] > max(self.split_round) + 1: + keep_split = False + else: + keep_split = True + + if self.init_round[mid] == False: + # First round of clustering + if self.num_model > 1: + known_clt_list = self.count_known_main(global_index, mid) + if len(known_clt_list) <= 1: + return + X_sub = X_sub[known_clt_list] + global_index = known_clt_list + + if self.metric == 'kl': + clustering = KLKmeans(n_clusters=2, init_center = X_sub[:2] ) + clustering.fit(X_sub) + labels = clustering.labels_ + + else: + clustering = MiniBatchKMeans(n_clusters=2, + random_state=0, + batch_size=10).fit(X_sub) + labels = clustering.labels_ # return labels + + for clt in range(len(global_index)): + self.update_subR(global_index[clt], mid, int(labels[clt])) + self.init_round[mid] = True + + elif keep_split: + # Continuous clustering of subsequent rounds + sub_num, sub_label_list = self.subcluster_policy(mid, global_index) + + if sub_num[0] == sub_num[1] == 0: + return + + elif sub_num[0] == 0 or sub_num[1] == 0: + sub_size = len(global_index) // 4 + if self.metric == 'kl': + neigh = NearestNeighbors(n_neighbors=sub_size, metric=lambda a, b: self.kl(a, b)).fit(X_sub) + else: + neigh = NearestNeighbors(n_neighbors=sub_size).fit(X_sub) + center_solo = np.argwhere(np.array(sub_label_list) != 2)[0][0] + near_clt_id = neigh.kneighbors([X_sub[center_solo]])[1][0] + for clt in near_clt_id: + self.update_subR(global_index[clt], mid, sub_label_list[center_solo]) + return + + else: + labeled_id = np.argwhere(np.array(sub_label_list) != 2).reshape(-1) + knn_label = list(filter(lambda score: score != 2, sub_label_list)) + knn_data = X_sub[labeled_id] + if self.metric == 'kl': + neigh = KNeighborsClassifier(n_neighbors=1, metric=lambda a, b: self.kl(a, b)).fit(knn_data, + knn_label) + else: + neigh = KNeighborsClassifier(n_neighbors=1).fit(knn_data, knn_label) + # labels = neigh.predict(X_sub) + unseen_id = np.argwhere(np.array(sub_label_list) == 2).reshape(-1) + if len(unseen_id) < 1: + return + labels = neigh.predict(X_sub[unseen_id]) + pred_prob = neigh.predict_proba(X_sub[unseen_id]) + for i, unseen_clt in enumerate(unseen_id): + conf = pred_prob[i][0] / max(pred_prob[i][1], 0.1) + if conf > 1.5 or conf < 0.67: + self.update_subR(global_index[unseen_clt], mid, int(labels[i])) + return + + def update_mainR(self, mid, X_sub, global_index, remain_round=True): + # update cluster membership for each clusters + if self.split_round is not None: + if self.epoch[mid] > max(self.split_round) + 1: + remain_round = False + else: + remain_round = True + + if self.num_model > 1: + known_clt_list = self.count_known_main(global_index, mid) + known_ratio = len(known_clt_list) / len(global_index) + + if len(known_clt_list) < 2 and self.train_ratio < 1: + # X_subcen = np.mean(X_sub, axis=0) + # print("know too less") + return False + else: + X_known = X_sub[known_clt_list] + X_subcen = np.mean(X_known, axis=0) + + X_ = X_sub - X_subcen + square_dist = np.sum(X_ ** 2, axis=1) + dist_list = np.sqrt(square_dist) + reward_list = self.dist_to_reward(dist_list, known_ratio) + self.update_R_batch(global_index, mid, reward_list, remain_round) + + split = False + self.split_counter[mid] += 1 + if self.split(mid): + # if self.split_counter[mid] > self.min_round : + split = True + self.grow_table(mid) + + if self.num_model > 2 and self.split_counter[mid] > 5 and self.merge(mid): + self.shrink_table(mid) + print(f"<<< Merge cluster {mid}") + logging.info(f"<<< Merge cluster {mid}") + self.model_policy_batch([*range(self.num_client)]) # update kmeans + + return split + + def update_pivot_client(self, mid): + trained_clt = set() + c = set() + for cid in range(1, self.num_client): + if self.Qtable[cid][mid].train_times_R > 0 and self.Qtable[cid][mid].R > 1: + c.add(cid) + # tmp_dict = self.Qtable[cid] + # if np.argsort(tmp_dict.values())[-1] == mid: + if self.Qtable[cid][mid].R > 1: + trained_clt.add(cid) + print(len(c), len(trained_clt)) + self.pivot_clients[mid] = list(trained_clt) + # TODO: can have many overlap clients, instead choose clients with highest score + + def return_pivot_client(self, mid): + # Return the clients that belong to the cohort + self.update_pivot_client(mid) + return self.pivot_clients[mid] + + def merge(self, mid): + if self.merge_action == False: + return False + + mid_client = np.argwhere(self.y_kmeans == mid) + if len(mid_client) < self.min_cluster_size: + return True + return False + + def count_trained_clt(self): + trained_clt = [] + for mid in range(self.num_model): + self.update_pivot_client(mid) + trained_clt += self.pivot_clients[mid] + print("Trained clients :", len(set(trained_clt))) + return list(set(trained_clt)) + + def plot(self, X, epoch): + # Visualize the clustering result + trained_clt = self.count_trained_clt() + plt.scatter(X[trained_clt, 0], X[trained_clt, 1], c=self.y_kmeans[trained_clt], s=30, cmap='viridis') + plt.title(f"Epoch {epoch}: {len(np.unique(self.y_kmeans[trained_clt], axis=0))} clusters") + plt.savefig(f"epoch_{epoch}.png") + plt.show() + + if len(np.unique(self.y_kmeans[trained_clt], axis=0)) > 1: + silhouette_avg = silhouette_score(X[trained_clt], self.y_kmeans[trained_clt]) + print("Silhouette score is ", silhouette_avg) diff --git a/examples/auxo/config.py b/examples/auxo/config.py new file mode 100644 index 00000000..e539f31e --- /dev/null +++ b/examples/auxo/config.py @@ -0,0 +1,16 @@ +import yaml, os +global auxo_config + + +# Get the value of the FEDSCALE_HOME environment variable +fedscale_home = os.environ.get('FEDSCALE_HOME') + +# Check if the environment variable is set +if fedscale_home is not None: + config_path = os.path.join(fedscale_home, 'examples', 'auxo', 'config.yml') + + # Now, open the file using the constructed path + with open(config_path, 'r') as f: + auxo_config = yaml.load(f, Loader=yaml.FullLoader) +else: + print("FEDSCALE_HOME environment variable is not set.") \ No newline at end of file diff --git a/examples/auxo/config.yml b/examples/auxo/config.yml new file mode 100644 index 00000000..4946c84b --- /dev/null +++ b/examples/auxo/config.yml @@ -0,0 +1,8 @@ +train_data_map_file: examples/auxo/train_by_clt.csv +test_data_map_file: examples/auxo/test_by_clt.csv +split_round: [50,100] # can be removed +exploredecay: 0.99 +explorerate: 0.01 +reduction: False +metric: 'cosine' +start_round: 0 \ No newline at end of file diff --git a/examples/auxo/executor.py b/examples/auxo/executor.py new file mode 100644 index 00000000..4e10f0ec --- /dev/null +++ b/examples/auxo/executor.py @@ -0,0 +1,253 @@ +import logging + +from fedscale.cloud.execution.executor import * +from utils.helper import * +import copy +from config import auxo_config + +class AuxoExecutor(Executor): + def __init__(self, args): + super().__init__(args) + self.round = [0] + self.model_adapter = [self.get_client_trainer(args).get_model_adapter(init_model())] + + def UpdateModel(self, model_weights, cohort_id): + """Receive the broadcasted global model for current round + + Args: + config (PyTorch or TensorFlow model): The broadcasted global model config + cohort_id (int): The cohort id + """ + self.round[cohort_id] += 1 + self.model_adapter[cohort_id].set_weights(model_weights) + + def training_handler(self, client_id, conf, model, cohort_id): + """Train model given client id + + Args: + client_id (int): The client id. + conf (dictionary): The client runtime config. + cohort_id (int): The cohort id. + + Returns: + dictionary: The train result + + """ + self.model_adapter[cohort_id].set_weights(model) + conf.client_id = client_id + conf.tokenizer = tokenizer + client_data = self.training_sets if self.args.task == "rl" else \ + select_dataset(client_id, self.training_sets, + batch_size=conf.batch_size, args=self.args, + collate_fn=self.collate_fn + ) + client = self.get_client_trainer(self.args) + if len(client_data) == 0: + state_dicts = self.model_adapter[cohort_id].get_model().state_dict() + logging.info(f"Client {client_id} has no data, return empty result") + return {'client_id': client_id, 'moving_loss': 0, + 'trained_size': 0, 'utility': 0, 'wall_duration': 0, + 'update_weight': {p: state_dicts[p].data.cpu().numpy() + for p in state_dicts}, + 'success': 1} + train_res = client.train( + client_data=client_data, model=self.model_adapter[cohort_id].get_model(), conf=conf) + + return train_res + + def Train(self, config): + """Load train config and data to start training on that client + + Args: + config (dictionary): The client training config. + + Returns: + tuple (int, dictionary): The client id and train result + + """ + client_id, train_config, cohort_id = config['client_id'], config['task_config'], config['cohort_id'] + + if 'model' not in config or not config['model']: + raise "The 'model' object must be a non-null value in the training config." + client_conf = self.override_conf(train_config) + train_res = self.training_handler( + client_id=client_id, conf=client_conf, model=config['model'], cohort_id=cohort_id) + + # Report execution completion meta information + response = self.aggregator_communicator.stub.CLIENT_EXECUTE_COMPLETION( + job_api_pb2.CompleteRequest( + client_id=str(client_id), executor_id=self.executor_id, + event=generate_msg(commons.CLIENT_TRAIN, cohort_id), status=True, msg=None, + meta_result=None, data_result=None + ) + ) + self.dispatch_worker_events(response) + logging.info(f"[Cohort {cohort_id}] Client {client_id} finished training. ") + + return client_id, train_res + + def _init_train_test_data(self): + + if self.args.data_set == 'femnist': + from utils.openimg import OpenImage + train_transform, test_transform = get_data_transform('mnist') + train_dataset = OpenImage(self.args.data_dir, dataset='femnist', transform=train_transform, client_mapping_file = auxo_config['train_data_map_file'] ) + test_dataset = OpenImage(self.args.data_dir, dataset='femnist', transform=test_transform, client_mapping_file = auxo_config['test_data_map_file'] ) + else: + raise NotImplementedError + return train_dataset, test_dataset + + + def init_data(self): + """Return the training and testing dataset + + Returns: + Tuple of DataPartitioner class: The partioned dataset class for training and testing + + """ + train_dataset, test_dataset = self._init_train_test_data() + if self.args.task == "rl": + return train_dataset, test_dataset + if self.args.task == 'nlp': + self.collate_fn = collate + elif self.args.task == 'voice': + self.collate_fn = voice_collate_fn + # load data partitionxr (entire_train_data) + logging.info("Data partitioner starts ...") + + training_sets = DataPartitioner( + data=train_dataset, args=self.args, numOfClass=self.args.num_class) + training_sets.partition_data_helper( + num_clients=self.args.num_participants, data_map_file=auxo_config['train_data_map_file']) + + testing_sets = DataPartitioner( + data=test_dataset, args=self.args, numOfClass=self.args.num_class) + testing_sets.partition_data_helper( + num_clients=self.args.num_participants, data_map_file=auxo_config['test_data_map_file']) + + logging.info("Data partitioner completes ...") + + return training_sets, testing_sets + + + def testing_handler(self, client_list, cohort_id=0): + """Test model + + Args: + args (dictionary): Variable arguments for fedscale runtime config. defaults to the setup in arg_parser.py + config (dictionary): Variable arguments from coordinator. + Returns: + dictionary: The test result + + """ + + test_num_clt = max(len(client_list) // self.num_executors, 1) + test_client_id_list = client_list[(self.this_rank - 1) * test_num_clt: self.this_rank * test_num_clt] + logging.info(f"[Cohort {cohort_id}] Test client ID: {test_client_id_list}") + testResults_accum = {'top_1': 0, 'top_5': 0, 'test_loss': 0, 'test_len': 0} + + test_config = self.override_conf({ + 'rank': self.this_rank, + 'memory_capacity': self.args.memory_capacity, + 'tokenizer': tokenizer + }) + for clt in test_client_id_list: + client = self.get_client_trainer(test_config) + data_loader = select_dataset(clt, self.testing_sets, + batch_size=self.args.test_bsz, args=self.args, + isTest=False, collate_fn=self.collate_fn) + if len(data_loader) > 0: + test_results = client.test(data_loader, self.model_adapter[cohort_id].get_model(), test_config) + testResults_accum['top_1'] += test_results['top_1'] + testResults_accum['top_5'] += test_results['top_5'] + testResults_accum['test_loss'] += test_results['test_loss'] + testResults_accum['test_len'] += test_results['test_len'] + + # testRes = {'top_1': correct, 'top_5': top_5, + # 'test_loss': sum_loss, 'test_len': test_len} + + gc.collect() + + return testResults_accum + + def Test(self, config, cohort_id): + """Model Testing. By default, we test the accuracy on all data of clients in the test group + + Args: + config (dictionary): The client testing config. + + """ + test_res = self.testing_handler(config['client_id'], cohort_id) + test_res = {'executorId': self.this_rank, 'results': test_res} + + # Report execution completion information + response = self.aggregator_communicator.stub.CLIENT_EXECUTE_COMPLETION( + job_api_pb2.CompleteRequest( + client_id=self.executor_id, executor_id=self.executor_id, + event=generate_msg(commons.MODEL_TEST, cohort_id), status=True, msg=None, + meta_result=None, data_result=self.serialize_response(test_res) + ) + ) + self.dispatch_worker_events(response) + + def _init_split(self, cohort_id, new_cohort_id): + if len(self.model_adapter) <= new_cohort_id: + self.model_adapter.append(copy.deepcopy(self.model_adapter[cohort_id])) + self.round.append(copy.deepcopy(self.round[cohort_id])) + + def event_monitor(self): + """Activate event handler once receiving new message + """ + logging.info("Start monitoring events ...") + self.client_register() + + while not self.received_stop_request: + if len(self.event_queue) > 0: + request = self.event_queue.popleft() + current_event = request.event + event_type, cohort_id = decode_msg(current_event) + if event_type != commons.DUMMY_EVENT: + logging.info("Received message: {}".format(current_event)) + + if event_type == commons.CLIENT_TRAIN: + train_config = self.deserialize_response(request.meta) + train_model = self.deserialize_response(request.data) + train_config['model'] = train_model + train_config['client_id'] = int(train_config['client_id']) + client_id, train_res = self.Train(train_config) + + # Upload model updates + future_call = self.aggregator_communicator.stub.CLIENT_EXECUTE_COMPLETION.future( + job_api_pb2.CompleteRequest(client_id=str(client_id), executor_id=self.executor_id, + event=generate_msg(commons.UPLOAD_MODEL, cohort_id), status=True, msg=None, + meta_result=None, data_result=self.serialize_response(train_res) + )) + future_call.add_done_callback(lambda _response: self.dispatch_worker_events(_response.result())) + + elif event_type == commons.MODEL_TEST: + self.Test(self.deserialize_response(request.meta), cohort_id) + + elif event_type == commons.UPDATE_MODEL: + model_weights = self.deserialize_response(request.data) + self.UpdateModel(model_weights, cohort_id) + + elif event_type == 'split': + new_cohort_id = len(self.model_adapter) + self._init_split(cohort_id, new_cohort_id) + + elif event_type == commons.SHUT_DOWN: + self.Stop() + + elif event_type == commons.DUMMY_EVENT: + pass + else: + time.sleep(1) + try: + self.client_ping() + except Exception as e: + logging.info(f"Caught exception {e} from aggregator, terminating executor {self.this_rank} ...") + self.Stop() + +if __name__ == "__main__": + executor = AuxoExecutor(parser.args) + executor.run() diff --git a/examples/auxo/fig/auxo.png b/examples/auxo/fig/auxo.png new file mode 100644 index 00000000..71458bb5 Binary files /dev/null and b/examples/auxo/fig/auxo.png differ diff --git a/examples/auxo/fig/epoch_100.png b/examples/auxo/fig/epoch_100.png new file mode 100644 index 00000000..305f0e61 Binary files /dev/null and b/examples/auxo/fig/epoch_100.png differ diff --git a/examples/auxo/fig/epoch_14.png b/examples/auxo/fig/epoch_14.png new file mode 100644 index 00000000..1323c1d4 Binary files /dev/null and b/examples/auxo/fig/epoch_14.png differ diff --git a/examples/auxo/fig/epoch_224.png b/examples/auxo/fig/epoch_224.png new file mode 100644 index 00000000..fa7cc029 Binary files /dev/null and b/examples/auxo/fig/epoch_224.png differ diff --git a/examples/auxo/fig/epoch_300.png b/examples/auxo/fig/epoch_300.png new file mode 100644 index 00000000..da20135d Binary files /dev/null and b/examples/auxo/fig/epoch_300.png differ diff --git a/examples/auxo/fig/epoch_500.png b/examples/auxo/fig/epoch_500.png new file mode 100644 index 00000000..e8096cd4 Binary files /dev/null and b/examples/auxo/fig/epoch_500.png differ diff --git a/examples/auxo/fig/epoch_700.png b/examples/auxo/fig/epoch_700.png new file mode 100644 index 00000000..61f11aa3 Binary files /dev/null and b/examples/auxo/fig/epoch_700.png differ diff --git a/examples/auxo/playground.py b/examples/auxo/playground.py new file mode 100644 index 00000000..b8dad2b3 --- /dev/null +++ b/examples/auxo/playground.py @@ -0,0 +1,64 @@ +import time +import numpy as np +import random +import sys +from sklearn.datasets import make_blobs +from sklearn.cluster import KMeans +from clustering import QTable + +DEFAULT_SAMPLE_SIZE = 50 +DEFAULT_TOTAL_EPOCH = 800 +DEFAULT_TOTAL_SAMPLE = 1000 +DEFAULT_KNOWN_CLT = 150 +NUM_CENTERS = 4 +random.seed(100) + +def bipartition_cluster(total_epoch, total_sample, sample_size, known_clt): + start_time = time.time() + + learning_rate = 0.1 + avg_train_times = 4 if len(sys.argv) <= 3 else total_epoch * sample_size / total_sample + print(f"{sample_size}/{total_sample} of samples cluster for {total_epoch} epochs") + + X, y_true = make_blobs(n_samples=total_sample, centers=NUM_CENTERS, cluster_std=2, random_state=100) + + Q_table = QTable( + total_sample, + train_ratio=0.9, + elbow_constant=0.8, + merge=False, + known_clt=known_clt, + avg_train_times=avg_train_times, + # split_round=[200] + ) + + for epoch in range(total_epoch): + num_list = random.sample(range(total_sample), sample_size) + + for mid in range(Q_table.num_model): + client_id = np.argwhere(Q_table.y_kmeans[num_list] == mid) + global_index = [num_list[i[0]] for i in client_id] + + if len(client_id) < 5: + continue + + X_sub = X[global_index] + Q_table.knn_update_subR(mid, X_sub, global_index) + split = Q_table.update_mainR(mid, X_sub, global_index, epoch <= total_epoch * 0.8) + + if split or epoch % 200 == 100 and mid == 0: + if split: + print(f"SPLIT at round {epoch}") + print(f'Epoch: {epoch}') + Q_table.plot(X, epoch) + + print(f"Split to {Q_table.num_model} models") + print(f"Time usage: {time.time() - start_time}") + +if __name__ == "__main__": + total_epoch = int(sys.argv[3]) if len(sys.argv) > 3 else DEFAULT_TOTAL_EPOCH + total_sample = int(sys.argv[1]) if len(sys.argv) > 3 else DEFAULT_TOTAL_SAMPLE + sample_size = int(sys.argv[2]) if len(sys.argv) > 3 else DEFAULT_SAMPLE_SIZE + known_clt = DEFAULT_KNOWN_CLT if len(sys.argv) <= 3 else 1000 + + bipartition_cluster(total_epoch, total_sample, sample_size, known_clt) diff --git a/examples/auxo/requirements.txt b/examples/auxo/requirements.txt new file mode 100644 index 00000000..a66a9f9a --- /dev/null +++ b/examples/auxo/requirements.txt @@ -0,0 +1,32 @@ +tensorboard +numba==0.48.0 +pip==20.0.2 +torch_optimizer +torch +tensorflow +torchvision +transformers +scipy==1.4.1 +matplotlib==3.1.3 +torch_baidu_ctc==0.3.0 +tensorboardX==2.1 +overrides==3.1.0 +python-levenshtein==0.12.0 +pandas==1.1.0 +PyYAML +pytest +sox==1.3.7 +grpcio==1.40.0 +gym +jupyter +pillow==9 +sentencepiece +gdown +h5py +librosa==0.7.2 +SoundFile +resampy==0.3.1 +kubernetes +wandb +nltk + diff --git a/examples/auxo/resource_manager.py b/examples/auxo/resource_manager.py new file mode 100644 index 00000000..12055df7 --- /dev/null +++ b/examples/auxo/resource_manager.py @@ -0,0 +1,58 @@ +from fedscale.cloud.resource_manager import * + + +class AuxoResourceManager(ResourceManager): + def __init__(self, experiment_mode): + self.client_run_queue = [[]] + self.client_run_queue_idx = [0] + self.experiment_mode = experiment_mode + self.update_lock = threading.Lock() + + + def register_tasks(self, clientsToRun, cohort_id): + self.client_run_queue[cohort_id] = clientsToRun.copy() + self.client_run_queue_idx[cohort_id] = 0 + + def split(self, cohort_id): + self.client_run_queue.append( self.client_run_queue[cohort_id].copy()) + self.client_run_queue_idx.append(0) + + def get_task_length(self, cohort_id) -> int: + """Number of tasks left in the queue + + Returns: + int: Number of tasks left in the queue + """ + self.update_lock.acquire() + remaining_task_num: int = len(self.client_run_queue[cohort_id]) - self.client_run_queue_idx[cohort_id] + self.update_lock.release() + return remaining_task_num + + def remove_client_task(self, client_id, cohort_id): + assert(client_id in self.client_run_queue[cohort_id], + f"client task {client_id} is not in task queue") + + def has_next_task(self, client_id=None, cohort_id=0): + exist_next_task = False + if self.experiment_mode == commons.SIMULATION_MODE: + exist_next_task = self.client_run_queue_idx[cohort_id] < len( + self.client_run_queue[cohort_id]) + else: + exist_next_task = client_id in self.client_run_queue[cohort_id] + return exist_next_task + + def get_next_task(self, client_id=None, cohort_id=0): + # TODO: remove client id when finish + next_task_id = None + self.update_lock.acquire() + if self.experiment_mode == commons.SIMULATION_MODE: + if self.has_next_task(client_id, cohort_id): + next_task_id = self.client_run_queue[cohort_id][self.client_run_queue_idx[cohort_id]] + self.client_run_queue_idx[cohort_id] += 1 + else: + if client_id in self.client_run_queue[cohort_id]: + next_task_id = client_id + self.client_run_queue[cohort_id].remove(next_task_id) + + self.update_lock.release() + return next_task_id \ No newline at end of file diff --git a/examples/auxo/utils/grad_monitor.py b/examples/auxo/utils/grad_monitor.py new file mode 100644 index 00000000..8475dcf5 --- /dev/null +++ b/examples/auxo/utils/grad_monitor.py @@ -0,0 +1,73 @@ +import logging +import numpy as np +import math + +class Gradient_Monitor(): + def __init__(self, rank): + self.client_grad = {} + self.client_data = {} + self.grad_stability = [] + self.rank = rank + + + def register_client(self, client_id, new_w, old_w, client_data): + if self.rank > 1: + return + + old_w = [dt.cpu().numpy() for dt in old_w] + new_w = [new_w[k] for k in new_w] + gradient = [pb - pa for pa, pb in zip(new_w, old_w)] + self.client_grad[client_id] = np.concatenate([n.ravel() for n in gradient]) + clt_data = None + for data_pair in client_data: + (data, target) = data_pair + if clt_data is None: + clt_data = [np.asarray(data.ravel())] + else: + clt_data.append(np.asarray(data.ravel())) + + self.client_data[client_id] = np.mean(clt_data, axis=0) + logging.info(f'client_data[{client_id}] registered') + + def _cal_data_similarity(self, client_1, client_2): + '''Calculate the cosine similarity between the data of two clients''' + data_1 = self.client_data[client_1].ravel() + data_2 = self.client_data[client_2].ravel() + cosine_similarity = np.dot(data_1, data_2) / (np.linalg.norm(data_1) * np.linalg.norm(data_2)) + return cosine_similarity + + def _cal_grad_similarity(self, client_1, client_2): + '''Calculate the cosine similarity between the gradient of two clients''' + grad_1 = self.client_grad[client_1] + grad_2 = self.client_grad[client_2] + + # Calculate the magnitude (Euclidean norm) of A and B + magnitude_A = np.linalg.norm(grad_1) + magnitude_B = np.linalg.norm(grad_2) + # Calculate the cosine similarity + cosine_similarity = np.dot(grad_1, grad_2) / (magnitude_A * magnitude_B) + + return cosine_similarity + + def cal_pairwise_grad_stability(self): + '''Calculate the pairwise gradient similarity and data similarity''' + if len(self.client_grad) > 0: + grad_sim = [] + data_sim = [] + for client_id1 in self.client_grad: + for client_id2 in self.client_grad: + if client_id1 > client_id2: + data_similarity = self._cal_data_similarity(client_id1, client_id2) + grad_similarity = self._cal_grad_similarity(client_id1, client_id2) + # logging.info(f"Gradient similarity between {client_id1} and {client_id2}: {grad_similarity}") + # logging.info(f"Data similarity between {client_id1} and {client_id2}: {data_similarity}") + data_sim.append(data_similarity) + grad_sim.append(grad_similarity) + correlation_coefficient = np.corrcoef(grad_sim, data_sim)[0, 1] + + self.grad_stability.append(correlation_coefficient) + logging.info(f"Gradient stability: {self.grad_stability}") + # Reset for the next round + self.client_grad = {} + self.client_data = {} + diff --git a/examples/auxo/utils/helper.py b/examples/auxo/utils/helper.py new file mode 100644 index 00000000..4c61353d --- /dev/null +++ b/examples/auxo/utils/helper.py @@ -0,0 +1,15 @@ + + +def decode_msg(msg): + """Decode message into event type and cohort id + + Args: + msg (string): message from client + """ + return msg.split('-')[0], int(msg.split('-')[1]) + + +def generate_msg( msg_type, cohort_id=0): + return f'{msg_type}-{cohort_id}' + + diff --git a/examples/auxo/utils/klkmeans.py b/examples/auxo/utils/klkmeans.py new file mode 100644 index 00000000..7545e308 --- /dev/null +++ b/examples/auxo/utils/klkmeans.py @@ -0,0 +1,37 @@ +from nltk.cluster import KMeansClusterer, euclidean_distance +import numpy as np + + +class KLKmeans(object): + def __init__(self,n_clusters, init_center =None ): + self.labels_ = None + + def _processNegVals(x): + x = np.array(x) + minx = np.min(x) + if minx < 0: + x = x + abs(minx) + """ 0.000001 is used here to avoid 0. """ + x = x + 0.000001 + # px = x / np.sum(x) + return x + + def _KL(P, Q): + epsilon = 0.00001 + P = _processNegVals(P) + Q = _processNegVals(Q) + # You may want to instead make copies to avoid changing the np arrays. + divergence = np.sum(P * np.log(P / Q)) + return divergence + + self.klkmeans = KMeansClusterer(n_clusters, _KL, initial_means = init_center) + + def fit(self, x): + print(x) + self.klkmeans.cluster(x) + self.cluster_centers_ = self.klkmeans.means() + self.labels_ = self.predict(x) + + def predict(self, x): + return [ self.klkmeans.classify(i) for i in x] + diff --git a/examples/auxo/utils/openimg.py b/examples/auxo/utils/openimg.py new file mode 100644 index 00000000..498a7b05 --- /dev/null +++ b/examples/auxo/utils/openimg.py @@ -0,0 +1,138 @@ +from __future__ import print_function +import warnings +from PIL import Image +import os +import os.path +import csv + + +class OpenImage(): + """ + Args: + root (string): Root directory of dataset where ``MNIST/processed/training.pt`` + and ``MNIST/processed/test.pt`` exist. + train (bool, optional): If True, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + classes = [] + + @property + def train_labels(self): + warnings.warn("train_labels has been renamed targets") + return self.targets + + @property + def test_labels(self): + warnings.warn("test_labels has been renamed targets") + return self.targets + + @property + def train_data(self): + warnings.warn("train_data has been renamed data") + return self.data + + @property + def test_data(self): + warnings.warn("test_data has been renamed data") + return self.data + + def __init__(self, root, dataset='train', transform=None, target_transform=None, + imgview=False, client_mapping_file=None, num_clt=1e10, noniid=0): + + self.root = root + self.transform = transform + self.target_transform = target_transform + self.data_file = dataset # 'train', 'test', 'validation' + self.client_mapping_file = client_mapping_file + self.num_clt = num_clt + if not self._check_exists(): + raise RuntimeError('Dataset not found. You have to download it') + + self.path = os.path.join(self.processed_folder, self.data_file) + # load data and targets + self.data_to_clientID = {} + self.data, self.targets = self.load_file(self.path) + self.imgview = imgview + self.noniid = noniid + + def __getitem__(self, index): + """ + Args: + id_clt (int): Index, client ID + + Returns: + tuple: (image, target) where target is index of the target class. + """ + # index , clientID = id_clt + imgName, target = self.data[index], int(self.targets[index]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.open(os.path.join(self.path, imgName)) + # avoid channel error + if img.mode != 'RGB': + img = img.convert('RGB') + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.data) + + @property + def raw_folder(self): + return self.root + + @property + def processed_folder(self): + return self.root + + def _check_exists(self): + + print("Checking data path:", os.path.join(self.processed_folder, self.data_file)) + return (os.path.exists(os.path.join(self.processed_folder, + self.data_file))) + + def load_meta_data(self, path): + datas, labels = [], [] + unique_clientIds = set() + with open(path) as csv_file: + csv_reader = csv.reader(csv_file, delimiter=',') + line_count = 0 + for row in csv_reader: + if line_count != 0: + unique_clientIds.add(row[0]) + self.data_to_clientID[len(datas)] = row[0] + if len(unique_clientIds) > self.num_clt: + break + datas.append(row[1]) + labels.append(int(row[-1])) + + line_count += 1 + return datas, labels + + def load_file(self, path): + # load meta file to get labels + # datas, labels = self.load_meta_data(os.path.join(self.processed_folder, 'client_data_mapping', self.data_file+'.csv')) + if self.client_mapping_file is not None: + datas, labels = self.load_meta_data(self.client_mapping_file) + else: + datas, labels = self.load_meta_data( + os.path.join(self.processed_folder, 'client_data_mapping', self.data_file + '.csv')) + + return datas, labels + + diff --git a/examples/auxo/utils/prepare_train_test.py b/examples/auxo/utils/prepare_train_test.py new file mode 100644 index 00000000..eb528075 --- /dev/null +++ b/examples/auxo/utils/prepare_train_test.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import csv +import sys +from collections import defaultdict +from os import path +import pandas as pd + + +def partition_by_client_unsort(file_path='train.csv'): + print(f"Processing {file_path} into train and test sets.") + + clt_smp = defaultdict(list) + title = None + + # Read CSV and partition by client + with open(file_path, 'r') as fin: + csv_reader = csv.reader(fin) + for ind, row in enumerate(csv_reader): + if ind == 0: + title = row + else: + clt_smp[row[0]].append(ind - 1) + + # Check if output files already exist + no_title = path.exists('train_by_clt.csv') or path.exists('test_by_clt.csv') + + # Initialize CSV writers for train and test sets + with open('train_by_clt.csv', 'a') as write_partition_train_file, open('test_by_clt.csv', + 'a') as write_partition_test_file: + writer_train = csv.writer(write_partition_train_file) + writer_test = csv.writer(write_partition_test_file) + + if not no_title: + writer_train.writerow(title) + writer_test.writerow(title) + + # Read original CSV into a Pandas DataFrame + print("Reading into pandas DataFrame.") + df = pd.read_csv(file_path) + + cnt = 0 + clt_num = 0 + + # Partition data and write to train and test CSV files + for clt, samples in clt_smp.items(): + sample_num = len(samples) + cnt += sample_num + clt_num += 1 + + for i, sample_idx in enumerate(samples): + if i < sample_num * 0.8: + writer_train.writerow(list(df.loc[sample_idx].values)) + else: + writer_test.writerow(list(df.loc[sample_idx].values)) + + if cnt % 10000 == 0: + print(f"Wrote {cnt} samples.") + print(f"Running average sample: {cnt / clt_num}") + + +if __name__ == "__main__": + file_path = sys.argv[1] + partition_by_client_unsort(file_path)