Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 4 pt descriptor compression #4227

Merged
merged 102 commits into from
Nov 1, 2024
Merged
Changes from 1 commit
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
6bb8795
yan devel
cherryWangY Jun 9, 2024
7ecd122
tabulate_fusion_se_t
cherryWangY Jun 10, 2024
ab670ed
tabulate_fusion_all_op_basic_verion
cherryWangY Jun 10, 2024
9fc3fb0
compile safe version
cherryWangY Jun 10, 2024
cab50c9
compile safe version
cherryWangY Jun 10, 2024
e9ccb98
Merge branch 'devel' of https://github.com/cherryWangY/deepmd-kit int…
cherryWangY Jun 13, 2024
87909e2
se_a & se_atten
cherryWangY Jun 13, 2024
2225975
se_r
cherryWangY Jun 13, 2024
c09a7a7
remove print
cherryWangY Jun 13, 2024
ee5b64e
move pt op test
cherryWangY Jun 13, 2024
c7efbce
Merge remote-tracking branch 'upstream/devel' into devel
cherryWangY Jun 13, 2024
763c7b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2024
3fe4b64
remove print
cherryWangY Jun 13, 2024
6f76ccf
fixed for commit
cherryWangY Jun 13, 2024
b89ceed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2024
5a8b77e
fix pull request warning
cherryWangY Jun 13, 2024
25ca8c1
fix pr warning
cherryWangY Jun 13, 2024
f1c43f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 13, 2024
179e175
gpu test debug
cherryWangY Jun 16, 2024
34c664c
merge
cherryWangY Jun 16, 2024
4cc1478
merge
cherryWangY Jun 16, 2024
5921a60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 16, 2024
b63209c
table_info set cpu
cherryWangY Jun 17, 2024
a824a80
remove print
cherryWangY Jun 17, 2024
8527819
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2024
e47dcba
add dtype=float64
cherryWangY Jun 17, 2024
95a9566
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2024
114f7a6
add dtype=float64
cherryWangY Jun 17, 2024
9e677f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2024
22ae3b7
test both float64 and float32
njzjz Jun 17, 2024
9920e57
skip tests if customized ops are not enables
njzjz Jun 17, 2024
7dd7f6a
Merge branch 'devel' into devel
cherryWangY Jun 18, 2024
82a5035
Merge branch 'devel' into devel
cherryWangY Jun 20, 2024
e6bc120
Merge branch 'devel' into tabulate_op
njzjz Jun 21, 2024
1523196
reduce test size from 192 atoms to 4 atoms
njzjz Jun 21, 2024
4c6f69b
Merge branch 'deepmodeling:devel' into devel
cherryWangY Jun 29, 2024
ff57db2
basic descriptor se_a
cherryWangY Jul 7, 2024
fdb13bb
Merge branch 'devel' of https://github.com/cherryWangY/deepmd-kit int…
cherryWangY Jul 7, 2024
882297c
basic descriptor se_a
cherryWangY Jul 7, 2024
4afd634
basic descriptor se_a
cherryWangY Jul 7, 2024
6880e8b
basic torch version
cherryWangY Jul 27, 2024
48a7508
compressed se_a test
cherryWangY Aug 6, 2024
1af34ce
se_a debug
cherryWangY Aug 28, 2024
e9915e3
four descriptors compression
cherryWangY Oct 17, 2024
322ad02
remove redundant code for pt descriptor compression
cherryWangY Oct 21, 2024
d7813d9
align to latest version
cherryWangY Oct 21, 2024
0331c43
align to latest version
cherryWangY Oct 21, 2024
d6cea9b
Merge branch 'devel' into devel
cherryWangY Oct 21, 2024
8984eb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 21, 2024
c38472a
remove extra indent
cherryWangY Oct 22, 2024
7d9d1e2
fix pre-commit
cherryWangY Oct 22, 2024
ed1d598
Merge branch 'devel' into devel
cherryWangY Oct 22, 2024
f12fa1b
enhance code robustness
cherryWangY Oct 23, 2024
95e4f9f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
776816f
fix pre-commit
cherryWangY Oct 23, 2024
a0d7403
solve some problems
cherryWangY Oct 24, 2024
0d46b21
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2024
2617b6f
fix annotation error
cherryWangY Oct 24, 2024
7b9332b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2024
e2106af
fix self.table and remove locals
cherryWangY Oct 25, 2024
6e1f787
merge latest
cherryWangY Oct 25, 2024
2802947
fix se_atten
cherryWangY Oct 25, 2024
b572034
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
8e86382
fix gg undefine
cherryWangY Oct 25, 2024
3eb8d3f
Merge remote-tracking branch 'new-fork/devel' into devel
cherryWangY Oct 25, 2024
7d84da6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
37b1644
make torchscript happy
njzjz Oct 25, 2024
c0ab006
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
c924750
add fake op for freezing
cherryWangY Oct 26, 2024
78f20fd
add explicit device
cherryWangY Oct 26, 2024
755401e
add parametized tests
cherryWangY Oct 26, 2024
976c8df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2024
126a170
fix parameterized
cherryWangY Oct 26, 2024
706d614
remove useless exception
cherryWangY Oct 26, 2024
b92c1f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2024
4d3c1b8
fix prec bug
cherryWangY Oct 26, 2024
619e5fb
Merge branch 'devel' of https://github.com/cherryWangY/deepmd-kit int…
cherryWangY Oct 26, 2024
1b1f0e9
set se_t precesion as 1e-6 when float64
cherryWangY Oct 27, 2024
b3737b9
avoid using env.DEVICE in the forward
cherryWangY Oct 27, 2024
90c5ac3
change enable_compression
cherryWangY Oct 27, 2024
5d8c96a
solve coderabbitai conversations
cherryWangY Oct 27, 2024
feda81b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2024
7231776
change error info
cherryWangY Oct 29, 2024
05c9145
avoid calling serialize() multiple times
cherryWangY Oct 29, 2024
dbe92ef
add enable_compression() to BaseBaseModel
cherryWangY Oct 29, 2024
a7b9f66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
f5d1757
remove enable_compression() in base_model
cherryWangY Oct 30, 2024
3fb094a
add enable_compression() in base_descriptor
cherryWangY Oct 30, 2024
621b45f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
99f9f6b
simplified code
cherryWangY Oct 30, 2024
4df6df0
update branch
cherryWangY Oct 30, 2024
30c722c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 30, 2024
7c86385
merge pt and tf similar implementation of tabulate
cherryWangY Oct 31, 2024
2ce4356
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
39d6d17
Merge branch 'devel' into devel
cherryWangY Oct 31, 2024
cb15335
set is_pt
cherryWangY Oct 31, 2024
15bf3e2
fix for loop; fix codeql warnings
njzjz Oct 31, 2024
689748e
add comment at descrpt SeT build()
cherryWangY Nov 1, 2024
9c7534e
Refactor duplicate code in _get_bias and _get_matrix methods
cherryWangY Nov 1, 2024
5794248
fix device inconsistency
cherryWangY Nov 1, 2024
c06f54a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2024
006c159
Merge branch 'devel' into devel
cherryWangY Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 16, 2024
commit 5921a60a4fde54b28ebdfc74b0723b7aa4ffa96f
65 changes: 37 additions & 28 deletions source/op/pt/tabulate_multi_device.cc
Original file line number Diff line number Diff line change
@@ -21,29 +21,33 @@
#include <cuda_runtime.h>

