diff --git a/analog/analog.py b/analog/analog.py index aa2034db..e07f6ccb 100644 --- a/analog/analog.py +++ b/analog/analog.py @@ -10,8 +10,6 @@ from analog.analysis import InfluenceFunction -from analog.analysis import InfluenceFunction - from analog.batch_info import BatchInfo from analog.config import init_config_from_yaml from analog.logging import HookLogger diff --git a/analog/analysis/influence_function.py b/analog/analysis/influence_function.py index 7d7635be..5f4f57df 100644 --- a/analog/analysis/influence_function.py +++ b/analog/analysis/influence_function.py @@ -113,7 +113,7 @@ def compute_influence( if self.flatten: src = self.flatten_log(src) - synchronize_device_flatten(src, tgt) + synchronize_device(src, tgt) total_influence += self._dot_product_logs(src, tgt) if not self.flatten: diff --git a/analog/analysis/utils.py b/analog/analysis/utils.py index 437397d8..c9c7240c 100644 --- a/analog/analysis/utils.py +++ b/analog/analysis/utils.py @@ -1,8 +1,8 @@ -from typing import Dict, Optional, List, Tuple +from typing import Dict, Optional, List, Tuple, Union import torch -def synchronize_device( +def synchronize_device_unflatten( src: Dict[str, Dict[str, torch.Tensor]], tgt: Dict[str, Dict[str, torch.Tensor]], device: Optional[torch.device] = None, @@ -40,3 +40,16 @@ def synchronize_device_flatten( if device is None: src_device = src.device tgt.to(device=src_device) + + +def synchronize_device( + src: Union[Dict[str, Dict[str, torch.Tensor]], torch.Tensor], + tgt: Union[Dict[str, Dict[str, torch.Tensor]], torch.Tensor], + device: Optional[torch.device] = None, +): + if isinstance(src, dict): + assert isinstance(tgt, dict) + synchronize_device_unflatten(src, tgt, device) + else: + assert isinstance(src, torch.Tensor) and isinstance(tgt, torch.Tensor) + synchronize_device_flatten(src, tgt, device) diff --git a/analog/logging/log_saver.py b/analog/logging/log_saver.py index 61db13b2..44d89427 100644 --- a/analog/logging/log_saver.py +++ b/analog/logging/log_saver.py @@ -19,7 +19,6 @@ def __init__(self, config, state): self.flush_count = 0 self.buffer = nested_dict() - self.buffer_size = 0 def buffer_write(self, binfo): diff --git a/analog/test/config.yaml b/analog/test/config.yaml index d1b00bbb..f8857d45 100644 --- a/analog/test/config.yaml +++ b/analog/test/config.yaml @@ -11,7 +11,7 @@ logging: log_dtype: none num_workers: 1 lora: - init: pca + init: random parameter_sharing: false parameter_sharing_groups: null rank: 64 diff --git a/analog/utils.py b/analog/utils.py index 13a81c49..4030334c 100644 --- a/analog/utils.py +++ b/analog/utils.py @@ -1,7 +1,7 @@ import sys import logging as default_logging from typing import Any, List, Optional -from collections import defaultdict, OrderedDict +from collections import defaultdict import hashlib diff --git a/examples/mnist_influence/compute_influences.py b/examples/mnist_influence/compute_influences.py index 46d7b010..b84b8b11 100644 --- a/examples/mnist_influence/compute_influences.py +++ b/examples/mnist_influence/compute_influences.py @@ -34,7 +34,7 @@ batch_size=1, split="valid", shuffle=False, indices=args.eval_idxs ) -analog = AnaLog(project="test") # switched. +analog = AnaLog(project="test") # Gradient & Hessian logging analog.watch(model)