Skip to content

Commit

Permalink
Merge pull request #3 from fkautz/protect-against-snooping
Browse files Browse the repository at this point in the history
Prevent snooping around filesystem outside of bounds passed in Add().
  • Loading branch information
fkautz authored May 28, 2024
2 parents b02b506 + 8e03827 commit ee1b305
Show file tree
Hide file tree
Showing 20 changed files with 383 additions and 20 deletions.
1 change: 1 addition & 0 deletions definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ type Plugin interface {
Store(envelope *Envelope) error
Sha1ADG(map[string]string)
Sha256ADG(map[string]string)
SetAllowList([]string)
}
31 changes: 30 additions & 1 deletion directory_plugin.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package omnitrail

import (
"github.com/omnibor/omnibor-go"
"fmt"
"os"
"path/filepath"
"sort"
"strings"

"github.com/omnibor/omnibor-go"
)

type DirectoryPlugin struct {
Expand All @@ -13,6 +16,16 @@ type DirectoryPlugin struct {
directories map[string]bool
sha1adgs map[string]omnibor.ArtifactTree
sha256adgs map[string]omnibor.ArtifactTree
AllowList []string
}

func (plug *DirectoryPlugin) isAllowedDirectory(path string) bool {
for _, allowedPath := range plug.AllowList {
if strings.HasPrefix(path, allowedPath) {
return true
}
}
return false
}

func (plug *DirectoryPlugin) Sha1ADG(m map[string]string) {
Expand All @@ -28,20 +41,32 @@ func (plug *DirectoryPlugin) Sha256ADG(m map[string]string) {
}

func (plug *DirectoryPlugin) Add(path string) error {

// if this is a broken symlink, ignore
fileInfo, err := os.Lstat(path)
if err != nil {
// if it's a symlink and the symlink is bad, ignore and return
if os.IsNotExist(err) {
return nil
}
return err
}
if fileInfo.Mode()&os.ModeSymlink != 0 {
// path is a symlink
targetPath, err := os.Readlink(path)
if err != nil {
// if it's a symlink and the symlink is bad, ignore and return
if os.IsNotExist(err) {
return nil
}
return err
}
if !filepath.IsAbs(targetPath) {
targetPath = filepath.Join(filepath.Dir(path), targetPath)
}
if !plug.isAllowedDirectory(targetPath) {
return fmt.Errorf("path %s is not in the allow list", path)
}
if _, err := os.Stat(targetPath); err != nil {
return nil
}
Expand Down Expand Up @@ -168,6 +193,10 @@ func (plug *DirectoryPlugin) addKeysToTree(keys []string, tree map[string]omnibo
return nil
}

func (plug *DirectoryPlugin) SetAllowList(allowList []string) {
plug.AllowList = allowList
}

func NewDirectoryPlugin() Plugin {
algorithms := []string{"gitoid:sha1", "gitoid:sha256"}
sort.Strings(algorithms)
Expand Down
18 changes: 14 additions & 4 deletions factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,27 @@ import (
)

type factoryImpl struct {
Options *Options
envelope *Envelope
Plugins []Plugin
Options *Options
envelope *Envelope
Plugins []Plugin
AllowList []string
}

func (factory *factoryImpl) Add(originalPath string) error {
originalPath, err := filepath.Abs(originalPath)
// Convert the path to an absolute path
absPath, err := filepath.Abs(originalPath)
if err != nil {
return err
}

// Add the absolute path to the allow list
factory.AllowList = append(factory.AllowList, absPath)

// For each plugin, add the allow list
for _, plugin := range factory.Plugins {
plugin.SetAllowList(factory.AllowList)
}

// check if path already exists in the envelope, if so, return
if _, ok := factory.envelope.Mapping[originalPath]; ok {
return nil
Expand Down
28 changes: 27 additions & 1 deletion file_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ import (
type FilePlugin struct {
algorithms []string
files map[string]map[string]string
AllowList []string
}

func (plug *FilePlugin) isAllowedDirectory(path string) bool {
for _, allowedPath := range plug.AllowList {
if strings.HasPrefix(path, allowedPath) {
return true
}
}
return false
}

func (plug *FilePlugin) Sha1ADG(m map[string]string) {
Expand All @@ -38,6 +48,10 @@ func (plug *FilePlugin) Sha256ADG(m map[string]string) {
}
}

func (plug *FilePlugin) SetAllowList(allowList []string) {
plug.AllowList = allowList
}

func NewFilePlugin() Plugin {
algorithms := []string{"sha1", "sha256", "gitoid:sha1", "gitoid:sha256"}
sort.Strings(algorithms)
Expand All @@ -52,20 +66,32 @@ func NewFilePlugin() Plugin {
}

func (plug *FilePlugin) Add(filePath string) error {

// ignore broken symlink
localFileInfo, err := os.Lstat(filePath)
if err != nil {
// if it's a symlink and the symlink is bad, ignore and return
if os.IsNotExist(err) {
return nil
}
return err
}
if localFileInfo.Mode()&os.ModeSymlink != 0 {
targetPath, err := os.Readlink(filePath)
if err != nil {
// if it's a symlink and the symlink is bad, ignore and return
if os.IsNotExist(err) {
return nil
}
fmt.Println("returning err: ", err)
return err
}
if !filepath.IsAbs(targetPath) {

targetPath = filepath.Join(filepath.Dir(filePath), targetPath)
}
if !plug.isAllowedDirectory(targetPath) {
return fmt.Errorf("path %s is not in the allow list", filePath)
}
if _, err = os.Stat(targetPath); err != nil {
return nil
}
Expand Down
7 changes: 6 additions & 1 deletion omnitrail.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ func NewTrail(option ...Option) Factory {
if o.Sha1Enabled == false && o.Sha256Enabled == false {
o.Sha1Enabled = true
}
allowList := []string{}
plugins := make([]Plugin, 0)
plugins = append(plugins, NewFilePlugin())
plugins = append(plugins, NewDirectoryPlugin())
plugins = append(plugins, NewPosixPlugin())
return &factoryImpl{

factory := &factoryImpl{
Options: o,
Plugins: plugins,
envelope: &Envelope{
Expand All @@ -26,7 +28,10 @@ func NewTrail(option ...Option) Factory {
},
Mapping: make(map[string]*Element),
},
AllowList: allowList,
}

return factory
}

func FormatADGString(mapping Factory) string {
Expand Down
90 changes: 78 additions & 12 deletions omnitrail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package omnitrail

import (
"encoding/json"
"github.com/stretchr/testify/assert"
"fmt"
"os"
"os/user"
"reflect"
"sort"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestEmpty(t *testing.T) {
Expand All @@ -17,36 +20,88 @@ func TestEmpty(t *testing.T) {
assert.NoError(t, err)
}
name := "empty"
testAdd(t, name)
if err := testAdd(t, name); err != nil {
t.Fatalf("TestEmpty failed: %v", err)
}
}

func TestOneFiles(t *testing.T) {
name := "one-file"
testAdd(t, name)
if err := testAdd(t, name); err != nil {
t.Fatalf("TestOneFiles failed: %v", err)
}
}

func TestTwoFiles(t *testing.T) {
name := "two-files"
testAdd(t, name)
if err := testAdd(t, name); err != nil {
t.Fatalf("TestTwoFiles failed: %v", err)
}
}

func TestDeepStructure(t *testing.T) {
name := "deep"
testAdd(t, name)
if err := testAdd(t, name); err != nil {
t.Fatalf("TestDeepStructure failed: %v", err)
}
}

func TestSymlinkGood(t *testing.T) {
name := "symlink-good"
if err := testAdd(t, name); err != nil {
t.Fatalf("TestSymlinkGood failed: %v", err)
}
}

func TestSymlinkBroken(t *testing.T) {
name := "symlink-broken"
if err := testAdd(t, name); err != nil {
t.Fatalf("should ignore a bad symlink: %v", err)
}
}

func TestSymlinkOutOfBounds(t *testing.T) {
name := "symlink-out-of-bounds"
err := os.WriteFile("/tmp/omnitrail-well-known-file", []byte("hello"), 0644)
if err != nil {
t.Fatalf("unable to write temporary file: %v", err)
}
defer os.Remove("/tmp/omnitrail-well-known-file")
err = testAdd(t, name)
if !strings.Contains(err.Error(), "not in the allow list") {
t.Fatalf("unexpected error: %v", err)

}
if err == nil {
t.Fatalf("TestSymlinkOutOfBounds failed: should report a symlik out of bounds")
}
}

func testAdd(t *testing.T, name string) {
func testAdd(t *testing.T, name string) error {
mapping := NewTrail()

err := mapping.Add("./test/" + name)
assert.NoError(t, err)
if err != nil {
return err
}

// WARNING: these are only for generating new test cases easily
// file, err := json.MarshalIndent(mapping.Envelope(), "", " ")
// os.WriteFile("./test/"+name+".json", file, 0644)
// res := FormatADGString(mapping)
// os.WriteFile("./test/"+name+".adg", []byte(res), 0644)
// END WARNING

expectedBytes, err := os.ReadFile("./test/" + name + ".json")
assert.NoError(t, err)
if err != nil {
return err
}

var expectedEnvelope Envelope
err = json.Unmarshal(expectedBytes, &expectedEnvelope)
assert.NoError(t, err)
if err != nil {
return err
}

shortestExpectedKey := getShortestKey(&expectedEnvelope)
shortestActualKey := getShortestKey(mapping.Envelope())
Expand All @@ -59,7 +114,9 @@ func testAdd(t *testing.T, name string) {

// get current username
currentUser, err := user.Current()
assert.NoError(t, err)
if err != nil {
return err
}
uid := currentUser.Uid
gid := currentUser.Gid

Expand All @@ -70,11 +127,20 @@ func testAdd(t *testing.T, name string) {

assert.Equal(t, &expectedEnvelope, mapping.Envelope())

if !reflect.DeepEqual(&expectedEnvelope, mapping.Envelope()) {
return fmt.Errorf("expected envelope does not match actual envelope")
}

res := FormatADGString(mapping)

expectedBytes, err = os.ReadFile("./test/" + name + ".adg")
assert.NoError(t, err)
assert.Equal(t, string(expectedBytes), res)
if err != nil {
return err
}
if string(expectedBytes) != res {
return fmt.Errorf("expected ADG string does not match actual ADG string")
}
return nil
}

func getShortestKey(expectedEnvelope *Envelope) string {
Expand Down
Loading

0 comments on commit ee1b305

Please sign in to comment.