diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index 685844070d..3e0c1420e1 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -1,10 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later import logging -import torch -import numpy as np -from torch.optim.optimizer import Optimizer import math + +import torch import torch.distributed as dist -from torch.profiler import profile, record_function, ProfilerActivity +from torch.optim.optimizer import ( + Optimizer, +) def distribute_indices(total_length, num_workers): @@ -19,7 +21,9 @@ def distribute_indices(total_length, num_workers): indices.append((start, end)) start = end - return indices,remainder + return indices, remainder + + class LKFOptimizer(Optimizer): def __init__( self, @@ -28,7 +32,6 @@ def __init__( kalman_nue=0.9987, block_size=5120, ): - defaults = dict( lr=0.1, kalman_nue=kalman_nue, @@ -55,7 +58,6 @@ def __init__( self.__init_P() def __init_P(self): - param_nums = [] param_sum = 0 block_size = self.__get_blocksize() @@ -79,21 +81,21 @@ def __init_P(self): params_packed_index = [] logging.info("LKF parameter nums: %s" % param_nums) if self.dist_init: - block_num = 0; + block_num = 0 for param_num in param_nums: if param_num >= block_size: block_num += math.ceil(param_num / block_size) - else: - block_num +=1 + else: + block_num += 1 num_workers = dist.get_world_size() - self.dindex,self.remainder = distribute_indices(block_num,num_workers) + self.dindex, self.remainder = distribute_indices(block_num, num_workers) index = 0 for param_num in param_nums: if param_num >= block_size: block_num = math.ceil(param_num / block_size) for i in range(block_num): device_id = self.get_device_id(index) - index +=1 + index += 1 dist_device = torch.device("cuda:" + str(device_id)) if i != block_num - 1: params_packed_index.append(block_size) @@ -106,7 +108,7 @@ def __init_P(self): ) ) else: - continue + continue else: params_packed_index.append(param_num - block_size * i) if self.rank == device_id: @@ -120,13 +122,15 @@ def __init_P(self): else: continue - else: + else: device_id = self.get_device_id(index) - index +=1 + index += 1 params_packed_index.append(param_num) if self.rank == device_id: dist_device = torch.device("cuda:" + str(device_id)) - P.append(torch.eye(param_num, dtype=data_type, device=dist_device)) + P.append( + torch.eye(param_num, dtype=data_type, device=dist_device) + ) device_id = self.rank else: for param_num in param_nums: @@ -151,7 +155,7 @@ def __init_P(self): ) ) params_packed_index.append(param_num - block_size * i) - else: + else: P.append(torch.eye(param_num, dtype=data_type, device=device)) params_packed_index.append(param_num) @@ -206,13 +210,18 @@ def __update(self, H, error, weights): if self.dist_init: device = torch.device("cuda:" + str(self.rank)) local_shape = [tensor.shape[0] for tensor in weights] - shape_list = [torch.zeros_like(torch.empty(1),dtype=torch.float64,device=device) for _ in range(dist.get_world_size())] + shape_list = [ + torch.zeros_like(torch.empty(1), dtype=torch.float64, device=device) + for _ in range(dist.get_world_size()) + ] dist.all_gather_object(shape_list, local_shape) weight_tensor = torch.cat(weights) world_shape = [sum(inner_list) for inner_list in shape_list] weight_list = [None] * len(world_shape) for i in range(len(world_shape)): - weight_list[i] = torch.zeros(world_shape[i],dtype=torch.float64,device=device) + weight_list[i] = torch.zeros( + world_shape[i], dtype=torch.float64, device=device + ) dist.all_gather(weight_list, weight_tensor) result = [] for i in range(dist.get_world_size()): @@ -255,7 +264,6 @@ def set_grad_prefactor(self, grad_prefactor): self.grad_prefactor = grad_prefactor def step(self, error): - params_packed_index = self._state.get("params_packed_index") weights = []