diff --git a/pkg/varchiver/tar.go b/pkg/varchiver/tar.go index dfe3f802..42af89f9 100644 --- a/pkg/varchiver/tar.go +++ b/pkg/varchiver/tar.go @@ -2,16 +2,82 @@ package varchiver import ( "archive/tar" + "archive/zip" "compress/gzip" + "errors" "fmt" "io" "os" "path" + "path/filepath" ) +var ( + ErrMustBeLocal = errors.New("security: src must be a local file") +) + +func Unzip(src string, dest string) error { + if !filepath.IsLocal(src) || !filepath.IsLocal(dest) { + return ErrMustBeLocal + } + + reader, err := zip.OpenReader(src) + if err != nil { + return err + } + + for _, header := range reader.File { + p := path.Join(dest, header.Name) + + if !filepath.IsLocal(header.Name) { + return ErrMustBeLocal + } + + if header.FileInfo().IsDir() { + err = os.MkdirAll(p, os.ModePerm) + if err != nil { + return err + } + } else { + err = os.MkdirAll(path.Dir(p), os.ModePerm) + if err != nil { + return err + } + + file, err := os.Create(p) + if err != nil { + return err + } + + content, err := header.Open() + if err != nil { + return err + } + + _, err = io.Copy(file, content) + if err != nil { + return err + } + + err = os.Chmod(p, 0755) + if err != nil { + return err + } + + file.Close() + } + } + + return nil +} + // Untar a tarball to a destination. src is the path to // the tarball, and dest is the path to the destination directory. func Untar(src string, dest string) error { + if !filepath.IsLocal(src) || !filepath.IsLocal(dest) { + return ErrMustBeLocal + } + archive, err := os.Open(src) if err != nil { return err @@ -36,21 +102,25 @@ func Untar(src string, dest string) error { return err } - filepath := path.Join(dest, header.Name) + p := path.Join(dest, header.Name) + + if !filepath.IsLocal(header.Name) { + return ErrMustBeLocal + } switch header.Typeflag { case tar.TypeDir: - err = os.MkdirAll(filepath, os.ModePerm) + err = os.MkdirAll(p, os.ModePerm) if err != nil { return err } case tar.TypeReg: - err := os.MkdirAll(path.Dir(filepath), os.ModePerm) + err := os.MkdirAll(path.Dir(p), os.ModePerm) if err != nil { return err } - file, err := os.Create(filepath) + file, err := os.Create(p) if err != nil { return err } @@ -60,7 +130,7 @@ func Untar(src string, dest string) error { return err } - err = os.Chmod(filepath, 0755) + err = os.Chmod(p, 0755) if err != nil { return err } diff --git a/services/dependencies.go b/services/dependencies.go index 583d00cf..66d9cef0 100644 --- a/services/dependencies.go +++ b/services/dependencies.go @@ -1,7 +1,6 @@ package services import ( - "archive/zip" "context" "errors" "fmt" @@ -17,6 +16,7 @@ import ( "github.com/vertex-center/vertex/config" "github.com/vertex-center/vertex/pkg/log" "github.com/vertex-center/vertex/pkg/storage" + "github.com/vertex-center/vertex/pkg/varchiver" "github.com/vertex-center/vertex/pkg/vdocker" "github.com/vertex-center/vertex/types" "github.com/vertex-center/vlog" @@ -245,12 +245,14 @@ func (d *clientUpdater) InstallUpdate() error { return err } - err = download(dir, *asset.BrowserDownloadURL) + tempZipPath := path.Join(dir, "temp.zip") + + err = download(tempZipPath, *asset.BrowserDownloadURL) if err != nil { return err } - err = unarchive(dir) + err = varchiver.Unzip(tempZipPath, dir) if err != nil { return err } @@ -293,7 +295,7 @@ func download(dir string, url string) error { } defer res.Body.Close() - file, err := os.Create(path.Join(dir, "temp.zip")) + file, err := os.Create(dir) if err != nil { return err } @@ -303,53 +305,6 @@ func download(dir string, url string) error { return err } -func unarchive(dir string) error { - reader, err := zip.OpenReader(path.Join(dir, "temp.zip")) - if err != nil { - return err - } - - for _, header := range reader.File { - filepath := path.Join(dir, header.Name) - - if header.FileInfo().IsDir() { - err = os.MkdirAll(filepath, os.ModePerm) - if err != nil { - return err - } - } else { - err = os.MkdirAll(path.Dir(filepath), os.ModePerm) - if err != nil { - return err - } - - file, err := os.Create(filepath) - if err != nil { - return err - } - - content, err := header.Open() - if err != nil { - return err - } - - _, err = io.Copy(file, content) - if err != nil { - return err - } - - err = os.Chmod(filepath, 0755) - if err != nil { - return err - } - - file.Close() - } - } - - return nil -} - type gitHubUpdater struct { dir string name string