Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix q-fedavg model loading error | Fix test cases #246

Merged
merged 3 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions fedscale/cloud/aggregation/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,9 @@ def update_round_gradient(
update_weights = result["update_weight"]
if type(update_weights) is dict:
update_weights = [x for x in update_weights.values()]

weights = [
torch.from_numpy(np.asarray(x, dtype=np.float32)).to(
device=self.device
)
for x in update_weights
torch.tensor(x).to(device=self.device) for x in update_weights
]
grads = [
(u - v) * 1.0 / learning_rate for u, v in zip(last_model, weights)
Expand All @@ -100,8 +98,10 @@ def update_round_gradient(
) + (1.0 / learning_rate) * np.float_power(loss + 1e-10, qfedq)

# update global model
for idx, param in enumerate(target_model.parameters()):
param.data = last_model[idx] - Deltas[idx] / (hs + 1e-10)
new_state_dict = {
name: last_model[idx] - Deltas[idx] / (hs + 1e-10) for idx, name in enumerate(target_model.state_dict().keys())
}
target_model.load_state_dict(new_state_dict)

else:
# The default optimizer, FedAvg, has been applied in aggregator.py on the fly
Expand Down
7 changes: 6 additions & 1 deletion fedscale/cloud/internal/model_adapter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ class ModelAdapterBase(abc.ABC):
"""
Represents an adapter that operates on a framework-specific model.
"""

@abc.abstractmethod
def set_weights(self, weights: np.ndarray):
def set_weights(
self, weights: np.ndarray, is_aggregator=True, client_training_results=None
):
"""
Set the model's weights to the numpy weights array.
:param weights: numpy weights array
:param is_aggregator: boolean indicating whether the caller is the aggregator
:param client_training_results: list of gradients from every clients, for q-fedavg
"""
pass

Expand Down
19 changes: 17 additions & 2 deletions fedscale/cloud/internal/tensorflow_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,28 @@ class TensorflowModelAdapter(ModelAdapterBase):
def __init__(self, model: tf.keras.Model):
self.model = model

def set_weights(self, weights: List[np.ndarray]):
def set_weights(
self,
weights: List[np.ndarray],
is_aggregator=True,
client_training_results=None,
):
"""
Set the model's weights to the numpy weights array.
:param weights: numpy weights array
:param is_aggregator: boolean indicating whether the caller is the aggregator
:param client_training_results: list of gradients from every clients, for q-fedavg
"""
for i, layer in enumerate(self.model.layers):
if layer.trainable:
layer.set_weights(weights[i])

def get_weights(self) -> List[np.ndarray]:
return [np.asarray(layer.get_weights()) for layer in self.model.layers if layer.trainable]
return [
np.asarray(layer.get_weights())
for layer in self.model.layers
if layer.trainable
]

def get_model(self):
return self.model
14 changes: 10 additions & 4 deletions fedscale/tests/cloud/aggregation/test_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, model_wrapper):
self.model_in_update = 1
self.tasks_round = 3
self.model_wrapper = model_wrapper
self.client_training_results = None


def multiply_weights(weights, factor):
Expand All @@ -23,8 +24,9 @@ def multiply_weights(weights, factor):
class TestAggregator:
def test_update_weight_aggregation_for_keras_model(self):
x = tf.keras.Input(shape=(2,))
y = tf.keras.layers.Dense(2, activation='softmax')(
tf.keras.layers.Dense(4, activation='softmax')(x))
y = tf.keras.layers.Dense(2, activation="softmax")(
tf.keras.layers.Dense(4, activation="softmax")(x)
)
model = tf.keras.Model(x, y)
model_adapter = TensorflowModelAdapter(model)
aggregator = MockAggregator(model_adapter)
Expand All @@ -34,7 +36,9 @@ def test_update_weight_aggregation_for_keras_model(self):
aggregator.update_weight_aggregation(multiply_weights(weights, 2))
aggregator.model_in_update += 1
aggregator.update_weight_aggregation(multiply_weights(weights, 5))
np.array_equal(aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3))
np.array_equal(
aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3)
)

def test_update_weight_aggregation_for_torch_model(self):
model = torch.nn.Linear(3, 2)
Expand All @@ -46,4 +50,6 @@ def test_update_weight_aggregation_for_torch_model(self):
aggregator.update_weight_aggregation(multiply_weights(weights, 2))
aggregator.model_in_update += 1
aggregator.update_weight_aggregation(multiply_weights(weights, 5))
np.array_equal(aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3))
np.array_equal(
aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3)
)
Loading