Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pinning a source of randomness in server-side aggregation. #278

Merged
merged 13 commits into from
Nov 11, 2024

Conversation

emersodb
Copy link
Collaborator

@emersodb emersodb commented Nov 6, 2024

Adding in sorting of results on the server side prior to aggregation. This reduces randomness in results due to numpy numerical addition not identically associative

PR Type

"Fix" (not really but reducing randomness is good)

Short Description

Clickup Ticket(s): N/A

A source of randomness appeared in some of our experiments when using more than 2 clients, in spite of us pinning random seeds for python, numpy, and torch. One culprit for this non-determinism was found on the server-side weight aggregation. It turns out that numpy numerical addition is not associative (https://stackoverflow.com/questions/69616727/why-does-computing-mean-with-numpy-meana-axis-10-differs-from-computing-mea). That is, the order in which you add up floats can make a difference in the numerical precision fluctuations that you see. Because client weights are not strictly ordered (they are ordered by when their message is processed by the server), this can change the order in which weights are added together.

My initial approach was to sort the client results seen by the server by client ID (CID). However, these CIDs are generated deep within Flower by uuid and are, therefore, very hard to pin. So they fluctuate each run, which makes sorting by them useless in preserving summation order. So, in order to do this, I introduced a pseudo sorting approach that should work in deterministic settings. It's not may favorite idea, because it introduces some computation overhead, but it works.

If someone has a less invasive/better idea, I'm very much open to it.

Tests Added

Tested by hand on some of our trajectories to guarantee that the code here fixes the numerical fluctuations. This does not, however, guarantee that there are no other sources of randomness.

@sanaAyrml
Copy link
Collaborator

sanaAyrml commented Nov 8, 2024

On second thought should we also add

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

to set_all_random_seeds in random.py?

@emersodb
Copy link
Collaborator Author

On second thought should we also add

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

to set_all_random_seeds in random.py?

It's actually included in a different PR: See https://github.com/VectorInstitute/FL4Health/pull/251/files (random.py there). Sorry for the confusion.

@emersodb emersodb merged commit 23b4224 into main Nov 11, 2024
6 checks passed
@emersodb emersodb deleted the dbe/fixing_some_numpy_randomness_in_aggregation branch November 11, 2024 17:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants