From aa547ad2b9d5ab18e295b7d71b875744916c8de1 Mon Sep 17 00:00:00 2001 From: ganglv <88995770+ganglyu@users.noreply.github.com> Date: Fri, 10 Jan 2025 09:10:51 +0800 Subject: [PATCH] Improve GNMI service to limit API access by role (#335) Microsoft ADO: 30073317 #### Why I did it GNMI/GNOI services need to support role in API operation. RO user is not allowed to call write API. #### How I did it We have supported common name to role mapping, and I have updated API authentication to verify role. #### How to verify it Run unit test and end to end test. --- gnmi_server/debug.go | 2 +- gnmi_server/gnoi.go | 36 ++++++++++++++++++------------------ gnmi_server/server.go | 18 +++++++++++++----- gnmi_server/server_test.go | 37 +++++++++++++++++++++++++++++++++---- 4 files changed, 65 insertions(+), 28 deletions(-) diff --git a/gnmi_server/debug.go b/gnmi_server/debug.go index 6099630e..11ec93dc 100644 --- a/gnmi_server/debug.go +++ b/gnmi_server/debug.go @@ -35,7 +35,7 @@ import ( func (srv *Server) GetSubscribePreferences(req *spb_gnoi.SubscribePreferencesReq, stream spb_gnoi.Debug_GetSubscribePreferencesServer) error { ctx := stream.Context() - ctx, err := authenticate(srv.config, ctx) + ctx, err := authenticate(srv.config, ctx, false) if err != nil { return err } diff --git a/gnmi_server/gnoi.go b/gnmi_server/gnoi.go index 4c3aad6c..14e1fd37 100644 --- a/gnmi_server/gnoi.go +++ b/gnmi_server/gnoi.go @@ -71,7 +71,7 @@ func ReadFileStat(path string) (*gnoi_file_pb.StatInfo, error) { } func (srv *FileServer) Stat(ctx context.Context, req *gnoi_file_pb.StatRequest) (*gnoi_file_pb.StatResponse, error) { - _, err := authenticate(srv.config, ctx) + _, err := authenticate(srv.config, ctx, false) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func KillOrRestartProcess(restart bool, serviceName string) error { } func (srv *SystemServer) KillProcess(ctx context.Context, req *gnoi_system_pb.KillProcessRequest) (*gnoi_system_pb.KillProcessResponse, error) { - _, err := authenticate(srv.config, ctx) + _, err := authenticate(srv.config, ctx, true) if err != nil { return nil, err } @@ -166,7 +166,7 @@ func RebootSystem(fileName string) error { func (srv *SystemServer) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest) (*gnoi_system_pb.RebootResponse, error) { fileName := common_utils.GNMI_WORK_PATH + "/config_db.json.tmp" - _, err := authenticate(srv.config, ctx) + _, err := authenticate(srv.config, ctx, true) if err != nil { return nil, err } @@ -202,7 +202,7 @@ func (srv *SystemServer) Reboot(ctx context.Context, req *gnoi_system_pb.RebootR // TODO: Support GNOI RebootStatus func (srv *SystemServer) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootStatusRequest) (*gnoi_system_pb.RebootStatusResponse, error) { - _, err := authenticate(srv.config, ctx) + _, err := authenticate(srv.config, ctx, false) if err != nil { return nil, err } @@ -212,7 +212,7 @@ func (srv *SystemServer) RebootStatus(ctx context.Context, req *gnoi_system_pb.R // TODO: Support GNOI CancelReboot func (srv *SystemServer) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelRebootRequest) (*gnoi_system_pb.CancelRebootResponse, error) { - _, err := authenticate(srv.config, ctx) + _, err := authenticate(srv.config, ctx, true) if err != nil { return nil, err } @@ -221,7 +221,7 @@ func (srv *SystemServer) CancelReboot(ctx context.Context, req *gnoi_system_pb.C } func (srv *SystemServer) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.System_PingServer) error { ctx := rs.Context() - _, err := authenticate(srv.config, ctx) + _, err := authenticate(srv.config, ctx, true) if err != nil { return err } @@ -230,7 +230,7 @@ func (srv *SystemServer) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb } func (srv *SystemServer) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_system_pb.System_TracerouteServer) error { ctx := rs.Context() - _, err := authenticate(srv.config, ctx) + _, err := authenticate(srv.config, ctx, true) if err != nil { return err } @@ -239,7 +239,7 @@ func (srv *SystemServer) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gn } func (srv *SystemServer) SetPackage(rs gnoi_system_pb.System_SetPackageServer) error { ctx := rs.Context() - _, err := authenticate(srv.config, ctx) + _, err := authenticate(srv.config, ctx, true) if err != nil { return err } @@ -247,7 +247,7 @@ func (srv *SystemServer) SetPackage(rs gnoi_system_pb.System_SetPackageServer) e return status.Errorf(codes.Unimplemented, "") } func (srv *SystemServer) SwitchControlProcessor(ctx context.Context, req *gnoi_system_pb.SwitchControlProcessorRequest) (*gnoi_system_pb.SwitchControlProcessorResponse, error) { - _, err := authenticate(srv.config, ctx) + _, err := authenticate(srv.config, ctx, true) if err != nil { return nil, err } @@ -255,7 +255,7 @@ func (srv *SystemServer) SwitchControlProcessor(ctx context.Context, req *gnoi_s return nil, status.Errorf(codes.Unimplemented, "") } func (srv *SystemServer) Time(ctx context.Context, req *gnoi_system_pb.TimeRequest) (*gnoi_system_pb.TimeResponse, error) { - _, err := authenticate(srv.config, ctx) + _, err := authenticate(srv.config, ctx, false) if err != nil { return nil, err } @@ -267,7 +267,7 @@ func (srv *SystemServer) Time(ctx context.Context, req *gnoi_system_pb.TimeReque func (srv *Server) Authenticate(ctx context.Context, req *spb_jwt.AuthenticateRequest) (*spb_jwt.AuthenticateResponse, error) { // Can't enforce normal authentication here.. maybe only enforce client cert auth if enabled? - // ctx,err := authenticate(srv.config, ctx) + // ctx,err := authenticate(srv.config, ctx, false) // if err != nil { // return nil, err // } @@ -292,7 +292,7 @@ func (srv *Server) Authenticate(ctx context.Context, req *spb_jwt.AuthenticateRe } func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*spb_jwt.RefreshResponse, error) { - ctx, err := authenticate(srv.config, ctx) + ctx, err := authenticate(srv.config, ctx, true) if err != nil { return nil, err } @@ -320,7 +320,7 @@ func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*s } func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRequest) (*spb.ClearNeighborsResponse, error) { - ctx, err := authenticate(srv.config, ctx) + ctx, err := authenticate(srv.config, ctx, true) if err != nil { return nil, err } @@ -352,7 +352,7 @@ func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRe } func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (*spb.CopyConfigResponse, error) { - ctx, err := authenticate(srv.config, ctx) + ctx, err := authenticate(srv.config, ctx, true) if err != nil { return nil, err } @@ -383,7 +383,7 @@ func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) ( } func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequest) (*spb.TechsupportResponse, error) { - ctx, err := authenticate(srv.config, ctx) + ctx, err := authenticate(srv.config, ctx, false) if err != nil { return nil, err } @@ -415,7 +415,7 @@ func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequ } func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallRequest) (*spb.ImageInstallResponse, error) { - ctx, err := authenticate(srv.config, ctx) + ctx, err := authenticate(srv.config, ctx, true) if err != nil { return nil, err } @@ -447,7 +447,7 @@ func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallReques } func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest) (*spb.ImageRemoveResponse, error) { - ctx, err := authenticate(srv.config, ctx) + ctx, err := authenticate(srv.config, ctx, true) if err != nil { return nil, err } @@ -477,7 +477,7 @@ func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest) } func (srv *Server) ImageDefault(ctx context.Context, req *spb.ImageDefaultRequest) (*spb.ImageDefaultResponse, error) { - ctx, err := authenticate(srv.config, ctx) + ctx, err := authenticate(srv.config, ctx, true) if err != nil { return nil, err } diff --git a/gnmi_server/server.go b/gnmi_server/server.go index 23dd817c..9c002baa 100644 --- a/gnmi_server/server.go +++ b/gnmi_server/server.go @@ -91,6 +91,7 @@ type Config struct { var AuthLock sync.Mutex var maMu sync.Mutex +const WriteAccessMode = "readwrite" func (i AuthTypes) String() string { if i["none"] { @@ -240,7 +241,7 @@ func (srv *Server) Port() int64 { return srv.config.Port } -func authenticate(config *Config, ctx context.Context) (context.Context, error) { +func authenticate(config *Config, ctx context.Context, writeAccess bool) (context.Context, error) { var err error success := false rc, ctx := common_utils.GetContext(ctx) @@ -268,6 +269,13 @@ func authenticate(config *Config, ctx context.Context) (context.Context, error) if err == nil { success = true } + // role must be readwrite to support write access + if writeAccess && config.ConfigTableName != "" { + role := rc.Auth.Roles[0] + if role != WriteAccessMode { + return ctx, fmt.Errorf("%s does not have write access, %s", rc.Auth.User, role) + } + } } //Allow for future authentication mechanisms here... @@ -283,7 +291,7 @@ func authenticate(config *Config, ctx context.Context) (context.Context, error) // Subscribe implements the gNMI Subscribe RPC. func (s *Server) Subscribe(stream gnmipb.GNMI_SubscribeServer) error { ctx := stream.Context() - ctx, err := authenticate(s.config, ctx) + ctx, err := authenticate(s.config, ctx, false) if err != nil { return err } @@ -368,7 +376,7 @@ func IsNativeOrigin(origin string) bool { // Get implements the Get RPC in gNMI spec. func (s *Server) Get(ctx context.Context, req *gnmipb.GetRequest) (*gnmipb.GetResponse, error) { common_utils.IncCounter(common_utils.GNMI_GET) - ctx, err := authenticate(s.config, ctx) + ctx, err := authenticate(s.config, ctx, false) if err != nil { common_utils.IncCounter(common_utils.GNMI_GET_FAIL) return nil, err @@ -474,7 +482,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe common_utils.IncCounter(common_utils.GNMI_SET_FAIL) return nil, grpc.Errorf(codes.Unimplemented, "GNMI is in read-only mode") } - ctx, err := authenticate(s.config, ctx) + ctx, err := authenticate(s.config, ctx, true) if err != nil { common_utils.IncCounter(common_utils.GNMI_SET_FAIL) return nil, err @@ -576,7 +584,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe } func (s *Server) Capabilities(ctx context.Context, req *gnmipb.CapabilityRequest) (*gnmipb.CapabilityResponse, error) { - ctx, err := authenticate(s.config, ctx) + ctx, err := authenticate(s.config, ctx, false) if err != nil { return nil, err } diff --git a/gnmi_server/server_test.go b/gnmi_server/server_test.go index 40a90abc..4b4da0e3 100644 --- a/gnmi_server/server_test.go +++ b/gnmi_server/server_test.go @@ -4520,7 +4520,7 @@ func TestClientCertAuthenAndAuthor(t *testing.T) { // check get 1 cert name ctx, cancel = CreateAuthorizationCtx() configDb.Flushdb() - gnmiTable.Hset("certname1", "role", "role1") + gnmiTable.Hset("certname1", "role", "readwrite") ctx, err = ClientCertAuthenAndAuthor(ctx, "GNMI_CLIENT_CERT", false) if err != nil { t.Errorf("CommonNameMatch with correct cert name should success: %v", err) @@ -4531,8 +4531,8 @@ func TestClientCertAuthenAndAuthor(t *testing.T) { // check get multiple cert names ctx, cancel = CreateAuthorizationCtx() configDb.Flushdb() - gnmiTable.Hset("certname1", "role", "role1") - gnmiTable.Hset("certname2", "role", "role2") + gnmiTable.Hset("certname1", "role", "readwrite") + gnmiTable.Hset("certname2", "role", "readonly") ctx, err = ClientCertAuthenAndAuthor(ctx, "GNMI_CLIENT_CERT", false) if err != nil { t.Errorf("CommonNameMatch with correct cert name should success: %v", err) @@ -4543,7 +4543,7 @@ func TestClientCertAuthenAndAuthor(t *testing.T) { // check a invalid cert cname ctx, cancel = CreateAuthorizationCtx() configDb.Flushdb() - gnmiTable.Hset("certname2", "role", "role2") + gnmiTable.Hset("certname2", "role", "readonly") ctx, err = ClientCertAuthenAndAuthor(ctx, "GNMI_CLIENT_CERT", false) if err == nil { t.Errorf("CommonNameMatch with invalid cert name should fail: %v", err) @@ -4555,6 +4555,35 @@ func TestClientCertAuthenAndAuthor(t *testing.T) { swsscommon.DeleteDBConnector(configDb) } +func TestAuthenticate(t *testing.T) { + if !swsscommon.SonicDBConfigIsInit() { + swsscommon.SonicDBConfigInitialize() + } + + var tableName = "GNMI_CLIENT_CERT" + var configDb = swsscommon.NewDBConnector("CONFIG_DB", uint(0), true) + var gnmiTable = swsscommon.NewTable(configDb, tableName) + defer swsscommon.DeleteTable(gnmiTable) + defer swsscommon.DeleteDBConnector(configDb) + configDb.Flushdb() + + // initialize err variable + err := status.Error(codes.Unauthenticated, "") + + // check a invalid role + cfg := &Config{ConfigTableName: tableName, UserAuth: AuthTypes{"password": false, "cert": true, "jwt": false}} + ctx, cancel := CreateAuthorizationCtx() + configDb.Flushdb() + gnmiTable.Hset("certname1", "role", "readonly") + // Call authenticate to verify the user's role. This should fail if the role is "readonly". + _, err = authenticate(cfg, ctx, true) + if err == nil { + t.Errorf("authenticate with readonly role should fail: %v", err) + } + + cancel() +} + type MockServerStream struct { grpc.ServerStream }