Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: Add parallel implementation for LKF #3436

Merged
merged 11 commits into from
Mar 11, 2024
167 changes: 132 additions & 35 deletions deepmd/pt/optimizer/LKF.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,25 @@
import math

import torch
import torch.distributed as dist
from torch.optim.optimizer import (
Optimizer,
)

log = logging.getLogger(__name__)

def distribute_indices(total_length, num_workers):
indices_per_worker = total_length // num_workers
remainder = total_length % num_workers

Check warning on line 14 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L13-L14

Added lines #L13 - L14 were not covered by tests

indices = []
start = 0

Check warning on line 17 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L16-L17

Added lines #L16 - L17 were not covered by tests

for i in range(num_workers):
end = start + indices_per_worker + (1 if i < remainder else 0)
indices.append((start, end))
start = end

Check warning on line 22 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L19-L22

Added lines #L19 - L22 were not covered by tests

return indices, remainder

Check warning on line 24 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L24

Added line #L24 was not covered by tests


class LKFOptimizer(Optimizer):
Expand All @@ -18,11 +32,8 @@
kalman_nue=0.9987,
block_size=5120,
):
defaults = {
"lr": 0.1,
"kalman_nue": kalman_nue,
"block_size": block_size,
}
defaults = {"lr": 0.1, "kalman_nue": kalman_nue, "block_size": block_size}

super().__init__(params, defaults)

self._params = self.param_groups[0]["params"]
Expand All @@ -36,7 +47,10 @@
# the first param, because this helps with casting in load_state_dict
self._state = self.state[self._params[0]]
self._state.setdefault("kalman_lambda", kalman_lambda)

self.dist_init = dist.is_initialized()
self.rank = dist.get_rank() if self.dist_init else 0
self.dindex = []
self.remainder = 0
self.__init_P()

def __init_P(self):
Expand All @@ -61,32 +75,84 @@

P = []
params_packed_index = []
log.info("LKF parameter nums: %s" % param_nums)
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
if i != block_num - 1:
logging.info("LKF parameter nums: %s" % param_nums)
if self.dist_init:
block_num = 0
for param_num in param_nums:
if param_num >= block_size:
block_num += math.ceil(param_num / block_size)

Check warning on line 83 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L80-L83

Added lines #L80 - L83 were not covered by tests
else:
block_num += 1
num_workers = dist.get_world_size()
self.dindex, self.remainder = distribute_indices(block_num, num_workers)
index = 0
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
device_id = self.get_device_id(index)
index += 1
dist_device = torch.device("cuda:" + str(device_id))
if i != block_num - 1:
params_packed_index.append(block_size)
if self.rank == device_id:
P.append(

Check warning on line 99 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L85-L99

Added lines #L85 - L99 were not covered by tests
torch.eye(
block_size,
dtype=data_type,
device=dist_device,
)
)
else:
continue

Check warning on line 107 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L107

Added line #L107 was not covered by tests
else:
params_packed_index.append(param_num - block_size * i)
if self.rank == device_id:
P.append(

Check warning on line 111 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L109-L111

Added lines #L109 - L111 were not covered by tests
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=dist_device,
)
)
else:
continue

Check warning on line 119 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L119

Added line #L119 was not covered by tests

else:
device_id = self.get_device_id(index)
index += 1
params_packed_index.append(param_num)
if self.rank == device_id:
dist_device = torch.device("cuda:" + str(device_id))

Check warning on line 126 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L122-L126

Added lines #L122 - L126 were not covered by tests
P.append(
torch.eye(
block_size,
dtype=data_type,
device=device,
)
torch.eye(param_num, dtype=data_type, device=dist_device)
)
params_packed_index.append(block_size)
else:
P.append(
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=device,
else:
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
if i != block_num - 1:
P.append(
torch.eye(
block_size,
dtype=data_type,
device=device,
)
)
)
params_packed_index.append(param_num - block_size * i)
else:
P.append(torch.eye(param_num, dtype=data_type, device=device))
params_packed_index.append(param_num)
params_packed_index.append(block_size)
else:
P.append(
torch.eye(
param_num - block_size * i,
dtype=data_type,
device=device,
)
)
params_packed_index.append(param_num - block_size * i)
else:
P.append(torch.eye(param_num, dtype=data_type, device=device))
params_packed_index.append(param_num)

self._state.setdefault("P", P)
self._state.setdefault("weights_num", len(P))
Expand Down Expand Up @@ -125,16 +191,35 @@
tmp = 0
for i in range(weights_num):
tmp = tmp + (kalman_lambda + torch.matmul(torch.matmul(H[i].T, P[i]), H[i]))

if self.dist_init:
dist.all_reduce(tmp, op=dist.ReduceOp.SUM)

Check warning on line 195 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L195

Added line #L195 was not covered by tests
A = 1 / tmp

for i in range(weights_num):
K = torch.matmul(P[i], H[i])

weights[i] = weights[i] + A * error * K

P[i] = (1 / kalman_lambda) * (P[i] - A * torch.matmul(K, K.T))

if self.dist_init:
device = torch.device("cuda:" + str(self.rank))
local_shape = [tensor.shape[0] for tensor in weights]
shape_list = [

Check warning on line 206 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L204-L206

Added lines #L204 - L206 were not covered by tests
torch.zeros_like(torch.empty(1), dtype=torch.float64, device=device)
for _ in range(dist.get_world_size())
]
dist.all_gather_object(shape_list, local_shape)
weight_tensor = torch.cat(weights)
world_shape = [sum(inner_list) for inner_list in shape_list]
weight_list = [None] * len(world_shape)
for i in range(len(world_shape)):
weight_list[i] = torch.zeros(

Check warning on line 215 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L210-L215

Added lines #L210 - L215 were not covered by tests
world_shape[i], dtype=torch.float64, device=device
)
dist.all_gather(weight_list, weight_tensor)
result = []
for i in range(dist.get_world_size()):
result = result + list(torch.split(weight_list[i], shape_list[i]))
weights = result

Check warning on line 222 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L218-L222

Added lines #L218 - L222 were not covered by tests
kalman_lambda = kalman_nue * kalman_lambda + 1 - kalman_nue
self._state.update({"kalman_lambda": kalman_lambda})

Expand Down Expand Up @@ -215,9 +300,21 @@
param_sum += nelement

if param_sum == params_packed_index[param_index]:
H.append(res_grad)
weights.append(res)
param_sum = 0
if self.dist_init:
device_id = self.get_device_id(param_index)
if self.rank == device_id:
weights.append(res)
H.append(res_grad)

Check warning on line 308 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L305-L308

Added lines #L305 - L308 were not covered by tests
else:
weights.append(res)
H.append(res_grad)
param_index += 1

self.__update(H, error, weights)

def get_device_id(self, index):
for i, (start, end) in enumerate(self.dindex):
if start <= index < end:
return i
return None

Check warning on line 320 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L317-L320

Added lines #L317 - L320 were not covered by tests