diff --git a/s2example/demo.go b/s2example/demo.go index d9f818c..ed97c1b 100644 --- a/s2example/demo.go +++ b/s2example/demo.go @@ -15,13 +15,15 @@ package main import ( + "crypto/rsa" "crypto/x509" "fmt" + "io" "net/http" - - "io/ioutil" + "os" "encoding/base64" + "encoding/pem" "encoding/xml" saml2 "github.com/russellhaering/gosaml2" @@ -35,7 +37,7 @@ func main() { panic(err) } - rawMetadata, err := ioutil.ReadAll(res.Body) + rawMetadata, err := io.ReadAll(res.Body) if err != nil { panic(err) } @@ -82,16 +84,36 @@ func main() { AudienceURI: "http://example.com/saml/acs/example", IDPCertificateStore: &certStore, SPKeyStore: randomKeyStore, + IdentityProviderSLOURL: metadata.IDPSSODescriptor.SingleLogoutServices[0].Location, + ServiceProviderSLOURL: "http://localhost:8080/v1/_logout", + } + + +// generate sp private key, certificate this is used to sign the slo request for logout + keystore, err := loadKeystore("","") + if err!=nil { + fmt.Printf("Error loading keystore") } + sp.SetSPSigningKeyStore(keystore) + + sessionIndex := "" + nameID := "" + http.HandleFunc("/v1/_saml_callback", func(rw http.ResponseWriter, req *http.Request) { err := req.ParseForm() + rw.Header().Add("Content-Type", "text/html") + if err != nil { rw.WriteHeader(http.StatusBadRequest) return } assertionInfo, err := sp.RetrieveAssertionInfo(req.FormValue("SAMLResponse")) + + sessionIndex = assertionInfo.SessionIndex + nameID = assertionInfo.NameID + if err != nil { rw.WriteHeader(http.StatusForbidden) return @@ -119,6 +141,36 @@ func main() { fmt.Fprintf(rw, "Warnings:\n") fmt.Fprintf(rw, "%+v\n", assertionInfo.WarningInfo) + + logoutRequest, err := sp.BuildLogoutRequestDocument(nameID, sessionIndex) + + logouturl , _ := sp.BuildLogoutURLRedirect("", logoutRequest) + + fmt.Fprintf(rw, " Click to logout Logout \n", logouturl) + }) + + http.HandleFunc("/v1/_logout", func(rw http.ResponseWriter, req *http.Request) { + + err := req.ParseForm() + if err != nil { + rw.WriteHeader(http.StatusBadRequest) + return + } + + + response , err := sp.ValidateEncodedLogoutResponsePOST(req.FormValue("SAMLResponse")) + fmt.Printf("the logout response %v\n", response.Status.StatusCode.Value) + + + if err!=nil { + fmt.Printf("Failed to logout %s \n", err) + fmt.Fprintf(rw, "Couldn't log out %s due to some internal error",nameID) + }else if response.Status.StatusCode.Value != "urn:oasis:names:tc:SAML:2.0:status:Success"{ + fmt.Fprintf(rw, "Couldn't log out %s due to invalid request", nameID) + }else{ + fmt.Fprintf(rw, "%s logged out successfully",nameID) + } + }) println("Visit this URL To Authenticate:") @@ -137,3 +189,41 @@ func main() { panic(err) } } + + +func loadKeystore(privateKeyPath, publicCertificate string) (*saml2.KeyStore, error){ + privateKeyBytes, err := os.ReadFile(privateKeyPath) + + if err != nil { + fmt.Printf("There is some error reading the private key %v \n", err) + return nil, err + } + + certBytes, err := os.ReadFile(publicCertificate) + + if err != nil { + fmt.Printf("There is some error reading the cert key %v \n", err) + return nil, err + } + + block, _ := pem.Decode(privateKeyBytes) + + if block == nil { + fmt.Printf("Invalid Pem private key") + panic("Invalid Pem private key") + } + + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + + if err!=nil { + fmt.Printf("There is some error signing parsing the private key: %v\n", err) + return nil, err + } + + keystore := &saml2.KeyStore{ + Signer: privateKey.(*rsa.PrivateKey), + Cert: certBytes, + } + + return keystore, nil +} \ No newline at end of file