From 4077515b33351462e8393c11b98dc6aa88c042f7 Mon Sep 17 00:00:00 2001 From: CaRoLZhangxy Date: Fri, 8 Mar 2024 07:36:48 +0000 Subject: [PATCH 1/7] add distributed form of lkf --- deepmd/pt/optimizer/LKF.py | 180 ++++++++++++++++++++++++++++--------- 1 file changed, 138 insertions(+), 42 deletions(-) diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index ebc9242d49..685844070d 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -1,15 +1,25 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later import logging +import torch +import numpy as np +from torch.optim.optimizer import Optimizer import math +import torch.distributed as dist +from torch.profiler import profile, record_function, ProfilerActivity -import torch -from torch.optim.optimizer import ( - Optimizer, -) -log = logging.getLogger(__name__) +def distribute_indices(total_length, num_workers): + indices_per_worker = total_length // num_workers + remainder = total_length % num_workers + + indices = [] + start = 0 + for i in range(num_workers): + end = start + indices_per_worker + (1 if i < remainder else 0) + indices.append((start, end)) + start = end + return indices,remainder class LKFOptimizer(Optimizer): def __init__( self, @@ -18,12 +28,14 @@ def __init__( kalman_nue=0.9987, block_size=5120, ): - defaults = { - "lr": 0.1, - "kalman_nue": kalman_nue, - "block_size": block_size, - } - super().__init__(params, defaults) + + defaults = dict( + lr=0.1, + kalman_nue=kalman_nue, + block_size=block_size, + ) + + super(LKFOptimizer, self).__init__(params, defaults) self._params = self.param_groups[0]["params"] @@ -36,10 +48,14 @@ def __init__( # the first param, because this helps with casting in load_state_dict self._state = self.state[self._params[0]] self._state.setdefault("kalman_lambda", kalman_lambda) - + self.dist_init = dist.is_initialized() + self.rank = dist.get_rank() if self.dist_init else 0 + self.dindex = [] + self.remainder = 0 self.__init_P() def __init_P(self): + param_nums = [] param_sum = 0 block_size = self.__get_blocksize() @@ -61,32 +77,83 @@ def __init_P(self): P = [] params_packed_index = [] - log.info("LKF parameter nums: %s" % param_nums) - for param_num in param_nums: - if param_num >= block_size: - block_num = math.ceil(param_num / block_size) - for i in range(block_num): - if i != block_num - 1: - P.append( - torch.eye( - block_size, - dtype=data_type, - device=device, + logging.info("LKF parameter nums: %s" % param_nums) + if self.dist_init: + block_num = 0; + for param_num in param_nums: + if param_num >= block_size: + block_num += math.ceil(param_num / block_size) + else: + block_num +=1 + num_workers = dist.get_world_size() + self.dindex,self.remainder = distribute_indices(block_num,num_workers) + index = 0 + for param_num in param_nums: + if param_num >= block_size: + block_num = math.ceil(param_num / block_size) + for i in range(block_num): + device_id = self.get_device_id(index) + index +=1 + dist_device = torch.device("cuda:" + str(device_id)) + if i != block_num - 1: + params_packed_index.append(block_size) + if self.rank == device_id: + P.append( + torch.eye( + block_size, + dtype=data_type, + device=dist_device, + ) + ) + else: + continue + else: + params_packed_index.append(param_num - block_size * i) + if self.rank == device_id: + P.append( + torch.eye( + param_num - block_size * i, + dtype=data_type, + device=dist_device, + ) + ) + else: + continue + + else: + device_id = self.get_device_id(index) + index +=1 + params_packed_index.append(param_num) + if self.rank == device_id: + dist_device = torch.device("cuda:" + str(device_id)) + P.append(torch.eye(param_num, dtype=data_type, device=dist_device)) + device_id = self.rank + else: + for param_num in param_nums: + if param_num >= block_size: + block_num = math.ceil(param_num / block_size) + for i in range(block_num): + if i != block_num - 1: + P.append( + torch.eye( + block_size, + dtype=data_type, + device=device, + ) ) - ) - params_packed_index.append(block_size) - else: - P.append( - torch.eye( - param_num - block_size * i, - dtype=data_type, - device=device, + params_packed_index.append(block_size) + else: + P.append( + torch.eye( + param_num - block_size * i, + dtype=data_type, + device=device, + ) ) - ) - params_packed_index.append(param_num - block_size * i) - else: - P.append(torch.eye(param_num, dtype=data_type, device=device)) - params_packed_index.append(param_num) + params_packed_index.append(param_num - block_size * i) + else: + P.append(torch.eye(param_num, dtype=data_type, device=device)) + params_packed_index.append(param_num) self._state.setdefault("P", P) self._state.setdefault("weights_num", len(P)) @@ -115,6 +182,8 @@ def __split_weights(self, weight): def __update(self, H, error, weights): P = self._state.get("P") + # for item in P: + # print(self.rank," size ",item.shape) kalman_lambda = self._state.get("kalman_lambda") weights_num = self._state.get("weights_num") params_packed_index = self._state.get("params_packed_index") @@ -125,16 +194,30 @@ def __update(self, H, error, weights): tmp = 0 for i in range(weights_num): tmp = tmp + (kalman_lambda + torch.matmul(torch.matmul(H[i].T, P[i]), H[i])) - + if self.dist_init: + dist.all_reduce(tmp, op=dist.ReduceOp.SUM) A = 1 / tmp - for i in range(weights_num): K = torch.matmul(P[i], H[i]) weights[i] = weights[i] + A * error * K P[i] = (1 / kalman_lambda) * (P[i] - A * torch.matmul(K, K.T)) - + if self.dist_init: + device = torch.device("cuda:" + str(self.rank)) + local_shape = [tensor.shape[0] for tensor in weights] + shape_list = [torch.zeros_like(torch.empty(1),dtype=torch.float64,device=device) for _ in range(dist.get_world_size())] + dist.all_gather_object(shape_list, local_shape) + weight_tensor = torch.cat(weights) + world_shape = [sum(inner_list) for inner_list in shape_list] + weight_list = [None] * len(world_shape) + for i in range(len(world_shape)): + weight_list[i] = torch.zeros(world_shape[i],dtype=torch.float64,device=device) + dist.all_gather(weight_list, weight_tensor) + result = [] + for i in range(dist.get_world_size()): + result = result + list(torch.split(weight_list[i], shape_list[i])) + weights = result kalman_lambda = kalman_nue * kalman_lambda + 1 - kalman_nue self._state.update({"kalman_lambda": kalman_lambda}) @@ -172,6 +255,7 @@ def set_grad_prefactor(self, grad_prefactor): self.grad_prefactor = grad_prefactor def step(self, error): + params_packed_index = self._state.get("params_packed_index") weights = [] @@ -215,9 +299,21 @@ def step(self, error): param_sum += nelement if param_sum == params_packed_index[param_index]: - H.append(res_grad) - weights.append(res) param_sum = 0 + if self.dist_init: + device_id = self.get_device_id(param_index) + if self.rank == device_id: + weights.append(res) + H.append(res_grad) + else: + weights.append(res) + H.append(res_grad) param_index += 1 self.__update(H, error, weights) + + def get_device_id(self, index): + for i, (start, end) in enumerate(self.dindex): + if start <= index < end: + return i + return None From 3a355f7f0fb32fb063b07a7c85f1ddbdb3adcd89 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Mar 2024 07:40:51 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/optimizer/LKF.py | 48 ++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index 685844070d..3e0c1420e1 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -1,10 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later import logging -import torch -import numpy as np -from torch.optim.optimizer import Optimizer import math + +import torch import torch.distributed as dist -from torch.profiler import profile, record_function, ProfilerActivity +from torch.optim.optimizer import ( + Optimizer, +) def distribute_indices(total_length, num_workers): @@ -19,7 +21,9 @@ def distribute_indices(total_length, num_workers): indices.append((start, end)) start = end - return indices,remainder + return indices, remainder + + class LKFOptimizer(Optimizer): def __init__( self, @@ -28,7 +32,6 @@ def __init__( kalman_nue=0.9987, block_size=5120, ): - defaults = dict( lr=0.1, kalman_nue=kalman_nue, @@ -55,7 +58,6 @@ def __init__( self.__init_P() def __init_P(self): - param_nums = [] param_sum = 0 block_size = self.__get_blocksize() @@ -79,21 +81,21 @@ def __init_P(self): params_packed_index = [] logging.info("LKF parameter nums: %s" % param_nums) if self.dist_init: - block_num = 0; + block_num = 0 for param_num in param_nums: if param_num >= block_size: block_num += math.ceil(param_num / block_size) - else: - block_num +=1 + else: + block_num += 1 num_workers = dist.get_world_size() - self.dindex,self.remainder = distribute_indices(block_num,num_workers) + self.dindex, self.remainder = distribute_indices(block_num, num_workers) index = 0 for param_num in param_nums: if param_num >= block_size: block_num = math.ceil(param_num / block_size) for i in range(block_num): device_id = self.get_device_id(index) - index +=1 + index += 1 dist_device = torch.device("cuda:" + str(device_id)) if i != block_num - 1: params_packed_index.append(block_size) @@ -106,7 +108,7 @@ def __init_P(self): ) ) else: - continue + continue else: params_packed_index.append(param_num - block_size * i) if self.rank == device_id: @@ -120,13 +122,15 @@ def __init_P(self): else: continue - else: + else: device_id = self.get_device_id(index) - index +=1 + index += 1 params_packed_index.append(param_num) if self.rank == device_id: dist_device = torch.device("cuda:" + str(device_id)) - P.append(torch.eye(param_num, dtype=data_type, device=dist_device)) + P.append( + torch.eye(param_num, dtype=data_type, device=dist_device) + ) device_id = self.rank else: for param_num in param_nums: @@ -151,7 +155,7 @@ def __init_P(self): ) ) params_packed_index.append(param_num - block_size * i) - else: + else: P.append(torch.eye(param_num, dtype=data_type, device=device)) params_packed_index.append(param_num) @@ -206,13 +210,18 @@ def __update(self, H, error, weights): if self.dist_init: device = torch.device("cuda:" + str(self.rank)) local_shape = [tensor.shape[0] for tensor in weights] - shape_list = [torch.zeros_like(torch.empty(1),dtype=torch.float64,device=device) for _ in range(dist.get_world_size())] + shape_list = [ + torch.zeros_like(torch.empty(1), dtype=torch.float64, device=device) + for _ in range(dist.get_world_size()) + ] dist.all_gather_object(shape_list, local_shape) weight_tensor = torch.cat(weights) world_shape = [sum(inner_list) for inner_list in shape_list] weight_list = [None] * len(world_shape) for i in range(len(world_shape)): - weight_list[i] = torch.zeros(world_shape[i],dtype=torch.float64,device=device) + weight_list[i] = torch.zeros( + world_shape[i], dtype=torch.float64, device=device + ) dist.all_gather(weight_list, weight_tensor) result = [] for i in range(dist.get_world_size()): @@ -255,7 +264,6 @@ def set_grad_prefactor(self, grad_prefactor): self.grad_prefactor = grad_prefactor def step(self, error): - params_packed_index = self._state.get("params_packed_index") weights = [] From a1795d50444afaf53dfcb9a51bd61e206bdcaa43 Mon Sep 17 00:00:00 2001 From: CaRoLZhangxy Date: Fri, 8 Mar 2024 08:05:11 +0000 Subject: [PATCH 3/7] remove unused device_id --- deepmd/pt/optimizer/LKF.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index 685844070d..5f38f00622 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -127,7 +127,6 @@ def __init_P(self): if self.rank == device_id: dist_device = torch.device("cuda:" + str(device_id)) P.append(torch.eye(param_num, dtype=data_type, device=dist_device)) - device_id = self.rank else: for param_num in param_nums: if param_num >= block_size: From d9608a3f3e2905fd59e52582348f27fe8641f6d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Mar 2024 08:06:52 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/optimizer/LKF.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index cf58d3a695..35db5b1da3 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -128,7 +128,9 @@ def __init_P(self): params_packed_index.append(param_num) if self.rank == device_id: dist_device = torch.device("cuda:" + str(device_id)) - P.append(torch.eye(param_num, dtype=data_type, device=dist_device)) + P.append( + torch.eye(param_num, dtype=data_type, device=dist_device) + ) else: for param_num in param_nums: if param_num >= block_size: From 4b3a7c6ea2f8eaf55490d5350ccbaa8fa3afd5df Mon Sep 17 00:00:00 2001 From: CaRoLZhangxy Date: Fri, 8 Mar 2024 08:33:22 +0000 Subject: [PATCH 5/7] remove print --- deepmd/pt/optimizer/LKF.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index 35db5b1da3..069dcea1fc 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -185,8 +185,6 @@ def __split_weights(self, weight): def __update(self, H, error, weights): P = self._state.get("P") - # for item in P: - # print(self.rank," size ",item.shape) kalman_lambda = self._state.get("kalman_lambda") weights_num = self._state.get("weights_num") params_packed_index = self._state.get("params_packed_index") From 84fad566cc699a5ac2ad4f647444ef67eb93d60b Mon Sep 17 00:00:00 2001 From: CaRoLZhangxy Date: Mon, 11 Mar 2024 07:46:16 +0000 Subject: [PATCH 6/7] change code style --- deepmd/pt/optimizer/LKF.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index 069dcea1fc..2dc934263c 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -32,13 +32,13 @@ def __init__( kalman_nue=0.9987, block_size=5120, ): - defaults = dict( - lr=0.1, - kalman_nue=kalman_nue, - block_size=block_size, - ) + defaults = { + "lr":0.1, + "kalman_nue":kalman_nue, + "block_size":block_size + } - super(LKFOptimizer, self).__init__(params, defaults) + super().__init__(params, defaults) self._params = self.param_groups[0]["params"] From 99d2e2fb47febf63f8ee8083051d50706093a397 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Mar 2024 07:46:45 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/optimizer/LKF.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/deepmd/pt/optimizer/LKF.py b/deepmd/pt/optimizer/LKF.py index 2dc934263c..06b341d987 100644 --- a/deepmd/pt/optimizer/LKF.py +++ b/deepmd/pt/optimizer/LKF.py @@ -32,11 +32,7 @@ def __init__( kalman_nue=0.9987, block_size=5120, ): - defaults = { - "lr":0.1, - "kalman_nue":kalman_nue, - "block_size":block_size - } + defaults = {"lr": 0.1, "kalman_nue": kalman_nue, "block_size": block_size} super().__init__(params, defaults)