Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Turningl authored Dec 18, 2024
1 parent cd91d79 commit cc5c809
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,82 +26,82 @@ def main(args, dataset, model):
train, test = dataset.train, dataset.test
print('Using differential privacy!\n') if args.dp else print('No differential privacy!\n')

# prepare local dataset follow by dirichlet distribution
# =============================== prepare local dataset follow by dirichlet distribution ===============================
local_indices = molecule_dirichlet_distribution(args, train, args.num_clients, args.alpha, args.null_value, args.seed)

# set privacy preference
# =============================== set privacy preference ===============================
privacy_preferences = prepare_local_differential_privacy(args, args.num_clients)
print('privacy preferences: \n', privacy_preferences, '\n')

# set simulated databases
# =============================== set simulated databases ===============================
simulated_databases = []
for i in range(args.num_clients):
simulated_database = SimulatedDatabase(train=train, indices=local_indices[i], args=args)

# set noise multiplier
# =============================== set noise multiplier ===============================
if args.dp:
epsilon = privacy_preferences[i]
simulated_database.set_local_differential_privacy(epsilon)
print('the %d simulated database noise epsilon is %.4f' % ((i + 1), epsilon))

simulated_databases.append(simulated_database)

# set server
# =============================== set server ===============================
server = GlobalServer(model=model, args=args)

# set open access database
# =============================== set open access database ===============================
server.set_open_access_database(privacy_preferences) if args.dp else None

# init server algorithm
# =============================== init server algorithm ===============================
server.init_alg(alg=args.alg)

# init global model
# =============================== init global model ===============================
server_model = server.init_global_model()

# set communication round
# =============================== set communication round ===============================
communication_round = args.global_round // args.local_round
print('the communication_round is %d' % communication_round)

accuracy_accountant, rmse_accoutant = [], []
model_states, means = None, None

# start communication
# =============================== start communication ===============================
for r in range(communication_round):
print()
print('the %d communication round. \n' % (r + 1))

# local update and aggregate
# =============================== local update and aggregate ===============================
for idx, participant in enumerate(simulated_databases):
print("the %dth participant local update." % (idx + 1))

# delivery model
# =============================== delivery model ===============================
participant.download(copy.deepcopy(server_model))

# update participant open_access model states and means information
# =============================== update participant open_access model states and means information ===============================
if model_states:
participant.update_comm_optimization(model_states=model_states, means=means, participant=(idx not in server.open_access))

# local update
# =============================== local update ===============================
model_state = participant.local_update()

# aggregate
# =============================== aggregate ===============================
server.aggregate(idx, model_state, args.alg)

# load average weight
# =============================== load average weight ===============================
global_model = server.update()

# fetch model states and means information with communication optimization
# =============================== fetch model states and means information with communication optimization ===============================
if args.comm_optimization:
model_states, means = server.fetch_comm_optimization()

# regression
# =============================== regression
if dataset.dataset_name in dataset.dataset_names['regression']:
test_rmse, test_loss = inference_test_regression(args, global_model, test)
print('current global model has test rmse: %.4f test loss: %.4f' % (test_rmse, test_loss))

rmse_accoutant.append(test_rmse)

# classification
# =============================== classification ===============================
elif dataset.dataset_name in dataset.dataset_names['classification']:
test_acc, test_loss = inference_test_classification(args, global_model, test)
print('current global model has test acc: %.4f test loss: %.4f' % (test_acc, test_loss))
Expand All @@ -113,6 +113,7 @@ def main(args, dataset, model):

# torch.save(accuracy_accountant, 'accuracy_accountant.pt')

# =============================== print and save ===============================
if rmse_accoutant:
optimal_result = print_rmse_accoutant(rmse_accoutant)
save_progress(args, rmse_accoutant, optimal_result)
Expand Down

0 comments on commit cc5c809

Please sign in to comment.