Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eatpk committed Jan 28, 2024
1 parent 79d94c2 commit 5145589
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 9 deletions.
2 changes: 0 additions & 2 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

from analog.analysis import InfluenceFunction

from analog.analysis import InfluenceFunction

from analog.batch_info import BatchInfo
from analog.config import init_config_from_yaml
from analog.logging import HookLogger
Expand Down
2 changes: 1 addition & 1 deletion analog/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def compute_influence(

if self.flatten:
src = self.flatten_log(src)
synchronize_device_flatten(src, tgt)
synchronize_device(src, tgt)
total_influence += self._dot_product_logs(src, tgt)

if not self.flatten:
Expand Down
17 changes: 15 additions & 2 deletions analog/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Dict, Optional, List, Tuple
from typing import Dict, Optional, List, Tuple, Union
import torch


def synchronize_device(
def synchronize_device_unflatten(
src: Dict[str, Dict[str, torch.Tensor]],
tgt: Dict[str, Dict[str, torch.Tensor]],
device: Optional[torch.device] = None,
Expand Down Expand Up @@ -40,3 +40,16 @@ def synchronize_device_flatten(
if device is None:
src_device = src.device
tgt.to(device=src_device)


def synchronize_device(
src: Union[Dict[str, Dict[str, torch.Tensor]], torch.Tensor],
tgt: Union[Dict[str, Dict[str, torch.Tensor]], torch.Tensor],
device: Optional[torch.device] = None,
):
if isinstance(src, dict):
assert isinstance(tgt, dict)
synchronize_device_unflatten(src, tgt, device)
else:
assert isinstance(src, torch.Tensor) and isinstance(tgt, torch.Tensor)
synchronize_device_flatten(src, tgt, device)
1 change: 0 additions & 1 deletion analog/logging/log_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(self, config, state):
self.flush_count = 0

self.buffer = nested_dict()

self.buffer_size = 0

def buffer_write(self, binfo):
Expand Down
2 changes: 1 addition & 1 deletion analog/test/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ logging:
log_dtype: none
num_workers: 1
lora:
init: pca
init: random
parameter_sharing: false
parameter_sharing_groups: null
rank: 64
Expand Down
2 changes: 1 addition & 1 deletion analog/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import logging as default_logging
from typing import Any, List, Optional
from collections import defaultdict, OrderedDict
from collections import defaultdict

import hashlib

Expand Down
2 changes: 1 addition & 1 deletion examples/mnist_influence/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
batch_size=1, split="valid", shuffle=False, indices=args.eval_idxs
)

analog = AnaLog(project="test") # switched.
analog = AnaLog(project="test")

# Gradient & Hessian logging
analog.watch(model)
Expand Down

0 comments on commit 5145589

Please sign in to comment.