void checkPointerLocation(const void* ptr, const std::string& name) {
cudaPointerAttributes attributes;
cudaError_t err = cudaPointerGetAttributes(&attributes, ptr);
cudaPointerAttributes attributes;
cudaError_t err = cudaPointerGetAttributes(&attributes, ptr);

if (err != cudaSuccess) {
std::cerr << "Error checking pointer " << name << ": " << cudaGetErrorString(err) << std::endl;
return;
}
if (err != cudaSuccess) {
std::cerr << "Error checking pointer " << name << ": "
<< cudaGetErrorString(err) << std::endl;
return;
}

if (attributes.type == cudaMemoryTypeDevice) {
std::cout << "Pointer " << name << " is located in device memory." << std::endl;
} else if (attributes.type == cudaMemoryTypeHost) {
std::cout << "Pointer " << name << " is located in host memory." << std::endl;
} else {
std::cout << "Pointer " << name << " is of unknown memory type." << std::endl;
}
if (attributes.type == cudaMemoryTypeDevice) {
std::cout << "Pointer " << name << " is located in device memory."
<< std::endl;
} else if (attributes.type == cudaMemoryTypeHost) {
std::cout << "Pointer " << name << " is located in host memory."
<< std::endl;
} else {
std::cout << "Pointer " << name << " is of unknown memory type."
<< std::endl;
}
}

void check_contiguity(const torch::Tensor& tensor, const std::string& name) {
if (tensor.is_contiguous()) {
std::cout << name << " is contiguous" << std::endl;
} else {
std::cout << name << " is not contiguous" << std::endl;
}
if (tensor.is_contiguous()) {
std::cout << name << " is contiguous" << std::endl;
} else {
std::cout << name << " is not contiguous" << std::endl;
}
}

template <typename FPTYPE>
@@ -71,11 +75,15 @@
std::string device;
GetTensorDevice(table_tensor, device);
// debug
std::cout << "table_tensor device: " << table_tensor.device().type() << std::endl;
std::cout << "table_info_tensor device: " << table_info_tensor.device().type() << std::endl;
std::cout << "em_x_tensor device: " << em_x_tensor.device().type() << std::endl;
std::cout << "table_tensor device: " << table_tensor.device().type()
<< std::endl;
std::cout << "table_info_tensor device: " << table_info_tensor.device().type()
<< std::endl;
std::cout << "em_x_tensor device: " << em_x_tensor.device().type()
<< std::endl;
std::cout << "em_tensor device: " << em_tensor.device().type() << std::endl;
std::cout << "descriptor_tensor device before computation: " << descriptor_tensor.device().type() << std::endl;
std::cout << "descriptor_tensor device before computation: "
<< descriptor_tensor.device().type() << std::endl;

check_contiguity(table_tensor, "table_tensor");
check_contiguity(table_info_tensor, "table_info_tensor");
@@ -503,12 +511,13 @@
torch::Tensor descriptor_tensor =
torch::empty({em_tensor.size(0), 4, last_layer_size}, options);
// test device
// std::cout << "table_tensor device: " << table_tensor.device().type() << std::endl;
// std::cout << "table_info_tensor device: " << table_info_tensor.device().type() << std::endl;
// std::cout << "em_x_tensor device: " << em_x_tensor.device().type() << std::endl;
// std::cout << "em_tensor device: " << em_tensor.device().type() << std::endl;
// std::cout << "descriptor_tensor device: " << descriptor_tensor.device().type() << std::endl;
// compute
// std::cout << "table_tensor device: " << table_tensor.device().type() <<
// std::endl; std::cout << "table_info_tensor device: " <<
// table_info_tensor.device().type() << std::endl; std::cout << "em_x_tensor
// device: " << em_x_tensor.device().type() << std::endl; std::cout <<
// "em_tensor device: " << em_tensor.device().type() << std::endl; std::cout
// << "descriptor_tensor device: " << descriptor_tensor.device().type() <<
// std::endl; compute
TabulateFusionSeAForward<FPTYPE>(table_tensor, table_info_tensor,
em_x_tensor, em_tensor, at::Tensor(),
last_layer_size, descriptor_tensor);
@@ -997,15 +1006,15 @@
last_layer_size);
}

