Skip to content

Commit

Permalink
ekfac wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 22, 2023
1 parent 576f3b9 commit 72e2145
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 27 deletions.
54 changes: 48 additions & 6 deletions analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ def add_lora(
watch: bool = True,
clear: bool = True,
) -> None:
"""
Adds LoRA for gradient compression.
Args:
model: The neural network model.
parameter_sharing (bool, optional): Whether to use parameter sharing or not.
parameter_sharing_groups (list, optional): List of parameter sharing groups.
watch (bool, optional): Whether to watch the model or not.
clear (bool, optional): Whether to clear the internal states or not.
"""
hessian_state = self.hessian_handler.get_hessian_state()
self.lora_handler.add_lora(
model=model,
Expand Down Expand Up @@ -159,7 +169,9 @@ def remove_analysis(self, analysis_name: str) -> None:
analysis_name (str): Name of the analysis to be removed.
"""
if analysis_name not in self.analysis_plugins:
print(f"Analysis {analysis_name} does not exist. Nothing to remove.")
get_logger().warning(
f"Analysis {analysis_name} does not exist. Nothing to remove."
)
return None
del self.analysis_plugins[analysis_name]
delattr(self, analysis_name)
Expand All @@ -171,6 +183,7 @@ def __call__(
hessian: bool = True,
save: bool = False,
test: bool = False,
strategy: Optional[str] = None,
):
"""
Args:
Expand All @@ -183,11 +196,14 @@ def __call__(
Returns:
self: Returns the instance of the AnaLog object.
"""
self.data_id = data_id
self.log = log
self.hessian = hessian if not test else False
self.save = save if not test else False
self.test = test
if strategy is None:
self.data_id = data_id
self.log = log
self.hessian = hessian if not test else False
self.save = save if not test else False
self.test = test
else:
self.parse_strategy(strategy)

self.sanity_check(self.data_id, self.log, self.test)

Expand Down Expand Up @@ -342,6 +358,22 @@ def finalize(
self.hessian_handler.clear()
self.storage_handler.clear()

def parse_strategy(self, strategy: str) -> None:
"""
Parses the strategy string to set the internal states.
Args:
strategy (str): The strategy string.
"""
strategy = strategy.lower()
if strategy == "train":
self.log = [FORWARD, BACKWARD]
self.hessian = True
self.save = False
self.test = False
else:
raise ValueError(f"Unknown strategy: {strategy}")

def sanity_check(
self, data_id: Iterable[Any], log: Iterable[str], test: bool
) -> None:
Expand All @@ -355,6 +387,16 @@ def sanity_check(
if GRAD in log and len(log) > 1:
raise ValueError("Cannot log 'grad' with other log types.")

def ekfac(self, on: bool = True) -> None:
"""
Compute the EKFAC approximation of the Hessian.
"""
assert self.hessian_handler.config.get("type", "kfac") == "kfac"
if on:
self.hessian_handler.ekfac = True
else:
self.hessian_handler.ekfac = False

def reset(self) -> None:
"""
Reset the internal states.
Expand Down
108 changes: 90 additions & 18 deletions analog/hessian/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class KFACHessianHandler(HessianHandlerBase):
"""
Compute the Hessian via the K-FAC method.
"""
def __init__(self, config: dict) -> None:
super().__init__(config)
self.ekfac = False

def parse_config(self) -> None:
self.damping = self.config.get("damping", 1e-2)
Expand All @@ -27,6 +30,10 @@ def on_exit(self, current_log=None) -> None:
if self.reduce:
raise NotImplementedError

if self.ekfac:
for module_name, module_grad in self.modules_to_hook.items():
self.update_ekfac(module_name, module_grad)

@torch.no_grad()
def update_hessian(
self,
Expand All @@ -36,28 +43,57 @@ def update_hessian(
data: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> None:
if not self.reduce:
# extract activations
activation = self.extract_activations(module, mode, data, mask)
if self.reduce or self.ekfac:
return
# extract activations
activation = self.extract_activations(module, mode, data, mask)

# compute covariance
covariance = torch.matmul(torch.t(activation), activation).cpu().detach()
# compute covariance
covariance = torch.matmul(torch.t(activation), activation).cpu().detach()

# update covariance
if deep_get(self.hessian_state, [module_name, mode]) is None:
self.hessian_state[module_name][mode] = torch.zeros_like(covariance)
self.sample_counter[module_name][mode] = 0
self.hessian_state[module_name][mode].add_(covariance)
self.sample_counter[module_name][mode] += self.get_sample_size(data, mask)
# update covariance
if deep_get(self.hessian_state, [module_name, mode]) is None:
self.hessian_state[module_name][mode] = torch.zeros_like(covariance)
self.sample_counter[module_name][mode] = 0
self.hessian_state[module_name][mode].add_(covariance)
self.sample_counter[module_name][mode] += self.get_sample_size(data, mask)

@torch.no_grad()
def update_ekfac(
self,
module_name: str,
data: torch.Tensor,
) -> None:
if not hasattr(self, "hessian_eigval_state"):
self.hessian_svd(set_attr=True)
if not hasattr(self, "ekfac_eigval_state"):
self.ekfac_eigval_state = nested_dict()
self.ekfac_counter = nested_dict()

if module_name not in selfk.ekfac_eigval_state:
self.ekfac_eigval_state[module_name] = torch.zeros(0, 0)

self.ekfac_counter[module_name] += len(data)
rotated_grads = torch.matmul(data, self.hessian_svd_state[module_name][FORWARD])
for rotated_grad in rotated_grads:
weight = torch.matmul(
self.hessian_svd_state[module_name][BACKWARD], roated_grad
)
self.ekfac_eigval_state[module_name].add_(torch.square(weight))

def finalize(self) -> None:
for module_name, module_state in self.hessian_state.items():
for mode, covariance in module_state.items():
covariance.div_(self.sample_counter[module_name][mode])
if self.ekfac:
for module_name, ekfac_eigval in self.ekfac_eigval_state.items():
ekfac_eigval.div_(self.ekfac_counter[module_name])
else:
for module_name, module_state in self.hessian_state.items():
for mode, covariance in module_state.items():
covariance.div_(self.sample_counter[module_name][mode])

self.synchronize()

@torch.no_grad()
def hessian_inverse(self):
def hessian_inverse(self, set_attr: bool = False):
"""
Compute the inverse of the covariance.
"""
Expand All @@ -71,10 +107,14 @@ def hessian_inverse(self):
* torch.eye(covariance.size(0))
/ covariance.size(0)
)

if set_attr:
self.hessian_inverse_state = hessian_inverse_state

return hessian_inverse_state

@torch.no_grad()
def hessian_svd(self):
def hessian_svd(self, set_attr: bool = False):
"""
Compute the SVD of the covariance.
"""
Expand All @@ -86,19 +126,51 @@ def hessian_svd(self):
hessian_eigval_state[module_name][mode] = eigvals
hessian_eigvec_state[module_name][mode] = eigvecs

if set_attr:
self.hessian_eigval_state = hessian_eigval_state
self.hessian_eigvec_state = hessian_eigvec_state

return hessian_eigval_state, hessian_eigvec_state

def get_hessian_inverse_state(self):
if not hasattr(self, "hessian_inverse_state"):
self.hessian_inverse(set_attr=True)
return self.hessian_inverse_state

def get_hessian_svd_state(self):
if not hasattr(self, "hessian_eigval_state"):
self.hessian_svd(set_attr=True)
return self.hessian_eigval_state, self.hessian_eigvec_state

def synchronize(self) -> None:
"""
Synchronize the covariance across all processes.
"""
world_size = get_world_size()
if world_size > 1:
if get_world_size() <= 1:
return

if self.ekfac:
for _, ekfac_eigval in self.ekfac_eigval_state.items():
ekfac_eigval.div_(world_size)
dist.all_reduce(ekfac_eigval, op=dist.ReduceOp.SUM)
else:
for _, module_state in self.hessian_state.items():
for _, covariance in module_state.items():
covariance.div_(world_size)
dist.all_reduce(covariance, op=dist.ReduceOp.SUM)

def clear(self) -> None:
"""
Clear the Hessian state.
"""
super().clear()
if hasattr(self, "hessian_eigval_state"):
del self.hessian_eigval_state
del self.hessian_eigvec_state
if hasattr(self, "ekfac_eigval_state"):
del self.ekfac_eigval_state
del self.ekfac_counter

def extract_activations(
self,
module: nn.Module,
Expand Down
5 changes: 3 additions & 2 deletions analog/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(

self.storage_handler = storage_handler
self.hessian_handler = hessian_handler
self.hessian_type = hessian_handler.config.get("type", "kfac")

# Internal states
self.log = None
Expand Down Expand Up @@ -69,7 +70,7 @@ def _forward_hook_fn(
"""
assert len(inputs) == 1

if self.hessian:
if self.hessian and self.hessian_type == "kfac":
self.hessian_handler.update_hessian(module, module_name, FORWARD, inputs[0])

if FORWARD in self.log:
Expand Down Expand Up @@ -98,7 +99,7 @@ def _backward_hook_fn(
"""
assert len(grad_outputs) == 1

if self.hessian:
if self.hessian and self.hessian_type == "kfac":
self.hessian_handler.update_hessian(
module, module_name, BACKWARD, grad_outputs[0]
)
Expand Down
2 changes: 1 addition & 1 deletion analog/lora/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, rank: int, linear: nn.Linear, shared_module: nn.Linear = None

self._linear = linear

def forward(self, input) -> torch.Tensor:
def forward(self, input: torch.Tensor) -> torch.Tensor:
result = self._linear(input)
result += self.analog_lora_C(self.analog_lora_B(self.analog_lora_A(input)))

Expand Down

0 comments on commit 72e2145

Please sign in to comment.