Skip to content

Commit

Permalink
Merge pull request #25 from matthiasBT/iter23
Browse files Browse the repository at this point in the history
Sprint 9, increment 23
  • Loading branch information
matthiasBT authored Jan 7, 2024
2 parents af81264 + 6cee785 commit 17bc252
Show file tree
Hide file tree
Showing 18 changed files with 711 additions and 196 deletions.
7 changes: 6 additions & 1 deletion cmd/agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ func main() {
done := make(<-chan bool)
dataExchange := entities.SnapshotWrapper{CurrSnapshot: nil}
retrier := setupRetrier(conf, logger)
publicKey, err := conf.ReadServerPublicKey()
if err != nil {
panic(err)
}
reporter := report.Reporter{
Logger: logger,
Data: &dataExchange,
Expand All @@ -62,6 +66,7 @@ func main() {
conf.UpdateURL,
retrier,
[]byte(conf.HMACKey),
publicKey,
conf.RateLimit,
),
}
Expand All @@ -75,7 +80,7 @@ func main() {
go reporter.Report()
go poller.Poll()
quitChannel := make(chan os.Signal, 1)
signal.Notify(quitChannel, syscall.SIGINT, syscall.SIGTERM)
signal.Notify(quitChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
<-quitChannel
fmt.Println("Stopping the agent")
}
29 changes: 20 additions & 9 deletions cmd/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package main

import (
"context"
"crypto/rsa"
"errors"
"log"
"net/http"
Expand All @@ -16,8 +17,8 @@ import (
"github.com/go-chi/chi/v5"
"github.com/matthiasBT/monitoring/internal/infra/compression"
"github.com/matthiasBT/monitoring/internal/infra/config/server"
"github.com/matthiasBT/monitoring/internal/infra/hashcheck"
"github.com/matthiasBT/monitoring/internal/infra/logging"
"github.com/matthiasBT/monitoring/internal/infra/secure"
"github.com/matthiasBT/monitoring/internal/infra/utils"
"github.com/matthiasBT/monitoring/internal/server/adapters"
"github.com/matthiasBT/monitoring/internal/server/entities"
Expand All @@ -32,12 +33,17 @@ var (

// setupServer configures and returns a new HTTP router with middleware and routes.
// It includes logging, compression, optional HMAC checking, and controller routes.
func setupServer(logger logging.ILogger, controller *usecases.BaseController, hmacKey string) *chi.Mux {
func setupServer(
logger logging.ILogger, controller *usecases.BaseController, hmacKey string, key *rsa.PrivateKey,
) *chi.Mux {
r := chi.NewRouter()
r.Use(logging.Middleware(logger))
r.Use(compression.MiddlewareReader, compression.MiddlewareWriter)
if hmacKey != "" {
r.Use(hashcheck.MiddlewareReader(hmacKey), hashcheck.MiddlewareWriter(hmacKey))
r.Use(secure.MiddlewareHashReader(hmacKey), secure.MiddlewareHashWriter(hmacKey))
}
if key != nil {
r.Use(secure.MiddlewareCryptoReader(key))
}
r.Mount("/", controller.Route())
return r
Expand All @@ -47,11 +53,11 @@ func setupServer(logger logging.ILogger, controller *usecases.BaseController, hm
// It listens for system signals and shuts down the server after processing ongoing requests.
func gracefulShutdown(srv *http.Server, done chan struct{}, logger logging.ILogger) {
quitChannel := make(chan os.Signal, 1)
signal.Notify(quitChannel, syscall.SIGINT, syscall.SIGTERM)
signal.Notify(quitChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
sig := <-quitChannel
logger.Infof("Received signal: %v\n", sig)
done <- struct{}{}
time.Sleep(2 * time.Second)
time.Sleep(5 * time.Second)

if err := srv.Shutdown(context.Background()); err != nil {
log.Fatalf("Server shutdown failed: %v\n", err.Error())
Expand All @@ -74,7 +80,8 @@ func setupRetrier(conf *server.Config, logger logging.ILogger) utils.Retrier {
func setupKeeper(conf *server.Config, logger logging.ILogger, retrier utils.Retrier) entities.Keeper {
if conf.Flushes() {
if conf.DatabaseDSN != "" {
return adapters.NewDBKeeper(conf, logger, retrier)
db := adapters.OpenDB(conf.DatabaseDSN)
return adapters.NewDBKeeper(db, logger, retrier)
} else {
return adapters.NewFileKeeper(conf, logger, retrier)
}
Expand All @@ -88,7 +95,7 @@ func setupTicker(conf *server.Config) <-chan time.Time {
if conf.FlushesSync() {
return make(chan time.Time) // will never be used
} else {
ticker := time.NewTicker(time.Duration(*conf.StoreInterval) * time.Second)
ticker := time.NewTicker(time.Duration(conf.StoreInterval) * time.Second)
return ticker.C
}
}
Expand Down Expand Up @@ -116,7 +123,7 @@ func main() {
storage := adapters.NewMemStorage(done, tickerChan, logger, keeper)

if conf.Flushes() {
if *conf.Restore {
if conf.Restore {
state := keeper.Restore()
storage.Init(state)
}
Expand All @@ -126,7 +133,11 @@ func main() {
}

controller := usecases.NewBaseController(logger, storage, conf.TemplatePath)
r := setupServer(logger, controller, conf.HMACKey)
key, err := conf.ReadPrivateKey()
if err != nil {
panic(err)
}
r := setupServer(logger, controller, conf.HMACKey, key)
srv := http.Server{Addr: conf.Addr, Handler: r}
go func() {
logger.Infof("Launching the server at %s\n", conf.Addr)
Expand Down
35 changes: 29 additions & 6 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func Example() {
storage := adapters.NewMemStorage(nil, nil, logger, nil)
controller := usecases.NewBaseController(logger, storage, "/")

ping(controller)
updateCounter(100500, controller)
updateGauge(5.5, controller)
getCounter(controller)
Expand All @@ -30,10 +31,11 @@ func Example() {
getGauge(controller)

// Output:
// {"id":"FooBar","type":"counter","delta":100500}
// {"id":"BarFoo","type":"gauge","value":5.5}
// {"id":"FooBar","type":"counter","delta":100511}
// {"id":"BarFoo","type":"gauge","value":1.5}
// 200
// {id:FooBar, type:counter, delta:100500}
// {id:BarFoo, type:gauge, value:5.5}
// {id:FooBar, type:counter, delta:100511}
// {id:BarFoo, type:gauge, value:1.5}
}

func updateCounter(value int64, controller *usecases.BaseController) {
Expand Down Expand Up @@ -76,7 +78,7 @@ func getCounter(controller *usecases.BaseController) {
getCounterReq := httptest.NewRequest(http.MethodGet, "/value/", bytes.NewReader(body))
getCounterReq.Header.Set("Content-Type", "application/json")
controller.GetMetric(w, getCounterReq)
fmt.Printf("%v\n", w.Body)
printSorted(w.Body.Bytes())
}

func getGauge(controller *usecases.BaseController) {
Expand All @@ -91,5 +93,26 @@ func getGauge(controller *usecases.BaseController) {
getGaugeReq := httptest.NewRequest(http.MethodGet, "/value/", bytes.NewReader(body))
getGaugeReq.Header.Set("Content-Type", "application/json")
controller.GetMetric(w, getGaugeReq)
fmt.Printf("%v\n", w.Body)
printSorted(w.Body.Bytes())
}

func printSorted(body []byte) {
var result common.Metrics
json.Unmarshal(body, &result)
if result.MType == common.TypeGauge {
fmt.Printf("{%s:%v, %s:%v, %s:%v}\n", "id", result.ID, "type", result.MType, "value", *result.Value)
} else {
fmt.Printf("{%s:%v, %s:%v, %s:%v}\n", "id", result.ID, "type", result.MType, "delta", *result.Delta)
}
}

func ping(controller *usecases.BaseController) {
w := httptest.NewRecorder()
getGaugeReq := httptest.NewRequest(http.MethodGet, "/ping", nil)
controller.Ping(w, getGaugeReq)
resp := w.Result()
if resp.Body != nil {
defer resp.Body.Close()
}
fmt.Println(resp.StatusCode)
}
99 changes: 85 additions & 14 deletions internal/agent/adapters/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ package adapters

import (
"bytes"
"compress/gzip"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/hex"
"encoding/json"
Expand All @@ -27,8 +32,8 @@ type HTTPReportAdapter struct {
// Logger is used for logging messages related to HTTP reporting activities.
Logger logging.ILogger

// jobs is an internal channel used to queue payloads for reporting.
jobs chan []byte
// Jobs is an internal channel used to queue payloads for reporting.
Jobs chan []byte

// ServerAddr specifies the HTTP server address where reports are sent.
ServerAddr string
Expand All @@ -39,6 +44,9 @@ type HTTPReportAdapter struct {
// HMACKey is the key used for HMAC-SHA256 hashing to ensure data integrity.
HMACKey []byte

// CryptoKey is the key used for payload encryption
CryptoKey *rsa.PublicKey

// Retrier is used to handle retries for HTTP requests in case of failures.
Retrier utils.Retrier
}
Expand All @@ -55,24 +63,25 @@ func NewHTTPReportAdapter(
updateURL string,
retrier utils.Retrier,
hmacKey []byte,
cryptoKey *rsa.PublicKey,
workerNum uint,
) *HTTPReportAdapter {
jobs := make(chan []byte, workerNum)
adapter := HTTPReportAdapter{
Logger: logger,
ServerAddr: serverAddr,
UpdateURL: updateURL,
Retrier: retrier,
HMACKey: hmacKey,
jobs: jobs,
CryptoKey: cryptoKey,
Jobs: make(chan []byte, workerNum),
}
var i uint
for i = 0; i < workerNum; i++ {
go func() {
for {
data := <-jobs
data := <-adapter.Jobs
//nolint:errcheck
adapter.report(&data)
adapter.report(data)
}
}()
}
Expand All @@ -87,7 +96,7 @@ func (r *HTTPReportAdapter) Report(metrics *common.Metrics) error {
r.Logger.Errorf("Failed to marshal a metric: %v", metrics)
return err
}
r.jobs <- payload
r.Jobs <- payload
return nil
}

Expand All @@ -99,16 +108,22 @@ func (r *HTTPReportAdapter) ReportBatch(batch []*common.Metrics) error {
r.Logger.Errorf("Failed to marshal a batch of metrics: %v\n", err.Error())
return err
}
r.jobs <- payload
r.Jobs <- payload
return nil
}

func (r *HTTPReportAdapter) report(payload *[]byte) error {
func (r *HTTPReportAdapter) report(payload []byte) error {
var (
req *http.Request
err error
)
u := url.URL{Scheme: "http", Host: r.ServerAddr, Path: r.UpdateURL}
if r.CryptoKey != nil {
payload, err = r.encryptData(payload)
if err != nil {
return err
}
}
if req, err = r.createRequest(u, payload); err != nil {
return err
}
Expand Down Expand Up @@ -145,17 +160,23 @@ func (r *HTTPReportAdapter) report(payload *[]byte) error {
return nil
}

func (r *HTTPReportAdapter) createRequest(path url.URL, payload *[]byte) (*http.Request, error) {
req, err := http.NewRequest("POST", path.String(), bytes.NewReader(*payload))
func (r *HTTPReportAdapter) createRequest(path url.URL, payload []byte) (*http.Request, error) {
var compressed bytes.Buffer
compressed, err := r.compress(payload)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", path.String(), &compressed)
if err != nil {
r.Logger.Errorf("Failed to create a request: %v\n", err.Error())
return nil, err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Content-Encoding", "gzip")
return req, nil
}

func (r *HTTPReportAdapter) addHMACHeader(req *http.Request, payload *[]byte) error {
func (r *HTTPReportAdapter) addHMACHeader(req *http.Request, payload []byte) error {
if hash, err := r.hashData(payload); err != nil {
return err
} else if hash != "" {
Expand All @@ -164,12 +185,12 @@ func (r *HTTPReportAdapter) addHMACHeader(req *http.Request, payload *[]byte) er
return nil
}

func (r *HTTPReportAdapter) hashData(payload *[]byte) (string, error) {
func (r *HTTPReportAdapter) hashData(payload []byte) (string, error) {
if bytes.Equal(r.HMACKey, []byte{}) {
return "", nil
}
mac := hmac.New(sha256.New, r.HMACKey)
if _, err := mac.Write(*payload); err != nil {
if _, err := mac.Write(payload); err != nil {
r.Logger.Errorf("Failed to calculate hash: %v", err.Error())
return "", err
}
Expand All @@ -178,3 +199,53 @@ func (r *HTTPReportAdapter) hashData(payload *[]byte) (string, error) {
r.Logger.Infof("HMAC-SHA256 hash: %s\n", result)
return result, nil
}

func (r *HTTPReportAdapter) compress(payload []byte) (bytes.Buffer, error) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, err := gz.Write(payload)
if err != nil {
return buf, err
}
if err := gz.Close(); err != nil {
return buf, err
}
return buf, nil
}

func (r *HTTPReportAdapter) encryptData(payload []byte) ([]byte, error) {
key, encryptedPayload, err := encryptAES(payload)
if err != nil {
r.Logger.Errorf("Error encrypting message: %v", err)
return nil, err
}
encryptedKey, err := rsa.EncryptOAEP(sha256.New(), rand.Reader, r.CryptoKey, key, nil)
if err != nil {
r.Logger.Errorf("Error encrypting AES key: %v", err)
return nil, err
}
return append(encryptedKey, encryptedPayload...), nil
}

func encryptAES(plaintext []byte) ([]byte, []byte, error) {
key := make([]byte, 32) // AES-256
if _, err := rand.Read(key); err != nil {
return nil, nil, err
}

block, err := aes.NewCipher(key)
if err != nil {
return nil, nil, err
}

ciphertext := make([]byte, aes.BlockSize+len(plaintext))
iv := ciphertext[:aes.BlockSize]
if _, err := rand.Read(iv); err != nil {
return nil, nil, err
}

stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], plaintext)

return key, ciphertext, nil
}
Loading

0 comments on commit 17bc252

Please sign in to comment.