TORCH_LIBRARY_FRAGMENT(deepmd, m) {

Check notice

Code scanning / CodeQL

Unused static function Note

Static function TORCH_LIBRARY_FRAGMENT_init_deepmd_2 is unreachable (
TORCH_LIBRARY_FRAGMENT_static_init_deepmd_2
must be removed at the same time)
m.def("tabulate_fusion_se_a", tabulate_fusion_se_a);
}
TORCH_LIBRARY_FRAGMENT(deepmd, m) {

Check notice

Code scanning / CodeQL

Unused static function Note

Static function TORCH_LIBRARY_FRAGMENT_init_deepmd_3 is unreachable (
TORCH_LIBRARY_FRAGMENT_static_init_deepmd_3
must be removed at the same time)
m.def("tabulate_fusion_se_atten", tabulate_fusion_se_atten);
}
TORCH_LIBRARY_FRAGMENT(deepmd, m) {

Check notice

Code scanning / CodeQL

Unused static function Note

Static function TORCH_LIBRARY_FRAGMENT_init_deepmd_4 is unreachable (
TORCH_LIBRARY_FRAGMENT_static_init_deepmd_4
must be removed at the same time)
m.def("tabulate_fusion_se_t", tabulate_fusion_se_t);
}
TORCH_LIBRARY_FRAGMENT(deepmd, m) {

Check notice

Code scanning / CodeQL

Unused static function Note

Static function TORCH_LIBRARY_FRAGMENT_init_deepmd_5 is unreachable (
TORCH_LIBRARY_FRAGMENT_static_init_deepmd_5
must be removed at the same time)
m.def("tabulate_fusion_se_r", tabulate_fusion_se_r);
}
28 changes: 17 additions & 11 deletions source/tests/pt/test_tabulate_fusion_se_a.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import torch

from deepmd.pt.utils import (
env
env,
)


