Skip to content

Commit

Permalink
polish decomp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Nov 1, 2024
1 parent 3b27c49 commit 541dae6
Showing 1 changed file with 127 additions and 12 deletions.
139 changes: 127 additions & 12 deletions deepmd/pd/utils/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
# This file will be removed when implmented functions are decomposed into primitive
# function in Paddle framework in the future.

from __future__ import (
annotations,
)

import paddle

Expand All @@ -23,13 +26,17 @@
def softmax_decomp(x: paddle.Tensor, axis: int = -1) -> paddle.Tensor:
"""Forward decompsition function of softmax.
Args:
x (paddle.Tensor): Input.
axis (int, optional): A dimension along which softmax will be computed. Defaults to -1.
Parameters
----------
x : paddle.Tensor
Input.
axis : int, defaults: -1.
A dimension along which softmax will be computed.
Returns
-------
paddle.Tensor: Computed output.
paddle.Tensor
Computed output.
"""
x_max = paddle.max(x, axis=axis, keepdim=True)
x = x - x_max
Expand All @@ -39,6 +46,25 @@ def softmax_decomp(x: paddle.Tensor, axis: int = -1) -> paddle.Tensor:
def norm_decomp(
x: paddle.Tensor, p: float = 2, axis: bool = -1, keepdim: bool = False
) -> paddle.Tensor:
"""Forward decompsition function of norm.
Parameters
----------
x : paddle.Tensor
Input
p : float, default: 2
Order of norm
axis : bool, default: -1
Dimensions over which to compute the vector or matrix norm
keepdim : bool, default: False
If set to True, the reduced dimensions are retained in the result as dimensions
with size one
Returns
-------
paddle.Tensor
A real-valued tensor, even when A is complex.
"""
if p == 2 or p == 2.0:
# clip for negative indexing, or 1/(0^(k-1)) will cause inf in backward
return (x * x).sum(axis=axis, keepdim=keepdim).clip(1e-12) ** 0.5
Expand All @@ -47,9 +73,27 @@ def norm_decomp(

def take_along_axis_decomp(
x: paddle.Tensor, indices: paddle.Tensor, axis: int, broadcast: bool = True
):
"""Broadcast no used now."""
# manually contruct indices for gather_nd(ind_gather_nd.ndim == indices.ndim + 1, the lsat 1 represents the number of dimension(s) of indices)
) -> paddle.Tensor:
"""Forward decompsition function of take_along_axis.
Parameters
----------
x : paddle.Tensor
The input tensor.
indices : paddle.Tensor
Indices to take along each 1d slice of array.
axis : int
The axis to take 1d slices along.
broadcast : bool, default: True
Whether the indices broadcast.
Returns
-------
paddle.Tensor
Computed output.
"""
# manually contruct indices for gather_nd(ind_gather_nd.ndim == indices.ndim + 1,
# the lsat 1 represents the number of dimension(s) of indices)
ind_gather_nd = paddle.stack(
paddle.meshgrid(*[paddle.arange(v) for v in indices.shape], indexing="ij"),
axis=-1,
Expand All @@ -67,6 +111,27 @@ def scatter_reduce_decomp(
src: paddle.Tensor,
reduce: str,
) -> paddle.Tensor:
"""Forward decompsition function of scatter_reduce.
Parameters
----------
input : paddle.Tensor
Input tensor.
axis : int
The axis along which to index.
index : paddle.Tensor
The indices of elements to scatter and reduce.
src : paddle.Tensor
The source elements to scatter and reduce.
reduce : str
The reduction operation to apply for non-unique indices.
Supported modes: ("sum", "prod", "mean", "amax", "amin").
Returns
-------
paddle.Tensor
Computed output.
"""
# reduce: "sum", "prod", "mean", "amax", "amin"
if reduce == "sum":
input.put_along_axis_(indices=index, values=src, axis=axis, reduce="add")
Expand All @@ -86,17 +151,49 @@ def scatter_reduce_decomp(
return input


def sec(l: int, size: int) -> list[int]:
assert l > 0
def sec(length: int, size: int) -> list[int]:
"""Auxiliary function for decomposed functions.
If length is not divisible by size, the last chunk will be smaller.
Parameters
----------
length : int
Length to be chunked.
size : int
Chunk size.
Returns
-------
list[int]
Chunked output list.
"""
assert length > 0
assert size > 0
if l % size == 0:
return [size] * (l // size)
return [size] * (l // size) + [l % size]
if length % size == 0:
return [size] * (length // size)
return [size] * (length // size) + [length % size]


def masked_add__decomp(
x: paddle.Tensor, mask: paddle.Tensor, v: paddle.Tensor
) -> paddle.Tensor:
"""Forward decompsition function of masked_add_(inplace operator).
Parameters
----------
x : paddle.Tensor
Input tensor.
mask : paddle.Tensor
Mask tensor.
v : paddle.Tensor
Value to add.
Returns
-------
paddle.Tensor
Computed output.
"""
assert mask.dtype == paddle.bool, f"mask must be bool type, but got {mask.dtype}"
# indices is bool mask
mask_coord = paddle.concat(
Expand All @@ -120,6 +217,24 @@ def normalize_decomp(
axis: int = 1,
epsilon: float = 1e-12,
) -> paddle.Tensor:
"""Forward decompsition function of normalize.
Parameters
----------
x : paddle.Tensor
Input tensor.
p : float, optional
Order of the norm, default: 2
axis : int, optional
Axis on which to perform normalization, default: 1
epsilon : float, optional
Epislon value, default: 1e-12
Returns
-------
paddle.Tensor
Computed output.
"""
return x / (norm(x, p=p, axis=axis, keepdim=True).clip(min=epsilon))


Expand Down

0 comments on commit 541dae6

Please sign in to comment.