Skip to content

Commit

Permalink
Jean/fix log frequency (#247)
Browse files Browse the repository at this point in the history
* solving #233

* blacking
  • Loading branch information
jeandut authored Nov 15, 2022
1 parent 7a18b82 commit 8963217
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions flamby/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,10 @@ def _local_train(self, dataloader_with_memory, num_updates):
# Local train
_size = len(dataloader_with_memory)
self.model = self.model.train()
for idx, _batch in enumerate(range(num_updates)):
for _batch in range(num_updates):
X, y = dataloader_with_memory.get_samples()
X, y = X.to(self._device), y.to(self._device)
if idx == 0:
if _batch == 0:
# Initialize the batch-size using the first batch to avoid
# edge cases with drop_last=False
_batch_size = X.shape[0]
Expand All @@ -211,14 +211,6 @@ def _local_train(self, dataloader_with_memory, num_updates):

if self.log:
if _batch % self.log_period == 0:
if _current_epoch > self.current_epoch:
# At each epoch we look at the histograms of all the
# network's parameters
for name, p in self.model.named_parameters():
self.writer.add_histogram(
f"client{self.client_id}/{name}", p, _current_epoch
)

print(
f"loss: {_loss:>7f} after {self.num_batches_seen:>5d}"
f" batches of data amounting to {_current_epoch:>5d}"
Expand All @@ -229,6 +221,15 @@ def _local_train(self, dataloader_with_memory, num_updates):
_loss,
self.num_batches_seen,
)

if _current_epoch > self.current_epoch:
# At each epoch we look at the histograms of all the
# network's parameters
for name, p in self.model.named_parameters():
self.writer.add_histogram(
f"client{self.client_id}/{name}", p, _current_epoch
)

self.current_epoch = _current_epoch

def _prox_local_train(self, dataloader_with_memory, num_updates, mu):
Expand Down

0 comments on commit 8963217

Please sign in to comment.