@@ -1069,9 +1069,11 @@ def setUp(self):
-1.664931178025436733e-05,
-4.312450972708557703e-06,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(8, 132)
self.table_info_tensor = torch.tensor([0, 0.2, 0.4, 0.01, 0.1, -1], device=env.DEVICE)
self.table_info_tensor = torch.tensor(
[0, 0.2, 0.4, 0.01, 0.1, -1], device=env.DEVICE
)
self.em_x_tensor = torch.tensor(
[
0.0343909,
@@ -1091,7 +1093,7 @@ def setUp(self):
0.17527857,
0.04249097,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4)
self.em_tensor = torch.tensor(
[
@@ -1160,7 +1162,7 @@ def setUp(self):
0.18275348,
0.02921504,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4, 4)
self.table_info_tensor.requires_grad = True
self.table_tensor.requires_grad = True
@@ -1301,7 +1303,7 @@ def setUp(self):
-0.09300045,
-0.50528542,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4, 8)
# backward test
self.expected_dy_dem_x = torch.tensor(
@@ -1323,7 +1325,7 @@ def setUp(self):
-0.02917727,
-0.04478649,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4)
self.expected_dy_dem = torch.tensor(
[
@@ -1392,7 +1394,7 @@ def setUp(self):
-3.33051143,
-3.33051143,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4, 4)

def test_forward(self):
@@ -1453,11 +1455,15 @@ def test_backward(self):
self.assertEqual(self.em_tensor.grad.shape, self.expected_dy_dem.shape)

# Check the values of the gradients
torch.testing.assert_close(self.em_x_tensor.grad, self.expected_dy_dem_x, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(
self.em_x_tensor.grad, self.expected_dy_dem_x, atol=1e-5, rtol=1e-5
)

torch.testing.assert_close(self.em_tensor.grad, self.expected_dy_dem, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(
self.em_tensor.grad, self.expected_dy_dem, atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
env.DEVICE = 'cpu'
env.DEVICE = "cpu"
unittest.main()
28 changes: 17 additions & 11 deletions source/tests/pt/test_tabulate_fusion_se_atten.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import torch

from deepmd.pt.utils import (
env
env,
)


@@ -1069,9 +1069,11 @@ def setUp(self):
-1.664931178025436733e-05,
-4.312450972708557703e-06,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(8, 132)
self.table_info_tensor = torch.tensor([0, 0.2, 0.4, 0.01, 0.1, -1], device=env.DEVICE)
self.table_info_tensor = torch.tensor(
[0, 0.2, 0.4, 0.01, 0.1, -1], device=env.DEVICE
)
self.em_x_tensor = torch.tensor(
[
0.0343909,
@@ -1091,7 +1093,7 @@ def setUp(self):
0.17527857,
0.04249097,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4)
self.em_tensor = torch.tensor(
[
@@ -1160,7 +1162,7 @@ def setUp(self):
0.18275348,
0.02921504,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4, 4)
self.two_embed_tensor = torch.tensor(
[
@@ -1293,7 +1295,7 @@ def setUp(self):
0.5194672674960213,
0.04635102497306032,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(8, 16)
self.table_info_tensor.requires_grad = False
self.table_tensor.requires_grad = False
@@ -1436,7 +1438,7 @@ def setUp(self):
-0.162872,
-0.723229,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4, 8)
# backward test
self.expected_dy_dem_x = torch.tensor(
@@ -1526,7 +1528,7 @@ def setUp(self):
-3.90654,
-3.90654,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4, 4)

def test_forward(self):
@@ -1583,11 +1585,15 @@ def test_backward(self):
self.assertEqual(self.em_tensor.grad.shape, self.expected_dy_dem.shape)

# Check the values of the gradients
torch.testing.assert_close(self.em_x_tensor.grad, self.expected_dy_dem_x, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(
self.em_x_tensor.grad, self.expected_dy_dem_x, atol=1e-5, rtol=1e-5
)

torch.testing.assert_close(self.em_tensor.grad, self.expected_dy_dem, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(
self.em_tensor.grad, self.expected_dy_dem, atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
env.DEVICE = 'cpu'
env.DEVICE = "cpu"
unittest.main()
20 changes: 12 additions & 8 deletions source/tests/pt/test_tabulate_fusion_se_r.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import torch

from deepmd.pt.utils import (
env
env,
)


@@ -1069,9 +1069,11 @@ def setUp(self):
-1.664931178025436733e-05,
-4.312450972708557703e-06,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(8, 132)
self.table_info_tensor = torch.tensor([0, 0.2, 0.4, 0.01, 0.1, -1], device=env.DEVICE)
self.table_info_tensor = torch.tensor(
[0, 0.2, 0.4, 0.01, 0.1, -1], device=env.DEVICE
)
self.em_tensor = torch.tensor(
[
0.0343909,
@@ -1091,7 +1093,7 @@ def setUp(self):
0.17527857,
0.04249097,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4)
self.table_info_tensor.requires_grad = True
self.table_tensor.requires_grad = True
@@ -1231,7 +1233,7 @@ def setUp(self):
-0.281368,
-1.471135,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4, 8)
# backward test
self.expected_dy_dem = torch.tensor(
@@ -1253,7 +1255,7 @@ def setUp(self):
-0.095974,
-0.105310,
],
device=env.DEVICE
device=env.DEVICE,
).reshape(4, 4)

def test_forward(self):
@@ -1302,9 +1304,11 @@ def test_backward(self):
self.assertEqual(self.em_tensor.grad.shape, self.expected_dy_dem.shape)

# Check the values of the gradients
torch.testing.assert_close(self.em_tensor.grad, self.expected_dy_dem, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(
self.em_tensor.grad, self.expected_dy_dem, atol=1e-5, rtol=1e-5
)


if __name__ == "__main__":
env.DEVICE = 'cuda:0'
env.DEVICE = "cuda:0"
unittest.main()
Loading
Loading