Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust moments #46

Merged
merged 24 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r tests/requirements.txt
pip install -e .
pip install .
- name: Run tests
run: |
pytest -v tests/.
Expand Down
19 changes: 19 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Cargo.toml
[package]
name = "anisoap_rust"
version = "0.0.0"
edition = "2021"

[dependencies]
pyo3 = "0.21.0"
numpy = "0.21.0"

[lib]
name = "anisoap_rust_lib" # private module to be nested into Python package,
# needs to match the name of the function with the `[#pymodule]` attribute

path = "src/lib.rs"
crate-type = ["cdylib"] # required for shared library for Python to import from.

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
# See also PyO3 docs on writing Cargo.toml files at https://pyo3.rs
3 changes: 3 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# MANIFEST.in
include Cargo.toml
recursive-include rust *.rs
13 changes: 8 additions & 5 deletions anisoap/representations/ellipsoidal_density_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from itertools import product

import numpy as np
from anisoap_rust_lib import compute_moments
rosecers marked this conversation as resolved.
Show resolved Hide resolved
from metatensor import (
Labels,
TensorBlock,
Expand Down Expand Up @@ -30,6 +31,7 @@ def pairwise_ellip_expansion(
sph_to_cart,
radial_basis,
show_progress=False,
rust_moments=True,
):
r"""Computes pairwise expansion

Expand Down Expand Up @@ -136,13 +138,14 @@ def pairwise_ellip_expansion(
constant,
) = radial_basis.compute_gaussian_parameters(r_ij, lengths, rot)

moments = (
np.exp(-0.5 * constant)
* length_norm
* compute_moments_inefficient_implementation(
if rust_moments:
moments = compute_moments(precision, center, maxdeg)
else:
moments = compute_moments_inefficient_implementation(
precision, center, maxdeg=maxdeg
)
)
moments *= np.exp(-0.5 * constant) * length_norm

for l in range(lmax + 1):
deg = l + 2 * (num_ns[l] - 1)
moments_l = moments[: deg + 1, : deg + 1, : deg + 1]
Expand Down
57 changes: 39 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,44 @@
[tool.tox]
legacy_tox_ini = """
[tox]
[project]
name = "anisoap"
version = "0.0.0"
requires-python = ">=3.8"
authors = [
{name = "Arthur Lin", email = "[email protected]"},
{name = "Kevin Kazuki Huguenin-Dumittan"},
{name = "Jigyasa Nigam"},
{name = "Yong-Cheol Cho"},
{name = "Lucas Ortengren"},
{name = "Seonwoo Hwang"},
{name = "Rose K. Cersonsky"}
]
description = "A package for computing anisotropic extensions to the SOAP formalism"
readme = "README.md"
license = {file = "LICENSE"}
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
"License :: OSI Approved :: Apache License 2.0",
"Natural Language :: English",

[testenv:tests]
changedir = tests
deps = -rtests/requirements.txt
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
]

commands =
coverage run -m unittest discover -p "*.py"
coverage xml
[build-system]
requires = ["setuptools", "setuptools-rust"]
build-backend = "setuptools.build_meta"

