diff --git a/client/scep.go b/client/scep.go index 5b11b9e..3ccc755 100644 --- a/client/scep.go +++ b/client/scep.go @@ -67,16 +67,16 @@ func (c *client) GetCACaps(ctx context.Context) ([]byte, error) { return r.Data, nil } -func (c *client) GetCACert(ctx context.Context) ([]byte, error) { +func (c *client) GetCACert(ctx context.Context) ([]byte, int, error) { request := scepserver.SCEPRequest{ Operation: "GetCACert", } reply, err := c.getRemote(ctx, request) if err != nil { - return nil, err + return nil, 0, err } r := reply.(scepserver.SCEPResponse) - return r.Data, nil + return r.Data, r.CACertNum, nil } func (c *client) PKIOperation(ctx context.Context, data []byte) ([]byte, error) { diff --git a/cmd/scepclient/scepclient.go b/cmd/scepclient/scepclient.go index 03f430c..5908c9e 100644 --- a/cmd/scepclient/scepclient.go +++ b/cmd/scepclient/scepclient.go @@ -76,13 +76,23 @@ func run(cfg runCfg) error { client = scepclient.NewClient(cfg.serverURL) } - resp, err := client.GetCACert(ctx) + resp, certNum, err := client.GetCACert(ctx) if err != nil { return err } - certs, err := scep.CACerts(resp) - if err != nil { - return err + var certs []*x509.Certificate + { + if certNum > 1 { + certs, err = scep.CACerts(resp) + if err != nil { + return err + } + } else { + certs, err = x509.ParseCertificates(resp) + if err != nil { + return err + } + } } var signerCert *x509.Certificate diff --git a/server/endpoint.go b/server/endpoint.go index af02c46..6278a71 100644 --- a/server/endpoint.go +++ b/server/endpoint.go @@ -11,6 +11,7 @@ type SCEPRequest struct { // Business errors will be encoded as a CertRep message // with pkiStatus FAILURE and a failInfo attribute. type SCEPResponse struct { - Data []byte - Err error // response error + CACertNum int //chain + Data []byte + Err error // response error } diff --git a/server/service.go b/server/service.go index 1dd3c57..69290ba 100644 --- a/server/service.go +++ b/server/service.go @@ -24,7 +24,7 @@ type Service interface { // GetCACert returns CA certificate or // a CA certificate chain with intermediates // in a PKCS#7 Degenerate Certificates format - GetCACert(ctx context.Context) ([]byte, error) + GetCACert(ctx context.Context) ([]byte, int, error) // PKIOperation handles incoming SCEP messages such as PKCSReq and // sends back a CertRep PKIMessag. @@ -50,11 +50,15 @@ func (svc service) GetCACaps(ctx context.Context) ([]byte, error) { return defaultCaps, nil } -func (svc service) GetCACert(ctx context.Context) ([]byte, error) { +func (svc service) GetCACert(ctx context.Context) ([]byte, int, error) { if len(svc.ca) == 0 { - return nil, errors.New("missing CA Cert") + return nil, 0, errors.New("missing CA Cert") } - return scep.DegenerateCertificates(svc.ca) + if len(svc.ca) == 1 { + return svc.ca[0].Raw, 1, nil + } + data, err := scep.DegenerateCertificates(svc.ca) + return data, len(svc.ca), err } func (svc service) PKIOperation(ctx context.Context, data []byte) ([]byte, error) { diff --git a/server/service_logging.go b/server/service_logging.go index e19f2d1..3c4a1e8 100644 --- a/server/service_logging.go +++ b/server/service_logging.go @@ -29,7 +29,7 @@ func (mw loggingService) GetCACaps(ctx context.Context) (caps []byte, err error) return } -func (mw loggingService) GetCACert(ctx context.Context) (cert []byte, err error) { +func (mw loggingService) GetCACert(ctx context.Context) (cert []byte, certNum int, err error) { defer func(begin time.Time) { _ = mw.logger.Log( "method", "GetCACert", @@ -37,7 +37,7 @@ func (mw loggingService) GetCACert(ctx context.Context) (cert []byte, err error) "took", time.Since(begin), ) }(time.Now()) - cert, err = mw.Service.GetCACert(ctx) + cert, certNum, err = mw.Service.GetCACert(ctx) return } diff --git a/server/transport.go b/server/transport.go index 8d7a548..a2fb664 100644 --- a/server/transport.go +++ b/server/transport.go @@ -104,7 +104,7 @@ func encodeSCEPResponse(ctx context.Context, w http.ResponseWriter, response int fmt.Println(resp.Err) return resp.Err } - w.Header().Set("Content-Type", contentHeader(ctx)) + w.Header().Set("Content-Type", contentHeader(ctx, resp.CACertNum)) w.Write(resp.Data) return nil } @@ -118,19 +118,29 @@ func DecodeSCEPResponse(ctx context.Context, r *http.Response) (interface{}, err resp := SCEPResponse{ Data: data, } + header := r.Header.Get("Content-Type") + if header == certChainHeader { + // TODO decode the response instead of just passing []byte around + // 0 or 1 + resp.CACertNum = 2 + } return resp, nil } const ( certChainHeader = "application/x-x509-ca-ra-cert" + leafHeader = "application/x-x509-ca-cert" pkiOpHeader = "application/x-pki-message" ) -func contentHeader(ctx context.Context) string { +func contentHeader(ctx context.Context, certNum int) string { op := ctx.Value("operation") switch op { case "GetCACert": - return certChainHeader + if certNum > 1 { + return certChainHeader + } + return leafHeader case "PKIOperation": return pkiOpHeader default: @@ -153,9 +163,9 @@ func makeSCEPEndpoint(svc Service) endpoint.Endpoint { } return SCEPResponse{Data: caps}, nil case "GetCACert": - cert, err := svc.GetCACert(ctx) + cert, certNum, err := svc.GetCACert(ctx) if err != nil { - return SCEPResponse{Err: err}, nil + return SCEPResponse{Err: err, CACertNum: certNum}, nil } return SCEPResponse{Data: cert}, nil case "PKIOperation":