Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 8, 2024
1 parent 4077515 commit 3a355f7
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions deepmd/pt/optimizer/LKF.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import torch
import numpy as np
from torch.optim.optimizer import Optimizer
import math

import torch
import torch.distributed as dist

Check warning on line 6 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L6

Added line #L6 was not covered by tests
from torch.profiler import profile, record_function, ProfilerActivity
from torch.optim.optimizer import (
Optimizer,
)


def distribute_indices(total_length, num_workers):
Expand All @@ -19,7 +21,9 @@ def distribute_indices(total_length, num_workers):
indices.append((start, end))
start = end

Check warning on line 22 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L19-L22

Added lines #L19 - L22 were not covered by tests

return indices,remainder
return indices, remainder

Check warning on line 24 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L24

Added line #L24 was not covered by tests


class LKFOptimizer(Optimizer):
def __init__(
self,
Expand All @@ -28,7 +32,6 @@ def __init__(
kalman_nue=0.9987,
block_size=5120,
):

defaults = dict(

Check warning on line 35 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L35

Added line #L35 was not covered by tests
lr=0.1,
kalman_nue=kalman_nue,
Expand All @@ -55,7 +58,6 @@ def __init__(
self.__init_P()

def __init_P(self):

param_nums = []
param_sum = 0
block_size = self.__get_blocksize()
Expand All @@ -79,21 +81,21 @@ def __init_P(self):
params_packed_index = []
logging.info("LKF parameter nums: %s" % param_nums)
if self.dist_init:
block_num = 0;
block_num = 0
for param_num in param_nums:
if param_num >= block_size:
block_num += math.ceil(param_num / block_size)

Check warning on line 87 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L82-L87

Added lines #L82 - L87 were not covered by tests
else:
block_num +=1
else:
block_num += 1
num_workers = dist.get_world_size()
self.dindex,self.remainder = distribute_indices(block_num,num_workers)
self.dindex, self.remainder = distribute_indices(block_num, num_workers)
index = 0
for param_num in param_nums:
if param_num >= block_size:
block_num = math.ceil(param_num / block_size)
for i in range(block_num):
device_id = self.get_device_id(index)
index +=1
index += 1
dist_device = torch.device("cuda:" + str(device_id))
if i != block_num - 1:
params_packed_index.append(block_size)
Expand All @@ -106,7 +108,7 @@ def __init_P(self):
)
)
else:
continue
continue

Check warning on line 111 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L111

Added line #L111 was not covered by tests
else:
params_packed_index.append(param_num - block_size * i)
if self.rank == device_id:
Expand All @@ -120,13 +122,15 @@ def __init_P(self):
else:
continue

Check warning on line 123 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L123

Added line #L123 was not covered by tests

else:
else:
device_id = self.get_device_id(index)
index +=1
index += 1
params_packed_index.append(param_num)
if self.rank == device_id:
dist_device = torch.device("cuda:" + str(device_id))

Check warning on line 130 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L126-L130

Added lines #L126 - L130 were not covered by tests
P.append(torch.eye(param_num, dtype=data_type, device=dist_device))
P.append(
torch.eye(param_num, dtype=data_type, device=dist_device)
)
device_id = self.rank

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable device_id is not used.

Check warning on line 134 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L134

Added line #L134 was not covered by tests
else:
for param_num in param_nums:
Expand All @@ -151,7 +155,7 @@ def __init_P(self):
)
)
params_packed_index.append(param_num - block_size * i)

Check warning on line 157 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L157

Added line #L157 was not covered by tests
else:
else:
P.append(torch.eye(param_num, dtype=data_type, device=device))
params_packed_index.append(param_num)

Check warning on line 160 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L159-L160

Added lines #L159 - L160 were not covered by tests

Expand Down Expand Up @@ -206,13 +210,18 @@ def __update(self, H, error, weights):
if self.dist_init:
device = torch.device("cuda:" + str(self.rank))
local_shape = [tensor.shape[0] for tensor in weights]
shape_list = [torch.zeros_like(torch.empty(1),dtype=torch.float64,device=device) for _ in range(dist.get_world_size())]
shape_list = [

Check warning on line 213 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L210-L213

Added lines #L210 - L213 were not covered by tests
torch.zeros_like(torch.empty(1), dtype=torch.float64, device=device)
for _ in range(dist.get_world_size())
]
dist.all_gather_object(shape_list, local_shape)
weight_tensor = torch.cat(weights)
world_shape = [sum(inner_list) for inner_list in shape_list]
weight_list = [None] * len(world_shape)
for i in range(len(world_shape)):
weight_list[i] = torch.zeros(world_shape[i],dtype=torch.float64,device=device)
weight_list[i] = torch.zeros(

Check warning on line 222 in deepmd/pt/optimizer/LKF.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/optimizer/LKF.py#L217-L222

Added lines #L217 - L222 were not covered by tests
world_shape[i], dtype=torch.float64, device=device
)
dist.all_gather(weight_list, weight_tensor)
result = []
for i in range(dist.get_world_size()):
Expand Down Expand Up @@ -255,7 +264,6 @@ def set_grad_prefactor(self, grad_prefactor):
self.grad_prefactor = grad_prefactor

def step(self, error):

params_packed_index = self._state.get("params_packed_index")

weights = []
Expand Down

0 comments on commit 3a355f7

Please sign in to comment.