Skip to content

Commit

Permalink
♻️ Final cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Pascal Marco Caversaccio <[email protected]>
  • Loading branch information
pcaversaccio committed Nov 6, 2024
1 parent 8b51eba commit 8de3c5d
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions src/snekmate/utils/math.vy
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ def _mul_div(x: uint256, y: uint256, denominator: uint256, roundup: bool) -> uin
assert denominator != empty(uint256), "math: mul_div division by zero"

# 512-bit multiplication "[prod1 prod0] = x * y".
# Compute the product "mod 2 ** 256" and "mod 2 ** 256 - 1".
# Compute the product "mod 2**256" and "mod 2**256 - 1".
# Then use the Chinese Remainder theorem to reconstruct
# the 512-bit result. The result is stored in two 256-bit
# variables, where: "product = prod1 * 2 ** 256 + prod0".
# variables, where: "product = prod1 * 2**256 + prod0".
mm: uint256 = uint256_mulmod(x, y, max_value(uint256))
# The least significant 256 bits of the product.
prod0: uint256 = unsafe_mul(x, y)
Expand All @@ -147,12 +147,12 @@ def _mul_div(x: uint256, y: uint256, denominator: uint256, roundup: bool) -> uin
# Calculate "ceil((x * y) / denominator)". The following
# line cannot overflow because we have the previous check
# "(x * y) % denominator != 0", which accordingly rules out
# the possibility of "x * y = 2 ** 256 - 1" and `denominator == 1`.
# the possibility of "x * y = 2**256 - 1" and `denominator == 1`.
return unsafe_add(unsafe_div(prod0, denominator), 1)

return unsafe_div(prod0, denominator)

# Ensure that the result is less than "2 ** 256". Also,
# Ensure that the result is less than "2**256". Also,
# prevents that `denominator == 0`.
assert denominator > prod1, "math: mul_div overflow"

Expand Down Expand Up @@ -181,34 +181,34 @@ def _mul_div(x: uint256, y: uint256, denominator: uint256, roundup: bool) -> uin
denominator_div: uint256 = unsafe_div(denominator, twos)
# Divide "[prod1 prod0]" by `twos`.
prod0 = unsafe_div(prod0, twos)
# Flip `twos` such that it is "2 ** 256 / twos". If `twos` is zero,
# Flip `twos` such that it is "2**256 / twos". If `twos` is zero,
# it becomes one.
twos = unsafe_add(unsafe_div(unsafe_sub(empty(uint256), twos), twos), 1)

# Shift bits from `prod1` to `prod0`.
prod0 |= unsafe_mul(prod1, twos)

# Invert the denominator "mod 2 ** 256". Since the denominator is
# now an odd number, it has an inverse modulo "2 ** 256", so we have:
# "denominator * inverse = 1 mod 2 ** 256". Calculate the inverse by
# Invert the denominator "mod 2**256". Since the denominator is
# now an odd number, it has an inverse modulo "2**256", so we have:
# "denominator * inverse = 1 mod 2**256". Calculate the inverse by
# starting with a seed that is correct for four bits. That is,
# "denominator * inverse = 1 mod 2 ** 4".
# "denominator * inverse = 1 mod 2**4".
inverse: uint256 = unsafe_mul(3, denominator_div) ^ 2

# Use Newton-Raphson iteration to improve accuracy. Thanks to Hensel's
# lifting lemma, this also works in modular arithmetic by doubling the
# correct bits in each step.
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2 ** 8".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2 ** 16".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2 ** 32".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2 ** 64".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2 ** 128".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2 ** 256".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2**8".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2**16".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2**32".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2**64".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2**128".
inverse = unsafe_mul(inverse, unsafe_sub(2, unsafe_mul(denominator_div, inverse))) # Inverse "mod 2**256".

# Since the division is now exact, we can divide by multiplying
# with the modular inverse of the denominator. This returns the
# correct result modulo "2 ** 256". Since the preconditions guarantee
# that the result is less than "2 ** 256", this is the final result.
# correct result modulo "2**256". Since the preconditions guarantee
# that the result is less than "2**256", this is the final result.
# We do not need to calculate the high bits of the result and
# `prod1` is no longer necessary.
result: uint256 = unsafe_mul(prod0, inverse)
Expand Down Expand Up @@ -395,13 +395,13 @@ def _wad_ln(x: int256) -> int256:
if x == empty(int256):
return empty(int256)

