From 77dd9cdfd685cc493a9d4d080d75ba10e8e0d640 Mon Sep 17 00:00:00 2001 From: "Davide Gessa (dakk)" Date: Tue, 16 Apr 2024 16:40:58 +0200 Subject: [PATCH] fix mul_even_const --- qlasskit/types/qint.py | 58 +++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/qlasskit/types/qint.py b/qlasskit/types/qint.py index 1d19edef..9c45ecac 100644 --- a/qlasskit/types/qint.py +++ b/qlasskit/types/qint.py @@ -165,11 +165,10 @@ def add(cls, tleft: TExp, tright: TExp) -> TExp: return (cls if cls.BIT_SIZE > tleft_e[0].BIT_SIZE else tleft_e[0], sums) @staticmethod - def mul_even_const(t_num: TExp, t_const: TExp) -> TExp: + def mul_even_const(t_num: TExp, const: int, result_type: "QintImp") -> TExp: """Multiply by an even const using shift and add (x << 3) + (x << 1) # Here 10*x is computed as x*2^3 + x*2 """ - const = t_const[0].from_bool(t_const[1]) # Multiply t_num by the nearest n | 2**n < t_const n = 1 @@ -178,22 +177,41 @@ def mul_even_const(t_num: TExp, t_const: TExp) -> TExp: if 2**n > const: n -= 1 - t_num_r = t_num[0].crop(t_num[0].shift_left(t_num, n)) + t_num_r = result_type.shift_left((result_type, t_num[1]), n) # Shift t_const by t_const - 2**n r = const - 2**n if r > 0: # Add the shift result to t_num - res = t_num_r[0].add( - t_num_r, t_num[0].crop(t_num[0].shift_left(t_num, int(r / 2))) + res = result_type.add( + (result_type, t_num_r[1]), + result_type.shift_left((result_type, t_num[1]), int(r / 2)), ) else: - res = t_num_r + res = (result_type, t_num_r[1]) return res @classmethod def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901 + def __mul_sizing(n, m): + if (n + m) <= 2: + return Qint2 + elif (n + m) > 2 and (n + m) <= 4: + return Qint4 + elif (n + m) > 4 and (n + m) <= 6: + return Qint6 + elif (n + m) > 6 and (n + m) <= 8: + return Qint8 + elif (n + m) > 8 and (n + m) <= 12: + return Qint12 + elif (n + m) > 12 and (n + m) <= 16: + return Qint16 + elif (n + m) > 16: + return Qint16 + else: + raise Exception(f"Mul result size is too big ({n+m})") + # Fill constants so explicit typecast is not needed if cls.is_const(tleft): tleft = tright[0].fill(tleft) @@ -201,6 +219,9 @@ def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901 if cls.is_const(tright): tright = tleft[0].fill(tright) + n = len(tleft[1]) + m = len(tright[1]) + # If one operand is an even constant, use mul_even_const if cls.is_const(tleft) or cls.is_const(tright): t_num = tleft if cls.is_const(tright) else tright @@ -208,10 +229,9 @@ def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901 const = t_const[0].from_bool(t_const[1]) if const % 2 == 0: - return cls.mul_even_const(t_num, t_const) - - n = len(tleft[1]) - m = len(tright[1]) + t = __mul_sizing(n, m) + res = cls.mul_even_const(t_num, const, t) + return t.crop(t.fill(res)) if n != m: raise Exception(f"Mul works only on same size Qint: {n} != {m}") @@ -233,22 +253,8 @@ def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901 if i + m < n + m: product[i + m] = carry - if (n + m) <= 2: - return Qint2, product - elif (n + m) > 2 and (n + m) <= 4: - return Qint4, product - elif (n + m) > 4 and (n + m) <= 6: - return Qint6, product - elif (n + m) > 6 and (n + m) <= 8: - return Qint8, product - elif (n + m) > 8 and (n + m) <= 12: - return Qint12, product - elif (n + m) > 12 and (n + m) <= 16: - return Qint16, product - elif (n + m) > 16: - return Qint16.crop((Qint16, product)) - - raise Exception(f"Mul result size is too big ({n+m})") + t = __mul_sizing(n, m) + return t.crop(t.fill((t, product))) @classmethod def sub(cls, tleft: TExp, tright: TExp) -> TExp: