Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Zaptoss committed Jun 9, 2024
1 parent b94fe18 commit 09b72b3
Showing 1 changed file with 32 additions and 84 deletions.
116 changes: 32 additions & 84 deletions requests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ const (
deuCode = "4474197"

genesisCode = "kPRQYQUcWzW"

balancesEndpoint = "public/balances"
eventsEndpoint = "public/events"
)

var baseProof = zkptypes.ZKProof{
Expand All @@ -41,12 +44,10 @@ var baseProof = zkptypes.ZKProof{
}

func TestCreateBalance(t *testing.T) {
endpoint := "public/balances"

t.Run("SimpleBalance", func(t *testing.T) {
nullifier := "0x0000000000000000000000000000000000000000000000000000000000000001"
body := createBalanceBody(nullifier, genesisCode)
_, respCode := postRequest(t, endpoint, body, nullifier)
_, respCode := postPatchRequest(t, balancesEndpoint, body, nullifier, false)
if respCode != http.StatusOK {
t.Errorf("failed to create simple balance: want %d got %d", http.StatusOK, respCode)
}
Expand All @@ -55,7 +56,7 @@ func TestCreateBalance(t *testing.T) {
t.Run("SameBalance", func(t *testing.T) {
nullifier := "0x0000000000000000000000000000000000000000000000000000000000000001"
body := createBalanceBody(nullifier, genesisCode)
_, respCode := postRequest(t, endpoint, body, nullifier)
_, respCode := postPatchRequest(t, balancesEndpoint, body, nullifier, false)
if respCode != http.StatusConflict {
t.Errorf("want %d got %d", http.StatusConflict, respCode)
}
Expand All @@ -64,7 +65,7 @@ func TestCreateBalance(t *testing.T) {
t.Run("Unauthorized", func(t *testing.T) {
nullifier := "0x0000000000000000000000000000000000000000000000000000000000000002"
body := createBalanceBody(nullifier, genesisCode)
_, respCode := postRequest(t, endpoint, body, "0x1"+nullifier[3:])
_, respCode := postPatchRequest(t, balancesEndpoint, body, "0x1"+nullifier[3:], false)
if respCode != http.StatusUnauthorized {
t.Errorf("want %d got %d", http.StatusUnauthorized, respCode)
}
Expand All @@ -73,15 +74,14 @@ func TestCreateBalance(t *testing.T) {
t.Run("IncorrectCode", func(t *testing.T) {
nullifier := "0x0000000000000000000000000000000000000000000000000000000000000002"
body := createBalanceBody(nullifier, "someAntoherCode")
_, respCode := postRequest(t, endpoint, body, nullifier)
_, respCode := postPatchRequest(t, balancesEndpoint, body, nullifier, false)
if respCode != http.StatusNotFound {
t.Errorf("want %d got %d", http.StatusNotFound, respCode)
}
})
}

func TestVerifyPassport(t *testing.T) {
endpoint := "public/balances"
nullifier := "0x0000000000000000000000000000000000000000000000000000000000000002"
referrer := "0x0000000000000000000000000000000000000000000000000000000000000001"

Expand All @@ -97,14 +97,14 @@ func TestVerifyPassport(t *testing.T) {
body := verifyPassportBody(nullifier, proof)

t.Run("VerifyPassport", func(t *testing.T) {
_, respCode := postRequest(t, endpoint+"/"+nullifier+"/verifypassport", body, nullifier)
_, respCode := postPatchRequest(t, balancesEndpoint+"/"+nullifier+"/verifypassport", body, nullifier, false)
if respCode != http.StatusNoContent {
t.Errorf("failed to verify passport: want %d got %d", http.StatusNoContent, respCode)
}
})

t.Run("VerifyOneMore", func(t *testing.T) {
_, respCode := postRequest(t, endpoint+"/"+nullifier+"/verifypassport", body, nullifier)
_, respCode := postPatchRequest(t, balancesEndpoint+"/"+nullifier+"/verifypassport", body, nullifier, false)
if respCode != http.StatusTooManyRequests {
t.Errorf("want %d got %d", http.StatusTooManyRequests, respCode)
}
Expand All @@ -113,15 +113,14 @@ func TestVerifyPassport(t *testing.T) {
t.Run("IncorrectCoutnryCode", func(t *testing.T) {
proof.PubSignals[zk.Citizenship] = "6974819"
body := verifyPassportBody(referrer, proof)
_, respCode := postRequest(t, endpoint+"/"+referrer+"/verifypassport", body, referrer)
_, respCode := postPatchRequest(t, balancesEndpoint+"/"+referrer+"/verifypassport", body, referrer, false)
if respCode != http.StatusInternalServerError {
t.Errorf("want %d got %d", http.StatusInternalServerError, respCode)
}
})
}

func TestClaimEvent(t *testing.T) {
endpoint := "public/events"
nullifier1 := "0x0000000000000000000000000000000000000000000000000000000000000010"
nullifier2 := "0x0000000000000000000000000000000000000000000000000000000000000020"

Expand All @@ -138,7 +137,7 @@ func TestClaimEvent(t *testing.T) {

t.Run("TryClaimOpenEvent", func(t *testing.T) {
body := claimEventBody(passportScanEventID)
_, respCode := patchRequest(t, endpoint+"/"+passportScanEventID, body, nullifier1)
_, respCode := postPatchRequest(t, eventsEndpoint+"/"+passportScanEventID, body, nullifier1, true)
if respCode != http.StatusNotFound {
t.Errorf("want %d got %d", http.StatusNotFound, respCode)
}
Expand All @@ -154,7 +153,7 @@ func TestClaimEvent(t *testing.T) {

t.Run("TryClaimEventWithoutPassport", func(t *testing.T) {
body := claimEventBody(refSpecEventID)
_, respCode := patchRequest(t, endpoint+"/"+refSpecEventID, body, nullifier1)
_, respCode := postPatchRequest(t, eventsEndpoint+"/"+refSpecEventID, body, nullifier1, true)
if respCode != http.StatusForbidden {
t.Errorf("want %d got %d", http.StatusForbidden, respCode)
}
Expand All @@ -167,7 +166,7 @@ func TestClaimEvent(t *testing.T) {

t.Run("ClaimEvent", func(t *testing.T) {
body := claimEventBody(passportScanEventID)
_, respCode := patchRequest(t, endpoint+"/"+passportScanEventID, body, nullifier2)
_, respCode := postPatchRequest(t, eventsEndpoint+"/"+passportScanEventID, body, nullifier2, true)
if respCode != http.StatusOK {
t.Errorf("want %d got %d", http.StatusOK, respCode)
}
Expand Down Expand Up @@ -231,15 +230,13 @@ func TestCountryPools(t *testing.T) {
})

t.Run("OverLimit", func(t *testing.T) {
endpoint := "public/events"

freeWeeklyEventID, _ := getEventFromList(getEvents(t, nullifier), evtypes.TypeFreeWeekly)
if freeWeeklyEventID == "" {
t.Fatalf("free weekly event absent for %s", nullifier)
}

body := claimEventBody(freeWeeklyEventID)
_, respCode := patchRequest(t, endpoint+"/"+freeWeeklyEventID, body, nullifier)
_, respCode := postPatchRequest(t, eventsEndpoint+"/"+freeWeeklyEventID, body, nullifier, true)
if respCode != http.StatusForbidden {
t.Errorf("want %d got %d", http.StatusForbidden, respCode)
}
Expand All @@ -251,15 +248,13 @@ func TestCountryPools(t *testing.T) {
verifyPassport(t, nullifier, gbrCode)

t.Run("NotAllowedReserve", func(t *testing.T) {
endpoint := "public/events"

freeWeeklyEventID, _ := getEventFromList(getEvents(t, nullifier), evtypes.TypeFreeWeekly)
if freeWeeklyEventID == "" {
t.Fatalf("free weekly event absent for %s", nullifier)
}

body := claimEventBody(freeWeeklyEventID)
_, respCode := patchRequest(t, endpoint+"/"+freeWeeklyEventID, body, nullifier)
_, respCode := postPatchRequest(t, eventsEndpoint+"/"+freeWeeklyEventID, body, nullifier, true)
if respCode != http.StatusForbidden {
t.Errorf("want %d got %d", http.StatusForbidden, respCode)
}
Expand All @@ -280,15 +275,13 @@ func TestCountryPools(t *testing.T) {
})

t.Run("DefaultOverLimit", func(t *testing.T) {
endpoint := "public/events"

freeWeeklyEventID, _ := getEventFromList(getEvents(t, nullifier), evtypes.TypeFreeWeekly)
if freeWeeklyEventID == "" {
t.Fatalf("free weekly event absent for %s", nullifier)
}

body := claimEventBody(freeWeeklyEventID)
_, respCode := patchRequest(t, endpoint+"/"+freeWeeklyEventID, body, nullifier)
_, respCode := postPatchRequest(t, eventsEndpoint+"/"+freeWeeklyEventID, body, nullifier, true)
if respCode != http.StatusForbidden {
t.Errorf("want %d got %d", http.StatusForbidden, respCode)
}
Expand All @@ -305,10 +298,8 @@ func getEventFromList(events resources.EventListResponse, evtype string) (id, st
}

func claimEvent(t *testing.T, id, nullifier string) resources.EventResponse {
endpoint := "public/events"

body := claimEventBody(id)
respBody, respCode := patchRequest(t, endpoint+"/"+id, body, nullifier)
respBody, respCode := postPatchRequest(t, eventsEndpoint+"/"+id, body, nullifier, true)
if respCode != http.StatusOK {
t.Errorf("want %d got %d", http.StatusOK, respCode)
}
Expand All @@ -327,18 +318,15 @@ func verifyPassport(t *testing.T, nullifier, country string) {
proof.PubSignals[zk.Citizenship] = country
body := verifyPassportBody(nullifier, proof)

endpoint := "public/balances"
_, respCode := postRequest(t, endpoint+"/"+nullifier+"/verifypassport", body, nullifier)
_, respCode := postPatchRequest(t, balancesEndpoint+"/"+nullifier+"/verifypassport", body, nullifier, false)
if respCode != http.StatusNoContent {
t.Errorf("failed to verify passport: want %d got %d", http.StatusNoContent, respCode)
}
}

func getEvents(t *testing.T, nullifier string) resources.EventListResponse {
endpoint := "public/events"

respBody, respCode := getRequest(t,
endpoint, func() url.Values {
eventsEndpoint, func() url.Values {
query := url.Values{}
query.Add("filter[nullifier]", nullifier)
return query
Expand All @@ -357,10 +345,8 @@ func getEvents(t *testing.T, nullifier string) resources.EventListResponse {
}

func createBalance(t *testing.T, nullifier, code string) resources.BalanceResponse {
endpoint := "public/balances"

body := createBalanceBody(nullifier, code)
respBody, respCode := postRequest(t, endpoint, body, nullifier)
respBody, respCode := postPatchRequest(t, balancesEndpoint, body, nullifier, false)
if respCode != http.StatusOK {
t.Fatalf("failed to create simple balance: want %d got %d", http.StatusOK, respCode)
}
Expand All @@ -375,10 +361,8 @@ func createBalance(t *testing.T, nullifier, code string) resources.BalanceRespon
}

func getBalance(t *testing.T, nullifier string) resources.BalanceResponse {
endpoint := "public/balances"

respBody, respCode := getRequest(t,
endpoint+"/"+nullifier,
balancesEndpoint+"/"+nullifier,
func() url.Values {
query := url.Values{}
query.Add("referral_codes", "true")
Expand Down Expand Up @@ -435,7 +419,7 @@ func claimEventBody(id string) resources.Relation {
}
}

func patchRequest(t *testing.T, endpoint string, body any, user string) ([]byte, int) {
func postPatchRequest(t *testing.T, endpoint string, body any, user string, patch bool) ([]byte, int) {
if body == nil {
t.Fatal("request body not provided")
}
Expand All @@ -447,42 +431,13 @@ func patchRequest(t *testing.T, endpoint string, body any, user string) ([]byte,
log.Printf(" endpoint=/%s body=%s", endpoint, body)

reqBody := strings.NewReader(string(bodyJSON))
req, err := http.NewRequest("PATCH", apiURL+endpoint, reqBody)
if err != nil {
t.Fatalf("failed to create patch request: %v", err)
}

if user != "" {
req.Header.Set("nullifier", user)
}

resp, err := (&http.Client{Timeout: requestTimeout}).Do(req)
if err != nil {
t.Fatalf("failed to perform patch request: %v", err)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read resp body: %v", err)
}

log.Printf(" endpoint=/%s body=%s", endpoint, respBody)

return respBody, resp.StatusCode
}

func postRequest(t *testing.T, endpoint string, body any, user string) ([]byte, int) {
if body == nil {
t.Fatal("request body not provided")
}
bodyJSON, err := json.Marshal(body)
if err != nil {
t.Fatalf("failed to marshal request bode: %v", err)
reqType := "POST"
if patch {
reqType = "PATCH"
}

log.Printf(" endpoint=/%s body=%s", endpoint, body)

reqBody := strings.NewReader(string(bodyJSON))
req, err := http.NewRequest("POST", apiURL+endpoint, reqBody)
req, err := http.NewRequest(reqType, apiURL+endpoint, reqBody)
if err != nil {
t.Fatalf("failed to create post request: %v", err)
}
Expand All @@ -495,6 +450,9 @@ func postRequest(t *testing.T, endpoint string, body any, user string) ([]byte,
if err != nil {
t.Fatalf("failed to perform post request: %v", err)
}
defer func() {
resp.Body.Close()
}()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("failed to read resp body: %v", err)
Expand Down Expand Up @@ -527,25 +485,15 @@ func getRequest(t *testing.T, endpoint string, query url.Values, user string) ([
if err != nil {
t.Fatalf("failed to read resp body: %v", err)
}
defer func() {
resp.Body.Close()
}()

log.Printf(" endpoint=/%s body=%s", endpoint, respBody)

return respBody, resp.StatusCode
}

func checkResponseStatus(t *testing.T, got int, expectedCodes ...int) {
// 200 OK code is the most common
if len(expectedCodes) == 0 {
expectedCodes = []int{http.StatusOK}
}
for _, exp := range expectedCodes {
if exp == got {
return
}
}
t.Fatalf("expected status one of %v, got status=%d", expectedCodes, got)
}

var apiURL = func() string {
var cfg struct {
Addr string `fig:"addr,required"`
Expand Down

0 comments on commit 09b72b3

Please sign in to comment.