Skip to content

Commit

Permalink
APP-6559 cuda + jetpack platform tags (viamrobotics#4492)
Browse files Browse the repository at this point in the history
  • Loading branch information
abe-winter authored Oct 29, 2024
1 parent 7a3be1d commit 752c9e8
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 65 deletions.
108 changes: 108 additions & 0 deletions config/platform.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package config

import (
"bufio"
"context"
"os"
"os/exec"
"regexp"
"runtime"
"strings"
"time"

"go.viam.com/rdk/logging"
)

var (
cudaRegex = regexp.MustCompile(`Cuda compilation tools, release (\d+)\.`)
aptCacheVersionRegex = regexp.MustCompile(`\nVersion: (\d+)\D`)
savedPlatformTags []string
)

// helper to read platform tags for GPU-related system libraries.
func readGPUTags(logger logging.Logger, tags []string) []string {
// this timeout is for all steps in this function.
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
if _, err := exec.LookPath("nvcc"); err == nil {
out, err := exec.CommandContext(ctx, "nvcc", "--version").Output()
if err != nil {
logger.Errorw("error getting Cuda version from nvcc. Cuda-specific modules may not load", "err", err)
}
if match := cudaRegex.FindSubmatch(out); match != nil {
tags = append(tags, "cuda:true", "cuda_version:"+string(match[1]))
} else {
logger.Errorw("error parsing `nvcc --version` output. Cuda-specific modules may not load")
}
}
if _, err := exec.LookPath("apt-cache"); err == nil {
out, err := exec.CommandContext(ctx, "apt-cache", "show", "nvidia-jetpack").Output()
// note: the error case here will usually mean 'package missing', we don't analyze it.
if err == nil {
if match := aptCacheVersionRegex.FindSubmatch(out); match != nil {
tags = append(tags, "jetpack:"+string(match[1]))
}
}
}
return tags
}

// helper to parse the /etc/os-release file on linux systems.
func parseOsRelease(body *bufio.Reader) map[string]string {
ret := make(map[string]string)
for {
line, err := body.ReadString('\n')
if err != nil {
return ret
}
key, value, _ := strings.Cut(line, "=")
// note: we trim `value` rather than `line` because os_version value is quoted sometimes.
ret[key] = strings.Trim(value, "\n\"")
}
}

// append key:value pair to orig if value is non-empty.
func appendPairIfNonempty(orig []string, key, value string) []string {
if value != "" {
return append(orig, key+":"+value)
}
return orig
}

// helper to tag-ify the contents of /etc/os-release.
func readLinuxTags(logger logging.Logger, tags []string) []string {
if body, err := os.Open("/etc/os-release"); err != nil {
if !os.IsNotExist(err) {
logger.Errorw("can't open /etc/os-release, modules may not load correctly", "err", err)
}
} else {
defer body.Close() //nolint:errcheck
osRelease := parseOsRelease(bufio.NewReader(body))
tags = appendPairIfNonempty(tags, "distro", osRelease["ID"])
tags = appendPairIfNonempty(tags, "os_version", osRelease["VERSION_ID"])
tags = appendPairIfNonempty(tags, "codename", osRelease["VERSION_CODENAME"])
}
return tags
}

// This reads the granular platform constraints (os version, distro, etc).
// This further constrains the basic runtime.GOOS/GOARCH stuff in getAgentInfo
// so module authors can publish builds with ABI or SDK dependencies. The
// list of tags returned by this function is expected to grow.
func readExtendedPlatformTags(logger logging.Logger, cache bool) []string {
// TODO(APP-6696): CI in multiple environments (alpine + mac), darwin support.
if cache && savedPlatformTags != nil {
return savedPlatformTags
}
tags := make([]string, 0, 3)
if runtime.GOOS == "linux" {
tags = readLinuxTags(logger, tags)
tags = readGPUTags(logger, tags)
}
if cache {
savedPlatformTags = tags
// note: we only log in the cache condition because it would be annoying to log this in a loop.
logger.Infow("platform tags", "tags", strings.Join(tags, ","))
}
return tags
}
69 changes: 69 additions & 0 deletions config/platform_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package config

import (
"runtime"
"testing"

"go.viam.com/test"

"go.viam.com/rdk/logging"
)

func TestReadExtendedPlatformTags(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("skipping platform tags test on non-linux")
}
logger := logging.NewTestLogger(t)
tags := readExtendedPlatformTags(logger, true)
test.That(t, len(tags), test.ShouldBeGreaterThanOrEqualTo, 2)
}

func TestAppendPairIfNonempty(t *testing.T) {
arr := make([]string, 0, 1)
arr = appendPairIfNonempty(arr, "x", "y")
arr = appendPairIfNonempty(arr, "a", "")
test.That(t, arr, test.ShouldResemble, []string{"x:y"})
}

