Skip to content

Commit

Permalink
pt: Add parallel implementation for LKF (#3436)
Browse files Browse the repository at this point in the history
Add parallel implementation for LKF, enabling distributed storage of the
P matrix across multiple GPUs, reducing memory overhead.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
3 people authored Mar 11, 2024
1 parent 619fd1c commit 24d02b7
Showing 1 changed file with 132 additions and 35 deletions.
167 changes: 132 additions & 35 deletions deepmd/pt/optimizer/LKF.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,25 @@
import math

import torch
import torch.distributed as dist
from torch.optim.optimizer import (
Optimizer,
)

log = logging.getLogger(__name__)

def distribute_indices(total_length, num_workers):
indices_per_worker = total_length // num_workers
remainder = total_length % num_workers

indices = []
start = 0

for i in range(num_workers):
end = start + indices_per_worker + (1 if i < remainder else 0)
indices.append((start, end))
start = end

return indices, remainder


class LKFOptimizer(Optimizer):
Expand All @@ -18,11 +32,8 @@ def __init__(
kalman_nue=0.9987,
block_size=5120,
):
defaults = {
"lr": 0.1,
"kalman_nue": kalman_nue,
"block_size": block_size,
}
defaults = {"lr": 0.1, "kalman_nue": kalman_nue, "block_size": block_size}

super().__init__(params, defaults)

self._params = self.param_groups[0]["params"]
Expand All @@ -36,7 +47,10 @@ def __init__(
# the first param, because this helps with casting in load_state_dict
self._state = self.state[self._params[0]]
self._state.setdefault("kalman_lambda", kalman_lambda)

self.dist_init = dist.is_initialized()
self.rank = dist.get_rank() if self.dist_init else 0
self.dindex = []
self.remainder = 0
self.__init_P()

def __init_P(self):
Expand All @@ -61,32 +75,84 @@ def __init_P(self):

P = []
params_packed_index = []
log.info("LKF parameter nums: %s" % param_nums)
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
if i != block_num - 1:
logging.info("LKF parameter nums: %s" % param_nums)
if self.dist_init:
block_num = 0
for param_num in param_nums:
if param_num >= block_size:
block_num += math.ceil(param_num / block_size)
else:
block_num += 1
num_workers = dist.get_world_size()
self.dindex, self.remainder = distribute_indices(block_num, num_workers)
index = 0
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
device_id = self.get_device_id(index)
index += 1
dist_device = torch.device("cuda:" + str(device_id))
if i != block_num - 1:
params_packed_index.append(block_size)
if self.rank == device_id:
P.append(
torch.eye(
block_size,
dtype=data_type,
device=dist_device,
)
)
else:
continue
else:
params_packed_index.append(param_num - block_size * i)
if self.rank == device_id:
P.append(
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=dist_device,
)
)
else:
continue

else:
device_id = self.get_device_id(index)
index += 1
params_packed_index.append(param_num)
if self.rank == device_id:
dist_device = torch.device("cuda:" + str(device_id))
P.append(
torch.eye(
block_size,
dtype=data_type,
device=device,
)
torch.eye(param_num, dtype=data_type, device=dist_device)
)
params_packed_index.append(block_size)
else:
P.append(
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=device,
else:
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
if i != block_num - 1:
P.append(
torch.eye(
block_size,
dtype=data_type,
device=device,
)
)
)
params_packed_index.append(param_num - block_size * i)
else:
P.append(torch.eye(param_num, dtype=data_type, device=device))
params_packed_index.append(param_num)
params_packed_index.append(block_size)
else:
P.append(
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=device,
)
)
params_packed_index.append(param_num - block_size * i)
else:
P.append(torch.eye(param_num, dtype=data_type, device=device))
params_packed_index.append(param_num)

self._state.setdefault("P", P)
self._state.setdefault("weights_num", len(P))
Expand Down Expand Up @@ -125,16 +191,35 @@ def __update(self, H, error, weights):
tmp = 0
for i in range(weights_num):
tmp = tmp + (kalman_lambda + torch.matmul(torch.matmul(H[i].T, P[i]), H[i]))

if self.dist_init:
dist.all_reduce(tmp, op=dist.ReduceOp.SUM)
A = 1 / tmp

for i in range(weights_num):
K = torch.matmul(P[i], H[i])

weights[i] = weights[i] + A * error * K

P[i] = (1 / kalman_lambda) * (P[i] - A * torch.matmul(K, K.T))

if self.dist_init:
device = torch.device("cuda:" + str(self.rank))
local_shape = [tensor.shape[0] for tensor in weights]
shape_list = [
torch.zeros_like(torch.empty(1), dtype=torch.float64, device=device)
for _ in range(dist.get_world_size())
]
dist.all_gather_object(shape_list, local_shape)
weight_tensor = torch.cat(weights)
world_shape = [sum(inner_list) for inner_list in shape_list]
weight_list = [None] * len(world_shape)
for i in range(len(world_shape)):
weight_list[i] = torch.zeros(
world_shape[i], dtype=torch.float64, device=device
)
dist.all_gather(weight_list, weight_tensor)
result = []
for i in range(dist.get_world_size()):
result = result + list(torch.split(weight_list[i], shape_list[i]))
weights = result
kalman_lambda = kalman_nue * kalman_lambda + 1 - kalman_nue
self._state.update({"kalman_lambda": kalman_lambda})

Expand Down Expand Up @@ -215,9 +300,21 @@ def step(self, error):
param_sum += nelement

if param_sum == params_packed_index[param_index]:
H.append(res_grad)
weights.append(res)
param_sum = 0
if self.dist_init:
device_id = self.get_device_id(param_index)
if self.rank == device_id:
weights.append(res)
H.append(res_grad)
else:
weights.append(res)
H.append(res_grad)
param_index += 1

self.__update(H, error, weights)

def get_device_id(self, index):
for i, (start, end) in enumerate(self.dindex):
if start <= index < end:
return i
return None

0 comments on commit 24d02b7

Please sign in to comment.