From 111a0789161e18acfab932eed280a120577acae6 Mon Sep 17 00:00:00 2001 From: Ivo Kubjas Date: Thu, 11 Jul 2024 11:41:45 +0200 Subject: [PATCH] fix: fix OR computation in case one input is constant and other variable (#1181) * test: add regression test * fix: or for constant input --- frontend/cs/scs/api.go | 23 +++++++---------------- frontend/cs/scs/api_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/frontend/cs/scs/api.go b/frontend/cs/scs/api.go index eacb0df6b2..21b6276672 100644 --- a/frontend/cs/scs/api.go +++ b/frontend/cs/scs/api.go @@ -351,9 +351,6 @@ func (builder *builder) Or(a, b frontend.Variable) frontend.Variable { return 0 } - res := builder.newInternalVariable() - builder.MarkBoolean(res) - // if one input is constant, ensure we put it in b if aConstant { a, b = b, a @@ -362,20 +359,14 @@ func (builder *builder) Or(a, b frontend.Variable) frontend.Variable { } if bConstant { - xa := a.(expr.Term) - // b = b - 1 - qL := _b - qL = builder.cs.Sub(qL, builder.tOne) - qL = builder.cs.Mul(qL, xa.Coeff) - // a * (b-1) + res == 0 - builder.addPlonkConstraint(sparseR1C{ - xa: xa.VID, - xc: res.VID, - qL: qL, - qO: builder.tOne, - }) - return res + if builder.cs.IsOne(_b) { + return 1 + } else { + return a + } } + res := builder.newInternalVariable() + builder.MarkBoolean(res) xa := a.(expr.Term) xb := b.(expr.Term) // -a - b + ab + res == 0 diff --git a/frontend/cs/scs/api_test.go b/frontend/cs/scs/api_test.go index 1f81e31c15..10ed705caa 100644 --- a/frontend/cs/scs/api_test.go +++ b/frontend/cs/scs/api_test.go @@ -257,3 +257,36 @@ func TestSubSameNoConstraint(t *testing.T) { t.Fatal("expected 0 constraints") } } + +type regressionOr struct { + A frontend.Variable + constOr int + constCheck int +} + +func (c *regressionOr) Define(api frontend.API) error { + y := api.Or(c.A, c.constOr) + api.AssertIsEqual(y, c.constCheck) + return nil +} + +func TestRegressionOr(t *testing.T) { + assert := test.NewAssert(t) + for _, tc := range []struct{ in, o, c int }{ + {1, 1, 1}, {0, 1, 1}, + {1, 0, 1}, {0, 0, 0}, + } { + ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, ®ressionOr{constOr: tc.o, constCheck: tc.c}) + assert.NoError(err) + w, err := frontend.NewWitness(®ressionOr{ + A: tc.in, + }, ecc.BN254.ScalarField()) + if err != nil { + t.Error("compile", err) + } + _, err = ccs.Solve(w) + if err != nil { + t.Error("solve", err) + } + } +}