diff --git a/analog/analog.py b/analog/analog.py index b7e05973..cc4c0b7a 100644 --- a/analog/analog.py +++ b/analog/analog.py @@ -17,6 +17,7 @@ from analog.lora import LoRAHandler from analog.lora.utils import is_lora from analog.state import AnaLogState +from analog.monitor_util.timer import DeviceFunctionTimer from analog.utils import ( get_logger, get_rank, @@ -237,6 +238,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: """ self.logger.update() + @DeviceFunctionTimer.timer def build_log_dataset(self): """ Constructs the log dataset from the stored logs. This dataset can then be used @@ -249,6 +251,7 @@ def build_log_dataset(self): log_dataset = LogDataset(log_dir=self.log_dir, config=self.influence_config) return log_dataset + @DeviceFunctionTimer.timer def build_log_dataloader( self, batch_size: int = 16, num_workers: int = 0, pin_memory: bool = False ): diff --git a/analog/analysis/influence_function.py b/analog/analysis/influence_function.py index 64df43aa..2834cbc1 100644 --- a/analog/analysis/influence_function.py +++ b/analog/analysis/influence_function.py @@ -5,6 +5,7 @@ from einops import einsum, rearrange, reduce from analog.config import InfluenceConfig from analog.state import AnaLogState +from analog.monitor_util.timer import DeviceFunctionTimer from analog.utils import get_logger, nested_dict from analog.analysis.utils import synchronize_device @@ -24,6 +25,7 @@ def __init__(self, config: InfluenceConfig, state: AnaLogState): self.influence_scores = pd.DataFrame() self.flatten = config.flatten + @DeviceFunctionTimer.timer @torch.no_grad() def precondition( self, @@ -212,6 +214,7 @@ def flatten_log(self, src): to_cat.append(log.view(bsz, -1)) return torch.cat(to_cat, dim=1) + @DeviceFunctionTimer.timer def compute_influence_all( self, src_log: Tuple[str, Dict[str, Dict[str, torch.Tensor]]], diff --git a/analog/logging/log_saver.py b/analog/logging/log_saver.py index 26b25408..e95d5ced 100644 --- a/analog/logging/log_saver.py +++ b/analog/logging/log_saver.py @@ -1,6 +1,7 @@ from concurrent.futures import ThreadPoolExecutor import torch +from analog.monitor_util.timer import DeviceFunctionTimer from analog.utils import nested_dict, to_numpy, get_rank from analog.logging.mmap import MemoryMapHandler @@ -21,6 +22,7 @@ def __init__(self, config, state): self.buffer = nested_dict() self.buffer_size = 0 + @DeviceFunctionTimer.timer def buffer_write(self, binfo): """ Add log state on exit. @@ -85,6 +87,7 @@ def _flush_serialized(self, log_dir) -> str: del buffer_list return log_dir + @DeviceFunctionTimer.timer def flush(self) -> None: """ For the DefaultHandler, there's no batch operation needed since each add operation writes to the file. diff --git a/analog/logging/logger.py b/analog/logging/logger.py index 684bb156..a3d6b3eb 100644 --- a/analog/logging/logger.py +++ b/analog/logging/logger.py @@ -10,6 +10,7 @@ from analog.logging.option import LogOption from analog.logging.log_saver import LogSaver from analog.logging.utils import compute_per_sample_gradient +from analog.monitor_util.timer import DeviceFunctionTimer from analog.utils import get_logger @@ -42,6 +43,7 @@ def __init__( self.grad_hooks = [] self.tensor_hooks = [] + @DeviceFunctionTimer.timer def log(self, data_id: Any, mask: Optional[torch.Tensor] = None): """ Add log state on exit. @@ -59,6 +61,7 @@ def log(self, data_id: Any, mask: Optional[torch.Tensor] = None): return log + @DeviceFunctionTimer.timer def update(self): # Update statistics for stat in self.opt.statistic["grad"]: @@ -82,6 +85,7 @@ def update(self): self.log_saver.buffer_write(binfo=self.binfo) self.log_saver.flush() + @DeviceFunctionTimer.timer def _forward_hook_fn( self, module: nn.Module, inputs: Tuple[torch.Tensor], module_name: str ) -> None: @@ -131,6 +135,7 @@ def _forward_hook_fn( cpu_offload=self.cpu_offload, ) + @DeviceFunctionTimer.timer def _backward_hook_fn( self, module: nn.Module, @@ -172,6 +177,7 @@ def _backward_hook_fn( cpu_offload=self.cpu_offload, ) + @DeviceFunctionTimer.timer def _grad_hook_fn( self, module: nn.Module, @@ -270,6 +276,7 @@ def _tensor_backward_hook_fn(self, grad: torch.Tensor, tensor_name: str) -> None cpu_offload=self.cpu_offload, ) + @DeviceFunctionTimer.timer def register_all_module_hooks(self) -> None: """ Register all module hooks. diff --git a/analog/monitor_util/__init__.py b/analog/monitor_util/__init__.py new file mode 100644 index 00000000..45761224 --- /dev/null +++ b/analog/monitor_util/__init__.py @@ -0,0 +1,2 @@ +from .timer import FunctionTimer, Timer +from .profiler import memory_profiler diff --git a/analog/monitor_util/profiler.py b/analog/monitor_util/profiler.py new file mode 100644 index 00000000..30066d16 --- /dev/null +++ b/analog/monitor_util/profiler.py @@ -0,0 +1,28 @@ +import torch +import functools +from torch.profiler import profile, ProfilerActivity + + +def memory_profiler(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + activities = [ProfilerActivity.CPU] + if device.type == "cuda": + activities.append(ProfilerActivity.CUDA) + + with profile(activities=activities, profile_memory=True) as prof: + result = func(*args, **kwargs) + + print( + prof.key_averages().table( + sort_by=( + "self_cuda_memory_usage" + if device.type == "cuda" + else "self_cpu_memory_usage" + ) + ) + ) + return result + + return wrapper diff --git a/analog/monitor_util/timer.py b/analog/monitor_util/timer.py new file mode 100644 index 00000000..fed99c20 --- /dev/null +++ b/analog/monitor_util/timer.py @@ -0,0 +1,178 @@ +import logging +import time +import functools + +import torch + + +def get_gpu_memory(device_index=None): + return torch.cuda.memory_allocated(device_index) + + +def get_gpu_max_memory(device_index=None): + return torch.cuda.max_memory_allocated(device_index) + + +class FunctionTimer: + log = {} + + @classmethod + def _wrap_function(cls, func, label, host_timer): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if host_timer: + return cls._host_timer_wrapper(func, label, *args, **kwargs) + else: + return cls._device_timer_wrapper(func, label, *args, **kwargs) + + return wrapper + + @classmethod + def _host_timer_wrapper(cls, func, label, *args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + if label not in cls.log: + cls.log[label] = [ + { + "time_delta": end_time - start_time, + } + ] + else: + cls.log[label].append( + { + "time_delta": end_time - start_time, + } + ) + return result + + @classmethod + def _device_timer_wrapper(cls, func, label, *args, **kwargs): + start_event = torch.cuda.Event(enable_timing=True) + start_event.record() + result = func(*args, **kwargs) + end_event = torch.cuda.Event(enable_timing=True) + end_event.record() + torch.cuda.current_stream().wait_event(end_event) + torch.cuda.synchronize() + if label not in cls.log: + cls.log[label] = [ + { + "time_delta": start_event.elapsed_time(end_event) + / 1000, # turn to seconds + } + ] + else: + cls.log[label].append( + { + "time_delta": start_event.elapsed_time(end_event) + / 1000, # turn to seconds + } + ) + return result + + @classmethod + def timer(cls, label_or_func=None): + host_timer = getattr( + cls, "host_timer", False + ) # Fallback to False if not defined + + def decorator(func): + label = label_or_func if isinstance(label_or_func, str) else func.__name__ + return cls._wrap_function(func, label, host_timer) + + if callable(label_or_func): + return decorator(label_or_func) + return decorator + + @classmethod + def get_log(cls): + return cls.log + + @classmethod + def print_log(cls): + print( + "###########################################################################" + ) + print( + "################################ TIMER LOG ################################" + ) + header = f"{'Label':<50} | {'Total Time (sec)':>20}" + print(header) + print("-" * len(header)) + for label, details in cls.log.items(): + sum_time = 0 + for log_entry in details: + time_delta = log_entry.get("time_delta", 0) + sum_time += time_delta + # truncate 47 letters if the label is longer than 50. + display_label = (label[:47] + "...") if len(label) > 50 else label + row = f"{display_label:<50} | {sum_time:>20.4f}" + print(row) + + +class HostFunctionTimer(FunctionTimer): + host_timer = True + + +class DeviceFunctionTimer(FunctionTimer): + if torch.cuda.is_available(): + host_timer = False + else: + logging.warning( + "CUDA is not set, setting the monitor_util is set to host monitor_util." + ) + host_timer = True + + +class Timer: + def __init__(self): + self.timers = { + "cpu": {}, + "gpu": {}, + } + self.timer_info = {} # synchronized. + self.is_synchronized = False + + def start_timer(self, name, host_timer=False): + if host_timer: + if name in self.timers["cpu"]: + logging.warning(f"monitor_util for {name} already exist") + return + start_time = time.time() + self.timers["cpu"][name] = [start_time] + else: + if name in self.timers["gpu"]: + logging.warning(f"monitor_util for {name} already exist") + return + self.is_synchronized = False + start_event = torch.cuda.Event(enable_timing=True) + start_event.record() + self.timers["gpu"][name] = [start_event] + + def stop_timer(self, name): + if name in self.timers["cpu"]: + end_time = time.time() + self.timers["cpu"][name].append(end_time) + if name in self.timers["gpu"]: + self.is_synchronized = False + end_event = torch.cuda.Event(enable_timing=True) + end_event.record() + self.timers["gpu"][name].append(end_event) + + def _calculate_elapse_time(self): + for name, timer in self.timers["cpu"].items(): + assert len(timer) == 2 + self.timer_info[name] = (timer[1] - timer[0]) * 1000 + if not self.is_synchronized: + for name, events in self.timers["gpu"].items(): + assert len(events) == 2 + torch.cuda.current_stream().wait_event(events[1]) + torch.cuda.synchronize() + self.timer_info[name] = events[0].elapsed_time(events[1]) + self.is_synchronized = True + + def get_info(self): + if not self.is_synchronized: + self._calculate_elapse_time() + return self.timer_info diff --git a/examples/mnist_influence/compute_influences.py b/examples/mnist_influence/compute_influences.py index b84b8b11..bf974cff 100644 --- a/examples/mnist_influence/compute_influences.py +++ b/examples/mnist_influence/compute_influences.py @@ -11,6 +11,8 @@ construct_mlp, ) +from analog.monitor_util import FunctionTimer + parser = argparse.ArgumentParser("MNIST Influence Analysis") parser.add_argument("--data", type=str, default="mnist", help="mnist or fmnist") parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0]) @@ -79,6 +81,7 @@ ) _, top_influential_data = torch.topk(if_scores, k=10) +FunctionTimer.print_log() # Save if_scores = if_scores.cpu().numpy().tolist()[0] torch.save(if_scores, "if_analog.pt") diff --git a/requirements.txt b/requirements.txt index 8a703764..89411e3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,3 @@ torch einops pyyaml -