Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement generalized NumPy-friendly oqupy.util.create_delta function #140

Open
Sampreet opened this issue Aug 27, 2024 · 1 comment
Open

Comments

@Sampreet
Copy link
Contributor

Sampreet commented Aug 27, 2024

Summary

Currently the oqupy.util.create_delta(tensor, index_scrambling) function returns a axis-scrambled tensor from a given tensor. The shape of the output tensor is determined by the index_scrambling parameter whose indices are the axis-indices of the original tensor. A way to speed up a special case of this function where a rank-n tensor is converted to a rank-(n + 1) tensor (equivalent to index_scrambling set to [0, 1, ..., n-2, n-1, n-1]) has been proposed in this issue. I am opening this issue to discuss a more general and NumPy-friendly version of create_delta by eliminating the while loop and the recursive increase_list_of_index function.

Changelog

  • 2024-09-29: Updated symbols and function docs.
  • 2024-09-01: Corrected typos, updated scripts and added new plots. Removed runtimes for the extension as those included variable runtimes for the influence matrices (which utilizes least recently used cache).

Motivation

The rationale behind this change is not only to speed up the create_delta function but also make it vectorizable and to support just-in-time (JIT) compilation for certain steps.

Related Theory

From what I understand,

  • The increase_list_of_index function iterates over the axes of tensor backwards and returns False when all the indices are covered. This is leads to a total number of iterations $n_{iters} = m_{n_{in} - 1} \times ... \times m_1 \times m_0$ inside the while loop of create_delta, where $n_{in}$ is the rank of tensor and $m$ denotes the number of elements in each axis, such that the shape of the tensor is ($m_0$, $m_1$, ..., $m_{n_{in} -1}$). The while loop and the condition-function can therefore be replaced by a for loop using $n_{iters}$ iterables each of the input and output axes.
  • The ret_ndarray is a tensor of rank $n_{out}$, determined by the total number of elements in index_scrambling. Since the elements of index_scrambling are the indices for the axes of tensor, by constructing a 2D-array indices_in with shape ($n_{in}, n_{iters}$) whose elements slice the individual axes of tensor to return the required values at each iteration, one can obtain the 2D-array indices_out with shape ($n_{out}, n_{iters}$) for the updated indices of ret_ndarray by slicing indices_in with index_scrambling. This removes the requirement of a for loop.

Implementation

The create_delta function is modified to:

def create_delta(
        tensor: ndarray,
        index_scrambling: List[int],
    ) -> ndarray:
    """Creates deltas in a tensor."""
    # converting to NumPy-array for future-proof implementation
    # see [this issue](https://github.com/google/jax/issues/4564)
    # the shape of the tensor has n_in elements whereas
    # index_scrambling has n_out elements
    _shape = np.array(tensor.shape, dtype=int)
    _idxs = np.array(index_scrambling, dtype=int)

    # obtain the selection indices for each axis
    _indices = get_indices(_shape, np.prod(_shape))

    # scramble output tensor with elements of input tensor
    scrambled_tensor = np.zeros(tuple(_shape[_idxs]), \
                        dtype=tensor.dtype)
    scrambled_tensor[tuple(_indices[_idxs])] = tensor[tuple(_indices)]
    return scrambled_tensor

def get_indices(
        shape: ndarray,
        n_iters: int,
    ) -> ndarray:
    """Obtain index matrix for scrambling."""
    # obtain divisors for each axis as values equal to the
    # number of elements contained upto the preceeding axes
    # for e.g., shape [4, 5, 3] will result in [15, 3, 1]
    divisors = np.cumprod(np.concatenate([
        shape[1:],
        np.array([1], dtype=int)
    ])[::-1])[::-1]
    # prepare an iteration matrix of shape (n_iters, n_in)
    # to index each axis, for e.g., n_iters = 3 x 5 x 4
    iteration_matrix = np.arange(0, n_iters).reshape(
        (n_iters, 1)).repeat(shape.shape[0], 1)
    # divide each element with the divisors obtained above
    # and obtain the remainder modullo the size of each axis
    # return the index matrix with shape (n_in, n_iters)
    return ((iteration_matrix / divisors).astype(int) % shape).T

The idea behind segregating the get_indices function is to make it JIT-compatible. Here, n_iters is passed separately to maintain JAX-compatibility with np.arange. All the corresponding changes can be viewed by comparing the pr/enhancement-create-delta branch.

Comparison

The following plots illustrate the speedups for four different scenarios (use-cases in oqupy.backends.tempo_backend, oqupy.backends.pt_tempo_backend and oqupy.process_tensor, and runtimes for multi-time correlations as demonstrated in Fig. 5 of arXiv:2406.16650):

