Skip to content

Commit

Permalink
[202311] Add cert authorization with common name support. (#323)
Browse files Browse the repository at this point in the history
Add cert authorization with common name support.

#### Why I did it
Support cert authorization with common name.

#### How I did it
Load trusted cert common name from config DB and check cert common name. 

#### How to verify it
Manually test.
Add new UT.

#### Work item tracking
Microsoft ADO (number only): 25226269

#### Which release branch to backport (provide reason below if selected)

<!--
- Note we only backport fixes to a release branch, *not* features!
- Please also provide a reason for the backporting below.
- e.g.
- [x] 202006
-->

- [ ] 201811
- [ ] 201911
- [ ] 202006
- [ ] 202012
- [ ] 202106
- [ ] 202111

#### Description for the changelog
Add cert authorization with common name support.

#### Link to config_db schema for YANG module changes
<!--
Provide a link to config_db schema for the table for which YANG model
is defined
Link should point to correct section on https://github.com/Azure/SONiC/wiki/Configuration.
-->

#### A picture of a cute animal (not mandatory but encouraged)
  • Loading branch information
liuh-80 authored Nov 11, 2024
1 parent 7fdd5b5 commit 9c63257
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 30 deletions.
43 changes: 39 additions & 4 deletions gnmi_server/clientCertAuth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gnmi

import (
"github.com/sonic-net/sonic-gnmi/common_utils"
"github.com/sonic-net/sonic-gnmi/swsscommon"
"github.com/golang/glog"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
Expand All @@ -10,7 +11,7 @@ import (
"google.golang.org/grpc/status"
)

func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) {
func ClientCertAuthenAndAuthor(ctx context.Context, serviceConfigTableName string) (context.Context, error) {
rc, ctx := common_utils.GetContext(ctx)
p, ok := peer.FromContext(ctx)
if !ok {
Expand All @@ -32,10 +33,44 @@ func ClientCertAuthenAndAuthor(ctx context.Context) (context.Context, error) {
return ctx, status.Error(codes.Unauthenticated, "invalid username in certificate common name.")
}

if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil {
glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err)
return ctx, status.Errorf(codes.Unauthenticated, "")
if serviceConfigTableName != "" {
if err := PopulateAuthStructByCommonName(username, &rc.Auth, serviceConfigTableName); err != nil {
return ctx, err
}
} else {
if err := PopulateAuthStruct(username, &rc.Auth, nil); err != nil {
glog.Infof("[%s] Failed to retrieve authentication information; %v", rc.ID, err)
return ctx, status.Errorf(codes.Unauthenticated, "")
}
}

return ctx, nil
}

func PopulateAuthStructByCommonName(certCommonName string, auth *common_utils.AuthInfo, serviceConfigTableName string) error {
if serviceConfigTableName == "" {
return status.Errorf(codes.Unauthenticated, "Service config table name should not be empty")
}

var configDbConnector = swsscommon.NewConfigDBConnector_Native()
defer swsscommon.DeleteConfigDBConnector_Native(configDbConnector)
configDbConnector.Connect(false)

var fieldValuePairs = configDbConnector.Get_entry(serviceConfigTableName, certCommonName)
if fieldValuePairs.Size() > 0 {
if fieldValuePairs.Has_key("role") {
var role = fieldValuePairs.Get("role")
auth.Roles = []string{role}
}
} else {
glog.Warningf("Failed to retrieve cert common name mapping; %s", certCommonName)
}

swsscommon.DeleteFieldValueMap(fieldValuePairs)

if len(auth.Roles) == 0 {
return status.Errorf(codes.Unauthenticated, "Invalid cert cname:'%s', not a trusted cert common name.", certCommonName)
} else {
return nil
}
}
2 changes: 1 addition & 1 deletion gnmi_server/debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (

func (srv *Server) GetSubscribePreferences(req *spb_gnoi.SubscribePreferencesReq, stream spb_gnoi.Debug_GetSubscribePreferencesServer) error {
ctx := stream.Context()
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand Down
30 changes: 15 additions & 15 deletions gnmi_server/gnoi.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func RebootSystem(fileName string) error {
func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest) (*gnoi_system_pb.RebootResponse, error) {
fileName := common_utils.GNMI_WORK_PATH + "/config_db.json.tmp"

_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -57,7 +57,7 @@ func (srv *Server) Reboot(ctx context.Context, req *gnoi_system_pb.RebootRequest

// TODO: Support GNOI RebootStatus
func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootStatusRequest) (*gnoi_system_pb.RebootStatusResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -67,7 +67,7 @@ func (srv *Server) RebootStatus(ctx context.Context, req *gnoi_system_pb.RebootS

// TODO: Support GNOI CancelReboot
func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelRebootRequest) (*gnoi_system_pb.CancelRebootResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand All @@ -76,7 +76,7 @@ func (srv *Server) CancelReboot(ctx context.Context, req *gnoi_system_pb.CancelR
}
func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.System_PingServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand All @@ -85,7 +85,7 @@ func (srv *Server) Ping(req *gnoi_system_pb.PingRequest, rs gnoi_system_pb.Syste
}
func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_system_pb.System_TracerouteServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
Expand All @@ -94,23 +94,23 @@ func (srv *Server) Traceroute(req *gnoi_system_pb.TracerouteRequest, rs gnoi_sys
}
func (srv *Server) SetPackage(rs gnoi_system_pb.System_SetPackageServer) error {
ctx := rs.Context()
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return err
}
log.V(1).Info("gNOI: SetPackage")
return status.Errorf(codes.Unimplemented, "")
}
func (srv *Server) SwitchControlProcessor(ctx context.Context, req *gnoi_system_pb.SwitchControlProcessorRequest) (*gnoi_system_pb.SwitchControlProcessorResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
log.V(1).Info("gNOI: SwitchControlProcessor")
return nil, status.Errorf(codes.Unimplemented, "")
}
func (srv *Server) Time(ctx context.Context, req *gnoi_system_pb.TimeRequest) (*gnoi_system_pb.TimeResponse, error) {
_, err := authenticate(srv.config.UserAuth, ctx)
_, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -147,7 +147,7 @@ func (srv *Server) Authenticate(ctx context.Context, req *spb_jwt.AuthenticateRe

}
func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*spb_jwt.RefreshResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -175,7 +175,7 @@ func (srv *Server) Refresh(ctx context.Context, req *spb_jwt.RefreshRequest) (*s
}

func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRequest) (*spb.ClearNeighborsResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -207,7 +207,7 @@ func (srv *Server) ClearNeighbors(ctx context.Context, req *spb.ClearNeighborsRe
}

func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (*spb.CopyConfigResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -238,7 +238,7 @@ func (srv *Server) CopyConfig(ctx context.Context, req *spb.CopyConfigRequest) (
}

func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequest) (*spb.TechsupportResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -270,7 +270,7 @@ func (srv *Server) ShowTechsupport(ctx context.Context, req *spb.TechsupportRequ
}

func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallRequest) (*spb.ImageInstallResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -302,7 +302,7 @@ func (srv *Server) ImageInstall(ctx context.Context, req *spb.ImageInstallReques
}

func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest) (*spb.ImageRemoveResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -334,7 +334,7 @@ func (srv *Server) ImageRemove(ctx context.Context, req *spb.ImageRemoveRequest)
}

func (srv *Server) ImageDefault(ctx context.Context, req *spb.ImageDefaultRequest) (*spb.ImageDefaultResponse, error) {
ctx, err := authenticate(srv.config.UserAuth, ctx)
ctx, err := authenticate(srv.config, ctx)
if err != nil {
return nil, err
}
Expand Down
21 changes: 11 additions & 10 deletions gnmi_server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type Config struct {
EnableNativeWrite bool
ZmqAddress string
IdleConnDuration int
ConfigTableName string
}

var AuthLock sync.Mutex
Expand Down Expand Up @@ -188,30 +189,30 @@ func (srv *Server) Port() int64 {
return srv.config.Port
}

func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, error) {
func authenticate(config *Config, ctx context.Context) (context.Context, error) {
var err error
success := false
rc, ctx := common_utils.GetContext(ctx)
if !UserAuth.Any() {
if !config.UserAuth.Any() {
//No Auth enabled
rc.Auth.AuthEnabled = false
return ctx, nil
}
rc.Auth.AuthEnabled = true
if UserAuth.Enabled("password") {
if config.UserAuth.Enabled("password") {
ctx, err = BasicAuthenAndAuthor(ctx)
if err == nil {
success = true
}
}
if !success && UserAuth.Enabled("jwt") {
if !success && config.UserAuth.Enabled("jwt") {
_, ctx, err = JwtAuthenAndAuthor(ctx)
if err == nil {
success = true
}
}
if !success && UserAuth.Enabled("cert") {
ctx, err = ClientCertAuthenAndAuthor(ctx)
if !success && config.UserAuth.Enabled("cert") {
ctx, err = ClientCertAuthenAndAuthor(ctx, config.ConfigTableName)
if err == nil {
success = true
}
Expand All @@ -230,7 +231,7 @@ func authenticate(UserAuth AuthTypes, ctx context.Context) (context.Context, err
// Subscribe implements the gNMI Subscribe RPC.
func (s *Server) Subscribe(stream gnmipb.GNMI_SubscribeServer) error {
ctx := stream.Context()
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -315,7 +316,7 @@ func IsNativeOrigin(origin string) bool {
// Get implements the Get RPC in gNMI spec.
func (s *Server) Get(ctx context.Context, req *gnmipb.GetRequest) (*gnmipb.GetResponse, error) {
common_utils.IncCounter(common_utils.GNMI_GET)
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
common_utils.IncCounter(common_utils.GNMI_GET_FAIL)
return nil, err
Expand Down Expand Up @@ -402,7 +403,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe
common_utils.IncCounter(common_utils.GNMI_SET_FAIL)
return nil, grpc.Errorf(codes.Unimplemented, "GNMI is in read-only mode")
}
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
common_utils.IncCounter(common_utils.GNMI_SET_FAIL)
return nil, err
Expand Down Expand Up @@ -502,7 +503,7 @@ func (s *Server) Set(ctx context.Context, req *gnmipb.SetRequest) (*gnmipb.SetRe
}

func (s *Server) Capabilities(ctx context.Context, req *gnmipb.CapabilityRequest) (*gnmipb.CapabilityResponse, error) {
ctx, err := authenticate(s.config.UserAuth, ctx)
ctx, err := authenticate(s.config, ctx)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 9c63257

Please sign in to comment.