# We want to convert `x` from "10 ** 18" fixed point to "2 ** 96"
# fixed point. We do this by multiplying by "2 ** 96 / 10 ** 18".
# But since "ln(x * C) = ln(x) + ln(C)" holds, we can just do nothing
# here and add "ln(2 ** 96 / 10 ** 18)" at the end.
# We want to convert `x` from "10**18" fixed point to "2**96"
# fixed point. We do this by multiplying by "2**96 / 10**18".
# But since "ln(x * C) = ln(x) + ln(C)" holds, we can just do
# nothing here and add "ln(2**96 / 10**18)" at the end.

# Reduce the range of `x` to "(1, 2) * 2 ** 96".
# Also remember that "ln(2 ** k * x) = k * ln(2) + ln(x)" holds.
# Reduce the range of `x` to "(1, 2) * 2**96".
# Also remember that "ln(2**k * x) = k * ln(2) + ln(x)" holds.
k: int256 = unsafe_sub(convert(self._log2(convert(x, uint256), False), int256), 96)
# Note that to circumvent Vyper's safecast feature for the potentially
# negative expression `x <<= uint256(159 - k)`, we first convert the
Expand All @@ -422,7 +422,7 @@ def _wad_ln(x: int256) -> int256:
p = unsafe_sub(unsafe_mul(p, x) >> 96, 14_706_773_417_378_608_786_704_636_184_526)
p = unsafe_sub(unsafe_mul(p, x), 795_164_235_651_350_426_258_249_787_498 << 96)