create_delta_m_1001_new create_delta_m_0110_new
create_delta_chi_d_0122_new 5d_compare

Reproducing the Comparisons

The following snippet can be used to reproduce the first three plots:

ms = np.arange(2, 16)**2        # axis dimensions
shapes = [(m, m) for m in ms]
index_scrambling = [1, 0, 0, 1] # or [0, 1, 1, 0]
# # uncomment for third plot
# chi_ds = np.arange(2, 503, 20)  # bond dimensions
# shapes = [(chi_d, chi_d, 4) for chi_d in chi_ds]
# index_scrambling = [0, 1, 2, 2]
funcs = [create_delta_old, create_delta]
times = []
average_over = 5
for shape in shapes:
    tensor_in = np.ones(shape, dtype=np.complex_)
    ts = []
    for j in range(len(funcs)):
        start = time.time()
        for i in range(average_over):
            _ = funcs[j](
                tensor=tensor_in,
                index_scrambling=index_scrambling
            )
        ts.append((time.time() - start) / average_over)
    times.append(ts)

The final plot is obtained with the same methods and parameters as mentioned in the arXiv preprint. The runtimes of the general case proposed here are close to those of the special-case snippet posted by @eoin-dp-oneill in the previous issue. Also, I observed slightly faster runtimes for higher dimensions with JIT-ted JAX-CPU implementation of get_indices. All plots are obtained using an Intel i7 8700K processor throttled at 90% usage.

Extension

The initialize_mps_mpo method of oqupy.backends.tempo_backend.BaseTempoBackend involves several calls to create_delta successively, within a for loop running for dkmax_pre_compute - 1 steps with equal shapes of the influence tensor (for i != 0). Same goes for the oqupy.backends.pt_tempo_backend.PtTempoBackend.initialize_mps_mpo method but with one less step at the end. Since the influence tensor is a reproducible tensor for each i, a new oqupy.util function create_deltas (with an s) can be implemented to speed up this successive computation as follows:

def create_deltas(
        func_tensors: callable,
        indices: List[int],
        index_scrambling: List[int]) -> List[ndarray]:
    """Creates deltas in multiple tensors."""
    # use a test tensor to obtain the indices
    tensor = func_tensors(indices[0])
    _shape = np.array(tensor.shape, dtype=int)
    _idxs = np.array(index_scrambling, dtype=int)
    _indices = get_indices(_shape, np.prod(_shape))
    indices_in = tuple(_indices)
    indices_out = tuple(_indices[_idxs])

    # accumulate scrambled tensors and return list
    scrambled_tensors = []
    for i in indices:
        array = np.zeros(tuple(_shape[_idxs]), \
                                dtype=tensor.dtype)
        array[indices_out] = func_tensors(i)[indices_in]
        scrambled_tensors.append(array)
    return scrambled_tensors

The corresponding code block of for loop in initialize_mps_mpo (say for BaseTempoBackend) can be modified to:

        influences = []

        # this block takes care of `i == 0`
        infl = self._influence(0)
        if self._degeneracy_maps is not None:
            infl_four_legs = np.zeros((tmp_west_deg_num_vals, self._dim**2,
                                       tmp_north_deg_num_vals, self._dim**2), \
                                        dtype=NpDtype)
            # a little bit of optimization is done here by
            # removing the `for` loop and updating slices
            _idxs = np.array(list(range(self._dim**2)))
            indices = (west_degeneracy_map[_idxs], _idxs,
                       north_degeneracy_map[_idxs], _idxs)
            infl_four_legs[indices] = infl[indices[2]]
        else:
            infl_four_legs = create_delta(infl, [1, 0, 0, 1])
        infl_four_legs = np.dot(np.moveaxis(infl_four_legs, 1, -1), \
                                self._super_u_dagg)
        infl_four_legs = np.moveaxis(infl_four_legs, -1, 1)
        infl_four_legs = np.dot(infl_four_legs, self._super_u.T)
        influences.append(infl_four_legs)

        # this block takes care of `i > 0`
        if dkmax_pre_compute > 1:
            indices = list(range(1, dkmax_pre_compute))
            influences += create_deltas(self._influence, indices,
                                        [1, 0, 0, 1])
            # # uncomment to test the new function
            # for index in indices:
            #    infl = self._influence(index)
            #    influences.append(create_delta(infl, [1, 0, 0, 1]))

I have tested the above implementations using tox and reproduced the plots of arXiv:2406.16650 for both of the mentioned approaches. I hope that the generalized function will be useful to implement more platform-dependent special cases. Kindly share your views on the same as per your convenience.

@Sampreet Sampreet changed the title Implementing a general NumPy-friendly oqupy.util.create_delta function Implement generalized NumPy-friendly oqupy.util.create_delta function Sep 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants
@Sampreet and others