diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index ebc9242d49..06b341d987 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -3,11 +3,25 @@ import math import torch +import torch.distributed as dist from torch.optim.optimizer import ( Optimizer, ) -log = logging.getLogger(__name__) + +def distribute_indices(total_length, num_workers): + indices_per_worker = total_length // num_workers + remainder = total_length % num_workers + + indices = [] + start = 0 + + for i in range(num_workers): + end = start + indices_per_worker + (1 if i < remainder else 0) + indices.append((start, end)) + start = end + + return indices, remainder class LKFOptimizer(Optimizer): @@ -18,11 +32,8 @@ def __init__( kalman_nue=0.9987, block_size=5120, ): - defaults = { - "lr": 0.1, - "kalman_nue": kalman_nue, - "block_size": block_size, - } + defaults = {"lr": 0.1, "kalman_nue": kalman_nue, "block_size": block_size} + super().__init__(params, defaults) self._params = self.param_groups[0]["params"] @@ -36,7 +47,10 @@ def __init__( # the first param, because this helps with casting in load_state_dict self._state = self.state[self._params[0]] self._state.setdefault("kalman_lambda", kalman_lambda) - + self.dist_init = dist.is_initialized() + self.rank = dist.get_rank() if self.dist_init else 0 + self.dindex = [] + self.remainder = 0 self.__init_P() def __init_P(self): @@ -61,32 +75,84 @@ def __init_P(self): P = [] params_packed_index = [] - log.info("LKF parameter nums: %s" % param_nums) - for param_num in param_nums: - if param_num >= block_size: - block_num = math.ceil(param_num / block_size) - for i in range(block_num): - if i != block_num - 1: + logging.info("LKF parameter nums: %s" % param_nums) + if self.dist_init: + block_num = 0 + for param_num in param_nums: + if param_num >= block_size: + block_num += math.ceil(param_num / block_size) + else: + block_num += 1 + num_workers = dist.get_world_size() + self.dindex, self.remainder = distribute_indices(block_num, num_workers) + index = 0 + for param_num in param_nums: + if param_num >= block_size: + block_num = math.ceil(param_num / block_size) + for i in range(block_num): + device_id = self.get_device_id(index) + index += 1 + dist_device = torch.device("cuda:" + str(device_id)) + if i != block_num - 1: + params_packed_index.append(block_size) + if self.rank == device_id: + P.append( + torch.eye( + block_size, + dtype=data_type, + device=dist_device, + ) + ) + else: + continue + else: + params_packed_index.append(param_num - block_size * i) + if self.rank == device_id: + P.append( + torch.eye( + param_num - block_size * i, + dtype=data_type, + device=dist_device, + ) + ) + else: + continue + + else: + device_id = self.get_device_id(index) + index += 1 + params_packed_index.append(param_num) + if self.rank == device_id: + dist_device = torch.device("cuda:" + str(device_id)) P.append( - torch.eye( - block_size, - dtype=data_type, - device=device, - ) + torch.eye(param_num, dtype=data_type, device=dist_device) ) - params_packed_index.append(block_size) - else: - P.append( - torch.eye( - param_num - block_size * i, - dtype=data_type, - device=device, + else: + for param_num in param_nums: + if param_num >= block_size: + block_num = math.ceil(param_num / block_size) + for i in range(block_num): + if i != block_num - 1: + P.append( + torch.eye( + block_size, + dtype=data_type, + device=device, + ) ) - ) - params_packed_index.append(param_num - block_size * i) - else: - P.append(torch.eye(param_num, dtype=data_type, device=device)) - params_packed_index.append(param_num) + params_packed_index.append(block_size) + else: + P.append( + torch.eye( + param_num - block_size * i, + dtype=data_type, + device=device, + ) + ) + params_packed_index.append(param_num - block_size * i) + else: + P.append(torch.eye(param_num, dtype=data_type, device=device)) + params_packed_index.append(param_num) self._state.setdefault("P", P) self._state.setdefault("weights_num", len(P)) @@ -125,16 +191,35 @@ def __update(self, H, error, weights): tmp = 0 for i in range(weights_num): tmp = tmp + (kalman_lambda + torch.matmul(torch.matmul(H[i].T, P[i]), H[i])) - + if self.dist_init: + dist.all_reduce(tmp, op=dist.ReduceOp.SUM) A = 1 / tmp - for i in range(weights_num): K = torch.matmul(P[i], H[i]) weights[i] = weights[i] + A * error * K P[i] = (1 / kalman_lambda) * (P[i] - A * torch.matmul(K, K.T)) - + if self.dist_init: + device = torch.device("cuda:" + str(self.rank)) + local_shape = [tensor.shape[0] for tensor in weights] + shape_list = [ + torch.zeros_like(torch.empty(1), dtype=torch.float64, device=device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather_object(shape_list, local_shape) + weight_tensor = torch.cat(weights) + world_shape = [sum(inner_list) for inner_list in shape_list] + weight_list = [None] * len(world_shape) + for i in range(len(world_shape)): + weight_list[i] = torch.zeros( + world_shape[i], dtype=torch.float64, device=device + ) + dist.all_gather(weight_list, weight_tensor) + result = [] + for i in range(dist.get_world_size()): + result = result + list(torch.split(weight_list[i], shape_list[i])) + weights = result kalman_lambda = kalman_nue * kalman_lambda + 1 - kalman_nue self._state.update({"kalman_lambda": kalman_lambda}) @@ -215,9 +300,21 @@ def step(self, error): param_sum += nelement if param_sum == params_packed_index[param_index]: - H.append(res_grad) - weights.append(res) param_sum = 0 + if self.dist_init: + device_id = self.get_device_id(param_index) + if self.rank == device_id: + weights.append(res) + H.append(res_grad) + else: + weights.append(res) + H.append(res_grad) param_index += 1 self.__update(H, error, weights) + + def get_device_id(self, index): + for i, (start, end) in enumerate(self.dindex): + if start <= index < end: + return i + return None