diff --git a/README.md b/README.md index 695d4a8..2aa6817 100644 --- a/README.md +++ b/README.md @@ -119,34 +119,16 @@ resourceTypes := []ResourceType{ ### 4. Create Server ```go -server := Server{ - Config: config, +serverArgs := &ServerArgs{ + ServiceProviderConfig: config, ResourceTypes: resourceTypes, } -``` - -## Logging - -No incoming or outgoing (incl. errors) requests are logged by default. It is up to the user to implement this. This can -either be done through middleware around the server or by implementing the `ResourceHandler` interface. - -### Internal - -The SCIM server uses the standard `slog` package for logging. -There are two moments where the server logs: - -1. When it was not able to marshal the response, it will log the error. This should not happen, since these are - predefined structures, of which most have custom `MarshalJSON` methods. In these cases an `errors.ScimErrorInternal` - error is returned. -2. When the server was not able to `Write` the response. - -This logger can be customized by overwriting the default `slog.Logger`. +serverOpts := []ServerOption{ + WithLogger(logger), // optional, default is no logging +} -```go -var scimLogger slog.Logger -// initialize w/ own implementation -scim.SetLogger(scimLogger) +server := NewServer(serverArgs, serverOpts...) ``` ## String Values for Attributes diff --git a/examples_test.go b/examples_test.go index 85dd419..360764d 100644 --- a/examples_test.go +++ b/examples_test.go @@ -6,17 +6,25 @@ import ( ) func ExampleNewServer() { - server := Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: nil, + args := &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + } + server, err := NewServer(args) + if err != nil { + logger.Fatal(err) } logger.Fatal(http.ListenAndServe(":7643", server)) } func ExampleNewServer_basePath() { - server := Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: nil, + args := &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + } + server, err := NewServer(args) + if err != nil { + logger.Fatal(err) } // You can host the SCIM server on a custom path, make sure to strip the prefix, so only `/v2/` is left. http.Handle("/scim/", http.StripPrefix("/scim", server)) @@ -33,9 +41,13 @@ func ExampleNewServer_logger() { return http.HandlerFunc(fn) } - server := Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: nil, + args := &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + } + server, err := NewServer(args) + if err != nil { + logger.Fatal(err) } logger.Fatal(http.ListenAndServe(":7643", loggingMiddleware(server))) } diff --git a/filter_test.go b/filter_test.go index 6e758e3..3377adc 100644 --- a/filter_test.go +++ b/filter_test.go @@ -14,7 +14,7 @@ import ( ) func Test_Group_Filter(t *testing.T) { - s := newTestServerForFilter() + s := newTestServerForFilter(t) tests := []struct { name string @@ -72,7 +72,7 @@ func Test_Group_Filter(t *testing.T) { } func Test_User_Filter(t *testing.T) { - s := newTestServerForFilter() + s := newTestServerForFilter(t) tests := []struct { name string @@ -129,37 +129,46 @@ func Test_User_Filter(t *testing.T) { } } -func newTestServerForFilter() scim.Server { - return scim.Server{ - ResourceTypes: []scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{"userName": "testUser"}}, - "0002": {attributes: map[string]interface{}{"userName": "testUser+test"}}, +// newTestServerForFilter creates a new test server with a User and Group resource type +// or fails the test if an error occurs. +func newTestServerForFilter(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{}, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{"userName": "testUser"}}, + "0002": {attributes: map[string]interface{}{"userName": "testUser+test"}}, + }, + schema: schema.CoreUserSchema(), }, - schema: schema.CoreUserSchema(), }, - }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{"displayName": "testGroup"}}, - "0002": {attributes: map[string]interface{}{"displayName": "testGroup+test"}}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{"displayName": "testGroup"}}, + "0002": {attributes: map[string]interface{}{"displayName": "testGroup+test"}}, + }, + schema: schema.CoreGroupSchema(), }, - schema: schema.CoreGroupSchema(), }, }, }, + ) + if err != nil { + t.Fatal(err) } + return s } diff --git a/handlers.go b/handlers.go index 530ad60..7f207ee 100644 --- a/handlers.go +++ b/handlers.go @@ -8,10 +8,10 @@ import ( "github.com/elimity-com/scim/schema" ) -func errorHandler(w http.ResponseWriter, scimErr *errors.ScimError) { +func (s Server) errorHandler(w http.ResponseWriter, scimErr *errors.ScimError) { raw, err := json.Marshal(scimErr) if err != nil { - log.Error( + s.log.Error( "failed marshaling scim error", "scimError", scimErr, "error", err, @@ -22,7 +22,7 @@ func errorHandler(w http.ResponseWriter, scimErr *errors.ScimError) { w.WriteHeader(scimErr.Status) _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -35,7 +35,7 @@ func (s Server) resourceDeleteHandler(w http.ResponseWriter, r *http.Request, id deleteErr := resourceType.Handler.Delete(r, id) if deleteErr != nil { scimErr := errors.CheckScimError(deleteErr, http.MethodDelete) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } @@ -48,14 +48,14 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st resource, getErr := resourceType.Handler.Get(r, id) if getErr != nil { scimErr := errors.CheckScimError(getErr, http.MethodGet) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling resource", "resource", resource, "error", err, @@ -69,7 +69,7 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -81,14 +81,14 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id string, resourceType ResourceType) { patch, scimErr := resourceType.validatePatch(r) if scimErr != nil { - errorHandler(w, scimErr) + s.errorHandler(w, scimErr) return } resource, patchErr := resourceType.Handler.Patch(r, id, patch) if patchErr != nil { scimErr := errors.CheckScimError(patchErr, http.MethodPatch) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } @@ -99,8 +99,8 @@ func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling resource", "resource", resource, "error", err, @@ -116,7 +116,7 @@ func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -130,21 +130,21 @@ func (s Server) resourcePostHandler(w http.ResponseWriter, r *http.Request, reso attributes, scimErr := resourceType.validate(data) if scimErr != nil { - errorHandler(w, scimErr) + s.errorHandler(w, scimErr) return } resource, postErr := resourceType.Handler.Create(r, attributes) if postErr != nil { scimErr := errors.CheckScimError(postErr, http.MethodPost) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling resource", "resource", resource, "error", err, @@ -160,7 +160,7 @@ func (s Server) resourcePostHandler(w http.ResponseWriter, r *http.Request, reso _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -174,21 +174,21 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st attributes, scimErr := resourceType.validate(data) if scimErr != nil { - errorHandler(w, scimErr) + s.errorHandler(w, scimErr) return } resource, putError := resourceType.Handler.Replace(r, id, attributes) if putError != nil { scimErr := errors.CheckScimError(putError, http.MethodPut) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling resource", "resource", resource, "error", err, @@ -202,7 +202,7 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -213,7 +213,7 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st // resource types name to the /ResourceTypes endpoint. For example: "/ResourceTypes/User". func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name string) { var resourceType ResourceType - for _, r := range s.ResourceTypes { + for _, r := range s.resourceTypes { if r.Name == name { resourceType = r break @@ -222,14 +222,14 @@ func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name if resourceType.Name != name { scimErr := errors.ScimErrorResourceNotFound(name) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resourceType.getRaw()) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling resource type", "resourceType", resourceType, "error", err, @@ -238,7 +238,7 @@ func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -251,26 +251,26 @@ func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name func (s Server) resourceTypesHandler(w http.ResponseWriter, r *http.Request) { params, paramsErr := s.parseRequestParams(r, schema.ResourceTypeSchema()) if paramsErr != nil { - errorHandler(w, paramsErr) + s.errorHandler(w, paramsErr) return } - start, end := clamp(params.StartIndex-1, params.Count, len(s.ResourceTypes)) + start, end := clamp(params.StartIndex-1, params.Count, len(s.resourceTypes)) var resources []interface{} - for _, v := range s.ResourceTypes[start:end] { + for _, v := range s.resourceTypes[start:end] { resources = append(resources, v.getRaw()) } lr := listResponse{ - TotalResults: len(s.ResourceTypes), + TotalResults: len(s.resourceTypes), ItemsPerPage: params.Count, StartIndex: params.StartIndex, Resources: resources, } raw, err := json.Marshal(lr) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling list response", "listResponse", lr, "error", err, @@ -280,7 +280,7 @@ func (s Server) resourceTypesHandler(w http.ResponseWriter, r *http.Request) { _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -292,14 +292,14 @@ func (s Server) resourceTypesHandler(w http.ResponseWriter, r *http.Request) { func (s Server) resourcesGetHandler(w http.ResponseWriter, r *http.Request, resourceType ResourceType) { params, paramsErr := s.parseRequestParams(r, resourceType.Schema, resourceType.getSchemaExtensions()...) if paramsErr != nil { - errorHandler(w, paramsErr) + s.errorHandler(w, paramsErr) return } page, getError := resourceType.Handler.GetAll(r, params) if getError != nil { scimErr := errors.CheckScimError(getError, http.MethodGet) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } @@ -311,8 +311,8 @@ func (s Server) resourcesGetHandler(w http.ResponseWriter, r *http.Request, reso } raw, err := json.Marshal(lr) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling list response", "listResponse", lr, "error", err, @@ -322,7 +322,7 @@ func (s Server) resourcesGetHandler(w http.ResponseWriter, r *http.Request, reso _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -335,14 +335,14 @@ func (s Server) schemaHandler(w http.ResponseWriter, r *http.Request, id string) getSchema := s.getSchema(id) if getSchema.ID != id { scimErr := errors.ScimErrorResourceNotFound(id) - errorHandler(w, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(getSchema) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling schema", "schema", getSchema, "error", err, @@ -351,7 +351,7 @@ func (s Server) schemaHandler(w http.ResponseWriter, r *http.Request, id string) _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -363,7 +363,7 @@ func (s Server) schemaHandler(w http.ResponseWriter, r *http.Request, id string) func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { params, paramsErr := s.parseRequestParams(r, schema.Definition()) if paramsErr != nil { - errorHandler(w, paramsErr) + s.errorHandler(w, paramsErr) return } @@ -373,7 +373,7 @@ func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { ) if validator := params.FilterValidator; validator != nil { if err := validator.Validate(); err != nil { - errorHandler(w, &errors.ScimErrorInvalidFilter) + s.errorHandler(w, &errors.ScimErrorInvalidFilter) return } } @@ -395,8 +395,8 @@ func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { } raw, err := json.Marshal(lr) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling list response", "listResponse", lr, "error", err, @@ -406,7 +406,7 @@ func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) @@ -416,12 +416,12 @@ func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { // serviceProviderConfigHandler receives an HTTP GET to this endpoint will return a JSON structure that describes the // SCIM specification features available on a service provider. func (s Server) serviceProviderConfigHandler(w http.ResponseWriter, r *http.Request) { - raw, err := json.Marshal(s.Config.getRaw()) + raw, err := json.Marshal(s.config.getRaw()) if err != nil { - errorHandler(w, &errors.ScimErrorInternal) - log.Error( + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( "failed marshaling service provider config", - "serviceProviderConfig", s.Config, + "serviceProviderConfig", s.config, "error", err, ) return @@ -429,7 +429,7 @@ func (s Server) serviceProviderConfigHandler(w http.ResponseWriter, r *http.Requ _, err = w.Write(raw) if err != nil { - log.Error( + s.log.Error( "failed writing response", "error", err, ) diff --git a/handlers_test.go b/handlers_test.go index 4daca06..ca0fb33 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -84,7 +84,7 @@ func TestInvalidRequests(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(test.method, test.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, test.expectedStatus, rr.Code) }) @@ -94,7 +94,7 @@ func TestInvalidRequests(t *testing.T) { func TestServerMeEndpoint(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Me", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotImplemented, rr.Code) } @@ -102,7 +102,7 @@ func TestServerMeEndpoint(t *testing.T) { func TestServerResourceDeleteHandler(t *testing.T) { req := httptest.NewRequest(http.MethodDelete, "/Users/0001", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNoContent, rr.Code) } @@ -110,7 +110,7 @@ func TestServerResourceDeleteHandler(t *testing.T) { func TestServerResourceDeleteHandlerNotFound(t *testing.T) { req := httptest.NewRequest(http.MethodDelete, "/Users/9999", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotFound, rr.Code) @@ -156,7 +156,7 @@ func TestServerResourceGetHandler(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tt.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -185,7 +185,7 @@ func TestServerResourceGetHandler(t *testing.T) { func TestServerResourceGetHandlerNotFound(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users/9999", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotFound, rr.Code) @@ -210,7 +210,7 @@ func TestServerResourcePatchHandlerFailOnBadType(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) var resource map[string]interface{} assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource)) @@ -232,7 +232,7 @@ func TestServerResourcePatchHandlerInvalidPath(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) @@ -252,14 +252,14 @@ func TestServerResourcePatchHandlerInvalidRemoveOp(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) } func TestServerResourcePatchHandlerMapTypeSubAttribute(t *testing.T) { recorder := httptest.NewRecorder() - newTestServer().ServeHTTP(recorder, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ + newTestServer(t).ServeHTTP(recorder, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], "Operations":[ { @@ -272,7 +272,7 @@ func TestServerResourcePatchHandlerMapTypeSubAttribute(t *testing.T) { assertEqualStatusCode(t, http.StatusOK, recorder.Code) recorder2 := httptest.NewRecorder() - newTestServer().ServeHTTP(recorder2, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ + newTestServer(t).ServeHTTP(recorder2, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], "Operations":[ { @@ -319,7 +319,7 @@ func TestServerResourcePatchHandlerReturnsNoContent(t *testing.T) { } for _, req := range reqs { rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNoContent, rr.Code) } @@ -358,7 +358,7 @@ func TestServerResourcePatchHandlerValid(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -408,7 +408,7 @@ func TestServerResourcePatchHandlerValidPathHasSubAttributes(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) } @@ -424,7 +424,7 @@ func TestServerResourcePatchHandlerValidRemoveOp(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNoContent, rr.Code) } @@ -468,7 +468,7 @@ func TestServerResourcePostHandlerValid(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, test.target, test.body) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusCreated, rr.Code) @@ -498,7 +498,7 @@ func TestServerResourcePostHandlerValid(t *testing.T) { func TestServerResourcePutHandlerNotFound(t *testing.T) { req := httptest.NewRequest(http.MethodPut, "/Users/9999", strings.NewReader(`{"userName": "other"}`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotFound, rr.Code) @@ -549,7 +549,7 @@ func TestServerResourcePutHandlerValid(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPut, test.target, test.body) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -595,7 +595,7 @@ func TestServerResourceTypeHandlerValid(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/ResourceTypes/%s", tt.versionPrefix, tt.resourceType), nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -625,7 +625,7 @@ func TestServerResourceTypesHandler(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, test.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -650,7 +650,7 @@ func TestServerResourceTypesHandler(t *testing.T) { func TestServerResourcesGetAllHandlerNegativeCount(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=-1", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -663,7 +663,7 @@ func TestServerResourcesGetAllHandlerNegativeCount(t *testing.T) { func TestServerResourcesGetHandler(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -676,7 +676,7 @@ func TestServerResourcesGetHandler(t *testing.T) { func TestServerResourcesGetHandlerMaxCount(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=20000", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -688,7 +688,7 @@ func TestServerResourcesGetHandlerMaxCount(t *testing.T) { func TestServerResourcesGetHandlerPagination(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=2&startIndex=2", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -726,7 +726,7 @@ func TestServerSchemaEndpointValid(t *testing.T) { "%s/Schemas/%s", test.versionPrefix, test.schema, ), nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -755,7 +755,7 @@ func TestServerSchemasEndpoint(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, test.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -791,7 +791,7 @@ func TestServerSchemasEndpointFilter(t *testing.T) { "/Schemas?%s", params.Encode(), ), nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -819,7 +819,7 @@ func TestServerServiceProviderConfigHandler(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tt.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) }) @@ -930,39 +930,45 @@ func newTestResourceHandler() ResourceHandler { } } -func newTestServer() Server { +func newTestServer(t *testing.T) Server { userSchema := getUserSchema() userSchemaExtension := getUserExtensionSchema() - return Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: []ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: userSchema, - Handler: newTestResourceHandler(), - }, - { - ID: optional.NewString("EnterpriseUser"), - Name: "EnterpriseUser", - Endpoint: "/EnterpriseUsers", - Description: optional.NewString("Enterprise User Account"), - Schema: userSchema, - SchemaExtensions: []SchemaExtension{ - {Schema: userSchemaExtension}, + s, err := NewServer( + &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: userSchema, + Handler: newTestResourceHandler(), + }, + { + ID: optional.NewString("EnterpriseUser"), + Name: "EnterpriseUser", + Endpoint: "/EnterpriseUsers", + Description: optional.NewString("Enterprise User Account"), + Schema: userSchema, + SchemaExtensions: []SchemaExtension{ + {Schema: userSchemaExtension}, + }, + Handler: newTestResourceHandler(), + }, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: newTestResourceHandler(), }, - Handler: newTestResourceHandler(), - }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: newTestResourceHandler(), }, }, + ) + if err != nil { + t.Fatal(err) } + return s } diff --git a/internal/idp_test/azuread_util_test.go b/internal/idp_test/azuread_util_test.go index 5ca108f..e12863b 100644 --- a/internal/idp_test/azuread_util_test.go +++ b/internal/idp_test/azuread_util_test.go @@ -2,6 +2,7 @@ package idp_test import ( "net/http" + "testing" "time" "github.com/elimity-com/scim" @@ -16,34 +17,39 @@ var azureCreatedTime = time.Date( 19, 59, 26, 0, time.UTC, ) -func newAzureADTestServer() scim.Server { - return scim.Server{ - Config: scim.ServiceProviderConfig{ - MaxResults: 20, - }, - ResourceTypes: []scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - SchemaExtensions: []scim.SchemaExtension{ - {Schema: schema.ExtensionEnterpriseUser()}, - }, - Handler: azureADUserResourceHandler{}, +func newAzureADTestServer(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{ + MaxResults: 20, }, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + SchemaExtensions: []scim.SchemaExtension{ + {Schema: schema.ExtensionEnterpriseUser()}, + }, + Handler: azureADUserResourceHandler{}, + }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: azureADGroupResourceHandler{}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: azureADGroupResourceHandler{}, + }, }, - }, + }) + if err != nil { + t.Fatal(err) } + return s } type azureADGroupResourceHandler struct{} diff --git a/internal/idp_test/idp_test.go b/internal/idp_test/idp_test.go index 881b2ea..f093ec1 100644 --- a/internal/idp_test/idp_test.go +++ b/internal/idp_test/idp_test.go @@ -27,7 +27,7 @@ func TestIdP(t *testing.T) { var test testCase _ = unmarshal(raw, &test) t.Run(strings.TrimSuffix(f.Name(), ".json"), func(t *testing.T) { - if err := testRequest(test, idp.Name()); err != nil { + if err := testRequest(t, test, idp.Name()); err != nil { t.Error(err) } }) @@ -36,23 +36,23 @@ func TestIdP(t *testing.T) { } } -func testRequest(t testCase, idpName string) error { +func testRequest(t *testing.T, tc testCase, idpName string) error { rr := httptest.NewRecorder() - br := bytes.NewReader(t.Request) - getNewServer(idpName).ServeHTTP( + br := bytes.NewReader(tc.Request) + getNewServer(t, idpName).ServeHTTP( rr, - httptest.NewRequest(t.Method, t.Path, br), + httptest.NewRequest(tc.Method, tc.Path, br), ) - if code := rr.Code; code != t.StatusCode { - return fmt.Errorf("expected %d, got %d", t.StatusCode, code) + if code := rr.Code; code != tc.StatusCode { + return fmt.Errorf("expected %d, got %d", tc.StatusCode, code) } - if len(t.Response) != 0 { + if len(tc.Response) != 0 { var response map[string]interface{} if err := unmarshal(rr.Body.Bytes(), &response); err != nil { return err } - if !reflect.DeepEqual(t.Response, response) { - return fmt.Errorf("expected, got:\n%v\n%v", t.Response, response) + if !reflect.DeepEqual(tc.Response, response) { + return fmt.Errorf("expected, got:\n%v\n%v", tc.Response, response) } } return nil diff --git a/internal/idp_test/okta_util_test.go b/internal/idp_test/okta_util_test.go index 5cbf0a1..066edc5 100644 --- a/internal/idp_test/okta_util_test.go +++ b/internal/idp_test/okta_util_test.go @@ -2,6 +2,7 @@ package idp_test import ( "net/http" + "testing" "github.com/elimity-com/scim" "github.com/elimity-com/scim/errors" @@ -9,28 +10,36 @@ import ( "github.com/elimity-com/scim/schema" ) -func newOktaTestServer() scim.Server { - return scim.Server{ - ResourceTypes: []scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - Handler: oktaUserResourceHandler{}, - }, +func newOktaTestServer(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{}, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + Handler: oktaUserResourceHandler{}, + }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: oktaGroupResourceHandler{}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: oktaGroupResourceHandler{}, + }, }, }, + ) + if err != nil { + t.Fatal(err) } + + return s } type oktaGroupResourceHandler struct{} diff --git a/internal/idp_test/util_test.go b/internal/idp_test/util_test.go index c1a112b..ae9b1fa 100644 --- a/internal/idp_test/util_test.go +++ b/internal/idp_test/util_test.go @@ -3,16 +3,17 @@ package idp_test import ( "bytes" "encoding/json" + "testing" "github.com/elimity-com/scim" ) -func getNewServer(idpName string) scim.Server { +func getNewServer(t *testing.T, idpName string) scim.Server { switch idpName { case "okta": - return newOktaTestServer() + return newOktaTestServer(t) case "azuread": - return newAzureADTestServer() + return newAzureADTestServer(t) default: panic("unreachable") } diff --git a/logger.go b/logger.go index cd619c0..e8db123 100644 --- a/logger.go +++ b/logger.go @@ -1,10 +1,10 @@ package scim -import "log/slog" +// Logger defines and interface for logging errors. +type Logger interface { + Error(args ...interface{}) +} -var log *slog.Logger = slog.Default().WithGroup("scim") +type noopLogger struct{} -// SetLogger sets the logger for the scim package. -func SetLogger(l *slog.Logger) { - log = l -} +func (noopLogger) Error(...interface{}) {} diff --git a/patch_add_test.go b/patch_add_test.go index 009ef90..c7cca2e 100644 --- a/patch_add_test.go +++ b/patch_add_test.go @@ -18,7 +18,7 @@ func TestPatch_addAttributes(t *testing.T) { req = httptest.NewRequest(http.MethodPatch, "/Users/0001", bytes.NewReader(raw)) rr = httptest.NewRecorder() ) - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatal(rr.Code, rr.Body.String()) } @@ -65,7 +65,7 @@ func TestPatch_addMember(t *testing.T) { req = httptest.NewRequest(http.MethodPatch, "/Groups/0001", bytes.NewReader(raw)) rr = httptest.NewRecorder() ) - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatal(rr.Code, rr.Body.String()) } @@ -121,7 +121,7 @@ func TestPatch_alreadyExists(t *testing.T) { changed: false, }, } { - server := newTestServer() + server := newTestServer(t) raw, err := os.ReadFile(test.jsonFilePath) if err != nil { t.Fatal(err) @@ -158,7 +158,7 @@ func TestPatch_complex(t *testing.T) { req = httptest.NewRequest(http.MethodPatch, "/Users/0001", bytes.NewReader(raw)) rr = httptest.NewRecorder() ) - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatal(rr.Code, rr.Body.String()) } diff --git a/server.go b/server.go index d8eac99..e79a612 100644 --- a/server.go +++ b/server.go @@ -52,11 +52,54 @@ func parseIdentifier(path, endpoint string) (string, error) { return url.PathUnescape(strings.TrimPrefix(path, endpoint+"/")) } -// Server represents a SCIM server which implements the HTTP-based SCIM protocol that makes managing identities in multi- -// domain scenarios easier to support via a standardized service. +// Server represents a SCIM server which implements the HTTP-based SCIM protocol +// that makes managing identities in multi-domain scenarios easier to support via a standardized service. type Server struct { - Config ServiceProviderConfig - ResourceTypes []ResourceType + config ServiceProviderConfig + resourceTypes []ResourceType + log Logger +} + +type ServerArgs struct { + ServiceProviderConfig *ServiceProviderConfig + ResourceTypes []ResourceType +} + +type ServerOption func(*Server) + +// WithLogger sets the logger for the server. +func WithLogger(logger Logger) ServerOption { + return func(s *Server) { + if logger != nil { + s.log = logger + } + } +} + +func NewServer(args *ServerArgs, opts ...ServerOption) (Server, error) { + if args == nil { + return Server{}, fmt.Errorf("arguments not provided") + } + + if args.ServiceProviderConfig == nil { + return Server{}, fmt.Errorf("service provider config not provided") + } + + if args.ResourceTypes == nil { + return Server{}, fmt.Errorf("resource types not provided") + } + + s := &Server{ + config: *args.ServiceProviderConfig, + resourceTypes: args.ResourceTypes, + log: &noopLogger{}, + } + + for _, opt := range opts { + opt(s) + } + + return *s, nil } // ServeHTTP dispatches the request to the handler whose pattern most closely matches the request URL. @@ -67,7 +110,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch { case path == "/Me": - errorHandler(w, &errors.ScimError{ + s.errorHandler(w, &errors.ScimError{ Status: http.StatusNotImplemented, }) return @@ -88,7 +131,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - for _, resourceType := range s.ResourceTypes { + for _, resourceType := range s.resourceTypes { if path == resourceType.Endpoint { switch r.Method { case http.MethodPost: @@ -123,7 +166,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - errorHandler(w, &errors.ScimError{ + s.errorHandler(w, &errors.ScimError{ Detail: "Specified endpoint does not exist.", Status: http.StatusNotFound, }) @@ -131,7 +174,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // getSchema extracts the schemas from the resources types defined in the server with given id. func (s Server) getSchema(id string) schema.Schema { - for _, resourceType := range s.ResourceTypes { + for _, resourceType := range s.resourceTypes { if resourceType.Schema.ID == id { return resourceType.Schema } @@ -148,7 +191,7 @@ func (s Server) getSchema(id string) schema.Schema { func (s Server) getSchemas() []schema.Schema { ids := make([]string, 0) schemas := make([]schema.Schema, 0) - for _, resourceType := range s.ResourceTypes { + for _, resourceType := range s.resourceTypes { if !contains(ids, resourceType.Schema.ID) { schemas = append(schemas, resourceType.Schema) } @@ -166,7 +209,7 @@ func (s Server) getSchemas() []schema.Schema { func (s Server) parseRequestParams(r *http.Request, refSchema schema.Schema, refExtensions ...schema.Schema) (ListRequestParams, *errors.ScimError) { invalidParams := make([]string, 0) - defaultCount := s.Config.getItemsPerPage() + defaultCount := s.config.getItemsPerPage() count, countErr := getIntQueryParam(r, "count", defaultCount) if countErr != nil { invalidParams = append(invalidParams, "count") diff --git a/server_test.go b/server_test.go index 57d5c56..726fea5 100644 --- a/server_test.go +++ b/server_test.go @@ -2,11 +2,13 @@ package scim_test import ( "fmt" - internal "github.com/elimity-com/scim/filter" "io" "net/http" + "testing" "time" + internal "github.com/elimity-com/scim/filter" + "github.com/elimity-com/scim" "github.com/elimity-com/scim/errors" "github.com/elimity-com/scim/optional" @@ -40,37 +42,44 @@ func externalID(attributes scim.ResourceAttributes) optional.String { // - Whether a reference to another entity really exists. // e.g. if a member gets added, does this entity exist? -func newTestServer() scim.Server { - return scim.Server{ - ResourceTypes: []scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{}}, +func newTestServer(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{}, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{}}, + }, + schema: schema.CoreUserSchema(), }, - schema: schema.CoreUserSchema(), }, - }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{}}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{}}, + }, + schema: schema.CoreGroupSchema(), }, - schema: schema.CoreGroupSchema(), }, }, }, + ) + if err != nil { + t.Fatal(err) } + return s } // testData represents a resource entity.