func TestCudaRegexes(t *testing.T) {
t.Run("cuda", func(t *testing.T) {
output := `nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Thu_Nov_18_09:45:30_PST_2021
Cuda compilation tools, release 11.5, V11.5.119
Build cuda_11.5.r11.5/compiler.30672275_0
`
match := cudaRegex.FindSubmatch([]byte(output))
test.That(t, match, test.ShouldNotBeNil)
test.That(t, string(match[1]), test.ShouldResemble, "11")
})

t.Run("apt-cache", func(t *testing.T) {
jp5 := `Package: nvidia-jetpack
Version: 5.1.1-b56
Architecture: arm64
Maintainer: NVIDIA Corporation
Installed-Size: 194
Depends: nvidia-jetpack-runtime (= 5.1.1-b56), nvidia-jetpack-dev (= 5.1.1-b56)
Homepage: http://developer.nvidia.com/jetson
Priority: standard
Section: metapackages`
match := aptCacheVersionRegex.FindSubmatch([]byte(jp5))
test.That(t, match, test.ShouldNotBeNil)
test.That(t, string(match[1]), test.ShouldResemble, "5")

jp6 := `Package: nvidia-jetpack
Source: nvidia-jetpack (6.1)
Version: 6.1+b123
Architecture: arm64
Maintainer: NVIDIA Corporation
Installed-Size: 194
Depends: nvidia-jetpack-runtime (= 6.1+b123), nvidia-jetpack-dev (= 6.1+b123)
Homepage: http://developer.nvidia.com/jetson
Priority: standard
Section: metapackages`
match = aptCacheVersionRegex.FindSubmatch([]byte(jp6))
test.That(t, match, test.ShouldNotBeNil)
test.That(t, string(match[1]), test.ShouldResemble, "6")
})
}
52 changes: 3 additions & 49 deletions config/reader.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package config

import (
"bufio"
"bytes"
"context"
"encoding/json"
Expand All @@ -12,7 +11,6 @@ import (
"os"
"path/filepath"
"runtime"
"strings"
"time"

"github.com/a8m/envsubst"
Expand Down Expand Up @@ -44,51 +42,7 @@ const (
LocalPackagesSuffix = "-local"
)

func parseOsRelease(body *bufio.Reader) map[string]string {
ret := make(map[string]string)
for {
line, err := body.ReadString('\n')
if err != nil {
return ret
}
key, value, _ := strings.Cut(line, "=")
// note: we trim `value` rather than `line` because os_version value is quoted sometimes.
ret[key] = strings.Trim(value, "\n\"")
}
}

// append key:value pair to orig if value is non-empty.
func appendPairIfNonempty(orig []string, key, value string) []string {
if value != "" {
return append(orig, key+":"+value)
}
return orig
}

// This reads the granular platform constraints (os version, distro, etc).
// This further constrains the basic runtime.GOOS/GOARCH stuff in getAgentInfo
// so module authors can publish builds with ABI or SDK dependencies. The
// list of tags returned by this function is expected to grow.
func readExtendedPlatformTags() []string {
// TODO(APP-6696): CI in multiple environments (alpine + mac), darwin support.
tags := make([]string, 0, 3)
if runtime.GOOS == "linux" {
if body, err := os.Open("/etc/os-release"); err != nil {
if !os.IsNotExist(err) {
logging.Global().Errorw("can't open /etc/os-release, modules may not load correctly", "err", err)
}
} else {
defer body.Close() //nolint:errcheck
osRelease := parseOsRelease(bufio.NewReader(body))
tags = appendPairIfNonempty(tags, "distro", osRelease["ID"])
tags = appendPairIfNonempty(tags, "os_version", osRelease["VERSION_ID"])
tags = appendPairIfNonempty(tags, "codename", osRelease["VERSION_CODENAME"])
}
}
return tags
}

func getAgentInfo() (*apppb.AgentInfo, error) {
func getAgentInfo(logger logging.Logger) (*apppb.AgentInfo, error) {
hostname, err := os.Hostname()
if err != nil {
return nil, err
Expand Down Expand Up @@ -122,7 +76,7 @@ func getAgentInfo() (*apppb.AgentInfo, error) {
Version: Version,
GitRevision: GitRevision,
Platform: &platform,
PlatformTags: readExtendedPlatformTags(),
PlatformTags: readExtendedPlatformTags(logger, true),
}, nil
}

Expand Down Expand Up @@ -701,7 +655,7 @@ func getFromCloudGRPC(ctx context.Context, cloudCfg *Cloud, logger logging.Logge
}
defer utils.UncheckedErrorFunc(conn.Close)

agentInfo, err := getAgentInfo()
agentInfo, err := getAgentInfo(logger)
if err != nil {
return nil, shouldCheckCacheOnFailure, err
}
Expand Down
16 changes: 0 additions & 16 deletions config/reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io/fs"
"os"
"runtime"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -395,18 +394,3 @@ func TestReadTLSFromCache(t *testing.T) {
test.That(t, err, test.ShouldBeNil)
})
}

func TestReadExtendedPlatformTags(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skip("skipping platform tags test on non-linux")
}
tags := readExtendedPlatformTags()
test.That(t, len(tags), test.ShouldBeGreaterThanOrEqualTo, 2)
}

func TestAppendPairIfNonempty(t *testing.T) {
arr := make([]string, 0, 1)
arr = appendPairIfNonempty(arr, "x", "y")
arr = appendPairIfNonempty(arr, "a", "")
test.That(t, arr, test.ShouldResemble, []string{"x:y"})
}

0 comments on commit 752c9e8

Please sign in to comment.