# We leave `p` in the "2 ** 192" base so that we do not have to scale it up
# We leave `p` in the "2**192" base so that we do not have to scale it up
# again for the division. Note that `q` is monic by convention.
q: int256 = unsafe_add(
unsafe_mul(unsafe_add(x, 5_573_035_233_440_673_466_300_451_813_936), x) >> 96,
Expand All @@ -435,15 +435,15 @@ def _wad_ln(x: int256) -> int256:
q = unsafe_add(unsafe_mul(q, x) >> 96, 909_429_971_244_387_300_277_376_558_375)

# It is known that the polynomial `q` has no zeros in the domain.
# No scaling is required, as `p` is already "2 ** 96" too large. Also,
# `r` is in the range "(0, 0.125) * 2 ** 96" after the division.
# No scaling is required, as `p` is already "2**96" too large. Also,
# `r` is in the range "(0, 0.125) * 2**96" after the division.
r: int256 = unsafe_div(p, q)

# To finalise the calculation, we have to proceed with the following steps:
# - multiply by the scaling factor "s = 5.549...",
# - add "ln(2 ** 96 / 10 ** 18)",
# - add "ln(2**96 / 10**18)",
# - add "k * ln(2)", and
# - multiply by "10 ** 18 / 2 ** 96 = 5 ** 18 >> 78".
# - multiply by "10**18 / 2**96 = 5**18 >> 78".
# In order to perform the most gas-efficient calculation, we carry out all
# these steps in one expression.
return (
Expand Down Expand Up @@ -477,17 +477,17 @@ def _wad_exp(x: int256) -> int256:
if x <= -41_446_531_673_892_822_313:
return empty(int256)

# When the result is "> (2 ** 255 - 1) / 1e18" we cannot represent it as a signed integer.
# This happens when "x >= floor(log((2 ** 255 - 1) / 1e18) * 1e18) ~ 135".
# When the result is "> (2**255 - 1) / 1e18" we cannot represent it as a signed integer.
# This happens when "x >= floor(log((2**255 - 1) / 1e18) * 1e18) ~ 135".
assert x < 135_305_999_368_893_231_589, "math: wad_exp overflow"

# `x` is now in the range "(-42, 136) * 1e18". Convert to "(-42, 136) * 2 ** 96" for higher
# `x` is now in the range "(-42, 136) * 1e18". Convert to "(-42, 136) * 2**96" for higher
# intermediate precision and a binary base. This base conversion is a multiplication with
# "1e18 / 2 ** 96 = 5 ** 18 / 2 ** 78".
# "1e18 / 2**96 = 5**18 / 2**78".
x = unsafe_div(x << 78, 5 ** 18)

# Reduce the range of `x` to "(-½ ln 2, ½ ln 2) * 2 ** 96" by factoring out powers of two
# so that "exp(x) = exp(x') * 2 ** k", where `k` is a signer integer. Solving this gives
# Reduce the range of `x` to "(-½ ln 2, ½ ln 2) * 2**96" by factoring out powers of two
# so that "exp(x) = exp(x') * 2**k", where `k` is a signer integer. Solving this gives
# "k = round(x / log(2))" and "x' = x - k * log(2)". Thus, `k` is in the range "[-61, 195]".
k: int256 = unsafe_add(unsafe_div(x << 96, 54_916_777_467_707_473_351_141_471_128), 2 ** 95) >> 96
x = unsafe_sub(x, unsafe_mul(k, 54_916_777_467_707_473_351_141_471_128))
Expand All @@ -509,7 +509,7 @@ def _wad_exp(x: int256) -> int256:
4_385_272_521_454_847_904_659_076_985_693_276 << 96,
)

# We leave `p` in the "2 ** 192" base so that we do not have to scale it up
# We leave `p` in the "2**192" base so that we do not have to scale it up
# again for the division.
q: int256 = unsafe_add(
unsafe_mul(unsafe_sub(x, 2_855_989_394_907_223_263_936_484_059_900), x) >> 96,
Expand All @@ -521,15 +521,15 @@ def _wad_exp(x: int256) -> int256:
q = unsafe_add(unsafe_mul(q, x) >> 96, 26_449_188_498_355_588_339_934_803_723_976_023)

# The polynomial `q` has no zeros in the range because all its roots are complex.
# No scaling is required, as `p` is already "2 ** 96" too large. Also,
# `r` is in the range "(0.09, 0.25) * 2 ** 96" after the division.
# No scaling is required, as `p` is already "2**96" too large. Also,
# `r` is in the range "(0.09, 0.25) * 2**96" after the division.
r: int256 = unsafe_div(p, q)

# To finalise the calculation, we have to multiply `r` by:
# - the scale factor "s = ~6.031367120",
# - the factor "2 ** k" from the range reduction, and
# - the factor "1e18 / 2 ** 96" for the base conversion.
# We do this all at once, with an intermediate result in "2 ** 213" base,
# - the factor "2**k" from the range reduction, and
# - the factor "1e18 / 2**96" for the base conversion.
# We do this all at once, with an intermediate result in "2**213" base,
# so that the final right shift always gives a positive value.

# Note that to circumvent Vyper's safecast feature for the potentially
Expand Down Expand Up @@ -605,15 +605,15 @@ def _wad_cbrt(x: uint256) -> uint256:
log2x: uint256 = self._log2(value, False)

# If we divide log2x by 3, the remainder is "log2x % 3". So if we simply
# multiply "2 ** (log2x/3)" and discard the remainder to calculate our guess,
# multiply "2**(log2x/3)" and discard the remainder to calculate our guess,
# the Newton-Raphson method takes more iterations to converge to a solution
# because it lacks this precision. A few more calculations now in order to
# do fewer calculations later:
# - "pow = log2(x) // 3" (the operator `//` means integer division),
# - "remainder = log2(x) % 3",
# - "initial_guess = 2 ** pow * cbrt(2) ** remainder".
# - "initial_guess = 2**pow * cbrt(2)**remainder".
# Now substituting "2 = 1.26 ≈ 1,260 / 1,000", we get:
# - "initial_guess = 2 ** pow * 1,260 ** remainder // 1,000 ** remainder".
# - "initial_guess = 2**pow * 1,260**remainder // 1,000**remainder".
remainder: uint256 = log2x % 3
y: uint256 = unsafe_div(
unsafe_mul(pow_mod256(2, unsafe_div(log2x, 3)), pow_mod256(1_260, remainder)), pow_mod256(1_000, remainder)
Expand Down

0 comments on commit 8de3c5d

Please sign in to comment.