Skip to content

Commit

Permalink
Included tests for GitHub attestations (#61)
Browse files Browse the repository at this point in the history
* Included tests for GitHub attestations

- Included tests for GitHub attestations and some simple clean up.

Signed-off-by: naveensrinivasan <[email protected]>

* Fixed review comments

Signed-off-by: naveensrinivasan <[email protected]>

---------

Signed-off-by: naveensrinivasan <[email protected]>
Signed-off-by: Tom Meadows <[email protected]>
Co-authored-by: Tom Meadows <[email protected]>
  • Loading branch information
naveensrinivasan and ChaosInTheCRD authored Jan 22, 2024
1 parent 3e7ddcc commit 07128d2
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 18 deletions.
41 changes: 27 additions & 14 deletions attestation/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"os"
"strings"

"github.com/davecgh/go-spew/spew"
"github.com/in-toto/go-witness/attestation"
"github.com/in-toto/go-witness/attestation/jwt"
"github.com/in-toto/go-witness/cryptoutil"
Expand All @@ -51,18 +50,22 @@ var (
_ attestation.BackReffer = &Attestor{}
)

// init registers the github attestor.
func init() {
attestation.RegisterAttestation(Name, Type, RunType, func() attestation.Attestor {
return New()
})
}

// ErrNotGitlab is an error type that indicates the environment is not a github ci job.
type ErrNotGitlab struct{}

// Error returns the error message for ErrNotGitlab.
func (e ErrNotGitlab) Error() string {
return "not in a github ci job"
}

// Attestor is a struct that holds the necessary information for github attestation.
type Attestor struct {
JWT *jwt.Attestor `json:"jwt,omitempty"`
CIConfigPath string `json:"ciconfigpath"`
Expand All @@ -81,6 +84,7 @@ type Attestor struct {
aud string
}

// New creates and returns a new github attestor.
func New() *Attestor {
return &Attestor{
aud: tokenAudience,
Expand All @@ -89,35 +93,39 @@ func New() *Attestor {
}
}

// Name returns the name of the attestor.
func (a *Attestor) Name() string {
return Name
}

// Type returns the type of the attestor.
func (a *Attestor) Type() string {
return Type
}

// RunType returns the run type of the attestor.
func (a *Attestor) RunType() attestation.RunType {
return RunType
}

// Attest performs the attestation for the github environment.
func (a *Attestor) Attest(ctx *attestation.AttestationContext) error {
if os.Getenv("GITHUB_ACTIONS") != "true" {
return ErrNotGitlab{}
}

jwtString, err := fetchToken(a.tokenURL, os.Getenv("ACTIONS_ID_TOKEN_REQUEST_TOKEN"), "witness")
if err != nil {
return err
return fmt.Errorf("error on fetching token %w", err)
}

spew.Dump(jwtString)
if jwtString == "" {
return fmt.Errorf("empty JWT string")
}

if jwtString != "" {
a.JWT = jwt.New(jwt.WithToken(jwtString), jwt.WithJWKSUrl(a.jwksURL))
if err := a.JWT.Attest(ctx); err != nil {
return err
}
a.JWT = jwt.New(jwt.WithToken(jwtString), jwt.WithJWKSUrl(a.jwksURL))
if err := a.JWT.Attest(ctx); err != nil {
return fmt.Errorf("failed to attest github jwt: %w", err)
}

a.CIServerUrl = os.Getenv("GITHUB_SERVER_URL")
Expand All @@ -134,6 +142,7 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error {
return nil
}

// Subjects returns a map of subjects and their corresponding digest sets.
func (a *Attestor) Subjects() map[string]cryptoutil.DigestSet {
subjects := make(map[string]cryptoutil.DigestSet)
hashes := []crypto.Hash{crypto.SHA256}
Expand All @@ -152,6 +161,7 @@ func (a *Attestor) Subjects() map[string]cryptoutil.DigestSet {
return subjects
}

// BackRefs returns a map of back references and their corresponding digest sets.
func (a *Attestor) BackRefs() map[string]cryptoutil.DigestSet {
backRefs := make(map[string]cryptoutil.DigestSet)
for subj, ds := range a.Subjects() {
Expand All @@ -164,13 +174,14 @@ func (a *Attestor) BackRefs() map[string]cryptoutil.DigestSet {
return backRefs
}

// fetchToken fetches the token from the given URL.
func fetchToken(tokenURL string, bearer string, audience string) (string, error) {
client := &http.Client{}

//add audient "&audience=witness" to the end of the tokenURL, parse it, and then add it to the query
// add audience "&audience=witness" to the end of the tokenURL, parse it, and then add it to the query
u, err := url.Parse(tokenURL)
if err != nil {
return "", err
return "", fmt.Errorf("error on parsing token url %w", err)
}

q := u.Query()
Expand All @@ -181,33 +192,35 @@ func fetchToken(tokenURL string, bearer string, audience string) (string, error)

req, err := http.NewRequest("GET", reqURL, nil)
if err != nil {
return "", err
return "", fmt.Errorf("error on creating request %w", err)
}
req.Header.Add("Authorization", "bearer "+bearer)
resp, err := client.Do(req)
if err != nil {
return "", err
return "", fmt.Errorf("error on request %w", err)
}
defer resp.Body.Close()
body, err := readResponseBody(resp.Body)
if err != nil {
return "", err
return "", fmt.Errorf("error on reading response body %w", err)
}

var tokenResponse GithubTokenResponse
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return "", err
return "", fmt.Errorf("error on unmarshaling token response %w", err)
}

return tokenResponse.Value, nil
}

// GithubTokenResponse is a struct that holds the response from the github token request.
type GithubTokenResponse struct {
Count int `json:"count"`
Value string `json:"value"`
}

// readResponseBody reads the response body and returns it as a byte slice.
func readResponseBody(body io.Reader) ([]byte, error) {
var buf bytes.Buffer
_, err := buf.ReadFrom(body)
Expand Down
122 changes: 122 additions & 0 deletions attestation/github/github_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2021 The Witness Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package github

import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func createMockServer() *httptest.Server {
type Response struct {
Count int `json:"count"`
Value string `json:"value"`
}
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/valid" && r.Header.Get("Authorization") == "bearer validBearer" {
resp, _ := json.Marshal(Response{Count: 1, Value: "validJWTToken"})
_, _ = w.Write(resp)
} else {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
}
}))
}

func TestFetchToken(t *testing.T) {
testCases := []struct {
name string
tokenURL string
bearer string
audience string
wantToken string
wantErr bool
}{
{
name: "valid token",
tokenURL: "/valid",
bearer: "validBearer",
audience: "validAudience",
wantToken: "validJWTToken",
wantErr: false,
},
{
name: "invalid token url",
tokenURL: "/invalid",
bearer: "validBearer",
audience: "validAudience",
wantToken: "",
wantErr: true,
},
{
name: "invalid bearer",
tokenURL: "/valid",
bearer: "invalidBearer",
audience: "validAudience",
wantToken: "",
wantErr: true,
},
{
name: "invalid url",
tokenURL: "invalidURL",
bearer: "validBearer",
audience: "validAudience",
wantToken: "",
wantErr: true,
},
}

server := createMockServer()
defer server.Close()

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
gotToken, err := fetchToken(server.URL+testCase.tokenURL, testCase.bearer, testCase.audience)
if testCase.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, testCase.wantToken, gotToken)
}
})
}
}

