diff --git a/go/extractor/project/BUILD.bazel b/go/extractor/project/BUILD.bazel index b4401187ce2d..9170ff95be35 100644 --- a/go/extractor/project/BUILD.bazel +++ b/go/extractor/project/BUILD.bazel @@ -19,5 +19,8 @@ go_test( name = "project_test", srcs = ["project_test.go"], embed = [":project"], - deps = ["//go/extractor/vendor/golang.org/x/mod/modfile"], + deps = [ + "//go/extractor/util", + "//go/extractor/vendor/golang.org/x/mod/modfile", + ], ) diff --git a/go/extractor/project/project.go b/go/extractor/project/project.go index 149dfe274bda..853e871d62bb 100644 --- a/go/extractor/project/project.go +++ b/go/extractor/project/project.go @@ -36,6 +36,17 @@ type GoModule struct { Module *modfile.File // The parsed contents of the `go.mod` file } +// Tries to find the Go toolchain version required for this module. +func (module *GoModule) RequiredGoVersion() util.SemVer { + if module.Module != nil && module.Module.Toolchain != nil { + return util.NewSemVer(module.Module.Toolchain.Name) + } else if module.Module != nil && module.Module.Go != nil { + return util.NewSemVer(module.Module.Go.Version) + } else { + return tryReadGoDirective(module.Path) + } +} + // Represents information about a Go project workspace: this may either be a folder containing // a `go.work` file or a collection of `go.mod` files. type GoWorkspace struct { @@ -54,24 +65,23 @@ type GoVersionInfo = util.SemVer // 1. The Go version specified in the `go.work` file, if any. // 2. The greatest Go version specified in any `go.mod` file, if any. func (workspace *GoWorkspace) RequiredGoVersion() util.SemVer { - if workspace.WorkspaceFile != nil && workspace.WorkspaceFile.Go != nil { - // If we have parsed a `go.work` file, return the version number from it. + // If we have parsed a `go.work` file, we prioritise versions from it over those in individual `go.mod` + // files. We are interested in toolchain versions, so if there is an explicit toolchain declaration in + // a `go.work` file, we use that. Otherwise, we fall back to the language version in the `go.work` file + // and use that as toolchain version. If we didn't parse a `go.work` file, then we try to find the + // greatest version contained in `go.mod` files. + if workspace.WorkspaceFile != nil && workspace.WorkspaceFile.Toolchain != nil { + return util.NewSemVer(workspace.WorkspaceFile.Toolchain.Name) + } else if workspace.WorkspaceFile != nil && workspace.WorkspaceFile.Go != nil { return util.NewSemVer(workspace.WorkspaceFile.Go.Version) } else if workspace.Modules != nil && len(workspace.Modules) > 0 { // Otherwise, if we have `go.work` files, find the greatest Go version in those. var greatestVersion util.SemVer = nil for _, module := range workspace.Modules { - if module.Module != nil && module.Module.Go != nil { - // If we have parsed the file, retrieve the version number we have already obtained. - modVersion := util.NewSemVer(module.Module.Go.Version) - if greatestVersion == nil || modVersion.IsNewerThan(greatestVersion) { - greatestVersion = modVersion - } - } else { - modVersion := tryReadGoDirective(module.Path) - if modVersion != nil && (greatestVersion == nil || modVersion.IsNewerThan(greatestVersion)) { - greatestVersion = modVersion - } + modVersion := module.RequiredGoVersion() + + if modVersion != nil && (greatestVersion == nil || modVersion.IsNewerThan(greatestVersion)) { + greatestVersion = modVersion } } diff --git a/go/extractor/project/project_test.go b/go/extractor/project/project_test.go index b7485960b5fd..149a9723ec29 100644 --- a/go/extractor/project/project_test.go +++ b/go/extractor/project/project_test.go @@ -4,6 +4,7 @@ import ( "path/filepath" "testing" + "github.com/github/codeql-go/extractor/util" "golang.org/x/mod/modfile" ) @@ -28,14 +29,18 @@ func TestStartsWithAnyOf(t *testing.T) { testStartsWithAnyOf(t, filepath.Join("foo", "bar"), filepath.Join("foo", "baz"), false) } -func testHasInvalidToolchainVersion(t *testing.T, contents string) bool { - modFile, err := modfile.Parse("test.go", []byte(contents), nil) +func parseModFile(t *testing.T, contents string) *modfile.File { + modFile, err := modfile.Parse("go.mod", []byte(contents), nil) if err != nil { t.Errorf("Unable to parse %s: %s.\n", contents, err.Error()) } - return hasInvalidToolchainVersion(modFile) + return modFile +} + +func testHasInvalidToolchainVersion(t *testing.T, contents string) bool { + return hasInvalidToolchainVersion(parseModFile(t, contents)) } func TestHasInvalidToolchainVersion(t *testing.T) { @@ -62,3 +67,74 @@ func TestHasInvalidToolchainVersion(t *testing.T) { } } } + +func parseWorkFile(t *testing.T, contents string) *modfile.WorkFile { + workFile, err := modfile.ParseWork("go.work", []byte(contents), nil) + + if err != nil { + t.Errorf("Unable to parse %s: %s.\n", contents, err.Error()) + } + + return workFile +} + +func TestRequiredGoVersion(t *testing.T) { + type ModVersionPair struct { + FileContents string + ExpectedVersion string + } + + modules := []ModVersionPair{ + {"go 1.20", "v1.20"}, + {"go 1.21.2", "v1.21.2"}, + {"go 1.21rc1", "v1.21.0-rc1"}, + {"go 1.21rc1\ntoolchain go1.22.0", "v1.22.0"}, + {"go 1.21rc1\ntoolchain go1.22rc1", "v1.22.0-rc1"}, + } + + for _, testData := range modules { + // `go.mod` and `go.work` files have mostly the same format + modFile := parseModFile(t, testData.FileContents) + workFile := parseWorkFile(t, testData.FileContents) + mod := GoModule{ + Path: "test", // irrelevant + Module: modFile, + } + work := GoWorkspace{ + WorkspaceFile: workFile, + } + + result := mod.RequiredGoVersion() + if result == nil { + t.Errorf( + "Expected mod.RequiredGoVersion() to return %s for the below `go.mod` file, but got nothing:\n%s", + testData.ExpectedVersion, + testData.FileContents, + ) + } else if result != util.NewSemVer(testData.ExpectedVersion) { + t.Errorf( + "Expected mod.RequiredGoVersion() to return %s for the below `go.mod` file, but got %s:\n%s", + testData.ExpectedVersion, + result, + testData.FileContents, + ) + } + + result = work.RequiredGoVersion() + if result == nil { + t.Errorf( + "Expected mod.RequiredGoVersion() to return %s for the below `go.work` file, but got nothing:\n%s", + testData.ExpectedVersion, + testData.FileContents, + ) + } else if result != util.NewSemVer(testData.ExpectedVersion) { + t.Errorf( + "Expected mod.RequiredGoVersion() to return %s for the below `go.work` file, but got %s:\n%s", + testData.ExpectedVersion, + result, + testData.FileContents, + ) + } + } + +}