Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: Add parallel implementation for LKF #3436

Merged
merged 11 commits into from
Mar 11, 2024
173 changes: 137 additions & 36 deletions deepmd/pt/optimizer/LKF.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,25 @@
import math

import torch
import torch.distributed as dist
from torch.optim.optimizer import (
Optimizer,
)

log = logging.getLogger(__name__)

def distribute_indices(total_length, num_workers):
indices_per_worker = total_length // num_workers
remainder = total_length % num_workers

Check warning on line 14 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L13-L14

Added lines #L13 - L14 were not covered by tests

indices = []
start = 0

Check warning on line 17 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L16-L17

Added lines #L16 - L17 were not covered by tests

for i in range(num_workers):
end = start + indices_per_worker + (1 if i < remainder else 0)
indices.append((start, end))
start = end

Check warning on line 22 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L19-L22

Added lines #L19 - L22 were not covered by tests

return indices, remainder

Check warning on line 24 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L24

Added line #L24 was not covered by tests


class LKFOptimizer(Optimizer):
Expand All @@ -18,12 +32,13 @@
kalman_nue=0.9987,
block_size=5120,
):
defaults = {
"lr": 0.1,
"kalman_nue": kalman_nue,
"block_size": block_size,
}
super().__init__(params, defaults)
defaults = dict(
lr=0.1,
kalman_nue=kalman_nue,
block_size=block_size,
)

super(LKFOptimizer, self).__init__(params, defaults)

self._params = self.param_groups[0]["params"]

Expand All @@ -36,7 +51,10 @@
# the first param, because this helps with casting in load_state_dict
self._state = self.state[self._params[0]]
self._state.setdefault("kalman_lambda", kalman_lambda)

self.dist_init = dist.is_initialized()
self.rank = dist.get_rank() if self.dist_init else 0
self.dindex = []
self.remainder = 0
self.__init_P()

def __init_P(self):
Expand All @@ -61,32 +79,84 @@

P = []
params_packed_index = []
log.info("LKF parameter nums: %s" % param_nums)
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
if i != block_num - 1:
logging.info("LKF parameter nums: %s" % param_nums)
if self.dist_init:
block_num = 0
for param_num in param_nums:
if param_num >= block_size:
block_num += math.ceil(param_num / block_size)

Check warning on line 87 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L84-L87

Added lines #L84 - L87 were not covered by tests
else:
block_num += 1
num_workers = dist.get_world_size()
self.dindex, self.remainder = distribute_indices(block_num, num_workers)
index = 0
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
device_id = self.get_device_id(index)
index += 1
dist_device = torch.device("cuda:" + str(device_id))
if i != block_num - 1:
params_packed_index.append(block_size)
if self.rank == device_id:
P.append(

Check warning on line 103 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L89-L103

Added lines #L89 - L103 were not covered by tests
torch.eye(
block_size,
dtype=data_type,
device=dist_device,
)
)
else:
continue

Check warning on line 111 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L111

Added line #L111 was not covered by tests
else:
params_packed_index.append(param_num - block_size * i)
if self.rank == device_id:
P.append(

Check warning on line 115 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L113-L115

Added lines #L113 - L115 were not covered by tests
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=dist_device,
)
)
else:
continue

Check warning on line 123 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L123

Added line #L123 was not covered by tests

else:
device_id = self.get_device_id(index)
index += 1
params_packed_index.append(param_num)
if self.rank == device_id:
dist_device = torch.device("cuda:" + str(device_id))

Check warning on line 130 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L126-L130

Added lines #L126 - L130 were not covered by tests
P.append(
torch.eye(
block_size,
dtype=data_type,
device=device,
)
torch.eye(param_num, dtype=data_type, device=dist_device)
)
params_packed_index.append(block_size)
else:
P.append(
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=device,
else:
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
if i != block_num - 1:
P.append(
torch.eye(
block_size,
dtype=data_type,
device=device,
)
)
)
params_packed_index.append(param_num - block_size * i)
else:
P.append(torch.eye(param_num, dtype=data_type, device=device))
params_packed_index.append(param_num)
params_packed_index.append(block_size)
else:
P.append(
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=device,
)
)
params_packed_index.append(param_num - block_size * i)
else:
P.append(torch.eye(param_num, dtype=data_type, device=device))
params_packed_index.append(param_num)

self._state.setdefault("P", P)
self._state.setdefault("weights_num", len(P))
Expand Down Expand Up @@ -125,16 +195,35 @@
tmp = 0
for i in range(weights_num):
tmp = tmp + (kalman_lambda + torch.matmul(torch.matmul(H[i].T, P[i]), H[i]))

if self.dist_init:
dist.all_reduce(tmp, op=dist.ReduceOp.SUM)

Check warning on line 199 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L199

Added line #L199 was not covered by tests
A = 1 / tmp

for i in range(weights_num):
K = torch.matmul(P[i], H[i])

weights[i] = weights[i] + A * error * K

P[i] = (1 / kalman_lambda) * (P[i] - A * torch.matmul(K, K.T))

if self.dist_init:
device = torch.device("cuda:" + str(self.rank))
local_shape = [tensor.shape[0] for tensor in weights]
shape_list = [

Check warning on line 210 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L208-L210

Added lines #L208 - L210 were not covered by tests
torch.zeros_like(torch.empty(1), dtype=torch.float64, device=device)
for _ in range(dist.get_world_size())
]
dist.all_gather_object(shape_list, local_shape)
weight_tensor = torch.cat(weights)
world_shape = [sum(inner_list) for inner_list in shape_list]
weight_list = [None] * len(world_shape)
for i in range(len(world_shape)):
weight_list[i] = torch.zeros(

Check warning on line 219 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L214-L219

Added lines #L214 - L219 were not covered by tests
world_shape[i], dtype=torch.float64, device=device
)
dist.all_gather(weight_list, weight_tensor)
result = []
for i in range(dist.get_world_size()):
result = result + list(torch.split(weight_list[i], shape_list[i]))
weights = result

Check warning on line 226 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L222-L226

Added lines #L222 - L226 were not covered by tests
kalman_lambda = kalman_nue * kalman_lambda + 1 - kalman_nue
self._state.update({"kalman_lambda": kalman_lambda})

Expand Down Expand Up @@ -215,9 +304,21 @@
param_sum += nelement

if param_sum == params_packed_index[param_index]:
H.append(res_grad)
weights.append(res)
param_sum = 0
if self.dist_init:
device_id = self.get_device_id(param_index)
if self.rank == device_id:
weights.append(res)
H.append(res_grad)

Check warning on line 312 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L309-L312

Added lines #L309 - L312 were not covered by tests
else:
weights.append(res)
H.append(res_grad)
param_index += 1

self.__update(H, error, weights)

def get_device_id(self, index):
for i, (start, end) in enumerate(self.dindex):
if start <= index < end:
return i
return None

Check warning on line 324 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L321-L324

Added lines #L321 - L324 were not covered by tests