Skip to content

Commit

Permalink
Print pairwise
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 19, 2024
1 parent 0ccec85 commit 83f3763
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 35 deletions.
2 changes: 0 additions & 2 deletions kronfluence/module/tracked_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,11 +679,9 @@ def synchronize_preconditioned_gradient(self, num_processes: int) -> None:
"""Stacks preconditioned gradient across multiple devices or nodes in a distributed setting."""
if dist.is_initialized() and torch.cuda.is_available() and self._preconditioned_gradient_available():
if isinstance(self._storage[PRECONDITIONED_GRADIENT_NAME], list):
print("1")
assert len(self._storage[PRECONDITIONED_GRADIENT_NAME]) == 2
for i in range(len(self._storage[PRECONDITIONED_GRADIENT_NAME])):
size = self._storage[PRECONDITIONED_GRADIENT_NAME][i].size()
print(f"Size: {size}")
stacked_matrix = torch.empty(
size=(num_processes, size[0], size[1], size[2]),
dtype=self._storage[PRECONDITIONED_GRADIENT_NAME][i].dtype,
Expand Down
70 changes: 37 additions & 33 deletions tests/gpu_tests/ddp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,39 +124,43 @@ def setUpClass(cls) -> None:
# rtol=1e-1,
# )
#
# def test_pairwise_scores(self) -> None:
# pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME)
#
# score_args = ScoreArguments(
# score_dtype=torch.float64,
# per_sample_gradient_dtype=torch.float64,
# precondition_dtype=torch.float64,
# )
# self.analyzer.compute_pairwise_scores(
# scores_name=NEW_SCORE_NAME,
# factors_name=OLD_FACTOR_NAME,
# query_dataset=self.eval_dataset,
# train_dataset=self.train_dataset,
# train_indices=list(range(TRAIN_INDICES)),
# query_indices=list(range(QUERY_INDICES)),
# per_device_query_batch_size=12,
# per_device_train_batch_size=512,
# score_args=score_args,
# overwrite_output_dir=True,
# )
# new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME)
#
# if LOCAL_RANK == 0:
# print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}")
# print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}")
# print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}")
# print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}")
# assert check_tensor_dict_equivalence(
# pairwise_scores,
# new_pairwise_scores,
# atol=1e-5,
# rtol=1e-3,
# )
def test_pairwise_scores(self) -> None:
pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=OLD_SCORE_NAME)

score_args = ScoreArguments(
score_dtype=torch.float64,
per_sample_gradient_dtype=torch.float64,
precondition_dtype=torch.float64,
)
self.analyzer.compute_pairwise_scores(
scores_name=NEW_SCORE_NAME,
factors_name=OLD_FACTOR_NAME,
query_dataset=self.eval_dataset,
train_dataset=self.train_dataset,
train_indices=list(range(TRAIN_INDICES)),
query_indices=list(range(QUERY_INDICES)),
per_device_query_batch_size=12,
per_device_train_batch_size=512,
score_args=score_args,
overwrite_output_dir=True,
)
new_pairwise_scores = self.analyzer.load_pairwise_scores(scores_name=NEW_SCORE_NAME)

if LOCAL_RANK == 0:
print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][0]}")
print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}")
print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][0]}")
print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}")
print(f"Previous score: {pairwise_scores[ALL_MODULE_NAME][50]}")
print(f"Previous shape: {pairwise_scores[ALL_MODULE_NAME].shape}")
print(f"New score: {new_pairwise_scores[ALL_MODULE_NAME][50]}")
print(f"New shape: {new_pairwise_scores[ALL_MODULE_NAME].shape}")
assert check_tensor_dict_equivalence(
pairwise_scores,
new_pairwise_scores,
atol=1e-5,
rtol=1e-3,
)
#
# def test_self_scores(self) -> None:
# self_scores = self.analyzer.load_self_scores(scores_name=OLD_SCORE_NAME)
Expand Down

0 comments on commit 83f3763

Please sign in to comment.