diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index b937935..3ef0f31 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -3,28 +3,28 @@ jobs: arrange: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-go@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: go-version: '1.16' - - run: go get github.com/jdeflander/goarrange + - run: go install github.com/jdeflander/goarrange@v1.0.0 working-directory: ${{ runner.temp }} - run: test -z "$(goarrange run -r -d)" lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: golangci/golangci-lint-action@v2 + - uses: actions/checkout@v4 + - uses: golangci/golangci-lint-action@v4 with: - version: v1.39 + version: 'v1.56.2' args: -E misspell,godot,whitespace test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-go@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: go-version: '1.16' - run: go test -v ./... @@ -32,8 +32,8 @@ jobs: tidy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: actions/setup-go@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 with: go-version: '1.16' - run: go mod tidy diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d4c1fa7 --- /dev/null +++ b/Makefile @@ -0,0 +1,21 @@ +.PHONY: all arrange tidy lint test + +all: arrange tidy lint test + +arrange: + @echo "Arranging files..." + @go fmt ./... + @goarrange run -r + +tidy: + @echo "Tidying up..." + @go mod tidy + +lint: + @echo "Linting files..." + @go vet ./... + @golangci-lint run ./... -E misspell,godot,whitespace + +test: + @echo "Running tests..." + @go test ./... -cover diff --git a/README.md b/README.md index 6f3b65a..6a2b5f0 100644 --- a/README.md +++ b/README.md @@ -3,45 +3,56 @@ [![GoVersion](https://img.shields.io/github/go-mod/go-version/elimity-com/scim.svg)](https://github.com/elimity-com/scim) [![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg)](https://pkg.go.dev/github.com/elimity-com/scim) - [![Tag](https://img.shields.io/github/tag/elimity-com/scim.svg)](https://gitHub.com/elimity-com/scim/releases) -This is an open source implementation of the [SCIM v2.0](http://www.simplecloud.info/#Specification) specification for use in Golang. +This is an open source implementation of the [SCIM v2.0](http://www.simplecloud.info/#Specification) specification for +use in Golang. SCIM defines a flexible schema mechanism and REST API for managing identity data. -The goal is to reduce the complexity of user management operations by providing patterns for exchanging schemas using HTTP. +The goal is to reduce the complexity of user management operations by providing patterns for exchanging schemas using +HTTP. In this implementation it is easy to add *custom* schemas and extensions with the provided structures. Incoming resources will be *validated* by their corresponding schemas before being passed on to their callbacks. The following features are supported: + - GET for `/Schemas`, `/ServiceProviderConfig` and `/ResourceTypes` - CRUD (POST/GET/PUT/DELETE and PATCH) for your own resource types (i.e. `/Users`, `/Groups`, `/Employees`, ...) Other optional features such as sorting, bulk, etc. are **not** supported in this version. ## Installation + Assuming you already have a (recent) version of Go installed, you can get the code with go get: + ```bash $ go get github.com/elimity-com/scim ``` ## Usage + **!** errors are ignored for simplicity. + ### 1. Create a service provider configuration. + [RFC Config](https://tools.ietf.org/html/rfc7643#section-5) | [Example Config](https://tools.ietf.org/html/rfc7643#section-8.5) + ```go config := scim.ServiceProviderConfig{ DocumentationURI: optional.NewString("www.example.com/scim"), } ``` + **!** no additional features/operations are supported in this version. ### 2. Create all supported schemas and extensions. + [RFC Schema](https://tools.ietf.org/html/rfc7643#section-2) | [User Schema](https://tools.ietf.org/html/rfc7643#section-4.1) | [Group Schema](https://tools.ietf.org/html/rfc7643#section-4.2) | [Extension Schema](https://tools.ietf.org/html/rfc7643#section-4.3) + ```go schema := schema.Schema{ ID: "urn:ietf:params:scim:schemas:core:2.0:User", @@ -72,18 +83,23 @@ extension := schema.Schema{ ``` ### 3. Create all resource types and their callbacks. + [RFC Resource Type](https://tools.ietf.org/html/rfc7643#section-6) | [Example Resource Type](https://tools.ietf.org/html/rfc7643#section-8.6) #### 3.1 Callback (implementation of `ResourceHandler`) + [Simple In Memory Example](resource_handler_test.go) + ```go var userResourceHandler scim.ResourceHandler // initialize w/ own implementation ``` + **!** each resource type should have its own resource handler. #### 3.2 Resource Type + ```go resourceTypes := []ResourceType{ { @@ -101,30 +117,63 @@ resourceTypes := []ResourceType{ ``` ### 4. Create Server + ```go -server := Server{ - Config: config, +serverArgs := &ServerArgs{ + ServiceProviderConfig: config, ResourceTypes: resourceTypes, } + +serverOpts := []ServerOption{ + WithLogger(logger), // optional, default is no logging +} + +server, err := NewServer(serverArgs, serverOpts...) +``` + +## Backwards Compatibility + +Even though the SCIM package has been running in some production environments, it is still in an early stage, and not +all features are supported. So be aware that a change in the minor version could break your implementation. We will not +make any breaking changes that takes hours to fix, but some functions might change name or signature. + +This was the case for `v0.1` to `v0.2.0`. + +## String Values for Attributes + +By default, the SCIM server will NOT use the `string` type for all attributes, since this is NOT compliant with the +SCIM specification. It is still possible to enable this behavior by toggling a flag within the `schema` package. + +```go +import "github.com/elimity-com/scim/schema" + +schema.SetAllowStringValues(true) ``` ## Addition Checks/Tests + Not everything can be checked by the SCIM server itself. Below are some things listed that we expect that the implementation covers. **!** this list is currently incomplete! -We want to keep this list as short as possible. +We want to keep this list as short as possible. If you have ideas how we could enforce these rules in the server itself do not hesitate to open [an issue](https://github.com/elimity-com/scim/issues/new) or a PR. + ### Mutability + #### Immutable Attributes + *PUT Handler*: If one or more values are already set for the attribute, the input value(s) MUST match. + #### WriteOnly Attributes + *ALL Handlers*: Attribute values SHALL NOT be returned. \ Note: These attributes usually also has a returned setting of "never". ## Contributing + [![Contributors](https://img.shields.io/github/contributors/elimity-com/scim.svg)](https://gitHub.com/elimity-com/scim/contributors/) We are happy to review pull requests, @@ -132,13 +181,15 @@ but please first discuss the change you wish to make via issue, email, or any other method with the owners of this repository before making a change. If you would like to propose a change please ensure the following: -- All checks of GitHub Actions are passing ([GolangCI-Lint](https://github.com/golangci/golangci-lint): `misspell`, `godot` and `whitespace`) + +- All checks of GitHub Actions are + passing ([GolangCI-Lint](https://github.com/golangci/golangci-lint): `misspell`, `godot` and `whitespace`) - All already existing tests are passing. - You have written tests that cover the code you are making, make sure to include edge cases. - There is documentation for at least all public functions you have added. - New public functions and structures are kept to a minimum. - The same practices are applied (such as the anatomy of methods, names, etc.) - Your changes are compliant with SCIM v2.0 (released as -[RFC7642](https://tools.ietf.org/html/rfc7642), -[RFC7643](https://tools.ietf.org/html/rfc7643) and -[RFC7644](https://tools.ietf.org/html/rfc7644) under [IETF](https://ietf.org/)). + [RFC7642](https://tools.ietf.org/html/rfc7642), + [RFC7643](https://tools.ietf.org/html/rfc7643) and + [RFC7644](https://tools.ietf.org/html/rfc7644) under [IETF](https://ietf.org/)). diff --git a/examples_test.go b/examples_test.go index 8dee9e8..360764d 100644 --- a/examples_test.go +++ b/examples_test.go @@ -1,21 +1,53 @@ package scim import ( - "log" + logger "log" "net/http" ) func ExampleNewServer() { - log.Fatal(http.ListenAndServe(":7643", Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: nil, - })) + args := &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + } + server, err := NewServer(args) + if err != nil { + logger.Fatal(err) + } + logger.Fatal(http.ListenAndServe(":7643", server)) } func ExampleNewServer_basePath() { - http.Handle("/scim/", http.StripPrefix("/scim", Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: nil, - })) - log.Fatal(http.ListenAndServe(":7643", nil)) + args := &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + } + server, err := NewServer(args) + if err != nil { + logger.Fatal(err) + } + // You can host the SCIM server on a custom path, make sure to strip the prefix, so only `/v2/` is left. + http.Handle("/scim/", http.StripPrefix("/scim", server)) + logger.Fatal(http.ListenAndServe(":7643", nil)) +} + +func ExampleNewServer_logger() { + loggingMiddleware := func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + logger.Println(r.Method, r.URL.Path) + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } + args := &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{}, + } + server, err := NewServer(args) + if err != nil { + logger.Fatal(err) + } + logger.Fatal(http.ListenAndServe(":7643", loggingMiddleware(server))) } diff --git a/internal/filter/filter.go b/filter/filter.go similarity index 100% rename from internal/filter/filter.go rename to filter/filter.go diff --git a/internal/filter/filter_test.go b/filter/filter_test.go similarity index 93% rename from internal/filter/filter_test.go rename to filter/filter_test.go index c1c47ef..62293f6 100644 --- a/internal/filter/filter_test.go +++ b/filter/filter_test.go @@ -1,7 +1,7 @@ package filter_test import ( - internal "github.com/elimity-com/scim/internal/filter" + "github.com/elimity-com/scim/filter" "github.com/elimity-com/scim/schema" "testing" ) @@ -25,7 +25,7 @@ func TestPathValidator_Validate(t *testing.T) { `urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:employeeNumber`, `urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager.displayName`, } { - validator, err := internal.NewPathValidator(f, schema.CoreUserSchema(), schema.ExtensionEnterpriseUser()) + validator, err := filter.NewPathValidator(f, schema.CoreUserSchema(), schema.ExtensionEnterpriseUser()) if err != nil { t.Fatal(err) } @@ -47,7 +47,7 @@ func TestPathValidator_Validate(t *testing.T) { `urn:ietf:params:scim:schemas:core:2.0:User:employeeNumber`, `urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:userName`, } { - validator, err := internal.NewPathValidator(f, schema.CoreUserSchema(), schema.ExtensionEnterpriseUser()) + validator, err := filter.NewPathValidator(f, schema.CoreUserSchema(), schema.ExtensionEnterpriseUser()) if err != nil { t.Fatal(err) } @@ -92,7 +92,7 @@ func TestValidator_PassesFilter(t *testing.T) { }, }, } { - validator, err := internal.NewValidator(test.filter, schema.CoreUserSchema()) + validator, err := filter.NewValidator(test.filter, schema.CoreUserSchema()) if err != nil { t.Fatal(err) } @@ -137,7 +137,7 @@ func TestValidator_PassesFilter(t *testing.T) { userSchema := schema.CoreUserSchema() userSchema.Attributes = append(userSchema.Attributes, schema.SchemasAttributes()) userSchema.Attributes = append(userSchema.Attributes, schema.CommonAttributes()...) - validator, err := internal.NewValidator(test.filter, userSchema) + validator, err := filter.NewValidator(test.filter, userSchema) if err != nil { t.Fatal(err) } @@ -168,7 +168,7 @@ func TestValidator_PassesFilter(t *testing.T) { filter: `urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:organization eq "Elimity"`, }, } { - validator, err := internal.NewValidator(test.filter, schema.ExtensionEnterpriseUser()) + validator, err := filter.NewValidator(test.filter, schema.ExtensionEnterpriseUser()) if err != nil { t.Fatal(err) } @@ -212,7 +212,7 @@ func TestValidator_Validate(t *testing.T) { `userType eq "Employee" and emails[type eq "work" and value co "@example.com"]`, `emails[type eq "work" and value co "@example.com"] or ims[type eq "xmpp" and value co "@foo.com"]`, } { - validator, err := internal.NewValidator(f, userSchema) + validator, err := filter.NewValidator(f, userSchema) if err != nil { t.Fatal(err) } diff --git a/internal/filter/op_binary.go b/filter/op_binary.go similarity index 100% rename from internal/filter/op_binary.go rename to filter/op_binary.go diff --git a/internal/filter/op_boolean.go b/filter/op_boolean.go similarity index 100% rename from internal/filter/op_boolean.go rename to filter/op_boolean.go diff --git a/internal/filter/op_boolean_test.go b/filter/op_boolean_test.go similarity index 95% rename from internal/filter/op_boolean_test.go rename to filter/op_boolean_test.go index dfe1c5d..b46eda9 100644 --- a/internal/filter/op_boolean_test.go +++ b/filter/op_boolean_test.go @@ -2,7 +2,7 @@ package filter_test import ( "fmt" - internal "github.com/elimity-com/scim/internal/filter" + internal "github.com/elimity-com/scim/filter" "github.com/elimity-com/scim/schema" "github.com/scim2/filter-parser/v2" "testing" diff --git a/internal/filter/op_datetime.go b/filter/op_datetime.go similarity index 100% rename from internal/filter/op_datetime.go rename to filter/op_datetime.go diff --git a/internal/filter/op_datetime_test.go b/filter/op_datetime_test.go similarity index 96% rename from internal/filter/op_datetime_test.go rename to filter/op_datetime_test.go index 1f85ba9..749937a 100644 --- a/internal/filter/op_datetime_test.go +++ b/filter/op_datetime_test.go @@ -2,7 +2,7 @@ package filter_test import ( "fmt" - internal "github.com/elimity-com/scim/internal/filter" + internal "github.com/elimity-com/scim/filter" "github.com/elimity-com/scim/schema" "github.com/scim2/filter-parser/v2" "testing" diff --git a/internal/filter/op_decimal.go b/filter/op_decimal.go similarity index 100% rename from internal/filter/op_decimal.go rename to filter/op_decimal.go diff --git a/internal/filter/op_decimal_test.go b/filter/op_decimal_test.go similarity index 96% rename from internal/filter/op_decimal_test.go rename to filter/op_decimal_test.go index aea9d71..8eb023e 100644 --- a/internal/filter/op_decimal_test.go +++ b/filter/op_decimal_test.go @@ -2,7 +2,7 @@ package filter_test import ( "fmt" - internal "github.com/elimity-com/scim/internal/filter" + internal "github.com/elimity-com/scim/filter" "github.com/elimity-com/scim/schema" "github.com/scim2/filter-parser/v2" "testing" diff --git a/internal/filter/op_integer.go b/filter/op_integer.go similarity index 100% rename from internal/filter/op_integer.go rename to filter/op_integer.go diff --git a/internal/filter/op_integer_test.go b/filter/op_integer_test.go similarity index 96% rename from internal/filter/op_integer_test.go rename to filter/op_integer_test.go index 386dabe..3a72386 100644 --- a/internal/filter/op_integer_test.go +++ b/filter/op_integer_test.go @@ -2,7 +2,7 @@ package filter_test import ( "fmt" - internal "github.com/elimity-com/scim/internal/filter" + internal "github.com/elimity-com/scim/filter" "github.com/elimity-com/scim/schema" "github.com/scim2/filter-parser/v2" "testing" diff --git a/internal/filter/op_string.go b/filter/op_string.go similarity index 100% rename from internal/filter/op_string.go rename to filter/op_string.go diff --git a/internal/filter/op_string_test.go b/filter/op_string_test.go similarity index 97% rename from internal/filter/op_string_test.go rename to filter/op_string_test.go index a327185..e8b7d5e 100644 --- a/internal/filter/op_string_test.go +++ b/filter/op_string_test.go @@ -2,7 +2,7 @@ package filter_test import ( "fmt" - internal "github.com/elimity-com/scim/internal/filter" + internal "github.com/elimity-com/scim/filter" "github.com/elimity-com/scim/schema" "github.com/scim2/filter-parser/v2" "testing" diff --git a/internal/filter/operators.go b/filter/operators.go similarity index 100% rename from internal/filter/operators.go rename to filter/operators.go diff --git a/internal/filter/operators_test.go b/filter/operators_test.go similarity index 97% rename from internal/filter/operators_test.go rename to filter/operators_test.go index cb3cd58..4c2b99f 100644 --- a/internal/filter/operators_test.go +++ b/filter/operators_test.go @@ -1,7 +1,7 @@ package filter_test import ( - internal "github.com/elimity-com/scim/internal/filter" + internal "github.com/elimity-com/scim/filter" "github.com/elimity-com/scim/schema" "testing" ) diff --git a/internal/filter/path.go b/filter/path.go similarity index 100% rename from internal/filter/path.go rename to filter/path.go diff --git a/filter_test.go b/filter_test.go index 57709f2..3377adc 100644 --- a/filter_test.go +++ b/filter_test.go @@ -2,7 +2,7 @@ package scim_test import ( "encoding/json" - "io/ioutil" + "io" "net/http" "net/http/httptest" "net/url" @@ -14,7 +14,7 @@ import ( ) func Test_Group_Filter(t *testing.T) { - s := newTestServerForFilter() + s := newTestServerForFilter(t) tests := []struct { name string @@ -31,7 +31,7 @@ func Test_Group_Filter(t *testing.T) { w := httptest.NewRecorder() s.ServeHTTP(w, r) - bytes, err := ioutil.ReadAll(w.Result().Body) + bytes, err := io.ReadAll(w.Result().Body) if err != nil { t.Fatal(err) } @@ -72,7 +72,7 @@ func Test_Group_Filter(t *testing.T) { } func Test_User_Filter(t *testing.T) { - s := newTestServerForFilter() + s := newTestServerForFilter(t) tests := []struct { name string @@ -89,7 +89,7 @@ func Test_User_Filter(t *testing.T) { w := httptest.NewRecorder() s.ServeHTTP(w, r) - bytes, err := ioutil.ReadAll(w.Result().Body) + bytes, err := io.ReadAll(w.Result().Body) if err != nil { t.Fatal(err) } @@ -129,37 +129,46 @@ func Test_User_Filter(t *testing.T) { } } -func newTestServerForFilter() scim.Server { - return scim.Server{ - ResourceTypes: []scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{"userName": "testUser"}}, - "0002": {attributes: map[string]interface{}{"userName": "testUser+test"}}, +// newTestServerForFilter creates a new test server with a User and Group resource type +// or fails the test if an error occurs. +func newTestServerForFilter(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{}, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{"userName": "testUser"}}, + "0002": {attributes: map[string]interface{}{"userName": "testUser+test"}}, + }, + schema: schema.CoreUserSchema(), }, - schema: schema.CoreUserSchema(), }, - }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{"displayName": "testGroup"}}, - "0002": {attributes: map[string]interface{}{"displayName": "testGroup+test"}}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{"displayName": "testGroup"}}, + "0002": {attributes: map[string]interface{}{"displayName": "testGroup+test"}}, + }, + schema: schema.CoreGroupSchema(), }, - schema: schema.CoreGroupSchema(), }, }, }, + ) + if err != nil { + t.Fatal(err) } + return s } diff --git a/handlers.go b/handlers.go index 548ed55..7f207ee 100644 --- a/handlers.go +++ b/handlers.go @@ -2,24 +2,30 @@ package scim import ( "encoding/json" - "log" "net/http" "github.com/elimity-com/scim/errors" - f "github.com/elimity-com/scim/internal/filter" "github.com/elimity-com/scim/schema" ) -func errorHandler(w http.ResponseWriter, _ *http.Request, scimErr *errors.ScimError) { +func (s Server) errorHandler(w http.ResponseWriter, scimErr *errors.ScimError) { raw, err := json.Marshal(scimErr) if err != nil { - log.Fatalf("failed marshaling scim error: %v", err) + s.log.Error( + "failed marshaling scim error", + "scimError", scimErr, + "error", err, + ) + return } w.WriteHeader(scimErr.Status) _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } @@ -29,7 +35,7 @@ func (s Server) resourceDeleteHandler(w http.ResponseWriter, r *http.Request, id deleteErr := resourceType.Handler.Delete(r, id) if deleteErr != nil { scimErr := errors.CheckScimError(deleteErr, http.MethodDelete) - errorHandler(w, r, &scimErr) + s.errorHandler(w, &scimErr) return } @@ -42,14 +48,18 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st resource, getErr := resourceType.Handler.Get(r, id) if getErr != nil { scimErr := errors.CheckScimError(getErr, http.MethodGet) - errorHandler(w, r, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, r, &errors.ScimErrorInternal) - log.Fatalf("failed marshaling resource: %v", err) + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling resource", + "resource", resource, + "error", err, + ) return } @@ -59,7 +69,10 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } @@ -68,14 +81,14 @@ func (s Server) resourceGetHandler(w http.ResponseWriter, r *http.Request, id st func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id string, resourceType ResourceType) { patch, scimErr := resourceType.validatePatch(r) if scimErr != nil { - errorHandler(w, r, scimErr) + s.errorHandler(w, scimErr) return } resource, patchErr := resourceType.Handler.Patch(r, id, patch) if patchErr != nil { scimErr := errors.CheckScimError(patchErr, http.MethodPatch) - errorHandler(w, r, &scimErr) + s.errorHandler(w, &scimErr) return } @@ -86,8 +99,12 @@ func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, r, &errors.ScimErrorInternal) - log.Fatalf("failed marshaling resource: %v", err) + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling resource", + "resource", resource, + "error", err, + ) return } @@ -99,7 +116,10 @@ func (s Server) resourcePatchHandler(w http.ResponseWriter, r *http.Request, id _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } @@ -110,21 +130,25 @@ func (s Server) resourcePostHandler(w http.ResponseWriter, r *http.Request, reso attributes, scimErr := resourceType.validate(data) if scimErr != nil { - errorHandler(w, r, scimErr) + s.errorHandler(w, scimErr) return } resource, postErr := resourceType.Handler.Create(r, attributes) if postErr != nil { scimErr := errors.CheckScimError(postErr, http.MethodPost) - errorHandler(w, r, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, r, &errors.ScimErrorInternal) - log.Fatalf("failed marshaling resource: %v", err) + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling resource", + "resource", resource, + "error", err, + ) return } @@ -136,7 +160,10 @@ func (s Server) resourcePostHandler(w http.ResponseWriter, r *http.Request, reso _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } @@ -147,21 +174,25 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st attributes, scimErr := resourceType.validate(data) if scimErr != nil { - errorHandler(w, r, scimErr) + s.errorHandler(w, scimErr) return } resource, putError := resourceType.Handler.Replace(r, id, attributes) if putError != nil { scimErr := errors.CheckScimError(putError, http.MethodPut) - errorHandler(w, r, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resource.response(resourceType)) if err != nil { - errorHandler(w, r, &errors.ScimErrorInternal) - log.Fatalf("failed marshaling resource: %v", err) + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling resource", + "resource", resource, + "error", err, + ) return } @@ -171,7 +202,10 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } @@ -179,7 +213,7 @@ func (s Server) resourcePutHandler(w http.ResponseWriter, r *http.Request, id st // resource types name to the /ResourceTypes endpoint. For example: "/ResourceTypes/User". func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name string) { var resourceType ResourceType - for _, r := range s.ResourceTypes { + for _, r := range s.resourceTypes { if r.Name == name { resourceType = r break @@ -188,20 +222,26 @@ func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name if resourceType.Name != name { scimErr := errors.ScimErrorResourceNotFound(name) - errorHandler(w, r, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(resourceType.getRaw()) if err != nil { - errorHandler(w, r, &errors.ScimErrorInternal) - log.Fatalf("failed marshaling resource type: %v", err) - return + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling resource type", + "resourceType", resourceType, + "error", err, + ) } _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } @@ -211,31 +251,39 @@ func (s Server) resourceTypeHandler(w http.ResponseWriter, r *http.Request, name func (s Server) resourceTypesHandler(w http.ResponseWriter, r *http.Request) { params, paramsErr := s.parseRequestParams(r, schema.ResourceTypeSchema()) if paramsErr != nil { - errorHandler(w, r, paramsErr) + s.errorHandler(w, paramsErr) return } - start, end := clamp(params.StartIndex-1, params.Count, len(s.ResourceTypes)) + start, end := clamp(params.StartIndex-1, params.Count, len(s.resourceTypes)) var resources []interface{} - for _, v := range s.ResourceTypes[start:end] { + for _, v := range s.resourceTypes[start:end] { resources = append(resources, v.getRaw()) } - raw, err := json.Marshal(listResponse{ - TotalResults: len(s.ResourceTypes), + lr := listResponse{ + TotalResults: len(s.resourceTypes), ItemsPerPage: params.Count, StartIndex: params.StartIndex, Resources: resources, - }) + } + raw, err := json.Marshal(lr) if err != nil { - errorHandler(w, r, &errors.ScimErrorInternal) - log.Fatalf("failed marshaling list response: %v", err) + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling list response", + "listResponse", lr, + "error", err, + ) return } _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } @@ -244,32 +292,40 @@ func (s Server) resourceTypesHandler(w http.ResponseWriter, r *http.Request) { func (s Server) resourcesGetHandler(w http.ResponseWriter, r *http.Request, resourceType ResourceType) { params, paramsErr := s.parseRequestParams(r, resourceType.Schema, resourceType.getSchemaExtensions()...) if paramsErr != nil { - errorHandler(w, r, paramsErr) + s.errorHandler(w, paramsErr) return } page, getError := resourceType.Handler.GetAll(r, params) if getError != nil { scimErr := errors.CheckScimError(getError, http.MethodGet) - errorHandler(w, r, &scimErr) + s.errorHandler(w, &scimErr) return } - raw, err := json.Marshal(listResponse{ + lr := listResponse{ TotalResults: page.TotalResults, Resources: page.resources(resourceType), StartIndex: params.StartIndex, ItemsPerPage: params.Count, - }) + } + raw, err := json.Marshal(lr) if err != nil { - errorHandler(w, r, &errors.ScimErrorInternal) - log.Fatalf("failed marshalling list response: %v", err) + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling list response", + "listResponse", lr, + "error", err, + ) return } _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } @@ -279,20 +335,26 @@ func (s Server) schemaHandler(w http.ResponseWriter, r *http.Request, id string) getSchema := s.getSchema(id) if getSchema.ID != id { scimErr := errors.ScimErrorResourceNotFound(id) - errorHandler(w, r, &scimErr) + s.errorHandler(w, &scimErr) return } raw, err := json.Marshal(getSchema) if err != nil { - errorHandler(w, r, &errors.ScimErrorInternal) - log.Fatalf("failed marshaling schema: %v", err) - return + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling schema", + "schema", getSchema, + "error", err, + ) } _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } @@ -301,24 +363,23 @@ func (s Server) schemaHandler(w http.ResponseWriter, r *http.Request, id string) func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { params, paramsErr := s.parseRequestParams(r, schema.Definition()) if paramsErr != nil { - errorHandler(w, r, paramsErr) + s.errorHandler(w, paramsErr) return } var ( - validator = f.NewFilterValidator(params.Filter, schema.Definition()) start, end = clamp(params.StartIndex-1, params.Count, len(s.getSchemas())) resources []interface{} ) - if params.Filter != nil { + if validator := params.FilterValidator; validator != nil { if err := validator.Validate(); err != nil { - errorHandler(w, r, &errors.ScimErrorInvalidFilter) + s.errorHandler(w, &errors.ScimErrorInvalidFilter) return } } for _, v := range s.getSchemas()[start:end] { resource := v.ToMap() - if params.Filter != nil { + if validator := params.FilterValidator; validator != nil { if err := validator.PassesFilter(resource); err != nil { continue } @@ -326,36 +387,51 @@ func (s Server) schemasHandler(w http.ResponseWriter, r *http.Request) { resources = append(resources, resource) } - raw, err := json.Marshal(listResponse{ + lr := listResponse{ TotalResults: len(s.getSchemas()), ItemsPerPage: params.Count, StartIndex: params.StartIndex, Resources: resources, - }) + } + raw, err := json.Marshal(lr) if err != nil { - errorHandler(w, r, &errors.ScimErrorInternal) - log.Fatalf("failed marshaling list response: %v", err) + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling list response", + "listResponse", lr, + "error", err, + ) return } _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } // serviceProviderConfigHandler receives an HTTP GET to this endpoint will return a JSON structure that describes the // SCIM specification features available on a service provider. func (s Server) serviceProviderConfigHandler(w http.ResponseWriter, r *http.Request) { - raw, err := json.Marshal(s.Config.getRaw()) + raw, err := json.Marshal(s.config.getRaw()) if err != nil { - errorHandler(w, r, &errors.ScimErrorInternal) - log.Fatalf("failed marshaling service provider config: %v", err) + s.errorHandler(w, &errors.ScimErrorInternal) + s.log.Error( + "failed marshaling service provider config", + "serviceProviderConfig", s.config, + "error", err, + ) return } _, err = w.Write(raw) if err != nil { - log.Printf("failed writing response: %v", err) + s.log.Error( + "failed writing response", + "error", err, + ) } } diff --git a/handlers_test.go b/handlers_test.go index 0dd83ee..9d7e6c2 100644 --- a/handlers_test.go +++ b/handlers_test.go @@ -84,7 +84,7 @@ func TestInvalidRequests(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(test.method, test.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, test.expectedStatus, rr.Code) }) @@ -94,7 +94,7 @@ func TestInvalidRequests(t *testing.T) { func TestServerMeEndpoint(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Me", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotImplemented, rr.Code) } @@ -102,7 +102,7 @@ func TestServerMeEndpoint(t *testing.T) { func TestServerResourceDeleteHandler(t *testing.T) { req := httptest.NewRequest(http.MethodDelete, "/Users/0001", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNoContent, rr.Code) } @@ -110,7 +110,7 @@ func TestServerResourceDeleteHandler(t *testing.T) { func TestServerResourceDeleteHandlerNotFound(t *testing.T) { req := httptest.NewRequest(http.MethodDelete, "/Users/9999", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotFound, rr.Code) @@ -156,7 +156,7 @@ func TestServerResourceGetHandler(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tt.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -185,7 +185,7 @@ func TestServerResourceGetHandler(t *testing.T) { func TestServerResourceGetHandlerNotFound(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users/9999", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotFound, rr.Code) @@ -210,7 +210,7 @@ func TestServerResourcePatchHandlerFailOnBadType(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) var resource map[string]interface{} assertUnmarshalNoError(t, json.Unmarshal(rr.Body.Bytes(), &resource)) @@ -232,7 +232,7 @@ func TestServerResourcePatchHandlerInvalidPath(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) @@ -252,14 +252,14 @@ func TestServerResourcePatchHandlerInvalidRemoveOp(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) } func TestServerResourcePatchHandlerMapTypeSubAttribute(t *testing.T) { recorder := httptest.NewRecorder() - newTestServer().ServeHTTP(recorder, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ + newTestServer(t).ServeHTTP(recorder, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], "Operations":[ { @@ -272,7 +272,7 @@ func TestServerResourcePatchHandlerMapTypeSubAttribute(t *testing.T) { assertEqualStatusCode(t, http.StatusOK, recorder.Code) recorder2 := httptest.NewRecorder() - newTestServer().ServeHTTP(recorder2, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ + newTestServer(t).ServeHTTP(recorder2, httptest.NewRequest(http.MethodPatch, "/Users/0001", strings.NewReader(`{ "schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"], "Operations":[ { @@ -319,7 +319,7 @@ func TestServerResourcePatchHandlerReturnsNoContent(t *testing.T) { } for _, req := range reqs { rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNoContent, rr.Code) } @@ -358,7 +358,7 @@ func TestServerResourcePatchHandlerValid(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -408,7 +408,7 @@ func TestServerResourcePatchHandlerValidPathHasSubAttributes(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) } @@ -424,7 +424,7 @@ func TestServerResourcePatchHandlerValidRemoveOp(t *testing.T) { ] }`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNoContent, rr.Code) } @@ -468,7 +468,7 @@ func TestServerResourcePostHandlerValid(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, test.target, test.body) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusCreated, rr.Code) @@ -498,7 +498,7 @@ func TestServerResourcePostHandlerValid(t *testing.T) { func TestServerResourcePutHandlerNotFound(t *testing.T) { req := httptest.NewRequest(http.MethodPut, "/Users/9999", strings.NewReader(`{"userName": "other"}`)) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusNotFound, rr.Code) @@ -549,7 +549,7 @@ func TestServerResourcePutHandlerValid(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPut, test.target, test.body) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -595,7 +595,7 @@ func TestServerResourceTypeHandlerValid(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/ResourceTypes/%s", tt.versionPrefix, tt.resourceType), nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -625,7 +625,7 @@ func TestServerResourceTypesHandler(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, test.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -650,7 +650,7 @@ func TestServerResourceTypesHandler(t *testing.T) { func TestServerResourcesGetAllHandlerNegativeCount(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=-1", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -663,7 +663,7 @@ func TestServerResourcesGetAllHandlerNegativeCount(t *testing.T) { func TestServerResourcesGetAllHandlerNonIntCount(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=BadBanana", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) @@ -676,7 +676,7 @@ func TestServerResourcesGetAllHandlerNonIntCount(t *testing.T) { func TestServerResourcesGetAllHandlerNonIntStartIndex(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?startIndex=BadBanana", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusBadRequest, rr.Code) @@ -689,7 +689,7 @@ func TestServerResourcesGetAllHandlerNonIntStartIndex(t *testing.T) { func TestServerResourcesGetHandler(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -702,7 +702,7 @@ func TestServerResourcesGetHandler(t *testing.T) { func TestServerResourcesGetHandlerMaxCount(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=20000", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -714,7 +714,7 @@ func TestServerResourcesGetHandlerMaxCount(t *testing.T) { func TestServerResourcesGetHandlerPagination(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/Users?count=2&startIndex=2", nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -752,7 +752,7 @@ func TestServerSchemaEndpointValid(t *testing.T) { "%s/Schemas/%s", test.versionPrefix, test.schema, ), nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -781,7 +781,7 @@ func TestServerSchemasEndpoint(t *testing.T) { t.Run(test.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, test.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -817,7 +817,7 @@ func TestServerSchemasEndpointFilter(t *testing.T) { "/Schemas?%s", params.Encode(), ), nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) @@ -845,7 +845,7 @@ func TestServerServiceProviderConfigHandler(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tt.target, nil) rr := httptest.NewRecorder() - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) assertEqualStatusCode(t, http.StatusOK, rr.Code) }) @@ -956,39 +956,45 @@ func newTestResourceHandler() ResourceHandler { } } -func newTestServer() Server { +func newTestServer(t *testing.T) Server { userSchema := getUserSchema() userSchemaExtension := getUserExtensionSchema() - return Server{ - Config: ServiceProviderConfig{}, - ResourceTypes: []ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: userSchema, - Handler: newTestResourceHandler(), - }, - { - ID: optional.NewString("EnterpriseUser"), - Name: "EnterpriseUser", - Endpoint: "/EnterpriseUsers", - Description: optional.NewString("Enterprise User Account"), - Schema: userSchema, - SchemaExtensions: []SchemaExtension{ - {Schema: userSchemaExtension}, + s, err := NewServer( + &ServerArgs{ + ServiceProviderConfig: &ServiceProviderConfig{}, + ResourceTypes: []ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: userSchema, + Handler: newTestResourceHandler(), + }, + { + ID: optional.NewString("EnterpriseUser"), + Name: "EnterpriseUser", + Endpoint: "/EnterpriseUsers", + Description: optional.NewString("Enterprise User Account"), + Schema: userSchema, + SchemaExtensions: []SchemaExtension{ + {Schema: userSchemaExtension}, + }, + Handler: newTestResourceHandler(), + }, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: newTestResourceHandler(), }, - Handler: newTestResourceHandler(), - }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: newTestResourceHandler(), }, }, + ) + if err != nil { + t.Fatal(err) } + return s } diff --git a/internal/idp_test/azuread_util_test.go b/internal/idp_test/azuread_util_test.go index de122ce..e12863b 100644 --- a/internal/idp_test/azuread_util_test.go +++ b/internal/idp_test/azuread_util_test.go @@ -2,6 +2,7 @@ package idp_test import ( "net/http" + "testing" "time" "github.com/elimity-com/scim" @@ -16,34 +17,39 @@ var azureCreatedTime = time.Date( 19, 59, 26, 0, time.UTC, ) -func newAzureADTestServer() scim.Server { - return scim.Server{ - Config: scim.ServiceProviderConfig{ - MaxResults: 20, - }, - ResourceTypes: []scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - SchemaExtensions: []scim.SchemaExtension{ - {Schema: schema.ExtensionEnterpriseUser()}, - }, - Handler: azureADUserResourceHandler{}, +func newAzureADTestServer(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{ + MaxResults: 20, }, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + SchemaExtensions: []scim.SchemaExtension{ + {Schema: schema.ExtensionEnterpriseUser()}, + }, + Handler: azureADUserResourceHandler{}, + }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: azureADGroupResourceHandler{}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: azureADGroupResourceHandler{}, + }, }, - }, + }) + if err != nil { + t.Fatal(err) } + return s } type azureADGroupResourceHandler struct{} @@ -171,7 +177,7 @@ func (a azureADUserResourceHandler) Get(r *http.Request, id string) (scim.Resour } func (a azureADUserResourceHandler) GetAll(r *http.Request, params scim.ListRequestParams) (scim.Page, error) { - f := params.Filter.(*filter.AttributeExpression) + f := (params.FilterValidator.GetFilter()).(*filter.AttributeExpression) if f.CompareValue.(string) == "non-existent user" { return scim.Page{}, nil } diff --git a/internal/idp_test/idp_test.go b/internal/idp_test/idp_test.go index 881b2ea..f093ec1 100644 --- a/internal/idp_test/idp_test.go +++ b/internal/idp_test/idp_test.go @@ -27,7 +27,7 @@ func TestIdP(t *testing.T) { var test testCase _ = unmarshal(raw, &test) t.Run(strings.TrimSuffix(f.Name(), ".json"), func(t *testing.T) { - if err := testRequest(test, idp.Name()); err != nil { + if err := testRequest(t, test, idp.Name()); err != nil { t.Error(err) } }) @@ -36,23 +36,23 @@ func TestIdP(t *testing.T) { } } -func testRequest(t testCase, idpName string) error { +func testRequest(t *testing.T, tc testCase, idpName string) error { rr := httptest.NewRecorder() - br := bytes.NewReader(t.Request) - getNewServer(idpName).ServeHTTP( + br := bytes.NewReader(tc.Request) + getNewServer(t, idpName).ServeHTTP( rr, - httptest.NewRequest(t.Method, t.Path, br), + httptest.NewRequest(tc.Method, tc.Path, br), ) - if code := rr.Code; code != t.StatusCode { - return fmt.Errorf("expected %d, got %d", t.StatusCode, code) + if code := rr.Code; code != tc.StatusCode { + return fmt.Errorf("expected %d, got %d", tc.StatusCode, code) } - if len(t.Response) != 0 { + if len(tc.Response) != 0 { var response map[string]interface{} if err := unmarshal(rr.Body.Bytes(), &response); err != nil { return err } - if !reflect.DeepEqual(t.Response, response) { - return fmt.Errorf("expected, got:\n%v\n%v", t.Response, response) + if !reflect.DeepEqual(tc.Response, response) { + return fmt.Errorf("expected, got:\n%v\n%v", tc.Response, response) } } return nil diff --git a/internal/idp_test/okta_util_test.go b/internal/idp_test/okta_util_test.go index 5cbf0a1..066edc5 100644 --- a/internal/idp_test/okta_util_test.go +++ b/internal/idp_test/okta_util_test.go @@ -2,6 +2,7 @@ package idp_test import ( "net/http" + "testing" "github.com/elimity-com/scim" "github.com/elimity-com/scim/errors" @@ -9,28 +10,36 @@ import ( "github.com/elimity-com/scim/schema" ) -func newOktaTestServer() scim.Server { - return scim.Server{ - ResourceTypes: []scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - Handler: oktaUserResourceHandler{}, - }, +func newOktaTestServer(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{}, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + Handler: oktaUserResourceHandler{}, + }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: oktaGroupResourceHandler{}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: oktaGroupResourceHandler{}, + }, }, }, + ) + if err != nil { + t.Fatal(err) } + + return s } type oktaGroupResourceHandler struct{} diff --git a/internal/idp_test/util_test.go b/internal/idp_test/util_test.go index c1a112b..ae9b1fa 100644 --- a/internal/idp_test/util_test.go +++ b/internal/idp_test/util_test.go @@ -3,16 +3,17 @@ package idp_test import ( "bytes" "encoding/json" + "testing" "github.com/elimity-com/scim" ) -func getNewServer(idpName string) scim.Server { +func getNewServer(t *testing.T, idpName string) scim.Server { switch idpName { case "okta": - return newOktaTestServer() + return newOktaTestServer(t) case "azuread": - return newAzureADTestServer() + return newAzureADTestServer(t) default: panic("unreachable") } diff --git a/internal/patch/add_test.go b/internal/patch/add_test.go index 56dcd6e..fddd1b8 100644 --- a/internal/patch/add_test.go +++ b/internal/patch/add_test.go @@ -1,21 +1,22 @@ package patch import ( + "encoding/json" "fmt" "github.com/elimity-com/scim/schema" ) // The following example shows how to add a member to a group. func Example_addMemberToGroup() { - operation := `{ - "op": "add", - "path": "members", - "value": { - "display": "di-wu", - "$ref": "https://example.com/v2/Users/0001", - "value": "0001" - } -}` + operation, _ := json.Marshal(map[string]interface{}{ + "op": "add", + "path": "members", + "value": map[string]interface{}{ + "display": "di-wu", + "$ref": "https://example.com/v2/Users/0001", + "value": "0001", + }, + }) validator, _ := NewValidator(operation, schema.CoreGroupSchema()) fmt.Println(validator.Validate()) // Output: @@ -24,18 +25,18 @@ func Example_addMemberToGroup() { // The following example shows how to add one or more attributes to a User resource without using a "path" attribute. func Example_addWithoutPath() { - operation := `{ - "op": "add", - "value": { - "emails": [ - { - "value": "quint@elimity.com", - "type": "work" - } - ], - "nickname": "di-wu" - } -}` + operation, _ := json.Marshal(map[string]interface{}{ + "op": "add", + "value": map[string]interface{}{ + "emails": []map[string]interface{}{ + { + "value": "quint@elimity.com", + "type": "work", + }, + }, + "nickname": "di-wu", + }, + }) validator, _ := NewValidator(operation, schema.CoreUserSchema()) fmt.Println(validator.Validate()) // Output: diff --git a/internal/patch/patch.go b/internal/patch/patch.go index e540784..70beef0 100644 --- a/internal/patch/patch.go +++ b/internal/patch/patch.go @@ -1,11 +1,12 @@ package patch import ( + "bytes" "encoding/json" "fmt" + f "github.com/elimity-com/scim/filter" "strings" - f "github.com/elimity-com/scim/internal/filter" "github.com/elimity-com/scim/schema" "github.com/scim2/filter-parser/v2" ) @@ -35,13 +36,16 @@ type OperationValidator struct { // NewValidator creates an OperationValidator based on the given JSON string and reference schemas. // Returns an error if patchReq is not valid. -func NewValidator(patchReq string, s schema.Schema, extensions ...schema.Schema) (OperationValidator, error) { +func NewValidator(patchReq []byte, s schema.Schema, extensions ...schema.Schema) (OperationValidator, error) { var operation struct { Op string Path string Value interface{} } - if err := json.Unmarshal([]byte(patchReq), &operation); err != nil { + + d := json.NewDecoder(bytes.NewReader(patchReq)) + d.UseNumber() + if err := d.Decode(&operation); err != nil { return OperationValidator{}, err } diff --git a/internal/patch/patch_test.go b/internal/patch/patch_test.go index 991f20b..6f0de67 100644 --- a/internal/patch/patch_test.go +++ b/internal/patch/patch_test.go @@ -1,23 +1,103 @@ package patch import ( - "fmt" + "encoding/json" "github.com/elimity-com/scim/schema" "github.com/scim2/filter-parser/v2" "testing" ) func TestNewPathValidator(t *testing.T) { - t.Run("Invalid JSON", func(t *testing.T) { - // The quotes in the value filter are not escaped. - op := `{"op":"add","path":"complexMultiValued[attr1 eq "value"].attr1","value":"value"}` - if _, err := NewValidator(op, patchSchema); err == nil { - t.Error("expected JSON error, got none") + t.Run("Valid Integer", func(t *testing.T) { + for _, op := range []map[string]interface{}{ + {"op": "add", "path": "attr2", "value": 1234}, + {"op": "add", "path": "attr2", "value": "1234"}, + } { + operation, _ := json.Marshal(op) + validator, err := NewValidator(operation, patchSchema) + if err != nil { + t.Fatalf("unexpected error, got %v", err) + } + schema.SetAllowStringValues(true) + defer schema.SetAllowStringValues(false) + v, err := validator.Validate() + if err != nil { + t.Fatalf("unexpected error, got %v", err) + } + n, ok := v.(int64) + if !ok { + t.Fatalf("unexpected type, got %T", v) + } + if n != 1234 { + t.Fatalf("unexpected integer, got %d", n) + } + } + }) + + t.Run("Valid Float", func(t *testing.T) { + for _, op := range []map[string]interface{}{ + {"op": "add", "path": "attr3", "value": 12.34}, + {"op": "add", "path": "attr3", "value": "12.34"}, + } { + operation, _ := json.Marshal(op) + validator, err := NewValidator(operation, patchSchema) + if err != nil { + t.Fatalf("unexpected error, got %v", err) + } + schema.SetAllowStringValues(true) + defer schema.SetAllowStringValues(false) + v, err := validator.Validate() + if err != nil { + t.Fatalf("unexpected error, got %v", err) + } + n, ok := v.(float64) + if !ok { + t.Fatalf("unexpected type, got %T", v) + } + if n != 12.34 { + t.Fatalf("unexpected integer, got %f", n) + } + } + }) + + t.Run("Valid Booleans", func(t *testing.T) { + tests := []struct { + op map[string]interface{} + expected bool + }{ + {map[string]interface{}{"op": "add", "path": "attr4", "value": true}, true}, + {map[string]interface{}{"op": "add", "path": "attr4", "value": "True"}, true}, + {map[string]interface{}{"op": "add", "path": "attr4", "value": false}, false}, + {map[string]interface{}{"op": "add", "path": "attr4", "value": "False"}, false}, + } + for _, tc := range tests { + operation, _ := json.Marshal(tc.op) + validator, err := NewValidator(operation, patchSchema) + if err != nil { + t.Fatalf("unexpected error, got %v", err) + } + schema.SetAllowStringValues(true) + defer schema.SetAllowStringValues(false) + v, err := validator.Validate() + if err != nil { + t.Fatalf("unexpected error, got %v", err) + } + b, ok := v.(bool) + if !ok { + t.Fatalf("unexpected type, got %T", v) + } + if b != tc.expected { + t.Fatalf("unexpected integer, got %v", b) + } } }) t.Run("Invalid Op", func(t *testing.T) { // "op" must be one of "add", "remove", or "replace". - op := `{"op":"invalid","path":"attr1","value":"value"}` + op, _ := json.Marshal(map[string]interface{}{ + "op": "invalid", + "path": "attr1", + "value": "value", + }) validator, _ := NewValidator(op, patchSchema) if _, err := validator.Validate(); err == nil { t.Errorf("expected error, got none") @@ -26,7 +106,11 @@ func TestNewPathValidator(t *testing.T) { t.Run("Invalid Attribute", func(t *testing.T) { // "invalid pr" is not a valid path filter. // This error will be caught by the path filter validator. - op := `{"op":"add","path":"invalid pr","value":"value"}` + op, _ := json.Marshal(map[string]interface{}{ + "op": "add", + "path": "invalid pr", + "value": "value", + }) if _, err := NewValidator(op, patchSchema); err == nil { t.Error("expected JSON error, got none") } @@ -42,9 +126,15 @@ func TestOperationValidator_getRefAttribute(t *testing.T) { {`name.givenName`, `givenName`}, {`urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:employeeNumber`, `employeeNumber`}, } { + op, _ := json.Marshal(map[string]interface{}{ + "op": "add", + "path": test.pathFilter, + "value": "value", + }) validator, err := NewValidator( - fmt.Sprintf(`{"op":"invalid","path":%q,"value":"value"}`, test.pathFilter), - schema.CoreUserSchema(), schema.ExtensionEnterpriseUser(), + op, + schema.CoreUserSchema(), + schema.ExtensionEnterpriseUser(), ) if err != nil { t.Fatal(err) @@ -58,9 +148,15 @@ func TestOperationValidator_getRefAttribute(t *testing.T) { } } + op, _ := json.Marshal(map[string]interface{}{ + "op": "invalid", + "path": "complex", + "value": "value", + }) validator, _ := NewValidator( - `{"op":"invalid","path":"complex","value":"value"}`, - schema.CoreUserSchema(), schema.ExtensionEnterpriseUser(), + op, + schema.CoreUserSchema(), + schema.ExtensionEnterpriseUser(), ) if _, err := validator.getRefAttribute(filter.AttributePath{ AttributeName: "invalid", @@ -77,9 +173,15 @@ func TestOperationValidator_getRefSubAttribute(t *testing.T) { {`name`, `givenName`}, {`groups`, `display`}, } { + op, _ := json.Marshal(map[string]interface{}{ + "op": "invalid", + "path": test.attributeName, + "value": "value", + }) validator, err := NewValidator( - fmt.Sprintf(`{"op":"invalid","path":%q,"value":"value"}`, test.attributeName), - schema.CoreUserSchema(), schema.ExtensionEnterpriseUser(), + op, + schema.CoreUserSchema(), + schema.ExtensionEnterpriseUser(), ) if err != nil { t.Fatal(err) diff --git a/internal/patch/remove.go b/internal/patch/remove.go index 576ac4f..ebd88fe 100644 --- a/internal/patch/remove.go +++ b/internal/patch/remove.go @@ -1,10 +1,10 @@ package patch import ( + "github.com/elimity-com/scim/filter" "net/http" "github.com/elimity-com/scim/errors" - f "github.com/elimity-com/scim/internal/filter" "github.com/elimity-com/scim/schema" ) @@ -24,8 +24,8 @@ func (v OperationValidator) validateRemove() (interface{}, error) { return nil, err } if v.Path.ValueExpression != nil { - if err := f.NewFilterValidator(v.Path.ValueExpression, schema.Schema{ - Attributes: f.MultiValuedFilterAttributes(*refAttr), + if err := filter.NewFilterValidator(v.Path.ValueExpression, schema.Schema{ + Attributes: filter.MultiValuedFilterAttributes(*refAttr), }).Validate(); err != nil { return nil, err } diff --git a/internal/patch/remove_test.go b/internal/patch/remove_test.go index a3e9816..f382ac2 100644 --- a/internal/patch/remove_test.go +++ b/internal/patch/remove_test.go @@ -1,6 +1,7 @@ package patch import ( + "encoding/json" "fmt" "testing" @@ -9,10 +10,10 @@ import ( // The following example shows how remove all members of a group. func Example_removeAllMembers() { - operation := `{ - "op": "remove", - "path": "members" -}` + operation, _ := json.Marshal(map[string]interface{}{ + "op": "remove", + "path": "members", + }) validator, _ := NewValidator(operation, schema.CoreGroupSchema()) fmt.Println(validator.Validate()) // Output: @@ -21,76 +22,82 @@ func Example_removeAllMembers() { // The following example shows how remove a value from a complex multi-valued attribute. func Example_removeComplexMultiValuedAttributeValue() { - operation := `{ - "op": "remove", - "path": "emails[type eq \"work\" and value ew \"elimity.com\"]" -}` + operation, _ := json.Marshal(map[string]interface{}{ + "op": "remove", + "path": `emails[type eq "work" and value eq "elimity.com"]`, + }) validator, _ := NewValidator(operation, schema.CoreUserSchema()) fmt.Println(validator.Validate()) // Output: // } -// The following example shows how remove a single member from a group. -func Example_removeSingleMember() { - operation := `{ - "op": "remove", - "path": "members[value eq \"0001\"]" -}` - validator, _ := NewValidator(operation, schema.CoreGroupSchema()) - fmt.Println(validator.Validate()) - // Output: - // -} - // The following example shows how remove a single group from a user. func Example_removeSingleGroup() { - operation := `{ - "op": "remove", + operation, _ := json.Marshal(map[string]interface{}{ + "op": "remove", "path": "groups", - "value": [{ - "$ref": null, - "value": "f648f8d5ea4e4cd38e9c" - }] - }` + "value": []interface{}{ + map[string]interface{}{ + "$ref": nil, + "value": "f648f8d5ea4e4cd38e9c", + }, + }, + }) validator, _ := NewValidator(operation, schema.CoreUserSchema()) fmt.Println(validator.Validate()) // Output: // [map[]] } +// The following example shows how remove a single member from a group. +func Example_removeSingleMember() { + operation, _ := json.Marshal(map[string]interface{}{ + "op": "remove", + "path": `members[value eq "0001"]`, + }) + validator, _ := NewValidator(operation, schema.CoreGroupSchema()) + fmt.Println(validator.Validate()) + // Output: + // +} + // The following example shows how to replace all of the members of a group with a different members list. func Example_replaceAllMembers() { - operations := []string{`{ - "op": "remove", - "path": "members" -}`, - `{ - "op": "remove", - "path": "members", - "value": [{ - "value": "f648f8d5ea4e4cd38e9c" - }] -}`, - `{ - "op": "add", - "path": "members", - "value": [ + operations := []map[string]interface{}{ { - "display": "di-wu", - "$ref": "https://example.com/v2/Users/0001", - "value": "0001" + "op": "remove", + "path": "members", }, { - "display": "example", - "$ref": "https://example.com/v2/Users/0002", - "value": "0002" - } - ] -}`, + "op": "remove", + "path": "members", + "value": []interface{}{ + map[string]interface{}{ + "value": "f648f8d5ea4e4cd38e9c", + }, + }, + }, + { + "op": "add", + "path": "members", + "value": []interface{}{ + map[string]interface{}{ + "display": "di-wu", + "$ref": "https://example.com/v2/Users/0001", + "value": "0001", + }, + map[string]interface{}{ + "display": "example", + "$ref": "https://example.com/v2/Users/0002", + "value": "0002", + }, + }, + }, } for _, op := range operations { - validator, _ := NewValidator(op, schema.CoreGroupSchema()) + operation, _ := json.Marshal(op) + validator, _ := NewValidator(operation, schema.CoreGroupSchema()) fmt.Println(validator.Validate()) } // Output: @@ -113,27 +120,28 @@ func TestOperationValidator_ValidateRemove(t *testing.T) { // attribute's sub-attributes, the matching records are removed. for i, test := range []struct { - valid string - invalid string + valid map[string]interface{} + invalid map[string]interface{} }{ // If "path" is unspecified, the operation fails. - {invalid: `{"op":"remove"}`}, + {invalid: map[string]interface{}{"op": "remove"}}, // If the target location is a single-value attribute. - {valid: `{"op":"remove","path":"attr1"}`}, + {valid: map[string]interface{}{"op": "remove", "path": "attr1"}}, // If the target location is a multi-valued attribute and no filter is specified. - {valid: `{"op":"remove","path":"multiValued"}`}, + {valid: map[string]interface{}{"op": "remove", "path": "multiValued"}}, // If the target location is a multi-valued attribute and a complex filter is specified comparing a "value". - {valid: `{"op":"remove","path":"multivalued[value eq \"value\"]"}`}, + {valid: map[string]interface{}{"op": "remove", "path": `multivalued[value eq "value"]`}}, // If the target location is a complex multi-valued attribute and a complex filter is specified based on the // attribute's sub-attributes - {valid: `{"op":"remove","path":"complexMultiValued[attr1 eq \"value\"]"}`}, - {valid: `{"op":"remove","path":"complexMultiValued[attr1 eq \"value\"].attr1"}`}, + {valid: map[string]interface{}{"op": "remove", "path": `complexMultiValued[attr1 eq "value"]`}}, + {valid: map[string]interface{}{"op": "remove", "path": `complexMultiValued[attr1 eq "value"].attr1`}}, } { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { // valid - if op := test.valid; op != "" { - validator, err := NewValidator(op, patchSchema, patchSchemaExtension) + if op := test.valid; op != nil { + operation, _ := json.Marshal(op) + validator, err := NewValidator(operation, patchSchema, patchSchemaExtension) if err != nil { t.Fatal(err) } @@ -142,8 +150,9 @@ func TestOperationValidator_ValidateRemove(t *testing.T) { } } // invalid - if op := test.invalid; op != "" { - validator, err := NewValidator(op, patchSchema, patchSchemaExtension) + if op := test.invalid; op != nil { + operation, _ := json.Marshal(op) + validator, err := NewValidator(operation, patchSchema, patchSchemaExtension) if err != nil { t.Fatal(err) } diff --git a/internal/patch/replace_test.go b/internal/patch/replace_test.go index 13d407b..34347ec 100644 --- a/internal/patch/replace_test.go +++ b/internal/patch/replace_test.go @@ -1,29 +1,30 @@ package patch import ( + "encoding/json" "fmt" "github.com/elimity-com/scim/schema" ) // The following example shows how to replace all values of one or more specific attributes. func Example_replaceAnyAttribute() { - operation := `{ - "op": "replace", - "value": { - "emails": [ - { - "value": "quint", - "type": "work", - "primary": true + operation, _ := json.Marshal(map[string]interface{}{ + "op": "replace", + "value": map[string]interface{}{ + "emails": []map[string]interface{}{ + { + "value": "quint", + "type": "work", + "primary": true, + }, + { + "value": "me@di-wu.be", + "type": "home", + }, }, - { - "value": "me@di-wu.be", - "type": "home" - } - ], - "nickname": "di-wu" - } -}` + "nickname": "di-wu", + }, + }) validator, _ := NewValidator(operation, schema.CoreUserSchema()) fmt.Println(validator.Validate()) // Output: @@ -33,25 +34,27 @@ func Example_replaceAnyAttribute() { // The following example shows how to replace all of the members of a group with a different members list in a single // replace operation. func Example_replaceMembers() { - operations := []string{`{ - "op": "replace", - "path": "members", - "value": [ + operations := []map[string]interface{}{ { - "display": "di-wu", - "$ref": "https://example.com/v2/Users/0001", - "value": "0001" + "op": "replace", + "path": "members", + "value": []interface{}{ + map[string]interface{}{ + "display": "di-wu", + "$ref": "https://example.com/v2/Users/0001", + "value": "0001", + }, + map[string]interface{}{ + "display": "example", + "$ref": "https://example.com/v2/Users/0002", + "value": "0002", + }, + }, }, - { - "display": "example", - "$ref": "https://example.com/v2/Users/0002", - "value": "0002" - } - ] -}`, } for _, op := range operations { - validator, _ := NewValidator(op, schema.CoreGroupSchema()) + operation, _ := json.Marshal(op) + validator, _ := NewValidator(operation, schema.CoreGroupSchema()) fmt.Println(validator.Validate()) } // Output: @@ -61,11 +64,11 @@ func Example_replaceMembers() { // The following example shows how to change a specific sub-attribute "streetAddress" of complex attribute "emails" // selected by a "valuePath" filter. func Example_replaceSpecificSubAttribute() { - operation := `{ - "op": "replace", - "path": "addresses[type eq \"work\"].streetAddress", - "value": "ExampleStreet 100" -}` + operation, _ := json.Marshal(map[string]interface{}{ + "op": "replace", + "path": `addresses[type eq "work"].streetAddress`, + "value": "ExampleStreet 100", + }) validator, _ := NewValidator(operation, schema.CoreUserSchema()) fmt.Println(validator.Validate()) // Output: @@ -74,18 +77,18 @@ func Example_replaceSpecificSubAttribute() { // The following example shows how to change a User's entire "work" address, using a "valuePath" filter. func Example_replaceWorkAddress() { - operation := `{ - "op": "replace", - "path": "addresses[type eq \"work\"]", - "value": { - "type": "work", - "streetAddress": "ExampleStreet 1", - "locality": "ExampleCity", - "postalCode": "0001", - "country": "BE", - "primary": true - } -}` + operation, _ := json.Marshal(map[string]interface{}{ + "op": "replace", + "path": `addresses[type eq "work"]`, + "value": map[string]interface{}{ + "type": "work", + "streetAddress": "ExampleStreet 1", + "locality": "ExampleCity", + "postalCode": "0001", + "country": "BE", + "primary": true, + }, + }) validator, _ := NewValidator(operation, schema.CoreUserSchema()) fmt.Println(validator.Validate()) // Output: diff --git a/internal/patch/schema_test.go b/internal/patch/schema_test.go index 711864e..c7c6c6a 100644 --- a/internal/patch/schema_test.go +++ b/internal/patch/schema_test.go @@ -14,6 +14,17 @@ var ( schema.SimpleCoreAttribute(schema.SimpleStringParams(schema.StringParams{ Name: "attr1", })), + schema.SimpleCoreAttribute(schema.SimpleNumberParams(schema.NumberParams{ + Name: "attr2", + Type: schema.AttributeTypeInteger(), + })), + schema.SimpleCoreAttribute(schema.SimpleNumberParams(schema.NumberParams{ + Name: "attr3", + Type: schema.AttributeTypeDecimal(), + })), + schema.SimpleCoreAttribute(schema.SimpleBooleanParams(schema.BooleanParams{ + Name: "attr4", + })), schema.SimpleCoreAttribute(schema.SimpleStringParams(schema.StringParams{ Name: "multiValued", MultiValued: true, diff --git a/internal/patch/update.go b/internal/patch/update.go index 9ff4168..8c0f599 100644 --- a/internal/patch/update.go +++ b/internal/patch/update.go @@ -2,7 +2,7 @@ package patch import ( "fmt" - f "github.com/elimity-com/scim/internal/filter" + f "github.com/elimity-com/scim/filter" "github.com/elimity-com/scim/schema" ) diff --git a/internal/patch/update_test.go b/internal/patch/update_test.go index 3789b94..4c18ee3 100644 --- a/internal/patch/update_test.go +++ b/internal/patch/update_test.go @@ -1,6 +1,7 @@ package patch import ( + "encoding/json" "fmt" "testing" ) @@ -25,21 +26,21 @@ func TestOperationValidator_ValidateUpdate(t *testing.T) { // - Unless other operations change the resource, this operation shall not change the modify timestamp of the // resource. for i, test := range []struct { - valid string - invalid string + valid map[string]interface{} + invalid map[string]interface{} }{ // The operation must contain a "value" member whose content specifies the value to be added. { - valid: `{"op":"add","path":"attr1","value":"value"}`, - invalid: `{"op":"add","path":"attr1"}`, + valid: map[string]interface{}{"op": "add", "path": "attr1", "value": "value"}, + invalid: map[string]interface{}{"op": "add", "path": "attr1"}, }, // A URI prefix in the path. { - valid: `{"op":"add","path":"test:PatchEntity:attr1","value":"value"}`, - invalid: `{"op":"add","path":"invalid:attr1","value":"value"}`, + valid: map[string]interface{}{"op": "add", "path": "test:PatchEntity:attr1", "value": "value"}, + invalid: map[string]interface{}{"op": "add", "path": "invalid:attr1", "value": "value"}, }, - {valid: `{"op":"add","path":"test:PatchExtension:attr1","value":"value"}`}, + {valid: map[string]interface{}{"op": "add", "path": "test:PatchExtension:attr1", "value": "value"}}, // The value MAY be a quoted value, or it may be a JSON object containing the sub-attributes of the complex // attribute specified in the operation's "path". @@ -49,66 +50,66 @@ func TestOperationValidator_ValidateUpdate(t *testing.T) { // The idea is that path can be either fine-grained or point to a whole object. // Thus value of "value" depends on what path points to. { - valid: `{"op":"add","path":"complex.attr1","value":"value"}`, - invalid: `{"op":"add","path":"complex.attr1","value":{"attr1":"value"}}`, + valid: map[string]interface{}{"op": "add", "path": "complex.attr1", "value": "value"}, + invalid: map[string]interface{}{"op": "add", "path": "complex.attr1", "value": map[string]interface{}{"attr1": "value"}}, }, { - valid: `{"op":"add","path":"complex","value":{"attr1":"value"}}`, - invalid: `{"op":"add","path":"complex","value":"value"}`, + valid: map[string]interface{}{"op": "add", "path": "complex", "value": map[string]interface{}{"attr1": "value"}}, + invalid: map[string]interface{}{"op": "add", "path": "complex", "value": "value"}, }, // If omitted, the target location is assumed to be the resource itself. The "value" parameter contains a // set of attributes to be added to the resource. { - valid: `{"op":"add","value":{"attr1":"value"}}`, - invalid: `{"op":"add","value":"value"}`, + valid: map[string]interface{}{"op": "add", "value": map[string]interface{}{"attr1": "value"}}, + invalid: map[string]interface{}{"op": "add", "value": "value"}, }, - {invalid: `{"op":"add","value":{"invalid":"value"}}`}, - {invalid: `{"op":"add","value":{"invalid:attr1":"value"}}`}, + {invalid: map[string]interface{}{"op": "add", "value": map[string]interface{}{"invalid": "value"}}}, + {invalid: map[string]interface{}{"op": "add", "value": map[string]interface{}{"invalid:attr1": "value"}}}, // If the target location specifies a multi-valued attribute, a new value is added to the attribute. - {valid: `{"op":"add","value":{"multiValued":"value"}}`}, + {valid: map[string]interface{}{"op": "add", "value": map[string]interface{}{"multiValued": "value"}}}, // Example on page 36 (RFC7644, Section 3.5.2.1). - {valid: `{"op":"add","path":"complexMultiValued","value":[{"attr1":"value"}]}`}, - {valid: `{"op":"add","path":"complexMultiValued","value":{"attr1":"value"}}`}, + {valid: map[string]interface{}{"op": "add", "path": "complexMultiValued", "value": []interface{}{map[string]interface{}{"attr1": "value"}}}}, + {valid: map[string]interface{}{"op": "add", "path": "complexMultiValued", "value": map[string]interface{}{"attr1": "value"}}}, // Example on page 37 (RFC7644, Section 3.5.2.1). - {valid: `{"op":"add","value":{"attr1":"value","complexMultiValued":[{"attr1":"value"}]}}`}, - {valid: `{"op":"add","value":{"attr1":"value","complexMultiValued":[{"attr1":"value"}]}}`}, + {valid: map[string]interface{}{"op": "add", "value": map[string]interface{}{"attr1": "value", "complexMultiValued": []interface{}{map[string]interface{}{"attr1": "value"}}}}}, { - valid: `{"op":"add","path":"complexMultiValued[attr1 eq \"value\"].attr1","value":"value"}`, - invalid: `{"op":"add","path":"complexMultiValued[attr1 eq \"value\"].attr2","value":"value"}`, + valid: map[string]interface{}{"op": "add", "path": `complexMultiValued[attr1 eq "value"].attr1`, "value": "value"}, + invalid: map[string]interface{}{"op": "add", "path": `complexMultiValued[attr1 eq "value"].attr2`, "value": "value"}, }, { - valid: `{"op":"add","path":"test:PatchEntity:complexMultiValued[attr1 eq \"value\"].attr1","value":"value"}`, - invalid: `{"op":"add","path":"test:PatchEntity:complexMultiValued[attr2 eq \"value\"].attr1","value":"value"}`, + valid: map[string]interface{}{"op": "add", "path": `test:PatchEntity:complexMultiValued[attr1 eq "value"].attr1`, "value": "value"}, + invalid: map[string]interface{}{"op": "add", "path": `test:PatchEntity:complexMultiValued[attr2 eq "value"].attr1`, "value": "value"}, }, // Valid path, attribute not found. - {invalid: `{"op":"add","path":"invalid","value":"value"}`}, - {invalid: `{"op":"add","path":"complex.invalid","value":"value"}`}, + {invalid: map[string]interface{}{"op": "add", "path": "invalid", "value": "value"}}, + {invalid: map[string]interface{}{"op": "add", "path": "complex.invalid", "value": "value"}}, // Sub-attributes in complex assignments. - {valid: `{"op":"add","value":{"complex.attr1":"value"}}`}, + {valid: map[string]interface{}{"op": "add", "value": map[string]interface{}{"complex.attr1": "value"}}}, // Has no sub-attributes. - {invalid: `{"op":"add","path":"attr1.invalid","value":"value"}`}, + {invalid: map[string]interface{}{"op": "add", "path": "attr1.invalid", "value": "value"}}, // Invalid types. - {invalid: `{"op":"add","path":"attr1","value":1}`}, - {invalid: `{"op":"add","path":"multiValued","value":1}`}, - {invalid: `{"op":"add","path":"multiValued","value":[1]}`}, - {invalid: `{"op":"add","path":"complex.attr1","value":1}`}, - {invalid: `{"op":"add","value":{"attr1":1}}`}, - {invalid: `{"op":"add","value":{"multiValued":1}}`}, - {invalid: `{"op":"add","value":{"multiValued":[1]}}`}, + {invalid: map[string]interface{}{"op": "add", "path": "attr1", "value": 1}}, + {invalid: map[string]interface{}{"op": "add", "path": "multiValued", "value": 1}}, + {invalid: map[string]interface{}{"op": "add", "path": "multiValued", "value": []interface{}{1}}}, + {invalid: map[string]interface{}{"op": "add", "path": "complex.attr1", "value": 1}}, + {invalid: map[string]interface{}{"op": "add", "value": map[string]interface{}{"attr1": 1}}}, + {invalid: map[string]interface{}{"op": "add", "value": map[string]interface{}{"multiValued": 1}}}, + {invalid: map[string]interface{}{"op": "add", "value": map[string]interface{}{"multiValued": []interface{}{1}}}}, } { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { // valid - if op := test.valid; op != "" { - validator, err := NewValidator(op, patchSchema, patchSchemaExtension) + if op := test.valid; op != nil { + operation, _ := json.Marshal(op) + validator, err := NewValidator(operation, patchSchema, patchSchemaExtension) if err != nil { t.Fatal(err) } @@ -117,8 +118,9 @@ func TestOperationValidator_ValidateUpdate(t *testing.T) { } } // invalid - if op := test.invalid; op != "" { - validator, err := NewValidator(op, patchSchema, patchSchemaExtension) + if op := test.invalid; op != nil { + operation, _ := json.Marshal(op) + validator, err := NewValidator(operation, patchSchema, patchSchemaExtension) if err != nil { return } diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..180fe62 --- /dev/null +++ b/logger.go @@ -0,0 +1,10 @@ +package scim + +// Logger defines an interface for logging errors. +type Logger interface { + Error(args ...interface{}) +} + +type noopLogger struct{} + +func (noopLogger) Error(...interface{}) {} diff --git a/patch_add_test.go b/patch_add_test.go index 9a84587..c7cca2e 100644 --- a/patch_add_test.go +++ b/patch_add_test.go @@ -3,14 +3,14 @@ package scim_test import ( "bytes" "encoding/json" - "io/ioutil" "net/http" "net/http/httptest" + "os" "testing" ) func TestPatch_addAttributes(t *testing.T) { - raw, err := ioutil.ReadFile("testdata/patch/add/attributes.json") + raw, err := os.ReadFile("testdata/patch/add/attributes.json") if err != nil { t.Fatal(err) } @@ -18,7 +18,7 @@ func TestPatch_addAttributes(t *testing.T) { req = httptest.NewRequest(http.MethodPatch, "/Users/0001", bytes.NewReader(raw)) rr = httptest.NewRecorder() ) - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatal(rr.Code, rr.Body.String()) } @@ -57,7 +57,7 @@ func TestPatch_addAttributes(t *testing.T) { } func TestPatch_addMember(t *testing.T) { - raw, err := ioutil.ReadFile("testdata/patch/add/member.json") + raw, err := os.ReadFile("testdata/patch/add/member.json") if err != nil { t.Fatal(err) } @@ -65,7 +65,7 @@ func TestPatch_addMember(t *testing.T) { req = httptest.NewRequest(http.MethodPatch, "/Groups/0001", bytes.NewReader(raw)) rr = httptest.NewRecorder() ) - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatal(rr.Code, rr.Body.String()) } @@ -121,8 +121,8 @@ func TestPatch_alreadyExists(t *testing.T) { changed: false, }, } { - server := newTestServer() - raw, err := ioutil.ReadFile(test.jsonFilePath) + server := newTestServer(t) + raw, err := os.ReadFile(test.jsonFilePath) if err != nil { t.Fatal(err) } @@ -150,7 +150,7 @@ func TestPatch_alreadyExists(t *testing.T) { } func TestPatch_complex(t *testing.T) { - raw, err := ioutil.ReadFile("testdata/patch/add/complex.json") + raw, err := os.ReadFile("testdata/patch/add/complex.json") if err != nil { t.Fatal(err) } @@ -158,7 +158,7 @@ func TestPatch_complex(t *testing.T) { req = httptest.NewRequest(http.MethodPatch, "/Users/0001", bytes.NewReader(raw)) rr = httptest.NewRecorder() ) - newTestServer().ServeHTTP(rr, req) + newTestServer(t).ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatal(rr.Code, rr.Body.String()) } diff --git a/resource_handler.go b/resource_handler.go index f0bd3d9..1d0cd15 100644 --- a/resource_handler.go +++ b/resource_handler.go @@ -6,9 +6,9 @@ import ( "net/url" "time" + "github.com/elimity-com/scim/filter" "github.com/elimity-com/scim/optional" "github.com/elimity-com/scim/schema" - "github.com/scim2/filter-parser/v2" ) // ListRequestParams request parameters sent to the API via a "GetAll" route. @@ -19,7 +19,7 @@ type ListRequestParams struct { // Filter represents the parsed and tokenized filter query parameter. // It is an optional parameter and thus will be nil when the parameter is not present. - Filter filter.Expression + FilterValidator *filter.Validator // StartIndex The 1-based index of the first query result. A value less than 1 SHALL be interpreted as 1. StartIndex int diff --git a/resource_handler_test.go b/resource_handler_test.go index b2989dd..17e557d 100644 --- a/resource_handler_test.go +++ b/resource_handler_test.go @@ -30,8 +30,8 @@ type testResourceHandler struct { func (h testResourceHandler) Create(r *http.Request, attributes ResourceAttributes) (Resource, error) { // create unique identifier - rand.Seed(time.Now().UnixNano()) - id := fmt.Sprintf("%04d", rand.Intn(9999)) + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + id := fmt.Sprintf("%04d", rng.Intn(9999)) // store resource h.data[id] = testData{ diff --git a/resource_type.go b/resource_type.go index 992093a..8761d7d 100644 --- a/resource_type.go +++ b/resource_type.go @@ -148,7 +148,7 @@ func (t ResourceType) validatePatch(r *http.Request) ([]PatchOperation, *errors. var operations []PatchOperation for _, v := range req.Operations { validator, err := patch.NewValidator( - string(v), + v, t.schemaWithCommon(), t.getSchemaExtensions()..., ) diff --git a/schema/characteristics.go b/schema/characteristics.go index 69a93da..09091d7 100644 --- a/schema/characteristics.go +++ b/schema/characteristics.go @@ -136,15 +136,19 @@ const ( ) func (a attributeMutability) MarshalJSON() ([]byte, error) { + return json.Marshal(a.String()) +} + +func (a attributeMutability) String() string { switch a { case attributeMutabilityImmutable: - return json.Marshal("immutable") + return "immutable" case attributeMutabilityReadOnly: - return json.Marshal("readOnly") + return "readOnly" case attributeMutabilityWriteOnly: - return json.Marshal("writeOnly") + return "writeOnly" default: - return json.Marshal("readWrite") + return "readWrite" } } @@ -158,15 +162,19 @@ const ( ) func (a attributeReturned) MarshalJSON() ([]byte, error) { + return json.Marshal(a.String()) +} + +func (a attributeReturned) String() string { switch a { case attributeReturnedAlways: - return json.Marshal("always") + return "always" case attributeReturnedNever: - return json.Marshal("never") + return "never" case attributeReturnedRequest: - return json.Marshal("request") + return "request" default: - return json.Marshal("default") + return "default" } } @@ -218,12 +226,16 @@ const ( ) func (a attributeUniqueness) MarshalJSON() ([]byte, error) { + return json.Marshal(a.String()) +} + +func (a attributeUniqueness) String() string { switch a { case attributeUniquenessGlobal: - return json.Marshal("global") + return "global" case attributeUniquenessServer: - return json.Marshal("server") + return "server" default: - return json.Marshal("none") + return "none" } } diff --git a/schema/core.go b/schema/core.go index 18399de..37ceaee 100644 --- a/schema/core.go +++ b/schema/core.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "regexp" + "strconv" "strings" datetime "github.com/di-wu/xsd-datetime" @@ -11,6 +12,18 @@ import ( "github.com/elimity-com/scim/optional" ) +var ( + schemaAllowStringValues = false +) + +// SetAllowStringValues sets whether string values are allowed. +// If enabled, string values are allowed for booleans, integer and decimal attributes. +// NOTE: This is NOT a standard SCIM behaviour, and should only be used for compatibility with non-compliant SCIM +// clients, such as the one provided by Microsoft Azure. +func SetAllowStringValues(enabled bool) { + schemaAllowStringValues = enabled +} + // CoreAttribute represents those attributes that sit at the top level of the JSON object together with the common // attributes (such as the resource "id"). type CoreAttribute struct { @@ -122,8 +135,7 @@ func (a CoreAttribute) MultiValued() bool { // Mutability returns the mutability of the attribute. func (a CoreAttribute) Mutability() string { - raw, _ := a.mutability.MarshalJSON() - return string(raw) + return a.mutability.String() } // Name returns the case insensitive name of the attribute. @@ -143,8 +155,7 @@ func (a CoreAttribute) Required() bool { // Returned returns when the attribute need to be returned. func (a CoreAttribute) Returned() string { - raw, _ := a.returned.MarshalJSON() - return string(raw) + return a.returned.String() } // SubAttributes returns the sub attributes. @@ -154,8 +165,7 @@ func (a CoreAttribute) SubAttributes() Attributes { // Uniqueness returns the attributes uniqueness. func (a CoreAttribute) Uniqueness() string { - raw, _ := a.uniqueness.MarshalJSON() - return string(raw) + return a.uniqueness.String() } // ValidateSingular checks whether the given singular value matches the attribute data type. Unknown attributes in @@ -181,6 +191,13 @@ func (a CoreAttribute) ValidateSingular(attribute interface{}) (interface{}, *er case attributeDataTypeBoolean: b, ok := attribute.(bool) if !ok { + if b, ok := attribute.(string); ok && schemaAllowStringValues { + b, err := strconv.ParseBool(b) + if err != nil { + return nil, &errors.ScimErrorInvalidValue + } + return b, nil + } return nil, &errors.ScimErrorInvalidValue } @@ -233,10 +250,14 @@ func (a CoreAttribute) ValidateSingular(attribute interface{}) (interface{}, *er if err != nil { return nil, &errors.ScimErrorInvalidValue } - return f, nil case float64: return n, nil + case string: + if f, err := strconv.ParseFloat(n, 64); err == nil && schemaAllowStringValues { + return f, nil + } + return nil, &errors.ScimErrorInvalidValue default: return nil, &errors.ScimErrorInvalidValue } @@ -247,10 +268,14 @@ func (a CoreAttribute) ValidateSingular(attribute interface{}) (interface{}, *er if err != nil { return nil, &errors.ScimErrorInvalidValue } - return i, nil case int, int8, int16, int32, int64: return n, nil + case string: + if i, err := strconv.ParseInt(n, 10, 64); err == nil && schemaAllowStringValues { + return i, nil + } + return nil, &errors.ScimErrorInvalidValue default: return nil, &errors.ScimErrorInvalidValue } diff --git a/schema/schema_test.go b/schema/schema_test.go index 91bd9ed..c9faacc 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -2,7 +2,7 @@ package schema import ( "encoding/json" - "io/ioutil" + "os" "testing" "github.com/elimity-com/scim/optional" @@ -76,16 +76,14 @@ func TestInvalidAttributeName(t *testing.T) { } func TestJSONMarshalling(t *testing.T) { - expectedJSON, err := ioutil.ReadFile("./testdata/schema_test.json") + expectedJSON, err := os.ReadFile("./testdata/schema_test.json") if err != nil { - t.Errorf("failed to acquire test data") - return + t.Fatal("failed to acquire test data") } actualJSON, err := testSchema.MarshalJSON() if err != nil { - t.Errorf("failed to marshal schema into JSON") - return + t.Fatal("failed to marshal schema into JSON") } normalizedActual, err := normalizeJSON(actualJSON) @@ -94,7 +92,6 @@ func TestJSONMarshalling(t *testing.T) { t.Errorf("failed to normalize test JSON") return } - if normalizedActual != normalizedExpected { t.Errorf("schema output by MarshalJSON did not match the expected output. want %s, got %s", normalizedExpected, normalizedActual) } diff --git a/schema/schemas_test.go b/schema/schemas_test.go index 525238c..eae9adb 100644 --- a/schema/schemas_test.go +++ b/schema/schemas_test.go @@ -2,7 +2,7 @@ package schema import ( "fmt" - "io/ioutil" + "os" "testing" ) @@ -24,7 +24,7 @@ func TestDefaultSchemas(t *testing.T) { schema: ExtensionEnterpriseUser(), }, } { - expectedJSON, err := ioutil.ReadFile(fmt.Sprintf("./testdata/%s", test.file)) + expectedJSON, err := os.ReadFile(fmt.Sprintf("./testdata/%s", test.file)) if err != nil { t.Errorf("Failed to acquire test data") return diff --git a/server.go b/server.go index 564ac31..4981793 100644 --- a/server.go +++ b/server.go @@ -2,14 +2,13 @@ package scim import ( "fmt" - f "github.com/elimity-com/scim/internal/filter" - "github.com/scim2/filter-parser/v2" "net/http" "net/url" "strconv" "strings" "github.com/elimity-com/scim/errors" + "github.com/elimity-com/scim/filter" "github.com/elimity-com/scim/schema" ) @@ -19,20 +18,20 @@ const ( ) // getFilter returns a validated filter if present in the url query, nil otherwise. -func getFilter(r *http.Request, s schema.Schema, extensions ...schema.Schema) (filter.Expression, error) { - filter := strings.TrimSpace(r.URL.Query().Get("filter")) - if filter == "" { +func getFilterValidator(r *http.Request, s schema.Schema, extensions ...schema.Schema) (*filter.Validator, error) { + f := strings.TrimSpace(r.URL.Query().Get("filter")) + if f == "" { return nil, nil // No filter present. } - validator, err := f.NewValidator(filter, s, extensions...) + validator, err := filter.NewValidator(f, s, extensions...) if err != nil { return nil, err } if err := validator.Validate(); err != nil { return nil, err } - return validator.GetFilter(), nil + return &validator, nil } func getIntQueryParam(r *http.Request, key string, def int) (int, error) { @@ -53,11 +52,38 @@ func parseIdentifier(path, endpoint string) (string, error) { return url.PathUnescape(strings.TrimPrefix(path, endpoint+"/")) } -// Server represents a SCIM server which implements the HTTP-based SCIM protocol that makes managing identities in multi- -// domain scenarios easier to support via a standardized service. +// Server represents a SCIM server which implements the HTTP-based SCIM protocol +// that makes managing identities in multi-domain scenarios easier to support via a standardized service. type Server struct { - Config ServiceProviderConfig - ResourceTypes []ResourceType + config ServiceProviderConfig + resourceTypes []ResourceType + log Logger +} + +func NewServer(args *ServerArgs, opts ...ServerOption) (Server, error) { + if args == nil { + return Server{}, fmt.Errorf("arguments not provided") + } + + if args.ServiceProviderConfig == nil { + return Server{}, fmt.Errorf("service provider config not provided") + } + + if args.ResourceTypes == nil { + return Server{}, fmt.Errorf("resource types not provided") + } + + s := &Server{ + config: *args.ServiceProviderConfig, + resourceTypes: args.ResourceTypes, + log: &noopLogger{}, + } + + for _, opt := range opts { + opt(s) + } + + return *s, nil } // ServeHTTP dispatches the request to the handler whose pattern most closely matches the request URL. @@ -68,7 +94,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch { case path == "/Me": - errorHandler(w, r, &errors.ScimError{ + s.errorHandler(w, &errors.ScimError{ Status: http.StatusNotImplemented, }) return @@ -89,7 +115,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - for _, resourceType := range s.ResourceTypes { + for _, resourceType := range s.resourceTypes { if path == resourceType.Endpoint { switch r.Method { case http.MethodPost: @@ -124,7 +150,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - errorHandler(w, r, &errors.ScimError{ + s.errorHandler(w, &errors.ScimError{ Detail: "Specified endpoint does not exist.", Status: http.StatusNotFound, }) @@ -132,7 +158,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // getSchema extracts the schemas from the resources types defined in the server with given id. func (s Server) getSchema(id string) schema.Schema { - for _, resourceType := range s.ResourceTypes { + for _, resourceType := range s.resourceTypes { if resourceType.Schema.ID == id { return resourceType.Schema } @@ -149,7 +175,7 @@ func (s Server) getSchema(id string) schema.Schema { func (s Server) getSchemas() []schema.Schema { ids := make([]string, 0) schemas := make([]schema.Schema, 0) - for _, resourceType := range s.ResourceTypes { + for _, resourceType := range s.resourceTypes { if !contains(ids, resourceType.Schema.ID) { schemas = append(schemas, resourceType.Schema) } @@ -167,7 +193,7 @@ func (s Server) getSchemas() []schema.Schema { func (s Server) parseRequestParams(r *http.Request, refSchema schema.Schema, refExtensions ...schema.Schema) (ListRequestParams, *errors.ScimError) { invalidParams := make([]string, 0) - defaultCount := s.Config.getItemsPerPage() + defaultCount := s.config.getItemsPerPage() count, countErr := getIntQueryParam(r, "count", defaultCount) if countErr != nil { invalidParams = append(invalidParams, "count") @@ -194,14 +220,30 @@ func (s Server) parseRequestParams(r *http.Request, refSchema schema.Schema, ref return ListRequestParams{}, &scimErr } - reqFilter, err := getFilter(r, refSchema, refExtensions...) + validator, err := getFilterValidator(r, refSchema, refExtensions...) if err != nil { return ListRequestParams{}, &errors.ScimErrorInvalidFilter } return ListRequestParams{ - Count: count, - Filter: reqFilter, - StartIndex: startIndex, + Count: count, + FilterValidator: validator, + StartIndex: startIndex, }, nil } + +type ServerArgs struct { + ServiceProviderConfig *ServiceProviderConfig + ResourceTypes []ResourceType +} + +type ServerOption func(*Server) + +// WithLogger sets the logger for the server. +func WithLogger(logger Logger) ServerOption { + return func(s *Server) { + if logger != nil { + s.log = logger + } + } +} diff --git a/server_test.go b/server_test.go index 813af0f..726fea5 100644 --- a/server_test.go +++ b/server_test.go @@ -2,13 +2,15 @@ package scim_test import ( "fmt" - "io/ioutil" + "io" "net/http" + "testing" "time" + internal "github.com/elimity-com/scim/filter" + "github.com/elimity-com/scim" "github.com/elimity-com/scim/errors" - internal "github.com/elimity-com/scim/internal/filter" "github.com/elimity-com/scim/optional" "github.com/elimity-com/scim/schema" "github.com/scim2/filter-parser/v2" @@ -16,7 +18,7 @@ import ( func checkBodyNotEmpty(r *http.Request) error { // Check whether the request body is empty. - data, err := ioutil.ReadAll(r.Body) + data, err := io.ReadAll(r.Body) if err != nil { return err } @@ -40,37 +42,44 @@ func externalID(attributes scim.ResourceAttributes) optional.String { // - Whether a reference to another entity really exists. // e.g. if a member gets added, does this entity exist? -func newTestServer() scim.Server { - return scim.Server{ - ResourceTypes: []scim.ResourceType{ - { - ID: optional.NewString("User"), - Name: "User", - Endpoint: "/Users", - Description: optional.NewString("User Account"), - Schema: schema.CoreUserSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{}}, +func newTestServer(t *testing.T) scim.Server { + s, err := scim.NewServer( + &scim.ServerArgs{ + ServiceProviderConfig: &scim.ServiceProviderConfig{}, + ResourceTypes: []scim.ResourceType{ + { + ID: optional.NewString("User"), + Name: "User", + Endpoint: "/Users", + Description: optional.NewString("User Account"), + Schema: schema.CoreUserSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{}}, + }, + schema: schema.CoreUserSchema(), }, - schema: schema.CoreUserSchema(), }, - }, - { - ID: optional.NewString("Group"), - Name: "Group", - Endpoint: "/Groups", - Description: optional.NewString("Group"), - Schema: schema.CoreGroupSchema(), - Handler: &testResourceHandler{ - data: map[string]testData{ - "0001": {attributes: map[string]interface{}{}}, + { + ID: optional.NewString("Group"), + Name: "Group", + Endpoint: "/Groups", + Description: optional.NewString("Group"), + Schema: schema.CoreGroupSchema(), + Handler: &testResourceHandler{ + data: map[string]testData{ + "0001": {attributes: map[string]interface{}{}}, + }, + schema: schema.CoreGroupSchema(), }, - schema: schema.CoreGroupSchema(), }, }, }, + ) + if err != nil { + t.Fatal(err) } + return s } // testData represents a resource entity. @@ -156,7 +165,7 @@ func (h testResourceHandler) GetAll(r *http.Request, params scim.ListRequestPara break } - validator := internal.NewFilterValidator(params.Filter, h.schema) + validator := internal.NewFilterValidator(params.FilterValidator.GetFilter(), h.schema) if err := validator.PassesFilter(v.attributes); err != nil { continue } diff --git a/utils.go b/utils.go index 5890365..326653c 100644 --- a/utils.go +++ b/utils.go @@ -2,7 +2,7 @@ package scim import ( "bytes" - "io/ioutil" + "io" "net/http" ) @@ -29,10 +29,10 @@ func contains(arr []string, el string) bool { } func readBody(r *http.Request) ([]byte, error) { - data, err := ioutil.ReadAll(r.Body) + data, err := io.ReadAll(r.Body) if err != nil { return nil, err } - r.Body = ioutil.NopCloser(bytes.NewBuffer(data)) + r.Body = io.NopCloser(bytes.NewBuffer(data)) return data, nil } diff --git a/utils_test.go b/utils_test.go index 920cad9..24ddd1d 100644 --- a/utils_test.go +++ b/utils_test.go @@ -30,7 +30,7 @@ func assertEqualStrings(t *testing.T, expected, actual []string) { assertLen(t, actual, len(expected)) for i, id := range expected { if rID := actual[i]; rID != id { - t.Errorf("%s is not equal to %sd", rID, id) + t.Errorf("%s is not equal to %s", rID, id) } } }