Skip to content

Commit

Permalink
fix mul_even_const
Browse files Browse the repository at this point in the history
  • Loading branch information
dakk committed Apr 16, 2024
1 parent 48cf912 commit 77dd9cd
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions qlasskit/types/qint.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,10 @@ def add(cls, tleft: TExp, tright: TExp) -> TExp:
return (cls if cls.BIT_SIZE > tleft_e[0].BIT_SIZE else tleft_e[0], sums)

@staticmethod
def mul_even_const(t_num: TExp, t_const: TExp) -> TExp:
def mul_even_const(t_num: TExp, const: int, result_type: "QintImp") -> TExp:
"""Multiply by an even const using shift and add
(x << 3) + (x << 1) # Here 10*x is computed as x*2^3 + x*2
"""
const = t_const[0].from_bool(t_const[1])

# Multiply t_num by the nearest n | 2**n < t_const
n = 1
Expand All @@ -178,40 +177,61 @@ def mul_even_const(t_num: TExp, t_const: TExp) -> TExp:
if 2**n > const:
n -= 1

t_num_r = t_num[0].crop(t_num[0].shift_left(t_num, n))
t_num_r = result_type.shift_left((result_type, t_num[1]), n)

# Shift t_const by t_const - 2**n
r = const - 2**n
if r > 0:
# Add the shift result to t_num
res = t_num_r[0].add(
t_num_r, t_num[0].crop(t_num[0].shift_left(t_num, int(r / 2)))
res = result_type.add(
(result_type, t_num_r[1]),
result_type.shift_left((result_type, t_num[1]), int(r / 2)),
)
else:
res = t_num_r
res = (result_type, t_num_r[1])

return res

@classmethod
def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901
def __mul_sizing(n, m):
if (n + m) <= 2:
return Qint2
elif (n + m) > 2 and (n + m) <= 4:
return Qint4
elif (n + m) > 4 and (n + m) <= 6:
return Qint6
elif (n + m) > 6 and (n + m) <= 8:
return Qint8
elif (n + m) > 8 and (n + m) <= 12:
return Qint12
elif (n + m) > 12 and (n + m) <= 16:
return Qint16
elif (n + m) > 16:
return Qint16
else:
raise Exception(f"Mul result size is too big ({n+m})")

# Fill constants so explicit typecast is not needed
if cls.is_const(tleft):
tleft = tright[0].fill(tleft)

if cls.is_const(tright):
tright = tleft[0].fill(tright)

n = len(tleft[1])
m = len(tright[1])

# If one operand is an even constant, use mul_even_const
if cls.is_const(tleft) or cls.is_const(tright):
t_num = tleft if cls.is_const(tright) else tright
t_const = tleft if cls.is_const(tleft) else tright
const = t_const[0].from_bool(t_const[1])

if const % 2 == 0:
return cls.mul_even_const(t_num, t_const)

n = len(tleft[1])
m = len(tright[1])
t = __mul_sizing(n, m)
res = cls.mul_even_const(t_num, const, t)
return t.crop(t.fill(res))

if n != m:
raise Exception(f"Mul works only on same size Qint: {n} != {m}")
Expand All @@ -233,22 +253,8 @@ def mul(cls, tleft: TExp, tright: TExp) -> TExp: # noqa: C901
if i + m < n + m:
product[i + m] = carry

if (n + m) <= 2:
return Qint2, product
elif (n + m) > 2 and (n + m) <= 4:
return Qint4, product
elif (n + m) > 4 and (n + m) <= 6:
return Qint6, product
elif (n + m) > 6 and (n + m) <= 8:
return Qint8, product
elif (n + m) > 8 and (n + m) <= 12:
return Qint12, product
elif (n + m) > 12 and (n + m) <= 16:
return Qint16, product
elif (n + m) > 16:
return Qint16.crop((Qint16, product))

raise Exception(f"Mul result size is too big ({n+m})")
t = __mul_sizing(n, m)
return t.crop(t.fill((t, product)))

@classmethod
def sub(cls, tleft: TExp, tright: TExp) -> TExp:
Expand Down

0 comments on commit 77dd9cd

Please sign in to comment.