diff --git a/main.go b/main.go index e351c06..660b08c 100644 --- a/main.go +++ b/main.go @@ -4,10 +4,10 @@ import ( "flag" "fmt" "io" - "io/ioutil" "net/http" "os" "path" + "path/filepath" "time" "github.com/pin/tftp" @@ -18,9 +18,10 @@ var dir string // readHandler is called when client starts file download from server func readHandler(filename string, rf io.ReaderFrom) error { + file_path := filepath.Clean(path.Join(dir, filename)) - if _, err := os.Stat(path.Join(dir, filename)); err == nil { - file, err := os.Open(path.Join(dir, filename)) + if _, err := os.Stat(file_path); err == nil { + file, err := os.Open(file_path) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) return err @@ -50,8 +51,8 @@ func readHandler(filename string, rf io.ReaderFrom) error { defer resp.Body.Close() if resp.StatusCode != 200 { - io.Copy(ioutil.Discard, resp.Body) - return fmt.Errorf("Received status code: %d", resp.StatusCode) + io.Copy(io.Discard, resp.Body) + return fmt.Errorf("received status code: %d", resp.StatusCode) } rf.(tftp.OutgoingTransfer).SetSize(resp.ContentLength)