diff --git a/Cargo.toml b/Cargo.toml index c13a4cb..8d725e1 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,4 +19,3 @@ rustc-hash = "1.1" codegen-units = 1 lto = true opt-level = 3 -panic = "abort" diff --git a/pyproject.toml b/pyproject.toml index a5ae6bf..c2d241e 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,22 @@ -[build-system] -requires = ["maturin>=0.15,<0.16"] -build-backend = "maturin" - [project] name = "cotengrust" -requires-python = ">=3.7" +version = "0.1.0" +description = "Fast contraction ordering primitives for tensor networks." +requires-python = ">=3.8" classifiers = [ "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] +license = { file = "LICENSE" } +authors = [ + {name = "Johnnie Gray", email = "johnniemcgray@gmail.com"} +] + +[build-system] +requires = ["maturin>=0.15,<0.16"] +build-backend = "maturin" [tool.maturin] features = ["pyo3/extension-module"] diff --git a/src/lib.rs b/src/lib.rs index f34173f..eb8bfe7 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,7 @@ type SubContraction = (Legs, Score, BitPath); /// helper struct to build contractions from bottom up struct ContractionProcessor { nodes: Dict, - edges: Dict>, + edges: Dict>, appearances: Vec, sizes: Vec, ssa: Node, @@ -133,7 +133,7 @@ impl ContractionProcessor { size_dict: Dict, ) -> ContractionProcessor { let mut nodes: Dict = Dict::default(); - let mut edges: Dict> = Dict::default(); + let mut edges: Dict> = Dict::default(); let mut indmap: Dict = Dict::default(); let mut sizes: Vec = Vec::with_capacity(size_dict.len()); let mut appearances: Vec = Vec::with_capacity(size_dict.len()); @@ -147,7 +147,7 @@ impl ContractionProcessor { None => { // index not parsed yet indmap.insert(ind, c); - edges.insert(c, vec![i as Node]); + edges.insert(c, std::iter::once(i as Node).collect()); appearances.push(1); sizes.push(f32::log(size_dict[&ind] as f32, 2.0)); legs.push((c, 1)); @@ -156,7 +156,7 @@ impl ContractionProcessor { Some(&ix) => { // index already present appearances[ix as usize] += 1; - edges.get_mut(&ix).unwrap().push(i as Node); + edges.get_mut(&ix).unwrap().insert(i as Node); legs.push((ix, 1)); } }; @@ -204,11 +204,15 @@ impl ContractionProcessor { fn pop_node(&mut self, i: Node) -> Legs { let legs = self.nodes.remove(&i).unwrap(); for (ix, _) in legs.iter() { - let nodes = self.edges.get_mut(&ix).unwrap(); - if nodes.len() == 1 { + let enodes = match self.edges.get_mut(&ix) { + Some(enodes) => enodes, + // if repeated index, might have already been removed + None => continue, + }; + enodes.remove(&i); + if enodes.len() == 0 { + // last node with this index -> remove from map self.edges.remove(&ix); - } else { - nodes.retain(|&j| j != i); } } legs @@ -221,8 +225,8 @@ impl ContractionProcessor { for (ix, _) in &legs { self.edges .entry(*ix) - .and_modify(|nodes| nodes.push(i)) - .or_insert(vec![i]); + .and_modify(|nodes| {nodes.insert(i);}) + .or_insert(std::iter::once(i as Node).collect()); } self.nodes.insert(i, legs); i @@ -393,6 +397,8 @@ impl ContractionProcessor { // get the initial candidate contractions for ix_nodes in self.edges.values() { + // convert to vector for combinational indexing + let ix_nodes: Vec = ix_nodes.iter().cloned().collect(); // for all combinations of nodes with a connected edge for ip in 0..ix_nodes.len() { let i = ix_nodes[ip]; @@ -617,7 +623,7 @@ impl ContractionProcessor { subgraph: Vec, minimize: Option, cost_cap: Option, - allow_outer: Option, + search_outer: Option, ) { // parse the minimize argument let minimize = minimize.unwrap_or("flops".to_string()); @@ -642,7 +648,7 @@ impl ContractionProcessor { minimize ), }; - let allow_outer = allow_outer.unwrap_or(false); + let search_outer = search_outer.unwrap_or(false); // storage for each possible contraction to reach subgraph of size m let mut contractions: Vec> = @@ -691,8 +697,8 @@ impl ContractionProcessor { let mut temp_legs: Legs = Vec::with_capacity(ilegs.len() + jlegs.len()); ip = 0; jp = 0; - // if allow_outer -> we will never skip - skip_because_outer = !allow_outer; + // if search_outer -> we will never skip + skip_because_outer = !search_outer; while ip < ilegs.len() && jp < jlegs.len() { if ilegs[ip].0 < jlegs[jp].0 { // index only appears in ilegs @@ -784,16 +790,44 @@ impl ContractionProcessor { &mut self, minimize: Option, cost_cap: Option, - allow_outer: Option, + search_outer: Option, ) { for subgraph in self.subgraphs() { - self.optimize_optimal_connected(subgraph, minimize.clone(), cost_cap, allow_outer); + self.optimize_optimal_connected(subgraph, minimize.clone(), cost_cap, search_outer); } } } // --------------------------- PYTHON FUNCTIONS ---------------------------- // +#[pyfunction] +#[pyo3()] +fn ssa_to_linear(ssa_path: SSAPath, n: Option) -> SSAPath { + let n = match n { + Some(n) => n, + None => ssa_path.iter().map(|v| v.len()).sum::() + ssa_path.len() + 1, + }; + let mut ids: Vec = (0..n).map(|i| i as Node).collect(); + let mut path: SSAPath = Vec::with_capacity(2 * n - 1); + let mut ssa = n as Node; + for scon in ssa_path { + // find the locations of the ssa ids in the list of ids + let mut con: Vec = scon + .iter() + .map(|s| ids.binary_search(s).unwrap() as Node) + .collect(); + // remove the ssa ids from the list + con.sort(); + for j in con.iter().rev() { + ids.remove(*j as usize); + } + path.push(con); + ids.push(ssa); + ssa += 1; + } + path +} + #[pyfunction] #[pyo3()] fn find_subgraphs( @@ -811,10 +845,16 @@ fn optimize_simplify( inputs: Vec>, output: Vec, size_dict: Dict, + use_ssa: Option, ) -> SSAPath { + let n = inputs.len(); let mut cp = ContractionProcessor::new(inputs, output, size_dict); cp.simplify(); - cp.ssa_path + if use_ssa.unwrap_or(false) { + cp.ssa_path + } else { + ssa_to_linear(cp.ssa_path, Some(n)) + } } #[pyfunction] @@ -826,15 +866,23 @@ fn optimize_greedy( costmod: Option, temperature: Option, simplify: Option, + use_ssa: Option, ) -> Vec> { + let n = inputs.len(); let mut cp = ContractionProcessor::new(inputs, output, size_dict); if simplify.unwrap_or(true) { + // perform simplifications cp.simplify(); } + // greddily contract each connected subgraph cp.optimize_greedy(costmod, temperature); // optimize any remaining disconnected terms cp.optimize_remaining_by_size(); - cp.ssa_path + if use_ssa.unwrap_or(false) { + cp.ssa_path + } else { + ssa_to_linear(cp.ssa_path, Some(n)) + } } #[pyfunction] @@ -845,22 +893,31 @@ fn optimize_optimal( size_dict: Dict, minimize: Option, cost_cap: Option, - allow_outer: Option, + search_outer: Option, simplify: Option, + use_ssa: Option, ) -> Vec> { + let n = inputs.len(); let mut cp = ContractionProcessor::new(inputs, output, size_dict); if simplify.unwrap_or(true) { + // perform simplifications cp.simplify(); } - cp.optimize_optimal(minimize, cost_cap, allow_outer); + // optimally contract each connected subgraph + cp.optimize_optimal(minimize, cost_cap, search_outer); // optimize any remaining disconnected terms cp.optimize_remaining_by_size(); - cp.ssa_path + if use_ssa.unwrap_or(false) { + cp.ssa_path + } else { + ssa_to_linear(cp.ssa_path, Some(n)) + } } /// A Python module implemented in Rust. #[pymodule] fn cotengrust(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(ssa_to_linear, m)?)?; m.add_function(wrap_pyfunction!(find_subgraphs, m)?)?; m.add_function(wrap_pyfunction!(optimize_simplify, m)?)?; m.add_function(wrap_pyfunction!(optimize_greedy, m)?)?; diff --git a/tests/test_cotengrust.py b/tests/test_cotengrust.py new file mode 100644 index 0000000..f2f38ab --- /dev/null +++ b/tests/test_cotengrust.py @@ -0,0 +1,185 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose +import cotengra as ctg +import cotengrust as ctgr + + +def find_output_str(lhs): + tmp_lhs = lhs.replace(",", "") + return "".join(s for s in sorted(set(tmp_lhs)) if tmp_lhs.count(s) == 1) + + +def eq_to_inputs_output(eq): + if "->" not in eq: + eq += "->" + find_output_str(eq) + inputs, output = eq.split("->") + inputs = inputs.split(",") + inputs = [list(s) for s in inputs] + output = list(output) + return inputs, output + + +def get_rand_size_dict(inputs, d_min=2, d_max=3): + size_dict = {} + for term in inputs: + for ix in term: + if ix not in size_dict: + size_dict[ix] = np.random.randint(d_min, d_max + 1) + return size_dict + + +def build_arrays(inputs, size_dict): + return [ + np.random.randn(*[size_dict[ix] for ix in term]) for term in inputs + ] + + +# these are taken from opt_einsum +test_case_eqs = [ + # Test scalar-like operations + "a,->a", + "ab,->ab", + ",ab,->ab", + ",,->", + # Test hadamard-like products + "a,ab,abc->abc", + "a,b,ab->ab", + # Test index-transformations + "ea,fb,gc,hd,abcd->efgh", + "ea,fb,abcd,gc,hd->efgh", + "abcd,ea,fb,gc,hd->efgh", + # Test complex contractions + "acdf,jbje,gihb,hfac,gfac,gifabc,hfac", + "cd,bdhe,aidb,hgca,gc,hgibcd,hgac", + "abhe,hidj,jgba,hiab,gab", + "bde,cdh,agdb,hica,ibd,hgicd,hiac", + "chd,bde,agbc,hiad,hgc,hgi,hiad", + "chd,bde,agbc,hiad,bdi,cgh,agdb", + "bdhe,acad,hiab,agac,hibd", + # Test collapse + "ab,ab,c->", + "ab,ab,c->c", + "ab,ab,cd,cd->", + "ab,ab,cd,cd->ac", + "ab,ab,cd,cd->cd", + "ab,ab,cd,cd,ef,ef->", + # Test outer prodcuts + "ab,cd,ef->abcdef", + "ab,cd,ef->acdf", + "ab,cd,de->abcde", + "ab,cd,de->be", + "ab,bcd,cd->abcd", + "ab,bcd,cd->abd", + # Random test cases that have previously failed + "eb,cb,fb->cef", + "dd,fb,be,cdb->cef", + "bca,cdb,dbf,afc->", + "dcc,fce,ea,dbf->ab", + "fdf,cdd,ccd,afe->ae", + "abcd,ad", + "ed,fcd,ff,bcf->be", + "baa,dcf,af,cde->be", + "bd,db,eac->ace", + "fff,fae,bef,def->abd", + "efc,dbc,acf,fd->abe", + # Inner products + "ab,ab", + "ab,ba", + "abc,abc", + "abc,bac", + "abc,cba", + # GEMM test cases + "ab,bc", + "ab,cb", + "ba,bc", + "ba,cb", + "abcd,cd", + "abcd,ab", + "abcd,cdef", + "abcd,cdef->feba", + "abcd,efdc", + # Inner than dot + "aab,bc->ac", + "ab,bcc->ac", + "aab,bcc->ac", + "baa,bcc->ac", + "aab,ccb->ac", + # Randomly built test caes + "aab,fa,df,ecc->bde", + "ecb,fef,bad,ed->ac", + "bcf,bbb,fbf,fc->", + "bb,ff,be->e", + "bcb,bb,fc,fff->", + "fbb,dfd,fc,fc->", + "afd,ba,cc,dc->bf", + "adb,bc,fa,cfc->d", + "bbd,bda,fc,db->acf", + "dba,ead,cad->bce", + "aef,fbc,dca->bde", +] + + +@pytest.mark.parametrize("eq", test_case_eqs) +@pytest.mark.parametrize("which", ["greedy", "optimal"]) +def test_manual_cases(eq, which): + inputs, output = eq_to_inputs_output(eq) + size_dict = get_rand_size_dict(inputs) + arrays = build_arrays(inputs, size_dict) + expected = np.einsum(eq, *arrays, optimize=True) + + path = { + "greedy": ctgr.optimize_greedy, + "optimal": ctgr.optimize_optimal, + }[ + which + ](inputs, output, size_dict) + tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path) + assert_allclose(tree.contract(arrays), expected) + + +@pytest.mark.parametrize("seed", range(10)) +@pytest.mark.parametrize("which", ["greedy", "optimal"]) +def test_basic_rand(seed, which): + inputs, output, shapes, size_dict = ctg.utils.rand_equation( + n=10, + reg=4, + n_out=2, + n_hyper_in=1, + n_hyper_out=1, + d_min=2, + d_max=3, + seed=seed, + ) + eq = ",".join(map("".join, inputs)) + "->" + "".join(output) + + path = { + "greedy": ctgr.optimize_greedy, + "optimal": ctgr.optimize_optimal, + }[ + which + ](inputs, output, size_dict) + + tree = ctg.ContractionTree.from_path(inputs, output, size_dict, path=path) + arrays = [np.random.randn(*s) for s in shapes] + assert_allclose( + tree.contract(arrays), np.einsum(eq, *arrays, optimize=True) + ) + + +def test_optimal_lattice_eq(): + inputs, output, _, size_dict = ctg.utils.lattice_equation( + [4, 5], d_max=3, seed=42 + ) + + path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='flops') + tree = ctg.ContractionTree.from_path( + inputs, output, size_dict, path=path + ) + assert tree.contraction_cost() == 3628 + + path = ctgr.optimize_optimal(inputs, output, size_dict, minimize='size') + tree = ctg.ContractionTree.from_path( + inputs, output, size_dict, path=path + ) + assert tree.contraction_width() == pytest.approx(6.754887502163468)