"""
[tool.setuptools.packages]
# Pure Python packages/modules
find = { where = ["."] }

[tool.coverage.run]
branch = true
data_file = "tests/.coverage"
[[tool.setuptools-rust.ext-modules]]
# Private Rust extension module to be nested into the Python package
target = "anisoap_rust_lib" # The last part of the name (e.g. "_lib") has to match lib.name in Cargo.toml,
# but you can add a prefix to nest it inside of a Python package.
path = "Cargo.toml" # Default value, can be omitted
binding = "PyO3" # Default value, can be omitted

[tool.coverage.report]
include = ["anisoap/*"]

[tool.coverage.xml]
output = "tests/coverage.xml"
17 changes: 17 additions & 0 deletions rust/lib.rs
rosecers marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use pyo3::prelude::*;

/// Formats the sum of two numbers as string.
#[pyfunction]
fn fib(n: u64) -> u64 {
if n <= 1 {
return n;
}
fib(n - 1) + fib(n - 2)
}

/// A Python module implemented in Rust.
#[pymodule]
fn fibbers(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fib, m)?)?;
Ok(())
}
27 changes: 0 additions & 27 deletions setup.cfg

This file was deleted.

10 changes: 0 additions & 10 deletions setup.py

This file was deleted.

175 changes: 175 additions & 0 deletions src/ellip_expansion/compute_moments.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
use numpy::ndarray::{Array3, ArrayView1, ArrayView2};
use pyo3::exceptions::PyAssertionError;
use pyo3::prelude::*;

/// Compute all moments <x^n0 y^n1 z^n2> for a general dilation matrix.
/// Since this computes moments for all n0, n1, and n2, and stores 0 for some
/// impossible configurations, it may not be memory-efficient.
/// However, this implementation allows simple access to all moments with
/// [n0, n1, n2] indexing like normal arrays.
///
/// # Arguments
/// * `dil_mat` - A symmetric, 3x3 matrix, given by np.ndarray from python side.
/// This function will return Err (exception on Python side) if
/// the matrix is not of size 3x3, not symmetric, or not invertible.
/// * `gau_cen` - A 3-dimensional vector for center of tri-variate Gaussian.
/// * `max_deg` - An integer that represents the maximum degree for which moments
/// must be computed. The given number must be positive; otherwise,
/// it will return Err (exception on Python side).
pub fn compute_moments_rust(
dil_mat: ArrayView2<'_, f64>,
gau_cen: ArrayView1<'_, f64>,
max_deg: i32,
) -> PyResult<Array3<f64>> {
// Check if the dilation matrix is a 3x3 matrix.
if dil_mat.shape() != &[3, 3] {
return Err(PyErr::new::<PyAssertionError, _>(
"Dilation matrix needs to be 3x3",
));
}

// Check if the dilation matrix is symmetric
for i in 0..3 {
for j in 0..3 {
if (dil_mat[[i, j]] - dil_mat[[j, i]]).powi(2) >= 1e-14 {
return Err(PyErr::new::<PyAssertionError, _>(
"Dilation matrix needs to be symmetric",
));
}
}
}

if gau_cen.shape() != &[3] {
return Err(PyErr::new::<PyAssertionError, _>(
"Center of Gaussian has to be given by a 3-dim. vector.",
));
}

if max_deg <= 0 {
return Err(PyErr::new::<PyAssertionError, _>(
"The maximum degree needs to be at least 1.",
));
}

// Unpack three values of Gaussian centers, as they will be frequently
// accessed while calculating moments.
let (a0, a1, a2) = (gau_cen[0], gau_cen[1], gau_cen[2]);

// [a, b, c] <- This is how general symmetric 3x3 matrix look like
// [b, d, e] and we only need 6 out of 9 values to compute entire
// [c, e, f] determinant and inverse.
// These values are cached on stack to remove frequent address
// lookups required for indexing
let (a, b, c, d, e, f) = (
dil_mat[[0, 0]],
dil_mat[[0, 1]],
dil_mat[[0, 2]],
dil_mat[[1, 1]],
dil_mat[[1, 2]],
dil_mat[[2, 2]],
);

// cofNM is determinant of resulting matrix after removing N-th row and
// M-th column, with appropriate sign of (-1)^(row + col)
// (i.e. (N, M) co-factor matrix)
let (cof00, cof01, cof02) = (d * f - e * e, c * e - b * f, b * e - c * d);

// Determinant of entire dilation matrix
let det = a * cof00 + b * cof01 + c * cof02;
if det.abs() < 1e-14 {
return Err(PyErr::new::<PyAssertionError, _>(
"The given dilation matrix is singular.",
));
}

// Compute inverse; but since each we use coefficients a lot for moments
// calculation, each elements will be stored as individual variables.
let (cov00, cov01, cov02, cov11, cov12, cov22) = (
cof00 / det, // Use pre-computed co-factors
cof01 / det,
cof02 / det,
(a * f - c * c) / det, // Computed with co-factors
(b * c - a * e) / det,
(a * d - b * b) / det,
);

// Compute global_factor, a number that must be multiplied by before returning.
// global_factor = (2 PI)^1.5 / SQRT(det|dil_mat|)
// = SQRT(8 PI^3 / det|dil_mat|)
let global_factor = (8.0 * (std::f64::consts::PI).powi(3) / det).sqrt();

// Prepare an empty array to store answers
let max_deg = max_deg as usize;
let mut moments = Array3::<f64>::zeros((max_deg + 1, max_deg + 1, max_deg + 1));

// Initialize degree-1 elements
moments[[0, 0, 0]] = 1.0;
moments[[1, 0, 0]] = a0;
moments[[0, 1, 0]] = a1;
moments[[0, 0, 1]] = a2;

if max_deg > 1 {
// Initialize degree-2 elements
moments[[2, 0, 0]] = cov00 + a0 * a0;
moments[[0, 2, 0]] = cov11 + a1 * a1;
moments[[0, 0, 2]] = cov22 + a2 * a2;
moments[[1, 1, 0]] = cov01 + a0 * a1;
moments[[0, 1, 1]] = cov12 + a1 * a2;
moments[[1, 0, 1]] = cov02 + a0 * a2;
}

if max_deg > 2 {
for deg in 2..max_deg {
for n0 in 0..=deg {
for n1 in 0..=(deg - n0) {
let n2 = deg - n0 - n1; // Forces n0 + n1 + n2 = deg
let (n0_pos, n1_pos, n2_pos) = (n0 > 0, n1 > 0, n2 > 0);
let x_iter_add =
0.0 + if n0_pos {
cov00 * n0 as f64 * moments[[n0 - 1, n1, n2]]
} else {
0.0
} + if n1_pos {
cov01 * n1 as f64 * moments[[n0, n1 - 1, n2]]
} else {
0.0
} + if n2_pos {
cov02 * n2 as f64 * moments[[n0, n1, n2 - 1]]
} else {
0.0
};

// Run the x-iteration
moments[[n0 + 1, n1, n2]] = a0 * moments[[n0, n1, n2]] + x_iter_add;

// Run y-iteration if n0 is 0.
if !n0_pos {
let y_iter_add =
0.0 + if n1_pos {
cov11 * n1 as f64 * moments[[n0, n1 - 1, n2]]
} else {
0.0
} + if n2_pos {
cov12 * n2 as f64 * moments[[n0, n1, n2 - 1]]
} else {
0.0
};
moments[[n0, n1 + 1, n2]] = a1 * moments[[n0, n1, n2]] + y_iter_add;

// Run z-iteration if both n0 and n1 are 0.
if !n1_pos {
moments[[n0, n1, n2 + 1]] = a2 * moments[[n0, n1, n2]]
+ if n2_pos {
cov22 * n2 as f64 * moments[[n0, n1, n2 - 1]]
} else {
0.0
}
}
}
}
}
}
}

Ok(moments * global_factor)
}
1 change: 1 addition & 0 deletions src/ellip_expansion/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod compute_moments;
Loading
Loading