Skip to content

Commit

Permalink
Updated gates
Browse files Browse the repository at this point in the history
  • Loading branch information
Eagle941 committed Mar 7, 2024
1 parent 25ee86f commit d26686e
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 60 deletions.
24 changes: 24 additions & 0 deletions ProvenZk/Binary.lean
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ lemma is_vector_binary_iff_exists_bool_vec {N n : ℕ} {v : Vector (ZMod N) n}:
simp [Vector.toList, htl]
rfl

def recover_binary_nat {d} (rep : Vector Bool d): Nat := match d with
| 0 => 0
| Nat.succ _ => rep.head.toNat + 2 * recover_binary_nat rep.tail

def recover_binary_zmod' {d n} (rep : Vector (ZMod n) d) : ZMod n := match d with
| 0 => 0
| Nat.succ _ => rep.head + 2 * recover_binary_zmod' rep.tail
Expand All @@ -115,6 +119,26 @@ protected theorem Nat.add_lt_add_of_le_of_lt {a b c d : Nat} (hle : a ≤ b) (hl
a + c < b + d :=
Nat.lt_of_le_of_lt (Nat.add_le_add_right hle _) (Nat.add_lt_add_left hlt _)

def binary_length (n : Nat) : Nat := (Nat.log 2 n).succ

def bit_mod_two (inp : Nat) : Bool := match h:inp%2 with
| 0 => false
| 1 => true
| x + 2 => False.elim (by
have := Nat.mod_lt inp (y := 2)
rw [h] at this
simp at this
contradiction
)

def nat_to_bits_le_full_n (l : Nat): Nat → Vector Bool l := match l with
| 0 => fun _ => Vector.nil
| Nat.succ l => fun i =>
let x := i / 2
let y := bit_mod_two i
let xs := nat_to_bits_le_full_n l x
y ::ᵥ xs

namespace Fin

def msb {d:ℕ} (v : Fin (2^d.succ)): Bool := v.val ≥ 2^d
Expand Down
124 changes: 103 additions & 21 deletions ProvenZk/Gates.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,116 @@ import ProvenZk.Binary

open BigOperators

namespace Gates
namespace GatesDef
variable {N : Nat}
def is_bool (a : ZMod N): Prop := a = 0 ∨ a = 1
-- variable [Fact (Nat.Prime N)]
def is_bool (a : ZMod N): Prop := (1-a)*a = 0
def add (a b : ZMod N): ZMod N := a + b
def mul_acc (a b c : ZMod N): ZMod N := a + (b * c)
def neg (a : ZMod N): ZMod N := a * (-1)
def sub (a b : ZMod N): ZMod N := a - b
def mul (a b : ZMod N): ZMod N := a * b
def div_unchecked [Fact (Nat.Prime N)] (a b out : ZMod N): Prop := (b ≠ 0 ∧ out = a * (1 / b)) ∨ (a = 0 ∧ b = 0 ∧ out = 0)
def div [Fact (Nat.Prime N)] (a b out : ZMod N): Prop := b ≠ 0 ∧ out = a * (1 / b)
def inv [Fact (Nat.Prime N)] (a out : ZMod N): Prop := a ≠ 0 ∧ out = 1 / a
def xor (a b out : ZMod N): Prop := is_bool a ∧ is_bool b ∧ out = a*(1-2*b)+b
def div_unchecked [Fact (Nat.Prime N)] (a b out : ZMod N): Prop := (b ≠ 0 ∧ out*b = a) ∨ (a = 0 ∧ b = 0 ∧ out = 0)
def div [Fact (Nat.Prime N)] (a b out : ZMod N): Prop := b ≠ 0 ∧ out*b = a
def inv (a out : ZMod N): Prop := a ≠ 0 ∧ out*a = 1
def xor (a b out : ZMod N): Prop := is_bool a ∧ is_bool b ∧ out = a+b-a*b-a*b
def or (a b out : ZMod N): Prop := is_bool a ∧ is_bool b ∧ out = a+b-a*b
def and (a b out : ZMod N): Prop := is_bool a ∧ is_bool b ∧ out = a*b
def select (b i1 i2 out : ZMod N): Prop := is_bool b ∧ ((b = 1 ∧ out = i1) ∨ (b = 0 ∧ out = i2))
def lookup (b0 b1 i0 i1 i2 i3 out : ZMod N): Prop := is_bool b0 ∧ is_bool b1 ∧ (
(b0 = 0 ∧ b1 = 0 ∧ out = i0) ∨
(b0 = 1 ∧ b1 = 0 ∧ out = i1) ∨
(b0 = 0 ∧ b1 = 1 ∧ out = i2) ∨
(b0 = 1 ∧ b1 = 1 ∧ out = i3)
)
def cmp (a b : ZMod N) (out : ZMod N): Prop := (a = b ∧ out = 0) ∨
(ZMod.val a < ZMod.val b ∧ out = -1) ∨
(ZMod.val a > ZMod.val b ∧ out = 1)
def is_zero (a out: ZMod N): Prop := (a = 0 ∧ out = 1) ∨ (a != 0 ∧ out = 0)
def select (b i1 i2 out : ZMod N): Prop := is_bool b ∧ out = i2 - b*(i2-i1)
def lookup (b0 b1 i0 i1 i2 i3 out : ZMod N): Prop :=
is_bool b0 ∧ is_bool b1 ∧
out = (i2 - i0) * b1 + i0 + (((i3 - i2 - i1 + i0) * b1 + i1 - i0) * b0)

-- In gnark 8 the number is decomposed in a binary vector with the length of the field order
-- however this doesn't guarantee that the number is unique.
def cmp_8 (a b out : ZMod N): Prop :=
((recover_binary_nat (nat_to_bits_le_full_n (binary_length N) a.val)) % N = a.val) ∧
((recover_binary_nat (nat_to_bits_le_full_n (binary_length N) b.val)) % N = b.val) ∧
((a = b ∧ out = 0) ∨
(a.val < b.val ∧ out = -1) ∨
(a.val > b.val ∧ out = 1))

-- In gnark 9 the number is reduced to the smallest representation, ensuring it is unique.
def cmp_9 (a b out : ZMod N): Prop :=
((recover_binary_nat (nat_to_bits_le_full_n (binary_length N) a.val)) = a.val) ∧
((recover_binary_nat (nat_to_bits_le_full_n (binary_length N) b.val)) = b.val) ∧
((a = b ∧ out = 0) ∨
(a.val < b.val ∧ out = -1) ∨
(a.val > b.val ∧ out = 1))

-- Inverse is calculated using a Hint at circuit execution
def is_zero (a out: ZMod N): Prop := (a ≠ 0 ∧ out = 0) ∨ (a = 0 ∧ out = 1)
def eq (a b : ZMod N): Prop := a = b
def ne (a b : ZMod N): Prop := a ≠ b
def le (a b : ZMod N): Prop := ZMod.val a <= ZMod.val b
def to_binary (a : ZMod N) (n : Nat) (out : Vector (ZMod N) n): Prop := recover_binary_zmod' out = a ∧ is_vector_binary out
def from_binary {d} (a : Vector (ZMod N) d) (out : ZMod N): Prop := (recover_binary_zmod' a : ZMod N) = out
end Gates

def le_8 (a b : ZMod N): Prop :=
(recover_binary_nat (nat_to_bits_le_full_n (binary_length N) a.val)) % N = a.val ∧
(recover_binary_nat (nat_to_bits_le_full_n (binary_length N) b.val)) % N = b.val ∧
a.val <= b.val

def le_9 (a b : ZMod N): Prop :=
(recover_binary_nat (nat_to_bits_le_full_n (binary_length N) a.val)) = a.val ∧
(recover_binary_nat (nat_to_bits_le_full_n (binary_length N) b.val)) = b.val ∧
a.val <= b.val

-- `a(.val)` is always less than `N` because it's `ZMod`.
-- If `a` doesn't fit in `n`, then the result of `recover_binary_zmod'` is `a % 2^n`
-- If `a` fits `n`, the result is exact
def to_binary (a : ZMod N) (d : Nat) (out : Vector (ZMod N) d): Prop :=
@recover_binary_zmod' d N out = a ∧ is_vector_binary out
def from_binary {d} (a : Vector (ZMod N) d) (out : ZMod N) : Prop :=
@recover_binary_zmod' d N a = out ∧ is_vector_binary a
end GatesDef

structure Gates_base (α : Type) : Type where
is_bool : α → Prop
add : α → α → α
mul_acc : α → α → α → α
neg : α → α
sub : α → α → α
mul : α → α → α
div_unchecked : α → α → α → Prop
div : α → α → α → Prop
inv : α → α → Prop
xor : α → α → α → Prop
or : α → α → α → Prop
and : α → α → α → Prop
select : α → α → α → α → Prop
lookup : α → α → α → α → α → α → α → Prop
cmp : α → α → α → Prop
is_zero : α → α → Prop
eq : α → α → Prop
ne : α → α → Prop
le : α → α → Prop
to_binary : α → (n : Nat) → Vector α n → Prop
from_binary : Vector α d → α → Prop

def GatesGnark_8 (N : Nat) [Fact (Nat.Prime N)] : Gates_base (ZMod N) := {
is_bool := GatesDef.is_bool,
add := GatesDef.add,
mul_acc := GatesDef.mul_acc,
neg := GatesDef.neg,
sub := GatesDef.sub,
mul := GatesDef.mul,
div_unchecked := GatesDef.div_unchecked,
div := GatesDef.div,
inv := GatesDef.inv,
xor := GatesDef.xor,
or := GatesDef.or,
and := GatesDef.and,
select := GatesDef.select,
lookup := GatesDef.lookup,
cmp := GatesDef.cmp_8,
is_zero := GatesDef.is_zero,
eq := GatesDef.eq,
ne := GatesDef.ne,
le := GatesDef.le_8,
to_binary := GatesDef.to_binary,
from_binary := GatesDef.from_binary
}

def GatesGnark_9 (N : Nat) [Fact (Nat.Prime N)] : Gates_base (ZMod N) := {
GatesGnark_8 N with
cmp := GatesDef.cmp_9
le := GatesDef.le_9
}
81 changes: 42 additions & 39 deletions ProvenZk/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,95 +26,98 @@ theorem exists_eq_left₂ {pred : α → β → Prop}:
simp [and_assoc]

@[simp]
theorem is_bool_is_bit (a : ZMod n) [Fact (Nat.Prime n)]: Gates.is_bool a = is_bit a := by rfl
theorem is_bool_is_bit (a : ZMod n) [Fact (Nat.Prime n)]: GatesDef.is_bool a = is_bit a := by
simp [is_bit, GatesDef.is_bool, sub_eq_zero]
tauto

@[simp]
theorem Gates.eq_def : Gates.eq a b ↔ a = b := by simp [Gates.eq]
theorem Gates.eq_def : GatesDef.eq a b ↔ a = b := by simp [GatesDef.eq]

@[simp]
theorem Gates.sub_def {N} {a b : ZMod N} : Gates.sub a b = a - b := by simp [Gates.sub]
theorem Gates.sub_def {N} {a b : ZMod N} : GatesDef.sub a b = a - b := by simp [GatesDef.sub]

@[simp]
theorem Gates.is_zero_def {N} {a out : ZMod N} : Gates.is_zero a out ↔ out = Bool.toZMod (a = 0) := by
simp [Gates.is_zero]
theorem Gates.is_zero_def {N} {a out : ZMod N} : GatesDef.is_zero a out ↔ out = Bool.toZMod (a = 0) := by
simp [GatesDef.is_zero]
apply Iff.intro
. rintro (_ | _) <;> simp [*]
. rintro ⟨_⟩
simp [Bool.toZMod, Bool.toNat]
tauto

@[simp]
theorem Gates.select_zero {a b r : ZMod N}: Gates.select 0 a b r = (r = b) := by
simp [Gates.select]
theorem Gates.select_zero {a b r : ZMod N}: GatesDef.select 0 a b r = (r = b) := by
simp [GatesDef.select]

@[simp]
theorem Gates.select_one {a b r : ZMod N}: Gates.select 1 a b r = (r = a) := by
simp [Gates.select]
theorem Gates.select_one {a b r : ZMod N}: GatesDef.select 1 a b r = (r = a) := by
simp [GatesDef.select]

@[simp]
theorem Gates.or_zero { a r : ZMod N}: Gates.or a 0 r = (is_bit a ∧ r = a) := by
simp [Gates.or]
theorem Gates.or_zero { a r : ZMod N}: GatesDef.or a 0 r = (is_bit a ∧ r = a) := by
simp [GatesDef.or]

@[simp]
theorem Gates.zero_or { a r : ZMod N}: Gates.or 0 a r = (is_bit a ∧ r = a) := by
simp [Gates.or]
theorem Gates.zero_or { a r : ZMod N}: GatesDef.or 0 a r = (is_bit a ∧ r = a) := by
simp [GatesDef.or]

@[simp]
theorem Gates.one_or { a r : ZMod N}: Gates.or 1 a r = (is_bit a ∧ r = 1) := by
simp [Gates.or]
theorem Gates.one_or { a r : ZMod N}: GatesDef.or 1 a r = (is_bit a ∧ r = 1) := by
simp [GatesDef.or]

@[simp]
theorem Gates.or_one { a r : ZMod N}: Gates.or a 1 r = (is_bit a ∧ r = 1) := by
simp [Gates.or]
theorem Gates.or_one { a r : ZMod N}: GatesDef.or a 1 r = (is_bit a ∧ r = 1) := by
simp [GatesDef.or]

@[simp]
theorem Gates.is_bit_one_sub {a : ZMod N}: is_bit (Gates.sub 1 a) ↔ is_bit a := by
simp [Gates.sub, is_bit, sub_eq_zero]
theorem Gates.is_bit_one_sub {a : ZMod N}: is_bit (GatesDef.sub 1 a) ↔ is_bit a := by
simp [GatesDef.sub, is_bit, sub_eq_zero]
tauto

@[simp]
theorem Gates.xor_bool {N} [Fact (N>1)] {a b : Bool} {c : ZMod N} : Gates.xor a.toZMod b.toZMod c ↔ c = (a != b).toZMod := by
unfold xor
theorem Gates.xor_bool {N} [Fact (N>1)] {a b : Bool} {c : ZMod N} : GatesDef.xor a.toZMod b.toZMod c ↔ c = (a != b).toZMod := by
unfold GatesDef.xor
cases a <;> cases b <;> {
simp [is_bool, Bool.toZMod, Bool.toNat, bne]
simp [GatesDef.is_bool, Bool.toZMod, Bool.toNat, bne]
try ring_nf
}

@[simp]
theorem Gates.and_bool {N} [Fact (N>1)] {a b : Bool} {c : ZMod N} : Gates.and a.toZMod b.toZMod c ↔ c = (a && b).toZMod := by
unfold and
theorem Gates.and_bool {N} [Fact (N>1)] {a b : Bool} {c : ZMod N} : GatesDef.and a.toZMod b.toZMod c ↔ c = (a && b).toZMod := by
unfold GatesDef.and
cases a <;> cases b <;> {
simp [is_bool, Bool.toZMod, Bool.toNat]
simp [GatesDef.is_bool, Bool.toZMod, Bool.toNat]
}

@[simp]
theorem Gates.or_bool {N} [Fact (N>1)] {a b : Bool} {c : ZMod N} : Gates.or a.toZMod b.toZMod c ↔ c = (a || b).toZMod := by
unfold or
theorem Gates.or_bool {N} [Fact (N>1)] {a b : Bool} {c : ZMod N} : GatesDef.or a.toZMod b.toZMod c ↔ c = (a || b).toZMod := by
unfold GatesDef.or
cases a <;> cases b <;> {
simp [is_bool, Bool.toZMod, Bool.toNat]
simp [GatesDef.is_bool, Bool.toZMod, Bool.toNat]
}

@[simp]
theorem Gates.not_bool {N} [Fact (N>1)] {a : Bool} : (1 : ZMod N) - a.toZMod = (!a).toZMod := by
cases a <;> simp [sub]
cases a <;> simp [GatesDef.sub]

@[simp]
lemma Gates.select_bool {N} [Fact (N > 1)] {c : Bool} {t f r : ZMod N}: Gates.select (c.toZMod (N:=N)) t f r ↔ r = if c then t else f := by
cases c <;> simp [select, is_bool]
lemma Gates.select_bool {N} [Fact (N > 1)] {c : Bool} {t f r : ZMod N}: GatesDef.select (c.toZMod (N:=N)) t f r ↔ r = if c then t else f := by
cases c <;> simp [GatesDef.select, GatesDef.is_bool]

@[simp]
lemma Gates.eq_1_toZMod {N} [Fact (N>1)] {b : Bool}: Gates.eq (b.toZMod (N:=N)) 1 ↔ b := by
cases b <;> simp [eq, is_bool]
lemma Gates.eq_1_toZMod {N} [Fact (N>1)] {b : Bool}: GatesDef.eq (b.toZMod (N:=N)) 1 ↔ b := by
cases b <;> simp [GatesDef.eq, GatesDef.is_bool]

@[simp]
lemma Gates.ite_0_toZMod {N} [Fact (N>1)] {b f: Bool}: (if b then (0:ZMod N) else f.toZMod (N:=N)) = (if b then false else f).toZMod := by
cases b <;> simp

theorem Gates.to_binary_rangecheck {a : ZMod N} {n out} (h: to_binary a n out): a.val < 2^n := by
theorem Gates.to_binary_rangecheck {a : ZMod N} {n out} (h: GatesDef.to_binary a n out): a.val < 2^n := by
rcases h with ⟨hrec, hbin⟩
replace hbin := is_vector_binary_iff_exists_bool_vec.mp hbin
rcases hbin with ⟨x, ⟨_⟩⟩
rw [recover_binary_zmod'_map_toZMod_eq_Fin_ofBitsLE] at hrec

cases Nat.lt_or_ge (2^n) N with
| inl hp =>
cases hrec
Expand All @@ -127,8 +130,8 @@ theorem Gates.to_binary_rangecheck {a : ZMod N} {n out} (h: to_binary a n out):
. simp [*]

lemma Gates.to_binary_iff_eq_Fin_ofBitsLE {l : ℕ} {a : ZMod N} {v : Vector (ZMod N) l}:
Gates.to_binary a l v ↔ ∃v', v = v'.map Bool.toZMod ∧ a = (Fin.ofBitsLE v').val := by
unfold to_binary
GatesDef.to_binary a l v ↔ ∃v', v = v'.map Bool.toZMod ∧ a = (Fin.ofBitsLE v').val := by
unfold GatesDef.to_binary
rw [is_vector_binary_iff_exists_bool_vec]
apply Iff.intro
. rintro ⟨⟨_⟩, ⟨x, ⟨_⟩⟩⟩
Expand All @@ -152,7 +155,7 @@ lemma map_toZMod_ofZMod_eq_self_of_is_vector_binary {n : ℕ} {v : Vector (ZMod
simp [*]

lemma Gates.to_binary_iff_eq_fin_to_bits_le_of_pow_length_lt {l : ℕ} {a : ZMod N} {v : Vector (ZMod N) l} (pow_lt : 2 ^ l < N):
Gates.to_binary a l v ↔ ∃(ha : a.val < 2^l), v = (Fin.toBitsLE ⟨a.val, ha⟩).map Bool.toZMod := by
GatesDef.to_binary a l v ↔ ∃(ha : a.val < 2^l), v = (Fin.toBitsLE ⟨a.val, ha⟩).map Bool.toZMod := by
apply Iff.intro
. intro to_bin
have := Gates.to_binary_rangecheck to_bin
Expand All @@ -171,5 +174,5 @@ lemma Gates.to_binary_iff_eq_fin_to_bits_le_of_pow_length_lt {l : ℕ} {a : ZMod
simp [*]

lemma Gates.from_binary_iff_eq_ofBitsLE_mod_order {l : ℕ} {a : Vector Bool l} {out : ZMod N}:
Gates.from_binary (a.map Bool.toZMod) out ↔ out = (Fin.ofBitsLE a).val := by
simp [from_binary, recover_binary_zmod'_map_toZMod_eq_Fin_ofBitsLE, eq_comm]
GatesDef.from_binary (a.map Bool.toZMod) out ↔ out = (Fin.ofBitsLE a).val := by
simp [GatesDef.from_binary, recover_binary_zmod'_map_toZMod_eq_Fin_ofBitsLE, eq_comm]

0 comments on commit d26686e

Please sign in to comment.