func TestSubjects(t *testing.T) {
tokenServer := createMockServer()
defer tokenServer.Close()
attestor := &Attestor{
aud: "projecturl",
jwksURL: tokenServer.URL,
tokenURL: os.Getenv("ACTIONS_ID_TOKEN_REQUEST_URL"),
}

subjects := attestor.Subjects()
assert.NotNil(t, subjects)
assert.Equal(t, 2, len(subjects))

expectedSubjects := []string{"pipelineurl:" + attestor.PipelineUrl, "projecturl:" + attestor.ProjectUrl}
for _, expectedSubject := range expectedSubjects {
_, ok := subjects[expectedSubject]
assert.True(t, ok, "Expected subject not found: %s", expectedSubject)
}
m := attestor.BackRefs()
assert.NotNil(t, m)
assert.Equal(t, 1, len(m))
}
8 changes: 4 additions & 4 deletions attestation/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,23 @@ func (a *Attestor) Attest(ctx *attestation.AttestationContext) error {

parsed, err := jwt.ParseSigned(a.token)
if err != nil {
return err
return fmt.Errorf("error parsing token: %w", err)
}

resp, err := http.Get(a.jwksUrl)
if err != nil {
return err
return fmt.Errorf("error fetching jwks: %w", err)
}

defer resp.Body.Close()
jwks := jose.JSONWebKeySet{}
decoder := json.NewDecoder(resp.Body)
if err := decoder.Decode(&jwks); err != nil {
return err
return fmt.Errorf("error decoding jwks: %w", err)
}

if err := parsed.Claims(jwks, &a.Claims); err != nil {
return err
return fmt.Errorf("error parsing claims: %w", err)
}

keyID := ""
Expand Down

0 comments on commit 07128d2

Please sign in to comment.