diff --git a/cmd/agent/main.go b/cmd/agent/main.go index ceab078..ce19d68 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -4,7 +4,6 @@ import ( "fmt" "os" "os/signal" - "sync" "syscall" "time" @@ -40,13 +39,14 @@ func main() { Data: &dataExchange, Ticker: time.NewTicker(time.Duration(conf.ReportInterval) * time.Second), Done: done, - SendAdapter: &adapters.HTTPReportAdapter{ - Logger: logger, - ServerAddr: conf.Addr, - UpdateURL: conf.UpdateURL, - Retrier: retrier, - Lock: &sync.Mutex{}, - }, + SendAdapter: adapters.NewHTTPReportAdapter( + logger, + conf.Addr, + conf.UpdateURL, + retrier, + []byte(conf.HMACKey), + conf.RateLimit, + ), } poller := poll.Poller{ Logger: logger, diff --git a/cmd/server/main.go b/cmd/server/main.go index cdb6f6e..d8e68ef 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -10,6 +10,7 @@ import ( "syscall" "time" + "github.com/matthiasBT/monitoring/internal/infra/hashcheck" "github.com/matthiasBT/monitoring/internal/infra/utils" "github.com/matthiasBT/monitoring/internal/server/entities" @@ -21,11 +22,13 @@ import ( "github.com/matthiasBT/monitoring/internal/server/usecases" ) -func setupServer(logger logging.ILogger, controller *usecases.BaseController) *chi.Mux { +func setupServer(logger logging.ILogger, controller *usecases.BaseController, hmacKey string) *chi.Mux { r := chi.NewRouter() r.Use(logging.Middleware(logger)) - r.Use(compression.MiddlewareReader) - r.Use(compression.MiddlewareWriter) + r.Use(compression.MiddlewareReader, compression.MiddlewareWriter) + if hmacKey != "" { + r.Use(hashcheck.MiddlewareReader(hmacKey), hashcheck.MiddlewareWriter(hmacKey)) + } r.Mount("/", controller.Route()) return r } @@ -100,7 +103,7 @@ func main() { } controller := usecases.NewBaseController(logger, storage, conf.TemplatePath) - r := setupServer(logger, controller) + r := setupServer(logger, controller, conf.HMACKey) srv := http.Server{Addr: conf.Addr, Handler: r} go func() { logger.Infof("Launching the server at %s\n", conf.Addr) diff --git a/go.mod b/go.mod index 5b8b016..bf03e7a 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa github.com/jackc/pgx/v5 v5.4.3 github.com/pressly/goose/v3 v3.15.0 + github.com/shirou/gopsutil/v3 v3.23.8 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 @@ -15,11 +16,17 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-ole/go-ole v1.2.6 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/kr/text v0.2.0 // indirect + github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect + github.com/tklauser/go-sysconf v0.3.12 // indirect + github.com/tklauser/numcpus v0.6.1 // indirect + github.com/yusufpapurcu/wmi v1.2.3 // indirect golang.org/x/crypto v0.10.0 // indirect golang.org/x/sys v0.11.0 // indirect golang.org/x/text v0.11.0 // indirect diff --git a/go.sum b/go.sum index 03ea3a5..31e3e30 100644 --- a/go.sum +++ b/go.sum @@ -7,6 +7,11 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/go-chi/chi/v5 v5.0.10 h1:rLz5avzKpjqxrYwXNfmjkrYYXOyLJd37pz53UFHC6vk= github.com/go-chi/chi/v5 v5.0.10/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa h1:s+4MhCQ6YrzisK6hFJUX53drDT4UsSW3DEhKn0ifuHw= github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= @@ -20,32 +25,54 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4= +github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= +github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/pressly/goose/v3 v3.15.0 h1:6tY5aDqFknY6VZkorFGgZtWygodZQxfmmEF4rqyJW9k= github.com/pressly/goose/v3 v3.15.0/go.mod h1:LlIo3zGccjb/YUgG+Svdb9Er14vefRdlDI7URCDrwYo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/shirou/gopsutil/v3 v3.23.8 h1:xnATPiybo6GgdRoC4YoGnxXZFRc3dqQTGi73oLvvBrE= +github.com/shirou/gopsutil/v3 v3.23.8/go.mod h1:7hmCaBn+2ZwaZOr6jmPBZDfawwMGuo1id3C6aM8EDqQ= +github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= +github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= +github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= +github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= +github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= +github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= +github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 h1:m64FZMko/V45gv0bNmrNYoDEq8U5YUhetc9cBWKS1TQ= golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63/go.mod h1:0v4NqG35kSWCMzLaMeX+IQrlSnVE/bqGSyC2cz/9Le8= golang.org/x/mod v0.12.0 h1:rmsUpXtvNzj340zd98LZ4KntptpfRHwpFOHG188oHXc= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/tools v0.12.1-0.20230815132531-74c255bcf846 h1:Vve/L0v7CXXuxUmaMGIEK/dEeq7uiqb5qBgQrZzIE7E= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/agent/adapters/http.go b/internal/agent/adapters/http.go index 17248dc..ac8b1a8 100644 --- a/internal/agent/adapters/http.go +++ b/internal/agent/adapters/http.go @@ -3,11 +3,14 @@ package adapters import ( "bytes" "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "encoding/json" + "errors" "io" "net/http" "net/url" - "sync" common "github.com/matthiasBT/monitoring/internal/infra/entities" "github.com/matthiasBT/monitoring/internal/infra/logging" @@ -19,59 +22,86 @@ type HTTPReportAdapter struct { ServerAddr string UpdateURL string Retrier utils.Retrier - Lock *sync.Mutex + HMACKey []byte + jobs chan []byte +} + +var ErrResponseNotOK = errors.New("response not OK") + +func NewHTTPReportAdapter( + logger logging.ILogger, + serverAddr string, + updateURL string, + retrier utils.Retrier, + hmacKey []byte, + workerNum uint, +) *HTTPReportAdapter { + jobs := make(chan []byte, workerNum) + adapter := HTTPReportAdapter{ + Logger: logger, + ServerAddr: serverAddr, + UpdateURL: updateURL, + Retrier: retrier, + HMACKey: hmacKey, + jobs: jobs, + } + var i uint + for i = 0; i < workerNum; i++ { + go func() { + for { + data := <-jobs + adapter.report(&data) + } + }() + } + return &adapter } func (r *HTTPReportAdapter) Report(metrics *common.Metrics) error { - r.Lock.Lock() - defer r.Lock.Unlock() payload, err := json.Marshal(metrics) if err != nil { r.Logger.Errorf("Failed to marshal a metric: %v", metrics) return err } - - u := url.URL{Scheme: "http", Host: r.ServerAddr, Path: r.UpdateURL} - f := func() (any, error) { - resp, err := http.Post(u.String(), "application/json", bytes.NewReader(payload)) - if err != nil { - r.Logger.Errorf("Request failed: %v\n", err.Error()) - return nil, err - } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - return body, nil - } - bodyAny, err := r.Retrier.RetryChecked(context.Background(), f, utils.CheckConnectionError) - if err != nil { - return err - } - body := bodyAny.([]byte) - r.Logger.Infof("Success. Server response: %v", string(body)) + r.jobs <- payload return nil } func (r *HTTPReportAdapter) ReportBatch(batch []*common.Metrics) error { - r.Lock.Lock() - defer r.Lock.Unlock() payload, err := json.Marshal(batch) if err != nil { r.Logger.Errorf("Failed to marshal a batch of metrics: %v\n", err.Error()) return err } + r.jobs <- payload + return nil +} +func (r *HTTPReportAdapter) report(payload *[]byte) error { + var ( + req *http.Request + err error + ) u := url.URL{Scheme: "http", Host: r.ServerAddr, Path: r.UpdateURL} + if req, err = r.createRequest(u, payload); err != nil { + return err + } + if err := r.addHMACHeader(req, payload); err != nil { + return err + } + f := func() (any, error) { - resp, err := http.Post(u.String(), "application/json", bytes.NewReader(payload)) + client := &http.Client{} + resp, err := client.Do(req) if err != nil { r.Logger.Errorf("Request failed: %v\n", err.Error()) return nil, err } + if resp.StatusCode != http.StatusOK { + r.Logger.Errorf("Request failed with code: %d\n", resp.StatusCode) + return nil, ErrResponseNotOK + } defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) if err != nil { return nil, err @@ -86,3 +116,37 @@ func (r *HTTPReportAdapter) ReportBatch(batch []*common.Metrics) error { r.Logger.Infof("Success. Server response: %v", string(body)) return nil } + +func (r *HTTPReportAdapter) createRequest(path url.URL, payload *[]byte) (*http.Request, error) { + req, err := http.NewRequest("POST", path.String(), bytes.NewReader(*payload)) + if err != nil { + r.Logger.Errorf("Failed to create a request: %v\n", err.Error()) + return nil, err + } + req.Header.Add("Content-Type", "application/json") + return req, nil +} + +func (r *HTTPReportAdapter) addHMACHeader(req *http.Request, payload *[]byte) error { + if hash, err := r.hashData(payload); err != nil { + return err + } else if hash != "" { + req.Header.Add("HashSHA256", hash) + } + return nil +} + +func (r *HTTPReportAdapter) hashData(payload *[]byte) (string, error) { + if bytes.Equal(r.HMACKey, []byte{}) { + return "", nil + } + mac := hmac.New(sha256.New, r.HMACKey) + if _, err := mac.Write(*payload); err != nil { + r.Logger.Errorf("Failed to calculate hash: %v", err.Error()) + return "", err + } + hash := mac.Sum(nil) + result := hex.EncodeToString(hash) + r.Logger.Infof("HMAC-SHA256 hash: %s\n", result) + return result, nil +} diff --git a/internal/agent/usecases/poll/poll.go b/internal/agent/usecases/poll/poll.go index bb866df..66c1d3f 100644 --- a/internal/agent/usecases/poll/poll.go +++ b/internal/agent/usecases/poll/poll.go @@ -1,17 +1,17 @@ package poll import ( + "fmt" "math/rand" "runtime" "time" "github.com/matthiasBT/monitoring/internal/agent/entities" "github.com/matthiasBT/monitoring/internal/infra/logging" + "github.com/shirou/gopsutil/v3/cpu" + "github.com/shirou/gopsutil/v3/mem" ) -type PollerInfra struct { -} - type Poller struct { Logger logging.ILogger PollCount int64 @@ -72,6 +72,22 @@ func (p *Poller) currentSnapshot() { "PollCount": p.PollCount, }, } + if memstat, err := mem.VirtualMemory(); err != nil { + p.Logger.Errorf("Failed to get memory statistics: %v\n", err.Error()) + return + } else { + snapshot.Gauges["TotalMemory"] = float64(memstat.Total) + snapshot.Gauges["FreeMemory"] = float64(memstat.Free) + } + if cpuUtilStat, err := cpu.Percent(0, true); err != nil { + p.Logger.Errorf("Failed to get CPU statistics: %v\n", err.Error()) + return + } else { + for idx, utilStat := range cpuUtilStat { + name := fmt.Sprintf("CPUutilization%d", idx+1) + snapshot.Gauges[name] = utilStat + } + } p.Data.CurrSnapshot = snapshot p.Logger.Infoln("Created another metrics snapshot") } diff --git a/internal/agent/usecases/poll/poll_test.go b/internal/agent/usecases/poll/poll_test.go index 2b3d903..9ea279a 100644 --- a/internal/agent/usecases/poll/poll_test.go +++ b/internal/agent/usecases/poll/poll_test.go @@ -1,11 +1,13 @@ package poll import ( + "fmt" "sort" "testing" "github.com/matthiasBT/monitoring/internal/agent/entities" "github.com/matthiasBT/monitoring/internal/infra/logging" + "github.com/shirou/gopsutil/v3/cpu" "github.com/stretchr/testify/assert" ) @@ -25,10 +27,10 @@ func TestCollect(t *testing.T) { gauges = append(gauges, key) } sort.Strings(gauges) - expectedGauges := []string{ "Alloc", "BuckHashSys", + "FreeMemory", "Frees", "GCCPUFraction", "GCSys", @@ -55,6 +57,16 @@ func TestCollect(t *testing.T) { "StackSys", "Sys", "TotalAlloc", + "TotalMemory", + } + if cpuCount, err := cpu.Counts(true); err != nil { + t.Fatalf("Failed to get the number of CPUs: %v", err) + } else { + for i := 1; i <= cpuCount; i++ { + name := fmt.Sprintf("CPUutilization%d", i) + expectedGauges = append(expectedGauges, name) + } + sort.Strings(expectedGauges) } assert.EqualValues(t, expectedGauges, gauges) } diff --git a/internal/infra/config/agent/agent.go b/internal/infra/config/agent/agent.go index 90c9bd5..280adf2 100644 --- a/internal/infra/config/agent/agent.go +++ b/internal/infra/config/agent/agent.go @@ -20,8 +20,10 @@ const ( type Config struct { Addr string `env:"ADDRESS"` UpdateURL string - ReportInterval uint `env:"REPORT_INTERVAL"` - PollInterval uint `env:"POLL_INTERVAL"` + ReportInterval uint `env:"REPORT_INTERVAL"` + PollInterval uint `env:"POLL_INTERVAL"` + HMACKey string `env:"KEY"` + RateLimit uint `env:"RATE_LIMIT"` RetryAttempts int RetryIntervalInitial time.Duration RetryIntervalBackoff time.Duration @@ -38,6 +40,8 @@ func InitConfig() (*Config, error) { "r", DefReportInterval, "How often to send metrics to the server, seconds", ) pollInterval := flag.Uint("p", DefPollInterval, "How often to query metrics, seconds") + hmacKey := flag.String("k", "", "HMAC key for integrity checks") + rateLimit := flag.Uint("l", 1, "Max number of active workers") flag.Parse() if conf.Addr == "" { conf.Addr = *addr @@ -48,6 +52,12 @@ func InitConfig() (*Config, error) { if conf.PollInterval == 0 { conf.PollInterval = *pollInterval } + if conf.HMACKey == "" { + conf.HMACKey = *hmacKey + } + if conf.RateLimit == 0 { + conf.RateLimit = *rateLimit + } conf.UpdateURL = updateURL conf.RetryAttempts = DefRetryAttempts conf.RetryIntervalInitial = DefRetryIntervalInitial diff --git a/internal/infra/config/server/server.go b/internal/infra/config/server/server.go index 23db9c6..d6e3b9a 100644 --- a/internal/infra/config/server/server.go +++ b/internal/infra/config/server/server.go @@ -25,6 +25,7 @@ type Config struct { FileStoragePath string `env:"FILE_STORAGE_PATH"` Restore *bool `env:"RESTORE"` DatabaseDSN string `env:"DATABASE_DSN"` + HMACKey string `env:"KEY"` RetryAttempts int RetryIntervalInitial time.Duration RetryIntervalBackoff time.Duration @@ -52,7 +53,7 @@ func InitConfig() (*Config, error) { flagRestore := flag.Bool("r", DefRestore, "Restore init state from the file (see -f flag)") flagStoreInterval := flag.Uint("i", DefStoreInterval, "How often to store data in the file") - + hmacKey := flag.String("k", "", "HMAC key for integrity checks") flag.Parse() if conf.Addr == "" { @@ -70,6 +71,9 @@ func InitConfig() (*Config, error) { if conf.StoreInterval == nil { conf.StoreInterval = flagStoreInterval } + if conf.HMACKey == "" { + conf.HMACKey = *hmacKey + } return conf, nil } diff --git a/internal/infra/hashcheck/middleware.go b/internal/infra/hashcheck/middleware.go new file mode 100644 index 0000000..5f4d77b --- /dev/null +++ b/internal/infra/hashcheck/middleware.go @@ -0,0 +1,92 @@ +package hashcheck + +import ( + "bytes" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "io" + "net/http" +) + +type responseMetadata struct { + data []byte +} + +type extendedWriter struct { + http.ResponseWriter + response *responseMetadata + hmacKey string +} + +func (w *extendedWriter) Write(b []byte) (int, error) { + w.response.data = append(w.response.data, b...) + serverHash, err := hashData([]byte(w.hmacKey), &w.response.data) // "{"id":"SD11","type":"counter","delta":1}" + if err != nil { + w.Write([]byte(err.Error())) + w.WriteHeader(http.StatusInternalServerError) + return len(err.Error()), err + } + w.Header().Set("HashSHA256", serverHash) + size, err := w.ResponseWriter.Write(b) + return size, err +} + +func MiddlewareReader(key string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + checkHashFn := func(w http.ResponseWriter, r *http.Request) { + var clientHash string + if clientHash = r.Header.Get("HashSHA256"); clientHash == "" { + next.ServeHTTP(w, r) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + serverHash, err := hashData([]byte(key), &body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if clientHash != serverHash { + w.WriteHeader(http.StatusBadRequest) + return + } + + r.Body = io.NopCloser(bytes.NewBuffer(body)) + next.ServeHTTP(w, r) + } + return http.HandlerFunc(checkHashFn) + } +} + +func MiddlewareWriter(key string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + addHashFn := func(w http.ResponseWriter, r *http.Request) { + extWriter := &extendedWriter{ + ResponseWriter: w, + response: &responseMetadata{ + data: []byte{}, + }, + hmacKey: key, + } + next.ServeHTTP(extWriter, r) + } + return http.HandlerFunc(addHashFn) + } +} + +func hashData(key []byte, payload *[]byte) (string, error) { + mac := hmac.New(sha256.New, key) + if _, err := mac.Write(*payload); err != nil { + return "", err + } + hash := mac.Sum(nil) + result := hex.EncodeToString(hash) + return result, nil +}