diff --git a/server/depot.go b/server/depot.go index d61caae..21d5823 100644 --- a/server/depot.go +++ b/server/depot.go @@ -6,6 +6,7 @@ import ( "encoding/pem" "errors" "io/ioutil" + "math/big" "os" "path/filepath" ) @@ -13,6 +14,8 @@ import ( // Depot is a repository for managing certificates type Depot interface { CA(pass []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) + Put(name string, cert []byte) error + Serial() (*big.Int, error) } // NewFileDepot returns a new cert depot @@ -44,12 +47,95 @@ func (d fileDepot) CA(pass []byte) ([]*x509.Certificate, *rsa.PrivateKey, error) return []*x509.Certificate{cert}, key, nil } +// file permissions const ( - rootPerm = 0400 - branchPerm = 0440 - leafPerm = 0444 + certPerm = 0444 + serialPerm = 0400 ) +// Put adds a certificate to the depot +func (d fileDepot) Put(name string, data []byte) error { + if data == nil { + return errors.New("data is nil") + } + + if err := os.MkdirAll(d.dirPath, 0755); err != nil { + return err + } + + serial, err := d.Serial() + if err != nil { + return err + } + + name = d.path(name) + "." + serial.String() + ".pem" + file, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, certPerm) + if err != nil { + return err + } + defer file.Close() + + if _, err := file.Write(pemCert(data)); err != nil { + file.Close() + os.Remove(name) + return err + } + + if err := d.incrementSerial(serial); err != nil { + return err + } + + return nil +} + +func (d fileDepot) Serial() (*big.Int, error) { + name := d.path("serial") + s := big.NewInt(2) + if err := d.check("serial"); err != nil { + // assuming it doesnt exist, create + if err := d.writeSerial(s); err != nil { + return nil, err + } + return s, nil + } + data, err := ioutil.ReadFile(name) + if err != nil { + return nil, err + } + serial := s.SetBytes(data) + return serial, nil +} + +func (d fileDepot) writeSerial(serial *big.Int) error { + if err := os.MkdirAll(d.dirPath, 0755); err != nil { + return err + } + name := d.path("serial") + os.Remove(name) + + file, err := os.OpenFile(name, os.O_WRONLY|os.O_CREATE|os.O_EXCL, serialPerm) + if err != nil { + return err + } + defer file.Close() + + if _, err := file.Write(serial.Bytes()); err != nil { + file.Close() + os.Remove(name) + return err + } + return nil +} + +// read serial and increment +func (d fileDepot) incrementSerial(s *big.Int) error { + serial := s.Add(s, big.NewInt(1)) + if err := d.writeSerial(serial); err != nil { + return err + } + return nil +} + type file struct { Info os.FileInfo Data []byte @@ -115,3 +201,13 @@ func loadCert(data []byte) (*x509.Certificate, error) { return x509.ParseCertificate(pemBlock.Bytes) } + +func pemCert(derBytes []byte) []byte { + pemBlock := &pem.Block{ + Type: certificatePEMBlockType, + Headers: nil, + Bytes: derBytes, + } + out := pem.EncodeToMemory(pemBlock) + return out +} diff --git a/server/service.go b/server/service.go index 0096714..92413f9 100644 --- a/server/service.go +++ b/server/service.go @@ -82,9 +82,13 @@ func (svc service) PKIOperation(ctx context.Context, data []byte) ([]byte, error return nil, err } + serial, err := svc.depot.Serial() + if err != nil { + return nil, err + } // create cert template tmpl := &x509.Certificate{ - SerialNumber: big.NewInt(4), + SerialNumber: serial, Subject: csr.Subject, NotBefore: time.Now().Add(-600).UTC(), NotAfter: time.Now().AddDate(1, 0, 0).UTC(), @@ -100,10 +104,23 @@ func (svc service) PKIOperation(ctx context.Context, data []byte) ([]byte, error return nil, err } + crt := certRep.CertRepMessage.Certificate + name := certName(crt) + if err := svc.depot.Put(name, crt.Raw); err != nil { + return nil, err + } + return certRep.Raw, nil } +func certName(crt *x509.Certificate) string { + if crt.Subject.CommonName != "" { + return crt.Subject.CommonName + } + return string(crt.Signature) +} + func (svc service) GetNextCACert(ctx context.Context) ([]byte, error) { panic("not implemented") }