From 3b78a2afc201f6edf9b68daabb8e652187a19948 Mon Sep 17 00:00:00 2001 From: Cosmin Tupangiu Date: Fri, 13 Dec 2024 14:54:08 +0100 Subject: [PATCH] agent/auth: Implement agent authentication Agent authentication uses the same logic as the planner auth. For simplicity, the jwt inserted in the ova image is the user's token. Signed-off-by: Cosmin Tupangiu --- cmd/planner-agent/main.go | 21 ++++++++++----- data/ignition.template | 7 +++++ internal/agent/agent.go | 12 +++++++-- internal/agent/agent_test.go | 3 ++- internal/agent/client/planner.go | 31 ++++++++++++++++++++--- internal/agent/common/common.go | 7 +++++ internal/agent/service/health.go | 10 ++++---- internal/agent/service/health_test.go | 10 ++++---- internal/api_server/agentserver/server.go | 7 +++++ internal/auth/rhsso_authenticator.go | 1 + internal/auth/user.go | 7 ++++- internal/image/ova.go | 6 +++++ internal/service/image.go | 6 +++++ 13 files changed, 105 insertions(+), 23 deletions(-) create mode 100644 internal/agent/common/common.go diff --git a/cmd/planner-agent/main.go b/cmd/planner-agent/main.go index 3801316..3382f5e 100644 --- a/cmd/planner-agent/main.go +++ b/cmd/planner-agent/main.go @@ -19,6 +19,7 @@ import ( const ( agentFilename = "agent_id" + jwtFilename = "jwt.json" ) func main() { @@ -70,23 +71,31 @@ func (a *agentCmd) Execute() error { undo := zap.ReplaceGlobals(logger) defer undo() - agentID, err := a.getAgentID() + agentID, err := a.readFile(agentFilename) if err != nil { zap.S().Fatalf("failed to retreive agent_id: %v", err) } - agentInstance := agent.New(uuid.MustParse(agentID), a.config) + // Try to read jwt from file. + // We're assuming the jwt is valid. + // The agent will not try to validate the jwt. The backend is responsible for validating the token. + jwt, err := a.readFile(jwtFilename) + if err != nil { + zap.S().Errorf("failed to read jwt: %v", err) + } + + agentInstance := agent.New(uuid.MustParse(agentID), jwt, a.config) if err := agentInstance.Run(context.Background()); err != nil { zap.S().Fatalf("running device agent: %v", err) } return nil } -func (a *agentCmd) getAgentID() (string, error) { +func (a *agentCmd) readFile(filename string) (string, error) { // look for it in data dir - dataDirPath := path.Join(a.config.DataDir, agentFilename) - if _, err := os.Stat(dataDirPath); err == nil { - content, err := os.ReadFile(dataDirPath) + confDirPath := path.Join(a.config.DataDir, filename) + if _, err := os.Stat(confDirPath); err == nil { + content, err := os.ReadFile(confDirPath) if err != nil { return "", err } diff --git a/data/ignition.template b/data/ignition.template index 2d9c5bf..df8696f 100644 --- a/data/ignition.template +++ b/data/ignition.template @@ -67,6 +67,13 @@ storage: location = "{{.InsecureRegistry}}" insecure = true {{end}} + {{ if .Token }} + - path: /home/core/.migration-planner/config/jwt.json + mode: 0644 + contents: + inline: | + {{ .Token }} + {{ end }} - path: /var/lib/systemd/linger/core mode: 0644 contents: diff --git a/internal/agent/agent.go b/internal/agent/agent.go index f1e9479..6540c12 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -14,6 +14,7 @@ import ( "github.com/google/uuid" api "github.com/kubev2v/migration-planner/api/v1alpha1" "github.com/kubev2v/migration-planner/internal/agent/client" + "github.com/kubev2v/migration-planner/internal/agent/common" "github.com/kubev2v/migration-planner/internal/agent/config" "github.com/kubev2v/migration-planner/internal/agent/service" "github.com/lthibault/jitterbug" @@ -31,11 +32,12 @@ const ( var version string // New creates a new agent. -func New(id uuid.UUID, config *config.Config) *Agent { +func New(id uuid.UUID, jwt string, config *config.Config) *Agent { return &Agent{ config: config, healtCheckStopCh: make(chan chan any), id: id, + jwt: jwt, } } @@ -45,6 +47,7 @@ type Agent struct { healtCheckStopCh chan chan any credUrl string id uuid.UUID + jwt string } func (a *Agent) Run(ctx context.Context) error { @@ -96,6 +99,11 @@ func (a *Agent) start(ctx context.Context, plannerClient client.Planner) { inventoryUpdater := service.NewInventoryUpdater(a.id, plannerClient) statusUpdater := service.NewStatusUpdater(a.id, version, a.credUrl, a.config, plannerClient) + // insert jwt into the context if any + if a.jwt != "" { + ctx = context.WithValue(ctx, common.JwtKey, a.jwt) + } + // start server a.server = NewServer(defaultAgentPort, a.config.DataDir, a.config.WwwDir) go a.server.Start(statusUpdater) @@ -114,7 +122,7 @@ func (a *Agent) start(ctx context.Context, plannerClient client.Planner) { } // TODO refactor health checker to call it from the main goroutine - healthChecker.Start(a.healtCheckStopCh) + healthChecker.Start(ctx, a.healtCheckStopCh) collector := service.NewCollector(a.config.DataDir) collector.Collect(ctx) diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go index 0be2845..d04dfeb 100644 --- a/internal/agent/agent_test.go +++ b/internal/agent/agent_test.go @@ -74,7 +74,8 @@ var _ = Describe("Agent", func() { } config.PlannerService.Service.Server = testHttpServer.URL - a := agent.New(agentID, &config) + jwt := "" + a := agent.New(agentID, jwt, &config) ctx, cancel := context.WithTimeout(context.TODO(), 30*time.Second) go func() { err := a.Run(ctx) diff --git a/internal/agent/client/planner.go b/internal/agent/client/planner.go index ae68f70..1c804f4 100644 --- a/internal/agent/client/planner.go +++ b/internal/agent/client/planner.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" api "github.com/kubev2v/migration-planner/api/v1alpha1/agent" + "github.com/kubev2v/migration-planner/internal/agent/common" client "github.com/kubev2v/migration-planner/internal/api/client/agent" ) @@ -29,7 +30,12 @@ type planner struct { } func (p *planner) UpdateSourceStatus(ctx context.Context, id uuid.UUID, params api.SourceStatusUpdate) error { - resp, err := p.client.ReplaceSourceStatusWithResponse(ctx, id, params) + resp, err := p.client.ReplaceSourceStatusWithResponse(ctx, id, params, func(ctx context.Context, req *http.Request) error { + if jwt, found := p.jwtFromContext(ctx); found { + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", jwt)) + } + return nil + }) if err != nil { return err } @@ -44,7 +50,13 @@ func (p *planner) UpdateSourceStatus(ctx context.Context, id uuid.UUID, params a } func (p *planner) Health(ctx context.Context) error { - resp, err := p.client.HealthWithResponse(ctx) + resp, err := p.client.HealthWithResponse(ctx, func(ctx context.Context, req *http.Request) error { + if jwt, found := p.jwtFromContext(ctx); found { + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", jwt)) + } + return nil + }) + if err != nil { return err } @@ -58,7 +70,12 @@ func (p *planner) Health(ctx context.Context) error { } func (p *planner) UpdateAgentStatus(ctx context.Context, id uuid.UUID, params api.AgentStatusUpdate) error { - resp, err := p.client.UpdateAgentStatusWithResponse(ctx, id, params) + resp, err := p.client.UpdateAgentStatusWithResponse(ctx, id, params, func(ctx context.Context, req *http.Request) error { + if jwt, found := p.jwtFromContext(ctx); found { + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", jwt)) + } + return nil + }) if err != nil { return err } @@ -73,3 +90,11 @@ func (p *planner) UpdateAgentStatus(ctx context.Context, id uuid.UUID, params ap } return nil } + +func (p *planner) jwtFromContext(ctx context.Context) (string, bool) { + val := ctx.Value(common.JwtKey) + if val == nil { + return "", false + } + return val.(string), true +} diff --git a/internal/agent/common/common.go b/internal/agent/common/common.go new file mode 100644 index 0000000..81bc166 --- /dev/null +++ b/internal/agent/common/common.go @@ -0,0 +1,7 @@ +package common + +type jwtContextKeyType struct{} + +var ( + JwtKey jwtContextKeyType +) diff --git a/internal/agent/service/health.go b/internal/agent/service/health.go index 2c290f5..e556e78 100644 --- a/internal/agent/service/health.go +++ b/internal/agent/service/health.go @@ -76,8 +76,8 @@ func NewHealthChecker(client client.Planner, logFolder string, checkInterval tim // initialInterval represents the time after which the check is started. // checkInterval represents the time to wait between checks. // closeCh is the channel used to close the goroutine. -func (h *HealthChecker) Start(closeCh chan chan any) { - h.do() +func (h *HealthChecker) Start(ctx context.Context, closeCh chan chan any) { + h.do(ctx) h.once.Do(func() { go func() { @@ -96,7 +96,7 @@ func (h *HealthChecker) Start(closeCh chan chan any) { close(c) return case <-t.C: - h.do() + h.do(ctx) } } }() @@ -109,8 +109,8 @@ func (h *HealthChecker) State() AgentHealthState { return h.state } -func (h *HealthChecker) do() { - ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout*time.Second) +func (h *HealthChecker) do(ctx context.Context) { + ctx, cancel := context.WithTimeout(ctx, defaultTimeout*time.Second) defer cancel() err := h.client.Health(ctx) diff --git a/internal/agent/service/health_test.go b/internal/agent/service/health_test.go index b6cae4d..15b607e 100644 --- a/internal/agent/service/health_test.go +++ b/internal/agent/service/health_test.go @@ -54,7 +54,7 @@ var _ = Describe("Health checker", func() { It("should close OK", func() { closeCh := make(chan chan any) - hc.Start(closeCh) + hc.Start(context.TODO(), closeCh) <-time.After(2 * time.Second) c := make(chan any, 1) @@ -81,7 +81,7 @@ var _ = Describe("Health checker", func() { It("should call health endpoint", func() { closeCh := make(chan chan any) - hc.Start(closeCh) + hc.Start(context.TODO(), closeCh) <-time.After(5 * time.Second) c := make(chan any, 1) @@ -136,7 +136,7 @@ var _ = Describe("Health checker", func() { It("should write OK -- only failures", func() { closeCh := make(chan chan any) testClient.ShouldReturnError = true - hc.Start(closeCh) + hc.Start(context.TODO(), closeCh) <-time.After(5 * time.Second) @@ -155,7 +155,7 @@ var _ = Describe("Health checker", func() { It("should write OK -- failures and one OK line", func() { closeCh := make(chan chan any) testClient.ShouldReturnError = true - hc.Start(closeCh) + hc.Start(context.TODO(), closeCh) <-time.After(2 * time.Second) testClient.ShouldReturnError = false @@ -182,7 +182,7 @@ var _ = Describe("Health checker", func() { It("should write OK -- failures and 2 OK lines", func() { closeCh := make(chan chan any) testClient.ShouldReturnError = true - hc.Start(closeCh) + hc.Start(context.TODO(), closeCh) <-time.After(2 * time.Second) testClient.ShouldReturnError = false diff --git a/internal/api_server/agentserver/server.go b/internal/api_server/agentserver/server.go index bed90bc..506cd9d 100644 --- a/internal/api_server/agentserver/server.go +++ b/internal/api_server/agentserver/server.go @@ -12,6 +12,7 @@ import ( "github.com/go-chi/chi/v5/middleware" api "github.com/kubev2v/migration-planner/api/v1alpha1/agent" server "github.com/kubev2v/migration-planner/internal/api/server/agent" + "github.com/kubev2v/migration-planner/internal/auth" "github.com/kubev2v/migration-planner/internal/config" "github.com/kubev2v/migration-planner/internal/events" service "github.com/kubev2v/migration-planner/internal/service/agent" @@ -64,8 +65,14 @@ func (s *AgentServer) Run(ctx context.Context) error { ErrorHandler: oapiErrorHandler, } + authenticator, err := auth.NewAuthenticator(s.cfg.Service.Auth) + if err != nil { + return fmt.Errorf("failed to create authenticator: %w", err) + } + router := chi.NewRouter() router.Use( + authenticator.Authenticator, middleware.RequestID, zapchi.Logger(zap.S(), "router_agent"), middleware.Recoverer, diff --git a/internal/auth/rhsso_authenticator.go b/internal/auth/rhsso_authenticator.go index 9416d08..98ba38e 100644 --- a/internal/auth/rhsso_authenticator.go +++ b/internal/auth/rhsso_authenticator.go @@ -58,6 +58,7 @@ func (rh *RHSSOAuthenticator) parseToken(userToken *jwt.Token) (User, error) { Username: claims["username"].(string), Organization: claims["org_id"].(string), ClientID: claims["client_id"].(string), + Token: userToken, }, nil } diff --git a/internal/auth/user.go b/internal/auth/user.go index fe23603..f90e295 100644 --- a/internal/auth/user.go +++ b/internal/auth/user.go @@ -1,6 +1,10 @@ package auth -import "context" +import ( + "context" + + "github.com/golang-jwt/jwt/v5" +) type usernameKeyType struct{} @@ -24,4 +28,5 @@ type User struct { Username string Organization string ClientID string + Token *jwt.Token } diff --git a/internal/image/ova.go b/internal/image/ova.go index 930ed8d..f973a6c 100644 --- a/internal/image/ova.go +++ b/internal/image/ova.go @@ -11,6 +11,7 @@ import ( "github.com/coreos/butane/config" "github.com/coreos/butane/config/common" + "github.com/golang-jwt/jwt/v5" "github.com/kubev2v/migration-planner/internal/util" "github.com/openshift/assisted-image-service/pkg/isoeditor" "github.com/openshift/assisted-image-service/pkg/overlay" @@ -24,6 +25,7 @@ const ResponseWriterKey Key = 0 type Ova struct { Writer io.Writer SshKey *string + Jwt *jwt.Token } // IgnitionData defines modifiable fields in ignition config @@ -32,6 +34,7 @@ type IgnitionData struct { PlannerService string MigrationPlannerAgentImage string InsecureRegistry string + Token string } type Image interface { @@ -146,6 +149,9 @@ func (o *Ova) generateIgnition() (string, error) { if o.SshKey != nil { ignData.SshKey = *o.SshKey } + if o.Jwt != nil { + ignData.Token = o.Jwt.Raw + } if insecureRegistry := os.Getenv("INSECURE_REGISTRY"); insecureRegistry != "" { ignData.InsecureRegistry = insecureRegistry diff --git a/internal/service/image.go b/internal/service/image.go index ec8fed2..114a9b6 100644 --- a/internal/service/image.go +++ b/internal/service/image.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/kubev2v/migration-planner/internal/api/server" + "github.com/kubev2v/migration-planner/internal/auth" "github.com/kubev2v/migration-planner/internal/image" ) @@ -15,7 +16,12 @@ func (h *ServiceHandler) GetImage(ctx context.Context, request server.GetImageRe if !ok { return server.GetImage500JSONResponse{Message: "error creating the HTTP stream"}, nil } + ova := &image.Ova{SshKey: request.Params.SshKey, Writer: writer} + // get token if any + if user, found := auth.UserFromContext(ctx); found { + ova.Jwt = user.Token + } if err := ova.Generate(); err != nil { return server.GetImage500JSONResponse{Message: fmt.Sprintf("error generating image %s", err)}, nil }