diff --git a/go/extractor/project/BUILD.bazel b/go/extractor/project/BUILD.bazel index b4401187ce2d5..9170ff95be357 100644 --- a/go/extractor/project/BUILD.bazel +++ b/go/extractor/project/BUILD.bazel @@ -19,5 +19,8 @@ go_test( name = "project_test", srcs = ["project_test.go"], embed = [":project"], - deps = ["//go/extractor/vendor/golang.org/x/mod/modfile"], + deps = [ + "//go/extractor/util", + "//go/extractor/vendor/golang.org/x/mod/modfile", + ], ) diff --git a/go/extractor/project/project_test.go b/go/extractor/project/project_test.go index b7485960b5fd6..149a9723ec29c 100644 --- a/go/extractor/project/project_test.go +++ b/go/extractor/project/project_test.go @@ -4,6 +4,7 @@ import ( "path/filepath" "testing" + "github.com/github/codeql-go/extractor/util" "golang.org/x/mod/modfile" ) @@ -28,14 +29,18 @@ func TestStartsWithAnyOf(t *testing.T) { testStartsWithAnyOf(t, filepath.Join("foo", "bar"), filepath.Join("foo", "baz"), false) } -func testHasInvalidToolchainVersion(t *testing.T, contents string) bool { - modFile, err := modfile.Parse("test.go", []byte(contents), nil) +func parseModFile(t *testing.T, contents string) *modfile.File { + modFile, err := modfile.Parse("go.mod", []byte(contents), nil) if err != nil { t.Errorf("Unable to parse %s: %s.\n", contents, err.Error()) } - return hasInvalidToolchainVersion(modFile) + return modFile +} + +func testHasInvalidToolchainVersion(t *testing.T, contents string) bool { + return hasInvalidToolchainVersion(parseModFile(t, contents)) } func TestHasInvalidToolchainVersion(t *testing.T) { @@ -62,3 +67,74 @@ func TestHasInvalidToolchainVersion(t *testing.T) { } } } + +func parseWorkFile(t *testing.T, contents string) *modfile.WorkFile { + workFile, err := modfile.ParseWork("go.work", []byte(contents), nil) + + if err != nil { + t.Errorf("Unable to parse %s: %s.\n", contents, err.Error()) + } + + return workFile +} + +func TestRequiredGoVersion(t *testing.T) { + type ModVersionPair struct { + FileContents string + ExpectedVersion string + } + + modules := []ModVersionPair{ + {"go 1.20", "v1.20"}, + {"go 1.21.2", "v1.21.2"}, + {"go 1.21rc1", "v1.21.0-rc1"}, + {"go 1.21rc1\ntoolchain go1.22.0", "v1.22.0"}, + {"go 1.21rc1\ntoolchain go1.22rc1", "v1.22.0-rc1"}, + } + + for _, testData := range modules { + // `go.mod` and `go.work` files have mostly the same format + modFile := parseModFile(t, testData.FileContents) + workFile := parseWorkFile(t, testData.FileContents) + mod := GoModule{ + Path: "test", // irrelevant + Module: modFile, + } + work := GoWorkspace{ + WorkspaceFile: workFile, + } + + result := mod.RequiredGoVersion() + if result == nil { + t.Errorf( + "Expected mod.RequiredGoVersion() to return %s for the below `go.mod` file, but got nothing:\n%s", + testData.ExpectedVersion, + testData.FileContents, + ) + } else if result != util.NewSemVer(testData.ExpectedVersion) { + t.Errorf( + "Expected mod.RequiredGoVersion() to return %s for the below `go.mod` file, but got %s:\n%s", + testData.ExpectedVersion, + result, + testData.FileContents, + ) + } + + result = work.RequiredGoVersion() + if result == nil { + t.Errorf( + "Expected mod.RequiredGoVersion() to return %s for the below `go.work` file, but got nothing:\n%s", + testData.ExpectedVersion, + testData.FileContents, + ) + } else if result != util.NewSemVer(testData.ExpectedVersion) { + t.Errorf( + "Expected mod.RequiredGoVersion() to return %s for the below `go.work` file, but got %s:\n%s", + testData.ExpectedVersion, + result, + testData.FileContents, + ) + } + } + +}