You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently the oqupy.util.create_delta(tensor, index_scrambling) function returns a axis-scrambled tensor from a given tensor. The shape of the output tensor is determined by the index_scrambling parameter whose indices are the axis-indices of the original tensor. A way to speed up a special case of this function where a rank-n tensor is converted to a rank-(n + 1) tensor (equivalent to index_scrambling set to [0, 1, ..., n-2, n-1, n-1]) has been proposed in this issue. I am opening this issue to discuss a more general and NumPy-friendly version of create_delta by eliminating the while loop and the recursive increase_list_of_index function.
Changelog
2024-09-29: Updated symbols and function docs.
2024-09-01: Corrected typos, updated scripts and added new plots. Removed runtimes for the extension as those included variable runtimes for the influence matrices (which utilizes least recently used cache).
Motivation
The rationale behind this change is not only to speed up the create_delta function but also make it vectorizable and to support just-in-time (JIT) compilation for certain steps.
Related Theory
From what I understand,
The increase_list_of_index function iterates over the axes of tensor backwards and returns False when all the indices are covered. This is leads to a total number of iterations $n_{iters} = m_{n_{in} - 1} \times ... \times m_1 \times m_0$ inside the while loop of create_delta, where $n_{in}$ is the rank of tensor and $m$ denotes the number of elements in each axis, such that the shape of the tensor is ($m_0$, $m_1$, ..., $m_{n_{in} -1}$). The while loop and the condition-function can therefore be replaced by a for loop using $n_{iters}$ iterables each of the input and output axes.
The ret_ndarray is a tensor of rank $n_{out}$, determined by the total number of elements in index_scrambling. Since the elements of index_scrambling are the indices for the axes of tensor, by constructing a 2D-array indices_in with shape ($n_{in}, n_{iters}$) whose elements slice the individual axes of tensor to return the required values at each iteration, one can obtain the 2D-array indices_out with shape ($n_{out}, n_{iters}$) for the updated indices of ret_ndarray by slicing indices_in with index_scrambling. This removes the requirement of a for loop.
Implementation
The create_delta function is modified to:
defcreate_delta(
tensor: ndarray,
index_scrambling: List[int],
) ->ndarray:
"""Creates deltas in a tensor."""# converting to NumPy-array for future-proof implementation# see [this issue](https://github.com/google/jax/issues/4564)# the shape of the tensor has n_in elements whereas# index_scrambling has n_out elements_shape=np.array(tensor.shape, dtype=int)
_idxs=np.array(index_scrambling, dtype=int)
# obtain the selection indices for each axis_indices=get_indices(_shape, np.prod(_shape))
# scramble output tensor with elements of input tensorscrambled_tensor=np.zeros(tuple(_shape[_idxs]), \
dtype=tensor.dtype)
scrambled_tensor[tuple(_indices[_idxs])] =tensor[tuple(_indices)]
returnscrambled_tensordefget_indices(
shape: ndarray,
n_iters: int,
) ->ndarray:
"""Obtain index matrix for scrambling."""# obtain divisors for each axis as values equal to the# number of elements contained upto the preceeding axes# for e.g., shape [4, 5, 3] will result in [15, 3, 1]divisors=np.cumprod(np.concatenate([
shape[1:],
np.array([1], dtype=int)
])[::-1])[::-1]
# prepare an iteration matrix of shape (n_iters, n_in)# to index each axis, for e.g., n_iters = 3 x 5 x 4iteration_matrix=np.arange(0, n_iters).reshape(
(n_iters, 1)).repeat(shape.shape[0], 1)
# divide each element with the divisors obtained above# and obtain the remainder modullo the size of each axis# return the index matrix with shape (n_in, n_iters)return ((iteration_matrix/divisors).astype(int) %shape).T
The idea behind segregating the get_indices function is to make it JIT-compatible. Here, n_iters is passed separately to maintain JAX-compatibility with np.arange. All the corresponding changes can be viewed by comparing the pr/enhancement-create-delta branch.
Comparison
The following plots illustrate the speedups for four different scenarios (use-cases in oqupy.backends.tempo_backend, oqupy.backends.pt_tempo_backend and oqupy.process_tensor, and runtimes for multi-time correlations as demonstrated in Fig. 5 of arXiv:2406.16650):
Reproducing the Comparisons
The following snippet can be used to reproduce the first three plots:
ms=np.arange(2, 16)**2# axis dimensionsshapes= [(m, m) forminms]
index_scrambling= [1, 0, 0, 1] # or [0, 1, 1, 0]# # uncomment for third plot# chi_ds = np.arange(2, 503, 20) # bond dimensions# shapes = [(chi_d, chi_d, 4) for chi_d in chi_ds]# index_scrambling = [0, 1, 2, 2]funcs= [create_delta_old, create_delta]
times= []
average_over=5forshapeinshapes:
tensor_in=np.ones(shape, dtype=np.complex_)
ts= []
forjinrange(len(funcs)):
start=time.time()
foriinrange(average_over):
_=funcs[j](
tensor=tensor_in,
index_scrambling=index_scrambling
)
ts.append((time.time() -start) /average_over)
times.append(ts)
The final plot is obtained with the same methods and parameters as mentioned in the arXiv preprint. The runtimes of the general case proposed here are close to those of the special-case snippet posted by @eoin-dp-oneill in the previous issue. Also, I observed slightly faster runtimes for higher dimensions with JIT-ted JAX-CPU implementation of get_indices. All plots are obtained using an Intel i7 8700K processor throttled at 90% usage.
Extension
The initialize_mps_mpo method of oqupy.backends.tempo_backend.BaseTempoBackend involves several calls to create_delta successively, within a for loop running for dkmax_pre_compute - 1 steps with equal shapes of the influence tensor (for i != 0). Same goes for the oqupy.backends.pt_tempo_backend.PtTempoBackend.initialize_mps_mpo method but with one less step at the end. Since the influence tensor is a reproducible tensor for each i, a new oqupy.util function create_deltas (with an s) can be implemented to speed up this successive computation as follows:
defcreate_deltas(
func_tensors: callable,
indices: List[int],
index_scrambling: List[int]) ->List[ndarray]:
"""Creates deltas in multiple tensors."""# use a test tensor to obtain the indicestensor=func_tensors(indices[0])
_shape=np.array(tensor.shape, dtype=int)
_idxs=np.array(index_scrambling, dtype=int)
_indices=get_indices(_shape, np.prod(_shape))
indices_in=tuple(_indices)
indices_out=tuple(_indices[_idxs])
# accumulate scrambled tensors and return listscrambled_tensors= []
foriinindices:
array=np.zeros(tuple(_shape[_idxs]), \
dtype=tensor.dtype)
array[indices_out] =func_tensors(i)[indices_in]
scrambled_tensors.append(array)
returnscrambled_tensors
The corresponding code block of for loop in initialize_mps_mpo (say for BaseTempoBackend) can be modified to:
influences= []
# this block takes care of `i == 0`infl=self._influence(0)
ifself._degeneracy_mapsisnotNone:
infl_four_legs=np.zeros((tmp_west_deg_num_vals, self._dim**2,
tmp_north_deg_num_vals, self._dim**2), \
dtype=NpDtype)
# a little bit of optimization is done here by# removing the `for` loop and updating slices_idxs=np.array(list(range(self._dim**2)))
indices= (west_degeneracy_map[_idxs], _idxs,
north_degeneracy_map[_idxs], _idxs)
infl_four_legs[indices] =infl[indices[2]]
else:
infl_four_legs=create_delta(infl, [1, 0, 0, 1])
infl_four_legs=np.dot(np.moveaxis(infl_four_legs, 1, -1), \
self._super_u_dagg)
infl_four_legs=np.moveaxis(infl_four_legs, -1, 1)
infl_four_legs=np.dot(infl_four_legs, self._super_u.T)
influences.append(infl_four_legs)
# this block takes care of `i > 0`ifdkmax_pre_compute>1:
indices=list(range(1, dkmax_pre_compute))
influences+=create_deltas(self._influence, indices,
[1, 0, 0, 1])
# # uncomment to test the new function# for index in indices:# infl = self._influence(index)# influences.append(create_delta(infl, [1, 0, 0, 1]))
I have tested the above implementations using tox and reproduced the plots of arXiv:2406.16650 for both of the mentioned approaches. I hope that the generalized function will be useful to implement more platform-dependent special cases. Kindly share your views on the same as per your convenience.
The text was updated successfully, but these errors were encountered:
Sampreet
changed the title
Implementing a general NumPy-friendly oqupy.util.create_delta function
Implement generalized NumPy-friendly oqupy.util.create_delta function
Sep 2, 2024
Summary
Currently the
oqupy.util.create_delta(tensor, index_scrambling)
function returns a axis-scrambled tensor from a given tensor. The shape of the output tensor is determined by theindex_scrambling
parameter whose indices are the axis-indices of the original tensor. A way to speed up a special case of this function where a rank-n tensor is converted to a rank-(n + 1) tensor (equivalent toindex_scrambling
set to[0, 1, ..., n-2, n-1, n-1]
) has been proposed in this issue. I am opening this issue to discuss a more general and NumPy-friendly version ofcreate_delta
by eliminating thewhile
loop and the recursiveincrease_list_of_index
function.Changelog
Motivation
The rationale behind this change is not only to speed up the
create_delta
function but also make it vectorizable and to support just-in-time (JIT) compilation for certain steps.Related Theory
From what I understand,
increase_list_of_index
function iterates over the axes oftensor
backwards and returnsFalse
when all the indices are covered. This is leads to a total number of iterationswhile
loop ofcreate_delta
, wheretensor
andwhile
loop and the condition-function can therefore be replaced by afor
loop usingret_ndarray
is a tensor of rankindex_scrambling
. Since the elements ofindex_scrambling
are the indices for the axes oftensor
, by constructing a 2D-arrayindices_in
with shape (tensor
to return the required values at each iteration, one can obtain the 2D-arrayindices_out
with shape (ret_ndarray
by slicingindices_in
withindex_scrambling
. This removes the requirement of afor
loop.Implementation
The
create_delta
function is modified to:The idea behind segregating the
get_indices
function is to make it JIT-compatible. Here,n_iters
is passed separately to maintain JAX-compatibility withnp.arange
. All the corresponding changes can be viewed by comparing the pr/enhancement-create-delta branch.Comparison
The following plots illustrate the speedups for four different scenarios (use-cases in
oqupy.backends.tempo_backend
,oqupy.backends.pt_tempo_backend
andoqupy.process_tensor
, and runtimes for multi-time correlations as demonstrated in Fig. 5 of arXiv:2406.16650):Reproducing the Comparisons
The following snippet can be used to reproduce the first three plots:
The final plot is obtained with the same methods and parameters as mentioned in the arXiv preprint. The runtimes of the general case proposed here are close to those of the special-case snippet posted by @eoin-dp-oneill in the previous issue. Also, I observed slightly faster runtimes for higher dimensions with JIT-ted JAX-CPU implementation of
get_indices
. All plots are obtained using an Intel i7 8700K processor throttled at 90% usage.Extension
The
initialize_mps_mpo
method ofoqupy.backends.tempo_backend.BaseTempoBackend
involves several calls tocreate_delta
successively, within afor
loop running fordkmax_pre_compute - 1
steps with equal shapes of the influence tensor (fori != 0
). Same goes for theoqupy.backends.pt_tempo_backend.PtTempoBackend.initialize_mps_mpo
method but with one less step at the end. Since the influence tensor is a reproducible tensor for eachi
, a newoqupy.util
functioncreate_deltas
(with ans
) can be implemented to speed up this successive computation as follows:The corresponding code block of
for
loop ininitialize_mps_mpo
(say forBaseTempoBackend
) can be modified to:I have tested the above implementations using
tox
and reproduced the plots of arXiv:2406.16650 for both of the mentioned approaches. I hope that the generalized function will be useful to implement more platform-dependent special cases. Kindly share your views on the same as per your convenience.The text was updated successfully, but these errors were encountered: